Merge pull request #1598 from A-Dawn/dev
feat&doc(A_memorix):引入新算法,明确修改边界
This commit is contained in:
5
.github/pull_request_template.md
vendored
5
.github/pull_request_template.md
vendored
@@ -10,8 +10,9 @@
|
|||||||
3. - [ ] 本次更新类型为:BUG修复
|
3. - [ ] 本次更新类型为:BUG修复
|
||||||
- [ ] 本次更新类型为:功能新增
|
- [ ] 本次更新类型为:功能新增
|
||||||
4. - [ ] 本次更新是否经过测试
|
4. - [ ] 本次更新是否经过测试
|
||||||
5. 请填写破坏性更新的具体内容(如有):
|
5. - [ ] 如果本次修改涉及 `src/A_memorix`,我确认已阅读 `src/A_memorix/MODIFICATION_POLICY.md`,不涉及则无需勾选
|
||||||
6. 请简要说明本次更新的内容和目的:
|
6. 请填写破坏性更新的具体内容(如有):
|
||||||
|
7. 请简要说明本次更新的内容和目的:
|
||||||
# 其他信息
|
# 其他信息
|
||||||
- **关联 Issue**:Close #
|
- **关联 Issue**:Close #
|
||||||
- **截图/GIF**:
|
- **截图/GIF**:
|
||||||
|
|||||||
10
AGENTS.md
10
AGENTS.md
@@ -44,5 +44,13 @@
|
|||||||
# 关于webui修改
|
# 关于webui修改
|
||||||
不要修改dashboard下的内容,因为这部分内容由另一个仓库build
|
不要修改dashboard下的内容,因为这部分内容由另一个仓库build
|
||||||
|
|
||||||
|
# 关于 A_memorix 修改
|
||||||
|
如果修改涉及 `src/A_memorix`,请先阅读 `src/A_memorix/MODIFICATION_POLICY.md`。
|
||||||
|
|
||||||
|
默认原则:
|
||||||
|
1. `src/A_memorix` 的实现层改动应优先遵守 `src/A_memorix/MODIFICATION_POLICY.md` 中的归属约束。
|
||||||
|
2. 不要提交无边界的 `ruff`、格式化、导入整理或大面积实现整理。
|
||||||
|
3. 本地实验目录或依赖其运行的测试,除非明确说明并确认,否则不要进入共享历史。
|
||||||
|
|
||||||
# maibot插件开发文档
|
# maibot插件开发文档
|
||||||
https://github.com/Mai-with-u/maibot-plugin-sdk/blob/main/docs/guide.md
|
https://github.com/Mai-with-u/maibot-plugin-sdk/blob/main/docs/guide.md
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ default_sample_size = 24
|
|||||||
|
|
||||||
- 长期记忆控制台:适合修改高频项,例如 embedding、检索、Episode、人物画像、导入与调优的常用开关。
|
- 长期记忆控制台:适合修改高频项,例如 embedding、检索、Episode、人物画像、导入与调优的常用开关。
|
||||||
- 原始 TOML:适合复制整份配置、批量调整参数,或修改未在可视化表单中展示的高级项。
|
- 原始 TOML:适合复制整份配置、批量调整参数,或修改未在可视化表单中展示的高级项。
|
||||||
- raw-only 高级项仍包括:`retrieval.fusion.*`、`retrieval.search.relation_intent.*`、`retrieval.search.graph_recall.*`、`retrieval.aggregate.*`、`memory.orphan.*`、`advanced.extraction_model`、`web.import.llm_retry.*`、`web.import.path_aliases`、`web.import.convert.*`、`web.tuning.llm_retry.*`、`web.tuning.eval_query_timeout_seconds`。
|
- raw-only 高级项仍包括:`retrieval.fusion.*`、`retrieval.search.relation_intent.*`、`retrieval.search.graph_recall.*`、`retrieval.search.posterior_graph.*`、`retrieval.aggregate.*`、`memory.orphan.*`、`advanced.extraction_model`、`web.import.llm_retry.*`、`web.import.path_aliases`、`web.import.convert.*`、`web.tuning.llm_retry.*`、`web.tuning.eval_query_timeout_seconds`。
|
||||||
|
|
||||||
## 1. 存储与嵌入
|
## 1. 存储与嵌入
|
||||||
|
|
||||||
@@ -213,6 +213,26 @@ default_sample_size = 24
|
|||||||
- `allow_two_hop_pair` (默认 `true`)
|
- `allow_two_hop_pair` (默认 `true`)
|
||||||
- `max_paths` (默认 `4`)
|
- `max_paths` (默认 `4`)
|
||||||
|
|
||||||
|
### `retrieval.search.posterior_graph` (`PosteriorGraphConfig`)
|
||||||
|
|
||||||
|
- `enabled` (默认 `true`)
|
||||||
|
- `drop_ratio` (默认 `0.15`)
|
||||||
|
- `min_core_results` (默认 `2`)
|
||||||
|
- `max_graph_slots` (默认 `2`)
|
||||||
|
- `gate_scan_top_k` (默认 `5`)
|
||||||
|
- `grounded_confidence_threshold` (默认 `0.48`)
|
||||||
|
- `incidental_confidence_threshold` (默认 `0.22`)
|
||||||
|
- `min_query_token_coverage` (默认 `0.78`)
|
||||||
|
- `incidental_query_relevance_threshold` (默认 `0.68`)
|
||||||
|
- `incidental_core_overlap_threshold` (默认 `0.34`)
|
||||||
|
- `incidental_specificity_threshold` (默认 `0.42`)
|
||||||
|
|
||||||
|
说明:
|
||||||
|
|
||||||
|
- 这组配置控制“后验图补位”,即先跑正常双路检索,再判断是否需要从图结构补一小批 relation 候选进入尾部竞争。
|
||||||
|
- 设计目标以 `recall` 为主,而不是强行把 relation 顶到第一名。
|
||||||
|
- 如果你的最终回答仍会经过 LLM 汇总,这组能力更适合用于“保证证据进入前排候选”,而不是做激进排序改写。
|
||||||
|
|
||||||
### `retrieval.aggregate`
|
### `retrieval.aggregate`
|
||||||
|
|
||||||
- `retrieval.aggregate.rrf_k`
|
- `retrieval.aggregate.rrf_k`
|
||||||
|
|||||||
@@ -23,7 +23,10 @@ from src.common.logger import get_logger
|
|||||||
from .presets import (
|
from .presets import (
|
||||||
EmbeddingModelConfig,
|
EmbeddingModelConfig,
|
||||||
get_custom_config,
|
get_custom_config,
|
||||||
|
validate_config_compatibility,
|
||||||
|
are_models_compatible,
|
||||||
)
|
)
|
||||||
|
from ..utils.quantization import QuantizationType
|
||||||
|
|
||||||
logger = get_logger("A_Memorix.EmbeddingManager")
|
logger = get_logger("A_Memorix.EmbeddingManager")
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Optional, Dict, Any, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from .dual_path import (
|
|||||||
FusionConfig,
|
FusionConfig,
|
||||||
RelationIntentConfig,
|
RelationIntentConfig,
|
||||||
)
|
)
|
||||||
|
from .posterior_graph import PosteriorGraphConfig
|
||||||
from .pagerank import (
|
from .pagerank import (
|
||||||
PersonalizedPageRank,
|
PersonalizedPageRank,
|
||||||
PageRankConfig,
|
PageRankConfig,
|
||||||
@@ -37,6 +38,7 @@ __all__ = [
|
|||||||
"TemporalQueryOptions",
|
"TemporalQueryOptions",
|
||||||
"FusionConfig",
|
"FusionConfig",
|
||||||
"RelationIntentConfig",
|
"RelationIntentConfig",
|
||||||
|
"PosteriorGraphConfig",
|
||||||
# PersonalizedPageRank
|
# PersonalizedPageRank
|
||||||
"PersonalizedPageRank",
|
"PersonalizedPageRank",
|
||||||
"PageRankConfig",
|
"PageRankConfig",
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, List, Dict, Any, Tuple
|
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -19,6 +19,7 @@ from ..utils.matcher import AhoCorasick
|
|||||||
from ..utils.time_parser import format_timestamp
|
from ..utils.time_parser import format_timestamp
|
||||||
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
|
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
|
||||||
from .pagerank import PersonalizedPageRank, PageRankConfig
|
from .pagerank import PersonalizedPageRank, PageRankConfig
|
||||||
|
from .posterior_graph import PosteriorGraphConfig, apply_posterior_graph_gate
|
||||||
from .sparse_bm25 import SparseBM25Config, SparseBM25Index
|
from .sparse_bm25 import SparseBM25Config, SparseBM25Index
|
||||||
|
|
||||||
logger = get_logger("A_Memorix.DualPathRetriever")
|
logger = get_logger("A_Memorix.DualPathRetriever")
|
||||||
@@ -101,6 +102,7 @@ class DualPathRetrieverConfig:
|
|||||||
fusion: "FusionConfig" = field(default_factory=lambda: FusionConfig())
|
fusion: "FusionConfig" = field(default_factory=lambda: FusionConfig())
|
||||||
relation_intent: "RelationIntentConfig" = field(default_factory=lambda: RelationIntentConfig())
|
relation_intent: "RelationIntentConfig" = field(default_factory=lambda: RelationIntentConfig())
|
||||||
graph_recall: GraphRelationRecallConfig = field(default_factory=GraphRelationRecallConfig)
|
graph_recall: GraphRelationRecallConfig = field(default_factory=GraphRelationRecallConfig)
|
||||||
|
posterior_graph: PosteriorGraphConfig = field(default_factory=PosteriorGraphConfig)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""验证配置"""
|
"""验证配置"""
|
||||||
@@ -112,6 +114,8 @@ class DualPathRetrieverConfig:
|
|||||||
self.relation_intent = RelationIntentConfig(**self.relation_intent)
|
self.relation_intent = RelationIntentConfig(**self.relation_intent)
|
||||||
if isinstance(self.graph_recall, dict):
|
if isinstance(self.graph_recall, dict):
|
||||||
self.graph_recall = GraphRelationRecallConfig(**self.graph_recall)
|
self.graph_recall = GraphRelationRecallConfig(**self.graph_recall)
|
||||||
|
if isinstance(self.posterior_graph, dict):
|
||||||
|
self.posterior_graph = PosteriorGraphConfig(**self.posterior_graph)
|
||||||
|
|
||||||
if not 0 <= self.alpha <= 1:
|
if not 0 <= self.alpha <= 1:
|
||||||
raise ValueError(f"alpha必须在[0, 1]之间: {self.alpha}")
|
raise ValueError(f"alpha必须在[0, 1]之间: {self.alpha}")
|
||||||
@@ -320,7 +324,7 @@ class DualPathRetriever:
|
|||||||
|
|
||||||
# 调试模式:打印结果原文
|
# 调试模式:打印结果原文
|
||||||
if self.config.debug:
|
if self.config.debug:
|
||||||
logger.info("[DEBUG] 检索结果内容原文:")
|
logger.info(f"[DEBUG] 检索结果内容原文:")
|
||||||
for i, res in enumerate(results):
|
for i, res in enumerate(results):
|
||||||
logger.info(f" {i+1}. [{res.result_type}] (Score: {res.score:.4f}) {res.content}")
|
logger.info(f" {i+1}. [{res.result_type}] (Score: {res.score:.4f}) {res.content}")
|
||||||
|
|
||||||
@@ -588,6 +592,7 @@ class DualPathRetriever:
|
|||||||
candidate_k = max(top_k, self.config.sparse.candidate_k)
|
candidate_k = max(top_k, self.config.sparse.candidate_k)
|
||||||
candidate_k = self._cap_temporal_scan_k(candidate_k, temporal)
|
candidate_k = self._cap_temporal_scan_k(candidate_k, temporal)
|
||||||
sparse_rows = self.sparse_index.search(query=query, k=candidate_k)
|
sparse_rows = self.sparse_index.search(query=query, k=candidate_k)
|
||||||
|
sparse_rows = self._filter_sparse_paragraph_rows(sparse_rows)
|
||||||
results: List[RetrievalResult] = []
|
results: List[RetrievalResult] = []
|
||||||
for row in sparse_rows:
|
for row in sparse_rows:
|
||||||
hash_value = row["hash"]
|
hash_value = row["hash"]
|
||||||
@@ -614,6 +619,53 @@ class DualPathRetriever:
|
|||||||
self._normalize_scores_minmax(results)
|
self._normalize_scores_minmax(results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _filter_sparse_paragraph_rows(
|
||||||
|
self,
|
||||||
|
rows: List[Dict[str, Any]],
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
过滤 paragraph sparse tail。
|
||||||
|
|
||||||
|
目标不是压缩强 lexical hit,而是避免只命中一个弱 token 的尾部结果
|
||||||
|
在 weighted RRF 中拿到过高的 rank credit。
|
||||||
|
"""
|
||||||
|
if len(rows) <= 2:
|
||||||
|
return rows
|
||||||
|
|
||||||
|
top_score = max(0.0, float(rows[0].get("score", 0.0) or 0.0))
|
||||||
|
if top_score <= 0.0:
|
||||||
|
return rows[:2]
|
||||||
|
|
||||||
|
relative_floor = top_score * 0.2
|
||||||
|
filtered_rows: List[Dict[str, Any]] = []
|
||||||
|
removed_count = 0
|
||||||
|
for index, row in enumerate(rows):
|
||||||
|
if index < 2:
|
||||||
|
filtered_rows.append(row)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_score = float(row.get("score", 0.0) or 0.0)
|
||||||
|
matched_token_count = int(row.get("matched_token_count", 0) or 0)
|
||||||
|
matched_token_ratio = float(row.get("matched_token_ratio", 0.0) or 0.0)
|
||||||
|
|
||||||
|
if (
|
||||||
|
raw_score >= relative_floor
|
||||||
|
or matched_token_count >= 3
|
||||||
|
or (matched_token_count >= 2 and matched_token_ratio >= 0.12)
|
||||||
|
):
|
||||||
|
filtered_rows.append(row)
|
||||||
|
continue
|
||||||
|
|
||||||
|
removed_count += 1
|
||||||
|
|
||||||
|
if removed_count > 0:
|
||||||
|
logger.debug(
|
||||||
|
"sparse_paragraph_tail_pruned=1 "
|
||||||
|
f"removed_count={removed_count} "
|
||||||
|
f"kept_count={len(filtered_rows)}"
|
||||||
|
)
|
||||||
|
return filtered_rows
|
||||||
|
|
||||||
def _search_relations_sparse(
|
def _search_relations_sparse(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -1025,6 +1077,14 @@ class DualPathRetriever:
|
|||||||
)
|
)
|
||||||
if temporal:
|
if temporal:
|
||||||
fused_results = self._sort_results_with_temporal(fused_results, temporal)
|
fused_results = self._sort_results_with_temporal(fused_results, temporal)
|
||||||
|
fused_results = apply_posterior_graph_gate(
|
||||||
|
self,
|
||||||
|
query=query,
|
||||||
|
base_results=fused_results,
|
||||||
|
top_k=top_k,
|
||||||
|
temporal=temporal,
|
||||||
|
relation_intent=relation_intent,
|
||||||
|
)
|
||||||
fused_results = self._apply_relation_intent_pair_rerank(
|
fused_results = self._apply_relation_intent_pair_rerank(
|
||||||
fused_results,
|
fused_results,
|
||||||
enabled=bool(relation_intent.get("enabled", False)),
|
enabled=bool(relation_intent.get("enabled", False)),
|
||||||
@@ -1126,6 +1186,15 @@ class DualPathRetriever:
|
|||||||
if temporal:
|
if temporal:
|
||||||
fused_results = self._sort_results_with_temporal(fused_results, temporal)
|
fused_results = self._sort_results_with_temporal(fused_results, temporal)
|
||||||
|
|
||||||
|
fused_results = apply_posterior_graph_gate(
|
||||||
|
self,
|
||||||
|
query=query,
|
||||||
|
base_results=fused_results,
|
||||||
|
top_k=top_k,
|
||||||
|
temporal=temporal,
|
||||||
|
relation_intent=relation_intent,
|
||||||
|
)
|
||||||
|
|
||||||
fused_results = self._apply_relation_intent_pair_rerank(
|
fused_results = self._apply_relation_intent_pair_rerank(
|
||||||
fused_results,
|
fused_results,
|
||||||
enabled=bool(relation_intent.get("enabled", False)),
|
enabled=bool(relation_intent.get("enabled", False)),
|
||||||
@@ -1150,37 +1219,13 @@ class DualPathRetriever:
|
|||||||
Returns:
|
Returns:
|
||||||
(段落结果, 关系结果)
|
(段落结果, 关系结果)
|
||||||
"""
|
"""
|
||||||
# 使用 asyncio.gather 并发执行两个搜索任务
|
|
||||||
# 由于 _search_paragraphs 和 _search_relations 是 CPU 密集型同步函数,
|
|
||||||
# 使用 asyncio.to_thread 在线程池中执行
|
|
||||||
try:
|
try:
|
||||||
para_task = asyncio.to_thread(
|
return await asyncio.to_thread(
|
||||||
self._search_paragraphs,
|
self._collect_mixed_candidates,
|
||||||
query_emb,
|
query_emb,
|
||||||
self.config.top_k_paragraphs,
|
|
||||||
temporal,
|
temporal,
|
||||||
|
relation_top_k,
|
||||||
)
|
)
|
||||||
rel_task = asyncio.to_thread(
|
|
||||||
self._search_relations,
|
|
||||||
query_emb,
|
|
||||||
relation_top_k if relation_top_k is not None else self.config.top_k_relations,
|
|
||||||
temporal,
|
|
||||||
)
|
|
||||||
|
|
||||||
para_results, rel_results = await asyncio.gather(
|
|
||||||
para_task, rel_task, return_exceptions=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理异常
|
|
||||||
if isinstance(para_results, Exception):
|
|
||||||
logger.error(f"段落检索失败: {para_results}")
|
|
||||||
para_results = []
|
|
||||||
if isinstance(rel_results, Exception):
|
|
||||||
logger.error(f"关系检索失败: {rel_results}")
|
|
||||||
rel_results = []
|
|
||||||
|
|
||||||
return para_results, rel_results
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"并行检索失败: {e}")
|
logger.error(f"并行检索失败: {e}")
|
||||||
return [], []
|
return [], []
|
||||||
@@ -1200,18 +1245,125 @@ class DualPathRetriever:
|
|||||||
Returns:
|
Returns:
|
||||||
(段落结果, 关系结果)
|
(段落结果, 关系结果)
|
||||||
"""
|
"""
|
||||||
para_results = self._search_paragraphs(
|
return self._collect_mixed_candidates(
|
||||||
query_emb,
|
query_emb,
|
||||||
self.config.top_k_paragraphs,
|
|
||||||
temporal,
|
temporal,
|
||||||
|
relation_top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
rel_results = self._search_relations(
|
def _mixed_candidate_budget(
|
||||||
query_emb,
|
self,
|
||||||
relation_top_k if relation_top_k is not None else self.config.top_k_relations,
|
para_top_k: int,
|
||||||
temporal,
|
rel_top_k: int,
|
||||||
)
|
temporal: Optional[TemporalQueryOptions],
|
||||||
|
) -> int:
|
||||||
|
multiplier = max(1, temporal.candidate_multiplier) if temporal else 1
|
||||||
|
base = max(para_top_k + rel_top_k, max(para_top_k, rel_top_k) * 2)
|
||||||
|
return max(base * 6 * multiplier, 48)
|
||||||
|
|
||||||
|
def _merge_backfilled_results(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
primary_results: List[RetrievalResult],
|
||||||
|
backfill_results: List[RetrievalResult],
|
||||||
|
top_k: int,
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
merged: Dict[str, RetrievalResult] = {}
|
||||||
|
for item in primary_results:
|
||||||
|
merged[item.hash_value] = item
|
||||||
|
for item in backfill_results:
|
||||||
|
existing = merged.get(item.hash_value)
|
||||||
|
if existing is None or float(item.score) > float(existing.score):
|
||||||
|
merged[item.hash_value] = item
|
||||||
|
|
||||||
|
results = list(merged.values())
|
||||||
|
results.sort(key=lambda item: item.score, reverse=True)
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
def _collect_mixed_candidates(
|
||||||
|
self,
|
||||||
|
query_emb: np.ndarray,
|
||||||
|
temporal: Optional[TemporalQueryOptions] = None,
|
||||||
|
relation_top_k: Optional[int] = None,
|
||||||
|
) -> Tuple[List[RetrievalResult], List[RetrievalResult]]:
|
||||||
|
para_top_k = self.config.top_k_paragraphs
|
||||||
|
rel_top_k = relation_top_k if relation_top_k is not None else self.config.top_k_relations
|
||||||
|
candidate_k = self._mixed_candidate_budget(para_top_k, rel_top_k, temporal)
|
||||||
|
candidate_k = self._cap_temporal_scan_k(candidate_k, temporal)
|
||||||
|
ids, scores = self.vector_store.search(query_emb, k=candidate_k)
|
||||||
|
|
||||||
|
para_candidates: List[RetrievalResult] = []
|
||||||
|
rel_candidates: List[RetrievalResult] = []
|
||||||
|
seen_para = set()
|
||||||
|
seen_rel = set()
|
||||||
|
|
||||||
|
for hash_value, score in zip(ids, scores):
|
||||||
|
paragraph = self.metadata_store.get_paragraph(hash_value)
|
||||||
|
if paragraph is not None and hash_value not in seen_para:
|
||||||
|
seen_para.add(hash_value)
|
||||||
|
para_candidates.append(
|
||||||
|
RetrievalResult(
|
||||||
|
hash_value=hash_value,
|
||||||
|
content=paragraph["content"],
|
||||||
|
score=float(score),
|
||||||
|
result_type="paragraph",
|
||||||
|
source="paragraph_search",
|
||||||
|
metadata={
|
||||||
|
"word_count": paragraph.get("word_count", 0),
|
||||||
|
"time_meta": self._build_time_meta_from_paragraph(
|
||||||
|
paragraph,
|
||||||
|
temporal=temporal,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
relation = self.metadata_store.get_relation(hash_value, include_inactive=False)
|
||||||
|
if relation is None or hash_value in seen_rel:
|
||||||
|
continue
|
||||||
|
|
||||||
|
relation_time_meta = None
|
||||||
|
if temporal:
|
||||||
|
relation_time_meta = self._best_supporting_time_meta(hash_value, temporal)
|
||||||
|
if relation_time_meta is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_rel.add(hash_value)
|
||||||
|
rel_candidates.append(
|
||||||
|
RetrievalResult(
|
||||||
|
hash_value=hash_value,
|
||||||
|
content=f"{relation['subject']} {relation['predicate']} {relation['object']}",
|
||||||
|
score=float(score),
|
||||||
|
result_type="relation",
|
||||||
|
source="relation_search",
|
||||||
|
metadata={
|
||||||
|
"subject": relation["subject"],
|
||||||
|
"predicate": relation["predicate"],
|
||||||
|
"object": relation["object"],
|
||||||
|
"confidence": relation.get("confidence", 1.0),
|
||||||
|
"time_meta": relation_time_meta,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
para_results = self._apply_temporal_filter_to_paragraphs(para_candidates, temporal)
|
||||||
|
rel_results = self._apply_temporal_filter_to_relations(rel_candidates, temporal)
|
||||||
|
|
||||||
|
# 双重方案里,向量主干优先解决“召回不够”,因此主检索走共享候选池,
|
||||||
|
# 但再补一层按类型回填,避免 paragraph / relation 任一侧被饿死。
|
||||||
|
para_backfill = self._search_paragraphs(query_emb, para_top_k, temporal)
|
||||||
|
rel_backfill = self._search_relations(query_emb, rel_top_k, temporal)
|
||||||
|
para_results = self._merge_backfilled_results(
|
||||||
|
primary_results=para_results,
|
||||||
|
backfill_results=para_backfill,
|
||||||
|
top_k=para_top_k,
|
||||||
|
)
|
||||||
|
rel_results = self._merge_backfilled_results(
|
||||||
|
primary_results=rel_results,
|
||||||
|
backfill_results=rel_backfill,
|
||||||
|
top_k=rel_top_k,
|
||||||
|
)
|
||||||
return para_results, rel_results
|
return para_results, rel_results
|
||||||
|
|
||||||
def _search_paragraphs(
|
def _search_paragraphs(
|
||||||
@@ -1560,9 +1712,20 @@ class DualPathRetriever:
|
|||||||
entity_scores.append(ppr_scores_by_name[ent_name])
|
entity_scores.append(ppr_scores_by_name[ent_name])
|
||||||
|
|
||||||
if entity_scores:
|
if entity_scores:
|
||||||
avg_ppr = np.mean(entity_scores)
|
# 只使用命中的高价值图实体做正向增益,避免把原本高分的正确段落
|
||||||
# 融合原始分数和PPR分数
|
# 因为“实体多但非全部命中”而反向压低。
|
||||||
result.score = result.score * 0.7 + avg_ppr * 0.3
|
focus_scores = sorted(entity_scores, reverse=True)[:2]
|
||||||
|
ppr_signal = float(np.mean(focus_scores))
|
||||||
|
boost_weight = 0.12 if len(focus_scores) >= 2 else 0.06
|
||||||
|
boost = ppr_signal * boost_weight
|
||||||
|
|
||||||
|
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||||
|
metadata["ppr_signal"] = round(ppr_signal, 4)
|
||||||
|
metadata["ppr_focus_entity_count"] = len(focus_scores)
|
||||||
|
metadata["ppr_boost"] = round(boost, 4)
|
||||||
|
result.metadata = metadata
|
||||||
|
|
||||||
|
result.score = float(result.score) + float(boost)
|
||||||
|
|
||||||
# 重新排序
|
# 重新排序
|
||||||
results.sort(key=lambda x: x.score, reverse=True)
|
results.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ Personalized PageRank实现
|
|||||||
提供个性化的图节点排序功能。
|
提供个性化的图节点排序功能。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple, Any
|
from typing import Dict, List, Optional, Tuple, Union, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from ..storage import GraphStore
|
from ..storage import GraphStore
|
||||||
@@ -48,7 +49,7 @@ class PageRankConfig:
|
|||||||
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
|
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
|
||||||
|
|
||||||
if self.min_iterations >= self.max_iter:
|
if self.min_iterations >= self.max_iter:
|
||||||
raise ValueError("min_iterations必须小于max_iter")
|
raise ValueError(f"min_iterations必须小于max_iter")
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedPageRank:
|
class PersonalizedPageRank:
|
||||||
|
|||||||
792
src/A_memorix/core/retrieval/posterior_graph.py
Normal file
792
src/A_memorix/core/retrieval/posterior_graph.py
Normal file
@@ -0,0 +1,792 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Set, Tuple
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .dual_path import DualPathRetriever, RetrievalResult
|
||||||
|
|
||||||
|
|
||||||
|
_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]+")
|
||||||
|
_BROAD_PREDICATES = {
|
||||||
|
"contains_fact",
|
||||||
|
"describe",
|
||||||
|
"describes",
|
||||||
|
"description",
|
||||||
|
"mention",
|
||||||
|
"mentions",
|
||||||
|
"summary",
|
||||||
|
"summarizes",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PosteriorGraphConfig:
|
||||||
|
"""双重方案中的后验图补位配置。"""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
drop_ratio: float = 0.15
|
||||||
|
min_core_results: int = 2
|
||||||
|
max_graph_slots: int = 2
|
||||||
|
gate_scan_top_k: int = 5
|
||||||
|
grounded_confidence_threshold: float = 0.48
|
||||||
|
incidental_confidence_threshold: float = 0.22
|
||||||
|
min_query_token_coverage: float = 0.78
|
||||||
|
incidental_query_relevance_threshold: float = 0.68
|
||||||
|
incidental_core_overlap_threshold: float = 0.34
|
||||||
|
incidental_specificity_threshold: float = 0.42
|
||||||
|
query_weight: float = 0.28
|
||||||
|
novelty_weight: float = 0.18
|
||||||
|
complementarity_weight: float = 0.16
|
||||||
|
specificity_weight: float = 0.12
|
||||||
|
gap_fill_weight: float = 0.26
|
||||||
|
max_candidate_tokens: int = 48
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.enabled = bool(self.enabled)
|
||||||
|
self.drop_ratio = max(0.0, float(self.drop_ratio))
|
||||||
|
self.min_core_results = max(1, int(self.min_core_results))
|
||||||
|
self.max_graph_slots = max(0, int(self.max_graph_slots))
|
||||||
|
self.gate_scan_top_k = max(1, int(self.gate_scan_top_k))
|
||||||
|
self.grounded_confidence_threshold = _clip_score(self.grounded_confidence_threshold)
|
||||||
|
self.incidental_confidence_threshold = _clip_score(self.incidental_confidence_threshold)
|
||||||
|
self.min_query_token_coverage = _clip_score(self.min_query_token_coverage)
|
||||||
|
self.incidental_query_relevance_threshold = _clip_score(
|
||||||
|
self.incidental_query_relevance_threshold
|
||||||
|
)
|
||||||
|
self.incidental_core_overlap_threshold = _clip_score(
|
||||||
|
self.incidental_core_overlap_threshold
|
||||||
|
)
|
||||||
|
self.incidental_specificity_threshold = _clip_score(
|
||||||
|
self.incidental_specificity_threshold
|
||||||
|
)
|
||||||
|
self.query_weight = max(0.0, float(self.query_weight))
|
||||||
|
self.novelty_weight = max(0.0, float(self.novelty_weight))
|
||||||
|
self.complementarity_weight = max(0.0, float(self.complementarity_weight))
|
||||||
|
self.specificity_weight = max(0.0, float(self.specificity_weight))
|
||||||
|
self.gap_fill_weight = max(0.0, float(self.gap_fill_weight))
|
||||||
|
self.max_candidate_tokens = max(8, int(self.max_candidate_tokens))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _CompetitionProfile:
|
||||||
|
text: str
|
||||||
|
tokens: Set[str]
|
||||||
|
entities: Set[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _SeedEvidence:
|
||||||
|
name: str
|
||||||
|
strength: str
|
||||||
|
support_count: int
|
||||||
|
rank_hint: int
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_ratio(numerator: int, denominator: int) -> float:
|
||||||
|
if denominator <= 0:
|
||||||
|
return 0.0
|
||||||
|
return float(numerator) / float(denominator)
|
||||||
|
|
||||||
|
|
||||||
|
def _clip_score(value: float) -> float:
|
||||||
|
return max(0.0, min(1.0, float(value)))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_cjk_chunk(token: str) -> bool:
|
||||||
|
return bool(token) and all("\u4e00" <= char <= "\u9fff" for char in token)
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize_for_competition(text: str, *, max_tokens: int) -> List[str]:
|
||||||
|
normalized = str(text or "").lower().strip()
|
||||||
|
if not normalized:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tokens: List[str] = []
|
||||||
|
for chunk in _TOKEN_PATTERN.findall(normalized):
|
||||||
|
if _is_cjk_chunk(chunk):
|
||||||
|
tokens.extend(
|
||||||
|
item.strip().lower()
|
||||||
|
for item in jieba.lcut_for_search(chunk)
|
||||||
|
if item.strip()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokens.append(chunk)
|
||||||
|
|
||||||
|
filtered: List[str] = []
|
||||||
|
for token in tokens:
|
||||||
|
if len(token) <= 1:
|
||||||
|
continue
|
||||||
|
filtered.append(token)
|
||||||
|
if len(filtered) >= max_tokens:
|
||||||
|
break
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def _result_text_for_entity_match(result: RetrievalResult) -> str:
|
||||||
|
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||||
|
parts = [
|
||||||
|
str(result.content or ""),
|
||||||
|
str(metadata.get("subject", "") or ""),
|
||||||
|
str(metadata.get("object", "") or ""),
|
||||||
|
str(metadata.get("context_title", "") or ""),
|
||||||
|
str(metadata.get("benchmark_title", "") or ""),
|
||||||
|
]
|
||||||
|
return "\n".join(part for part in parts if part)
|
||||||
|
|
||||||
|
|
||||||
|
def _candidate_text_for_competition(result: RetrievalResult) -> str:
|
||||||
|
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||||
|
parts = [
|
||||||
|
_result_text_for_entity_match(result),
|
||||||
|
str(metadata.get("predicate", "") or ""),
|
||||||
|
]
|
||||||
|
return "\n".join(part for part in parts if part)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_candidate_entities(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
result: RetrievalResult,
|
||||||
|
) -> Set[str]:
|
||||||
|
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||||
|
entities: Set[str] = set()
|
||||||
|
|
||||||
|
for name in retriever._extract_entities(_candidate_text_for_competition(result)).keys():
|
||||||
|
normalized = str(name or "").strip().lower()
|
||||||
|
if normalized:
|
||||||
|
entities.add(normalized)
|
||||||
|
|
||||||
|
for key in ("benchmark_title", "context_title", "object", "subject"):
|
||||||
|
normalized = str(metadata.get(key, "") or "").strip().lower()
|
||||||
|
if normalized:
|
||||||
|
entities.add(normalized)
|
||||||
|
|
||||||
|
return entities
|
||||||
|
|
||||||
|
|
||||||
|
def _build_query_profile(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> _CompetitionProfile:
|
||||||
|
text = str(query or "")
|
||||||
|
tokens = set(_tokenize_for_competition(text, max_tokens=max_tokens))
|
||||||
|
entities = {
|
||||||
|
str(name or "").strip().lower()
|
||||||
|
for name in retriever._extract_entities(text).keys()
|
||||||
|
if str(name or "").strip()
|
||||||
|
}
|
||||||
|
return _CompetitionProfile(text=text, tokens=tokens, entities=entities)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_candidate_profile(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
result: RetrievalResult,
|
||||||
|
*,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> _CompetitionProfile:
|
||||||
|
text = _candidate_text_for_competition(result)
|
||||||
|
return _CompetitionProfile(
|
||||||
|
text=text,
|
||||||
|
tokens=set(_tokenize_for_competition(text, max_tokens=max_tokens)),
|
||||||
|
entities=_extract_candidate_entities(retriever, result),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_core_profile(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
results: Sequence[RetrievalResult],
|
||||||
|
*,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> _CompetitionProfile:
|
||||||
|
parts: List[str] = []
|
||||||
|
tokens: Set[str] = set()
|
||||||
|
entities: Set[str] = set()
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
profile = _build_candidate_profile(retriever, result, max_tokens=max_tokens)
|
||||||
|
parts.append(profile.text)
|
||||||
|
tokens.update(profile.tokens)
|
||||||
|
entities.update(profile.entities)
|
||||||
|
|
||||||
|
return _CompetitionProfile(text="\n".join(parts), tokens=tokens, entities=entities)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_query_relevance(candidate: _CompetitionProfile, query: _CompetitionProfile) -> float:
|
||||||
|
entity_hit = _safe_ratio(len(candidate.entities & query.entities), len(query.entities))
|
||||||
|
token_hit = _safe_ratio(len(candidate.tokens & query.tokens), len(query.tokens))
|
||||||
|
if query.entities:
|
||||||
|
return _clip_score(0.65 * entity_hit + 0.35 * token_hit)
|
||||||
|
return _clip_score(max(entity_hit, token_hit))
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_novelty(candidate: _CompetitionProfile, core: _CompetitionProfile) -> float:
|
||||||
|
entity_novelty = _safe_ratio(len(candidate.entities - core.entities), len(candidate.entities))
|
||||||
|
token_novelty = _safe_ratio(len(candidate.tokens - core.tokens), len(candidate.tokens))
|
||||||
|
return _clip_score(0.5 * entity_novelty + 0.5 * token_novelty)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_complementarity(
|
||||||
|
candidate: _CompetitionProfile,
|
||||||
|
core: _CompetitionProfile,
|
||||||
|
query_relevance: float,
|
||||||
|
) -> float:
|
||||||
|
if not core.tokens and not core.entities:
|
||||||
|
return _clip_score(query_relevance)
|
||||||
|
|
||||||
|
entity_overlap = _safe_ratio(len(candidate.entities & core.entities), len(candidate.entities))
|
||||||
|
token_overlap = _safe_ratio(len(candidate.tokens & core.tokens), len(candidate.tokens))
|
||||||
|
core_overlap = 0.5 * entity_overlap + 0.5 * token_overlap
|
||||||
|
sweet_spot = 1.0 - abs(core_overlap - 0.4) / 0.4
|
||||||
|
return _clip_score(max(0.0, sweet_spot) * max(query_relevance, 0.2))
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_specificity(candidate: _CompetitionProfile, result: RetrievalResult) -> float:
|
||||||
|
token_count = max(1, len(candidate.tokens))
|
||||||
|
entity_density = _clip_score(_safe_ratio(len(candidate.entities), token_count) * 4.0)
|
||||||
|
brevity = 1.0 - min(1.0, max(0, token_count - 16) / 16.0)
|
||||||
|
predicate_bonus = 0.0
|
||||||
|
|
||||||
|
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||||
|
predicate = str(metadata.get("predicate", "") or "").strip().lower()
|
||||||
|
if predicate:
|
||||||
|
if predicate in _BROAD_PREDICATES:
|
||||||
|
predicate_bonus = -0.25
|
||||||
|
elif result.result_type == "relation":
|
||||||
|
predicate_bonus = 0.10
|
||||||
|
|
||||||
|
return _clip_score(0.6 * entity_density + 0.4 * brevity + predicate_bonus)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_gap_fill(
|
||||||
|
candidate: _CompetitionProfile,
|
||||||
|
query: _CompetitionProfile,
|
||||||
|
core: _CompetitionProfile,
|
||||||
|
) -> float:
|
||||||
|
missing_entities = query.entities - core.entities
|
||||||
|
missing_tokens = query.tokens - core.tokens
|
||||||
|
|
||||||
|
entity_fill = _safe_ratio(len(candidate.entities & missing_entities), len(missing_entities))
|
||||||
|
token_fill = _safe_ratio(len(candidate.tokens & missing_tokens), len(missing_tokens))
|
||||||
|
|
||||||
|
if missing_entities:
|
||||||
|
return _clip_score(0.7 * entity_fill + 0.3 * token_fill)
|
||||||
|
return _clip_score(max(entity_fill, token_fill))
|
||||||
|
|
||||||
|
|
||||||
|
def _core_overlap(candidate: _CompetitionProfile, core: _CompetitionProfile) -> float:
|
||||||
|
entity_overlap = _safe_ratio(len(candidate.entities & core.entities), len(candidate.entities))
|
||||||
|
token_overlap = _safe_ratio(len(candidate.tokens & core.tokens), len(candidate.tokens))
|
||||||
|
return _clip_score(0.5 * entity_overlap + 0.5 * token_overlap)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_competition_score(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
candidate: RetrievalResult,
|
||||||
|
*,
|
||||||
|
query_profile: _CompetitionProfile,
|
||||||
|
core_profile: _CompetitionProfile,
|
||||||
|
cfg: PosteriorGraphConfig,
|
||||||
|
) -> Tuple[float, Dict[str, float]]:
|
||||||
|
candidate_profile = _build_candidate_profile(
|
||||||
|
retriever,
|
||||||
|
candidate,
|
||||||
|
max_tokens=cfg.max_candidate_tokens,
|
||||||
|
)
|
||||||
|
query_relevance = _compute_query_relevance(candidate_profile, query_profile)
|
||||||
|
novelty = _compute_novelty(candidate_profile, core_profile)
|
||||||
|
complementarity = _compute_complementarity(candidate_profile, core_profile, query_relevance)
|
||||||
|
specificity = _compute_specificity(candidate_profile, candidate)
|
||||||
|
gap_fill = _compute_gap_fill(candidate_profile, query_profile, core_profile)
|
||||||
|
|
||||||
|
final_score = (
|
||||||
|
cfg.query_weight * query_relevance
|
||||||
|
+ cfg.novelty_weight * novelty
|
||||||
|
+ cfg.complementarity_weight * complementarity
|
||||||
|
+ cfg.specificity_weight * specificity
|
||||||
|
+ cfg.gap_fill_weight * gap_fill
|
||||||
|
)
|
||||||
|
breakdown = {
|
||||||
|
"query_relevance": round(query_relevance, 4),
|
||||||
|
"novelty": round(novelty, 4),
|
||||||
|
"complementarity": round(complementarity, 4),
|
||||||
|
"specificity": round(specificity, 4),
|
||||||
|
"gap_fill": round(gap_fill, 4),
|
||||||
|
"competition_score": round(_clip_score(final_score), 4),
|
||||||
|
}
|
||||||
|
return _clip_score(final_score), breakdown
|
||||||
|
|
||||||
|
|
||||||
|
def _top_score(results: Sequence[RetrievalResult]) -> float:
|
||||||
|
if not results:
|
||||||
|
return 0.0
|
||||||
|
return max(float(item.score) for item in results)
|
||||||
|
|
||||||
|
|
||||||
|
def find_score_cliff(
|
||||||
|
results: Sequence[RetrievalResult],
|
||||||
|
*,
|
||||||
|
drop_ratio: float,
|
||||||
|
min_core_results: int,
|
||||||
|
) -> int:
|
||||||
|
ranked = list(results)
|
||||||
|
if not ranked:
|
||||||
|
return 0
|
||||||
|
if len(ranked) <= min_core_results:
|
||||||
|
return len(ranked)
|
||||||
|
|
||||||
|
for index in range(1, len(ranked)):
|
||||||
|
prev_score = max(float(ranked[index - 1].score), 1e-8)
|
||||||
|
current_score = float(ranked[index].score)
|
||||||
|
score_drop = prev_score - current_score
|
||||||
|
if score_drop / prev_score > float(drop_ratio):
|
||||||
|
return max(min_core_results, index)
|
||||||
|
|
||||||
|
fallback = max(min_core_results, len(ranked) // 2)
|
||||||
|
return min(len(ranked), fallback)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_seed_evidence(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
query_profile: _CompetitionProfile,
|
||||||
|
results: Sequence[RetrievalResult],
|
||||||
|
*,
|
||||||
|
scan_top_k: int,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[_SeedEvidence]:
|
||||||
|
score_map: Dict[Tuple[str, str], _SeedEvidence] = {}
|
||||||
|
top_results = list(results)[: max(1, int(scan_top_k))]
|
||||||
|
|
||||||
|
for rank, item in enumerate(top_results, start=1):
|
||||||
|
profile = _build_candidate_profile(retriever, item, max_tokens=max_tokens)
|
||||||
|
for entity in profile.entities:
|
||||||
|
strength = "grounded" if entity in query_profile.entities else "incidental"
|
||||||
|
key = (entity, strength)
|
||||||
|
existing = score_map.get(key)
|
||||||
|
if existing is None:
|
||||||
|
score_map[key] = _SeedEvidence(
|
||||||
|
name=entity,
|
||||||
|
strength=strength,
|
||||||
|
support_count=1,
|
||||||
|
rank_hint=rank,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
existing.support_count += 1
|
||||||
|
existing.rank_hint = min(existing.rank_hint, rank)
|
||||||
|
|
||||||
|
return sorted(
|
||||||
|
score_map.values(),
|
||||||
|
key=lambda item: (
|
||||||
|
0 if item.strength == "grounded" else 1,
|
||||||
|
-int(item.support_count),
|
||||||
|
int(item.rank_hint),
|
||||||
|
-len(item.name),
|
||||||
|
item.name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _grounded_seed_names(seed_evidence: Sequence[_SeedEvidence]) -> List[str]:
|
||||||
|
return [item.name for item in seed_evidence if item.strength == "grounded"]
|
||||||
|
|
||||||
|
|
||||||
|
def _incidental_seed_names(seed_evidence: Sequence[_SeedEvidence]) -> List[str]:
|
||||||
|
return [item.name for item in seed_evidence if item.strength == "incidental"]
|
||||||
|
|
||||||
|
|
||||||
|
def _need_for_graph(
|
||||||
|
*,
|
||||||
|
query_profile: _CompetitionProfile,
|
||||||
|
core_profile: _CompetitionProfile,
|
||||||
|
core_profiles: Sequence[_CompetitionProfile],
|
||||||
|
grounded_seeds: Sequence[str],
|
||||||
|
rag_confidence: float,
|
||||||
|
cfg: PosteriorGraphConfig,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
uncovered_query_entities = query_profile.entities - core_profile.entities
|
||||||
|
if uncovered_query_entities:
|
||||||
|
return True, "uncovered_query_entities"
|
||||||
|
|
||||||
|
if len(grounded_seeds) >= 2:
|
||||||
|
bridge_targets = set(list(grounded_seeds)[:2])
|
||||||
|
same_core_hit = any(len(profile.entities & bridge_targets) >= 2 for profile in core_profiles)
|
||||||
|
if not same_core_hit:
|
||||||
|
return True, "grounded_bridge_gap"
|
||||||
|
|
||||||
|
token_coverage = _safe_ratio(len(core_profile.tokens & query_profile.tokens), len(query_profile.tokens))
|
||||||
|
if grounded_seeds and token_coverage < float(cfg.min_query_token_coverage):
|
||||||
|
return True, "low_core_query_coverage"
|
||||||
|
|
||||||
|
if grounded_seeds and float(rag_confidence) < float(cfg.grounded_confidence_threshold):
|
||||||
|
return True, "low_confidence_grounded"
|
||||||
|
|
||||||
|
return False, "core_already_sufficient"
|
||||||
|
|
||||||
|
|
||||||
|
def _passes_incidental_high_bar(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
candidate: RetrievalResult,
|
||||||
|
*,
|
||||||
|
query_profile: _CompetitionProfile,
|
||||||
|
core_profile: _CompetitionProfile,
|
||||||
|
cfg: PosteriorGraphConfig,
|
||||||
|
) -> bool:
|
||||||
|
candidate_profile = _build_candidate_profile(
|
||||||
|
retriever,
|
||||||
|
candidate,
|
||||||
|
max_tokens=cfg.max_candidate_tokens,
|
||||||
|
)
|
||||||
|
uncovered_query_entities = query_profile.entities - core_profile.entities
|
||||||
|
if candidate_profile.entities & uncovered_query_entities:
|
||||||
|
return True
|
||||||
|
|
||||||
|
query_relevance = _compute_query_relevance(candidate_profile, query_profile)
|
||||||
|
specificity = _compute_specificity(candidate_profile, candidate)
|
||||||
|
overlap = _core_overlap(candidate_profile, core_profile)
|
||||||
|
gap_fill = _compute_gap_fill(candidate_profile, query_profile, core_profile)
|
||||||
|
|
||||||
|
return bool(
|
||||||
|
query_relevance >= float(cfg.incidental_query_relevance_threshold)
|
||||||
|
and specificity >= float(cfg.incidental_specificity_threshold)
|
||||||
|
and overlap <= float(cfg.incidental_core_overlap_threshold)
|
||||||
|
and gap_fill > 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _linked_core_paragraph_hashes(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
relation_hash: str,
|
||||||
|
) -> Set[str]:
|
||||||
|
rows = retriever.metadata_store.query(
|
||||||
|
"""
|
||||||
|
SELECT paragraph_hash FROM paragraph_relations
|
||||||
|
WHERE relation_hash = ?
|
||||||
|
""",
|
||||||
|
(relation_hash,),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
str(row.get("paragraph_hash", "") or "").strip()
|
||||||
|
for row in rows
|
||||||
|
if str(row.get("paragraph_hash", "") or "").strip()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_graph_results_from_seeds(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
*,
|
||||||
|
seed_entities: Sequence[str],
|
||||||
|
temporal: Any,
|
||||||
|
alpha: float,
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
from .dual_path import RetrievalResult
|
||||||
|
|
||||||
|
service = getattr(retriever, "_graph_relation_recall", None)
|
||||||
|
if service is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
payloads = service.recall(seed_entities=seed_entities)
|
||||||
|
if not payloads:
|
||||||
|
return []
|
||||||
|
|
||||||
|
graph_results: List[RetrievalResult] = []
|
||||||
|
for payload in payloads:
|
||||||
|
meta = payload.to_payload()
|
||||||
|
graph_results.append(
|
||||||
|
RetrievalResult(
|
||||||
|
hash_value=str(meta["hash"]),
|
||||||
|
content=str(meta["content"]),
|
||||||
|
score=0.0,
|
||||||
|
result_type="relation",
|
||||||
|
source="posterior_graph_recall",
|
||||||
|
metadata={
|
||||||
|
"subject": meta["subject"],
|
||||||
|
"predicate": meta["predicate"],
|
||||||
|
"object": meta["object"],
|
||||||
|
"confidence": float(meta["confidence"]),
|
||||||
|
"graph_seed_entities": list(meta["graph_seed_entities"]),
|
||||||
|
"graph_hops": int(meta["graph_hops"]),
|
||||||
|
"graph_candidate_type": str(meta["graph_candidate_type"]),
|
||||||
|
"supporting_paragraph_count": int(meta["supporting_paragraph_count"]),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_results = retriever._apply_temporal_filter_to_relations(graph_results, temporal)
|
||||||
|
graph_results = retriever._merge_relation_results_graph_enhanced([], [], graph_results)
|
||||||
|
relation_weight = max(0.0, 1.0 - float(alpha))
|
||||||
|
for item in graph_results:
|
||||||
|
item.score = float(item.score) * relation_weight
|
||||||
|
item.source = "posterior_graph_competition"
|
||||||
|
return graph_results
|
||||||
|
|
||||||
|
|
||||||
|
def _competition_merge(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
*,
|
||||||
|
query: str,
|
||||||
|
base_results: Sequence[RetrievalResult],
|
||||||
|
graph_results: Sequence[RetrievalResult],
|
||||||
|
top_k: int,
|
||||||
|
cfg: PosteriorGraphConfig,
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
ranked = list(base_results)[: max(1, int(top_k))]
|
||||||
|
if not ranked or not graph_results:
|
||||||
|
return ranked
|
||||||
|
|
||||||
|
cliff = find_score_cliff(
|
||||||
|
ranked,
|
||||||
|
drop_ratio=cfg.drop_ratio,
|
||||||
|
min_core_results=cfg.min_core_results,
|
||||||
|
)
|
||||||
|
core_results = ranked[:cliff]
|
||||||
|
replaceable_slots = min(
|
||||||
|
max(0, int(top_k) - len(core_results)),
|
||||||
|
int(cfg.max_graph_slots),
|
||||||
|
)
|
||||||
|
if replaceable_slots <= 0:
|
||||||
|
return ranked[:top_k]
|
||||||
|
|
||||||
|
core_paragraph_hashes = {
|
||||||
|
item.hash_value for item in core_results if item.result_type == "paragraph"
|
||||||
|
}
|
||||||
|
selected_hashes = {item.hash_value for item in core_results}
|
||||||
|
filtered_graph_results: List[RetrievalResult] = []
|
||||||
|
for item in graph_results:
|
||||||
|
if item.hash_value in selected_hashes:
|
||||||
|
continue
|
||||||
|
linked_hashes = _linked_core_paragraph_hashes(retriever, item.hash_value)
|
||||||
|
if core_paragraph_hashes & linked_hashes:
|
||||||
|
continue
|
||||||
|
filtered_graph_results.append(item)
|
||||||
|
|
||||||
|
tail_candidates: List[RetrievalResult] = []
|
||||||
|
for item in ranked[cliff:top_k]:
|
||||||
|
if item.hash_value not in selected_hashes:
|
||||||
|
tail_candidates.append(item)
|
||||||
|
tail_candidates.extend(filtered_graph_results)
|
||||||
|
|
||||||
|
query_profile = _build_query_profile(
|
||||||
|
retriever,
|
||||||
|
query,
|
||||||
|
max_tokens=cfg.max_candidate_tokens,
|
||||||
|
)
|
||||||
|
core_profile = _build_core_profile(
|
||||||
|
retriever,
|
||||||
|
core_results,
|
||||||
|
max_tokens=cfg.max_candidate_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
scored_candidates: List[Tuple[RetrievalResult, float]] = []
|
||||||
|
for item in tail_candidates:
|
||||||
|
competition_score, breakdown = _compute_competition_score(
|
||||||
|
retriever,
|
||||||
|
item,
|
||||||
|
query_profile=query_profile,
|
||||||
|
core_profile=core_profile,
|
||||||
|
cfg=cfg,
|
||||||
|
)
|
||||||
|
metadata = dict(item.metadata) if isinstance(item.metadata, dict) else {}
|
||||||
|
metadata["posterior_original_score"] = round(float(item.score), 4)
|
||||||
|
metadata["posterior_competition_breakdown"] = breakdown
|
||||||
|
metadata["posterior_competition_source"] = "posterior_graph_gate"
|
||||||
|
item.metadata = metadata
|
||||||
|
scored_candidates.append((item, competition_score))
|
||||||
|
|
||||||
|
scored_candidates.sort(
|
||||||
|
key=lambda payload: (
|
||||||
|
float(payload[1]),
|
||||||
|
1 if payload[0].result_type == "relation" else 0,
|
||||||
|
),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tail_winners: List[RetrievalResult] = []
|
||||||
|
seen_hashes = set(selected_hashes)
|
||||||
|
for item, _ in scored_candidates:
|
||||||
|
if item.hash_value in seen_hashes:
|
||||||
|
continue
|
||||||
|
tail_winners.append(item)
|
||||||
|
seen_hashes.add(item.hash_value)
|
||||||
|
if len(tail_winners) >= replaceable_slots:
|
||||||
|
break
|
||||||
|
|
||||||
|
return (core_results + tail_winners)[:top_k]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_posterior_graph_gate(
|
||||||
|
retriever: DualPathRetriever,
|
||||||
|
*,
|
||||||
|
query: str,
|
||||||
|
base_results: Sequence[RetrievalResult],
|
||||||
|
top_k: int,
|
||||||
|
temporal: Any,
|
||||||
|
relation_intent: Dict[str, Any],
|
||||||
|
) -> List[RetrievalResult]:
|
||||||
|
cfg = getattr(retriever.config, "posterior_graph", None)
|
||||||
|
if not isinstance(cfg, PosteriorGraphConfig) or not cfg.enabled:
|
||||||
|
return list(base_results)[:top_k]
|
||||||
|
if not base_results:
|
||||||
|
setattr(
|
||||||
|
retriever,
|
||||||
|
"_posterior_graph_gate_last_decision",
|
||||||
|
{
|
||||||
|
"scheme": "posterior_graph_gate",
|
||||||
|
"query": str(query or ""),
|
||||||
|
"enabled": False,
|
||||||
|
"bucket": "posterior_gate_empty",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
top_k_int = max(1, int(top_k))
|
||||||
|
alpha_override = relation_intent.get("alpha_override") if isinstance(relation_intent, dict) else None
|
||||||
|
alpha = float(alpha_override) if alpha_override is not None else float(retriever.config.alpha)
|
||||||
|
rag_confidence = _top_score(list(base_results)[:top_k_int])
|
||||||
|
|
||||||
|
query_profile = _build_query_profile(
|
||||||
|
retriever,
|
||||||
|
query,
|
||||||
|
max_tokens=cfg.max_candidate_tokens,
|
||||||
|
)
|
||||||
|
seed_evidence = _extract_seed_evidence(
|
||||||
|
retriever,
|
||||||
|
query_profile,
|
||||||
|
base_results,
|
||||||
|
scan_top_k=cfg.gate_scan_top_k,
|
||||||
|
max_tokens=cfg.max_candidate_tokens,
|
||||||
|
)
|
||||||
|
grounded_seeds = _grounded_seed_names(seed_evidence)[:2]
|
||||||
|
incidental_seeds = _incidental_seed_names(seed_evidence)[:2]
|
||||||
|
|
||||||
|
core_results = list(base_results)[
|
||||||
|
: find_score_cliff(
|
||||||
|
list(base_results)[:top_k_int],
|
||||||
|
drop_ratio=cfg.drop_ratio,
|
||||||
|
min_core_results=cfg.min_core_results,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
core_profile = _build_core_profile(
|
||||||
|
retriever,
|
||||||
|
core_results,
|
||||||
|
max_tokens=cfg.max_candidate_tokens,
|
||||||
|
)
|
||||||
|
core_profiles = [
|
||||||
|
_build_candidate_profile(retriever, item, max_tokens=cfg.max_candidate_tokens)
|
||||||
|
for item in core_results
|
||||||
|
]
|
||||||
|
need_for_graph, need_reason = _need_for_graph(
|
||||||
|
query_profile=query_profile,
|
||||||
|
core_profile=core_profile,
|
||||||
|
core_profiles=core_profiles,
|
||||||
|
grounded_seeds=grounded_seeds,
|
||||||
|
rag_confidence=rag_confidence,
|
||||||
|
cfg=cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
seed_type = "none"
|
||||||
|
seed_names: List[str] = []
|
||||||
|
if grounded_seeds and need_for_graph:
|
||||||
|
seed_type = "grounded"
|
||||||
|
seed_names = grounded_seeds
|
||||||
|
elif (
|
||||||
|
not grounded_seeds
|
||||||
|
and incidental_seeds
|
||||||
|
and rag_confidence < float(cfg.incidental_confidence_threshold)
|
||||||
|
):
|
||||||
|
seed_type = "incidental"
|
||||||
|
seed_names = incidental_seeds
|
||||||
|
|
||||||
|
if not seed_names:
|
||||||
|
setattr(
|
||||||
|
retriever,
|
||||||
|
"_posterior_graph_gate_last_decision",
|
||||||
|
{
|
||||||
|
"scheme": "posterior_graph_gate",
|
||||||
|
"query": str(query or ""),
|
||||||
|
"enabled": False,
|
||||||
|
"bucket": "posterior_gate_none",
|
||||||
|
"grounded_seeds": list(grounded_seeds),
|
||||||
|
"incidental_seeds": list(incidental_seeds),
|
||||||
|
"selected_seed_type": seed_type,
|
||||||
|
"need_for_graph": bool(need_for_graph),
|
||||||
|
"need_reason": str(need_reason),
|
||||||
|
"rag_confidence": round(float(rag_confidence), 4),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return list(base_results)[:top_k_int]
|
||||||
|
|
||||||
|
graph_results = _build_graph_results_from_seeds(
|
||||||
|
retriever,
|
||||||
|
seed_entities=seed_names,
|
||||||
|
temporal=temporal,
|
||||||
|
alpha=alpha,
|
||||||
|
)
|
||||||
|
raw_graph_count = len(graph_results)
|
||||||
|
if seed_type == "incidental":
|
||||||
|
graph_results = [
|
||||||
|
item
|
||||||
|
for item in graph_results
|
||||||
|
if _passes_incidental_high_bar(
|
||||||
|
retriever,
|
||||||
|
item,
|
||||||
|
query_profile=query_profile,
|
||||||
|
core_profile=core_profile,
|
||||||
|
cfg=cfg,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if not graph_results:
|
||||||
|
setattr(
|
||||||
|
retriever,
|
||||||
|
"_posterior_graph_gate_last_decision",
|
||||||
|
{
|
||||||
|
"scheme": "posterior_graph_gate",
|
||||||
|
"query": str(query or ""),
|
||||||
|
"enabled": False,
|
||||||
|
"bucket": "posterior_gate_graph_filtered",
|
||||||
|
"grounded_seeds": list(grounded_seeds),
|
||||||
|
"incidental_seeds": list(incidental_seeds),
|
||||||
|
"selected_seed_type": seed_type,
|
||||||
|
"need_for_graph": bool(need_for_graph),
|
||||||
|
"need_reason": str(need_reason),
|
||||||
|
"rag_confidence": round(float(rag_confidence), 4),
|
||||||
|
"graph_result_count": int(raw_graph_count),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return list(base_results)[:top_k_int]
|
||||||
|
|
||||||
|
final_results = _competition_merge(
|
||||||
|
retriever,
|
||||||
|
query=query,
|
||||||
|
base_results=base_results,
|
||||||
|
graph_results=graph_results,
|
||||||
|
top_k=top_k_int,
|
||||||
|
cfg=cfg,
|
||||||
|
)
|
||||||
|
selected_hashes = {item.hash_value for item in final_results}
|
||||||
|
graph_selected = any(item.hash_value in selected_hashes for item in graph_results)
|
||||||
|
setattr(
|
||||||
|
retriever,
|
||||||
|
"_posterior_graph_gate_last_decision",
|
||||||
|
{
|
||||||
|
"scheme": "posterior_graph_gate",
|
||||||
|
"query": str(query or ""),
|
||||||
|
"enabled": bool(graph_selected),
|
||||||
|
"bucket": "posterior_gate_enabled" if graph_selected else "posterior_gate_tail_rejected",
|
||||||
|
"grounded_seeds": list(grounded_seeds),
|
||||||
|
"incidental_seeds": list(incidental_seeds),
|
||||||
|
"selected_seed_type": seed_type,
|
||||||
|
"need_for_graph": bool(need_for_graph),
|
||||||
|
"need_reason": str(need_reason),
|
||||||
|
"rag_confidence": round(float(rag_confidence), 4),
|
||||||
|
"graph_result_count": int(raw_graph_count),
|
||||||
|
"filtered_graph_count": max(0, raw_graph_count - len(graph_results)),
|
||||||
|
"base_top_k_count": min(len(base_results), top_k_int),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return final_results[:top_k_int]
|
||||||
@@ -306,15 +306,22 @@ class SparseBM25Index:
|
|||||||
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
|
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
|
||||||
|
|
||||||
results: List[Dict[str, Any]] = []
|
results: List[Dict[str, Any]] = []
|
||||||
|
token_count = max(1, len(tokens))
|
||||||
for rank, row in enumerate(rows, start=1):
|
for rank, row in enumerate(rows, start=1):
|
||||||
bm25_score = float(row.get("bm25_score", 0.0))
|
bm25_score = float(row.get("bm25_score", 0.0))
|
||||||
|
content = str(row.get("content", "") or "")
|
||||||
|
content_low = content.lower()
|
||||||
|
matched_tokens = [token for token in tokens if token in content_low]
|
||||||
|
matched_token_count = len(dict.fromkeys(matched_tokens))
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"hash": row["hash"],
|
"hash": row["hash"],
|
||||||
"content": row["content"],
|
"content": content,
|
||||||
"rank": rank,
|
"rank": rank,
|
||||||
"bm25_score": bm25_score,
|
"bm25_score": bm25_score,
|
||||||
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
|
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
|
||||||
|
"matched_token_count": matched_token_count,
|
||||||
|
"matched_token_ratio": matched_token_count / float(token_count),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class ThresholdConfig:
|
|||||||
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
|
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
|
||||||
|
|
||||||
if self.min_threshold >= self.max_threshold:
|
if self.min_threshold >= self.max_threshold:
|
||||||
raise ValueError("min_threshold必须小于max_threshold")
|
raise ValueError(f"min_threshold必须小于max_threshold")
|
||||||
|
|
||||||
if not 0 <= self.percentile <= 100:
|
if not 0 <= self.percentile <= 100:
|
||||||
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
|
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Coroutine, cast
|
from typing import Any, Callable, Coroutine, cast
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -18,7 +19,7 @@ from src.services.llm_service import LLMServiceClient
|
|||||||
|
|
||||||
from ...paths import default_data_dir, resolve_repo_path
|
from ...paths import default_data_dir, resolve_repo_path
|
||||||
from ..embedding import create_embedding_api_adapter
|
from ..embedding import create_embedding_api_adapter
|
||||||
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index
|
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions
|
||||||
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
|
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
|
||||||
from ..utils.aggregate_query_service import AggregateQueryService
|
from ..utils.aggregate_query_service import AggregateQueryService
|
||||||
from ..utils.episode_retrieval_service import EpisodeRetrievalService
|
from ..utils.episode_retrieval_service import EpisodeRetrievalService
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from ..retrieval import (
|
|||||||
DynamicThresholdFilter,
|
DynamicThresholdFilter,
|
||||||
FusionConfig,
|
FusionConfig,
|
||||||
GraphRelationRecallConfig,
|
GraphRelationRecallConfig,
|
||||||
|
PosteriorGraphConfig,
|
||||||
RelationIntentConfig,
|
RelationIntentConfig,
|
||||||
RetrievalStrategy,
|
RetrievalStrategy,
|
||||||
SparseBM25Config,
|
SparseBM25Config,
|
||||||
@@ -143,6 +144,9 @@ def build_search_runtime(
|
|||||||
graph_recall_cfg_raw = _safe_dict(
|
graph_recall_cfg_raw = _safe_dict(
|
||||||
_get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {}
|
_get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {}
|
||||||
)
|
)
|
||||||
|
posterior_graph_cfg_raw = _safe_dict(
|
||||||
|
_get_config_value(plugin_config, "retrieval.search.posterior_graph", {}) or {}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||||
@@ -168,6 +172,12 @@ def build_search_runtime(
|
|||||||
log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}")
|
log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}")
|
||||||
graph_recall_cfg = GraphRelationRecallConfig()
|
graph_recall_cfg = GraphRelationRecallConfig()
|
||||||
|
|
||||||
|
try:
|
||||||
|
posterior_graph_cfg = PosteriorGraphConfig(**posterior_graph_cfg_raw)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"{prefix_text}[{owner}] posterior_graph 配置非法,回退默认: {e}")
|
||||||
|
posterior_graph_cfg = PosteriorGraphConfig()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = DualPathRetrieverConfig(
|
config = DualPathRetrieverConfig(
|
||||||
top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20),
|
top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20),
|
||||||
@@ -189,6 +199,7 @@ def build_search_runtime(
|
|||||||
fusion=fusion_cfg,
|
fusion=fusion_cfg,
|
||||||
relation_intent=relation_intent_cfg,
|
relation_intent=relation_intent_cfg,
|
||||||
graph_recall=graph_recall_cfg,
|
graph_recall=graph_recall_cfg,
|
||||||
|
posterior_graph=posterior_graph_cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
runtime.retriever = DualPathRetriever(
|
runtime.retriever = DualPathRetriever(
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Tuple, List, Dict, Set, Any
|
from typing import Optional, Union, Tuple, List, Dict, Set, Any
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import threading
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -41,6 +42,7 @@ except ImportError:
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from ..utils.hash import compute_hash
|
||||||
from ..utils.io import atomic_write
|
from ..utils.io import atomic_write
|
||||||
|
|
||||||
logger = get_logger("A_Memorix.GraphStore")
|
logger = get_logger("A_Memorix.GraphStore")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。
|
基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import hashlib
|
import hashlib
|
||||||
import shutil
|
import shutil
|
||||||
@@ -190,7 +191,7 @@ class VectorStore:
|
|||||||
self._update_reservoir(batch_vecs)
|
self._update_reservoir(batch_vecs)
|
||||||
# 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断
|
# 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断
|
||||||
if len(self._reservoir_buffer) >= 10000:
|
if len(self._reservoir_buffer) >= 10000:
|
||||||
logger.info("训练样本达到上限,开始训练...")
|
logger.info(f"训练样本达到上限,开始训练...")
|
||||||
self._train_and_replay_unlocked()
|
self._train_and_replay_unlocked()
|
||||||
|
|
||||||
self._total_added += len(batch_ids)
|
self._total_added += len(batch_ids)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional, Union
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List, Dict, Any
|
||||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
||||||
|
|
||||||
class FactualStrategy(BaseStrategy):
|
class FactualStrategy(BaseStrategy):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List, Dict, Any
|
||||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
||||||
|
|
||||||
class NarrativeStrategy(BaseStrategy):
|
class NarrativeStrategy(BaseStrategy):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Dict, Any
|
||||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags
|
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags
|
||||||
|
|
||||||
class QuoteStrategy(BaseStrategy):
|
class QuoteStrategy(BaseStrategy):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
def compute_hash(text: str, hash_type: str = "sha256") -> str:
|
def compute_hash(text: str, hash_type: str = "sha256") -> str:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ IO Utilities
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import contextlib
|
import contextlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
实现 Aho-Corasick 算法用于多模式匹配。
|
实现 Aho-Corasick 算法用于多模式匹配。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Dict, Tuple, Set
|
from typing import List, Dict, Tuple, Set, Any
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Any, Dict, List, Sequence, Tuple
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ..retrieval.dual_path import RetrievalResult
|
from ..retrieval.dual_path import RetrievalResult
|
||||||
|
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ async def ensure_runtime_self_check(
|
|||||||
sample_text=sample_text,
|
sample_text=sample_text,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
plugin_or_config._runtime_self_check_report = report
|
setattr(plugin_or_config, "_runtime_self_check_report", report)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return report
|
return report
|
||||||
|
|||||||
@@ -287,7 +287,7 @@ class SearchExecutionService:
|
|||||||
|
|
||||||
async def _executor() -> Dict[str, Any]:
|
async def _executor() -> Dict[str, Any]:
|
||||||
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
|
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
|
||||||
retriever.config.enable_ppr = bool(request.enable_ppr)
|
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
|
||||||
started_at = time.time()
|
started_at = time.time()
|
||||||
try:
|
try:
|
||||||
retrieved = await retriever.retrieve(
|
retrieved = await retriever.retrieve(
|
||||||
@@ -380,7 +380,7 @@ class SearchExecutionService:
|
|||||||
elapsed_ms = (time.time() - started_at) * 1000.0
|
elapsed_ms = (time.time() - started_at) * 1000.0
|
||||||
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
||||||
finally:
|
finally:
|
||||||
retriever.config.enable_ppr = original_ppr
|
setattr(retriever.config, "enable_ppr", original_ppr)
|
||||||
|
|
||||||
dedup_hit = False
|
dedup_hit = False
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
导入到 A_memorix 的存储组件中。
|
导入到 A_memorix 的存储组件中。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
_runtime_kernel: Any = None
|
_runtime_kernel: Any = None
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ for _path in (SRC_ROOT, PROJECT_ROOT, PLUGIN_ROOT):
|
|||||||
if _path_str not in sys.path:
|
if _path_str not in sys.path:
|
||||||
sys.path.insert(0, _path_str)
|
sys.path.insert(0, _path_str)
|
||||||
|
|
||||||
from A_memorix.paths import config_path, default_data_dir
|
from A_memorix.paths import config_path, default_data_dir, resolve_repo_path
|
||||||
|
|
||||||
DEFAULT_CONFIG_PATH = config_path()
|
DEFAULT_CONFIG_PATH = config_path()
|
||||||
DEFAULT_DATA_DIR = default_data_dir()
|
DEFAULT_DATA_DIR = default_data_dir()
|
||||||
|
|||||||
@@ -10,12 +10,14 @@ LPMM 到 A_memorix 存储转换器
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import pickle
|
import pickle
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Tuple
|
from typing import Dict, Any, List, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
|
||||||
|
|||||||
@@ -12,14 +12,17 @@
|
|||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import sqlite3
|
|||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user