feat:新增 A_Memorix 记忆插件

引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。
This commit is contained in:
DawnARC
2026-03-18 21:33:15 +08:00
parent a5a6d2cb26
commit 999e7246e2
48 changed files with 17070 additions and 0 deletions

View File

@@ -0,0 +1,54 @@
"""检索模块 - 双路检索与排序"""
from .dual_path import (
DualPathRetriever,
RetrievalStrategy,
RetrievalResult,
DualPathRetrieverConfig,
TemporalQueryOptions,
FusionConfig,
RelationIntentConfig,
)
from .pagerank import (
PersonalizedPageRank,
PageRankConfig,
create_ppr_from_graph,
)
from .threshold import (
DynamicThresholdFilter,
ThresholdMethod,
ThresholdConfig,
)
from .sparse_bm25 import (
SparseBM25Index,
SparseBM25Config,
)
from .graph_relation_recall import (
GraphRelationRecallConfig,
GraphRelationRecallService,
)
__all__ = [
# DualPathRetriever
"DualPathRetriever",
"RetrievalStrategy",
"RetrievalResult",
"DualPathRetrieverConfig",
"TemporalQueryOptions",
"FusionConfig",
"RelationIntentConfig",
# PersonalizedPageRank
"PersonalizedPageRank",
"PageRankConfig",
"create_ppr_from_graph",
# DynamicThresholdFilter
"DynamicThresholdFilter",
"ThresholdMethod",
"ThresholdConfig",
# Sparse BM25
"SparseBM25Index",
"SparseBM25Config",
# Graph relation recall
"GraphRelationRecallConfig",
"GraphRelationRecallService",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,272 @@
"""Graph-assisted relation candidate recall for relation-oriented queries."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set
from src.common.logger import get_logger
logger = get_logger("A_Memorix.GraphRelationRecall")
@dataclass
class GraphRelationRecallConfig:
"""Configuration for controlled graph relation recall."""
enabled: bool = True
candidate_k: int = 24
max_hop: int = 1
allow_two_hop_pair: bool = True
max_paths: int = 4
def __post_init__(self) -> None:
self.enabled = bool(self.enabled)
self.candidate_k = max(1, int(self.candidate_k))
self.max_hop = max(1, int(self.max_hop))
self.allow_two_hop_pair = bool(self.allow_two_hop_pair)
self.max_paths = max(1, int(self.max_paths))
@dataclass
class GraphRelationCandidate:
"""A graph-derived relation candidate before retriever-side fusion."""
hash_value: str
subject: str
predicate: str
object: str
confidence: float
graph_seed_entities: List[str]
graph_hops: int
graph_candidate_type: str
supporting_paragraph_count: int
def to_payload(self) -> Dict[str, Any]:
content = f"{self.subject} {self.predicate} {self.object}"
return {
"hash": self.hash_value,
"content": content,
"subject": self.subject,
"predicate": self.predicate,
"object": self.object,
"confidence": self.confidence,
"graph_seed_entities": list(self.graph_seed_entities),
"graph_hops": int(self.graph_hops),
"graph_candidate_type": self.graph_candidate_type,
"supporting_paragraph_count": int(self.supporting_paragraph_count),
}
class GraphRelationRecallService:
"""Collect relation candidates from the entity graph in a controlled way."""
def __init__(
self,
*,
graph_store: Any,
metadata_store: Any,
config: Optional[GraphRelationRecallConfig] = None,
) -> None:
self.graph_store = graph_store
self.metadata_store = metadata_store
self.config = config or GraphRelationRecallConfig()
def recall(
self,
*,
seed_entities: Sequence[str],
) -> List[GraphRelationCandidate]:
if not self.config.enabled:
return []
if self.graph_store is None or self.metadata_store is None:
return []
seeds = self._normalize_seed_entities(seed_entities)
if not seeds:
return []
seen_hashes: Set[str] = set()
candidates: List[GraphRelationCandidate] = []
if len(seeds) >= 2:
self._collect_direct_pair_candidates(
seed_a=seeds[0],
seed_b=seeds[1],
seen_hashes=seen_hashes,
out=candidates,
)
if (
len(candidates) < 3
and self.config.allow_two_hop_pair
and len(candidates) < self.config.candidate_k
):
self._collect_two_hop_pair_candidates(
seed_a=seeds[0],
seed_b=seeds[1],
seen_hashes=seen_hashes,
out=candidates,
)
else:
self._collect_one_hop_seed_candidates(
seed=seeds[0],
seen_hashes=seen_hashes,
out=candidates,
)
return candidates[: self.config.candidate_k]
def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]:
out: List[str] = []
seen = set()
for raw in list(seed_entities)[:2]:
resolved = None
try:
resolved = self.graph_store.find_node(str(raw), ignore_case=True)
except Exception:
resolved = None
if not resolved:
continue
canon = str(resolved).strip().lower()
if not canon or canon in seen:
continue
seen.add(canon)
out.append(str(resolved))
return out
def _collect_direct_pair_candidates(
self,
*,
seed_a: str,
seed_b: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
relation_hashes = []
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b))
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a))
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="direct_pair",
graph_hops=1,
graph_seed_entities=[seed_a, seed_b],
)
def _collect_two_hop_pair_candidates(
self,
*,
seed_a: str,
seed_b: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
try:
paths = self.graph_store.find_paths(
seed_a,
seed_b,
max_depth=2,
max_paths=self.config.max_paths,
)
except Exception as e:
logger.debug("graph two-hop recall skipped: %s", e)
return
for path_nodes in paths:
if len(out) >= self.config.candidate_k:
break
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3:
continue
if len(path_nodes) != 3:
continue
for idx in range(len(path_nodes) - 1):
if len(out) >= self.config.candidate_k:
break
u = str(path_nodes[idx])
v = str(path_nodes[idx + 1])
relation_hashes = []
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v))
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u))
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="two_hop_pair",
graph_hops=2,
graph_seed_entities=[seed_a, seed_b],
)
def _collect_one_hop_seed_candidates(
self,
*,
seed: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
try:
relation_hashes = self.graph_store.get_incident_relation_hashes(
seed,
limit=self.config.candidate_k,
)
except Exception as e:
logger.debug("graph one-hop recall skipped: %s", e)
return
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="one_hop_seed",
graph_hops=min(1, self.config.max_hop),
graph_seed_entities=[seed],
)
def _append_relation_hashes(
self,
*,
relation_hashes: Sequence[str],
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
candidate_type: str,
graph_hops: int,
graph_seed_entities: Sequence[str],
) -> None:
for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}):
if len(out) >= self.config.candidate_k:
break
if relation_hash in seen_hashes:
continue
candidate = self._build_candidate(
relation_hash=relation_hash,
candidate_type=candidate_type,
graph_hops=graph_hops,
graph_seed_entities=graph_seed_entities,
)
if candidate is None:
continue
seen_hashes.add(relation_hash)
out.append(candidate)
def _build_candidate(
self,
*,
relation_hash: str,
candidate_type: str,
graph_hops: int,
graph_seed_entities: Sequence[str],
) -> Optional[GraphRelationCandidate]:
relation = self.metadata_store.get_relation(relation_hash)
if relation is None:
return None
supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash)
return GraphRelationCandidate(
hash_value=relation_hash,
subject=str(relation.get("subject", "")),
predicate=str(relation.get("predicate", "")),
object=str(relation.get("object", "")),
confidence=float(relation.get("confidence", 1.0) or 1.0),
graph_seed_entities=[str(x) for x in graph_seed_entities],
graph_hops=int(graph_hops),
graph_candidate_type=str(candidate_type),
supporting_paragraph_count=len(supporting_paragraphs),
)

View File

@@ -0,0 +1,482 @@
"""
Personalized PageRank实现
提供个性化的图节点排序功能。
"""
from typing import Dict, List, Optional, Tuple, Union, Any
from dataclasses import dataclass
import numpy as np
from src.common.logger import get_logger
from ..storage import GraphStore
from ..utils.matcher import AhoCorasick
logger = get_logger("A_Memorix.PersonalizedPageRank")
@dataclass
class PageRankConfig:
"""
PageRank配置
属性:
alpha: 阻尼系数0-1之间
max_iter: 最大迭代次数
tol: 收敛阈值
normalize: 是否归一化结果
min_iterations: 最小迭代次数
"""
alpha: float = 0.85
max_iter: int = 100
tol: float = 1e-6
normalize: bool = True
min_iterations: int = 20
def __post_init__(self):
"""验证配置"""
if not 0 <= self.alpha < 1:
raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}")
if self.max_iter <= 0:
raise ValueError(f"max_iter必须大于0: {self.max_iter}")
if self.tol <= 0:
raise ValueError(f"tol必须大于0: {self.tol}")
if self.min_iterations < 0:
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
if self.min_iterations >= self.max_iter:
raise ValueError(f"min_iterations必须小于max_iter")
class PersonalizedPageRank:
"""
Personalized PageRank计算器
功能:
- 个性化向量支持
- 快速收敛检测
- 结果归一化
- 批量计算
- 统计信息
参数:
graph_store: 图存储
config: PageRank配置
"""
def __init__(
self,
graph_store: GraphStore,
config: Optional[PageRankConfig] = None,
):
"""
初始化PPR计算器
Args:
graph_store: 图存储
config: PageRank配置
"""
self.graph_store = graph_store
self.config = config or PageRankConfig()
# 统计信息
self._total_computations = 0
self._total_iterations = 0
self._convergence_history: List[int] = []
logger.info(
f"PersonalizedPageRank 初始化: "
f"alpha={self.config.alpha}, "
f"max_iter={self.config.max_iter}"
)
# 缓存 Aho-Corasick 匹配器
self._ac_matcher: Optional[AhoCorasick] = None
self._ac_nodes_count = 0
def compute(
self,
personalization: Optional[Dict[str, float]] = None,
alpha: Optional[float] = None,
max_iter: Optional[int] = None,
normalize: Optional[bool] = None,
) -> Dict[str, float]:
"""
计算Personalized PageRank
Args:
personalization: 个性化向量 {节点名: 权重}
alpha: 阻尼系数(覆盖配置值)
max_iter: 最大迭代次数(覆盖配置值)
normalize: 是否归一化(覆盖配置值)
Returns:
节点PageRank值字典 {节点名: 分数}
"""
# 使用覆盖值或配置值
alpha = alpha if alpha is not None else self.config.alpha
max_iter = max_iter if max_iter is not None else self.config.max_iter
normalize = normalize if normalize is not None else self.config.normalize
# 调用GraphStore的compute_pagerank
scores = self.graph_store.compute_pagerank(
personalization=personalization,
alpha=alpha,
max_iter=max_iter,
tol=self.config.tol,
)
# 归一化(如果需要)
if normalize and scores:
total = sum(scores.values())
if total > 0:
scores = {node: score / total for node, score in scores.items()}
# 更新统计
self._total_computations += 1
logger.debug(
f"PPR计算完成: {len(scores)} 个节点, "
f"personalization_nodes={len(personalization) if personalization else 0}"
)
return scores
def compute_batch(
self,
personalization_list: List[Dict[str, float]],
normalize: bool = True,
) -> List[Dict[str, float]]:
"""
批量计算PPR
Args:
personalization_list: 个性化向量列表
normalize: 是否归一化
Returns:
PageRank值字典列表
"""
results = []
for i, personalization in enumerate(personalization_list):
logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR")
scores = self.compute(personalization=personalization, normalize=normalize)
results.append(scores)
return results
def compute_for_entities(
self,
entities: List[str],
weights: Optional[List[float]] = None,
normalize: bool = True,
) -> Dict[str, float]:
"""
为实体列表计算PPR
Args:
entities: 实体列表
weights: 权重列表(默认均匀权重)
normalize: 是否归一化
Returns:
PageRank值字典
"""
if not entities:
logger.warning("实体列表为空返回均匀PPR")
return self.compute(personalization=None, normalize=normalize)
# 构建个性化向量
if weights is None:
weights = [1.0] * len(entities)
if len(weights) != len(entities):
raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}")
personalization = {entity: weight for entity, weight in zip(entities, weights)}
return self.compute(personalization=personalization, normalize=normalize)
def compute_for_query(
self,
query: str,
entity_extractor: Optional[callable] = None,
normalize: bool = True,
) -> Dict[str, float]:
"""
为查询计算PPR
Args:
query: 查询文本
entity_extractor: 实体提取函数(可选)
normalize: 是否归一化
Returns:
PageRank值字典
"""
# 提取实体
if entity_extractor is not None:
entities = entity_extractor(query)
else:
# 简单实现:基于图中的节点匹配
entities = self._extract_entities_from_query(query)
if not entities:
logger.debug(f"未从查询中提取到实体: '{query}'")
return self.compute(personalization=None, normalize=normalize)
# 计算PPR
return self.compute_for_entities(entities, normalize=normalize)
def rank_nodes(
self,
scores: Dict[str, float],
top_k: Optional[int] = None,
min_score: float = 0.0,
) -> List[Tuple[str, float]]:
"""
对节点排序
Args:
scores: PageRank分数字典
top_k: 返回前k个节点None表示全部
min_score: 最小分数阈值
Returns:
排序后的节点列表 [(节点名, 分数), ...]
"""
# 过滤低分节点
filtered = [(node, score) for node, score in scores.items() if score >= min_score]
# 按分数降序排序
sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True)
# 返回top_k
if top_k is not None:
sorted_nodes = sorted_nodes[:top_k]
return sorted_nodes
def get_personalization_vector(
self,
nodes: List[str],
method: str = "uniform",
) -> Dict[str, float]:
"""
生成个性化向量
Args:
nodes: 节点列表
method: 生成方法
- "uniform": 均匀权重
- "degree": 按度数加权
- "inverse_degree": 按度数反比加权
Returns:
个性化向量 {节点名: 权重}
"""
if not nodes:
return {}
if method == "uniform":
# 均匀权重
weight = 1.0 / len(nodes)
return {node: weight for node in nodes}
elif method == "degree":
# 按度数加权
node_degrees = {}
for node in nodes:
neighbors = self.graph_store.get_neighbors(node)
node_degrees[node] = len(neighbors)
total_degree = sum(node_degrees.values())
if total_degree > 0:
return {node: degree / total_degree for node, degree in node_degrees.items()}
else:
return {node: 1.0 / len(nodes) for node in nodes}
elif method == "inverse_degree":
# 按度数反比加权
node_degrees = {}
for node in nodes:
neighbors = self.graph_store.get_neighbors(node)
node_degrees[node] = len(neighbors)
# 反度数
inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()}
total_inv = sum(inv_degrees.values())
if total_inv > 0:
return {node: inv / total_inv for node, inv in inv_degrees.items()}
else:
return {node: 1.0 / len(nodes) for node in nodes}
else:
raise ValueError(f"不支持的个性化向量生成方法: {method}")
def compare_scores(
self,
scores1: Dict[str, float],
scores2: Dict[str, float],
) -> Dict[str, Dict[str, float]]:
"""
比较两组PPR分数
Args:
scores1: 第一组分数
scores2: 第二组分数
Returns:
比较结果 {
"common_nodes": {节点: (score1, score2)},
"only_in_1": {节点: score1},
"only_in_2": {节点: score2},
}
"""
common_nodes = {}
only_in_1 = {}
only_in_2 = {}
all_nodes = set(scores1.keys()) | set(scores2.keys())
for node in all_nodes:
if node in scores1 and node in scores2:
common_nodes[node] = (scores1[node], scores2[node])
elif node in scores1:
only_in_1[node] = scores1[node]
else:
only_in_2[node] = scores2[node]
return {
"common_nodes": common_nodes,
"only_in_1": only_in_1,
"only_in_2": only_in_2,
}
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
avg_iterations = (
self._total_iterations / self._total_computations
if self._total_computations > 0
else 0
)
return {
"config": {
"alpha": self.config.alpha,
"max_iter": self.config.max_iter,
"tol": self.config.tol,
"normalize": self.config.normalize,
"min_iterations": self.config.min_iterations,
},
"statistics": {
"total_computations": self._total_computations,
"total_iterations": self._total_iterations,
"avg_iterations": avg_iterations,
"convergence_history": self._convergence_history.copy(),
},
"graph": {
"num_nodes": self.graph_store.num_nodes,
"num_edges": self.graph_store.num_edges,
},
}
def reset_statistics(self) -> None:
"""重置统计信息"""
self._total_computations = 0
self._total_iterations = 0
self._convergence_history.clear()
logger.info("统计信息已重置")
def _extract_entities_from_query(self, query: str) -> List[str]:
"""
从查询中提取实体(简化实现)
Args:
query: 查询文本
Returns:
实体列表
"""
# 获取所有节点
all_nodes = self.graph_store.get_nodes()
if not all_nodes:
return []
# 检查是否需要更新 Aho-Corasick 匹配器
if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes):
self._ac_matcher = AhoCorasick()
for node in all_nodes:
# 统一转为小写进行不区分大小写匹配
self._ac_matcher.add_pattern(node.lower())
self._ac_matcher.build()
self._ac_nodes_count = len(all_nodes)
# 执行匹配
query_lower = query.lower()
stats = self._ac_matcher.find_all(query_lower)
# 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值)
# 为了简单AC 中 add_pattern 存的是小写
# 我们需要一个映射:小写 -> 原始
node_map = {node.lower(): node for node in all_nodes}
entities = [node_map[low_name] for low_name in stats.keys()]
return entities
@property
def num_computations(self) -> int:
"""计算次数"""
return self._total_computations
@property
def avg_iterations(self) -> float:
"""平均迭代次数"""
if self._total_computations == 0:
return 0.0
return self._total_iterations / self._total_computations
def __repr__(self) -> str:
return (
f"PersonalizedPageRank("
f"alpha={self.config.alpha}, "
f"computations={self._total_computations})"
)
def create_ppr_from_graph(
graph_store: GraphStore,
alpha: float = 0.85,
max_iter: int = 100,
) -> PersonalizedPageRank:
"""
从图存储创建PPR计算器
Args:
graph_store: 图存储
alpha: 阻尼系数
max_iter: 最大迭代次数
Returns:
PPR计算器实例
"""
config = PageRankConfig(
alpha=alpha,
max_iter=max_iter,
)
return PersonalizedPageRank(
graph_store=graph_store,
config=config,
)

View File

@@ -0,0 +1,402 @@
"""
稀疏检索组件FTS5 + BM25
支持:
- 懒加载索引连接
- jieba / char n-gram 分词
- 可卸载并收缩 SQLite 内存缓存
"""
from __future__ import annotations
import re
import sqlite3
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
from ..storage import MetadataStore
logger = get_logger("A_Memorix.SparseBM25")
try:
import jieba # type: ignore
HAS_JIEBA = True
except Exception:
HAS_JIEBA = False
jieba = None
@dataclass
class SparseBM25Config:
"""BM25 稀疏检索配置。"""
enabled: bool = True
backend: str = "fts5"
lazy_load: bool = True
mode: str = "auto" # auto | fallback_only | hybrid
tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram
jieba_user_dict: str = ""
char_ngram_n: int = 2
candidate_k: int = 80
max_doc_len: int = 2000
enable_ngram_fallback_index: bool = True
enable_like_fallback: bool = False
enable_relation_sparse_fallback: bool = True
relation_candidate_k: int = 60
relation_max_doc_len: int = 512
unload_on_disable: bool = True
shrink_memory_on_unload: bool = True
def __post_init__(self) -> None:
self.backend = str(self.backend or "fts5").strip().lower()
self.mode = str(self.mode or "auto").strip().lower()
self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower()
self.char_ngram_n = max(1, int(self.char_ngram_n))
self.candidate_k = max(1, int(self.candidate_k))
self.max_doc_len = max(0, int(self.max_doc_len))
self.relation_candidate_k = max(1, int(self.relation_candidate_k))
self.relation_max_doc_len = max(0, int(self.relation_max_doc_len))
if self.backend != "fts5":
raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}")
if self.mode not in {"auto", "fallback_only", "hybrid"}:
raise ValueError(f"sparse.mode 非法: {self.mode}")
if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}:
raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}")
class SparseBM25Index:
"""
基于 SQLite FTS5 的 BM25 检索适配层。
"""
def __init__(
self,
metadata_store: MetadataStore,
config: Optional[SparseBM25Config] = None,
):
self.metadata_store = metadata_store
self.config = config or SparseBM25Config()
self._conn: Optional[sqlite3.Connection] = None
self._loaded: bool = False
self._jieba_dict_loaded: bool = False
@property
def loaded(self) -> bool:
return self._loaded and self._conn is not None
def ensure_loaded(self) -> bool:
"""按需加载 FTS 连接与索引。"""
if not self.config.enabled:
return False
if self.loaded:
return True
db_path = self.metadata_store.get_db_path()
conn = sqlite3.connect(
str(db_path),
check_same_thread=False,
timeout=30.0,
)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA temp_store=MEMORY")
if not self.metadata_store.ensure_fts_schema(conn=conn):
conn.close()
return False
self.metadata_store.ensure_fts_backfilled(conn=conn)
# 关系稀疏检索按独立开关加载,避免不必要的初始化开销。
if self.config.enable_relation_sparse_fallback:
self.metadata_store.ensure_relations_fts_schema(conn=conn)
self.metadata_store.ensure_relations_fts_backfilled(conn=conn)
if self.config.enable_ngram_fallback_index:
self.metadata_store.ensure_paragraph_ngram_schema(conn=conn)
self.metadata_store.ensure_paragraph_ngram_backfilled(
n=self.config.char_ngram_n,
conn=conn,
)
self._conn = conn
self._loaded = True
self._prepare_tokenizer()
logger.info(
"SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s",
self.config.tokenizer_mode,
self.config.mode,
)
return True
def _prepare_tokenizer(self) -> None:
if self._jieba_dict_loaded:
return
if self.config.tokenizer_mode not in {"jieba", "mixed"}:
return
if not HAS_JIEBA:
logger.warning("jieba 不可用tokenizer 将退化为 char n-gram")
return
user_dict = str(self.config.jieba_user_dict or "").strip()
if user_dict:
try:
jieba.load_userdict(user_dict) # type: ignore[union-attr]
logger.info("已加载 jieba 用户词典: %s", user_dict)
except Exception as e:
logger.warning("加载 jieba 用户词典失败: %s", e)
self._jieba_dict_loaded = True
def _tokenize_jieba(self, text: str) -> List[str]:
if not HAS_JIEBA:
return []
try:
tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr]
return [t.strip().lower() for t in tokens if t and t.strip()]
except Exception:
return []
def _tokenize_char_ngram(self, text: str, n: int) -> List[str]:
compact = re.sub(r"\s+", "", text.lower())
if not compact:
return []
if len(compact) < n:
return [compact]
return [compact[i : i + n] for i in range(0, len(compact) - n + 1)]
def _tokenize(self, text: str) -> List[str]:
text = str(text or "").strip()
if not text:
return []
mode = self.config.tokenizer_mode
if mode == "jieba":
tokens = self._tokenize_jieba(text)
if tokens:
return list(dict.fromkeys(tokens))
return self._tokenize_char_ngram(text, self.config.char_ngram_n)
if mode == "mixed":
toks = self._tokenize_jieba(text)
toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n))
return list(dict.fromkeys([t for t in toks if t]))
return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n)))
def _build_match_query(self, tokens: List[str]) -> str:
safe_tokens: List[str] = []
for token in tokens:
t = token.replace('"', '""').strip()
if not t:
continue
safe_tokens.append(f'"{t}"')
if not safe_tokens:
return ""
# 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。
return " OR ".join(safe_tokens[:64])
def _fallback_substring_search(
self,
tokens: List[str],
limit: int,
) -> List[Dict[str, Any]]:
"""
当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。
说明:
- FTS 索引当前采用 unicode61 tokenizer。
- 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。
- 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。
"""
if not tokens:
return []
# 去重并裁剪 token 数量,避免生成超长 SQL。
uniq_tokens = [t for t in dict.fromkeys(tokens) if t]
uniq_tokens = uniq_tokens[:32]
if not uniq_tokens:
return []
if self.config.enable_ngram_fallback_index:
try:
# 允许运行时切换开关后按需补齐 schema/回填。
self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn)
self.metadata_store.ensure_paragraph_ngram_backfilled(
n=self.config.char_ngram_n,
conn=self._conn,
)
rows = self.metadata_store.ngram_search_paragraphs(
tokens=uniq_tokens,
limit=limit,
max_doc_len=self.config.max_doc_len,
conn=self._conn,
)
if rows:
return rows
except Exception as e:
logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}")
if not self.config.enable_like_fallback:
return []
conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens))
params: List[Any] = [f"%{tok}%" for tok in uniq_tokens]
scan_limit = max(int(limit) * 8, 200)
params.append(scan_limit)
sql = f"""
SELECT p.hash, p.content
FROM paragraphs p
WHERE (p.is_deleted IS NULL OR p.is_deleted = 0)
AND ({conditions})
LIMIT ?
"""
rows = self.metadata_store.query(sql, tuple(params))
if not rows:
return []
scored: List[Dict[str, Any]] = []
token_count = max(1, len(uniq_tokens))
for row in rows:
content = str(row.get("content") or "")
content_low = content.lower()
matched = [tok for tok in uniq_tokens if tok in content_low]
if not matched:
continue
coverage = len(matched) / token_count
length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low))
# 兜底路径使用相对分,保持与上层接口兼容。
fallback_score = coverage * 0.8 + length_bonus * 0.2
scored.append(
{
"hash": row["hash"],
"content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content,
"bm25_score": -float(fallback_score),
"fallback_score": float(fallback_score),
}
)
scored.sort(key=lambda x: x["fallback_score"], reverse=True)
return scored[:limit]
def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
"""执行 BM25 检索。"""
if not self.config.enabled:
return []
if self.config.lazy_load and not self.loaded:
if not self.ensure_loaded():
return []
if not self.loaded:
return []
# 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。
self.metadata_store.ensure_relations_fts_schema(conn=self._conn)
self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn)
tokens = self._tokenize(query)
match_query = self._build_match_query(tokens)
if not match_query:
return []
limit = max(1, int(k))
rows = self.metadata_store.fts_search_bm25(
match_query=match_query,
limit=limit,
max_doc_len=self.config.max_doc_len,
conn=self._conn,
)
if not rows:
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
results: List[Dict[str, Any]] = []
for rank, row in enumerate(rows, start=1):
bm25_score = float(row.get("bm25_score", 0.0))
results.append(
{
"hash": row["hash"],
"content": row["content"],
"rank": rank,
"bm25_score": bm25_score,
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
}
)
return results
def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
"""执行关系稀疏检索FTS5 + BM25"""
if not self.config.enabled or not self.config.enable_relation_sparse_fallback:
return []
if self.config.lazy_load and not self.loaded:
if not self.ensure_loaded():
return []
if not self.loaded:
return []
tokens = self._tokenize(query)
match_query = self._build_match_query(tokens)
if not match_query:
return []
rows = self.metadata_store.fts_search_relations_bm25(
match_query=match_query,
limit=max(1, int(k)),
max_doc_len=self.config.relation_max_doc_len,
conn=self._conn,
)
out: List[Dict[str, Any]] = []
for rank, row in enumerate(rows, start=1):
bm25_score = float(row.get("bm25_score", 0.0))
out.append(
{
"hash": row["hash"],
"subject": row["subject"],
"predicate": row["predicate"],
"object": row["object"],
"content": row["content"],
"rank": rank,
"bm25_score": bm25_score,
"score": -bm25_score,
}
)
return out
def upsert_paragraph(self, paragraph_hash: str) -> bool:
if not self.loaded:
return False
return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn)
def delete_paragraph(self, paragraph_hash: str) -> bool:
if not self.loaded:
return False
return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn)
def unload(self) -> None:
"""卸载 BM25 连接并尽量释放内存。"""
if self._conn is not None:
try:
if self.config.shrink_memory_on_unload:
self.metadata_store.shrink_memory(conn=self._conn)
except Exception:
pass
try:
self._conn.close()
except Exception:
pass
self._conn = None
self._loaded = False
logger.info("SparseBM25Index unloaded")
def stats(self) -> Dict[str, Any]:
doc_count = 0
if self.loaded:
doc_count = self.metadata_store.fts_doc_count(conn=self._conn)
return {
"enabled": self.config.enabled,
"backend": self.config.backend,
"mode": self.config.mode,
"tokenizer_mode": self.config.tokenizer_mode,
"enable_ngram_fallback_index": self.config.enable_ngram_fallback_index,
"enable_like_fallback": self.config.enable_like_fallback,
"enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback,
"loaded": self.loaded,
"has_jieba": HAS_JIEBA,
"doc_count": doc_count,
}

View File

@@ -0,0 +1,450 @@
"""
动态阈值过滤器
根据检索结果的分布特征自适应调整过滤阈值。
"""
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from dataclasses import dataclass
from enum import Enum
from src.common.logger import get_logger
from .dual_path import RetrievalResult
logger = get_logger("A_Memorix.DynamicThresholdFilter")
class ThresholdMethod(Enum):
"""阈值计算方法"""
PERCENTILE = "percentile" # 百分位数
STD_DEV = "std_dev" # 标准差
GAP_DETECTION = "gap_detection" # 跳变检测
ADAPTIVE = "adaptive" # 自适应(综合多种方法)
@dataclass
class ThresholdConfig:
"""
阈值配置
属性:
method: 阈值计算方法
min_threshold: 最小阈值(绝对值)
max_threshold: 最大阈值(绝对值)
percentile: 百分位数用于percentile方法
std_multiplier: 标准差倍数用于std_dev方法
min_results: 最少保留结果数
enable_auto_adjust: 是否自动调整参数
"""
method: ThresholdMethod = ThresholdMethod.ADAPTIVE
min_threshold: float = 0.3
max_threshold: float = 0.95
percentile: float = 75.0 # 百分位数
std_multiplier: float = 1.5 # 标准差倍数
min_results: int = 3 # 最少保留结果数
enable_auto_adjust: bool = True
def __post_init__(self):
"""验证配置"""
if not 0 <= self.min_threshold <= 1:
raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}")
if not 0 <= self.max_threshold <= 1:
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
if self.min_threshold >= self.max_threshold:
raise ValueError(f"min_threshold必须小于max_threshold")
if not 0 <= self.percentile <= 100:
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
if self.std_multiplier <= 0:
raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}")
if self.min_results < 0:
raise ValueError(f"min_results必须大于等于0: {self.min_results}")
class DynamicThresholdFilter:
"""
动态阈值过滤器
功能:
- 基于结果分布自适应计算阈值
- 多种阈值计算方法
- 自动参数调整
- 统计信息收集
参数:
config: 阈值配置
"""
def __init__(
self,
config: Optional[ThresholdConfig] = None,
):
"""
初始化动态阈值过滤器
Args:
config: 阈值配置
"""
self.config = config or ThresholdConfig()
# 统计信息
self._total_filtered = 0
self._total_processed = 0
self._threshold_history: List[float] = []
logger.info(
f"DynamicThresholdFilter 初始化: "
f"method={self.config.method.value}, "
f"min_threshold={self.config.min_threshold}"
)
def filter(
self,
results: List[RetrievalResult],
return_threshold: bool = False,
) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]:
"""
过滤检索结果
Args:
results: 检索结果列表
return_threshold: 是否返回使用的阈值
Returns:
过滤后的结果列表,或 (结果列表, 阈值) 元组
"""
if not results:
logger.debug("结果列表为空,无需过滤")
return ([], 0.0) if return_threshold else []
self._total_processed += len(results)
# 提取分数
scores = np.array([r.score for r in results])
# 计算阈值
threshold = self._compute_threshold(scores, results)
# 记录阈值
self._threshold_history.append(threshold)
# 应用阈值过滤
filtered_results = [
r for r in results
if r.score >= threshold
]
# 确保至少保留min_results个结果
if len(filtered_results) < self.config.min_results:
# 按分数排序取前min_results个
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
filtered_results = sorted_results[:self.config.min_results]
threshold = filtered_results[-1].score if filtered_results else 0.0
self._total_filtered += len(results) - len(filtered_results)
logger.info(
f"过滤完成: {len(results)} -> {len(filtered_results)} "
f"(threshold={threshold:.3f})"
)
if return_threshold:
return filtered_results, threshold
return filtered_results
def _compute_threshold(
self,
scores: np.ndarray,
results: List[RetrievalResult],
) -> float:
"""
计算阈值
Args:
scores: 分数数组
results: 检索结果列表
Returns:
阈值
"""
if self.config.method == ThresholdMethod.PERCENTILE:
threshold = self._percentile_threshold(scores)
elif self.config.method == ThresholdMethod.STD_DEV:
threshold = self._std_dev_threshold(scores)
elif self.config.method == ThresholdMethod.GAP_DETECTION:
threshold = self._gap_detection_threshold(scores)
else: # ADAPTIVE
# 自适应方法:综合多种方法
thresholds = [
self._percentile_threshold(scores),
self._std_dev_threshold(scores),
self._gap_detection_threshold(scores),
]
# 使用中位数作为最终阈值
threshold = float(np.median(thresholds))
# 限制在[min_threshold, max_threshold]范围内
threshold = np.clip(
threshold,
self.config.min_threshold,
self.config.max_threshold,
)
# 自动调整
if self.config.enable_auto_adjust:
threshold = self._auto_adjust_threshold(threshold, scores)
return float(threshold)
def _percentile_threshold(self, scores: np.ndarray) -> float:
"""
基于百分位数计算阈值
Args:
scores: 分数数组
Returns:
阈值
"""
percentile = self.config.percentile
threshold = float(np.percentile(scores, percentile))
logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})")
return threshold
def _std_dev_threshold(self, scores: np.ndarray) -> float:
"""
基于标准差计算阈值
threshold = mean - std_multiplier * std
Args:
scores: 分数数组
Returns:
阈值
"""
mean = float(np.mean(scores))
std = float(np.std(scores))
multiplier = self.config.std_multiplier
threshold = mean - multiplier * std
logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})")
return threshold
def _gap_detection_threshold(self, scores: np.ndarray) -> float:
"""
基于跳变检测计算阈值
找到分数分布中最大的"跳变"位置,以此为阈值
Args:
scores: 分数数组(降序排列)
Returns:
阈值
"""
# 降序排列
sorted_scores = np.sort(scores)[::-1]
if len(sorted_scores) < 2:
return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0
# 计算相邻分数的差值
gaps = np.diff(sorted_scores)
# 找到最大的跳变位置
max_gap_idx = int(np.argmax(gaps))
# 阈值为跳变后的分数
threshold = float(sorted_scores[max_gap_idx + 1])
logger.debug(
f"跳变检测阈值: {threshold:.3f} "
f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})"
)
return threshold
def _auto_adjust_threshold(
self,
threshold: float,
scores: np.ndarray,
) -> float:
"""
自动调整阈值
基于历史阈值和当前分数分布调整
Args:
threshold: 当前阈值
scores: 分数数组
Returns:
调整后的阈值
"""
if not self._threshold_history:
return threshold
# 计算历史阈值的移动平均
recent_thresholds = self._threshold_history[-10:] # 最近10次
avg_threshold = float(np.mean(recent_thresholds))
# 当前阈值与历史平均的差异
diff = threshold - avg_threshold
# 如果差异过大(>0.2),向历史平均靠拢
if abs(diff) > 0.2:
adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50%
logger.debug(
f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} "
f"(历史平均={avg_threshold:.3f})"
)
return adjusted_threshold
return threshold
def filter_by_confidence(
self,
results: List[RetrievalResult],
min_confidence: float = 0.5,
) -> List[RetrievalResult]:
"""
基于置信度过滤结果
Args:
results: 检索结果列表
min_confidence: 最小置信度
Returns:
过滤后的结果列表
"""
filtered = []
for result in results:
# 对于关系结果使用confidence字段
if result.result_type == "relation":
confidence = result.metadata.get("confidence", 1.0)
if confidence >= min_confidence:
filtered.append(result)
else:
# 对于段落结果,直接使用分数
if result.score >= min_confidence:
filtered.append(result)
logger.info(
f"置信度过滤: {len(results)} -> {len(filtered)} "
f"(min_confidence={min_confidence})"
)
return filtered
def filter_by_diversity(
self,
results: List[RetrievalResult],
similarity_threshold: float = 0.9,
top_k: int = 10,
) -> List[RetrievalResult]:
"""
基于多样性过滤结果(去除重复)
Args:
results: 检索结果列表
similarity_threshold: 相似度阈值(高于此值视为重复)
top_k: 最多保留结果数
Returns:
过滤后的结果列表
"""
if not results:
return []
# 按分数排序
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
# 贪心选择:选择与已选结果相似度低的结果
selected = []
selected_hashes = []
for result in sorted_results:
if len(selected) >= top_k:
break
# 检查与已选结果的相似度
is_duplicate = False
for selected_hash in selected_hashes:
# 简单判断基于hash的前缀
if result.hash_value[:8] == selected_hash[:8]:
is_duplicate = True
break
if not is_duplicate:
selected.append(result)
selected_hashes.append(result.hash_value)
logger.info(
f"多样性过滤: {len(results)} -> {len(selected)} "
f"(similarity_threshold={similarity_threshold})"
)
return selected
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
filter_rate = (
self._total_filtered / self._total_processed
if self._total_processed > 0
else 0.0
)
stats = {
"config": {
"method": self.config.method.value,
"min_threshold": self.config.min_threshold,
"max_threshold": self.config.max_threshold,
"percentile": self.config.percentile,
"std_multiplier": self.config.std_multiplier,
"min_results": self.config.min_results,
"enable_auto_adjust": self.config.enable_auto_adjust,
},
"statistics": {
"total_processed": self._total_processed,
"total_filtered": self._total_filtered,
"filter_rate": filter_rate,
"avg_threshold": float(np.mean(self._threshold_history))
if self._threshold_history else 0.0,
"threshold_count": len(self._threshold_history),
},
}
if self._threshold_history:
stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history))
stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history))
return stats
def reset_statistics(self) -> None:
"""重置统计信息"""
self._total_filtered = 0
self._total_processed = 0
self._threshold_history.clear()
logger.info("统计信息已重置")
def __repr__(self) -> str:
return (
f"DynamicThresholdFilter("
f"method={self.config.method.value}, "
f"min_threshold={self.config.min_threshold}, "
f"filtered={self._total_filtered}/{self._total_processed})"
)