feat:新增 A_Memorix 记忆插件

引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。
This commit is contained in:
DawnARC
2026-03-18 21:33:15 +08:00
parent a5a6d2cb26
commit 999e7246e2
48 changed files with 17070 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
"""工具模块 - 哈希、监控等辅助功能"""
from .hash import compute_hash, normalize_text
from .monitor import MemoryMonitor
from .quantization import quantize_vector, dequantize_vector
from .time_parser import (
parse_query_datetime_to_timestamp,
parse_query_time_range,
parse_ingest_datetime_to_timestamp,
normalize_time_meta,
format_timestamp,
)
from .relation_write_service import RelationWriteService, RelationWriteResult
from .relation_query import RelationQuerySpec, parse_relation_query_spec
from .plugin_id_policy import PluginIdPolicy
__all__ = [
"compute_hash",
"normalize_text",
"MemoryMonitor",
"quantize_vector",
"dequantize_vector",
"parse_query_datetime_to_timestamp",
"parse_query_time_range",
"parse_ingest_datetime_to_timestamp",
"normalize_time_meta",
"format_timestamp",
"RelationWriteService",
"RelationWriteResult",
"RelationQuerySpec",
"parse_relation_query_spec",
"PluginIdPolicy",
]

View File

@@ -0,0 +1,360 @@
"""
聚合查询服务:
- 并发执行 search/time/episode 分支
- 统一分支结果结构
- 可选混合排序Weighted RRF
"""
from __future__ import annotations
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
logger = get_logger("A_Memorix.AggregateQueryService")
BranchRunner = Callable[[], Awaitable[Dict[str, Any]]]
class AggregateQueryService:
"""聚合查询执行服务search/time/episode"""
def __init__(self, plugin_config: Optional[Any] = None):
self.plugin_config = plugin_config or {}
def _cfg(self, key: str, default: Any = None) -> Any:
getter = getattr(self.plugin_config, "get_config", None)
if callable(getter):
return getter(key, default)
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
@staticmethod
def _as_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except Exception:
return float(default)
@staticmethod
def _as_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _rrf_k(self) -> float:
raw = self._cfg("retrieval.aggregate.rrf_k", 60.0)
value = self._as_float(raw, 60.0)
return max(1.0, value)
def _weights(self) -> Dict[str, float]:
defaults = {"search": 1.0, "time": 1.0, "episode": 1.0}
raw = self._cfg("retrieval.aggregate.weights", {})
if not isinstance(raw, dict):
return defaults
out = dict(defaults)
for key in ("search", "time", "episode"):
if key in raw:
out[key] = max(0.0, self._as_float(raw.get(key), defaults[key]))
return out
@staticmethod
def _normalize_branch_payload(
name: str,
payload: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
data = payload if isinstance(payload, dict) else {}
results_raw = data.get("results", [])
results = results_raw if isinstance(results_raw, list) else []
count = data.get("count")
if count is None:
count = len(results)
return {
"name": name,
"success": bool(data.get("success", False)),
"skipped": bool(data.get("skipped", False)),
"skip_reason": str(data.get("skip_reason", "") or "").strip(),
"error": str(data.get("error", "") or "").strip(),
"results": results,
"count": max(0, int(count)),
"elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)),
"content": str(data.get("content", "") or ""),
"query_type": str(data.get("query_type", "") or name),
}
@staticmethod
def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str:
item_type = str(item.get("type", "") or "").strip().lower()
if item_type == "episode":
episode_id = str(item.get("episode_id", "") or "").strip()
if episode_id:
return f"episode:{episode_id}"
item_hash = str(item.get("hash", "") or "").strip()
if item_hash:
return f"{item_type}:{item_hash}"
return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}"
def _build_mixed_results(
self,
*,
branches: Dict[str, Dict[str, Any]],
top_k: int,
) -> List[Dict[str, Any]]:
rrf_k = self._rrf_k()
weights = self._weights()
bucket: Dict[str, Dict[str, Any]] = {}
for branch_name, branch in branches.items():
if not branch.get("success", False):
continue
results = branch.get("results", [])
if not isinstance(results, list):
continue
weight = max(0.0, float(weights.get(branch_name, 1.0)))
for idx, item in enumerate(results, start=1):
if not isinstance(item, dict):
continue
key = self._mix_key(item, branch_name, idx)
score = weight / (rrf_k + float(idx))
if key not in bucket:
merged = dict(item)
merged["fusion_score"] = 0.0
merged["_source_branches"] = set()
bucket[key] = merged
target = bucket[key]
target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score
target["_source_branches"].add(branch_name)
mixed = list(bucket.values())
mixed.sort(
key=lambda x: (
-float(x.get("fusion_score", 0.0)),
str(x.get("type", "") or ""),
str(x.get("hash", "") or x.get("episode_id", "") or ""),
)
)
out: List[Dict[str, Any]] = []
for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1):
merged = dict(item)
branches_set = merged.pop("_source_branches", set())
merged["source_branches"] = sorted(list(branches_set))
merged["rank"] = rank
out.append(merged)
return out
@staticmethod
def _status(branch: Dict[str, Any]) -> str:
if branch.get("skipped", False):
return "skipped"
if branch.get("success", False):
return "success"
return "failed"
def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
summary: Dict[str, Dict[str, Any]] = {}
for name, branch in branches.items():
status = self._status(branch)
summary[name] = {
"status": status,
"count": int(branch.get("count", 0) or 0),
}
if status == "skipped":
summary[name]["reason"] = str(branch.get("skip_reason", "") or "")
if status == "failed":
summary[name]["error"] = str(branch.get("error", "") or "")
return summary
def _build_content(
self,
*,
query: str,
branches: Dict[str, Dict[str, Any]],
errors: List[Dict[str, str]],
mixed_results: Optional[List[Dict[str, Any]]],
) -> str:
lines: List[str] = [
f"🔀 聚合查询结果query='{query or 'N/A'}'",
"",
"分支状态:",
]
for name in ("search", "time", "episode"):
branch = branches.get(name, {})
status = self._status(branch)
count = int(branch.get("count", 0) or 0)
line = f"- {name}: {status}, count={count}"
reason = str(branch.get("skip_reason", "") or "").strip()
err = str(branch.get("error", "") or "").strip()
if status == "skipped" and reason:
line += f" ({reason})"
if status == "failed" and err:
line += f" ({err})"
lines.append(line)
if errors:
lines.append("")
lines.append("错误:")
for item in errors[:6]:
lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}")
if mixed_results is not None:
lines.append("")
lines.append(f"🧩 混合结果({len(mixed_results)} 条):")
for idx, item in enumerate(mixed_results[:5], start=1):
src = ",".join(item.get("source_branches", []) or [])
if str(item.get("type", "") or "") == "episode":
title = str(item.get("title", "") or "Untitled")
lines.append(f"{idx}. 🧠 {title} [{src}]")
else:
text = str(item.get("content", "") or "")
if len(text) > 80:
text = text[:80] + "..."
lines.append(f"{idx}. {text} [{src}]")
return "\n".join(lines)
async def execute(
self,
*,
query: str,
top_k: int,
mix: bool,
mix_top_k: Optional[int],
time_from: Optional[str],
time_to: Optional[str],
search_runner: Optional[BranchRunner],
time_runner: Optional[BranchRunner],
episode_runner: Optional[BranchRunner],
) -> Dict[str, Any]:
clean_query = str(query or "").strip()
safe_top_k = max(1, int(top_k))
safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k))
branches: Dict[str, Dict[str, Any]] = {}
errors: List[Dict[str, str]] = []
scheduled: List[Tuple[str, asyncio.Task]] = []
if clean_query:
if search_runner is not None:
scheduled.append(("search", asyncio.create_task(search_runner())))
else:
branches["search"] = self._normalize_branch_payload(
"search",
{"success": False, "error": "search runner unavailable", "results": []},
)
else:
branches["search"] = self._normalize_branch_payload(
"search",
{
"success": False,
"skipped": True,
"skip_reason": "missing_query",
"results": [],
"count": 0,
},
)
if time_from or time_to:
if time_runner is not None:
scheduled.append(("time", asyncio.create_task(time_runner())))
else:
branches["time"] = self._normalize_branch_payload(
"time",
{"success": False, "error": "time runner unavailable", "results": []},
)
else:
branches["time"] = self._normalize_branch_payload(
"time",
{
"success": False,
"skipped": True,
"skip_reason": "missing_time_window",
"results": [],
"count": 0,
},
)
if episode_runner is not None:
scheduled.append(("episode", asyncio.create_task(episode_runner())))
else:
branches["episode"] = self._normalize_branch_payload(
"episode",
{"success": False, "error": "episode runner unavailable", "results": []},
)
if scheduled:
done = await asyncio.gather(
*[task for _, task in scheduled],
return_exceptions=True,
)
for (branch_name, _), payload in zip(scheduled, done):
if isinstance(payload, Exception):
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
normalized = self._normalize_branch_payload(
branch_name,
{
"success": False,
"error": str(payload),
"results": [],
},
)
else:
normalized = self._normalize_branch_payload(branch_name, payload)
branches[branch_name] = normalized
for name in ("search", "time", "episode"):
branch = branches.get(name)
if not branch:
continue
if branch.get("skipped", False):
continue
if not branch.get("success", False):
errors.append(
{
"branch": name,
"error": str(branch.get("error", "") or "unknown error"),
}
)
success = any(
bool(branches.get(name, {}).get("success", False))
for name in ("search", "time", "episode")
)
mixed_results: Optional[List[Dict[str, Any]]] = None
if mix:
mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k)
payload: Dict[str, Any] = {
"success": success,
"query_type": "aggregate",
"query": clean_query,
"top_k": safe_top_k,
"mix": bool(mix),
"mix_top_k": safe_mix_top_k,
"branches": branches,
"errors": errors,
"summary": self._build_summary(branches),
}
if mixed_results is not None:
payload["mixed_results"] = mixed_results
payload["content"] = self._build_content(
query=clean_query,
branches=branches,
errors=errors,
mixed_results=mixed_results,
)
return payload

View File

@@ -0,0 +1,182 @@
"""Episode hybrid retrieval service."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
from ..retrieval import DualPathRetriever, TemporalQueryOptions
logger = get_logger("A_Memorix.EpisodeRetrievalService")
class EpisodeRetrievalService:
"""Hybrid episode retrieval backed by lexical rows and evidence projection."""
_RRF_K = 60.0
_BRANCH_WEIGHTS = {
"lexical": 1.0,
"paragraph_evidence": 1.0,
"relation_evidence": 0.85,
}
def __init__(
self,
*,
metadata_store: Any,
retriever: Optional[DualPathRetriever] = None,
) -> None:
self.metadata_store = metadata_store
self.retriever = retriever
async def query(
self,
*,
query: str = "",
top_k: int = 5,
time_from: Optional[float] = None,
time_to: Optional[float] = None,
person: Optional[str] = None,
source: Optional[str] = None,
include_paragraphs: bool = False,
) -> List[Dict[str, Any]]:
clean_query = str(query or "").strip()
safe_top_k = max(1, int(top_k))
candidate_k = max(30, safe_top_k * 6)
branches: Dict[str, List[Dict[str, Any]]] = {
"lexical": self.metadata_store.query_episodes(
query=clean_query,
time_from=time_from,
time_to=time_to,
person=person,
source=source,
limit=(candidate_k if clean_query else safe_top_k),
)
}
if clean_query and self.retriever is not None:
try:
temporal = TemporalQueryOptions(
time_from=time_from,
time_to=time_to,
person=person,
source=source,
)
results = await self.retriever.retrieve(
query=clean_query,
top_k=candidate_k,
temporal=temporal,
)
except Exception as exc:
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
else:
paragraph_rank_map: Dict[str, int] = {}
relation_rank_map: Dict[str, int] = {}
for rank, item in enumerate(results, start=1):
hash_value = str(getattr(item, "hash_value", "") or "").strip()
result_type = str(getattr(item, "result_type", "") or "").strip().lower()
if not hash_value:
continue
if result_type == "paragraph" and hash_value not in paragraph_rank_map:
paragraph_rank_map[hash_value] = rank
elif result_type == "relation" and hash_value not in relation_rank_map:
relation_rank_map[hash_value] = rank
if paragraph_rank_map:
paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes(
list(paragraph_rank_map.keys()),
source=source,
)
if paragraph_rows:
branches["paragraph_evidence"] = self._rank_projected_rows(
paragraph_rows,
rank_map=paragraph_rank_map,
support_key="matched_paragraph_hashes",
)
if relation_rank_map:
relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes(
list(relation_rank_map.keys()),
source=source,
)
if relation_rows:
branches["relation_evidence"] = self._rank_projected_rows(
relation_rows,
rank_map=relation_rank_map,
support_key="matched_relation_hashes",
)
fused = self._fuse_branches(branches, top_k=safe_top_k)
if include_paragraphs:
for item in fused:
item["paragraphs"] = self.metadata_store.get_episode_paragraphs(
episode_id=str(item.get("episode_id") or ""),
limit=50,
)
return fused
@staticmethod
def _rank_projected_rows(
rows: List[Dict[str, Any]],
*,
rank_map: Dict[str, int],
support_key: str,
) -> List[Dict[str, Any]]:
sentinel = 10**9
ranked = [dict(item) for item in rows]
def _first_support_rank(item: Dict[str, Any]) -> int:
support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])]
ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map]
return min(ranks) if ranks else sentinel
ranked.sort(
key=lambda item: (
_first_support_rank(item),
-int(item.get("matched_paragraph_count") or 0),
-float(item.get("updated_at") or 0.0),
str(item.get("episode_id") or ""),
)
)
return ranked
def _fuse_branches(
self,
branches: Dict[str, List[Dict[str, Any]]],
*,
top_k: int,
) -> List[Dict[str, Any]]:
bucket: Dict[str, Dict[str, Any]] = {}
for branch_name, rows in branches.items():
weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0)
if weight <= 0.0:
continue
for rank, row in enumerate(rows, start=1):
episode_id = str(row.get("episode_id", "") or "").strip()
if not episode_id:
continue
if episode_id not in bucket:
payload = dict(row)
payload.pop("matched_paragraph_hashes", None)
payload.pop("matched_relation_hashes", None)
payload.pop("matched_paragraph_count", None)
payload.pop("matched_relation_count", None)
payload["_fusion_score"] = 0.0
bucket[episode_id] = payload
bucket[episode_id]["_fusion_score"] = float(
bucket[episode_id].get("_fusion_score", 0.0)
) + weight / (self._RRF_K + float(rank))
out = list(bucket.values())
out.sort(
key=lambda item: (
-float(item.get("_fusion_score", 0.0)),
-float(item.get("updated_at") or 0.0),
str(item.get("episode_id") or ""),
)
)
for item in out:
item.pop("_fusion_score", None)
return out[: max(1, int(top_k))]

View File

@@ -0,0 +1,129 @@
"""
哈希工具模块
提供文本哈希计算功能,用于唯一标识和去重。
"""
import hashlib
import re
from typing import Union
def compute_hash(text: str, hash_type: str = "sha256") -> str:
"""
计算文本的哈希值
Args:
text: 输入文本
hash_type: 哈希算法类型sha256, md5等
Returns:
哈希值字符串
"""
if hash_type == "sha256":
return hashlib.sha256(text.encode("utf-8")).hexdigest()
elif hash_type == "md5":
return hashlib.md5(text.encode("utf-8")).hexdigest()
else:
raise ValueError(f"不支持的哈希算法: {hash_type}")
def normalize_text(text: str) -> str:
"""
规范化文本用于哈希计算
执行以下操作:
- 去除首尾空白
- 统一换行符为\\n
- 压缩多个连续空格
- 去除不可见字符(保留换行和制表符)
Args:
text: 输入文本
Returns:
规范化后的文本
"""
# 去除首尾空白
text = text.strip()
# 统一换行符
text = text.replace("\r\n", "\n").replace("\r", "\n")
# 压缩多个连续空格为一个(但保留换行和制表符)
text = re.sub(r"[^\S\n]+", " ", text)
return text
def compute_paragraph_hash(paragraph: str) -> str:
"""
计算段落的哈希值
Args:
paragraph: 段落文本
Returns:
段落哈希值用于paragraph-前缀)
"""
normalized = normalize_text(paragraph)
return compute_hash(normalized)
def compute_entity_hash(entity: str) -> str:
"""
计算实体的哈希值
Args:
entity: 实体名称
Returns:
实体哈希值用于entity-前缀)
"""
normalized = entity.strip().lower()
return compute_hash(normalized)
def compute_relation_hash(relation: tuple) -> str:
"""
计算关系的哈希值
Args:
relation: 关系元组 (subject, predicate, object)
Returns:
关系哈希值用于relation-前缀)
"""
# 将关系元组转为字符串
relation_str = str(tuple(relation))
return compute_hash(relation_str)
def format_hash_key(hash_type: str, hash_value: str) -> str:
"""
格式化哈希键
Args:
hash_type: 类型前缀paragraph, entity, relation
hash_value: 哈希值
Returns:
格式化的键(如 paragraph-abc123...
"""
return f"{hash_type}-{hash_value}"
def parse_hash_key(key: str) -> tuple[str, str]:
"""
解析哈希键
Args:
key: 格式化的键(如 paragraph-abc123...
Returns:
(类型, 哈希值) 元组
"""
parts = key.split("-", 1)
if len(parts) != 2:
raise ValueError(f"无效的哈希键格式: {key}")
return parts[0], parts[1]

View File

@@ -0,0 +1,110 @@
"""Shared import payload normalization helpers."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from ..storage import KnowledgeType, resolve_stored_knowledge_type
from .time_parser import normalize_time_meta
def _normalize_entities(raw_entities: Any) -> List[str]:
if not isinstance(raw_entities, list):
return []
out: List[str] = []
seen = set()
for item in raw_entities:
name = str(item or "").strip()
if not name:
continue
key = name.lower()
if key in seen:
continue
seen.add(key)
out.append(name)
return out
def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]:
if not isinstance(raw_relations, list):
return []
out: List[Dict[str, str]] = []
for item in raw_relations:
if not isinstance(item, dict):
continue
subject = str(item.get("subject", "")).strip()
predicate = str(item.get("predicate", "")).strip()
obj = str(item.get("object", "")).strip()
if not (subject and predicate and obj):
continue
out.append(
{
"subject": subject,
"predicate": predicate,
"object": obj,
}
)
return out
def normalize_paragraph_import_item(
item: Any,
*,
default_source: str,
) -> Dict[str, Any]:
"""Normalize one paragraph import item from text/json payloads."""
if isinstance(item, str):
content = str(item)
knowledge_type = resolve_stored_knowledge_type(None, content=content)
return {
"content": content,
"knowledge_type": knowledge_type.value,
"source": str(default_source or "").strip(),
"time_meta": None,
"entities": [],
"relations": [],
}
if not isinstance(item, dict) or "content" not in item:
raise ValueError("段落项必须为字符串或包含 content 的对象")
content = str(item.get("content", "") or "")
if not content.strip():
raise ValueError("段落 content 不能为空")
raw_time_meta = {
"event_time": item.get("event_time"),
"event_time_start": item.get("event_time_start"),
"event_time_end": item.get("event_time_end"),
"time_range": item.get("time_range"),
"time_granularity": item.get("time_granularity"),
"time_confidence": item.get("time_confidence"),
}
time_meta_field = item.get("time_meta")
if isinstance(time_meta_field, dict):
raw_time_meta.update(time_meta_field)
knowledge_type_raw = item.get("knowledge_type")
if knowledge_type_raw is None:
knowledge_type_raw = item.get("type")
knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content)
source = str(item.get("source") or default_source or "").strip()
if not source:
source = str(default_source or "").strip()
normalized_time_meta = normalize_time_meta(raw_time_meta)
return {
"content": content,
"knowledge_type": knowledge_type.value,
"source": source,
"time_meta": normalized_time_meta if normalized_time_meta else None,
"entities": _normalize_entities(item.get("entities")),
"relations": _normalize_relations(item.get("relations")),
}
def normalize_summary_knowledge_type(value: Any) -> KnowledgeType:
"""Normalize config-driven summary knowledge type."""
return resolve_stored_knowledge_type(value, content="")

View File

@@ -0,0 +1,84 @@
"""
IO Utilities
提供原子文件写入等IO辅助功能。
"""
import os
import shutil
import contextlib
from pathlib import Path
from typing import Union
@contextlib.contextmanager
def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs):
"""
原子文件写入上下文管理器
原理:
1. 写入 .tmp 临时文件
2. 写入成功后,使用 os.replace 原子替换目标文件
3. 如果失败,自动删除临时文件
Args:
file_path: 目标文件路径
mode: 打开模式 ('w', 'wb' 等)
encoding: 编码
**kwargs: 传给 open() 的其他参数
"""
path = Path(file_path)
# 确保父目录存在
path.parent.mkdir(parents=True, exist_ok=True)
# 临时文件路径
tmp_path = path.with_suffix(path.suffix + ".tmp")
try:
with open(tmp_path, mode, encoding=encoding, **kwargs) as f:
yield f
# 确保写入磁盘
f.flush()
os.fsync(f.fileno())
# 原子替换 (Windows下可能需要先删除目标文件但 os.replace 在 Py3.3+ 尽可能原子)
# 注意: Windows 上如果有其他进程占用文件os.replace 可能会失败
os.replace(tmp_path, path)
except Exception as e:
# 清理临时文件
if tmp_path.exists():
try:
os.remove(tmp_path)
except:
pass
raise e
@contextlib.contextmanager
def atomic_save_path(file_path: Union[str, Path]):
"""
提供临时路径用于原子保存 (针对只接受路径的API如Faiss)
Args:
file_path: 最终目标路径
Yields:
tmp_path: 临时文件路径 (str)
"""
path = Path(file_path)
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(path.suffix + ".tmp")
try:
yield str(tmp_path)
if Path(tmp_path).exists():
os.replace(tmp_path, path)
except Exception as e:
if Path(tmp_path).exists():
try:
os.remove(tmp_path)
except:
pass
raise e

View File

@@ -0,0 +1,89 @@
"""
高效文本匹配工具模块
实现 Aho-Corasick 算法用于多模式匹配。
"""
from typing import List, Dict, Tuple, Set, Any
from collections import deque
class AhoCorasick:
"""
Aho-Corasick 自动机实现高效多模式匹配
"""
def __init__(self):
# next_states[state][char] = next_state
self.next_states: List[Dict[str, int]] = [{}]
# fail[state] = fail_state
self.fail: List[int] = [0]
# output[state] = set of patterns ending at this state
self.output: List[Set[str]] = [set()]
self.patterns: Set[str] = set()
def add_pattern(self, pattern: str):
"""添加模式"""
if not pattern:
return
self.patterns.add(pattern)
state = 0
for char in pattern:
if char not in self.next_states[state]:
new_state = len(self.next_states)
self.next_states[state][char] = new_state
self.next_states.append({})
self.fail.append(0)
self.output.append(set())
state = self.next_states[state][char]
self.output[state].add(pattern)
def build(self):
"""构建失败指针"""
queue = deque()
# 处理第一层
for char, state in self.next_states[0].items():
queue.append(state)
self.fail[state] = 0
while queue:
r = queue.popleft()
for char, s in self.next_states[r].items():
queue.append(s)
# 找到失败路径
state = self.fail[r]
while char not in self.next_states[state] and state != 0:
state = self.fail[state]
self.fail[s] = self.next_states[state].get(char, 0)
# 合并输出
self.output[s].update(self.output[self.fail[s]])
def search(self, text: str) -> List[Tuple[int, str]]:
"""
在文本中搜索所有模式
Returns:
[(结束索引, 匹配到的模式), ...]
"""
state = 0
results = []
for i, char in enumerate(text):
while char not in self.next_states[state] and state != 0:
state = self.fail[state]
state = self.next_states[state].get(char, 0)
for pattern in self.output[state]:
results.append((i, pattern))
return results
def find_all(self, text: str) -> Dict[str, int]:
"""
查找并统计所有模式出现次数
Returns:
{模式: 出现次数}
"""
results = self.search(text)
stats = {}
for _, pattern in results:
stats[pattern] = stats.get(pattern, 0) + 1
return stats

View File

@@ -0,0 +1,189 @@
"""
内存监控模块
提供内存使用监控和预警功能。
"""
import gc
import threading
import time
from typing import Callable, Optional
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
from src.common.logger import get_logger
logger = get_logger("A_Memorix.MemoryMonitor")
class MemoryMonitor:
"""
内存监控器
功能:
- 实时监控内存使用
- 超过阈值时触发警告
- 支持自动垃圾回收
"""
def __init__(
self,
max_memory_mb: int,
warning_threshold: float = 0.9,
check_interval: float = 10.0,
enable_auto_gc: bool = True,
):
"""
初始化内存监控器
Args:
max_memory_mb: 最大内存限制MB
warning_threshold: 警告阈值0-1之间默认0.9表示90%
check_interval: 检查间隔(秒)
enable_auto_gc: 是否启用自动垃圾回收
"""
self.max_memory_mb = max_memory_mb
self.warning_threshold = warning_threshold
self.check_interval = check_interval
self.enable_auto_gc = enable_auto_gc
self._running = False
self._thread: Optional[threading.Thread] = None
self._callbacks: list[Callable[[float, float], None]] = []
def start(self):
"""启动监控"""
if self._running:
logger.warning("内存监控已在运行")
return
if not HAS_PSUTIL:
logger.warning("psutil 未安装,内存监控功能不可用")
return
self._running = True
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._thread.start()
logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)")
def stop(self):
"""停止监控"""
if not self._running:
return
self._running = False
if self._thread:
self._thread.join(timeout=5.0)
logger.info("内存监控已停止")
def register_callback(self, callback: Callable[[float, float], None]):
"""
注册内存超限回调函数
Args:
callback: 回调函数,接收 (当前使用MB, 限制MB) 参数
"""
self._callbacks.append(callback)
def get_current_memory_mb(self) -> float:
"""
获取当前进程内存使用量MB
Returns:
内存使用量MB
"""
if not HAS_PSUTIL:
# 降级方案:使用内置函数
import sys
return sys.getsizeof(gc.get_objects()) / 1024 / 1024
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def get_memory_usage_ratio(self) -> float:
"""
获取内存使用率
Returns:
使用率0-1之间
"""
current = self.get_current_memory_mb()
return current / self.max_memory_mb if self.max_memory_mb > 0 else 0
def _monitor_loop(self):
"""监控循环"""
while self._running:
try:
current_mb = self.get_current_memory_mb()
ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0
# 检查是否超过阈值
if ratio >= self.warning_threshold:
logger.warning(
f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB "
f"({ratio*100:.1f}%)"
)
# 触发回调
for callback in self._callbacks:
try:
callback(current_mb, self.max_memory_mb)
except Exception as e:
logger.error(f"内存回调执行失败: {e}")
# 自动垃圾回收
if self.enable_auto_gc:
before = self.get_current_memory_mb()
gc.collect()
after = self.get_current_memory_mb()
freed = before - after
if freed > 1: # 释放超过1MB才记录
logger.info(f"垃圾回收释放: {freed:.1f}MB")
# 定期垃圾回收(即使未超限)
elif self.enable_auto_gc and int(time.time()) % 60 == 0:
gc.collect()
except Exception as e:
logger.error(f"内存监控出错: {e}")
# 等待下次检查
time.sleep(self.check_interval)
def __enter__(self):
"""上下文管理器入口"""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
self.stop()
def get_memory_info() -> dict:
"""
获取系统内存信息
Returns:
内存信息字典
"""
if not HAS_PSUTIL:
return {"error": "psutil 未安装"}
try:
mem = psutil.virtual_memory()
process = psutil.Process()
return {
"system_total_gb": mem.total / 1024 / 1024 / 1024,
"system_available_gb": mem.available / 1024 / 1024 / 1024,
"system_usage_percent": mem.percent,
"process_mb": process.memory_info().rss / 1024 / 1024,
"process_percent": (process.memory_info().rss / mem.total) * 100,
}
except Exception as e:
return {"error": str(e)}

View File

@@ -0,0 +1,165 @@
"""Shared path-fallback helpers for search post-processing."""
from __future__ import annotations
import hashlib
from typing import Any, Dict, List, Optional, Sequence, Tuple
from ..retrieval.dual_path import RetrievalResult
def extract_entities(query: str, graph_store: Any) -> List[str]:
"""Extract up to two graph nodes from a query using n-gram matching."""
if not graph_store:
return []
text = str(query or "").strip()
if not text:
return []
# Keep the heuristic aligned with previous legacy behavior.
tokens = (
text.replace("?", " ")
.replace("!", " ")
.replace(".", " ")
.split()
)
if not tokens:
return []
found_entities = set()
skip_indices = set()
max_n = min(4, len(tokens))
for size in range(max_n, 0, -1):
for i in range(len(tokens) - size + 1):
if any(idx in skip_indices for idx in range(i, i + size)):
continue
span = " ".join(tokens[i : i + size])
matched_node = graph_store.find_node(span, ignore_case=True)
if not matched_node:
continue
found_entities.add(matched_node)
for idx in range(i, i + size):
skip_indices.add(idx)
return list(found_entities)
def find_paths_between_entities(
start_node: str,
end_node: str,
graph_store: Any,
metadata_store: Any,
*,
max_depth: int = 3,
max_paths: int = 5,
) -> List[Dict[str, Any]]:
"""Find and enrich indirect paths between two nodes."""
if not graph_store or not metadata_store:
return []
try:
paths = graph_store.find_paths(
start_node,
end_node,
max_depth=max_depth,
max_paths=max_paths,
)
except Exception:
return []
if not paths:
return []
edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {}
formatted_paths: List[Dict[str, Any]] = []
for path_nodes in paths:
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2:
continue
path_desc: List[str] = []
for i in range(len(path_nodes) - 1):
u = str(path_nodes[i])
v = str(path_nodes[i + 1])
cache_key = tuple(sorted((u, v)))
if cache_key in edge_cache:
pred, direction = edge_cache[cache_key]
else:
pred = "related"
direction = "->"
rels = metadata_store.get_relations(subject=u, object=v)
if not rels:
rels = metadata_store.get_relations(subject=v, object=u)
direction = "<-"
if rels:
best_rel = max(rels, key=lambda x: x.get("confidence", 1.0))
pred = str(best_rel.get("predicate", "related") or "related")
edge_cache[cache_key] = (pred, direction)
step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-"
path_desc.append(step_str)
full_path_str = str(path_nodes[0])
for i, step in enumerate(path_desc):
full_path_str += f" {step} {path_nodes[i + 1]}"
formatted_paths.append(
{
"nodes": list(path_nodes),
"description": full_path_str,
}
)
return formatted_paths
def find_paths_from_query(
query: str,
graph_store: Any,
metadata_store: Any,
*,
max_depth: int = 3,
max_paths: int = 5,
) -> List[Dict[str, Any]]:
"""Extract entities from query and resolve indirect paths."""
entities = extract_entities(query, graph_store)
if len(entities) != 2:
return []
return find_paths_between_entities(
entities[0],
entities[1],
graph_store,
metadata_store,
max_depth=max_depth,
max_paths=max_paths,
)
def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]:
"""Convert path results into retrieval results for the unified pipeline."""
converted: List[RetrievalResult] = []
for item in paths:
description = str(item.get("description", "")).strip()
if not description:
continue
hash_seed = description.encode("utf-8")
path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}"
converted.append(
RetrievalResult(
hash_value=path_hash,
content=f"[Indirect Relation] {description}",
score=0.95,
result_type="relation",
source="graph_path",
metadata={
"source": "graph_path",
"is_indirect": True,
"nodes": list(item.get("nodes", [])),
},
)
)
return converted

View File

@@ -0,0 +1,495 @@
"""
人物画像服务
主链路:
person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储
"""
import json
import time
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from ..embedding import EmbeddingAPIAdapter
from ..retrieval import (
DualPathRetriever,
RetrievalStrategy,
DualPathRetrieverConfig,
SparseBM25Config,
FusionConfig,
GraphRelationRecallConfig,
)
from ..storage import MetadataStore, GraphStore, VectorStore
logger = get_logger("A_Memorix.PersonProfileService")
class PersonProfileService:
"""人物画像聚合/刷新服务。"""
def __init__(
self,
metadata_store: MetadataStore,
graph_store: Optional[GraphStore] = None,
vector_store: Optional[VectorStore] = None,
embedding_manager: Optional[EmbeddingAPIAdapter] = None,
sparse_index: Any = None,
plugin_config: Optional[dict] = None,
retriever: Optional[DualPathRetriever] = None,
):
self.metadata_store = metadata_store
self.graph_store = graph_store
self.vector_store = vector_store
self.embedding_manager = embedding_manager
self.sparse_index = sparse_index
self.plugin_config = plugin_config or {}
self.retriever = retriever or self._build_retriever()
def _cfg(self, key: str, default: Any = None) -> Any:
"""读取嵌套配置。"""
if not isinstance(self.plugin_config, dict):
return default
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _build_retriever(self) -> Optional[DualPathRetriever]:
"""按需构建检索器(无依赖时返回 None"""
if not all(
[
self.vector_store is not None,
self.graph_store is not None,
self.metadata_store is not None,
self.embedding_manager is not None,
]
):
return None
try:
sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {}
fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {}
graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {}
if not isinstance(sparse_cfg_raw, dict):
sparse_cfg_raw = {}
if not isinstance(fusion_cfg_raw, dict):
fusion_cfg_raw = {}
if not isinstance(graph_recall_cfg_raw, dict):
graph_recall_cfg_raw = {}
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
fusion_cfg = FusionConfig(**fusion_cfg_raw)
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
config = DualPathRetrieverConfig(
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)),
top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)),
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
alpha=float(self._cfg("retrieval.alpha", 0.5)),
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
debug=bool(self._cfg("advanced.debug", False)),
sparse=sparse_cfg,
fusion=fusion_cfg,
graph_recall=graph_recall_cfg,
)
return DualPathRetriever(
vector_store=self.vector_store,
graph_store=self.graph_store,
metadata_store=self.metadata_store,
embedding_manager=self.embedding_manager,
sparse_index=self.sparse_index,
config=config,
)
except Exception as e:
logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}")
return None
@staticmethod
def resolve_person_id(identifier: str) -> str:
"""按 person_id 或姓名/别名解析 person_id。"""
if not identifier:
return ""
key = str(identifier).strip()
if not key:
return ""
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
return key.lower()
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where(PersonInfo.group_nick_name.contains(key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
return ""
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
if not raw_value:
return []
if isinstance(raw_value, list):
items = raw_value
else:
try:
items = json.loads(raw_value)
except Exception:
return []
names: List[str] = []
for item in items:
if isinstance(item, dict):
value = str(item.get("group_nick_name", "")).strip()
if value:
names.append(value)
elif isinstance(item, str):
value = item.strip()
if value:
names.append(value)
return names
def _parse_memory_traits(self, raw_value: Any) -> List[str]:
if not raw_value:
return []
try:
values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
except Exception:
return []
if not isinstance(values, list):
return []
traits: List[str] = []
for item in values:
text = str(item).strip()
if not text:
continue
if ":" in text:
parts = text.split(":")
if len(parts) >= 3:
content = ":".join(parts[1:-1]).strip()
if content:
traits.append(content)
continue
traits.append(text)
return traits[:10]
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
"""获取人物别名集合、主展示名、记忆特征。"""
aliases: List[str] = []
primary_name = ""
memory_traits: List[str] = []
if not person_id:
return aliases, primary_name, memory_traits
try:
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not record:
return aliases, primary_name, memory_traits
person_name = str(getattr(record, "person_name", "") or "").strip()
nickname = str(getattr(record, "nickname", "") or "").strip()
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
candidates = [person_name, nickname] + group_nicks
seen = set()
for item in candidates:
norm = str(item or "").strip()
if not norm or norm in seen:
continue
seen.add(norm)
aliases.append(norm)
except Exception as e:
logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}")
return aliases, primary_name, memory_traits
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
relation_by_hash: Dict[str, Dict[str, Any]] = {}
for alias in aliases:
for rel in self.metadata_store.get_relations(subject=alias):
h = str(rel.get("hash", ""))
if h:
relation_by_hash[h] = rel
for rel in self.metadata_store.get_relations(object=alias):
h = str(rel.get("hash", ""))
if h:
relation_by_hash[h] = rel
relations = list(relation_by_hash.values())
relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True)
relations = relations[: max(1, int(limit))]
edges: List[Dict[str, Any]] = []
for rel in relations:
edges.append(
{
"hash": str(rel.get("hash", "")),
"subject": str(rel.get("subject", "")),
"predicate": str(rel.get("predicate", "")),
"object": str(rel.get("object", "")),
"confidence": float(rel.get("confidence", 1.0) or 1.0),
}
)
return edges
async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]:
alias_queries = [a for a in aliases if a]
if not alias_queries:
return []
if self.retriever is None:
# 回退:无检索器时只做简单内容匹配
fallback: List[Dict[str, Any]] = []
seen_hash = set()
for alias in alias_queries:
for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]:
h = str(para.get("hash", ""))
if not h or h in seen_hash:
continue
seen_hash.add(h)
fallback.append(
{
"hash": h,
"type": "paragraph",
"score": 0.0,
"content": str(para.get("content", ""))[:180],
"metadata": {},
}
)
return fallback[:top_k]
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
seen_hash = set()
evidence: List[Dict[str, Any]] = []
for alias in alias_queries:
try:
results = await self.retriever.retrieve(alias, top_k=per_alias_top_k)
except Exception as e:
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
continue
for item in results:
h = str(getattr(item, "hash_value", "") or "")
if not h or h in seen_hash:
continue
seen_hash.add(h)
evidence.append(
{
"hash": h,
"type": str(getattr(item, "result_type", "")),
"score": float(getattr(item, "score", 0.0) or 0.0),
"content": str(getattr(item, "content", "") or "")[:220],
"metadata": dict(getattr(item, "metadata", {}) or {}),
}
)
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
return evidence[:top_k]
def _build_profile_text(
self,
person_id: str,
primary_name: str,
aliases: List[str],
relation_edges: List[Dict[str, Any]],
vector_evidence: List[Dict[str, Any]],
memory_traits: List[str],
) -> str:
"""基于证据构建画像文本(供 LLM 上下文注入)。"""
lines: List[str] = []
lines.append(f"人物ID: {person_id}")
if primary_name:
lines.append(f"主称呼: {primary_name}")
if aliases:
lines.append(f"别名: {', '.join(aliases[:8])}")
if memory_traits:
lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}")
if relation_edges:
lines.append("关系证据:")
for rel in relation_edges[:6]:
s = rel.get("subject", "")
p = rel.get("predicate", "")
o = rel.get("object", "")
conf = float(rel.get("confidence", 0.0))
lines.append(f"- {s} {p} {o} (conf={conf:.2f})")
if vector_evidence:
lines.append("向量证据摘要:")
for item in vector_evidence[:4]:
content = str(item.get("content", "")).strip()
if content:
lines.append(f"- {content}")
if len(lines) <= 2:
lines.append("暂无足够证据形成稳定画像。")
return "\n".join(lines)
@staticmethod
def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool:
if not snapshot:
return True
now = time.time()
expires_at = snapshot.get("expires_at")
if expires_at is not None:
try:
return now >= float(expires_at)
except Exception:
return True
updated_at = float(snapshot.get("updated_at") or 0.0)
return (now - updated_at) >= ttl_seconds
def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]:
"""将手工覆盖并入画像结果(覆盖 profile_text同时保留 auto_profile_text"""
payload = dict(profile_payload or {})
auto_text = str(payload.get("profile_text", "") or "")
payload["auto_profile_text"] = auto_text
payload["has_manual_override"] = False
payload["manual_override_text"] = ""
payload["override_updated_at"] = None
payload["override_updated_by"] = ""
payload["profile_source"] = "auto_snapshot"
if not person_id or self.metadata_store is None:
return payload
try:
override = self.metadata_store.get_person_profile_override(person_id)
except Exception as e:
logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}")
return payload
if not override:
return payload
manual_text = str(override.get("override_text", "") or "").strip()
if not manual_text:
return payload
payload["has_manual_override"] = True
payload["manual_override_text"] = manual_text
payload["override_updated_at"] = override.get("updated_at")
payload["override_updated_by"] = str(override.get("updated_by", "") or "")
payload["profile_text"] = manual_text
payload["profile_source"] = "manual_override"
return payload
async def query_person_profile(
self,
person_id: str = "",
person_keyword: str = "",
top_k: int = 12,
ttl_seconds: float = 6 * 3600,
force_refresh: bool = False,
source_note: str = "",
) -> Dict[str, Any]:
"""查询或刷新人物画像。"""
pid = str(person_id or "").strip()
if not pid and person_keyword:
pid = self.resolve_person_id(person_keyword)
if not pid:
return {
"success": False,
"error": "person_id 无效,且未能通过别名解析",
}
latest = self.metadata_store.get_latest_person_profile_snapshot(pid)
if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds):
aliases, primary_name, _ = self.get_person_aliases(pid)
payload = {
"success": True,
"person_id": pid,
"person_name": primary_name,
"from_cache": True,
**(latest or {}),
}
if aliases and not payload.get("aliases"):
payload["aliases"] = aliases
return {
**self._apply_manual_override(pid, payload),
}
aliases, primary_name, memory_traits = self.get_person_aliases(pid)
if not aliases and person_keyword:
aliases = [person_keyword.strip()]
primary_name = person_keyword.strip()
relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2))
vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k))
evidence_ids = [
str(item.get("hash", ""))
for item in (relation_edges + vector_evidence)
if str(item.get("hash", "")).strip()
]
dedup_ids: List[str] = []
seen = set()
for item in evidence_ids:
if item in seen:
continue
seen.add(item)
dedup_ids.append(item)
profile_text = self._build_profile_text(
person_id=pid,
primary_name=primary_name,
aliases=aliases,
relation_edges=relation_edges,
vector_evidence=vector_evidence,
memory_traits=memory_traits,
)
expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None
snapshot = self.metadata_store.upsert_person_profile_snapshot(
person_id=pid,
profile_text=profile_text,
aliases=aliases,
relation_edges=relation_edges,
vector_evidence=vector_evidence,
evidence_ids=dedup_ids,
expires_at=expires_at,
source_note=source_note,
)
payload = {
"success": True,
"person_id": pid,
"person_name": primary_name,
"from_cache": False,
**snapshot,
}
return {
**self._apply_manual_override(pid, payload),
}
@staticmethod
def format_persona_profile_block(profile: Dict[str, Any]) -> str:
"""格式化给 replyer 的注入块。"""
if not profile or not profile.get("success"):
return ""
text = str(profile.get("profile_text", "") or "").strip()
if not text:
return ""
return (
"【人物画像-内部参考】\n"
f"{text}\n"
"仅供内部推理,不要向用户逐字复述。"
)

View File

@@ -0,0 +1,27 @@
"""Plugin ID matching policy for A_Memorix."""
from __future__ import annotations
from typing import Any
class PluginIdPolicy:
"""Centralized plugin id normalization/matching policy."""
CANONICAL_ID = "a_memorix"
@classmethod
def normalize(cls, plugin_id: Any) -> str:
if not isinstance(plugin_id, str):
return ""
return plugin_id.strip().lower()
@classmethod
def is_target_plugin_id(cls, plugin_id: Any) -> bool:
normalized = cls.normalize(plugin_id)
if not normalized:
return False
if normalized == cls.CANONICAL_ID:
return True
return normalized.split(".")[-1] == cls.CANONICAL_ID

View File

@@ -0,0 +1,344 @@
"""
向量量化工具模块
提供向量量化与反量化功能,用于压缩存储空间。
"""
import numpy as np
from enum import Enum
from typing import Tuple, Union
from src.common.logger import get_logger
logger = get_logger("A_Memorix.Quantization")
class QuantizationType(Enum):
"""量化类型枚举"""
FLOAT32 = "float32" # 无量化
INT8 = "int8" # 标量量化8位整数
PQ = "pq" # 乘积量化Product Quantization
def quantize_vector(
vector: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
量化向量
Args:
vector: 输入向量float32
quant_type: 量化类型
Returns:
量化后的向量:
- INT8: int8向量
- PQ: (编码向量, 聚类中心) 元组
"""
if quant_type == QuantizationType.FLOAT32:
return vector.astype(np.float32)
elif quant_type == QuantizationType.INT8:
return _scalar_quantize_int8(vector)
elif quant_type == QuantizationType.PQ:
return _product_quantize(vector)
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def dequantize_vector(
quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
quant_type: QuantizationType = QuantizationType.INT8,
original_shape: Tuple[int, ...] = None,
) -> np.ndarray:
"""
反量化向量
Args:
quantized_vector: 量化后的向量
quant_type: 量化类型
original_shape: 原始向量形状用于PQ
Returns:
反量化后的向量float32
"""
if quant_type == QuantizationType.FLOAT32:
return quantized_vector.astype(np.float32)
elif quant_type == QuantizationType.INT8:
return _scalar_dequantize_int8(quantized_vector)
elif quant_type == QuantizationType.PQ:
if not isinstance(quantized_vector, tuple):
raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)")
return _product_dequantize(quantized_vector[0], quantized_vector[1])
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray:
"""
标量量化float32 -> int8
将向量归一化到 [0, 255] 范围,然后映射到 int8
Args:
vector: 输入向量
Returns:
量化后的 int8 向量
"""
# 计算最小最大值
min_val = np.min(vector)
max_val = np.max(vector)
# 避免除零
if max_val == min_val:
return np.zeros_like(vector, dtype=np.int8)
# 归一化到 [0, 255]
normalized = (vector - min_val) / (max_val - min_val) * 255
# 映射到 [-128, 127] 并转换为 int8
# np.round might return float, minus 128 then cast
quantized = np.round(normalized - 128.0).astype(np.int8)
# 存储归一化参数(用于反量化)
# 在实际存储中,这些参数需要单独保存
# 这里为了简单,我们使用一个全局字典来模拟
if not hasattr(_scalar_quantize_int8, "_params"):
_scalar_quantize_int8._params = {}
vector_id = id(vector)
_scalar_quantize_int8._params[vector_id] = (min_val, max_val)
return quantized
def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray:
"""
标量反量化int8 -> float32
Args:
quantized: 量化后的 int8 向量
Returns:
反量化后的 float32 向量
"""
# 计算归一化参数(如果提供了)
# 在实际应用中min_val 和 max_val 应该被保存
if not hasattr(_scalar_dequantize_int8, "_params"):
# 默认假设范围是 [-1, 1]
return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0
# 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取)
# return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min
return (quantized.astype(np.float32) + 128.0) / 255.0
def quantize_matrix(
matrix: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
量化矩阵(批量量化向量)
Args:
matrix: 输入矩阵N x D每行是一个向量
quant_type: 量化类型
Returns:
量化后的矩阵
"""
if quant_type == QuantizationType.FLOAT32:
return matrix.astype(np.float32)
elif quant_type == QuantizationType.INT8:
# 对整个矩阵进行全局归一化
min_val = np.min(matrix)
max_val = np.max(matrix)
if max_val == min_val:
return np.zeros_like(matrix, dtype=np.int8)
# 归一化到 [0, 255]
normalized = (matrix - min_val) / (max_val - min_val) * 255
quantized = np.round(normalized).astype(np.int8)
return quantized
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def dequantize_matrix(
quantized_matrix: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
min_val: float = None,
max_val: float = None,
) -> np.ndarray:
"""
反量化矩阵
Args:
quantized_matrix: 量化后的矩阵
quant_type: 量化类型
min_val: 归一化最小值int8反量化需要
max_val: 归一化最大值int8反量化需要
Returns:
反量化后的矩阵
"""
if quant_type == QuantizationType.FLOAT32:
return quantized_matrix.astype(np.float32)
elif quant_type == QuantizationType.INT8:
# 使用提供的归一化参数反量化
if min_val is None or max_val is None:
# 默认假设范围是 [0, 255] -> [-1, 1]
return quantized_matrix.astype(np.float32) / 127.0
else:
# 恢复到原始范围
normalized = quantized_matrix.astype(np.float32) / 255.0
return normalized * (max_val - min_val) + min_val
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def estimate_memory_reduction(
num_vectors: int,
dimension: int,
from_type: QuantizationType,
to_type: QuantizationType,
) -> Tuple[float, float]:
"""
估算内存节省量
Args:
num_vectors: 向量数量
dimension: 向量维度
from_type: 原始量化类型
to_type: 目标量化类型
Returns:
(原始大小MB, 量化后大小MB, 节省比例)
"""
# 计算每个向量占用的字节数
bytes_per_element = {
QuantizationType.FLOAT32: 4,
QuantizationType.INT8: 1,
QuantizationType.PQ: 0.25, # 假设压缩到1/4
}
original_bytes = num_vectors * dimension * bytes_per_element[from_type]
quantized_bytes = num_vectors * dimension * bytes_per_element[to_type]
original_mb = original_bytes / 1024 / 1024
quantized_mb = quantized_bytes / 1024 / 1024
reduction_ratio = (original_bytes - quantized_bytes) / original_bytes
return original_mb, quantized_mb, reduction_ratio
def estimate_compression_stats(
num_vectors: int,
dimension: int,
quant_type: QuantizationType,
) -> dict:
"""
估算压缩统计信息
Args:
num_vectors: 向量数量
dimension: 向量维度
quant_type: 量化类型
Returns:
统计信息字典
"""
original_mb, quantized_mb, ratio = estimate_memory_reduction(
num_vectors, dimension, QuantizationType.FLOAT32, quant_type
)
return {
"num_vectors": num_vectors,
"dimension": dimension,
"quantization_type": quant_type.value,
"original_size_mb": round(original_mb, 2),
"quantized_size_mb": round(quantized_mb, 2),
"saved_mb": round(original_mb - quantized_mb, 2),
"compression_ratio": round(ratio * 100, 2),
}
def _product_quantize(
vector: np.ndarray, m: int = 8, k: int = 256
) -> Tuple[np.ndarray, np.ndarray]:
"""
乘积量化 (PQ) 简化实现
Args:
vector: 输入向量 (D,)
m: 子空间数量
k: 每个子空间的聚类中心数
Returns:
(编码后的向量, 聚类中心)
"""
d = vector.shape[0]
if d % m != 0:
raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除")
ds = d // m # 子空间维度
codes = np.zeros(m, dtype=np.uint8)
centroids = np.zeros((m, k, ds), dtype=np.float32)
# 这里采用一种简化的 PQ不进行 K-means 训练,
# 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练)
# 为了演示,我们直接将子空间切分为 k 份进行量化
for i in range(m):
sub_vec = vector[i * ds : (i + 1) * ds]
# 简化:假定每个子空间的取值范围并划分
# 实际 PQ 应使用 k-means 产生的 centroids
# 这里为演示创建一个随机 codebook 并找到最接近的核心
sub_min, sub_max = np.min(sub_vec), np.max(sub_vec)
if sub_max == sub_min:
linspace = np.zeros(k)
else:
linspace = np.linspace(sub_min, sub_max, k)
for j in range(k):
centroids[i, j, :] = linspace[j]
# 编码:这里简化为取子空间均值找最接近的 centroid
sub_mean = np.mean(sub_vec)
code = np.argmin(np.abs(linspace - sub_mean))
codes[i] = code
return codes, centroids
def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""
PQ 反量化
Args:
codes: 编码向量 (M,)
centroids: 聚类中心 (M, K, DS)
Returns:
恢复后的向量 (D,)
"""
m, k, ds = centroids.shape
vector = np.zeros(m * ds, dtype=np.float32)
for i in range(m):
code = codes[i]
vector[i * ds : (i + 1) * ds] = centroids[i, code, :]
return vector

View File

@@ -0,0 +1,121 @@
"""关系查询规格解析工具。"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Optional
@dataclass
class RelationQuerySpec:
raw: str
is_structured: bool
subject: Optional[str]
predicate: Optional[str]
object: Optional[str]
error: Optional[str] = None
_NATURAL_LANGUAGE_PATTERN = re.compile(
r"(^\s*(what|who|which|how|why|when|where)\b|"
r"\?||"
r"\b(relation|related|between)\b|"
r"(什么关系|有哪些关系|之间|关联))",
re.IGNORECASE,
)
def _looks_like_natural_language(raw: str) -> bool:
text = str(raw or "").strip()
if not text:
return False
return _NATURAL_LANGUAGE_PATTERN.search(text) is not None
def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec:
raw = str(relation_spec or "").strip()
if not raw:
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
error="empty",
)
if "|" in raw:
parts = [p.strip() for p in raw.split("|")]
if len(parts) < 2:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=None,
predicate=None,
object=None,
error="invalid_pipe_format",
)
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0] or None,
predicate=parts[1] or None,
object=parts[2] if len(parts) > 2 and parts[2] else None,
)
if "->" in raw:
parts = [p.strip() for p in raw.split("->") if p.strip()]
if len(parts) >= 3:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=parts[1],
object=parts[2],
)
if len(parts) == 2:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=None,
object=parts[1],
)
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=None,
predicate=None,
object=None,
error="invalid_arrow_format",
)
if _looks_like_natural_language(raw):
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
)
# 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。
# 两词形式过于模糊,不再视为结构化关系查询。
parts = raw.split()
if len(parts) == 3:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=parts[1],
object=parts[2],
)
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
)

View File

@@ -0,0 +1,164 @@
"""
统一关系写入与关系向量化服务。
规则:
1. 元数据是主数据源,向量是从索引。
2. 关系先写 metadata再写向量。
3. 向量失败不回滚 metadata依赖状态机与回填任务修复。
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional
from src.common.logger import get_logger
logger = get_logger("A_Memorix.RelationWriteService")
@dataclass
class RelationWriteResult:
hash_value: str
vector_written: bool
vector_already_exists: bool
vector_state: str
class RelationWriteService:
"""关系写入收口服务。"""
ERROR_MAX_LEN = 500
def __init__(
self,
metadata_store: Any,
graph_store: Any,
vector_store: Any,
embedding_manager: Any,
):
self.metadata_store = metadata_store
self.graph_store = graph_store
self.vector_store = vector_store
self.embedding_manager = embedding_manager
@staticmethod
def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str:
s = str(subject or "").strip()
p = str(predicate or "").strip()
o = str(obj or "").strip()
# 双表达:兼容关键词检索与自然语言问句
return f"{s} {p} {o}\n{s}{o}的关系是{p}"
async def ensure_relation_vector(
self,
hash_value: str,
subject: str,
predicate: str,
obj: str,
*,
max_error_len: int = ERROR_MAX_LEN,
) -> RelationWriteResult:
"""
为已有关系确保向量存在并更新状态。
"""
if hash_value in self.vector_store:
self.metadata_store.set_relation_vector_state(hash_value, "ready")
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=True,
vector_state="ready",
)
self.metadata_store.set_relation_vector_state(hash_value, "pending")
try:
vector_text = self.build_relation_vector_text(subject, predicate, obj)
embedding = await self.embedding_manager.encode(vector_text)
self.vector_store.add(
vectors=embedding.reshape(1, -1),
ids=[hash_value],
)
self.metadata_store.set_relation_vector_state(hash_value, "ready")
logger.info(
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
hash_value[:16],
)
return RelationWriteResult(
hash_value=hash_value,
vector_written=True,
vector_already_exists=False,
vector_state="ready",
)
except ValueError:
# 向量已存在冲突,按成功处理
self.metadata_store.set_relation_vector_state(hash_value, "ready")
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=True,
vector_state="ready",
)
except Exception as e:
err = str(e)[:max_error_len]
self.metadata_store.set_relation_vector_state(
hash_value,
"failed",
error=err,
bump_retry=True,
)
logger.warning(
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
hash_value[:16],
err,
)
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=False,
vector_state="failed",
)
async def upsert_relation_with_vector(
self,
subject: str,
predicate: str,
obj: str,
confidence: float = 1.0,
source_paragraph: str = "",
metadata: Optional[Dict[str, Any]] = None,
*,
write_vector: bool = True,
) -> RelationWriteResult:
"""
统一关系写入:
1) 写 metadata relation
2) 写 graph edge relation_hash
3) 按需写 relation vector
"""
rel_hash = self.metadata_store.add_relation(
subject=subject,
predicate=predicate,
obj=obj,
confidence=confidence,
source_paragraph=source_paragraph,
metadata=metadata or {},
)
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
if not write_vector:
self.metadata_store.set_relation_vector_state(rel_hash, "none")
return RelationWriteResult(
hash_value=rel_hash,
vector_written=False,
vector_already_exists=False,
vector_state="none",
)
return await self.ensure_relation_vector(
hash_value=rel_hash,
subject=subject,
predicate=predicate,
obj=obj,
)

View File

@@ -0,0 +1,197 @@
"""Runtime self-check helpers for A_Memorix."""
from __future__ import annotations
import time
from typing import Any, Dict, Optional
import numpy as np
from src.common.logger import get_logger
logger = get_logger("A_Memorix.RuntimeSelfCheck")
_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check"
def _safe_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _get_config_value(config: Any, key: str, default: Any = None) -> Any:
getter = getattr(config, "get_config", None)
if callable(getter):
return getter(key, default)
current: Any = config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _build_report(
*,
ok: bool,
code: str,
message: str,
configured_dimension: int,
vector_store_dimension: int,
detected_dimension: int,
encoded_dimension: int,
elapsed_ms: float,
sample_text: str,
) -> Dict[str, Any]:
return {
"ok": bool(ok),
"code": str(code or "").strip(),
"message": str(message or "").strip(),
"configured_dimension": int(configured_dimension),
"vector_store_dimension": int(vector_store_dimension),
"detected_dimension": int(detected_dimension),
"encoded_dimension": int(encoded_dimension),
"elapsed_ms": float(elapsed_ms),
"sample_text": str(sample_text or ""),
"checked_at": time.time(),
}
async def run_embedding_runtime_self_check(
*,
config: Any,
vector_store: Optional[Any],
embedding_manager: Optional[Any],
sample_text: str = _DEFAULT_SAMPLE_TEXT,
) -> Dict[str, Any]:
"""Probe the real embedding path and compare dimensions with runtime storage."""
configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0)
vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0)
if vector_store is None or embedding_manager is None:
return _build_report(
ok=False,
code="runtime_components_missing",
message="vector_store 或 embedding_manager 未初始化",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=0,
encoded_dimension=0,
elapsed_ms=0.0,
sample_text=sample_text,
)
start = time.perf_counter()
detected_dimension = 0
encoded_dimension = 0
try:
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
encoded = await embedding_manager.encode(sample_text)
if isinstance(encoded, np.ndarray):
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
else:
encoded_dimension = len(encoded) if encoded is not None else 0
except Exception as exc:
elapsed_ms = (time.perf_counter() - start) * 1000.0
logger.warning("embedding runtime self-check failed: %s", exc)
return _build_report(
ok=False,
code="embedding_probe_failed",
message=f"embedding probe failed: {exc}",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
elapsed_ms = (time.perf_counter() - start) * 1000.0
expected_dimension = vector_store_dimension or configured_dimension or detected_dimension
if expected_dimension <= 0:
return _build_report(
ok=False,
code="invalid_expected_dimension",
message="无法确定期望 embedding 维度",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
if encoded_dimension != expected_dimension:
msg = (
"embedding 真实输出维度与当前向量存储不一致: "
f"expected={expected_dimension}, encoded={encoded_dimension}"
)
logger.error(msg)
return _build_report(
ok=False,
code="embedding_dimension_mismatch",
message=msg,
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
return _build_report(
ok=True,
code="ok",
message="embedding runtime self-check passed",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
async def ensure_runtime_self_check(
plugin_or_config: Any,
*,
force: bool = False,
sample_text: str = _DEFAULT_SAMPLE_TEXT,
) -> Dict[str, Any]:
"""Run or reuse cached runtime self-check report."""
if plugin_or_config is None:
return _build_report(
ok=False,
code="missing_plugin_or_config",
message="plugin/config unavailable",
configured_dimension=0,
vector_store_dimension=0,
detected_dimension=0,
encoded_dimension=0,
elapsed_ms=0.0,
sample_text=sample_text,
)
cache = getattr(plugin_or_config, "_runtime_self_check_report", None)
if isinstance(cache, dict) and cache and not force:
return cache
report = await run_embedding_runtime_self_check(
config=getattr(plugin_or_config, "config", plugin_or_config),
vector_store=getattr(plugin_or_config, "vector_store", None)
if not isinstance(plugin_or_config, dict)
else plugin_or_config.get("vector_store"),
embedding_manager=getattr(plugin_or_config, "embedding_manager", None)
if not isinstance(plugin_or_config, dict)
else plugin_or_config.get("embedding_manager"),
sample_text=sample_text,
)
try:
setattr(plugin_or_config, "_runtime_self_check_report", report)
except Exception:
pass
return report

View File

@@ -0,0 +1,90 @@
"""Post-processing helpers for unified search execution."""
from __future__ import annotations
from typing import Any, List, Tuple
from .path_fallback_service import find_paths_from_query, to_retrieval_results
def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]:
"""Deduplicate results by hash/content while preserving at least one entry."""
if not results:
return [], 0
unique_results: List[Any] = []
seen_hashes = set()
seen_contents = set()
for item in results:
content = str(getattr(item, "content", "") or "").strip()
if not content:
continue
hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content))
if hash_value in seen_hashes:
continue
is_dup = False
for seen in seen_contents:
if content in seen or seen in content:
is_dup = True
break
if is_dup:
continue
seen_hashes.add(hash_value)
seen_contents.add(content)
unique_results.append(item)
if not unique_results:
unique_results.append(results[0])
removed_count = max(0, len(results) - len(unique_results))
return unique_results, removed_count
def maybe_apply_smart_path_fallback(
*,
query: str,
results: List[Any],
graph_store: Any,
metadata_store: Any,
enabled: bool,
threshold: float,
max_depth: int = 3,
max_paths: int = 5,
) -> Tuple[List[Any], bool, int]:
"""Append indirect relation paths when semantic results are weak."""
if not enabled or not str(query or "").strip():
return results, False, 0
if graph_store is None or metadata_store is None:
return results, False, 0
max_score = 0.0
if results:
try:
max_score = float(getattr(results[0], "score", 0.0) or 0.0)
except Exception:
max_score = 0.0
if max_score >= float(threshold):
return results, False, 0
paths = find_paths_from_query(
query=query,
graph_store=graph_store,
metadata_store=metadata_store,
max_depth=max_depth,
max_paths=max_paths,
)
if not paths:
return results, False, 0
path_results = to_retrieval_results(paths)
if not path_results:
return results, False, 0
merged = list(path_results) + list(results)
return merged, True, len(path_results)

View File

@@ -0,0 +1,170 @@
"""
时间解析工具。
约束:
1. 查询参数Action/Command/Tool仅接受结构化绝对时间
- YYYY/MM/DD
- YYYY/MM/DD HH:mm
2. 入库时允许更宽松格式含时间戳、YYYY-MM-DD 等)。
"""
from __future__ import annotations
import re
from datetime import datetime
from typing import Any, Dict, Optional, Tuple
_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$")
_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$")
_INGEST_FORMATS = [
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
"%Y/%m/%d",
"%Y-%m-%d",
]
_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"}
def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float:
"""解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。"""
text = str(value).strip()
if not text:
raise ValueError("时间不能为空")
if _QUERY_DATE_RE.fullmatch(text):
dt = datetime.strptime(text, "%Y/%m/%d")
if is_end:
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
return dt.timestamp()
if _QUERY_MINUTE_RE.fullmatch(text):
dt = datetime.strptime(text, "%Y/%m/%d %H:%M")
return dt.timestamp()
raise ValueError(
f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm"
)
def parse_query_time_range(
time_from: Optional[str],
time_to: Optional[str],
) -> Tuple[Optional[float], Optional[float]]:
"""解析查询窗口并验证区间。"""
ts_from = (
parse_query_datetime_to_timestamp(time_from, is_end=False)
if time_from
else None
)
ts_to = (
parse_query_datetime_to_timestamp(time_to, is_end=True)
if time_to
else None
)
if ts_from is not None and ts_to is not None and ts_from > ts_to:
raise ValueError("time_from 不能晚于 time_to")
return ts_from, ts_to
def parse_ingest_datetime_to_timestamp(
value: Any,
is_end: bool = False,
) -> Optional[float]:
"""解析入库时间,允许 timestamp/常见字符串格式。"""
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
text = str(value).strip()
if not text:
return None
if _NUMERIC_RE.fullmatch(text):
return float(text)
for fmt in _INGEST_FORMATS:
try:
dt = datetime.strptime(text, fmt)
if fmt in _INGEST_DATE_FORMATS and is_end:
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
return dt.timestamp()
except ValueError:
continue
raise ValueError(f"无法解析时间: {text}")
def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""归一化 time_meta 到存储层字段。"""
if not time_meta:
return {}
normalized: Dict[str, Any] = {}
event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time"))
event_start = parse_ingest_datetime_to_timestamp(
time_meta.get("event_time_start"),
is_end=False,
)
event_end = parse_ingest_datetime_to_timestamp(
time_meta.get("event_time_end"),
is_end=True,
)
time_range = time_meta.get("time_range")
if (
isinstance(time_range, (list, tuple))
and len(time_range) == 2
):
if event_start is None:
event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False)
if event_end is None:
event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True)
if event_start is not None and event_end is not None and event_start > event_end:
raise ValueError("event_time_start 不能晚于 event_time_end")
if event_time is not None:
normalized["event_time"] = event_time
if event_start is not None:
normalized["event_time_start"] = event_start
if event_end is not None:
normalized["event_time_end"] = event_end
granularity = time_meta.get("time_granularity")
if granularity:
normalized["time_granularity"] = str(granularity)
else:
raw_time_values = [
time_meta.get("event_time"),
time_meta.get("event_time_start"),
time_meta.get("event_time_end"),
]
has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None)
normalized["time_granularity"] = "minute" if has_minute else "day"
confidence = time_meta.get("time_confidence")
if confidence is not None:
normalized["time_confidence"] = float(confidence)
return normalized
def format_timestamp(ts: Optional[float]) -> Optional[str]:
"""将 timestamp 格式化为 YYYY/MM/DD HH:mm。"""
if ts is None:
return None
return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M")