feat:新增 A_Memorix 记忆插件
引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。
This commit is contained in:
54
plugins/A_memorix/core/retrieval/__init__.py
Normal file
54
plugins/A_memorix/core/retrieval/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""检索模块 - 双路检索与排序"""
|
||||
|
||||
from .dual_path import (
|
||||
DualPathRetriever,
|
||||
RetrievalStrategy,
|
||||
RetrievalResult,
|
||||
DualPathRetrieverConfig,
|
||||
TemporalQueryOptions,
|
||||
FusionConfig,
|
||||
RelationIntentConfig,
|
||||
)
|
||||
from .pagerank import (
|
||||
PersonalizedPageRank,
|
||||
PageRankConfig,
|
||||
create_ppr_from_graph,
|
||||
)
|
||||
from .threshold import (
|
||||
DynamicThresholdFilter,
|
||||
ThresholdMethod,
|
||||
ThresholdConfig,
|
||||
)
|
||||
from .sparse_bm25 import (
|
||||
SparseBM25Index,
|
||||
SparseBM25Config,
|
||||
)
|
||||
from .graph_relation_recall import (
|
||||
GraphRelationRecallConfig,
|
||||
GraphRelationRecallService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# DualPathRetriever
|
||||
"DualPathRetriever",
|
||||
"RetrievalStrategy",
|
||||
"RetrievalResult",
|
||||
"DualPathRetrieverConfig",
|
||||
"TemporalQueryOptions",
|
||||
"FusionConfig",
|
||||
"RelationIntentConfig",
|
||||
# PersonalizedPageRank
|
||||
"PersonalizedPageRank",
|
||||
"PageRankConfig",
|
||||
"create_ppr_from_graph",
|
||||
# DynamicThresholdFilter
|
||||
"DynamicThresholdFilter",
|
||||
"ThresholdMethod",
|
||||
"ThresholdConfig",
|
||||
# Sparse BM25
|
||||
"SparseBM25Index",
|
||||
"SparseBM25Config",
|
||||
# Graph relation recall
|
||||
"GraphRelationRecallConfig",
|
||||
"GraphRelationRecallService",
|
||||
]
|
||||
1796
plugins/A_memorix/core/retrieval/dual_path.py
Normal file
1796
plugins/A_memorix/core/retrieval/dual_path.py
Normal file
File diff suppressed because it is too large
Load Diff
272
plugins/A_memorix/core/retrieval/graph_relation_recall.py
Normal file
272
plugins/A_memorix/core/retrieval/graph_relation_recall.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Graph-assisted relation candidate recall for relation-oriented queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.GraphRelationRecall")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphRelationRecallConfig:
|
||||
"""Configuration for controlled graph relation recall."""
|
||||
|
||||
enabled: bool = True
|
||||
candidate_k: int = 24
|
||||
max_hop: int = 1
|
||||
allow_two_hop_pair: bool = True
|
||||
max_paths: int = 4
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.enabled = bool(self.enabled)
|
||||
self.candidate_k = max(1, int(self.candidate_k))
|
||||
self.max_hop = max(1, int(self.max_hop))
|
||||
self.allow_two_hop_pair = bool(self.allow_two_hop_pair)
|
||||
self.max_paths = max(1, int(self.max_paths))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphRelationCandidate:
|
||||
"""A graph-derived relation candidate before retriever-side fusion."""
|
||||
|
||||
hash_value: str
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
graph_seed_entities: List[str]
|
||||
graph_hops: int
|
||||
graph_candidate_type: str
|
||||
supporting_paragraph_count: int
|
||||
|
||||
def to_payload(self) -> Dict[str, Any]:
|
||||
content = f"{self.subject} {self.predicate} {self.object}"
|
||||
return {
|
||||
"hash": self.hash_value,
|
||||
"content": content,
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object": self.object,
|
||||
"confidence": self.confidence,
|
||||
"graph_seed_entities": list(self.graph_seed_entities),
|
||||
"graph_hops": int(self.graph_hops),
|
||||
"graph_candidate_type": self.graph_candidate_type,
|
||||
"supporting_paragraph_count": int(self.supporting_paragraph_count),
|
||||
}
|
||||
|
||||
|
||||
class GraphRelationRecallService:
|
||||
"""Collect relation candidates from the entity graph in a controlled way."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
config: Optional[GraphRelationRecallConfig] = None,
|
||||
) -> None:
|
||||
self.graph_store = graph_store
|
||||
self.metadata_store = metadata_store
|
||||
self.config = config or GraphRelationRecallConfig()
|
||||
|
||||
def recall(
|
||||
self,
|
||||
*,
|
||||
seed_entities: Sequence[str],
|
||||
) -> List[GraphRelationCandidate]:
|
||||
if not self.config.enabled:
|
||||
return []
|
||||
if self.graph_store is None or self.metadata_store is None:
|
||||
return []
|
||||
|
||||
seeds = self._normalize_seed_entities(seed_entities)
|
||||
if not seeds:
|
||||
return []
|
||||
|
||||
seen_hashes: Set[str] = set()
|
||||
candidates: List[GraphRelationCandidate] = []
|
||||
|
||||
if len(seeds) >= 2:
|
||||
self._collect_direct_pair_candidates(
|
||||
seed_a=seeds[0],
|
||||
seed_b=seeds[1],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
if (
|
||||
len(candidates) < 3
|
||||
and self.config.allow_two_hop_pair
|
||||
and len(candidates) < self.config.candidate_k
|
||||
):
|
||||
self._collect_two_hop_pair_candidates(
|
||||
seed_a=seeds[0],
|
||||
seed_b=seeds[1],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
else:
|
||||
self._collect_one_hop_seed_candidates(
|
||||
seed=seeds[0],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
|
||||
return candidates[: self.config.candidate_k]
|
||||
|
||||
def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]:
|
||||
out: List[str] = []
|
||||
seen = set()
|
||||
for raw in list(seed_entities)[:2]:
|
||||
resolved = None
|
||||
try:
|
||||
resolved = self.graph_store.find_node(str(raw), ignore_case=True)
|
||||
except Exception:
|
||||
resolved = None
|
||||
if not resolved:
|
||||
continue
|
||||
canon = str(resolved).strip().lower()
|
||||
if not canon or canon in seen:
|
||||
continue
|
||||
seen.add(canon)
|
||||
out.append(str(resolved))
|
||||
return out
|
||||
|
||||
def _collect_direct_pair_candidates(
|
||||
self,
|
||||
*,
|
||||
seed_a: str,
|
||||
seed_b: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
relation_hashes = []
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b))
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a))
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="direct_pair",
|
||||
graph_hops=1,
|
||||
graph_seed_entities=[seed_a, seed_b],
|
||||
)
|
||||
|
||||
def _collect_two_hop_pair_candidates(
|
||||
self,
|
||||
*,
|
||||
seed_a: str,
|
||||
seed_b: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
try:
|
||||
paths = self.graph_store.find_paths(
|
||||
seed_a,
|
||||
seed_b,
|
||||
max_depth=2,
|
||||
max_paths=self.config.max_paths,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("graph two-hop recall skipped: %s", e)
|
||||
return
|
||||
|
||||
for path_nodes in paths:
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3:
|
||||
continue
|
||||
if len(path_nodes) != 3:
|
||||
continue
|
||||
for idx in range(len(path_nodes) - 1):
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
u = str(path_nodes[idx])
|
||||
v = str(path_nodes[idx + 1])
|
||||
relation_hashes = []
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v))
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u))
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="two_hop_pair",
|
||||
graph_hops=2,
|
||||
graph_seed_entities=[seed_a, seed_b],
|
||||
)
|
||||
|
||||
def _collect_one_hop_seed_candidates(
|
||||
self,
|
||||
*,
|
||||
seed: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
try:
|
||||
relation_hashes = self.graph_store.get_incident_relation_hashes(
|
||||
seed,
|
||||
limit=self.config.candidate_k,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("graph one-hop recall skipped: %s", e)
|
||||
return
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="one_hop_seed",
|
||||
graph_hops=min(1, self.config.max_hop),
|
||||
graph_seed_entities=[seed],
|
||||
)
|
||||
|
||||
def _append_relation_hashes(
|
||||
self,
|
||||
*,
|
||||
relation_hashes: Sequence[str],
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
candidate_type: str,
|
||||
graph_hops: int,
|
||||
graph_seed_entities: Sequence[str],
|
||||
) -> None:
|
||||
for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}):
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
if relation_hash in seen_hashes:
|
||||
continue
|
||||
candidate = self._build_candidate(
|
||||
relation_hash=relation_hash,
|
||||
candidate_type=candidate_type,
|
||||
graph_hops=graph_hops,
|
||||
graph_seed_entities=graph_seed_entities,
|
||||
)
|
||||
if candidate is None:
|
||||
continue
|
||||
seen_hashes.add(relation_hash)
|
||||
out.append(candidate)
|
||||
|
||||
def _build_candidate(
|
||||
self,
|
||||
*,
|
||||
relation_hash: str,
|
||||
candidate_type: str,
|
||||
graph_hops: int,
|
||||
graph_seed_entities: Sequence[str],
|
||||
) -> Optional[GraphRelationCandidate]:
|
||||
relation = self.metadata_store.get_relation(relation_hash)
|
||||
if relation is None:
|
||||
return None
|
||||
supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash)
|
||||
return GraphRelationCandidate(
|
||||
hash_value=relation_hash,
|
||||
subject=str(relation.get("subject", "")),
|
||||
predicate=str(relation.get("predicate", "")),
|
||||
object=str(relation.get("object", "")),
|
||||
confidence=float(relation.get("confidence", 1.0) or 1.0),
|
||||
graph_seed_entities=[str(x) for x in graph_seed_entities],
|
||||
graph_hops=int(graph_hops),
|
||||
graph_candidate_type=str(candidate_type),
|
||||
supporting_paragraph_count=len(supporting_paragraphs),
|
||||
)
|
||||
482
plugins/A_memorix/core/retrieval/pagerank.py
Normal file
482
plugins/A_memorix/core/retrieval/pagerank.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
Personalized PageRank实现
|
||||
|
||||
提供个性化的图节点排序功能。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..storage import GraphStore
|
||||
from ..utils.matcher import AhoCorasick
|
||||
|
||||
logger = get_logger("A_Memorix.PersonalizedPageRank")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PageRankConfig:
|
||||
"""
|
||||
PageRank配置
|
||||
|
||||
属性:
|
||||
alpha: 阻尼系数(0-1之间)
|
||||
max_iter: 最大迭代次数
|
||||
tol: 收敛阈值
|
||||
normalize: 是否归一化结果
|
||||
min_iterations: 最小迭代次数
|
||||
"""
|
||||
|
||||
alpha: float = 0.85
|
||||
max_iter: int = 100
|
||||
tol: float = 1e-6
|
||||
normalize: bool = True
|
||||
min_iterations: int = 20
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
if not 0 <= self.alpha < 1:
|
||||
raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}")
|
||||
|
||||
if self.max_iter <= 0:
|
||||
raise ValueError(f"max_iter必须大于0: {self.max_iter}")
|
||||
|
||||
if self.tol <= 0:
|
||||
raise ValueError(f"tol必须大于0: {self.tol}")
|
||||
|
||||
if self.min_iterations < 0:
|
||||
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
|
||||
|
||||
if self.min_iterations >= self.max_iter:
|
||||
raise ValueError(f"min_iterations必须小于max_iter")
|
||||
|
||||
|
||||
class PersonalizedPageRank:
|
||||
"""
|
||||
Personalized PageRank计算器
|
||||
|
||||
功能:
|
||||
- 个性化向量支持
|
||||
- 快速收敛检测
|
||||
- 结果归一化
|
||||
- 批量计算
|
||||
- 统计信息
|
||||
|
||||
参数:
|
||||
graph_store: 图存储
|
||||
config: PageRank配置
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_store: GraphStore,
|
||||
config: Optional[PageRankConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化PPR计算器
|
||||
|
||||
Args:
|
||||
graph_store: 图存储
|
||||
config: PageRank配置
|
||||
"""
|
||||
self.graph_store = graph_store
|
||||
self.config = config or PageRankConfig()
|
||||
|
||||
# 统计信息
|
||||
self._total_computations = 0
|
||||
self._total_iterations = 0
|
||||
self._convergence_history: List[int] = []
|
||||
|
||||
logger.info(
|
||||
f"PersonalizedPageRank 初始化: "
|
||||
f"alpha={self.config.alpha}, "
|
||||
f"max_iter={self.config.max_iter}"
|
||||
)
|
||||
|
||||
# 缓存 Aho-Corasick 匹配器
|
||||
self._ac_matcher: Optional[AhoCorasick] = None
|
||||
self._ac_nodes_count = 0
|
||||
|
||||
def compute(
|
||||
self,
|
||||
personalization: Optional[Dict[str, float]] = None,
|
||||
alpha: Optional[float] = None,
|
||||
max_iter: Optional[int] = None,
|
||||
normalize: Optional[bool] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
计算Personalized PageRank
|
||||
|
||||
Args:
|
||||
personalization: 个性化向量 {节点名: 权重}
|
||||
alpha: 阻尼系数(覆盖配置值)
|
||||
max_iter: 最大迭代次数(覆盖配置值)
|
||||
normalize: 是否归一化(覆盖配置值)
|
||||
|
||||
Returns:
|
||||
节点PageRank值字典 {节点名: 分数}
|
||||
"""
|
||||
# 使用覆盖值或配置值
|
||||
alpha = alpha if alpha is not None else self.config.alpha
|
||||
max_iter = max_iter if max_iter is not None else self.config.max_iter
|
||||
normalize = normalize if normalize is not None else self.config.normalize
|
||||
|
||||
# 调用GraphStore的compute_pagerank
|
||||
scores = self.graph_store.compute_pagerank(
|
||||
personalization=personalization,
|
||||
alpha=alpha,
|
||||
max_iter=max_iter,
|
||||
tol=self.config.tol,
|
||||
)
|
||||
|
||||
# 归一化(如果需要)
|
||||
if normalize and scores:
|
||||
total = sum(scores.values())
|
||||
if total > 0:
|
||||
scores = {node: score / total for node, score in scores.items()}
|
||||
|
||||
# 更新统计
|
||||
self._total_computations += 1
|
||||
|
||||
logger.debug(
|
||||
f"PPR计算完成: {len(scores)} 个节点, "
|
||||
f"personalization_nodes={len(personalization) if personalization else 0}"
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def compute_batch(
|
||||
self,
|
||||
personalization_list: List[Dict[str, float]],
|
||||
normalize: bool = True,
|
||||
) -> List[Dict[str, float]]:
|
||||
"""
|
||||
批量计算PPR
|
||||
|
||||
Args:
|
||||
personalization_list: 个性化向量列表
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典列表
|
||||
"""
|
||||
results = []
|
||||
|
||||
for i, personalization in enumerate(personalization_list):
|
||||
logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR")
|
||||
scores = self.compute(personalization=personalization, normalize=normalize)
|
||||
results.append(scores)
|
||||
|
||||
return results
|
||||
|
||||
def compute_for_entities(
|
||||
self,
|
||||
entities: List[str],
|
||||
weights: Optional[List[float]] = None,
|
||||
normalize: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
为实体列表计算PPR
|
||||
|
||||
Args:
|
||||
entities: 实体列表
|
||||
weights: 权重列表(默认均匀权重)
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典
|
||||
"""
|
||||
if not entities:
|
||||
logger.warning("实体列表为空,返回均匀PPR")
|
||||
return self.compute(personalization=None, normalize=normalize)
|
||||
|
||||
# 构建个性化向量
|
||||
if weights is None:
|
||||
weights = [1.0] * len(entities)
|
||||
|
||||
if len(weights) != len(entities):
|
||||
raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}")
|
||||
|
||||
personalization = {entity: weight for entity, weight in zip(entities, weights)}
|
||||
|
||||
return self.compute(personalization=personalization, normalize=normalize)
|
||||
|
||||
def compute_for_query(
|
||||
self,
|
||||
query: str,
|
||||
entity_extractor: Optional[callable] = None,
|
||||
normalize: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
为查询计算PPR
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
entity_extractor: 实体提取函数(可选)
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典
|
||||
"""
|
||||
# 提取实体
|
||||
if entity_extractor is not None:
|
||||
entities = entity_extractor(query)
|
||||
else:
|
||||
# 简单实现:基于图中的节点匹配
|
||||
entities = self._extract_entities_from_query(query)
|
||||
|
||||
if not entities:
|
||||
logger.debug(f"未从查询中提取到实体: '{query}'")
|
||||
return self.compute(personalization=None, normalize=normalize)
|
||||
|
||||
# 计算PPR
|
||||
return self.compute_for_entities(entities, normalize=normalize)
|
||||
|
||||
def rank_nodes(
|
||||
self,
|
||||
scores: Dict[str, float],
|
||||
top_k: Optional[int] = None,
|
||||
min_score: float = 0.0,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
对节点排序
|
||||
|
||||
Args:
|
||||
scores: PageRank分数字典
|
||||
top_k: 返回前k个节点(None表示全部)
|
||||
min_score: 最小分数阈值
|
||||
|
||||
Returns:
|
||||
排序后的节点列表 [(节点名, 分数), ...]
|
||||
"""
|
||||
# 过滤低分节点
|
||||
filtered = [(node, score) for node, score in scores.items() if score >= min_score]
|
||||
|
||||
# 按分数降序排序
|
||||
sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 返回top_k
|
||||
if top_k is not None:
|
||||
sorted_nodes = sorted_nodes[:top_k]
|
||||
|
||||
return sorted_nodes
|
||||
|
||||
def get_personalization_vector(
|
||||
self,
|
||||
nodes: List[str],
|
||||
method: str = "uniform",
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
生成个性化向量
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
method: 生成方法
|
||||
- "uniform": 均匀权重
|
||||
- "degree": 按度数加权
|
||||
- "inverse_degree": 按度数反比加权
|
||||
|
||||
Returns:
|
||||
个性化向量 {节点名: 权重}
|
||||
"""
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
if method == "uniform":
|
||||
# 均匀权重
|
||||
weight = 1.0 / len(nodes)
|
||||
return {node: weight for node in nodes}
|
||||
|
||||
elif method == "degree":
|
||||
# 按度数加权
|
||||
node_degrees = {}
|
||||
for node in nodes:
|
||||
neighbors = self.graph_store.get_neighbors(node)
|
||||
node_degrees[node] = len(neighbors)
|
||||
|
||||
total_degree = sum(node_degrees.values())
|
||||
if total_degree > 0:
|
||||
return {node: degree / total_degree for node, degree in node_degrees.items()}
|
||||
else:
|
||||
return {node: 1.0 / len(nodes) for node in nodes}
|
||||
|
||||
elif method == "inverse_degree":
|
||||
# 按度数反比加权
|
||||
node_degrees = {}
|
||||
for node in nodes:
|
||||
neighbors = self.graph_store.get_neighbors(node)
|
||||
node_degrees[node] = len(neighbors)
|
||||
|
||||
# 反度数
|
||||
inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()}
|
||||
total_inv = sum(inv_degrees.values())
|
||||
|
||||
if total_inv > 0:
|
||||
return {node: inv / total_inv for node, inv in inv_degrees.items()}
|
||||
else:
|
||||
return {node: 1.0 / len(nodes) for node in nodes}
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的个性化向量生成方法: {method}")
|
||||
|
||||
def compare_scores(
|
||||
self,
|
||||
scores1: Dict[str, float],
|
||||
scores2: Dict[str, float],
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
比较两组PPR分数
|
||||
|
||||
Args:
|
||||
scores1: 第一组分数
|
||||
scores2: 第二组分数
|
||||
|
||||
Returns:
|
||||
比较结果 {
|
||||
"common_nodes": {节点: (score1, score2)},
|
||||
"only_in_1": {节点: score1},
|
||||
"only_in_2": {节点: score2},
|
||||
}
|
||||
"""
|
||||
common_nodes = {}
|
||||
only_in_1 = {}
|
||||
only_in_2 = {}
|
||||
|
||||
all_nodes = set(scores1.keys()) | set(scores2.keys())
|
||||
|
||||
for node in all_nodes:
|
||||
if node in scores1 and node in scores2:
|
||||
common_nodes[node] = (scores1[node], scores2[node])
|
||||
elif node in scores1:
|
||||
only_in_1[node] = scores1[node]
|
||||
else:
|
||||
only_in_2[node] = scores2[node]
|
||||
|
||||
return {
|
||||
"common_nodes": common_nodes,
|
||||
"only_in_1": only_in_1,
|
||||
"only_in_2": only_in_2,
|
||||
}
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
avg_iterations = (
|
||||
self._total_iterations / self._total_computations
|
||||
if self._total_computations > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"config": {
|
||||
"alpha": self.config.alpha,
|
||||
"max_iter": self.config.max_iter,
|
||||
"tol": self.config.tol,
|
||||
"normalize": self.config.normalize,
|
||||
"min_iterations": self.config.min_iterations,
|
||||
},
|
||||
"statistics": {
|
||||
"total_computations": self._total_computations,
|
||||
"total_iterations": self._total_iterations,
|
||||
"avg_iterations": avg_iterations,
|
||||
"convergence_history": self._convergence_history.copy(),
|
||||
},
|
||||
"graph": {
|
||||
"num_nodes": self.graph_store.num_nodes,
|
||||
"num_edges": self.graph_store.num_edges,
|
||||
},
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""重置统计信息"""
|
||||
self._total_computations = 0
|
||||
self._total_iterations = 0
|
||||
self._convergence_history.clear()
|
||||
logger.info("统计信息已重置")
|
||||
|
||||
def _extract_entities_from_query(self, query: str) -> List[str]:
|
||||
"""
|
||||
从查询中提取实体(简化实现)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
# 获取所有节点
|
||||
all_nodes = self.graph_store.get_nodes()
|
||||
if not all_nodes:
|
||||
return []
|
||||
|
||||
# 检查是否需要更新 Aho-Corasick 匹配器
|
||||
if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes):
|
||||
self._ac_matcher = AhoCorasick()
|
||||
for node in all_nodes:
|
||||
# 统一转为小写进行不区分大小写匹配
|
||||
self._ac_matcher.add_pattern(node.lower())
|
||||
self._ac_matcher.build()
|
||||
self._ac_nodes_count = len(all_nodes)
|
||||
|
||||
# 执行匹配
|
||||
query_lower = query.lower()
|
||||
stats = self._ac_matcher.find_all(query_lower)
|
||||
|
||||
# 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值)
|
||||
# 为了简单,AC 中 add_pattern 存的是小写
|
||||
# 我们需要一个映射:小写 -> 原始
|
||||
node_map = {node.lower(): node for node in all_nodes}
|
||||
entities = [node_map[low_name] for low_name in stats.keys()]
|
||||
|
||||
return entities
|
||||
|
||||
@property
|
||||
def num_computations(self) -> int:
|
||||
"""计算次数"""
|
||||
return self._total_computations
|
||||
|
||||
@property
|
||||
def avg_iterations(self) -> float:
|
||||
"""平均迭代次数"""
|
||||
if self._total_computations == 0:
|
||||
return 0.0
|
||||
return self._total_iterations / self._total_computations
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PersonalizedPageRank("
|
||||
f"alpha={self.config.alpha}, "
|
||||
f"computations={self._total_computations})"
|
||||
)
|
||||
|
||||
|
||||
def create_ppr_from_graph(
|
||||
graph_store: GraphStore,
|
||||
alpha: float = 0.85,
|
||||
max_iter: int = 100,
|
||||
) -> PersonalizedPageRank:
|
||||
"""
|
||||
从图存储创建PPR计算器
|
||||
|
||||
Args:
|
||||
graph_store: 图存储
|
||||
alpha: 阻尼系数
|
||||
max_iter: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
PPR计算器实例
|
||||
"""
|
||||
config = PageRankConfig(
|
||||
alpha=alpha,
|
||||
max_iter=max_iter,
|
||||
)
|
||||
|
||||
return PersonalizedPageRank(
|
||||
graph_store=graph_store,
|
||||
config=config,
|
||||
)
|
||||
402
plugins/A_memorix/core/retrieval/sparse_bm25.py
Normal file
402
plugins/A_memorix/core/retrieval/sparse_bm25.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
稀疏检索组件(FTS5 + BM25)
|
||||
|
||||
支持:
|
||||
- 懒加载索引连接
|
||||
- jieba / char n-gram 分词
|
||||
- 可卸载并收缩 SQLite 内存缓存
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..storage import MetadataStore
|
||||
|
||||
logger = get_logger("A_Memorix.SparseBM25")
|
||||
|
||||
try:
|
||||
import jieba # type: ignore
|
||||
|
||||
HAS_JIEBA = True
|
||||
except Exception:
|
||||
HAS_JIEBA = False
|
||||
jieba = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseBM25Config:
|
||||
"""BM25 稀疏检索配置。"""
|
||||
|
||||
enabled: bool = True
|
||||
backend: str = "fts5"
|
||||
lazy_load: bool = True
|
||||
mode: str = "auto" # auto | fallback_only | hybrid
|
||||
tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram
|
||||
jieba_user_dict: str = ""
|
||||
char_ngram_n: int = 2
|
||||
candidate_k: int = 80
|
||||
max_doc_len: int = 2000
|
||||
enable_ngram_fallback_index: bool = True
|
||||
enable_like_fallback: bool = False
|
||||
enable_relation_sparse_fallback: bool = True
|
||||
relation_candidate_k: int = 60
|
||||
relation_max_doc_len: int = 512
|
||||
unload_on_disable: bool = True
|
||||
shrink_memory_on_unload: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.backend = str(self.backend or "fts5").strip().lower()
|
||||
self.mode = str(self.mode or "auto").strip().lower()
|
||||
self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower()
|
||||
self.char_ngram_n = max(1, int(self.char_ngram_n))
|
||||
self.candidate_k = max(1, int(self.candidate_k))
|
||||
self.max_doc_len = max(0, int(self.max_doc_len))
|
||||
self.relation_candidate_k = max(1, int(self.relation_candidate_k))
|
||||
self.relation_max_doc_len = max(0, int(self.relation_max_doc_len))
|
||||
if self.backend != "fts5":
|
||||
raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}")
|
||||
if self.mode not in {"auto", "fallback_only", "hybrid"}:
|
||||
raise ValueError(f"sparse.mode 非法: {self.mode}")
|
||||
if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}:
|
||||
raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}")
|
||||
|
||||
|
||||
class SparseBM25Index:
|
||||
"""
|
||||
基于 SQLite FTS5 的 BM25 检索适配层。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: MetadataStore,
|
||||
config: Optional[SparseBM25Config] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.config = config or SparseBM25Config()
|
||||
self._conn: Optional[sqlite3.Connection] = None
|
||||
self._loaded: bool = False
|
||||
self._jieba_dict_loaded: bool = False
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self._loaded and self._conn is not None
|
||||
|
||||
def ensure_loaded(self) -> bool:
|
||||
"""按需加载 FTS 连接与索引。"""
|
||||
if not self.config.enabled:
|
||||
return False
|
||||
if self.loaded:
|
||||
return True
|
||||
|
||||
db_path = self.metadata_store.get_db_path()
|
||||
conn = sqlite3.connect(
|
||||
str(db_path),
|
||||
check_same_thread=False,
|
||||
timeout=30.0,
|
||||
)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute("PRAGMA temp_store=MEMORY")
|
||||
|
||||
if not self.metadata_store.ensure_fts_schema(conn=conn):
|
||||
conn.close()
|
||||
return False
|
||||
self.metadata_store.ensure_fts_backfilled(conn=conn)
|
||||
# 关系稀疏检索按独立开关加载,避免不必要的初始化开销。
|
||||
if self.config.enable_relation_sparse_fallback:
|
||||
self.metadata_store.ensure_relations_fts_schema(conn=conn)
|
||||
self.metadata_store.ensure_relations_fts_backfilled(conn=conn)
|
||||
if self.config.enable_ngram_fallback_index:
|
||||
self.metadata_store.ensure_paragraph_ngram_schema(conn=conn)
|
||||
self.metadata_store.ensure_paragraph_ngram_backfilled(
|
||||
n=self.config.char_ngram_n,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
self._conn = conn
|
||||
self._loaded = True
|
||||
self._prepare_tokenizer()
|
||||
logger.info(
|
||||
"SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s",
|
||||
self.config.tokenizer_mode,
|
||||
self.config.mode,
|
||||
)
|
||||
return True
|
||||
|
||||
def _prepare_tokenizer(self) -> None:
|
||||
if self._jieba_dict_loaded:
|
||||
return
|
||||
if self.config.tokenizer_mode not in {"jieba", "mixed"}:
|
||||
return
|
||||
if not HAS_JIEBA:
|
||||
logger.warning("jieba 不可用,tokenizer 将退化为 char n-gram")
|
||||
return
|
||||
user_dict = str(self.config.jieba_user_dict or "").strip()
|
||||
if user_dict:
|
||||
try:
|
||||
jieba.load_userdict(user_dict) # type: ignore[union-attr]
|
||||
logger.info("已加载 jieba 用户词典: %s", user_dict)
|
||||
except Exception as e:
|
||||
logger.warning("加载 jieba 用户词典失败: %s", e)
|
||||
self._jieba_dict_loaded = True
|
||||
|
||||
def _tokenize_jieba(self, text: str) -> List[str]:
|
||||
if not HAS_JIEBA:
|
||||
return []
|
||||
try:
|
||||
tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr]
|
||||
return [t.strip().lower() for t in tokens if t and t.strip()]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _tokenize_char_ngram(self, text: str, n: int) -> List[str]:
|
||||
compact = re.sub(r"\s+", "", text.lower())
|
||||
if not compact:
|
||||
return []
|
||||
if len(compact) < n:
|
||||
return [compact]
|
||||
return [compact[i : i + n] for i in range(0, len(compact) - n + 1)]
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
text = str(text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
mode = self.config.tokenizer_mode
|
||||
if mode == "jieba":
|
||||
tokens = self._tokenize_jieba(text)
|
||||
if tokens:
|
||||
return list(dict.fromkeys(tokens))
|
||||
return self._tokenize_char_ngram(text, self.config.char_ngram_n)
|
||||
|
||||
if mode == "mixed":
|
||||
toks = self._tokenize_jieba(text)
|
||||
toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n))
|
||||
return list(dict.fromkeys([t for t in toks if t]))
|
||||
|
||||
return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n)))
|
||||
|
||||
def _build_match_query(self, tokens: List[str]) -> str:
|
||||
safe_tokens: List[str] = []
|
||||
for token in tokens:
|
||||
t = token.replace('"', '""').strip()
|
||||
if not t:
|
||||
continue
|
||||
safe_tokens.append(f'"{t}"')
|
||||
if not safe_tokens:
|
||||
return ""
|
||||
# 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。
|
||||
return " OR ".join(safe_tokens[:64])
|
||||
|
||||
def _fallback_substring_search(
|
||||
self,
|
||||
tokens: List[str],
|
||||
limit: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。
|
||||
|
||||
说明:
|
||||
- FTS 索引当前采用 unicode61 tokenizer。
|
||||
- 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。
|
||||
- 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
# 去重并裁剪 token 数量,避免生成超长 SQL。
|
||||
uniq_tokens = [t for t in dict.fromkeys(tokens) if t]
|
||||
uniq_tokens = uniq_tokens[:32]
|
||||
if not uniq_tokens:
|
||||
return []
|
||||
|
||||
if self.config.enable_ngram_fallback_index:
|
||||
try:
|
||||
# 允许运行时切换开关后按需补齐 schema/回填。
|
||||
self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn)
|
||||
self.metadata_store.ensure_paragraph_ngram_backfilled(
|
||||
n=self.config.char_ngram_n,
|
||||
conn=self._conn,
|
||||
)
|
||||
rows = self.metadata_store.ngram_search_paragraphs(
|
||||
tokens=uniq_tokens,
|
||||
limit=limit,
|
||||
max_doc_len=self.config.max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
if rows:
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}")
|
||||
|
||||
if not self.config.enable_like_fallback:
|
||||
return []
|
||||
|
||||
conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens))
|
||||
params: List[Any] = [f"%{tok}%" for tok in uniq_tokens]
|
||||
scan_limit = max(int(limit) * 8, 200)
|
||||
params.append(scan_limit)
|
||||
|
||||
sql = f"""
|
||||
SELECT p.hash, p.content
|
||||
FROM paragraphs p
|
||||
WHERE (p.is_deleted IS NULL OR p.is_deleted = 0)
|
||||
AND ({conditions})
|
||||
LIMIT ?
|
||||
"""
|
||||
rows = self.metadata_store.query(sql, tuple(params))
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
scored: List[Dict[str, Any]] = []
|
||||
token_count = max(1, len(uniq_tokens))
|
||||
for row in rows:
|
||||
content = str(row.get("content") or "")
|
||||
content_low = content.lower()
|
||||
matched = [tok for tok in uniq_tokens if tok in content_low]
|
||||
if not matched:
|
||||
continue
|
||||
coverage = len(matched) / token_count
|
||||
length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low))
|
||||
# 兜底路径使用相对分,保持与上层接口兼容。
|
||||
fallback_score = coverage * 0.8 + length_bonus * 0.2
|
||||
scored.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content,
|
||||
"bm25_score": -float(fallback_score),
|
||||
"fallback_score": float(fallback_score),
|
||||
}
|
||||
)
|
||||
|
||||
scored.sort(key=lambda x: x["fallback_score"], reverse=True)
|
||||
return scored[:limit]
|
||||
|
||||
def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
|
||||
"""执行 BM25 检索。"""
|
||||
if not self.config.enabled:
|
||||
return []
|
||||
if self.config.lazy_load and not self.loaded:
|
||||
if not self.ensure_loaded():
|
||||
return []
|
||||
if not self.loaded:
|
||||
return []
|
||||
# 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。
|
||||
self.metadata_store.ensure_relations_fts_schema(conn=self._conn)
|
||||
self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn)
|
||||
|
||||
tokens = self._tokenize(query)
|
||||
match_query = self._build_match_query(tokens)
|
||||
if not match_query:
|
||||
return []
|
||||
|
||||
limit = max(1, int(k))
|
||||
rows = self.metadata_store.fts_search_bm25(
|
||||
match_query=match_query,
|
||||
limit=limit,
|
||||
max_doc_len=self.config.max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
if not rows:
|
||||
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
bm25_score = float(row.get("bm25_score", 0.0))
|
||||
results.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"content": row["content"],
|
||||
"rank": rank,
|
||||
"bm25_score": bm25_score,
|
||||
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
|
||||
"""执行关系稀疏检索(FTS5 + BM25)。"""
|
||||
if not self.config.enabled or not self.config.enable_relation_sparse_fallback:
|
||||
return []
|
||||
if self.config.lazy_load and not self.loaded:
|
||||
if not self.ensure_loaded():
|
||||
return []
|
||||
if not self.loaded:
|
||||
return []
|
||||
|
||||
tokens = self._tokenize(query)
|
||||
match_query = self._build_match_query(tokens)
|
||||
if not match_query:
|
||||
return []
|
||||
|
||||
rows = self.metadata_store.fts_search_relations_bm25(
|
||||
match_query=match_query,
|
||||
limit=max(1, int(k)),
|
||||
max_doc_len=self.config.relation_max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
out: List[Dict[str, Any]] = []
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
bm25_score = float(row.get("bm25_score", 0.0))
|
||||
out.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"subject": row["subject"],
|
||||
"predicate": row["predicate"],
|
||||
"object": row["object"],
|
||||
"content": row["content"],
|
||||
"rank": rank,
|
||||
"bm25_score": bm25_score,
|
||||
"score": -bm25_score,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
def upsert_paragraph(self, paragraph_hash: str) -> bool:
|
||||
if not self.loaded:
|
||||
return False
|
||||
return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn)
|
||||
|
||||
def delete_paragraph(self, paragraph_hash: str) -> bool:
|
||||
if not self.loaded:
|
||||
return False
|
||||
return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn)
|
||||
|
||||
def unload(self) -> None:
|
||||
"""卸载 BM25 连接并尽量释放内存。"""
|
||||
if self._conn is not None:
|
||||
try:
|
||||
if self.config.shrink_memory_on_unload:
|
||||
self.metadata_store.shrink_memory(conn=self._conn)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
self._loaded = False
|
||||
logger.info("SparseBM25Index unloaded")
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
doc_count = 0
|
||||
if self.loaded:
|
||||
doc_count = self.metadata_store.fts_doc_count(conn=self._conn)
|
||||
return {
|
||||
"enabled": self.config.enabled,
|
||||
"backend": self.config.backend,
|
||||
"mode": self.config.mode,
|
||||
"tokenizer_mode": self.config.tokenizer_mode,
|
||||
"enable_ngram_fallback_index": self.config.enable_ngram_fallback_index,
|
||||
"enable_like_fallback": self.config.enable_like_fallback,
|
||||
"enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback,
|
||||
"loaded": self.loaded,
|
||||
"has_jieba": HAS_JIEBA,
|
||||
"doc_count": doc_count,
|
||||
}
|
||||
450
plugins/A_memorix/core/retrieval/threshold.py
Normal file
450
plugins/A_memorix/core/retrieval/threshold.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
动态阈值过滤器
|
||||
|
||||
根据检索结果的分布特征自适应调整过滤阈值。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .dual_path import RetrievalResult
|
||||
|
||||
logger = get_logger("A_Memorix.DynamicThresholdFilter")
|
||||
|
||||
|
||||
class ThresholdMethod(Enum):
|
||||
"""阈值计算方法"""
|
||||
|
||||
PERCENTILE = "percentile" # 百分位数
|
||||
STD_DEV = "std_dev" # 标准差
|
||||
GAP_DETECTION = "gap_detection" # 跳变检测
|
||||
ADAPTIVE = "adaptive" # 自适应(综合多种方法)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThresholdConfig:
|
||||
"""
|
||||
阈值配置
|
||||
|
||||
属性:
|
||||
method: 阈值计算方法
|
||||
min_threshold: 最小阈值(绝对值)
|
||||
max_threshold: 最大阈值(绝对值)
|
||||
percentile: 百分位数(用于percentile方法)
|
||||
std_multiplier: 标准差倍数(用于std_dev方法)
|
||||
min_results: 最少保留结果数
|
||||
enable_auto_adjust: 是否自动调整参数
|
||||
"""
|
||||
|
||||
method: ThresholdMethod = ThresholdMethod.ADAPTIVE
|
||||
min_threshold: float = 0.3
|
||||
max_threshold: float = 0.95
|
||||
percentile: float = 75.0 # 百分位数
|
||||
std_multiplier: float = 1.5 # 标准差倍数
|
||||
min_results: int = 3 # 最少保留结果数
|
||||
enable_auto_adjust: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
if not 0 <= self.min_threshold <= 1:
|
||||
raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}")
|
||||
|
||||
if not 0 <= self.max_threshold <= 1:
|
||||
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
|
||||
|
||||
if self.min_threshold >= self.max_threshold:
|
||||
raise ValueError(f"min_threshold必须小于max_threshold")
|
||||
|
||||
if not 0 <= self.percentile <= 100:
|
||||
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
|
||||
|
||||
if self.std_multiplier <= 0:
|
||||
raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}")
|
||||
|
||||
if self.min_results < 0:
|
||||
raise ValueError(f"min_results必须大于等于0: {self.min_results}")
|
||||
|
||||
|
||||
class DynamicThresholdFilter:
|
||||
"""
|
||||
动态阈值过滤器
|
||||
|
||||
功能:
|
||||
- 基于结果分布自适应计算阈值
|
||||
- 多种阈值计算方法
|
||||
- 自动参数调整
|
||||
- 统计信息收集
|
||||
|
||||
参数:
|
||||
config: 阈值配置
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ThresholdConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化动态阈值过滤器
|
||||
|
||||
Args:
|
||||
config: 阈值配置
|
||||
"""
|
||||
self.config = config or ThresholdConfig()
|
||||
|
||||
# 统计信息
|
||||
self._total_filtered = 0
|
||||
self._total_processed = 0
|
||||
self._threshold_history: List[float] = []
|
||||
|
||||
logger.info(
|
||||
f"DynamicThresholdFilter 初始化: "
|
||||
f"method={self.config.method.value}, "
|
||||
f"min_threshold={self.config.min_threshold}"
|
||||
)
|
||||
|
||||
def filter(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
return_threshold: bool = False,
|
||||
) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]:
|
||||
"""
|
||||
过滤检索结果
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
return_threshold: 是否返回使用的阈值
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表,或 (结果列表, 阈值) 元组
|
||||
"""
|
||||
if not results:
|
||||
logger.debug("结果列表为空,无需过滤")
|
||||
return ([], 0.0) if return_threshold else []
|
||||
|
||||
self._total_processed += len(results)
|
||||
|
||||
# 提取分数
|
||||
scores = np.array([r.score for r in results])
|
||||
|
||||
# 计算阈值
|
||||
threshold = self._compute_threshold(scores, results)
|
||||
|
||||
# 记录阈值
|
||||
self._threshold_history.append(threshold)
|
||||
|
||||
# 应用阈值过滤
|
||||
filtered_results = [
|
||||
r for r in results
|
||||
if r.score >= threshold
|
||||
]
|
||||
|
||||
# 确保至少保留min_results个结果
|
||||
if len(filtered_results) < self.config.min_results:
|
||||
# 按分数排序,取前min_results个
|
||||
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||
filtered_results = sorted_results[:self.config.min_results]
|
||||
threshold = filtered_results[-1].score if filtered_results else 0.0
|
||||
|
||||
self._total_filtered += len(results) - len(filtered_results)
|
||||
|
||||
logger.info(
|
||||
f"过滤完成: {len(results)} -> {len(filtered_results)} "
|
||||
f"(threshold={threshold:.3f})"
|
||||
)
|
||||
|
||||
if return_threshold:
|
||||
return filtered_results, threshold
|
||||
return filtered_results
|
||||
|
||||
def _compute_threshold(
|
||||
self,
|
||||
scores: np.ndarray,
|
||||
results: List[RetrievalResult],
|
||||
) -> float:
|
||||
"""
|
||||
计算阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
if self.config.method == ThresholdMethod.PERCENTILE:
|
||||
threshold = self._percentile_threshold(scores)
|
||||
elif self.config.method == ThresholdMethod.STD_DEV:
|
||||
threshold = self._std_dev_threshold(scores)
|
||||
elif self.config.method == ThresholdMethod.GAP_DETECTION:
|
||||
threshold = self._gap_detection_threshold(scores)
|
||||
else: # ADAPTIVE
|
||||
# 自适应方法:综合多种方法
|
||||
thresholds = [
|
||||
self._percentile_threshold(scores),
|
||||
self._std_dev_threshold(scores),
|
||||
self._gap_detection_threshold(scores),
|
||||
]
|
||||
# 使用中位数作为最终阈值
|
||||
threshold = float(np.median(thresholds))
|
||||
|
||||
# 限制在[min_threshold, max_threshold]范围内
|
||||
threshold = np.clip(
|
||||
threshold,
|
||||
self.config.min_threshold,
|
||||
self.config.max_threshold,
|
||||
)
|
||||
|
||||
# 自动调整
|
||||
if self.config.enable_auto_adjust:
|
||||
threshold = self._auto_adjust_threshold(threshold, scores)
|
||||
|
||||
return float(threshold)
|
||||
|
||||
def _percentile_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于百分位数计算阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
percentile = self.config.percentile
|
||||
threshold = float(np.percentile(scores, percentile))
|
||||
|
||||
logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})")
|
||||
return threshold
|
||||
|
||||
def _std_dev_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于标准差计算阈值
|
||||
|
||||
threshold = mean - std_multiplier * std
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
mean = float(np.mean(scores))
|
||||
std = float(np.std(scores))
|
||||
multiplier = self.config.std_multiplier
|
||||
|
||||
threshold = mean - multiplier * std
|
||||
|
||||
logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})")
|
||||
return threshold
|
||||
|
||||
def _gap_detection_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于跳变检测计算阈值
|
||||
|
||||
找到分数分布中最大的"跳变"位置,以此为阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组(降序排列)
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
# 降序排列
|
||||
sorted_scores = np.sort(scores)[::-1]
|
||||
|
||||
if len(sorted_scores) < 2:
|
||||
return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0
|
||||
|
||||
# 计算相邻分数的差值
|
||||
gaps = np.diff(sorted_scores)
|
||||
|
||||
# 找到最大的跳变位置
|
||||
max_gap_idx = int(np.argmax(gaps))
|
||||
|
||||
# 阈值为跳变后的分数
|
||||
threshold = float(sorted_scores[max_gap_idx + 1])
|
||||
|
||||
logger.debug(
|
||||
f"跳变检测阈值: {threshold:.3f} "
|
||||
f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})"
|
||||
)
|
||||
return threshold
|
||||
|
||||
def _auto_adjust_threshold(
|
||||
self,
|
||||
threshold: float,
|
||||
scores: np.ndarray,
|
||||
) -> float:
|
||||
"""
|
||||
自动调整阈值
|
||||
|
||||
基于历史阈值和当前分数分布调整
|
||||
|
||||
Args:
|
||||
threshold: 当前阈值
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
调整后的阈值
|
||||
"""
|
||||
if not self._threshold_history:
|
||||
return threshold
|
||||
|
||||
# 计算历史阈值的移动平均
|
||||
recent_thresholds = self._threshold_history[-10:] # 最近10次
|
||||
avg_threshold = float(np.mean(recent_thresholds))
|
||||
|
||||
# 当前阈值与历史平均的差异
|
||||
diff = threshold - avg_threshold
|
||||
|
||||
# 如果差异过大(>0.2),向历史平均靠拢
|
||||
if abs(diff) > 0.2:
|
||||
adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50%
|
||||
logger.debug(
|
||||
f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} "
|
||||
f"(历史平均={avg_threshold:.3f})"
|
||||
)
|
||||
return adjusted_threshold
|
||||
|
||||
return threshold
|
||||
|
||||
def filter_by_confidence(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
min_confidence: float = 0.5,
|
||||
) -> List[RetrievalResult]:
|
||||
"""
|
||||
基于置信度过滤结果
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
min_confidence: 最小置信度
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表
|
||||
"""
|
||||
filtered = []
|
||||
for result in results:
|
||||
# 对于关系结果,使用confidence字段
|
||||
if result.result_type == "relation":
|
||||
confidence = result.metadata.get("confidence", 1.0)
|
||||
if confidence >= min_confidence:
|
||||
filtered.append(result)
|
||||
else:
|
||||
# 对于段落结果,直接使用分数
|
||||
if result.score >= min_confidence:
|
||||
filtered.append(result)
|
||||
|
||||
logger.info(
|
||||
f"置信度过滤: {len(results)} -> {len(filtered)} "
|
||||
f"(min_confidence={min_confidence})"
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
def filter_by_diversity(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
similarity_threshold: float = 0.9,
|
||||
top_k: int = 10,
|
||||
) -> List[RetrievalResult]:
|
||||
"""
|
||||
基于多样性过滤结果(去除重复)
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
similarity_threshold: 相似度阈值(高于此值视为重复)
|
||||
top_k: 最多保留结果数
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# 按分数排序
|
||||
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||
|
||||
# 贪心选择:选择与已选结果相似度低的结果
|
||||
selected = []
|
||||
selected_hashes = []
|
||||
|
||||
for result in sorted_results:
|
||||
if len(selected) >= top_k:
|
||||
break
|
||||
|
||||
# 检查与已选结果的相似度
|
||||
is_duplicate = False
|
||||
for selected_hash in selected_hashes:
|
||||
# 简单判断:基于hash的前缀
|
||||
if result.hash_value[:8] == selected_hash[:8]:
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
selected.append(result)
|
||||
selected_hashes.append(result.hash_value)
|
||||
|
||||
logger.info(
|
||||
f"多样性过滤: {len(results)} -> {len(selected)} "
|
||||
f"(similarity_threshold={similarity_threshold})"
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
filter_rate = (
|
||||
self._total_filtered / self._total_processed
|
||||
if self._total_processed > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
stats = {
|
||||
"config": {
|
||||
"method": self.config.method.value,
|
||||
"min_threshold": self.config.min_threshold,
|
||||
"max_threshold": self.config.max_threshold,
|
||||
"percentile": self.config.percentile,
|
||||
"std_multiplier": self.config.std_multiplier,
|
||||
"min_results": self.config.min_results,
|
||||
"enable_auto_adjust": self.config.enable_auto_adjust,
|
||||
},
|
||||
"statistics": {
|
||||
"total_processed": self._total_processed,
|
||||
"total_filtered": self._total_filtered,
|
||||
"filter_rate": filter_rate,
|
||||
"avg_threshold": float(np.mean(self._threshold_history))
|
||||
if self._threshold_history else 0.0,
|
||||
"threshold_count": len(self._threshold_history),
|
||||
},
|
||||
}
|
||||
|
||||
if self._threshold_history:
|
||||
stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history))
|
||||
stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history))
|
||||
|
||||
return stats
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""重置统计信息"""
|
||||
self._total_filtered = 0
|
||||
self._total_processed = 0
|
||||
self._threshold_history.clear()
|
||||
logger.info("统计信息已重置")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"DynamicThresholdFilter("
|
||||
f"method={self.config.method.value}, "
|
||||
f"min_threshold={self.config.min_threshold}, "
|
||||
f"filtered={self._total_filtered}/{self._total_processed})"
|
||||
)
|
||||
Reference in New Issue
Block a user