feat:新增 A_Memorix 记忆插件
引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。
This commit is contained in:
33
plugins/A_memorix/core/utils/__init__.py
Normal file
33
plugins/A_memorix/core/utils/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""工具模块 - 哈希、监控等辅助功能"""
|
||||
|
||||
from .hash import compute_hash, normalize_text
|
||||
from .monitor import MemoryMonitor
|
||||
from .quantization import quantize_vector, dequantize_vector
|
||||
from .time_parser import (
|
||||
parse_query_datetime_to_timestamp,
|
||||
parse_query_time_range,
|
||||
parse_ingest_datetime_to_timestamp,
|
||||
normalize_time_meta,
|
||||
format_timestamp,
|
||||
)
|
||||
from .relation_write_service import RelationWriteService, RelationWriteResult
|
||||
from .relation_query import RelationQuerySpec, parse_relation_query_spec
|
||||
from .plugin_id_policy import PluginIdPolicy
|
||||
|
||||
__all__ = [
|
||||
"compute_hash",
|
||||
"normalize_text",
|
||||
"MemoryMonitor",
|
||||
"quantize_vector",
|
||||
"dequantize_vector",
|
||||
"parse_query_datetime_to_timestamp",
|
||||
"parse_query_time_range",
|
||||
"parse_ingest_datetime_to_timestamp",
|
||||
"normalize_time_meta",
|
||||
"format_timestamp",
|
||||
"RelationWriteService",
|
||||
"RelationWriteResult",
|
||||
"RelationQuerySpec",
|
||||
"parse_relation_query_spec",
|
||||
"PluginIdPolicy",
|
||||
]
|
||||
360
plugins/A_memorix/core/utils/aggregate_query_service.py
Normal file
360
plugins/A_memorix/core/utils/aggregate_query_service.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
聚合查询服务:
|
||||
- 并发执行 search/time/episode 分支
|
||||
- 统一分支结果结构
|
||||
- 可选混合排序(Weighted RRF)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.AggregateQueryService")
|
||||
|
||||
BranchRunner = Callable[[], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class AggregateQueryService:
|
||||
"""聚合查询执行服务(search/time/episode)。"""
|
||||
|
||||
def __init__(self, plugin_config: Optional[Any] = None):
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(self.plugin_config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _as_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return float(default)
|
||||
|
||||
@staticmethod
|
||||
def _as_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
def _rrf_k(self) -> float:
|
||||
raw = self._cfg("retrieval.aggregate.rrf_k", 60.0)
|
||||
value = self._as_float(raw, 60.0)
|
||||
return max(1.0, value)
|
||||
|
||||
def _weights(self) -> Dict[str, float]:
|
||||
defaults = {"search": 1.0, "time": 1.0, "episode": 1.0}
|
||||
raw = self._cfg("retrieval.aggregate.weights", {})
|
||||
if not isinstance(raw, dict):
|
||||
return defaults
|
||||
|
||||
out = dict(defaults)
|
||||
for key in ("search", "time", "episode"):
|
||||
if key in raw:
|
||||
out[key] = max(0.0, self._as_float(raw.get(key), defaults[key]))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _normalize_branch_payload(
|
||||
name: str,
|
||||
payload: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
data = payload if isinstance(payload, dict) else {}
|
||||
results_raw = data.get("results", [])
|
||||
results = results_raw if isinstance(results_raw, list) else []
|
||||
count = data.get("count")
|
||||
if count is None:
|
||||
count = len(results)
|
||||
return {
|
||||
"name": name,
|
||||
"success": bool(data.get("success", False)),
|
||||
"skipped": bool(data.get("skipped", False)),
|
||||
"skip_reason": str(data.get("skip_reason", "") or "").strip(),
|
||||
"error": str(data.get("error", "") or "").strip(),
|
||||
"results": results,
|
||||
"count": max(0, int(count)),
|
||||
"elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)),
|
||||
"content": str(data.get("content", "") or ""),
|
||||
"query_type": str(data.get("query_type", "") or name),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str:
|
||||
item_type = str(item.get("type", "") or "").strip().lower()
|
||||
if item_type == "episode":
|
||||
episode_id = str(item.get("episode_id", "") or "").strip()
|
||||
if episode_id:
|
||||
return f"episode:{episode_id}"
|
||||
|
||||
item_hash = str(item.get("hash", "") or "").strip()
|
||||
if item_hash:
|
||||
return f"{item_type}:{item_hash}"
|
||||
|
||||
return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}"
|
||||
|
||||
def _build_mixed_results(
|
||||
self,
|
||||
*,
|
||||
branches: Dict[str, Dict[str, Any]],
|
||||
top_k: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
rrf_k = self._rrf_k()
|
||||
weights = self._weights()
|
||||
bucket: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for branch_name, branch in branches.items():
|
||||
if not branch.get("success", False):
|
||||
continue
|
||||
results = branch.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
continue
|
||||
|
||||
weight = max(0.0, float(weights.get(branch_name, 1.0)))
|
||||
for idx, item in enumerate(results, start=1):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
key = self._mix_key(item, branch_name, idx)
|
||||
score = weight / (rrf_k + float(idx))
|
||||
if key not in bucket:
|
||||
merged = dict(item)
|
||||
merged["fusion_score"] = 0.0
|
||||
merged["_source_branches"] = set()
|
||||
bucket[key] = merged
|
||||
|
||||
target = bucket[key]
|
||||
target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score
|
||||
target["_source_branches"].add(branch_name)
|
||||
|
||||
mixed = list(bucket.values())
|
||||
mixed.sort(
|
||||
key=lambda x: (
|
||||
-float(x.get("fusion_score", 0.0)),
|
||||
str(x.get("type", "") or ""),
|
||||
str(x.get("hash", "") or x.get("episode_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
out: List[Dict[str, Any]] = []
|
||||
for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1):
|
||||
merged = dict(item)
|
||||
branches_set = merged.pop("_source_branches", set())
|
||||
merged["source_branches"] = sorted(list(branches_set))
|
||||
merged["rank"] = rank
|
||||
out.append(merged)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _status(branch: Dict[str, Any]) -> str:
|
||||
if branch.get("skipped", False):
|
||||
return "skipped"
|
||||
if branch.get("success", False):
|
||||
return "success"
|
||||
return "failed"
|
||||
|
||||
def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
summary: Dict[str, Dict[str, Any]] = {}
|
||||
for name, branch in branches.items():
|
||||
status = self._status(branch)
|
||||
summary[name] = {
|
||||
"status": status,
|
||||
"count": int(branch.get("count", 0) or 0),
|
||||
}
|
||||
if status == "skipped":
|
||||
summary[name]["reason"] = str(branch.get("skip_reason", "") or "")
|
||||
if status == "failed":
|
||||
summary[name]["error"] = str(branch.get("error", "") or "")
|
||||
return summary
|
||||
|
||||
def _build_content(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
branches: Dict[str, Dict[str, Any]],
|
||||
errors: List[Dict[str, str]],
|
||||
mixed_results: Optional[List[Dict[str, Any]]],
|
||||
) -> str:
|
||||
lines: List[str] = [
|
||||
f"🔀 聚合查询结果(query='{query or 'N/A'}')",
|
||||
"",
|
||||
"分支状态:",
|
||||
]
|
||||
for name in ("search", "time", "episode"):
|
||||
branch = branches.get(name, {})
|
||||
status = self._status(branch)
|
||||
count = int(branch.get("count", 0) or 0)
|
||||
line = f"- {name}: {status}, count={count}"
|
||||
reason = str(branch.get("skip_reason", "") or "").strip()
|
||||
err = str(branch.get("error", "") or "").strip()
|
||||
if status == "skipped" and reason:
|
||||
line += f" ({reason})"
|
||||
if status == "failed" and err:
|
||||
line += f" ({err})"
|
||||
lines.append(line)
|
||||
|
||||
if errors:
|
||||
lines.append("")
|
||||
lines.append("错误:")
|
||||
for item in errors[:6]:
|
||||
lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}")
|
||||
|
||||
if mixed_results is not None:
|
||||
lines.append("")
|
||||
lines.append(f"🧩 混合结果({len(mixed_results)} 条):")
|
||||
for idx, item in enumerate(mixed_results[:5], start=1):
|
||||
src = ",".join(item.get("source_branches", []) or [])
|
||||
if str(item.get("type", "") or "") == "episode":
|
||||
title = str(item.get("title", "") or "Untitled")
|
||||
lines.append(f"{idx}. 🧠 {title} [{src}]")
|
||||
else:
|
||||
text = str(item.get("content", "") or "")
|
||||
if len(text) > 80:
|
||||
text = text[:80] + "..."
|
||||
lines.append(f"{idx}. {text} [{src}]")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
top_k: int,
|
||||
mix: bool,
|
||||
mix_top_k: Optional[int],
|
||||
time_from: Optional[str],
|
||||
time_to: Optional[str],
|
||||
search_runner: Optional[BranchRunner],
|
||||
time_runner: Optional[BranchRunner],
|
||||
episode_runner: Optional[BranchRunner],
|
||||
) -> Dict[str, Any]:
|
||||
clean_query = str(query or "").strip()
|
||||
safe_top_k = max(1, int(top_k))
|
||||
safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k))
|
||||
|
||||
branches: Dict[str, Dict[str, Any]] = {}
|
||||
errors: List[Dict[str, str]] = []
|
||||
scheduled: List[Tuple[str, asyncio.Task]] = []
|
||||
|
||||
if clean_query:
|
||||
if search_runner is not None:
|
||||
scheduled.append(("search", asyncio.create_task(search_runner())))
|
||||
else:
|
||||
branches["search"] = self._normalize_branch_payload(
|
||||
"search",
|
||||
{"success": False, "error": "search runner unavailable", "results": []},
|
||||
)
|
||||
else:
|
||||
branches["search"] = self._normalize_branch_payload(
|
||||
"search",
|
||||
{
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"skip_reason": "missing_query",
|
||||
"results": [],
|
||||
"count": 0,
|
||||
},
|
||||
)
|
||||
|
||||
if time_from or time_to:
|
||||
if time_runner is not None:
|
||||
scheduled.append(("time", asyncio.create_task(time_runner())))
|
||||
else:
|
||||
branches["time"] = self._normalize_branch_payload(
|
||||
"time",
|
||||
{"success": False, "error": "time runner unavailable", "results": []},
|
||||
)
|
||||
else:
|
||||
branches["time"] = self._normalize_branch_payload(
|
||||
"time",
|
||||
{
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"skip_reason": "missing_time_window",
|
||||
"results": [],
|
||||
"count": 0,
|
||||
},
|
||||
)
|
||||
|
||||
if episode_runner is not None:
|
||||
scheduled.append(("episode", asyncio.create_task(episode_runner())))
|
||||
else:
|
||||
branches["episode"] = self._normalize_branch_payload(
|
||||
"episode",
|
||||
{"success": False, "error": "episode runner unavailable", "results": []},
|
||||
)
|
||||
|
||||
if scheduled:
|
||||
done = await asyncio.gather(
|
||||
*[task for _, task in scheduled],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for (branch_name, _), payload in zip(scheduled, done):
|
||||
if isinstance(payload, Exception):
|
||||
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
|
||||
normalized = self._normalize_branch_payload(
|
||||
branch_name,
|
||||
{
|
||||
"success": False,
|
||||
"error": str(payload),
|
||||
"results": [],
|
||||
},
|
||||
)
|
||||
else:
|
||||
normalized = self._normalize_branch_payload(branch_name, payload)
|
||||
branches[branch_name] = normalized
|
||||
|
||||
for name in ("search", "time", "episode"):
|
||||
branch = branches.get(name)
|
||||
if not branch:
|
||||
continue
|
||||
if branch.get("skipped", False):
|
||||
continue
|
||||
if not branch.get("success", False):
|
||||
errors.append(
|
||||
{
|
||||
"branch": name,
|
||||
"error": str(branch.get("error", "") or "unknown error"),
|
||||
}
|
||||
)
|
||||
|
||||
success = any(
|
||||
bool(branches.get(name, {}).get("success", False))
|
||||
for name in ("search", "time", "episode")
|
||||
)
|
||||
mixed_results: Optional[List[Dict[str, Any]]] = None
|
||||
if mix:
|
||||
mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"success": success,
|
||||
"query_type": "aggregate",
|
||||
"query": clean_query,
|
||||
"top_k": safe_top_k,
|
||||
"mix": bool(mix),
|
||||
"mix_top_k": safe_mix_top_k,
|
||||
"branches": branches,
|
||||
"errors": errors,
|
||||
"summary": self._build_summary(branches),
|
||||
}
|
||||
if mixed_results is not None:
|
||||
payload["mixed_results"] = mixed_results
|
||||
|
||||
payload["content"] = self._build_content(
|
||||
query=clean_query,
|
||||
branches=branches,
|
||||
errors=errors,
|
||||
mixed_results=mixed_results,
|
||||
)
|
||||
return payload
|
||||
182
plugins/A_memorix/core/utils/episode_retrieval_service.py
Normal file
182
plugins/A_memorix/core/utils/episode_retrieval_service.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Episode hybrid retrieval service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import DualPathRetriever, TemporalQueryOptions
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeRetrievalService")
|
||||
|
||||
|
||||
class EpisodeRetrievalService:
|
||||
"""Hybrid episode retrieval backed by lexical rows and evidence projection."""
|
||||
|
||||
_RRF_K = 60.0
|
||||
_BRANCH_WEIGHTS = {
|
||||
"lexical": 1.0,
|
||||
"paragraph_evidence": 1.0,
|
||||
"relation_evidence": 0.85,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_store: Any,
|
||||
retriever: Optional[DualPathRetriever] = None,
|
||||
) -> None:
|
||||
self.metadata_store = metadata_store
|
||||
self.retriever = retriever
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
query: str = "",
|
||||
top_k: int = 5,
|
||||
time_from: Optional[float] = None,
|
||||
time_to: Optional[float] = None,
|
||||
person: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
include_paragraphs: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
clean_query = str(query or "").strip()
|
||||
safe_top_k = max(1, int(top_k))
|
||||
candidate_k = max(30, safe_top_k * 6)
|
||||
|
||||
branches: Dict[str, List[Dict[str, Any]]] = {
|
||||
"lexical": self.metadata_store.query_episodes(
|
||||
query=clean_query,
|
||||
time_from=time_from,
|
||||
time_to=time_to,
|
||||
person=person,
|
||||
source=source,
|
||||
limit=(candidate_k if clean_query else safe_top_k),
|
||||
)
|
||||
}
|
||||
|
||||
if clean_query and self.retriever is not None:
|
||||
try:
|
||||
temporal = TemporalQueryOptions(
|
||||
time_from=time_from,
|
||||
time_to=time_to,
|
||||
person=person,
|
||||
source=source,
|
||||
)
|
||||
results = await self.retriever.retrieve(
|
||||
query=clean_query,
|
||||
top_k=candidate_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
|
||||
else:
|
||||
paragraph_rank_map: Dict[str, int] = {}
|
||||
relation_rank_map: Dict[str, int] = {}
|
||||
for rank, item in enumerate(results, start=1):
|
||||
hash_value = str(getattr(item, "hash_value", "") or "").strip()
|
||||
result_type = str(getattr(item, "result_type", "") or "").strip().lower()
|
||||
if not hash_value:
|
||||
continue
|
||||
if result_type == "paragraph" and hash_value not in paragraph_rank_map:
|
||||
paragraph_rank_map[hash_value] = rank
|
||||
elif result_type == "relation" and hash_value not in relation_rank_map:
|
||||
relation_rank_map[hash_value] = rank
|
||||
|
||||
if paragraph_rank_map:
|
||||
paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes(
|
||||
list(paragraph_rank_map.keys()),
|
||||
source=source,
|
||||
)
|
||||
if paragraph_rows:
|
||||
branches["paragraph_evidence"] = self._rank_projected_rows(
|
||||
paragraph_rows,
|
||||
rank_map=paragraph_rank_map,
|
||||
support_key="matched_paragraph_hashes",
|
||||
)
|
||||
|
||||
if relation_rank_map:
|
||||
relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes(
|
||||
list(relation_rank_map.keys()),
|
||||
source=source,
|
||||
)
|
||||
if relation_rows:
|
||||
branches["relation_evidence"] = self._rank_projected_rows(
|
||||
relation_rows,
|
||||
rank_map=relation_rank_map,
|
||||
support_key="matched_relation_hashes",
|
||||
)
|
||||
|
||||
fused = self._fuse_branches(branches, top_k=safe_top_k)
|
||||
if include_paragraphs:
|
||||
for item in fused:
|
||||
item["paragraphs"] = self.metadata_store.get_episode_paragraphs(
|
||||
episode_id=str(item.get("episode_id") or ""),
|
||||
limit=50,
|
||||
)
|
||||
return fused
|
||||
|
||||
@staticmethod
|
||||
def _rank_projected_rows(
|
||||
rows: List[Dict[str, Any]],
|
||||
*,
|
||||
rank_map: Dict[str, int],
|
||||
support_key: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
sentinel = 10**9
|
||||
ranked = [dict(item) for item in rows]
|
||||
|
||||
def _first_support_rank(item: Dict[str, Any]) -> int:
|
||||
support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])]
|
||||
ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map]
|
||||
return min(ranks) if ranks else sentinel
|
||||
|
||||
ranked.sort(
|
||||
key=lambda item: (
|
||||
_first_support_rank(item),
|
||||
-int(item.get("matched_paragraph_count") or 0),
|
||||
-float(item.get("updated_at") or 0.0),
|
||||
str(item.get("episode_id") or ""),
|
||||
)
|
||||
)
|
||||
return ranked
|
||||
|
||||
def _fuse_branches(
|
||||
self,
|
||||
branches: Dict[str, List[Dict[str, Any]]],
|
||||
*,
|
||||
top_k: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
bucket: Dict[str, Dict[str, Any]] = {}
|
||||
for branch_name, rows in branches.items():
|
||||
weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0)
|
||||
if weight <= 0.0:
|
||||
continue
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
episode_id = str(row.get("episode_id", "") or "").strip()
|
||||
if not episode_id:
|
||||
continue
|
||||
if episode_id not in bucket:
|
||||
payload = dict(row)
|
||||
payload.pop("matched_paragraph_hashes", None)
|
||||
payload.pop("matched_relation_hashes", None)
|
||||
payload.pop("matched_paragraph_count", None)
|
||||
payload.pop("matched_relation_count", None)
|
||||
payload["_fusion_score"] = 0.0
|
||||
bucket[episode_id] = payload
|
||||
bucket[episode_id]["_fusion_score"] = float(
|
||||
bucket[episode_id].get("_fusion_score", 0.0)
|
||||
) + weight / (self._RRF_K + float(rank))
|
||||
|
||||
out = list(bucket.values())
|
||||
out.sort(
|
||||
key=lambda item: (
|
||||
-float(item.get("_fusion_score", 0.0)),
|
||||
-float(item.get("updated_at") or 0.0),
|
||||
str(item.get("episode_id") or ""),
|
||||
)
|
||||
)
|
||||
for item in out:
|
||||
item.pop("_fusion_score", None)
|
||||
return out[: max(1, int(top_k))]
|
||||
129
plugins/A_memorix/core/utils/hash.py
Normal file
129
plugins/A_memorix/core/utils/hash.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
哈希工具模块
|
||||
|
||||
提供文本哈希计算功能,用于唯一标识和去重。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
|
||||
def compute_hash(text: str, hash_type: str = "sha256") -> str:
|
||||
"""
|
||||
计算文本的哈希值
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
hash_type: 哈希算法类型(sha256, md5等)
|
||||
|
||||
Returns:
|
||||
哈希值字符串
|
||||
"""
|
||||
if hash_type == "sha256":
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
elif hash_type == "md5":
|
||||
return hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
raise ValueError(f"不支持的哈希算法: {hash_type}")
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
规范化文本用于哈希计算
|
||||
|
||||
执行以下操作:
|
||||
- 去除首尾空白
|
||||
- 统一换行符为\\n
|
||||
- 压缩多个连续空格
|
||||
- 去除不可见字符(保留换行和制表符)
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
规范化后的文本
|
||||
"""
|
||||
# 去除首尾空白
|
||||
text = text.strip()
|
||||
|
||||
# 统一换行符
|
||||
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
# 压缩多个连续空格为一个(但保留换行和制表符)
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def compute_paragraph_hash(paragraph: str) -> str:
|
||||
"""
|
||||
计算段落的哈希值
|
||||
|
||||
Args:
|
||||
paragraph: 段落文本
|
||||
|
||||
Returns:
|
||||
段落哈希值(用于paragraph-前缀)
|
||||
"""
|
||||
normalized = normalize_text(paragraph)
|
||||
return compute_hash(normalized)
|
||||
|
||||
|
||||
def compute_entity_hash(entity: str) -> str:
|
||||
"""
|
||||
计算实体的哈希值
|
||||
|
||||
Args:
|
||||
entity: 实体名称
|
||||
|
||||
Returns:
|
||||
实体哈希值(用于entity-前缀)
|
||||
"""
|
||||
normalized = entity.strip().lower()
|
||||
return compute_hash(normalized)
|
||||
|
||||
|
||||
def compute_relation_hash(relation: tuple) -> str:
|
||||
"""
|
||||
计算关系的哈希值
|
||||
|
||||
Args:
|
||||
relation: 关系元组 (subject, predicate, object)
|
||||
|
||||
Returns:
|
||||
关系哈希值(用于relation-前缀)
|
||||
"""
|
||||
# 将关系元组转为字符串
|
||||
relation_str = str(tuple(relation))
|
||||
return compute_hash(relation_str)
|
||||
|
||||
|
||||
def format_hash_key(hash_type: str, hash_value: str) -> str:
|
||||
"""
|
||||
格式化哈希键
|
||||
|
||||
Args:
|
||||
hash_type: 类型前缀(paragraph, entity, relation)
|
||||
hash_value: 哈希值
|
||||
|
||||
Returns:
|
||||
格式化的键(如 paragraph-abc123...)
|
||||
"""
|
||||
return f"{hash_type}-{hash_value}"
|
||||
|
||||
|
||||
def parse_hash_key(key: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析哈希键
|
||||
|
||||
Args:
|
||||
key: 格式化的键(如 paragraph-abc123...)
|
||||
|
||||
Returns:
|
||||
(类型, 哈希值) 元组
|
||||
"""
|
||||
parts = key.split("-", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"无效的哈希键格式: {key}")
|
||||
return parts[0], parts[1]
|
||||
110
plugins/A_memorix/core/utils/import_payloads.py
Normal file
110
plugins/A_memorix/core/utils/import_payloads.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Shared import payload normalization helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..storage import KnowledgeType, resolve_stored_knowledge_type
|
||||
from .time_parser import normalize_time_meta
|
||||
|
||||
|
||||
def _normalize_entities(raw_entities: Any) -> List[str]:
|
||||
if not isinstance(raw_entities, list):
|
||||
return []
|
||||
out: List[str] = []
|
||||
seen = set()
|
||||
for item in raw_entities:
|
||||
name = str(item or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(name)
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]:
|
||||
if not isinstance(raw_relations, list):
|
||||
return []
|
||||
out: List[Dict[str, str]] = []
|
||||
for item in raw_relations:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
subject = str(item.get("subject", "")).strip()
|
||||
predicate = str(item.get("predicate", "")).strip()
|
||||
obj = str(item.get("object", "")).strip()
|
||||
if not (subject and predicate and obj):
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"subject": subject,
|
||||
"predicate": predicate,
|
||||
"object": obj,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def normalize_paragraph_import_item(
|
||||
item: Any,
|
||||
*,
|
||||
default_source: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Normalize one paragraph import item from text/json payloads."""
|
||||
|
||||
if isinstance(item, str):
|
||||
content = str(item)
|
||||
knowledge_type = resolve_stored_knowledge_type(None, content=content)
|
||||
return {
|
||||
"content": content,
|
||||
"knowledge_type": knowledge_type.value,
|
||||
"source": str(default_source or "").strip(),
|
||||
"time_meta": None,
|
||||
"entities": [],
|
||||
"relations": [],
|
||||
}
|
||||
|
||||
if not isinstance(item, dict) or "content" not in item:
|
||||
raise ValueError("段落项必须为字符串或包含 content 的对象")
|
||||
|
||||
content = str(item.get("content", "") or "")
|
||||
if not content.strip():
|
||||
raise ValueError("段落 content 不能为空")
|
||||
|
||||
raw_time_meta = {
|
||||
"event_time": item.get("event_time"),
|
||||
"event_time_start": item.get("event_time_start"),
|
||||
"event_time_end": item.get("event_time_end"),
|
||||
"time_range": item.get("time_range"),
|
||||
"time_granularity": item.get("time_granularity"),
|
||||
"time_confidence": item.get("time_confidence"),
|
||||
}
|
||||
time_meta_field = item.get("time_meta")
|
||||
if isinstance(time_meta_field, dict):
|
||||
raw_time_meta.update(time_meta_field)
|
||||
|
||||
knowledge_type_raw = item.get("knowledge_type")
|
||||
if knowledge_type_raw is None:
|
||||
knowledge_type_raw = item.get("type")
|
||||
knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content)
|
||||
source = str(item.get("source") or default_source or "").strip()
|
||||
if not source:
|
||||
source = str(default_source or "").strip()
|
||||
|
||||
normalized_time_meta = normalize_time_meta(raw_time_meta)
|
||||
return {
|
||||
"content": content,
|
||||
"knowledge_type": knowledge_type.value,
|
||||
"source": source,
|
||||
"time_meta": normalized_time_meta if normalized_time_meta else None,
|
||||
"entities": _normalize_entities(item.get("entities")),
|
||||
"relations": _normalize_relations(item.get("relations")),
|
||||
}
|
||||
|
||||
|
||||
def normalize_summary_knowledge_type(value: Any) -> KnowledgeType:
|
||||
"""Normalize config-driven summary knowledge type."""
|
||||
|
||||
return resolve_stored_knowledge_type(value, content="")
|
||||
84
plugins/A_memorix/core/utils/io.py
Normal file
84
plugins/A_memorix/core/utils/io.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
IO Utilities
|
||||
|
||||
提供原子文件写入等IO辅助功能。
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@contextlib.contextmanager
|
||||
def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs):
|
||||
"""
|
||||
原子文件写入上下文管理器
|
||||
|
||||
原理:
|
||||
1. 写入 .tmp 临时文件
|
||||
2. 写入成功后,使用 os.replace 原子替换目标文件
|
||||
3. 如果失败,自动删除临时文件
|
||||
|
||||
Args:
|
||||
file_path: 目标文件路径
|
||||
mode: 打开模式 ('w', 'wb' 等)
|
||||
encoding: 编码
|
||||
**kwargs: 传给 open() 的其他参数
|
||||
"""
|
||||
path = Path(file_path)
|
||||
# 确保父目录存在
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 临时文件路径
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
|
||||
try:
|
||||
with open(tmp_path, mode, encoding=encoding, **kwargs) as f:
|
||||
yield f
|
||||
|
||||
# 确保写入磁盘
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
# 原子替换 (Windows下可能需要先删除目标文件,但 os.replace 在 Py3.3+ 尽可能原子)
|
||||
# 注意: Windows 上如果有其他进程占用文件,os.replace 可能会失败
|
||||
os.replace(tmp_path, path)
|
||||
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
if tmp_path.exists():
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
|
||||
@contextlib.contextmanager
|
||||
def atomic_save_path(file_path: Union[str, Path]):
|
||||
"""
|
||||
提供临时路径用于原子保存 (针对只接受路径的API,如Faiss)
|
||||
|
||||
Args:
|
||||
file_path: 最终目标路径
|
||||
|
||||
Yields:
|
||||
tmp_path: 临时文件路径 (str)
|
||||
"""
|
||||
path = Path(file_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
|
||||
try:
|
||||
yield str(tmp_path)
|
||||
|
||||
if Path(tmp_path).exists():
|
||||
os.replace(tmp_path, path)
|
||||
|
||||
except Exception as e:
|
||||
if Path(tmp_path).exists():
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
89
plugins/A_memorix/core/utils/matcher.py
Normal file
89
plugins/A_memorix/core/utils/matcher.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
高效文本匹配工具模块
|
||||
|
||||
实现 Aho-Corasick 算法用于多模式匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Tuple, Set, Any
|
||||
from collections import deque
|
||||
|
||||
|
||||
class AhoCorasick:
|
||||
"""
|
||||
Aho-Corasick 自动机实现高效多模式匹配
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# next_states[state][char] = next_state
|
||||
self.next_states: List[Dict[str, int]] = [{}]
|
||||
# fail[state] = fail_state
|
||||
self.fail: List[int] = [0]
|
||||
# output[state] = set of patterns ending at this state
|
||||
self.output: List[Set[str]] = [set()]
|
||||
self.patterns: Set[str] = set()
|
||||
|
||||
def add_pattern(self, pattern: str):
|
||||
"""添加模式"""
|
||||
if not pattern:
|
||||
return
|
||||
self.patterns.add(pattern)
|
||||
state = 0
|
||||
for char in pattern:
|
||||
if char not in self.next_states[state]:
|
||||
new_state = len(self.next_states)
|
||||
self.next_states[state][char] = new_state
|
||||
self.next_states.append({})
|
||||
self.fail.append(0)
|
||||
self.output.append(set())
|
||||
state = self.next_states[state][char]
|
||||
self.output[state].add(pattern)
|
||||
|
||||
def build(self):
|
||||
"""构建失败指针"""
|
||||
queue = deque()
|
||||
# 处理第一层
|
||||
for char, state in self.next_states[0].items():
|
||||
queue.append(state)
|
||||
self.fail[state] = 0
|
||||
|
||||
while queue:
|
||||
r = queue.popleft()
|
||||
for char, s in self.next_states[r].items():
|
||||
queue.append(s)
|
||||
# 找到失败路径
|
||||
state = self.fail[r]
|
||||
while char not in self.next_states[state] and state != 0:
|
||||
state = self.fail[state]
|
||||
self.fail[s] = self.next_states[state].get(char, 0)
|
||||
# 合并输出
|
||||
self.output[s].update(self.output[self.fail[s]])
|
||||
|
||||
def search(self, text: str) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
在文本中搜索所有模式
|
||||
|
||||
Returns:
|
||||
[(结束索引, 匹配到的模式), ...]
|
||||
"""
|
||||
state = 0
|
||||
results = []
|
||||
for i, char in enumerate(text):
|
||||
while char not in self.next_states[state] and state != 0:
|
||||
state = self.fail[state]
|
||||
state = self.next_states[state].get(char, 0)
|
||||
for pattern in self.output[state]:
|
||||
results.append((i, pattern))
|
||||
return results
|
||||
|
||||
def find_all(self, text: str) -> Dict[str, int]:
|
||||
"""
|
||||
查找并统计所有模式出现次数
|
||||
|
||||
Returns:
|
||||
{模式: 出现次数}
|
||||
"""
|
||||
results = self.search(text)
|
||||
stats = {}
|
||||
for _, pattern in results:
|
||||
stats[pattern] = stats.get(pattern, 0) + 1
|
||||
return stats
|
||||
189
plugins/A_memorix/core/utils/monitor.py
Normal file
189
plugins/A_memorix/core/utils/monitor.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
内存监控模块
|
||||
|
||||
提供内存使用监控和预警功能。
|
||||
"""
|
||||
|
||||
import gc
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
try:
|
||||
import psutil
|
||||
HAS_PSUTIL = True
|
||||
except ImportError:
|
||||
HAS_PSUTIL = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.MemoryMonitor")
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""
|
||||
内存监控器
|
||||
|
||||
功能:
|
||||
- 实时监控内存使用
|
||||
- 超过阈值时触发警告
|
||||
- 支持自动垃圾回收
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_memory_mb: int,
|
||||
warning_threshold: float = 0.9,
|
||||
check_interval: float = 10.0,
|
||||
enable_auto_gc: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化内存监控器
|
||||
|
||||
Args:
|
||||
max_memory_mb: 最大内存限制(MB)
|
||||
warning_threshold: 警告阈值(0-1之间,默认0.9表示90%)
|
||||
check_interval: 检查间隔(秒)
|
||||
enable_auto_gc: 是否启用自动垃圾回收
|
||||
"""
|
||||
self.max_memory_mb = max_memory_mb
|
||||
self.warning_threshold = warning_threshold
|
||||
self.check_interval = check_interval
|
||||
self.enable_auto_gc = enable_auto_gc
|
||||
|
||||
self._running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._callbacks: list[Callable[[float, float], None]] = []
|
||||
|
||||
def start(self):
|
||||
"""启动监控"""
|
||||
if self._running:
|
||||
logger.warning("内存监控已在运行")
|
||||
return
|
||||
|
||||
if not HAS_PSUTIL:
|
||||
logger.warning("psutil 未安装,内存监控功能不可用")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)")
|
||||
|
||||
def stop(self):
|
||||
"""停止监控"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5.0)
|
||||
logger.info("内存监控已停止")
|
||||
|
||||
def register_callback(self, callback: Callable[[float, float], None]):
|
||||
"""
|
||||
注册内存超限回调函数
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收 (当前使用MB, 限制MB) 参数
|
||||
"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def get_current_memory_mb(self) -> float:
|
||||
"""
|
||||
获取当前进程内存使用量(MB)
|
||||
|
||||
Returns:
|
||||
内存使用量(MB)
|
||||
"""
|
||||
if not HAS_PSUTIL:
|
||||
# 降级方案:使用内置函数
|
||||
import sys
|
||||
return sys.getsizeof(gc.get_objects()) / 1024 / 1024
|
||||
|
||||
process = psutil.Process()
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
def get_memory_usage_ratio(self) -> float:
|
||||
"""
|
||||
获取内存使用率
|
||||
|
||||
Returns:
|
||||
使用率(0-1之间)
|
||||
"""
|
||||
current = self.get_current_memory_mb()
|
||||
return current / self.max_memory_mb if self.max_memory_mb > 0 else 0
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""监控循环"""
|
||||
while self._running:
|
||||
try:
|
||||
current_mb = self.get_current_memory_mb()
|
||||
ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0
|
||||
|
||||
# 检查是否超过阈值
|
||||
if ratio >= self.warning_threshold:
|
||||
logger.warning(
|
||||
f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB "
|
||||
f"({ratio*100:.1f}%)"
|
||||
)
|
||||
|
||||
# 触发回调
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
callback(current_mb, self.max_memory_mb)
|
||||
except Exception as e:
|
||||
logger.error(f"内存回调执行失败: {e}")
|
||||
|
||||
# 自动垃圾回收
|
||||
if self.enable_auto_gc:
|
||||
before = self.get_current_memory_mb()
|
||||
gc.collect()
|
||||
after = self.get_current_memory_mb()
|
||||
freed = before - after
|
||||
if freed > 1: # 释放超过1MB才记录
|
||||
logger.info(f"垃圾回收释放: {freed:.1f}MB")
|
||||
|
||||
# 定期垃圾回收(即使未超限)
|
||||
elif self.enable_auto_gc and int(time.time()) % 60 == 0:
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"内存监控出错: {e}")
|
||||
|
||||
# 等待下次检查
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""上下文管理器出口"""
|
||||
self.stop()
|
||||
|
||||
|
||||
def get_memory_info() -> dict:
|
||||
"""
|
||||
获取系统内存信息
|
||||
|
||||
Returns:
|
||||
内存信息字典
|
||||
"""
|
||||
if not HAS_PSUTIL:
|
||||
return {"error": "psutil 未安装"}
|
||||
|
||||
try:
|
||||
mem = psutil.virtual_memory()
|
||||
process = psutil.Process()
|
||||
|
||||
return {
|
||||
"system_total_gb": mem.total / 1024 / 1024 / 1024,
|
||||
"system_available_gb": mem.available / 1024 / 1024 / 1024,
|
||||
"system_usage_percent": mem.percent,
|
||||
"process_mb": process.memory_info().rss / 1024 / 1024,
|
||||
"process_percent": (process.memory_info().rss / mem.total) * 100,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
165
plugins/A_memorix/core/utils/path_fallback_service.py
Normal file
165
plugins/A_memorix/core/utils/path_fallback_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Shared path-fallback helpers for search post-processing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..retrieval.dual_path import RetrievalResult
|
||||
|
||||
|
||||
def extract_entities(query: str, graph_store: Any) -> List[str]:
|
||||
"""Extract up to two graph nodes from a query using n-gram matching."""
|
||||
if not graph_store:
|
||||
return []
|
||||
|
||||
text = str(query or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Keep the heuristic aligned with previous legacy behavior.
|
||||
tokens = (
|
||||
text.replace("?", " ")
|
||||
.replace("!", " ")
|
||||
.replace(".", " ")
|
||||
.split()
|
||||
)
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
found_entities = set()
|
||||
skip_indices = set()
|
||||
max_n = min(4, len(tokens))
|
||||
|
||||
for size in range(max_n, 0, -1):
|
||||
for i in range(len(tokens) - size + 1):
|
||||
if any(idx in skip_indices for idx in range(i, i + size)):
|
||||
continue
|
||||
span = " ".join(tokens[i : i + size])
|
||||
matched_node = graph_store.find_node(span, ignore_case=True)
|
||||
if not matched_node:
|
||||
continue
|
||||
found_entities.add(matched_node)
|
||||
for idx in range(i, i + size):
|
||||
skip_indices.add(idx)
|
||||
|
||||
return list(found_entities)
|
||||
|
||||
|
||||
def find_paths_between_entities(
|
||||
start_node: str,
|
||||
end_node: str,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
*,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find and enrich indirect paths between two nodes."""
|
||||
if not graph_store or not metadata_store:
|
||||
return []
|
||||
|
||||
try:
|
||||
paths = graph_store.find_paths(
|
||||
start_node,
|
||||
end_node,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if not paths:
|
||||
return []
|
||||
|
||||
edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {}
|
||||
formatted_paths: List[Dict[str, Any]] = []
|
||||
|
||||
for path_nodes in paths:
|
||||
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2:
|
||||
continue
|
||||
|
||||
path_desc: List[str] = []
|
||||
for i in range(len(path_nodes) - 1):
|
||||
u = str(path_nodes[i])
|
||||
v = str(path_nodes[i + 1])
|
||||
|
||||
cache_key = tuple(sorted((u, v)))
|
||||
if cache_key in edge_cache:
|
||||
pred, direction = edge_cache[cache_key]
|
||||
else:
|
||||
pred = "related"
|
||||
direction = "->"
|
||||
rels = metadata_store.get_relations(subject=u, object=v)
|
||||
if not rels:
|
||||
rels = metadata_store.get_relations(subject=v, object=u)
|
||||
direction = "<-"
|
||||
if rels:
|
||||
best_rel = max(rels, key=lambda x: x.get("confidence", 1.0))
|
||||
pred = str(best_rel.get("predicate", "related") or "related")
|
||||
edge_cache[cache_key] = (pred, direction)
|
||||
|
||||
step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-"
|
||||
path_desc.append(step_str)
|
||||
|
||||
full_path_str = str(path_nodes[0])
|
||||
for i, step in enumerate(path_desc):
|
||||
full_path_str += f" {step} {path_nodes[i + 1]}"
|
||||
|
||||
formatted_paths.append(
|
||||
{
|
||||
"nodes": list(path_nodes),
|
||||
"description": full_path_str,
|
||||
}
|
||||
)
|
||||
|
||||
return formatted_paths
|
||||
|
||||
|
||||
def find_paths_from_query(
|
||||
query: str,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
*,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract entities from query and resolve indirect paths."""
|
||||
entities = extract_entities(query, graph_store)
|
||||
if len(entities) != 2:
|
||||
return []
|
||||
return find_paths_between_entities(
|
||||
entities[0],
|
||||
entities[1],
|
||||
graph_store,
|
||||
metadata_store,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
|
||||
|
||||
def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]:
|
||||
"""Convert path results into retrieval results for the unified pipeline."""
|
||||
converted: List[RetrievalResult] = []
|
||||
for item in paths:
|
||||
description = str(item.get("description", "")).strip()
|
||||
if not description:
|
||||
continue
|
||||
hash_seed = description.encode("utf-8")
|
||||
path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}"
|
||||
converted.append(
|
||||
RetrievalResult(
|
||||
hash_value=path_hash,
|
||||
content=f"[Indirect Relation] {description}",
|
||||
score=0.95,
|
||||
result_type="relation",
|
||||
source="graph_path",
|
||||
metadata={
|
||||
"source": "graph_path",
|
||||
"is_indirect": True,
|
||||
"nodes": list(item.get("nodes", [])),
|
||||
},
|
||||
)
|
||||
)
|
||||
return converted
|
||||
|
||||
495
plugins/A_memorix/core/utils/person_profile_service.py
Normal file
495
plugins/A_memorix/core/utils/person_profile_service.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
人物画像服务
|
||||
|
||||
主链路:
|
||||
person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import PersonInfo
|
||||
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
from ..retrieval import (
|
||||
DualPathRetriever,
|
||||
RetrievalStrategy,
|
||||
DualPathRetrieverConfig,
|
||||
SparseBM25Config,
|
||||
FusionConfig,
|
||||
GraphRelationRecallConfig,
|
||||
)
|
||||
from ..storage import MetadataStore, GraphStore, VectorStore
|
||||
|
||||
logger = get_logger("A_Memorix.PersonProfileService")
|
||||
|
||||
|
||||
class PersonProfileService:
|
||||
"""人物画像聚合/刷新服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: MetadataStore,
|
||||
graph_store: Optional[GraphStore] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
embedding_manager: Optional[EmbeddingAPIAdapter] = None,
|
||||
sparse_index: Any = None,
|
||||
plugin_config: Optional[dict] = None,
|
||||
retriever: Optional[DualPathRetriever] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.graph_store = graph_store
|
||||
self.vector_store = vector_store
|
||||
self.embedding_manager = embedding_manager
|
||||
self.sparse_index = sparse_index
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.retriever = retriever or self._build_retriever()
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
"""读取嵌套配置。"""
|
||||
if not isinstance(self.plugin_config, dict):
|
||||
return default
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
def _build_retriever(self) -> Optional[DualPathRetriever]:
|
||||
"""按需构建检索器(无依赖时返回 None)。"""
|
||||
if not all(
|
||||
[
|
||||
self.vector_store is not None,
|
||||
self.graph_store is not None,
|
||||
self.metadata_store is not None,
|
||||
self.embedding_manager is not None,
|
||||
]
|
||||
):
|
||||
return None
|
||||
try:
|
||||
sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {}
|
||||
fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {}
|
||||
graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {}
|
||||
if not isinstance(sparse_cfg_raw, dict):
|
||||
sparse_cfg_raw = {}
|
||||
if not isinstance(fusion_cfg_raw, dict):
|
||||
fusion_cfg_raw = {}
|
||||
if not isinstance(graph_recall_cfg_raw, dict):
|
||||
graph_recall_cfg_raw = {}
|
||||
|
||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||
fusion_cfg = FusionConfig(**fusion_cfg_raw)
|
||||
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
|
||||
config = DualPathRetrieverConfig(
|
||||
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)),
|
||||
top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)),
|
||||
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
|
||||
alpha=float(self._cfg("retrieval.alpha", 0.5)),
|
||||
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
|
||||
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
|
||||
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
|
||||
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
|
||||
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
|
||||
debug=bool(self._cfg("advanced.debug", False)),
|
||||
sparse=sparse_cfg,
|
||||
fusion=fusion_cfg,
|
||||
graph_recall=graph_recall_cfg,
|
||||
)
|
||||
return DualPathRetriever(
|
||||
vector_store=self.vector_store,
|
||||
graph_store=self.graph_store,
|
||||
metadata_store=self.metadata_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
sparse_index=self.sparse_index,
|
||||
config=config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def resolve_person_id(identifier: str) -> str:
|
||||
"""按 person_id 或姓名/别名解析 person_id。"""
|
||||
if not identifier:
|
||||
return ""
|
||||
key = str(identifier).strip()
|
||||
if not key:
|
||||
return ""
|
||||
|
||||
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
|
||||
return key.lower()
|
||||
|
||||
try:
|
||||
record = (
|
||||
PersonInfo.select(PersonInfo.person_id)
|
||||
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
|
||||
.first()
|
||||
)
|
||||
if record and record.person_id:
|
||||
return str(record.person_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
record = (
|
||||
PersonInfo.select(PersonInfo.person_id)
|
||||
.where(PersonInfo.group_nick_name.contains(key))
|
||||
.first()
|
||||
)
|
||||
if record and record.person_id:
|
||||
return str(record.person_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
|
||||
if not raw_value:
|
||||
return []
|
||||
if isinstance(raw_value, list):
|
||||
items = raw_value
|
||||
else:
|
||||
try:
|
||||
items = json.loads(raw_value)
|
||||
except Exception:
|
||||
return []
|
||||
names: List[str] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
value = str(item.get("group_nick_name", "")).strip()
|
||||
if value:
|
||||
names.append(value)
|
||||
elif isinstance(item, str):
|
||||
value = item.strip()
|
||||
if value:
|
||||
names.append(value)
|
||||
return names
|
||||
|
||||
def _parse_memory_traits(self, raw_value: Any) -> List[str]:
|
||||
if not raw_value:
|
||||
return []
|
||||
try:
|
||||
values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(values, list):
|
||||
return []
|
||||
traits: List[str] = []
|
||||
for item in values:
|
||||
text = str(item).strip()
|
||||
if not text:
|
||||
continue
|
||||
if ":" in text:
|
||||
parts = text.split(":")
|
||||
if len(parts) >= 3:
|
||||
content = ":".join(parts[1:-1]).strip()
|
||||
if content:
|
||||
traits.append(content)
|
||||
continue
|
||||
traits.append(text)
|
||||
return traits[:10]
|
||||
|
||||
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
|
||||
"""获取人物别名集合、主展示名、记忆特征。"""
|
||||
aliases: List[str] = []
|
||||
primary_name = ""
|
||||
memory_traits: List[str] = []
|
||||
if not person_id:
|
||||
return aliases, primary_name, memory_traits
|
||||
try:
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if not record:
|
||||
return aliases, primary_name, memory_traits
|
||||
person_name = str(getattr(record, "person_name", "") or "").strip()
|
||||
nickname = str(getattr(record, "nickname", "") or "").strip()
|
||||
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
|
||||
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
|
||||
|
||||
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
|
||||
|
||||
candidates = [person_name, nickname] + group_nicks
|
||||
seen = set()
|
||||
for item in candidates:
|
||||
norm = str(item or "").strip()
|
||||
if not norm or norm in seen:
|
||||
continue
|
||||
seen.add(norm)
|
||||
aliases.append(norm)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}")
|
||||
return aliases, primary_name, memory_traits
|
||||
|
||||
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
|
||||
relation_by_hash: Dict[str, Dict[str, Any]] = {}
|
||||
for alias in aliases:
|
||||
for rel in self.metadata_store.get_relations(subject=alias):
|
||||
h = str(rel.get("hash", ""))
|
||||
if h:
|
||||
relation_by_hash[h] = rel
|
||||
for rel in self.metadata_store.get_relations(object=alias):
|
||||
h = str(rel.get("hash", ""))
|
||||
if h:
|
||||
relation_by_hash[h] = rel
|
||||
|
||||
relations = list(relation_by_hash.values())
|
||||
relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True)
|
||||
relations = relations[: max(1, int(limit))]
|
||||
|
||||
edges: List[Dict[str, Any]] = []
|
||||
for rel in relations:
|
||||
edges.append(
|
||||
{
|
||||
"hash": str(rel.get("hash", "")),
|
||||
"subject": str(rel.get("subject", "")),
|
||||
"predicate": str(rel.get("predicate", "")),
|
||||
"object": str(rel.get("object", "")),
|
||||
"confidence": float(rel.get("confidence", 1.0) or 1.0),
|
||||
}
|
||||
)
|
||||
return edges
|
||||
|
||||
async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]:
|
||||
alias_queries = [a for a in aliases if a]
|
||||
if not alias_queries:
|
||||
return []
|
||||
|
||||
if self.retriever is None:
|
||||
# 回退:无检索器时只做简单内容匹配
|
||||
fallback: List[Dict[str, Any]] = []
|
||||
seen_hash = set()
|
||||
for alias in alias_queries:
|
||||
for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]:
|
||||
h = str(para.get("hash", ""))
|
||||
if not h or h in seen_hash:
|
||||
continue
|
||||
seen_hash.add(h)
|
||||
fallback.append(
|
||||
{
|
||||
"hash": h,
|
||||
"type": "paragraph",
|
||||
"score": 0.0,
|
||||
"content": str(para.get("content", ""))[:180],
|
||||
"metadata": {},
|
||||
}
|
||||
)
|
||||
return fallback[:top_k]
|
||||
|
||||
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
|
||||
seen_hash = set()
|
||||
evidence: List[Dict[str, Any]] = []
|
||||
for alias in alias_queries:
|
||||
try:
|
||||
results = await self.retriever.retrieve(alias, top_k=per_alias_top_k)
|
||||
except Exception as e:
|
||||
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
|
||||
continue
|
||||
for item in results:
|
||||
h = str(getattr(item, "hash_value", "") or "")
|
||||
if not h or h in seen_hash:
|
||||
continue
|
||||
seen_hash.add(h)
|
||||
evidence.append(
|
||||
{
|
||||
"hash": h,
|
||||
"type": str(getattr(item, "result_type", "")),
|
||||
"score": float(getattr(item, "score", 0.0) or 0.0),
|
||||
"content": str(getattr(item, "content", "") or "")[:220],
|
||||
"metadata": dict(getattr(item, "metadata", {}) or {}),
|
||||
}
|
||||
)
|
||||
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
|
||||
return evidence[:top_k]
|
||||
|
||||
def _build_profile_text(
|
||||
self,
|
||||
person_id: str,
|
||||
primary_name: str,
|
||||
aliases: List[str],
|
||||
relation_edges: List[Dict[str, Any]],
|
||||
vector_evidence: List[Dict[str, Any]],
|
||||
memory_traits: List[str],
|
||||
) -> str:
|
||||
"""基于证据构建画像文本(供 LLM 上下文注入)。"""
|
||||
lines: List[str] = []
|
||||
lines.append(f"人物ID: {person_id}")
|
||||
if primary_name:
|
||||
lines.append(f"主称呼: {primary_name}")
|
||||
if aliases:
|
||||
lines.append(f"别名: {', '.join(aliases[:8])}")
|
||||
if memory_traits:
|
||||
lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}")
|
||||
|
||||
if relation_edges:
|
||||
lines.append("关系证据:")
|
||||
for rel in relation_edges[:6]:
|
||||
s = rel.get("subject", "")
|
||||
p = rel.get("predicate", "")
|
||||
o = rel.get("object", "")
|
||||
conf = float(rel.get("confidence", 0.0))
|
||||
lines.append(f"- {s} {p} {o} (conf={conf:.2f})")
|
||||
|
||||
if vector_evidence:
|
||||
lines.append("向量证据摘要:")
|
||||
for item in vector_evidence[:4]:
|
||||
content = str(item.get("content", "")).strip()
|
||||
if content:
|
||||
lines.append(f"- {content}")
|
||||
|
||||
if len(lines) <= 2:
|
||||
lines.append("暂无足够证据形成稳定画像。")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool:
|
||||
if not snapshot:
|
||||
return True
|
||||
now = time.time()
|
||||
expires_at = snapshot.get("expires_at")
|
||||
if expires_at is not None:
|
||||
try:
|
||||
return now >= float(expires_at)
|
||||
except Exception:
|
||||
return True
|
||||
updated_at = float(snapshot.get("updated_at") or 0.0)
|
||||
return (now - updated_at) >= ttl_seconds
|
||||
|
||||
def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""将手工覆盖并入画像结果(覆盖 profile_text,同时保留 auto_profile_text)。"""
|
||||
payload = dict(profile_payload or {})
|
||||
auto_text = str(payload.get("profile_text", "") or "")
|
||||
payload["auto_profile_text"] = auto_text
|
||||
payload["has_manual_override"] = False
|
||||
payload["manual_override_text"] = ""
|
||||
payload["override_updated_at"] = None
|
||||
payload["override_updated_by"] = ""
|
||||
payload["profile_source"] = "auto_snapshot"
|
||||
|
||||
if not person_id or self.metadata_store is None:
|
||||
return payload
|
||||
|
||||
try:
|
||||
override = self.metadata_store.get_person_profile_override(person_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}")
|
||||
return payload
|
||||
|
||||
if not override:
|
||||
return payload
|
||||
|
||||
manual_text = str(override.get("override_text", "") or "").strip()
|
||||
if not manual_text:
|
||||
return payload
|
||||
|
||||
payload["has_manual_override"] = True
|
||||
payload["manual_override_text"] = manual_text
|
||||
payload["override_updated_at"] = override.get("updated_at")
|
||||
payload["override_updated_by"] = str(override.get("updated_by", "") or "")
|
||||
payload["profile_text"] = manual_text
|
||||
payload["profile_source"] = "manual_override"
|
||||
return payload
|
||||
|
||||
async def query_person_profile(
|
||||
self,
|
||||
person_id: str = "",
|
||||
person_keyword: str = "",
|
||||
top_k: int = 12,
|
||||
ttl_seconds: float = 6 * 3600,
|
||||
force_refresh: bool = False,
|
||||
source_note: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""查询或刷新人物画像。"""
|
||||
pid = str(person_id or "").strip()
|
||||
if not pid and person_keyword:
|
||||
pid = self.resolve_person_id(person_keyword)
|
||||
|
||||
if not pid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "person_id 无效,且未能通过别名解析",
|
||||
}
|
||||
|
||||
latest = self.metadata_store.get_latest_person_profile_snapshot(pid)
|
||||
if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds):
|
||||
aliases, primary_name, _ = self.get_person_aliases(pid)
|
||||
payload = {
|
||||
"success": True,
|
||||
"person_id": pid,
|
||||
"person_name": primary_name,
|
||||
"from_cache": True,
|
||||
**(latest or {}),
|
||||
}
|
||||
if aliases and not payload.get("aliases"):
|
||||
payload["aliases"] = aliases
|
||||
return {
|
||||
**self._apply_manual_override(pid, payload),
|
||||
}
|
||||
|
||||
aliases, primary_name, memory_traits = self.get_person_aliases(pid)
|
||||
if not aliases and person_keyword:
|
||||
aliases = [person_keyword.strip()]
|
||||
primary_name = person_keyword.strip()
|
||||
relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2))
|
||||
vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k))
|
||||
|
||||
evidence_ids = [
|
||||
str(item.get("hash", ""))
|
||||
for item in (relation_edges + vector_evidence)
|
||||
if str(item.get("hash", "")).strip()
|
||||
]
|
||||
dedup_ids: List[str] = []
|
||||
seen = set()
|
||||
for item in evidence_ids:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
dedup_ids.append(item)
|
||||
|
||||
profile_text = self._build_profile_text(
|
||||
person_id=pid,
|
||||
primary_name=primary_name,
|
||||
aliases=aliases,
|
||||
relation_edges=relation_edges,
|
||||
vector_evidence=vector_evidence,
|
||||
memory_traits=memory_traits,
|
||||
)
|
||||
|
||||
expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None
|
||||
snapshot = self.metadata_store.upsert_person_profile_snapshot(
|
||||
person_id=pid,
|
||||
profile_text=profile_text,
|
||||
aliases=aliases,
|
||||
relation_edges=relation_edges,
|
||||
vector_evidence=vector_evidence,
|
||||
evidence_ids=dedup_ids,
|
||||
expires_at=expires_at,
|
||||
source_note=source_note,
|
||||
)
|
||||
payload = {
|
||||
"success": True,
|
||||
"person_id": pid,
|
||||
"person_name": primary_name,
|
||||
"from_cache": False,
|
||||
**snapshot,
|
||||
}
|
||||
return {
|
||||
**self._apply_manual_override(pid, payload),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def format_persona_profile_block(profile: Dict[str, Any]) -> str:
|
||||
"""格式化给 replyer 的注入块。"""
|
||||
if not profile or not profile.get("success"):
|
||||
return ""
|
||||
text = str(profile.get("profile_text", "") or "").strip()
|
||||
if not text:
|
||||
return ""
|
||||
return (
|
||||
"【人物画像-内部参考】\n"
|
||||
f"{text}\n"
|
||||
"仅供内部推理,不要向用户逐字复述。"
|
||||
)
|
||||
27
plugins/A_memorix/core/utils/plugin_id_policy.py
Normal file
27
plugins/A_memorix/core/utils/plugin_id_policy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Plugin ID matching policy for A_Memorix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PluginIdPolicy:
|
||||
"""Centralized plugin id normalization/matching policy."""
|
||||
|
||||
CANONICAL_ID = "a_memorix"
|
||||
|
||||
@classmethod
|
||||
def normalize(cls, plugin_id: Any) -> str:
|
||||
if not isinstance(plugin_id, str):
|
||||
return ""
|
||||
return plugin_id.strip().lower()
|
||||
|
||||
@classmethod
|
||||
def is_target_plugin_id(cls, plugin_id: Any) -> bool:
|
||||
normalized = cls.normalize(plugin_id)
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized == cls.CANONICAL_ID:
|
||||
return True
|
||||
return normalized.split(".")[-1] == cls.CANONICAL_ID
|
||||
|
||||
344
plugins/A_memorix/core/utils/quantization.py
Normal file
344
plugins/A_memorix/core/utils/quantization.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
向量量化工具模块
|
||||
|
||||
提供向量量化与反量化功能,用于压缩存储空间。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from typing import Tuple, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.Quantization")
|
||||
|
||||
|
||||
class QuantizationType(Enum):
|
||||
"""量化类型枚举"""
|
||||
FLOAT32 = "float32" # 无量化
|
||||
INT8 = "int8" # 标量量化(8位整数)
|
||||
PQ = "pq" # 乘积量化(Product Quantization)
|
||||
|
||||
|
||||
def quantize_vector(
|
||||
vector: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
量化向量
|
||||
|
||||
Args:
|
||||
vector: 输入向量(float32)
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
量化后的向量:
|
||||
- INT8: int8向量
|
||||
- PQ: (编码向量, 聚类中心) 元组
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return vector.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
return _scalar_quantize_int8(vector)
|
||||
|
||||
elif quant_type == QuantizationType.PQ:
|
||||
return _product_quantize(vector)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def dequantize_vector(
|
||||
quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
original_shape: Tuple[int, ...] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
反量化向量
|
||||
|
||||
Args:
|
||||
quantized_vector: 量化后的向量
|
||||
quant_type: 量化类型
|
||||
original_shape: 原始向量形状(用于PQ)
|
||||
|
||||
Returns:
|
||||
反量化后的向量(float32)
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return quantized_vector.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
return _scalar_dequantize_int8(quantized_vector)
|
||||
|
||||
elif quant_type == QuantizationType.PQ:
|
||||
if not isinstance(quantized_vector, tuple):
|
||||
raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)")
|
||||
return _product_dequantize(quantized_vector[0], quantized_vector[1])
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
标量量化:float32 -> int8
|
||||
|
||||
将向量归一化到 [0, 255] 范围,然后映射到 int8
|
||||
|
||||
Args:
|
||||
vector: 输入向量
|
||||
|
||||
Returns:
|
||||
量化后的 int8 向量
|
||||
"""
|
||||
# 计算最小最大值
|
||||
min_val = np.min(vector)
|
||||
max_val = np.max(vector)
|
||||
|
||||
# 避免除零
|
||||
if max_val == min_val:
|
||||
return np.zeros_like(vector, dtype=np.int8)
|
||||
|
||||
# 归一化到 [0, 255]
|
||||
normalized = (vector - min_val) / (max_val - min_val) * 255
|
||||
|
||||
# 映射到 [-128, 127] 并转换为 int8
|
||||
# np.round might return float, minus 128 then cast
|
||||
quantized = np.round(normalized - 128.0).astype(np.int8)
|
||||
|
||||
# 存储归一化参数(用于反量化)
|
||||
# 在实际存储中,这些参数需要单独保存
|
||||
# 这里为了简单,我们使用一个全局字典来模拟
|
||||
if not hasattr(_scalar_quantize_int8, "_params"):
|
||||
_scalar_quantize_int8._params = {}
|
||||
|
||||
vector_id = id(vector)
|
||||
_scalar_quantize_int8._params[vector_id] = (min_val, max_val)
|
||||
|
||||
return quantized
|
||||
|
||||
|
||||
def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
标量反量化:int8 -> float32
|
||||
|
||||
Args:
|
||||
quantized: 量化后的 int8 向量
|
||||
|
||||
Returns:
|
||||
反量化后的 float32 向量
|
||||
"""
|
||||
# 计算归一化参数(如果提供了)
|
||||
# 在实际应用中,min_val 和 max_val 应该被保存
|
||||
if not hasattr(_scalar_dequantize_int8, "_params"):
|
||||
# 默认假设范围是 [-1, 1]
|
||||
return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0
|
||||
|
||||
# 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取)
|
||||
# return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min
|
||||
return (quantized.astype(np.float32) + 128.0) / 255.0
|
||||
|
||||
|
||||
def quantize_matrix(
|
||||
matrix: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
量化矩阵(批量量化向量)
|
||||
|
||||
Args:
|
||||
matrix: 输入矩阵(N x D,每行是一个向量)
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
量化后的矩阵
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return matrix.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
# 对整个矩阵进行全局归一化
|
||||
min_val = np.min(matrix)
|
||||
max_val = np.max(matrix)
|
||||
|
||||
if max_val == min_val:
|
||||
return np.zeros_like(matrix, dtype=np.int8)
|
||||
|
||||
# 归一化到 [0, 255]
|
||||
normalized = (matrix - min_val) / (max_val - min_val) * 255
|
||||
quantized = np.round(normalized).astype(np.int8)
|
||||
|
||||
return quantized
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def dequantize_matrix(
|
||||
quantized_matrix: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
min_val: float = None,
|
||||
max_val: float = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
反量化矩阵
|
||||
|
||||
Args:
|
||||
quantized_matrix: 量化后的矩阵
|
||||
quant_type: 量化类型
|
||||
min_val: 归一化最小值(int8反量化需要)
|
||||
max_val: 归一化最大值(int8反量化需要)
|
||||
|
||||
Returns:
|
||||
反量化后的矩阵
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return quantized_matrix.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
# 使用提供的归一化参数反量化
|
||||
if min_val is None or max_val is None:
|
||||
# 默认假设范围是 [0, 255] -> [-1, 1]
|
||||
return quantized_matrix.astype(np.float32) / 127.0
|
||||
else:
|
||||
# 恢复到原始范围
|
||||
normalized = quantized_matrix.astype(np.float32) / 255.0
|
||||
return normalized * (max_val - min_val) + min_val
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def estimate_memory_reduction(
|
||||
num_vectors: int,
|
||||
dimension: int,
|
||||
from_type: QuantizationType,
|
||||
to_type: QuantizationType,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
估算内存节省量
|
||||
|
||||
Args:
|
||||
num_vectors: 向量数量
|
||||
dimension: 向量维度
|
||||
from_type: 原始量化类型
|
||||
to_type: 目标量化类型
|
||||
|
||||
Returns:
|
||||
(原始大小MB, 量化后大小MB, 节省比例)
|
||||
"""
|
||||
# 计算每个向量占用的字节数
|
||||
bytes_per_element = {
|
||||
QuantizationType.FLOAT32: 4,
|
||||
QuantizationType.INT8: 1,
|
||||
QuantizationType.PQ: 0.25, # 假设压缩到1/4
|
||||
}
|
||||
|
||||
original_bytes = num_vectors * dimension * bytes_per_element[from_type]
|
||||
quantized_bytes = num_vectors * dimension * bytes_per_element[to_type]
|
||||
|
||||
original_mb = original_bytes / 1024 / 1024
|
||||
quantized_mb = quantized_bytes / 1024 / 1024
|
||||
reduction_ratio = (original_bytes - quantized_bytes) / original_bytes
|
||||
|
||||
return original_mb, quantized_mb, reduction_ratio
|
||||
|
||||
|
||||
def estimate_compression_stats(
|
||||
num_vectors: int,
|
||||
dimension: int,
|
||||
quant_type: QuantizationType,
|
||||
) -> dict:
|
||||
"""
|
||||
估算压缩统计信息
|
||||
|
||||
Args:
|
||||
num_vectors: 向量数量
|
||||
dimension: 向量维度
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
original_mb, quantized_mb, ratio = estimate_memory_reduction(
|
||||
num_vectors, dimension, QuantizationType.FLOAT32, quant_type
|
||||
)
|
||||
|
||||
return {
|
||||
"num_vectors": num_vectors,
|
||||
"dimension": dimension,
|
||||
"quantization_type": quant_type.value,
|
||||
"original_size_mb": round(original_mb, 2),
|
||||
"quantized_size_mb": round(quantized_mb, 2),
|
||||
"saved_mb": round(original_mb - quantized_mb, 2),
|
||||
"compression_ratio": round(ratio * 100, 2),
|
||||
}
|
||||
|
||||
|
||||
def _product_quantize(
|
||||
vector: np.ndarray, m: int = 8, k: int = 256
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
乘积量化 (PQ) 简化实现
|
||||
|
||||
Args:
|
||||
vector: 输入向量 (D,)
|
||||
m: 子空间数量
|
||||
k: 每个子空间的聚类中心数
|
||||
|
||||
Returns:
|
||||
(编码后的向量, 聚类中心)
|
||||
"""
|
||||
d = vector.shape[0]
|
||||
if d % m != 0:
|
||||
raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除")
|
||||
|
||||
ds = d // m # 子空间维度
|
||||
codes = np.zeros(m, dtype=np.uint8)
|
||||
centroids = np.zeros((m, k, ds), dtype=np.float32)
|
||||
|
||||
# 这里采用一种简化的 PQ:不进行 K-means 训练,
|
||||
# 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练)
|
||||
# 为了演示,我们直接将子空间切分为 k 份进行量化
|
||||
for i in range(m):
|
||||
sub_vec = vector[i * ds : (i + 1) * ds]
|
||||
# 简化:假定每个子空间的取值范围并划分
|
||||
# 实际 PQ 应使用 k-means 产生的 centroids
|
||||
# 这里为演示创建一个随机 codebook 并找到最接近的核心
|
||||
sub_min, sub_max = np.min(sub_vec), np.max(sub_vec)
|
||||
if sub_max == sub_min:
|
||||
linspace = np.zeros(k)
|
||||
else:
|
||||
linspace = np.linspace(sub_min, sub_max, k)
|
||||
|
||||
for j in range(k):
|
||||
centroids[i, j, :] = linspace[j]
|
||||
|
||||
# 编码:这里简化为取子空间均值找最接近的 centroid
|
||||
sub_mean = np.mean(sub_vec)
|
||||
code = np.argmin(np.abs(linspace - sub_mean))
|
||||
codes[i] = code
|
||||
|
||||
return codes, centroids
|
||||
|
||||
|
||||
def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
PQ 反量化
|
||||
|
||||
Args:
|
||||
codes: 编码向量 (M,)
|
||||
centroids: 聚类中心 (M, K, DS)
|
||||
|
||||
Returns:
|
||||
恢复后的向量 (D,)
|
||||
"""
|
||||
m, k, ds = centroids.shape
|
||||
vector = np.zeros(m * ds, dtype=np.float32)
|
||||
|
||||
for i in range(m):
|
||||
code = codes[i]
|
||||
vector[i * ds : (i + 1) * ds] = centroids[i, code, :]
|
||||
|
||||
return vector
|
||||
121
plugins/A_memorix/core/utils/relation_query.py
Normal file
121
plugins/A_memorix/core/utils/relation_query.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""关系查询规格解析工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationQuerySpec:
|
||||
raw: str
|
||||
is_structured: bool
|
||||
subject: Optional[str]
|
||||
predicate: Optional[str]
|
||||
object: Optional[str]
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
_NATURAL_LANGUAGE_PATTERN = re.compile(
|
||||
r"(^\s*(what|who|which|how|why|when|where)\b|"
|
||||
r"\?|?|"
|
||||
r"\b(relation|related|between)\b|"
|
||||
r"(什么关系|有哪些关系|之间|关联))",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _looks_like_natural_language(raw: str) -> bool:
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
return _NATURAL_LANGUAGE_PATTERN.search(text) is not None
|
||||
|
||||
|
||||
def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec:
|
||||
raw = str(relation_spec or "").strip()
|
||||
if not raw:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="empty",
|
||||
)
|
||||
|
||||
if "|" in raw:
|
||||
parts = [p.strip() for p in raw.split("|")]
|
||||
if len(parts) < 2:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="invalid_pipe_format",
|
||||
)
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0] or None,
|
||||
predicate=parts[1] or None,
|
||||
object=parts[2] if len(parts) > 2 and parts[2] else None,
|
||||
)
|
||||
|
||||
if "->" in raw:
|
||||
parts = [p.strip() for p in raw.split("->") if p.strip()]
|
||||
if len(parts) >= 3:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=parts[1],
|
||||
object=parts[2],
|
||||
)
|
||||
if len(parts) == 2:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=None,
|
||||
object=parts[1],
|
||||
)
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="invalid_arrow_format",
|
||||
)
|
||||
|
||||
if _looks_like_natural_language(raw):
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
)
|
||||
|
||||
# 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。
|
||||
# 两词形式过于模糊,不再视为结构化关系查询。
|
||||
parts = raw.split()
|
||||
if len(parts) == 3:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=parts[1],
|
||||
object=parts[2],
|
||||
)
|
||||
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
)
|
||||
164
plugins/A_memorix/core/utils/relation_write_service.py
Normal file
164
plugins/A_memorix/core/utils/relation_write_service.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
统一关系写入与关系向量化服务。
|
||||
|
||||
规则:
|
||||
1. 元数据是主数据源,向量是从索引。
|
||||
2. 关系先写 metadata,再写向量。
|
||||
3. 向量失败不回滚 metadata,依赖状态机与回填任务修复。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("A_Memorix.RelationWriteService")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationWriteResult:
|
||||
hash_value: str
|
||||
vector_written: bool
|
||||
vector_already_exists: bool
|
||||
vector_state: str
|
||||
|
||||
|
||||
class RelationWriteService:
|
||||
"""关系写入收口服务。"""
|
||||
|
||||
ERROR_MAX_LEN = 500
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: Any,
|
||||
graph_store: Any,
|
||||
vector_store: Any,
|
||||
embedding_manager: Any,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.graph_store = graph_store
|
||||
self.vector_store = vector_store
|
||||
self.embedding_manager = embedding_manager
|
||||
|
||||
@staticmethod
|
||||
def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str:
|
||||
s = str(subject or "").strip()
|
||||
p = str(predicate or "").strip()
|
||||
o = str(obj or "").strip()
|
||||
# 双表达:兼容关键词检索与自然语言问句
|
||||
return f"{s} {p} {o}\n{s}和{o}的关系是{p}"
|
||||
|
||||
async def ensure_relation_vector(
|
||||
self,
|
||||
hash_value: str,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
*,
|
||||
max_error_len: int = ERROR_MAX_LEN,
|
||||
) -> RelationWriteResult:
|
||||
"""
|
||||
为已有关系确保向量存在并更新状态。
|
||||
"""
|
||||
if hash_value in self.vector_store:
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=True,
|
||||
vector_state="ready",
|
||||
)
|
||||
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "pending")
|
||||
try:
|
||||
vector_text = self.build_relation_vector_text(subject, predicate, obj)
|
||||
embedding = await self.embedding_manager.encode(vector_text)
|
||||
self.vector_store.add(
|
||||
vectors=embedding.reshape(1, -1),
|
||||
ids=[hash_value],
|
||||
)
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
logger.info(
|
||||
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
|
||||
hash_value[:16],
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=True,
|
||||
vector_already_exists=False,
|
||||
vector_state="ready",
|
||||
)
|
||||
except ValueError:
|
||||
# 向量已存在冲突,按成功处理
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=True,
|
||||
vector_state="ready",
|
||||
)
|
||||
except Exception as e:
|
||||
err = str(e)[:max_error_len]
|
||||
self.metadata_store.set_relation_vector_state(
|
||||
hash_value,
|
||||
"failed",
|
||||
error=err,
|
||||
bump_retry=True,
|
||||
)
|
||||
logger.warning(
|
||||
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
|
||||
hash_value[:16],
|
||||
err,
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=False,
|
||||
vector_state="failed",
|
||||
)
|
||||
|
||||
async def upsert_relation_with_vector(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
confidence: float = 1.0,
|
||||
source_paragraph: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
write_vector: bool = True,
|
||||
) -> RelationWriteResult:
|
||||
"""
|
||||
统一关系写入:
|
||||
1) 写 metadata relation
|
||||
2) 写 graph edge relation_hash
|
||||
3) 按需写 relation vector
|
||||
"""
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
confidence=confidence,
|
||||
source_paragraph=source_paragraph,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
|
||||
|
||||
if not write_vector:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
return RelationWriteResult(
|
||||
hash_value=rel_hash,
|
||||
vector_written=False,
|
||||
vector_already_exists=False,
|
||||
vector_state="none",
|
||||
)
|
||||
|
||||
return await self.ensure_relation_vector(
|
||||
hash_value=rel_hash,
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
)
|
||||
197
plugins/A_memorix/core/utils/runtime_self_check.py
Normal file
197
plugins/A_memorix/core/utils/runtime_self_check.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Runtime self-check helpers for A_Memorix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.RuntimeSelfCheck")
|
||||
|
||||
_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check"
|
||||
|
||||
|
||||
def _safe_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
|
||||
def _get_config_value(config: Any, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def _build_report(
|
||||
*,
|
||||
ok: bool,
|
||||
code: str,
|
||||
message: str,
|
||||
configured_dimension: int,
|
||||
vector_store_dimension: int,
|
||||
detected_dimension: int,
|
||||
encoded_dimension: int,
|
||||
elapsed_ms: float,
|
||||
sample_text: str,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"ok": bool(ok),
|
||||
"code": str(code or "").strip(),
|
||||
"message": str(message or "").strip(),
|
||||
"configured_dimension": int(configured_dimension),
|
||||
"vector_store_dimension": int(vector_store_dimension),
|
||||
"detected_dimension": int(detected_dimension),
|
||||
"encoded_dimension": int(encoded_dimension),
|
||||
"elapsed_ms": float(elapsed_ms),
|
||||
"sample_text": str(sample_text or ""),
|
||||
"checked_at": time.time(),
|
||||
}
|
||||
|
||||
|
||||
async def run_embedding_runtime_self_check(
|
||||
*,
|
||||
config: Any,
|
||||
vector_store: Optional[Any],
|
||||
embedding_manager: Optional[Any],
|
||||
sample_text: str = _DEFAULT_SAMPLE_TEXT,
|
||||
) -> Dict[str, Any]:
|
||||
"""Probe the real embedding path and compare dimensions with runtime storage."""
|
||||
configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0)
|
||||
vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0)
|
||||
|
||||
if vector_store is None or embedding_manager is None:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="runtime_components_missing",
|
||||
message="vector_store 或 embedding_manager 未初始化",
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=0,
|
||||
encoded_dimension=0,
|
||||
elapsed_ms=0.0,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
detected_dimension = 0
|
||||
encoded_dimension = 0
|
||||
try:
|
||||
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
|
||||
encoded = await embedding_manager.encode(sample_text)
|
||||
if isinstance(encoded, np.ndarray):
|
||||
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
|
||||
else:
|
||||
encoded_dimension = len(encoded) if encoded is not None else 0
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
logger.warning("embedding runtime self-check failed: %s", exc)
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="embedding_probe_failed",
|
||||
message=f"embedding probe failed: {exc}",
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
expected_dimension = vector_store_dimension or configured_dimension or detected_dimension
|
||||
if expected_dimension <= 0:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="invalid_expected_dimension",
|
||||
message="无法确定期望 embedding 维度",
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
if encoded_dimension != expected_dimension:
|
||||
msg = (
|
||||
"embedding 真实输出维度与当前向量存储不一致: "
|
||||
f"expected={expected_dimension}, encoded={encoded_dimension}"
|
||||
)
|
||||
logger.error(msg)
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="embedding_dimension_mismatch",
|
||||
message=msg,
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
return _build_report(
|
||||
ok=True,
|
||||
code="ok",
|
||||
message="embedding runtime self-check passed",
|
||||
configured_dimension=configured_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
|
||||
async def ensure_runtime_self_check(
|
||||
plugin_or_config: Any,
|
||||
*,
|
||||
force: bool = False,
|
||||
sample_text: str = _DEFAULT_SAMPLE_TEXT,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run or reuse cached runtime self-check report."""
|
||||
if plugin_or_config is None:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="missing_plugin_or_config",
|
||||
message="plugin/config unavailable",
|
||||
configured_dimension=0,
|
||||
vector_store_dimension=0,
|
||||
detected_dimension=0,
|
||||
encoded_dimension=0,
|
||||
elapsed_ms=0.0,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
cache = getattr(plugin_or_config, "_runtime_self_check_report", None)
|
||||
if isinstance(cache, dict) and cache and not force:
|
||||
return cache
|
||||
|
||||
report = await run_embedding_runtime_self_check(
|
||||
config=getattr(plugin_or_config, "config", plugin_or_config),
|
||||
vector_store=getattr(plugin_or_config, "vector_store", None)
|
||||
if not isinstance(plugin_or_config, dict)
|
||||
else plugin_or_config.get("vector_store"),
|
||||
embedding_manager=getattr(plugin_or_config, "embedding_manager", None)
|
||||
if not isinstance(plugin_or_config, dict)
|
||||
else plugin_or_config.get("embedding_manager"),
|
||||
sample_text=sample_text,
|
||||
)
|
||||
try:
|
||||
setattr(plugin_or_config, "_runtime_self_check_report", report)
|
||||
except Exception:
|
||||
pass
|
||||
return report
|
||||
90
plugins/A_memorix/core/utils/search_postprocess.py
Normal file
90
plugins/A_memorix/core/utils/search_postprocess.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Post-processing helpers for unified search execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from .path_fallback_service import find_paths_from_query, to_retrieval_results
|
||||
|
||||
|
||||
def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]:
|
||||
"""Deduplicate results by hash/content while preserving at least one entry."""
|
||||
if not results:
|
||||
return [], 0
|
||||
|
||||
unique_results: List[Any] = []
|
||||
seen_hashes = set()
|
||||
seen_contents = set()
|
||||
|
||||
for item in results:
|
||||
content = str(getattr(item, "content", "") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content))
|
||||
if hash_value in seen_hashes:
|
||||
continue
|
||||
|
||||
is_dup = False
|
||||
for seen in seen_contents:
|
||||
if content in seen or seen in content:
|
||||
is_dup = True
|
||||
break
|
||||
if is_dup:
|
||||
continue
|
||||
|
||||
seen_hashes.add(hash_value)
|
||||
seen_contents.add(content)
|
||||
unique_results.append(item)
|
||||
|
||||
if not unique_results:
|
||||
unique_results.append(results[0])
|
||||
|
||||
removed_count = max(0, len(results) - len(unique_results))
|
||||
return unique_results, removed_count
|
||||
|
||||
|
||||
def maybe_apply_smart_path_fallback(
|
||||
*,
|
||||
query: str,
|
||||
results: List[Any],
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
enabled: bool,
|
||||
threshold: float,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> Tuple[List[Any], bool, int]:
|
||||
"""Append indirect relation paths when semantic results are weak."""
|
||||
if not enabled or not str(query or "").strip():
|
||||
return results, False, 0
|
||||
if graph_store is None or metadata_store is None:
|
||||
return results, False, 0
|
||||
|
||||
max_score = 0.0
|
||||
if results:
|
||||
try:
|
||||
max_score = float(getattr(results[0], "score", 0.0) or 0.0)
|
||||
except Exception:
|
||||
max_score = 0.0
|
||||
|
||||
if max_score >= float(threshold):
|
||||
return results, False, 0
|
||||
|
||||
paths = find_paths_from_query(
|
||||
query=query,
|
||||
graph_store=graph_store,
|
||||
metadata_store=metadata_store,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
if not paths:
|
||||
return results, False, 0
|
||||
|
||||
path_results = to_retrieval_results(paths)
|
||||
if not path_results:
|
||||
return results, False, 0
|
||||
|
||||
merged = list(path_results) + list(results)
|
||||
return merged, True, len(path_results)
|
||||
|
||||
170
plugins/A_memorix/core/utils/time_parser.py
Normal file
170
plugins/A_memorix/core/utils/time_parser.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
时间解析工具。
|
||||
|
||||
约束:
|
||||
1. 查询参数(Action/Command/Tool)仅接受结构化绝对时间:
|
||||
- YYYY/MM/DD
|
||||
- YYYY/MM/DD HH:mm
|
||||
2. 入库时允许更宽松格式(含时间戳、YYYY-MM-DD 等)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
|
||||
_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
|
||||
_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$")
|
||||
_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$")
|
||||
|
||||
_INGEST_FORMATS = [
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%d",
|
||||
]
|
||||
|
||||
_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"}
|
||||
|
||||
|
||||
def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float:
|
||||
"""解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。"""
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
raise ValueError("时间不能为空")
|
||||
|
||||
if _QUERY_DATE_RE.fullmatch(text):
|
||||
dt = datetime.strptime(text, "%Y/%m/%d")
|
||||
if is_end:
|
||||
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
|
||||
return dt.timestamp()
|
||||
|
||||
if _QUERY_MINUTE_RE.fullmatch(text):
|
||||
dt = datetime.strptime(text, "%Y/%m/%d %H:%M")
|
||||
return dt.timestamp()
|
||||
|
||||
raise ValueError(
|
||||
f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm"
|
||||
)
|
||||
|
||||
|
||||
def parse_query_time_range(
|
||||
time_from: Optional[str],
|
||||
time_to: Optional[str],
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""解析查询窗口并验证区间。"""
|
||||
ts_from = (
|
||||
parse_query_datetime_to_timestamp(time_from, is_end=False)
|
||||
if time_from
|
||||
else None
|
||||
)
|
||||
ts_to = (
|
||||
parse_query_datetime_to_timestamp(time_to, is_end=True)
|
||||
if time_to
|
||||
else None
|
||||
)
|
||||
|
||||
if ts_from is not None and ts_to is not None and ts_from > ts_to:
|
||||
raise ValueError("time_from 不能晚于 time_to")
|
||||
|
||||
return ts_from, ts_to
|
||||
|
||||
|
||||
def parse_ingest_datetime_to_timestamp(
|
||||
value: Any,
|
||||
is_end: bool = False,
|
||||
) -> Optional[float]:
|
||||
"""解析入库时间,允许 timestamp/常见字符串格式。"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
if _NUMERIC_RE.fullmatch(text):
|
||||
return float(text)
|
||||
|
||||
for fmt in _INGEST_FORMATS:
|
||||
try:
|
||||
dt = datetime.strptime(text, fmt)
|
||||
if fmt in _INGEST_DATE_FORMATS and is_end:
|
||||
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
|
||||
return dt.timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise ValueError(f"无法解析时间: {text}")
|
||||
|
||||
|
||||
def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""归一化 time_meta 到存储层字段。"""
|
||||
if not time_meta:
|
||||
return {}
|
||||
|
||||
normalized: Dict[str, Any] = {}
|
||||
|
||||
event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time"))
|
||||
event_start = parse_ingest_datetime_to_timestamp(
|
||||
time_meta.get("event_time_start"),
|
||||
is_end=False,
|
||||
)
|
||||
event_end = parse_ingest_datetime_to_timestamp(
|
||||
time_meta.get("event_time_end"),
|
||||
is_end=True,
|
||||
)
|
||||
|
||||
time_range = time_meta.get("time_range")
|
||||
if (
|
||||
isinstance(time_range, (list, tuple))
|
||||
and len(time_range) == 2
|
||||
):
|
||||
if event_start is None:
|
||||
event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False)
|
||||
if event_end is None:
|
||||
event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True)
|
||||
|
||||
if event_start is not None and event_end is not None and event_start > event_end:
|
||||
raise ValueError("event_time_start 不能晚于 event_time_end")
|
||||
|
||||
if event_time is not None:
|
||||
normalized["event_time"] = event_time
|
||||
if event_start is not None:
|
||||
normalized["event_time_start"] = event_start
|
||||
if event_end is not None:
|
||||
normalized["event_time_end"] = event_end
|
||||
|
||||
granularity = time_meta.get("time_granularity")
|
||||
if granularity:
|
||||
normalized["time_granularity"] = str(granularity)
|
||||
else:
|
||||
raw_time_values = [
|
||||
time_meta.get("event_time"),
|
||||
time_meta.get("event_time_start"),
|
||||
time_meta.get("event_time_end"),
|
||||
]
|
||||
has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None)
|
||||
normalized["time_granularity"] = "minute" if has_minute else "day"
|
||||
|
||||
confidence = time_meta.get("time_confidence")
|
||||
if confidence is not None:
|
||||
normalized["time_confidence"] = float(confidence)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def format_timestamp(ts: Optional[float]) -> Optional[str]:
|
||||
"""将 timestamp 格式化为 YYYY/MM/DD HH:mm。"""
|
||||
if ts is None:
|
||||
return None
|
||||
return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M")
|
||||
|
||||
Reference in New Issue
Block a user