From c6e2c6e003783c32a100afcd2b550c20554bbb0c Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:06:46 +0800 Subject: [PATCH 1/4] =?UTF-8?q?chore(A=5Fmemorix):=20=E5=9B=9E=E9=80=80?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E6=95=B4=E7=90=86=E4=BB=A5=E6=81=A2=E5=A4=8D?= =?UTF-8?q?=E6=9C=AC=E5=9C=B0=E4=B8=BB=E5=AF=BC=E5=9F=BA=E7=BA=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/A_memorix/core/embedding/manager.py | 3 +++ src/A_memorix/core/embedding/presets.py | 2 +- src/A_memorix/core/retrieval/dual_path.py | 4 ++-- src/A_memorix/core/retrieval/pagerank.py | 5 +++-- src/A_memorix/core/retrieval/threshold.py | 2 +- src/A_memorix/core/runtime/lifecycle_orchestrator.py | 1 + src/A_memorix/core/runtime/sdk_memory_kernel.py | 3 ++- src/A_memorix/core/storage/graph_store.py | 2 ++ src/A_memorix/core/storage/vector_store.py | 3 ++- src/A_memorix/core/strategies/base.py | 2 +- src/A_memorix/core/strategies/factual.py | 2 +- src/A_memorix/core/strategies/narrative.py | 2 +- src/A_memorix/core/strategies/quote.py | 2 +- src/A_memorix/core/utils/hash.py | 1 + src/A_memorix/core/utils/io.py | 1 + src/A_memorix/core/utils/matcher.py | 2 +- src/A_memorix/core/utils/path_fallback_service.py | 2 +- src/A_memorix/core/utils/runtime_self_check.py | 2 +- src/A_memorix/core/utils/search_execution_service.py | 4 ++-- src/A_memorix/core/utils/summary_importer.py | 1 + src/A_memorix/runtime_registry.py | 2 +- src/A_memorix/scripts/_bootstrap.py | 2 +- src/A_memorix/scripts/convert_lpmm.py | 4 +++- src/A_memorix/scripts/process_knowledge.py | 5 ++++- src/A_memorix/scripts/release_vnext_migrate.py | 2 +- 25 files changed, 39 insertions(+), 22 deletions(-) diff --git a/src/A_memorix/core/embedding/manager.py b/src/A_memorix/core/embedding/manager.py index a65689ac..d161e23b 100644 --- a/src/A_memorix/core/embedding/manager.py +++ b/src/A_memorix/core/embedding/manager.py @@ -23,7 +23,10 @@ from src.common.logger import get_logger from .presets import ( EmbeddingModelConfig, get_custom_config, + validate_config_compatibility, + are_models_compatible, ) +from ..utils.quantization import QuantizationType logger = get_logger("A_Memorix.EmbeddingManager") diff --git a/src/A_memorix/core/embedding/presets.py b/src/A_memorix/core/embedding/presets.py index 88714b22..54e6f8b4 100644 --- a/src/A_memorix/core/embedding/presets.py +++ b/src/A_memorix/core/embedding/presets.py @@ -3,7 +3,7 @@ """ from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Dict, Any, Union from pathlib import Path diff --git a/src/A_memorix/core/retrieval/dual_path.py b/src/A_memorix/core/retrieval/dual_path.py index ae701906..437f3dd7 100644 --- a/src/A_memorix/core/retrieval/dual_path.py +++ b/src/A_memorix/core/retrieval/dual_path.py @@ -7,7 +7,7 @@ import asyncio import re from dataclasses import dataclass, field -from typing import Optional, List, Dict, Any, Tuple +from typing import Optional, List, Dict, Any, Tuple, Union from enum import Enum import numpy as np @@ -320,7 +320,7 @@ class DualPathRetriever: # 调试模式:打印结果原文 if self.config.debug: - logger.info("[DEBUG] 检索结果内容原文:") + logger.info(f"[DEBUG] 检索结果内容原文:") for i, res in enumerate(results): logger.info(f" {i+1}. [{res.result_type}] (Score: {res.score:.4f}) {res.content}") diff --git a/src/A_memorix/core/retrieval/pagerank.py b/src/A_memorix/core/retrieval/pagerank.py index 36e456d8..c8ee48bb 100644 --- a/src/A_memorix/core/retrieval/pagerank.py +++ b/src/A_memorix/core/retrieval/pagerank.py @@ -4,8 +4,9 @@ Personalized PageRank实现 提供个性化的图节点排序功能。 """ -from typing import Dict, List, Optional, Tuple, Any +from typing import Dict, List, Optional, Tuple, Union, Any from dataclasses import dataclass +import numpy as np from src.common.logger import get_logger from ..storage import GraphStore @@ -48,7 +49,7 @@ class PageRankConfig: raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}") if self.min_iterations >= self.max_iter: - raise ValueError("min_iterations必须小于max_iter") + raise ValueError(f"min_iterations必须小于max_iter") class PersonalizedPageRank: diff --git a/src/A_memorix/core/retrieval/threshold.py b/src/A_memorix/core/retrieval/threshold.py index fc342b52..87a0094b 100644 --- a/src/A_memorix/core/retrieval/threshold.py +++ b/src/A_memorix/core/retrieval/threshold.py @@ -56,7 +56,7 @@ class ThresholdConfig: raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}") if self.min_threshold >= self.max_threshold: - raise ValueError("min_threshold必须小于max_threshold") + raise ValueError(f"min_threshold必须小于max_threshold") if not 0 <= self.percentile <= 100: raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}") diff --git a/src/A_memorix/core/runtime/lifecycle_orchestrator.py b/src/A_memorix/core/runtime/lifecycle_orchestrator.py index 64746205..a421b05a 100644 --- a/src/A_memorix/core/runtime/lifecycle_orchestrator.py +++ b/src/A_memorix/core/runtime/lifecycle_orchestrator.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from pathlib import Path from typing import Any, Callable, Coroutine, cast from src.common.logger import get_logger diff --git a/src/A_memorix/core/runtime/sdk_memory_kernel.py b/src/A_memorix/core/runtime/sdk_memory_kernel.py index ed2a60ed..12681e05 100644 --- a/src/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/src/A_memorix/core/runtime/sdk_memory_kernel.py @@ -4,6 +4,7 @@ import asyncio import json import pickle import time +import uuid from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path @@ -18,7 +19,7 @@ from src.services.llm_service import LLMServiceClient from ...paths import default_data_dir, resolve_repo_path from ..embedding import create_embedding_api_adapter -from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index +from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore from ..utils.aggregate_query_service import AggregateQueryService from ..utils.episode_retrieval_service import EpisodeRetrievalService diff --git a/src/A_memorix/core/storage/graph_store.py b/src/A_memorix/core/storage/graph_store.py index f036b6e4..e338ffc5 100644 --- a/src/A_memorix/core/storage/graph_store.py +++ b/src/A_memorix/core/storage/graph_store.py @@ -9,6 +9,7 @@ from enum import Enum from pathlib import Path from typing import Optional, Union, Tuple, List, Dict, Set, Any from collections import defaultdict +import threading import asyncio import numpy as np @@ -41,6 +42,7 @@ except ImportError: import contextlib from src.common.logger import get_logger +from ..utils.hash import compute_hash from ..utils.io import atomic_write logger = get_logger("A_Memorix.GraphStore") diff --git a/src/A_memorix/core/storage/vector_store.py b/src/A_memorix/core/storage/vector_store.py index 3590dba1..787e625a 100644 --- a/src/A_memorix/core/storage/vector_store.py +++ b/src/A_memorix/core/storage/vector_store.py @@ -4,6 +4,7 @@ 基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。 """ +import os import pickle import hashlib import shutil @@ -190,7 +191,7 @@ class VectorStore: self._update_reservoir(batch_vecs) # 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断 if len(self._reservoir_buffer) >= 10000: - logger.info("训练样本达到上限,开始训练...") + logger.info(f"训练样本达到上限,开始训练...") self._train_and_replay_unlocked() self._total_added += len(batch_ids) diff --git a/src/A_memorix/core/strategies/base.py b/src/A_memorix/core/strategies/base.py index 58e05303..ff250cdf 100644 --- a/src/A_memorix/core/strategies/base.py +++ b/src/A_memorix/core/strategies/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional, Union from dataclasses import dataclass, field from enum import Enum import hashlib diff --git a/src/A_memorix/core/strategies/factual.py b/src/A_memorix/core/strategies/factual.py index b0444ccd..4b7d6e56 100644 --- a/src/A_memorix/core/strategies/factual.py +++ b/src/A_memorix/core/strategies/factual.py @@ -1,5 +1,5 @@ import re -from typing import List +from typing import List, Dict, Any from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext class FactualStrategy(BaseStrategy): diff --git a/src/A_memorix/core/strategies/narrative.py b/src/A_memorix/core/strategies/narrative.py index 18fa6f93..731414f7 100644 --- a/src/A_memorix/core/strategies/narrative.py +++ b/src/A_memorix/core/strategies/narrative.py @@ -1,5 +1,5 @@ import re -from typing import List +from typing import List, Dict, Any from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext class NarrativeStrategy(BaseStrategy): diff --git a/src/A_memorix/core/strategies/quote.py b/src/A_memorix/core/strategies/quote.py index d4d62ce0..10733d64 100644 --- a/src/A_memorix/core/strategies/quote.py +++ b/src/A_memorix/core/strategies/quote.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict, Any from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags class QuoteStrategy(BaseStrategy): diff --git a/src/A_memorix/core/utils/hash.py b/src/A_memorix/core/utils/hash.py index cfcc4ec4..b6363257 100644 --- a/src/A_memorix/core/utils/hash.py +++ b/src/A_memorix/core/utils/hash.py @@ -6,6 +6,7 @@ import hashlib import re +from typing import Union def compute_hash(text: str, hash_type: str = "sha256") -> str: diff --git a/src/A_memorix/core/utils/io.py b/src/A_memorix/core/utils/io.py index 4d2f84a1..ed14df43 100644 --- a/src/A_memorix/core/utils/io.py +++ b/src/A_memorix/core/utils/io.py @@ -5,6 +5,7 @@ IO Utilities """ import os +import shutil import contextlib from pathlib import Path from typing import Union diff --git a/src/A_memorix/core/utils/matcher.py b/src/A_memorix/core/utils/matcher.py index de84c83c..bddff5ee 100644 --- a/src/A_memorix/core/utils/matcher.py +++ b/src/A_memorix/core/utils/matcher.py @@ -4,7 +4,7 @@ 实现 Aho-Corasick 算法用于多模式匹配。 """ -from typing import List, Dict, Tuple, Set +from typing import List, Dict, Tuple, Set, Any from collections import deque diff --git a/src/A_memorix/core/utils/path_fallback_service.py b/src/A_memorix/core/utils/path_fallback_service.py index 9d7de787..c8ef0be8 100644 --- a/src/A_memorix/core/utils/path_fallback_service.py +++ b/src/A_memorix/core/utils/path_fallback_service.py @@ -3,7 +3,7 @@ from __future__ import annotations import hashlib -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple from ..retrieval.dual_path import RetrievalResult diff --git a/src/A_memorix/core/utils/runtime_self_check.py b/src/A_memorix/core/utils/runtime_self_check.py index ae1a6635..6e9a41f2 100644 --- a/src/A_memorix/core/utils/runtime_self_check.py +++ b/src/A_memorix/core/utils/runtime_self_check.py @@ -234,7 +234,7 @@ async def ensure_runtime_self_check( sample_text=sample_text, ) try: - plugin_or_config._runtime_self_check_report = report + setattr(plugin_or_config, "_runtime_self_check_report", report) except Exception: pass return report diff --git a/src/A_memorix/core/utils/search_execution_service.py b/src/A_memorix/core/utils/search_execution_service.py index 3206d2f0..ace051e9 100644 --- a/src/A_memorix/core/utils/search_execution_service.py +++ b/src/A_memorix/core/utils/search_execution_service.py @@ -287,7 +287,7 @@ class SearchExecutionService: async def _executor() -> Dict[str, Any]: original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) - retriever.config.enable_ppr = bool(request.enable_ppr) + setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) started_at = time.time() try: retrieved = await retriever.retrieve( @@ -380,7 +380,7 @@ class SearchExecutionService: elapsed_ms = (time.time() - started_at) * 1000.0 return {"results": retrieved, "elapsed_ms": elapsed_ms} finally: - retriever.config.enable_ppr = original_ppr + setattr(retriever.config, "enable_ppr", original_ppr) dedup_hit = False try: diff --git a/src/A_memorix/core/utils/summary_importer.py b/src/A_memorix/core/utils/summary_importer.py index 7de7839a..0728e4dc 100644 --- a/src/A_memorix/core/utils/summary_importer.py +++ b/src/A_memorix/core/utils/summary_importer.py @@ -5,6 +5,7 @@ 导入到 A_memorix 的存储组件中。 """ +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import json diff --git a/src/A_memorix/runtime_registry.py b/src/A_memorix/runtime_registry.py index cac8893e..d389cba9 100644 --- a/src/A_memorix/runtime_registry.py +++ b/src/A_memorix/runtime_registry.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Dict, Optional _runtime_kernel: Any = None diff --git a/src/A_memorix/scripts/_bootstrap.py b/src/A_memorix/scripts/_bootstrap.py index 6bcb0f36..f8b72d72 100644 --- a/src/A_memorix/scripts/_bootstrap.py +++ b/src/A_memorix/scripts/_bootstrap.py @@ -15,7 +15,7 @@ for _path in (SRC_ROOT, PROJECT_ROOT, PLUGIN_ROOT): if _path_str not in sys.path: sys.path.insert(0, _path_str) -from A_memorix.paths import config_path, default_data_dir +from A_memorix.paths import config_path, default_data_dir, resolve_repo_path DEFAULT_CONFIG_PATH = config_path() DEFAULT_DATA_DIR = default_data_dir() diff --git a/src/A_memorix/scripts/convert_lpmm.py b/src/A_memorix/scripts/convert_lpmm.py index e7b028f1..8772fbff 100644 --- a/src/A_memorix/scripts/convert_lpmm.py +++ b/src/A_memorix/scripts/convert_lpmm.py @@ -10,12 +10,14 @@ LPMM 到 A_memorix 存储转换器 """ import sys +import os +import json import argparse import asyncio import pickle import logging from pathlib import Path -from typing import Dict, Any, Tuple +from typing import Dict, Any, List, Tuple import numpy as np import tomlkit diff --git a/src/A_memorix/scripts/process_knowledge.py b/src/A_memorix/scripts/process_knowledge.py index 7ea6f114..1dfec182 100644 --- a/src/A_memorix/scripts/process_knowledge.py +++ b/src/A_memorix/scripts/process_knowledge.py @@ -12,14 +12,17 @@ from datetime import datetime from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from rich.console import Console +from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type import argparse import asyncio import hashlib import json +import os +import random import sys import time import tomlkit diff --git a/src/A_memorix/scripts/release_vnext_migrate.py b/src/A_memorix/scripts/release_vnext_migrate.py index cd16bf65..49c7517c 100644 --- a/src/A_memorix/scripts/release_vnext_migrate.py +++ b/src/A_memorix/scripts/release_vnext_migrate.py @@ -17,7 +17,7 @@ import sqlite3 import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple import tomlkit From e41922e24c0013786193cb30bae4a868ae1267a2 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:08:14 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat(A=5Fmemorix):=20=E6=94=B6=E7=B4=A7?= =?UTF-8?q?=E7=A8=80=E7=96=8F=E5=B0=BE=E9=83=A8=E5=8F=AC=E5=9B=9E=E5=B9=B6?= =?UTF-8?q?=E6=94=B9=E8=BF=9B=20PPR=20=E5=A2=9E=E7=9B=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/A_memorix/core/retrieval/dual_path.py | 65 ++++++++++++++++++++- src/A_memorix/core/retrieval/sparse_bm25.py | 9 ++- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/A_memorix/core/retrieval/dual_path.py b/src/A_memorix/core/retrieval/dual_path.py index 437f3dd7..c03be548 100644 --- a/src/A_memorix/core/retrieval/dual_path.py +++ b/src/A_memorix/core/retrieval/dual_path.py @@ -588,6 +588,7 @@ class DualPathRetriever: candidate_k = max(top_k, self.config.sparse.candidate_k) candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) sparse_rows = self.sparse_index.search(query=query, k=candidate_k) + sparse_rows = self._filter_sparse_paragraph_rows(sparse_rows) results: List[RetrievalResult] = [] for row in sparse_rows: hash_value = row["hash"] @@ -614,6 +615,53 @@ class DualPathRetriever: self._normalize_scores_minmax(results) return results + def _filter_sparse_paragraph_rows( + self, + rows: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """ + 过滤 paragraph sparse tail。 + + 目标不是压缩强 lexical hit,而是避免只命中一个弱 token 的尾部结果 + 在 weighted RRF 中拿到过高的 rank credit。 + """ + if len(rows) <= 2: + return rows + + top_score = max(0.0, float(rows[0].get("score", 0.0) or 0.0)) + if top_score <= 0.0: + return rows[:2] + + relative_floor = top_score * 0.2 + filtered_rows: List[Dict[str, Any]] = [] + removed_count = 0 + for index, row in enumerate(rows): + if index < 2: + filtered_rows.append(row) + continue + + raw_score = float(row.get("score", 0.0) or 0.0) + matched_token_count = int(row.get("matched_token_count", 0) or 0) + matched_token_ratio = float(row.get("matched_token_ratio", 0.0) or 0.0) + + if ( + raw_score >= relative_floor + or matched_token_count >= 3 + or (matched_token_count >= 2 and matched_token_ratio >= 0.12) + ): + filtered_rows.append(row) + continue + + removed_count += 1 + + if removed_count > 0: + logger.debug( + "sparse_paragraph_tail_pruned=1 " + f"removed_count={removed_count} " + f"kept_count={len(filtered_rows)}" + ) + return filtered_rows + def _search_relations_sparse( self, query: str, @@ -1560,9 +1608,20 @@ class DualPathRetriever: entity_scores.append(ppr_scores_by_name[ent_name]) if entity_scores: - avg_ppr = np.mean(entity_scores) - # 融合原始分数和PPR分数 - result.score = result.score * 0.7 + avg_ppr * 0.3 + # 只使用命中的高价值图实体做正向增益,避免把原本高分的正确段落 + # 因为“实体多但非全部命中”而反向压低。 + focus_scores = sorted(entity_scores, reverse=True)[:2] + ppr_signal = float(np.mean(focus_scores)) + boost_weight = 0.12 if len(focus_scores) >= 2 else 0.06 + boost = ppr_signal * boost_weight + + metadata = result.metadata if isinstance(result.metadata, dict) else {} + metadata["ppr_signal"] = round(ppr_signal, 4) + metadata["ppr_focus_entity_count"] = len(focus_scores) + metadata["ppr_boost"] = round(boost, 4) + result.metadata = metadata + + result.score = float(result.score) + float(boost) # 重新排序 results.sort(key=lambda x: x.score, reverse=True) diff --git a/src/A_memorix/core/retrieval/sparse_bm25.py b/src/A_memorix/core/retrieval/sparse_bm25.py index 276e8778..7808a516 100644 --- a/src/A_memorix/core/retrieval/sparse_bm25.py +++ b/src/A_memorix/core/retrieval/sparse_bm25.py @@ -306,15 +306,22 @@ class SparseBM25Index: rows = self._fallback_substring_search(tokens=tokens, limit=limit) results: List[Dict[str, Any]] = [] + token_count = max(1, len(tokens)) for rank, row in enumerate(rows, start=1): bm25_score = float(row.get("bm25_score", 0.0)) + content = str(row.get("content", "") or "") + content_low = content.lower() + matched_tokens = [token for token in tokens if token in content_low] + matched_token_count = len(dict.fromkeys(matched_tokens)) results.append( { "hash": row["hash"], - "content": row["content"], + "content": content, "rank": rank, "bm25_score": bm25_score, "score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数 + "matched_token_count": matched_token_count, + "matched_token_ratio": matched_token_count / float(token_count), } ) return results From 611a0a575dd1103e34ba30df38d48408455c1a6f Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:21:43 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat(A=5Fmemorix):=20=E9=9B=86=E6=88=90?= =?UTF-8?q?=E5=90=91=E9=87=8F=E5=8F=AC=E5=9B=9E=E4=B8=8E=E5=90=8E=E9=AA=8C?= =?UTF-8?q?=E5=9B=BE=E8=A1=A5=E4=BD=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/pull_request_template.md | 5 +++-- AGENTS.md | 10 +++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 3ed6800e..22f722d1 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -10,8 +10,9 @@ 3. - [ ] 本次更新类型为:BUG修复 - [ ] 本次更新类型为:功能新增 4. - [ ] 本次更新是否经过测试 -5. 请填写破坏性更新的具体内容(如有): -6. 请简要说明本次更新的内容和目的: +5. - [ ] 如果本次修改涉及 `src/A_memorix`,我确认已阅读 `src/A_memorix/MODIFICATION_POLICY.md`,不涉及则无需勾选 +6. 请填写破坏性更新的具体内容(如有): +7. 请简要说明本次更新的内容和目的: # 其他信息 - **关联 Issue**:Close # - **截图/GIF**: diff --git a/AGENTS.md b/AGENTS.md index fead0c13..5cec7c47 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -44,5 +44,13 @@ # 关于webui修改 不要修改dashboard下的内容,因为这部分内容由另一个仓库build +# 关于 A_memorix 修改 +如果修改涉及 `src/A_memorix`,请先阅读 `src/A_memorix/MODIFICATION_POLICY.md`。 + +默认原则: +1. `src/A_memorix` 的实现层改动应优先遵守 `src/A_memorix/MODIFICATION_POLICY.md` 中的归属约束。 +2. 不要提交无边界的 `ruff`、格式化、导入整理或大面积实现整理。 +3. 本地实验目录或依赖其运行的测试,除非明确说明并确认,否则不要进入共享历史。 + # maibot插件开发文档 -https://github.com/Mai-with-u/maibot-plugin-sdk/blob/main/docs/guide.md \ No newline at end of file +https://github.com/Mai-with-u/maibot-plugin-sdk/blob/main/docs/guide.md From 0eba6186c157d14a94a49fd128fe4d83a89997af Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 21 Apr 2026 23:20:17 +0800 Subject: [PATCH 4/4] =?UTF-8?q?feat(A=5Fmemorix):=20=E4=B8=BA=E5=8F=8C?= =?UTF-8?q?=E8=B7=AF=E6=A3=80=E7=B4=A2=E6=8E=A5=E5=85=A5=E5=85=B1=E4=BA=AB?= =?UTF-8?q?=E5=80=99=E9=80=89=E6=B1=A0=E4=B8=8E=E5=9B=BE=E5=90=8E=E9=AA=8C?= =?UTF-8?q?=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 让段落与关系检索先复用共享向量候选池,再按类型回填,缓解单侧候选竞争导致的召回不足。 新增 posterior_graph 尾部补位竞争逻辑,并补齐运行时配置解析与 CONFIG_REFERENCE 说明 --- src/A_memorix/CONFIG_REFERENCE.md | 22 +- src/A_memorix/core/retrieval/__init__.py | 2 + src/A_memorix/core/retrieval/dual_path.py | 172 +++- .../core/retrieval/posterior_graph.py | 792 ++++++++++++++++++ .../runtime/search_runtime_initializer.py | 11 + 5 files changed, 964 insertions(+), 35 deletions(-) create mode 100644 src/A_memorix/core/retrieval/posterior_graph.py diff --git a/src/A_memorix/CONFIG_REFERENCE.md b/src/A_memorix/CONFIG_REFERENCE.md index 1a79f858..76163ac8 100644 --- a/src/A_memorix/CONFIG_REFERENCE.md +++ b/src/A_memorix/CONFIG_REFERENCE.md @@ -120,7 +120,7 @@ default_sample_size = 24 - 长期记忆控制台:适合修改高频项,例如 embedding、检索、Episode、人物画像、导入与调优的常用开关。 - 原始 TOML:适合复制整份配置、批量调整参数,或修改未在可视化表单中展示的高级项。 -- raw-only 高级项仍包括:`retrieval.fusion.*`、`retrieval.search.relation_intent.*`、`retrieval.search.graph_recall.*`、`retrieval.aggregate.*`、`memory.orphan.*`、`advanced.extraction_model`、`web.import.llm_retry.*`、`web.import.path_aliases`、`web.import.convert.*`、`web.tuning.llm_retry.*`、`web.tuning.eval_query_timeout_seconds`。 +- raw-only 高级项仍包括:`retrieval.fusion.*`、`retrieval.search.relation_intent.*`、`retrieval.search.graph_recall.*`、`retrieval.search.posterior_graph.*`、`retrieval.aggregate.*`、`memory.orphan.*`、`advanced.extraction_model`、`web.import.llm_retry.*`、`web.import.path_aliases`、`web.import.convert.*`、`web.tuning.llm_retry.*`、`web.tuning.eval_query_timeout_seconds`。 ## 1. 存储与嵌入 @@ -213,6 +213,26 @@ default_sample_size = 24 - `allow_two_hop_pair` (默认 `true`) - `max_paths` (默认 `4`) +### `retrieval.search.posterior_graph` (`PosteriorGraphConfig`) + +- `enabled` (默认 `true`) +- `drop_ratio` (默认 `0.15`) +- `min_core_results` (默认 `2`) +- `max_graph_slots` (默认 `2`) +- `gate_scan_top_k` (默认 `5`) +- `grounded_confidence_threshold` (默认 `0.48`) +- `incidental_confidence_threshold` (默认 `0.22`) +- `min_query_token_coverage` (默认 `0.78`) +- `incidental_query_relevance_threshold` (默认 `0.68`) +- `incidental_core_overlap_threshold` (默认 `0.34`) +- `incidental_specificity_threshold` (默认 `0.42`) + +说明: + +- 这组配置控制“后验图补位”,即先跑正常双路检索,再判断是否需要从图结构补一小批 relation 候选进入尾部竞争。 +- 设计目标以 `recall` 为主,而不是强行把 relation 顶到第一名。 +- 如果你的最终回答仍会经过 LLM 汇总,这组能力更适合用于“保证证据进入前排候选”,而不是做激进排序改写。 + ### `retrieval.aggregate` - `retrieval.aggregate.rrf_k` diff --git a/src/A_memorix/core/retrieval/__init__.py b/src/A_memorix/core/retrieval/__init__.py index 6efce7f6..2bd84a4d 100644 --- a/src/A_memorix/core/retrieval/__init__.py +++ b/src/A_memorix/core/retrieval/__init__.py @@ -9,6 +9,7 @@ from .dual_path import ( FusionConfig, RelationIntentConfig, ) +from .posterior_graph import PosteriorGraphConfig from .pagerank import ( PersonalizedPageRank, PageRankConfig, @@ -37,6 +38,7 @@ __all__ = [ "TemporalQueryOptions", "FusionConfig", "RelationIntentConfig", + "PosteriorGraphConfig", # PersonalizedPageRank "PersonalizedPageRank", "PageRankConfig", diff --git a/src/A_memorix/core/retrieval/dual_path.py b/src/A_memorix/core/retrieval/dual_path.py index c03be548..ff379978 100644 --- a/src/A_memorix/core/retrieval/dual_path.py +++ b/src/A_memorix/core/retrieval/dual_path.py @@ -19,6 +19,7 @@ from ..utils.matcher import AhoCorasick from ..utils.time_parser import format_timestamp from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService from .pagerank import PersonalizedPageRank, PageRankConfig +from .posterior_graph import PosteriorGraphConfig, apply_posterior_graph_gate from .sparse_bm25 import SparseBM25Config, SparseBM25Index logger = get_logger("A_Memorix.DualPathRetriever") @@ -101,6 +102,7 @@ class DualPathRetrieverConfig: fusion: "FusionConfig" = field(default_factory=lambda: FusionConfig()) relation_intent: "RelationIntentConfig" = field(default_factory=lambda: RelationIntentConfig()) graph_recall: GraphRelationRecallConfig = field(default_factory=GraphRelationRecallConfig) + posterior_graph: PosteriorGraphConfig = field(default_factory=PosteriorGraphConfig) def __post_init__(self): """验证配置""" @@ -112,6 +114,8 @@ class DualPathRetrieverConfig: self.relation_intent = RelationIntentConfig(**self.relation_intent) if isinstance(self.graph_recall, dict): self.graph_recall = GraphRelationRecallConfig(**self.graph_recall) + if isinstance(self.posterior_graph, dict): + self.posterior_graph = PosteriorGraphConfig(**self.posterior_graph) if not 0 <= self.alpha <= 1: raise ValueError(f"alpha必须在[0, 1]之间: {self.alpha}") @@ -1073,6 +1077,14 @@ class DualPathRetriever: ) if temporal: fused_results = self._sort_results_with_temporal(fused_results, temporal) + fused_results = apply_posterior_graph_gate( + self, + query=query, + base_results=fused_results, + top_k=top_k, + temporal=temporal, + relation_intent=relation_intent, + ) fused_results = self._apply_relation_intent_pair_rerank( fused_results, enabled=bool(relation_intent.get("enabled", False)), @@ -1174,6 +1186,15 @@ class DualPathRetriever: if temporal: fused_results = self._sort_results_with_temporal(fused_results, temporal) + fused_results = apply_posterior_graph_gate( + self, + query=query, + base_results=fused_results, + top_k=top_k, + temporal=temporal, + relation_intent=relation_intent, + ) + fused_results = self._apply_relation_intent_pair_rerank( fused_results, enabled=bool(relation_intent.get("enabled", False)), @@ -1198,37 +1219,13 @@ class DualPathRetriever: Returns: (段落结果, 关系结果) """ - # 使用 asyncio.gather 并发执行两个搜索任务 - # 由于 _search_paragraphs 和 _search_relations 是 CPU 密集型同步函数, - # 使用 asyncio.to_thread 在线程池中执行 try: - para_task = asyncio.to_thread( - self._search_paragraphs, + return await asyncio.to_thread( + self._collect_mixed_candidates, query_emb, - self.config.top_k_paragraphs, temporal, + relation_top_k, ) - rel_task = asyncio.to_thread( - self._search_relations, - query_emb, - relation_top_k if relation_top_k is not None else self.config.top_k_relations, - temporal, - ) - - para_results, rel_results = await asyncio.gather( - para_task, rel_task, return_exceptions=True - ) - - # 处理异常 - if isinstance(para_results, Exception): - logger.error(f"段落检索失败: {para_results}") - para_results = [] - if isinstance(rel_results, Exception): - logger.error(f"关系检索失败: {rel_results}") - rel_results = [] - - return para_results, rel_results - except Exception as e: logger.error(f"并行检索失败: {e}") return [], [] @@ -1248,18 +1245,125 @@ class DualPathRetriever: Returns: (段落结果, 关系结果) """ - para_results = self._search_paragraphs( + return self._collect_mixed_candidates( query_emb, - self.config.top_k_paragraphs, temporal, + relation_top_k, ) - rel_results = self._search_relations( - query_emb, - relation_top_k if relation_top_k is not None else self.config.top_k_relations, - temporal, - ) + def _mixed_candidate_budget( + self, + para_top_k: int, + rel_top_k: int, + temporal: Optional[TemporalQueryOptions], + ) -> int: + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + base = max(para_top_k + rel_top_k, max(para_top_k, rel_top_k) * 2) + return max(base * 6 * multiplier, 48) + def _merge_backfilled_results( + self, + *, + primary_results: List[RetrievalResult], + backfill_results: List[RetrievalResult], + top_k: int, + ) -> List[RetrievalResult]: + merged: Dict[str, RetrievalResult] = {} + for item in primary_results: + merged[item.hash_value] = item + for item in backfill_results: + existing = merged.get(item.hash_value) + if existing is None or float(item.score) > float(existing.score): + merged[item.hash_value] = item + + results = list(merged.values()) + results.sort(key=lambda item: item.score, reverse=True) + return results[:top_k] + + def _collect_mixed_candidates( + self, + query_emb: np.ndarray, + temporal: Optional[TemporalQueryOptions] = None, + relation_top_k: Optional[int] = None, + ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: + para_top_k = self.config.top_k_paragraphs + rel_top_k = relation_top_k if relation_top_k is not None else self.config.top_k_relations + candidate_k = self._mixed_candidate_budget(para_top_k, rel_top_k, temporal) + candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) + ids, scores = self.vector_store.search(query_emb, k=candidate_k) + + para_candidates: List[RetrievalResult] = [] + rel_candidates: List[RetrievalResult] = [] + seen_para = set() + seen_rel = set() + + for hash_value, score in zip(ids, scores): + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is not None and hash_value not in seen_para: + seen_para.add(hash_value) + para_candidates.append( + RetrievalResult( + hash_value=hash_value, + content=paragraph["content"], + score=float(score), + result_type="paragraph", + source="paragraph_search", + metadata={ + "word_count": paragraph.get("word_count", 0), + "time_meta": self._build_time_meta_from_paragraph( + paragraph, + temporal=temporal, + ), + }, + ) + ) + continue + + relation = self.metadata_store.get_relation(hash_value, include_inactive=False) + if relation is None or hash_value in seen_rel: + continue + + relation_time_meta = None + if temporal: + relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) + if relation_time_meta is None: + continue + + seen_rel.add(hash_value) + rel_candidates.append( + RetrievalResult( + hash_value=hash_value, + content=f"{relation['subject']} {relation['predicate']} {relation['object']}", + score=float(score), + result_type="relation", + source="relation_search", + metadata={ + "subject": relation["subject"], + "predicate": relation["predicate"], + "object": relation["object"], + "confidence": relation.get("confidence", 1.0), + "time_meta": relation_time_meta, + }, + ) + ) + + para_results = self._apply_temporal_filter_to_paragraphs(para_candidates, temporal) + rel_results = self._apply_temporal_filter_to_relations(rel_candidates, temporal) + + # 双重方案里,向量主干优先解决“召回不够”,因此主检索走共享候选池, + # 但再补一层按类型回填,避免 paragraph / relation 任一侧被饿死。 + para_backfill = self._search_paragraphs(query_emb, para_top_k, temporal) + rel_backfill = self._search_relations(query_emb, rel_top_k, temporal) + para_results = self._merge_backfilled_results( + primary_results=para_results, + backfill_results=para_backfill, + top_k=para_top_k, + ) + rel_results = self._merge_backfilled_results( + primary_results=rel_results, + backfill_results=rel_backfill, + top_k=rel_top_k, + ) return para_results, rel_results def _search_paragraphs( diff --git a/src/A_memorix/core/retrieval/posterior_graph.py b/src/A_memorix/core/retrieval/posterior_graph.py new file mode 100644 index 00000000..5ac663fe --- /dev/null +++ b/src/A_memorix/core/retrieval/posterior_graph.py @@ -0,0 +1,792 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Set, Tuple + +import re + +import jieba + +if TYPE_CHECKING: + from .dual_path import DualPathRetriever, RetrievalResult + + +_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]+") +_BROAD_PREDICATES = { + "contains_fact", + "describe", + "describes", + "description", + "mention", + "mentions", + "summary", + "summarizes", +} + + +@dataclass +class PosteriorGraphConfig: + """双重方案中的后验图补位配置。""" + + enabled: bool = True + drop_ratio: float = 0.15 + min_core_results: int = 2 + max_graph_slots: int = 2 + gate_scan_top_k: int = 5 + grounded_confidence_threshold: float = 0.48 + incidental_confidence_threshold: float = 0.22 + min_query_token_coverage: float = 0.78 + incidental_query_relevance_threshold: float = 0.68 + incidental_core_overlap_threshold: float = 0.34 + incidental_specificity_threshold: float = 0.42 + query_weight: float = 0.28 + novelty_weight: float = 0.18 + complementarity_weight: float = 0.16 + specificity_weight: float = 0.12 + gap_fill_weight: float = 0.26 + max_candidate_tokens: int = 48 + + def __post_init__(self) -> None: + self.enabled = bool(self.enabled) + self.drop_ratio = max(0.0, float(self.drop_ratio)) + self.min_core_results = max(1, int(self.min_core_results)) + self.max_graph_slots = max(0, int(self.max_graph_slots)) + self.gate_scan_top_k = max(1, int(self.gate_scan_top_k)) + self.grounded_confidence_threshold = _clip_score(self.grounded_confidence_threshold) + self.incidental_confidence_threshold = _clip_score(self.incidental_confidence_threshold) + self.min_query_token_coverage = _clip_score(self.min_query_token_coverage) + self.incidental_query_relevance_threshold = _clip_score( + self.incidental_query_relevance_threshold + ) + self.incidental_core_overlap_threshold = _clip_score( + self.incidental_core_overlap_threshold + ) + self.incidental_specificity_threshold = _clip_score( + self.incidental_specificity_threshold + ) + self.query_weight = max(0.0, float(self.query_weight)) + self.novelty_weight = max(0.0, float(self.novelty_weight)) + self.complementarity_weight = max(0.0, float(self.complementarity_weight)) + self.specificity_weight = max(0.0, float(self.specificity_weight)) + self.gap_fill_weight = max(0.0, float(self.gap_fill_weight)) + self.max_candidate_tokens = max(8, int(self.max_candidate_tokens)) + + +@dataclass +class _CompetitionProfile: + text: str + tokens: Set[str] + entities: Set[str] + + +@dataclass +class _SeedEvidence: + name: str + strength: str + support_count: int + rank_hint: int + + +def _safe_ratio(numerator: int, denominator: int) -> float: + if denominator <= 0: + return 0.0 + return float(numerator) / float(denominator) + + +def _clip_score(value: float) -> float: + return max(0.0, min(1.0, float(value))) + + +def _is_cjk_chunk(token: str) -> bool: + return bool(token) and all("\u4e00" <= char <= "\u9fff" for char in token) + + +def _tokenize_for_competition(text: str, *, max_tokens: int) -> List[str]: + normalized = str(text or "").lower().strip() + if not normalized: + return [] + + tokens: List[str] = [] + for chunk in _TOKEN_PATTERN.findall(normalized): + if _is_cjk_chunk(chunk): + tokens.extend( + item.strip().lower() + for item in jieba.lcut_for_search(chunk) + if item.strip() + ) + else: + tokens.append(chunk) + + filtered: List[str] = [] + for token in tokens: + if len(token) <= 1: + continue + filtered.append(token) + if len(filtered) >= max_tokens: + break + return filtered + + +def _result_text_for_entity_match(result: RetrievalResult) -> str: + metadata = result.metadata if isinstance(result.metadata, dict) else {} + parts = [ + str(result.content or ""), + str(metadata.get("subject", "") or ""), + str(metadata.get("object", "") or ""), + str(metadata.get("context_title", "") or ""), + str(metadata.get("benchmark_title", "") or ""), + ] + return "\n".join(part for part in parts if part) + + +def _candidate_text_for_competition(result: RetrievalResult) -> str: + metadata = result.metadata if isinstance(result.metadata, dict) else {} + parts = [ + _result_text_for_entity_match(result), + str(metadata.get("predicate", "") or ""), + ] + return "\n".join(part for part in parts if part) + + +def _extract_candidate_entities( + retriever: DualPathRetriever, + result: RetrievalResult, +) -> Set[str]: + metadata = result.metadata if isinstance(result.metadata, dict) else {} + entities: Set[str] = set() + + for name in retriever._extract_entities(_candidate_text_for_competition(result)).keys(): + normalized = str(name or "").strip().lower() + if normalized: + entities.add(normalized) + + for key in ("benchmark_title", "context_title", "object", "subject"): + normalized = str(metadata.get(key, "") or "").strip().lower() + if normalized: + entities.add(normalized) + + return entities + + +def _build_query_profile( + retriever: DualPathRetriever, + query: str, + *, + max_tokens: int, +) -> _CompetitionProfile: + text = str(query or "") + tokens = set(_tokenize_for_competition(text, max_tokens=max_tokens)) + entities = { + str(name or "").strip().lower() + for name in retriever._extract_entities(text).keys() + if str(name or "").strip() + } + return _CompetitionProfile(text=text, tokens=tokens, entities=entities) + + +def _build_candidate_profile( + retriever: DualPathRetriever, + result: RetrievalResult, + *, + max_tokens: int, +) -> _CompetitionProfile: + text = _candidate_text_for_competition(result) + return _CompetitionProfile( + text=text, + tokens=set(_tokenize_for_competition(text, max_tokens=max_tokens)), + entities=_extract_candidate_entities(retriever, result), + ) + + +def _build_core_profile( + retriever: DualPathRetriever, + results: Sequence[RetrievalResult], + *, + max_tokens: int, +) -> _CompetitionProfile: + parts: List[str] = [] + tokens: Set[str] = set() + entities: Set[str] = set() + + for result in results: + profile = _build_candidate_profile(retriever, result, max_tokens=max_tokens) + parts.append(profile.text) + tokens.update(profile.tokens) + entities.update(profile.entities) + + return _CompetitionProfile(text="\n".join(parts), tokens=tokens, entities=entities) + + +def _compute_query_relevance(candidate: _CompetitionProfile, query: _CompetitionProfile) -> float: + entity_hit = _safe_ratio(len(candidate.entities & query.entities), len(query.entities)) + token_hit = _safe_ratio(len(candidate.tokens & query.tokens), len(query.tokens)) + if query.entities: + return _clip_score(0.65 * entity_hit + 0.35 * token_hit) + return _clip_score(max(entity_hit, token_hit)) + + +def _compute_novelty(candidate: _CompetitionProfile, core: _CompetitionProfile) -> float: + entity_novelty = _safe_ratio(len(candidate.entities - core.entities), len(candidate.entities)) + token_novelty = _safe_ratio(len(candidate.tokens - core.tokens), len(candidate.tokens)) + return _clip_score(0.5 * entity_novelty + 0.5 * token_novelty) + + +def _compute_complementarity( + candidate: _CompetitionProfile, + core: _CompetitionProfile, + query_relevance: float, +) -> float: + if not core.tokens and not core.entities: + return _clip_score(query_relevance) + + entity_overlap = _safe_ratio(len(candidate.entities & core.entities), len(candidate.entities)) + token_overlap = _safe_ratio(len(candidate.tokens & core.tokens), len(candidate.tokens)) + core_overlap = 0.5 * entity_overlap + 0.5 * token_overlap + sweet_spot = 1.0 - abs(core_overlap - 0.4) / 0.4 + return _clip_score(max(0.0, sweet_spot) * max(query_relevance, 0.2)) + + +def _compute_specificity(candidate: _CompetitionProfile, result: RetrievalResult) -> float: + token_count = max(1, len(candidate.tokens)) + entity_density = _clip_score(_safe_ratio(len(candidate.entities), token_count) * 4.0) + brevity = 1.0 - min(1.0, max(0, token_count - 16) / 16.0) + predicate_bonus = 0.0 + + metadata = result.metadata if isinstance(result.metadata, dict) else {} + predicate = str(metadata.get("predicate", "") or "").strip().lower() + if predicate: + if predicate in _BROAD_PREDICATES: + predicate_bonus = -0.25 + elif result.result_type == "relation": + predicate_bonus = 0.10 + + return _clip_score(0.6 * entity_density + 0.4 * brevity + predicate_bonus) + + +def _compute_gap_fill( + candidate: _CompetitionProfile, + query: _CompetitionProfile, + core: _CompetitionProfile, +) -> float: + missing_entities = query.entities - core.entities + missing_tokens = query.tokens - core.tokens + + entity_fill = _safe_ratio(len(candidate.entities & missing_entities), len(missing_entities)) + token_fill = _safe_ratio(len(candidate.tokens & missing_tokens), len(missing_tokens)) + + if missing_entities: + return _clip_score(0.7 * entity_fill + 0.3 * token_fill) + return _clip_score(max(entity_fill, token_fill)) + + +def _core_overlap(candidate: _CompetitionProfile, core: _CompetitionProfile) -> float: + entity_overlap = _safe_ratio(len(candidate.entities & core.entities), len(candidate.entities)) + token_overlap = _safe_ratio(len(candidate.tokens & core.tokens), len(candidate.tokens)) + return _clip_score(0.5 * entity_overlap + 0.5 * token_overlap) + + +def _compute_competition_score( + retriever: DualPathRetriever, + candidate: RetrievalResult, + *, + query_profile: _CompetitionProfile, + core_profile: _CompetitionProfile, + cfg: PosteriorGraphConfig, +) -> Tuple[float, Dict[str, float]]: + candidate_profile = _build_candidate_profile( + retriever, + candidate, + max_tokens=cfg.max_candidate_tokens, + ) + query_relevance = _compute_query_relevance(candidate_profile, query_profile) + novelty = _compute_novelty(candidate_profile, core_profile) + complementarity = _compute_complementarity(candidate_profile, core_profile, query_relevance) + specificity = _compute_specificity(candidate_profile, candidate) + gap_fill = _compute_gap_fill(candidate_profile, query_profile, core_profile) + + final_score = ( + cfg.query_weight * query_relevance + + cfg.novelty_weight * novelty + + cfg.complementarity_weight * complementarity + + cfg.specificity_weight * specificity + + cfg.gap_fill_weight * gap_fill + ) + breakdown = { + "query_relevance": round(query_relevance, 4), + "novelty": round(novelty, 4), + "complementarity": round(complementarity, 4), + "specificity": round(specificity, 4), + "gap_fill": round(gap_fill, 4), + "competition_score": round(_clip_score(final_score), 4), + } + return _clip_score(final_score), breakdown + + +def _top_score(results: Sequence[RetrievalResult]) -> float: + if not results: + return 0.0 + return max(float(item.score) for item in results) + + +def find_score_cliff( + results: Sequence[RetrievalResult], + *, + drop_ratio: float, + min_core_results: int, +) -> int: + ranked = list(results) + if not ranked: + return 0 + if len(ranked) <= min_core_results: + return len(ranked) + + for index in range(1, len(ranked)): + prev_score = max(float(ranked[index - 1].score), 1e-8) + current_score = float(ranked[index].score) + score_drop = prev_score - current_score + if score_drop / prev_score > float(drop_ratio): + return max(min_core_results, index) + + fallback = max(min_core_results, len(ranked) // 2) + return min(len(ranked), fallback) + + +def _extract_seed_evidence( + retriever: DualPathRetriever, + query_profile: _CompetitionProfile, + results: Sequence[RetrievalResult], + *, + scan_top_k: int, + max_tokens: int, +) -> List[_SeedEvidence]: + score_map: Dict[Tuple[str, str], _SeedEvidence] = {} + top_results = list(results)[: max(1, int(scan_top_k))] + + for rank, item in enumerate(top_results, start=1): + profile = _build_candidate_profile(retriever, item, max_tokens=max_tokens) + for entity in profile.entities: + strength = "grounded" if entity in query_profile.entities else "incidental" + key = (entity, strength) + existing = score_map.get(key) + if existing is None: + score_map[key] = _SeedEvidence( + name=entity, + strength=strength, + support_count=1, + rank_hint=rank, + ) + else: + existing.support_count += 1 + existing.rank_hint = min(existing.rank_hint, rank) + + return sorted( + score_map.values(), + key=lambda item: ( + 0 if item.strength == "grounded" else 1, + -int(item.support_count), + int(item.rank_hint), + -len(item.name), + item.name, + ), + ) + + +def _grounded_seed_names(seed_evidence: Sequence[_SeedEvidence]) -> List[str]: + return [item.name for item in seed_evidence if item.strength == "grounded"] + + +def _incidental_seed_names(seed_evidence: Sequence[_SeedEvidence]) -> List[str]: + return [item.name for item in seed_evidence if item.strength == "incidental"] + + +def _need_for_graph( + *, + query_profile: _CompetitionProfile, + core_profile: _CompetitionProfile, + core_profiles: Sequence[_CompetitionProfile], + grounded_seeds: Sequence[str], + rag_confidence: float, + cfg: PosteriorGraphConfig, +) -> Tuple[bool, str]: + uncovered_query_entities = query_profile.entities - core_profile.entities + if uncovered_query_entities: + return True, "uncovered_query_entities" + + if len(grounded_seeds) >= 2: + bridge_targets = set(list(grounded_seeds)[:2]) + same_core_hit = any(len(profile.entities & bridge_targets) >= 2 for profile in core_profiles) + if not same_core_hit: + return True, "grounded_bridge_gap" + + token_coverage = _safe_ratio(len(core_profile.tokens & query_profile.tokens), len(query_profile.tokens)) + if grounded_seeds and token_coverage < float(cfg.min_query_token_coverage): + return True, "low_core_query_coverage" + + if grounded_seeds and float(rag_confidence) < float(cfg.grounded_confidence_threshold): + return True, "low_confidence_grounded" + + return False, "core_already_sufficient" + + +def _passes_incidental_high_bar( + retriever: DualPathRetriever, + candidate: RetrievalResult, + *, + query_profile: _CompetitionProfile, + core_profile: _CompetitionProfile, + cfg: PosteriorGraphConfig, +) -> bool: + candidate_profile = _build_candidate_profile( + retriever, + candidate, + max_tokens=cfg.max_candidate_tokens, + ) + uncovered_query_entities = query_profile.entities - core_profile.entities + if candidate_profile.entities & uncovered_query_entities: + return True + + query_relevance = _compute_query_relevance(candidate_profile, query_profile) + specificity = _compute_specificity(candidate_profile, candidate) + overlap = _core_overlap(candidate_profile, core_profile) + gap_fill = _compute_gap_fill(candidate_profile, query_profile, core_profile) + + return bool( + query_relevance >= float(cfg.incidental_query_relevance_threshold) + and specificity >= float(cfg.incidental_specificity_threshold) + and overlap <= float(cfg.incidental_core_overlap_threshold) + and gap_fill > 0.0 + ) + + +def _linked_core_paragraph_hashes( + retriever: DualPathRetriever, + relation_hash: str, +) -> Set[str]: + rows = retriever.metadata_store.query( + """ + SELECT paragraph_hash FROM paragraph_relations + WHERE relation_hash = ? + """, + (relation_hash,), + ) + return { + str(row.get("paragraph_hash", "") or "").strip() + for row in rows + if str(row.get("paragraph_hash", "") or "").strip() + } + + +def _build_graph_results_from_seeds( + retriever: DualPathRetriever, + *, + seed_entities: Sequence[str], + temporal: Any, + alpha: float, +) -> List[RetrievalResult]: + from .dual_path import RetrievalResult + + service = getattr(retriever, "_graph_relation_recall", None) + if service is None: + return [] + + payloads = service.recall(seed_entities=seed_entities) + if not payloads: + return [] + + graph_results: List[RetrievalResult] = [] + for payload in payloads: + meta = payload.to_payload() + graph_results.append( + RetrievalResult( + hash_value=str(meta["hash"]), + content=str(meta["content"]), + score=0.0, + result_type="relation", + source="posterior_graph_recall", + metadata={ + "subject": meta["subject"], + "predicate": meta["predicate"], + "object": meta["object"], + "confidence": float(meta["confidence"]), + "graph_seed_entities": list(meta["graph_seed_entities"]), + "graph_hops": int(meta["graph_hops"]), + "graph_candidate_type": str(meta["graph_candidate_type"]), + "supporting_paragraph_count": int(meta["supporting_paragraph_count"]), + }, + ) + ) + + graph_results = retriever._apply_temporal_filter_to_relations(graph_results, temporal) + graph_results = retriever._merge_relation_results_graph_enhanced([], [], graph_results) + relation_weight = max(0.0, 1.0 - float(alpha)) + for item in graph_results: + item.score = float(item.score) * relation_weight + item.source = "posterior_graph_competition" + return graph_results + + +def _competition_merge( + retriever: DualPathRetriever, + *, + query: str, + base_results: Sequence[RetrievalResult], + graph_results: Sequence[RetrievalResult], + top_k: int, + cfg: PosteriorGraphConfig, +) -> List[RetrievalResult]: + ranked = list(base_results)[: max(1, int(top_k))] + if not ranked or not graph_results: + return ranked + + cliff = find_score_cliff( + ranked, + drop_ratio=cfg.drop_ratio, + min_core_results=cfg.min_core_results, + ) + core_results = ranked[:cliff] + replaceable_slots = min( + max(0, int(top_k) - len(core_results)), + int(cfg.max_graph_slots), + ) + if replaceable_slots <= 0: + return ranked[:top_k] + + core_paragraph_hashes = { + item.hash_value for item in core_results if item.result_type == "paragraph" + } + selected_hashes = {item.hash_value for item in core_results} + filtered_graph_results: List[RetrievalResult] = [] + for item in graph_results: + if item.hash_value in selected_hashes: + continue + linked_hashes = _linked_core_paragraph_hashes(retriever, item.hash_value) + if core_paragraph_hashes & linked_hashes: + continue + filtered_graph_results.append(item) + + tail_candidates: List[RetrievalResult] = [] + for item in ranked[cliff:top_k]: + if item.hash_value not in selected_hashes: + tail_candidates.append(item) + tail_candidates.extend(filtered_graph_results) + + query_profile = _build_query_profile( + retriever, + query, + max_tokens=cfg.max_candidate_tokens, + ) + core_profile = _build_core_profile( + retriever, + core_results, + max_tokens=cfg.max_candidate_tokens, + ) + + scored_candidates: List[Tuple[RetrievalResult, float]] = [] + for item in tail_candidates: + competition_score, breakdown = _compute_competition_score( + retriever, + item, + query_profile=query_profile, + core_profile=core_profile, + cfg=cfg, + ) + metadata = dict(item.metadata) if isinstance(item.metadata, dict) else {} + metadata["posterior_original_score"] = round(float(item.score), 4) + metadata["posterior_competition_breakdown"] = breakdown + metadata["posterior_competition_source"] = "posterior_graph_gate" + item.metadata = metadata + scored_candidates.append((item, competition_score)) + + scored_candidates.sort( + key=lambda payload: ( + float(payload[1]), + 1 if payload[0].result_type == "relation" else 0, + ), + reverse=True, + ) + + tail_winners: List[RetrievalResult] = [] + seen_hashes = set(selected_hashes) + for item, _ in scored_candidates: + if item.hash_value in seen_hashes: + continue + tail_winners.append(item) + seen_hashes.add(item.hash_value) + if len(tail_winners) >= replaceable_slots: + break + + return (core_results + tail_winners)[:top_k] + + +def apply_posterior_graph_gate( + retriever: DualPathRetriever, + *, + query: str, + base_results: Sequence[RetrievalResult], + top_k: int, + temporal: Any, + relation_intent: Dict[str, Any], +) -> List[RetrievalResult]: + cfg = getattr(retriever.config, "posterior_graph", None) + if not isinstance(cfg, PosteriorGraphConfig) or not cfg.enabled: + return list(base_results)[:top_k] + if not base_results: + setattr( + retriever, + "_posterior_graph_gate_last_decision", + { + "scheme": "posterior_graph_gate", + "query": str(query or ""), + "enabled": False, + "bucket": "posterior_gate_empty", + }, + ) + return [] + + top_k_int = max(1, int(top_k)) + alpha_override = relation_intent.get("alpha_override") if isinstance(relation_intent, dict) else None + alpha = float(alpha_override) if alpha_override is not None else float(retriever.config.alpha) + rag_confidence = _top_score(list(base_results)[:top_k_int]) + + query_profile = _build_query_profile( + retriever, + query, + max_tokens=cfg.max_candidate_tokens, + ) + seed_evidence = _extract_seed_evidence( + retriever, + query_profile, + base_results, + scan_top_k=cfg.gate_scan_top_k, + max_tokens=cfg.max_candidate_tokens, + ) + grounded_seeds = _grounded_seed_names(seed_evidence)[:2] + incidental_seeds = _incidental_seed_names(seed_evidence)[:2] + + core_results = list(base_results)[ + : find_score_cliff( + list(base_results)[:top_k_int], + drop_ratio=cfg.drop_ratio, + min_core_results=cfg.min_core_results, + ) + ] + core_profile = _build_core_profile( + retriever, + core_results, + max_tokens=cfg.max_candidate_tokens, + ) + core_profiles = [ + _build_candidate_profile(retriever, item, max_tokens=cfg.max_candidate_tokens) + for item in core_results + ] + need_for_graph, need_reason = _need_for_graph( + query_profile=query_profile, + core_profile=core_profile, + core_profiles=core_profiles, + grounded_seeds=grounded_seeds, + rag_confidence=rag_confidence, + cfg=cfg, + ) + + seed_type = "none" + seed_names: List[str] = [] + if grounded_seeds and need_for_graph: + seed_type = "grounded" + seed_names = grounded_seeds + elif ( + not grounded_seeds + and incidental_seeds + and rag_confidence < float(cfg.incidental_confidence_threshold) + ): + seed_type = "incidental" + seed_names = incidental_seeds + + if not seed_names: + setattr( + retriever, + "_posterior_graph_gate_last_decision", + { + "scheme": "posterior_graph_gate", + "query": str(query or ""), + "enabled": False, + "bucket": "posterior_gate_none", + "grounded_seeds": list(grounded_seeds), + "incidental_seeds": list(incidental_seeds), + "selected_seed_type": seed_type, + "need_for_graph": bool(need_for_graph), + "need_reason": str(need_reason), + "rag_confidence": round(float(rag_confidence), 4), + }, + ) + return list(base_results)[:top_k_int] + + graph_results = _build_graph_results_from_seeds( + retriever, + seed_entities=seed_names, + temporal=temporal, + alpha=alpha, + ) + raw_graph_count = len(graph_results) + if seed_type == "incidental": + graph_results = [ + item + for item in graph_results + if _passes_incidental_high_bar( + retriever, + item, + query_profile=query_profile, + core_profile=core_profile, + cfg=cfg, + ) + ] + + if not graph_results: + setattr( + retriever, + "_posterior_graph_gate_last_decision", + { + "scheme": "posterior_graph_gate", + "query": str(query or ""), + "enabled": False, + "bucket": "posterior_gate_graph_filtered", + "grounded_seeds": list(grounded_seeds), + "incidental_seeds": list(incidental_seeds), + "selected_seed_type": seed_type, + "need_for_graph": bool(need_for_graph), + "need_reason": str(need_reason), + "rag_confidence": round(float(rag_confidence), 4), + "graph_result_count": int(raw_graph_count), + }, + ) + return list(base_results)[:top_k_int] + + final_results = _competition_merge( + retriever, + query=query, + base_results=base_results, + graph_results=graph_results, + top_k=top_k_int, + cfg=cfg, + ) + selected_hashes = {item.hash_value for item in final_results} + graph_selected = any(item.hash_value in selected_hashes for item in graph_results) + setattr( + retriever, + "_posterior_graph_gate_last_decision", + { + "scheme": "posterior_graph_gate", + "query": str(query or ""), + "enabled": bool(graph_selected), + "bucket": "posterior_gate_enabled" if graph_selected else "posterior_gate_tail_rejected", + "grounded_seeds": list(grounded_seeds), + "incidental_seeds": list(incidental_seeds), + "selected_seed_type": seed_type, + "need_for_graph": bool(need_for_graph), + "need_reason": str(need_reason), + "rag_confidence": round(float(rag_confidence), 4), + "graph_result_count": int(raw_graph_count), + "filtered_graph_count": max(0, raw_graph_count - len(graph_results)), + "base_top_k_count": min(len(base_results), top_k_int), + }, + ) + return final_results[:top_k_int] diff --git a/src/A_memorix/core/runtime/search_runtime_initializer.py b/src/A_memorix/core/runtime/search_runtime_initializer.py index 5afcd5a3..0c6146c6 100644 --- a/src/A_memorix/core/runtime/search_runtime_initializer.py +++ b/src/A_memorix/core/runtime/search_runtime_initializer.py @@ -13,6 +13,7 @@ from ..retrieval import ( DynamicThresholdFilter, FusionConfig, GraphRelationRecallConfig, + PosteriorGraphConfig, RelationIntentConfig, RetrievalStrategy, SparseBM25Config, @@ -143,6 +144,9 @@ def build_search_runtime( graph_recall_cfg_raw = _safe_dict( _get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {} ) + posterior_graph_cfg_raw = _safe_dict( + _get_config_value(plugin_config, "retrieval.search.posterior_graph", {}) or {} + ) try: sparse_cfg = SparseBM25Config(**sparse_cfg_raw) @@ -168,6 +172,12 @@ def build_search_runtime( log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}") graph_recall_cfg = GraphRelationRecallConfig() + try: + posterior_graph_cfg = PosteriorGraphConfig(**posterior_graph_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] posterior_graph 配置非法,回退默认: {e}") + posterior_graph_cfg = PosteriorGraphConfig() + try: config = DualPathRetrieverConfig( top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20), @@ -189,6 +199,7 @@ def build_search_runtime( fusion=fusion_cfg, relation_intent=relation_intent_cfg, graph_recall=graph_recall_cfg, + posterior_graph=posterior_graph_cfg, ) runtime.retriever = DualPathRetriever(