Files
mai-bot/plugins/A_memorix/scripts/migrate_maibot_memory.py
DawnARC 71b3a828c6 添加 A_Memorix 插件 v2.0.0(包含运行时与文档)
引入 A_Memorix 插件 v2.0.0:新增大量运行时组件、存储/模式更新、检索能力提升、管理工具、导入/调优工作流以及相关文档。关键新增内容包括:lifecycle_orchestrator、SDKMemoryKernel/运行时初始化器、新的存储层与 metadata_store 变更(SCHEMA_VERSION v8)、检索增强(双路径检索、图关系召回、稀疏 BM25),以及多种工具服务(episode/person_profile/relation/segmentation/tuning/search execution)。同时新增 Web 导入/摘要导入器及大量维护脚本。还更新了插件清单、embedding API 适配器、plugin.py、requirements/pyproject,以及主入口文件,使新插件接入项目。该变更为 2.0.0 版本发布做好准备,实现统一的 SDK Tool 接口并扩展整体运行能力。
2026-03-19 00:09:04 +08:00

1715 lines
65 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
MaiBot 记忆迁移脚本chat_history -> A_memorix
特性:
1. 高性能:分页读取 + 批量 embedding + 批量写入
2. 断点续传:基于 last_committed_id 的窗口提交
3. 精确一次语义:稳定哈希 + 幂等写入 + 向量存在性检查
4. 可确认筛选支持时间区间、聊天流stream/group/user筛选并先预览后确认
"""
from __future__ import annotations
import argparse
import asyncio
import hashlib
import importlib
import json
import logging
import os
import pickle
import sqlite3
import sys
import time
import traceback
import types
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple
import numpy as np
import tomlkit
CURRENT_DIR = Path(__file__).resolve().parent
PLUGIN_ROOT = CURRENT_DIR.parent
WORKSPACE_ROOT = PLUGIN_ROOT.parent
MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot"
RUNTIME_CORE_PACKAGE = "_a_memorix_runtime_core"
VectorStore = None
GraphStore = None
MetadataStore = None
create_embedding_api_adapter = None
KnowledgeType = None
QuantizationType = None
SparseMatrixFormat = None
compute_hash = None
normalize_text = None
atomic_write = None
model_config = None
RelationWriteService = None
def _create_bootstrap_logger():
fallback = logging.getLogger("A_Memorix.MaiBotMigration")
if not fallback.handlers:
fallback.addHandler(logging.NullHandler())
try:
for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT):
path_str = str(path)
if path_str not in sys.path:
sys.path.insert(0, path_str)
from src.common.logger import get_logger
return get_logger("A_Memorix.MaiBotMigration")
except Exception:
return fallback
logger = _create_bootstrap_logger()
def _ensure_import_paths() -> None:
for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT):
path_str = str(path)
if path_str not in sys.path:
sys.path.insert(0, path_str)
def _ensure_runtime_core_package() -> str:
existing = sys.modules.get(RUNTIME_CORE_PACKAGE)
if existing is not None and hasattr(existing, "__path__"):
return RUNTIME_CORE_PACKAGE
pkg = types.ModuleType(RUNTIME_CORE_PACKAGE)
pkg.__path__ = [str(PLUGIN_ROOT / "core")]
pkg.__package__ = RUNTIME_CORE_PACKAGE
sys.modules[RUNTIME_CORE_PACKAGE] = pkg
return RUNTIME_CORE_PACKAGE
def _disable_unavailable_gemini_provider() -> None:
global model_config
try:
from google import genai # type: ignore # noqa: F401
return
except Exception:
pass
from src.config.config import model_config as loaded_model_config
providers = list(getattr(loaded_model_config, "api_providers", []))
if not providers:
model_config = loaded_model_config
return
kept_providers = [p for p in providers if str(getattr(p, "client_type", "")).lower() != "gemini"]
if len(kept_providers) == len(providers):
model_config = loaded_model_config
return
loaded_model_config.api_providers = kept_providers
loaded_model_config.api_providers_dict = {p.name: p for p in kept_providers}
models = list(getattr(loaded_model_config, "models", []))
kept_models = [m for m in models if m.api_provider in loaded_model_config.api_providers_dict]
loaded_model_config.models = kept_models
loaded_model_config.models_dict = {m.name: m for m in kept_models}
task_cfg = loaded_model_config.model_task_config
for field_name in task_cfg.__dataclass_fields__.keys():
task = getattr(task_cfg, field_name, None)
if task is None or not hasattr(task, "model_list"):
continue
task.model_list = [m for m in list(task.model_list) if m in loaded_model_config.models_dict]
model_config = loaded_model_config
logger.warning("检测到缺少 google.genai已临时禁用 gemini provider 以保证脚本可运行。")
def _bootstrap_runtime_symbols() -> None:
global VectorStore
global GraphStore
global MetadataStore
global KnowledgeType
global QuantizationType
global SparseMatrixFormat
global compute_hash
global normalize_text
global atomic_write
global RelationWriteService
global logger
if VectorStore is not None and compute_hash is not None and atomic_write is not None:
return
_ensure_import_paths()
import src # noqa: F401
from src.common.logger import get_logger
logger = get_logger("A_Memorix.MaiBotMigration")
pkg = _ensure_runtime_core_package()
vector_store_module = importlib.import_module(f"{pkg}.storage.vector_store")
graph_store_module = importlib.import_module(f"{pkg}.storage.graph_store")
metadata_store_module = importlib.import_module(f"{pkg}.storage.metadata_store")
knowledge_types_module = importlib.import_module(f"{pkg}.storage.knowledge_types")
hash_module = importlib.import_module(f"{pkg}.utils.hash")
io_module = importlib.import_module(f"{pkg}.utils.io")
relation_write_service_module = importlib.import_module(f"{pkg}.utils.relation_write_service")
VectorStore = vector_store_module.VectorStore
GraphStore = graph_store_module.GraphStore
MetadataStore = metadata_store_module.MetadataStore
KnowledgeType = knowledge_types_module.KnowledgeType
QuantizationType = vector_store_module.QuantizationType
SparseMatrixFormat = graph_store_module.SparseMatrixFormat
compute_hash = hash_module.compute_hash
normalize_text = hash_module.normalize_text
atomic_write = io_module.atomic_write
RelationWriteService = relation_write_service_module.RelationWriteService
def _load_embedding_adapter_factory() -> None:
global create_embedding_api_adapter
global model_config
if create_embedding_api_adapter is not None:
return
_ensure_import_paths()
from src.config.config import model_config as loaded_model_config
model_config = loaded_model_config
_disable_unavailable_gemini_provider()
pkg = _ensure_runtime_core_package()
api_adapter_module = importlib.import_module(f"{pkg}.embedding.api_adapter")
create_embedding_api_adapter = api_adapter_module.create_embedding_api_adapter
DEFAULT_SOURCE_DB = MAIBOT_ROOT / "data" / "MaiBot.db"
DEFAULT_TARGET_DATA_DIR = PLUGIN_ROOT / "data"
DEFAULT_CONFIG_PATH = PLUGIN_ROOT / "config.toml"
MIGRATION_STATE_DIRNAME = "migration_state"
STATE_FILENAME = "chat_history_resume.json"
BAD_ROWS_FILENAME = "chat_history_bad_rows.jsonl"
REPORT_FILENAME = "chat_history_report.json"
class MigrationError(Exception):
"""迁移流程错误。"""
@dataclass
class SelectionFilter:
time_from_ts: Optional[float]
time_to_ts: Optional[float]
stream_ids: List[str]
stream_filter_requested: bool
start_id: Optional[int]
end_id: Optional[int]
time_from_raw: Optional[str]
time_to_raw: Optional[str]
def fingerprint_payload(self) -> Dict[str, Any]:
return {
"time_from_ts": self.time_from_ts,
"time_to_ts": self.time_to_ts,
"time_from_raw": self.time_from_raw,
"time_to_raw": self.time_to_raw,
"stream_ids": sorted(self.stream_ids),
"stream_filter_requested": self.stream_filter_requested,
"start_id": self.start_id,
"end_id": self.end_id,
}
@dataclass
class PreviewResult:
total: int
distribution: List[Tuple[str, int]]
samples: List[Dict[str, Any]]
@dataclass
class MappedRow:
row_id: int
chat_id: str
paragraph_hash: str
content: str
source: str
time_meta: Dict[str, Any]
entities: List[str]
relations: List[Tuple[str, str, str]]
existing_paragraph_vector: bool
def _safe_int(value: Any, default: int) -> int:
try:
return int(value)
except Exception:
return default
def _safe_float(value: Any, default: float) -> float:
try:
return float(value)
except Exception:
return default
def _normalize_name(value: Any) -> str:
return str(value or "").strip()
def _canonical_name(value: Any) -> str:
return _normalize_name(value).lower()
def _dedup_keep_order(items: Iterable[str]) -> List[str]:
out: List[str] = []
seen: set[str] = set()
for raw in items:
v = _normalize_name(raw)
if not v:
continue
k = v.lower()
if k in seen:
continue
seen.add(k)
out.append(v)
return out
def _format_ts(ts: Optional[float]) -> str:
if ts is None:
return "-"
try:
return datetime.fromtimestamp(float(ts)).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
return str(ts)
def _parse_cli_datetime(text: str, is_end: bool = False) -> float:
value = str(text or "").strip()
if not value:
raise ValueError("时间不能为空")
formats = [
("%Y-%m-%d %H:%M:%S", False),
("%Y/%m/%d %H:%M:%S", False),
("%Y-%m-%d %H:%M", False),
("%Y/%m/%d %H:%M", False),
("%Y-%m-%d", True),
("%Y/%m/%d", True),
]
for fmt, is_date_only in formats:
try:
dt = datetime.strptime(value, fmt)
if is_date_only and is_end:
dt = dt.replace(hour=23, minute=59, second=59, microsecond=0)
return dt.timestamp()
except ValueError:
continue
raise ValueError(
f"时间格式错误: {value},仅支持 YYYY-MM-DD、YYYY/MM/DD、YYYY-MM-DD HH:mm[:ss]、YYYY/MM/DD HH:mm[:ss]"
)
def _json_hash(payload: Dict[str, Any]) -> str:
data = json.dumps(payload, ensure_ascii=False, sort_keys=True)
return hashlib.sha1(data.encode("utf-8")).hexdigest()
def _deep_merge_dict(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
out = dict(base)
for key, value in override.items():
if isinstance(value, dict) and isinstance(out.get(key), dict):
out[key] = _deep_merge_dict(out[key], value)
else:
out[key] = value
return out
def _extract_schema_defaults(schema_obj: Dict[str, Any]) -> Dict[str, Any]:
defaults: Dict[str, Any] = {}
if not isinstance(schema_obj, dict):
return defaults
for key, spec in schema_obj.items():
if not isinstance(spec, dict):
continue
if "default" in spec:
defaults[key] = spec.get("default")
continue
props = spec.get("properties")
if isinstance(props, dict):
defaults[key] = _extract_schema_defaults(props)
return defaults
def _load_manifest_defaults() -> Dict[str, Any]:
manifest_path = PLUGIN_ROOT / "_manifest.json"
if not manifest_path.exists():
return {}
try:
with open(manifest_path, "r", encoding="utf-8") as f:
payload = json.load(f)
schema = payload.get("config_schema")
if isinstance(schema, dict):
return _extract_schema_defaults(schema)
except Exception as e:
logger.warning(f"读取 manifest 默认配置失败,已回退空配置: {e}")
return {}
def _build_source_db_fingerprint(db_path: Path) -> Dict[str, Any]:
stat = db_path.stat()
payload = {
"path": str(db_path.resolve()),
"size": stat.st_size,
"mtime": stat.st_mtime,
}
payload["sha1"] = _json_hash(payload)
return payload
def _state_path(target_data_dir: Path) -> Path:
return target_data_dir / MIGRATION_STATE_DIRNAME / STATE_FILENAME
def _bad_rows_path(target_data_dir: Path) -> Path:
return target_data_dir / MIGRATION_STATE_DIRNAME / BAD_ROWS_FILENAME
def _report_path(target_data_dir: Path) -> Path:
return target_data_dir / MIGRATION_STATE_DIRNAME / REPORT_FILENAME
def _dump_json_atomic(path: Path, payload: Dict[str, Any]) -> None:
if atomic_write is None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
with open(tmp, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
f.write("\n")
f.flush()
os.fsync(f.fileno())
os.replace(tmp, path)
return
with atomic_write(path, mode="w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
f.write("\n")
class SourceDB:
def __init__(self, db_path: Path):
self.db_path = db_path
self.conn: Optional[sqlite3.Connection] = None
def connect(self) -> None:
if not self.db_path.exists():
raise MigrationError(f"源数据库不存在: {self.db_path}")
uri = f"file:{self.db_path.resolve().as_posix()}?mode=ro"
try:
self.conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
except sqlite3.OperationalError:
self.conn = sqlite3.connect(str(self.db_path.resolve()), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
pragmas = [
"PRAGMA query_only = ON",
"PRAGMA cache_size = -128000",
"PRAGMA temp_store = MEMORY",
"PRAGMA synchronous = OFF",
"PRAGMA journal_mode = WAL",
]
for sql in pragmas:
try:
self.conn.execute(sql)
except sqlite3.OperationalError:
# 部分 PRAGMA 在 mode=ro 下会失败,不影响只读扫描能力
continue
def close(self) -> None:
if self.conn is not None:
self.conn.close()
self.conn = None
def _require_conn(self) -> sqlite3.Connection:
if self.conn is None:
raise MigrationError("源数据库尚未连接")
return self.conn
def resolve_stream_ids(
self,
stream_ids: Sequence[str],
group_ids: Sequence[str],
user_ids: Sequence[str],
) -> List[str]:
conn = self._require_conn()
resolved: set[str] = set(_normalize_name(x) for x in stream_ids if _normalize_name(x))
has_group_or_user = any(_normalize_name(x) for x in group_ids) or any(_normalize_name(x) for x in user_ids)
if not has_group_or_user:
return sorted(resolved)
table_exists = conn.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name='chat_streams' LIMIT 1"
).fetchone()
if table_exists is None:
raise MigrationError("源库缺少 chat_streams 表,无法根据 --group-id/--user-id 映射 stream_id")
def _select_by_field(field: str, values: Sequence[str]) -> None:
values_norm = [_normalize_name(v) for v in values if _normalize_name(v)]
if not values_norm:
return
placeholders = ",".join("?" for _ in values_norm)
sql = f"SELECT DISTINCT stream_id FROM chat_streams WHERE {field} IN ({placeholders})"
cur = conn.execute(sql, tuple(values_norm))
for row in cur.fetchall():
sid = _normalize_name(row["stream_id"])
if sid:
resolved.add(sid)
_select_by_field("group_id", group_ids)
_select_by_field("user_id", user_ids)
return sorted(resolved)
@staticmethod
def _build_where(
selection: SelectionFilter,
start_after_id: Optional[int] = None,
) -> Tuple[str, List[Any]]:
conditions: List[str] = []
params: List[Any] = []
if selection.start_id is not None:
conditions.append("id >= ?")
params.append(selection.start_id)
if selection.end_id is not None:
conditions.append("id <= ?")
params.append(selection.end_id)
if start_after_id is not None:
conditions.append("id > ?")
params.append(start_after_id)
if selection.stream_ids:
placeholders = ",".join("?" for _ in selection.stream_ids)
conditions.append(f"chat_id IN ({placeholders})")
params.extend(selection.stream_ids)
elif selection.stream_filter_requested:
conditions.append("1=0")
if selection.time_from_ts is not None and selection.time_to_ts is not None:
conditions.append("(end_time >= ? AND start_time <= ?)")
params.extend([selection.time_from_ts, selection.time_to_ts])
elif selection.time_from_ts is not None:
conditions.append("(end_time >= ?)")
params.append(selection.time_from_ts)
elif selection.time_to_ts is not None:
conditions.append("(start_time <= ?)")
params.append(selection.time_to_ts)
where_sql = "WHERE " + " AND ".join(conditions) if conditions else ""
return where_sql, params
def count_candidates(self, selection: SelectionFilter) -> int:
conn = self._require_conn()
where_sql, params = self._build_where(selection, start_after_id=None)
sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}"
cur = conn.execute(sql, tuple(params))
return int(cur.fetchone()["c"])
def preview(self, selection: SelectionFilter, preview_limit: int) -> PreviewResult:
conn = self._require_conn()
where_sql, params = self._build_where(selection, start_after_id=None)
total_sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}"
total = int(conn.execute(total_sql, tuple(params)).fetchone()["c"])
dist_sql = (
f"SELECT chat_id, COUNT(*) AS c FROM chat_history {where_sql} "
"GROUP BY chat_id ORDER BY c DESC LIMIT 30"
)
distribution = [
(_normalize_name(row["chat_id"]), int(row["c"]))
for row in conn.execute(dist_sql, tuple(params)).fetchall()
]
sample_sql = (
"SELECT id, chat_id, start_time, end_time, theme, summary "
f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?"
)
sample_params = list(params)
sample_params.append(max(1, int(preview_limit)))
samples = [dict(row) for row in conn.execute(sample_sql, tuple(sample_params)).fetchall()]
return PreviewResult(total=total, distribution=distribution, samples=samples)
def iter_rows(
self,
selection: SelectionFilter,
batch_size: int,
start_after_id: int,
) -> Generator[List[sqlite3.Row], None, None]:
conn = self._require_conn()
cursor = int(start_after_id)
while True:
where_sql, params = self._build_where(selection, start_after_id=cursor)
sql = (
"SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary "
f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?"
)
bind = list(params)
bind.append(max(1, int(batch_size)))
rows = conn.execute(sql, tuple(bind)).fetchall()
if not rows:
break
yield rows
cursor = int(rows[-1]["id"])
def sample_rows_for_verify(
self,
selection: SelectionFilter,
sample_size: int,
) -> List[sqlite3.Row]:
conn = self._require_conn()
where_sql, params = self._build_where(selection, start_after_id=None)
sql = (
"SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary "
f"FROM chat_history {where_sql} ORDER BY RANDOM() LIMIT ?"
)
bind = list(params)
bind.append(max(1, int(sample_size)))
return conn.execute(sql, tuple(bind)).fetchall()
class MigrationRunner:
def __init__(self, args: argparse.Namespace):
self.args = args
self.source_db_path = Path(args.source_db).resolve()
self.target_data_dir = Path(args.target_data_dir).resolve()
self.state_file = _state_path(self.target_data_dir)
self.bad_rows_file = _bad_rows_path(self.target_data_dir)
self.report_file = _report_path(self.target_data_dir)
self.source_db = SourceDB(self.source_db_path)
self.vector_store = None
self.graph_store = None
self.metadata_store = None
self.embedding_manager = None
self.relation_write_service = None
self.plugin_config: Dict[str, Any] = {}
self.embed_workers: int = 5
self.selection: Optional[SelectionFilter] = None
self.filter_fingerprint: str = ""
self.source_db_fingerprint: Dict[str, Any] = {}
self.source_db_fingerprint_hash: str = ""
self.state: Dict[str, Any] = {}
self.started_at = time.time()
self.exit_code = 0
self.failed = False
self.fail_reason: Optional[str] = None
self.stats: Dict[str, Any] = {
"source_matched_total": 0,
"scanned_rows": 0,
"valid_rows": 0,
"migrated_rows": 0,
"skipped_existing_rows": 0,
"bad_rows": 0,
"paragraph_vectors_added": 0,
"entity_vectors_added": 0,
"relations_written": 0,
"relation_vectors_written": 0,
"relation_vectors_failed": 0,
"relation_vectors_skipped": 0,
"graph_edges_written": 0,
"windows_committed": 0,
"last_committed_id": 0,
"verify_sample_size": 0,
"verify_paragraph_missing": 0,
"verify_vector_missing": 0,
"verify_relation_missing": 0,
"verify_edge_missing": 0,
"verify_passed": False,
}
async def run(self) -> int:
try:
_bootstrap_runtime_symbols()
self._prepare_paths()
self.source_db.connect()
self.selection = self._build_selection_filter()
self.filter_fingerprint = _json_hash(self.selection.fingerprint_payload())
self.source_db_fingerprint = _build_source_db_fingerprint(self.source_db_path)
self.source_db_fingerprint_hash = str(self.source_db_fingerprint.get("sha1", ""))
preview = self.source_db.preview(self.selection, preview_limit=self.args.preview_limit)
self.stats["source_matched_total"] = int(preview.total)
self._print_preview(preview)
if preview.total <= 0:
logger.info("筛选后无数据,退出。")
self.stats["verify_passed"] = True
if self.args.verify_only:
self._load_plugin_config()
await self._init_target_stores(require_embedding=False)
await self._verify(strict=True)
return self._finalize()
if self.args.verify_only:
self._load_plugin_config()
await self._init_target_stores(require_embedding=False)
await self._verify(strict=True)
return self._finalize()
if self.args.dry_run:
logger.info("dry-run 模式:仅预览,不写入。")
return self._finalize()
if not self.args.yes:
if not self._confirm():
logger.info("用户取消执行。")
return self._finalize()
self._load_plugin_config()
await self._init_target_stores(require_embedding=True)
self._load_or_init_state()
start_after_id = self._resolve_start_after_id()
await self._migrate(start_after_id=start_after_id)
await self._verify(strict=True)
return self._finalize()
except Exception as e:
self.failed = True
self.fail_reason = str(e)
logger.error(f"迁移失败: {e}\n{traceback.format_exc()}")
return self._finalize()
finally:
self._close()
def _prepare_paths(self) -> None:
(self.target_data_dir / MIGRATION_STATE_DIRNAME).mkdir(parents=True, exist_ok=True)
if self.args.reset_state and self.state_file.exists():
self.state_file.unlink()
if self.args.reset_state and self.bad_rows_file.exists():
self.bad_rows_file.unlink()
def _load_plugin_config(self) -> None:
merged = _load_manifest_defaults()
config_path = DEFAULT_CONFIG_PATH
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as f:
raw = tomlkit.load(f)
if isinstance(raw, dict):
merged = _deep_merge_dict(merged, dict(raw))
except Exception as e:
logger.warning(f"读取插件配置失败,继续使用默认配置: {e}")
self.plugin_config = merged
def _read_existing_vector_dimension(self, fallback_dimension: int) -> int:
meta_path = self.target_data_dir / "vectors" / "vectors_metadata.pkl"
if not meta_path.exists():
return fallback_dimension
try:
with open(meta_path, "rb") as f:
payload = pickle.load(f)
value = _safe_int(payload.get("dimension"), fallback_dimension)
return max(1, value)
except Exception:
return fallback_dimension
async def _init_target_stores(self, require_embedding: bool) -> None:
if VectorStore is None or GraphStore is None or MetadataStore is None:
raise MigrationError("运行时初始化失败:存储组件不可用")
emb_cfg = self.plugin_config.get("embedding", {}) if isinstance(self.plugin_config, dict) else {}
graph_cfg = self.plugin_config.get("graph", {}) if isinstance(self.plugin_config, dict) else {}
self.embed_workers = max(1, _safe_int(self.args.embed_workers, _safe_int(emb_cfg.get("max_concurrent"), 5)))
emb_batch_size = max(1, _safe_int(emb_cfg.get("batch_size"), 32))
emb_default_dim = max(1, _safe_int(emb_cfg.get("dimension"), 1024))
emb_model_name = str(emb_cfg.get("model_name", "auto"))
emb_retry = emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {}
if require_embedding:
_load_embedding_adapter_factory()
if create_embedding_api_adapter is None:
raise MigrationError("运行时初始化失败embedding 适配器不可用")
if model_config is not None:
embedding_task = getattr(getattr(model_config, "model_task_config", None), "embedding", None)
if embedding_task is not None and hasattr(embedding_task, "model_list"):
if not list(embedding_task.model_list):
raise MigrationError(
"当前配置没有可用 embedding 模型。若你使用 gemini provider请先安装 `google-genai` "
"或切换到可用的 embedding provider。"
)
self.embedding_manager = create_embedding_api_adapter(
batch_size=emb_batch_size,
max_concurrent=self.embed_workers,
default_dimension=emb_default_dim,
model_name=emb_model_name,
retry_config=emb_retry,
)
try:
detected_dim = self._read_existing_vector_dimension(emb_default_dim)
has_existing_vectors = (self.target_data_dir / "vectors" / "vectors_metadata.pkl").exists()
if not has_existing_vectors:
detected_dim = await self.embedding_manager._detect_dimension()
except Exception as e:
logger.warning(f"嵌入维度探测失败,回退配置维度: {e}")
detected_dim = self._read_existing_vector_dimension(emb_default_dim)
else:
detected_dim = self._read_existing_vector_dimension(emb_default_dim)
self.embedding_manager = None
q_type = str(emb_cfg.get("quantization_type", "int8")).lower()
if q_type != "int8":
raise MigrationError(
"embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。"
" 请先执行 scripts/release_vnext_migrate.py migrate。"
)
quantization = QuantizationType.INT8
matrix_fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower()
fmt_map = {
"csr": SparseMatrixFormat.CSR,
"csc": SparseMatrixFormat.CSC,
}
sparse_fmt = fmt_map.get(matrix_fmt, SparseMatrixFormat.CSR)
self.vector_store = VectorStore(
dimension=detected_dim,
quantization_type=quantization,
data_dir=self.target_data_dir / "vectors",
)
self.graph_store = GraphStore(
matrix_format=sparse_fmt,
data_dir=self.target_data_dir / "graph",
)
self.metadata_store = MetadataStore(data_dir=self.target_data_dir / "metadata")
self.metadata_store.connect()
if self.vector_store.has_data():
self.vector_store.load()
if self.graph_store.has_data():
self.graph_store.load()
self.relation_write_service = None
if require_embedding and RelationWriteService is not None and self.embedding_manager is not None:
self.relation_write_service = RelationWriteService(
metadata_store=self.metadata_store,
graph_store=self.graph_store,
vector_store=self.vector_store,
embedding_manager=self.embedding_manager,
)
logger.info(
f"目标存储初始化完成: dim={self.vector_store.dimension}, quant={q_type}, graph_fmt={matrix_fmt}, "
f"embed_workers={self.embed_workers}"
)
def _should_write_relation_vectors(self) -> bool:
retrieval_cfg = self.plugin_config.get("retrieval", {}) if isinstance(self.plugin_config, dict) else {}
if not isinstance(retrieval_cfg, dict):
return False
rv_cfg = retrieval_cfg.get("relation_vectorization", {})
if not isinstance(rv_cfg, dict):
return False
return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True))
async def _ensure_relation_vectors_for_records(
self,
relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]],
) -> None:
if not relation_records:
return
if self.relation_write_service is None:
return
success = 0
failed = 0
skipped = 0
for relation_hash, rel in relation_records.items():
result = await self.relation_write_service.ensure_relation_vector(
hash_value=relation_hash,
subject=str(rel[0]),
predicate=str(rel[1]),
obj=str(rel[2]),
)
if result.vector_state == "ready":
if result.vector_written:
success += 1
else:
skipped += 1
else:
failed += 1
self.stats["relation_vectors_written"] += success
self.stats["relation_vectors_failed"] += failed
self.stats["relation_vectors_skipped"] += skipped
def _build_selection_filter(self) -> SelectionFilter:
if self.args.start_id is not None and self.args.start_id <= 0:
raise MigrationError("--start-id 必须 > 0")
if self.args.end_id is not None and self.args.end_id <= 0:
raise MigrationError("--end-id 必须 > 0")
if self.args.start_id is not None and self.args.end_id is not None and self.args.start_id > self.args.end_id:
raise MigrationError("--start-id 不能大于 --end-id")
time_from_ts = _parse_cli_datetime(self.args.time_from, is_end=False) if self.args.time_from else None
time_to_ts = _parse_cli_datetime(self.args.time_to, is_end=True) if self.args.time_to else None
if time_from_ts is not None and time_to_ts is not None and time_from_ts > time_to_ts:
raise MigrationError("--time-from 不能晚于 --time-to")
stream_filter_requested = bool(
(self.args.stream_id or []) or (self.args.group_id or []) or (self.args.user_id or [])
)
stream_ids = self.source_db.resolve_stream_ids(
stream_ids=self.args.stream_id or [],
group_ids=self.args.group_id or [],
user_ids=self.args.user_id or [],
)
if stream_filter_requested and not stream_ids:
logger.warning("已指定 stream/group/user 筛选,但未解析到任何 stream_id结果将为空。")
logger.info(
f"筛选条件: time_from={self.args.time_from or '-'}, time_to={self.args.time_to or '-'}, "
f"stream_ids={len(stream_ids)}, stream_filter_requested={stream_filter_requested}"
)
return SelectionFilter(
time_from_ts=time_from_ts,
time_to_ts=time_to_ts,
stream_ids=stream_ids,
stream_filter_requested=stream_filter_requested,
start_id=self.args.start_id,
end_id=self.args.end_id,
time_from_raw=self.args.time_from,
time_to_raw=self.args.time_to,
)
def _load_or_init_state(self) -> None:
if self.args.start_id is not None:
logger.info("检测到 --start-id已按用户指定起点覆盖断点状态。")
self.state = self._new_state(last_committed_id=int(self.args.start_id) - 1)
return
if self.args.no_resume:
self.state = self._new_state(last_committed_id=0)
return
if not self.state_file.exists():
self.state = self._new_state(last_committed_id=0)
return
with open(self.state_file, "r", encoding="utf-8") as f:
loaded = json.load(f)
loaded_filter_fp = str(loaded.get("filter_fingerprint", ""))
loaded_source_fp = str(loaded.get("source_db_fingerprint", ""))
if loaded_filter_fp != self.filter_fingerprint or loaded_source_fp != self.source_db_fingerprint_hash:
if self.args.dry_run or self.args.verify_only:
logger.info("检测到断点与当前筛选不一致;当前为只读模式,将忽略旧断点。")
self.state = self._new_state(last_committed_id=0)
return
raise MigrationError(
"检测到筛选条件或源库指纹变化,已拒绝继续续传。请使用 --reset-state 或调整参数后重试。"
)
self.state = loaded
stored_stats = loaded.get("stats", {})
if isinstance(stored_stats, dict):
for k, v in stored_stats.items():
if k in self.stats and isinstance(v, (int, float, bool)):
self.stats[k] = v
def _new_state(self, last_committed_id: int) -> Dict[str, Any]:
return {
"version": 1,
"updated_at": time.time(),
"last_committed_id": int(last_committed_id),
"filter_fingerprint": self.filter_fingerprint,
"source_db_fingerprint": self.source_db_fingerprint_hash,
"source_db_meta": self.source_db_fingerprint,
"stats": dict(self.stats),
}
def _flush_state(self, last_committed_id: int) -> None:
self.stats["last_committed_id"] = int(last_committed_id)
self.state = {
"version": 1,
"updated_at": time.time(),
"last_committed_id": int(last_committed_id),
"filter_fingerprint": self.filter_fingerprint,
"source_db_fingerprint": self.source_db_fingerprint_hash,
"source_db_meta": self.source_db_fingerprint,
"stats": dict(self.stats),
}
_dump_json_atomic(self.state_file, self.state)
def _resolve_start_after_id(self) -> int:
if self.selection is None:
raise MigrationError("selection 未初始化")
if self.args.start_id is not None:
return int(self.args.start_id) - 1
if self.args.no_resume:
return 0
state_last = _safe_int(self.state.get("last_committed_id"), 0) if self.state else 0
return max(0, state_last)
def _print_preview(self, preview: PreviewResult) -> None:
print("\n=== Migration Preview ===")
print(f"source_db: {self.source_db_path}")
print(f"target_data_dir: {self.target_data_dir}")
if self.selection:
print(
f"time_window: [{self.selection.time_from_raw or '-'} ~ {self.selection.time_to_raw or '-'}] "
f"(ts: {_format_ts(self.selection.time_from_ts)} ~ {_format_ts(self.selection.time_to_ts)})"
)
print(
f"id_window: [{self.selection.start_id or '-'} ~ {self.selection.end_id or '-'}], "
f"selected_streams={len(self.selection.stream_ids)}"
)
print(f"matched_rows: {preview.total}")
if preview.distribution:
print("top_chat_distribution:")
for cid, cnt in preview.distribution[:10]:
print(f" - {cid}: {cnt}")
else:
print("top_chat_distribution: (none)")
if preview.samples:
print(f"samples (first {len(preview.samples)}):")
for row in preview.samples:
summary_preview = _normalize_name(row.get("summary", ""))[:60]
theme_preview = _normalize_name(row.get("theme", ""))[:30]
print(
f" - id={row.get('id')} chat_id={row.get('chat_id')} "
f"[{_format_ts(row.get('start_time'))} ~ {_format_ts(row.get('end_time'))}] "
f"theme={theme_preview!r} summary={summary_preview!r}"
)
print("=========================\n")
def _confirm(self) -> bool:
answer = input("确认按以上筛选执行迁移?输入 y 继续 [y/N]: ").strip().lower()
return answer in {"y", "yes"}
def _parse_json_list_field(self, raw: Any, field_name: str, row_id: int) -> List[str]:
if raw is None:
return []
if isinstance(raw, list):
data = raw
elif isinstance(raw, str):
try:
parsed = json.loads(raw)
except Exception as e:
raise ValueError(f"{field_name} JSON 解析失败: {e}") from e
if not isinstance(parsed, list):
raise ValueError(f"{field_name} JSON 必须是 list当前为 {type(parsed).__name__}")
data = parsed
else:
raise ValueError(f"{field_name} 字段类型不支持: {type(raw).__name__}")
return _dedup_keep_order(str(x) for x in data if _normalize_name(x))
def _map_row(self, row: sqlite3.Row) -> MappedRow:
row_id = int(row["id"])
chat_id = _normalize_name(row["chat_id"])
theme = _normalize_name(row["theme"])
summary = _normalize_name(row["summary"])
participants = self._parse_json_list_field(row["participants"], "participants", row_id)
keywords = self._parse_json_list_field(row["keywords"], "keywords", row_id)
keywords_top = keywords[:8]
participants_text = "".join(participants) if participants else ""
keywords_text = "".join(keywords_top) if keywords_top else ""
content = (
f"话题:{theme}\n"
f"概括:{summary}\n"
f"参与者:{participants_text}\n"
f"关键词:{keywords_text}"
).strip()
paragraph_hash = compute_hash(normalize_text(content))
source = f"maibot.chat_history:{chat_id}"
start_time = _safe_float(row["start_time"], 0.0)
end_time = _safe_float(row["end_time"], start_time)
time_meta = {
"event_time_start": start_time,
"event_time_end": end_time,
"time_granularity": "minute",
"time_confidence": 0.95,
}
entities = _dedup_keep_order([*participants, theme, *keywords_top])
relations: List[Tuple[str, str, str]] = []
if theme:
for participant in participants:
relations.append((participant, "参与话题", theme))
for keyword in keywords_top:
relations.append((theme, "关键词", keyword))
existing_vector = paragraph_hash in self.vector_store
return MappedRow(
row_id=row_id,
chat_id=chat_id,
paragraph_hash=paragraph_hash,
content=content,
source=source,
time_meta=time_meta,
entities=entities,
relations=relations,
existing_paragraph_vector=existing_vector,
)
def _append_bad_row(self, row: sqlite3.Row, reason: str) -> None:
payload = {
"id": int(row["id"]),
"chat_id": _normalize_name(row["chat_id"]),
"start_time": row["start_time"],
"end_time": row["end_time"],
"participants": row["participants"],
"theme": _normalize_name(row["theme"]),
"keywords": row["keywords"],
"summary": row["summary"],
"error": reason,
"timestamp": time.time(),
}
self.bad_rows_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.bad_rows_file, "a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False))
f.write("\n")
async def _migrate(self, start_after_id: int) -> None:
if self.selection is None:
raise MigrationError("selection 未初始化")
read_batch_size = max(1, int(self.args.read_batch_size))
commit_window_rows = max(1, int(self.args.commit_window_rows))
log_every = max(1, int(self.args.log_every))
window_rows: List[MappedRow] = []
window_scanned = 0
last_seen_id = start_after_id
logger.info(
f"开始迁移: start_after_id={start_after_id}, read_batch_size={read_batch_size}, "
f"commit_window_rows={commit_window_rows}"
)
for batch in self.source_db.iter_rows(self.selection, read_batch_size, start_after_id):
for row in batch:
row_id = int(row["id"])
last_seen_id = row_id
self.stats["scanned_rows"] += 1
window_scanned += 1
try:
mapped = self._map_row(row)
except Exception as e:
self.stats["bad_rows"] += 1
self._append_bad_row(row, str(e))
if self.stats["bad_rows"] > int(self.args.max_errors):
raise MigrationError(
f"坏行数量超过上限 max_errors={self.args.max_errors},已中止。"
)
continue
self.stats["valid_rows"] += 1
if mapped.existing_paragraph_vector:
self.stats["skipped_existing_rows"] += 1
else:
self.stats["migrated_rows"] += 1
window_rows.append(mapped)
if window_scanned >= commit_window_rows:
await self._commit_window(window_rows, last_seen_id)
window_rows = []
window_scanned = 0
if self.stats["scanned_rows"] % log_every == 0:
logger.info(
f"迁移进度: scanned={self.stats['scanned_rows']}/{self.stats['source_matched_total']}, "
f"valid={self.stats['valid_rows']}, bad={self.stats['bad_rows']}, "
f"last_id={last_seen_id}"
)
if window_scanned > 0 or window_rows:
await self._commit_window(window_rows, last_seen_id)
logger.info(
f"迁移主流程完成: scanned={self.stats['scanned_rows']}, valid={self.stats['valid_rows']}, "
f"bad={self.stats['bad_rows']}, last_committed_id={self.stats['last_committed_id']}"
)
async def _commit_window(self, rows: List[MappedRow], last_seen_id: int) -> None:
if not rows:
self._flush_state(last_seen_id)
self.stats["windows_committed"] += 1
return
now_ts = time.time()
empty_meta_blob = pickle.dumps({})
conn = self.metadata_store.get_connection()
cursor = conn.cursor()
# 批量查询本窗口内已存在的段落,保证重跑时 entity/mention 不重复累计
existing_paragraph_hashes: set[str] = set()
all_hashes = [item.paragraph_hash for item in rows]
for i in range(0, len(all_hashes), 800):
batch_hashes = all_hashes[i : i + 800]
if not batch_hashes:
continue
placeholders = ",".join("?" for _ in batch_hashes)
existing_rows = cursor.execute(
f"SELECT hash FROM paragraphs WHERE hash IN ({placeholders})",
tuple(batch_hashes),
).fetchall()
for row in existing_rows:
existing_paragraph_hashes.add(str(row["hash"]))
paragraph_records: List[Tuple[Any, ...]] = []
paragraph_embed_map: Dict[str, str] = {}
entity_display: Dict[str, str] = {}
entity_counts: Dict[str, int] = defaultdict(int)
paragraph_entity_mentions: Dict[Tuple[str, str], int] = defaultdict(int)
entity_embed_map: Dict[str, str] = {}
relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]] = {}
paragraph_relation_links: set[Tuple[str, str]] = set()
for item in rows:
is_new_paragraph = item.paragraph_hash not in existing_paragraph_hashes
start_ts = _safe_float(item.time_meta.get("event_time_start"), 0.0)
end_ts = _safe_float(item.time_meta.get("event_time_end"), start_ts)
confidence = _safe_float(item.time_meta.get("time_confidence"), 0.95)
granularity = _normalize_name(item.time_meta.get("time_granularity")) or "minute"
if is_new_paragraph:
paragraph_records.append(
(
item.paragraph_hash,
item.content,
None,
now_ts,
now_ts,
empty_meta_blob,
item.source,
len(normalize_text(item.content).split()),
None,
start_ts,
end_ts,
granularity,
confidence,
KnowledgeType.NARRATIVE.value,
)
)
if item.paragraph_hash not in self.vector_store:
paragraph_embed_map[item.paragraph_hash] = item.content
for entity in item.entities:
name = _normalize_name(entity)
if not name:
continue
canon = _canonical_name(name)
if not canon:
continue
entity_hash = compute_hash(canon)
entity_display.setdefault(entity_hash, name)
if is_new_paragraph:
entity_counts[entity_hash] += 1
paragraph_entity_mentions[(item.paragraph_hash, entity_hash)] += 1
if entity_hash not in self.vector_store:
entity_embed_map.setdefault(entity_hash, name)
for subject, predicate, obj in item.relations:
s = _normalize_name(subject)
p = _normalize_name(predicate)
o = _normalize_name(obj)
if not (s and p and o):
continue
s_canon = _canonical_name(s)
p_canon = _canonical_name(p)
o_canon = _canonical_name(o)
relation_hash = compute_hash(f"{s_canon}|{p_canon}|{o_canon}")
if is_new_paragraph:
relation_records.setdefault(
relation_hash,
(s, p, o, 1.0, item.paragraph_hash, empty_meta_blob),
)
paragraph_relation_links.add((item.paragraph_hash, relation_hash))
for relation_entity in (s, o):
e_canon = _canonical_name(relation_entity)
if not e_canon:
continue
e_hash = compute_hash(e_canon)
entity_display.setdefault(e_hash, relation_entity)
if is_new_paragraph:
entity_counts[e_hash] += 1
paragraph_entity_mentions[(item.paragraph_hash, e_hash)] += 1
if e_hash not in self.vector_store:
entity_embed_map.setdefault(e_hash, relation_entity)
try:
cursor.execute("BEGIN")
if paragraph_records:
cursor.executemany(
"""
INSERT OR IGNORE INTO paragraphs
(
hash, content, vector_index, created_at, updated_at, metadata, source, word_count,
event_time, event_time_start, event_time_end, time_granularity, time_confidence, knowledge_type
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
paragraph_records,
)
if entity_counts:
entity_rows = [
(
entity_hash,
entity_display[entity_hash],
None,
int(count),
now_ts,
empty_meta_blob,
)
for entity_hash, count in entity_counts.items()
]
try:
cursor.executemany(
"""
INSERT INTO entities
(hash, name, vector_index, appearance_count, created_at, metadata)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(hash) DO UPDATE SET
appearance_count = entities.appearance_count + excluded.appearance_count
""",
entity_rows,
)
except sqlite3.OperationalError:
cursor.executemany(
"""
INSERT OR IGNORE INTO entities
(hash, name, vector_index, appearance_count, created_at, metadata)
VALUES (?, ?, ?, ?, ?, ?)
""",
entity_rows,
)
cursor.executemany(
"UPDATE entities SET appearance_count = appearance_count + ? WHERE hash = ?",
[(int(count), entity_hash) for entity_hash, count in entity_counts.items()],
)
if paragraph_entity_mentions:
pe_rows = [
(paragraph_hash, entity_hash, int(mentions))
for (paragraph_hash, entity_hash), mentions in paragraph_entity_mentions.items()
]
try:
cursor.executemany(
"""
INSERT INTO paragraph_entities
(paragraph_hash, entity_hash, mention_count)
VALUES (?, ?, ?)
ON CONFLICT(paragraph_hash, entity_hash) DO UPDATE SET
mention_count = paragraph_entities.mention_count + excluded.mention_count
""",
pe_rows,
)
except sqlite3.OperationalError:
cursor.executemany(
"""
INSERT OR IGNORE INTO paragraph_entities
(paragraph_hash, entity_hash, mention_count)
VALUES (?, ?, ?)
""",
pe_rows,
)
cursor.executemany(
"""
UPDATE paragraph_entities
SET mention_count = mention_count + ?
WHERE paragraph_hash = ? AND entity_hash = ?
""",
[(m, p, e) for (p, e, m) in pe_rows],
)
if relation_records:
relation_rows = [
(
relation_hash,
rel[0],
rel[1],
rel[2],
None,
rel[3],
now_ts,
rel[4],
rel[5],
)
for relation_hash, rel in relation_records.items()
]
cursor.executemany(
"""
INSERT OR IGNORE INTO relations
(hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
relation_rows,
)
if paragraph_relation_links:
pr_rows = [(p_hash, r_hash) for p_hash, r_hash in paragraph_relation_links]
cursor.executemany(
"""
INSERT OR IGNORE INTO paragraph_relations
(paragraph_hash, relation_hash)
VALUES (?, ?)
""",
pr_rows,
)
conn.commit()
except Exception:
conn.rollback()
raise
self.stats["relations_written"] += len(relation_records)
if relation_records:
edge_pairs = []
relation_hashes = []
for relation_hash, rel in relation_records.items():
edge_pairs.append((rel[0], rel[2]))
relation_hashes.append(relation_hash)
with self.graph_store.batch_update():
self.graph_store.add_edges(edge_pairs, relation_hashes=relation_hashes)
self.stats["graph_edges_written"] += len(edge_pairs)
if self._should_write_relation_vectors():
await self._ensure_relation_vectors_for_records(relation_records)
para_added = await self._embed_and_add_vectors(
id_to_text=paragraph_embed_map,
batch_size=max(1, int(self.args.embed_batch_size)),
workers=self.embed_workers,
)
ent_added = await self._embed_and_add_vectors(
id_to_text=entity_embed_map,
batch_size=max(1, int(self.args.entity_embed_batch_size)),
workers=self.embed_workers,
)
self.stats["paragraph_vectors_added"] += para_added
self.stats["entity_vectors_added"] += ent_added
self.vector_store.save()
self.graph_store.save()
self.stats["windows_committed"] += 1
self._flush_state(last_seen_id)
async def _embed_and_add_vectors(
self,
id_to_text: Dict[str, str],
batch_size: int,
workers: int,
) -> int:
if not id_to_text:
return 0
if self.embedding_manager is None:
raise MigrationError("embedding_manager 未初始化,无法写入向量")
ids = []
texts = []
for hash_id, text in id_to_text.items():
if hash_id in self.vector_store:
continue
ids.append(hash_id)
texts.append(text)
if not ids:
return 0
total_added = 0
chunk_size = max(1, int(batch_size))
for i in range(0, len(ids), chunk_size):
chunk_ids = ids[i : i + chunk_size]
chunk_texts = texts[i : i + chunk_size]
embeddings = await self.embedding_manager.encode_batch(
chunk_texts,
batch_size=chunk_size,
num_workers=max(1, int(workers)),
)
emb_arr = np.asarray(embeddings, dtype=np.float32)
if emb_arr.ndim == 1:
emb_arr = emb_arr.reshape(1, -1)
if emb_arr.shape[0] != len(chunk_ids):
logger.warning(
f"embedding 返回数量异常: expected={len(chunk_ids)}, got={emb_arr.shape[0]},跳过该批次"
)
continue
valid_vectors = []
valid_ids = []
for idx, vec in enumerate(emb_arr):
if vec.ndim != 1:
continue
if vec.shape[0] != self.vector_store.dimension:
logger.warning(
f"向量维度不匹配,跳过: id={chunk_ids[idx]}, got={vec.shape[0]}, expected={self.vector_store.dimension}"
)
continue
if not np.all(np.isfinite(vec)):
logger.warning(f"向量含 NaN/Inf跳过: id={chunk_ids[idx]}")
continue
if chunk_ids[idx] in self.vector_store:
continue
valid_vectors.append(vec)
valid_ids.append(chunk_ids[idx])
if valid_vectors:
batch_vectors = np.stack(valid_vectors).astype(np.float32, copy=False)
added = self.vector_store.add(batch_vectors, valid_ids)
total_added += int(added)
return total_added
async def _verify(self, strict: bool) -> None:
if self.selection is None:
raise MigrationError("selection 未初始化")
sample_size = min(2000, max(0, int(self.stats.get("source_matched_total", 0))))
self.stats["verify_sample_size"] = sample_size
if sample_size <= 0:
self.stats["verify_passed"] = True
return
sample_rows = self.source_db.sample_rows_for_verify(self.selection, sample_size)
para_missing = 0
vec_missing = 0
rel_missing = 0
edge_missing = 0
for row in sample_rows:
try:
mapped = self._map_row(row)
except Exception:
continue
paragraph = self.metadata_store.get_paragraph(mapped.paragraph_hash)
if paragraph is None:
para_missing += 1
if mapped.paragraph_hash not in self.vector_store:
vec_missing += 1
for s, p, o in mapped.relations:
relation_hash = compute_hash(f"{_canonical_name(s)}|{_canonical_name(p)}|{_canonical_name(o)}")
relation = self.metadata_store.get_relation(relation_hash)
if relation is None:
rel_missing += 1
if self.graph_store.get_edge_weight(s, o) <= 0.0:
edge_missing += 1
self.stats["verify_paragraph_missing"] = para_missing
self.stats["verify_vector_missing"] = vec_missing
self.stats["verify_relation_missing"] = rel_missing
self.stats["verify_edge_missing"] = edge_missing
verify_passed = all(x == 0 for x in [para_missing, vec_missing, rel_missing, edge_missing])
if strict and not verify_passed:
self.failed = True
self.fail_reason = (
"严格校验失败: "
f"paragraph_missing={para_missing}, vector_missing={vec_missing}, "
f"relation_missing={rel_missing}, edge_missing={edge_missing}"
)
self.stats["verify_passed"] = verify_passed
def _finalize(self) -> int:
elapsed = time.time() - self.started_at
self.stats["elapsed_seconds"] = elapsed
report = {
"success": not self.failed,
"fail_reason": self.fail_reason,
"args": vars(self.args),
"source_db": str(self.source_db_path),
"target_data_dir": str(self.target_data_dir),
"selection": self.selection.fingerprint_payload() if self.selection else {},
"filter_fingerprint": self.filter_fingerprint,
"source_db_fingerprint": self.source_db_fingerprint,
"state_file": str(self.state_file),
"bad_rows_file": str(self.bad_rows_file),
"stats": dict(self.stats),
"timestamp": time.time(),
}
_dump_json_atomic(self.report_file, report)
if self.failed:
self.exit_code = 1
elif self.stats.get("bad_rows", 0) > 0:
self.exit_code = 2
else:
self.exit_code = 0
print("\n=== Migration Report ===")
print(f"success: {not self.failed}")
if self.fail_reason:
print(f"fail_reason: {self.fail_reason}")
print(f"elapsed: {elapsed:.2f}s")
print(f"source_matched_total: {self.stats['source_matched_total']}")
print(f"scanned_rows: {self.stats['scanned_rows']}")
print(f"valid_rows: {self.stats['valid_rows']}")
print(f"migrated_rows: {self.stats['migrated_rows']}")
print(f"skipped_existing_rows: {self.stats['skipped_existing_rows']}")
print(f"bad_rows: {self.stats['bad_rows']}")
print(f"paragraph_vectors_added: {self.stats['paragraph_vectors_added']}")
print(f"entity_vectors_added: {self.stats['entity_vectors_added']}")
print(f"relations_written: {self.stats['relations_written']}")
print(
"relation_vectors: "
f"written={self.stats['relation_vectors_written']}, "
f"failed={self.stats['relation_vectors_failed']}, "
f"skipped={self.stats['relation_vectors_skipped']}"
)
print(f"graph_edges_written: {self.stats['graph_edges_written']}")
print(f"windows_committed: {self.stats['windows_committed']}")
print(f"last_committed_id: {self.stats['last_committed_id']}")
print(
"verify: "
f"sample={self.stats['verify_sample_size']}, "
f"paragraph_missing={self.stats['verify_paragraph_missing']}, "
f"vector_missing={self.stats['verify_vector_missing']}, "
f"relation_missing={self.stats['verify_relation_missing']}, "
f"edge_missing={self.stats['verify_edge_missing']}, "
f"passed={self.stats['verify_passed']}"
)
print(f"report_file: {self.report_file}")
print("========================\n")
return self.exit_code
def _close(self) -> None:
try:
if self.metadata_store is not None:
self.metadata_store.close()
except Exception:
pass
self.source_db.close()
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="迁移 MaiBot chat_history 到 A_memorix高性能 + 可断点续传 + 可确认筛选)"
)
parser.add_argument("--source-db", default=str(DEFAULT_SOURCE_DB), help="源数据库路径(默认 data/MaiBot.db")
parser.add_argument(
"--target-data-dir",
default=str(DEFAULT_TARGET_DATA_DIR),
help="A_memorix 数据目录(默认 plugins/A_memorix/data",
)
resume_group = parser.add_mutually_exclusive_group()
resume_group.add_argument("--resume", dest="no_resume", action="store_false", help="启用断点续传(默认)")
resume_group.add_argument("--no-resume", dest="no_resume", action="store_true", help="禁用断点续传")
parser.set_defaults(no_resume=False)
parser.add_argument("--reset-state", action="store_true", help="清空迁移状态文件后执行")
parser.add_argument("--start-id", type=int, default=None, help="从指定 chat_history.id 开始迁移(覆盖断点)")
parser.add_argument("--end-id", type=int, default=None, help="迁移到指定 chat_history.id")
parser.add_argument("--read-batch-size", type=int, default=2000, help="源库分页读取大小(默认 2000")
parser.add_argument("--commit-window-rows", type=int, default=20000, help="每窗口提交行数(默认 20000")
parser.add_argument("--embed-batch-size", type=int, default=256, help="段落 embedding 批次大小(默认 256")
parser.add_argument(
"--entity-embed-batch-size",
type=int,
default=512,
help="实体 embedding 批次大小(默认 512",
)
parser.add_argument("--embed-workers", type=int, default=None, help="embedding 并发数(默认读取配置)")
parser.add_argument("--max-errors", type=int, default=500, help="坏行上限(默认 500")
parser.add_argument("--log-every", type=int, default=5000, help="日志输出步长(默认 5000")
parser.add_argument("--dry-run", action="store_true", help="仅预览不写入")
parser.add_argument("--verify-only", action="store_true", help="仅执行严格校验")
parser.add_argument("--time-from", default=None, help="开始时间YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]")
parser.add_argument("--time-to", default=None, help="结束时间YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]")
parser.add_argument("--stream-id", action="append", default=[], help="聊天流 stream_id可重复")
parser.add_argument("--group-id", action="append", default=[], help="群号(可重复,自动映射 stream_id")
parser.add_argument("--user-id", action="append", default=[], help="用户号(可重复,自动映射 stream_id")
parser.add_argument("--yes", action="store_true", help="跳过交互确认")
parser.add_argument("--preview-limit", type=int, default=20, help="预览样本条数(默认 20")
return parser
async def async_main() -> int:
parser = build_parser()
args = parser.parse_args()
runner = MigrationRunner(args)
return await runner.run()
def main() -> int:
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
return asyncio.run(async_main())
if __name__ == "__main__":
raise SystemExit(main())