Merge branch 'dev' of https://github.com/Mai-with-u/MaiBot into dev
This commit is contained in:
5
.github/pull_request_template.md
vendored
5
.github/pull_request_template.md
vendored
@@ -10,8 +10,9 @@
|
||||
3. - [ ] 本次更新类型为:BUG修复
|
||||
- [ ] 本次更新类型为:功能新增
|
||||
4. - [ ] 本次更新是否经过测试
|
||||
5. 请填写破坏性更新的具体内容(如有):
|
||||
6. 请简要说明本次更新的内容和目的:
|
||||
5. - [ ] 如果本次修改涉及 `src/A_memorix`,我确认已阅读 `src/A_memorix/MODIFICATION_POLICY.md`,不涉及则无需勾选
|
||||
6. 请填写破坏性更新的具体内容(如有):
|
||||
7. 请简要说明本次更新的内容和目的:
|
||||
# 其他信息
|
||||
- **关联 Issue**:Close #
|
||||
- **截图/GIF**:
|
||||
|
||||
10
AGENTS.md
10
AGENTS.md
@@ -44,5 +44,13 @@
|
||||
# 关于webui修改
|
||||
不要修改dashboard下的内容,因为这部分内容由另一个仓库build
|
||||
|
||||
# 关于 A_memorix 修改
|
||||
如果修改涉及 `src/A_memorix`,请先阅读 `src/A_memorix/MODIFICATION_POLICY.md`。
|
||||
|
||||
默认原则:
|
||||
1. `src/A_memorix` 的实现层改动应优先遵守 `src/A_memorix/MODIFICATION_POLICY.md` 中的归属约束。
|
||||
2. 不要提交无边界的 `ruff`、格式化、导入整理或大面积实现整理。
|
||||
3. 本地实验目录或依赖其运行的测试,除非明确说明并确认,否则不要进入共享历史。
|
||||
|
||||
# maibot插件开发文档
|
||||
https://github.com/Mai-with-u/maibot-plugin-sdk/blob/main/docs/guide.md
|
||||
https://github.com/Mai-with-u/maibot-plugin-sdk/blob/main/docs/guide.md
|
||||
|
||||
@@ -120,7 +120,7 @@ default_sample_size = 24
|
||||
|
||||
- 长期记忆控制台:适合修改高频项,例如 embedding、检索、Episode、人物画像、导入与调优的常用开关。
|
||||
- 原始 TOML:适合复制整份配置、批量调整参数,或修改未在可视化表单中展示的高级项。
|
||||
- raw-only 高级项仍包括:`retrieval.fusion.*`、`retrieval.search.relation_intent.*`、`retrieval.search.graph_recall.*`、`retrieval.aggregate.*`、`memory.orphan.*`、`advanced.extraction_model`、`web.import.llm_retry.*`、`web.import.path_aliases`、`web.import.convert.*`、`web.tuning.llm_retry.*`、`web.tuning.eval_query_timeout_seconds`。
|
||||
- raw-only 高级项仍包括:`retrieval.fusion.*`、`retrieval.search.relation_intent.*`、`retrieval.search.graph_recall.*`、`retrieval.search.posterior_graph.*`、`retrieval.aggregate.*`、`memory.orphan.*`、`advanced.extraction_model`、`web.import.llm_retry.*`、`web.import.path_aliases`、`web.import.convert.*`、`web.tuning.llm_retry.*`、`web.tuning.eval_query_timeout_seconds`。
|
||||
|
||||
## 1. 存储与嵌入
|
||||
|
||||
@@ -213,6 +213,26 @@ default_sample_size = 24
|
||||
- `allow_two_hop_pair` (默认 `true`)
|
||||
- `max_paths` (默认 `4`)
|
||||
|
||||
### `retrieval.search.posterior_graph` (`PosteriorGraphConfig`)
|
||||
|
||||
- `enabled` (默认 `true`)
|
||||
- `drop_ratio` (默认 `0.15`)
|
||||
- `min_core_results` (默认 `2`)
|
||||
- `max_graph_slots` (默认 `2`)
|
||||
- `gate_scan_top_k` (默认 `5`)
|
||||
- `grounded_confidence_threshold` (默认 `0.48`)
|
||||
- `incidental_confidence_threshold` (默认 `0.22`)
|
||||
- `min_query_token_coverage` (默认 `0.78`)
|
||||
- `incidental_query_relevance_threshold` (默认 `0.68`)
|
||||
- `incidental_core_overlap_threshold` (默认 `0.34`)
|
||||
- `incidental_specificity_threshold` (默认 `0.42`)
|
||||
|
||||
说明:
|
||||
|
||||
- 这组配置控制“后验图补位”,即先跑正常双路检索,再判断是否需要从图结构补一小批 relation 候选进入尾部竞争。
|
||||
- 设计目标以 `recall` 为主,而不是强行把 relation 顶到第一名。
|
||||
- 如果你的最终回答仍会经过 LLM 汇总,这组能力更适合用于“保证证据进入前排候选”,而不是做激进排序改写。
|
||||
|
||||
### `retrieval.aggregate`
|
||||
|
||||
- `retrieval.aggregate.rrf_k`
|
||||
|
||||
@@ -23,7 +23,10 @@ from src.common.logger import get_logger
|
||||
from .presets import (
|
||||
EmbeddingModelConfig,
|
||||
get_custom_config,
|
||||
validate_config_compatibility,
|
||||
are_models_compatible,
|
||||
)
|
||||
from ..utils.quantization import QuantizationType
|
||||
|
||||
logger = get_logger("A_Memorix.EmbeddingManager")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from .dual_path import (
|
||||
FusionConfig,
|
||||
RelationIntentConfig,
|
||||
)
|
||||
from .posterior_graph import PosteriorGraphConfig
|
||||
from .pagerank import (
|
||||
PersonalizedPageRank,
|
||||
PageRankConfig,
|
||||
@@ -37,6 +38,7 @@ __all__ = [
|
||||
"TemporalQueryOptions",
|
||||
"FusionConfig",
|
||||
"RelationIntentConfig",
|
||||
"PosteriorGraphConfig",
|
||||
# PersonalizedPageRank
|
||||
"PersonalizedPageRank",
|
||||
"PageRankConfig",
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
@@ -19,6 +19,7 @@ from ..utils.matcher import AhoCorasick
|
||||
from ..utils.time_parser import format_timestamp
|
||||
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
|
||||
from .pagerank import PersonalizedPageRank, PageRankConfig
|
||||
from .posterior_graph import PosteriorGraphConfig, apply_posterior_graph_gate
|
||||
from .sparse_bm25 import SparseBM25Config, SparseBM25Index
|
||||
|
||||
logger = get_logger("A_Memorix.DualPathRetriever")
|
||||
@@ -101,6 +102,7 @@ class DualPathRetrieverConfig:
|
||||
fusion: "FusionConfig" = field(default_factory=lambda: FusionConfig())
|
||||
relation_intent: "RelationIntentConfig" = field(default_factory=lambda: RelationIntentConfig())
|
||||
graph_recall: GraphRelationRecallConfig = field(default_factory=GraphRelationRecallConfig)
|
||||
posterior_graph: PosteriorGraphConfig = field(default_factory=PosteriorGraphConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
@@ -112,6 +114,8 @@ class DualPathRetrieverConfig:
|
||||
self.relation_intent = RelationIntentConfig(**self.relation_intent)
|
||||
if isinstance(self.graph_recall, dict):
|
||||
self.graph_recall = GraphRelationRecallConfig(**self.graph_recall)
|
||||
if isinstance(self.posterior_graph, dict):
|
||||
self.posterior_graph = PosteriorGraphConfig(**self.posterior_graph)
|
||||
|
||||
if not 0 <= self.alpha <= 1:
|
||||
raise ValueError(f"alpha必须在[0, 1]之间: {self.alpha}")
|
||||
@@ -320,7 +324,7 @@ class DualPathRetriever:
|
||||
|
||||
# 调试模式:打印结果原文
|
||||
if self.config.debug:
|
||||
logger.info("[DEBUG] 检索结果内容原文:")
|
||||
logger.info(f"[DEBUG] 检索结果内容原文:")
|
||||
for i, res in enumerate(results):
|
||||
logger.info(f" {i+1}. [{res.result_type}] (Score: {res.score:.4f}) {res.content}")
|
||||
|
||||
@@ -588,6 +592,7 @@ class DualPathRetriever:
|
||||
candidate_k = max(top_k, self.config.sparse.candidate_k)
|
||||
candidate_k = self._cap_temporal_scan_k(candidate_k, temporal)
|
||||
sparse_rows = self.sparse_index.search(query=query, k=candidate_k)
|
||||
sparse_rows = self._filter_sparse_paragraph_rows(sparse_rows)
|
||||
results: List[RetrievalResult] = []
|
||||
for row in sparse_rows:
|
||||
hash_value = row["hash"]
|
||||
@@ -614,6 +619,53 @@ class DualPathRetriever:
|
||||
self._normalize_scores_minmax(results)
|
||||
return results
|
||||
|
||||
def _filter_sparse_paragraph_rows(
|
||||
self,
|
||||
rows: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
过滤 paragraph sparse tail。
|
||||
|
||||
目标不是压缩强 lexical hit,而是避免只命中一个弱 token 的尾部结果
|
||||
在 weighted RRF 中拿到过高的 rank credit。
|
||||
"""
|
||||
if len(rows) <= 2:
|
||||
return rows
|
||||
|
||||
top_score = max(0.0, float(rows[0].get("score", 0.0) or 0.0))
|
||||
if top_score <= 0.0:
|
||||
return rows[:2]
|
||||
|
||||
relative_floor = top_score * 0.2
|
||||
filtered_rows: List[Dict[str, Any]] = []
|
||||
removed_count = 0
|
||||
for index, row in enumerate(rows):
|
||||
if index < 2:
|
||||
filtered_rows.append(row)
|
||||
continue
|
||||
|
||||
raw_score = float(row.get("score", 0.0) or 0.0)
|
||||
matched_token_count = int(row.get("matched_token_count", 0) or 0)
|
||||
matched_token_ratio = float(row.get("matched_token_ratio", 0.0) or 0.0)
|
||||
|
||||
if (
|
||||
raw_score >= relative_floor
|
||||
or matched_token_count >= 3
|
||||
or (matched_token_count >= 2 and matched_token_ratio >= 0.12)
|
||||
):
|
||||
filtered_rows.append(row)
|
||||
continue
|
||||
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(
|
||||
"sparse_paragraph_tail_pruned=1 "
|
||||
f"removed_count={removed_count} "
|
||||
f"kept_count={len(filtered_rows)}"
|
||||
)
|
||||
return filtered_rows
|
||||
|
||||
def _search_relations_sparse(
|
||||
self,
|
||||
query: str,
|
||||
@@ -1025,6 +1077,14 @@ class DualPathRetriever:
|
||||
)
|
||||
if temporal:
|
||||
fused_results = self._sort_results_with_temporal(fused_results, temporal)
|
||||
fused_results = apply_posterior_graph_gate(
|
||||
self,
|
||||
query=query,
|
||||
base_results=fused_results,
|
||||
top_k=top_k,
|
||||
temporal=temporal,
|
||||
relation_intent=relation_intent,
|
||||
)
|
||||
fused_results = self._apply_relation_intent_pair_rerank(
|
||||
fused_results,
|
||||
enabled=bool(relation_intent.get("enabled", False)),
|
||||
@@ -1126,6 +1186,15 @@ class DualPathRetriever:
|
||||
if temporal:
|
||||
fused_results = self._sort_results_with_temporal(fused_results, temporal)
|
||||
|
||||
fused_results = apply_posterior_graph_gate(
|
||||
self,
|
||||
query=query,
|
||||
base_results=fused_results,
|
||||
top_k=top_k,
|
||||
temporal=temporal,
|
||||
relation_intent=relation_intent,
|
||||
)
|
||||
|
||||
fused_results = self._apply_relation_intent_pair_rerank(
|
||||
fused_results,
|
||||
enabled=bool(relation_intent.get("enabled", False)),
|
||||
@@ -1150,37 +1219,13 @@ class DualPathRetriever:
|
||||
Returns:
|
||||
(段落结果, 关系结果)
|
||||
"""
|
||||
# 使用 asyncio.gather 并发执行两个搜索任务
|
||||
# 由于 _search_paragraphs 和 _search_relations 是 CPU 密集型同步函数,
|
||||
# 使用 asyncio.to_thread 在线程池中执行
|
||||
try:
|
||||
para_task = asyncio.to_thread(
|
||||
self._search_paragraphs,
|
||||
return await asyncio.to_thread(
|
||||
self._collect_mixed_candidates,
|
||||
query_emb,
|
||||
self.config.top_k_paragraphs,
|
||||
temporal,
|
||||
relation_top_k,
|
||||
)
|
||||
rel_task = asyncio.to_thread(
|
||||
self._search_relations,
|
||||
query_emb,
|
||||
relation_top_k if relation_top_k is not None else self.config.top_k_relations,
|
||||
temporal,
|
||||
)
|
||||
|
||||
para_results, rel_results = await asyncio.gather(
|
||||
para_task, rel_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# 处理异常
|
||||
if isinstance(para_results, Exception):
|
||||
logger.error(f"段落检索失败: {para_results}")
|
||||
para_results = []
|
||||
if isinstance(rel_results, Exception):
|
||||
logger.error(f"关系检索失败: {rel_results}")
|
||||
rel_results = []
|
||||
|
||||
return para_results, rel_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"并行检索失败: {e}")
|
||||
return [], []
|
||||
@@ -1200,18 +1245,125 @@ class DualPathRetriever:
|
||||
Returns:
|
||||
(段落结果, 关系结果)
|
||||
"""
|
||||
para_results = self._search_paragraphs(
|
||||
return self._collect_mixed_candidates(
|
||||
query_emb,
|
||||
self.config.top_k_paragraphs,
|
||||
temporal,
|
||||
relation_top_k,
|
||||
)
|
||||
|
||||
rel_results = self._search_relations(
|
||||
query_emb,
|
||||
relation_top_k if relation_top_k is not None else self.config.top_k_relations,
|
||||
temporal,
|
||||
)
|
||||
def _mixed_candidate_budget(
|
||||
self,
|
||||
para_top_k: int,
|
||||
rel_top_k: int,
|
||||
temporal: Optional[TemporalQueryOptions],
|
||||
) -> int:
|
||||
multiplier = max(1, temporal.candidate_multiplier) if temporal else 1
|
||||
base = max(para_top_k + rel_top_k, max(para_top_k, rel_top_k) * 2)
|
||||
return max(base * 6 * multiplier, 48)
|
||||
|
||||
def _merge_backfilled_results(
|
||||
self,
|
||||
*,
|
||||
primary_results: List[RetrievalResult],
|
||||
backfill_results: List[RetrievalResult],
|
||||
top_k: int,
|
||||
) -> List[RetrievalResult]:
|
||||
merged: Dict[str, RetrievalResult] = {}
|
||||
for item in primary_results:
|
||||
merged[item.hash_value] = item
|
||||
for item in backfill_results:
|
||||
existing = merged.get(item.hash_value)
|
||||
if existing is None or float(item.score) > float(existing.score):
|
||||
merged[item.hash_value] = item
|
||||
|
||||
results = list(merged.values())
|
||||
results.sort(key=lambda item: item.score, reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
def _collect_mixed_candidates(
|
||||
self,
|
||||
query_emb: np.ndarray,
|
||||
temporal: Optional[TemporalQueryOptions] = None,
|
||||
relation_top_k: Optional[int] = None,
|
||||
) -> Tuple[List[RetrievalResult], List[RetrievalResult]]:
|
||||
para_top_k = self.config.top_k_paragraphs
|
||||
rel_top_k = relation_top_k if relation_top_k is not None else self.config.top_k_relations
|
||||
candidate_k = self._mixed_candidate_budget(para_top_k, rel_top_k, temporal)
|
||||
candidate_k = self._cap_temporal_scan_k(candidate_k, temporal)
|
||||
ids, scores = self.vector_store.search(query_emb, k=candidate_k)
|
||||
|
||||
para_candidates: List[RetrievalResult] = []
|
||||
rel_candidates: List[RetrievalResult] = []
|
||||
seen_para = set()
|
||||
seen_rel = set()
|
||||
|
||||
for hash_value, score in zip(ids, scores):
|
||||
paragraph = self.metadata_store.get_paragraph(hash_value)
|
||||
if paragraph is not None and hash_value not in seen_para:
|
||||
seen_para.add(hash_value)
|
||||
para_candidates.append(
|
||||
RetrievalResult(
|
||||
hash_value=hash_value,
|
||||
content=paragraph["content"],
|
||||
score=float(score),
|
||||
result_type="paragraph",
|
||||
source="paragraph_search",
|
||||
metadata={
|
||||
"word_count": paragraph.get("word_count", 0),
|
||||
"time_meta": self._build_time_meta_from_paragraph(
|
||||
paragraph,
|
||||
temporal=temporal,
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
relation = self.metadata_store.get_relation(hash_value, include_inactive=False)
|
||||
if relation is None or hash_value in seen_rel:
|
||||
continue
|
||||
|
||||
relation_time_meta = None
|
||||
if temporal:
|
||||
relation_time_meta = self._best_supporting_time_meta(hash_value, temporal)
|
||||
if relation_time_meta is None:
|
||||
continue
|
||||
|
||||
seen_rel.add(hash_value)
|
||||
rel_candidates.append(
|
||||
RetrievalResult(
|
||||
hash_value=hash_value,
|
||||
content=f"{relation['subject']} {relation['predicate']} {relation['object']}",
|
||||
score=float(score),
|
||||
result_type="relation",
|
||||
source="relation_search",
|
||||
metadata={
|
||||
"subject": relation["subject"],
|
||||
"predicate": relation["predicate"],
|
||||
"object": relation["object"],
|
||||
"confidence": relation.get("confidence", 1.0),
|
||||
"time_meta": relation_time_meta,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
para_results = self._apply_temporal_filter_to_paragraphs(para_candidates, temporal)
|
||||
rel_results = self._apply_temporal_filter_to_relations(rel_candidates, temporal)
|
||||
|
||||
# 双重方案里,向量主干优先解决“召回不够”,因此主检索走共享候选池,
|
||||
# 但再补一层按类型回填,避免 paragraph / relation 任一侧被饿死。
|
||||
para_backfill = self._search_paragraphs(query_emb, para_top_k, temporal)
|
||||
rel_backfill = self._search_relations(query_emb, rel_top_k, temporal)
|
||||
para_results = self._merge_backfilled_results(
|
||||
primary_results=para_results,
|
||||
backfill_results=para_backfill,
|
||||
top_k=para_top_k,
|
||||
)
|
||||
rel_results = self._merge_backfilled_results(
|
||||
primary_results=rel_results,
|
||||
backfill_results=rel_backfill,
|
||||
top_k=rel_top_k,
|
||||
)
|
||||
return para_results, rel_results
|
||||
|
||||
def _search_paragraphs(
|
||||
@@ -1560,9 +1712,20 @@ class DualPathRetriever:
|
||||
entity_scores.append(ppr_scores_by_name[ent_name])
|
||||
|
||||
if entity_scores:
|
||||
avg_ppr = np.mean(entity_scores)
|
||||
# 融合原始分数和PPR分数
|
||||
result.score = result.score * 0.7 + avg_ppr * 0.3
|
||||
# 只使用命中的高价值图实体做正向增益,避免把原本高分的正确段落
|
||||
# 因为“实体多但非全部命中”而反向压低。
|
||||
focus_scores = sorted(entity_scores, reverse=True)[:2]
|
||||
ppr_signal = float(np.mean(focus_scores))
|
||||
boost_weight = 0.12 if len(focus_scores) >= 2 else 0.06
|
||||
boost = ppr_signal * boost_weight
|
||||
|
||||
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||
metadata["ppr_signal"] = round(ppr_signal, 4)
|
||||
metadata["ppr_focus_entity_count"] = len(focus_scores)
|
||||
metadata["ppr_boost"] = round(boost, 4)
|
||||
result.metadata = metadata
|
||||
|
||||
result.score = float(result.score) + float(boost)
|
||||
|
||||
# 重新排序
|
||||
results.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
@@ -4,8 +4,9 @@ Personalized PageRank实现
|
||||
提供个性化的图节点排序功能。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..storage import GraphStore
|
||||
@@ -48,7 +49,7 @@ class PageRankConfig:
|
||||
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
|
||||
|
||||
if self.min_iterations >= self.max_iter:
|
||||
raise ValueError("min_iterations必须小于max_iter")
|
||||
raise ValueError(f"min_iterations必须小于max_iter")
|
||||
|
||||
|
||||
class PersonalizedPageRank:
|
||||
|
||||
792
src/A_memorix/core/retrieval/posterior_graph.py
Normal file
792
src/A_memorix/core/retrieval/posterior_graph.py
Normal file
@@ -0,0 +1,792 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Set, Tuple
|
||||
|
||||
import re
|
||||
|
||||
import jieba
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .dual_path import DualPathRetriever, RetrievalResult
|
||||
|
||||
|
||||
_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]+")
|
||||
_BROAD_PREDICATES = {
|
||||
"contains_fact",
|
||||
"describe",
|
||||
"describes",
|
||||
"description",
|
||||
"mention",
|
||||
"mentions",
|
||||
"summary",
|
||||
"summarizes",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PosteriorGraphConfig:
|
||||
"""双重方案中的后验图补位配置。"""
|
||||
|
||||
enabled: bool = True
|
||||
drop_ratio: float = 0.15
|
||||
min_core_results: int = 2
|
||||
max_graph_slots: int = 2
|
||||
gate_scan_top_k: int = 5
|
||||
grounded_confidence_threshold: float = 0.48
|
||||
incidental_confidence_threshold: float = 0.22
|
||||
min_query_token_coverage: float = 0.78
|
||||
incidental_query_relevance_threshold: float = 0.68
|
||||
incidental_core_overlap_threshold: float = 0.34
|
||||
incidental_specificity_threshold: float = 0.42
|
||||
query_weight: float = 0.28
|
||||
novelty_weight: float = 0.18
|
||||
complementarity_weight: float = 0.16
|
||||
specificity_weight: float = 0.12
|
||||
gap_fill_weight: float = 0.26
|
||||
max_candidate_tokens: int = 48
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.enabled = bool(self.enabled)
|
||||
self.drop_ratio = max(0.0, float(self.drop_ratio))
|
||||
self.min_core_results = max(1, int(self.min_core_results))
|
||||
self.max_graph_slots = max(0, int(self.max_graph_slots))
|
||||
self.gate_scan_top_k = max(1, int(self.gate_scan_top_k))
|
||||
self.grounded_confidence_threshold = _clip_score(self.grounded_confidence_threshold)
|
||||
self.incidental_confidence_threshold = _clip_score(self.incidental_confidence_threshold)
|
||||
self.min_query_token_coverage = _clip_score(self.min_query_token_coverage)
|
||||
self.incidental_query_relevance_threshold = _clip_score(
|
||||
self.incidental_query_relevance_threshold
|
||||
)
|
||||
self.incidental_core_overlap_threshold = _clip_score(
|
||||
self.incidental_core_overlap_threshold
|
||||
)
|
||||
self.incidental_specificity_threshold = _clip_score(
|
||||
self.incidental_specificity_threshold
|
||||
)
|
||||
self.query_weight = max(0.0, float(self.query_weight))
|
||||
self.novelty_weight = max(0.0, float(self.novelty_weight))
|
||||
self.complementarity_weight = max(0.0, float(self.complementarity_weight))
|
||||
self.specificity_weight = max(0.0, float(self.specificity_weight))
|
||||
self.gap_fill_weight = max(0.0, float(self.gap_fill_weight))
|
||||
self.max_candidate_tokens = max(8, int(self.max_candidate_tokens))
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompetitionProfile:
|
||||
text: str
|
||||
tokens: Set[str]
|
||||
entities: Set[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SeedEvidence:
|
||||
name: str
|
||||
strength: str
|
||||
support_count: int
|
||||
rank_hint: int
|
||||
|
||||
|
||||
def _safe_ratio(numerator: int, denominator: int) -> float:
|
||||
if denominator <= 0:
|
||||
return 0.0
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def _clip_score(value: float) -> float:
|
||||
return max(0.0, min(1.0, float(value)))
|
||||
|
||||
|
||||
def _is_cjk_chunk(token: str) -> bool:
|
||||
return bool(token) and all("\u4e00" <= char <= "\u9fff" for char in token)
|
||||
|
||||
|
||||
def _tokenize_for_competition(text: str, *, max_tokens: int) -> List[str]:
|
||||
normalized = str(text or "").lower().strip()
|
||||
if not normalized:
|
||||
return []
|
||||
|
||||
tokens: List[str] = []
|
||||
for chunk in _TOKEN_PATTERN.findall(normalized):
|
||||
if _is_cjk_chunk(chunk):
|
||||
tokens.extend(
|
||||
item.strip().lower()
|
||||
for item in jieba.lcut_for_search(chunk)
|
||||
if item.strip()
|
||||
)
|
||||
else:
|
||||
tokens.append(chunk)
|
||||
|
||||
filtered: List[str] = []
|
||||
for token in tokens:
|
||||
if len(token) <= 1:
|
||||
continue
|
||||
filtered.append(token)
|
||||
if len(filtered) >= max_tokens:
|
||||
break
|
||||
return filtered
|
||||
|
||||
|
||||
def _result_text_for_entity_match(result: RetrievalResult) -> str:
|
||||
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||
parts = [
|
||||
str(result.content or ""),
|
||||
str(metadata.get("subject", "") or ""),
|
||||
str(metadata.get("object", "") or ""),
|
||||
str(metadata.get("context_title", "") or ""),
|
||||
str(metadata.get("benchmark_title", "") or ""),
|
||||
]
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
|
||||
def _candidate_text_for_competition(result: RetrievalResult) -> str:
|
||||
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||
parts = [
|
||||
_result_text_for_entity_match(result),
|
||||
str(metadata.get("predicate", "") or ""),
|
||||
]
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
|
||||
def _extract_candidate_entities(
|
||||
retriever: DualPathRetriever,
|
||||
result: RetrievalResult,
|
||||
) -> Set[str]:
|
||||
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||
entities: Set[str] = set()
|
||||
|
||||
for name in retriever._extract_entities(_candidate_text_for_competition(result)).keys():
|
||||
normalized = str(name or "").strip().lower()
|
||||
if normalized:
|
||||
entities.add(normalized)
|
||||
|
||||
for key in ("benchmark_title", "context_title", "object", "subject"):
|
||||
normalized = str(metadata.get(key, "") or "").strip().lower()
|
||||
if normalized:
|
||||
entities.add(normalized)
|
||||
|
||||
return entities
|
||||
|
||||
|
||||
def _build_query_profile(
|
||||
retriever: DualPathRetriever,
|
||||
query: str,
|
||||
*,
|
||||
max_tokens: int,
|
||||
) -> _CompetitionProfile:
|
||||
text = str(query or "")
|
||||
tokens = set(_tokenize_for_competition(text, max_tokens=max_tokens))
|
||||
entities = {
|
||||
str(name or "").strip().lower()
|
||||
for name in retriever._extract_entities(text).keys()
|
||||
if str(name or "").strip()
|
||||
}
|
||||
return _CompetitionProfile(text=text, tokens=tokens, entities=entities)
|
||||
|
||||
|
||||
def _build_candidate_profile(
|
||||
retriever: DualPathRetriever,
|
||||
result: RetrievalResult,
|
||||
*,
|
||||
max_tokens: int,
|
||||
) -> _CompetitionProfile:
|
||||
text = _candidate_text_for_competition(result)
|
||||
return _CompetitionProfile(
|
||||
text=text,
|
||||
tokens=set(_tokenize_for_competition(text, max_tokens=max_tokens)),
|
||||
entities=_extract_candidate_entities(retriever, result),
|
||||
)
|
||||
|
||||
|
||||
def _build_core_profile(
|
||||
retriever: DualPathRetriever,
|
||||
results: Sequence[RetrievalResult],
|
||||
*,
|
||||
max_tokens: int,
|
||||
) -> _CompetitionProfile:
|
||||
parts: List[str] = []
|
||||
tokens: Set[str] = set()
|
||||
entities: Set[str] = set()
|
||||
|
||||
for result in results:
|
||||
profile = _build_candidate_profile(retriever, result, max_tokens=max_tokens)
|
||||
parts.append(profile.text)
|
||||
tokens.update(profile.tokens)
|
||||
entities.update(profile.entities)
|
||||
|
||||
return _CompetitionProfile(text="\n".join(parts), tokens=tokens, entities=entities)
|
||||
|
||||
|
||||
def _compute_query_relevance(candidate: _CompetitionProfile, query: _CompetitionProfile) -> float:
|
||||
entity_hit = _safe_ratio(len(candidate.entities & query.entities), len(query.entities))
|
||||
token_hit = _safe_ratio(len(candidate.tokens & query.tokens), len(query.tokens))
|
||||
if query.entities:
|
||||
return _clip_score(0.65 * entity_hit + 0.35 * token_hit)
|
||||
return _clip_score(max(entity_hit, token_hit))
|
||||
|
||||
|
||||
def _compute_novelty(candidate: _CompetitionProfile, core: _CompetitionProfile) -> float:
|
||||
entity_novelty = _safe_ratio(len(candidate.entities - core.entities), len(candidate.entities))
|
||||
token_novelty = _safe_ratio(len(candidate.tokens - core.tokens), len(candidate.tokens))
|
||||
return _clip_score(0.5 * entity_novelty + 0.5 * token_novelty)
|
||||
|
||||
|
||||
def _compute_complementarity(
|
||||
candidate: _CompetitionProfile,
|
||||
core: _CompetitionProfile,
|
||||
query_relevance: float,
|
||||
) -> float:
|
||||
if not core.tokens and not core.entities:
|
||||
return _clip_score(query_relevance)
|
||||
|
||||
entity_overlap = _safe_ratio(len(candidate.entities & core.entities), len(candidate.entities))
|
||||
token_overlap = _safe_ratio(len(candidate.tokens & core.tokens), len(candidate.tokens))
|
||||
core_overlap = 0.5 * entity_overlap + 0.5 * token_overlap
|
||||
sweet_spot = 1.0 - abs(core_overlap - 0.4) / 0.4
|
||||
return _clip_score(max(0.0, sweet_spot) * max(query_relevance, 0.2))
|
||||
|
||||
|
||||
def _compute_specificity(candidate: _CompetitionProfile, result: RetrievalResult) -> float:
|
||||
token_count = max(1, len(candidate.tokens))
|
||||
entity_density = _clip_score(_safe_ratio(len(candidate.entities), token_count) * 4.0)
|
||||
brevity = 1.0 - min(1.0, max(0, token_count - 16) / 16.0)
|
||||
predicate_bonus = 0.0
|
||||
|
||||
metadata = result.metadata if isinstance(result.metadata, dict) else {}
|
||||
predicate = str(metadata.get("predicate", "") or "").strip().lower()
|
||||
if predicate:
|
||||
if predicate in _BROAD_PREDICATES:
|
||||
predicate_bonus = -0.25
|
||||
elif result.result_type == "relation":
|
||||
predicate_bonus = 0.10
|
||||
|
||||
return _clip_score(0.6 * entity_density + 0.4 * brevity + predicate_bonus)
|
||||
|
||||
|
||||
def _compute_gap_fill(
|
||||
candidate: _CompetitionProfile,
|
||||
query: _CompetitionProfile,
|
||||
core: _CompetitionProfile,
|
||||
) -> float:
|
||||
missing_entities = query.entities - core.entities
|
||||
missing_tokens = query.tokens - core.tokens
|
||||
|
||||
entity_fill = _safe_ratio(len(candidate.entities & missing_entities), len(missing_entities))
|
||||
token_fill = _safe_ratio(len(candidate.tokens & missing_tokens), len(missing_tokens))
|
||||
|
||||
if missing_entities:
|
||||
return _clip_score(0.7 * entity_fill + 0.3 * token_fill)
|
||||
return _clip_score(max(entity_fill, token_fill))
|
||||
|
||||
|
||||
def _core_overlap(candidate: _CompetitionProfile, core: _CompetitionProfile) -> float:
|
||||
entity_overlap = _safe_ratio(len(candidate.entities & core.entities), len(candidate.entities))
|
||||
token_overlap = _safe_ratio(len(candidate.tokens & core.tokens), len(candidate.tokens))
|
||||
return _clip_score(0.5 * entity_overlap + 0.5 * token_overlap)
|
||||
|
||||
|
||||
def _compute_competition_score(
|
||||
retriever: DualPathRetriever,
|
||||
candidate: RetrievalResult,
|
||||
*,
|
||||
query_profile: _CompetitionProfile,
|
||||
core_profile: _CompetitionProfile,
|
||||
cfg: PosteriorGraphConfig,
|
||||
) -> Tuple[float, Dict[str, float]]:
|
||||
candidate_profile = _build_candidate_profile(
|
||||
retriever,
|
||||
candidate,
|
||||
max_tokens=cfg.max_candidate_tokens,
|
||||
)
|
||||
query_relevance = _compute_query_relevance(candidate_profile, query_profile)
|
||||
novelty = _compute_novelty(candidate_profile, core_profile)
|
||||
complementarity = _compute_complementarity(candidate_profile, core_profile, query_relevance)
|
||||
specificity = _compute_specificity(candidate_profile, candidate)
|
||||
gap_fill = _compute_gap_fill(candidate_profile, query_profile, core_profile)
|
||||
|
||||
final_score = (
|
||||
cfg.query_weight * query_relevance
|
||||
+ cfg.novelty_weight * novelty
|
||||
+ cfg.complementarity_weight * complementarity
|
||||
+ cfg.specificity_weight * specificity
|
||||
+ cfg.gap_fill_weight * gap_fill
|
||||
)
|
||||
breakdown = {
|
||||
"query_relevance": round(query_relevance, 4),
|
||||
"novelty": round(novelty, 4),
|
||||
"complementarity": round(complementarity, 4),
|
||||
"specificity": round(specificity, 4),
|
||||
"gap_fill": round(gap_fill, 4),
|
||||
"competition_score": round(_clip_score(final_score), 4),
|
||||
}
|
||||
return _clip_score(final_score), breakdown
|
||||
|
||||
|
||||
def _top_score(results: Sequence[RetrievalResult]) -> float:
|
||||
if not results:
|
||||
return 0.0
|
||||
return max(float(item.score) for item in results)
|
||||
|
||||
|
||||
def find_score_cliff(
|
||||
results: Sequence[RetrievalResult],
|
||||
*,
|
||||
drop_ratio: float,
|
||||
min_core_results: int,
|
||||
) -> int:
|
||||
ranked = list(results)
|
||||
if not ranked:
|
||||
return 0
|
||||
if len(ranked) <= min_core_results:
|
||||
return len(ranked)
|
||||
|
||||
for index in range(1, len(ranked)):
|
||||
prev_score = max(float(ranked[index - 1].score), 1e-8)
|
||||
current_score = float(ranked[index].score)
|
||||
score_drop = prev_score - current_score
|
||||
if score_drop / prev_score > float(drop_ratio):
|
||||
return max(min_core_results, index)
|
||||
|
||||
fallback = max(min_core_results, len(ranked) // 2)
|
||||
return min(len(ranked), fallback)
|
||||
|
||||
|
||||
def _extract_seed_evidence(
|
||||
retriever: DualPathRetriever,
|
||||
query_profile: _CompetitionProfile,
|
||||
results: Sequence[RetrievalResult],
|
||||
*,
|
||||
scan_top_k: int,
|
||||
max_tokens: int,
|
||||
) -> List[_SeedEvidence]:
|
||||
score_map: Dict[Tuple[str, str], _SeedEvidence] = {}
|
||||
top_results = list(results)[: max(1, int(scan_top_k))]
|
||||
|
||||
for rank, item in enumerate(top_results, start=1):
|
||||
profile = _build_candidate_profile(retriever, item, max_tokens=max_tokens)
|
||||
for entity in profile.entities:
|
||||
strength = "grounded" if entity in query_profile.entities else "incidental"
|
||||
key = (entity, strength)
|
||||
existing = score_map.get(key)
|
||||
if existing is None:
|
||||
score_map[key] = _SeedEvidence(
|
||||
name=entity,
|
||||
strength=strength,
|
||||
support_count=1,
|
||||
rank_hint=rank,
|
||||
)
|
||||
else:
|
||||
existing.support_count += 1
|
||||
existing.rank_hint = min(existing.rank_hint, rank)
|
||||
|
||||
return sorted(
|
||||
score_map.values(),
|
||||
key=lambda item: (
|
||||
0 if item.strength == "grounded" else 1,
|
||||
-int(item.support_count),
|
||||
int(item.rank_hint),
|
||||
-len(item.name),
|
||||
item.name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _grounded_seed_names(seed_evidence: Sequence[_SeedEvidence]) -> List[str]:
|
||||
return [item.name for item in seed_evidence if item.strength == "grounded"]
|
||||
|
||||
|
||||
def _incidental_seed_names(seed_evidence: Sequence[_SeedEvidence]) -> List[str]:
|
||||
return [item.name for item in seed_evidence if item.strength == "incidental"]
|
||||
|
||||
|
||||
def _need_for_graph(
|
||||
*,
|
||||
query_profile: _CompetitionProfile,
|
||||
core_profile: _CompetitionProfile,
|
||||
core_profiles: Sequence[_CompetitionProfile],
|
||||
grounded_seeds: Sequence[str],
|
||||
rag_confidence: float,
|
||||
cfg: PosteriorGraphConfig,
|
||||
) -> Tuple[bool, str]:
|
||||
uncovered_query_entities = query_profile.entities - core_profile.entities
|
||||
if uncovered_query_entities:
|
||||
return True, "uncovered_query_entities"
|
||||
|
||||
if len(grounded_seeds) >= 2:
|
||||
bridge_targets = set(list(grounded_seeds)[:2])
|
||||
same_core_hit = any(len(profile.entities & bridge_targets) >= 2 for profile in core_profiles)
|
||||
if not same_core_hit:
|
||||
return True, "grounded_bridge_gap"
|
||||
|
||||
token_coverage = _safe_ratio(len(core_profile.tokens & query_profile.tokens), len(query_profile.tokens))
|
||||
if grounded_seeds and token_coverage < float(cfg.min_query_token_coverage):
|
||||
return True, "low_core_query_coverage"
|
||||
|
||||
if grounded_seeds and float(rag_confidence) < float(cfg.grounded_confidence_threshold):
|
||||
return True, "low_confidence_grounded"
|
||||
|
||||
return False, "core_already_sufficient"
|
||||
|
||||
|
||||
def _passes_incidental_high_bar(
|
||||
retriever: DualPathRetriever,
|
||||
candidate: RetrievalResult,
|
||||
*,
|
||||
query_profile: _CompetitionProfile,
|
||||
core_profile: _CompetitionProfile,
|
||||
cfg: PosteriorGraphConfig,
|
||||
) -> bool:
|
||||
candidate_profile = _build_candidate_profile(
|
||||
retriever,
|
||||
candidate,
|
||||
max_tokens=cfg.max_candidate_tokens,
|
||||
)
|
||||
uncovered_query_entities = query_profile.entities - core_profile.entities
|
||||
if candidate_profile.entities & uncovered_query_entities:
|
||||
return True
|
||||
|
||||
query_relevance = _compute_query_relevance(candidate_profile, query_profile)
|
||||
specificity = _compute_specificity(candidate_profile, candidate)
|
||||
overlap = _core_overlap(candidate_profile, core_profile)
|
||||
gap_fill = _compute_gap_fill(candidate_profile, query_profile, core_profile)
|
||||
|
||||
return bool(
|
||||
query_relevance >= float(cfg.incidental_query_relevance_threshold)
|
||||
and specificity >= float(cfg.incidental_specificity_threshold)
|
||||
and overlap <= float(cfg.incidental_core_overlap_threshold)
|
||||
and gap_fill > 0.0
|
||||
)
|
||||
|
||||
|
||||
def _linked_core_paragraph_hashes(
|
||||
retriever: DualPathRetriever,
|
||||
relation_hash: str,
|
||||
) -> Set[str]:
|
||||
rows = retriever.metadata_store.query(
|
||||
"""
|
||||
SELECT paragraph_hash FROM paragraph_relations
|
||||
WHERE relation_hash = ?
|
||||
""",
|
||||
(relation_hash,),
|
||||
)
|
||||
return {
|
||||
str(row.get("paragraph_hash", "") or "").strip()
|
||||
for row in rows
|
||||
if str(row.get("paragraph_hash", "") or "").strip()
|
||||
}
|
||||
|
||||
|
||||
def _build_graph_results_from_seeds(
|
||||
retriever: DualPathRetriever,
|
||||
*,
|
||||
seed_entities: Sequence[str],
|
||||
temporal: Any,
|
||||
alpha: float,
|
||||
) -> List[RetrievalResult]:
|
||||
from .dual_path import RetrievalResult
|
||||
|
||||
service = getattr(retriever, "_graph_relation_recall", None)
|
||||
if service is None:
|
||||
return []
|
||||
|
||||
payloads = service.recall(seed_entities=seed_entities)
|
||||
if not payloads:
|
||||
return []
|
||||
|
||||
graph_results: List[RetrievalResult] = []
|
||||
for payload in payloads:
|
||||
meta = payload.to_payload()
|
||||
graph_results.append(
|
||||
RetrievalResult(
|
||||
hash_value=str(meta["hash"]),
|
||||
content=str(meta["content"]),
|
||||
score=0.0,
|
||||
result_type="relation",
|
||||
source="posterior_graph_recall",
|
||||
metadata={
|
||||
"subject": meta["subject"],
|
||||
"predicate": meta["predicate"],
|
||||
"object": meta["object"],
|
||||
"confidence": float(meta["confidence"]),
|
||||
"graph_seed_entities": list(meta["graph_seed_entities"]),
|
||||
"graph_hops": int(meta["graph_hops"]),
|
||||
"graph_candidate_type": str(meta["graph_candidate_type"]),
|
||||
"supporting_paragraph_count": int(meta["supporting_paragraph_count"]),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
graph_results = retriever._apply_temporal_filter_to_relations(graph_results, temporal)
|
||||
graph_results = retriever._merge_relation_results_graph_enhanced([], [], graph_results)
|
||||
relation_weight = max(0.0, 1.0 - float(alpha))
|
||||
for item in graph_results:
|
||||
item.score = float(item.score) * relation_weight
|
||||
item.source = "posterior_graph_competition"
|
||||
return graph_results
|
||||
|
||||
|
||||
def _competition_merge(
|
||||
retriever: DualPathRetriever,
|
||||
*,
|
||||
query: str,
|
||||
base_results: Sequence[RetrievalResult],
|
||||
graph_results: Sequence[RetrievalResult],
|
||||
top_k: int,
|
||||
cfg: PosteriorGraphConfig,
|
||||
) -> List[RetrievalResult]:
|
||||
ranked = list(base_results)[: max(1, int(top_k))]
|
||||
if not ranked or not graph_results:
|
||||
return ranked
|
||||
|
||||
cliff = find_score_cliff(
|
||||
ranked,
|
||||
drop_ratio=cfg.drop_ratio,
|
||||
min_core_results=cfg.min_core_results,
|
||||
)
|
||||
core_results = ranked[:cliff]
|
||||
replaceable_slots = min(
|
||||
max(0, int(top_k) - len(core_results)),
|
||||
int(cfg.max_graph_slots),
|
||||
)
|
||||
if replaceable_slots <= 0:
|
||||
return ranked[:top_k]
|
||||
|
||||
core_paragraph_hashes = {
|
||||
item.hash_value for item in core_results if item.result_type == "paragraph"
|
||||
}
|
||||
selected_hashes = {item.hash_value for item in core_results}
|
||||
filtered_graph_results: List[RetrievalResult] = []
|
||||
for item in graph_results:
|
||||
if item.hash_value in selected_hashes:
|
||||
continue
|
||||
linked_hashes = _linked_core_paragraph_hashes(retriever, item.hash_value)
|
||||
if core_paragraph_hashes & linked_hashes:
|
||||
continue
|
||||
filtered_graph_results.append(item)
|
||||
|
||||
tail_candidates: List[RetrievalResult] = []
|
||||
for item in ranked[cliff:top_k]:
|
||||
if item.hash_value not in selected_hashes:
|
||||
tail_candidates.append(item)
|
||||
tail_candidates.extend(filtered_graph_results)
|
||||
|
||||
query_profile = _build_query_profile(
|
||||
retriever,
|
||||
query,
|
||||
max_tokens=cfg.max_candidate_tokens,
|
||||
)
|
||||
core_profile = _build_core_profile(
|
||||
retriever,
|
||||
core_results,
|
||||
max_tokens=cfg.max_candidate_tokens,
|
||||
)
|
||||
|
||||
scored_candidates: List[Tuple[RetrievalResult, float]] = []
|
||||
for item in tail_candidates:
|
||||
competition_score, breakdown = _compute_competition_score(
|
||||
retriever,
|
||||
item,
|
||||
query_profile=query_profile,
|
||||
core_profile=core_profile,
|
||||
cfg=cfg,
|
||||
)
|
||||
metadata = dict(item.metadata) if isinstance(item.metadata, dict) else {}
|
||||
metadata["posterior_original_score"] = round(float(item.score), 4)
|
||||
metadata["posterior_competition_breakdown"] = breakdown
|
||||
metadata["posterior_competition_source"] = "posterior_graph_gate"
|
||||
item.metadata = metadata
|
||||
scored_candidates.append((item, competition_score))
|
||||
|
||||
scored_candidates.sort(
|
||||
key=lambda payload: (
|
||||
float(payload[1]),
|
||||
1 if payload[0].result_type == "relation" else 0,
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
tail_winners: List[RetrievalResult] = []
|
||||
seen_hashes = set(selected_hashes)
|
||||
for item, _ in scored_candidates:
|
||||
if item.hash_value in seen_hashes:
|
||||
continue
|
||||
tail_winners.append(item)
|
||||
seen_hashes.add(item.hash_value)
|
||||
if len(tail_winners) >= replaceable_slots:
|
||||
break
|
||||
|
||||
return (core_results + tail_winners)[:top_k]
|
||||
|
||||
|
||||
def apply_posterior_graph_gate(
|
||||
retriever: DualPathRetriever,
|
||||
*,
|
||||
query: str,
|
||||
base_results: Sequence[RetrievalResult],
|
||||
top_k: int,
|
||||
temporal: Any,
|
||||
relation_intent: Dict[str, Any],
|
||||
) -> List[RetrievalResult]:
|
||||
cfg = getattr(retriever.config, "posterior_graph", None)
|
||||
if not isinstance(cfg, PosteriorGraphConfig) or not cfg.enabled:
|
||||
return list(base_results)[:top_k]
|
||||
if not base_results:
|
||||
setattr(
|
||||
retriever,
|
||||
"_posterior_graph_gate_last_decision",
|
||||
{
|
||||
"scheme": "posterior_graph_gate",
|
||||
"query": str(query or ""),
|
||||
"enabled": False,
|
||||
"bucket": "posterior_gate_empty",
|
||||
},
|
||||
)
|
||||
return []
|
||||
|
||||
top_k_int = max(1, int(top_k))
|
||||
alpha_override = relation_intent.get("alpha_override") if isinstance(relation_intent, dict) else None
|
||||
alpha = float(alpha_override) if alpha_override is not None else float(retriever.config.alpha)
|
||||
rag_confidence = _top_score(list(base_results)[:top_k_int])
|
||||
|
||||
query_profile = _build_query_profile(
|
||||
retriever,
|
||||
query,
|
||||
max_tokens=cfg.max_candidate_tokens,
|
||||
)
|
||||
seed_evidence = _extract_seed_evidence(
|
||||
retriever,
|
||||
query_profile,
|
||||
base_results,
|
||||
scan_top_k=cfg.gate_scan_top_k,
|
||||
max_tokens=cfg.max_candidate_tokens,
|
||||
)
|
||||
grounded_seeds = _grounded_seed_names(seed_evidence)[:2]
|
||||
incidental_seeds = _incidental_seed_names(seed_evidence)[:2]
|
||||
|
||||
core_results = list(base_results)[
|
||||
: find_score_cliff(
|
||||
list(base_results)[:top_k_int],
|
||||
drop_ratio=cfg.drop_ratio,
|
||||
min_core_results=cfg.min_core_results,
|
||||
)
|
||||
]
|
||||
core_profile = _build_core_profile(
|
||||
retriever,
|
||||
core_results,
|
||||
max_tokens=cfg.max_candidate_tokens,
|
||||
)
|
||||
core_profiles = [
|
||||
_build_candidate_profile(retriever, item, max_tokens=cfg.max_candidate_tokens)
|
||||
for item in core_results
|
||||
]
|
||||
need_for_graph, need_reason = _need_for_graph(
|
||||
query_profile=query_profile,
|
||||
core_profile=core_profile,
|
||||
core_profiles=core_profiles,
|
||||
grounded_seeds=grounded_seeds,
|
||||
rag_confidence=rag_confidence,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
seed_type = "none"
|
||||
seed_names: List[str] = []
|
||||
if grounded_seeds and need_for_graph:
|
||||
seed_type = "grounded"
|
||||
seed_names = grounded_seeds
|
||||
elif (
|
||||
not grounded_seeds
|
||||
and incidental_seeds
|
||||
and rag_confidence < float(cfg.incidental_confidence_threshold)
|
||||
):
|
||||
seed_type = "incidental"
|
||||
seed_names = incidental_seeds
|
||||
|
||||
if not seed_names:
|
||||
setattr(
|
||||
retriever,
|
||||
"_posterior_graph_gate_last_decision",
|
||||
{
|
||||
"scheme": "posterior_graph_gate",
|
||||
"query": str(query or ""),
|
||||
"enabled": False,
|
||||
"bucket": "posterior_gate_none",
|
||||
"grounded_seeds": list(grounded_seeds),
|
||||
"incidental_seeds": list(incidental_seeds),
|
||||
"selected_seed_type": seed_type,
|
||||
"need_for_graph": bool(need_for_graph),
|
||||
"need_reason": str(need_reason),
|
||||
"rag_confidence": round(float(rag_confidence), 4),
|
||||
},
|
||||
)
|
||||
return list(base_results)[:top_k_int]
|
||||
|
||||
graph_results = _build_graph_results_from_seeds(
|
||||
retriever,
|
||||
seed_entities=seed_names,
|
||||
temporal=temporal,
|
||||
alpha=alpha,
|
||||
)
|
||||
raw_graph_count = len(graph_results)
|
||||
if seed_type == "incidental":
|
||||
graph_results = [
|
||||
item
|
||||
for item in graph_results
|
||||
if _passes_incidental_high_bar(
|
||||
retriever,
|
||||
item,
|
||||
query_profile=query_profile,
|
||||
core_profile=core_profile,
|
||||
cfg=cfg,
|
||||
)
|
||||
]
|
||||
|
||||
if not graph_results:
|
||||
setattr(
|
||||
retriever,
|
||||
"_posterior_graph_gate_last_decision",
|
||||
{
|
||||
"scheme": "posterior_graph_gate",
|
||||
"query": str(query or ""),
|
||||
"enabled": False,
|
||||
"bucket": "posterior_gate_graph_filtered",
|
||||
"grounded_seeds": list(grounded_seeds),
|
||||
"incidental_seeds": list(incidental_seeds),
|
||||
"selected_seed_type": seed_type,
|
||||
"need_for_graph": bool(need_for_graph),
|
||||
"need_reason": str(need_reason),
|
||||
"rag_confidence": round(float(rag_confidence), 4),
|
||||
"graph_result_count": int(raw_graph_count),
|
||||
},
|
||||
)
|
||||
return list(base_results)[:top_k_int]
|
||||
|
||||
final_results = _competition_merge(
|
||||
retriever,
|
||||
query=query,
|
||||
base_results=base_results,
|
||||
graph_results=graph_results,
|
||||
top_k=top_k_int,
|
||||
cfg=cfg,
|
||||
)
|
||||
selected_hashes = {item.hash_value for item in final_results}
|
||||
graph_selected = any(item.hash_value in selected_hashes for item in graph_results)
|
||||
setattr(
|
||||
retriever,
|
||||
"_posterior_graph_gate_last_decision",
|
||||
{
|
||||
"scheme": "posterior_graph_gate",
|
||||
"query": str(query or ""),
|
||||
"enabled": bool(graph_selected),
|
||||
"bucket": "posterior_gate_enabled" if graph_selected else "posterior_gate_tail_rejected",
|
||||
"grounded_seeds": list(grounded_seeds),
|
||||
"incidental_seeds": list(incidental_seeds),
|
||||
"selected_seed_type": seed_type,
|
||||
"need_for_graph": bool(need_for_graph),
|
||||
"need_reason": str(need_reason),
|
||||
"rag_confidence": round(float(rag_confidence), 4),
|
||||
"graph_result_count": int(raw_graph_count),
|
||||
"filtered_graph_count": max(0, raw_graph_count - len(graph_results)),
|
||||
"base_top_k_count": min(len(base_results), top_k_int),
|
||||
},
|
||||
)
|
||||
return final_results[:top_k_int]
|
||||
@@ -306,15 +306,22 @@ class SparseBM25Index:
|
||||
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
token_count = max(1, len(tokens))
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
bm25_score = float(row.get("bm25_score", 0.0))
|
||||
content = str(row.get("content", "") or "")
|
||||
content_low = content.lower()
|
||||
matched_tokens = [token for token in tokens if token in content_low]
|
||||
matched_token_count = len(dict.fromkeys(matched_tokens))
|
||||
results.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"content": row["content"],
|
||||
"content": content,
|
||||
"rank": rank,
|
||||
"bm25_score": bm25_score,
|
||||
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
|
||||
"matched_token_count": matched_token_count,
|
||||
"matched_token_ratio": matched_token_count / float(token_count),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
@@ -56,7 +56,7 @@ class ThresholdConfig:
|
||||
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
|
||||
|
||||
if self.min_threshold >= self.max_threshold:
|
||||
raise ValueError("min_threshold必须小于max_threshold")
|
||||
raise ValueError(f"min_threshold必须小于max_threshold")
|
||||
|
||||
if not 0 <= self.percentile <= 100:
|
||||
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, cast
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
@@ -18,7 +19,7 @@ from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from ...paths import default_data_dir, resolve_repo_path
|
||||
from ..embedding import create_embedding_api_adapter
|
||||
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index
|
||||
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions
|
||||
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
|
||||
from ..utils.aggregate_query_service import AggregateQueryService
|
||||
from ..utils.episode_retrieval_service import EpisodeRetrievalService
|
||||
|
||||
@@ -13,6 +13,7 @@ from ..retrieval import (
|
||||
DynamicThresholdFilter,
|
||||
FusionConfig,
|
||||
GraphRelationRecallConfig,
|
||||
PosteriorGraphConfig,
|
||||
RelationIntentConfig,
|
||||
RetrievalStrategy,
|
||||
SparseBM25Config,
|
||||
@@ -143,6 +144,9 @@ def build_search_runtime(
|
||||
graph_recall_cfg_raw = _safe_dict(
|
||||
_get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {}
|
||||
)
|
||||
posterior_graph_cfg_raw = _safe_dict(
|
||||
_get_config_value(plugin_config, "retrieval.search.posterior_graph", {}) or {}
|
||||
)
|
||||
|
||||
try:
|
||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||
@@ -168,6 +172,12 @@ def build_search_runtime(
|
||||
log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}")
|
||||
graph_recall_cfg = GraphRelationRecallConfig()
|
||||
|
||||
try:
|
||||
posterior_graph_cfg = PosteriorGraphConfig(**posterior_graph_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] posterior_graph 配置非法,回退默认: {e}")
|
||||
posterior_graph_cfg = PosteriorGraphConfig()
|
||||
|
||||
try:
|
||||
config = DualPathRetrieverConfig(
|
||||
top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20),
|
||||
@@ -189,6 +199,7 @@ def build_search_runtime(
|
||||
fusion=fusion_cfg,
|
||||
relation_intent=relation_intent_cfg,
|
||||
graph_recall=graph_recall_cfg,
|
||||
posterior_graph=posterior_graph_cfg,
|
||||
)
|
||||
|
||||
runtime.retriever = DualPathRetriever(
|
||||
|
||||
@@ -9,6 +9,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Tuple, List, Dict, Set, Any
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
import numpy as np
|
||||
@@ -41,6 +42,7 @@ except ImportError:
|
||||
|
||||
import contextlib
|
||||
from src.common.logger import get_logger
|
||||
from ..utils.hash import compute_hash
|
||||
from ..utils.io import atomic_write
|
||||
|
||||
logger = get_logger("A_Memorix.GraphStore")
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import hashlib
|
||||
import shutil
|
||||
@@ -190,7 +191,7 @@ class VectorStore:
|
||||
self._update_reservoir(batch_vecs)
|
||||
# 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断
|
||||
if len(self._reservoir_buffer) >= 10000:
|
||||
logger.info("训练样本达到上限,开始训练...")
|
||||
logger.info(f"训练样本达到上限,开始训练...")
|
||||
self._train_and_replay_unlocked()
|
||||
|
||||
self._total_added += len(batch_ids)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import List
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
||||
|
||||
class FactualStrategy(BaseStrategy):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import List
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
||||
|
||||
class NarrativeStrategy(BaseStrategy):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags
|
||||
|
||||
class QuoteStrategy(BaseStrategy):
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
|
||||
def compute_hash(text: str, hash_type: str = "sha256") -> str:
|
||||
|
||||
@@ -5,6 +5,7 @@ IO Utilities
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
实现 Aho-Corasick 算法用于多模式匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Tuple, Set
|
||||
from typing import List, Dict, Tuple, Set, Any
|
||||
from collections import deque
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..retrieval.dual_path import RetrievalResult
|
||||
|
||||
|
||||
@@ -234,7 +234,7 @@ async def ensure_runtime_self_check(
|
||||
sample_text=sample_text,
|
||||
)
|
||||
try:
|
||||
plugin_or_config._runtime_self_check_report = report
|
||||
setattr(plugin_or_config, "_runtime_self_check_report", report)
|
||||
except Exception:
|
||||
pass
|
||||
return report
|
||||
|
||||
@@ -287,7 +287,7 @@ class SearchExecutionService:
|
||||
|
||||
async def _executor() -> Dict[str, Any]:
|
||||
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
|
||||
retriever.config.enable_ppr = bool(request.enable_ppr)
|
||||
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
|
||||
started_at = time.time()
|
||||
try:
|
||||
retrieved = await retriever.retrieve(
|
||||
@@ -380,7 +380,7 @@ class SearchExecutionService:
|
||||
elapsed_ms = (time.time() - started_at) * 1000.0
|
||||
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
||||
finally:
|
||||
retriever.config.enable_ppr = original_ppr
|
||||
setattr(retriever.config, "enable_ppr", original_ppr)
|
||||
|
||||
dedup_hit = False
|
||||
try:
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
导入到 A_memorix 的存储组件中。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
_runtime_kernel: Any = None
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ for _path in (SRC_ROOT, PROJECT_ROOT, PLUGIN_ROOT):
|
||||
if _path_str not in sys.path:
|
||||
sys.path.insert(0, _path_str)
|
||||
|
||||
from A_memorix.paths import config_path, default_data_dir
|
||||
from A_memorix.paths import config_path, default_data_dir, resolve_repo_path
|
||||
|
||||
DEFAULT_CONFIG_PATH = config_path()
|
||||
DEFAULT_DATA_DIR = default_data_dir()
|
||||
|
||||
@@ -10,12 +10,14 @@ LPMM 到 A_memorix 存储转换器
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import asyncio
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Tuple
|
||||
from typing import Dict, Any, List, Tuple
|
||||
import numpy as np
|
||||
import tomlkit
|
||||
|
||||
|
||||
@@ -12,14 +12,17 @@
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
import argparse
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import tomlkit
|
||||
|
||||
@@ -17,7 +17,7 @@ import sqlite3
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import tomlkit
|
||||
|
||||
|
||||
Reference in New Issue
Block a user