添加 A_Memorix 插件 v2.0.0(包含运行时与文档)

引入 A_Memorix 插件 v2.0.0:新增大量运行时组件、存储/模式更新、检索能力提升、管理工具、导入/调优工作流以及相关文档。关键新增内容包括:lifecycle_orchestrator、SDKMemoryKernel/运行时初始化器、新的存储层与 metadata_store 变更(SCHEMA_VERSION v8)、检索增强(双路径检索、图关系召回、稀疏 BM25),以及多种工具服务(episode/person_profile/relation/segmentation/tuning/search execution)。同时新增 Web 导入/摘要导入器及大量维护脚本。还更新了插件清单、embedding API 适配器、plugin.py、requirements/pyproject,以及主入口文件,使新插件接入项目。该变更为 2.0.0 版本发布做好准备,实现统一的 SDK Tool 接口并扩展整体运行能力。
This commit is contained in:
DawnARC
2026-03-19 00:09:04 +08:00
parent eb257345dd
commit 71b3a828c6
44 changed files with 18193 additions and 405 deletions

View File

@@ -1,46 +1,55 @@
"""
Hash-based embedding adapter used by the SDK runtime.
请求式嵌入 API 适配器。
The plugin runtime cannot import MaiBot host embedding internals from ``src.chat``
or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in
Runner by generating deterministic dense vectors locally.
恢复 v1.0.1 的真实 embedding 请求语义:
- 通过宿主模型配置探测/请求 embedding
- 支持 dimensions 参数
- 支持批量与重试
- 不再提供本地 hash fallback
"""
from __future__ import annotations
import hashlib
import re
import asyncio
import time
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
import aiohttp
import numpy as np
import openai
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.exceptions import NetworkConnectionError
from src.llm_models.model_client.base_client import client_registry
logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}")
class EmbeddingAPIAdapter:
"""Deterministic local embedding adapter."""
"""适配宿主 embedding 请求接口。"""
def __init__(
self,
batch_size: int = 32,
max_concurrent: int = 5,
default_dimension: int = 256,
default_dimension: int = 1024,
enable_cache: bool = False,
model_name: str = "hash-v1",
model_name: str = "auto",
retry_config: Optional[dict] = None,
) -> None:
self.batch_size = max(1, int(batch_size))
self.max_concurrent = max(1, int(max_concurrent))
self.default_dimension = max(32, int(default_dimension))
self.default_dimension = max(1, int(default_dimension))
self.enable_cache = bool(enable_cache)
self.model_name = str(model_name or "hash-v1")
self.model_name = str(model_name or "auto")
self.retry_config = retry_config or {}
self.max_attempts = max(1, int(self.retry_config.get("max_attempts", 5)))
self.max_wait_seconds = max(0.1, float(self.retry_config.get("max_wait_seconds", 40)))
self.min_wait_seconds = max(0.1, float(self.retry_config.get("min_wait_seconds", 3)))
self.backoff_multiplier = max(1.0, float(self.retry_config.get("backoff_multiplier", 3)))
self._dimension: Optional[int] = None
self._dimension_detected = False
@@ -49,57 +58,164 @@ class EmbeddingAPIAdapter:
self._total_time = 0.0
logger.info(
"EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s",
self.model_name,
self.batch_size,
self.default_dimension,
"EmbeddingAPIAdapter 初始化: "
f"batch_size={self.batch_size}, "
f"max_concurrent={self.max_concurrent}, "
f"default_dim={self.default_dimension}, "
f"model={self.model_name}"
)
def _get_current_model_config(self):
return config_manager.get_model_config()
@staticmethod
def _find_model_info(model_name: str) -> ModelInfo:
model_cfg = config_manager.get_model_config()
for item in model_cfg.models:
if item.name == model_name:
return item
raise ValueError(f"未找到 embedding 模型: {model_name}")
@staticmethod
def _find_provider(provider_name: str) -> APIProvider:
model_cfg = config_manager.get_model_config()
for item in model_cfg.api_providers:
if item.name == provider_name:
return item
raise ValueError(f"未找到 embedding provider: {provider_name}")
def _resolve_candidate_model_names(self) -> List[str]:
task_config = self._get_current_model_config().model_task_config.embedding
configured = list(getattr(task_config, "model_list", []) or [])
if self.model_name and self.model_name != "auto":
return [self.model_name, *[name for name in configured if name != self.model_name]]
return configured
@staticmethod
def _validate_embedding_vector(embedding: Any, *, source: str) -> np.ndarray:
array = np.asarray(embedding, dtype=np.float32)
if array.ndim != 1:
raise RuntimeError(f"{source} 返回的 embedding 维度非法: ndim={array.ndim}")
if array.size <= 0:
raise RuntimeError(f"{source} 返回了空 embedding")
if not np.all(np.isfinite(array)):
raise RuntimeError(f"{source} 返回了非有限 embedding 值")
return array
async def _request_with_retry(self, client, model_info, text: str, extra_params: dict):
retriable_exceptions = (
openai.APIConnectionError,
openai.APITimeoutError,
aiohttp.ClientError,
asyncio.TimeoutError,
NetworkConnectionError,
)
last_exc: Optional[BaseException] = None
for attempt in range(1, self.max_attempts + 1):
try:
return await client.get_embedding(
model_info=model_info,
embedding_input=text,
extra_params=extra_params,
)
except retriable_exceptions as exc:
last_exc = exc
if attempt >= self.max_attempts:
raise
wait_seconds = min(
self.max_wait_seconds,
self.min_wait_seconds * (self.backoff_multiplier ** (attempt - 1)),
)
logger.warning(
"Embedding 请求失败,重试 "
f"{attempt}/{max(1, self.max_attempts - 1)}"
f"{wait_seconds:.1f}s 后重试: {exc}"
)
await asyncio.sleep(wait_seconds)
except Exception:
raise
if last_exc is not None:
raise last_exc
raise RuntimeError("Embedding 请求失败:未知错误")
async def _get_embedding_direct(self, text: str, dimensions: Optional[int] = None) -> Optional[List[float]]:
candidate_names = self._resolve_candidate_model_names()
if not candidate_names:
raise RuntimeError("embedding 任务未配置模型")
last_exc: Optional[BaseException] = None
for candidate_name in candidate_names:
try:
model_info = self._find_model_info(candidate_name)
api_provider = self._find_provider(model_info.api_provider)
client = client_registry.get_client_class_instance(api_provider, force_new=True)
extra_params = dict(getattr(model_info, "extra_params", {}) or {})
if dimensions is not None:
extra_params["dimensions"] = int(dimensions)
response = await self._request_with_retry(
client=client,
model_info=model_info,
text=text,
extra_params=extra_params,
)
embedding = getattr(response, "embedding", None)
if embedding is None:
raise RuntimeError(f"模型 {candidate_name} 未返回 embedding")
vector = self._validate_embedding_vector(
embedding,
source=f"embedding 模型 {candidate_name}",
)
return vector.tolist()
except Exception as exc:
last_exc = exc
logger.warning(f"embedding 模型 {candidate_name} 请求失败: {exc}")
if last_exc is not None:
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
return None
async def _detect_dimension(self) -> int:
if self._dimension_detected and self._dimension is not None:
return self._dimension
logger.info("正在检测嵌入模型维度...")
try:
target_dim = self.default_dimension
logger.debug(f"尝试请求指定维度: {target_dim}")
test_embedding = await self._get_embedding_direct("test", dimensions=target_dim)
if test_embedding and isinstance(test_embedding, list):
detected_dim = len(test_embedding)
if detected_dim == target_dim:
logger.info(f"嵌入维度检测成功 (匹配配置): {detected_dim}")
else:
logger.warning(
f"请求维度 {target_dim} 但模型返回 {detected_dim},将使用模型自然维度"
)
self._dimension = detected_dim
self._dimension_detected = True
return detected_dim
except Exception as exc:
logger.debug(f"带维度参数探测失败: {exc},尝试不带参数探测")
try:
test_embedding = await self._get_embedding_direct("test", dimensions=None)
if test_embedding and isinstance(test_embedding, list):
detected_dim = len(test_embedding)
self._dimension = detected_dim
self._dimension_detected = True
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
return detected_dim
logger.warning(f"嵌入维度检测失败,使用默认值: {self.default_dimension}")
except Exception as exc:
logger.error(f"嵌入维度检测异常: {exc},使用默认值: {self.default_dimension}")
self._dimension = self.default_dimension
self._dimension_detected = True
return self._dimension
@staticmethod
def _tokenize(text: str) -> List[str]:
clean = str(text or "").strip().lower()
if not clean:
return []
return _TOKEN_PATTERN.findall(clean)
@staticmethod
def _feature_weight(token: str) -> float:
digest = hashlib.sha256(token.encode("utf-8")).digest()
return 1.0 + (digest[10] / 255.0) * 0.5
def _encode_single(self, text: str, dimension: int) -> np.ndarray:
vector = np.zeros(dimension, dtype=np.float32)
content = str(text or "").strip()
tokens = self._tokenize(content)
if not tokens and content:
tokens = [content.lower()]
if not tokens:
vector[0] = 1.0
return vector
for token in tokens:
digest = hashlib.sha256(token.encode("utf-8")).digest()
bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension
sign = 1.0 if digest[8] % 2 == 0 else -1.0
vector[bucket] += sign * self._feature_weight(token)
second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension
if second_bucket != bucket:
vector[second_bucket] += (sign * 0.35)
norm = float(np.linalg.norm(vector))
if norm > 1e-8:
vector /= norm
else:
vector[0] = 1.0
return vector
return self.default_dimension
async def encode(
self,
@@ -109,59 +225,137 @@ class EmbeddingAPIAdapter:
normalize: bool = True,
dimensions: Optional[int] = None,
) -> np.ndarray:
_ = batch_size
_ = show_progress
_ = normalize
del show_progress
del normalize
started_at = time.time()
target_dimension = max(32, int(dimensions or await self._detect_dimension()))
start_time = time.time()
target_dim = int(dimensions) if dimensions is not None else int(await self._detect_dimension())
if isinstance(texts, str):
single_input = True
normalized_texts = [texts]
single_input = True
else:
single_input = False
normalized_texts = list(texts or [])
single_input = False
if not normalized_texts:
empty = np.zeros((0, target_dimension), dtype=np.float32)
empty = np.zeros((0, target_dim), dtype=np.float32)
return empty[0] if single_input else empty
if batch_size is None:
batch_size = self.batch_size
try:
matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts])
embeddings = await self._encode_batch_internal(
normalized_texts,
batch_size=max(1, int(batch_size)),
dimensions=dimensions,
)
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
self._total_encoded += len(normalized_texts)
self._total_time += time.time() - started_at
except Exception:
elapsed = time.time() - start_time
self._total_time += elapsed
logger.debug(
"编码完成: "
f"{len(normalized_texts)} 个文本, "
f"耗时 {elapsed:.2f}s, "
f"平均 {elapsed / max(1, len(normalized_texts)):.3f}s/文本"
)
return embeddings[0] if single_input else embeddings
except Exception as exc:
self._total_errors += 1
raise
logger.error(f"编码失败: {exc}")
raise RuntimeError(f"embedding encode failed: {exc}") from exc
return matrix[0] if single_input else matrix
async def _encode_batch_internal(
self,
texts: List[str],
batch_size: int,
dimensions: Optional[int] = None,
) -> np.ndarray:
all_embeddings: List[np.ndarray] = []
for offset in range(0, len(texts), batch_size):
batch = texts[offset : offset + batch_size]
semaphore = asyncio.Semaphore(self.max_concurrent)
def get_statistics(self) -> dict:
avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0
async def encode_with_semaphore(text: str, index: int):
async with semaphore:
embedding = await self._get_embedding_direct(text, dimensions=dimensions)
if embedding is None:
raise RuntimeError(f"文本 {index} 编码失败embedding 返回为空")
vector = self._validate_embedding_vector(
embedding,
source=f"文本 {index}",
)
return index, vector
tasks = [
encode_with_semaphore(text, offset + index)
for index, text in enumerate(batch)
]
results = await asyncio.gather(*tasks)
results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in results)
return np.array(all_embeddings, dtype=np.float32)
async def encode_batch(
self,
texts: List[str],
batch_size: Optional[int] = None,
num_workers: Optional[int] = None,
show_progress: bool = False,
dimensions: Optional[int] = None,
) -> np.ndarray:
del show_progress
if num_workers is not None:
previous = self.max_concurrent
self.max_concurrent = max(1, int(num_workers))
try:
return await self.encode(texts, batch_size=batch_size, dimensions=dimensions)
finally:
self.max_concurrent = previous
return await self.encode(texts, batch_size=batch_size, dimensions=dimensions)
def get_embedding_dimension(self) -> int:
if self._dimension is not None:
return self._dimension
logger.warning(f"维度尚未检测,返回默认值: {self.default_dimension}")
return self.default_dimension
def get_model_info(self) -> dict:
return {
"model_name": self.model_name,
"dimension": self._dimension or self.default_dimension,
"dimension_detected": self._dimension_detected,
"batch_size": self.batch_size,
"max_concurrent": self.max_concurrent,
"total_encoded": self._total_encoded,
"total_errors": self._total_errors,
"total_time": self._total_time,
"avg_time_per_text": avg_time,
"avg_time_per_text": self._total_time / self._total_encoded if self._total_encoded else 0.0,
}
def get_statistics(self) -> dict:
return self.get_model_info()
@property
def is_model_loaded(self) -> bool:
return True
def __repr__(self) -> str:
return (
f"EmbeddingAPIAdapter(model_name={self.model_name}, "
f"dimension={self._dimension or self.default_dimension}, "
f"total_encoded={self._total_encoded})"
f"EmbeddingAPIAdapter(dim={self._dimension or self.default_dimension}, "
f"detected={self._dimension_detected}, encoded={self._total_encoded})"
)
def create_embedding_api_adapter(
batch_size: int = 32,
max_concurrent: int = 5,
default_dimension: int = 256,
default_dimension: int = 1024,
enable_cache: bool = False,
model_name: str = "hash-v1",
model_name: str = "auto",
retry_config: Optional[dict] = None,
) -> EmbeddingAPIAdapter:
return EmbeddingAPIAdapter(

View File

@@ -285,10 +285,10 @@ class DualPathRetriever:
relation_intent_ctx = self._build_relation_intent_context(query=query, top_k=top_k)
logger.info(
"执行检索: query='%s...', strategy=%s, relation_intent=%s",
query[:50],
strategy.value,
relation_intent_ctx.get("enabled", False),
"执行检索: "
f"query='{query[:50]}...', "
f"strategy={strategy.value}, "
f"relation_intent={relation_intent_ctx.get('enabled', False)}"
)
if temporal and not (query or "").strip():
@@ -1408,10 +1408,10 @@ class DualPathRetriever:
return results
logger.debug(
"relation_rerank_applied=1 relation_pair_groups=%s relation_pair_overflow_count=%s relation_pair_limit=%s",
len(ordered_groups),
len(overflow),
pair_limit,
"relation_rerank_applied=1 "
f"relation_pair_groups={len(ordered_groups)} "
f"relation_pair_overflow_count={len(overflow)} "
f"relation_pair_limit={pair_limit}"
)
rebuilt = list(results)
@@ -1455,9 +1455,9 @@ class DualPathRetriever:
)
except asyncio.TimeoutError:
logger.warning(
"metric.ppr_timeout_skip_count=1 timeout_s=%s entities=%s",
ppr_timeout_s,
len(entities),
"metric.ppr_timeout_skip_count=1 "
f"timeout_s={ppr_timeout_s} "
f"entities={len(entities)}"
)
return results
except Exception as e:

View File

@@ -170,7 +170,7 @@ class GraphRelationRecallService:
max_paths=self.config.max_paths,
)
except Exception as e:
logger.debug("graph two-hop recall skipped: %s", e)
logger.debug(f"graph two-hop recall skipped: {e}")
return
for path_nodes in paths:
@@ -210,7 +210,7 @@ class GraphRelationRecallService:
limit=self.config.candidate_k,
)
except Exception as e:
logger.debug("graph one-hop recall skipped: %s", e)
logger.debug(f"graph one-hop recall skipped: {e}")
return
self._append_relation_hashes(
relation_hashes=relation_hashes,

View File

@@ -123,9 +123,8 @@ class SparseBM25Index:
self._loaded = True
self._prepare_tokenizer()
logger.info(
"SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s",
self.config.tokenizer_mode,
self.config.mode,
"SparseBM25Index loaded: "
f"backend=fts5, tokenizer={self.config.tokenizer_mode}, mode={self.config.mode}"
)
return True
@@ -141,9 +140,9 @@ class SparseBM25Index:
if user_dict:
try:
jieba.load_userdict(user_dict) # type: ignore[union-attr]
logger.info("已加载 jieba 用户词典: %s", user_dict)
logger.info(f"已加载 jieba 用户词典: {user_dict}")
except Exception as e:
logger.warning("加载 jieba 用户词典失败: %s", e)
logger.warning(f"加载 jieba 用户词典失败: {e}")
self._jieba_dict_loaded = True
def _tokenize_jieba(self, text: str) -> List[str]:

View File

@@ -1,8 +1,16 @@
"""SDK runtime exports for A_Memorix."""
from .search_runtime_initializer import (
SearchRuntimeBundle,
SearchRuntimeInitializer,
build_search_runtime,
)
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
__all__ = [
"SearchRuntimeBundle",
"SearchRuntimeInitializer",
"build_search_runtime",
"KernelSearchRequest",
"SDKMemoryKernel",
]

View File

@@ -0,0 +1,268 @@
"""Lifecycle bootstrap/teardown helpers extracted from plugin.py."""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any
from src.common.logger import get_logger
from ..embedding import create_embedding_api_adapter
from ..retrieval import SparseBM25Config, SparseBM25Index
from ..storage import (
GraphStore,
MetadataStore,
QuantizationType,
SparseMatrixFormat,
VectorStore,
)
from ..utils.runtime_self_check import ensure_runtime_self_check
from ..utils.relation_write_service import RelationWriteService
logger = get_logger("A_Memorix.LifecycleOrchestrator")
async def ensure_initialized(plugin: Any) -> None:
if plugin._initialized:
plugin._runtime_ready = plugin._check_storage_ready()
return
async with plugin._init_lock:
if plugin._initialized:
plugin._runtime_ready = plugin._check_storage_ready()
return
logger.info("A_Memorix 插件正在异步初始化存储组件...")
plugin._validate_runtime_config()
await initialize_storage_async(plugin)
report = await ensure_runtime_self_check(plugin, force=True)
if not bool(report.get("ok", False)):
logger.error(
"A_Memorix runtime self-check failed: "
f"{report.get('message', 'unknown')}; "
"建议执行 python plugins/A_memorix/scripts/runtime_self_check.py --json"
)
if plugin.graph_store and plugin.metadata_store:
relation_count = plugin.metadata_store.count_relations()
if relation_count > 0 and not plugin.graph_store.has_edge_hash_map():
raise RuntimeError(
"检测到 relations 数据存在但 edge-hash-map 为空。"
" 请先执行 scripts/release_vnext_migrate.py migrate。"
)
plugin._initialized = True
plugin._runtime_ready = plugin._check_storage_ready()
plugin._update_plugin_config()
logger.info("A_Memorix 插件异步初始化成功")
def start_background_tasks(plugin: Any) -> None:
"""Start background tasks idempotently."""
if not hasattr(plugin, "_episode_generation_task"):
plugin._episode_generation_task = None
if (
plugin.get_config("summarization.enabled", True)
and plugin.get_config("schedule.enabled", True)
and (plugin._scheduled_import_task is None or plugin._scheduled_import_task.done())
):
plugin._scheduled_import_task = asyncio.create_task(plugin._scheduled_import_loop())
if (
plugin.get_config("advanced.enable_auto_save", True)
and (plugin._auto_save_task is None or plugin._auto_save_task.done())
):
plugin._auto_save_task = asyncio.create_task(plugin._auto_save_loop())
if (
plugin.get_config("person_profile.enabled", True)
and (plugin._person_profile_refresh_task is None or plugin._person_profile_refresh_task.done())
):
plugin._person_profile_refresh_task = asyncio.create_task(plugin._person_profile_refresh_loop())
if plugin._memory_maintenance_task is None or plugin._memory_maintenance_task.done():
plugin._memory_maintenance_task = asyncio.create_task(plugin._memory_maintenance_loop())
rv_cfg = plugin.get_config("retrieval.relation_vectorization", {}) or {}
if isinstance(rv_cfg, dict):
rv_enabled = bool(rv_cfg.get("enabled", False))
rv_backfill = bool(rv_cfg.get("backfill_enabled", False))
else:
rv_enabled = False
rv_backfill = False
if rv_enabled and rv_backfill and (
plugin._relation_vector_backfill_task is None or plugin._relation_vector_backfill_task.done()
):
plugin._relation_vector_backfill_task = asyncio.create_task(plugin._relation_vector_backfill_loop())
episode_task = getattr(plugin, "_episode_generation_task", None)
episode_loop = getattr(plugin, "_episode_generation_loop", None)
if (
callable(episode_loop)
and bool(plugin.get_config("episode.enabled", True))
and bool(plugin.get_config("episode.generation_enabled", True))
and (episode_task is None or episode_task.done())
):
plugin._episode_generation_task = asyncio.create_task(episode_loop())
async def cancel_background_tasks(plugin: Any) -> None:
"""Cancel all background tasks and wait for cleanup."""
tasks = [
("scheduled_import", plugin._scheduled_import_task),
("auto_save", plugin._auto_save_task),
("person_profile_refresh", plugin._person_profile_refresh_task),
("memory_maintenance", plugin._memory_maintenance_task),
("relation_vector_backfill", plugin._relation_vector_backfill_task),
("episode_generation", getattr(plugin, "_episode_generation_task", None)),
]
for _, task in tasks:
if task and not task.done():
task.cancel()
for name, task in tasks:
if not task:
continue
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.warning(f"后台任务 {name} 退出异常: {e}")
plugin._scheduled_import_task = None
plugin._auto_save_task = None
plugin._person_profile_refresh_task = None
plugin._memory_maintenance_task = None
plugin._relation_vector_backfill_task = None
plugin._episode_generation_task = None
async def initialize_storage_async(plugin: Any) -> None:
"""Initialize storage components asynchronously."""
data_dir_str = plugin.get_config("storage.data_dir", "./data")
if data_dir_str.startswith("."):
plugin_dir = Path(__file__).resolve().parents[2]
data_dir = (plugin_dir / data_dir_str).resolve()
else:
data_dir = Path(data_dir_str)
logger.info(f"A_Memorix 数据存储路径: {data_dir}")
data_dir.mkdir(parents=True, exist_ok=True)
plugin.embedding_manager = create_embedding_api_adapter(
batch_size=plugin.get_config("embedding.batch_size", 32),
max_concurrent=plugin.get_config("embedding.max_concurrent", 5),
default_dimension=plugin.get_config("embedding.dimension", 1024),
model_name=plugin.get_config("embedding.model_name", "auto"),
retry_config=plugin.get_config("embedding.retry", {}),
)
logger.info("嵌入 API 适配器初始化完成")
try:
detected_dimension = await plugin.embedding_manager._detect_dimension()
logger.info(f"嵌入维度检测成功: {detected_dimension}")
except Exception as e:
logger.warning(f"嵌入维度检测失败: {e},使用默认值")
detected_dimension = plugin.embedding_manager.default_dimension
quantization_str = plugin.get_config("embedding.quantization_type", "int8")
if str(quantization_str or "").strip().lower() != "int8":
raise ValueError("embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。")
quantization_type = QuantizationType.INT8
plugin.vector_store = VectorStore(
dimension=detected_dimension,
quantization_type=quantization_type,
data_dir=data_dir / "vectors",
)
plugin.vector_store.min_train_threshold = plugin.get_config("embedding.min_train_threshold", 40)
logger.info(
"向量存储初始化完成("
f"维度: {detected_dimension}, "
f"训练阈值: {plugin.vector_store.min_train_threshold}"
)
matrix_format_str = plugin.get_config("graph.sparse_matrix_format", "csr")
matrix_format_map = {
"csr": SparseMatrixFormat.CSR,
"csc": SparseMatrixFormat.CSC,
}
matrix_format = matrix_format_map.get(matrix_format_str, SparseMatrixFormat.CSR)
plugin.graph_store = GraphStore(
matrix_format=matrix_format,
data_dir=data_dir / "graph",
)
logger.info("图存储初始化完成")
plugin.metadata_store = MetadataStore(data_dir=data_dir / "metadata")
plugin.metadata_store.connect()
logger.info("元数据存储初始化完成")
plugin.relation_write_service = RelationWriteService(
metadata_store=plugin.metadata_store,
graph_store=plugin.graph_store,
vector_store=plugin.vector_store,
embedding_manager=plugin.embedding_manager,
)
logger.info("关系写入服务初始化完成")
sparse_cfg_raw = plugin.get_config("retrieval.sparse", {}) or {}
if not isinstance(sparse_cfg_raw, dict):
sparse_cfg_raw = {}
try:
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
except Exception as e:
logger.warning(f"sparse 配置非法,回退默认配置: {e}")
sparse_cfg = SparseBM25Config()
plugin.sparse_index = SparseBM25Index(
metadata_store=plugin.metadata_store,
config=sparse_cfg,
)
logger.info(
"稀疏检索组件初始化完成: "
f"enabled={sparse_cfg.enabled}, "
f"lazy_load={sparse_cfg.lazy_load}, "
f"mode={sparse_cfg.mode}, "
f"tokenizer={sparse_cfg.tokenizer_mode}"
)
if sparse_cfg.enabled and not sparse_cfg.lazy_load:
plugin.sparse_index.ensure_loaded()
if plugin.vector_store.has_data():
try:
plugin.vector_store.load()
logger.info(f"向量数据已加载,共 {plugin.vector_store.num_vectors} 个向量")
except Exception as e:
logger.warning(f"加载向量数据失败: {e}")
try:
warmup_summary = plugin.vector_store.warmup_index(force_train=True)
if warmup_summary.get("ok"):
logger.info(
"向量索引预热完成: "
f"trained={warmup_summary.get('trained')}, "
f"index_ntotal={warmup_summary.get('index_ntotal')}, "
f"fallback_ntotal={warmup_summary.get('fallback_ntotal')}, "
f"bin_count={warmup_summary.get('bin_count')}, "
f"duration_ms={float(warmup_summary.get('duration_ms', 0.0)):.2f}"
)
else:
logger.warning(
"向量索引预热失败,继续启用 sparse 降级路径: "
f"{warmup_summary.get('error', 'unknown')}"
)
except Exception as e:
logger.warning(f"向量索引预热异常,继续启用 sparse 降级路径: {e}")
if plugin.graph_store.has_data():
try:
plugin.graph_store.load()
logger.info(f"图数据已加载,共 {plugin.graph_store.num_nodes} 个节点")
except Exception as e:
logger.warning(f"加载图数据失败: {e}")
logger.info(f"知识库数据目录: {data_dir}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,240 @@
"""Shared runtime initializer for Action/Tool/Command retrieval components."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional
from src.common.logger import get_logger
from ..retrieval import (
DualPathRetriever,
DualPathRetrieverConfig,
DynamicThresholdFilter,
FusionConfig,
GraphRelationRecallConfig,
RelationIntentConfig,
RetrievalStrategy,
SparseBM25Config,
ThresholdConfig,
ThresholdMethod,
)
_logger = get_logger("A_Memorix.SearchRuntimeInitializer")
_REQUIRED_COMPONENT_KEYS = (
"vector_store",
"graph_store",
"metadata_store",
"embedding_manager",
)
def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any:
if not isinstance(config, dict):
return default
current: Any = config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _safe_dict(value: Any) -> Dict[str, Any]:
return value if isinstance(value, dict) else {}
def _resolve_debug_enabled(plugin_config: Optional[dict]) -> bool:
advanced = _get_config_value(plugin_config, "advanced", {})
if isinstance(advanced, dict):
return bool(advanced.get("debug", False))
return bool(_get_config_value(plugin_config, "debug", False))
@dataclass
class SearchRuntimeBundle:
"""Resolved runtime components and initialized retriever/filter."""
vector_store: Optional[Any] = None
graph_store: Optional[Any] = None
metadata_store: Optional[Any] = None
embedding_manager: Optional[Any] = None
sparse_index: Optional[Any] = None
retriever: Optional[DualPathRetriever] = None
threshold_filter: Optional[DynamicThresholdFilter] = None
error: str = ""
@property
def ready(self) -> bool:
return (
self.retriever is not None
and self.vector_store is not None
and self.graph_store is not None
and self.metadata_store is not None
and self.embedding_manager is not None
)
def _resolve_runtime_components(plugin_config: Optional[dict]) -> SearchRuntimeBundle:
bundle = SearchRuntimeBundle(
vector_store=_get_config_value(plugin_config, "vector_store"),
graph_store=_get_config_value(plugin_config, "graph_store"),
metadata_store=_get_config_value(plugin_config, "metadata_store"),
embedding_manager=_get_config_value(plugin_config, "embedding_manager"),
sparse_index=_get_config_value(plugin_config, "sparse_index"),
)
missing_required = any(
getattr(bundle, key) is None for key in _REQUIRED_COMPONENT_KEYS
)
if not missing_required:
return bundle
try:
from ...plugin import AMemorixPlugin
instances = AMemorixPlugin.get_storage_instances()
except Exception:
instances = {}
if not isinstance(instances, dict) or not instances:
return bundle
if bundle.vector_store is None:
bundle.vector_store = instances.get("vector_store")
if bundle.graph_store is None:
bundle.graph_store = instances.get("graph_store")
if bundle.metadata_store is None:
bundle.metadata_store = instances.get("metadata_store")
if bundle.embedding_manager is None:
bundle.embedding_manager = instances.get("embedding_manager")
if bundle.sparse_index is None:
bundle.sparse_index = instances.get("sparse_index")
return bundle
def build_search_runtime(
plugin_config: Optional[dict],
logger_obj: Optional[Any],
owner_tag: str,
*,
log_prefix: str = "",
) -> SearchRuntimeBundle:
"""Build retriever + threshold filter with unified fallback/config parsing."""
log = logger_obj or _logger
owner = str(owner_tag or "runtime").strip().lower() or "runtime"
prefix = str(log_prefix or "").strip()
prefix_text = f"{prefix} " if prefix else ""
runtime = _resolve_runtime_components(plugin_config)
if any(getattr(runtime, key) is None for key in _REQUIRED_COMPONENT_KEYS):
runtime.error = "存储组件未完全初始化"
log.warning(f"{prefix_text}[{owner}] 存储组件未完全初始化,无法使用检索功能")
return runtime
sparse_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.sparse", {}) or {})
fusion_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.fusion", {}) or {})
relation_intent_cfg_raw = _safe_dict(
_get_config_value(plugin_config, "retrieval.search.relation_intent", {}) or {}
)
graph_recall_cfg_raw = _safe_dict(
_get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {}
)
try:
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
except Exception as e:
log.warning(f"{prefix_text}[{owner}] sparse 配置非法,回退默认: {e}")
sparse_cfg = SparseBM25Config()
try:
fusion_cfg = FusionConfig(**fusion_cfg_raw)
except Exception as e:
log.warning(f"{prefix_text}[{owner}] fusion 配置非法,回退默认: {e}")
fusion_cfg = FusionConfig()
try:
relation_intent_cfg = RelationIntentConfig(**relation_intent_cfg_raw)
except Exception as e:
log.warning(f"{prefix_text}[{owner}] relation_intent 配置非法,回退默认: {e}")
relation_intent_cfg = RelationIntentConfig()
try:
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
except Exception as e:
log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}")
graph_recall_cfg = GraphRelationRecallConfig()
try:
config = DualPathRetrieverConfig(
top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20),
top_k_relations=_get_config_value(plugin_config, "retrieval.top_k_relations", 10),
top_k_final=_get_config_value(plugin_config, "retrieval.top_k_final", 10),
alpha=_get_config_value(plugin_config, "retrieval.alpha", 0.5),
enable_ppr=_get_config_value(plugin_config, "retrieval.enable_ppr", True),
ppr_alpha=_get_config_value(plugin_config, "retrieval.ppr_alpha", 0.85),
ppr_timeout_seconds=_get_config_value(
plugin_config, "retrieval.ppr_timeout_seconds", 1.5
),
ppr_concurrency_limit=_get_config_value(
plugin_config, "retrieval.ppr_concurrency_limit", 4
),
enable_parallel=_get_config_value(plugin_config, "retrieval.enable_parallel", True),
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
debug=_resolve_debug_enabled(plugin_config),
sparse=sparse_cfg,
fusion=fusion_cfg,
relation_intent=relation_intent_cfg,
graph_recall=graph_recall_cfg,
)
runtime.retriever = DualPathRetriever(
vector_store=runtime.vector_store,
graph_store=runtime.graph_store,
metadata_store=runtime.metadata_store,
embedding_manager=runtime.embedding_manager,
sparse_index=runtime.sparse_index,
config=config,
)
threshold_config = ThresholdConfig(
method=ThresholdMethod.ADAPTIVE,
min_threshold=_get_config_value(plugin_config, "threshold.min_threshold", 0.3),
max_threshold=_get_config_value(plugin_config, "threshold.max_threshold", 0.95),
percentile=_get_config_value(plugin_config, "threshold.percentile", 75.0),
std_multiplier=_get_config_value(plugin_config, "threshold.std_multiplier", 1.5),
min_results=_get_config_value(plugin_config, "threshold.min_results", 3),
enable_auto_adjust=_get_config_value(plugin_config, "threshold.enable_auto_adjust", True),
)
runtime.threshold_filter = DynamicThresholdFilter(threshold_config)
runtime.error = ""
log.info(f"{prefix_text}[{owner}] 检索运行时初始化完成")
except Exception as e:
runtime.retriever = None
runtime.threshold_filter = None
runtime.error = str(e)
log.error(f"{prefix_text}[{owner}] 检索运行时初始化失败: {e}")
return runtime
class SearchRuntimeInitializer:
"""Compatibility wrapper around the function style initializer."""
@staticmethod
def build_search_runtime(
plugin_config: Optional[dict],
logger_obj: Optional[Any],
owner_tag: str,
*,
log_prefix: str = "",
) -> SearchRuntimeBundle:
return build_search_runtime(
plugin_config=plugin_config,
logger_obj=logger_obj,
owner_tag=owner_tag,
log_prefix=log_prefix,
)

View File

@@ -24,6 +24,20 @@ try:
from scipy.sparse.linalg import norm
HAS_SCIPY = True
except ImportError:
class _SparseMatrixPlaceholder:
pass
def _scipy_missing(*args, **kwargs):
raise ImportError("SciPy 未安装,请安装: pip install scipy")
csr_matrix = _SparseMatrixPlaceholder
csc_matrix = _SparseMatrixPlaceholder
lil_matrix = _SparseMatrixPlaceholder
triu = _scipy_missing
save_npz = _scipy_missing
load_npz = _scipy_missing
bmat = _scipy_missing
norm = _scipy_missing
HAS_SCIPY = False
import contextlib

View File

@@ -7,6 +7,8 @@
import sqlite3
import pickle
import json
import uuid
import re
from datetime import datetime
from pathlib import Path
from typing import Optional, Union, List, Dict, Any, Tuple
@@ -24,7 +26,7 @@ from .knowledge_types import (
logger = get_logger("A_Memorix.MetadataStore")
SCHEMA_VERSION = 7
SCHEMA_VERSION = 8
class MetadataStore:
@@ -500,6 +502,63 @@ class MetadataStore:
CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph
ON external_memory_refs(paragraph_hash)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS memory_v5_operations (
operation_id TEXT PRIMARY KEY,
action TEXT NOT NULL,
target TEXT,
reason TEXT,
updated_by TEXT,
created_at REAL NOT NULL,
resolved_hashes_json TEXT,
result_json TEXT
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created
ON memory_v5_operations(created_at DESC)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS delete_operations (
operation_id TEXT PRIMARY KEY,
mode TEXT NOT NULL,
selector TEXT,
reason TEXT,
requested_by TEXT,
status TEXT NOT NULL,
created_at REAL NOT NULL,
restored_at REAL,
summary_json TEXT
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_delete_operations_created
ON delete_operations(created_at DESC)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_delete_operations_mode
ON delete_operations(mode, created_at DESC)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS delete_operation_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
operation_id TEXT NOT NULL,
item_type TEXT NOT NULL,
item_hash TEXT,
item_key TEXT,
payload_json TEXT,
created_at REAL NOT NULL,
FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation
ON delete_operation_items(operation_id, id ASC)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash
ON delete_operation_items(item_hash)
""")
# 新版 schema 包含完整字段,直接写入版本信息
cursor.execute("INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", (SCHEMA_VERSION, datetime.now().timestamp()))
self._conn.commit()
@@ -618,6 +677,63 @@ class MetadataStore:
CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph
ON external_memory_refs(paragraph_hash)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS memory_v5_operations (
operation_id TEXT PRIMARY KEY,
action TEXT NOT NULL,
target TEXT,
reason TEXT,
updated_by TEXT,
created_at REAL NOT NULL,
resolved_hashes_json TEXT,
result_json TEXT
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created
ON memory_v5_operations(created_at DESC)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS delete_operations (
operation_id TEXT PRIMARY KEY,
mode TEXT NOT NULL,
selector TEXT,
reason TEXT,
requested_by TEXT,
status TEXT NOT NULL,
created_at REAL NOT NULL,
restored_at REAL,
summary_json TEXT
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_delete_operations_created
ON delete_operations(created_at DESC)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_delete_operations_mode
ON delete_operations(mode, created_at DESC)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS delete_operation_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
operation_id TEXT NOT NULL,
item_type TEXT NOT NULL,
item_hash TEXT,
item_key TEXT,
payload_json TEXT,
created_at REAL NOT NULL,
FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation
ON delete_operation_items(operation_id, id ASC)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash
ON delete_operation_items(item_hash)
""")
# 检查paragraphs表是否有knowledge_type列
cursor.execute("PRAGMA table_info(paragraphs)")
@@ -2595,6 +2711,328 @@ class MetadataStore:
"metadata": metadata or {},
}
@staticmethod
def _json_dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True)
@staticmethod
def _json_loads(value: Any, default: Any) -> Any:
if value in {None, ""}:
return default
try:
return json.loads(value)
except Exception:
return default
def list_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]:
hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()]
if not hashes:
return []
placeholders = ",".join(["?"] * len(hashes))
cursor = self._conn.cursor()
cursor.execute(
f"""
SELECT external_id, paragraph_hash, source_type, created_at, metadata_json
FROM external_memory_refs
WHERE paragraph_hash IN ({placeholders})
ORDER BY created_at ASC, external_id ASC
""",
tuple(hashes),
)
items: List[Dict[str, Any]] = []
for row in cursor.fetchall():
payload = dict(row)
payload["metadata"] = self._json_loads(payload.get("metadata_json"), {})
items.append(payload)
return items
def delete_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]:
items = self.list_external_memory_refs_by_paragraphs(paragraph_hashes)
hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()]
if not hashes:
return items
placeholders = ",".join(["?"] * len(hashes))
cursor = self._conn.cursor()
cursor.execute(
f"DELETE FROM external_memory_refs WHERE paragraph_hash IN ({placeholders})",
tuple(hashes),
)
self._conn.commit()
return items
def restore_external_memory_refs(self, refs: List[Dict[str, Any]]) -> int:
count = 0
for item in refs or []:
external_id = str(item.get("external_id", "") or "").strip()
paragraph_hash = str(item.get("paragraph_hash", "") or "").strip()
if not external_id or not paragraph_hash:
continue
created_at = float(item.get("created_at") or datetime.now().timestamp())
metadata_json = self._json_dumps(item.get("metadata") or {})
cursor = self._conn.cursor()
cursor.execute(
"""
INSERT INTO external_memory_refs (
external_id, paragraph_hash, source_type, created_at, metadata_json
)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(external_id) DO UPDATE SET
paragraph_hash = excluded.paragraph_hash,
source_type = excluded.source_type,
created_at = excluded.created_at,
metadata_json = excluded.metadata_json
""",
(
external_id,
paragraph_hash,
str(item.get("source_type", "") or "").strip() or None,
created_at,
metadata_json,
),
)
count += max(0, int(cursor.rowcount or 0))
self._conn.commit()
return count
def record_v5_operation(
self,
*,
action: str,
target: str,
resolved_hashes: List[str],
reason: str = "",
updated_by: str = "",
result: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
operation_id = f"v5_{uuid.uuid4().hex}"
created_at = datetime.now().timestamp()
payload = {
"operation_id": operation_id,
"action": str(action or "").strip(),
"target": str(target or "").strip(),
"reason": str(reason or "").strip(),
"updated_by": str(updated_by or "").strip(),
"created_at": created_at,
"resolved_hashes": [str(item or "").strip() for item in (resolved_hashes or []) if str(item or "").strip()],
"result": result or {},
}
cursor = self._conn.cursor()
cursor.execute(
"""
INSERT INTO memory_v5_operations (
operation_id, action, target, reason, updated_by, created_at, resolved_hashes_json, result_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
operation_id,
payload["action"],
payload["target"] or None,
payload["reason"] or None,
payload["updated_by"] or None,
created_at,
self._json_dumps(payload["resolved_hashes"]),
self._json_dumps(payload["result"]),
),
)
self._conn.commit()
return payload
def create_delete_operation(
self,
*,
mode: str,
selector: Any,
items: List[Dict[str, Any]],
reason: str = "",
requested_by: str = "",
status: str = "executed",
summary: Optional[Dict[str, Any]] = None,
operation_id: Optional[str] = None,
) -> Dict[str, Any]:
op_id = str(operation_id or f"del_{uuid.uuid4().hex}").strip()
created_at = datetime.now().timestamp()
normalized_items: List[Dict[str, Any]] = []
for item in items or []:
if not isinstance(item, dict):
continue
item_type = str(item.get("item_type", "") or "").strip()
if not item_type:
continue
normalized_items.append(
{
"item_type": item_type,
"item_hash": str(item.get("item_hash", "") or "").strip() or None,
"item_key": str(item.get("item_key", "") or item.get("item_hash", "") or "").strip() or None,
"payload": item.get("payload") if isinstance(item.get("payload"), dict) else {},
}
)
cursor = self._conn.cursor()
cursor.execute(
"""
INSERT INTO delete_operations (
operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json
) VALUES (?, ?, ?, ?, ?, ?, ?, NULL, ?)
""",
(
op_id,
str(mode or "").strip(),
self._json_dumps(selector if selector is not None else {}),
str(reason or "").strip() or None,
str(requested_by or "").strip() or None,
str(status or "executed").strip(),
created_at,
self._json_dumps(summary or {}),
),
)
if normalized_items:
cursor.executemany(
"""
INSERT INTO delete_operation_items (
operation_id, item_type, item_hash, item_key, payload_json, created_at
) VALUES (?, ?, ?, ?, ?, ?)
""",
[
(
op_id,
item["item_type"],
item["item_hash"],
item["item_key"],
self._json_dumps(item["payload"]),
created_at,
)
for item in normalized_items
],
)
self._conn.commit()
return self.get_delete_operation(op_id) or {
"operation_id": op_id,
"mode": str(mode or "").strip(),
"selector": selector,
"reason": str(reason or "").strip(),
"requested_by": str(requested_by or "").strip(),
"status": str(status or "executed").strip(),
"created_at": created_at,
"summary": summary or {},
"items": normalized_items,
}
def mark_delete_operation_restored(
self,
operation_id: str,
*,
summary: Optional[Dict[str, Any]] = None,
) -> bool:
token = str(operation_id or "").strip()
if not token:
return False
cursor = self._conn.cursor()
cursor.execute(
"""
UPDATE delete_operations
SET status = ?, restored_at = ?, summary_json = ?
WHERE operation_id = ?
""",
(
"restored",
datetime.now().timestamp(),
self._json_dumps(summary or {}),
token,
),
)
self._conn.commit()
return cursor.rowcount > 0
def list_delete_operations(self, *, limit: int = 50, mode: str = "") -> List[Dict[str, Any]]:
cursor = self._conn.cursor()
params: List[Any] = []
where = ""
mode_token = str(mode or "").strip().lower()
if mode_token:
where = "WHERE LOWER(mode) = ?"
params.append(mode_token)
params.append(max(1, int(limit or 50)))
cursor.execute(
f"""
SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json
FROM delete_operations
{where}
ORDER BY created_at DESC
LIMIT ?
""",
tuple(params),
)
items: List[Dict[str, Any]] = []
for row in cursor.fetchall():
payload = dict(row)
payload["selector"] = self._json_loads(payload.get("selector"), {})
payload["summary"] = self._json_loads(payload.get("summary_json"), {})
items.append(payload)
return items
def get_delete_operation(self, operation_id: str) -> Optional[Dict[str, Any]]:
token = str(operation_id or "").strip()
if not token:
return None
cursor = self._conn.cursor()
cursor.execute(
"""
SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json
FROM delete_operations
WHERE operation_id = ?
LIMIT 1
""",
(token,),
)
row = cursor.fetchone()
if row is None:
return None
payload = dict(row)
payload["selector"] = self._json_loads(payload.get("selector"), {})
payload["summary"] = self._json_loads(payload.get("summary_json"), {})
cursor.execute(
"""
SELECT item_type, item_hash, item_key, payload_json, created_at
FROM delete_operation_items
WHERE operation_id = ?
ORDER BY id ASC
""",
(token,),
)
payload["items"] = [
{
"item_type": str(item["item_type"] or ""),
"item_hash": str(item["item_hash"] or ""),
"item_key": str(item["item_key"] or ""),
"payload": self._json_loads(item["payload_json"], {}),
"created_at": item["created_at"],
}
for item in cursor.fetchall()
]
return payload
def purge_deleted_relations(self, *, cutoff_time: float, limit: int = 1000) -> List[str]:
cursor = self._conn.cursor()
cursor.execute(
"""
SELECT hash
FROM deleted_relations
WHERE deleted_at IS NOT NULL AND deleted_at < ?
ORDER BY deleted_at ASC
LIMIT ?
""",
(float(cutoff_time), max(1, int(limit or 1000))),
)
hashes = [str(row[0] or "").strip() for row in cursor.fetchall() if str(row[0] or "").strip()]
if not hashes:
return []
placeholders = ",".join(["?"] * len(hashes))
cursor.execute(f"DELETE FROM deleted_relations WHERE hash IN ({placeholders})", tuple(hashes))
self._conn.commit()
return hashes
def get_statistics(self) -> Dict[str, int]:
"""
获取统计信息
@@ -2956,6 +3394,18 @@ class MetadataStore:
self._conn.commit()
return changed
def restore_paragraph_by_hash(self, paragraph_hash: str) -> bool:
"""恢复软删除段落。"""
cursor = self._conn.cursor()
cursor.execute(
"UPDATE paragraphs SET is_deleted=0, deleted_at=NULL WHERE hash=?",
(str(paragraph_hash),),
)
changed = cursor.rowcount > 0
if changed:
self._conn.commit()
return changed
def backfill_temporal_metadata_from_created_at(
self,
*,
@@ -4698,6 +5148,29 @@ class MetadataStore:
)
self._conn.commit()
def get_episode_pending_status_counts(self, source: str) -> Dict[str, int]:
"""统计某个 source 当前 pending 队列中的状态分布。"""
token = self._normalize_episode_source(source)
if not token:
return {"pending": 0, "running": 0, "failed": 0, "done": 0}
cursor = self._conn.cursor()
cursor.execute(
"""
SELECT status, COUNT(*) AS count
FROM episode_pending_paragraphs
WHERE TRIM(COALESCE(source, '')) = ?
GROUP BY status
""",
(token,),
)
counts = {"pending": 0, "running": 0, "failed": 0, "done": 0}
for row in cursor.fetchall():
status = str(row["status"] or "").strip().lower()
if status in counts:
counts[status] = int(row["count"] or 0)
return counts
def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]:
data = dict(row)
@@ -4904,7 +5377,7 @@ class MetadataStore:
SELECT 1
FROM episode_rebuild_sources ers
WHERE ers.source = TRIM(COALESCE(e.source, ''))
AND ers.status IN ('pending', 'running', 'failed')
AND ers.status IN ('pending', 'running')
)
"""
)
@@ -4948,6 +5421,26 @@ class MetadataStore:
return source_expr, effective_start, effective_end, conditions, params
@staticmethod
def _tokenize_episode_query(query: str) -> Tuple[str, List[str]]:
"""将 episode 查询归一化为短语和 token。"""
normalized = normalize_text(str(query or "")).strip().lower()
if not normalized:
return "", []
token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}")
tokens: List[str] = []
seen = set()
for token in token_pattern.findall(normalized):
if token in seen:
continue
seen.add(token)
tokens.append(token)
if not tokens and len(normalized) >= 2:
tokens = [normalized]
return normalized, tokens
def get_episode_rows_by_paragraph_hashes(
self,
paragraph_hashes: List[str],
@@ -5097,28 +5590,58 @@ class MetadataStore:
source=source,
)
q = str(query or "").strip().lower()
q, tokens = self._tokenize_episode_query(query)
select_score_sql = "0.0 AS lexical_score"
order_sql = f"{effective_end} DESC, e.updated_at DESC"
select_params: List[Any] = []
query_params: List[Any] = []
if q:
like = f"%{q}%"
title_expr = "LOWER(COALESCE(e.title, '')) LIKE ?"
summary_expr = "LOWER(COALESCE(e.summary, '')) LIKE ?"
keywords_expr = "LOWER(COALESCE(e.keywords_json, '')) LIKE ?"
participants_expr = "LOWER(COALESCE(e.participants_json, '')) LIKE ?"
conditions.append(
f"({title_expr} OR {summary_expr} OR {keywords_expr} OR {participants_expr})"
field_exprs = {
"title": "LOWER(COALESCE(e.title, ''))",
"summary": "LOWER(COALESCE(e.summary, ''))",
"keywords": "LOWER(COALESCE(e.keywords_json, ''))",
"participants": "LOWER(COALESCE(e.participants_json, ''))",
}
score_parts: List[str] = []
phrase_like = f"%{q}%"
score_parts.extend(
[
f"CASE WHEN {field_exprs['title']} LIKE ? THEN 6.0 ELSE 0.0 END",
f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 4.5 ELSE 0.0 END",
f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 3.0 ELSE 0.0 END",
f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 2.0 ELSE 0.0 END",
]
)
select_score_sql = (
f"(CASE WHEN {title_expr} THEN 4.0 ELSE 0.0 END + "
f"CASE WHEN {keywords_expr} THEN 3.0 ELSE 0.0 END + "
f"CASE WHEN {summary_expr} THEN 2.0 ELSE 0.0 END + "
f"CASE WHEN {participants_expr} THEN 1.0 ELSE 0.0 END) AS lexical_score"
)
select_params.extend([like, like, like, like])
query_params.extend([like, like, like, like])
select_params.extend([phrase_like, phrase_like, phrase_like, phrase_like])
token_predicates: List[str] = []
for token in tokens:
like = f"%{token}%"
token_any = (
f"({field_exprs['title']} LIKE ? OR "
f"{field_exprs['summary']} LIKE ? OR "
f"{field_exprs['keywords']} LIKE ? OR "
f"{field_exprs['participants']} LIKE ?)"
)
token_predicates.append(token_any)
query_params.extend([like, like, like, like])
score_parts.append(
"("
f"CASE WHEN {field_exprs['title']} LIKE ? THEN 3.0 ELSE 0.0 END + "
f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 2.5 ELSE 0.0 END + "
f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 2.0 ELSE 0.0 END + "
f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 1.5 ELSE 0.0 END + "
f"CASE WHEN {token_any.replace('?', '?')} THEN 2.0 ELSE 0.0 END"
")"
)
select_params.extend([like, like, like, like, like, like, like, like])
if token_predicates:
conditions.append("(" + " OR ".join(token_predicates) + ")")
select_score_sql = f"({' + '.join(score_parts)}) AS lexical_score"
order_sql = f"lexical_score DESC, {effective_end} DESC, e.updated_at DESC"
where_sql = ("WHERE " + " AND ".join(conditions)) if conditions else ""

View File

@@ -302,7 +302,7 @@ class AggregateQueryService:
)
for (branch_name, _), payload in zip(scheduled, done):
if isinstance(payload, Exception):
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
logger.error(f"aggregate branch failed: branch={branch_name} error={payload}")
normalized = self._normalize_branch_payload(
branch_name,
{

View File

@@ -70,7 +70,7 @@ class EpisodeRetrievalService:
temporal=temporal,
)
except Exception as exc:
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}")
else:
paragraph_rank_map: Dict[str, int] = {}
relation_rank_map: Dict[str, int] = {}

View File

@@ -0,0 +1,304 @@
"""
Episode 语义切分服务LLM 主路径)。
职责:
1. 组装语义切分提示词
2. 调用 LLM 生成结构化 episode JSON
3. 严格校验输出结构,返回标准化结果
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.config.model_configs import TaskConfig
from src.config.config import model_config as host_model_config
from src.services import llm_service as llm_api
logger = get_logger("A_Memorix.EpisodeSegmentationService")
class EpisodeSegmentationService:
"""基于 LLM 的 episode 语义切分服务。"""
SEGMENTATION_VERSION = "episode_mvp_v1"
def __init__(self, plugin_config: Optional[dict] = None):
self.plugin_config = plugin_config or {}
def _cfg(self, key: str, default: Any = None) -> Any:
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
@staticmethod
def _is_task_config(obj: Any) -> bool:
return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", []))
def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig:
return TaskConfig(
model_list=[model_name],
max_tokens=template.max_tokens,
temperature=template.temperature,
slow_threshold=template.slow_threshold,
selection_strategy=template.selection_strategy,
)
def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]:
preferred = ("utils", "replyer", "planner", "tool_use")
for task_name in preferred:
cfg = available_tasks.get(task_name)
if self._is_task_config(cfg):
return cfg
for task_name, cfg in available_tasks.items():
if task_name != "embedding" and self._is_task_config(cfg):
return cfg
for cfg in available_tasks.values():
if self._is_task_config(cfg):
return cfg
return None
def _resolve_model_config(self) -> Tuple[Optional[Any], str]:
available_tasks = llm_api.get_available_models() or {}
if not available_tasks:
return None, "unavailable"
selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip()
model_dict = getattr(host_model_config, "models_dict", {}) or {}
if selector and selector.lower() != "auto":
direct_task = available_tasks.get(selector)
if self._is_task_config(direct_task):
return direct_task, selector
if selector in model_dict:
template = self._pick_template_task(available_tasks)
if template is not None:
return self._build_single_model_task(selector, template), selector
logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto")
for task_name in ("utils", "replyer", "planner", "tool_use"):
cfg = available_tasks.get(task_name)
if self._is_task_config(cfg):
return cfg, task_name
fallback = self._pick_template_task(available_tasks)
if fallback is not None:
return fallback, "auto"
return None, "unavailable"
@staticmethod
def _clamp_score(value: Any, default: float = 0.0) -> float:
try:
num = float(value)
except Exception:
num = default
if num < 0.0:
return 0.0
if num > 1.0:
return 1.0
return num
@staticmethod
def _safe_json_loads(text: str) -> Dict[str, Any]:
raw = str(text or "").strip()
if not raw:
raise ValueError("empty_response")
if "```" in raw:
raw = raw.replace("```json", "```").replace("```JSON", "```")
parts = raw.split("```")
for part in parts:
part = part.strip()
if part.startswith("{") and part.endswith("}"):
raw = part
break
try:
data = json.loads(raw)
if isinstance(data, dict):
return data
except Exception:
pass
start = raw.find("{")
end = raw.rfind("}")
if start >= 0 and end > start:
candidate = raw[start : end + 1]
data = json.loads(candidate)
if isinstance(data, dict):
return data
raise ValueError("invalid_json_response")
def _build_prompt(
self,
*,
source: str,
window_start: Optional[float],
window_end: Optional[float],
paragraphs: List[Dict[str, Any]],
) -> str:
rows: List[str] = []
for idx, item in enumerate(paragraphs, 1):
p_hash = str(item.get("hash", "") or "").strip()
content = str(item.get("content", "") or "").strip().replace("\r\n", "\n")
content = content[:800]
event_start = item.get("event_time_start")
event_end = item.get("event_time_end")
event_time = item.get("event_time")
rows.append(
(
f"[{idx}] hash={p_hash}\n"
f"event_time={event_time}\n"
f"event_time_start={event_start}\n"
f"event_time_end={event_end}\n"
f"content={content}"
)
)
source_text = str(source or "").strip() or "unknown"
return (
"You are an episode segmentation engine.\n"
"Group the given paragraphs into one or more coherent episodes.\n"
"Return JSON ONLY. No markdown, no explanation.\n"
"\n"
"Hard JSON schema:\n"
"{\n"
' "episodes": [\n'
" {\n"
' "title": "string",\n'
' "summary": "string",\n'
' "paragraph_hashes": ["hash1", "hash2"],\n'
' "participants": ["person1", "person2"],\n'
' "keywords": ["kw1", "kw2"],\n'
' "time_confidence": 0.0,\n'
' "llm_confidence": 0.0\n'
" }\n"
" ]\n"
"}\n"
"\n"
"Rules:\n"
"1) paragraph_hashes must come from input only.\n"
"2) title and summary must be non-empty.\n"
"3) keep participants/keywords concise and deduplicated.\n"
"4) if uncertain, still provide best effort confidence values.\n"
"\n"
f"source={source_text}\n"
f"window_start={window_start}\n"
f"window_end={window_end}\n"
"paragraphs:\n"
+ "\n\n".join(rows)
)
def _normalize_episodes(
self,
*,
payload: Dict[str, Any],
input_hashes: List[str],
) -> List[Dict[str, Any]]:
raw_episodes = payload.get("episodes")
if not isinstance(raw_episodes, list):
raise ValueError("episodes_missing_or_not_list")
valid_hashes = set(input_hashes)
normalized: List[Dict[str, Any]] = []
for item in raw_episodes:
if not isinstance(item, dict):
continue
title = str(item.get("title", "") or "").strip()
summary = str(item.get("summary", "") or "").strip()
if not title or not summary:
continue
raw_hashes = item.get("paragraph_hashes")
if not isinstance(raw_hashes, list):
continue
dedup_hashes: List[str] = []
seen_hashes = set()
for h in raw_hashes:
token = str(h or "").strip()
if not token or token in seen_hashes or token not in valid_hashes:
continue
seen_hashes.add(token)
dedup_hashes.append(token)
if not dedup_hashes:
continue
participants = []
for p in item.get("participants", []) or []:
token = str(p or "").strip()
if token:
participants.append(token)
keywords = []
for kw in item.get("keywords", []) or []:
token = str(kw or "").strip()
if token:
keywords.append(token)
normalized.append(
{
"title": title,
"summary": summary,
"paragraph_hashes": dedup_hashes,
"participants": participants[:16],
"keywords": keywords[:20],
"time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0),
"llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5),
}
)
if not normalized:
raise ValueError("episodes_all_invalid")
return normalized
async def segment(
self,
*,
source: str,
window_start: Optional[float],
window_end: Optional[float],
paragraphs: List[Dict[str, Any]],
) -> Dict[str, Any]:
if not paragraphs:
raise ValueError("paragraphs_empty")
model_config, model_label = self._resolve_model_config()
if model_config is None:
raise RuntimeError("episode segmentation model unavailable")
prompt = self._build_prompt(
source=source,
window_start=window_start,
window_end=window_end,
paragraphs=paragraphs,
)
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="A_Memorix.EpisodeSegmentation",
)
if not success or not response:
raise RuntimeError("llm_generate_failed")
payload = self._safe_json_loads(str(response))
input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs]
episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes)
return {
"episodes": episodes,
"segmentation_model": model_label,
"segmentation_version": self.SEGMENTATION_VERSION,
}

View File

@@ -0,0 +1,558 @@
"""
Episode 聚合与落库服务。
流程:
1. 从 pending 队列读取段落并组批
2. 按 source + 时间窗口切组
3. 调用 LLM 语义切分
4. 写入 episodes + episode_paragraphs
5. LLM 失败时使用确定性 fallback
"""
from __future__ import annotations
import json
import re
from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from .episode_segmentation_service import EpisodeSegmentationService
from .hash import compute_hash
logger = get_logger("A_Memorix.EpisodeService")
class EpisodeService:
"""Episode MVP 后台处理服务。"""
def __init__(
self,
*,
metadata_store: Any,
plugin_config: Optional[Any] = None,
segmentation_service: Optional[EpisodeSegmentationService] = None,
):
self.metadata_store = metadata_store
self.plugin_config = plugin_config or {}
self.segmentation_service = segmentation_service or EpisodeSegmentationService(
plugin_config=self._config_dict(),
)
def _config_dict(self) -> Dict[str, Any]:
if isinstance(self.plugin_config, dict):
return self.plugin_config
return {}
def _cfg(self, key: str, default: Any = None) -> Any:
getter = getattr(self.plugin_config, "get_config", None)
if callable(getter):
return getter(key, default)
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
@staticmethod
def _to_optional_float(value: Any) -> Optional[float]:
if value is None:
return None
try:
return float(value)
except Exception:
return None
@staticmethod
def _clamp_score(value: Any, default: float = 1.0) -> float:
try:
num = float(value)
except Exception:
num = default
if num < 0.0:
return 0.0
if num > 1.0:
return 1.0
return num
@staticmethod
def _paragraph_anchor(paragraph: Dict[str, Any]) -> float:
for key in ("event_time_end", "event_time_start", "event_time", "created_at"):
value = paragraph.get(key)
try:
if value is not None:
return float(value)
except Exception:
continue
return 0.0
@staticmethod
def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]:
return (
EpisodeService._paragraph_anchor(paragraph),
str(paragraph.get("hash", "") or ""),
)
def load_pending_paragraphs(
self,
pending_rows: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], List[str]]:
"""
将 pending 行展开为段落上下文。
Returns:
(loaded_paragraphs, missing_hashes)
"""
loaded: List[Dict[str, Any]] = []
missing: List[str] = []
for row in pending_rows or []:
p_hash = str(row.get("paragraph_hash", "") or "").strip()
if not p_hash:
continue
paragraph = self.metadata_store.get_paragraph(p_hash)
if not paragraph:
missing.append(p_hash)
continue
loaded.append(
{
"hash": p_hash,
"source": str(row.get("source") or paragraph.get("source") or "").strip(),
"content": str(paragraph.get("content", "") or ""),
"created_at": self._to_optional_float(paragraph.get("created_at"))
or self._to_optional_float(row.get("created_at"))
or 0.0,
"event_time": self._to_optional_float(paragraph.get("event_time")),
"event_time_start": self._to_optional_float(paragraph.get("event_time_start")),
"event_time_end": self._to_optional_float(paragraph.get("event_time_end")),
"time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None,
"time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0),
}
)
return loaded, missing
def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。
"""
if not paragraphs:
return []
max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20)))
max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000)))
window_seconds = max(
60.0,
float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0,
)
by_source: Dict[str, List[Dict[str, Any]]] = {}
for paragraph in paragraphs:
source = str(paragraph.get("source", "") or "").strip()
by_source.setdefault(source, []).append(paragraph)
groups: List[Dict[str, Any]] = []
for source, items in by_source.items():
ordered = sorted(items, key=self._paragraph_sort_key)
current: List[Dict[str, Any]] = []
current_chars = 0
last_anchor: Optional[float] = None
def flush() -> None:
nonlocal current, current_chars, last_anchor
if not current:
return
sorted_current = sorted(current, key=self._paragraph_sort_key)
groups.append(
{
"source": source,
"paragraphs": sorted_current,
}
)
current = []
current_chars = 0
last_anchor = None
for paragraph in ordered:
anchor = self._paragraph_anchor(paragraph)
content_len = len(str(paragraph.get("content", "") or ""))
need_flush = False
if current:
if len(current) >= max_paragraphs:
need_flush = True
elif current_chars + content_len > max_chars:
need_flush = True
elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds:
need_flush = True
if need_flush:
flush()
current.append(paragraph)
current_chars += content_len
last_anchor = anchor
flush()
groups.sort(
key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0
)
return groups
def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]:
starts: List[float] = []
ends: List[float] = []
granularity_priority = {
"minute": 4,
"hour": 3,
"day": 2,
"month": 1,
"year": 0,
}
granularity = None
granularity_rank = -1
conf_values: List[float] = []
for p in paragraphs:
s = self._to_optional_float(p.get("event_time_start"))
e = self._to_optional_float(p.get("event_time_end"))
t = self._to_optional_float(p.get("event_time"))
c = self._to_optional_float(p.get("created_at"))
start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c))
end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c))
if start_candidate is not None:
starts.append(start_candidate)
if end_candidate is not None:
ends.append(end_candidate)
g = str(p.get("time_granularity", "") or "").strip().lower()
if g in granularity_priority and granularity_priority[g] > granularity_rank:
granularity_rank = granularity_priority[g]
granularity = g
conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0))
time_start = min(starts) if starts else None
time_end = max(ends) if ends else None
time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0
return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0)
def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]:
seen = set()
participants: List[str] = []
for p_hash in paragraph_hashes:
try:
entities = self.metadata_store.get_paragraph_entities(p_hash)
except Exception:
entities = []
for item in entities:
name = str(item.get("name", "") or "").strip()
if not name:
continue
key = name.lower()
if key in seen:
continue
seen.add(key)
participants.append(name)
if len(participants) >= limit:
return participants
return participants
@staticmethod
def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]:
token_counter: Counter[str] = Counter()
token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}")
stop_words = {
"the",
"and",
"that",
"this",
"with",
"from",
"for",
"have",
"will",
"your",
"you",
"我们",
"你们",
"他们",
"以及",
"一个",
"这个",
"那个",
"然后",
"因为",
"所以",
}
for p in paragraphs:
text = str(p.get("content", "") or "").lower()
for token in token_pattern.findall(text):
if token in stop_words:
continue
token_counter[token] += 1
return [token for token, _ in token_counter.most_common(limit)]
def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]:
paragraphs = group.get("paragraphs", []) or []
source = str(group.get("source", "") or "").strip()
hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
snippets = []
for p in paragraphs[:3]:
text = str(p.get("content", "") or "").strip().replace("\n", " ")
if text:
snippets.append(text[:140])
summary = "".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。"
time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs)
participants = self._collect_participants(hashes, limit=12)
keywords = self._derive_keywords(paragraphs, limit=10)
if time_start is not None:
day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d")
title = f"{source or 'unknown'} {day_text} 情景片段"
else:
title = f"{source or 'unknown'} 情景片段"
return {
"title": title[:80],
"summary": summary,
"paragraph_hashes": hashes,
"participants": participants,
"keywords": keywords,
"time_confidence": time_conf,
"llm_confidence": 0.0,
"event_time_start": time_start,
"event_time_end": time_end,
"time_granularity": granularity,
"segmentation_model": "fallback_rule",
"segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION,
}
@staticmethod
def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]:
in_group = set(group_hashes_ordered)
dedup: List[str] = []
seen = set()
for h in episode_hashes or []:
token = str(h or "").strip()
if not token or token not in in_group or token in seen:
continue
seen.add(token)
dedup.append(token)
return dedup
async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
paragraphs = group.get("paragraphs", []) or []
if not paragraphs:
return {
"payloads": [],
"done_hashes": [],
"episode_count": 0,
"fallback_count": 0,
}
source = str(group.get("source", "") or "").strip()
group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
group_start, group_end, _, _ = self._compute_time_meta(paragraphs)
fallback_used = False
segmentation_model = "fallback_rule"
segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION
try:
llm_result = await self.segmentation_service.segment(
source=source,
window_start=group_start,
window_end=group_end,
paragraphs=paragraphs,
)
episodes = list(llm_result.get("episodes") or [])
segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto"
segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION
if not episodes:
raise ValueError("llm_empty_episodes")
except Exception as e:
logger.warning(
"Episode segmentation fallback: "
f"source={source} "
f"size={len(group_hashes)} "
f"err={e}"
)
episodes = [self._build_fallback_episode(group)]
fallback_used = True
stored_payloads: List[Dict[str, Any]] = []
for episode in episodes:
ordered_hashes = self._normalize_episode_hashes(
episode_hashes=episode.get("paragraph_hashes", []),
group_hashes_ordered=group_hashes,
)
if not ordered_hashes:
continue
sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)]
event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs)
participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()]
keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()]
if not participants:
participants = self._collect_participants(ordered_hashes, limit=16)
if not keywords:
keywords = self._derive_keywords(sub_paragraphs, limit=12)
title = str(episode.get("title", "") or "").strip()[:120]
summary = str(episode.get("summary", "") or "").strip()[:2000]
if not title or not summary:
continue
seed = json.dumps(
{
"source": source,
"hashes": ordered_hashes,
"version": segmentation_version,
},
ensure_ascii=False,
sort_keys=True,
)
episode_id = compute_hash(seed)
payload = {
"episode_id": episode_id,
"source": source or None,
"title": title,
"summary": summary,
"event_time_start": episode.get("event_time_start", event_start),
"event_time_end": episode.get("event_time_end", event_end),
"time_granularity": episode.get("time_granularity", granularity),
"time_confidence": self._clamp_score(
episode.get("time_confidence"),
default=time_conf_default,
),
"participants": participants[:16],
"keywords": keywords[:20],
"evidence_ids": ordered_hashes,
"paragraph_count": len(ordered_hashes),
"llm_confidence": self._clamp_score(
episode.get("llm_confidence"),
default=0.0 if fallback_used else 0.6,
),
"segmentation_model": (
str(episode.get("segmentation_model", "") or "").strip()
or ("fallback_rule" if fallback_used else segmentation_model)
),
"segmentation_version": (
str(episode.get("segmentation_version", "") or "").strip()
or segmentation_version
),
}
stored_payloads.append(payload)
return {
"payloads": stored_payloads,
"done_hashes": group_hashes,
"episode_count": len(stored_payloads),
"fallback_count": 1 if fallback_used else 0,
}
async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
result = await self._build_episode_payloads_for_group(group)
stored_count = 0
for payload in result.get("payloads") or []:
stored = self.metadata_store.upsert_episode(payload)
final_id = str(stored.get("episode_id") or payload.get("episode_id") or "")
if final_id:
self.metadata_store.bind_episode_paragraphs(
final_id,
list(payload.get("evidence_ids") or []),
)
stored_count += 1
result["episode_count"] = stored_count
return {
"done_hashes": list(result.get("done_hashes") or []),
"episode_count": stored_count,
"fallback_count": int(result.get("fallback_count") or 0),
}
async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]:
loaded, missing_hashes = self.load_pending_paragraphs(pending_rows)
groups = self.group_paragraphs(loaded)
done_hashes: List[str] = list(missing_hashes)
failed_hashes: Dict[str, str] = {}
episode_count = 0
fallback_count = 0
for group in groups:
group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])]
try:
result = await self.process_group(group)
done_hashes.extend(result.get("done_hashes") or [])
episode_count += int(result.get("episode_count") or 0)
fallback_count += int(result.get("fallback_count") or 0)
except Exception as e:
err = str(e)[:500]
for h in group_hashes:
if h:
failed_hashes[h] = err
dedup_done = list(dict.fromkeys([h for h in done_hashes if h]))
return {
"done_hashes": dedup_done,
"failed_hashes": failed_hashes,
"episode_count": episode_count,
"fallback_count": fallback_count,
"missing_count": len(missing_hashes),
"group_count": len(groups),
}
async def rebuild_source(self, source: str) -> Dict[str, Any]:
token = str(source or "").strip()
if not token:
return {
"source": "",
"episode_count": 0,
"fallback_count": 0,
"group_count": 0,
"paragraph_count": 0,
}
paragraphs = self.metadata_store.get_live_paragraphs_by_source(token)
if not paragraphs:
replace_result = self.metadata_store.replace_episodes_for_source(token, [])
return {
"source": token,
"episode_count": int(replace_result.get("episode_count") or 0),
"fallback_count": 0,
"group_count": 0,
"paragraph_count": 0,
}
groups = self.group_paragraphs(paragraphs)
payloads: List[Dict[str, Any]] = []
fallback_count = 0
for group in groups:
result = await self._build_episode_payloads_for_group(group)
payloads.extend(list(result.get("payloads") or []))
fallback_count += int(result.get("fallback_count") or 0)
replace_result = self.metadata_store.replace_episodes_for_source(token, payloads)
return {
"source": token,
"episode_count": int(replace_result.get("episode_count") or 0),
"fallback_count": fallback_count,
"group_count": len(groups),
"paragraph_count": len(paragraphs),
}

View File

@@ -9,7 +9,11 @@ import json
import time
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import or_
from sqlmodel import select
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from ..embedding import EmbeddingAPIAdapter
@@ -120,31 +124,40 @@ class PersonProfileService:
if not key:
return ""
try:
with get_db_session(auto_commit=False) as session:
record = session.exec(
select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1)
).first()
if record:
return str(record)
record = session.exec(
select(PersonInfo.person_id)
.where(
or_(
PersonInfo.person_name == key,
PersonInfo.user_nickname == key,
)
)
.limit(1)
).first()
if record:
return str(record)
record = session.exec(
select(PersonInfo.person_id)
.where(PersonInfo.group_cardname.contains(key))
.limit(1)
).first()
if record:
return str(record)
except Exception as e:
logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}")
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
return key.lower()
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where(PersonInfo.group_nick_name.contains(key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
return ""
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
@@ -160,7 +173,7 @@ class PersonProfileService:
names: List[str] = []
for item in items:
if isinstance(item, dict):
value = str(item.get("group_nick_name", "")).strip()
value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip()
if value:
names.append(value)
elif isinstance(item, str):
@@ -193,6 +206,42 @@ class PersonProfileService:
traits.append(text)
return traits[:10]
def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]:
"""当人物主档案缺失时,从已有记忆证据里回捞可用别名。"""
if not person_id:
return [], ""
aliases: List[str] = []
primary_name = ""
seen = set()
try:
paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id)
except Exception as e:
logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}")
return [], ""
for paragraph in paragraphs[:20]:
paragraph_hash = str(paragraph.get("hash", "") or "").strip()
if not paragraph_hash:
continue
try:
paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash)
except Exception:
paragraph_entities = []
for entity in paragraph_entities:
name = str(entity.get("name", "") or "").strip()
if not name or name == person_id:
continue
key = name.lower()
if key in seen:
continue
seen.add(key)
aliases.append(name)
if not primary_name:
primary_name = name
return aliases, primary_name
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
"""获取人物别名集合、主展示名、记忆特征。"""
aliases: List[str] = []
@@ -200,18 +249,28 @@ class PersonProfileService:
memory_traits: List[str] = []
if not person_id:
return aliases, primary_name, memory_traits
recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id)
try:
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not record:
return aliases, primary_name, memory_traits
with get_db_session(auto_commit=False) as session:
record = session.exec(
select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1)
).first()
if not record:
return recovered_aliases, recovered_primary_name or person_id, memory_traits
person_name = str(getattr(record, "person_name", "") or "").strip()
nickname = str(getattr(record, "nickname", "") or "").strip()
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
nickname = str(getattr(record, "user_nickname", "") or "").strip()
group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None))
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
primary_name = (
person_name
or nickname
or recovered_primary_name
or str(getattr(record, "user_id", "") or "").strip()
or person_id
)
candidates = [person_name, nickname] + group_nicks
candidates = [person_name, nickname] + group_nicks + recovered_aliases
seen = set()
for item in candidates:
norm = str(item or "").strip()

View File

@@ -82,8 +82,9 @@ class RelationWriteService:
)
self.metadata_store.set_relation_vector_state(hash_value, "ready")
logger.info(
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
hash_value[:16],
"metric.relation_vector_write_success=1 "
"metric.relation_vector_write_success_count=1 "
f"hash={hash_value[:16]}"
)
return RelationWriteResult(
hash_value=hash_value,
@@ -109,9 +110,10 @@ class RelationWriteService:
bump_retry=True,
)
logger.warning(
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
hash_value[:16],
err,
"metric.relation_vector_write_fail=1 "
"metric.relation_vector_write_fail_count=1 "
f"hash={hash_value[:16]} "
f"err={err}"
)
return RelationWriteResult(
hash_value=hash_value,

File diff suppressed because it is too large Load Diff

View File

@@ -61,6 +61,29 @@ def _build_report(
}
def _normalize_encoded_vector(encoded: Any) -> np.ndarray:
if encoded is None:
raise ValueError("embedding encode returned None")
if isinstance(encoded, np.ndarray):
array = encoded
else:
array = np.asarray(encoded, dtype=np.float32)
if array.ndim == 2:
if array.shape[0] != 1:
raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}")
array = array[0]
if array.ndim != 1:
raise ValueError(f"embedding encode returned invalid ndim={array.ndim}")
if array.size <= 0:
raise ValueError("embedding encode returned empty vector")
if not np.all(np.isfinite(array)):
raise ValueError("embedding encode returned non-finite values")
return array.astype(np.float32, copy=False)
async def run_embedding_runtime_self_check(
*,
config: Any,
@@ -91,13 +114,11 @@ async def run_embedding_runtime_self_check(
try:
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
encoded = await embedding_manager.encode(sample_text)
if isinstance(encoded, np.ndarray):
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
else:
encoded_dimension = len(encoded) if encoded is not None else 0
encoded_array = _normalize_encoded_vector(encoded)
encoded_dimension = int(encoded_array.shape[0])
except Exception as exc:
elapsed_ms = (time.perf_counter() - start) * 1000.0
logger.warning("embedding runtime self-check failed: %s", exc)
logger.warning(f"embedding runtime self-check failed: {exc}")
return _build_report(
ok=False,
code="embedding_probe_failed",

View File

@@ -0,0 +1,442 @@
"""
统一检索执行服务。
用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。
"""
from __future__ import annotations
import hashlib
import json
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from ..retrieval import TemporalQueryOptions
from .search_postprocess import (
apply_safe_content_dedup,
maybe_apply_smart_path_fallback,
)
from .time_parser import parse_query_time_range
logger = get_logger("A_Memorix.SearchExecutionService")
def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any:
if not isinstance(config, dict):
return default
current: Any = config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _sanitize_text(value: Any) -> str:
if value is None:
return ""
return str(value).strip()
@dataclass
class SearchExecutionRequest:
caller: str
stream_id: Optional[str] = None
group_id: Optional[str] = None
user_id: Optional[str] = None
query_type: str = "search" # search|semantic|time|hybrid
query: str = ""
top_k: Optional[int] = None
time_from: Optional[str] = None
time_to: Optional[str] = None
person: Optional[str] = None
source: Optional[str] = None
use_threshold: bool = True
enable_ppr: bool = True
@dataclass
class SearchExecutionResult:
success: bool
error: str = ""
query_type: str = "search"
query: str = ""
top_k: int = 10
time_from: Optional[str] = None
time_to: Optional[str] = None
person: Optional[str] = None
source: Optional[str] = None
temporal: Optional[TemporalQueryOptions] = None
results: List[Any] = field(default_factory=list)
elapsed_ms: float = 0.0
chat_filtered: bool = False
dedup_hit: bool = False
@property
def count(self) -> int:
return len(self.results)
class SearchExecutionService:
"""统一检索执行服务。"""
@staticmethod
def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]:
if isinstance(plugin_config, dict):
plugin_instance = plugin_config.get("plugin_instance")
if plugin_instance is not None:
return plugin_instance
try:
from ...plugin import AMemorixPlugin
return getattr(AMemorixPlugin, "get_global_instance", lambda: None)()
except Exception:
return None
@staticmethod
def _normalize_query_type(raw_query_type: str) -> str:
query_type = _sanitize_text(raw_query_type).lower() or "search"
if query_type == "semantic":
return "search"
return query_type
@staticmethod
def _resolve_runtime_component(
plugin_config: Optional[dict],
plugin_instance: Optional[Any],
key: str,
) -> Optional[Any]:
if isinstance(plugin_config, dict):
value = plugin_config.get(key)
if value is not None:
return value
if plugin_instance is not None:
value = getattr(plugin_instance, key, None)
if value is not None:
return value
return None
@staticmethod
def _resolve_top_k(
plugin_config: Optional[dict],
query_type: str,
top_k_raw: Optional[Any],
) -> Tuple[bool, int, str]:
temporal_default_top_k = int(
_get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10)
)
default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10
if top_k_raw is None:
return True, max(1, min(50, default_top_k)), ""
try:
top_k = int(top_k_raw)
except (TypeError, ValueError):
return False, 0, "top_k 参数必须为整数"
return True, max(1, min(50, top_k)), ""
@staticmethod
def _build_temporal(
plugin_config: Optional[dict],
query_type: str,
time_from_raw: Optional[str],
time_to_raw: Optional[str],
person: Optional[str],
source: Optional[str],
) -> Tuple[bool, Optional[TemporalQueryOptions], str]:
if query_type not in {"time", "hybrid"}:
return True, None, ""
temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True))
if not temporal_enabled:
return False, None, "时序检索已禁用retrieval.temporal.enabled=false"
if not time_from_raw and not time_to_raw:
return False, None, "time/hybrid 模式至少需要 time_from 或 time_to"
try:
ts_from, ts_to = parse_query_time_range(
str(time_from_raw) if time_from_raw is not None else None,
str(time_to_raw) if time_to_raw is not None else None,
)
except ValueError as e:
return False, None, f"时间参数错误: {e}"
temporal = TemporalQueryOptions(
time_from=ts_from,
time_to=ts_to,
person=_sanitize_text(person) or None,
source=_sanitize_text(source) or None,
allow_created_fallback=bool(
_get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True)
),
candidate_multiplier=int(
_get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8)
),
max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)),
)
return True, temporal, ""
@staticmethod
def _build_request_key(
request: SearchExecutionRequest,
query_type: str,
top_k: int,
temporal: Optional[TemporalQueryOptions],
) -> str:
payload = {
"stream_id": _sanitize_text(request.stream_id),
"query_type": query_type,
"query": _sanitize_text(request.query),
"time_from": _sanitize_text(request.time_from),
"time_to": _sanitize_text(request.time_to),
"time_from_ts": temporal.time_from if temporal else None,
"time_to_ts": temporal.time_to if temporal else None,
"person": _sanitize_text(request.person),
"source": _sanitize_text(request.source),
"top_k": int(top_k),
"use_threshold": bool(request.use_threshold),
"enable_ppr": bool(request.enable_ppr),
}
payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True)
return hashlib.sha1(payload_json.encode("utf-8")).hexdigest()
@staticmethod
async def execute(
*,
retriever: Any,
threshold_filter: Optional[Any],
plugin_config: Optional[dict],
request: SearchExecutionRequest,
enforce_chat_filter: bool = True,
reinforce_access: bool = True,
) -> SearchExecutionResult:
if retriever is None:
return SearchExecutionResult(success=False, error="知识检索器未初始化")
query_type = SearchExecutionService._normalize_query_type(request.query_type)
query = _sanitize_text(request.query)
if query_type not in {"search", "time", "hybrid"}:
return SearchExecutionResult(
success=False,
error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid",
)
if query_type in {"search", "hybrid"} and not query:
return SearchExecutionResult(
success=False,
error="search/hybrid 模式必须提供 query",
)
top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k(
plugin_config, query_type, request.top_k
)
if not top_k_ok:
return SearchExecutionResult(success=False, error=top_k_error)
temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal(
plugin_config=plugin_config,
query_type=query_type,
time_from_raw=request.time_from,
time_to_raw=request.time_to,
person=request.person,
source=request.source,
)
if not temporal_ok:
return SearchExecutionResult(success=False, error=temporal_error)
plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config)
if (
enforce_chat_filter
and plugin_instance is not None
and hasattr(plugin_instance, "is_chat_enabled")
):
if not plugin_instance.is_chat_enabled(
stream_id=request.stream_id,
group_id=request.group_id,
user_id=request.user_id,
):
logger.info(
"检索请求被聊天过滤拦截: "
f"caller={request.caller}, "
f"stream_id={request.stream_id}"
)
return SearchExecutionResult(
success=True,
query_type=query_type,
query=query,
top_k=top_k,
time_from=request.time_from,
time_to=request.time_to,
person=request.person,
source=request.source,
temporal=temporal,
results=[],
elapsed_ms=0.0,
chat_filtered=True,
dedup_hit=False,
)
request_key = SearchExecutionService._build_request_key(
request=request,
query_type=query_type,
top_k=top_k,
temporal=temporal,
)
async def _executor() -> Dict[str, Any]:
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
started_at = time.time()
try:
retrieved = await retriever.retrieve(
query=query,
top_k=top_k,
temporal=temporal,
)
should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None
if (
query_type == "time"
and not query
and bool(
_get_config_value(
plugin_config,
"retrieval.time.skip_threshold_when_query_empty",
True,
)
)
):
should_apply_threshold = False
if should_apply_threshold:
retrieved = threshold_filter.filter(retrieved)
if (
reinforce_access
and plugin_instance is not None
and hasattr(plugin_instance, "reinforce_access")
):
relation_hashes = [
item.hash_value
for item in retrieved
if getattr(item, "result_type", "") == "relation"
]
if relation_hashes:
await plugin_instance.reinforce_access(relation_hashes)
if query_type == "search":
graph_store = SearchExecutionService._resolve_runtime_component(
plugin_config, plugin_instance, "graph_store"
)
metadata_store = SearchExecutionService._resolve_runtime_component(
plugin_config, plugin_instance, "metadata_store"
)
fallback_enabled = bool(
_get_config_value(
plugin_config,
"retrieval.search.smart_fallback.enabled",
True,
)
)
fallback_threshold = float(
_get_config_value(
plugin_config,
"retrieval.search.smart_fallback.threshold",
0.6,
)
)
retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback(
query=query,
results=list(retrieved),
graph_store=graph_store,
metadata_store=metadata_store,
enabled=fallback_enabled,
threshold=fallback_threshold,
)
if fallback_triggered:
logger.info(
"metric.smart_fallback_triggered_count=1 "
f"caller={request.caller} "
f"added={fallback_added}"
)
dedup_enabled = bool(
_get_config_value(
plugin_config,
"retrieval.search.safe_content_dedup.enabled",
True,
)
)
if dedup_enabled:
retrieved, removed_count = apply_safe_content_dedup(list(retrieved))
if removed_count > 0:
logger.info(
f"metric.safe_dedup_removed_count={removed_count} "
f"caller={request.caller}"
)
elapsed_ms = (time.time() - started_at) * 1000.0
return {"results": retrieved, "elapsed_ms": elapsed_ms}
finally:
setattr(retriever.config, "enable_ppr", original_ppr)
dedup_hit = False
try:
# 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。
bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning"
if (
not bypass_request_dedup
and
plugin_instance is not None
and hasattr(plugin_instance, "execute_request_with_dedup")
):
dedup_hit, payload = await plugin_instance.execute_request_with_dedup(
request_key,
_executor,
)
else:
payload = await _executor()
except Exception as e:
return SearchExecutionResult(success=False, error=f"知识检索失败: {e}")
if dedup_hit:
logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}")
return SearchExecutionResult(
success=True,
query_type=query_type,
query=query,
top_k=top_k,
time_from=request.time_from,
time_to=request.time_to,
person=request.person,
source=request.source,
temporal=temporal,
results=payload.get("results", []),
elapsed_ms=float(payload.get("elapsed_ms", 0.0)),
chat_filtered=False,
dedup_hit=bool(dedup_hit),
)
@staticmethod
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]:
serialized: List[Dict[str, Any]] = []
for item in results:
metadata = dict(getattr(item, "metadata", {}) or {})
if "time_meta" not in metadata:
metadata["time_meta"] = {}
serialized.append(
{
"hash": getattr(item, "hash_value", ""),
"type": getattr(item, "result_type", ""),
"score": float(getattr(item, "score", 0.0)),
"content": getattr(item, "content", ""),
"metadata": metadata,
}
)
return serialized

View File

@@ -0,0 +1,425 @@
"""
聊天总结与知识导入工具
该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系
导入到 A_memorix 的存储组件中。
"""
import time
import json
import re
import traceback
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path
from src.common.logger import get_logger
from src.services import llm_service as llm_api
from src.services import message_service as message_api
from src.config.config import global_config, model_config as host_model_config
from src.config.model_configs import TaskConfig
from ..storage import (
KnowledgeType,
VectorStore,
GraphStore,
MetadataStore,
resolve_stored_knowledge_type,
)
from ..embedding import EmbeddingAPIAdapter
from .relation_write_service import RelationWriteService
from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check
logger = get_logger("A_Memorix.SummaryImporter")
# 默认总结提示词模版
SUMMARY_PROMPT_TEMPLATE = """
你是 {bot_name}{personality_context}
现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。
聊天记录内容:
{chat_history}
请完成以下任务:
1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。
2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。
请严格以 JSON 格式输出,格式如下:
{{
"summary": "总结文本内容",
"entities": ["张三", "李四"],
"relations": [
{{"subject": "张三", "predicate": "认识", "object": "李四"}}
]
}}
注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。
"""
class SummaryImporter:
"""总结并导入知识的工具类"""
def __init__(
self,
vector_store: VectorStore,
graph_store: GraphStore,
metadata_store: MetadataStore,
embedding_manager: EmbeddingAPIAdapter,
plugin_config: dict
):
self.vector_store = vector_store
self.graph_store = graph_store
self.metadata_store = metadata_store
self.embedding_manager = embedding_manager
self.plugin_config = plugin_config
self.relation_write_service: Optional[RelationWriteService] = (
plugin_config.get("relation_write_service")
if isinstance(plugin_config, dict)
else None
)
def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]:
"""标准化 summarization.model_name 配置vNext 仅接受字符串数组)。"""
if raw_value is None:
return ["auto"]
if isinstance(raw_value, list):
selectors = [str(x).strip() for x in raw_value if str(x).strip()]
return selectors or ["auto"]
raise ValueError(
"summarization.model_name 在 vNext 必须为 List[str]。"
" 请执行 scripts/release_vnext_migrate.py migrate。"
)
def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]:
"""
选择总结默认任务,避免错误落到 embedding 任务。
优先级replyer > utils > planner > tool_use > 其他非 embedding。
"""
preferred = ("replyer", "utils", "planner", "tool_use")
for name in preferred:
cfg = available_tasks.get(name)
if cfg and cfg.model_list:
return name, cfg
for name, cfg in available_tasks.items():
if name != "embedding" and cfg.model_list:
return name, cfg
for name, cfg in available_tasks.items():
if cfg.model_list:
return name, cfg
return None, None
def _resolve_summary_model_config(self) -> Optional[TaskConfig]:
"""
解析 summarization.model_name 为 TaskConfig。
支持:
- "auto"
- "replyer"(任务名)
- "some-model-name"(具体模型名)
- ["utils:model1", "utils:model2", "replyer"](数组混合语法)
"""
available_tasks = llm_api.get_available_models()
if not available_tasks:
return None
raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto")
selectors = self._normalize_summary_model_selectors(raw_cfg)
default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks)
selected_models: List[str] = []
base_cfg: Optional[TaskConfig] = None
model_dict = getattr(host_model_config, "models_dict", {})
def _append_models(models: List[str]):
for model_name in models:
if model_name and model_name not in selected_models:
selected_models.append(model_name)
for raw_selector in selectors:
selector = raw_selector.strip()
if not selector:
continue
if selector.lower() == "auto":
if default_task_cfg:
_append_models(default_task_cfg.model_list)
if base_cfg is None:
base_cfg = default_task_cfg
continue
if ":" in selector:
task_name, model_name = selector.split(":", 1)
task_name = task_name.strip()
model_name = model_name.strip()
task_cfg = available_tasks.get(task_name)
if not task_cfg:
logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过")
continue
if base_cfg is None:
base_cfg = task_cfg
if not model_name or model_name.lower() == "auto":
_append_models(task_cfg.model_list)
continue
if model_name in model_dict or model_name in task_cfg.model_list:
_append_models([model_name])
else:
logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过")
continue
task_cfg = available_tasks.get(selector)
if task_cfg:
_append_models(task_cfg.model_list)
if base_cfg is None:
base_cfg = task_cfg
continue
if selector in model_dict:
_append_models([selector])
continue
logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过")
if not selected_models:
if default_task_cfg:
_append_models(default_task_cfg.model_list)
if base_cfg is None:
base_cfg = default_task_cfg
else:
first_cfg = next(iter(available_tasks.values()))
_append_models(first_cfg.model_list)
if base_cfg is None:
base_cfg = first_cfg
if not selected_models:
return None
template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values()))
return TaskConfig(
model_list=selected_models,
max_tokens=template_cfg.max_tokens,
temperature=template_cfg.temperature,
slow_threshold=template_cfg.slow_threshold,
selection_strategy=template_cfg.selection_strategy,
)
async def import_from_stream(
self,
stream_id: str,
context_length: Optional[int] = None,
include_personality: Optional[bool] = None
) -> Tuple[bool, str]:
"""
从指定的聊天流中提取记录并执行总结导入
Args:
stream_id: 聊天流 ID
context_length: 总结的历史消息条数
include_personality: 是否包含人设
Returns:
Tuple[bool, str]: (是否成功, 结果消息)
"""
try:
self_check_ok, self_check_msg = await self._ensure_runtime_self_check()
if not self_check_ok:
return False, f"导入前自检失败: {self_check_msg}"
# 1. 获取配置
if context_length is None:
context_length = self.plugin_config.get("summarization", {}).get("context_length", 50)
if include_personality is None:
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
# 2. 获取历史消息
# 获取当前时间之前的消息
now = time.time()
messages = message_api.get_messages_before_time_in_chat(
chat_id=stream_id,
timestamp=now,
limit=context_length
)
if not messages:
return False, "未找到有效的聊天记录进行总结"
# 转换为可读文本
chat_history_text = message_api.build_readable_messages(messages)
# 3. 准备提示词内容
bot_name = global_config.bot.nickname or "机器人"
personality_context = ""
if include_personality:
personality = getattr(global_config.bot, "personality", "")
if personality:
personality_context = f"你的性格设定是:{personality}"
# 4. 调用 LLM
prompt = SUMMARY_PROMPT_TEMPLATE.format(
bot_name=bot_name,
personality_context=personality_context,
chat_history=chat_history_text
)
model_config_to_use = self._resolve_summary_model_config()
if model_config_to_use is None:
return False, "未找到可用的总结模型配置"
logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
logger.info(f"总结模型候选列表: {model_config_to_use.model_list}")
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config_to_use,
request_type="A_Memorix.ChatSummarization"
)
if not success or not response:
return False, "LLM 生成总结失败"
# 5. 解析结果
data = self._parse_llm_response(response)
if not data or "summary" not in data:
return False, "解析 LLM 响应失败或总结为空"
summary_text = data["summary"]
entities = data.get("entities", [])
relations = data.get("relations", [])
msg_times = [
float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)())
for msg in messages
if getattr(msg, "time", None) is not None
]
time_meta = {}
if msg_times:
time_meta = {
"event_time_start": min(msg_times),
"event_time_end": max(msg_times),
"time_granularity": "minute",
"time_confidence": 0.95,
}
# 6. 执行导入
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta)
# 7. 持久化
self.vector_store.save()
self.graph_store.save()
result_msg = (
f"✅ 总结导入成功\n"
f"📝 总结长度: {len(summary_text)}\n"
f"📌 提取实体: {len(entities)}\n"
f"🔗 提取关系: {len(relations)}"
)
return True, result_msg
except Exception as e:
logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}")
return False, f"错误: {str(e)}"
async def _ensure_runtime_self_check(self) -> Tuple[bool, str]:
plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None
if plugin_instance is not None:
report = await ensure_runtime_self_check(plugin_instance)
else:
report = await run_embedding_runtime_self_check(
config=self.plugin_config,
vector_store=self.vector_store,
embedding_manager=self.embedding_manager,
)
if bool(report.get("ok", False)):
return True, ""
return (
False,
f"{report.get('message', 'unknown')} "
f"(configured={report.get('configured_dimension', 0)}, "
f"store={report.get('vector_store_dimension', 0)}, "
f"encoded={report.get('encoded_dimension', 0)})",
)
def _parse_llm_response(self, response: str) -> Dict[str, Any]:
"""解析 LLM 返回的 JSON"""
try:
# 尝试查找 JSON
json_match = re.search(r"\{.*\}", response, re.DOTALL)
if json_match:
return json.loads(json_match.group())
return {}
except Exception as e:
logger.warning(f"解析总结 JSON 失败: {e}")
return {}
async def _execute_import(
self,
summary: str,
entities: List[str],
relations: List[Dict[str, str]],
stream_id: str,
time_meta: Optional[Dict[str, Any]] = None,
):
"""将数据写入存储"""
# 获取默认知识类型
type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative")
try:
knowledge_type = resolve_stored_knowledge_type(type_str, content=summary)
except ValueError:
logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative")
knowledge_type = KnowledgeType.NARRATIVE
# 导入总结文本
hash_value = self.metadata_store.add_paragraph(
content=summary,
source=f"chat_summary:{stream_id}",
knowledge_type=knowledge_type.value,
time_meta=time_meta,
)
embedding = await self.embedding_manager.encode(summary)
self.vector_store.add(
vectors=embedding.reshape(1, -1),
ids=[hash_value]
)
# 导入实体
if entities:
self.graph_store.add_nodes(entities)
# 导入关系
rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {})
if not isinstance(rv_cfg, dict):
rv_cfg = {}
write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True))
for rel in relations:
s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object")
if all([s, p, o]):
if self.relation_write_service is not None:
await self.relation_write_service.upsert_relation_with_vector(
subject=s,
predicate=p,
obj=o,
confidence=1.0,
source_paragraph=summary,
write_vector=write_vector,
)
else:
# 写入元数据
rel_hash = self.metadata_store.add_relation(
subject=s,
predicate=p,
obj=o,
confidence=1.0,
source_paragraph=summary
)
# 写入图数据库(写入 relation_hashes确保后续可按关系精确修剪
self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash])
try:
self.metadata_store.set_relation_vector_state(rel_hash, "none")
except Exception:
pass
logger.info(f"总结导入完成: hash={hash_value[:8]}")

File diff suppressed because it is too large Load Diff