feat(A_memorix): 收紧稀疏尾部召回并改进 PPR 增益

This commit is contained in:
A-Dawn
2026-04-21 14:08:14 +08:00
parent c6e2c6e003
commit e41922e24c
2 changed files with 70 additions and 4 deletions

View File

@@ -588,6 +588,7 @@ class DualPathRetriever:
candidate_k = max(top_k, self.config.sparse.candidate_k)
candidate_k = self._cap_temporal_scan_k(candidate_k, temporal)
sparse_rows = self.sparse_index.search(query=query, k=candidate_k)
sparse_rows = self._filter_sparse_paragraph_rows(sparse_rows)
results: List[RetrievalResult] = []
for row in sparse_rows:
hash_value = row["hash"]
@@ -614,6 +615,53 @@ class DualPathRetriever:
self._normalize_scores_minmax(results)
return results
def _filter_sparse_paragraph_rows(
self,
rows: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
过滤 paragraph sparse tail。
目标不是压缩强 lexical hit而是避免只命中一个弱 token 的尾部结果
在 weighted RRF 中拿到过高的 rank credit。
"""
if len(rows) <= 2:
return rows
top_score = max(0.0, float(rows[0].get("score", 0.0) or 0.0))
if top_score <= 0.0:
return rows[:2]
relative_floor = top_score * 0.2
filtered_rows: List[Dict[str, Any]] = []
removed_count = 0
for index, row in enumerate(rows):
if index < 2:
filtered_rows.append(row)
continue
raw_score = float(row.get("score", 0.0) or 0.0)
matched_token_count = int(row.get("matched_token_count", 0) or 0)
matched_token_ratio = float(row.get("matched_token_ratio", 0.0) or 0.0)
if (
raw_score >= relative_floor
or matched_token_count >= 3
or (matched_token_count >= 2 and matched_token_ratio >= 0.12)
):
filtered_rows.append(row)
continue
removed_count += 1
if removed_count > 0:
logger.debug(
"sparse_paragraph_tail_pruned=1 "
f"removed_count={removed_count} "
f"kept_count={len(filtered_rows)}"
)
return filtered_rows
def _search_relations_sparse(
self,
query: str,
@@ -1560,9 +1608,20 @@ class DualPathRetriever:
entity_scores.append(ppr_scores_by_name[ent_name])
if entity_scores:
avg_ppr = np.mean(entity_scores)
# 融合原始分数和PPR分数
result.score = result.score * 0.7 + avg_ppr * 0.3
# 只使用命中的高价值图实体做正向增益,避免把原本高分的正确段落
# 因为“实体多但非全部命中”而反向压低。
focus_scores = sorted(entity_scores, reverse=True)[:2]
ppr_signal = float(np.mean(focus_scores))
boost_weight = 0.12 if len(focus_scores) >= 2 else 0.06
boost = ppr_signal * boost_weight
metadata = result.metadata if isinstance(result.metadata, dict) else {}
metadata["ppr_signal"] = round(ppr_signal, 4)
metadata["ppr_focus_entity_count"] = len(focus_scores)
metadata["ppr_boost"] = round(boost, 4)
result.metadata = metadata
result.score = float(result.score) + float(boost)
# 重新排序
results.sort(key=lambda x: x.score, reverse=True)

View File

@@ -306,15 +306,22 @@ class SparseBM25Index:
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
results: List[Dict[str, Any]] = []
token_count = max(1, len(tokens))
for rank, row in enumerate(rows, start=1):
bm25_score = float(row.get("bm25_score", 0.0))
content = str(row.get("content", "") or "")
content_low = content.lower()
matched_tokens = [token for token in tokens if token in content_low]
matched_token_count = len(dict.fromkeys(matched_tokens))
results.append(
{
"hash": row["hash"],
"content": row["content"],
"content": content,
"rank": rank,
"bm25_score": bm25_score,
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
"matched_token_count": matched_token_count,
"matched_token_ratio": matched_token_count / float(token_count),
}
)
return results