Files
mai-bot/plugins/A_memorix/scripts/convert_lpmm.py
DawnARC 999e7246e2 feat:新增 A_Memorix 记忆插件
引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。
2026-03-18 21:33:15 +08:00

536 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
LPMM 到 A_memorix 存储转换器
功能:
1. 读取 LPMM parquet 文件 (paragraph.parquet, entity.parquet, relation.parquet)
2. 读取 LPMM 图文件 (graph.graphml 或 graph_structure.pkl)
3. 直接写入 A_memorix 二进制 VectorStore 和稀疏 GraphStore
4. 绕过 Embedding 生成以节省 Token
"""
import sys
import os
import json
import argparse
import asyncio
import pickle
import logging
from pathlib import Path
from typing import Dict, Any, List, Tuple
import numpy as np
import tomlkit
# 设置路径
current_dir = Path(__file__).resolve().parent
plugin_root = current_dir.parent
project_root = plugin_root.parent.parent
sys.path.insert(0, str(project_root))
def _build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="将 LPMM 数据转换为 A_memorix 格式")
parser.add_argument("--input", "-i", required=True, help="包含 LPMM 数据的输入目录 (parquet, graphml)")
parser.add_argument("--output", "-o", required=True, help="A_memorix 数据的输出目录")
parser.add_argument("--dim", type=int, default=384, help="Embedding 维度 (必须与 LPMM 模型匹配)")
parser.add_argument("--batch-size", type=int, default=1024, help="Parquet 分批读取大小 (默认 1024)")
parser.add_argument(
"--skip-relation-vector-rebuild",
action="store_true",
help="跳过按关系元数据重建关系向量(默认开启)",
)
return parser
# --help/-h fast path: avoid heavy host/plugin bootstrap
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
_build_arg_parser().print_help()
sys.exit(0)
# 设置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("LPMM_Converter")
try:
import networkx as nx
from scipy import sparse
import pyarrow.parquet as pq
except ImportError as e:
logger.error(f"缺少依赖: {e}")
logger.error("请安装: pip install pandas pyarrow networkx scipy")
sys.exit(1)
try:
# 优先采取相对导入 (将插件根目录加入路径)
# 这样可以避免硬编码插件名称 (plugins.A_memorix)
if str(plugin_root) not in sys.path:
sys.path.insert(0, str(plugin_root))
from core.storage.vector_store import VectorStore
from core.storage.graph_store import GraphStore
from core.storage.metadata_store import MetadataStore
from core.storage import QuantizationType, SparseMatrixFormat
from core.embedding import create_embedding_api_adapter
from core.utils.relation_write_service import RelationWriteService
except ImportError as e:
logger.error(f"无法导入 A_memorix 核心模块: {e}")
logger.error("请确保在正确的环境中运行,且已安装所有依赖。")
sys.exit(1)
class LPMMConverter:
def __init__(
self,
lpmm_data_dir: Path,
output_dir: Path,
dimension: int = 384,
batch_size: int = 1024,
rebuild_relation_vectors: bool = True,
):
self.lpmm_dir = lpmm_data_dir
self.output_dir = output_dir
self.dimension = dimension
self.batch_size = max(1, int(batch_size))
self.rebuild_relation_vectors = bool(rebuild_relation_vectors)
self.vector_dir = output_dir / "vectors"
self.graph_dir = output_dir / "graph"
self.metadata_dir = output_dir / "metadata"
self.vector_store = None
self.graph_store = None
self.metadata_store = None
self.embedding_manager = None
self.relation_write_service = None
# LPMM 原 ID -> A_memorix ID 映射(用于图重写)
self.id_mapping: Dict[str, str] = {}
def _register_id_mapping(self, raw_id: Any, mapped_id: str, p_type: str) -> None:
"""记录 ID 映射,兼容带/不带类型前缀两种格式。"""
if raw_id is None:
return
raw = str(raw_id).strip()
if not raw:
return
self.id_mapping[raw] = mapped_id
prefix = f"{p_type}-"
if raw.startswith(prefix):
self.id_mapping[raw[len(prefix):]] = mapped_id
else:
self.id_mapping[prefix + raw] = mapped_id
def _map_node_id(self, node: Any) -> str:
"""将图节点 ID 映射到转换后的 A_memorix ID。"""
node_key = str(node)
return self.id_mapping.get(node_key, node_key)
def initialize_stores(self):
"""初始化空的 A_memorix 存储"""
logger.info(f"正在初始化存储于 {self.output_dir}...")
# 初始化 VectorStore (A_memorix 默认使用 INT8 量化)
self.vector_store = VectorStore(
dimension=self.dimension,
quantization_type=QuantizationType.INT8,
data_dir=self.vector_dir
)
self.vector_store.clear() # 清空旧数据
# 初始化 GraphStore (使用 CSR 格式)
self.graph_store = GraphStore(
matrix_format=SparseMatrixFormat.CSR,
data_dir=self.graph_dir
)
self.graph_store.clear()
# 初始化 MetadataStore
self.metadata_store = MetadataStore(data_dir=self.metadata_dir)
self.metadata_store.connect()
# 清空元数据表?理想情况下是的,但要小心。
# 对于转换,我们假设是全新的开始或覆盖。
# A_memorix 中的 MetadataStore 通常使用 SQLite。
# 如果目录是新的,我们会依赖它创建新文件。
if self.rebuild_relation_vectors:
self._init_relation_vector_service()
def _load_plugin_config(self) -> Dict[str, Any]:
config_path = plugin_root / "config.toml"
if not config_path.exists():
return {}
try:
with open(config_path, "r", encoding="utf-8") as f:
parsed = tomlkit.load(f)
return dict(parsed) if isinstance(parsed, dict) else {}
except Exception as e:
logger.warning(f"读取 config.toml 失败,使用默认 embedding 配置: {e}")
return {}
def _init_relation_vector_service(self) -> None:
if not self.rebuild_relation_vectors:
return
cfg = self._load_plugin_config()
emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {}
if not isinstance(emb_cfg, dict):
emb_cfg = {}
try:
self.embedding_manager = create_embedding_api_adapter(
batch_size=int(emb_cfg.get("batch_size", 32)),
max_concurrent=int(emb_cfg.get("max_concurrent", 5)),
default_dimension=int(emb_cfg.get("dimension", self.dimension)),
model_name=str(emb_cfg.get("model_name", "auto")),
retry_config=emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {},
)
self.relation_write_service = RelationWriteService(
metadata_store=self.metadata_store,
graph_store=self.graph_store,
vector_store=self.vector_store,
embedding_manager=self.embedding_manager,
)
except Exception as e:
self.embedding_manager = None
self.relation_write_service = None
logger.warning(f"初始化关系向量重建服务失败,将跳过关系向量回填: {e}")
async def _rebuild_relation_vectors(self) -> None:
if not self.rebuild_relation_vectors:
return
if self.relation_write_service is None:
logger.warning("关系向量重建已启用,但写入服务不可用,已跳过。")
return
rows = self.metadata_store.get_relations()
if not rows:
logger.info("未发现关系元数据,无需重建关系向量。")
return
success = 0
failed = 0
skipped = 0
for row in rows:
result = await self.relation_write_service.ensure_relation_vector(
hash_value=str(row["hash"]),
subject=str(row.get("subject", "")),
predicate=str(row.get("predicate", "")),
obj=str(row.get("object", "")),
)
if result.vector_state == "ready":
if result.vector_written:
success += 1
else:
skipped += 1
else:
failed += 1
logger.info(
"关系向量重建完成: total=%s success=%s skipped=%s failed=%s",
len(rows),
success,
skipped,
failed,
)
@staticmethod
def _parse_relation_text(text: str) -> Tuple[str, str, str]:
raw = str(text or "").strip()
if not raw:
return "", "", ""
if "|" in raw:
parts = [p.strip() for p in raw.split("|") if p.strip()]
if len(parts) >= 3:
return parts[0], parts[1], parts[2]
if "->" in raw:
parts = [p.strip() for p in raw.split("->") if p.strip()]
if len(parts) >= 3:
return parts[0], parts[1], parts[2]
pieces = raw.split()
if len(pieces) >= 3:
return pieces[0], pieces[1], " ".join(pieces[2:])
return "", "", ""
def _import_relation_metadata_from_parquet(self, relation_path: Path) -> int:
if not relation_path.exists():
return 0
try:
parquet_file = pq.ParquetFile(relation_path)
except Exception as e:
logger.warning(f"读取 relation.parquet 失败,跳过关系元数据导入: {e}")
return 0
cols = set(parquet_file.schema_arrow.names)
has_triple_cols = {"subject", "predicate", "object"}.issubset(cols)
content_col = "str" if "str" in cols else ("content" if "content" in cols else "")
imported_hashes = set()
imported = 0
for record_batch in parquet_file.iter_batches(batch_size=self.batch_size):
df_batch = record_batch.to_pandas()
for _, row in df_batch.iterrows():
subject = ""
predicate = ""
obj = ""
if has_triple_cols:
subject = str(row.get("subject", "") or "").strip()
predicate = str(row.get("predicate", "") or "").strip()
obj = str(row.get("object", "") or "").strip()
elif content_col:
subject, predicate, obj = self._parse_relation_text(row.get(content_col, ""))
if not (subject and predicate and obj):
continue
rel_hash = self.metadata_store.add_relation(
subject=subject,
predicate=predicate,
obj=obj,
source_paragraph=None,
)
if rel_hash in imported_hashes:
continue
imported_hashes.add(rel_hash)
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
try:
self.metadata_store.set_relation_vector_state(rel_hash, "none")
except Exception:
pass
imported += 1
return imported
def convert_vectors(self):
"""将 Parquet 向量转换为 VectorStore"""
# LPMM 默认文件名
parquet_files = {
"paragraph": self.lpmm_dir / "paragraph.parquet",
"entity": self.lpmm_dir / "entity.parquet",
"relation": self.lpmm_dir / "relation.parquet"
}
total_vectors = 0
for p_type, p_path in parquet_files.items():
# 关系向量在当前脚本中无法保证与 MetadataStore 的关系记录一一对应,
# 直接导入会污染召回结果(命中后无法反查 relation 元数据)。
if p_type == "relation":
relation_count = self._import_relation_metadata_from_parquet(p_path)
logger.warning(
"跳过 relation.parquet 向量导入(保持一致性);已导入关系元数据: %s",
relation_count,
)
continue
if not p_path.exists():
logger.warning(f"文件未找到: {p_path}, 跳过 {p_type} 向量。")
continue
logger.info(f"正在处理 {p_type} 向量,来源: {p_path}...")
try:
parquet_file = pq.ParquetFile(p_path)
total_rows = parquet_file.metadata.num_rows
if total_rows == 0:
logger.info(f"{p_path} 为空,跳过。")
continue
# LPMM Schema: 'hash', 'embedding', 'str'
cols = parquet_file.schema_arrow.names
# 兼容性检查
content_col = 'str' if 'str' in cols else 'content'
emb_col = 'embedding'
hash_col = 'hash'
if content_col not in cols or emb_col not in cols:
logger.error(f"{p_path} 中缺少必要列 (需包含 {content_col}, {emb_col})。发现: {cols}")
continue
batch_columns = [content_col, emb_col]
if hash_col in cols:
batch_columns.append(hash_col)
processed_rows = 0
added_for_type = 0
batch_idx = 0
for record_batch in parquet_file.iter_batches(
batch_size=self.batch_size,
columns=batch_columns,
):
batch_idx += 1
df_batch = record_batch.to_pandas()
embeddings_list = []
ids_list = []
# 同时处理元数据映射
for _, row in df_batch.iterrows():
processed_rows += 1
content = row[content_col]
emb = row[emb_col]
if content is None or (isinstance(content, float) and np.isnan(content)):
continue
content = str(content).strip()
if not content:
continue
if emb is None or len(emb) == 0:
continue
# 先写 MetadataStore并使用其返回的真实 hash 作为向量 ID
# 保证检索返回 ID 可以直接反查元数据。
store_id = None
if p_type == "paragraph":
store_id = self.metadata_store.add_paragraph(
content=content,
source="lpmm_import",
knowledge_type="factual",
)
elif p_type == "entity":
store_id = self.metadata_store.add_entity(name=content)
else:
continue
raw_hash = row[hash_col] if hash_col in df_batch.columns else None
if raw_hash is not None and not (isinstance(raw_hash, float) and np.isnan(raw_hash)):
self._register_id_mapping(raw_hash, store_id, p_type)
# 确保 embedding 是 numpy 数组
emb_np = np.array(emb, dtype=np.float32)
if emb_np.shape[0] != self.dimension:
logger.error(f"维度不匹配: {emb_np.shape[0]} vs {self.dimension}")
continue
embeddings_list.append(emb_np)
ids_list.append(store_id)
if embeddings_list:
# 分批添加到向量存储
vectors_np = np.stack(embeddings_list)
count = self.vector_store.add(vectors_np, ids_list)
added_for_type += count
total_vectors += count
if batch_idx == 1 or batch_idx % 10 == 0:
logger.info(
f"[{p_type}] 批次 {batch_idx}: 已扫描 {processed_rows}/{total_rows}, 已导入 {added_for_type}"
)
logger.info(
f"{p_type} 向量处理完成:总扫描 {processed_rows},总导入 {added_for_type}"
)
except Exception as e:
logger.error(f"处理 {p_path} 时出错: {e}")
# 提交向量存储
self.vector_store.save()
logger.info(f"向量转换完成。总向量数: {total_vectors}")
def convert_graph(self):
"""将 LPMM 图转换为 GraphStore"""
# LPMM 默认文件名是 rag-graph.graphml
graph_files = [
self.lpmm_dir / "rag-graph.graphml",
self.lpmm_dir / "graph.graphml",
self.lpmm_dir / "graph_structure.pkl"
]
nx_graph = None
for g_path in graph_files:
if g_path.exists():
logger.info(f"发现图文件: {g_path}")
try:
if g_path.suffix == ".graphml":
nx_graph = nx.read_graphml(g_path)
elif g_path.suffix == ".pkl":
with open(g_path, "rb") as f:
data = pickle.load(f)
# LPMM 可能会将图存储在包装类中
if hasattr(data, "graph") and isinstance(data.graph, nx.Graph):
nx_graph = data.graph
elif isinstance(data, nx.Graph):
nx_graph = data
break
except Exception as e:
logger.error(f"加载 {g_path} 失败: {e}")
if nx_graph is None:
logger.warning("未找到有效的图文件。跳过图转换。")
return
logger.info(f"已加载图,包含 {nx_graph.number_of_nodes()} 个节点和 {nx_graph.number_of_edges()} 条边。")
# 1. 添加节点
# LPMM 节点通常是哈希或带前缀的字符串。
# 我们需要将它们映射到 A_memorix 格式。
# 如果 LPMM 使用 "entity-HASH",则与 A_memorix 匹配。
nodes_to_add = []
node_attrs = {}
for node, attrs in nx_graph.nodes(data=True):
# 假设 LPMM 使用一致的命名 "entity-..." 或 "paragraph-..."
mapped_node = self._map_node_id(node)
nodes_to_add.append(mapped_node)
if attrs:
node_attrs[mapped_node] = attrs
self.graph_store.add_nodes(nodes_to_add, node_attrs)
# 2. 添加边
edges_to_add = []
weights = []
for u, v, data in nx_graph.edges(data=True):
weight = data.get("weight", 1.0)
edges_to_add.append((self._map_node_id(u), self._map_node_id(v)))
weights.append(float(weight))
# 如果可能,将关系同步到 MetadataStore
# 但图的边并不总是包含关系谓词
# 如果 LPMM 边数据有 'predicate',我们可以添加到元数据
# 通常 LPMM 边是加权和,谓词信息可能在简单图中丢失
if edges_to_add:
self.graph_store.add_edges(edges_to_add, weights)
self.graph_store.save()
logger.info("图转换完成。")
def run(self):
self.initialize_stores()
self.convert_vectors()
self.convert_graph()
asyncio.run(self._rebuild_relation_vectors())
self.vector_store.save()
self.graph_store.save()
self.metadata_store.close()
logger.info("所有转换成功完成。")
def main():
parser = _build_arg_parser()
args = parser.parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
if not input_path.exists():
logger.error(f"输入目录不存在: {input_path}")
sys.exit(1)
converter = LPMMConverter(
input_path,
output_path,
dimension=args.dim,
batch_size=args.batch_size,
rebuild_relation_vectors=not bool(args.skip_relation_vector_rebuild),
)
converter.run()
if __name__ == "__main__":
main()