添加 A_Memorix 插件 v2.0.0(包含运行时与文档)
引入 A_Memorix 插件 v2.0.0:新增大量运行时组件、存储/模式更新、检索能力提升、管理工具、导入/调优工作流以及相关文档。关键新增内容包括:lifecycle_orchestrator、SDKMemoryKernel/运行时初始化器、新的存储层与 metadata_store 变更(SCHEMA_VERSION v8)、检索增强(双路径检索、图关系召回、稀疏 BM25),以及多种工具服务(episode/person_profile/relation/segmentation/tuning/search execution)。同时新增 Web 导入/摘要导入器及大量维护脚本。还更新了插件清单、embedding API 适配器、plugin.py、requirements/pyproject,以及主入口文件,使新插件接入项目。该变更为 2.0.0 版本发布做好准备,实现统一的 SDK Tool 接口并扩展整体运行能力。
This commit is contained in:
@@ -302,7 +302,7 @@ class AggregateQueryService:
|
||||
)
|
||||
for (branch_name, _), payload in zip(scheduled, done):
|
||||
if isinstance(payload, Exception):
|
||||
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
|
||||
logger.error(f"aggregate branch failed: branch={branch_name} error={payload}")
|
||||
normalized = self._normalize_branch_payload(
|
||||
branch_name,
|
||||
{
|
||||
|
||||
@@ -70,7 +70,7 @@ class EpisodeRetrievalService:
|
||||
temporal=temporal,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
|
||||
logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}")
|
||||
else:
|
||||
paragraph_rank_map: Dict[str, int] = {}
|
||||
relation_rank_map: Dict[str, int] = {}
|
||||
|
||||
304
plugins/A_memorix/core/utils/episode_segmentation_service.py
Normal file
304
plugins/A_memorix/core/utils/episode_segmentation_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Episode 语义切分服务(LLM 主路径)。
|
||||
|
||||
职责:
|
||||
1. 组装语义切分提示词
|
||||
2. 调用 LLM 生成结构化 episode JSON
|
||||
3. 严格校验输出结构,返回标准化结果
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.config.config import model_config as host_model_config
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeSegmentationService")
|
||||
|
||||
|
||||
class EpisodeSegmentationService:
|
||||
"""基于 LLM 的 episode 语义切分服务。"""
|
||||
|
||||
SEGMENTATION_VERSION = "episode_mvp_v1"
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None):
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _is_task_config(obj: Any) -> bool:
|
||||
return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", []))
|
||||
|
||||
def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig:
|
||||
return TaskConfig(
|
||||
model_list=[model_name],
|
||||
max_tokens=template.max_tokens,
|
||||
temperature=template.temperature,
|
||||
slow_threshold=template.slow_threshold,
|
||||
selection_strategy=template.selection_strategy,
|
||||
)
|
||||
|
||||
def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]:
|
||||
preferred = ("utils", "replyer", "planner", "tool_use")
|
||||
for task_name in preferred:
|
||||
cfg = available_tasks.get(task_name)
|
||||
if self._is_task_config(cfg):
|
||||
return cfg
|
||||
for task_name, cfg in available_tasks.items():
|
||||
if task_name != "embedding" and self._is_task_config(cfg):
|
||||
return cfg
|
||||
for cfg in available_tasks.values():
|
||||
if self._is_task_config(cfg):
|
||||
return cfg
|
||||
return None
|
||||
|
||||
def _resolve_model_config(self) -> Tuple[Optional[Any], str]:
|
||||
available_tasks = llm_api.get_available_models() or {}
|
||||
if not available_tasks:
|
||||
return None, "unavailable"
|
||||
|
||||
selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip()
|
||||
model_dict = getattr(host_model_config, "models_dict", {}) or {}
|
||||
|
||||
if selector and selector.lower() != "auto":
|
||||
direct_task = available_tasks.get(selector)
|
||||
if self._is_task_config(direct_task):
|
||||
return direct_task, selector
|
||||
|
||||
if selector in model_dict:
|
||||
template = self._pick_template_task(available_tasks)
|
||||
if template is not None:
|
||||
return self._build_single_model_task(selector, template), selector
|
||||
|
||||
logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto")
|
||||
|
||||
for task_name in ("utils", "replyer", "planner", "tool_use"):
|
||||
cfg = available_tasks.get(task_name)
|
||||
if self._is_task_config(cfg):
|
||||
return cfg, task_name
|
||||
|
||||
fallback = self._pick_template_task(available_tasks)
|
||||
if fallback is not None:
|
||||
return fallback, "auto"
|
||||
return None, "unavailable"
|
||||
|
||||
@staticmethod
|
||||
def _clamp_score(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
num = float(value)
|
||||
except Exception:
|
||||
num = default
|
||||
if num < 0.0:
|
||||
return 0.0
|
||||
if num > 1.0:
|
||||
return 1.0
|
||||
return num
|
||||
|
||||
@staticmethod
|
||||
def _safe_json_loads(text: str) -> Dict[str, Any]:
|
||||
raw = str(text or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("empty_response")
|
||||
|
||||
if "```" in raw:
|
||||
raw = raw.replace("```json", "```").replace("```JSON", "```")
|
||||
parts = raw.split("```")
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part.startswith("{") and part.endswith("}"):
|
||||
raw = part
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
start = raw.find("{")
|
||||
end = raw.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
candidate = raw[start : end + 1]
|
||||
data = json.loads(candidate)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
|
||||
raise ValueError("invalid_json_response")
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
window_start: Optional[float],
|
||||
window_end: Optional[float],
|
||||
paragraphs: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
rows: List[str] = []
|
||||
for idx, item in enumerate(paragraphs, 1):
|
||||
p_hash = str(item.get("hash", "") or "").strip()
|
||||
content = str(item.get("content", "") or "").strip().replace("\r\n", "\n")
|
||||
content = content[:800]
|
||||
event_start = item.get("event_time_start")
|
||||
event_end = item.get("event_time_end")
|
||||
event_time = item.get("event_time")
|
||||
rows.append(
|
||||
(
|
||||
f"[{idx}] hash={p_hash}\n"
|
||||
f"event_time={event_time}\n"
|
||||
f"event_time_start={event_start}\n"
|
||||
f"event_time_end={event_end}\n"
|
||||
f"content={content}"
|
||||
)
|
||||
)
|
||||
|
||||
source_text = str(source or "").strip() or "unknown"
|
||||
return (
|
||||
"You are an episode segmentation engine.\n"
|
||||
"Group the given paragraphs into one or more coherent episodes.\n"
|
||||
"Return JSON ONLY. No markdown, no explanation.\n"
|
||||
"\n"
|
||||
"Hard JSON schema:\n"
|
||||
"{\n"
|
||||
' "episodes": [\n'
|
||||
" {\n"
|
||||
' "title": "string",\n'
|
||||
' "summary": "string",\n'
|
||||
' "paragraph_hashes": ["hash1", "hash2"],\n'
|
||||
' "participants": ["person1", "person2"],\n'
|
||||
' "keywords": ["kw1", "kw2"],\n'
|
||||
' "time_confidence": 0.0,\n'
|
||||
' "llm_confidence": 0.0\n'
|
||||
" }\n"
|
||||
" ]\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"Rules:\n"
|
||||
"1) paragraph_hashes must come from input only.\n"
|
||||
"2) title and summary must be non-empty.\n"
|
||||
"3) keep participants/keywords concise and deduplicated.\n"
|
||||
"4) if uncertain, still provide best effort confidence values.\n"
|
||||
"\n"
|
||||
f"source={source_text}\n"
|
||||
f"window_start={window_start}\n"
|
||||
f"window_end={window_end}\n"
|
||||
"paragraphs:\n"
|
||||
+ "\n\n".join(rows)
|
||||
)
|
||||
|
||||
def _normalize_episodes(
|
||||
self,
|
||||
*,
|
||||
payload: Dict[str, Any],
|
||||
input_hashes: List[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
raw_episodes = payload.get("episodes")
|
||||
if not isinstance(raw_episodes, list):
|
||||
raise ValueError("episodes_missing_or_not_list")
|
||||
|
||||
valid_hashes = set(input_hashes)
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for item in raw_episodes:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
title = str(item.get("title", "") or "").strip()
|
||||
summary = str(item.get("summary", "") or "").strip()
|
||||
if not title or not summary:
|
||||
continue
|
||||
|
||||
raw_hashes = item.get("paragraph_hashes")
|
||||
if not isinstance(raw_hashes, list):
|
||||
continue
|
||||
|
||||
dedup_hashes: List[str] = []
|
||||
seen_hashes = set()
|
||||
for h in raw_hashes:
|
||||
token = str(h or "").strip()
|
||||
if not token or token in seen_hashes or token not in valid_hashes:
|
||||
continue
|
||||
seen_hashes.add(token)
|
||||
dedup_hashes.append(token)
|
||||
|
||||
if not dedup_hashes:
|
||||
continue
|
||||
|
||||
participants = []
|
||||
for p in item.get("participants", []) or []:
|
||||
token = str(p or "").strip()
|
||||
if token:
|
||||
participants.append(token)
|
||||
|
||||
keywords = []
|
||||
for kw in item.get("keywords", []) or []:
|
||||
token = str(kw or "").strip()
|
||||
if token:
|
||||
keywords.append(token)
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"paragraph_hashes": dedup_hashes,
|
||||
"participants": participants[:16],
|
||||
"keywords": keywords[:20],
|
||||
"time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0),
|
||||
"llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5),
|
||||
}
|
||||
)
|
||||
|
||||
if not normalized:
|
||||
raise ValueError("episodes_all_invalid")
|
||||
return normalized
|
||||
|
||||
async def segment(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
window_start: Optional[float],
|
||||
window_end: Optional[float],
|
||||
paragraphs: List[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
if not paragraphs:
|
||||
raise ValueError("paragraphs_empty")
|
||||
|
||||
model_config, model_label = self._resolve_model_config()
|
||||
if model_config is None:
|
||||
raise RuntimeError("episode segmentation model unavailable")
|
||||
|
||||
prompt = self._build_prompt(
|
||||
source=source,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
paragraphs=paragraphs,
|
||||
)
|
||||
success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt=prompt,
|
||||
model_config=model_config,
|
||||
request_type="A_Memorix.EpisodeSegmentation",
|
||||
)
|
||||
if not success or not response:
|
||||
raise RuntimeError("llm_generate_failed")
|
||||
|
||||
payload = self._safe_json_loads(str(response))
|
||||
input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs]
|
||||
episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes)
|
||||
|
||||
return {
|
||||
"episodes": episodes,
|
||||
"segmentation_model": model_label,
|
||||
"segmentation_version": self.SEGMENTATION_VERSION,
|
||||
}
|
||||
|
||||
558
plugins/A_memorix/core/utils/episode_service.py
Normal file
558
plugins/A_memorix/core/utils/episode_service.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
Episode 聚合与落库服务。
|
||||
|
||||
流程:
|
||||
1. 从 pending 队列读取段落并组批
|
||||
2. 按 source + 时间窗口切组
|
||||
3. 调用 LLM 语义切分
|
||||
4. 写入 episodes + episode_paragraphs
|
||||
5. LLM 失败时使用确定性 fallback
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .episode_segmentation_service import EpisodeSegmentationService
|
||||
from .hash import compute_hash
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeService")
|
||||
|
||||
|
||||
class EpisodeService:
|
||||
"""Episode MVP 后台处理服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_store: Any,
|
||||
plugin_config: Optional[Any] = None,
|
||||
segmentation_service: Optional[EpisodeSegmentationService] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.segmentation_service = segmentation_service or EpisodeSegmentationService(
|
||||
plugin_config=self._config_dict(),
|
||||
)
|
||||
|
||||
def _config_dict(self) -> Dict[str, Any]:
|
||||
if isinstance(self.plugin_config, dict):
|
||||
return self.plugin_config
|
||||
return {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(self.plugin_config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _to_optional_float(value: Any) -> Optional[float]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _clamp_score(value: Any, default: float = 1.0) -> float:
|
||||
try:
|
||||
num = float(value)
|
||||
except Exception:
|
||||
num = default
|
||||
if num < 0.0:
|
||||
return 0.0
|
||||
if num > 1.0:
|
||||
return 1.0
|
||||
return num
|
||||
|
||||
@staticmethod
|
||||
def _paragraph_anchor(paragraph: Dict[str, Any]) -> float:
|
||||
for key in ("event_time_end", "event_time_start", "event_time", "created_at"):
|
||||
value = paragraph.get(key)
|
||||
try:
|
||||
if value is not None:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]:
|
||||
return (
|
||||
EpisodeService._paragraph_anchor(paragraph),
|
||||
str(paragraph.get("hash", "") or ""),
|
||||
)
|
||||
|
||||
def load_pending_paragraphs(
|
||||
self,
|
||||
pending_rows: List[Dict[str, Any]],
|
||||
) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""
|
||||
将 pending 行展开为段落上下文。
|
||||
|
||||
Returns:
|
||||
(loaded_paragraphs, missing_hashes)
|
||||
"""
|
||||
loaded: List[Dict[str, Any]] = []
|
||||
missing: List[str] = []
|
||||
for row in pending_rows or []:
|
||||
p_hash = str(row.get("paragraph_hash", "") or "").strip()
|
||||
if not p_hash:
|
||||
continue
|
||||
|
||||
paragraph = self.metadata_store.get_paragraph(p_hash)
|
||||
if not paragraph:
|
||||
missing.append(p_hash)
|
||||
continue
|
||||
|
||||
loaded.append(
|
||||
{
|
||||
"hash": p_hash,
|
||||
"source": str(row.get("source") or paragraph.get("source") or "").strip(),
|
||||
"content": str(paragraph.get("content", "") or ""),
|
||||
"created_at": self._to_optional_float(paragraph.get("created_at"))
|
||||
or self._to_optional_float(row.get("created_at"))
|
||||
or 0.0,
|
||||
"event_time": self._to_optional_float(paragraph.get("event_time")),
|
||||
"event_time_start": self._to_optional_float(paragraph.get("event_time_start")),
|
||||
"event_time_end": self._to_optional_float(paragraph.get("event_time_end")),
|
||||
"time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None,
|
||||
"time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0),
|
||||
}
|
||||
)
|
||||
return loaded, missing
|
||||
|
||||
def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。
|
||||
"""
|
||||
if not paragraphs:
|
||||
return []
|
||||
|
||||
max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20)))
|
||||
max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000)))
|
||||
window_seconds = max(
|
||||
60.0,
|
||||
float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0,
|
||||
)
|
||||
|
||||
by_source: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for paragraph in paragraphs:
|
||||
source = str(paragraph.get("source", "") or "").strip()
|
||||
by_source.setdefault(source, []).append(paragraph)
|
||||
|
||||
groups: List[Dict[str, Any]] = []
|
||||
for source, items in by_source.items():
|
||||
ordered = sorted(items, key=self._paragraph_sort_key)
|
||||
|
||||
current: List[Dict[str, Any]] = []
|
||||
current_chars = 0
|
||||
last_anchor: Optional[float] = None
|
||||
|
||||
def flush() -> None:
|
||||
nonlocal current, current_chars, last_anchor
|
||||
if not current:
|
||||
return
|
||||
sorted_current = sorted(current, key=self._paragraph_sort_key)
|
||||
groups.append(
|
||||
{
|
||||
"source": source,
|
||||
"paragraphs": sorted_current,
|
||||
}
|
||||
)
|
||||
current = []
|
||||
current_chars = 0
|
||||
last_anchor = None
|
||||
|
||||
for paragraph in ordered:
|
||||
anchor = self._paragraph_anchor(paragraph)
|
||||
content_len = len(str(paragraph.get("content", "") or ""))
|
||||
|
||||
need_flush = False
|
||||
if current:
|
||||
if len(current) >= max_paragraphs:
|
||||
need_flush = True
|
||||
elif current_chars + content_len > max_chars:
|
||||
need_flush = True
|
||||
elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds:
|
||||
need_flush = True
|
||||
|
||||
if need_flush:
|
||||
flush()
|
||||
|
||||
current.append(paragraph)
|
||||
current_chars += content_len
|
||||
last_anchor = anchor
|
||||
|
||||
flush()
|
||||
|
||||
groups.sort(
|
||||
key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0
|
||||
)
|
||||
return groups
|
||||
|
||||
def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]:
|
||||
starts: List[float] = []
|
||||
ends: List[float] = []
|
||||
granularity_priority = {
|
||||
"minute": 4,
|
||||
"hour": 3,
|
||||
"day": 2,
|
||||
"month": 1,
|
||||
"year": 0,
|
||||
}
|
||||
granularity = None
|
||||
granularity_rank = -1
|
||||
conf_values: List[float] = []
|
||||
|
||||
for p in paragraphs:
|
||||
s = self._to_optional_float(p.get("event_time_start"))
|
||||
e = self._to_optional_float(p.get("event_time_end"))
|
||||
t = self._to_optional_float(p.get("event_time"))
|
||||
c = self._to_optional_float(p.get("created_at"))
|
||||
|
||||
start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c))
|
||||
end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c))
|
||||
|
||||
if start_candidate is not None:
|
||||
starts.append(start_candidate)
|
||||
if end_candidate is not None:
|
||||
ends.append(end_candidate)
|
||||
|
||||
g = str(p.get("time_granularity", "") or "").strip().lower()
|
||||
if g in granularity_priority and granularity_priority[g] > granularity_rank:
|
||||
granularity_rank = granularity_priority[g]
|
||||
granularity = g
|
||||
|
||||
conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0))
|
||||
|
||||
time_start = min(starts) if starts else None
|
||||
time_end = max(ends) if ends else None
|
||||
time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0
|
||||
return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0)
|
||||
|
||||
def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]:
|
||||
seen = set()
|
||||
participants: List[str] = []
|
||||
for p_hash in paragraph_hashes:
|
||||
try:
|
||||
entities = self.metadata_store.get_paragraph_entities(p_hash)
|
||||
except Exception:
|
||||
entities = []
|
||||
for item in entities:
|
||||
name = str(item.get("name", "") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
participants.append(name)
|
||||
if len(participants) >= limit:
|
||||
return participants
|
||||
return participants
|
||||
|
||||
@staticmethod
|
||||
def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]:
|
||||
token_counter: Counter[str] = Counter()
|
||||
token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}")
|
||||
stop_words = {
|
||||
"the",
|
||||
"and",
|
||||
"that",
|
||||
"this",
|
||||
"with",
|
||||
"from",
|
||||
"for",
|
||||
"have",
|
||||
"will",
|
||||
"your",
|
||||
"you",
|
||||
"我们",
|
||||
"你们",
|
||||
"他们",
|
||||
"以及",
|
||||
"一个",
|
||||
"这个",
|
||||
"那个",
|
||||
"然后",
|
||||
"因为",
|
||||
"所以",
|
||||
}
|
||||
for p in paragraphs:
|
||||
text = str(p.get("content", "") or "").lower()
|
||||
for token in token_pattern.findall(text):
|
||||
if token in stop_words:
|
||||
continue
|
||||
token_counter[token] += 1
|
||||
|
||||
return [token for token, _ in token_counter.most_common(limit)]
|
||||
|
||||
def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
paragraphs = group.get("paragraphs", []) or []
|
||||
source = str(group.get("source", "") or "").strip()
|
||||
hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
|
||||
snippets = []
|
||||
for p in paragraphs[:3]:
|
||||
text = str(p.get("content", "") or "").strip().replace("\n", " ")
|
||||
if text:
|
||||
snippets.append(text[:140])
|
||||
summary = ";".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。"
|
||||
|
||||
time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs)
|
||||
participants = self._collect_participants(hashes, limit=12)
|
||||
keywords = self._derive_keywords(paragraphs, limit=10)
|
||||
|
||||
if time_start is not None:
|
||||
day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d")
|
||||
title = f"{source or 'unknown'} {day_text} 情景片段"
|
||||
else:
|
||||
title = f"{source or 'unknown'} 情景片段"
|
||||
|
||||
return {
|
||||
"title": title[:80],
|
||||
"summary": summary,
|
||||
"paragraph_hashes": hashes,
|
||||
"participants": participants,
|
||||
"keywords": keywords,
|
||||
"time_confidence": time_conf,
|
||||
"llm_confidence": 0.0,
|
||||
"event_time_start": time_start,
|
||||
"event_time_end": time_end,
|
||||
"time_granularity": granularity,
|
||||
"segmentation_model": "fallback_rule",
|
||||
"segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]:
|
||||
in_group = set(group_hashes_ordered)
|
||||
dedup: List[str] = []
|
||||
seen = set()
|
||||
for h in episode_hashes or []:
|
||||
token = str(h or "").strip()
|
||||
if not token or token not in in_group or token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
dedup.append(token)
|
||||
return dedup
|
||||
|
||||
async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
paragraphs = group.get("paragraphs", []) or []
|
||||
if not paragraphs:
|
||||
return {
|
||||
"payloads": [],
|
||||
"done_hashes": [],
|
||||
"episode_count": 0,
|
||||
"fallback_count": 0,
|
||||
}
|
||||
|
||||
source = str(group.get("source", "") or "").strip()
|
||||
group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
|
||||
group_start, group_end, _, _ = self._compute_time_meta(paragraphs)
|
||||
|
||||
fallback_used = False
|
||||
segmentation_model = "fallback_rule"
|
||||
segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION
|
||||
|
||||
try:
|
||||
llm_result = await self.segmentation_service.segment(
|
||||
source=source,
|
||||
window_start=group_start,
|
||||
window_end=group_end,
|
||||
paragraphs=paragraphs,
|
||||
)
|
||||
episodes = list(llm_result.get("episodes") or [])
|
||||
segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto"
|
||||
segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION
|
||||
if not episodes:
|
||||
raise ValueError("llm_empty_episodes")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Episode segmentation fallback: "
|
||||
f"source={source} "
|
||||
f"size={len(group_hashes)} "
|
||||
f"err={e}"
|
||||
)
|
||||
episodes = [self._build_fallback_episode(group)]
|
||||
fallback_used = True
|
||||
|
||||
stored_payloads: List[Dict[str, Any]] = []
|
||||
for episode in episodes:
|
||||
ordered_hashes = self._normalize_episode_hashes(
|
||||
episode_hashes=episode.get("paragraph_hashes", []),
|
||||
group_hashes_ordered=group_hashes,
|
||||
)
|
||||
if not ordered_hashes:
|
||||
continue
|
||||
|
||||
sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)]
|
||||
event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs)
|
||||
|
||||
participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()]
|
||||
keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()]
|
||||
if not participants:
|
||||
participants = self._collect_participants(ordered_hashes, limit=16)
|
||||
if not keywords:
|
||||
keywords = self._derive_keywords(sub_paragraphs, limit=12)
|
||||
|
||||
title = str(episode.get("title", "") or "").strip()[:120]
|
||||
summary = str(episode.get("summary", "") or "").strip()[:2000]
|
||||
if not title or not summary:
|
||||
continue
|
||||
|
||||
seed = json.dumps(
|
||||
{
|
||||
"source": source,
|
||||
"hashes": ordered_hashes,
|
||||
"version": segmentation_version,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
)
|
||||
episode_id = compute_hash(seed)
|
||||
|
||||
payload = {
|
||||
"episode_id": episode_id,
|
||||
"source": source or None,
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"event_time_start": episode.get("event_time_start", event_start),
|
||||
"event_time_end": episode.get("event_time_end", event_end),
|
||||
"time_granularity": episode.get("time_granularity", granularity),
|
||||
"time_confidence": self._clamp_score(
|
||||
episode.get("time_confidence"),
|
||||
default=time_conf_default,
|
||||
),
|
||||
"participants": participants[:16],
|
||||
"keywords": keywords[:20],
|
||||
"evidence_ids": ordered_hashes,
|
||||
"paragraph_count": len(ordered_hashes),
|
||||
"llm_confidence": self._clamp_score(
|
||||
episode.get("llm_confidence"),
|
||||
default=0.0 if fallback_used else 0.6,
|
||||
),
|
||||
"segmentation_model": (
|
||||
str(episode.get("segmentation_model", "") or "").strip()
|
||||
or ("fallback_rule" if fallback_used else segmentation_model)
|
||||
),
|
||||
"segmentation_version": (
|
||||
str(episode.get("segmentation_version", "") or "").strip()
|
||||
or segmentation_version
|
||||
),
|
||||
}
|
||||
stored_payloads.append(payload)
|
||||
|
||||
return {
|
||||
"payloads": stored_payloads,
|
||||
"done_hashes": group_hashes,
|
||||
"episode_count": len(stored_payloads),
|
||||
"fallback_count": 1 if fallback_used else 0,
|
||||
}
|
||||
|
||||
async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = await self._build_episode_payloads_for_group(group)
|
||||
stored_count = 0
|
||||
for payload in result.get("payloads") or []:
|
||||
stored = self.metadata_store.upsert_episode(payload)
|
||||
final_id = str(stored.get("episode_id") or payload.get("episode_id") or "")
|
||||
if final_id:
|
||||
self.metadata_store.bind_episode_paragraphs(
|
||||
final_id,
|
||||
list(payload.get("evidence_ids") or []),
|
||||
)
|
||||
stored_count += 1
|
||||
|
||||
result["episode_count"] = stored_count
|
||||
return {
|
||||
"done_hashes": list(result.get("done_hashes") or []),
|
||||
"episode_count": stored_count,
|
||||
"fallback_count": int(result.get("fallback_count") or 0),
|
||||
}
|
||||
|
||||
async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
loaded, missing_hashes = self.load_pending_paragraphs(pending_rows)
|
||||
groups = self.group_paragraphs(loaded)
|
||||
|
||||
done_hashes: List[str] = list(missing_hashes)
|
||||
failed_hashes: Dict[str, str] = {}
|
||||
episode_count = 0
|
||||
fallback_count = 0
|
||||
|
||||
for group in groups:
|
||||
group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])]
|
||||
try:
|
||||
result = await self.process_group(group)
|
||||
done_hashes.extend(result.get("done_hashes") or [])
|
||||
episode_count += int(result.get("episode_count") or 0)
|
||||
fallback_count += int(result.get("fallback_count") or 0)
|
||||
except Exception as e:
|
||||
err = str(e)[:500]
|
||||
for h in group_hashes:
|
||||
if h:
|
||||
failed_hashes[h] = err
|
||||
|
||||
dedup_done = list(dict.fromkeys([h for h in done_hashes if h]))
|
||||
return {
|
||||
"done_hashes": dedup_done,
|
||||
"failed_hashes": failed_hashes,
|
||||
"episode_count": episode_count,
|
||||
"fallback_count": fallback_count,
|
||||
"missing_count": len(missing_hashes),
|
||||
"group_count": len(groups),
|
||||
}
|
||||
|
||||
async def rebuild_source(self, source: str) -> Dict[str, Any]:
|
||||
token = str(source or "").strip()
|
||||
if not token:
|
||||
return {
|
||||
"source": "",
|
||||
"episode_count": 0,
|
||||
"fallback_count": 0,
|
||||
"group_count": 0,
|
||||
"paragraph_count": 0,
|
||||
}
|
||||
|
||||
paragraphs = self.metadata_store.get_live_paragraphs_by_source(token)
|
||||
if not paragraphs:
|
||||
replace_result = self.metadata_store.replace_episodes_for_source(token, [])
|
||||
return {
|
||||
"source": token,
|
||||
"episode_count": int(replace_result.get("episode_count") or 0),
|
||||
"fallback_count": 0,
|
||||
"group_count": 0,
|
||||
"paragraph_count": 0,
|
||||
}
|
||||
|
||||
groups = self.group_paragraphs(paragraphs)
|
||||
payloads: List[Dict[str, Any]] = []
|
||||
fallback_count = 0
|
||||
|
||||
for group in groups:
|
||||
result = await self._build_episode_payloads_for_group(group)
|
||||
payloads.extend(list(result.get("payloads") or []))
|
||||
fallback_count += int(result.get("fallback_count") or 0)
|
||||
|
||||
replace_result = self.metadata_store.replace_episodes_for_source(token, payloads)
|
||||
return {
|
||||
"source": token,
|
||||
"episode_count": int(replace_result.get("episode_count") or 0),
|
||||
"fallback_count": fallback_count,
|
||||
"group_count": len(groups),
|
||||
"paragraph_count": len(paragraphs),
|
||||
}
|
||||
@@ -9,7 +9,11 @@ import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
@@ -120,31 +124,40 @@ class PersonProfileService:
|
||||
if not key:
|
||||
return ""
|
||||
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id)
|
||||
.where(
|
||||
or_(
|
||||
PersonInfo.person_name == key,
|
||||
PersonInfo.user_nickname == key,
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id)
|
||||
.where(PersonInfo.group_cardname.contains(key))
|
||||
.limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
except Exception as e:
|
||||
logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}")
|
||||
|
||||
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
|
||||
return key.lower()
|
||||
|
||||
try:
|
||||
record = (
|
||||
PersonInfo.select(PersonInfo.person_id)
|
||||
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
|
||||
.first()
|
||||
)
|
||||
if record and record.person_id:
|
||||
return str(record.person_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
record = (
|
||||
PersonInfo.select(PersonInfo.person_id)
|
||||
.where(PersonInfo.group_nick_name.contains(key))
|
||||
.first()
|
||||
)
|
||||
if record and record.person_id:
|
||||
return str(record.person_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
|
||||
@@ -160,7 +173,7 @@ class PersonProfileService:
|
||||
names: List[str] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
value = str(item.get("group_nick_name", "")).strip()
|
||||
value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip()
|
||||
if value:
|
||||
names.append(value)
|
||||
elif isinstance(item, str):
|
||||
@@ -193,6 +206,42 @@ class PersonProfileService:
|
||||
traits.append(text)
|
||||
return traits[:10]
|
||||
|
||||
def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]:
|
||||
"""当人物主档案缺失时,从已有记忆证据里回捞可用别名。"""
|
||||
if not person_id:
|
||||
return [], ""
|
||||
|
||||
aliases: List[str] = []
|
||||
primary_name = ""
|
||||
seen = set()
|
||||
|
||||
try:
|
||||
paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}")
|
||||
return [], ""
|
||||
|
||||
for paragraph in paragraphs[:20]:
|
||||
paragraph_hash = str(paragraph.get("hash", "") or "").strip()
|
||||
if not paragraph_hash:
|
||||
continue
|
||||
try:
|
||||
paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash)
|
||||
except Exception:
|
||||
paragraph_entities = []
|
||||
for entity in paragraph_entities:
|
||||
name = str(entity.get("name", "") or "").strip()
|
||||
if not name or name == person_id:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
aliases.append(name)
|
||||
if not primary_name:
|
||||
primary_name = name
|
||||
return aliases, primary_name
|
||||
|
||||
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
|
||||
"""获取人物别名集合、主展示名、记忆特征。"""
|
||||
aliases: List[str] = []
|
||||
@@ -200,18 +249,28 @@ class PersonProfileService:
|
||||
memory_traits: List[str] = []
|
||||
if not person_id:
|
||||
return aliases, primary_name, memory_traits
|
||||
recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id)
|
||||
try:
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if not record:
|
||||
return aliases, primary_name, memory_traits
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
record = session.exec(
|
||||
select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1)
|
||||
).first()
|
||||
if not record:
|
||||
return recovered_aliases, recovered_primary_name or person_id, memory_traits
|
||||
person_name = str(getattr(record, "person_name", "") or "").strip()
|
||||
nickname = str(getattr(record, "nickname", "") or "").strip()
|
||||
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
|
||||
nickname = str(getattr(record, "user_nickname", "") or "").strip()
|
||||
group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None))
|
||||
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
|
||||
|
||||
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
|
||||
primary_name = (
|
||||
person_name
|
||||
or nickname
|
||||
or recovered_primary_name
|
||||
or str(getattr(record, "user_id", "") or "").strip()
|
||||
or person_id
|
||||
)
|
||||
|
||||
candidates = [person_name, nickname] + group_nicks
|
||||
candidates = [person_name, nickname] + group_nicks + recovered_aliases
|
||||
seen = set()
|
||||
for item in candidates:
|
||||
norm = str(item or "").strip()
|
||||
|
||||
@@ -82,8 +82,9 @@ class RelationWriteService:
|
||||
)
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
logger.info(
|
||||
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
|
||||
hash_value[:16],
|
||||
"metric.relation_vector_write_success=1 "
|
||||
"metric.relation_vector_write_success_count=1 "
|
||||
f"hash={hash_value[:16]}"
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
@@ -109,9 +110,10 @@ class RelationWriteService:
|
||||
bump_retry=True,
|
||||
)
|
||||
logger.warning(
|
||||
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
|
||||
hash_value[:16],
|
||||
err,
|
||||
"metric.relation_vector_write_fail=1 "
|
||||
"metric.relation_vector_write_fail_count=1 "
|
||||
f"hash={hash_value[:16]} "
|
||||
f"err={err}"
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
|
||||
1857
plugins/A_memorix/core/utils/retrieval_tuning_manager.py
Normal file
1857
plugins/A_memorix/core/utils/retrieval_tuning_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -61,6 +61,29 @@ def _build_report(
|
||||
}
|
||||
|
||||
|
||||
def _normalize_encoded_vector(encoded: Any) -> np.ndarray:
|
||||
if encoded is None:
|
||||
raise ValueError("embedding encode returned None")
|
||||
|
||||
if isinstance(encoded, np.ndarray):
|
||||
array = encoded
|
||||
else:
|
||||
array = np.asarray(encoded, dtype=np.float32)
|
||||
|
||||
if array.ndim == 2:
|
||||
if array.shape[0] != 1:
|
||||
raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}")
|
||||
array = array[0]
|
||||
|
||||
if array.ndim != 1:
|
||||
raise ValueError(f"embedding encode returned invalid ndim={array.ndim}")
|
||||
if array.size <= 0:
|
||||
raise ValueError("embedding encode returned empty vector")
|
||||
if not np.all(np.isfinite(array)):
|
||||
raise ValueError("embedding encode returned non-finite values")
|
||||
return array.astype(np.float32, copy=False)
|
||||
|
||||
|
||||
async def run_embedding_runtime_self_check(
|
||||
*,
|
||||
config: Any,
|
||||
@@ -91,13 +114,11 @@ async def run_embedding_runtime_self_check(
|
||||
try:
|
||||
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
|
||||
encoded = await embedding_manager.encode(sample_text)
|
||||
if isinstance(encoded, np.ndarray):
|
||||
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
|
||||
else:
|
||||
encoded_dimension = len(encoded) if encoded is not None else 0
|
||||
encoded_array = _normalize_encoded_vector(encoded)
|
||||
encoded_dimension = int(encoded_array.shape[0])
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
logger.warning("embedding runtime self-check failed: %s", exc)
|
||||
logger.warning(f"embedding runtime self-check failed: {exc}")
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="embedding_probe_failed",
|
||||
|
||||
442
plugins/A_memorix/core/utils/search_execution_service.py
Normal file
442
plugins/A_memorix/core/utils/search_execution_service.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
统一检索执行服务。
|
||||
|
||||
用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import TemporalQueryOptions
|
||||
from .search_postprocess import (
|
||||
apply_safe_content_dedup,
|
||||
maybe_apply_smart_path_fallback,
|
||||
)
|
||||
from .time_parser import parse_query_time_range
|
||||
|
||||
logger = get_logger("A_Memorix.SearchExecutionService")
|
||||
|
||||
|
||||
def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any:
|
||||
if not isinstance(config, dict):
|
||||
return default
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def _sanitize_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchExecutionRequest:
|
||||
caller: str
|
||||
stream_id: Optional[str] = None
|
||||
group_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
query_type: str = "search" # search|semantic|time|hybrid
|
||||
query: str = ""
|
||||
top_k: Optional[int] = None
|
||||
time_from: Optional[str] = None
|
||||
time_to: Optional[str] = None
|
||||
person: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
use_threshold: bool = True
|
||||
enable_ppr: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchExecutionResult:
|
||||
success: bool
|
||||
error: str = ""
|
||||
query_type: str = "search"
|
||||
query: str = ""
|
||||
top_k: int = 10
|
||||
time_from: Optional[str] = None
|
||||
time_to: Optional[str] = None
|
||||
person: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
temporal: Optional[TemporalQueryOptions] = None
|
||||
results: List[Any] = field(default_factory=list)
|
||||
elapsed_ms: float = 0.0
|
||||
chat_filtered: bool = False
|
||||
dedup_hit: bool = False
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self.results)
|
||||
|
||||
|
||||
class SearchExecutionService:
|
||||
"""统一检索执行服务。"""
|
||||
|
||||
@staticmethod
|
||||
def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]:
|
||||
if isinstance(plugin_config, dict):
|
||||
plugin_instance = plugin_config.get("plugin_instance")
|
||||
if plugin_instance is not None:
|
||||
return plugin_instance
|
||||
|
||||
try:
|
||||
from ...plugin import AMemorixPlugin
|
||||
|
||||
return getattr(AMemorixPlugin, "get_global_instance", lambda: None)()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_query_type(raw_query_type: str) -> str:
|
||||
query_type = _sanitize_text(raw_query_type).lower() or "search"
|
||||
if query_type == "semantic":
|
||||
return "search"
|
||||
return query_type
|
||||
|
||||
@staticmethod
|
||||
def _resolve_runtime_component(
|
||||
plugin_config: Optional[dict],
|
||||
plugin_instance: Optional[Any],
|
||||
key: str,
|
||||
) -> Optional[Any]:
|
||||
if isinstance(plugin_config, dict):
|
||||
value = plugin_config.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
if plugin_instance is not None:
|
||||
value = getattr(plugin_instance, key, None)
|
||||
if value is not None:
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_top_k(
|
||||
plugin_config: Optional[dict],
|
||||
query_type: str,
|
||||
top_k_raw: Optional[Any],
|
||||
) -> Tuple[bool, int, str]:
|
||||
temporal_default_top_k = int(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10)
|
||||
)
|
||||
default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10
|
||||
if top_k_raw is None:
|
||||
return True, max(1, min(50, default_top_k)), ""
|
||||
try:
|
||||
top_k = int(top_k_raw)
|
||||
except (TypeError, ValueError):
|
||||
return False, 0, "top_k 参数必须为整数"
|
||||
return True, max(1, min(50, top_k)), ""
|
||||
|
||||
@staticmethod
|
||||
def _build_temporal(
|
||||
plugin_config: Optional[dict],
|
||||
query_type: str,
|
||||
time_from_raw: Optional[str],
|
||||
time_to_raw: Optional[str],
|
||||
person: Optional[str],
|
||||
source: Optional[str],
|
||||
) -> Tuple[bool, Optional[TemporalQueryOptions], str]:
|
||||
if query_type not in {"time", "hybrid"}:
|
||||
return True, None, ""
|
||||
|
||||
temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True))
|
||||
if not temporal_enabled:
|
||||
return False, None, "时序检索已禁用(retrieval.temporal.enabled=false)"
|
||||
|
||||
if not time_from_raw and not time_to_raw:
|
||||
return False, None, "time/hybrid 模式至少需要 time_from 或 time_to"
|
||||
|
||||
try:
|
||||
ts_from, ts_to = parse_query_time_range(
|
||||
str(time_from_raw) if time_from_raw is not None else None,
|
||||
str(time_to_raw) if time_to_raw is not None else None,
|
||||
)
|
||||
except ValueError as e:
|
||||
return False, None, f"时间参数错误: {e}"
|
||||
|
||||
temporal = TemporalQueryOptions(
|
||||
time_from=ts_from,
|
||||
time_to=ts_to,
|
||||
person=_sanitize_text(person) or None,
|
||||
source=_sanitize_text(source) or None,
|
||||
allow_created_fallback=bool(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True)
|
||||
),
|
||||
candidate_multiplier=int(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8)
|
||||
),
|
||||
max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)),
|
||||
)
|
||||
return True, temporal, ""
|
||||
|
||||
@staticmethod
|
||||
def _build_request_key(
|
||||
request: SearchExecutionRequest,
|
||||
query_type: str,
|
||||
top_k: int,
|
||||
temporal: Optional[TemporalQueryOptions],
|
||||
) -> str:
|
||||
payload = {
|
||||
"stream_id": _sanitize_text(request.stream_id),
|
||||
"query_type": query_type,
|
||||
"query": _sanitize_text(request.query),
|
||||
"time_from": _sanitize_text(request.time_from),
|
||||
"time_to": _sanitize_text(request.time_to),
|
||||
"time_from_ts": temporal.time_from if temporal else None,
|
||||
"time_to_ts": temporal.time_to if temporal else None,
|
||||
"person": _sanitize_text(request.person),
|
||||
"source": _sanitize_text(request.source),
|
||||
"top_k": int(top_k),
|
||||
"use_threshold": bool(request.use_threshold),
|
||||
"enable_ppr": bool(request.enable_ppr),
|
||||
}
|
||||
payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
||||
return hashlib.sha1(payload_json.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
*,
|
||||
retriever: Any,
|
||||
threshold_filter: Optional[Any],
|
||||
plugin_config: Optional[dict],
|
||||
request: SearchExecutionRequest,
|
||||
enforce_chat_filter: bool = True,
|
||||
reinforce_access: bool = True,
|
||||
) -> SearchExecutionResult:
|
||||
if retriever is None:
|
||||
return SearchExecutionResult(success=False, error="知识检索器未初始化")
|
||||
|
||||
query_type = SearchExecutionService._normalize_query_type(request.query_type)
|
||||
query = _sanitize_text(request.query)
|
||||
if query_type not in {"search", "time", "hybrid"}:
|
||||
return SearchExecutionResult(
|
||||
success=False,
|
||||
error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid)",
|
||||
)
|
||||
|
||||
if query_type in {"search", "hybrid"} and not query:
|
||||
return SearchExecutionResult(
|
||||
success=False,
|
||||
error="search/hybrid 模式必须提供 query",
|
||||
)
|
||||
|
||||
top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k(
|
||||
plugin_config, query_type, request.top_k
|
||||
)
|
||||
if not top_k_ok:
|
||||
return SearchExecutionResult(success=False, error=top_k_error)
|
||||
|
||||
temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal(
|
||||
plugin_config=plugin_config,
|
||||
query_type=query_type,
|
||||
time_from_raw=request.time_from,
|
||||
time_to_raw=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
)
|
||||
if not temporal_ok:
|
||||
return SearchExecutionResult(success=False, error=temporal_error)
|
||||
|
||||
plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config)
|
||||
if (
|
||||
enforce_chat_filter
|
||||
and plugin_instance is not None
|
||||
and hasattr(plugin_instance, "is_chat_enabled")
|
||||
):
|
||||
if not plugin_instance.is_chat_enabled(
|
||||
stream_id=request.stream_id,
|
||||
group_id=request.group_id,
|
||||
user_id=request.user_id,
|
||||
):
|
||||
logger.info(
|
||||
"检索请求被聊天过滤拦截: "
|
||||
f"caller={request.caller}, "
|
||||
f"stream_id={request.stream_id}"
|
||||
)
|
||||
return SearchExecutionResult(
|
||||
success=True,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
time_from=request.time_from,
|
||||
time_to=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
temporal=temporal,
|
||||
results=[],
|
||||
elapsed_ms=0.0,
|
||||
chat_filtered=True,
|
||||
dedup_hit=False,
|
||||
)
|
||||
|
||||
request_key = SearchExecutionService._build_request_key(
|
||||
request=request,
|
||||
query_type=query_type,
|
||||
top_k=top_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
|
||||
async def _executor() -> Dict[str, Any]:
|
||||
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
|
||||
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
|
||||
started_at = time.time()
|
||||
try:
|
||||
retrieved = await retriever.retrieve(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
|
||||
should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None
|
||||
if (
|
||||
query_type == "time"
|
||||
and not query
|
||||
and bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.time.skip_threshold_when_query_empty",
|
||||
True,
|
||||
)
|
||||
)
|
||||
):
|
||||
should_apply_threshold = False
|
||||
|
||||
if should_apply_threshold:
|
||||
retrieved = threshold_filter.filter(retrieved)
|
||||
|
||||
if (
|
||||
reinforce_access
|
||||
and plugin_instance is not None
|
||||
and hasattr(plugin_instance, "reinforce_access")
|
||||
):
|
||||
relation_hashes = [
|
||||
item.hash_value
|
||||
for item in retrieved
|
||||
if getattr(item, "result_type", "") == "relation"
|
||||
]
|
||||
if relation_hashes:
|
||||
await plugin_instance.reinforce_access(relation_hashes)
|
||||
|
||||
if query_type == "search":
|
||||
graph_store = SearchExecutionService._resolve_runtime_component(
|
||||
plugin_config, plugin_instance, "graph_store"
|
||||
)
|
||||
metadata_store = SearchExecutionService._resolve_runtime_component(
|
||||
plugin_config, plugin_instance, "metadata_store"
|
||||
)
|
||||
fallback_enabled = bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.smart_fallback.enabled",
|
||||
True,
|
||||
)
|
||||
)
|
||||
fallback_threshold = float(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.smart_fallback.threshold",
|
||||
0.6,
|
||||
)
|
||||
)
|
||||
retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback(
|
||||
query=query,
|
||||
results=list(retrieved),
|
||||
graph_store=graph_store,
|
||||
metadata_store=metadata_store,
|
||||
enabled=fallback_enabled,
|
||||
threshold=fallback_threshold,
|
||||
)
|
||||
if fallback_triggered:
|
||||
logger.info(
|
||||
"metric.smart_fallback_triggered_count=1 "
|
||||
f"caller={request.caller} "
|
||||
f"added={fallback_added}"
|
||||
)
|
||||
|
||||
dedup_enabled = bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.safe_content_dedup.enabled",
|
||||
True,
|
||||
)
|
||||
)
|
||||
if dedup_enabled:
|
||||
retrieved, removed_count = apply_safe_content_dedup(list(retrieved))
|
||||
if removed_count > 0:
|
||||
logger.info(
|
||||
f"metric.safe_dedup_removed_count={removed_count} "
|
||||
f"caller={request.caller}"
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - started_at) * 1000.0
|
||||
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
||||
finally:
|
||||
setattr(retriever.config, "enable_ppr", original_ppr)
|
||||
|
||||
dedup_hit = False
|
||||
try:
|
||||
# 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。
|
||||
bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning"
|
||||
if (
|
||||
not bypass_request_dedup
|
||||
and
|
||||
plugin_instance is not None
|
||||
and hasattr(plugin_instance, "execute_request_with_dedup")
|
||||
):
|
||||
dedup_hit, payload = await plugin_instance.execute_request_with_dedup(
|
||||
request_key,
|
||||
_executor,
|
||||
)
|
||||
else:
|
||||
payload = await _executor()
|
||||
except Exception as e:
|
||||
return SearchExecutionResult(success=False, error=f"知识检索失败: {e}")
|
||||
|
||||
if dedup_hit:
|
||||
logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}")
|
||||
|
||||
return SearchExecutionResult(
|
||||
success=True,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
time_from=request.time_from,
|
||||
time_to=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
temporal=temporal,
|
||||
results=payload.get("results", []),
|
||||
elapsed_ms=float(payload.get("elapsed_ms", 0.0)),
|
||||
chat_filtered=False,
|
||||
dedup_hit=bool(dedup_hit),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]:
|
||||
serialized: List[Dict[str, Any]] = []
|
||||
for item in results:
|
||||
metadata = dict(getattr(item, "metadata", {}) or {})
|
||||
if "time_meta" not in metadata:
|
||||
metadata["time_meta"] = {}
|
||||
serialized.append(
|
||||
{
|
||||
"hash": getattr(item, "hash_value", ""),
|
||||
"type": getattr(item, "result_type", ""),
|
||||
"score": float(getattr(item, "score", 0.0)),
|
||||
"content": getattr(item, "content", ""),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return serialized
|
||||
425
plugins/A_memorix/core/utils/summary_importer.py
Normal file
425
plugins/A_memorix/core/utils/summary_importer.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
聊天总结与知识导入工具
|
||||
|
||||
该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系
|
||||
导入到 A_memorix 的存储组件中。
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services import llm_service as llm_api
|
||||
from src.services import message_service as message_api
|
||||
from src.config.config import global_config, model_config as host_model_config
|
||||
from src.config.model_configs import TaskConfig
|
||||
|
||||
from ..storage import (
|
||||
KnowledgeType,
|
||||
VectorStore,
|
||||
GraphStore,
|
||||
MetadataStore,
|
||||
resolve_stored_knowledge_type,
|
||||
)
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
from .relation_write_service import RelationWriteService
|
||||
from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check
|
||||
|
||||
logger = get_logger("A_Memorix.SummaryImporter")
|
||||
|
||||
# 默认总结提示词模版
|
||||
SUMMARY_PROMPT_TEMPLATE = """
|
||||
你是 {bot_name}。{personality_context}
|
||||
现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。
|
||||
|
||||
聊天记录内容:
|
||||
{chat_history}
|
||||
|
||||
请完成以下任务:
|
||||
1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。
|
||||
2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。
|
||||
|
||||
请严格以 JSON 格式输出,格式如下:
|
||||
{{
|
||||
"summary": "总结文本内容",
|
||||
"entities": ["张三", "李四"],
|
||||
"relations": [
|
||||
{{"subject": "张三", "predicate": "认识", "object": "李四"}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。
|
||||
"""
|
||||
|
||||
class SummaryImporter:
|
||||
"""总结并导入知识的工具类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorStore,
|
||||
graph_store: GraphStore,
|
||||
metadata_store: MetadataStore,
|
||||
embedding_manager: EmbeddingAPIAdapter,
|
||||
plugin_config: dict
|
||||
):
|
||||
self.vector_store = vector_store
|
||||
self.graph_store = graph_store
|
||||
self.metadata_store = metadata_store
|
||||
self.embedding_manager = embedding_manager
|
||||
self.plugin_config = plugin_config
|
||||
self.relation_write_service: Optional[RelationWriteService] = (
|
||||
plugin_config.get("relation_write_service")
|
||||
if isinstance(plugin_config, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]:
|
||||
"""标准化 summarization.model_name 配置(vNext 仅接受字符串数组)。"""
|
||||
if raw_value is None:
|
||||
return ["auto"]
|
||||
if isinstance(raw_value, list):
|
||||
selectors = [str(x).strip() for x in raw_value if str(x).strip()]
|
||||
return selectors or ["auto"]
|
||||
raise ValueError(
|
||||
"summarization.model_name 在 vNext 必须为 List[str]。"
|
||||
" 请执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
|
||||
def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]:
|
||||
"""
|
||||
选择总结默认任务,避免错误落到 embedding 任务。
|
||||
优先级:replyer > utils > planner > tool_use > 其他非 embedding。
|
||||
"""
|
||||
preferred = ("replyer", "utils", "planner", "tool_use")
|
||||
for name in preferred:
|
||||
cfg = available_tasks.get(name)
|
||||
if cfg and cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
for name, cfg in available_tasks.items():
|
||||
if name != "embedding" and cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
for name, cfg in available_tasks.items():
|
||||
if cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
return None, None
|
||||
|
||||
def _resolve_summary_model_config(self) -> Optional[TaskConfig]:
|
||||
"""
|
||||
解析 summarization.model_name 为 TaskConfig。
|
||||
支持:
|
||||
- "auto"
|
||||
- "replyer"(任务名)
|
||||
- "some-model-name"(具体模型名)
|
||||
- ["utils:model1", "utils:model2", "replyer"](数组混合语法)
|
||||
"""
|
||||
available_tasks = llm_api.get_available_models()
|
||||
if not available_tasks:
|
||||
return None
|
||||
|
||||
raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto")
|
||||
selectors = self._normalize_summary_model_selectors(raw_cfg)
|
||||
default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks)
|
||||
|
||||
selected_models: List[str] = []
|
||||
base_cfg: Optional[TaskConfig] = None
|
||||
model_dict = getattr(host_model_config, "models_dict", {})
|
||||
|
||||
def _append_models(models: List[str]):
|
||||
for model_name in models:
|
||||
if model_name and model_name not in selected_models:
|
||||
selected_models.append(model_name)
|
||||
|
||||
for raw_selector in selectors:
|
||||
selector = raw_selector.strip()
|
||||
if not selector:
|
||||
continue
|
||||
|
||||
if selector.lower() == "auto":
|
||||
if default_task_cfg:
|
||||
_append_models(default_task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = default_task_cfg
|
||||
continue
|
||||
|
||||
if ":" in selector:
|
||||
task_name, model_name = selector.split(":", 1)
|
||||
task_name = task_name.strip()
|
||||
model_name = model_name.strip()
|
||||
task_cfg = available_tasks.get(task_name)
|
||||
if not task_cfg:
|
||||
logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过")
|
||||
continue
|
||||
|
||||
if base_cfg is None:
|
||||
base_cfg = task_cfg
|
||||
|
||||
if not model_name or model_name.lower() == "auto":
|
||||
_append_models(task_cfg.model_list)
|
||||
continue
|
||||
|
||||
if model_name in model_dict or model_name in task_cfg.model_list:
|
||||
_append_models([model_name])
|
||||
else:
|
||||
logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过")
|
||||
continue
|
||||
|
||||
task_cfg = available_tasks.get(selector)
|
||||
if task_cfg:
|
||||
_append_models(task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = task_cfg
|
||||
continue
|
||||
|
||||
if selector in model_dict:
|
||||
_append_models([selector])
|
||||
continue
|
||||
|
||||
logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过")
|
||||
|
||||
if not selected_models:
|
||||
if default_task_cfg:
|
||||
_append_models(default_task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = default_task_cfg
|
||||
else:
|
||||
first_cfg = next(iter(available_tasks.values()))
|
||||
_append_models(first_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = first_cfg
|
||||
|
||||
if not selected_models:
|
||||
return None
|
||||
|
||||
template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values()))
|
||||
return TaskConfig(
|
||||
model_list=selected_models,
|
||||
max_tokens=template_cfg.max_tokens,
|
||||
temperature=template_cfg.temperature,
|
||||
slow_threshold=template_cfg.slow_threshold,
|
||||
selection_strategy=template_cfg.selection_strategy,
|
||||
)
|
||||
|
||||
async def import_from_stream(
|
||||
self,
|
||||
stream_id: str,
|
||||
context_length: Optional[int] = None,
|
||||
include_personality: Optional[bool] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
从指定的聊天流中提取记录并执行总结导入
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流 ID
|
||||
context_length: 总结的历史消息条数
|
||||
include_personality: 是否包含人设
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 结果消息)
|
||||
"""
|
||||
try:
|
||||
self_check_ok, self_check_msg = await self._ensure_runtime_self_check()
|
||||
if not self_check_ok:
|
||||
return False, f"导入前自检失败: {self_check_msg}"
|
||||
|
||||
# 1. 获取配置
|
||||
if context_length is None:
|
||||
context_length = self.plugin_config.get("summarization", {}).get("context_length", 50)
|
||||
|
||||
if include_personality is None:
|
||||
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
|
||||
|
||||
# 2. 获取历史消息
|
||||
# 获取当前时间之前的消息
|
||||
now = time.time()
|
||||
messages = message_api.get_messages_before_time_in_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=now,
|
||||
limit=context_length
|
||||
)
|
||||
|
||||
if not messages:
|
||||
return False, "未找到有效的聊天记录进行总结"
|
||||
|
||||
# 转换为可读文本
|
||||
chat_history_text = message_api.build_readable_messages(messages)
|
||||
|
||||
# 3. 准备提示词内容
|
||||
bot_name = global_config.bot.nickname or "机器人"
|
||||
personality_context = ""
|
||||
if include_personality:
|
||||
personality = getattr(global_config.bot, "personality", "")
|
||||
if personality:
|
||||
personality_context = f"你的性格设定是:{personality}"
|
||||
|
||||
# 4. 调用 LLM
|
||||
prompt = SUMMARY_PROMPT_TEMPLATE.format(
|
||||
bot_name=bot_name,
|
||||
personality_context=personality_context,
|
||||
chat_history=chat_history_text
|
||||
)
|
||||
|
||||
model_config_to_use = self._resolve_summary_model_config()
|
||||
if model_config_to_use is None:
|
||||
return False, "未找到可用的总结模型配置"
|
||||
|
||||
logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
|
||||
logger.info(f"总结模型候选列表: {model_config_to_use.model_list}")
|
||||
|
||||
success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt=prompt,
|
||||
model_config=model_config_to_use,
|
||||
request_type="A_Memorix.ChatSummarization"
|
||||
)
|
||||
|
||||
if not success or not response:
|
||||
return False, "LLM 生成总结失败"
|
||||
|
||||
# 5. 解析结果
|
||||
data = self._parse_llm_response(response)
|
||||
if not data or "summary" not in data:
|
||||
return False, "解析 LLM 响应失败或总结为空"
|
||||
|
||||
summary_text = data["summary"]
|
||||
entities = data.get("entities", [])
|
||||
relations = data.get("relations", [])
|
||||
msg_times = [
|
||||
float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)())
|
||||
for msg in messages
|
||||
if getattr(msg, "time", None) is not None
|
||||
]
|
||||
time_meta = {}
|
||||
if msg_times:
|
||||
time_meta = {
|
||||
"event_time_start": min(msg_times),
|
||||
"event_time_end": max(msg_times),
|
||||
"time_granularity": "minute",
|
||||
"time_confidence": 0.95,
|
||||
}
|
||||
|
||||
# 6. 执行导入
|
||||
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta)
|
||||
|
||||
# 7. 持久化
|
||||
self.vector_store.save()
|
||||
self.graph_store.save()
|
||||
|
||||
result_msg = (
|
||||
f"✅ 总结导入成功\n"
|
||||
f"📝 总结长度: {len(summary_text)}\n"
|
||||
f"📌 提取实体: {len(entities)}\n"
|
||||
f"🔗 提取关系: {len(relations)}"
|
||||
)
|
||||
return True, result_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}")
|
||||
return False, f"错误: {str(e)}"
|
||||
|
||||
async def _ensure_runtime_self_check(self) -> Tuple[bool, str]:
|
||||
plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None
|
||||
if plugin_instance is not None:
|
||||
report = await ensure_runtime_self_check(plugin_instance)
|
||||
else:
|
||||
report = await run_embedding_runtime_self_check(
|
||||
config=self.plugin_config,
|
||||
vector_store=self.vector_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
)
|
||||
if bool(report.get("ok", False)):
|
||||
return True, ""
|
||||
return (
|
||||
False,
|
||||
f"{report.get('message', 'unknown')} "
|
||||
f"(configured={report.get('configured_dimension', 0)}, "
|
||||
f"store={report.get('vector_store_dimension', 0)}, "
|
||||
f"encoded={report.get('encoded_dimension', 0)})",
|
||||
)
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析 LLM 返回的 JSON"""
|
||||
try:
|
||||
# 尝试查找 JSON
|
||||
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning(f"解析总结 JSON 失败: {e}")
|
||||
return {}
|
||||
|
||||
async def _execute_import(
|
||||
self,
|
||||
summary: str,
|
||||
entities: List[str],
|
||||
relations: List[Dict[str, str]],
|
||||
stream_id: str,
|
||||
time_meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""将数据写入存储"""
|
||||
# 获取默认知识类型
|
||||
type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative")
|
||||
try:
|
||||
knowledge_type = resolve_stored_knowledge_type(type_str, content=summary)
|
||||
except ValueError:
|
||||
logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative")
|
||||
knowledge_type = KnowledgeType.NARRATIVE
|
||||
|
||||
# 导入总结文本
|
||||
hash_value = self.metadata_store.add_paragraph(
|
||||
content=summary,
|
||||
source=f"chat_summary:{stream_id}",
|
||||
knowledge_type=knowledge_type.value,
|
||||
time_meta=time_meta,
|
||||
)
|
||||
|
||||
embedding = await self.embedding_manager.encode(summary)
|
||||
self.vector_store.add(
|
||||
vectors=embedding.reshape(1, -1),
|
||||
ids=[hash_value]
|
||||
)
|
||||
|
||||
# 导入实体
|
||||
if entities:
|
||||
self.graph_store.add_nodes(entities)
|
||||
|
||||
# 导入关系
|
||||
rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {})
|
||||
if not isinstance(rv_cfg, dict):
|
||||
rv_cfg = {}
|
||||
write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True))
|
||||
for rel in relations:
|
||||
s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object")
|
||||
if all([s, p, o]):
|
||||
if self.relation_write_service is not None:
|
||||
await self.relation_write_service.upsert_relation_with_vector(
|
||||
subject=s,
|
||||
predicate=p,
|
||||
obj=o,
|
||||
confidence=1.0,
|
||||
source_paragraph=summary,
|
||||
write_vector=write_vector,
|
||||
)
|
||||
else:
|
||||
# 写入元数据
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
subject=s,
|
||||
predicate=p,
|
||||
obj=o,
|
||||
confidence=1.0,
|
||||
source_paragraph=summary
|
||||
)
|
||||
# 写入图数据库(写入 relation_hashes,确保后续可按关系精确修剪)
|
||||
self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash])
|
||||
try:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"总结导入完成: hash={hash_value[:8]}")
|
||||
3522
plugins/A_memorix/core/utils/web_import_manager.py
Normal file
3522
plugins/A_memorix/core/utils/web_import_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user