feat(A_memorix): 收紧稀疏尾部召回并改进 PPR 增益
This commit is contained in:
@@ -588,6 +588,7 @@ class DualPathRetriever:
|
||||
candidate_k = max(top_k, self.config.sparse.candidate_k)
|
||||
candidate_k = self._cap_temporal_scan_k(candidate_k, temporal)
|
||||
sparse_rows = self.sparse_index.search(query=query, k=candidate_k)
|
||||
sparse_rows = self._filter_sparse_paragraph_rows(sparse_rows)
|
||||
results: List[RetrievalResult] = []
|
||||
for row in sparse_rows:
|
||||
hash_value = row["hash"]
|
||||
@@ -614,6 +615,53 @@ class DualPathRetriever:
|
||||
self._normalize_scores_minmax(results)
|
||||
return results
|
||||
|
||||
def _filter_sparse_paragraph_rows(
|
||||
self,
|
||||
rows: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
过滤 paragraph sparse tail。
|
||||
|
||||
目标不是压缩强 lexical hit,而是避免只命中一个弱 token 的尾部结果
|
||||
在 weighted RRF 中拿到过高的 rank credit。
|
||||
"""
|
||||
if len(rows) <= 2:
|
||||
return rows
|
||||
|
||||
top_score = max(0.0, float(rows[0].get("score", 0.0) or 0.0))
|
||||
if top_score <= 0.0:
|
||||
return rows[:2]
|
||||
|
||||
relative_floor = top_score * 0.2
|
||||
filtered_rows: List[Dict[str, Any]] = []
|
||||
removed_count = 0
|
||||
for index, row in enumerate(rows):
|
||||
if index < 2:
|
||||
filtered_rows.append(row)
|
||||
continue
|
||||
|
||||
raw_score = float(row.get("score", 0.0) or 0.0)
|
||||
matched_token_count = int(row.get("matched_token_count", 0) or 0)
|
||||
matched_token_ratio = float(row.get("matched_token_ratio", 0.0) or 0.0)
|
||||
|
||||
if (
|
||||
raw_score >= relative_floor
|
||||
or matched_token_count >= 3
|
||||
or (matched_token_count >= 2 and matched_token_ratio >= 0.12)
|
||||
):
|
||||
filtered_rows.append(row)
|
||||
continue
|
||||
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(
|
||||
"sparse_paragraph_tail_pruned=1 "
|
||||
f"removed_count={removed_count} "
|
||||
f"kept_count={len(filtered_rows)}"
|
||||
)
|
||||
return filtered_rows
|
||||
|
||||
def _search_relations_sparse(
|
||||
self,
|
||||
query: str,
|
||||
@@ -1560,9 +1608,20 @@ class DualPathRetriever:
|
||||
entity_scores.append(ppr_scores_by_name[ent_name])
|
||||
|
||||
if entity_scores:
|
||||
avg_ppr = np.mean(entity_scores)
|
||||
# 融合原始分数和PPR分数
|
||||
result.score = result.score * 0.7 + avg_ppr * 0.3
|
||||
# 只使用命中的高价值图实体做正向增益,避免把原本高分的正确段落
|
||||
# 因为“实体多但非全部命中”而反向压低。
|
||||
focus_scores = sorted(entity_scores, reverse=True)[:2]
|
||||
ppr_signal = float(np.mean(focus_scores))
|
||||
boost_weight = 0.12 if len(focus_scores) >= 2 else 0.06
|
||||
boost = ppr_signal * boost_weight
|
||||
|
||||
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||
metadata["ppr_signal"] = round(ppr_signal, 4)
|
||||
metadata["ppr_focus_entity_count"] = len(focus_scores)
|
||||
metadata["ppr_boost"] = round(boost, 4)
|
||||
result.metadata = metadata
|
||||
|
||||
result.score = float(result.score) + float(boost)
|
||||
|
||||
# 重新排序
|
||||
results.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
@@ -306,15 +306,22 @@ class SparseBM25Index:
|
||||
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
token_count = max(1, len(tokens))
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
bm25_score = float(row.get("bm25_score", 0.0))
|
||||
content = str(row.get("content", "") or "")
|
||||
content_low = content.lower()
|
||||
matched_tokens = [token for token in tokens if token in content_low]
|
||||
matched_token_count = len(dict.fromkeys(matched_tokens))
|
||||
results.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"content": row["content"],
|
||||
"content": content,
|
||||
"rank": rank,
|
||||
"bm25_score": bm25_score,
|
||||
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
|
||||
"matched_token_count": matched_token_count,
|
||||
"matched_token_ratio": matched_token_count / float(token_count),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user