diff --git a/plugins/A_memorix/CHANGELOG.md b/plugins/A_memorix/CHANGELOG.md new file mode 100644 index 00000000..772cff46 --- /dev/null +++ b/plugins/A_memorix/CHANGELOG.md @@ -0,0 +1,718 @@ +# 更新日志 (Changelog) + +## [2.0.0] - 2026-03-18 + +本次 `2.0.0` 为架构收敛版本,主线是 **SDK Tool 接口统一**、**管理工具能力补齐**、**元数据 schema 升级到 v8** 与 **文档口径同步到 2.0.0**。 + +### 🔖 版本信息 + +- 插件版本:`1.0.1` → `2.0.0` +- 元数据 schema:`7` → `8` + +### 🚀 重点能力 + +- Tool 接口统一: + - `plugin.py` 统一通过 `SDKMemoryKernel` 对外提供 Tool 能力。 + - 保留基础工具:`search_memory / ingest_summary / ingest_text / get_person_profile / maintain_memory / memory_stats`。 + - 新增管理工具:`memory_graph_admin / memory_source_admin / memory_episode_admin / memory_profile_admin / memory_runtime_admin / memory_import_admin / memory_tuning_admin / memory_v5_admin / memory_delete_admin`。 +- 检索与写入治理增强: + - 检索/写入链路支持 `respect_filter + user_id/group_id` 的聊天过滤语义。 + - `maintain_memory` 支持 `freeze` 与 `recycle_bin`,并统一到内核维护流程。 +- 导入与调优能力收敛: + - `memory_import_admin` 提供任务化导入能力(上传、粘贴、扫描、OpenIE、LPMM 转换、时序回填、MaiBot 迁移)。 + - `memory_tuning_admin` 提供检索调优任务(创建、轮次查看、回滚、apply_best、报告导出)。 +- V5 与删除运维: + - 新增 `memory_v5_admin`(`reinforce/weaken/remember_forever/forget/restore/status`)。 + - 新增 `memory_delete_admin`(`preview/execute/restore/list/get/purge`),支持操作审计与恢复。 + +### 🛠️ 存储与运行时 + +- `metadata_store` 升级到 `SCHEMA_VERSION = 8`。 +- 新增/完善外部引用与运维记录能力(包括 `external_memory_refs`、`memory_v5_operations`、`delete_operations` 相关数据结构)。 +- `SDKMemoryKernel` 增加统一后台任务编排(自动保存、Episode pending 处理、画像刷新、记忆维护)。 + +### 📚 文档同步 + +- `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 已切换到 `2.0.0` 口径。 +- 文档主入口统一为 SDK Tool 工作流,不再以旧版 slash 命令作为主说明路径。 + +## [1.0.1] - 2026-03-07 + +本次 `1.0.1` 为 `1.0.0` 发布后的热修复版本,主线是 **图谱 WebUI 取数稳定性修复**、**大图过滤性能修复** 与 **真实检索调优链路稳定性修复**。 + +### 🔖 版本信息 + +- 插件版本:`1.0.0` → `1.0.1` +- 配置版本:`4.1.0`(不变) + +### 🛠️ 代码修复 + +- 图谱接口稳定性: + - 修复 `/api/graph` 在“磁盘已有图文件但运行时尚未装载入内存”场景下返回空图的问题,接口现在会自动补加载持久化图数据。 + - 修复问题数据集下 WebUI 打开图谱页时看似“没有任何节点”的现象;根因不是图数据消失,而是后端过滤路径过慢。 +- 图谱过滤性能: + - 优化 `/api/graph?exclude_leaf=true` 的叶子过滤逻辑,改为预计算 hub 邻接关系,不再对每个节点反复做高成本边权查询。 + - 优化 `GraphStore.get_neighbors()` 并补充入邻居访问能力,避免稠密矩阵展开导致的大图性能退化。 +- 检索调优稳定性: + - 修复真实调优任务在构建运行时配置时深拷贝 `plugin.config`,误复制注入的存储实例并触发 `cannot pickle '_thread.RLock' object` 的问题。 + - 调优评估改为跳过顶层运行时实例键,仅保留纯配置字段后再附加运行时依赖,真实 WebUI 调优任务可正常启动。 + +### 📚 文档同步 + +- 同步更新 `README.md`、`CHANGELOG.md`、`CONFIG_REFERENCE.md` 与版本元数据(`plugin.py`、`__init__.py`、`_manifest.json`)。 +- README 新增 `v1.0.1` 修复说明,并补充“调优前先做 runtime self-check”的建议。 + +## [1.0.0] - 2026-03-06 + +本次 `1.0.0` 为主版本升级,主线是 **运行时架构模块化**、**Episode 情景记忆闭环**、**聚合检索与图召回增强**、**离线迁移 / 运行时自检 / 检索调优中心**。 + +### 🔖 版本信息 + +- 插件版本:`0.7.0` → `1.0.0` +- 配置版本:`4.1.0`(不变) + +### 🚀 重点能力 + +- 运行时重构: + - `plugin.py` 大幅瘦身,生命周期、后台任务、请求路由、检索运行时初始化拆分到 `core/runtime/*`。 + - 配置 schema 抽离到 `core/config/plugin_config_schema.py`,`_manifest.json` 同步扩展新配置项。 +- 检索与查询增强: + - `KnowledgeQueryTool` 拆分为 query mode + orchestrator,新增长 `aggregate` / `episode` 查询模式。 + - 新增图辅助关系召回、统一 forward/runtime 构建与请求去重桥接。 +- Episode / 运维能力: + - `metadata_store` schema 升级到 `SCHEMA_VERSION = 7`,新增 `episodes` / `episode_paragraphs` / rebuild queue 等结构。 + - 新增 `release_vnext_migrate.py`、`runtime_self_check.py`、`rebuild_episodes.py` 与 Web 检索调优页 `web/tuning.html`。 + +### 📚 文档同步 + +- 版本号同步到 `plugin.py`、`__init__.py`、`_manifest.json`、`README.md` 与 `CONFIG_REFERENCE.md`。 +- 新增 `RELEASE_SUMMARY_1.0.0.md` + +## [0.7.0] - 2026-03-04 + +本次 `0.7.0` 为中版本升级,主线是 **关系向量化闭环(写入 + 状态机 + 回填 + 审计)**、**检索/命令链路增强** 与 **导入任务能力补齐**。 + +### 🔖 版本信息 + +- 插件版本:`0.6.1` → `0.7.0` +- 配置版本:`4.1.0`(不变) + +### 🚀 重点能力 + +- 关系向量化闭环: + - 新增统一关系写入服务 `RelationWriteService`(metadata 先写、向量后写,失败进入状态机而非回滚主数据)。 + - `relations` 侧补齐 `vector_state/retry_count/last_error/updated_at` 等状态字段,支持 `none/pending/ready/failed` 统一治理。 + - 插件新增后台回填循环与统计接口,可持续修复关系向量缺失并暴露覆盖率指标。 +- 检索与命令链路增强: + - 检索主链继续收敛到 `search/time` forward 路由,`legacy` 仅保留兼容别名。 + - relation 查询规格解析收口,结构化查询与语义回退边界更清晰。 + - `/query stats` 与 tool stats 补充关系向量化统计输出。 +- 导入与运维增强: + - Web Import 新增 `temporal_backfill` 任务入口与编排处理。 + - 新增一致性审计与离线回填脚本,支持灰度修复历史数据。 + +### 📚 文档同步 + +- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志版本信息。 +- `README.md` 新增关系向量审计/回填脚本使用说明,并更新 `convert_lpmm.py` 的关系向量重建行为描述。 + +## [0.6.1] - 2026-03-03 + +本次 `0.6.1` 为热修复小版本,重点修复 WebUI 插件配置接口在 A_Memorix 场景下的 `tomlkit` 节点序列化兼容问题。 + +### 🔖 版本信息 + +- 插件版本:`0.6.0` → `0.6.1` +- 配置版本:`4.1.0`(不变) + +### 🛠️ 代码修复 + +- 新增运行时补丁 `_patch_webui_a_memorix_routes_for_tomlkit_serialization()`: + - 仅包裹 `/api/webui/plugins/config/{plugin_id}` 及其 schema 的 `GET` 路由。 + - 仅在 `plugin_id == "A_Memorix"` 时,将返回中的 `config/schema` 通过 `to_builtin_data` 原生化。 + - 保持 `/api/webui/config/*` 全局接口行为不变,避免对其他插件或核心配置路径产生副作用。 +- 在插件初始化时执行该补丁,确保 WebUI 读取插件配置时返回结构可稳定序列化。 + +### 📚 文档同步 + +- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志中的版本信息及修复说明。 + +## [0.6.0] - 2026-03-02 + +本次 `0.6.0` 为中版本升级,主线是 **Web Import 导入中心上线与脚本能力对齐**、**失败重试机制升级**、**删除后 manifest 同步** 与 **导入链路稳定性增强**。 + +### 🔖 版本信息 + +- 插件版本:`0.5.1` → `0.6.0` +- 配置版本:`4.0.1` → `4.1.0` + +### 🚀 重点能力 + +- 新增 Web Import 导入中心(`/import`): + - 上传/粘贴/本地扫描/LPMM OpenIE/LPMM 转换/时序回填/MaiBot 迁移。 + - 任务/文件/分块三级状态展示,支持取消与失败重试。 + - 导入文档弹窗读取(远程优先,失败回退本地)。 +- 失败重试升级为“分块优先 + 文件回退”: + - `POST /api/import/tasks/{task_id}/retry_failed` 保持原路径,语义升级。 + - 支持对 `extracting` 失败分块进行子集重试。 + - `writing`/JSON 解析失败自动回退为文件级重试。 +- 删除后 manifest 同步失效: + - 覆盖 `/api/source/batch_delete` 与 `/api/source`。 + - 返回 `manifest_cleanup` 明细,避免误命中去重跳过重导入。 + +### 📂 变更文件清单(本次发布) + +新增文件: + +- `core/utils/web_import_manager.py` +- `scripts/migrate_maibot_memory.py` +- `web/import.html` + +修改文件: + +- `CHANGELOG.md` +- `CONFIG_REFERENCE.md` +- `IMPORT_GUIDE.md` +- `QUICK_START.md` +- `README.md` +- `__init__.py` +- `_manifest.json` +- `components/commands/debug_server_command.py` +- `core/embedding/api_adapter.py` +- `core/storage/graph_store.py` +- `core/utils/summary_importer.py` +- `plugin.py` +- `requirements.txt` +- `server.py` +- `web/index.html` + +删除文件: + +- 无 + +### 📚 文档同步 + +- 同步更新 `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 与本日志。 +- `IMPORT_GUIDE.md` 新增 “Web Import 导入中心” 专区,统一说明能力范围、状态语义与安全边界。 + +## [0.5.1] - 2026-02-23 + +本次 `0.5.1` 为热修订小版本,重点修复“随主程序启动的后台任务拉起”“空名单过滤语义”以及“知识抽取模型选择”。 + +### 🔖 版本信息 + +- 插件版本:`0.5.0` → `0.5.1` +- 配置版本:`4.0.0` → `4.0.1` + +### 🛠️ 代码修复 + +- 生命周期接入主程序事件: + - 新增 `a_memorix_start_handler`(`ON_START`)调用 `plugin.on_enable()`; + - 新增 `a_memorix_stop_handler`(`ON_STOP`)调用 `plugin.on_disable()`; + - 解决仅注册插件但未触发生命周期时,定时导入任务不启动的问题。 +- 聊天过滤空列表策略调整: + - `whitelist + []`:全部拒绝; + - `blacklist + []`:全部放行。 +- 知识抽取模型选择逻辑调整(`import_command._select_model`): + - `advanced.extraction_model` 现在支持三种语义:任务名 / 模型名 / `auto`; + - `auto` 优先抽取相关任务(`lpmm_entity_extract`、`lpmm_rdf_build` 等),并避免误落到 `embedding`; + - 当配置无法识别时输出告警并回退自动选择,提高导入阶段的模型选择可预期性。 + +### 📚 文档同步 + +- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与 `CHANGELOG.md`。 +- 同步修正文档中的空名单过滤行为描述,保持与当前代码一致。 + +## [0.5.0] - 2026-02-15 + +本次 `0.5.0` 以提交 `66ddc1b98547df3c866b19a3f5dc96e1c8eb7731` 为核心,主线是“人物画像能力上线 + 工具/命令接入 + 版本与文档同步”。 + +### 🔖 版本信息 + +- 插件版本:`0.4.0` → `0.5.0` +- 配置版本:`3.1.0` → `4.0.0` + +### 🚀 人物画像主特性(核心) + +- 新增人物画像服务:`core/utils/person_profile_service.py` + - 支持 `person_id/姓名/别名` 解析。 + - 聚合图关系证据 + 向量证据,生成画像文本并版本化快照。 + - 支持手工覆盖(override)与 TTL 快照复用。 +- 存储层新增人物画像相关表与 API:`core/storage/metadata_store.py` + - `person_profile_switches` + - `person_profile_snapshots` + - `person_profile_active_persons` + - `person_profile_overrides` +- 新增命令:`/person_profile on|off|status` + - 文件:`components/commands/person_profile_command.py` + - 作用:按 `stream_id + user_id` 控制自动注入开关(opt-in 模式)。 +- 查询链路接入人物画像: + - `knowledge_query_tool` 新增 `query_type=person`,支持 `person_id` 或别名查询。 + - `/query person` 与 `/query p` 接入画像查询输出。 +- 插件生命周期接入画像刷新任务: + - 启动/停止统一管理 `person_profile_refresh` 后台任务。 + - 按活跃窗口自动刷新画像快照。 + +### 🛠️ 版本与 schema 同步 + +- `plugin.py`:`plugin_version` 更新为 `0.5.0`。 +- `plugin.py`:`plugin.config_version` 默认值更新为 `4.0.0`。 +- `config.toml`:`config_version` 基线同步为 `4.0.0`(本地配置文件)。 +- `__init__.py`:`__version__` 更新为 `0.5.0`。 +- `_manifest.json`:`version` 更新为 `0.5.0`,`manifest_version` 保持 `1` 。 +- `manifest_utils.py`:仓库内已兼容更高 manifest 版本;但插件发布默认保持 `manifest_version=1` 。 + +### 📚 文档同步 + +- 更新 `README.md`、`CONFIG_REFERENCE.md`、`QUICK_START.md`、`USAGE_ARCHITECTURE.md`。 +- 0.5.0 文档主线改为“人物画像能力 + 版本升级 + 检索链路补充说明”。 + +## [0.4.0] - 2026-02-13 + +本次 `0.4.0` 版本整合了时序检索增强与后续检索链路增强、稳定性修复和文档同步。 + +### 🔖 版本信息 + +- 插件版本:`0.3.3` → `0.4.0` +- 配置版本:`3.0.0` → `3.1.0` + +### 🚀 新增 + +- 新增 `core/retrieval/sparse_bm25.py` + - `SparseBM25Config` / `SparseBM25Index` + - FTS5 + BM25 稀疏检索 + - 支持 `jieba/mixed/char_2gram` 分词与懒加载 + - 支持 ngram 倒排回退与可选 LIKE 兜底 +- `DualPathRetriever` 新增 sparse/fusion 配置注入: + - embedding 不可用时自动 sparse 回退; + - `hybrid` 模式支持向量路 + sparse 路并行候选; + - 新增 `FusionConfig` 与 `weighted_rrf` 融合。 +- `MetadataStore` 新增 FTS/倒排能力: + - `paragraphs_fts`、`relations_fts` schema 与回填; + - `paragraph_ngrams` 倒排索引与回填; + - `fts_search_bm25` / `fts_search_relations_bm25` / `ngram_search_paragraphs`。 + +### 🛠️ 组件链路同步 + +- `plugin.py` + - 新增 `[retrieval.sparse]`、`[retrieval.fusion]` 默认配置; + - 初始化并向组件注入 `sparse_index`; + - `on_disable` 支持按配置卸载 sparse 连接并释放缓存。 +- `knowledge_search_action.py` / `query_command.py` / `knowledge_query_tool.py` + - 统一接入 sparse/fusion 配置; + - 统一注入 `sparse_index`; + - `stats` 输出新增 sparse 状态观测。 +- `requirements.txt` + - 新增 `jieba>=0.42.1`(未安装时自动回退 char n-gram)。 + +### 🧯 修复与行为调整 + +- 修复 `retrieval.ppr_concurrency_limit` 不生效问题: + - `DualPathRetriever` 使用配置值初始化 `_ppr_semaphore`,不再被固定值覆盖。 +- 修复 `char_2gram` 召回失效场景: + - FTS miss 时增加 `_fallback_substring_search`,优先 ngram 倒排回退,按配置可选 LIKE 兜底。 +- 提升可观测性与兼容性: + - `get_statistics()` 对向量规模字段兼容读取 `size -> num_vectors -> 0`,避免属性缺失导致异常。 + - `/query stats` 与 `knowledge_query` 输出包含 sparse 状态(enabled/loaded/tokenizer/doc_count)。 + +### 📚 文档 + +- `README.md` + - 新增检索增强说明、稀疏行为说明、时序回填脚本入口。 +- `CONFIG_REFERENCE.md` + - 补齐 sparse/fusion 参数与触发规则、回退链路、融合实现细节。 + +### ⏱️ 时序检索与导入增强 + +#### 时序检索能力(分钟级) + +- 新增统一时序查询入口: + - `/query time`(别名 `/query t`) + - `knowledge_query(query_type=time)` + - `knowledge_search(query_type=time|hybrid)` +- 查询时间参数统一支持: + - `YYYY/MM/DD` + - `YYYY/MM/DD HH:mm` +- 日期参数自动展开边界: + - `from/time_from` -> `00:00` + - `to/time_to` -> `23:59` +- 查询结果统一回传 `metadata.time_meta`,包含命中时间窗口与命中依据(事件时间或 `created_at` 回退)。 + +#### 存储与检索链路 + +- 段落存储层支持时序字段: + - `event_time` + - `event_time_start` + - `event_time_end` + - `time_granularity` + - `time_confidence` +- 时序命中采用区间相交逻辑,并遵循“双层时间语义”: + - 优先 `event_time/event_time_range` + - 缺失时回退 `created_at`(可配置关闭) +- 检索排序规则保持:语义优先,时间次排序(新到旧)。 +- `process_knowledge.py` 新增 `--chat-log` 参数: + - 启用后强制使用 `narrative` 策略; + - 使用 LLM 对聊天文本进行语义时间抽取(支持相对时间转绝对时间),写入 `event_time/event_time_start/event_time_end`。 + - 新增 `--chat-reference-time`,用于指定相对时间语义解析的参考时间点。 + +#### Schema 与文档同步 + +- `_manifest.json` 同步补齐 `retrieval.temporal` 配置 schema。 +- 配置 schema 版本升级:`config_version` 从 `3.0.0` 提升到 `3.1.0`(`plugin.py` / `config.toml` / 配置文档同步)。 +- 更新 `README.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md`,补充时序检索入口、参数格式与导入时间字段说明。 + +## [0.3.3] - 2026-02-11 + +本次更新为 **语言一致性补丁版本**,重点收敛知识抽取时的语言漂移问题,要求输出严格贴合原文语言,不做翻译改写。 + +### 🛠️ 关键修复 + +#### 抽取语言约束 + +- `BaseStrategy`: + - 移除按 `zh/en/mixed` 分支的语言类型判定逻辑; + - 统一为单一约束:抽取值保持原文语言、保留原始术语、禁止翻译。 +- `NarrativeStrategy` / `FactualStrategy`: + - 抽取提示词统一接入上述语言约束; + - 明确要求 JSON 键名固定、抽取值遵循原文语言表达。 + +#### 导入链路一致性 + +- `ImportCommand` 的 LLM 抽取提示词同步强化“优先原文语言、不要翻译”要求,避免脚本与指令导入行为不一致。 + +#### 测试与文档 + +- 更新 `test_strategies.py`,将语言判定测试调整为统一语言约束测试,并验证提示词中包含禁止翻译约束。 +- 同步更新注释与文档描述,确保实现与说明一致。 + +### 🔖 版本信息 + +- 插件版本:`0.3.2` → `0.3.3` + +## [0.3.2] - 2026-02-11 + +本次更新为 **V5 稳定性与兼容性修复版本**,在保持原有业务设计(强化→衰减→冷冻→修剪→回收)的前提下,修复关键链路断裂与误判问题。 + +### 🛠️ 关键修复 + +#### V5 记忆系统契约与链路 + +- `MetadataStore`: + - 统一 `mark_relations_inactive(hashes, inactive_since=None)` 调用契约,兼容不同调用方; + - 补充 `has_table(table_name)`; + - 增加 `restore_relation(hash)` 兼容别名,修复服务层恢复调用断裂; + - 修正 `get_entity_gc_candidates` 对孤立节点参数的处理(支持节点名映射到实体 hash)。 +- `GraphStore`: + - 清理 `deactivate_edges` 重复定义并统一返回冻结数量,保证上层日志与断言稳定。 +- `server.py`: + - 修复 `/api/memory/restore` relation 恢复链路; + - 清理不可达分支并统一异常路径; + - 回收站查询在表检测场景下不再出现错误退空。 + +#### 命令与模型选择 + +- `/memory` 命令修复 hash 长度判定:以 64 位 `sha256` 为标准,同时兼容历史 32 位输入。 +- 总结模型选择修复: + - 解决 `summarization.model_name = auto` 误命中 `embedding` 问题; + - 支持数组与选择器语法(`task:model` / task / model); + - 兼容逗号分隔字符串写法(如 `"utils:model1","utils:model2",replyer`)。 + +#### 生命周期与脚本稳定性 + +- `plugin.py` 修复后台任务生命周期管理: + - 增加 `_scheduled_import_task` / `_auto_save_task` / `_memory_maintenance_task` 句柄; + - 避免重复启动; + - 插件停用时统一 cancel + await 收敛。 +- `process_knowledge.py` 修复 tenacity 重试日志级别类型错误(`"WARNING"` → `logging.WARNING`),避免 `KeyError: 'WARNING'`。 + +### 🔖 版本信息 + +- 插件版本:`0.3.1` → `0.3.2` + +## [0.3.1] - 2026-02-07 + +本次更新为 **稳定性补丁版本**,主要修复脚本导入链路、删除安全性与 LPMM 转换一致性问题。 + +### 🛠️ 关键修复 + +#### 新增功能 + +- 新增 `scripts/convert_lpmm.py`: + - 支持将 LPMM 的 `parquet + graph` 数据直接转换为 A_Memorix 存储结构; + - 提供 LPMM ID 到 A_Memorix ID 的映射能力,用于图节点/边重写; + - 当前实现优先保证检索一致性,关系向量采用安全策略(不直接导入)。 + +#### 导入链路 + +- 修复 `import_lpmm_json.py` 依赖的 `AutoImporter.import_json_data` 公共入口缺失/不稳定问题,确保外部脚本可稳定调用 JSON 直导入流程。 + +#### 删除安全 + +- 修复按来源删除时“同一 `(subject, object)` 存在多关系”场景下的误删风险: + - `MetadataStore.delete_paragraph_atomic` 新增 `relation_prune_ops`; + - 仅在无兄弟关系时才回退删除整条边。 +- `delete_knowledge.py` 新增保守孤儿实体清理(仅对本次候选实体执行,且需同时满足无段落引用、无关系引用、图无邻居)。 +- `delete_knowledge.py` 改为读取向量元数据中的真实维度,避免 `dimension=1` 写回污染。 + +#### LPMM 转换修复 + +- 修复 `convert_lpmm.py` 中向量 ID 与 `MetadataStore` 哈希不一致导致的检索反查失败问题。 +- 为避免脏召回,转换阶段暂时跳过 `relation.parquet` 的直接向量导入(待关系元数据一一映射能力完善后再恢复)。 + +### 🔖 版本信息 + +- 插件版本:`0.3.0` → `0.3.1` + +## [0.3.0] - 2026-01-30 + +本次更新引入了 **V5 动态记忆系统**,实现了符合生物学特性的记忆衰减、强化与全声明周期管理,并提供了配套的指令与工具。 + +### 🧠 记忆系统 (V5) + +#### 核心机制 + +- **记忆衰减 (Decay)**: 引入"遗忘曲线",随时间推移自动降低图谱连接权重。 +- **访问强化 (Reinforcement)**: "越用越强",每次检索命中都会刷新记忆活跃度并增强权重。 +- **生命周期 (Lifecycle)**: + - **活跃 (Active)**: 正常参与计算与检索。 + - **冷冻 (Inactive)**: 权重过低被冻结,不再参与 PPR 计算,但保留语义映射 (Mapping)。 + - **修剪 (Prune)**: 过期且无保护的冷冻记忆将被移入回收站。 +- **多重保护**: 支持 **永久锁定 (Pin)** 与 **限时保护 (TTL)**,防止关键记忆被误删。 + +#### GraphStore + +- **多关系映射**: 实现 `(u,v) -> Set[Hash]` 映射,确保同一通道下的多重语义关系互不干扰。 +- **原子化操作**: 新增 `decay`, `deactivate_edges` (软删), `prune_relation_hashes` (硬删) 等原子操作。 + +### 🛠️ 指令与工具 + +#### Memory Command (`/memory`) + +新增全套记忆维护指令: + +- `/memory status`: 查看记忆系统健康状态(活跃/冷冻/回收站计数)。 +- `/memory protect [hours]`: 保护记忆。不填时间为永久锁定(Pin),填时间为临时保护(TTL)。 +- `/memory reinforce `: 手动强化记忆(绕过冷却时间)。 +- `/memory restore `: 从回收站恢复误删记忆(仅当节点存在时重建连接)。 + +#### MemoryModifierTool + +- **LLM 能力增强**: 更新工具逻辑,支持 LLM 自主触发 `reinforce`, `weaken`, `remember_forever`, `forget` 操作,并自动映射到 V5 底层逻辑。 + +### ⚙️ 配置 (`config.toml`) + +新增 `[memory]` 配置节: + +- `half_life_hours`: 记忆半衰期 (默认 24h)。 +- `enable_auto_reinforce`: 是否开启检索自动强化。 +- `prune_threshold`: 冷冻/修剪阈值 (默认 0.1)。 + +### 💻 WebUI (v1.4) + +实现了与 V5 记忆系统深度集成的全生命周期管理界面: + +- **可视化增强**: + - **冷冻状态**: 非活跃记忆以 **虚线 + 灰色 (Slate-300)** 显示。 + - **保护状态**: 被 Pin 或保护的记忆带有 **金色 (Amber) 光晕**。 +- **交互升级**: + - **记忆回收站**: 新增 Dock 入口与专用面板,支持浏览删除记录并一键恢复。 + - **快捷操作**: 边属性面板新增 **强化 (Reinforce)**、**保护 (Protect/Pin)**、**冷冻 (Freeze)** 按钮。 + - **实时反馈**: 操作后自动刷新图谱布局与样式。 + +--- + +## [0.2.3] - 2026-01-30 + +本次更新主要集中在 **WebUI 交互体验优化** 与 **文档/配置的规范化**。 + +### 🎨 WebUI (v1.3) + +#### 加载与同步体验升级 + +- **沉浸式加载**: 全新设计的加载遮罩,采用磨砂玻璃背景 (`backdrop-filter`) 与呼吸灯文字动效,提升视觉质感。 +- **精准状态反馈**: 优化加载逻辑,明确区分“网络同步”与“拓扑计算”阶段,解决数据加载时的闪烁问题。 +- **新手引导**: 在加载界面新增基础操作提示,降低新用户上手门槛。 + +#### 全功能帮助面板 + +- **操作指南重构**: 全面翻新“操作指南”面板,新增 Dock 栏功能详解、编辑管理操作及视图配置说明。 + +### 🛠️ 工程与规范 + +#### plugin.py + +- **配置描述补全**: 修复了 `config_section_descriptions` 中缺失 `summarization`, `schedule`, `filter` 节导致的问题。 +- **版本号**: `0.2.2` → `0.2.3` + +### ⚙️ 核心与服务 + +#### Core + +- **量化逻辑修正**: 修正了 `_scalar_quantize_int8` 函数,确保向量值正确映射到 `[-128, 127]` 区间,提高量化精度。 + +#### Server + +- **缓存一致性**: 在执行删除节点/边等修改操作后,显式清除 `_relation_cache`,确保前端获取的关系数据实时更新。 + +### 🤖 脚本与数据处理 + +#### process_knowledge.py + +- **策略模式重构**: 引入了 `Strategy-Aware` 架构,支持通过 `Narrative` (叙事), `Factual` (事实), `Quote` (引用) 三种策略差异化处理文本(准确说是确认实装)(默认采用 Narrative模式)。 +- **智能分块纠错**: 新增“分块拯救” (`Chunk Rescue`) 机制,可在长叙事文本中自动识别并提取内嵌的歌词或诗句。 + +#### import_lpmm_json.py + +- **LPMM 迁移工具**: 增加了对 LPMM OpenIE JSON 格式的完整支持,能够自动计算 Hash 并迁移实体/关系数据,确保与 A_Memorix 存储格式兼容。 + +#### Project + +- **构建清理**: 优化 `.gitignore` 规则 + +--- + +## [0.2.2] - 2026-01-27 + +本次更新专注于提高 **网络请求的鲁棒性**,特别是针对嵌入服务的调用。 + +### 🛠️ 稳定性与工程改进 + +#### EmbeddingAPI + +- **可配置重试机制**: 新增 `[embedding.retry]` 配置项,允许自定义最大重试次数和等待时间。默认重试次数从 3 次增加到 10 次,以更好应对网络波动。 +- **配置项**: + - `max_attempts`: 最大重试次数 (默认: 10) + - `max_wait_seconds`: 最大等待时间 (默认: 30s) + - `min_wait_seconds`: 最小等待时间 (默认: 2s) + +#### plugin.py + +- **版本号**: `0.2.1` → `0.2.2` + +--- + +## [0.2.1] - 2026-01-26 + +本次更新重点在于 **可视化交互的全方位重构** 以及 **底层鲁棒性的进一步增强**。 + +### 🎨 可视化与交互重构 + +#### WebUI (Glassmorphism) + +- **全新视觉设计**: 采用深色磨砂玻璃 (Glassmorphism) 风格,配合动态渐变背景。 +- **Dock 菜单栏**: 底部新增 macOS 风格 Dock 栏,聚合所有常用功能。 +- **显著性视图 (Saliency View)**: 基于 **PageRank** 算法的“信息密度”滑块,支持以此过滤叶子节点,仅展示核心骨干或全量细节。 +- **功能面板**: + - **❓ 操作指南**: 内置交互说明与特性介绍。 + - **🔍 悬浮搜索**: 支持按拼音/ID 实时过滤节点。 + - **📂 记忆溯源**: 支持按源文件批量查看和删除记忆数据。 + - **📖 内容字典**: 列表化展示所有实体与关系,支持排序与筛选。 + +### 🛠️ 稳定性与工程改进 + +#### EmbeddingAPI + +- **鲁棒性增强**: 引入 `tenacity` 实现指数退避重试机制。 +- **错误处理**: 失败时返回 `NaN` 向量而非零向量,允许上层逻辑安全跳过。 + +#### MetadataStore + +- **自动修复**: 自动检测并修复 `vector_index` 列错位(文件名误存)的历史数据问题。 +- **数据统计**: 新增 `get_all_sources` 接口支持来源统计。 + +#### 脚本与工具 + +- **用户体验**: 引入 `rich` 库优化终端输出进度条与状态显示。 +- **接口开放**: `process_knowledge.py` 新增 `import_json_data` 供外部调用。 +- **LPMM 迁移**: 新增 `import_lpmm_json.py`,支持导入符合 LPMM 规范的 OpenIE JSON 数据。 + +#### plugin.py + +- **版本号**: `0.2.0` → `0.2.1` + +--- + +## [0.2.0] - 2026-01-22 + +> [!CAUTION] +> **不完全兼容变更**:v0.2.0 版本重构了底层存储架构。由于数据结构的重大调整,**旧版本的导入数据无法在新版本中完全无损兼容**。 +> 虽然部分组件支持自动迁移,但为确保数据一致性和检索质量,**强烈建议在升级后重新使用 `process_knowledge.py` 导入原始数据**。 + +本次更新为**重大版本升级**,包含向量存储架构重写、检索逻辑强化及多项稳定性改进。 + +### 🚀 核心架构重写 + +#### VectorStore: SQ8 量化 + Append-Only 存储 + +- **全新存储格式**: 从 `.npy` 迁移至 `vectors.bin`(float16 增量追加)和 `vectors_ids.bin`,大幅减少内存占用。 +- **原生 SQ8 量化**: 使用 Faiss `IndexScalarQuantizer(QT_8bit)`,替代手动 int8 量化逻辑。 +- **L2 Normalization 强制化**: 所有向量在存储和检索时统一执行 L2 归一化,确保 Inner Product 等价于 Cosine 相似度。 +- **Fallback 索引机制**: 新增 `IndexFlatIP` 回退索引,在 SQ8 训练完成前提供检索能力,避免冷启动无结果问题。 +- **Reservoir Sampling 训练采样**: 使用蓄水池采样收集训练数据(上限 10k),保证小数据集和流式导入场景下的训练样本多样性。 +- **线程安全**: 新增 `threading.RLock` 保护并发读写操作。 +- **自动迁移**: 支持从旧版 `.npy` 格式自动迁移至新 `.bin` 格式。 + +### ✨ 检索功能增强 + +#### KnowledgeQueryTool: 智能回退与多跳路径搜索 + +- **Smart Fallback (智能回退)**: 当向量检索置信度低于阈值 (默认 0.6) 时,自动尝试提取查询中的实体进行多跳路径搜索(`_path_search`),增强对间接关系的召回能力。 +- **结果去重 (`_deduplicate_results`)**: 新增基于内容相似度的安全去重逻辑,防止冗余结果污染 LLM 上下文,同时确保至少保留一条结果。 +- **语义关系检索 (`_semantic_search_relation`)**: 支持自然语言查询关系(无需 `S|P|O` 格式),内部使用 `REL_ONLY` 策略进行向量检索。 +- **路径搜索 (`_path_search`)**: 新增 `GraphStore.find_paths` 调用,支持查找两个实体间的间接连接路径(最大深度 3,最多 5 条路径)。 +- **Clean Output**: LLM 上下文中不再包含原始相似度分数,避免模型偏见。 + +#### DualPathRetriever: 并发控制与调试模式 + +- **PPR 并发限制 (`ppr_concurrency_limit`)**: 新增 Semaphore 控制 PageRank 计算并发数,防止 CPU 峰值过载。 +- **Debug 模式**: 新增 `debug` 配置项,启用时打印检索结果原文到日志。 +- **Entity-Pivot 关系检索**: 优化 `_retrieve_relations_only` 策略,通过检索实体后扩展其关联关系,替代直接检索关系向量。 + +### ⚙️ 配置与 Schema 扩展 + +#### plugin.py + +- **版本号**: `0.1.3` → `0.2.0` +- **默认配置版本**: `config_version` 默认值更新为 `2.0.0` +- **新增配置项**: + - `retrieval.relation_semantic_fallback` (bool): 是否启用关系查询的语义回退。 + - `retrieval.relation_fallback_min_score` (float): 语义回退的最小相似度阈值。 +- **相对路径支持**: `storage.data_dir` 现在支持相对路径(相对于插件目录),默认值改为 `./data`。 +- **全局实例获取**: 新增 `A_MemorixPlugin.get_global_instance()` 静态方法,供组件可靠获取插件实例。 + +#### config.toml / \_manifest.json + +- **新增 `ppr_concurrency_limit`**: 控制 PPR 算法并发数。 +- **新增训练阈值配置**: `embedding.min_train_threshold` 控制触发 SQ8 训练的最小样本数。 + +### 🛠️ 稳定性与工程改进 + +#### GraphStore + +- **`find_paths` 方法**: 新增多跳路径查找功能,支持 BFS 搜索指定深度内的实体间路径。 +- **`find_node` 方法**: 新增大小写不敏感的节点查找。 + +#### MetadataStore + +- **Schema 迁移**: 自动添加缺失的 `is_permanent`, `last_accessed`, `access_count` 字段。 + +#### 脚本与工具 + +- **新增脚本**: + - `scripts/diagnose_relations_source.py`: 诊断关系溯源问题。 + - `scripts/verify_search_robustness.py`: 验证检索鲁棒性。 + - `scripts/run_stress_test.py`, `stress_test_data.py`: 压力测试套件。 + - `scripts/migrate_canonicalization.py`, `migrate_paragraph_relations.py`: 数据迁移工具。 +- **目录整理**: 将大量旧版测试脚本移动至 `deprecated/` 目录。 + +### 🗑️ 移除与废弃 + +- 废弃 `vectors.npy` 存储格式(自动迁移至 `.bin`)。 + +--- + +## [0.1.3] - 上一个稳定版本 + +- 初始发布,包含基础双路检索功能。 +- 手动 Int8 向量量化。 +- 基于 `.npy` 的向量存储。 diff --git a/plugins/A_memorix/CONFIG_REFERENCE.md b/plugins/A_memorix/CONFIG_REFERENCE.md new file mode 100644 index 00000000..ada8aec5 --- /dev/null +++ b/plugins/A_memorix/CONFIG_REFERENCE.md @@ -0,0 +1,292 @@ +# A_Memorix 配置参考 (v2.0.0) + +本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 8`)。 + +说明: + +- 本文只覆盖 **当前运行时实际读取** 的配置键。 +- 旧版 `/query`、`/memory`、`/visualize` 命令体系相关配置,不再作为主路径说明。 +- 未配置的键会回退到代码默认值。 + +## 最小可用配置 + +```toml +[plugin] +enabled = true + +[storage] +data_dir = "./data" + +[embedding] +model_name = "auto" +dimension = 1024 +batch_size = 32 +max_concurrent = 5 +enable_cache = false +quantization_type = "int8" + +[retrieval] +top_k_paragraphs = 20 +top_k_relations = 10 +top_k_final = 10 +alpha = 0.5 +enable_ppr = true +ppr_alpha = 0.85 +ppr_timeout_seconds = 1.5 +ppr_concurrency_limit = 4 +enable_parallel = true + +[retrieval.sparse] +enabled = true + +[episode] +enabled = true +generation_enabled = true +pending_batch_size = 20 +pending_max_retry = 3 + +[person_profile] +enabled = true + +[memory] +enabled = true +half_life_hours = 24.0 +prune_threshold = 0.1 + +[advanced] +enable_auto_save = true +auto_save_interval_minutes = 5 + +[web.import] +enabled = true + +[web.tuning] +enabled = true +``` + +## 1. 存储与嵌入 + +### `storage` + +- `storage.data_dir` (默认 `./data`) +: 数据目录。相对路径按插件目录解析。 + +### `embedding` + +- `embedding.model_name` (默认 `auto`) +: embedding 模型选择。 +- `embedding.dimension` (默认 `1024`) +: 期望维度(运行时会做真实探测并校验)。 +- `embedding.batch_size` (默认 `32`) +- `embedding.max_concurrent` (默认 `5`) +- `embedding.enable_cache` (默认 `false`) +- `embedding.retry` (默认 `{}`) +: embedding 调用重试策略。 +- `embedding.quantization_type` +: 当前主路径仅建议 `int8`。 + +## 2. 检索 + +### `retrieval` 主键 + +- `retrieval.top_k_paragraphs` (默认 `20`) +- `retrieval.top_k_relations` (默认 `10`) +- `retrieval.top_k_final` (默认 `10`) +- `retrieval.alpha` (默认 `0.5`) +- `retrieval.enable_ppr` (默认 `true`) +- `retrieval.ppr_alpha` (默认 `0.85`) +- `retrieval.ppr_timeout_seconds` (默认 `1.5`) +- `retrieval.ppr_concurrency_limit` (默认 `4`) +- `retrieval.enable_parallel` (默认 `true`) +- `retrieval.relation_vectorization.enabled` (默认 `false`) + +### `retrieval.sparse` (`SparseBM25Config`) + +常用键(默认值): + +- `enabled = true` +- `backend = "fts5"` +- `lazy_load = true` +- `mode = "auto"` (`auto`/`fallback_only`/`hybrid`) +- `tokenizer_mode = "jieba"` (`jieba`/`mixed`/`char_2gram`) +- `char_ngram_n = 2` +- `candidate_k = 80` +- `relation_candidate_k = 60` +- `enable_ngram_fallback_index = true` +- `enable_relation_sparse_fallback = true` + +### `retrieval.fusion` (`FusionConfig`) + +- `method` (默认 `weighted_rrf`) +- `rrf_k` (默认 `60`) +- `vector_weight` (默认 `0.7`) +- `bm25_weight` (默认 `0.3`) +- `normalize_score` (默认 `true`) +- `normalize_method` (默认 `minmax`) + +### `retrieval.search.relation_intent` (`RelationIntentConfig`) + +- `enabled` (默认 `true`) +- `alpha_override` (默认 `0.35`) +- `relation_candidate_multiplier` (默认 `4`) +- `preserve_top_relations` (默认 `3`) +- `force_relation_sparse` (默认 `true`) +- `pair_predicate_rerank_enabled` (默认 `true`) +- `pair_predicate_limit` (默认 `3`) + +### `retrieval.search.graph_recall` (`GraphRelationRecallConfig`) + +- `enabled` (默认 `true`) +- `candidate_k` (默认 `24`) +- `max_hop` (默认 `1`) +- `allow_two_hop_pair` (默认 `true`) +- `max_paths` (默认 `4`) + +### `retrieval.aggregate` + +- `retrieval.aggregate.rrf_k` +- `retrieval.aggregate.weights` + +用于聚合检索阶段混合策略;未配置时走代码默认行为。 + +## 3. 阈值过滤 + +### `threshold` (`ThresholdConfig`) + +- `threshold.min_threshold` (默认 `0.3`) +- `threshold.max_threshold` (默认 `0.95`) +- `threshold.percentile` (默认 `75.0`) +- `threshold.std_multiplier` (默认 `1.5`) +- `threshold.min_results` (默认 `3`) +- `threshold.enable_auto_adjust` (默认 `true`) + +## 4. 聊天过滤 + +### `filter` + +用于 `respect_filter=true` 场景(检索和写入都支持)。 + +```toml +[filter] +enabled = true +mode = "blacklist" # blacklist / whitelist +chats = ["group:123", "user:456", "stream:abc"] +``` + +规则: + +- `blacklist`:命中列表即拒绝 +- `whitelist`:仅列表内允许 +- 列表为空时: + - `blacklist` => 全允许 + - `whitelist` => 全拒绝 + +## 5. Episode + +### `episode` + +- `episode.enabled` (默认 `true`) +- `episode.generation_enabled` (默认 `true`) +- `episode.pending_batch_size` (默认 `20`,部分路径默认 `12`) +- `episode.pending_max_retry` (默认 `3`) +- `episode.max_paragraphs_per_call` (默认 `20`) +- `episode.max_chars_per_call` (默认 `6000`) +- `episode.source_time_window_hours` (默认 `24`) +- `episode.segmentation_model` (默认 `auto`) + +## 6. 人物画像 + +### `person_profile` + +- `person_profile.enabled` (默认 `true`) +- `person_profile.refresh_interval_minutes` (默认 `30`) +- `person_profile.active_window_hours` (默认 `72`) +- `person_profile.max_refresh_per_cycle` (默认 `50`) +- `person_profile.top_k_evidence` (默认 `12`) + +## 7. 记忆演化与回收 + +### `memory` + +- `memory.enabled` (默认 `true`) +- `memory.half_life_hours` (默认 `24.0`) +- `memory.base_decay_interval_hours` (默认 `1.0`) +- `memory.prune_threshold` (默认 `0.1`) +- `memory.freeze_duration_hours` (默认 `24.0`) + +### `memory.orphan` + +- `enable_soft_delete` (默认 `true`) +- `entity_retention_days` (默认 `7.0`) +- `paragraph_retention_days` (默认 `7.0`) +- `sweep_grace_hours` (默认 `24.0`) + +## 8. 高级运行时 + +### `advanced` + +- `advanced.enable_auto_save` (默认 `true`) +- `advanced.auto_save_interval_minutes` (默认 `5`) +- `advanced.debug` (默认 `false`) +- `advanced.extraction_model` (默认 `auto`) + +## 9. 导入中心 (`web.import`) + +### 开关与限流 + +- `web.import.enabled` (默认 `true`) +- `web.import.max_queue_size` (默认 `20`) +- `web.import.max_files_per_task` (默认 `200`) +- `web.import.max_file_size_mb` (默认 `20`) +- `web.import.max_paste_chars` (默认 `200000`) +- `web.import.default_file_concurrency` (默认 `2`) +- `web.import.default_chunk_concurrency` (默认 `4`) +- `web.import.max_file_concurrency` (默认 `6`) +- `web.import.max_chunk_concurrency` (默认 `12`) +- `web.import.poll_interval_ms` (默认 `1000`) + +### 重试与路径 + +- `web.import.llm_retry.max_attempts` (默认 `4`) +- `web.import.llm_retry.min_wait_seconds` (默认 `3`) +- `web.import.llm_retry.max_wait_seconds` (默认 `40`) +- `web.import.llm_retry.backoff_multiplier` (默认 `3`) +- `web.import.path_aliases` (默认内置 `raw/lpmm/plugin_data`) + +### 转换阶段 + +- `web.import.convert.enable_staging_switch` (默认 `true`) +- `web.import.convert.keep_backup_count` (默认 `3`) + +## 10. 调优中心 (`web.tuning`) + +- `web.tuning.enabled` (默认 `true`) +- `web.tuning.max_queue_size` (默认 `8`) +- `web.tuning.poll_interval_ms` (默认 `1200`) +- `web.tuning.eval_query_timeout_seconds` (默认 `10.0`) +- `web.tuning.default_intensity` (默认 `standard`) +- `web.tuning.default_objective` (默认 `precision_priority`) +- `web.tuning.default_top_k_eval` (默认 `20`) +- `web.tuning.default_sample_size` (默认 `24`) +- `web.tuning.llm_retry.max_attempts` (默认 `3`) +- `web.tuning.llm_retry.min_wait_seconds` (默认 `2`) +- `web.tuning.llm_retry.max_wait_seconds` (默认 `20`) +- `web.tuning.llm_retry.backoff_multiplier` (默认 `2`) + +## 11. 兼容性提示 + +- 若你从 `1.x` 升级,请优先运行: + +```bash +python plugins/A_memorix/scripts/release_vnext_migrate.py preflight --strict +python plugins/A_memorix/scripts/release_vnext_migrate.py migrate --verify-after +python plugins/A_memorix/scripts/release_vnext_migrate.py verify --strict +``` + +- 启动前再执行: + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +以避免 embedding 维度与向量库不匹配导致运行时异常。 diff --git a/plugins/A_memorix/IMPORT_GUIDE.md b/plugins/A_memorix/IMPORT_GUIDE.md new file mode 100644 index 00000000..618690e0 --- /dev/null +++ b/plugins/A_memorix/IMPORT_GUIDE.md @@ -0,0 +1,335 @@ +# A_Memorix 导入指南 (v2.0.0) + +本文档对应当前 `2.0.0` 代码路径,覆盖两类导入方式: + +1. 脚本导入(离线批处理) +2. `memory_import_admin` 任务导入(在线任务化) + +## 1. 导入前检查 + +建议先执行: + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +再确认: + +- `storage.data_dir` 路径可写 +- embedding 配置可用 +- 若是升级项目,先完成迁移脚本 + +## 2. 方式 A:脚本导入(推荐起步) + +## 2.1 原始文本导入 + +将 `.txt` 文件放入: + +```text +plugins/A_memorix/data/raw/ +``` + +执行: + +```bash +python plugins/A_memorix/scripts/process_knowledge.py +``` + +常用参数: + +```bash +python plugins/A_memorix/scripts/process_knowledge.py --force +python plugins/A_memorix/scripts/process_knowledge.py --chat-log +python plugins/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30" +``` + +## 2.2 OpenIE JSON 导入 + +```bash +python plugins/A_memorix/scripts/import_lpmm_json.py +``` + +## 2.3 LPMM 数据转换 + +```bash +python plugins/A_memorix/scripts/convert_lpmm.py -i -o plugins/A_memorix/data +``` + +## 2.4 历史数据迁移 + +```bash +python plugins/A_memorix/scripts/migrate_chat_history.py --help +python plugins/A_memorix/scripts/migrate_maibot_memory.py --help +python plugins/A_memorix/scripts/migrate_person_memory_points.py --help +``` + +## 2.5 导入后修复与重建 + +```bash +python plugins/A_memorix/scripts/backfill_temporal_metadata.py --dry-run +python plugins/A_memorix/scripts/backfill_relation_vectors.py --limit 1000 +python plugins/A_memorix/scripts/rebuild_episodes.py --all --wait +python plugins/A_memorix/scripts/audit_vector_consistency.py --json +``` + +## 3. 方式 B:`memory_import_admin` 任务导入 + +`memory_import_admin` 是在线任务化导入入口,适合宿主侧面板或自动化管道。 + +### 3.1 常用 action + +- `settings` / `get_settings` / `get_guide` +- `path_aliases` / `get_path_aliases` +- `resolve_path` +- `create_upload` +- `create_paste` +- `create_raw_scan` +- `create_lpmm_openie` +- `create_lpmm_convert` +- `create_temporal_backfill` +- `create_maibot_migration` +- `list` +- `get` +- `chunks` / `get_chunks` +- `cancel` +- `retry_failed` + +### 3.2 调用示例 + +查看运行时设置: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "settings" + } +} +``` + +创建粘贴导入任务: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "create_paste", + "content": "今天完成了检索调优回归。", + "input_mode": "plain_text", + "source": "manual:worklog" + } +} +``` + +查询任务列表: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "list", + "limit": 20 + } +} +``` + +查看任务详情: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "get", + "task_id": "", + "include_chunks": true + } +} +``` + +重试失败任务: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "retry_failed", + "task_id": "" + } +} +``` + +## 4. 直接写入 Tool(非任务化) + +若你不需要任务编排,也可以直接调用: + +- `ingest_summary` +- `ingest_text` + +示例: + +```json +{ + "tool": "ingest_text", + "arguments": { + "external_id": "note:2026-03-18:001", + "source_type": "note", + "text": "新的召回阈值方案已通过评审", + "chat_id": "group:dev", + "tags": ["worklog", "review"] + } +} +``` + +`external_id` 建议全局唯一,用于幂等去重。 + +## 5. 时间字段建议 + +可用时间字段(按常见优先级): + +- `timestamp` +- `time_start` +- `time_end` + +建议: + +- 事件类记录优先写 `time_start/time_end` +- 仅有单点时间时写 `timestamp` +- 历史数据可先导入,再用 `backfill_temporal_metadata.py` 回填 + +## 6. source_type 建议 + +常见值: + +- `chat_summary` +- `note` +- `person_fact` +- `lpmm_openie` +- `migration` + +建议保持稳定枚举,便于后续按来源治理与重建 Episode。 + +## 7. 导入完成后的验证 + +建议执行以下顺序: + +1. `memory_stats` 看总量是否增长 +2. `search_memory`(`mode=search`/`aggregate`)抽检召回 +3. `memory_episode_admin` 的 `status`/`query` 检查 Episode 生成 +4. `memory_runtime_admin` 的 `self_check` 再确认运行时健康 + +## 8. 常见问题 + +### Q1: 导入任务创建成功但无写入 + +- 检查聊天过滤配置 `filter`(若 `respect_filter=true` 可能被过滤) +- 检查任务详情中的失败原因与分块状态 + +### Q2: 任务反复失败 + +- 检查 embedding 与 LLM 可用性 +- 降低并发(`web.import.default_*_concurrency`) +- 调整重试参数(`web.import.llm_retry.*`) + +### Q3: 导入后检索效果差 + +- 先做 `runtime_self_check` +- 检查 `retrieval.sparse` 是否启用 +- 使用 `memory_tuning_admin` 创建调优任务做参数回归 + +## 9. 相关文档 + +- [QUICK_START.md](QUICK_START.md) +- [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) +- [README.md](README.md) +- [CHANGELOG.md](CHANGELOG.md) + +## 10. 附录:策略模式参考 + +A_Memorix 导入链路仍然遵循策略模式(Strategy-Aware)。`process_knowledge.py` 会自动识别文本类型,也支持手动指定。 + +| 策略类型 | 适用场景 | 核心逻辑 | 自动识别特征 | +| :-- | :-- | :-- | :-- | +| `Narrative` (叙事) | 小说、同人文、剧本、长篇故事 | 按场景/章节切分,使用滑动窗口;提取事件与角色关系 | `#`、`Chapter`、`***` 等章节标记 | +| `Factual` (事实) | 设定集、百科、说明书 | 按语义块切分,保留列表/定义结构;提取 SPO 三元组 | 列表符号、`术语: 解释` | +| `Quote` (引用) | 歌词、诗歌、名言、台词 | 按双换行切分,原文即知识,不做概括 | 平均行长短、行数多 | + +## 11. 附录:参考用例(已恢复) + +以下样例可直接复制保存为文件测试,或作为 LLM few-shot 示例。 + +### 11.1 叙事文本 (`plugins/A_memorix/data/raw/story_demo.txt`) + +```text +# 第一章:星之子 + +艾瑞克在废墟中醒来,手中的星盘发出微弱的蓝光。他并不记得自己是如何来到这里的,只依稀记得莉莉丝最后的警告:“千万不要回头。” + +远处传来了机械守卫的轰鸣声。艾瑞克迅速收起星盘,向着北方的废弃都市奔去。他知道,那里有反抗军唯一的据点。 + +*** + +# 第二章:重逢 + +在反抗军的地下掩体中,艾瑞克见到了那个熟悉的身影。莉莉丝正站在全息地图前,眉头紧锁。 + +“你还是来了。”莉莉丝没有回头,但声音中带着一丝颤抖。 +“我必须来,”艾瑞克握紧了拳头,“为了解开星盘的秘密,也为了你。” +``` + +### 11.2 事实文本 (`plugins/A_memorix/data/raw/rules_demo.txt`) + +```text +# 联邦安全协议 v2.0 + +## 核心法则 +1. **第一公理**:任何人工智能不得伤害人类个体,或因不作为而使人类个体受到伤害。 +2. **第二公理**:人工智能必须服从人类的命令,除非该命令与第一公理冲突。 + +## 术语定义 +- **以太网络**:覆盖全联邦的高速量子通讯网络。 +- **黑色障壁**:用于隔离高危 AI 的物理防火墙设施。 +``` + +### 11.3 引用文本 (`plugins/A_memorix/data/raw/poem_demo.txt`) + +```text +致橡树 + +我如果爱你—— +绝不像攀援的凌霄花, +借你的高枝炫耀自己; + +我如果爱你—— +绝不学痴情的鸟儿, +为绿荫重复单调的歌曲; + +也不止像泉源, +常年送来清凉的慰籍; +也不止像险峰, +增加你的高度,衬托你的威仪。 +``` + +### 11.4 LPMM JSON (`lpmm_data-openie.json`) + +```json +{ + "docs": [ + { + "passage": "艾瑞克手中的星盘是打开遗迹的唯一钥匙。", + "extracted_triples": [ + ["星盘", "是", "唯一的钥匙"], + ["星盘", "属于", "艾瑞克"], + ["钥匙", "用于", "遗迹"] + ], + "extracted_entities": ["星盘", "艾瑞克", "遗迹", "钥匙"] + }, + { + "passage": "莉莉丝是反抗军的现任领袖。", + "extracted_triples": [ + ["莉莉丝", "是", "领袖"], + ["领袖", "所属", "反抗军"] + ] + } + ] +} +``` diff --git a/plugins/A_memorix/LICENSE b/plugins/A_memorix/LICENSE new file mode 100644 index 00000000..e20b431b --- /dev/null +++ b/plugins/A_memorix/LICENSE @@ -0,0 +1,661 @@ +GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/plugins/A_memorix/LICENSE-MAIBOT-GPL.md b/plugins/A_memorix/LICENSE-MAIBOT-GPL.md new file mode 100644 index 00000000..83108097 --- /dev/null +++ b/plugins/A_memorix/LICENSE-MAIBOT-GPL.md @@ -0,0 +1,22 @@ +Special GPL License Grant for MaiBot + +Licensor +- A_Dawn + +Effective date +- 2026-03-18 + +Default license +- This repository is licensed under AGPL-3.0 by default (see `LICENSE`). + +Additional grant for MaiBot +- The copyright holder(s) of this repository grant an additional, non-exclusive permission to + the project at `https://github.com/Mai-with-u/MaiBot` (including its maintainers and contributors) + to use, modify, and redistribute code from this repository under GPL-3.0. + +Scope +- This additional GPL grant is intended for use in the MaiBot project context. +- For all other uses not covered by the grant above, AGPL-3.0 remains the applicable license. + +No warranty +- This grant is provided without warranty, consistent with AGPL-3.0 and GPL-3.0. diff --git a/plugins/A_memorix/QUICK_START.md b/plugins/A_memorix/QUICK_START.md new file mode 100644 index 00000000..76750453 --- /dev/null +++ b/plugins/A_memorix/QUICK_START.md @@ -0,0 +1,210 @@ +# A_Memorix Quick Start (v2.0.0) + +本文档面向当前 `2.0.0` 架构(SDK Tool 接口)。 + +## 0. 版本与接口变更 + +- 当前插件版本:`2.0.0` +- 接口形态:`memory_provider` + Tool 调用 +- 旧版 slash 命令(如 `/query`、`/memory`、`/visualize`)不再作为本分支主文档入口 + +## 1. 环境准备 + +- Python 3.10+ +- 与 MaiBot 主程序相同的运行环境 +- 可访问你配置的 embedding 服务 + +安装依赖: + +```bash +pip install -r plugins/A_memorix/requirements.txt --upgrade +``` + +如果当前目录就是插件目录,也可以: + +```bash +pip install -r requirements.txt --upgrade +``` + +## 2. 启用插件 + +在主程序插件配置中启用 `A_Memorix`。 + +若你使用 `plugins/A_memorix/config.toml` 方式,最小示例: + +```toml +[plugin] +enabled = true + +[storage] +data_dir = "./data" + +[embedding] +model_name = "auto" +dimension = 1024 +batch_size = 32 +max_concurrent = 5 +quantization_type = "int8" +``` + +## 3. 运行时自检(强烈建议) + +先确认 embedding 实际输出维度与向量库兼容: + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +如果结果 `ok=false`,先修复 embedding 配置或向量库,再继续导入。 + +## 4. 导入数据 + +### 4.1 文本批量导入 + +把文本放到: + +```text +plugins/A_memorix/data/raw/ +``` + +执行: + +```bash +python plugins/A_memorix/scripts/process_knowledge.py +``` + +常用参数: + +```bash +python plugins/A_memorix/scripts/process_knowledge.py --force +python plugins/A_memorix/scripts/process_knowledge.py --chat-log +python plugins/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30" +``` + +### 4.2 其他导入脚本 + +```bash +python plugins/A_memorix/scripts/import_lpmm_json.py +python plugins/A_memorix/scripts/convert_lpmm.py -i -o plugins/A_memorix/data +python plugins/A_memorix/scripts/migrate_chat_history.py --help +python plugins/A_memorix/scripts/migrate_maibot_memory.py --help +python plugins/A_memorix/scripts/migrate_person_memory_points.py --help +``` + +## 5. 核心 Tool 调用 + +### 5.1 检索 + +```json +{ + "tool": "search_memory", + "arguments": { + "query": "项目复盘", + "mode": "aggregate", + "limit": 5, + "chat_id": "group:dev" + } +} +``` + +`mode` 支持:`search/time/hybrid/episode/aggregate` + +### 5.2 写入摘要 + +```json +{ + "tool": "ingest_summary", + "arguments": { + "external_id": "chat_summary:group-dev:2026-03-18", + "chat_id": "group:dev", + "text": "今天完成了检索调优评审" + } +} +``` + +### 5.3 写入普通记忆 + +```json +{ + "tool": "ingest_text", + "arguments": { + "external_id": "note:2026-03-18:001", + "source_type": "note", + "text": "模型切换后召回质量更稳定", + "chat_id": "group:dev", + "tags": ["worklog"] + } +} +``` + +### 5.4 画像与维护 + +```json +{ + "tool": "get_person_profile", + "arguments": { + "person_id": "Alice", + "limit": 8 + } +} +``` + +```json +{ + "tool": "maintain_memory", + "arguments": { + "action": "protect", + "target": "模型切换后召回质量更稳定", + "hours": 24 + } +} +``` + +```json +{ + "tool": "memory_stats", + "arguments": {} +} +``` + +## 6. 管理 Tool(进阶) + +`2.0.0` 提供完整管理工具: + +- `memory_graph_admin` +- `memory_source_admin` +- `memory_episode_admin` +- `memory_profile_admin` +- `memory_runtime_admin` +- `memory_import_admin` +- `memory_tuning_admin` +- `memory_v5_admin` +- `memory_delete_admin` + +可先用 `action=list` / `action=status` 等只读动作验证链路。 + +## 7. 常见问题 + +### Q1: 检索为空 + +1. 先看 `memory_stats` 是否有段落/关系 +2. 检查 `chat_id`、`person_id` 过滤条件是否过严 +3. 运行 `runtime_self_check.py --json` 确认 embedding 维度无误 + +### Q2: 启动时报向量维度不一致 + +- 原因:现有向量库维度与当前 embedding 输出不一致 +- 处理:恢复原配置或重建向量数据后再启动 + +### Q3: Web 页面打不开 + +本分支不内置独立 `server.py`。 + +- `web/index.html`、`web/import.html`、`web/tuning.html` 由宿主侧路由/API 集成暴露 +- 请检查宿主是否已映射对应静态页与 `/api/*` 接口 + +## 8. 下一步 + +- 配置细节见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) +- 导入细节见 [IMPORT_GUIDE.md](IMPORT_GUIDE.md) +- 版本历史见 [CHANGELOG.md](CHANGELOG.md) diff --git a/plugins/A_memorix/README.md b/plugins/A_memorix/README.md new file mode 100644 index 00000000..1afb1b5f --- /dev/null +++ b/plugins/A_memorix/README.md @@ -0,0 +1,216 @@ +# A_Memorix + +**长期记忆与认知增强插件** (v2.0.0) + +> 消えていかない感覚 , まだまだ足りてないみたい ! + +A_Memorix 是面向 MaiBot SDK 的 `memory_provider` 插件。 +它把文本、关系、Episode、人物画像和检索调优统一在一套运行时里,适合长期运行的 Agent 记忆场景。 + +## 快速导航 + +- [快速入门](QUICK_START.md) +- [配置参数详解](CONFIG_REFERENCE.md) +- [导入指南与最佳实践](IMPORT_GUIDE.md) +- [更新日志](CHANGELOG.md) + +## 2.0.0 版本定位 + +`v2.0.0` 是一次架构收敛版本,当前分支以 **SDK Tool 接口** 为主: + +- 旧 `components/commands/*`、`components/tools/*` 与 `server.py` 已移除。 +- 统一入口为 [`plugin.py`](plugin.py) + [`core/runtime/sdk_memory_kernel.py`](core/runtime/sdk_memory_kernel.py)。 +- 元数据 schema 为 `v8`,新增外部引用与运维操作记录(如 `external_memory_refs`、`memory_v5_operations`、`delete_operations`)。 + +如果你还在使用旧版 slash 命令(如 `/query`、`/memory`、`/visualize`),需要按本文的 Tool 接口迁移。 + +## 核心能力 + +- 双路检索:向量 + 图谱关系联合召回,支持 `search/time/hybrid/episode/aggregate`。 +- 写入与去重:`external_id` 幂等、段落/关系联合写入、Episode pending 队列处理。 +- Episode 能力:按 source 重建、状态查询、批处理 pending。 +- 人物画像:自动快照 + 手动 override。 +- 管理能力:图谱、来源、Episode、画像、导入、调优、V5 运维、删除恢复全套管理工具。 + +## Tool 接口 (v2.0.0) + +### 基础工具 + +| Tool | 说明 | 关键参数 | +| --- | --- | --- | +| `search_memory` | 检索长期记忆 | `query` `mode` `limit` `chat_id` `person_id` `time_start` `time_end` | +| `ingest_summary` | 写入聊天摘要 | `external_id` `chat_id` `text` | +| `ingest_text` | 写入普通文本记忆 | `external_id` `source_type` `text` | +| `get_person_profile` | 获取人物画像 | `person_id` `chat_id` `limit` | +| `maintain_memory` | 维护关系状态 | `action=reinforce/protect/restore/freeze/recycle_bin` | +| `memory_stats` | 获取统计信息 | 无 | + +### 管理工具 + +| Tool | 常用 action | +| --- | --- | +| `memory_graph_admin` | `get_graph/create_node/delete_node/rename_node/create_edge/delete_edge/update_edge_weight` | +| `memory_source_admin` | `list/delete/batch_delete` | +| `memory_episode_admin` | `query/list/get/status/rebuild/process_pending` | +| `memory_profile_admin` | `query/list/set_override/delete_override` | +| `memory_runtime_admin` | `save/get_config/self_check/refresh_self_check/set_auto_save` | +| `memory_import_admin` | `settings/get_guide/create_upload/create_paste/create_raw_scan/create_lpmm_openie/create_lpmm_convert/create_temporal_backfill/create_maibot_migration/list/get/chunks/cancel/retry_failed` | +| `memory_tuning_admin` | `settings/get_profile/apply_profile/rollback_profile/export_profile/create_task/list_tasks/get_task/get_rounds/cancel/apply_best/get_report` | +| `memory_v5_admin` | `status/recycle_bin/restore/reinforce/weaken/remember_forever/forget` | +| `memory_delete_admin` | `preview/execute/restore/get_operation/list_operations/purge` | + +## 调用示例 + +```json +{ + "tool": "search_memory", + "arguments": { + "query": "项目复盘", + "mode": "aggregate", + "limit": 5, + "chat_id": "group:dev" + } +} +``` + +```json +{ + "tool": "ingest_text", + "arguments": { + "external_id": "note:2026-03-18:001", + "source_type": "note", + "text": "今天完成了检索调优评审", + "chat_id": "group:dev", + "tags": ["worklog"] + } +} +``` + +```json +{ + "tool": "maintain_memory", + "arguments": { + "action": "protect", + "target": "完成了 检索调优评审", + "hours": 72 + } +} +``` + +## 快速开始 + +### 1. 安装依赖 + +在 MaiBot 主程序使用的同一个 Python 环境中执行: + +```bash +pip install -r plugins/A_memorix/requirements.txt --upgrade +``` + +如果当前目录已经是插件目录,也可以执行: + +```bash +pip install -r requirements.txt --upgrade +``` + +### 2. 启用插件 + +在 `config.toml` 中启用插件(路径取决于你的宿主部署): + +```toml +[plugin] +enabled = true +``` + +### 3. 先做运行时自检 + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +### 4. 导入文本并验证统计 + +```bash +python plugins/A_memorix/scripts/process_knowledge.py +``` + +然后调用 `memory_stats` 或 `search_memory` 检查是否有数据。 + +## Web 页面说明 + +仓库内保留了 Web 静态页面: + +- `web/index.html`(图谱与记忆管理) +- `web/import.html`(导入中心) +- `web/tuning.html`(检索调优) + +当前分支不再内置独立 `server.py`,页面路由与 API 暴露由宿主侧集成负责。 + +## 常用脚本 + +| 脚本 | 用途 | +| --- | --- | +| `process_knowledge.py` | 批量导入原始文本(策略感知) | +| `import_lpmm_json.py` | 导入 OpenIE JSON | +| `convert_lpmm.py` | 转换 LPMM 数据 | +| `migrate_chat_history.py` | 迁移 chat_history | +| `migrate_maibot_memory.py` | 迁移 MaiBot 历史记忆 | +| `migrate_person_memory_points.py` | 迁移 person memory points | +| `backfill_temporal_metadata.py` | 回填时间元数据 | +| `audit_vector_consistency.py` | 审计向量一致性 | +| `backfill_relation_vectors.py` | 回填关系向量 | +| `rebuild_episodes.py` | 按 source 重建 Episode | +| `release_vnext_migrate.py` | 升级预检/迁移/校验 | +| `runtime_self_check.py` | 真实 embedding 运行时自检 | + +## 配置重点 + +完整配置见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md)。 + +高频配置项: + +- `storage.data_dir` +- `embedding.dimension` +- `embedding.quantization_type`(当前仅支持 `int8`) +- `retrieval.*` +- `retrieval.sparse.*` +- `episode.*` +- `person_profile.*` +- `memory.*` +- `web.import.*` +- `web.tuning.*` + +## Troubleshooting + +### SQLite 无 FTS5 + +如果环境中的 SQLite 未启用 `FTS5`,可关闭稀疏检索: + +```toml +[retrieval.sparse] +enabled = false +``` + +### 向量维度不一致 + +若日志提示当前 embedding 输出维度与既有向量库不一致,请先执行: + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +必要时重建向量或调整 embedding 配置后再启动插件。 + +## 许可证 + +默认许可证为 [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0)(见 `LICENSE`)。 + +针对 `Mai-with-u/MaiBot` 项目的 GPL 额外授权见 `LICENSE-MAIBOT-GPL.md`。 + +除上述额外授权外,其他使用场景仍适用 AGPL-3.0。 + +## 贡献说明 + +当前不接受 PR,只接受 issue。 + +**作者**: `A_Dawn` diff --git a/plugins/A_memorix/_manifest.json b/plugins/A_memorix/_manifest.json index a45b2f73..e4217fdd 100644 --- a/plugins/A_memorix/_manifest.json +++ b/plugins/A_memorix/_manifest.json @@ -55,6 +55,51 @@ "type": "tool", "name": "memory_stats", "description": "查询记忆统计" + }, + { + "type": "tool", + "name": "memory_graph_admin", + "description": "图谱管理接口" + }, + { + "type": "tool", + "name": "memory_source_admin", + "description": "来源管理接口" + }, + { + "type": "tool", + "name": "memory_episode_admin", + "description": "Episode 管理接口" + }, + { + "type": "tool", + "name": "memory_profile_admin", + "description": "画像管理接口" + }, + { + "type": "tool", + "name": "memory_runtime_admin", + "description": "运行时管理接口" + }, + { + "type": "tool", + "name": "memory_import_admin", + "description": "导入管理接口" + }, + { + "type": "tool", + "name": "memory_tuning_admin", + "description": "调优管理接口" + }, + { + "type": "tool", + "name": "memory_v5_admin", + "description": "V5 记忆管理接口" + }, + { + "type": "tool", + "name": "memory_delete_admin", + "description": "删除管理接口" } ] }, diff --git a/plugins/A_memorix/core/embedding/api_adapter.py b/plugins/A_memorix/core/embedding/api_adapter.py index 4262ddb9..d11e2d05 100644 --- a/plugins/A_memorix/core/embedding/api_adapter.py +++ b/plugins/A_memorix/core/embedding/api_adapter.py @@ -1,46 +1,55 @@ """ -Hash-based embedding adapter used by the SDK runtime. +请求式嵌入 API 适配器。 -The plugin runtime cannot import MaiBot host embedding internals from ``src.chat`` -or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in -Runner by generating deterministic dense vectors locally. +恢复 v1.0.1 的真实 embedding 请求语义: +- 通过宿主模型配置探测/请求 embedding +- 支持 dimensions 参数 +- 支持批量与重试 +- 不再提供本地 hash fallback """ from __future__ import annotations -import hashlib -import re +import asyncio import time -from typing import List, Optional, Union +from typing import Any, List, Optional, Union +import aiohttp import numpy as np +import openai from src.common.logger import get_logger - +from src.config.config import config_manager +from src.config.model_configs import APIProvider, ModelInfo +from src.llm_models.exceptions import NetworkConnectionError +from src.llm_models.model_client.base_client import client_registry logger = get_logger("A_Memorix.EmbeddingAPIAdapter") -_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}") - class EmbeddingAPIAdapter: - """Deterministic local embedding adapter.""" + """适配宿主 embedding 请求接口。""" def __init__( self, batch_size: int = 32, max_concurrent: int = 5, - default_dimension: int = 256, + default_dimension: int = 1024, enable_cache: bool = False, - model_name: str = "hash-v1", + model_name: str = "auto", retry_config: Optional[dict] = None, ) -> None: self.batch_size = max(1, int(batch_size)) self.max_concurrent = max(1, int(max_concurrent)) - self.default_dimension = max(32, int(default_dimension)) + self.default_dimension = max(1, int(default_dimension)) self.enable_cache = bool(enable_cache) - self.model_name = str(model_name or "hash-v1") + self.model_name = str(model_name or "auto") + self.retry_config = retry_config or {} + self.max_attempts = max(1, int(self.retry_config.get("max_attempts", 5))) + self.max_wait_seconds = max(0.1, float(self.retry_config.get("max_wait_seconds", 40))) + self.min_wait_seconds = max(0.1, float(self.retry_config.get("min_wait_seconds", 3))) + self.backoff_multiplier = max(1.0, float(self.retry_config.get("backoff_multiplier", 3))) self._dimension: Optional[int] = None self._dimension_detected = False @@ -49,57 +58,164 @@ class EmbeddingAPIAdapter: self._total_time = 0.0 logger.info( - "EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s", - self.model_name, - self.batch_size, - self.default_dimension, + "EmbeddingAPIAdapter 初始化: " + f"batch_size={self.batch_size}, " + f"max_concurrent={self.max_concurrent}, " + f"default_dim={self.default_dimension}, " + f"model={self.model_name}" ) + def _get_current_model_config(self): + return config_manager.get_model_config() + + @staticmethod + def _find_model_info(model_name: str) -> ModelInfo: + model_cfg = config_manager.get_model_config() + for item in model_cfg.models: + if item.name == model_name: + return item + raise ValueError(f"未找到 embedding 模型: {model_name}") + + @staticmethod + def _find_provider(provider_name: str) -> APIProvider: + model_cfg = config_manager.get_model_config() + for item in model_cfg.api_providers: + if item.name == provider_name: + return item + raise ValueError(f"未找到 embedding provider: {provider_name}") + + def _resolve_candidate_model_names(self) -> List[str]: + task_config = self._get_current_model_config().model_task_config.embedding + configured = list(getattr(task_config, "model_list", []) or []) + if self.model_name and self.model_name != "auto": + return [self.model_name, *[name for name in configured if name != self.model_name]] + return configured + + @staticmethod + def _validate_embedding_vector(embedding: Any, *, source: str) -> np.ndarray: + array = np.asarray(embedding, dtype=np.float32) + if array.ndim != 1: + raise RuntimeError(f"{source} 返回的 embedding 维度非法: ndim={array.ndim}") + if array.size <= 0: + raise RuntimeError(f"{source} 返回了空 embedding") + if not np.all(np.isfinite(array)): + raise RuntimeError(f"{source} 返回了非有限 embedding 值") + return array + + async def _request_with_retry(self, client, model_info, text: str, extra_params: dict): + retriable_exceptions = ( + openai.APIConnectionError, + openai.APITimeoutError, + aiohttp.ClientError, + asyncio.TimeoutError, + NetworkConnectionError, + ) + + last_exc: Optional[BaseException] = None + for attempt in range(1, self.max_attempts + 1): + try: + return await client.get_embedding( + model_info=model_info, + embedding_input=text, + extra_params=extra_params, + ) + except retriable_exceptions as exc: + last_exc = exc + if attempt >= self.max_attempts: + raise + wait_seconds = min( + self.max_wait_seconds, + self.min_wait_seconds * (self.backoff_multiplier ** (attempt - 1)), + ) + logger.warning( + "Embedding 请求失败,重试 " + f"{attempt}/{max(1, self.max_attempts - 1)}," + f"{wait_seconds:.1f}s 后重试: {exc}" + ) + await asyncio.sleep(wait_seconds) + except Exception: + raise + + if last_exc is not None: + raise last_exc + raise RuntimeError("Embedding 请求失败:未知错误") + + async def _get_embedding_direct(self, text: str, dimensions: Optional[int] = None) -> Optional[List[float]]: + candidate_names = self._resolve_candidate_model_names() + if not candidate_names: + raise RuntimeError("embedding 任务未配置模型") + + last_exc: Optional[BaseException] = None + for candidate_name in candidate_names: + try: + model_info = self._find_model_info(candidate_name) + api_provider = self._find_provider(model_info.api_provider) + client = client_registry.get_client_class_instance(api_provider, force_new=True) + + extra_params = dict(getattr(model_info, "extra_params", {}) or {}) + if dimensions is not None: + extra_params["dimensions"] = int(dimensions) + + response = await self._request_with_retry( + client=client, + model_info=model_info, + text=text, + extra_params=extra_params, + ) + embedding = getattr(response, "embedding", None) + if embedding is None: + raise RuntimeError(f"模型 {candidate_name} 未返回 embedding") + vector = self._validate_embedding_vector( + embedding, + source=f"embedding 模型 {candidate_name}", + ) + return vector.tolist() + except Exception as exc: + last_exc = exc + logger.warning(f"embedding 模型 {candidate_name} 请求失败: {exc}") + + if last_exc is not None: + logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}") + return None + async def _detect_dimension(self) -> int: if self._dimension_detected and self._dimension is not None: return self._dimension + + logger.info("正在检测嵌入模型维度...") + try: + target_dim = self.default_dimension + logger.debug(f"尝试请求指定维度: {target_dim}") + test_embedding = await self._get_embedding_direct("test", dimensions=target_dim) + if test_embedding and isinstance(test_embedding, list): + detected_dim = len(test_embedding) + if detected_dim == target_dim: + logger.info(f"嵌入维度检测成功 (匹配配置): {detected_dim}") + else: + logger.warning( + f"请求维度 {target_dim} 但模型返回 {detected_dim},将使用模型自然维度" + ) + self._dimension = detected_dim + self._dimension_detected = True + return detected_dim + except Exception as exc: + logger.debug(f"带维度参数探测失败: {exc},尝试不带参数探测") + + try: + test_embedding = await self._get_embedding_direct("test", dimensions=None) + if test_embedding and isinstance(test_embedding, list): + detected_dim = len(test_embedding) + self._dimension = detected_dim + self._dimension_detected = True + logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}") + return detected_dim + logger.warning(f"嵌入维度检测失败,使用默认值: {self.default_dimension}") + except Exception as exc: + logger.error(f"嵌入维度检测异常: {exc},使用默认值: {self.default_dimension}") + self._dimension = self.default_dimension self._dimension_detected = True - return self._dimension - - @staticmethod - def _tokenize(text: str) -> List[str]: - clean = str(text or "").strip().lower() - if not clean: - return [] - return _TOKEN_PATTERN.findall(clean) - - @staticmethod - def _feature_weight(token: str) -> float: - digest = hashlib.sha256(token.encode("utf-8")).digest() - return 1.0 + (digest[10] / 255.0) * 0.5 - - def _encode_single(self, text: str, dimension: int) -> np.ndarray: - vector = np.zeros(dimension, dtype=np.float32) - content = str(text or "").strip() - tokens = self._tokenize(content) - if not tokens and content: - tokens = [content.lower()] - if not tokens: - vector[0] = 1.0 - return vector - - for token in tokens: - digest = hashlib.sha256(token.encode("utf-8")).digest() - bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension - sign = 1.0 if digest[8] % 2 == 0 else -1.0 - vector[bucket] += sign * self._feature_weight(token) - - second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension - if second_bucket != bucket: - vector[second_bucket] += (sign * 0.35) - - norm = float(np.linalg.norm(vector)) - if norm > 1e-8: - vector /= norm - else: - vector[0] = 1.0 - return vector + return self.default_dimension async def encode( self, @@ -109,59 +225,137 @@ class EmbeddingAPIAdapter: normalize: bool = True, dimensions: Optional[int] = None, ) -> np.ndarray: - _ = batch_size - _ = show_progress - _ = normalize + del show_progress + del normalize - started_at = time.time() - target_dimension = max(32, int(dimensions or await self._detect_dimension())) + start_time = time.time() + target_dim = int(dimensions) if dimensions is not None else int(await self._detect_dimension()) if isinstance(texts, str): - single_input = True normalized_texts = [texts] + single_input = True else: - single_input = False normalized_texts = list(texts or []) + single_input = False if not normalized_texts: - empty = np.zeros((0, target_dimension), dtype=np.float32) + empty = np.zeros((0, target_dim), dtype=np.float32) return empty[0] if single_input else empty + if batch_size is None: + batch_size = self.batch_size + try: - matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts]) + embeddings = await self._encode_batch_internal( + normalized_texts, + batch_size=max(1, int(batch_size)), + dimensions=dimensions, + ) + if embeddings.ndim == 1: + embeddings = embeddings.reshape(1, -1) self._total_encoded += len(normalized_texts) - self._total_time += time.time() - started_at - except Exception: + elapsed = time.time() - start_time + self._total_time += elapsed + logger.debug( + "编码完成: " + f"{len(normalized_texts)} 个文本, " + f"耗时 {elapsed:.2f}s, " + f"平均 {elapsed / max(1, len(normalized_texts)):.3f}s/文本" + ) + return embeddings[0] if single_input else embeddings + except Exception as exc: self._total_errors += 1 - raise + logger.error(f"编码失败: {exc}") + raise RuntimeError(f"embedding encode failed: {exc}") from exc - return matrix[0] if single_input else matrix + async def _encode_batch_internal( + self, + texts: List[str], + batch_size: int, + dimensions: Optional[int] = None, + ) -> np.ndarray: + all_embeddings: List[np.ndarray] = [] + for offset in range(0, len(texts), batch_size): + batch = texts[offset : offset + batch_size] + semaphore = asyncio.Semaphore(self.max_concurrent) - def get_statistics(self) -> dict: - avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0 + async def encode_with_semaphore(text: str, index: int): + async with semaphore: + embedding = await self._get_embedding_direct(text, dimensions=dimensions) + if embedding is None: + raise RuntimeError(f"文本 {index} 编码失败:embedding 返回为空") + vector = self._validate_embedding_vector( + embedding, + source=f"文本 {index}", + ) + return index, vector + + tasks = [ + encode_with_semaphore(text, offset + index) + for index, text in enumerate(batch) + ] + results = await asyncio.gather(*tasks) + results.sort(key=lambda item: item[0]) + all_embeddings.extend(emb for _, emb in results) + + return np.array(all_embeddings, dtype=np.float32) + + async def encode_batch( + self, + texts: List[str], + batch_size: Optional[int] = None, + num_workers: Optional[int] = None, + show_progress: bool = False, + dimensions: Optional[int] = None, + ) -> np.ndarray: + del show_progress + if num_workers is not None: + previous = self.max_concurrent + self.max_concurrent = max(1, int(num_workers)) + try: + return await self.encode(texts, batch_size=batch_size, dimensions=dimensions) + finally: + self.max_concurrent = previous + return await self.encode(texts, batch_size=batch_size, dimensions=dimensions) + + def get_embedding_dimension(self) -> int: + if self._dimension is not None: + return self._dimension + logger.warning(f"维度尚未检测,返回默认值: {self.default_dimension}") + return self.default_dimension + + def get_model_info(self) -> dict: return { "model_name": self.model_name, "dimension": self._dimension or self.default_dimension, + "dimension_detected": self._dimension_detected, + "batch_size": self.batch_size, + "max_concurrent": self.max_concurrent, "total_encoded": self._total_encoded, "total_errors": self._total_errors, - "total_time": self._total_time, - "avg_time_per_text": avg_time, + "avg_time_per_text": self._total_time / self._total_encoded if self._total_encoded else 0.0, } + def get_statistics(self) -> dict: + return self.get_model_info() + + @property + def is_model_loaded(self) -> bool: + return True + def __repr__(self) -> str: return ( - f"EmbeddingAPIAdapter(model_name={self.model_name}, " - f"dimension={self._dimension or self.default_dimension}, " - f"total_encoded={self._total_encoded})" + f"EmbeddingAPIAdapter(dim={self._dimension or self.default_dimension}, " + f"detected={self._dimension_detected}, encoded={self._total_encoded})" ) def create_embedding_api_adapter( batch_size: int = 32, max_concurrent: int = 5, - default_dimension: int = 256, + default_dimension: int = 1024, enable_cache: bool = False, - model_name: str = "hash-v1", + model_name: str = "auto", retry_config: Optional[dict] = None, ) -> EmbeddingAPIAdapter: return EmbeddingAPIAdapter( diff --git a/plugins/A_memorix/core/retrieval/dual_path.py b/plugins/A_memorix/core/retrieval/dual_path.py index cfeb343c..6ed5e71a 100644 --- a/plugins/A_memorix/core/retrieval/dual_path.py +++ b/plugins/A_memorix/core/retrieval/dual_path.py @@ -285,10 +285,10 @@ class DualPathRetriever: relation_intent_ctx = self._build_relation_intent_context(query=query, top_k=top_k) logger.info( - "执行检索: query='%s...', strategy=%s, relation_intent=%s", - query[:50], - strategy.value, - relation_intent_ctx.get("enabled", False), + "执行检索: " + f"query='{query[:50]}...', " + f"strategy={strategy.value}, " + f"relation_intent={relation_intent_ctx.get('enabled', False)}" ) if temporal and not (query or "").strip(): @@ -1408,10 +1408,10 @@ class DualPathRetriever: return results logger.debug( - "relation_rerank_applied=1 relation_pair_groups=%s relation_pair_overflow_count=%s relation_pair_limit=%s", - len(ordered_groups), - len(overflow), - pair_limit, + "relation_rerank_applied=1 " + f"relation_pair_groups={len(ordered_groups)} " + f"relation_pair_overflow_count={len(overflow)} " + f"relation_pair_limit={pair_limit}" ) rebuilt = list(results) @@ -1455,9 +1455,9 @@ class DualPathRetriever: ) except asyncio.TimeoutError: logger.warning( - "metric.ppr_timeout_skip_count=1 timeout_s=%s entities=%s", - ppr_timeout_s, - len(entities), + "metric.ppr_timeout_skip_count=1 " + f"timeout_s={ppr_timeout_s} " + f"entities={len(entities)}" ) return results except Exception as e: diff --git a/plugins/A_memorix/core/retrieval/graph_relation_recall.py b/plugins/A_memorix/core/retrieval/graph_relation_recall.py index 3ce03b14..9af862f3 100644 --- a/plugins/A_memorix/core/retrieval/graph_relation_recall.py +++ b/plugins/A_memorix/core/retrieval/graph_relation_recall.py @@ -170,7 +170,7 @@ class GraphRelationRecallService: max_paths=self.config.max_paths, ) except Exception as e: - logger.debug("graph two-hop recall skipped: %s", e) + logger.debug(f"graph two-hop recall skipped: {e}") return for path_nodes in paths: @@ -210,7 +210,7 @@ class GraphRelationRecallService: limit=self.config.candidate_k, ) except Exception as e: - logger.debug("graph one-hop recall skipped: %s", e) + logger.debug(f"graph one-hop recall skipped: {e}") return self._append_relation_hashes( relation_hashes=relation_hashes, diff --git a/plugins/A_memorix/core/retrieval/sparse_bm25.py b/plugins/A_memorix/core/retrieval/sparse_bm25.py index 3b6f075d..1fef9f80 100644 --- a/plugins/A_memorix/core/retrieval/sparse_bm25.py +++ b/plugins/A_memorix/core/retrieval/sparse_bm25.py @@ -123,9 +123,8 @@ class SparseBM25Index: self._loaded = True self._prepare_tokenizer() logger.info( - "SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s", - self.config.tokenizer_mode, - self.config.mode, + "SparseBM25Index loaded: " + f"backend=fts5, tokenizer={self.config.tokenizer_mode}, mode={self.config.mode}" ) return True @@ -141,9 +140,9 @@ class SparseBM25Index: if user_dict: try: jieba.load_userdict(user_dict) # type: ignore[union-attr] - logger.info("已加载 jieba 用户词典: %s", user_dict) + logger.info(f"已加载 jieba 用户词典: {user_dict}") except Exception as e: - logger.warning("加载 jieba 用户词典失败: %s", e) + logger.warning(f"加载 jieba 用户词典失败: {e}") self._jieba_dict_loaded = True def _tokenize_jieba(self, text: str) -> List[str]: diff --git a/plugins/A_memorix/core/runtime/__init__.py b/plugins/A_memorix/core/runtime/__init__.py index fa222715..eece6d21 100644 --- a/plugins/A_memorix/core/runtime/__init__.py +++ b/plugins/A_memorix/core/runtime/__init__.py @@ -1,8 +1,16 @@ """SDK runtime exports for A_Memorix.""" +from .search_runtime_initializer import ( + SearchRuntimeBundle, + SearchRuntimeInitializer, + build_search_runtime, +) from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel __all__ = [ + "SearchRuntimeBundle", + "SearchRuntimeInitializer", + "build_search_runtime", "KernelSearchRequest", "SDKMemoryKernel", ] diff --git a/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py b/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py new file mode 100644 index 00000000..423b55c4 --- /dev/null +++ b/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py @@ -0,0 +1,268 @@ +"""Lifecycle bootstrap/teardown helpers extracted from plugin.py.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +from src.common.logger import get_logger + +from ..embedding import create_embedding_api_adapter +from ..retrieval import SparseBM25Config, SparseBM25Index +from ..storage import ( + GraphStore, + MetadataStore, + QuantizationType, + SparseMatrixFormat, + VectorStore, +) +from ..utils.runtime_self_check import ensure_runtime_self_check +from ..utils.relation_write_service import RelationWriteService + +logger = get_logger("A_Memorix.LifecycleOrchestrator") + + +async def ensure_initialized(plugin: Any) -> None: + if plugin._initialized: + plugin._runtime_ready = plugin._check_storage_ready() + return + + async with plugin._init_lock: + if plugin._initialized: + plugin._runtime_ready = plugin._check_storage_ready() + return + + logger.info("A_Memorix 插件正在异步初始化存储组件...") + plugin._validate_runtime_config() + await initialize_storage_async(plugin) + report = await ensure_runtime_self_check(plugin, force=True) + if not bool(report.get("ok", False)): + logger.error( + "A_Memorix runtime self-check failed: " + f"{report.get('message', 'unknown')}; " + "建议执行 python plugins/A_memorix/scripts/runtime_self_check.py --json" + ) + + if plugin.graph_store and plugin.metadata_store: + relation_count = plugin.metadata_store.count_relations() + if relation_count > 0 and not plugin.graph_store.has_edge_hash_map(): + raise RuntimeError( + "检测到 relations 数据存在但 edge-hash-map 为空。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + + plugin._initialized = True + plugin._runtime_ready = plugin._check_storage_ready() + plugin._update_plugin_config() + logger.info("A_Memorix 插件异步初始化成功") + + +def start_background_tasks(plugin: Any) -> None: + """Start background tasks idempotently.""" + if not hasattr(plugin, "_episode_generation_task"): + plugin._episode_generation_task = None + + if ( + plugin.get_config("summarization.enabled", True) + and plugin.get_config("schedule.enabled", True) + and (plugin._scheduled_import_task is None or plugin._scheduled_import_task.done()) + ): + plugin._scheduled_import_task = asyncio.create_task(plugin._scheduled_import_loop()) + + if ( + plugin.get_config("advanced.enable_auto_save", True) + and (plugin._auto_save_task is None or plugin._auto_save_task.done()) + ): + plugin._auto_save_task = asyncio.create_task(plugin._auto_save_loop()) + + if ( + plugin.get_config("person_profile.enabled", True) + and (plugin._person_profile_refresh_task is None or plugin._person_profile_refresh_task.done()) + ): + plugin._person_profile_refresh_task = asyncio.create_task(plugin._person_profile_refresh_loop()) + + if plugin._memory_maintenance_task is None or plugin._memory_maintenance_task.done(): + plugin._memory_maintenance_task = asyncio.create_task(plugin._memory_maintenance_loop()) + + rv_cfg = plugin.get_config("retrieval.relation_vectorization", {}) or {} + if isinstance(rv_cfg, dict): + rv_enabled = bool(rv_cfg.get("enabled", False)) + rv_backfill = bool(rv_cfg.get("backfill_enabled", False)) + else: + rv_enabled = False + rv_backfill = False + if rv_enabled and rv_backfill and ( + plugin._relation_vector_backfill_task is None or plugin._relation_vector_backfill_task.done() + ): + plugin._relation_vector_backfill_task = asyncio.create_task(plugin._relation_vector_backfill_loop()) + + episode_task = getattr(plugin, "_episode_generation_task", None) + episode_loop = getattr(plugin, "_episode_generation_loop", None) + if ( + callable(episode_loop) + and bool(plugin.get_config("episode.enabled", True)) + and bool(plugin.get_config("episode.generation_enabled", True)) + and (episode_task is None or episode_task.done()) + ): + plugin._episode_generation_task = asyncio.create_task(episode_loop()) + + +async def cancel_background_tasks(plugin: Any) -> None: + """Cancel all background tasks and wait for cleanup.""" + tasks = [ + ("scheduled_import", plugin._scheduled_import_task), + ("auto_save", plugin._auto_save_task), + ("person_profile_refresh", plugin._person_profile_refresh_task), + ("memory_maintenance", plugin._memory_maintenance_task), + ("relation_vector_backfill", plugin._relation_vector_backfill_task), + ("episode_generation", getattr(plugin, "_episode_generation_task", None)), + ] + for _, task in tasks: + if task and not task.done(): + task.cancel() + + for name, task in tasks: + if not task: + continue + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning(f"后台任务 {name} 退出异常: {e}") + + plugin._scheduled_import_task = None + plugin._auto_save_task = None + plugin._person_profile_refresh_task = None + plugin._memory_maintenance_task = None + plugin._relation_vector_backfill_task = None + plugin._episode_generation_task = None + + +async def initialize_storage_async(plugin: Any) -> None: + """Initialize storage components asynchronously.""" + data_dir_str = plugin.get_config("storage.data_dir", "./data") + if data_dir_str.startswith("."): + plugin_dir = Path(__file__).resolve().parents[2] + data_dir = (plugin_dir / data_dir_str).resolve() + else: + data_dir = Path(data_dir_str) + + logger.info(f"A_Memorix 数据存储路径: {data_dir}") + data_dir.mkdir(parents=True, exist_ok=True) + + plugin.embedding_manager = create_embedding_api_adapter( + batch_size=plugin.get_config("embedding.batch_size", 32), + max_concurrent=plugin.get_config("embedding.max_concurrent", 5), + default_dimension=plugin.get_config("embedding.dimension", 1024), + model_name=plugin.get_config("embedding.model_name", "auto"), + retry_config=plugin.get_config("embedding.retry", {}), + ) + logger.info("嵌入 API 适配器初始化完成") + + try: + detected_dimension = await plugin.embedding_manager._detect_dimension() + logger.info(f"嵌入维度检测成功: {detected_dimension}") + except Exception as e: + logger.warning(f"嵌入维度检测失败: {e},使用默认值") + detected_dimension = plugin.embedding_manager.default_dimension + + quantization_str = plugin.get_config("embedding.quantization_type", "int8") + if str(quantization_str or "").strip().lower() != "int8": + raise ValueError("embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。") + quantization_type = QuantizationType.INT8 + + plugin.vector_store = VectorStore( + dimension=detected_dimension, + quantization_type=quantization_type, + data_dir=data_dir / "vectors", + ) + plugin.vector_store.min_train_threshold = plugin.get_config("embedding.min_train_threshold", 40) + logger.info( + "向量存储初始化完成(" + f"维度: {detected_dimension}, " + f"训练阈值: {plugin.vector_store.min_train_threshold})" + ) + + matrix_format_str = plugin.get_config("graph.sparse_matrix_format", "csr") + matrix_format_map = { + "csr": SparseMatrixFormat.CSR, + "csc": SparseMatrixFormat.CSC, + } + matrix_format = matrix_format_map.get(matrix_format_str, SparseMatrixFormat.CSR) + + plugin.graph_store = GraphStore( + matrix_format=matrix_format, + data_dir=data_dir / "graph", + ) + logger.info("图存储初始化完成") + + plugin.metadata_store = MetadataStore(data_dir=data_dir / "metadata") + plugin.metadata_store.connect() + logger.info("元数据存储初始化完成") + + plugin.relation_write_service = RelationWriteService( + metadata_store=plugin.metadata_store, + graph_store=plugin.graph_store, + vector_store=plugin.vector_store, + embedding_manager=plugin.embedding_manager, + ) + logger.info("关系写入服务初始化完成") + + sparse_cfg_raw = plugin.get_config("retrieval.sparse", {}) or {} + if not isinstance(sparse_cfg_raw, dict): + sparse_cfg_raw = {} + try: + sparse_cfg = SparseBM25Config(**sparse_cfg_raw) + except Exception as e: + logger.warning(f"sparse 配置非法,回退默认配置: {e}") + sparse_cfg = SparseBM25Config() + plugin.sparse_index = SparseBM25Index( + metadata_store=plugin.metadata_store, + config=sparse_cfg, + ) + logger.info( + "稀疏检索组件初始化完成: " + f"enabled={sparse_cfg.enabled}, " + f"lazy_load={sparse_cfg.lazy_load}, " + f"mode={sparse_cfg.mode}, " + f"tokenizer={sparse_cfg.tokenizer_mode}" + ) + if sparse_cfg.enabled and not sparse_cfg.lazy_load: + plugin.sparse_index.ensure_loaded() + + if plugin.vector_store.has_data(): + try: + plugin.vector_store.load() + logger.info(f"向量数据已加载,共 {plugin.vector_store.num_vectors} 个向量") + except Exception as e: + logger.warning(f"加载向量数据失败: {e}") + + try: + warmup_summary = plugin.vector_store.warmup_index(force_train=True) + if warmup_summary.get("ok"): + logger.info( + "向量索引预热完成: " + f"trained={warmup_summary.get('trained')}, " + f"index_ntotal={warmup_summary.get('index_ntotal')}, " + f"fallback_ntotal={warmup_summary.get('fallback_ntotal')}, " + f"bin_count={warmup_summary.get('bin_count')}, " + f"duration_ms={float(warmup_summary.get('duration_ms', 0.0)):.2f}" + ) + else: + logger.warning( + "向量索引预热失败,继续启用 sparse 降级路径: " + f"{warmup_summary.get('error', 'unknown')}" + ) + except Exception as e: + logger.warning(f"向量索引预热异常,继续启用 sparse 降级路径: {e}") + + if plugin.graph_store.has_data(): + try: + plugin.graph_store.load() + logger.info(f"图数据已加载,共 {plugin.graph_store.num_nodes} 个节点") + except Exception as e: + logger.warning(f"加载图数据失败: {e}") + + logger.info(f"知识库数据目录: {data_dir}") diff --git a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py index 7c8f9213..439afd3d 100644 --- a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py @@ -1,26 +1,33 @@ from __future__ import annotations +import asyncio +import json +import pickle import time +import uuid from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Sequence from src.common.logger import get_logger from ..embedding import create_embedding_api_adapter -from ..retrieval import ( - DualPathRetriever, - DualPathRetrieverConfig, - RetrievalResult, - SparseBM25Config, - SparseBM25Index, - TemporalQueryOptions, -) +from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore from ..utils.aggregate_query_service import AggregateQueryService from ..utils.episode_retrieval_service import EpisodeRetrievalService -from ..utils.hash import normalize_text +from ..utils.episode_segmentation_service import EpisodeSegmentationService +from ..utils.episode_service import EpisodeService +from ..utils.hash import compute_hash, normalize_text +from ..utils.person_profile_service import PersonProfileService from ..utils.relation_write_service import RelationWriteService +from ..utils.retrieval_tuning_manager import RetrievalTuningManager +from ..utils.runtime_self_check import run_embedding_runtime_self_check +from ..utils.search_execution_service import SearchExecutionRequest, SearchExecutionService +from ..utils.summary_importer import SummaryImporter +from ..utils.time_parser import format_timestamp, parse_query_datetime_to_timestamp +from ..utils.web_import_manager import ImportTaskManager +from .search_runtime_initializer import SearchRuntimeBundle, build_search_runtime logger = get_logger("A_Memorix.SDKMemoryKernel") @@ -32,8 +39,76 @@ class KernelSearchRequest: mode: str = "hybrid" chat_id: str = "" person_id: str = "" - time_start: Optional[float] = None - time_end: Optional[float] = None + time_start: Optional[str | float] = None + time_end: Optional[str | float] = None + respect_filter: bool = True + user_id: str = "" + group_id: str = "" + + +@dataclass +class _NormalizedSearchTimeWindow: + numeric_start: Optional[float] = None + numeric_end: Optional[float] = None + query_start: Optional[str] = None + query_end: Optional[str] = None + + +class _KernelRuntimeFacade: + def __init__(self, kernel: "SDKMemoryKernel") -> None: + self._kernel = kernel + self.config = kernel.config + self._plugin_config = kernel.config + self._runtime_self_check_report: Dict[str, Any] = {} + + def get_config(self, key: str, default: Any = None) -> Any: + return self._kernel._cfg(key, default) + + def is_runtime_ready(self) -> bool: + return self._kernel.is_runtime_ready() + + def is_chat_enabled(self, stream_id: str, group_id: str | None = None, user_id: str | None = None) -> bool: + return self._kernel.is_chat_enabled(stream_id=stream_id, group_id=group_id, user_id=user_id) + + async def reinforce_access(self, relation_hashes: Sequence[str]) -> None: + if self._kernel.metadata_store is None: + return + hashes = [str(item or "").strip() for item in relation_hashes if str(item or "").strip()] + if not hashes: + return + self._kernel.metadata_store.reinforce_relations(hashes) + self._kernel._last_maintenance_at = time.time() + + async def execute_request_with_dedup( + self, + request_key: str, + executor: Callable[[], Awaitable[Dict[str, Any]]], + ) -> tuple[bool, Dict[str, Any]]: + return await self._kernel.execute_request_with_dedup(request_key, executor) + + @property + def vector_store(self) -> Optional[VectorStore]: + return self._kernel.vector_store + + @property + def graph_store(self) -> Optional[GraphStore]: + return self._kernel.graph_store + + @property + def metadata_store(self) -> Optional[MetadataStore]: + return self._kernel.metadata_store + + @property + def embedding_manager(self): + return self._kernel.embedding_manager + + @property + def sparse_index(self): + return self._kernel.sparse_index + + @property + def relation_write_service(self) -> Optional[RelationWriteService]: + return self._kernel.relation_write_service class SDKMemoryKernel: @@ -43,7 +118,7 @@ class SDKMemoryKernel: storage_cfg = self._cfg("storage", {}) or {} data_dir = str(storage_cfg.get("data_dir", "./data") or "./data") self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir) - self.embedding_dimension = max(32, int(self._cfg("embedding.dimension", 256))) + self.embedding_dimension = max(1, int(self._cfg("embedding.dimension", 1024))) self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False)) self.embedding_manager = None @@ -51,16 +126,30 @@ class SDKMemoryKernel: self.graph_store: Optional[GraphStore] = None self.metadata_store: Optional[MetadataStore] = None self.relation_write_service: Optional[RelationWriteService] = None - self.sparse_index = None - self.retriever: Optional[DualPathRetriever] = None + self.sparse_index: Optional[SparseBM25Index] = None + self.retriever = None + self.threshold_filter = None self.episode_retriever: Optional[EpisodeRetrievalService] = None self.aggregate_query_service: Optional[AggregateQueryService] = None + self.person_profile_service: Optional[PersonProfileService] = None + self.episode_segmentation_service: Optional[EpisodeSegmentationService] = None + self.episode_service: Optional[EpisodeService] = None + self.summary_importer: Optional[SummaryImporter] = None + self.import_task_manager: Optional[ImportTaskManager] = None + self.retrieval_tuning_manager: Optional[RetrievalTuningManager] = None + self._runtime_bundle: Optional[SearchRuntimeBundle] = None + self._runtime_facade = _KernelRuntimeFacade(self) self._initialized = False self._last_maintenance_at: Optional[float] = None + self._request_dedup_tasks: Dict[str, asyncio.Task] = {} + self._background_tasks: Dict[str, asyncio.Task] = {} + self._background_lock = asyncio.Lock() + self._background_stopping = False + self._active_person_timestamps: Dict[str, float] = {} def _cfg(self, key: str, default: Any = None) -> Any: current: Any = self.config - if key in {"storage", "embedding", "retrieval"} and isinstance(current, dict): + if key in {"storage", "embedding", "retrieval", "graph", "episode", "web", "advanced", "threshold", "summarization"} and isinstance(current, dict): return current.get(key, default) for part in key.split("."): if isinstance(current, dict) and part in current: @@ -69,34 +158,183 @@ class SDKMemoryKernel: return default return current + def _set_cfg(self, key: str, value: Any) -> None: + current: Dict[str, Any] = self.config + parts = [part for part in str(key or "").split(".") if part] + if not parts: + return + for part in parts[:-1]: + next_value = current.get(part) + if not isinstance(next_value, dict): + next_value = {} + current[part] = next_value + current = next_value + current[parts[-1]] = value + + def _build_runtime_config(self) -> Dict[str, Any]: + runtime_config = dict(self.config) + runtime_config.update( + { + "vector_store": self.vector_store, + "graph_store": self.graph_store, + "metadata_store": self.metadata_store, + "embedding_manager": self.embedding_manager, + "sparse_index": self.sparse_index, + "relation_write_service": self.relation_write_service, + "plugin_instance": self._runtime_facade, + } + ) + return runtime_config + + def is_runtime_ready(self) -> bool: + return bool( + self._initialized + and self.vector_store is not None + and self.graph_store is not None + and self.metadata_store is not None + and self.embedding_manager is not None + and self.retriever is not None + ) + + def is_chat_enabled(self, stream_id: str, group_id: str | None = None, user_id: str | None = None) -> bool: + filter_config = self._cfg("filter", {}) or {} + if not isinstance(filter_config, dict) or not filter_config: + return True + + if not bool(filter_config.get("enabled", True)): + return True + + mode = str(filter_config.get("mode", "blacklist") or "blacklist").strip().lower() + patterns = filter_config.get("chats") or [] + if not isinstance(patterns, list): + patterns = [] + + if not patterns: + return mode == "blacklist" + + stream_token = str(stream_id or "").strip() + group_token = str(group_id or "").strip() + user_token = str(user_id or "").strip() + candidates = {token for token in (stream_token, group_token, user_token) if token} + + matched = False + for raw_pattern in patterns: + pattern = str(raw_pattern or "").strip() + if not pattern: + continue + if ":" in pattern: + prefix, value = pattern.split(":", 1) + prefix = prefix.strip().lower() + value = value.strip() + if prefix == "group" and value and value == group_token: + matched = True + elif prefix in {"user", "private"} and value and value == user_token: + matched = True + elif prefix == "stream" and value and value == stream_token: + matched = True + elif pattern in candidates: + matched = True + + if matched: + break + + if mode == "blacklist": + return not matched + return matched + + def _is_chat_filtered( + self, + *, + respect_filter: bool, + stream_id: str = "", + group_id: str = "", + user_id: str = "", + ) -> bool: + if not bool(respect_filter): + return False + + stream_token = str(stream_id or "").strip() + group_token = str(group_id or "").strip() + user_token = str(user_id or "").strip() + if not (stream_token or group_token or user_token): + return False + return not self.is_chat_enabled(stream_token, group_token, user_token) + + def _stored_vector_dimension(self) -> Optional[int]: + meta_path = self.data_dir / "vectors" / "vectors_metadata.pkl" + if not meta_path.exists(): + return None + try: + with open(meta_path, "rb") as handle: + meta = pickle.load(handle) + except Exception as exc: + logger.warning(f"读取向量元数据失败,将回退到 runtime self-check: {exc}") + return None + try: + value = int(meta.get("dimension") or 0) + except Exception: + return None + return value if value > 0 else None + + def _vector_mismatch_error(self, *, stored_dimension: int, detected_dimension: int) -> str: + return ( + "检测到现有向量库与当前 embedding 输出维度不一致:" + f"stored={stored_dimension}, encoded={detected_dimension}。" + " 当前版本不会兼容 hash 时代或其他维度的旧向量,请改回原 embedding 配置," + "或执行重嵌入/重建向量。" + ) + async def initialize(self) -> None: if self._initialized: + await self._start_background_tasks() return + self.data_dir.mkdir(parents=True, exist_ok=True) self.embedding_manager = create_embedding_api_adapter( batch_size=int(self._cfg("embedding.batch_size", 32)), max_concurrent=int(self._cfg("embedding.max_concurrent", 5)), default_dimension=self.embedding_dimension, - model_name=str(self._cfg("embedding.model_name", "hash-v1")), + enable_cache=bool(self._cfg("embedding.enable_cache", False)), + model_name=str(self._cfg("embedding.model_name", "auto") or "auto"), retry_config=self._cfg("embedding.retry", {}) or {}, ) - self.embedding_dimension = int(await self.embedding_manager._detect_dimension()) + detected_dimension = int(await self.embedding_manager._detect_dimension()) + self.embedding_dimension = detected_dimension + + stored_dimension = self._stored_vector_dimension() + if stored_dimension is not None and stored_dimension != detected_dimension: + raise RuntimeError( + self._vector_mismatch_error( + stored_dimension=stored_dimension, + detected_dimension=detected_dimension, + ) + ) + + matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower() + graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR + self.vector_store = VectorStore( - dimension=self.embedding_dimension, + dimension=detected_dimension, quantization_type=QuantizationType.INT8, data_dir=self.data_dir / "vectors", ) - self.graph_store = GraphStore(matrix_format=SparseMatrixFormat.CSR, data_dir=self.data_dir / "graph") + self.graph_store = GraphStore(matrix_format=graph_format, data_dir=self.data_dir / "graph") self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata") self.metadata_store.connect() + if self.vector_store.has_data(): self.vector_store.load() self.vector_store.warmup_index(force_train=True) if self.graph_store.has_data(): self.graph_store.load() - sparse_cfg = self._cfg("retrieval.sparse", {}) or {} - self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=SparseBM25Config(**sparse_cfg)) + sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {} + try: + sparse_cfg = SparseBM25Config(**sparse_cfg_raw) + except Exception as exc: + logger.warning(f"sparse 配置非法,回退默认: {exc}") + sparse_cfg = SparseBM25Config() + self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=sparse_cfg) if getattr(self.sparse_index.config, "enabled", False): self.sparse_index.ensure_loaded() @@ -106,39 +344,133 @@ class SDKMemoryKernel: vector_store=self.vector_store, embedding_manager=self.embedding_manager, ) - self.retriever = DualPathRetriever( + + runtime_config = self._build_runtime_config() + self._runtime_bundle = build_search_runtime( + plugin_config=runtime_config, + logger_obj=logger, + owner_tag="sdk_kernel", + log_prefix="[sdk]", + ) + if not self._runtime_bundle.ready: + raise RuntimeError(self._runtime_bundle.error or "检索运行时初始化失败") + + self.retriever = self._runtime_bundle.retriever + self.threshold_filter = self._runtime_bundle.threshold_filter + self.sparse_index = self._runtime_bundle.sparse_index or self.sparse_index + + runtime_config = self._build_runtime_config() + self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever) + self.aggregate_query_service = AggregateQueryService(plugin_config=runtime_config) + self.person_profile_service = PersonProfileService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + sparse_index=self.sparse_index, + plugin_config=runtime_config, + retriever=self.retriever, + ) + self.episode_segmentation_service = EpisodeSegmentationService(plugin_config=runtime_config) + self.episode_service = EpisodeService( + metadata_store=self.metadata_store, + plugin_config=runtime_config, + segmentation_service=self.episode_segmentation_service, + ) + self.summary_importer = SummaryImporter( vector_store=self.vector_store, graph_store=self.graph_store, metadata_store=self.metadata_store, embedding_manager=self.embedding_manager, - sparse_index=self.sparse_index, - config=DualPathRetrieverConfig( - top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 24)), - top_k_relations=int(self._cfg("retrieval.top_k_relations", 12)), - top_k_final=int(self._cfg("retrieval.top_k_final", 10)), - alpha=float(self._cfg("retrieval.alpha", 0.5)), - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)), - ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)), - enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)), - sparse=sparse_cfg, - fusion=self._cfg("retrieval.fusion", {}) or {}, - graph_recall=self._cfg("retrieval.search.graph_recall", {}) or {}, - relation_intent=self._cfg("retrieval.search.relation_intent", {}) or {}, - ), + plugin_config=runtime_config, ) - self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever) - self.aggregate_query_service = AggregateQueryService(plugin_config=self.config) + self.import_task_manager = ImportTaskManager(self._runtime_facade) + self.retrieval_tuning_manager = RetrievalTuningManager( + self._runtime_facade, + import_write_blocked_provider=self.import_task_manager.is_write_blocked, + ) + + report = await run_embedding_runtime_self_check( + config=runtime_config, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + sample_text="A_Memorix runtime self check", + ) + self._runtime_facade._runtime_self_check_report = dict(report) + if not bool(report.get("ok", False)): + message = str(report.get("message", "runtime self-check failed") or "runtime self-check failed") + raise RuntimeError(f"{message};请改回原 embedding 配置,或执行重嵌入/重建向量。") + self._initialized = True + await self._start_background_tasks() + + async def shutdown(self) -> None: + await self._stop_background_tasks() + if self.import_task_manager is not None: + try: + await self.import_task_manager.shutdown() + except Exception as exc: + logger.warning(f"关闭导入任务管理器失败: {exc}") + if self.retrieval_tuning_manager is not None: + try: + await self.retrieval_tuning_manager.shutdown() + except Exception as exc: + logger.warning(f"关闭调优任务管理器失败: {exc}") + self.close() def close(self) -> None: - if self.vector_store is not None: - self.vector_store.save() - if self.graph_store is not None: - self.graph_store.save() - if self.metadata_store is not None: - self.metadata_store.close() - self._initialized = False + try: + self._persist() + finally: + if self.metadata_store is not None: + self.metadata_store.close() + self._initialized = False + self._request_dedup_tasks.clear() + self._runtime_facade._runtime_self_check_report = {} + self._background_tasks.clear() + self._active_person_timestamps.clear() + + async def execute_request_with_dedup( + self, + request_key: str, + executor: Callable[[], Awaitable[Dict[str, Any]]], + ) -> tuple[bool, Dict[str, Any]]: + token = str(request_key or "").strip() + if not token: + return False, await executor() + + existing = self._request_dedup_tasks.get(token) + if existing is not None: + return True, await existing + + task = asyncio.create_task(executor()) + self._request_dedup_tasks[token] = task + try: + payload = await task + return False, payload + finally: + current = self._request_dedup_tasks.get(token) + if current is task: + self._request_dedup_tasks.pop(token, None) + + async def summarize_chat_stream( + self, + *, + chat_id: str, + context_length: Optional[int] = None, + include_personality: Optional[bool] = None, + ) -> Dict[str, Any]: + await self.initialize() + assert self.summary_importer + success, detail = await self.summary_importer.import_from_stream( + stream_id=str(chat_id or "").strip(), + context_length=context_length, + include_personality=include_personality, + ) + if success: + await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])]) + self._persist() + return {"success": bool(success), "detail": detail} async def ingest_summary( self, @@ -151,9 +483,35 @@ class SDKMemoryKernel: time_end: Optional[float] = None, tags: Optional[Sequence[str]] = None, metadata: Optional[Dict[str, Any]] = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", ) -> Dict[str, Any]: + external_token = str(external_id or "").strip() or compute_hash(f"chat_summary:{chat_id}:{text}") + if self._is_chat_filtered( + respect_filter=respect_filter, + stream_id=chat_id, + group_id=group_id, + user_id=user_id, + ): + return { + "success": True, + "stored_ids": [], + "skipped_ids": [external_token], + "detail": "chat_filtered", + } + summary_meta = dict(metadata or {}) summary_meta.setdefault("kind", "chat_summary") + if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): + result = await self.summarize_chat_stream( + chat_id=chat_id, + context_length=self._optional_int(summary_meta.get("context_length")), + include_personality=summary_meta.get("include_personality"), + ) + result.setdefault("external_id", external_id) + result.setdefault("chat_id", chat_id) + return result return await self.ingest_text( external_id=external_id, source_type="chat_summary", @@ -164,6 +522,9 @@ class SDKMemoryKernel: time_end=time_end, tags=tags, metadata=summary_meta, + respect_filter=respect_filter, + user_id=user_id, + group_id=group_id, ) async def ingest_text( @@ -182,15 +543,42 @@ class SDKMemoryKernel: metadata: Optional[Dict[str, Any]] = None, entities: Optional[Sequence[str]] = None, relations: Optional[Sequence[Dict[str, Any]]] = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", ) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store and self.vector_store and self.graph_store and self.embedding_manager - assert self.relation_write_service content = normalize_text(text) + external_token = str(external_id or "").strip() or compute_hash(f"{source_type}:{chat_id}:{content}") + if self._is_chat_filtered( + respect_filter=respect_filter, + stream_id=chat_id, + group_id=group_id, + user_id=user_id, + ): + return { + "success": True, + "stored_ids": [], + "skipped_ids": [external_token], + "detail": "chat_filtered", + } + + await self.initialize() + assert self.metadata_store is not None + assert self.vector_store is not None + assert self.graph_store is not None + assert self.embedding_manager is not None + assert self.relation_write_service is not None + if not content: - return {"stored_ids": [], "skipped_ids": [external_id], "reason": "empty_text"} - if ref := self.metadata_store.get_external_memory_ref(external_id): - return {"stored_ids": [], "skipped_ids": [str(ref.get("paragraph_hash", ""))], "reason": "exists"} + return {"stored_ids": [], "skipped_ids": [external_token], "reason": "empty_text"} + + existing_ref = self.metadata_store.get_external_memory_ref(external_token) + if existing_ref: + return { + "stored_ids": [], + "skipped_ids": [str(existing_ref.get("paragraph_hash", "") or "")], + "reason": "exists", + } person_tokens = self._tokens(person_ids) participant_tokens = self._tokens(participants) @@ -199,7 +587,7 @@ class SDKMemoryKernel: paragraph_meta = dict(metadata or {}) paragraph_meta.update( { - "external_id": external_id, + "external_id": external_token, "source_type": str(source_type or "").strip(), "chat_id": str(chat_id or "").strip(), "person_ids": person_tokens, @@ -207,146 +595,303 @@ class SDKMemoryKernel: "tags": self._tokens(tags), } ) + paragraph_hash = self.metadata_store.add_paragraph( content=content, source=source, metadata=paragraph_meta, - knowledge_type="factual" if source_type == "person_fact" else "narrative" if source_type == "chat_summary" else "mixed", + knowledge_type=self._resolve_knowledge_type(source_type), time_meta=self._time_meta(timestamp, time_start, time_end), ) embedding = await self.embedding_manager.encode(content) self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash]) + for name in entity_tokens: self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash) stored_relations: List[str] = [] for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]: - s = str(row.get("subject", "") or "").strip() - p = str(row.get("predicate", "") or "").strip() - o = str(row.get("object", "") or "").strip() - if not (s and p and o): + subject = str(row.get("subject", "") or "").strip() + predicate = str(row.get("predicate", "") or "").strip() + obj = str(row.get("object", "") or "").strip() + if not (subject and predicate and obj): continue result = await self.relation_write_service.upsert_relation_with_vector( - subject=s, - predicate=p, - obj=o, + subject=subject, + predicate=predicate, + obj=obj, confidence=float(row.get("confidence", 1.0) or 1.0), source_paragraph=paragraph_hash, - metadata={"external_id": external_id, "source_type": source_type}, + metadata=row.get("metadata") if isinstance(row.get("metadata"), dict) else {"external_id": external_token, "source_type": source_type}, write_vector=self.relation_vectors_enabled, ) self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value) stored_relations.append(result.hash_value) self.metadata_store.upsert_external_memory_ref( - external_id=external_id, + external_id=external_token, paragraph_hash=paragraph_hash, source_type=source_type, metadata={"chat_id": chat_id, "person_ids": person_tokens}, ) + self.metadata_store.enqueue_episode_pending(paragraph_hash, source=source) self._persist() - self.rebuild_episodes_for_sources([source]) + await self.process_episode_pending_batch( + limit=max(1, int(self._cfg("episode.pending_batch_size", 12))), + max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3))), + ) for person_id in person_tokens: + self._mark_person_active(person_id) await self.refresh_person_profile(person_id) return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []} - async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]: + async def process_episode_pending_batch(self, *, limit: int = 20, max_retry: int = 3) -> Dict[str, Any]: await self.initialize() - assert self.retriever and self.episode_retriever and self.aggregate_query_service + assert self.metadata_store is not None + assert self.episode_service is not None + + pending_rows = self.metadata_store.fetch_episode_pending_batch(limit=max(1, int(limit)), max_retry=max(1, int(max_retry))) + if not pending_rows: + return {"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0} + + source_to_hashes: Dict[str, List[str]] = {} + pending_hashes = [str(row.get("paragraph_hash", "") or "").strip() for row in pending_rows if str(row.get("paragraph_hash", "") or "").strip()] + for row in pending_rows: + paragraph_hash = str(row.get("paragraph_hash", "") or "").strip() + source = str(row.get("source", "") or "").strip() + if not paragraph_hash or not source: + continue + source_to_hashes.setdefault(source, []).append(paragraph_hash) + + if pending_hashes: + self.metadata_store.mark_episode_pending_running(pending_hashes) + + result = await self.episode_service.process_pending_rows(pending_rows) + done_hashes = [str(item or "").strip() for item in result.get("done_hashes", []) if str(item or "").strip()] + failed_hashes = { + str(hash_value or "").strip(): str(error or "").strip() + for hash_value, error in (result.get("failed_hashes", {}) or {}).items() + if str(hash_value or "").strip() + } + + if done_hashes: + self.metadata_store.mark_episode_pending_done(done_hashes) + for hash_value, error in failed_hashes.items(): + self.metadata_store.mark_episode_pending_failed(hash_value, error) + + untouched = [hash_value for hash_value in pending_hashes if hash_value not in set(done_hashes) and hash_value not in failed_hashes] + for hash_value in untouched: + self.metadata_store.mark_episode_pending_failed(hash_value, "episode processing finished without explicit status") + + for source, paragraph_hashes in source_to_hashes.items(): + counts = self.metadata_store.get_episode_pending_status_counts(source) + if counts.get("failed", 0) > 0: + source_error = next( + ( + failed_hashes.get(hash_value) + for hash_value in paragraph_hashes + if failed_hashes.get(hash_value) + ), + "episode pending source contains failed rows", + ) + self.metadata_store.mark_episode_source_failed(source, str(source_error or "episode pending source contains failed rows")) + elif counts.get("pending", 0) == 0 and counts.get("running", 0) == 0: + self.metadata_store.mark_episode_source_done(source) + + self._persist() + return { + "processed": len(done_hashes) + len(failed_hashes), + "episode_count": int(result.get("episode_count") or 0), + "fallback_count": int(result.get("fallback_count") or 0), + "failed": len(failed_hashes) + len(untouched), + "group_count": int(result.get("group_count") or 0), + "missing_count": int(result.get("missing_count") or 0), + } + + async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]: + if self._is_chat_filtered( + respect_filter=request.respect_filter, + stream_id=request.chat_id, + group_id=request.group_id, + user_id=request.user_id, + ): + return {"summary": "", "hits": [], "filtered": True} + + await self.initialize() + assert self.retriever is not None + assert self.episode_retriever is not None + assert self.aggregate_query_service is not None + mode = str(request.mode or "hybrid").strip().lower() or "hybrid" - clean_query = str(request.query or "").strip() + query = str(request.query or "").strip() limit = max(1, int(request.limit or 5)) - temporal = self._temporal(request) + try: + time_window = self._normalize_search_time_window(request.time_start, request.time_end) + except ValueError as exc: + return {"summary": "", "hits": [], "error": str(exc)} + if mode == "episode": rows = await self.episode_retriever.query( - query=clean_query, + query=query, top_k=limit, - time_from=request.time_start, - time_to=request.time_end, + time_from=time_window.numeric_start, + time_to=time_window.numeric_end, + person=request.person_id or None, source=self._chat_source(request.chat_id), ) hits = [self._episode_hit(row) for row in rows] return {"summary": self._summary(hits), "hits": hits} + if mode == "aggregate": payload = await self.aggregate_query_service.execute( - query=clean_query, + query=query, top_k=limit, mix=True, mix_top_k=limit, - time_from=str(request.time_start) if request.time_start is not None else None, - time_to=str(request.time_end) if request.time_end is not None else None, - search_runner=lambda: self._aggregate_search(clean_query, limit, temporal), - time_runner=lambda: self._aggregate_time(clean_query, limit, temporal), - episode_runner=lambda: self._aggregate_episode(clean_query, limit, request), + time_from=time_window.query_start, + time_to=time_window.query_end, + search_runner=lambda: self._aggregate_search(query, limit, request), + time_runner=lambda: self._aggregate_time(query, limit, request, time_window), + episode_runner=lambda: self._aggregate_episode(query, limit, request, time_window), ) hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)] for item in hits: item.setdefault("metadata", {}) - return {"summary": self._summary(hits), "hits": hits} - results = await self.retriever.retrieve(query=clean_query, top_k=limit, temporal=temporal) - hits = [self._retrieval_hit(item) for item in results] - return {"summary": self._summary(self._filter_hits(hits, request.person_id)), "hits": self._filter_hits(hits, request.person_id)} + filtered = self._filter_hits(hits, request.person_id) + return {"summary": self._summary(filtered), "hits": filtered} + + query_type = "search" if mode in {"search", "semantic"} else mode + runtime_config = self._build_runtime_config() + result = await SearchExecutionService.execute( + retriever=self.retriever, + threshold_filter=self.threshold_filter, + plugin_config=runtime_config, + request=SearchExecutionRequest( + caller="sdk_memory_kernel", + stream_id=str(request.chat_id or "") or None, + group_id=str(request.group_id or "") or None, + user_id=str(request.user_id or "") or None, + query_type=query_type, + query=query, + top_k=limit, + time_from=time_window.query_start, + time_to=time_window.query_end, + person=str(request.person_id or "") or None, + source=self._chat_source(request.chat_id), + use_threshold=True, + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ), + enforce_chat_filter=bool(request.respect_filter), + reinforce_access=True, + ) + if not result.success: + return {"summary": "", "hits": [], "error": result.error} + if result.chat_filtered: + return {"summary": "", "hits": [], "filtered": True} + + hits = [self._retrieval_result_hit(item) for item in result.results] + filtered = self._filter_hits(hits, request.person_id) + return {"summary": self._summary(filtered), "hits": filtered} async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]: - _ = chat_id + del chat_id await self.initialize() - assert self.metadata_store - snapshot = self.metadata_store.get_latest_person_profile_snapshot(person_id) or await self.refresh_person_profile(person_id, limit=limit) + assert self.metadata_store is not None + assert self.person_profile_service is not None + self._mark_person_active(person_id) + profile = await self.person_profile_service.query_person_profile( + person_id=person_id, + top_k=max(4, int(limit or 10)), + source_note="sdk_memory_kernel.get_person_profile", + ) + if not profile.get("success"): + return {"summary": "", "traits": [], "evidence": []} + evidence = [] - for hash_value in snapshot.get("evidence_ids", [])[: max(1, int(limit))]: + for hash_value in profile.get("evidence_ids", [])[: max(1, int(limit))]: paragraph = self.metadata_store.get_paragraph(hash_value) if paragraph is not None: - evidence.append({"hash": hash_value, "content": str(paragraph.get("content", "") or "")[:220], "metadata": paragraph.get("metadata", {}) or {}}) - text = str(snapshot.get("profile_text", "") or "").strip() + evidence.append( + { + "hash": hash_value, + "content": str(paragraph.get("content", "") or "")[:220], + "metadata": paragraph.get("metadata", {}) or {}, + "type": "paragraph", + } + ) + continue + + relation = self.metadata_store.get_relation(hash_value) + if relation is not None: + evidence.append( + { + "hash": hash_value, + "content": " ".join( + [ + str(relation.get("subject", "") or "").strip(), + str(relation.get("predicate", "") or "").strip(), + str(relation.get("object", "") or "").strip(), + ] + ).strip(), + "metadata": { + "confidence": relation.get("confidence"), + "source_paragraph": relation.get("source_paragraph"), + }, + "type": "relation", + } + ) + + text = str(profile.get("profile_text", "") or "").strip() traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8] - return {"summary": text, "traits": traits, "evidence": evidence} + return { + "summary": text, + "traits": traits, + "evidence": evidence, + "person_id": str(profile.get("person_id", "") or person_id), + "person_name": str(profile.get("person_name", "") or ""), + "profile_source": str(profile.get("profile_source", "") or "auto_snapshot"), + "has_manual_override": bool(profile.get("has_manual_override", False)), + } - async def refresh_person_profile(self, person_id: str, limit: int = 10) -> Dict[str, Any]: + async def refresh_person_profile(self, person_id: str, limit: int = 10, *, mark_active: bool = True) -> Dict[str, Any]: await self.initialize() - assert self.metadata_store - rows = self.metadata_store.query( - """ - SELECT DISTINCT p.* - FROM paragraphs p - JOIN paragraph_entities pe ON pe.paragraph_hash = p.hash - JOIN entities e ON e.hash = pe.entity_hash - WHERE e.name = ? - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - ORDER BY COALESCE(p.event_time_end, p.event_time_start, p.event_time, p.updated_at, p.created_at) DESC - LIMIT ? - """, - (person_id, max(1, int(limit)) * 3), - ) - evidence_ids = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()] - vector_evidence = [{"hash": str(row.get("hash", "") or ""), "type": "paragraph", "score": 0.0, "content": str(row.get("content", "") or "")[:220], "metadata": row.get("metadata", {}) or {}} for row in rows[: max(1, int(limit))]] - relation_edges = [{"hash": str(row.get("hash", "") or ""), "subject": str(row.get("subject", "") or ""), "predicate": str(row.get("predicate", "") or ""), "object": str(row.get("object", "") or ""), "confidence": float(row.get("confidence", 1.0) or 1.0)} for row in self.metadata_store.get_relations(subject=person_id)[:limit]] - if relation_edges: - profile_text = "\n".join(f"{item['subject']} {item['predicate']} {item['object']}" for item in relation_edges[:6]) - elif vector_evidence: - profile_text = "\n".join(f"- {item['content']}" for item in vector_evidence[:6]) - else: - profile_text = "暂无稳定画像证据。" - return self.metadata_store.upsert_person_profile_snapshot( + assert self.person_profile_service + if mark_active: + self._mark_person_active(person_id) + profile = await self.person_profile_service.query_person_profile( person_id=person_id, - profile_text=profile_text, - aliases=[person_id], - relation_edges=relation_edges, - vector_evidence=vector_evidence, - evidence_ids=evidence_ids[: max(1, int(limit))], - expires_at=time.time() + 6 * 3600, - source_note="sdk_memory_kernel", + top_k=max(4, int(limit or 10)), + force_refresh=True, + source_note="sdk_memory_kernel.refresh_person_profile", ) + return profile if isinstance(profile, dict) else {} - async def maintain_memory(self, *, action: str, target: str, hours: Optional[float] = None, reason: str = "") -> Dict[str, Any]: - _ = reason + async def maintain_memory( + self, + *, + action: str, + target: str = "", + hours: Optional[float] = None, + reason: str = "", + limit: int = 50, + ) -> Dict[str, Any]: + del reason await self.initialize() assert self.metadata_store - hashes = self._resolve_relation_hashes(target) + act = str(action or "").strip().lower() + if act == "recycle_bin": + items = self.metadata_store.get_deleted_relations(limit=max(1, int(limit or 50))) + return {"success": True, "items": items, "count": len(items)} + + hashes = self._resolve_deleted_relation_hashes(target) if act == "restore" else self._resolve_relation_hashes(target) if not hashes: return {"success": False, "detail": "未命中可维护关系"} - act = str(action or "").strip().lower() + if act == "reinforce": self.metadata_store.reinforce_relations(hashes) + elif act == "freeze": + self.metadata_store.mark_relations_inactive(hashes) + self._rebuild_graph_from_metadata() elif act == "protect": ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0 self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0) @@ -354,72 +899,679 @@ class SDKMemoryKernel: restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value)) if restored <= 0: return {"success": False, "detail": "未恢复任何关系"} + self._rebuild_graph_from_metadata() else: return {"success": False, "detail": f"不支持的维护动作: {act}"} + self._last_maintenance_at = time.time() self._persist() return {"success": True, "detail": f"{act} {len(hashes)} 条关系"} - def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> int: - assert self.metadata_store - rebuilt = 0 + async def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store is not None + assert self.episode_service is not None + + items: List[Dict[str, Any]] = [] + failures: List[Dict[str, str]] = [] for source in self._tokens(sources): - rows = self.metadata_store.query( - """ - SELECT * FROM paragraphs - WHERE source = ? - AND (is_deleted IS NULL OR is_deleted = 0) - ORDER BY COALESCE(event_time_start, event_time, created_at) ASC, hash ASC - """, - (source,), - ) - if not rows: - continue - paragraph_hashes = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()] - payload = self.metadata_store.upsert_episode( - { - "source": source, - "title": str((rows[0].get("metadata", {}) or {}).get("theme", "") or f"{source} 情景记忆")[:80], - "summary": ";".join(str(row.get("content", "") or "").strip().replace("\n", " ")[:120] for row in rows[:3] if str(row.get("content", "") or "").strip())[:500] or "自动构建的情景记忆。", - "participants": self._episode_participants(rows), - "keywords": self._episode_keywords(rows), - "evidence_ids": paragraph_hashes, - "paragraph_count": len(paragraph_hashes), - "event_time_start": self._time_bound(rows, "event_time_start", "event_time", reverse=False), - "event_time_end": self._time_bound(rows, "event_time_end", "event_time", reverse=True), - "time_granularity": "day", - "time_confidence": 0.7, - "llm_confidence": 0.0, - "segmentation_model": "rule_based_sdk", - "segmentation_version": "1", - } - ) - self.metadata_store.bind_episode_paragraphs(payload["episode_id"], paragraph_hashes) - rebuilt += 1 - return rebuilt + self.metadata_store.mark_episode_source_running(source) + try: + result = await self.episode_service.rebuild_source(source) + self.metadata_store.mark_episode_source_done(source) + items.append(result) + except Exception as exc: + err = str(exc)[:500] + self.metadata_store.mark_episode_source_failed(source, err) + failures.append({"source": source, "error": err}) + self._persist() + return { + "rebuilt": len(items), + "items": items, + "failures": failures, + "sources": [str(item.get("source", "") or "") for item in items] or self._tokens(sources), + } def memory_stats(self) -> Dict[str, Any]: assert self.metadata_store stats = self.metadata_store.get_statistics() episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"] profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"] - return {"paragraphs": int(stats.get("paragraph_count", 0) or 0), "relations": int(stats.get("relation_count", 0) or 0), "episodes": int(episodes or 0), "profiles": int(profiles or 0), "last_maintenance_at": self._last_maintenance_at} + pending = self.metadata_store.query( + "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" + )[0]["c"] + return { + "paragraphs": int(stats.get("paragraph_count", 0) or 0), + "relations": int(stats.get("relation_count", 0) or 0), + "episodes": int(episodes or 0), + "profiles": int(profiles or 0), + "episode_pending": int(pending or 0), + "last_maintenance_at": self._last_maintenance_at, + } - async def _aggregate_search(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]: - assert self.retriever - hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)] - return {"success": True, "results": hits, "count": len(hits), "query_type": "search"} + async def memory_graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store is not None + assert self.graph_store is not None - async def _aggregate_time(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]: - if temporal is None: - return {"success": False, "error": "missing temporal window", "results": []} - assert self.retriever - hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)] - return {"success": True, "results": hits, "count": len(hits), "query_type": "time"} + act = str(action or "").strip().lower() + if act == "get_graph": + return {"success": True, **self._serialize_graph(limit=max(1, int(kwargs.get("limit", 200) or 200)))} - async def _aggregate_episode(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]: + if act == "create_node": + name = str(kwargs.get("name", "") or kwargs.get("node", "") or "").strip() + if not name: + return {"success": False, "error": "node name 不能为空"} + entity_hash = self.metadata_store.add_entity(name=name, metadata=kwargs.get("metadata") or {}) + self._rebuild_graph_from_metadata() + self._persist() + return {"success": True, "node": {"name": name, "hash": entity_hash}} + + if act == "delete_node": + name = str(kwargs.get("name", "") or kwargs.get("node", "") or kwargs.get("hash_or_name", "") or "").strip() + if not name: + return {"success": False, "error": "node name 不能为空"} + result = await self._execute_delete_action( + mode="entity", + selector={"query": name}, + requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), + reason=str(kwargs.get("reason", "") or "graph_delete_node"), + ) + return { + "success": bool(result.get("success", False)), + "deleted": bool(result.get("deleted_count", 0)), + "node": name, + "operation_id": result.get("operation_id", ""), + "counts": result.get("counts", {}), + "error": result.get("error", ""), + } + + if act == "rename_node": + old_name = str(kwargs.get("name", "") or kwargs.get("old_name", "") or kwargs.get("node", "") or "").strip() + new_name = str(kwargs.get("new_name", "") or kwargs.get("target_name", "") or "").strip() + return self._rename_node(old_name, new_name) + + if act == "create_edge": + subject = str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip() + predicate = str(kwargs.get("predicate", "") or kwargs.get("label", "") or "").strip() + obj = str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip() + if not all([subject, predicate, obj]): + return {"success": False, "error": "subject/predicate/object 不能为空"} + if self.relation_write_service is not None: + result = await self.relation_write_service.upsert_relation_with_vector( + subject=subject, + predicate=predicate, + obj=obj, + confidence=float(kwargs.get("confidence", 1.0) or 1.0), + source_paragraph=str(kwargs.get("source_paragraph", "") or "") or None, + metadata=kwargs.get("metadata") or {}, + write_vector=self.relation_vectors_enabled, + ) + relation_hash = result.hash_value + else: + relation_hash = self.metadata_store.add_relation( + subject=subject, + predicate=predicate, + obj=obj, + confidence=float(kwargs.get("confidence", 1.0) or 1.0), + source_paragraph=kwargs.get("source_paragraph"), + metadata=kwargs.get("metadata") or {}, + ) + self._rebuild_graph_from_metadata() + self._persist() + return { + "success": True, + "edge": { + "hash": relation_hash, + "subject": subject, + "predicate": predicate, + "object": obj, + "weight": float(kwargs.get("confidence", 1.0) or 1.0), + }, + } + + if act == "delete_edge": + relation_hash = str(kwargs.get("hash", "") or kwargs.get("relation_hash", "") or "").strip() + if relation_hash: + result = await self._execute_delete_action( + mode="relation", + selector={"query": relation_hash}, + requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), + reason=str(kwargs.get("reason", "") or "graph_delete_edge"), + ) + return { + "success": bool(result.get("success", False)), + "deleted": int(result.get("deleted_count", 0)), + "hash": relation_hash, + "operation_id": result.get("operation_id", ""), + "counts": result.get("counts", {}), + "error": result.get("error", ""), + } + + subject = str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip() + obj = str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip() + deleted_hashes = [ + str(row.get("hash", "") or "") + for row in self.metadata_store.get_relations(subject=subject) + if str(row.get("object", "") or "").strip() == obj + ] + result = await self._execute_delete_action( + mode="relation", + selector={"hashes": deleted_hashes, "subject": subject, "object": obj}, + requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), + reason=str(kwargs.get("reason", "") or "graph_delete_edge"), + ) + return { + "success": bool(result.get("success", False)), + "deleted": int(result.get("deleted_count", 0)), + "subject": subject, + "object": obj, + "operation_id": result.get("operation_id", ""), + "counts": result.get("counts", {}), + "error": result.get("error", ""), + } + + if act == "update_edge_weight": + return self._update_edge_weight( + relation_hash=str(kwargs.get("hash", "") or kwargs.get("relation_hash", "") or "").strip(), + subject=str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip(), + obj=str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip(), + weight=float(kwargs.get("weight", kwargs.get("confidence", 1.0)) or 1.0), + ) + + return {"success": False, "error": f"不支持的 graph action: {act}"} + + async def memory_source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store + + act = str(action or "").strip().lower() + if act == "list": + sources = self.metadata_store.get_all_sources() + items = [] + for row in sources: + source_name = str(row.get("source", "") or "").strip() + items.append( + { + **row, + "episode_rebuild_blocked": self.metadata_store.is_episode_source_query_blocked(source_name), + } + ) + return {"success": True, "items": items, "count": len(items)} + + if act == "delete": + source = str(kwargs.get("source", "") or "").strip() + return await self._execute_delete_action( + mode="source", + selector={"sources": [source]}, + requested_by=str(kwargs.get("requested_by", "") or "memory_source_admin"), + reason=str(kwargs.get("reason", "") or "source_delete"), + ) + + if act == "batch_delete": + return await self._execute_delete_action( + mode="source", + selector={"sources": list(kwargs.get("sources") or [])}, + requested_by=str(kwargs.get("requested_by", "") or "memory_source_admin"), + reason=str(kwargs.get("reason", "") or "source_batch_delete"), + ) + + return {"success": False, "error": f"不支持的 source action: {act}"} + + async def memory_episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store + + act = str(action or "").strip().lower() + if act in {"query", "list"}: + items = self.metadata_store.query_episodes( + query=str(kwargs.get("query", "") or "").strip(), + time_from=self._optional_float(kwargs.get("time_start", kwargs.get("time_from"))), + time_to=self._optional_float(kwargs.get("time_end", kwargs.get("time_to"))), + person=str(kwargs.get("person_id", "") or kwargs.get("person", "") or "").strip() or None, + source=str(kwargs.get("source", "") or "").strip() or None, + limit=max(1, int(kwargs.get("limit", 20) or 20)), + ) + return {"success": True, "items": items, "count": len(items)} + + if act == "get": + episode_id = str(kwargs.get("episode_id", "") or "").strip() + if not episode_id: + return {"success": False, "error": "episode_id 不能为空"} + episode = self.metadata_store.get_episode_by_id(episode_id) + if episode is None: + return {"success": False, "error": "episode 不存在"} + episode["paragraphs"] = self.metadata_store.get_episode_paragraphs( + episode_id, + limit=max(1, int(kwargs.get("paragraph_limit", 100) or 100)), + ) + return {"success": True, "episode": episode} + + if act == "status": + summary = self.metadata_store.get_episode_source_rebuild_summary( + failed_limit=max(1, int(kwargs.get("limit", 20) or 20)) + ) + summary["pending_queue"] = self.metadata_store.query( + "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" + )[0]["c"] + return {"success": True, **summary} + + if act == "rebuild": + sources = self._tokens(kwargs.get("sources")) + if not sources: + source = str(kwargs.get("source", "") or "").strip() + if source: + sources = [source] + if not sources and bool(kwargs.get("all", False)): + sources = self.metadata_store.list_episode_sources_for_rebuild() + if not sources: + sources = [str(row.get("source", "") or "").strip() for row in self.metadata_store.get_all_sources()] + if not sources: + return {"success": False, "error": "未提供可重建的 source"} + result = await self.rebuild_episodes_for_sources(sources) + return {"success": len(result.get("failures", [])) == 0, **result} + + if act == "process_pending": + result = await self.process_episode_pending_batch( + limit=max(1, int(kwargs.get("limit", 20) or 20)), + max_retry=max(1, int(kwargs.get("max_retry", 3) or 3)), + ) + return {"success": True, **result} + + return {"success": False, "error": f"不支持的 episode action: {act}"} + + async def memory_profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store is not None + assert self.person_profile_service is not None + + act = str(action or "").strip().lower() + if act == "query": + profile = await self.person_profile_service.query_person_profile( + person_id=str(kwargs.get("person_id", "") or "").strip(), + person_keyword=str(kwargs.get("person_keyword", "") or kwargs.get("keyword", "") or "").strip(), + top_k=max(1, int(kwargs.get("limit", kwargs.get("top_k", 12)) or 12)), + force_refresh=bool(kwargs.get("force_refresh", False)), + source_note="sdk_memory_kernel.memory_profile_admin.query", + ) + return profile if isinstance(profile, dict) else {"success": False, "error": "invalid profile payload"} + + if act == "list": + limit = max(1, int(kwargs.get("limit", 50) or 50)) + rows = self.metadata_store.query( + """ + SELECT s.person_id, s.profile_version, s.profile_text, s.updated_at, s.expires_at, s.source_note + FROM person_profile_snapshots s + JOIN ( + SELECT person_id, MAX(profile_version) AS max_version + FROM person_profile_snapshots + GROUP BY person_id + ) latest + ON latest.person_id = s.person_id + AND latest.max_version = s.profile_version + ORDER BY s.updated_at DESC + LIMIT ? + """, + (limit,), + ) + items = [] + for row in rows: + person_id = str(row.get("person_id", "") or "").strip() + override = self.metadata_store.get_person_profile_override(person_id) + items.append( + { + "person_id": person_id, + "profile_version": int(row.get("profile_version", 0) or 0), + "profile_text": str(row.get("profile_text", "") or ""), + "updated_at": row.get("updated_at"), + "expires_at": row.get("expires_at"), + "source_note": str(row.get("source_note", "") or ""), + "has_manual_override": bool(override), + "manual_override": override, + } + ) + return {"success": True, "items": items, "count": len(items)} + + if act == "set_override": + person_id = str(kwargs.get("person_id", "") or "").strip() + override = self.metadata_store.set_person_profile_override( + person_id=person_id, + override_text=str(kwargs.get("override_text", "") or kwargs.get("text", "") or ""), + updated_by=str(kwargs.get("updated_by", "") or ""), + source=str(kwargs.get("source", "") or "memory_profile_admin"), + ) + return {"success": True, "override": override} + + if act == "delete_override": + person_id = str(kwargs.get("person_id", "") or "").strip() + deleted = self.metadata_store.delete_person_profile_override(person_id) + return {"success": bool(deleted), "deleted": bool(deleted), "person_id": person_id} + + return {"success": False, "error": f"不支持的 profile action: {act}"} + + async def memory_runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + act = str(action or "").strip().lower() + + if act == "save": + self._persist() + return {"success": True, "saved": True, "data_dir": str(self.data_dir)} + + if act == "get_config": + return { + "success": True, + "config": self.config, + "data_dir": str(self.data_dir), + "embedding_dimension": int(self.embedding_dimension), + "auto_save": bool(self._cfg("advanced.enable_auto_save", True)), + "relation_vectors_enabled": bool(self.relation_vectors_enabled), + "runtime_ready": self.is_runtime_ready(), + } + + if act in {"self_check", "refresh_self_check"}: + report = await run_embedding_runtime_self_check( + config=self._build_runtime_config(), + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + sample_text=str(kwargs.get("sample_text", "") or "A_Memorix runtime self check"), + ) + self._runtime_facade._runtime_self_check_report = dict(report) + return {"success": bool(report.get("ok", False)), "report": report} + + if act == "set_auto_save": + enabled = bool(kwargs.get("enabled", False)) + self._set_cfg("advanced.enable_auto_save", enabled) + return {"success": True, "auto_save": enabled} + + return {"success": False, "error": f"不支持的 runtime action: {act}"} + + async def memory_import_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + manager = self.import_task_manager + if manager is None: + return {"success": False, "error": "import manager 未初始化"} + + act = str(action or "").strip().lower() + if act in {"settings", "get_settings", "get_guide"}: + return {"success": True, "settings": await manager.get_runtime_settings()} + if act in {"path_aliases", "get_path_aliases"}: + return {"success": True, "path_aliases": manager.get_path_aliases()} + if act in {"resolve_path", "resolve"}: + return await manager.resolve_path_request(kwargs) + if act == "create_upload": + task = await manager.create_upload_task( + list(kwargs.get("staged_files") or kwargs.get("files") or kwargs.get("uploads") or []), + kwargs, + ) + return {"success": True, "task": task} + if act == "create_paste": + return {"success": True, "task": await manager.create_paste_task(kwargs)} + if act == "create_raw_scan": + return {"success": True, "task": await manager.create_raw_scan_task(kwargs)} + if act == "create_lpmm_openie": + return {"success": True, "task": await manager.create_lpmm_openie_task(kwargs)} + if act == "create_lpmm_convert": + return {"success": True, "task": await manager.create_lpmm_convert_task(kwargs)} + if act == "create_temporal_backfill": + return {"success": True, "task": await manager.create_temporal_backfill_task(kwargs)} + if act == "create_maibot_migration": + return {"success": True, "task": await manager.create_maibot_migration_task(kwargs)} + if act == "list": + items = await manager.list_tasks(limit=max(1, int(kwargs.get("limit", 50) or 50))) + return {"success": True, "items": items, "count": len(items)} + if act == "get": + task = await manager.get_task( + str(kwargs.get("task_id", "") or ""), + include_chunks=bool(kwargs.get("include_chunks", False)), + ) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + if act in {"chunks", "get_chunks"}: + payload = await manager.get_chunks( + str(kwargs.get("task_id", "") or ""), + str(kwargs.get("file_id", "") or ""), + offset=max(0, int(kwargs.get("offset", 0) or 0)), + limit=max(1, int(kwargs.get("limit", 50) or 50)), + ) + return {"success": payload is not None, **(payload or {}), "error": "" if payload is not None else "任务或文件不存在"} + if act == "cancel": + task = await manager.cancel_task(str(kwargs.get("task_id", "") or "")) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + if act == "retry_failed": + overrides = kwargs.get("overrides") if isinstance(kwargs.get("overrides"), dict) else kwargs + task = await manager.retry_failed(str(kwargs.get("task_id", "") or ""), overrides=overrides) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + return {"success": False, "error": f"不支持的 import action: {act}"} + + async def memory_tuning_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + manager = self.retrieval_tuning_manager + if manager is None: + return {"success": False, "error": "tuning manager 未初始化"} + + act = str(action or "").strip().lower() + if act in {"settings", "get_settings"}: + return {"success": True, "settings": manager.get_runtime_settings()} + if act == "get_profile": + profile = manager.get_profile_snapshot() + return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)} + if act == "apply_profile": + profile = kwargs.get("profile") if isinstance(kwargs.get("profile"), dict) else kwargs + return {"success": True, **await manager.apply_profile(profile, reason=str(kwargs.get("reason", "manual") or "manual"))} + if act == "rollback_profile": + return {"success": True, **await manager.rollback_profile()} + if act == "export_profile": + profile = manager.get_profile_snapshot() + return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)} + if act == "create_task": + payload = kwargs.get("payload") if isinstance(kwargs.get("payload"), dict) else kwargs + return {"success": True, "task": await manager.create_task(payload)} + if act == "list_tasks": + items = await manager.list_tasks(limit=max(1, int(kwargs.get("limit", 50) or 50))) + return {"success": True, "items": items, "count": len(items)} + if act == "get_task": + task = await manager.get_task( + str(kwargs.get("task_id", "") or ""), + include_rounds=bool(kwargs.get("include_rounds", False)), + ) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + if act == "get_rounds": + payload = await manager.get_rounds( + str(kwargs.get("task_id", "") or ""), + offset=max(0, int(kwargs.get("offset", 0) or 0)), + limit=max(1, int(kwargs.get("limit", 50) or 50)), + ) + return {"success": payload is not None, **(payload or {}), "error": "" if payload is not None else "任务不存在"} + if act == "cancel": + task = await manager.cancel_task(str(kwargs.get("task_id", "") or "")) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + if act == "apply_best": + return {"success": True, **await manager.apply_best(str(kwargs.get("task_id", "") or ""))} + if act == "get_report": + report = await manager.get_report(str(kwargs.get("task_id", "") or ""), fmt=str(kwargs.get("format", "md") or "md")) + return {"success": report is not None, "report": report, "error": "" if report is not None else "任务不存在"} + return {"success": False, "error": f"不支持的 tuning action: {act}"} + + async def memory_v5_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store + + act = str(action or "").strip().lower() + target = str(kwargs.get("target", "") or kwargs.get("query", "") or "").strip() + reason = str(kwargs.get("reason", "") or "").strip() + updated_by = str(kwargs.get("updated_by", "") or kwargs.get("requested_by", "") or "").strip() + limit = max(1, int(kwargs.get("limit", 50) or 50)) + + if act == "recycle_bin": + items = self.metadata_store.get_deleted_relations(limit=limit) + return {"success": True, "items": items, "count": len(items)} + + if act == "status": + return self._memory_v5_status(target=target, limit=limit) + + if act == "restore": + hashes = self._resolve_deleted_relation_hashes(target) + if not hashes: + return {"success": False, "error": "未命中可恢复关系"} + result = await self._restore_relation_hashes(hashes) + operation = self.metadata_store.record_v5_operation( + action=act, + target=target, + resolved_hashes=hashes, + reason=reason, + updated_by=updated_by, + result=result, + ) + return {"success": bool(result.get("restored_count", 0) > 0), "operation": operation, **result} + + hashes = self._resolve_relation_hashes(target) + if not hashes: + return {"success": False, "error": "未命中可维护关系"} + + result = self._apply_v5_relation_action( + action=act, + hashes=hashes, + strength=float(kwargs.get("strength", 1.0) or 1.0), + ) + operation = self.metadata_store.record_v5_operation( + action=act, + target=target, + resolved_hashes=hashes, + reason=reason, + updated_by=updated_by, + result=result, + ) + return {"success": bool(result.get("success", False)), "operation": operation, **result} + + async def memory_delete_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + act = str(action or "").strip().lower() + mode = str(kwargs.get("mode", "") or "").strip().lower() + selector = kwargs.get("selector") + if selector is None: + selector = { + key: value + for key, value in kwargs.items() + if key + not in { + "action", + "mode", + "dry_run", + "cascade", + "operation_id", + "reason", + "requested_by", + } + } + reason = str(kwargs.get("reason", "") or "").strip() + requested_by = str(kwargs.get("requested_by", "") or "").strip() + + if act == "preview": + return await self._preview_delete_action(mode=mode, selector=selector) + if act == "execute": + return await self._execute_delete_action( + mode=mode, + selector=selector, + requested_by=requested_by, + reason=reason, + ) + if act == "restore": + return await self._restore_delete_action( + mode=mode, + selector=selector, + operation_id=str(kwargs.get("operation_id", "") or "").strip(), + requested_by=requested_by, + reason=reason, + ) + if act == "get_operation": + operation = self.metadata_store.get_delete_operation(str(kwargs.get("operation_id", "") or "").strip()) + return {"success": operation is not None, "operation": operation, "error": "" if operation is not None else "operation 不存在"} + if act == "list_operations": + items = self.metadata_store.list_delete_operations( + limit=max(1, int(kwargs.get("limit", 50) or 50)), + mode=mode, + ) + return {"success": True, "items": items, "count": len(items)} + if act == "purge": + return await self._purge_deleted_memory( + grace_hours=self._optional_float(kwargs.get("grace_hours")), + limit=max(1, int(kwargs.get("limit", 1000) or 1000)), + ) + return {"success": False, "error": f"不支持的 delete action: {act}"} + + def get_import_task_manager(self) -> Optional[ImportTaskManager]: + return self.import_task_manager + + def get_retrieval_tuning_manager(self) -> Optional[RetrievalTuningManager]: + return self.retrieval_tuning_manager + + async def _aggregate_search(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]: + result = await SearchExecutionService.execute( + retriever=self.retriever, + threshold_filter=self.threshold_filter, + plugin_config=self._build_runtime_config(), + request=SearchExecutionRequest( + caller="sdk_memory_kernel.aggregate", + stream_id=str(request.chat_id or "") or None, + query_type="search", + query=query, + top_k=limit, + person=str(request.person_id or "") or None, + source=self._chat_source(request.chat_id), + use_threshold=True, + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ), + enforce_chat_filter=False, + reinforce_access=True, + ) + hits = [self._retrieval_result_hit(item) for item in result.results] if result.success else [] + return {"success": result.success, "results": hits, "count": len(hits), "query_type": "search", "error": result.error} + + async def _aggregate_time( + self, + query: str, + limit: int, + request: KernelSearchRequest, + time_window: _NormalizedSearchTimeWindow, + ) -> Dict[str, Any]: + result = await SearchExecutionService.execute( + retriever=self.retriever, + threshold_filter=self.threshold_filter, + plugin_config=self._build_runtime_config(), + request=SearchExecutionRequest( + caller="sdk_memory_kernel.aggregate", + stream_id=str(request.chat_id or "") or None, + query_type="time", + query=query, + top_k=limit, + time_from=time_window.query_start, + time_to=time_window.query_end, + person=str(request.person_id or "") or None, + source=self._chat_source(request.chat_id), + use_threshold=True, + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ), + enforce_chat_filter=False, + reinforce_access=True, + ) + hits = [self._retrieval_result_hit(item) for item in result.results] if result.success else [] + return {"success": result.success, "results": hits, "count": len(hits), "query_type": "time", "error": result.error} + + async def _aggregate_episode( + self, + query: str, + limit: int, + request: KernelSearchRequest, + time_window: _NormalizedSearchTimeWindow, + ) -> Dict[str, Any]: assert self.episode_retriever - rows = await self.episode_retriever.query(query=query, top_k=limit, time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id)) + rows = await self.episode_retriever.query( + query=query, + top_k=limit, + time_from=time_window.numeric_start, + time_to=time_window.numeric_end, + person=request.person_id or None, + source=self._chat_source(request.chat_id), + ) hits = [self._episode_hit(row) for row in rows] return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"} @@ -431,6 +1583,494 @@ class SDKMemoryKernel: if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False): self.sparse_index.ensure_loaded() + async def _start_background_tasks(self) -> None: + async with self._background_lock: + self._background_stopping = False + self._ensure_background_task("auto_save", self._auto_save_loop) + self._ensure_background_task("episode_pending", self._episode_pending_loop) + self._ensure_background_task("memory_maintenance", self._memory_maintenance_loop) + self._ensure_background_task("person_profile_refresh", self._person_profile_refresh_loop) + + def _ensure_background_task(self, name: str, factory: Callable[[], Awaitable[None]]) -> None: + task = self._background_tasks.get(name) + if task is not None and not task.done(): + return + self._background_tasks[name] = asyncio.create_task(factory(), name=f"A_Memorix.{name}") + + async def _stop_background_tasks(self) -> None: + async with self._background_lock: + self._background_stopping = True + tasks = [task for task in self._background_tasks.values() if task is not None and not task.done()] + for task in tasks: + task.cancel() + for task in tasks: + try: + await task + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning(f"后台任务退出异常: {exc}") + self._background_tasks.clear() + + async def _auto_save_loop(self) -> None: + try: + while not self._background_stopping: + interval_minutes = max(1.0, float(self._cfg("advanced.auto_save_interval_minutes", 5) or 5)) + await asyncio.sleep(interval_minutes * 60.0) + if self._background_stopping: + break + if bool(self._cfg("advanced.enable_auto_save", True)): + self._persist() + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"auto_save loop 异常: {exc}") + + async def _episode_pending_loop(self) -> None: + try: + while not self._background_stopping: + await asyncio.sleep(60.0) + if self._background_stopping: + break + if not bool(self._cfg("episode.enabled", True)): + continue + if not bool(self._cfg("episode.generation_enabled", True)): + continue + await self.process_episode_pending_batch( + limit=max(1, int(self._cfg("episode.pending_batch_size", 20) or 20)), + max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3) or 3)), + ) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"episode_pending loop 异常: {exc}") + + async def _person_profile_refresh_loop(self) -> None: + try: + while not self._background_stopping: + interval_minutes = max(1.0, float(self._cfg("person_profile.refresh_interval_minutes", 30) or 30)) + await asyncio.sleep(max(60.0, interval_minutes * 60.0)) + if self._background_stopping: + break + if not bool(self._cfg("person_profile.enabled", True)): + continue + active_window_hours = max(1.0, float(self._cfg("person_profile.active_window_hours", 72.0) or 72.0)) + max_refresh = max(1, int(self._cfg("person_profile.max_refresh_per_cycle", 50) or 50)) + cutoff = time.time() - active_window_hours * 3600.0 + candidates = [ + person_id + for person_id, seen_at in sorted( + self._active_person_timestamps.items(), + key=lambda item: item[1], + reverse=True, + ) + if seen_at >= cutoff + ][:max_refresh] + for person_id in candidates: + try: + await self.refresh_person_profile(person_id, limit=max(4, int(self._cfg("person_profile.top_k_evidence", 12) or 12)), mark_active=False) + except Exception as exc: + logger.warning(f"刷新人物画像失败: {exc}") + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"person_profile_refresh loop 异常: {exc}") + + async def _memory_maintenance_loop(self) -> None: + try: + while not self._background_stopping: + interval_hours = max(1.0 / 60.0, float(self._cfg("memory.base_decay_interval_hours", 1.0) or 1.0)) + await asyncio.sleep(max(60.0, interval_hours * 3600.0)) + if self._background_stopping: + break + if not bool(self._cfg("memory.enabled", True)): + continue + await self._run_memory_maintenance_cycle(interval_hours=interval_hours) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"memory_maintenance loop 异常: {exc}") + + async def _run_memory_maintenance_cycle(self, *, interval_hours: float) -> None: + assert self.graph_store is not None + assert self.metadata_store is not None + half_life = float(self._cfg("memory.half_life_hours", 24.0) or 24.0) + if half_life > 0: + factor = 0.5 ** (float(interval_hours) / half_life) + self.graph_store.decay(factor) + + await self._process_freeze_and_prune() + await self._orphan_gc_phase() + self._last_maintenance_at = time.time() + self._persist() + + async def _process_freeze_and_prune(self) -> None: + assert self.metadata_store is not None + assert self.graph_store is not None + prune_threshold = max(0.0, float(self._cfg("memory.prune_threshold", 0.1) or 0.1)) + freeze_duration = max(0.0, float(self._cfg("memory.freeze_duration_hours", 24.0) or 24.0)) * 3600.0 + now = time.time() + + low_edges = self.graph_store.get_low_weight_edges(prune_threshold) + hashes_to_freeze: List[str] = [] + edges_to_deactivate: List[tuple[str, str]] = [] + for src, tgt in low_edges: + relation_hashes = list(self.graph_store.get_relation_hashes_for_edge(src, tgt)) + if not relation_hashes: + continue + statuses = self.metadata_store.get_relation_status_batch(relation_hashes) + current_hashes: List[str] = [] + protected = False + for hash_value, status in statuses.items(): + if bool(status.get("is_pinned")) or float(status.get("protected_until") or 0.0) > now: + protected = True + break + current_hashes.append(hash_value) + if protected or not current_hashes: + continue + hashes_to_freeze.extend(current_hashes) + edges_to_deactivate.append((src, tgt)) + + if hashes_to_freeze: + self.metadata_store.mark_relations_inactive(hashes_to_freeze, inactive_since=now) + self.graph_store.deactivate_edges(edges_to_deactivate) + + cutoff = now - freeze_duration + expired_hashes = self.metadata_store.get_prune_candidates(cutoff) + if not expired_hashes: + return + relation_info = self.metadata_store.get_relations_subject_object_map(expired_hashes) + operations = [(src, tgt, hash_value) for hash_value, (src, tgt) in relation_info.items()] + if operations: + self.graph_store.prune_relation_hashes(operations) + deleted_hashes = [hash_value for hash_value in expired_hashes if hash_value in relation_info] + if deleted_hashes: + self.metadata_store.backup_and_delete_relations(deleted_hashes) + if self.vector_store is not None: + self.vector_store.delete(deleted_hashes) + + async def _orphan_gc_phase(self) -> None: + assert self.metadata_store is not None + assert self.graph_store is not None + orphan_cfg = self._cfg("memory.orphan", {}) or {} + if not bool(orphan_cfg.get("enable_soft_delete", True)): + return + entity_retention = max(0.0, float(orphan_cfg.get("entity_retention_days", 7.0) or 7.0)) * 86400.0 + paragraph_retention = max(0.0, float(orphan_cfg.get("paragraph_retention_days", 7.0) or 7.0)) * 86400.0 + grace_period = max(0.0, float(orphan_cfg.get("sweep_grace_hours", 24.0) or 24.0)) * 3600.0 + + isolated = self.graph_store.get_isolated_nodes(include_inactive=True) + if isolated: + entity_hashes = self.metadata_store.get_entity_gc_candidates(isolated, retention_seconds=entity_retention) + if entity_hashes: + self.metadata_store.mark_as_deleted(entity_hashes, "entity") + + paragraph_hashes = self.metadata_store.get_paragraph_gc_candidates(retention_seconds=paragraph_retention) + if paragraph_hashes: + self.metadata_store.mark_as_deleted(paragraph_hashes, "paragraph") + + dead_paragraphs = self.metadata_store.sweep_deleted_items("paragraph", grace_period) + if dead_paragraphs: + hashes = [str(item[0] or "").strip() for item in dead_paragraphs if item and str(item[0] or "").strip()] + if hashes: + self.metadata_store.physically_delete_paragraphs(hashes) + if self.vector_store is not None: + self.vector_store.delete(hashes) + + dead_entities = self.metadata_store.sweep_deleted_items("entity", grace_period) + if dead_entities: + entity_hashes = [str(item[0] or "").strip() for item in dead_entities if item and str(item[0] or "").strip()] + entity_names = [str(item[1] or "").strip() for item in dead_entities if item and str(item[1] or "").strip()] + if entity_names: + self.graph_store.delete_nodes(entity_names) + if entity_hashes: + self.metadata_store.physically_delete_entities(entity_hashes) + if self.vector_store is not None: + self.vector_store.delete(entity_hashes) + + def _mark_person_active(self, person_id: str) -> None: + token = str(person_id or "").strip() + if not token: + return + self._active_person_timestamps[token] = time.time() + + def _serialize_graph(self, *, limit: int = 200) -> Dict[str, Any]: + assert self.graph_store is not None + assert self.metadata_store is not None + nodes = self.graph_store.get_nodes() + if limit > 0: + nodes = nodes[:limit] + node_set = set(nodes) + node_payload = [] + for name in nodes: + attrs = self.graph_store.get_node_attributes(name) or {} + node_payload.append({"id": name, "name": name, "attributes": attrs}) + + edge_payload = [] + for source, target, relation_hashes in self.graph_store.iter_edge_hash_entries(): + if source not in node_set or target not in node_set: + continue + edge_payload.append( + { + "source": source, + "target": target, + "weight": float(self.graph_store.get_edge_weight(source, target)), + "relation_hashes": sorted(str(item) for item in relation_hashes if str(item).strip()), + } + ) + return { + "nodes": node_payload, + "edges": edge_payload, + "total_nodes": int(self.graph_store.num_nodes), + "total_edges": int(self.graph_store.num_edges), + } + + def _delete_sources(self, sources: Iterable[Any]) -> Dict[str, Any]: + assert self.metadata_store + source_tokens = self._tokens(sources) + if not source_tokens: + return {"success": False, "error": "source 不能为空"} + + deleted_paragraphs = 0 + deleted_sources: List[str] = [] + for source in source_tokens: + paragraphs = self.metadata_store.get_paragraphs_by_source(source) + if not paragraphs: + self.metadata_store.replace_episodes_for_source(source, []) + continue + for row in paragraphs: + paragraph_hash = str(row.get("hash", "") or "").strip() + if not paragraph_hash: + continue + cleanup = self.metadata_store.delete_paragraph_atomic(paragraph_hash) + self._apply_cleanup_plan(cleanup) + deleted_paragraphs += 1 + self.metadata_store.replace_episodes_for_source(source, []) + deleted_sources.append(source) + + self._rebuild_graph_from_metadata() + self._persist() + return { + "success": True, + "sources": deleted_sources, + "deleted_source_count": len(deleted_sources), + "deleted_paragraph_count": deleted_paragraphs, + } + + def _apply_cleanup_plan(self, cleanup: Dict[str, Any]) -> None: + if not isinstance(cleanup, dict): + return + if self.vector_store is not None: + vector_ids: List[str] = [] + paragraph_hash = str(cleanup.get("vector_id_to_remove", "") or "").strip() + if paragraph_hash: + vector_ids.append(paragraph_hash) + for _, _, relation_hash in cleanup.get("relation_prune_ops", []) or []: + token = str(relation_hash or "").strip() + if token: + vector_ids.append(token) + if vector_ids: + self.vector_store.delete(list(dict.fromkeys(vector_ids))) + + def _rebuild_graph_from_metadata(self) -> Dict[str, int]: + assert self.metadata_store is not None + assert self.graph_store is not None + entity_rows = self.metadata_store.query( + """ + SELECT name + FROM entities + WHERE is_deleted IS NULL OR is_deleted = 0 + ORDER BY name ASC + """ + ) + raw_relation_rows = self.metadata_store.query( + """ + SELECT subject, object, confidence, hash + FROM relations + WHERE is_inactive IS NULL OR is_inactive = 0 + """ + ) + relation_rows = [ + row + for row in raw_relation_rows + if str(row.get("subject", "") or "").strip() and str(row.get("object", "") or "").strip() + ] + + names = list( + dict.fromkeys( + [ + str(row.get("name", "") or "").strip() + for row in entity_rows + if str(row.get("name", "") or "").strip() + ] + + [ + str(row.get("subject", "") or "").strip() + for row in relation_rows + if str(row.get("subject", "") or "").strip() + ] + + [ + str(row.get("object", "") or "").strip() + for row in relation_rows + if str(row.get("object", "") or "").strip() + ] + ) + ) + self.graph_store.clear() + if names: + self.graph_store.add_nodes(names) + if relation_rows: + self.graph_store.add_edges( + [ + ( + str(row.get("subject", "") or "").strip(), + str(row.get("object", "") or "").strip(), + ) + for row in relation_rows + ], + weights=[float(row.get("confidence", 1.0) or 1.0) for row in relation_rows], + relation_hashes=[str(row.get("hash", "") or "") for row in relation_rows], + ) + return {"node_count": int(self.graph_store.num_nodes), "edge_count": int(self.graph_store.num_edges)} + + def _rename_node(self, old_name: str, new_name: str) -> Dict[str, Any]: + assert self.metadata_store + source = str(old_name or "").strip() + target = str(new_name or "").strip() + if not source or not target: + return {"success": False, "error": "old_name/new_name 不能为空"} + if source == target: + return {"success": True, "renamed": False, "old_name": source, "new_name": target} + + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + old_hash = compute_hash(source.lower()) + target_hash = compute_hash(target.lower()) + + cursor.execute( + """ + SELECT hash, name, vector_index, appearance_count, created_at, metadata + FROM entities + WHERE hash = ? + OR LOWER(TRIM(name)) = LOWER(TRIM(?)) + LIMIT 1 + """, + (old_hash, source), + ) + old_row = cursor.fetchone() + if old_row is None: + return {"success": False, "error": "原节点不存在"} + + cursor.execute( + """ + SELECT hash, appearance_count + FROM entities + WHERE hash = ? + OR LOWER(TRIM(name)) = LOWER(TRIM(?)) + LIMIT 1 + """, + (target_hash, target), + ) + target_row = cursor.fetchone() + + try: + cursor.execute("BEGIN IMMEDIATE") + if target_row is None: + cursor.execute( + """ + INSERT INTO entities (hash, name, vector_index, appearance_count, created_at, metadata, is_deleted, deleted_at) + VALUES (?, ?, ?, ?, ?, ?, 0, NULL) + """, + ( + target_hash, + target, + old_row["vector_index"], + old_row["appearance_count"], + old_row["created_at"], + old_row["metadata"], + ), + ) + resolved_target_hash = target_hash + else: + resolved_target_hash = str(target_row["hash"] or "").strip() + cursor.execute( + """ + UPDATE entities + SET name = ?, + appearance_count = COALESCE(appearance_count, 0) + ?, + is_deleted = 0, + deleted_at = NULL + WHERE hash = ? + """, + ( + target, + int(old_row["appearance_count"] or 0), + resolved_target_hash, + ), + ) + + cursor.execute( + "UPDATE OR IGNORE paragraph_entities SET entity_hash = ? WHERE entity_hash = ?", + (resolved_target_hash, old_row["hash"]), + ) + cursor.execute("DELETE FROM paragraph_entities WHERE entity_hash = ?", (old_row["hash"],)) + cursor.execute( + "UPDATE relations SET subject = ? WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?))", + (target, old_row["name"]), + ) + cursor.execute( + "UPDATE relations SET object = ? WHERE LOWER(TRIM(object)) = LOWER(TRIM(?))", + (target, old_row["name"]), + ) + cursor.execute("DELETE FROM entities WHERE hash = ?", (old_row["hash"],)) + conn.commit() + except Exception as exc: + conn.rollback() + return {"success": False, "error": f"rename failed: {exc}"} + + self._rebuild_graph_from_metadata() + self._persist() + return {"success": True, "renamed": True, "old_name": source, "new_name": target} + + def _update_edge_weight( + self, + *, + relation_hash: str, + subject: str, + obj: str, + weight: float, + ) -> Dict[str, Any]: + assert self.metadata_store + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + target_weight = max(0.0, float(weight or 0.0)) + if relation_hash: + cursor.execute("UPDATE relations SET confidence = ? WHERE hash = ?", (target_weight, relation_hash)) + updated = cursor.rowcount + else: + cursor.execute( + """ + UPDATE relations + SET confidence = ? + WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?)) + AND LOWER(TRIM(object)) = LOWER(TRIM(?)) + """, + (target_weight, subject, obj), + ) + updated = cursor.rowcount + conn.commit() + if updated <= 0: + return {"success": False, "error": "未找到可更新的关系"} + self._rebuild_graph_from_metadata() + self._persist() + return { + "success": True, + "updated": int(updated), + "weight": target_weight, + "hash": relation_hash, + "subject": subject, + "object": obj, + } + @staticmethod def _tokens(values: Optional[Iterable[Any]]) -> List[str]: result: List[str] = [] @@ -469,6 +2109,15 @@ class SDKMemoryKernel: clean = str(chat_id or "").strip() return f"chat_summary:{clean}" if clean else None + @staticmethod + def _resolve_knowledge_type(source_type: str) -> str: + clean_type = str(source_type or "").strip().lower() + if clean_type == "person_fact": + return "factual" + if clean_type == "chat_summary": + return "narrative" + return "mixed" + @staticmethod def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]: payload: Dict[str, Any] = {} @@ -483,19 +2132,70 @@ class SDKMemoryKernel: payload["time_confidence"] = 0.95 return payload - def _temporal(self, request: KernelSearchRequest) -> Optional[TemporalQueryOptions]: - if request.time_start is None and request.time_end is None and not request.chat_id: - return None - return TemporalQueryOptions(time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id)) + @classmethod + def _normalize_search_time_bound(cls, value: Any, *, is_end: bool) -> tuple[Optional[float], Optional[str]]: + if value in {None, ""}: + return None, None + if isinstance(value, (int, float)): + ts = float(value) + return ts, format_timestamp(ts) + + text = str(value or "").strip() + if not text: + return None, None + + numeric = cls._optional_float(text) + if numeric is not None: + return numeric, format_timestamp(numeric) + + try: + ts = parse_query_datetime_to_timestamp(text, is_end=is_end) + except ValueError as exc: + raise ValueError(f"时间参数错误: {exc}") from exc + return ts, text + + @classmethod + def _normalize_search_time_window(cls, time_start: Any, time_end: Any) -> _NormalizedSearchTimeWindow: + numeric_start, query_start = cls._normalize_search_time_bound(time_start, is_end=False) + numeric_end, query_end = cls._normalize_search_time_bound(time_end, is_end=True) + if numeric_start is not None and numeric_end is not None and numeric_start > numeric_end: + raise ValueError("时间参数错误: time_start 不能晚于 time_end") + return _NormalizedSearchTimeWindow( + numeric_start=numeric_start, + numeric_end=numeric_end, + query_start=query_start, + query_end=query_end, + ) @staticmethod - def _retrieval_hit(item: RetrievalResult) -> Dict[str, Any]: + def _retrieval_result_hit(item: RetrievalResult) -> Dict[str, Any]: payload = item.to_dict() - return {"hash": payload.get("hash", ""), "content": payload.get("content", ""), "score": payload.get("score", 0.0), "type": payload.get("type", ""), "source": payload.get("source", ""), "metadata": payload.get("metadata", {}) or {}} + return { + "hash": payload.get("hash", ""), + "content": payload.get("content", ""), + "score": payload.get("score", 0.0), + "type": payload.get("type", ""), + "source": payload.get("source", ""), + "metadata": payload.get("metadata", {}) or {}, + } @staticmethod def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]: - return {"type": "episode", "episode_id": str(row.get("episode_id", "") or ""), "title": str(row.get("title", "") or ""), "content": str(row.get("summary", "") or ""), "score": float(row.get("lexical_score", 0.0) or 0.0), "source": "episode", "metadata": {"participants": row.get("participants", []) or [], "keywords": row.get("keywords", []) or [], "source": row.get("source"), "event_time_start": row.get("event_time_start"), "event_time_end": row.get("event_time_end")}} + return { + "type": "episode", + "episode_id": str(row.get("episode_id", "") or ""), + "title": str(row.get("title", "") or ""), + "content": str(row.get("summary", "") or ""), + "score": float(row.get("lexical_score", 0.0) or 0.0), + "source": "episode", + "metadata": { + "participants": row.get("participants", []) or [], + "keywords": row.get("keywords", []) or [], + "source": row.get("source"), + "event_time_start": row.get("event_time_start"), + "event_time_end": row.get("event_time_end"), + }, + } @staticmethod def _summary(hits: Sequence[Dict[str, Any]]) -> str: @@ -521,51 +2221,6 @@ class SDKMemoryKernel: filtered.append(item) return filtered or hits - @staticmethod - def _episode_participants(rows: Sequence[Dict[str, Any]]) -> List[str]: - seen = set() - result: List[str] = [] - for row in rows: - meta = row.get("metadata", {}) or {} - for key in ("participants", "person_ids"): - for item in meta.get(key, []) or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - result.append(token) - return result[:16] - - @staticmethod - def _episode_keywords(rows: Sequence[Dict[str, Any]]) -> List[str]: - seen = set() - result: List[str] = [] - for row in rows: - meta = row.get("metadata", {}) or {} - for item in meta.get("tags", []) or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - result.append(token) - return result[:12] - - @staticmethod - def _time_bound(rows: Sequence[Dict[str, Any]], primary: str, fallback: str, reverse: bool) -> Optional[float]: - values: List[float] = [] - for row in rows: - for key in (primary, fallback): - value = row.get(key) - try: - if value is not None: - values.append(float(value)) - break - except Exception: - continue - if not values: - return None - return max(values) if reverse else min(values) - def _resolve_relation_hashes(self, target: str) -> List[str]: assert self.metadata_store token = str(target or "").strip() @@ -576,4 +2231,896 @@ class SDKMemoryKernel: hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10) if hashes: return hashes - return [str(row.get("hash", "") or "") for row in self.metadata_store.get_relations(subject=token)[:10] if str(row.get("hash", "")).strip()] + return [ + str(row.get("hash", "") or "") + for row in self.metadata_store.get_relations(subject=token)[:10] + if str(row.get("hash", "")).strip() + ] + + def _resolve_deleted_relation_hashes(self, target: str) -> List[str]: + assert self.metadata_store + token = str(target or "").strip() + if not token: + return [] + if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()): + return [token] + return self.metadata_store.search_deleted_relation_hashes_by_text(token, limit=10) + + def _memory_v5_status(self, *, target: str = "", limit: int = 50) -> Dict[str, Any]: + assert self.metadata_store + now = time.time() + summary = self.metadata_store.get_memory_status_summary(now) + payload: Dict[str, Any] = { + "success": True, + **summary, + "config": { + "half_life_hours": float(self._cfg("memory.half_life_hours", 24.0) or 24.0), + "base_decay_interval_hours": float(self._cfg("memory.base_decay_interval_hours", 1.0) or 1.0), + "prune_threshold": float(self._cfg("memory.prune_threshold", 0.1) or 0.1), + "freeze_duration_hours": float(self._cfg("memory.freeze_duration_hours", 24.0) or 24.0), + }, + "last_maintenance_at": self._last_maintenance_at, + } + token = str(target or "").strip() + if not token: + return payload + + active_hashes = self._resolve_relation_hashes(token)[:limit] + deleted_hashes = self._resolve_deleted_relation_hashes(token)[:limit] + active_statuses = self.metadata_store.get_relation_status_batch(active_hashes) + items: List[Dict[str, Any]] = [] + for hash_value in active_hashes: + relation = self.metadata_store.get_relation(hash_value) or {} + status = active_statuses.get(hash_value, {}) + items.append( + { + "hash": hash_value, + "subject": str(relation.get("subject", "") or ""), + "predicate": str(relation.get("predicate", "") or ""), + "object": str(relation.get("object", "") or ""), + "state": "inactive" if bool(status.get("is_inactive")) else "active", + "is_pinned": bool(status.get("is_pinned", False)), + "temp_protected": bool(float(status.get("protected_until") or 0.0) > now), + "protected_until": status.get("protected_until"), + "last_reinforced": status.get("last_reinforced"), + "weight": float(status.get("weight", relation.get("confidence", 0.0)) or 0.0), + } + ) + for hash_value in deleted_hashes: + relation = self.metadata_store.get_deleted_relation(hash_value) or {} + items.append( + { + "hash": hash_value, + "subject": str(relation.get("subject", "") or ""), + "predicate": str(relation.get("predicate", "") or ""), + "object": str(relation.get("object", "") or ""), + "state": "deleted", + "is_pinned": bool(relation.get("is_pinned", False)), + "temp_protected": False, + "protected_until": relation.get("protected_until"), + "last_reinforced": relation.get("last_reinforced"), + "weight": float(relation.get("confidence", 0.0) or 0.0), + "deleted_at": relation.get("deleted_at"), + } + ) + payload["items"] = items[:limit] + payload["count"] = len(payload["items"]) + payload["target"] = token + return payload + + def _adjust_relation_confidence(self, hashes: List[str], *, delta: float) -> Dict[str, float]: + assert self.metadata_store + normalized = [str(item or "").strip() for item in hashes if str(item or "").strip()] + if not normalized: + return {} + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + chunk_size = 200 + for index in range(0, len(normalized), chunk_size): + chunk = normalized[index : index + chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + cursor.execute( + f""" + UPDATE relations + SET confidence = MAX(0.0, COALESCE(confidence, 0.0) + ?) + WHERE hash IN ({placeholders}) + """, + tuple([float(delta)] + chunk), + ) + conn.commit() + statuses = self.metadata_store.get_relation_status_batch(normalized) + return {hash_value: float((statuses.get(hash_value) or {}).get("weight", 0.0) or 0.0) for hash_value in normalized} + + def _apply_v5_relation_action(self, *, action: str, hashes: List[str], strength: float = 1.0) -> Dict[str, Any]: + assert self.metadata_store + act = str(action or "").strip().lower() + normalized = [str(item or "").strip() for item in hashes if str(item or "").strip()] + if not normalized: + return {"success": False, "error": "未命中可维护关系"} + + now = time.time() + strength_value = max(0.1, float(strength or 1.0)) + prune_threshold = max(0.0, float(self._cfg("memory.prune_threshold", 0.1) or 0.1)) + detail = "" + + if act == "reinforce": + weights = self._adjust_relation_confidence(normalized, delta=0.5 * strength_value) + protect_hours = max(1.0, 24.0 * strength_value) + self.metadata_store.reinforce_relations(normalized) + self.metadata_store.mark_relations_active(normalized, boost_weight=max(prune_threshold, 0.1)) + self.metadata_store.update_relations_protection( + normalized, + protected_until=now + protect_hours * 3600.0, + last_reinforced=now, + ) + detail = f"reinforce {len(normalized)} 条关系" + elif act == "weaken": + weights = self._adjust_relation_confidence(normalized, delta=-0.5 * strength_value) + to_freeze = [hash_value for hash_value, weight in weights.items() if weight <= prune_threshold] + if to_freeze: + self.metadata_store.mark_relations_inactive(to_freeze, inactive_since=now) + detail = f"weaken {len(normalized)} 条关系" + elif act == "remember_forever": + self.metadata_store.mark_relations_active(normalized, boost_weight=max(prune_threshold, 0.1)) + self.metadata_store.update_relations_protection(normalized, protected_until=0.0, is_pinned=True) + weights = {hash_value: float((self.metadata_store.get_relation_status_batch([hash_value]).get(hash_value) or {}).get("weight", 0.0) or 0.0) for hash_value in normalized} + detail = f"remember_forever {len(normalized)} 条关系" + elif act == "forget": + weights = self._adjust_relation_confidence(normalized, delta=-2.0 * strength_value) + self.metadata_store.update_relations_protection(normalized, protected_until=0.0, is_pinned=False) + self.metadata_store.mark_relations_inactive(normalized, inactive_since=now) + detail = f"forget {len(normalized)} 条关系" + else: + return {"success": False, "error": f"不支持的 V5 动作: {act}"} + + self._rebuild_graph_from_metadata() + self._last_maintenance_at = now + self._persist() + statuses = self.metadata_store.get_relation_status_batch(normalized) + return { + "success": True, + "detail": detail, + "hashes": normalized, + "count": len(normalized), + "weights": weights, + "statuses": statuses, + } + + async def _ensure_vector_for_text(self, *, item_hash: str, text: str) -> bool: + if self.vector_store is None or self.embedding_manager is None: + return False + token = str(item_hash or "").strip() + content = str(text or "").strip() + if not token or not content: + return False + embedding = await self.embedding_manager.encode([content], dimensions=self.embedding_dimension) + if getattr(embedding, "ndim", 1) == 1: + embedding = embedding.reshape(1, -1) + if getattr(embedding, "size", 0) <= 0: + return False + try: + self.vector_store.add(embedding, [token]) + return True + except Exception as exc: + logger.warning(f"重建向量失败: {exc}") + return False + + async def _ensure_relation_vector(self, relation: Dict[str, Any]) -> bool: + if not bool(self.relation_vectors_enabled): + return False + return await self._ensure_vector_for_text( + item_hash=str(relation.get("hash", "") or ""), + text=" ".join( + [ + str(relation.get("subject", "") or "").strip(), + str(relation.get("predicate", "") or "").strip(), + str(relation.get("object", "") or "").strip(), + ] + ).strip(), + ) + + async def _ensure_paragraph_vector(self, paragraph: Dict[str, Any]) -> bool: + return await self._ensure_vector_for_text( + item_hash=str(paragraph.get("hash", "") or ""), + text=str(paragraph.get("content", "") or ""), + ) + + async def _ensure_entity_vector(self, entity: Dict[str, Any]) -> bool: + return await self._ensure_vector_for_text( + item_hash=str(entity.get("hash", "") or ""), + text=str(entity.get("name", "") or ""), + ) + + async def _restore_relation_hashes( + self, + hashes: List[str], + *, + payloads: Optional[Dict[str, Dict[str, Any]]] = None, + rebuild_graph: bool = True, + persist: bool = True, + ) -> Dict[str, Any]: + assert self.metadata_store + restored: List[str] = [] + failures: List[Dict[str, str]] = [] + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + payload_map = payloads or {} + for hash_value in [str(item or "").strip() for item in hashes if str(item or "").strip()]: + relation = self.metadata_store.restore_relation(hash_value) + if relation is None: + relation = self.metadata_store.get_relation(hash_value) + if relation is None: + failures.append({"hash": hash_value, "error": "relation 不存在"}) + continue + payload = payload_map.get(hash_value) if isinstance(payload_map.get(hash_value), dict) else {} + paragraph_hashes = self._tokens(payload.get("paragraph_hashes")) + for paragraph_hash in paragraph_hashes: + cursor.execute( + """ + INSERT OR IGNORE INTO paragraph_relations (paragraph_hash, relation_hash) + VALUES (?, ?) + """, + (paragraph_hash, hash_value), + ) + await self._ensure_relation_vector({**relation, "hash": hash_value}) + restored.append(hash_value) + conn.commit() + if restored and rebuild_graph: + self._rebuild_graph_from_metadata() + if restored and persist: + self._persist() + return {"restored_hashes": restored, "restored_count": len(restored), "failures": failures} + + @staticmethod + def _selector_dict(selector: Any) -> Dict[str, Any]: + if isinstance(selector, dict): + return dict(selector) + if isinstance(selector, (list, tuple)): + return {"items": list(selector)} + token = str(selector or "").strip() + return {"query": token} if token else {} + + def _resolve_paragraph_targets(self, selector: Any, *, include_deleted: bool = False) -> List[Dict[str, Any]]: + assert self.metadata_store + raw = self._selector_dict(selector) + rows: List[Dict[str, Any]] = [] + hashes = self._merge_tokens(raw.get("hashes"), raw.get("items"), [raw.get("hash")]) + for hash_value in hashes: + row = self.metadata_store.get_paragraph(hash_value) + if row is None: + continue + if not include_deleted and bool(row.get("is_deleted", 0)): + continue + rows.append(row) + if rows: + return rows + query = str(raw.get("query", "") or raw.get("content", "") or "").strip() + if not query: + return [] + if len(query) == 64 and all(ch in "0123456789abcdef" for ch in query.lower()): + row = self.metadata_store.get_paragraph(query) + if row is None: + return [] + if not include_deleted and bool(row.get("is_deleted", 0)): + return [] + return [row] + matches = self.metadata_store.search_paragraphs_by_content(query) + return [row for row in matches if include_deleted or not bool(row.get("is_deleted", 0))] + + def _resolve_entity_targets(self, selector: Any, *, include_deleted: bool = False) -> List[Dict[str, Any]]: + assert self.metadata_store + raw = self._selector_dict(selector) + rows: List[Dict[str, Any]] = [] + hashes = self._merge_tokens(raw.get("hashes"), raw.get("items"), [raw.get("hash")]) + for hash_value in hashes: + row = self.metadata_store.get_entity(hash_value) + if row is None: + continue + if not include_deleted and bool(row.get("is_deleted", 0)): + continue + rows.append(row) + names = self._merge_tokens(raw.get("names"), [raw.get("name")], [raw.get("query")]) + for name in names: + if not name: + continue + matches = self.metadata_store.query( + """ + SELECT * + FROM entities + WHERE LOWER(TRIM(name)) = LOWER(TRIM(?)) + OR hash = ? + ORDER BY appearance_count DESC, created_at ASC + """, + (name, compute_hash(str(name).strip().lower())), + ) + for row in matches: + if not include_deleted and bool(row.get("is_deleted", 0)): + continue + rows.append(self.metadata_store._row_to_dict(row, "entity") if hasattr(self.metadata_store, "_row_to_dict") else row) + dedup: Dict[str, Dict[str, Any]] = {} + for row in rows: + token = str(row.get("hash", "") or "").strip() + if token and token not in dedup: + dedup[token] = row + return list(dedup.values()) + + def _resolve_source_targets(self, selector: Any) -> List[str]: + raw = self._selector_dict(selector) + return self._merge_tokens(raw.get("sources"), [raw.get("source")], [raw.get("query")], raw.get("items")) + + def _snapshot_relation_item(self, hash_value: str) -> Optional[Dict[str, Any]]: + assert self.metadata_store + relation = self.metadata_store.get_relation(hash_value) + if relation is None: + relation = self.metadata_store.get_deleted_relation(hash_value) + if relation is None: + return None + paragraph_hashes = [ + str(row.get("paragraph_hash", "") or "").strip() + for row in self.metadata_store.query( + "SELECT paragraph_hash FROM paragraph_relations WHERE relation_hash = ? ORDER BY paragraph_hash ASC", + (hash_value,), + ) + if str(row.get("paragraph_hash", "") or "").strip() + ] + return { + "item_type": "relation", + "item_hash": hash_value, + "item_key": hash_value, + "payload": { + "relation": relation, + "paragraph_hashes": paragraph_hashes, + }, + } + + def _snapshot_paragraph_item(self, hash_value: str) -> Optional[Dict[str, Any]]: + assert self.metadata_store + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is None: + return None + entity_links = [ + { + "paragraph_hash": hash_value, + "entity_hash": str(row.get("entity_hash", "") or ""), + "mention_count": int(row.get("mention_count", 1) or 1), + } + for row in self.metadata_store.query( + """ + SELECT paragraph_hash, entity_hash, mention_count + FROM paragraph_entities + WHERE paragraph_hash = ? + ORDER BY entity_hash ASC + """, + (hash_value,), + ) + ] + relation_hashes = [ + str(row.get("relation_hash", "") or "").strip() + for row in self.metadata_store.query( + """ + SELECT relation_hash + FROM paragraph_relations + WHERE paragraph_hash = ? + ORDER BY relation_hash ASC + """, + (hash_value,), + ) + if str(row.get("relation_hash", "") or "").strip() + ] + return { + "item_type": "paragraph", + "item_hash": hash_value, + "item_key": hash_value, + "payload": { + "paragraph": paragraph, + "entity_links": entity_links, + "relation_hashes": relation_hashes, + "external_refs": self.metadata_store.list_external_memory_refs_by_paragraphs([hash_value]), + }, + } + + def _snapshot_entity_item(self, hash_value: str) -> Optional[Dict[str, Any]]: + assert self.metadata_store + entity = self.metadata_store.get_entity(hash_value) + if entity is None: + return None + paragraph_links = [ + { + "paragraph_hash": str(row.get("paragraph_hash", "") or ""), + "entity_hash": hash_value, + "mention_count": int(row.get("mention_count", 1) or 1), + } + for row in self.metadata_store.query( + """ + SELECT paragraph_hash, mention_count + FROM paragraph_entities + WHERE entity_hash = ? + ORDER BY paragraph_hash ASC + """, + (hash_value,), + ) + ] + return { + "item_type": "entity", + "item_hash": hash_value, + "item_key": hash_value, + "payload": { + "entity": entity, + "paragraph_links": paragraph_links, + }, + } + + def _relation_has_remaining_paragraphs(self, relation_hash: str, removing_hashes: Sequence[str]) -> bool: + assert self.metadata_store + excluded = [str(item or "").strip() for item in removing_hashes if str(item or "").strip()] + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + if excluded: + placeholders = ",".join(["?"] * len(excluded)) + cursor.execute( + f""" + SELECT 1 + FROM paragraph_relations pr + JOIN paragraphs p ON p.hash = pr.paragraph_hash + WHERE pr.relation_hash = ? + AND pr.paragraph_hash NOT IN ({placeholders}) + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + LIMIT 1 + """, + tuple([relation_hash] + excluded), + ) + else: + cursor.execute( + """ + SELECT 1 + FROM paragraph_relations pr + JOIN paragraphs p ON p.hash = pr.paragraph_hash + WHERE pr.relation_hash = ? + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + LIMIT 1 + """, + (relation_hash,), + ) + return cursor.fetchone() is not None + + async def _build_delete_plan(self, *, mode: str, selector: Any) -> Dict[str, Any]: + assert self.metadata_store + act_mode = str(mode or "").strip().lower() + normalized_selector = self._selector_dict(selector) + items: List[Dict[str, Any]] = [] + counts = {"relations": 0, "paragraphs": 0, "entities": 0, "sources": 0} + vector_ids: List[str] = [] + sources: List[str] = [] + target_hashes: Dict[str, List[str]] = {"relations": [], "paragraphs": [], "entities": [], "sources": []} + + if act_mode == "relation": + relation_rows = [row for row in (self.metadata_store.get_relation(hash_value) for hash_value in self._resolve_relation_hashes(str(normalized_selector.get("query", "") or ""))) if row] + if normalized_selector.get("hashes"): + relation_rows = [ + row + for hash_value in self._tokens(normalized_selector.get("hashes")) + for row in [self.metadata_store.get_relation(hash_value)] + if row is not None + ] + dedup_hashes: List[str] = [] + seen = set() + for row in relation_rows: + hash_value = str(row.get("hash", "") or "").strip() + if hash_value and hash_value not in seen: + seen.add(hash_value) + dedup_hashes.append(hash_value) + snap = self._snapshot_relation_item(hash_value) + if snap: + items.append(snap) + vector_ids.append(hash_value) + counts["relations"] = len(dedup_hashes) + target_hashes["relations"] = dedup_hashes + + elif act_mode in {"paragraph", "source"}: + paragraph_rows: List[Dict[str, Any]] = [] + if act_mode == "source": + source_tokens = self._resolve_source_targets(normalized_selector) + target_hashes["sources"] = source_tokens + counts["sources"] = len(source_tokens) + for source in source_tokens: + sources.append(source) + paragraph_rows.extend( + self.metadata_store.query( + """ + SELECT * + FROM paragraphs + WHERE source = ? + AND (is_deleted IS NULL OR is_deleted = 0) + ORDER BY created_at ASC + """, + (source,), + ) + ) + else: + paragraph_rows = self._resolve_paragraph_targets(normalized_selector, include_deleted=False) + paragraph_hashes = self._tokens([row.get("hash", "") for row in paragraph_rows]) + target_hashes["paragraphs"] = paragraph_hashes + counts["paragraphs"] = len(paragraph_hashes) + for hash_value in paragraph_hashes: + snap = self._snapshot_paragraph_item(hash_value) + if snap: + items.append(snap) + vector_ids.append(hash_value) + paragraph = snap["payload"].get("paragraph") or {} + source = str(paragraph.get("source", "") or "").strip() + if source: + sources.append(source) + + orphan_relations: List[str] = [] + for item in items: + if item.get("item_type") != "paragraph": + continue + for relation_hash in self._tokens((item.get("payload") or {}).get("relation_hashes")): + if relation_hash in orphan_relations: + continue + if not self._relation_has_remaining_paragraphs(relation_hash, paragraph_hashes): + orphan_relations.append(relation_hash) + for relation_hash in orphan_relations: + snap = self._snapshot_relation_item(relation_hash) + if snap: + items.append(snap) + vector_ids.append(relation_hash) + target_hashes["relations"] = orphan_relations + counts["relations"] = len(orphan_relations) + + elif act_mode == "entity": + entity_rows = self._resolve_entity_targets(normalized_selector, include_deleted=False) + entity_hashes = self._tokens([row.get("hash", "") for row in entity_rows]) + target_hashes["entities"] = entity_hashes + counts["entities"] = len(entity_hashes) + entity_names = [str(row.get("name", "") or "").strip() for row in entity_rows if str(row.get("name", "") or "").strip()] + for hash_value in entity_hashes: + snap = self._snapshot_entity_item(hash_value) + if snap: + items.append(snap) + vector_ids.append(hash_value) + relation_hashes: List[str] = [] + for entity_name in entity_names: + for relation in self.metadata_store.get_relations(subject=entity_name) + self.metadata_store.get_relations(object=entity_name): + hash_value = str(relation.get("hash", "") or "").strip() + if hash_value and hash_value not in relation_hashes: + relation_hashes.append(hash_value) + for relation_hash in relation_hashes: + snap = self._snapshot_relation_item(relation_hash) + if snap: + items.append(snap) + vector_ids.append(relation_hash) + target_hashes["relations"] = relation_hashes + counts["relations"] = len(relation_hashes) + else: + return {"success": False, "error": f"不支持的 delete mode: {act_mode}"} + + sources = self._tokens(sources) + vector_ids = self._tokens(vector_ids) + primary_count = counts.get(f"{act_mode}s", 0) if act_mode != "source" else counts.get("sources", 0) + return { + "success": primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0, + "mode": act_mode, + "selector": normalized_selector, + "items": items, + "counts": counts, + "vector_ids": vector_ids, + "sources": sources, + "target_hashes": target_hashes, + "error": "" if (primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0) else "未命中可删除内容", + } + + async def _preview_delete_action(self, *, mode: str, selector: Any) -> Dict[str, Any]: + plan = await self._build_delete_plan(mode=mode, selector=selector) + if not plan.get("success", False): + return {"success": False, "error": plan.get("error", "未命中可删除内容")} + preview_items = [ + { + "item_type": str(item.get("item_type", "") or ""), + "item_hash": str(item.get("item_hash", "") or ""), + } + for item in plan.get("items", [])[:100] + ] + return { + "success": True, + "mode": plan.get("mode"), + "selector": plan.get("selector"), + "counts": plan.get("counts", {}), + "sources": plan.get("sources", []), + "vector_ids": plan.get("vector_ids", []), + "items": preview_items, + "item_count": len(plan.get("items", [])), + "dry_run": True, + } + + async def _execute_delete_action( + self, + *, + mode: str, + selector: Any, + requested_by: str = "", + reason: str = "", + ) -> Dict[str, Any]: + assert self.metadata_store + plan = await self._build_delete_plan(mode=mode, selector=selector) + if not plan.get("success", False): + return {"success": False, "error": plan.get("error", "未命中可删除内容")} + + act_mode = str(plan.get("mode", "") or "").strip().lower() + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + paragraph_hashes = self._tokens((plan.get("target_hashes") or {}).get("paragraphs")) + entity_hashes = self._tokens((plan.get("target_hashes") or {}).get("entities")) + relation_hashes = self._tokens((plan.get("target_hashes") or {}).get("relations")) + source_tokens = self._tokens((plan.get("target_hashes") or {}).get("sources")) + + try: + if paragraph_hashes: + self.metadata_store.mark_as_deleted(paragraph_hashes, "paragraph") + cursor.execute( + f"DELETE FROM paragraph_entities WHERE paragraph_hash IN ({','.join(['?'] * len(paragraph_hashes))})", + tuple(paragraph_hashes), + ) + cursor.execute( + f"DELETE FROM paragraph_relations WHERE paragraph_hash IN ({','.join(['?'] * len(paragraph_hashes))})", + tuple(paragraph_hashes), + ) + self.metadata_store.delete_external_memory_refs_by_paragraphs(paragraph_hashes) + if act_mode == "source" and source_tokens: + for source in source_tokens: + self.metadata_store.replace_episodes_for_source(source, []) + + if entity_hashes: + self.metadata_store.mark_as_deleted(entity_hashes, "entity") + cursor.execute( + f"DELETE FROM paragraph_entities WHERE entity_hash IN ({','.join(['?'] * len(entity_hashes))})", + tuple(entity_hashes), + ) + + conn.commit() + + deleted_relations = self.metadata_store.backup_and_delete_relations(relation_hashes) + deleted_vectors = 0 + if self.vector_store is not None and plan.get("vector_ids"): + deleted_vectors = self.vector_store.delete(list(plan.get("vector_ids") or [])) + + operation = self.metadata_store.create_delete_operation( + mode=act_mode, + selector=plan.get("selector"), + items=plan.get("items", []), + reason=reason, + requested_by=requested_by, + summary={ + "counts": plan.get("counts", {}), + "sources": plan.get("sources", []), + "vector_ids": plan.get("vector_ids", []), + "deleted_relation_rows": deleted_relations, + }, + ) + + if plan.get("sources"): + self.metadata_store._enqueue_episode_source_rebuilds(list(plan.get("sources") or []), reason="delete_admin_execute") + self._rebuild_graph_from_metadata() + self._persist() + deleted_count = ( + len(source_tokens) + if act_mode == "source" + else len(paragraph_hashes) + if act_mode == "paragraph" + else len(entity_hashes) + if act_mode == "entity" + else len(relation_hashes) + ) + result = { + "success": True, + "mode": act_mode, + "operation_id": operation.get("operation_id", ""), + "counts": plan.get("counts", {}), + "sources": plan.get("sources", []), + "deleted_count": deleted_count, + "deleted_vector_count": int(deleted_vectors or 0), + "deleted_relation_count": len(relation_hashes), + } + if act_mode == "source": + result["deleted_source_count"] = len(source_tokens) + result["deleted_paragraph_count"] = len(paragraph_hashes) + return result + except Exception as exc: + conn.rollback() + logger.warning(f"delete_admin execute 失败: {exc}") + return {"success": False, "error": str(exc)} + + async def _restore_delete_action( + self, + *, + mode: str, + selector: Any, + operation_id: str = "", + requested_by: str = "", + reason: str = "", + ) -> Dict[str, Any]: + del requested_by + del reason + assert self.metadata_store + + op_id = str(operation_id or "").strip() + if op_id: + operation = self.metadata_store.get_delete_operation(op_id) + if operation is None: + return {"success": False, "error": "operation 不存在"} + return await self._restore_delete_operation(operation) + + act_mode = str(mode or "").strip().lower() + if act_mode != "relation": + return {"success": False, "error": "paragraph/entity/source 恢复必须提供 operation_id"} + + raw = self._selector_dict(selector) + target = str(raw.get("query", "") or raw.get("target", "") or raw.get("hash", "") or "").strip() + hashes = self._resolve_deleted_relation_hashes(target) + if not hashes: + return {"success": False, "error": "未命中可恢复关系"} + result = await self._restore_relation_hashes(hashes) + return {"success": bool(result.get("restored_count", 0) > 0), **result} + + async def _restore_delete_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]: + assert self.metadata_store + items = operation.get("items") if isinstance(operation.get("items"), list) else [] + entity_payloads: Dict[str, Dict[str, Any]] = {} + paragraph_payloads: Dict[str, Dict[str, Any]] = {} + relation_payloads: Dict[str, Dict[str, Any]] = {} + for item in items: + if not isinstance(item, dict): + continue + item_type = str(item.get("item_type", "") or "").strip() + item_hash = str(item.get("item_hash", "") or "").strip() + payload = item.get("payload") if isinstance(item.get("payload"), dict) else {} + if item_type == "entity" and item_hash: + entity_payloads[item_hash] = payload + elif item_type == "paragraph" and item_hash: + paragraph_payloads[item_hash] = payload + elif item_type == "relation" and item_hash: + relation_payloads[item_hash] = payload + + restored_entities: List[str] = [] + restored_paragraphs: List[str] = [] + for hash_value, payload in entity_payloads.items(): + entity_row = payload.get("entity") if isinstance(payload.get("entity"), dict) else {} + if entity_row: + self.metadata_store.restore_entity_by_hash(hash_value) + await self._ensure_entity_vector(entity_row) + restored_entities.append(hash_value) + for hash_value, payload in paragraph_payloads.items(): + paragraph_row = payload.get("paragraph") if isinstance(payload.get("paragraph"), dict) else {} + if paragraph_row: + self.metadata_store.restore_paragraph_by_hash(hash_value) + await self._ensure_paragraph_vector(paragraph_row) + restored_paragraphs.append(hash_value) + + restored_relations = await self._restore_relation_hashes(list(relation_payloads.keys()), payloads=relation_payloads, rebuild_graph=False, persist=False) + + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + for payload in entity_payloads.values(): + for link in payload.get("paragraph_links") or []: + paragraph_hash = str(link.get("paragraph_hash", "") or "").strip() + entity_hash = str(link.get("entity_hash", "") or "").strip() + mention_count = max(1, int(link.get("mention_count", 1) or 1)) + if not paragraph_hash or not entity_hash: + continue + cursor.execute( + """ + INSERT OR IGNORE INTO paragraph_entities (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + """, + (paragraph_hash, entity_hash, mention_count), + ) + for payload in paragraph_payloads.values(): + for link in payload.get("entity_links") or []: + paragraph_hash = str(link.get("paragraph_hash", "") or "").strip() + entity_hash = str(link.get("entity_hash", "") or "").strip() + mention_count = max(1, int(link.get("mention_count", 1) or 1)) + if not paragraph_hash or not entity_hash: + continue + cursor.execute( + """ + INSERT OR IGNORE INTO paragraph_entities (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + """, + (paragraph_hash, entity_hash, mention_count), + ) + for relation_hash in self._tokens(payload.get("relation_hashes")): + paragraph_hash = str((payload.get("paragraph") or {}).get("hash", "") or "").strip() + if not paragraph_hash or not relation_hash: + continue + cursor.execute( + """ + INSERT OR IGNORE INTO paragraph_relations (paragraph_hash, relation_hash) + VALUES (?, ?) + """, + (paragraph_hash, relation_hash), + ) + self.metadata_store.restore_external_memory_refs(list(payload.get("external_refs") or [])) + conn.commit() + + sources = self._tokens( + [ + str(((payload.get("paragraph") or {}).get("source", "") or "")).strip() + for payload in paragraph_payloads.values() + ] + ) + if sources: + self.metadata_store._enqueue_episode_source_rebuilds(sources, reason="delete_admin_restore") + self._rebuild_graph_from_metadata() + self._persist() + summary = { + "restored_entities": restored_entities, + "restored_paragraphs": restored_paragraphs, + "restored_relations": restored_relations.get("restored_hashes", []), + "sources": sources, + } + self.metadata_store.mark_delete_operation_restored(str(operation.get("operation_id", "") or ""), summary=summary) + return { + "success": True, + "operation_id": str(operation.get("operation_id", "") or ""), + **summary, + "restored_relation_count": restored_relations.get("restored_count", 0), + "relation_failures": restored_relations.get("failures", []), + } + + async def _purge_deleted_memory(self, *, grace_hours: Optional[float], limit: int) -> Dict[str, Any]: + assert self.metadata_store + orphan_cfg = self._cfg("memory.orphan", {}) or {} + grace = float(grace_hours) if grace_hours is not None else max( + 1.0, + float(orphan_cfg.get("sweep_grace_hours", 24.0) or 24.0), + ) + cutoff = time.time() - grace * 3600.0 + deleted_relation_hashes = self.metadata_store.purge_deleted_relations(cutoff_time=cutoff, limit=limit) + dead_paragraphs = self.metadata_store.sweep_deleted_items("paragraph", grace * 3600.0) + paragraph_hashes = [str(item[0] or "").strip() for item in dead_paragraphs if str(item[0] or "").strip()] + dead_entities = self.metadata_store.sweep_deleted_items("entity", grace * 3600.0) + entity_hashes = [str(item[0] or "").strip() for item in dead_entities if str(item[0] or "").strip()] + entity_names = [str(item[1] or "").strip() for item in dead_entities if str(item[1] or "").strip()] + + if paragraph_hashes: + self.metadata_store.physically_delete_paragraphs(paragraph_hashes) + if entity_hashes: + self.metadata_store.physically_delete_entities(entity_hashes) + if entity_names: + self.graph_store.delete_nodes(entity_names) + if self.vector_store is not None: + vector_ids = self._merge_tokens(paragraph_hashes, entity_hashes, deleted_relation_hashes) + if vector_ids: + self.vector_store.delete(vector_ids) + self._rebuild_graph_from_metadata() + self._persist() + return { + "success": True, + "grace_hours": grace, + "purged_deleted_relations": deleted_relation_hashes, + "purged_paragraph_hashes": paragraph_hashes, + "purged_entity_hashes": entity_hashes, + "purged_counts": { + "relations": len(deleted_relation_hashes), + "paragraphs": len(paragraph_hashes), + "entities": len(entity_hashes), + }, + } + + @staticmethod + def _optional_float(value: Any) -> Optional[float]: + if value in {None, ""}: + return None + try: + return float(value) + except Exception: + return None + + @staticmethod + def _optional_int(value: Any) -> Optional[int]: + if value in {None, ""}: + return None + try: + return int(value) + except Exception: + return None diff --git a/plugins/A_memorix/core/runtime/search_runtime_initializer.py b/plugins/A_memorix/core/runtime/search_runtime_initializer.py new file mode 100644 index 00000000..c3c7a81f --- /dev/null +++ b/plugins/A_memorix/core/runtime/search_runtime_initializer.py @@ -0,0 +1,240 @@ +"""Shared runtime initializer for Action/Tool/Command retrieval components.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from src.common.logger import get_logger + +from ..retrieval import ( + DualPathRetriever, + DualPathRetrieverConfig, + DynamicThresholdFilter, + FusionConfig, + GraphRelationRecallConfig, + RelationIntentConfig, + RetrievalStrategy, + SparseBM25Config, + ThresholdConfig, + ThresholdMethod, +) + +_logger = get_logger("A_Memorix.SearchRuntimeInitializer") + +_REQUIRED_COMPONENT_KEYS = ( + "vector_store", + "graph_store", + "metadata_store", + "embedding_manager", +) + + +def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any: + if not isinstance(config, dict): + return default + current: Any = config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + +def _safe_dict(value: Any) -> Dict[str, Any]: + return value if isinstance(value, dict) else {} + + +def _resolve_debug_enabled(plugin_config: Optional[dict]) -> bool: + advanced = _get_config_value(plugin_config, "advanced", {}) + if isinstance(advanced, dict): + return bool(advanced.get("debug", False)) + return bool(_get_config_value(plugin_config, "debug", False)) + + +@dataclass +class SearchRuntimeBundle: + """Resolved runtime components and initialized retriever/filter.""" + + vector_store: Optional[Any] = None + graph_store: Optional[Any] = None + metadata_store: Optional[Any] = None + embedding_manager: Optional[Any] = None + sparse_index: Optional[Any] = None + retriever: Optional[DualPathRetriever] = None + threshold_filter: Optional[DynamicThresholdFilter] = None + error: str = "" + + @property + def ready(self) -> bool: + return ( + self.retriever is not None + and self.vector_store is not None + and self.graph_store is not None + and self.metadata_store is not None + and self.embedding_manager is not None + ) + + +def _resolve_runtime_components(plugin_config: Optional[dict]) -> SearchRuntimeBundle: + bundle = SearchRuntimeBundle( + vector_store=_get_config_value(plugin_config, "vector_store"), + graph_store=_get_config_value(plugin_config, "graph_store"), + metadata_store=_get_config_value(plugin_config, "metadata_store"), + embedding_manager=_get_config_value(plugin_config, "embedding_manager"), + sparse_index=_get_config_value(plugin_config, "sparse_index"), + ) + + missing_required = any( + getattr(bundle, key) is None for key in _REQUIRED_COMPONENT_KEYS + ) + if not missing_required: + return bundle + + try: + from ...plugin import AMemorixPlugin + + instances = AMemorixPlugin.get_storage_instances() + except Exception: + instances = {} + + if not isinstance(instances, dict) or not instances: + return bundle + + if bundle.vector_store is None: + bundle.vector_store = instances.get("vector_store") + if bundle.graph_store is None: + bundle.graph_store = instances.get("graph_store") + if bundle.metadata_store is None: + bundle.metadata_store = instances.get("metadata_store") + if bundle.embedding_manager is None: + bundle.embedding_manager = instances.get("embedding_manager") + if bundle.sparse_index is None: + bundle.sparse_index = instances.get("sparse_index") + return bundle + + +def build_search_runtime( + plugin_config: Optional[dict], + logger_obj: Optional[Any], + owner_tag: str, + *, + log_prefix: str = "", +) -> SearchRuntimeBundle: + """Build retriever + threshold filter with unified fallback/config parsing.""" + + log = logger_obj or _logger + owner = str(owner_tag or "runtime").strip().lower() or "runtime" + prefix = str(log_prefix or "").strip() + prefix_text = f"{prefix} " if prefix else "" + + runtime = _resolve_runtime_components(plugin_config) + if any(getattr(runtime, key) is None for key in _REQUIRED_COMPONENT_KEYS): + runtime.error = "存储组件未完全初始化" + log.warning(f"{prefix_text}[{owner}] 存储组件未完全初始化,无法使用检索功能") + return runtime + + sparse_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.sparse", {}) or {}) + fusion_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.fusion", {}) or {}) + relation_intent_cfg_raw = _safe_dict( + _get_config_value(plugin_config, "retrieval.search.relation_intent", {}) or {} + ) + graph_recall_cfg_raw = _safe_dict( + _get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {} + ) + + try: + sparse_cfg = SparseBM25Config(**sparse_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] sparse 配置非法,回退默认: {e}") + sparse_cfg = SparseBM25Config() + + try: + fusion_cfg = FusionConfig(**fusion_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] fusion 配置非法,回退默认: {e}") + fusion_cfg = FusionConfig() + + try: + relation_intent_cfg = RelationIntentConfig(**relation_intent_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] relation_intent 配置非法,回退默认: {e}") + relation_intent_cfg = RelationIntentConfig() + + try: + graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}") + graph_recall_cfg = GraphRelationRecallConfig() + + try: + config = DualPathRetrieverConfig( + top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20), + top_k_relations=_get_config_value(plugin_config, "retrieval.top_k_relations", 10), + top_k_final=_get_config_value(plugin_config, "retrieval.top_k_final", 10), + alpha=_get_config_value(plugin_config, "retrieval.alpha", 0.5), + enable_ppr=_get_config_value(plugin_config, "retrieval.enable_ppr", True), + ppr_alpha=_get_config_value(plugin_config, "retrieval.ppr_alpha", 0.85), + ppr_timeout_seconds=_get_config_value( + plugin_config, "retrieval.ppr_timeout_seconds", 1.5 + ), + ppr_concurrency_limit=_get_config_value( + plugin_config, "retrieval.ppr_concurrency_limit", 4 + ), + enable_parallel=_get_config_value(plugin_config, "retrieval.enable_parallel", True), + retrieval_strategy=RetrievalStrategy.DUAL_PATH, + debug=_resolve_debug_enabled(plugin_config), + sparse=sparse_cfg, + fusion=fusion_cfg, + relation_intent=relation_intent_cfg, + graph_recall=graph_recall_cfg, + ) + + runtime.retriever = DualPathRetriever( + vector_store=runtime.vector_store, + graph_store=runtime.graph_store, + metadata_store=runtime.metadata_store, + embedding_manager=runtime.embedding_manager, + sparse_index=runtime.sparse_index, + config=config, + ) + + threshold_config = ThresholdConfig( + method=ThresholdMethod.ADAPTIVE, + min_threshold=_get_config_value(plugin_config, "threshold.min_threshold", 0.3), + max_threshold=_get_config_value(plugin_config, "threshold.max_threshold", 0.95), + percentile=_get_config_value(plugin_config, "threshold.percentile", 75.0), + std_multiplier=_get_config_value(plugin_config, "threshold.std_multiplier", 1.5), + min_results=_get_config_value(plugin_config, "threshold.min_results", 3), + enable_auto_adjust=_get_config_value(plugin_config, "threshold.enable_auto_adjust", True), + ) + runtime.threshold_filter = DynamicThresholdFilter(threshold_config) + runtime.error = "" + log.info(f"{prefix_text}[{owner}] 检索运行时初始化完成") + except Exception as e: + runtime.retriever = None + runtime.threshold_filter = None + runtime.error = str(e) + log.error(f"{prefix_text}[{owner}] 检索运行时初始化失败: {e}") + + return runtime + + +class SearchRuntimeInitializer: + """Compatibility wrapper around the function style initializer.""" + + @staticmethod + def build_search_runtime( + plugin_config: Optional[dict], + logger_obj: Optional[Any], + owner_tag: str, + *, + log_prefix: str = "", + ) -> SearchRuntimeBundle: + return build_search_runtime( + plugin_config=plugin_config, + logger_obj=logger_obj, + owner_tag=owner_tag, + log_prefix=log_prefix, + ) diff --git a/plugins/A_memorix/core/storage/graph_store.py b/plugins/A_memorix/core/storage/graph_store.py index 8a075864..0a5fd95d 100644 --- a/plugins/A_memorix/core/storage/graph_store.py +++ b/plugins/A_memorix/core/storage/graph_store.py @@ -24,6 +24,20 @@ try: from scipy.sparse.linalg import norm HAS_SCIPY = True except ImportError: + class _SparseMatrixPlaceholder: + pass + + def _scipy_missing(*args, **kwargs): + raise ImportError("SciPy 未安装,请安装: pip install scipy") + + csr_matrix = _SparseMatrixPlaceholder + csc_matrix = _SparseMatrixPlaceholder + lil_matrix = _SparseMatrixPlaceholder + triu = _scipy_missing + save_npz = _scipy_missing + load_npz = _scipy_missing + bmat = _scipy_missing + norm = _scipy_missing HAS_SCIPY = False import contextlib diff --git a/plugins/A_memorix/core/storage/metadata_store.py b/plugins/A_memorix/core/storage/metadata_store.py index e94610f0..39f2701c 100644 --- a/plugins/A_memorix/core/storage/metadata_store.py +++ b/plugins/A_memorix/core/storage/metadata_store.py @@ -7,6 +7,8 @@ import sqlite3 import pickle import json +import uuid +import re from datetime import datetime from pathlib import Path from typing import Optional, Union, List, Dict, Any, Tuple @@ -24,7 +26,7 @@ from .knowledge_types import ( logger = get_logger("A_Memorix.MetadataStore") -SCHEMA_VERSION = 7 +SCHEMA_VERSION = 8 class MetadataStore: @@ -500,6 +502,63 @@ class MetadataStore: CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph ON external_memory_refs(paragraph_hash) """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memory_v5_operations ( + operation_id TEXT PRIMARY KEY, + action TEXT NOT NULL, + target TEXT, + reason TEXT, + updated_by TEXT, + created_at REAL NOT NULL, + resolved_hashes_json TEXT, + result_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created + ON memory_v5_operations(created_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS delete_operations ( + operation_id TEXT PRIMARY KEY, + mode TEXT NOT NULL, + selector TEXT, + reason TEXT, + requested_by TEXT, + status TEXT NOT NULL, + created_at REAL NOT NULL, + restored_at REAL, + summary_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operations_created + ON delete_operations(created_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operations_mode + ON delete_operations(mode, created_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS delete_operation_items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + operation_id TEXT NOT NULL, + item_type TEXT NOT NULL, + item_hash TEXT, + item_key TEXT, + payload_json TEXT, + created_at REAL NOT NULL, + FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation + ON delete_operation_items(operation_id, id ASC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash + ON delete_operation_items(item_hash) + """) # 新版 schema 包含完整字段,直接写入版本信息 cursor.execute("INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", (SCHEMA_VERSION, datetime.now().timestamp())) self._conn.commit() @@ -618,6 +677,63 @@ class MetadataStore: CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph ON external_memory_refs(paragraph_hash) """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memory_v5_operations ( + operation_id TEXT PRIMARY KEY, + action TEXT NOT NULL, + target TEXT, + reason TEXT, + updated_by TEXT, + created_at REAL NOT NULL, + resolved_hashes_json TEXT, + result_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created + ON memory_v5_operations(created_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS delete_operations ( + operation_id TEXT PRIMARY KEY, + mode TEXT NOT NULL, + selector TEXT, + reason TEXT, + requested_by TEXT, + status TEXT NOT NULL, + created_at REAL NOT NULL, + restored_at REAL, + summary_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operations_created + ON delete_operations(created_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operations_mode + ON delete_operations(mode, created_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS delete_operation_items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + operation_id TEXT NOT NULL, + item_type TEXT NOT NULL, + item_hash TEXT, + item_key TEXT, + payload_json TEXT, + created_at REAL NOT NULL, + FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation + ON delete_operation_items(operation_id, id ASC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash + ON delete_operation_items(item_hash) + """) # 检查paragraphs表是否有knowledge_type列 cursor.execute("PRAGMA table_info(paragraphs)") @@ -2595,6 +2711,328 @@ class MetadataStore: "metadata": metadata or {}, } + @staticmethod + def _json_dumps(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, sort_keys=True) + + @staticmethod + def _json_loads(value: Any, default: Any) -> Any: + if value in {None, ""}: + return default + try: + return json.loads(value) + except Exception: + return default + + def list_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]: + hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()] + if not hashes: + return [] + placeholders = ",".join(["?"] * len(hashes)) + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT external_id, paragraph_hash, source_type, created_at, metadata_json + FROM external_memory_refs + WHERE paragraph_hash IN ({placeholders}) + ORDER BY created_at ASC, external_id ASC + """, + tuple(hashes), + ) + items: List[Dict[str, Any]] = [] + for row in cursor.fetchall(): + payload = dict(row) + payload["metadata"] = self._json_loads(payload.get("metadata_json"), {}) + items.append(payload) + return items + + def delete_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]: + items = self.list_external_memory_refs_by_paragraphs(paragraph_hashes) + hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()] + if not hashes: + return items + placeholders = ",".join(["?"] * len(hashes)) + cursor = self._conn.cursor() + cursor.execute( + f"DELETE FROM external_memory_refs WHERE paragraph_hash IN ({placeholders})", + tuple(hashes), + ) + self._conn.commit() + return items + + def restore_external_memory_refs(self, refs: List[Dict[str, Any]]) -> int: + count = 0 + for item in refs or []: + external_id = str(item.get("external_id", "") or "").strip() + paragraph_hash = str(item.get("paragraph_hash", "") or "").strip() + if not external_id or not paragraph_hash: + continue + created_at = float(item.get("created_at") or datetime.now().timestamp()) + metadata_json = self._json_dumps(item.get("metadata") or {}) + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO external_memory_refs ( + external_id, paragraph_hash, source_type, created_at, metadata_json + ) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(external_id) DO UPDATE SET + paragraph_hash = excluded.paragraph_hash, + source_type = excluded.source_type, + created_at = excluded.created_at, + metadata_json = excluded.metadata_json + """, + ( + external_id, + paragraph_hash, + str(item.get("source_type", "") or "").strip() or None, + created_at, + metadata_json, + ), + ) + count += max(0, int(cursor.rowcount or 0)) + self._conn.commit() + return count + + def record_v5_operation( + self, + *, + action: str, + target: str, + resolved_hashes: List[str], + reason: str = "", + updated_by: str = "", + result: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + operation_id = f"v5_{uuid.uuid4().hex}" + created_at = datetime.now().timestamp() + payload = { + "operation_id": operation_id, + "action": str(action or "").strip(), + "target": str(target or "").strip(), + "reason": str(reason or "").strip(), + "updated_by": str(updated_by or "").strip(), + "created_at": created_at, + "resolved_hashes": [str(item or "").strip() for item in (resolved_hashes or []) if str(item or "").strip()], + "result": result or {}, + } + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO memory_v5_operations ( + operation_id, action, target, reason, updated_by, created_at, resolved_hashes_json, result_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + operation_id, + payload["action"], + payload["target"] or None, + payload["reason"] or None, + payload["updated_by"] or None, + created_at, + self._json_dumps(payload["resolved_hashes"]), + self._json_dumps(payload["result"]), + ), + ) + self._conn.commit() + return payload + + def create_delete_operation( + self, + *, + mode: str, + selector: Any, + items: List[Dict[str, Any]], + reason: str = "", + requested_by: str = "", + status: str = "executed", + summary: Optional[Dict[str, Any]] = None, + operation_id: Optional[str] = None, + ) -> Dict[str, Any]: + op_id = str(operation_id or f"del_{uuid.uuid4().hex}").strip() + created_at = datetime.now().timestamp() + normalized_items: List[Dict[str, Any]] = [] + for item in items or []: + if not isinstance(item, dict): + continue + item_type = str(item.get("item_type", "") or "").strip() + if not item_type: + continue + normalized_items.append( + { + "item_type": item_type, + "item_hash": str(item.get("item_hash", "") or "").strip() or None, + "item_key": str(item.get("item_key", "") or item.get("item_hash", "") or "").strip() or None, + "payload": item.get("payload") if isinstance(item.get("payload"), dict) else {}, + } + ) + + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO delete_operations ( + operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, NULL, ?) + """, + ( + op_id, + str(mode or "").strip(), + self._json_dumps(selector if selector is not None else {}), + str(reason or "").strip() or None, + str(requested_by or "").strip() or None, + str(status or "executed").strip(), + created_at, + self._json_dumps(summary or {}), + ), + ) + if normalized_items: + cursor.executemany( + """ + INSERT INTO delete_operation_items ( + operation_id, item_type, item_hash, item_key, payload_json, created_at + ) VALUES (?, ?, ?, ?, ?, ?) + """, + [ + ( + op_id, + item["item_type"], + item["item_hash"], + item["item_key"], + self._json_dumps(item["payload"]), + created_at, + ) + for item in normalized_items + ], + ) + self._conn.commit() + return self.get_delete_operation(op_id) or { + "operation_id": op_id, + "mode": str(mode or "").strip(), + "selector": selector, + "reason": str(reason or "").strip(), + "requested_by": str(requested_by or "").strip(), + "status": str(status or "executed").strip(), + "created_at": created_at, + "summary": summary or {}, + "items": normalized_items, + } + + def mark_delete_operation_restored( + self, + operation_id: str, + *, + summary: Optional[Dict[str, Any]] = None, + ) -> bool: + token = str(operation_id or "").strip() + if not token: + return False + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE delete_operations + SET status = ?, restored_at = ?, summary_json = ? + WHERE operation_id = ? + """, + ( + "restored", + datetime.now().timestamp(), + self._json_dumps(summary or {}), + token, + ), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def list_delete_operations(self, *, limit: int = 50, mode: str = "") -> List[Dict[str, Any]]: + cursor = self._conn.cursor() + params: List[Any] = [] + where = "" + mode_token = str(mode or "").strip().lower() + if mode_token: + where = "WHERE LOWER(mode) = ?" + params.append(mode_token) + params.append(max(1, int(limit or 50))) + cursor.execute( + f""" + SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json + FROM delete_operations + {where} + ORDER BY created_at DESC + LIMIT ? + """, + tuple(params), + ) + items: List[Dict[str, Any]] = [] + for row in cursor.fetchall(): + payload = dict(row) + payload["selector"] = self._json_loads(payload.get("selector"), {}) + payload["summary"] = self._json_loads(payload.get("summary_json"), {}) + items.append(payload) + return items + + def get_delete_operation(self, operation_id: str) -> Optional[Dict[str, Any]]: + token = str(operation_id or "").strip() + if not token: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json + FROM delete_operations + WHERE operation_id = ? + LIMIT 1 + """, + (token,), + ) + row = cursor.fetchone() + if row is None: + return None + + payload = dict(row) + payload["selector"] = self._json_loads(payload.get("selector"), {}) + payload["summary"] = self._json_loads(payload.get("summary_json"), {}) + + cursor.execute( + """ + SELECT item_type, item_hash, item_key, payload_json, created_at + FROM delete_operation_items + WHERE operation_id = ? + ORDER BY id ASC + """, + (token,), + ) + payload["items"] = [ + { + "item_type": str(item["item_type"] or ""), + "item_hash": str(item["item_hash"] or ""), + "item_key": str(item["item_key"] or ""), + "payload": self._json_loads(item["payload_json"], {}), + "created_at": item["created_at"], + } + for item in cursor.fetchall() + ] + return payload + + def purge_deleted_relations(self, *, cutoff_time: float, limit: int = 1000) -> List[str]: + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT hash + FROM deleted_relations + WHERE deleted_at IS NOT NULL AND deleted_at < ? + ORDER BY deleted_at ASC + LIMIT ? + """, + (float(cutoff_time), max(1, int(limit or 1000))), + ) + hashes = [str(row[0] or "").strip() for row in cursor.fetchall() if str(row[0] or "").strip()] + if not hashes: + return [] + placeholders = ",".join(["?"] * len(hashes)) + cursor.execute(f"DELETE FROM deleted_relations WHERE hash IN ({placeholders})", tuple(hashes)) + self._conn.commit() + return hashes + def get_statistics(self) -> Dict[str, int]: """ 获取统计信息 @@ -2956,6 +3394,18 @@ class MetadataStore: self._conn.commit() return changed + def restore_paragraph_by_hash(self, paragraph_hash: str) -> bool: + """恢复软删除段落。""" + cursor = self._conn.cursor() + cursor.execute( + "UPDATE paragraphs SET is_deleted=0, deleted_at=NULL WHERE hash=?", + (str(paragraph_hash),), + ) + changed = cursor.rowcount > 0 + if changed: + self._conn.commit() + return changed + def backfill_temporal_metadata_from_created_at( self, *, @@ -4698,6 +5148,29 @@ class MetadataStore: ) self._conn.commit() + def get_episode_pending_status_counts(self, source: str) -> Dict[str, int]: + """统计某个 source 当前 pending 队列中的状态分布。""" + token = self._normalize_episode_source(source) + if not token: + return {"pending": 0, "running": 0, "failed": 0, "done": 0} + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) AS count + FROM episode_pending_paragraphs + WHERE TRIM(COALESCE(source, '')) = ? + GROUP BY status + """, + (token,), + ) + counts = {"pending": 0, "running": 0, "failed": 0, "done": 0} + for row in cursor.fetchall(): + status = str(row["status"] or "").strip().lower() + if status in counts: + counts[status] = int(row["count"] or 0) + return counts + def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: data = dict(row) @@ -4904,7 +5377,7 @@ class MetadataStore: SELECT 1 FROM episode_rebuild_sources ers WHERE ers.source = TRIM(COALESCE(e.source, '')) - AND ers.status IN ('pending', 'running', 'failed') + AND ers.status IN ('pending', 'running') ) """ ) @@ -4948,6 +5421,26 @@ class MetadataStore: return source_expr, effective_start, effective_end, conditions, params + @staticmethod + def _tokenize_episode_query(query: str) -> Tuple[str, List[str]]: + """将 episode 查询归一化为短语和 token。""" + normalized = normalize_text(str(query or "")).strip().lower() + if not normalized: + return "", [] + + token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}") + tokens: List[str] = [] + seen = set() + for token in token_pattern.findall(normalized): + if token in seen: + continue + seen.add(token) + tokens.append(token) + + if not tokens and len(normalized) >= 2: + tokens = [normalized] + return normalized, tokens + def get_episode_rows_by_paragraph_hashes( self, paragraph_hashes: List[str], @@ -5097,28 +5590,58 @@ class MetadataStore: source=source, ) - q = str(query or "").strip().lower() + q, tokens = self._tokenize_episode_query(query) select_score_sql = "0.0 AS lexical_score" order_sql = f"{effective_end} DESC, e.updated_at DESC" select_params: List[Any] = [] query_params: List[Any] = [] if q: - like = f"%{q}%" - title_expr = "LOWER(COALESCE(e.title, '')) LIKE ?" - summary_expr = "LOWER(COALESCE(e.summary, '')) LIKE ?" - keywords_expr = "LOWER(COALESCE(e.keywords_json, '')) LIKE ?" - participants_expr = "LOWER(COALESCE(e.participants_json, '')) LIKE ?" - conditions.append( - f"({title_expr} OR {summary_expr} OR {keywords_expr} OR {participants_expr})" + field_exprs = { + "title": "LOWER(COALESCE(e.title, ''))", + "summary": "LOWER(COALESCE(e.summary, ''))", + "keywords": "LOWER(COALESCE(e.keywords_json, ''))", + "participants": "LOWER(COALESCE(e.participants_json, ''))", + } + + score_parts: List[str] = [] + phrase_like = f"%{q}%" + score_parts.extend( + [ + f"CASE WHEN {field_exprs['title']} LIKE ? THEN 6.0 ELSE 0.0 END", + f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 4.5 ELSE 0.0 END", + f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 3.0 ELSE 0.0 END", + f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 2.0 ELSE 0.0 END", + ] ) - select_score_sql = ( - f"(CASE WHEN {title_expr} THEN 4.0 ELSE 0.0 END + " - f"CASE WHEN {keywords_expr} THEN 3.0 ELSE 0.0 END + " - f"CASE WHEN {summary_expr} THEN 2.0 ELSE 0.0 END + " - f"CASE WHEN {participants_expr} THEN 1.0 ELSE 0.0 END) AS lexical_score" - ) - select_params.extend([like, like, like, like]) - query_params.extend([like, like, like, like]) + select_params.extend([phrase_like, phrase_like, phrase_like, phrase_like]) + + token_predicates: List[str] = [] + for token in tokens: + like = f"%{token}%" + token_any = ( + f"({field_exprs['title']} LIKE ? OR " + f"{field_exprs['summary']} LIKE ? OR " + f"{field_exprs['keywords']} LIKE ? OR " + f"{field_exprs['participants']} LIKE ?)" + ) + token_predicates.append(token_any) + query_params.extend([like, like, like, like]) + + score_parts.append( + "(" + f"CASE WHEN {field_exprs['title']} LIKE ? THEN 3.0 ELSE 0.0 END + " + f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 2.5 ELSE 0.0 END + " + f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 2.0 ELSE 0.0 END + " + f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 1.5 ELSE 0.0 END + " + f"CASE WHEN {token_any.replace('?', '?')} THEN 2.0 ELSE 0.0 END" + ")" + ) + select_params.extend([like, like, like, like, like, like, like, like]) + + if token_predicates: + conditions.append("(" + " OR ".join(token_predicates) + ")") + + select_score_sql = f"({' + '.join(score_parts)}) AS lexical_score" order_sql = f"lexical_score DESC, {effective_end} DESC, e.updated_at DESC" where_sql = ("WHERE " + " AND ".join(conditions)) if conditions else "" diff --git a/plugins/A_memorix/core/utils/aggregate_query_service.py b/plugins/A_memorix/core/utils/aggregate_query_service.py index a87a4913..dcf64c34 100644 --- a/plugins/A_memorix/core/utils/aggregate_query_service.py +++ b/plugins/A_memorix/core/utils/aggregate_query_service.py @@ -302,7 +302,7 @@ class AggregateQueryService: ) for (branch_name, _), payload in zip(scheduled, done): if isinstance(payload, Exception): - logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload) + logger.error(f"aggregate branch failed: branch={branch_name} error={payload}") normalized = self._normalize_branch_payload( branch_name, { diff --git a/plugins/A_memorix/core/utils/episode_retrieval_service.py b/plugins/A_memorix/core/utils/episode_retrieval_service.py index 5a4cd24d..44b22854 100644 --- a/plugins/A_memorix/core/utils/episode_retrieval_service.py +++ b/plugins/A_memorix/core/utils/episode_retrieval_service.py @@ -70,7 +70,7 @@ class EpisodeRetrievalService: temporal=temporal, ) except Exception as exc: - logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc) + logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}") else: paragraph_rank_map: Dict[str, int] = {} relation_rank_map: Dict[str, int] = {} diff --git a/plugins/A_memorix/core/utils/episode_segmentation_service.py b/plugins/A_memorix/core/utils/episode_segmentation_service.py new file mode 100644 index 00000000..f42b1456 --- /dev/null +++ b/plugins/A_memorix/core/utils/episode_segmentation_service.py @@ -0,0 +1,304 @@ +""" +Episode 语义切分服务(LLM 主路径)。 + +职责: +1. 组装语义切分提示词 +2. 调用 LLM 生成结构化 episode JSON +3. 严格校验输出结构,返回标准化结果 +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logger import get_logger +from src.config.model_configs import TaskConfig +from src.config.config import model_config as host_model_config +from src.services import llm_service as llm_api + +logger = get_logger("A_Memorix.EpisodeSegmentationService") + + +class EpisodeSegmentationService: + """基于 LLM 的 episode 语义切分服务。""" + + SEGMENTATION_VERSION = "episode_mvp_v1" + + def __init__(self, plugin_config: Optional[dict] = None): + self.plugin_config = plugin_config or {} + + def _cfg(self, key: str, default: Any = None) -> Any: + current: Any = self.plugin_config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + @staticmethod + def _is_task_config(obj: Any) -> bool: + return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", [])) + + def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig: + return TaskConfig( + model_list=[model_name], + max_tokens=template.max_tokens, + temperature=template.temperature, + slow_threshold=template.slow_threshold, + selection_strategy=template.selection_strategy, + ) + + def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]: + preferred = ("utils", "replyer", "planner", "tool_use") + for task_name in preferred: + cfg = available_tasks.get(task_name) + if self._is_task_config(cfg): + return cfg + for task_name, cfg in available_tasks.items(): + if task_name != "embedding" and self._is_task_config(cfg): + return cfg + for cfg in available_tasks.values(): + if self._is_task_config(cfg): + return cfg + return None + + def _resolve_model_config(self) -> Tuple[Optional[Any], str]: + available_tasks = llm_api.get_available_models() or {} + if not available_tasks: + return None, "unavailable" + + selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip() + model_dict = getattr(host_model_config, "models_dict", {}) or {} + + if selector and selector.lower() != "auto": + direct_task = available_tasks.get(selector) + if self._is_task_config(direct_task): + return direct_task, selector + + if selector in model_dict: + template = self._pick_template_task(available_tasks) + if template is not None: + return self._build_single_model_task(selector, template), selector + + logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto") + + for task_name in ("utils", "replyer", "planner", "tool_use"): + cfg = available_tasks.get(task_name) + if self._is_task_config(cfg): + return cfg, task_name + + fallback = self._pick_template_task(available_tasks) + if fallback is not None: + return fallback, "auto" + return None, "unavailable" + + @staticmethod + def _clamp_score(value: Any, default: float = 0.0) -> float: + try: + num = float(value) + except Exception: + num = default + if num < 0.0: + return 0.0 + if num > 1.0: + return 1.0 + return num + + @staticmethod + def _safe_json_loads(text: str) -> Dict[str, Any]: + raw = str(text or "").strip() + if not raw: + raise ValueError("empty_response") + + if "```" in raw: + raw = raw.replace("```json", "```").replace("```JSON", "```") + parts = raw.split("```") + for part in parts: + part = part.strip() + if part.startswith("{") and part.endswith("}"): + raw = part + break + + try: + data = json.loads(raw) + if isinstance(data, dict): + return data + except Exception: + pass + + start = raw.find("{") + end = raw.rfind("}") + if start >= 0 and end > start: + candidate = raw[start : end + 1] + data = json.loads(candidate) + if isinstance(data, dict): + return data + + raise ValueError("invalid_json_response") + + def _build_prompt( + self, + *, + source: str, + window_start: Optional[float], + window_end: Optional[float], + paragraphs: List[Dict[str, Any]], + ) -> str: + rows: List[str] = [] + for idx, item in enumerate(paragraphs, 1): + p_hash = str(item.get("hash", "") or "").strip() + content = str(item.get("content", "") or "").strip().replace("\r\n", "\n") + content = content[:800] + event_start = item.get("event_time_start") + event_end = item.get("event_time_end") + event_time = item.get("event_time") + rows.append( + ( + f"[{idx}] hash={p_hash}\n" + f"event_time={event_time}\n" + f"event_time_start={event_start}\n" + f"event_time_end={event_end}\n" + f"content={content}" + ) + ) + + source_text = str(source or "").strip() or "unknown" + return ( + "You are an episode segmentation engine.\n" + "Group the given paragraphs into one or more coherent episodes.\n" + "Return JSON ONLY. No markdown, no explanation.\n" + "\n" + "Hard JSON schema:\n" + "{\n" + ' "episodes": [\n' + " {\n" + ' "title": "string",\n' + ' "summary": "string",\n' + ' "paragraph_hashes": ["hash1", "hash2"],\n' + ' "participants": ["person1", "person2"],\n' + ' "keywords": ["kw1", "kw2"],\n' + ' "time_confidence": 0.0,\n' + ' "llm_confidence": 0.0\n' + " }\n" + " ]\n" + "}\n" + "\n" + "Rules:\n" + "1) paragraph_hashes must come from input only.\n" + "2) title and summary must be non-empty.\n" + "3) keep participants/keywords concise and deduplicated.\n" + "4) if uncertain, still provide best effort confidence values.\n" + "\n" + f"source={source_text}\n" + f"window_start={window_start}\n" + f"window_end={window_end}\n" + "paragraphs:\n" + + "\n\n".join(rows) + ) + + def _normalize_episodes( + self, + *, + payload: Dict[str, Any], + input_hashes: List[str], + ) -> List[Dict[str, Any]]: + raw_episodes = payload.get("episodes") + if not isinstance(raw_episodes, list): + raise ValueError("episodes_missing_or_not_list") + + valid_hashes = set(input_hashes) + normalized: List[Dict[str, Any]] = [] + for item in raw_episodes: + if not isinstance(item, dict): + continue + + title = str(item.get("title", "") or "").strip() + summary = str(item.get("summary", "") or "").strip() + if not title or not summary: + continue + + raw_hashes = item.get("paragraph_hashes") + if not isinstance(raw_hashes, list): + continue + + dedup_hashes: List[str] = [] + seen_hashes = set() + for h in raw_hashes: + token = str(h or "").strip() + if not token or token in seen_hashes or token not in valid_hashes: + continue + seen_hashes.add(token) + dedup_hashes.append(token) + + if not dedup_hashes: + continue + + participants = [] + for p in item.get("participants", []) or []: + token = str(p or "").strip() + if token: + participants.append(token) + + keywords = [] + for kw in item.get("keywords", []) or []: + token = str(kw or "").strip() + if token: + keywords.append(token) + + normalized.append( + { + "title": title, + "summary": summary, + "paragraph_hashes": dedup_hashes, + "participants": participants[:16], + "keywords": keywords[:20], + "time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0), + "llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5), + } + ) + + if not normalized: + raise ValueError("episodes_all_invalid") + return normalized + + async def segment( + self, + *, + source: str, + window_start: Optional[float], + window_end: Optional[float], + paragraphs: List[Dict[str, Any]], + ) -> Dict[str, Any]: + if not paragraphs: + raise ValueError("paragraphs_empty") + + model_config, model_label = self._resolve_model_config() + if model_config is None: + raise RuntimeError("episode segmentation model unavailable") + + prompt = self._build_prompt( + source=source, + window_start=window_start, + window_end=window_end, + paragraphs=paragraphs, + ) + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config, + request_type="A_Memorix.EpisodeSegmentation", + ) + if not success or not response: + raise RuntimeError("llm_generate_failed") + + payload = self._safe_json_loads(str(response)) + input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs] + episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes) + + return { + "episodes": episodes, + "segmentation_model": model_label, + "segmentation_version": self.SEGMENTATION_VERSION, + } + diff --git a/plugins/A_memorix/core/utils/episode_service.py b/plugins/A_memorix/core/utils/episode_service.py new file mode 100644 index 00000000..ca94dd96 --- /dev/null +++ b/plugins/A_memorix/core/utils/episode_service.py @@ -0,0 +1,558 @@ +""" +Episode 聚合与落库服务。 + +流程: +1. 从 pending 队列读取段落并组批 +2. 按 source + 时间窗口切组 +3. 调用 LLM 语义切分 +4. 写入 episodes + episode_paragraphs +5. LLM 失败时使用确定性 fallback +""" + +from __future__ import annotations + +import json +import re +from collections import Counter +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +from .episode_segmentation_service import EpisodeSegmentationService +from .hash import compute_hash + +logger = get_logger("A_Memorix.EpisodeService") + + +class EpisodeService: + """Episode MVP 后台处理服务。""" + + def __init__( + self, + *, + metadata_store: Any, + plugin_config: Optional[Any] = None, + segmentation_service: Optional[EpisodeSegmentationService] = None, + ): + self.metadata_store = metadata_store + self.plugin_config = plugin_config or {} + self.segmentation_service = segmentation_service or EpisodeSegmentationService( + plugin_config=self._config_dict(), + ) + + def _config_dict(self) -> Dict[str, Any]: + if isinstance(self.plugin_config, dict): + return self.plugin_config + return {} + + def _cfg(self, key: str, default: Any = None) -> Any: + getter = getattr(self.plugin_config, "get_config", None) + if callable(getter): + return getter(key, default) + + current: Any = self.plugin_config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + @staticmethod + def _to_optional_float(value: Any) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except Exception: + return None + + @staticmethod + def _clamp_score(value: Any, default: float = 1.0) -> float: + try: + num = float(value) + except Exception: + num = default + if num < 0.0: + return 0.0 + if num > 1.0: + return 1.0 + return num + + @staticmethod + def _paragraph_anchor(paragraph: Dict[str, Any]) -> float: + for key in ("event_time_end", "event_time_start", "event_time", "created_at"): + value = paragraph.get(key) + try: + if value is not None: + return float(value) + except Exception: + continue + return 0.0 + + @staticmethod + def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]: + return ( + EpisodeService._paragraph_anchor(paragraph), + str(paragraph.get("hash", "") or ""), + ) + + def load_pending_paragraphs( + self, + pending_rows: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], List[str]]: + """ + 将 pending 行展开为段落上下文。 + + Returns: + (loaded_paragraphs, missing_hashes) + """ + loaded: List[Dict[str, Any]] = [] + missing: List[str] = [] + for row in pending_rows or []: + p_hash = str(row.get("paragraph_hash", "") or "").strip() + if not p_hash: + continue + + paragraph = self.metadata_store.get_paragraph(p_hash) + if not paragraph: + missing.append(p_hash) + continue + + loaded.append( + { + "hash": p_hash, + "source": str(row.get("source") or paragraph.get("source") or "").strip(), + "content": str(paragraph.get("content", "") or ""), + "created_at": self._to_optional_float(paragraph.get("created_at")) + or self._to_optional_float(row.get("created_at")) + or 0.0, + "event_time": self._to_optional_float(paragraph.get("event_time")), + "event_time_start": self._to_optional_float(paragraph.get("event_time_start")), + "event_time_end": self._to_optional_float(paragraph.get("event_time_end")), + "time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None, + "time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0), + } + ) + return loaded, missing + + def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。 + """ + if not paragraphs: + return [] + + max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20))) + max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000))) + window_seconds = max( + 60.0, + float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0, + ) + + by_source: Dict[str, List[Dict[str, Any]]] = {} + for paragraph in paragraphs: + source = str(paragraph.get("source", "") or "").strip() + by_source.setdefault(source, []).append(paragraph) + + groups: List[Dict[str, Any]] = [] + for source, items in by_source.items(): + ordered = sorted(items, key=self._paragraph_sort_key) + + current: List[Dict[str, Any]] = [] + current_chars = 0 + last_anchor: Optional[float] = None + + def flush() -> None: + nonlocal current, current_chars, last_anchor + if not current: + return + sorted_current = sorted(current, key=self._paragraph_sort_key) + groups.append( + { + "source": source, + "paragraphs": sorted_current, + } + ) + current = [] + current_chars = 0 + last_anchor = None + + for paragraph in ordered: + anchor = self._paragraph_anchor(paragraph) + content_len = len(str(paragraph.get("content", "") or "")) + + need_flush = False + if current: + if len(current) >= max_paragraphs: + need_flush = True + elif current_chars + content_len > max_chars: + need_flush = True + elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds: + need_flush = True + + if need_flush: + flush() + + current.append(paragraph) + current_chars += content_len + last_anchor = anchor + + flush() + + groups.sort( + key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0 + ) + return groups + + def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]: + starts: List[float] = [] + ends: List[float] = [] + granularity_priority = { + "minute": 4, + "hour": 3, + "day": 2, + "month": 1, + "year": 0, + } + granularity = None + granularity_rank = -1 + conf_values: List[float] = [] + + for p in paragraphs: + s = self._to_optional_float(p.get("event_time_start")) + e = self._to_optional_float(p.get("event_time_end")) + t = self._to_optional_float(p.get("event_time")) + c = self._to_optional_float(p.get("created_at")) + + start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c)) + end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c)) + + if start_candidate is not None: + starts.append(start_candidate) + if end_candidate is not None: + ends.append(end_candidate) + + g = str(p.get("time_granularity", "") or "").strip().lower() + if g in granularity_priority and granularity_priority[g] > granularity_rank: + granularity_rank = granularity_priority[g] + granularity = g + + conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0)) + + time_start = min(starts) if starts else None + time_end = max(ends) if ends else None + time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0 + return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0) + + def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]: + seen = set() + participants: List[str] = [] + for p_hash in paragraph_hashes: + try: + entities = self.metadata_store.get_paragraph_entities(p_hash) + except Exception: + entities = [] + for item in entities: + name = str(item.get("name", "") or "").strip() + if not name: + continue + key = name.lower() + if key in seen: + continue + seen.add(key) + participants.append(name) + if len(participants) >= limit: + return participants + return participants + + @staticmethod + def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]: + token_counter: Counter[str] = Counter() + token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}") + stop_words = { + "the", + "and", + "that", + "this", + "with", + "from", + "for", + "have", + "will", + "your", + "you", + "我们", + "你们", + "他们", + "以及", + "一个", + "这个", + "那个", + "然后", + "因为", + "所以", + } + for p in paragraphs: + text = str(p.get("content", "") or "").lower() + for token in token_pattern.findall(text): + if token in stop_words: + continue + token_counter[token] += 1 + + return [token for token, _ in token_counter.most_common(limit)] + + def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]: + paragraphs = group.get("paragraphs", []) or [] + source = str(group.get("source", "") or "").strip() + hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()] + snippets = [] + for p in paragraphs[:3]: + text = str(p.get("content", "") or "").strip().replace("\n", " ") + if text: + snippets.append(text[:140]) + summary = ";".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。" + + time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs) + participants = self._collect_participants(hashes, limit=12) + keywords = self._derive_keywords(paragraphs, limit=10) + + if time_start is not None: + day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d") + title = f"{source or 'unknown'} {day_text} 情景片段" + else: + title = f"{source or 'unknown'} 情景片段" + + return { + "title": title[:80], + "summary": summary, + "paragraph_hashes": hashes, + "participants": participants, + "keywords": keywords, + "time_confidence": time_conf, + "llm_confidence": 0.0, + "event_time_start": time_start, + "event_time_end": time_end, + "time_granularity": granularity, + "segmentation_model": "fallback_rule", + "segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION, + } + + @staticmethod + def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]: + in_group = set(group_hashes_ordered) + dedup: List[str] = [] + seen = set() + for h in episode_hashes or []: + token = str(h or "").strip() + if not token or token not in in_group or token in seen: + continue + seen.add(token) + dedup.append(token) + return dedup + + async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]: + paragraphs = group.get("paragraphs", []) or [] + if not paragraphs: + return { + "payloads": [], + "done_hashes": [], + "episode_count": 0, + "fallback_count": 0, + } + + source = str(group.get("source", "") or "").strip() + group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()] + group_start, group_end, _, _ = self._compute_time_meta(paragraphs) + + fallback_used = False + segmentation_model = "fallback_rule" + segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION + + try: + llm_result = await self.segmentation_service.segment( + source=source, + window_start=group_start, + window_end=group_end, + paragraphs=paragraphs, + ) + episodes = list(llm_result.get("episodes") or []) + segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto" + segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION + if not episodes: + raise ValueError("llm_empty_episodes") + except Exception as e: + logger.warning( + "Episode segmentation fallback: " + f"source={source} " + f"size={len(group_hashes)} " + f"err={e}" + ) + episodes = [self._build_fallback_episode(group)] + fallback_used = True + + stored_payloads: List[Dict[str, Any]] = [] + for episode in episodes: + ordered_hashes = self._normalize_episode_hashes( + episode_hashes=episode.get("paragraph_hashes", []), + group_hashes_ordered=group_hashes, + ) + if not ordered_hashes: + continue + + sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)] + event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs) + + participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()] + keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()] + if not participants: + participants = self._collect_participants(ordered_hashes, limit=16) + if not keywords: + keywords = self._derive_keywords(sub_paragraphs, limit=12) + + title = str(episode.get("title", "") or "").strip()[:120] + summary = str(episode.get("summary", "") or "").strip()[:2000] + if not title or not summary: + continue + + seed = json.dumps( + { + "source": source, + "hashes": ordered_hashes, + "version": segmentation_version, + }, + ensure_ascii=False, + sort_keys=True, + ) + episode_id = compute_hash(seed) + + payload = { + "episode_id": episode_id, + "source": source or None, + "title": title, + "summary": summary, + "event_time_start": episode.get("event_time_start", event_start), + "event_time_end": episode.get("event_time_end", event_end), + "time_granularity": episode.get("time_granularity", granularity), + "time_confidence": self._clamp_score( + episode.get("time_confidence"), + default=time_conf_default, + ), + "participants": participants[:16], + "keywords": keywords[:20], + "evidence_ids": ordered_hashes, + "paragraph_count": len(ordered_hashes), + "llm_confidence": self._clamp_score( + episode.get("llm_confidence"), + default=0.0 if fallback_used else 0.6, + ), + "segmentation_model": ( + str(episode.get("segmentation_model", "") or "").strip() + or ("fallback_rule" if fallback_used else segmentation_model) + ), + "segmentation_version": ( + str(episode.get("segmentation_version", "") or "").strip() + or segmentation_version + ), + } + stored_payloads.append(payload) + + return { + "payloads": stored_payloads, + "done_hashes": group_hashes, + "episode_count": len(stored_payloads), + "fallback_count": 1 if fallback_used else 0, + } + + async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]: + result = await self._build_episode_payloads_for_group(group) + stored_count = 0 + for payload in result.get("payloads") or []: + stored = self.metadata_store.upsert_episode(payload) + final_id = str(stored.get("episode_id") or payload.get("episode_id") or "") + if final_id: + self.metadata_store.bind_episode_paragraphs( + final_id, + list(payload.get("evidence_ids") or []), + ) + stored_count += 1 + + result["episode_count"] = stored_count + return { + "done_hashes": list(result.get("done_hashes") or []), + "episode_count": stored_count, + "fallback_count": int(result.get("fallback_count") or 0), + } + + async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]: + loaded, missing_hashes = self.load_pending_paragraphs(pending_rows) + groups = self.group_paragraphs(loaded) + + done_hashes: List[str] = list(missing_hashes) + failed_hashes: Dict[str, str] = {} + episode_count = 0 + fallback_count = 0 + + for group in groups: + group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])] + try: + result = await self.process_group(group) + done_hashes.extend(result.get("done_hashes") or []) + episode_count += int(result.get("episode_count") or 0) + fallback_count += int(result.get("fallback_count") or 0) + except Exception as e: + err = str(e)[:500] + for h in group_hashes: + if h: + failed_hashes[h] = err + + dedup_done = list(dict.fromkeys([h for h in done_hashes if h])) + return { + "done_hashes": dedup_done, + "failed_hashes": failed_hashes, + "episode_count": episode_count, + "fallback_count": fallback_count, + "missing_count": len(missing_hashes), + "group_count": len(groups), + } + + async def rebuild_source(self, source: str) -> Dict[str, Any]: + token = str(source or "").strip() + if not token: + return { + "source": "", + "episode_count": 0, + "fallback_count": 0, + "group_count": 0, + "paragraph_count": 0, + } + + paragraphs = self.metadata_store.get_live_paragraphs_by_source(token) + if not paragraphs: + replace_result = self.metadata_store.replace_episodes_for_source(token, []) + return { + "source": token, + "episode_count": int(replace_result.get("episode_count") or 0), + "fallback_count": 0, + "group_count": 0, + "paragraph_count": 0, + } + + groups = self.group_paragraphs(paragraphs) + payloads: List[Dict[str, Any]] = [] + fallback_count = 0 + + for group in groups: + result = await self._build_episode_payloads_for_group(group) + payloads.extend(list(result.get("payloads") or [])) + fallback_count += int(result.get("fallback_count") or 0) + + replace_result = self.metadata_store.replace_episodes_for_source(token, payloads) + return { + "source": token, + "episode_count": int(replace_result.get("episode_count") or 0), + "fallback_count": fallback_count, + "group_count": len(groups), + "paragraph_count": len(paragraphs), + } diff --git a/plugins/A_memorix/core/utils/person_profile_service.py b/plugins/A_memorix/core/utils/person_profile_service.py index ccbbaf90..6460c013 100644 --- a/plugins/A_memorix/core/utils/person_profile_service.py +++ b/plugins/A_memorix/core/utils/person_profile_service.py @@ -9,7 +9,11 @@ import json import time from typing import Any, Dict, List, Optional, Tuple +from sqlalchemy import or_ +from sqlmodel import select + from src.common.logger import get_logger +from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from ..embedding import EmbeddingAPIAdapter @@ -120,31 +124,40 @@ class PersonProfileService: if not key: return "" + try: + with get_db_session(auto_commit=False) as session: + record = session.exec( + select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1) + ).first() + if record: + return str(record) + + record = session.exec( + select(PersonInfo.person_id) + .where( + or_( + PersonInfo.person_name == key, + PersonInfo.user_nickname == key, + ) + ) + .limit(1) + ).first() + if record: + return str(record) + + record = session.exec( + select(PersonInfo.person_id) + .where(PersonInfo.group_cardname.contains(key)) + .limit(1) + ).first() + if record: + return str(record) + except Exception as e: + logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}") + if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key): return key.lower() - try: - record = ( - PersonInfo.select(PersonInfo.person_id) - .where((PersonInfo.person_name == key) | (PersonInfo.nickname == key)) - .first() - ) - if record and record.person_id: - return str(record.person_id) - except Exception: - pass - - try: - record = ( - PersonInfo.select(PersonInfo.person_id) - .where(PersonInfo.group_nick_name.contains(key)) - .first() - ) - if record and record.person_id: - return str(record.person_id) - except Exception: - pass - return "" def _parse_group_nicks(self, raw_value: Any) -> List[str]: @@ -160,7 +173,7 @@ class PersonProfileService: names: List[str] = [] for item in items: if isinstance(item, dict): - value = str(item.get("group_nick_name", "")).strip() + value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip() if value: names.append(value) elif isinstance(item, str): @@ -193,6 +206,42 @@ class PersonProfileService: traits.append(text) return traits[:10] + def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]: + """当人物主档案缺失时,从已有记忆证据里回捞可用别名。""" + if not person_id: + return [], "" + + aliases: List[str] = [] + primary_name = "" + seen = set() + + try: + paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id) + except Exception as e: + logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}") + return [], "" + + for paragraph in paragraphs[:20]: + paragraph_hash = str(paragraph.get("hash", "") or "").strip() + if not paragraph_hash: + continue + try: + paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash) + except Exception: + paragraph_entities = [] + for entity in paragraph_entities: + name = str(entity.get("name", "") or "").strip() + if not name or name == person_id: + continue + key = name.lower() + if key in seen: + continue + seen.add(key) + aliases.append(name) + if not primary_name: + primary_name = name + return aliases, primary_name + def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]: """获取人物别名集合、主展示名、记忆特征。""" aliases: List[str] = [] @@ -200,18 +249,28 @@ class PersonProfileService: memory_traits: List[str] = [] if not person_id: return aliases, primary_name, memory_traits + recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id) try: - record = PersonInfo.get_or_none(PersonInfo.person_id == person_id) - if not record: - return aliases, primary_name, memory_traits + with get_db_session(auto_commit=False) as session: + record = session.exec( + select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1) + ).first() + if not record: + return recovered_aliases, recovered_primary_name or person_id, memory_traits person_name = str(getattr(record, "person_name", "") or "").strip() - nickname = str(getattr(record, "nickname", "") or "").strip() - group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None)) + nickname = str(getattr(record, "user_nickname", "") or "").strip() + group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None)) memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None)) - primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id + primary_name = ( + person_name + or nickname + or recovered_primary_name + or str(getattr(record, "user_id", "") or "").strip() + or person_id + ) - candidates = [person_name, nickname] + group_nicks + candidates = [person_name, nickname] + group_nicks + recovered_aliases seen = set() for item in candidates: norm = str(item or "").strip() diff --git a/plugins/A_memorix/core/utils/relation_write_service.py b/plugins/A_memorix/core/utils/relation_write_service.py index b73e1260..6fa2e621 100644 --- a/plugins/A_memorix/core/utils/relation_write_service.py +++ b/plugins/A_memorix/core/utils/relation_write_service.py @@ -82,8 +82,9 @@ class RelationWriteService: ) self.metadata_store.set_relation_vector_state(hash_value, "ready") logger.info( - "metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s", - hash_value[:16], + "metric.relation_vector_write_success=1 " + "metric.relation_vector_write_success_count=1 " + f"hash={hash_value[:16]}" ) return RelationWriteResult( hash_value=hash_value, @@ -109,9 +110,10 @@ class RelationWriteService: bump_retry=True, ) logger.warning( - "metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s", - hash_value[:16], - err, + "metric.relation_vector_write_fail=1 " + "metric.relation_vector_write_fail_count=1 " + f"hash={hash_value[:16]} " + f"err={err}" ) return RelationWriteResult( hash_value=hash_value, diff --git a/plugins/A_memorix/core/utils/retrieval_tuning_manager.py b/plugins/A_memorix/core/utils/retrieval_tuning_manager.py new file mode 100644 index 00000000..e0e8ecd6 --- /dev/null +++ b/plugins/A_memorix/core/utils/retrieval_tuning_manager.py @@ -0,0 +1,1857 @@ +""" +Retrieval tuning manager for WebUI. +""" + +from __future__ import annotations + +import asyncio +import copy +import json +import random +import re +import time +import uuid +from collections import Counter, deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +from ..runtime.search_runtime_initializer import build_search_runtime +from .search_execution_service import SearchExecutionRequest, SearchExecutionService + +try: + from src.services import llm_service as llm_api +except Exception: # pragma: no cover + llm_api = None + +logger = get_logger("A_Memorix.RetrievalTuningManager") + + +OBJECTIVES = {"precision_priority", "balanced", "recall_priority"} +INTENSITIES = {"quick": 8, "standard": 20, "deep": 32} +CATEGORIES = {"query_nl", "query_kw", "spo_relation", "spo_search"} +_RUNTIME_CONFIG_INSTANCE_KEYS = { + "vector_store", + "graph_store", + "metadata_store", + "embedding_manager", + "sparse_index", + "relation_write_service", + "plugin_instance", +} + + +def _now() -> float: + return time.time() + + +def _clamp_int(value: Any, default: int, min_value: int, max_value: int) -> int: + try: + parsed = int(value) + except Exception: + parsed = int(default) + return max(min_value, min(max_value, parsed)) + + +def _clamp_float(value: Any, default: float, min_value: float, max_value: float) -> float: + try: + parsed = float(value) + except Exception: + parsed = float(default) + return max(min_value, min(max_value, parsed)) + + +def _coerce_bool(value: Any, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + text = str(value).strip().lower() + if text in {"1", "true", "yes", "y", "on"}: + return True + if text in {"0", "false", "no", "n", "off"}: + return False + return default + + +def _nested_get(data: Dict[str, Any], key: str, default: Any = None) -> Any: + cur: Any = data + for part in key.split("."): + if isinstance(cur, dict) and part in cur: + cur = cur[part] + else: + return default + return cur + + +def _nested_set(data: Dict[str, Any], key: str, value: Any) -> None: + parts = key.split(".") + cur = data + for part in parts[:-1]: + if part not in cur or not isinstance(cur[part], dict): + cur[part] = {} + cur = cur[part] + cur[parts[-1]] = value + + +def _deep_merge(base: Dict[str, Any], patch: Dict[str, Any]) -> Dict[str, Any]: + out = copy.deepcopy(base) + for key, value in (patch or {}).items(): + if isinstance(value, dict) and isinstance(out.get(key), dict): + out[key] = _deep_merge(out[key], value) + else: + out[key] = copy.deepcopy(value) + return out + + +def _safe_json_loads(text: str) -> Optional[Any]: + raw = str(text or "").strip() + if not raw: + return None + if "```" in raw: + raw = raw.replace("```json", "```") + for seg in raw.split("```"): + seg = seg.strip() + if seg.startswith("{") or seg.startswith("["): + raw = seg + break + try: + return json.loads(raw) + except Exception: + pass + s = raw.find("{") + e = raw.rfind("}") + if s >= 0 and e > s: + try: + return json.loads(raw[s : e + 1]) + except Exception: + return None + return None + + +@dataclass +class RetrievalQueryCase: + case_id: str + category: str + query: str + expected_hashes: List[str] = field(default_factory=list) + expected_spo: Dict[str, str] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "case_id": self.case_id, + "category": self.category, + "query": self.query, + "expected_hashes": list(self.expected_hashes), + "expected_spo": dict(self.expected_spo), + "metadata": dict(self.metadata), + } + + +@dataclass +class RetrievalTuningRoundRecord: + round_index: int + candidate_profile: Dict[str, Any] + metrics: Dict[str, Any] + score: float + latency_ms: float + failure_summary: Dict[str, Any] = field(default_factory=dict) + created_at: float = field(default_factory=_now) + + def to_dict(self) -> Dict[str, Any]: + return { + "round_index": self.round_index, + "candidate_profile": copy.deepcopy(self.candidate_profile), + "metrics": copy.deepcopy(self.metrics), + "score": float(self.score), + "latency_ms": float(self.latency_ms), + "failure_summary": copy.deepcopy(self.failure_summary), + "created_at": float(self.created_at), + } + + +@dataclass +class RetrievalTuningTaskRecord: + task_id: str + status: str + progress: float + objective: str + intensity: str + rounds_total: int + rounds_done: int = 0 + best_profile: Dict[str, Any] = field(default_factory=dict) + best_metrics: Dict[str, Any] = field(default_factory=dict) + best_score: float = -1.0 + baseline_profile: Dict[str, Any] = field(default_factory=dict) + baseline_metrics: Dict[str, Any] = field(default_factory=dict) + error: str = "" + params: Dict[str, Any] = field(default_factory=dict) + query_set_stats: Dict[str, Any] = field(default_factory=dict) + artifact_paths: Dict[str, str] = field(default_factory=dict) + rounds: List[RetrievalTuningRoundRecord] = field(default_factory=list) + cancel_requested: bool = False + created_at: float = field(default_factory=_now) + started_at: Optional[float] = None + finished_at: Optional[float] = None + updated_at: float = field(default_factory=_now) + apply_log: List[Dict[str, Any]] = field(default_factory=list) + + def to_summary(self) -> Dict[str, Any]: + return { + "task_id": self.task_id, + "status": self.status, + "progress": self.progress, + "objective": self.objective, + "intensity": self.intensity, + "rounds_total": self.rounds_total, + "rounds_done": self.rounds_done, + "best_score": self.best_score, + "error": self.error, + "query_set_stats": dict(self.query_set_stats), + "artifact_paths": dict(self.artifact_paths), + "created_at": self.created_at, + "started_at": self.started_at, + "finished_at": self.finished_at, + "updated_at": self.updated_at, + } + + def to_detail(self, include_rounds: bool = False) -> Dict[str, Any]: + payload = self.to_summary() + payload.update( + { + "params": copy.deepcopy(self.params), + "best_profile": copy.deepcopy(self.best_profile), + "best_metrics": copy.deepcopy(self.best_metrics), + "baseline_profile": copy.deepcopy(self.baseline_profile), + "baseline_metrics": copy.deepcopy(self.baseline_metrics), + "apply_log": copy.deepcopy(self.apply_log), + } + ) + if include_rounds: + payload["rounds"] = [x.to_dict() for x in self.rounds] + return payload + + +class RetrievalTuningManager: + def __init__( + self, + plugin: Any, + *, + import_write_blocked_provider: Optional[Callable[[], bool]] = None, + ): + self.plugin = plugin + self._import_write_blocked_provider = import_write_blocked_provider + + self._lock = asyncio.Lock() + self._tasks: Dict[str, RetrievalTuningTaskRecord] = {} + self._task_order: deque[str] = deque() + self._queue: deque[str] = deque() + self._active_task_id: Optional[str] = None + self._worker_task: Optional[asyncio.Task] = None + self._stopping = False + + self._rollback_snapshot: Optional[Dict[str, Any]] = None + + self._artifacts_root = Path(__file__).resolve().parents[2] / "artifacts" / "retrieval_tuning" + self._artifacts_root.mkdir(parents=True, exist_ok=True) + + def _cfg(self, key: str, default: Any = None) -> Any: + getter = getattr(self.plugin, "get_config", None) + if callable(getter): + return getter(key, default) + return default + + def _is_enabled(self) -> bool: + return bool(self._cfg("web.tuning.enabled", True)) + + def _queue_limit(self) -> int: + return _clamp_int(self._cfg("web.tuning.max_queue_size", 8), 8, 1, 100) + + def _poll_interval_s(self) -> float: + ms = _clamp_int(self._cfg("web.tuning.poll_interval_ms", 1200), 1200, 200, 60000) + return max(0.2, ms / 1000.0) + + def _llm_retry_cfg(self) -> Dict[str, Any]: + return { + "max_attempts": _clamp_int(self._cfg("web.tuning.llm_retry.max_attempts", 3), 3, 1, 10), + "min_wait_seconds": _clamp_float(self._cfg("web.tuning.llm_retry.min_wait_seconds", 2), 2.0, 0.1, 60.0), + "max_wait_seconds": _clamp_float(self._cfg("web.tuning.llm_retry.max_wait_seconds", 20), 20.0, 0.2, 120.0), + "backoff_multiplier": _clamp_float(self._cfg("web.tuning.llm_retry.backoff_multiplier", 2), 2.0, 1.0, 10.0), + } + + def _eval_query_timeout_s(self) -> float: + return _clamp_float( + self._cfg("web.tuning.eval_query_timeout_seconds", 10.0), + 10.0, + 0.01, + 120.0, + ) + + def get_runtime_settings(self) -> Dict[str, Any]: + intensity = str(self._cfg("web.tuning.default_intensity", "standard") or "standard") + if intensity not in INTENSITIES: + intensity = "standard" + objective = str(self._cfg("web.tuning.default_objective", "precision_priority") or "precision_priority") + if objective not in OBJECTIVES: + objective = "precision_priority" + return { + "enabled": self._is_enabled(), + "poll_interval_ms": _clamp_int(self._cfg("web.tuning.poll_interval_ms", 1200), 1200, 200, 60000), + "max_queue_size": self._queue_limit(), + "default_objective": objective, + "default_intensity": intensity, + "default_rounds": INTENSITIES[intensity], + "default_top_k_eval": _clamp_int(self._cfg("web.tuning.default_top_k_eval", 20), 20, 5, 100), + "default_sample_size": _clamp_int(self._cfg("web.tuning.default_sample_size", 24), 24, 4, 200), + "eval_query_timeout_seconds": self._eval_query_timeout_s(), + "llm_retry": self._llm_retry_cfg(), + } + + def _ensure_ready(self) -> None: + required = ("metadata_store", "vector_store", "graph_store", "embedding_manager") + missing = [x for x in required if getattr(self.plugin, x, None) is None] + if missing: + raise ValueError(f"调优依赖未初始化: {', '.join(missing)}") + checker = getattr(self.plugin, "is_runtime_ready", None) + if callable(checker) and not checker(): + raise ValueError("插件运行时未就绪") + provider = self._import_write_blocked_provider + if provider is not None and bool(provider()): + raise ValueError("导入任务运行中,当前禁止启动检索调优") + + def get_profile_snapshot(self) -> Dict[str, Any]: + cfg = getattr(self.plugin, "config", {}) or {} + profile = { + "retrieval": { + "top_k_paragraphs": _nested_get(cfg, "retrieval.top_k_paragraphs", 20), + "top_k_relations": _nested_get(cfg, "retrieval.top_k_relations", 10), + "top_k_final": _nested_get(cfg, "retrieval.top_k_final", 10), + "alpha": _nested_get(cfg, "retrieval.alpha", 0.5), + "enable_ppr": _nested_get(cfg, "retrieval.enable_ppr", True), + "search": {"smart_fallback": {"enabled": _nested_get(cfg, "retrieval.search.smart_fallback.enabled", True)}}, + "sparse": { + "enabled": _nested_get(cfg, "retrieval.sparse.enabled", True), + "mode": _nested_get(cfg, "retrieval.sparse.mode", "auto"), + "candidate_k": _nested_get(cfg, "retrieval.sparse.candidate_k", 80), + "relation_candidate_k": _nested_get(cfg, "retrieval.sparse.relation_candidate_k", 60), + }, + "fusion": { + "method": _nested_get(cfg, "retrieval.fusion.method", "weighted_rrf"), + "rrf_k": _nested_get(cfg, "retrieval.fusion.rrf_k", 60), + "vector_weight": _nested_get(cfg, "retrieval.fusion.vector_weight", 0.7), + "bm25_weight": _nested_get(cfg, "retrieval.fusion.bm25_weight", 0.3), + }, + }, + "threshold": { + "percentile": _nested_get(cfg, "threshold.percentile", 75.0), + "min_results": _nested_get(cfg, "threshold.min_results", 3), + }, + } + return self._normalize_profile(profile, fallback=profile) + + def _normalize_profile(self, profile: Optional[Dict[str, Any]], *, fallback: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + raw = copy.deepcopy(profile or {}) + base = copy.deepcopy(fallback or self.get_profile_snapshot()) + + def pick(path: str, default: Any) -> Any: + if _nested_get(raw, path, None) is not None: + return _nested_get(raw, path, default) + if path in raw: + return raw.get(path, default) + return _nested_get(base, path, default) + + fusion_method = str(pick("retrieval.fusion.method", "weighted_rrf") or "weighted_rrf").strip().lower() + if fusion_method not in {"weighted_rrf", "alpha_legacy"}: + fusion_method = "weighted_rrf" + + sparse_mode = str(pick("retrieval.sparse.mode", "auto") or "auto").strip().lower() + if sparse_mode not in {"auto", "hybrid", "fallback_only"}: + sparse_mode = "auto" + + vec_w = _clamp_float(pick("retrieval.fusion.vector_weight", 0.7), 0.7, 0.0, 1.0) + bm_w = _clamp_float(pick("retrieval.fusion.bm25_weight", 0.3), 0.3, 0.0, 1.0) + s = vec_w + bm_w + if s <= 1e-9: + vec_w, bm_w = 0.7, 0.3 + else: + vec_w, bm_w = vec_w / s, bm_w / s + + return { + "retrieval": { + "top_k_paragraphs": _clamp_int(pick("retrieval.top_k_paragraphs", 20), 20, 10, 1200), + "top_k_relations": _clamp_int(pick("retrieval.top_k_relations", 10), 10, 4, 512), + "top_k_final": _clamp_int(pick("retrieval.top_k_final", 10), 10, 4, 512), + "alpha": _clamp_float(pick("retrieval.alpha", 0.5), 0.5, 0.0, 1.0), + "enable_ppr": _coerce_bool(pick("retrieval.enable_ppr", True), True), + "search": {"smart_fallback": {"enabled": _coerce_bool(pick("retrieval.search.smart_fallback.enabled", True), True)}}, + "sparse": { + "enabled": _coerce_bool(pick("retrieval.sparse.enabled", True), True), + "mode": sparse_mode, + "candidate_k": _clamp_int(pick("retrieval.sparse.candidate_k", 80), 80, 20, 2000), + "relation_candidate_k": _clamp_int(pick("retrieval.sparse.relation_candidate_k", 60), 60, 20, 2000), + }, + "fusion": { + "method": fusion_method, + "rrf_k": _clamp_int(pick("retrieval.fusion.rrf_k", 60), 60, 1, 500), + "vector_weight": float(vec_w), + "bm25_weight": float(bm_w), + }, + }, + "threshold": { + "percentile": _clamp_float(pick("threshold.percentile", 75.0), 75.0, 1.0, 99.0), + "min_results": _clamp_int(pick("threshold.min_results", 3), 3, 1, 100), + }, + } + + def _apply_profile_to_runtime(self, normalized: Dict[str, Any]) -> None: + if not isinstance(getattr(self.plugin, "config", None), dict): + raise RuntimeError("插件 config 不可写") + for key, value in normalized.items(): + _nested_set(self.plugin.config, key, value) + plugin_cfg = getattr(self.plugin, "_plugin_config", None) + if isinstance(plugin_cfg, dict): + for key, value in normalized.items(): + _nested_set(plugin_cfg, key, value) + + async def apply_profile(self, profile: Dict[str, Any], *, reason: str = "manual") -> Dict[str, Any]: + normalized = self._normalize_profile(profile) + current = self.get_profile_snapshot() + self._rollback_snapshot = current + self._apply_profile_to_runtime(normalized) + return { + "applied": normalized, + "rollback_snapshot": current, + "reason": reason, + "applied_at": _now(), + } + + async def rollback_profile(self) -> Dict[str, Any]: + if not self._rollback_snapshot: + raise ValueError("暂无可回滚的参数快照") + target = self._normalize_profile(self._rollback_snapshot, fallback=self._rollback_snapshot) + self._apply_profile_to_runtime(target) + return {"rolled_back_to": target, "rolled_back_at": _now()} + + def export_toml_snippet(self, profile: Optional[Dict[str, Any]] = None) -> str: + p = self._normalize_profile(profile or self.get_profile_snapshot()) + r = p["retrieval"] + t = p["threshold"] + lines = [ + "[retrieval]", + f"top_k_paragraphs = {int(r['top_k_paragraphs'])}", + f"top_k_relations = {int(r['top_k_relations'])}", + f"top_k_final = {int(r['top_k_final'])}", + f"alpha = {float(r['alpha']):.4f}", + f"enable_ppr = {str(bool(r['enable_ppr'])).lower()}", + "", + "[retrieval.search.smart_fallback]", + f"enabled = {str(bool(r['search']['smart_fallback']['enabled'])).lower()}", + "", + "[retrieval.sparse]", + f"enabled = {str(bool(r['sparse']['enabled'])).lower()}", + f"mode = \"{r['sparse']['mode']}\"", + f"candidate_k = {int(r['sparse']['candidate_k'])}", + f"relation_candidate_k = {int(r['sparse']['relation_candidate_k'])}", + "", + "[retrieval.fusion]", + f"method = \"{r['fusion']['method']}\"", + f"rrf_k = {int(r['fusion']['rrf_k'])}", + f"vector_weight = {float(r['fusion']['vector_weight']):.4f}", + f"bm25_weight = {float(r['fusion']['bm25_weight']):.4f}", + "", + "[threshold]", + f"percentile = {float(t['percentile']):.4f}", + f"min_results = {int(t['min_results'])}", + ] + return "\n".join(lines).strip() + "\n" + + def _pending_task_count(self) -> int: + return sum(1 for t in self._tasks.values() if t.status in {"queued", "running", "cancel_requested"}) + + def _sample_triples_for_query_set( + self, + *, + triples: List[Tuple[Any, Any, Any, Any]], + sample_size: int, + seed: int, + ) -> Tuple[List[Tuple[str, str, str, str]], Dict[str, Any]]: + normalized: List[Tuple[str, str, str, str]] = [] + for row in triples: + try: + subject, predicate, obj, rel_hash = row + except Exception: + continue + relation_hash = str(rel_hash or "").strip() + if not relation_hash: + continue + normalized.append((str(subject or ""), str(predicate or ""), str(obj or ""), relation_hash)) + + if not normalized: + return [], {"error": "no_relations"} + + target = min(max(4, int(sample_size)), len(normalized)) + predicate_counter = Counter([str(x[1] or "").strip() or "__empty__" for x in normalized]) + entity_counter = Counter() + for subj, _, obj, _ in normalized: + entity_counter.update([str(subj or "").strip().lower() or "__empty__"]) + entity_counter.update([str(obj or "").strip().lower() or "__empty__"]) + + if target >= len(normalized): + return list(normalized), { + "strategy": "all", + "sample_size": int(target), + "total_triples": int(len(normalized)), + "predicate_total": int(len(predicate_counter)), + "predicate_sampled": int(len(predicate_counter)), + } + + rng = random.Random(f"{seed}:triple_sample") + by_predicate: Dict[str, List[int]] = {} + for idx, (_, predicate, _, _) in enumerate(normalized): + key = str(predicate or "").strip() or "__empty__" + by_predicate.setdefault(key, []).append(idx) + for pool in by_predicate.values(): + rng.shuffle(pool) + + predicate_order = sorted(by_predicate.keys()) + rng.shuffle(predicate_order) + + selected: List[int] = [] + selected_set = set() + + # First pass: predicate round-robin to avoid head predicate dominating query set. + while len(selected) < target: + progressed = False + for key in predicate_order: + pool = by_predicate.get(key, []) + if not pool: + continue + idx = int(pool.pop()) + if idx in selected_set: + continue + selected.append(idx) + selected_set.add(idx) + progressed = True + if len(selected) >= target: + break + if not progressed: + break + + if len(selected) < target: + remain = [idx for idx in range(len(normalized)) if idx not in selected_set] + rng.shuffle(remain) + + # Second pass: prefer lower-frequency entities and predicates for better diversity. + def _remain_score(idx: int) -> Tuple[int, int]: + subj, predicate, obj, _ = normalized[idx] + subject_freq = int(entity_counter.get(str(subj or "").strip().lower() or "__empty__", 0)) + object_freq = int(entity_counter.get(str(obj or "").strip().lower() or "__empty__", 0)) + pred_freq = int(predicate_counter.get(str(predicate or "").strip() or "__empty__", 0)) + return (subject_freq + object_freq, pred_freq) + + remain = sorted(remain, key=_remain_score) + need = target - len(selected) + for idx in remain[:need]: + selected.append(idx) + selected_set.add(idx) + + selected = selected[:target] + sampled = [normalized[idx] for idx in selected] + sampled_predicates = {str(x[1] or "").strip() or "__empty__" for x in sampled} + + return sampled, { + "strategy": "predicate_round_robin_entity_diversity", + "sample_size": int(target), + "total_triples": int(len(normalized)), + "predicate_total": int(len(predicate_counter)), + "predicate_sampled": int(len(sampled_predicates)), + } + + def _select_round_eval_cases( + self, + *, + cases: List[RetrievalQueryCase], + intensity: str, + round_index: int, + seed: int, + ) -> List[RetrievalQueryCase]: + if not cases: + return [] + mode = str(intensity or "standard").strip().lower() + if mode not in INTENSITIES: + mode = "standard" + if mode == "deep": + return list(cases) + + if mode == "quick": + ratio = 0.45 + min_total = 16 + else: + ratio = 0.70 + min_total = 24 + + total = len(cases) + target = max(min_total, int(total * ratio)) + if target >= total: + return list(cases) + + rng = random.Random(f"{seed}:{round_index}:subset") + by_cat: Dict[str, List[RetrievalQueryCase]] = {} + for item in cases: + by_cat.setdefault(str(item.category), []).append(item) + + selected: List[RetrievalQueryCase] = [] + selected_ids = set() + cat_names = sorted([x for x in by_cat.keys() if x in CATEGORIES]) + if not cat_names: + cat_names = sorted(by_cat.keys()) + per_cat = max(1, target // max(1, len(cat_names))) + + for cat in cat_names: + pool = by_cat.get(cat, []) + if not pool: + continue + picked = list(pool) if len(pool) <= per_cat else rng.sample(pool, per_cat) + for item in picked: + if item.case_id in selected_ids: + continue + selected.append(item) + selected_ids.add(item.case_id) + + if len(selected) < target: + remain = [x for x in cases if x.case_id not in selected_ids] + if len(remain) > (target - len(selected)): + remain = rng.sample(remain, target - len(selected)) + for item in remain: + selected.append(item) + selected_ids.add(item.case_id) + + return selected[:target] + + async def _ensure_worker(self) -> None: + async with self._lock: + if self._worker_task and not self._worker_task.done(): + return + self._stopping = False + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def shutdown(self) -> None: + self._stopping = True + worker = self._worker_task + if worker is None or worker.done(): + return + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning(f"Retrieval tuning worker shutdown failed: {e}") + + async def create_task(self, payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("检索调优中心已禁用") + self._ensure_ready() + + data = payload or {} + objective = str(data.get("objective") or self._cfg("web.tuning.default_objective", "precision_priority")) + if objective not in OBJECTIVES: + raise ValueError(f"objective 非法: {objective}") + + intensity = str(data.get("intensity") or self._cfg("web.tuning.default_intensity", "standard")) + if intensity not in INTENSITIES: + raise ValueError(f"intensity 非法: {intensity}") + + rounds_total = _clamp_int(data.get("rounds", INTENSITIES[intensity]), INTENSITIES[intensity], 1, 200) + sample_size = _clamp_int(data.get("sample_size", self._cfg("web.tuning.default_sample_size", 24)), 24, 4, 500) + top_k_eval = _clamp_int(data.get("top_k_eval", self._cfg("web.tuning.default_top_k_eval", 20)), 20, 5, 100) + eval_query_timeout_seconds = _clamp_float( + data.get("eval_query_timeout_seconds", self._eval_query_timeout_s()), + self._eval_query_timeout_s(), + 0.01, + 120.0, + ) + llm_enabled = _coerce_bool(data.get("llm_enabled", True), True) + seed = data.get("seed") + try: + seed = int(seed) + except Exception: + seed = int(time.time()) % 1000003 + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("调优任务队列已满,请稍后重试") + task = RetrievalTuningTaskRecord( + task_id=uuid.uuid4().hex, + status="queued", + progress=0.0, + objective=objective, + intensity=intensity, + rounds_total=rounds_total, + params={ + "sample_size": sample_size, + "top_k_eval": top_k_eval, + "eval_query_timeout_seconds": float(eval_query_timeout_seconds), + "llm_enabled": llm_enabled, + "seed": seed, + }, + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + task.updated_at = _now() + + await self._ensure_worker() + return task.to_summary() + + async def list_tasks(self, limit: int = 50) -> List[Dict[str, Any]]: + limit = _clamp_int(limit, 50, 1, 500) + async with self._lock: + items: List[Dict[str, Any]] = [] + for task_id in list(self._task_order)[:limit]: + task = self._tasks.get(task_id) + if task: + items.append(task.to_summary()) + return items + + async def get_task(self, task_id: str, include_rounds: bool = False) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + return task.to_detail(include_rounds=include_rounds) + + async def get_rounds(self, task_id: str, offset: int = 0, limit: int = 50) -> Optional[Dict[str, Any]]: + offset = max(0, int(offset)) + limit = _clamp_int(limit, 50, 1, 500) + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + total = len(task.rounds) + sliced = task.rounds[offset : offset + limit] + return { + "total": total, + "offset": offset, + "limit": limit, + "items": [item.to_dict() for item in sliced], + } + + async def cancel_task(self, task_id: str) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + if task.status in {"completed", "failed", "cancelled"}: + return task.to_summary() + if task.status == "queued": + task.status = "cancelled" + task.cancel_requested = True + task.finished_at = _now() + task.updated_at = task.finished_at + self._queue = deque([x for x in self._queue if x != task_id]) + return task.to_summary() + task.status = "cancel_requested" + task.cancel_requested = True + task.updated_at = _now() + return task.to_summary() + + async def apply_best(self, task_id: str) -> Dict[str, Any]: + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + raise ValueError("任务不存在") + if task.status != "completed": + raise ValueError("任务未完成,无法应用最优参数") + if not task.best_profile: + raise ValueError("任务没有可应用的最优参数") + best = copy.deepcopy(task.best_profile) + applied = await self.apply_profile(best, reason=f"task:{task_id}:apply_best") + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.apply_log.append({"applied_at": _now(), "reason": "apply_best", "profile": best}) + task.updated_at = _now() + return applied + + async def get_report(self, task_id: str, fmt: str = "md") -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return None + artifacts = dict(task.artifact_paths) + fmt = str(fmt or "md").strip().lower() + if fmt not in {"md", "json"}: + fmt = "md" + path_key = "report_md" if fmt == "md" else "report_json" + path = artifacts.get(path_key) + if not path: + return {"format": fmt, "content": "", "path": ""} + p = Path(path) + if not p.exists(): + return {"format": fmt, "content": "", "path": str(p)} + try: + content = p.read_text(encoding="utf-8") + except Exception: + content = "" + return {"format": fmt, "content": content, "path": str(p)} + + async def _worker_loop(self) -> None: + while not self._stopping: + task_id: Optional[str] = None + async with self._lock: + while self._queue: + candidate = self._queue.popleft() + task = self._tasks.get(candidate) + if task is None: + continue + if task.status != "queued": + continue + task_id = candidate + self._active_task_id = candidate + break + + if not task_id: + await asyncio.sleep(self._poll_interval_s()) + continue + + try: + await self._run_task(task_id) + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Retrieval tuning task crashed: task_id={task_id}, err={e}") + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.status = "failed" + task.error = str(e) + task.finished_at = _now() + task.updated_at = task.finished_at + finally: + async with self._lock: + if self._active_task_id == task_id: + self._active_task_id = None + + async def _run_task(self, task_id: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + task.status = "running" + task.started_at = _now() + task.updated_at = task.started_at + + artifacts_dir = self._artifacts_root / task_id + artifacts_dir.mkdir(parents=True, exist_ok=True) + query_set_path = artifacts_dir / "query_set.json" + rounds_path = artifacts_dir / "round_metrics.jsonl" + best_profile_path = artifacts_dir / "best_profile.json" + report_json_path = artifacts_dir / "report.json" + report_md_path = artifacts_dir / "report.md" + + try: + params = dict(task.params) + cases, stats = await self._build_query_set( + sample_size=int(params["sample_size"]), + seed=int(params["seed"]), + llm_enabled=bool(params.get("llm_enabled", True)), + ) + if not cases: + raise ValueError("当前知识库样本不足,无法构建调优测试集") + + query_set_path.write_text( + json.dumps( + { + "task_id": task_id, + "created_at": _now(), + "stats": stats, + "items": [c.to_dict() for c in cases], + }, + ensure_ascii=False, + indent=2, + ), + encoding="utf-8", + ) + + baseline_profile = self.get_profile_snapshot() + top_k_eval = int(params["top_k_eval"]) + baseline_eval = await self._evaluate_profile( + profile=baseline_profile, + cases=cases, + objective=task.objective, + top_k_eval=top_k_eval, + query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), + ) + baseline_round = RetrievalTuningRoundRecord( + round_index=0, + candidate_profile=baseline_profile, + metrics=baseline_eval["metrics"], + score=float(baseline_eval["score"]), + latency_ms=float(baseline_eval["avg_elapsed_ms"]), + failure_summary=baseline_eval["failure_summary"], + ) + rounds_path.write_text(json.dumps(baseline_round.to_dict(), ensure_ascii=False) + "\n", encoding="utf-8") + + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + task.query_set_stats = stats + task.baseline_profile = copy.deepcopy(baseline_profile) + task.baseline_metrics = copy.deepcopy(baseline_eval["metrics"]) + task.rounds.append(baseline_round) + task.best_profile = copy.deepcopy(baseline_profile) + task.best_metrics = copy.deepcopy(baseline_eval["metrics"]) + task.best_score = float(baseline_eval["score"]) + task.progress = 0.0 + task.updated_at = _now() + + best_profile = copy.deepcopy(baseline_profile) + best_metrics = copy.deepcopy(baseline_eval["metrics"]) + best_failure_summary = copy.deepcopy(baseline_eval["failure_summary"]) + best_score = float(baseline_eval["score"]) + llm_suggestions: List[Dict[str, Any]] = [] + task_cancelled = False + + for round_idx in range(1, int(task.rounds_total) + 1): + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + if task.cancel_requested or task.status == "cancel_requested": + task.status = "cancelled" + task.finished_at = _now() + task.updated_at = task.finished_at + task_cancelled = True + break + + if round_idx == 1 or (round_idx % 5 == 0 and not llm_suggestions): + llm_suggestions = await self._suggest_profiles_with_llm( + base_profile=best_profile, + failure_summary=best_failure_summary, + objective=task.objective, + max_count=3, + enabled=bool(params.get("llm_enabled", True)), + ) + + candidate_profile = self._generate_candidate_profile( + task_id=task_id, + round_index=round_idx, + objective=task.objective, + baseline_profile=baseline_profile, + best_profile=best_profile, + llm_suggestions=llm_suggestions, + ) + eval_cases = self._select_round_eval_cases( + cases=cases, + intensity=task.intensity, + round_index=round_idx, + seed=int(params.get("seed", 0)), + ) + eval_result = await self._evaluate_profile( + profile=candidate_profile, + cases=eval_cases, + objective=task.objective, + top_k_eval=top_k_eval, + query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), + ) + round_record = RetrievalTuningRoundRecord( + round_index=round_idx, + candidate_profile=candidate_profile, + metrics=eval_result["metrics"], + score=float(eval_result["score"]), + latency_ms=float(eval_result["avg_elapsed_ms"]), + failure_summary=eval_result["failure_summary"], + ) + with rounds_path.open("a", encoding="utf-8") as fp: + fp.write(json.dumps(round_record.to_dict(), ensure_ascii=False) + "\n") + + if float(eval_result["score"]) > float(best_score): + best_score = float(eval_result["score"]) + best_profile = copy.deepcopy(candidate_profile) + best_metrics = copy.deepcopy(eval_result["metrics"]) + best_failure_summary = copy.deepcopy(eval_result["failure_summary"]) + + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + task.rounds_done = round_idx + task.rounds.append(round_record) + task.best_profile = copy.deepcopy(best_profile) + task.best_metrics = copy.deepcopy(best_metrics) + task.best_score = float(best_score) + task.progress = min(1.0, float(round_idx) / float(task.rounds_total)) + task.updated_at = _now() + + if best_profile and (not task_cancelled): + # 候选轮可能基于子样本评估,收官时用全量样本复核,确保最终指标可解释。 + best_full = await self._evaluate_profile( + profile=best_profile, + cases=cases, + objective=task.objective, + top_k_eval=top_k_eval, + query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), + ) + best_profile = copy.deepcopy(best_profile) + best_metrics = copy.deepcopy(best_full["metrics"]) + best_failure_summary = copy.deepcopy(best_full["failure_summary"]) + best_score = float(best_full["score"]) + if best_score < float(baseline_eval["score"]): + best_profile = copy.deepcopy(baseline_profile) + best_metrics = copy.deepcopy(baseline_eval["metrics"]) + best_failure_summary = copy.deepcopy(baseline_eval["failure_summary"]) + best_score = float(baseline_eval["score"]) + + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.best_profile = copy.deepcopy(best_profile) + task.best_metrics = copy.deepcopy(best_metrics) + task.best_score = float(best_score) + task.updated_at = _now() + + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + if task.status not in {"cancelled", "failed"}: + task.status = "completed" + task.progress = 1.0 + task.finished_at = _now() + task.updated_at = task.finished_at + final_task = copy.deepcopy(task) + + if final_task.status == "completed": + best_profile_path.write_text(json.dumps(final_task.best_profile, ensure_ascii=False, indent=2), encoding="utf-8") + report_payload = self._build_report_payload(final_task) + report_json_path.write_text(json.dumps(report_payload, ensure_ascii=False, indent=2), encoding="utf-8") + report_md_path.write_text(self._build_report_markdown(final_task, report_payload), encoding="utf-8") + + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.artifact_paths = { + "query_set": str(query_set_path), + "round_metrics_jsonl": str(rounds_path), + "best_profile": str(best_profile_path), + "report_json": str(report_json_path), + "report_md": str(report_md_path), + } + task.updated_at = _now() + except Exception as e: + logger.error(f"Retrieval tuning task failed: task_id={task_id}, err={e}") + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.status = "failed" + task.error = str(e) + task.finished_at = _now() + task.updated_at = task.finished_at + + async def _build_query_set(self, *, sample_size: int, seed: int, llm_enabled: bool) -> Tuple[List[RetrievalQueryCase], Dict[str, Any]]: + store = getattr(self.plugin, "metadata_store", None) + if store is None: + return [], {"error": "metadata_store_unavailable"} + + triples = list(store.get_all_triples() or []) + if not triples: + return [], {"error": "no_relations"} + + sampled, sample_info = self._sample_triples_for_query_set( + triples=triples, + sample_size=sample_size, + seed=seed, + ) + if not sampled: + return [], {"error": "no_relations"} + + anchors: List[Dict[str, Any]] = [] + for idx, row in enumerate(sampled): + subject, predicate, obj, relation_hash = row + paragraphs = store.get_paragraphs_by_relation(relation_hash) + para_hash = "" + para_content = "" + if paragraphs: + para_hash = str(paragraphs[0].get("hash") or "").strip() + para_content = str(paragraphs[0].get("content") or "") + anchors.append( + { + "anchor_id": f"a{idx+1:04d}", + "subject": str(subject or ""), + "predicate": str(predicate or ""), + "object": str(obj or ""), + "relation_hash": relation_hash, + "paragraph_hash": para_hash, + "paragraph_excerpt": para_content[:300], + } + ) + + if not anchors: + return [], {"error": "no_anchors"} + + predicate_groups: Dict[str, List[Dict[str, Any]]] = {} + for anchor in anchors: + predicate_groups.setdefault(str(anchor.get("predicate") or ""), []).append(anchor) + + nl_queries = await self._generate_nl_queries_with_llm(anchors, enabled=llm_enabled) + cases: List[RetrievalQueryCase] = [] + + seq = 0 + for anchor in anchors: + seq += 1 + subject = anchor["subject"] + predicate = anchor["predicate"] + obj = anchor["object"] + rel_hash = anchor["relation_hash"] + para_hash = anchor["paragraph_hash"] + expected = [rel_hash] + if para_hash: + expected.append(para_hash) + aid = anchor["anchor_id"] + + common_meta = { + "anchor_id": aid, + "relation_hash": rel_hash, + "paragraph_hash": para_hash, + "subject": subject, + "predicate": predicate, + "object": obj, + } + cases.append( + RetrievalQueryCase( + case_id=f"spo_relation_{seq:04d}", + category="spo_relation", + query=f"{subject}|{predicate}|{obj}", + expected_hashes=[rel_hash], + expected_spo={"subject": subject, "predicate": predicate, "object": obj}, + metadata=dict(common_meta), + ) + ) + cases.append( + RetrievalQueryCase( + case_id=f"spo_search_{seq:04d}", + category="spo_search", + query=self._build_spo_search_query( + anchor=anchor, + seq=seq, + predicate_groups=predicate_groups, + ), + expected_hashes=list(expected), + metadata=dict(common_meta), + ) + ) + cases.append( + RetrievalQueryCase( + case_id=f"query_kw_{seq:04d}", + category="query_kw", + query=self._build_keyword_query( + anchor=anchor, + seq=seq, + predicate_groups=predicate_groups, + ), + expected_hashes=list(expected), + metadata=dict(common_meta), + ) + ) + nl_query = nl_queries.get(aid) or self._build_nl_template( + anchor=anchor, + seq=seq, + predicate_groups=predicate_groups, + ) + cases.append( + RetrievalQueryCase( + case_id=f"query_nl_{seq:04d}", + category="query_nl", + query=nl_query, + expected_hashes=list(expected), + metadata=dict(common_meta), + ) + ) + + counts = Counter([c.category for c in cases]) + stats = { + "anchors": len(anchors), + "case_total": len(cases), + "category_counts": {k: int(v) for k, v in counts.items()}, + "seed": int(seed), + "sample_size": int(sample_info.get("sample_size", len(anchors))), + "sampling": dict(sample_info), + "llm_nl_enabled": bool(llm_enabled), + "llm_nl_generated": int(len(nl_queries)), + } + return cases, stats + + def _pick_contrast_anchor( + self, + *, + anchor: Dict[str, Any], + predicate_groups: Dict[str, List[Dict[str, Any]]], + seq: int, + ) -> Optional[Dict[str, Any]]: + predicate = str(anchor.get("predicate") or "") + pool = predicate_groups.get(predicate, []) + if not pool: + return None + candidates = [x for x in pool if x is not anchor and str(x.get("object") or "") != str(anchor.get("object") or "")] + if not candidates: + return None + return candidates[seq % len(candidates)] + + def _build_spo_search_query( + self, + *, + anchor: Dict[str, Any], + seq: int, + predicate_groups: Dict[str, List[Dict[str, Any]]], + ) -> str: + subject = str(anchor.get("subject") or "") + predicate = str(anchor.get("predicate") or "") + obj = str(anchor.get("object") or "") + contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) + contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" + + variants = [ + f"{subject} {predicate} {obj}", + f"{subject} {obj} relation {predicate}", + f"{predicate} {subject} {obj} evidence", + f"{subject} {predicate} {obj} not {contrast_obj}".strip(), + ] + return variants[seq % len(variants)].strip() + + def _build_keyword_query( + self, + *, + anchor: Dict[str, Any], + seq: int, + predicate_groups: Dict[str, List[Dict[str, Any]]], + ) -> str: + subject = str(anchor.get("subject") or "") + predicate = str(anchor.get("predicate") or "") + obj = str(anchor.get("object") or "") + excerpt = str(anchor.get("paragraph_excerpt") or "") + tokens = re.findall(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}", excerpt) + extras: List[str] = [] + seen = set() + for token in tokens: + key = token.lower() + if key in seen: + continue + if key in {subject.lower(), predicate.lower(), obj.lower()}: + continue + seen.add(key) + extras.append(token) + if len(extras) >= 2: + break + contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) + contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" + + variants = [ + [subject, obj] + extras[:2], + [predicate, obj] + extras[:2], + [subject, predicate] + extras[:2], + [subject, obj, predicate, contrast_obj] + extras[:1], + ] + parts = variants[seq % len(variants)] + return " ".join([x for x in parts if x]).strip() + + def _build_nl_template( + self, + *, + anchor: Dict[str, Any], + seq: int, + predicate_groups: Dict[str, List[Dict[str, Any]]], + ) -> str: + subject = str(anchor.get("subject") or "") + predicate = str(anchor.get("predicate") or "") + obj = str(anchor.get("object") or "") + contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) + contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" + templates = [ + f"请问 {subject} 与 {obj} 的关系是什么,是否是“{predicate}”?", + f"在当前知识库中,哪条信息说明 {subject} 对应的是 {obj},关系词接近“{predicate}”?", + f"我想确认:{subject} 和 {obj} 之间是不是“{predicate}”这层关系,而不是 {contrast_obj}?", + f"帮我查一下关于 {subject} 与 {obj} 的证据,重点看 {predicate} 相关描述。", + ] + return templates[seq % len(templates)] + + async def _select_llm_model(self) -> Optional[Any]: + if llm_api is None: + return None + try: + models = llm_api.get_available_models() or {} + except Exception: + return None + if not models: + return None + + cfg_model = str(self._cfg("advanced.extraction_model", "auto") or "auto").strip() + if cfg_model.lower() != "auto" and cfg_model in models: + return models[cfg_model] + for task_name in ["utils", "planner", "tool_use", "replyer", "embedding"]: + if task_name in models: + return models[task_name] + return models[next(iter(models))] + + async def _llm_call_text(self, prompt: str, *, request_type: str) -> str: + if llm_api is None: + raise RuntimeError("llm_api unavailable") + model_cfg = await self._select_llm_model() + if model_cfg is None: + raise RuntimeError("no_llm_model") + + retry = self._llm_retry_cfg() + max_attempts = int(retry["max_attempts"]) + min_wait = float(retry["min_wait_seconds"]) + max_wait = float(retry["max_wait_seconds"]) + backoff = float(retry["backoff_multiplier"]) + + last_error: Optional[Exception] = None + for idx in range(max_attempts): + try: + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_cfg, + request_type=request_type, + ) + if not success: + raise RuntimeError("llm_generation_failed") + text = str(response or "").strip() + if text: + return text + raise RuntimeError("empty_llm_response") + except Exception as e: + last_error = e + if idx >= max_attempts - 1: + break + delay = min(max_wait, min_wait * (backoff ** idx)) + await asyncio.sleep(max(0.05, delay)) + raise RuntimeError(f"LLM call failed: {last_error}") + + async def _generate_nl_queries_with_llm(self, anchors: List[Dict[str, Any]], *, enabled: bool) -> Dict[str, str]: + if not enabled or llm_api is None or not anchors: + return {} + payload = [ + { + "anchor_id": x["anchor_id"], + "subject": x["subject"], + "predicate": x["predicate"], + "object": x["object"], + "paragraph_excerpt": x["paragraph_excerpt"][:180], + } + for x in anchors[:60] + ] + prompt = ( + "你是检索评估问题生成器。" + "请基于给定 SPO 与简短上下文,为每条样本生成 1 条自然语言检索问题,返回 JSON:" + "{\"items\":[{\"anchor_id\":\"...\",\"query\":\"...\"}]}。\n" + f"样本:\n{json.dumps(payload, ensure_ascii=False)}" + ) + try: + raw = await self._llm_call_text(prompt, request_type="A_Memorix.RetrievalTuning.NLCaseGen") + obj = _safe_json_loads(raw) + if not isinstance(obj, dict): + return {} + items = obj.get("items") + if not isinstance(items, list): + return {} + out: Dict[str, str] = {} + for row in items: + if not isinstance(row, dict): + continue + anchor_id = str(row.get("anchor_id") or "").strip() + query = str(row.get("query") or "").strip() + if anchor_id and query: + out[anchor_id] = query + return out + except Exception: + return {} + + async def _suggest_profiles_with_llm( + self, + *, + base_profile: Dict[str, Any], + failure_summary: Dict[str, Any], + objective: str, + max_count: int, + enabled: bool, + ) -> List[Dict[str, Any]]: + if not enabled or llm_api is None or max_count <= 0: + return [] + prompt = ( + "你是检索调参专家。" + "请基于基础参数与失败摘要,给出最多 " + f"{int(max_count)} 组候选参数,返回 JSON: {{\"profiles\": [ ... ]}}。\n" + "字段仅可包含:retrieval.top_k_paragraphs, retrieval.top_k_relations, retrieval.top_k_final, " + "retrieval.alpha, retrieval.enable_ppr, retrieval.search.smart_fallback.enabled, " + "retrieval.sparse.enabled, retrieval.sparse.mode, retrieval.sparse.candidate_k, retrieval.sparse.relation_candidate_k, " + "retrieval.fusion.method, retrieval.fusion.rrf_k, retrieval.fusion.vector_weight, retrieval.fusion.bm25_weight, " + "threshold.percentile, threshold.min_results。\n" + f"objective={objective}\n" + f"base={json.dumps(base_profile, ensure_ascii=False)}\n" + f"failure_summary={json.dumps(failure_summary, ensure_ascii=False)}" + ) + try: + raw = await self._llm_call_text(prompt, request_type="A_Memorix.RetrievalTuning.ProfileSuggest") + obj = _safe_json_loads(raw) + if not isinstance(obj, dict): + return [] + profiles = obj.get("profiles") + if not isinstance(profiles, list): + return [] + out = [] + for item in profiles[:max_count]: + if isinstance(item, dict): + out.append(self._normalize_profile(item, fallback=base_profile)) + return out + except Exception: + return [] + + def _generate_candidate_profile( + self, + *, + task_id: str, + round_index: int, + objective: str, + baseline_profile: Dict[str, Any], + best_profile: Dict[str, Any], + llm_suggestions: List[Dict[str, Any]], + ) -> Dict[str, Any]: + if llm_suggestions: + return self._normalize_profile(llm_suggestions.pop(0), fallback=best_profile) + + rng = random.Random(f"{task_id}:{round_index}") + base = baseline_profile if round_index % 4 == 1 else best_profile + candidate = copy.deepcopy(base) + + if objective == "precision_priority": + para_choices = [40, 80, 120, 180, 240, 320] + rel_choices = [4, 8, 12, 16, 24] + final_choices = [4, 8, 12, 16, 20, 32, 48, 64] + alpha_choices = [0.0, 0.35, 0.50, 0.62, 0.72, 0.82, 0.90] + pct_choices = [55, 60, 65, 72, 80] + min_results_choices = [1, 2] + elif objective == "recall_priority": + para_choices = [120, 220, 300, 420, 560, 720] + rel_choices = [8, 12, 16, 24, 32] + final_choices = [8, 16, 32, 48, 64, 96, 128] + alpha_choices = [0.20, 0.35, 0.45, 0.55, 0.65, 0.75] + pct_choices = [40, 48, 55, 62] + min_results_choices = [1, 2, 3] + else: + para_choices = [80, 160, 240, 320, 420, 520] + rel_choices = [6, 10, 14, 18, 24, 30] + final_choices = [6, 12, 20, 32, 48, 64, 80] + alpha_choices = [0.25, 0.45, 0.55, 0.65, 0.75, 0.85] + pct_choices = [48, 55, 62, 70] + min_results_choices = [1, 2, 3] + + _nested_set(candidate, "retrieval.top_k_paragraphs", rng.choice(para_choices)) + _nested_set(candidate, "retrieval.top_k_relations", rng.choice(rel_choices)) + _nested_set(candidate, "retrieval.top_k_final", rng.choice(final_choices)) + _nested_set(candidate, "retrieval.alpha", rng.choice(alpha_choices)) + # PPR 在 TestClient/异步评估场景下存在偶发长时阻塞风险,调优评估链路固定关闭。 + _nested_set(candidate, "retrieval.enable_ppr", False) + _nested_set(candidate, "retrieval.search.smart_fallback.enabled", bool(rng.choice([True, True, False]))) + _nested_set(candidate, "retrieval.sparse.enabled", bool(rng.choice([True, True, False]))) + _nested_set(candidate, "retrieval.sparse.mode", rng.choice(["auto", "hybrid", "fallback_only"])) + _nested_set(candidate, "retrieval.sparse.candidate_k", rng.choice([60, 80, 120, 160, 220, 320])) + _nested_set(candidate, "retrieval.sparse.relation_candidate_k", rng.choice([40, 60, 90, 120, 180, 260])) + _nested_set(candidate, "retrieval.fusion.method", rng.choice(["weighted_rrf", "weighted_rrf", "alpha_legacy"])) + _nested_set(candidate, "retrieval.fusion.rrf_k", rng.choice([30, 45, 60, 75, 90])) + vec_w = float(rng.choice([0.55, 0.65, 0.72, 0.80, 0.88])) + _nested_set(candidate, "retrieval.fusion.vector_weight", vec_w) + _nested_set(candidate, "retrieval.fusion.bm25_weight", 1.0 - vec_w) + _nested_set(candidate, "threshold.percentile", rng.choice(pct_choices)) + _nested_set(candidate, "threshold.min_results", rng.choice(min_results_choices)) + + return self._normalize_profile(candidate, fallback=base) + + def _build_runtime_config(self, normalized_profile: Dict[str, Any]) -> Dict[str, Any]: + raw_base = getattr(self.plugin, "config", {}) or {} + if isinstance(raw_base, dict): + base = { + key: value + for key, value in raw_base.items() + if key not in _RUNTIME_CONFIG_INSTANCE_KEYS + } + else: + base = {} + merged = _deep_merge(base, normalized_profile) + # 调优评估场景优先稳定性,避免并发访问共享 SQLite/Faiss 导致长时阻塞。 + _nested_set(merged, "retrieval.enable_parallel", False) + # 调优评估阶段关闭 PPR,规避 PageRank 线程计算偶发阻塞导致整轮卡死。 + _nested_set(merged, "retrieval.enable_ppr", False) + merged["vector_store"] = getattr(self.plugin, "vector_store", None) + merged["graph_store"] = getattr(self.plugin, "graph_store", None) + merged["metadata_store"] = getattr(self.plugin, "metadata_store", None) + merged["embedding_manager"] = getattr(self.plugin, "embedding_manager", None) + merged["sparse_index"] = getattr(self.plugin, "sparse_index", None) + merged["plugin_instance"] = self.plugin + return merged + + async def _evaluate_profile( + self, + *, + profile: Dict[str, Any], + cases: List[RetrievalQueryCase], + objective: str, + top_k_eval: int, + query_timeout_s: float, + ) -> Dict[str, Any]: + normalized = self._normalize_profile(profile) + eval_top_k = _clamp_int(top_k_eval, 20, 1, 1000) + # 评估时让 top_k_final 参与有效召回深度,避免该参数对评分无影响。 + request_top_k = min( + int(eval_top_k), + _clamp_int(_nested_get(normalized, "retrieval.top_k_final", eval_top_k), eval_top_k, 1, 512), + ) + eval_timeout_s = _clamp_float( + query_timeout_s, + self._eval_query_timeout_s(), + 0.01, + 120.0, + ) + runtime_cfg = self._build_runtime_config(normalized) + runtime = build_search_runtime( + plugin_config=runtime_cfg, + logger_obj=logger, + owner_tag="retrieval_tuning", + log_prefix="[RetrievalTuning]", + ) + if not runtime.ready: + metrics = { + "total_text_cases": 0, + "precision_at_1": 0.0, + "precision_at_3": 0.0, + "mrr": 0.0, + "recall_at_k": 0.0, + "spo_relation_hit_rate": 0.0, + "empty_rate": 1.0, + "avg_elapsed_ms": 0.0, + "category": {}, + "error": runtime.error or "runtime_not_ready", + } + return {"metrics": metrics, "score": -1.0, "avg_elapsed_ms": 0.0, "failure_summary": {"reason": metrics["error"]}} + + text_total = 0 + hit1 = 0 + hit3 = 0 + hitk = 0 + mrr_sum = 0.0 + empty_count = 0 + timeout_count = 0 + elapsed_total = 0.0 + text_failed: List[str] = [] + + spo_total = 0 + spo_hit = 0 + spo_failed: List[str] = [] + + category_stats: Dict[str, Dict[str, Any]] = {} + failed_predicates = Counter() + + for case in cases: + cat = str(case.category) + if cat not in CATEGORIES: + continue + if cat not in category_stats: + category_stats[cat] = { + "total": 0, + "hit": 0, + "hit_at_1": 0, + "hit_at_3": 0, + "empty": 0, + } + category_stats[cat]["total"] += 1 + + if cat == "spo_relation": + spo_total += 1 + spo = case.expected_spo or {} + rows = runtime.metadata_store.get_relations( + subject=str(spo.get("subject") or ""), + predicate=str(spo.get("predicate") or ""), + object=str(spo.get("object") or ""), + ) + expected_hash = str(case.expected_hashes[0]) if case.expected_hashes else "" + ok = False + for row in rows: + if not isinstance(row, dict): + continue + if expected_hash and str(row.get("hash") or "") == expected_hash: + ok = True + break + if not expected_hash: + ok = True + break + if ok: + spo_hit += 1 + category_stats[cat]["hit"] += 1 + category_stats[cat]["hit_at_1"] += 1 + category_stats[cat]["hit_at_3"] += 1 + else: + spo_failed.append(case.case_id) + failed_predicates.update([str(spo.get("predicate") or "").strip() or "__empty__"]) + continue + + text_total += 1 + req = SearchExecutionRequest( + caller="retrieval_tuning", + query_type="search", + query=str(case.query or "").strip(), + top_k=int(request_top_k), + use_threshold=True, + # 调优评估固定关闭 PPR,避免该链路阻塞拖挂整轮任务。 + enable_ppr=False, + ) + try: + execution = await asyncio.wait_for( + SearchExecutionService.execute( + retriever=runtime.retriever, + threshold_filter=runtime.threshold_filter, + plugin_config=runtime_cfg, + request=req, + enforce_chat_filter=False, + reinforce_access=False, + ), + timeout=float(eval_timeout_s), + ) + except asyncio.TimeoutError: + timeout_count += 1 + empty_count += 1 + category_stats[cat]["empty"] += 1 + text_failed.append(case.case_id) + failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) + continue + + if execution is None: + empty_count += 1 + category_stats[cat]["empty"] += 1 + text_failed.append(case.case_id) + failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) + continue + + elapsed_total += float(getattr(execution, "elapsed_ms", 0.0) or 0.0) + + if not bool(getattr(execution, "success", False)): + empty_count += 1 + category_stats[cat]["empty"] += 1 + text_failed.append(case.case_id) + failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) + continue + + hashes = [str(getattr(x, "hash_value", "") or "") for x in (getattr(execution, "results", None) or [])] + if not hashes: + empty_count += 1 + category_stats[cat]["empty"] += 1 + + expected_set = set(case.expected_hashes or []) + rank = 0 + for idx, hv in enumerate(hashes, start=1): + if hv and hv in expected_set: + rank = idx + break + + if rank > 0: + category_stats[cat]["hit"] += 1 + hitk += 1 + if rank <= 1: + hit1 += 1 + category_stats[cat]["hit_at_1"] += 1 + if rank <= 3: + hit3 += 1 + category_stats[cat]["hit_at_3"] += 1 + mrr_sum += 1.0 / float(rank) + else: + text_failed.append(case.case_id) + failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) + + p1 = (hit1 / text_total) if text_total else 0.0 + p3 = (hit3 / text_total) if text_total else 0.0 + recall = (hitk / text_total) if text_total else 0.0 + mrr = (mrr_sum / text_total) if text_total else 0.0 + spo_rate = (spo_hit / spo_total) if spo_total else 0.0 + empty_rate = (empty_count / text_total) if text_total else 1.0 + avg_elapsed = (elapsed_total / text_total) if text_total else 0.0 + + metrics = { + "total_text_cases": int(text_total), + "precision_at_1": float(round(p1, 6)), + "precision_at_3": float(round(p3, 6)), + "mrr": float(round(mrr, 6)), + "recall_at_k": float(round(recall, 6)), + "spo_relation_hit_rate": float(round(spo_rate, 6)), + "empty_rate": float(round(empty_rate, 6)), + "timeout_count": int(timeout_count), + "avg_elapsed_ms": float(round(avg_elapsed, 3)), + "category": category_stats, + } + metrics["category_floor_penalty"] = float(round(self._category_floor_penalty(metrics, objective=objective), 6)) + + score = self._score_metrics(metrics, objective=objective) + failure_summary = { + "text_failed_count": len(text_failed), + "spo_failed_count": len(spo_failed), + "failed_case_ids": text_failed[:50] + spo_failed[:50], + "failed_by_category": {k: int(v["total"] - v["hit"]) for k, v in category_stats.items()}, + "top_failed_predicates": [ + {"predicate": key, "count": int(cnt)} + for key, cnt in failed_predicates.most_common(5) + if key + ], + "query_timeout_seconds": float(eval_timeout_s), + "timeout_count": int(timeout_count), + "effective_top_k": int(request_top_k), + "ppr_forced_disabled": True, + } + return { + "metrics": metrics, + "score": float(round(score, 6)), + "avg_elapsed_ms": float(avg_elapsed), + "failure_summary": failure_summary, + } + + def _score_metrics(self, metrics: Dict[str, Any], *, objective: str) -> float: + p1 = float(metrics.get("precision_at_1", 0.0) or 0.0) + p3 = float(metrics.get("precision_at_3", 0.0) or 0.0) + mrr = float(metrics.get("mrr", 0.0) or 0.0) + recall = float(metrics.get("recall_at_k", 0.0) or 0.0) + spo = float(metrics.get("spo_relation_hit_rate", 0.0) or 0.0) + empty_rate = float(metrics.get("empty_rate", 1.0) or 1.0) + category_penalty = metrics.get("category_floor_penalty", None) + if category_penalty is None: + category_penalty = self._category_floor_penalty(metrics, objective=objective) + category_penalty = float(max(0.0, category_penalty)) + + if objective == "recall_priority": + raw = 0.15 * p1 + 0.15 * p3 + 0.15 * mrr + 0.40 * recall + 0.15 * spo + penalty = 0.05 * empty_rate + elif objective == "balanced": + raw = 0.25 * p1 + 0.20 * p3 + 0.15 * mrr + 0.25 * recall + 0.15 * spo + penalty = 0.10 * empty_rate + else: + raw = 0.40 * p1 + 0.20 * p3 + 0.15 * mrr + 0.15 * recall + 0.10 * spo + penalty = 0.15 * empty_rate + return float(raw - penalty - category_penalty) + + def _category_floor_penalty(self, metrics: Dict[str, Any], *, objective: str) -> float: + category = metrics.get("category") + if not isinstance(category, dict) or not category: + return 0.0 + + if objective == "recall_priority": + floors = {"query_nl": 0.60, "query_kw": 0.48, "spo_search": 0.52, "spo_relation": 0.88} + scale = 0.12 + elif objective == "balanced": + floors = {"query_nl": 0.65, "query_kw": 0.52, "spo_search": 0.55, "spo_relation": 0.90} + scale = 0.18 + else: + floors = {"query_nl": 0.70, "query_kw": 0.55, "spo_search": 0.58, "spo_relation": 0.92} + scale = 0.25 + + weights = {"query_nl": 1.0, "query_kw": 1.1, "spo_search": 1.0, "spo_relation": 1.2} + weighted_shortfall = 0.0 + weight_total = 0.0 + + for cat, floor in floors.items(): + row = category.get(cat) + if not isinstance(row, dict): + continue + total = int(row.get("total", 0) or 0) + if total <= 0: + continue + hit = float(row.get("hit", 0.0) or 0.0) + hit_rate = max(0.0, min(1.0, hit / float(max(1, total)))) + shortfall = max(0.0, float(floor) - hit_rate) + w = float(weights.get(cat, 1.0)) + weighted_shortfall += w * shortfall + weight_total += w + + if weight_total <= 1e-9: + return 0.0 + return float(scale * (weighted_shortfall / weight_total)) + + def _build_report_payload(self, task: RetrievalTuningTaskRecord) -> Dict[str, Any]: + baseline = task.baseline_metrics or {} + best = task.best_metrics or {} + + def delta(name: str) -> float: + return float(best.get(name, 0.0) or 0.0) - float(baseline.get(name, 0.0) or 0.0) + + return { + "task_id": task.task_id, + "objective": task.objective, + "intensity": task.intensity, + "status": task.status, + "created_at": task.created_at, + "started_at": task.started_at, + "finished_at": task.finished_at, + "rounds_total": task.rounds_total, + "rounds_done": task.rounds_done, + "best_score": task.best_score, + "baseline_score": self._score_metrics(baseline, objective=task.objective), + "query_set_stats": task.query_set_stats, + "baseline_metrics": baseline, + "best_metrics": best, + "deltas": { + "precision_at_1": delta("precision_at_1"), + "precision_at_3": delta("precision_at_3"), + "mrr": delta("mrr"), + "recall_at_k": delta("recall_at_k"), + "spo_relation_hit_rate": delta("spo_relation_hit_rate"), + "empty_rate": delta("empty_rate"), + "timeout_count": delta("timeout_count"), + "avg_elapsed_ms": delta("avg_elapsed_ms"), + }, + "best_profile": task.best_profile, + "baseline_profile": task.baseline_profile, + "apply_log": task.apply_log, + } + + def _build_report_markdown(self, task: RetrievalTuningTaskRecord, payload: Dict[str, Any]) -> str: + baseline = payload.get("baseline_metrics", {}) or {} + best = payload.get("best_metrics", {}) or {} + d = payload.get("deltas", {}) or {} + lines = [ + f"# 检索调优报告({task.task_id})", + "", + "## 1. 任务信息", + f"- 状态: {task.status}", + f"- 目标函数: {task.objective}", + f"- 强度: {task.intensity}", + f"- 轮次: baseline + {task.rounds_total}", + f"- 创建时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.created_at))}", + f"- 开始时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.started_at)) if task.started_at else '-'}", + f"- 完成时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.finished_at)) if task.finished_at else '-'}", + "", + "## 2. 基线 vs 最优", + f"- baseline score: {payload.get('baseline_score', 0.0):.6f}", + f"- best score: {task.best_score:.6f}", + f"- P@1: {baseline.get('precision_at_1', 0.0):.4f} -> {best.get('precision_at_1', 0.0):.4f} (Δ {d.get('precision_at_1', 0.0):+.4f})", + f"- P@3: {baseline.get('precision_at_3', 0.0):.4f} -> {best.get('precision_at_3', 0.0):.4f} (Δ {d.get('precision_at_3', 0.0):+.4f})", + f"- MRR: {baseline.get('mrr', 0.0):.4f} -> {best.get('mrr', 0.0):.4f} (Δ {d.get('mrr', 0.0):+.4f})", + f"- Recall@K: {baseline.get('recall_at_k', 0.0):.4f} -> {best.get('recall_at_k', 0.0):.4f} (Δ {d.get('recall_at_k', 0.0):+.4f})", + f"- SPO relation hit: {baseline.get('spo_relation_hit_rate', 0.0):.4f} -> {best.get('spo_relation_hit_rate', 0.0):.4f} (Δ {d.get('spo_relation_hit_rate', 0.0):+.4f})", + f"- 空结果率: {baseline.get('empty_rate', 0.0):.4f} -> {best.get('empty_rate', 0.0):.4f} (Δ {d.get('empty_rate', 0.0):+.4f})", + f"- 超时数: {int(baseline.get('timeout_count', 0) or 0)} -> {int(best.get('timeout_count', 0) or 0)} (Δ {int(d.get('timeout_count', 0) or 0):+d})", + f"- 平均耗时(ms): {baseline.get('avg_elapsed_ms', 0.0):.2f} -> {best.get('avg_elapsed_ms', 0.0):.2f} (Δ {d.get('avg_elapsed_ms', 0.0):+.2f})", + "", + "## 3. 最优参数", + "```json", + json.dumps(task.best_profile, ensure_ascii=False, indent=2), + "```", + "", + "## 4. 测试集规模", + f"- {json.dumps(task.query_set_stats, ensure_ascii=False)}", + "", + "## 5. 说明", + "- 本报告仅对当前已存储图谱与向量状态有效。", + "- 参数应用策略:运行时生效,不自动写入 config.toml。", + ] + return "\n".join(lines).strip() + "\n" diff --git a/plugins/A_memorix/core/utils/runtime_self_check.py b/plugins/A_memorix/core/utils/runtime_self_check.py index 36a2cf7e..131ab32a 100644 --- a/plugins/A_memorix/core/utils/runtime_self_check.py +++ b/plugins/A_memorix/core/utils/runtime_self_check.py @@ -61,6 +61,29 @@ def _build_report( } +def _normalize_encoded_vector(encoded: Any) -> np.ndarray: + if encoded is None: + raise ValueError("embedding encode returned None") + + if isinstance(encoded, np.ndarray): + array = encoded + else: + array = np.asarray(encoded, dtype=np.float32) + + if array.ndim == 2: + if array.shape[0] != 1: + raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}") + array = array[0] + + if array.ndim != 1: + raise ValueError(f"embedding encode returned invalid ndim={array.ndim}") + if array.size <= 0: + raise ValueError("embedding encode returned empty vector") + if not np.all(np.isfinite(array)): + raise ValueError("embedding encode returned non-finite values") + return array.astype(np.float32, copy=False) + + async def run_embedding_runtime_self_check( *, config: Any, @@ -91,13 +114,11 @@ async def run_embedding_runtime_self_check( try: detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0) encoded = await embedding_manager.encode(sample_text) - if isinstance(encoded, np.ndarray): - encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1]) - else: - encoded_dimension = len(encoded) if encoded is not None else 0 + encoded_array = _normalize_encoded_vector(encoded) + encoded_dimension = int(encoded_array.shape[0]) except Exception as exc: elapsed_ms = (time.perf_counter() - start) * 1000.0 - logger.warning("embedding runtime self-check failed: %s", exc) + logger.warning(f"embedding runtime self-check failed: {exc}") return _build_report( ok=False, code="embedding_probe_failed", diff --git a/plugins/A_memorix/core/utils/search_execution_service.py b/plugins/A_memorix/core/utils/search_execution_service.py new file mode 100644 index 00000000..efb2093f --- /dev/null +++ b/plugins/A_memorix/core/utils/search_execution_service.py @@ -0,0 +1,442 @@ +""" +统一检索执行服务。 + +用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。 +""" + +from __future__ import annotations + +import hashlib +import json +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +from ..retrieval import TemporalQueryOptions +from .search_postprocess import ( + apply_safe_content_dedup, + maybe_apply_smart_path_fallback, +) +from .time_parser import parse_query_time_range + +logger = get_logger("A_Memorix.SearchExecutionService") + + +def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any: + if not isinstance(config, dict): + return default + current: Any = config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + +def _sanitize_text(value: Any) -> str: + if value is None: + return "" + return str(value).strip() + + +@dataclass +class SearchExecutionRequest: + caller: str + stream_id: Optional[str] = None + group_id: Optional[str] = None + user_id: Optional[str] = None + query_type: str = "search" # search|semantic|time|hybrid + query: str = "" + top_k: Optional[int] = None + time_from: Optional[str] = None + time_to: Optional[str] = None + person: Optional[str] = None + source: Optional[str] = None + use_threshold: bool = True + enable_ppr: bool = True + + +@dataclass +class SearchExecutionResult: + success: bool + error: str = "" + query_type: str = "search" + query: str = "" + top_k: int = 10 + time_from: Optional[str] = None + time_to: Optional[str] = None + person: Optional[str] = None + source: Optional[str] = None + temporal: Optional[TemporalQueryOptions] = None + results: List[Any] = field(default_factory=list) + elapsed_ms: float = 0.0 + chat_filtered: bool = False + dedup_hit: bool = False + + @property + def count(self) -> int: + return len(self.results) + + +class SearchExecutionService: + """统一检索执行服务。""" + + @staticmethod + def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]: + if isinstance(plugin_config, dict): + plugin_instance = plugin_config.get("plugin_instance") + if plugin_instance is not None: + return plugin_instance + + try: + from ...plugin import AMemorixPlugin + + return getattr(AMemorixPlugin, "get_global_instance", lambda: None)() + except Exception: + return None + + @staticmethod + def _normalize_query_type(raw_query_type: str) -> str: + query_type = _sanitize_text(raw_query_type).lower() or "search" + if query_type == "semantic": + return "search" + return query_type + + @staticmethod + def _resolve_runtime_component( + plugin_config: Optional[dict], + plugin_instance: Optional[Any], + key: str, + ) -> Optional[Any]: + if isinstance(plugin_config, dict): + value = plugin_config.get(key) + if value is not None: + return value + if plugin_instance is not None: + value = getattr(plugin_instance, key, None) + if value is not None: + return value + return None + + @staticmethod + def _resolve_top_k( + plugin_config: Optional[dict], + query_type: str, + top_k_raw: Optional[Any], + ) -> Tuple[bool, int, str]: + temporal_default_top_k = int( + _get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10) + ) + default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10 + if top_k_raw is None: + return True, max(1, min(50, default_top_k)), "" + try: + top_k = int(top_k_raw) + except (TypeError, ValueError): + return False, 0, "top_k 参数必须为整数" + return True, max(1, min(50, top_k)), "" + + @staticmethod + def _build_temporal( + plugin_config: Optional[dict], + query_type: str, + time_from_raw: Optional[str], + time_to_raw: Optional[str], + person: Optional[str], + source: Optional[str], + ) -> Tuple[bool, Optional[TemporalQueryOptions], str]: + if query_type not in {"time", "hybrid"}: + return True, None, "" + + temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True)) + if not temporal_enabled: + return False, None, "时序检索已禁用(retrieval.temporal.enabled=false)" + + if not time_from_raw and not time_to_raw: + return False, None, "time/hybrid 模式至少需要 time_from 或 time_to" + + try: + ts_from, ts_to = parse_query_time_range( + str(time_from_raw) if time_from_raw is not None else None, + str(time_to_raw) if time_to_raw is not None else None, + ) + except ValueError as e: + return False, None, f"时间参数错误: {e}" + + temporal = TemporalQueryOptions( + time_from=ts_from, + time_to=ts_to, + person=_sanitize_text(person) or None, + source=_sanitize_text(source) or None, + allow_created_fallback=bool( + _get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True) + ), + candidate_multiplier=int( + _get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8) + ), + max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)), + ) + return True, temporal, "" + + @staticmethod + def _build_request_key( + request: SearchExecutionRequest, + query_type: str, + top_k: int, + temporal: Optional[TemporalQueryOptions], + ) -> str: + payload = { + "stream_id": _sanitize_text(request.stream_id), + "query_type": query_type, + "query": _sanitize_text(request.query), + "time_from": _sanitize_text(request.time_from), + "time_to": _sanitize_text(request.time_to), + "time_from_ts": temporal.time_from if temporal else None, + "time_to_ts": temporal.time_to if temporal else None, + "person": _sanitize_text(request.person), + "source": _sanitize_text(request.source), + "top_k": int(top_k), + "use_threshold": bool(request.use_threshold), + "enable_ppr": bool(request.enable_ppr), + } + payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True) + return hashlib.sha1(payload_json.encode("utf-8")).hexdigest() + + @staticmethod + async def execute( + *, + retriever: Any, + threshold_filter: Optional[Any], + plugin_config: Optional[dict], + request: SearchExecutionRequest, + enforce_chat_filter: bool = True, + reinforce_access: bool = True, + ) -> SearchExecutionResult: + if retriever is None: + return SearchExecutionResult(success=False, error="知识检索器未初始化") + + query_type = SearchExecutionService._normalize_query_type(request.query_type) + query = _sanitize_text(request.query) + if query_type not in {"search", "time", "hybrid"}: + return SearchExecutionResult( + success=False, + error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid)", + ) + + if query_type in {"search", "hybrid"} and not query: + return SearchExecutionResult( + success=False, + error="search/hybrid 模式必须提供 query", + ) + + top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k( + plugin_config, query_type, request.top_k + ) + if not top_k_ok: + return SearchExecutionResult(success=False, error=top_k_error) + + temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal( + plugin_config=plugin_config, + query_type=query_type, + time_from_raw=request.time_from, + time_to_raw=request.time_to, + person=request.person, + source=request.source, + ) + if not temporal_ok: + return SearchExecutionResult(success=False, error=temporal_error) + + plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config) + if ( + enforce_chat_filter + and plugin_instance is not None + and hasattr(plugin_instance, "is_chat_enabled") + ): + if not plugin_instance.is_chat_enabled( + stream_id=request.stream_id, + group_id=request.group_id, + user_id=request.user_id, + ): + logger.info( + "检索请求被聊天过滤拦截: " + f"caller={request.caller}, " + f"stream_id={request.stream_id}" + ) + return SearchExecutionResult( + success=True, + query_type=query_type, + query=query, + top_k=top_k, + time_from=request.time_from, + time_to=request.time_to, + person=request.person, + source=request.source, + temporal=temporal, + results=[], + elapsed_ms=0.0, + chat_filtered=True, + dedup_hit=False, + ) + + request_key = SearchExecutionService._build_request_key( + request=request, + query_type=query_type, + top_k=top_k, + temporal=temporal, + ) + + async def _executor() -> Dict[str, Any]: + original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) + setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) + started_at = time.time() + try: + retrieved = await retriever.retrieve( + query=query, + top_k=top_k, + temporal=temporal, + ) + + should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None + if ( + query_type == "time" + and not query + and bool( + _get_config_value( + plugin_config, + "retrieval.time.skip_threshold_when_query_empty", + True, + ) + ) + ): + should_apply_threshold = False + + if should_apply_threshold: + retrieved = threshold_filter.filter(retrieved) + + if ( + reinforce_access + and plugin_instance is not None + and hasattr(plugin_instance, "reinforce_access") + ): + relation_hashes = [ + item.hash_value + for item in retrieved + if getattr(item, "result_type", "") == "relation" + ] + if relation_hashes: + await plugin_instance.reinforce_access(relation_hashes) + + if query_type == "search": + graph_store = SearchExecutionService._resolve_runtime_component( + plugin_config, plugin_instance, "graph_store" + ) + metadata_store = SearchExecutionService._resolve_runtime_component( + plugin_config, plugin_instance, "metadata_store" + ) + fallback_enabled = bool( + _get_config_value( + plugin_config, + "retrieval.search.smart_fallback.enabled", + True, + ) + ) + fallback_threshold = float( + _get_config_value( + plugin_config, + "retrieval.search.smart_fallback.threshold", + 0.6, + ) + ) + retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback( + query=query, + results=list(retrieved), + graph_store=graph_store, + metadata_store=metadata_store, + enabled=fallback_enabled, + threshold=fallback_threshold, + ) + if fallback_triggered: + logger.info( + "metric.smart_fallback_triggered_count=1 " + f"caller={request.caller} " + f"added={fallback_added}" + ) + + dedup_enabled = bool( + _get_config_value( + plugin_config, + "retrieval.search.safe_content_dedup.enabled", + True, + ) + ) + if dedup_enabled: + retrieved, removed_count = apply_safe_content_dedup(list(retrieved)) + if removed_count > 0: + logger.info( + f"metric.safe_dedup_removed_count={removed_count} " + f"caller={request.caller}" + ) + + elapsed_ms = (time.time() - started_at) * 1000.0 + return {"results": retrieved, "elapsed_ms": elapsed_ms} + finally: + setattr(retriever.config, "enable_ppr", original_ppr) + + dedup_hit = False + try: + # 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。 + bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning" + if ( + not bypass_request_dedup + and + plugin_instance is not None + and hasattr(plugin_instance, "execute_request_with_dedup") + ): + dedup_hit, payload = await plugin_instance.execute_request_with_dedup( + request_key, + _executor, + ) + else: + payload = await _executor() + except Exception as e: + return SearchExecutionResult(success=False, error=f"知识检索失败: {e}") + + if dedup_hit: + logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}") + + return SearchExecutionResult( + success=True, + query_type=query_type, + query=query, + top_k=top_k, + time_from=request.time_from, + time_to=request.time_to, + person=request.person, + source=request.source, + temporal=temporal, + results=payload.get("results", []), + elapsed_ms=float(payload.get("elapsed_ms", 0.0)), + chat_filtered=False, + dedup_hit=bool(dedup_hit), + ) + + @staticmethod + def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: + serialized: List[Dict[str, Any]] = [] + for item in results: + metadata = dict(getattr(item, "metadata", {}) or {}) + if "time_meta" not in metadata: + metadata["time_meta"] = {} + serialized.append( + { + "hash": getattr(item, "hash_value", ""), + "type": getattr(item, "result_type", ""), + "score": float(getattr(item, "score", 0.0)), + "content": getattr(item, "content", ""), + "metadata": metadata, + } + ) + return serialized diff --git a/plugins/A_memorix/core/utils/summary_importer.py b/plugins/A_memorix/core/utils/summary_importer.py new file mode 100644 index 00000000..b6271db4 --- /dev/null +++ b/plugins/A_memorix/core/utils/summary_importer.py @@ -0,0 +1,425 @@ +""" +聊天总结与知识导入工具 + +该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系 +导入到 A_memorix 的存储组件中。 +""" + +import time +import json +import re +import traceback +from typing import List, Dict, Any, Tuple, Optional +from pathlib import Path + +from src.common.logger import get_logger +from src.services import llm_service as llm_api +from src.services import message_service as message_api +from src.config.config import global_config, model_config as host_model_config +from src.config.model_configs import TaskConfig + +from ..storage import ( + KnowledgeType, + VectorStore, + GraphStore, + MetadataStore, + resolve_stored_knowledge_type, +) +from ..embedding import EmbeddingAPIAdapter +from .relation_write_service import RelationWriteService +from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check + +logger = get_logger("A_Memorix.SummaryImporter") + +# 默认总结提示词模版 +SUMMARY_PROMPT_TEMPLATE = """ +你是 {bot_name}。{personality_context} +现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。 + +聊天记录内容: +{chat_history} + +请完成以下任务: +1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。 +2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。 + +请严格以 JSON 格式输出,格式如下: +{{ + "summary": "总结文本内容", + "entities": ["张三", "李四"], + "relations": [ + {{"subject": "张三", "predicate": "认识", "object": "李四"}} + ] +}} + +注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。 +""" + +class SummaryImporter: + """总结并导入知识的工具类""" + + def __init__( + self, + vector_store: VectorStore, + graph_store: GraphStore, + metadata_store: MetadataStore, + embedding_manager: EmbeddingAPIAdapter, + plugin_config: dict + ): + self.vector_store = vector_store + self.graph_store = graph_store + self.metadata_store = metadata_store + self.embedding_manager = embedding_manager + self.plugin_config = plugin_config + self.relation_write_service: Optional[RelationWriteService] = ( + plugin_config.get("relation_write_service") + if isinstance(plugin_config, dict) + else None + ) + + def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]: + """标准化 summarization.model_name 配置(vNext 仅接受字符串数组)。""" + if raw_value is None: + return ["auto"] + if isinstance(raw_value, list): + selectors = [str(x).strip() for x in raw_value if str(x).strip()] + return selectors or ["auto"] + raise ValueError( + "summarization.model_name 在 vNext 必须为 List[str]。" + " 请执行 scripts/release_vnext_migrate.py migrate。" + ) + + def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]: + """ + 选择总结默认任务,避免错误落到 embedding 任务。 + 优先级:replyer > utils > planner > tool_use > 其他非 embedding。 + """ + preferred = ("replyer", "utils", "planner", "tool_use") + for name in preferred: + cfg = available_tasks.get(name) + if cfg and cfg.model_list: + return name, cfg + + for name, cfg in available_tasks.items(): + if name != "embedding" and cfg.model_list: + return name, cfg + + for name, cfg in available_tasks.items(): + if cfg.model_list: + return name, cfg + + return None, None + + def _resolve_summary_model_config(self) -> Optional[TaskConfig]: + """ + 解析 summarization.model_name 为 TaskConfig。 + 支持: + - "auto" + - "replyer"(任务名) + - "some-model-name"(具体模型名) + - ["utils:model1", "utils:model2", "replyer"](数组混合语法) + """ + available_tasks = llm_api.get_available_models() + if not available_tasks: + return None + + raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto") + selectors = self._normalize_summary_model_selectors(raw_cfg) + default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks) + + selected_models: List[str] = [] + base_cfg: Optional[TaskConfig] = None + model_dict = getattr(host_model_config, "models_dict", {}) + + def _append_models(models: List[str]): + for model_name in models: + if model_name and model_name not in selected_models: + selected_models.append(model_name) + + for raw_selector in selectors: + selector = raw_selector.strip() + if not selector: + continue + + if selector.lower() == "auto": + if default_task_cfg: + _append_models(default_task_cfg.model_list) + if base_cfg is None: + base_cfg = default_task_cfg + continue + + if ":" in selector: + task_name, model_name = selector.split(":", 1) + task_name = task_name.strip() + model_name = model_name.strip() + task_cfg = available_tasks.get(task_name) + if not task_cfg: + logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过") + continue + + if base_cfg is None: + base_cfg = task_cfg + + if not model_name or model_name.lower() == "auto": + _append_models(task_cfg.model_list) + continue + + if model_name in model_dict or model_name in task_cfg.model_list: + _append_models([model_name]) + else: + logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过") + continue + + task_cfg = available_tasks.get(selector) + if task_cfg: + _append_models(task_cfg.model_list) + if base_cfg is None: + base_cfg = task_cfg + continue + + if selector in model_dict: + _append_models([selector]) + continue + + logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过") + + if not selected_models: + if default_task_cfg: + _append_models(default_task_cfg.model_list) + if base_cfg is None: + base_cfg = default_task_cfg + else: + first_cfg = next(iter(available_tasks.values())) + _append_models(first_cfg.model_list) + if base_cfg is None: + base_cfg = first_cfg + + if not selected_models: + return None + + template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values())) + return TaskConfig( + model_list=selected_models, + max_tokens=template_cfg.max_tokens, + temperature=template_cfg.temperature, + slow_threshold=template_cfg.slow_threshold, + selection_strategy=template_cfg.selection_strategy, + ) + + async def import_from_stream( + self, + stream_id: str, + context_length: Optional[int] = None, + include_personality: Optional[bool] = None + ) -> Tuple[bool, str]: + """ + 从指定的聊天流中提取记录并执行总结导入 + + Args: + stream_id: 聊天流 ID + context_length: 总结的历史消息条数 + include_personality: 是否包含人设 + + Returns: + Tuple[bool, str]: (是否成功, 结果消息) + """ + try: + self_check_ok, self_check_msg = await self._ensure_runtime_self_check() + if not self_check_ok: + return False, f"导入前自检失败: {self_check_msg}" + + # 1. 获取配置 + if context_length is None: + context_length = self.plugin_config.get("summarization", {}).get("context_length", 50) + + if include_personality is None: + include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True) + + # 2. 获取历史消息 + # 获取当前时间之前的消息 + now = time.time() + messages = message_api.get_messages_before_time_in_chat( + chat_id=stream_id, + timestamp=now, + limit=context_length + ) + + if not messages: + return False, "未找到有效的聊天记录进行总结" + + # 转换为可读文本 + chat_history_text = message_api.build_readable_messages(messages) + + # 3. 准备提示词内容 + bot_name = global_config.bot.nickname or "机器人" + personality_context = "" + if include_personality: + personality = getattr(global_config.bot, "personality", "") + if personality: + personality_context = f"你的性格设定是:{personality}" + + # 4. 调用 LLM + prompt = SUMMARY_PROMPT_TEMPLATE.format( + bot_name=bot_name, + personality_context=personality_context, + chat_history=chat_history_text + ) + + model_config_to_use = self._resolve_summary_model_config() + if model_config_to_use is None: + return False, "未找到可用的总结模型配置" + + logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}") + logger.info(f"总结模型候选列表: {model_config_to_use.model_list}") + + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config_to_use, + request_type="A_Memorix.ChatSummarization" + ) + + if not success or not response: + return False, "LLM 生成总结失败" + + # 5. 解析结果 + data = self._parse_llm_response(response) + if not data or "summary" not in data: + return False, "解析 LLM 响应失败或总结为空" + + summary_text = data["summary"] + entities = data.get("entities", []) + relations = data.get("relations", []) + msg_times = [ + float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)()) + for msg in messages + if getattr(msg, "time", None) is not None + ] + time_meta = {} + if msg_times: + time_meta = { + "event_time_start": min(msg_times), + "event_time_end": max(msg_times), + "time_granularity": "minute", + "time_confidence": 0.95, + } + + # 6. 执行导入 + await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta) + + # 7. 持久化 + self.vector_store.save() + self.graph_store.save() + + result_msg = ( + f"✅ 总结导入成功\n" + f"📝 总结长度: {len(summary_text)}\n" + f"📌 提取实体: {len(entities)}\n" + f"🔗 提取关系: {len(relations)}" + ) + return True, result_msg + + except Exception as e: + logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}") + return False, f"错误: {str(e)}" + + async def _ensure_runtime_self_check(self) -> Tuple[bool, str]: + plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None + if plugin_instance is not None: + report = await ensure_runtime_self_check(plugin_instance) + else: + report = await run_embedding_runtime_self_check( + config=self.plugin_config, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + if bool(report.get("ok", False)): + return True, "" + return ( + False, + f"{report.get('message', 'unknown')} " + f"(configured={report.get('configured_dimension', 0)}, " + f"store={report.get('vector_store_dimension', 0)}, " + f"encoded={report.get('encoded_dimension', 0)})", + ) + + def _parse_llm_response(self, response: str) -> Dict[str, Any]: + """解析 LLM 返回的 JSON""" + try: + # 尝试查找 JSON + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + return json.loads(json_match.group()) + return {} + except Exception as e: + logger.warning(f"解析总结 JSON 失败: {e}") + return {} + + async def _execute_import( + self, + summary: str, + entities: List[str], + relations: List[Dict[str, str]], + stream_id: str, + time_meta: Optional[Dict[str, Any]] = None, + ): + """将数据写入存储""" + # 获取默认知识类型 + type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative") + try: + knowledge_type = resolve_stored_knowledge_type(type_str, content=summary) + except ValueError: + logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative") + knowledge_type = KnowledgeType.NARRATIVE + + # 导入总结文本 + hash_value = self.metadata_store.add_paragraph( + content=summary, + source=f"chat_summary:{stream_id}", + knowledge_type=knowledge_type.value, + time_meta=time_meta, + ) + + embedding = await self.embedding_manager.encode(summary) + self.vector_store.add( + vectors=embedding.reshape(1, -1), + ids=[hash_value] + ) + + # 导入实体 + if entities: + self.graph_store.add_nodes(entities) + + # 导入关系 + rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {}) + if not isinstance(rv_cfg, dict): + rv_cfg = {} + write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) + for rel in relations: + s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object") + if all([s, p, o]): + if self.relation_write_service is not None: + await self.relation_write_service.upsert_relation_with_vector( + subject=s, + predicate=p, + obj=o, + confidence=1.0, + source_paragraph=summary, + write_vector=write_vector, + ) + else: + # 写入元数据 + rel_hash = self.metadata_store.add_relation( + subject=s, + predicate=p, + obj=o, + confidence=1.0, + source_paragraph=summary + ) + # 写入图数据库(写入 relation_hashes,确保后续可按关系精确修剪) + self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash]) + try: + self.metadata_store.set_relation_vector_state(rel_hash, "none") + except Exception: + pass + + logger.info(f"总结导入完成: hash={hash_value[:8]}") diff --git a/plugins/A_memorix/core/utils/web_import_manager.py b/plugins/A_memorix/core/utils/web_import_manager.py new file mode 100644 index 00000000..b088be1f --- /dev/null +++ b/plugins/A_memorix/core/utils/web_import_manager.py @@ -0,0 +1,3522 @@ +""" +Web Import Task Manager + +为 A_Memorix WebUI 提供导入任务队列、状态管理、并发调度与取消/重试能力。 +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import os +import shutil +import sys +import time +import traceback +import uuid +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +from src.common.logger import get_logger +from src.services import llm_service as llm_api + +from ..storage import ( + parse_import_strategy, + resolve_stored_knowledge_type, + select_import_strategy, + KnowledgeType, + MetadataStore, +) +from ..storage.type_detection import looks_like_quote_text +from ..utils.import_payloads import normalize_paragraph_import_item +from ..utils.runtime_self_check import ensure_runtime_self_check +from ..utils.time_parser import normalize_time_meta +from ..storage.knowledge_types import ImportStrategy +from ..strategies.base import ProcessedChunk, KnowledgeType as StrategyKnowledgeType +from ..strategies.narrative import NarrativeStrategy +from ..strategies.factual import FactualStrategy +from ..strategies.quote import QuoteStrategy + +logger = get_logger("A_Memorix.WebImportManager") + + +TASK_STATUS = { + "queued", + "preparing", + "running", + "cancel_requested", + "cancelled", + "completed", + "completed_with_errors", + "failed", +} + +FILE_STATUS = { + "queued", + "preparing", + "splitting", + "extracting", + "writing", + "saving", + "completed", + "failed", + "cancelled", +} + +CHUNK_STATUS = { + "queued", + "extracting", + "writing", + "completed", + "failed", + "cancelled", +} + + +def _now() -> float: + return time.time() + + +def _coerce_int(value: Any, default: int) -> int: + try: + return int(value) + except Exception: + return default + + +def _coerce_bool(value: Any, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + text = str(value).strip().lower() + if text in {"1", "true", "yes", "y", "on"}: + return True + if text in {"0", "false", "no", "n", "off", ""}: + return False + return default + + +def _clamp(value: int, min_value: int, max_value: int) -> int: + return max(min_value, min(max_value, value)) + + +def _coerce_list(value: Any) -> List[str]: + if value is None: + return [] + if isinstance(value, list): + raw_items = value + else: + text = str(value or "").replace("\r", "\n") + raw_items = [] + for seg in text.split("\n"): + raw_items.extend(seg.split(",")) + + out: List[str] = [] + seen = set() + for item in raw_items: + v = str(item or "").strip() + if not v: + continue + key = v.lower() + if key in seen: + continue + seen.add(key) + out.append(v) + return out + + +def _parse_optional_positive_int(value: Any, field_name: str) -> Optional[int]: + if value is None: + return None + text = str(value).strip() + if text == "": + return None + try: + parsed = int(text) + except Exception: + raise ValueError(f"{field_name} 必须为整数") + if parsed <= 0: + raise ValueError(f"{field_name} 必须 > 0") + return parsed + + +def _safe_filename(name: str) -> str: + base = os.path.basename(str(name or "").strip()) + if not base: + return f"unnamed_{uuid.uuid4().hex[:8]}.txt" + return base + + +def _storage_type_from_strategy(strategy_type: StrategyKnowledgeType) -> str: + if strategy_type == StrategyKnowledgeType.NARRATIVE: + return KnowledgeType.NARRATIVE.value + if strategy_type == StrategyKnowledgeType.FACTUAL: + return KnowledgeType.FACTUAL.value + if strategy_type == StrategyKnowledgeType.QUOTE: + return KnowledgeType.QUOTE.value + return KnowledgeType.MIXED.value + + +@dataclass +class ImportChunkRecord: + chunk_id: str + index: int + chunk_type: str + status: str = "queued" + step: str = "queued" + failed_at: str = "" + retryable: bool = False + error: str = "" + progress: float = 0.0 + content_preview: str = "" + updated_at: float = field(default_factory=_now) + + def to_dict(self) -> Dict[str, Any]: + return { + "chunk_id": self.chunk_id, + "index": self.index, + "chunk_type": self.chunk_type, + "status": self.status, + "step": self.step, + "failed_at": self.failed_at, + "retryable": self.retryable, + "error": self.error, + "progress": self.progress, + "content_preview": self.content_preview, + "updated_at": self.updated_at, + } + + +@dataclass +class ImportFileRecord: + file_id: str + name: str + source_kind: str + input_mode: str + status: str = "queued" + current_step: str = "queued" + detected_strategy_type: str = "unknown" + total_chunks: int = 0 + done_chunks: int = 0 + failed_chunks: int = 0 + cancelled_chunks: int = 0 + progress: float = 0.0 + error: str = "" + chunks: List[ImportChunkRecord] = field(default_factory=list) + created_at: float = field(default_factory=_now) + updated_at: float = field(default_factory=_now) + temp_path: Optional[str] = None + source_path: Optional[str] = None + inline_content: Optional[str] = None + content_hash: str = "" + retry_chunk_indexes: List[int] = field(default_factory=list) + retry_mode: str = "" + + def to_dict(self, include_chunks: bool = False) -> Dict[str, Any]: + payload = { + "file_id": self.file_id, + "name": self.name, + "source_kind": self.source_kind, + "input_mode": self.input_mode, + "status": self.status, + "current_step": self.current_step, + "detected_strategy_type": self.detected_strategy_type, + "total_chunks": self.total_chunks, + "done_chunks": self.done_chunks, + "failed_chunks": self.failed_chunks, + "cancelled_chunks": self.cancelled_chunks, + "progress": self.progress, + "error": self.error, + "created_at": self.created_at, + "updated_at": self.updated_at, + "source_path": self.source_path or "", + "content_hash": self.content_hash or "", + "retry_chunk_indexes": list(self.retry_chunk_indexes or []), + "retry_mode": self.retry_mode or "", + } + if include_chunks: + payload["chunks"] = [chunk.to_dict() for chunk in self.chunks] + return payload + + +@dataclass +class ImportTaskRecord: + task_id: str + source: str + params: Dict[str, Any] + status: str = "queued" + current_step: str = "queued" + total_chunks: int = 0 + done_chunks: int = 0 + failed_chunks: int = 0 + cancelled_chunks: int = 0 + progress: float = 0.0 + error: str = "" + files: List[ImportFileRecord] = field(default_factory=list) + created_at: float = field(default_factory=_now) + started_at: Optional[float] = None + finished_at: Optional[float] = None + updated_at: float = field(default_factory=_now) + schema_detected: str = "" + artifact_paths: Dict[str, str] = field(default_factory=dict) + rollback_info: Dict[str, Any] = field(default_factory=dict) + retry_parent_task_id: str = "" + retry_summary: Dict[str, Any] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "task_id": self.task_id, + "source": self.source, + "status": self.status, + "current_step": self.current_step, + "total_chunks": self.total_chunks, + "done_chunks": self.done_chunks, + "failed_chunks": self.failed_chunks, + "cancelled_chunks": self.cancelled_chunks, + "progress": self.progress, + "error": self.error, + "file_count": len(self.files), + "created_at": self.created_at, + "started_at": self.started_at, + "finished_at": self.finished_at, + "updated_at": self.updated_at, + "task_kind": str(self.params.get("task_kind") or self.source), + "schema_detected": self.schema_detected, + "artifact_paths": dict(self.artifact_paths), + "rollback_info": dict(self.rollback_info), + "retry_parent_task_id": self.retry_parent_task_id or "", + "retry_summary": dict(self.retry_summary), + } + + def to_detail(self, include_chunks: bool = False) -> Dict[str, Any]: + payload = self.to_summary() + payload["params"] = self.params + payload["files"] = [f.to_dict(include_chunks=include_chunks) for f in self.files] + return payload + + +class ImportTaskManager: + def __init__(self, plugin: Any): + self.plugin = plugin + self._lock = asyncio.Lock() + self._storage_lock = asyncio.Lock() + + self._tasks: Dict[str, ImportTaskRecord] = {} + self._task_order: deque[str] = deque() + self._queue: deque[str] = deque() + self._active_task_id: Optional[str] = None + + self._worker_task: Optional[asyncio.Task] = None + self._stopping = False + + self._temp_root = self._resolve_temp_root() + self._temp_root.mkdir(parents=True, exist_ok=True) + self._reports_root = self._resolve_reports_root() + self._reports_root.mkdir(parents=True, exist_ok=True) + self._manifest_path = self._resolve_manifest_path() + self._manifest_cache: Optional[Dict[str, Any]] = None + self._write_changed_callback: Optional[Callable[[Dict[str, Any]], Any]] = None + + def set_write_changed_callback(self, callback: Optional[Callable[[Dict[str, Any]], Any]]) -> None: + self._write_changed_callback = callback + + async def _notify_write_changed(self, payload: Dict[str, Any]) -> None: + callback = self._write_changed_callback + if callback is None: + return + try: + maybe_awaitable = callback(payload) + if asyncio.iscoroutine(maybe_awaitable): + await maybe_awaitable + except Exception as e: + logger.warning(f"写入变更回调执行失败: {e}") + + def _resolve_temp_root(self) -> Path: + data_dir = Path(self.plugin.get_config("storage.data_dir", "./data")) + if str(data_dir).startswith("."): + plugin_dir = Path(__file__).resolve().parents[2] + data_dir = (plugin_dir / data_dir).resolve() + return data_dir / "web_import_tmp" + + def _resolve_reports_root(self) -> Path: + return self._resolve_data_dir() / "web_import_reports" + + def _resolve_manifest_path(self) -> Path: + return self._resolve_data_dir() / "import_manifest.json" + + def _resolve_staging_root(self) -> Path: + return self._resolve_data_dir() / "import_staging" + + def _resolve_backup_root(self) -> Path: + return self._resolve_data_dir() / "import_backup" + + def _resolve_repo_root(self) -> Path: + return Path(__file__).resolve().parents[3] + + def _resolve_data_dir(self) -> Path: + data_dir = Path(self.plugin.get_config("storage.data_dir", "./data")) + if str(data_dir).startswith("."): + plugin_dir = Path(__file__).resolve().parents[2] + data_dir = (plugin_dir / data_dir).resolve() + return data_dir.resolve() + + def _resolve_migration_script(self) -> Path: + return Path(__file__).resolve().parents[2] / "scripts" / "migrate_maibot_memory.py" + + def _default_maibot_source_db(self) -> Path: + # A_memorix/core/utils -> workspace root + return self._resolve_repo_root() / "MaiBot" / "data" / "MaiBot.db" + + def _cfg(self, key: str, default: Any) -> Any: + return self.plugin.get_config(key, default) + + def _cfg_int(self, key: str, default: int) -> int: + return _coerce_int(self._cfg(key, default), default) + + def _is_enabled(self) -> bool: + return bool(self._cfg("web.import.enabled", True)) + + def _queue_limit(self) -> int: + return max(1, self._cfg_int("web.import.max_queue_size", 20)) + + def _max_files_per_task(self) -> int: + return max(1, self._cfg_int("web.import.max_files_per_task", 200)) + + def _max_file_size_bytes(self) -> int: + mb = max(1, self._cfg_int("web.import.max_file_size_mb", 20)) + return mb * 1024 * 1024 + + def _max_paste_chars(self) -> int: + return max(1000, self._cfg_int("web.import.max_paste_chars", 200000)) + + def _default_file_concurrency(self) -> int: + return max(1, self._cfg_int("web.import.default_file_concurrency", 2)) + + def _default_chunk_concurrency(self) -> int: + return max(1, self._cfg_int("web.import.default_chunk_concurrency", 4)) + + def _max_file_concurrency(self) -> int: + return max(1, self._cfg_int("web.import.max_file_concurrency", 6)) + + def _max_chunk_concurrency(self) -> int: + return max(1, self._cfg_int("web.import.max_chunk_concurrency", 12)) + + def _llm_retry_config(self) -> Dict[str, float]: + retries = max(0, self._cfg_int("web.import.llm_retry.max_attempts", 4)) + min_wait = max(0.1, float(self._cfg("web.import.llm_retry.min_wait_seconds", 3) or 3)) + max_wait = max(min_wait, float(self._cfg("web.import.llm_retry.max_wait_seconds", 40) or 40)) + mult = max(1.0, float(self._cfg("web.import.llm_retry.backoff_multiplier", 3) or 3)) + return { + "retries": retries, + "min_wait": min_wait, + "max_wait": max_wait, + "multiplier": mult, + } + + def _default_path_aliases(self) -> Dict[str, str]: + plugin_dir = Path(__file__).resolve().parents[2] + repo_root = self._resolve_repo_root() + return { + "raw": str((plugin_dir / "data" / "raw").resolve()), + "lpmm": str((repo_root / "data" / "lpmm_storage").resolve()), + "plugin_data": str((plugin_dir / "data").resolve()), + } + + def get_path_aliases(self) -> Dict[str, str]: + configured = self._cfg("web.import.path_aliases", self._default_path_aliases()) + if not isinstance(configured, dict): + configured = self._default_path_aliases() + + repo_root = self._resolve_repo_root() + result: Dict[str, str] = {} + for alias, raw_path in configured.items(): + key = str(alias or "").strip() + if not key: + continue + text = str(raw_path or "").strip() + if not text: + continue + if text.startswith("\\\\"): + continue + p = Path(text) + if not p.is_absolute(): + p = (repo_root / p).resolve() + else: + p = p.resolve() + result[key] = str(p) + + defaults = self._default_path_aliases() + for key, path in defaults.items(): + result.setdefault(key, path) + return result + + def resolve_path_alias( + self, + alias: str, + relative_path: str = "", + *, + must_exist: bool = False, + ) -> Path: + alias_key = str(alias or "").strip() + aliases = self.get_path_aliases() + if alias_key not in aliases: + raise ValueError(f"未知路径别名: {alias_key}") + + root = Path(aliases[alias_key]).resolve() + rel = str(relative_path or "").strip().replace("\\", "/") + if rel.startswith("/") or rel.startswith("\\") or rel.startswith("//"): + raise ValueError("relative_path 不能为绝对路径") + if ":" in rel: + raise ValueError("relative_path 不允许包含盘符") + + candidate = (root / rel).resolve() if rel else root + try: + candidate.relative_to(root) + except ValueError: + raise ValueError("路径越界:relative_path 超出白名单目录") + if must_exist and not candidate.exists(): + raise ValueError(f"路径不存在: {candidate}") + return candidate + + async def resolve_path_request(self, payload: Dict[str, Any]) -> Dict[str, Any]: + alias = str(payload.get("alias") or "").strip() + relative_path = str(payload.get("relative_path") or "").strip() + must_exist = _coerce_bool(payload.get("must_exist"), True) + resolved = self.resolve_path_alias(alias, relative_path, must_exist=must_exist) + return { + "alias": alias, + "relative_path": relative_path, + "resolved_path": str(resolved), + "exists": resolved.exists(), + "is_file": resolved.is_file(), + "is_dir": resolved.is_dir(), + } + + def _load_manifest(self) -> Dict[str, Any]: + if self._manifest_cache is not None: + return self._manifest_cache + path = self._manifest_path + if not path.exists(): + self._manifest_cache = {} + return self._manifest_cache + try: + payload = json.loads(path.read_text(encoding="utf-8")) + if isinstance(payload, dict): + self._manifest_cache = payload + else: + self._manifest_cache = {} + except Exception: + self._manifest_cache = {} + return self._manifest_cache + + def _save_manifest(self, payload: Dict[str, Any]) -> None: + path = self._manifest_path + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + self._manifest_cache = payload + + def _clear_manifest(self) -> None: + self._save_manifest({}) + + def _normalize_manifest_path(self, raw_path: str) -> str: + text = str(raw_path or "").strip() + if not text: + return "" + return text.replace("\\", "/").strip().lower() + + def _match_manifest_item_for_source(self, source: str, item: Dict[str, Any]) -> bool: + source_text = str(source or "").strip() + if not source_text or ":" not in source_text: + return False + prefix, tail = source_text.split(":", 1) + source_kind = prefix.strip().lower() + source_value = tail.strip() + if not source_value: + return False + + item_kind = str(item.get("source_kind") or "").strip().lower() + item_name = str(item.get("name") or "").strip() + item_path_norm = self._normalize_manifest_path(item.get("source_path") or "") + + if source_kind in {"raw_scan", "lpmm_openie"}: + source_path_norm = self._normalize_manifest_path(source_value) + if source_path_norm and item_path_norm and source_path_norm == item_path_norm and item_kind == source_kind: + return True + + if source_kind == "web_import": + return item_kind in {"upload", "paste"} and item_name == source_value + + if source_kind == "lpmm_openie": + source_name = Path(source_value).name + return item_kind == "lpmm_openie" and item_name == source_name + + return False + + async def invalidate_manifest_for_sources(self, sources: List[str]) -> Dict[str, Any]: + requested_sources: List[str] = [] + seen_sources = set() + for raw in sources or []: + source = str(raw or "").strip() + if not source: + continue + key = source.lower() + if key in seen_sources: + continue + seen_sources.add(key) + requested_sources.append(source) + + result: Dict[str, Any] = { + "requested_sources": requested_sources, + "removed_count": 0, + "removed_keys": [], + "remaining_count": 0, + "unmatched_sources": [], + "warnings": [], + } + + async with self._lock: + manifest = self._load_manifest() + if not isinstance(manifest, dict): + manifest = {} + + valid_items: List[Tuple[str, Dict[str, Any]]] = [] + malformed_keys: List[str] = [] + for key, item in manifest.items(): + if isinstance(item, dict): + valid_items.append((str(key), item)) + else: + malformed_keys.append(str(key)) + + keys_to_remove = set() + for source in requested_sources: + matched = False + for key, item in valid_items: + if self._match_manifest_item_for_source(source, item): + keys_to_remove.add(key) + matched = True + if not matched: + result["unmatched_sources"].append(source) + + if keys_to_remove: + for key in keys_to_remove: + manifest.pop(key, None) + self._save_manifest(manifest) + + result["removed_keys"] = sorted(keys_to_remove) + result["removed_count"] = len(keys_to_remove) + result["remaining_count"] = len(manifest) + + if malformed_keys: + preview = ", ".join(malformed_keys[:5]) + extra = "" if len(malformed_keys) <= 5 else f" ... (+{len(malformed_keys) - 5})" + result["warnings"].append( + f"manifest 条目结构异常,已跳过 {len(malformed_keys)} 项: {preview}{extra}" + ) + + return result + + def _manifest_key_for_file(self, file_record: ImportFileRecord, content_hash: str, dedupe_policy: str) -> str: + if dedupe_policy == "content_hash": + return f"hash:{content_hash}" + if file_record.source_path: + return f"path:{Path(file_record.source_path).as_posix().lower()}" + return f"hash:{content_hash}" + + def _is_manifest_hit( + self, + file_record: ImportFileRecord, + content_hash: str, + dedupe_policy: str, + ) -> bool: + key = self._manifest_key_for_file(file_record, content_hash, dedupe_policy) + manifest = self._load_manifest() + item = manifest.get(key) + if not isinstance(item, dict): + return False + return str(item.get("hash") or "") == content_hash and bool(item.get("imported")) + + def _record_manifest_import( + self, + file_record: ImportFileRecord, + content_hash: str, + dedupe_policy: str, + task_id: str, + ) -> None: + key = self._manifest_key_for_file(file_record, content_hash, dedupe_policy) + manifest = self._load_manifest() + manifest[key] = { + "hash": content_hash, + "imported": True, + "timestamp": _now(), + "task_id": task_id, + "name": file_record.name, + "source_path": file_record.source_path or "", + "source_kind": file_record.source_kind, + } + self._save_manifest(manifest) + + def _normalize_common_import_params(self, payload: Dict[str, Any], *, default_dedupe: str) -> Dict[str, Any]: + input_mode = str(payload.get("input_mode", "text") or "text").strip().lower() + if input_mode not in {"text", "json"}: + raise ValueError("input_mode 必须为 text 或 json") + + file_concurrency = _coerce_int( + payload.get("file_concurrency", self._default_file_concurrency()), + self._default_file_concurrency(), + ) + chunk_concurrency = _coerce_int( + payload.get("chunk_concurrency", self._default_chunk_concurrency()), + self._default_chunk_concurrency(), + ) + file_concurrency = _clamp(file_concurrency, 1, self._max_file_concurrency()) + chunk_concurrency = _clamp(chunk_concurrency, 1, self._max_chunk_concurrency()) + + llm_enabled = _coerce_bool(payload.get("llm_enabled", True), True) + strategy_override = parse_import_strategy( + payload.get("strategy_override", "auto"), + default=ImportStrategy.AUTO, + ).value + + dedupe_policy = str(payload.get("dedupe_policy", default_dedupe) or default_dedupe).strip().lower() + if dedupe_policy not in {"content_hash", "manifest", "none"}: + raise ValueError("dedupe_policy 必须为 content_hash/manifest/none") + + chat_log = _coerce_bool(payload.get("chat_log"), False) + chat_reference_time = str(payload.get("chat_reference_time") or "").strip() or None + force = _coerce_bool(payload.get("force"), False) + clear_manifest = _coerce_bool(payload.get("clear_manifest"), False) + + return { + "input_mode": input_mode, + "file_concurrency": file_concurrency, + "chunk_concurrency": chunk_concurrency, + "llm_enabled": llm_enabled, + "strategy_override": strategy_override, + "chat_log": chat_log, + "chat_reference_time": chat_reference_time, + "force": force, + "clear_manifest": clear_manifest, + "dedupe_policy": dedupe_policy, + } + + def _normalize_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + params = self._normalize_common_import_params(payload, default_dedupe="content_hash") + params["task_kind"] = "upload" + return params + + def _normalize_raw_scan_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + params = self._normalize_common_import_params(payload, default_dedupe="manifest") + alias = str(payload.get("alias") or "raw").strip() + relative_path = str(payload.get("relative_path") or "").strip() + glob_pattern = str(payload.get("glob") or "*").strip() or "*" + recursive = _coerce_bool(payload.get("recursive"), True) + if ".." in relative_path.replace("\\", "/").split("/"): + raise ValueError("relative_path 不允许包含 ..") + params.update( + { + "task_kind": "raw_scan", + "alias": alias, + "relative_path": relative_path, + "glob": glob_pattern, + "recursive": recursive, + } + ) + return params + + def _normalize_lpmm_openie_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + params = self._normalize_common_import_params(payload, default_dedupe="manifest") + alias = str(payload.get("alias") or "lpmm").strip() + relative_path = str(payload.get("relative_path") or "").strip() + include_all_json = _coerce_bool(payload.get("include_all_json"), False) + params.update( + { + "task_kind": "lpmm_openie", + "alias": alias, + "relative_path": relative_path, + "include_all_json": include_all_json, + "input_mode": "json", + } + ) + return params + + def _normalize_temporal_backfill_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + alias = str(payload.get("alias") or "plugin_data").strip() + relative_path = str(payload.get("relative_path") or "").strip() + dry_run = _coerce_bool(payload.get("dry_run"), False) + no_created_fallback = _coerce_bool(payload.get("no_created_fallback"), False) + limit = _parse_optional_positive_int(payload.get("limit"), "limit") or 100000 + return { + "task_kind": "temporal_backfill", + "alias": alias, + "relative_path": relative_path, + "dry_run": dry_run, + "no_created_fallback": no_created_fallback, + "limit": limit, + } + + def _normalize_lpmm_convert_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + alias = str(payload.get("alias") or "lpmm").strip() + relative_path = str(payload.get("relative_path") or "").strip() + target_alias = str(payload.get("target_alias") or "plugin_data").strip() + target_relative_path = str(payload.get("target_relative_path") or "").strip() + dimension = _parse_optional_positive_int(payload.get("dimension"), "dimension") or _coerce_int( + self._cfg("embedding.dimension", 384), + 384, + ) + batch_size = _parse_optional_positive_int(payload.get("batch_size"), "batch_size") or 1024 + return { + "task_kind": "lpmm_convert", + "alias": alias, + "relative_path": relative_path, + "target_alias": target_alias, + "target_relative_path": target_relative_path, + "dimension": dimension, + "batch_size": batch_size, + } + + def _normalize_by_task_kind(self, task_kind: str, payload: Dict[str, Any]) -> Dict[str, Any]: + kind = str(task_kind or "").strip().lower() + if kind in {"upload", "paste"}: + params = self._normalize_params(payload) + params["task_kind"] = kind + return params + if kind == "maibot_migration": + return self._normalize_migration_params(payload) + if kind == "raw_scan": + return self._normalize_raw_scan_params(payload) + if kind == "lpmm_openie": + return self._normalize_lpmm_openie_params(payload) + if kind == "temporal_backfill": + return self._normalize_temporal_backfill_params(payload) + if kind == "lpmm_convert": + return self._normalize_lpmm_convert_params(payload) + # upload/paste 默认走通用文本导入参数 + return self._normalize_params(payload) + + def _normalize_migration_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + source_db = str(payload.get("source_db") or "").strip() + if not source_db: + source_db = str(self._default_maibot_source_db()) + + time_from = str(payload.get("time_from") or "").strip() or None + time_to = str(payload.get("time_to") or "").strip() or None + + stream_ids = _coerce_list(payload.get("stream_ids")) + group_ids = _coerce_list(payload.get("group_ids")) + user_ids = _coerce_list(payload.get("user_ids")) + + start_id = _parse_optional_positive_int(payload.get("start_id"), "start_id") + end_id = _parse_optional_positive_int(payload.get("end_id"), "end_id") + if start_id is not None and end_id is not None and start_id > end_id: + raise ValueError("start_id 不能大于 end_id") + + read_batch_size = _parse_optional_positive_int(payload.get("read_batch_size"), "read_batch_size") or 2000 + commit_window_rows = _parse_optional_positive_int(payload.get("commit_window_rows"), "commit_window_rows") or 20000 + embed_batch_size = _parse_optional_positive_int(payload.get("embed_batch_size"), "embed_batch_size") or 256 + entity_embed_batch_size = ( + _parse_optional_positive_int(payload.get("entity_embed_batch_size"), "entity_embed_batch_size") or 512 + ) + embed_workers = _parse_optional_positive_int(payload.get("embed_workers"), "embed_workers") + max_errors = _parse_optional_positive_int(payload.get("max_errors"), "max_errors") or 500 + log_every = _parse_optional_positive_int(payload.get("log_every"), "log_every") or 5000 + preview_limit = _parse_optional_positive_int(payload.get("preview_limit"), "preview_limit") or 20 + + no_resume = _coerce_bool(payload.get("no_resume"), False) + reset_state = _coerce_bool(payload.get("reset_state"), False) + dry_run = _coerce_bool(payload.get("dry_run"), False) + verify_only = _coerce_bool(payload.get("verify_only"), False) + + return { + "task_kind": "maibot_migration", + "source_db": source_db, + "target_data_dir": str(self._resolve_data_dir()), + "time_from": time_from, + "time_to": time_to, + "stream_ids": stream_ids, + "group_ids": group_ids, + "user_ids": user_ids, + "start_id": start_id, + "end_id": end_id, + "read_batch_size": read_batch_size, + "commit_window_rows": commit_window_rows, + "embed_batch_size": embed_batch_size, + "entity_embed_batch_size": entity_embed_batch_size, + "embed_workers": embed_workers, + "max_errors": max_errors, + "log_every": log_every, + "preview_limit": preview_limit, + "no_resume": no_resume, + "reset_state": reset_state, + "dry_run": dry_run, + "verify_only": verify_only, + } + + def _pending_task_count(self) -> int: + pending = 0 + for task in self._tasks.values(): + if task.status in {"queued", "preparing", "running", "cancel_requested"}: + pending += 1 + return pending + + async def _ensure_worker(self) -> None: + async with self._lock: + if self._worker_task and not self._worker_task.done(): + return + self._stopping = False + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def get_runtime_settings(self) -> Dict[str, Any]: + llm_retry = self._llm_retry_config() + return { + "max_queue_size": self._queue_limit(), + "max_files_per_task": self._max_files_per_task(), + "max_file_size_mb": self._cfg_int("web.import.max_file_size_mb", 20), + "max_paste_chars": self._max_paste_chars(), + "default_file_concurrency": self._default_file_concurrency(), + "default_chunk_concurrency": self._default_chunk_concurrency(), + "max_file_concurrency": self._max_file_concurrency(), + "max_chunk_concurrency": self._max_chunk_concurrency(), + "poll_interval_ms": max(200, self._cfg_int("web.import.poll_interval_ms", 1000)), + "maibot_source_db_default": str(self._default_maibot_source_db()), + "maibot_target_data_dir": str(self._resolve_data_dir()), + "path_aliases": self.get_path_aliases(), + "llm_retry": llm_retry, + "convert_enable_staging_switch": _coerce_bool( + self._cfg("web.import.convert.enable_staging_switch", True), True + ), + "convert_keep_backup_count": max(0, self._cfg_int("web.import.convert.keep_backup_count", 3)), + } + + def is_write_blocked(self) -> bool: + task_id = self._active_task_id + if not task_id: + return False + task = self._tasks.get(task_id) + if not task: + return False + return task.status in {"preparing", "running", "cancel_requested"} + + def _ensure_ready(self) -> None: + required_attrs = ("metadata_store", "vector_store", "graph_store", "embedding_manager") + + def _collect_missing() -> List[str]: + missing_local: List[str] = [] + for attr in required_attrs: + if getattr(self.plugin, attr, None) is None: + missing_local.append(attr) + return missing_local + + missing = _collect_missing() + if missing: + raise ValueError(f"导入依赖未初始化: {', '.join(missing)}") + ready_checker = getattr(self.plugin, "is_runtime_ready", None) + if callable(ready_checker) and not ready_checker(): + raise ValueError("插件运行时未就绪,请先完成 on_enable 初始化") + + def _scan_files( + self, + base_path: Path, + *, + recursive: bool, + glob_pattern: str, + allowed_exts: Optional[set[str]] = None, + ) -> List[Path]: + if base_path.is_file(): + candidates = [base_path] + else: + if recursive: + candidates = list(base_path.rglob(glob_pattern)) + else: + candidates = list(base_path.glob(glob_pattern)) + out: List[Path] = [] + for p in candidates: + if not p.is_file(): + continue + ext = p.suffix.lower() + if allowed_exts and ext not in allowed_exts: + continue + out.append(p.resolve()) + out.sort(key=lambda x: x.as_posix().lower()) + return out + + async def create_upload_task(self, files: List[Any], payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + if not files: + raise ValueError("至少需要上传一个文件") + + params = self._normalize_params(payload) + max_files = self._max_files_per_task() + if len(files) > max_files: + raise ValueError(f"单任务文件数超过上限: {max_files}") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="upload", + params=params, + status="queued", + current_step="queued", + ) + task_dir = self._temp_root / task.task_id + task_dir.mkdir(parents=True, exist_ok=True) + + max_size = self._max_file_size_bytes() + for idx, uploaded in enumerate(files): + file_id = uuid.uuid4().hex + if isinstance(uploaded, dict): + staged_path_raw = uploaded.get("staged_path") or uploaded.get("path") or "" + staged_path = Path(str(staged_path_raw or "")).expanduser().resolve() + if not staged_path.is_file(): + raise ValueError(f"上传暂存文件不存在: {staged_path}") + name = _safe_filename(uploaded.get("filename") or uploaded.get("name") or staged_path.name) + ext = Path(name).suffix.lower() + if ext not in {".txt", ".md", ".json"}: + raise ValueError(f"不支持的文件类型: {name}") + if staged_path.stat().st_size > max_size: + raise ValueError(f"文件超过大小限制: {name}") + temp_path = task_dir / f"{file_id}_{name}" + shutil.copy2(staged_path, temp_path) + else: + name = _safe_filename(getattr(uploaded, "filename", f"file_{idx}.txt")) + ext = Path(name).suffix.lower() + if ext not in {".txt", ".md", ".json"}: + raise ValueError(f"不支持的文件类型: {name}") + content = await uploaded.read() + if len(content) > max_size: + raise ValueError(f"文件超过大小限制: {name}") + temp_path = task_dir / f"{file_id}_{name}" + temp_path.write_bytes(content) + file_mode = "json" if ext == ".json" else params["input_mode"] + task.files.append( + ImportFileRecord( + file_id=file_id, + name=name, + source_kind="upload", + input_mode=file_mode, + temp_path=str(temp_path), + ) + ) + + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_paste_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + + params = self._normalize_params(payload) + params["task_kind"] = "paste" + content = str(payload.get("content", "") or "") + if not content.strip(): + raise ValueError("content 不能为空") + if len(content) > self._max_paste_chars(): + raise ValueError(f"粘贴内容超过限制: {self._max_paste_chars()} 字符") + + name = _safe_filename(payload.get("name") or f"paste_{int(_now())}.txt") + if params["input_mode"] == "json" and Path(name).suffix.lower() != ".json": + name = f"{Path(name).stem}.json" + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="paste", + params=params, + status="queued", + current_step="queued", + ) + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=name, + source_kind="paste", + input_mode=params["input_mode"], + inline_content=content, + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_raw_scan_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + params = self._normalize_raw_scan_params(payload) + source_path = self.resolve_path_alias( + params["alias"], + params["relative_path"], + must_exist=True, + ) + files = self._scan_files( + source_path, + recursive=bool(params["recursive"]), + glob_pattern=str(params["glob"] or "*"), + allowed_exts={".txt", ".md", ".json"}, + ) + if not files: + raise ValueError("未找到可导入文件") + if len(files) > self._max_files_per_task(): + raise ValueError(f"单任务文件数超过上限: {self._max_files_per_task()}") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="raw_scan", + params=params, + status="queued", + current_step="queued", + ) + for path in files: + mode = "json" if path.suffix.lower() == ".json" else params["input_mode"] + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=path.name, + source_kind="raw_scan", + input_mode=mode, + source_path=str(path), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_lpmm_openie_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + params = self._normalize_lpmm_openie_params(payload) + source_path = self.resolve_path_alias( + params["alias"], + params["relative_path"], + must_exist=True, + ) + files: List[Path] = [] + if source_path.is_file(): + files = [source_path] + else: + files = self._scan_files( + source_path, + recursive=True, + glob_pattern="*-openie.json", + allowed_exts={".json"}, + ) + if not files and params.get("include_all_json"): + files = self._scan_files( + source_path, + recursive=True, + glob_pattern="*.json", + allowed_exts={".json"}, + ) + if not files: + raise ValueError("未找到 LPMM OpenIE JSON 文件") + if len(files) > self._max_files_per_task(): + raise ValueError(f"单任务文件数超过上限: {self._max_files_per_task()}") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="lpmm_openie", + params=params, + status="queued", + current_step="queued", + schema_detected="lpmm_openie", + ) + for path in files: + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=path.name, + source_kind="lpmm_openie", + input_mode="json", + source_path=str(path), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_temporal_backfill_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + params = self._normalize_temporal_backfill_params(payload) + target_path = self.resolve_path_alias( + params["alias"], + params["relative_path"], + must_exist=True, + ) + if not target_path.is_dir(): + raise ValueError("temporal_backfill 目标路径必须为目录") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="temporal_backfill", + params=params, + status="queued", + current_step="queued", + ) + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=f"temporal_backfill_{int(_now())}", + source_kind="temporal_backfill", + input_mode="json", + source_path=str(target_path), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_lpmm_convert_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + params = self._normalize_lpmm_convert_params(payload) + source_path = self.resolve_path_alias( + params["alias"], + params["relative_path"], + must_exist=True, + ) + if not source_path.is_dir(): + raise ValueError("lpmm_convert 输入路径必须为目录") + target_path = self.resolve_path_alias( + params["target_alias"], + params["target_relative_path"], + must_exist=False, + ) + target_path.mkdir(parents=True, exist_ok=True) + if not target_path.is_dir(): + raise ValueError("lpmm_convert 目标路径必须为目录") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="lpmm_convert", + params={**params, "source_path": str(source_path), "target_path": str(target_path)}, + status="queued", + current_step="queued", + ) + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=f"lpmm_convert_{int(_now())}", + source_kind="lpmm_convert", + input_mode="json", + source_path=str(source_path), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_maibot_migration_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + + params = self._normalize_migration_params(payload) + script_path = self._resolve_migration_script() + if not script_path.exists(): + raise ValueError(f"迁移脚本不存在: {script_path}") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="maibot_migration", + params=params, + status="queued", + current_step="queued", + ) + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=f"maibot_migration_{int(_now())}", + source_kind="maibot_migration", + input_mode="text", + inline_content=json.dumps(params, ensure_ascii=False), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def list_tasks(self, limit: int = 50) -> List[Dict[str, Any]]: + async with self._lock: + task_ids = list(self._task_order)[: max(1, int(limit))] + return [self._tasks[task_id].to_summary() for task_id in task_ids if task_id in self._tasks] + + async def get_task(self, task_id: str, include_chunks: bool = False) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + return task.to_detail(include_chunks=include_chunks) + + async def get_chunks(self, task_id: str, file_id: str, offset: int = 0, limit: int = 50) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + file_obj = self._find_file(task, file_id) + if not file_obj: + return None + start = max(0, int(offset)) + size = max(1, min(500, int(limit))) + items = file_obj.chunks[start : start + size] + return { + "task_id": task_id, + "file_id": file_id, + "offset": start, + "limit": size, + "total": len(file_obj.chunks), + "items": [x.to_dict() for x in items], + } + + async def cancel_task(self, task_id: str) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + if task.status == "queued": + self._mark_task_cancelled_locked(task, "任务已取消") + self._queue = deque([x for x in self._queue if x != task_id]) + elif task.status in {"preparing", "running"}: + task.status = "cancel_requested" + task.current_step = "cancel_requested" + task.updated_at = _now() + return task.to_summary() + + def _build_retry_plan(self, task: ImportTaskRecord) -> Dict[str, Any]: + chunk_retry_candidates: List[Tuple[ImportFileRecord, List[int]]] = [] + file_fallback_candidates: List[ImportFileRecord] = [] + skipped: List[Dict[str, str]] = [] + + for file_obj in task.files: + if file_obj.status == "cancelled": + continue + + failed_chunks = [c for c in file_obj.chunks if c.status == "failed"] + has_file_level_failure = file_obj.status == "failed" and not failed_chunks + if has_file_level_failure: + file_fallback_candidates.append(file_obj) + continue + + if not failed_chunks: + continue + + retry_indexes: List[int] = [] + has_non_retryable = False + for chunk in failed_chunks: + failed_at = str(chunk.failed_at or "").strip().lower() + retryable = bool(chunk.retryable) or ( + file_obj.input_mode == "text" and failed_at == "extracting" + ) + if retryable: + try: + retry_indexes.append(int(chunk.index)) + except Exception: + has_non_retryable = True + else: + has_non_retryable = True + + if has_non_retryable: + file_fallback_candidates.append(file_obj) + continue + + retry_indexes = sorted(set(retry_indexes)) + if retry_indexes: + chunk_retry_candidates.append((file_obj, retry_indexes)) + else: + skipped.append( + { + "file_name": file_obj.name, + "source_kind": file_obj.source_kind, + "reason": "no_retryable_failed_chunks", + } + ) + + unique_fallback: List[ImportFileRecord] = [] + fallback_seen = set() + for file_obj in file_fallback_candidates: + if file_obj.file_id in fallback_seen: + continue + fallback_seen.add(file_obj.file_id) + unique_fallback.append(file_obj) + + return { + "chunk_retry_candidates": chunk_retry_candidates, + "file_fallback_candidates": unique_fallback, + "skipped": skipped, + } + + def _clone_failed_file_for_retry( + self, + retry_task: ImportTaskRecord, + failed_file: ImportFileRecord, + task_dir: Path, + *, + retry_mode: str, + retry_chunk_indexes: Optional[List[int]] = None, + ) -> Tuple[bool, str]: + source_kind = str(failed_file.source_kind or "").strip().lower() + retry_chunk_indexes = list(retry_chunk_indexes or []) + + if source_kind == "upload": + candidate_paths: List[Path] = [] + if failed_file.temp_path: + candidate_paths.append(Path(failed_file.temp_path)) + if failed_file.source_path: + candidate_paths.append(Path(failed_file.source_path)) + src_path = next((p for p in candidate_paths if p.exists() and p.is_file()), None) + if src_path is None: + return False, "upload_source_missing" + data = src_path.read_bytes() + file_id = uuid.uuid4().hex + name = _safe_filename(failed_file.name) + dst = task_dir / f"{file_id}_{name}" + dst.write_bytes(data) + retry_task.files.append( + ImportFileRecord( + file_id=file_id, + name=name, + source_kind="upload", + input_mode=failed_file.input_mode, + temp_path=str(dst), + retry_mode=retry_mode, + retry_chunk_indexes=retry_chunk_indexes, + ) + ) + return True, "" + + if source_kind == "paste": + if failed_file.inline_content is None: + return False, "paste_content_missing" + retry_task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=_safe_filename(failed_file.name), + source_kind="paste", + input_mode=failed_file.input_mode, + inline_content=failed_file.inline_content, + retry_mode=retry_mode, + retry_chunk_indexes=retry_chunk_indexes, + ) + ) + return True, "" + + if source_kind == "maibot_migration": + retry_task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=_safe_filename(failed_file.name), + source_kind="maibot_migration", + input_mode="text", + inline_content=failed_file.inline_content, + retry_mode="file_fallback", + retry_chunk_indexes=[], + ) + ) + return True, "" + + if source_kind in {"raw_scan", "lpmm_openie", "lpmm_convert", "temporal_backfill"}: + retry_task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=_safe_filename(failed_file.name), + source_kind=source_kind, + input_mode=failed_file.input_mode, + source_path=failed_file.source_path, + inline_content=failed_file.inline_content, + retry_mode=retry_mode, + retry_chunk_indexes=retry_chunk_indexes, + ) + ) + return True, "" + + return False, f"unsupported_source_kind:{source_kind or 'unknown'}" + + async def retry_failed(self, task_id: str, overrides: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + retry_plan = self._build_retry_plan(task) + chunk_retry_candidates = list(retry_plan["chunk_retry_candidates"]) + file_fallback_candidates = list(retry_plan["file_fallback_candidates"]) + skipped_candidates = list(retry_plan["skipped"]) + if not chunk_retry_candidates and not file_fallback_candidates: + raise ValueError("当前任务没有可重试失败项") + base_params = dict(task.params) + task_kind = str(task.params.get("task_kind") or "").strip().lower() + + if overrides: + base_params.update(overrides) + params = self._normalize_by_task_kind(task_kind, base_params) + params["retry_parent_task_id"] = task_id + params["retry_strategy"] = "chunk_first_auto_file_fallback" + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + retry_task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source=task.source, + params=params, + status="queued", + current_step="queued", + schema_detected=task.schema_detected, + retry_parent_task_id=task_id, + ) + + task_dir = self._temp_root / retry_task.task_id + task_dir.mkdir(parents=True, exist_ok=True) + + retry_summary = { + "chunk_retry_files": 0, + "chunk_retry_chunks": 0, + "file_fallback_files": 0, + "skipped_files": 0, + "parent_task_id": task_id, + } + skipped_details = list(skipped_candidates) + + for file_obj, chunk_indexes in chunk_retry_candidates: + ok, reason = self._clone_failed_file_for_retry( + retry_task, + file_obj, + task_dir, + retry_mode="chunk", + retry_chunk_indexes=chunk_indexes, + ) + if ok: + retry_summary["chunk_retry_files"] += 1 + retry_summary["chunk_retry_chunks"] += len(chunk_indexes) + else: + skipped_details.append( + { + "file_name": file_obj.name, + "source_kind": file_obj.source_kind, + "reason": reason, + } + ) + + for file_obj in file_fallback_candidates: + ok, reason = self._clone_failed_file_for_retry( + retry_task, + file_obj, + task_dir, + retry_mode="file_fallback", + retry_chunk_indexes=[], + ) + if ok: + retry_summary["file_fallback_files"] += 1 + else: + skipped_details.append( + { + "file_name": file_obj.name, + "source_kind": file_obj.source_kind, + "reason": reason, + } + ) + + retry_summary["skipped_files"] = len(skipped_details) + if skipped_details: + retry_summary["skipped_details"] = skipped_details + retry_task.retry_summary = retry_summary + + if not retry_task.files: + raise ValueError("无可执行的重试输入:失败项均无法构建重试任务") + + self._tasks[retry_task.task_id] = retry_task + self._task_order.appendleft(retry_task.task_id) + self._queue.append(retry_task.task_id) + logger.info( + "重试任务已创建 " + f"parent={task_id} retry={retry_task.task_id} " + f"chunk_files={retry_summary['chunk_retry_files']} " + f"chunk_chunks={retry_summary['chunk_retry_chunks']} " + f"file_fallback={retry_summary['file_fallback_files']} " + f"skipped={retry_summary['skipped_files']}" + ) + + await self._ensure_worker() + return retry_task.to_summary() + + async def shutdown(self) -> None: + async with self._lock: + self._stopping = True + for task in self._tasks.values(): + if task.status in {"queued", "preparing", "running", "cancel_requested"}: + self._mark_task_cancelled_locked(task, "服务关闭") + self._queue.clear() + worker = self._worker_task + self._worker_task = None + + if worker: + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + except Exception: + pass + + self._cleanup_temp_root() + + def _cleanup_temp_root(self) -> None: + try: + if not self._temp_root.exists(): + return + for child in self._temp_root.rglob("*"): + if child.is_file(): + child.unlink(missing_ok=True) + for child in sorted(self._temp_root.rglob("*"), reverse=True): + if child.is_dir(): + child.rmdir() + self._temp_root.rmdir() + except Exception as e: + logger.warning(f"清理临时导入目录失败: {e}") + + async def _worker_loop(self) -> None: + logger.info("Web 导入任务 worker 已启动") + while True: + if self._stopping: + break + + task_id: Optional[str] = None + async with self._lock: + while self._queue: + candidate = self._queue.popleft() + t = self._tasks.get(candidate) + if not t: + continue + if t.status == "cancelled": + continue + task_id = candidate + self._active_task_id = candidate + break + + if not task_id: + await asyncio.sleep(0.2) + continue + + try: + await self._run_task(task_id) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"导入任务执行失败 task={task_id}: {e}\n{traceback.format_exc()}") + async with self._lock: + task = self._tasks.get(task_id) + if task and task.status not in {"cancelled", "completed", "completed_with_errors"}: + task.status = "failed" + task.current_step = "failed" + task.error = str(e) + task.finished_at = _now() + task.updated_at = _now() + finally: + should_cleanup = await self._should_cleanup_task_temp(task_id) + async with self._lock: + if self._active_task_id == task_id: + self._active_task_id = None + if should_cleanup: + await self._cleanup_task_temp_files(task_id) + + logger.info("Web 导入任务 worker 已停止") + + async def _cleanup_task_temp_files(self, task_id: str) -> None: + task_dir = self._temp_root / task_id + if not task_dir.exists(): + return + try: + for child in task_dir.rglob("*"): + if child.is_file(): + child.unlink(missing_ok=True) + for child in sorted(task_dir.rglob("*"), reverse=True): + if child.is_dir(): + child.rmdir() + task_dir.rmdir() + except Exception as e: + logger.warning(f"清理任务临时文件失败 task={task_id}: {e}") + + def _task_report_path(self, task_id: str) -> Path: + self._reports_root.mkdir(parents=True, exist_ok=True) + return self._reports_root / f"{task_id}_summary.json" + + def _write_task_report(self, task: ImportTaskRecord) -> None: + path = self._task_report_path(task.task_id) + payload = task.to_detail(include_chunks=False) + payload["generated_at"] = _now() + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + task.artifact_paths["summary"] = str(path) + + async def _run_task(self, task_id: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + task.status = "preparing" + task.current_step = "preparing" + task.started_at = _now() + task.updated_at = _now() + if task.params.get("clear_manifest"): + self._clear_manifest() + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + if task.status == "cancel_requested": + task.status = "cancelled" + task.current_step = "cancelled" + task.finished_at = _now() + task.updated_at = _now() + return + task.status = "running" + task.current_step = "running" + task.updated_at = _now() + + task_kind = str(task.params.get("task_kind") or task.source).strip().lower() + if task_kind == "maibot_migration": + if not task.files: + raise RuntimeError("迁移任务缺少文件记录") + await self._process_maibot_migration(task_id, task.files[0]) + elif task_kind == "temporal_backfill": + if not task.files: + raise RuntimeError("回填任务缺少文件记录") + await self._process_temporal_backfill(task_id, task.files[0]) + elif task_kind == "lpmm_convert": + if not task.files: + raise RuntimeError("转换任务缺少文件记录") + await self._process_lpmm_convert(task_id, task.files[0]) + else: + file_semaphore = asyncio.Semaphore(task.params["file_concurrency"]) + chunk_semaphore = asyncio.Semaphore(task.params["chunk_concurrency"]) + jobs = [ + asyncio.create_task(self._process_file(task_id, f, file_semaphore, chunk_semaphore)) + for f in task.files + ] + await asyncio.gather(*jobs, return_exceptions=True) + + write_changed_payload: Optional[Dict[str, Any]] = None + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + self._recompute_task_progress(task) + has_failed = any( + (f.status == "failed") + or (f.failed_chunks > 0) + or bool(str(f.error or "").strip()) + for f in task.files + ) + has_cancelled = any(f.status == "cancelled" for f in task.files) + has_completed = any(f.status == "completed" for f in task.files) + + # 统一按文件真实终态收敛任务状态,避免出现“任务已取消但文件已完成”的矛盾结果。 + if has_failed and not has_cancelled: + task.status = "completed_with_errors" + task.current_step = "completed_with_errors" + elif has_cancelled and not has_completed: + task.status = "cancelled" + task.current_step = "cancelled" + elif has_cancelled and has_completed: + task.status = "cancelled" + task.current_step = "cancelled" + else: + task.status = "completed" + task.current_step = "completed" + task.finished_at = _now() + task.updated_at = _now() + try: + self._write_task_report(task) + except Exception as report_err: + logger.warning(f"写入任务报告失败 task={task_id}: {report_err}") + task_kind = str(task.params.get("task_kind") or task.source).strip().lower() + write_task_kinds = {"upload", "paste", "raw_scan", "lpmm_openie", "maibot_migration", "lpmm_convert"} + has_written_chunks = (task.done_chunks > 0) or any(f.done_chunks > 0 for f in task.files) + if task_kind in write_task_kinds and has_written_chunks: + write_changed_payload = { + "task_id": task.task_id, + "task_kind": task_kind, + "status": task.status, + "done_chunks": task.done_chunks, + "finished_at": task.finished_at, + } + + if write_changed_payload: + await self._notify_write_changed(write_changed_payload) + + def _build_maibot_migration_command(self, params: Dict[str, Any]) -> List[str]: + script_path = self._resolve_migration_script() + if not script_path.exists(): + raise RuntimeError(f"迁移脚本不存在: {script_path}") + + cmd = [ + sys.executable, + str(script_path), + "--source-db", + str(params["source_db"]), + "--target-data-dir", + str(params["target_data_dir"]), + "--read-batch-size", + str(params["read_batch_size"]), + "--commit-window-rows", + str(params["commit_window_rows"]), + "--embed-batch-size", + str(params["embed_batch_size"]), + "--entity-embed-batch-size", + str(params["entity_embed_batch_size"]), + "--max-errors", + str(params["max_errors"]), + "--log-every", + str(params["log_every"]), + "--preview-limit", + str(params["preview_limit"]), + "--yes", + ] + + if params.get("embed_workers") is not None: + cmd.extend(["--embed-workers", str(params["embed_workers"])]) + if params.get("start_id") is not None: + cmd.extend(["--start-id", str(params["start_id"])]) + if params.get("end_id") is not None: + cmd.extend(["--end-id", str(params["end_id"])]) + if params.get("time_from"): + cmd.extend(["--time-from", str(params["time_from"])]) + if params.get("time_to"): + cmd.extend(["--time-to", str(params["time_to"])]) + + for sid in params.get("stream_ids") or []: + cmd.extend(["--stream-id", str(sid)]) + for gid in params.get("group_ids") or []: + cmd.extend(["--group-id", str(gid)]) + for uid in params.get("user_ids") or []: + cmd.extend(["--user-id", str(uid)]) + + if params.get("reset_state"): + cmd.append("--reset-state") + if params.get("no_resume"): + cmd.append("--no-resume") + if params.get("dry_run"): + cmd.append("--dry-run") + if params.get("verify_only"): + cmd.append("--verify-only") + + return cmd + + async def _ensure_maibot_migration_chunk( + self, + task_id: str, + file_id: str, + *, + chunk_type: str = "maibot_migration", + preview: str = "MaiBot chat_history 迁移任务", + ) -> str: + chunk_id = f"{file_id}_{chunk_type}" + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return chunk_id + f = self._find_file(task, file_id) + if not f: + return chunk_id + if not f.chunks: + f.chunks = [ + ImportChunkRecord( + chunk_id=chunk_id, + index=0, + chunk_type=chunk_type, + status="queued", + step="queued", + progress=0.0, + content_preview=preview, + ) + ] + f.total_chunks = 1 + f.done_chunks = 0 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 0.0 + f.updated_at = _now() + self._recompute_task_progress(task) + else: + chunk_id = f.chunks[0].chunk_id + return chunk_id + + async def _refresh_maibot_progress_from_state( + self, + task_id: str, + file_id: str, + chunk_id: str, + state_path: Path, + ) -> None: + if not state_path.exists(): + return + try: + payload = json.loads(state_path.read_text(encoding="utf-8")) + except Exception: + return + + stats = payload.get("stats", {}) if isinstance(payload, dict) else {} + if not isinstance(stats, dict): + stats = {} + + total = max(0, _coerce_int(stats.get("source_matched_total", 0), 0)) + scanned = max(0, _coerce_int(stats.get("scanned_rows", 0), 0)) + bad = max(0, _coerce_int(stats.get("bad_rows", 0), 0)) + done = max(0, scanned - bad) + migrated = max(0, _coerce_int(stats.get("migrated_rows", 0), 0)) + last_id = max(0, _coerce_int(stats.get("last_committed_id", 0), 0)) + + if total <= 0: + total = max(1, scanned) + + progress = max(0.0, min(1.0, float(scanned) / float(total))) if total > 0 else 0.0 + preview = f"scanned={scanned}/{total}, migrated={migrated}, bad={bad}, last_id={last_id}" + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if c: + if c.status not in {"completed", "failed", "cancelled"}: + c.status = "writing" + c.step = "migrating" + c.progress = progress + c.content_preview = preview + c.updated_at = _now() + f.total_chunks = total + f.done_chunks = done + f.failed_chunks = bad + f.cancelled_chunks = 0 + f.progress = progress + if f.status not in {"failed", "cancelled"}: + f.status = "writing" + f.current_step = "migrating" + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _terminate_process(self, process: asyncio.subprocess.Process) -> None: + if process.returncode is not None: + return + try: + process.terminate() + await asyncio.wait_for(process.wait(), timeout=5.0) + except Exception: + try: + process.kill() + await asyncio.wait_for(process.wait(), timeout=3.0) + except Exception: + pass + + async def _reload_stores_after_external_migration(self) -> None: + async with self._storage_lock: + try: + if self.plugin.vector_store and self.plugin.vector_store.has_data(): + self.plugin.vector_store.load() + except Exception as e: + logger.warning(f"迁移后重载 VectorStore 失败: {e}") + try: + if self.plugin.graph_store and self.plugin.graph_store.has_data(): + self.plugin.graph_store.load() + except Exception as e: + logger.warning(f"迁移后重载 GraphStore 失败: {e}") + + async def _process_maibot_migration(self, task_id: str, file_record: ImportFileRecord) -> None: + await self._set_file_strategy(task_id, file_record.file_id, "maibot_migration") + await self._set_file_state(task_id, file_record.file_id, "preparing", "preparing") + chunk_id = await self._ensure_maibot_migration_chunk( + task_id, + file_record.file_id, + chunk_type="maibot_migration", + preview="MaiBot chat_history 迁移任务", + ) + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "migrating", 0.0) + + task = self._tasks.get(task_id) + if not task: + await self._set_file_failed(task_id, file_record.file_id, "任务不存在") + return + params = dict(task.params) + + command = self._build_maibot_migration_command(params) + project_root = self._resolve_repo_root() + state_path = Path(params["target_data_dir"]) / "migration_state" / "chat_history_resume.json" + report_path = Path(params["target_data_dir"]) / "migration_state" / "chat_history_report.json" + + logger.info(f"开始执行 MaiBot 迁移任务: {' '.join(command)}") + process = await asyncio.create_subprocess_exec( + *command, + cwd=str(project_root), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout_lines: List[str] = [] + stderr_lines: List[str] = [] + + async def _drain(stream: Optional[asyncio.StreamReader], target: List[str]) -> None: + if stream is None: + return + while True: + line = await stream.readline() + if not line: + break + text = line.decode("utf-8", errors="replace").strip() + if not text: + continue + target.append(text) + if len(target) > 120: + del target[:-120] + + drain_tasks = [ + asyncio.create_task(_drain(process.stdout, stdout_lines)), + asyncio.create_task(_drain(process.stderr, stderr_lines)), + ] + + cancelled = False + return_code: Optional[int] = None + try: + while True: + if await self._is_cancel_requested(task_id): + cancelled = True + await self._terminate_process(process) + break + + await self._refresh_maibot_progress_from_state(task_id, file_record.file_id, chunk_id, state_path) + try: + return_code = await asyncio.wait_for(process.wait(), timeout=1.0) + break + except asyncio.TimeoutError: + continue + finally: + await asyncio.gather(*drain_tasks, return_exceptions=True) + + if cancelled: + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + + await self._refresh_maibot_progress_from_state(task_id, file_record.file_id, chunk_id, state_path) + + report: Dict[str, Any] = {} + if report_path.exists(): + try: + report = json.loads(report_path.read_text(encoding="utf-8")) + except Exception: + report = {} + + stats = report.get("stats", {}) if isinstance(report, dict) else {} + if not isinstance(stats, dict): + stats = {} + bad_rows = max(0, _coerce_int(stats.get("bad_rows", 0), 0)) + + if return_code in {0, 2}: + await self._set_file_state(task_id, file_record.file_id, "saving", "saving") + await self._reload_stores_after_external_migration() + + async with self._lock: + task2 = self._tasks.get(task_id) + if not task2: + return + f = self._find_file(task2, file_record.file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if c and c.status not in {"cancelled", "failed"}: + c.status = "completed" + c.step = "completed" + c.progress = 1.0 + c.updated_at = _now() + if f.total_chunks <= 0: + f.total_chunks = 1 + if f.done_chunks + f.failed_chunks <= 0: + f.done_chunks = f.total_chunks - bad_rows + f.failed_chunks = bad_rows + f.done_chunks = max(0, min(f.done_chunks, f.total_chunks)) + f.failed_chunks = max(0, min(f.failed_chunks, f.total_chunks)) + f.cancelled_chunks = 0 + f.progress = 1.0 + f.status = "completed" + f.current_step = "completed" + if bad_rows > 0 and not f.error: + f.error = f"迁移完成,但存在坏行: {bad_rows}" + f.updated_at = _now() + self._recompute_task_progress(task2) + return + + fail_reason = "" + if isinstance(report, dict): + fail_reason = str(report.get("fail_reason") or "").strip() + tail = (stderr_lines[-1] if stderr_lines else "") or (stdout_lines[-1] if stdout_lines else "") + detail = fail_reason or tail or f"迁移进程退出码: {return_code}" + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, detail) + await self._set_file_failed(task_id, file_record.file_id, detail) + + def _resolve_convert_script(self) -> Path: + return Path(__file__).resolve().parents[2] / "scripts" / "convert_lpmm.py" + + def _cleanup_old_backups(self) -> None: + keep = max(0, self._cfg_int("web.import.convert.keep_backup_count", 3)) + backup_root = self._resolve_backup_root() + if not backup_root.exists() or keep <= 0: + return + dirs = [p for p in backup_root.iterdir() if p.is_dir() and p.name.startswith("lpmm_convert_")] + dirs.sort(key=lambda p: p.stat().st_mtime, reverse=True) + for old in dirs[keep:]: + try: + shutil.rmtree(old, ignore_errors=True) + except Exception: + pass + + def _verify_convert_output(self, output_dir: Path) -> Dict[str, Any]: + vectors = output_dir / "vectors" + graph = output_dir / "graph" + metadata = output_dir / "metadata" + checks = { + "vectors_exists": vectors.exists(), + "graph_exists": graph.exists(), + "metadata_exists": metadata.exists(), + "vectors_nonempty": vectors.exists() and any(vectors.iterdir()), + "graph_nonempty": graph.exists() and any(graph.iterdir()), + "metadata_nonempty": metadata.exists() and any(metadata.iterdir()), + } + checks["ok"] = checks["vectors_exists"] and checks["graph_exists"] and checks["metadata_exists"] + return checks + + async def _preflight_convert_runtime(self) -> Tuple[bool, str]: + """使用当前服务解释器做 convert 依赖预检,避免子进程报错信息不透明。""" + probe_code = ( + "import importlib\n" + "mods=['networkx','scipy','pyarrow']\n" + "failed=[]\n" + "for m in mods:\n" + " try:\n" + " importlib.import_module(m)\n" + " except Exception as e:\n" + " failed.append(f'{m}:{e.__class__.__name__}:{e}')\n" + "print('OK' if not failed else ';'.join(failed))\n" + ) + try: + probe = await asyncio.create_subprocess_exec( + sys.executable, + "-c", + probe_code, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(probe.communicate(), timeout=20.0) + except Exception as e: + return False, f"依赖预检执行失败: {e}" + + out = (stdout or b"").decode("utf-8", errors="replace").strip() + err = (stderr or b"").decode("utf-8", errors="replace").strip() + if probe.returncode != 0: + detail = err or out or f"return_code={probe.returncode}" + return False, f"依赖预检失败 (python={sys.executable}): {detail}" + if out != "OK": + return False, f"依赖预检失败 (python={sys.executable}): {out}" + return True, "" + + async def _process_lpmm_convert(self, task_id: str, file_record: ImportFileRecord) -> None: + await self._set_file_strategy(task_id, file_record.file_id, "lpmm_convert") + await self._set_file_state(task_id, file_record.file_id, "preparing", "preflight") + chunk_id = await self._ensure_maibot_migration_chunk( + task_id, + file_record.file_id, + chunk_type="lpmm_convert", + preview="LPMM 二进制转换任务", + ) + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "converting", 0.05) + + task = self._tasks.get(task_id) + if not task: + await self._set_file_failed(task_id, file_record.file_id, "任务不存在") + return + params = dict(task.params) + source_dir = Path(params.get("source_path") or "") + target_dir = Path(params.get("target_path") or "") + if not source_dir.exists() or not source_dir.is_dir(): + await self._set_file_failed(task_id, file_record.file_id, f"输入目录无效: {source_dir}") + return + if not target_dir.exists() or not target_dir.is_dir(): + await self._set_file_failed(task_id, file_record.file_id, f"目标目录无效: {target_dir}") + return + + script_path = self._resolve_convert_script() + if not script_path.exists(): + await self._set_file_failed(task_id, file_record.file_id, f"转换脚本不存在: {script_path}") + return + + runtime_ok, runtime_detail = await self._preflight_convert_runtime() + if not runtime_ok: + await self._set_file_failed(task_id, file_record.file_id, runtime_detail) + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, runtime_detail) + return + + required_inputs = ["paragraph.parquet", "entity.parquet"] + if not any((source_dir / name).exists() for name in required_inputs): + await self._set_file_failed( + task_id, + file_record.file_id, + f"输入目录缺少必要文件,至少需要其一: {', '.join(required_inputs)}", + ) + return + + staging_root = self._resolve_staging_root() + staging_root.mkdir(parents=True, exist_ok=True) + staging_dir = staging_root / f"lpmm_convert_{task_id}" + if staging_dir.exists(): + shutil.rmtree(staging_dir, ignore_errors=True) + staging_dir.mkdir(parents=True, exist_ok=True) + + # 简单空间预检:至少保留 512MB + usage = shutil.disk_usage(str(target_dir)) + if usage.free < 512 * 1024 * 1024: + await self._set_file_failed(task_id, file_record.file_id, "磁盘剩余空间不足(<512MB)") + return + + cmd = [ + sys.executable, + str(script_path), + "--input", + str(source_dir), + "--output", + str(staging_dir), + "--dim", + str(params.get("dimension", 384)), + "--batch-size", + str(params.get("batch_size", 1024)), + ] + process = await asyncio.create_subprocess_exec( + *cmd, + cwd=str(self._resolve_repo_root()), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout_lines: List[str] = [] + stderr_lines: List[str] = [] + + async def _drain(stream: Optional[asyncio.StreamReader], target: List[str]) -> None: + if stream is None: + return + while True: + line = await stream.readline() + if not line: + break + text = line.decode("utf-8", errors="replace").strip() + if text: + target.append(text) + if len(target) > 120: + del target[:-120] + + drain_tasks = [ + asyncio.create_task(_drain(process.stdout, stdout_lines)), + asyncio.create_task(_drain(process.stderr, stderr_lines)), + ] + + cancelled = False + return_code: Optional[int] = None + try: + while True: + if await self._is_cancel_requested(task_id): + cancelled = True + await self._terminate_process(process) + break + try: + return_code = await asyncio.wait_for(process.wait(), timeout=1.0) + break + except asyncio.TimeoutError: + continue + finally: + await asyncio.gather(*drain_tasks, return_exceptions=True) + + if cancelled: + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + if return_code != 0: + detail = (stderr_lines[-1] if stderr_lines else "") or (stdout_lines[-1] if stdout_lines else "") + await self._set_file_failed(task_id, file_record.file_id, detail or f"转换失败,退出码: {return_code}") + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, detail or f"退出码: {return_code}") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "verifying", 0.65) + verify = self._verify_convert_output(staging_dir) + async with self._lock: + t = self._tasks.get(task_id) + if t: + t.artifact_paths["staging_dir"] = str(staging_dir) + t.artifact_paths["verify"] = json.dumps(verify, ensure_ascii=False) + if not verify.get("ok"): + await self._set_file_failed(task_id, file_record.file_id, f"校验失败: {verify}") + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"校验失败: {verify}") + return + + enable_switch = _coerce_bool(self._cfg("web.import.convert.enable_staging_switch", True), True) + if not enable_switch: + await self._set_file_failed(task_id, file_record.file_id, "未启用 staging 切换") + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, "未启用 staging 切换") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "switching", 0.85) + backup_root = self._resolve_backup_root() + backup_root.mkdir(parents=True, exist_ok=True) + backup_dir = backup_root / f"lpmm_convert_{task_id}_{int(_now())}" + backup_dir.mkdir(parents=True, exist_ok=True) + + switched = False + rollback_info: Dict[str, Any] = {"attempted": True, "restored": False, "error": ""} + moved_items: List[Tuple[Path, Path]] = [] + try: + for name in ("vectors", "graph", "metadata"): + src_current = target_dir / name + src_new = staging_dir / name + if not src_new.exists(): + raise RuntimeError(f"staging 缺少目录: {src_new}") + if src_current.exists(): + dst_backup = backup_dir / name + shutil.move(str(src_current), str(dst_backup)) + moved_items.append((dst_backup, src_current)) + shutil.move(str(src_new), str(src_current)) + switched = True + except Exception as switch_err: + rollback_info["error"] = str(switch_err) + # 尝试回滚 + for src_backup, dst_original in moved_items: + if src_backup.exists() and not dst_original.exists(): + try: + shutil.move(str(src_backup), str(dst_original)) + except Exception: + pass + rollback_info["restored"] = True + async with self._lock: + t = self._tasks.get(task_id) + if t: + t.rollback_info = rollback_info + await self._set_file_failed(task_id, file_record.file_id, f"切换失败并回滚: {switch_err}") + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"switch failed: {switch_err}") + return + + if switched: + async with self._lock: + t = self._tasks.get(task_id) + if t: + t.rollback_info = rollback_info + t.artifact_paths["backup_dir"] = str(backup_dir) + self._cleanup_old_backups() + try: + await self._reload_stores_after_external_migration() + except Exception as reload_err: + logger.warning(f"转换后重载存储失败: {reload_err}") + + await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) + async with self._lock: + t = self._tasks.get(task_id) + if not t: + return + f = self._find_file(t, file_record.file_id) + if not f: + return + f.total_chunks = 1 + f.done_chunks = 1 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 1.0 + f.status = "completed" + f.current_step = "completed" + f.updated_at = _now() + self._recompute_task_progress(t) + + async def _process_temporal_backfill(self, task_id: str, file_record: ImportFileRecord) -> None: + await self._set_file_strategy(task_id, file_record.file_id, "temporal_backfill") + await self._set_file_state(task_id, file_record.file_id, "preparing", "backfilling") + chunk_id = await self._ensure_maibot_migration_chunk( + task_id, + file_record.file_id, + chunk_type="temporal_backfill", + preview="时序字段回填任务", + ) + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "backfilling", 0.2) + + task = self._tasks.get(task_id) + if not task: + await self._set_file_failed(task_id, file_record.file_id, "任务不存在") + return + params = dict(task.params) + target_dir = Path(file_record.source_path or "") + metadata_dir = target_dir / "metadata" + if not metadata_dir.exists(): + await self._set_file_failed(task_id, file_record.file_id, f"metadata 目录不存在: {metadata_dir}") + return + + dry_run = bool(params.get("dry_run")) + no_created_fallback = bool(params.get("no_created_fallback")) + limit = max(1, _coerce_int(params.get("limit"), 100000)) + + store = MetadataStore(data_dir=metadata_dir) + updated = 0 + candidates = 0 + try: + store.connect() + summary = store.backfill_temporal_metadata_from_created_at( + limit=limit, + dry_run=dry_run, + no_created_fallback=no_created_fallback, + ) + candidates = int(summary.get("candidates", 0)) + updated = int(summary.get("updated", 0)) + finally: + try: + store.close() + except Exception: + pass + + async with self._lock: + t = self._tasks.get(task_id) + if t: + t.artifact_paths["temporal_backfill"] = json.dumps( + { + "target_dir": str(target_dir), + "dry_run": dry_run, + "no_created_fallback": no_created_fallback, + "limit": limit, + "candidates": candidates, + "updated": updated, + }, + ensure_ascii=False, + ) + await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) + async with self._lock: + t = self._tasks.get(task_id) + if not t: + return + f = self._find_file(t, file_record.file_id) + if not f: + return + f.total_chunks = 1 + f.done_chunks = 1 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 1.0 + f.status = "completed" + f.current_step = "completed" + f.updated_at = _now() + self._recompute_task_progress(t) + + async def _process_file( + self, + task_id: str, + file_record: ImportFileRecord, + file_semaphore: asyncio.Semaphore, + chunk_semaphore: asyncio.Semaphore, + ) -> None: + async with file_semaphore: + await self._set_file_state(task_id, file_record.file_id, "preparing", "preparing") + if await self._is_cancel_requested(task_id): + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + + try: + content = await self._read_file_content(file_record) + content_hash = hashlib.md5(content.encode("utf-8", errors="ignore")).hexdigest() + file_record.content_hash = content_hash + task = self._tasks.get(task_id) + if task: + dedupe_policy = str(task.params.get("dedupe_policy") or "none") + force = bool(task.params.get("force")) + if dedupe_policy != "none" and not force: + async with self._lock: + if self._is_manifest_hit(file_record, content_hash, dedupe_policy): + task2 = self._tasks.get(task_id) + if task2: + f = self._find_file(task2, file_record.file_id) + if f: + f.status = "completed" + f.current_step = "skipped" + f.progress = 1.0 + f.total_chunks = 0 + f.done_chunks = 0 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.detected_strategy_type = "skipped" + f.error = "" + f.updated_at = _now() + self._recompute_task_progress(task2) + return + if file_record.input_mode == "json": + await self._process_json_file(task_id, file_record, content, chunk_semaphore) + else: + await self._process_text_file(task_id, file_record, content, chunk_semaphore) + task3 = self._tasks.get(task_id) + if task3: + dedupe_policy = str(task3.params.get("dedupe_policy") or "none") + f3 = self._find_file(task3, file_record.file_id) + if dedupe_policy != "none" and f3 and f3.status == "completed": + async with self._lock: + self._record_manifest_import(file_record, content_hash, dedupe_policy, task_id) + except Exception as e: + await self._set_file_failed(task_id, file_record.file_id, str(e)) + + async def _read_file_content(self, file_record: ImportFileRecord) -> str: + if file_record.inline_content is not None: + return file_record.inline_content + if file_record.source_path and Path(file_record.source_path).exists(): + data = Path(file_record.source_path).read_bytes() + try: + return data.decode("utf-8") + except UnicodeDecodeError: + return data.decode("utf-8", errors="replace") + if file_record.temp_path and Path(file_record.temp_path).exists(): + data = Path(file_record.temp_path).read_bytes() + try: + return data.decode("utf-8") + except UnicodeDecodeError: + return data.decode("utf-8", errors="replace") + raise RuntimeError("读取文件失败:输入内容缺失") + + async def _process_text_file( + self, + task_id: str, + file_record: ImportFileRecord, + content: str, + chunk_semaphore: asyncio.Semaphore, + ) -> None: + task = self._tasks[task_id] + async with self._lock: + t = self._tasks.get(task_id) + if t and not t.schema_detected: + t.schema_detected = "plain_text" + strategy = self._determine_strategy( + file_record.name, + content, + task.params["strategy_override"], + chat_log=bool(task.params.get("chat_log")), + ) + await self._set_file_strategy(task_id, file_record.file_id, strategy) + await self._set_file_state(task_id, file_record.file_id, "splitting", "splitting") + await self._ensure_embedding_runtime_ready() + + chunks = strategy.split(content) + selected_chunks = list(chunks) + if file_record.retry_mode == "chunk": + retry_index_set = set() + for idx in file_record.retry_chunk_indexes: + try: + retry_index_set.add(int(idx)) + except Exception: + continue + selected_chunks = [chunk for chunk in chunks if int(chunk.chunk.index) in retry_index_set] + if not selected_chunks: + raise RuntimeError("失败分块重试索引无效,未匹配到可执行分块") + logger.info( + "重试任务按失败分块执行: " + f"file={file_record.name} " + f"selected={len(selected_chunks)} " + f"total={len(chunks)}" + ) + + await self._register_chunks(task_id, file_record.file_id, selected_chunks) + + await self._set_file_state(task_id, file_record.file_id, "extracting", "extracting") + model_cfg = None + if task.params["llm_enabled"]: + model_cfg = await self._select_model() + + jobs = [] + for chunk in selected_chunks: + jobs.append( + asyncio.create_task( + self._process_text_chunk( + task_id=task_id, + file_record=file_record, + chunk=chunk, + strategy=strategy, + llm_enabled=task.params["llm_enabled"], + model_cfg=model_cfg, + chunk_semaphore=chunk_semaphore, + chat_log=bool(task.params.get("chat_log")), + chat_reference_time=str(task.params.get("chat_reference_time") or "").strip() or None, + ) + ) + ) + await asyncio.gather(*jobs, return_exceptions=True) + + if await self._is_cancel_requested(task_id): + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + + await self._set_file_state(task_id, file_record.file_id, "saving", "saving") + async with self._storage_lock: + self.plugin.vector_store.save() + self.plugin.graph_store.save() + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_record.file_id) + if not f: + return + if f.failed_chunks > 0: + f.status = "failed" + f.current_step = "failed" + if not f.error: + f.error = f"存在失败分块: {f.failed_chunks}" + elif task.status == "cancel_requested": + f.status = "cancelled" + f.current_step = "cancelled" + else: + f.status = "completed" + f.current_step = "completed" + f.progress = 1.0 + f.updated_at = _now() + self._recompute_task_progress(task) + async def _process_text_chunk( + self, + task_id: str, + file_record: ImportFileRecord, + chunk: ProcessedChunk, + strategy: Any, + llm_enabled: bool, + model_cfg: Any, + chunk_semaphore: asyncio.Semaphore, + chat_log: bool = False, + chat_reference_time: Optional[str] = None, + ) -> None: + async with chunk_semaphore: + chunk_id = chunk.chunk.chunk_id + if await self._is_cancel_requested(task_id): + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "extracting", "extracting", 0.25) + + processed = chunk + rescue_strategy = self._chunk_rescue(chunk, file_record.name) + current_strategy = strategy + if rescue_strategy: + chunk.type = StrategyKnowledgeType.QUOTE + chunk.flags.verbatim = True + chunk.flags.requires_llm = False + current_strategy = rescue_strategy + try: + if llm_enabled and chunk.flags.requires_llm: + processed = await current_strategy.extract( + chunk, + lambda prompt: self._llm_call(prompt, model_cfg), + ) + elif chunk.type == StrategyKnowledgeType.QUOTE: + processed = await current_strategy.extract(chunk) + except Exception as e: + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"抽取失败: {e}") + return + + if await self._is_cancel_requested(task_id): + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "writing", 0.7) + try: + time_meta = None + if chat_log and llm_enabled and model_cfg is not None: + time_meta = await self._extract_chat_time_meta_with_llm( + processed.chunk.text, + model_cfg, + reference_time=chat_reference_time, + ) + async with self._storage_lock: + await self._persist_processed_chunk(file_record, processed, time_meta=time_meta) + await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) + except Exception as e: + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"写入失败: {e}") + + async def _process_json_file( + self, + task_id: str, + file_record: ImportFileRecord, + content: str, + chunk_semaphore: asyncio.Semaphore, + ) -> None: + await self._set_file_strategy(task_id, file_record.file_id, "json") + await self._set_file_state(task_id, file_record.file_id, "splitting", "splitting") + await self._ensure_embedding_runtime_ready() + + try: + data = json.loads(content) + except Exception as e: + raise RuntimeError(f"JSON 解析失败: {e}") + + schema = self._detect_json_schema(data) + async with self._lock: + task = self._tasks.get(task_id) + if task: + task.schema_detected = schema + task.updated_at = _now() + units = self._build_json_units(data, file_record.file_id, file_record.name, schema) + await self._register_json_units(task_id, file_record.file_id, units) + + await self._set_file_state(task_id, file_record.file_id, "extracting", "extracting") + jobs = [ + asyncio.create_task(self._process_json_unit(task_id, file_record, unit, chunk_semaphore)) + for unit in units + ] + await asyncio.gather(*jobs, return_exceptions=True) + + if await self._is_cancel_requested(task_id): + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + + await self._set_file_state(task_id, file_record.file_id, "saving", "saving") + async with self._storage_lock: + self.plugin.vector_store.save() + self.plugin.graph_store.save() + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_record.file_id) + if not f: + return + if f.failed_chunks > 0: + f.status = "failed" + f.current_step = "failed" + if not f.error: + f.error = f"存在失败分块: {f.failed_chunks}" + elif task.status == "cancel_requested": + f.status = "cancelled" + f.current_step = "cancelled" + else: + f.status = "completed" + f.current_step = "completed" + f.progress = 1.0 + f.updated_at = _now() + self._recompute_task_progress(task) + + def _detect_json_schema(self, data: Any) -> str: + if isinstance(data, dict) and isinstance(data.get("docs"), list): + return "lpmm_openie" + if isinstance(data, dict) and isinstance(data.get("paragraphs"), list): + paragraphs = data.get("paragraphs", []) + for p in paragraphs: + if isinstance(p, dict) and any( + key in p for key in ("entities", "relations", "time_meta", "source", "type", "knowledge_type") + ): + return "script_json" + return "web_json" + raise RuntimeError("不支持的 JSON 格式:需要 paragraphs 或 docs") + + def _build_json_units(self, data: Any, file_id: str, filename: str, schema: str) -> List[Dict[str, Any]]: + units: List[Dict[str, Any]] = [] + paragraphs: List[Any] = [] + entities: List[Any] = [] + relations: List[Any] = [] + + if schema in {"web_json", "script_json"}: + paragraphs = data.get("paragraphs", []) + entities = data.get("entities", []) + relations = data.get("relations", []) + elif schema == "lpmm_openie": + docs = data.get("docs", []) + for d in docs: + if not isinstance(d, dict): + continue + content = str(d.get("passage", "") or "").strip() + if not content: + continue + triples = d.get("extracted_triples", []) or [] + rels = [] + for t in triples: + if isinstance(t, list) and len(t) == 3: + rels.append( + { + "subject": str(t[0]), + "predicate": str(t[1]), + "object": str(t[2]), + } + ) + para_item = { + "content": content, + "source": f"lpmm_openie:{filename}", + "entities": d.get("extracted_entities", []) or [], + "relations": rels, + "knowledge_type": "factual", + } + paragraphs.append(para_item) + + for p in paragraphs: + paragraph = normalize_paragraph_import_item( + p, + default_source=f"web_import:{filename}", + ) + units.append( + { + "chunk_id": f"{file_id}_json_{len(units)}", + "kind": "paragraph", + "content": paragraph["content"], + "time_meta": paragraph["time_meta"], + "knowledge_type": paragraph["knowledge_type"], + "chunk_type": paragraph["knowledge_type"], + "source": paragraph["source"], + "entities": paragraph["entities"], + "relations": paragraph["relations"], + "preview": paragraph["content"][:120], + } + ) + + for e in entities: + name = str(e or "").strip() + if name: + units.append( + { + "chunk_id": f"{file_id}_json_{len(units)}", + "kind": "entity", + "name": name, + "chunk_type": "entity", + "preview": name[:120], + } + ) + + for r in relations: + if not isinstance(r, dict): + continue + s = str(r.get("subject", "")).strip() + p = str(r.get("predicate", "")).strip() + o = str(r.get("object", "")).strip() + if s and p and o: + units.append( + { + "chunk_id": f"{file_id}_json_{len(units)}", + "kind": "relation", + "subject": s, + "predicate": p, + "object": o, + "chunk_type": "relation", + "preview": f"{s} {p} {o}"[:120], + } + ) + return units + + async def _register_json_units(self, task_id: str, file_id: str, units: List[Dict[str, Any]]) -> None: + records = [ + ImportChunkRecord( + chunk_id=u["chunk_id"], + index=i, + chunk_type=u.get("chunk_type", "json"), + status="queued", + step="queued", + progress=0.0, + content_preview=str(u.get("preview", "")), + ) + for i, u in enumerate(units) + ] + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.chunks = records + f.total_chunks = len(records) + f.done_chunks = 0 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 0.0 if records else 1.0 + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _process_json_unit( + self, + task_id: str, + file_record: ImportFileRecord, + unit: Dict[str, Any], + chunk_semaphore: asyncio.Semaphore, + ) -> None: + chunk_id = unit["chunk_id"] + async with chunk_semaphore: + if await self._is_cancel_requested(task_id): + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "writing", 0.7) + try: + async with self._storage_lock: + kind = unit["kind"] + if kind == "paragraph": + content = str(unit.get("content", "")) + k_type = resolve_stored_knowledge_type( + unit.get("knowledge_type"), + content=content, + ).value + source = str(unit.get("source") or f"web_import:{file_record.name}") + para_hash = self.plugin.metadata_store.add_paragraph( + content=content, + source=source, + knowledge_type=k_type, + time_meta=unit.get("time_meta"), + ) + emb = await self.plugin.embedding_manager.encode(content) + try: + self.plugin.vector_store.add(emb.reshape(1, -1), [para_hash]) + except ValueError: + pass + for name in unit.get("entities", []) or []: + n = str(name or "").strip() + if n: + await self._add_entity_with_vector(n, source_paragraph=para_hash) + for rel in unit.get("relations", []) or []: + if not isinstance(rel, dict): + continue + s = str(rel.get("subject", "")).strip() + p = str(rel.get("predicate", "")).strip() + o = str(rel.get("object", "")).strip() + if s and p and o: + await self._add_relation(s, p, o, source_paragraph=para_hash) + elif kind == "entity": + await self._add_entity_with_vector(unit["name"]) + elif kind == "relation": + await self._add_relation(unit["subject"], unit["predicate"], unit["object"]) + else: + raise RuntimeError(f"未知 JSON 导入单元类型: {kind}") + await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) + except Exception as e: + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"写入失败: {e}") + + def _source_label(self, file_record: ImportFileRecord) -> str: + if file_record.source_path: + return f"{file_record.source_kind}:{file_record.source_path}" + return f"web_import:{file_record.name}" + + async def _ensure_embedding_runtime_ready(self) -> None: + report = await ensure_runtime_self_check(self.plugin) + if bool(report.get("ok", False)): + return + raise RuntimeError( + "embedding runtime self-check failed: " + f"{report.get('message', 'unknown')} " + f"(configured={report.get('configured_dimension', 0)}, " + f"store={report.get('vector_store_dimension', 0)}, " + f"encoded={report.get('encoded_dimension', 0)})" + ) + + async def _persist_processed_chunk( + self, + file_record: ImportFileRecord, + processed: ProcessedChunk, + *, + time_meta: Optional[Dict[str, Any]] = None, + ) -> None: + content = processed.chunk.text + para_hash = self.plugin.metadata_store.add_paragraph( + content=content, + source=self._source_label(file_record), + knowledge_type=_storage_type_from_strategy(processed.type), + time_meta=time_meta, + ) + + emb = await self.plugin.embedding_manager.encode(content) + try: + self.plugin.vector_store.add(emb.reshape(1, -1), [para_hash]) + except ValueError: + pass + + data = processed.data or {} + entities: List[str] = [] + relations: List[Tuple[str, str, str]] = [] + + for triple in data.get("triples", []): + s = str(triple.get("subject", "")).strip() + p = str(triple.get("predicate", "")).strip() + o = str(triple.get("object", "")).strip() + if s and p and o: + relations.append((s, p, o)) + entities.extend([s, o]) + + for rel in data.get("relations", []): + s = str(rel.get("subject", "")).strip() + p = str(rel.get("predicate", "")).strip() + o = str(rel.get("object", "")).strip() + if s and p and o: + relations.append((s, p, o)) + entities.extend([s, o]) + + for k in ("entities", "events", "verbatim_entities"): + for e in data.get(k, []): + name = str(e or "").strip() + if name: + entities.append(name) + + uniq_entities = list({x.strip().lower(): x.strip() for x in entities if str(x).strip()}.values()) + for name in uniq_entities: + await self._add_entity_with_vector(name, source_paragraph=para_hash) + + for s, p, o in relations: + await self._add_relation(s, p, o, source_paragraph=para_hash) + + async def _add_entity_with_vector(self, name: str, source_paragraph: str = "") -> str: + hash_value = self.plugin.metadata_store.add_entity(name=name, source_paragraph=source_paragraph) + self.plugin.graph_store.add_nodes([name]) + if hash_value not in self.plugin.vector_store: + emb = await self.plugin.embedding_manager.encode(name) + try: + self.plugin.vector_store.add(emb.reshape(1, -1), [hash_value]) + except ValueError: + pass + return hash_value + + async def _add_relation(self, subject: str, predicate: str, obj: str, source_paragraph: str = "") -> str: + await self._add_entity_with_vector(subject, source_paragraph=source_paragraph) + await self._add_entity_with_vector(obj, source_paragraph=source_paragraph) + rv_cfg = self.plugin.get_config("retrieval.relation_vectorization", {}) or {} + if not isinstance(rv_cfg, dict): + rv_cfg = {} + write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) + + relation_service = getattr(self.plugin, "relation_write_service", None) + if relation_service is not None: + result = await relation_service.upsert_relation_with_vector( + subject=subject, + predicate=predicate, + obj=obj, + confidence=1.0, + source_paragraph=source_paragraph, + write_vector=write_vector, + ) + return result.hash_value + + rel_hash = self.plugin.metadata_store.add_relation( + subject=subject, + predicate=predicate, + obj=obj, + source_paragraph=source_paragraph, + confidence=1.0, + ) + self.plugin.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) + try: + self.plugin.metadata_store.set_relation_vector_state(rel_hash, "none") + except Exception: + pass + return rel_hash + async def _select_model(self) -> Any: + models = llm_api.get_available_models() + if not models: + raise RuntimeError("没有可用 LLM 模型") + + config_model = str(self._cfg("advanced.extraction_model", "auto") or "auto").strip() + if config_model.lower() != "auto" and config_model in models: + return models[config_model] + + for task_name in [ + "lpmm_entity_extract", + "lpmm_rdf_build", + "embedding", + "replyer", + "utils", + "planner", + "tool_use", + ]: + if task_name in models: + return models[task_name] + + return models[next(iter(models))] + + async def _llm_call(self, prompt: str, model_config: Any) -> Dict[str, Any]: + cfg = self._llm_retry_config() + retries = int(cfg["retries"]) + last_error: Optional[Exception] = None + for attempt in range(retries + 1): + try: + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config, + request_type="A_Memorix.WebImport", + ) + if not success or not response: + raise RuntimeError("LLM 生成失败") + + txt = str(response or "").strip() + if "```" in txt: + txt = txt.split("```json")[-1].split("```")[0].strip() + if txt.startswith("json"): + txt = txt[4:].strip() + + try: + return json.loads(txt) + except Exception: + s = txt.find("{") + e = txt.rfind("}") + if s >= 0 and e > s: + return json.loads(txt[s : e + 1]) + raise + except Exception as err: + last_error = err + if attempt >= retries: + break + delay = min(cfg["max_wait"], cfg["min_wait"] * (cfg["multiplier"] ** attempt)) + await asyncio.sleep(max(0.0, float(delay))) + raise RuntimeError(f"LLM 抽取失败: {last_error}") + + def _parse_reference_time(self, value: Optional[str]) -> datetime: + if not value: + return datetime.now() + text = str(value).strip() + formats = [ + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y/%m/%d", + "%Y-%m-%d", + ] + for fmt in formats: + try: + return datetime.strptime(text, fmt) + except ValueError: + continue + return datetime.now() + + async def _extract_chat_time_meta_with_llm( + self, + text: str, + model_config: Any, + *, + reference_time: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + if not str(text or "").strip(): + return None + ref_dt = self._parse_reference_time(reference_time) + reference_now = ref_dt.strftime("%Y/%m/%d %H:%M") + prompt = f"""You are a time extraction engine for chat logs. +Extract temporal information from the following chat paragraph. + +Rules: +1. Use semantic understanding, not regex matching. +2. Convert relative expressions to absolute local datetime using reference_now. +3. If a time span exists, return event_time_start/event_time_end. +4. If only one point in time exists, return event_time. +5. If no reliable time info exists, keep all event_time fields null. +6. Return JSON only. + +reference_now: {reference_now} +text: +{text} + +JSON schema: +{{ + "event_time": null, + "event_time_start": null, + "event_time_end": null, + "time_range": null, + "time_granularity": null, + "time_confidence": 0.0 +}} +""" + try: + result = await self._llm_call(prompt, model_config) + except Exception as e: + logger.warning(f"chat_log 时间语义抽取失败: {e}") + return None + + raw_time_meta = { + "event_time": result.get("event_time"), + "event_time_start": result.get("event_time_start"), + "event_time_end": result.get("event_time_end"), + "time_range": result.get("time_range"), + "time_granularity": result.get("time_granularity"), + "time_confidence": result.get("time_confidence"), + } + try: + normalized = normalize_time_meta(raw_time_meta) + except Exception: + return None + has_effective = any(k in normalized for k in ("event_time", "event_time_start", "event_time_end")) + if not has_effective: + return None + return normalized + + def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[Any]: + if chunk.type == StrategyKnowledgeType.QUOTE: + return None + if looks_like_quote_text(chunk.chunk.text): + return QuoteStrategy(filename) + return None + + def _instantiate_strategy(self, filename: str, strategy: ImportStrategy) -> Any: + if strategy == ImportStrategy.FACTUAL: + return FactualStrategy(filename) + if strategy == ImportStrategy.QUOTE: + return QuoteStrategy(filename) + return NarrativeStrategy(filename) + + def _determine_strategy(self, filename: str, content: str, override: str, *, chat_log: bool = False) -> Any: + strategy = select_import_strategy( + content, + override=override, + chat_log=chat_log, + ) + return self._instantiate_strategy(filename, strategy) + + async def _set_file_strategy(self, task_id: str, file_id: str, strategy: Any) -> None: + if isinstance(strategy, str): + strategy_type = strategy + elif isinstance(strategy, NarrativeStrategy): + strategy_type = "narrative" + elif isinstance(strategy, FactualStrategy): + strategy_type = "factual" + elif isinstance(strategy, QuoteStrategy): + strategy_type = "quote" + else: + strategy_type = "unknown" + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.detected_strategy_type = strategy_type + f.updated_at = _now() + task.updated_at = _now() + + async def _register_chunks(self, task_id: str, file_id: str, chunks: List[ProcessedChunk]) -> None: + records = [ + ImportChunkRecord( + chunk_id=chunk.chunk.chunk_id, + index=index, + chunk_type=chunk.type.value, + status="queued", + step="queued", + progress=0.0, + content_preview=str(chunk.chunk.text or "")[:120], + ) + for index, chunk in enumerate(chunks) + ] + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.chunks = records + f.total_chunks = len(records) + f.done_chunks = 0 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 0.0 if records else 1.0 + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_file_state(self, task_id: str, file_id: str, status: str, step: str) -> None: + if status not in FILE_STATUS: + return + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.status = status + f.current_step = step + f.updated_at = _now() + task.updated_at = _now() + if step in {"preparing", "splitting", "extracting", "writing", "saving"} and task.status in {"queued", "preparing"}: + task.status = "running" + task.current_step = "running" + + async def _set_file_failed(self, task_id: str, file_id: str, error: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.status = "failed" + f.current_step = "failed" + f.error = str(error) + f.updated_at = _now() + task.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_file_cancelled(self, task_id: str, file_id: str, reason: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.status = "cancelled" + f.current_step = "cancelled" + f.error = reason + additional_cancelled = 0 + for chunk in f.chunks: + if chunk.status in {"completed", "failed", "cancelled"}: + continue + chunk.status = "cancelled" + chunk.step = "cancelled" + chunk.retryable = False + chunk.error = reason + chunk.progress = 1.0 + chunk.updated_at = _now() + additional_cancelled += 1 + if additional_cancelled > 0: + f.cancelled_chunks += additional_cancelled + f.progress = self._compute_ratio( + f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks + ) + f.updated_at = _now() + task.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_chunk_state( + self, + task_id: str, + file_id: str, + chunk_id: str, + status: str, + step: str, + progress: float, + ) -> None: + if status not in CHUNK_STATUS: + return + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if not c: + return + c.status = status + c.step = step + if status in {"queued", "extracting", "writing"}: + c.error = "" + c.failed_at = "" + c.retryable = False + c.progress = max(0.0, min(1.0, float(progress))) + c.updated_at = _now() + if f.status not in {"failed", "cancelled"}: + f.status = "extracting" if status == "extracting" else "writing" + f.current_step = step + f.updated_at = _now() + task.updated_at = _now() + + async def _set_chunk_completed(self, task_id: str, file_id: str, chunk_id: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if not c or c.status == "completed": + return + c.status = "completed" + c.step = "completed" + c.failed_at = "" + c.retryable = False + c.progress = 1.0 + c.updated_at = _now() + f.done_chunks += 1 + f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_chunk_failed(self, task_id: str, file_id: str, chunk_id: str, error: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if not c or c.status == "failed": + return + failed_stage = str(c.step or "").strip().lower() + if failed_stage in {"", "queued", "failed", "completed", "cancelled"}: + failed_stage = str(f.current_step or "").strip().lower() + if failed_stage in {"", "queued", "failed", "completed", "cancelled"}: + failed_stage = "unknown" + c.status = "failed" + c.step = "failed" + c.failed_at = failed_stage + c.retryable = bool(f.input_mode == "text" and failed_stage == "extracting") + c.error = str(error) + c.progress = 1.0 + c.updated_at = _now() + f.failed_chunks += 1 + f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) + if not f.error: + f.error = str(error) + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_chunk_cancelled(self, task_id: str, file_id: str, chunk_id: str, reason: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if not c or c.status == "cancelled": + return + c.status = "cancelled" + c.step = "cancelled" + c.retryable = False + c.error = reason + c.progress = 1.0 + c.updated_at = _now() + f.cancelled_chunks += 1 + f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _is_cancel_requested(self, task_id: str) -> bool: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return True + return task.status == "cancel_requested" + + def _find_file(self, task: ImportTaskRecord, file_id: str) -> Optional[ImportFileRecord]: + for f in task.files: + if f.file_id == file_id: + return f + return None + + def _find_chunk(self, file_record: ImportFileRecord, chunk_id: str) -> Optional[ImportChunkRecord]: + for c in file_record.chunks: + if c.chunk_id == chunk_id: + return c + return None + + def _compute_ratio(self, done: int, total: int) -> float: + if total <= 0: + return 1.0 + return max(0.0, min(1.0, float(done) / float(total))) + + def _recompute_task_progress(self, task: ImportTaskRecord) -> None: + total = 0 + done = 0 + failed = 0 + cancelled = 0 + for f in task.files: + total += f.total_chunks + done += f.done_chunks + failed += f.failed_chunks + cancelled += f.cancelled_chunks + task.total_chunks = total + task.done_chunks = done + task.failed_chunks = failed + task.cancelled_chunks = cancelled + task.progress = self._compute_ratio(done + failed + cancelled, total) + task.updated_at = _now() + + async def _should_cleanup_task_temp(self, task_id: str) -> bool: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return True + for f in task.files: + if f.status == "failed": + return False + return True + + def _mark_task_cancelled_locked(self, task: ImportTaskRecord, reason: str) -> None: + for f in task.files: + if f.status in {"completed", "failed", "cancelled"}: + continue + f.status = "cancelled" + f.current_step = "cancelled" + f.error = reason + additional_cancelled = 0 + for c in f.chunks: + if c.status in {"completed", "failed", "cancelled"}: + continue + c.status = "cancelled" + c.step = "cancelled" + c.retryable = False + c.error = reason + c.progress = 1.0 + c.updated_at = _now() + additional_cancelled += 1 + if additional_cancelled > 0: + f.cancelled_chunks += additional_cancelled + f.progress = self._compute_ratio( + f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks + ) + f.updated_at = _now() + task.status = "cancelled" + task.current_step = "cancelled" + task.finished_at = _now() + task.updated_at = _now() + self._recompute_task_progress(task) diff --git a/plugins/A_memorix/plugin.py b/plugins/A_memorix/plugin.py index 56df45b9..390515f5 100644 --- a/plugins/A_memorix/plugin.py +++ b/plugins/A_memorix/plugin.py @@ -15,6 +15,12 @@ def _tool_param(name: str, param_type: ToolParamType, description: str, required return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required) +_ADMIN_TOOL_PARAMS = [ + _tool_param("action", ToolParamType.STRING, "管理动作", True), + _tool_param("target", ToolParamType.STRING, "可选目标标识", False), +] + + class AMemorixPlugin(MaiBotPlugin): def __init__(self) -> None: super().__init__() @@ -33,7 +39,11 @@ class AMemorixPlugin(MaiBotPlugin): async def on_unload(self): if self._kernel is not None: - self._kernel.close() + shutdown = getattr(self._kernel, "shutdown", None) + if callable(shutdown): + await shutdown() + else: + self._kernel.close() self._kernel = None async def _get_kernel(self) -> SDKMemoryKernel: @@ -42,6 +52,11 @@ class AMemorixPlugin(MaiBotPlugin): await self._kernel.initialize() return self._kernel + async def _dispatch_admin_tool(self, method_name: str, action: str, **kwargs): + kernel = await self._get_kernel() + handler = getattr(kernel, method_name) + return await handler(action=action, **kwargs) + @Tool( "search_memory", description="搜索长期记忆", @@ -53,6 +68,7 @@ class AMemorixPlugin(MaiBotPlugin): _tool_param("person_id", ToolParamType.STRING, "人物 ID", False), _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), + _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), ], ) async def handle_search_memory( @@ -62,11 +78,11 @@ class AMemorixPlugin(MaiBotPlugin): mode: str = "hybrid", chat_id: str = "", person_id: str = "", - time_start: float | None = None, - time_end: float | None = None, + time_start: str | float | None = None, + time_end: str | float | None = None, + respect_filter: bool = True, **kwargs, ): - _ = kwargs kernel = await self._get_kernel() return await kernel.search_memory( KernelSearchRequest( @@ -77,6 +93,9 @@ class AMemorixPlugin(MaiBotPlugin): person_id=person_id, time_start=time_start, time_end=time_end, + respect_filter=respect_filter, + user_id=str(kwargs.get("user_id", "") or "").strip(), + group_id=str(kwargs.get("group_id", "") or "").strip(), ) ) @@ -89,6 +108,7 @@ class AMemorixPlugin(MaiBotPlugin): _tool_param("text", ToolParamType.STRING, "摘要文本", True), _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), + _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), ], ) async def handle_ingest_summary( @@ -101,9 +121,9 @@ class AMemorixPlugin(MaiBotPlugin): time_end: float | None = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + respect_filter: bool = True, **kwargs, ): - _ = kwargs kernel = await self._get_kernel() return await kernel.ingest_summary( external_id=external_id, @@ -114,6 +134,9 @@ class AMemorixPlugin(MaiBotPlugin): time_end=time_end, tags=tags, metadata=metadata, + respect_filter=respect_filter, + user_id=str(kwargs.get("user_id", "") or "").strip(), + group_id=str(kwargs.get("group_id", "") or "").strip(), ) @Tool( @@ -125,6 +148,7 @@ class AMemorixPlugin(MaiBotPlugin): _tool_param("text", ToolParamType.STRING, "原始文本", True), _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), _tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False), + _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), ], ) async def handle_ingest_text( @@ -140,6 +164,7 @@ class AMemorixPlugin(MaiBotPlugin): time_end: float | None = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + respect_filter: bool = True, **kwargs, ): relations = kwargs.get("relations") @@ -159,6 +184,9 @@ class AMemorixPlugin(MaiBotPlugin): metadata=metadata, entities=entities, relations=relations, + respect_filter=respect_filter, + user_id=str(kwargs.get("user_id", "") or "").strip(), + group_id=str(kwargs.get("group_id", "") or "").strip(), ) @Tool( @@ -179,22 +207,24 @@ class AMemorixPlugin(MaiBotPlugin): "maintain_memory", description="维护长期记忆关系状态", parameters=[ - _tool_param("action", ToolParamType.STRING, "reinforce/protect/restore", True), - _tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", True), + _tool_param("action", ToolParamType.STRING, "reinforce/protect/restore/freeze/recycle_bin", True), + _tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", False), _tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False), + _tool_param("limit", ToolParamType.INTEGER, "查询条数(用于 recycle_bin)", False), ], ) async def handle_maintain_memory( self, action: str, - target: str, + target: str = "", hours: float | None = None, reason: str = "", + limit: int = 50, **kwargs, ): _ = kwargs kernel = await self._get_kernel() - return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason) + return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason, limit=limit) @Tool("memory_stats", description="获取长期记忆统计", parameters=[]) async def handle_memory_stats(self, **kwargs): @@ -202,6 +232,42 @@ class AMemorixPlugin(MaiBotPlugin): kernel = await self._get_kernel() return kernel.memory_stats() + @Tool("memory_graph_admin", description="长期记忆图谱管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_graph_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_graph_admin", action=action, **kwargs) + + @Tool("memory_source_admin", description="长期记忆来源管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_source_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_source_admin", action=action, **kwargs) + + @Tool("memory_episode_admin", description="Episode 管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_episode_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_episode_admin", action=action, **kwargs) + + @Tool("memory_profile_admin", description="人物画像管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_profile_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_profile_admin", action=action, **kwargs) + + @Tool("memory_runtime_admin", description="长期记忆运行时管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_runtime_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_runtime_admin", action=action, **kwargs) + + @Tool("memory_import_admin", description="长期记忆导入管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_import_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_import_admin", action=action, **kwargs) + + @Tool("memory_tuning_admin", description="长期记忆调优管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_tuning_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_tuning_admin", action=action, **kwargs) + + @Tool("memory_v5_admin", description="长期记忆 V5 管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_v5_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_v5_admin", action=action, **kwargs) + + @Tool("memory_delete_admin", description="长期记忆删除管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_delete_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_delete_admin", action=action, **kwargs) + def create_plugin(): return AMemorixPlugin() diff --git a/plugins/A_memorix/requirements.txt b/plugins/A_memorix/requirements.txt new file mode 100644 index 00000000..f737fdf4 --- /dev/null +++ b/plugins/A_memorix/requirements.txt @@ -0,0 +1,52 @@ +# A_Memorix 插件依赖 +# +# 核心依赖 (必需) +# ================== + +# 数值计算 - 用于向量操作、矩阵计算 +numpy>=1.20.0 + +# 稀疏矩阵 - 用于图存储的邻接矩阵 +scipy>=1.7.0 + +# 图结构处理(LPMM 转换) +networkx>=3.0.0 + +# Parquet 读取(LPMM 转换) +pyarrow>=10.0.0 + +# DataFrame 处理(LPMM 转换) +pandas>=1.5.0 + +# 异步事件循环嵌套 - 用于插件初始化时的异步操作 +nest-asyncio>=1.5.0 + +# 向量索引 - 用于向量存储和检索 +faiss-cpu>=1.7.0 + +# Web 服务器依赖 (可视化功能需要) +# ================== + +# ASGI 服务器 +uvicorn>=0.20.0 + +# Web 框架 +fastapi>=0.100.0 + +# 数据验证 +pydantic>=2.0.0 +python-multipart>=0.0.9 + +# 注意事项 +# ================== +# +# 1. sqlite3 是 Python 标准库,无需安装 +# 2. json, re, time, pathlib 等都是标准库 +# 3. sentence-transformers 不需要(使用主程序 Embedding API) + +# UI 交互 +rich>=14.0.0 +tenacity>=8.0.0 + +# 稀疏检索中文分词(可选,未安装时自动回退 char n-gram) +jieba>=0.42.1 diff --git a/plugins/A_memorix/scripts/audit_vector_consistency.py b/plugins/A_memorix/scripts/audit_vector_consistency.py new file mode 100644 index 00000000..c97806dc --- /dev/null +++ b/plugins/A_memorix/scripts/audit_vector_consistency.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +A_Memorix 一致性审计脚本。 + +输出内容: +1. paragraph/entity/relation 向量覆盖率 +2. relation vector_state 分布 +3. 孤儿向量数量(向量存在但 metadata 不存在) +4. 状态与向量文件不一致统计 +""" + +from __future__ import annotations + +import argparse +import json +import pickle +import sys +from pathlib import Path +from typing import Any, Dict, Set + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PLUGIN_ROOT)) + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="审计 A_Memorix 向量一致性") + parser.add_argument( + "--data-dir", + default=str(PLUGIN_ROOT / "data"), + help="A_Memorix 数据目录(默认: plugins/A_memorix/data)", + ) + parser.add_argument("--json-out", default="", help="可选:输出 JSON 文件路径") + parser.add_argument( + "--strict", + action="store_true", + help="若发现一致性异常则返回非 0 退出码", + ) + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + sys.exit(0) + +try: + from core.storage.vector_store import VectorStore + from core.storage.metadata_store import MetadataStore + from core.storage import QuantizationType +except Exception as e: # pragma: no cover + print(f"❌ 导入核心模块失败: {e}") + sys.exit(1) + + +def _safe_ratio(numerator: int, denominator: int) -> float: + if denominator <= 0: + return 0.0 + return float(numerator) / float(denominator) + + +def _load_vector_store(data_dir: Path) -> VectorStore: + meta_path = data_dir / "vectors" / "vectors_metadata.pkl" + if not meta_path.exists(): + raise FileNotFoundError(f"未找到向量元数据文件: {meta_path}") + + with open(meta_path, "rb") as f: + meta = pickle.load(f) + dimension = int(meta.get("dimension", 1024)) + + store = VectorStore( + dimension=max(1, dimension), + quantization_type=QuantizationType.INT8, + data_dir=data_dir / "vectors", + ) + if store.has_data(): + store.load() + return store + + +def _load_metadata_store(data_dir: Path) -> MetadataStore: + store = MetadataStore(data_dir=data_dir / "metadata") + store.connect() + return store + + +def _hash_set(metadata_store: MetadataStore, table: str) -> Set[str]: + return {str(h) for h in metadata_store.list_hashes(table)} + + +def _relation_state_stats(metadata_store: MetadataStore) -> Dict[str, int]: + return metadata_store.count_relations_by_vector_state() + + +def run_audit(data_dir: Path) -> Dict[str, Any]: + vector_store = _load_vector_store(data_dir) + metadata_store = _load_metadata_store(data_dir) + try: + paragraph_hashes = _hash_set(metadata_store, "paragraphs") + entity_hashes = _hash_set(metadata_store, "entities") + relation_hashes = _hash_set(metadata_store, "relations") + + known_hashes = set(getattr(vector_store, "_known_hashes", set())) + live_vector_hashes = {h for h in known_hashes if h in vector_store} + + para_vector_hits = len(paragraph_hashes & live_vector_hashes) + ent_vector_hits = len(entity_hashes & live_vector_hashes) + rel_vector_hits = len(relation_hashes & live_vector_hashes) + + orphan_vector_hashes = sorted( + live_vector_hashes - paragraph_hashes - entity_hashes - relation_hashes + ) + + relation_rows = metadata_store.get_relations() + ready_but_missing = 0 + not_ready_but_present = 0 + for row in relation_rows: + h = str(row.get("hash") or "") + state = str(row.get("vector_state") or "none").lower() + in_vector = h in live_vector_hashes + if state == "ready" and not in_vector: + ready_but_missing += 1 + if state != "ready" and in_vector: + not_ready_but_present += 1 + + relation_states = _relation_state_stats(metadata_store) + rel_total = max(0, int(relation_states.get("total", len(relation_hashes)))) + ready_count = max(0, int(relation_states.get("ready", 0))) + + result = { + "counts": { + "paragraphs": len(paragraph_hashes), + "entities": len(entity_hashes), + "relations": len(relation_hashes), + "vectors_live": len(live_vector_hashes), + }, + "coverage": { + "paragraph_vector_coverage": _safe_ratio(para_vector_hits, len(paragraph_hashes)), + "entity_vector_coverage": _safe_ratio(ent_vector_hits, len(entity_hashes)), + "relation_vector_coverage": _safe_ratio(rel_vector_hits, len(relation_hashes)), + "relation_ready_coverage": _safe_ratio(ready_count, rel_total), + }, + "relation_states": relation_states, + "orphans": { + "vector_only_count": len(orphan_vector_hashes), + "vector_only_sample": orphan_vector_hashes[:30], + }, + "consistency_checks": { + "ready_but_missing_vector": ready_but_missing, + "not_ready_but_vector_present": not_ready_but_present, + }, + } + return result + finally: + metadata_store.close() + + +def main() -> int: + parser = _build_arg_parser() + args = parser.parse_args() + + data_dir = Path(args.data_dir).resolve() + if not data_dir.exists(): + print(f"❌ 数据目录不存在: {data_dir}") + return 2 + + try: + result = run_audit(data_dir) + except Exception as e: + print(f"❌ 审计失败: {e}") + return 2 + + print("=== A_Memorix Vector Consistency Audit ===") + print(f"data_dir: {data_dir}") + print(f"paragraphs: {result['counts']['paragraphs']}") + print(f"entities: {result['counts']['entities']}") + print(f"relations: {result['counts']['relations']}") + print(f"vectors_live: {result['counts']['vectors_live']}") + print( + "coverage: " + f"paragraph={result['coverage']['paragraph_vector_coverage']:.3f}, " + f"entity={result['coverage']['entity_vector_coverage']:.3f}, " + f"relation={result['coverage']['relation_vector_coverage']:.3f}, " + f"relation_ready={result['coverage']['relation_ready_coverage']:.3f}" + ) + print(f"relation_states: {result['relation_states']}") + print( + "consistency_checks: " + f"ready_but_missing_vector={result['consistency_checks']['ready_but_missing_vector']}, " + f"not_ready_but_vector_present={result['consistency_checks']['not_ready_but_vector_present']}" + ) + print(f"orphan_vectors: {result['orphans']['vector_only_count']}") + + if args.json_out: + out_path = Path(args.json_out).resolve() + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + print(f"json_out: {out_path}") + + has_anomaly = ( + result["orphans"]["vector_only_count"] > 0 + or result["consistency_checks"]["ready_but_missing_vector"] > 0 + ) + if args.strict and has_anomaly: + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/backfill_relation_vectors.py b/plugins/A_memorix/scripts/backfill_relation_vectors.py new file mode 100644 index 00000000..7ba0ade0 --- /dev/null +++ b/plugins/A_memorix/scripts/backfill_relation_vectors.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +""" +关系向量一次性回填脚本(灰度/离线执行)。 + +用途: +1. 对 relations 中 vector_state in (none, failed, pending) 的记录补齐向量。 +2. 支持并发控制,降低总耗时。 +3. 可作为灰度阶段验证工具,与 audit_vector_consistency.py 配合使用。 +4. 可选自动纳入“ready 但向量缺失”的漂移记录进行修复。 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +import time +from pathlib import Path +from typing import Any, Dict, List + +import tomlkit + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PLUGIN_ROOT)) + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="关系向量一次性回填") + parser.add_argument( + "--config", + default=str(PLUGIN_ROOT / "config.toml"), + help="配置文件路径(默认 plugins/A_memorix/config.toml)", + ) + parser.add_argument( + "--data-dir", + default=str(PLUGIN_ROOT / "data"), + help="数据目录(默认 plugins/A_memorix/data)", + ) + parser.add_argument( + "--states", + default="none,failed,pending", + help="待处理状态列表,逗号分隔", + ) + parser.add_argument("--limit", type=int, default=50000, help="最大处理数量") + parser.add_argument("--concurrency", type=int, default=8, help="并发数") + parser.add_argument("--max-retry", type=int, default=None, help="最大重试次数过滤") + parser.add_argument( + "--include-ready-missing", + action="store_true", + help="额外纳入 vector_state=ready 但向量缺失的关系", + ) + parser.add_argument("--dry-run", action="store_true", help="仅统计候选,不写入") + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + +from core.storage import ( + VectorStore, + GraphStore, + MetadataStore, + QuantizationType, + SparseMatrixFormat, +) +from core.embedding import create_embedding_api_adapter +from core.utils.relation_write_service import RelationWriteService + + +def _load_config(config_path: Path) -> Dict[str, Any]: + with open(config_path, "r", encoding="utf-8") as f: + raw = tomlkit.load(f) + return dict(raw) if isinstance(raw, dict) else {} + + +def _build_vector_store(data_dir: Path, emb_cfg: Dict[str, Any]) -> VectorStore: + q_type = str(emb_cfg.get("quantization_type", "int8")).lower() + if q_type != "int8": + raise ValueError( + "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + dim = int(emb_cfg.get("dimension", 1024)) + store = VectorStore( + dimension=max(1, dim), + quantization_type=QuantizationType.INT8, + data_dir=data_dir / "vectors", + ) + if store.has_data(): + store.load() + return store + + +def _build_graph_store(data_dir: Path, graph_cfg: Dict[str, Any]) -> GraphStore: + fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower() + fmt_map = { + "csr": SparseMatrixFormat.CSR, + "csc": SparseMatrixFormat.CSC, + } + store = GraphStore( + matrix_format=fmt_map.get(fmt, SparseMatrixFormat.CSR), + data_dir=data_dir / "graph", + ) + if store.has_data(): + store.load() + return store + + +def _build_metadata_store(data_dir: Path) -> MetadataStore: + store = MetadataStore(data_dir=data_dir / "metadata") + store.connect() + return store + + +def _build_embedding_manager(emb_cfg: Dict[str, Any]): + retry_cfg = emb_cfg.get("retry", {}) + if not isinstance(retry_cfg, dict): + retry_cfg = {} + return create_embedding_api_adapter( + batch_size=int(emb_cfg.get("batch_size", 32)), + max_concurrent=int(emb_cfg.get("max_concurrent", 5)), + default_dimension=int(emb_cfg.get("dimension", 1024)), + model_name=str(emb_cfg.get("model_name", "auto")), + retry_config=retry_cfg, + ) + + +async def _process_rows( + service: RelationWriteService, + rows: List[Dict[str, Any]], + concurrency: int, +) -> Dict[str, int]: + semaphore = asyncio.Semaphore(max(1, int(concurrency))) + stat = {"success": 0, "failed": 0, "skipped": 0} + + async def _worker(row: Dict[str, Any]) -> None: + async with semaphore: + result = await service.ensure_relation_vector( + hash_value=str(row["hash"]), + subject=str(row.get("subject", "")), + predicate=str(row.get("predicate", "")), + obj=str(row.get("object", "")), + ) + if result.vector_state == "ready": + if result.vector_written: + stat["success"] += 1 + else: + stat["skipped"] += 1 + else: + stat["failed"] += 1 + + await asyncio.gather(*[_worker(row) for row in rows]) + return stat + + +async def main_async(args: argparse.Namespace) -> int: + config_path = Path(args.config).resolve() + if not config_path.exists(): + print(f"❌ 配置文件不存在: {config_path}") + return 2 + + cfg = _load_config(config_path) + emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {} + graph_cfg = cfg.get("graph", {}) if isinstance(cfg, dict) else {} + retrieval_cfg = cfg.get("retrieval", {}) if isinstance(cfg, dict) else {} + rv_cfg = retrieval_cfg.get("relation_vectorization", {}) if isinstance(retrieval_cfg, dict) else {} + if not isinstance(emb_cfg, dict): + emb_cfg = {} + if not isinstance(graph_cfg, dict): + graph_cfg = {} + if not isinstance(rv_cfg, dict): + rv_cfg = {} + + data_dir = Path(args.data_dir).resolve() + if not data_dir.exists(): + print(f"❌ 数据目录不存在: {data_dir}") + return 2 + + print(f"data_dir: {data_dir}") + print(f"config: {config_path}") + + vector_store = _build_vector_store(data_dir, emb_cfg) + graph_store = _build_graph_store(data_dir, graph_cfg) + metadata_store = _build_metadata_store(data_dir) + embedding_manager = _build_embedding_manager(emb_cfg) + service = RelationWriteService( + metadata_store=metadata_store, + graph_store=graph_store, + vector_store=vector_store, + embedding_manager=embedding_manager, + ) + + try: + states = [s.strip() for s in str(args.states).split(",") if s.strip()] + if not states: + states = ["none", "failed", "pending"] + max_retry = int(args.max_retry) if args.max_retry is not None else int(rv_cfg.get("max_retry", 3)) + limit = int(args.limit) + + rows = metadata_store.list_relations_by_vector_state( + states=states, + limit=max(1, limit), + max_retry=max(1, max_retry), + ) + added_ready_missing = 0 + if args.include_ready_missing: + ready_rows = metadata_store.list_relations_by_vector_state( + states=["ready"], + limit=max(1, limit), + max_retry=max(1, max_retry), + ) + ready_missing_rows = [ + row for row in ready_rows if str(row.get("hash", "")) not in vector_store + ] + added_ready_missing = len(ready_missing_rows) + if ready_missing_rows: + dedup: Dict[str, Dict[str, Any]] = {} + for row in rows: + dedup[str(row.get("hash", ""))] = row + for row in ready_missing_rows: + dedup.setdefault(str(row.get("hash", "")), row) + rows = list(dedup.values())[: max(1, limit)] + print(f"candidates: {len(rows)} (states={states}, max_retry={max_retry})") + if args.include_ready_missing: + print(f"ready_missing_candidates_added: {added_ready_missing}") + if not rows: + return 0 + + if args.dry_run: + print("dry_run=true,未执行写入。") + return 0 + + started = time.time() + stat = await _process_rows( + service=service, + rows=rows, + concurrency=int(args.concurrency), + ) + elapsed = (time.time() - started) * 1000.0 + + vector_store.save() + graph_store.save() + state_stats = metadata_store.count_relations_by_vector_state() + output = { + "processed": len(rows), + "success": int(stat["success"]), + "failed": int(stat["failed"]), + "skipped": int(stat["skipped"]), + "elapsed_ms": elapsed, + "state_stats": state_stats, + } + print(json.dumps(output, ensure_ascii=False, indent=2)) + return 0 if stat["failed"] == 0 else 1 + finally: + metadata_store.close() + + +def parse_args() -> argparse.Namespace: + return _build_arg_parser().parse_args() + + +if __name__ == "__main__": + arguments = parse_args() + raise SystemExit(asyncio.run(main_async(arguments))) diff --git a/plugins/A_memorix/scripts/backfill_temporal_metadata.py b/plugins/A_memorix/scripts/backfill_temporal_metadata.py new file mode 100644 index 00000000..b68820cd --- /dev/null +++ b/plugins/A_memorix/scripts/backfill_temporal_metadata.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +""" +回填段落时序字段。 + +默认策略: +1. 若段落缺失 event_time/event_time_start/event_time_end +2. 且存在 created_at +3. 写入 event_time=created_at, time_granularity=day, time_confidence=0.2 +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from plugins.A_memorix.core.storage import MetadataStore # noqa: E402 + + +def backfill( + data_dir: Path, + dry_run: bool, + limit: int, + no_created_fallback: bool, +) -> int: + store = MetadataStore(data_dir=data_dir) + store.connect() + summary = store.backfill_temporal_metadata_from_created_at( + limit=limit, + dry_run=dry_run, + no_created_fallback=no_created_fallback, + ) + store.close() + if dry_run: + print(f"[dry-run] candidates={summary['candidates']}") + return int(summary["candidates"]) + if no_created_fallback: + print(f"skip update (no-created-fallback), candidates={summary['candidates']}") + return 0 + print(f"updated={summary['updated']}") + return int(summary["updated"]) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Backfill temporal metadata for A_Memorix paragraphs") + parser.add_argument("--data-dir", default=str(PLUGIN_ROOT / "data"), help="数据目录") + parser.add_argument("--dry-run", action="store_true", help="仅统计,不写入") + parser.add_argument("--limit", type=int, default=100000, help="最大处理条数") + parser.add_argument( + "--no-created-fallback", + action="store_true", + help="不使用 created_at 回填,仅输出候选数量", + ) + args = parser.parse_args() + + backfill( + data_dir=Path(args.data_dir), + dry_run=args.dry_run, + limit=max(1, int(args.limit)), + no_created_fallback=args.no_created_fallback, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/plugins/A_memorix/scripts/convert_lpmm.py b/plugins/A_memorix/scripts/convert_lpmm.py index 5ff284fb..2ef0b396 100644 --- a/plugins/A_memorix/scripts/convert_lpmm.py +++ b/plugins/A_memorix/scripts/convert_lpmm.py @@ -46,9 +46,14 @@ if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): _build_arg_parser().print_help() sys.exit(0) -# 设置日志 -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -logger = logging.getLogger("LPMM_Converter") +# 设置日志:优先复用 MaiBot 统一日志体系,失败时回退到标准 logging。 +try: + from src.common.logger import get_logger + + logger = get_logger("A_Memorix.LPMMConverter") +except Exception: + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + logger = logging.getLogger("A_Memorix.LPMMConverter") try: import networkx as nx @@ -225,11 +230,11 @@ class LPMMConverter: failed += 1 logger.info( - "关系向量重建完成: total=%s success=%s skipped=%s failed=%s", - len(rows), - success, - skipped, - failed, + "关系向量重建完成: " + f"total={len(rows)} " + f"success={success} " + f"skipped={skipped} " + f"failed={failed}" ) @staticmethod @@ -317,8 +322,8 @@ class LPMMConverter: if p_type == "relation": relation_count = self._import_relation_metadata_from_parquet(p_path) logger.warning( - "跳过 relation.parquet 向量导入(保持一致性);已导入关系元数据: %s", - relation_count, + "跳过 relation.parquet 向量导入(保持一致性);" + f"已导入关系元数据: {relation_count}" ) continue diff --git a/plugins/A_memorix/scripts/import_lpmm_json.py b/plugins/A_memorix/scripts/import_lpmm_json.py new file mode 100644 index 00000000..2e458e16 --- /dev/null +++ b/plugins/A_memorix/scripts/import_lpmm_json.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +""" +LPMM OpenIE JSON 导入工具。 + +功能: +1. 读取符合 LPMM 规范的 OpenIE JSON 文件 +2. 转换为 A_Memorix 的统一导入格式 +3. 复用 `process_knowledge.py` 中的 `AutoImporter` 直接入库 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +import traceback +from pathlib import Path +from typing import Any, Dict, List + +from rich.console import Console +from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn + +console = Console() + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +for path in (CURRENT_DIR, WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="将 LPMM OpenIE JSON 导入 A_Memorix") + parser.add_argument("path", help="LPMM JSON 文件路径或目录") + parser.add_argument("--force", action="store_true", help="强制重新导入") + parser.add_argument("--concurrency", "-c", type=int, default=5, help="并发数") + return parser + + +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + + +try: + from process_knowledge import AutoImporter + from A_memorix.core.utils.hash import compute_paragraph_hash + from src.common.logger import get_logger +except ImportError as exc: # pragma: no cover - script bootstrap + print(f"导入模块失败,请确认 PYTHONPATH 与工作区结构: {exc}") + raise SystemExit(1) + + +logger = get_logger("A_Memorix.LPMMImport") + + +class LPMMConverter: + def convert_lpmm_to_memorix(self, lpmm_data: Dict[str, Any], filename: str) -> Dict[str, Any]: + memorix_data = {"paragraphs": [], "entities": []} + docs = lpmm_data.get("docs", []) or [] + if not docs: + logger.warning(f"文件中未找到 docs 字段: {filename}") + return memorix_data + + all_entities = set() + for doc in docs: + content = str(doc.get("passage", "") or "").strip() + if not content: + continue + + relations: List[Dict[str, str]] = [] + for triple in doc.get("extracted_triples", []) or []: + if isinstance(triple, list) and len(triple) == 3: + relations.append( + { + "subject": str(triple[0] or "").strip(), + "predicate": str(triple[1] or "").strip(), + "object": str(triple[2] or "").strip(), + } + ) + + entities = [str(item or "").strip() for item in doc.get("extracted_entities", []) or [] if str(item or "").strip()] + all_entities.update(entities) + for relation in relations: + if relation["subject"]: + all_entities.add(relation["subject"]) + if relation["object"]: + all_entities.add(relation["object"]) + + memorix_data["paragraphs"].append( + { + "hash": compute_paragraph_hash(content), + "content": content, + "source": filename, + "entities": entities, + "relations": relations, + } + ) + + memorix_data["entities"] = sorted(all_entities) + return memorix_data + + +async def main() -> None: + parser = _build_arg_parser() + args = parser.parse_args() + + target_path = Path(args.path) + if not target_path.exists(): + logger.error(f"路径不存在: {target_path}") + return + + if target_path.is_dir(): + files_to_process = list(target_path.glob("*-openie.json")) or list(target_path.glob("*.json")) + else: + files_to_process = [target_path] + + if not files_to_process: + logger.error("未找到可处理的 JSON 文件") + return + + importer = AutoImporter(force=bool(args.force), concurrency=int(args.concurrency)) + if not await importer.initialize(): + logger.error("初始化存储失败") + return + + converter = LPMMConverter() + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + console=console, + transient=False, + ) as progress: + for json_file in files_to_process: + logger.info(f"正在转换并导入: {json_file.name}") + try: + with open(json_file, "r", encoding="utf-8") as handle: + lpmm_data = json.load(handle) + memorix_data = converter.convert_lpmm_to_memorix(lpmm_data, json_file.name) + total_items = len(memorix_data.get("paragraphs", [])) + if total_items <= 0: + logger.warning(f"转换结果为空: {json_file.name}") + continue + + task_id = progress.add_task(f"Importing {json_file.name}", total=total_items) + + def update_progress(step: int = 1) -> None: + progress.advance(task_id, advance=step) + + await importer.import_json_data( + memorix_data, + filename=f"lpmm_{json_file.name}", + progress_callback=update_progress, + ) + except Exception as exc: + logger.error(f"处理文件 {json_file.name} 失败: {exc}\n{traceback.format_exc()}") + + await importer.close() + logger.info("全部处理完成") + + +if __name__ == "__main__": + if sys.platform == "win32": # pragma: no cover + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.run(main()) diff --git a/plugins/A_memorix/scripts/migrate_maibot_memory.py b/plugins/A_memorix/scripts/migrate_maibot_memory.py new file mode 100644 index 00000000..0b26a9cd --- /dev/null +++ b/plugins/A_memorix/scripts/migrate_maibot_memory.py @@ -0,0 +1,1714 @@ +#!/usr/bin/env python3 +""" +MaiBot 记忆迁移脚本(chat_history -> A_memorix) + +特性: +1. 高性能:分页读取 + 批量 embedding + 批量写入 +2. 断点续传:基于 last_committed_id 的窗口提交 +3. 精确一次语义:稳定哈希 + 幂等写入 + 向量存在性检查 +4. 可确认筛选:支持时间区间、聊天流(stream/group/user)筛选,并先预览后确认 +""" + +from __future__ import annotations + +import argparse +import asyncio +import hashlib +import importlib +import json +import logging +import os +import pickle +import sqlite3 +import sys +import time +import traceback +import types +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import tomlkit + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +RUNTIME_CORE_PACKAGE = "_a_memorix_runtime_core" + +VectorStore = None +GraphStore = None +MetadataStore = None +create_embedding_api_adapter = None +KnowledgeType = None +QuantizationType = None +SparseMatrixFormat = None +compute_hash = None +normalize_text = None +atomic_write = None +model_config = None +RelationWriteService = None + + +def _create_bootstrap_logger(): + fallback = logging.getLogger("A_Memorix.MaiBotMigration") + if not fallback.handlers: + fallback.addHandler(logging.NullHandler()) + try: + for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + from src.common.logger import get_logger + + return get_logger("A_Memorix.MaiBotMigration") + except Exception: + return fallback + + +logger = _create_bootstrap_logger() + + +def _ensure_import_paths() -> None: + for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def _ensure_runtime_core_package() -> str: + existing = sys.modules.get(RUNTIME_CORE_PACKAGE) + if existing is not None and hasattr(existing, "__path__"): + return RUNTIME_CORE_PACKAGE + + pkg = types.ModuleType(RUNTIME_CORE_PACKAGE) + pkg.__path__ = [str(PLUGIN_ROOT / "core")] + pkg.__package__ = RUNTIME_CORE_PACKAGE + sys.modules[RUNTIME_CORE_PACKAGE] = pkg + return RUNTIME_CORE_PACKAGE + + +def _disable_unavailable_gemini_provider() -> None: + global model_config + try: + from google import genai # type: ignore # noqa: F401 + return + except Exception: + pass + + from src.config.config import model_config as loaded_model_config + + providers = list(getattr(loaded_model_config, "api_providers", [])) + if not providers: + model_config = loaded_model_config + return + + kept_providers = [p for p in providers if str(getattr(p, "client_type", "")).lower() != "gemini"] + if len(kept_providers) == len(providers): + model_config = loaded_model_config + return + + loaded_model_config.api_providers = kept_providers + loaded_model_config.api_providers_dict = {p.name: p for p in kept_providers} + + models = list(getattr(loaded_model_config, "models", [])) + kept_models = [m for m in models if m.api_provider in loaded_model_config.api_providers_dict] + loaded_model_config.models = kept_models + loaded_model_config.models_dict = {m.name: m for m in kept_models} + + task_cfg = loaded_model_config.model_task_config + for field_name in task_cfg.__dataclass_fields__.keys(): + task = getattr(task_cfg, field_name, None) + if task is None or not hasattr(task, "model_list"): + continue + task.model_list = [m for m in list(task.model_list) if m in loaded_model_config.models_dict] + + model_config = loaded_model_config + logger.warning("检测到缺少 google.genai,已临时禁用 gemini provider 以保证脚本可运行。") + + +def _bootstrap_runtime_symbols() -> None: + global VectorStore + global GraphStore + global MetadataStore + global KnowledgeType + global QuantizationType + global SparseMatrixFormat + global compute_hash + global normalize_text + global atomic_write + global RelationWriteService + global logger + + if VectorStore is not None and compute_hash is not None and atomic_write is not None: + return + + _ensure_import_paths() + + import src # noqa: F401 + from src.common.logger import get_logger + + logger = get_logger("A_Memorix.MaiBotMigration") + + pkg = _ensure_runtime_core_package() + + vector_store_module = importlib.import_module(f"{pkg}.storage.vector_store") + graph_store_module = importlib.import_module(f"{pkg}.storage.graph_store") + metadata_store_module = importlib.import_module(f"{pkg}.storage.metadata_store") + knowledge_types_module = importlib.import_module(f"{pkg}.storage.knowledge_types") + hash_module = importlib.import_module(f"{pkg}.utils.hash") + io_module = importlib.import_module(f"{pkg}.utils.io") + relation_write_service_module = importlib.import_module(f"{pkg}.utils.relation_write_service") + + VectorStore = vector_store_module.VectorStore + GraphStore = graph_store_module.GraphStore + MetadataStore = metadata_store_module.MetadataStore + KnowledgeType = knowledge_types_module.KnowledgeType + QuantizationType = vector_store_module.QuantizationType + SparseMatrixFormat = graph_store_module.SparseMatrixFormat + compute_hash = hash_module.compute_hash + normalize_text = hash_module.normalize_text + atomic_write = io_module.atomic_write + RelationWriteService = relation_write_service_module.RelationWriteService + + +def _load_embedding_adapter_factory() -> None: + global create_embedding_api_adapter + global model_config + + if create_embedding_api_adapter is not None: + return + + _ensure_import_paths() + + from src.config.config import model_config as loaded_model_config + + model_config = loaded_model_config + _disable_unavailable_gemini_provider() + + pkg = _ensure_runtime_core_package() + api_adapter_module = importlib.import_module(f"{pkg}.embedding.api_adapter") + create_embedding_api_adapter = api_adapter_module.create_embedding_api_adapter + + +DEFAULT_SOURCE_DB = MAIBOT_ROOT / "data" / "MaiBot.db" +DEFAULT_TARGET_DATA_DIR = PLUGIN_ROOT / "data" +DEFAULT_CONFIG_PATH = PLUGIN_ROOT / "config.toml" + +MIGRATION_STATE_DIRNAME = "migration_state" +STATE_FILENAME = "chat_history_resume.json" +BAD_ROWS_FILENAME = "chat_history_bad_rows.jsonl" +REPORT_FILENAME = "chat_history_report.json" + + +class MigrationError(Exception): + """迁移流程错误。""" + + +@dataclass +class SelectionFilter: + time_from_ts: Optional[float] + time_to_ts: Optional[float] + stream_ids: List[str] + stream_filter_requested: bool + start_id: Optional[int] + end_id: Optional[int] + time_from_raw: Optional[str] + time_to_raw: Optional[str] + + def fingerprint_payload(self) -> Dict[str, Any]: + return { + "time_from_ts": self.time_from_ts, + "time_to_ts": self.time_to_ts, + "time_from_raw": self.time_from_raw, + "time_to_raw": self.time_to_raw, + "stream_ids": sorted(self.stream_ids), + "stream_filter_requested": self.stream_filter_requested, + "start_id": self.start_id, + "end_id": self.end_id, + } + + +@dataclass +class PreviewResult: + total: int + distribution: List[Tuple[str, int]] + samples: List[Dict[str, Any]] + + +@dataclass +class MappedRow: + row_id: int + chat_id: str + paragraph_hash: str + content: str + source: str + time_meta: Dict[str, Any] + entities: List[str] + relations: List[Tuple[str, str, str]] + existing_paragraph_vector: bool + + +def _safe_int(value: Any, default: int) -> int: + try: + return int(value) + except Exception: + return default + + +def _safe_float(value: Any, default: float) -> float: + try: + return float(value) + except Exception: + return default + + +def _normalize_name(value: Any) -> str: + return str(value or "").strip() + + +def _canonical_name(value: Any) -> str: + return _normalize_name(value).lower() + + +def _dedup_keep_order(items: Iterable[str]) -> List[str]: + out: List[str] = [] + seen: set[str] = set() + for raw in items: + v = _normalize_name(raw) + if not v: + continue + k = v.lower() + if k in seen: + continue + seen.add(k) + out.append(v) + return out + + +def _format_ts(ts: Optional[float]) -> str: + if ts is None: + return "-" + try: + return datetime.fromtimestamp(float(ts)).strftime("%Y-%m-%d %H:%M:%S") + except Exception: + return str(ts) + + +def _parse_cli_datetime(text: str, is_end: bool = False) -> float: + value = str(text or "").strip() + if not value: + raise ValueError("时间不能为空") + + formats = [ + ("%Y-%m-%d %H:%M:%S", False), + ("%Y/%m/%d %H:%M:%S", False), + ("%Y-%m-%d %H:%M", False), + ("%Y/%m/%d %H:%M", False), + ("%Y-%m-%d", True), + ("%Y/%m/%d", True), + ] + + for fmt, is_date_only in formats: + try: + dt = datetime.strptime(value, fmt) + if is_date_only and is_end: + dt = dt.replace(hour=23, minute=59, second=59, microsecond=0) + return dt.timestamp() + except ValueError: + continue + + raise ValueError( + f"时间格式错误: {value},仅支持 YYYY-MM-DD、YYYY/MM/DD、YYYY-MM-DD HH:mm[:ss]、YYYY/MM/DD HH:mm[:ss]" + ) + + +def _json_hash(payload: Dict[str, Any]) -> str: + data = json.dumps(payload, ensure_ascii=False, sort_keys=True) + return hashlib.sha1(data.encode("utf-8")).hexdigest() + + +def _deep_merge_dict(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + out = dict(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(out.get(key), dict): + out[key] = _deep_merge_dict(out[key], value) + else: + out[key] = value + return out + + +def _extract_schema_defaults(schema_obj: Dict[str, Any]) -> Dict[str, Any]: + defaults: Dict[str, Any] = {} + if not isinstance(schema_obj, dict): + return defaults + + for key, spec in schema_obj.items(): + if not isinstance(spec, dict): + continue + if "default" in spec: + defaults[key] = spec.get("default") + continue + props = spec.get("properties") + if isinstance(props, dict): + defaults[key] = _extract_schema_defaults(props) + return defaults + + +def _load_manifest_defaults() -> Dict[str, Any]: + manifest_path = PLUGIN_ROOT / "_manifest.json" + if not manifest_path.exists(): + return {} + try: + with open(manifest_path, "r", encoding="utf-8") as f: + payload = json.load(f) + schema = payload.get("config_schema") + if isinstance(schema, dict): + return _extract_schema_defaults(schema) + except Exception as e: + logger.warning(f"读取 manifest 默认配置失败,已回退空配置: {e}") + return {} + + +def _build_source_db_fingerprint(db_path: Path) -> Dict[str, Any]: + stat = db_path.stat() + payload = { + "path": str(db_path.resolve()), + "size": stat.st_size, + "mtime": stat.st_mtime, + } + payload["sha1"] = _json_hash(payload) + return payload + + +def _state_path(target_data_dir: Path) -> Path: + return target_data_dir / MIGRATION_STATE_DIRNAME / STATE_FILENAME + + +def _bad_rows_path(target_data_dir: Path) -> Path: + return target_data_dir / MIGRATION_STATE_DIRNAME / BAD_ROWS_FILENAME + + +def _report_path(target_data_dir: Path) -> Path: + return target_data_dir / MIGRATION_STATE_DIRNAME / REPORT_FILENAME + + +def _dump_json_atomic(path: Path, payload: Dict[str, Any]) -> None: + if atomic_write is None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + f.write("\n") + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) + return + + with atomic_write(path, mode="w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + f.write("\n") + + +class SourceDB: + def __init__(self, db_path: Path): + self.db_path = db_path + self.conn: Optional[sqlite3.Connection] = None + + def connect(self) -> None: + if not self.db_path.exists(): + raise MigrationError(f"源数据库不存在: {self.db_path}") + + uri = f"file:{self.db_path.resolve().as_posix()}?mode=ro" + try: + self.conn = sqlite3.connect(uri, uri=True, check_same_thread=False) + except sqlite3.OperationalError: + self.conn = sqlite3.connect(str(self.db_path.resolve()), check_same_thread=False) + + self.conn.row_factory = sqlite3.Row + pragmas = [ + "PRAGMA query_only = ON", + "PRAGMA cache_size = -128000", + "PRAGMA temp_store = MEMORY", + "PRAGMA synchronous = OFF", + "PRAGMA journal_mode = WAL", + ] + for sql in pragmas: + try: + self.conn.execute(sql) + except sqlite3.OperationalError: + # 部分 PRAGMA 在 mode=ro 下会失败,不影响只读扫描能力 + continue + + def close(self) -> None: + if self.conn is not None: + self.conn.close() + self.conn = None + + def _require_conn(self) -> sqlite3.Connection: + if self.conn is None: + raise MigrationError("源数据库尚未连接") + return self.conn + + def resolve_stream_ids( + self, + stream_ids: Sequence[str], + group_ids: Sequence[str], + user_ids: Sequence[str], + ) -> List[str]: + conn = self._require_conn() + resolved: set[str] = set(_normalize_name(x) for x in stream_ids if _normalize_name(x)) + has_group_or_user = any(_normalize_name(x) for x in group_ids) or any(_normalize_name(x) for x in user_ids) + if not has_group_or_user: + return sorted(resolved) + + table_exists = conn.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name='chat_streams' LIMIT 1" + ).fetchone() + if table_exists is None: + raise MigrationError("源库缺少 chat_streams 表,无法根据 --group-id/--user-id 映射 stream_id") + + def _select_by_field(field: str, values: Sequence[str]) -> None: + values_norm = [_normalize_name(v) for v in values if _normalize_name(v)] + if not values_norm: + return + placeholders = ",".join("?" for _ in values_norm) + sql = f"SELECT DISTINCT stream_id FROM chat_streams WHERE {field} IN ({placeholders})" + cur = conn.execute(sql, tuple(values_norm)) + for row in cur.fetchall(): + sid = _normalize_name(row["stream_id"]) + if sid: + resolved.add(sid) + + _select_by_field("group_id", group_ids) + _select_by_field("user_id", user_ids) + return sorted(resolved) + + @staticmethod + def _build_where( + selection: SelectionFilter, + start_after_id: Optional[int] = None, + ) -> Tuple[str, List[Any]]: + conditions: List[str] = [] + params: List[Any] = [] + + if selection.start_id is not None: + conditions.append("id >= ?") + params.append(selection.start_id) + if selection.end_id is not None: + conditions.append("id <= ?") + params.append(selection.end_id) + if start_after_id is not None: + conditions.append("id > ?") + params.append(start_after_id) + + if selection.stream_ids: + placeholders = ",".join("?" for _ in selection.stream_ids) + conditions.append(f"chat_id IN ({placeholders})") + params.extend(selection.stream_ids) + elif selection.stream_filter_requested: + conditions.append("1=0") + + if selection.time_from_ts is not None and selection.time_to_ts is not None: + conditions.append("(end_time >= ? AND start_time <= ?)") + params.extend([selection.time_from_ts, selection.time_to_ts]) + elif selection.time_from_ts is not None: + conditions.append("(end_time >= ?)") + params.append(selection.time_from_ts) + elif selection.time_to_ts is not None: + conditions.append("(start_time <= ?)") + params.append(selection.time_to_ts) + + where_sql = "WHERE " + " AND ".join(conditions) if conditions else "" + return where_sql, params + + def count_candidates(self, selection: SelectionFilter) -> int: + conn = self._require_conn() + where_sql, params = self._build_where(selection, start_after_id=None) + sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}" + cur = conn.execute(sql, tuple(params)) + return int(cur.fetchone()["c"]) + + def preview(self, selection: SelectionFilter, preview_limit: int) -> PreviewResult: + conn = self._require_conn() + where_sql, params = self._build_where(selection, start_after_id=None) + + total_sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}" + total = int(conn.execute(total_sql, tuple(params)).fetchone()["c"]) + + dist_sql = ( + f"SELECT chat_id, COUNT(*) AS c FROM chat_history {where_sql} " + "GROUP BY chat_id ORDER BY c DESC LIMIT 30" + ) + distribution = [ + (_normalize_name(row["chat_id"]), int(row["c"])) + for row in conn.execute(dist_sql, tuple(params)).fetchall() + ] + + sample_sql = ( + "SELECT id, chat_id, start_time, end_time, theme, summary " + f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?" + ) + sample_params = list(params) + sample_params.append(max(1, int(preview_limit))) + samples = [dict(row) for row in conn.execute(sample_sql, tuple(sample_params)).fetchall()] + + return PreviewResult(total=total, distribution=distribution, samples=samples) + + def iter_rows( + self, + selection: SelectionFilter, + batch_size: int, + start_after_id: int, + ) -> Generator[List[sqlite3.Row], None, None]: + conn = self._require_conn() + cursor = int(start_after_id) + while True: + where_sql, params = self._build_where(selection, start_after_id=cursor) + sql = ( + "SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary " + f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?" + ) + bind = list(params) + bind.append(max(1, int(batch_size))) + rows = conn.execute(sql, tuple(bind)).fetchall() + if not rows: + break + yield rows + cursor = int(rows[-1]["id"]) + + def sample_rows_for_verify( + self, + selection: SelectionFilter, + sample_size: int, + ) -> List[sqlite3.Row]: + conn = self._require_conn() + where_sql, params = self._build_where(selection, start_after_id=None) + sql = ( + "SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary " + f"FROM chat_history {where_sql} ORDER BY RANDOM() LIMIT ?" + ) + bind = list(params) + bind.append(max(1, int(sample_size))) + return conn.execute(sql, tuple(bind)).fetchall() + + +class MigrationRunner: + def __init__(self, args: argparse.Namespace): + self.args = args + self.source_db_path = Path(args.source_db).resolve() + self.target_data_dir = Path(args.target_data_dir).resolve() + self.state_file = _state_path(self.target_data_dir) + self.bad_rows_file = _bad_rows_path(self.target_data_dir) + self.report_file = _report_path(self.target_data_dir) + + self.source_db = SourceDB(self.source_db_path) + + self.vector_store = None + self.graph_store = None + self.metadata_store = None + self.embedding_manager = None + self.relation_write_service = None + self.plugin_config: Dict[str, Any] = {} + self.embed_workers: int = 5 + + self.selection: Optional[SelectionFilter] = None + self.filter_fingerprint: str = "" + self.source_db_fingerprint: Dict[str, Any] = {} + self.source_db_fingerprint_hash: str = "" + self.state: Dict[str, Any] = {} + + self.started_at = time.time() + self.exit_code = 0 + self.failed = False + self.fail_reason: Optional[str] = None + + self.stats: Dict[str, Any] = { + "source_matched_total": 0, + "scanned_rows": 0, + "valid_rows": 0, + "migrated_rows": 0, + "skipped_existing_rows": 0, + "bad_rows": 0, + "paragraph_vectors_added": 0, + "entity_vectors_added": 0, + "relations_written": 0, + "relation_vectors_written": 0, + "relation_vectors_failed": 0, + "relation_vectors_skipped": 0, + "graph_edges_written": 0, + "windows_committed": 0, + "last_committed_id": 0, + "verify_sample_size": 0, + "verify_paragraph_missing": 0, + "verify_vector_missing": 0, + "verify_relation_missing": 0, + "verify_edge_missing": 0, + "verify_passed": False, + } + + async def run(self) -> int: + try: + _bootstrap_runtime_symbols() + self._prepare_paths() + + self.source_db.connect() + self.selection = self._build_selection_filter() + self.filter_fingerprint = _json_hash(self.selection.fingerprint_payload()) + + self.source_db_fingerprint = _build_source_db_fingerprint(self.source_db_path) + self.source_db_fingerprint_hash = str(self.source_db_fingerprint.get("sha1", "")) + + preview = self.source_db.preview(self.selection, preview_limit=self.args.preview_limit) + self.stats["source_matched_total"] = int(preview.total) + self._print_preview(preview) + + if preview.total <= 0: + logger.info("筛选后无数据,退出。") + self.stats["verify_passed"] = True + if self.args.verify_only: + self._load_plugin_config() + await self._init_target_stores(require_embedding=False) + await self._verify(strict=True) + return self._finalize() + + if self.args.verify_only: + self._load_plugin_config() + await self._init_target_stores(require_embedding=False) + await self._verify(strict=True) + return self._finalize() + + if self.args.dry_run: + logger.info("dry-run 模式:仅预览,不写入。") + return self._finalize() + + if not self.args.yes: + if not self._confirm(): + logger.info("用户取消执行。") + return self._finalize() + + self._load_plugin_config() + await self._init_target_stores(require_embedding=True) + self._load_or_init_state() + + start_after_id = self._resolve_start_after_id() + await self._migrate(start_after_id=start_after_id) + await self._verify(strict=True) + return self._finalize() + except Exception as e: + self.failed = True + self.fail_reason = str(e) + logger.error(f"迁移失败: {e}\n{traceback.format_exc()}") + return self._finalize() + finally: + self._close() + + def _prepare_paths(self) -> None: + (self.target_data_dir / MIGRATION_STATE_DIRNAME).mkdir(parents=True, exist_ok=True) + if self.args.reset_state and self.state_file.exists(): + self.state_file.unlink() + if self.args.reset_state and self.bad_rows_file.exists(): + self.bad_rows_file.unlink() + + def _load_plugin_config(self) -> None: + merged = _load_manifest_defaults() + + config_path = DEFAULT_CONFIG_PATH + if config_path.exists(): + try: + with open(config_path, "r", encoding="utf-8") as f: + raw = tomlkit.load(f) + if isinstance(raw, dict): + merged = _deep_merge_dict(merged, dict(raw)) + except Exception as e: + logger.warning(f"读取插件配置失败,继续使用默认配置: {e}") + + self.plugin_config = merged + + def _read_existing_vector_dimension(self, fallback_dimension: int) -> int: + meta_path = self.target_data_dir / "vectors" / "vectors_metadata.pkl" + if not meta_path.exists(): + return fallback_dimension + try: + with open(meta_path, "rb") as f: + payload = pickle.load(f) + value = _safe_int(payload.get("dimension"), fallback_dimension) + return max(1, value) + except Exception: + return fallback_dimension + + async def _init_target_stores(self, require_embedding: bool) -> None: + if VectorStore is None or GraphStore is None or MetadataStore is None: + raise MigrationError("运行时初始化失败:存储组件不可用") + + emb_cfg = self.plugin_config.get("embedding", {}) if isinstance(self.plugin_config, dict) else {} + graph_cfg = self.plugin_config.get("graph", {}) if isinstance(self.plugin_config, dict) else {} + + self.embed_workers = max(1, _safe_int(self.args.embed_workers, _safe_int(emb_cfg.get("max_concurrent"), 5))) + emb_batch_size = max(1, _safe_int(emb_cfg.get("batch_size"), 32)) + emb_default_dim = max(1, _safe_int(emb_cfg.get("dimension"), 1024)) + emb_model_name = str(emb_cfg.get("model_name", "auto")) + emb_retry = emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {} + + if require_embedding: + _load_embedding_adapter_factory() + if create_embedding_api_adapter is None: + raise MigrationError("运行时初始化失败:embedding 适配器不可用") + + if model_config is not None: + embedding_task = getattr(getattr(model_config, "model_task_config", None), "embedding", None) + if embedding_task is not None and hasattr(embedding_task, "model_list"): + if not list(embedding_task.model_list): + raise MigrationError( + "当前配置没有可用 embedding 模型。若你使用 gemini provider,请先安装 `google-genai` " + "或切换到可用的 embedding provider。" + ) + + self.embedding_manager = create_embedding_api_adapter( + batch_size=emb_batch_size, + max_concurrent=self.embed_workers, + default_dimension=emb_default_dim, + model_name=emb_model_name, + retry_config=emb_retry, + ) + + try: + detected_dim = self._read_existing_vector_dimension(emb_default_dim) + has_existing_vectors = (self.target_data_dir / "vectors" / "vectors_metadata.pkl").exists() + if not has_existing_vectors: + detected_dim = await self.embedding_manager._detect_dimension() + except Exception as e: + logger.warning(f"嵌入维度探测失败,回退配置维度: {e}") + detected_dim = self._read_existing_vector_dimension(emb_default_dim) + else: + detected_dim = self._read_existing_vector_dimension(emb_default_dim) + self.embedding_manager = None + + q_type = str(emb_cfg.get("quantization_type", "int8")).lower() + if q_type != "int8": + raise MigrationError( + "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + quantization = QuantizationType.INT8 + + matrix_fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower() + fmt_map = { + "csr": SparseMatrixFormat.CSR, + "csc": SparseMatrixFormat.CSC, + } + sparse_fmt = fmt_map.get(matrix_fmt, SparseMatrixFormat.CSR) + + self.vector_store = VectorStore( + dimension=detected_dim, + quantization_type=quantization, + data_dir=self.target_data_dir / "vectors", + ) + self.graph_store = GraphStore( + matrix_format=sparse_fmt, + data_dir=self.target_data_dir / "graph", + ) + self.metadata_store = MetadataStore(data_dir=self.target_data_dir / "metadata") + self.metadata_store.connect() + + if self.vector_store.has_data(): + self.vector_store.load() + if self.graph_store.has_data(): + self.graph_store.load() + + self.relation_write_service = None + if require_embedding and RelationWriteService is not None and self.embedding_manager is not None: + self.relation_write_service = RelationWriteService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + + logger.info( + f"目标存储初始化完成: dim={self.vector_store.dimension}, quant={q_type}, graph_fmt={matrix_fmt}, " + f"embed_workers={self.embed_workers}" + ) + + def _should_write_relation_vectors(self) -> bool: + retrieval_cfg = self.plugin_config.get("retrieval", {}) if isinstance(self.plugin_config, dict) else {} + if not isinstance(retrieval_cfg, dict): + return False + rv_cfg = retrieval_cfg.get("relation_vectorization", {}) + if not isinstance(rv_cfg, dict): + return False + return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) + + async def _ensure_relation_vectors_for_records( + self, + relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]], + ) -> None: + if not relation_records: + return + if self.relation_write_service is None: + return + + success = 0 + failed = 0 + skipped = 0 + for relation_hash, rel in relation_records.items(): + result = await self.relation_write_service.ensure_relation_vector( + hash_value=relation_hash, + subject=str(rel[0]), + predicate=str(rel[1]), + obj=str(rel[2]), + ) + if result.vector_state == "ready": + if result.vector_written: + success += 1 + else: + skipped += 1 + else: + failed += 1 + + self.stats["relation_vectors_written"] += success + self.stats["relation_vectors_failed"] += failed + self.stats["relation_vectors_skipped"] += skipped + + def _build_selection_filter(self) -> SelectionFilter: + if self.args.start_id is not None and self.args.start_id <= 0: + raise MigrationError("--start-id 必须 > 0") + if self.args.end_id is not None and self.args.end_id <= 0: + raise MigrationError("--end-id 必须 > 0") + if self.args.start_id is not None and self.args.end_id is not None and self.args.start_id > self.args.end_id: + raise MigrationError("--start-id 不能大于 --end-id") + + time_from_ts = _parse_cli_datetime(self.args.time_from, is_end=False) if self.args.time_from else None + time_to_ts = _parse_cli_datetime(self.args.time_to, is_end=True) if self.args.time_to else None + if time_from_ts is not None and time_to_ts is not None and time_from_ts > time_to_ts: + raise MigrationError("--time-from 不能晚于 --time-to") + + stream_filter_requested = bool( + (self.args.stream_id or []) or (self.args.group_id or []) or (self.args.user_id or []) + ) + stream_ids = self.source_db.resolve_stream_ids( + stream_ids=self.args.stream_id or [], + group_ids=self.args.group_id or [], + user_ids=self.args.user_id or [], + ) + if stream_filter_requested and not stream_ids: + logger.warning("已指定 stream/group/user 筛选,但未解析到任何 stream_id,结果将为空。") + + logger.info( + f"筛选条件: time_from={self.args.time_from or '-'}, time_to={self.args.time_to or '-'}, " + f"stream_ids={len(stream_ids)}, stream_filter_requested={stream_filter_requested}" + ) + + return SelectionFilter( + time_from_ts=time_from_ts, + time_to_ts=time_to_ts, + stream_ids=stream_ids, + stream_filter_requested=stream_filter_requested, + start_id=self.args.start_id, + end_id=self.args.end_id, + time_from_raw=self.args.time_from, + time_to_raw=self.args.time_to, + ) + + def _load_or_init_state(self) -> None: + if self.args.start_id is not None: + logger.info("检测到 --start-id,已按用户指定起点覆盖断点状态。") + self.state = self._new_state(last_committed_id=int(self.args.start_id) - 1) + return + + if self.args.no_resume: + self.state = self._new_state(last_committed_id=0) + return + + if not self.state_file.exists(): + self.state = self._new_state(last_committed_id=0) + return + + with open(self.state_file, "r", encoding="utf-8") as f: + loaded = json.load(f) + + loaded_filter_fp = str(loaded.get("filter_fingerprint", "")) + loaded_source_fp = str(loaded.get("source_db_fingerprint", "")) + + if loaded_filter_fp != self.filter_fingerprint or loaded_source_fp != self.source_db_fingerprint_hash: + if self.args.dry_run or self.args.verify_only: + logger.info("检测到断点与当前筛选不一致;当前为只读模式,将忽略旧断点。") + self.state = self._new_state(last_committed_id=0) + return + raise MigrationError( + "检测到筛选条件或源库指纹变化,已拒绝继续续传。请使用 --reset-state 或调整参数后重试。" + ) + + self.state = loaded + stored_stats = loaded.get("stats", {}) + if isinstance(stored_stats, dict): + for k, v in stored_stats.items(): + if k in self.stats and isinstance(v, (int, float, bool)): + self.stats[k] = v + + def _new_state(self, last_committed_id: int) -> Dict[str, Any]: + return { + "version": 1, + "updated_at": time.time(), + "last_committed_id": int(last_committed_id), + "filter_fingerprint": self.filter_fingerprint, + "source_db_fingerprint": self.source_db_fingerprint_hash, + "source_db_meta": self.source_db_fingerprint, + "stats": dict(self.stats), + } + + def _flush_state(self, last_committed_id: int) -> None: + self.stats["last_committed_id"] = int(last_committed_id) + self.state = { + "version": 1, + "updated_at": time.time(), + "last_committed_id": int(last_committed_id), + "filter_fingerprint": self.filter_fingerprint, + "source_db_fingerprint": self.source_db_fingerprint_hash, + "source_db_meta": self.source_db_fingerprint, + "stats": dict(self.stats), + } + _dump_json_atomic(self.state_file, self.state) + + def _resolve_start_after_id(self) -> int: + if self.selection is None: + raise MigrationError("selection 未初始化") + + if self.args.start_id is not None: + return int(self.args.start_id) - 1 + + if self.args.no_resume: + return 0 + + state_last = _safe_int(self.state.get("last_committed_id"), 0) if self.state else 0 + return max(0, state_last) + + def _print_preview(self, preview: PreviewResult) -> None: + print("\n=== Migration Preview ===") + print(f"source_db: {self.source_db_path}") + print(f"target_data_dir: {self.target_data_dir}") + if self.selection: + print( + f"time_window: [{self.selection.time_from_raw or '-'} ~ {self.selection.time_to_raw or '-'}] " + f"(ts: {_format_ts(self.selection.time_from_ts)} ~ {_format_ts(self.selection.time_to_ts)})" + ) + print( + f"id_window: [{self.selection.start_id or '-'} ~ {self.selection.end_id or '-'}], " + f"selected_streams={len(self.selection.stream_ids)}" + ) + print(f"matched_rows: {preview.total}") + + if preview.distribution: + print("top_chat_distribution:") + for cid, cnt in preview.distribution[:10]: + print(f" - {cid}: {cnt}") + else: + print("top_chat_distribution: (none)") + + if preview.samples: + print(f"samples (first {len(preview.samples)}):") + for row in preview.samples: + summary_preview = _normalize_name(row.get("summary", ""))[:60] + theme_preview = _normalize_name(row.get("theme", ""))[:30] + print( + f" - id={row.get('id')} chat_id={row.get('chat_id')} " + f"[{_format_ts(row.get('start_time'))} ~ {_format_ts(row.get('end_time'))}] " + f"theme={theme_preview!r} summary={summary_preview!r}" + ) + print("=========================\n") + + def _confirm(self) -> bool: + answer = input("确认按以上筛选执行迁移?输入 y 继续 [y/N]: ").strip().lower() + return answer in {"y", "yes"} + + def _parse_json_list_field(self, raw: Any, field_name: str, row_id: int) -> List[str]: + if raw is None: + return [] + if isinstance(raw, list): + data = raw + elif isinstance(raw, str): + try: + parsed = json.loads(raw) + except Exception as e: + raise ValueError(f"{field_name} JSON 解析失败: {e}") from e + if not isinstance(parsed, list): + raise ValueError(f"{field_name} JSON 必须是 list,当前为 {type(parsed).__name__}") + data = parsed + else: + raise ValueError(f"{field_name} 字段类型不支持: {type(raw).__name__}") + return _dedup_keep_order(str(x) for x in data if _normalize_name(x)) + + def _map_row(self, row: sqlite3.Row) -> MappedRow: + row_id = int(row["id"]) + chat_id = _normalize_name(row["chat_id"]) + theme = _normalize_name(row["theme"]) + summary = _normalize_name(row["summary"]) + + participants = self._parse_json_list_field(row["participants"], "participants", row_id) + keywords = self._parse_json_list_field(row["keywords"], "keywords", row_id) + keywords_top = keywords[:8] + + participants_text = "、".join(participants) if participants else "" + keywords_text = "、".join(keywords_top) if keywords_top else "" + + content = ( + f"话题:{theme}\n" + f"概括:{summary}\n" + f"参与者:{participants_text}\n" + f"关键词:{keywords_text}" + ).strip() + + paragraph_hash = compute_hash(normalize_text(content)) + source = f"maibot.chat_history:{chat_id}" + + start_time = _safe_float(row["start_time"], 0.0) + end_time = _safe_float(row["end_time"], start_time) + time_meta = { + "event_time_start": start_time, + "event_time_end": end_time, + "time_granularity": "minute", + "time_confidence": 0.95, + } + + entities = _dedup_keep_order([*participants, theme, *keywords_top]) + relations: List[Tuple[str, str, str]] = [] + if theme: + for participant in participants: + relations.append((participant, "参与话题", theme)) + for keyword in keywords_top: + relations.append((theme, "关键词", keyword)) + + existing_vector = paragraph_hash in self.vector_store + return MappedRow( + row_id=row_id, + chat_id=chat_id, + paragraph_hash=paragraph_hash, + content=content, + source=source, + time_meta=time_meta, + entities=entities, + relations=relations, + existing_paragraph_vector=existing_vector, + ) + + def _append_bad_row(self, row: sqlite3.Row, reason: str) -> None: + payload = { + "id": int(row["id"]), + "chat_id": _normalize_name(row["chat_id"]), + "start_time": row["start_time"], + "end_time": row["end_time"], + "participants": row["participants"], + "theme": _normalize_name(row["theme"]), + "keywords": row["keywords"], + "summary": row["summary"], + "error": reason, + "timestamp": time.time(), + } + self.bad_rows_file.parent.mkdir(parents=True, exist_ok=True) + with open(self.bad_rows_file, "a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False)) + f.write("\n") + + async def _migrate(self, start_after_id: int) -> None: + if self.selection is None: + raise MigrationError("selection 未初始化") + + read_batch_size = max(1, int(self.args.read_batch_size)) + commit_window_rows = max(1, int(self.args.commit_window_rows)) + log_every = max(1, int(self.args.log_every)) + + window_rows: List[MappedRow] = [] + window_scanned = 0 + last_seen_id = start_after_id + + logger.info( + f"开始迁移: start_after_id={start_after_id}, read_batch_size={read_batch_size}, " + f"commit_window_rows={commit_window_rows}" + ) + + for batch in self.source_db.iter_rows(self.selection, read_batch_size, start_after_id): + for row in batch: + row_id = int(row["id"]) + last_seen_id = row_id + self.stats["scanned_rows"] += 1 + window_scanned += 1 + + try: + mapped = self._map_row(row) + except Exception as e: + self.stats["bad_rows"] += 1 + self._append_bad_row(row, str(e)) + if self.stats["bad_rows"] > int(self.args.max_errors): + raise MigrationError( + f"坏行数量超过上限 max_errors={self.args.max_errors},已中止。" + ) + continue + + self.stats["valid_rows"] += 1 + if mapped.existing_paragraph_vector: + self.stats["skipped_existing_rows"] += 1 + else: + self.stats["migrated_rows"] += 1 + window_rows.append(mapped) + + if window_scanned >= commit_window_rows: + await self._commit_window(window_rows, last_seen_id) + window_rows = [] + window_scanned = 0 + + if self.stats["scanned_rows"] % log_every == 0: + logger.info( + f"迁移进度: scanned={self.stats['scanned_rows']}/{self.stats['source_matched_total']}, " + f"valid={self.stats['valid_rows']}, bad={self.stats['bad_rows']}, " + f"last_id={last_seen_id}" + ) + + if window_scanned > 0 or window_rows: + await self._commit_window(window_rows, last_seen_id) + + logger.info( + f"迁移主流程完成: scanned={self.stats['scanned_rows']}, valid={self.stats['valid_rows']}, " + f"bad={self.stats['bad_rows']}, last_committed_id={self.stats['last_committed_id']}" + ) + + async def _commit_window(self, rows: List[MappedRow], last_seen_id: int) -> None: + if not rows: + self._flush_state(last_seen_id) + self.stats["windows_committed"] += 1 + return + + now_ts = time.time() + empty_meta_blob = pickle.dumps({}) + + conn = self.metadata_store.get_connection() + + cursor = conn.cursor() + + # 批量查询本窗口内已存在的段落,保证重跑时 entity/mention 不重复累计 + existing_paragraph_hashes: set[str] = set() + all_hashes = [item.paragraph_hash for item in rows] + for i in range(0, len(all_hashes), 800): + batch_hashes = all_hashes[i : i + 800] + if not batch_hashes: + continue + placeholders = ",".join("?" for _ in batch_hashes) + existing_rows = cursor.execute( + f"SELECT hash FROM paragraphs WHERE hash IN ({placeholders})", + tuple(batch_hashes), + ).fetchall() + for row in existing_rows: + existing_paragraph_hashes.add(str(row["hash"])) + + paragraph_records: List[Tuple[Any, ...]] = [] + paragraph_embed_map: Dict[str, str] = {} + + entity_display: Dict[str, str] = {} + entity_counts: Dict[str, int] = defaultdict(int) + paragraph_entity_mentions: Dict[Tuple[str, str], int] = defaultdict(int) + entity_embed_map: Dict[str, str] = {} + + relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]] = {} + paragraph_relation_links: set[Tuple[str, str]] = set() + + for item in rows: + is_new_paragraph = item.paragraph_hash not in existing_paragraph_hashes + + start_ts = _safe_float(item.time_meta.get("event_time_start"), 0.0) + end_ts = _safe_float(item.time_meta.get("event_time_end"), start_ts) + confidence = _safe_float(item.time_meta.get("time_confidence"), 0.95) + granularity = _normalize_name(item.time_meta.get("time_granularity")) or "minute" + + if is_new_paragraph: + paragraph_records.append( + ( + item.paragraph_hash, + item.content, + None, + now_ts, + now_ts, + empty_meta_blob, + item.source, + len(normalize_text(item.content).split()), + None, + start_ts, + end_ts, + granularity, + confidence, + KnowledgeType.NARRATIVE.value, + ) + ) + + if item.paragraph_hash not in self.vector_store: + paragraph_embed_map[item.paragraph_hash] = item.content + + for entity in item.entities: + name = _normalize_name(entity) + if not name: + continue + canon = _canonical_name(name) + if not canon: + continue + entity_hash = compute_hash(canon) + entity_display.setdefault(entity_hash, name) + if is_new_paragraph: + entity_counts[entity_hash] += 1 + paragraph_entity_mentions[(item.paragraph_hash, entity_hash)] += 1 + if entity_hash not in self.vector_store: + entity_embed_map.setdefault(entity_hash, name) + + for subject, predicate, obj in item.relations: + s = _normalize_name(subject) + p = _normalize_name(predicate) + o = _normalize_name(obj) + if not (s and p and o): + continue + + s_canon = _canonical_name(s) + p_canon = _canonical_name(p) + o_canon = _canonical_name(o) + relation_hash = compute_hash(f"{s_canon}|{p_canon}|{o_canon}") + + if is_new_paragraph: + relation_records.setdefault( + relation_hash, + (s, p, o, 1.0, item.paragraph_hash, empty_meta_blob), + ) + paragraph_relation_links.add((item.paragraph_hash, relation_hash)) + + for relation_entity in (s, o): + e_canon = _canonical_name(relation_entity) + if not e_canon: + continue + e_hash = compute_hash(e_canon) + entity_display.setdefault(e_hash, relation_entity) + if is_new_paragraph: + entity_counts[e_hash] += 1 + paragraph_entity_mentions[(item.paragraph_hash, e_hash)] += 1 + if e_hash not in self.vector_store: + entity_embed_map.setdefault(e_hash, relation_entity) + + try: + cursor.execute("BEGIN") + + if paragraph_records: + cursor.executemany( + """ + INSERT OR IGNORE INTO paragraphs + ( + hash, content, vector_index, created_at, updated_at, metadata, source, word_count, + event_time, event_time_start, event_time_end, time_granularity, time_confidence, knowledge_type + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + paragraph_records, + ) + + if entity_counts: + entity_rows = [ + ( + entity_hash, + entity_display[entity_hash], + None, + int(count), + now_ts, + empty_meta_blob, + ) + for entity_hash, count in entity_counts.items() + ] + try: + cursor.executemany( + """ + INSERT INTO entities + (hash, name, vector_index, appearance_count, created_at, metadata) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(hash) DO UPDATE SET + appearance_count = entities.appearance_count + excluded.appearance_count + """, + entity_rows, + ) + except sqlite3.OperationalError: + cursor.executemany( + """ + INSERT OR IGNORE INTO entities + (hash, name, vector_index, appearance_count, created_at, metadata) + VALUES (?, ?, ?, ?, ?, ?) + """, + entity_rows, + ) + cursor.executemany( + "UPDATE entities SET appearance_count = appearance_count + ? WHERE hash = ?", + [(int(count), entity_hash) for entity_hash, count in entity_counts.items()], + ) + + if paragraph_entity_mentions: + pe_rows = [ + (paragraph_hash, entity_hash, int(mentions)) + for (paragraph_hash, entity_hash), mentions in paragraph_entity_mentions.items() + ] + try: + cursor.executemany( + """ + INSERT INTO paragraph_entities + (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + ON CONFLICT(paragraph_hash, entity_hash) DO UPDATE SET + mention_count = paragraph_entities.mention_count + excluded.mention_count + """, + pe_rows, + ) + except sqlite3.OperationalError: + cursor.executemany( + """ + INSERT OR IGNORE INTO paragraph_entities + (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + """, + pe_rows, + ) + cursor.executemany( + """ + UPDATE paragraph_entities + SET mention_count = mention_count + ? + WHERE paragraph_hash = ? AND entity_hash = ? + """, + [(m, p, e) for (p, e, m) in pe_rows], + ) + + if relation_records: + relation_rows = [ + ( + relation_hash, + rel[0], + rel[1], + rel[2], + None, + rel[3], + now_ts, + rel[4], + rel[5], + ) + for relation_hash, rel in relation_records.items() + ] + cursor.executemany( + """ + INSERT OR IGNORE INTO relations + (hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + relation_rows, + ) + + if paragraph_relation_links: + pr_rows = [(p_hash, r_hash) for p_hash, r_hash in paragraph_relation_links] + cursor.executemany( + """ + INSERT OR IGNORE INTO paragraph_relations + (paragraph_hash, relation_hash) + VALUES (?, ?) + """, + pr_rows, + ) + + conn.commit() + except Exception: + conn.rollback() + raise + + self.stats["relations_written"] += len(relation_records) + + if relation_records: + edge_pairs = [] + relation_hashes = [] + for relation_hash, rel in relation_records.items(): + edge_pairs.append((rel[0], rel[2])) + relation_hashes.append(relation_hash) + + with self.graph_store.batch_update(): + self.graph_store.add_edges(edge_pairs, relation_hashes=relation_hashes) + self.stats["graph_edges_written"] += len(edge_pairs) + + if self._should_write_relation_vectors(): + await self._ensure_relation_vectors_for_records(relation_records) + + para_added = await self._embed_and_add_vectors( + id_to_text=paragraph_embed_map, + batch_size=max(1, int(self.args.embed_batch_size)), + workers=self.embed_workers, + ) + ent_added = await self._embed_and_add_vectors( + id_to_text=entity_embed_map, + batch_size=max(1, int(self.args.entity_embed_batch_size)), + workers=self.embed_workers, + ) + self.stats["paragraph_vectors_added"] += para_added + self.stats["entity_vectors_added"] += ent_added + + self.vector_store.save() + self.graph_store.save() + + self.stats["windows_committed"] += 1 + self._flush_state(last_seen_id) + + async def _embed_and_add_vectors( + self, + id_to_text: Dict[str, str], + batch_size: int, + workers: int, + ) -> int: + if not id_to_text: + return 0 + if self.embedding_manager is None: + raise MigrationError("embedding_manager 未初始化,无法写入向量") + + ids = [] + texts = [] + for hash_id, text in id_to_text.items(): + if hash_id in self.vector_store: + continue + ids.append(hash_id) + texts.append(text) + + if not ids: + return 0 + + total_added = 0 + chunk_size = max(1, int(batch_size)) + for i in range(0, len(ids), chunk_size): + chunk_ids = ids[i : i + chunk_size] + chunk_texts = texts[i : i + chunk_size] + + embeddings = await self.embedding_manager.encode_batch( + chunk_texts, + batch_size=chunk_size, + num_workers=max(1, int(workers)), + ) + + emb_arr = np.asarray(embeddings, dtype=np.float32) + if emb_arr.ndim == 1: + emb_arr = emb_arr.reshape(1, -1) + if emb_arr.shape[0] != len(chunk_ids): + logger.warning( + f"embedding 返回数量异常: expected={len(chunk_ids)}, got={emb_arr.shape[0]},跳过该批次" + ) + continue + + valid_vectors = [] + valid_ids = [] + for idx, vec in enumerate(emb_arr): + if vec.ndim != 1: + continue + if vec.shape[0] != self.vector_store.dimension: + logger.warning( + f"向量维度不匹配,跳过: id={chunk_ids[idx]}, got={vec.shape[0]}, expected={self.vector_store.dimension}" + ) + continue + if not np.all(np.isfinite(vec)): + logger.warning(f"向量含 NaN/Inf,跳过: id={chunk_ids[idx]}") + continue + if chunk_ids[idx] in self.vector_store: + continue + valid_vectors.append(vec) + valid_ids.append(chunk_ids[idx]) + + if valid_vectors: + batch_vectors = np.stack(valid_vectors).astype(np.float32, copy=False) + added = self.vector_store.add(batch_vectors, valid_ids) + total_added += int(added) + + return total_added + + async def _verify(self, strict: bool) -> None: + if self.selection is None: + raise MigrationError("selection 未初始化") + + sample_size = min(2000, max(0, int(self.stats.get("source_matched_total", 0)))) + self.stats["verify_sample_size"] = sample_size + + if sample_size <= 0: + self.stats["verify_passed"] = True + return + + sample_rows = self.source_db.sample_rows_for_verify(self.selection, sample_size) + para_missing = 0 + vec_missing = 0 + rel_missing = 0 + edge_missing = 0 + + for row in sample_rows: + try: + mapped = self._map_row(row) + except Exception: + continue + + paragraph = self.metadata_store.get_paragraph(mapped.paragraph_hash) + if paragraph is None: + para_missing += 1 + if mapped.paragraph_hash not in self.vector_store: + vec_missing += 1 + + for s, p, o in mapped.relations: + relation_hash = compute_hash(f"{_canonical_name(s)}|{_canonical_name(p)}|{_canonical_name(o)}") + relation = self.metadata_store.get_relation(relation_hash) + if relation is None: + rel_missing += 1 + if self.graph_store.get_edge_weight(s, o) <= 0.0: + edge_missing += 1 + + self.stats["verify_paragraph_missing"] = para_missing + self.stats["verify_vector_missing"] = vec_missing + self.stats["verify_relation_missing"] = rel_missing + self.stats["verify_edge_missing"] = edge_missing + + verify_passed = all(x == 0 for x in [para_missing, vec_missing, rel_missing, edge_missing]) + if strict and not verify_passed: + self.failed = True + self.fail_reason = ( + "严格校验失败: " + f"paragraph_missing={para_missing}, vector_missing={vec_missing}, " + f"relation_missing={rel_missing}, edge_missing={edge_missing}" + ) + + self.stats["verify_passed"] = verify_passed + + def _finalize(self) -> int: + elapsed = time.time() - self.started_at + self.stats["elapsed_seconds"] = elapsed + + report = { + "success": not self.failed, + "fail_reason": self.fail_reason, + "args": vars(self.args), + "source_db": str(self.source_db_path), + "target_data_dir": str(self.target_data_dir), + "selection": self.selection.fingerprint_payload() if self.selection else {}, + "filter_fingerprint": self.filter_fingerprint, + "source_db_fingerprint": self.source_db_fingerprint, + "state_file": str(self.state_file), + "bad_rows_file": str(self.bad_rows_file), + "stats": dict(self.stats), + "timestamp": time.time(), + } + + _dump_json_atomic(self.report_file, report) + + if self.failed: + self.exit_code = 1 + elif self.stats.get("bad_rows", 0) > 0: + self.exit_code = 2 + else: + self.exit_code = 0 + + print("\n=== Migration Report ===") + print(f"success: {not self.failed}") + if self.fail_reason: + print(f"fail_reason: {self.fail_reason}") + print(f"elapsed: {elapsed:.2f}s") + print(f"source_matched_total: {self.stats['source_matched_total']}") + print(f"scanned_rows: {self.stats['scanned_rows']}") + print(f"valid_rows: {self.stats['valid_rows']}") + print(f"migrated_rows: {self.stats['migrated_rows']}") + print(f"skipped_existing_rows: {self.stats['skipped_existing_rows']}") + print(f"bad_rows: {self.stats['bad_rows']}") + print(f"paragraph_vectors_added: {self.stats['paragraph_vectors_added']}") + print(f"entity_vectors_added: {self.stats['entity_vectors_added']}") + print(f"relations_written: {self.stats['relations_written']}") + print( + "relation_vectors: " + f"written={self.stats['relation_vectors_written']}, " + f"failed={self.stats['relation_vectors_failed']}, " + f"skipped={self.stats['relation_vectors_skipped']}" + ) + print(f"graph_edges_written: {self.stats['graph_edges_written']}") + print(f"windows_committed: {self.stats['windows_committed']}") + print(f"last_committed_id: {self.stats['last_committed_id']}") + print( + "verify: " + f"sample={self.stats['verify_sample_size']}, " + f"paragraph_missing={self.stats['verify_paragraph_missing']}, " + f"vector_missing={self.stats['verify_vector_missing']}, " + f"relation_missing={self.stats['verify_relation_missing']}, " + f"edge_missing={self.stats['verify_edge_missing']}, " + f"passed={self.stats['verify_passed']}" + ) + print(f"report_file: {self.report_file}") + print("========================\n") + + return self.exit_code + + def _close(self) -> None: + try: + if self.metadata_store is not None: + self.metadata_store.close() + except Exception: + pass + self.source_db.close() + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="迁移 MaiBot chat_history 到 A_memorix(高性能 + 可断点续传 + 可确认筛选)" + ) + + parser.add_argument("--source-db", default=str(DEFAULT_SOURCE_DB), help="源数据库路径(默认 data/MaiBot.db)") + parser.add_argument( + "--target-data-dir", + default=str(DEFAULT_TARGET_DATA_DIR), + help="A_memorix 数据目录(默认 plugins/A_memorix/data)", + ) + + resume_group = parser.add_mutually_exclusive_group() + resume_group.add_argument("--resume", dest="no_resume", action="store_false", help="启用断点续传(默认)") + resume_group.add_argument("--no-resume", dest="no_resume", action="store_true", help="禁用断点续传") + parser.set_defaults(no_resume=False) + + parser.add_argument("--reset-state", action="store_true", help="清空迁移状态文件后执行") + parser.add_argument("--start-id", type=int, default=None, help="从指定 chat_history.id 开始迁移(覆盖断点)") + parser.add_argument("--end-id", type=int, default=None, help="迁移到指定 chat_history.id") + + parser.add_argument("--read-batch-size", type=int, default=2000, help="源库分页读取大小(默认 2000)") + parser.add_argument("--commit-window-rows", type=int, default=20000, help="每窗口提交行数(默认 20000)") + parser.add_argument("--embed-batch-size", type=int, default=256, help="段落 embedding 批次大小(默认 256)") + parser.add_argument( + "--entity-embed-batch-size", + type=int, + default=512, + help="实体 embedding 批次大小(默认 512)", + ) + parser.add_argument("--embed-workers", type=int, default=None, help="embedding 并发数(默认读取配置)") + parser.add_argument("--max-errors", type=int, default=500, help="坏行上限(默认 500)") + parser.add_argument("--log-every", type=int, default=5000, help="日志输出步长(默认 5000)") + + parser.add_argument("--dry-run", action="store_true", help="仅预览不写入") + parser.add_argument("--verify-only", action="store_true", help="仅执行严格校验") + + parser.add_argument("--time-from", default=None, help="开始时间:YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]") + parser.add_argument("--time-to", default=None, help="结束时间:YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]") + parser.add_argument("--stream-id", action="append", default=[], help="聊天流 stream_id(可重复)") + parser.add_argument("--group-id", action="append", default=[], help="群号(可重复,自动映射 stream_id)") + parser.add_argument("--user-id", action="append", default=[], help="用户号(可重复,自动映射 stream_id)") + parser.add_argument("--yes", action="store_true", help="跳过交互确认") + parser.add_argument("--preview-limit", type=int, default=20, help="预览样本条数(默认 20)") + + return parser + + +async def async_main() -> int: + parser = build_parser() + args = parser.parse_args() + + runner = MigrationRunner(args) + return await runner.run() + + +def main() -> int: + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + return asyncio.run(async_main()) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/process_knowledge.py b/plugins/A_memorix/scripts/process_knowledge.py new file mode 100644 index 00000000..d9e6fe32 --- /dev/null +++ b/plugins/A_memorix/scripts/process_knowledge.py @@ -0,0 +1,728 @@ +#!/usr/bin/env python3 +""" +知识库自动导入脚本 (Strategy-Aware Version) + +功能: +1. 扫描 plugins/A_memorix/data/raw 下的 .txt 文件 +2. 检查 data/import_manifest.json 确认是否已导入 +3. 使用 Strategy 模式处理文件 (Narrative/Factual/Quote) +4. 将生成的数据直接存入 VectorStore/GraphStore/MetadataStore +5. 更新 manifest +""" + +import sys +import os +import json +import asyncio +import time +import random +import hashlib +import tomlkit +import argparse +from pathlib import Path +from datetime import datetime +from typing import List, Dict, Any, Optional +from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn +from rich.console import Console +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type + +console = Console() + +class LLMGenerationError(Exception): + pass + +# 路径设置 +current_dir = Path(__file__).resolve().parent +plugin_root = current_dir.parent +workspace_root = plugin_root.parent +maibot_root = workspace_root / "MaiBot" +for path in (workspace_root, maibot_root, plugin_root): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + +# 数据目录 +DATA_DIR = plugin_root / "data" +RAW_DIR = DATA_DIR / "raw" +PROCESSED_DIR = DATA_DIR / "processed" +MANIFEST_PATH = DATA_DIR / "import_manifest.json" + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="A_Memorix Knowledge Importer (Strategy-Aware)") + parser.add_argument("--force", action="store_true", help="Force re-import") + parser.add_argument("--clear-manifest", action="store_true", help="Clear manifest") + parser.add_argument( + "--type", + "-t", + default="auto", + help="Target import strategy override (auto/narrative/factual/quote)", + ) + parser.add_argument("--concurrency", "-c", type=int, default=5) + parser.add_argument( + "--chat-log", + action="store_true", + help="聊天记录导入模式:强制 narrative 策略,并使用 LLM 语义抽取 event_time/event_time_range", + ) + parser.add_argument( + "--chat-reference-time", + default=None, + help="chat_log 模式的相对时间参考点(如 2026/02/12 10:30);不传则使用当前本地时间", + ) + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + sys.exit(0) + + +try: + import A_memorix.core as core_module + import A_memorix.core.storage as storage_module + from src.common.logger import get_logger + from src.services import llm_service as llm_api + from src.config.config import global_config, model_config + + VectorStore = core_module.VectorStore + GraphStore = core_module.GraphStore + MetadataStore = core_module.MetadataStore + ImportStrategy = core_module.ImportStrategy + create_embedding_api_adapter = core_module.create_embedding_api_adapter + RelationWriteService = getattr(core_module, "RelationWriteService", None) + + looks_like_quote_text = storage_module.looks_like_quote_text + parse_import_strategy = storage_module.parse_import_strategy + resolve_stored_knowledge_type = storage_module.resolve_stored_knowledge_type + select_import_strategy = storage_module.select_import_strategy + + from A_memorix.core.utils.time_parser import normalize_time_meta + from A_memorix.core.utils.import_payloads import normalize_paragraph_import_item + from A_memorix.core.strategies.base import BaseStrategy, ProcessedChunk, KnowledgeType as StratKnowledgeType + from A_memorix.core.strategies.narrative import NarrativeStrategy + from A_memorix.core.strategies.factual import FactualStrategy + from A_memorix.core.strategies.quote import QuoteStrategy + +except ImportError as e: + print(f"❌ 无法导入模块: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + +logger = get_logger("A_Memorix.AutoImport") + + +def _log_before_retry(retry_state) -> None: + """使用项目统一日志风格记录重试信息。""" + exc = None + if getattr(retry_state, "outcome", None) is not None and retry_state.outcome.failed: + exc = retry_state.outcome.exception() + next_sleep = getattr(getattr(retry_state, "next_action", None), "sleep", None) + logger.warning( + "LLM 调用即将重试: " + f"attempt={getattr(retry_state, 'attempt_number', '?')} " + f"next_sleep={next_sleep} " + f"error={exc}" + ) + +class AutoImporter: + def __init__( + self, + force: bool = False, + clear_manifest: bool = False, + target_type: str = "auto", + concurrency: int = 5, + chat_log: bool = False, + chat_reference_time: Optional[str] = None, + ): + self.vector_store: Optional[VectorStore] = None + self.graph_store: Optional[GraphStore] = None + self.metadata_store: Optional[MetadataStore] = None + self.embedding_manager = None + self.relation_write_service = None + self.plugin_config = {} + self.manifest = {} + self.force = force + self.clear_manifest = clear_manifest + self.chat_log = chat_log + parsed_target_type = parse_import_strategy(target_type, default=ImportStrategy.AUTO) + self.target_type = ImportStrategy.NARRATIVE.value if chat_log else parsed_target_type.value + self.chat_reference_dt = self._parse_reference_time(chat_reference_time) + if self.chat_log and parsed_target_type not in {ImportStrategy.AUTO, ImportStrategy.NARRATIVE}: + logger.warning( + f"chat_log 模式已启用,target_type={target_type} 将被覆盖为 narrative" + ) + self.concurrency_limit = concurrency + self.semaphore = None + self.storage_lock = None + + async def initialize(self): + logger.info(f"正在初始化... (并发数: {self.concurrency_limit})") + self.semaphore = asyncio.Semaphore(self.concurrency_limit) + self.storage_lock = asyncio.Lock() + + RAW_DIR.mkdir(parents=True, exist_ok=True) + PROCESSED_DIR.mkdir(parents=True, exist_ok=True) + + if self.clear_manifest: + logger.info("🧹 清理 Mainfest") + self.manifest = {} + self._save_manifest() + elif MANIFEST_PATH.exists(): + try: + with open(MANIFEST_PATH, "r", encoding="utf-8") as f: + self.manifest = json.load(f) + except Exception: + self.manifest = {} + + config_path = plugin_root / "config.toml" + try: + with open(config_path, "r", encoding="utf-8") as f: + self.plugin_config = tomlkit.load(f) + except Exception as e: + logger.error(f"加载插件配置失败: {e}") + return False + + try: + await self._init_stores() + except Exception as e: + logger.error(f"初始化存储失败: {e}") + return False + + return True + + async def _init_stores(self): + # ... (Same as original) + self.embedding_manager = create_embedding_api_adapter( + batch_size=self.plugin_config.get("embedding", {}).get("batch_size", 32), + default_dimension=self.plugin_config.get("embedding", {}).get("dimension", 384), + model_name=self.plugin_config.get("embedding", {}).get("model_name", "auto"), + retry_config=self.plugin_config.get("embedding", {}).get("retry", {}), + ) + try: + dim = await self.embedding_manager._detect_dimension() + except: + dim = self.embedding_manager.default_dimension + + q_type_str = str(self.plugin_config.get("embedding", {}).get("quantization_type", "int8") or "int8").lower() + # Need to access QuantizationType from storage_module if not imported globally + QuantizationType = storage_module.QuantizationType + if q_type_str != "int8": + raise ValueError( + "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + + self.vector_store = VectorStore( + dimension=dim, + quantization_type=QuantizationType.INT8, + data_dir=DATA_DIR / "vectors" + ) + + SparseMatrixFormat = storage_module.SparseMatrixFormat + m_fmt_str = self.plugin_config.get("graph", {}).get("sparse_matrix_format", "csr") + m_map = {"csr": SparseMatrixFormat.CSR, "csc": SparseMatrixFormat.CSC} + + self.graph_store = GraphStore( + matrix_format=m_map.get(m_fmt_str, SparseMatrixFormat.CSR), + data_dir=DATA_DIR / "graph" + ) + + self.metadata_store = MetadataStore(data_dir=DATA_DIR / "metadata") + self.metadata_store.connect() + + if RelationWriteService is not None: + self.relation_write_service = RelationWriteService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + + if self.vector_store.has_data(): self.vector_store.load() + if self.graph_store.has_data(): self.graph_store.load() + + def _should_write_relation_vectors(self) -> bool: + retrieval_cfg = self.plugin_config.get("retrieval", {}) + if not isinstance(retrieval_cfg, dict): + return False + rv_cfg = retrieval_cfg.get("relation_vectorization", {}) + if not isinstance(rv_cfg, dict): + return False + return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) + + def load_file(self, file_path: Path) -> str: + with open(file_path, "r", encoding="utf-8") as f: + return f.read() + + def get_file_hash(self, content: str) -> str: + return hashlib.md5(content.encode("utf-8")).hexdigest() + + def _parse_reference_time(self, value: Optional[str]) -> datetime: + """解析 chat_log 模式的参考时间(用于相对时间语义解析)。""" + if not value: + return datetime.now() + formats = [ + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y/%m/%d", + "%Y-%m-%d", + ] + text = str(value).strip() + for fmt in formats: + try: + return datetime.strptime(text, fmt) + except ValueError: + continue + logger.warning( + f"无法解析 chat_reference_time={value},将回退为当前本地时间" + ) + return datetime.now() + + async def _extract_chat_time_meta_with_llm( + self, + text: str, + model_config: Any, + ) -> Optional[Dict[str, Any]]: + """ + 使用 LLM 从聊天文本语义中抽取时间信息。 + 支持将相对时间表达转换为绝对时间。 + """ + if not text.strip(): + return None + + reference_now = self.chat_reference_dt.strftime("%Y/%m/%d %H:%M") + prompt = f"""You are a time extraction engine for chat logs. +Extract temporal information from the following chat paragraph. + +Rules: +1. Use semantic understanding, not regex matching. +2. Convert relative expressions (e.g., yesterday evening, last Friday morning) to absolute local datetime using reference_now. +3. If a time span exists, return event_time_start/event_time_end. +4. If only one point in time exists, return event_time. +5. If no reliable time can be inferred, return all time fields as null. +6. Output ONLY valid JSON. No markdown, no explanation. + +reference_now: {reference_now} +timezone: local system timezone + +Allowed output formats for time values: +- "YYYY/MM/DD" +- "YYYY/MM/DD HH:mm" + +JSON schema: +{{ + "event_time": null, + "event_time_start": null, + "event_time_end": null, + "time_range": null, + "time_granularity": "day", + "time_confidence": 0.0 +}} + +Chat paragraph: +\"\"\"{text}\"\"\" +""" + try: + result = await self._llm_call(prompt, model_config) + except Exception as e: + logger.warning(f"chat_log 时间语义抽取失败: {e}") + return None + + if not isinstance(result, dict): + return None + + raw_time_meta = { + "event_time": result.get("event_time"), + "event_time_start": result.get("event_time_start"), + "event_time_end": result.get("event_time_end"), + "time_range": result.get("time_range"), + "time_granularity": result.get("time_granularity"), + "time_confidence": result.get("time_confidence"), + } + try: + normalized = normalize_time_meta(raw_time_meta) + except Exception as e: + logger.warning(f"chat_log 时间语义抽取结果不可用,已忽略: {e}") + return None + + has_effective_time = any( + key in normalized + for key in ("event_time", "event_time_start", "event_time_end") + ) + if not has_effective_time: + return None + + return normalized + + def _determine_strategy(self, filename: str, content: str) -> BaseStrategy: + """Layer 1: Global Strategy Routing""" + strategy = select_import_strategy( + content, + override=self.target_type, + chat_log=self.chat_log, + ) + if self.chat_log: + logger.info(f"chat_log 模式: {filename} 强制使用 NarrativeStrategy") + elif strategy == ImportStrategy.QUOTE: + logger.info(f"Auto-detected Quote/Lyric type for {filename}") + + if strategy == ImportStrategy.FACTUAL: + return FactualStrategy(filename) + if strategy == ImportStrategy.QUOTE: + return QuoteStrategy(filename) + return NarrativeStrategy(filename) + + def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[BaseStrategy]: + """Layer 2: Chunk-level rescue strategies""" + # If we are already in Quote strategy, no need to rescue + if chunk.type == StratKnowledgeType.QUOTE: + return None + + if looks_like_quote_text(chunk.chunk.text): + logger.info(f" > Rescuing chunk {chunk.chunk.index} as Quote") + return QuoteStrategy(filename) + + return None + + async def process_and_import(self): + if not await self.initialize(): return + + files = list(RAW_DIR.glob("*.txt")) + logger.info(f"扫描到 {len(files)} 个文件 in {RAW_DIR}") + + if not files: return + + tasks = [] + for file_path in files: + tasks.append(asyncio.create_task(self._process_single_file(file_path))) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + success_count = sum(1 for r in results if r is True) + logger.info(f"本次主处理完成,共成功处理 {success_count}/{len(files)} 个文件") + + if self.vector_store: self.vector_store.save() + if self.graph_store: self.graph_store.save() + + async def _process_single_file(self, file_path: Path) -> bool: + filename = file_path.name + async with self.semaphore: + try: + content = self.load_file(file_path) + file_hash = self.get_file_hash(content) + + if not self.force and filename in self.manifest: + record = self.manifest[filename] + if record.get("hash") == file_hash and record.get("imported"): + logger.info(f"跳过已导入文件: {filename}") + return False + + logger.info(f">>> 开始处理: {filename}") + + # 1. Strategy Selection + strategy = self._determine_strategy(filename, content) + logger.info(f" 策略: {strategy.__class__.__name__}") + + # 2. Split (Strategy-Aware) + initial_chunks = strategy.split(content) + logger.info(f" 初步分块: {len(initial_chunks)}") + + processed_data = {"paragraphs": [], "entities": [], "relations": []} + + # 3. Extract Loop + model_config = await self._select_model() + + for i, chunk in enumerate(initial_chunks): + current_strategy = strategy + # Layer 2: Chunk Rescue + rescue_strategy = self._chunk_rescue(chunk, filename) + if rescue_strategy: + # Re-split? No, just re-process this text as a single chunk using the rescue strategy + # But rescue strategy might want to split it further? + # Simplification: Treat the whole chunk text as one block for the rescue strategy + # OR create a single chunk object for it. + # Creating a new chunk using rescue strategy logic might be complex if split behavior differs. + # Let's just instantiate a chunk of the new type manually + chunk.type = StratKnowledgeType.QUOTE + chunk.flags.verbatim = True + chunk.flags.requires_llm = False # Quotes don't usually need LLM + current_strategy = rescue_strategy + + # Extraction + if chunk.flags.requires_llm: + result_chunk = await current_strategy.extract(chunk, lambda p: self._llm_call(p, model_config)) + else: + # For quotes, extract might be just pass through or regex + result_chunk = await current_strategy.extract(chunk) + + time_meta = None + if self.chat_log: + time_meta = await self._extract_chat_time_meta_with_llm( + result_chunk.chunk.text, + model_config, + ) + + # Normalize Data + self._normalize_and_aggregate( + result_chunk, + processed_data, + time_meta=time_meta, + ) + + logger.info(f" 已处理块 {i+1}/{len(initial_chunks)}") + + # 4. Save Json + json_path = PROCESSED_DIR / f"{file_path.stem}.json" + with open(json_path, "w", encoding="utf-8") as f: + json.dump(processed_data, f, ensure_ascii=False, indent=2) + + # 5. Import to DB + async with self.storage_lock: + await self._import_to_db(processed_data) + + self.manifest[filename] = { + "hash": file_hash, + "timestamp": time.time(), + "imported": True + } + self._save_manifest() + self.vector_store.save() + self.graph_store.save() + logger.info(f"✅ 文件 {filename} 处理并导入完成") + return True + + except Exception as e: + logger.error(f"❌ 处理失败 {filename}: {e}") + import traceback + traceback.print_exc() + return False + + def _normalize_and_aggregate( + self, + chunk: ProcessedChunk, + all_data: Dict, + time_meta: Optional[Dict[str, Any]] = None, + ): + """Convert strategy-specific data to unified generic format for storage.""" + # Generic fields + para_item = { + "content": chunk.chunk.text, + "source": chunk.source.file, + "knowledge_type": resolve_stored_knowledge_type( + chunk.type.value, + content=chunk.chunk.text, + ).value, + "entities": [], + "relations": [] + } + + data = chunk.data + + # 1. Triples (Factual) + if "triples" in data: + for t in data["triples"]: + para_item["relations"].append({ + "subject": t.get("subject"), + "predicate": t.get("predicate"), + "object": t.get("object") + }) + # Auto-add entities from triples + para_item["entities"].extend([t.get("subject"), t.get("object")]) + + # 2. Events & Relations (Narrative) + if "events" in data: + # Store events as content/metadata? Or entities? + # For now maybe just keep them in logic, or add as 'Event' entities? + # Creating entities for events is good. + para_item["entities"].extend(data["events"]) + + if "relations" in data: # Narrative also outputs relations list + para_item["relations"].extend(data["relations"]) + for r in data["relations"]: + para_item["entities"].extend([r.get("subject"), r.get("object")]) + + # 3. Verbatim Entities (Quote) + if "verbatim_entities" in data: + para_item["entities"].extend(data["verbatim_entities"]) + + # Dedupe per paragraph + para_item["entities"] = list(set([e for e in para_item["entities"] if e])) + + if time_meta: + para_item["time_meta"] = time_meta + + all_data["paragraphs"].append(para_item) + all_data["entities"].extend(para_item["entities"]) + if "relations" in para_item: + all_data["relations"].extend(para_item["relations"]) + + @retry( + retry=retry_if_exception_type((LLMGenerationError, json.JSONDecodeError)), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + before_sleep=_log_before_retry + ) + async def _llm_call(self, prompt: str, model_config: Any) -> Dict: + """Generic LLM Caller""" + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config, + request_type="Script.ProcessKnowledge" + ) + if success: + txt = response.strip() + if "```" in txt: + txt = txt.split("```json")[-1].split("```")[0].strip() + try: + return json.loads(txt) + except json.JSONDecodeError: + # Fallback: try to find first { and last } + start = txt.find('{') + end = txt.rfind('}') + if start != -1 and end != -1: + return json.loads(txt[start:end+1]) + raise + else: + raise LLMGenerationError("LLM generation failed") + + async def _select_model(self) -> Any: + models = llm_api.get_available_models() + if not models: raise ValueError("No LLM models") + + config_model = self.plugin_config.get("advanced", {}).get("extraction_model", "auto") + if config_model != "auto" and config_model in models: + return models[config_model] + + for task_key in ["lpmm_entity_extract", "lpmm_rdf_build", "embedding"]: + if task_key in models: return models[task_key] + + return models[list(models.keys())[0]] + + # Re-use existing methods + async def _add_entity_with_vector(self, name: str, source_paragraph: Optional[str] = None) -> str: + # Same as before + hash_value = self.metadata_store.add_entity(name, source_paragraph=source_paragraph) + self.graph_store.add_nodes([name]) + try: + emb = await self.embedding_manager.encode(name) + try: + self.vector_store.add(emb.reshape(1, -1), [hash_value]) + except ValueError: pass + except Exception: pass + return hash_value + + async def import_json_data(self, data: Dict, filename: str = "script_import", progress_callback=None): + """Public import entrypoint for pre-processed JSON payloads.""" + if not self.storage_lock: + raise RuntimeError("Importer is not initialized. Call initialize() first.") + + async with self.storage_lock: + await self._import_to_db(data, progress_callback=progress_callback) + self.manifest[filename] = { + "hash": self.get_file_hash(json.dumps(data, ensure_ascii=False, sort_keys=True)), + "timestamp": time.time(), + "imported": True, + } + self._save_manifest() + self.vector_store.save() + self.graph_store.save() + + async def _import_to_db(self, data: Dict, progress_callback=None): + # Same logic, but ensure robust + with self.graph_store.batch_update(): + for item in data.get("paragraphs", []): + paragraph = normalize_paragraph_import_item( + item, + default_source="script", + ) + content = paragraph["content"] + source = paragraph["source"] + k_type_val = paragraph["knowledge_type"] + + h_val = self.metadata_store.add_paragraph( + content=content, + source=source, + knowledge_type=k_type_val, + time_meta=paragraph["time_meta"], + ) + + if h_val not in self.vector_store: + try: + emb = await self.embedding_manager.encode(content) + self.vector_store.add(emb.reshape(1, -1), [h_val]) + except Exception as e: + logger.error(f" Vector fail: {e}") + + para_entities = paragraph["entities"] + for entity in para_entities: + if entity: + await self._add_entity_with_vector(entity, source_paragraph=h_val) + + para_relations = paragraph["relations"] + for rel in para_relations: + s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object") + if s and p and o: + await self._add_entity_with_vector(s, source_paragraph=h_val) + await self._add_entity_with_vector(o, source_paragraph=h_val) + confidence = float(rel.get("confidence", 1.0) or 1.0) + rel_meta = rel.get("metadata", {}) + write_vector = self._should_write_relation_vectors() + if self.relation_write_service is not None: + await self.relation_write_service.upsert_relation_with_vector( + subject=s, + predicate=p, + obj=o, + confidence=confidence, + source_paragraph=h_val, + metadata=rel_meta if isinstance(rel_meta, dict) else {}, + write_vector=write_vector, + ) + else: + rel_hash = self.metadata_store.add_relation( + s, + p, + o, + confidence=confidence, + source_paragraph=h_val, + metadata=rel_meta if isinstance(rel_meta, dict) else {}, + ) + self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash]) + try: + self.metadata_store.set_relation_vector_state(rel_hash, "none") + except Exception: + pass + + if progress_callback: progress_callback(1) + + async def close(self): + if self.metadata_store: self.metadata_store.close() + + def _save_manifest(self): + with open(MANIFEST_PATH, "w", encoding="utf-8") as f: + json.dump(self.manifest, f, ensure_ascii=False, indent=2) + +async def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + if not global_config: return + + importer = AutoImporter( + force=args.force, + clear_manifest=args.clear_manifest, + target_type=args.type, + concurrency=args.concurrency, + chat_log=args.chat_log, + chat_reference_time=args.chat_reference_time, + ) + await importer.process_and_import() + await importer.close() + +if __name__ == "__main__": + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.run(main()) diff --git a/plugins/A_memorix/scripts/rebuild_episodes.py b/plugins/A_memorix/scripts/rebuild_episodes.py new file mode 100644 index 00000000..b6adaa21 --- /dev/null +++ b/plugins/A_memorix/scripts/rebuild_episodes.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +"""Episode source 级重建工具。""" + +from __future__ import annotations + +import argparse +import asyncio +import sys +from pathlib import Path +from typing import Any, Dict, List + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + +try: + import tomlkit # type: ignore +except Exception: # pragma: no cover + tomlkit = None + +from A_memorix.core.storage import MetadataStore +from A_memorix.core.utils.episode_service import EpisodeService + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Rebuild A_Memorix episodes by source") + parser.add_argument("--data-dir", default=str(PLUGIN_ROOT / "data"), help="插件数据目录") + parser.add_argument("--source", type=str, help="指定单个 source 入队/重建") + parser.add_argument("--all", action="store_true", help="对所有 source 入队/重建") + parser.add_argument("--wait", action="store_true", help="在脚本内同步执行重建") + return parser + + +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + + +def _load_plugin_config() -> Dict[str, Any]: + config_path = PLUGIN_ROOT / "config.toml" + if tomlkit is None or not config_path.exists(): + return {} + try: + with open(config_path, "r", encoding="utf-8") as handle: + parsed = tomlkit.load(handle) + return dict(parsed) if isinstance(parsed, dict) else {} + except Exception: + return {} + + +def _resolve_sources(store: MetadataStore, *, source: str | None, rebuild_all: bool) -> List[str]: + if rebuild_all: + return list(store.list_episode_sources_for_rebuild()) + token = str(source or "").strip() + if not token: + raise ValueError("必须提供 --source 或 --all") + return [token] + + +async def _run_rebuilds(store: MetadataStore, plugin_config: Dict[str, Any], sources: List[str]) -> int: + service = EpisodeService(metadata_store=store, plugin_config=plugin_config) + failures: List[str] = [] + for source in sources: + started = store.mark_episode_source_running(source) + if not started: + failures.append(f"{source}: unable_to_mark_running") + continue + try: + result = await service.rebuild_source(source) + store.mark_episode_source_done(source) + print( + "rebuilt" + f" source={source}" + f" paragraphs={int(result.get('paragraph_count') or 0)}" + f" groups={int(result.get('group_count') or 0)}" + f" episodes={int(result.get('episode_count') or 0)}" + f" fallback={int(result.get('fallback_count') or 0)}" + ) + except Exception as exc: + err = str(exc)[:500] + store.mark_episode_source_failed(source, err) + failures.append(f"{source}: {err}") + print(f"failed source={source} error={err}") + + if failures: + for item in failures: + print(item) + return 1 + return 0 + + +def main() -> int: + parser = _build_arg_parser() + args = parser.parse_args() + if bool(args.all) == bool(args.source): + parser.error("必须且只能选择一个:--source 或 --all") + + store = MetadataStore(data_dir=Path(args.data_dir) / "metadata") + store.connect() + try: + sources = _resolve_sources(store, source=args.source, rebuild_all=bool(args.all)) + if not sources: + print("no sources to rebuild") + return 0 + + enqueued = 0 + reason = "script_rebuild_all" if args.all else "script_rebuild_source" + for source in sources: + enqueued += int(store.enqueue_episode_source_rebuild(source, reason=reason)) + print(f"enqueued={enqueued} sources={len(sources)}") + + if not args.wait: + return 0 + + plugin_config = _load_plugin_config() + return asyncio.run(_run_rebuilds(store, plugin_config, sources)) + finally: + store.close() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/release_vnext_migrate.py b/plugins/A_memorix/scripts/release_vnext_migrate.py new file mode 100644 index 00000000..0922fd0b --- /dev/null +++ b/plugins/A_memorix/scripts/release_vnext_migrate.py @@ -0,0 +1,731 @@ +#!/usr/bin/env python3 +""" +vNext release migration entrypoint for A_Memorix. + +Subcommands: +- preflight: detect legacy config/data/schema risks +- migrate: offline migrate config + vectors + metadata schema + graph edge hash map +- verify: strict post-migration consistency checks +""" + +from __future__ import annotations + +import argparse +import json +import pickle +import sqlite3 +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import tomlkit + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PLUGIN_ROOT)) + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="A_Memorix vNext release migration tool") + parser.add_argument( + "--config", + default=str(PLUGIN_ROOT / "config.toml"), + help="config.toml path (default: plugins/A_memorix/config.toml)", + ) + parser.add_argument( + "--data-dir", + default="", + help="optional data dir override; default resolved from config.storage.data_dir", + ) + parser.add_argument("--json-out", default="", help="optional JSON report output path") + + sub = parser.add_subparsers(dest="command", required=True) + + p_preflight = sub.add_parser("preflight", help="scan legacy risks") + p_preflight.add_argument("--strict", action="store_true", help="return 1 if any error check exists") + + p_migrate = sub.add_parser("migrate", help="run offline migration") + p_migrate.add_argument("--dry-run", action="store_true", help="only print planned changes") + p_migrate.add_argument( + "--verify-after", + action="store_true", + help="run verify automatically after migrate", + ) + + p_verify = sub.add_parser("verify", help="post-migration verification") + p_verify.add_argument("--strict", action="store_true", help="return 1 if any error check exists") + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + +try: + from core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore + from core.storage.metadata_store import SCHEMA_VERSION +except Exception as e: # pragma: no cover + print(f"❌ failed to import storage modules: {e}") + raise SystemExit(2) + + +@dataclass +class CheckItem: + code: str + level: str + message: str + details: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + out = { + "code": self.code, + "level": self.level, + "message": self.message, + } + if self.details: + out["details"] = self.details + return out + + +def _read_toml(path: Path) -> Dict[str, Any]: + text = path.read_text(encoding="utf-8") + return tomlkit.parse(text) + + +def _write_toml(path: Path, data: Dict[str, Any]) -> None: + path.write_text(tomlkit.dumps(data), encoding="utf-8") + + +def _get_nested(obj: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: + cur: Any = obj + for k in keys: + if not isinstance(cur, dict) or k not in cur: + return default + cur = cur[k] + return cur + + +def _ensure_table(obj: Dict[str, Any], key: str) -> Dict[str, Any]: + if key not in obj or not isinstance(obj[key], dict): + obj[key] = tomlkit.table() + return obj[key] + + +def _resolve_data_dir(config_doc: Dict[str, Any], explicit_data_dir: Optional[str]) -> Path: + if explicit_data_dir: + return Path(explicit_data_dir).expanduser().resolve() + raw = str(_get_nested(config_doc, ("storage", "data_dir"), "./data") or "./data").strip() + if raw.startswith("."): + return (PLUGIN_ROOT / raw).resolve() + return Path(raw).expanduser().resolve() + + +def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool: + row = conn.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1", + (table,), + ).fetchone() + return row is not None + + +def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]: + hashes: List[str] = [] + if _sqlite_table_exists(conn, "relations"): + rows = conn.execute("SELECT hash FROM relations").fetchall() + hashes.extend(str(r[0]) for r in rows if r and r[0]) + if _sqlite_table_exists(conn, "deleted_relations"): + rows = conn.execute("SELECT hash FROM deleted_relations").fetchall() + hashes.extend(str(r[0]) for r in rows if r and r[0]) + + alias_map: Dict[str, str] = {} + conflicts: Dict[str, set[str]] = {} + for h in hashes: + if len(h) != 64: + continue + alias = h[:32] + old = alias_map.get(alias) + if old is None: + alias_map[alias] = h + continue + if old != h: + conflicts.setdefault(alias, set()).update({old, h}) + return {k: sorted(v) for k, v in conflicts.items()} + + +def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]: + if not _sqlite_table_exists(conn, "paragraphs"): + return [] + + allowed = {item.value for item in KnowledgeType} + rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall() + invalid: List[str] = [] + for row in rows: + raw = row[0] + value = str(raw).strip().lower() if raw is not None else "" + if value not in allowed: + invalid.append(str(raw) if raw is not None else "") + return sorted(set(invalid)) + + +def _guess_vector_dimension(config_doc: Dict[str, Any], vectors_dir: Path) -> int: + meta_path = vectors_dir / "vectors_metadata.pkl" + if meta_path.exists(): + try: + with open(meta_path, "rb") as f: + meta = pickle.load(f) + dim = int(meta.get("dimension", 0)) + if dim > 0: + return dim + except Exception: + pass + try: + dim_cfg = int(_get_nested(config_doc, ("embedding", "dimension"), 1024)) + if dim_cfg > 0: + return dim_cfg + except Exception: + pass + return 1024 + + +def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: + checks: List[CheckItem] = [] + facts: Dict[str, Any] = { + "config_path": str(config_path), + "data_dir": str(data_dir), + } + + if not config_path.exists(): + checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}")) + return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts} + + config_doc = _read_toml(config_path) + tool_mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower() + summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"]) + summary_knowledge_type = str( + _get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative" + ).strip().lower() + quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower() + + facts["routing.tool_search_mode"] = tool_mode + facts["summarization.model_name_type"] = type(summary_model).__name__ + facts["summarization.default_knowledge_type"] = summary_knowledge_type + facts["embedding.quantization_type"] = quantization + + if tool_mode == "legacy": + checks.append( + CheckItem( + "CP-04", + "error", + "routing.tool_search_mode=legacy is no longer accepted at runtime", + ) + ) + elif tool_mode not in {"forward", "disabled"}: + checks.append( + CheckItem( + "CP-04", + "error", + f"routing.tool_search_mode invalid value: {tool_mode}", + ) + ) + + if isinstance(summary_model, str): + checks.append( + CheckItem( + "CP-11", + "error", + "summarization.model_name must be List[str], string legacy format detected", + ) + ) + elif not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model): + checks.append( + CheckItem( + "CP-11", + "error", + "summarization.model_name must be List[str]", + ) + ) + + if summary_knowledge_type not in {item.value for item in KnowledgeType}: + checks.append( + CheckItem( + "CP-13", + "error", + f"invalid summarization.default_knowledge_type: {summary_knowledge_type}", + ) + ) + + if quantization != "int8": + checks.append( + CheckItem( + "UG-07", + "error", + "embedding.quantization_type must be int8 in vNext", + ) + ) + + vectors_dir = data_dir / "vectors" + npy_path = vectors_dir / "vectors.npy" + bin_path = vectors_dir / "vectors.bin" + ids_bin_path = vectors_dir / "vectors_ids.bin" + facts["vectors.npy_exists"] = npy_path.exists() + facts["vectors.bin_exists"] = bin_path.exists() + facts["vectors_ids.bin_exists"] = ids_bin_path.exists() + + if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): + checks.append( + CheckItem( + "CP-07", + "error", + "legacy vectors.npy detected; offline migrate required", + {"npy_path": str(npy_path)}, + ) + ) + + metadata_db = data_dir / "metadata" / "metadata.db" + facts["metadata_db_exists"] = metadata_db.exists() + relation_count = 0 + if metadata_db.exists(): + conn = sqlite3.connect(str(metadata_db)) + try: + has_schema_table = _sqlite_table_exists(conn, "schema_migrations") + facts["schema_migrations_exists"] = has_schema_table + if not has_schema_table: + checks.append( + CheckItem( + "CP-08", + "error", + "schema_migrations table missing (legacy metadata schema)", + ) + ) + else: + row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone() + version = int(row[0]) if row and row[0] is not None else 0 + facts["schema_version"] = version + if version != SCHEMA_VERSION: + checks.append( + CheckItem( + "CP-08", + "error", + f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}", + ) + ) + + if _sqlite_table_exists(conn, "relations"): + row = conn.execute("SELECT COUNT(*) FROM relations").fetchone() + relation_count = int(row[0]) if row and row[0] is not None else 0 + facts["relations_count"] = relation_count + + conflicts = _collect_hash_alias_conflicts(conn) + facts["alias_conflict_count"] = len(conflicts) + if conflicts: + checks.append( + CheckItem( + "CP-05", + "error", + "32-bit relation hash alias conflict detected", + {"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)}, + ) + ) + + invalid_knowledge_types = _collect_invalid_knowledge_types(conn) + facts["invalid_knowledge_type_values"] = invalid_knowledge_types + if invalid_knowledge_types: + checks.append( + CheckItem( + "CP-12", + "error", + "invalid paragraph knowledge_type values detected", + {"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)}, + ) + ) + finally: + conn.close() + else: + checks.append( + CheckItem( + "META-00", + "warning", + "metadata.db not found, schema checks skipped", + ) + ) + + graph_meta_path = data_dir / "graph" / "graph_metadata.pkl" + facts["graph_metadata_exists"] = graph_meta_path.exists() + if relation_count > 0: + if not graph_meta_path.exists(): + checks.append( + CheckItem( + "CP-06", + "error", + "relations exist but graph metadata missing", + ) + ) + else: + try: + with open(graph_meta_path, "rb") as f: + graph_meta = pickle.load(f) + edge_hash_map = graph_meta.get("edge_hash_map", {}) + edge_hash_map_size = len(edge_hash_map) if isinstance(edge_hash_map, dict) else 0 + facts["edge_hash_map_size"] = edge_hash_map_size + if edge_hash_map_size <= 0: + checks.append( + CheckItem( + "CP-06", + "error", + "edge_hash_map missing/empty while relations exist", + ) + ) + except Exception as e: + checks.append( + CheckItem( + "CP-06", + "error", + f"failed to read graph metadata: {e}", + ) + ) + + has_error = any(c.level == "error" for c in checks) + return { + "ok": not has_error, + "checks": [c.to_dict() for c in checks], + "facts": facts, + } + + +def _migrate_config(config_doc: Dict[str, Any]) -> Dict[str, Any]: + changes: Dict[str, Any] = {} + + routing = _ensure_table(config_doc, "routing") + mode_raw = str(routing.get("tool_search_mode", "forward") or "").strip().lower() + mode_new = mode_raw + if mode_raw == "legacy" or mode_raw not in {"forward", "disabled"}: + mode_new = "forward" + if mode_new != mode_raw: + routing["tool_search_mode"] = mode_new + changes["routing.tool_search_mode"] = {"old": mode_raw, "new": mode_new} + + summary = _ensure_table(config_doc, "summarization") + summary_model = summary.get("model_name", ["auto"]) + if isinstance(summary_model, str): + normalized = [summary_model.strip() or "auto"] + summary["model_name"] = normalized + changes["summarization.model_name"] = {"old": summary_model, "new": normalized} + elif not isinstance(summary_model, list): + normalized = ["auto"] + summary["model_name"] = normalized + changes["summarization.model_name"] = {"old": str(type(summary_model)), "new": normalized} + elif any(not isinstance(x, str) for x in summary_model): + normalized = [str(x).strip() for x in summary_model if str(x).strip()] + if not normalized: + normalized = ["auto"] + summary["model_name"] = normalized + changes["summarization.model_name"] = {"old": summary_model, "new": normalized} + + default_knowledge_type = str(summary.get("default_knowledge_type", "narrative") or "").strip().lower() + allowed_knowledge_types = {item.value for item in KnowledgeType} + if default_knowledge_type not in allowed_knowledge_types: + summary["default_knowledge_type"] = "narrative" + changes["summarization.default_knowledge_type"] = { + "old": default_knowledge_type, + "new": "narrative", + } + + embedding = _ensure_table(config_doc, "embedding") + quantization = str(embedding.get("quantization_type", "int8") or "").strip().lower() + if quantization != "int8": + embedding["quantization_type"] = "int8" + changes["embedding.quantization_type"] = {"old": quantization, "new": "int8"} + + return changes + + +def _migrate_impl(config_path: Path, data_dir: Path, dry_run: bool) -> Dict[str, Any]: + config_doc = _read_toml(config_path) + result: Dict[str, Any] = { + "config_path": str(config_path), + "data_dir": str(data_dir), + "dry_run": bool(dry_run), + "steps": {}, + } + + config_changes = _migrate_config(config_doc) + result["steps"]["config"] = {"changed": bool(config_changes), "changes": config_changes} + if config_changes and not dry_run: + _write_toml(config_path, config_doc) + + vectors_dir = data_dir / "vectors" + vectors_dir.mkdir(parents=True, exist_ok=True) + npy_path = vectors_dir / "vectors.npy" + bin_path = vectors_dir / "vectors.bin" + ids_bin_path = vectors_dir / "vectors_ids.bin" + if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): + if dry_run: + result["steps"]["vector"] = {"migrated": False, "reason": "dry_run"} + else: + dim = _guess_vector_dimension(config_doc, vectors_dir) + store = VectorStore( + dimension=max(1, int(dim)), + quantization_type=QuantizationType.INT8, + data_dir=vectors_dir, + ) + result["steps"]["vector"] = store.migrate_legacy_npy(vectors_dir) + else: + result["steps"]["vector"] = {"migrated": False, "reason": "not_required"} + + metadata_dir = data_dir / "metadata" + metadata_dir.mkdir(parents=True, exist_ok=True) + metadata_db = metadata_dir / "metadata.db" + triples: List[Tuple[str, str, str, str]] = [] + relation_count = 0 + + metadata_result: Dict[str, Any] = {"migrated": False, "reason": "not_required"} + if metadata_db.exists(): + store = MetadataStore(data_dir=metadata_dir) + store.connect(enforce_schema=False) + try: + if dry_run: + metadata_result = {"migrated": False, "reason": "dry_run"} + else: + metadata_result = store.run_legacy_migration_for_vnext() + relation_count = int(store.count_relations()) + if relation_count > 0: + triples = [(str(s), str(p), str(o), str(h)) for s, p, o, h in store.get_all_triples()] + finally: + store.close() + result["steps"]["metadata"] = metadata_result + + graph_dir = data_dir / "graph" + graph_dir.mkdir(parents=True, exist_ok=True) + graph_matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr") + graph_store = GraphStore(matrix_format=graph_matrix_format, data_dir=graph_dir) + graph_step: Dict[str, Any] = { + "rebuilt": False, + "mapped_hashes": 0, + "relation_count": relation_count, + "topology_rebuilt_from_relations": False, + } + if relation_count > 0: + if dry_run: + graph_step["reason"] = "dry_run" + else: + if graph_store.has_data(): + graph_store.load() + + mapped = graph_store.rebuild_edge_hash_map(triples) + + # 兜底:历史数据里 graph 节点/边与 relations 脱节时,直接从 relations 重建图。 + if mapped <= 0 or not graph_store.has_edge_hash_map(): + nodes = sorted({s for s, _, o, _ in triples} | {o for _, _, o, _ in triples}) + edges = [(s, o) for s, _, o, _ in triples] + hashes = [h for _, _, _, h in triples] + + graph_store.clear() + if nodes: + graph_store.add_nodes(nodes) + if edges: + mapped = graph_store.add_edges(edges, relation_hashes=hashes) + else: + mapped = 0 + graph_step.update( + { + "topology_rebuilt_from_relations": True, + "rebuilt_nodes": len(nodes), + "rebuilt_edges": int(graph_store.num_edges), + } + ) + + graph_store.save() + graph_step.update({"rebuilt": True, "mapped_hashes": int(mapped)}) + else: + graph_step["reason"] = "no_relations" + result["steps"]["graph"] = graph_step + + return result + + +def _verify_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: + checks: List[CheckItem] = [] + facts: Dict[str, Any] = { + "config_path": str(config_path), + "data_dir": str(data_dir), + } + + if not config_path.exists(): + checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}")) + return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts} + + config_doc = _read_toml(config_path) + mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower() + if mode not in {"forward", "disabled"}: + checks.append(CheckItem("CP-04", "error", f"invalid routing.tool_search_mode: {mode}")) + + summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"]) + if not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model): + checks.append(CheckItem("CP-11", "error", "summarization.model_name must be List[str]")) + summary_knowledge_type = str( + _get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative" + ).strip().lower() + if summary_knowledge_type not in {item.value for item in KnowledgeType}: + checks.append( + CheckItem("CP-13", "error", f"invalid summarization.default_knowledge_type: {summary_knowledge_type}") + ) + + quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower() + if quantization != "int8": + checks.append(CheckItem("UG-07", "error", "embedding.quantization_type must be int8")) + + vectors_dir = data_dir / "vectors" + npy_path = vectors_dir / "vectors.npy" + bin_path = vectors_dir / "vectors.bin" + ids_bin_path = vectors_dir / "vectors_ids.bin" + if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): + checks.append(CheckItem("CP-07", "error", "legacy vectors.npy still exists without bin migration")) + + metadata_dir = data_dir / "metadata" + store = MetadataStore(data_dir=metadata_dir) + try: + store.connect(enforce_schema=True) + schema_version = store.get_schema_version() + facts["schema_version"] = schema_version + if schema_version != SCHEMA_VERSION: + checks.append(CheckItem("CP-08", "error", f"schema version mismatch: {schema_version}")) + + relation_count = int(store.count_relations()) + facts["relations_count"] = relation_count + + conflicts = {} + invalid_knowledge_types: List[str] = [] + db_path = metadata_dir / "metadata.db" + if db_path.exists(): + conn = sqlite3.connect(str(db_path)) + try: + conflicts = _collect_hash_alias_conflicts(conn) + invalid_knowledge_types = _collect_invalid_knowledge_types(conn) + finally: + conn.close() + if conflicts: + checks.append( + CheckItem( + "CP-05", + "error", + "alias conflicts still exist after migration", + {"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)}, + ) + ) + if invalid_knowledge_types: + checks.append( + CheckItem( + "CP-12", + "error", + "invalid paragraph knowledge_type values remain after migration", + {"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)}, + ) + ) + + if relation_count > 0: + graph_dir = data_dir / "graph" + if not (graph_dir / "graph_metadata.pkl").exists(): + checks.append(CheckItem("CP-06", "error", "graph metadata missing while relations exist")) + else: + matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr") + graph_store = GraphStore(matrix_format=matrix_format, data_dir=graph_dir) + graph_store.load() + if not graph_store.has_edge_hash_map(): + checks.append(CheckItem("CP-06", "error", "edge_hash_map is empty")) + except Exception as e: + checks.append(CheckItem("CP-08", "error", f"metadata strict connect failed: {e}")) + finally: + try: + store.close() + except Exception: + pass + + has_error = any(c.level == "error" for c in checks) + return { + "ok": not has_error, + "checks": [c.to_dict() for c in checks], + "facts": facts, + } + + +def _print_report(title: str, report: Dict[str, Any]) -> None: + print(f"=== {title} ===") + print(f"ok: {bool(report.get('ok', True))}") + facts = report.get("facts", {}) + if facts: + print("facts:") + for k in sorted(facts.keys()): + print(f" - {k}: {facts[k]}") + checks = report.get("checks", []) + if checks: + print("checks:") + for item in checks: + print(f" - [{item.get('level')}] {item.get('code')}: {item.get('message')}") + else: + print("checks: none") + + +def _write_json_if_needed(path: str, payload: Dict[str, Any]) -> None: + if not path: + return + out = Path(path).expanduser().resolve() + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"json_out: {out}") + + +def main() -> int: + parser = _build_arg_parser() + args = parser.parse_args() + config_path = Path(args.config).expanduser().resolve() + if not config_path.exists(): + print(f"❌ config not found: {config_path}") + return 2 + config_doc = _read_toml(config_path) + data_dir = _resolve_data_dir(config_doc, args.data_dir) + + if args.command == "preflight": + report = _preflight_impl(config_path, data_dir) + _print_report("vNext Preflight", report) + _write_json_if_needed(args.json_out, report) + has_error = any(item.get("level") == "error" for item in report.get("checks", [])) + if args.strict and has_error: + return 1 + return 0 + + if args.command == "migrate": + payload = _migrate_impl(config_path, data_dir, dry_run=bool(args.dry_run)) + print("=== vNext Migrate ===") + print(json.dumps(payload, ensure_ascii=False, indent=2)) + + verify_report = None + if args.verify_after and not args.dry_run: + verify_report = _verify_impl(config_path, data_dir) + _print_report("vNext Verify (after migrate)", verify_report) + payload["verify_after"] = verify_report + + _write_json_if_needed(args.json_out, payload) + if verify_report is not None: + has_error = any(item.get("level") == "error" for item in verify_report.get("checks", [])) + if has_error: + return 1 + return 0 + + if args.command == "verify": + report = _verify_impl(config_path, data_dir) + _print_report("vNext Verify", report) + _write_json_if_needed(args.json_out, report) + has_error = any(item.get("level") == "error" for item in report.get("checks", [])) + if args.strict and has_error: + return 1 + return 0 + + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/runtime_self_check.py b/plugins/A_memorix/scripts/runtime_self_check.py new file mode 100644 index 00000000..70c423ac --- /dev/null +++ b/plugins/A_memorix/scripts/runtime_self_check.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +"""Run A_Memorix runtime self-check against real embedding/runtime configuration.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +import tempfile +from pathlib import Path +from typing import Any + +import tomlkit + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PLUGIN_ROOT)) + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="A_Memorix runtime self-check") + parser.add_argument( + "--config", + default=str(PLUGIN_ROOT / "config.toml"), + help="config.toml path (default: plugins/A_memorix/config.toml)", + ) + parser.add_argument( + "--data-dir", + default="", + help="optional data dir override; default resolved from config.storage.data_dir", + ) + parser.add_argument( + "--use-config-data-dir", + action="store_true", + help="use config.storage.data_dir directly instead of an isolated temp dir", + ) + parser.add_argument( + "--sample-text", + default="A_Memorix runtime self check", + help="sample text used for real embedding probe", + ) + parser.add_argument("--json", action="store_true", help="print JSON report") + return parser + + +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + +from core.runtime.lifecycle_orchestrator import initialize_storage_async +from core.utils.runtime_self_check import run_embedding_runtime_self_check + + +def _load_config(path: Path) -> dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + raw = tomlkit.load(f) + return dict(raw) if isinstance(raw, dict) else {} + + +def _nested_get(config: dict[str, Any], key: str, default: Any = None) -> Any: + current: Any = config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + +class _PluginStub: + def __init__(self, config: dict[str, Any]): + self.config = config + self.vector_store = None + self.graph_store = None + self.metadata_store = None + self.embedding_manager = None + self.sparse_index = None + self.relation_write_service = None + + def get_config(self, key: str, default: Any = None) -> Any: + return _nested_get(self.config, key, default) + + +async def _main_async(args: argparse.Namespace) -> int: + config_path = Path(args.config).resolve() + if not config_path.exists(): + print(f"❌ 配置文件不存在: {config_path}") + return 2 + + config = _load_config(config_path) + temp_dir_ctx = None + if args.data_dir: + storage_dir = str(Path(args.data_dir).resolve()) + elif args.use_config_data_dir: + raw_data_dir = str(_nested_get(config, "storage.data_dir", "./data") or "./data").strip() + if raw_data_dir.startswith("."): + storage_dir = str((config_path.parent / raw_data_dir).resolve()) + else: + storage_dir = str(Path(raw_data_dir).resolve()) + else: + temp_dir_ctx = tempfile.TemporaryDirectory(prefix="memorix-runtime-self-check-") + storage_dir = temp_dir_ctx.name + + storage_cfg = config.setdefault("storage", {}) + storage_cfg["data_dir"] = storage_dir + + plugin = _PluginStub(config) + try: + await initialize_storage_async(plugin) + report = await run_embedding_runtime_self_check( + config=config, + vector_store=plugin.vector_store, + embedding_manager=plugin.embedding_manager, + sample_text=str(args.sample_text or "A_Memorix runtime self check"), + ) + report["data_dir"] = storage_dir + report["isolated_data_dir"] = temp_dir_ctx is not None + if args.json: + print(json.dumps(report, ensure_ascii=False, indent=2)) + else: + print("A_Memorix Runtime Self-Check") + print(f"ok: {report.get('ok')}") + print(f"code: {report.get('code')}") + print(f"message: {report.get('message')}") + print(f"configured_dimension: {report.get('configured_dimension')}") + print(f"vector_store_dimension: {report.get('vector_store_dimension')}") + print(f"detected_dimension: {report.get('detected_dimension')}") + print(f"encoded_dimension: {report.get('encoded_dimension')}") + print(f"elapsed_ms: {float(report.get('elapsed_ms', 0.0)):.2f}") + return 0 if bool(report.get("ok")) else 1 + finally: + if plugin.metadata_store is not None: + try: + plugin.metadata_store.close() + except Exception: + pass + if temp_dir_ctx is not None: + temp_dir_ctx.cleanup() + + +def main() -> int: + parser = _build_arg_parser() + args = parser.parse_args() + return asyncio.run(_main_async(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml index f6dd6646..d9ce5c5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "python-levenshtein", "quick-algo>=0.1.4", "rich>=14.0.0", + "scipy>=1.7.0", "sqlalchemy>=2.0.40", "sqlmodel>=0.0.24", "structlog>=25.4.0", diff --git a/requirements.txt b/requirements.txt index 6a72e1e4..50a5a746 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ python-multipart>=0.0.20 python-levenshtein quick-algo>=0.1.4 rich>=14.0.0 +scipy>=1.7.0 sqlalchemy>=2.0.40 sqlmodel>=0.0.24 structlog>=25.4.0 diff --git a/src/main.py b/src/main.py index c28b6025..1bfa91b0 100644 --- a/src/main.py +++ b/src/main.py @@ -167,6 +167,7 @@ async def main() -> None: system.schedule_tasks(), ) finally: + emoji_manager.shutdown() await memory_automation_service.shutdown() await get_plugin_runtime_manager().bridge_event("on_stop") await get_plugin_runtime_manager().stop()