From e41922e24c0013786193cb30bae4a868ae1267a2 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:08:14 +0800 Subject: [PATCH] =?UTF-8?q?feat(A=5Fmemorix):=20=E6=94=B6=E7=B4=A7?= =?UTF-8?q?=E7=A8=80=E7=96=8F=E5=B0=BE=E9=83=A8=E5=8F=AC=E5=9B=9E=E5=B9=B6?= =?UTF-8?q?=E6=94=B9=E8=BF=9B=20PPR=20=E5=A2=9E=E7=9B=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/A_memorix/core/retrieval/dual_path.py | 65 ++++++++++++++++++++- src/A_memorix/core/retrieval/sparse_bm25.py | 9 ++- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/A_memorix/core/retrieval/dual_path.py b/src/A_memorix/core/retrieval/dual_path.py index 437f3dd7..c03be548 100644 --- a/src/A_memorix/core/retrieval/dual_path.py +++ b/src/A_memorix/core/retrieval/dual_path.py @@ -588,6 +588,7 @@ class DualPathRetriever: candidate_k = max(top_k, self.config.sparse.candidate_k) candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) sparse_rows = self.sparse_index.search(query=query, k=candidate_k) + sparse_rows = self._filter_sparse_paragraph_rows(sparse_rows) results: List[RetrievalResult] = [] for row in sparse_rows: hash_value = row["hash"] @@ -614,6 +615,53 @@ class DualPathRetriever: self._normalize_scores_minmax(results) return results + def _filter_sparse_paragraph_rows( + self, + rows: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """ + 过滤 paragraph sparse tail。 + + 目标不是压缩强 lexical hit,而是避免只命中一个弱 token 的尾部结果 + 在 weighted RRF 中拿到过高的 rank credit。 + """ + if len(rows) <= 2: + return rows + + top_score = max(0.0, float(rows[0].get("score", 0.0) or 0.0)) + if top_score <= 0.0: + return rows[:2] + + relative_floor = top_score * 0.2 + filtered_rows: List[Dict[str, Any]] = [] + removed_count = 0 + for index, row in enumerate(rows): + if index < 2: + filtered_rows.append(row) + continue + + raw_score = float(row.get("score", 0.0) or 0.0) + matched_token_count = int(row.get("matched_token_count", 0) or 0) + matched_token_ratio = float(row.get("matched_token_ratio", 0.0) or 0.0) + + if ( + raw_score >= relative_floor + or matched_token_count >= 3 + or (matched_token_count >= 2 and matched_token_ratio >= 0.12) + ): + filtered_rows.append(row) + continue + + removed_count += 1 + + if removed_count > 0: + logger.debug( + "sparse_paragraph_tail_pruned=1 " + f"removed_count={removed_count} " + f"kept_count={len(filtered_rows)}" + ) + return filtered_rows + def _search_relations_sparse( self, query: str, @@ -1560,9 +1608,20 @@ class DualPathRetriever: entity_scores.append(ppr_scores_by_name[ent_name]) if entity_scores: - avg_ppr = np.mean(entity_scores) - # 融合原始分数和PPR分数 - result.score = result.score * 0.7 + avg_ppr * 0.3 + # 只使用命中的高价值图实体做正向增益,避免把原本高分的正确段落 + # 因为“实体多但非全部命中”而反向压低。 + focus_scores = sorted(entity_scores, reverse=True)[:2] + ppr_signal = float(np.mean(focus_scores)) + boost_weight = 0.12 if len(focus_scores) >= 2 else 0.06 + boost = ppr_signal * boost_weight + + metadata = result.metadata if isinstance(result.metadata, dict) else {} + metadata["ppr_signal"] = round(ppr_signal, 4) + metadata["ppr_focus_entity_count"] = len(focus_scores) + metadata["ppr_boost"] = round(boost, 4) + result.metadata = metadata + + result.score = float(result.score) + float(boost) # 重新排序 results.sort(key=lambda x: x.score, reverse=True) diff --git a/src/A_memorix/core/retrieval/sparse_bm25.py b/src/A_memorix/core/retrieval/sparse_bm25.py index 276e8778..7808a516 100644 --- a/src/A_memorix/core/retrieval/sparse_bm25.py +++ b/src/A_memorix/core/retrieval/sparse_bm25.py @@ -306,15 +306,22 @@ class SparseBM25Index: rows = self._fallback_substring_search(tokens=tokens, limit=limit) results: List[Dict[str, Any]] = [] + token_count = max(1, len(tokens)) for rank, row in enumerate(rows, start=1): bm25_score = float(row.get("bm25_score", 0.0)) + content = str(row.get("content", "") or "") + content_low = content.lower() + matched_tokens = [token for token in tokens if token in content_low] + matched_token_count = len(dict.fromkeys(matched_tokens)) results.append( { "hash": row["hash"], - "content": row["content"], + "content": content, "rank": rank, "bm25_score": bm25_score, "score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数 + "matched_token_count": matched_token_count, + "matched_token_ratio": matched_token_count / float(token_count), } ) return results