添加 A_Memorix 插件 v2.0.0(包含运行时与文档)
引入 A_Memorix 插件 v2.0.0:新增大量运行时组件、存储/模式更新、检索能力提升、管理工具、导入/调优工作流以及相关文档。关键新增内容包括:lifecycle_orchestrator、SDKMemoryKernel/运行时初始化器、新的存储层与 metadata_store 变更(SCHEMA_VERSION v8)、检索增强(双路径检索、图关系召回、稀疏 BM25),以及多种工具服务(episode/person_profile/relation/segmentation/tuning/search execution)。同时新增 Web 导入/摘要导入器及大量维护脚本。还更新了插件清单、embedding API 适配器、plugin.py、requirements/pyproject,以及主入口文件,使新插件接入项目。该变更为 2.0.0 版本发布做好准备,实现统一的 SDK Tool 接口并扩展整体运行能力。
This commit is contained in:
@@ -1,46 +1,55 @@
|
||||
"""
|
||||
Hash-based embedding adapter used by the SDK runtime.
|
||||
请求式嵌入 API 适配器。
|
||||
|
||||
The plugin runtime cannot import MaiBot host embedding internals from ``src.chat``
|
||||
or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in
|
||||
Runner by generating deterministic dense vectors locally.
|
||||
恢复 v1.0.1 的真实 embedding 请求语义:
|
||||
- 通过宿主模型配置探测/请求 embedding
|
||||
- 支持 dimensions 参数
|
||||
- 支持批量与重试
|
||||
- 不再提供本地 hash fallback
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.exceptions import NetworkConnectionError
|
||||
from src.llm_models.model_client.base_client import client_registry
|
||||
|
||||
logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
|
||||
|
||||
_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}")
|
||||
|
||||
|
||||
class EmbeddingAPIAdapter:
|
||||
"""Deterministic local embedding adapter."""
|
||||
"""适配宿主 embedding 请求接口。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 32,
|
||||
max_concurrent: int = 5,
|
||||
default_dimension: int = 256,
|
||||
default_dimension: int = 1024,
|
||||
enable_cache: bool = False,
|
||||
model_name: str = "hash-v1",
|
||||
model_name: str = "auto",
|
||||
retry_config: Optional[dict] = None,
|
||||
) -> None:
|
||||
self.batch_size = max(1, int(batch_size))
|
||||
self.max_concurrent = max(1, int(max_concurrent))
|
||||
self.default_dimension = max(32, int(default_dimension))
|
||||
self.default_dimension = max(1, int(default_dimension))
|
||||
self.enable_cache = bool(enable_cache)
|
||||
self.model_name = str(model_name or "hash-v1")
|
||||
self.model_name = str(model_name or "auto")
|
||||
|
||||
self.retry_config = retry_config or {}
|
||||
self.max_attempts = max(1, int(self.retry_config.get("max_attempts", 5)))
|
||||
self.max_wait_seconds = max(0.1, float(self.retry_config.get("max_wait_seconds", 40)))
|
||||
self.min_wait_seconds = max(0.1, float(self.retry_config.get("min_wait_seconds", 3)))
|
||||
self.backoff_multiplier = max(1.0, float(self.retry_config.get("backoff_multiplier", 3)))
|
||||
|
||||
self._dimension: Optional[int] = None
|
||||
self._dimension_detected = False
|
||||
@@ -49,57 +58,164 @@ class EmbeddingAPIAdapter:
|
||||
self._total_time = 0.0
|
||||
|
||||
logger.info(
|
||||
"EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s",
|
||||
self.model_name,
|
||||
self.batch_size,
|
||||
self.default_dimension,
|
||||
"EmbeddingAPIAdapter 初始化: "
|
||||
f"batch_size={self.batch_size}, "
|
||||
f"max_concurrent={self.max_concurrent}, "
|
||||
f"default_dim={self.default_dimension}, "
|
||||
f"model={self.model_name}"
|
||||
)
|
||||
|
||||
def _get_current_model_config(self):
|
||||
return config_manager.get_model_config()
|
||||
|
||||
@staticmethod
|
||||
def _find_model_info(model_name: str) -> ModelInfo:
|
||||
model_cfg = config_manager.get_model_config()
|
||||
for item in model_cfg.models:
|
||||
if item.name == model_name:
|
||||
return item
|
||||
raise ValueError(f"未找到 embedding 模型: {model_name}")
|
||||
|
||||
@staticmethod
|
||||
def _find_provider(provider_name: str) -> APIProvider:
|
||||
model_cfg = config_manager.get_model_config()
|
||||
for item in model_cfg.api_providers:
|
||||
if item.name == provider_name:
|
||||
return item
|
||||
raise ValueError(f"未找到 embedding provider: {provider_name}")
|
||||
|
||||
def _resolve_candidate_model_names(self) -> List[str]:
|
||||
task_config = self._get_current_model_config().model_task_config.embedding
|
||||
configured = list(getattr(task_config, "model_list", []) or [])
|
||||
if self.model_name and self.model_name != "auto":
|
||||
return [self.model_name, *[name for name in configured if name != self.model_name]]
|
||||
return configured
|
||||
|
||||
@staticmethod
|
||||
def _validate_embedding_vector(embedding: Any, *, source: str) -> np.ndarray:
|
||||
array = np.asarray(embedding, dtype=np.float32)
|
||||
if array.ndim != 1:
|
||||
raise RuntimeError(f"{source} 返回的 embedding 维度非法: ndim={array.ndim}")
|
||||
if array.size <= 0:
|
||||
raise RuntimeError(f"{source} 返回了空 embedding")
|
||||
if not np.all(np.isfinite(array)):
|
||||
raise RuntimeError(f"{source} 返回了非有限 embedding 值")
|
||||
return array
|
||||
|
||||
async def _request_with_retry(self, client, model_info, text: str, extra_params: dict):
|
||||
retriable_exceptions = (
|
||||
openai.APIConnectionError,
|
||||
openai.APITimeoutError,
|
||||
aiohttp.ClientError,
|
||||
asyncio.TimeoutError,
|
||||
NetworkConnectionError,
|
||||
)
|
||||
|
||||
last_exc: Optional[BaseException] = None
|
||||
for attempt in range(1, self.max_attempts + 1):
|
||||
try:
|
||||
return await client.get_embedding(
|
||||
model_info=model_info,
|
||||
embedding_input=text,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
except retriable_exceptions as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_attempts:
|
||||
raise
|
||||
wait_seconds = min(
|
||||
self.max_wait_seconds,
|
||||
self.min_wait_seconds * (self.backoff_multiplier ** (attempt - 1)),
|
||||
)
|
||||
logger.warning(
|
||||
"Embedding 请求失败,重试 "
|
||||
f"{attempt}/{max(1, self.max_attempts - 1)},"
|
||||
f"{wait_seconds:.1f}s 后重试: {exc}"
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if last_exc is not None:
|
||||
raise last_exc
|
||||
raise RuntimeError("Embedding 请求失败:未知错误")
|
||||
|
||||
async def _get_embedding_direct(self, text: str, dimensions: Optional[int] = None) -> Optional[List[float]]:
|
||||
candidate_names = self._resolve_candidate_model_names()
|
||||
if not candidate_names:
|
||||
raise RuntimeError("embedding 任务未配置模型")
|
||||
|
||||
last_exc: Optional[BaseException] = None
|
||||
for candidate_name in candidate_names:
|
||||
try:
|
||||
model_info = self._find_model_info(candidate_name)
|
||||
api_provider = self._find_provider(model_info.api_provider)
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=True)
|
||||
|
||||
extra_params = dict(getattr(model_info, "extra_params", {}) or {})
|
||||
if dimensions is not None:
|
||||
extra_params["dimensions"] = int(dimensions)
|
||||
|
||||
response = await self._request_with_retry(
|
||||
client=client,
|
||||
model_info=model_info,
|
||||
text=text,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
embedding = getattr(response, "embedding", None)
|
||||
if embedding is None:
|
||||
raise RuntimeError(f"模型 {candidate_name} 未返回 embedding")
|
||||
vector = self._validate_embedding_vector(
|
||||
embedding,
|
||||
source=f"embedding 模型 {candidate_name}",
|
||||
)
|
||||
return vector.tolist()
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
logger.warning(f"embedding 模型 {candidate_name} 请求失败: {exc}")
|
||||
|
||||
if last_exc is not None:
|
||||
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
|
||||
return None
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
if self._dimension_detected and self._dimension is not None:
|
||||
return self._dimension
|
||||
|
||||
logger.info("正在检测嵌入模型维度...")
|
||||
try:
|
||||
target_dim = self.default_dimension
|
||||
logger.debug(f"尝试请求指定维度: {target_dim}")
|
||||
test_embedding = await self._get_embedding_direct("test", dimensions=target_dim)
|
||||
if test_embedding and isinstance(test_embedding, list):
|
||||
detected_dim = len(test_embedding)
|
||||
if detected_dim == target_dim:
|
||||
logger.info(f"嵌入维度检测成功 (匹配配置): {detected_dim}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"请求维度 {target_dim} 但模型返回 {detected_dim},将使用模型自然维度"
|
||||
)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
return detected_dim
|
||||
except Exception as exc:
|
||||
logger.debug(f"带维度参数探测失败: {exc},尝试不带参数探测")
|
||||
|
||||
try:
|
||||
test_embedding = await self._get_embedding_direct("test", dimensions=None)
|
||||
if test_embedding and isinstance(test_embedding, list):
|
||||
detected_dim = len(test_embedding)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
|
||||
return detected_dim
|
||||
logger.warning(f"嵌入维度检测失败,使用默认值: {self.default_dimension}")
|
||||
except Exception as exc:
|
||||
logger.error(f"嵌入维度检测异常: {exc},使用默认值: {self.default_dimension}")
|
||||
|
||||
self._dimension = self.default_dimension
|
||||
self._dimension_detected = True
|
||||
return self._dimension
|
||||
|
||||
@staticmethod
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
clean = str(text or "").strip().lower()
|
||||
if not clean:
|
||||
return []
|
||||
return _TOKEN_PATTERN.findall(clean)
|
||||
|
||||
@staticmethod
|
||||
def _feature_weight(token: str) -> float:
|
||||
digest = hashlib.sha256(token.encode("utf-8")).digest()
|
||||
return 1.0 + (digest[10] / 255.0) * 0.5
|
||||
|
||||
def _encode_single(self, text: str, dimension: int) -> np.ndarray:
|
||||
vector = np.zeros(dimension, dtype=np.float32)
|
||||
content = str(text or "").strip()
|
||||
tokens = self._tokenize(content)
|
||||
if not tokens and content:
|
||||
tokens = [content.lower()]
|
||||
if not tokens:
|
||||
vector[0] = 1.0
|
||||
return vector
|
||||
|
||||
for token in tokens:
|
||||
digest = hashlib.sha256(token.encode("utf-8")).digest()
|
||||
bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension
|
||||
sign = 1.0 if digest[8] % 2 == 0 else -1.0
|
||||
vector[bucket] += sign * self._feature_weight(token)
|
||||
|
||||
second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension
|
||||
if second_bucket != bucket:
|
||||
vector[second_bucket] += (sign * 0.35)
|
||||
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 1e-8:
|
||||
vector /= norm
|
||||
else:
|
||||
vector[0] = 1.0
|
||||
return vector
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
@@ -109,59 +225,137 @@ class EmbeddingAPIAdapter:
|
||||
normalize: bool = True,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
_ = batch_size
|
||||
_ = show_progress
|
||||
_ = normalize
|
||||
del show_progress
|
||||
del normalize
|
||||
|
||||
started_at = time.time()
|
||||
target_dimension = max(32, int(dimensions or await self._detect_dimension()))
|
||||
start_time = time.time()
|
||||
target_dim = int(dimensions) if dimensions is not None else int(await self._detect_dimension())
|
||||
|
||||
if isinstance(texts, str):
|
||||
single_input = True
|
||||
normalized_texts = [texts]
|
||||
single_input = True
|
||||
else:
|
||||
single_input = False
|
||||
normalized_texts = list(texts or [])
|
||||
single_input = False
|
||||
|
||||
if not normalized_texts:
|
||||
empty = np.zeros((0, target_dimension), dtype=np.float32)
|
||||
empty = np.zeros((0, target_dim), dtype=np.float32)
|
||||
return empty[0] if single_input else empty
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
try:
|
||||
matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts])
|
||||
embeddings = await self._encode_batch_internal(
|
||||
normalized_texts,
|
||||
batch_size=max(1, int(batch_size)),
|
||||
dimensions=dimensions,
|
||||
)
|
||||
if embeddings.ndim == 1:
|
||||
embeddings = embeddings.reshape(1, -1)
|
||||
self._total_encoded += len(normalized_texts)
|
||||
self._total_time += time.time() - started_at
|
||||
except Exception:
|
||||
elapsed = time.time() - start_time
|
||||
self._total_time += elapsed
|
||||
logger.debug(
|
||||
"编码完成: "
|
||||
f"{len(normalized_texts)} 个文本, "
|
||||
f"耗时 {elapsed:.2f}s, "
|
||||
f"平均 {elapsed / max(1, len(normalized_texts)):.3f}s/文本"
|
||||
)
|
||||
return embeddings[0] if single_input else embeddings
|
||||
except Exception as exc:
|
||||
self._total_errors += 1
|
||||
raise
|
||||
logger.error(f"编码失败: {exc}")
|
||||
raise RuntimeError(f"embedding encode failed: {exc}") from exc
|
||||
|
||||
return matrix[0] if single_input else matrix
|
||||
async def _encode_batch_internal(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
all_embeddings: List[np.ndarray] = []
|
||||
for offset in range(0, len(texts), batch_size):
|
||||
batch = texts[offset : offset + batch_size]
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0
|
||||
async def encode_with_semaphore(text: str, index: int):
|
||||
async with semaphore:
|
||||
embedding = await self._get_embedding_direct(text, dimensions=dimensions)
|
||||
if embedding is None:
|
||||
raise RuntimeError(f"文本 {index} 编码失败:embedding 返回为空")
|
||||
vector = self._validate_embedding_vector(
|
||||
embedding,
|
||||
source=f"文本 {index}",
|
||||
)
|
||||
return index, vector
|
||||
|
||||
tasks = [
|
||||
encode_with_semaphore(text, offset + index)
|
||||
for index, text in enumerate(batch)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in results)
|
||||
|
||||
return np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
async def encode_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
del show_progress
|
||||
if num_workers is not None:
|
||||
previous = self.max_concurrent
|
||||
self.max_concurrent = max(1, int(num_workers))
|
||||
try:
|
||||
return await self.encode(texts, batch_size=batch_size, dimensions=dimensions)
|
||||
finally:
|
||||
self.max_concurrent = previous
|
||||
return await self.encode(texts, batch_size=batch_size, dimensions=dimensions)
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
if self._dimension is not None:
|
||||
return self._dimension
|
||||
logger.warning(f"维度尚未检测,返回默认值: {self.default_dimension}")
|
||||
return self.default_dimension
|
||||
|
||||
def get_model_info(self) -> dict:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"dimension": self._dimension or self.default_dimension,
|
||||
"dimension_detected": self._dimension_detected,
|
||||
"batch_size": self.batch_size,
|
||||
"max_concurrent": self.max_concurrent,
|
||||
"total_encoded": self._total_encoded,
|
||||
"total_errors": self._total_errors,
|
||||
"total_time": self._total_time,
|
||||
"avg_time_per_text": avg_time,
|
||||
"avg_time_per_text": self._total_time / self._total_encoded if self._total_encoded else 0.0,
|
||||
}
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
return self.get_model_info()
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"EmbeddingAPIAdapter(model_name={self.model_name}, "
|
||||
f"dimension={self._dimension or self.default_dimension}, "
|
||||
f"total_encoded={self._total_encoded})"
|
||||
f"EmbeddingAPIAdapter(dim={self._dimension or self.default_dimension}, "
|
||||
f"detected={self._dimension_detected}, encoded={self._total_encoded})"
|
||||
)
|
||||
|
||||
|
||||
def create_embedding_api_adapter(
|
||||
batch_size: int = 32,
|
||||
max_concurrent: int = 5,
|
||||
default_dimension: int = 256,
|
||||
default_dimension: int = 1024,
|
||||
enable_cache: bool = False,
|
||||
model_name: str = "hash-v1",
|
||||
model_name: str = "auto",
|
||||
retry_config: Optional[dict] = None,
|
||||
) -> EmbeddingAPIAdapter:
|
||||
return EmbeddingAPIAdapter(
|
||||
|
||||
@@ -285,10 +285,10 @@ class DualPathRetriever:
|
||||
relation_intent_ctx = self._build_relation_intent_context(query=query, top_k=top_k)
|
||||
|
||||
logger.info(
|
||||
"执行检索: query='%s...', strategy=%s, relation_intent=%s",
|
||||
query[:50],
|
||||
strategy.value,
|
||||
relation_intent_ctx.get("enabled", False),
|
||||
"执行检索: "
|
||||
f"query='{query[:50]}...', "
|
||||
f"strategy={strategy.value}, "
|
||||
f"relation_intent={relation_intent_ctx.get('enabled', False)}"
|
||||
)
|
||||
|
||||
if temporal and not (query or "").strip():
|
||||
@@ -1408,10 +1408,10 @@ class DualPathRetriever:
|
||||
return results
|
||||
|
||||
logger.debug(
|
||||
"relation_rerank_applied=1 relation_pair_groups=%s relation_pair_overflow_count=%s relation_pair_limit=%s",
|
||||
len(ordered_groups),
|
||||
len(overflow),
|
||||
pair_limit,
|
||||
"relation_rerank_applied=1 "
|
||||
f"relation_pair_groups={len(ordered_groups)} "
|
||||
f"relation_pair_overflow_count={len(overflow)} "
|
||||
f"relation_pair_limit={pair_limit}"
|
||||
)
|
||||
|
||||
rebuilt = list(results)
|
||||
@@ -1455,9 +1455,9 @@ class DualPathRetriever:
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"metric.ppr_timeout_skip_count=1 timeout_s=%s entities=%s",
|
||||
ppr_timeout_s,
|
||||
len(entities),
|
||||
"metric.ppr_timeout_skip_count=1 "
|
||||
f"timeout_s={ppr_timeout_s} "
|
||||
f"entities={len(entities)}"
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
|
||||
@@ -170,7 +170,7 @@ class GraphRelationRecallService:
|
||||
max_paths=self.config.max_paths,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("graph two-hop recall skipped: %s", e)
|
||||
logger.debug(f"graph two-hop recall skipped: {e}")
|
||||
return
|
||||
|
||||
for path_nodes in paths:
|
||||
@@ -210,7 +210,7 @@ class GraphRelationRecallService:
|
||||
limit=self.config.candidate_k,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("graph one-hop recall skipped: %s", e)
|
||||
logger.debug(f"graph one-hop recall skipped: {e}")
|
||||
return
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
|
||||
@@ -123,9 +123,8 @@ class SparseBM25Index:
|
||||
self._loaded = True
|
||||
self._prepare_tokenizer()
|
||||
logger.info(
|
||||
"SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s",
|
||||
self.config.tokenizer_mode,
|
||||
self.config.mode,
|
||||
"SparseBM25Index loaded: "
|
||||
f"backend=fts5, tokenizer={self.config.tokenizer_mode}, mode={self.config.mode}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -141,9 +140,9 @@ class SparseBM25Index:
|
||||
if user_dict:
|
||||
try:
|
||||
jieba.load_userdict(user_dict) # type: ignore[union-attr]
|
||||
logger.info("已加载 jieba 用户词典: %s", user_dict)
|
||||
logger.info(f"已加载 jieba 用户词典: {user_dict}")
|
||||
except Exception as e:
|
||||
logger.warning("加载 jieba 用户词典失败: %s", e)
|
||||
logger.warning(f"加载 jieba 用户词典失败: {e}")
|
||||
self._jieba_dict_loaded = True
|
||||
|
||||
def _tokenize_jieba(self, text: str) -> List[str]:
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
"""SDK runtime exports for A_Memorix."""
|
||||
|
||||
from .search_runtime_initializer import (
|
||||
SearchRuntimeBundle,
|
||||
SearchRuntimeInitializer,
|
||||
build_search_runtime,
|
||||
)
|
||||
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
|
||||
__all__ = [
|
||||
"SearchRuntimeBundle",
|
||||
"SearchRuntimeInitializer",
|
||||
"build_search_runtime",
|
||||
"KernelSearchRequest",
|
||||
"SDKMemoryKernel",
|
||||
]
|
||||
|
||||
268
plugins/A_memorix/core/runtime/lifecycle_orchestrator.py
Normal file
268
plugins/A_memorix/core/runtime/lifecycle_orchestrator.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Lifecycle bootstrap/teardown helpers extracted from plugin.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..embedding import create_embedding_api_adapter
|
||||
from ..retrieval import SparseBM25Config, SparseBM25Index
|
||||
from ..storage import (
|
||||
GraphStore,
|
||||
MetadataStore,
|
||||
QuantizationType,
|
||||
SparseMatrixFormat,
|
||||
VectorStore,
|
||||
)
|
||||
from ..utils.runtime_self_check import ensure_runtime_self_check
|
||||
from ..utils.relation_write_service import RelationWriteService
|
||||
|
||||
logger = get_logger("A_Memorix.LifecycleOrchestrator")
|
||||
|
||||
|
||||
async def ensure_initialized(plugin: Any) -> None:
|
||||
if plugin._initialized:
|
||||
plugin._runtime_ready = plugin._check_storage_ready()
|
||||
return
|
||||
|
||||
async with plugin._init_lock:
|
||||
if plugin._initialized:
|
||||
plugin._runtime_ready = plugin._check_storage_ready()
|
||||
return
|
||||
|
||||
logger.info("A_Memorix 插件正在异步初始化存储组件...")
|
||||
plugin._validate_runtime_config()
|
||||
await initialize_storage_async(plugin)
|
||||
report = await ensure_runtime_self_check(plugin, force=True)
|
||||
if not bool(report.get("ok", False)):
|
||||
logger.error(
|
||||
"A_Memorix runtime self-check failed: "
|
||||
f"{report.get('message', 'unknown')}; "
|
||||
"建议执行 python plugins/A_memorix/scripts/runtime_self_check.py --json"
|
||||
)
|
||||
|
||||
if plugin.graph_store and plugin.metadata_store:
|
||||
relation_count = plugin.metadata_store.count_relations()
|
||||
if relation_count > 0 and not plugin.graph_store.has_edge_hash_map():
|
||||
raise RuntimeError(
|
||||
"检测到 relations 数据存在但 edge-hash-map 为空。"
|
||||
" 请先执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
|
||||
plugin._initialized = True
|
||||
plugin._runtime_ready = plugin._check_storage_ready()
|
||||
plugin._update_plugin_config()
|
||||
logger.info("A_Memorix 插件异步初始化成功")
|
||||
|
||||
|
||||
def start_background_tasks(plugin: Any) -> None:
|
||||
"""Start background tasks idempotently."""
|
||||
if not hasattr(plugin, "_episode_generation_task"):
|
||||
plugin._episode_generation_task = None
|
||||
|
||||
if (
|
||||
plugin.get_config("summarization.enabled", True)
|
||||
and plugin.get_config("schedule.enabled", True)
|
||||
and (plugin._scheduled_import_task is None or plugin._scheduled_import_task.done())
|
||||
):
|
||||
plugin._scheduled_import_task = asyncio.create_task(plugin._scheduled_import_loop())
|
||||
|
||||
if (
|
||||
plugin.get_config("advanced.enable_auto_save", True)
|
||||
and (plugin._auto_save_task is None or plugin._auto_save_task.done())
|
||||
):
|
||||
plugin._auto_save_task = asyncio.create_task(plugin._auto_save_loop())
|
||||
|
||||
if (
|
||||
plugin.get_config("person_profile.enabled", True)
|
||||
and (plugin._person_profile_refresh_task is None or plugin._person_profile_refresh_task.done())
|
||||
):
|
||||
plugin._person_profile_refresh_task = asyncio.create_task(plugin._person_profile_refresh_loop())
|
||||
|
||||
if plugin._memory_maintenance_task is None or plugin._memory_maintenance_task.done():
|
||||
plugin._memory_maintenance_task = asyncio.create_task(plugin._memory_maintenance_loop())
|
||||
|
||||
rv_cfg = plugin.get_config("retrieval.relation_vectorization", {}) or {}
|
||||
if isinstance(rv_cfg, dict):
|
||||
rv_enabled = bool(rv_cfg.get("enabled", False))
|
||||
rv_backfill = bool(rv_cfg.get("backfill_enabled", False))
|
||||
else:
|
||||
rv_enabled = False
|
||||
rv_backfill = False
|
||||
if rv_enabled and rv_backfill and (
|
||||
plugin._relation_vector_backfill_task is None or plugin._relation_vector_backfill_task.done()
|
||||
):
|
||||
plugin._relation_vector_backfill_task = asyncio.create_task(plugin._relation_vector_backfill_loop())
|
||||
|
||||
episode_task = getattr(plugin, "_episode_generation_task", None)
|
||||
episode_loop = getattr(plugin, "_episode_generation_loop", None)
|
||||
if (
|
||||
callable(episode_loop)
|
||||
and bool(plugin.get_config("episode.enabled", True))
|
||||
and bool(plugin.get_config("episode.generation_enabled", True))
|
||||
and (episode_task is None or episode_task.done())
|
||||
):
|
||||
plugin._episode_generation_task = asyncio.create_task(episode_loop())
|
||||
|
||||
|
||||
async def cancel_background_tasks(plugin: Any) -> None:
|
||||
"""Cancel all background tasks and wait for cleanup."""
|
||||
tasks = [
|
||||
("scheduled_import", plugin._scheduled_import_task),
|
||||
("auto_save", plugin._auto_save_task),
|
||||
("person_profile_refresh", plugin._person_profile_refresh_task),
|
||||
("memory_maintenance", plugin._memory_maintenance_task),
|
||||
("relation_vector_backfill", plugin._relation_vector_backfill_task),
|
||||
("episode_generation", getattr(plugin, "_episode_generation_task", None)),
|
||||
]
|
||||
for _, task in tasks:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
for name, task in tasks:
|
||||
if not task:
|
||||
continue
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"后台任务 {name} 退出异常: {e}")
|
||||
|
||||
plugin._scheduled_import_task = None
|
||||
plugin._auto_save_task = None
|
||||
plugin._person_profile_refresh_task = None
|
||||
plugin._memory_maintenance_task = None
|
||||
plugin._relation_vector_backfill_task = None
|
||||
plugin._episode_generation_task = None
|
||||
|
||||
|
||||
async def initialize_storage_async(plugin: Any) -> None:
|
||||
"""Initialize storage components asynchronously."""
|
||||
data_dir_str = plugin.get_config("storage.data_dir", "./data")
|
||||
if data_dir_str.startswith("."):
|
||||
plugin_dir = Path(__file__).resolve().parents[2]
|
||||
data_dir = (plugin_dir / data_dir_str).resolve()
|
||||
else:
|
||||
data_dir = Path(data_dir_str)
|
||||
|
||||
logger.info(f"A_Memorix 数据存储路径: {data_dir}")
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plugin.embedding_manager = create_embedding_api_adapter(
|
||||
batch_size=plugin.get_config("embedding.batch_size", 32),
|
||||
max_concurrent=plugin.get_config("embedding.max_concurrent", 5),
|
||||
default_dimension=plugin.get_config("embedding.dimension", 1024),
|
||||
model_name=plugin.get_config("embedding.model_name", "auto"),
|
||||
retry_config=plugin.get_config("embedding.retry", {}),
|
||||
)
|
||||
logger.info("嵌入 API 适配器初始化完成")
|
||||
|
||||
try:
|
||||
detected_dimension = await plugin.embedding_manager._detect_dimension()
|
||||
logger.info(f"嵌入维度检测成功: {detected_dimension}")
|
||||
except Exception as e:
|
||||
logger.warning(f"嵌入维度检测失败: {e},使用默认值")
|
||||
detected_dimension = plugin.embedding_manager.default_dimension
|
||||
|
||||
quantization_str = plugin.get_config("embedding.quantization_type", "int8")
|
||||
if str(quantization_str or "").strip().lower() != "int8":
|
||||
raise ValueError("embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。")
|
||||
quantization_type = QuantizationType.INT8
|
||||
|
||||
plugin.vector_store = VectorStore(
|
||||
dimension=detected_dimension,
|
||||
quantization_type=quantization_type,
|
||||
data_dir=data_dir / "vectors",
|
||||
)
|
||||
plugin.vector_store.min_train_threshold = plugin.get_config("embedding.min_train_threshold", 40)
|
||||
logger.info(
|
||||
"向量存储初始化完成("
|
||||
f"维度: {detected_dimension}, "
|
||||
f"训练阈值: {plugin.vector_store.min_train_threshold})"
|
||||
)
|
||||
|
||||
matrix_format_str = plugin.get_config("graph.sparse_matrix_format", "csr")
|
||||
matrix_format_map = {
|
||||
"csr": SparseMatrixFormat.CSR,
|
||||
"csc": SparseMatrixFormat.CSC,
|
||||
}
|
||||
matrix_format = matrix_format_map.get(matrix_format_str, SparseMatrixFormat.CSR)
|
||||
|
||||
plugin.graph_store = GraphStore(
|
||||
matrix_format=matrix_format,
|
||||
data_dir=data_dir / "graph",
|
||||
)
|
||||
logger.info("图存储初始化完成")
|
||||
|
||||
plugin.metadata_store = MetadataStore(data_dir=data_dir / "metadata")
|
||||
plugin.metadata_store.connect()
|
||||
logger.info("元数据存储初始化完成")
|
||||
|
||||
plugin.relation_write_service = RelationWriteService(
|
||||
metadata_store=plugin.metadata_store,
|
||||
graph_store=plugin.graph_store,
|
||||
vector_store=plugin.vector_store,
|
||||
embedding_manager=plugin.embedding_manager,
|
||||
)
|
||||
logger.info("关系写入服务初始化完成")
|
||||
|
||||
sparse_cfg_raw = plugin.get_config("retrieval.sparse", {}) or {}
|
||||
if not isinstance(sparse_cfg_raw, dict):
|
||||
sparse_cfg_raw = {}
|
||||
try:
|
||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"sparse 配置非法,回退默认配置: {e}")
|
||||
sparse_cfg = SparseBM25Config()
|
||||
plugin.sparse_index = SparseBM25Index(
|
||||
metadata_store=plugin.metadata_store,
|
||||
config=sparse_cfg,
|
||||
)
|
||||
logger.info(
|
||||
"稀疏检索组件初始化完成: "
|
||||
f"enabled={sparse_cfg.enabled}, "
|
||||
f"lazy_load={sparse_cfg.lazy_load}, "
|
||||
f"mode={sparse_cfg.mode}, "
|
||||
f"tokenizer={sparse_cfg.tokenizer_mode}"
|
||||
)
|
||||
if sparse_cfg.enabled and not sparse_cfg.lazy_load:
|
||||
plugin.sparse_index.ensure_loaded()
|
||||
|
||||
if plugin.vector_store.has_data():
|
||||
try:
|
||||
plugin.vector_store.load()
|
||||
logger.info(f"向量数据已加载,共 {plugin.vector_store.num_vectors} 个向量")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载向量数据失败: {e}")
|
||||
|
||||
try:
|
||||
warmup_summary = plugin.vector_store.warmup_index(force_train=True)
|
||||
if warmup_summary.get("ok"):
|
||||
logger.info(
|
||||
"向量索引预热完成: "
|
||||
f"trained={warmup_summary.get('trained')}, "
|
||||
f"index_ntotal={warmup_summary.get('index_ntotal')}, "
|
||||
f"fallback_ntotal={warmup_summary.get('fallback_ntotal')}, "
|
||||
f"bin_count={warmup_summary.get('bin_count')}, "
|
||||
f"duration_ms={float(warmup_summary.get('duration_ms', 0.0)):.2f}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"向量索引预热失败,继续启用 sparse 降级路径: "
|
||||
f"{warmup_summary.get('error', 'unknown')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"向量索引预热异常,继续启用 sparse 降级路径: {e}")
|
||||
|
||||
if plugin.graph_store.has_data():
|
||||
try:
|
||||
plugin.graph_store.load()
|
||||
logger.info(f"图数据已加载,共 {plugin.graph_store.num_nodes} 个节点")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载图数据失败: {e}")
|
||||
|
||||
logger.info(f"知识库数据目录: {data_dir}")
|
||||
File diff suppressed because it is too large
Load Diff
240
plugins/A_memorix/core/runtime/search_runtime_initializer.py
Normal file
240
plugins/A_memorix/core/runtime/search_runtime_initializer.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Shared runtime initializer for Action/Tool/Command retrieval components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import (
|
||||
DualPathRetriever,
|
||||
DualPathRetrieverConfig,
|
||||
DynamicThresholdFilter,
|
||||
FusionConfig,
|
||||
GraphRelationRecallConfig,
|
||||
RelationIntentConfig,
|
||||
RetrievalStrategy,
|
||||
SparseBM25Config,
|
||||
ThresholdConfig,
|
||||
ThresholdMethod,
|
||||
)
|
||||
|
||||
_logger = get_logger("A_Memorix.SearchRuntimeInitializer")
|
||||
|
||||
_REQUIRED_COMPONENT_KEYS = (
|
||||
"vector_store",
|
||||
"graph_store",
|
||||
"metadata_store",
|
||||
"embedding_manager",
|
||||
)
|
||||
|
||||
|
||||
def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any:
|
||||
if not isinstance(config, dict):
|
||||
return default
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def _safe_dict(value: Any) -> Dict[str, Any]:
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _resolve_debug_enabled(plugin_config: Optional[dict]) -> bool:
|
||||
advanced = _get_config_value(plugin_config, "advanced", {})
|
||||
if isinstance(advanced, dict):
|
||||
return bool(advanced.get("debug", False))
|
||||
return bool(_get_config_value(plugin_config, "debug", False))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchRuntimeBundle:
|
||||
"""Resolved runtime components and initialized retriever/filter."""
|
||||
|
||||
vector_store: Optional[Any] = None
|
||||
graph_store: Optional[Any] = None
|
||||
metadata_store: Optional[Any] = None
|
||||
embedding_manager: Optional[Any] = None
|
||||
sparse_index: Optional[Any] = None
|
||||
retriever: Optional[DualPathRetriever] = None
|
||||
threshold_filter: Optional[DynamicThresholdFilter] = None
|
||||
error: str = ""
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
return (
|
||||
self.retriever is not None
|
||||
and self.vector_store is not None
|
||||
and self.graph_store is not None
|
||||
and self.metadata_store is not None
|
||||
and self.embedding_manager is not None
|
||||
)
|
||||
|
||||
|
||||
def _resolve_runtime_components(plugin_config: Optional[dict]) -> SearchRuntimeBundle:
|
||||
bundle = SearchRuntimeBundle(
|
||||
vector_store=_get_config_value(plugin_config, "vector_store"),
|
||||
graph_store=_get_config_value(plugin_config, "graph_store"),
|
||||
metadata_store=_get_config_value(plugin_config, "metadata_store"),
|
||||
embedding_manager=_get_config_value(plugin_config, "embedding_manager"),
|
||||
sparse_index=_get_config_value(plugin_config, "sparse_index"),
|
||||
)
|
||||
|
||||
missing_required = any(
|
||||
getattr(bundle, key) is None for key in _REQUIRED_COMPONENT_KEYS
|
||||
)
|
||||
if not missing_required:
|
||||
return bundle
|
||||
|
||||
try:
|
||||
from ...plugin import AMemorixPlugin
|
||||
|
||||
instances = AMemorixPlugin.get_storage_instances()
|
||||
except Exception:
|
||||
instances = {}
|
||||
|
||||
if not isinstance(instances, dict) or not instances:
|
||||
return bundle
|
||||
|
||||
if bundle.vector_store is None:
|
||||
bundle.vector_store = instances.get("vector_store")
|
||||
if bundle.graph_store is None:
|
||||
bundle.graph_store = instances.get("graph_store")
|
||||
if bundle.metadata_store is None:
|
||||
bundle.metadata_store = instances.get("metadata_store")
|
||||
if bundle.embedding_manager is None:
|
||||
bundle.embedding_manager = instances.get("embedding_manager")
|
||||
if bundle.sparse_index is None:
|
||||
bundle.sparse_index = instances.get("sparse_index")
|
||||
return bundle
|
||||
|
||||
|
||||
def build_search_runtime(
|
||||
plugin_config: Optional[dict],
|
||||
logger_obj: Optional[Any],
|
||||
owner_tag: str,
|
||||
*,
|
||||
log_prefix: str = "",
|
||||
) -> SearchRuntimeBundle:
|
||||
"""Build retriever + threshold filter with unified fallback/config parsing."""
|
||||
|
||||
log = logger_obj or _logger
|
||||
owner = str(owner_tag or "runtime").strip().lower() or "runtime"
|
||||
prefix = str(log_prefix or "").strip()
|
||||
prefix_text = f"{prefix} " if prefix else ""
|
||||
|
||||
runtime = _resolve_runtime_components(plugin_config)
|
||||
if any(getattr(runtime, key) is None for key in _REQUIRED_COMPONENT_KEYS):
|
||||
runtime.error = "存储组件未完全初始化"
|
||||
log.warning(f"{prefix_text}[{owner}] 存储组件未完全初始化,无法使用检索功能")
|
||||
return runtime
|
||||
|
||||
sparse_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.sparse", {}) or {})
|
||||
fusion_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.fusion", {}) or {})
|
||||
relation_intent_cfg_raw = _safe_dict(
|
||||
_get_config_value(plugin_config, "retrieval.search.relation_intent", {}) or {}
|
||||
)
|
||||
graph_recall_cfg_raw = _safe_dict(
|
||||
_get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {}
|
||||
)
|
||||
|
||||
try:
|
||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] sparse 配置非法,回退默认: {e}")
|
||||
sparse_cfg = SparseBM25Config()
|
||||
|
||||
try:
|
||||
fusion_cfg = FusionConfig(**fusion_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] fusion 配置非法,回退默认: {e}")
|
||||
fusion_cfg = FusionConfig()
|
||||
|
||||
try:
|
||||
relation_intent_cfg = RelationIntentConfig(**relation_intent_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] relation_intent 配置非法,回退默认: {e}")
|
||||
relation_intent_cfg = RelationIntentConfig()
|
||||
|
||||
try:
|
||||
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}")
|
||||
graph_recall_cfg = GraphRelationRecallConfig()
|
||||
|
||||
try:
|
||||
config = DualPathRetrieverConfig(
|
||||
top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20),
|
||||
top_k_relations=_get_config_value(plugin_config, "retrieval.top_k_relations", 10),
|
||||
top_k_final=_get_config_value(plugin_config, "retrieval.top_k_final", 10),
|
||||
alpha=_get_config_value(plugin_config, "retrieval.alpha", 0.5),
|
||||
enable_ppr=_get_config_value(plugin_config, "retrieval.enable_ppr", True),
|
||||
ppr_alpha=_get_config_value(plugin_config, "retrieval.ppr_alpha", 0.85),
|
||||
ppr_timeout_seconds=_get_config_value(
|
||||
plugin_config, "retrieval.ppr_timeout_seconds", 1.5
|
||||
),
|
||||
ppr_concurrency_limit=_get_config_value(
|
||||
plugin_config, "retrieval.ppr_concurrency_limit", 4
|
||||
),
|
||||
enable_parallel=_get_config_value(plugin_config, "retrieval.enable_parallel", True),
|
||||
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
|
||||
debug=_resolve_debug_enabled(plugin_config),
|
||||
sparse=sparse_cfg,
|
||||
fusion=fusion_cfg,
|
||||
relation_intent=relation_intent_cfg,
|
||||
graph_recall=graph_recall_cfg,
|
||||
)
|
||||
|
||||
runtime.retriever = DualPathRetriever(
|
||||
vector_store=runtime.vector_store,
|
||||
graph_store=runtime.graph_store,
|
||||
metadata_store=runtime.metadata_store,
|
||||
embedding_manager=runtime.embedding_manager,
|
||||
sparse_index=runtime.sparse_index,
|
||||
config=config,
|
||||
)
|
||||
|
||||
threshold_config = ThresholdConfig(
|
||||
method=ThresholdMethod.ADAPTIVE,
|
||||
min_threshold=_get_config_value(plugin_config, "threshold.min_threshold", 0.3),
|
||||
max_threshold=_get_config_value(plugin_config, "threshold.max_threshold", 0.95),
|
||||
percentile=_get_config_value(plugin_config, "threshold.percentile", 75.0),
|
||||
std_multiplier=_get_config_value(plugin_config, "threshold.std_multiplier", 1.5),
|
||||
min_results=_get_config_value(plugin_config, "threshold.min_results", 3),
|
||||
enable_auto_adjust=_get_config_value(plugin_config, "threshold.enable_auto_adjust", True),
|
||||
)
|
||||
runtime.threshold_filter = DynamicThresholdFilter(threshold_config)
|
||||
runtime.error = ""
|
||||
log.info(f"{prefix_text}[{owner}] 检索运行时初始化完成")
|
||||
except Exception as e:
|
||||
runtime.retriever = None
|
||||
runtime.threshold_filter = None
|
||||
runtime.error = str(e)
|
||||
log.error(f"{prefix_text}[{owner}] 检索运行时初始化失败: {e}")
|
||||
|
||||
return runtime
|
||||
|
||||
|
||||
class SearchRuntimeInitializer:
|
||||
"""Compatibility wrapper around the function style initializer."""
|
||||
|
||||
@staticmethod
|
||||
def build_search_runtime(
|
||||
plugin_config: Optional[dict],
|
||||
logger_obj: Optional[Any],
|
||||
owner_tag: str,
|
||||
*,
|
||||
log_prefix: str = "",
|
||||
) -> SearchRuntimeBundle:
|
||||
return build_search_runtime(
|
||||
plugin_config=plugin_config,
|
||||
logger_obj=logger_obj,
|
||||
owner_tag=owner_tag,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
@@ -24,6 +24,20 @@ try:
|
||||
from scipy.sparse.linalg import norm
|
||||
HAS_SCIPY = True
|
||||
except ImportError:
|
||||
class _SparseMatrixPlaceholder:
|
||||
pass
|
||||
|
||||
def _scipy_missing(*args, **kwargs):
|
||||
raise ImportError("SciPy 未安装,请安装: pip install scipy")
|
||||
|
||||
csr_matrix = _SparseMatrixPlaceholder
|
||||
csc_matrix = _SparseMatrixPlaceholder
|
||||
lil_matrix = _SparseMatrixPlaceholder
|
||||
triu = _scipy_missing
|
||||
save_npz = _scipy_missing
|
||||
load_npz = _scipy_missing
|
||||
bmat = _scipy_missing
|
||||
norm = _scipy_missing
|
||||
HAS_SCIPY = False
|
||||
|
||||
import contextlib
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
import sqlite3
|
||||
import pickle
|
||||
import json
|
||||
import uuid
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, List, Dict, Any, Tuple
|
||||
@@ -24,7 +26,7 @@ from .knowledge_types import (
|
||||
logger = get_logger("A_Memorix.MetadataStore")
|
||||
|
||||
|
||||
SCHEMA_VERSION = 7
|
||||
SCHEMA_VERSION = 8
|
||||
|
||||
|
||||
class MetadataStore:
|
||||
@@ -500,6 +502,63 @@ class MetadataStore:
|
||||
CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph
|
||||
ON external_memory_refs(paragraph_hash)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS memory_v5_operations (
|
||||
operation_id TEXT PRIMARY KEY,
|
||||
action TEXT NOT NULL,
|
||||
target TEXT,
|
||||
reason TEXT,
|
||||
updated_by TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
resolved_hashes_json TEXT,
|
||||
result_json TEXT
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created
|
||||
ON memory_v5_operations(created_at DESC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS delete_operations (
|
||||
operation_id TEXT PRIMARY KEY,
|
||||
mode TEXT NOT NULL,
|
||||
selector TEXT,
|
||||
reason TEXT,
|
||||
requested_by TEXT,
|
||||
status TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
restored_at REAL,
|
||||
summary_json TEXT
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_delete_operations_created
|
||||
ON delete_operations(created_at DESC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_delete_operations_mode
|
||||
ON delete_operations(mode, created_at DESC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS delete_operation_items (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
operation_id TEXT NOT NULL,
|
||||
item_type TEXT NOT NULL,
|
||||
item_hash TEXT,
|
||||
item_key TEXT,
|
||||
payload_json TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation
|
||||
ON delete_operation_items(operation_id, id ASC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash
|
||||
ON delete_operation_items(item_hash)
|
||||
""")
|
||||
# 新版 schema 包含完整字段,直接写入版本信息
|
||||
cursor.execute("INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", (SCHEMA_VERSION, datetime.now().timestamp()))
|
||||
self._conn.commit()
|
||||
@@ -618,6 +677,63 @@ class MetadataStore:
|
||||
CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph
|
||||
ON external_memory_refs(paragraph_hash)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS memory_v5_operations (
|
||||
operation_id TEXT PRIMARY KEY,
|
||||
action TEXT NOT NULL,
|
||||
target TEXT,
|
||||
reason TEXT,
|
||||
updated_by TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
resolved_hashes_json TEXT,
|
||||
result_json TEXT
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created
|
||||
ON memory_v5_operations(created_at DESC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS delete_operations (
|
||||
operation_id TEXT PRIMARY KEY,
|
||||
mode TEXT NOT NULL,
|
||||
selector TEXT,
|
||||
reason TEXT,
|
||||
requested_by TEXT,
|
||||
status TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
restored_at REAL,
|
||||
summary_json TEXT
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_delete_operations_created
|
||||
ON delete_operations(created_at DESC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_delete_operations_mode
|
||||
ON delete_operations(mode, created_at DESC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS delete_operation_items (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
operation_id TEXT NOT NULL,
|
||||
item_type TEXT NOT NULL,
|
||||
item_hash TEXT,
|
||||
item_key TEXT,
|
||||
payload_json TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation
|
||||
ON delete_operation_items(operation_id, id ASC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash
|
||||
ON delete_operation_items(item_hash)
|
||||
""")
|
||||
|
||||
# 检查paragraphs表是否有knowledge_type列
|
||||
cursor.execute("PRAGMA table_info(paragraphs)")
|
||||
@@ -2595,6 +2711,328 @@ class MetadataStore:
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _json_dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True)
|
||||
|
||||
@staticmethod
|
||||
def _json_loads(value: Any, default: Any) -> Any:
|
||||
if value in {None, ""}:
|
||||
return default
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def list_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]:
|
||||
hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()]
|
||||
if not hashes:
|
||||
return []
|
||||
placeholders = ",".join(["?"] * len(hashes))
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT external_id, paragraph_hash, source_type, created_at, metadata_json
|
||||
FROM external_memory_refs
|
||||
WHERE paragraph_hash IN ({placeholders})
|
||||
ORDER BY created_at ASC, external_id ASC
|
||||
""",
|
||||
tuple(hashes),
|
||||
)
|
||||
items: List[Dict[str, Any]] = []
|
||||
for row in cursor.fetchall():
|
||||
payload = dict(row)
|
||||
payload["metadata"] = self._json_loads(payload.get("metadata_json"), {})
|
||||
items.append(payload)
|
||||
return items
|
||||
|
||||
def delete_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]:
|
||||
items = self.list_external_memory_refs_by_paragraphs(paragraph_hashes)
|
||||
hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()]
|
||||
if not hashes:
|
||||
return items
|
||||
placeholders = ",".join(["?"] * len(hashes))
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"DELETE FROM external_memory_refs WHERE paragraph_hash IN ({placeholders})",
|
||||
tuple(hashes),
|
||||
)
|
||||
self._conn.commit()
|
||||
return items
|
||||
|
||||
def restore_external_memory_refs(self, refs: List[Dict[str, Any]]) -> int:
|
||||
count = 0
|
||||
for item in refs or []:
|
||||
external_id = str(item.get("external_id", "") or "").strip()
|
||||
paragraph_hash = str(item.get("paragraph_hash", "") or "").strip()
|
||||
if not external_id or not paragraph_hash:
|
||||
continue
|
||||
created_at = float(item.get("created_at") or datetime.now().timestamp())
|
||||
metadata_json = self._json_dumps(item.get("metadata") or {})
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO external_memory_refs (
|
||||
external_id, paragraph_hash, source_type, created_at, metadata_json
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(external_id) DO UPDATE SET
|
||||
paragraph_hash = excluded.paragraph_hash,
|
||||
source_type = excluded.source_type,
|
||||
created_at = excluded.created_at,
|
||||
metadata_json = excluded.metadata_json
|
||||
""",
|
||||
(
|
||||
external_id,
|
||||
paragraph_hash,
|
||||
str(item.get("source_type", "") or "").strip() or None,
|
||||
created_at,
|
||||
metadata_json,
|
||||
),
|
||||
)
|
||||
count += max(0, int(cursor.rowcount or 0))
|
||||
self._conn.commit()
|
||||
return count
|
||||
|
||||
def record_v5_operation(
|
||||
self,
|
||||
*,
|
||||
action: str,
|
||||
target: str,
|
||||
resolved_hashes: List[str],
|
||||
reason: str = "",
|
||||
updated_by: str = "",
|
||||
result: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
operation_id = f"v5_{uuid.uuid4().hex}"
|
||||
created_at = datetime.now().timestamp()
|
||||
payload = {
|
||||
"operation_id": operation_id,
|
||||
"action": str(action or "").strip(),
|
||||
"target": str(target or "").strip(),
|
||||
"reason": str(reason or "").strip(),
|
||||
"updated_by": str(updated_by or "").strip(),
|
||||
"created_at": created_at,
|
||||
"resolved_hashes": [str(item or "").strip() for item in (resolved_hashes or []) if str(item or "").strip()],
|
||||
"result": result or {},
|
||||
}
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO memory_v5_operations (
|
||||
operation_id, action, target, reason, updated_by, created_at, resolved_hashes_json, result_json
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
operation_id,
|
||||
payload["action"],
|
||||
payload["target"] or None,
|
||||
payload["reason"] or None,
|
||||
payload["updated_by"] or None,
|
||||
created_at,
|
||||
self._json_dumps(payload["resolved_hashes"]),
|
||||
self._json_dumps(payload["result"]),
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
return payload
|
||||
|
||||
def create_delete_operation(
|
||||
self,
|
||||
*,
|
||||
mode: str,
|
||||
selector: Any,
|
||||
items: List[Dict[str, Any]],
|
||||
reason: str = "",
|
||||
requested_by: str = "",
|
||||
status: str = "executed",
|
||||
summary: Optional[Dict[str, Any]] = None,
|
||||
operation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
op_id = str(operation_id or f"del_{uuid.uuid4().hex}").strip()
|
||||
created_at = datetime.now().timestamp()
|
||||
normalized_items: List[Dict[str, Any]] = []
|
||||
for item in items or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = str(item.get("item_type", "") or "").strip()
|
||||
if not item_type:
|
||||
continue
|
||||
normalized_items.append(
|
||||
{
|
||||
"item_type": item_type,
|
||||
"item_hash": str(item.get("item_hash", "") or "").strip() or None,
|
||||
"item_key": str(item.get("item_key", "") or item.get("item_hash", "") or "").strip() or None,
|
||||
"payload": item.get("payload") if isinstance(item.get("payload"), dict) else {},
|
||||
}
|
||||
)
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO delete_operations (
|
||||
operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, NULL, ?)
|
||||
""",
|
||||
(
|
||||
op_id,
|
||||
str(mode or "").strip(),
|
||||
self._json_dumps(selector if selector is not None else {}),
|
||||
str(reason or "").strip() or None,
|
||||
str(requested_by or "").strip() or None,
|
||||
str(status or "executed").strip(),
|
||||
created_at,
|
||||
self._json_dumps(summary or {}),
|
||||
),
|
||||
)
|
||||
if normalized_items:
|
||||
cursor.executemany(
|
||||
"""
|
||||
INSERT INTO delete_operation_items (
|
||||
operation_id, item_type, item_hash, item_key, payload_json, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(
|
||||
op_id,
|
||||
item["item_type"],
|
||||
item["item_hash"],
|
||||
item["item_key"],
|
||||
self._json_dumps(item["payload"]),
|
||||
created_at,
|
||||
)
|
||||
for item in normalized_items
|
||||
],
|
||||
)
|
||||
self._conn.commit()
|
||||
return self.get_delete_operation(op_id) or {
|
||||
"operation_id": op_id,
|
||||
"mode": str(mode or "").strip(),
|
||||
"selector": selector,
|
||||
"reason": str(reason or "").strip(),
|
||||
"requested_by": str(requested_by or "").strip(),
|
||||
"status": str(status or "executed").strip(),
|
||||
"created_at": created_at,
|
||||
"summary": summary or {},
|
||||
"items": normalized_items,
|
||||
}
|
||||
|
||||
def mark_delete_operation_restored(
|
||||
self,
|
||||
operation_id: str,
|
||||
*,
|
||||
summary: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
token = str(operation_id or "").strip()
|
||||
if not token:
|
||||
return False
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE delete_operations
|
||||
SET status = ?, restored_at = ?, summary_json = ?
|
||||
WHERE operation_id = ?
|
||||
""",
|
||||
(
|
||||
"restored",
|
||||
datetime.now().timestamp(),
|
||||
self._json_dumps(summary or {}),
|
||||
token,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def list_delete_operations(self, *, limit: int = 50, mode: str = "") -> List[Dict[str, Any]]:
|
||||
cursor = self._conn.cursor()
|
||||
params: List[Any] = []
|
||||
where = ""
|
||||
mode_token = str(mode or "").strip().lower()
|
||||
if mode_token:
|
||||
where = "WHERE LOWER(mode) = ?"
|
||||
params.append(mode_token)
|
||||
params.append(max(1, int(limit or 50)))
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json
|
||||
FROM delete_operations
|
||||
{where}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
items: List[Dict[str, Any]] = []
|
||||
for row in cursor.fetchall():
|
||||
payload = dict(row)
|
||||
payload["selector"] = self._json_loads(payload.get("selector"), {})
|
||||
payload["summary"] = self._json_loads(payload.get("summary_json"), {})
|
||||
items.append(payload)
|
||||
return items
|
||||
|
||||
def get_delete_operation(self, operation_id: str) -> Optional[Dict[str, Any]]:
|
||||
token = str(operation_id or "").strip()
|
||||
if not token:
|
||||
return None
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json
|
||||
FROM delete_operations
|
||||
WHERE operation_id = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(token,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
payload = dict(row)
|
||||
payload["selector"] = self._json_loads(payload.get("selector"), {})
|
||||
payload["summary"] = self._json_loads(payload.get("summary_json"), {})
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT item_type, item_hash, item_key, payload_json, created_at
|
||||
FROM delete_operation_items
|
||||
WHERE operation_id = ?
|
||||
ORDER BY id ASC
|
||||
""",
|
||||
(token,),
|
||||
)
|
||||
payload["items"] = [
|
||||
{
|
||||
"item_type": str(item["item_type"] or ""),
|
||||
"item_hash": str(item["item_hash"] or ""),
|
||||
"item_key": str(item["item_key"] or ""),
|
||||
"payload": self._json_loads(item["payload_json"], {}),
|
||||
"created_at": item["created_at"],
|
||||
}
|
||||
for item in cursor.fetchall()
|
||||
]
|
||||
return payload
|
||||
|
||||
def purge_deleted_relations(self, *, cutoff_time: float, limit: int = 1000) -> List[str]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT hash
|
||||
FROM deleted_relations
|
||||
WHERE deleted_at IS NOT NULL AND deleted_at < ?
|
||||
ORDER BY deleted_at ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(float(cutoff_time), max(1, int(limit or 1000))),
|
||||
)
|
||||
hashes = [str(row[0] or "").strip() for row in cursor.fetchall() if str(row[0] or "").strip()]
|
||||
if not hashes:
|
||||
return []
|
||||
placeholders = ",".join(["?"] * len(hashes))
|
||||
cursor.execute(f"DELETE FROM deleted_relations WHERE hash IN ({placeholders})", tuple(hashes))
|
||||
self._conn.commit()
|
||||
return hashes
|
||||
|
||||
def get_statistics(self) -> Dict[str, int]:
|
||||
"""
|
||||
获取统计信息
|
||||
@@ -2956,6 +3394,18 @@ class MetadataStore:
|
||||
self._conn.commit()
|
||||
return changed
|
||||
|
||||
def restore_paragraph_by_hash(self, paragraph_hash: str) -> bool:
|
||||
"""恢复软删除段落。"""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE paragraphs SET is_deleted=0, deleted_at=NULL WHERE hash=?",
|
||||
(str(paragraph_hash),),
|
||||
)
|
||||
changed = cursor.rowcount > 0
|
||||
if changed:
|
||||
self._conn.commit()
|
||||
return changed
|
||||
|
||||
def backfill_temporal_metadata_from_created_at(
|
||||
self,
|
||||
*,
|
||||
@@ -4698,6 +5148,29 @@ class MetadataStore:
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_episode_pending_status_counts(self, source: str) -> Dict[str, int]:
|
||||
"""统计某个 source 当前 pending 队列中的状态分布。"""
|
||||
token = self._normalize_episode_source(source)
|
||||
if not token:
|
||||
return {"pending": 0, "running": 0, "failed": 0, "done": 0}
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT status, COUNT(*) AS count
|
||||
FROM episode_pending_paragraphs
|
||||
WHERE TRIM(COALESCE(source, '')) = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(token,),
|
||||
)
|
||||
counts = {"pending": 0, "running": 0, "failed": 0, "done": 0}
|
||||
for row in cursor.fetchall():
|
||||
status = str(row["status"] or "").strip().lower()
|
||||
if status in counts:
|
||||
counts[status] = int(row["count"] or 0)
|
||||
return counts
|
||||
|
||||
def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]:
|
||||
data = dict(row)
|
||||
|
||||
@@ -4904,7 +5377,7 @@ class MetadataStore:
|
||||
SELECT 1
|
||||
FROM episode_rebuild_sources ers
|
||||
WHERE ers.source = TRIM(COALESCE(e.source, ''))
|
||||
AND ers.status IN ('pending', 'running', 'failed')
|
||||
AND ers.status IN ('pending', 'running')
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -4948,6 +5421,26 @@ class MetadataStore:
|
||||
|
||||
return source_expr, effective_start, effective_end, conditions, params
|
||||
|
||||
@staticmethod
|
||||
def _tokenize_episode_query(query: str) -> Tuple[str, List[str]]:
|
||||
"""将 episode 查询归一化为短语和 token。"""
|
||||
normalized = normalize_text(str(query or "")).strip().lower()
|
||||
if not normalized:
|
||||
return "", []
|
||||
|
||||
token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}")
|
||||
tokens: List[str] = []
|
||||
seen = set()
|
||||
for token in token_pattern.findall(normalized):
|
||||
if token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
tokens.append(token)
|
||||
|
||||
if not tokens and len(normalized) >= 2:
|
||||
tokens = [normalized]
|
||||
return normalized, tokens
|
||||
|
||||
def get_episode_rows_by_paragraph_hashes(
|
||||
self,
|
||||
paragraph_hashes: List[str],
|
||||
@@ -5097,28 +5590,58 @@ class MetadataStore:
|
||||
source=source,
|
||||
)
|
||||
|
||||
q = str(query or "").strip().lower()
|
||||
q, tokens = self._tokenize_episode_query(query)
|
||||
select_score_sql = "0.0 AS lexical_score"
|
||||
order_sql = f"{effective_end} DESC, e.updated_at DESC"
|
||||
select_params: List[Any] = []
|
||||
query_params: List[Any] = []
|
||||
if q:
|
||||
like = f"%{q}%"
|
||||
title_expr = "LOWER(COALESCE(e.title, '')) LIKE ?"
|
||||
summary_expr = "LOWER(COALESCE(e.summary, '')) LIKE ?"
|
||||
keywords_expr = "LOWER(COALESCE(e.keywords_json, '')) LIKE ?"
|
||||
participants_expr = "LOWER(COALESCE(e.participants_json, '')) LIKE ?"
|
||||
conditions.append(
|
||||
f"({title_expr} OR {summary_expr} OR {keywords_expr} OR {participants_expr})"
|
||||
field_exprs = {
|
||||
"title": "LOWER(COALESCE(e.title, ''))",
|
||||
"summary": "LOWER(COALESCE(e.summary, ''))",
|
||||
"keywords": "LOWER(COALESCE(e.keywords_json, ''))",
|
||||
"participants": "LOWER(COALESCE(e.participants_json, ''))",
|
||||
}
|
||||
|
||||
score_parts: List[str] = []
|
||||
phrase_like = f"%{q}%"
|
||||
score_parts.extend(
|
||||
[
|
||||
f"CASE WHEN {field_exprs['title']} LIKE ? THEN 6.0 ELSE 0.0 END",
|
||||
f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 4.5 ELSE 0.0 END",
|
||||
f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 3.0 ELSE 0.0 END",
|
||||
f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 2.0 ELSE 0.0 END",
|
||||
]
|
||||
)
|
||||
select_score_sql = (
|
||||
f"(CASE WHEN {title_expr} THEN 4.0 ELSE 0.0 END + "
|
||||
f"CASE WHEN {keywords_expr} THEN 3.0 ELSE 0.0 END + "
|
||||
f"CASE WHEN {summary_expr} THEN 2.0 ELSE 0.0 END + "
|
||||
f"CASE WHEN {participants_expr} THEN 1.0 ELSE 0.0 END) AS lexical_score"
|
||||
)
|
||||
select_params.extend([like, like, like, like])
|
||||
query_params.extend([like, like, like, like])
|
||||
select_params.extend([phrase_like, phrase_like, phrase_like, phrase_like])
|
||||
|
||||
token_predicates: List[str] = []
|
||||
for token in tokens:
|
||||
like = f"%{token}%"
|
||||
token_any = (
|
||||
f"({field_exprs['title']} LIKE ? OR "
|
||||
f"{field_exprs['summary']} LIKE ? OR "
|
||||
f"{field_exprs['keywords']} LIKE ? OR "
|
||||
f"{field_exprs['participants']} LIKE ?)"
|
||||
)
|
||||
token_predicates.append(token_any)
|
||||
query_params.extend([like, like, like, like])
|
||||
|
||||
score_parts.append(
|
||||
"("
|
||||
f"CASE WHEN {field_exprs['title']} LIKE ? THEN 3.0 ELSE 0.0 END + "
|
||||
f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 2.5 ELSE 0.0 END + "
|
||||
f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 2.0 ELSE 0.0 END + "
|
||||
f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 1.5 ELSE 0.0 END + "
|
||||
f"CASE WHEN {token_any.replace('?', '?')} THEN 2.0 ELSE 0.0 END"
|
||||
")"
|
||||
)
|
||||
select_params.extend([like, like, like, like, like, like, like, like])
|
||||
|
||||
if token_predicates:
|
||||
conditions.append("(" + " OR ".join(token_predicates) + ")")
|
||||
|
||||
select_score_sql = f"({' + '.join(score_parts)}) AS lexical_score"
|
||||
order_sql = f"lexical_score DESC, {effective_end} DESC, e.updated_at DESC"
|
||||
|
||||
where_sql = ("WHERE " + " AND ".join(conditions)) if conditions else ""
|
||||
|
||||
@@ -302,7 +302,7 @@ class AggregateQueryService:
|
||||
)
|
||||
for (branch_name, _), payload in zip(scheduled, done):
|
||||
if isinstance(payload, Exception):
|
||||
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
|
||||
logger.error(f"aggregate branch failed: branch={branch_name} error={payload}")
|
||||
normalized = self._normalize_branch_payload(
|
||||
branch_name,
|
||||
{
|
||||
|
||||
@@ -70,7 +70,7 @@ class EpisodeRetrievalService:
|
||||
temporal=temporal,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
|
||||
logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}")
|
||||
else:
|
||||
paragraph_rank_map: Dict[str, int] = {}
|
||||
relation_rank_map: Dict[str, int] = {}
|
||||
|
||||
304
plugins/A_memorix/core/utils/episode_segmentation_service.py
Normal file
304
plugins/A_memorix/core/utils/episode_segmentation_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Episode 语义切分服务(LLM 主路径)。
|
||||
|
||||
职责:
|
||||
1. 组装语义切分提示词
|
||||
2. 调用 LLM 生成结构化 episode JSON
|
||||
3. 严格校验输出结构,返回标准化结果
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.config.config import model_config as host_model_config
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeSegmentationService")
|
||||
|
||||
|
||||
class EpisodeSegmentationService:
|
||||
"""基于 LLM 的 episode 语义切分服务。"""
|
||||
|
||||
SEGMENTATION_VERSION = "episode_mvp_v1"
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None):
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _is_task_config(obj: Any) -> bool:
|
||||
return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", []))
|
||||
|
||||
def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig:
|
||||
return TaskConfig(
|
||||
model_list=[model_name],
|
||||
max_tokens=template.max_tokens,
|
||||
temperature=template.temperature,
|
||||
slow_threshold=template.slow_threshold,
|
||||
selection_strategy=template.selection_strategy,
|
||||
)
|
||||
|
||||
def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]:
|
||||
preferred = ("utils", "replyer", "planner", "tool_use")
|
||||
for task_name in preferred:
|
||||
cfg = available_tasks.get(task_name)
|
||||
if self._is_task_config(cfg):
|
||||
return cfg
|
||||
for task_name, cfg in available_tasks.items():
|
||||
if task_name != "embedding" and self._is_task_config(cfg):
|
||||
return cfg
|
||||
for cfg in available_tasks.values():
|
||||
if self._is_task_config(cfg):
|
||||
return cfg
|
||||
return None
|
||||
|
||||
def _resolve_model_config(self) -> Tuple[Optional[Any], str]:
|
||||
available_tasks = llm_api.get_available_models() or {}
|
||||
if not available_tasks:
|
||||
return None, "unavailable"
|
||||
|
||||
selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip()
|
||||
model_dict = getattr(host_model_config, "models_dict", {}) or {}
|
||||
|
||||
if selector and selector.lower() != "auto":
|
||||
direct_task = available_tasks.get(selector)
|
||||
if self._is_task_config(direct_task):
|
||||
return direct_task, selector
|
||||
|
||||
if selector in model_dict:
|
||||
template = self._pick_template_task(available_tasks)
|
||||
if template is not None:
|
||||
return self._build_single_model_task(selector, template), selector
|
||||
|
||||
logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto")
|
||||
|
||||
for task_name in ("utils", "replyer", "planner", "tool_use"):
|
||||
cfg = available_tasks.get(task_name)
|
||||
if self._is_task_config(cfg):
|
||||
return cfg, task_name
|
||||
|
||||
fallback = self._pick_template_task(available_tasks)
|
||||
if fallback is not None:
|
||||
return fallback, "auto"
|
||||
return None, "unavailable"
|
||||
|
||||
@staticmethod
|
||||
def _clamp_score(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
num = float(value)
|
||||
except Exception:
|
||||
num = default
|
||||
if num < 0.0:
|
||||
return 0.0
|
||||
if num > 1.0:
|
||||
return 1.0
|
||||
return num
|
||||
|
||||
@staticmethod
|
||||
def _safe_json_loads(text: str) -> Dict[str, Any]:
|
||||
raw = str(text or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("empty_response")
|
||||
|
||||
if "```" in raw:
|
||||
raw = raw.replace("```json", "```").replace("```JSON", "```")
|
||||
parts = raw.split("```")
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part.startswith("{") and part.endswith("}"):
|
||||
raw = part
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
start = raw.find("{")
|
||||
end = raw.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
candidate = raw[start : end + 1]
|
||||
data = json.loads(candidate)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
|
||||
raise ValueError("invalid_json_response")
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
window_start: Optional[float],
|
||||
window_end: Optional[float],
|
||||
paragraphs: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
rows: List[str] = []
|
||||
for idx, item in enumerate(paragraphs, 1):
|
||||
p_hash = str(item.get("hash", "") or "").strip()
|
||||
content = str(item.get("content", "") or "").strip().replace("\r\n", "\n")
|
||||
content = content[:800]
|
||||
event_start = item.get("event_time_start")
|
||||
event_end = item.get("event_time_end")
|
||||
event_time = item.get("event_time")
|
||||
rows.append(
|
||||
(
|
||||
f"[{idx}] hash={p_hash}\n"
|
||||
f"event_time={event_time}\n"
|
||||
f"event_time_start={event_start}\n"
|
||||
f"event_time_end={event_end}\n"
|
||||
f"content={content}"
|
||||
)
|
||||
)
|
||||
|
||||
source_text = str(source or "").strip() or "unknown"
|
||||
return (
|
||||
"You are an episode segmentation engine.\n"
|
||||
"Group the given paragraphs into one or more coherent episodes.\n"
|
||||
"Return JSON ONLY. No markdown, no explanation.\n"
|
||||
"\n"
|
||||
"Hard JSON schema:\n"
|
||||
"{\n"
|
||||
' "episodes": [\n'
|
||||
" {\n"
|
||||
' "title": "string",\n'
|
||||
' "summary": "string",\n'
|
||||
' "paragraph_hashes": ["hash1", "hash2"],\n'
|
||||
' "participants": ["person1", "person2"],\n'
|
||||
' "keywords": ["kw1", "kw2"],\n'
|
||||
' "time_confidence": 0.0,\n'
|
||||
' "llm_confidence": 0.0\n'
|
||||
" }\n"
|
||||
" ]\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"Rules:\n"
|
||||
"1) paragraph_hashes must come from input only.\n"
|
||||
"2) title and summary must be non-empty.\n"
|
||||
"3) keep participants/keywords concise and deduplicated.\n"
|
||||
"4) if uncertain, still provide best effort confidence values.\n"
|
||||
"\n"
|
||||
f"source={source_text}\n"
|
||||
f"window_start={window_start}\n"
|
||||
f"window_end={window_end}\n"
|
||||
"paragraphs:\n"
|
||||
+ "\n\n".join(rows)
|
||||
)
|
||||
|
||||
def _normalize_episodes(
|
||||
self,
|
||||
*,
|
||||
payload: Dict[str, Any],
|
||||
input_hashes: List[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
raw_episodes = payload.get("episodes")
|
||||
if not isinstance(raw_episodes, list):
|
||||
raise ValueError("episodes_missing_or_not_list")
|
||||
|
||||
valid_hashes = set(input_hashes)
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for item in raw_episodes:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
title = str(item.get("title", "") or "").strip()
|
||||
summary = str(item.get("summary", "") or "").strip()
|
||||
if not title or not summary:
|
||||
continue
|
||||
|
||||
raw_hashes = item.get("paragraph_hashes")
|
||||
if not isinstance(raw_hashes, list):
|
||||
continue
|
||||
|
||||
dedup_hashes: List[str] = []
|
||||
seen_hashes = set()
|
||||
for h in raw_hashes:
|
||||
token = str(h or "").strip()
|
||||
if not token or token in seen_hashes or token not in valid_hashes:
|
||||
continue
|
||||
seen_hashes.add(token)
|
||||
dedup_hashes.append(token)
|
||||
|
||||
if not dedup_hashes:
|
||||
continue
|
||||
|
||||
participants = []
|
||||
for p in item.get("participants", []) or []:
|
||||
token = str(p or "").strip()
|
||||
if token:
|
||||
participants.append(token)
|
||||
|
||||
keywords = []
|
||||
for kw in item.get("keywords", []) or []:
|
||||
token = str(kw or "").strip()
|
||||
if token:
|
||||
keywords.append(token)
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"paragraph_hashes": dedup_hashes,
|
||||
"participants": participants[:16],
|
||||
"keywords": keywords[:20],
|
||||
"time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0),
|
||||
"llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5),
|
||||
}
|
||||
)
|
||||
|
||||
if not normalized:
|
||||
raise ValueError("episodes_all_invalid")
|
||||
return normalized
|
||||
|
||||
async def segment(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
window_start: Optional[float],
|
||||
window_end: Optional[float],
|
||||
paragraphs: List[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
if not paragraphs:
|
||||
raise ValueError("paragraphs_empty")
|
||||
|
||||
model_config, model_label = self._resolve_model_config()
|
||||
if model_config is None:
|
||||
raise RuntimeError("episode segmentation model unavailable")
|
||||
|
||||
prompt = self._build_prompt(
|
||||
source=source,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
paragraphs=paragraphs,
|
||||
)
|
||||
success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt=prompt,
|
||||
model_config=model_config,
|
||||
request_type="A_Memorix.EpisodeSegmentation",
|
||||
)
|
||||
if not success or not response:
|
||||
raise RuntimeError("llm_generate_failed")
|
||||
|
||||
payload = self._safe_json_loads(str(response))
|
||||
input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs]
|
||||
episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes)
|
||||
|
||||
return {
|
||||
"episodes": episodes,
|
||||
"segmentation_model": model_label,
|
||||
"segmentation_version": self.SEGMENTATION_VERSION,
|
||||
}
|
||||
|
||||
558
plugins/A_memorix/core/utils/episode_service.py
Normal file
558
plugins/A_memorix/core/utils/episode_service.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
Episode 聚合与落库服务。
|
||||
|
||||
流程:
|
||||
1. 从 pending 队列读取段落并组批
|
||||
2. 按 source + 时间窗口切组
|
||||
3. 调用 LLM 语义切分
|
||||
4. 写入 episodes + episode_paragraphs
|
||||
5. LLM 失败时使用确定性 fallback
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .episode_segmentation_service import EpisodeSegmentationService
|
||||
from .hash import compute_hash
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeService")
|
||||
|
||||
|
||||
class EpisodeService:
|
||||
"""Episode MVP 后台处理服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_store: Any,
|
||||
plugin_config: Optional[Any] = None,
|
||||
segmentation_service: Optional[EpisodeSegmentationService] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.segmentation_service = segmentation_service or EpisodeSegmentationService(
|
||||
plugin_config=self._config_dict(),
|
||||
)
|
||||
|
||||
def _config_dict(self) -> Dict[str, Any]:
|
||||
if isinstance(self.plugin_config, dict):
|
||||
return self.plugin_config
|
||||
return {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(self.plugin_config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _to_optional_float(value: Any) -> Optional[float]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _clamp_score(value: Any, default: float = 1.0) -> float:
|
||||
try:
|
||||
num = float(value)
|
||||
except Exception:
|
||||
num = default
|
||||
if num < 0.0:
|
||||
return 0.0
|
||||
if num > 1.0:
|
||||
return 1.0
|
||||
return num
|
||||
|
||||
@staticmethod
|
||||
def _paragraph_anchor(paragraph: Dict[str, Any]) -> float:
|
||||
for key in ("event_time_end", "event_time_start", "event_time", "created_at"):
|
||||
value = paragraph.get(key)
|
||||
try:
|
||||
if value is not None:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]:
|
||||
return (
|
||||
EpisodeService._paragraph_anchor(paragraph),
|
||||
str(paragraph.get("hash", "") or ""),
|
||||
)
|
||||
|
||||
def load_pending_paragraphs(
|
||||
self,
|
||||
pending_rows: List[Dict[str, Any]],
|
||||
) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""
|
||||
将 pending 行展开为段落上下文。
|
||||
|
||||
Returns:
|
||||
(loaded_paragraphs, missing_hashes)
|
||||
"""
|
||||
loaded: List[Dict[str, Any]] = []
|
||||
missing: List[str] = []
|
||||
for row in pending_rows or []:
|
||||
p_hash = str(row.get("paragraph_hash", "") or "").strip()
|
||||
if not p_hash:
|
||||
continue
|
||||
|
||||
paragraph = self.metadata_store.get_paragraph(p_hash)
|
||||
if not paragraph:
|
||||
missing.append(p_hash)
|
||||
continue
|
||||
|
||||
loaded.append(
|
||||
{
|
||||
"hash": p_hash,
|
||||
"source": str(row.get("source") or paragraph.get("source") or "").strip(),
|
||||
"content": str(paragraph.get("content", "") or ""),
|
||||
"created_at": self._to_optional_float(paragraph.get("created_at"))
|
||||
or self._to_optional_float(row.get("created_at"))
|
||||
or 0.0,
|
||||
"event_time": self._to_optional_float(paragraph.get("event_time")),
|
||||
"event_time_start": self._to_optional_float(paragraph.get("event_time_start")),
|
||||
"event_time_end": self._to_optional_float(paragraph.get("event_time_end")),
|
||||
"time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None,
|
||||
"time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0),
|
||||
}
|
||||
)
|
||||
return loaded, missing
|
||||
|
||||
def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。
|
||||
"""
|
||||
if not paragraphs:
|
||||
return []
|
||||
|
||||
max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20)))
|
||||
max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000)))
|
||||
window_seconds = max(
|
||||
60.0,
|
||||
float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0,
|
||||
)
|
||||
|
||||
by_source: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for paragraph in paragraphs:
|
||||
source = str(paragraph.get("source", "") or "").strip()
|
||||
by_source.setdefault(source, []).append(paragraph)
|
||||
|
||||
groups: List[Dict[str, Any]] = []
|
||||
for source, items in by_source.items():
|
||||
ordered = sorted(items, key=self._paragraph_sort_key)
|
||||
|
||||
current: List[Dict[str, Any]] = []
|
||||
current_chars = 0
|
||||
last_anchor: Optional[float] = None
|
||||
|
||||
def flush() -> None:
|
||||
nonlocal current, current_chars, last_anchor
|
||||
if not current:
|
||||
return
|
||||
sorted_current = sorted(current, key=self._paragraph_sort_key)
|
||||
groups.append(
|
||||
{
|
||||
"source": source,
|
||||
"paragraphs": sorted_current,
|
||||
}
|
||||
)
|
||||
current = []
|
||||
current_chars = 0
|
||||
last_anchor = None
|
||||
|
||||
for paragraph in ordered:
|
||||
anchor = self._paragraph_anchor(paragraph)
|
||||
content_len = len(str(paragraph.get("content", "") or ""))
|
||||
|
||||
need_flush = False
|
||||
if current:
|
||||
if len(current) >= max_paragraphs:
|
||||
need_flush = True
|
||||
elif current_chars + content_len > max_chars:
|
||||
need_flush = True
|
||||
elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds:
|
||||
need_flush = True
|
||||
|
||||
if need_flush:
|
||||
flush()
|
||||
|
||||
current.append(paragraph)
|
||||
current_chars += content_len
|
||||
last_anchor = anchor
|
||||
|
||||
flush()
|
||||
|
||||
groups.sort(
|
||||
key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0
|
||||
)
|
||||
return groups
|
||||
|
||||
def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]:
|
||||
starts: List[float] = []
|
||||
ends: List[float] = []
|
||||
granularity_priority = {
|
||||
"minute": 4,
|
||||
"hour": 3,
|
||||
"day": 2,
|
||||
"month": 1,
|
||||
"year": 0,
|
||||
}
|
||||
granularity = None
|
||||
granularity_rank = -1
|
||||
conf_values: List[float] = []
|
||||
|
||||
for p in paragraphs:
|
||||
s = self._to_optional_float(p.get("event_time_start"))
|
||||
e = self._to_optional_float(p.get("event_time_end"))
|
||||
t = self._to_optional_float(p.get("event_time"))
|
||||
c = self._to_optional_float(p.get("created_at"))
|
||||
|
||||
start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c))
|
||||
end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c))
|
||||
|
||||
if start_candidate is not None:
|
||||
starts.append(start_candidate)
|
||||
if end_candidate is not None:
|
||||
ends.append(end_candidate)
|
||||
|
||||
g = str(p.get("time_granularity", "") or "").strip().lower()
|
||||
if g in granularity_priority and granularity_priority[g] > granularity_rank:
|
||||
granularity_rank = granularity_priority[g]
|
||||
granularity = g
|
||||
|
||||
conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0))
|
||||
|
||||
time_start = min(starts) if starts else None
|
||||
time_end = max(ends) if ends else None
|
||||
time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0
|
||||
return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0)
|
||||
|
||||
def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]:
|
||||
seen = set()
|
||||
participants: List[str] = []
|
||||
for p_hash in paragraph_hashes:
|
||||
try:
|
||||
entities = self.metadata_store.get_paragraph_entities(p_hash)
|
||||
except Exception:
|
||||
entities = []
|
||||
for item in entities:
|
||||
name = str(item.get("name", "") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
participants.append(name)
|
||||
if len(participants) >= limit:
|
||||
return participants
|
||||
return participants
|
||||
|
||||
@staticmethod
|
||||
def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]:
|
||||
token_counter: Counter[str] = Counter()
|
||||
token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}")
|
||||
stop_words = {
|
||||
"the",
|
||||
"and",
|
||||
"that",
|
||||
"this",
|
||||
"with",
|
||||
"from",
|
||||
"for",
|
||||
"have",
|
||||
"will",
|
||||
"your",
|
||||
"you",
|
||||
"我们",
|
||||
"你们",
|
||||
"他们",
|
||||
"以及",
|
||||
"一个",
|
||||
"这个",
|
||||
"那个",
|
||||
"然后",
|
||||
"因为",
|
||||
"所以",
|
||||
}
|
||||
for p in paragraphs:
|
||||
text = str(p.get("content", "") or "").lower()
|
||||
for token in token_pattern.findall(text):
|
||||
if token in stop_words:
|
||||
continue
|
||||
token_counter[token] += 1
|
||||
|
||||
return [token for token, _ in token_counter.most_common(limit)]
|
||||
|
||||
def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
paragraphs = group.get("paragraphs", []) or []
|
||||
source = str(group.get("source", "") or "").strip()
|
||||
hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
|
||||
snippets = []
|
||||
for p in paragraphs[:3]:
|
||||
text = str(p.get("content", "") or "").strip().replace("\n", " ")
|
||||
if text:
|
||||
snippets.append(text[:140])
|
||||
summary = ";".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。"
|
||||
|
||||
time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs)
|
||||
participants = self._collect_participants(hashes, limit=12)
|
||||
keywords = self._derive_keywords(paragraphs, limit=10)
|
||||
|
||||
if time_start is not None:
|
||||
day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d")
|
||||
title = f"{source or 'unknown'} {day_text} 情景片段"
|
||||
else:
|
||||
title = f"{source or 'unknown'} 情景片段"
|
||||
|
||||
return {
|
||||
"title": title[:80],
|
||||
"summary": summary,
|
||||
"paragraph_hashes": hashes,
|
||||
"participants": participants,
|
||||
"keywords": keywords,
|
||||
"time_confidence": time_conf,
|
||||
"llm_confidence": 0.0,
|
||||
"event_time_start": time_start,
|
||||
"event_time_end": time_end,
|
||||
"time_granularity": granularity,
|
||||
"segmentation_model": "fallback_rule",
|
||||
"segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]:
|
||||
in_group = set(group_hashes_ordered)
|
||||
dedup: List[str] = []
|
||||
seen = set()
|
||||
for h in episode_hashes or []:
|
||||
token = str(h or "").strip()
|
||||
if not token or token not in in_group or token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
dedup.append(token)
|
||||
return dedup
|
||||
|
||||
async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
paragraphs = group.get("paragraphs", []) or []
|
||||
if not paragraphs:
|
||||
return {
|
||||
"payloads": [],
|
||||
"done_hashes": [],
|
||||
"episode_count": 0,
|
||||
"fallback_count": 0,
|
||||
}
|
||||
|
||||
source = str(group.get("source", "") or "").strip()
|
||||
group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
|
||||
group_start, group_end, _, _ = self._compute_time_meta(paragraphs)
|
||||
|
||||
fallback_used = False
|
||||
segmentation_model = "fallback_rule"
|
||||
segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION
|
||||
|
||||
try:
|
||||
llm_result = await self.segmentation_service.segment(
|
||||
source=source,
|
||||
window_start=group_start,
|
||||
window_end=group_end,
|
||||
paragraphs=paragraphs,
|
||||
)
|
||||
episodes = list(llm_result.get("episodes") or [])
|
||||
segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto"
|
||||
segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION
|
||||
if not episodes:
|
||||
raise ValueError("llm_empty_episodes")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Episode segmentation fallback: "
|
||||
f"source={source} "
|
||||
f"size={len(group_hashes)} "
|
||||
f"err={e}"
|
||||
)
|
||||
episodes = [self._build_fallback_episode(group)]
|
||||
fallback_used = True
|
||||
|
||||
stored_payloads: List[Dict[str, Any]] = []
|
||||
for episode in episodes:
|
||||
ordered_hashes = self._normalize_episode_hashes(
|
||||
episode_hashes=episode.get("paragraph_hashes", []),
|
||||
group_hashes_ordered=group_hashes,
|
||||
)
|
||||
if not ordered_hashes:
|
||||
continue
|
||||
|
||||
sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)]
|
||||
event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs)
|
||||
|
||||
participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()]
|
||||
keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()]
|
||||
if not participants:
|
||||
participants = self._collect_participants(ordered_hashes, limit=16)
|
||||
if not keywords:
|
||||
keywords = self._derive_keywords(sub_paragraphs, limit=12)
|
||||
|
||||
title = str(episode.get("title", "") or "").strip()[:120]
|
||||
summary = str(episode.get("summary", "") or "").strip()[:2000]
|
||||
if not title or not summary:
|
||||
continue
|
||||
|
||||
seed = json.dumps(
|
||||
{
|
||||
"source": source,
|
||||
"hashes": ordered_hashes,
|
||||
"version": segmentation_version,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
)
|
||||
episode_id = compute_hash(seed)
|
||||
|
||||
payload = {
|
||||
"episode_id": episode_id,
|
||||
"source": source or None,
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"event_time_start": episode.get("event_time_start", event_start),
|
||||
"event_time_end": episode.get("event_time_end", event_end),
|
||||
"time_granularity": episode.get("time_granularity", granularity),
|
||||
"time_confidence": self._clamp_score(
|
||||
episode.get("time_confidence"),
|
||||
default=time_conf_default,
|
||||
),
|
||||
"participants": participants[:16],
|
||||
"keywords": keywords[:20],
|
||||
"evidence_ids": ordered_hashes,
|
||||
"paragraph_count": len(ordered_hashes),
|
||||
"llm_confidence": self._clamp_score(
|
||||
episode.get("llm_confidence"),
|
||||
default=0.0 if fallback_used else 0.6,
|
||||
),
|
||||
"segmentation_model": (
|
||||
str(episode.get("segmentation_model", "") or "").strip()
|
||||
or ("fallback_rule" if fallback_used else segmentation_model)
|
||||
),
|
||||
"segmentation_version": (
|
||||
str(episode.get("segmentation_version", "") or "").strip()
|
||||
or segmentation_version
|
||||
),
|
||||
}
|
||||
stored_payloads.append(payload)
|
||||
|
||||
return {
|
||||
"payloads": stored_payloads,
|
||||
"done_hashes": group_hashes,
|
||||
"episode_count": len(stored_payloads),
|
||||
"fallback_count": 1 if fallback_used else 0,
|
||||
}
|
||||
|
||||
async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = await self._build_episode_payloads_for_group(group)
|
||||
stored_count = 0
|
||||
for payload in result.get("payloads") or []:
|
||||
stored = self.metadata_store.upsert_episode(payload)
|
||||
final_id = str(stored.get("episode_id") or payload.get("episode_id") or "")
|
||||
if final_id:
|
||||
self.metadata_store.bind_episode_paragraphs(
|
||||
final_id,
|
||||
list(payload.get("evidence_ids") or []),
|
||||
)
|
||||
stored_count += 1
|
||||
|
||||
result["episode_count"] = stored_count
|
||||
return {
|
||||
"done_hashes": list(result.get("done_hashes") or []),
|
||||
"episode_count": stored_count,
|
||||
"fallback_count": int(result.get("fallback_count") or 0),
|
||||
}
|
||||
|
||||
async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
loaded, missing_hashes = self.load_pending_paragraphs(pending_rows)
|
||||
groups = self.group_paragraphs(loaded)
|
||||
|
||||
done_hashes: List[str] = list(missing_hashes)
|
||||
failed_hashes: Dict[str, str] = {}
|
||||
episode_count = 0
|
||||
fallback_count = 0
|
||||
|
||||
for group in groups:
|
||||
group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])]
|
||||
try:
|
||||
result = await self.process_group(group)
|
||||
done_hashes.extend(result.get("done_hashes") or [])
|
||||
episode_count += int(result.get("episode_count") or 0)
|
||||
fallback_count += int(result.get("fallback_count") or 0)
|
||||
except Exception as e:
|
||||
err = str(e)[:500]
|
||||
for h in group_hashes:
|
||||
if h:
|
||||
failed_hashes[h] = err
|
||||
|
||||
dedup_done = list(dict.fromkeys([h for h in done_hashes if h]))
|
||||
return {
|
||||
"done_hashes": dedup_done,
|
||||
"failed_hashes": failed_hashes,
|
||||
"episode_count": episode_count,
|
||||
"fallback_count": fallback_count,
|
||||
"missing_count": len(missing_hashes),
|
||||
"group_count": len(groups),
|
||||
}
|
||||
|
||||
async def rebuild_source(self, source: str) -> Dict[str, Any]:
|
||||
token = str(source or "").strip()
|
||||
if not token:
|
||||
return {
|
||||
"source": "",
|
||||
"episode_count": 0,
|
||||
"fallback_count": 0,
|
||||
"group_count": 0,
|
||||
"paragraph_count": 0,
|
||||
}
|
||||
|
||||
paragraphs = self.metadata_store.get_live_paragraphs_by_source(token)
|
||||
if not paragraphs:
|
||||
replace_result = self.metadata_store.replace_episodes_for_source(token, [])
|
||||
return {
|
||||
"source": token,
|
||||
"episode_count": int(replace_result.get("episode_count") or 0),
|
||||
"fallback_count": 0,
|
||||
"group_count": 0,
|
||||
"paragraph_count": 0,
|
||||
}
|
||||
|
||||
groups = self.group_paragraphs(paragraphs)
|
||||
payloads: List[Dict[str, Any]] = []
|
||||
fallback_count = 0
|
||||
|
||||
for group in groups:
|
||||
result = await self._build_episode_payloads_for_group(group)
|
||||
payloads.extend(list(result.get("payloads") or []))
|
||||
fallback_count += int(result.get("fallback_count") or 0)
|
||||
|
||||
replace_result = self.metadata_store.replace_episodes_for_source(token, payloads)
|
||||
return {
|
||||
"source": token,
|
||||
"episode_count": int(replace_result.get("episode_count") or 0),
|
||||
"fallback_count": fallback_count,
|
||||
"group_count": len(groups),
|
||||
"paragraph_count": len(paragraphs),
|
||||
}
|
||||
@@ -9,7 +9,11 @@ import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
@@ -120,31 +124,40 @@ class PersonProfileService:
|
||||
if not key:
|
||||
return ""
|
||||
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id)
|
||||
.where(
|
||||
or_(
|
||||
PersonInfo.person_name == key,
|
||||
PersonInfo.user_nickname == key,
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id)
|
||||
.where(PersonInfo.group_cardname.contains(key))
|
||||
.limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
except Exception as e:
|
||||
logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}")
|
||||
|
||||
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
|
||||
return key.lower()
|
||||
|
||||
try:
|
||||
record = (
|
||||
PersonInfo.select(PersonInfo.person_id)
|
||||
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
|
||||
.first()
|
||||
)
|
||||
if record and record.person_id:
|
||||
return str(record.person_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
record = (
|
||||
PersonInfo.select(PersonInfo.person_id)
|
||||
.where(PersonInfo.group_nick_name.contains(key))
|
||||
.first()
|
||||
)
|
||||
if record and record.person_id:
|
||||
return str(record.person_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
|
||||
@@ -160,7 +173,7 @@ class PersonProfileService:
|
||||
names: List[str] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
value = str(item.get("group_nick_name", "")).strip()
|
||||
value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip()
|
||||
if value:
|
||||
names.append(value)
|
||||
elif isinstance(item, str):
|
||||
@@ -193,6 +206,42 @@ class PersonProfileService:
|
||||
traits.append(text)
|
||||
return traits[:10]
|
||||
|
||||
def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]:
|
||||
"""当人物主档案缺失时,从已有记忆证据里回捞可用别名。"""
|
||||
if not person_id:
|
||||
return [], ""
|
||||
|
||||
aliases: List[str] = []
|
||||
primary_name = ""
|
||||
seen = set()
|
||||
|
||||
try:
|
||||
paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}")
|
||||
return [], ""
|
||||
|
||||
for paragraph in paragraphs[:20]:
|
||||
paragraph_hash = str(paragraph.get("hash", "") or "").strip()
|
||||
if not paragraph_hash:
|
||||
continue
|
||||
try:
|
||||
paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash)
|
||||
except Exception:
|
||||
paragraph_entities = []
|
||||
for entity in paragraph_entities:
|
||||
name = str(entity.get("name", "") or "").strip()
|
||||
if not name or name == person_id:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
aliases.append(name)
|
||||
if not primary_name:
|
||||
primary_name = name
|
||||
return aliases, primary_name
|
||||
|
||||
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
|
||||
"""获取人物别名集合、主展示名、记忆特征。"""
|
||||
aliases: List[str] = []
|
||||
@@ -200,18 +249,28 @@ class PersonProfileService:
|
||||
memory_traits: List[str] = []
|
||||
if not person_id:
|
||||
return aliases, primary_name, memory_traits
|
||||
recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id)
|
||||
try:
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if not record:
|
||||
return aliases, primary_name, memory_traits
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
record = session.exec(
|
||||
select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1)
|
||||
).first()
|
||||
if not record:
|
||||
return recovered_aliases, recovered_primary_name or person_id, memory_traits
|
||||
person_name = str(getattr(record, "person_name", "") or "").strip()
|
||||
nickname = str(getattr(record, "nickname", "") or "").strip()
|
||||
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
|
||||
nickname = str(getattr(record, "user_nickname", "") or "").strip()
|
||||
group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None))
|
||||
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
|
||||
|
||||
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
|
||||
primary_name = (
|
||||
person_name
|
||||
or nickname
|
||||
or recovered_primary_name
|
||||
or str(getattr(record, "user_id", "") or "").strip()
|
||||
or person_id
|
||||
)
|
||||
|
||||
candidates = [person_name, nickname] + group_nicks
|
||||
candidates = [person_name, nickname] + group_nicks + recovered_aliases
|
||||
seen = set()
|
||||
for item in candidates:
|
||||
norm = str(item or "").strip()
|
||||
|
||||
@@ -82,8 +82,9 @@ class RelationWriteService:
|
||||
)
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
logger.info(
|
||||
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
|
||||
hash_value[:16],
|
||||
"metric.relation_vector_write_success=1 "
|
||||
"metric.relation_vector_write_success_count=1 "
|
||||
f"hash={hash_value[:16]}"
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
@@ -109,9 +110,10 @@ class RelationWriteService:
|
||||
bump_retry=True,
|
||||
)
|
||||
logger.warning(
|
||||
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
|
||||
hash_value[:16],
|
||||
err,
|
||||
"metric.relation_vector_write_fail=1 "
|
||||
"metric.relation_vector_write_fail_count=1 "
|
||||
f"hash={hash_value[:16]} "
|
||||
f"err={err}"
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
|
||||
1857
plugins/A_memorix/core/utils/retrieval_tuning_manager.py
Normal file
1857
plugins/A_memorix/core/utils/retrieval_tuning_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -61,6 +61,29 @@ def _build_report(
|
||||
}
|
||||
|
||||
|
||||
def _normalize_encoded_vector(encoded: Any) -> np.ndarray:
|
||||
if encoded is None:
|
||||
raise ValueError("embedding encode returned None")
|
||||
|
||||
if isinstance(encoded, np.ndarray):
|
||||
array = encoded
|
||||
else:
|
||||
array = np.asarray(encoded, dtype=np.float32)
|
||||
|
||||
if array.ndim == 2:
|
||||
if array.shape[0] != 1:
|
||||
raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}")
|
||||
array = array[0]
|
||||
|
||||
if array.ndim != 1:
|
||||
raise ValueError(f"embedding encode returned invalid ndim={array.ndim}")
|
||||
if array.size <= 0:
|
||||
raise ValueError("embedding encode returned empty vector")
|
||||
if not np.all(np.isfinite(array)):
|
||||
raise ValueError("embedding encode returned non-finite values")
|
||||
return array.astype(np.float32, copy=False)
|
||||
|
||||
|
||||
async def run_embedding_runtime_self_check(
|
||||
*,
|
||||
config: Any,
|
||||
@@ -91,13 +114,11 @@ async def run_embedding_runtime_self_check(
|
||||
try:
|
||||
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
|
||||
encoded = await embedding_manager.encode(sample_text)
|
||||
if isinstance(encoded, np.ndarray):
|
||||
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
|
||||
else:
|
||||
encoded_dimension = len(encoded) if encoded is not None else 0
|
||||
encoded_array = _normalize_encoded_vector(encoded)
|
||||
encoded_dimension = int(encoded_array.shape[0])
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
logger.warning("embedding runtime self-check failed: %s", exc)
|
||||
logger.warning(f"embedding runtime self-check failed: {exc}")
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="embedding_probe_failed",
|
||||
|
||||
442
plugins/A_memorix/core/utils/search_execution_service.py
Normal file
442
plugins/A_memorix/core/utils/search_execution_service.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
统一检索执行服务。
|
||||
|
||||
用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import TemporalQueryOptions
|
||||
from .search_postprocess import (
|
||||
apply_safe_content_dedup,
|
||||
maybe_apply_smart_path_fallback,
|
||||
)
|
||||
from .time_parser import parse_query_time_range
|
||||
|
||||
logger = get_logger("A_Memorix.SearchExecutionService")
|
||||
|
||||
|
||||
def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any:
|
||||
if not isinstance(config, dict):
|
||||
return default
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def _sanitize_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchExecutionRequest:
|
||||
caller: str
|
||||
stream_id: Optional[str] = None
|
||||
group_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
query_type: str = "search" # search|semantic|time|hybrid
|
||||
query: str = ""
|
||||
top_k: Optional[int] = None
|
||||
time_from: Optional[str] = None
|
||||
time_to: Optional[str] = None
|
||||
person: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
use_threshold: bool = True
|
||||
enable_ppr: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchExecutionResult:
|
||||
success: bool
|
||||
error: str = ""
|
||||
query_type: str = "search"
|
||||
query: str = ""
|
||||
top_k: int = 10
|
||||
time_from: Optional[str] = None
|
||||
time_to: Optional[str] = None
|
||||
person: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
temporal: Optional[TemporalQueryOptions] = None
|
||||
results: List[Any] = field(default_factory=list)
|
||||
elapsed_ms: float = 0.0
|
||||
chat_filtered: bool = False
|
||||
dedup_hit: bool = False
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self.results)
|
||||
|
||||
|
||||
class SearchExecutionService:
|
||||
"""统一检索执行服务。"""
|
||||
|
||||
@staticmethod
|
||||
def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]:
|
||||
if isinstance(plugin_config, dict):
|
||||
plugin_instance = plugin_config.get("plugin_instance")
|
||||
if plugin_instance is not None:
|
||||
return plugin_instance
|
||||
|
||||
try:
|
||||
from ...plugin import AMemorixPlugin
|
||||
|
||||
return getattr(AMemorixPlugin, "get_global_instance", lambda: None)()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_query_type(raw_query_type: str) -> str:
|
||||
query_type = _sanitize_text(raw_query_type).lower() or "search"
|
||||
if query_type == "semantic":
|
||||
return "search"
|
||||
return query_type
|
||||
|
||||
@staticmethod
|
||||
def _resolve_runtime_component(
|
||||
plugin_config: Optional[dict],
|
||||
plugin_instance: Optional[Any],
|
||||
key: str,
|
||||
) -> Optional[Any]:
|
||||
if isinstance(plugin_config, dict):
|
||||
value = plugin_config.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
if plugin_instance is not None:
|
||||
value = getattr(plugin_instance, key, None)
|
||||
if value is not None:
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_top_k(
|
||||
plugin_config: Optional[dict],
|
||||
query_type: str,
|
||||
top_k_raw: Optional[Any],
|
||||
) -> Tuple[bool, int, str]:
|
||||
temporal_default_top_k = int(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10)
|
||||
)
|
||||
default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10
|
||||
if top_k_raw is None:
|
||||
return True, max(1, min(50, default_top_k)), ""
|
||||
try:
|
||||
top_k = int(top_k_raw)
|
||||
except (TypeError, ValueError):
|
||||
return False, 0, "top_k 参数必须为整数"
|
||||
return True, max(1, min(50, top_k)), ""
|
||||
|
||||
@staticmethod
|
||||
def _build_temporal(
|
||||
plugin_config: Optional[dict],
|
||||
query_type: str,
|
||||
time_from_raw: Optional[str],
|
||||
time_to_raw: Optional[str],
|
||||
person: Optional[str],
|
||||
source: Optional[str],
|
||||
) -> Tuple[bool, Optional[TemporalQueryOptions], str]:
|
||||
if query_type not in {"time", "hybrid"}:
|
||||
return True, None, ""
|
||||
|
||||
temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True))
|
||||
if not temporal_enabled:
|
||||
return False, None, "时序检索已禁用(retrieval.temporal.enabled=false)"
|
||||
|
||||
if not time_from_raw and not time_to_raw:
|
||||
return False, None, "time/hybrid 模式至少需要 time_from 或 time_to"
|
||||
|
||||
try:
|
||||
ts_from, ts_to = parse_query_time_range(
|
||||
str(time_from_raw) if time_from_raw is not None else None,
|
||||
str(time_to_raw) if time_to_raw is not None else None,
|
||||
)
|
||||
except ValueError as e:
|
||||
return False, None, f"时间参数错误: {e}"
|
||||
|
||||
temporal = TemporalQueryOptions(
|
||||
time_from=ts_from,
|
||||
time_to=ts_to,
|
||||
person=_sanitize_text(person) or None,
|
||||
source=_sanitize_text(source) or None,
|
||||
allow_created_fallback=bool(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True)
|
||||
),
|
||||
candidate_multiplier=int(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8)
|
||||
),
|
||||
max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)),
|
||||
)
|
||||
return True, temporal, ""
|
||||
|
||||
@staticmethod
|
||||
def _build_request_key(
|
||||
request: SearchExecutionRequest,
|
||||
query_type: str,
|
||||
top_k: int,
|
||||
temporal: Optional[TemporalQueryOptions],
|
||||
) -> str:
|
||||
payload = {
|
||||
"stream_id": _sanitize_text(request.stream_id),
|
||||
"query_type": query_type,
|
||||
"query": _sanitize_text(request.query),
|
||||
"time_from": _sanitize_text(request.time_from),
|
||||
"time_to": _sanitize_text(request.time_to),
|
||||
"time_from_ts": temporal.time_from if temporal else None,
|
||||
"time_to_ts": temporal.time_to if temporal else None,
|
||||
"person": _sanitize_text(request.person),
|
||||
"source": _sanitize_text(request.source),
|
||||
"top_k": int(top_k),
|
||||
"use_threshold": bool(request.use_threshold),
|
||||
"enable_ppr": bool(request.enable_ppr),
|
||||
}
|
||||
payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
||||
return hashlib.sha1(payload_json.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
*,
|
||||
retriever: Any,
|
||||
threshold_filter: Optional[Any],
|
||||
plugin_config: Optional[dict],
|
||||
request: SearchExecutionRequest,
|
||||
enforce_chat_filter: bool = True,
|
||||
reinforce_access: bool = True,
|
||||
) -> SearchExecutionResult:
|
||||
if retriever is None:
|
||||
return SearchExecutionResult(success=False, error="知识检索器未初始化")
|
||||
|
||||
query_type = SearchExecutionService._normalize_query_type(request.query_type)
|
||||
query = _sanitize_text(request.query)
|
||||
if query_type not in {"search", "time", "hybrid"}:
|
||||
return SearchExecutionResult(
|
||||
success=False,
|
||||
error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid)",
|
||||
)
|
||||
|
||||
if query_type in {"search", "hybrid"} and not query:
|
||||
return SearchExecutionResult(
|
||||
success=False,
|
||||
error="search/hybrid 模式必须提供 query",
|
||||
)
|
||||
|
||||
top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k(
|
||||
plugin_config, query_type, request.top_k
|
||||
)
|
||||
if not top_k_ok:
|
||||
return SearchExecutionResult(success=False, error=top_k_error)
|
||||
|
||||
temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal(
|
||||
plugin_config=plugin_config,
|
||||
query_type=query_type,
|
||||
time_from_raw=request.time_from,
|
||||
time_to_raw=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
)
|
||||
if not temporal_ok:
|
||||
return SearchExecutionResult(success=False, error=temporal_error)
|
||||
|
||||
plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config)
|
||||
if (
|
||||
enforce_chat_filter
|
||||
and plugin_instance is not None
|
||||
and hasattr(plugin_instance, "is_chat_enabled")
|
||||
):
|
||||
if not plugin_instance.is_chat_enabled(
|
||||
stream_id=request.stream_id,
|
||||
group_id=request.group_id,
|
||||
user_id=request.user_id,
|
||||
):
|
||||
logger.info(
|
||||
"检索请求被聊天过滤拦截: "
|
||||
f"caller={request.caller}, "
|
||||
f"stream_id={request.stream_id}"
|
||||
)
|
||||
return SearchExecutionResult(
|
||||
success=True,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
time_from=request.time_from,
|
||||
time_to=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
temporal=temporal,
|
||||
results=[],
|
||||
elapsed_ms=0.0,
|
||||
chat_filtered=True,
|
||||
dedup_hit=False,
|
||||
)
|
||||
|
||||
request_key = SearchExecutionService._build_request_key(
|
||||
request=request,
|
||||
query_type=query_type,
|
||||
top_k=top_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
|
||||
async def _executor() -> Dict[str, Any]:
|
||||
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
|
||||
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
|
||||
started_at = time.time()
|
||||
try:
|
||||
retrieved = await retriever.retrieve(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
|
||||
should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None
|
||||
if (
|
||||
query_type == "time"
|
||||
and not query
|
||||
and bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.time.skip_threshold_when_query_empty",
|
||||
True,
|
||||
)
|
||||
)
|
||||
):
|
||||
should_apply_threshold = False
|
||||
|
||||
if should_apply_threshold:
|
||||
retrieved = threshold_filter.filter(retrieved)
|
||||
|
||||
if (
|
||||
reinforce_access
|
||||
and plugin_instance is not None
|
||||
and hasattr(plugin_instance, "reinforce_access")
|
||||
):
|
||||
relation_hashes = [
|
||||
item.hash_value
|
||||
for item in retrieved
|
||||
if getattr(item, "result_type", "") == "relation"
|
||||
]
|
||||
if relation_hashes:
|
||||
await plugin_instance.reinforce_access(relation_hashes)
|
||||
|
||||
if query_type == "search":
|
||||
graph_store = SearchExecutionService._resolve_runtime_component(
|
||||
plugin_config, plugin_instance, "graph_store"
|
||||
)
|
||||
metadata_store = SearchExecutionService._resolve_runtime_component(
|
||||
plugin_config, plugin_instance, "metadata_store"
|
||||
)
|
||||
fallback_enabled = bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.smart_fallback.enabled",
|
||||
True,
|
||||
)
|
||||
)
|
||||
fallback_threshold = float(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.smart_fallback.threshold",
|
||||
0.6,
|
||||
)
|
||||
)
|
||||
retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback(
|
||||
query=query,
|
||||
results=list(retrieved),
|
||||
graph_store=graph_store,
|
||||
metadata_store=metadata_store,
|
||||
enabled=fallback_enabled,
|
||||
threshold=fallback_threshold,
|
||||
)
|
||||
if fallback_triggered:
|
||||
logger.info(
|
||||
"metric.smart_fallback_triggered_count=1 "
|
||||
f"caller={request.caller} "
|
||||
f"added={fallback_added}"
|
||||
)
|
||||
|
||||
dedup_enabled = bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.safe_content_dedup.enabled",
|
||||
True,
|
||||
)
|
||||
)
|
||||
if dedup_enabled:
|
||||
retrieved, removed_count = apply_safe_content_dedup(list(retrieved))
|
||||
if removed_count > 0:
|
||||
logger.info(
|
||||
f"metric.safe_dedup_removed_count={removed_count} "
|
||||
f"caller={request.caller}"
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - started_at) * 1000.0
|
||||
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
||||
finally:
|
||||
setattr(retriever.config, "enable_ppr", original_ppr)
|
||||
|
||||
dedup_hit = False
|
||||
try:
|
||||
# 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。
|
||||
bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning"
|
||||
if (
|
||||
not bypass_request_dedup
|
||||
and
|
||||
plugin_instance is not None
|
||||
and hasattr(plugin_instance, "execute_request_with_dedup")
|
||||
):
|
||||
dedup_hit, payload = await plugin_instance.execute_request_with_dedup(
|
||||
request_key,
|
||||
_executor,
|
||||
)
|
||||
else:
|
||||
payload = await _executor()
|
||||
except Exception as e:
|
||||
return SearchExecutionResult(success=False, error=f"知识检索失败: {e}")
|
||||
|
||||
if dedup_hit:
|
||||
logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}")
|
||||
|
||||
return SearchExecutionResult(
|
||||
success=True,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
time_from=request.time_from,
|
||||
time_to=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
temporal=temporal,
|
||||
results=payload.get("results", []),
|
||||
elapsed_ms=float(payload.get("elapsed_ms", 0.0)),
|
||||
chat_filtered=False,
|
||||
dedup_hit=bool(dedup_hit),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]:
|
||||
serialized: List[Dict[str, Any]] = []
|
||||
for item in results:
|
||||
metadata = dict(getattr(item, "metadata", {}) or {})
|
||||
if "time_meta" not in metadata:
|
||||
metadata["time_meta"] = {}
|
||||
serialized.append(
|
||||
{
|
||||
"hash": getattr(item, "hash_value", ""),
|
||||
"type": getattr(item, "result_type", ""),
|
||||
"score": float(getattr(item, "score", 0.0)),
|
||||
"content": getattr(item, "content", ""),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return serialized
|
||||
425
plugins/A_memorix/core/utils/summary_importer.py
Normal file
425
plugins/A_memorix/core/utils/summary_importer.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
聊天总结与知识导入工具
|
||||
|
||||
该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系
|
||||
导入到 A_memorix 的存储组件中。
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services import llm_service as llm_api
|
||||
from src.services import message_service as message_api
|
||||
from src.config.config import global_config, model_config as host_model_config
|
||||
from src.config.model_configs import TaskConfig
|
||||
|
||||
from ..storage import (
|
||||
KnowledgeType,
|
||||
VectorStore,
|
||||
GraphStore,
|
||||
MetadataStore,
|
||||
resolve_stored_knowledge_type,
|
||||
)
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
from .relation_write_service import RelationWriteService
|
||||
from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check
|
||||
|
||||
logger = get_logger("A_Memorix.SummaryImporter")
|
||||
|
||||
# 默认总结提示词模版
|
||||
SUMMARY_PROMPT_TEMPLATE = """
|
||||
你是 {bot_name}。{personality_context}
|
||||
现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。
|
||||
|
||||
聊天记录内容:
|
||||
{chat_history}
|
||||
|
||||
请完成以下任务:
|
||||
1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。
|
||||
2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。
|
||||
|
||||
请严格以 JSON 格式输出,格式如下:
|
||||
{{
|
||||
"summary": "总结文本内容",
|
||||
"entities": ["张三", "李四"],
|
||||
"relations": [
|
||||
{{"subject": "张三", "predicate": "认识", "object": "李四"}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。
|
||||
"""
|
||||
|
||||
class SummaryImporter:
|
||||
"""总结并导入知识的工具类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorStore,
|
||||
graph_store: GraphStore,
|
||||
metadata_store: MetadataStore,
|
||||
embedding_manager: EmbeddingAPIAdapter,
|
||||
plugin_config: dict
|
||||
):
|
||||
self.vector_store = vector_store
|
||||
self.graph_store = graph_store
|
||||
self.metadata_store = metadata_store
|
||||
self.embedding_manager = embedding_manager
|
||||
self.plugin_config = plugin_config
|
||||
self.relation_write_service: Optional[RelationWriteService] = (
|
||||
plugin_config.get("relation_write_service")
|
||||
if isinstance(plugin_config, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]:
|
||||
"""标准化 summarization.model_name 配置(vNext 仅接受字符串数组)。"""
|
||||
if raw_value is None:
|
||||
return ["auto"]
|
||||
if isinstance(raw_value, list):
|
||||
selectors = [str(x).strip() for x in raw_value if str(x).strip()]
|
||||
return selectors or ["auto"]
|
||||
raise ValueError(
|
||||
"summarization.model_name 在 vNext 必须为 List[str]。"
|
||||
" 请执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
|
||||
def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]:
|
||||
"""
|
||||
选择总结默认任务,避免错误落到 embedding 任务。
|
||||
优先级:replyer > utils > planner > tool_use > 其他非 embedding。
|
||||
"""
|
||||
preferred = ("replyer", "utils", "planner", "tool_use")
|
||||
for name in preferred:
|
||||
cfg = available_tasks.get(name)
|
||||
if cfg and cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
for name, cfg in available_tasks.items():
|
||||
if name != "embedding" and cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
for name, cfg in available_tasks.items():
|
||||
if cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
return None, None
|
||||
|
||||
def _resolve_summary_model_config(self) -> Optional[TaskConfig]:
|
||||
"""
|
||||
解析 summarization.model_name 为 TaskConfig。
|
||||
支持:
|
||||
- "auto"
|
||||
- "replyer"(任务名)
|
||||
- "some-model-name"(具体模型名)
|
||||
- ["utils:model1", "utils:model2", "replyer"](数组混合语法)
|
||||
"""
|
||||
available_tasks = llm_api.get_available_models()
|
||||
if not available_tasks:
|
||||
return None
|
||||
|
||||
raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto")
|
||||
selectors = self._normalize_summary_model_selectors(raw_cfg)
|
||||
default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks)
|
||||
|
||||
selected_models: List[str] = []
|
||||
base_cfg: Optional[TaskConfig] = None
|
||||
model_dict = getattr(host_model_config, "models_dict", {})
|
||||
|
||||
def _append_models(models: List[str]):
|
||||
for model_name in models:
|
||||
if model_name and model_name not in selected_models:
|
||||
selected_models.append(model_name)
|
||||
|
||||
for raw_selector in selectors:
|
||||
selector = raw_selector.strip()
|
||||
if not selector:
|
||||
continue
|
||||
|
||||
if selector.lower() == "auto":
|
||||
if default_task_cfg:
|
||||
_append_models(default_task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = default_task_cfg
|
||||
continue
|
||||
|
||||
if ":" in selector:
|
||||
task_name, model_name = selector.split(":", 1)
|
||||
task_name = task_name.strip()
|
||||
model_name = model_name.strip()
|
||||
task_cfg = available_tasks.get(task_name)
|
||||
if not task_cfg:
|
||||
logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过")
|
||||
continue
|
||||
|
||||
if base_cfg is None:
|
||||
base_cfg = task_cfg
|
||||
|
||||
if not model_name or model_name.lower() == "auto":
|
||||
_append_models(task_cfg.model_list)
|
||||
continue
|
||||
|
||||
if model_name in model_dict or model_name in task_cfg.model_list:
|
||||
_append_models([model_name])
|
||||
else:
|
||||
logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过")
|
||||
continue
|
||||
|
||||
task_cfg = available_tasks.get(selector)
|
||||
if task_cfg:
|
||||
_append_models(task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = task_cfg
|
||||
continue
|
||||
|
||||
if selector in model_dict:
|
||||
_append_models([selector])
|
||||
continue
|
||||
|
||||
logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过")
|
||||
|
||||
if not selected_models:
|
||||
if default_task_cfg:
|
||||
_append_models(default_task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = default_task_cfg
|
||||
else:
|
||||
first_cfg = next(iter(available_tasks.values()))
|
||||
_append_models(first_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = first_cfg
|
||||
|
||||
if not selected_models:
|
||||
return None
|
||||
|
||||
template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values()))
|
||||
return TaskConfig(
|
||||
model_list=selected_models,
|
||||
max_tokens=template_cfg.max_tokens,
|
||||
temperature=template_cfg.temperature,
|
||||
slow_threshold=template_cfg.slow_threshold,
|
||||
selection_strategy=template_cfg.selection_strategy,
|
||||
)
|
||||
|
||||
async def import_from_stream(
|
||||
self,
|
||||
stream_id: str,
|
||||
context_length: Optional[int] = None,
|
||||
include_personality: Optional[bool] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
从指定的聊天流中提取记录并执行总结导入
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流 ID
|
||||
context_length: 总结的历史消息条数
|
||||
include_personality: 是否包含人设
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 结果消息)
|
||||
"""
|
||||
try:
|
||||
self_check_ok, self_check_msg = await self._ensure_runtime_self_check()
|
||||
if not self_check_ok:
|
||||
return False, f"导入前自检失败: {self_check_msg}"
|
||||
|
||||
# 1. 获取配置
|
||||
if context_length is None:
|
||||
context_length = self.plugin_config.get("summarization", {}).get("context_length", 50)
|
||||
|
||||
if include_personality is None:
|
||||
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
|
||||
|
||||
# 2. 获取历史消息
|
||||
# 获取当前时间之前的消息
|
||||
now = time.time()
|
||||
messages = message_api.get_messages_before_time_in_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=now,
|
||||
limit=context_length
|
||||
)
|
||||
|
||||
if not messages:
|
||||
return False, "未找到有效的聊天记录进行总结"
|
||||
|
||||
# 转换为可读文本
|
||||
chat_history_text = message_api.build_readable_messages(messages)
|
||||
|
||||
# 3. 准备提示词内容
|
||||
bot_name = global_config.bot.nickname or "机器人"
|
||||
personality_context = ""
|
||||
if include_personality:
|
||||
personality = getattr(global_config.bot, "personality", "")
|
||||
if personality:
|
||||
personality_context = f"你的性格设定是:{personality}"
|
||||
|
||||
# 4. 调用 LLM
|
||||
prompt = SUMMARY_PROMPT_TEMPLATE.format(
|
||||
bot_name=bot_name,
|
||||
personality_context=personality_context,
|
||||
chat_history=chat_history_text
|
||||
)
|
||||
|
||||
model_config_to_use = self._resolve_summary_model_config()
|
||||
if model_config_to_use is None:
|
||||
return False, "未找到可用的总结模型配置"
|
||||
|
||||
logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
|
||||
logger.info(f"总结模型候选列表: {model_config_to_use.model_list}")
|
||||
|
||||
success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt=prompt,
|
||||
model_config=model_config_to_use,
|
||||
request_type="A_Memorix.ChatSummarization"
|
||||
)
|
||||
|
||||
if not success or not response:
|
||||
return False, "LLM 生成总结失败"
|
||||
|
||||
# 5. 解析结果
|
||||
data = self._parse_llm_response(response)
|
||||
if not data or "summary" not in data:
|
||||
return False, "解析 LLM 响应失败或总结为空"
|
||||
|
||||
summary_text = data["summary"]
|
||||
entities = data.get("entities", [])
|
||||
relations = data.get("relations", [])
|
||||
msg_times = [
|
||||
float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)())
|
||||
for msg in messages
|
||||
if getattr(msg, "time", None) is not None
|
||||
]
|
||||
time_meta = {}
|
||||
if msg_times:
|
||||
time_meta = {
|
||||
"event_time_start": min(msg_times),
|
||||
"event_time_end": max(msg_times),
|
||||
"time_granularity": "minute",
|
||||
"time_confidence": 0.95,
|
||||
}
|
||||
|
||||
# 6. 执行导入
|
||||
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta)
|
||||
|
||||
# 7. 持久化
|
||||
self.vector_store.save()
|
||||
self.graph_store.save()
|
||||
|
||||
result_msg = (
|
||||
f"✅ 总结导入成功\n"
|
||||
f"📝 总结长度: {len(summary_text)}\n"
|
||||
f"📌 提取实体: {len(entities)}\n"
|
||||
f"🔗 提取关系: {len(relations)}"
|
||||
)
|
||||
return True, result_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}")
|
||||
return False, f"错误: {str(e)}"
|
||||
|
||||
async def _ensure_runtime_self_check(self) -> Tuple[bool, str]:
|
||||
plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None
|
||||
if plugin_instance is not None:
|
||||
report = await ensure_runtime_self_check(plugin_instance)
|
||||
else:
|
||||
report = await run_embedding_runtime_self_check(
|
||||
config=self.plugin_config,
|
||||
vector_store=self.vector_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
)
|
||||
if bool(report.get("ok", False)):
|
||||
return True, ""
|
||||
return (
|
||||
False,
|
||||
f"{report.get('message', 'unknown')} "
|
||||
f"(configured={report.get('configured_dimension', 0)}, "
|
||||
f"store={report.get('vector_store_dimension', 0)}, "
|
||||
f"encoded={report.get('encoded_dimension', 0)})",
|
||||
)
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析 LLM 返回的 JSON"""
|
||||
try:
|
||||
# 尝试查找 JSON
|
||||
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning(f"解析总结 JSON 失败: {e}")
|
||||
return {}
|
||||
|
||||
async def _execute_import(
|
||||
self,
|
||||
summary: str,
|
||||
entities: List[str],
|
||||
relations: List[Dict[str, str]],
|
||||
stream_id: str,
|
||||
time_meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""将数据写入存储"""
|
||||
# 获取默认知识类型
|
||||
type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative")
|
||||
try:
|
||||
knowledge_type = resolve_stored_knowledge_type(type_str, content=summary)
|
||||
except ValueError:
|
||||
logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative")
|
||||
knowledge_type = KnowledgeType.NARRATIVE
|
||||
|
||||
# 导入总结文本
|
||||
hash_value = self.metadata_store.add_paragraph(
|
||||
content=summary,
|
||||
source=f"chat_summary:{stream_id}",
|
||||
knowledge_type=knowledge_type.value,
|
||||
time_meta=time_meta,
|
||||
)
|
||||
|
||||
embedding = await self.embedding_manager.encode(summary)
|
||||
self.vector_store.add(
|
||||
vectors=embedding.reshape(1, -1),
|
||||
ids=[hash_value]
|
||||
)
|
||||
|
||||
# 导入实体
|
||||
if entities:
|
||||
self.graph_store.add_nodes(entities)
|
||||
|
||||
# 导入关系
|
||||
rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {})
|
||||
if not isinstance(rv_cfg, dict):
|
||||
rv_cfg = {}
|
||||
write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True))
|
||||
for rel in relations:
|
||||
s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object")
|
||||
if all([s, p, o]):
|
||||
if self.relation_write_service is not None:
|
||||
await self.relation_write_service.upsert_relation_with_vector(
|
||||
subject=s,
|
||||
predicate=p,
|
||||
obj=o,
|
||||
confidence=1.0,
|
||||
source_paragraph=summary,
|
||||
write_vector=write_vector,
|
||||
)
|
||||
else:
|
||||
# 写入元数据
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
subject=s,
|
||||
predicate=p,
|
||||
obj=o,
|
||||
confidence=1.0,
|
||||
source_paragraph=summary
|
||||
)
|
||||
# 写入图数据库(写入 relation_hashes,确保后续可按关系精确修剪)
|
||||
self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash])
|
||||
try:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"总结导入完成: hash={hash_value[:8]}")
|
||||
3522
plugins/A_memorix/core/utils/web_import_manager.py
Normal file
3522
plugins/A_memorix/core/utils/web_import_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user