feat:新增 A_Memorix 记忆插件
引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -339,6 +339,7 @@ run_pet.bat
|
||||
|
||||
/plugins/*
|
||||
!/plugins
|
||||
!/plugins/A_memorix
|
||||
!/plugins/hello_world_plugin
|
||||
!/plugins/emoji_manage_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
|
||||
12
plugins/A_memorix/__init__.py
Normal file
12
plugins/A_memorix/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
A_Memorix - 轻量级知识库插件
|
||||
|
||||
完全独立的记忆增强系统,优化低资源环境下的知识存储与检索。
|
||||
"""
|
||||
|
||||
__version__ = "2.0.0"
|
||||
__author__ = "A_Dawn"
|
||||
|
||||
from .plugin import AMemorixPlugin
|
||||
|
||||
__all__ = ["AMemorixPlugin"]
|
||||
62
plugins/A_memorix/_manifest.json
Normal file
62
plugins/A_memorix/_manifest.json
Normal file
@@ -0,0 +1,62 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "A_Memorix",
|
||||
"version": "2.0.0",
|
||||
"description": "MaiBot SDK 长期记忆插件,负责统一检索、写入、画像与记忆维护。",
|
||||
"author": {
|
||||
"name": "A_Dawn"
|
||||
},
|
||||
"license": "AGPL-3.0",
|
||||
"repository_url": "https://github.com/A-Dawn/A_memorix/",
|
||||
"host_application": {
|
||||
"min_version": "1.0.0"
|
||||
},
|
||||
"keywords": [
|
||||
"memory",
|
||||
"knowledge",
|
||||
"retrieval",
|
||||
"profile",
|
||||
"episode"
|
||||
],
|
||||
"categories": [
|
||||
"Memory",
|
||||
"Data"
|
||||
],
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "memory_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "search_memory",
|
||||
"description": "搜索长期记忆"
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "ingest_summary",
|
||||
"description": "写入聊天摘要"
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "ingest_text",
|
||||
"description": "写入普通长期记忆文本"
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "get_person_profile",
|
||||
"description": "查询人物画像"
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "maintain_memory",
|
||||
"description": "维护记忆关系"
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "memory_stats",
|
||||
"description": "查询记忆统计"
|
||||
}
|
||||
]
|
||||
},
|
||||
"capabilities": []
|
||||
}
|
||||
84
plugins/A_memorix/core/__init__.py
Normal file
84
plugins/A_memorix/core/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""核心模块 - 存储、嵌入、检索引擎"""
|
||||
|
||||
# 存储模块(已实现)
|
||||
from .storage import (
|
||||
VectorStore,
|
||||
GraphStore,
|
||||
MetadataStore,
|
||||
ImportStrategy,
|
||||
KnowledgeType,
|
||||
parse_import_strategy,
|
||||
resolve_stored_knowledge_type,
|
||||
detect_knowledge_type,
|
||||
select_import_strategy,
|
||||
should_extract_relations,
|
||||
get_type_display_name,
|
||||
)
|
||||
|
||||
# 嵌入模块(使用主程序 API)
|
||||
from .embedding import (
|
||||
EmbeddingAPIAdapter,
|
||||
create_embedding_api_adapter,
|
||||
)
|
||||
|
||||
# 检索模块(已实现)
|
||||
from .retrieval import (
|
||||
DualPathRetriever,
|
||||
RetrievalStrategy,
|
||||
RetrievalResult,
|
||||
DualPathRetrieverConfig,
|
||||
TemporalQueryOptions,
|
||||
FusionConfig,
|
||||
GraphRelationRecallConfig,
|
||||
RelationIntentConfig,
|
||||
PersonalizedPageRank,
|
||||
PageRankConfig,
|
||||
create_ppr_from_graph,
|
||||
DynamicThresholdFilter,
|
||||
ThresholdMethod,
|
||||
ThresholdConfig,
|
||||
SparseBM25Index,
|
||||
SparseBM25Config,
|
||||
)
|
||||
from .utils import (
|
||||
RelationWriteService,
|
||||
RelationWriteResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Storage
|
||||
"VectorStore",
|
||||
"GraphStore",
|
||||
"MetadataStore",
|
||||
"ImportStrategy",
|
||||
"KnowledgeType",
|
||||
"parse_import_strategy",
|
||||
"resolve_stored_knowledge_type",
|
||||
"detect_knowledge_type",
|
||||
"select_import_strategy",
|
||||
"should_extract_relations",
|
||||
"get_type_display_name",
|
||||
# Embedding
|
||||
"EmbeddingAPIAdapter",
|
||||
"create_embedding_api_adapter",
|
||||
# Retrieval
|
||||
"DualPathRetriever",
|
||||
"RetrievalStrategy",
|
||||
"RetrievalResult",
|
||||
"DualPathRetrieverConfig",
|
||||
"TemporalQueryOptions",
|
||||
"FusionConfig",
|
||||
"GraphRelationRecallConfig",
|
||||
"RelationIntentConfig",
|
||||
"PersonalizedPageRank",
|
||||
"PageRankConfig",
|
||||
"create_ppr_from_graph",
|
||||
"DynamicThresholdFilter",
|
||||
"ThresholdMethod",
|
||||
"ThresholdConfig",
|
||||
"SparseBM25Index",
|
||||
"SparseBM25Config",
|
||||
"RelationWriteService",
|
||||
"RelationWriteResult",
|
||||
]
|
||||
|
||||
18
plugins/A_memorix/core/embedding/__init__.py
Normal file
18
plugins/A_memorix/core/embedding/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""嵌入模块 - 向量生成与量化"""
|
||||
|
||||
# 新的 API 适配器(主程序嵌入 API)
|
||||
from .api_adapter import (
|
||||
EmbeddingAPIAdapter,
|
||||
create_embedding_api_adapter,
|
||||
)
|
||||
|
||||
from ..utils.quantization import QuantizationType
|
||||
|
||||
__all__ = [
|
||||
# 新的 API 适配器(推荐使用)
|
||||
"EmbeddingAPIAdapter",
|
||||
"create_embedding_api_adapter",
|
||||
# 量化
|
||||
"QuantizationType",
|
||||
]
|
||||
|
||||
174
plugins/A_memorix/core/embedding/api_adapter.py
Normal file
174
plugins/A_memorix/core/embedding/api_adapter.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Hash-based embedding adapter used by the SDK runtime.
|
||||
|
||||
The plugin runtime cannot import MaiBot host embedding internals from ``src.chat``
|
||||
or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in
|
||||
Runner by generating deterministic dense vectors locally.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
|
||||
|
||||
_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}")
|
||||
|
||||
|
||||
class EmbeddingAPIAdapter:
|
||||
"""Deterministic local embedding adapter."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 32,
|
||||
max_concurrent: int = 5,
|
||||
default_dimension: int = 256,
|
||||
enable_cache: bool = False,
|
||||
model_name: str = "hash-v1",
|
||||
retry_config: Optional[dict] = None,
|
||||
) -> None:
|
||||
self.batch_size = max(1, int(batch_size))
|
||||
self.max_concurrent = max(1, int(max_concurrent))
|
||||
self.default_dimension = max(32, int(default_dimension))
|
||||
self.enable_cache = bool(enable_cache)
|
||||
self.model_name = str(model_name or "hash-v1")
|
||||
self.retry_config = retry_config or {}
|
||||
|
||||
self._dimension: Optional[int] = None
|
||||
self._dimension_detected = False
|
||||
self._total_encoded = 0
|
||||
self._total_errors = 0
|
||||
self._total_time = 0.0
|
||||
|
||||
logger.info(
|
||||
"EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s",
|
||||
self.model_name,
|
||||
self.batch_size,
|
||||
self.default_dimension,
|
||||
)
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
if self._dimension_detected and self._dimension is not None:
|
||||
return self._dimension
|
||||
self._dimension = self.default_dimension
|
||||
self._dimension_detected = True
|
||||
return self._dimension
|
||||
|
||||
@staticmethod
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
clean = str(text or "").strip().lower()
|
||||
if not clean:
|
||||
return []
|
||||
return _TOKEN_PATTERN.findall(clean)
|
||||
|
||||
@staticmethod
|
||||
def _feature_weight(token: str) -> float:
|
||||
digest = hashlib.sha256(token.encode("utf-8")).digest()
|
||||
return 1.0 + (digest[10] / 255.0) * 0.5
|
||||
|
||||
def _encode_single(self, text: str, dimension: int) -> np.ndarray:
|
||||
vector = np.zeros(dimension, dtype=np.float32)
|
||||
content = str(text or "").strip()
|
||||
tokens = self._tokenize(content)
|
||||
if not tokens and content:
|
||||
tokens = [content.lower()]
|
||||
if not tokens:
|
||||
vector[0] = 1.0
|
||||
return vector
|
||||
|
||||
for token in tokens:
|
||||
digest = hashlib.sha256(token.encode("utf-8")).digest()
|
||||
bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension
|
||||
sign = 1.0 if digest[8] % 2 == 0 else -1.0
|
||||
vector[bucket] += sign * self._feature_weight(token)
|
||||
|
||||
second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension
|
||||
if second_bucket != bucket:
|
||||
vector[second_bucket] += (sign * 0.35)
|
||||
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 1e-8:
|
||||
vector /= norm
|
||||
else:
|
||||
vector[0] = 1.0
|
||||
return vector
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
texts: Union[str, List[str]],
|
||||
batch_size: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
normalize: bool = True,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
_ = batch_size
|
||||
_ = show_progress
|
||||
_ = normalize
|
||||
|
||||
started_at = time.time()
|
||||
target_dimension = max(32, int(dimensions or await self._detect_dimension()))
|
||||
|
||||
if isinstance(texts, str):
|
||||
single_input = True
|
||||
normalized_texts = [texts]
|
||||
else:
|
||||
single_input = False
|
||||
normalized_texts = list(texts or [])
|
||||
|
||||
if not normalized_texts:
|
||||
empty = np.zeros((0, target_dimension), dtype=np.float32)
|
||||
return empty[0] if single_input else empty
|
||||
|
||||
try:
|
||||
matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts])
|
||||
self._total_encoded += len(normalized_texts)
|
||||
self._total_time += time.time() - started_at
|
||||
except Exception:
|
||||
self._total_errors += 1
|
||||
raise
|
||||
|
||||
return matrix[0] if single_input else matrix
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"dimension": self._dimension or self.default_dimension,
|
||||
"total_encoded": self._total_encoded,
|
||||
"total_errors": self._total_errors,
|
||||
"total_time": self._total_time,
|
||||
"avg_time_per_text": avg_time,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"EmbeddingAPIAdapter(model_name={self.model_name}, "
|
||||
f"dimension={self._dimension or self.default_dimension}, "
|
||||
f"total_encoded={self._total_encoded})"
|
||||
)
|
||||
|
||||
|
||||
def create_embedding_api_adapter(
|
||||
batch_size: int = 32,
|
||||
max_concurrent: int = 5,
|
||||
default_dimension: int = 256,
|
||||
enable_cache: bool = False,
|
||||
model_name: str = "hash-v1",
|
||||
retry_config: Optional[dict] = None,
|
||||
) -> EmbeddingAPIAdapter:
|
||||
return EmbeddingAPIAdapter(
|
||||
batch_size=batch_size,
|
||||
max_concurrent=max_concurrent,
|
||||
default_dimension=default_dimension,
|
||||
enable_cache=enable_cache,
|
||||
model_name=model_name,
|
||||
retry_config=retry_config,
|
||||
)
|
||||
510
plugins/A_memorix/core/embedding/manager.py
Normal file
510
plugins/A_memorix/core/embedding/manager.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
嵌入管理器
|
||||
|
||||
负责嵌入模型的加载、缓存和批量生成。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import pickle
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, List, Dict, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
HAS_SENTENCE_TRANSFORMERS = True
|
||||
except ImportError:
|
||||
HAS_SENTENCE_TRANSFORMERS = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .presets import (
|
||||
EmbeddingModelConfig,
|
||||
get_custom_config,
|
||||
validate_config_compatibility,
|
||||
are_models_compatible,
|
||||
)
|
||||
from ..utils.quantization import QuantizationType
|
||||
|
||||
logger = get_logger("A_Memorix.EmbeddingManager")
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""
|
||||
嵌入管理器
|
||||
|
||||
功能:
|
||||
- 模型加载与缓存
|
||||
- 批量生成嵌入
|
||||
- 多线程/多进程支持
|
||||
- 模型一致性检查
|
||||
- 智能分批
|
||||
|
||||
参数:
|
||||
config: 模型配置
|
||||
cache_dir: 缓存目录
|
||||
enable_cache: 是否启用缓存
|
||||
num_workers: 工作线程数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: EmbeddingModelConfig,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
enable_cache: bool = True,
|
||||
num_workers: int = 1,
|
||||
):
|
||||
"""
|
||||
初始化嵌入管理器
|
||||
|
||||
Args:
|
||||
config: 模型配置
|
||||
cache_dir: 缓存目录
|
||||
enable_cache: 是否启用缓存
|
||||
num_workers: 工作线程数
|
||||
"""
|
||||
if not HAS_SENTENCE_TRANSFORMERS:
|
||||
raise ImportError(
|
||||
"sentence-transformers 未安装,请安装: "
|
||||
"pip install sentence-transformers"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.cache_dir = Path(cache_dir) if cache_dir else None
|
||||
self.enable_cache = enable_cache
|
||||
self.num_workers = max(1, num_workers)
|
||||
|
||||
# 模型实例
|
||||
self._model: Optional[SentenceTransformer] = None
|
||||
self._model_lock = threading.Lock()
|
||||
|
||||
# 缓存
|
||||
self._embedding_cache: Dict[str, np.ndarray] = {}
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
# 统计
|
||||
self._total_encoded = 0
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
logger.info(
|
||||
f"EmbeddingManager 初始化: model={config.model_name}, "
|
||||
f"dim={config.dimension}, workers={num_workers}"
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""加载模型(懒加载)"""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
with self._model_lock:
|
||||
# 双重检查
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
logger.info(f"正在加载模型: {self.config.model_name}")
|
||||
|
||||
try:
|
||||
# 构建模型参数
|
||||
model_kwargs = {}
|
||||
if self.config.cache_dir:
|
||||
model_kwargs["cache_folder"] = self.config.cache_dir
|
||||
|
||||
# 加载模型
|
||||
self._model = SentenceTransformer(
|
||||
self.config.model_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"模型加载成功: {self.config.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
def encode(
|
||||
self,
|
||||
texts: Union[str, List[str]],
|
||||
batch_size: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
normalize: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
生成文本嵌入
|
||||
|
||||
Args:
|
||||
texts: 文本或文本列表
|
||||
batch_size: 批次大小(默认使用配置值)
|
||||
show_progress: 是否显示进度条
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
嵌入向量 (N x D)
|
||||
"""
|
||||
# 确保模型已加载
|
||||
self.load_model()
|
||||
|
||||
# 标准化输入
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
single_input = True
|
||||
else:
|
||||
single_input = False
|
||||
|
||||
if not texts:
|
||||
return np.zeros((0, self.config.dimension), dtype=np.float32)
|
||||
|
||||
# 使用配置的批次大小
|
||||
if batch_size is None:
|
||||
batch_size = self.config.batch_size
|
||||
|
||||
# 生成嵌入
|
||||
try:
|
||||
embeddings = self._model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=show_progress,
|
||||
normalize_embeddings=normalize and self.config.normalization,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
|
||||
# 确保是2D数组
|
||||
if embeddings.ndim == 1:
|
||||
embeddings = embeddings.reshape(1, -1)
|
||||
|
||||
self._total_encoded += len(texts)
|
||||
|
||||
# 如果是单个输入,返回1D数组
|
||||
if single_input:
|
||||
return embeddings[0]
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成嵌入失败: {e}")
|
||||
raise
|
||||
|
||||
def encode_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
批量生成嵌入(多线程优化)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
batch_size: 批次大小
|
||||
num_workers: 工作线程数(默认使用初始化时的值)
|
||||
show_progress: 是否显示进度条
|
||||
|
||||
Returns:
|
||||
嵌入向量 (N x D)
|
||||
"""
|
||||
if not texts:
|
||||
return np.zeros((0, self.config.dimension), dtype=np.float32)
|
||||
|
||||
# 单线程模式
|
||||
num_workers = num_workers if num_workers is not None else self.num_workers
|
||||
if num_workers == 1:
|
||||
return self.encode(texts, batch_size=batch_size, show_progress=show_progress)
|
||||
|
||||
# 多线程模式
|
||||
logger.info(f"使用 {num_workers} 个线程生成 {len(texts)} 个嵌入")
|
||||
|
||||
# 分批
|
||||
batch_size = batch_size or self.config.batch_size
|
||||
batches = [
|
||||
texts[i:i + batch_size]
|
||||
for i in range(0, len(texts), batch_size)
|
||||
]
|
||||
|
||||
# 多线程生成
|
||||
all_embeddings = []
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
# 提交任务
|
||||
future_to_batch = {
|
||||
executor.submit(
|
||||
self.encode,
|
||||
batch,
|
||||
batch_size,
|
||||
False, # 不显示进度条(多线程时会混乱)
|
||||
): i
|
||||
for i, batch in enumerate(batches)
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(future_to_batch):
|
||||
batch_idx = future_to_batch[future]
|
||||
try:
|
||||
embeddings = future.result()
|
||||
all_embeddings.append((batch_idx, embeddings))
|
||||
except Exception as e:
|
||||
logger.error(f"批次 {batch_idx} 生成嵌入失败: {e}")
|
||||
raise
|
||||
|
||||
# 按顺序合并
|
||||
all_embeddings.sort(key=lambda x: x[0])
|
||||
final_embeddings = np.concatenate([emb for _, emb in all_embeddings], axis=0)
|
||||
|
||||
return final_embeddings
|
||||
|
||||
def encode_with_cache(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
生成嵌入(带缓存)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
batch_size: 批次大小
|
||||
show_progress: 是否显示进度条
|
||||
|
||||
Returns:
|
||||
嵌入向量 (N x D)
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return self.encode(texts, batch_size, show_progress)
|
||||
|
||||
# 分离缓存命中和未命中的文本
|
||||
cached_embeddings = []
|
||||
uncached_texts = []
|
||||
uncached_indices = []
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
cache_key = self._get_cache_key(text)
|
||||
|
||||
with self._cache_lock:
|
||||
if cache_key in self._embedding_cache:
|
||||
cached_embeddings.append((i, self._embedding_cache[cache_key]))
|
||||
self._cache_hits += 1
|
||||
else:
|
||||
uncached_texts.append(text)
|
||||
uncached_indices.append(i)
|
||||
self._cache_misses += 1
|
||||
|
||||
# 生成未缓存的嵌入
|
||||
if uncached_texts:
|
||||
new_embeddings = self.encode(
|
||||
uncached_texts,
|
||||
batch_size,
|
||||
show_progress,
|
||||
)
|
||||
|
||||
# 更新缓存
|
||||
with self._cache_lock:
|
||||
for text, embedding in zip(uncached_texts, new_embeddings):
|
||||
cache_key = self._get_cache_key(text)
|
||||
self._embedding_cache[cache_key] = embedding.copy()
|
||||
|
||||
# 合并结果
|
||||
for idx, embedding in zip(uncached_indices, new_embeddings):
|
||||
cached_embeddings.append((idx, embedding))
|
||||
|
||||
# 按原始顺序排序
|
||||
cached_embeddings.sort(key=lambda x: x[0])
|
||||
final_embeddings = np.array([emb for _, emb in cached_embeddings])
|
||||
|
||||
return final_embeddings
|
||||
|
||||
def save_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None:
|
||||
"""
|
||||
保存缓存到磁盘
|
||||
|
||||
Args:
|
||||
cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl)
|
||||
"""
|
||||
if cache_path is None:
|
||||
if self.cache_dir is None:
|
||||
raise ValueError("未指定缓存目录")
|
||||
cache_path = self.cache_dir / "embeddings_cache.pkl"
|
||||
|
||||
cache_path = Path(cache_path)
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with self._cache_lock:
|
||||
with open(cache_path, "wb") as f:
|
||||
pickle.dump(self._embedding_cache, f)
|
||||
|
||||
logger.info(f"缓存已保存: {cache_path} ({len(self._embedding_cache)} 条)")
|
||||
|
||||
def load_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None:
|
||||
"""
|
||||
从磁盘加载缓存
|
||||
|
||||
Args:
|
||||
cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl)
|
||||
"""
|
||||
if cache_path is None:
|
||||
if self.cache_dir is None:
|
||||
raise ValueError("未指定缓存目录")
|
||||
cache_path = self.cache_dir / "embeddings_cache.pkl"
|
||||
|
||||
cache_path = Path(cache_path)
|
||||
if not cache_path.exists():
|
||||
logger.warning(f"缓存文件不存在: {cache_path}")
|
||||
return
|
||||
|
||||
with self._cache_lock:
|
||||
with open(cache_path, "rb") as f:
|
||||
self._embedding_cache = pickle.load(f)
|
||||
|
||||
logger.info(f"缓存已加载: {cache_path} ({len(self._embedding_cache)} 条)")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空缓存"""
|
||||
with self._cache_lock:
|
||||
count = len(self._embedding_cache)
|
||||
self._embedding_cache.clear()
|
||||
logger.info(f"已清空缓存: {count} 条")
|
||||
|
||||
def check_model_consistency(
|
||||
self,
|
||||
stored_embeddings: np.ndarray,
|
||||
sample_texts: List[str] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
检查模型一致性
|
||||
|
||||
Args:
|
||||
stored_embeddings: 存储的嵌入向量
|
||||
sample_texts: 样本文本(用于重新生成对比)
|
||||
|
||||
Returns:
|
||||
(是否一致, 详细信息)
|
||||
"""
|
||||
# 检查维度
|
||||
if stored_embeddings.shape[1] != self.config.dimension:
|
||||
return False, f"维度不匹配: 期望 {self.config.dimension}, 实际 {stored_embeddings.shape[1]}"
|
||||
|
||||
# 如果提供了样本文本,重新生成并比较
|
||||
if sample_texts:
|
||||
try:
|
||||
new_embeddings = self.encode(sample_texts[:5]) # 只比较前5个
|
||||
|
||||
# 计算相似度
|
||||
similarities = np.dot(
|
||||
stored_embeddings[:5],
|
||||
new_embeddings.T,
|
||||
).diagonal()
|
||||
|
||||
# 检查相似度
|
||||
if np.mean(similarities) < 0.95:
|
||||
return False, f"模型可能已更改,平均相似度: {np.mean(similarities):.3f}"
|
||||
|
||||
return True, f"模型一致,平均相似度: {np.mean(similarities):.3f}"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"一致性检查失败: {e}"
|
||||
|
||||
return True, "维度匹配"
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
return {
|
||||
"model_name": self.config.model_name,
|
||||
"dimension": self.config.dimension,
|
||||
"max_seq_length": self.config.max_seq_length,
|
||||
"batch_size": self.config.batch_size,
|
||||
"normalization": self.config.normalization,
|
||||
"pooling": self.config.pooling,
|
||||
"model_loaded": self._model is not None,
|
||||
"cache_enabled": self.enable_cache,
|
||||
"cache_size": len(self._embedding_cache),
|
||||
"total_encoded": self._total_encoded,
|
||||
"cache_hits": self._cache_hits,
|
||||
"cache_misses": self._cache_misses,
|
||||
}
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""获取嵌入维度"""
|
||||
return self.config.dimension
|
||||
|
||||
def _get_cache_key(self, text: str) -> str:
|
||||
"""
|
||||
生成缓存键
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
|
||||
Returns:
|
||||
缓存键(SHA256哈希)
|
||||
"""
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
"""模型是否已加载"""
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def cache_hit_rate(self) -> float:
|
||||
"""缓存命中率"""
|
||||
total = self._cache_hits + self._cache_misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self._cache_hits / total
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"EmbeddingManager(model={self.config.model_name}, "
|
||||
f"dim={self.config.dimension}, "
|
||||
f"loaded={self.is_model_loaded}, "
|
||||
f"cache={len(self._embedding_cache)})"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def create_embedding_manager_from_config(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
dimension: int,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
enable_cache: bool = True,
|
||||
num_workers: int = 1,
|
||||
**config_kwargs,
|
||||
) -> EmbeddingManager:
|
||||
"""
|
||||
从自定义配置创建嵌入管理器
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
model_path: HuggingFace模型路径
|
||||
dimension: 输出维度
|
||||
cache_dir: 缓存目录
|
||||
enable_cache: 是否启用缓存
|
||||
num_workers: 工作线程数
|
||||
**config_kwargs: 其他配置参数
|
||||
|
||||
Returns:
|
||||
嵌入管理器实例
|
||||
"""
|
||||
# 创建自定义配置
|
||||
config = get_custom_config(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
dimension=dimension,
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
# 创建管理器
|
||||
return EmbeddingManager(
|
||||
config=config,
|
||||
cache_dir=cache_dir,
|
||||
enable_cache=enable_cache,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
72
plugins/A_memorix/core/embedding/presets.py
Normal file
72
plugins/A_memorix/core/embedding/presets.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
嵌入模型配置模块
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelConfig:
|
||||
"""
|
||||
嵌入模型配置
|
||||
|
||||
属性:
|
||||
model_name: 模型描述名称
|
||||
model_path: 实际加载路径(Local or HF)
|
||||
dimension: 嵌入向量维度
|
||||
max_seq_length: 最大序列长度
|
||||
batch_size: 编码批次大小
|
||||
model_size_mb: 估计显存占用
|
||||
description: 模型说明
|
||||
normalization: 是否自动归一化
|
||||
pooling: 池化策略 (mean, cls, max)
|
||||
cache_dir: 模型缓存目录
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
model_path: str
|
||||
dimension: int
|
||||
max_seq_length: int = 512
|
||||
batch_size: int = 32
|
||||
model_size_mb: int = 100
|
||||
description: str = ""
|
||||
normalization: bool = True
|
||||
pooling: str = "mean"
|
||||
cache_dir: Optional[Union[str, Path]] = None
|
||||
|
||||
|
||||
def validate_config_compatibility(
|
||||
config1: EmbeddingModelConfig, config2: EmbeddingModelConfig
|
||||
) -> bool:
|
||||
"""检查两个配置是否兼容(主要看维度)"""
|
||||
return config1.dimension == config2.dimension
|
||||
|
||||
|
||||
def are_models_compatible(
|
||||
config1: EmbeddingModelConfig, config2: EmbeddingModelConfig
|
||||
) -> bool:
|
||||
"""检查模型是否完全相同(用于热切换判断)"""
|
||||
return (
|
||||
config1.model_path == config2.model_path
|
||||
and config1.dimension == config2.dimension
|
||||
and config1.pooling == config2.pooling
|
||||
)
|
||||
|
||||
|
||||
def get_custom_config(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
dimension: int,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
**kwargs,
|
||||
) -> EmbeddingModelConfig:
|
||||
"""创建自定义模型配置"""
|
||||
return EmbeddingModelConfig(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
dimension=dimension,
|
||||
cache_dir=cache_dir,
|
||||
**kwargs,
|
||||
)
|
||||
54
plugins/A_memorix/core/retrieval/__init__.py
Normal file
54
plugins/A_memorix/core/retrieval/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""检索模块 - 双路检索与排序"""
|
||||
|
||||
from .dual_path import (
|
||||
DualPathRetriever,
|
||||
RetrievalStrategy,
|
||||
RetrievalResult,
|
||||
DualPathRetrieverConfig,
|
||||
TemporalQueryOptions,
|
||||
FusionConfig,
|
||||
RelationIntentConfig,
|
||||
)
|
||||
from .pagerank import (
|
||||
PersonalizedPageRank,
|
||||
PageRankConfig,
|
||||
create_ppr_from_graph,
|
||||
)
|
||||
from .threshold import (
|
||||
DynamicThresholdFilter,
|
||||
ThresholdMethod,
|
||||
ThresholdConfig,
|
||||
)
|
||||
from .sparse_bm25 import (
|
||||
SparseBM25Index,
|
||||
SparseBM25Config,
|
||||
)
|
||||
from .graph_relation_recall import (
|
||||
GraphRelationRecallConfig,
|
||||
GraphRelationRecallService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# DualPathRetriever
|
||||
"DualPathRetriever",
|
||||
"RetrievalStrategy",
|
||||
"RetrievalResult",
|
||||
"DualPathRetrieverConfig",
|
||||
"TemporalQueryOptions",
|
||||
"FusionConfig",
|
||||
"RelationIntentConfig",
|
||||
# PersonalizedPageRank
|
||||
"PersonalizedPageRank",
|
||||
"PageRankConfig",
|
||||
"create_ppr_from_graph",
|
||||
# DynamicThresholdFilter
|
||||
"DynamicThresholdFilter",
|
||||
"ThresholdMethod",
|
||||
"ThresholdConfig",
|
||||
# Sparse BM25
|
||||
"SparseBM25Index",
|
||||
"SparseBM25Config",
|
||||
# Graph relation recall
|
||||
"GraphRelationRecallConfig",
|
||||
"GraphRelationRecallService",
|
||||
]
|
||||
1796
plugins/A_memorix/core/retrieval/dual_path.py
Normal file
1796
plugins/A_memorix/core/retrieval/dual_path.py
Normal file
File diff suppressed because it is too large
Load Diff
272
plugins/A_memorix/core/retrieval/graph_relation_recall.py
Normal file
272
plugins/A_memorix/core/retrieval/graph_relation_recall.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Graph-assisted relation candidate recall for relation-oriented queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.GraphRelationRecall")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphRelationRecallConfig:
|
||||
"""Configuration for controlled graph relation recall."""
|
||||
|
||||
enabled: bool = True
|
||||
candidate_k: int = 24
|
||||
max_hop: int = 1
|
||||
allow_two_hop_pair: bool = True
|
||||
max_paths: int = 4
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.enabled = bool(self.enabled)
|
||||
self.candidate_k = max(1, int(self.candidate_k))
|
||||
self.max_hop = max(1, int(self.max_hop))
|
||||
self.allow_two_hop_pair = bool(self.allow_two_hop_pair)
|
||||
self.max_paths = max(1, int(self.max_paths))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphRelationCandidate:
|
||||
"""A graph-derived relation candidate before retriever-side fusion."""
|
||||
|
||||
hash_value: str
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
graph_seed_entities: List[str]
|
||||
graph_hops: int
|
||||
graph_candidate_type: str
|
||||
supporting_paragraph_count: int
|
||||
|
||||
def to_payload(self) -> Dict[str, Any]:
|
||||
content = f"{self.subject} {self.predicate} {self.object}"
|
||||
return {
|
||||
"hash": self.hash_value,
|
||||
"content": content,
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object": self.object,
|
||||
"confidence": self.confidence,
|
||||
"graph_seed_entities": list(self.graph_seed_entities),
|
||||
"graph_hops": int(self.graph_hops),
|
||||
"graph_candidate_type": self.graph_candidate_type,
|
||||
"supporting_paragraph_count": int(self.supporting_paragraph_count),
|
||||
}
|
||||
|
||||
|
||||
class GraphRelationRecallService:
|
||||
"""Collect relation candidates from the entity graph in a controlled way."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
config: Optional[GraphRelationRecallConfig] = None,
|
||||
) -> None:
|
||||
self.graph_store = graph_store
|
||||
self.metadata_store = metadata_store
|
||||
self.config = config or GraphRelationRecallConfig()
|
||||
|
||||
def recall(
|
||||
self,
|
||||
*,
|
||||
seed_entities: Sequence[str],
|
||||
) -> List[GraphRelationCandidate]:
|
||||
if not self.config.enabled:
|
||||
return []
|
||||
if self.graph_store is None or self.metadata_store is None:
|
||||
return []
|
||||
|
||||
seeds = self._normalize_seed_entities(seed_entities)
|
||||
if not seeds:
|
||||
return []
|
||||
|
||||
seen_hashes: Set[str] = set()
|
||||
candidates: List[GraphRelationCandidate] = []
|
||||
|
||||
if len(seeds) >= 2:
|
||||
self._collect_direct_pair_candidates(
|
||||
seed_a=seeds[0],
|
||||
seed_b=seeds[1],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
if (
|
||||
len(candidates) < 3
|
||||
and self.config.allow_two_hop_pair
|
||||
and len(candidates) < self.config.candidate_k
|
||||
):
|
||||
self._collect_two_hop_pair_candidates(
|
||||
seed_a=seeds[0],
|
||||
seed_b=seeds[1],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
else:
|
||||
self._collect_one_hop_seed_candidates(
|
||||
seed=seeds[0],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
|
||||
return candidates[: self.config.candidate_k]
|
||||
|
||||
def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]:
|
||||
out: List[str] = []
|
||||
seen = set()
|
||||
for raw in list(seed_entities)[:2]:
|
||||
resolved = None
|
||||
try:
|
||||
resolved = self.graph_store.find_node(str(raw), ignore_case=True)
|
||||
except Exception:
|
||||
resolved = None
|
||||
if not resolved:
|
||||
continue
|
||||
canon = str(resolved).strip().lower()
|
||||
if not canon or canon in seen:
|
||||
continue
|
||||
seen.add(canon)
|
||||
out.append(str(resolved))
|
||||
return out
|
||||
|
||||
def _collect_direct_pair_candidates(
|
||||
self,
|
||||
*,
|
||||
seed_a: str,
|
||||
seed_b: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
relation_hashes = []
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b))
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a))
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="direct_pair",
|
||||
graph_hops=1,
|
||||
graph_seed_entities=[seed_a, seed_b],
|
||||
)
|
||||
|
||||
def _collect_two_hop_pair_candidates(
|
||||
self,
|
||||
*,
|
||||
seed_a: str,
|
||||
seed_b: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
try:
|
||||
paths = self.graph_store.find_paths(
|
||||
seed_a,
|
||||
seed_b,
|
||||
max_depth=2,
|
||||
max_paths=self.config.max_paths,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("graph two-hop recall skipped: %s", e)
|
||||
return
|
||||
|
||||
for path_nodes in paths:
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3:
|
||||
continue
|
||||
if len(path_nodes) != 3:
|
||||
continue
|
||||
for idx in range(len(path_nodes) - 1):
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
u = str(path_nodes[idx])
|
||||
v = str(path_nodes[idx + 1])
|
||||
relation_hashes = []
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v))
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u))
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="two_hop_pair",
|
||||
graph_hops=2,
|
||||
graph_seed_entities=[seed_a, seed_b],
|
||||
)
|
||||
|
||||
def _collect_one_hop_seed_candidates(
|
||||
self,
|
||||
*,
|
||||
seed: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
try:
|
||||
relation_hashes = self.graph_store.get_incident_relation_hashes(
|
||||
seed,
|
||||
limit=self.config.candidate_k,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("graph one-hop recall skipped: %s", e)
|
||||
return
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="one_hop_seed",
|
||||
graph_hops=min(1, self.config.max_hop),
|
||||
graph_seed_entities=[seed],
|
||||
)
|
||||
|
||||
def _append_relation_hashes(
|
||||
self,
|
||||
*,
|
||||
relation_hashes: Sequence[str],
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
candidate_type: str,
|
||||
graph_hops: int,
|
||||
graph_seed_entities: Sequence[str],
|
||||
) -> None:
|
||||
for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}):
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
if relation_hash in seen_hashes:
|
||||
continue
|
||||
candidate = self._build_candidate(
|
||||
relation_hash=relation_hash,
|
||||
candidate_type=candidate_type,
|
||||
graph_hops=graph_hops,
|
||||
graph_seed_entities=graph_seed_entities,
|
||||
)
|
||||
if candidate is None:
|
||||
continue
|
||||
seen_hashes.add(relation_hash)
|
||||
out.append(candidate)
|
||||
|
||||
def _build_candidate(
|
||||
self,
|
||||
*,
|
||||
relation_hash: str,
|
||||
candidate_type: str,
|
||||
graph_hops: int,
|
||||
graph_seed_entities: Sequence[str],
|
||||
) -> Optional[GraphRelationCandidate]:
|
||||
relation = self.metadata_store.get_relation(relation_hash)
|
||||
if relation is None:
|
||||
return None
|
||||
supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash)
|
||||
return GraphRelationCandidate(
|
||||
hash_value=relation_hash,
|
||||
subject=str(relation.get("subject", "")),
|
||||
predicate=str(relation.get("predicate", "")),
|
||||
object=str(relation.get("object", "")),
|
||||
confidence=float(relation.get("confidence", 1.0) or 1.0),
|
||||
graph_seed_entities=[str(x) for x in graph_seed_entities],
|
||||
graph_hops=int(graph_hops),
|
||||
graph_candidate_type=str(candidate_type),
|
||||
supporting_paragraph_count=len(supporting_paragraphs),
|
||||
)
|
||||
482
plugins/A_memorix/core/retrieval/pagerank.py
Normal file
482
plugins/A_memorix/core/retrieval/pagerank.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
Personalized PageRank实现
|
||||
|
||||
提供个性化的图节点排序功能。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..storage import GraphStore
|
||||
from ..utils.matcher import AhoCorasick
|
||||
|
||||
logger = get_logger("A_Memorix.PersonalizedPageRank")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PageRankConfig:
|
||||
"""
|
||||
PageRank配置
|
||||
|
||||
属性:
|
||||
alpha: 阻尼系数(0-1之间)
|
||||
max_iter: 最大迭代次数
|
||||
tol: 收敛阈值
|
||||
normalize: 是否归一化结果
|
||||
min_iterations: 最小迭代次数
|
||||
"""
|
||||
|
||||
alpha: float = 0.85
|
||||
max_iter: int = 100
|
||||
tol: float = 1e-6
|
||||
normalize: bool = True
|
||||
min_iterations: int = 20
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
if not 0 <= self.alpha < 1:
|
||||
raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}")
|
||||
|
||||
if self.max_iter <= 0:
|
||||
raise ValueError(f"max_iter必须大于0: {self.max_iter}")
|
||||
|
||||
if self.tol <= 0:
|
||||
raise ValueError(f"tol必须大于0: {self.tol}")
|
||||
|
||||
if self.min_iterations < 0:
|
||||
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
|
||||
|
||||
if self.min_iterations >= self.max_iter:
|
||||
raise ValueError(f"min_iterations必须小于max_iter")
|
||||
|
||||
|
||||
class PersonalizedPageRank:
|
||||
"""
|
||||
Personalized PageRank计算器
|
||||
|
||||
功能:
|
||||
- 个性化向量支持
|
||||
- 快速收敛检测
|
||||
- 结果归一化
|
||||
- 批量计算
|
||||
- 统计信息
|
||||
|
||||
参数:
|
||||
graph_store: 图存储
|
||||
config: PageRank配置
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_store: GraphStore,
|
||||
config: Optional[PageRankConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化PPR计算器
|
||||
|
||||
Args:
|
||||
graph_store: 图存储
|
||||
config: PageRank配置
|
||||
"""
|
||||
self.graph_store = graph_store
|
||||
self.config = config or PageRankConfig()
|
||||
|
||||
# 统计信息
|
||||
self._total_computations = 0
|
||||
self._total_iterations = 0
|
||||
self._convergence_history: List[int] = []
|
||||
|
||||
logger.info(
|
||||
f"PersonalizedPageRank 初始化: "
|
||||
f"alpha={self.config.alpha}, "
|
||||
f"max_iter={self.config.max_iter}"
|
||||
)
|
||||
|
||||
# 缓存 Aho-Corasick 匹配器
|
||||
self._ac_matcher: Optional[AhoCorasick] = None
|
||||
self._ac_nodes_count = 0
|
||||
|
||||
def compute(
|
||||
self,
|
||||
personalization: Optional[Dict[str, float]] = None,
|
||||
alpha: Optional[float] = None,
|
||||
max_iter: Optional[int] = None,
|
||||
normalize: Optional[bool] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
计算Personalized PageRank
|
||||
|
||||
Args:
|
||||
personalization: 个性化向量 {节点名: 权重}
|
||||
alpha: 阻尼系数(覆盖配置值)
|
||||
max_iter: 最大迭代次数(覆盖配置值)
|
||||
normalize: 是否归一化(覆盖配置值)
|
||||
|
||||
Returns:
|
||||
节点PageRank值字典 {节点名: 分数}
|
||||
"""
|
||||
# 使用覆盖值或配置值
|
||||
alpha = alpha if alpha is not None else self.config.alpha
|
||||
max_iter = max_iter if max_iter is not None else self.config.max_iter
|
||||
normalize = normalize if normalize is not None else self.config.normalize
|
||||
|
||||
# 调用GraphStore的compute_pagerank
|
||||
scores = self.graph_store.compute_pagerank(
|
||||
personalization=personalization,
|
||||
alpha=alpha,
|
||||
max_iter=max_iter,
|
||||
tol=self.config.tol,
|
||||
)
|
||||
|
||||
# 归一化(如果需要)
|
||||
if normalize and scores:
|
||||
total = sum(scores.values())
|
||||
if total > 0:
|
||||
scores = {node: score / total for node, score in scores.items()}
|
||||
|
||||
# 更新统计
|
||||
self._total_computations += 1
|
||||
|
||||
logger.debug(
|
||||
f"PPR计算完成: {len(scores)} 个节点, "
|
||||
f"personalization_nodes={len(personalization) if personalization else 0}"
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def compute_batch(
|
||||
self,
|
||||
personalization_list: List[Dict[str, float]],
|
||||
normalize: bool = True,
|
||||
) -> List[Dict[str, float]]:
|
||||
"""
|
||||
批量计算PPR
|
||||
|
||||
Args:
|
||||
personalization_list: 个性化向量列表
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典列表
|
||||
"""
|
||||
results = []
|
||||
|
||||
for i, personalization in enumerate(personalization_list):
|
||||
logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR")
|
||||
scores = self.compute(personalization=personalization, normalize=normalize)
|
||||
results.append(scores)
|
||||
|
||||
return results
|
||||
|
||||
def compute_for_entities(
|
||||
self,
|
||||
entities: List[str],
|
||||
weights: Optional[List[float]] = None,
|
||||
normalize: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
为实体列表计算PPR
|
||||
|
||||
Args:
|
||||
entities: 实体列表
|
||||
weights: 权重列表(默认均匀权重)
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典
|
||||
"""
|
||||
if not entities:
|
||||
logger.warning("实体列表为空,返回均匀PPR")
|
||||
return self.compute(personalization=None, normalize=normalize)
|
||||
|
||||
# 构建个性化向量
|
||||
if weights is None:
|
||||
weights = [1.0] * len(entities)
|
||||
|
||||
if len(weights) != len(entities):
|
||||
raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}")
|
||||
|
||||
personalization = {entity: weight for entity, weight in zip(entities, weights)}
|
||||
|
||||
return self.compute(personalization=personalization, normalize=normalize)
|
||||
|
||||
def compute_for_query(
|
||||
self,
|
||||
query: str,
|
||||
entity_extractor: Optional[callable] = None,
|
||||
normalize: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
为查询计算PPR
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
entity_extractor: 实体提取函数(可选)
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典
|
||||
"""
|
||||
# 提取实体
|
||||
if entity_extractor is not None:
|
||||
entities = entity_extractor(query)
|
||||
else:
|
||||
# 简单实现:基于图中的节点匹配
|
||||
entities = self._extract_entities_from_query(query)
|
||||
|
||||
if not entities:
|
||||
logger.debug(f"未从查询中提取到实体: '{query}'")
|
||||
return self.compute(personalization=None, normalize=normalize)
|
||||
|
||||
# 计算PPR
|
||||
return self.compute_for_entities(entities, normalize=normalize)
|
||||
|
||||
def rank_nodes(
|
||||
self,
|
||||
scores: Dict[str, float],
|
||||
top_k: Optional[int] = None,
|
||||
min_score: float = 0.0,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
对节点排序
|
||||
|
||||
Args:
|
||||
scores: PageRank分数字典
|
||||
top_k: 返回前k个节点(None表示全部)
|
||||
min_score: 最小分数阈值
|
||||
|
||||
Returns:
|
||||
排序后的节点列表 [(节点名, 分数), ...]
|
||||
"""
|
||||
# 过滤低分节点
|
||||
filtered = [(node, score) for node, score in scores.items() if score >= min_score]
|
||||
|
||||
# 按分数降序排序
|
||||
sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 返回top_k
|
||||
if top_k is not None:
|
||||
sorted_nodes = sorted_nodes[:top_k]
|
||||
|
||||
return sorted_nodes
|
||||
|
||||
def get_personalization_vector(
|
||||
self,
|
||||
nodes: List[str],
|
||||
method: str = "uniform",
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
生成个性化向量
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
method: 生成方法
|
||||
- "uniform": 均匀权重
|
||||
- "degree": 按度数加权
|
||||
- "inverse_degree": 按度数反比加权
|
||||
|
||||
Returns:
|
||||
个性化向量 {节点名: 权重}
|
||||
"""
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
if method == "uniform":
|
||||
# 均匀权重
|
||||
weight = 1.0 / len(nodes)
|
||||
return {node: weight for node in nodes}
|
||||
|
||||
elif method == "degree":
|
||||
# 按度数加权
|
||||
node_degrees = {}
|
||||
for node in nodes:
|
||||
neighbors = self.graph_store.get_neighbors(node)
|
||||
node_degrees[node] = len(neighbors)
|
||||
|
||||
total_degree = sum(node_degrees.values())
|
||||
if total_degree > 0:
|
||||
return {node: degree / total_degree for node, degree in node_degrees.items()}
|
||||
else:
|
||||
return {node: 1.0 / len(nodes) for node in nodes}
|
||||
|
||||
elif method == "inverse_degree":
|
||||
# 按度数反比加权
|
||||
node_degrees = {}
|
||||
for node in nodes:
|
||||
neighbors = self.graph_store.get_neighbors(node)
|
||||
node_degrees[node] = len(neighbors)
|
||||
|
||||
# 反度数
|
||||
inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()}
|
||||
total_inv = sum(inv_degrees.values())
|
||||
|
||||
if total_inv > 0:
|
||||
return {node: inv / total_inv for node, inv in inv_degrees.items()}
|
||||
else:
|
||||
return {node: 1.0 / len(nodes) for node in nodes}
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的个性化向量生成方法: {method}")
|
||||
|
||||
def compare_scores(
|
||||
self,
|
||||
scores1: Dict[str, float],
|
||||
scores2: Dict[str, float],
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
比较两组PPR分数
|
||||
|
||||
Args:
|
||||
scores1: 第一组分数
|
||||
scores2: 第二组分数
|
||||
|
||||
Returns:
|
||||
比较结果 {
|
||||
"common_nodes": {节点: (score1, score2)},
|
||||
"only_in_1": {节点: score1},
|
||||
"only_in_2": {节点: score2},
|
||||
}
|
||||
"""
|
||||
common_nodes = {}
|
||||
only_in_1 = {}
|
||||
only_in_2 = {}
|
||||
|
||||
all_nodes = set(scores1.keys()) | set(scores2.keys())
|
||||
|
||||
for node in all_nodes:
|
||||
if node in scores1 and node in scores2:
|
||||
common_nodes[node] = (scores1[node], scores2[node])
|
||||
elif node in scores1:
|
||||
only_in_1[node] = scores1[node]
|
||||
else:
|
||||
only_in_2[node] = scores2[node]
|
||||
|
||||
return {
|
||||
"common_nodes": common_nodes,
|
||||
"only_in_1": only_in_1,
|
||||
"only_in_2": only_in_2,
|
||||
}
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
avg_iterations = (
|
||||
self._total_iterations / self._total_computations
|
||||
if self._total_computations > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"config": {
|
||||
"alpha": self.config.alpha,
|
||||
"max_iter": self.config.max_iter,
|
||||
"tol": self.config.tol,
|
||||
"normalize": self.config.normalize,
|
||||
"min_iterations": self.config.min_iterations,
|
||||
},
|
||||
"statistics": {
|
||||
"total_computations": self._total_computations,
|
||||
"total_iterations": self._total_iterations,
|
||||
"avg_iterations": avg_iterations,
|
||||
"convergence_history": self._convergence_history.copy(),
|
||||
},
|
||||
"graph": {
|
||||
"num_nodes": self.graph_store.num_nodes,
|
||||
"num_edges": self.graph_store.num_edges,
|
||||
},
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""重置统计信息"""
|
||||
self._total_computations = 0
|
||||
self._total_iterations = 0
|
||||
self._convergence_history.clear()
|
||||
logger.info("统计信息已重置")
|
||||
|
||||
def _extract_entities_from_query(self, query: str) -> List[str]:
|
||||
"""
|
||||
从查询中提取实体(简化实现)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
# 获取所有节点
|
||||
all_nodes = self.graph_store.get_nodes()
|
||||
if not all_nodes:
|
||||
return []
|
||||
|
||||
# 检查是否需要更新 Aho-Corasick 匹配器
|
||||
if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes):
|
||||
self._ac_matcher = AhoCorasick()
|
||||
for node in all_nodes:
|
||||
# 统一转为小写进行不区分大小写匹配
|
||||
self._ac_matcher.add_pattern(node.lower())
|
||||
self._ac_matcher.build()
|
||||
self._ac_nodes_count = len(all_nodes)
|
||||
|
||||
# 执行匹配
|
||||
query_lower = query.lower()
|
||||
stats = self._ac_matcher.find_all(query_lower)
|
||||
|
||||
# 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值)
|
||||
# 为了简单,AC 中 add_pattern 存的是小写
|
||||
# 我们需要一个映射:小写 -> 原始
|
||||
node_map = {node.lower(): node for node in all_nodes}
|
||||
entities = [node_map[low_name] for low_name in stats.keys()]
|
||||
|
||||
return entities
|
||||
|
||||
@property
|
||||
def num_computations(self) -> int:
|
||||
"""计算次数"""
|
||||
return self._total_computations
|
||||
|
||||
@property
|
||||
def avg_iterations(self) -> float:
|
||||
"""平均迭代次数"""
|
||||
if self._total_computations == 0:
|
||||
return 0.0
|
||||
return self._total_iterations / self._total_computations
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PersonalizedPageRank("
|
||||
f"alpha={self.config.alpha}, "
|
||||
f"computations={self._total_computations})"
|
||||
)
|
||||
|
||||
|
||||
def create_ppr_from_graph(
|
||||
graph_store: GraphStore,
|
||||
alpha: float = 0.85,
|
||||
max_iter: int = 100,
|
||||
) -> PersonalizedPageRank:
|
||||
"""
|
||||
从图存储创建PPR计算器
|
||||
|
||||
Args:
|
||||
graph_store: 图存储
|
||||
alpha: 阻尼系数
|
||||
max_iter: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
PPR计算器实例
|
||||
"""
|
||||
config = PageRankConfig(
|
||||
alpha=alpha,
|
||||
max_iter=max_iter,
|
||||
)
|
||||
|
||||
return PersonalizedPageRank(
|
||||
graph_store=graph_store,
|
||||
config=config,
|
||||
)
|
||||
402
plugins/A_memorix/core/retrieval/sparse_bm25.py
Normal file
402
plugins/A_memorix/core/retrieval/sparse_bm25.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
稀疏检索组件(FTS5 + BM25)
|
||||
|
||||
支持:
|
||||
- 懒加载索引连接
|
||||
- jieba / char n-gram 分词
|
||||
- 可卸载并收缩 SQLite 内存缓存
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..storage import MetadataStore
|
||||
|
||||
logger = get_logger("A_Memorix.SparseBM25")
|
||||
|
||||
try:
|
||||
import jieba # type: ignore
|
||||
|
||||
HAS_JIEBA = True
|
||||
except Exception:
|
||||
HAS_JIEBA = False
|
||||
jieba = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseBM25Config:
|
||||
"""BM25 稀疏检索配置。"""
|
||||
|
||||
enabled: bool = True
|
||||
backend: str = "fts5"
|
||||
lazy_load: bool = True
|
||||
mode: str = "auto" # auto | fallback_only | hybrid
|
||||
tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram
|
||||
jieba_user_dict: str = ""
|
||||
char_ngram_n: int = 2
|
||||
candidate_k: int = 80
|
||||
max_doc_len: int = 2000
|
||||
enable_ngram_fallback_index: bool = True
|
||||
enable_like_fallback: bool = False
|
||||
enable_relation_sparse_fallback: bool = True
|
||||
relation_candidate_k: int = 60
|
||||
relation_max_doc_len: int = 512
|
||||
unload_on_disable: bool = True
|
||||
shrink_memory_on_unload: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.backend = str(self.backend or "fts5").strip().lower()
|
||||
self.mode = str(self.mode or "auto").strip().lower()
|
||||
self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower()
|
||||
self.char_ngram_n = max(1, int(self.char_ngram_n))
|
||||
self.candidate_k = max(1, int(self.candidate_k))
|
||||
self.max_doc_len = max(0, int(self.max_doc_len))
|
||||
self.relation_candidate_k = max(1, int(self.relation_candidate_k))
|
||||
self.relation_max_doc_len = max(0, int(self.relation_max_doc_len))
|
||||
if self.backend != "fts5":
|
||||
raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}")
|
||||
if self.mode not in {"auto", "fallback_only", "hybrid"}:
|
||||
raise ValueError(f"sparse.mode 非法: {self.mode}")
|
||||
if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}:
|
||||
raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}")
|
||||
|
||||
|
||||
class SparseBM25Index:
|
||||
"""
|
||||
基于 SQLite FTS5 的 BM25 检索适配层。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: MetadataStore,
|
||||
config: Optional[SparseBM25Config] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.config = config or SparseBM25Config()
|
||||
self._conn: Optional[sqlite3.Connection] = None
|
||||
self._loaded: bool = False
|
||||
self._jieba_dict_loaded: bool = False
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self._loaded and self._conn is not None
|
||||
|
||||
def ensure_loaded(self) -> bool:
|
||||
"""按需加载 FTS 连接与索引。"""
|
||||
if not self.config.enabled:
|
||||
return False
|
||||
if self.loaded:
|
||||
return True
|
||||
|
||||
db_path = self.metadata_store.get_db_path()
|
||||
conn = sqlite3.connect(
|
||||
str(db_path),
|
||||
check_same_thread=False,
|
||||
timeout=30.0,
|
||||
)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute("PRAGMA temp_store=MEMORY")
|
||||
|
||||
if not self.metadata_store.ensure_fts_schema(conn=conn):
|
||||
conn.close()
|
||||
return False
|
||||
self.metadata_store.ensure_fts_backfilled(conn=conn)
|
||||
# 关系稀疏检索按独立开关加载,避免不必要的初始化开销。
|
||||
if self.config.enable_relation_sparse_fallback:
|
||||
self.metadata_store.ensure_relations_fts_schema(conn=conn)
|
||||
self.metadata_store.ensure_relations_fts_backfilled(conn=conn)
|
||||
if self.config.enable_ngram_fallback_index:
|
||||
self.metadata_store.ensure_paragraph_ngram_schema(conn=conn)
|
||||
self.metadata_store.ensure_paragraph_ngram_backfilled(
|
||||
n=self.config.char_ngram_n,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
self._conn = conn
|
||||
self._loaded = True
|
||||
self._prepare_tokenizer()
|
||||
logger.info(
|
||||
"SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s",
|
||||
self.config.tokenizer_mode,
|
||||
self.config.mode,
|
||||
)
|
||||
return True
|
||||
|
||||
def _prepare_tokenizer(self) -> None:
|
||||
if self._jieba_dict_loaded:
|
||||
return
|
||||
if self.config.tokenizer_mode not in {"jieba", "mixed"}:
|
||||
return
|
||||
if not HAS_JIEBA:
|
||||
logger.warning("jieba 不可用,tokenizer 将退化为 char n-gram")
|
||||
return
|
||||
user_dict = str(self.config.jieba_user_dict or "").strip()
|
||||
if user_dict:
|
||||
try:
|
||||
jieba.load_userdict(user_dict) # type: ignore[union-attr]
|
||||
logger.info("已加载 jieba 用户词典: %s", user_dict)
|
||||
except Exception as e:
|
||||
logger.warning("加载 jieba 用户词典失败: %s", e)
|
||||
self._jieba_dict_loaded = True
|
||||
|
||||
def _tokenize_jieba(self, text: str) -> List[str]:
|
||||
if not HAS_JIEBA:
|
||||
return []
|
||||
try:
|
||||
tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr]
|
||||
return [t.strip().lower() for t in tokens if t and t.strip()]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _tokenize_char_ngram(self, text: str, n: int) -> List[str]:
|
||||
compact = re.sub(r"\s+", "", text.lower())
|
||||
if not compact:
|
||||
return []
|
||||
if len(compact) < n:
|
||||
return [compact]
|
||||
return [compact[i : i + n] for i in range(0, len(compact) - n + 1)]
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
text = str(text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
mode = self.config.tokenizer_mode
|
||||
if mode == "jieba":
|
||||
tokens = self._tokenize_jieba(text)
|
||||
if tokens:
|
||||
return list(dict.fromkeys(tokens))
|
||||
return self._tokenize_char_ngram(text, self.config.char_ngram_n)
|
||||
|
||||
if mode == "mixed":
|
||||
toks = self._tokenize_jieba(text)
|
||||
toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n))
|
||||
return list(dict.fromkeys([t for t in toks if t]))
|
||||
|
||||
return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n)))
|
||||
|
||||
def _build_match_query(self, tokens: List[str]) -> str:
|
||||
safe_tokens: List[str] = []
|
||||
for token in tokens:
|
||||
t = token.replace('"', '""').strip()
|
||||
if not t:
|
||||
continue
|
||||
safe_tokens.append(f'"{t}"')
|
||||
if not safe_tokens:
|
||||
return ""
|
||||
# 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。
|
||||
return " OR ".join(safe_tokens[:64])
|
||||
|
||||
def _fallback_substring_search(
|
||||
self,
|
||||
tokens: List[str],
|
||||
limit: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。
|
||||
|
||||
说明:
|
||||
- FTS 索引当前采用 unicode61 tokenizer。
|
||||
- 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。
|
||||
- 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
# 去重并裁剪 token 数量,避免生成超长 SQL。
|
||||
uniq_tokens = [t for t in dict.fromkeys(tokens) if t]
|
||||
uniq_tokens = uniq_tokens[:32]
|
||||
if not uniq_tokens:
|
||||
return []
|
||||
|
||||
if self.config.enable_ngram_fallback_index:
|
||||
try:
|
||||
# 允许运行时切换开关后按需补齐 schema/回填。
|
||||
self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn)
|
||||
self.metadata_store.ensure_paragraph_ngram_backfilled(
|
||||
n=self.config.char_ngram_n,
|
||||
conn=self._conn,
|
||||
)
|
||||
rows = self.metadata_store.ngram_search_paragraphs(
|
||||
tokens=uniq_tokens,
|
||||
limit=limit,
|
||||
max_doc_len=self.config.max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
if rows:
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}")
|
||||
|
||||
if not self.config.enable_like_fallback:
|
||||
return []
|
||||
|
||||
conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens))
|
||||
params: List[Any] = [f"%{tok}%" for tok in uniq_tokens]
|
||||
scan_limit = max(int(limit) * 8, 200)
|
||||
params.append(scan_limit)
|
||||
|
||||
sql = f"""
|
||||
SELECT p.hash, p.content
|
||||
FROM paragraphs p
|
||||
WHERE (p.is_deleted IS NULL OR p.is_deleted = 0)
|
||||
AND ({conditions})
|
||||
LIMIT ?
|
||||
"""
|
||||
rows = self.metadata_store.query(sql, tuple(params))
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
scored: List[Dict[str, Any]] = []
|
||||
token_count = max(1, len(uniq_tokens))
|
||||
for row in rows:
|
||||
content = str(row.get("content") or "")
|
||||
content_low = content.lower()
|
||||
matched = [tok for tok in uniq_tokens if tok in content_low]
|
||||
if not matched:
|
||||
continue
|
||||
coverage = len(matched) / token_count
|
||||
length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low))
|
||||
# 兜底路径使用相对分,保持与上层接口兼容。
|
||||
fallback_score = coverage * 0.8 + length_bonus * 0.2
|
||||
scored.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content,
|
||||
"bm25_score": -float(fallback_score),
|
||||
"fallback_score": float(fallback_score),
|
||||
}
|
||||
)
|
||||
|
||||
scored.sort(key=lambda x: x["fallback_score"], reverse=True)
|
||||
return scored[:limit]
|
||||
|
||||
def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
|
||||
"""执行 BM25 检索。"""
|
||||
if not self.config.enabled:
|
||||
return []
|
||||
if self.config.lazy_load and not self.loaded:
|
||||
if not self.ensure_loaded():
|
||||
return []
|
||||
if not self.loaded:
|
||||
return []
|
||||
# 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。
|
||||
self.metadata_store.ensure_relations_fts_schema(conn=self._conn)
|
||||
self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn)
|
||||
|
||||
tokens = self._tokenize(query)
|
||||
match_query = self._build_match_query(tokens)
|
||||
if not match_query:
|
||||
return []
|
||||
|
||||
limit = max(1, int(k))
|
||||
rows = self.metadata_store.fts_search_bm25(
|
||||
match_query=match_query,
|
||||
limit=limit,
|
||||
max_doc_len=self.config.max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
if not rows:
|
||||
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
bm25_score = float(row.get("bm25_score", 0.0))
|
||||
results.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"content": row["content"],
|
||||
"rank": rank,
|
||||
"bm25_score": bm25_score,
|
||||
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
|
||||
"""执行关系稀疏检索(FTS5 + BM25)。"""
|
||||
if not self.config.enabled or not self.config.enable_relation_sparse_fallback:
|
||||
return []
|
||||
if self.config.lazy_load and not self.loaded:
|
||||
if not self.ensure_loaded():
|
||||
return []
|
||||
if not self.loaded:
|
||||
return []
|
||||
|
||||
tokens = self._tokenize(query)
|
||||
match_query = self._build_match_query(tokens)
|
||||
if not match_query:
|
||||
return []
|
||||
|
||||
rows = self.metadata_store.fts_search_relations_bm25(
|
||||
match_query=match_query,
|
||||
limit=max(1, int(k)),
|
||||
max_doc_len=self.config.relation_max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
out: List[Dict[str, Any]] = []
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
bm25_score = float(row.get("bm25_score", 0.0))
|
||||
out.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"subject": row["subject"],
|
||||
"predicate": row["predicate"],
|
||||
"object": row["object"],
|
||||
"content": row["content"],
|
||||
"rank": rank,
|
||||
"bm25_score": bm25_score,
|
||||
"score": -bm25_score,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
def upsert_paragraph(self, paragraph_hash: str) -> bool:
|
||||
if not self.loaded:
|
||||
return False
|
||||
return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn)
|
||||
|
||||
def delete_paragraph(self, paragraph_hash: str) -> bool:
|
||||
if not self.loaded:
|
||||
return False
|
||||
return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn)
|
||||
|
||||
def unload(self) -> None:
|
||||
"""卸载 BM25 连接并尽量释放内存。"""
|
||||
if self._conn is not None:
|
||||
try:
|
||||
if self.config.shrink_memory_on_unload:
|
||||
self.metadata_store.shrink_memory(conn=self._conn)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
self._loaded = False
|
||||
logger.info("SparseBM25Index unloaded")
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
doc_count = 0
|
||||
if self.loaded:
|
||||
doc_count = self.metadata_store.fts_doc_count(conn=self._conn)
|
||||
return {
|
||||
"enabled": self.config.enabled,
|
||||
"backend": self.config.backend,
|
||||
"mode": self.config.mode,
|
||||
"tokenizer_mode": self.config.tokenizer_mode,
|
||||
"enable_ngram_fallback_index": self.config.enable_ngram_fallback_index,
|
||||
"enable_like_fallback": self.config.enable_like_fallback,
|
||||
"enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback,
|
||||
"loaded": self.loaded,
|
||||
"has_jieba": HAS_JIEBA,
|
||||
"doc_count": doc_count,
|
||||
}
|
||||
450
plugins/A_memorix/core/retrieval/threshold.py
Normal file
450
plugins/A_memorix/core/retrieval/threshold.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
动态阈值过滤器
|
||||
|
||||
根据检索结果的分布特征自适应调整过滤阈值。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .dual_path import RetrievalResult
|
||||
|
||||
logger = get_logger("A_Memorix.DynamicThresholdFilter")
|
||||
|
||||
|
||||
class ThresholdMethod(Enum):
|
||||
"""阈值计算方法"""
|
||||
|
||||
PERCENTILE = "percentile" # 百分位数
|
||||
STD_DEV = "std_dev" # 标准差
|
||||
GAP_DETECTION = "gap_detection" # 跳变检测
|
||||
ADAPTIVE = "adaptive" # 自适应(综合多种方法)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThresholdConfig:
|
||||
"""
|
||||
阈值配置
|
||||
|
||||
属性:
|
||||
method: 阈值计算方法
|
||||
min_threshold: 最小阈值(绝对值)
|
||||
max_threshold: 最大阈值(绝对值)
|
||||
percentile: 百分位数(用于percentile方法)
|
||||
std_multiplier: 标准差倍数(用于std_dev方法)
|
||||
min_results: 最少保留结果数
|
||||
enable_auto_adjust: 是否自动调整参数
|
||||
"""
|
||||
|
||||
method: ThresholdMethod = ThresholdMethod.ADAPTIVE
|
||||
min_threshold: float = 0.3
|
||||
max_threshold: float = 0.95
|
||||
percentile: float = 75.0 # 百分位数
|
||||
std_multiplier: float = 1.5 # 标准差倍数
|
||||
min_results: int = 3 # 最少保留结果数
|
||||
enable_auto_adjust: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
if not 0 <= self.min_threshold <= 1:
|
||||
raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}")
|
||||
|
||||
if not 0 <= self.max_threshold <= 1:
|
||||
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
|
||||
|
||||
if self.min_threshold >= self.max_threshold:
|
||||
raise ValueError(f"min_threshold必须小于max_threshold")
|
||||
|
||||
if not 0 <= self.percentile <= 100:
|
||||
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
|
||||
|
||||
if self.std_multiplier <= 0:
|
||||
raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}")
|
||||
|
||||
if self.min_results < 0:
|
||||
raise ValueError(f"min_results必须大于等于0: {self.min_results}")
|
||||
|
||||
|
||||
class DynamicThresholdFilter:
|
||||
"""
|
||||
动态阈值过滤器
|
||||
|
||||
功能:
|
||||
- 基于结果分布自适应计算阈值
|
||||
- 多种阈值计算方法
|
||||
- 自动参数调整
|
||||
- 统计信息收集
|
||||
|
||||
参数:
|
||||
config: 阈值配置
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ThresholdConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化动态阈值过滤器
|
||||
|
||||
Args:
|
||||
config: 阈值配置
|
||||
"""
|
||||
self.config = config or ThresholdConfig()
|
||||
|
||||
# 统计信息
|
||||
self._total_filtered = 0
|
||||
self._total_processed = 0
|
||||
self._threshold_history: List[float] = []
|
||||
|
||||
logger.info(
|
||||
f"DynamicThresholdFilter 初始化: "
|
||||
f"method={self.config.method.value}, "
|
||||
f"min_threshold={self.config.min_threshold}"
|
||||
)
|
||||
|
||||
def filter(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
return_threshold: bool = False,
|
||||
) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]:
|
||||
"""
|
||||
过滤检索结果
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
return_threshold: 是否返回使用的阈值
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表,或 (结果列表, 阈值) 元组
|
||||
"""
|
||||
if not results:
|
||||
logger.debug("结果列表为空,无需过滤")
|
||||
return ([], 0.0) if return_threshold else []
|
||||
|
||||
self._total_processed += len(results)
|
||||
|
||||
# 提取分数
|
||||
scores = np.array([r.score for r in results])
|
||||
|
||||
# 计算阈值
|
||||
threshold = self._compute_threshold(scores, results)
|
||||
|
||||
# 记录阈值
|
||||
self._threshold_history.append(threshold)
|
||||
|
||||
# 应用阈值过滤
|
||||
filtered_results = [
|
||||
r for r in results
|
||||
if r.score >= threshold
|
||||
]
|
||||
|
||||
# 确保至少保留min_results个结果
|
||||
if len(filtered_results) < self.config.min_results:
|
||||
# 按分数排序,取前min_results个
|
||||
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||
filtered_results = sorted_results[:self.config.min_results]
|
||||
threshold = filtered_results[-1].score if filtered_results else 0.0
|
||||
|
||||
self._total_filtered += len(results) - len(filtered_results)
|
||||
|
||||
logger.info(
|
||||
f"过滤完成: {len(results)} -> {len(filtered_results)} "
|
||||
f"(threshold={threshold:.3f})"
|
||||
)
|
||||
|
||||
if return_threshold:
|
||||
return filtered_results, threshold
|
||||
return filtered_results
|
||||
|
||||
def _compute_threshold(
|
||||
self,
|
||||
scores: np.ndarray,
|
||||
results: List[RetrievalResult],
|
||||
) -> float:
|
||||
"""
|
||||
计算阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
if self.config.method == ThresholdMethod.PERCENTILE:
|
||||
threshold = self._percentile_threshold(scores)
|
||||
elif self.config.method == ThresholdMethod.STD_DEV:
|
||||
threshold = self._std_dev_threshold(scores)
|
||||
elif self.config.method == ThresholdMethod.GAP_DETECTION:
|
||||
threshold = self._gap_detection_threshold(scores)
|
||||
else: # ADAPTIVE
|
||||
# 自适应方法:综合多种方法
|
||||
thresholds = [
|
||||
self._percentile_threshold(scores),
|
||||
self._std_dev_threshold(scores),
|
||||
self._gap_detection_threshold(scores),
|
||||
]
|
||||
# 使用中位数作为最终阈值
|
||||
threshold = float(np.median(thresholds))
|
||||
|
||||
# 限制在[min_threshold, max_threshold]范围内
|
||||
threshold = np.clip(
|
||||
threshold,
|
||||
self.config.min_threshold,
|
||||
self.config.max_threshold,
|
||||
)
|
||||
|
||||
# 自动调整
|
||||
if self.config.enable_auto_adjust:
|
||||
threshold = self._auto_adjust_threshold(threshold, scores)
|
||||
|
||||
return float(threshold)
|
||||
|
||||
def _percentile_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于百分位数计算阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
percentile = self.config.percentile
|
||||
threshold = float(np.percentile(scores, percentile))
|
||||
|
||||
logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})")
|
||||
return threshold
|
||||
|
||||
def _std_dev_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于标准差计算阈值
|
||||
|
||||
threshold = mean - std_multiplier * std
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
mean = float(np.mean(scores))
|
||||
std = float(np.std(scores))
|
||||
multiplier = self.config.std_multiplier
|
||||
|
||||
threshold = mean - multiplier * std
|
||||
|
||||
logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})")
|
||||
return threshold
|
||||
|
||||
def _gap_detection_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于跳变检测计算阈值
|
||||
|
||||
找到分数分布中最大的"跳变"位置,以此为阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组(降序排列)
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
# 降序排列
|
||||
sorted_scores = np.sort(scores)[::-1]
|
||||
|
||||
if len(sorted_scores) < 2:
|
||||
return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0
|
||||
|
||||
# 计算相邻分数的差值
|
||||
gaps = np.diff(sorted_scores)
|
||||
|
||||
# 找到最大的跳变位置
|
||||
max_gap_idx = int(np.argmax(gaps))
|
||||
|
||||
# 阈值为跳变后的分数
|
||||
threshold = float(sorted_scores[max_gap_idx + 1])
|
||||
|
||||
logger.debug(
|
||||
f"跳变检测阈值: {threshold:.3f} "
|
||||
f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})"
|
||||
)
|
||||
return threshold
|
||||
|
||||
def _auto_adjust_threshold(
|
||||
self,
|
||||
threshold: float,
|
||||
scores: np.ndarray,
|
||||
) -> float:
|
||||
"""
|
||||
自动调整阈值
|
||||
|
||||
基于历史阈值和当前分数分布调整
|
||||
|
||||
Args:
|
||||
threshold: 当前阈值
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
调整后的阈值
|
||||
"""
|
||||
if not self._threshold_history:
|
||||
return threshold
|
||||
|
||||
# 计算历史阈值的移动平均
|
||||
recent_thresholds = self._threshold_history[-10:] # 最近10次
|
||||
avg_threshold = float(np.mean(recent_thresholds))
|
||||
|
||||
# 当前阈值与历史平均的差异
|
||||
diff = threshold - avg_threshold
|
||||
|
||||
# 如果差异过大(>0.2),向历史平均靠拢
|
||||
if abs(diff) > 0.2:
|
||||
adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50%
|
||||
logger.debug(
|
||||
f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} "
|
||||
f"(历史平均={avg_threshold:.3f})"
|
||||
)
|
||||
return adjusted_threshold
|
||||
|
||||
return threshold
|
||||
|
||||
def filter_by_confidence(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
min_confidence: float = 0.5,
|
||||
) -> List[RetrievalResult]:
|
||||
"""
|
||||
基于置信度过滤结果
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
min_confidence: 最小置信度
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表
|
||||
"""
|
||||
filtered = []
|
||||
for result in results:
|
||||
# 对于关系结果,使用confidence字段
|
||||
if result.result_type == "relation":
|
||||
confidence = result.metadata.get("confidence", 1.0)
|
||||
if confidence >= min_confidence:
|
||||
filtered.append(result)
|
||||
else:
|
||||
# 对于段落结果,直接使用分数
|
||||
if result.score >= min_confidence:
|
||||
filtered.append(result)
|
||||
|
||||
logger.info(
|
||||
f"置信度过滤: {len(results)} -> {len(filtered)} "
|
||||
f"(min_confidence={min_confidence})"
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
def filter_by_diversity(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
similarity_threshold: float = 0.9,
|
||||
top_k: int = 10,
|
||||
) -> List[RetrievalResult]:
|
||||
"""
|
||||
基于多样性过滤结果(去除重复)
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
similarity_threshold: 相似度阈值(高于此值视为重复)
|
||||
top_k: 最多保留结果数
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# 按分数排序
|
||||
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||
|
||||
# 贪心选择:选择与已选结果相似度低的结果
|
||||
selected = []
|
||||
selected_hashes = []
|
||||
|
||||
for result in sorted_results:
|
||||
if len(selected) >= top_k:
|
||||
break
|
||||
|
||||
# 检查与已选结果的相似度
|
||||
is_duplicate = False
|
||||
for selected_hash in selected_hashes:
|
||||
# 简单判断:基于hash的前缀
|
||||
if result.hash_value[:8] == selected_hash[:8]:
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
selected.append(result)
|
||||
selected_hashes.append(result.hash_value)
|
||||
|
||||
logger.info(
|
||||
f"多样性过滤: {len(results)} -> {len(selected)} "
|
||||
f"(similarity_threshold={similarity_threshold})"
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
filter_rate = (
|
||||
self._total_filtered / self._total_processed
|
||||
if self._total_processed > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
stats = {
|
||||
"config": {
|
||||
"method": self.config.method.value,
|
||||
"min_threshold": self.config.min_threshold,
|
||||
"max_threshold": self.config.max_threshold,
|
||||
"percentile": self.config.percentile,
|
||||
"std_multiplier": self.config.std_multiplier,
|
||||
"min_results": self.config.min_results,
|
||||
"enable_auto_adjust": self.config.enable_auto_adjust,
|
||||
},
|
||||
"statistics": {
|
||||
"total_processed": self._total_processed,
|
||||
"total_filtered": self._total_filtered,
|
||||
"filter_rate": filter_rate,
|
||||
"avg_threshold": float(np.mean(self._threshold_history))
|
||||
if self._threshold_history else 0.0,
|
||||
"threshold_count": len(self._threshold_history),
|
||||
},
|
||||
}
|
||||
|
||||
if self._threshold_history:
|
||||
stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history))
|
||||
stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history))
|
||||
|
||||
return stats
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""重置统计信息"""
|
||||
self._total_filtered = 0
|
||||
self._total_processed = 0
|
||||
self._threshold_history.clear()
|
||||
logger.info("统计信息已重置")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"DynamicThresholdFilter("
|
||||
f"method={self.config.method.value}, "
|
||||
f"min_threshold={self.config.min_threshold}, "
|
||||
f"filtered={self._total_filtered}/{self._total_processed})"
|
||||
)
|
||||
8
plugins/A_memorix/core/runtime/__init__.py
Normal file
8
plugins/A_memorix/core/runtime/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""SDK runtime exports for A_Memorix."""
|
||||
|
||||
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
|
||||
__all__ = [
|
||||
"KernelSearchRequest",
|
||||
"SDKMemoryKernel",
|
||||
]
|
||||
579
plugins/A_memorix/core/runtime/sdk_memory_kernel.py
Normal file
579
plugins/A_memorix/core/runtime/sdk_memory_kernel.py
Normal file
@@ -0,0 +1,579 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..embedding import create_embedding_api_adapter
|
||||
from ..retrieval import (
|
||||
DualPathRetriever,
|
||||
DualPathRetrieverConfig,
|
||||
RetrievalResult,
|
||||
SparseBM25Config,
|
||||
SparseBM25Index,
|
||||
TemporalQueryOptions,
|
||||
)
|
||||
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
|
||||
from ..utils.aggregate_query_service import AggregateQueryService
|
||||
from ..utils.episode_retrieval_service import EpisodeRetrievalService
|
||||
from ..utils.hash import normalize_text
|
||||
from ..utils.relation_write_service import RelationWriteService
|
||||
|
||||
logger = get_logger("A_Memorix.SDKMemoryKernel")
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSearchRequest:
|
||||
query: str = ""
|
||||
limit: int = 5
|
||||
mode: str = "hybrid"
|
||||
chat_id: str = ""
|
||||
person_id: str = ""
|
||||
time_start: Optional[float] = None
|
||||
time_end: Optional[float] = None
|
||||
|
||||
|
||||
class SDKMemoryKernel:
|
||||
def __init__(self, *, plugin_root: Path, config: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.plugin_root = Path(plugin_root).resolve()
|
||||
self.config = config or {}
|
||||
storage_cfg = self._cfg("storage", {}) or {}
|
||||
data_dir = str(storage_cfg.get("data_dir", "./data") or "./data")
|
||||
self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir)
|
||||
self.embedding_dimension = max(32, int(self._cfg("embedding.dimension", 256)))
|
||||
self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False))
|
||||
|
||||
self.embedding_manager = None
|
||||
self.vector_store: Optional[VectorStore] = None
|
||||
self.graph_store: Optional[GraphStore] = None
|
||||
self.metadata_store: Optional[MetadataStore] = None
|
||||
self.relation_write_service: Optional[RelationWriteService] = None
|
||||
self.sparse_index = None
|
||||
self.retriever: Optional[DualPathRetriever] = None
|
||||
self.episode_retriever: Optional[EpisodeRetrievalService] = None
|
||||
self.aggregate_query_service: Optional[AggregateQueryService] = None
|
||||
self._initialized = False
|
||||
self._last_maintenance_at: Optional[float] = None
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
current: Any = self.config
|
||||
if key in {"storage", "embedding", "retrieval"} and isinstance(current, dict):
|
||||
return current.get(key, default)
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.embedding_manager = create_embedding_api_adapter(
|
||||
batch_size=int(self._cfg("embedding.batch_size", 32)),
|
||||
max_concurrent=int(self._cfg("embedding.max_concurrent", 5)),
|
||||
default_dimension=self.embedding_dimension,
|
||||
model_name=str(self._cfg("embedding.model_name", "hash-v1")),
|
||||
retry_config=self._cfg("embedding.retry", {}) or {},
|
||||
)
|
||||
self.embedding_dimension = int(await self.embedding_manager._detect_dimension())
|
||||
self.vector_store = VectorStore(
|
||||
dimension=self.embedding_dimension,
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=self.data_dir / "vectors",
|
||||
)
|
||||
self.graph_store = GraphStore(matrix_format=SparseMatrixFormat.CSR, data_dir=self.data_dir / "graph")
|
||||
self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata")
|
||||
self.metadata_store.connect()
|
||||
if self.vector_store.has_data():
|
||||
self.vector_store.load()
|
||||
self.vector_store.warmup_index(force_train=True)
|
||||
if self.graph_store.has_data():
|
||||
self.graph_store.load()
|
||||
|
||||
sparse_cfg = self._cfg("retrieval.sparse", {}) or {}
|
||||
self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=SparseBM25Config(**sparse_cfg))
|
||||
if getattr(self.sparse_index.config, "enabled", False):
|
||||
self.sparse_index.ensure_loaded()
|
||||
|
||||
self.relation_write_service = RelationWriteService(
|
||||
metadata_store=self.metadata_store,
|
||||
graph_store=self.graph_store,
|
||||
vector_store=self.vector_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
)
|
||||
self.retriever = DualPathRetriever(
|
||||
vector_store=self.vector_store,
|
||||
graph_store=self.graph_store,
|
||||
metadata_store=self.metadata_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
sparse_index=self.sparse_index,
|
||||
config=DualPathRetrieverConfig(
|
||||
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 24)),
|
||||
top_k_relations=int(self._cfg("retrieval.top_k_relations", 12)),
|
||||
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
|
||||
alpha=float(self._cfg("retrieval.alpha", 0.5)),
|
||||
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
|
||||
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
|
||||
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
|
||||
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
|
||||
sparse=sparse_cfg,
|
||||
fusion=self._cfg("retrieval.fusion", {}) or {},
|
||||
graph_recall=self._cfg("retrieval.search.graph_recall", {}) or {},
|
||||
relation_intent=self._cfg("retrieval.search.relation_intent", {}) or {},
|
||||
),
|
||||
)
|
||||
self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever)
|
||||
self.aggregate_query_service = AggregateQueryService(plugin_config=self.config)
|
||||
self._initialized = True
|
||||
|
||||
def close(self) -> None:
|
||||
if self.vector_store is not None:
|
||||
self.vector_store.save()
|
||||
if self.graph_store is not None:
|
||||
self.graph_store.save()
|
||||
if self.metadata_store is not None:
|
||||
self.metadata_store.close()
|
||||
self._initialized = False
|
||||
|
||||
async def ingest_summary(
|
||||
self,
|
||||
*,
|
||||
external_id: str,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
participants: Optional[Sequence[str]] = None,
|
||||
time_start: Optional[float] = None,
|
||||
time_end: Optional[float] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
summary_meta = dict(metadata or {})
|
||||
summary_meta.setdefault("kind", "chat_summary")
|
||||
return await self.ingest_text(
|
||||
external_id=external_id,
|
||||
source_type="chat_summary",
|
||||
text=text,
|
||||
chat_id=chat_id,
|
||||
participants=participants,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
tags=tags,
|
||||
metadata=summary_meta,
|
||||
)
|
||||
|
||||
async def ingest_text(
|
||||
self,
|
||||
*,
|
||||
external_id: str,
|
||||
source_type: str,
|
||||
text: str,
|
||||
chat_id: str = "",
|
||||
person_ids: Optional[Sequence[str]] = None,
|
||||
participants: Optional[Sequence[str]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
time_start: Optional[float] = None,
|
||||
time_end: Optional[float] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
entities: Optional[Sequence[str]] = None,
|
||||
relations: Optional[Sequence[Dict[str, Any]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
await self.initialize()
|
||||
assert self.metadata_store and self.vector_store and self.graph_store and self.embedding_manager
|
||||
assert self.relation_write_service
|
||||
content = normalize_text(text)
|
||||
if not content:
|
||||
return {"stored_ids": [], "skipped_ids": [external_id], "reason": "empty_text"}
|
||||
if ref := self.metadata_store.get_external_memory_ref(external_id):
|
||||
return {"stored_ids": [], "skipped_ids": [str(ref.get("paragraph_hash", ""))], "reason": "exists"}
|
||||
|
||||
person_tokens = self._tokens(person_ids)
|
||||
participant_tokens = self._tokens(participants)
|
||||
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
|
||||
source = self._build_source(source_type, chat_id, person_tokens)
|
||||
paragraph_meta = dict(metadata or {})
|
||||
paragraph_meta.update(
|
||||
{
|
||||
"external_id": external_id,
|
||||
"source_type": str(source_type or "").strip(),
|
||||
"chat_id": str(chat_id or "").strip(),
|
||||
"person_ids": person_tokens,
|
||||
"participants": participant_tokens,
|
||||
"tags": self._tokens(tags),
|
||||
}
|
||||
)
|
||||
paragraph_hash = self.metadata_store.add_paragraph(
|
||||
content=content,
|
||||
source=source,
|
||||
metadata=paragraph_meta,
|
||||
knowledge_type="factual" if source_type == "person_fact" else "narrative" if source_type == "chat_summary" else "mixed",
|
||||
time_meta=self._time_meta(timestamp, time_start, time_end),
|
||||
)
|
||||
embedding = await self.embedding_manager.encode(content)
|
||||
self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash])
|
||||
for name in entity_tokens:
|
||||
self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash)
|
||||
|
||||
stored_relations: List[str] = []
|
||||
for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]:
|
||||
s = str(row.get("subject", "") or "").strip()
|
||||
p = str(row.get("predicate", "") or "").strip()
|
||||
o = str(row.get("object", "") or "").strip()
|
||||
if not (s and p and o):
|
||||
continue
|
||||
result = await self.relation_write_service.upsert_relation_with_vector(
|
||||
subject=s,
|
||||
predicate=p,
|
||||
obj=o,
|
||||
confidence=float(row.get("confidence", 1.0) or 1.0),
|
||||
source_paragraph=paragraph_hash,
|
||||
metadata={"external_id": external_id, "source_type": source_type},
|
||||
write_vector=self.relation_vectors_enabled,
|
||||
)
|
||||
self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value)
|
||||
stored_relations.append(result.hash_value)
|
||||
|
||||
self.metadata_store.upsert_external_memory_ref(
|
||||
external_id=external_id,
|
||||
paragraph_hash=paragraph_hash,
|
||||
source_type=source_type,
|
||||
metadata={"chat_id": chat_id, "person_ids": person_tokens},
|
||||
)
|
||||
self._persist()
|
||||
self.rebuild_episodes_for_sources([source])
|
||||
for person_id in person_tokens:
|
||||
await self.refresh_person_profile(person_id)
|
||||
return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []}
|
||||
|
||||
async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]:
|
||||
await self.initialize()
|
||||
assert self.retriever and self.episode_retriever and self.aggregate_query_service
|
||||
mode = str(request.mode or "hybrid").strip().lower() or "hybrid"
|
||||
clean_query = str(request.query or "").strip()
|
||||
limit = max(1, int(request.limit or 5))
|
||||
temporal = self._temporal(request)
|
||||
if mode == "episode":
|
||||
rows = await self.episode_retriever.query(
|
||||
query=clean_query,
|
||||
top_k=limit,
|
||||
time_from=request.time_start,
|
||||
time_to=request.time_end,
|
||||
source=self._chat_source(request.chat_id),
|
||||
)
|
||||
hits = [self._episode_hit(row) for row in rows]
|
||||
return {"summary": self._summary(hits), "hits": hits}
|
||||
if mode == "aggregate":
|
||||
payload = await self.aggregate_query_service.execute(
|
||||
query=clean_query,
|
||||
top_k=limit,
|
||||
mix=True,
|
||||
mix_top_k=limit,
|
||||
time_from=str(request.time_start) if request.time_start is not None else None,
|
||||
time_to=str(request.time_end) if request.time_end is not None else None,
|
||||
search_runner=lambda: self._aggregate_search(clean_query, limit, temporal),
|
||||
time_runner=lambda: self._aggregate_time(clean_query, limit, temporal),
|
||||
episode_runner=lambda: self._aggregate_episode(clean_query, limit, request),
|
||||
)
|
||||
hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)]
|
||||
for item in hits:
|
||||
item.setdefault("metadata", {})
|
||||
return {"summary": self._summary(hits), "hits": hits}
|
||||
results = await self.retriever.retrieve(query=clean_query, top_k=limit, temporal=temporal)
|
||||
hits = [self._retrieval_hit(item) for item in results]
|
||||
return {"summary": self._summary(self._filter_hits(hits, request.person_id)), "hits": self._filter_hits(hits, request.person_id)}
|
||||
|
||||
async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]:
|
||||
_ = chat_id
|
||||
await self.initialize()
|
||||
assert self.metadata_store
|
||||
snapshot = self.metadata_store.get_latest_person_profile_snapshot(person_id) or await self.refresh_person_profile(person_id, limit=limit)
|
||||
evidence = []
|
||||
for hash_value in snapshot.get("evidence_ids", [])[: max(1, int(limit))]:
|
||||
paragraph = self.metadata_store.get_paragraph(hash_value)
|
||||
if paragraph is not None:
|
||||
evidence.append({"hash": hash_value, "content": str(paragraph.get("content", "") or "")[:220], "metadata": paragraph.get("metadata", {}) or {}})
|
||||
text = str(snapshot.get("profile_text", "") or "").strip()
|
||||
traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8]
|
||||
return {"summary": text, "traits": traits, "evidence": evidence}
|
||||
|
||||
async def refresh_person_profile(self, person_id: str, limit: int = 10) -> Dict[str, Any]:
|
||||
await self.initialize()
|
||||
assert self.metadata_store
|
||||
rows = self.metadata_store.query(
|
||||
"""
|
||||
SELECT DISTINCT p.*
|
||||
FROM paragraphs p
|
||||
JOIN paragraph_entities pe ON pe.paragraph_hash = p.hash
|
||||
JOIN entities e ON e.hash = pe.entity_hash
|
||||
WHERE e.name = ?
|
||||
AND (p.is_deleted IS NULL OR p.is_deleted = 0)
|
||||
ORDER BY COALESCE(p.event_time_end, p.event_time_start, p.event_time, p.updated_at, p.created_at) DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(person_id, max(1, int(limit)) * 3),
|
||||
)
|
||||
evidence_ids = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()]
|
||||
vector_evidence = [{"hash": str(row.get("hash", "") or ""), "type": "paragraph", "score": 0.0, "content": str(row.get("content", "") or "")[:220], "metadata": row.get("metadata", {}) or {}} for row in rows[: max(1, int(limit))]]
|
||||
relation_edges = [{"hash": str(row.get("hash", "") or ""), "subject": str(row.get("subject", "") or ""), "predicate": str(row.get("predicate", "") or ""), "object": str(row.get("object", "") or ""), "confidence": float(row.get("confidence", 1.0) or 1.0)} for row in self.metadata_store.get_relations(subject=person_id)[:limit]]
|
||||
if relation_edges:
|
||||
profile_text = "\n".join(f"{item['subject']} {item['predicate']} {item['object']}" for item in relation_edges[:6])
|
||||
elif vector_evidence:
|
||||
profile_text = "\n".join(f"- {item['content']}" for item in vector_evidence[:6])
|
||||
else:
|
||||
profile_text = "暂无稳定画像证据。"
|
||||
return self.metadata_store.upsert_person_profile_snapshot(
|
||||
person_id=person_id,
|
||||
profile_text=profile_text,
|
||||
aliases=[person_id],
|
||||
relation_edges=relation_edges,
|
||||
vector_evidence=vector_evidence,
|
||||
evidence_ids=evidence_ids[: max(1, int(limit))],
|
||||
expires_at=time.time() + 6 * 3600,
|
||||
source_note="sdk_memory_kernel",
|
||||
)
|
||||
|
||||
async def maintain_memory(self, *, action: str, target: str, hours: Optional[float] = None, reason: str = "") -> Dict[str, Any]:
|
||||
_ = reason
|
||||
await self.initialize()
|
||||
assert self.metadata_store
|
||||
hashes = self._resolve_relation_hashes(target)
|
||||
if not hashes:
|
||||
return {"success": False, "detail": "未命中可维护关系"}
|
||||
act = str(action or "").strip().lower()
|
||||
if act == "reinforce":
|
||||
self.metadata_store.reinforce_relations(hashes)
|
||||
elif act == "protect":
|
||||
ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0
|
||||
self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0)
|
||||
elif act == "restore":
|
||||
restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value))
|
||||
if restored <= 0:
|
||||
return {"success": False, "detail": "未恢复任何关系"}
|
||||
else:
|
||||
return {"success": False, "detail": f"不支持的维护动作: {act}"}
|
||||
self._last_maintenance_at = time.time()
|
||||
self._persist()
|
||||
return {"success": True, "detail": f"{act} {len(hashes)} 条关系"}
|
||||
|
||||
def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> int:
|
||||
assert self.metadata_store
|
||||
rebuilt = 0
|
||||
for source in self._tokens(sources):
|
||||
rows = self.metadata_store.query(
|
||||
"""
|
||||
SELECT * FROM paragraphs
|
||||
WHERE source = ?
|
||||
AND (is_deleted IS NULL OR is_deleted = 0)
|
||||
ORDER BY COALESCE(event_time_start, event_time, created_at) ASC, hash ASC
|
||||
""",
|
||||
(source,),
|
||||
)
|
||||
if not rows:
|
||||
continue
|
||||
paragraph_hashes = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()]
|
||||
payload = self.metadata_store.upsert_episode(
|
||||
{
|
||||
"source": source,
|
||||
"title": str((rows[0].get("metadata", {}) or {}).get("theme", "") or f"{source} 情景记忆")[:80],
|
||||
"summary": ";".join(str(row.get("content", "") or "").strip().replace("\n", " ")[:120] for row in rows[:3] if str(row.get("content", "") or "").strip())[:500] or "自动构建的情景记忆。",
|
||||
"participants": self._episode_participants(rows),
|
||||
"keywords": self._episode_keywords(rows),
|
||||
"evidence_ids": paragraph_hashes,
|
||||
"paragraph_count": len(paragraph_hashes),
|
||||
"event_time_start": self._time_bound(rows, "event_time_start", "event_time", reverse=False),
|
||||
"event_time_end": self._time_bound(rows, "event_time_end", "event_time", reverse=True),
|
||||
"time_granularity": "day",
|
||||
"time_confidence": 0.7,
|
||||
"llm_confidence": 0.0,
|
||||
"segmentation_model": "rule_based_sdk",
|
||||
"segmentation_version": "1",
|
||||
}
|
||||
)
|
||||
self.metadata_store.bind_episode_paragraphs(payload["episode_id"], paragraph_hashes)
|
||||
rebuilt += 1
|
||||
return rebuilt
|
||||
|
||||
def memory_stats(self) -> Dict[str, Any]:
|
||||
assert self.metadata_store
|
||||
stats = self.metadata_store.get_statistics()
|
||||
episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"]
|
||||
profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"]
|
||||
return {"paragraphs": int(stats.get("paragraph_count", 0) or 0), "relations": int(stats.get("relation_count", 0) or 0), "episodes": int(episodes or 0), "profiles": int(profiles or 0), "last_maintenance_at": self._last_maintenance_at}
|
||||
|
||||
async def _aggregate_search(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]:
|
||||
assert self.retriever
|
||||
hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)]
|
||||
return {"success": True, "results": hits, "count": len(hits), "query_type": "search"}
|
||||
|
||||
async def _aggregate_time(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]:
|
||||
if temporal is None:
|
||||
return {"success": False, "error": "missing temporal window", "results": []}
|
||||
assert self.retriever
|
||||
hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)]
|
||||
return {"success": True, "results": hits, "count": len(hits), "query_type": "time"}
|
||||
|
||||
async def _aggregate_episode(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]:
|
||||
assert self.episode_retriever
|
||||
rows = await self.episode_retriever.query(query=query, top_k=limit, time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id))
|
||||
hits = [self._episode_hit(row) for row in rows]
|
||||
return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"}
|
||||
|
||||
def _persist(self) -> None:
|
||||
if self.vector_store is not None:
|
||||
self.vector_store.save()
|
||||
if self.graph_store is not None:
|
||||
self.graph_store.save()
|
||||
if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False):
|
||||
self.sparse_index.ensure_loaded()
|
||||
|
||||
@staticmethod
|
||||
def _tokens(values: Optional[Iterable[Any]]) -> List[str]:
|
||||
result: List[str] = []
|
||||
seen = set()
|
||||
for item in values or []:
|
||||
token = str(item or "").strip()
|
||||
if not token or token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
result.append(token)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _merge_tokens(cls, *groups: Optional[Iterable[Any]]) -> List[str]:
|
||||
merged: List[str] = []
|
||||
seen = set()
|
||||
for group in groups:
|
||||
for item in cls._tokens(group):
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
merged.append(item)
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _build_source(source_type: str, chat_id: str, person_ids: Sequence[str]) -> str:
|
||||
clean_type = str(source_type or "").strip() or "memory"
|
||||
if clean_type == "chat_summary" and chat_id:
|
||||
return f"chat_summary:{chat_id}"
|
||||
if clean_type == "person_fact" and person_ids:
|
||||
return f"person_fact:{person_ids[0]}"
|
||||
return f"{clean_type}:{chat_id}" if chat_id else clean_type
|
||||
|
||||
@staticmethod
|
||||
def _chat_source(chat_id: str) -> Optional[str]:
|
||||
clean = str(chat_id or "").strip()
|
||||
return f"chat_summary:{clean}" if clean else None
|
||||
|
||||
@staticmethod
|
||||
def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {}
|
||||
if timestamp is not None:
|
||||
payload["event_time"] = float(timestamp)
|
||||
if time_start is not None:
|
||||
payload["event_time_start"] = float(time_start)
|
||||
if time_end is not None:
|
||||
payload["event_time_end"] = float(time_end)
|
||||
if payload:
|
||||
payload["time_granularity"] = "minute"
|
||||
payload["time_confidence"] = 0.95
|
||||
return payload
|
||||
|
||||
def _temporal(self, request: KernelSearchRequest) -> Optional[TemporalQueryOptions]:
|
||||
if request.time_start is None and request.time_end is None and not request.chat_id:
|
||||
return None
|
||||
return TemporalQueryOptions(time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id))
|
||||
|
||||
@staticmethod
|
||||
def _retrieval_hit(item: RetrievalResult) -> Dict[str, Any]:
|
||||
payload = item.to_dict()
|
||||
return {"hash": payload.get("hash", ""), "content": payload.get("content", ""), "score": payload.get("score", 0.0), "type": payload.get("type", ""), "source": payload.get("source", ""), "metadata": payload.get("metadata", {}) or {}}
|
||||
|
||||
@staticmethod
|
||||
def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {"type": "episode", "episode_id": str(row.get("episode_id", "") or ""), "title": str(row.get("title", "") or ""), "content": str(row.get("summary", "") or ""), "score": float(row.get("lexical_score", 0.0) or 0.0), "source": "episode", "metadata": {"participants": row.get("participants", []) or [], "keywords": row.get("keywords", []) or [], "source": row.get("source"), "event_time_start": row.get("event_time_start"), "event_time_end": row.get("event_time_end")}}
|
||||
|
||||
@staticmethod
|
||||
def _summary(hits: Sequence[Dict[str, Any]]) -> str:
|
||||
if not hits:
|
||||
return ""
|
||||
lines = []
|
||||
for index, item in enumerate(hits[:5], start=1):
|
||||
content = str(item.get("content", "") or "").strip().replace("\n", " ")
|
||||
lines.append(f"{index}. {(content[:120] + '...') if len(content) > 120 else content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _filter_hits(hits: List[Dict[str, Any]], person_id: str) -> List[Dict[str, Any]]:
|
||||
if not person_id:
|
||||
return hits
|
||||
filtered = []
|
||||
for item in hits:
|
||||
metadata = item.get("metadata", {}) or {}
|
||||
if person_id in (metadata.get("person_ids", []) or []):
|
||||
filtered.append(item)
|
||||
continue
|
||||
if person_id and person_id in str(item.get("content", "") or ""):
|
||||
filtered.append(item)
|
||||
return filtered or hits
|
||||
|
||||
@staticmethod
|
||||
def _episode_participants(rows: Sequence[Dict[str, Any]]) -> List[str]:
|
||||
seen = set()
|
||||
result: List[str] = []
|
||||
for row in rows:
|
||||
meta = row.get("metadata", {}) or {}
|
||||
for key in ("participants", "person_ids"):
|
||||
for item in meta.get(key, []) or []:
|
||||
token = str(item or "").strip()
|
||||
if not token or token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
result.append(token)
|
||||
return result[:16]
|
||||
|
||||
@staticmethod
|
||||
def _episode_keywords(rows: Sequence[Dict[str, Any]]) -> List[str]:
|
||||
seen = set()
|
||||
result: List[str] = []
|
||||
for row in rows:
|
||||
meta = row.get("metadata", {}) or {}
|
||||
for item in meta.get("tags", []) or []:
|
||||
token = str(item or "").strip()
|
||||
if not token or token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
result.append(token)
|
||||
return result[:12]
|
||||
|
||||
@staticmethod
|
||||
def _time_bound(rows: Sequence[Dict[str, Any]], primary: str, fallback: str, reverse: bool) -> Optional[float]:
|
||||
values: List[float] = []
|
||||
for row in rows:
|
||||
for key in (primary, fallback):
|
||||
value = row.get(key)
|
||||
try:
|
||||
if value is not None:
|
||||
values.append(float(value))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if not values:
|
||||
return None
|
||||
return max(values) if reverse else min(values)
|
||||
|
||||
def _resolve_relation_hashes(self, target: str) -> List[str]:
|
||||
assert self.metadata_store
|
||||
token = str(target or "").strip()
|
||||
if not token:
|
||||
return []
|
||||
if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()):
|
||||
return [token]
|
||||
hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10)
|
||||
if hashes:
|
||||
return hashes
|
||||
return [str(row.get("hash", "") or "") for row in self.metadata_store.get_relations(subject=token)[:10] if str(row.get("hash", "")).strip()]
|
||||
53
plugins/A_memorix/core/storage/__init__.py
Normal file
53
plugins/A_memorix/core/storage/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""存储层"""
|
||||
|
||||
from .vector_store import VectorStore, QuantizationType
|
||||
from .graph_store import GraphStore, SparseMatrixFormat
|
||||
from .metadata_store import MetadataStore
|
||||
from .knowledge_types import (
|
||||
ImportStrategy,
|
||||
KnowledgeType,
|
||||
allowed_import_strategy_values,
|
||||
allowed_knowledge_type_values,
|
||||
get_knowledge_type_from_string,
|
||||
get_import_strategy_from_string,
|
||||
parse_import_strategy,
|
||||
resolve_stored_knowledge_type,
|
||||
should_extract_relations,
|
||||
get_default_chunk_size,
|
||||
get_type_display_name,
|
||||
validate_stored_knowledge_type,
|
||||
)
|
||||
from .type_detection import (
|
||||
detect_knowledge_type,
|
||||
get_type_from_user_input,
|
||||
looks_like_factual_text,
|
||||
looks_like_quote_text,
|
||||
looks_like_structured_text,
|
||||
select_import_strategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VectorStore",
|
||||
"GraphStore",
|
||||
"MetadataStore",
|
||||
"QuantizationType",
|
||||
"SparseMatrixFormat",
|
||||
"ImportStrategy",
|
||||
"KnowledgeType",
|
||||
"allowed_import_strategy_values",
|
||||
"allowed_knowledge_type_values",
|
||||
"get_knowledge_type_from_string",
|
||||
"get_import_strategy_from_string",
|
||||
"parse_import_strategy",
|
||||
"resolve_stored_knowledge_type",
|
||||
"should_extract_relations",
|
||||
"get_default_chunk_size",
|
||||
"get_type_display_name",
|
||||
"validate_stored_knowledge_type",
|
||||
"detect_knowledge_type",
|
||||
"get_type_from_user_input",
|
||||
"looks_like_factual_text",
|
||||
"looks_like_quote_text",
|
||||
"looks_like_structured_text",
|
||||
"select_import_strategy",
|
||||
]
|
||||
1434
plugins/A_memorix/core/storage/graph_store.py
Normal file
1434
plugins/A_memorix/core/storage/graph_store.py
Normal file
File diff suppressed because it is too large
Load Diff
183
plugins/A_memorix/core/storage/knowledge_types.py
Normal file
183
plugins/A_memorix/core/storage/knowledge_types.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Knowledge type and import strategy helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class KnowledgeType(str, Enum):
|
||||
"""持久化到 paragraphs.knowledge_type 的合法类型。"""
|
||||
|
||||
STRUCTURED = "structured"
|
||||
NARRATIVE = "narrative"
|
||||
FACTUAL = "factual"
|
||||
QUOTE = "quote"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class ImportStrategy(str, Enum):
|
||||
"""文本导入阶段的策略选择。"""
|
||||
|
||||
AUTO = "auto"
|
||||
NARRATIVE = "narrative"
|
||||
FACTUAL = "factual"
|
||||
QUOTE = "quote"
|
||||
|
||||
|
||||
def allowed_knowledge_type_values() -> tuple[str, ...]:
|
||||
return tuple(item.value for item in KnowledgeType)
|
||||
|
||||
|
||||
def allowed_import_strategy_values() -> tuple[str, ...]:
|
||||
return tuple(item.value for item in ImportStrategy)
|
||||
|
||||
|
||||
def get_knowledge_type_from_string(type_str: Any) -> Optional[KnowledgeType]:
|
||||
"""从字符串解析合法的落库知识类型。"""
|
||||
|
||||
if not isinstance(type_str, str):
|
||||
return None
|
||||
normalized = type_str.lower().strip()
|
||||
try:
|
||||
return KnowledgeType(normalized)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def get_import_strategy_from_string(value: Any) -> Optional[ImportStrategy]:
|
||||
"""从字符串解析文本导入策略。"""
|
||||
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
normalized = value.lower().strip()
|
||||
try:
|
||||
return ImportStrategy(normalized)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def parse_import_strategy(value: Any, default: ImportStrategy = ImportStrategy.AUTO) -> ImportStrategy:
|
||||
"""解析 import strategy;非法值直接报错。"""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, ImportStrategy):
|
||||
return value
|
||||
|
||||
normalized = str(value or "").strip().lower()
|
||||
if not normalized:
|
||||
return default
|
||||
|
||||
strategy = get_import_strategy_from_string(normalized)
|
||||
if strategy is None:
|
||||
allowed = "/".join(allowed_import_strategy_values())
|
||||
raise ValueError(f"strategy_override 必须为 {allowed}")
|
||||
return strategy
|
||||
|
||||
|
||||
def validate_stored_knowledge_type(value: Any) -> KnowledgeType:
|
||||
"""校验写库 knowledge_type,仅允许合法落库类型。"""
|
||||
|
||||
if isinstance(value, KnowledgeType):
|
||||
return value
|
||||
|
||||
resolved = get_knowledge_type_from_string(value)
|
||||
if resolved is None:
|
||||
allowed = "/".join(allowed_knowledge_type_values())
|
||||
raise ValueError(f"knowledge_type 必须为 {allowed}")
|
||||
return resolved
|
||||
|
||||
|
||||
def resolve_stored_knowledge_type(
|
||||
value: Any,
|
||||
*,
|
||||
content: str = "",
|
||||
allow_legacy: bool = False,
|
||||
unknown_fallback: Optional[KnowledgeType] = None,
|
||||
) -> KnowledgeType:
|
||||
"""
|
||||
将策略/字符串/旧值解析为合法落库类型。
|
||||
|
||||
`allow_legacy=True` 仅供迁移使用。
|
||||
"""
|
||||
|
||||
if isinstance(value, KnowledgeType):
|
||||
return value
|
||||
|
||||
if isinstance(value, ImportStrategy):
|
||||
if value == ImportStrategy.AUTO:
|
||||
if not str(content or "").strip():
|
||||
raise ValueError("knowledge_type=auto 需要 content 才能推断")
|
||||
from .type_detection import detect_knowledge_type
|
||||
|
||||
return detect_knowledge_type(content)
|
||||
return KnowledgeType(value.value)
|
||||
|
||||
raw = str(value or "").strip()
|
||||
if not raw:
|
||||
if str(content or "").strip():
|
||||
from .type_detection import detect_knowledge_type
|
||||
|
||||
return detect_knowledge_type(content)
|
||||
raise ValueError("knowledge_type 不能为空")
|
||||
|
||||
direct = get_knowledge_type_from_string(raw)
|
||||
if direct is not None:
|
||||
return direct
|
||||
|
||||
strategy = get_import_strategy_from_string(raw)
|
||||
if strategy is not None:
|
||||
return resolve_stored_knowledge_type(strategy, content=content)
|
||||
|
||||
if allow_legacy:
|
||||
normalized = raw.lower()
|
||||
if normalized == "imported":
|
||||
return KnowledgeType.FACTUAL
|
||||
if str(content or "").strip():
|
||||
from .type_detection import detect_knowledge_type
|
||||
|
||||
detected = detect_knowledge_type(content)
|
||||
if detected is not None:
|
||||
return detected
|
||||
if unknown_fallback is not None:
|
||||
return unknown_fallback
|
||||
|
||||
allowed = "/".join(allowed_knowledge_type_values())
|
||||
raise ValueError(f"非法 knowledge_type: {raw}(仅允许 {allowed})")
|
||||
|
||||
|
||||
def should_extract_relations(knowledge_type: KnowledgeType) -> bool:
|
||||
"""判断是否应该做关系抽取。"""
|
||||
|
||||
return knowledge_type in [
|
||||
KnowledgeType.STRUCTURED,
|
||||
KnowledgeType.FACTUAL,
|
||||
KnowledgeType.MIXED,
|
||||
]
|
||||
|
||||
|
||||
def get_default_chunk_size(knowledge_type: KnowledgeType) -> int:
|
||||
"""获取默认分块大小。"""
|
||||
|
||||
chunk_sizes = {
|
||||
KnowledgeType.STRUCTURED: 300,
|
||||
KnowledgeType.NARRATIVE: 800,
|
||||
KnowledgeType.FACTUAL: 500,
|
||||
KnowledgeType.QUOTE: 400,
|
||||
KnowledgeType.MIXED: 500,
|
||||
}
|
||||
return chunk_sizes.get(knowledge_type, 500)
|
||||
|
||||
|
||||
def get_type_display_name(knowledge_type: KnowledgeType) -> str:
|
||||
"""获取知识类型中文名称。"""
|
||||
|
||||
display_names = {
|
||||
KnowledgeType.STRUCTURED: "结构化知识",
|
||||
KnowledgeType.NARRATIVE: "叙事性文本",
|
||||
KnowledgeType.FACTUAL: "事实陈述",
|
||||
KnowledgeType.QUOTE: "引用文本",
|
||||
KnowledgeType.MIXED: "混合类型",
|
||||
}
|
||||
return display_names.get(knowledge_type, "未知类型")
|
||||
5225
plugins/A_memorix/core/storage/metadata_store.py
Normal file
5225
plugins/A_memorix/core/storage/metadata_store.py
Normal file
File diff suppressed because it is too large
Load Diff
137
plugins/A_memorix/core/storage/type_detection.py
Normal file
137
plugins/A_memorix/core/storage/type_detection.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Heuristic detection for import strategies and stored knowledge types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from .knowledge_types import (
|
||||
ImportStrategy,
|
||||
KnowledgeType,
|
||||
parse_import_strategy,
|
||||
resolve_stored_knowledge_type,
|
||||
)
|
||||
|
||||
|
||||
_NARRATIVE_MARKERS = [
|
||||
r"然后",
|
||||
r"接着",
|
||||
r"于是",
|
||||
r"后来",
|
||||
r"最后",
|
||||
r"突然",
|
||||
r"一天",
|
||||
r"曾经",
|
||||
r"有一次",
|
||||
r"从前",
|
||||
r"说道",
|
||||
r"问道",
|
||||
r"想着",
|
||||
r"觉得",
|
||||
]
|
||||
_FACTUAL_MARKERS = [
|
||||
r"是",
|
||||
r"有",
|
||||
r"在",
|
||||
r"为",
|
||||
r"属于",
|
||||
r"位于",
|
||||
r"包含",
|
||||
r"拥有",
|
||||
r"成立于",
|
||||
r"出生于",
|
||||
]
|
||||
|
||||
|
||||
def _non_empty_lines(content: str) -> list[str]:
|
||||
return [line for line in str(content or "").splitlines() if line.strip()]
|
||||
|
||||
|
||||
def looks_like_structured_text(content: str) -> bool:
|
||||
text = str(content or "").strip()
|
||||
if "|" not in text or text.count("|") < 2:
|
||||
return False
|
||||
parts = text.split("|")
|
||||
return len(parts) == 3 and all(part.strip() for part in parts)
|
||||
|
||||
|
||||
def looks_like_quote_text(content: str) -> bool:
|
||||
lines = _non_empty_lines(content)
|
||||
if len(lines) < 5:
|
||||
return False
|
||||
avg_len = sum(len(line) for line in lines) / len(lines)
|
||||
return avg_len < 20
|
||||
|
||||
|
||||
def looks_like_narrative_text(content: str) -> bool:
|
||||
text = str(content or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
narrative_score = sum(1 for marker in _NARRATIVE_MARKERS if re.search(marker, text))
|
||||
has_dialogue = bool(re.search(r'["「『].*?["」』]', text))
|
||||
has_chapter = any(token in text[:500] for token in ("Chapter", "CHAPTER", "###"))
|
||||
return has_chapter or has_dialogue or narrative_score >= 2
|
||||
|
||||
|
||||
def looks_like_factual_text(content: str) -> bool:
|
||||
text = str(content or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
if looks_like_structured_text(text) or looks_like_quote_text(text):
|
||||
return False
|
||||
|
||||
factual_score = sum(1 for marker in _FACTUAL_MARKERS if re.search(r"\s*" + marker + r"\s*", text))
|
||||
if factual_score <= 0:
|
||||
return False
|
||||
|
||||
if len(text) <= 240:
|
||||
return True
|
||||
return factual_score >= 2 and not looks_like_narrative_text(text)
|
||||
|
||||
|
||||
def select_import_strategy(
|
||||
content: str,
|
||||
*,
|
||||
override: Optional[str | ImportStrategy] = None,
|
||||
chat_log: bool = False,
|
||||
) -> ImportStrategy:
|
||||
"""文本导入策略选择:override > quote > factual > narrative。"""
|
||||
|
||||
if chat_log:
|
||||
return ImportStrategy.NARRATIVE
|
||||
|
||||
strategy = parse_import_strategy(override, default=ImportStrategy.AUTO)
|
||||
if strategy != ImportStrategy.AUTO:
|
||||
return strategy
|
||||
|
||||
if looks_like_quote_text(content):
|
||||
return ImportStrategy.QUOTE
|
||||
if looks_like_factual_text(content):
|
||||
return ImportStrategy.FACTUAL
|
||||
return ImportStrategy.NARRATIVE
|
||||
|
||||
|
||||
def detect_knowledge_type(content: str) -> KnowledgeType:
|
||||
"""自动检测落库 knowledge_type;无法可靠判断时回退 mixed。"""
|
||||
|
||||
text = str(content or "").strip()
|
||||
if not text:
|
||||
return KnowledgeType.MIXED
|
||||
if looks_like_structured_text(text):
|
||||
return KnowledgeType.STRUCTURED
|
||||
if looks_like_quote_text(text):
|
||||
return KnowledgeType.QUOTE
|
||||
if looks_like_factual_text(text):
|
||||
return KnowledgeType.FACTUAL
|
||||
if looks_like_narrative_text(text):
|
||||
return KnowledgeType.NARRATIVE
|
||||
return KnowledgeType.MIXED
|
||||
|
||||
|
||||
def get_type_from_user_input(type_hint: Optional[str], content: str) -> KnowledgeType:
|
||||
"""优先使用显式 type_hint,否则自动检测。"""
|
||||
|
||||
if type_hint:
|
||||
return resolve_stored_knowledge_type(type_hint, content=content)
|
||||
return detect_knowledge_type(content)
|
||||
776
plugins/A_memorix/core/storage/vector_store.py
Normal file
776
plugins/A_memorix/core/storage/vector_store.py
Normal file
@@ -0,0 +1,776 @@
|
||||
"""
|
||||
向量存储模块
|
||||
|
||||
基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import hashlib
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Tuple, List, Dict, Set, Any
|
||||
import random
|
||||
import threading # Added threading import
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import faiss
|
||||
HAS_FAISS = True
|
||||
except ImportError:
|
||||
HAS_FAISS = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..utils.quantization import QuantizationType
|
||||
from ..utils.io import atomic_write, atomic_save_path
|
||||
|
||||
logger = get_logger("A_Memorix.VectorStore")
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
向量存储类 (SQ8 + Append-Only Disk)
|
||||
|
||||
特性:
|
||||
- 索引: IndexIDMap2(IndexScalarQuantizer(QT_8bit))
|
||||
- 存储: float16 on-disk binary (vectors.bin)
|
||||
- 内存: 仅索引常驻 RAM (<512MB for 100k vectors)
|
||||
- ID: SHA1-based stable int64 IDs
|
||||
- 一致性: 强制 L2 Normalization (IP == Cosine)
|
||||
"""
|
||||
|
||||
# 默认训练触发阈值 (40 样本,过大可能导致小数据集不生效,过小可能量化退化)
|
||||
DEFAULT_MIN_TRAIN = 40
|
||||
# 强制训练样本量
|
||||
TRAIN_SIZE = 10000
|
||||
# 储水池采样上限 (流式处理前 50k 数据)
|
||||
RESERVOIR_CAPACITY = 10000
|
||||
RESERVOIR_SAMPLE_SCOPE = 50000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int,
|
||||
quantization_type: QuantizationType = QuantizationType.INT8,
|
||||
index_type: str = "sq8",
|
||||
data_dir: Optional[Union[str, Path]] = None,
|
||||
use_mmap: bool = True,
|
||||
buffer_size: int = 1024,
|
||||
):
|
||||
if not HAS_FAISS:
|
||||
raise ImportError("Faiss 未安装,请安装: pip install faiss-cpu")
|
||||
|
||||
self.dimension = dimension
|
||||
self.data_dir = Path(data_dir) if data_dir else None
|
||||
if self.data_dir:
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
if quantization_type != QuantizationType.INT8:
|
||||
raise ValueError(
|
||||
"vNext 仅支持 quantization_type=int8(SQ8)。"
|
||||
" 请更新配置并执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
normalized_index_type = str(index_type or "sq8").strip().lower()
|
||||
if normalized_index_type not in {"sq8", "int8"}:
|
||||
raise ValueError(
|
||||
"vNext 仅支持 index_type=sq8。"
|
||||
" 请更新配置并执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
self.quantization_type = QuantizationType.INT8
|
||||
self.index_type = "sq8"
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
self._index: Optional[faiss.IndexIDMap2] = None
|
||||
self._init_index()
|
||||
|
||||
self._is_trained = False
|
||||
self._vector_norm = "l2"
|
||||
|
||||
# Fallback Index (Flat) - 用于在 SQ8 训练完成前提供检索能力
|
||||
# 必须使用 IndexIDMap2 以保证 ID 与主索引一致
|
||||
self._fallback_index: Optional[faiss.IndexIDMap2] = None
|
||||
self._init_fallback_index()
|
||||
|
||||
self._known_hashes: Set[str] = set()
|
||||
self._deleted_ids: Set[int] = set()
|
||||
|
||||
self._reservoir_buffer: List[np.ndarray] = []
|
||||
self._seen_count_for_reservoir = 0
|
||||
|
||||
self._write_buffer_vecs: List[np.ndarray] = []
|
||||
self._write_buffer_ids: List[int] = []
|
||||
|
||||
self._total_added = 0
|
||||
self._total_deleted = 0
|
||||
self._bin_count = 0
|
||||
|
||||
# Thread safety lock
|
||||
self._lock = threading.RLock()
|
||||
|
||||
logger.info(f"VectorStore Init: dim={dimension}, SQ8 Mode, Append-Only Storage")
|
||||
|
||||
def _init_index(self):
|
||||
"""初始化空的 Faiss 索引"""
|
||||
quantizer = faiss.IndexScalarQuantizer(
|
||||
self.dimension,
|
||||
faiss.ScalarQuantizer.QT_8bit,
|
||||
faiss.METRIC_INNER_PRODUCT
|
||||
)
|
||||
self._index = faiss.IndexIDMap2(quantizer)
|
||||
self._is_trained = False
|
||||
|
||||
def _init_fallback_index(self):
|
||||
"""初始化 Flat 回退索引"""
|
||||
flat_index = faiss.IndexFlatIP(self.dimension)
|
||||
self._fallback_index = faiss.IndexIDMap2(flat_index)
|
||||
logger.debug("Fallback index (Flat) initialized.")
|
||||
|
||||
@staticmethod
|
||||
def _generate_id(key: str) -> int:
|
||||
"""生成稳定的 int64 ID (SHA1 截断)"""
|
||||
h = hashlib.sha1(key.encode("utf-8")).digest()
|
||||
val = int.from_bytes(h[:8], byteorder="big", signed=False)
|
||||
return val & 0x7FFFFFFFFFFFFFFF
|
||||
|
||||
@property
|
||||
def _bin_path(self) -> Path:
|
||||
return self.data_dir / "vectors.bin"
|
||||
|
||||
@property
|
||||
def _ids_bin_path(self) -> Path:
|
||||
return self.data_dir / "vectors_ids.bin"
|
||||
|
||||
@property
|
||||
def _int_to_str_map(self) -> Dict[int, str]:
|
||||
"""Lazy build volatile map from known hashes"""
|
||||
# Note: This is read-heavy and cached, might need lock if _known_hashes updates concurrently
|
||||
# But add/delete are now locked, so checking len mismatch is somewhat safe-ish for quick dirty cache
|
||||
if not hasattr(self, "_cached_map") or len(self._cached_map) != len(self._known_hashes):
|
||||
with self._lock: # Protect cache rebuild
|
||||
self._cached_map = {self._generate_id(k): k for k in self._known_hashes}
|
||||
return self._cached_map
|
||||
|
||||
def add(self, vectors: np.ndarray, ids: List[str]) -> int:
|
||||
with self._lock:
|
||||
if vectors.shape[1] != self.dimension:
|
||||
raise ValueError(f"Dimension mismatch: {vectors.shape[1]} vs {self.dimension}")
|
||||
|
||||
vectors = np.ascontiguousarray(vectors, dtype=np.float32)
|
||||
faiss.normalize_L2(vectors)
|
||||
|
||||
processed_vecs = []
|
||||
processed_int_ids = []
|
||||
|
||||
for i, str_id in enumerate(ids):
|
||||
if str_id in self._known_hashes:
|
||||
continue
|
||||
|
||||
int_id = self._generate_id(str_id)
|
||||
self._known_hashes.add(str_id)
|
||||
|
||||
processed_vecs.append(vectors[i])
|
||||
processed_int_ids.append(int_id)
|
||||
|
||||
if not processed_vecs:
|
||||
return 0
|
||||
|
||||
batch_vecs = np.array(processed_vecs, dtype=np.float32)
|
||||
batch_ids = np.array(processed_int_ids, dtype=np.int64)
|
||||
|
||||
self._write_buffer_vecs.append(batch_vecs)
|
||||
self._write_buffer_ids.extend(processed_int_ids)
|
||||
|
||||
if len(self._write_buffer_ids) >= self.buffer_size:
|
||||
self._flush_write_buffer_unlocked()
|
||||
|
||||
if not self._is_trained:
|
||||
# 双写到回退索引
|
||||
self._fallback_index.add_with_ids(batch_vecs, batch_ids)
|
||||
|
||||
self._update_reservoir(batch_vecs)
|
||||
# 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断
|
||||
if len(self._reservoir_buffer) >= 10000:
|
||||
logger.info(f"训练样本达到上限,开始训练...")
|
||||
self._train_and_replay_unlocked()
|
||||
|
||||
self._total_added += len(batch_ids)
|
||||
return len(batch_ids)
|
||||
|
||||
def _flush_write_buffer(self):
|
||||
with self._lock:
|
||||
self._flush_write_buffer_unlocked()
|
||||
|
||||
def _flush_write_buffer_unlocked(self):
|
||||
if not self._write_buffer_vecs:
|
||||
return
|
||||
|
||||
batch_vecs = np.concatenate(self._write_buffer_vecs, axis=0)
|
||||
batch_ids = np.array(self._write_buffer_ids, dtype=np.int64)
|
||||
|
||||
vecs_fp16 = batch_vecs.astype(np.float16)
|
||||
|
||||
with open(self._bin_path, "ab") as f:
|
||||
f.write(vecs_fp16.tobytes())
|
||||
|
||||
ids_bytes = batch_ids.astype('>i8').tobytes()
|
||||
with open(self._ids_bin_path, "ab") as f:
|
||||
f.write(ids_bytes)
|
||||
|
||||
self._bin_count += len(batch_ids)
|
||||
|
||||
if self._is_trained and self._index.is_trained:
|
||||
self._index.add_with_ids(batch_vecs, batch_ids)
|
||||
else:
|
||||
# 即使在 flush 时,如果未训练,也要同步到 fallback
|
||||
self._fallback_index.add_with_ids(batch_vecs, batch_ids)
|
||||
|
||||
self._write_buffer_vecs.clear()
|
||||
self._write_buffer_ids.clear()
|
||||
|
||||
def _update_reservoir(self, vectors: np.ndarray):
|
||||
for vec in vectors:
|
||||
self._seen_count_for_reservoir += 1
|
||||
if len(self._reservoir_buffer) < self.RESERVOIR_CAPACITY:
|
||||
self._reservoir_buffer.append(vec)
|
||||
else:
|
||||
if self._seen_count_for_reservoir <= self.RESERVOIR_SAMPLE_SCOPE:
|
||||
r = random.randint(0, self._seen_count_for_reservoir - 1)
|
||||
if r < self.RESERVOIR_CAPACITY:
|
||||
self._reservoir_buffer[r] = vec
|
||||
|
||||
def _train_and_replay(self):
|
||||
with self._lock:
|
||||
self._train_and_replay_unlocked()
|
||||
|
||||
def _train_and_replay_unlocked(self):
|
||||
if not self._reservoir_buffer:
|
||||
logger.warning("No training data available.")
|
||||
return
|
||||
|
||||
train_data = np.array(self._reservoir_buffer, dtype=np.float32)
|
||||
logger.info(f"Training Index with {len(train_data)} samples...")
|
||||
|
||||
try:
|
||||
self._index.train(train_data)
|
||||
except Exception as e:
|
||||
logger.error(f"SQ8 Training failed: {e}. Staying in fallback mode.")
|
||||
return
|
||||
|
||||
self._is_trained = True
|
||||
self._reservoir_buffer = []
|
||||
|
||||
logger.info("Replaying data from disk to populate index...")
|
||||
try:
|
||||
replay_count = self._replay_vectors_to_index()
|
||||
# 只有当 replay 成功且数据量一致时,才释放回退索引
|
||||
if self._index.ntotal >= self._bin_count:
|
||||
logger.info(f"Replay successful ({self._index.ntotal}/{self._bin_count}). Releasing fallback index.")
|
||||
self._fallback_index.reset()
|
||||
else:
|
||||
logger.warning(f"Replay count mismatch: {self._index.ntotal} vs {self._bin_count}. Keeping fallback index.")
|
||||
except Exception as e:
|
||||
logger.error(f"Replay failed: {e}. Keeping fallback index as backup.")
|
||||
|
||||
def _replay_vectors_to_index(self) -> int:
|
||||
"""从 vectors.bin 读取并添加到 index"""
|
||||
if not self._bin_path.exists() or not self._ids_bin_path.exists():
|
||||
return 0
|
||||
|
||||
vec_item_size = self.dimension * 2
|
||||
id_item_size = 8
|
||||
chunk_size = 10000
|
||||
|
||||
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id:
|
||||
while True:
|
||||
vec_data = f_vec.read(chunk_size * vec_item_size)
|
||||
id_data = f_id.read(chunk_size * id_item_size)
|
||||
|
||||
if not vec_data:
|
||||
break
|
||||
|
||||
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
|
||||
batch_fp32 = batch_fp16.astype(np.float32)
|
||||
faiss.normalize_L2(batch_fp32)
|
||||
|
||||
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
|
||||
|
||||
valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
|
||||
if not all(valid_mask):
|
||||
batch_fp32 = batch_fp32[valid_mask]
|
||||
batch_ids = batch_ids[valid_mask]
|
||||
|
||||
if len(batch_ids) > 0:
|
||||
self._index.add_with_ids(batch_fp32, batch_ids)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
k: int = 10,
|
||||
filter_deleted: bool = True,
|
||||
) -> Tuple[List[str], List[float]]:
|
||||
query_local = np.array(query, dtype=np.float32, order="C", copy=True)
|
||||
if query_local.ndim == 1:
|
||||
got_dim = int(query_local.shape[0])
|
||||
query_local = query_local.reshape(1, -1)
|
||||
elif query_local.ndim == 2:
|
||||
if query_local.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}"
|
||||
)
|
||||
got_dim = int(query_local.shape[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}"
|
||||
)
|
||||
|
||||
if got_dim != self.dimension:
|
||||
raise ValueError(
|
||||
f"query embedding dimension mismatch: expected={self.dimension} got={got_dim}"
|
||||
)
|
||||
if not np.all(np.isfinite(query_local)):
|
||||
raise ValueError("query embedding contains non-finite values")
|
||||
|
||||
faiss.normalize_L2(query_local)
|
||||
|
||||
# 查询路径仅负责检索,不在此触发训练/回放。
|
||||
# 训练/回放前置到 warmup_index(),并由插件启动阶段触发。
|
||||
# Faiss 索引在并发 search 下可能出现阻塞,这里串行化检索调用保证稳定性。
|
||||
with self._lock:
|
||||
self._flush_write_buffer_unlocked()
|
||||
search_index = self._index if (self._is_trained and self._index.ntotal > 0) else self._fallback_index
|
||||
if search_index.ntotal == 0:
|
||||
logger.warning("Indices are empty. No data to search.")
|
||||
return [], []
|
||||
# 执行检索
|
||||
dists, ids = search_index.search(query_local, k * 2)
|
||||
|
||||
# Faiss search 返回的是 (1, K) 的数组,取第一行
|
||||
dists = dists[0]
|
||||
ids = ids[0]
|
||||
|
||||
results = []
|
||||
for id_val, score in zip(ids, dists):
|
||||
if id_val == -1: continue
|
||||
if filter_deleted and id_val in self._deleted_ids:
|
||||
continue
|
||||
|
||||
str_id = self._int_to_str_map.get(id_val)
|
||||
if str_id:
|
||||
results.append((str_id, float(score)))
|
||||
|
||||
# Sort and trim just in case filtering reduced count
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
results = results[:k]
|
||||
|
||||
if not results:
|
||||
return [], []
|
||||
|
||||
return [r[0] for r in results], [r[1] for r in results]
|
||||
|
||||
def warmup_index(self, force_train: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
预热向量索引(训练/回放前置),避免首个线上查询触发重初始化。
|
||||
|
||||
Args:
|
||||
force_train: 是否在满足阈值时强制训练 SQ8 索引
|
||||
|
||||
Returns:
|
||||
预热状态摘要
|
||||
"""
|
||||
started = time.perf_counter()
|
||||
logger.info(f"metric.vector_index_prewarm_started=1 force_train={bool(force_train)}")
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
self._flush_write_buffer()
|
||||
|
||||
if self._bin_path.exists():
|
||||
self._bin_count = self._bin_path.stat().st_size // (self.dimension * 2)
|
||||
else:
|
||||
self._bin_count = 0
|
||||
|
||||
needs_fallback_bootstrap = (
|
||||
self._bin_count > 0
|
||||
and self._fallback_index.ntotal == 0
|
||||
and (not self._is_trained or self._index.ntotal == 0)
|
||||
)
|
||||
if needs_fallback_bootstrap:
|
||||
self._bootstrap_fallback_from_disk()
|
||||
|
||||
min_train = max(1, int(getattr(self, "min_train_threshold", self.DEFAULT_MIN_TRAIN)))
|
||||
needs_train = (
|
||||
bool(force_train)
|
||||
and self._bin_count >= min_train
|
||||
and not self._is_trained
|
||||
)
|
||||
if needs_train:
|
||||
self._force_train_small_data()
|
||||
|
||||
duration_ms = (time.perf_counter() - started) * 1000.0
|
||||
summary = {
|
||||
"ok": True,
|
||||
"trained": bool(self._is_trained),
|
||||
"index_ntotal": int(self._index.ntotal),
|
||||
"fallback_ntotal": int(self._fallback_index.ntotal),
|
||||
"bin_count": int(self._bin_count),
|
||||
"duration_ms": duration_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as e:
|
||||
duration_ms = (time.perf_counter() - started) * 1000.0
|
||||
summary = {
|
||||
"ok": False,
|
||||
"trained": bool(self._is_trained),
|
||||
"index_ntotal": int(self._index.ntotal) if self._index is not None else 0,
|
||||
"fallback_ntotal": int(self._fallback_index.ntotal) if self._fallback_index is not None else 0,
|
||||
"bin_count": int(getattr(self, "_bin_count", 0)),
|
||||
"duration_ms": duration_ms,
|
||||
"error": str(e),
|
||||
}
|
||||
logger.error(
|
||||
"metric.vector_index_prewarm_fail=1 "
|
||||
f"metric.vector_index_prewarm_duration_ms={duration_ms:.2f} "
|
||||
f"error={e}"
|
||||
)
|
||||
return summary
|
||||
|
||||
logger.info(
|
||||
"metric.vector_index_prewarm_success=1 "
|
||||
f"metric.vector_index_prewarm_duration_ms={summary['duration_ms']:.2f} "
|
||||
f"trained={summary['trained']} "
|
||||
f"index_ntotal={summary['index_ntotal']} "
|
||||
f"fallback_ntotal={summary['fallback_ntotal']} "
|
||||
f"bin_count={summary['bin_count']}"
|
||||
)
|
||||
return summary
|
||||
|
||||
def _bootstrap_fallback_from_disk(self):
|
||||
with self._lock:
|
||||
self._bootstrap_fallback_from_disk_unlocked()
|
||||
|
||||
def _bootstrap_fallback_from_disk_unlocked(self):
|
||||
"""重启后自举:从磁盘 vectors.bin 加载数据到 fallback 索引"""
|
||||
if not self._bin_path.exists() or not self._ids_bin_path.exists():
|
||||
return
|
||||
|
||||
logger.info("Replaying all disk vectors to fallback index...")
|
||||
vec_item_size = self.dimension * 2
|
||||
id_item_size = 8
|
||||
chunk_size = 10000
|
||||
|
||||
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id:
|
||||
while True:
|
||||
vec_data = f_vec.read(chunk_size * vec_item_size)
|
||||
id_data = f_id.read(chunk_size * id_item_size)
|
||||
if not vec_data: break
|
||||
|
||||
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
|
||||
batch_fp32 = batch_fp16.astype(np.float32)
|
||||
faiss.normalize_L2(batch_fp32)
|
||||
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
|
||||
|
||||
valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
|
||||
if any(valid_mask):
|
||||
self._fallback_index.add_with_ids(batch_fp32[valid_mask], batch_ids[valid_mask])
|
||||
|
||||
logger.info(f"Fallback index self-bootstrapped with {self._fallback_index.ntotal} items.")
|
||||
|
||||
def _force_train_small_data(self):
|
||||
with self._lock:
|
||||
self._force_train_small_data_unlocked()
|
||||
|
||||
def _force_train_small_data_unlocked(self):
|
||||
logger.info("Forcing training on small dataset...")
|
||||
self._reservoir_buffer = []
|
||||
|
||||
chunk_size = 10000
|
||||
vec_item_size = self.dimension * 2
|
||||
|
||||
with open(self._bin_path, "rb") as f:
|
||||
while len(self._reservoir_buffer) < self.TRAIN_SIZE:
|
||||
data = f.read(chunk_size * vec_item_size)
|
||||
if not data: break
|
||||
fp16 = np.frombuffer(data, dtype=np.float16).reshape(-1, self.dimension)
|
||||
fp32 = fp16.astype(np.float32)
|
||||
faiss.normalize_L2(fp32)
|
||||
|
||||
for vec in fp32:
|
||||
self._reservoir_buffer.append(vec)
|
||||
if len(self._reservoir_buffer) >= self.TRAIN_SIZE:
|
||||
break
|
||||
|
||||
self._train_and_replay_unlocked()
|
||||
|
||||
def delete(self, ids: List[str]) -> int:
|
||||
with self._lock:
|
||||
count = 0
|
||||
for str_id in ids:
|
||||
if str_id not in self._known_hashes:
|
||||
continue
|
||||
int_id = self._generate_id(str_id)
|
||||
if int_id not in self._deleted_ids:
|
||||
self._deleted_ids.add(int_id)
|
||||
if self._index.is_trained:
|
||||
self._index.remove_ids(np.array([int_id], dtype=np.int64))
|
||||
# 同步从 fallback 移除
|
||||
if self._fallback_index.ntotal > 0:
|
||||
self._fallback_index.remove_ids(np.array([int_id], dtype=np.int64))
|
||||
count += 1
|
||||
self._total_deleted += count
|
||||
|
||||
# Check GC
|
||||
self._check_rebuild_needed()
|
||||
return count
|
||||
|
||||
def _check_rebuild_needed(self):
|
||||
"""GC Excution Check"""
|
||||
if self._bin_count == 0: return
|
||||
ratio = len(self._deleted_ids) / self._bin_count
|
||||
if ratio > 0.3 and len(self._deleted_ids) > 1000:
|
||||
logger.info(f"Triggering GC/Rebuild (deleted ratio: {ratio:.2f})")
|
||||
self.rebuild_index()
|
||||
|
||||
def rebuild_index(self):
|
||||
"""GC: 重建索引,压缩 bin 文件"""
|
||||
with self._lock:
|
||||
self._rebuild_index_locked()
|
||||
|
||||
def _rebuild_index_locked(self):
|
||||
"""实际 GC 重建逻辑。"""
|
||||
logger.info("Starting Compaction (GC)...")
|
||||
|
||||
tmp_bin = self.data_dir / "vectors.bin.tmp"
|
||||
tmp_ids = self.data_dir / "vectors_ids.bin.tmp"
|
||||
|
||||
vec_item_size = self.dimension * 2
|
||||
id_item_size = 8
|
||||
chunk_size = 10000
|
||||
|
||||
new_count = 0
|
||||
|
||||
# 1. Compact Files
|
||||
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id, \
|
||||
open(tmp_bin, "wb") as w_vec, open(tmp_ids, "wb") as w_id:
|
||||
while True:
|
||||
vec_data = f_vec.read(chunk_size * vec_item_size)
|
||||
id_data = f_id.read(chunk_size * id_item_size)
|
||||
if not vec_data: break
|
||||
|
||||
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
|
||||
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
|
||||
|
||||
keep_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
|
||||
|
||||
if any(keep_mask):
|
||||
keep_vecs = batch_fp16[keep_mask]
|
||||
keep_ids = batch_ids[keep_mask]
|
||||
|
||||
w_vec.write(keep_vecs.tobytes())
|
||||
w_id.write(keep_ids.astype('>i8').tobytes())
|
||||
new_count += len(keep_ids)
|
||||
|
||||
# 2. Reset State & Atomic Swap
|
||||
self._bin_count = new_count
|
||||
|
||||
# Close current index
|
||||
self._index.reset()
|
||||
if self._fallback_index: self._fallback_index.reset() # Also clear fallback
|
||||
self._is_trained = False
|
||||
|
||||
# Swap files
|
||||
shutil.move(str(tmp_bin), str(self._bin_path))
|
||||
shutil.move(str(tmp_ids), str(self._ids_bin_path))
|
||||
|
||||
# Reset Tombstones (Critical)
|
||||
self._deleted_ids.clear()
|
||||
|
||||
# 3. Reload/Rebuild Index (Fresh Train)
|
||||
# We need to re-train because data distribution might have changed significantly after deletion
|
||||
self._init_index()
|
||||
self._init_fallback_index() # Re-init fallback too
|
||||
self._force_train_small_data() # This will train and replay from the NEW compact file
|
||||
|
||||
logger.info("Compaction Complete.")
|
||||
|
||||
def save(self, data_dir: Optional[Union[str, Path]] = None) -> None:
|
||||
with self._lock:
|
||||
if not data_dir:
|
||||
data_dir = self.data_dir
|
||||
if not data_dir:
|
||||
raise ValueError("No data_dir")
|
||||
|
||||
data_dir = Path(data_dir)
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._flush_write_buffer_unlocked()
|
||||
|
||||
if self._is_trained:
|
||||
index_path = data_dir / "vectors.index"
|
||||
with atomic_save_path(index_path) as tmp:
|
||||
faiss.write_index(self._index, tmp)
|
||||
|
||||
meta = {
|
||||
"dimension": self.dimension,
|
||||
"quantization_type": self.quantization_type.value,
|
||||
"is_trained": self._is_trained,
|
||||
"vector_norm": self._vector_norm,
|
||||
"deleted_ids": list(self._deleted_ids),
|
||||
"known_hashes": list(self._known_hashes),
|
||||
}
|
||||
|
||||
with atomic_write(data_dir / "vectors_metadata.pkl", "wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
|
||||
logger.info("VectorStore saved.")
|
||||
|
||||
def migrate_legacy_npy(self, data_dir: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
离线迁移入口:将 legacy vectors.npy 转为 vNext 二进制格式。
|
||||
"""
|
||||
with self._lock:
|
||||
target_dir = Path(data_dir) if data_dir else self.data_dir
|
||||
if target_dir is None:
|
||||
raise ValueError("No data_dir")
|
||||
target_dir = Path(target_dir)
|
||||
npy_path = target_dir / "vectors.npy"
|
||||
idx_path = target_dir / "vectors.index"
|
||||
bin_path = target_dir / "vectors.bin"
|
||||
ids_bin_path = target_dir / "vectors_ids.bin"
|
||||
meta_path = target_dir / "vectors_metadata.pkl"
|
||||
|
||||
if not npy_path.exists():
|
||||
return {"migrated": False, "reason": "npy_missing"}
|
||||
if not meta_path.exists():
|
||||
raise RuntimeError("legacy vectors.npy migration requires vectors_metadata.pkl")
|
||||
if bin_path.exists() and ids_bin_path.exists():
|
||||
return {"migrated": False, "reason": "bin_exists"}
|
||||
|
||||
# Reset in-memory state to avoid appending to stale runtime buffers.
|
||||
self._known_hashes.clear()
|
||||
self._deleted_ids.clear()
|
||||
self._write_buffer_vecs.clear()
|
||||
self._write_buffer_ids.clear()
|
||||
self._init_index()
|
||||
self._init_fallback_index()
|
||||
self._is_trained = False
|
||||
self._bin_count = 0
|
||||
|
||||
self._migrate_from_npy_unlocked(npy_path, idx_path, target_dir)
|
||||
self.save(target_dir)
|
||||
return {"migrated": True, "reason": "ok"}
|
||||
|
||||
def load(self, data_dir: Optional[Union[str, Path]] = None) -> None:
|
||||
with self._lock:
|
||||
if not data_dir: data_dir = self.data_dir
|
||||
data_dir = Path(data_dir)
|
||||
|
||||
npy_path = data_dir / "vectors.npy"
|
||||
idx_path = data_dir / "vectors.index"
|
||||
bin_path = data_dir / "vectors.bin"
|
||||
|
||||
if npy_path.exists() and not bin_path.exists():
|
||||
raise RuntimeError(
|
||||
"检测到 legacy vectors.npy,vNext 不再支持运行时自动迁移。"
|
||||
" 请先执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
|
||||
meta_path = data_dir / "vectors_metadata.pkl"
|
||||
if not meta_path.exists():
|
||||
logger.warning("No metadata found, initialized empty.")
|
||||
return
|
||||
|
||||
with open(meta_path, "rb") as f:
|
||||
meta = pickle.load(f)
|
||||
|
||||
if meta.get("vector_norm") != "l2":
|
||||
logger.warning("Index IDMap2 version mismatch (L2 Norm), forcing rebuild...")
|
||||
self._known_hashes = set(meta.get("ids", [])) | set(meta.get("known_hashes", []))
|
||||
self._deleted_ids = set(meta.get("deleted_ids", []))
|
||||
self._init_index()
|
||||
self._force_train_small_data()
|
||||
return
|
||||
|
||||
self._is_trained = meta.get("is_trained", False)
|
||||
self._vector_norm = meta.get("vector_norm", "l2")
|
||||
self._deleted_ids = set(meta.get("deleted_ids", []))
|
||||
self._known_hashes = set(meta.get("known_hashes", []))
|
||||
|
||||
if self._is_trained:
|
||||
if idx_path.exists():
|
||||
try:
|
||||
self._index = faiss.read_index(str(idx_path))
|
||||
if not isinstance(self._index, faiss.IndexIDMap2):
|
||||
logger.warning("Loaded index type mismatch. Rebuilding...")
|
||||
self._init_index()
|
||||
self._force_train_small_data()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load index: {e}. Rebuilding...")
|
||||
self._init_index()
|
||||
self._force_train_small_data()
|
||||
else:
|
||||
logger.warning("Index file missing despite metadata indicating trained. Rebuilding from bin...")
|
||||
self._init_index()
|
||||
self._force_train_small_data()
|
||||
|
||||
if bin_path.exists():
|
||||
self._bin_count = bin_path.stat().st_size // (self.dimension * 2)
|
||||
|
||||
def _migrate_from_npy(self, npy_path, idx_path, data_dir):
|
||||
with self._lock:
|
||||
self._migrate_from_npy_unlocked(npy_path, idx_path, data_dir)
|
||||
|
||||
def _migrate_from_npy_unlocked(self, npy_path, idx_path, data_dir):
|
||||
try:
|
||||
arr = np.load(npy_path, mmap_mode="r")
|
||||
except Exception:
|
||||
arr = np.load(npy_path)
|
||||
|
||||
meta_path = data_dir / "vectors_metadata.pkl"
|
||||
old_ids = []
|
||||
if meta_path.exists():
|
||||
with open(meta_path, "rb") as f:
|
||||
m = pickle.load(f)
|
||||
old_ids = m.get("ids", [])
|
||||
|
||||
if len(arr) != len(old_ids):
|
||||
logger.error(f"Migration mismatch: arr {len(arr)} != ids {len(old_ids)}")
|
||||
return
|
||||
|
||||
logger.info(f"Migrating {len(arr)} vectors...")
|
||||
|
||||
chunk = 1000
|
||||
for i in range(0, len(arr), chunk):
|
||||
sub_arr = arr[i : i+chunk]
|
||||
sub_ids = old_ids[i : i+chunk]
|
||||
self.add(sub_arr, sub_ids)
|
||||
|
||||
if not self._is_trained:
|
||||
self._force_train_small_data()
|
||||
|
||||
shutil.move(str(npy_path), str(npy_path) + ".bak")
|
||||
if idx_path.exists():
|
||||
shutil.move(str(idx_path), str(idx_path) + ".bak")
|
||||
|
||||
logger.info("Migration complete.")
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._ids_bin_path.unlink(missing_ok=True)
|
||||
self._bin_path.unlink(missing_ok=True)
|
||||
self._init_index()
|
||||
self._known_hashes.clear()
|
||||
self._deleted_ids.clear()
|
||||
self._bin_count = 0
|
||||
logger.info("VectorStore cleared.")
|
||||
|
||||
def has_data(self) -> bool:
|
||||
return (self.data_dir / "vectors_metadata.pkl").exists()
|
||||
|
||||
@property
|
||||
def num_vectors(self) -> int:
|
||||
return len(self._known_hashes) - len(self._deleted_ids)
|
||||
|
||||
def __contains__(self, hash_value: str) -> bool:
|
||||
"""Check if a hash exists in the store"""
|
||||
return hash_value in self._known_hashes and self._generate_id(hash_value) not in self._deleted_ids
|
||||
|
||||
0
plugins/A_memorix/core/strategies/__init__.py
Normal file
0
plugins/A_memorix/core/strategies/__init__.py
Normal file
89
plugins/A_memorix/core/strategies/base.py
Normal file
89
plugins/A_memorix/core/strategies/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
|
||||
class KnowledgeType(str, Enum):
|
||||
NARRATIVE = "narrative"
|
||||
FACTUAL = "factual"
|
||||
QUOTE = "quote"
|
||||
MIXED = "mixed"
|
||||
|
||||
@dataclass
|
||||
class SourceInfo:
|
||||
file: str
|
||||
offset_start: int
|
||||
offset_end: int
|
||||
checksum: str = ""
|
||||
|
||||
@dataclass
|
||||
class ChunkContext:
|
||||
chunk_id: str
|
||||
index: int
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
text: str = ""
|
||||
|
||||
@dataclass
|
||||
class ChunkFlags:
|
||||
verbatim: bool = False
|
||||
requires_llm: bool = True
|
||||
|
||||
@dataclass
|
||||
class ProcessedChunk:
|
||||
type: KnowledgeType
|
||||
source: SourceInfo
|
||||
chunk: ChunkContext
|
||||
data: Dict[str, Any] = field(default_factory=dict) # triples、events、verbatim_entities
|
||||
flags: ChunkFlags = field(default_factory=ChunkFlags)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"source": {
|
||||
"file": self.source.file,
|
||||
"offset_start": self.source.offset_start,
|
||||
"offset_end": self.source.offset_end,
|
||||
"checksum": self.source.checksum
|
||||
},
|
||||
"chunk": {
|
||||
"text": self.chunk.text,
|
||||
"chunk_id": self.chunk.chunk_id,
|
||||
"index": self.chunk.index,
|
||||
"context": self.chunk.context
|
||||
},
|
||||
"data": self.data,
|
||||
"flags": {
|
||||
"verbatim": self.flags.verbatim,
|
||||
"requires_llm": self.flags.requires_llm
|
||||
}
|
||||
}
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
def __init__(self, filename: str):
|
||||
self.filename = filename
|
||||
|
||||
@abstractmethod
|
||||
def split(self, text: str) -> List[ProcessedChunk]:
|
||||
"""按策略将文本切分为块。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
|
||||
"""从文本块中抽取结构化信息。"""
|
||||
pass
|
||||
|
||||
def calculate_checksum(self, text: str) -> str:
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
def build_language_guard(self, text: str) -> str:
|
||||
"""
|
||||
构建统一的输出语言约束。
|
||||
不区分语言类型,仅要求抽取值保持原文语言,不做翻译。
|
||||
"""
|
||||
_ = text # 预留参数,便于后续按需扩展
|
||||
return (
|
||||
"Focus on the original source language. Keep extracted events, entities, predicates "
|
||||
"and objects in the same language as the source text, preserve names/terms as-is, "
|
||||
"and do not translate."
|
||||
)
|
||||
98
plugins/A_memorix/core/strategies/factual.py
Normal file
98
plugins/A_memorix/core/strategies/factual.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import re
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
||||
|
||||
class FactualStrategy(BaseStrategy):
|
||||
def split(self, text: str) -> List[ProcessedChunk]:
|
||||
# 结构感知切分
|
||||
lines = text.split('\n')
|
||||
chunks = []
|
||||
current_chunk_lines = []
|
||||
current_len = 0
|
||||
target_size = 600
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 判断是否应当切分
|
||||
# 若当前行为列表项/定义/表格行,则尽量不切分
|
||||
is_structure = self._is_structural_line(line)
|
||||
|
||||
current_len += len(line) + 1
|
||||
current_chunk_lines.append(line)
|
||||
|
||||
# 达到目标长度且不在紧凑结构块内时切分(过长时强制切分)
|
||||
if current_len >= target_size and not is_structure:
|
||||
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
|
||||
current_chunk_lines = []
|
||||
current_len = 0
|
||||
elif current_len >= target_size * 2: # 超长时强制切分
|
||||
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
|
||||
current_chunk_lines = []
|
||||
current_len = 0
|
||||
|
||||
if current_chunk_lines:
|
||||
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
|
||||
|
||||
return chunks
|
||||
|
||||
def _is_structural_line(self, line: str) -> bool:
|
||||
line = line.strip()
|
||||
if not line: return False
|
||||
# 列表项
|
||||
if re.match(r'^[\-\*]\s+', line) or re.match(r'^\d+\.\s+', line):
|
||||
return True
|
||||
# 定义项(术语: 定义)
|
||||
if re.match(r'^[^::]+[::].+', line):
|
||||
return True
|
||||
# 表格行(按 markdown 语法假设)
|
||||
if line.startswith('|') and line.endswith('|'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _create_chunk(self, lines: List[str], index: int) -> ProcessedChunk:
|
||||
text = "\n".join(lines)
|
||||
return ProcessedChunk(
|
||||
type=KnowledgeType.FACTUAL,
|
||||
source=SourceInfo(
|
||||
file=self.filename,
|
||||
offset_start=0, # 简化处理:真实偏移跟踪需要额外状态
|
||||
offset_end=0,
|
||||
checksum=self.calculate_checksum(text)
|
||||
),
|
||||
chunk=ChunkContext(
|
||||
chunk_id=f"{self.filename}_{index}",
|
||||
index=index,
|
||||
text=text
|
||||
)
|
||||
)
|
||||
|
||||
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
|
||||
if not llm_func:
|
||||
raise ValueError("LLM function required for Factual extraction")
|
||||
|
||||
language_guard = self.build_language_guard(chunk.chunk.text)
|
||||
prompt = f"""You are a factual knowledge extraction engine.
|
||||
Extract factual triples and entities from the text.
|
||||
Preserve lists and definitions accurately.
|
||||
|
||||
Language constraints:
|
||||
- {language_guard}
|
||||
- Preserve original names and domain terms exactly when possible.
|
||||
- JSON keys must stay exactly as: triples, entities, subject, predicate, object.
|
||||
|
||||
Text:
|
||||
{chunk.chunk.text}
|
||||
|
||||
Return ONLY valid JSON:
|
||||
{{
|
||||
"triples": [
|
||||
{{"subject": "Entity", "predicate": "Relationship", "object": "Entity"}}
|
||||
],
|
||||
"entities": ["Entity1", "Entity2"]
|
||||
}}
|
||||
"""
|
||||
result = await llm_func(prompt)
|
||||
|
||||
# 结果保持原样存入 data,后续统一归一化流程会处理
|
||||
# vector_store 侧期望关系字段为 subject/predicate/object 映射形式
|
||||
chunk.data = result
|
||||
return chunk
|
||||
126
plugins/A_memorix/core/strategies/narrative.py
Normal file
126
plugins/A_memorix/core/strategies/narrative.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import re
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
||||
|
||||
class NarrativeStrategy(BaseStrategy):
|
||||
def split(self, text: str) -> List[ProcessedChunk]:
|
||||
scenes = self._split_into_scenes(text)
|
||||
chunks = []
|
||||
|
||||
for scene_idx, (scene_text, scene_title) in enumerate(scenes):
|
||||
scene_chunks = self._sliding_window(scene_text, scene_title, scene_idx)
|
||||
chunks.extend(scene_chunks)
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_into_scenes(self, text: str) -> List[tuple[str, str]]:
|
||||
"""按标题或分隔符把文本切分为场景。"""
|
||||
# 简单启发式:按 markdown 标题或特定分隔符切分
|
||||
# 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行
|
||||
# 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行
|
||||
scene_pattern_str = r'^(?:#{1,6}\s+.*|Chapter\s+\d+|^\*{3,}$|^={3,}$)'
|
||||
|
||||
# 保留分隔符,以便识别场景起点
|
||||
parts = re.split(f"({scene_pattern_str})", text, flags=re.MULTILINE)
|
||||
|
||||
scenes = []
|
||||
current_scene_title = "Start"
|
||||
current_scene_content = []
|
||||
|
||||
if parts and parts[0].strip() == "":
|
||||
parts = parts[1:]
|
||||
|
||||
for part in parts:
|
||||
if re.match(scene_pattern_str, part, re.MULTILINE):
|
||||
# 先保存上一段场景
|
||||
if current_scene_content:
|
||||
scenes.append(("".join(current_scene_content), current_scene_title))
|
||||
current_scene_content = []
|
||||
current_scene_title = part.strip()
|
||||
else:
|
||||
current_scene_content.append(part)
|
||||
|
||||
if current_scene_content:
|
||||
scenes.append(("".join(current_scene_content), current_scene_title))
|
||||
|
||||
# 若未识别到场景,则把全文视作单一场景
|
||||
if not scenes:
|
||||
scenes = [(text, "Whole Text")]
|
||||
|
||||
return scenes
|
||||
|
||||
def _sliding_window(self, text: str, scene_id: str, scene_idx: int, window_size=800, overlap=200) -> List[ProcessedChunk]:
|
||||
chunks = []
|
||||
if len(text) <= window_size:
|
||||
chunks.append(self._create_chunk(text, scene_id, scene_idx, 0, 0))
|
||||
return chunks
|
||||
|
||||
stride = window_size - overlap
|
||||
start = 0
|
||||
local_idx = 0
|
||||
while start < len(text):
|
||||
end = min(start + window_size, len(text))
|
||||
chunk_text = text[start:end]
|
||||
|
||||
# 尽量对齐到最近换行,避免生硬截断句子
|
||||
# 仅在未到文本尾部时进行回退
|
||||
if end < len(text):
|
||||
last_newline = chunk_text.rfind('\n')
|
||||
if last_newline > window_size // 2: # 仅在回退距离可接受时启用
|
||||
end = start + last_newline + 1
|
||||
chunk_text = text[start:end]
|
||||
|
||||
chunks.append(self._create_chunk(chunk_text, scene_id, scene_idx, local_idx, start))
|
||||
|
||||
start += len(chunk_text) - overlap if end < len(text) else len(chunk_text)
|
||||
local_idx += 1
|
||||
|
||||
return chunks
|
||||
|
||||
def _create_chunk(self, text: str, scene_id: str, scene_idx: int, local_idx: int, offset: int) -> ProcessedChunk:
|
||||
return ProcessedChunk(
|
||||
type=KnowledgeType.NARRATIVE,
|
||||
source=SourceInfo(
|
||||
file=self.filename,
|
||||
offset_start=offset,
|
||||
offset_end=offset + len(text),
|
||||
checksum=self.calculate_checksum(text)
|
||||
),
|
||||
chunk=ChunkContext(
|
||||
chunk_id=f"{self.filename}_{scene_idx}_{local_idx}",
|
||||
index=local_idx,
|
||||
text=text,
|
||||
context={"scene_id": scene_id}
|
||||
)
|
||||
)
|
||||
|
||||
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
|
||||
if not llm_func:
|
||||
raise ValueError("LLM function required for Narrative extraction")
|
||||
|
||||
language_guard = self.build_language_guard(chunk.chunk.text)
|
||||
prompt = f"""You are a narrative knowledge extraction engine.
|
||||
Extract key events and character relations from the scene text.
|
||||
|
||||
Language constraints:
|
||||
- {language_guard}
|
||||
- Preserve original names and terms exactly when possible.
|
||||
- JSON keys must stay exactly as: events, relations, subject, predicate, object.
|
||||
|
||||
Scene:
|
||||
{chunk.chunk.context.get('scene_id')}
|
||||
|
||||
Text:
|
||||
{chunk.chunk.text}
|
||||
|
||||
Return ONLY valid JSON:
|
||||
{{
|
||||
"events": ["event description 1", "event description 2"],
|
||||
"relations": [
|
||||
{{"subject": "CharacterA", "predicate": "relation", "object": "CharacterB"}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
result = await llm_func(prompt)
|
||||
chunk.data = result
|
||||
return chunk
|
||||
52
plugins/A_memorix/core/strategies/quote.py
Normal file
52
plugins/A_memorix/core/strategies/quote.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags
|
||||
|
||||
class QuoteStrategy(BaseStrategy):
|
||||
def split(self, text: str) -> List[ProcessedChunk]:
|
||||
# Split by double newlines (stanzas)
|
||||
stanzas = text.split("\n\n")
|
||||
chunks = []
|
||||
offset = 0
|
||||
|
||||
for idx, stanza in enumerate(stanzas):
|
||||
if not stanza.strip():
|
||||
offset += len(stanza) + 2
|
||||
continue
|
||||
|
||||
chunk = ProcessedChunk(
|
||||
type=KnowledgeType.QUOTE,
|
||||
source=SourceInfo(
|
||||
file=self.filename,
|
||||
offset_start=offset,
|
||||
offset_end=offset + len(stanza),
|
||||
checksum=self.calculate_checksum(stanza)
|
||||
),
|
||||
chunk=ChunkContext(
|
||||
chunk_id=f"{self.filename}_{idx}",
|
||||
index=idx,
|
||||
text=stanza
|
||||
),
|
||||
flags=ChunkFlags(
|
||||
verbatim=True,
|
||||
requires_llm=False # Default to no LLM, but can be overridden
|
||||
)
|
||||
)
|
||||
chunks.append(chunk)
|
||||
offset += len(stanza) + 2 # +2 for \n\n
|
||||
|
||||
return chunks
|
||||
|
||||
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
|
||||
# For quotes, the text itself is the entity/knowledge
|
||||
# We might use LLM to extract headers/metadata if requested, but core logic is pass-through
|
||||
|
||||
# Treat the whole chunk text as a verbatim entity
|
||||
chunk.data = {
|
||||
"verbatim_entities": [chunk.chunk.text]
|
||||
}
|
||||
|
||||
if llm_func and chunk.flags.requires_llm:
|
||||
# Optional: Extract metadata
|
||||
pass
|
||||
|
||||
return chunk
|
||||
33
plugins/A_memorix/core/utils/__init__.py
Normal file
33
plugins/A_memorix/core/utils/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""工具模块 - 哈希、监控等辅助功能"""
|
||||
|
||||
from .hash import compute_hash, normalize_text
|
||||
from .monitor import MemoryMonitor
|
||||
from .quantization import quantize_vector, dequantize_vector
|
||||
from .time_parser import (
|
||||
parse_query_datetime_to_timestamp,
|
||||
parse_query_time_range,
|
||||
parse_ingest_datetime_to_timestamp,
|
||||
normalize_time_meta,
|
||||
format_timestamp,
|
||||
)
|
||||
from .relation_write_service import RelationWriteService, RelationWriteResult
|
||||
from .relation_query import RelationQuerySpec, parse_relation_query_spec
|
||||
from .plugin_id_policy import PluginIdPolicy
|
||||
|
||||
__all__ = [
|
||||
"compute_hash",
|
||||
"normalize_text",
|
||||
"MemoryMonitor",
|
||||
"quantize_vector",
|
||||
"dequantize_vector",
|
||||
"parse_query_datetime_to_timestamp",
|
||||
"parse_query_time_range",
|
||||
"parse_ingest_datetime_to_timestamp",
|
||||
"normalize_time_meta",
|
||||
"format_timestamp",
|
||||
"RelationWriteService",
|
||||
"RelationWriteResult",
|
||||
"RelationQuerySpec",
|
||||
"parse_relation_query_spec",
|
||||
"PluginIdPolicy",
|
||||
]
|
||||
360
plugins/A_memorix/core/utils/aggregate_query_service.py
Normal file
360
plugins/A_memorix/core/utils/aggregate_query_service.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
聚合查询服务:
|
||||
- 并发执行 search/time/episode 分支
|
||||
- 统一分支结果结构
|
||||
- 可选混合排序(Weighted RRF)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.AggregateQueryService")
|
||||
|
||||
BranchRunner = Callable[[], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class AggregateQueryService:
|
||||
"""聚合查询执行服务(search/time/episode)。"""
|
||||
|
||||
def __init__(self, plugin_config: Optional[Any] = None):
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(self.plugin_config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _as_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return float(default)
|
||||
|
||||
@staticmethod
|
||||
def _as_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
def _rrf_k(self) -> float:
|
||||
raw = self._cfg("retrieval.aggregate.rrf_k", 60.0)
|
||||
value = self._as_float(raw, 60.0)
|
||||
return max(1.0, value)
|
||||
|
||||
def _weights(self) -> Dict[str, float]:
|
||||
defaults = {"search": 1.0, "time": 1.0, "episode": 1.0}
|
||||
raw = self._cfg("retrieval.aggregate.weights", {})
|
||||
if not isinstance(raw, dict):
|
||||
return defaults
|
||||
|
||||
out = dict(defaults)
|
||||
for key in ("search", "time", "episode"):
|
||||
if key in raw:
|
||||
out[key] = max(0.0, self._as_float(raw.get(key), defaults[key]))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _normalize_branch_payload(
|
||||
name: str,
|
||||
payload: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
data = payload if isinstance(payload, dict) else {}
|
||||
results_raw = data.get("results", [])
|
||||
results = results_raw if isinstance(results_raw, list) else []
|
||||
count = data.get("count")
|
||||
if count is None:
|
||||
count = len(results)
|
||||
return {
|
||||
"name": name,
|
||||
"success": bool(data.get("success", False)),
|
||||
"skipped": bool(data.get("skipped", False)),
|
||||
"skip_reason": str(data.get("skip_reason", "") or "").strip(),
|
||||
"error": str(data.get("error", "") or "").strip(),
|
||||
"results": results,
|
||||
"count": max(0, int(count)),
|
||||
"elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)),
|
||||
"content": str(data.get("content", "") or ""),
|
||||
"query_type": str(data.get("query_type", "") or name),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str:
|
||||
item_type = str(item.get("type", "") or "").strip().lower()
|
||||
if item_type == "episode":
|
||||
episode_id = str(item.get("episode_id", "") or "").strip()
|
||||
if episode_id:
|
||||
return f"episode:{episode_id}"
|
||||
|
||||
item_hash = str(item.get("hash", "") or "").strip()
|
||||
if item_hash:
|
||||
return f"{item_type}:{item_hash}"
|
||||
|
||||
return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}"
|
||||
|
||||
def _build_mixed_results(
|
||||
self,
|
||||
*,
|
||||
branches: Dict[str, Dict[str, Any]],
|
||||
top_k: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
rrf_k = self._rrf_k()
|
||||
weights = self._weights()
|
||||
bucket: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for branch_name, branch in branches.items():
|
||||
if not branch.get("success", False):
|
||||
continue
|
||||
results = branch.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
continue
|
||||
|
||||
weight = max(0.0, float(weights.get(branch_name, 1.0)))
|
||||
for idx, item in enumerate(results, start=1):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
key = self._mix_key(item, branch_name, idx)
|
||||
score = weight / (rrf_k + float(idx))
|
||||
if key not in bucket:
|
||||
merged = dict(item)
|
||||
merged["fusion_score"] = 0.0
|
||||
merged["_source_branches"] = set()
|
||||
bucket[key] = merged
|
||||
|
||||
target = bucket[key]
|
||||
target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score
|
||||
target["_source_branches"].add(branch_name)
|
||||
|
||||
mixed = list(bucket.values())
|
||||
mixed.sort(
|
||||
key=lambda x: (
|
||||
-float(x.get("fusion_score", 0.0)),
|
||||
str(x.get("type", "") or ""),
|
||||
str(x.get("hash", "") or x.get("episode_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
out: List[Dict[str, Any]] = []
|
||||
for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1):
|
||||
merged = dict(item)
|
||||
branches_set = merged.pop("_source_branches", set())
|
||||
merged["source_branches"] = sorted(list(branches_set))
|
||||
merged["rank"] = rank
|
||||
out.append(merged)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _status(branch: Dict[str, Any]) -> str:
|
||||
if branch.get("skipped", False):
|
||||
return "skipped"
|
||||
if branch.get("success", False):
|
||||
return "success"
|
||||
return "failed"
|
||||
|
||||
def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
summary: Dict[str, Dict[str, Any]] = {}
|
||||
for name, branch in branches.items():
|
||||
status = self._status(branch)
|
||||
summary[name] = {
|
||||
"status": status,
|
||||
"count": int(branch.get("count", 0) or 0),
|
||||
}
|
||||
if status == "skipped":
|
||||
summary[name]["reason"] = str(branch.get("skip_reason", "") or "")
|
||||
if status == "failed":
|
||||
summary[name]["error"] = str(branch.get("error", "") or "")
|
||||
return summary
|
||||
|
||||
def _build_content(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
branches: Dict[str, Dict[str, Any]],
|
||||
errors: List[Dict[str, str]],
|
||||
mixed_results: Optional[List[Dict[str, Any]]],
|
||||
) -> str:
|
||||
lines: List[str] = [
|
||||
f"🔀 聚合查询结果(query='{query or 'N/A'}')",
|
||||
"",
|
||||
"分支状态:",
|
||||
]
|
||||
for name in ("search", "time", "episode"):
|
||||
branch = branches.get(name, {})
|
||||
status = self._status(branch)
|
||||
count = int(branch.get("count", 0) or 0)
|
||||
line = f"- {name}: {status}, count={count}"
|
||||
reason = str(branch.get("skip_reason", "") or "").strip()
|
||||
err = str(branch.get("error", "") or "").strip()
|
||||
if status == "skipped" and reason:
|
||||
line += f" ({reason})"
|
||||
if status == "failed" and err:
|
||||
line += f" ({err})"
|
||||
lines.append(line)
|
||||
|
||||
if errors:
|
||||
lines.append("")
|
||||
lines.append("错误:")
|
||||
for item in errors[:6]:
|
||||
lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}")
|
||||
|
||||
if mixed_results is not None:
|
||||
lines.append("")
|
||||
lines.append(f"🧩 混合结果({len(mixed_results)} 条):")
|
||||
for idx, item in enumerate(mixed_results[:5], start=1):
|
||||
src = ",".join(item.get("source_branches", []) or [])
|
||||
if str(item.get("type", "") or "") == "episode":
|
||||
title = str(item.get("title", "") or "Untitled")
|
||||
lines.append(f"{idx}. 🧠 {title} [{src}]")
|
||||
else:
|
||||
text = str(item.get("content", "") or "")
|
||||
if len(text) > 80:
|
||||
text = text[:80] + "..."
|
||||
lines.append(f"{idx}. {text} [{src}]")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
top_k: int,
|
||||
mix: bool,
|
||||
mix_top_k: Optional[int],
|
||||
time_from: Optional[str],
|
||||
time_to: Optional[str],
|
||||
search_runner: Optional[BranchRunner],
|
||||
time_runner: Optional[BranchRunner],
|
||||
episode_runner: Optional[BranchRunner],
|
||||
) -> Dict[str, Any]:
|
||||
clean_query = str(query or "").strip()
|
||||
safe_top_k = max(1, int(top_k))
|
||||
safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k))
|
||||
|
||||
branches: Dict[str, Dict[str, Any]] = {}
|
||||
errors: List[Dict[str, str]] = []
|
||||
scheduled: List[Tuple[str, asyncio.Task]] = []
|
||||
|
||||
if clean_query:
|
||||
if search_runner is not None:
|
||||
scheduled.append(("search", asyncio.create_task(search_runner())))
|
||||
else:
|
||||
branches["search"] = self._normalize_branch_payload(
|
||||
"search",
|
||||
{"success": False, "error": "search runner unavailable", "results": []},
|
||||
)
|
||||
else:
|
||||
branches["search"] = self._normalize_branch_payload(
|
||||
"search",
|
||||
{
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"skip_reason": "missing_query",
|
||||
"results": [],
|
||||
"count": 0,
|
||||
},
|
||||
)
|
||||
|
||||
if time_from or time_to:
|
||||
if time_runner is not None:
|
||||
scheduled.append(("time", asyncio.create_task(time_runner())))
|
||||
else:
|
||||
branches["time"] = self._normalize_branch_payload(
|
||||
"time",
|
||||
{"success": False, "error": "time runner unavailable", "results": []},
|
||||
)
|
||||
else:
|
||||
branches["time"] = self._normalize_branch_payload(
|
||||
"time",
|
||||
{
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"skip_reason": "missing_time_window",
|
||||
"results": [],
|
||||
"count": 0,
|
||||
},
|
||||
)
|
||||
|
||||
if episode_runner is not None:
|
||||
scheduled.append(("episode", asyncio.create_task(episode_runner())))
|
||||
else:
|
||||
branches["episode"] = self._normalize_branch_payload(
|
||||
"episode",
|
||||
{"success": False, "error": "episode runner unavailable", "results": []},
|
||||
)
|
||||
|
||||
if scheduled:
|
||||
done = await asyncio.gather(
|
||||
*[task for _, task in scheduled],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for (branch_name, _), payload in zip(scheduled, done):
|
||||
if isinstance(payload, Exception):
|
||||
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
|
||||
normalized = self._normalize_branch_payload(
|
||||
branch_name,
|
||||
{
|
||||
"success": False,
|
||||
"error": str(payload),
|
||||
"results": [],
|
||||
},
|
||||
)
|
||||
else:
|
||||
normalized = self._normalize_branch_payload(branch_name, payload)
|
||||
branches[branch_name] = normalized
|
||||
|
||||
for name in ("search", "time", "episode"):
|
||||
branch = branches.get(name)
|
||||
if not branch:
|
||||
continue
|
||||
if branch.get("skipped", False):
|
||||
continue
|
||||
if not branch.get("success", False):
|
||||
errors.append(
|
||||
{
|
||||
"branch": name,
|
||||
"error": str(branch.get("error", "") or "unknown error"),
|
||||
}
|
||||
)
|
||||
|
||||
success = any(
|
||||
bool(branches.get(name, {}).get("success", False))
|
||||
for name in ("search", "time", "episode")
|
||||
)
|
||||
mixed_results: Optional[List[Dict[str, Any]]] = None
|
||||
if mix:
|
||||
mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"success": success,
|
||||
"query_type": "aggregate",
|
||||
"query": clean_query,
|
||||
"top_k": safe_top_k,
|
||||
"mix": bool(mix),
|
||||
"mix_top_k": safe_mix_top_k,
|
||||
"branches": branches,
|
||||
"errors": errors,
|
||||
"summary": self._build_summary(branches),
|
||||
}
|
||||
if mixed_results is not None:
|
||||
payload["mixed_results"] = mixed_results
|
||||
|
||||
payload["content"] = self._build_content(
|
||||
query=clean_query,
|
||||
branches=branches,
|
||||
errors=errors,
|
||||
mixed_results=mixed_results,
|
||||
)
|
||||
return payload
|
||||
182
plugins/A_memorix/core/utils/episode_retrieval_service.py
Normal file
182
plugins/A_memorix/core/utils/episode_retrieval_service.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Episode hybrid retrieval service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import DualPathRetriever, TemporalQueryOptions
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeRetrievalService")
|
||||
|
||||
|
||||
class EpisodeRetrievalService:
|
||||
"""Hybrid episode retrieval backed by lexical rows and evidence projection."""
|
||||
|
||||
_RRF_K = 60.0
|
||||
_BRANCH_WEIGHTS = {
|
||||
"lexical": 1.0,
|
||||
"paragraph_evidence": 1.0,
|
||||
"relation_evidence": 0.85,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_store: Any,
|
||||
retriever: Optional[DualPathRetriever] = None,
|
||||
) -> None:
|
||||
self.metadata_store = metadata_store
|
||||
self.retriever = retriever
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
query: str = "",
|
||||
top_k: int = 5,
|
||||
time_from: Optional[float] = None,
|
||||
time_to: Optional[float] = None,
|
||||
person: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
include_paragraphs: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
clean_query = str(query or "").strip()
|
||||
safe_top_k = max(1, int(top_k))
|
||||
candidate_k = max(30, safe_top_k * 6)
|
||||
|
||||
branches: Dict[str, List[Dict[str, Any]]] = {
|
||||
"lexical": self.metadata_store.query_episodes(
|
||||
query=clean_query,
|
||||
time_from=time_from,
|
||||
time_to=time_to,
|
||||
person=person,
|
||||
source=source,
|
||||
limit=(candidate_k if clean_query else safe_top_k),
|
||||
)
|
||||
}
|
||||
|
||||
if clean_query and self.retriever is not None:
|
||||
try:
|
||||
temporal = TemporalQueryOptions(
|
||||
time_from=time_from,
|
||||
time_to=time_to,
|
||||
person=person,
|
||||
source=source,
|
||||
)
|
||||
results = await self.retriever.retrieve(
|
||||
query=clean_query,
|
||||
top_k=candidate_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
|
||||
else:
|
||||
paragraph_rank_map: Dict[str, int] = {}
|
||||
relation_rank_map: Dict[str, int] = {}
|
||||
for rank, item in enumerate(results, start=1):
|
||||
hash_value = str(getattr(item, "hash_value", "") or "").strip()
|
||||
result_type = str(getattr(item, "result_type", "") or "").strip().lower()
|
||||
if not hash_value:
|
||||
continue
|
||||
if result_type == "paragraph" and hash_value not in paragraph_rank_map:
|
||||
paragraph_rank_map[hash_value] = rank
|
||||
elif result_type == "relation" and hash_value not in relation_rank_map:
|
||||
relation_rank_map[hash_value] = rank
|
||||
|
||||
if paragraph_rank_map:
|
||||
paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes(
|
||||
list(paragraph_rank_map.keys()),
|
||||
source=source,
|
||||
)
|
||||
if paragraph_rows:
|
||||
branches["paragraph_evidence"] = self._rank_projected_rows(
|
||||
paragraph_rows,
|
||||
rank_map=paragraph_rank_map,
|
||||
support_key="matched_paragraph_hashes",
|
||||
)
|
||||
|
||||
if relation_rank_map:
|
||||
relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes(
|
||||
list(relation_rank_map.keys()),
|
||||
source=source,
|
||||
)
|
||||
if relation_rows:
|
||||
branches["relation_evidence"] = self._rank_projected_rows(
|
||||
relation_rows,
|
||||
rank_map=relation_rank_map,
|
||||
support_key="matched_relation_hashes",
|
||||
)
|
||||
|
||||
fused = self._fuse_branches(branches, top_k=safe_top_k)
|
||||
if include_paragraphs:
|
||||
for item in fused:
|
||||
item["paragraphs"] = self.metadata_store.get_episode_paragraphs(
|
||||
episode_id=str(item.get("episode_id") or ""),
|
||||
limit=50,
|
||||
)
|
||||
return fused
|
||||
|
||||
@staticmethod
|
||||
def _rank_projected_rows(
|
||||
rows: List[Dict[str, Any]],
|
||||
*,
|
||||
rank_map: Dict[str, int],
|
||||
support_key: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
sentinel = 10**9
|
||||
ranked = [dict(item) for item in rows]
|
||||
|
||||
def _first_support_rank(item: Dict[str, Any]) -> int:
|
||||
support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])]
|
||||
ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map]
|
||||
return min(ranks) if ranks else sentinel
|
||||
|
||||
ranked.sort(
|
||||
key=lambda item: (
|
||||
_first_support_rank(item),
|
||||
-int(item.get("matched_paragraph_count") or 0),
|
||||
-float(item.get("updated_at") or 0.0),
|
||||
str(item.get("episode_id") or ""),
|
||||
)
|
||||
)
|
||||
return ranked
|
||||
|
||||
def _fuse_branches(
|
||||
self,
|
||||
branches: Dict[str, List[Dict[str, Any]]],
|
||||
*,
|
||||
top_k: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
bucket: Dict[str, Dict[str, Any]] = {}
|
||||
for branch_name, rows in branches.items():
|
||||
weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0)
|
||||
if weight <= 0.0:
|
||||
continue
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
episode_id = str(row.get("episode_id", "") or "").strip()
|
||||
if not episode_id:
|
||||
continue
|
||||
if episode_id not in bucket:
|
||||
payload = dict(row)
|
||||
payload.pop("matched_paragraph_hashes", None)
|
||||
payload.pop("matched_relation_hashes", None)
|
||||
payload.pop("matched_paragraph_count", None)
|
||||
payload.pop("matched_relation_count", None)
|
||||
payload["_fusion_score"] = 0.0
|
||||
bucket[episode_id] = payload
|
||||
bucket[episode_id]["_fusion_score"] = float(
|
||||
bucket[episode_id].get("_fusion_score", 0.0)
|
||||
) + weight / (self._RRF_K + float(rank))
|
||||
|
||||
out = list(bucket.values())
|
||||
out.sort(
|
||||
key=lambda item: (
|
||||
-float(item.get("_fusion_score", 0.0)),
|
||||
-float(item.get("updated_at") or 0.0),
|
||||
str(item.get("episode_id") or ""),
|
||||
)
|
||||
)
|
||||
for item in out:
|
||||
item.pop("_fusion_score", None)
|
||||
return out[: max(1, int(top_k))]
|
||||
129
plugins/A_memorix/core/utils/hash.py
Normal file
129
plugins/A_memorix/core/utils/hash.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
哈希工具模块
|
||||
|
||||
提供文本哈希计算功能,用于唯一标识和去重。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
|
||||
def compute_hash(text: str, hash_type: str = "sha256") -> str:
|
||||
"""
|
||||
计算文本的哈希值
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
hash_type: 哈希算法类型(sha256, md5等)
|
||||
|
||||
Returns:
|
||||
哈希值字符串
|
||||
"""
|
||||
if hash_type == "sha256":
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
elif hash_type == "md5":
|
||||
return hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
raise ValueError(f"不支持的哈希算法: {hash_type}")
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
规范化文本用于哈希计算
|
||||
|
||||
执行以下操作:
|
||||
- 去除首尾空白
|
||||
- 统一换行符为\\n
|
||||
- 压缩多个连续空格
|
||||
- 去除不可见字符(保留换行和制表符)
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
规范化后的文本
|
||||
"""
|
||||
# 去除首尾空白
|
||||
text = text.strip()
|
||||
|
||||
# 统一换行符
|
||||
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
# 压缩多个连续空格为一个(但保留换行和制表符)
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def compute_paragraph_hash(paragraph: str) -> str:
|
||||
"""
|
||||
计算段落的哈希值
|
||||
|
||||
Args:
|
||||
paragraph: 段落文本
|
||||
|
||||
Returns:
|
||||
段落哈希值(用于paragraph-前缀)
|
||||
"""
|
||||
normalized = normalize_text(paragraph)
|
||||
return compute_hash(normalized)
|
||||
|
||||
|
||||
def compute_entity_hash(entity: str) -> str:
|
||||
"""
|
||||
计算实体的哈希值
|
||||
|
||||
Args:
|
||||
entity: 实体名称
|
||||
|
||||
Returns:
|
||||
实体哈希值(用于entity-前缀)
|
||||
"""
|
||||
normalized = entity.strip().lower()
|
||||
return compute_hash(normalized)
|
||||
|
||||
|
||||
def compute_relation_hash(relation: tuple) -> str:
|
||||
"""
|
||||
计算关系的哈希值
|
||||
|
||||
Args:
|
||||
relation: 关系元组 (subject, predicate, object)
|
||||
|
||||
Returns:
|
||||
关系哈希值(用于relation-前缀)
|
||||
"""
|
||||
# 将关系元组转为字符串
|
||||
relation_str = str(tuple(relation))
|
||||
return compute_hash(relation_str)
|
||||
|
||||
|
||||
def format_hash_key(hash_type: str, hash_value: str) -> str:
|
||||
"""
|
||||
格式化哈希键
|
||||
|
||||
Args:
|
||||
hash_type: 类型前缀(paragraph, entity, relation)
|
||||
hash_value: 哈希值
|
||||
|
||||
Returns:
|
||||
格式化的键(如 paragraph-abc123...)
|
||||
"""
|
||||
return f"{hash_type}-{hash_value}"
|
||||
|
||||
|
||||
def parse_hash_key(key: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析哈希键
|
||||
|
||||
Args:
|
||||
key: 格式化的键(如 paragraph-abc123...)
|
||||
|
||||
Returns:
|
||||
(类型, 哈希值) 元组
|
||||
"""
|
||||
parts = key.split("-", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"无效的哈希键格式: {key}")
|
||||
return parts[0], parts[1]
|
||||
110
plugins/A_memorix/core/utils/import_payloads.py
Normal file
110
plugins/A_memorix/core/utils/import_payloads.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Shared import payload normalization helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..storage import KnowledgeType, resolve_stored_knowledge_type
|
||||
from .time_parser import normalize_time_meta
|
||||
|
||||
|
||||
def _normalize_entities(raw_entities: Any) -> List[str]:
|
||||
if not isinstance(raw_entities, list):
|
||||
return []
|
||||
out: List[str] = []
|
||||
seen = set()
|
||||
for item in raw_entities:
|
||||
name = str(item or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(name)
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]:
|
||||
if not isinstance(raw_relations, list):
|
||||
return []
|
||||
out: List[Dict[str, str]] = []
|
||||
for item in raw_relations:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
subject = str(item.get("subject", "")).strip()
|
||||
predicate = str(item.get("predicate", "")).strip()
|
||||
obj = str(item.get("object", "")).strip()
|
||||
if not (subject and predicate and obj):
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"subject": subject,
|
||||
"predicate": predicate,
|
||||
"object": obj,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def normalize_paragraph_import_item(
|
||||
item: Any,
|
||||
*,
|
||||
default_source: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Normalize one paragraph import item from text/json payloads."""
|
||||
|
||||
if isinstance(item, str):
|
||||
content = str(item)
|
||||
knowledge_type = resolve_stored_knowledge_type(None, content=content)
|
||||
return {
|
||||
"content": content,
|
||||
"knowledge_type": knowledge_type.value,
|
||||
"source": str(default_source or "").strip(),
|
||||
"time_meta": None,
|
||||
"entities": [],
|
||||
"relations": [],
|
||||
}
|
||||
|
||||
if not isinstance(item, dict) or "content" not in item:
|
||||
raise ValueError("段落项必须为字符串或包含 content 的对象")
|
||||
|
||||
content = str(item.get("content", "") or "")
|
||||
if not content.strip():
|
||||
raise ValueError("段落 content 不能为空")
|
||||
|
||||
raw_time_meta = {
|
||||
"event_time": item.get("event_time"),
|
||||
"event_time_start": item.get("event_time_start"),
|
||||
"event_time_end": item.get("event_time_end"),
|
||||
"time_range": item.get("time_range"),
|
||||
"time_granularity": item.get("time_granularity"),
|
||||
"time_confidence": item.get("time_confidence"),
|
||||
}
|
||||
time_meta_field = item.get("time_meta")
|
||||
if isinstance(time_meta_field, dict):
|
||||
raw_time_meta.update(time_meta_field)
|
||||
|
||||
knowledge_type_raw = item.get("knowledge_type")
|
||||
if knowledge_type_raw is None:
|
||||
knowledge_type_raw = item.get("type")
|
||||
knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content)
|
||||
source = str(item.get("source") or default_source or "").strip()
|
||||
if not source:
|
||||
source = str(default_source or "").strip()
|
||||
|
||||
normalized_time_meta = normalize_time_meta(raw_time_meta)
|
||||
return {
|
||||
"content": content,
|
||||
"knowledge_type": knowledge_type.value,
|
||||
"source": source,
|
||||
"time_meta": normalized_time_meta if normalized_time_meta else None,
|
||||
"entities": _normalize_entities(item.get("entities")),
|
||||
"relations": _normalize_relations(item.get("relations")),
|
||||
}
|
||||
|
||||
|
||||
def normalize_summary_knowledge_type(value: Any) -> KnowledgeType:
|
||||
"""Normalize config-driven summary knowledge type."""
|
||||
|
||||
return resolve_stored_knowledge_type(value, content="")
|
||||
84
plugins/A_memorix/core/utils/io.py
Normal file
84
plugins/A_memorix/core/utils/io.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
IO Utilities
|
||||
|
||||
提供原子文件写入等IO辅助功能。
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@contextlib.contextmanager
|
||||
def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs):
|
||||
"""
|
||||
原子文件写入上下文管理器
|
||||
|
||||
原理:
|
||||
1. 写入 .tmp 临时文件
|
||||
2. 写入成功后,使用 os.replace 原子替换目标文件
|
||||
3. 如果失败,自动删除临时文件
|
||||
|
||||
Args:
|
||||
file_path: 目标文件路径
|
||||
mode: 打开模式 ('w', 'wb' 等)
|
||||
encoding: 编码
|
||||
**kwargs: 传给 open() 的其他参数
|
||||
"""
|
||||
path = Path(file_path)
|
||||
# 确保父目录存在
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 临时文件路径
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
|
||||
try:
|
||||
with open(tmp_path, mode, encoding=encoding, **kwargs) as f:
|
||||
yield f
|
||||
|
||||
# 确保写入磁盘
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
# 原子替换 (Windows下可能需要先删除目标文件,但 os.replace 在 Py3.3+ 尽可能原子)
|
||||
# 注意: Windows 上如果有其他进程占用文件,os.replace 可能会失败
|
||||
os.replace(tmp_path, path)
|
||||
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
if tmp_path.exists():
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
|
||||
@contextlib.contextmanager
|
||||
def atomic_save_path(file_path: Union[str, Path]):
|
||||
"""
|
||||
提供临时路径用于原子保存 (针对只接受路径的API,如Faiss)
|
||||
|
||||
Args:
|
||||
file_path: 最终目标路径
|
||||
|
||||
Yields:
|
||||
tmp_path: 临时文件路径 (str)
|
||||
"""
|
||||
path = Path(file_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
|
||||
try:
|
||||
yield str(tmp_path)
|
||||
|
||||
if Path(tmp_path).exists():
|
||||
os.replace(tmp_path, path)
|
||||
|
||||
except Exception as e:
|
||||
if Path(tmp_path).exists():
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
89
plugins/A_memorix/core/utils/matcher.py
Normal file
89
plugins/A_memorix/core/utils/matcher.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
高效文本匹配工具模块
|
||||
|
||||
实现 Aho-Corasick 算法用于多模式匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Tuple, Set, Any
|
||||
from collections import deque
|
||||
|
||||
|
||||
class AhoCorasick:
|
||||
"""
|
||||
Aho-Corasick 自动机实现高效多模式匹配
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# next_states[state][char] = next_state
|
||||
self.next_states: List[Dict[str, int]] = [{}]
|
||||
# fail[state] = fail_state
|
||||
self.fail: List[int] = [0]
|
||||
# output[state] = set of patterns ending at this state
|
||||
self.output: List[Set[str]] = [set()]
|
||||
self.patterns: Set[str] = set()
|
||||
|
||||
def add_pattern(self, pattern: str):
|
||||
"""添加模式"""
|
||||
if not pattern:
|
||||
return
|
||||
self.patterns.add(pattern)
|
||||
state = 0
|
||||
for char in pattern:
|
||||
if char not in self.next_states[state]:
|
||||
new_state = len(self.next_states)
|
||||
self.next_states[state][char] = new_state
|
||||
self.next_states.append({})
|
||||
self.fail.append(0)
|
||||
self.output.append(set())
|
||||
state = self.next_states[state][char]
|
||||
self.output[state].add(pattern)
|
||||
|
||||
def build(self):
|
||||
"""构建失败指针"""
|
||||
queue = deque()
|
||||
# 处理第一层
|
||||
for char, state in self.next_states[0].items():
|
||||
queue.append(state)
|
||||
self.fail[state] = 0
|
||||
|
||||
while queue:
|
||||
r = queue.popleft()
|
||||
for char, s in self.next_states[r].items():
|
||||
queue.append(s)
|
||||
# 找到失败路径
|
||||
state = self.fail[r]
|
||||
while char not in self.next_states[state] and state != 0:
|
||||
state = self.fail[state]
|
||||
self.fail[s] = self.next_states[state].get(char, 0)
|
||||
# 合并输出
|
||||
self.output[s].update(self.output[self.fail[s]])
|
||||
|
||||
def search(self, text: str) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
在文本中搜索所有模式
|
||||
|
||||
Returns:
|
||||
[(结束索引, 匹配到的模式), ...]
|
||||
"""
|
||||
state = 0
|
||||
results = []
|
||||
for i, char in enumerate(text):
|
||||
while char not in self.next_states[state] and state != 0:
|
||||
state = self.fail[state]
|
||||
state = self.next_states[state].get(char, 0)
|
||||
for pattern in self.output[state]:
|
||||
results.append((i, pattern))
|
||||
return results
|
||||
|
||||
def find_all(self, text: str) -> Dict[str, int]:
|
||||
"""
|
||||
查找并统计所有模式出现次数
|
||||
|
||||
Returns:
|
||||
{模式: 出现次数}
|
||||
"""
|
||||
results = self.search(text)
|
||||
stats = {}
|
||||
for _, pattern in results:
|
||||
stats[pattern] = stats.get(pattern, 0) + 1
|
||||
return stats
|
||||
189
plugins/A_memorix/core/utils/monitor.py
Normal file
189
plugins/A_memorix/core/utils/monitor.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
内存监控模块
|
||||
|
||||
提供内存使用监控和预警功能。
|
||||
"""
|
||||
|
||||
import gc
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
try:
|
||||
import psutil
|
||||
HAS_PSUTIL = True
|
||||
except ImportError:
|
||||
HAS_PSUTIL = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.MemoryMonitor")
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""
|
||||
内存监控器
|
||||
|
||||
功能:
|
||||
- 实时监控内存使用
|
||||
- 超过阈值时触发警告
|
||||
- 支持自动垃圾回收
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_memory_mb: int,
|
||||
warning_threshold: float = 0.9,
|
||||
check_interval: float = 10.0,
|
||||
enable_auto_gc: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化内存监控器
|
||||
|
||||
Args:
|
||||
max_memory_mb: 最大内存限制(MB)
|
||||
warning_threshold: 警告阈值(0-1之间,默认0.9表示90%)
|
||||
check_interval: 检查间隔(秒)
|
||||
enable_auto_gc: 是否启用自动垃圾回收
|
||||
"""
|
||||
self.max_memory_mb = max_memory_mb
|
||||
self.warning_threshold = warning_threshold
|
||||
self.check_interval = check_interval
|
||||
self.enable_auto_gc = enable_auto_gc
|
||||
|
||||
self._running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._callbacks: list[Callable[[float, float], None]] = []
|
||||
|
||||
def start(self):
|
||||
"""启动监控"""
|
||||
if self._running:
|
||||
logger.warning("内存监控已在运行")
|
||||
return
|
||||
|
||||
if not HAS_PSUTIL:
|
||||
logger.warning("psutil 未安装,内存监控功能不可用")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)")
|
||||
|
||||
def stop(self):
|
||||
"""停止监控"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5.0)
|
||||
logger.info("内存监控已停止")
|
||||
|
||||
def register_callback(self, callback: Callable[[float, float], None]):
|
||||
"""
|
||||
注册内存超限回调函数
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收 (当前使用MB, 限制MB) 参数
|
||||
"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def get_current_memory_mb(self) -> float:
|
||||
"""
|
||||
获取当前进程内存使用量(MB)
|
||||
|
||||
Returns:
|
||||
内存使用量(MB)
|
||||
"""
|
||||
if not HAS_PSUTIL:
|
||||
# 降级方案:使用内置函数
|
||||
import sys
|
||||
return sys.getsizeof(gc.get_objects()) / 1024 / 1024
|
||||
|
||||
process = psutil.Process()
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
def get_memory_usage_ratio(self) -> float:
|
||||
"""
|
||||
获取内存使用率
|
||||
|
||||
Returns:
|
||||
使用率(0-1之间)
|
||||
"""
|
||||
current = self.get_current_memory_mb()
|
||||
return current / self.max_memory_mb if self.max_memory_mb > 0 else 0
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""监控循环"""
|
||||
while self._running:
|
||||
try:
|
||||
current_mb = self.get_current_memory_mb()
|
||||
ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0
|
||||
|
||||
# 检查是否超过阈值
|
||||
if ratio >= self.warning_threshold:
|
||||
logger.warning(
|
||||
f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB "
|
||||
f"({ratio*100:.1f}%)"
|
||||
)
|
||||
|
||||
# 触发回调
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
callback(current_mb, self.max_memory_mb)
|
||||
except Exception as e:
|
||||
logger.error(f"内存回调执行失败: {e}")
|
||||
|
||||
# 自动垃圾回收
|
||||
if self.enable_auto_gc:
|
||||
before = self.get_current_memory_mb()
|
||||
gc.collect()
|
||||
after = self.get_current_memory_mb()
|
||||
freed = before - after
|
||||
if freed > 1: # 释放超过1MB才记录
|
||||
logger.info(f"垃圾回收释放: {freed:.1f}MB")
|
||||
|
||||
# 定期垃圾回收(即使未超限)
|
||||
elif self.enable_auto_gc and int(time.time()) % 60 == 0:
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"内存监控出错: {e}")
|
||||
|
||||
# 等待下次检查
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""上下文管理器出口"""
|
||||
self.stop()
|
||||
|
||||
|
||||
def get_memory_info() -> dict:
|
||||
"""
|
||||
获取系统内存信息
|
||||
|
||||
Returns:
|
||||
内存信息字典
|
||||
"""
|
||||
if not HAS_PSUTIL:
|
||||
return {"error": "psutil 未安装"}
|
||||
|
||||
try:
|
||||
mem = psutil.virtual_memory()
|
||||
process = psutil.Process()
|
||||
|
||||
return {
|
||||
"system_total_gb": mem.total / 1024 / 1024 / 1024,
|
||||
"system_available_gb": mem.available / 1024 / 1024 / 1024,
|
||||
"system_usage_percent": mem.percent,
|
||||
"process_mb": process.memory_info().rss / 1024 / 1024,
|
||||
"process_percent": (process.memory_info().rss / mem.total) * 100,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
165
plugins/A_memorix/core/utils/path_fallback_service.py
Normal file
165
plugins/A_memorix/core/utils/path_fallback_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Shared path-fallback helpers for search post-processing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..retrieval.dual_path import RetrievalResult
|
||||
|
||||
|
||||
def extract_entities(query: str, graph_store: Any) -> List[str]:
|
||||
"""Extract up to two graph nodes from a query using n-gram matching."""
|
||||
if not graph_store:
|
||||
return []
|
||||
|
||||
text = str(query or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Keep the heuristic aligned with previous legacy behavior.
|
||||
tokens = (
|
||||
text.replace("?", " ")
|
||||
.replace("!", " ")
|
||||
.replace(".", " ")
|
||||
.split()
|
||||
)
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
found_entities = set()
|
||||
skip_indices = set()
|
||||
max_n = min(4, len(tokens))
|
||||
|
||||
for size in range(max_n, 0, -1):
|
||||
for i in range(len(tokens) - size + 1):
|
||||
if any(idx in skip_indices for idx in range(i, i + size)):
|
||||
continue
|
||||
span = " ".join(tokens[i : i + size])
|
||||
matched_node = graph_store.find_node(span, ignore_case=True)
|
||||
if not matched_node:
|
||||
continue
|
||||
found_entities.add(matched_node)
|
||||
for idx in range(i, i + size):
|
||||
skip_indices.add(idx)
|
||||
|
||||
return list(found_entities)
|
||||
|
||||
|
||||
def find_paths_between_entities(
|
||||
start_node: str,
|
||||
end_node: str,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
*,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find and enrich indirect paths between two nodes."""
|
||||
if not graph_store or not metadata_store:
|
||||
return []
|
||||
|
||||
try:
|
||||
paths = graph_store.find_paths(
|
||||
start_node,
|
||||
end_node,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if not paths:
|
||||
return []
|
||||
|
||||
edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {}
|
||||
formatted_paths: List[Dict[str, Any]] = []
|
||||
|
||||
for path_nodes in paths:
|
||||
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2:
|
||||
continue
|
||||
|
||||
path_desc: List[str] = []
|
||||
for i in range(len(path_nodes) - 1):
|
||||
u = str(path_nodes[i])
|
||||
v = str(path_nodes[i + 1])
|
||||
|
||||
cache_key = tuple(sorted((u, v)))
|
||||
if cache_key in edge_cache:
|
||||
pred, direction = edge_cache[cache_key]
|
||||
else:
|
||||
pred = "related"
|
||||
direction = "->"
|
||||
rels = metadata_store.get_relations(subject=u, object=v)
|
||||
if not rels:
|
||||
rels = metadata_store.get_relations(subject=v, object=u)
|
||||
direction = "<-"
|
||||
if rels:
|
||||
best_rel = max(rels, key=lambda x: x.get("confidence", 1.0))
|
||||
pred = str(best_rel.get("predicate", "related") or "related")
|
||||
edge_cache[cache_key] = (pred, direction)
|
||||
|
||||
step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-"
|
||||
path_desc.append(step_str)
|
||||
|
||||
full_path_str = str(path_nodes[0])
|
||||
for i, step in enumerate(path_desc):
|
||||
full_path_str += f" {step} {path_nodes[i + 1]}"
|
||||
|
||||
formatted_paths.append(
|
||||
{
|
||||
"nodes": list(path_nodes),
|
||||
"description": full_path_str,
|
||||
}
|
||||
)
|
||||
|
||||
return formatted_paths
|
||||
|
||||
|
||||
def find_paths_from_query(
|
||||
query: str,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
*,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract entities from query and resolve indirect paths."""
|
||||
entities = extract_entities(query, graph_store)
|
||||
if len(entities) != 2:
|
||||
return []
|
||||
return find_paths_between_entities(
|
||||
entities[0],
|
||||
entities[1],
|
||||
graph_store,
|
||||
metadata_store,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
|
||||
|
||||
def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]:
|
||||
"""Convert path results into retrieval results for the unified pipeline."""
|
||||
converted: List[RetrievalResult] = []
|
||||
for item in paths:
|
||||
description = str(item.get("description", "")).strip()
|
||||
if not description:
|
||||
continue
|
||||
hash_seed = description.encode("utf-8")
|
||||
path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}"
|
||||
converted.append(
|
||||
RetrievalResult(
|
||||
hash_value=path_hash,
|
||||
content=f"[Indirect Relation] {description}",
|
||||
score=0.95,
|
||||
result_type="relation",
|
||||
source="graph_path",
|
||||
metadata={
|
||||
"source": "graph_path",
|
||||
"is_indirect": True,
|
||||
"nodes": list(item.get("nodes", [])),
|
||||
},
|
||||
)
|
||||
)
|
||||
return converted
|
||||
|
||||
495
plugins/A_memorix/core/utils/person_profile_service.py
Normal file
495
plugins/A_memorix/core/utils/person_profile_service.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
人物画像服务
|
||||
|
||||
主链路:
|
||||
person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import PersonInfo
|
||||
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
from ..retrieval import (
|
||||
DualPathRetriever,
|
||||
RetrievalStrategy,
|
||||
DualPathRetrieverConfig,
|
||||
SparseBM25Config,
|
||||
FusionConfig,
|
||||
GraphRelationRecallConfig,
|
||||
)
|
||||
from ..storage import MetadataStore, GraphStore, VectorStore
|
||||
|
||||
logger = get_logger("A_Memorix.PersonProfileService")
|
||||
|
||||
|
||||
class PersonProfileService:
|
||||
"""人物画像聚合/刷新服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: MetadataStore,
|
||||
graph_store: Optional[GraphStore] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
embedding_manager: Optional[EmbeddingAPIAdapter] = None,
|
||||
sparse_index: Any = None,
|
||||
plugin_config: Optional[dict] = None,
|
||||
retriever: Optional[DualPathRetriever] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.graph_store = graph_store
|
||||
self.vector_store = vector_store
|
||||
self.embedding_manager = embedding_manager
|
||||
self.sparse_index = sparse_index
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.retriever = retriever or self._build_retriever()
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
"""读取嵌套配置。"""
|
||||
if not isinstance(self.plugin_config, dict):
|
||||
return default
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
def _build_retriever(self) -> Optional[DualPathRetriever]:
|
||||
"""按需构建检索器(无依赖时返回 None)。"""
|
||||
if not all(
|
||||
[
|
||||
self.vector_store is not None,
|
||||
self.graph_store is not None,
|
||||
self.metadata_store is not None,
|
||||
self.embedding_manager is not None,
|
||||
]
|
||||
):
|
||||
return None
|
||||
try:
|
||||
sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {}
|
||||
fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {}
|
||||
graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {}
|
||||
if not isinstance(sparse_cfg_raw, dict):
|
||||
sparse_cfg_raw = {}
|
||||
if not isinstance(fusion_cfg_raw, dict):
|
||||
fusion_cfg_raw = {}
|
||||
if not isinstance(graph_recall_cfg_raw, dict):
|
||||
graph_recall_cfg_raw = {}
|
||||
|
||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||
fusion_cfg = FusionConfig(**fusion_cfg_raw)
|
||||
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
|
||||
config = DualPathRetrieverConfig(
|
||||
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)),
|
||||
top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)),
|
||||
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
|
||||
alpha=float(self._cfg("retrieval.alpha", 0.5)),
|
||||
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
|
||||
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
|
||||
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
|
||||
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
|
||||
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
|
||||
debug=bool(self._cfg("advanced.debug", False)),
|
||||
sparse=sparse_cfg,
|
||||
fusion=fusion_cfg,
|
||||
graph_recall=graph_recall_cfg,
|
||||
)
|
||||
return DualPathRetriever(
|
||||
vector_store=self.vector_store,
|
||||
graph_store=self.graph_store,
|
||||
metadata_store=self.metadata_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
sparse_index=self.sparse_index,
|
||||
config=config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def resolve_person_id(identifier: str) -> str:
|
||||
"""按 person_id 或姓名/别名解析 person_id。"""
|
||||
if not identifier:
|
||||
return ""
|
||||
key = str(identifier).strip()
|
||||
if not key:
|
||||
return ""
|
||||
|
||||
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
|
||||
return key.lower()
|
||||
|
||||
try:
|
||||
record = (
|
||||
PersonInfo.select(PersonInfo.person_id)
|
||||
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
|
||||
.first()
|
||||
)
|
||||
if record and record.person_id:
|
||||
return str(record.person_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
record = (
|
||||
PersonInfo.select(PersonInfo.person_id)
|
||||
.where(PersonInfo.group_nick_name.contains(key))
|
||||
.first()
|
||||
)
|
||||
if record and record.person_id:
|
||||
return str(record.person_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
|
||||
if not raw_value:
|
||||
return []
|
||||
if isinstance(raw_value, list):
|
||||
items = raw_value
|
||||
else:
|
||||
try:
|
||||
items = json.loads(raw_value)
|
||||
except Exception:
|
||||
return []
|
||||
names: List[str] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
value = str(item.get("group_nick_name", "")).strip()
|
||||
if value:
|
||||
names.append(value)
|
||||
elif isinstance(item, str):
|
||||
value = item.strip()
|
||||
if value:
|
||||
names.append(value)
|
||||
return names
|
||||
|
||||
def _parse_memory_traits(self, raw_value: Any) -> List[str]:
|
||||
if not raw_value:
|
||||
return []
|
||||
try:
|
||||
values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(values, list):
|
||||
return []
|
||||
traits: List[str] = []
|
||||
for item in values:
|
||||
text = str(item).strip()
|
||||
if not text:
|
||||
continue
|
||||
if ":" in text:
|
||||
parts = text.split(":")
|
||||
if len(parts) >= 3:
|
||||
content = ":".join(parts[1:-1]).strip()
|
||||
if content:
|
||||
traits.append(content)
|
||||
continue
|
||||
traits.append(text)
|
||||
return traits[:10]
|
||||
|
||||
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
|
||||
"""获取人物别名集合、主展示名、记忆特征。"""
|
||||
aliases: List[str] = []
|
||||
primary_name = ""
|
||||
memory_traits: List[str] = []
|
||||
if not person_id:
|
||||
return aliases, primary_name, memory_traits
|
||||
try:
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if not record:
|
||||
return aliases, primary_name, memory_traits
|
||||
person_name = str(getattr(record, "person_name", "") or "").strip()
|
||||
nickname = str(getattr(record, "nickname", "") or "").strip()
|
||||
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
|
||||
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
|
||||
|
||||
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
|
||||
|
||||
candidates = [person_name, nickname] + group_nicks
|
||||
seen = set()
|
||||
for item in candidates:
|
||||
norm = str(item or "").strip()
|
||||
if not norm or norm in seen:
|
||||
continue
|
||||
seen.add(norm)
|
||||
aliases.append(norm)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}")
|
||||
return aliases, primary_name, memory_traits
|
||||
|
||||
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
|
||||
relation_by_hash: Dict[str, Dict[str, Any]] = {}
|
||||
for alias in aliases:
|
||||
for rel in self.metadata_store.get_relations(subject=alias):
|
||||
h = str(rel.get("hash", ""))
|
||||
if h:
|
||||
relation_by_hash[h] = rel
|
||||
for rel in self.metadata_store.get_relations(object=alias):
|
||||
h = str(rel.get("hash", ""))
|
||||
if h:
|
||||
relation_by_hash[h] = rel
|
||||
|
||||
relations = list(relation_by_hash.values())
|
||||
relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True)
|
||||
relations = relations[: max(1, int(limit))]
|
||||
|
||||
edges: List[Dict[str, Any]] = []
|
||||
for rel in relations:
|
||||
edges.append(
|
||||
{
|
||||
"hash": str(rel.get("hash", "")),
|
||||
"subject": str(rel.get("subject", "")),
|
||||
"predicate": str(rel.get("predicate", "")),
|
||||
"object": str(rel.get("object", "")),
|
||||
"confidence": float(rel.get("confidence", 1.0) or 1.0),
|
||||
}
|
||||
)
|
||||
return edges
|
||||
|
||||
async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]:
|
||||
alias_queries = [a for a in aliases if a]
|
||||
if not alias_queries:
|
||||
return []
|
||||
|
||||
if self.retriever is None:
|
||||
# 回退:无检索器时只做简单内容匹配
|
||||
fallback: List[Dict[str, Any]] = []
|
||||
seen_hash = set()
|
||||
for alias in alias_queries:
|
||||
for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]:
|
||||
h = str(para.get("hash", ""))
|
||||
if not h or h in seen_hash:
|
||||
continue
|
||||
seen_hash.add(h)
|
||||
fallback.append(
|
||||
{
|
||||
"hash": h,
|
||||
"type": "paragraph",
|
||||
"score": 0.0,
|
||||
"content": str(para.get("content", ""))[:180],
|
||||
"metadata": {},
|
||||
}
|
||||
)
|
||||
return fallback[:top_k]
|
||||
|
||||
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
|
||||
seen_hash = set()
|
||||
evidence: List[Dict[str, Any]] = []
|
||||
for alias in alias_queries:
|
||||
try:
|
||||
results = await self.retriever.retrieve(alias, top_k=per_alias_top_k)
|
||||
except Exception as e:
|
||||
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
|
||||
continue
|
||||
for item in results:
|
||||
h = str(getattr(item, "hash_value", "") or "")
|
||||
if not h or h in seen_hash:
|
||||
continue
|
||||
seen_hash.add(h)
|
||||
evidence.append(
|
||||
{
|
||||
"hash": h,
|
||||
"type": str(getattr(item, "result_type", "")),
|
||||
"score": float(getattr(item, "score", 0.0) or 0.0),
|
||||
"content": str(getattr(item, "content", "") or "")[:220],
|
||||
"metadata": dict(getattr(item, "metadata", {}) or {}),
|
||||
}
|
||||
)
|
||||
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
|
||||
return evidence[:top_k]
|
||||
|
||||
def _build_profile_text(
|
||||
self,
|
||||
person_id: str,
|
||||
primary_name: str,
|
||||
aliases: List[str],
|
||||
relation_edges: List[Dict[str, Any]],
|
||||
vector_evidence: List[Dict[str, Any]],
|
||||
memory_traits: List[str],
|
||||
) -> str:
|
||||
"""基于证据构建画像文本(供 LLM 上下文注入)。"""
|
||||
lines: List[str] = []
|
||||
lines.append(f"人物ID: {person_id}")
|
||||
if primary_name:
|
||||
lines.append(f"主称呼: {primary_name}")
|
||||
if aliases:
|
||||
lines.append(f"别名: {', '.join(aliases[:8])}")
|
||||
if memory_traits:
|
||||
lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}")
|
||||
|
||||
if relation_edges:
|
||||
lines.append("关系证据:")
|
||||
for rel in relation_edges[:6]:
|
||||
s = rel.get("subject", "")
|
||||
p = rel.get("predicate", "")
|
||||
o = rel.get("object", "")
|
||||
conf = float(rel.get("confidence", 0.0))
|
||||
lines.append(f"- {s} {p} {o} (conf={conf:.2f})")
|
||||
|
||||
if vector_evidence:
|
||||
lines.append("向量证据摘要:")
|
||||
for item in vector_evidence[:4]:
|
||||
content = str(item.get("content", "")).strip()
|
||||
if content:
|
||||
lines.append(f"- {content}")
|
||||
|
||||
if len(lines) <= 2:
|
||||
lines.append("暂无足够证据形成稳定画像。")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool:
|
||||
if not snapshot:
|
||||
return True
|
||||
now = time.time()
|
||||
expires_at = snapshot.get("expires_at")
|
||||
if expires_at is not None:
|
||||
try:
|
||||
return now >= float(expires_at)
|
||||
except Exception:
|
||||
return True
|
||||
updated_at = float(snapshot.get("updated_at") or 0.0)
|
||||
return (now - updated_at) >= ttl_seconds
|
||||
|
||||
def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""将手工覆盖并入画像结果(覆盖 profile_text,同时保留 auto_profile_text)。"""
|
||||
payload = dict(profile_payload or {})
|
||||
auto_text = str(payload.get("profile_text", "") or "")
|
||||
payload["auto_profile_text"] = auto_text
|
||||
payload["has_manual_override"] = False
|
||||
payload["manual_override_text"] = ""
|
||||
payload["override_updated_at"] = None
|
||||
payload["override_updated_by"] = ""
|
||||
payload["profile_source"] = "auto_snapshot"
|
||||
|
||||
if not person_id or self.metadata_store is None:
|
||||
return payload
|
||||
|
||||
try:
|
||||
override = self.metadata_store.get_person_profile_override(person_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}")
|
||||
return payload
|
||||
|
||||
if not override:
|
||||
return payload
|
||||
|
||||
manual_text = str(override.get("override_text", "") or "").strip()
|
||||
if not manual_text:
|
||||
return payload
|
||||
|
||||
payload["has_manual_override"] = True
|
||||
payload["manual_override_text"] = manual_text
|
||||
payload["override_updated_at"] = override.get("updated_at")
|
||||
payload["override_updated_by"] = str(override.get("updated_by", "") or "")
|
||||
payload["profile_text"] = manual_text
|
||||
payload["profile_source"] = "manual_override"
|
||||
return payload
|
||||
|
||||
async def query_person_profile(
|
||||
self,
|
||||
person_id: str = "",
|
||||
person_keyword: str = "",
|
||||
top_k: int = 12,
|
||||
ttl_seconds: float = 6 * 3600,
|
||||
force_refresh: bool = False,
|
||||
source_note: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""查询或刷新人物画像。"""
|
||||
pid = str(person_id or "").strip()
|
||||
if not pid and person_keyword:
|
||||
pid = self.resolve_person_id(person_keyword)
|
||||
|
||||
if not pid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "person_id 无效,且未能通过别名解析",
|
||||
}
|
||||
|
||||
latest = self.metadata_store.get_latest_person_profile_snapshot(pid)
|
||||
if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds):
|
||||
aliases, primary_name, _ = self.get_person_aliases(pid)
|
||||
payload = {
|
||||
"success": True,
|
||||
"person_id": pid,
|
||||
"person_name": primary_name,
|
||||
"from_cache": True,
|
||||
**(latest or {}),
|
||||
}
|
||||
if aliases and not payload.get("aliases"):
|
||||
payload["aliases"] = aliases
|
||||
return {
|
||||
**self._apply_manual_override(pid, payload),
|
||||
}
|
||||
|
||||
aliases, primary_name, memory_traits = self.get_person_aliases(pid)
|
||||
if not aliases and person_keyword:
|
||||
aliases = [person_keyword.strip()]
|
||||
primary_name = person_keyword.strip()
|
||||
relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2))
|
||||
vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k))
|
||||
|
||||
evidence_ids = [
|
||||
str(item.get("hash", ""))
|
||||
for item in (relation_edges + vector_evidence)
|
||||
if str(item.get("hash", "")).strip()
|
||||
]
|
||||
dedup_ids: List[str] = []
|
||||
seen = set()
|
||||
for item in evidence_ids:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
dedup_ids.append(item)
|
||||
|
||||
profile_text = self._build_profile_text(
|
||||
person_id=pid,
|
||||
primary_name=primary_name,
|
||||
aliases=aliases,
|
||||
relation_edges=relation_edges,
|
||||
vector_evidence=vector_evidence,
|
||||
memory_traits=memory_traits,
|
||||
)
|
||||
|
||||
expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None
|
||||
snapshot = self.metadata_store.upsert_person_profile_snapshot(
|
||||
person_id=pid,
|
||||
profile_text=profile_text,
|
||||
aliases=aliases,
|
||||
relation_edges=relation_edges,
|
||||
vector_evidence=vector_evidence,
|
||||
evidence_ids=dedup_ids,
|
||||
expires_at=expires_at,
|
||||
source_note=source_note,
|
||||
)
|
||||
payload = {
|
||||
"success": True,
|
||||
"person_id": pid,
|
||||
"person_name": primary_name,
|
||||
"from_cache": False,
|
||||
**snapshot,
|
||||
}
|
||||
return {
|
||||
**self._apply_manual_override(pid, payload),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def format_persona_profile_block(profile: Dict[str, Any]) -> str:
|
||||
"""格式化给 replyer 的注入块。"""
|
||||
if not profile or not profile.get("success"):
|
||||
return ""
|
||||
text = str(profile.get("profile_text", "") or "").strip()
|
||||
if not text:
|
||||
return ""
|
||||
return (
|
||||
"【人物画像-内部参考】\n"
|
||||
f"{text}\n"
|
||||
"仅供内部推理,不要向用户逐字复述。"
|
||||
)
|
||||
27
plugins/A_memorix/core/utils/plugin_id_policy.py
Normal file
27
plugins/A_memorix/core/utils/plugin_id_policy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Plugin ID matching policy for A_Memorix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PluginIdPolicy:
|
||||
"""Centralized plugin id normalization/matching policy."""
|
||||
|
||||
CANONICAL_ID = "a_memorix"
|
||||
|
||||
@classmethod
|
||||
def normalize(cls, plugin_id: Any) -> str:
|
||||
if not isinstance(plugin_id, str):
|
||||
return ""
|
||||
return plugin_id.strip().lower()
|
||||
|
||||
@classmethod
|
||||
def is_target_plugin_id(cls, plugin_id: Any) -> bool:
|
||||
normalized = cls.normalize(plugin_id)
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized == cls.CANONICAL_ID:
|
||||
return True
|
||||
return normalized.split(".")[-1] == cls.CANONICAL_ID
|
||||
|
||||
344
plugins/A_memorix/core/utils/quantization.py
Normal file
344
plugins/A_memorix/core/utils/quantization.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
向量量化工具模块
|
||||
|
||||
提供向量量化与反量化功能,用于压缩存储空间。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from typing import Tuple, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.Quantization")
|
||||
|
||||
|
||||
class QuantizationType(Enum):
|
||||
"""量化类型枚举"""
|
||||
FLOAT32 = "float32" # 无量化
|
||||
INT8 = "int8" # 标量量化(8位整数)
|
||||
PQ = "pq" # 乘积量化(Product Quantization)
|
||||
|
||||
|
||||
def quantize_vector(
|
||||
vector: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
量化向量
|
||||
|
||||
Args:
|
||||
vector: 输入向量(float32)
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
量化后的向量:
|
||||
- INT8: int8向量
|
||||
- PQ: (编码向量, 聚类中心) 元组
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return vector.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
return _scalar_quantize_int8(vector)
|
||||
|
||||
elif quant_type == QuantizationType.PQ:
|
||||
return _product_quantize(vector)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def dequantize_vector(
|
||||
quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
original_shape: Tuple[int, ...] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
反量化向量
|
||||
|
||||
Args:
|
||||
quantized_vector: 量化后的向量
|
||||
quant_type: 量化类型
|
||||
original_shape: 原始向量形状(用于PQ)
|
||||
|
||||
Returns:
|
||||
反量化后的向量(float32)
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return quantized_vector.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
return _scalar_dequantize_int8(quantized_vector)
|
||||
|
||||
elif quant_type == QuantizationType.PQ:
|
||||
if not isinstance(quantized_vector, tuple):
|
||||
raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)")
|
||||
return _product_dequantize(quantized_vector[0], quantized_vector[1])
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
标量量化:float32 -> int8
|
||||
|
||||
将向量归一化到 [0, 255] 范围,然后映射到 int8
|
||||
|
||||
Args:
|
||||
vector: 输入向量
|
||||
|
||||
Returns:
|
||||
量化后的 int8 向量
|
||||
"""
|
||||
# 计算最小最大值
|
||||
min_val = np.min(vector)
|
||||
max_val = np.max(vector)
|
||||
|
||||
# 避免除零
|
||||
if max_val == min_val:
|
||||
return np.zeros_like(vector, dtype=np.int8)
|
||||
|
||||
# 归一化到 [0, 255]
|
||||
normalized = (vector - min_val) / (max_val - min_val) * 255
|
||||
|
||||
# 映射到 [-128, 127] 并转换为 int8
|
||||
# np.round might return float, minus 128 then cast
|
||||
quantized = np.round(normalized - 128.0).astype(np.int8)
|
||||
|
||||
# 存储归一化参数(用于反量化)
|
||||
# 在实际存储中,这些参数需要单独保存
|
||||
# 这里为了简单,我们使用一个全局字典来模拟
|
||||
if not hasattr(_scalar_quantize_int8, "_params"):
|
||||
_scalar_quantize_int8._params = {}
|
||||
|
||||
vector_id = id(vector)
|
||||
_scalar_quantize_int8._params[vector_id] = (min_val, max_val)
|
||||
|
||||
return quantized
|
||||
|
||||
|
||||
def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
标量反量化:int8 -> float32
|
||||
|
||||
Args:
|
||||
quantized: 量化后的 int8 向量
|
||||
|
||||
Returns:
|
||||
反量化后的 float32 向量
|
||||
"""
|
||||
# 计算归一化参数(如果提供了)
|
||||
# 在实际应用中,min_val 和 max_val 应该被保存
|
||||
if not hasattr(_scalar_dequantize_int8, "_params"):
|
||||
# 默认假设范围是 [-1, 1]
|
||||
return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0
|
||||
|
||||
# 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取)
|
||||
# return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min
|
||||
return (quantized.astype(np.float32) + 128.0) / 255.0
|
||||
|
||||
|
||||
def quantize_matrix(
|
||||
matrix: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
量化矩阵(批量量化向量)
|
||||
|
||||
Args:
|
||||
matrix: 输入矩阵(N x D,每行是一个向量)
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
量化后的矩阵
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return matrix.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
# 对整个矩阵进行全局归一化
|
||||
min_val = np.min(matrix)
|
||||
max_val = np.max(matrix)
|
||||
|
||||
if max_val == min_val:
|
||||
return np.zeros_like(matrix, dtype=np.int8)
|
||||
|
||||
# 归一化到 [0, 255]
|
||||
normalized = (matrix - min_val) / (max_val - min_val) * 255
|
||||
quantized = np.round(normalized).astype(np.int8)
|
||||
|
||||
return quantized
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def dequantize_matrix(
|
||||
quantized_matrix: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
min_val: float = None,
|
||||
max_val: float = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
反量化矩阵
|
||||
|
||||
Args:
|
||||
quantized_matrix: 量化后的矩阵
|
||||
quant_type: 量化类型
|
||||
min_val: 归一化最小值(int8反量化需要)
|
||||
max_val: 归一化最大值(int8反量化需要)
|
||||
|
||||
Returns:
|
||||
反量化后的矩阵
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return quantized_matrix.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
# 使用提供的归一化参数反量化
|
||||
if min_val is None or max_val is None:
|
||||
# 默认假设范围是 [0, 255] -> [-1, 1]
|
||||
return quantized_matrix.astype(np.float32) / 127.0
|
||||
else:
|
||||
# 恢复到原始范围
|
||||
normalized = quantized_matrix.astype(np.float32) / 255.0
|
||||
return normalized * (max_val - min_val) + min_val
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def estimate_memory_reduction(
|
||||
num_vectors: int,
|
||||
dimension: int,
|
||||
from_type: QuantizationType,
|
||||
to_type: QuantizationType,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
估算内存节省量
|
||||
|
||||
Args:
|
||||
num_vectors: 向量数量
|
||||
dimension: 向量维度
|
||||
from_type: 原始量化类型
|
||||
to_type: 目标量化类型
|
||||
|
||||
Returns:
|
||||
(原始大小MB, 量化后大小MB, 节省比例)
|
||||
"""
|
||||
# 计算每个向量占用的字节数
|
||||
bytes_per_element = {
|
||||
QuantizationType.FLOAT32: 4,
|
||||
QuantizationType.INT8: 1,
|
||||
QuantizationType.PQ: 0.25, # 假设压缩到1/4
|
||||
}
|
||||
|
||||
original_bytes = num_vectors * dimension * bytes_per_element[from_type]
|
||||
quantized_bytes = num_vectors * dimension * bytes_per_element[to_type]
|
||||
|
||||
original_mb = original_bytes / 1024 / 1024
|
||||
quantized_mb = quantized_bytes / 1024 / 1024
|
||||
reduction_ratio = (original_bytes - quantized_bytes) / original_bytes
|
||||
|
||||
return original_mb, quantized_mb, reduction_ratio
|
||||
|
||||
|
||||
def estimate_compression_stats(
|
||||
num_vectors: int,
|
||||
dimension: int,
|
||||
quant_type: QuantizationType,
|
||||
) -> dict:
|
||||
"""
|
||||
估算压缩统计信息
|
||||
|
||||
Args:
|
||||
num_vectors: 向量数量
|
||||
dimension: 向量维度
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
original_mb, quantized_mb, ratio = estimate_memory_reduction(
|
||||
num_vectors, dimension, QuantizationType.FLOAT32, quant_type
|
||||
)
|
||||
|
||||
return {
|
||||
"num_vectors": num_vectors,
|
||||
"dimension": dimension,
|
||||
"quantization_type": quant_type.value,
|
||||
"original_size_mb": round(original_mb, 2),
|
||||
"quantized_size_mb": round(quantized_mb, 2),
|
||||
"saved_mb": round(original_mb - quantized_mb, 2),
|
||||
"compression_ratio": round(ratio * 100, 2),
|
||||
}
|
||||
|
||||
|
||||
def _product_quantize(
|
||||
vector: np.ndarray, m: int = 8, k: int = 256
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
乘积量化 (PQ) 简化实现
|
||||
|
||||
Args:
|
||||
vector: 输入向量 (D,)
|
||||
m: 子空间数量
|
||||
k: 每个子空间的聚类中心数
|
||||
|
||||
Returns:
|
||||
(编码后的向量, 聚类中心)
|
||||
"""
|
||||
d = vector.shape[0]
|
||||
if d % m != 0:
|
||||
raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除")
|
||||
|
||||
ds = d // m # 子空间维度
|
||||
codes = np.zeros(m, dtype=np.uint8)
|
||||
centroids = np.zeros((m, k, ds), dtype=np.float32)
|
||||
|
||||
# 这里采用一种简化的 PQ:不进行 K-means 训练,
|
||||
# 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练)
|
||||
# 为了演示,我们直接将子空间切分为 k 份进行量化
|
||||
for i in range(m):
|
||||
sub_vec = vector[i * ds : (i + 1) * ds]
|
||||
# 简化:假定每个子空间的取值范围并划分
|
||||
# 实际 PQ 应使用 k-means 产生的 centroids
|
||||
# 这里为演示创建一个随机 codebook 并找到最接近的核心
|
||||
sub_min, sub_max = np.min(sub_vec), np.max(sub_vec)
|
||||
if sub_max == sub_min:
|
||||
linspace = np.zeros(k)
|
||||
else:
|
||||
linspace = np.linspace(sub_min, sub_max, k)
|
||||
|
||||
for j in range(k):
|
||||
centroids[i, j, :] = linspace[j]
|
||||
|
||||
# 编码:这里简化为取子空间均值找最接近的 centroid
|
||||
sub_mean = np.mean(sub_vec)
|
||||
code = np.argmin(np.abs(linspace - sub_mean))
|
||||
codes[i] = code
|
||||
|
||||
return codes, centroids
|
||||
|
||||
|
||||
def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
PQ 反量化
|
||||
|
||||
Args:
|
||||
codes: 编码向量 (M,)
|
||||
centroids: 聚类中心 (M, K, DS)
|
||||
|
||||
Returns:
|
||||
恢复后的向量 (D,)
|
||||
"""
|
||||
m, k, ds = centroids.shape
|
||||
vector = np.zeros(m * ds, dtype=np.float32)
|
||||
|
||||
for i in range(m):
|
||||
code = codes[i]
|
||||
vector[i * ds : (i + 1) * ds] = centroids[i, code, :]
|
||||
|
||||
return vector
|
||||
121
plugins/A_memorix/core/utils/relation_query.py
Normal file
121
plugins/A_memorix/core/utils/relation_query.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""关系查询规格解析工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationQuerySpec:
|
||||
raw: str
|
||||
is_structured: bool
|
||||
subject: Optional[str]
|
||||
predicate: Optional[str]
|
||||
object: Optional[str]
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
_NATURAL_LANGUAGE_PATTERN = re.compile(
|
||||
r"(^\s*(what|who|which|how|why|when|where)\b|"
|
||||
r"\?|?|"
|
||||
r"\b(relation|related|between)\b|"
|
||||
r"(什么关系|有哪些关系|之间|关联))",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _looks_like_natural_language(raw: str) -> bool:
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
return _NATURAL_LANGUAGE_PATTERN.search(text) is not None
|
||||
|
||||
|
||||
def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec:
|
||||
raw = str(relation_spec or "").strip()
|
||||
if not raw:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="empty",
|
||||
)
|
||||
|
||||
if "|" in raw:
|
||||
parts = [p.strip() for p in raw.split("|")]
|
||||
if len(parts) < 2:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="invalid_pipe_format",
|
||||
)
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0] or None,
|
||||
predicate=parts[1] or None,
|
||||
object=parts[2] if len(parts) > 2 and parts[2] else None,
|
||||
)
|
||||
|
||||
if "->" in raw:
|
||||
parts = [p.strip() for p in raw.split("->") if p.strip()]
|
||||
if len(parts) >= 3:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=parts[1],
|
||||
object=parts[2],
|
||||
)
|
||||
if len(parts) == 2:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=None,
|
||||
object=parts[1],
|
||||
)
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="invalid_arrow_format",
|
||||
)
|
||||
|
||||
if _looks_like_natural_language(raw):
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
)
|
||||
|
||||
# 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。
|
||||
# 两词形式过于模糊,不再视为结构化关系查询。
|
||||
parts = raw.split()
|
||||
if len(parts) == 3:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=parts[1],
|
||||
object=parts[2],
|
||||
)
|
||||
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
)
|
||||
164
plugins/A_memorix/core/utils/relation_write_service.py
Normal file
164
plugins/A_memorix/core/utils/relation_write_service.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
统一关系写入与关系向量化服务。
|
||||
|
||||
规则:
|
||||
1. 元数据是主数据源,向量是从索引。
|
||||
2. 关系先写 metadata,再写向量。
|
||||
3. 向量失败不回滚 metadata,依赖状态机与回填任务修复。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("A_Memorix.RelationWriteService")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationWriteResult:
|
||||
hash_value: str
|
||||
vector_written: bool
|
||||
vector_already_exists: bool
|
||||
vector_state: str
|
||||
|
||||
|
||||
class RelationWriteService:
|
||||
"""关系写入收口服务。"""
|
||||
|
||||
ERROR_MAX_LEN = 500
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: Any,
|
||||
graph_store: Any,
|
||||
vector_store: Any,
|
||||
embedding_manager: Any,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.graph_store = graph_store
|
||||
self.vector_store = vector_store
|
||||
self.embedding_manager = embedding_manager
|
||||
|
||||
@staticmethod
|
||||
def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str:
|
||||
s = str(subject or "").strip()
|
||||
p = str(predicate or "").strip()
|
||||
o = str(obj or "").strip()
|
||||
# 双表达:兼容关键词检索与自然语言问句
|
||||
return f"{s} {p} {o}\n{s}和{o}的关系是{p}"
|
||||
|
||||
async def ensure_relation_vector(
|
||||
self,
|
||||
hash_value: str,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
*,
|
||||
max_error_len: int = ERROR_MAX_LEN,
|
||||
) -> RelationWriteResult:
|
||||
"""
|
||||
为已有关系确保向量存在并更新状态。
|
||||
"""
|
||||
if hash_value in self.vector_store:
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=True,
|
||||
vector_state="ready",
|
||||
)
|
||||
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "pending")
|
||||
try:
|
||||
vector_text = self.build_relation_vector_text(subject, predicate, obj)
|
||||
embedding = await self.embedding_manager.encode(vector_text)
|
||||
self.vector_store.add(
|
||||
vectors=embedding.reshape(1, -1),
|
||||
ids=[hash_value],
|
||||
)
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
logger.info(
|
||||
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
|
||||
hash_value[:16],
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=True,
|
||||
vector_already_exists=False,
|
||||
vector_state="ready",
|
||||
)
|
||||
except ValueError:
|
||||
# 向量已存在冲突,按成功处理
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=True,
|
||||
vector_state="ready",
|
||||
)
|
||||
except Exception as e:
|
||||
err = str(e)[:max_error_len]
|
||||
self.metadata_store.set_relation_vector_state(
|
||||
hash_value,
|
||||
"failed",
|
||||
error=err,
|
||||
bump_retry=True,
|
||||
)
|
||||
logger.warning(
|
||||
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
|
||||
hash_value[:16],
|
||||
err,
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=False,
|
||||
vector_state="failed",
|
||||
)
|
||||
|
||||
async def upsert_relation_with_vector(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
confidence: float = 1.0,
|
||||
source_paragraph: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
write_vector: bool = True,
|
||||
) -> RelationWriteResult:
|
||||
"""
|
||||
统一关系写入:
|
||||
1) 写 metadata relation
|
||||
2) 写 graph edge relation_hash
|
||||
3) 按需写 relation vector
|
||||
"""
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
confidence=confidence,
|
||||
source_paragraph=source_paragraph,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
|
||||
|
||||
if not write_vector:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
return RelationWriteResult(
|
||||
hash_value=rel_hash,
|
||||
vector_written=False,
|
||||
vector_already_exists=False,
|
||||
vector_state="none",
|
||||
)
|
||||
|
||||
return await self.ensure_relation_vector(
|
||||
hash_value=rel_hash,
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
)
|
||||
197
plugins/A_memorix/core/utils/runtime_self_check.py
Normal file
197
plugins/A_memorix/core/utils/runtime_self_check.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Runtime self-check helpers for A_Memorix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.RuntimeSelfCheck")
|
||||
|
||||
_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check"
|
||||
|
||||
|
||||
def _safe_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
|
||||
def _get_config_value(config: Any, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def _build_report(
|
||||
*,
|
||||
ok: bool,
|
||||
code: str,
|
||||
message: str,
|
||||
configured_dimension: int,
|
||||
vector_store_dimension: int,
|
||||
detected_dimension: int,
|
||||
encoded_dimension: int,
|
||||
elapsed_ms: float,
|
||||
sample_text: str,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"ok": bool(ok),
|
||||
"code": str(code or "").strip(),
|
||||
"message": str(message or "").strip(),
|
||||
"configured_dimension": int(configured_dimension),
|
||||
"vector_store_dimension": int(vector_store_dimension),
|
||||
"detected_dimension": int(detected_dimension),
|
||||
"encoded_dimension": int(encoded_dimension),
|
||||
"elapsed_ms": float(elapsed_ms),
|
||||
"sample_text": str(sample_text or ""),
|
||||
"checked_at": time.time(),
|
||||
}
|
||||
|
||||
|
||||
async def run_embedding_runtime_self_check(
|
||||
*,
|
||||
config: Any,
|
||||
vector_store: Optional[Any],
|
||||
embedding_manager: Optional[Any],
|
||||
sample_text: str = _DEFAULT_SAMPLE_TEXT,
|
||||
) -> Dict[str, Any]:
|
||||
"""Probe the real embedding path and compare dimensions with runtime storage."""
|
||||
configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0)
|
||||
vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0)
|
||||
|
||||
if vector_store is None or embedding_manager is None:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="runtime_components_missing",
|
||||
message="vector_store 或 embedding_manager 未初始化",
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=0,
|
||||
encoded_dimension=0,
|
||||
elapsed_ms=0.0,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
detected_dimension = 0
|
||||
encoded_dimension = 0
|
||||
try:
|
||||
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
|
||||
encoded = await embedding_manager.encode(sample_text)
|
||||
if isinstance(encoded, np.ndarray):
|
||||
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
|
||||
else:
|
||||
encoded_dimension = len(encoded) if encoded is not None else 0
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
logger.warning("embedding runtime self-check failed: %s", exc)
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="embedding_probe_failed",
|
||||
message=f"embedding probe failed: {exc}",
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
expected_dimension = vector_store_dimension or configured_dimension or detected_dimension
|
||||
if expected_dimension <= 0:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="invalid_expected_dimension",
|
||||
message="无法确定期望 embedding 维度",
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
if encoded_dimension != expected_dimension:
|
||||
msg = (
|
||||
"embedding 真实输出维度与当前向量存储不一致: "
|
||||
f"expected={expected_dimension}, encoded={encoded_dimension}"
|
||||
)
|
||||
logger.error(msg)
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="embedding_dimension_mismatch",
|
||||
message=msg,
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
return _build_report(
|
||||
ok=True,
|
||||
code="ok",
|
||||
message="embedding runtime self-check passed",
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
|
||||
async def ensure_runtime_self_check(
|
||||
plugin_or_config: Any,
|
||||
*,
|
||||
force: bool = False,
|
||||
sample_text: str = _DEFAULT_SAMPLE_TEXT,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run or reuse cached runtime self-check report."""
|
||||
if plugin_or_config is None:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="missing_plugin_or_config",
|
||||
message="plugin/config unavailable",
|
||||
configured_dimension=0,
|
||||
vector_store_dimension=0,
|
||||
detected_dimension=0,
|
||||
encoded_dimension=0,
|
||||
elapsed_ms=0.0,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
cache = getattr(plugin_or_config, "_runtime_self_check_report", None)
|
||||
if isinstance(cache, dict) and cache and not force:
|
||||
return cache
|
||||
|
||||
report = await run_embedding_runtime_self_check(
|
||||
config=getattr(plugin_or_config, "config", plugin_or_config),
|
||||
vector_store=getattr(plugin_or_config, "vector_store", None)
|
||||
if not isinstance(plugin_or_config, dict)
|
||||
else plugin_or_config.get("vector_store"),
|
||||
embedding_manager=getattr(plugin_or_config, "embedding_manager", None)
|
||||
if not isinstance(plugin_or_config, dict)
|
||||
else plugin_or_config.get("embedding_manager"),
|
||||
sample_text=sample_text,
|
||||
)
|
||||
try:
|
||||
setattr(plugin_or_config, "_runtime_self_check_report", report)
|
||||
except Exception:
|
||||
pass
|
||||
return report
|
||||
90
plugins/A_memorix/core/utils/search_postprocess.py
Normal file
90
plugins/A_memorix/core/utils/search_postprocess.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Post-processing helpers for unified search execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from .path_fallback_service import find_paths_from_query, to_retrieval_results
|
||||
|
||||
|
||||
def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]:
|
||||
"""Deduplicate results by hash/content while preserving at least one entry."""
|
||||
if not results:
|
||||
return [], 0
|
||||
|
||||
unique_results: List[Any] = []
|
||||
seen_hashes = set()
|
||||
seen_contents = set()
|
||||
|
||||
for item in results:
|
||||
content = str(getattr(item, "content", "") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content))
|
||||
if hash_value in seen_hashes:
|
||||
continue
|
||||
|
||||
is_dup = False
|
||||
for seen in seen_contents:
|
||||
if content in seen or seen in content:
|
||||
is_dup = True
|
||||
break
|
||||
if is_dup:
|
||||
continue
|
||||
|
||||
seen_hashes.add(hash_value)
|
||||
seen_contents.add(content)
|
||||
unique_results.append(item)
|
||||
|
||||
if not unique_results:
|
||||
unique_results.append(results[0])
|
||||
|
||||
removed_count = max(0, len(results) - len(unique_results))
|
||||
return unique_results, removed_count
|
||||
|
||||
|
||||
def maybe_apply_smart_path_fallback(
|
||||
*,
|
||||
query: str,
|
||||
results: List[Any],
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
enabled: bool,
|
||||
threshold: float,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> Tuple[List[Any], bool, int]:
|
||||
"""Append indirect relation paths when semantic results are weak."""
|
||||
if not enabled or not str(query or "").strip():
|
||||
return results, False, 0
|
||||
if graph_store is None or metadata_store is None:
|
||||
return results, False, 0
|
||||
|
||||
max_score = 0.0
|
||||
if results:
|
||||
try:
|
||||
max_score = float(getattr(results[0], "score", 0.0) or 0.0)
|
||||
except Exception:
|
||||
max_score = 0.0
|
||||
|
||||
if max_score >= float(threshold):
|
||||
return results, False, 0
|
||||
|
||||
paths = find_paths_from_query(
|
||||
query=query,
|
||||
graph_store=graph_store,
|
||||
metadata_store=metadata_store,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
if not paths:
|
||||
return results, False, 0
|
||||
|
||||
path_results = to_retrieval_results(paths)
|
||||
if not path_results:
|
||||
return results, False, 0
|
||||
|
||||
merged = list(path_results) + list(results)
|
||||
return merged, True, len(path_results)
|
||||
|
||||
170
plugins/A_memorix/core/utils/time_parser.py
Normal file
170
plugins/A_memorix/core/utils/time_parser.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
时间解析工具。
|
||||
|
||||
约束:
|
||||
1. 查询参数(Action/Command/Tool)仅接受结构化绝对时间:
|
||||
- YYYY/MM/DD
|
||||
- YYYY/MM/DD HH:mm
|
||||
2. 入库时允许更宽松格式(含时间戳、YYYY-MM-DD 等)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
|
||||
_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
|
||||
_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$")
|
||||
_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$")
|
||||
|
||||
_INGEST_FORMATS = [
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%d",
|
||||
]
|
||||
|
||||
_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"}
|
||||
|
||||
|
||||
def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float:
|
||||
"""解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。"""
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
raise ValueError("时间不能为空")
|
||||
|
||||
if _QUERY_DATE_RE.fullmatch(text):
|
||||
dt = datetime.strptime(text, "%Y/%m/%d")
|
||||
if is_end:
|
||||
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
|
||||
return dt.timestamp()
|
||||
|
||||
if _QUERY_MINUTE_RE.fullmatch(text):
|
||||
dt = datetime.strptime(text, "%Y/%m/%d %H:%M")
|
||||
return dt.timestamp()
|
||||
|
||||
raise ValueError(
|
||||
f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm"
|
||||
)
|
||||
|
||||
|
||||
def parse_query_time_range(
|
||||
time_from: Optional[str],
|
||||
time_to: Optional[str],
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""解析查询窗口并验证区间。"""
|
||||
ts_from = (
|
||||
parse_query_datetime_to_timestamp(time_from, is_end=False)
|
||||
if time_from
|
||||
else None
|
||||
)
|
||||
ts_to = (
|
||||
parse_query_datetime_to_timestamp(time_to, is_end=True)
|
||||
if time_to
|
||||
else None
|
||||
)
|
||||
|
||||
if ts_from is not None and ts_to is not None and ts_from > ts_to:
|
||||
raise ValueError("time_from 不能晚于 time_to")
|
||||
|
||||
return ts_from, ts_to
|
||||
|
||||
|
||||
def parse_ingest_datetime_to_timestamp(
|
||||
value: Any,
|
||||
is_end: bool = False,
|
||||
) -> Optional[float]:
|
||||
"""解析入库时间,允许 timestamp/常见字符串格式。"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
if _NUMERIC_RE.fullmatch(text):
|
||||
return float(text)
|
||||
|
||||
for fmt in _INGEST_FORMATS:
|
||||
try:
|
||||
dt = datetime.strptime(text, fmt)
|
||||
if fmt in _INGEST_DATE_FORMATS and is_end:
|
||||
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
|
||||
return dt.timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise ValueError(f"无法解析时间: {text}")
|
||||
|
||||
|
||||
def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""归一化 time_meta 到存储层字段。"""
|
||||
if not time_meta:
|
||||
return {}
|
||||
|
||||
normalized: Dict[str, Any] = {}
|
||||
|
||||
event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time"))
|
||||
event_start = parse_ingest_datetime_to_timestamp(
|
||||
time_meta.get("event_time_start"),
|
||||
is_end=False,
|
||||
)
|
||||
event_end = parse_ingest_datetime_to_timestamp(
|
||||
time_meta.get("event_time_end"),
|
||||
is_end=True,
|
||||
)
|
||||
|
||||
time_range = time_meta.get("time_range")
|
||||
if (
|
||||
isinstance(time_range, (list, tuple))
|
||||
and len(time_range) == 2
|
||||
):
|
||||
if event_start is None:
|
||||
event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False)
|
||||
if event_end is None:
|
||||
event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True)
|
||||
|
||||
if event_start is not None and event_end is not None and event_start > event_end:
|
||||
raise ValueError("event_time_start 不能晚于 event_time_end")
|
||||
|
||||
if event_time is not None:
|
||||
normalized["event_time"] = event_time
|
||||
if event_start is not None:
|
||||
normalized["event_time_start"] = event_start
|
||||
if event_end is not None:
|
||||
normalized["event_time_end"] = event_end
|
||||
|
||||
granularity = time_meta.get("time_granularity")
|
||||
if granularity:
|
||||
normalized["time_granularity"] = str(granularity)
|
||||
else:
|
||||
raw_time_values = [
|
||||
time_meta.get("event_time"),
|
||||
time_meta.get("event_time_start"),
|
||||
time_meta.get("event_time_end"),
|
||||
]
|
||||
has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None)
|
||||
normalized["time_granularity"] = "minute" if has_minute else "day"
|
||||
|
||||
confidence = time_meta.get("time_confidence")
|
||||
if confidence is not None:
|
||||
normalized["time_confidence"] = float(confidence)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def format_timestamp(ts: Optional[float]) -> Optional[str]:
|
||||
"""将 timestamp 格式化为 YYYY/MM/DD HH:mm。"""
|
||||
if ts is None:
|
||||
return None
|
||||
return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M")
|
||||
|
||||
207
plugins/A_memorix/plugin.py
Normal file
207
plugins/A_memorix/plugin.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""A_Memorix SDK plugin entry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from maibot_sdk import MaiBotPlugin, Tool
|
||||
from maibot_sdk.types import ToolParameterInfo, ToolParamType
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
|
||||
|
||||
def _tool_param(name: str, param_type: ToolParamType, description: str, required: bool) -> ToolParameterInfo:
|
||||
return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required)
|
||||
|
||||
|
||||
class AMemorixPlugin(MaiBotPlugin):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._plugin_root = Path(__file__).resolve().parent
|
||||
self._plugin_config: Dict[str, Any] = {}
|
||||
self._kernel: Optional[SDKMemoryKernel] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
self._plugin_config = config or {}
|
||||
if self._kernel is not None:
|
||||
self._kernel.close()
|
||||
self._kernel = None
|
||||
|
||||
async def on_load(self):
|
||||
await self._get_kernel()
|
||||
|
||||
async def on_unload(self):
|
||||
if self._kernel is not None:
|
||||
self._kernel.close()
|
||||
self._kernel = None
|
||||
|
||||
async def _get_kernel(self) -> SDKMemoryKernel:
|
||||
if self._kernel is None:
|
||||
self._kernel = SDKMemoryKernel(plugin_root=self._plugin_root, config=self._plugin_config)
|
||||
await self._kernel.initialize()
|
||||
return self._kernel
|
||||
|
||||
@Tool(
|
||||
"search_memory",
|
||||
description="搜索长期记忆",
|
||||
parameters=[
|
||||
_tool_param("query", ToolParamType.STRING, "查询文本", False),
|
||||
_tool_param("limit", ToolParamType.INTEGER, "返回条数", False),
|
||||
_tool_param("mode", ToolParamType.STRING, "search/time/hybrid/episode/aggregate", False),
|
||||
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
|
||||
_tool_param("person_id", ToolParamType.STRING, "人物 ID", False),
|
||||
_tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False),
|
||||
_tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False),
|
||||
],
|
||||
)
|
||||
async def handle_search_memory(
|
||||
self,
|
||||
query: str = "",
|
||||
limit: int = 5,
|
||||
mode: str = "hybrid",
|
||||
chat_id: str = "",
|
||||
person_id: str = "",
|
||||
time_start: float | None = None,
|
||||
time_end: float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
_ = kwargs
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=query,
|
||||
limit=limit,
|
||||
mode=mode,
|
||||
chat_id=chat_id,
|
||||
person_id=person_id,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
)
|
||||
|
||||
@Tool(
|
||||
"ingest_summary",
|
||||
description="写入聊天摘要到长期记忆",
|
||||
parameters=[
|
||||
_tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True),
|
||||
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", True),
|
||||
_tool_param("text", ToolParamType.STRING, "摘要文本", True),
|
||||
_tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False),
|
||||
_tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False),
|
||||
],
|
||||
)
|
||||
async def handle_ingest_summary(
|
||||
self,
|
||||
external_id: str,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
participants: Optional[List[str]] = None,
|
||||
time_start: float | None = None,
|
||||
time_end: float | None = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
_ = kwargs
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.ingest_summary(
|
||||
external_id=external_id,
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
participants=participants,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@Tool(
|
||||
"ingest_text",
|
||||
description="写入普通长期记忆文本",
|
||||
parameters=[
|
||||
_tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True),
|
||||
_tool_param("source_type", ToolParamType.STRING, "来源类型", True),
|
||||
_tool_param("text", ToolParamType.STRING, "原始文本", True),
|
||||
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
|
||||
_tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False),
|
||||
],
|
||||
)
|
||||
async def handle_ingest_text(
|
||||
self,
|
||||
external_id: str,
|
||||
source_type: str,
|
||||
text: str,
|
||||
chat_id: str = "",
|
||||
person_ids: Optional[List[str]] = None,
|
||||
participants: Optional[List[str]] = None,
|
||||
timestamp: float | None = None,
|
||||
time_start: float | None = None,
|
||||
time_end: float | None = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
relations = kwargs.get("relations")
|
||||
entities = kwargs.get("entities")
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.ingest_text(
|
||||
external_id=external_id,
|
||||
source_type=source_type,
|
||||
text=text,
|
||||
chat_id=chat_id,
|
||||
person_ids=person_ids,
|
||||
participants=participants,
|
||||
timestamp=timestamp,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
)
|
||||
|
||||
@Tool(
|
||||
"get_person_profile",
|
||||
description="获取人物画像",
|
||||
parameters=[
|
||||
_tool_param("person_id", ToolParamType.STRING, "人物 ID", True),
|
||||
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
|
||||
_tool_param("limit", ToolParamType.INTEGER, "证据条数", False),
|
||||
],
|
||||
)
|
||||
async def handle_get_person_profile(self, person_id: str, chat_id: str = "", limit: int = 10, **kwargs):
|
||||
_ = kwargs
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.get_person_profile(person_id=person_id, chat_id=chat_id, limit=limit)
|
||||
|
||||
@Tool(
|
||||
"maintain_memory",
|
||||
description="维护长期记忆关系状态",
|
||||
parameters=[
|
||||
_tool_param("action", ToolParamType.STRING, "reinforce/protect/restore", True),
|
||||
_tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", True),
|
||||
_tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False),
|
||||
],
|
||||
)
|
||||
async def handle_maintain_memory(
|
||||
self,
|
||||
action: str,
|
||||
target: str,
|
||||
hours: float | None = None,
|
||||
reason: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
_ = kwargs
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason)
|
||||
|
||||
@Tool("memory_stats", description="获取长期记忆统计", parameters=[])
|
||||
async def handle_memory_stats(self, **kwargs):
|
||||
_ = kwargs
|
||||
kernel = await self._get_kernel()
|
||||
return kernel.memory_stats()
|
||||
|
||||
|
||||
def create_plugin():
|
||||
return AMemorixPlugin()
|
||||
535
plugins/A_memorix/scripts/convert_lpmm.py
Normal file
535
plugins/A_memorix/scripts/convert_lpmm.py
Normal file
@@ -0,0 +1,535 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LPMM 到 A_memorix 存储转换器
|
||||
|
||||
功能:
|
||||
1. 读取 LPMM parquet 文件 (paragraph.parquet, entity.parquet, relation.parquet)
|
||||
2. 读取 LPMM 图文件 (graph.graphml 或 graph_structure.pkl)
|
||||
3. 直接写入 A_memorix 二进制 VectorStore 和稀疏 GraphStore
|
||||
4. 绕过 Embedding 生成以节省 Token
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import asyncio
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Tuple
|
||||
import numpy as np
|
||||
import tomlkit
|
||||
|
||||
# 设置路径
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
plugin_root = current_dir.parent
|
||||
project_root = plugin_root.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="将 LPMM 数据转换为 A_memorix 格式")
|
||||
parser.add_argument("--input", "-i", required=True, help="包含 LPMM 数据的输入目录 (parquet, graphml)")
|
||||
parser.add_argument("--output", "-o", required=True, help="A_memorix 数据的输出目录")
|
||||
parser.add_argument("--dim", type=int, default=384, help="Embedding 维度 (必须与 LPMM 模型匹配)")
|
||||
parser.add_argument("--batch-size", type=int, default=1024, help="Parquet 分批读取大小 (默认 1024)")
|
||||
parser.add_argument(
|
||||
"--skip-relation-vector-rebuild",
|
||||
action="store_true",
|
||||
help="跳过按关系元数据重建关系向量(默认开启)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
# --help/-h fast path: avoid heavy host/plugin bootstrap
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
sys.exit(0)
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger("LPMM_Converter")
|
||||
|
||||
try:
|
||||
import networkx as nx
|
||||
from scipy import sparse
|
||||
import pyarrow.parquet as pq
|
||||
except ImportError as e:
|
||||
logger.error(f"缺少依赖: {e}")
|
||||
logger.error("请安装: pip install pandas pyarrow networkx scipy")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
# 优先采取相对导入 (将插件根目录加入路径)
|
||||
# 这样可以避免硬编码插件名称 (plugins.A_memorix)
|
||||
if str(plugin_root) not in sys.path:
|
||||
sys.path.insert(0, str(plugin_root))
|
||||
|
||||
from core.storage.vector_store import VectorStore
|
||||
from core.storage.graph_store import GraphStore
|
||||
from core.storage.metadata_store import MetadataStore
|
||||
from core.storage import QuantizationType, SparseMatrixFormat
|
||||
from core.embedding import create_embedding_api_adapter
|
||||
from core.utils.relation_write_service import RelationWriteService
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入 A_memorix 核心模块: {e}")
|
||||
logger.error("请确保在正确的环境中运行,且已安装所有依赖。")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class LPMMConverter:
|
||||
def __init__(
|
||||
self,
|
||||
lpmm_data_dir: Path,
|
||||
output_dir: Path,
|
||||
dimension: int = 384,
|
||||
batch_size: int = 1024,
|
||||
rebuild_relation_vectors: bool = True,
|
||||
):
|
||||
self.lpmm_dir = lpmm_data_dir
|
||||
self.output_dir = output_dir
|
||||
self.dimension = dimension
|
||||
self.batch_size = max(1, int(batch_size))
|
||||
self.rebuild_relation_vectors = bool(rebuild_relation_vectors)
|
||||
|
||||
self.vector_dir = output_dir / "vectors"
|
||||
self.graph_dir = output_dir / "graph"
|
||||
self.metadata_dir = output_dir / "metadata"
|
||||
|
||||
self.vector_store = None
|
||||
self.graph_store = None
|
||||
self.metadata_store = None
|
||||
self.embedding_manager = None
|
||||
self.relation_write_service = None
|
||||
# LPMM 原 ID -> A_memorix ID 映射(用于图重写)
|
||||
self.id_mapping: Dict[str, str] = {}
|
||||
|
||||
def _register_id_mapping(self, raw_id: Any, mapped_id: str, p_type: str) -> None:
|
||||
"""记录 ID 映射,兼容带/不带类型前缀两种格式。"""
|
||||
if raw_id is None:
|
||||
return
|
||||
|
||||
raw = str(raw_id).strip()
|
||||
if not raw:
|
||||
return
|
||||
|
||||
self.id_mapping[raw] = mapped_id
|
||||
|
||||
prefix = f"{p_type}-"
|
||||
if raw.startswith(prefix):
|
||||
self.id_mapping[raw[len(prefix):]] = mapped_id
|
||||
else:
|
||||
self.id_mapping[prefix + raw] = mapped_id
|
||||
|
||||
def _map_node_id(self, node: Any) -> str:
|
||||
"""将图节点 ID 映射到转换后的 A_memorix ID。"""
|
||||
node_key = str(node)
|
||||
return self.id_mapping.get(node_key, node_key)
|
||||
|
||||
def initialize_stores(self):
|
||||
"""初始化空的 A_memorix 存储"""
|
||||
logger.info(f"正在初始化存储于 {self.output_dir}...")
|
||||
|
||||
# 初始化 VectorStore (A_memorix 默认使用 INT8 量化)
|
||||
self.vector_store = VectorStore(
|
||||
dimension=self.dimension,
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=self.vector_dir
|
||||
)
|
||||
self.vector_store.clear() # 清空旧数据
|
||||
|
||||
# 初始化 GraphStore (使用 CSR 格式)
|
||||
self.graph_store = GraphStore(
|
||||
matrix_format=SparseMatrixFormat.CSR,
|
||||
data_dir=self.graph_dir
|
||||
)
|
||||
self.graph_store.clear()
|
||||
|
||||
# 初始化 MetadataStore
|
||||
self.metadata_store = MetadataStore(data_dir=self.metadata_dir)
|
||||
self.metadata_store.connect()
|
||||
# 清空元数据表?理想情况下是的,但要小心。
|
||||
# 对于转换,我们假设是全新的开始或覆盖。
|
||||
# A_memorix 中的 MetadataStore 通常使用 SQLite。
|
||||
# 如果目录是新的,我们会依赖它创建新文件。
|
||||
if self.rebuild_relation_vectors:
|
||||
self._init_relation_vector_service()
|
||||
|
||||
def _load_plugin_config(self) -> Dict[str, Any]:
|
||||
config_path = plugin_root / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
parsed = tomlkit.load(f)
|
||||
return dict(parsed) if isinstance(parsed, dict) else {}
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 config.toml 失败,使用默认 embedding 配置: {e}")
|
||||
return {}
|
||||
|
||||
def _init_relation_vector_service(self) -> None:
|
||||
if not self.rebuild_relation_vectors:
|
||||
return
|
||||
cfg = self._load_plugin_config()
|
||||
emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {}
|
||||
if not isinstance(emb_cfg, dict):
|
||||
emb_cfg = {}
|
||||
try:
|
||||
self.embedding_manager = create_embedding_api_adapter(
|
||||
batch_size=int(emb_cfg.get("batch_size", 32)),
|
||||
max_concurrent=int(emb_cfg.get("max_concurrent", 5)),
|
||||
default_dimension=int(emb_cfg.get("dimension", self.dimension)),
|
||||
model_name=str(emb_cfg.get("model_name", "auto")),
|
||||
retry_config=emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {},
|
||||
)
|
||||
self.relation_write_service = RelationWriteService(
|
||||
metadata_store=self.metadata_store,
|
||||
graph_store=self.graph_store,
|
||||
vector_store=self.vector_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
)
|
||||
except Exception as e:
|
||||
self.embedding_manager = None
|
||||
self.relation_write_service = None
|
||||
logger.warning(f"初始化关系向量重建服务失败,将跳过关系向量回填: {e}")
|
||||
|
||||
async def _rebuild_relation_vectors(self) -> None:
|
||||
if not self.rebuild_relation_vectors:
|
||||
return
|
||||
if self.relation_write_service is None:
|
||||
logger.warning("关系向量重建已启用,但写入服务不可用,已跳过。")
|
||||
return
|
||||
|
||||
rows = self.metadata_store.get_relations()
|
||||
if not rows:
|
||||
logger.info("未发现关系元数据,无需重建关系向量。")
|
||||
return
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
result = await self.relation_write_service.ensure_relation_vector(
|
||||
hash_value=str(row["hash"]),
|
||||
subject=str(row.get("subject", "")),
|
||||
predicate=str(row.get("predicate", "")),
|
||||
obj=str(row.get("object", "")),
|
||||
)
|
||||
if result.vector_state == "ready":
|
||||
if result.vector_written:
|
||||
success += 1
|
||||
else:
|
||||
skipped += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
logger.info(
|
||||
"关系向量重建完成: total=%s success=%s skipped=%s failed=%s",
|
||||
len(rows),
|
||||
success,
|
||||
skipped,
|
||||
failed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_relation_text(text: str) -> Tuple[str, str, str]:
|
||||
raw = str(text or "").strip()
|
||||
if not raw:
|
||||
return "", "", ""
|
||||
if "|" in raw:
|
||||
parts = [p.strip() for p in raw.split("|") if p.strip()]
|
||||
if len(parts) >= 3:
|
||||
return parts[0], parts[1], parts[2]
|
||||
if "->" in raw:
|
||||
parts = [p.strip() for p in raw.split("->") if p.strip()]
|
||||
if len(parts) >= 3:
|
||||
return parts[0], parts[1], parts[2]
|
||||
pieces = raw.split()
|
||||
if len(pieces) >= 3:
|
||||
return pieces[0], pieces[1], " ".join(pieces[2:])
|
||||
return "", "", ""
|
||||
|
||||
def _import_relation_metadata_from_parquet(self, relation_path: Path) -> int:
|
||||
if not relation_path.exists():
|
||||
return 0
|
||||
|
||||
try:
|
||||
parquet_file = pq.ParquetFile(relation_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 relation.parquet 失败,跳过关系元数据导入: {e}")
|
||||
return 0
|
||||
|
||||
cols = set(parquet_file.schema_arrow.names)
|
||||
has_triple_cols = {"subject", "predicate", "object"}.issubset(cols)
|
||||
content_col = "str" if "str" in cols else ("content" if "content" in cols else "")
|
||||
|
||||
imported_hashes = set()
|
||||
imported = 0
|
||||
for record_batch in parquet_file.iter_batches(batch_size=self.batch_size):
|
||||
df_batch = record_batch.to_pandas()
|
||||
for _, row in df_batch.iterrows():
|
||||
subject = ""
|
||||
predicate = ""
|
||||
obj = ""
|
||||
if has_triple_cols:
|
||||
subject = str(row.get("subject", "") or "").strip()
|
||||
predicate = str(row.get("predicate", "") or "").strip()
|
||||
obj = str(row.get("object", "") or "").strip()
|
||||
elif content_col:
|
||||
subject, predicate, obj = self._parse_relation_text(row.get(content_col, ""))
|
||||
|
||||
if not (subject and predicate and obj):
|
||||
continue
|
||||
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
source_paragraph=None,
|
||||
)
|
||||
if rel_hash in imported_hashes:
|
||||
continue
|
||||
imported_hashes.add(rel_hash)
|
||||
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
|
||||
try:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
except Exception:
|
||||
pass
|
||||
imported += 1
|
||||
|
||||
return imported
|
||||
|
||||
def convert_vectors(self):
|
||||
"""将 Parquet 向量转换为 VectorStore"""
|
||||
# LPMM 默认文件名
|
||||
parquet_files = {
|
||||
"paragraph": self.lpmm_dir / "paragraph.parquet",
|
||||
"entity": self.lpmm_dir / "entity.parquet",
|
||||
"relation": self.lpmm_dir / "relation.parquet"
|
||||
}
|
||||
|
||||
total_vectors = 0
|
||||
|
||||
for p_type, p_path in parquet_files.items():
|
||||
# 关系向量在当前脚本中无法保证与 MetadataStore 的关系记录一一对应,
|
||||
# 直接导入会污染召回结果(命中后无法反查 relation 元数据)。
|
||||
if p_type == "relation":
|
||||
relation_count = self._import_relation_metadata_from_parquet(p_path)
|
||||
logger.warning(
|
||||
"跳过 relation.parquet 向量导入(保持一致性);已导入关系元数据: %s",
|
||||
relation_count,
|
||||
)
|
||||
continue
|
||||
|
||||
if not p_path.exists():
|
||||
logger.warning(f"文件未找到: {p_path}, 跳过 {p_type} 向量。")
|
||||
continue
|
||||
|
||||
logger.info(f"正在处理 {p_type} 向量,来源: {p_path}...")
|
||||
try:
|
||||
parquet_file = pq.ParquetFile(p_path)
|
||||
total_rows = parquet_file.metadata.num_rows
|
||||
if total_rows == 0:
|
||||
logger.info(f"{p_path} 为空,跳过。")
|
||||
continue
|
||||
|
||||
# LPMM Schema: 'hash', 'embedding', 'str'
|
||||
cols = parquet_file.schema_arrow.names
|
||||
# 兼容性检查
|
||||
content_col = 'str' if 'str' in cols else 'content'
|
||||
emb_col = 'embedding'
|
||||
hash_col = 'hash'
|
||||
|
||||
if content_col not in cols or emb_col not in cols:
|
||||
logger.error(f"{p_path} 中缺少必要列 (需包含 {content_col}, {emb_col})。发现: {cols}")
|
||||
continue
|
||||
|
||||
batch_columns = [content_col, emb_col]
|
||||
if hash_col in cols:
|
||||
batch_columns.append(hash_col)
|
||||
|
||||
processed_rows = 0
|
||||
added_for_type = 0
|
||||
batch_idx = 0
|
||||
|
||||
for record_batch in parquet_file.iter_batches(
|
||||
batch_size=self.batch_size,
|
||||
columns=batch_columns,
|
||||
):
|
||||
batch_idx += 1
|
||||
df_batch = record_batch.to_pandas()
|
||||
|
||||
embeddings_list = []
|
||||
ids_list = []
|
||||
|
||||
# 同时处理元数据映射
|
||||
for _, row in df_batch.iterrows():
|
||||
processed_rows += 1
|
||||
content = row[content_col]
|
||||
emb = row[emb_col]
|
||||
|
||||
if content is None or (isinstance(content, float) and np.isnan(content)):
|
||||
continue
|
||||
content = str(content).strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if emb is None or len(emb) == 0:
|
||||
continue
|
||||
|
||||
# 先写 MetadataStore,并使用其返回的真实 hash 作为向量 ID
|
||||
# 保证检索返回 ID 可以直接反查元数据。
|
||||
store_id = None
|
||||
if p_type == "paragraph":
|
||||
store_id = self.metadata_store.add_paragraph(
|
||||
content=content,
|
||||
source="lpmm_import",
|
||||
knowledge_type="factual",
|
||||
)
|
||||
elif p_type == "entity":
|
||||
store_id = self.metadata_store.add_entity(name=content)
|
||||
else:
|
||||
continue
|
||||
|
||||
raw_hash = row[hash_col] if hash_col in df_batch.columns else None
|
||||
if raw_hash is not None and not (isinstance(raw_hash, float) and np.isnan(raw_hash)):
|
||||
self._register_id_mapping(raw_hash, store_id, p_type)
|
||||
|
||||
# 确保 embedding 是 numpy 数组
|
||||
emb_np = np.array(emb, dtype=np.float32)
|
||||
if emb_np.shape[0] != self.dimension:
|
||||
logger.error(f"维度不匹配: {emb_np.shape[0]} vs {self.dimension}")
|
||||
continue
|
||||
|
||||
embeddings_list.append(emb_np)
|
||||
ids_list.append(store_id)
|
||||
|
||||
if embeddings_list:
|
||||
# 分批添加到向量存储
|
||||
vectors_np = np.stack(embeddings_list)
|
||||
count = self.vector_store.add(vectors_np, ids_list)
|
||||
added_for_type += count
|
||||
total_vectors += count
|
||||
|
||||
if batch_idx == 1 or batch_idx % 10 == 0:
|
||||
logger.info(
|
||||
f"[{p_type}] 批次 {batch_idx}: 已扫描 {processed_rows}/{total_rows}, 已导入 {added_for_type}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{p_type} 向量处理完成:总扫描 {processed_rows},总导入 {added_for_type}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 {p_path} 时出错: {e}")
|
||||
|
||||
# 提交向量存储
|
||||
self.vector_store.save()
|
||||
logger.info(f"向量转换完成。总向量数: {total_vectors}")
|
||||
|
||||
def convert_graph(self):
|
||||
"""将 LPMM 图转换为 GraphStore"""
|
||||
# LPMM 默认文件名是 rag-graph.graphml
|
||||
graph_files = [
|
||||
self.lpmm_dir / "rag-graph.graphml",
|
||||
self.lpmm_dir / "graph.graphml",
|
||||
self.lpmm_dir / "graph_structure.pkl"
|
||||
]
|
||||
|
||||
nx_graph = None
|
||||
|
||||
for g_path in graph_files:
|
||||
if g_path.exists():
|
||||
logger.info(f"发现图文件: {g_path}")
|
||||
try:
|
||||
if g_path.suffix == ".graphml":
|
||||
nx_graph = nx.read_graphml(g_path)
|
||||
elif g_path.suffix == ".pkl":
|
||||
with open(g_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
# LPMM 可能会将图存储在包装类中
|
||||
if hasattr(data, "graph") and isinstance(data.graph, nx.Graph):
|
||||
nx_graph = data.graph
|
||||
elif isinstance(data, nx.Graph):
|
||||
nx_graph = data
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"加载 {g_path} 失败: {e}")
|
||||
|
||||
if nx_graph is None:
|
||||
logger.warning("未找到有效的图文件。跳过图转换。")
|
||||
return
|
||||
|
||||
logger.info(f"已加载图,包含 {nx_graph.number_of_nodes()} 个节点和 {nx_graph.number_of_edges()} 条边。")
|
||||
|
||||
# 1. 添加节点
|
||||
# LPMM 节点通常是哈希或带前缀的字符串。
|
||||
# 我们需要将它们映射到 A_memorix 格式。
|
||||
# 如果 LPMM 使用 "entity-HASH",则与 A_memorix 匹配。
|
||||
|
||||
nodes_to_add = []
|
||||
node_attrs = {}
|
||||
|
||||
for node, attrs in nx_graph.nodes(data=True):
|
||||
# 假设 LPMM 使用一致的命名 "entity-..." 或 "paragraph-..."
|
||||
mapped_node = self._map_node_id(node)
|
||||
nodes_to_add.append(mapped_node)
|
||||
if attrs:
|
||||
node_attrs[mapped_node] = attrs
|
||||
|
||||
self.graph_store.add_nodes(nodes_to_add, node_attrs)
|
||||
|
||||
# 2. 添加边
|
||||
edges_to_add = []
|
||||
weights = []
|
||||
|
||||
for u, v, data in nx_graph.edges(data=True):
|
||||
weight = data.get("weight", 1.0)
|
||||
edges_to_add.append((self._map_node_id(u), self._map_node_id(v)))
|
||||
weights.append(float(weight))
|
||||
|
||||
# 如果可能,将关系同步到 MetadataStore
|
||||
# 但图的边并不总是包含关系谓词
|
||||
# 如果 LPMM 边数据有 'predicate',我们可以添加到元数据
|
||||
# 通常 LPMM 边是加权和,谓词信息可能在简单图中丢失
|
||||
|
||||
if edges_to_add:
|
||||
self.graph_store.add_edges(edges_to_add, weights)
|
||||
|
||||
self.graph_store.save()
|
||||
logger.info("图转换完成。")
|
||||
|
||||
def run(self):
|
||||
self.initialize_stores()
|
||||
self.convert_vectors()
|
||||
self.convert_graph()
|
||||
asyncio.run(self._rebuild_relation_vectors())
|
||||
self.vector_store.save()
|
||||
self.graph_store.save()
|
||||
self.metadata_store.close()
|
||||
logger.info("所有转换成功完成。")
|
||||
|
||||
|
||||
def main():
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
output_path = Path(args.output)
|
||||
|
||||
if not input_path.exists():
|
||||
logger.error(f"输入目录不存在: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
converter = LPMMConverter(
|
||||
input_path,
|
||||
output_path,
|
||||
dimension=args.dim,
|
||||
batch_size=args.batch_size,
|
||||
rebuild_relation_vectors=not bool(args.skip_relation_vector_rebuild),
|
||||
)
|
||||
converter.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
110
plugins/A_memorix/scripts/migrate_chat_history.py
Normal file
110
plugins/A_memorix/scripts/migrate_chat_history.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
PLUGIN_ROOT = CURRENT_DIR.parent
|
||||
WORKSPACE_ROOT = PLUGIN_ROOT.parent
|
||||
MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot"
|
||||
DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db"
|
||||
|
||||
if str(WORKSPACE_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(WORKSPACE_ROOT))
|
||||
if str(MAIBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(MAIBOT_ROOT))
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="迁移 MaiBot chat_history 到 A_Memorix")
|
||||
parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径")
|
||||
parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录")
|
||||
parser.add_argument("--limit", type=int, default=0, help="限制迁移条数,0 表示全部")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _to_timestamp(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(text).timestamp()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
async def _main() -> int:
|
||||
args = _parse_args()
|
||||
db_path = Path(args.db_path).resolve()
|
||||
if not db_path.exists():
|
||||
print(f"数据库不存在: {db_path}")
|
||||
return 1
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
sql = """
|
||||
SELECT id, session_id, start_timestamp, end_timestamp, participants, theme, keywords, summary
|
||||
FROM chat_history
|
||||
ORDER BY id ASC
|
||||
"""
|
||||
if int(args.limit or 0) > 0:
|
||||
sql += " LIMIT ?"
|
||||
rows = conn.execute(sql, (int(args.limit),)).fetchall()
|
||||
else:
|
||||
rows = conn.execute(sql).fetchall()
|
||||
conn.close()
|
||||
|
||||
print(f"chat_history 待处理: {len(rows)}")
|
||||
if args.dry_run:
|
||||
for row in rows[:5]:
|
||||
print(f"- id={row['id']} session={row['session_id']} theme={row['theme']}")
|
||||
return 0
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}})
|
||||
await kernel.initialize()
|
||||
migrated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
participants = json.loads(row["participants"]) if row["participants"] else []
|
||||
keywords = json.loads(row["keywords"]) if row["keywords"] else []
|
||||
theme = str(row["theme"] or "").strip()
|
||||
summary = str(row["summary"] or "").strip()
|
||||
text = f"主题:{theme}\n概括:{summary}".strip()
|
||||
result: Dict[str, Any] = await kernel.ingest_summary(
|
||||
external_id=f"chat_history:{row['id']}",
|
||||
chat_id=str(row["session_id"] or ""),
|
||||
text=text,
|
||||
participants=participants,
|
||||
time_start=_to_timestamp(row["start_timestamp"]),
|
||||
time_end=_to_timestamp(row["end_timestamp"]),
|
||||
tags=keywords,
|
||||
metadata={"theme": theme, "source_row_id": int(row["id"])},
|
||||
)
|
||||
if result.get("stored_ids"):
|
||||
migrated += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
print(f"迁移完成: migrated={migrated} skipped={skipped}")
|
||||
print(json.dumps(kernel.memory_stats(), ensure_ascii=False))
|
||||
kernel.close()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(asyncio.run(_main()))
|
||||
120
plugins/A_memorix/scripts/migrate_person_memory_points.py
Normal file
120
plugins/A_memorix/scripts/migrate_person_memory_points.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
PLUGIN_ROOT = CURRENT_DIR.parent
|
||||
WORKSPACE_ROOT = PLUGIN_ROOT.parent
|
||||
MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot"
|
||||
DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db"
|
||||
|
||||
if str(WORKSPACE_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(WORKSPACE_ROOT))
|
||||
if str(MAIBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(MAIBOT_ROOT))
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="迁移 MaiBot person_info.memory_points 到 A_Memorix")
|
||||
parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径")
|
||||
parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录")
|
||||
parser.add_argument("--limit", type=int, default=0, help="限制迁移人数,0 表示全部")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_memory_points(raw_value: Any) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
values = json.loads(raw_value) if raw_value else []
|
||||
except Exception:
|
||||
values = []
|
||||
items: List[Dict[str, Any]] = []
|
||||
for index, item in enumerate(values):
|
||||
text = str(item or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
parts = text.split(":")
|
||||
if len(parts) >= 3:
|
||||
category = parts[0].strip()
|
||||
content = ":".join(parts[1:-1]).strip()
|
||||
weight = parts[-1].strip()
|
||||
else:
|
||||
category = "其他"
|
||||
content = text
|
||||
weight = "1.0"
|
||||
if content:
|
||||
items.append({"index": index, "category": category or "其他", "content": content, "weight": weight or "1.0"})
|
||||
return items
|
||||
|
||||
|
||||
async def _main() -> int:
|
||||
args = _parse_args()
|
||||
db_path = Path(args.db_path).resolve()
|
||||
if not db_path.exists():
|
||||
print(f"数据库不存在: {db_path}")
|
||||
return 1
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
sql = """
|
||||
SELECT person_id, person_name, user_nickname, memory_points
|
||||
FROM person_info
|
||||
WHERE memory_points IS NOT NULL AND memory_points != ''
|
||||
ORDER BY id ASC
|
||||
"""
|
||||
if int(args.limit or 0) > 0:
|
||||
sql += " LIMIT ?"
|
||||
rows = conn.execute(sql, (int(args.limit),)).fetchall()
|
||||
else:
|
||||
rows = conn.execute(sql).fetchall()
|
||||
conn.close()
|
||||
|
||||
preview_total = sum(len(_parse_memory_points(row["memory_points"])) for row in rows)
|
||||
print(f"person_info 待迁移人物: {len(rows)} 记忆点: {preview_total}")
|
||||
if args.dry_run:
|
||||
for row in rows[:5]:
|
||||
print(f"- person_id={row['person_id']} person_name={row['person_name'] or row['user_nickname']}")
|
||||
return 0
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}})
|
||||
await kernel.initialize()
|
||||
migrated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
person_id = str(row["person_id"] or "").strip()
|
||||
if not person_id:
|
||||
continue
|
||||
display_name = str(row["person_name"] or row["user_nickname"] or "").strip()
|
||||
for item in _parse_memory_points(row["memory_points"]):
|
||||
result: Dict[str, Any] = await kernel.ingest_text(
|
||||
external_id=f"person_memory:{person_id}:{item['index']}",
|
||||
source_type="person_fact",
|
||||
text=f"[{item['category']}] {item['content']}",
|
||||
person_ids=[person_id],
|
||||
tags=[item["category"]],
|
||||
entities=[person_id, display_name] if display_name else [person_id],
|
||||
metadata={"category": item["category"], "weight": item["weight"], "display_name": display_name},
|
||||
)
|
||||
if result.get("stored_ids"):
|
||||
migrated += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
print(f"迁移完成: migrated={migrated} skipped={skipped}")
|
||||
print(json.dumps(kernel.memory_stats(), ensure_ascii=False))
|
||||
kernel.close()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(asyncio.run(_main()))
|
||||
Reference in New Issue
Block a user