添加 A_Memorix 插件 v2.0.0(包含运行时与文档)

引入 A_Memorix 插件 v2.0.0:新增大量运行时组件、存储/模式更新、检索能力提升、管理工具、导入/调优工作流以及相关文档。关键新增内容包括:lifecycle_orchestrator、SDKMemoryKernel/运行时初始化器、新的存储层与 metadata_store 变更(SCHEMA_VERSION v8)、检索增强(双路径检索、图关系召回、稀疏 BM25),以及多种工具服务(episode/person_profile/relation/segmentation/tuning/search execution)。同时新增 Web 导入/摘要导入器及大量维护脚本。还更新了插件清单、embedding API 适配器、plugin.py、requirements/pyproject,以及主入口文件,使新插件接入项目。该变更为 2.0.0 版本发布做好准备,实现统一的 SDK Tool 接口并扩展整体运行能力。
This commit is contained in:
DawnARC
2026-03-19 00:09:04 +08:00
parent eb257345dd
commit 71b3a828c6
44 changed files with 18193 additions and 405 deletions

View File

@@ -302,7 +302,7 @@ class AggregateQueryService:
)
for (branch_name, _), payload in zip(scheduled, done):
if isinstance(payload, Exception):
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
logger.error(f"aggregate branch failed: branch={branch_name} error={payload}")
normalized = self._normalize_branch_payload(
branch_name,
{

View File

@@ -70,7 +70,7 @@ class EpisodeRetrievalService:
temporal=temporal,
)
except Exception as exc:
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}")
else:
paragraph_rank_map: Dict[str, int] = {}
relation_rank_map: Dict[str, int] = {}

View File

@@ -0,0 +1,304 @@
"""
Episode 语义切分服务LLM 主路径)。
职责:
1. 组装语义切分提示词
2. 调用 LLM 生成结构化 episode JSON
3. 严格校验输出结构,返回标准化结果
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.config.model_configs import TaskConfig
from src.config.config import model_config as host_model_config
from src.services import llm_service as llm_api
logger = get_logger("A_Memorix.EpisodeSegmentationService")
class EpisodeSegmentationService:
"""基于 LLM 的 episode 语义切分服务。"""
SEGMENTATION_VERSION = "episode_mvp_v1"
def __init__(self, plugin_config: Optional[dict] = None):
self.plugin_config = plugin_config or {}
def _cfg(self, key: str, default: Any = None) -> Any:
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
@staticmethod
def _is_task_config(obj: Any) -> bool:
return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", []))
def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig:
return TaskConfig(
model_list=[model_name],
max_tokens=template.max_tokens,
temperature=template.temperature,
slow_threshold=template.slow_threshold,
selection_strategy=template.selection_strategy,
)
def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]:
preferred = ("utils", "replyer", "planner", "tool_use")
for task_name in preferred:
cfg = available_tasks.get(task_name)
if self._is_task_config(cfg):
return cfg
for task_name, cfg in available_tasks.items():
if task_name != "embedding" and self._is_task_config(cfg):
return cfg
for cfg in available_tasks.values():
if self._is_task_config(cfg):
return cfg
return None
def _resolve_model_config(self) -> Tuple[Optional[Any], str]:
available_tasks = llm_api.get_available_models() or {}
if not available_tasks:
return None, "unavailable"
selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip()
model_dict = getattr(host_model_config, "models_dict", {}) or {}
if selector and selector.lower() != "auto":
direct_task = available_tasks.get(selector)
if self._is_task_config(direct_task):
return direct_task, selector
if selector in model_dict:
template = self._pick_template_task(available_tasks)
if template is not None:
return self._build_single_model_task(selector, template), selector
logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto")
for task_name in ("utils", "replyer", "planner", "tool_use"):
cfg = available_tasks.get(task_name)
if self._is_task_config(cfg):
return cfg, task_name
fallback = self._pick_template_task(available_tasks)
if fallback is not None:
return fallback, "auto"
return None, "unavailable"
@staticmethod
def _clamp_score(value: Any, default: float = 0.0) -> float:
try:
num = float(value)
except Exception:
num = default
if num < 0.0:
return 0.0
if num > 1.0:
return 1.0
return num
@staticmethod
def _safe_json_loads(text: str) -> Dict[str, Any]:
raw = str(text or "").strip()
if not raw:
raise ValueError("empty_response")
if "```" in raw:
raw = raw.replace("```json", "```").replace("```JSON", "```")
parts = raw.split("```")
for part in parts:
part = part.strip()
if part.startswith("{") and part.endswith("}"):
raw = part
break
try:
data = json.loads(raw)
if isinstance(data, dict):
return data
except Exception:
pass
start = raw.find("{")
end = raw.rfind("}")
if start >= 0 and end > start:
candidate = raw[start : end + 1]
data = json.loads(candidate)
if isinstance(data, dict):
return data
raise ValueError("invalid_json_response")
def _build_prompt(
self,
*,
source: str,
window_start: Optional[float],
window_end: Optional[float],
paragraphs: List[Dict[str, Any]],
) -> str:
rows: List[str] = []
for idx, item in enumerate(paragraphs, 1):
p_hash = str(item.get("hash", "") or "").strip()
content = str(item.get("content", "") or "").strip().replace("\r\n", "\n")
content = content[:800]
event_start = item.get("event_time_start")
event_end = item.get("event_time_end")
event_time = item.get("event_time")
rows.append(
(
f"[{idx}] hash={p_hash}\n"
f"event_time={event_time}\n"
f"event_time_start={event_start}\n"
f"event_time_end={event_end}\n"
f"content={content}"
)
)
source_text = str(source or "").strip() or "unknown"
return (
"You are an episode segmentation engine.\n"
"Group the given paragraphs into one or more coherent episodes.\n"
"Return JSON ONLY. No markdown, no explanation.\n"
"\n"
"Hard JSON schema:\n"
"{\n"
' "episodes": [\n'
" {\n"
' "title": "string",\n'
' "summary": "string",\n'
' "paragraph_hashes": ["hash1", "hash2"],\n'
' "participants": ["person1", "person2"],\n'
' "keywords": ["kw1", "kw2"],\n'
' "time_confidence": 0.0,\n'
' "llm_confidence": 0.0\n'
" }\n"
" ]\n"
"}\n"
"\n"
"Rules:\n"
"1) paragraph_hashes must come from input only.\n"
"2) title and summary must be non-empty.\n"
"3) keep participants/keywords concise and deduplicated.\n"
"4) if uncertain, still provide best effort confidence values.\n"
"\n"
f"source={source_text}\n"
f"window_start={window_start}\n"
f"window_end={window_end}\n"
"paragraphs:\n"
+ "\n\n".join(rows)
)
def _normalize_episodes(
self,
*,
payload: Dict[str, Any],
input_hashes: List[str],
) -> List[Dict[str, Any]]:
raw_episodes = payload.get("episodes")
if not isinstance(raw_episodes, list):
raise ValueError("episodes_missing_or_not_list")
valid_hashes = set(input_hashes)
normalized: List[Dict[str, Any]] = []
for item in raw_episodes:
if not isinstance(item, dict):
continue
title = str(item.get("title", "") or "").strip()
summary = str(item.get("summary", "") or "").strip()
if not title or not summary:
continue
raw_hashes = item.get("paragraph_hashes")
if not isinstance(raw_hashes, list):
continue
dedup_hashes: List[str] = []
seen_hashes = set()
for h in raw_hashes:
token = str(h or "").strip()
if not token or token in seen_hashes or token not in valid_hashes:
continue
seen_hashes.add(token)
dedup_hashes.append(token)
if not dedup_hashes:
continue
participants = []
for p in item.get("participants", []) or []:
token = str(p or "").strip()
if token:
participants.append(token)
keywords = []
for kw in item.get("keywords", []) or []:
token = str(kw or "").strip()
if token:
keywords.append(token)
normalized.append(
{
"title": title,
"summary": summary,
"paragraph_hashes": dedup_hashes,
"participants": participants[:16],
"keywords": keywords[:20],
"time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0),
"llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5),
}
)
if not normalized:
raise ValueError("episodes_all_invalid")
return normalized
async def segment(
self,
*,
source: str,
window_start: Optional[float],
window_end: Optional[float],
paragraphs: List[Dict[str, Any]],
) -> Dict[str, Any]:
if not paragraphs:
raise ValueError("paragraphs_empty")
model_config, model_label = self._resolve_model_config()
if model_config is None:
raise RuntimeError("episode segmentation model unavailable")
prompt = self._build_prompt(
source=source,
window_start=window_start,
window_end=window_end,
paragraphs=paragraphs,
)
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="A_Memorix.EpisodeSegmentation",
)
if not success or not response:
raise RuntimeError("llm_generate_failed")
payload = self._safe_json_loads(str(response))
input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs]
episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes)
return {
"episodes": episodes,
"segmentation_model": model_label,
"segmentation_version": self.SEGMENTATION_VERSION,
}

View File

@@ -0,0 +1,558 @@
"""
Episode 聚合与落库服务。
流程:
1. 从 pending 队列读取段落并组批
2. 按 source + 时间窗口切组
3. 调用 LLM 语义切分
4. 写入 episodes + episode_paragraphs
5. LLM 失败时使用确定性 fallback
"""
from __future__ import annotations
import json
import re
from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from .episode_segmentation_service import EpisodeSegmentationService
from .hash import compute_hash
logger = get_logger("A_Memorix.EpisodeService")
class EpisodeService:
"""Episode MVP 后台处理服务。"""
def __init__(
self,
*,
metadata_store: Any,
plugin_config: Optional[Any] = None,
segmentation_service: Optional[EpisodeSegmentationService] = None,
):
self.metadata_store = metadata_store
self.plugin_config = plugin_config or {}
self.segmentation_service = segmentation_service or EpisodeSegmentationService(
plugin_config=self._config_dict(),
)
def _config_dict(self) -> Dict[str, Any]:
if isinstance(self.plugin_config, dict):
return self.plugin_config
return {}
def _cfg(self, key: str, default: Any = None) -> Any:
getter = getattr(self.plugin_config, "get_config", None)
if callable(getter):
return getter(key, default)
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
@staticmethod
def _to_optional_float(value: Any) -> Optional[float]:
if value is None:
return None
try:
return float(value)
except Exception:
return None
@staticmethod
def _clamp_score(value: Any, default: float = 1.0) -> float:
try:
num = float(value)
except Exception:
num = default
if num < 0.0:
return 0.0
if num > 1.0:
return 1.0
return num
@staticmethod
def _paragraph_anchor(paragraph: Dict[str, Any]) -> float:
for key in ("event_time_end", "event_time_start", "event_time", "created_at"):
value = paragraph.get(key)
try:
if value is not None:
return float(value)
except Exception:
continue
return 0.0
@staticmethod
def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]:
return (
EpisodeService._paragraph_anchor(paragraph),
str(paragraph.get("hash", "") or ""),
)
def load_pending_paragraphs(
self,
pending_rows: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], List[str]]:
"""
将 pending 行展开为段落上下文。
Returns:
(loaded_paragraphs, missing_hashes)
"""
loaded: List[Dict[str, Any]] = []
missing: List[str] = []
for row in pending_rows or []:
p_hash = str(row.get("paragraph_hash", "") or "").strip()
if not p_hash:
continue
paragraph = self.metadata_store.get_paragraph(p_hash)
if not paragraph:
missing.append(p_hash)
continue
loaded.append(
{
"hash": p_hash,
"source": str(row.get("source") or paragraph.get("source") or "").strip(),
"content": str(paragraph.get("content", "") or ""),
"created_at": self._to_optional_float(paragraph.get("created_at"))
or self._to_optional_float(row.get("created_at"))
or 0.0,
"event_time": self._to_optional_float(paragraph.get("event_time")),
"event_time_start": self._to_optional_float(paragraph.get("event_time_start")),
"event_time_end": self._to_optional_float(paragraph.get("event_time_end")),
"time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None,
"time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0),
}
)
return loaded, missing
def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。
"""
if not paragraphs:
return []
max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20)))
max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000)))
window_seconds = max(
60.0,
float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0,
)
by_source: Dict[str, List[Dict[str, Any]]] = {}
for paragraph in paragraphs:
source = str(paragraph.get("source", "") or "").strip()
by_source.setdefault(source, []).append(paragraph)
groups: List[Dict[str, Any]] = []
for source, items in by_source.items():
ordered = sorted(items, key=self._paragraph_sort_key)
current: List[Dict[str, Any]] = []
current_chars = 0
last_anchor: Optional[float] = None
def flush() -> None:
nonlocal current, current_chars, last_anchor
if not current:
return
sorted_current = sorted(current, key=self._paragraph_sort_key)
groups.append(
{
"source": source,
"paragraphs": sorted_current,
}
)
current = []
current_chars = 0
last_anchor = None
for paragraph in ordered:
anchor = self._paragraph_anchor(paragraph)
content_len = len(str(paragraph.get("content", "") or ""))
need_flush = False
if current:
if len(current) >= max_paragraphs:
need_flush = True
elif current_chars + content_len > max_chars:
need_flush = True
elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds:
need_flush = True
if need_flush:
flush()
current.append(paragraph)
current_chars += content_len
last_anchor = anchor
flush()
groups.sort(
key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0
)
return groups
def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]:
starts: List[float] = []
ends: List[float] = []
granularity_priority = {
"minute": 4,
"hour": 3,
"day": 2,
"month": 1,
"year": 0,
}
granularity = None
granularity_rank = -1
conf_values: List[float] = []
for p in paragraphs:
s = self._to_optional_float(p.get("event_time_start"))
e = self._to_optional_float(p.get("event_time_end"))
t = self._to_optional_float(p.get("event_time"))
c = self._to_optional_float(p.get("created_at"))
start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c))
end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c))
if start_candidate is not None:
starts.append(start_candidate)
if end_candidate is not None:
ends.append(end_candidate)
g = str(p.get("time_granularity", "") or "").strip().lower()
if g in granularity_priority and granularity_priority[g] > granularity_rank:
granularity_rank = granularity_priority[g]
granularity = g
conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0))
time_start = min(starts) if starts else None
time_end = max(ends) if ends else None
time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0
return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0)
def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]:
seen = set()
participants: List[str] = []
for p_hash in paragraph_hashes:
try:
entities = self.metadata_store.get_paragraph_entities(p_hash)
except Exception:
entities = []
for item in entities:
name = str(item.get("name", "") or "").strip()
if not name:
continue
key = name.lower()
if key in seen:
continue
seen.add(key)
participants.append(name)
if len(participants) >= limit:
return participants
return participants
@staticmethod
def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]:
token_counter: Counter[str] = Counter()
token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}")
stop_words = {
"the",
"and",
"that",
"this",
"with",
"from",
"for",
"have",
"will",
"your",
"you",
"我们",
"你们",
"他们",
"以及",
"一个",
"这个",
"那个",
"然后",
"因为",
"所以",
}
for p in paragraphs:
text = str(p.get("content", "") or "").lower()
for token in token_pattern.findall(text):
if token in stop_words:
continue
token_counter[token] += 1
return [token for token, _ in token_counter.most_common(limit)]
def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]:
paragraphs = group.get("paragraphs", []) or []
source = str(group.get("source", "") or "").strip()
hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
snippets = []
for p in paragraphs[:3]:
text = str(p.get("content", "") or "").strip().replace("\n", " ")
if text:
snippets.append(text[:140])
summary = "".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。"
time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs)
participants = self._collect_participants(hashes, limit=12)
keywords = self._derive_keywords(paragraphs, limit=10)
if time_start is not None:
day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d")
title = f"{source or 'unknown'} {day_text} 情景片段"
else:
title = f"{source or 'unknown'} 情景片段"
return {
"title": title[:80],
"summary": summary,
"paragraph_hashes": hashes,
"participants": participants,
"keywords": keywords,
"time_confidence": time_conf,
"llm_confidence": 0.0,
"event_time_start": time_start,
"event_time_end": time_end,
"time_granularity": granularity,
"segmentation_model": "fallback_rule",
"segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION,
}
@staticmethod
def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]:
in_group = set(group_hashes_ordered)
dedup: List[str] = []
seen = set()
for h in episode_hashes or []:
token = str(h or "").strip()
if not token or token not in in_group or token in seen:
continue
seen.add(token)
dedup.append(token)
return dedup
async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
paragraphs = group.get("paragraphs", []) or []
if not paragraphs:
return {
"payloads": [],
"done_hashes": [],
"episode_count": 0,
"fallback_count": 0,
}
source = str(group.get("source", "") or "").strip()
group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
group_start, group_end, _, _ = self._compute_time_meta(paragraphs)
fallback_used = False
segmentation_model = "fallback_rule"
segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION
try:
llm_result = await self.segmentation_service.segment(
source=source,
window_start=group_start,
window_end=group_end,
paragraphs=paragraphs,
)
episodes = list(llm_result.get("episodes") or [])
segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto"
segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION
if not episodes:
raise ValueError("llm_empty_episodes")
except Exception as e:
logger.warning(
"Episode segmentation fallback: "
f"source={source} "
f"size={len(group_hashes)} "
f"err={e}"
)
episodes = [self._build_fallback_episode(group)]
fallback_used = True
stored_payloads: List[Dict[str, Any]] = []
for episode in episodes:
ordered_hashes = self._normalize_episode_hashes(
episode_hashes=episode.get("paragraph_hashes", []),
group_hashes_ordered=group_hashes,
)
if not ordered_hashes:
continue
sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)]
event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs)
participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()]
keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()]
if not participants:
participants = self._collect_participants(ordered_hashes, limit=16)
if not keywords:
keywords = self._derive_keywords(sub_paragraphs, limit=12)
title = str(episode.get("title", "") or "").strip()[:120]
summary = str(episode.get("summary", "") or "").strip()[:2000]
if not title or not summary:
continue
seed = json.dumps(
{
"source": source,
"hashes": ordered_hashes,
"version": segmentation_version,
},
ensure_ascii=False,
sort_keys=True,
)
episode_id = compute_hash(seed)
payload = {
"episode_id": episode_id,
"source": source or None,
"title": title,
"summary": summary,
"event_time_start": episode.get("event_time_start", event_start),
"event_time_end": episode.get("event_time_end", event_end),
"time_granularity": episode.get("time_granularity", granularity),
"time_confidence": self._clamp_score(
episode.get("time_confidence"),
default=time_conf_default,
),
"participants": participants[:16],
"keywords": keywords[:20],
"evidence_ids": ordered_hashes,
"paragraph_count": len(ordered_hashes),
"llm_confidence": self._clamp_score(
episode.get("llm_confidence"),
default=0.0 if fallback_used else 0.6,
),
"segmentation_model": (
str(episode.get("segmentation_model", "") or "").strip()
or ("fallback_rule" if fallback_used else segmentation_model)
),
"segmentation_version": (
str(episode.get("segmentation_version", "") or "").strip()
or segmentation_version
),
}
stored_payloads.append(payload)
return {
"payloads": stored_payloads,
"done_hashes": group_hashes,
"episode_count": len(stored_payloads),
"fallback_count": 1 if fallback_used else 0,
}
async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
result = await self._build_episode_payloads_for_group(group)
stored_count = 0
for payload in result.get("payloads") or []:
stored = self.metadata_store.upsert_episode(payload)
final_id = str(stored.get("episode_id") or payload.get("episode_id") or "")
if final_id:
self.metadata_store.bind_episode_paragraphs(
final_id,
list(payload.get("evidence_ids") or []),
)
stored_count += 1
result["episode_count"] = stored_count
return {
"done_hashes": list(result.get("done_hashes") or []),
"episode_count": stored_count,
"fallback_count": int(result.get("fallback_count") or 0),
}
async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]:
loaded, missing_hashes = self.load_pending_paragraphs(pending_rows)
groups = self.group_paragraphs(loaded)
done_hashes: List[str] = list(missing_hashes)
failed_hashes: Dict[str, str] = {}
episode_count = 0
fallback_count = 0
for group in groups:
group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])]
try:
result = await self.process_group(group)
done_hashes.extend(result.get("done_hashes") or [])
episode_count += int(result.get("episode_count") or 0)
fallback_count += int(result.get("fallback_count") or 0)
except Exception as e:
err = str(e)[:500]
for h in group_hashes:
if h:
failed_hashes[h] = err
dedup_done = list(dict.fromkeys([h for h in done_hashes if h]))
return {
"done_hashes": dedup_done,
"failed_hashes": failed_hashes,
"episode_count": episode_count,
"fallback_count": fallback_count,
"missing_count": len(missing_hashes),
"group_count": len(groups),
}
async def rebuild_source(self, source: str) -> Dict[str, Any]:
token = str(source or "").strip()
if not token:
return {
"source": "",
"episode_count": 0,
"fallback_count": 0,
"group_count": 0,
"paragraph_count": 0,
}
paragraphs = self.metadata_store.get_live_paragraphs_by_source(token)
if not paragraphs:
replace_result = self.metadata_store.replace_episodes_for_source(token, [])
return {
"source": token,
"episode_count": int(replace_result.get("episode_count") or 0),
"fallback_count": 0,
"group_count": 0,
"paragraph_count": 0,
}
groups = self.group_paragraphs(paragraphs)
payloads: List[Dict[str, Any]] = []
fallback_count = 0
for group in groups:
result = await self._build_episode_payloads_for_group(group)
payloads.extend(list(result.get("payloads") or []))
fallback_count += int(result.get("fallback_count") or 0)
replace_result = self.metadata_store.replace_episodes_for_source(token, payloads)
return {
"source": token,
"episode_count": int(replace_result.get("episode_count") or 0),
"fallback_count": fallback_count,
"group_count": len(groups),
"paragraph_count": len(paragraphs),
}

View File

@@ -9,7 +9,11 @@ import json
import time
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import or_
from sqlmodel import select
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from ..embedding import EmbeddingAPIAdapter
@@ -120,31 +124,40 @@ class PersonProfileService:
if not key:
return ""
try:
with get_db_session(auto_commit=False) as session:
record = session.exec(
select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1)
).first()
if record:
return str(record)
record = session.exec(
select(PersonInfo.person_id)
.where(
or_(
PersonInfo.person_name == key,
PersonInfo.user_nickname == key,
)
)
.limit(1)
).first()
if record:
return str(record)
record = session.exec(
select(PersonInfo.person_id)
.where(PersonInfo.group_cardname.contains(key))
.limit(1)
).first()
if record:
return str(record)
except Exception as e:
logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}")
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
return key.lower()
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where(PersonInfo.group_nick_name.contains(key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
return ""
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
@@ -160,7 +173,7 @@ class PersonProfileService:
names: List[str] = []
for item in items:
if isinstance(item, dict):
value = str(item.get("group_nick_name", "")).strip()
value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip()
if value:
names.append(value)
elif isinstance(item, str):
@@ -193,6 +206,42 @@ class PersonProfileService:
traits.append(text)
return traits[:10]
def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]:
"""当人物主档案缺失时,从已有记忆证据里回捞可用别名。"""
if not person_id:
return [], ""
aliases: List[str] = []
primary_name = ""
seen = set()
try:
paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id)
except Exception as e:
logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}")
return [], ""
for paragraph in paragraphs[:20]:
paragraph_hash = str(paragraph.get("hash", "") or "").strip()
if not paragraph_hash:
continue
try:
paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash)
except Exception:
paragraph_entities = []
for entity in paragraph_entities:
name = str(entity.get("name", "") or "").strip()
if not name or name == person_id:
continue
key = name.lower()
if key in seen:
continue
seen.add(key)
aliases.append(name)
if not primary_name:
primary_name = name
return aliases, primary_name
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
"""获取人物别名集合、主展示名、记忆特征。"""
aliases: List[str] = []
@@ -200,18 +249,28 @@ class PersonProfileService:
memory_traits: List[str] = []
if not person_id:
return aliases, primary_name, memory_traits
recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id)
try:
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not record:
return aliases, primary_name, memory_traits
with get_db_session(auto_commit=False) as session:
record = session.exec(
select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1)
).first()
if not record:
return recovered_aliases, recovered_primary_name or person_id, memory_traits
person_name = str(getattr(record, "person_name", "") or "").strip()
nickname = str(getattr(record, "nickname", "") or "").strip()
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
nickname = str(getattr(record, "user_nickname", "") or "").strip()
group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None))
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
primary_name = (
person_name
or nickname
or recovered_primary_name
or str(getattr(record, "user_id", "") or "").strip()
or person_id
)
candidates = [person_name, nickname] + group_nicks
candidates = [person_name, nickname] + group_nicks + recovered_aliases
seen = set()
for item in candidates:
norm = str(item or "").strip()

View File

@@ -82,8 +82,9 @@ class RelationWriteService:
)
self.metadata_store.set_relation_vector_state(hash_value, "ready")
logger.info(
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
hash_value[:16],
"metric.relation_vector_write_success=1 "
"metric.relation_vector_write_success_count=1 "
f"hash={hash_value[:16]}"
)
return RelationWriteResult(
hash_value=hash_value,
@@ -109,9 +110,10 @@ class RelationWriteService:
bump_retry=True,
)
logger.warning(
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
hash_value[:16],
err,
"metric.relation_vector_write_fail=1 "
"metric.relation_vector_write_fail_count=1 "
f"hash={hash_value[:16]} "
f"err={err}"
)
return RelationWriteResult(
hash_value=hash_value,

File diff suppressed because it is too large Load Diff

View File

@@ -61,6 +61,29 @@ def _build_report(
}
def _normalize_encoded_vector(encoded: Any) -> np.ndarray:
if encoded is None:
raise ValueError("embedding encode returned None")
if isinstance(encoded, np.ndarray):
array = encoded
else:
array = np.asarray(encoded, dtype=np.float32)
if array.ndim == 2:
if array.shape[0] != 1:
raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}")
array = array[0]
if array.ndim != 1:
raise ValueError(f"embedding encode returned invalid ndim={array.ndim}")
if array.size <= 0:
raise ValueError("embedding encode returned empty vector")
if not np.all(np.isfinite(array)):
raise ValueError("embedding encode returned non-finite values")
return array.astype(np.float32, copy=False)
async def run_embedding_runtime_self_check(
*,
config: Any,
@@ -91,13 +114,11 @@ async def run_embedding_runtime_self_check(
try:
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
encoded = await embedding_manager.encode(sample_text)
if isinstance(encoded, np.ndarray):
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
else:
encoded_dimension = len(encoded) if encoded is not None else 0
encoded_array = _normalize_encoded_vector(encoded)
encoded_dimension = int(encoded_array.shape[0])
except Exception as exc:
elapsed_ms = (time.perf_counter() - start) * 1000.0
logger.warning("embedding runtime self-check failed: %s", exc)
logger.warning(f"embedding runtime self-check failed: {exc}")
return _build_report(
ok=False,
code="embedding_probe_failed",

View File

@@ -0,0 +1,442 @@
"""
统一检索执行服务。
用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。
"""
from __future__ import annotations
import hashlib
import json
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from ..retrieval import TemporalQueryOptions
from .search_postprocess import (
apply_safe_content_dedup,
maybe_apply_smart_path_fallback,
)
from .time_parser import parse_query_time_range
logger = get_logger("A_Memorix.SearchExecutionService")
def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any:
if not isinstance(config, dict):
return default
current: Any = config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _sanitize_text(value: Any) -> str:
if value is None:
return ""
return str(value).strip()
@dataclass
class SearchExecutionRequest:
caller: str
stream_id: Optional[str] = None
group_id: Optional[str] = None
user_id: Optional[str] = None
query_type: str = "search" # search|semantic|time|hybrid
query: str = ""
top_k: Optional[int] = None
time_from: Optional[str] = None
time_to: Optional[str] = None
person: Optional[str] = None
source: Optional[str] = None
use_threshold: bool = True
enable_ppr: bool = True
@dataclass
class SearchExecutionResult:
success: bool
error: str = ""
query_type: str = "search"
query: str = ""
top_k: int = 10
time_from: Optional[str] = None
time_to: Optional[str] = None
person: Optional[str] = None
source: Optional[str] = None
temporal: Optional[TemporalQueryOptions] = None
results: List[Any] = field(default_factory=list)
elapsed_ms: float = 0.0
chat_filtered: bool = False
dedup_hit: bool = False
@property
def count(self) -> int:
return len(self.results)
class SearchExecutionService:
"""统一检索执行服务。"""
@staticmethod
def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]:
if isinstance(plugin_config, dict):
plugin_instance = plugin_config.get("plugin_instance")
if plugin_instance is not None:
return plugin_instance
try:
from ...plugin import AMemorixPlugin
return getattr(AMemorixPlugin, "get_global_instance", lambda: None)()
except Exception:
return None
@staticmethod
def _normalize_query_type(raw_query_type: str) -> str:
query_type = _sanitize_text(raw_query_type).lower() or "search"
if query_type == "semantic":
return "search"
return query_type
@staticmethod
def _resolve_runtime_component(
plugin_config: Optional[dict],
plugin_instance: Optional[Any],
key: str,
) -> Optional[Any]:
if isinstance(plugin_config, dict):
value = plugin_config.get(key)
if value is not None:
return value
if plugin_instance is not None:
value = getattr(plugin_instance, key, None)
if value is not None:
return value
return None
@staticmethod
def _resolve_top_k(
plugin_config: Optional[dict],
query_type: str,
top_k_raw: Optional[Any],
) -> Tuple[bool, int, str]:
temporal_default_top_k = int(
_get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10)
)
default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10
if top_k_raw is None:
return True, max(1, min(50, default_top_k)), ""
try:
top_k = int(top_k_raw)
except (TypeError, ValueError):
return False, 0, "top_k 参数必须为整数"
return True, max(1, min(50, top_k)), ""
@staticmethod
def _build_temporal(
plugin_config: Optional[dict],
query_type: str,
time_from_raw: Optional[str],
time_to_raw: Optional[str],
person: Optional[str],
source: Optional[str],
) -> Tuple[bool, Optional[TemporalQueryOptions], str]:
if query_type not in {"time", "hybrid"}:
return True, None, ""
temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True))
if not temporal_enabled:
return False, None, "时序检索已禁用retrieval.temporal.enabled=false"
if not time_from_raw and not time_to_raw:
return False, None, "time/hybrid 模式至少需要 time_from 或 time_to"
try:
ts_from, ts_to = parse_query_time_range(
str(time_from_raw) if time_from_raw is not None else None,
str(time_to_raw) if time_to_raw is not None else None,
)
except ValueError as e:
return False, None, f"时间参数错误: {e}"
temporal = TemporalQueryOptions(
time_from=ts_from,
time_to=ts_to,
person=_sanitize_text(person) or None,
source=_sanitize_text(source) or None,
allow_created_fallback=bool(
_get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True)
),
candidate_multiplier=int(
_get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8)
),
max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)),
)
return True, temporal, ""
@staticmethod
def _build_request_key(
request: SearchExecutionRequest,
query_type: str,
top_k: int,
temporal: Optional[TemporalQueryOptions],
) -> str:
payload = {
"stream_id": _sanitize_text(request.stream_id),
"query_type": query_type,
"query": _sanitize_text(request.query),
"time_from": _sanitize_text(request.time_from),
"time_to": _sanitize_text(request.time_to),
"time_from_ts": temporal.time_from if temporal else None,
"time_to_ts": temporal.time_to if temporal else None,
"person": _sanitize_text(request.person),
"source": _sanitize_text(request.source),
"top_k": int(top_k),
"use_threshold": bool(request.use_threshold),
"enable_ppr": bool(request.enable_ppr),
}
payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True)
return hashlib.sha1(payload_json.encode("utf-8")).hexdigest()
@staticmethod
async def execute(
*,
retriever: Any,
threshold_filter: Optional[Any],
plugin_config: Optional[dict],
request: SearchExecutionRequest,
enforce_chat_filter: bool = True,
reinforce_access: bool = True,
) -> SearchExecutionResult:
if retriever is None:
return SearchExecutionResult(success=False, error="知识检索器未初始化")
query_type = SearchExecutionService._normalize_query_type(request.query_type)
query = _sanitize_text(request.query)
if query_type not in {"search", "time", "hybrid"}:
return SearchExecutionResult(
success=False,
error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid",
)
if query_type in {"search", "hybrid"} and not query:
return SearchExecutionResult(
success=False,
error="search/hybrid 模式必须提供 query",
)
top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k(
plugin_config, query_type, request.top_k
)
if not top_k_ok:
return SearchExecutionResult(success=False, error=top_k_error)
temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal(
plugin_config=plugin_config,
query_type=query_type,
time_from_raw=request.time_from,
time_to_raw=request.time_to,
person=request.person,
source=request.source,
)
if not temporal_ok:
return SearchExecutionResult(success=False, error=temporal_error)
plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config)
if (
enforce_chat_filter
and plugin_instance is not None
and hasattr(plugin_instance, "is_chat_enabled")
):
if not plugin_instance.is_chat_enabled(
stream_id=request.stream_id,
group_id=request.group_id,
user_id=request.user_id,
):
logger.info(
"检索请求被聊天过滤拦截: "
f"caller={request.caller}, "
f"stream_id={request.stream_id}"
)
return SearchExecutionResult(
success=True,
query_type=query_type,
query=query,
top_k=top_k,
time_from=request.time_from,
time_to=request.time_to,
person=request.person,
source=request.source,
temporal=temporal,
results=[],
elapsed_ms=0.0,
chat_filtered=True,
dedup_hit=False,
)
request_key = SearchExecutionService._build_request_key(
request=request,
query_type=query_type,
top_k=top_k,
temporal=temporal,
)
async def _executor() -> Dict[str, Any]:
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
started_at = time.time()
try:
retrieved = await retriever.retrieve(
query=query,
top_k=top_k,
temporal=temporal,
)
should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None
if (
query_type == "time"
and not query
and bool(
_get_config_value(
plugin_config,
"retrieval.time.skip_threshold_when_query_empty",
True,
)
)
):
should_apply_threshold = False
if should_apply_threshold:
retrieved = threshold_filter.filter(retrieved)
if (
reinforce_access
and plugin_instance is not None
and hasattr(plugin_instance, "reinforce_access")
):
relation_hashes = [
item.hash_value
for item in retrieved
if getattr(item, "result_type", "") == "relation"
]
if relation_hashes:
await plugin_instance.reinforce_access(relation_hashes)
if query_type == "search":
graph_store = SearchExecutionService._resolve_runtime_component(
plugin_config, plugin_instance, "graph_store"
)
metadata_store = SearchExecutionService._resolve_runtime_component(
plugin_config, plugin_instance, "metadata_store"
)
fallback_enabled = bool(
_get_config_value(
plugin_config,
"retrieval.search.smart_fallback.enabled",
True,
)
)
fallback_threshold = float(
_get_config_value(
plugin_config,
"retrieval.search.smart_fallback.threshold",
0.6,
)
)
retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback(
query=query,
results=list(retrieved),
graph_store=graph_store,
metadata_store=metadata_store,
enabled=fallback_enabled,
threshold=fallback_threshold,
)
if fallback_triggered:
logger.info(
"metric.smart_fallback_triggered_count=1 "
f"caller={request.caller} "
f"added={fallback_added}"
)
dedup_enabled = bool(
_get_config_value(
plugin_config,
"retrieval.search.safe_content_dedup.enabled",
True,
)
)
if dedup_enabled:
retrieved, removed_count = apply_safe_content_dedup(list(retrieved))
if removed_count > 0:
logger.info(
f"metric.safe_dedup_removed_count={removed_count} "
f"caller={request.caller}"
)
elapsed_ms = (time.time() - started_at) * 1000.0
return {"results": retrieved, "elapsed_ms": elapsed_ms}
finally:
setattr(retriever.config, "enable_ppr", original_ppr)
dedup_hit = False
try:
# 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。
bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning"
if (
not bypass_request_dedup
and
plugin_instance is not None
and hasattr(plugin_instance, "execute_request_with_dedup")
):
dedup_hit, payload = await plugin_instance.execute_request_with_dedup(
request_key,
_executor,
)
else:
payload = await _executor()
except Exception as e:
return SearchExecutionResult(success=False, error=f"知识检索失败: {e}")
if dedup_hit:
logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}")
return SearchExecutionResult(
success=True,
query_type=query_type,
query=query,
top_k=top_k,
time_from=request.time_from,
time_to=request.time_to,
person=request.person,
source=request.source,
temporal=temporal,
results=payload.get("results", []),
elapsed_ms=float(payload.get("elapsed_ms", 0.0)),
chat_filtered=False,
dedup_hit=bool(dedup_hit),
)
@staticmethod
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]:
serialized: List[Dict[str, Any]] = []
for item in results:
metadata = dict(getattr(item, "metadata", {}) or {})
if "time_meta" not in metadata:
metadata["time_meta"] = {}
serialized.append(
{
"hash": getattr(item, "hash_value", ""),
"type": getattr(item, "result_type", ""),
"score": float(getattr(item, "score", 0.0)),
"content": getattr(item, "content", ""),
"metadata": metadata,
}
)
return serialized

View File

@@ -0,0 +1,425 @@
"""
聊天总结与知识导入工具
该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系
导入到 A_memorix 的存储组件中。
"""
import time
import json
import re
import traceback
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path
from src.common.logger import get_logger
from src.services import llm_service as llm_api
from src.services import message_service as message_api
from src.config.config import global_config, model_config as host_model_config
from src.config.model_configs import TaskConfig
from ..storage import (
KnowledgeType,
VectorStore,
GraphStore,
MetadataStore,
resolve_stored_knowledge_type,
)
from ..embedding import EmbeddingAPIAdapter
from .relation_write_service import RelationWriteService
from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check
logger = get_logger("A_Memorix.SummaryImporter")
# 默认总结提示词模版
SUMMARY_PROMPT_TEMPLATE = """
你是 {bot_name}{personality_context}
现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。
聊天记录内容:
{chat_history}
请完成以下任务:
1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。
2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。
请严格以 JSON 格式输出,格式如下:
{{
"summary": "总结文本内容",
"entities": ["张三", "李四"],
"relations": [
{{"subject": "张三", "predicate": "认识", "object": "李四"}}
]
}}
注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。
"""
class SummaryImporter:
"""总结并导入知识的工具类"""
def __init__(
self,
vector_store: VectorStore,
graph_store: GraphStore,
metadata_store: MetadataStore,
embedding_manager: EmbeddingAPIAdapter,
plugin_config: dict
):
self.vector_store = vector_store
self.graph_store = graph_store
self.metadata_store = metadata_store
self.embedding_manager = embedding_manager
self.plugin_config = plugin_config
self.relation_write_service: Optional[RelationWriteService] = (
plugin_config.get("relation_write_service")
if isinstance(plugin_config, dict)
else None
)
def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]:
"""标准化 summarization.model_name 配置vNext 仅接受字符串数组)。"""
if raw_value is None:
return ["auto"]
if isinstance(raw_value, list):
selectors = [str(x).strip() for x in raw_value if str(x).strip()]
return selectors or ["auto"]
raise ValueError(
"summarization.model_name 在 vNext 必须为 List[str]。"
" 请执行 scripts/release_vnext_migrate.py migrate。"
)
def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]:
"""
选择总结默认任务,避免错误落到 embedding 任务。
优先级replyer > utils > planner > tool_use > 其他非 embedding。
"""
preferred = ("replyer", "utils", "planner", "tool_use")
for name in preferred:
cfg = available_tasks.get(name)
if cfg and cfg.model_list:
return name, cfg
for name, cfg in available_tasks.items():
if name != "embedding" and cfg.model_list:
return name, cfg
for name, cfg in available_tasks.items():
if cfg.model_list:
return name, cfg
return None, None
def _resolve_summary_model_config(self) -> Optional[TaskConfig]:
"""
解析 summarization.model_name 为 TaskConfig。
支持:
- "auto"
- "replyer"(任务名)
- "some-model-name"(具体模型名)
- ["utils:model1", "utils:model2", "replyer"](数组混合语法)
"""
available_tasks = llm_api.get_available_models()
if not available_tasks:
return None
raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto")
selectors = self._normalize_summary_model_selectors(raw_cfg)
default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks)
selected_models: List[str] = []
base_cfg: Optional[TaskConfig] = None
model_dict = getattr(host_model_config, "models_dict", {})
def _append_models(models: List[str]):
for model_name in models:
if model_name and model_name not in selected_models:
selected_models.append(model_name)
for raw_selector in selectors:
selector = raw_selector.strip()
if not selector:
continue
if selector.lower() == "auto":
if default_task_cfg:
_append_models(default_task_cfg.model_list)
if base_cfg is None:
base_cfg = default_task_cfg
continue
if ":" in selector:
task_name, model_name = selector.split(":", 1)
task_name = task_name.strip()
model_name = model_name.strip()
task_cfg = available_tasks.get(task_name)
if not task_cfg:
logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过")
continue
if base_cfg is None:
base_cfg = task_cfg
if not model_name or model_name.lower() == "auto":
_append_models(task_cfg.model_list)
continue
if model_name in model_dict or model_name in task_cfg.model_list:
_append_models([model_name])
else:
logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过")
continue
task_cfg = available_tasks.get(selector)
if task_cfg:
_append_models(task_cfg.model_list)
if base_cfg is None:
base_cfg = task_cfg
continue
if selector in model_dict:
_append_models([selector])
continue
logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过")
if not selected_models:
if default_task_cfg:
_append_models(default_task_cfg.model_list)
if base_cfg is None:
base_cfg = default_task_cfg
else:
first_cfg = next(iter(available_tasks.values()))
_append_models(first_cfg.model_list)
if base_cfg is None:
base_cfg = first_cfg
if not selected_models:
return None
template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values()))
return TaskConfig(
model_list=selected_models,
max_tokens=template_cfg.max_tokens,
temperature=template_cfg.temperature,
slow_threshold=template_cfg.slow_threshold,
selection_strategy=template_cfg.selection_strategy,
)
async def import_from_stream(
self,
stream_id: str,
context_length: Optional[int] = None,
include_personality: Optional[bool] = None
) -> Tuple[bool, str]:
"""
从指定的聊天流中提取记录并执行总结导入
Args:
stream_id: 聊天流 ID
context_length: 总结的历史消息条数
include_personality: 是否包含人设
Returns:
Tuple[bool, str]: (是否成功, 结果消息)
"""
try:
self_check_ok, self_check_msg = await self._ensure_runtime_self_check()
if not self_check_ok:
return False, f"导入前自检失败: {self_check_msg}"
# 1. 获取配置
if context_length is None:
context_length = self.plugin_config.get("summarization", {}).get("context_length", 50)
if include_personality is None:
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
# 2. 获取历史消息
# 获取当前时间之前的消息
now = time.time()
messages = message_api.get_messages_before_time_in_chat(
chat_id=stream_id,
timestamp=now,
limit=context_length
)
if not messages:
return False, "未找到有效的聊天记录进行总结"
# 转换为可读文本
chat_history_text = message_api.build_readable_messages(messages)
# 3. 准备提示词内容
bot_name = global_config.bot.nickname or "机器人"
personality_context = ""
if include_personality:
personality = getattr(global_config.bot, "personality", "")
if personality:
personality_context = f"你的性格设定是:{personality}"
# 4. 调用 LLM
prompt = SUMMARY_PROMPT_TEMPLATE.format(
bot_name=bot_name,
personality_context=personality_context,
chat_history=chat_history_text
)
model_config_to_use = self._resolve_summary_model_config()
if model_config_to_use is None:
return False, "未找到可用的总结模型配置"
logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
logger.info(f"总结模型候选列表: {model_config_to_use.model_list}")
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config_to_use,
request_type="A_Memorix.ChatSummarization"
)
if not success or not response:
return False, "LLM 生成总结失败"
# 5. 解析结果
data = self._parse_llm_response(response)
if not data or "summary" not in data:
return False, "解析 LLM 响应失败或总结为空"
summary_text = data["summary"]
entities = data.get("entities", [])
relations = data.get("relations", [])
msg_times = [
float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)())
for msg in messages
if getattr(msg, "time", None) is not None
]
time_meta = {}
if msg_times:
time_meta = {
"event_time_start": min(msg_times),
"event_time_end": max(msg_times),
"time_granularity": "minute",
"time_confidence": 0.95,
}
# 6. 执行导入
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta)
# 7. 持久化
self.vector_store.save()
self.graph_store.save()
result_msg = (
f"✅ 总结导入成功\n"
f"📝 总结长度: {len(summary_text)}\n"
f"📌 提取实体: {len(entities)}\n"
f"🔗 提取关系: {len(relations)}"
)
return True, result_msg
except Exception as e:
logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}")
return False, f"错误: {str(e)}"
async def _ensure_runtime_self_check(self) -> Tuple[bool, str]:
plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None
if plugin_instance is not None:
report = await ensure_runtime_self_check(plugin_instance)
else:
report = await run_embedding_runtime_self_check(
config=self.plugin_config,
vector_store=self.vector_store,
embedding_manager=self.embedding_manager,
)
if bool(report.get("ok", False)):
return True, ""
return (
False,
f"{report.get('message', 'unknown')} "
f"(configured={report.get('configured_dimension', 0)}, "
f"store={report.get('vector_store_dimension', 0)}, "
f"encoded={report.get('encoded_dimension', 0)})",
)
def _parse_llm_response(self, response: str) -> Dict[str, Any]:
"""解析 LLM 返回的 JSON"""
try:
# 尝试查找 JSON
json_match = re.search(r"\{.*\}", response, re.DOTALL)
if json_match:
return json.loads(json_match.group())
return {}
except Exception as e:
logger.warning(f"解析总结 JSON 失败: {e}")
return {}
async def _execute_import(
self,
summary: str,
entities: List[str],
relations: List[Dict[str, str]],
stream_id: str,
time_meta: Optional[Dict[str, Any]] = None,
):
"""将数据写入存储"""
# 获取默认知识类型
type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative")
try:
knowledge_type = resolve_stored_knowledge_type(type_str, content=summary)
except ValueError:
logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative")
knowledge_type = KnowledgeType.NARRATIVE
# 导入总结文本
hash_value = self.metadata_store.add_paragraph(
content=summary,
source=f"chat_summary:{stream_id}",
knowledge_type=knowledge_type.value,
time_meta=time_meta,
)
embedding = await self.embedding_manager.encode(summary)
self.vector_store.add(
vectors=embedding.reshape(1, -1),
ids=[hash_value]
)
# 导入实体
if entities:
self.graph_store.add_nodes(entities)
# 导入关系
rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {})
if not isinstance(rv_cfg, dict):
rv_cfg = {}
write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True))
for rel in relations:
s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object")
if all([s, p, o]):
if self.relation_write_service is not None:
await self.relation_write_service.upsert_relation_with_vector(
subject=s,
predicate=p,
obj=o,
confidence=1.0,
source_paragraph=summary,
write_vector=write_vector,
)
else:
# 写入元数据
rel_hash = self.metadata_store.add_relation(
subject=s,
predicate=p,
obj=o,
confidence=1.0,
source_paragraph=summary
)
# 写入图数据库(写入 relation_hashes确保后续可按关系精确修剪
self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash])
try:
self.metadata_store.set_relation_vector_state(rel_hash, "none")
except Exception:
pass
logger.info(f"总结导入完成: hash={hash_value[:8]}")

File diff suppressed because it is too large Load Diff