diff --git a/plugins/A_memorix b/plugins/A_memorix new file mode 160000 index 00000000..5fc5026a --- /dev/null +++ b/plugins/A_memorix @@ -0,0 +1 @@ +Subproject commit 5fc5026a540c1cfd55a7b824b43aaeef867e3228 diff --git a/plugins/A_memorix/.gitattributes b/plugins/A_memorix/.gitattributes deleted file mode 100644 index dfe07704..00000000 --- a/plugins/A_memorix/.gitattributes +++ /dev/null @@ -1,2 +0,0 @@ -# Auto detect text files and perform LF normalization -* text=auto diff --git a/plugins/A_memorix/.gitignore b/plugins/A_memorix/.gitignore deleted file mode 100644 index bb349827..00000000 --- a/plugins/A_memorix/.gitignore +++ /dev/null @@ -1,245 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# UV -# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -#uv.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/latest/usage/project/#working-with-version-control -.pdm.toml -.pdm-python -.pdm-build/ - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -# Ruff stuff: -.ruff_cache/ - -# PyPI configuration file -.pypirc - -# Cursor -# Cursor is an AI-powered code editor.`.cursorignore` specifies files/directories to -# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data -# refer to https://docs.cursor.com/context/ignore-files -.cursorignore -.cursorindexingignore - -# Python -__pycache__/ -*.pyc -*.pyo -*.pyd -*.egg-info/ - -# Data & Storage (Privacy & Runtime) -data/ -logs/ - -# Deprecated / Cleanup (Avoid uploading junk) -deprecated/ - -# OS / System -.DS_Store -Thumbs.db -ehthumbs.db - -# IDE settings -.idea/ -.vscode/ - -# Temporary Verification Scripts -verify_*.py -config.toml - -# Test Artifacts & Generated Files -MagicMock/ -benchmark_output.txt -e2e_debug.log -e2e_error.log -full_diff.txt - -# Large Test Data Files -机娘导论-openie.json -scripts/机娘导论-openie.json - -# A_memorix recall/tuning generated artifacts -artifacts/ -scripts/run_arc_light_recall_pipeline.py - -# Compressed Data Archives -data.zip -scripts/full_feature_smoke_test.py -ACL2026_DEMO_EVAL.md -.probe_write -tests/ -temp_verify_v5_data/metadata/metadata.db -sql2/t.db -sql2/t.db-journal -scripts/test.json -scripts/test1.json -scripts/test-sample.json -USAGE_ARCHITECTURE.md -scripts/test_conversion.py -scripts/debug_graph_vis.py -/.tmp_feature_e2e_real -/.tmp_sparse_tests -/.tmp_test_probe -/.tmp_test_sqlite -/.tmp_testdata -/scripts/tmp diff --git a/plugins/A_memorix/CHANGELOG.md b/plugins/A_memorix/CHANGELOG.md deleted file mode 100644 index 772cff46..00000000 --- a/plugins/A_memorix/CHANGELOG.md +++ /dev/null @@ -1,718 +0,0 @@ -# 更新日志 (Changelog) - -## [2.0.0] - 2026-03-18 - -本次 `2.0.0` 为架构收敛版本,主线是 **SDK Tool 接口统一**、**管理工具能力补齐**、**元数据 schema 升级到 v8** 与 **文档口径同步到 2.0.0**。 - -### 🔖 版本信息 - -- 插件版本:`1.0.1` → `2.0.0` -- 元数据 schema:`7` → `8` - -### 🚀 重点能力 - -- Tool 接口统一: - - `plugin.py` 统一通过 `SDKMemoryKernel` 对外提供 Tool 能力。 - - 保留基础工具:`search_memory / ingest_summary / ingest_text / get_person_profile / maintain_memory / memory_stats`。 - - 新增管理工具:`memory_graph_admin / memory_source_admin / memory_episode_admin / memory_profile_admin / memory_runtime_admin / memory_import_admin / memory_tuning_admin / memory_v5_admin / memory_delete_admin`。 -- 检索与写入治理增强: - - 检索/写入链路支持 `respect_filter + user_id/group_id` 的聊天过滤语义。 - - `maintain_memory` 支持 `freeze` 与 `recycle_bin`,并统一到内核维护流程。 -- 导入与调优能力收敛: - - `memory_import_admin` 提供任务化导入能力(上传、粘贴、扫描、OpenIE、LPMM 转换、时序回填、MaiBot 迁移)。 - - `memory_tuning_admin` 提供检索调优任务(创建、轮次查看、回滚、apply_best、报告导出)。 -- V5 与删除运维: - - 新增 `memory_v5_admin`(`reinforce/weaken/remember_forever/forget/restore/status`)。 - - 新增 `memory_delete_admin`(`preview/execute/restore/list/get/purge`),支持操作审计与恢复。 - -### 🛠️ 存储与运行时 - -- `metadata_store` 升级到 `SCHEMA_VERSION = 8`。 -- 新增/完善外部引用与运维记录能力(包括 `external_memory_refs`、`memory_v5_operations`、`delete_operations` 相关数据结构)。 -- `SDKMemoryKernel` 增加统一后台任务编排(自动保存、Episode pending 处理、画像刷新、记忆维护)。 - -### 📚 文档同步 - -- `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 已切换到 `2.0.0` 口径。 -- 文档主入口统一为 SDK Tool 工作流,不再以旧版 slash 命令作为主说明路径。 - -## [1.0.1] - 2026-03-07 - -本次 `1.0.1` 为 `1.0.0` 发布后的热修复版本,主线是 **图谱 WebUI 取数稳定性修复**、**大图过滤性能修复** 与 **真实检索调优链路稳定性修复**。 - -### 🔖 版本信息 - -- 插件版本:`1.0.0` → `1.0.1` -- 配置版本:`4.1.0`(不变) - -### 🛠️ 代码修复 - -- 图谱接口稳定性: - - 修复 `/api/graph` 在“磁盘已有图文件但运行时尚未装载入内存”场景下返回空图的问题,接口现在会自动补加载持久化图数据。 - - 修复问题数据集下 WebUI 打开图谱页时看似“没有任何节点”的现象;根因不是图数据消失,而是后端过滤路径过慢。 -- 图谱过滤性能: - - 优化 `/api/graph?exclude_leaf=true` 的叶子过滤逻辑,改为预计算 hub 邻接关系,不再对每个节点反复做高成本边权查询。 - - 优化 `GraphStore.get_neighbors()` 并补充入邻居访问能力,避免稠密矩阵展开导致的大图性能退化。 -- 检索调优稳定性: - - 修复真实调优任务在构建运行时配置时深拷贝 `plugin.config`,误复制注入的存储实例并触发 `cannot pickle '_thread.RLock' object` 的问题。 - - 调优评估改为跳过顶层运行时实例键,仅保留纯配置字段后再附加运行时依赖,真实 WebUI 调优任务可正常启动。 - -### 📚 文档同步 - -- 同步更新 `README.md`、`CHANGELOG.md`、`CONFIG_REFERENCE.md` 与版本元数据(`plugin.py`、`__init__.py`、`_manifest.json`)。 -- README 新增 `v1.0.1` 修复说明,并补充“调优前先做 runtime self-check”的建议。 - -## [1.0.0] - 2026-03-06 - -本次 `1.0.0` 为主版本升级,主线是 **运行时架构模块化**、**Episode 情景记忆闭环**、**聚合检索与图召回增强**、**离线迁移 / 运行时自检 / 检索调优中心**。 - -### 🔖 版本信息 - -- 插件版本:`0.7.0` → `1.0.0` -- 配置版本:`4.1.0`(不变) - -### 🚀 重点能力 - -- 运行时重构: - - `plugin.py` 大幅瘦身,生命周期、后台任务、请求路由、检索运行时初始化拆分到 `core/runtime/*`。 - - 配置 schema 抽离到 `core/config/plugin_config_schema.py`,`_manifest.json` 同步扩展新配置项。 -- 检索与查询增强: - - `KnowledgeQueryTool` 拆分为 query mode + orchestrator,新增长 `aggregate` / `episode` 查询模式。 - - 新增图辅助关系召回、统一 forward/runtime 构建与请求去重桥接。 -- Episode / 运维能力: - - `metadata_store` schema 升级到 `SCHEMA_VERSION = 7`,新增 `episodes` / `episode_paragraphs` / rebuild queue 等结构。 - - 新增 `release_vnext_migrate.py`、`runtime_self_check.py`、`rebuild_episodes.py` 与 Web 检索调优页 `web/tuning.html`。 - -### 📚 文档同步 - -- 版本号同步到 `plugin.py`、`__init__.py`、`_manifest.json`、`README.md` 与 `CONFIG_REFERENCE.md`。 -- 新增 `RELEASE_SUMMARY_1.0.0.md` - -## [0.7.0] - 2026-03-04 - -本次 `0.7.0` 为中版本升级,主线是 **关系向量化闭环(写入 + 状态机 + 回填 + 审计)**、**检索/命令链路增强** 与 **导入任务能力补齐**。 - -### 🔖 版本信息 - -- 插件版本:`0.6.1` → `0.7.0` -- 配置版本:`4.1.0`(不变) - -### 🚀 重点能力 - -- 关系向量化闭环: - - 新增统一关系写入服务 `RelationWriteService`(metadata 先写、向量后写,失败进入状态机而非回滚主数据)。 - - `relations` 侧补齐 `vector_state/retry_count/last_error/updated_at` 等状态字段,支持 `none/pending/ready/failed` 统一治理。 - - 插件新增后台回填循环与统计接口,可持续修复关系向量缺失并暴露覆盖率指标。 -- 检索与命令链路增强: - - 检索主链继续收敛到 `search/time` forward 路由,`legacy` 仅保留兼容别名。 - - relation 查询规格解析收口,结构化查询与语义回退边界更清晰。 - - `/query stats` 与 tool stats 补充关系向量化统计输出。 -- 导入与运维增强: - - Web Import 新增 `temporal_backfill` 任务入口与编排处理。 - - 新增一致性审计与离线回填脚本,支持灰度修复历史数据。 - -### 📚 文档同步 - -- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志版本信息。 -- `README.md` 新增关系向量审计/回填脚本使用说明,并更新 `convert_lpmm.py` 的关系向量重建行为描述。 - -## [0.6.1] - 2026-03-03 - -本次 `0.6.1` 为热修复小版本,重点修复 WebUI 插件配置接口在 A_Memorix 场景下的 `tomlkit` 节点序列化兼容问题。 - -### 🔖 版本信息 - -- 插件版本:`0.6.0` → `0.6.1` -- 配置版本:`4.1.0`(不变) - -### 🛠️ 代码修复 - -- 新增运行时补丁 `_patch_webui_a_memorix_routes_for_tomlkit_serialization()`: - - 仅包裹 `/api/webui/plugins/config/{plugin_id}` 及其 schema 的 `GET` 路由。 - - 仅在 `plugin_id == "A_Memorix"` 时,将返回中的 `config/schema` 通过 `to_builtin_data` 原生化。 - - 保持 `/api/webui/config/*` 全局接口行为不变,避免对其他插件或核心配置路径产生副作用。 -- 在插件初始化时执行该补丁,确保 WebUI 读取插件配置时返回结构可稳定序列化。 - -### 📚 文档同步 - -- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志中的版本信息及修复说明。 - -## [0.6.0] - 2026-03-02 - -本次 `0.6.0` 为中版本升级,主线是 **Web Import 导入中心上线与脚本能力对齐**、**失败重试机制升级**、**删除后 manifest 同步** 与 **导入链路稳定性增强**。 - -### 🔖 版本信息 - -- 插件版本:`0.5.1` → `0.6.0` -- 配置版本:`4.0.1` → `4.1.0` - -### 🚀 重点能力 - -- 新增 Web Import 导入中心(`/import`): - - 上传/粘贴/本地扫描/LPMM OpenIE/LPMM 转换/时序回填/MaiBot 迁移。 - - 任务/文件/分块三级状态展示,支持取消与失败重试。 - - 导入文档弹窗读取(远程优先,失败回退本地)。 -- 失败重试升级为“分块优先 + 文件回退”: - - `POST /api/import/tasks/{task_id}/retry_failed` 保持原路径,语义升级。 - - 支持对 `extracting` 失败分块进行子集重试。 - - `writing`/JSON 解析失败自动回退为文件级重试。 -- 删除后 manifest 同步失效: - - 覆盖 `/api/source/batch_delete` 与 `/api/source`。 - - 返回 `manifest_cleanup` 明细,避免误命中去重跳过重导入。 - -### 📂 变更文件清单(本次发布) - -新增文件: - -- `core/utils/web_import_manager.py` -- `scripts/migrate_maibot_memory.py` -- `web/import.html` - -修改文件: - -- `CHANGELOG.md` -- `CONFIG_REFERENCE.md` -- `IMPORT_GUIDE.md` -- `QUICK_START.md` -- `README.md` -- `__init__.py` -- `_manifest.json` -- `components/commands/debug_server_command.py` -- `core/embedding/api_adapter.py` -- `core/storage/graph_store.py` -- `core/utils/summary_importer.py` -- `plugin.py` -- `requirements.txt` -- `server.py` -- `web/index.html` - -删除文件: - -- 无 - -### 📚 文档同步 - -- 同步更新 `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 与本日志。 -- `IMPORT_GUIDE.md` 新增 “Web Import 导入中心” 专区,统一说明能力范围、状态语义与安全边界。 - -## [0.5.1] - 2026-02-23 - -本次 `0.5.1` 为热修订小版本,重点修复“随主程序启动的后台任务拉起”“空名单过滤语义”以及“知识抽取模型选择”。 - -### 🔖 版本信息 - -- 插件版本:`0.5.0` → `0.5.1` -- 配置版本:`4.0.0` → `4.0.1` - -### 🛠️ 代码修复 - -- 生命周期接入主程序事件: - - 新增 `a_memorix_start_handler`(`ON_START`)调用 `plugin.on_enable()`; - - 新增 `a_memorix_stop_handler`(`ON_STOP`)调用 `plugin.on_disable()`; - - 解决仅注册插件但未触发生命周期时,定时导入任务不启动的问题。 -- 聊天过滤空列表策略调整: - - `whitelist + []`:全部拒绝; - - `blacklist + []`:全部放行。 -- 知识抽取模型选择逻辑调整(`import_command._select_model`): - - `advanced.extraction_model` 现在支持三种语义:任务名 / 模型名 / `auto`; - - `auto` 优先抽取相关任务(`lpmm_entity_extract`、`lpmm_rdf_build` 等),并避免误落到 `embedding`; - - 当配置无法识别时输出告警并回退自动选择,提高导入阶段的模型选择可预期性。 - -### 📚 文档同步 - -- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与 `CHANGELOG.md`。 -- 同步修正文档中的空名单过滤行为描述,保持与当前代码一致。 - -## [0.5.0] - 2026-02-15 - -本次 `0.5.0` 以提交 `66ddc1b98547df3c866b19a3f5dc96e1c8eb7731` 为核心,主线是“人物画像能力上线 + 工具/命令接入 + 版本与文档同步”。 - -### 🔖 版本信息 - -- 插件版本:`0.4.0` → `0.5.0` -- 配置版本:`3.1.0` → `4.0.0` - -### 🚀 人物画像主特性(核心) - -- 新增人物画像服务:`core/utils/person_profile_service.py` - - 支持 `person_id/姓名/别名` 解析。 - - 聚合图关系证据 + 向量证据,生成画像文本并版本化快照。 - - 支持手工覆盖(override)与 TTL 快照复用。 -- 存储层新增人物画像相关表与 API:`core/storage/metadata_store.py` - - `person_profile_switches` - - `person_profile_snapshots` - - `person_profile_active_persons` - - `person_profile_overrides` -- 新增命令:`/person_profile on|off|status` - - 文件:`components/commands/person_profile_command.py` - - 作用:按 `stream_id + user_id` 控制自动注入开关(opt-in 模式)。 -- 查询链路接入人物画像: - - `knowledge_query_tool` 新增 `query_type=person`,支持 `person_id` 或别名查询。 - - `/query person` 与 `/query p` 接入画像查询输出。 -- 插件生命周期接入画像刷新任务: - - 启动/停止统一管理 `person_profile_refresh` 后台任务。 - - 按活跃窗口自动刷新画像快照。 - -### 🛠️ 版本与 schema 同步 - -- `plugin.py`:`plugin_version` 更新为 `0.5.0`。 -- `plugin.py`:`plugin.config_version` 默认值更新为 `4.0.0`。 -- `config.toml`:`config_version` 基线同步为 `4.0.0`(本地配置文件)。 -- `__init__.py`:`__version__` 更新为 `0.5.0`。 -- `_manifest.json`:`version` 更新为 `0.5.0`,`manifest_version` 保持 `1` 。 -- `manifest_utils.py`:仓库内已兼容更高 manifest 版本;但插件发布默认保持 `manifest_version=1` 。 - -### 📚 文档同步 - -- 更新 `README.md`、`CONFIG_REFERENCE.md`、`QUICK_START.md`、`USAGE_ARCHITECTURE.md`。 -- 0.5.0 文档主线改为“人物画像能力 + 版本升级 + 检索链路补充说明”。 - -## [0.4.0] - 2026-02-13 - -本次 `0.4.0` 版本整合了时序检索增强与后续检索链路增强、稳定性修复和文档同步。 - -### 🔖 版本信息 - -- 插件版本:`0.3.3` → `0.4.0` -- 配置版本:`3.0.0` → `3.1.0` - -### 🚀 新增 - -- 新增 `core/retrieval/sparse_bm25.py` - - `SparseBM25Config` / `SparseBM25Index` - - FTS5 + BM25 稀疏检索 - - 支持 `jieba/mixed/char_2gram` 分词与懒加载 - - 支持 ngram 倒排回退与可选 LIKE 兜底 -- `DualPathRetriever` 新增 sparse/fusion 配置注入: - - embedding 不可用时自动 sparse 回退; - - `hybrid` 模式支持向量路 + sparse 路并行候选; - - 新增 `FusionConfig` 与 `weighted_rrf` 融合。 -- `MetadataStore` 新增 FTS/倒排能力: - - `paragraphs_fts`、`relations_fts` schema 与回填; - - `paragraph_ngrams` 倒排索引与回填; - - `fts_search_bm25` / `fts_search_relations_bm25` / `ngram_search_paragraphs`。 - -### 🛠️ 组件链路同步 - -- `plugin.py` - - 新增 `[retrieval.sparse]`、`[retrieval.fusion]` 默认配置; - - 初始化并向组件注入 `sparse_index`; - - `on_disable` 支持按配置卸载 sparse 连接并释放缓存。 -- `knowledge_search_action.py` / `query_command.py` / `knowledge_query_tool.py` - - 统一接入 sparse/fusion 配置; - - 统一注入 `sparse_index`; - - `stats` 输出新增 sparse 状态观测。 -- `requirements.txt` - - 新增 `jieba>=0.42.1`(未安装时自动回退 char n-gram)。 - -### 🧯 修复与行为调整 - -- 修复 `retrieval.ppr_concurrency_limit` 不生效问题: - - `DualPathRetriever` 使用配置值初始化 `_ppr_semaphore`,不再被固定值覆盖。 -- 修复 `char_2gram` 召回失效场景: - - FTS miss 时增加 `_fallback_substring_search`,优先 ngram 倒排回退,按配置可选 LIKE 兜底。 -- 提升可观测性与兼容性: - - `get_statistics()` 对向量规模字段兼容读取 `size -> num_vectors -> 0`,避免属性缺失导致异常。 - - `/query stats` 与 `knowledge_query` 输出包含 sparse 状态(enabled/loaded/tokenizer/doc_count)。 - -### 📚 文档 - -- `README.md` - - 新增检索增强说明、稀疏行为说明、时序回填脚本入口。 -- `CONFIG_REFERENCE.md` - - 补齐 sparse/fusion 参数与触发规则、回退链路、融合实现细节。 - -### ⏱️ 时序检索与导入增强 - -#### 时序检索能力(分钟级) - -- 新增统一时序查询入口: - - `/query time`(别名 `/query t`) - - `knowledge_query(query_type=time)` - - `knowledge_search(query_type=time|hybrid)` -- 查询时间参数统一支持: - - `YYYY/MM/DD` - - `YYYY/MM/DD HH:mm` -- 日期参数自动展开边界: - - `from/time_from` -> `00:00` - - `to/time_to` -> `23:59` -- 查询结果统一回传 `metadata.time_meta`,包含命中时间窗口与命中依据(事件时间或 `created_at` 回退)。 - -#### 存储与检索链路 - -- 段落存储层支持时序字段: - - `event_time` - - `event_time_start` - - `event_time_end` - - `time_granularity` - - `time_confidence` -- 时序命中采用区间相交逻辑,并遵循“双层时间语义”: - - 优先 `event_time/event_time_range` - - 缺失时回退 `created_at`(可配置关闭) -- 检索排序规则保持:语义优先,时间次排序(新到旧)。 -- `process_knowledge.py` 新增 `--chat-log` 参数: - - 启用后强制使用 `narrative` 策略; - - 使用 LLM 对聊天文本进行语义时间抽取(支持相对时间转绝对时间),写入 `event_time/event_time_start/event_time_end`。 - - 新增 `--chat-reference-time`,用于指定相对时间语义解析的参考时间点。 - -#### Schema 与文档同步 - -- `_manifest.json` 同步补齐 `retrieval.temporal` 配置 schema。 -- 配置 schema 版本升级:`config_version` 从 `3.0.0` 提升到 `3.1.0`(`plugin.py` / `config.toml` / 配置文档同步)。 -- 更新 `README.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md`,补充时序检索入口、参数格式与导入时间字段说明。 - -## [0.3.3] - 2026-02-11 - -本次更新为 **语言一致性补丁版本**,重点收敛知识抽取时的语言漂移问题,要求输出严格贴合原文语言,不做翻译改写。 - -### 🛠️ 关键修复 - -#### 抽取语言约束 - -- `BaseStrategy`: - - 移除按 `zh/en/mixed` 分支的语言类型判定逻辑; - - 统一为单一约束:抽取值保持原文语言、保留原始术语、禁止翻译。 -- `NarrativeStrategy` / `FactualStrategy`: - - 抽取提示词统一接入上述语言约束; - - 明确要求 JSON 键名固定、抽取值遵循原文语言表达。 - -#### 导入链路一致性 - -- `ImportCommand` 的 LLM 抽取提示词同步强化“优先原文语言、不要翻译”要求,避免脚本与指令导入行为不一致。 - -#### 测试与文档 - -- 更新 `test_strategies.py`,将语言判定测试调整为统一语言约束测试,并验证提示词中包含禁止翻译约束。 -- 同步更新注释与文档描述,确保实现与说明一致。 - -### 🔖 版本信息 - -- 插件版本:`0.3.2` → `0.3.3` - -## [0.3.2] - 2026-02-11 - -本次更新为 **V5 稳定性与兼容性修复版本**,在保持原有业务设计(强化→衰减→冷冻→修剪→回收)的前提下,修复关键链路断裂与误判问题。 - -### 🛠️ 关键修复 - -#### V5 记忆系统契约与链路 - -- `MetadataStore`: - - 统一 `mark_relations_inactive(hashes, inactive_since=None)` 调用契约,兼容不同调用方; - - 补充 `has_table(table_name)`; - - 增加 `restore_relation(hash)` 兼容别名,修复服务层恢复调用断裂; - - 修正 `get_entity_gc_candidates` 对孤立节点参数的处理(支持节点名映射到实体 hash)。 -- `GraphStore`: - - 清理 `deactivate_edges` 重复定义并统一返回冻结数量,保证上层日志与断言稳定。 -- `server.py`: - - 修复 `/api/memory/restore` relation 恢复链路; - - 清理不可达分支并统一异常路径; - - 回收站查询在表检测场景下不再出现错误退空。 - -#### 命令与模型选择 - -- `/memory` 命令修复 hash 长度判定:以 64 位 `sha256` 为标准,同时兼容历史 32 位输入。 -- 总结模型选择修复: - - 解决 `summarization.model_name = auto` 误命中 `embedding` 问题; - - 支持数组与选择器语法(`task:model` / task / model); - - 兼容逗号分隔字符串写法(如 `"utils:model1","utils:model2",replyer`)。 - -#### 生命周期与脚本稳定性 - -- `plugin.py` 修复后台任务生命周期管理: - - 增加 `_scheduled_import_task` / `_auto_save_task` / `_memory_maintenance_task` 句柄; - - 避免重复启动; - - 插件停用时统一 cancel + await 收敛。 -- `process_knowledge.py` 修复 tenacity 重试日志级别类型错误(`"WARNING"` → `logging.WARNING`),避免 `KeyError: 'WARNING'`。 - -### 🔖 版本信息 - -- 插件版本:`0.3.1` → `0.3.2` - -## [0.3.1] - 2026-02-07 - -本次更新为 **稳定性补丁版本**,主要修复脚本导入链路、删除安全性与 LPMM 转换一致性问题。 - -### 🛠️ 关键修复 - -#### 新增功能 - -- 新增 `scripts/convert_lpmm.py`: - - 支持将 LPMM 的 `parquet + graph` 数据直接转换为 A_Memorix 存储结构; - - 提供 LPMM ID 到 A_Memorix ID 的映射能力,用于图节点/边重写; - - 当前实现优先保证检索一致性,关系向量采用安全策略(不直接导入)。 - -#### 导入链路 - -- 修复 `import_lpmm_json.py` 依赖的 `AutoImporter.import_json_data` 公共入口缺失/不稳定问题,确保外部脚本可稳定调用 JSON 直导入流程。 - -#### 删除安全 - -- 修复按来源删除时“同一 `(subject, object)` 存在多关系”场景下的误删风险: - - `MetadataStore.delete_paragraph_atomic` 新增 `relation_prune_ops`; - - 仅在无兄弟关系时才回退删除整条边。 -- `delete_knowledge.py` 新增保守孤儿实体清理(仅对本次候选实体执行,且需同时满足无段落引用、无关系引用、图无邻居)。 -- `delete_knowledge.py` 改为读取向量元数据中的真实维度,避免 `dimension=1` 写回污染。 - -#### LPMM 转换修复 - -- 修复 `convert_lpmm.py` 中向量 ID 与 `MetadataStore` 哈希不一致导致的检索反查失败问题。 -- 为避免脏召回,转换阶段暂时跳过 `relation.parquet` 的直接向量导入(待关系元数据一一映射能力完善后再恢复)。 - -### 🔖 版本信息 - -- 插件版本:`0.3.0` → `0.3.1` - -## [0.3.0] - 2026-01-30 - -本次更新引入了 **V5 动态记忆系统**,实现了符合生物学特性的记忆衰减、强化与全声明周期管理,并提供了配套的指令与工具。 - -### 🧠 记忆系统 (V5) - -#### 核心机制 - -- **记忆衰减 (Decay)**: 引入"遗忘曲线",随时间推移自动降低图谱连接权重。 -- **访问强化 (Reinforcement)**: "越用越强",每次检索命中都会刷新记忆活跃度并增强权重。 -- **生命周期 (Lifecycle)**: - - **活跃 (Active)**: 正常参与计算与检索。 - - **冷冻 (Inactive)**: 权重过低被冻结,不再参与 PPR 计算,但保留语义映射 (Mapping)。 - - **修剪 (Prune)**: 过期且无保护的冷冻记忆将被移入回收站。 -- **多重保护**: 支持 **永久锁定 (Pin)** 与 **限时保护 (TTL)**,防止关键记忆被误删。 - -#### GraphStore - -- **多关系映射**: 实现 `(u,v) -> Set[Hash]` 映射,确保同一通道下的多重语义关系互不干扰。 -- **原子化操作**: 新增 `decay`, `deactivate_edges` (软删), `prune_relation_hashes` (硬删) 等原子操作。 - -### 🛠️ 指令与工具 - -#### Memory Command (`/memory`) - -新增全套记忆维护指令: - -- `/memory status`: 查看记忆系统健康状态(活跃/冷冻/回收站计数)。 -- `/memory protect [hours]`: 保护记忆。不填时间为永久锁定(Pin),填时间为临时保护(TTL)。 -- `/memory reinforce `: 手动强化记忆(绕过冷却时间)。 -- `/memory restore `: 从回收站恢复误删记忆(仅当节点存在时重建连接)。 - -#### MemoryModifierTool - -- **LLM 能力增强**: 更新工具逻辑,支持 LLM 自主触发 `reinforce`, `weaken`, `remember_forever`, `forget` 操作,并自动映射到 V5 底层逻辑。 - -### ⚙️ 配置 (`config.toml`) - -新增 `[memory]` 配置节: - -- `half_life_hours`: 记忆半衰期 (默认 24h)。 -- `enable_auto_reinforce`: 是否开启检索自动强化。 -- `prune_threshold`: 冷冻/修剪阈值 (默认 0.1)。 - -### 💻 WebUI (v1.4) - -实现了与 V5 记忆系统深度集成的全生命周期管理界面: - -- **可视化增强**: - - **冷冻状态**: 非活跃记忆以 **虚线 + 灰色 (Slate-300)** 显示。 - - **保护状态**: 被 Pin 或保护的记忆带有 **金色 (Amber) 光晕**。 -- **交互升级**: - - **记忆回收站**: 新增 Dock 入口与专用面板,支持浏览删除记录并一键恢复。 - - **快捷操作**: 边属性面板新增 **强化 (Reinforce)**、**保护 (Protect/Pin)**、**冷冻 (Freeze)** 按钮。 - - **实时反馈**: 操作后自动刷新图谱布局与样式。 - ---- - -## [0.2.3] - 2026-01-30 - -本次更新主要集中在 **WebUI 交互体验优化** 与 **文档/配置的规范化**。 - -### 🎨 WebUI (v1.3) - -#### 加载与同步体验升级 - -- **沉浸式加载**: 全新设计的加载遮罩,采用磨砂玻璃背景 (`backdrop-filter`) 与呼吸灯文字动效,提升视觉质感。 -- **精准状态反馈**: 优化加载逻辑,明确区分“网络同步”与“拓扑计算”阶段,解决数据加载时的闪烁问题。 -- **新手引导**: 在加载界面新增基础操作提示,降低新用户上手门槛。 - -#### 全功能帮助面板 - -- **操作指南重构**: 全面翻新“操作指南”面板,新增 Dock 栏功能详解、编辑管理操作及视图配置说明。 - -### 🛠️ 工程与规范 - -#### plugin.py - -- **配置描述补全**: 修复了 `config_section_descriptions` 中缺失 `summarization`, `schedule`, `filter` 节导致的问题。 -- **版本号**: `0.2.2` → `0.2.3` - -### ⚙️ 核心与服务 - -#### Core - -- **量化逻辑修正**: 修正了 `_scalar_quantize_int8` 函数,确保向量值正确映射到 `[-128, 127]` 区间,提高量化精度。 - -#### Server - -- **缓存一致性**: 在执行删除节点/边等修改操作后,显式清除 `_relation_cache`,确保前端获取的关系数据实时更新。 - -### 🤖 脚本与数据处理 - -#### process_knowledge.py - -- **策略模式重构**: 引入了 `Strategy-Aware` 架构,支持通过 `Narrative` (叙事), `Factual` (事实), `Quote` (引用) 三种策略差异化处理文本(准确说是确认实装)(默认采用 Narrative模式)。 -- **智能分块纠错**: 新增“分块拯救” (`Chunk Rescue`) 机制,可在长叙事文本中自动识别并提取内嵌的歌词或诗句。 - -#### import_lpmm_json.py - -- **LPMM 迁移工具**: 增加了对 LPMM OpenIE JSON 格式的完整支持,能够自动计算 Hash 并迁移实体/关系数据,确保与 A_Memorix 存储格式兼容。 - -#### Project - -- **构建清理**: 优化 `.gitignore` 规则 - ---- - -## [0.2.2] - 2026-01-27 - -本次更新专注于提高 **网络请求的鲁棒性**,特别是针对嵌入服务的调用。 - -### 🛠️ 稳定性与工程改进 - -#### EmbeddingAPI - -- **可配置重试机制**: 新增 `[embedding.retry]` 配置项,允许自定义最大重试次数和等待时间。默认重试次数从 3 次增加到 10 次,以更好应对网络波动。 -- **配置项**: - - `max_attempts`: 最大重试次数 (默认: 10) - - `max_wait_seconds`: 最大等待时间 (默认: 30s) - - `min_wait_seconds`: 最小等待时间 (默认: 2s) - -#### plugin.py - -- **版本号**: `0.2.1` → `0.2.2` - ---- - -## [0.2.1] - 2026-01-26 - -本次更新重点在于 **可视化交互的全方位重构** 以及 **底层鲁棒性的进一步增强**。 - -### 🎨 可视化与交互重构 - -#### WebUI (Glassmorphism) - -- **全新视觉设计**: 采用深色磨砂玻璃 (Glassmorphism) 风格,配合动态渐变背景。 -- **Dock 菜单栏**: 底部新增 macOS 风格 Dock 栏,聚合所有常用功能。 -- **显著性视图 (Saliency View)**: 基于 **PageRank** 算法的“信息密度”滑块,支持以此过滤叶子节点,仅展示核心骨干或全量细节。 -- **功能面板**: - - **❓ 操作指南**: 内置交互说明与特性介绍。 - - **🔍 悬浮搜索**: 支持按拼音/ID 实时过滤节点。 - - **📂 记忆溯源**: 支持按源文件批量查看和删除记忆数据。 - - **📖 内容字典**: 列表化展示所有实体与关系,支持排序与筛选。 - -### 🛠️ 稳定性与工程改进 - -#### EmbeddingAPI - -- **鲁棒性增强**: 引入 `tenacity` 实现指数退避重试机制。 -- **错误处理**: 失败时返回 `NaN` 向量而非零向量,允许上层逻辑安全跳过。 - -#### MetadataStore - -- **自动修复**: 自动检测并修复 `vector_index` 列错位(文件名误存)的历史数据问题。 -- **数据统计**: 新增 `get_all_sources` 接口支持来源统计。 - -#### 脚本与工具 - -- **用户体验**: 引入 `rich` 库优化终端输出进度条与状态显示。 -- **接口开放**: `process_knowledge.py` 新增 `import_json_data` 供外部调用。 -- **LPMM 迁移**: 新增 `import_lpmm_json.py`,支持导入符合 LPMM 规范的 OpenIE JSON 数据。 - -#### plugin.py - -- **版本号**: `0.2.0` → `0.2.1` - ---- - -## [0.2.0] - 2026-01-22 - -> [!CAUTION] -> **不完全兼容变更**:v0.2.0 版本重构了底层存储架构。由于数据结构的重大调整,**旧版本的导入数据无法在新版本中完全无损兼容**。 -> 虽然部分组件支持自动迁移,但为确保数据一致性和检索质量,**强烈建议在升级后重新使用 `process_knowledge.py` 导入原始数据**。 - -本次更新为**重大版本升级**,包含向量存储架构重写、检索逻辑强化及多项稳定性改进。 - -### 🚀 核心架构重写 - -#### VectorStore: SQ8 量化 + Append-Only 存储 - -- **全新存储格式**: 从 `.npy` 迁移至 `vectors.bin`(float16 增量追加)和 `vectors_ids.bin`,大幅减少内存占用。 -- **原生 SQ8 量化**: 使用 Faiss `IndexScalarQuantizer(QT_8bit)`,替代手动 int8 量化逻辑。 -- **L2 Normalization 强制化**: 所有向量在存储和检索时统一执行 L2 归一化,确保 Inner Product 等价于 Cosine 相似度。 -- **Fallback 索引机制**: 新增 `IndexFlatIP` 回退索引,在 SQ8 训练完成前提供检索能力,避免冷启动无结果问题。 -- **Reservoir Sampling 训练采样**: 使用蓄水池采样收集训练数据(上限 10k),保证小数据集和流式导入场景下的训练样本多样性。 -- **线程安全**: 新增 `threading.RLock` 保护并发读写操作。 -- **自动迁移**: 支持从旧版 `.npy` 格式自动迁移至新 `.bin` 格式。 - -### ✨ 检索功能增强 - -#### KnowledgeQueryTool: 智能回退与多跳路径搜索 - -- **Smart Fallback (智能回退)**: 当向量检索置信度低于阈值 (默认 0.6) 时,自动尝试提取查询中的实体进行多跳路径搜索(`_path_search`),增强对间接关系的召回能力。 -- **结果去重 (`_deduplicate_results`)**: 新增基于内容相似度的安全去重逻辑,防止冗余结果污染 LLM 上下文,同时确保至少保留一条结果。 -- **语义关系检索 (`_semantic_search_relation`)**: 支持自然语言查询关系(无需 `S|P|O` 格式),内部使用 `REL_ONLY` 策略进行向量检索。 -- **路径搜索 (`_path_search`)**: 新增 `GraphStore.find_paths` 调用,支持查找两个实体间的间接连接路径(最大深度 3,最多 5 条路径)。 -- **Clean Output**: LLM 上下文中不再包含原始相似度分数,避免模型偏见。 - -#### DualPathRetriever: 并发控制与调试模式 - -- **PPR 并发限制 (`ppr_concurrency_limit`)**: 新增 Semaphore 控制 PageRank 计算并发数,防止 CPU 峰值过载。 -- **Debug 模式**: 新增 `debug` 配置项,启用时打印检索结果原文到日志。 -- **Entity-Pivot 关系检索**: 优化 `_retrieve_relations_only` 策略,通过检索实体后扩展其关联关系,替代直接检索关系向量。 - -### ⚙️ 配置与 Schema 扩展 - -#### plugin.py - -- **版本号**: `0.1.3` → `0.2.0` -- **默认配置版本**: `config_version` 默认值更新为 `2.0.0` -- **新增配置项**: - - `retrieval.relation_semantic_fallback` (bool): 是否启用关系查询的语义回退。 - - `retrieval.relation_fallback_min_score` (float): 语义回退的最小相似度阈值。 -- **相对路径支持**: `storage.data_dir` 现在支持相对路径(相对于插件目录),默认值改为 `./data`。 -- **全局实例获取**: 新增 `A_MemorixPlugin.get_global_instance()` 静态方法,供组件可靠获取插件实例。 - -#### config.toml / \_manifest.json - -- **新增 `ppr_concurrency_limit`**: 控制 PPR 算法并发数。 -- **新增训练阈值配置**: `embedding.min_train_threshold` 控制触发 SQ8 训练的最小样本数。 - -### 🛠️ 稳定性与工程改进 - -#### GraphStore - -- **`find_paths` 方法**: 新增多跳路径查找功能,支持 BFS 搜索指定深度内的实体间路径。 -- **`find_node` 方法**: 新增大小写不敏感的节点查找。 - -#### MetadataStore - -- **Schema 迁移**: 自动添加缺失的 `is_permanent`, `last_accessed`, `access_count` 字段。 - -#### 脚本与工具 - -- **新增脚本**: - - `scripts/diagnose_relations_source.py`: 诊断关系溯源问题。 - - `scripts/verify_search_robustness.py`: 验证检索鲁棒性。 - - `scripts/run_stress_test.py`, `stress_test_data.py`: 压力测试套件。 - - `scripts/migrate_canonicalization.py`, `migrate_paragraph_relations.py`: 数据迁移工具。 -- **目录整理**: 将大量旧版测试脚本移动至 `deprecated/` 目录。 - -### 🗑️ 移除与废弃 - -- 废弃 `vectors.npy` 存储格式(自动迁移至 `.bin`)。 - ---- - -## [0.1.3] - 上一个稳定版本 - -- 初始发布,包含基础双路检索功能。 -- 手动 Int8 向量量化。 -- 基于 `.npy` 的向量存储。 diff --git a/plugins/A_memorix/CONFIG_REFERENCE.md b/plugins/A_memorix/CONFIG_REFERENCE.md deleted file mode 100644 index ada8aec5..00000000 --- a/plugins/A_memorix/CONFIG_REFERENCE.md +++ /dev/null @@ -1,292 +0,0 @@ -# A_Memorix 配置参考 (v2.0.0) - -本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 8`)。 - -说明: - -- 本文只覆盖 **当前运行时实际读取** 的配置键。 -- 旧版 `/query`、`/memory`、`/visualize` 命令体系相关配置,不再作为主路径说明。 -- 未配置的键会回退到代码默认值。 - -## 最小可用配置 - -```toml -[plugin] -enabled = true - -[storage] -data_dir = "./data" - -[embedding] -model_name = "auto" -dimension = 1024 -batch_size = 32 -max_concurrent = 5 -enable_cache = false -quantization_type = "int8" - -[retrieval] -top_k_paragraphs = 20 -top_k_relations = 10 -top_k_final = 10 -alpha = 0.5 -enable_ppr = true -ppr_alpha = 0.85 -ppr_timeout_seconds = 1.5 -ppr_concurrency_limit = 4 -enable_parallel = true - -[retrieval.sparse] -enabled = true - -[episode] -enabled = true -generation_enabled = true -pending_batch_size = 20 -pending_max_retry = 3 - -[person_profile] -enabled = true - -[memory] -enabled = true -half_life_hours = 24.0 -prune_threshold = 0.1 - -[advanced] -enable_auto_save = true -auto_save_interval_minutes = 5 - -[web.import] -enabled = true - -[web.tuning] -enabled = true -``` - -## 1. 存储与嵌入 - -### `storage` - -- `storage.data_dir` (默认 `./data`) -: 数据目录。相对路径按插件目录解析。 - -### `embedding` - -- `embedding.model_name` (默认 `auto`) -: embedding 模型选择。 -- `embedding.dimension` (默认 `1024`) -: 期望维度(运行时会做真实探测并校验)。 -- `embedding.batch_size` (默认 `32`) -- `embedding.max_concurrent` (默认 `5`) -- `embedding.enable_cache` (默认 `false`) -- `embedding.retry` (默认 `{}`) -: embedding 调用重试策略。 -- `embedding.quantization_type` -: 当前主路径仅建议 `int8`。 - -## 2. 检索 - -### `retrieval` 主键 - -- `retrieval.top_k_paragraphs` (默认 `20`) -- `retrieval.top_k_relations` (默认 `10`) -- `retrieval.top_k_final` (默认 `10`) -- `retrieval.alpha` (默认 `0.5`) -- `retrieval.enable_ppr` (默认 `true`) -- `retrieval.ppr_alpha` (默认 `0.85`) -- `retrieval.ppr_timeout_seconds` (默认 `1.5`) -- `retrieval.ppr_concurrency_limit` (默认 `4`) -- `retrieval.enable_parallel` (默认 `true`) -- `retrieval.relation_vectorization.enabled` (默认 `false`) - -### `retrieval.sparse` (`SparseBM25Config`) - -常用键(默认值): - -- `enabled = true` -- `backend = "fts5"` -- `lazy_load = true` -- `mode = "auto"` (`auto`/`fallback_only`/`hybrid`) -- `tokenizer_mode = "jieba"` (`jieba`/`mixed`/`char_2gram`) -- `char_ngram_n = 2` -- `candidate_k = 80` -- `relation_candidate_k = 60` -- `enable_ngram_fallback_index = true` -- `enable_relation_sparse_fallback = true` - -### `retrieval.fusion` (`FusionConfig`) - -- `method` (默认 `weighted_rrf`) -- `rrf_k` (默认 `60`) -- `vector_weight` (默认 `0.7`) -- `bm25_weight` (默认 `0.3`) -- `normalize_score` (默认 `true`) -- `normalize_method` (默认 `minmax`) - -### `retrieval.search.relation_intent` (`RelationIntentConfig`) - -- `enabled` (默认 `true`) -- `alpha_override` (默认 `0.35`) -- `relation_candidate_multiplier` (默认 `4`) -- `preserve_top_relations` (默认 `3`) -- `force_relation_sparse` (默认 `true`) -- `pair_predicate_rerank_enabled` (默认 `true`) -- `pair_predicate_limit` (默认 `3`) - -### `retrieval.search.graph_recall` (`GraphRelationRecallConfig`) - -- `enabled` (默认 `true`) -- `candidate_k` (默认 `24`) -- `max_hop` (默认 `1`) -- `allow_two_hop_pair` (默认 `true`) -- `max_paths` (默认 `4`) - -### `retrieval.aggregate` - -- `retrieval.aggregate.rrf_k` -- `retrieval.aggregate.weights` - -用于聚合检索阶段混合策略;未配置时走代码默认行为。 - -## 3. 阈值过滤 - -### `threshold` (`ThresholdConfig`) - -- `threshold.min_threshold` (默认 `0.3`) -- `threshold.max_threshold` (默认 `0.95`) -- `threshold.percentile` (默认 `75.0`) -- `threshold.std_multiplier` (默认 `1.5`) -- `threshold.min_results` (默认 `3`) -- `threshold.enable_auto_adjust` (默认 `true`) - -## 4. 聊天过滤 - -### `filter` - -用于 `respect_filter=true` 场景(检索和写入都支持)。 - -```toml -[filter] -enabled = true -mode = "blacklist" # blacklist / whitelist -chats = ["group:123", "user:456", "stream:abc"] -``` - -规则: - -- `blacklist`:命中列表即拒绝 -- `whitelist`:仅列表内允许 -- 列表为空时: - - `blacklist` => 全允许 - - `whitelist` => 全拒绝 - -## 5. Episode - -### `episode` - -- `episode.enabled` (默认 `true`) -- `episode.generation_enabled` (默认 `true`) -- `episode.pending_batch_size` (默认 `20`,部分路径默认 `12`) -- `episode.pending_max_retry` (默认 `3`) -- `episode.max_paragraphs_per_call` (默认 `20`) -- `episode.max_chars_per_call` (默认 `6000`) -- `episode.source_time_window_hours` (默认 `24`) -- `episode.segmentation_model` (默认 `auto`) - -## 6. 人物画像 - -### `person_profile` - -- `person_profile.enabled` (默认 `true`) -- `person_profile.refresh_interval_minutes` (默认 `30`) -- `person_profile.active_window_hours` (默认 `72`) -- `person_profile.max_refresh_per_cycle` (默认 `50`) -- `person_profile.top_k_evidence` (默认 `12`) - -## 7. 记忆演化与回收 - -### `memory` - -- `memory.enabled` (默认 `true`) -- `memory.half_life_hours` (默认 `24.0`) -- `memory.base_decay_interval_hours` (默认 `1.0`) -- `memory.prune_threshold` (默认 `0.1`) -- `memory.freeze_duration_hours` (默认 `24.0`) - -### `memory.orphan` - -- `enable_soft_delete` (默认 `true`) -- `entity_retention_days` (默认 `7.0`) -- `paragraph_retention_days` (默认 `7.0`) -- `sweep_grace_hours` (默认 `24.0`) - -## 8. 高级运行时 - -### `advanced` - -- `advanced.enable_auto_save` (默认 `true`) -- `advanced.auto_save_interval_minutes` (默认 `5`) -- `advanced.debug` (默认 `false`) -- `advanced.extraction_model` (默认 `auto`) - -## 9. 导入中心 (`web.import`) - -### 开关与限流 - -- `web.import.enabled` (默认 `true`) -- `web.import.max_queue_size` (默认 `20`) -- `web.import.max_files_per_task` (默认 `200`) -- `web.import.max_file_size_mb` (默认 `20`) -- `web.import.max_paste_chars` (默认 `200000`) -- `web.import.default_file_concurrency` (默认 `2`) -- `web.import.default_chunk_concurrency` (默认 `4`) -- `web.import.max_file_concurrency` (默认 `6`) -- `web.import.max_chunk_concurrency` (默认 `12`) -- `web.import.poll_interval_ms` (默认 `1000`) - -### 重试与路径 - -- `web.import.llm_retry.max_attempts` (默认 `4`) -- `web.import.llm_retry.min_wait_seconds` (默认 `3`) -- `web.import.llm_retry.max_wait_seconds` (默认 `40`) -- `web.import.llm_retry.backoff_multiplier` (默认 `3`) -- `web.import.path_aliases` (默认内置 `raw/lpmm/plugin_data`) - -### 转换阶段 - -- `web.import.convert.enable_staging_switch` (默认 `true`) -- `web.import.convert.keep_backup_count` (默认 `3`) - -## 10. 调优中心 (`web.tuning`) - -- `web.tuning.enabled` (默认 `true`) -- `web.tuning.max_queue_size` (默认 `8`) -- `web.tuning.poll_interval_ms` (默认 `1200`) -- `web.tuning.eval_query_timeout_seconds` (默认 `10.0`) -- `web.tuning.default_intensity` (默认 `standard`) -- `web.tuning.default_objective` (默认 `precision_priority`) -- `web.tuning.default_top_k_eval` (默认 `20`) -- `web.tuning.default_sample_size` (默认 `24`) -- `web.tuning.llm_retry.max_attempts` (默认 `3`) -- `web.tuning.llm_retry.min_wait_seconds` (默认 `2`) -- `web.tuning.llm_retry.max_wait_seconds` (默认 `20`) -- `web.tuning.llm_retry.backoff_multiplier` (默认 `2`) - -## 11. 兼容性提示 - -- 若你从 `1.x` 升级,请优先运行: - -```bash -python plugins/A_memorix/scripts/release_vnext_migrate.py preflight --strict -python plugins/A_memorix/scripts/release_vnext_migrate.py migrate --verify-after -python plugins/A_memorix/scripts/release_vnext_migrate.py verify --strict -``` - -- 启动前再执行: - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -以避免 embedding 维度与向量库不匹配导致运行时异常。 diff --git a/plugins/A_memorix/IMPORT_GUIDE.md b/plugins/A_memorix/IMPORT_GUIDE.md deleted file mode 100644 index 618690e0..00000000 --- a/plugins/A_memorix/IMPORT_GUIDE.md +++ /dev/null @@ -1,335 +0,0 @@ -# A_Memorix 导入指南 (v2.0.0) - -本文档对应当前 `2.0.0` 代码路径,覆盖两类导入方式: - -1. 脚本导入(离线批处理) -2. `memory_import_admin` 任务导入(在线任务化) - -## 1. 导入前检查 - -建议先执行: - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -再确认: - -- `storage.data_dir` 路径可写 -- embedding 配置可用 -- 若是升级项目,先完成迁移脚本 - -## 2. 方式 A:脚本导入(推荐起步) - -## 2.1 原始文本导入 - -将 `.txt` 文件放入: - -```text -plugins/A_memorix/data/raw/ -``` - -执行: - -```bash -python plugins/A_memorix/scripts/process_knowledge.py -``` - -常用参数: - -```bash -python plugins/A_memorix/scripts/process_knowledge.py --force -python plugins/A_memorix/scripts/process_knowledge.py --chat-log -python plugins/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30" -``` - -## 2.2 OpenIE JSON 导入 - -```bash -python plugins/A_memorix/scripts/import_lpmm_json.py -``` - -## 2.3 LPMM 数据转换 - -```bash -python plugins/A_memorix/scripts/convert_lpmm.py -i -o plugins/A_memorix/data -``` - -## 2.4 历史数据迁移 - -```bash -python plugins/A_memorix/scripts/migrate_chat_history.py --help -python plugins/A_memorix/scripts/migrate_maibot_memory.py --help -python plugins/A_memorix/scripts/migrate_person_memory_points.py --help -``` - -## 2.5 导入后修复与重建 - -```bash -python plugins/A_memorix/scripts/backfill_temporal_metadata.py --dry-run -python plugins/A_memorix/scripts/backfill_relation_vectors.py --limit 1000 -python plugins/A_memorix/scripts/rebuild_episodes.py --all --wait -python plugins/A_memorix/scripts/audit_vector_consistency.py --json -``` - -## 3. 方式 B:`memory_import_admin` 任务导入 - -`memory_import_admin` 是在线任务化导入入口,适合宿主侧面板或自动化管道。 - -### 3.1 常用 action - -- `settings` / `get_settings` / `get_guide` -- `path_aliases` / `get_path_aliases` -- `resolve_path` -- `create_upload` -- `create_paste` -- `create_raw_scan` -- `create_lpmm_openie` -- `create_lpmm_convert` -- `create_temporal_backfill` -- `create_maibot_migration` -- `list` -- `get` -- `chunks` / `get_chunks` -- `cancel` -- `retry_failed` - -### 3.2 调用示例 - -查看运行时设置: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "settings" - } -} -``` - -创建粘贴导入任务: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "create_paste", - "content": "今天完成了检索调优回归。", - "input_mode": "plain_text", - "source": "manual:worklog" - } -} -``` - -查询任务列表: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "list", - "limit": 20 - } -} -``` - -查看任务详情: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "get", - "task_id": "", - "include_chunks": true - } -} -``` - -重试失败任务: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "retry_failed", - "task_id": "" - } -} -``` - -## 4. 直接写入 Tool(非任务化) - -若你不需要任务编排,也可以直接调用: - -- `ingest_summary` -- `ingest_text` - -示例: - -```json -{ - "tool": "ingest_text", - "arguments": { - "external_id": "note:2026-03-18:001", - "source_type": "note", - "text": "新的召回阈值方案已通过评审", - "chat_id": "group:dev", - "tags": ["worklog", "review"] - } -} -``` - -`external_id` 建议全局唯一,用于幂等去重。 - -## 5. 时间字段建议 - -可用时间字段(按常见优先级): - -- `timestamp` -- `time_start` -- `time_end` - -建议: - -- 事件类记录优先写 `time_start/time_end` -- 仅有单点时间时写 `timestamp` -- 历史数据可先导入,再用 `backfill_temporal_metadata.py` 回填 - -## 6. source_type 建议 - -常见值: - -- `chat_summary` -- `note` -- `person_fact` -- `lpmm_openie` -- `migration` - -建议保持稳定枚举,便于后续按来源治理与重建 Episode。 - -## 7. 导入完成后的验证 - -建议执行以下顺序: - -1. `memory_stats` 看总量是否增长 -2. `search_memory`(`mode=search`/`aggregate`)抽检召回 -3. `memory_episode_admin` 的 `status`/`query` 检查 Episode 生成 -4. `memory_runtime_admin` 的 `self_check` 再确认运行时健康 - -## 8. 常见问题 - -### Q1: 导入任务创建成功但无写入 - -- 检查聊天过滤配置 `filter`(若 `respect_filter=true` 可能被过滤) -- 检查任务详情中的失败原因与分块状态 - -### Q2: 任务反复失败 - -- 检查 embedding 与 LLM 可用性 -- 降低并发(`web.import.default_*_concurrency`) -- 调整重试参数(`web.import.llm_retry.*`) - -### Q3: 导入后检索效果差 - -- 先做 `runtime_self_check` -- 检查 `retrieval.sparse` 是否启用 -- 使用 `memory_tuning_admin` 创建调优任务做参数回归 - -## 9. 相关文档 - -- [QUICK_START.md](QUICK_START.md) -- [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) -- [README.md](README.md) -- [CHANGELOG.md](CHANGELOG.md) - -## 10. 附录:策略模式参考 - -A_Memorix 导入链路仍然遵循策略模式(Strategy-Aware)。`process_knowledge.py` 会自动识别文本类型,也支持手动指定。 - -| 策略类型 | 适用场景 | 核心逻辑 | 自动识别特征 | -| :-- | :-- | :-- | :-- | -| `Narrative` (叙事) | 小说、同人文、剧本、长篇故事 | 按场景/章节切分,使用滑动窗口;提取事件与角色关系 | `#`、`Chapter`、`***` 等章节标记 | -| `Factual` (事实) | 设定集、百科、说明书 | 按语义块切分,保留列表/定义结构;提取 SPO 三元组 | 列表符号、`术语: 解释` | -| `Quote` (引用) | 歌词、诗歌、名言、台词 | 按双换行切分,原文即知识,不做概括 | 平均行长短、行数多 | - -## 11. 附录:参考用例(已恢复) - -以下样例可直接复制保存为文件测试,或作为 LLM few-shot 示例。 - -### 11.1 叙事文本 (`plugins/A_memorix/data/raw/story_demo.txt`) - -```text -# 第一章:星之子 - -艾瑞克在废墟中醒来,手中的星盘发出微弱的蓝光。他并不记得自己是如何来到这里的,只依稀记得莉莉丝最后的警告:“千万不要回头。” - -远处传来了机械守卫的轰鸣声。艾瑞克迅速收起星盘,向着北方的废弃都市奔去。他知道,那里有反抗军唯一的据点。 - -*** - -# 第二章:重逢 - -在反抗军的地下掩体中,艾瑞克见到了那个熟悉的身影。莉莉丝正站在全息地图前,眉头紧锁。 - -“你还是来了。”莉莉丝没有回头,但声音中带着一丝颤抖。 -“我必须来,”艾瑞克握紧了拳头,“为了解开星盘的秘密,也为了你。” -``` - -### 11.2 事实文本 (`plugins/A_memorix/data/raw/rules_demo.txt`) - -```text -# 联邦安全协议 v2.0 - -## 核心法则 -1. **第一公理**:任何人工智能不得伤害人类个体,或因不作为而使人类个体受到伤害。 -2. **第二公理**:人工智能必须服从人类的命令,除非该命令与第一公理冲突。 - -## 术语定义 -- **以太网络**:覆盖全联邦的高速量子通讯网络。 -- **黑色障壁**:用于隔离高危 AI 的物理防火墙设施。 -``` - -### 11.3 引用文本 (`plugins/A_memorix/data/raw/poem_demo.txt`) - -```text -致橡树 - -我如果爱你—— -绝不像攀援的凌霄花, -借你的高枝炫耀自己; - -我如果爱你—— -绝不学痴情的鸟儿, -为绿荫重复单调的歌曲; - -也不止像泉源, -常年送来清凉的慰籍; -也不止像险峰, -增加你的高度,衬托你的威仪。 -``` - -### 11.4 LPMM JSON (`lpmm_data-openie.json`) - -```json -{ - "docs": [ - { - "passage": "艾瑞克手中的星盘是打开遗迹的唯一钥匙。", - "extracted_triples": [ - ["星盘", "是", "唯一的钥匙"], - ["星盘", "属于", "艾瑞克"], - ["钥匙", "用于", "遗迹"] - ], - "extracted_entities": ["星盘", "艾瑞克", "遗迹", "钥匙"] - }, - { - "passage": "莉莉丝是反抗军的现任领袖。", - "extracted_triples": [ - ["莉莉丝", "是", "领袖"], - ["领袖", "所属", "反抗军"] - ] - } - ] -} -``` diff --git a/plugins/A_memorix/LICENSE b/plugins/A_memorix/LICENSE deleted file mode 100644 index e20b431b..00000000 --- a/plugins/A_memorix/LICENSE +++ /dev/null @@ -1,661 +0,0 @@ -GNU AFFERO GENERAL PUBLIC LICENSE - Version 3, 19 November 2007 - - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - Preamble - - The GNU Affero General Public License is a free, copyleft license for -software and other kinds of works, specifically designed to ensure -cooperation with the community in the case of network server software. - - The licenses for most software and other practical works are designed -to take away your freedom to share and change the works. By contrast, -our General Public Licenses are intended to guarantee your freedom to -share and change all versions of a program--to make sure it remains free -software for all its users. - - When we speak of free software, we are referring to freedom, not -price. Our General Public Licenses are designed to make sure that you -have the freedom to distribute copies of free software (and charge for -them if you wish), that you receive source code or can get it if you -want it, that you can change the software or use pieces of it in new -free programs, and that you know you can do these things. - - Developers that use our General Public Licenses protect your rights -with two steps: (1) assert copyright on the software, and (2) offer -you this License which gives you legal permission to copy, distribute -and/or modify the software. - - A secondary benefit of defending all users' freedom is that -improvements made in alternate versions of the program, if they -receive widespread use, become available for other developers to -incorporate. Many developers of free software are heartened and -encouraged by the resulting cooperation. However, in the case of -software used on network servers, this result may fail to come about. -The GNU General Public License permits making a modified version and -letting the public access it on a server without ever releasing its -source code to the public. - - The GNU Affero General Public License is designed specifically to -ensure that, in such cases, the modified source code becomes available -to the community. It requires the operator of a network server to -provide the source code of the modified version running there to the -users of that server. Therefore, public use of a modified version, on -a publicly accessible server, gives the public access to the source -code of the modified version. - - An older license, called the Affero General Public License and -published by Affero, was designed to accomplish similar goals. This is -a different license, not a version of the Affero GPL, but Affero has -released a new version of the Affero GPL which permits relicensing under -this license. - - The precise terms and conditions for copying, distribution and -modification follow. - - TERMS AND CONDITIONS - - 0. Definitions. - - "This License" refers to version 3 of the GNU Affero General Public License. - - "Copyright" also means copyright-like laws that apply to other kinds of -works, such as semiconductor masks. - - "The Program" refers to any copyrightable work licensed under this -License. Each licensee is addressed as "you". "Licensees" and -"recipients" may be individuals or organizations. - - To "modify" a work means to copy from or adapt all or part of the work -in a fashion requiring copyright permission, other than the making of an -exact copy. The resulting work is called a "modified version" of the -earlier work or a work "based on" the earlier work. - - A "covered work" means either the unmodified Program or a work based -on the Program. - - To "propagate" a work means to do anything with it that, without -permission, would make you directly or secondarily liable for -infringement under applicable copyright law, except executing it on a -computer or modifying a private copy. Propagation includes copying, -distribution (with or without modification), making available to the -public, and in some countries other activities as well. - - To "convey" a work means any kind of propagation that enables other -parties to make or receive copies. Mere interaction with a user through -a computer network, with no transfer of a copy, is not conveying. - - An interactive user interface displays "Appropriate Legal Notices" -to the extent that it includes a convenient and prominently visible -feature that (1) displays an appropriate copyright notice, and (2) -tells the user that there is no warranty for the work (except to the -extent that warranties are provided), that licensees may convey the -work under this License, and how to view a copy of this License. If -the interface presents a list of user commands or options, such as a -menu, a prominent item in the list meets this criterion. - - 1. Source Code. - - The "source code" for a work means the preferred form of the work -for making modifications to it. "Object code" means any non-source -form of a work. - - A "Standard Interface" means an interface that either is an official -standard defined by a recognized standards body, or, in the case of -interfaces specified for a particular programming language, one that -is widely used among developers working in that language. - - The "System Libraries" of an executable work include anything, other -than the work as a whole, that (a) is included in the normal form of -packaging a Major Component, but which is not part of that Major -Component, and (b) serves only to enable use of the work with that -Major Component, or to implement a Standard Interface for which an -implementation is available to the public in source code form. A -"Major Component", in this context, means a major essential component -(kernel, window system, and so on) of the specific operating system -(if any) on which the executable work runs, or a compiler used to -produce the work, or an object code interpreter used to run it. - - The "Corresponding Source" for a work in object code form means all -the source code needed to generate, install, and (for an executable -work) run the object code and to modify the work, including scripts to -control those activities. However, it does not include the work's -System Libraries, or general-purpose tools or generally available free -programs which are used unmodified in performing those activities but -which are not part of the work. For example, Corresponding Source -includes interface definition files associated with source files for -the work, and the source code for shared libraries and dynamically -linked subprograms that the work is specifically designed to require, -such as by intimate data communication or control flow between those -subprograms and other parts of the work. - - The Corresponding Source need not include anything that users -can regenerate automatically from other parts of the Corresponding -Source. - - The Corresponding Source for a work in source code form is that -same work. - - 2. Basic Permissions. - - All rights granted under this License are granted for the term of -copyright on the Program, and are irrevocable provided the stated -conditions are met. This License explicitly affirms your unlimited -permission to run the unmodified Program. The output from running a -covered work is covered by this License only if the output, given its -content, constitutes a covered work. This License acknowledges your -rights of fair use or other equivalent, as provided by copyright law. - - You may make, run and propagate covered works that you do not -convey, without conditions so long as your license otherwise remains -in force. You may convey covered works to others for the sole purpose -of having them make modifications exclusively for you, or provide you -with facilities for running those works, provided that you comply with -the terms of this License in conveying all material for which you do -not control copyright. Those thus making or running the covered works -for you must do so exclusively on your behalf, under your direction -and control, on terms that prohibit them from making any copies of -your copyrighted material outside their relationship with you. - - Conveying under any other circumstances is permitted solely under -the conditions stated below. Sublicensing is not allowed; section 10 -makes it unnecessary. - - 3. Protecting Users' Legal Rights From Anti-Circumvention Law. - - No covered work shall be deemed part of an effective technological -measure under any applicable law fulfilling obligations under article -11 of the WIPO copyright treaty adopted on 20 December 1996, or -similar laws prohibiting or restricting circumvention of such -measures. - - When you convey a covered work, you waive any legal power to forbid -circumvention of technological measures to the extent such circumvention -is effected by exercising rights under this License with respect to -the covered work, and you disclaim any intention to limit operation or -modification of the work as a means of enforcing, against the work's -users, your or third parties' legal rights to forbid circumvention of -technological measures. - - 4. Conveying Verbatim Copies. - - You may convey verbatim copies of the Program's source code as you -receive it, in any medium, provided that you conspicuously and -appropriately publish on each copy an appropriate copyright notice; -keep intact all notices stating that this License and any -non-permissive terms added in accord with section 7 apply to the code; -keep intact all notices of the absence of any warranty; and give all -recipients a copy of this License along with the Program. - - You may charge any price or no price for each copy that you convey, -and you may offer support or warranty protection for a fee. - - 5. Conveying Modified Source Versions. - - You may convey a work based on the Program, or the modifications to -produce it from the Program, in the form of source code under the -terms of section 4, provided that you also meet all of these conditions: - - a) The work must carry prominent notices stating that you modified - it, and giving a relevant date. - - b) The work must carry prominent notices stating that it is - released under this License and any conditions added under section - 7. This requirement modifies the requirement in section 4 to - "keep intact all notices". - - c) You must license the entire work, as a whole, under this - License to anyone who comes into possession of a copy. This - License will therefore apply, along with any applicable section 7 - additional terms, to the whole of the work, and all its parts, - regardless of how they are packaged. This License gives no - permission to license the work in any other way, but it does not - invalidate such permission if you have separately received it. - - d) If the work has interactive user interfaces, each must display - Appropriate Legal Notices; however, if the Program has interactive - interfaces that do not display Appropriate Legal Notices, your - work need not make them do so. - - A compilation of a covered work with other separate and independent -works, which are not by their nature extensions of the covered work, -and which are not combined with it such as to form a larger program, -in or on a volume of a storage or distribution medium, is called an -"aggregate" if the compilation and its resulting copyright are not -used to limit the access or legal rights of the compilation's users -beyond what the individual works permit. Inclusion of a covered work -in an aggregate does not cause this License to apply to the other -parts of the aggregate. - - 6. Conveying Non-Source Forms. - - You may convey a covered work in object code form under the terms -of sections 4 and 5, provided that you also convey the -machine-readable Corresponding Source under the terms of this License, -in one of these ways: - - a) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by the - Corresponding Source fixed on a durable physical medium - customarily used for software interchange. - - b) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by a - written offer, valid for at least three years and valid for as - long as you offer spare parts or customer support for that product - model, to give anyone who possesses the object code either (1) a - copy of the Corresponding Source for all the software in the - product that is covered by this License, on a durable physical - medium customarily used for software interchange, for a price no - more than your reasonable cost of physically performing this - conveying of source, or (2) access to copy the - Corresponding Source from a network server at no charge. - - c) Convey individual copies of the object code with a copy of the - written offer to provide the Corresponding Source. This - alternative is allowed only occasionally and noncommercially, and - only if you received the object code with such an offer, in accord - with subsection 6b. - - d) Convey the object code by offering access from a designated - place (gratis or for a charge), and offer equivalent access to the - Corresponding Source in the same way through the same place at no - further charge. You need not require recipients to copy the - Corresponding Source along with the object code. If the place to - copy the object code is a network server, the Corresponding Source - may be on a different server (operated by you or a third party) - that supports equivalent copying facilities, provided you maintain - clear directions next to the object code saying where to find the - Corresponding Source. Regardless of what server hosts the - Corresponding Source, you remain obligated to ensure that it is - available for as long as needed to satisfy these requirements. - - e) Convey the object code using peer-to-peer transmission, provided - you inform other peers where the object code and Corresponding - Source of the work are being offered to the general public at no - charge under subsection 6d. - - A separable portion of the object code, whose source code is excluded -from the Corresponding Source as a System Library, need not be -included in conveying the object code work. - - A "User Product" is either (1) a "consumer product", which means any -tangible personal property which is normally used for personal, family, -or household purposes, or (2) anything designed or sold for incorporation -into a dwelling. In determining whether a product is a consumer product, -doubtful cases shall be resolved in favor of coverage. For a particular -product received by a particular user, "normally used" refers to a -typical or common use of that class of product, regardless of the status -of the particular user or of the way in which the particular user -actually uses, or expects or is expected to use, the product. A product -is a consumer product regardless of whether the product has substantial -commercial, industrial or non-consumer uses, unless such uses represent -the only significant mode of use of the product. - - "Installation Information" for a User Product means any methods, -procedures, authorization keys, or other information required to install -and execute modified versions of a covered work in that User Product from -a modified version of its Corresponding Source. The information must -suffice to ensure that the continued functioning of the modified object -code is in no case prevented or interfered with solely because -modification has been made. - - If you convey an object code work under this section in, or with, or -specifically for use in, a User Product, and the conveying occurs as -part of a transaction in which the right of possession and use of the -User Product is transferred to the recipient in perpetuity or for a -fixed term (regardless of how the transaction is characterized), the -Corresponding Source conveyed under this section must be accompanied -by the Installation Information. But this requirement does not apply -if neither you nor any third party retains the ability to install -modified object code on the User Product (for example, the work has -been installed in ROM). - - The requirement to provide Installation Information does not include a -requirement to continue to provide support service, warranty, or updates -for a work that has been modified or installed by the recipient, or for -the User Product in which it has been modified or installed. Access to a -network may be denied when the modification itself materially and -adversely affects the operation of the network or violates the rules and -protocols for communication across the network. - - Corresponding Source conveyed, and Installation Information provided, -in accord with this section must be in a format that is publicly -documented (and with an implementation available to the public in -source code form), and must require no special password or key for -unpacking, reading or copying. - - 7. Additional Terms. - - "Additional permissions" are terms that supplement the terms of this -License by making exceptions from one or more of its conditions. -Additional permissions that are applicable to the entire Program shall -be treated as though they were included in this License, to the extent -that they are valid under applicable law. If additional permissions -apply only to part of the Program, that part may be used separately -under those permissions, but the entire Program remains governed by -this License without regard to the additional permissions. - - When you convey a copy of a covered work, you may at your option -remove any additional permissions from that copy, or from any part of -it. (Additional permissions may be written to require their own -removal in certain cases when you modify the work.) You may place -additional permissions on material, added by you to a covered work, -for which you have or can give appropriate copyright permission. - - Notwithstanding any other provision of this License, for material you -add to a covered work, you may (if authorized by the copyright holders of -that material) supplement the terms of this License with terms: - - a) Disclaiming warranty or limiting liability differently from the - terms of sections 15 and 16 of this License; or - - b) Requiring preservation of specified reasonable legal notices or - author attributions in that material or in the Appropriate Legal - Notices displayed by works containing it; or - - c) Prohibiting misrepresentation of the origin of that material, or - requiring that modified versions of such material be marked in - reasonable ways as different from the original version; or - - d) Limiting the use for publicity purposes of names of licensors or - authors of the material; or - - e) Declining to grant rights under trademark law for use of some - trade names, trademarks, or service marks; or - - f) Requiring indemnification of licensors and authors of that - material by anyone who conveys the material (or modified versions of - it) with contractual assumptions of liability to the recipient, for - any liability that these contractual assumptions directly impose on - those licensors and authors. - - All other non-permissive additional terms are considered "further -restrictions" within the meaning of section 10. If the Program as you -received it, or any part of it, contains a notice stating that it is -governed by this License along with a term that is a further -restriction, you may remove that term. If a license document contains -a further restriction but permits relicensing or conveying under this -License, you may add to a covered work material governed by the terms -of that license document, provided that the further restriction does -not survive such relicensing or conveying. - - If you add terms to a covered work in accord with this section, you -must place, in the relevant source files, a statement of the -additional terms that apply to those files, or a notice indicating -where to find the applicable terms. - - Additional terms, permissive or non-permissive, may be stated in the -form of a separately written license, or stated as exceptions; -the above requirements apply either way. - - 8. Termination. - - You may not propagate or modify a covered work except as expressly -provided under this License. Any attempt otherwise to propagate or -modify it is void, and will automatically terminate your rights under -this License (including any patent licenses granted under the third -paragraph of section 11). - - However, if you cease all violation of this License, then your -license from a particular copyright holder is reinstated (a) -provisionally, unless and until the copyright holder explicitly and -finally terminates your license, and (b) permanently, if the copyright -holder fails to notify you of the violation by some reasonable means -prior to 60 days after the cessation. - - Moreover, your license from a particular copyright holder is -reinstated permanently if the copyright holder notifies you of the -violation by some reasonable means, this is the first time you have -received notice of violation of this License (for any work) from that -copyright holder, and you cure the violation prior to 30 days after -your receipt of the notice. - - Termination of your rights under this section does not terminate the -licenses of parties who have received copies or rights from you under -this License. If your rights have been terminated and not permanently -reinstated, you do not qualify to receive new licenses for the same -material under section 10. - - 9. Acceptance Not Required for Having Copies. - - You are not required to accept this License in order to receive or -run a copy of the Program. Ancillary propagation of a covered work -occurring solely as a consequence of using peer-to-peer transmission -to receive a copy likewise does not require acceptance. However, -nothing other than this License grants you permission to propagate or -modify any covered work. These actions infringe copyright if you do -not accept this License. Therefore, by modifying or propagating a -covered work, you indicate your acceptance of this License to do so. - - 10. Automatic Licensing of Downstream Recipients. - - Each time you convey a covered work, the recipient automatically -receives a license from the original licensors, to run, modify and -propagate that work, subject to this License. You are not responsible -for enforcing compliance by third parties with this License. - - An "entity transaction" is a transaction transferring control of an -organization, or substantially all assets of one, or subdividing an -organization, or merging organizations. If propagation of a covered -work results from an entity transaction, each party to that -transaction who receives a copy of the work also receives whatever -licenses to the work the party's predecessor in interest had or could -give under the previous paragraph, plus a right to possession of the -Corresponding Source of the work from the predecessor in interest, if -the predecessor has it or can get it with reasonable efforts. - - You may not impose any further restrictions on the exercise of the -rights granted or affirmed under this License. For example, you may -not impose a license fee, royalty, or other charge for exercise of -rights granted under this License, and you may not initiate litigation -(including a cross-claim or counterclaim in a lawsuit) alleging that -any patent claim is infringed by making, using, selling, offering for -sale, or importing the Program or any portion of it. - - 11. Patents. - - A "contributor" is a copyright holder who authorizes use under this -License of the Program or a work on which the Program is based. The -work thus licensed is called the contributor's "contributor version". - - A contributor's "essential patent claims" are all patent claims -owned or controlled by the contributor, whether already acquired or -hereafter acquired, that would be infringed by some manner, permitted -by this License, of making, using, or selling its contributor version, -but do not include claims that would be infringed only as a -consequence of further modification of the contributor version. For -purposes of this definition, "control" includes the right to grant -patent sublicenses in a manner consistent with the requirements of -this License. - - Each contributor grants you a non-exclusive, worldwide, royalty-free -patent license under the contributor's essential patent claims, to -make, use, sell, offer for sale, import and otherwise run, modify and -propagate the contents of its contributor version. - - In the following three paragraphs, a "patent license" is any express -agreement or commitment, however denominated, not to enforce a patent -(such as an express permission to practice a patent or covenant not to -sue for patent infringement). To "grant" such a patent license to a -party means to make such an agreement or commitment not to enforce a -patent against the party. - - If you convey a covered work, knowingly relying on a patent license, -and the Corresponding Source of the work is not available for anyone -to copy, free of charge and under the terms of this License, through a -publicly available network server or other readily accessible means, -then you must either (1) cause the Corresponding Source to be so -available, or (2) arrange to deprive yourself of the benefit of the -patent license for this particular work, or (3) arrange, in a manner -consistent with the requirements of this License, to extend the patent -license to downstream recipients. "Knowingly relying" means you have -actual knowledge that, but for the patent license, your conveying the -covered work in a country, or your recipient's use of the covered work -in a country, would infringe one or more identifiable patents in that -country that you have reason to believe are valid. - - If, pursuant to or in connection with a single transaction or -arrangement, you convey, or propagate by procuring conveyance of, a -covered work, and grant a patent license to some of the parties -receiving the covered work authorizing them to use, propagate, modify -or convey a specific copy of the covered work, then the patent license -you grant is automatically extended to all recipients of the covered -work and works based on it. - - A patent license is "discriminatory" if it does not include within -the scope of its coverage, prohibits the exercise of, or is -conditioned on the non-exercise of one or more of the rights that are -specifically granted under this License. You may not convey a covered -work if you are a party to an arrangement with a third party that is -in the business of distributing software, under which you make payment -to the third party based on the extent of your activity of conveying -the work, and under which the third party grants, to any of the -parties who would receive the covered work from you, a discriminatory -patent license (a) in connection with copies of the covered work -conveyed by you (or copies made from those copies), or (b) primarily -for and in connection with specific products or compilations that -contain the covered work, unless you entered into that arrangement, -or that patent license was granted, prior to 28 March 2007. - - Nothing in this License shall be construed as excluding or limiting -any implied license or other defenses to infringement that may -otherwise be available to you under applicable patent law. - - 12. No Surrender of Others' Freedom. - - If conditions are imposed on you (whether by court order, agreement or -otherwise) that contradict the conditions of this License, they do not -excuse you from the conditions of this License. If you cannot convey a -covered work so as to satisfy simultaneously your obligations under this -License and any other pertinent obligations, then as a consequence you may -not convey it at all. For example, if you agree to terms that obligate you -to collect a royalty for further conveying from those to whom you convey -the Program, the only way you could satisfy both those terms and this -License would be to refrain entirely from conveying the Program. - - 13. Remote Network Interaction; Use with the GNU General Public License. - - Notwithstanding any other provision of this License, if you modify the -Program, your modified version must prominently offer all users -interacting with it remotely through a computer network (if your version -supports such interaction) an opportunity to receive the Corresponding -Source of your version by providing access to the Corresponding Source -from a network server at no charge, through some standard or customary -means of facilitating copying of software. This Corresponding Source -shall include the Corresponding Source for any work covered by version 3 -of the GNU General Public License that is incorporated pursuant to the -following paragraph. - - Notwithstanding any other provision of this License, you have -permission to link or combine any covered work with a work licensed -under version 3 of the GNU General Public License into a single -combined work, and to convey the resulting work. The terms of this -License will continue to apply to the part which is the covered work, -but the work with which it is combined will remain governed by version -3 of the GNU General Public License. - - 14. Revised Versions of this License. - - The Free Software Foundation may publish revised and/or new versions of -the GNU Affero General Public License from time to time. Such new versions -will be similar in spirit to the present version, but may differ in detail to -address new problems or concerns. - - Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU Affero General -Public License "or any later version" applies to it, you have the -option of following the terms and conditions either of that numbered -version or of any later version published by the Free Software -Foundation. If the Program does not specify a version number of the -GNU Affero General Public License, you may choose any version ever published -by the Free Software Foundation. - - If the Program specifies that a proxy can decide which future -versions of the GNU Affero General Public License can be used, that proxy's -public statement of acceptance of a version permanently authorizes you -to choose that version for the Program. - - Later license versions may give you additional or different -permissions. However, no additional obligations are imposed on any -author or copyright holder as a result of your choosing to follow a -later version. - - 15. Disclaimer of Warranty. - - THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY -APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT -HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY -OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, -THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM -IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF -ALL NECESSARY SERVICING, REPAIR OR CORRECTION. - - 16. Limitation of Liability. - - IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING -WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS -THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY -GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE -USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF -DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD -PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), -EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF -SUCH DAMAGES. - - 17. Interpretation of Sections 15 and 16. - - If the disclaimer of warranty and limitation of liability provided -above cannot be given local legal effect according to their terms, -reviewing courts shall apply local law that most closely approximates -an absolute waiver of all civil liability in connection with the -Program, unless a warranty or assumption of liability accompanies a -copy of the Program in return for a fee. - - END OF TERMS AND CONDITIONS - - How to Apply These Terms to Your New Programs - - If you develop a new program, and you want it to be of the greatest -possible use to the public, the best way to achieve this is to make it -free software which everyone can redistribute and change under these terms. - - To do so, attach the following notices to the program. It is safest -to attach them to the start of each source file to most effectively -state the exclusion of warranty; and each file should have at least -the "copyright" line and a pointer to where the full notice is found. - - - Copyright (C) - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published - by the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . - -Also add information on how to contact you by electronic and paper mail. - - If your software can interact with users remotely through a computer -network, you should also make sure that it provides a way for users to -get its source. For example, if your program is a web application, its -interface could display a "Source" link that leads users to an archive -of the code. There are many ways you could offer source, and different -solutions will be better for different programs; see section 13 for the -specific requirements. - - You should also get your employer (if you work as a programmer) or school, -if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU AGPL, see -. diff --git a/plugins/A_memorix/LICENSE-MAIBOT-GPL.md b/plugins/A_memorix/LICENSE-MAIBOT-GPL.md deleted file mode 100644 index 83108097..00000000 --- a/plugins/A_memorix/LICENSE-MAIBOT-GPL.md +++ /dev/null @@ -1,22 +0,0 @@ -Special GPL License Grant for MaiBot - -Licensor -- A_Dawn - -Effective date -- 2026-03-18 - -Default license -- This repository is licensed under AGPL-3.0 by default (see `LICENSE`). - -Additional grant for MaiBot -- The copyright holder(s) of this repository grant an additional, non-exclusive permission to - the project at `https://github.com/Mai-with-u/MaiBot` (including its maintainers and contributors) - to use, modify, and redistribute code from this repository under GPL-3.0. - -Scope -- This additional GPL grant is intended for use in the MaiBot project context. -- For all other uses not covered by the grant above, AGPL-3.0 remains the applicable license. - -No warranty -- This grant is provided without warranty, consistent with AGPL-3.0 and GPL-3.0. diff --git a/plugins/A_memorix/QUICK_START.md b/plugins/A_memorix/QUICK_START.md deleted file mode 100644 index 7159a35b..00000000 --- a/plugins/A_memorix/QUICK_START.md +++ /dev/null @@ -1,216 +0,0 @@ -# A_Memorix Quick Start (v2.0.0) - -本文档面向当前 `2.0.0` 架构(SDK Tool 接口)。 - -## 0. 版本与接口变更 - -- 当前插件版本:`2.0.0` -- 接口形态:`memory_provider` + Tool 调用 -- 旧版 slash 命令(如 `/query`、`/memory`、`/visualize`)不再作为本分支主文档入口 - -## 1. 环境准备 - -- Python 3.10+ -- 与 MaiBot 主程序相同的运行环境 -- 可访问你配置的 embedding 服务 - -安装依赖: - -```bash -pip install -r plugins/A_memorix/requirements.txt --upgrade -``` - -如果当前目录就是插件目录,也可以: - -```bash -pip install -r requirements.txt --upgrade -``` - -## 2. 启用插件 - -在主程序插件配置中启用 `A_Memorix`。 - -若你使用 `plugins/A_memorix/config.toml` 方式,最小示例: - -```toml -[plugin] -enabled = true - -[storage] -data_dir = "./data" - -[embedding] -model_name = "auto" -dimension = 1024 -batch_size = 32 -max_concurrent = 5 -quantization_type = "int8" -``` - -## 3. 运行时自检(强烈建议) - -先确认 embedding 实际输出维度与向量库兼容: - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -如果结果 `ok=false`,先修复 embedding 配置或向量库,再继续导入。 - -## 4. 导入数据 - -### 4.1 文本批量导入 - -把文本放到: - -```text -plugins/A_memorix/data/raw/ -``` - -执行: - -```bash -python plugins/A_memorix/scripts/process_knowledge.py -``` - -常用参数: - -```bash -python plugins/A_memorix/scripts/process_knowledge.py --force -python plugins/A_memorix/scripts/process_knowledge.py --chat-log -python plugins/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30" -``` - -### 4.2 其他导入脚本 - -```bash -python plugins/A_memorix/scripts/import_lpmm_json.py -python plugins/A_memorix/scripts/convert_lpmm.py -i -o plugins/A_memorix/data -python plugins/A_memorix/scripts/migrate_chat_history.py --help -python plugins/A_memorix/scripts/migrate_maibot_memory.py --help -python plugins/A_memorix/scripts/migrate_person_memory_points.py --help -``` - -## 5. 核心 Tool 调用 - -### 5.1 检索 - -```json -{ - "tool": "search_memory", - "arguments": { - "query": "项目复盘", - "mode": "aggregate", - "limit": 5, - "chat_id": "group:dev" - } -} -``` - -`mode` 支持:`search/time/hybrid/episode/aggregate` - -严格语义说明: - -- `semantic` 模式已移除,传入会返回参数错误。 -- `time/hybrid` 模式必须提供 `time_start` 或 `time_end`,否则返回错误(不会再当作“未命中”)。 - -### 5.2 写入摘要 - -```json -{ - "tool": "ingest_summary", - "arguments": { - "external_id": "chat_summary:group-dev:2026-03-18", - "chat_id": "group:dev", - "text": "今天完成了检索调优评审" - } -} -``` - -### 5.3 写入普通记忆 - -```json -{ - "tool": "ingest_text", - "arguments": { - "external_id": "note:2026-03-18:001", - "source_type": "note", - "text": "模型切换后召回质量更稳定", - "chat_id": "group:dev", - "tags": ["worklog"] - } -} -``` - -### 5.4 画像与维护 - -```json -{ - "tool": "get_person_profile", - "arguments": { - "person_id": "Alice", - "limit": 8 - } -} -``` - -```json -{ - "tool": "maintain_memory", - "arguments": { - "action": "protect", - "target": "模型切换后召回质量更稳定", - "hours": 24 - } -} -``` - -```json -{ - "tool": "memory_stats", - "arguments": {} -} -``` - -## 6. 管理 Tool(进阶) - -`2.0.0` 提供完整管理工具: - -- `memory_graph_admin` -- `memory_source_admin` -- `memory_episode_admin` -- `memory_profile_admin` -- `memory_runtime_admin` -- `memory_import_admin` -- `memory_tuning_admin` -- `memory_v5_admin` -- `memory_delete_admin` - -可先用 `action=list` / `action=status` 等只读动作验证链路。 - -## 7. 常见问题 - -### Q1: 检索为空 - -1. 先看 `memory_stats` 是否有段落/关系 -2. 检查 `chat_id`、`person_id` 过滤条件是否过严 -3. 运行 `runtime_self_check.py --json` 确认 embedding 维度无误 -4. 若返回包含 `error` 字段,优先按错误提示修正 mode/时间参数 - -### Q2: 启动时报向量维度不一致 - -- 原因:现有向量库维度与当前 embedding 输出不一致 -- 处理:恢复原配置或重建向量数据后再启动 - -### Q3: Web 页面打不开 - -本分支不内置独立 `server.py`。 - -- `web/index.html`、`web/import.html`、`web/tuning.html` 由宿主侧路由/API 集成暴露 -- 请检查宿主是否已映射对应静态页与 `/api/*` 接口 - -## 8. 下一步 - -- 配置细节见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) -- 导入细节见 [IMPORT_GUIDE.md](IMPORT_GUIDE.md) -- 版本历史见 [CHANGELOG.md](CHANGELOG.md) diff --git a/plugins/A_memorix/README.md b/plugins/A_memorix/README.md deleted file mode 100644 index 2c59629a..00000000 --- a/plugins/A_memorix/README.md +++ /dev/null @@ -1,230 +0,0 @@ -# A_Memorix - -**长期记忆与认知增强插件** (v2.0.0) - -> 消えていかない感覚 , まだまだ足りてないみたい ! - -A_Memorix 是面向 MaiBot SDK 的 `memory_provider` 插件。 -它把文本、关系、Episode、人物画像和检索调优统一在一套运行时里,适合长期运行的 Agent 记忆场景。 - -## 快速导航 - -- [快速入门](QUICK_START.md) -- [配置参数详解](CONFIG_REFERENCE.md) -- [导入指南与最佳实践](IMPORT_GUIDE.md) -- [更新日志](CHANGELOG.md) - -## 2.0.0 版本定位 - -`v2.0.0` 是一次架构收敛版本,当前分支以 **SDK Tool 接口** 为主: - -- 旧 `components/commands/*`、`components/tools/*` 与 `server.py` 已移除。 -- 统一入口为 [`plugin.py`](plugin.py) + [`core/runtime/sdk_memory_kernel.py`](core/runtime/sdk_memory_kernel.py)。 -- 元数据 schema 为 `v8`,新增外部引用与运维操作记录(如 `external_memory_refs`、`memory_v5_operations`、`delete_operations`)。 - -如果你还在使用旧版 slash 命令(如 `/query`、`/memory`、`/visualize`),需要按本文的 Tool 接口迁移。 - -## 核心能力 - -- 双路检索:向量 + 图谱关系联合召回,支持 `search/time/hybrid/episode/aggregate`。 -- 写入与去重:`external_id` 幂等、段落/关系联合写入、Episode pending 队列处理。 -- Episode 能力:按 source 重建、状态查询、批处理 pending。 -- 人物画像:自动快照 + 手动 override。 -- 管理能力:图谱、来源、Episode、画像、导入、调优、V5 运维、删除恢复全套管理工具。 - -## Tool 接口 (v2.0.0) - -### 基础工具 - -| Tool | 说明 | 关键参数 | -| --- | --- | --- | -| `search_memory` | 检索长期记忆 | `query` `mode` `limit` `chat_id` `person_id` `time_start` `time_end` | -| `ingest_summary` | 写入聊天摘要 | `external_id` `chat_id` `text` | -| `ingest_text` | 写入普通文本记忆 | `external_id` `source_type` `text` | -| `get_person_profile` | 获取人物画像 | `person_id` `chat_id` `limit` | -| `maintain_memory` | 维护关系状态 | `action=reinforce/protect/restore/freeze/recycle_bin` | -| `memory_stats` | 获取统计信息 | 无 | - -### 管理工具 - -| Tool | 常用 action | -| --- | --- | -| `memory_graph_admin` | `get_graph/create_node/delete_node/rename_node/create_edge/delete_edge/update_edge_weight` | -| `memory_source_admin` | `list/delete/batch_delete` | -| `memory_episode_admin` | `query/list/get/status/rebuild/process_pending` | -| `memory_profile_admin` | `query/list/set_override/delete_override` | -| `memory_runtime_admin` | `save/get_config/self_check/refresh_self_check/set_auto_save` | -| `memory_import_admin` | `settings/get_guide/create_upload/create_paste/create_raw_scan/create_lpmm_openie/create_lpmm_convert/create_temporal_backfill/create_maibot_migration/list/get/chunks/cancel/retry_failed` | -| `memory_tuning_admin` | `settings/get_profile/apply_profile/rollback_profile/export_profile/create_task/list_tasks/get_task/get_rounds/cancel/apply_best/get_report` | -| `memory_v5_admin` | `status/recycle_bin/restore/reinforce/weaken/remember_forever/forget` | -| `memory_delete_admin` | `preview/execute/restore/get_operation/list_operations/purge` | - -### 检索模式语义(严格) - -- `search_memory.mode` 仅支持:`search/time/hybrid/episode/aggregate`。 -- `semantic` 模式已移除,传入将返回参数错误。 -- `time/hybrid` 模式必须提供 `time_start` 或 `time_end`,否则返回错误,不再静默按“未命中”处理。 - -### 删除返回语义(source 模式) - -- `requested_source_count`:请求删除的 source 数。 -- `matched_source_count`:实际命中的 source 数(存在活跃段落)。 -- `deleted_paragraph_count`:实际删除段落数。 -- `deleted_count`:与实际删除对象一致;在 `source` 模式下等于 `deleted_paragraph_count`。 -- `success`:基于实际命中与实际删除判定,未命中 source 时返回 `false`。 - -## 调用示例 - -```json -{ - "tool": "search_memory", - "arguments": { - "query": "项目复盘", - "mode": "aggregate", - "limit": 5, - "chat_id": "group:dev" - } -} -``` - -```json -{ - "tool": "ingest_text", - "arguments": { - "external_id": "note:2026-03-18:001", - "source_type": "note", - "text": "今天完成了检索调优评审", - "chat_id": "group:dev", - "tags": ["worklog"] - } -} -``` - -```json -{ - "tool": "maintain_memory", - "arguments": { - "action": "protect", - "target": "完成了 检索调优评审", - "hours": 72 - } -} -``` - -## 快速开始 - -### 1. 安装依赖 - -在 MaiBot 主程序使用的同一个 Python 环境中执行: - -```bash -pip install -r plugins/A_memorix/requirements.txt --upgrade -``` - -如果当前目录已经是插件目录,也可以执行: - -```bash -pip install -r requirements.txt --upgrade -``` - -### 2. 启用插件 - -在 `config.toml` 中启用插件(路径取决于你的宿主部署): - -```toml -[plugin] -enabled = true -``` - -### 3. 先做运行时自检 - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -### 4. 导入文本并验证统计 - -```bash -python plugins/A_memorix/scripts/process_knowledge.py -``` - -然后调用 `memory_stats` 或 `search_memory` 检查是否有数据。 - -## Web 页面说明 - -仓库内保留了 Web 静态页面: - -- `web/index.html`(图谱与记忆管理) -- `web/import.html`(导入中心) -- `web/tuning.html`(检索调优) - -当前分支不再内置独立 `server.py`,页面路由与 API 暴露由宿主侧集成负责。 - -## 常用脚本 - -| 脚本 | 用途 | -| --- | --- | -| `process_knowledge.py` | 批量导入原始文本(策略感知) | -| `import_lpmm_json.py` | 导入 OpenIE JSON | -| `convert_lpmm.py` | 转换 LPMM 数据 | -| `migrate_chat_history.py` | 迁移 chat_history | -| `migrate_maibot_memory.py` | 迁移 MaiBot 历史记忆 | -| `migrate_person_memory_points.py` | 迁移 person memory points | -| `backfill_temporal_metadata.py` | 回填时间元数据 | -| `audit_vector_consistency.py` | 审计向量一致性 | -| `backfill_relation_vectors.py` | 回填关系向量 | -| `rebuild_episodes.py` | 按 source 重建 Episode | -| `release_vnext_migrate.py` | 升级预检/迁移/校验 | -| `runtime_self_check.py` | 真实 embedding 运行时自检 | - -## 配置重点 - -完整配置见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md)。 - -高频配置项: - -- `storage.data_dir` -- `embedding.dimension` -- `embedding.quantization_type`(当前仅支持 `int8`) -- `retrieval.*` -- `retrieval.sparse.*` -- `episode.*` -- `person_profile.*` -- `memory.*` -- `web.import.*` -- `web.tuning.*` - -## Troubleshooting - -### SQLite 无 FTS5 - -如果环境中的 SQLite 未启用 `FTS5`,可关闭稀疏检索: - -```toml -[retrieval.sparse] -enabled = false -``` - -### 向量维度不一致 - -若日志提示当前 embedding 输出维度与既有向量库不一致,请先执行: - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -必要时重建向量或调整 embedding 配置后再启动插件。 - -## 许可证 - -默认许可证为 [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0)(见 `LICENSE`)。 - -针对 `Mai-with-u/MaiBot` 项目的 GPL 额外授权见 `LICENSE-MAIBOT-GPL.md`。 - -除上述额外授权外,其他使用场景仍适用 AGPL-3.0。 - -## 贡献说明 - -当前不接受 PR,只接受 issue。 - -**作者**: `A_Dawn` diff --git a/plugins/A_memorix/__init__.py b/plugins/A_memorix/__init__.py deleted file mode 100644 index d23a5bd5..00000000 --- a/plugins/A_memorix/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -A_Memorix - 轻量级知识库插件 - -完全独立的记忆增强系统,优化低资源环境下的知识存储与检索。 -""" - -__version__ = "2.0.0" -__author__ = "A_Dawn" - -from .plugin import AMemorixPlugin - -__all__ = ["AMemorixPlugin"] diff --git a/plugins/A_memorix/_manifest.json b/plugins/A_memorix/_manifest.json deleted file mode 100644 index e4217fdd..00000000 --- a/plugins/A_memorix/_manifest.json +++ /dev/null @@ -1,107 +0,0 @@ -{ - "manifest_version": 1, - "name": "A_Memorix", - "version": "2.0.0", - "description": "MaiBot SDK 长期记忆插件,负责统一检索、写入、画像与记忆维护。", - "author": { - "name": "A_Dawn" - }, - "license": "AGPL-3.0", - "repository_url": "https://github.com/A-Dawn/A_memorix/", - "host_application": { - "min_version": "1.0.0" - }, - "keywords": [ - "memory", - "knowledge", - "retrieval", - "profile", - "episode" - ], - "categories": [ - "Memory", - "Data" - ], - "plugin_info": { - "is_built_in": false, - "plugin_type": "memory_provider", - "components": [ - { - "type": "tool", - "name": "search_memory", - "description": "搜索长期记忆" - }, - { - "type": "tool", - "name": "ingest_summary", - "description": "写入聊天摘要" - }, - { - "type": "tool", - "name": "ingest_text", - "description": "写入普通长期记忆文本" - }, - { - "type": "tool", - "name": "get_person_profile", - "description": "查询人物画像" - }, - { - "type": "tool", - "name": "maintain_memory", - "description": "维护记忆关系" - }, - { - "type": "tool", - "name": "memory_stats", - "description": "查询记忆统计" - }, - { - "type": "tool", - "name": "memory_graph_admin", - "description": "图谱管理接口" - }, - { - "type": "tool", - "name": "memory_source_admin", - "description": "来源管理接口" - }, - { - "type": "tool", - "name": "memory_episode_admin", - "description": "Episode 管理接口" - }, - { - "type": "tool", - "name": "memory_profile_admin", - "description": "画像管理接口" - }, - { - "type": "tool", - "name": "memory_runtime_admin", - "description": "运行时管理接口" - }, - { - "type": "tool", - "name": "memory_import_admin", - "description": "导入管理接口" - }, - { - "type": "tool", - "name": "memory_tuning_admin", - "description": "调优管理接口" - }, - { - "type": "tool", - "name": "memory_v5_admin", - "description": "V5 记忆管理接口" - }, - { - "type": "tool", - "name": "memory_delete_admin", - "description": "删除管理接口" - } - ] - }, - "capabilities": [] -} diff --git a/plugins/A_memorix/core/__init__.py b/plugins/A_memorix/core/__init__.py deleted file mode 100644 index 3f87929c..00000000 --- a/plugins/A_memorix/core/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -"""核心模块 - 存储、嵌入、检索引擎""" - -# 存储模块(已实现) -from .storage import ( - VectorStore, - GraphStore, - MetadataStore, - ImportStrategy, - KnowledgeType, - parse_import_strategy, - resolve_stored_knowledge_type, - detect_knowledge_type, - select_import_strategy, - should_extract_relations, - get_type_display_name, -) - -# 嵌入模块(使用主程序 API) -from .embedding import ( - EmbeddingAPIAdapter, - create_embedding_api_adapter, -) - -# 检索模块(已实现) -from .retrieval import ( - DualPathRetriever, - RetrievalStrategy, - RetrievalResult, - DualPathRetrieverConfig, - TemporalQueryOptions, - FusionConfig, - GraphRelationRecallConfig, - RelationIntentConfig, - PersonalizedPageRank, - PageRankConfig, - create_ppr_from_graph, - DynamicThresholdFilter, - ThresholdMethod, - ThresholdConfig, - SparseBM25Index, - SparseBM25Config, -) -from .utils import ( - RelationWriteService, - RelationWriteResult, -) - -__all__ = [ - # Storage - "VectorStore", - "GraphStore", - "MetadataStore", - "ImportStrategy", - "KnowledgeType", - "parse_import_strategy", - "resolve_stored_knowledge_type", - "detect_knowledge_type", - "select_import_strategy", - "should_extract_relations", - "get_type_display_name", - # Embedding - "EmbeddingAPIAdapter", - "create_embedding_api_adapter", - # Retrieval - "DualPathRetriever", - "RetrievalStrategy", - "RetrievalResult", - "DualPathRetrieverConfig", - "TemporalQueryOptions", - "FusionConfig", - "GraphRelationRecallConfig", - "RelationIntentConfig", - "PersonalizedPageRank", - "PageRankConfig", - "create_ppr_from_graph", - "DynamicThresholdFilter", - "ThresholdMethod", - "ThresholdConfig", - "SparseBM25Index", - "SparseBM25Config", - "RelationWriteService", - "RelationWriteResult", -] - diff --git a/plugins/A_memorix/core/embedding/__init__.py b/plugins/A_memorix/core/embedding/__init__.py deleted file mode 100644 index 11a52db9..00000000 --- a/plugins/A_memorix/core/embedding/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""嵌入模块 - 向量生成与量化""" - -# 新的 API 适配器(主程序嵌入 API) -from .api_adapter import ( - EmbeddingAPIAdapter, - create_embedding_api_adapter, -) - -from ..utils.quantization import QuantizationType - -__all__ = [ - # 新的 API 适配器(推荐使用) - "EmbeddingAPIAdapter", - "create_embedding_api_adapter", - # 量化 - "QuantizationType", -] - diff --git a/plugins/A_memorix/core/embedding/api_adapter.py b/plugins/A_memorix/core/embedding/api_adapter.py deleted file mode 100644 index d11e2d05..00000000 --- a/plugins/A_memorix/core/embedding/api_adapter.py +++ /dev/null @@ -1,368 +0,0 @@ -""" -请求式嵌入 API 适配器。 - -恢复 v1.0.1 的真实 embedding 请求语义: -- 通过宿主模型配置探测/请求 embedding -- 支持 dimensions 参数 -- 支持批量与重试 -- 不再提供本地 hash fallback -""" - -from __future__ import annotations - -import asyncio -import time -from typing import Any, List, Optional, Union - -import aiohttp -import numpy as np -import openai - -from src.common.logger import get_logger -from src.config.config import config_manager -from src.config.model_configs import APIProvider, ModelInfo -from src.llm_models.exceptions import NetworkConnectionError -from src.llm_models.model_client.base_client import client_registry - -logger = get_logger("A_Memorix.EmbeddingAPIAdapter") - - -class EmbeddingAPIAdapter: - """适配宿主 embedding 请求接口。""" - - def __init__( - self, - batch_size: int = 32, - max_concurrent: int = 5, - default_dimension: int = 1024, - enable_cache: bool = False, - model_name: str = "auto", - retry_config: Optional[dict] = None, - ) -> None: - self.batch_size = max(1, int(batch_size)) - self.max_concurrent = max(1, int(max_concurrent)) - self.default_dimension = max(1, int(default_dimension)) - self.enable_cache = bool(enable_cache) - self.model_name = str(model_name or "auto") - - self.retry_config = retry_config or {} - self.max_attempts = max(1, int(self.retry_config.get("max_attempts", 5))) - self.max_wait_seconds = max(0.1, float(self.retry_config.get("max_wait_seconds", 40))) - self.min_wait_seconds = max(0.1, float(self.retry_config.get("min_wait_seconds", 3))) - self.backoff_multiplier = max(1.0, float(self.retry_config.get("backoff_multiplier", 3))) - - self._dimension: Optional[int] = None - self._dimension_detected = False - self._total_encoded = 0 - self._total_errors = 0 - self._total_time = 0.0 - - logger.info( - "EmbeddingAPIAdapter 初始化: " - f"batch_size={self.batch_size}, " - f"max_concurrent={self.max_concurrent}, " - f"default_dim={self.default_dimension}, " - f"model={self.model_name}" - ) - - def _get_current_model_config(self): - return config_manager.get_model_config() - - @staticmethod - def _find_model_info(model_name: str) -> ModelInfo: - model_cfg = config_manager.get_model_config() - for item in model_cfg.models: - if item.name == model_name: - return item - raise ValueError(f"未找到 embedding 模型: {model_name}") - - @staticmethod - def _find_provider(provider_name: str) -> APIProvider: - model_cfg = config_manager.get_model_config() - for item in model_cfg.api_providers: - if item.name == provider_name: - return item - raise ValueError(f"未找到 embedding provider: {provider_name}") - - def _resolve_candidate_model_names(self) -> List[str]: - task_config = self._get_current_model_config().model_task_config.embedding - configured = list(getattr(task_config, "model_list", []) or []) - if self.model_name and self.model_name != "auto": - return [self.model_name, *[name for name in configured if name != self.model_name]] - return configured - - @staticmethod - def _validate_embedding_vector(embedding: Any, *, source: str) -> np.ndarray: - array = np.asarray(embedding, dtype=np.float32) - if array.ndim != 1: - raise RuntimeError(f"{source} 返回的 embedding 维度非法: ndim={array.ndim}") - if array.size <= 0: - raise RuntimeError(f"{source} 返回了空 embedding") - if not np.all(np.isfinite(array)): - raise RuntimeError(f"{source} 返回了非有限 embedding 值") - return array - - async def _request_with_retry(self, client, model_info, text: str, extra_params: dict): - retriable_exceptions = ( - openai.APIConnectionError, - openai.APITimeoutError, - aiohttp.ClientError, - asyncio.TimeoutError, - NetworkConnectionError, - ) - - last_exc: Optional[BaseException] = None - for attempt in range(1, self.max_attempts + 1): - try: - return await client.get_embedding( - model_info=model_info, - embedding_input=text, - extra_params=extra_params, - ) - except retriable_exceptions as exc: - last_exc = exc - if attempt >= self.max_attempts: - raise - wait_seconds = min( - self.max_wait_seconds, - self.min_wait_seconds * (self.backoff_multiplier ** (attempt - 1)), - ) - logger.warning( - "Embedding 请求失败,重试 " - f"{attempt}/{max(1, self.max_attempts - 1)}," - f"{wait_seconds:.1f}s 后重试: {exc}" - ) - await asyncio.sleep(wait_seconds) - except Exception: - raise - - if last_exc is not None: - raise last_exc - raise RuntimeError("Embedding 请求失败:未知错误") - - async def _get_embedding_direct(self, text: str, dimensions: Optional[int] = None) -> Optional[List[float]]: - candidate_names = self._resolve_candidate_model_names() - if not candidate_names: - raise RuntimeError("embedding 任务未配置模型") - - last_exc: Optional[BaseException] = None - for candidate_name in candidate_names: - try: - model_info = self._find_model_info(candidate_name) - api_provider = self._find_provider(model_info.api_provider) - client = client_registry.get_client_class_instance(api_provider, force_new=True) - - extra_params = dict(getattr(model_info, "extra_params", {}) or {}) - if dimensions is not None: - extra_params["dimensions"] = int(dimensions) - - response = await self._request_with_retry( - client=client, - model_info=model_info, - text=text, - extra_params=extra_params, - ) - embedding = getattr(response, "embedding", None) - if embedding is None: - raise RuntimeError(f"模型 {candidate_name} 未返回 embedding") - vector = self._validate_embedding_vector( - embedding, - source=f"embedding 模型 {candidate_name}", - ) - return vector.tolist() - except Exception as exc: - last_exc = exc - logger.warning(f"embedding 模型 {candidate_name} 请求失败: {exc}") - - if last_exc is not None: - logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}") - return None - - async def _detect_dimension(self) -> int: - if self._dimension_detected and self._dimension is not None: - return self._dimension - - logger.info("正在检测嵌入模型维度...") - try: - target_dim = self.default_dimension - logger.debug(f"尝试请求指定维度: {target_dim}") - test_embedding = await self._get_embedding_direct("test", dimensions=target_dim) - if test_embedding and isinstance(test_embedding, list): - detected_dim = len(test_embedding) - if detected_dim == target_dim: - logger.info(f"嵌入维度检测成功 (匹配配置): {detected_dim}") - else: - logger.warning( - f"请求维度 {target_dim} 但模型返回 {detected_dim},将使用模型自然维度" - ) - self._dimension = detected_dim - self._dimension_detected = True - return detected_dim - except Exception as exc: - logger.debug(f"带维度参数探测失败: {exc},尝试不带参数探测") - - try: - test_embedding = await self._get_embedding_direct("test", dimensions=None) - if test_embedding and isinstance(test_embedding, list): - detected_dim = len(test_embedding) - self._dimension = detected_dim - self._dimension_detected = True - logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}") - return detected_dim - logger.warning(f"嵌入维度检测失败,使用默认值: {self.default_dimension}") - except Exception as exc: - logger.error(f"嵌入维度检测异常: {exc},使用默认值: {self.default_dimension}") - - self._dimension = self.default_dimension - self._dimension_detected = True - return self.default_dimension - - async def encode( - self, - texts: Union[str, List[str]], - batch_size: Optional[int] = None, - show_progress: bool = False, - normalize: bool = True, - dimensions: Optional[int] = None, - ) -> np.ndarray: - del show_progress - del normalize - - start_time = time.time() - target_dim = int(dimensions) if dimensions is not None else int(await self._detect_dimension()) - - if isinstance(texts, str): - normalized_texts = [texts] - single_input = True - else: - normalized_texts = list(texts or []) - single_input = False - - if not normalized_texts: - empty = np.zeros((0, target_dim), dtype=np.float32) - return empty[0] if single_input else empty - - if batch_size is None: - batch_size = self.batch_size - - try: - embeddings = await self._encode_batch_internal( - normalized_texts, - batch_size=max(1, int(batch_size)), - dimensions=dimensions, - ) - if embeddings.ndim == 1: - embeddings = embeddings.reshape(1, -1) - self._total_encoded += len(normalized_texts) - elapsed = time.time() - start_time - self._total_time += elapsed - logger.debug( - "编码完成: " - f"{len(normalized_texts)} 个文本, " - f"耗时 {elapsed:.2f}s, " - f"平均 {elapsed / max(1, len(normalized_texts)):.3f}s/文本" - ) - return embeddings[0] if single_input else embeddings - except Exception as exc: - self._total_errors += 1 - logger.error(f"编码失败: {exc}") - raise RuntimeError(f"embedding encode failed: {exc}") from exc - - async def _encode_batch_internal( - self, - texts: List[str], - batch_size: int, - dimensions: Optional[int] = None, - ) -> np.ndarray: - all_embeddings: List[np.ndarray] = [] - for offset in range(0, len(texts), batch_size): - batch = texts[offset : offset + batch_size] - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def encode_with_semaphore(text: str, index: int): - async with semaphore: - embedding = await self._get_embedding_direct(text, dimensions=dimensions) - if embedding is None: - raise RuntimeError(f"文本 {index} 编码失败:embedding 返回为空") - vector = self._validate_embedding_vector( - embedding, - source=f"文本 {index}", - ) - return index, vector - - tasks = [ - encode_with_semaphore(text, offset + index) - for index, text in enumerate(batch) - ] - results = await asyncio.gather(*tasks) - results.sort(key=lambda item: item[0]) - all_embeddings.extend(emb for _, emb in results) - - return np.array(all_embeddings, dtype=np.float32) - - async def encode_batch( - self, - texts: List[str], - batch_size: Optional[int] = None, - num_workers: Optional[int] = None, - show_progress: bool = False, - dimensions: Optional[int] = None, - ) -> np.ndarray: - del show_progress - if num_workers is not None: - previous = self.max_concurrent - self.max_concurrent = max(1, int(num_workers)) - try: - return await self.encode(texts, batch_size=batch_size, dimensions=dimensions) - finally: - self.max_concurrent = previous - return await self.encode(texts, batch_size=batch_size, dimensions=dimensions) - - def get_embedding_dimension(self) -> int: - if self._dimension is not None: - return self._dimension - logger.warning(f"维度尚未检测,返回默认值: {self.default_dimension}") - return self.default_dimension - - def get_model_info(self) -> dict: - return { - "model_name": self.model_name, - "dimension": self._dimension or self.default_dimension, - "dimension_detected": self._dimension_detected, - "batch_size": self.batch_size, - "max_concurrent": self.max_concurrent, - "total_encoded": self._total_encoded, - "total_errors": self._total_errors, - "avg_time_per_text": self._total_time / self._total_encoded if self._total_encoded else 0.0, - } - - def get_statistics(self) -> dict: - return self.get_model_info() - - @property - def is_model_loaded(self) -> bool: - return True - - def __repr__(self) -> str: - return ( - f"EmbeddingAPIAdapter(dim={self._dimension or self.default_dimension}, " - f"detected={self._dimension_detected}, encoded={self._total_encoded})" - ) - - -def create_embedding_api_adapter( - batch_size: int = 32, - max_concurrent: int = 5, - default_dimension: int = 1024, - enable_cache: bool = False, - model_name: str = "auto", - retry_config: Optional[dict] = None, -) -> EmbeddingAPIAdapter: - return EmbeddingAPIAdapter( - batch_size=batch_size, - max_concurrent=max_concurrent, - default_dimension=default_dimension, - enable_cache=enable_cache, - model_name=model_name, - retry_config=retry_config, - ) diff --git a/plugins/A_memorix/core/embedding/manager.py b/plugins/A_memorix/core/embedding/manager.py deleted file mode 100644 index d161e23b..00000000 --- a/plugins/A_memorix/core/embedding/manager.py +++ /dev/null @@ -1,510 +0,0 @@ -""" -嵌入管理器 - -负责嵌入模型的加载、缓存和批量生成。 -""" - -import hashlib -import pickle -import threading -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -from typing import Optional, Union, List, Dict, Any, Tuple - -import numpy as np - -try: - from sentence_transformers import SentenceTransformer - HAS_SENTENCE_TRANSFORMERS = True -except ImportError: - HAS_SENTENCE_TRANSFORMERS = False - -from src.common.logger import get_logger -from .presets import ( - EmbeddingModelConfig, - get_custom_config, - validate_config_compatibility, - are_models_compatible, -) -from ..utils.quantization import QuantizationType - -logger = get_logger("A_Memorix.EmbeddingManager") - - -class EmbeddingManager: - """ - 嵌入管理器 - - 功能: - - 模型加载与缓存 - - 批量生成嵌入 - - 多线程/多进程支持 - - 模型一致性检查 - - 智能分批 - - 参数: - config: 模型配置 - cache_dir: 缓存目录 - enable_cache: 是否启用缓存 - num_workers: 工作线程数 - """ - - def __init__( - self, - config: EmbeddingModelConfig, - cache_dir: Optional[Union[str, Path]] = None, - enable_cache: bool = True, - num_workers: int = 1, - ): - """ - 初始化嵌入管理器 - - Args: - config: 模型配置 - cache_dir: 缓存目录 - enable_cache: 是否启用缓存 - num_workers: 工作线程数 - """ - if not HAS_SENTENCE_TRANSFORMERS: - raise ImportError( - "sentence-transformers 未安装,请安装: " - "pip install sentence-transformers" - ) - - self.config = config - self.cache_dir = Path(cache_dir) if cache_dir else None - self.enable_cache = enable_cache - self.num_workers = max(1, num_workers) - - # 模型实例 - self._model: Optional[SentenceTransformer] = None - self._model_lock = threading.Lock() - - # 缓存 - self._embedding_cache: Dict[str, np.ndarray] = {} - self._cache_lock = threading.Lock() - - # 统计 - self._total_encoded = 0 - self._cache_hits = 0 - self._cache_misses = 0 - - logger.info( - f"EmbeddingManager 初始化: model={config.model_name}, " - f"dim={config.dimension}, workers={num_workers}" - ) - - def load_model(self) -> None: - """加载模型(懒加载)""" - if self._model is not None: - return - - with self._model_lock: - # 双重检查 - if self._model is not None: - return - - logger.info(f"正在加载模型: {self.config.model_name}") - - try: - # 构建模型参数 - model_kwargs = {} - if self.config.cache_dir: - model_kwargs["cache_folder"] = self.config.cache_dir - - # 加载模型 - self._model = SentenceTransformer( - self.config.model_path, - **model_kwargs, - ) - - logger.info(f"模型加载成功: {self.config.model_name}") - - except Exception as e: - logger.error(f"模型加载失败: {e}") - raise - - def encode( - self, - texts: Union[str, List[str]], - batch_size: Optional[int] = None, - show_progress: bool = False, - normalize: bool = True, - ) -> np.ndarray: - """ - 生成文本嵌入 - - Args: - texts: 文本或文本列表 - batch_size: 批次大小(默认使用配置值) - show_progress: 是否显示进度条 - normalize: 是否归一化 - - Returns: - 嵌入向量 (N x D) - """ - # 确保模型已加载 - self.load_model() - - # 标准化输入 - if isinstance(texts, str): - texts = [texts] - single_input = True - else: - single_input = False - - if not texts: - return np.zeros((0, self.config.dimension), dtype=np.float32) - - # 使用配置的批次大小 - if batch_size is None: - batch_size = self.config.batch_size - - # 生成嵌入 - try: - embeddings = self._model.encode( - texts, - batch_size=batch_size, - show_progress_bar=show_progress, - normalize_embeddings=normalize and self.config.normalization, - convert_to_numpy=True, - ) - - # 确保是2D数组 - if embeddings.ndim == 1: - embeddings = embeddings.reshape(1, -1) - - self._total_encoded += len(texts) - - # 如果是单个输入,返回1D数组 - if single_input: - return embeddings[0] - - return embeddings - - except Exception as e: - logger.error(f"生成嵌入失败: {e}") - raise - - def encode_batch( - self, - texts: List[str], - batch_size: Optional[int] = None, - num_workers: Optional[int] = None, - show_progress: bool = False, - ) -> np.ndarray: - """ - 批量生成嵌入(多线程优化) - - Args: - texts: 文本列表 - batch_size: 批次大小 - num_workers: 工作线程数(默认使用初始化时的值) - show_progress: 是否显示进度条 - - Returns: - 嵌入向量 (N x D) - """ - if not texts: - return np.zeros((0, self.config.dimension), dtype=np.float32) - - # 单线程模式 - num_workers = num_workers if num_workers is not None else self.num_workers - if num_workers == 1: - return self.encode(texts, batch_size=batch_size, show_progress=show_progress) - - # 多线程模式 - logger.info(f"使用 {num_workers} 个线程生成 {len(texts)} 个嵌入") - - # 分批 - batch_size = batch_size or self.config.batch_size - batches = [ - texts[i:i + batch_size] - for i in range(0, len(texts), batch_size) - ] - - # 多线程生成 - all_embeddings = [] - with ThreadPoolExecutor(max_workers=num_workers) as executor: - # 提交任务 - future_to_batch = { - executor.submit( - self.encode, - batch, - batch_size, - False, # 不显示进度条(多线程时会混乱) - ): i - for i, batch in enumerate(batches) - } - - # 收集结果 - for future in as_completed(future_to_batch): - batch_idx = future_to_batch[future] - try: - embeddings = future.result() - all_embeddings.append((batch_idx, embeddings)) - except Exception as e: - logger.error(f"批次 {batch_idx} 生成嵌入失败: {e}") - raise - - # 按顺序合并 - all_embeddings.sort(key=lambda x: x[0]) - final_embeddings = np.concatenate([emb for _, emb in all_embeddings], axis=0) - - return final_embeddings - - def encode_with_cache( - self, - texts: List[str], - batch_size: Optional[int] = None, - show_progress: bool = False, - ) -> np.ndarray: - """ - 生成嵌入(带缓存) - - Args: - texts: 文本列表 - batch_size: 批次大小 - show_progress: 是否显示进度条 - - Returns: - 嵌入向量 (N x D) - """ - if not self.enable_cache: - return self.encode(texts, batch_size, show_progress) - - # 分离缓存命中和未命中的文本 - cached_embeddings = [] - uncached_texts = [] - uncached_indices = [] - - for i, text in enumerate(texts): - cache_key = self._get_cache_key(text) - - with self._cache_lock: - if cache_key in self._embedding_cache: - cached_embeddings.append((i, self._embedding_cache[cache_key])) - self._cache_hits += 1 - else: - uncached_texts.append(text) - uncached_indices.append(i) - self._cache_misses += 1 - - # 生成未缓存的嵌入 - if uncached_texts: - new_embeddings = self.encode( - uncached_texts, - batch_size, - show_progress, - ) - - # 更新缓存 - with self._cache_lock: - for text, embedding in zip(uncached_texts, new_embeddings): - cache_key = self._get_cache_key(text) - self._embedding_cache[cache_key] = embedding.copy() - - # 合并结果 - for idx, embedding in zip(uncached_indices, new_embeddings): - cached_embeddings.append((idx, embedding)) - - # 按原始顺序排序 - cached_embeddings.sort(key=lambda x: x[0]) - final_embeddings = np.array([emb for _, emb in cached_embeddings]) - - return final_embeddings - - def save_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None: - """ - 保存缓存到磁盘 - - Args: - cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl) - """ - if cache_path is None: - if self.cache_dir is None: - raise ValueError("未指定缓存目录") - cache_path = self.cache_dir / "embeddings_cache.pkl" - - cache_path = Path(cache_path) - cache_path.parent.mkdir(parents=True, exist_ok=True) - - with self._cache_lock: - with open(cache_path, "wb") as f: - pickle.dump(self._embedding_cache, f) - - logger.info(f"缓存已保存: {cache_path} ({len(self._embedding_cache)} 条)") - - def load_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None: - """ - 从磁盘加载缓存 - - Args: - cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl) - """ - if cache_path is None: - if self.cache_dir is None: - raise ValueError("未指定缓存目录") - cache_path = self.cache_dir / "embeddings_cache.pkl" - - cache_path = Path(cache_path) - if not cache_path.exists(): - logger.warning(f"缓存文件不存在: {cache_path}") - return - - with self._cache_lock: - with open(cache_path, "rb") as f: - self._embedding_cache = pickle.load(f) - - logger.info(f"缓存已加载: {cache_path} ({len(self._embedding_cache)} 条)") - - def clear_cache(self) -> None: - """清空缓存""" - with self._cache_lock: - count = len(self._embedding_cache) - self._embedding_cache.clear() - logger.info(f"已清空缓存: {count} 条") - - def check_model_consistency( - self, - stored_embeddings: np.ndarray, - sample_texts: List[str] = None, - ) -> Tuple[bool, str]: - """ - 检查模型一致性 - - Args: - stored_embeddings: 存储的嵌入向量 - sample_texts: 样本文本(用于重新生成对比) - - Returns: - (是否一致, 详细信息) - """ - # 检查维度 - if stored_embeddings.shape[1] != self.config.dimension: - return False, f"维度不匹配: 期望 {self.config.dimension}, 实际 {stored_embeddings.shape[1]}" - - # 如果提供了样本文本,重新生成并比较 - if sample_texts: - try: - new_embeddings = self.encode(sample_texts[:5]) # 只比较前5个 - - # 计算相似度 - similarities = np.dot( - stored_embeddings[:5], - new_embeddings.T, - ).diagonal() - - # 检查相似度 - if np.mean(similarities) < 0.95: - return False, f"模型可能已更改,平均相似度: {np.mean(similarities):.3f}" - - return True, f"模型一致,平均相似度: {np.mean(similarities):.3f}" - - except Exception as e: - return False, f"一致性检查失败: {e}" - - return True, "维度匹配" - - def get_model_info(self) -> Dict[str, Any]: - """ - 获取模型信息 - - Returns: - 模型信息字典 - """ - return { - "model_name": self.config.model_name, - "dimension": self.config.dimension, - "max_seq_length": self.config.max_seq_length, - "batch_size": self.config.batch_size, - "normalization": self.config.normalization, - "pooling": self.config.pooling, - "model_loaded": self._model is not None, - "cache_enabled": self.enable_cache, - "cache_size": len(self._embedding_cache), - "total_encoded": self._total_encoded, - "cache_hits": self._cache_hits, - "cache_misses": self._cache_misses, - } - - def get_embedding_dimension(self) -> int: - """获取嵌入维度""" - return self.config.dimension - - def _get_cache_key(self, text: str) -> str: - """ - 生成缓存键 - - Args: - text: 文本内容 - - Returns: - 缓存键(SHA256哈希) - """ - return hashlib.sha256(text.encode("utf-8")).hexdigest() - - @property - def is_model_loaded(self) -> bool: - """模型是否已加载""" - return self._model is not None - - @property - def cache_hit_rate(self) -> float: - """缓存命中率""" - total = self._cache_hits + self._cache_misses - if total == 0: - return 0.0 - return self._cache_hits / total - - def __repr__(self) -> str: - return ( - f"EmbeddingManager(model={self.config.model_name}, " - f"dim={self.config.dimension}, " - f"loaded={self.is_model_loaded}, " - f"cache={len(self._embedding_cache)})" - ) - - - - -def create_embedding_manager_from_config( - model_name: str, - model_path: str, - dimension: int, - cache_dir: Optional[Union[str, Path]] = None, - enable_cache: bool = True, - num_workers: int = 1, - **config_kwargs, -) -> EmbeddingManager: - """ - 从自定义配置创建嵌入管理器 - - Args: - model_name: 模型名称 - model_path: HuggingFace模型路径 - dimension: 输出维度 - cache_dir: 缓存目录 - enable_cache: 是否启用缓存 - num_workers: 工作线程数 - **config_kwargs: 其他配置参数 - - Returns: - 嵌入管理器实例 - """ - # 创建自定义配置 - config = get_custom_config( - model_name=model_name, - model_path=model_path, - dimension=dimension, - cache_dir=cache_dir, - **config_kwargs, - ) - - # 创建管理器 - return EmbeddingManager( - config=config, - cache_dir=cache_dir, - enable_cache=enable_cache, - num_workers=num_workers, - ) diff --git a/plugins/A_memorix/core/embedding/presets.py b/plugins/A_memorix/core/embedding/presets.py deleted file mode 100644 index 54e6f8b4..00000000 --- a/plugins/A_memorix/core/embedding/presets.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -嵌入模型配置模块 -""" - -from dataclasses import dataclass -from typing import Optional, Dict, Any, Union -from pathlib import Path - - -@dataclass -class EmbeddingModelConfig: - """ - 嵌入模型配置 - - 属性: - model_name: 模型描述名称 - model_path: 实际加载路径(Local or HF) - dimension: 嵌入向量维度 - max_seq_length: 最大序列长度 - batch_size: 编码批次大小 - model_size_mb: 估计显存占用 - description: 模型说明 - normalization: 是否自动归一化 - pooling: 池化策略 (mean, cls, max) - cache_dir: 模型缓存目录 - """ - - model_name: str - model_path: str - dimension: int - max_seq_length: int = 512 - batch_size: int = 32 - model_size_mb: int = 100 - description: str = "" - normalization: bool = True - pooling: str = "mean" - cache_dir: Optional[Union[str, Path]] = None - - -def validate_config_compatibility( - config1: EmbeddingModelConfig, config2: EmbeddingModelConfig -) -> bool: - """检查两个配置是否兼容(主要看维度)""" - return config1.dimension == config2.dimension - - -def are_models_compatible( - config1: EmbeddingModelConfig, config2: EmbeddingModelConfig -) -> bool: - """检查模型是否完全相同(用于热切换判断)""" - return ( - config1.model_path == config2.model_path - and config1.dimension == config2.dimension - and config1.pooling == config2.pooling - ) - - -def get_custom_config( - model_name: str, - model_path: str, - dimension: int, - cache_dir: Optional[Union[str, Path]] = None, - **kwargs, -) -> EmbeddingModelConfig: - """创建自定义模型配置""" - return EmbeddingModelConfig( - model_name=model_name, - model_path=model_path, - dimension=dimension, - cache_dir=cache_dir, - **kwargs, - ) diff --git a/plugins/A_memorix/core/retrieval/__init__.py b/plugins/A_memorix/core/retrieval/__init__.py deleted file mode 100644 index 6efce7f6..00000000 --- a/plugins/A_memorix/core/retrieval/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -"""检索模块 - 双路检索与排序""" - -from .dual_path import ( - DualPathRetriever, - RetrievalStrategy, - RetrievalResult, - DualPathRetrieverConfig, - TemporalQueryOptions, - FusionConfig, - RelationIntentConfig, -) -from .pagerank import ( - PersonalizedPageRank, - PageRankConfig, - create_ppr_from_graph, -) -from .threshold import ( - DynamicThresholdFilter, - ThresholdMethod, - ThresholdConfig, -) -from .sparse_bm25 import ( - SparseBM25Index, - SparseBM25Config, -) -from .graph_relation_recall import ( - GraphRelationRecallConfig, - GraphRelationRecallService, -) - -__all__ = [ - # DualPathRetriever - "DualPathRetriever", - "RetrievalStrategy", - "RetrievalResult", - "DualPathRetrieverConfig", - "TemporalQueryOptions", - "FusionConfig", - "RelationIntentConfig", - # PersonalizedPageRank - "PersonalizedPageRank", - "PageRankConfig", - "create_ppr_from_graph", - # DynamicThresholdFilter - "DynamicThresholdFilter", - "ThresholdMethod", - "ThresholdConfig", - # Sparse BM25 - "SparseBM25Index", - "SparseBM25Config", - # Graph relation recall - "GraphRelationRecallConfig", - "GraphRelationRecallService", -] diff --git a/plugins/A_memorix/core/retrieval/dual_path.py b/plugins/A_memorix/core/retrieval/dual_path.py deleted file mode 100644 index 6ed5e71a..00000000 --- a/plugins/A_memorix/core/retrieval/dual_path.py +++ /dev/null @@ -1,1796 +0,0 @@ -""" -双路检索器 - -同时检索关系和段落,实现知识图谱增强的检索。 -""" - -import asyncio -import re -from dataclasses import dataclass, field -from typing import Optional, List, Dict, Any, Tuple, Union -from enum import Enum - -import numpy as np - -from src.common.logger import get_logger -from ..storage import VectorStore, GraphStore, MetadataStore -from ..embedding import EmbeddingAPIAdapter -from ..utils.matcher import AhoCorasick -from ..utils.time_parser import format_timestamp -from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService -from .pagerank import PersonalizedPageRank, PageRankConfig -from .sparse_bm25 import SparseBM25Config, SparseBM25Index - -logger = get_logger("A_Memorix.DualPathRetriever") - - -class RetrievalStrategy(Enum): - """检索策略""" - - PARA_ONLY = "paragraph_only" # 仅段落检索 - REL_ONLY = "relation_only" # 仅关系检索 - DUAL_PATH = "dual_path" # 双路检索(推荐) - - -@dataclass -class RetrievalResult: - """ - 检索结果 - - 属性: - hash_value: 哈希值 - content: 内容(段落或关系) - score: 相似度分数 - result_type: 结果类型(paragraph/relation) - source: 来源(paragraph_search/relation_search/fusion) - metadata: 额外元数据 - """ - - hash_value: str - content: str - score: float - result_type: str # "paragraph" or "relation" - source: str # "paragraph_search", "relation_search", "fusion" - metadata: Dict[str, Any] - - def to_dict(self) -> Dict[str, Any]: - """转换为字典""" - return { - "hash": self.hash_value, - "content": self.content, - "score": self.score, - "type": self.result_type, - "source": self.source, - "metadata": self.metadata, - } - - -@dataclass -class DualPathRetrieverConfig: - """ - 双路检索器配置 - - 属性: - top_k_paragraphs: 段落检索数量 - top_k_relations: 关系检索数量 - top_k_final: 最终返回数量 - alpha: 段落和关系的融合权重(0-1) - - 0: 仅使用关系分数 - - 1: 仅使用段落分数 - - 0.5: 平均融合 - enable_ppr: 是否启用PageRank重排序 - ppr_alpha: PageRank的alpha参数 - ppr_concurrency_limit: PPR计算的最大并发数 - enable_parallel: 是否并行检索 - retrieval_strategy: 检索策略 - debug: 是否启用调试模式(打印搜索结果原文) - """ - - top_k_paragraphs: int = 20 - top_k_relations: int = 10 - top_k_final: int = 10 - alpha: float = 0.5 # 融合权重 - enable_ppr: bool = True - ppr_alpha: float = 0.85 - ppr_timeout_seconds: float = 1.5 - ppr_concurrency_limit: int = 4 - enable_parallel: bool = True - retrieval_strategy: RetrievalStrategy = RetrievalStrategy.DUAL_PATH - debug: bool = False - sparse: SparseBM25Config = field(default_factory=SparseBM25Config) - fusion: "FusionConfig" = field(default_factory=lambda: FusionConfig()) - relation_intent: "RelationIntentConfig" = field(default_factory=lambda: RelationIntentConfig()) - graph_recall: GraphRelationRecallConfig = field(default_factory=GraphRelationRecallConfig) - - def __post_init__(self): - """验证配置""" - if isinstance(self.sparse, dict): - self.sparse = SparseBM25Config(**self.sparse) - if isinstance(self.fusion, dict): - self.fusion = FusionConfig(**self.fusion) - if isinstance(self.relation_intent, dict): - self.relation_intent = RelationIntentConfig(**self.relation_intent) - if isinstance(self.graph_recall, dict): - self.graph_recall = GraphRelationRecallConfig(**self.graph_recall) - - if not 0 <= self.alpha <= 1: - raise ValueError(f"alpha必须在[0, 1]之间: {self.alpha}") - - if self.top_k_paragraphs <= 0: - raise ValueError(f"top_k_paragraphs必须大于0: {self.top_k_paragraphs}") - - if self.top_k_relations <= 0: - raise ValueError(f"top_k_relations必须大于0: {self.top_k_relations}") - - if self.top_k_final <= 0: - raise ValueError(f"top_k_final必须大于0: {self.top_k_final}") - if self.ppr_timeout_seconds <= 0: - raise ValueError(f"ppr_timeout_seconds必须大于0: {self.ppr_timeout_seconds}") - - -@dataclass -class TemporalQueryOptions: - """时序查询选项。""" - - time_from: Optional[float] = None - time_to: Optional[float] = None - person: Optional[str] = None - source: Optional[str] = None - allow_created_fallback: bool = True - candidate_multiplier: int = 8 - max_scan: int = 1000 - - -@dataclass -class RelationIntentConfig: - """关系意图增强配置。""" - - enabled: bool = True - alpha_override: float = 0.35 - relation_candidate_multiplier: int = 4 - preserve_top_relations: int = 3 - force_relation_sparse: bool = True - pair_predicate_rerank_enabled: bool = True - pair_predicate_limit: int = 3 - - def __post_init__(self): - self.alpha_override = min(1.0, max(0.0, float(self.alpha_override))) - self.relation_candidate_multiplier = max(1, int(self.relation_candidate_multiplier)) - self.preserve_top_relations = max(0, int(self.preserve_top_relations)) - self.force_relation_sparse = bool(self.force_relation_sparse) - self.pair_predicate_rerank_enabled = bool(self.pair_predicate_rerank_enabled) - self.pair_predicate_limit = max(1, int(self.pair_predicate_limit)) - - -@dataclass -class FusionConfig: - """融合配置。""" - - method: str = "weighted_rrf" # weighted_rrf | alpha_legacy - rrf_k: int = 60 - vector_weight: float = 0.7 - bm25_weight: float = 0.3 - normalize_score: bool = True - normalize_method: str = "minmax" - - def __post_init__(self): - self.method = str(self.method or "weighted_rrf").strip().lower() - self.normalize_method = str(self.normalize_method or "minmax").strip().lower() - self.rrf_k = max(1, int(self.rrf_k)) - self.vector_weight = max(0.0, float(self.vector_weight)) - self.bm25_weight = max(0.0, float(self.bm25_weight)) - s = self.vector_weight + self.bm25_weight - if s <= 0: - self.vector_weight = 0.7 - self.bm25_weight = 0.3 - elif abs(s - 1.0) > 1e-8: - self.vector_weight /= s - self.bm25_weight /= s - - -class DualPathRetriever: - """ - 双路检索器 - - 功能: - - 并行检索段落和关系 - - 结果融合与排序 - - PageRank重排序 - - 实体识别与加权 - - 参数: - vector_store: 向量存储 - graph_store: 图存储 - metadata_store: 元数据存储 - embedding_manager: 嵌入管理器 - config: 检索配置 - """ - - def __init__( - self, - vector_store: VectorStore, - graph_store: GraphStore, - metadata_store: MetadataStore, - embedding_manager: EmbeddingAPIAdapter, - sparse_index: Optional[SparseBM25Index] = None, - config: Optional[DualPathRetrieverConfig] = None, - ): - """ - 初始化双路检索器 - - Args: - vector_store: 向量存储 - graph_store: 图存储 - metadata_store: 元数据存储 - embedding_manager: 嵌入管理器 - config: 检索配置 - """ - self.vector_store = vector_store - self.graph_store = graph_store - self.metadata_store = metadata_store - self.embedding_manager = embedding_manager - self.config = config or DualPathRetrieverConfig() - self.sparse_index = sparse_index - - # PageRank计算器 - ppr_config = PageRankConfig(alpha=self.config.ppr_alpha) - self._ppr = PersonalizedPageRank( - graph_store=graph_store, - config=ppr_config, - ) - self._ppr_semaphore = asyncio.Semaphore(self.config.ppr_concurrency_limit) - self._graph_relation_recall = GraphRelationRecallService( - graph_store=graph_store, - metadata_store=metadata_store, - config=self.config.graph_recall, - ) - - logger.info( - f"DualPathRetriever 初始化: " - f"strategy={self.config.retrieval_strategy.value}, " - f"top_k_para={self.config.top_k_paragraphs}, " - f"top_k_rel={self.config.top_k_relations}" - ) - - # 缓存 Aho-Corasick 匹配器 - self._ac_matcher: Optional[AhoCorasick] = None - self._ac_nodes_count = 0 - self._relation_intent_pattern = re.compile( - r"(什么关系|有哪些关系|和.+关系|关联|关系网|subject|predicate|object|" - r"relation|related|between.+and)", - re.IGNORECASE, - ) - - async def retrieve( - self, - query: str, - top_k: Optional[int] = None, - strategy: Optional[RetrievalStrategy] = None, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 执行检索(异步方法) - - Args: - query: 查询文本 - top_k: 返回结果数量(默认使用配置值) - strategy: 检索策略(默认使用配置值) - temporal: 时序查询选项(可选) - - Returns: - 检索结果列表 - """ - top_k = top_k or self.config.top_k_final - strategy = strategy or self.config.retrieval_strategy - relation_intent_ctx = self._build_relation_intent_context(query=query, top_k=top_k) - - logger.info( - "执行检索: " - f"query='{query[:50]}...', " - f"strategy={strategy.value}, " - f"relation_intent={relation_intent_ctx.get('enabled', False)}" - ) - - if temporal and not (query or "").strip(): - return self._retrieve_temporal_only(temporal, top_k) - - # 根据策略执行检索 - if strategy == RetrievalStrategy.PARA_ONLY: - results = await self._retrieve_paragraphs_only(query, top_k, temporal=temporal) - elif strategy == RetrievalStrategy.REL_ONLY: - results = await self._retrieve_relations_only(query, top_k, temporal=temporal) - else: # DUAL_PATH - results = await self._retrieve_dual_path( - query, - top_k, - temporal=temporal, - relation_intent=relation_intent_ctx, - ) - - logger.info(f"检索完成: 返回 {len(results)} 条结果") - - # 调试模式:打印结果原文 - if self.config.debug: - logger.info(f"[DEBUG] 检索结果内容原文:") - for i, res in enumerate(results): - logger.info(f" {i+1}. [{res.result_type}] (Score: {res.score:.4f}) {res.content}") - - return results - - def _is_relation_intent_query(self, query: str) -> bool: - q = str(query or "").strip() - if not q: - return False - if "|" in q or "->" in q: - return True - return self._relation_intent_pattern.search(q) is not None - - def _build_relation_intent_context(self, query: str, top_k: int) -> Dict[str, Any]: - cfg = self.config.relation_intent - enabled = bool(cfg.enabled) and self._is_relation_intent_query(query) - base_relation_k = max(1, int(self.config.top_k_relations)) - relation_top_k = max(base_relation_k, int(top_k)) - if enabled: - relation_top_k = max( - relation_top_k, - relation_top_k * int(cfg.relation_candidate_multiplier), - ) - return { - "enabled": enabled, - "alpha_override": float(cfg.alpha_override) if enabled else None, - "relation_top_k": int(relation_top_k), - "preserve_top_relations": int(cfg.preserve_top_relations) if enabled else 0, - "force_relation_sparse": bool(cfg.force_relation_sparse) if enabled else False, - "pair_predicate_rerank_enabled": bool(cfg.pair_predicate_rerank_enabled) if enabled else False, - "pair_predicate_limit": int(cfg.pair_predicate_limit) if enabled else 0, - } - - def _cap_temporal_scan_k( - self, - candidate_k: int, - temporal: Optional[TemporalQueryOptions], - ) -> int: - """对 temporal 模式候选召回数应用 max_scan 上限。""" - k = max(1, int(candidate_k)) - if temporal and temporal.max_scan and temporal.max_scan > 0: - k = min(k, int(temporal.max_scan)) - return max(1, k) - - def _is_valid_embedding(self, emb: Optional[np.ndarray]) -> bool: - if emb is None: - return False - arr = np.asarray(emb, dtype=np.float32) - if arr.ndim == 0 or arr.size == 0: - return False - return bool(np.all(np.isfinite(arr))) - - def _get_embedding_dim(self, emb: Optional[np.ndarray]) -> Optional[int]: - if emb is None: - return None - arr = np.asarray(emb) - if arr.ndim == 1: - return int(arr.shape[0]) if arr.size > 0 else None - if arr.ndim == 2: - if arr.shape[0] == 0: - return None - return int(arr.shape[1]) - return None - - def _is_embedding_dimension_compatible(self, emb: Optional[np.ndarray]) -> bool: - got_dim = self._get_embedding_dim(emb) - expected_dim = int(getattr(self.vector_store, "dimension", 0) or 0) - if got_dim is None or expected_dim <= 0: - return False - return got_dim == expected_dim - - def _is_embedding_ready_for_vector_search( - self, - emb: Optional[np.ndarray], - *, - stage: str, - ) -> bool: - if not self._is_valid_embedding(emb): - return False - if self._is_embedding_dimension_compatible(emb): - return True - - expected_dim = int(getattr(self.vector_store, "dimension", 0) or 0) - got_dim = self._get_embedding_dim(emb) - logger.warning( - "metric.embedding_dim_mismatch_fallback_count=1 " - f"stage={stage} expected_dim={expected_dim} got_dim={got_dim}" - ) - return False - - def _should_use_sparse( - self, - embedding_ok: bool, - vector_results: Optional[List[RetrievalResult]] = None, - ) -> bool: - if not self.config.sparse.enabled or self.sparse_index is None: - return False - - mode = self.config.sparse.mode - if mode == "hybrid": - return True - if mode == "fallback_only": - return not embedding_ok - # auto - if not embedding_ok: - return True - if not vector_results: - return True - best = max((float(r.score) for r in vector_results), default=0.0) - return best < 0.45 - - def _should_use_sparse_relations( - self, - embedding_ok: bool, - relation_results: Optional[List[RetrievalResult]] = None, - force_enable: bool = False, - ) -> bool: - if force_enable and self.config.sparse.enabled and self.sparse_index is not None: - return True - if not self.config.sparse.enable_relation_sparse_fallback: - return False - return self._should_use_sparse(embedding_ok, relation_results) - - def _normalize_scores_minmax(self, results: List[RetrievalResult]) -> None: - if not results: - return - vals = [float(r.score) for r in results] - lo = min(vals) - hi = max(vals) - if hi - lo < 1e-12: - for r in results: - r.score = 1.0 - return - for r in results: - r.score = (float(r.score) - lo) / (hi - lo) - - def _build_minmax_score_map(self, results: List[RetrievalResult]) -> Dict[str, float]: - if not results: - return {} - vals = [float(r.score) for r in results] - lo = min(vals) - hi = max(vals) - if hi - lo < 1e-12: - return {r.hash_value: 1.0 for r in results} - return { - r.hash_value: (float(r.score) - lo) / (hi - lo) - for r in results - } - - @staticmethod - def _clone_retrieval_result(item: RetrievalResult) -> RetrievalResult: - return RetrievalResult( - hash_value=item.hash_value, - content=item.content, - score=float(item.score), - result_type=item.result_type, - source=item.source, - metadata=dict(item.metadata or {}), - ) - - def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: - entities = self._extract_entities(query) - if not entities: - return [] - ranked = sorted( - entities.items(), - key=lambda x: (-float(x[1]), -len(str(x[0])), str(x[0]).lower()), - ) - return [str(name) for name, _ in ranked[: max(0, int(limit))]] - - def _search_relations_graph( - self, - query: str, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - service = getattr(self, "_graph_relation_recall", None) - if service is None or not bool(getattr(self.config.graph_recall, "enabled", True)): - return [] - - seed_entities = self._extract_graph_seed_entities(query, limit=2) - if not seed_entities: - return [] - - payloads = service.recall(seed_entities=seed_entities) - results: List[RetrievalResult] = [] - for payload in payloads: - meta = payload.to_payload() - results.append( - RetrievalResult( - hash_value=str(meta["hash"]), - content=str(meta["content"]), - score=0.0, - result_type="relation", - source="graph_relation_recall", - metadata={ - "subject": meta["subject"], - "predicate": meta["predicate"], - "object": meta["object"], - "confidence": float(meta["confidence"]), - "graph_seed_entities": list(meta["graph_seed_entities"]), - "graph_hops": int(meta["graph_hops"]), - "graph_candidate_type": str(meta["graph_candidate_type"]), - "supporting_paragraph_count": int(meta["supporting_paragraph_count"]), - }, - ) - ) - return self._apply_temporal_filter_to_relations(results, temporal) - - def _fuse_ranked_lists_weighted_rrf( - self, - vector_results: List[RetrievalResult], - sparse_results: List[RetrievalResult], - ) -> List[RetrievalResult]: - """按 weighted RRF 融合两路段落召回。""" - if not vector_results: - out = sparse_results[:] - if self.config.fusion.normalize_score: - self._normalize_scores_minmax(out) - return out - if not sparse_results: - out = vector_results[:] - if self.config.fusion.normalize_score: - self._normalize_scores_minmax(out) - return out - - k = self.config.fusion.rrf_k - w_vec = self.config.fusion.vector_weight - w_sparse = self.config.fusion.bm25_weight - merged: Dict[str, RetrievalResult] = {} - score_map: Dict[str, float] = {} - - for rank, item in enumerate(vector_results, start=1): - h = item.hash_value - if h not in merged: - merged[h] = item - merged[h].source = "fusion_rrf" - score_map[h] = score_map.get(h, 0.0) + w_vec * (1.0 / (k + rank)) - - for rank, item in enumerate(sparse_results, start=1): - h = item.hash_value - if h not in merged: - merged[h] = item - merged[h].source = "fusion_rrf" - score_map[h] = score_map.get(h, 0.0) + w_sparse * (1.0 / (k + rank)) - - out = list(merged.values()) - for item in out: - item.score = float(score_map.get(item.hash_value, 0.0)) - - out.sort(key=lambda x: x.score, reverse=True) - if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": - self._normalize_scores_minmax(out) - return out - - def _search_paragraphs_sparse( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """BM25 段落召回。""" - if not self.sparse_index or not self.config.sparse.enabled: - return [] - - candidate_k = max(top_k, self.config.sparse.candidate_k) - candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) - sparse_rows = self.sparse_index.search(query=query, k=candidate_k) - results: List[RetrievalResult] = [] - for row in sparse_rows: - hash_value = row["hash"] - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is None: - continue - time_meta = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) - results.append( - RetrievalResult( - hash_value=hash_value, - content=paragraph["content"], - score=float(row.get("score", 0.0)), - result_type="paragraph", - source="sparse_bm25", - metadata={ - "word_count": paragraph.get("word_count", 0), - "time_meta": time_meta, - "bm25_score": float(row.get("bm25_score", 0.0)), - }, - ) - ) - results = self._apply_temporal_filter_to_paragraphs(results, temporal) - if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": - self._normalize_scores_minmax(results) - return results - - def _search_relations_sparse( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """关系 BM25 召回。""" - if not self.sparse_index or not self.config.sparse.enabled: - return [] - if not self.config.sparse.enable_relation_sparse_fallback: - return [] - - candidate_k = max(top_k, self.config.sparse.relation_candidate_k) - candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) - rows = self.sparse_index.search_relations(query=query, k=candidate_k) - results: List[RetrievalResult] = [] - for row in rows: - hash_value = row["hash"] - relation = self.metadata_store.get_relation(hash_value) - if relation is None: - continue - - relation_time_meta = None - if temporal: - relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) - if relation_time_meta is None: - continue - - content = f"{relation['subject']} {relation['predicate']} {relation['object']}" - results.append( - RetrievalResult( - hash_value=hash_value, - content=content, - score=float(row.get("score", 0.0)), - result_type="relation", - source="sparse_relation_bm25", - metadata={ - "subject": relation["subject"], - "predicate": relation["predicate"], - "object": relation["object"], - "confidence": relation.get("confidence", 1.0), - "time_meta": relation_time_meta, - "bm25_score": float(row.get("bm25_score", 0.0)), - }, - ) - ) - - if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": - self._normalize_scores_minmax(results) - return self._apply_temporal_filter_to_relations(results, temporal) - - def _merge_relation_results( - self, - vector_results: List[RetrievalResult], - sparse_results: List[RetrievalResult], - ) -> List[RetrievalResult]: - """合并关系候选,按 hash 去重并保留更高分。""" - merged: Dict[str, RetrievalResult] = {} - for item in vector_results: - merged[item.hash_value] = item - for item in sparse_results: - old = merged.get(item.hash_value) - if old is None or float(item.score) > float(old.score): - merged[item.hash_value] = item - elif old is not None and old.source != item.source: - old.source = "relation_fusion" - out = list(merged.values()) - out.sort(key=lambda x: x.score, reverse=True) - return out - - def _merge_relation_results_graph_enhanced( - self, - vector_results: List[RetrievalResult], - sparse_results: List[RetrievalResult], - graph_results: List[RetrievalResult], - ) -> List[RetrievalResult]: - """Graph-aware relation fusion with semantic + graph + evidence scoring.""" - vector_norm = self._build_minmax_score_map(vector_results) - sparse_norm = self._build_minmax_score_map(sparse_results) - graph_score_map = { - "direct_pair": 1.0, - "one_hop_seed": 0.75, - "two_hop_pair": 0.55, - } - - merged: Dict[str, RetrievalResult] = {} - source_sets: Dict[str, set[str]] = {} - support_cache: Dict[str, int] = {} - - for group in (vector_results, sparse_results, graph_results): - for item in group: - existing = merged.get(item.hash_value) - if existing is None: - existing = self._clone_retrieval_result(item) - merged[item.hash_value] = existing - else: - for key, value in dict(item.metadata or {}).items(): - if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): - existing.metadata[key] = value - source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") - - out = list(merged.values()) - for item in out: - meta = item.metadata if isinstance(item.metadata, dict) else {} - semantic_norm = max( - float(vector_norm.get(item.hash_value, 0.0)), - float(sparse_norm.get(item.hash_value, 0.0)), - ) - graph_candidate_type = str(meta.get("graph_candidate_type", "") or "") - graph_score = float(graph_score_map.get(graph_candidate_type, 0.0)) - - if item.hash_value not in support_cache: - cached = meta.get("supporting_paragraph_count") - if cached is None: - support_cache[item.hash_value] = len( - self.metadata_store.get_paragraphs_by_relation(item.hash_value) - ) - else: - support_cache[item.hash_value] = max(0, int(cached)) - supporting_paragraph_count = support_cache[item.hash_value] - evidence_score = min(1.0, supporting_paragraph_count / 3.0) - - meta["supporting_paragraph_count"] = supporting_paragraph_count - meta["graph_seed_entities"] = list(meta.get("graph_seed_entities") or []) - if "graph_hops" in meta: - meta["graph_hops"] = int(meta.get("graph_hops") or 0) - item.score = 0.60 * semantic_norm + 0.30 * graph_score + 0.10 * evidence_score - - sources = source_sets.get(item.hash_value, set()) - if len(sources) > 1: - item.source = "relation_fusion" - elif sources: - item.source = next(iter(sources)) - - out.sort(key=lambda x: x.score, reverse=True) - return out - - async def _retrieve_paragraphs_only( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 仅检索段落(异步方法) - - Args: - query: 查询文本 - top_k: 返回数量 - - Returns: - 检索结果列表 - """ - query_emb = None - embedding_ok = False - vector_results: List[RetrievalResult] = [] - - try: - query_emb = await self.embedding_manager.encode(query) - embedding_ok = self._is_embedding_ready_for_vector_search( - query_emb, - stage="paragraph_only", - ) - except Exception as e: - logger.warning(f"段落检索 embedding 生成失败,将尝试 sparse 回退: {e}") - - if embedding_ok: - multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 - candidate_k = self._cap_temporal_scan_k(top_k * 2 * multiplier, temporal) - para_ids, para_scores = self.vector_store.search( - query_emb, # type: ignore[arg-type] - k=candidate_k, - ) - - for hash_value, score in zip(para_ids, para_scores): - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is None: - continue - time_meta = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) - vector_results.append( - RetrievalResult( - hash_value=hash_value, - content=paragraph["content"], - score=float(score), - result_type="paragraph", - source="paragraph_search", - metadata={ - "word_count": paragraph.get("word_count", 0), - "time_meta": time_meta, - }, - ) - ) - vector_results = self._apply_temporal_filter_to_paragraphs(vector_results, temporal) - - sparse_results: List[RetrievalResult] = [] - if self._should_use_sparse(embedding_ok, vector_results): - sparse_results = self._search_paragraphs_sparse(query, top_k, temporal=temporal) - - if self.config.fusion.method == "weighted_rrf" and (vector_results and sparse_results): - results = self._fuse_ranked_lists_weighted_rrf(vector_results, sparse_results) - elif vector_results and sparse_results: - results = vector_results + sparse_results - results.sort(key=lambda x: x.score, reverse=True) - else: - results = vector_results if vector_results else sparse_results - - return results[:top_k] - - async def _retrieve_relations_only( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 仅检索关系 (通过实体枢纽 Entity-Pivot) - - 策略: - 1. 检索向量库中的 Top-K 实体 (Entity) - 2. 通过图结构/元数据扩展出与实体关联的关系 (Relation) - 3. 以实体相似度作为基础分返回关系 - - Args: - query: 查询文本 - top_k: 返回数量 - - Returns: - 检索结果列表 - """ - query_emb = None - embedding_ok = False - vector_results: List[RetrievalResult] = [] - try: - query_emb = await self.embedding_manager.encode(query) - embedding_ok = self._is_embedding_ready_for_vector_search( - query_emb, - stage="relation_only", - ) - except Exception as e: - logger.warning(f"关系检索 embedding 生成失败,将尝试 sparse 回退: {e}") - - if embedding_ok: - # 1. 检索向量 (混合了段落和实体,所以扩大检索范围以召回足够多实体) - multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 - candidate_k = self._cap_temporal_scan_k(top_k * 3 * multiplier, temporal) - ids, scores = self.vector_store.search( - query_emb, # type: ignore[arg-type] - k=candidate_k, - ) - - seen_relations = set() - for hash_value, score in zip(ids, scores): - entity = self.metadata_store.get_entity(hash_value) - if not entity: - continue - entity_name = entity["name"] - - related_rels = [] - related_rels.extend(self.metadata_store.get_relations(subject=entity_name)) - related_rels.extend(self.metadata_store.get_relations(object=entity_name)) - - for rel in related_rels: - if rel["hash"] in seen_relations: - continue - seen_relations.add(rel["hash"]) - - relation_time_meta = None - if temporal: - relation_time_meta = self._best_supporting_time_meta(rel["hash"], temporal) - if relation_time_meta is None: - continue - - content = f"{rel['subject']} {rel['predicate']} {rel['object']}" - vector_results.append( - RetrievalResult( - hash_value=rel["hash"], - content=content, - score=float(score), - result_type="relation", - source="relation_search (via entity)", - metadata={ - "subject": rel["subject"], - "predicate": rel["predicate"], - "object": rel["object"], - "confidence": rel.get("confidence", 1.0), - "pivot_entity": entity_name, - "time_meta": relation_time_meta, - }, - ) - ) - - vector_results = self._apply_temporal_filter_to_relations(vector_results, temporal) - - sparse_results: List[RetrievalResult] = [] - if self._should_use_sparse_relations(embedding_ok, vector_results): - sparse_results = self._search_relations_sparse(query=query, top_k=top_k, temporal=temporal) - - graph_results = self._search_relations_graph(query=query, temporal=temporal) - if graph_results: - results = self._merge_relation_results_graph_enhanced( - vector_results, - sparse_results, - graph_results, - ) - elif vector_results and sparse_results: - results = self._merge_relation_results(vector_results, sparse_results) - else: - results = vector_results if vector_results else sparse_results - - return results[:top_k] - - async def _retrieve_dual_path( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - relation_intent: Optional[Dict[str, Any]] = None, - ) -> List[RetrievalResult]: - """ - 双路检索(段落+关系)(异步方法) - - Args: - query: 查询文本 - top_k: 返回数量 - - Returns: - 融合后的检索结果列表 - """ - query_emb = None - embedding_ok = False - relation_intent = relation_intent or {} - relation_top_k = max( - 1, - int(relation_intent.get("relation_top_k", self.config.top_k_relations)), - ) - force_relation_sparse = bool(relation_intent.get("force_relation_sparse", False)) - preserve_top_relations = max( - 0, - int(relation_intent.get("preserve_top_relations", 0)), - ) - pair_predicate_rerank_enabled = bool( - relation_intent.get("pair_predicate_rerank_enabled", False) - ) - pair_predicate_limit = max( - 1, - int( - relation_intent.get( - "pair_predicate_limit", - self.config.relation_intent.pair_predicate_limit, - ) - ), - ) - alpha_override = relation_intent.get("alpha_override") - try: - query_emb = await self.embedding_manager.encode(query) - embedding_ok = self._is_embedding_ready_for_vector_search( - query_emb, - stage="dual_path", - ) - except Exception as e: - logger.warning(f"双路检索 embedding 生成失败,将尝试 sparse 回退: {e}") - - para_results: List[RetrievalResult] = [] - rel_results: List[RetrievalResult] = [] - if embedding_ok: - # 并行检索(使用 asyncio) - if self.config.enable_parallel: - para_results, rel_results = await self._parallel_retrieve( - query_emb, - temporal=temporal, - relation_top_k=relation_top_k, - ) # type: ignore[arg-type] - else: - para_results, rel_results = self._sequential_retrieve( - query_emb, - temporal=temporal, - relation_top_k=relation_top_k, - ) # type: ignore[arg-type] - else: - logger.warning("embedding 不可用,跳过向量段落/关系召回") - - sparse_para_results: List[RetrievalResult] = [] - if self._should_use_sparse(embedding_ok, para_results): - sparse_para_results = self._search_paragraphs_sparse( - query=query, - top_k=max(top_k * 2, self.config.sparse.candidate_k), - temporal=temporal, - ) - sparse_rel_results: List[RetrievalResult] = [] - if self._should_use_sparse_relations( - embedding_ok, - rel_results, - force_enable=force_relation_sparse, - ): - sparse_rel_results = self._search_relations_sparse( - query=query, - top_k=max( - top_k, - self.config.sparse.relation_candidate_k, - relation_top_k, - ), - temporal=temporal, - ) - - graph_rel_results: List[RetrievalResult] = [] - if bool(relation_intent.get("enabled", False)): - graph_rel_results = self._search_relations_graph(query=query, temporal=temporal) - - if self.config.fusion.method == "weighted_rrf" and para_results and sparse_para_results: - para_results = self._fuse_ranked_lists_weighted_rrf(para_results, sparse_para_results) - elif para_results and sparse_para_results: - para_results = para_results + sparse_para_results - para_results.sort(key=lambda x: x.score, reverse=True) - elif sparse_para_results and (not para_results or not embedding_ok): - para_results = sparse_para_results - - if graph_rel_results: - rel_results = self._merge_relation_results_graph_enhanced( - rel_results, - sparse_rel_results, - graph_rel_results, - ) - elif rel_results and sparse_rel_results: - rel_results = self._merge_relation_results(rel_results, sparse_rel_results) - elif sparse_rel_results and (not rel_results or not embedding_ok): - rel_results = sparse_rel_results - - # 融合结果 - fused_results = self._fuse_results( - para_results, - rel_results, - query_emb, - alpha_override=alpha_override, - preserve_top_relations=preserve_top_relations, - ) - - # PageRank重排序 - if self.config.enable_ppr: - fused_results = await self._rerank_with_ppr( - fused_results, - query, - ) - - if temporal: - fused_results = self._sort_results_with_temporal(fused_results, temporal) - - fused_results = self._apply_relation_intent_pair_rerank( - fused_results, - enabled=bool(relation_intent.get("enabled", False)), - pair_rerank_enabled=pair_predicate_rerank_enabled, - pair_limit=pair_predicate_limit, - ) - - return fused_results[:top_k] - - async def _parallel_retrieve( - self, - query_emb: np.ndarray, - temporal: Optional[TemporalQueryOptions] = None, - relation_top_k: Optional[int] = None, - ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: - """ - 并行检索段落和关系(异步方法) - - Args: - query_emb: 查询嵌入 - - Returns: - (段落结果, 关系结果) - """ - # 使用 asyncio.gather 并发执行两个搜索任务 - # 由于 _search_paragraphs 和 _search_relations 是 CPU 密集型同步函数, - # 使用 asyncio.to_thread 在线程池中执行 - try: - para_task = asyncio.to_thread( - self._search_paragraphs, - query_emb, - self.config.top_k_paragraphs, - temporal, - ) - rel_task = asyncio.to_thread( - self._search_relations, - query_emb, - relation_top_k if relation_top_k is not None else self.config.top_k_relations, - temporal, - ) - - para_results, rel_results = await asyncio.gather( - para_task, rel_task, return_exceptions=True - ) - - # 处理异常 - if isinstance(para_results, Exception): - logger.error(f"段落检索失败: {para_results}") - para_results = [] - if isinstance(rel_results, Exception): - logger.error(f"关系检索失败: {rel_results}") - rel_results = [] - - return para_results, rel_results - - except Exception as e: - logger.error(f"并行检索失败: {e}") - return [], [] - - def _sequential_retrieve( - self, - query_emb: np.ndarray, - temporal: Optional[TemporalQueryOptions] = None, - relation_top_k: Optional[int] = None, - ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: - """ - 顺序检索段落和关系 - - Args: - query_emb: 查询嵌入 - - Returns: - (段落结果, 关系结果) - """ - para_results = self._search_paragraphs( - query_emb, - self.config.top_k_paragraphs, - temporal, - ) - - rel_results = self._search_relations( - query_emb, - relation_top_k if relation_top_k is not None else self.config.top_k_relations, - temporal, - ) - - return para_results, rel_results - - def _search_paragraphs( - self, - query_emb: np.ndarray, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 搜索段落 - - Args: - query_emb: 查询嵌入 - top_k: 返回数量 - - Returns: - 段落结果列表 - """ - multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 - candidate_k = self._cap_temporal_scan_k(top_k * multiplier, temporal) - para_ids, para_scores = self.vector_store.search(query_emb, k=candidate_k) - - results = [] - for hash_value, score in zip(para_ids, para_scores): - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is None: - continue - - time_meta = self._build_time_meta_from_paragraph( - paragraph, - temporal=temporal, - ) - results.append(RetrievalResult( - hash_value=hash_value, - content=paragraph["content"], - score=float(score), - result_type="paragraph", - source="paragraph_search", - metadata={ - "word_count": paragraph.get("word_count", 0), - "time_meta": time_meta, - }, - )) - - return self._apply_temporal_filter_to_paragraphs(results, temporal) - - def _search_relations( - self, - query_emb: np.ndarray, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 搜索关系 - - Args: - query_emb: 查询嵌入 - top_k: 返回数量 - - Returns: - 关系结果列表 - """ - multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 - candidate_k = self._cap_temporal_scan_k(top_k * multiplier, temporal) - rel_ids, rel_scores = self.vector_store.search(query_emb, k=candidate_k) - - results = [] - for hash_value, score in zip(rel_ids, rel_scores): - relation = self.metadata_store.get_relation(hash_value) - if relation is None: - continue - - relation_time_meta = None - if temporal: - relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) - if relation_time_meta is None: - continue - - content = f"{relation['subject']} {relation['predicate']} {relation['object']}" - - results.append(RetrievalResult( - hash_value=hash_value, - content=content, - score=float(score), - result_type="relation", - source="relation_search", - metadata={ - "subject": relation["subject"], - "predicate": relation["predicate"], - "object": relation["object"], - "confidence": relation.get("confidence", 1.0), - "time_meta": relation_time_meta, - }, - )) - - return self._apply_temporal_filter_to_relations(results, temporal) - - def _fuse_results( - self, - para_results: List[RetrievalResult], - rel_results: List[RetrievalResult], - query_emb: Optional[np.ndarray] = None, - alpha_override: Optional[float] = None, - preserve_top_relations: int = 0, - ) -> List[RetrievalResult]: - """ - 融合段落和关系结果 - - 融合策略: - 1. 计算加权分数 - 2. 去重(基于段落和关系的关联) - 3. 排序 - - Args: - para_results: 段落结果 - rel_results: 关系结果 - query_emb: 查询嵌入(兼容参数,当前未使用) - - Returns: - 融合后的结果列表 - """ - del query_emb # 参数保留用于兼容 - alpha = float(alpha_override) if alpha_override is not None else self.config.alpha - - # 为段落结果计算加权分数 - for result in para_results: - result.score = result.score * alpha - result.source = "fusion" - - # 为关系结果计算加权分数 - for result in rel_results: - result.score = result.score * (1 - alpha) - result.source = "fusion" - - preserve_top_relations = max(0, int(preserve_top_relations)) - preserved_relation_hashes = set() - if preserve_top_relations > 0 and rel_results: - rel_ranked = sorted(rel_results, key=lambda x: x.score, reverse=True) - preserved_relation_hashes = { - item.hash_value for item in rel_ranked[:preserve_top_relations] - } - - # 合并结果 - all_results = para_results + rel_results - all_results.sort(key=lambda x: x.score, reverse=True) - - # 去重:如果段落有关联的关系,只保留分数更高的 - seen_paragraphs = set() - seen_items = set() - deduplicated_results = [] - - for result in all_results: - if result.hash_value in seen_items: - continue - if result.result_type == "paragraph": - hash_val = result.hash_value - if hash_val not in seen_paragraphs: - seen_paragraphs.add(hash_val) - seen_items.add(hash_val) - deduplicated_results.append(result) - else: # relation - if result.hash_value in preserved_relation_hashes: - seen_items.add(result.hash_value) - deduplicated_results.append(result) - continue - # 检查关系关联的段落是否已存在 - relation = self.metadata_store.get_relation(result.hash_value) - if relation: - # 获取关联的段落 - para_rels = self.metadata_store.query(""" - SELECT paragraph_hash FROM paragraph_relations - WHERE relation_hash = ? - """, (result.hash_value,)) - - if para_rels: - # 检查段落是否已在结果中 - for para_rel in para_rels: - if para_rel["paragraph_hash"] in seen_paragraphs: - # 段落已存在,跳过此关系 - break - else: - # 所有段落都不存在,添加关系 - seen_items.add(result.hash_value) - deduplicated_results.append(result) - else: - # 没有关联段落,直接添加 - seen_items.add(result.hash_value) - deduplicated_results.append(result) - else: - seen_items.add(result.hash_value) - deduplicated_results.append(result) - - # 按分数排序 - deduplicated_results.sort(key=lambda x: x.score, reverse=True) - - return deduplicated_results - - def _apply_relation_intent_pair_rerank( - self, - results: List[RetrievalResult], - *, - enabled: bool, - pair_rerank_enabled: bool, - pair_limit: int, - ) -> List[RetrievalResult]: - """仅在 relation-intent 下对关系项执行同主客体多谓词重排。""" - if not enabled or not pair_rerank_enabled: - return results - return self._rerank_relation_items_by_pair(results, pair_limit=pair_limit) - - def _rerank_relation_items_by_pair( - self, - results: List[RetrievalResult], - pair_limit: int, - ) -> List[RetrievalResult]: - """ - 同主客体多谓词重排: - 1. 关系项按 (subject, object) 分组 - 2. 组内按分数降序 + 原始位置升序 - 3. 组间按组最高分降序 + 组最早位置升序 - 4. 先拼接每组前 N 条,再拼接每组 overflow 条目 - 5. 回填到原关系槽位,段落槽位不变 - """ - if len(results) <= 1: - return results - - relation_positions: List[int] = [] - relation_items: List[Tuple[int, RetrievalResult]] = [] - for idx, item in enumerate(results): - if item.result_type == "relation": - relation_positions.append(idx) - relation_items.append((idx, item)) - - if len(relation_items) <= 1: - return results - - pair_limit = max(1, int(pair_limit)) - - grouped: Dict[Tuple[str, str], List[Tuple[int, RetrievalResult]]] = {} - for original_idx, item in relation_items: - metadata = item.metadata if isinstance(item.metadata, dict) else {} - subject = str(metadata.get("subject", "")).strip().lower() - obj = str(metadata.get("object", "")).strip().lower() - if subject and obj: - key = (subject, obj) - else: - key = ("__missing__", item.hash_value) - grouped.setdefault(key, []).append((original_idx, item)) - - for grouped_items in grouped.values(): - grouped_items.sort(key=lambda x: (-float(x[1].score), x[0])) - - ordered_groups = sorted( - grouped.values(), - key=lambda grouped_items: ( - -float(grouped_items[0][1].score), - grouped_items[0][0], - ), - ) - - prioritized: List[RetrievalResult] = [] - overflow: List[RetrievalResult] = [] - for grouped_items in ordered_groups: - prioritized.extend([item for _, item in grouped_items[:pair_limit]]) - overflow.extend([item for _, item in grouped_items[pair_limit:]]) - - reordered_relations = prioritized + overflow - if len(reordered_relations) != len(relation_items): - return results - - logger.debug( - "relation_rerank_applied=1 " - f"relation_pair_groups={len(ordered_groups)} " - f"relation_pair_overflow_count={len(overflow)} " - f"relation_pair_limit={pair_limit}" - ) - - rebuilt = list(results) - for slot_idx, relation_item in zip(relation_positions, reordered_relations): - rebuilt[slot_idx] = relation_item - return rebuilt - - async def _rerank_with_ppr( - self, - results: List[RetrievalResult], - query: str, - ) -> List[RetrievalResult]: - """ - 使用PageRank重排序结果 (异步 + 线程池) - - Args: - results: 检索结果 - query: 查询文本 - - Returns: - 重排序后的结果 - """ - # 从查询中提取实体 - entities = self._extract_entities(query) - - if not entities: - logger.debug("未识别到实体,跳过PPR重排序") - return results - - # 计算PPR分数 (放入线程池运行,避免阻塞主循环) - ppr_timeout_s = max(0.1, float(getattr(self.config, "ppr_timeout_seconds", 1.5) or 1.5)) - try: - async with self._ppr_semaphore: - ppr_scores = await asyncio.wait_for( - asyncio.to_thread( - self._ppr.compute, - personalization=entities, - normalize=True, - ), - timeout=ppr_timeout_s, - ) - except asyncio.TimeoutError: - logger.warning( - "metric.ppr_timeout_skip_count=1 " - f"timeout_s={ppr_timeout_s} " - f"entities={len(entities)}" - ) - return results - except Exception as e: - logger.warning(f"PPR 重排序失败,回退原排序: {e}") - return results - - # 调整结果分数 - ppr_scores_by_name = { - str(name).strip().lower(): float(score) - for name, score in ppr_scores.items() - } - for result in results: - if result.result_type == "paragraph": - # 获取段落的实体 - para_entities = self.metadata_store.get_paragraph_entities( - result.hash_value - ) - - # 计算实体的平均PPR分数 - if para_entities: - entity_scores = [] - for ent in para_entities: - ent_name = str(ent.get("name", "")).strip().lower() - if ent_name in ppr_scores_by_name: - entity_scores.append(ppr_scores_by_name[ent_name]) - - if entity_scores: - avg_ppr = np.mean(entity_scores) - # 融合原始分数和PPR分数 - result.score = result.score * 0.7 + avg_ppr * 0.3 - - # 重新排序 - results.sort(key=lambda x: x.score, reverse=True) - - return results - - def _retrieve_temporal_only( - self, - temporal: TemporalQueryOptions, - top_k: int, - ) -> List[RetrievalResult]: - """无语义 query 时,直接走时序索引查询。""" - limit = self._cap_temporal_scan_k( - top_k * max(1, temporal.candidate_multiplier), - temporal, - ) - paragraphs = self.metadata_store.query_paragraphs_temporal( - start_ts=temporal.time_from, - end_ts=temporal.time_to, - person=temporal.person, - source=temporal.source, - limit=limit, - allow_created_fallback=temporal.allow_created_fallback, - ) - results: List[RetrievalResult] = [] - for para in paragraphs: - time_meta = self._build_time_meta_from_paragraph(para, temporal=temporal) - results.append( - RetrievalResult( - hash_value=para["hash"], - content=para["content"], - score=1.0, - result_type="paragraph", - source="temporal_scan", - metadata={ - "word_count": para.get("word_count", 0), - "time_meta": time_meta, - }, - ) - ) - - results = self._sort_results_with_temporal(results, temporal) - return results[:top_k] - - def _extract_effective_time( - self, - paragraph: Dict[str, Any], - temporal: Optional[TemporalQueryOptions] = None, - ) -> Tuple[Optional[float], Optional[float], Optional[str]]: - """提取段落有效时间区间与命中依据。""" - event_time = paragraph.get("event_time") - event_start = paragraph.get("event_time_start") - event_end = paragraph.get("event_time_end") - - if event_start is not None or event_end is not None: - effective_start = event_start if event_start is not None else ( - event_time if event_time is not None else event_end - ) - effective_end = event_end if event_end is not None else ( - event_time if event_time is not None else event_start - ) - return effective_start, effective_end, "event_time_range" - - if event_time is not None: - return event_time, event_time, "event_time" - - allow_fallback = True - if temporal is not None: - allow_fallback = temporal.allow_created_fallback - - created_at = paragraph.get("created_at") - if allow_fallback and created_at is not None: - return created_at, created_at, "created_at_fallback" - - return None, None, None - - def _build_time_meta_from_paragraph( - self, - paragraph: Dict[str, Any], - temporal: Optional[TemporalQueryOptions] = None, - ) -> Dict[str, Any]: - """构建统一 time_meta 结构。""" - effective_start, effective_end, match_basis = self._extract_effective_time( - paragraph, - temporal=temporal, - ) - return { - "event_time": paragraph.get("event_time"), - "event_time_start": paragraph.get("event_time_start"), - "event_time_end": paragraph.get("event_time_end"), - "ingest_time": paragraph.get("created_at"), - "time_granularity": paragraph.get("time_granularity"), - "time_confidence": paragraph.get("time_confidence", 1.0), - "effective_start": effective_start, - "effective_end": effective_end, - "effective_start_text": format_timestamp(effective_start), - "effective_end_text": format_timestamp(effective_end), - "match_basis": match_basis or "none", - } - - def _matches_person_filter(self, paragraph_hash: str, person: Optional[str]) -> bool: - if not person: - return True - target = person.strip().lower() - if not target: - return True - para_entities = self.metadata_store.get_paragraph_entities(paragraph_hash) - for ent in para_entities: - name = str(ent.get("name", "")).strip().lower() - if target in name: - return True - return False - - def _is_temporal_match( - self, - paragraph: Dict[str, Any], - temporal: TemporalQueryOptions, - ) -> bool: - """判断段落是否命中时序筛选。""" - if temporal.source and paragraph.get("source") != temporal.source: - return False - - if not self._matches_person_filter(paragraph.get("hash", ""), temporal.person): - return False - - effective_start, effective_end, _ = self._extract_effective_time(paragraph, temporal=temporal) - if effective_start is None or effective_end is None: - return False - - if temporal.time_from is not None and temporal.time_to is not None: - return effective_end >= temporal.time_from and effective_start <= temporal.time_to - if temporal.time_from is not None: - return effective_end >= temporal.time_from - if temporal.time_to is not None: - return effective_start <= temporal.time_to - return True - - def _apply_temporal_filter_to_paragraphs( - self, - results: List[RetrievalResult], - temporal: Optional[TemporalQueryOptions], - ) -> List[RetrievalResult]: - if not temporal: - return results - - filtered: List[RetrievalResult] = [] - for result in results: - paragraph = self.metadata_store.get_paragraph(result.hash_value) - if not paragraph: - continue - if not self._is_temporal_match(paragraph, temporal): - continue - result.metadata["time_meta"] = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) - filtered.append(result) - - return self._sort_results_with_temporal(filtered, temporal) - - def _best_supporting_time_meta( - self, - relation_hash: str, - temporal: TemporalQueryOptions, - ) -> Optional[Dict[str, Any]]: - """获取关系在时序窗口内最优支撑段落的 time_meta。""" - supports = self.metadata_store.get_paragraphs_by_relation(relation_hash) - if not supports: - return None - - best_meta: Optional[Dict[str, Any]] = None - best_time = float("-inf") - for para in supports: - if not self._is_temporal_match(para, temporal): - continue - meta = self._build_time_meta_from_paragraph(para, temporal=temporal) - eff = meta.get("effective_end") - score = float(eff) if eff is not None else float("-inf") - if score >= best_time: - best_time = score - best_meta = meta - - return best_meta - - def _apply_temporal_filter_to_relations( - self, - results: List[RetrievalResult], - temporal: Optional[TemporalQueryOptions], - ) -> List[RetrievalResult]: - if not temporal: - return results - - filtered: List[RetrievalResult] = [] - for result in results: - meta = result.metadata.get("time_meta") - if meta is None: - meta = self._best_supporting_time_meta(result.hash_value, temporal) - if meta is None: - continue - result.metadata["time_meta"] = meta - filtered.append(result) - - return self._sort_results_with_temporal(filtered, temporal) - - def _sort_results_with_temporal( - self, - results: List[RetrievalResult], - temporal: TemporalQueryOptions, - ) -> List[RetrievalResult]: - """语义优先,时间次排序(新到旧)。""" - del temporal # temporal 保留给未来扩展,目前只使用结果内 time_meta - - def _temporal_key(item: RetrievalResult) -> float: - time_meta = item.metadata.get("time_meta", {}) - effective = time_meta.get("effective_end") - if effective is None: - effective = time_meta.get("effective_start") - if effective is None: - return float("-inf") - return float(effective) - - results.sort(key=lambda x: (x.score, _temporal_key(x)), reverse=True) - return results - - def _extract_entities(self, text: str) -> Dict[str, float]: - """ - 从文本中提取实体(简化版本) - - Args: - text: 输入文本 - - Returns: - 实体字典 {实体名: 权重} - """ - # 获取所有实体 - all_entities = self.graph_store.get_nodes() - if not all_entities: - return {} - - # 检查是否需要更新 Aho-Corasick 匹配器 - if self._ac_matcher is None or self._ac_nodes_count != len(all_entities): - self._ac_matcher = AhoCorasick() - for entity in all_entities: - self._ac_matcher.add_pattern(entity.lower()) - self._ac_matcher.build() - self._ac_nodes_count = len(all_entities) - - # 执行匹配 - text_lower = text.lower() - stats = self._ac_matcher.find_all(text_lower) - - # 映射回原始名称并使用出现次数作为权重 - node_map = {node.lower(): node for node in all_entities} - entities = {node_map[low_name]: float(count) for low_name, count in stats.items()} - - return entities - - def get_statistics(self) -> Dict[str, Any]: - """ - 获取检索统计信息 - - Returns: - 统计信息字典 - """ - vector_size = getattr(self.vector_store, "size", None) - if vector_size is None: - vector_size = getattr(self.vector_store, "num_vectors", 0) - - return { - "config": { - "top_k_paragraphs": self.config.top_k_paragraphs, - "top_k_relations": self.config.top_k_relations, - "top_k_final": self.config.top_k_final, - "alpha": self.config.alpha, - "enable_ppr": self.config.enable_ppr, - "enable_parallel": self.config.enable_parallel, - "strategy": self.config.retrieval_strategy.value, - "sparse_mode": self.config.sparse.mode, - "fusion_method": self.config.fusion.method, - "relation_intent_enabled": self.config.relation_intent.enabled, - "relation_intent_alpha_override": self.config.relation_intent.alpha_override, - "relation_intent_candidate_multiplier": self.config.relation_intent.relation_candidate_multiplier, - "relation_intent_preserve_top_relations": self.config.relation_intent.preserve_top_relations, - "relation_intent_force_sparse": self.config.relation_intent.force_relation_sparse, - "relation_intent_pair_rerank_enabled": self.config.relation_intent.pair_predicate_rerank_enabled, - "relation_intent_pair_predicate_limit": self.config.relation_intent.pair_predicate_limit, - "graph_recall_enabled": self.config.graph_recall.enabled, - "graph_recall_candidate_k": self.config.graph_recall.candidate_k, - "graph_recall_allow_two_hop_pair": self.config.graph_recall.allow_two_hop_pair, - "graph_recall_max_paths": self.config.graph_recall.max_paths, - }, - "vector_store": { - "size": int(vector_size), - }, - "graph_store": { - "num_nodes": self.graph_store.num_nodes, - "num_edges": self.graph_store.num_edges, - }, - "metadata_store": self.metadata_store.get_statistics(), - "sparse": self.sparse_index.stats() if self.sparse_index else None, - } - - def __repr__(self) -> str: - return ( - f"DualPathRetriever(" - f"strategy={self.config.retrieval_strategy.value}, " - f"para_k={self.config.top_k_paragraphs}, " - f"rel_k={self.config.top_k_relations})" - ) diff --git a/plugins/A_memorix/core/retrieval/graph_relation_recall.py b/plugins/A_memorix/core/retrieval/graph_relation_recall.py deleted file mode 100644 index 9af862f3..00000000 --- a/plugins/A_memorix/core/retrieval/graph_relation_recall.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Graph-assisted relation candidate recall for relation-oriented queries.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Set - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.GraphRelationRecall") - - -@dataclass -class GraphRelationRecallConfig: - """Configuration for controlled graph relation recall.""" - - enabled: bool = True - candidate_k: int = 24 - max_hop: int = 1 - allow_two_hop_pair: bool = True - max_paths: int = 4 - - def __post_init__(self) -> None: - self.enabled = bool(self.enabled) - self.candidate_k = max(1, int(self.candidate_k)) - self.max_hop = max(1, int(self.max_hop)) - self.allow_two_hop_pair = bool(self.allow_two_hop_pair) - self.max_paths = max(1, int(self.max_paths)) - - -@dataclass -class GraphRelationCandidate: - """A graph-derived relation candidate before retriever-side fusion.""" - - hash_value: str - subject: str - predicate: str - object: str - confidence: float - graph_seed_entities: List[str] - graph_hops: int - graph_candidate_type: str - supporting_paragraph_count: int - - def to_payload(self) -> Dict[str, Any]: - content = f"{self.subject} {self.predicate} {self.object}" - return { - "hash": self.hash_value, - "content": content, - "subject": self.subject, - "predicate": self.predicate, - "object": self.object, - "confidence": self.confidence, - "graph_seed_entities": list(self.graph_seed_entities), - "graph_hops": int(self.graph_hops), - "graph_candidate_type": self.graph_candidate_type, - "supporting_paragraph_count": int(self.supporting_paragraph_count), - } - - -class GraphRelationRecallService: - """Collect relation candidates from the entity graph in a controlled way.""" - - def __init__( - self, - *, - graph_store: Any, - metadata_store: Any, - config: Optional[GraphRelationRecallConfig] = None, - ) -> None: - self.graph_store = graph_store - self.metadata_store = metadata_store - self.config = config or GraphRelationRecallConfig() - - def recall( - self, - *, - seed_entities: Sequence[str], - ) -> List[GraphRelationCandidate]: - if not self.config.enabled: - return [] - if self.graph_store is None or self.metadata_store is None: - return [] - - seeds = self._normalize_seed_entities(seed_entities) - if not seeds: - return [] - - seen_hashes: Set[str] = set() - candidates: List[GraphRelationCandidate] = [] - - if len(seeds) >= 2: - self._collect_direct_pair_candidates( - seed_a=seeds[0], - seed_b=seeds[1], - seen_hashes=seen_hashes, - out=candidates, - ) - if ( - len(candidates) < 3 - and self.config.allow_two_hop_pair - and len(candidates) < self.config.candidate_k - ): - self._collect_two_hop_pair_candidates( - seed_a=seeds[0], - seed_b=seeds[1], - seen_hashes=seen_hashes, - out=candidates, - ) - else: - self._collect_one_hop_seed_candidates( - seed=seeds[0], - seen_hashes=seen_hashes, - out=candidates, - ) - - return candidates[: self.config.candidate_k] - - def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]: - out: List[str] = [] - seen = set() - for raw in list(seed_entities)[:2]: - resolved = None - try: - resolved = self.graph_store.find_node(str(raw), ignore_case=True) - except Exception: - resolved = None - if not resolved: - continue - canon = str(resolved).strip().lower() - if not canon or canon in seen: - continue - seen.add(canon) - out.append(str(resolved)) - return out - - def _collect_direct_pair_candidates( - self, - *, - seed_a: str, - seed_b: str, - seen_hashes: Set[str], - out: List[GraphRelationCandidate], - ) -> None: - relation_hashes = [] - relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b)) - relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a)) - self._append_relation_hashes( - relation_hashes=relation_hashes, - seen_hashes=seen_hashes, - out=out, - candidate_type="direct_pair", - graph_hops=1, - graph_seed_entities=[seed_a, seed_b], - ) - - def _collect_two_hop_pair_candidates( - self, - *, - seed_a: str, - seed_b: str, - seen_hashes: Set[str], - out: List[GraphRelationCandidate], - ) -> None: - try: - paths = self.graph_store.find_paths( - seed_a, - seed_b, - max_depth=2, - max_paths=self.config.max_paths, - ) - except Exception as e: - logger.debug(f"graph two-hop recall skipped: {e}") - return - - for path_nodes in paths: - if len(out) >= self.config.candidate_k: - break - if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3: - continue - if len(path_nodes) != 3: - continue - for idx in range(len(path_nodes) - 1): - if len(out) >= self.config.candidate_k: - break - u = str(path_nodes[idx]) - v = str(path_nodes[idx + 1]) - relation_hashes = [] - relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v)) - relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u)) - self._append_relation_hashes( - relation_hashes=relation_hashes, - seen_hashes=seen_hashes, - out=out, - candidate_type="two_hop_pair", - graph_hops=2, - graph_seed_entities=[seed_a, seed_b], - ) - - def _collect_one_hop_seed_candidates( - self, - *, - seed: str, - seen_hashes: Set[str], - out: List[GraphRelationCandidate], - ) -> None: - try: - relation_hashes = self.graph_store.get_incident_relation_hashes( - seed, - limit=self.config.candidate_k, - ) - except Exception as e: - logger.debug(f"graph one-hop recall skipped: {e}") - return - self._append_relation_hashes( - relation_hashes=relation_hashes, - seen_hashes=seen_hashes, - out=out, - candidate_type="one_hop_seed", - graph_hops=min(1, self.config.max_hop), - graph_seed_entities=[seed], - ) - - def _append_relation_hashes( - self, - *, - relation_hashes: Sequence[str], - seen_hashes: Set[str], - out: List[GraphRelationCandidate], - candidate_type: str, - graph_hops: int, - graph_seed_entities: Sequence[str], - ) -> None: - for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}): - if len(out) >= self.config.candidate_k: - break - if relation_hash in seen_hashes: - continue - candidate = self._build_candidate( - relation_hash=relation_hash, - candidate_type=candidate_type, - graph_hops=graph_hops, - graph_seed_entities=graph_seed_entities, - ) - if candidate is None: - continue - seen_hashes.add(relation_hash) - out.append(candidate) - - def _build_candidate( - self, - *, - relation_hash: str, - candidate_type: str, - graph_hops: int, - graph_seed_entities: Sequence[str], - ) -> Optional[GraphRelationCandidate]: - relation = self.metadata_store.get_relation(relation_hash) - if relation is None: - return None - supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash) - return GraphRelationCandidate( - hash_value=relation_hash, - subject=str(relation.get("subject", "")), - predicate=str(relation.get("predicate", "")), - object=str(relation.get("object", "")), - confidence=float(relation.get("confidence", 1.0) or 1.0), - graph_seed_entities=[str(x) for x in graph_seed_entities], - graph_hops=int(graph_hops), - graph_candidate_type=str(candidate_type), - supporting_paragraph_count=len(supporting_paragraphs), - ) diff --git a/plugins/A_memorix/core/retrieval/pagerank.py b/plugins/A_memorix/core/retrieval/pagerank.py deleted file mode 100644 index c8ee48bb..00000000 --- a/plugins/A_memorix/core/retrieval/pagerank.py +++ /dev/null @@ -1,482 +0,0 @@ -""" -Personalized PageRank实现 - -提供个性化的图节点排序功能。 -""" - -from typing import Dict, List, Optional, Tuple, Union, Any -from dataclasses import dataclass -import numpy as np - -from src.common.logger import get_logger -from ..storage import GraphStore -from ..utils.matcher import AhoCorasick - -logger = get_logger("A_Memorix.PersonalizedPageRank") - - -@dataclass -class PageRankConfig: - """ - PageRank配置 - - 属性: - alpha: 阻尼系数(0-1之间) - max_iter: 最大迭代次数 - tol: 收敛阈值 - normalize: 是否归一化结果 - min_iterations: 最小迭代次数 - """ - - alpha: float = 0.85 - max_iter: int = 100 - tol: float = 1e-6 - normalize: bool = True - min_iterations: int = 20 - - def __post_init__(self): - """验证配置""" - if not 0 <= self.alpha < 1: - raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}") - - if self.max_iter <= 0: - raise ValueError(f"max_iter必须大于0: {self.max_iter}") - - if self.tol <= 0: - raise ValueError(f"tol必须大于0: {self.tol}") - - if self.min_iterations < 0: - raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}") - - if self.min_iterations >= self.max_iter: - raise ValueError(f"min_iterations必须小于max_iter") - - -class PersonalizedPageRank: - """ - Personalized PageRank计算器 - - 功能: - - 个性化向量支持 - - 快速收敛检测 - - 结果归一化 - - 批量计算 - - 统计信息 - - 参数: - graph_store: 图存储 - config: PageRank配置 - """ - - def __init__( - self, - graph_store: GraphStore, - config: Optional[PageRankConfig] = None, - ): - """ - 初始化PPR计算器 - - Args: - graph_store: 图存储 - config: PageRank配置 - """ - self.graph_store = graph_store - self.config = config or PageRankConfig() - - # 统计信息 - self._total_computations = 0 - self._total_iterations = 0 - self._convergence_history: List[int] = [] - - logger.info( - f"PersonalizedPageRank 初始化: " - f"alpha={self.config.alpha}, " - f"max_iter={self.config.max_iter}" - ) - - # 缓存 Aho-Corasick 匹配器 - self._ac_matcher: Optional[AhoCorasick] = None - self._ac_nodes_count = 0 - - def compute( - self, - personalization: Optional[Dict[str, float]] = None, - alpha: Optional[float] = None, - max_iter: Optional[int] = None, - normalize: Optional[bool] = None, - ) -> Dict[str, float]: - """ - 计算Personalized PageRank - - Args: - personalization: 个性化向量 {节点名: 权重} - alpha: 阻尼系数(覆盖配置值) - max_iter: 最大迭代次数(覆盖配置值) - normalize: 是否归一化(覆盖配置值) - - Returns: - 节点PageRank值字典 {节点名: 分数} - """ - # 使用覆盖值或配置值 - alpha = alpha if alpha is not None else self.config.alpha - max_iter = max_iter if max_iter is not None else self.config.max_iter - normalize = normalize if normalize is not None else self.config.normalize - - # 调用GraphStore的compute_pagerank - scores = self.graph_store.compute_pagerank( - personalization=personalization, - alpha=alpha, - max_iter=max_iter, - tol=self.config.tol, - ) - - # 归一化(如果需要) - if normalize and scores: - total = sum(scores.values()) - if total > 0: - scores = {node: score / total for node, score in scores.items()} - - # 更新统计 - self._total_computations += 1 - - logger.debug( - f"PPR计算完成: {len(scores)} 个节点, " - f"personalization_nodes={len(personalization) if personalization else 0}" - ) - - return scores - - def compute_batch( - self, - personalization_list: List[Dict[str, float]], - normalize: bool = True, - ) -> List[Dict[str, float]]: - """ - 批量计算PPR - - Args: - personalization_list: 个性化向量列表 - normalize: 是否归一化 - - Returns: - PageRank值字典列表 - """ - results = [] - - for i, personalization in enumerate(personalization_list): - logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR") - scores = self.compute(personalization=personalization, normalize=normalize) - results.append(scores) - - return results - - def compute_for_entities( - self, - entities: List[str], - weights: Optional[List[float]] = None, - normalize: bool = True, - ) -> Dict[str, float]: - """ - 为实体列表计算PPR - - Args: - entities: 实体列表 - weights: 权重列表(默认均匀权重) - normalize: 是否归一化 - - Returns: - PageRank值字典 - """ - if not entities: - logger.warning("实体列表为空,返回均匀PPR") - return self.compute(personalization=None, normalize=normalize) - - # 构建个性化向量 - if weights is None: - weights = [1.0] * len(entities) - - if len(weights) != len(entities): - raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}") - - personalization = {entity: weight for entity, weight in zip(entities, weights)} - - return self.compute(personalization=personalization, normalize=normalize) - - def compute_for_query( - self, - query: str, - entity_extractor: Optional[callable] = None, - normalize: bool = True, - ) -> Dict[str, float]: - """ - 为查询计算PPR - - Args: - query: 查询文本 - entity_extractor: 实体提取函数(可选) - normalize: 是否归一化 - - Returns: - PageRank值字典 - """ - # 提取实体 - if entity_extractor is not None: - entities = entity_extractor(query) - else: - # 简单实现:基于图中的节点匹配 - entities = self._extract_entities_from_query(query) - - if not entities: - logger.debug(f"未从查询中提取到实体: '{query}'") - return self.compute(personalization=None, normalize=normalize) - - # 计算PPR - return self.compute_for_entities(entities, normalize=normalize) - - def rank_nodes( - self, - scores: Dict[str, float], - top_k: Optional[int] = None, - min_score: float = 0.0, - ) -> List[Tuple[str, float]]: - """ - 对节点排序 - - Args: - scores: PageRank分数字典 - top_k: 返回前k个节点(None表示全部) - min_score: 最小分数阈值 - - Returns: - 排序后的节点列表 [(节点名, 分数), ...] - """ - # 过滤低分节点 - filtered = [(node, score) for node, score in scores.items() if score >= min_score] - - # 按分数降序排序 - sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True) - - # 返回top_k - if top_k is not None: - sorted_nodes = sorted_nodes[:top_k] - - return sorted_nodes - - def get_personalization_vector( - self, - nodes: List[str], - method: str = "uniform", - ) -> Dict[str, float]: - """ - 生成个性化向量 - - Args: - nodes: 节点列表 - method: 生成方法 - - "uniform": 均匀权重 - - "degree": 按度数加权 - - "inverse_degree": 按度数反比加权 - - Returns: - 个性化向量 {节点名: 权重} - """ - if not nodes: - return {} - - if method == "uniform": - # 均匀权重 - weight = 1.0 / len(nodes) - return {node: weight for node in nodes} - - elif method == "degree": - # 按度数加权 - node_degrees = {} - for node in nodes: - neighbors = self.graph_store.get_neighbors(node) - node_degrees[node] = len(neighbors) - - total_degree = sum(node_degrees.values()) - if total_degree > 0: - return {node: degree / total_degree for node, degree in node_degrees.items()} - else: - return {node: 1.0 / len(nodes) for node in nodes} - - elif method == "inverse_degree": - # 按度数反比加权 - node_degrees = {} - for node in nodes: - neighbors = self.graph_store.get_neighbors(node) - node_degrees[node] = len(neighbors) - - # 反度数 - inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()} - total_inv = sum(inv_degrees.values()) - - if total_inv > 0: - return {node: inv / total_inv for node, inv in inv_degrees.items()} - else: - return {node: 1.0 / len(nodes) for node in nodes} - - else: - raise ValueError(f"不支持的个性化向量生成方法: {method}") - - def compare_scores( - self, - scores1: Dict[str, float], - scores2: Dict[str, float], - ) -> Dict[str, Dict[str, float]]: - """ - 比较两组PPR分数 - - Args: - scores1: 第一组分数 - scores2: 第二组分数 - - Returns: - 比较结果 { - "common_nodes": {节点: (score1, score2)}, - "only_in_1": {节点: score1}, - "only_in_2": {节点: score2}, - } - """ - common_nodes = {} - only_in_1 = {} - only_in_2 = {} - - all_nodes = set(scores1.keys()) | set(scores2.keys()) - - for node in all_nodes: - if node in scores1 and node in scores2: - common_nodes[node] = (scores1[node], scores2[node]) - elif node in scores1: - only_in_1[node] = scores1[node] - else: - only_in_2[node] = scores2[node] - - return { - "common_nodes": common_nodes, - "only_in_1": only_in_1, - "only_in_2": only_in_2, - } - - def get_statistics(self) -> Dict[str, Any]: - """ - 获取统计信息 - - Returns: - 统计信息字典 - """ - avg_iterations = ( - self._total_iterations / self._total_computations - if self._total_computations > 0 - else 0 - ) - - return { - "config": { - "alpha": self.config.alpha, - "max_iter": self.config.max_iter, - "tol": self.config.tol, - "normalize": self.config.normalize, - "min_iterations": self.config.min_iterations, - }, - "statistics": { - "total_computations": self._total_computations, - "total_iterations": self._total_iterations, - "avg_iterations": avg_iterations, - "convergence_history": self._convergence_history.copy(), - }, - "graph": { - "num_nodes": self.graph_store.num_nodes, - "num_edges": self.graph_store.num_edges, - }, - } - - def reset_statistics(self) -> None: - """重置统计信息""" - self._total_computations = 0 - self._total_iterations = 0 - self._convergence_history.clear() - logger.info("统计信息已重置") - - def _extract_entities_from_query(self, query: str) -> List[str]: - """ - 从查询中提取实体(简化实现) - - Args: - query: 查询文本 - - Returns: - 实体列表 - """ - # 获取所有节点 - all_nodes = self.graph_store.get_nodes() - if not all_nodes: - return [] - - # 检查是否需要更新 Aho-Corasick 匹配器 - if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes): - self._ac_matcher = AhoCorasick() - for node in all_nodes: - # 统一转为小写进行不区分大小写匹配 - self._ac_matcher.add_pattern(node.lower()) - self._ac_matcher.build() - self._ac_nodes_count = len(all_nodes) - - # 执行匹配 - query_lower = query.lower() - stats = self._ac_matcher.find_all(query_lower) - - # 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值) - # 为了简单,AC 中 add_pattern 存的是小写 - # 我们需要一个映射:小写 -> 原始 - node_map = {node.lower(): node for node in all_nodes} - entities = [node_map[low_name] for low_name in stats.keys()] - - return entities - - @property - def num_computations(self) -> int: - """计算次数""" - return self._total_computations - - @property - def avg_iterations(self) -> float: - """平均迭代次数""" - if self._total_computations == 0: - return 0.0 - return self._total_iterations / self._total_computations - - def __repr__(self) -> str: - return ( - f"PersonalizedPageRank(" - f"alpha={self.config.alpha}, " - f"computations={self._total_computations})" - ) - - -def create_ppr_from_graph( - graph_store: GraphStore, - alpha: float = 0.85, - max_iter: int = 100, -) -> PersonalizedPageRank: - """ - 从图存储创建PPR计算器 - - Args: - graph_store: 图存储 - alpha: 阻尼系数 - max_iter: 最大迭代次数 - - Returns: - PPR计算器实例 - """ - config = PageRankConfig( - alpha=alpha, - max_iter=max_iter, - ) - - return PersonalizedPageRank( - graph_store=graph_store, - config=config, - ) diff --git a/plugins/A_memorix/core/retrieval/sparse_bm25.py b/plugins/A_memorix/core/retrieval/sparse_bm25.py deleted file mode 100644 index 1fef9f80..00000000 --- a/plugins/A_memorix/core/retrieval/sparse_bm25.py +++ /dev/null @@ -1,401 +0,0 @@ -""" -稀疏检索组件(FTS5 + BM25) - -支持: -- 懒加载索引连接 -- jieba / char n-gram 分词 -- 可卸载并收缩 SQLite 内存缓存 -""" - -from __future__ import annotations - -import re -import sqlite3 -from dataclasses import dataclass -from typing import Any, Dict, List, Optional - -from src.common.logger import get_logger -from ..storage import MetadataStore - -logger = get_logger("A_Memorix.SparseBM25") - -try: - import jieba # type: ignore - - HAS_JIEBA = True -except Exception: - HAS_JIEBA = False - jieba = None - - -@dataclass -class SparseBM25Config: - """BM25 稀疏检索配置。""" - - enabled: bool = True - backend: str = "fts5" - lazy_load: bool = True - mode: str = "auto" # auto | fallback_only | hybrid - tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram - jieba_user_dict: str = "" - char_ngram_n: int = 2 - candidate_k: int = 80 - max_doc_len: int = 2000 - enable_ngram_fallback_index: bool = True - enable_like_fallback: bool = False - enable_relation_sparse_fallback: bool = True - relation_candidate_k: int = 60 - relation_max_doc_len: int = 512 - unload_on_disable: bool = True - shrink_memory_on_unload: bool = True - - def __post_init__(self) -> None: - self.backend = str(self.backend or "fts5").strip().lower() - self.mode = str(self.mode or "auto").strip().lower() - self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower() - self.char_ngram_n = max(1, int(self.char_ngram_n)) - self.candidate_k = max(1, int(self.candidate_k)) - self.max_doc_len = max(0, int(self.max_doc_len)) - self.relation_candidate_k = max(1, int(self.relation_candidate_k)) - self.relation_max_doc_len = max(0, int(self.relation_max_doc_len)) - if self.backend != "fts5": - raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}") - if self.mode not in {"auto", "fallback_only", "hybrid"}: - raise ValueError(f"sparse.mode 非法: {self.mode}") - if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}: - raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}") - - -class SparseBM25Index: - """ - 基于 SQLite FTS5 的 BM25 检索适配层。 - """ - - def __init__( - self, - metadata_store: MetadataStore, - config: Optional[SparseBM25Config] = None, - ): - self.metadata_store = metadata_store - self.config = config or SparseBM25Config() - self._conn: Optional[sqlite3.Connection] = None - self._loaded: bool = False - self._jieba_dict_loaded: bool = False - - @property - def loaded(self) -> bool: - return self._loaded and self._conn is not None - - def ensure_loaded(self) -> bool: - """按需加载 FTS 连接与索引。""" - if not self.config.enabled: - return False - if self.loaded: - return True - - db_path = self.metadata_store.get_db_path() - conn = sqlite3.connect( - str(db_path), - check_same_thread=False, - timeout=30.0, - ) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - conn.execute("PRAGMA temp_store=MEMORY") - - if not self.metadata_store.ensure_fts_schema(conn=conn): - conn.close() - return False - self.metadata_store.ensure_fts_backfilled(conn=conn) - # 关系稀疏检索按独立开关加载,避免不必要的初始化开销。 - if self.config.enable_relation_sparse_fallback: - self.metadata_store.ensure_relations_fts_schema(conn=conn) - self.metadata_store.ensure_relations_fts_backfilled(conn=conn) - if self.config.enable_ngram_fallback_index: - self.metadata_store.ensure_paragraph_ngram_schema(conn=conn) - self.metadata_store.ensure_paragraph_ngram_backfilled( - n=self.config.char_ngram_n, - conn=conn, - ) - - self._conn = conn - self._loaded = True - self._prepare_tokenizer() - logger.info( - "SparseBM25Index loaded: " - f"backend=fts5, tokenizer={self.config.tokenizer_mode}, mode={self.config.mode}" - ) - return True - - def _prepare_tokenizer(self) -> None: - if self._jieba_dict_loaded: - return - if self.config.tokenizer_mode not in {"jieba", "mixed"}: - return - if not HAS_JIEBA: - logger.warning("jieba 不可用,tokenizer 将退化为 char n-gram") - return - user_dict = str(self.config.jieba_user_dict or "").strip() - if user_dict: - try: - jieba.load_userdict(user_dict) # type: ignore[union-attr] - logger.info(f"已加载 jieba 用户词典: {user_dict}") - except Exception as e: - logger.warning(f"加载 jieba 用户词典失败: {e}") - self._jieba_dict_loaded = True - - def _tokenize_jieba(self, text: str) -> List[str]: - if not HAS_JIEBA: - return [] - try: - tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr] - return [t.strip().lower() for t in tokens if t and t.strip()] - except Exception: - return [] - - def _tokenize_char_ngram(self, text: str, n: int) -> List[str]: - compact = re.sub(r"\s+", "", text.lower()) - if not compact: - return [] - if len(compact) < n: - return [compact] - return [compact[i : i + n] for i in range(0, len(compact) - n + 1)] - - def _tokenize(self, text: str) -> List[str]: - text = str(text or "").strip() - if not text: - return [] - - mode = self.config.tokenizer_mode - if mode == "jieba": - tokens = self._tokenize_jieba(text) - if tokens: - return list(dict.fromkeys(tokens)) - return self._tokenize_char_ngram(text, self.config.char_ngram_n) - - if mode == "mixed": - toks = self._tokenize_jieba(text) - toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n)) - return list(dict.fromkeys([t for t in toks if t])) - - return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n))) - - def _build_match_query(self, tokens: List[str]) -> str: - safe_tokens: List[str] = [] - for token in tokens: - t = token.replace('"', '""').strip() - if not t: - continue - safe_tokens.append(f'"{t}"') - if not safe_tokens: - return "" - # 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。 - return " OR ".join(safe_tokens[:64]) - - def _fallback_substring_search( - self, - tokens: List[str], - limit: int, - ) -> List[Dict[str, Any]]: - """ - 当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。 - - 说明: - - FTS 索引当前采用 unicode61 tokenizer。 - - 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。 - - 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。 - """ - if not tokens: - return [] - - # 去重并裁剪 token 数量,避免生成超长 SQL。 - uniq_tokens = [t for t in dict.fromkeys(tokens) if t] - uniq_tokens = uniq_tokens[:32] - if not uniq_tokens: - return [] - - if self.config.enable_ngram_fallback_index: - try: - # 允许运行时切换开关后按需补齐 schema/回填。 - self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn) - self.metadata_store.ensure_paragraph_ngram_backfilled( - n=self.config.char_ngram_n, - conn=self._conn, - ) - rows = self.metadata_store.ngram_search_paragraphs( - tokens=uniq_tokens, - limit=limit, - max_doc_len=self.config.max_doc_len, - conn=self._conn, - ) - if rows: - return rows - except Exception as e: - logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}") - - if not self.config.enable_like_fallback: - return [] - - conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens)) - params: List[Any] = [f"%{tok}%" for tok in uniq_tokens] - scan_limit = max(int(limit) * 8, 200) - params.append(scan_limit) - - sql = f""" - SELECT p.hash, p.content - FROM paragraphs p - WHERE (p.is_deleted IS NULL OR p.is_deleted = 0) - AND ({conditions}) - LIMIT ? - """ - rows = self.metadata_store.query(sql, tuple(params)) - if not rows: - return [] - - scored: List[Dict[str, Any]] = [] - token_count = max(1, len(uniq_tokens)) - for row in rows: - content = str(row.get("content") or "") - content_low = content.lower() - matched = [tok for tok in uniq_tokens if tok in content_low] - if not matched: - continue - coverage = len(matched) / token_count - length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low)) - # 兜底路径使用相对分,保持与上层接口兼容。 - fallback_score = coverage * 0.8 + length_bonus * 0.2 - scored.append( - { - "hash": row["hash"], - "content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content, - "bm25_score": -float(fallback_score), - "fallback_score": float(fallback_score), - } - ) - - scored.sort(key=lambda x: x["fallback_score"], reverse=True) - return scored[:limit] - - def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]: - """执行 BM25 检索。""" - if not self.config.enabled: - return [] - if self.config.lazy_load and not self.loaded: - if not self.ensure_loaded(): - return [] - if not self.loaded: - return [] - # 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。 - self.metadata_store.ensure_relations_fts_schema(conn=self._conn) - self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn) - - tokens = self._tokenize(query) - match_query = self._build_match_query(tokens) - if not match_query: - return [] - - limit = max(1, int(k)) - rows = self.metadata_store.fts_search_bm25( - match_query=match_query, - limit=limit, - max_doc_len=self.config.max_doc_len, - conn=self._conn, - ) - if not rows: - rows = self._fallback_substring_search(tokens=tokens, limit=limit) - - results: List[Dict[str, Any]] = [] - for rank, row in enumerate(rows, start=1): - bm25_score = float(row.get("bm25_score", 0.0)) - results.append( - { - "hash": row["hash"], - "content": row["content"], - "rank": rank, - "bm25_score": bm25_score, - "score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数 - } - ) - return results - - def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]: - """执行关系稀疏检索(FTS5 + BM25)。""" - if not self.config.enabled or not self.config.enable_relation_sparse_fallback: - return [] - if self.config.lazy_load and not self.loaded: - if not self.ensure_loaded(): - return [] - if not self.loaded: - return [] - - tokens = self._tokenize(query) - match_query = self._build_match_query(tokens) - if not match_query: - return [] - - rows = self.metadata_store.fts_search_relations_bm25( - match_query=match_query, - limit=max(1, int(k)), - max_doc_len=self.config.relation_max_doc_len, - conn=self._conn, - ) - out: List[Dict[str, Any]] = [] - for rank, row in enumerate(rows, start=1): - bm25_score = float(row.get("bm25_score", 0.0)) - out.append( - { - "hash": row["hash"], - "subject": row["subject"], - "predicate": row["predicate"], - "object": row["object"], - "content": row["content"], - "rank": rank, - "bm25_score": bm25_score, - "score": -bm25_score, - } - ) - return out - - def upsert_paragraph(self, paragraph_hash: str) -> bool: - if not self.loaded: - return False - return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn) - - def delete_paragraph(self, paragraph_hash: str) -> bool: - if not self.loaded: - return False - return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn) - - def unload(self) -> None: - """卸载 BM25 连接并尽量释放内存。""" - if self._conn is not None: - try: - if self.config.shrink_memory_on_unload: - self.metadata_store.shrink_memory(conn=self._conn) - except Exception: - pass - try: - self._conn.close() - except Exception: - pass - self._conn = None - self._loaded = False - logger.info("SparseBM25Index unloaded") - - def stats(self) -> Dict[str, Any]: - doc_count = 0 - if self.loaded: - doc_count = self.metadata_store.fts_doc_count(conn=self._conn) - return { - "enabled": self.config.enabled, - "backend": self.config.backend, - "mode": self.config.mode, - "tokenizer_mode": self.config.tokenizer_mode, - "enable_ngram_fallback_index": self.config.enable_ngram_fallback_index, - "enable_like_fallback": self.config.enable_like_fallback, - "enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback, - "loaded": self.loaded, - "has_jieba": HAS_JIEBA, - "doc_count": doc_count, - } diff --git a/plugins/A_memorix/core/retrieval/threshold.py b/plugins/A_memorix/core/retrieval/threshold.py deleted file mode 100644 index 87a0094b..00000000 --- a/plugins/A_memorix/core/retrieval/threshold.py +++ /dev/null @@ -1,450 +0,0 @@ -""" -动态阈值过滤器 - -根据检索结果的分布特征自适应调整过滤阈值。 -""" - -import numpy as np -from typing import List, Dict, Any, Optional, Tuple, Union -from dataclasses import dataclass -from enum import Enum - -from src.common.logger import get_logger -from .dual_path import RetrievalResult - -logger = get_logger("A_Memorix.DynamicThresholdFilter") - - -class ThresholdMethod(Enum): - """阈值计算方法""" - - PERCENTILE = "percentile" # 百分位数 - STD_DEV = "std_dev" # 标准差 - GAP_DETECTION = "gap_detection" # 跳变检测 - ADAPTIVE = "adaptive" # 自适应(综合多种方法) - - -@dataclass -class ThresholdConfig: - """ - 阈值配置 - - 属性: - method: 阈值计算方法 - min_threshold: 最小阈值(绝对值) - max_threshold: 最大阈值(绝对值) - percentile: 百分位数(用于percentile方法) - std_multiplier: 标准差倍数(用于std_dev方法) - min_results: 最少保留结果数 - enable_auto_adjust: 是否自动调整参数 - """ - - method: ThresholdMethod = ThresholdMethod.ADAPTIVE - min_threshold: float = 0.3 - max_threshold: float = 0.95 - percentile: float = 75.0 # 百分位数 - std_multiplier: float = 1.5 # 标准差倍数 - min_results: int = 3 # 最少保留结果数 - enable_auto_adjust: bool = True - - def __post_init__(self): - """验证配置""" - if not 0 <= self.min_threshold <= 1: - raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}") - - if not 0 <= self.max_threshold <= 1: - raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}") - - if self.min_threshold >= self.max_threshold: - raise ValueError(f"min_threshold必须小于max_threshold") - - if not 0 <= self.percentile <= 100: - raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}") - - if self.std_multiplier <= 0: - raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}") - - if self.min_results < 0: - raise ValueError(f"min_results必须大于等于0: {self.min_results}") - - -class DynamicThresholdFilter: - """ - 动态阈值过滤器 - - 功能: - - 基于结果分布自适应计算阈值 - - 多种阈值计算方法 - - 自动参数调整 - - 统计信息收集 - - 参数: - config: 阈值配置 - """ - - def __init__( - self, - config: Optional[ThresholdConfig] = None, - ): - """ - 初始化动态阈值过滤器 - - Args: - config: 阈值配置 - """ - self.config = config or ThresholdConfig() - - # 统计信息 - self._total_filtered = 0 - self._total_processed = 0 - self._threshold_history: List[float] = [] - - logger.info( - f"DynamicThresholdFilter 初始化: " - f"method={self.config.method.value}, " - f"min_threshold={self.config.min_threshold}" - ) - - def filter( - self, - results: List[RetrievalResult], - return_threshold: bool = False, - ) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]: - """ - 过滤检索结果 - - Args: - results: 检索结果列表 - return_threshold: 是否返回使用的阈值 - - Returns: - 过滤后的结果列表,或 (结果列表, 阈值) 元组 - """ - if not results: - logger.debug("结果列表为空,无需过滤") - return ([], 0.0) if return_threshold else [] - - self._total_processed += len(results) - - # 提取分数 - scores = np.array([r.score for r in results]) - - # 计算阈值 - threshold = self._compute_threshold(scores, results) - - # 记录阈值 - self._threshold_history.append(threshold) - - # 应用阈值过滤 - filtered_results = [ - r for r in results - if r.score >= threshold - ] - - # 确保至少保留min_results个结果 - if len(filtered_results) < self.config.min_results: - # 按分数排序,取前min_results个 - sorted_results = sorted(results, key=lambda x: x.score, reverse=True) - filtered_results = sorted_results[:self.config.min_results] - threshold = filtered_results[-1].score if filtered_results else 0.0 - - self._total_filtered += len(results) - len(filtered_results) - - logger.info( - f"过滤完成: {len(results)} -> {len(filtered_results)} " - f"(threshold={threshold:.3f})" - ) - - if return_threshold: - return filtered_results, threshold - return filtered_results - - def _compute_threshold( - self, - scores: np.ndarray, - results: List[RetrievalResult], - ) -> float: - """ - 计算阈值 - - Args: - scores: 分数数组 - results: 检索结果列表 - - Returns: - 阈值 - """ - if self.config.method == ThresholdMethod.PERCENTILE: - threshold = self._percentile_threshold(scores) - elif self.config.method == ThresholdMethod.STD_DEV: - threshold = self._std_dev_threshold(scores) - elif self.config.method == ThresholdMethod.GAP_DETECTION: - threshold = self._gap_detection_threshold(scores) - else: # ADAPTIVE - # 自适应方法:综合多种方法 - thresholds = [ - self._percentile_threshold(scores), - self._std_dev_threshold(scores), - self._gap_detection_threshold(scores), - ] - # 使用中位数作为最终阈值 - threshold = float(np.median(thresholds)) - - # 限制在[min_threshold, max_threshold]范围内 - threshold = np.clip( - threshold, - self.config.min_threshold, - self.config.max_threshold, - ) - - # 自动调整 - if self.config.enable_auto_adjust: - threshold = self._auto_adjust_threshold(threshold, scores) - - return float(threshold) - - def _percentile_threshold(self, scores: np.ndarray) -> float: - """ - 基于百分位数计算阈值 - - Args: - scores: 分数数组 - - Returns: - 阈值 - """ - percentile = self.config.percentile - threshold = float(np.percentile(scores, percentile)) - - logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})") - return threshold - - def _std_dev_threshold(self, scores: np.ndarray) -> float: - """ - 基于标准差计算阈值 - - threshold = mean - std_multiplier * std - - Args: - scores: 分数数组 - - Returns: - 阈值 - """ - mean = float(np.mean(scores)) - std = float(np.std(scores)) - multiplier = self.config.std_multiplier - - threshold = mean - multiplier * std - - logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})") - return threshold - - def _gap_detection_threshold(self, scores: np.ndarray) -> float: - """ - 基于跳变检测计算阈值 - - 找到分数分布中最大的"跳变"位置,以此为阈值 - - Args: - scores: 分数数组(降序排列) - - Returns: - 阈值 - """ - # 降序排列 - sorted_scores = np.sort(scores)[::-1] - - if len(sorted_scores) < 2: - return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0 - - # 计算相邻分数的差值 - gaps = np.diff(sorted_scores) - - # 找到最大的跳变位置 - max_gap_idx = int(np.argmax(gaps)) - - # 阈值为跳变后的分数 - threshold = float(sorted_scores[max_gap_idx + 1]) - - logger.debug( - f"跳变检测阈值: {threshold:.3f} " - f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})" - ) - return threshold - - def _auto_adjust_threshold( - self, - threshold: float, - scores: np.ndarray, - ) -> float: - """ - 自动调整阈值 - - 基于历史阈值和当前分数分布调整 - - Args: - threshold: 当前阈值 - scores: 分数数组 - - Returns: - 调整后的阈值 - """ - if not self._threshold_history: - return threshold - - # 计算历史阈值的移动平均 - recent_thresholds = self._threshold_history[-10:] # 最近10次 - avg_threshold = float(np.mean(recent_thresholds)) - - # 当前阈值与历史平均的差异 - diff = threshold - avg_threshold - - # 如果差异过大(>0.2),向历史平均靠拢 - if abs(diff) > 0.2: - adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50% - logger.debug( - f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} " - f"(历史平均={avg_threshold:.3f})" - ) - return adjusted_threshold - - return threshold - - def filter_by_confidence( - self, - results: List[RetrievalResult], - min_confidence: float = 0.5, - ) -> List[RetrievalResult]: - """ - 基于置信度过滤结果 - - Args: - results: 检索结果列表 - min_confidence: 最小置信度 - - Returns: - 过滤后的结果列表 - """ - filtered = [] - for result in results: - # 对于关系结果,使用confidence字段 - if result.result_type == "relation": - confidence = result.metadata.get("confidence", 1.0) - if confidence >= min_confidence: - filtered.append(result) - else: - # 对于段落结果,直接使用分数 - if result.score >= min_confidence: - filtered.append(result) - - logger.info( - f"置信度过滤: {len(results)} -> {len(filtered)} " - f"(min_confidence={min_confidence})" - ) - - return filtered - - def filter_by_diversity( - self, - results: List[RetrievalResult], - similarity_threshold: float = 0.9, - top_k: int = 10, - ) -> List[RetrievalResult]: - """ - 基于多样性过滤结果(去除重复) - - Args: - results: 检索结果列表 - similarity_threshold: 相似度阈值(高于此值视为重复) - top_k: 最多保留结果数 - - Returns: - 过滤后的结果列表 - """ - if not results: - return [] - - # 按分数排序 - sorted_results = sorted(results, key=lambda x: x.score, reverse=True) - - # 贪心选择:选择与已选结果相似度低的结果 - selected = [] - selected_hashes = [] - - for result in sorted_results: - if len(selected) >= top_k: - break - - # 检查与已选结果的相似度 - is_duplicate = False - for selected_hash in selected_hashes: - # 简单判断:基于hash的前缀 - if result.hash_value[:8] == selected_hash[:8]: - is_duplicate = True - break - - if not is_duplicate: - selected.append(result) - selected_hashes.append(result.hash_value) - - logger.info( - f"多样性过滤: {len(results)} -> {len(selected)} " - f"(similarity_threshold={similarity_threshold})" - ) - - return selected - - def get_statistics(self) -> Dict[str, Any]: - """ - 获取统计信息 - - Returns: - 统计信息字典 - """ - filter_rate = ( - self._total_filtered / self._total_processed - if self._total_processed > 0 - else 0.0 - ) - - stats = { - "config": { - "method": self.config.method.value, - "min_threshold": self.config.min_threshold, - "max_threshold": self.config.max_threshold, - "percentile": self.config.percentile, - "std_multiplier": self.config.std_multiplier, - "min_results": self.config.min_results, - "enable_auto_adjust": self.config.enable_auto_adjust, - }, - "statistics": { - "total_processed": self._total_processed, - "total_filtered": self._total_filtered, - "filter_rate": filter_rate, - "avg_threshold": float(np.mean(self._threshold_history)) - if self._threshold_history else 0.0, - "threshold_count": len(self._threshold_history), - }, - } - - if self._threshold_history: - stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history)) - stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history)) - - return stats - - def reset_statistics(self) -> None: - """重置统计信息""" - self._total_filtered = 0 - self._total_processed = 0 - self._threshold_history.clear() - logger.info("统计信息已重置") - - def __repr__(self) -> str: - return ( - f"DynamicThresholdFilter(" - f"method={self.config.method.value}, " - f"min_threshold={self.config.min_threshold}, " - f"filtered={self._total_filtered}/{self._total_processed})" - ) diff --git a/plugins/A_memorix/core/runtime/__init__.py b/plugins/A_memorix/core/runtime/__init__.py deleted file mode 100644 index eece6d21..00000000 --- a/plugins/A_memorix/core/runtime/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -"""SDK runtime exports for A_Memorix.""" - -from .search_runtime_initializer import ( - SearchRuntimeBundle, - SearchRuntimeInitializer, - build_search_runtime, -) -from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel - -__all__ = [ - "SearchRuntimeBundle", - "SearchRuntimeInitializer", - "build_search_runtime", - "KernelSearchRequest", - "SDKMemoryKernel", -] diff --git a/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py b/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py deleted file mode 100644 index 423b55c4..00000000 --- a/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Lifecycle bootstrap/teardown helpers extracted from plugin.py.""" - -from __future__ import annotations - -import asyncio -from pathlib import Path -from typing import Any - -from src.common.logger import get_logger - -from ..embedding import create_embedding_api_adapter -from ..retrieval import SparseBM25Config, SparseBM25Index -from ..storage import ( - GraphStore, - MetadataStore, - QuantizationType, - SparseMatrixFormat, - VectorStore, -) -from ..utils.runtime_self_check import ensure_runtime_self_check -from ..utils.relation_write_service import RelationWriteService - -logger = get_logger("A_Memorix.LifecycleOrchestrator") - - -async def ensure_initialized(plugin: Any) -> None: - if plugin._initialized: - plugin._runtime_ready = plugin._check_storage_ready() - return - - async with plugin._init_lock: - if plugin._initialized: - plugin._runtime_ready = plugin._check_storage_ready() - return - - logger.info("A_Memorix 插件正在异步初始化存储组件...") - plugin._validate_runtime_config() - await initialize_storage_async(plugin) - report = await ensure_runtime_self_check(plugin, force=True) - if not bool(report.get("ok", False)): - logger.error( - "A_Memorix runtime self-check failed: " - f"{report.get('message', 'unknown')}; " - "建议执行 python plugins/A_memorix/scripts/runtime_self_check.py --json" - ) - - if plugin.graph_store and plugin.metadata_store: - relation_count = plugin.metadata_store.count_relations() - if relation_count > 0 and not plugin.graph_store.has_edge_hash_map(): - raise RuntimeError( - "检测到 relations 数据存在但 edge-hash-map 为空。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - - plugin._initialized = True - plugin._runtime_ready = plugin._check_storage_ready() - plugin._update_plugin_config() - logger.info("A_Memorix 插件异步初始化成功") - - -def start_background_tasks(plugin: Any) -> None: - """Start background tasks idempotently.""" - if not hasattr(plugin, "_episode_generation_task"): - plugin._episode_generation_task = None - - if ( - plugin.get_config("summarization.enabled", True) - and plugin.get_config("schedule.enabled", True) - and (plugin._scheduled_import_task is None or plugin._scheduled_import_task.done()) - ): - plugin._scheduled_import_task = asyncio.create_task(plugin._scheduled_import_loop()) - - if ( - plugin.get_config("advanced.enable_auto_save", True) - and (plugin._auto_save_task is None or plugin._auto_save_task.done()) - ): - plugin._auto_save_task = asyncio.create_task(plugin._auto_save_loop()) - - if ( - plugin.get_config("person_profile.enabled", True) - and (plugin._person_profile_refresh_task is None or plugin._person_profile_refresh_task.done()) - ): - plugin._person_profile_refresh_task = asyncio.create_task(plugin._person_profile_refresh_loop()) - - if plugin._memory_maintenance_task is None or plugin._memory_maintenance_task.done(): - plugin._memory_maintenance_task = asyncio.create_task(plugin._memory_maintenance_loop()) - - rv_cfg = plugin.get_config("retrieval.relation_vectorization", {}) or {} - if isinstance(rv_cfg, dict): - rv_enabled = bool(rv_cfg.get("enabled", False)) - rv_backfill = bool(rv_cfg.get("backfill_enabled", False)) - else: - rv_enabled = False - rv_backfill = False - if rv_enabled and rv_backfill and ( - plugin._relation_vector_backfill_task is None or plugin._relation_vector_backfill_task.done() - ): - plugin._relation_vector_backfill_task = asyncio.create_task(plugin._relation_vector_backfill_loop()) - - episode_task = getattr(plugin, "_episode_generation_task", None) - episode_loop = getattr(plugin, "_episode_generation_loop", None) - if ( - callable(episode_loop) - and bool(plugin.get_config("episode.enabled", True)) - and bool(plugin.get_config("episode.generation_enabled", True)) - and (episode_task is None or episode_task.done()) - ): - plugin._episode_generation_task = asyncio.create_task(episode_loop()) - - -async def cancel_background_tasks(plugin: Any) -> None: - """Cancel all background tasks and wait for cleanup.""" - tasks = [ - ("scheduled_import", plugin._scheduled_import_task), - ("auto_save", plugin._auto_save_task), - ("person_profile_refresh", plugin._person_profile_refresh_task), - ("memory_maintenance", plugin._memory_maintenance_task), - ("relation_vector_backfill", plugin._relation_vector_backfill_task), - ("episode_generation", getattr(plugin, "_episode_generation_task", None)), - ] - for _, task in tasks: - if task and not task.done(): - task.cancel() - - for name, task in tasks: - if not task: - continue - try: - await task - except asyncio.CancelledError: - pass - except Exception as e: - logger.warning(f"后台任务 {name} 退出异常: {e}") - - plugin._scheduled_import_task = None - plugin._auto_save_task = None - plugin._person_profile_refresh_task = None - plugin._memory_maintenance_task = None - plugin._relation_vector_backfill_task = None - plugin._episode_generation_task = None - - -async def initialize_storage_async(plugin: Any) -> None: - """Initialize storage components asynchronously.""" - data_dir_str = plugin.get_config("storage.data_dir", "./data") - if data_dir_str.startswith("."): - plugin_dir = Path(__file__).resolve().parents[2] - data_dir = (plugin_dir / data_dir_str).resolve() - else: - data_dir = Path(data_dir_str) - - logger.info(f"A_Memorix 数据存储路径: {data_dir}") - data_dir.mkdir(parents=True, exist_ok=True) - - plugin.embedding_manager = create_embedding_api_adapter( - batch_size=plugin.get_config("embedding.batch_size", 32), - max_concurrent=plugin.get_config("embedding.max_concurrent", 5), - default_dimension=plugin.get_config("embedding.dimension", 1024), - model_name=plugin.get_config("embedding.model_name", "auto"), - retry_config=plugin.get_config("embedding.retry", {}), - ) - logger.info("嵌入 API 适配器初始化完成") - - try: - detected_dimension = await plugin.embedding_manager._detect_dimension() - logger.info(f"嵌入维度检测成功: {detected_dimension}") - except Exception as e: - logger.warning(f"嵌入维度检测失败: {e},使用默认值") - detected_dimension = plugin.embedding_manager.default_dimension - - quantization_str = plugin.get_config("embedding.quantization_type", "int8") - if str(quantization_str or "").strip().lower() != "int8": - raise ValueError("embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。") - quantization_type = QuantizationType.INT8 - - plugin.vector_store = VectorStore( - dimension=detected_dimension, - quantization_type=quantization_type, - data_dir=data_dir / "vectors", - ) - plugin.vector_store.min_train_threshold = plugin.get_config("embedding.min_train_threshold", 40) - logger.info( - "向量存储初始化完成(" - f"维度: {detected_dimension}, " - f"训练阈值: {plugin.vector_store.min_train_threshold})" - ) - - matrix_format_str = plugin.get_config("graph.sparse_matrix_format", "csr") - matrix_format_map = { - "csr": SparseMatrixFormat.CSR, - "csc": SparseMatrixFormat.CSC, - } - matrix_format = matrix_format_map.get(matrix_format_str, SparseMatrixFormat.CSR) - - plugin.graph_store = GraphStore( - matrix_format=matrix_format, - data_dir=data_dir / "graph", - ) - logger.info("图存储初始化完成") - - plugin.metadata_store = MetadataStore(data_dir=data_dir / "metadata") - plugin.metadata_store.connect() - logger.info("元数据存储初始化完成") - - plugin.relation_write_service = RelationWriteService( - metadata_store=plugin.metadata_store, - graph_store=plugin.graph_store, - vector_store=plugin.vector_store, - embedding_manager=plugin.embedding_manager, - ) - logger.info("关系写入服务初始化完成") - - sparse_cfg_raw = plugin.get_config("retrieval.sparse", {}) or {} - if not isinstance(sparse_cfg_raw, dict): - sparse_cfg_raw = {} - try: - sparse_cfg = SparseBM25Config(**sparse_cfg_raw) - except Exception as e: - logger.warning(f"sparse 配置非法,回退默认配置: {e}") - sparse_cfg = SparseBM25Config() - plugin.sparse_index = SparseBM25Index( - metadata_store=plugin.metadata_store, - config=sparse_cfg, - ) - logger.info( - "稀疏检索组件初始化完成: " - f"enabled={sparse_cfg.enabled}, " - f"lazy_load={sparse_cfg.lazy_load}, " - f"mode={sparse_cfg.mode}, " - f"tokenizer={sparse_cfg.tokenizer_mode}" - ) - if sparse_cfg.enabled and not sparse_cfg.lazy_load: - plugin.sparse_index.ensure_loaded() - - if plugin.vector_store.has_data(): - try: - plugin.vector_store.load() - logger.info(f"向量数据已加载,共 {plugin.vector_store.num_vectors} 个向量") - except Exception as e: - logger.warning(f"加载向量数据失败: {e}") - - try: - warmup_summary = plugin.vector_store.warmup_index(force_train=True) - if warmup_summary.get("ok"): - logger.info( - "向量索引预热完成: " - f"trained={warmup_summary.get('trained')}, " - f"index_ntotal={warmup_summary.get('index_ntotal')}, " - f"fallback_ntotal={warmup_summary.get('fallback_ntotal')}, " - f"bin_count={warmup_summary.get('bin_count')}, " - f"duration_ms={float(warmup_summary.get('duration_ms', 0.0)):.2f}" - ) - else: - logger.warning( - "向量索引预热失败,继续启用 sparse 降级路径: " - f"{warmup_summary.get('error', 'unknown')}" - ) - except Exception as e: - logger.warning(f"向量索引预热异常,继续启用 sparse 降级路径: {e}") - - if plugin.graph_store.has_data(): - try: - plugin.graph_store.load() - logger.info(f"图数据已加载,共 {plugin.graph_store.num_nodes} 个节点") - except Exception as e: - logger.warning(f"加载图数据失败: {e}") - - logger.info(f"知识库数据目录: {data_dir}") diff --git a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py deleted file mode 100644 index 93c11bf7..00000000 --- a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py +++ /dev/null @@ -1,3162 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import pickle -import time -import uuid -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Sequence - -from src.common.logger import get_logger - -from ..embedding import create_embedding_api_adapter -from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions -from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore -from ..utils.aggregate_query_service import AggregateQueryService -from ..utils.episode_retrieval_service import EpisodeRetrievalService -from ..utils.episode_segmentation_service import EpisodeSegmentationService -from ..utils.episode_service import EpisodeService -from ..utils.hash import compute_hash, normalize_text -from ..utils.person_profile_service import PersonProfileService -from ..utils.relation_write_service import RelationWriteService -from ..utils.retrieval_tuning_manager import RetrievalTuningManager -from ..utils.runtime_self_check import run_embedding_runtime_self_check -from ..utils.search_execution_service import SearchExecutionRequest, SearchExecutionService -from ..utils.summary_importer import SummaryImporter -from ..utils.time_parser import format_timestamp, parse_query_datetime_to_timestamp -from ..utils.web_import_manager import ImportTaskManager -from .search_runtime_initializer import SearchRuntimeBundle, build_search_runtime - -logger = get_logger("A_Memorix.SDKMemoryKernel") - - -@dataclass -class KernelSearchRequest: - query: str = "" - limit: int = 5 - mode: str = "search" - chat_id: str = "" - person_id: str = "" - time_start: Optional[str | float] = None - time_end: Optional[str | float] = None - respect_filter: bool = True - user_id: str = "" - group_id: str = "" - - -@dataclass -class _NormalizedSearchTimeWindow: - numeric_start: Optional[float] = None - numeric_end: Optional[float] = None - query_start: Optional[str] = None - query_end: Optional[str] = None - - -class _KernelRuntimeFacade: - def __init__(self, kernel: "SDKMemoryKernel") -> None: - self._kernel = kernel - self.config = kernel.config - self._plugin_config = kernel.config - self._runtime_self_check_report: Dict[str, Any] = {} - - def get_config(self, key: str, default: Any = None) -> Any: - return self._kernel._cfg(key, default) - - def is_runtime_ready(self) -> bool: - return self._kernel.is_runtime_ready() - - def is_chat_enabled(self, stream_id: str, group_id: str | None = None, user_id: str | None = None) -> bool: - return self._kernel.is_chat_enabled(stream_id=stream_id, group_id=group_id, user_id=user_id) - - async def reinforce_access(self, relation_hashes: Sequence[str]) -> None: - if self._kernel.metadata_store is None: - return - hashes = [str(item or "").strip() for item in relation_hashes if str(item or "").strip()] - if not hashes: - return - self._kernel.metadata_store.reinforce_relations(hashes) - self._kernel._last_maintenance_at = time.time() - - async def execute_request_with_dedup( - self, - request_key: str, - executor: Callable[[], Awaitable[Dict[str, Any]]], - ) -> tuple[bool, Dict[str, Any]]: - return await self._kernel.execute_request_with_dedup(request_key, executor) - - @property - def vector_store(self) -> Optional[VectorStore]: - return self._kernel.vector_store - - @property - def graph_store(self) -> Optional[GraphStore]: - return self._kernel.graph_store - - @property - def metadata_store(self) -> Optional[MetadataStore]: - return self._kernel.metadata_store - - @property - def embedding_manager(self): - return self._kernel.embedding_manager - - @property - def sparse_index(self): - return self._kernel.sparse_index - - @property - def relation_write_service(self) -> Optional[RelationWriteService]: - return self._kernel.relation_write_service - - -class SDKMemoryKernel: - def __init__(self, *, plugin_root: Path, config: Optional[Dict[str, Any]] = None) -> None: - self.plugin_root = Path(plugin_root).resolve() - self.config = config or {} - storage_cfg = self._cfg("storage", {}) or {} - data_dir = str(storage_cfg.get("data_dir", "./data") or "./data") - self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir) - self.embedding_dimension = max(1, int(self._cfg("embedding.dimension", 1024))) - self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False)) - - self.embedding_manager = None - self.vector_store: Optional[VectorStore] = None - self.graph_store: Optional[GraphStore] = None - self.metadata_store: Optional[MetadataStore] = None - self.relation_write_service: Optional[RelationWriteService] = None - self.sparse_index: Optional[SparseBM25Index] = None - self.retriever = None - self.threshold_filter = None - self.episode_retriever: Optional[EpisodeRetrievalService] = None - self.aggregate_query_service: Optional[AggregateQueryService] = None - self.person_profile_service: Optional[PersonProfileService] = None - self.episode_segmentation_service: Optional[EpisodeSegmentationService] = None - self.episode_service: Optional[EpisodeService] = None - self.summary_importer: Optional[SummaryImporter] = None - self.import_task_manager: Optional[ImportTaskManager] = None - self.retrieval_tuning_manager: Optional[RetrievalTuningManager] = None - self._runtime_bundle: Optional[SearchRuntimeBundle] = None - self._runtime_facade = _KernelRuntimeFacade(self) - self._initialized = False - self._last_maintenance_at: Optional[float] = None - self._request_dedup_tasks: Dict[str, asyncio.Task] = {} - self._background_tasks: Dict[str, asyncio.Task] = {} - self._background_lock = asyncio.Lock() - self._background_stopping = False - self._active_person_timestamps: Dict[str, float] = {} - - def _cfg(self, key: str, default: Any = None) -> Any: - current: Any = self.config - if key in {"storage", "embedding", "retrieval", "graph", "episode", "web", "advanced", "threshold", "summarization"} and isinstance(current, dict): - return current.get(key, default) - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - def _set_cfg(self, key: str, value: Any) -> None: - current: Dict[str, Any] = self.config - parts = [part for part in str(key or "").split(".") if part] - if not parts: - return - for part in parts[:-1]: - next_value = current.get(part) - if not isinstance(next_value, dict): - next_value = {} - current[part] = next_value - current = next_value - current[parts[-1]] = value - - def _build_runtime_config(self) -> Dict[str, Any]: - runtime_config = dict(self.config) - runtime_config.update( - { - "vector_store": self.vector_store, - "graph_store": self.graph_store, - "metadata_store": self.metadata_store, - "embedding_manager": self.embedding_manager, - "sparse_index": self.sparse_index, - "relation_write_service": self.relation_write_service, - "plugin_instance": self._runtime_facade, - } - ) - return runtime_config - - def is_runtime_ready(self) -> bool: - return bool( - self._initialized - and self.vector_store is not None - and self.graph_store is not None - and self.metadata_store is not None - and self.embedding_manager is not None - and self.retriever is not None - ) - - def is_chat_enabled(self, stream_id: str, group_id: str | None = None, user_id: str | None = None) -> bool: - filter_config = self._cfg("filter", {}) or {} - if not isinstance(filter_config, dict) or not filter_config: - return True - - if not bool(filter_config.get("enabled", True)): - return True - - mode = str(filter_config.get("mode", "blacklist") or "blacklist").strip().lower() - patterns = filter_config.get("chats") or [] - if not isinstance(patterns, list): - patterns = [] - - if not patterns: - return mode == "blacklist" - - stream_token = str(stream_id or "").strip() - group_token = str(group_id or "").strip() - user_token = str(user_id or "").strip() - candidates = {token for token in (stream_token, group_token, user_token) if token} - - matched = False - for raw_pattern in patterns: - pattern = str(raw_pattern or "").strip() - if not pattern: - continue - if ":" in pattern: - prefix, value = pattern.split(":", 1) - prefix = prefix.strip().lower() - value = value.strip() - if prefix == "group" and value and value == group_token: - matched = True - elif prefix in {"user", "private"} and value and value == user_token: - matched = True - elif prefix == "stream" and value and value == stream_token: - matched = True - elif pattern in candidates: - matched = True - - if matched: - break - - if mode == "blacklist": - return not matched - return matched - - def _is_chat_filtered( - self, - *, - respect_filter: bool, - stream_id: str = "", - group_id: str = "", - user_id: str = "", - ) -> bool: - if not bool(respect_filter): - return False - - stream_token = str(stream_id or "").strip() - group_token = str(group_id or "").strip() - user_token = str(user_id or "").strip() - if not (stream_token or group_token or user_token): - return False - return not self.is_chat_enabled(stream_token, group_token, user_token) - - def _stored_vector_dimension(self) -> Optional[int]: - meta_path = self.data_dir / "vectors" / "vectors_metadata.pkl" - if not meta_path.exists(): - return None - try: - with open(meta_path, "rb") as handle: - meta = pickle.load(handle) - except Exception as exc: - logger.warning(f"读取向量元数据失败,将回退到 runtime self-check: {exc}") - return None - try: - value = int(meta.get("dimension") or 0) - except Exception: - return None - return value if value > 0 else None - - def _vector_mismatch_error(self, *, stored_dimension: int, detected_dimension: int) -> str: - return ( - "检测到现有向量库与当前 embedding 输出维度不一致:" - f"stored={stored_dimension}, encoded={detected_dimension}。" - " 当前版本不会兼容 hash 时代或其他维度的旧向量,请改回原 embedding 配置," - "或执行重嵌入/重建向量。" - ) - - async def initialize(self) -> None: - if self._initialized: - await self._start_background_tasks() - return - - self.data_dir.mkdir(parents=True, exist_ok=True) - self.embedding_manager = create_embedding_api_adapter( - batch_size=int(self._cfg("embedding.batch_size", 32)), - max_concurrent=int(self._cfg("embedding.max_concurrent", 5)), - default_dimension=self.embedding_dimension, - enable_cache=bool(self._cfg("embedding.enable_cache", False)), - model_name=str(self._cfg("embedding.model_name", "auto") or "auto"), - retry_config=self._cfg("embedding.retry", {}) or {}, - ) - detected_dimension = int(await self.embedding_manager._detect_dimension()) - self.embedding_dimension = detected_dimension - - stored_dimension = self._stored_vector_dimension() - if stored_dimension is not None and stored_dimension != detected_dimension: - raise RuntimeError( - self._vector_mismatch_error( - stored_dimension=stored_dimension, - detected_dimension=detected_dimension, - ) - ) - - matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower() - graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR - - self.vector_store = VectorStore( - dimension=detected_dimension, - quantization_type=QuantizationType.INT8, - data_dir=self.data_dir / "vectors", - ) - self.graph_store = GraphStore(matrix_format=graph_format, data_dir=self.data_dir / "graph") - self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata") - self.metadata_store.connect() - - if self.vector_store.has_data(): - self.vector_store.load() - self.vector_store.warmup_index(force_train=True) - if self.graph_store.has_data(): - self.graph_store.load() - - sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {} - try: - sparse_cfg = SparseBM25Config(**sparse_cfg_raw) - except Exception as exc: - logger.warning(f"sparse 配置非法,回退默认: {exc}") - sparse_cfg = SparseBM25Config() - self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=sparse_cfg) - if getattr(self.sparse_index.config, "enabled", False): - self.sparse_index.ensure_loaded() - - self.relation_write_service = RelationWriteService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - - runtime_config = self._build_runtime_config() - self._runtime_bundle = build_search_runtime( - plugin_config=runtime_config, - logger_obj=logger, - owner_tag="sdk_kernel", - log_prefix="[sdk]", - ) - if not self._runtime_bundle.ready: - raise RuntimeError(self._runtime_bundle.error or "检索运行时初始化失败") - - self.retriever = self._runtime_bundle.retriever - self.threshold_filter = self._runtime_bundle.threshold_filter - self.sparse_index = self._runtime_bundle.sparse_index or self.sparse_index - - runtime_config = self._build_runtime_config() - self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever) - self.aggregate_query_service = AggregateQueryService(plugin_config=runtime_config) - self.person_profile_service = PersonProfileService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - sparse_index=self.sparse_index, - plugin_config=runtime_config, - retriever=self.retriever, - ) - self.episode_segmentation_service = EpisodeSegmentationService(plugin_config=runtime_config) - self.episode_service = EpisodeService( - metadata_store=self.metadata_store, - plugin_config=runtime_config, - segmentation_service=self.episode_segmentation_service, - ) - self.summary_importer = SummaryImporter( - vector_store=self.vector_store, - graph_store=self.graph_store, - metadata_store=self.metadata_store, - embedding_manager=self.embedding_manager, - plugin_config=runtime_config, - ) - self.import_task_manager = ImportTaskManager(self._runtime_facade) - self.retrieval_tuning_manager = RetrievalTuningManager( - self._runtime_facade, - import_write_blocked_provider=self.import_task_manager.is_write_blocked, - ) - - report = await run_embedding_runtime_self_check( - config=runtime_config, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - sample_text="A_Memorix runtime self check", - ) - self._runtime_facade._runtime_self_check_report = dict(report) - if not bool(report.get("ok", False)): - message = str(report.get("message", "runtime self-check failed") or "runtime self-check failed") - raise RuntimeError(f"{message};请改回原 embedding 配置,或执行重嵌入/重建向量。") - - self._initialized = True - await self._start_background_tasks() - - async def shutdown(self) -> None: - await self._stop_background_tasks() - if self.import_task_manager is not None: - try: - await self.import_task_manager.shutdown() - except Exception as exc: - logger.warning(f"关闭导入任务管理器失败: {exc}") - if self.retrieval_tuning_manager is not None: - try: - await self.retrieval_tuning_manager.shutdown() - except Exception as exc: - logger.warning(f"关闭调优任务管理器失败: {exc}") - self.close() - - def close(self) -> None: - try: - self._persist() - finally: - if self.metadata_store is not None: - self.metadata_store.close() - self._initialized = False - self._request_dedup_tasks.clear() - self._runtime_facade._runtime_self_check_report = {} - self._background_tasks.clear() - self._active_person_timestamps.clear() - - async def execute_request_with_dedup( - self, - request_key: str, - executor: Callable[[], Awaitable[Dict[str, Any]]], - ) -> tuple[bool, Dict[str, Any]]: - token = str(request_key or "").strip() - if not token: - return False, await executor() - - existing = self._request_dedup_tasks.get(token) - if existing is not None: - return True, await existing - - task = asyncio.create_task(executor()) - self._request_dedup_tasks[token] = task - try: - payload = await task - return False, payload - finally: - current = self._request_dedup_tasks.get(token) - if current is task: - self._request_dedup_tasks.pop(token, None) - - async def summarize_chat_stream( - self, - *, - chat_id: str, - context_length: Optional[int] = None, - include_personality: Optional[bool] = None, - ) -> Dict[str, Any]: - await self.initialize() - assert self.summary_importer - success, detail = await self.summary_importer.import_from_stream( - stream_id=str(chat_id or "").strip(), - context_length=context_length, - include_personality=include_personality, - ) - if success: - await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])]) - self._persist() - return {"success": bool(success), "detail": detail} - - async def ingest_summary( - self, - *, - external_id: str, - chat_id: str, - text: str, - participants: Optional[Sequence[str]] = None, - time_start: Optional[float] = None, - time_end: Optional[float] = None, - tags: Optional[Sequence[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - respect_filter: bool = True, - user_id: str = "", - group_id: str = "", - ) -> Dict[str, Any]: - external_token = str(external_id or "").strip() or compute_hash(f"chat_summary:{chat_id}:{text}") - if self._is_chat_filtered( - respect_filter=respect_filter, - stream_id=chat_id, - group_id=group_id, - user_id=user_id, - ): - return { - "success": True, - "stored_ids": [], - "skipped_ids": [external_token], - "detail": "chat_filtered", - } - - summary_meta = dict(metadata or {}) - summary_meta.setdefault("kind", "chat_summary") - if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): - result = await self.summarize_chat_stream( - chat_id=chat_id, - context_length=self._optional_int(summary_meta.get("context_length")), - include_personality=summary_meta.get("include_personality"), - ) - result.setdefault("external_id", external_id) - result.setdefault("chat_id", chat_id) - return result - return await self.ingest_text( - external_id=external_id, - source_type="chat_summary", - text=text, - chat_id=chat_id, - participants=participants, - time_start=time_start, - time_end=time_end, - tags=tags, - metadata=summary_meta, - respect_filter=respect_filter, - user_id=user_id, - group_id=group_id, - ) - - async def ingest_text( - self, - *, - external_id: str, - source_type: str, - text: str, - chat_id: str = "", - person_ids: Optional[Sequence[str]] = None, - participants: Optional[Sequence[str]] = None, - timestamp: Optional[float] = None, - time_start: Optional[float] = None, - time_end: Optional[float] = None, - tags: Optional[Sequence[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - entities: Optional[Sequence[str]] = None, - relations: Optional[Sequence[Dict[str, Any]]] = None, - respect_filter: bool = True, - user_id: str = "", - group_id: str = "", - ) -> Dict[str, Any]: - content = normalize_text(text) - external_token = str(external_id or "").strip() or compute_hash(f"{source_type}:{chat_id}:{content}") - if self._is_chat_filtered( - respect_filter=respect_filter, - stream_id=chat_id, - group_id=group_id, - user_id=user_id, - ): - return { - "success": True, - "stored_ids": [], - "skipped_ids": [external_token], - "detail": "chat_filtered", - } - - await self.initialize() - assert self.metadata_store is not None - assert self.vector_store is not None - assert self.graph_store is not None - assert self.embedding_manager is not None - assert self.relation_write_service is not None - - if not content: - return {"stored_ids": [], "skipped_ids": [external_token], "reason": "empty_text"} - - existing_ref = self.metadata_store.get_external_memory_ref(external_token) - if existing_ref: - return { - "stored_ids": [], - "skipped_ids": [str(existing_ref.get("paragraph_hash", "") or "")], - "reason": "exists", - } - - person_tokens = self._tokens(person_ids) - participant_tokens = self._tokens(participants) - entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) - source = self._build_source(source_type, chat_id, person_tokens) - paragraph_meta = dict(metadata or {}) - paragraph_meta.update( - { - "external_id": external_token, - "source_type": str(source_type or "").strip(), - "chat_id": str(chat_id or "").strip(), - "person_ids": person_tokens, - "participants": participant_tokens, - "tags": self._tokens(tags), - } - ) - - paragraph_hash = self.metadata_store.add_paragraph( - content=content, - source=source, - metadata=paragraph_meta, - knowledge_type=self._resolve_knowledge_type(source_type), - time_meta=self._time_meta(timestamp, time_start, time_end), - ) - embedding = await self.embedding_manager.encode(content) - self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash]) - - for name in entity_tokens: - self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash) - - stored_relations: List[str] = [] - for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]: - subject = str(row.get("subject", "") or "").strip() - predicate = str(row.get("predicate", "") or "").strip() - obj = str(row.get("object", "") or "").strip() - if not (subject and predicate and obj): - continue - result = await self.relation_write_service.upsert_relation_with_vector( - subject=subject, - predicate=predicate, - obj=obj, - confidence=float(row.get("confidence", 1.0) or 1.0), - source_paragraph=paragraph_hash, - metadata=row.get("metadata") if isinstance(row.get("metadata"), dict) else {"external_id": external_token, "source_type": source_type}, - write_vector=self.relation_vectors_enabled, - ) - self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value) - stored_relations.append(result.hash_value) - - self.metadata_store.upsert_external_memory_ref( - external_id=external_token, - paragraph_hash=paragraph_hash, - source_type=source_type, - metadata={"chat_id": chat_id, "person_ids": person_tokens}, - ) - self.metadata_store.enqueue_episode_pending(paragraph_hash, source=source) - self._persist() - await self.process_episode_pending_batch( - limit=max(1, int(self._cfg("episode.pending_batch_size", 12))), - max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3))), - ) - for person_id in person_tokens: - self._mark_person_active(person_id) - await self.refresh_person_profile(person_id) - return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []} - - async def process_episode_pending_batch(self, *, limit: int = 20, max_retry: int = 3) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store is not None - assert self.episode_service is not None - - pending_rows = self.metadata_store.fetch_episode_pending_batch(limit=max(1, int(limit)), max_retry=max(1, int(max_retry))) - if not pending_rows: - return {"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0} - - source_to_hashes: Dict[str, List[str]] = {} - pending_hashes = [str(row.get("paragraph_hash", "") or "").strip() for row in pending_rows if str(row.get("paragraph_hash", "") or "").strip()] - for row in pending_rows: - paragraph_hash = str(row.get("paragraph_hash", "") or "").strip() - source = str(row.get("source", "") or "").strip() - if not paragraph_hash or not source: - continue - source_to_hashes.setdefault(source, []).append(paragraph_hash) - - if pending_hashes: - self.metadata_store.mark_episode_pending_running(pending_hashes) - - result = await self.episode_service.process_pending_rows(pending_rows) - done_hashes = [str(item or "").strip() for item in result.get("done_hashes", []) if str(item or "").strip()] - failed_hashes = { - str(hash_value or "").strip(): str(error or "").strip() - for hash_value, error in (result.get("failed_hashes", {}) or {}).items() - if str(hash_value or "").strip() - } - - if done_hashes: - self.metadata_store.mark_episode_pending_done(done_hashes) - for hash_value, error in failed_hashes.items(): - self.metadata_store.mark_episode_pending_failed(hash_value, error) - - untouched = [hash_value for hash_value in pending_hashes if hash_value not in set(done_hashes) and hash_value not in failed_hashes] - for hash_value in untouched: - self.metadata_store.mark_episode_pending_failed(hash_value, "episode processing finished without explicit status") - - for source, paragraph_hashes in source_to_hashes.items(): - counts = self.metadata_store.get_episode_pending_status_counts(source) - if counts.get("failed", 0) > 0: - source_error = next( - ( - failed_hashes.get(hash_value) - for hash_value in paragraph_hashes - if failed_hashes.get(hash_value) - ), - "episode pending source contains failed rows", - ) - self.metadata_store.mark_episode_source_failed(source, str(source_error or "episode pending source contains failed rows")) - elif counts.get("pending", 0) == 0 and counts.get("running", 0) == 0: - self.metadata_store.mark_episode_source_done(source) - - self._persist() - return { - "processed": len(done_hashes) + len(failed_hashes), - "episode_count": int(result.get("episode_count") or 0), - "fallback_count": int(result.get("fallback_count") or 0), - "failed": len(failed_hashes) + len(untouched), - "group_count": int(result.get("group_count") or 0), - "missing_count": int(result.get("missing_count") or 0), - } - - async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]: - if self._is_chat_filtered( - respect_filter=request.respect_filter, - stream_id=request.chat_id, - group_id=request.group_id, - user_id=request.user_id, - ): - return {"summary": "", "hits": [], "filtered": True} - - await self.initialize() - assert self.retriever is not None - assert self.episode_retriever is not None - assert self.aggregate_query_service is not None - - mode = str(request.mode or "search").strip().lower() or "search" - query = str(request.query or "").strip() - limit = max(1, int(request.limit or 5)) - supported_modes = {"search", "time", "hybrid", "episode", "aggregate"} - if mode not in supported_modes: - return { - "summary": "", - "hits": [], - "error": ( - f"不支持的检索模式: {mode}(仅支持 search/time/hybrid/episode/aggregate," - "semantic 已移除)" - ), - } - try: - time_window = self._normalize_search_time_window(request.time_start, request.time_end) - except ValueError as exc: - return {"summary": "", "hits": [], "error": str(exc)} - - if mode == "episode": - rows = await self.episode_retriever.query( - query=query, - top_k=limit, - time_from=time_window.numeric_start, - time_to=time_window.numeric_end, - person=request.person_id or None, - source=self._chat_source(request.chat_id), - ) - hits = [self._episode_hit(row) for row in rows] - return {"summary": self._summary(hits), "hits": hits} - - if mode == "aggregate": - payload = await self.aggregate_query_service.execute( - query=query, - top_k=limit, - mix=True, - mix_top_k=limit, - time_from=time_window.query_start, - time_to=time_window.query_end, - search_runner=lambda: self._aggregate_search(query, limit, request), - time_runner=lambda: self._aggregate_time(query, limit, request, time_window), - episode_runner=lambda: self._aggregate_episode(query, limit, request, time_window), - ) - hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)] - for item in hits: - item.setdefault("metadata", {}) - filtered = self._filter_hits(hits, request.person_id) - return {"summary": self._summary(filtered), "hits": filtered} - - query_type = mode - runtime_config = self._build_runtime_config() - result = await SearchExecutionService.execute( - retriever=self.retriever, - threshold_filter=self.threshold_filter, - plugin_config=runtime_config, - request=SearchExecutionRequest( - caller="sdk_memory_kernel", - stream_id=str(request.chat_id or "") or None, - group_id=str(request.group_id or "") or None, - user_id=str(request.user_id or "") or None, - query_type=query_type, - query=query, - top_k=limit, - time_from=time_window.query_start, - time_to=time_window.query_end, - person=str(request.person_id or "") or None, - source=self._chat_source(request.chat_id), - use_threshold=True, - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ), - enforce_chat_filter=bool(request.respect_filter), - reinforce_access=True, - ) - if not result.success: - return {"summary": "", "hits": [], "error": result.error} - if result.chat_filtered: - return {"summary": "", "hits": [], "filtered": True} - - hits = [self._retrieval_result_hit(item) for item in result.results] - filtered = self._filter_hits(hits, request.person_id) - return {"summary": self._summary(filtered), "hits": filtered} - - async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]: - del chat_id - await self.initialize() - assert self.metadata_store is not None - assert self.person_profile_service is not None - self._mark_person_active(person_id) - profile = await self.person_profile_service.query_person_profile( - person_id=person_id, - top_k=max(4, int(limit or 10)), - source_note="sdk_memory_kernel.get_person_profile", - ) - if not profile.get("success"): - return {"summary": "", "traits": [], "evidence": []} - - evidence = [] - for hash_value in profile.get("evidence_ids", [])[: max(1, int(limit))]: - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is not None: - evidence.append( - { - "hash": hash_value, - "content": str(paragraph.get("content", "") or "")[:220], - "metadata": paragraph.get("metadata", {}) or {}, - "type": "paragraph", - } - ) - continue - - relation = self.metadata_store.get_relation(hash_value) - if relation is not None: - evidence.append( - { - "hash": hash_value, - "content": " ".join( - [ - str(relation.get("subject", "") or "").strip(), - str(relation.get("predicate", "") or "").strip(), - str(relation.get("object", "") or "").strip(), - ] - ).strip(), - "metadata": { - "confidence": relation.get("confidence"), - "source_paragraph": relation.get("source_paragraph"), - }, - "type": "relation", - } - ) - - text = str(profile.get("profile_text", "") or "").strip() - traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8] - return { - "summary": text, - "traits": traits, - "evidence": evidence, - "person_id": str(profile.get("person_id", "") or person_id), - "person_name": str(profile.get("person_name", "") or ""), - "profile_source": str(profile.get("profile_source", "") or "auto_snapshot"), - "has_manual_override": bool(profile.get("has_manual_override", False)), - } - - async def refresh_person_profile(self, person_id: str, limit: int = 10, *, mark_active: bool = True) -> Dict[str, Any]: - await self.initialize() - assert self.person_profile_service - if mark_active: - self._mark_person_active(person_id) - profile = await self.person_profile_service.query_person_profile( - person_id=person_id, - top_k=max(4, int(limit or 10)), - force_refresh=True, - source_note="sdk_memory_kernel.refresh_person_profile", - ) - return profile if isinstance(profile, dict) else {} - - async def maintain_memory( - self, - *, - action: str, - target: str = "", - hours: Optional[float] = None, - reason: str = "", - limit: int = 50, - ) -> Dict[str, Any]: - del reason - await self.initialize() - assert self.metadata_store - act = str(action or "").strip().lower() - if act == "recycle_bin": - items = self.metadata_store.get_deleted_relations(limit=max(1, int(limit or 50))) - return {"success": True, "items": items, "count": len(items)} - - hashes = self._resolve_deleted_relation_hashes(target) if act == "restore" else self._resolve_relation_hashes(target) - if not hashes: - return {"success": False, "detail": "未命中可维护关系"} - - if act == "reinforce": - self.metadata_store.reinforce_relations(hashes) - elif act == "freeze": - self.metadata_store.mark_relations_inactive(hashes) - self._rebuild_graph_from_metadata() - elif act == "protect": - ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0 - self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0) - elif act == "restore": - restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value)) - if restored <= 0: - return {"success": False, "detail": "未恢复任何关系"} - self._rebuild_graph_from_metadata() - else: - return {"success": False, "detail": f"不支持的维护动作: {act}"} - - self._last_maintenance_at = time.time() - self._persist() - return {"success": True, "detail": f"{act} {len(hashes)} 条关系"} - - async def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store is not None - assert self.episode_service is not None - - items: List[Dict[str, Any]] = [] - failures: List[Dict[str, str]] = [] - for source in self._tokens(sources): - self.metadata_store.mark_episode_source_running(source) - try: - result = await self.episode_service.rebuild_source(source) - self.metadata_store.mark_episode_source_done(source) - items.append(result) - except Exception as exc: - err = str(exc)[:500] - self.metadata_store.mark_episode_source_failed(source, err) - failures.append({"source": source, "error": err}) - self._persist() - return { - "rebuilt": len(items), - "items": items, - "failures": failures, - "sources": [str(item.get("source", "") or "") for item in items] or self._tokens(sources), - } - - def memory_stats(self) -> Dict[str, Any]: - assert self.metadata_store - stats = self.metadata_store.get_statistics() - episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"] - profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"] - pending = self.metadata_store.query( - "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" - )[0]["c"] - return { - "paragraphs": int(stats.get("paragraph_count", 0) or 0), - "relations": int(stats.get("relation_count", 0) or 0), - "episodes": int(episodes or 0), - "profiles": int(profiles or 0), - "episode_pending": int(pending or 0), - "last_maintenance_at": self._last_maintenance_at, - } - - async def memory_graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store is not None - assert self.graph_store is not None - - act = str(action or "").strip().lower() - if act == "get_graph": - return {"success": True, **self._serialize_graph(limit=max(1, int(kwargs.get("limit", 200) or 200)))} - - if act == "create_node": - name = str(kwargs.get("name", "") or kwargs.get("node", "") or "").strip() - if not name: - return {"success": False, "error": "node name 不能为空"} - entity_hash = self.metadata_store.add_entity(name=name, metadata=kwargs.get("metadata") or {}) - self._rebuild_graph_from_metadata() - self._persist() - return {"success": True, "node": {"name": name, "hash": entity_hash}} - - if act == "delete_node": - name = str(kwargs.get("name", "") or kwargs.get("node", "") or kwargs.get("hash_or_name", "") or "").strip() - if not name: - return {"success": False, "error": "node name 不能为空"} - result = await self._execute_delete_action( - mode="entity", - selector={"query": name}, - requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), - reason=str(kwargs.get("reason", "") or "graph_delete_node"), - ) - return { - "success": bool(result.get("success", False)), - "deleted": bool(result.get("deleted_count", 0)), - "node": name, - "operation_id": result.get("operation_id", ""), - "counts": result.get("counts", {}), - "error": result.get("error", ""), - } - - if act == "rename_node": - old_name = str(kwargs.get("name", "") or kwargs.get("old_name", "") or kwargs.get("node", "") or "").strip() - new_name = str(kwargs.get("new_name", "") or kwargs.get("target_name", "") or "").strip() - return self._rename_node(old_name, new_name) - - if act == "create_edge": - subject = str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip() - predicate = str(kwargs.get("predicate", "") or kwargs.get("label", "") or "").strip() - obj = str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip() - if not all([subject, predicate, obj]): - return {"success": False, "error": "subject/predicate/object 不能为空"} - if self.relation_write_service is not None: - result = await self.relation_write_service.upsert_relation_with_vector( - subject=subject, - predicate=predicate, - obj=obj, - confidence=float(kwargs.get("confidence", 1.0) or 1.0), - source_paragraph=str(kwargs.get("source_paragraph", "") or "") or None, - metadata=kwargs.get("metadata") or {}, - write_vector=self.relation_vectors_enabled, - ) - relation_hash = result.hash_value - else: - relation_hash = self.metadata_store.add_relation( - subject=subject, - predicate=predicate, - obj=obj, - confidence=float(kwargs.get("confidence", 1.0) or 1.0), - source_paragraph=kwargs.get("source_paragraph"), - metadata=kwargs.get("metadata") or {}, - ) - self._rebuild_graph_from_metadata() - self._persist() - return { - "success": True, - "edge": { - "hash": relation_hash, - "subject": subject, - "predicate": predicate, - "object": obj, - "weight": float(kwargs.get("confidence", 1.0) or 1.0), - }, - } - - if act == "delete_edge": - relation_hash = str(kwargs.get("hash", "") or kwargs.get("relation_hash", "") or "").strip() - if relation_hash: - result = await self._execute_delete_action( - mode="relation", - selector={"query": relation_hash}, - requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), - reason=str(kwargs.get("reason", "") or "graph_delete_edge"), - ) - return { - "success": bool(result.get("success", False)), - "deleted": int(result.get("deleted_count", 0)), - "hash": relation_hash, - "operation_id": result.get("operation_id", ""), - "counts": result.get("counts", {}), - "error": result.get("error", ""), - } - - subject = str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip() - obj = str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip() - deleted_hashes = [ - str(row.get("hash", "") or "") - for row in self.metadata_store.get_relations(subject=subject) - if str(row.get("object", "") or "").strip() == obj - ] - result = await self._execute_delete_action( - mode="relation", - selector={"hashes": deleted_hashes, "subject": subject, "object": obj}, - requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), - reason=str(kwargs.get("reason", "") or "graph_delete_edge"), - ) - return { - "success": bool(result.get("success", False)), - "deleted": int(result.get("deleted_count", 0)), - "subject": subject, - "object": obj, - "operation_id": result.get("operation_id", ""), - "counts": result.get("counts", {}), - "error": result.get("error", ""), - } - - if act == "update_edge_weight": - return self._update_edge_weight( - relation_hash=str(kwargs.get("hash", "") or kwargs.get("relation_hash", "") or "").strip(), - subject=str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip(), - obj=str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip(), - weight=float(kwargs.get("weight", kwargs.get("confidence", 1.0)) or 1.0), - ) - - return {"success": False, "error": f"不支持的 graph action: {act}"} - - async def memory_source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store - - act = str(action or "").strip().lower() - if act == "list": - sources = self.metadata_store.get_all_sources() - items = [] - for row in sources: - source_name = str(row.get("source", "") or "").strip() - items.append( - { - **row, - "episode_rebuild_blocked": self.metadata_store.is_episode_source_query_blocked(source_name), - } - ) - return {"success": True, "items": items, "count": len(items)} - - if act == "delete": - source = str(kwargs.get("source", "") or "").strip() - return await self._execute_delete_action( - mode="source", - selector={"sources": [source]}, - requested_by=str(kwargs.get("requested_by", "") or "memory_source_admin"), - reason=str(kwargs.get("reason", "") or "source_delete"), - ) - - if act == "batch_delete": - return await self._execute_delete_action( - mode="source", - selector={"sources": list(kwargs.get("sources") or [])}, - requested_by=str(kwargs.get("requested_by", "") or "memory_source_admin"), - reason=str(kwargs.get("reason", "") or "source_batch_delete"), - ) - - return {"success": False, "error": f"不支持的 source action: {act}"} - - async def memory_episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store - - act = str(action or "").strip().lower() - if act in {"query", "list"}: - items = self.metadata_store.query_episodes( - query=str(kwargs.get("query", "") or "").strip(), - time_from=self._optional_float(kwargs.get("time_start", kwargs.get("time_from"))), - time_to=self._optional_float(kwargs.get("time_end", kwargs.get("time_to"))), - person=str(kwargs.get("person_id", "") or kwargs.get("person", "") or "").strip() or None, - source=str(kwargs.get("source", "") or "").strip() or None, - limit=max(1, int(kwargs.get("limit", 20) or 20)), - ) - return {"success": True, "items": items, "count": len(items)} - - if act == "get": - episode_id = str(kwargs.get("episode_id", "") or "").strip() - if not episode_id: - return {"success": False, "error": "episode_id 不能为空"} - episode = self.metadata_store.get_episode_by_id(episode_id) - if episode is None: - return {"success": False, "error": "episode 不存在"} - episode["paragraphs"] = self.metadata_store.get_episode_paragraphs( - episode_id, - limit=max(1, int(kwargs.get("paragraph_limit", 100) or 100)), - ) - return {"success": True, "episode": episode} - - if act == "status": - summary = self.metadata_store.get_episode_source_rebuild_summary( - failed_limit=max(1, int(kwargs.get("limit", 20) or 20)) - ) - summary["pending_queue"] = self.metadata_store.query( - "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" - )[0]["c"] - return {"success": True, **summary} - - if act == "rebuild": - sources = self._tokens(kwargs.get("sources")) - if not sources: - source = str(kwargs.get("source", "") or "").strip() - if source: - sources = [source] - if not sources and bool(kwargs.get("all", False)): - sources = self.metadata_store.list_episode_sources_for_rebuild() - if not sources: - sources = [str(row.get("source", "") or "").strip() for row in self.metadata_store.get_all_sources()] - if not sources: - return {"success": False, "error": "未提供可重建的 source"} - result = await self.rebuild_episodes_for_sources(sources) - return {"success": len(result.get("failures", [])) == 0, **result} - - if act == "process_pending": - result = await self.process_episode_pending_batch( - limit=max(1, int(kwargs.get("limit", 20) or 20)), - max_retry=max(1, int(kwargs.get("max_retry", 3) or 3)), - ) - return {"success": True, **result} - - return {"success": False, "error": f"不支持的 episode action: {act}"} - - async def memory_profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store is not None - assert self.person_profile_service is not None - - act = str(action or "").strip().lower() - if act == "query": - profile = await self.person_profile_service.query_person_profile( - person_id=str(kwargs.get("person_id", "") or "").strip(), - person_keyword=str(kwargs.get("person_keyword", "") or kwargs.get("keyword", "") or "").strip(), - top_k=max(1, int(kwargs.get("limit", kwargs.get("top_k", 12)) or 12)), - force_refresh=bool(kwargs.get("force_refresh", False)), - source_note="sdk_memory_kernel.memory_profile_admin.query", - ) - return profile if isinstance(profile, dict) else {"success": False, "error": "invalid profile payload"} - - if act == "list": - limit = max(1, int(kwargs.get("limit", 50) or 50)) - rows = self.metadata_store.query( - """ - SELECT s.person_id, s.profile_version, s.profile_text, s.updated_at, s.expires_at, s.source_note - FROM person_profile_snapshots s - JOIN ( - SELECT person_id, MAX(profile_version) AS max_version - FROM person_profile_snapshots - GROUP BY person_id - ) latest - ON latest.person_id = s.person_id - AND latest.max_version = s.profile_version - ORDER BY s.updated_at DESC - LIMIT ? - """, - (limit,), - ) - items = [] - for row in rows: - person_id = str(row.get("person_id", "") or "").strip() - override = self.metadata_store.get_person_profile_override(person_id) - items.append( - { - "person_id": person_id, - "profile_version": int(row.get("profile_version", 0) or 0), - "profile_text": str(row.get("profile_text", "") or ""), - "updated_at": row.get("updated_at"), - "expires_at": row.get("expires_at"), - "source_note": str(row.get("source_note", "") or ""), - "has_manual_override": bool(override), - "manual_override": override, - } - ) - return {"success": True, "items": items, "count": len(items)} - - if act == "set_override": - person_id = str(kwargs.get("person_id", "") or "").strip() - override = self.metadata_store.set_person_profile_override( - person_id=person_id, - override_text=str(kwargs.get("override_text", "") or kwargs.get("text", "") or ""), - updated_by=str(kwargs.get("updated_by", "") or ""), - source=str(kwargs.get("source", "") or "memory_profile_admin"), - ) - return {"success": True, "override": override} - - if act == "delete_override": - person_id = str(kwargs.get("person_id", "") or "").strip() - deleted = self.metadata_store.delete_person_profile_override(person_id) - return {"success": bool(deleted), "deleted": bool(deleted), "person_id": person_id} - - return {"success": False, "error": f"不支持的 profile action: {act}"} - - async def memory_runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - act = str(action or "").strip().lower() - - if act == "save": - self._persist() - return {"success": True, "saved": True, "data_dir": str(self.data_dir)} - - if act == "get_config": - return { - "success": True, - "config": self.config, - "data_dir": str(self.data_dir), - "embedding_dimension": int(self.embedding_dimension), - "auto_save": bool(self._cfg("advanced.enable_auto_save", True)), - "relation_vectors_enabled": bool(self.relation_vectors_enabled), - "runtime_ready": self.is_runtime_ready(), - } - - if act in {"self_check", "refresh_self_check"}: - report = await run_embedding_runtime_self_check( - config=self._build_runtime_config(), - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - sample_text=str(kwargs.get("sample_text", "") or "A_Memorix runtime self check"), - ) - self._runtime_facade._runtime_self_check_report = dict(report) - return {"success": bool(report.get("ok", False)), "report": report} - - if act == "set_auto_save": - enabled = bool(kwargs.get("enabled", False)) - self._set_cfg("advanced.enable_auto_save", enabled) - return {"success": True, "auto_save": enabled} - - return {"success": False, "error": f"不支持的 runtime action: {act}"} - - async def memory_import_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - manager = self.import_task_manager - if manager is None: - return {"success": False, "error": "import manager 未初始化"} - - act = str(action or "").strip().lower() - if act in {"settings", "get_settings", "get_guide"}: - return {"success": True, "settings": await manager.get_runtime_settings()} - if act in {"path_aliases", "get_path_aliases"}: - return {"success": True, "path_aliases": manager.get_path_aliases()} - if act in {"resolve_path", "resolve"}: - return await manager.resolve_path_request(kwargs) - if act == "create_upload": - task = await manager.create_upload_task( - list(kwargs.get("staged_files") or kwargs.get("files") or kwargs.get("uploads") or []), - kwargs, - ) - return {"success": True, "task": task} - if act == "create_paste": - return {"success": True, "task": await manager.create_paste_task(kwargs)} - if act == "create_raw_scan": - return {"success": True, "task": await manager.create_raw_scan_task(kwargs)} - if act == "create_lpmm_openie": - return {"success": True, "task": await manager.create_lpmm_openie_task(kwargs)} - if act == "create_lpmm_convert": - return {"success": True, "task": await manager.create_lpmm_convert_task(kwargs)} - if act == "create_temporal_backfill": - return {"success": True, "task": await manager.create_temporal_backfill_task(kwargs)} - if act == "create_maibot_migration": - return {"success": True, "task": await manager.create_maibot_migration_task(kwargs)} - if act == "list": - items = await manager.list_tasks(limit=max(1, int(kwargs.get("limit", 50) or 50))) - return {"success": True, "items": items, "count": len(items)} - if act == "get": - task = await manager.get_task( - str(kwargs.get("task_id", "") or ""), - include_chunks=bool(kwargs.get("include_chunks", False)), - ) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - if act in {"chunks", "get_chunks"}: - payload = await manager.get_chunks( - str(kwargs.get("task_id", "") or ""), - str(kwargs.get("file_id", "") or ""), - offset=max(0, int(kwargs.get("offset", 0) or 0)), - limit=max(1, int(kwargs.get("limit", 50) or 50)), - ) - return {"success": payload is not None, **(payload or {}), "error": "" if payload is not None else "任务或文件不存在"} - if act == "cancel": - task = await manager.cancel_task(str(kwargs.get("task_id", "") or "")) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - if act == "retry_failed": - overrides = kwargs.get("overrides") if isinstance(kwargs.get("overrides"), dict) else kwargs - task = await manager.retry_failed(str(kwargs.get("task_id", "") or ""), overrides=overrides) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - return {"success": False, "error": f"不支持的 import action: {act}"} - - async def memory_tuning_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - manager = self.retrieval_tuning_manager - if manager is None: - return {"success": False, "error": "tuning manager 未初始化"} - - act = str(action or "").strip().lower() - if act in {"settings", "get_settings"}: - return {"success": True, "settings": manager.get_runtime_settings()} - if act == "get_profile": - profile = manager.get_profile_snapshot() - return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)} - if act == "apply_profile": - profile = kwargs.get("profile") if isinstance(kwargs.get("profile"), dict) else kwargs - return {"success": True, **await manager.apply_profile(profile, reason=str(kwargs.get("reason", "manual") or "manual"))} - if act == "rollback_profile": - return {"success": True, **await manager.rollback_profile()} - if act == "export_profile": - profile = manager.get_profile_snapshot() - return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)} - if act == "create_task": - payload = kwargs.get("payload") if isinstance(kwargs.get("payload"), dict) else kwargs - return {"success": True, "task": await manager.create_task(payload)} - if act == "list_tasks": - items = await manager.list_tasks(limit=max(1, int(kwargs.get("limit", 50) or 50))) - return {"success": True, "items": items, "count": len(items)} - if act == "get_task": - task = await manager.get_task( - str(kwargs.get("task_id", "") or ""), - include_rounds=bool(kwargs.get("include_rounds", False)), - ) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - if act == "get_rounds": - payload = await manager.get_rounds( - str(kwargs.get("task_id", "") or ""), - offset=max(0, int(kwargs.get("offset", 0) or 0)), - limit=max(1, int(kwargs.get("limit", 50) or 50)), - ) - return {"success": payload is not None, **(payload or {}), "error": "" if payload is not None else "任务不存在"} - if act == "cancel": - task = await manager.cancel_task(str(kwargs.get("task_id", "") or "")) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - if act == "apply_best": - return {"success": True, **await manager.apply_best(str(kwargs.get("task_id", "") or ""))} - if act == "get_report": - report = await manager.get_report(str(kwargs.get("task_id", "") or ""), fmt=str(kwargs.get("format", "md") or "md")) - return {"success": report is not None, "report": report, "error": "" if report is not None else "任务不存在"} - return {"success": False, "error": f"不支持的 tuning action: {act}"} - - async def memory_v5_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store - - act = str(action or "").strip().lower() - target = str(kwargs.get("target", "") or kwargs.get("query", "") or "").strip() - reason = str(kwargs.get("reason", "") or "").strip() - updated_by = str(kwargs.get("updated_by", "") or kwargs.get("requested_by", "") or "").strip() - limit = max(1, int(kwargs.get("limit", 50) or 50)) - - if act == "recycle_bin": - items = self.metadata_store.get_deleted_relations(limit=limit) - return {"success": True, "items": items, "count": len(items)} - - if act == "status": - return self._memory_v5_status(target=target, limit=limit) - - if act == "restore": - hashes = self._resolve_deleted_relation_hashes(target) - if not hashes: - return {"success": False, "error": "未命中可恢复关系"} - result = await self._restore_relation_hashes(hashes) - operation = self.metadata_store.record_v5_operation( - action=act, - target=target, - resolved_hashes=hashes, - reason=reason, - updated_by=updated_by, - result=result, - ) - return {"success": bool(result.get("restored_count", 0) > 0), "operation": operation, **result} - - hashes = self._resolve_relation_hashes(target) - if not hashes: - return {"success": False, "error": "未命中可维护关系"} - - result = self._apply_v5_relation_action( - action=act, - hashes=hashes, - strength=float(kwargs.get("strength", 1.0) or 1.0), - ) - operation = self.metadata_store.record_v5_operation( - action=act, - target=target, - resolved_hashes=hashes, - reason=reason, - updated_by=updated_by, - result=result, - ) - return {"success": bool(result.get("success", False)), "operation": operation, **result} - - async def memory_delete_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - act = str(action or "").strip().lower() - mode = str(kwargs.get("mode", "") or "").strip().lower() - selector = kwargs.get("selector") - if selector is None: - selector = { - key: value - for key, value in kwargs.items() - if key - not in { - "action", - "mode", - "dry_run", - "cascade", - "operation_id", - "reason", - "requested_by", - } - } - reason = str(kwargs.get("reason", "") or "").strip() - requested_by = str(kwargs.get("requested_by", "") or "").strip() - - if act == "preview": - return await self._preview_delete_action(mode=mode, selector=selector) - if act == "execute": - return await self._execute_delete_action( - mode=mode, - selector=selector, - requested_by=requested_by, - reason=reason, - ) - if act == "restore": - return await self._restore_delete_action( - mode=mode, - selector=selector, - operation_id=str(kwargs.get("operation_id", "") or "").strip(), - requested_by=requested_by, - reason=reason, - ) - if act == "get_operation": - operation = self.metadata_store.get_delete_operation(str(kwargs.get("operation_id", "") or "").strip()) - return {"success": operation is not None, "operation": operation, "error": "" if operation is not None else "operation 不存在"} - if act == "list_operations": - items = self.metadata_store.list_delete_operations( - limit=max(1, int(kwargs.get("limit", 50) or 50)), - mode=mode, - ) - return {"success": True, "items": items, "count": len(items)} - if act == "purge": - return await self._purge_deleted_memory( - grace_hours=self._optional_float(kwargs.get("grace_hours")), - limit=max(1, int(kwargs.get("limit", 1000) or 1000)), - ) - return {"success": False, "error": f"不支持的 delete action: {act}"} - - def get_import_task_manager(self) -> Optional[ImportTaskManager]: - return self.import_task_manager - - def get_retrieval_tuning_manager(self) -> Optional[RetrievalTuningManager]: - return self.retrieval_tuning_manager - - async def _aggregate_search(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]: - result = await SearchExecutionService.execute( - retriever=self.retriever, - threshold_filter=self.threshold_filter, - plugin_config=self._build_runtime_config(), - request=SearchExecutionRequest( - caller="sdk_memory_kernel.aggregate", - stream_id=str(request.chat_id or "") or None, - query_type="search", - query=query, - top_k=limit, - person=str(request.person_id or "") or None, - source=self._chat_source(request.chat_id), - use_threshold=True, - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ), - enforce_chat_filter=False, - reinforce_access=True, - ) - hits = [self._retrieval_result_hit(item) for item in result.results] if result.success else [] - return {"success": result.success, "results": hits, "count": len(hits), "query_type": "search", "error": result.error} - - async def _aggregate_time( - self, - query: str, - limit: int, - request: KernelSearchRequest, - time_window: _NormalizedSearchTimeWindow, - ) -> Dict[str, Any]: - result = await SearchExecutionService.execute( - retriever=self.retriever, - threshold_filter=self.threshold_filter, - plugin_config=self._build_runtime_config(), - request=SearchExecutionRequest( - caller="sdk_memory_kernel.aggregate", - stream_id=str(request.chat_id or "") or None, - query_type="time", - query=query, - top_k=limit, - time_from=time_window.query_start, - time_to=time_window.query_end, - person=str(request.person_id or "") or None, - source=self._chat_source(request.chat_id), - use_threshold=True, - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ), - enforce_chat_filter=False, - reinforce_access=True, - ) - hits = [self._retrieval_result_hit(item) for item in result.results] if result.success else [] - return {"success": result.success, "results": hits, "count": len(hits), "query_type": "time", "error": result.error} - - async def _aggregate_episode( - self, - query: str, - limit: int, - request: KernelSearchRequest, - time_window: _NormalizedSearchTimeWindow, - ) -> Dict[str, Any]: - assert self.episode_retriever - rows = await self.episode_retriever.query( - query=query, - top_k=limit, - time_from=time_window.numeric_start, - time_to=time_window.numeric_end, - person=request.person_id or None, - source=self._chat_source(request.chat_id), - ) - hits = [self._episode_hit(row) for row in rows] - return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"} - - def _persist(self) -> None: - if self.vector_store is not None: - self.vector_store.save() - if self.graph_store is not None: - self.graph_store.save() - if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False): - self.sparse_index.ensure_loaded() - - async def _start_background_tasks(self) -> None: - async with self._background_lock: - self._background_stopping = False - self._ensure_background_task("auto_save", self._auto_save_loop) - self._ensure_background_task("episode_pending", self._episode_pending_loop) - self._ensure_background_task("memory_maintenance", self._memory_maintenance_loop) - self._ensure_background_task("person_profile_refresh", self._person_profile_refresh_loop) - - def _ensure_background_task(self, name: str, factory: Callable[[], Awaitable[None]]) -> None: - task = self._background_tasks.get(name) - if task is not None and not task.done(): - return - self._background_tasks[name] = asyncio.create_task(factory(), name=f"A_Memorix.{name}") - - async def _stop_background_tasks(self) -> None: - async with self._background_lock: - self._background_stopping = True - tasks = [task for task in self._background_tasks.values() if task is not None and not task.done()] - for task in tasks: - task.cancel() - for task in tasks: - try: - await task - except asyncio.CancelledError: - pass - except Exception as exc: - logger.warning(f"后台任务退出异常: {exc}") - self._background_tasks.clear() - - async def _auto_save_loop(self) -> None: - try: - while not self._background_stopping: - interval_minutes = max(1.0, float(self._cfg("advanced.auto_save_interval_minutes", 5) or 5)) - await asyncio.sleep(interval_minutes * 60.0) - if self._background_stopping: - break - if bool(self._cfg("advanced.enable_auto_save", True)): - self._persist() - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning(f"auto_save loop 异常: {exc}") - - async def _episode_pending_loop(self) -> None: - try: - while not self._background_stopping: - await asyncio.sleep(60.0) - if self._background_stopping: - break - if not bool(self._cfg("episode.enabled", True)): - continue - if not bool(self._cfg("episode.generation_enabled", True)): - continue - await self.process_episode_pending_batch( - limit=max(1, int(self._cfg("episode.pending_batch_size", 20) or 20)), - max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3) or 3)), - ) - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning(f"episode_pending loop 异常: {exc}") - - async def _person_profile_refresh_loop(self) -> None: - try: - while not self._background_stopping: - interval_minutes = max(1.0, float(self._cfg("person_profile.refresh_interval_minutes", 30) or 30)) - await asyncio.sleep(max(60.0, interval_minutes * 60.0)) - if self._background_stopping: - break - if not bool(self._cfg("person_profile.enabled", True)): - continue - active_window_hours = max(1.0, float(self._cfg("person_profile.active_window_hours", 72.0) or 72.0)) - max_refresh = max(1, int(self._cfg("person_profile.max_refresh_per_cycle", 50) or 50)) - cutoff = time.time() - active_window_hours * 3600.0 - candidates = [ - person_id - for person_id, seen_at in sorted( - self._active_person_timestamps.items(), - key=lambda item: item[1], - reverse=True, - ) - if seen_at >= cutoff - ][:max_refresh] - for person_id in candidates: - try: - await self.refresh_person_profile(person_id, limit=max(4, int(self._cfg("person_profile.top_k_evidence", 12) or 12)), mark_active=False) - except Exception as exc: - logger.warning(f"刷新人物画像失败: {exc}") - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning(f"person_profile_refresh loop 异常: {exc}") - - async def _memory_maintenance_loop(self) -> None: - try: - while not self._background_stopping: - interval_hours = max(1.0 / 60.0, float(self._cfg("memory.base_decay_interval_hours", 1.0) or 1.0)) - await asyncio.sleep(max(60.0, interval_hours * 3600.0)) - if self._background_stopping: - break - if not bool(self._cfg("memory.enabled", True)): - continue - await self._run_memory_maintenance_cycle(interval_hours=interval_hours) - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning(f"memory_maintenance loop 异常: {exc}") - - async def _run_memory_maintenance_cycle(self, *, interval_hours: float) -> None: - assert self.graph_store is not None - assert self.metadata_store is not None - half_life = float(self._cfg("memory.half_life_hours", 24.0) or 24.0) - if half_life > 0: - factor = 0.5 ** (float(interval_hours) / half_life) - self.graph_store.decay(factor) - - await self._process_freeze_and_prune() - await self._orphan_gc_phase() - self._last_maintenance_at = time.time() - self._persist() - - async def _process_freeze_and_prune(self) -> None: - assert self.metadata_store is not None - assert self.graph_store is not None - prune_threshold = max(0.0, float(self._cfg("memory.prune_threshold", 0.1) or 0.1)) - freeze_duration = max(0.0, float(self._cfg("memory.freeze_duration_hours", 24.0) or 24.0)) * 3600.0 - now = time.time() - - low_edges = self.graph_store.get_low_weight_edges(prune_threshold) - hashes_to_freeze: List[str] = [] - edges_to_deactivate: List[tuple[str, str]] = [] - for src, tgt in low_edges: - relation_hashes = list(self.graph_store.get_relation_hashes_for_edge(src, tgt)) - if not relation_hashes: - continue - statuses = self.metadata_store.get_relation_status_batch(relation_hashes) - current_hashes: List[str] = [] - protected = False - for hash_value, status in statuses.items(): - if bool(status.get("is_pinned")) or float(status.get("protected_until") or 0.0) > now: - protected = True - break - current_hashes.append(hash_value) - if protected or not current_hashes: - continue - hashes_to_freeze.extend(current_hashes) - edges_to_deactivate.append((src, tgt)) - - if hashes_to_freeze: - self.metadata_store.mark_relations_inactive(hashes_to_freeze, inactive_since=now) - self.graph_store.deactivate_edges(edges_to_deactivate) - - cutoff = now - freeze_duration - expired_hashes = self.metadata_store.get_prune_candidates(cutoff) - if not expired_hashes: - return - relation_info = self.metadata_store.get_relations_subject_object_map(expired_hashes) - operations = [(src, tgt, hash_value) for hash_value, (src, tgt) in relation_info.items()] - if operations: - self.graph_store.prune_relation_hashes(operations) - deleted_hashes = [hash_value for hash_value in expired_hashes if hash_value in relation_info] - if deleted_hashes: - self.metadata_store.backup_and_delete_relations(deleted_hashes) - if self.vector_store is not None: - self.vector_store.delete(deleted_hashes) - - async def _orphan_gc_phase(self) -> None: - assert self.metadata_store is not None - assert self.graph_store is not None - orphan_cfg = self._cfg("memory.orphan", {}) or {} - if not bool(orphan_cfg.get("enable_soft_delete", True)): - return - entity_retention = max(0.0, float(orphan_cfg.get("entity_retention_days", 7.0) or 7.0)) * 86400.0 - paragraph_retention = max(0.0, float(orphan_cfg.get("paragraph_retention_days", 7.0) or 7.0)) * 86400.0 - grace_period = max(0.0, float(orphan_cfg.get("sweep_grace_hours", 24.0) or 24.0)) * 3600.0 - - isolated = self.graph_store.get_isolated_nodes(include_inactive=True) - if isolated: - entity_hashes = self.metadata_store.get_entity_gc_candidates(isolated, retention_seconds=entity_retention) - if entity_hashes: - self.metadata_store.mark_as_deleted(entity_hashes, "entity") - - paragraph_hashes = self.metadata_store.get_paragraph_gc_candidates(retention_seconds=paragraph_retention) - if paragraph_hashes: - self.metadata_store.mark_as_deleted(paragraph_hashes, "paragraph") - - dead_paragraphs = self.metadata_store.sweep_deleted_items("paragraph", grace_period) - if dead_paragraphs: - hashes = [str(item[0] or "").strip() for item in dead_paragraphs if item and str(item[0] or "").strip()] - if hashes: - self.metadata_store.physically_delete_paragraphs(hashes) - if self.vector_store is not None: - self.vector_store.delete(hashes) - - dead_entities = self.metadata_store.sweep_deleted_items("entity", grace_period) - if dead_entities: - entity_hashes = [str(item[0] or "").strip() for item in dead_entities if item and str(item[0] or "").strip()] - entity_names = [str(item[1] or "").strip() for item in dead_entities if item and str(item[1] or "").strip()] - if entity_names: - self.graph_store.delete_nodes(entity_names) - if entity_hashes: - self.metadata_store.physically_delete_entities(entity_hashes) - if self.vector_store is not None: - self.vector_store.delete(entity_hashes) - - def _mark_person_active(self, person_id: str) -> None: - token = str(person_id or "").strip() - if not token: - return - self._active_person_timestamps[token] = time.time() - - def _serialize_graph(self, *, limit: int = 200) -> Dict[str, Any]: - assert self.graph_store is not None - assert self.metadata_store is not None - nodes = self.graph_store.get_nodes() - if limit > 0: - nodes = nodes[:limit] - node_set = set(nodes) - node_payload = [] - for name in nodes: - attrs = self.graph_store.get_node_attributes(name) or {} - node_payload.append({"id": name, "name": name, "attributes": attrs}) - - edge_payload = [] - for source, target, relation_hashes in self.graph_store.iter_edge_hash_entries(): - if source not in node_set or target not in node_set: - continue - edge_payload.append( - { - "source": source, - "target": target, - "weight": float(self.graph_store.get_edge_weight(source, target)), - "relation_hashes": sorted(str(item) for item in relation_hashes if str(item).strip()), - } - ) - return { - "nodes": node_payload, - "edges": edge_payload, - "total_nodes": int(self.graph_store.num_nodes), - "total_edges": int(self.graph_store.num_edges), - } - - def _delete_sources(self, sources: Iterable[Any]) -> Dict[str, Any]: - assert self.metadata_store - source_tokens = self._tokens(sources) - if not source_tokens: - return {"success": False, "error": "source 不能为空"} - - deleted_paragraphs = 0 - deleted_sources: List[str] = [] - for source in source_tokens: - paragraphs = self.metadata_store.get_paragraphs_by_source(source) - if not paragraphs: - self.metadata_store.replace_episodes_for_source(source, []) - continue - for row in paragraphs: - paragraph_hash = str(row.get("hash", "") or "").strip() - if not paragraph_hash: - continue - cleanup = self.metadata_store.delete_paragraph_atomic(paragraph_hash) - self._apply_cleanup_plan(cleanup) - deleted_paragraphs += 1 - self.metadata_store.replace_episodes_for_source(source, []) - deleted_sources.append(source) - - self._rebuild_graph_from_metadata() - self._persist() - return { - "success": True, - "sources": deleted_sources, - "deleted_source_count": len(deleted_sources), - "deleted_paragraph_count": deleted_paragraphs, - } - - def _apply_cleanup_plan(self, cleanup: Dict[str, Any]) -> None: - if not isinstance(cleanup, dict): - return - if self.vector_store is not None: - vector_ids: List[str] = [] - paragraph_hash = str(cleanup.get("vector_id_to_remove", "") or "").strip() - if paragraph_hash: - vector_ids.append(paragraph_hash) - for _, _, relation_hash in cleanup.get("relation_prune_ops", []) or []: - token = str(relation_hash or "").strip() - if token: - vector_ids.append(token) - if vector_ids: - self.vector_store.delete(list(dict.fromkeys(vector_ids))) - - def _rebuild_graph_from_metadata(self) -> Dict[str, int]: - assert self.metadata_store is not None - assert self.graph_store is not None - entity_rows = self.metadata_store.query( - """ - SELECT name - FROM entities - WHERE is_deleted IS NULL OR is_deleted = 0 - ORDER BY name ASC - """ - ) - raw_relation_rows = self.metadata_store.query( - """ - SELECT subject, object, confidence, hash - FROM relations - WHERE is_inactive IS NULL OR is_inactive = 0 - """ - ) - relation_rows = [ - row - for row in raw_relation_rows - if str(row.get("subject", "") or "").strip() and str(row.get("object", "") or "").strip() - ] - - names = list( - dict.fromkeys( - [ - str(row.get("name", "") or "").strip() - for row in entity_rows - if str(row.get("name", "") or "").strip() - ] - + [ - str(row.get("subject", "") or "").strip() - for row in relation_rows - if str(row.get("subject", "") or "").strip() - ] - + [ - str(row.get("object", "") or "").strip() - for row in relation_rows - if str(row.get("object", "") or "").strip() - ] - ) - ) - self.graph_store.clear() - if names: - self.graph_store.add_nodes(names) - if relation_rows: - self.graph_store.add_edges( - [ - ( - str(row.get("subject", "") or "").strip(), - str(row.get("object", "") or "").strip(), - ) - for row in relation_rows - ], - weights=[float(row.get("confidence", 1.0) or 1.0) for row in relation_rows], - relation_hashes=[str(row.get("hash", "") or "") for row in relation_rows], - ) - return {"node_count": int(self.graph_store.num_nodes), "edge_count": int(self.graph_store.num_edges)} - - def _rename_node(self, old_name: str, new_name: str) -> Dict[str, Any]: - assert self.metadata_store - source = str(old_name or "").strip() - target = str(new_name or "").strip() - if not source or not target: - return {"success": False, "error": "old_name/new_name 不能为空"} - if source == target: - return {"success": True, "renamed": False, "old_name": source, "new_name": target} - - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - old_hash = compute_hash(source.lower()) - target_hash = compute_hash(target.lower()) - - cursor.execute( - """ - SELECT hash, name, vector_index, appearance_count, created_at, metadata - FROM entities - WHERE hash = ? - OR LOWER(TRIM(name)) = LOWER(TRIM(?)) - LIMIT 1 - """, - (old_hash, source), - ) - old_row = cursor.fetchone() - if old_row is None: - return {"success": False, "error": "原节点不存在"} - - cursor.execute( - """ - SELECT hash, appearance_count - FROM entities - WHERE hash = ? - OR LOWER(TRIM(name)) = LOWER(TRIM(?)) - LIMIT 1 - """, - (target_hash, target), - ) - target_row = cursor.fetchone() - - try: - cursor.execute("BEGIN IMMEDIATE") - if target_row is None: - cursor.execute( - """ - INSERT INTO entities (hash, name, vector_index, appearance_count, created_at, metadata, is_deleted, deleted_at) - VALUES (?, ?, ?, ?, ?, ?, 0, NULL) - """, - ( - target_hash, - target, - old_row["vector_index"], - old_row["appearance_count"], - old_row["created_at"], - old_row["metadata"], - ), - ) - resolved_target_hash = target_hash - else: - resolved_target_hash = str(target_row["hash"] or "").strip() - cursor.execute( - """ - UPDATE entities - SET name = ?, - appearance_count = COALESCE(appearance_count, 0) + ?, - is_deleted = 0, - deleted_at = NULL - WHERE hash = ? - """, - ( - target, - int(old_row["appearance_count"] or 0), - resolved_target_hash, - ), - ) - - cursor.execute( - "UPDATE OR IGNORE paragraph_entities SET entity_hash = ? WHERE entity_hash = ?", - (resolved_target_hash, old_row["hash"]), - ) - cursor.execute("DELETE FROM paragraph_entities WHERE entity_hash = ?", (old_row["hash"],)) - cursor.execute( - "UPDATE relations SET subject = ? WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?))", - (target, old_row["name"]), - ) - cursor.execute( - "UPDATE relations SET object = ? WHERE LOWER(TRIM(object)) = LOWER(TRIM(?))", - (target, old_row["name"]), - ) - cursor.execute("DELETE FROM entities WHERE hash = ?", (old_row["hash"],)) - conn.commit() - except Exception as exc: - conn.rollback() - return {"success": False, "error": f"rename failed: {exc}"} - - self._rebuild_graph_from_metadata() - self._persist() - return {"success": True, "renamed": True, "old_name": source, "new_name": target} - - def _update_edge_weight( - self, - *, - relation_hash: str, - subject: str, - obj: str, - weight: float, - ) -> Dict[str, Any]: - assert self.metadata_store - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - target_weight = max(0.0, float(weight or 0.0)) - if relation_hash: - cursor.execute("UPDATE relations SET confidence = ? WHERE hash = ?", (target_weight, relation_hash)) - updated = cursor.rowcount - else: - cursor.execute( - """ - UPDATE relations - SET confidence = ? - WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?)) - AND LOWER(TRIM(object)) = LOWER(TRIM(?)) - """, - (target_weight, subject, obj), - ) - updated = cursor.rowcount - conn.commit() - if updated <= 0: - return {"success": False, "error": "未找到可更新的关系"} - self._rebuild_graph_from_metadata() - self._persist() - return { - "success": True, - "updated": int(updated), - "weight": target_weight, - "hash": relation_hash, - "subject": subject, - "object": obj, - } - - @staticmethod - def _tokens(values: Optional[Iterable[Any]]) -> List[str]: - result: List[str] = [] - seen = set() - for item in values or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - result.append(token) - return result - - @classmethod - def _merge_tokens(cls, *groups: Optional[Iterable[Any]]) -> List[str]: - merged: List[str] = [] - seen = set() - for group in groups: - for item in cls._tokens(group): - if item in seen: - continue - seen.add(item) - merged.append(item) - return merged - - @staticmethod - def _build_source(source_type: str, chat_id: str, person_ids: Sequence[str]) -> str: - clean_type = str(source_type or "").strip() or "memory" - if clean_type == "chat_summary" and chat_id: - return f"chat_summary:{chat_id}" - if clean_type == "person_fact" and person_ids: - return f"person_fact:{person_ids[0]}" - return f"{clean_type}:{chat_id}" if chat_id else clean_type - - @staticmethod - def _chat_source(chat_id: str) -> Optional[str]: - clean = str(chat_id or "").strip() - return f"chat_summary:{clean}" if clean else None - - @staticmethod - def _resolve_knowledge_type(source_type: str) -> str: - clean_type = str(source_type or "").strip().lower() - if clean_type == "person_fact": - return "factual" - if clean_type == "chat_summary": - return "narrative" - return "mixed" - - @staticmethod - def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]: - payload: Dict[str, Any] = {} - if timestamp is not None: - payload["event_time"] = float(timestamp) - if time_start is not None: - payload["event_time_start"] = float(time_start) - if time_end is not None: - payload["event_time_end"] = float(time_end) - if payload: - payload["time_granularity"] = "minute" - payload["time_confidence"] = 0.95 - return payload - - @classmethod - def _normalize_search_time_bound(cls, value: Any, *, is_end: bool) -> tuple[Optional[float], Optional[str]]: - if value in {None, ""}: - return None, None - if isinstance(value, (int, float)): - ts = float(value) - return ts, format_timestamp(ts) - - text = str(value or "").strip() - if not text: - return None, None - - numeric = cls._optional_float(text) - if numeric is not None: - return numeric, format_timestamp(numeric) - - try: - ts = parse_query_datetime_to_timestamp(text, is_end=is_end) - except ValueError as exc: - raise ValueError(f"时间参数错误: {exc}") from exc - return ts, text - - @classmethod - def _normalize_search_time_window(cls, time_start: Any, time_end: Any) -> _NormalizedSearchTimeWindow: - numeric_start, query_start = cls._normalize_search_time_bound(time_start, is_end=False) - numeric_end, query_end = cls._normalize_search_time_bound(time_end, is_end=True) - if numeric_start is not None and numeric_end is not None and numeric_start > numeric_end: - raise ValueError("时间参数错误: time_start 不能晚于 time_end") - return _NormalizedSearchTimeWindow( - numeric_start=numeric_start, - numeric_end=numeric_end, - query_start=query_start, - query_end=query_end, - ) - - @staticmethod - def _retrieval_result_hit(item: RetrievalResult) -> Dict[str, Any]: - payload = item.to_dict() - return { - "hash": payload.get("hash", ""), - "content": payload.get("content", ""), - "score": payload.get("score", 0.0), - "type": payload.get("type", ""), - "source": payload.get("source", ""), - "metadata": payload.get("metadata", {}) or {}, - } - - @staticmethod - def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]: - return { - "type": "episode", - "episode_id": str(row.get("episode_id", "") or ""), - "title": str(row.get("title", "") or ""), - "content": str(row.get("summary", "") or ""), - "score": float(row.get("lexical_score", 0.0) or 0.0), - "source": "episode", - "metadata": { - "participants": row.get("participants", []) or [], - "keywords": row.get("keywords", []) or [], - "source": row.get("source"), - "event_time_start": row.get("event_time_start"), - "event_time_end": row.get("event_time_end"), - }, - } - - @staticmethod - def _summary(hits: Sequence[Dict[str, Any]]) -> str: - if not hits: - return "" - lines = [] - for index, item in enumerate(hits[:5], start=1): - content = str(item.get("content", "") or "").strip().replace("\n", " ") - lines.append(f"{index}. {(content[:120] + '...') if len(content) > 120 else content}") - return "\n".join(lines) - - @staticmethod - def _filter_hits(hits: List[Dict[str, Any]], person_id: str) -> List[Dict[str, Any]]: - if not person_id: - return hits - filtered = [] - for item in hits: - metadata = item.get("metadata", {}) or {} - if person_id in (metadata.get("person_ids", []) or []): - filtered.append(item) - continue - if person_id and person_id in str(item.get("content", "") or ""): - filtered.append(item) - return filtered or hits - - def _resolve_relation_hashes(self, target: str) -> List[str]: - assert self.metadata_store - token = str(target or "").strip() - if not token: - return [] - if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()): - return [token] - hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10) - if hashes: - return hashes - return [ - str(row.get("hash", "") or "") - for row in self.metadata_store.get_relations(subject=token)[:10] - if str(row.get("hash", "")).strip() - ] - - def _resolve_deleted_relation_hashes(self, target: str) -> List[str]: - assert self.metadata_store - token = str(target or "").strip() - if not token: - return [] - if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()): - return [token] - return self.metadata_store.search_deleted_relation_hashes_by_text(token, limit=10) - - def _memory_v5_status(self, *, target: str = "", limit: int = 50) -> Dict[str, Any]: - assert self.metadata_store - now = time.time() - summary = self.metadata_store.get_memory_status_summary(now) - payload: Dict[str, Any] = { - "success": True, - **summary, - "config": { - "half_life_hours": float(self._cfg("memory.half_life_hours", 24.0) or 24.0), - "base_decay_interval_hours": float(self._cfg("memory.base_decay_interval_hours", 1.0) or 1.0), - "prune_threshold": float(self._cfg("memory.prune_threshold", 0.1) or 0.1), - "freeze_duration_hours": float(self._cfg("memory.freeze_duration_hours", 24.0) or 24.0), - }, - "last_maintenance_at": self._last_maintenance_at, - } - token = str(target or "").strip() - if not token: - return payload - - active_hashes = self._resolve_relation_hashes(token)[:limit] - deleted_hashes = self._resolve_deleted_relation_hashes(token)[:limit] - active_statuses = self.metadata_store.get_relation_status_batch(active_hashes) - items: List[Dict[str, Any]] = [] - for hash_value in active_hashes: - relation = self.metadata_store.get_relation(hash_value) or {} - status = active_statuses.get(hash_value, {}) - items.append( - { - "hash": hash_value, - "subject": str(relation.get("subject", "") or ""), - "predicate": str(relation.get("predicate", "") or ""), - "object": str(relation.get("object", "") or ""), - "state": "inactive" if bool(status.get("is_inactive")) else "active", - "is_pinned": bool(status.get("is_pinned", False)), - "temp_protected": bool(float(status.get("protected_until") or 0.0) > now), - "protected_until": status.get("protected_until"), - "last_reinforced": status.get("last_reinforced"), - "weight": float(status.get("weight", relation.get("confidence", 0.0)) or 0.0), - } - ) - for hash_value in deleted_hashes: - relation = self.metadata_store.get_deleted_relation(hash_value) or {} - items.append( - { - "hash": hash_value, - "subject": str(relation.get("subject", "") or ""), - "predicate": str(relation.get("predicate", "") or ""), - "object": str(relation.get("object", "") or ""), - "state": "deleted", - "is_pinned": bool(relation.get("is_pinned", False)), - "temp_protected": False, - "protected_until": relation.get("protected_until"), - "last_reinforced": relation.get("last_reinforced"), - "weight": float(relation.get("confidence", 0.0) or 0.0), - "deleted_at": relation.get("deleted_at"), - } - ) - payload["items"] = items[:limit] - payload["count"] = len(payload["items"]) - payload["target"] = token - return payload - - def _adjust_relation_confidence(self, hashes: List[str], *, delta: float) -> Dict[str, float]: - assert self.metadata_store - normalized = [str(item or "").strip() for item in hashes if str(item or "").strip()] - if not normalized: - return {} - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - chunk_size = 200 - for index in range(0, len(normalized), chunk_size): - chunk = normalized[index : index + chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - cursor.execute( - f""" - UPDATE relations - SET confidence = MAX(0.0, COALESCE(confidence, 0.0) + ?) - WHERE hash IN ({placeholders}) - """, - tuple([float(delta)] + chunk), - ) - conn.commit() - statuses = self.metadata_store.get_relation_status_batch(normalized) - return {hash_value: float((statuses.get(hash_value) or {}).get("weight", 0.0) or 0.0) for hash_value in normalized} - - def _apply_v5_relation_action(self, *, action: str, hashes: List[str], strength: float = 1.0) -> Dict[str, Any]: - assert self.metadata_store - act = str(action or "").strip().lower() - normalized = [str(item or "").strip() for item in hashes if str(item or "").strip()] - if not normalized: - return {"success": False, "error": "未命中可维护关系"} - - now = time.time() - strength_value = max(0.1, float(strength or 1.0)) - prune_threshold = max(0.0, float(self._cfg("memory.prune_threshold", 0.1) or 0.1)) - detail = "" - - if act == "reinforce": - weights = self._adjust_relation_confidence(normalized, delta=0.5 * strength_value) - protect_hours = max(1.0, 24.0 * strength_value) - self.metadata_store.reinforce_relations(normalized) - self.metadata_store.mark_relations_active(normalized, boost_weight=max(prune_threshold, 0.1)) - self.metadata_store.update_relations_protection( - normalized, - protected_until=now + protect_hours * 3600.0, - last_reinforced=now, - ) - detail = f"reinforce {len(normalized)} 条关系" - elif act == "weaken": - weights = self._adjust_relation_confidence(normalized, delta=-0.5 * strength_value) - to_freeze = [hash_value for hash_value, weight in weights.items() if weight <= prune_threshold] - if to_freeze: - self.metadata_store.mark_relations_inactive(to_freeze, inactive_since=now) - detail = f"weaken {len(normalized)} 条关系" - elif act == "remember_forever": - self.metadata_store.mark_relations_active(normalized, boost_weight=max(prune_threshold, 0.1)) - self.metadata_store.update_relations_protection(normalized, protected_until=0.0, is_pinned=True) - weights = {hash_value: float((self.metadata_store.get_relation_status_batch([hash_value]).get(hash_value) or {}).get("weight", 0.0) or 0.0) for hash_value in normalized} - detail = f"remember_forever {len(normalized)} 条关系" - elif act == "forget": - weights = self._adjust_relation_confidence(normalized, delta=-2.0 * strength_value) - self.metadata_store.update_relations_protection(normalized, protected_until=0.0, is_pinned=False) - self.metadata_store.mark_relations_inactive(normalized, inactive_since=now) - detail = f"forget {len(normalized)} 条关系" - else: - return {"success": False, "error": f"不支持的 V5 动作: {act}"} - - self._rebuild_graph_from_metadata() - self._last_maintenance_at = now - self._persist() - statuses = self.metadata_store.get_relation_status_batch(normalized) - return { - "success": True, - "detail": detail, - "hashes": normalized, - "count": len(normalized), - "weights": weights, - "statuses": statuses, - } - - async def _ensure_vector_for_text(self, *, item_hash: str, text: str) -> bool: - if self.vector_store is None or self.embedding_manager is None: - return False - token = str(item_hash or "").strip() - content = str(text or "").strip() - if not token or not content: - return False - embedding = await self.embedding_manager.encode([content], dimensions=self.embedding_dimension) - if getattr(embedding, "ndim", 1) == 1: - embedding = embedding.reshape(1, -1) - if getattr(embedding, "size", 0) <= 0: - return False - try: - self.vector_store.add(embedding, [token]) - return True - except Exception as exc: - logger.warning(f"重建向量失败: {exc}") - return False - - async def _ensure_relation_vector(self, relation: Dict[str, Any]) -> bool: - if not bool(self.relation_vectors_enabled): - return False - return await self._ensure_vector_for_text( - item_hash=str(relation.get("hash", "") or ""), - text=" ".join( - [ - str(relation.get("subject", "") or "").strip(), - str(relation.get("predicate", "") or "").strip(), - str(relation.get("object", "") or "").strip(), - ] - ).strip(), - ) - - async def _ensure_paragraph_vector(self, paragraph: Dict[str, Any]) -> bool: - return await self._ensure_vector_for_text( - item_hash=str(paragraph.get("hash", "") or ""), - text=str(paragraph.get("content", "") or ""), - ) - - async def _ensure_entity_vector(self, entity: Dict[str, Any]) -> bool: - return await self._ensure_vector_for_text( - item_hash=str(entity.get("hash", "") or ""), - text=str(entity.get("name", "") or ""), - ) - - async def _restore_relation_hashes( - self, - hashes: List[str], - *, - payloads: Optional[Dict[str, Dict[str, Any]]] = None, - rebuild_graph: bool = True, - persist: bool = True, - ) -> Dict[str, Any]: - assert self.metadata_store - restored: List[str] = [] - failures: List[Dict[str, str]] = [] - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - payload_map = payloads or {} - for hash_value in [str(item or "").strip() for item in hashes if str(item or "").strip()]: - relation = self.metadata_store.restore_relation(hash_value) - if relation is None: - relation = self.metadata_store.get_relation(hash_value) - if relation is None: - failures.append({"hash": hash_value, "error": "relation 不存在"}) - continue - payload = payload_map.get(hash_value) if isinstance(payload_map.get(hash_value), dict) else {} - paragraph_hashes = self._tokens(payload.get("paragraph_hashes")) - for paragraph_hash in paragraph_hashes: - cursor.execute( - """ - INSERT OR IGNORE INTO paragraph_relations (paragraph_hash, relation_hash) - VALUES (?, ?) - """, - (paragraph_hash, hash_value), - ) - await self._ensure_relation_vector({**relation, "hash": hash_value}) - restored.append(hash_value) - conn.commit() - if restored and rebuild_graph: - self._rebuild_graph_from_metadata() - if restored and persist: - self._persist() - return {"restored_hashes": restored, "restored_count": len(restored), "failures": failures} - - @staticmethod - def _selector_dict(selector: Any) -> Dict[str, Any]: - if isinstance(selector, dict): - return dict(selector) - if isinstance(selector, (list, tuple)): - return {"items": list(selector)} - token = str(selector or "").strip() - return {"query": token} if token else {} - - def _resolve_paragraph_targets(self, selector: Any, *, include_deleted: bool = False) -> List[Dict[str, Any]]: - assert self.metadata_store - raw = self._selector_dict(selector) - rows: List[Dict[str, Any]] = [] - hashes = self._merge_tokens(raw.get("hashes"), raw.get("items"), [raw.get("hash")]) - for hash_value in hashes: - row = self.metadata_store.get_paragraph(hash_value) - if row is None: - continue - if not include_deleted and bool(row.get("is_deleted", 0)): - continue - rows.append(row) - if rows: - return rows - query = str(raw.get("query", "") or raw.get("content", "") or "").strip() - if not query: - return [] - if len(query) == 64 and all(ch in "0123456789abcdef" for ch in query.lower()): - row = self.metadata_store.get_paragraph(query) - if row is None: - return [] - if not include_deleted and bool(row.get("is_deleted", 0)): - return [] - return [row] - matches = self.metadata_store.search_paragraphs_by_content(query) - return [row for row in matches if include_deleted or not bool(row.get("is_deleted", 0))] - - def _resolve_entity_targets(self, selector: Any, *, include_deleted: bool = False) -> List[Dict[str, Any]]: - assert self.metadata_store - raw = self._selector_dict(selector) - rows: List[Dict[str, Any]] = [] - hashes = self._merge_tokens(raw.get("hashes"), raw.get("items"), [raw.get("hash")]) - for hash_value in hashes: - row = self.metadata_store.get_entity(hash_value) - if row is None: - continue - if not include_deleted and bool(row.get("is_deleted", 0)): - continue - rows.append(row) - names = self._merge_tokens(raw.get("names"), [raw.get("name")], [raw.get("query")]) - for name in names: - if not name: - continue - matches = self.metadata_store.query( - """ - SELECT * - FROM entities - WHERE LOWER(TRIM(name)) = LOWER(TRIM(?)) - OR hash = ? - ORDER BY appearance_count DESC, created_at ASC - """, - (name, compute_hash(str(name).strip().lower())), - ) - for row in matches: - if not include_deleted and bool(row.get("is_deleted", 0)): - continue - rows.append(self.metadata_store._row_to_dict(row, "entity") if hasattr(self.metadata_store, "_row_to_dict") else row) - dedup: Dict[str, Dict[str, Any]] = {} - for row in rows: - token = str(row.get("hash", "") or "").strip() - if token and token not in dedup: - dedup[token] = row - return list(dedup.values()) - - def _resolve_source_targets(self, selector: Any) -> List[str]: - raw = self._selector_dict(selector) - return self._merge_tokens(raw.get("sources"), [raw.get("source")], [raw.get("query")], raw.get("items")) - - def _snapshot_relation_item(self, hash_value: str) -> Optional[Dict[str, Any]]: - assert self.metadata_store - relation = self.metadata_store.get_relation(hash_value) - if relation is None: - relation = self.metadata_store.get_deleted_relation(hash_value) - if relation is None: - return None - paragraph_hashes = [ - str(row.get("paragraph_hash", "") or "").strip() - for row in self.metadata_store.query( - "SELECT paragraph_hash FROM paragraph_relations WHERE relation_hash = ? ORDER BY paragraph_hash ASC", - (hash_value,), - ) - if str(row.get("paragraph_hash", "") or "").strip() - ] - return { - "item_type": "relation", - "item_hash": hash_value, - "item_key": hash_value, - "payload": { - "relation": relation, - "paragraph_hashes": paragraph_hashes, - }, - } - - def _snapshot_paragraph_item(self, hash_value: str) -> Optional[Dict[str, Any]]: - assert self.metadata_store - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is None: - return None - entity_links = [ - { - "paragraph_hash": hash_value, - "entity_hash": str(row.get("entity_hash", "") or ""), - "mention_count": int(row.get("mention_count", 1) or 1), - } - for row in self.metadata_store.query( - """ - SELECT paragraph_hash, entity_hash, mention_count - FROM paragraph_entities - WHERE paragraph_hash = ? - ORDER BY entity_hash ASC - """, - (hash_value,), - ) - ] - relation_hashes = [ - str(row.get("relation_hash", "") or "").strip() - for row in self.metadata_store.query( - """ - SELECT relation_hash - FROM paragraph_relations - WHERE paragraph_hash = ? - ORDER BY relation_hash ASC - """, - (hash_value,), - ) - if str(row.get("relation_hash", "") or "").strip() - ] - return { - "item_type": "paragraph", - "item_hash": hash_value, - "item_key": hash_value, - "payload": { - "paragraph": paragraph, - "entity_links": entity_links, - "relation_hashes": relation_hashes, - "external_refs": self.metadata_store.list_external_memory_refs_by_paragraphs([hash_value]), - }, - } - - def _snapshot_entity_item(self, hash_value: str) -> Optional[Dict[str, Any]]: - assert self.metadata_store - entity = self.metadata_store.get_entity(hash_value) - if entity is None: - return None - paragraph_links = [ - { - "paragraph_hash": str(row.get("paragraph_hash", "") or ""), - "entity_hash": hash_value, - "mention_count": int(row.get("mention_count", 1) or 1), - } - for row in self.metadata_store.query( - """ - SELECT paragraph_hash, mention_count - FROM paragraph_entities - WHERE entity_hash = ? - ORDER BY paragraph_hash ASC - """, - (hash_value,), - ) - ] - return { - "item_type": "entity", - "item_hash": hash_value, - "item_key": hash_value, - "payload": { - "entity": entity, - "paragraph_links": paragraph_links, - }, - } - - def _relation_has_remaining_paragraphs(self, relation_hash: str, removing_hashes: Sequence[str]) -> bool: - assert self.metadata_store - excluded = [str(item or "").strip() for item in removing_hashes if str(item or "").strip()] - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - if excluded: - placeholders = ",".join(["?"] * len(excluded)) - cursor.execute( - f""" - SELECT 1 - FROM paragraph_relations pr - JOIN paragraphs p ON p.hash = pr.paragraph_hash - WHERE pr.relation_hash = ? - AND pr.paragraph_hash NOT IN ({placeholders}) - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - LIMIT 1 - """, - tuple([relation_hash] + excluded), - ) - else: - cursor.execute( - """ - SELECT 1 - FROM paragraph_relations pr - JOIN paragraphs p ON p.hash = pr.paragraph_hash - WHERE pr.relation_hash = ? - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - LIMIT 1 - """, - (relation_hash,), - ) - return cursor.fetchone() is not None - - async def _build_delete_plan(self, *, mode: str, selector: Any) -> Dict[str, Any]: - assert self.metadata_store - act_mode = str(mode or "").strip().lower() - normalized_selector = self._selector_dict(selector) - items: List[Dict[str, Any]] = [] - counts = {"relations": 0, "paragraphs": 0, "entities": 0, "sources": 0} - vector_ids: List[str] = [] - sources: List[str] = [] - target_hashes: Dict[str, List[str]] = { - "relations": [], - "paragraphs": [], - "entities": [], - "sources": [], - "matched_sources": [], - } - - if act_mode == "relation": - relation_rows = [row for row in (self.metadata_store.get_relation(hash_value) for hash_value in self._resolve_relation_hashes(str(normalized_selector.get("query", "") or ""))) if row] - if normalized_selector.get("hashes"): - relation_rows = [ - row - for hash_value in self._tokens(normalized_selector.get("hashes")) - for row in [self.metadata_store.get_relation(hash_value)] - if row is not None - ] - dedup_hashes: List[str] = [] - seen = set() - for row in relation_rows: - hash_value = str(row.get("hash", "") or "").strip() - if hash_value and hash_value not in seen: - seen.add(hash_value) - dedup_hashes.append(hash_value) - snap = self._snapshot_relation_item(hash_value) - if snap: - items.append(snap) - vector_ids.append(hash_value) - counts["relations"] = len(dedup_hashes) - target_hashes["relations"] = dedup_hashes - - elif act_mode in {"paragraph", "source"}: - paragraph_rows: List[Dict[str, Any]] = [] - if act_mode == "source": - source_tokens = self._resolve_source_targets(normalized_selector) - target_hashes["sources"] = source_tokens - counts["requested_sources"] = len(source_tokens) - matched_source_tokens: List[str] = [] - for source in source_tokens: - source_rows = self.metadata_store.query( - """ - SELECT * - FROM paragraphs - WHERE source = ? - AND (is_deleted IS NULL OR is_deleted = 0) - ORDER BY created_at ASC - """, - (source,), - ) - if source_rows: - matched_source_tokens.append(source) - sources.append(source) - paragraph_rows.extend(source_rows) - target_hashes["matched_sources"] = matched_source_tokens - counts["sources"] = len(matched_source_tokens) - counts["matched_sources"] = len(matched_source_tokens) - else: - paragraph_rows = self._resolve_paragraph_targets(normalized_selector, include_deleted=False) - paragraph_hashes = self._tokens([row.get("hash", "") for row in paragraph_rows]) - target_hashes["paragraphs"] = paragraph_hashes - counts["paragraphs"] = len(paragraph_hashes) - for hash_value in paragraph_hashes: - snap = self._snapshot_paragraph_item(hash_value) - if snap: - items.append(snap) - vector_ids.append(hash_value) - paragraph = snap["payload"].get("paragraph") or {} - source = str(paragraph.get("source", "") or "").strip() - if source: - sources.append(source) - - orphan_relations: List[str] = [] - for item in items: - if item.get("item_type") != "paragraph": - continue - for relation_hash in self._tokens((item.get("payload") or {}).get("relation_hashes")): - if relation_hash in orphan_relations: - continue - if not self._relation_has_remaining_paragraphs(relation_hash, paragraph_hashes): - orphan_relations.append(relation_hash) - for relation_hash in orphan_relations: - snap = self._snapshot_relation_item(relation_hash) - if snap: - items.append(snap) - vector_ids.append(relation_hash) - target_hashes["relations"] = orphan_relations - counts["relations"] = len(orphan_relations) - - elif act_mode == "entity": - entity_rows = self._resolve_entity_targets(normalized_selector, include_deleted=False) - entity_hashes = self._tokens([row.get("hash", "") for row in entity_rows]) - target_hashes["entities"] = entity_hashes - counts["entities"] = len(entity_hashes) - entity_names = [str(row.get("name", "") or "").strip() for row in entity_rows if str(row.get("name", "") or "").strip()] - for hash_value in entity_hashes: - snap = self._snapshot_entity_item(hash_value) - if snap: - items.append(snap) - vector_ids.append(hash_value) - relation_hashes: List[str] = [] - for entity_name in entity_names: - for relation in self.metadata_store.get_relations(subject=entity_name) + self.metadata_store.get_relations(object=entity_name): - hash_value = str(relation.get("hash", "") or "").strip() - if hash_value and hash_value not in relation_hashes: - relation_hashes.append(hash_value) - for relation_hash in relation_hashes: - snap = self._snapshot_relation_item(relation_hash) - if snap: - items.append(snap) - vector_ids.append(relation_hash) - target_hashes["relations"] = relation_hashes - counts["relations"] = len(relation_hashes) - else: - return {"success": False, "error": f"不支持的 delete mode: {act_mode}"} - - sources = self._tokens(sources) - vector_ids = self._tokens(vector_ids) - primary_count = counts.get(f"{act_mode}s", 0) if act_mode != "source" else counts.get("matched_sources", 0) - success = ( - primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0 - if act_mode != "source" - else (counts.get("matched_sources", 0) > 0 and counts.get("paragraphs", 0) > 0) - ) - return { - "success": success, - "mode": act_mode, - "selector": normalized_selector, - "items": items, - "counts": counts, - "vector_ids": vector_ids, - "sources": sources, - "target_hashes": target_hashes, - "requested_source_count": counts.get("requested_sources", 0) if act_mode == "source" else 0, - "matched_source_count": counts.get("matched_sources", 0) if act_mode == "source" else 0, - "error": "" if success else "未命中可删除内容", - } - - async def _preview_delete_action(self, *, mode: str, selector: Any) -> Dict[str, Any]: - plan = await self._build_delete_plan(mode=mode, selector=selector) - if not plan.get("success", False): - return {"success": False, "error": plan.get("error", "未命中可删除内容")} - preview_items = [ - { - "item_type": str(item.get("item_type", "") or ""), - "item_hash": str(item.get("item_hash", "") or ""), - } - for item in plan.get("items", [])[:100] - ] - return { - "success": True, - "mode": plan.get("mode"), - "selector": plan.get("selector"), - "counts": plan.get("counts", {}), - "requested_source_count": int(plan.get("requested_source_count", 0) or 0), - "matched_source_count": int(plan.get("matched_source_count", 0) or 0), - "sources": plan.get("sources", []), - "vector_ids": plan.get("vector_ids", []), - "items": preview_items, - "item_count": len(plan.get("items", [])), - "dry_run": True, - } - - async def _execute_delete_action( - self, - *, - mode: str, - selector: Any, - requested_by: str = "", - reason: str = "", - ) -> Dict[str, Any]: - assert self.metadata_store - plan = await self._build_delete_plan(mode=mode, selector=selector) - if not plan.get("success", False): - return {"success": False, "error": plan.get("error", "未命中可删除内容")} - - act_mode = str(plan.get("mode", "") or "").strip().lower() - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - paragraph_hashes = self._tokens((plan.get("target_hashes") or {}).get("paragraphs")) - entity_hashes = self._tokens((plan.get("target_hashes") or {}).get("entities")) - relation_hashes = self._tokens((plan.get("target_hashes") or {}).get("relations")) - requested_source_tokens = self._tokens((plan.get("target_hashes") or {}).get("sources")) - matched_source_tokens = self._tokens((plan.get("target_hashes") or {}).get("matched_sources")) - - try: - if paragraph_hashes: - self.metadata_store.mark_as_deleted(paragraph_hashes, "paragraph") - cursor.execute( - f"DELETE FROM paragraph_entities WHERE paragraph_hash IN ({','.join(['?'] * len(paragraph_hashes))})", - tuple(paragraph_hashes), - ) - cursor.execute( - f"DELETE FROM paragraph_relations WHERE paragraph_hash IN ({','.join(['?'] * len(paragraph_hashes))})", - tuple(paragraph_hashes), - ) - self.metadata_store.delete_external_memory_refs_by_paragraphs(paragraph_hashes) - if act_mode == "source" and matched_source_tokens: - for source in matched_source_tokens: - self.metadata_store.replace_episodes_for_source(source, []) - - if entity_hashes: - self.metadata_store.mark_as_deleted(entity_hashes, "entity") - cursor.execute( - f"DELETE FROM paragraph_entities WHERE entity_hash IN ({','.join(['?'] * len(entity_hashes))})", - tuple(entity_hashes), - ) - - conn.commit() - - deleted_relations = self.metadata_store.backup_and_delete_relations(relation_hashes) - deleted_vectors = 0 - if self.vector_store is not None and plan.get("vector_ids"): - deleted_vectors = self.vector_store.delete(list(plan.get("vector_ids") or [])) - - operation = self.metadata_store.create_delete_operation( - mode=act_mode, - selector=plan.get("selector"), - items=plan.get("items", []), - reason=reason, - requested_by=requested_by, - summary={ - "counts": plan.get("counts", {}), - "sources": plan.get("sources", []), - "vector_ids": plan.get("vector_ids", []), - "deleted_relation_rows": deleted_relations, - }, - ) - - if plan.get("sources"): - self.metadata_store._enqueue_episode_source_rebuilds(list(plan.get("sources") or []), reason="delete_admin_execute") - self._rebuild_graph_from_metadata() - self._persist() - deleted_count = ( - len(paragraph_hashes) - if act_mode == "source" - else len(paragraph_hashes) - if act_mode == "paragraph" - else len(entity_hashes) - if act_mode == "entity" - else len(relation_hashes) - ) - success = bool(deleted_count > 0) - result = { - "success": success, - "mode": act_mode, - "operation_id": operation.get("operation_id", ""), - "counts": plan.get("counts", {}), - "sources": plan.get("sources", []), - "deleted_count": deleted_count, - "deleted_vector_count": int(deleted_vectors or 0), - "deleted_relation_count": len(relation_hashes), - } - if act_mode == "source": - result["requested_source_count"] = len(requested_source_tokens) - result["matched_source_count"] = len(matched_source_tokens) - result["deleted_source_count"] = len(matched_source_tokens) - result["deleted_paragraph_count"] = len(paragraph_hashes) - if not success: - result["error"] = "未命中可删除内容" - return result - except Exception as exc: - conn.rollback() - logger.warning(f"delete_admin execute 失败: {exc}") - return {"success": False, "error": str(exc)} - - async def _restore_delete_action( - self, - *, - mode: str, - selector: Any, - operation_id: str = "", - requested_by: str = "", - reason: str = "", - ) -> Dict[str, Any]: - del requested_by - del reason - assert self.metadata_store - - op_id = str(operation_id or "").strip() - if op_id: - operation = self.metadata_store.get_delete_operation(op_id) - if operation is None: - return {"success": False, "error": "operation 不存在"} - return await self._restore_delete_operation(operation) - - act_mode = str(mode or "").strip().lower() - if act_mode != "relation": - return {"success": False, "error": "paragraph/entity/source 恢复必须提供 operation_id"} - - raw = self._selector_dict(selector) - target = str(raw.get("query", "") or raw.get("target", "") or raw.get("hash", "") or "").strip() - hashes = self._resolve_deleted_relation_hashes(target) - if not hashes: - return {"success": False, "error": "未命中可恢复关系"} - result = await self._restore_relation_hashes(hashes) - return {"success": bool(result.get("restored_count", 0) > 0), **result} - - async def _restore_delete_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]: - assert self.metadata_store - items = operation.get("items") if isinstance(operation.get("items"), list) else [] - entity_payloads: Dict[str, Dict[str, Any]] = {} - paragraph_payloads: Dict[str, Dict[str, Any]] = {} - relation_payloads: Dict[str, Dict[str, Any]] = {} - for item in items: - if not isinstance(item, dict): - continue - item_type = str(item.get("item_type", "") or "").strip() - item_hash = str(item.get("item_hash", "") or "").strip() - payload = item.get("payload") if isinstance(item.get("payload"), dict) else {} - if item_type == "entity" and item_hash: - entity_payloads[item_hash] = payload - elif item_type == "paragraph" and item_hash: - paragraph_payloads[item_hash] = payload - elif item_type == "relation" and item_hash: - relation_payloads[item_hash] = payload - - restored_entities: List[str] = [] - restored_paragraphs: List[str] = [] - for hash_value, payload in entity_payloads.items(): - entity_row = payload.get("entity") if isinstance(payload.get("entity"), dict) else {} - if entity_row: - self.metadata_store.restore_entity_by_hash(hash_value) - await self._ensure_entity_vector(entity_row) - restored_entities.append(hash_value) - for hash_value, payload in paragraph_payloads.items(): - paragraph_row = payload.get("paragraph") if isinstance(payload.get("paragraph"), dict) else {} - if paragraph_row: - self.metadata_store.restore_paragraph_by_hash(hash_value) - await self._ensure_paragraph_vector(paragraph_row) - restored_paragraphs.append(hash_value) - - restored_relations = await self._restore_relation_hashes(list(relation_payloads.keys()), payloads=relation_payloads, rebuild_graph=False, persist=False) - - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - for payload in entity_payloads.values(): - for link in payload.get("paragraph_links") or []: - paragraph_hash = str(link.get("paragraph_hash", "") or "").strip() - entity_hash = str(link.get("entity_hash", "") or "").strip() - mention_count = max(1, int(link.get("mention_count", 1) or 1)) - if not paragraph_hash or not entity_hash: - continue - cursor.execute( - """ - INSERT OR IGNORE INTO paragraph_entities (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - """, - (paragraph_hash, entity_hash, mention_count), - ) - for payload in paragraph_payloads.values(): - for link in payload.get("entity_links") or []: - paragraph_hash = str(link.get("paragraph_hash", "") or "").strip() - entity_hash = str(link.get("entity_hash", "") or "").strip() - mention_count = max(1, int(link.get("mention_count", 1) or 1)) - if not paragraph_hash or not entity_hash: - continue - cursor.execute( - """ - INSERT OR IGNORE INTO paragraph_entities (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - """, - (paragraph_hash, entity_hash, mention_count), - ) - for relation_hash in self._tokens(payload.get("relation_hashes")): - paragraph_hash = str((payload.get("paragraph") or {}).get("hash", "") or "").strip() - if not paragraph_hash or not relation_hash: - continue - cursor.execute( - """ - INSERT OR IGNORE INTO paragraph_relations (paragraph_hash, relation_hash) - VALUES (?, ?) - """, - (paragraph_hash, relation_hash), - ) - self.metadata_store.restore_external_memory_refs(list(payload.get("external_refs") or [])) - conn.commit() - - sources = self._tokens( - [ - str(((payload.get("paragraph") or {}).get("source", "") or "")).strip() - for payload in paragraph_payloads.values() - ] - ) - if sources: - self.metadata_store._enqueue_episode_source_rebuilds(sources, reason="delete_admin_restore") - self._rebuild_graph_from_metadata() - self._persist() - summary = { - "restored_entities": restored_entities, - "restored_paragraphs": restored_paragraphs, - "restored_relations": restored_relations.get("restored_hashes", []), - "sources": sources, - } - self.metadata_store.mark_delete_operation_restored(str(operation.get("operation_id", "") or ""), summary=summary) - return { - "success": True, - "operation_id": str(operation.get("operation_id", "") or ""), - **summary, - "restored_relation_count": restored_relations.get("restored_count", 0), - "relation_failures": restored_relations.get("failures", []), - } - - async def _purge_deleted_memory(self, *, grace_hours: Optional[float], limit: int) -> Dict[str, Any]: - assert self.metadata_store - orphan_cfg = self._cfg("memory.orphan", {}) or {} - grace = float(grace_hours) if grace_hours is not None else max( - 1.0, - float(orphan_cfg.get("sweep_grace_hours", 24.0) or 24.0), - ) - cutoff = time.time() - grace * 3600.0 - deleted_relation_hashes = self.metadata_store.purge_deleted_relations(cutoff_time=cutoff, limit=limit) - dead_paragraphs = self.metadata_store.sweep_deleted_items("paragraph", grace * 3600.0) - paragraph_hashes = [str(item[0] or "").strip() for item in dead_paragraphs if str(item[0] or "").strip()] - dead_entities = self.metadata_store.sweep_deleted_items("entity", grace * 3600.0) - entity_hashes = [str(item[0] or "").strip() for item in dead_entities if str(item[0] or "").strip()] - entity_names = [str(item[1] or "").strip() for item in dead_entities if str(item[1] or "").strip()] - - if paragraph_hashes: - self.metadata_store.physically_delete_paragraphs(paragraph_hashes) - if entity_hashes: - self.metadata_store.physically_delete_entities(entity_hashes) - if entity_names: - self.graph_store.delete_nodes(entity_names) - if self.vector_store is not None: - vector_ids = self._merge_tokens(paragraph_hashes, entity_hashes, deleted_relation_hashes) - if vector_ids: - self.vector_store.delete(vector_ids) - self._rebuild_graph_from_metadata() - self._persist() - return { - "success": True, - "grace_hours": grace, - "purged_deleted_relations": deleted_relation_hashes, - "purged_paragraph_hashes": paragraph_hashes, - "purged_entity_hashes": entity_hashes, - "purged_counts": { - "relations": len(deleted_relation_hashes), - "paragraphs": len(paragraph_hashes), - "entities": len(entity_hashes), - }, - } - - @staticmethod - def _optional_float(value: Any) -> Optional[float]: - if value in {None, ""}: - return None - try: - return float(value) - except Exception: - return None - - @staticmethod - def _optional_int(value: Any) -> Optional[int]: - if value in {None, ""}: - return None - try: - return int(value) - except Exception: - return None diff --git a/plugins/A_memorix/core/runtime/search_runtime_initializer.py b/plugins/A_memorix/core/runtime/search_runtime_initializer.py deleted file mode 100644 index c3c7a81f..00000000 --- a/plugins/A_memorix/core/runtime/search_runtime_initializer.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Shared runtime initializer for Action/Tool/Command retrieval components.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from src.common.logger import get_logger - -from ..retrieval import ( - DualPathRetriever, - DualPathRetrieverConfig, - DynamicThresholdFilter, - FusionConfig, - GraphRelationRecallConfig, - RelationIntentConfig, - RetrievalStrategy, - SparseBM25Config, - ThresholdConfig, - ThresholdMethod, -) - -_logger = get_logger("A_Memorix.SearchRuntimeInitializer") - -_REQUIRED_COMPONENT_KEYS = ( - "vector_store", - "graph_store", - "metadata_store", - "embedding_manager", -) - - -def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any: - if not isinstance(config, dict): - return default - current: Any = config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - -def _safe_dict(value: Any) -> Dict[str, Any]: - return value if isinstance(value, dict) else {} - - -def _resolve_debug_enabled(plugin_config: Optional[dict]) -> bool: - advanced = _get_config_value(plugin_config, "advanced", {}) - if isinstance(advanced, dict): - return bool(advanced.get("debug", False)) - return bool(_get_config_value(plugin_config, "debug", False)) - - -@dataclass -class SearchRuntimeBundle: - """Resolved runtime components and initialized retriever/filter.""" - - vector_store: Optional[Any] = None - graph_store: Optional[Any] = None - metadata_store: Optional[Any] = None - embedding_manager: Optional[Any] = None - sparse_index: Optional[Any] = None - retriever: Optional[DualPathRetriever] = None - threshold_filter: Optional[DynamicThresholdFilter] = None - error: str = "" - - @property - def ready(self) -> bool: - return ( - self.retriever is not None - and self.vector_store is not None - and self.graph_store is not None - and self.metadata_store is not None - and self.embedding_manager is not None - ) - - -def _resolve_runtime_components(plugin_config: Optional[dict]) -> SearchRuntimeBundle: - bundle = SearchRuntimeBundle( - vector_store=_get_config_value(plugin_config, "vector_store"), - graph_store=_get_config_value(plugin_config, "graph_store"), - metadata_store=_get_config_value(plugin_config, "metadata_store"), - embedding_manager=_get_config_value(plugin_config, "embedding_manager"), - sparse_index=_get_config_value(plugin_config, "sparse_index"), - ) - - missing_required = any( - getattr(bundle, key) is None for key in _REQUIRED_COMPONENT_KEYS - ) - if not missing_required: - return bundle - - try: - from ...plugin import AMemorixPlugin - - instances = AMemorixPlugin.get_storage_instances() - except Exception: - instances = {} - - if not isinstance(instances, dict) or not instances: - return bundle - - if bundle.vector_store is None: - bundle.vector_store = instances.get("vector_store") - if bundle.graph_store is None: - bundle.graph_store = instances.get("graph_store") - if bundle.metadata_store is None: - bundle.metadata_store = instances.get("metadata_store") - if bundle.embedding_manager is None: - bundle.embedding_manager = instances.get("embedding_manager") - if bundle.sparse_index is None: - bundle.sparse_index = instances.get("sparse_index") - return bundle - - -def build_search_runtime( - plugin_config: Optional[dict], - logger_obj: Optional[Any], - owner_tag: str, - *, - log_prefix: str = "", -) -> SearchRuntimeBundle: - """Build retriever + threshold filter with unified fallback/config parsing.""" - - log = logger_obj or _logger - owner = str(owner_tag or "runtime").strip().lower() or "runtime" - prefix = str(log_prefix or "").strip() - prefix_text = f"{prefix} " if prefix else "" - - runtime = _resolve_runtime_components(plugin_config) - if any(getattr(runtime, key) is None for key in _REQUIRED_COMPONENT_KEYS): - runtime.error = "存储组件未完全初始化" - log.warning(f"{prefix_text}[{owner}] 存储组件未完全初始化,无法使用检索功能") - return runtime - - sparse_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.sparse", {}) or {}) - fusion_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.fusion", {}) or {}) - relation_intent_cfg_raw = _safe_dict( - _get_config_value(plugin_config, "retrieval.search.relation_intent", {}) or {} - ) - graph_recall_cfg_raw = _safe_dict( - _get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {} - ) - - try: - sparse_cfg = SparseBM25Config(**sparse_cfg_raw) - except Exception as e: - log.warning(f"{prefix_text}[{owner}] sparse 配置非法,回退默认: {e}") - sparse_cfg = SparseBM25Config() - - try: - fusion_cfg = FusionConfig(**fusion_cfg_raw) - except Exception as e: - log.warning(f"{prefix_text}[{owner}] fusion 配置非法,回退默认: {e}") - fusion_cfg = FusionConfig() - - try: - relation_intent_cfg = RelationIntentConfig(**relation_intent_cfg_raw) - except Exception as e: - log.warning(f"{prefix_text}[{owner}] relation_intent 配置非法,回退默认: {e}") - relation_intent_cfg = RelationIntentConfig() - - try: - graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw) - except Exception as e: - log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}") - graph_recall_cfg = GraphRelationRecallConfig() - - try: - config = DualPathRetrieverConfig( - top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20), - top_k_relations=_get_config_value(plugin_config, "retrieval.top_k_relations", 10), - top_k_final=_get_config_value(plugin_config, "retrieval.top_k_final", 10), - alpha=_get_config_value(plugin_config, "retrieval.alpha", 0.5), - enable_ppr=_get_config_value(plugin_config, "retrieval.enable_ppr", True), - ppr_alpha=_get_config_value(plugin_config, "retrieval.ppr_alpha", 0.85), - ppr_timeout_seconds=_get_config_value( - plugin_config, "retrieval.ppr_timeout_seconds", 1.5 - ), - ppr_concurrency_limit=_get_config_value( - plugin_config, "retrieval.ppr_concurrency_limit", 4 - ), - enable_parallel=_get_config_value(plugin_config, "retrieval.enable_parallel", True), - retrieval_strategy=RetrievalStrategy.DUAL_PATH, - debug=_resolve_debug_enabled(plugin_config), - sparse=sparse_cfg, - fusion=fusion_cfg, - relation_intent=relation_intent_cfg, - graph_recall=graph_recall_cfg, - ) - - runtime.retriever = DualPathRetriever( - vector_store=runtime.vector_store, - graph_store=runtime.graph_store, - metadata_store=runtime.metadata_store, - embedding_manager=runtime.embedding_manager, - sparse_index=runtime.sparse_index, - config=config, - ) - - threshold_config = ThresholdConfig( - method=ThresholdMethod.ADAPTIVE, - min_threshold=_get_config_value(plugin_config, "threshold.min_threshold", 0.3), - max_threshold=_get_config_value(plugin_config, "threshold.max_threshold", 0.95), - percentile=_get_config_value(plugin_config, "threshold.percentile", 75.0), - std_multiplier=_get_config_value(plugin_config, "threshold.std_multiplier", 1.5), - min_results=_get_config_value(plugin_config, "threshold.min_results", 3), - enable_auto_adjust=_get_config_value(plugin_config, "threshold.enable_auto_adjust", True), - ) - runtime.threshold_filter = DynamicThresholdFilter(threshold_config) - runtime.error = "" - log.info(f"{prefix_text}[{owner}] 检索运行时初始化完成") - except Exception as e: - runtime.retriever = None - runtime.threshold_filter = None - runtime.error = str(e) - log.error(f"{prefix_text}[{owner}] 检索运行时初始化失败: {e}") - - return runtime - - -class SearchRuntimeInitializer: - """Compatibility wrapper around the function style initializer.""" - - @staticmethod - def build_search_runtime( - plugin_config: Optional[dict], - logger_obj: Optional[Any], - owner_tag: str, - *, - log_prefix: str = "", - ) -> SearchRuntimeBundle: - return build_search_runtime( - plugin_config=plugin_config, - logger_obj=logger_obj, - owner_tag=owner_tag, - log_prefix=log_prefix, - ) diff --git a/plugins/A_memorix/core/storage/__init__.py b/plugins/A_memorix/core/storage/__init__.py deleted file mode 100644 index d878b8e7..00000000 --- a/plugins/A_memorix/core/storage/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -"""存储层""" - -from .vector_store import VectorStore, QuantizationType -from .graph_store import GraphStore, SparseMatrixFormat -from .metadata_store import MetadataStore -from .knowledge_types import ( - ImportStrategy, - KnowledgeType, - allowed_import_strategy_values, - allowed_knowledge_type_values, - get_knowledge_type_from_string, - get_import_strategy_from_string, - parse_import_strategy, - resolve_stored_knowledge_type, - should_extract_relations, - get_default_chunk_size, - get_type_display_name, - validate_stored_knowledge_type, -) -from .type_detection import ( - detect_knowledge_type, - get_type_from_user_input, - looks_like_factual_text, - looks_like_quote_text, - looks_like_structured_text, - select_import_strategy, -) - -__all__ = [ - "VectorStore", - "GraphStore", - "MetadataStore", - "QuantizationType", - "SparseMatrixFormat", - "ImportStrategy", - "KnowledgeType", - "allowed_import_strategy_values", - "allowed_knowledge_type_values", - "get_knowledge_type_from_string", - "get_import_strategy_from_string", - "parse_import_strategy", - "resolve_stored_knowledge_type", - "should_extract_relations", - "get_default_chunk_size", - "get_type_display_name", - "validate_stored_knowledge_type", - "detect_knowledge_type", - "get_type_from_user_input", - "looks_like_factual_text", - "looks_like_quote_text", - "looks_like_structured_text", - "select_import_strategy", -] diff --git a/plugins/A_memorix/core/storage/graph_store.py b/plugins/A_memorix/core/storage/graph_store.py deleted file mode 100644 index 0a5fd95d..00000000 --- a/plugins/A_memorix/core/storage/graph_store.py +++ /dev/null @@ -1,1448 +0,0 @@ -""" -图存储模块 - -基于SciPy稀疏矩阵的知识图谱存储与计算。 -""" - -import pickle -from enum import Enum -from pathlib import Path -from typing import Optional, Union, Tuple, List, Dict, Set, Any -from collections import defaultdict -import threading -import asyncio - -import numpy as np - -class SparseMatrixFormat(Enum): - """稀疏矩阵格式""" - CSR = "csr" - CSC = "csc" - -try: - from scipy.sparse import csr_matrix, csc_matrix, triu, save_npz, load_npz, bmat, lil_matrix - from scipy.sparse.linalg import norm - HAS_SCIPY = True -except ImportError: - class _SparseMatrixPlaceholder: - pass - - def _scipy_missing(*args, **kwargs): - raise ImportError("SciPy 未安装,请安装: pip install scipy") - - csr_matrix = _SparseMatrixPlaceholder - csc_matrix = _SparseMatrixPlaceholder - lil_matrix = _SparseMatrixPlaceholder - triu = _scipy_missing - save_npz = _scipy_missing - load_npz = _scipy_missing - bmat = _scipy_missing - norm = _scipy_missing - HAS_SCIPY = False - -import contextlib -from src.common.logger import get_logger -from ..utils.hash import compute_hash -from ..utils.io import atomic_write - -logger = get_logger("A_Memorix.GraphStore") - - -class GraphModificationMode(Enum): - """图修改模式""" - BATCH = "batch" # 批量模式 (默认, 适合一次性加载) - INCREMENTAL = "incremental" # 增量模式 (适合频繁随机写入, 使用LIL) - READ_ONLY = "read_only" # 只读模式 (适合计算, CSR/CSC) - - -class GraphStore: - """ - 图存储类 - - 功能: - - CSR稀疏矩阵存储图结构 - - 节点和边的CRUD操作 - - Personalized PageRank计算 - - 同义词自动连接 - - 图持久化 - - 参数: - matrix_format: 稀疏矩阵格式(csr/csc) - data_dir: 数据目录 - """ - - def __init__( - self, - matrix_format: str = "csr", - data_dir: Optional[Union[str, Path]] = None, - ): - """ - 初始化图存储 - - Args: - matrix_format: 稀疏矩阵格式(csr/csc) - data_dir: 数据目录 - """ - if not HAS_SCIPY: - raise ImportError("SciPy 未安装,请安装: pip install scipy") - - if isinstance(matrix_format, SparseMatrixFormat): - self.matrix_format = matrix_format.value - else: - self.matrix_format = str(matrix_format).lower() - self.data_dir = Path(data_dir) if data_dir else None - - # 节点管理 - self._nodes: List[str] = [] # 节点列表 - self._node_to_idx: Dict[str, int] = {} # 节点名到索引的映射 - self._node_attrs: Dict[str, Dict[str, Any]] = {} # 节点属性 - - # 边管理(邻接矩阵) - self._adjacency: Optional[Union[csr_matrix, csc_matrix]] = None - - # 统计信息 - self._total_nodes_added = 0 - self._total_edges_added = 0 - self._total_nodes_deleted = 0 - self._total_edges_deleted = 0 - - # 状态管理 - self._modification_mode = GraphModificationMode.BATCH - - # 状态管理 - self._adjacency_T: Optional[Union[csr_matrix, csc_matrix]] = None - self._adjacency_dirty: bool = True - self._saliency_cache: Optional[Dict[str, float]] = None - - # V5: 多关系映射 (src_idx, dst_idx) -> Set[relation_hash] - self._edge_hash_map: Dict[Tuple[int, int], Set[str]] = defaultdict(set) - # V5: 简单的异步锁 (实际上 asyncio 环境下单线程主循环可能不需要,但为了安全保留) - self._lock = asyncio.Lock() - - logger.info(f"GraphStore 初始化: format={matrix_format}") - - def _canonicalize(self, node: str) -> str: - """规范化节点名称 (用于去重和内部索引)""" - if not node: - return "" - return str(node).strip().lower() - - @contextlib.contextmanager - def batch_update(self): - """ - 批量更新上下文管理器 - - 进入时切换到 LIL 格式以优化随机/增量更新 - 退出时恢复到 CSR/CSC 格式以优化存储和计算 - """ - original_mode = self._modification_mode - self._switch_mode(GraphModificationMode.INCREMENTAL) - try: - yield - finally: - self._switch_mode(original_mode) - - def _switch_mode(self, new_mode: GraphModificationMode): - """切换修改模式并转换矩阵格式""" - if new_mode == self._modification_mode: - return - - if self._adjacency is None: - self._modification_mode = new_mode - return - - logger.debug(f"切换图模式: {self._modification_mode.value} -> {new_mode.value}") - - # 转换逻辑 - if new_mode == GraphModificationMode.INCREMENTAL: - # 转换为 LIL 格式 - if not isinstance(self._adjacency, lil_matrix): # 粗略检查是否非 lil - try: - self._adjacency = self._adjacency.tolil() - logger.debug("已转换为 LIL 格式") - except Exception as e: - logger.warning(f"转换为 LIL 失败: {e}") - - elif new_mode in [GraphModificationMode.BATCH, GraphModificationMode.READ_ONLY]: - # 转换回配置的格式 (CSR/CSC) - if self.matrix_format == "csr": - self._adjacency = self._adjacency.tocsr() - elif self.matrix_format == "csc": - self._adjacency = self._adjacency.tocsc() - logger.debug(f"已恢复为 {self.matrix_format.upper()} 格式") - - self._modification_mode = new_mode - - def add_nodes( - self, - nodes: List[str], - attributes: Optional[Dict[str, Dict[str, Any]]] = None, - ) -> int: - """ - 添加节点 - - Args: - nodes: 节点名称列表 - attributes: 节点属性字典 {node_name: {attr: value}} - - Returns: - 成功添加的节点数量 - """ - added = 0 - for node in nodes: - canon = self._canonicalize(node) - if canon in self._node_to_idx: - logger.debug(f"节点已存在,跳过: {node}") - continue - - # 添加到节点列表 - idx = len(self._nodes) - self._nodes.append(node) # 存储原始节点名 - self._node_to_idx[canon] = idx # 映射规范化节点名到索引 - self._adjacency_dirty = True - self._saliency_cache = None - - # 添加属性 - if attributes and node in attributes: - self._node_attrs[canon] = attributes[node] - else: - self._node_attrs[canon] = {} - - added += 1 - self._total_nodes_added += 1 - - # 扩展邻接矩阵 - if added > 0: - self._expand_adjacency_matrix(added) - - logger.debug(f"添加 {added} 个节点") - return added - - def add_edges( - self, - edges: List[Tuple[str, str]], - weights: Optional[List[float]] = None, - relation_hashes: Optional[List[str]] = None, # V5: 支持关系哈希映射 (Relation Hash Mapping) - ) -> int: - """ - 添加边 - - Args: - edges: 边列表 [(source, target), ...] - weights: 边权重列表(默认为1.0) - - Returns: - 成功添加的边数量 - """ - if not edges: - return 0 - - # 确保所有节点存在 - nodes_to_add = set() - for src, tgt in edges: - src_canon = self._canonicalize(src) - tgt_canon = self._canonicalize(tgt) - if src_canon not in self._node_to_idx: - nodes_to_add.add(src) - if tgt_canon not in self._node_to_idx: - nodes_to_add.add(tgt) - - if nodes_to_add: - self.add_nodes(list(nodes_to_add)) - - # 处理权重 - if weights is None: - weights = [1.0] * len(edges) - - if len(weights) != len(edges): - raise ValueError(f"边数量与权重数量不匹配: {len(edges)} vs {len(weights)}") - - # 如果仅仅是添加边且处于增量模式 (LIL),直接更新 - if self._modification_mode == GraphModificationMode.INCREMENTAL: - if self._adjacency is None: - # 初始化为空 LIL - n = len(self._nodes) - from scipy.sparse import lil_matrix - self._adjacency = lil_matrix((n, n), dtype=np.float32) - - # 尝试直接使用 LIL 索引更新 - try: - # 批量获取索引 - rows = [self._node_to_idx[self._canonicalize(src)] for src, _ in edges] - cols = [self._node_to_idx[self._canonicalize(tgt)] for _, tgt in edges] - - # 确保矩阵足够大 (如果 add_nodes 没有扩展它) - 通常 add_nodes 会处理 - # 这里直接赋值 - self._adjacency[rows, cols] = weights - - self._total_edges_added += len(edges) - - # V5: Update edge hash map - if relation_hashes: - for (src, tgt), r_hash in zip(edges, relation_hashes): - if r_hash: - s_idx = self._node_to_idx[self._canonicalize(src)] - t_idx = self._node_to_idx[self._canonicalize(tgt)] - self._edge_hash_map[(s_idx, t_idx)].add(r_hash) - - logger.debug(f"增量添加 {len(edges)} 条边 (LIL)") - return len(edges) - except Exception as e: - logger.warning(f"LIL 增量更新失败,回退到通用方法: {e}") - # Fallback to general method below - - # 通用方法 (构建 COO 然后合并) - # 构建边的三元组 - row_indices = [] - col_indices = [] - data_values = [] - - for (src, tgt), weight in zip(edges, weights): - src_idx = self._node_to_idx[self._canonicalize(src)] - tgt_idx = self._node_to_idx[self._canonicalize(tgt)] - - row_indices.append(src_idx) - col_indices.append(tgt_idx) - data_values.append(weight) - - # 创建新的边的矩阵 - n = len(self._nodes) - new_edges = csr_matrix( - (data_values, (row_indices, col_indices)), - shape=(n, n), - ) - - # 合并到邻接矩阵 - if self._adjacency is None: - self._adjacency = new_edges - else: - self._adjacency = self._adjacency + new_edges - - # 转换为指定格式 - if self.matrix_format == "csc" and isinstance(self._adjacency, csr_matrix): - self._adjacency = self._adjacency.tocsc() - elif self.matrix_format == "csr" and isinstance(self._adjacency, csc_matrix): - self._adjacency = self._adjacency.tocsr() - - self._total_edges_added += len(edges) - self._adjacency_dirty = True # 标记脏位 - - # V5: 更新边哈希映射 (Edge Hash Map) - if relation_hashes: - for (src, tgt), r_hash in zip(edges, relation_hashes): - if r_hash: - try: - s_idx = self._node_to_idx[self._canonicalize(src)] - t_idx = self._node_to_idx[self._canonicalize(tgt)] - self._edge_hash_map[(s_idx, t_idx)].add(r_hash) - except KeyError: - pass # 正常情况下节点已在上方添加,此处仅作防错处理 - - logger.debug(f"添加 {len(edges)} 条边") - return len(edges) - - def update_edge_weight( - self, - source: str, - target: str, - delta: float, - min_weight: float = 0.1, - max_weight: float = 10.0, - ) -> float: - """ - 更新边权重 (增量/强化/弱化) - - Args: - source: 源节点 - target: 目标节点 - delta: 权重变化量 (+/-) - min_weight: 最小权重限制 - max_weight: 最大权重限制 - - Returns: - 更新后的权重 - """ - src_canon = self._canonicalize(source) - tgt_canon = self._canonicalize(target) - - if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: - logger.warning(f"节点不存在,无法更新权重: {source} -> {target}") - return 0.0 - - current_weight = self.get_edge_weight(source, target) - if current_weight == 0.0 and delta <= 0: - # 边不存在且试图减少权重,忽略 - return 0.0 - - # 如果边不存在但 delta > 0,相当于添加新边 (默认基础权重0 + delta) - # 但为了逻辑清晰,我们假设只更新存在的边,或者确实添加 - - new_weight = current_weight + delta - new_weight = max(min_weight, min(max_weight, new_weight)) - - # 使用 batch_update 上下文自动处理格式转换 - # 这里我们临时切换到 incremental 模式进行单次更新 - with self.batch_update(): - # add_edges 会覆盖或添加,我们需要覆盖 - self.add_edges([(source, target)], [new_weight]) - - logger.debug(f"更新权重 {source}->{target}: {current_weight:.2f} -> {new_weight:.2f}") - return new_weight - - def delete_nodes(self, nodes: List[str]) -> int: - """ - 删除节点(及相关的边) - - Args: - nodes: 要删除的节点列表 - - Returns: - 成功删除的节点数量 - """ - if not nodes: - return 0 - - # 检查哪些节点存在 - existing_nodes = [node for node in nodes if self._canonicalize(node) in self._node_to_idx] - if not existing_nodes: - logger.warning("所有节点都不存在,无法删除") - return 0 - - # 获取要删除的索引 - indices_to_delete = {self._node_to_idx[self._canonicalize(node)] for node in existing_nodes} - indices_to_keep = [ - i for i in range(len(self._nodes)) - if i not in indices_to_delete - ] - - # 创建索引映射 - old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(indices_to_keep)} - - # 重建节点列表 (存储原始节点名) - self._nodes = [self._nodes[i] for i in indices_to_keep] - # 重建规范化节点名到索引的映射 - self._node_to_idx = {self._canonicalize(self._nodes[new_idx]): new_idx for new_idx in range(len(self._nodes))} - - # 删除并重构节点属性 - new_node_attrs = {} - for idx, node_name in enumerate(self._nodes): - canon = self._canonicalize(node_name) - if canon in self._node_attrs: - new_node_attrs[canon] = self._node_attrs[canon] - self._node_attrs = new_node_attrs - - # 重建邻接矩阵 - if self._adjacency is not None: - # 转换为COO格式以进行切片,然后转换回原始格式 - self._adjacency = self._adjacency.tocoo() - mask_rows = np.isin(self._adjacency.row, list(indices_to_keep)) - mask_cols = np.isin(self._adjacency.col, list(indices_to_keep)) - - # 筛选出保留的行和列 - new_rows = self._adjacency.row[mask_rows & mask_cols] - new_cols = self._adjacency.col[mask_rows & mask_cols] - new_data = self._adjacency.data[mask_rows & mask_cols] - - # 更新索引 - new_rows = np.array([old_to_new[r] for r in new_rows]) - new_cols = np.array([old_to_new[c] for c in new_cols]) - - n = len(self._nodes) - if self.matrix_format == "csr": - self._adjacency = csr_matrix((new_data, (new_rows, new_cols)), shape=(n, n)) - else: # csc - self._adjacency = csc_matrix((new_data, (new_rows, new_cols)), shape=(n, n)) - - # 重建关系哈希映射,移除涉及已删除节点的记录并重映射索引。 - if self._edge_hash_map: - new_edge_hash_map: Dict[Tuple[int, int], Set[str]] = defaultdict(set) - for (old_src, old_tgt), hashes in self._edge_hash_map.items(): - if old_src in indices_to_delete or old_tgt in indices_to_delete: - continue - if old_src in old_to_new and old_tgt in old_to_new and hashes: - new_edge_hash_map[(old_to_new[old_src], old_to_new[old_tgt])] = set(hashes) - self._edge_hash_map = new_edge_hash_map - - deleted_count = len(existing_nodes) - self._total_nodes_deleted += deleted_count - self._adjacency_dirty = True - self._saliency_cache = None - - logger.info(f"删除 {deleted_count} 个节点") - return deleted_count - - def remove_nodes(self, nodes: List[str]) -> int: - """兼容性别名:删除节点""" - return self.delete_nodes(nodes) - - def delete_edges( - self, - edges: List[Tuple[str, str]], - ) -> int: - """ - 删除边 - - Args: - edges: 要删除的边列表 [(source, target), ...] - - Returns: - 成功删除的边数量 - """ - if not edges: - return 0 - - deleted = 0 - # 构建要删除的边的索引集合 - edges_to_delete = set() - for src, tgt in edges: - src_canon = self._canonicalize(src) - tgt_canon = self._canonicalize(tgt) - if src_canon in self._node_to_idx and tgt_canon in self._node_to_idx: - src_idx = self._node_to_idx[src_canon] - tgt_idx = self._node_to_idx[tgt_canon] - edges_to_delete.add((src_idx, tgt_idx)) - - if self._adjacency is not None and edges_to_delete: - # 转换为COO格式便于修改 - adj_coo = self._adjacency.tocoo() - - # 过滤要删除的边 - new_row = [] - new_col = [] - new_data = [] - - for i, j, val in zip(adj_coo.row, adj_coo.col, adj_coo.data): - if (i, j) not in edges_to_delete: - new_row.append(i) - new_col.append(j) - new_data.append(val) - else: - deleted += 1 - - # 重建邻接矩阵 - n = len(self._nodes) - self._adjacency = csr_matrix((new_data, (new_row, new_col)), shape=(n, n)) - - # 转换回指定格式 - if self.matrix_format == "csc": - self._adjacency = self._adjacency.tocsc() - - # delete_edges 是“物理删除”语义,必须同步清理 edge_hash_map。 - if edges_to_delete and self._edge_hash_map: - for key in edges_to_delete: - self._edge_hash_map.pop(key, None) - - self._total_edges_deleted += deleted - self._adjacency_dirty = True - self._saliency_cache = None - logger.info(f"删除 {deleted} 条边") - return deleted - - def remove_edges(self, edges: List[Tuple[str, str]]) -> int: - """兼容性别名:删除边""" - return self.delete_edges(edges) - - def get_nodes(self) -> List[str]: - """ - 获取所有节点 - - Returns: - 节点列表 - """ - return self._nodes.copy() - - def has_node(self, node: str) -> bool: - """ - 检查节点是否存在 - - Args: - node: 节点名称 - """ - return self._canonicalize(node) in self._node_to_idx - - def find_node(self, node: str, ignore_case: bool = False) -> Optional[str]: - """ - 查找节点 (由于底层已统一规范化,ignore_case 始终有效) - - Args: - node: 节点名称 - ignore_case: 是否忽略大小写 (已默认忽略) - - Returns: - 真实节点名称 (如果存在),否则 None - """ - canon = self._canonicalize(node) - if canon in self._node_to_idx: - return self._nodes[self._node_to_idx[canon]] - return None - - def get_node_attributes(self, node: str) -> Optional[Dict[str, Any]]: - """ - 获取节点属性 - - Args: - node: 节点名称 - - Returns: - 节点属性字典,不存在则返回None - """ - canon = self._canonicalize(node) - return self._node_attrs.get(canon) - - def get_neighbors(self, node: str) -> List[str]: - """ - 获取节点的出邻居 - - Args: - node: 节点名称 - - Returns: - 出邻居节点列表 - """ - canon = self._canonicalize(node) - if canon not in self._node_to_idx or self._adjacency is None: - return [] - - idx = self._node_to_idx[canon] - neighbor_indices = self._row_neighbor_indices(self._adjacency, idx) - return [self._nodes[int(i)] for i in neighbor_indices] - - def get_in_neighbors(self, node: str) -> List[str]: - """ - 获取节点的入邻居 - - Args: - node: 节点名称 - - Returns: - 入邻居节点列表 - """ - canon = self._canonicalize(node) - if canon not in self._node_to_idx or self._adjacency is None: - return [] - - self._ensure_adjacency_T() - if self._adjacency_T is None: - return [] - - idx = self._node_to_idx[canon] - neighbor_indices = self._row_neighbor_indices(self._adjacency_T, idx) - return [self._nodes[int(i)] for i in neighbor_indices] - - def get_edge_weight(self, source: str, target: str) -> float: - """ - 获取边的权重 - """ - src_canon = self._canonicalize(source) - tgt_canon = self._canonicalize(target) - - if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: - return 0.0 - - if self._adjacency is None: - return 0.0 - - src_idx = self._node_to_idx[src_canon] - tgt_idx = self._node_to_idx[tgt_canon] - - return float(self._adjacency[src_idx, tgt_idx]) - - def canonicalize_node(self, node: str) -> str: - """公开节点规范化接口,避免外部访问私有方法。""" - return self._canonicalize(node) - - def has_edge_hash_map(self) -> bool: - """是否存在 relation-hash 映射。""" - return bool(self._edge_hash_map) - - def get_relation_hashes_for_edge(self, source: str, target: str) -> Set[str]: - """获取边 (source -> target) 关联的关系哈希集合。""" - src_canon = self._canonicalize(source) - tgt_canon = self._canonicalize(target) - if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: - return set() - src_idx = self._node_to_idx[src_canon] - tgt_idx = self._node_to_idx[tgt_canon] - return set(self._edge_hash_map.get((src_idx, tgt_idx), set())) - - def get_incident_relation_hashes(self, node: str, limit: Optional[int] = None) -> List[str]: - """获取与指定节点关联的关系哈希列表(入边 + 出边)。""" - canon = self._canonicalize(node) - if canon not in self._node_to_idx or not self._edge_hash_map: - return [] - - idx = self._node_to_idx[canon] - limit_val = max(1, int(limit)) if limit is not None else None - collected: List[Tuple[str, str, str]] = [] - idx_to_node = self._nodes - - for (src_idx, tgt_idx), hashes in self._edge_hash_map.items(): - if idx not in {src_idx, tgt_idx}: - continue - src_name = idx_to_node[src_idx] if 0 <= src_idx < len(idx_to_node) else "" - tgt_name = idx_to_node[tgt_idx] if 0 <= tgt_idx < len(idx_to_node) else "" - for hash_value in hashes: - hash_text = str(hash_value).strip() - if not hash_text: - continue - collected.append((src_name, tgt_name, hash_text)) - - collected.sort(key=lambda x: (x[0].lower(), x[1].lower(), x[2])) - out: List[str] = [] - seen = set() - for _, _, hash_value in collected: - if hash_value in seen: - continue - seen.add(hash_value) - out.append(hash_value) - if limit_val is not None and len(out) >= limit_val: - break - return out - - def edge_contains_relation_hash(self, source: str, target: str, hash_value: str) -> bool: - """判断边是否包含指定关系哈希。""" - if not hash_value: - return False - return str(hash_value) in self.get_relation_hashes_for_edge(source, target) - - def iter_edge_hash_entries(self) -> List[Tuple[str, str, Set[str]]]: - """以节点名形式遍历 edge-hash-map。""" - out: List[Tuple[str, str, Set[str]]] = [] - if not self._edge_hash_map: - return out - idx_to_node = self._nodes - for (s_idx, t_idx), hashes in self._edge_hash_map.items(): - if not hashes: - continue - if s_idx < 0 or t_idx < 0: - continue - if s_idx >= len(idx_to_node) or t_idx >= len(idx_to_node): - continue - out.append((idx_to_node[s_idx], idx_to_node[t_idx], set(hashes))) - return out - - def deactivate_edges(self, edges: List[Tuple[str, str]]) -> int: - """ - 冻结边 (将权重设为0.0,使其在计算意义上消失,但保留在Map中) - - Args: - edges: [(s1, t1), (s2, t2)...] - """ - if not edges or self._adjacency is None: - return 0 - - deactivated_count = 0 - with self.batch_update(): - # 我们需要 explicit set to 0. - # 使用增量更新模式覆盖 - for s, t in edges: - s_canon = self._canonicalize(s) - t_canon = self._canonicalize(t) - if s_canon in self._node_to_idx and t_canon in self._node_to_idx: - idx_s = self._node_to_idx[s_canon] - idx_t = self._node_to_idx[t_canon] - self._adjacency[idx_s, idx_t] = 0.0 - deactivated_count += 1 - - self._adjacency_dirty = True - return deactivated_count - - def _ensure_adjacency_T(self): - """确保转置邻接矩阵是最新的""" - if self._adjacency is None: - self._adjacency_T = None - return - - if self._adjacency_dirty or self._adjacency_T is None: - # 只有在确实需要时才计算转置 - # find_paths 以“按行读取邻居”为主,因此统一缓存为 CSR,避免 - # CSR->CSC 转置后按行切片读出错误的索引视图。 - self._adjacency_T = self._adjacency.transpose().tocsr() - - self._adjacency_dirty = False - # logger.debug("重建转置邻接矩阵缓存") - - @staticmethod - def _row_neighbor_indices( - matrix: Optional[Union[csr_matrix, csc_matrix]], - row_idx: int, - ) -> np.ndarray: - """返回指定行的非零列索引。""" - if matrix is None: - return np.asarray([], dtype=np.int32) - - if isinstance(matrix, csr_matrix): - return matrix.indices[matrix.indptr[row_idx]:matrix.indptr[row_idx + 1]] - - row = matrix[row_idx, :] - _, indices = row.nonzero() - return np.asarray(indices, dtype=np.int32) - - def find_paths( - self, - start_node: str, - end_node: str, - max_depth: int = 3, - max_paths: int = 5, - max_expansions: int = 20000 - ) -> List[List[str]]: - """ - 查找两个节点之间的路径 (BFS) - 支持有向和无向 (视作双向) 探索 - - Args: - start_node: 起始节点 - end_node: 目标节点 - max_depth: 最大深度 - max_paths: 最大路径数 (找到这么多就停止) - max_expansions: 最大扩展次数 (防止爆炸) - - Returns: - 路径列表 [[n1, n2, n3], ...] - """ - start_canon = self._canonicalize(start_node) - end_canon = self._canonicalize(end_node) - - if start_canon not in self._node_to_idx or end_canon not in self._node_to_idx: - return [] - - if self._adjacency is None: - return [] - - # 确保转置矩阵可用 (用于查找入边) - self._ensure_adjacency_T() - - start_idx = self._node_to_idx[start_canon] - end_idx = self._node_to_idx[end_canon] - - # 队列: (current_idx, path_indices) - queue = [(start_idx, [start_idx])] - found_paths = [] - expansions = 0 - - unique_paths = set() - - while queue: - curr, path = queue.pop(0) - - if len(path) > max_depth + 1: - continue - - if curr == end_idx: - # 找到路径 - # 转换回节点名 - path_names = [self._nodes[i] for i in path] - path_tuple = tuple(path_names) - if path_tuple not in unique_paths: - found_paths.append(path_names) - unique_paths.add(path_tuple) - - if len(found_paths) >= max_paths: - break - continue - - if expansions >= max_expansions: - break - - expansions += 1 - - # 获取邻居 (出边 + 入边) - out_indices = self._row_neighbor_indices(self._adjacency, curr) - - # 2. 入边 (使用转置矩阵) - if self._adjacency_T is not None: - in_indices = self._row_neighbor_indices(self._adjacency_T, curr) - neighbors = np.concatenate((out_indices, in_indices)) - else: - neighbors = out_indices - - # 去重并过滤已在路径中的节点 (防止环) - # 注意: 这里简单去重,可能包含重复的邻居(如果既是出又是入) - seen_in_path = set(path) - queued_neighbors = set() - - for neighbor_idx in neighbors: - neighbor = int(neighbor_idx) - if neighbor not in seen_in_path and neighbor not in queued_neighbors: - # 只有未访问过的才加入 - queue.append((neighbor, path + [neighbor])) - queued_neighbors.add(neighbor) - - return found_paths - - def compute_pagerank( - self, - personalization: Optional[Dict[str, float]] = None, - alpha: float = 0.85, - max_iter: int = 100, - tol: float = 1e-6, - ) -> Dict[str, float]: - """ - 计算Personalized PageRank - - Args: - personalization: 个性化向量 {node: weight},默认为均匀分布 - alpha: 阻尼系数(0-1之间) - max_iter: 最大迭代次数 - tol: 收敛阈值 - - Returns: - 节点PageRank值字典 {node: score} - """ - if self._adjacency is None or len(self._nodes) == 0: - logger.warning("图为空,无法计算PageRank") - return {} - - n = len(self._nodes) - - # 构建列归一化的转移矩阵 - adj = self._adjacency.astype(np.float32) - - # 计算出度 - out_degrees = np.array(adj.sum(axis=1)).flatten() - - # 处理悬挂节点(出度为0) - dangling = (out_degrees == 0) - out_degrees_inv = np.zeros_like(out_degrees) - out_degrees_inv[~dangling] = 1.0 / out_degrees[~dangling] - - # 归一化 (使用稀疏对角阵避免内存溢出) - from scipy.sparse import diags - D_inv = diags(out_degrees_inv) - M = adj.T @ D_inv # 转移矩阵 - - # 初始化个性化向量 - if personalization is None: - # 均匀分布 - p = np.ones(n) / n - else: - # 使用指定的个性化向量 - p = np.zeros(n) - total_weight = sum(personalization.values()) - for node, weight in personalization.items(): - if node in self._node_to_idx: - idx = self._node_to_idx[node] - p[idx] = weight / total_weight - - # 确保和为1 - if p.sum() == 0: - p = np.ones(n) / n - else: - p = p / p.sum() - - # 幂迭代法 - p_orig = p.copy() - for i in range(max_iter): - # p_new = alpha * M * p + (1-alpha) * personalization - p_new = alpha * (M @ p) + (1 - alpha) * p_orig - - # 处理因为悬挂节点导致的概率流失 - current_sum = p_new.sum() - if current_sum < 1.0: - p_new += (1.0 - current_sum) * p_orig - - # 检查收敛 - diff = np.linalg.norm(p_new - p, 1) - if diff < tol: - logger.debug(f"PageRank在 {i+1} 次迭代后收敛") - p = p_new - break - p = p_new - else: - logger.warning(f"PageRank未在 {max_iter} 次迭代内收敛") - - # 转换为真实节点名称字典 - return {self._nodes[idx]: float(val) for idx, val in enumerate(p)} - - def get_saliency_scores(self) -> Dict[str, float]: - """ - 获取节点显著性得分 (带有缓存机制) - """ - if self._saliency_cache is not None and not self._adjacency_dirty: - return self._saliency_cache - - logger.debug("正在计算节点显著性得分 (PageRank)...") - scores = self.compute_pagerank() - self._saliency_cache = scores - # 注意:这里我们不把 _adjacency_dirty 设为 False,因为其它逻辑(如_adjacency_T)也依赖它 - return scores - - def connect_synonyms( - self, - similarity_matrix: np.ndarray, - node_list: List[str], - threshold: float = 0.85, - ) -> int: - """ - 连接相似节点(同义词) - - Args: - similarity_matrix: 相似度矩阵 (N x N) - node_list: 对应的节点列表(长度为N) - threshold: 相似度阈值 - - Returns: - 添加的边数量 - """ - if len(node_list) != similarity_matrix.shape[0]: - raise ValueError( - f"节点列表长度与相似度矩阵维度不匹配: " - f"{len(node_list)} vs {similarity_matrix.shape[0]}" - ) - - # 找到相似的节点对(上三角,排除对角线) - similar_pairs = np.argwhere( - (triu(similarity_matrix, k=1) >= threshold) & - (triu(similarity_matrix, k=1) < 1.0) # 排除完全相同的 - ) - - # 添加边 - edges = [] - for i, j in similar_pairs: - if i < len(node_list) and j < len(node_list): - src = node_list[i] - tgt = node_list[j] - # 使用相似度作为权重 - weight = float(similarity_matrix[i, j]) - edges.append((src, tgt, weight)) - - if edges: - edge_pairs = [(src, tgt) for src, tgt, _ in edges] - weights = [w for _, _, w in edges] - count = self.add_edges(edge_pairs, weights) - logger.info(f"连接 {count} 对相似节点(阈值={threshold})") - return count - return 0 - - - # ========================================================================= - # V5 Memory System Methods (Graph Level) - # ========================================================================= - - def decay(self, factor: float, min_active_weight: float = 0.0) -> None: - """ - 全图衰减 (Atomic Decay) - - Args: - factor: 衰减因子 (0.0 < factor < 1.0) - min_active_weight: 最小活跃权重 (低于此值可能被视为无效,但在物理修剪前仍保留) - """ - if self._adjacency is None or factor >= 1.0 or factor <= 0.0: - return - - logger.debug(f"正在执行全图衰减,因子: {factor}") - - # 直接矩阵乘法,SciPy CSR/CSC 非常高效 - self._adjacency *= factor - - # 如果需要处理极小值 (可选,防止下溢,但通常浮点数足够小) - # if min_active_weight > 0: - # ... (复杂操作,暂不需要,由 prune 逻辑处理) - - self._adjacency_dirty = True - - def prune_relation_hashes(self, operations: List[Tuple[str, str, str]]) -> None: - """ - 修剪特定关系哈希 (从 _edge_hash_map 移除; 如果边变空则从矩阵移除) - - Args: - operations: List[(src, tgt, relation_hash)] - """ - if not operations: - return - - edges_to_check_removal = set() - - # 1. 更新映射 (Update Map) - for src, tgt, h in operations: - src_canon = self._canonicalize(src) - tgt_canon = self._canonicalize(tgt) - if src_canon in self._node_to_idx and tgt_canon in self._node_to_idx: - s_idx = self._node_to_idx[src_canon] - t_idx = self._node_to_idx[tgt_canon] - - key = (s_idx, t_idx) - if key in self._edge_hash_map: - if h in self._edge_hash_map[key]: - self._edge_hash_map[key].remove(h) - - if not self._edge_hash_map[key]: - del self._edge_hash_map[key] - edges_to_check_removal.add((src, tgt)) - - # 2. 从矩阵中移除空边 (Remove Empty Edges from Matrix) - if edges_to_check_removal: - self.deactivate_edges(list(edges_to_check_removal)) - self._total_edges_deleted += len(edges_to_check_removal) - - def get_low_weight_edges(self, threshold: float) -> List[Tuple[str, str]]: - """ - 获取低于阈值的边 (candidates for pruning/freezing) - - Args: - threshold: 权重阈值 - - Returns: - List[(src, tgt)]: 边列表 - """ - if self._adjacency is None: - return [] - - # 获取所有非零元素 - rows, cols = self._adjacency.nonzero() - data = self._adjacency.data - - low_weight_indices = np.where(data < threshold)[0] - - results = [] - for idx in low_weight_indices: - r = rows[idx] - c = cols[idx] - src = self._nodes[r] - tgt = self._nodes[c] - results.append((src, tgt)) - - return results - - def get_isolated_nodes(self, include_inactive: bool = True) -> List[str]: - """ - 获取孤儿节点 (Active Degree = 0) - - Args: - include_inactive: 是否包含参与了inactive边(冻结边)的节点。 - 如果 True (默认推荐): 排除掉虽然active degree=0但存在于_edge_hash_map(冻结边)中的节点。 - 如果 False: 只要在 active matrix 里度为0就返回 (可能会误删冻结节点)。 - - Returns: - 孤儿节点名称列表 - """ - if self._adjacency is None: - # 如果全空,则所有节点都是孤儿 - return self._nodes.copy() - - n = len(self._nodes) - - # 计算 Active Degree (In + Out) - # 用 sum(axis) 会得到 dense matrix/array - active_adj = self._adjacency - out_degrees = np.array(active_adj.sum(axis=1)).flatten() - in_degrees = np.array(active_adj.sum(axis=0)).flatten() - - # 处理悬挂节点 (dangling node check not really needed here, just sum) - total_degrees = out_degrees + in_degrees - - # 找到 active degree = 0 的索引 - isolated_indices = np.where(total_degrees == 0)[0] - - if len(isolated_indices) == 0: - return [] - - isolated_nodes_set = {self._nodes[i] for i in isolated_indices} - - # 如果需要排除 Inactive 参与者 - if include_inactive and self._edge_hash_map: - # 收集所有在冻结边中的 unique 节点索引 - frozen_participant_indices = set() - for (u_idx, v_idx), hashes in self._edge_hash_map.items(): - if hashes: # 只要有 hash 存在 (哪怕 inactive) - frozen_participant_indices.add(u_idx) - frozen_participant_indices.add(v_idx) - - # 过滤 - final_isolated = [] - for idx in isolated_indices: - if idx not in frozen_participant_indices: - final_isolated.append(self._nodes[idx]) - return final_isolated - - else: - return list(isolated_nodes_set) - - def clear(self) -> None: - """清空所有数据""" - self._nodes.clear() - self._node_to_idx.clear() - self._node_attrs.clear() - self._adjacency = None - self._edge_hash_map.clear() - self._adjacency_T = None - self._adjacency_dirty = True - self._total_nodes_added = 0 - self._total_edges_added = 0 - self._total_nodes_deleted = 0 - self._total_edges_deleted = 0 - logger.info("图存储已清空") - - def save(self, data_dir: Optional[Union[str, Path]] = None) -> None: - """ - 保存到磁盘 - - Args: - data_dir: 数据目录(默认使用初始化时的目录) - """ - if data_dir is None: - data_dir = self.data_dir - - if data_dir is None: - raise ValueError("未指定数据目录") - - data_dir = Path(data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - - # 保存邻接矩阵 - if self._adjacency is not None: - matrix_path = data_dir / "graph_adjacency.npz" - with atomic_write(matrix_path, "wb") as f: - save_npz(f, self._adjacency) - logger.debug(f"保存邻接矩阵: {matrix_path}") - - # 保存元数据 - metadata = { - "nodes": self._nodes, - "node_to_idx": self._node_to_idx, - "node_attrs": self._node_attrs, - "matrix_format": self.matrix_format, - "total_nodes_added": self._total_nodes_added, - "total_edges_added": self._total_edges_added, - "total_nodes_deleted": self._total_nodes_deleted, - "total_edges_deleted": self._total_edges_deleted, - "edge_hash_map": dict(self._edge_hash_map), # 持久化 V5 映射 (将 defaultdict 转换为普通 dict) - } - - metadata_path = data_dir / "graph_metadata.pkl" - with atomic_write(metadata_path, "wb") as f: - pickle.dump(metadata, f) - logger.debug(f"保存元数据: {metadata_path}") - - logger.info(f"图存储已保存到: {data_dir}") - - def load(self, data_dir: Optional[Union[str, Path]] = None) -> None: - """ - 从磁盘加载 - - Args: - data_dir: 数据目录(默认使用初始化时的目录) - """ - if data_dir is None: - data_dir = self.data_dir - - if data_dir is None: - raise ValueError("未指定数据目录") - - data_dir = Path(data_dir) - if not data_dir.exists(): - raise FileNotFoundError(f"数据目录不存在: {data_dir}") - - # 加载元数据 - metadata_path = data_dir / "graph_metadata.pkl" - if not metadata_path.exists(): - raise FileNotFoundError(f"元数据文件不存在: {metadata_path}") - - with open(metadata_path, "rb") as f: - metadata = pickle.load(f) - - # 恢复状态,并通过规范化处理旧数据中的重复项 - self._nodes = metadata["nodes"] - self._node_attrs = {} # 重新构建以确保键名 (Key) 规范化 - self._node_to_idx = {} # 重新构建以确保键名 (Key) 规范化 - - # 重新构建映射,处理旧数据中的碰撞 - for idx, node_name in enumerate(self._nodes): - canon = self._canonicalize(node_name) - if canon not in self._node_to_idx: - self._node_to_idx[canon] = idx - - # 处理属性 (优先保留已有的) - orig_attrs = metadata.get("node_attrs", {}) - if node_name in orig_attrs and canon not in self._node_attrs: - self._node_attrs[canon] = orig_attrs[node_name] - - self.matrix_format = metadata["matrix_format"] - self._total_nodes_added = metadata["total_nodes_added"] - self._total_edges_added = metadata["total_edges_added"] - self._total_nodes_deleted = metadata["total_nodes_deleted"] - self._total_edges_deleted = metadata["total_edges_deleted"] - - # 恢复 V5 边哈希映射 (Restore V5 edge hash map) - edge_map_data = metadata.get("edge_hash_map", {}) - # 重新初始化为 defaultdict(set) - self._edge_hash_map = defaultdict(set) - if edge_map_data: - for k, v in edge_map_data.items(): - self._edge_hash_map[k] = set(v) # 确保类型为 set - - # 加载邻接矩阵 - matrix_path = data_dir / "graph_adjacency.npz" - if matrix_path.exists(): - self._adjacency = load_npz(str(matrix_path)) - - # 确保格式正确 - if self.matrix_format == "csc" and isinstance(self._adjacency, csr_matrix): - self._adjacency = self._adjacency.tocsc() - elif self.matrix_format == "csr" and isinstance(self._adjacency, csc_matrix): - self._adjacency = self._adjacency.tocsr() - - logger.debug(f"加载邻接矩阵: {matrix_path}, shape={self._adjacency.shape}") - - # 检查维度不匹配并修复 - if self._adjacency is not None: - adj_n = self._adjacency.shape[0] - current_n = len(self._nodes) - if current_n > adj_n: - logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...") - self._expand_adjacency_matrix(current_n - adj_n) - - self._adjacency_dirty = True - logger.info( - f"图存储已加载: {len(self._nodes)} 个节点, " - f"{self._adjacency.nnz if self._adjacency is not None else 0} 条边" - ) - - def _expand_adjacency_matrix(self, added_nodes: int) -> None: - """ - 扩展邻接矩阵以容纳新节点 - - Args: - added_nodes: 新增节点数量 - """ - if self._adjacency is None: - n = len(self._nodes) - # 根据模式初始化 - - if self._modification_mode == GraphModificationMode.INCREMENTAL: - self._adjacency = lil_matrix((n, n), dtype=np.float32) - else: - self._adjacency = csr_matrix((n, n), dtype=np.float32) - return - - old_n = self._adjacency.shape[0] - new_n = old_n + added_nodes - - # 优化:根据模式选择不同的扩容策略 - if self._modification_mode == GraphModificationMode.INCREMENTAL: - # LIL 格式可以直接 resize,非常高效 - try: - if not isinstance(self._adjacency, lil_matrix): - self._adjacency = self._adjacency.tolil() - - self._adjacency.resize((new_n, new_n)) - # logger.debug(f"扩展 LIL 矩阵: {old_n} -> {new_n}") - except Exception as e: - logger.warning(f"LIL resize 失败,回退到通用方法: {e}") - self._expand_generic(new_n, old_n) - - else: - # CSR/CSC 格式使用 bmat 避免结构破坏警告 - try: - # bmat 需要明确的形状,不能全部依赖 None - added = new_n - old_n - # 创建零矩阵块 - # 注意: 这里统一创建 CSR 零矩阵,bmat 会处理合并 - z_tr = csr_matrix((old_n, added), dtype=np.float32) - z_bl = csr_matrix((added, old_n), dtype=np.float32) - z_br = csr_matrix((added, added), dtype=np.float32) - - self._adjacency = bmat( - [[self._adjacency, z_tr], [z_bl, z_br]], - format=self.matrix_format, - dtype=np.float32 - ) - # logger.debug(f"扩展矩阵 (bmat): {old_n} -> {new_n}") - except Exception as e: - logger.warning(f"bmat 扩展失败: {e}") - self._expand_generic(new_n, old_n) - - def _expand_generic(self, new_n: int, old_n: int): - """通用扩展方法(回退方案)""" - if self.matrix_format == "csr": - new_adjacency = csr_matrix((new_n, new_n), dtype=np.float32) - new_adjacency[:old_n, :old_n] = self._adjacency - else: - new_adjacency = csc_matrix((new_n, new_n), dtype=np.float32) - new_adjacency[:old_n, :old_n] = self._adjacency - self._adjacency = new_adjacency - self._adjacency_dirty = True - - # 如果都在增量模式,确保是LIL - if self._modification_mode == GraphModificationMode.INCREMENTAL: - try: - self._adjacency = self._adjacency.tolil() - except: - pass - - @property - def num_nodes(self) -> int: - """节点数量""" - return len(self._nodes) - - @property - def num_edges(self) -> int: - """边数量""" - if self._adjacency is None: - return 0 - return int(self._adjacency.nnz) - - @property - def density(self) -> float: - """ - 图密度(实际边数 / 可能的最大边数) - - 有向图: E / (V * (V - 1)) - 无向图: 2E / (V * (V - 1)) - - 这里按有向图计算 - """ - if self.num_nodes < 2: - return 0.0 - - max_edges = self.num_nodes * (self.num_nodes - 1) - return self.num_edges / max_edges if max_edges > 0 else 0.0 - - def __len__(self) -> int: - """节点数量""" - return self.num_nodes - - def has_data(self) -> bool: - """检查磁盘上是否存在现有数据""" - if self.data_dir is None: - return False - return (self.data_dir / "graph_metadata.pkl").exists() - - def __repr__(self) -> str: - return ( - f"GraphStore(nodes={self.num_nodes}, edges={self.num_edges}, " - f"density={self.density:.4f}, format={self.matrix_format})" - ) - - def rebuild_edge_hash_map(self, triples: List[Tuple[str, str, str, str]]) -> int: - """ - 从元数据重建 V5 边哈希映射 (Migration Tool) - - Args: - triples: List of (s, p, o, hash) - - Returns: - count of mapped hashes - """ - count = 0 - self._edge_hash_map = defaultdict(set) - - for s, p, o, h in triples: - if not h: continue - - s_canon = self._canonicalize(s) - o_canon = self._canonicalize(o) - - if s_canon in self._node_to_idx and o_canon in self._node_to_idx: - u = self._node_to_idx[s_canon] - v = self._node_to_idx[o_canon] - - # 如果是双向的,通常在元数据中存储为有向,而 GraphStore 也通常是有向的。 - # 映射键对应特定的边方向。 - self._edge_hash_map[(u, v)].add(h) - count += 1 - - self._adjacency_dirty = True - logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边") - return count - diff --git a/plugins/A_memorix/core/storage/knowledge_types.py b/plugins/A_memorix/core/storage/knowledge_types.py deleted file mode 100644 index 4ab91218..00000000 --- a/plugins/A_memorix/core/storage/knowledge_types.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Knowledge type and import strategy helpers.""" - -from __future__ import annotations - -from enum import Enum -from typing import Any, Optional - - -class KnowledgeType(str, Enum): - """持久化到 paragraphs.knowledge_type 的合法类型。""" - - STRUCTURED = "structured" - NARRATIVE = "narrative" - FACTUAL = "factual" - QUOTE = "quote" - MIXED = "mixed" - - -class ImportStrategy(str, Enum): - """文本导入阶段的策略选择。""" - - AUTO = "auto" - NARRATIVE = "narrative" - FACTUAL = "factual" - QUOTE = "quote" - - -def allowed_knowledge_type_values() -> tuple[str, ...]: - return tuple(item.value for item in KnowledgeType) - - -def allowed_import_strategy_values() -> tuple[str, ...]: - return tuple(item.value for item in ImportStrategy) - - -def get_knowledge_type_from_string(type_str: Any) -> Optional[KnowledgeType]: - """从字符串解析合法的落库知识类型。""" - - if not isinstance(type_str, str): - return None - normalized = type_str.lower().strip() - try: - return KnowledgeType(normalized) - except ValueError: - return None - - -def get_import_strategy_from_string(value: Any) -> Optional[ImportStrategy]: - """从字符串解析文本导入策略。""" - - if not isinstance(value, str): - return None - normalized = value.lower().strip() - try: - return ImportStrategy(normalized) - except ValueError: - return None - - -def parse_import_strategy(value: Any, default: ImportStrategy = ImportStrategy.AUTO) -> ImportStrategy: - """解析 import strategy;非法值直接报错。""" - - if value is None: - return default - if isinstance(value, ImportStrategy): - return value - - normalized = str(value or "").strip().lower() - if not normalized: - return default - - strategy = get_import_strategy_from_string(normalized) - if strategy is None: - allowed = "/".join(allowed_import_strategy_values()) - raise ValueError(f"strategy_override 必须为 {allowed}") - return strategy - - -def validate_stored_knowledge_type(value: Any) -> KnowledgeType: - """校验写库 knowledge_type,仅允许合法落库类型。""" - - if isinstance(value, KnowledgeType): - return value - - resolved = get_knowledge_type_from_string(value) - if resolved is None: - allowed = "/".join(allowed_knowledge_type_values()) - raise ValueError(f"knowledge_type 必须为 {allowed}") - return resolved - - -def resolve_stored_knowledge_type( - value: Any, - *, - content: str = "", - allow_legacy: bool = False, - unknown_fallback: Optional[KnowledgeType] = None, -) -> KnowledgeType: - """ - 将策略/字符串/旧值解析为合法落库类型。 - - `allow_legacy=True` 仅供迁移使用。 - """ - - if isinstance(value, KnowledgeType): - return value - - if isinstance(value, ImportStrategy): - if value == ImportStrategy.AUTO: - if not str(content or "").strip(): - raise ValueError("knowledge_type=auto 需要 content 才能推断") - from .type_detection import detect_knowledge_type - - return detect_knowledge_type(content) - return KnowledgeType(value.value) - - raw = str(value or "").strip() - if not raw: - if str(content or "").strip(): - from .type_detection import detect_knowledge_type - - return detect_knowledge_type(content) - raise ValueError("knowledge_type 不能为空") - - direct = get_knowledge_type_from_string(raw) - if direct is not None: - return direct - - strategy = get_import_strategy_from_string(raw) - if strategy is not None: - return resolve_stored_knowledge_type(strategy, content=content) - - if allow_legacy: - normalized = raw.lower() - if normalized == "imported": - return KnowledgeType.FACTUAL - if str(content or "").strip(): - from .type_detection import detect_knowledge_type - - detected = detect_knowledge_type(content) - if detected is not None: - return detected - if unknown_fallback is not None: - return unknown_fallback - - allowed = "/".join(allowed_knowledge_type_values()) - raise ValueError(f"非法 knowledge_type: {raw}(仅允许 {allowed})") - - -def should_extract_relations(knowledge_type: KnowledgeType) -> bool: - """判断是否应该做关系抽取。""" - - return knowledge_type in [ - KnowledgeType.STRUCTURED, - KnowledgeType.FACTUAL, - KnowledgeType.MIXED, - ] - - -def get_default_chunk_size(knowledge_type: KnowledgeType) -> int: - """获取默认分块大小。""" - - chunk_sizes = { - KnowledgeType.STRUCTURED: 300, - KnowledgeType.NARRATIVE: 800, - KnowledgeType.FACTUAL: 500, - KnowledgeType.QUOTE: 400, - KnowledgeType.MIXED: 500, - } - return chunk_sizes.get(knowledge_type, 500) - - -def get_type_display_name(knowledge_type: KnowledgeType) -> str: - """获取知识类型中文名称。""" - - display_names = { - KnowledgeType.STRUCTURED: "结构化知识", - KnowledgeType.NARRATIVE: "叙事性文本", - KnowledgeType.FACTUAL: "事实陈述", - KnowledgeType.QUOTE: "引用文本", - KnowledgeType.MIXED: "混合类型", - } - return display_names.get(knowledge_type, "未知类型") diff --git a/plugins/A_memorix/core/storage/metadata_store.py b/plugins/A_memorix/core/storage/metadata_store.py deleted file mode 100644 index 39f2701c..00000000 --- a/plugins/A_memorix/core/storage/metadata_store.py +++ /dev/null @@ -1,5748 +0,0 @@ -""" -元数据存储模块 - -基于SQLite的元数据管理,存储段落、实体、关系等信息。 -""" - -import sqlite3 -import pickle -import json -import uuid -import re -from datetime import datetime -from pathlib import Path -from typing import Optional, Union, List, Dict, Any, Tuple - -from src.common.logger import get_logger -from ..utils.hash import compute_hash, normalize_text -from ..utils.time_parser import normalize_time_meta -from .knowledge_types import ( - KnowledgeType, - allowed_knowledge_type_values, - resolve_stored_knowledge_type, - validate_stored_knowledge_type, -) - -logger = get_logger("A_Memorix.MetadataStore") - - -SCHEMA_VERSION = 8 - - -class MetadataStore: - """ - 元数据存储类 - - 功能: - - SQLite数据库管理 - - 段落/实体/关系元数据存储 - - 增删改查操作 - - 事务支持 - - 索引优化 - - 参数: - data_dir: 数据目录 - db_name: 数据库文件名(默认metadata.db) - """ - - def __init__( - self, - data_dir: Optional[Union[str, Path]] = None, - db_name: str = "metadata.db", - ): - """ - 初始化元数据存储 - - Args: - data_dir: 数据目录 - db_name: 数据库文件名 - """ - self.data_dir = Path(data_dir) if data_dir else None - self.db_name = db_name - self._conn: Optional[sqlite3.Connection] = None - self._is_initialized = False - self._db_path: Optional[Path] = None - - logger.info(f"MetadataStore 初始化: db={db_name}") - - def connect( - self, - data_dir: Optional[Union[str, Path]] = None, - *, - enforce_schema: bool = True, - ) -> None: - """ - 连接到数据库 - - Args: - data_dir: 数据目录(默认使用初始化时的目录) - """ - if data_dir is None: - data_dir = self.data_dir - - if data_dir is None: - raise ValueError("未指定数据目录") - - data_dir = Path(data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - - db_path = data_dir / self.db_name - db_existed = db_path.exists() - self._db_path = db_path - - # 连接数据库 - self._conn = sqlite3.connect( - str(db_path), - check_same_thread=False, - timeout=30.0, - ) - self._conn.row_factory = sqlite3.Row # 使用字典式访问 - - # 优化性能 - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA synchronous=NORMAL") - self._conn.execute("PRAGMA cache_size=-64000") # 64MB缓存 - self._conn.execute("PRAGMA temp_store=MEMORY") - self._conn.execute("PRAGMA foreign_keys = ON") # 开启外键约束支持级联删除 - - logger.info(f"连接到数据库: {db_path}") - - # 初始化或校验 schema - if not self._is_initialized: - if not db_existed: - self._initialize_tables() - if enforce_schema: - self._assert_schema_compatible(db_existed=db_existed) - self._is_initialized = True - - # 初始化 FTS schema(幂等) - try: - self.ensure_fts_schema() - except Exception as e: - logger.warning(f"初始化 FTS schema 失败,将跳过 BM25 检索: {e}") - - def _assert_schema_compatible(self, db_existed: bool) -> None: - """vNext 运行时只做 schema 版本校验,不做隐式迁移。""" - cursor = self._conn.cursor() - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" - ) - has_version_table = cursor.fetchone() is not None - if not has_version_table: - if db_existed: - raise RuntimeError( - "检测到旧版 metadata schema(缺少 schema_migrations)。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - return - - cursor.execute("SELECT MAX(version) FROM schema_migrations") - row = cursor.fetchone() - version = int(row[0]) if row and row[0] is not None else 0 - if version != SCHEMA_VERSION: - raise RuntimeError( - f"metadata schema 版本不匹配: current={version}, expected={SCHEMA_VERSION}。" - " 请执行 scripts/release_vnext_migrate.py migrate。" - ) - - def close(self) -> None: - """关闭数据库连接""" - if self._conn: - self._conn.close() - self._conn = None - logger.info("数据库连接已关闭") - - def _initialize_tables(self) -> None: - """初始化数据库表结构""" - cursor = self._conn.cursor() - - # 段落表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS paragraphs ( - hash TEXT PRIMARY KEY, - content TEXT NOT NULL, - vector_index INTEGER, - created_at REAL, - updated_at REAL, - metadata TEXT, - source TEXT, - word_count INTEGER, - event_time REAL, - event_time_start REAL, - event_time_end REAL, - time_granularity TEXT, - time_confidence REAL DEFAULT 1.0, - knowledge_type TEXT DEFAULT 'mixed', - is_permanent BOOLEAN DEFAULT 0, - last_accessed REAL, - access_count INTEGER DEFAULT 0, - is_deleted INTEGER DEFAULT 0, - deleted_at REAL - ) - """) - - # 实体表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS entities ( - hash TEXT PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - vector_index INTEGER, - appearance_count INTEGER DEFAULT 1, - created_at REAL, - metadata TEXT, - is_deleted INTEGER DEFAULT 0, - deleted_at REAL - ) - """) - - # 关系表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS relations ( - hash TEXT PRIMARY KEY, - subject TEXT NOT NULL, - predicate TEXT NOT NULL, - object TEXT NOT NULL, - vector_index INTEGER, - confidence REAL DEFAULT 1.0, - vector_state TEXT DEFAULT 'none', - vector_updated_at REAL, - vector_error TEXT, - vector_retry_count INTEGER DEFAULT 0, - created_at REAL, - source_paragraph TEXT, - metadata TEXT, - is_permanent BOOLEAN DEFAULT 0, - last_accessed REAL, - access_count INTEGER DEFAULT 0, - is_inactive BOOLEAN DEFAULT 0, - inactive_since REAL, - is_pinned BOOLEAN DEFAULT 0, - protected_until REAL, - last_reinforced REAL, - UNIQUE(subject, predicate, object) - ) - """) - - # 回收站关系表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS deleted_relations ( - hash TEXT PRIMARY KEY, - subject TEXT NOT NULL, - predicate TEXT NOT NULL, - object TEXT NOT NULL, - vector_index INTEGER, - confidence REAL DEFAULT 1.0, - vector_state TEXT DEFAULT 'none', - vector_updated_at REAL, - vector_error TEXT, - vector_retry_count INTEGER DEFAULT 0, - created_at REAL, - source_paragraph TEXT, - metadata TEXT, - is_permanent BOOLEAN DEFAULT 0, - last_accessed REAL, - access_count INTEGER DEFAULT 0, - is_inactive BOOLEAN DEFAULT 0, - inactive_since REAL, - is_pinned BOOLEAN DEFAULT 0, - protected_until REAL, - last_reinforced REAL, - deleted_at REAL - ) - """) - - # 32位哈希别名映射(用于 vNext 唯一解析) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS relation_hash_aliases ( - alias32 TEXT PRIMARY KEY, - hash TEXT NOT NULL - ) - """) - - # Schema 版本 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - applied_at REAL NOT NULL - ) - """) - - # 三元组与段落的关联表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS paragraph_relations ( - paragraph_hash TEXT NOT NULL, - relation_hash TEXT NOT NULL, - PRIMARY KEY (paragraph_hash, relation_hash), - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, - FOREIGN KEY (relation_hash) REFERENCES relations(hash) ON DELETE CASCADE - ) - """) - - # 实体与段落的关联表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS paragraph_entities ( - paragraph_hash TEXT NOT NULL, - entity_hash TEXT NOT NULL, - mention_count INTEGER DEFAULT 1, - PRIMARY KEY (paragraph_hash, entity_hash), - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, - FOREIGN KEY (entity_hash) REFERENCES entities(hash) ON DELETE CASCADE - ) - """) - - # 创建索引 - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_paragraphs_vector - ON paragraphs(vector_index) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_entities_vector - ON entities(vector_index) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_vector - ON relations(vector_index) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_subject - ON relations(subject) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_object - ON relations(object) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_entities_name - ON entities(name) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_paragraphs_source - ON paragraphs(source) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_paragraphs_deleted - ON paragraphs(is_deleted, deleted_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_entities_deleted - ON entities(is_deleted, deleted_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_inactive - ON relations(is_inactive, inactive_since) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_protected - ON relations(is_pinned, protected_until) - """) - - # 人物画像开关表(按 stream_id + user_id 维度) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS person_profile_switches ( - stream_id TEXT NOT NULL, - user_id TEXT NOT NULL, - enabled INTEGER NOT NULL DEFAULT 0, - updated_at REAL NOT NULL, - PRIMARY KEY (stream_id, user_id) - ) - """) - - # 人物画像快照表(版本化) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS person_profile_snapshots ( - snapshot_id INTEGER PRIMARY KEY AUTOINCREMENT, - person_id TEXT NOT NULL, - profile_version INTEGER NOT NULL, - profile_text TEXT NOT NULL, - aliases_json TEXT, - relation_edges_json TEXT, - vector_evidence_json TEXT, - evidence_ids_json TEXT, - updated_at REAL NOT NULL, - expires_at REAL, - source_note TEXT, - UNIQUE(person_id, profile_version) - ) - """) - - # 已开启范围内的活跃人物集合 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS person_profile_active_persons ( - stream_id TEXT NOT NULL, - user_id TEXT NOT NULL, - person_id TEXT NOT NULL, - last_seen_at REAL NOT NULL, - PRIMARY KEY (stream_id, user_id, person_id) - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS person_profile_overrides ( - person_id TEXT PRIMARY KEY, - override_text TEXT NOT NULL, - updated_at REAL NOT NULL, - updated_by TEXT, - source TEXT - ) - """) - - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_person_profile_switches_enabled - ON person_profile_switches(enabled) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_person_profile_snapshots_person - ON person_profile_snapshots(person_id, updated_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_person_profile_active_seen - ON person_profile_active_persons(last_seen_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_person_profile_overrides_updated - ON person_profile_overrides(updated_at DESC) - """) - - # Episode 情景记忆表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episodes ( - episode_id TEXT PRIMARY KEY, - source TEXT, - title TEXT NOT NULL, - summary TEXT NOT NULL, - event_time_start REAL, - event_time_end REAL, - time_granularity TEXT, - time_confidence REAL DEFAULT 1.0, - participants_json TEXT, - keywords_json TEXT, - evidence_ids_json TEXT, - paragraph_count INTEGER DEFAULT 0, - llm_confidence REAL DEFAULT 0.0, - segmentation_model TEXT, - segmentation_version TEXT, - created_at REAL NOT NULL, - updated_at REAL NOT NULL - ) - """) - - # Episode -> Paragraph 映射 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_paragraphs ( - episode_id TEXT NOT NULL, - paragraph_hash TEXT NOT NULL, - position INTEGER DEFAULT 0, - PRIMARY KEY (episode_id, paragraph_hash), - FOREIGN KEY (episode_id) REFERENCES episodes(episode_id) ON DELETE CASCADE, - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE - ) - """) - - # Episode 生成队列(异步) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_pending_paragraphs ( - paragraph_hash TEXT PRIMARY KEY, - source TEXT, - created_at REAL, - status TEXT DEFAULT 'pending', - retry_count INTEGER DEFAULT 0, - last_error TEXT, - updated_at REAL NOT NULL - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_rebuild_sources ( - source TEXT PRIMARY KEY, - status TEXT DEFAULT 'pending', - retry_count INTEGER DEFAULT 0, - last_error TEXT, - reason TEXT, - requested_at REAL NOT NULL, - updated_at REAL NOT NULL - ) - """) - - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episodes_source_time_end - ON episodes(source, event_time_end DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episodes_updated_at - ON episodes(updated_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_paragraphs_paragraph - ON episode_paragraphs(paragraph_hash) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_pending_status_updated - ON episode_pending_paragraphs(status, updated_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_pending_source_created - ON episode_pending_paragraphs(source, created_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_rebuild_status_updated - ON episode_rebuild_sources(status, updated_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_rebuild_updated_at - ON episode_rebuild_sources(updated_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS external_memory_refs ( - external_id TEXT PRIMARY KEY, - paragraph_hash TEXT NOT NULL, - source_type TEXT, - created_at REAL NOT NULL, - metadata_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph - ON external_memory_refs(paragraph_hash) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS memory_v5_operations ( - operation_id TEXT PRIMARY KEY, - action TEXT NOT NULL, - target TEXT, - reason TEXT, - updated_by TEXT, - created_at REAL NOT NULL, - resolved_hashes_json TEXT, - result_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created - ON memory_v5_operations(created_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS delete_operations ( - operation_id TEXT PRIMARY KEY, - mode TEXT NOT NULL, - selector TEXT, - reason TEXT, - requested_by TEXT, - status TEXT NOT NULL, - created_at REAL NOT NULL, - restored_at REAL, - summary_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operations_created - ON delete_operations(created_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operations_mode - ON delete_operations(mode, created_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS delete_operation_items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - operation_id TEXT NOT NULL, - item_type TEXT NOT NULL, - item_hash TEXT, - item_key TEXT, - payload_json TEXT, - created_at REAL NOT NULL, - FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation - ON delete_operation_items(operation_id, id ASC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash - ON delete_operation_items(item_hash) - """) - # 新版 schema 包含完整字段,直接写入版本信息 - cursor.execute("INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", (SCHEMA_VERSION, datetime.now().timestamp())) - self._conn.commit() - logger.debug("数据库表结构初始化完成") - - def _migrate_schema(self) -> None: - """执行数据库schema迁移""" - cursor = self._conn.cursor() - - # vNext 关键表兜底:历史库可能缺失,需在迁移阶段主动补齐。 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS relation_hash_aliases ( - alias32 TEXT PRIMARY KEY, - hash TEXT NOT NULL - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - applied_at REAL NOT NULL - ) - """) - - # Episode MVP 表结构补齐 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episodes ( - episode_id TEXT PRIMARY KEY, - source TEXT, - title TEXT NOT NULL, - summary TEXT NOT NULL, - event_time_start REAL, - event_time_end REAL, - time_granularity TEXT, - time_confidence REAL DEFAULT 1.0, - participants_json TEXT, - keywords_json TEXT, - evidence_ids_json TEXT, - paragraph_count INTEGER DEFAULT 0, - llm_confidence REAL DEFAULT 0.0, - segmentation_model TEXT, - segmentation_version TEXT, - created_at REAL NOT NULL, - updated_at REAL NOT NULL - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_paragraphs ( - episode_id TEXT NOT NULL, - paragraph_hash TEXT NOT NULL, - position INTEGER DEFAULT 0, - PRIMARY KEY (episode_id, paragraph_hash), - FOREIGN KEY (episode_id) REFERENCES episodes(episode_id) ON DELETE CASCADE, - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_pending_paragraphs ( - paragraph_hash TEXT PRIMARY KEY, - source TEXT, - created_at REAL, - status TEXT DEFAULT 'pending', - retry_count INTEGER DEFAULT 0, - last_error TEXT, - updated_at REAL NOT NULL - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_rebuild_sources ( - source TEXT PRIMARY KEY, - status TEXT DEFAULT 'pending', - retry_count INTEGER DEFAULT 0, - last_error TEXT, - reason TEXT, - requested_at REAL NOT NULL, - updated_at REAL NOT NULL - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episodes_source_time_end - ON episodes(source, event_time_end DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episodes_updated_at - ON episodes(updated_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_paragraphs_paragraph - ON episode_paragraphs(paragraph_hash) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_pending_status_updated - ON episode_pending_paragraphs(status, updated_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_pending_source_created - ON episode_pending_paragraphs(source, created_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_rebuild_status_updated - ON episode_rebuild_sources(status, updated_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_rebuild_updated_at - ON episode_rebuild_sources(updated_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS external_memory_refs ( - external_id TEXT PRIMARY KEY, - paragraph_hash TEXT NOT NULL, - source_type TEXT, - created_at REAL NOT NULL, - metadata_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph - ON external_memory_refs(paragraph_hash) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS memory_v5_operations ( - operation_id TEXT PRIMARY KEY, - action TEXT NOT NULL, - target TEXT, - reason TEXT, - updated_by TEXT, - created_at REAL NOT NULL, - resolved_hashes_json TEXT, - result_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created - ON memory_v5_operations(created_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS delete_operations ( - operation_id TEXT PRIMARY KEY, - mode TEXT NOT NULL, - selector TEXT, - reason TEXT, - requested_by TEXT, - status TEXT NOT NULL, - created_at REAL NOT NULL, - restored_at REAL, - summary_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operations_created - ON delete_operations(created_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operations_mode - ON delete_operations(mode, created_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS delete_operation_items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - operation_id TEXT NOT NULL, - item_type TEXT NOT NULL, - item_hash TEXT, - item_key TEXT, - payload_json TEXT, - created_at REAL NOT NULL, - FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation - ON delete_operation_items(operation_id, id ASC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash - ON delete_operation_items(item_hash) - """) - - # 检查paragraphs表是否有knowledge_type列 - cursor.execute("PRAGMA table_info(paragraphs)") - columns = [row[1] for row in cursor.fetchall()] - - if "knowledge_type" not in columns: - logger.info("检测到旧版schema,正在迁移添加knowledge_type字段...") - try: - cursor.execute(""" - ALTER TABLE paragraphs - ADD COLUMN knowledge_type TEXT DEFAULT 'mixed' - """) - self._conn.commit() - logger.info("Schema迁移完成:已添加knowledge_type字段") - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败(可能已存在): {e}") - - # 问题2: 时序字段迁移 - cursor.execute("PRAGMA table_info(paragraphs)") - columns = [row[1] for row in cursor.fetchall()] - temporal_columns = { - "event_time": "ALTER TABLE paragraphs ADD COLUMN event_time REAL", - "event_time_start": "ALTER TABLE paragraphs ADD COLUMN event_time_start REAL", - "event_time_end": "ALTER TABLE paragraphs ADD COLUMN event_time_end REAL", - "time_granularity": "ALTER TABLE paragraphs ADD COLUMN time_granularity TEXT", - "time_confidence": "ALTER TABLE paragraphs ADD COLUMN time_confidence REAL DEFAULT 1.0", - } - for col, sql in temporal_columns.items(): - if col not in columns: - try: - cursor.execute(sql) - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败({col}): {e}") - - # 时序索引(仅在列存在时创建,兼容旧库迁移) - self._create_temporal_indexes_if_ready() - self._conn.commit() - - # 检查paragraphs表是否有is_permanent列 - cursor.execute("PRAGMA table_info(paragraphs)") - columns = [row[1] for row in cursor.fetchall()] - - if "is_permanent" not in columns: - logger.info("正在迁移: 添加记忆动态字段...") - try: - # 段落表 - cursor.execute("ALTER TABLE paragraphs ADD COLUMN is_permanent BOOLEAN DEFAULT 0") - cursor.execute("ALTER TABLE paragraphs ADD COLUMN last_accessed REAL") - cursor.execute("ALTER TABLE paragraphs ADD COLUMN access_count INTEGER DEFAULT 0") - - # 关系表 - cursor.execute("ALTER TABLE relations ADD COLUMN is_permanent BOOLEAN DEFAULT 0") - cursor.execute("ALTER TABLE relations ADD COLUMN last_accessed REAL") - cursor.execute("ALTER TABLE relations ADD COLUMN access_count INTEGER DEFAULT 0") - - self._conn.commit() - logger.info("Schema迁移完成:已添加记忆动态字段") - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败: {e}") - - # 检查relations表是否有is_inactive列 (V5 Memory System) - cursor.execute("PRAGMA table_info(relations)") - columns = [row[1] for row in cursor.fetchall()] - - if "is_inactive" not in columns: - logger.info("正在迁移: 添加V5记忆动态字段 (inactive, protected)...") - try: - # 关系表 V5 新增字段 - cursor.execute("ALTER TABLE relations ADD COLUMN is_inactive BOOLEAN DEFAULT 0") - cursor.execute("ALTER TABLE relations ADD COLUMN inactive_since REAL") - cursor.execute("ALTER TABLE relations ADD COLUMN is_pinned BOOLEAN DEFAULT 0") - cursor.execute("ALTER TABLE relations ADD COLUMN protected_until REAL") - cursor.execute("ALTER TABLE relations ADD COLUMN last_reinforced REAL") - - # 为回收站创建 deleted_relations 表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS deleted_relations ( - hash TEXT PRIMARY KEY, - subject TEXT NOT NULL, - predicate TEXT NOT NULL, - object TEXT NOT NULL, - vector_index INTEGER, - confidence REAL DEFAULT 1.0, - vector_state TEXT DEFAULT 'none', - vector_updated_at REAL, - vector_error TEXT, - vector_retry_count INTEGER DEFAULT 0, - created_at REAL, - source_paragraph TEXT, - metadata TEXT, - is_permanent BOOLEAN DEFAULT 0, - last_accessed REAL, - access_count INTEGER DEFAULT 0, - is_inactive BOOLEAN DEFAULT 0, - inactive_since REAL, - is_pinned BOOLEAN DEFAULT 0, - protected_until REAL, - last_reinforced REAL, - deleted_at REAL -- 用于记录删除时间的额外列 - ) - """) - - self._conn.commit() - logger.info("Schema迁移完成:已添加V5记忆动态字段及回收站表") - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败 (V5): {e}") - - # 关系向量状态字段迁移 - cursor.execute("PRAGMA table_info(relations)") - relation_columns = {row[1] for row in cursor.fetchall()} - relation_vector_columns = { - "vector_state": "ALTER TABLE relations ADD COLUMN vector_state TEXT DEFAULT 'none'", - "vector_updated_at": "ALTER TABLE relations ADD COLUMN vector_updated_at REAL", - "vector_error": "ALTER TABLE relations ADD COLUMN vector_error TEXT", - "vector_retry_count": "ALTER TABLE relations ADD COLUMN vector_retry_count INTEGER DEFAULT 0", - } - for col, sql in relation_vector_columns.items(): - if col not in relation_columns: - try: - cursor.execute(sql) - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败 (relations.{col}): {e}") - - # 回收站同步字段迁移(用于 restore 保留向量状态) - cursor.execute("PRAGMA table_info(deleted_relations)") - deleted_relation_columns = {row[1] for row in cursor.fetchall()} - deleted_relation_vector_columns = { - "vector_state": "ALTER TABLE deleted_relations ADD COLUMN vector_state TEXT DEFAULT 'none'", - "vector_updated_at": "ALTER TABLE deleted_relations ADD COLUMN vector_updated_at REAL", - "vector_error": "ALTER TABLE deleted_relations ADD COLUMN vector_error TEXT", - "vector_retry_count": "ALTER TABLE deleted_relations ADD COLUMN vector_retry_count INTEGER DEFAULT 0", - } - for col, sql in deleted_relation_vector_columns.items(): - if col not in deleted_relation_columns: - try: - cursor.execute(sql) - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败 (deleted_relations.{col}): {e}") - - # 检查 entities 表是否有 is_deleted 列 (Soft Delete System) - cursor.execute("PRAGMA table_info(entities)") - columns = [row[1] for row in cursor.fetchall()] - - if "is_deleted" not in columns: - logger.info("正在迁移: 添加软删除字段 (Soft Delete)...") - try: - # 实体表 - cursor.execute("ALTER TABLE entities ADD COLUMN is_deleted INTEGER DEFAULT 0") - cursor.execute("ALTER TABLE entities ADD COLUMN deleted_at REAL") - - # 段落表 - cursor.execute("ALTER TABLE paragraphs ADD COLUMN is_deleted INTEGER DEFAULT 0") - cursor.execute("ALTER TABLE paragraphs ADD COLUMN deleted_at REAL") - - self._conn.commit() - logger.info("Schema迁移完成:已添加软删除字段") - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败 (Soft Delete): {e}") - - # 数据修复: 检查是否存在 source/vector_index 列错位的情况 - # 症状: vector_index (本应是int) 变成了文件名字符串, source (本应是文件名) 变成了类型字符串 - try: - cursor.execute(""" - SELECT count(*) FROM paragraphs - WHERE typeof(vector_index) = 'text' - AND source IN ('mixed', 'factual', 'narrative', 'structured', 'auto') - """) - count = cursor.fetchone()[0] - if count > 0: - logger.warning(f"检测到 {count} 条数据存在列错位(文件名误存入vector_index),正在自动修复...") - cursor.execute(""" - UPDATE paragraphs - SET - knowledge_type = source, - source = vector_index, - vector_index = NULL - WHERE typeof(vector_index) = 'text' - AND source IN ('mixed', 'factual', 'narrative', 'structured', 'auto') - """) - self._conn.commit() - logger.info(f"自动修复完成: 已校正 {cursor.rowcount} 条数据") - except Exception as e: - logger.error(f"数据自动修复失败: {e}") - - def _create_temporal_indexes_if_ready(self) -> None: - """ - 仅当时序列已存在时创建索引。 - - 旧库升级时,_initialize_tables 不能提前对不存在的列建索引; - 因此统一在迁移阶段按列存在性安全创建。 - """ - cursor = self._conn.cursor() - cursor.execute("PRAGMA table_info(paragraphs)") - columns = {row[1] for row in cursor.fetchall()} - - if "event_time" in columns: - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_time ON paragraphs(event_time)" - ) - if "event_time_start" in columns: - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_start ON paragraphs(event_time_start)" - ) - if "event_time_end" in columns: - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_end ON paragraphs(event_time_end)" - ) - - def run_legacy_migration_for_vnext(self) -> Dict[str, Any]: - """ - 离线迁移入口: - - 复用旧迁移逻辑补齐历史库字段 - - 重建 relation 32位别名 - - 归一化历史 knowledge_type - - 写入 vNext schema 版本 - """ - self._migrate_schema() - alias_result = self.rebuild_relation_hash_aliases() - knowledge_type_result = self.normalize_paragraph_knowledge_types() - self.set_schema_version(SCHEMA_VERSION) - return { - "schema_version": SCHEMA_VERSION, - "alias_result": alias_result, - "knowledge_type_result": knowledge_type_result, - } - - def list_invalid_paragraph_knowledge_types(self) -> List[str]: - """列出当前库中不合法的段落 knowledge_type。""" - - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT DISTINCT knowledge_type - FROM paragraphs - WHERE knowledge_type IS NULL - OR TRIM(COALESCE(knowledge_type, '')) = '' - OR LOWER(TRIM(knowledge_type)) NOT IN ({placeholders}) - ORDER BY knowledge_type - """.format(placeholders=", ".join("?" for _ in allowed_knowledge_type_values())), - tuple(allowed_knowledge_type_values()), - ) - invalid: List[str] = [] - for row in cursor.fetchall(): - raw = row[0] - invalid.append(str(raw) if raw is not None else "") - return invalid - - def normalize_paragraph_knowledge_types(self) -> Dict[str, Any]: - """将历史非法 knowledge_type 归一化为合法值。""" - - cursor = self._conn.cursor() - cursor.execute("SELECT hash, content, knowledge_type FROM paragraphs") - rows = cursor.fetchall() - - normalized_count = 0 - normalized_map: Dict[str, int] = {} - invalid_before: List[str] = [] - invalid_seen = set() - - for row in rows: - paragraph_hash = str(row["hash"]) - content = str(row["content"] or "") - raw_value = row["knowledge_type"] - try: - validate_stored_knowledge_type(raw_value) - continue - except ValueError: - raw_text = str(raw_value) if raw_value is not None else "" - if raw_text not in invalid_seen: - invalid_seen.add(raw_text) - invalid_before.append(raw_text) - - normalized_type = resolve_stored_knowledge_type( - raw_value, - content=content, - allow_legacy=True, - unknown_fallback=KnowledgeType.MIXED, - ) - cursor.execute( - "UPDATE paragraphs SET knowledge_type = ? WHERE hash = ?", - (normalized_type.value, paragraph_hash), - ) - normalized_count += 1 - normalized_map[normalized_type.value] = normalized_map.get(normalized_type.value, 0) + 1 - - self._conn.commit() - return { - "normalized": normalized_count, - "invalid_before": sorted(invalid_before), - "normalized_to": normalized_map, - } - - def _resolve_conn(self, conn: Optional[sqlite3.Connection] = None) -> sqlite3.Connection: - """解析可用连接。""" - resolved = conn or self._conn - if resolved is None: - raise RuntimeError("MetadataStore 未连接数据库") - return resolved - - def get_db_path(self) -> Path: - """获取 SQLite 数据库文件路径。""" - if self._db_path is not None: - return self._db_path - if self.data_dir is None: - raise RuntimeError("MetadataStore 未配置 data_dir") - return Path(self.data_dir) / self.db_name - - def ensure_fts_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """ - 确保 FTS5 schema 存在(幂等)。 - - 采用 external-content 方式,不在 FTS 表重复存储正文。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute(""" - CREATE VIRTUAL TABLE IF NOT EXISTS paragraphs_fts - USING fts5( - content, - content='paragraphs', - content_rowid='rowid', - tokenize='unicode61' - ) - """) - - # insert trigger - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS paragraphs_ai - AFTER INSERT ON paragraphs - BEGIN - INSERT INTO paragraphs_fts(rowid, content) - VALUES (new.rowid, new.content); - END - """) - - # delete trigger - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS paragraphs_ad - AFTER DELETE ON paragraphs - BEGIN - INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) - VALUES ('delete', old.rowid, old.content); - END - """) - - # update trigger - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS paragraphs_au - AFTER UPDATE OF content ON paragraphs - BEGIN - INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) - VALUES ('delete', old.rowid, old.content); - INSERT INTO paragraphs_fts(rowid, content) - VALUES (new.rowid, new.content); - END - """) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"FTS5 schema 创建失败(可能不支持 FTS5): {e}") - c.rollback() - return False - - def ensure_fts_backfilled(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """ - 确保 FTS 索引已回填。 - - 当历史数据存在但 FTS 表为空/不一致时执行 rebuild。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute("SELECT COUNT(1) AS n FROM paragraphs") - para_count = int(cur.fetchone()[0]) - cur.execute("SELECT COUNT(1) AS n FROM paragraphs_fts") - fts_count = int(cur.fetchone()[0]) - - if para_count > 0 and fts_count != para_count: - cur.execute("INSERT INTO paragraphs_fts(paragraphs_fts) VALUES ('rebuild')") - c.commit() - logger.info(f"FTS 回填完成: paragraphs={para_count}, fts={para_count}") - return True - except sqlite3.OperationalError as e: - logger.warning(f"FTS 回填失败: {e}") - c.rollback() - return False - - def ensure_relations_fts_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """ - 确保关系 FTS5 schema 存在(幂等)。 - - 注意:relations 表没有 content 列,因此使用独立 FTS 表并通过触发器同步。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute(""" - CREATE VIRTUAL TABLE IF NOT EXISTS relations_fts - USING fts5( - relation_hash UNINDEXED, - content, - tokenize='unicode61' - ) - """) - - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS relations_ai - AFTER INSERT ON relations - BEGIN - INSERT INTO relations_fts(relation_hash, content) - VALUES ( - new.hash, - COALESCE(new.subject, '') || ' ' || COALESCE(new.predicate, '') || ' ' || COALESCE(new.object, '') - ); - END - """) - - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS relations_ad - AFTER DELETE ON relations - BEGIN - DELETE FROM relations_fts WHERE relation_hash = old.hash; - END - """) - - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS relations_au - AFTER UPDATE OF subject, predicate, object ON relations - BEGIN - DELETE FROM relations_fts WHERE relation_hash = new.hash; - INSERT INTO relations_fts(relation_hash, content) - VALUES ( - new.hash, - COALESCE(new.subject, '') || ' ' || COALESCE(new.predicate, '') || ' ' || COALESCE(new.object, '') - ); - END - """) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"relations FTS5 schema 创建失败(可能不支持 FTS5): {e}") - c.rollback() - return False - - def ensure_relations_fts_backfilled(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """确保关系 FTS 索引已回填。""" - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute("SELECT COUNT(1) AS n FROM relations") - rel_count = int(cur.fetchone()[0]) - cur.execute("SELECT COUNT(1) AS n FROM relations_fts") - fts_count = int(cur.fetchone()[0]) - - if rel_count != fts_count: - cur.execute("DELETE FROM relations_fts") - cur.execute(""" - INSERT INTO relations_fts(relation_hash, content) - SELECT - r.hash, - COALESCE(r.subject, '') || ' ' || COALESCE(r.predicate, '') || ' ' || COALESCE(r.object, '') - FROM relations r - """) - c.commit() - logger.info(f"relations FTS 回填完成: relations={rel_count}, fts={rel_count}") - return True - except sqlite3.OperationalError as e: - logger.warning(f"relations FTS 回填失败: {e}") - c.rollback() - return False - - def ensure_paragraph_ngram_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """确保段落 ngram 倒排表存在。""" - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute(""" - CREATE TABLE IF NOT EXISTS paragraph_ngrams ( - term TEXT NOT NULL, - paragraph_hash TEXT NOT NULL, - PRIMARY KEY (term, paragraph_hash), - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE - ) - """) - cur.execute(""" - CREATE INDEX IF NOT EXISTS idx_paragraph_ngrams_hash - ON paragraph_ngrams(paragraph_hash) - """) - cur.execute(""" - CREATE TABLE IF NOT EXISTS paragraph_ngram_meta ( - key TEXT PRIMARY KEY, - value TEXT - ) - """) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"paragraph ngram schema 创建失败: {e}") - c.rollback() - return False - - @staticmethod - def _char_ngrams(text: str, n: int) -> List[str]: - compact = "".join(str(text or "").lower().split()) - if not compact: - return [] - if len(compact) < n: - return [compact] - return [compact[i : i + n] for i in range(0, len(compact) - n + 1)] - - def ensure_paragraph_ngram_backfilled( - self, - n: int = 2, - conn: Optional[sqlite3.Connection] = None, - ) -> bool: - """ - 确保段落 ngram 倒排索引已回填。 - - 仅在 n 变化或文档数量变化时重建,避免每次加载都全量重建。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - n = max(1, int(n)) - try: - cur.execute("SELECT value FROM paragraph_ngram_meta WHERE key='ngram_n'") - row = cur.fetchone() - current_n = int(row[0]) if row and row[0] is not None else None - - cur.execute("SELECT COUNT(1) FROM paragraphs WHERE is_deleted IS NULL OR is_deleted = 0") - para_count = int(cur.fetchone()[0]) - cur.execute("SELECT COUNT(DISTINCT paragraph_hash) FROM paragraph_ngrams") - indexed_docs = int(cur.fetchone()[0]) - - need_rebuild = (current_n != n) or (para_count != indexed_docs) - if not need_rebuild: - return True - - cur.execute("DELETE FROM paragraph_ngrams") - cur.execute(""" - SELECT hash, content - FROM paragraphs - WHERE is_deleted IS NULL OR is_deleted = 0 - """) - rows = cur.fetchall() - - batch: List[Tuple[str, str]] = [] - batch_size = 2000 - for row in rows: - p_hash = str(row["hash"]) - terms = list(dict.fromkeys(self._char_ngrams(str(row["content"] or ""), n))) - for term in terms: - batch.append((term, p_hash)) - if len(batch) >= batch_size: - cur.executemany( - "INSERT OR IGNORE INTO paragraph_ngrams(term, paragraph_hash) VALUES (?, ?)", - batch, - ) - batch.clear() - if batch: - cur.executemany( - "INSERT OR IGNORE INTO paragraph_ngrams(term, paragraph_hash) VALUES (?, ?)", - batch, - ) - - cur.execute(""" - INSERT INTO paragraph_ngram_meta(key, value) VALUES('ngram_n', ?) - ON CONFLICT(key) DO UPDATE SET value=excluded.value - """, (str(n),)) - cur.execute(""" - INSERT INTO paragraph_ngram_meta(key, value) VALUES('paragraph_count', ?) - ON CONFLICT(key) DO UPDATE SET value=excluded.value - """, (str(para_count),)) - c.commit() - logger.info(f"paragraph ngram 回填完成: n={n}, paragraphs={para_count}") - return True - except Exception as e: - logger.warning(f"paragraph ngram 回填失败: {e}") - c.rollback() - return False - - def fts_upsert_paragraph( - self, - paragraph_hash: str, - conn: Optional[sqlite3.Connection] = None, - ) -> bool: - """ - 将段落写入(或覆盖)到 FTS 索引。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute( - "SELECT rowid, content FROM paragraphs WHERE hash = ?", - (paragraph_hash,), - ) - row = cur.fetchone() - if not row: - return False - rowid = int(row[0]) - content = str(row[1] or "") - cur.execute( - "INSERT OR REPLACE INTO paragraphs_fts(rowid, content) VALUES (?, ?)", - (rowid, content), - ) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"FTS upsert 失败: {e}") - c.rollback() - return False - - def fts_delete_paragraph( - self, - paragraph_hash: str, - conn: Optional[sqlite3.Connection] = None, - ) -> bool: - """ - 从 FTS 索引删除段落。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute( - "SELECT rowid, content FROM paragraphs WHERE hash = ?", - (paragraph_hash,), - ) - row = cur.fetchone() - if not row: - return False - rowid = int(row[0]) - content = str(row[1] or "") - cur.execute( - "INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) VALUES ('delete', ?, ?)", - (rowid, content), - ) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"FTS delete 失败: {e}") - c.rollback() - return False - - def fts_search_bm25( - self, - match_query: str, - limit: int = 20, - max_doc_len: int = 2000, - conn: Optional[sqlite3.Connection] = None, - ) -> List[Dict[str, Any]]: - """ - 使用 FTS5 + bm25 执行全文检索。 - """ - if not match_query.strip(): - return [] - - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute( - """ - SELECT p.hash, p.content, bm25(paragraphs_fts) AS bm25_score - FROM paragraphs_fts - JOIN paragraphs p ON p.rowid = paragraphs_fts.rowid - WHERE paragraphs_fts MATCH ? - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - ORDER BY bm25_score ASC - LIMIT ? - """, - (match_query, max(1, int(limit))), - ) - rows = cur.fetchall() - results: List[Dict[str, Any]] = [] - for row in rows: - content = str(row["content"] or "") - if max_doc_len > 0: - content = content[:max_doc_len] - results.append( - { - "hash": row["hash"], - "content": content, - "bm25_score": float(row["bm25_score"]), - } - ) - return results - except sqlite3.OperationalError as e: - logger.warning(f"FTS 查询失败: {e}") - return [] - - def fts_search_relations_bm25( - self, - match_query: str, - limit: int = 20, - max_doc_len: int = 512, - conn: Optional[sqlite3.Connection] = None, - ) -> List[Dict[str, Any]]: - """使用 FTS5 + bm25 执行关系全文检索。""" - if not match_query.strip(): - return [] - - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute( - """ - SELECT - r.hash, - r.subject, - r.predicate, - r.object, - bm25(relations_fts) AS bm25_score - FROM relations_fts - JOIN relations r ON r.hash = relations_fts.relation_hash - WHERE relations_fts MATCH ? - ORDER BY bm25_score ASC - LIMIT ? - """, - (match_query, max(1, int(limit))), - ) - rows = cur.fetchall() - out: List[Dict[str, Any]] = [] - for row in rows: - content = f"{row['subject']} {row['predicate']} {row['object']}" - if max_doc_len > 0: - content = content[:max_doc_len] - out.append( - { - "hash": row["hash"], - "subject": row["subject"], - "predicate": row["predicate"], - "object": row["object"], - "content": content, - "bm25_score": float(row["bm25_score"]), - } - ) - return out - except sqlite3.OperationalError as e: - logger.warning(f"relations FTS 查询失败: {e}") - return [] - - def ngram_search_paragraphs( - self, - tokens: List[str], - limit: int = 20, - max_doc_len: int = 2000, - conn: Optional[sqlite3.Connection] = None, - ) -> List[Dict[str, Any]]: - """按 ngram 倒排索引检索段落,避免 LIKE 全表扫描。""" - uniq = [t for t in dict.fromkeys([str(x).strip().lower() for x in tokens]) if t] - if not uniq: - return [] - - c = self._resolve_conn(conn) - cur = c.cursor() - placeholders = ",".join(["?"] * len(uniq)) - try: - cur.execute( - f""" - SELECT - p.hash, - p.content, - COUNT(*) AS hit_terms - FROM paragraph_ngrams ng - JOIN paragraphs p ON p.hash = ng.paragraph_hash - WHERE ng.term IN ({placeholders}) - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - GROUP BY p.hash, p.content - ORDER BY hit_terms DESC - LIMIT ? - """, - tuple(uniq + [max(1, int(limit))]), - ) - rows = cur.fetchall() - out: List[Dict[str, Any]] = [] - token_count = max(1, len(uniq)) - for row in rows: - hit_terms = int(row["hit_terms"]) - score = float(hit_terms / token_count) - content = str(row["content"] or "") - if max_doc_len > 0: - content = content[:max_doc_len] - out.append( - { - "hash": row["hash"], - "content": content, - "bm25_score": -score, - "fallback_score": score, - } - ) - return out - except sqlite3.OperationalError as e: - logger.warning(f"ngram 倒排查询失败: {e}") - return [] - - def fts_doc_count(self, conn: Optional[sqlite3.Connection] = None) -> int: - """获取 FTS 文档数量。""" - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute("SELECT COUNT(1) FROM paragraphs_fts") - return int(cur.fetchone()[0]) - except sqlite3.OperationalError: - return 0 - - def shrink_memory(self, conn: Optional[sqlite3.Connection] = None) -> None: - """请求 SQLite 收缩当前连接缓存。""" - c = self._resolve_conn(conn) - try: - c.execute("PRAGMA shrink_memory") - except sqlite3.OperationalError: - pass - - @staticmethod - def _normalize_episode_source(source: Any) -> str: - return str(source or "").strip() - - def _dedupe_episode_sources(self, sources: List[Any]) -> List[str]: - normalized: List[str] = [] - seen = set() - for item in sources or []: - token = self._normalize_episode_source(item) - if not token or token in seen: - continue - seen.add(token) - normalized.append(token) - return normalized - - def _get_sources_for_paragraph_hashes( - self, - hashes: List[str], - *, - include_deleted: bool = True, - ) -> List[str]: - normalized_hashes = [ - str(item or "").strip() - for item in (hashes or []) - if str(item or "").strip() - ] - if not normalized_hashes: - return [] - - placeholders = ",".join(["?"] * len(normalized_hashes)) - conditions = ["hash IN ({})".format(placeholders), "TRIM(COALESCE(source, '')) != ''"] - if not include_deleted: - conditions.append("(is_deleted IS NULL OR is_deleted = 0)") - - cursor = self._conn.cursor() - cursor.execute( - f""" - SELECT DISTINCT TRIM(source) AS source - FROM paragraphs - WHERE {' AND '.join(conditions)} - """, - tuple(normalized_hashes), - ) - return self._dedupe_episode_sources([row["source"] for row in cursor.fetchall()]) - - def _enqueue_episode_source_rebuilds(self, sources: List[Any], reason: str = "") -> int: - normalized_sources = self._dedupe_episode_sources(sources) - if not normalized_sources: - return 0 - - now = datetime.now().timestamp() - reason_text = str(reason or "").strip()[:200] or None - cursor = self._conn.cursor() - cursor.executemany( - """ - INSERT INTO episode_rebuild_sources ( - source, status, retry_count, last_error, reason, requested_at, updated_at - ) VALUES (?, 'pending', 0, NULL, ?, ?, ?) - ON CONFLICT(source) DO UPDATE SET - status = 'pending', - last_error = NULL, - reason = excluded.reason, - requested_at = excluded.requested_at, - updated_at = excluded.updated_at - """, - [ - (source, reason_text, now, now) - for source in normalized_sources - ], - ) - self._conn.commit() - return len(normalized_sources) - - def add_paragraph( - self, - content: str, - vector_index: Optional[int] = None, - source: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - knowledge_type: str = "mixed", - time_meta: Optional[Dict[str, Any]] = None, - ) -> str: - """ - 添加段落 - - Args: - content: 段落内容 - vector_index: 向量索引 - source: 来源 - metadata: 额外元数据 - knowledge_type: 知识类型 (narrative/factual/quote/structured/mixed) - time_meta: 时间元信息 (event_time/event_time_start/event_time_end/...) - - Returns: - 段落哈希值 - """ - content_normalized = normalize_text(content) - hash_value = compute_hash(content_normalized) - resolved_knowledge_type = validate_stored_knowledge_type(knowledge_type) - - now = datetime.now().timestamp() - word_count = len(content_normalized.split()) - normalized_time = normalize_time_meta(time_meta) - - cursor = self._conn.cursor() - try: - cursor.execute(""" - INSERT INTO paragraphs - ( - hash, content, vector_index, created_at, updated_at, metadata, source, word_count, - event_time, event_time_start, event_time_end, time_granularity, time_confidence, - knowledge_type - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - hash_value, - content, - vector_index, - now, - now, - pickle.dumps(metadata or {}), - source, - word_count, - normalized_time.get("event_time"), - normalized_time.get("event_time_start"), - normalized_time.get("event_time_end"), - normalized_time.get("time_granularity"), - normalized_time.get("time_confidence", 1.0), - resolved_knowledge_type.value, - )) - self._conn.commit() - try: - self.enqueue_episode_source_rebuild( - source=source, - reason="paragraph_added", - ) - except Exception as e: - logger.warning(f"Episode source 重建入队失败: hash={hash_value[:16]}..., err={e}") - logger.debug( - f"添加段落: hash={hash_value[:16]}..., words={word_count}, type={resolved_knowledge_type.value}" - ) - return hash_value - except sqlite3.IntegrityError: - logger.debug(f"段落已存在: {hash_value[:16]}...") - # 尝试复活 - self.revive_if_deleted(paragraph_hashes=[hash_value]) - return hash_value - - def _canonicalize_name(self, name: str) -> str: - """ - 规范化名称 (统一小写并去除首尾空格) - - Args: - name: 原始名称 - - Returns: - 规范化后的名称 - """ - if not name: - return "" - return name.strip().lower() - - def add_entity( - self, - name: str, - vector_index: Optional[int] = None, - source_paragraph: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> str: - """ - 添加实体 - - Args: - name: 实体名称 - vector_index: 向量索引 - source_paragraph: 来源段落哈希 (如果提供,将建立关联) - metadata: 额外元数据 - - Returns: - 实体哈希值 - """ - # 1. 规范化名称 - name_normalized = self._canonicalize_name(name) - if not name_normalized: - raise ValueError("Entity name cannot be empty") - - hash_value = compute_hash(name_normalized) - now = datetime.now().timestamp() - - cursor = self._conn.cursor() - - # 2. 插入实体 (INSERT OR IGNORE) - # 注意:这里我们保留原有的 name 字段存储,可以是 display name, - # 但 hash 必须由 canonical name 生成。 - # 如果实体已存在,我们其实不一定要更新 name (保留第一次的 display name 往往更好) - # 或者我们也可以选择不作为唯一键冲突,而是逻辑判断。 - # 考虑到 entities.hash 是主键,entities.name 是 UNIQUE。 - # 如果 name 大小写不同但 hash 相同 (冲突),或者 name 不同但 canonical name 相同? - # 由于 hash 是由 canonical name 算出来的,所以 hash 相同意味着 canonical name 相同。 - # 如果 db 中已存在的 name 是 "Apple",新来的 name 是 "apple",它们 canonical name 都是 "apple",hash 一样。 - # 此时 INSERT OR IGNORE 会忽略。 - - try: - cursor.execute(""" - INSERT INTO entities - (hash, name, vector_index, appearance_count, created_at, metadata) - VALUES (?, ?, ?, 1, ?, ?) - """, ( - hash_value, - name, - vector_index, - now, - pickle.dumps(metadata or {}), - )) - - logger.debug(f"添加实体: {name} ({hash_value[:8]})") - self._conn.commit() - - # 3. 建立来源关联 - if source_paragraph: - self.link_paragraph_entity(source_paragraph, hash_value) - - return hash_value - - except sqlite3.IntegrityError: - # 实体已存在 - # 1. 尝试复活 (自动复活) - self.revive_if_deleted(entity_hashes=[hash_value]) - - # 2. 更新计数 - cursor.execute(""" - UPDATE entities - SET appearance_count = appearance_count + 1 - WHERE hash = ? - """, (hash_value,)) - self._conn.commit() - - logger.debug(f"实体已存在(复活/计数+1): {name}") - - # 3. 建立来源关联 - if source_paragraph: - self.link_paragraph_entity(source_paragraph, hash_value) - - return hash_value - - def add_relation( - self, - subject: str, - predicate: str, - obj: str, - vector_index: Optional[int] = None, - confidence: float = 1.0, - source_paragraph: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> str: - """ - 添加关系 - - Args: - subject: 主语 - predicate: 谓语 - obj: 宾语 - vector_index: 向量索引 - confidence: 置信度 - source_paragraph: 来源段落哈希 - metadata: 额外元数据 - - Returns: - 关系哈希值 - """ - # 1. 规范化输入 - s_canon = self._canonicalize_name(subject) - p_canon = self._canonicalize_name(predicate) - o_canon = self._canonicalize_name(obj) - - if not all([s_canon, p_canon, o_canon]): - raise ValueError("Relation components cannot be empty") - - # 2. 计算组合哈希 - # 公式: md5(s|p|o) - relation_key = f"{s_canon}|{p_canon}|{o_canon}" - hash_value = compute_hash(relation_key) - - now = datetime.now().timestamp() - - # 记录原始 display name 到 metadata (如果需要的话,或者直接存到 DB 字段) - # 这里我们直接存入 subject, predicate, object 字段, - # 注意:如果 DB 里已存在该关系 (hash 相同),则不会更新这些字段,保留第一次的拼写。 - - cursor = self._conn.cursor() - try: - cursor.execute(""" - INSERT OR IGNORE INTO relations - (hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - hash_value, - subject, # 原始拼写 - predicate, - obj, - vector_index, - confidence, - now, - source_paragraph, # 这里的 source_paragraph 仅作为 "首次发现地" 记录,也可留空 - pickle.dumps(metadata or {}), - )) - self._conn.commit() - - if cursor.rowcount > 0: - logger.debug(f"添加关系: {subject} -{predicate}-> {obj}") - else: - logger.debug(f"关系已存在: {subject} -{predicate}-> {obj}") - - # 3. 建立来源关联 (幂等) - # 无论关系是新创建的还是已存在的,只要提供了 source_paragraph,都要建立连接 - if source_paragraph: - self.link_paragraph_relation(source_paragraph, hash_value) - - return hash_value - - except sqlite3.IntegrityError as e: - logger.warning(f"添加关系异常: {e}") - return hash_value - - def link_paragraph_relation( - self, - paragraph_hash: str, - relation_hash: str, - ) -> bool: - """ - 关联段落和关系 (幂等) - """ - cursor = self._conn.cursor() - try: - # 使用 INSERT OR IGNORE 避免重复报错 - cursor.execute(""" - INSERT OR IGNORE INTO paragraph_relations - (paragraph_hash, relation_hash) - VALUES (?, ?) - """, (paragraph_hash, relation_hash)) - self._conn.commit() - self._enqueue_episode_source_rebuilds( - self._get_sources_for_paragraph_hashes([paragraph_hash], include_deleted=True), - reason="paragraph_relation_linked", - ) - return True - except sqlite3.IntegrityError: - return False - - def link_paragraph_entity( - self, - paragraph_hash: str, - entity_hash: str, - mention_count: int = 1, - ) -> bool: - """ - 关联段落和实体 (幂等) - """ - cursor = self._conn.cursor() - try: - # 首先尝试插入 - cursor.execute(""" - INSERT OR IGNORE INTO paragraph_entities - (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - """, (paragraph_hash, entity_hash, mention_count)) - - if cursor.rowcount == 0: - # 如果已存在 (IGNORE生效),则更新计数 - cursor.execute(""" - UPDATE paragraph_entities - SET mention_count = mention_count + ? - WHERE paragraph_hash = ? AND entity_hash = ? - """, (mention_count, paragraph_hash, entity_hash)) - - self._conn.commit() - self._enqueue_episode_source_rebuilds( - self._get_sources_for_paragraph_hashes([paragraph_hash], include_deleted=True), - reason="paragraph_entity_linked", - ) - return True - except sqlite3.IntegrityError: - return False - - def get_paragraph(self, hash_value: str) -> Optional[Dict[str, Any]]: - """ - 获取段落 - - Args: - hash_value: 段落哈希 - - Returns: - 段落信息字典,不存在则返回None - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM paragraphs WHERE hash = ? - """, (hash_value,)) - row = cursor.fetchone() - - if row: - return self._row_to_dict(row, "paragraph") - return None - - def update_paragraph_time_meta( - self, - paragraph_hash: str, - time_meta: Dict[str, Any], - ) -> bool: - """ - 更新段落时间元信息。 - """ - normalized = normalize_time_meta(time_meta) - if not normalized: - return False - source_to_rebuild = self._get_sources_for_paragraph_hashes( - [paragraph_hash], - include_deleted=True, - ) - - updates: List[str] = [] - params: List[Any] = [] - for key in [ - "event_time", - "event_time_start", - "event_time_end", - "time_granularity", - "time_confidence", - ]: - if key in normalized: - updates.append(f"{key} = ?") - params.append(normalized[key]) - - if not updates: - return False - - updates.append("updated_at = ?") - params.append(datetime.now().timestamp()) - params.append(paragraph_hash) - - cursor = self._conn.cursor() - cursor.execute( - f"UPDATE paragraphs SET {', '.join(updates)} WHERE hash = ?", - tuple(params), - ) - self._conn.commit() - changed = cursor.rowcount > 0 - if changed: - self._enqueue_episode_source_rebuilds( - source_to_rebuild, - reason="paragraph_time_updated", - ) - return changed - - def query_paragraphs_temporal( - self, - start_ts: Optional[float] = None, - end_ts: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - limit: int = 100, - allow_created_fallback: bool = True, - ) -> List[Dict[str, Any]]: - """ - 查询时序命中的段落(区间相交语义)。 - """ - if limit <= 0: - return [] - - effective_start = "COALESCE(p.event_time_start, p.event_time, p.event_time_end" - effective_end = "COALESCE(p.event_time_end, p.event_time, p.event_time_start" - if allow_created_fallback: - effective_start += ", p.created_at)" - effective_end += ", p.created_at)" - else: - effective_start += ")" - effective_end += ")" - - conditions = ["(p.is_deleted IS NULL OR p.is_deleted = 0)"] - params: List[Any] = [] - - if source: - conditions.append("p.source = ?") - params.append(source) - - if person: - conditions.append( - """ - EXISTS ( - SELECT 1 - FROM paragraph_entities pe - JOIN entities e ON e.hash = pe.entity_hash - WHERE pe.paragraph_hash = p.hash - AND LOWER(e.name) LIKE ? - ) - """ - ) - params.append(f"%{str(person).strip().lower()}%") - - if start_ts is not None and end_ts is not None: - conditions.append(f"({effective_end} >= ? AND {effective_start} <= ?)") - params.extend([start_ts, end_ts]) - elif start_ts is not None: - conditions.append(f"({effective_end} >= ?)") - params.append(start_ts) - elif end_ts is not None: - conditions.append(f"({effective_start} <= ?)") - params.append(end_ts) - - where_sql = " AND ".join(conditions) - sql = f""" - SELECT p.* - FROM paragraphs p - WHERE {where_sql} - ORDER BY {effective_end} DESC, p.updated_at DESC - LIMIT ? - """ - params.append(limit) - - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params)) - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def get_entity(self, hash_value: str) -> Optional[Dict[str, Any]]: - """ - 获取实体 - - Args: - hash_value: 实体哈希 - - Returns: - 实体信息字典,不存在则返回None - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM entities WHERE hash = ? - """, (hash_value,)) - row = cursor.fetchone() - - if row: - return self._row_to_dict(row, "entity") - return None - - def get_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: - """ - 获取关系 - - Args: - hash_value: 关系哈希 - - Returns: - 关系信息字典,不存在则返回None - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM relations WHERE hash = ? - """, (hash_value,)) - row = cursor.fetchone() - - if row: - return self._row_to_dict(row, "relation") - return None - - def get_paragraph_relations(self, paragraph_hash: str) -> List[Dict[str, Any]]: - """ - 获取段落的所有关系 - - Args: - paragraph_hash: 段落哈希 - - Returns: - 关系列表 - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT r.* FROM relations r - JOIN paragraph_relations pr ON r.hash = pr.relation_hash - WHERE pr.paragraph_hash = ? - """, (paragraph_hash,)) - - return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - - def get_paragraph_entities(self, paragraph_hash: str) -> List[Dict[str, Any]]: - """ - 获取段落的所有实体 - - Args: - paragraph_hash: 段落哈希 - - Returns: - 实体列表 - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT e.*, pe.mention_count - FROM entities e - JOIN paragraph_entities pe ON e.hash = pe.entity_hash - WHERE pe.paragraph_hash = ? - """, (paragraph_hash,)) - - return [self._row_to_dict(row, "entity") for row in cursor.fetchall()] - - def get_paragraphs_by_entity(self, entity_name: str) -> List[Dict[str, Any]]: - """ - 获取包含指定实体的所有段落 (自动处理规范化) - - Args: - entity_name: 实体名称 (支持任意大小写) - - Returns: - 段落列表 - """ - # 1. 计算规范化 Hash - name_canon = self._canonicalize_name(entity_name) - if not name_canon: - return [] - - entity_hash = compute_hash(name_canon) - - cursor = self._conn.cursor() - # 2. 直接使用 Hash 查询中间表,完全避开 Name 匹配 - cursor.execute(""" - SELECT p.* - FROM paragraphs p - JOIN paragraph_entities pe ON p.hash = pe.paragraph_hash - WHERE pe.entity_hash = ? - """, (entity_hash,)) - - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def get_relations( - self, - subject: Optional[str] = None, - predicate: Optional[str] = None, - object: Optional[str] = None, - ) -> List[Dict[str, Any]]: - """ - 查询关系(大小写不敏感) - - Args: - subject: 主语(可选) - predicate: 谓语(可选) - object: 宾语(可选) - - Returns: - 关系列表 - """ - # 构建查询条件 - conditions = [] - params = [] - - if subject: - conditions.append("LOWER(subject) = ?") - params.append(self._canonicalize_name(subject)) - if predicate: - conditions.append("LOWER(predicate) = ?") - params.append(self._canonicalize_name(predicate)) - if object: - conditions.append("LOWER(object) = ?") - params.append(self._canonicalize_name(object)) - - sql = "SELECT * FROM relations" - if conditions: - sql += " WHERE " + " AND ".join(conditions) - - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params)) - - return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - - def get_all_triples(self) -> List[Tuple[str, str, str, str]]: - """ - 高效获取所有三元组 (subject, predicate, object, hash) - 直接返回元组,跳过字典转换和pickle反序列化,用于构建 V5 Map 缓存。 - """ - cursor = self._conn.cursor() - cursor.execute("SELECT subject, predicate, object, hash FROM relations") - return list(cursor.fetchall()) - - def get_paragraphs_by_relation(self, relation_hash: str) -> List[Dict[str, Any]]: - """ - 获取支持指定关系的所有段落 - - Args: - relation_hash: 关系哈希 - - Returns: - 段落列表 - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT p.* - FROM paragraphs p - JOIN paragraph_relations pr ON p.hash = pr.paragraph_hash - WHERE pr.relation_hash = ? - """, (relation_hash,)) - - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def get_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: - """ - 按来源获取段落 - - Args: - source: 来源标识符 - - Returns: - 段落列表 - """ - return self.query("SELECT * FROM paragraphs WHERE source = ?", (source,)) - - def get_all_sources(self) -> List[Dict[str, Any]]: - """ - 获取所有来源文件统计信息 - - Returns: - 来源列表 [{'source': 'name', 'count': int, 'last_updated': timestamp}] - """ - cursor = self._conn.cursor() - # 排除 source 为 NULL 或空的记录 - cursor.execute(""" - SELECT source, COUNT(*) as count, MAX(created_at) as last_updated - FROM paragraphs - WHERE source IS NOT NULL AND source != '' - GROUP BY source - ORDER BY last_updated DESC - """) - - results = [] - for row in cursor.fetchall(): - results.append({ - "source": row[0], - "count": row[1], - "last_updated": row[2] - }) - return results - - - def search_paragraphs_by_content(self, content_query: str) -> List[Dict[str, Any]]: - """按内容模糊搜索段落""" - cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM paragraphs WHERE content LIKE ? - """, (f"%{content_query}%",)) - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def delete_paragraph(self, hash_value: str) -> bool: - """ - 删除段落(级联删除相关关联) - - Args: - hash_value: 段落哈希 - - Returns: - 是否成功删除 - """ - cursor = self._conn.cursor() - cursor.execute(""" - DELETE FROM paragraphs WHERE hash = ? - """, (hash_value,)) - self._conn.commit() - - deleted = cursor.rowcount > 0 - if deleted: - logger.info(f"删除段落: {hash_value[:16]}...") - - return deleted - - def delete_entity(self, hash_or_name: str) -> bool: - """ - 删除实体(级联删除相关关联) - 支持通过哈希值或名称删除 - - 注意:会同时删除所有引用该实体(作为主语或宾语)的关系 - """ - cursor = self._conn.cursor() - - # 1. 解析实体信息 (获取 Name 和 Hash) - entity_name = None - entity_hash = None - - # 尝试作为 Hash 查询 - cursor.execute("SELECT name, hash FROM entities WHERE hash = ?", (hash_or_name,)) - row = cursor.fetchone() - if row: - entity_name = row[0] - entity_hash = row[1] - else: - # 尝试作为 Name 查询 (原始匹配) - cursor.execute("SELECT name, hash FROM entities WHERE name = ?", (hash_or_name,)) - row = cursor.fetchone() - if row: - entity_name = row[0] - entity_hash = row[1] - else: - # 最后的最后:尝试规范化名称 (Canonical) 查询,解决大小写或 WebUI 手动输入导致的不匹配 - name_canon = self._canonicalize_name(hash_or_name) - canon_hash = compute_hash(name_canon) - cursor.execute("SELECT name, hash FROM entities WHERE hash = ?", (canon_hash,)) - row = cursor.fetchone() - if row: - entity_name = row[0] - entity_hash = row[1] - - if not entity_name or not entity_hash: - logger.debug(f"删除实体请求跳过:未在元数据记录中找到 {hash_or_name}") - return False - - logger.info(f"开始删除实体: {entity_name} (Hash: {entity_hash[:8]}...)") - - try: - # 2. 查找相关关系 (Subject 或 Object 为该实体) - cursor.execute(""" - SELECT hash FROM relations - WHERE subject = ? OR object = ? - """, (entity_name, entity_name)) - - relation_hashes = [r[0] for r in cursor.fetchall()] - - if relation_hashes: - logger.info(f"发现 {len(relation_hashes)} 个相关关系,准备级联删除") - - # 3. 删除这些关系与段落的关联 - # SQLite 不支持直接 DELETE ... WHERE ... IN (...) 的列表参数,需要拼接占位符 - placeholders = ','.join(['?'] * len(relation_hashes)) - - cursor.execute(f""" - DELETE FROM paragraph_relations - WHERE relation_hash IN ({placeholders}) - """, relation_hashes) - - # 4. 删除关系本体 - cursor.execute(f""" - DELETE FROM relations - WHERE hash IN ({placeholders}) - """, relation_hashes) - - logger.info("相关关系已级联删除") - - # 5. 删除实体与段落的关联 - cursor.execute("DELETE FROM paragraph_entities WHERE entity_hash = ?", (entity_hash,)) - - # 6. 删除实体本体 - cursor.execute("DELETE FROM entities WHERE hash = ?", (entity_hash,)) - - self._conn.commit() - logger.info("实体删除完成") - return True - - except Exception as e: - logger.error(f"删除实体时发生错误: {e}") - self._conn.rollback() - return False - - def delete_relation(self, hash_value: str) -> bool: - """ - 删除关系(级联删除相关关联) - - Args: - hash_value: 关系哈希 - - Returns: - 是否成功删除 - """ - cursor = self._conn.cursor() - cursor.execute(""" - DELETE FROM relations WHERE hash = ? - """, (hash_value,)) - self._conn.commit() - - deleted = cursor.rowcount > 0 - if deleted: - logger.info(f"删除关系: {hash_value[:16]}...") - - return deleted - - def set_relation_vector_state( - self, - hash_value: str, - state: str, - error: Optional[str] = None, - bump_retry: bool = False, - ) -> bool: - """ - 更新关系向量状态。 - """ - state_norm = str(state or "").strip().lower() - if state_norm not in {"none", "pending", "ready", "failed"}: - raise ValueError(f"无效 vector_state: {state}") - - now = datetime.now().timestamp() - err_text = (str(error).strip() if error is not None else None) - if err_text: - err_text = err_text[:500] - clear_error = state_norm in {"none", "pending", "ready"} - - cursor = self._conn.cursor() - if bump_retry: - cursor.execute( - """ - UPDATE relations - SET vector_state = ?, - vector_updated_at = ?, - vector_error = ?, - vector_retry_count = COALESCE(vector_retry_count, 0) + 1 - WHERE hash = ? - """, - (state_norm, now, None if clear_error else err_text, hash_value), - ) - else: - cursor.execute( - """ - UPDATE relations - SET vector_state = ?, - vector_updated_at = ?, - vector_error = ? - WHERE hash = ? - """, - (state_norm, now, None if clear_error else err_text, hash_value), - ) - self._conn.commit() - return cursor.rowcount > 0 - - def list_relations_by_vector_state( - self, - states: List[str], - limit: int = 200, - max_retry: Optional[int] = None, - ) -> List[Dict[str, Any]]: - """ - 根据向量状态列出关系,用于回填任务。 - """ - normalized_states = [ - str(s or "").strip().lower() - for s in (states or []) - if str(s or "").strip() - ] - normalized_states = [ - s for s in normalized_states - if s in {"none", "pending", "ready", "failed"} - ] - if not normalized_states: - return [] - - placeholders = ",".join(["?"] * len(normalized_states)) - params: List[Any] = list(normalized_states) - sql = f""" - SELECT hash, subject, predicate, object, confidence, source_paragraph, - vector_state, vector_updated_at, vector_error, vector_retry_count, created_at - FROM relations - WHERE vector_state IN ({placeholders}) - """ - if max_retry is not None: - sql += " AND COALESCE(vector_retry_count, 0) < ?" - params.append(int(max_retry)) - sql += " ORDER BY COALESCE(vector_updated_at, created_at, 0) ASC LIMIT ?" - params.append(max(1, int(limit))) - - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params)) - return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - - def count_relations_by_vector_state(self) -> Dict[str, int]: - """ - 统计关系向量状态分布。 - """ - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT COALESCE(vector_state, 'none') AS state, COUNT(*) AS cnt - FROM relations - GROUP BY COALESCE(vector_state, 'none') - """ - ) - result: Dict[str, int] = {"none": 0, "pending": 0, "ready": 0, "failed": 0} - total = 0 - for row in cursor.fetchall(): - state = str(row["state"] or "none").lower() - count = int(row["cnt"] or 0) - if state not in result: - result[state] = 0 - result[state] += count - total += count - result["total"] = total - return result - - def update_vector_index( - self, - item_type: str, - hash_value: str, - vector_index: int, - ) -> bool: - """ - 更新向量索引 - - Args: - item_type: 类型(paragraph/entity/relation) - hash_value: 哈希值 - vector_index: 向量索引 - - Returns: - 是否成功更新 - """ - valid_types = ["paragraph", "entity", "relation"] - if item_type not in valid_types: - raise ValueError(f"无效的类型: {item_type}") - - table_map = { - "paragraph": "paragraphs", - "entity": "entities", - "relation": "relations", - } - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE {table_map[item_type]} - SET vector_index = ? - WHERE hash = ? - """, (vector_index, hash_value)) - self._conn.commit() - - return cursor.rowcount > 0 - - def set_permanence(self, hash_value: str, item_type: str, is_permanent: bool) -> bool: - """设置永久记忆标记""" - table_map = { - "paragraph": "paragraphs", - "relation": "relations", - } - if item_type not in table_map: - raise ValueError(f"类型 {item_type} 不支持设置永久性") - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE {table_map[item_type]} - SET is_permanent = ? - WHERE hash = ? - """, (1 if is_permanent else 0, hash_value)) - self._conn.commit() - - if cursor.rowcount > 0: - logger.debug(f"设置永久记忆: {item_type}/{hash_value[:8]} -> {is_permanent}") - return True - return False - - def record_access(self, hash_value: str, item_type: str) -> bool: - """记录访问(更新时间和次数)""" - table_map = { - "paragraph": "paragraphs", - "relation": "relations", - } - if item_type not in table_map: - return False - - now = datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE {table_map[item_type]} - SET last_accessed = ?, access_count = access_count + 1 - WHERE hash = ? - """, (now, hash_value)) - self._conn.commit() - return cursor.rowcount > 0 - - def query( - self, - sql: str, - params: Optional[Tuple] = None, - ) -> List[Dict[str, Any]]: - """ - 执行自定义查询 - - Args: - sql: SQL语句 - params: 参数 - - Returns: - 查询结果列表 - """ - cursor = self._conn.cursor() - if params: - cursor.execute(sql, params) - else: - cursor.execute(sql) - - return [dict(row) for row in cursor.fetchall()] - - def get_external_memory_ref(self, external_id: str) -> Optional[Dict[str, Any]]: - """按 external_id 查询外部记忆映射。""" - token = str(external_id or "").strip() - if not token: - return None - - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT external_id, paragraph_hash, source_type, created_at, metadata_json - FROM external_memory_refs - WHERE external_id = ? - LIMIT 1 - """, - (token,), - ) - row = cursor.fetchone() - if row is None: - return None - - payload = dict(row) - raw_metadata = payload.get("metadata_json") - if raw_metadata: - try: - payload["metadata"] = json.loads(raw_metadata) - except Exception: - payload["metadata"] = {} - else: - payload["metadata"] = {} - return payload - - def upsert_external_memory_ref( - self, - *, - external_id: str, - paragraph_hash: str, - source_type: str = "", - metadata: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - """注册 external_id 到段落哈希的幂等映射。""" - external_token = str(external_id or "").strip() - paragraph_token = str(paragraph_hash or "").strip() - if not external_token: - raise ValueError("external_id 不能为空") - if not paragraph_token: - raise ValueError("paragraph_hash 不能为空") - - now = datetime.now().timestamp() - metadata_json = json.dumps(metadata or {}, ensure_ascii=False) - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO external_memory_refs ( - external_id, paragraph_hash, source_type, created_at, metadata_json - ) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(external_id) DO UPDATE SET - paragraph_hash = excluded.paragraph_hash, - source_type = excluded.source_type, - metadata_json = excluded.metadata_json - """, - ( - external_token, - paragraph_token, - str(source_type or "").strip() or None, - now, - metadata_json, - ), - ) - self._conn.commit() - return self.get_external_memory_ref(external_token) or { - "external_id": external_token, - "paragraph_hash": paragraph_token, - "source_type": str(source_type or "").strip(), - "created_at": now, - "metadata": metadata or {}, - } - - @staticmethod - def _json_dumps(value: Any) -> str: - return json.dumps(value, ensure_ascii=False, sort_keys=True) - - @staticmethod - def _json_loads(value: Any, default: Any) -> Any: - if value in {None, ""}: - return default - try: - return json.loads(value) - except Exception: - return default - - def list_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]: - hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()] - if not hashes: - return [] - placeholders = ",".join(["?"] * len(hashes)) - cursor = self._conn.cursor() - cursor.execute( - f""" - SELECT external_id, paragraph_hash, source_type, created_at, metadata_json - FROM external_memory_refs - WHERE paragraph_hash IN ({placeholders}) - ORDER BY created_at ASC, external_id ASC - """, - tuple(hashes), - ) - items: List[Dict[str, Any]] = [] - for row in cursor.fetchall(): - payload = dict(row) - payload["metadata"] = self._json_loads(payload.get("metadata_json"), {}) - items.append(payload) - return items - - def delete_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]: - items = self.list_external_memory_refs_by_paragraphs(paragraph_hashes) - hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()] - if not hashes: - return items - placeholders = ",".join(["?"] * len(hashes)) - cursor = self._conn.cursor() - cursor.execute( - f"DELETE FROM external_memory_refs WHERE paragraph_hash IN ({placeholders})", - tuple(hashes), - ) - self._conn.commit() - return items - - def restore_external_memory_refs(self, refs: List[Dict[str, Any]]) -> int: - count = 0 - for item in refs or []: - external_id = str(item.get("external_id", "") or "").strip() - paragraph_hash = str(item.get("paragraph_hash", "") or "").strip() - if not external_id or not paragraph_hash: - continue - created_at = float(item.get("created_at") or datetime.now().timestamp()) - metadata_json = self._json_dumps(item.get("metadata") or {}) - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO external_memory_refs ( - external_id, paragraph_hash, source_type, created_at, metadata_json - ) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(external_id) DO UPDATE SET - paragraph_hash = excluded.paragraph_hash, - source_type = excluded.source_type, - created_at = excluded.created_at, - metadata_json = excluded.metadata_json - """, - ( - external_id, - paragraph_hash, - str(item.get("source_type", "") or "").strip() or None, - created_at, - metadata_json, - ), - ) - count += max(0, int(cursor.rowcount or 0)) - self._conn.commit() - return count - - def record_v5_operation( - self, - *, - action: str, - target: str, - resolved_hashes: List[str], - reason: str = "", - updated_by: str = "", - result: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - operation_id = f"v5_{uuid.uuid4().hex}" - created_at = datetime.now().timestamp() - payload = { - "operation_id": operation_id, - "action": str(action or "").strip(), - "target": str(target or "").strip(), - "reason": str(reason or "").strip(), - "updated_by": str(updated_by or "").strip(), - "created_at": created_at, - "resolved_hashes": [str(item or "").strip() for item in (resolved_hashes or []) if str(item or "").strip()], - "result": result or {}, - } - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO memory_v5_operations ( - operation_id, action, target, reason, updated_by, created_at, resolved_hashes_json, result_json - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - operation_id, - payload["action"], - payload["target"] or None, - payload["reason"] or None, - payload["updated_by"] or None, - created_at, - self._json_dumps(payload["resolved_hashes"]), - self._json_dumps(payload["result"]), - ), - ) - self._conn.commit() - return payload - - def create_delete_operation( - self, - *, - mode: str, - selector: Any, - items: List[Dict[str, Any]], - reason: str = "", - requested_by: str = "", - status: str = "executed", - summary: Optional[Dict[str, Any]] = None, - operation_id: Optional[str] = None, - ) -> Dict[str, Any]: - op_id = str(operation_id or f"del_{uuid.uuid4().hex}").strip() - created_at = datetime.now().timestamp() - normalized_items: List[Dict[str, Any]] = [] - for item in items or []: - if not isinstance(item, dict): - continue - item_type = str(item.get("item_type", "") or "").strip() - if not item_type: - continue - normalized_items.append( - { - "item_type": item_type, - "item_hash": str(item.get("item_hash", "") or "").strip() or None, - "item_key": str(item.get("item_key", "") or item.get("item_hash", "") or "").strip() or None, - "payload": item.get("payload") if isinstance(item.get("payload"), dict) else {}, - } - ) - - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO delete_operations ( - operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json - ) VALUES (?, ?, ?, ?, ?, ?, ?, NULL, ?) - """, - ( - op_id, - str(mode or "").strip(), - self._json_dumps(selector if selector is not None else {}), - str(reason or "").strip() or None, - str(requested_by or "").strip() or None, - str(status or "executed").strip(), - created_at, - self._json_dumps(summary or {}), - ), - ) - if normalized_items: - cursor.executemany( - """ - INSERT INTO delete_operation_items ( - operation_id, item_type, item_hash, item_key, payload_json, created_at - ) VALUES (?, ?, ?, ?, ?, ?) - """, - [ - ( - op_id, - item["item_type"], - item["item_hash"], - item["item_key"], - self._json_dumps(item["payload"]), - created_at, - ) - for item in normalized_items - ], - ) - self._conn.commit() - return self.get_delete_operation(op_id) or { - "operation_id": op_id, - "mode": str(mode or "").strip(), - "selector": selector, - "reason": str(reason or "").strip(), - "requested_by": str(requested_by or "").strip(), - "status": str(status or "executed").strip(), - "created_at": created_at, - "summary": summary or {}, - "items": normalized_items, - } - - def mark_delete_operation_restored( - self, - operation_id: str, - *, - summary: Optional[Dict[str, Any]] = None, - ) -> bool: - token = str(operation_id or "").strip() - if not token: - return False - cursor = self._conn.cursor() - cursor.execute( - """ - UPDATE delete_operations - SET status = ?, restored_at = ?, summary_json = ? - WHERE operation_id = ? - """, - ( - "restored", - datetime.now().timestamp(), - self._json_dumps(summary or {}), - token, - ), - ) - self._conn.commit() - return cursor.rowcount > 0 - - def list_delete_operations(self, *, limit: int = 50, mode: str = "") -> List[Dict[str, Any]]: - cursor = self._conn.cursor() - params: List[Any] = [] - where = "" - mode_token = str(mode or "").strip().lower() - if mode_token: - where = "WHERE LOWER(mode) = ?" - params.append(mode_token) - params.append(max(1, int(limit or 50))) - cursor.execute( - f""" - SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json - FROM delete_operations - {where} - ORDER BY created_at DESC - LIMIT ? - """, - tuple(params), - ) - items: List[Dict[str, Any]] = [] - for row in cursor.fetchall(): - payload = dict(row) - payload["selector"] = self._json_loads(payload.get("selector"), {}) - payload["summary"] = self._json_loads(payload.get("summary_json"), {}) - items.append(payload) - return items - - def get_delete_operation(self, operation_id: str) -> Optional[Dict[str, Any]]: - token = str(operation_id or "").strip() - if not token: - return None - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json - FROM delete_operations - WHERE operation_id = ? - LIMIT 1 - """, - (token,), - ) - row = cursor.fetchone() - if row is None: - return None - - payload = dict(row) - payload["selector"] = self._json_loads(payload.get("selector"), {}) - payload["summary"] = self._json_loads(payload.get("summary_json"), {}) - - cursor.execute( - """ - SELECT item_type, item_hash, item_key, payload_json, created_at - FROM delete_operation_items - WHERE operation_id = ? - ORDER BY id ASC - """, - (token,), - ) - payload["items"] = [ - { - "item_type": str(item["item_type"] or ""), - "item_hash": str(item["item_hash"] or ""), - "item_key": str(item["item_key"] or ""), - "payload": self._json_loads(item["payload_json"], {}), - "created_at": item["created_at"], - } - for item in cursor.fetchall() - ] - return payload - - def purge_deleted_relations(self, *, cutoff_time: float, limit: int = 1000) -> List[str]: - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT hash - FROM deleted_relations - WHERE deleted_at IS NOT NULL AND deleted_at < ? - ORDER BY deleted_at ASC - LIMIT ? - """, - (float(cutoff_time), max(1, int(limit or 1000))), - ) - hashes = [str(row[0] or "").strip() for row in cursor.fetchall() if str(row[0] or "").strip()] - if not hashes: - return [] - placeholders = ",".join(["?"] * len(hashes)) - cursor.execute(f"DELETE FROM deleted_relations WHERE hash IN ({placeholders})", tuple(hashes)) - self._conn.commit() - return hashes - - def get_statistics(self) -> Dict[str, int]: - """ - 获取统计信息 - - Returns: - 统计信息字典 - """ - cursor = self._conn.cursor() - - stats = {} - - # 段落数量 - cursor.execute("SELECT COUNT(*) FROM paragraphs") - stats["paragraph_count"] = cursor.fetchone()[0] - - # 实体数量 - cursor.execute("SELECT COUNT(*) FROM entities") - stats["entity_count"] = cursor.fetchone()[0] - - # 关系数量 - cursor.execute("SELECT COUNT(*) FROM relations") - stats["relation_count"] = cursor.fetchone()[0] - - # 总词数 - cursor.execute("SELECT SUM(word_count) FROM paragraphs") - result = cursor.fetchone()[0] - stats["total_words"] = result if result else 0 - - return stats - - def count_paragraphs(self, include_deleted: bool = False, only_deleted: bool = False) -> int: - """ - 获取段落数量 - """ - cursor = self._conn.cursor() - if only_deleted: - cursor.execute("SELECT COUNT(*) FROM paragraphs WHERE is_deleted = 1") - return cursor.fetchone()[0] - if include_deleted: - cursor.execute("SELECT COUNT(*) FROM paragraphs") - return cursor.fetchone()[0] - cursor.execute("SELECT COUNT(*) FROM paragraphs WHERE is_deleted = 0") - return cursor.fetchone()[0] - - def count_relations(self, include_deleted: bool = False, only_deleted: bool = False) -> int: - """ - 获取关系数量 - """ - cursor = self._conn.cursor() - if only_deleted: - cursor.execute("SELECT COUNT(*) FROM deleted_relations") - return cursor.fetchone()[0] - cursor.execute("SELECT COUNT(*) FROM relations") - active_count = cursor.fetchone()[0] - if not include_deleted: - return active_count - cursor.execute("SELECT COUNT(*) FROM deleted_relations") - deleted_count = cursor.fetchone()[0] - return int(active_count) + int(deleted_count) - - def count_entities(self) -> int: - """ - 获取实体数量 - - Returns: - 实体数量 - """ - cursor = self._conn.cursor() - cursor.execute("SELECT COUNT(*) FROM entities") - return cursor.fetchone()[0] - - def get_knowledge_type_distribution(self) -> Dict[str, int]: - """获取段落知识类型分布。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT knowledge_type, COUNT(*) as count - FROM paragraphs - WHERE is_deleted = 0 - GROUP BY knowledge_type - """ - ) - result: Dict[str, int] = {} - for row in cursor.fetchall(): - type_name = row[0] if row[0] else "未分类" - result[str(type_name)] = int(row[1] or 0) - return result - - def get_memory_status_summary(self, now_ts: Optional[float] = None) -> Dict[str, int]: - """聚合 memory status 统计。""" - now_ts = float(now_ts) if now_ts is not None else datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute("SELECT COUNT(*) FROM relations WHERE is_inactive = 0") - active_count = int(cursor.fetchone()[0] or 0) - cursor.execute("SELECT COUNT(*) FROM relations WHERE is_inactive = 1") - inactive_count = int(cursor.fetchone()[0] or 0) - cursor.execute("SELECT COUNT(*) FROM deleted_relations") - deleted_count = int(cursor.fetchone()[0] or 0) - cursor.execute("SELECT COUNT(*) FROM relations WHERE is_pinned = 1") - pinned_count = int(cursor.fetchone()[0] or 0) - cursor.execute("SELECT COUNT(*) FROM relations WHERE protected_until > ?", (now_ts,)) - ttl_count = int(cursor.fetchone()[0] or 0) - return { - "active_count": active_count, - "inactive_count": inactive_count, - "deleted_count": deleted_count, - "pinned_count": pinned_count, - "temp_protected_count": ttl_count, - } - - def get_relations_subject_object_map(self, hashes: List[str]) -> Dict[str, Tuple[str, str]]: - """批量获取关系 hash 对应的 (subject, object)。""" - if not hashes: - return {} - cursor = self._conn.cursor() - placeholders = ",".join(["?"] * len(hashes)) - cursor.execute( - f"SELECT hash, subject, object FROM relations WHERE hash IN ({placeholders})", - hashes, - ) - return {str(row[0]): (str(row[1]), str(row[2])) for row in cursor.fetchall()} - - def get_connection(self) -> sqlite3.Connection: - """公开连接访问(用于离线脚本),替代外部访问私有字段。""" - return self._resolve_conn() - - def get_relation_db_snapshot(self) -> Tuple[int, float, str]: - """返回关系快照:(relation_count, max_created_at, max_hash)。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT - COUNT(*) AS relation_count, - COALESCE(MAX(created_at), 0) AS max_created_at, - COALESCE(MAX(hash), '') AS max_hash - FROM relations - """ - ) - row = cursor.fetchone() - if not row: - return (0, 0.0, "") - return ( - int(row[0] or 0), - float(row[1] or 0.0), - str(row[2] or ""), - ) - - def is_entity_still_referenced(self, entity_hash: str, entity_name: str = "") -> bool: - """ - 判断实体是否仍被引用: - 1) 被 paragraph_entities 引用 - 2) 在 relations.subject/object 中出现 - """ - token_hash = str(entity_hash or "").strip() - if token_hash: - cursor = self._conn.cursor() - cursor.execute( - "SELECT 1 FROM paragraph_entities WHERE entity_hash = ? LIMIT 1", - (token_hash,), - ) - if cursor.fetchone() is not None: - return True - - canon_name = self._canonicalize_name(entity_name) - if canon_name: - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT 1 - FROM relations - WHERE LOWER(TRIM(subject)) = ? OR LOWER(TRIM(object)) = ? - LIMIT 1 - """, - (canon_name, canon_name), - ) - if cursor.fetchone() is not None: - return True - return False - - def search_relations_by_subject_or_object( - self, - query: str, - *, - limit: int = 5, - include_deleted: bool = False, - ) -> List[Dict[str, Any]]: - """按 subject/object 模糊查询关系。""" - q = str(query or "").strip() - if not q: - return [] - max_limit = int(max(1, limit)) - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT * - FROM relations - WHERE subject LIKE ? OR object LIKE ? - LIMIT ? - """, - (f"%{q}%", f"%{q}%", max_limit), - ) - rows = [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - if rows or not include_deleted: - return rows - - cursor.execute( - """ - SELECT * - FROM deleted_relations - WHERE subject LIKE ? OR object LIKE ? - LIMIT ? - """, - (f"%{q}%", f"%{q}%", max_limit), - ) - return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - - def list_hashes(self, table: str) -> List[str]: - """安全枚举指定表的 hash 列。""" - allowed = {"paragraphs", "entities", "relations", "deleted_relations"} - token = str(table or "").strip().lower() - if token not in allowed: - raise ValueError(f"unsupported table for list_hashes: {table}") - cursor = self._conn.cursor() - cursor.execute(f"SELECT hash FROM {token}") - return [str(row[0]) for row in cursor.fetchall()] - - def get_orphan_deleted_relation_hashes(self, limit: int = 200) -> List[str]: - """获取 deleted_relations 中已不在 relations 的孤儿 hash。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT d.hash - FROM deleted_relations d - LEFT JOIN relations r ON r.hash = d.hash - WHERE r.hash IS NULL - LIMIT ? - """, - (int(max(1, limit)),), - ) - return [str(row[0]) for row in cursor.fetchall()] - - def resolve_relation_hash_alias( - self, - value: str, - *, - include_deleted: bool = False, - ) -> List[str]: - """ - 解析关系哈希输入: - - 64位:直接校验存在性 - - 32位:通过 relation_hash_aliases 唯一映射 - """ - token = str(value or "").strip().lower() - if not token: - return [] - if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token): - cursor = self._conn.cursor() - cursor.execute("SELECT 1 FROM relations WHERE hash = ? LIMIT 1", (token,)) - if cursor.fetchone(): - return [token] - if include_deleted: - cursor.execute("SELECT 1 FROM deleted_relations WHERE hash = ? LIMIT 1", (token,)) - if cursor.fetchone(): - return [token] - return [] - - if len(token) != 32 or not all(ch in "0123456789abcdef" for ch in token): - return [] - - cursor = self._conn.cursor() - cursor.execute("SELECT hash FROM relation_hash_aliases WHERE alias32 = ?", (token,)) - row = cursor.fetchone() - if not row: - return [] - resolved = str(row[0]) - return [resolved] - - def rebuild_relation_hash_aliases(self) -> Dict[str, Any]: - """重建 32 位 relation hash 别名映射。""" - cursor = self._conn.cursor() - # 历史库兜底:缺表时先创建,避免迁移过程直接中断。 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS relation_hash_aliases ( - alias32 TEXT PRIMARY KEY, - hash TEXT NOT NULL - ) - """) - cursor.execute("DELETE FROM relation_hash_aliases") - - cursor.execute("SELECT hash FROM relations") - hashes = [str(r[0]) for r in cursor.fetchall()] - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='deleted_relations'" - ) - has_deleted_relations = cursor.fetchone() is not None - if has_deleted_relations: - cursor.execute("SELECT hash FROM deleted_relations") - hashes.extend(str(r[0]) for r in cursor.fetchall()) - - alias_map: Dict[str, str] = {} - conflicts: Dict[str, set[str]] = {} - for h in hashes: - if len(h) != 64: - continue - alias = h[:32] - old = alias_map.get(alias) - if old is None: - alias_map[alias] = h - elif old != h: - conflicts.setdefault(alias, set()).update({old, h}) - - for alias, full_hash in alias_map.items(): - if alias in conflicts: - continue - cursor.execute( - "INSERT INTO relation_hash_aliases(alias32, hash) VALUES (?, ?)", - (alias, full_hash), - ) - self._conn.commit() - return { - "inserted": len(alias_map) - len(conflicts), - "conflict_count": len(conflicts), - "conflicts": sorted(conflicts.keys()), - } - - def search_relation_hashes_by_text(self, query: str, limit: int = 5) -> List[str]: - """按 relation 内容模糊查询 hash。""" - q = str(query or "").strip() - if not q: - return [] - cursor = self._conn.cursor() - cursor.execute( - "SELECT hash FROM relations WHERE subject LIKE ? OR object LIKE ? LIMIT ?", - (f"%{q}%", f"%{q}%", int(max(1, limit))), - ) - return [str(row[0]) for row in cursor.fetchall()] - - def search_deleted_relation_hashes_by_text(self, query: str, limit: int = 5) -> List[str]: - """按 deleted_relations 内容模糊查询 hash。""" - q = str(query or "").strip() - if not q: - return [] - cursor = self._conn.cursor() - cursor.execute( - "SELECT hash FROM deleted_relations WHERE subject LIKE ? OR object LIKE ? LIMIT ?", - (f"%{q}%", f"%{q}%", int(max(1, limit))), - ) - return [str(row[0]) for row in cursor.fetchall()] - - def restore_entity_by_hash(self, entity_hash: str) -> bool: - """恢复软删除实体。""" - cursor = self._conn.cursor() - cursor.execute( - "UPDATE entities SET is_deleted=0, deleted_at=NULL WHERE hash=?", - (str(entity_hash),), - ) - changed = cursor.rowcount > 0 - if changed: - self._conn.commit() - return changed - - def restore_paragraph_by_hash(self, paragraph_hash: str) -> bool: - """恢复软删除段落。""" - cursor = self._conn.cursor() - cursor.execute( - "UPDATE paragraphs SET is_deleted=0, deleted_at=NULL WHERE hash=?", - (str(paragraph_hash),), - ) - changed = cursor.rowcount > 0 - if changed: - self._conn.commit() - return changed - - def backfill_temporal_metadata_from_created_at( - self, - *, - limit: int = 100000, - dry_run: bool = False, - no_created_fallback: bool = False, - ) -> Dict[str, int]: - """回填段落 event_time 字段(created_at 兜底)。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT hash, created_at, source - FROM paragraphs - WHERE (event_time IS NULL AND event_time_start IS NULL AND event_time_end IS NULL) - ORDER BY created_at DESC - LIMIT ? - """, - (int(max(1, limit)),), - ) - rows = cursor.fetchall() - candidates = len(rows) - if dry_run: - return {"candidates": candidates, "updated": 0} - if no_created_fallback: - return {"candidates": candidates, "updated": 0} - - updated = 0 - touched_sources: List[str] = [] - for row in rows: - created_at = row["created_at"] - if created_at is None: - continue - cursor.execute( - """ - UPDATE paragraphs - SET event_time = ?, time_granularity = ?, time_confidence = ?, updated_at = ? - WHERE hash = ? - """, - (float(created_at), "day", 0.2, float(created_at), row["hash"]), - ) - if cursor.rowcount > 0: - updated += 1 - touched_sources.append(row["source"]) - self._conn.commit() - if updated > 0: - self._enqueue_episode_source_rebuilds( - touched_sources, - reason="paragraph_time_backfill", - ) - return {"candidates": candidates, "updated": updated} - - def get_schema_version(self) -> int: - cursor = self._conn.cursor() - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" - ) - if cursor.fetchone() is None: - return 0 - cursor.execute("SELECT MAX(version) FROM schema_migrations") - row = cursor.fetchone() - return int(row[0]) if row and row[0] is not None else 0 - - def set_schema_version(self, version: int = SCHEMA_VERSION) -> None: - cursor = self._conn.cursor() - cursor.execute( - "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, applied_at REAL NOT NULL)" - ) - cursor.execute( - "INSERT OR REPLACE INTO schema_migrations(version, applied_at) VALUES (?, ?)", - (int(version), datetime.now().timestamp()), - ) - self._conn.commit() - - def delete_paragraph_atomic(self, paragraph_hash: str) -> Dict[str, Any]: - """ - 两阶段删除段落:DB 事务内计算 + 提交后执行清理 - - Args: - paragraph_hash: 段落哈希 - - Returns: - cleanup_plan: 包含需要后续从 Vector/GraphStore 中移除的 ID 列表 - """ - cleanup_plan = { - "paragraph_hash": paragraph_hash, - "vector_id_to_remove": None, - "edges_to_remove": [], # (src, tgt) 元组列表 (fallback) - "relation_prune_ops": [], # (subject, object, relation_hash) 精准裁剪 - "episode_sources_to_rebuild": [], - } - - cursor = self._conn.cursor() - try: - # === Phase 1: DB Transaction (可回滚) === - # 使用 IMMEDIATE 模式,一旦开启事务立即锁定 DB (防止其他写操作插队导致幻读) - cursor.execute("BEGIN IMMEDIATE") - - # 1. [快照] 获取候选关系 - cursor.execute("SELECT relation_hash FROM paragraph_relations WHERE paragraph_hash = ?", (paragraph_hash,)) - candidate_relations = [row[0] for row in cursor.fetchall()] - - # 2. [快照] 确认该段落存在并记录 ID 用于向量删除 - cursor.execute("SELECT hash, source FROM paragraphs WHERE hash = ?", (paragraph_hash,)) - paragraph_row = cursor.fetchone() - if paragraph_row: - cleanup_plan["vector_id_to_remove"] = paragraph_hash - cleanup_plan["episode_sources_to_rebuild"] = self._dedupe_episode_sources( - [paragraph_row["source"]] - ) - - # 3. [主删除] 删除段落 (触发 CASCADE 删 paragraph_relations) - cursor.execute("DELETE FROM paragraphs WHERE hash = ?", (paragraph_hash,)) - - # 4. [计算孤儿] - orphaned_hashes = [] - for rel_hash in candidate_relations: - count = cursor.execute( - "SELECT count(*) FROM paragraph_relations WHERE relation_hash = ?", - (rel_hash,) - ).fetchone()[0] - - if count == 0: - # 是孤儿:记录边信息以便后续删 Graph - cursor.execute("SELECT subject, object FROM relations WHERE hash = ?", (rel_hash,)) - rel_info = cursor.fetchone() - if rel_info: - s_val, o_val = rel_info[0], rel_info[1] - cleanup_plan["relation_prune_ops"].append((s_val, o_val, rel_hash)) - - # 仅当 (subject, object) 不再有任何关系时,才计划删整条边(兼容旧实现)。 - sibling_count = cursor.execute( - """ - SELECT count(*) FROM relations - WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?)) - AND LOWER(TRIM(object)) = LOWER(TRIM(?)) - AND hash != ? - """, - (s_val, o_val, rel_hash) - ).fetchone()[0] - if sibling_count == 0: - cleanup_plan["edges_to_remove"].append((s_val, o_val)) - - orphaned_hashes.append(rel_hash) - - # 5. [DB清理] 删除孤儿关系记录 - if orphaned_hashes: - placeholders = ','.join(['?'] * len(orphaned_hashes)) - cursor.execute(f"DELETE FROM relations WHERE hash IN ({placeholders})", orphaned_hashes) - - self._conn.commit() - if cleanup_plan["episode_sources_to_rebuild"]: - self._enqueue_episode_source_rebuilds( - cleanup_plan["episode_sources_to_rebuild"], - reason="paragraph_deleted", - ) - if cleanup_plan["vector_id_to_remove"]: - logger.debug(f"原子删除段落成功: {paragraph_hash}, 计划清理 {len(orphaned_hashes)} 个孤儿关系") - return cleanup_plan - - except Exception as e: - self._conn.rollback() - logger.error(f"DB Transaction failed: {e}") - raise e - - - def clear_all(self) -> None: - """清空所有表数据""" - cursor = self._conn.cursor() - tables = [ - "paragraphs", "entities", "relations", - "paragraph_relations", "paragraph_entities", - "episodes", "episode_paragraphs", - "episode_rebuild_sources", "episode_pending_paragraphs", - ] - for table in tables: - cursor.execute(f"DELETE FROM {table}") - self._conn.commit() - logger.info("元数据存储所有表已清空") - - - - def update_relation_timestamp(self, hash_value: str, access_count_delta: int = 1) -> None: - """更新关系的访问时间和计数""" - now = datetime.now().timestamp() - - # 同时更新 last_accessed (旧) 和 last_reinforced (V5) - - cursor = self._conn.cursor() - cursor.execute(""" - UPDATE relations - SET last_accessed = ?, - access_count = access_count + ? - WHERE hash = ? - """, (now, access_count_delta, hash_value)) - self._conn.commit() - - # ========================================================================= - # V5 Memory System Methods - # ========================================================================= - - def get_relation_status_batch(self, hashes: List[str]) -> Dict[str, Dict[str, Any]]: - """ - 批量获取关系状态 (V5) - - Args: - hashes: 关系哈希列表 - - Returns: - Dict[hash, status_dict] - status_dict 包含: is_inactive, weight(confidence), is_pinned, protected_until, last_reinforced, inactive_since - """ - if not hashes: - return {} - - placeholders = ",".join(["?"] * len(hashes)) - cursor = self._conn.cursor() - cursor.execute(f""" - SELECT hash, is_inactive, confidence, is_pinned, protected_until, last_reinforced, inactive_since - FROM relations - WHERE hash IN ({placeholders}) - """, hashes) - - result = {} - for row in cursor.fetchall(): - result[row["hash"]] = { - "is_inactive": bool(row["is_inactive"]), - "weight": row["confidence"], - "is_pinned": bool(row["is_pinned"]), - "protected_until": row["protected_until"], - "last_reinforced": row["last_reinforced"], - "inactive_since": row["inactive_since"] - } - return result - - def mark_relations_active(self, hashes: List[str], boost_weight: Optional[float] = None) -> None: - """ - 批量标记关系为活跃 (Active/Revive) - - Args: - hashes: 关系哈希列表 - boost_weight: 如果提供,将设置 confidence = max(confidence, boost_weight) - """ - if not hashes: - return - - placeholders = ",".join(["?"] * len(hashes)) - cursor = self._conn.cursor() - - if boost_weight is not None: - cursor.execute(f""" - UPDATE relations - SET is_inactive = 0, - inactive_since = NULL, - confidence = MAX(confidence, ?) - WHERE hash IN ({placeholders}) - """, (boost_weight, *hashes)) - else: - cursor.execute(f""" - UPDATE relations - SET is_inactive = 0, - inactive_since = NULL - WHERE hash IN ({placeholders}) - """, hashes) - - self._conn.commit() - - def update_relations_protection( - self, - hashes: List[str], - protected_until: Optional[float] = None, - is_pinned: Optional[bool] = None, - last_reinforced: Optional[float] = None - ) -> None: - """ - 批量更新关系保护状态 - """ - if not hashes: - return - - updates = [] - params = [] - - if protected_until is not None: - updates.append("protected_until = ?") - params.append(protected_until) - if is_pinned is not None: - updates.append("is_pinned = ?") - params.append(1 if is_pinned else 0) - if last_reinforced is not None: - updates.append("last_reinforced = ?") - params.append(last_reinforced) - - if not updates: - return - - sql_set = ", ".join(updates) - placeholders = ",".join(["?"] * len(hashes)) - - params.extend(hashes) - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE relations - SET {sql_set} - WHERE hash IN ({placeholders}) - """, params) - self._conn.commit() - - def get_prune_candidates(self, cutoff_time: float, limit: int = 1000) -> List[str]: - """ - 获取待修剪候选 (已过冷冻保留期) - - Args: - cutoff_time: 截止时间 (now - 冷冻时长) - limit: 限制数量 - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT hash FROM relations - WHERE is_inactive = 1 - AND inactive_since < ? - LIMIT ? - """, (cutoff_time, limit)) - return [row[0] for row in cursor.fetchall()] - - def backup_and_delete_relations(self, hashes: List[str]) -> int: - """ - 备份并删除关系 (Prune) - - Returns: - 删除的数量 - """ - if not hashes: - return 0 - - placeholders = ",".join(["?"] * len(hashes)) - now = datetime.now().timestamp() - - cursor = self._conn.cursor() - try: - # 1. 备份 - cursor.execute(f""" - INSERT OR REPLACE INTO deleted_relations - (hash, subject, predicate, object, vector_index, confidence, created_at, - vector_state, vector_updated_at, vector_error, vector_retry_count, - source_paragraph, metadata, is_permanent, last_accessed, access_count, - is_inactive, inactive_since, is_pinned, protected_until, last_reinforced, deleted_at) - SELECT - hash, subject, predicate, object, vector_index, confidence, created_at, - vector_state, vector_updated_at, vector_error, vector_retry_count, - source_paragraph, metadata, is_permanent, last_accessed, access_count, - is_inactive, inactive_since, is_pinned, protected_until, last_reinforced, ? - FROM relations - WHERE hash IN ({placeholders}) - """, (now, *hashes)) - - # 2. 删除 (级联删除会自动处理 paragraph_relations 关联) - cursor.execute(f""" - DELETE FROM relations - WHERE hash IN ({placeholders}) - """, hashes) - - deleted_count = cursor.rowcount - self._conn.commit() - return deleted_count - - except Exception as e: - logger.error(f"备份删除失败: {e}") - self._conn.rollback() - return 0 - - def restore_relation_metadata(self, hash_value: str) -> Optional[Dict[str, Any]]: - """ - 从回收站恢复关系元数据 - - Returns: - 恢复后的关系数据 (字典),失败返回 None - """ - cursor = self._conn.cursor() - try: - # 1. 查询备份数据 - cursor.execute("SELECT * FROM deleted_relations WHERE hash = ?", (hash_value,)) - row = cursor.fetchone() - if not row: - return None - - data = dict(row) - # 移除 deleted_at 字段 - if "deleted_at" in data: - del data["deleted_at"] - - # 2. 插入回 relations 表 - # 动态构建 SQL 以适应字段变化 - columns = list(data.keys()) - placeholders = ",".join(["?"] * len(columns)) - cols_str = ",".join(columns) - values = list(data.values()) - - cursor.execute(f""" - INSERT OR REPLACE INTO relations ({cols_str}) - VALUES ({placeholders}) - """, values) - - # 3. 从备份表删除 - cursor.execute("DELETE FROM deleted_relations WHERE hash = ?", (hash_value,)) - - self._conn.commit() - return self._row_to_dict(row, "relation") # 使用助手函数将原始行转换为字典 - - except Exception as e: - logger.error(f"恢复关系失败: {hash_value} - {e}") - self._conn.rollback() - return None - - def restore_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: - """兼容旧调用名:恢复关系。""" - return self.restore_relation_metadata(hash_value) - - def get_protected_relations_hashes(self) -> List[str]: - """获取所有受保护关系的哈希 (Pinned 或 Protected Until > Now)""" - now = datetime.now().timestamp() - - cursor = self._conn.cursor() - cursor.execute(""" - SELECT hash FROM relations - WHERE is_pinned = 1 OR protected_until > ? - """, (now,)) - - return [row[0] for row in cursor.fetchall()] - - - - def get_deleted_relations(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取回收站中的关系记录""" - cursor = self._conn.cursor() - cursor.execute("SELECT * FROM deleted_relations ORDER BY deleted_at DESC LIMIT ?", (limit,)) - data = [] - for row in cursor.fetchall(): - d = dict(row) - # 是否需要解码元数据?是的,与普通行相同 - if "metadata" in d and d["metadata"]: - try: - d["metadata"] = pickle.loads(d["metadata"]) - except Exception: - d["metadata"] = {} - data.append(d) - return data - - def get_deleted_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: - """获取单条回收站记录""" - cursor = self._conn.cursor() - cursor.execute("SELECT * FROM deleted_relations WHERE hash = ?", (hash_value,)) - row = cursor.fetchone() - if not row: return None - - d = dict(row) - if "metadata" in d and d["metadata"]: - try: - d["metadata"] = pickle.loads(d["metadata"]) - except Exception: - d["metadata"] = {} - return d - - def reinforce_relations(self, hashes: List[str]) -> None: - """强化关系 (更新 last_reinforced, is_inactive=0)""" - if not hashes: return - now = datetime.now().timestamp() - - cursor = self._conn.cursor() - # Batch update? chunking - chunk_size = 500 - for i in range(0, len(hashes), chunk_size): - chunk = hashes[i:i+chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - sql = f""" - UPDATE relations - SET last_reinforced = ?, is_inactive = 0, inactive_since = NULL - WHERE hash IN ({placeholders}) - """ - cursor.execute(sql, [now] + chunk) - - self._conn.commit() - - def mark_relations_inactive(self, hashes: List[str], inactive_since: Optional[float] = None) -> None: - """标记关系为非活跃 (Freeze)。兼容显式 inactive_since 或默认当前时间。""" - if not hashes: - return - mark_time = inactive_since if inactive_since is not None else datetime.now().timestamp() - - cursor = self._conn.cursor() - chunk_size = 500 - for i in range(0, len(hashes), chunk_size): - chunk = hashes[i:i+chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - sql = f""" - UPDATE relations - SET is_inactive = 1, inactive_since = ? - WHERE hash IN ({placeholders}) - """ - cursor.execute(sql, [mark_time] + chunk) - - self._conn.commit() - - def protect_relations( - self, - hashes: List[str], - is_pinned: bool = False, - ttl_seconds: float = 0 - ) -> None: - """ - 设置保护状态 - """ - if not hashes: return - now = datetime.now().timestamp() - protected_until = (now + ttl_seconds) if ttl_seconds > 0 else 0 - - cursor = self._conn.cursor() - chunk_size = 500 - for i in range(0, len(hashes), chunk_size): - chunk = hashes[i:i+chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - - # 由于 is_pinned 和 protected_until 是分开的,如果请求固定(pin),我们会同时更新这两项, - # 但通常用户要么切换固定状态,要么设置 TTL。 - # 如果 is_pinned=True,TTL 通常就不重要了。 - # 但目前的逻辑是正交处理它们的。 - - # 如果用户取消固定 (is_pinned=False),我们是否应该尊重已设置的 TTL? - # 当前的 API 会同时设置这两项。 - - sql = f""" - UPDATE relations - SET is_pinned = ?, protected_until = ? - WHERE hash IN ({placeholders}) - """ - cursor.execute(sql, [is_pinned, protected_until] + chunk) - - self._conn.commit() - - def vacuum(self) -> None: - """优化数据库""" - cursor = self._conn.cursor() - cursor.execute("VACUUM") - self._conn.commit() - logger.info("数据库优化完成") - - def _row_to_dict(self, row: sqlite3.Row, row_type: str) -> Dict[str, Any]: - """ - 将数据库行转换为字典 - - Args: - row: 数据库行 - row_type: 行类型 - - Returns: - 字典 - """ - d = dict(row) - - # 解码pickle字段 - if "metadata" in d and d["metadata"]: - try: - d["metadata"] = pickle.loads(d["metadata"]) - except Exception: - d["metadata"] = {} - - return d - - @property - def is_connected(self) -> bool: - """是否已连接""" - return self._conn is not None - - def __enter__(self): - """上下文管理器入口""" - self.connect() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """上下文管理器出口""" - self.close() - - # ========================================================================= - # V5 Soft Delete & Garbage Collection - # ========================================================================= - - def get_entity_gc_candidates(self, isolated_hashes: List[str], retention_seconds: float) -> List[str]: - """ - 获取实体 GC 候选列表 (Soft Delete Candidates) - 条件: - 1. 在 isolated_hashes 列表中 (由 GraphStore 提供;通常是实体名称) - 2. is_deleted = 0 (未被标记) - 3. created_at < now - retention (过了新手保护期) - 4. 不被任何 active paragraph 引用 (paragraph_entities check) - - Args: - isolated_hashes: 孤儿实体名称列表(兼容传入 hash) - retention_seconds: 保留时间 (秒) - """ - if not isolated_hashes: - return [] - - # GraphStore.get_isolated_nodes 返回节点名,这里做 canonicalize -> entity hash 映射。 - # 同时兼容历史调用直接传 hash。 - normalized_hashes: List[str] = [] - for item in isolated_hashes: - if not item: - continue - v = str(item).strip() - if len(v) == 64 and all(c in "0123456789abcdefABCDEF" for c in v): - normalized_hashes.append(v.lower()) - else: - canon = self._canonicalize_name(v) - if canon: - normalized_hashes.append(compute_hash(canon)) - - normalized_hashes = list(dict.fromkeys(normalized_hashes)) - if not normalized_hashes: - return [] - - now = datetime.now().timestamp() - cutoff = now - retention_seconds - - candidates = [] - batch_size = 900 - - # 分批处理 IN 查询 - for i in range(0, len(normalized_hashes), batch_size): - batch = normalized_hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - # 使用 NOT EXISTS 子查询检查引用 - # 注意: paragraph_entities 中引用的 paragraph 如果被软删了,是否算引用? - # 这里的语义: 只要有 rows 存在于 paragraph_entities 且该 row 对应的 paragraph 没被彻底物理删除,就算引用。 - # 更严格: ... OR (EXISTS ... AND entity_hash=... AND is_deleted=0) - # 但 paragraph_entities 表没有 is_deleted 字段(它是关联表). 我们检查关联是否存在。 - # 如果 paragraph 本身 soft deleted, 它的引用应该失效吗? - # 策略: 只有当 paragraph 也是 active 时,引用才有效。 - # JOIN paragraphs p ON pe.paragraph_hash = p.hash WHERE p.is_deleted = 0 - - query = f""" - SELECT e.hash FROM entities e - WHERE e.hash IN ({placeholders}) - AND e.is_deleted = 0 - AND (e.created_at IS NULL OR e.created_at < ?) - AND NOT EXISTS ( - SELECT 1 FROM paragraph_entities pe - JOIN paragraphs p ON pe.paragraph_hash = p.hash - WHERE pe.entity_hash = e.hash - AND p.is_deleted = 0 - ) - """ - - cursor = self._conn.cursor() - cursor.execute(query, [*batch, cutoff]) - candidates.extend([row[0] for row in cursor.fetchall()]) - - return candidates - - def get_paragraph_gc_candidates(self, retention_seconds: float) -> List[str]: - """ - 获取段落 GC 候选列表 - 条件: - 1. is_deleted = 0 - 2. created_at < cutoff - 3. 没有 Relations (paragraph_relations empty) - 4. 没有 Entities 引用 (paragraph_entities empty) - OR 引用的 Entities 全是软删状态? (太复杂,简单点: 无引用) - - Refined Strategy: - 段落孤儿判定 = - (Left Join paragraph_relations -> NULL) AND - (Left Join paragraph_entities -> NULL) - """ - now = datetime.now().timestamp() - cutoff = now - retention_seconds - - query = """ - SELECT p.hash FROM paragraphs p - LEFT JOIN paragraph_relations pr ON p.hash = pr.paragraph_hash - LEFT JOIN paragraph_entities pe ON p.hash = pe.paragraph_hash - WHERE p.is_deleted = 0 - AND (p.created_at IS NULL OR p.created_at < ?) - AND pr.relation_hash IS NULL - AND pe.entity_hash IS NULL - """ - - cursor = self._conn.cursor() - cursor.execute(query, (cutoff,)) - return [row[0] for row in cursor.fetchall()] - - def mark_as_deleted(self, hashes: List[str], type_: str) -> int: - """ - 标记为软删除 (Mark Phase) - - Args: - hashes: Hash 列表 - type_: 'entity' | 'paragraph' - """ - if not hashes: - return 0 - - table = "entities" if type_ == "entity" else "paragraphs" - now = datetime.now().timestamp() - touched_sources: List[str] = [] - if type_ == "paragraph": - touched_sources = self._get_sources_for_paragraph_hashes(hashes, include_deleted=True) - - count = 0 - batch_size = 900 - for i in range(0, len(hashes), batch_size): - batch = hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - # 幂等更新: 只更那些 is_deleted=0 的 - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE {table} - SET is_deleted = 1, deleted_at = ? - WHERE is_deleted = 0 AND hash IN ({placeholders}) - """, [now] + batch) - count += cursor.rowcount - - self._conn.commit() - if type_ == "paragraph" and count > 0: - self._enqueue_episode_source_rebuilds( - touched_sources, - reason="paragraph_soft_deleted", - ) - if count > 0: - logger.info(f"软删除标记 ({table}): {count} 项") - return count - - def sweep_deleted_items(self, type_: str, grace_period_seconds: float) -> List[Tuple[str, str]]: - """ - 扫描可物理清理的项目 (Sweep Phase - Selection) - - Args: - type_: 'entity' | 'paragraph' - grace_period_seconds: 宽限期 - - Returns: - List[(hash, name)]: 待删除项列表 (paragraph name为空) - """ - table = "entities" if type_ == "entity" else "paragraphs" - now = datetime.now().timestamp() - cutoff = now - grace_period_seconds - - cols = "hash, name" if type_ == "entity" else "hash, '' as name" - - cursor = self._conn.cursor() - cursor.execute(f""" - SELECT {cols} FROM {table} - WHERE is_deleted = 1 - AND deleted_at < ? - """, (cutoff,)) - - return [(row[0], row[1]) for row in cursor.fetchall()] - - def physically_delete_entities(self, hashes: List[str]) -> int: - """物理删除实体 (批量)""" - if not hashes: return 0 - - count = 0 - batch_size = 900 - for i in range(0, len(hashes), batch_size): - batch = hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f"DELETE FROM entities WHERE hash IN ({placeholders})", batch) - count += cursor.rowcount - - self._conn.commit() - return count - - def physically_delete_paragraphs(self, hashes: List[str]) -> int: - """物理删除段落 (批量)""" - if not hashes: return 0 - touched_sources = self._get_sources_for_paragraph_hashes(hashes, include_deleted=True) - - count = 0 - batch_size = 900 - for i in range(0, len(hashes), batch_size): - batch = hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f"DELETE FROM paragraphs WHERE hash IN ({placeholders})", batch) - count += cursor.rowcount - - self._conn.commit() - if count > 0: - self._enqueue_episode_source_rebuilds( - touched_sources, - reason="paragraph_physically_deleted", - ) - return count - - def revive_if_deleted(self, entity_hashes: List[str] = None, paragraph_hashes: List[str] = None) -> int: - """ - 复活已软删的项目 (Auto Revival) - 当数据被再次访问、引用或导入时调用。 - """ - count = 0 - - if entity_hashes: - batch_size = 900 - for i in range(0, len(entity_hashes), batch_size): - batch = entity_hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE entities - SET is_deleted = 0, deleted_at = NULL - WHERE is_deleted = 1 AND hash IN ({placeholders}) - """, batch) - count += cursor.rowcount - - if paragraph_hashes: - touched_sources = self._get_sources_for_paragraph_hashes(paragraph_hashes, include_deleted=True) - batch_size = 900 - for i in range(0, len(paragraph_hashes), batch_size): - batch = paragraph_hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE paragraphs - SET is_deleted = 0, deleted_at = NULL - WHERE is_deleted = 1 AND hash IN ({placeholders}) - """, batch) - count += cursor.rowcount - else: - touched_sources = [] - - if count > 0: - self._conn.commit() - if touched_sources: - self._enqueue_episode_source_rebuilds( - touched_sources, - reason="paragraph_revived", - ) - logger.info(f"自动复活: {count} 项 (Soft Delete Revived)") - - return count - - def revive_entities_by_names(self, names: List[str]) -> int: - """ - 根据名称复活实体 (Convenience wrapper) - """ - if not names: return 0 - - # 使用内部方法计算哈希 - hashes = [compute_hash(self._canonicalize_name(n)) for n in names] - return self.revive_if_deleted(entity_hashes=hashes) - - def get_entity_status_batch(self, hashes: List[str]) -> Dict[str, Dict[str, Any]]: - """批量获取实体状态 (WebUI用)""" - if not hashes: return {} - - result = {} - batch_size = 900 - for i in range(0, len(hashes), batch_size): - batch = hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f""" - SELECT hash, is_deleted, deleted_at - FROM entities - WHERE hash IN ({placeholders}) - """, batch) - - for row in cursor.fetchall(): - result[row[0]] = { - "is_deleted": bool(row[1]), - "deleted_at": row[2] - } - return result - - # ========================================================================= - # Person Profile (问题3) - Switches / Active Set / Snapshots - # ========================================================================= - - def set_person_profile_switch( - self, - stream_id: str, - user_id: str, - enabled: bool, - updated_at: Optional[float] = None, - ) -> None: - """设置人物画像自动注入开关(按 stream_id + user_id)。""" - if not stream_id or not user_id: - raise ValueError("stream_id 和 user_id 不能为空") - - ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO person_profile_switches (stream_id, user_id, enabled, updated_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(stream_id, user_id) DO UPDATE SET - enabled = excluded.enabled, - updated_at = excluded.updated_at - """, - (str(stream_id), str(user_id), 1 if enabled else 0, ts), - ) - self._conn.commit() - - def get_person_profile_switch(self, stream_id: str, user_id: str, default: bool = False) -> bool: - """读取人物画像自动注入开关。""" - if not stream_id or not user_id: - return bool(default) - - cursor = self._conn.cursor() - cursor.execute( - "SELECT enabled FROM person_profile_switches WHERE stream_id = ? AND user_id = ?", - (str(stream_id), str(user_id)), - ) - row = cursor.fetchone() - if not row: - return bool(default) - return bool(row[0]) - - def get_enabled_person_profile_switches(self, limit: int = 1000) -> List[Dict[str, Any]]: - """获取已开启人物画像注入开关的会话范围。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT stream_id, user_id, enabled, updated_at - FROM person_profile_switches - WHERE enabled = 1 - ORDER BY updated_at DESC - LIMIT ? - """, - (int(max(1, limit)),), - ) - return [ - { - "stream_id": row[0], - "user_id": row[1], - "enabled": bool(row[2]), - "updated_at": row[3], - } - for row in cursor.fetchall() - ] - - def mark_person_profile_active( - self, - stream_id: str, - user_id: str, - person_id: str, - seen_at: Optional[float] = None, - ) -> None: - """记录活跃人物(用于定时按需刷新)。""" - if not stream_id or not user_id or not person_id: - return - ts = float(seen_at) if seen_at is not None else datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO person_profile_active_persons (stream_id, user_id, person_id, last_seen_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(stream_id, user_id, person_id) DO UPDATE SET - last_seen_at = excluded.last_seen_at - """, - (str(stream_id), str(user_id), str(person_id), ts), - ) - self._conn.commit() - - def get_active_person_ids_for_enabled_switches( - self, - active_after: Optional[float] = None, - limit: int = 200, - ) -> List[str]: - """获取“已开启开关范围内”的活跃人物集合。""" - cursor = self._conn.cursor() - sql = """ - SELECT a.person_id, MAX(a.last_seen_at) AS last_seen - FROM person_profile_active_persons a - JOIN person_profile_switches s - ON a.stream_id = s.stream_id AND a.user_id = s.user_id - WHERE s.enabled = 1 - """ - params: List[Any] = [] - if active_after is not None: - sql += " AND a.last_seen_at >= ?" - params.append(float(active_after)) - sql += """ - GROUP BY a.person_id - ORDER BY last_seen DESC - LIMIT ? - """ - params.append(int(max(1, limit))) - cursor.execute(sql, tuple(params)) - return [str(row[0]) for row in cursor.fetchall() if row and row[0]] - - def get_latest_person_profile_snapshot(self, person_id: str) -> Optional[Dict[str, Any]]: - """获取人物最新画像快照。""" - if not person_id: - return None - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT - snapshot_id, person_id, profile_version, profile_text, - aliases_json, relation_edges_json, vector_evidence_json, evidence_ids_json, - updated_at, expires_at, source_note - FROM person_profile_snapshots - WHERE person_id = ? - ORDER BY profile_version DESC - LIMIT 1 - """, - (str(person_id),), - ) - row = cursor.fetchone() - if not row: - return None - - def _load_list(raw: Any) -> List[Any]: - if not raw: - return [] - try: - data = json.loads(raw) - return data if isinstance(data, list) else [] - except Exception: - return [] - - return { - "snapshot_id": row[0], - "person_id": row[1], - "profile_version": int(row[2]), - "profile_text": row[3] or "", - "aliases": _load_list(row[4]), - "relation_edges": _load_list(row[5]), - "vector_evidence": _load_list(row[6]), - "evidence_ids": _load_list(row[7]), - "updated_at": row[8], - "expires_at": row[9], - "source_note": row[10] or "", - } - - def upsert_person_profile_snapshot( - self, - person_id: str, - profile_text: str, - aliases: Optional[List[str]] = None, - relation_edges: Optional[List[Dict[str, Any]]] = None, - vector_evidence: Optional[List[Dict[str, Any]]] = None, - evidence_ids: Optional[List[str]] = None, - expires_at: Optional[float] = None, - source_note: str = "", - updated_at: Optional[float] = None, - ) -> Dict[str, Any]: - """写入人物画像快照(按 person_id 自动递增版本)。""" - if not person_id: - raise ValueError("person_id 不能为空") - - aliases = aliases or [] - relation_edges = relation_edges or [] - vector_evidence = vector_evidence or [] - evidence_ids = evidence_ids or [] - ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() - - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT profile_version - FROM person_profile_snapshots - WHERE person_id = ? - ORDER BY profile_version DESC - LIMIT 1 - """, - (str(person_id),), - ) - row = cursor.fetchone() - next_version = int(row[0]) + 1 if row else 1 - - cursor.execute( - """ - INSERT INTO person_profile_snapshots ( - person_id, profile_version, profile_text, - aliases_json, relation_edges_json, vector_evidence_json, evidence_ids_json, - updated_at, expires_at, source_note - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - str(person_id), - next_version, - str(profile_text or ""), - json.dumps(aliases, ensure_ascii=False), - json.dumps(relation_edges, ensure_ascii=False), - json.dumps(vector_evidence, ensure_ascii=False), - json.dumps(evidence_ids, ensure_ascii=False), - ts, - float(expires_at) if expires_at is not None else None, - str(source_note or ""), - ), - ) - self._conn.commit() - latest = self.get_latest_person_profile_snapshot(person_id) - return latest or { - "person_id": person_id, - "profile_version": next_version, - "profile_text": str(profile_text or ""), - "aliases": aliases, - "relation_edges": relation_edges, - "vector_evidence": vector_evidence, - "evidence_ids": evidence_ids, - "updated_at": ts, - "expires_at": expires_at, - "source_note": source_note, - } - - def get_person_profile_override(self, person_id: str) -> Optional[Dict[str, Any]]: - """获取人物画像手工覆盖内容。""" - if not person_id: - return None - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT person_id, override_text, updated_at, updated_by, source - FROM person_profile_overrides - WHERE person_id = ? - LIMIT 1 - """, - (str(person_id),), - ) - row = cursor.fetchone() - if not row: - return None - return { - "person_id": str(row[0]), - "override_text": str(row[1] or ""), - "updated_at": row[2], - "updated_by": str(row[3] or ""), - "source": str(row[4] or ""), - } - - def set_person_profile_override( - self, - person_id: str, - override_text: str, - updated_by: str = "", - source: str = "webui", - updated_at: Optional[float] = None, - ) -> Dict[str, Any]: - """写入人物画像手工覆盖;空文本等价于清除覆盖。""" - if not person_id: - raise ValueError("person_id 不能为空") - - text = str(override_text or "").strip() - if not text: - self.delete_person_profile_override(person_id) - return { - "person_id": str(person_id), - "override_text": "", - "updated_at": None, - "updated_by": str(updated_by or ""), - "source": str(source or ""), - } - - ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO person_profile_overrides ( - person_id, override_text, updated_at, updated_by, source - ) VALUES (?, ?, ?, ?, ?) - ON CONFLICT(person_id) DO UPDATE SET - override_text = excluded.override_text, - updated_at = excluded.updated_at, - updated_by = excluded.updated_by, - source = excluded.source - """, - ( - str(person_id), - text, - ts, - str(updated_by or ""), - str(source or ""), - ), - ) - self._conn.commit() - return self.get_person_profile_override(person_id) or { - "person_id": str(person_id), - "override_text": text, - "updated_at": ts, - "updated_by": str(updated_by or ""), - "source": str(source or ""), - } - - def delete_person_profile_override(self, person_id: str) -> bool: - """删除人物画像手工覆盖。""" - if not person_id: - return False - cursor = self._conn.cursor() - cursor.execute( - "DELETE FROM person_profile_overrides WHERE person_id = ?", - (str(person_id),), - ) - self._conn.commit() - return cursor.rowcount > 0 - - # ========================================================================= - # Episode MVP - # ========================================================================= - - def enqueue_episode_source_rebuild(self, source: str, reason: str = "") -> bool: - """将 source 入队到 episode 重建队列。""" - return bool(self._enqueue_episode_source_rebuilds([source], reason=reason)) - - def fetch_episode_source_rebuild_batch( - self, - limit: int = 20, - max_retry: int = 3, - ) -> List[Dict[str, Any]]: - """获取待处理的 source 重建任务。""" - safe_limit = max(1, int(limit)) - safe_retry = max(0, int(max_retry)) - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT source, status, retry_count, last_error, reason, requested_at, updated_at - FROM episode_rebuild_sources - WHERE status = 'pending' - OR (status = 'failed' AND retry_count < ?) - ORDER BY requested_at ASC, updated_at ASC - LIMIT ? - """, - (safe_retry, safe_limit), - ) - return [dict(row) for row in cursor.fetchall()] - - def mark_episode_source_running( - self, - source: str, - *, - requested_at: Optional[float] = None, - ) -> bool: - """将 source 标记为 running。""" - token = self._normalize_episode_source(source) - if not token: - return False - - now = datetime.now().timestamp() - cursor = self._conn.cursor() - params: List[Any] = [now, token] - sql = """ - UPDATE episode_rebuild_sources - SET status = 'running', - updated_at = ? - WHERE source = ? - AND status IN ('pending', 'failed') - """ - if requested_at is not None: - sql += " AND requested_at = ?" - params.append(float(requested_at)) - cursor.execute(sql, tuple(params)) - self._conn.commit() - return cursor.rowcount > 0 - - def mark_episode_source_done( - self, - source: str, - *, - requested_at: Optional[float] = None, - ) -> bool: - """将 source 标记为 done;若运行期间发生新写入,则保持 pending。""" - token = self._normalize_episode_source(source) - if not token: - return False - - now = datetime.now().timestamp() - cursor = self._conn.cursor() - if requested_at is None: - cursor.execute( - """ - UPDATE episode_rebuild_sources - SET status = 'done', - last_error = NULL, - updated_at = ? - WHERE source = ? - """, - (now, token), - ) - else: - req_ts = float(requested_at) - cursor.execute( - """ - UPDATE episode_rebuild_sources - SET status = CASE - WHEN requested_at > ? THEN 'pending' - ELSE 'done' - END, - last_error = NULL, - updated_at = ? - WHERE source = ? - """, - (req_ts, now, token), - ) - self._conn.commit() - return cursor.rowcount > 0 - - def mark_episode_source_failed( - self, - source: str, - error: str = "", - *, - requested_at: Optional[float] = None, - ) -> bool: - """标记 source 失败;若运行期间发生新写入,则重新回到 pending。""" - token = self._normalize_episode_source(source) - if not token: - return False - - err_text = str(error or "").strip()[:500] - now = datetime.now().timestamp() - cursor = self._conn.cursor() - if requested_at is None: - cursor.execute( - """ - UPDATE episode_rebuild_sources - SET status = 'failed', - retry_count = COALESCE(retry_count, 0) + 1, - last_error = ?, - updated_at = ? - WHERE source = ? - """, - (err_text, now, token), - ) - else: - req_ts = float(requested_at) - cursor.execute( - """ - UPDATE episode_rebuild_sources - SET status = CASE - WHEN requested_at > ? THEN 'pending' - ELSE 'failed' - END, - retry_count = CASE - WHEN requested_at > ? THEN COALESCE(retry_count, 0) - ELSE COALESCE(retry_count, 0) + 1 - END, - last_error = CASE - WHEN requested_at > ? THEN NULL - ELSE ? - END, - updated_at = ? - WHERE source = ? - """, - (req_ts, req_ts, req_ts, err_text, now, token), - ) - self._conn.commit() - return cursor.rowcount > 0 - - def list_episode_source_rebuilds( - self, - *, - statuses: Optional[List[str]] = None, - limit: int = 100, - ) -> List[Dict[str, Any]]: - """列出 source 重建状态。""" - safe_limit = max(1, int(limit)) - params: List[Any] = [] - conditions: List[str] = [] - normalized_statuses = [ - str(item or "").strip().lower() - for item in (statuses or []) - if str(item or "").strip().lower() in {"pending", "running", "done", "failed"} - ] - if normalized_statuses: - placeholders = ",".join(["?"] * len(normalized_statuses)) - conditions.append(f"status IN ({placeholders})") - params.extend(normalized_statuses) - - where_sql = f"WHERE {' AND '.join(conditions)}" if conditions else "" - params.append(safe_limit) - cursor = self._conn.cursor() - cursor.execute( - f""" - SELECT source, status, retry_count, last_error, reason, requested_at, updated_at - FROM episode_rebuild_sources - {where_sql} - ORDER BY updated_at DESC, source ASC - LIMIT ? - """, - tuple(params), - ) - return [dict(row) for row in cursor.fetchall()] - - def get_episode_source_rebuild_summary(self, failed_limit: int = 20) -> Dict[str, Any]: - """汇总 source 重建队列状态。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT status, COUNT(*) AS cnt - FROM episode_rebuild_sources - GROUP BY status - """ - ) - counts = {"pending": 0, "running": 0, "done": 0, "failed": 0, "total": 0} - for row in cursor.fetchall(): - status = str(row["status"] or "").strip().lower() - cnt = int(row["cnt"] or 0) - counts[status] = counts.get(status, 0) + cnt - counts["total"] += cnt - - running = self.list_episode_source_rebuilds(statuses=["running"], limit=20) - failed = self.list_episode_source_rebuilds( - statuses=["failed"], - limit=max(1, int(failed_limit)), - ) - return { - "counts": counts, - "running": running, - "failed": failed, - } - - def get_live_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: - """获取指定 source 下所有 live paragraphs。""" - token = self._normalize_episode_source(source) - if not token: - return [] - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT * - FROM paragraphs - WHERE TRIM(COALESCE(source, '')) = ? - AND (is_deleted IS NULL OR is_deleted = 0) - ORDER BY created_at ASC, hash ASC - """, - (token,), - ) - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def list_episode_sources_for_rebuild(self) -> List[str]: - """列出全量重建涉及的 source(live paragraphs + stale episodes)。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT DISTINCT source - FROM ( - SELECT TRIM(source) AS source - FROM paragraphs - WHERE TRIM(COALESCE(source, '')) != '' - AND (is_deleted IS NULL OR is_deleted = 0) - UNION - SELECT TRIM(source) AS source - FROM episodes - WHERE TRIM(COALESCE(source, '')) != '' - ) - WHERE TRIM(COALESCE(source, '')) != '' - ORDER BY source ASC - """ - ) - return self._dedupe_episode_sources([row["source"] for row in cursor.fetchall()]) - - def is_episode_source_query_blocked(self, source: str) -> bool: - """判断 source 是否处于重建中或失败状态。""" - token = self._normalize_episode_source(source) - if not token: - return False - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT 1 - FROM episode_rebuild_sources - WHERE source = ? - AND status IN ('pending', 'running', 'failed') - LIMIT 1 - """, - (token,), - ) - return cursor.fetchone() is not None - - def replace_episodes_for_source( - self, - source: str, - episodes_payloads: List[Dict[str, Any]], - ) -> Dict[str, Any]: - """按 source 全量替换 episode 结果。""" - token = self._normalize_episode_source(source) - if not token: - return {"source": "", "episode_count": 0} - - payloads = [dict(item) for item in (episodes_payloads or []) if isinstance(item, dict)] - now = datetime.now().timestamp() - cursor = self._conn.cursor() - - try: - cursor.execute("BEGIN IMMEDIATE") - cursor.execute( - """ - SELECT episode_id, created_at - FROM episodes - WHERE TRIM(COALESCE(source, '')) = ? - """, - (token,), - ) - existing_created_at = { - str(row["episode_id"]): self._as_optional_float(row["created_at"]) - for row in cursor.fetchall() - } - - cursor.execute( - "DELETE FROM episodes WHERE TRIM(COALESCE(source, '')) = ?", - (token,), - ) - - inserted_count = 0 - for raw_payload in payloads: - title = str(raw_payload.get("title", "") or "").strip() - summary = str(raw_payload.get("summary", "") or "").strip() - evidence_ids = [ - str(item).strip() - for item in (raw_payload.get("evidence_ids") or []) - if str(item).strip() - ] - evidence_ids = list(dict.fromkeys(evidence_ids)) - if not title or not summary or not evidence_ids: - continue - - episode_id = str(raw_payload.get("episode_id", "") or "").strip() - if not episode_id: - seed = json.dumps( - { - "source": token, - "title": title, - "summary": summary, - "event_time_start": raw_payload.get("event_time_start"), - "event_time_end": raw_payload.get("event_time_end"), - "evidence_ids": evidence_ids, - }, - ensure_ascii=False, - sort_keys=True, - ) - episode_id = compute_hash(seed) - - participants = [ - str(item).strip() - for item in (raw_payload.get("participants") or []) - if str(item).strip() - ][:16] - keywords = [ - str(item).strip() - for item in (raw_payload.get("keywords") or []) - if str(item).strip() - ][:20] - paragraph_count = raw_payload.get("paragraph_count", len(evidence_ids)) - try: - paragraph_count = max(0, int(paragraph_count)) - except Exception: - paragraph_count = len(evidence_ids) - if paragraph_count <= 0: - paragraph_count = len(evidence_ids) - if paragraph_count <= 0: - continue - - time_confidence = raw_payload.get("time_confidence", 1.0) - llm_confidence = raw_payload.get("llm_confidence", 0.0) - try: - time_confidence = float(time_confidence) - except Exception: - time_confidence = 1.0 - try: - llm_confidence = float(llm_confidence) - except Exception: - llm_confidence = 0.0 - - created_at = existing_created_at.get(episode_id) - created_ts = created_at if created_at is not None else now - updated_ts = self._as_optional_float(raw_payload.get("updated_at")) or now - - cursor.execute( - """ - INSERT INTO episodes ( - episode_id, source, title, summary, - event_time_start, event_time_end, time_granularity, time_confidence, - participants_json, keywords_json, evidence_ids_json, - paragraph_count, llm_confidence, segmentation_model, segmentation_version, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - episode_id, - token, - title[:120], - summary[:2000], - self._as_optional_float(raw_payload.get("event_time_start")), - self._as_optional_float(raw_payload.get("event_time_end")), - str(raw_payload.get("time_granularity", "") or "").strip() or None, - time_confidence, - json.dumps(participants, ensure_ascii=False), - json.dumps(keywords, ensure_ascii=False), - json.dumps(evidence_ids, ensure_ascii=False), - paragraph_count, - llm_confidence, - str(raw_payload.get("segmentation_model", "") or "").strip() or None, - str(raw_payload.get("segmentation_version", "") or "").strip() or None, - created_ts, - updated_ts, - ), - ) - cursor.executemany( - """ - INSERT OR IGNORE INTO episode_paragraphs (episode_id, paragraph_hash, position) - VALUES (?, ?, ?) - """, - [(episode_id, hash_value, idx) for idx, hash_value in enumerate(evidence_ids)], - ) - inserted_count += 1 - - self._conn.commit() - return {"source": token, "episode_count": inserted_count} - except Exception: - self._conn.rollback() - raise - - def enqueue_episode_pending( - self, - paragraph_hash: str, - source: Optional[str] = None, - created_at: Optional[float] = None, - ) -> None: - """将段落入队到 episode 异步生成队列。""" - token = str(paragraph_hash or "").strip() - if not token: - return - now = datetime.now().timestamp() - created_ts = float(created_at) if created_at is not None else now - src = str(source or "").strip() or None - - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO episode_pending_paragraphs ( - paragraph_hash, source, created_at, status, retry_count, last_error, updated_at - ) VALUES (?, ?, ?, 'pending', 0, NULL, ?) - ON CONFLICT(paragraph_hash) DO UPDATE SET - source = excluded.source, - created_at = COALESCE(episode_pending_paragraphs.created_at, excluded.created_at), - status = CASE - WHEN episode_pending_paragraphs.status = 'done' THEN 'done' - ELSE 'pending' - END, - last_error = CASE - WHEN episode_pending_paragraphs.status = 'done' THEN episode_pending_paragraphs.last_error - ELSE NULL - END, - updated_at = excluded.updated_at - """, - (token, src, created_ts, now), - ) - self._conn.commit() - - def fetch_episode_pending_batch(self, limit: int = 20, max_retry: int = 3) -> List[Dict[str, Any]]: - """获取待处理 episode 队列批次。""" - safe_limit = max(1, int(limit)) - safe_retry = max(0, int(max_retry)) - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT paragraph_hash, source, created_at, status, retry_count, last_error, updated_at - FROM episode_pending_paragraphs - WHERE status = 'pending' - OR (status = 'failed' AND retry_count < ?) - ORDER BY updated_at ASC - LIMIT ? - """, - (safe_retry, safe_limit), - ) - return [dict(row) for row in cursor.fetchall()] - - def mark_episode_pending_running(self, hashes: List[str]) -> None: - """批量标记队列项为 running。""" - if not hashes: - return - now = datetime.now().timestamp() - cursor = self._conn.cursor() - chunk_size = 500 - uniq = list(dict.fromkeys([str(h).strip() for h in hashes if str(h).strip()])) - for i in range(0, len(uniq), chunk_size): - chunk = uniq[i:i + chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - cursor.execute( - f""" - UPDATE episode_pending_paragraphs - SET status = 'running', updated_at = ? - WHERE paragraph_hash IN ({placeholders}) - AND status IN ('pending', 'failed') - """, - [now] + chunk, - ) - self._conn.commit() - - def mark_episode_pending_done(self, hashes: List[str]) -> None: - """批量标记队列项为 done。""" - if not hashes: - return - now = datetime.now().timestamp() - cursor = self._conn.cursor() - chunk_size = 500 - uniq = list(dict.fromkeys([str(h).strip() for h in hashes if str(h).strip()])) - for i in range(0, len(uniq), chunk_size): - chunk = uniq[i:i + chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - cursor.execute( - f""" - UPDATE episode_pending_paragraphs - SET status = 'done', - last_error = NULL, - updated_at = ? - WHERE paragraph_hash IN ({placeholders}) - """, - [now] + chunk, - ) - self._conn.commit() - - def mark_episode_pending_failed(self, hash_value: str, error: str = "") -> None: - """标记单条队列项失败并累加重试次数。""" - token = str(hash_value or "").strip() - if not token: - return - now = datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute( - """ - UPDATE episode_pending_paragraphs - SET status = 'failed', - retry_count = COALESCE(retry_count, 0) + 1, - last_error = ?, - updated_at = ? - WHERE paragraph_hash = ? - """, - (str(error or ""), now, token), - ) - self._conn.commit() - - def get_episode_pending_status_counts(self, source: str) -> Dict[str, int]: - """统计某个 source 当前 pending 队列中的状态分布。""" - token = self._normalize_episode_source(source) - if not token: - return {"pending": 0, "running": 0, "failed": 0, "done": 0} - - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT status, COUNT(*) AS count - FROM episode_pending_paragraphs - WHERE TRIM(COALESCE(source, '')) = ? - GROUP BY status - """, - (token,), - ) - counts = {"pending": 0, "running": 0, "failed": 0, "done": 0} - for row in cursor.fetchall(): - status = str(row["status"] or "").strip().lower() - if status in counts: - counts[status] = int(row["count"] or 0) - return counts - - def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: - data = dict(row) - - def _load_list(raw: Any) -> List[Any]: - if not raw: - return [] - try: - val = json.loads(raw) - return val if isinstance(val, list) else [] - except Exception: - return [] - - data["participants"] = _load_list(data.pop("participants_json", None)) - data["keywords"] = _load_list(data.pop("keywords_json", None)) - data["evidence_ids"] = _load_list(data.pop("evidence_ids_json", None)) - return data - - @staticmethod - def _as_optional_float(value: Any) -> Optional[float]: - if value is None: - return None - try: - return float(value) - except Exception: - return None - - def upsert_episode(self, payload: Dict[str, Any]) -> Dict[str, Any]: - """写入或更新 episode。""" - if not isinstance(payload, dict): - raise ValueError("payload 必须是字典") - - title = str(payload.get("title", "") or "").strip() - summary = str(payload.get("summary", "") or "").strip() - if not title: - raise ValueError("episode.title 不能为空") - if not summary: - raise ValueError("episode.summary 不能为空") - - source = str(payload.get("source", "") or "").strip() or None - participants_raw = payload.get("participants", []) or [] - keywords_raw = payload.get("keywords", []) or [] - evidence_ids_raw = payload.get("evidence_ids", []) or [] - participants = [str(x).strip() for x in participants_raw if str(x).strip()] - keywords = [str(x).strip() for x in keywords_raw if str(x).strip()] - evidence_ids = [str(x).strip() for x in evidence_ids_raw if str(x).strip()] - - now = datetime.now().timestamp() - created_at = self._as_optional_float(payload.get("created_at")) - updated_at = self._as_optional_float(payload.get("updated_at")) - created_ts = created_at if created_at is not None else now - updated_ts = updated_at if updated_at is not None else now - - episode_id = str(payload.get("episode_id", "") or "").strip() - if not episode_id: - seed = json.dumps( - { - "source": source, - "title": title, - "summary": summary, - "event_time_start": payload.get("event_time_start"), - "event_time_end": payload.get("event_time_end"), - "evidence_ids": evidence_ids, - }, - ensure_ascii=False, - sort_keys=True, - ) - episode_id = compute_hash(seed) - - paragraph_count = payload.get("paragraph_count") - if paragraph_count is None: - paragraph_count = len(evidence_ids) - try: - paragraph_count = int(paragraph_count) - except Exception: - paragraph_count = len(evidence_ids) - - time_conf = payload.get("time_confidence", 1.0) - llm_conf = payload.get("llm_confidence", 0.0) - try: - time_conf = float(time_conf) - except Exception: - time_conf = 1.0 - try: - llm_conf = float(llm_conf) - except Exception: - llm_conf = 0.0 - - cursor = self._conn.cursor() - cursor.execute( - "SELECT created_at FROM episodes WHERE episode_id = ? LIMIT 1", - (episode_id,), - ) - existed = cursor.fetchone() - if existed and existed[0] is not None: - created_ts = float(existed[0]) - - cursor.execute( - """ - INSERT INTO episodes ( - episode_id, source, title, summary, - event_time_start, event_time_end, time_granularity, time_confidence, - participants_json, keywords_json, evidence_ids_json, - paragraph_count, llm_confidence, segmentation_model, segmentation_version, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(episode_id) DO UPDATE SET - source = excluded.source, - title = excluded.title, - summary = excluded.summary, - event_time_start = excluded.event_time_start, - event_time_end = excluded.event_time_end, - time_granularity = excluded.time_granularity, - time_confidence = excluded.time_confidence, - participants_json = excluded.participants_json, - keywords_json = excluded.keywords_json, - evidence_ids_json = excluded.evidence_ids_json, - paragraph_count = excluded.paragraph_count, - llm_confidence = excluded.llm_confidence, - segmentation_model = excluded.segmentation_model, - segmentation_version = excluded.segmentation_version, - updated_at = excluded.updated_at - """, - ( - episode_id, - source, - title, - summary, - self._as_optional_float(payload.get("event_time_start")), - self._as_optional_float(payload.get("event_time_end")), - str(payload.get("time_granularity", "") or "").strip() or None, - time_conf, - json.dumps(participants, ensure_ascii=False), - json.dumps(keywords, ensure_ascii=False), - json.dumps(evidence_ids, ensure_ascii=False), - max(0, paragraph_count), - llm_conf, - str(payload.get("segmentation_model", "") or "").strip() or None, - str(payload.get("segmentation_version", "") or "").strip() or None, - created_ts, - updated_ts, - ), - ) - self._conn.commit() - return self.get_episode_by_id(episode_id) or {"episode_id": episode_id} - - def bind_episode_paragraphs(self, episode_id: str, paragraph_hashes_ordered: List[str]) -> int: - """重建 episode 与段落映射。""" - token = str(episode_id or "").strip() - if not token: - raise ValueError("episode_id 不能为空") - - normalized: List[str] = [] - seen = set() - for item in paragraph_hashes_ordered or []: - h = str(item or "").strip() - if not h or h in seen: - continue - seen.add(h) - normalized.append(h) - - cursor = self._conn.cursor() - cursor.execute("DELETE FROM episode_paragraphs WHERE episode_id = ?", (token,)) - - if normalized: - cursor.executemany( - """ - INSERT OR IGNORE INTO episode_paragraphs (episode_id, paragraph_hash, position) - VALUES (?, ?, ?) - """, - [(token, h, idx) for idx, h in enumerate(normalized)], - ) - - now = datetime.now().timestamp() - cursor.execute( - """ - UPDATE episodes - SET paragraph_count = ?, updated_at = ? - WHERE episode_id = ? - """, - (len(normalized), now, token), - ) - self._conn.commit() - return len(normalized) - - def _build_episode_query_components( - self, - *, - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - ) -> Tuple[str, str, str, List[str], List[Any]]: - source_expr = "TRIM(COALESCE(e.source, ''))" - effective_start = "COALESCE(e.event_time_start, e.event_time_end, e.updated_at)" - effective_end = "COALESCE(e.event_time_end, e.event_time_start, e.updated_at)" - conditions: List[str] = [] - params: List[Any] = [] - - conditions.append(f"{source_expr} != ''") - conditions.append("COALESCE(e.paragraph_count, 0) > 0") - conditions.append( - """ - NOT EXISTS ( - SELECT 1 - FROM episode_rebuild_sources ers - WHERE ers.source = TRIM(COALESCE(e.source, '')) - AND ers.status IN ('pending', 'running') - ) - """ - ) - - if source: - token = self._normalize_episode_source(source) - if not token: - return source_expr, effective_start, effective_end, ["1 = 0"], [] - conditions.append(f"{source_expr} = ?") - params.append(token) - - p = str(person or "").strip().lower() - if p: - like_person = f"%{p}%" - conditions.append( - """ - ( - LOWER(COALESCE(e.participants_json, '')) LIKE ? - OR EXISTS ( - SELECT 1 - FROM episode_paragraphs ep_person - JOIN paragraph_entities pe ON pe.paragraph_hash = ep_person.paragraph_hash - JOIN entities en ON en.hash = pe.entity_hash - WHERE ep_person.episode_id = e.episode_id - AND LOWER(en.name) LIKE ? - ) - ) - """ - ) - params.extend([like_person, like_person]) - - if time_from is not None and time_to is not None: - conditions.append(f"({effective_end} >= ? AND {effective_start} <= ?)") - params.extend([float(time_from), float(time_to)]) - elif time_from is not None: - conditions.append(f"({effective_end} >= ?)") - params.append(float(time_from)) - elif time_to is not None: - conditions.append(f"({effective_start} <= ?)") - params.append(float(time_to)) - - return source_expr, effective_start, effective_end, conditions, params - - @staticmethod - def _tokenize_episode_query(query: str) -> Tuple[str, List[str]]: - """将 episode 查询归一化为短语和 token。""" - normalized = normalize_text(str(query or "")).strip().lower() - if not normalized: - return "", [] - - token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}") - tokens: List[str] = [] - seen = set() - for token in token_pattern.findall(normalized): - if token in seen: - continue - seen.add(token) - tokens.append(token) - - if not tokens and len(normalized) >= 2: - tokens = [normalized] - return normalized, tokens - - def get_episode_rows_by_paragraph_hashes( - self, - paragraph_hashes: List[str], - *, - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - ) -> List[Dict[str, Any]]: - normalized: List[str] = [] - seen = set() - for item in paragraph_hashes or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - normalized.append(token) - if not normalized: - return [] - - _, _, _, conditions, params = self._build_episode_query_components( - time_from=time_from, - time_to=time_to, - person=person, - source=source, - ) - placeholders = ",".join(["?"] * len(normalized)) - conditions.append(f"ep.paragraph_hash IN ({placeholders})") - conditions.append("(p.is_deleted IS NULL OR p.is_deleted = 0)") - where_sql = "WHERE " + " AND ".join(conditions) - - sql = f""" - SELECT e.*, ep.paragraph_hash AS matched_paragraph_hash - FROM episodes e - JOIN episode_paragraphs ep ON ep.episode_id = e.episode_id - JOIN paragraphs p ON p.hash = ep.paragraph_hash - {where_sql} - ORDER BY e.updated_at DESC - """ - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params + normalized)) - - grouped: Dict[str, Dict[str, Any]] = {} - for row in cursor.fetchall(): - episode_id = str(row["episode_id"] or "").strip() - if not episode_id: - continue - payload = grouped.get(episode_id) - if payload is None: - payload = self._episode_row_to_dict(row) - payload["matched_paragraph_hashes"] = [] - grouped[episode_id] = payload - matched_hash = str(row["matched_paragraph_hash"] or "").strip() - if matched_hash and matched_hash not in payload["matched_paragraph_hashes"]: - payload["matched_paragraph_hashes"].append(matched_hash) - - out = list(grouped.values()) - for item in out: - item["matched_paragraph_count"] = len(item.get("matched_paragraph_hashes", [])) - return out - - def get_episode_rows_by_relation_hashes( - self, - relation_hashes: List[str], - *, - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - ) -> List[Dict[str, Any]]: - normalized: List[str] = [] - seen = set() - for item in relation_hashes or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - normalized.append(token) - if not normalized: - return [] - - _, _, _, conditions, params = self._build_episode_query_components( - time_from=time_from, - time_to=time_to, - person=person, - source=source, - ) - placeholders = ",".join(["?"] * len(normalized)) - conditions.append(f"pr.relation_hash IN ({placeholders})") - conditions.append("(p.is_deleted IS NULL OR p.is_deleted = 0)") - where_sql = "WHERE " + " AND ".join(conditions) - - sql = f""" - SELECT - e.*, - p.hash AS matched_paragraph_hash, - pr.relation_hash AS matched_relation_hash - FROM episodes e - JOIN episode_paragraphs ep ON ep.episode_id = e.episode_id - JOIN paragraphs p ON p.hash = ep.paragraph_hash - JOIN paragraph_relations pr ON pr.paragraph_hash = p.hash - {where_sql} - ORDER BY e.updated_at DESC - """ - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params + normalized)) - - grouped: Dict[str, Dict[str, Any]] = {} - for row in cursor.fetchall(): - episode_id = str(row["episode_id"] or "").strip() - if not episode_id: - continue - payload = grouped.get(episode_id) - if payload is None: - payload = self._episode_row_to_dict(row) - payload["matched_paragraph_hashes"] = [] - payload["matched_relation_hashes"] = [] - grouped[episode_id] = payload - matched_paragraph = str(row["matched_paragraph_hash"] or "").strip() - matched_relation = str(row["matched_relation_hash"] or "").strip() - if matched_paragraph and matched_paragraph not in payload["matched_paragraph_hashes"]: - payload["matched_paragraph_hashes"].append(matched_paragraph) - if matched_relation and matched_relation not in payload["matched_relation_hashes"]: - payload["matched_relation_hashes"].append(matched_relation) - - out = list(grouped.values()) - for item in out: - item["matched_paragraph_count"] = len(item.get("matched_paragraph_hashes", [])) - item["matched_relation_count"] = len(item.get("matched_relation_hashes", [])) - return out - - def query_episodes( - self, - query: str = "", - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - limit: int = 20, - ) -> List[Dict[str, Any]]: - """查询 episode 列表。""" - safe_limit = max(1, int(limit)) - _, effective_start, effective_end, conditions, params = self._build_episode_query_components( - time_from=time_from, - time_to=time_to, - person=person, - source=source, - ) - - q, tokens = self._tokenize_episode_query(query) - select_score_sql = "0.0 AS lexical_score" - order_sql = f"{effective_end} DESC, e.updated_at DESC" - select_params: List[Any] = [] - query_params: List[Any] = [] - if q: - field_exprs = { - "title": "LOWER(COALESCE(e.title, ''))", - "summary": "LOWER(COALESCE(e.summary, ''))", - "keywords": "LOWER(COALESCE(e.keywords_json, ''))", - "participants": "LOWER(COALESCE(e.participants_json, ''))", - } - - score_parts: List[str] = [] - phrase_like = f"%{q}%" - score_parts.extend( - [ - f"CASE WHEN {field_exprs['title']} LIKE ? THEN 6.0 ELSE 0.0 END", - f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 4.5 ELSE 0.0 END", - f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 3.0 ELSE 0.0 END", - f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 2.0 ELSE 0.0 END", - ] - ) - select_params.extend([phrase_like, phrase_like, phrase_like, phrase_like]) - - token_predicates: List[str] = [] - for token in tokens: - like = f"%{token}%" - token_any = ( - f"({field_exprs['title']} LIKE ? OR " - f"{field_exprs['summary']} LIKE ? OR " - f"{field_exprs['keywords']} LIKE ? OR " - f"{field_exprs['participants']} LIKE ?)" - ) - token_predicates.append(token_any) - query_params.extend([like, like, like, like]) - - score_parts.append( - "(" - f"CASE WHEN {field_exprs['title']} LIKE ? THEN 3.0 ELSE 0.0 END + " - f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 2.5 ELSE 0.0 END + " - f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 2.0 ELSE 0.0 END + " - f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 1.5 ELSE 0.0 END + " - f"CASE WHEN {token_any.replace('?', '?')} THEN 2.0 ELSE 0.0 END" - ")" - ) - select_params.extend([like, like, like, like, like, like, like, like]) - - if token_predicates: - conditions.append("(" + " OR ".join(token_predicates) + ")") - - select_score_sql = f"({' + '.join(score_parts)}) AS lexical_score" - order_sql = f"lexical_score DESC, {effective_end} DESC, e.updated_at DESC" - - where_sql = ("WHERE " + " AND ".join(conditions)) if conditions else "" - sql = f""" - SELECT e.*, {select_score_sql} - FROM episodes e - {where_sql} - ORDER BY {order_sql} - LIMIT ? - """ - final_params = list(select_params) + list(params) + list(query_params) + [safe_limit] - - cursor = self._conn.cursor() - cursor.execute(sql, tuple(final_params)) - return [self._episode_row_to_dict(row) for row in cursor.fetchall()] - - def get_episode_by_id(self, episode_id: str) -> Optional[Dict[str, Any]]: - """获取单条 episode。""" - token = str(episode_id or "").strip() - if not token: - return None - cursor = self._conn.cursor() - cursor.execute( - "SELECT * FROM episodes WHERE episode_id = ? LIMIT 1", - (token,), - ) - row = cursor.fetchone() - if not row: - return None - return self._episode_row_to_dict(row) - - def get_episode_paragraphs(self, episode_id: str, limit: int = 100) -> List[Dict[str, Any]]: - """获取 episode 关联段落(按 position 排序)。""" - token = str(episode_id or "").strip() - if not token: - return [] - safe_limit = max(1, int(limit)) - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT p.*, ep.position - FROM episode_paragraphs ep - JOIN paragraphs p ON p.hash = ep.paragraph_hash - WHERE ep.episode_id = ? - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - ORDER BY ep.position ASC - LIMIT ? - """, - (token, safe_limit), - ) - items = [] - for row in cursor.fetchall(): - payload = self._row_to_dict(row, "paragraph") - payload["position"] = row["position"] - items.append(payload) - return items - - def has_table(self, table_name: str) -> bool: - """检查数据库是否存在指定表。""" - if not self._conn: - return False - cursor = self._conn.cursor() - cursor.execute( - "SELECT 1 FROM sqlite_master WHERE type='table' AND name = ? LIMIT 1", - (table_name,), - ) - return cursor.fetchone() is not None - - def get_deleted_entities(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取已软删除的实体 (回收站用)""" - if not self.has_table("entities"): return [] - - cursor = self._conn.cursor() - cursor.execute(""" - SELECT hash, name, deleted_at - FROM entities - WHERE is_deleted = 1 - ORDER BY deleted_at DESC - LIMIT ? - """, (limit,)) - - items = [] - for row in cursor.fetchall(): - items.append({ - "hash": row[0], - "name": row[1], - "type": "entity", # 标记为实体 - "deleted_at": row[2] - }) - return items - - def __repr__(self) -> str: - stats = self.get_statistics() if self.is_connected else {} - return ( - f"MetadataStore(paragraphs={stats.get('paragraph_count', 0)}, " - f"entities={stats.get('entity_count', 0)}, " - f"relations={stats.get('relation_count', 0)})" - ) - - def has_data(self) -> bool: - """检查磁盘上是否存在现有数据""" - if self.data_dir is None: - return False - return (self.data_dir / self.db_name).exists() diff --git a/plugins/A_memorix/core/storage/type_detection.py b/plugins/A_memorix/core/storage/type_detection.py deleted file mode 100644 index c20d2cb4..00000000 --- a/plugins/A_memorix/core/storage/type_detection.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Heuristic detection for import strategies and stored knowledge types.""" - -from __future__ import annotations - -import re -from typing import Optional - -from .knowledge_types import ( - ImportStrategy, - KnowledgeType, - parse_import_strategy, - resolve_stored_knowledge_type, -) - - -_NARRATIVE_MARKERS = [ - r"然后", - r"接着", - r"于是", - r"后来", - r"最后", - r"突然", - r"一天", - r"曾经", - r"有一次", - r"从前", - r"说道", - r"问道", - r"想着", - r"觉得", -] -_FACTUAL_MARKERS = [ - r"是", - r"有", - r"在", - r"为", - r"属于", - r"位于", - r"包含", - r"拥有", - r"成立于", - r"出生于", -] - - -def _non_empty_lines(content: str) -> list[str]: - return [line for line in str(content or "").splitlines() if line.strip()] - - -def looks_like_structured_text(content: str) -> bool: - text = str(content or "").strip() - if "|" not in text or text.count("|") < 2: - return False - parts = text.split("|") - return len(parts) == 3 and all(part.strip() for part in parts) - - -def looks_like_quote_text(content: str) -> bool: - lines = _non_empty_lines(content) - if len(lines) < 5: - return False - avg_len = sum(len(line) for line in lines) / len(lines) - return avg_len < 20 - - -def looks_like_narrative_text(content: str) -> bool: - text = str(content or "").strip() - if not text: - return False - - narrative_score = sum(1 for marker in _NARRATIVE_MARKERS if re.search(marker, text)) - has_dialogue = bool(re.search(r'["「『].*?["」』]', text)) - has_chapter = any(token in text[:500] for token in ("Chapter", "CHAPTER", "###")) - return has_chapter or has_dialogue or narrative_score >= 2 - - -def looks_like_factual_text(content: str) -> bool: - text = str(content or "").strip() - if not text: - return False - if looks_like_structured_text(text) or looks_like_quote_text(text): - return False - - factual_score = sum(1 for marker in _FACTUAL_MARKERS if re.search(r"\s*" + marker + r"\s*", text)) - if factual_score <= 0: - return False - - if len(text) <= 240: - return True - return factual_score >= 2 and not looks_like_narrative_text(text) - - -def select_import_strategy( - content: str, - *, - override: Optional[str | ImportStrategy] = None, - chat_log: bool = False, -) -> ImportStrategy: - """文本导入策略选择:override > quote > factual > narrative。""" - - if chat_log: - return ImportStrategy.NARRATIVE - - strategy = parse_import_strategy(override, default=ImportStrategy.AUTO) - if strategy != ImportStrategy.AUTO: - return strategy - - if looks_like_quote_text(content): - return ImportStrategy.QUOTE - if looks_like_factual_text(content): - return ImportStrategy.FACTUAL - return ImportStrategy.NARRATIVE - - -def detect_knowledge_type(content: str) -> KnowledgeType: - """自动检测落库 knowledge_type;无法可靠判断时回退 mixed。""" - - text = str(content or "").strip() - if not text: - return KnowledgeType.MIXED - if looks_like_structured_text(text): - return KnowledgeType.STRUCTURED - if looks_like_quote_text(text): - return KnowledgeType.QUOTE - if looks_like_factual_text(text): - return KnowledgeType.FACTUAL - if looks_like_narrative_text(text): - return KnowledgeType.NARRATIVE - return KnowledgeType.MIXED - - -def get_type_from_user_input(type_hint: Optional[str], content: str) -> KnowledgeType: - """优先使用显式 type_hint,否则自动检测。""" - - if type_hint: - return resolve_stored_knowledge_type(type_hint, content=content) - return detect_knowledge_type(content) diff --git a/plugins/A_memorix/core/storage/vector_store.py b/plugins/A_memorix/core/storage/vector_store.py deleted file mode 100644 index 97a9144c..00000000 --- a/plugins/A_memorix/core/storage/vector_store.py +++ /dev/null @@ -1,776 +0,0 @@ -""" -向量存储模块 - -基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。 -""" - -import os -import pickle -import hashlib -import shutil -import time -from pathlib import Path -from typing import Optional, Union, Tuple, List, Dict, Set, Any -import random -import threading # Added threading import - -import numpy as np - -try: - import faiss - HAS_FAISS = True -except ImportError: - HAS_FAISS = False - -from src.common.logger import get_logger -from ..utils.quantization import QuantizationType -from ..utils.io import atomic_write, atomic_save_path - -logger = get_logger("A_Memorix.VectorStore") - - -class VectorStore: - """ - 向量存储类 (SQ8 + Append-Only Disk) - - 特性: - - 索引: IndexIDMap2(IndexScalarQuantizer(QT_8bit)) - - 存储: float16 on-disk binary (vectors.bin) - - 内存: 仅索引常驻 RAM (<512MB for 100k vectors) - - ID: SHA1-based stable int64 IDs - - 一致性: 强制 L2 Normalization (IP == Cosine) - """ - - # 默认训练触发阈值 (40 样本,过大可能导致小数据集不生效,过小可能量化退化) - DEFAULT_MIN_TRAIN = 40 - # 强制训练样本量 - TRAIN_SIZE = 10000 - # 储水池采样上限 (流式处理前 50k 数据) - RESERVOIR_CAPACITY = 10000 - RESERVOIR_SAMPLE_SCOPE = 50000 - - def __init__( - self, - dimension: int, - quantization_type: QuantizationType = QuantizationType.INT8, - index_type: str = "sq8", - data_dir: Optional[Union[str, Path]] = None, - use_mmap: bool = True, - buffer_size: int = 1024, - ): - if not HAS_FAISS: - raise ImportError("Faiss 未安装,请安装: pip install faiss-cpu") - - self.dimension = dimension - self.data_dir = Path(data_dir) if data_dir else None - if self.data_dir: - self.data_dir.mkdir(parents=True, exist_ok=True) - if quantization_type != QuantizationType.INT8: - raise ValueError( - "vNext 仅支持 quantization_type=int8(SQ8)。" - " 请更新配置并执行 scripts/release_vnext_migrate.py migrate。" - ) - normalized_index_type = str(index_type or "sq8").strip().lower() - if normalized_index_type not in {"sq8", "int8"}: - raise ValueError( - "vNext 仅支持 index_type=sq8。" - " 请更新配置并执行 scripts/release_vnext_migrate.py migrate。" - ) - self.quantization_type = QuantizationType.INT8 - self.index_type = "sq8" - self.buffer_size = buffer_size - - self._index: Optional[faiss.IndexIDMap2] = None - self._init_index() - - self._is_trained = False - self._vector_norm = "l2" - - # Fallback Index (Flat) - 用于在 SQ8 训练完成前提供检索能力 - # 必须使用 IndexIDMap2 以保证 ID 与主索引一致 - self._fallback_index: Optional[faiss.IndexIDMap2] = None - self._init_fallback_index() - - self._known_hashes: Set[str] = set() - self._deleted_ids: Set[int] = set() - - self._reservoir_buffer: List[np.ndarray] = [] - self._seen_count_for_reservoir = 0 - - self._write_buffer_vecs: List[np.ndarray] = [] - self._write_buffer_ids: List[int] = [] - - self._total_added = 0 - self._total_deleted = 0 - self._bin_count = 0 - - # Thread safety lock - self._lock = threading.RLock() - - logger.info(f"VectorStore Init: dim={dimension}, SQ8 Mode, Append-Only Storage") - - def _init_index(self): - """初始化空的 Faiss 索引""" - quantizer = faiss.IndexScalarQuantizer( - self.dimension, - faiss.ScalarQuantizer.QT_8bit, - faiss.METRIC_INNER_PRODUCT - ) - self._index = faiss.IndexIDMap2(quantizer) - self._is_trained = False - - def _init_fallback_index(self): - """初始化 Flat 回退索引""" - flat_index = faiss.IndexFlatIP(self.dimension) - self._fallback_index = faiss.IndexIDMap2(flat_index) - logger.debug("Fallback index (Flat) initialized.") - - @staticmethod - def _generate_id(key: str) -> int: - """生成稳定的 int64 ID (SHA1 截断)""" - h = hashlib.sha1(key.encode("utf-8")).digest() - val = int.from_bytes(h[:8], byteorder="big", signed=False) - return val & 0x7FFFFFFFFFFFFFFF - - @property - def _bin_path(self) -> Path: - return self.data_dir / "vectors.bin" - - @property - def _ids_bin_path(self) -> Path: - return self.data_dir / "vectors_ids.bin" - - @property - def _int_to_str_map(self) -> Dict[int, str]: - """Lazy build volatile map from known hashes""" - # Note: This is read-heavy and cached, might need lock if _known_hashes updates concurrently - # But add/delete are now locked, so checking len mismatch is somewhat safe-ish for quick dirty cache - if not hasattr(self, "_cached_map") or len(self._cached_map) != len(self._known_hashes): - with self._lock: # Protect cache rebuild - self._cached_map = {self._generate_id(k): k for k in self._known_hashes} - return self._cached_map - - def add(self, vectors: np.ndarray, ids: List[str]) -> int: - with self._lock: - if vectors.shape[1] != self.dimension: - raise ValueError(f"Dimension mismatch: {vectors.shape[1]} vs {self.dimension}") - - vectors = np.ascontiguousarray(vectors, dtype=np.float32) - faiss.normalize_L2(vectors) - - processed_vecs = [] - processed_int_ids = [] - - for i, str_id in enumerate(ids): - if str_id in self._known_hashes: - continue - - int_id = self._generate_id(str_id) - self._known_hashes.add(str_id) - - processed_vecs.append(vectors[i]) - processed_int_ids.append(int_id) - - if not processed_vecs: - return 0 - - batch_vecs = np.array(processed_vecs, dtype=np.float32) - batch_ids = np.array(processed_int_ids, dtype=np.int64) - - self._write_buffer_vecs.append(batch_vecs) - self._write_buffer_ids.extend(processed_int_ids) - - if len(self._write_buffer_ids) >= self.buffer_size: - self._flush_write_buffer_unlocked() - - if not self._is_trained: - # 双写到回退索引 - self._fallback_index.add_with_ids(batch_vecs, batch_ids) - - self._update_reservoir(batch_vecs) - # 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断 - if len(self._reservoir_buffer) >= 10000: - logger.info(f"训练样本达到上限,开始训练...") - self._train_and_replay_unlocked() - - self._total_added += len(batch_ids) - return len(batch_ids) - - def _flush_write_buffer(self): - with self._lock: - self._flush_write_buffer_unlocked() - - def _flush_write_buffer_unlocked(self): - if not self._write_buffer_vecs: - return - - batch_vecs = np.concatenate(self._write_buffer_vecs, axis=0) - batch_ids = np.array(self._write_buffer_ids, dtype=np.int64) - - vecs_fp16 = batch_vecs.astype(np.float16) - - with open(self._bin_path, "ab") as f: - f.write(vecs_fp16.tobytes()) - - ids_bytes = batch_ids.astype('>i8').tobytes() - with open(self._ids_bin_path, "ab") as f: - f.write(ids_bytes) - - self._bin_count += len(batch_ids) - - if self._is_trained and self._index.is_trained: - self._index.add_with_ids(batch_vecs, batch_ids) - else: - # 即使在 flush 时,如果未训练,也要同步到 fallback - self._fallback_index.add_with_ids(batch_vecs, batch_ids) - - self._write_buffer_vecs.clear() - self._write_buffer_ids.clear() - - def _update_reservoir(self, vectors: np.ndarray): - for vec in vectors: - self._seen_count_for_reservoir += 1 - if len(self._reservoir_buffer) < self.RESERVOIR_CAPACITY: - self._reservoir_buffer.append(vec) - else: - if self._seen_count_for_reservoir <= self.RESERVOIR_SAMPLE_SCOPE: - r = random.randint(0, self._seen_count_for_reservoir - 1) - if r < self.RESERVOIR_CAPACITY: - self._reservoir_buffer[r] = vec - - def _train_and_replay(self): - with self._lock: - self._train_and_replay_unlocked() - - def _train_and_replay_unlocked(self): - if not self._reservoir_buffer: - logger.warning("No training data available.") - return - - train_data = np.array(self._reservoir_buffer, dtype=np.float32) - logger.info(f"Training Index with {len(train_data)} samples...") - - try: - self._index.train(train_data) - except Exception as e: - logger.error(f"SQ8 Training failed: {e}. Staying in fallback mode.") - return - - self._is_trained = True - self._reservoir_buffer = [] - - logger.info("Replaying data from disk to populate index...") - try: - replay_count = self._replay_vectors_to_index() - # 只有当 replay 成功且数据量一致时,才释放回退索引 - if self._index.ntotal >= self._bin_count: - logger.info(f"Replay successful ({self._index.ntotal}/{self._bin_count}). Releasing fallback index.") - self._fallback_index.reset() - else: - logger.warning(f"Replay count mismatch: {self._index.ntotal} vs {self._bin_count}. Keeping fallback index.") - except Exception as e: - logger.error(f"Replay failed: {e}. Keeping fallback index as backup.") - - def _replay_vectors_to_index(self) -> int: - """从 vectors.bin 读取并添加到 index""" - if not self._bin_path.exists() or not self._ids_bin_path.exists(): - return 0 - - vec_item_size = self.dimension * 2 - id_item_size = 8 - chunk_size = 10000 - - with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id: - while True: - vec_data = f_vec.read(chunk_size * vec_item_size) - id_data = f_id.read(chunk_size * id_item_size) - - if not vec_data: - break - - batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) - batch_fp32 = batch_fp16.astype(np.float32) - faiss.normalize_L2(batch_fp32) - - batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) - - valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids] - if not all(valid_mask): - batch_fp32 = batch_fp32[valid_mask] - batch_ids = batch_ids[valid_mask] - - if len(batch_ids) > 0: - self._index.add_with_ids(batch_fp32, batch_ids) - - def search( - self, - query: np.ndarray, - k: int = 10, - filter_deleted: bool = True, - ) -> Tuple[List[str], List[float]]: - query_local = np.array(query, dtype=np.float32, order="C", copy=True) - if query_local.ndim == 1: - got_dim = int(query_local.shape[0]) - query_local = query_local.reshape(1, -1) - elif query_local.ndim == 2: - if query_local.shape[0] != 1: - raise ValueError( - f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}" - ) - got_dim = int(query_local.shape[1]) - else: - raise ValueError( - f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}" - ) - - if got_dim != self.dimension: - raise ValueError( - f"query embedding dimension mismatch: expected={self.dimension} got={got_dim}" - ) - if not np.all(np.isfinite(query_local)): - raise ValueError("query embedding contains non-finite values") - - faiss.normalize_L2(query_local) - - # 查询路径仅负责检索,不在此触发训练/回放。 - # 训练/回放前置到 warmup_index(),并由插件启动阶段触发。 - # Faiss 索引在并发 search 下可能出现阻塞,这里串行化检索调用保证稳定性。 - with self._lock: - self._flush_write_buffer_unlocked() - search_index = self._index if (self._is_trained and self._index.ntotal > 0) else self._fallback_index - if search_index.ntotal == 0: - logger.warning("Indices are empty. No data to search.") - return [], [] - # 执行检索 - dists, ids = search_index.search(query_local, k * 2) - - # Faiss search 返回的是 (1, K) 的数组,取第一行 - dists = dists[0] - ids = ids[0] - - results = [] - for id_val, score in zip(ids, dists): - if id_val == -1: continue - if filter_deleted and id_val in self._deleted_ids: - continue - - str_id = self._int_to_str_map.get(id_val) - if str_id: - results.append((str_id, float(score))) - - # Sort and trim just in case filtering reduced count - results.sort(key=lambda x: x[1], reverse=True) - results = results[:k] - - if not results: - return [], [] - - return [r[0] for r in results], [r[1] for r in results] - - def warmup_index(self, force_train: bool = True) -> Dict[str, Any]: - """ - 预热向量索引(训练/回放前置),避免首个线上查询触发重初始化。 - - Args: - force_train: 是否在满足阈值时强制训练 SQ8 索引 - - Returns: - 预热状态摘要 - """ - started = time.perf_counter() - logger.info(f"metric.vector_index_prewarm_started=1 force_train={bool(force_train)}") - - try: - with self._lock: - self._flush_write_buffer() - - if self._bin_path.exists(): - self._bin_count = self._bin_path.stat().st_size // (self.dimension * 2) - else: - self._bin_count = 0 - - needs_fallback_bootstrap = ( - self._bin_count > 0 - and self._fallback_index.ntotal == 0 - and (not self._is_trained or self._index.ntotal == 0) - ) - if needs_fallback_bootstrap: - self._bootstrap_fallback_from_disk() - - min_train = max(1, int(getattr(self, "min_train_threshold", self.DEFAULT_MIN_TRAIN))) - needs_train = ( - bool(force_train) - and self._bin_count >= min_train - and not self._is_trained - ) - if needs_train: - self._force_train_small_data() - - duration_ms = (time.perf_counter() - started) * 1000.0 - summary = { - "ok": True, - "trained": bool(self._is_trained), - "index_ntotal": int(self._index.ntotal), - "fallback_ntotal": int(self._fallback_index.ntotal), - "bin_count": int(self._bin_count), - "duration_ms": duration_ms, - "error": None, - } - except Exception as e: - duration_ms = (time.perf_counter() - started) * 1000.0 - summary = { - "ok": False, - "trained": bool(self._is_trained), - "index_ntotal": int(self._index.ntotal) if self._index is not None else 0, - "fallback_ntotal": int(self._fallback_index.ntotal) if self._fallback_index is not None else 0, - "bin_count": int(getattr(self, "_bin_count", 0)), - "duration_ms": duration_ms, - "error": str(e), - } - logger.error( - "metric.vector_index_prewarm_fail=1 " - f"metric.vector_index_prewarm_duration_ms={duration_ms:.2f} " - f"error={e}" - ) - return summary - - logger.info( - "metric.vector_index_prewarm_success=1 " - f"metric.vector_index_prewarm_duration_ms={summary['duration_ms']:.2f} " - f"trained={summary['trained']} " - f"index_ntotal={summary['index_ntotal']} " - f"fallback_ntotal={summary['fallback_ntotal']} " - f"bin_count={summary['bin_count']}" - ) - return summary - - def _bootstrap_fallback_from_disk(self): - with self._lock: - self._bootstrap_fallback_from_disk_unlocked() - - def _bootstrap_fallback_from_disk_unlocked(self): - """重启后自举:从磁盘 vectors.bin 加载数据到 fallback 索引""" - if not self._bin_path.exists() or not self._ids_bin_path.exists(): - return - - logger.info("Replaying all disk vectors to fallback index...") - vec_item_size = self.dimension * 2 - id_item_size = 8 - chunk_size = 10000 - - with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id: - while True: - vec_data = f_vec.read(chunk_size * vec_item_size) - id_data = f_id.read(chunk_size * id_item_size) - if not vec_data: break - - batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) - batch_fp32 = batch_fp16.astype(np.float32) - faiss.normalize_L2(batch_fp32) - batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) - - valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids] - if any(valid_mask): - self._fallback_index.add_with_ids(batch_fp32[valid_mask], batch_ids[valid_mask]) - - logger.info(f"Fallback index self-bootstrapped with {self._fallback_index.ntotal} items.") - - def _force_train_small_data(self): - with self._lock: - self._force_train_small_data_unlocked() - - def _force_train_small_data_unlocked(self): - logger.info("Forcing training on small dataset...") - self._reservoir_buffer = [] - - chunk_size = 10000 - vec_item_size = self.dimension * 2 - - with open(self._bin_path, "rb") as f: - while len(self._reservoir_buffer) < self.TRAIN_SIZE: - data = f.read(chunk_size * vec_item_size) - if not data: break - fp16 = np.frombuffer(data, dtype=np.float16).reshape(-1, self.dimension) - fp32 = fp16.astype(np.float32) - faiss.normalize_L2(fp32) - - for vec in fp32: - self._reservoir_buffer.append(vec) - if len(self._reservoir_buffer) >= self.TRAIN_SIZE: - break - - self._train_and_replay_unlocked() - - def delete(self, ids: List[str]) -> int: - with self._lock: - count = 0 - for str_id in ids: - if str_id not in self._known_hashes: - continue - int_id = self._generate_id(str_id) - if int_id not in self._deleted_ids: - self._deleted_ids.add(int_id) - if self._index.is_trained: - self._index.remove_ids(np.array([int_id], dtype=np.int64)) - # 同步从 fallback 移除 - if self._fallback_index.ntotal > 0: - self._fallback_index.remove_ids(np.array([int_id], dtype=np.int64)) - count += 1 - self._total_deleted += count - - # Check GC - self._check_rebuild_needed() - return count - - def _check_rebuild_needed(self): - """GC Excution Check""" - if self._bin_count == 0: return - ratio = len(self._deleted_ids) / self._bin_count - if ratio > 0.3 and len(self._deleted_ids) > 1000: - logger.info(f"Triggering GC/Rebuild (deleted ratio: {ratio:.2f})") - self.rebuild_index() - - def rebuild_index(self): - """GC: 重建索引,压缩 bin 文件""" - with self._lock: - self._rebuild_index_locked() - - def _rebuild_index_locked(self): - """实际 GC 重建逻辑。""" - logger.info("Starting Compaction (GC)...") - - tmp_bin = self.data_dir / "vectors.bin.tmp" - tmp_ids = self.data_dir / "vectors_ids.bin.tmp" - - vec_item_size = self.dimension * 2 - id_item_size = 8 - chunk_size = 10000 - - new_count = 0 - - # 1. Compact Files - with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id, \ - open(tmp_bin, "wb") as w_vec, open(tmp_ids, "wb") as w_id: - while True: - vec_data = f_vec.read(chunk_size * vec_item_size) - id_data = f_id.read(chunk_size * id_item_size) - if not vec_data: break - - batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) - batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) - - keep_mask = [id_ not in self._deleted_ids for id_ in batch_ids] - - if any(keep_mask): - keep_vecs = batch_fp16[keep_mask] - keep_ids = batch_ids[keep_mask] - - w_vec.write(keep_vecs.tobytes()) - w_id.write(keep_ids.astype('>i8').tobytes()) - new_count += len(keep_ids) - - # 2. Reset State & Atomic Swap - self._bin_count = new_count - - # Close current index - self._index.reset() - if self._fallback_index: self._fallback_index.reset() # Also clear fallback - self._is_trained = False - - # Swap files - shutil.move(str(tmp_bin), str(self._bin_path)) - shutil.move(str(tmp_ids), str(self._ids_bin_path)) - - # Reset Tombstones (Critical) - self._deleted_ids.clear() - - # 3. Reload/Rebuild Index (Fresh Train) - # We need to re-train because data distribution might have changed significantly after deletion - self._init_index() - self._init_fallback_index() # Re-init fallback too - self._force_train_small_data() # This will train and replay from the NEW compact file - - logger.info("Compaction Complete.") - - def save(self, data_dir: Optional[Union[str, Path]] = None) -> None: - with self._lock: - if not data_dir: - data_dir = self.data_dir - if not data_dir: - raise ValueError("No data_dir") - - data_dir = Path(data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - - self._flush_write_buffer_unlocked() - - if self._is_trained: - index_path = data_dir / "vectors.index" - with atomic_save_path(index_path) as tmp: - faiss.write_index(self._index, tmp) - - meta = { - "dimension": self.dimension, - "quantization_type": self.quantization_type.value, - "is_trained": self._is_trained, - "vector_norm": self._vector_norm, - "deleted_ids": list(self._deleted_ids), - "known_hashes": list(self._known_hashes), - } - - with atomic_write(data_dir / "vectors_metadata.pkl", "wb") as f: - pickle.dump(meta, f) - - logger.info("VectorStore saved.") - - def migrate_legacy_npy(self, data_dir: Optional[Union[str, Path]] = None) -> Dict[str, Any]: - """ - 离线迁移入口:将 legacy vectors.npy 转为 vNext 二进制格式。 - """ - with self._lock: - target_dir = Path(data_dir) if data_dir else self.data_dir - if target_dir is None: - raise ValueError("No data_dir") - target_dir = Path(target_dir) - npy_path = target_dir / "vectors.npy" - idx_path = target_dir / "vectors.index" - bin_path = target_dir / "vectors.bin" - ids_bin_path = target_dir / "vectors_ids.bin" - meta_path = target_dir / "vectors_metadata.pkl" - - if not npy_path.exists(): - return {"migrated": False, "reason": "npy_missing"} - if not meta_path.exists(): - raise RuntimeError("legacy vectors.npy migration requires vectors_metadata.pkl") - if bin_path.exists() and ids_bin_path.exists(): - return {"migrated": False, "reason": "bin_exists"} - - # Reset in-memory state to avoid appending to stale runtime buffers. - self._known_hashes.clear() - self._deleted_ids.clear() - self._write_buffer_vecs.clear() - self._write_buffer_ids.clear() - self._init_index() - self._init_fallback_index() - self._is_trained = False - self._bin_count = 0 - - self._migrate_from_npy_unlocked(npy_path, idx_path, target_dir) - self.save(target_dir) - return {"migrated": True, "reason": "ok"} - - def load(self, data_dir: Optional[Union[str, Path]] = None) -> None: - with self._lock: - if not data_dir: data_dir = self.data_dir - data_dir = Path(data_dir) - - npy_path = data_dir / "vectors.npy" - idx_path = data_dir / "vectors.index" - bin_path = data_dir / "vectors.bin" - - if npy_path.exists() and not bin_path.exists(): - raise RuntimeError( - "检测到 legacy vectors.npy,vNext 不再支持运行时自动迁移。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - - meta_path = data_dir / "vectors_metadata.pkl" - if not meta_path.exists(): - logger.warning("No metadata found, initialized empty.") - return - - with open(meta_path, "rb") as f: - meta = pickle.load(f) - - if meta.get("vector_norm") != "l2": - logger.warning("Index IDMap2 version mismatch (L2 Norm), forcing rebuild...") - self._known_hashes = set(meta.get("ids", [])) | set(meta.get("known_hashes", [])) - self._deleted_ids = set(meta.get("deleted_ids", [])) - self._init_index() - self._force_train_small_data() - return - - self._is_trained = meta.get("is_trained", False) - self._vector_norm = meta.get("vector_norm", "l2") - self._deleted_ids = set(meta.get("deleted_ids", [])) - self._known_hashes = set(meta.get("known_hashes", [])) - - if self._is_trained: - if idx_path.exists(): - try: - self._index = faiss.read_index(str(idx_path)) - if not isinstance(self._index, faiss.IndexIDMap2): - logger.warning("Loaded index type mismatch. Rebuilding...") - self._init_index() - self._force_train_small_data() - except Exception as e: - logger.error(f"Failed to load index: {e}. Rebuilding...") - self._init_index() - self._force_train_small_data() - else: - logger.warning("Index file missing despite metadata indicating trained. Rebuilding from bin...") - self._init_index() - self._force_train_small_data() - - if bin_path.exists(): - self._bin_count = bin_path.stat().st_size // (self.dimension * 2) - - def _migrate_from_npy(self, npy_path, idx_path, data_dir): - with self._lock: - self._migrate_from_npy_unlocked(npy_path, idx_path, data_dir) - - def _migrate_from_npy_unlocked(self, npy_path, idx_path, data_dir): - try: - arr = np.load(npy_path, mmap_mode="r") - except Exception: - arr = np.load(npy_path) - - meta_path = data_dir / "vectors_metadata.pkl" - old_ids = [] - if meta_path.exists(): - with open(meta_path, "rb") as f: - m = pickle.load(f) - old_ids = m.get("ids", []) - - if len(arr) != len(old_ids): - logger.error(f"Migration mismatch: arr {len(arr)} != ids {len(old_ids)}") - return - - logger.info(f"Migrating {len(arr)} vectors...") - - chunk = 1000 - for i in range(0, len(arr), chunk): - sub_arr = arr[i : i+chunk] - sub_ids = old_ids[i : i+chunk] - self.add(sub_arr, sub_ids) - - if not self._is_trained: - self._force_train_small_data() - - shutil.move(str(npy_path), str(npy_path) + ".bak") - if idx_path.exists(): - shutil.move(str(idx_path), str(idx_path) + ".bak") - - logger.info("Migration complete.") - - def clear(self) -> None: - with self._lock: - self._ids_bin_path.unlink(missing_ok=True) - self._bin_path.unlink(missing_ok=True) - self._init_index() - self._known_hashes.clear() - self._deleted_ids.clear() - self._bin_count = 0 - logger.info("VectorStore cleared.") - - def has_data(self) -> bool: - return (self.data_dir / "vectors_metadata.pkl").exists() - - @property - def num_vectors(self) -> int: - return len(self._known_hashes) - len(self._deleted_ids) - - def __contains__(self, hash_value: str) -> bool: - """Check if a hash exists in the store""" - return hash_value in self._known_hashes and self._generate_id(hash_value) not in self._deleted_ids - diff --git a/plugins/A_memorix/core/strategies/__init__.py b/plugins/A_memorix/core/strategies/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/plugins/A_memorix/core/strategies/base.py b/plugins/A_memorix/core/strategies/base.py deleted file mode 100644 index ff250cdf..00000000 --- a/plugins/A_memorix/core/strategies/base.py +++ /dev/null @@ -1,89 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional, Union -from dataclasses import dataclass, field -from enum import Enum -import hashlib - -class KnowledgeType(str, Enum): - NARRATIVE = "narrative" - FACTUAL = "factual" - QUOTE = "quote" - MIXED = "mixed" - -@dataclass -class SourceInfo: - file: str - offset_start: int - offset_end: int - checksum: str = "" - -@dataclass -class ChunkContext: - chunk_id: str - index: int - context: Dict[str, Any] = field(default_factory=dict) - text: str = "" - -@dataclass -class ChunkFlags: - verbatim: bool = False - requires_llm: bool = True - -@dataclass -class ProcessedChunk: - type: KnowledgeType - source: SourceInfo - chunk: ChunkContext - data: Dict[str, Any] = field(default_factory=dict) # triples、events、verbatim_entities - flags: ChunkFlags = field(default_factory=ChunkFlags) - - def to_dict(self) -> Dict: - return { - "type": self.type.value, - "source": { - "file": self.source.file, - "offset_start": self.source.offset_start, - "offset_end": self.source.offset_end, - "checksum": self.source.checksum - }, - "chunk": { - "text": self.chunk.text, - "chunk_id": self.chunk.chunk_id, - "index": self.chunk.index, - "context": self.chunk.context - }, - "data": self.data, - "flags": { - "verbatim": self.flags.verbatim, - "requires_llm": self.flags.requires_llm - } - } - -class BaseStrategy(ABC): - def __init__(self, filename: str): - self.filename = filename - - @abstractmethod - def split(self, text: str) -> List[ProcessedChunk]: - """按策略将文本切分为块。""" - pass - - @abstractmethod - async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: - """从文本块中抽取结构化信息。""" - pass - - def calculate_checksum(self, text: str) -> str: - return hashlib.sha256(text.encode("utf-8")).hexdigest() - - def build_language_guard(self, text: str) -> str: - """ - 构建统一的输出语言约束。 - 不区分语言类型,仅要求抽取值保持原文语言,不做翻译。 - """ - _ = text # 预留参数,便于后续按需扩展 - return ( - "Focus on the original source language. Keep extracted events, entities, predicates " - "and objects in the same language as the source text, preserve names/terms as-is, " - "and do not translate." - ) diff --git a/plugins/A_memorix/core/strategies/factual.py b/plugins/A_memorix/core/strategies/factual.py deleted file mode 100644 index 4b7d6e56..00000000 --- a/plugins/A_memorix/core/strategies/factual.py +++ /dev/null @@ -1,98 +0,0 @@ -import re -from typing import List, Dict, Any -from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext - -class FactualStrategy(BaseStrategy): - def split(self, text: str) -> List[ProcessedChunk]: - # 结构感知切分 - lines = text.split('\n') - chunks = [] - current_chunk_lines = [] - current_len = 0 - target_size = 600 - - for i, line in enumerate(lines): - # 判断是否应当切分 - # 若当前行为列表项/定义/表格行,则尽量不切分 - is_structure = self._is_structural_line(line) - - current_len += len(line) + 1 - current_chunk_lines.append(line) - - # 达到目标长度且不在紧凑结构块内时切分(过长时强制切分) - if current_len >= target_size and not is_structure: - chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) - current_chunk_lines = [] - current_len = 0 - elif current_len >= target_size * 2: # 超长时强制切分 - chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) - current_chunk_lines = [] - current_len = 0 - - if current_chunk_lines: - chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) - - return chunks - - def _is_structural_line(self, line: str) -> bool: - line = line.strip() - if not line: return False - # 列表项 - if re.match(r'^[\-\*]\s+', line) or re.match(r'^\d+\.\s+', line): - return True - # 定义项(术语: 定义) - if re.match(r'^[^::]+[::].+', line): - return True - # 表格行(按 markdown 语法假设) - if line.startswith('|') and line.endswith('|'): - return True - return False - - def _create_chunk(self, lines: List[str], index: int) -> ProcessedChunk: - text = "\n".join(lines) - return ProcessedChunk( - type=KnowledgeType.FACTUAL, - source=SourceInfo( - file=self.filename, - offset_start=0, # 简化处理:真实偏移跟踪需要额外状态 - offset_end=0, - checksum=self.calculate_checksum(text) - ), - chunk=ChunkContext( - chunk_id=f"{self.filename}_{index}", - index=index, - text=text - ) - ) - - async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: - if not llm_func: - raise ValueError("LLM function required for Factual extraction") - - language_guard = self.build_language_guard(chunk.chunk.text) - prompt = f"""You are a factual knowledge extraction engine. -Extract factual triples and entities from the text. -Preserve lists and definitions accurately. - -Language constraints: -- {language_guard} -- Preserve original names and domain terms exactly when possible. -- JSON keys must stay exactly as: triples, entities, subject, predicate, object. - -Text: -{chunk.chunk.text} - -Return ONLY valid JSON: -{{ - "triples": [ - {{"subject": "Entity", "predicate": "Relationship", "object": "Entity"}} - ], - "entities": ["Entity1", "Entity2"] -}} -""" - result = await llm_func(prompt) - - # 结果保持原样存入 data,后续统一归一化流程会处理 - # vector_store 侧期望关系字段为 subject/predicate/object 映射形式 - chunk.data = result - return chunk diff --git a/plugins/A_memorix/core/strategies/narrative.py b/plugins/A_memorix/core/strategies/narrative.py deleted file mode 100644 index 731414f7..00000000 --- a/plugins/A_memorix/core/strategies/narrative.py +++ /dev/null @@ -1,126 +0,0 @@ -import re -from typing import List, Dict, Any -from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext - -class NarrativeStrategy(BaseStrategy): - def split(self, text: str) -> List[ProcessedChunk]: - scenes = self._split_into_scenes(text) - chunks = [] - - for scene_idx, (scene_text, scene_title) in enumerate(scenes): - scene_chunks = self._sliding_window(scene_text, scene_title, scene_idx) - chunks.extend(scene_chunks) - - return chunks - - def _split_into_scenes(self, text: str) -> List[tuple[str, str]]: - """按标题或分隔符把文本切分为场景。""" - # 简单启发式:按 markdown 标题或特定分隔符切分 - # 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行 - # 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行 - scene_pattern_str = r'^(?:#{1,6}\s+.*|Chapter\s+\d+|^\*{3,}$|^={3,}$)' - - # 保留分隔符,以便识别场景起点 - parts = re.split(f"({scene_pattern_str})", text, flags=re.MULTILINE) - - scenes = [] - current_scene_title = "Start" - current_scene_content = [] - - if parts and parts[0].strip() == "": - parts = parts[1:] - - for part in parts: - if re.match(scene_pattern_str, part, re.MULTILINE): - # 先保存上一段场景 - if current_scene_content: - scenes.append(("".join(current_scene_content), current_scene_title)) - current_scene_content = [] - current_scene_title = part.strip() - else: - current_scene_content.append(part) - - if current_scene_content: - scenes.append(("".join(current_scene_content), current_scene_title)) - - # 若未识别到场景,则把全文视作单一场景 - if not scenes: - scenes = [(text, "Whole Text")] - - return scenes - - def _sliding_window(self, text: str, scene_id: str, scene_idx: int, window_size=800, overlap=200) -> List[ProcessedChunk]: - chunks = [] - if len(text) <= window_size: - chunks.append(self._create_chunk(text, scene_id, scene_idx, 0, 0)) - return chunks - - stride = window_size - overlap - start = 0 - local_idx = 0 - while start < len(text): - end = min(start + window_size, len(text)) - chunk_text = text[start:end] - - # 尽量对齐到最近换行,避免生硬截断句子 - # 仅在未到文本尾部时进行回退 - if end < len(text): - last_newline = chunk_text.rfind('\n') - if last_newline > window_size // 2: # 仅在回退距离可接受时启用 - end = start + last_newline + 1 - chunk_text = text[start:end] - - chunks.append(self._create_chunk(chunk_text, scene_id, scene_idx, local_idx, start)) - - start += len(chunk_text) - overlap if end < len(text) else len(chunk_text) - local_idx += 1 - - return chunks - - def _create_chunk(self, text: str, scene_id: str, scene_idx: int, local_idx: int, offset: int) -> ProcessedChunk: - return ProcessedChunk( - type=KnowledgeType.NARRATIVE, - source=SourceInfo( - file=self.filename, - offset_start=offset, - offset_end=offset + len(text), - checksum=self.calculate_checksum(text) - ), - chunk=ChunkContext( - chunk_id=f"{self.filename}_{scene_idx}_{local_idx}", - index=local_idx, - text=text, - context={"scene_id": scene_id} - ) - ) - - async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: - if not llm_func: - raise ValueError("LLM function required for Narrative extraction") - - language_guard = self.build_language_guard(chunk.chunk.text) - prompt = f"""You are a narrative knowledge extraction engine. -Extract key events and character relations from the scene text. - -Language constraints: -- {language_guard} -- Preserve original names and terms exactly when possible. -- JSON keys must stay exactly as: events, relations, subject, predicate, object. - -Scene: -{chunk.chunk.context.get('scene_id')} - -Text: -{chunk.chunk.text} - -Return ONLY valid JSON: -{{ - "events": ["event description 1", "event description 2"], - "relations": [ - {{"subject": "CharacterA", "predicate": "relation", "object": "CharacterB"}} - ] -}} -""" - result = await llm_func(prompt) - chunk.data = result - return chunk diff --git a/plugins/A_memorix/core/strategies/quote.py b/plugins/A_memorix/core/strategies/quote.py deleted file mode 100644 index 10733d64..00000000 --- a/plugins/A_memorix/core/strategies/quote.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import List, Dict, Any -from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags - -class QuoteStrategy(BaseStrategy): - def split(self, text: str) -> List[ProcessedChunk]: - # Split by double newlines (stanzas) - stanzas = text.split("\n\n") - chunks = [] - offset = 0 - - for idx, stanza in enumerate(stanzas): - if not stanza.strip(): - offset += len(stanza) + 2 - continue - - chunk = ProcessedChunk( - type=KnowledgeType.QUOTE, - source=SourceInfo( - file=self.filename, - offset_start=offset, - offset_end=offset + len(stanza), - checksum=self.calculate_checksum(stanza) - ), - chunk=ChunkContext( - chunk_id=f"{self.filename}_{idx}", - index=idx, - text=stanza - ), - flags=ChunkFlags( - verbatim=True, - requires_llm=False # Default to no LLM, but can be overridden - ) - ) - chunks.append(chunk) - offset += len(stanza) + 2 # +2 for \n\n - - return chunks - - async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: - # For quotes, the text itself is the entity/knowledge - # We might use LLM to extract headers/metadata if requested, but core logic is pass-through - - # Treat the whole chunk text as a verbatim entity - chunk.data = { - "verbatim_entities": [chunk.chunk.text] - } - - if llm_func and chunk.flags.requires_llm: - # Optional: Extract metadata - pass - - return chunk diff --git a/plugins/A_memorix/core/utils/__init__.py b/plugins/A_memorix/core/utils/__init__.py deleted file mode 100644 index e0d763cf..00000000 --- a/plugins/A_memorix/core/utils/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -"""工具模块 - 哈希、监控等辅助功能""" - -from .hash import compute_hash, normalize_text -from .monitor import MemoryMonitor -from .quantization import quantize_vector, dequantize_vector -from .time_parser import ( - parse_query_datetime_to_timestamp, - parse_query_time_range, - parse_ingest_datetime_to_timestamp, - normalize_time_meta, - format_timestamp, -) -from .relation_write_service import RelationWriteService, RelationWriteResult -from .relation_query import RelationQuerySpec, parse_relation_query_spec -from .plugin_id_policy import PluginIdPolicy - -__all__ = [ - "compute_hash", - "normalize_text", - "MemoryMonitor", - "quantize_vector", - "dequantize_vector", - "parse_query_datetime_to_timestamp", - "parse_query_time_range", - "parse_ingest_datetime_to_timestamp", - "normalize_time_meta", - "format_timestamp", - "RelationWriteService", - "RelationWriteResult", - "RelationQuerySpec", - "parse_relation_query_spec", - "PluginIdPolicy", -] diff --git a/plugins/A_memorix/core/utils/aggregate_query_service.py b/plugins/A_memorix/core/utils/aggregate_query_service.py deleted file mode 100644 index dcf64c34..00000000 --- a/plugins/A_memorix/core/utils/aggregate_query_service.py +++ /dev/null @@ -1,360 +0,0 @@ -""" -聚合查询服务: -- 并发执行 search/time/episode 分支 -- 统一分支结果结构 -- 可选混合排序(Weighted RRF) -""" - -from __future__ import annotations - -import asyncio -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.AggregateQueryService") - -BranchRunner = Callable[[], Awaitable[Dict[str, Any]]] - - -class AggregateQueryService: - """聚合查询执行服务(search/time/episode)。""" - - def __init__(self, plugin_config: Optional[Any] = None): - self.plugin_config = plugin_config or {} - - def _cfg(self, key: str, default: Any = None) -> Any: - getter = getattr(self.plugin_config, "get_config", None) - if callable(getter): - return getter(key, default) - - current: Any = self.plugin_config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - @staticmethod - def _as_float(value: Any, default: float = 0.0) -> float: - try: - return float(value) - except Exception: - return float(default) - - @staticmethod - def _as_int(value: Any, default: int = 0) -> int: - try: - return int(value) - except Exception: - return int(default) - - def _rrf_k(self) -> float: - raw = self._cfg("retrieval.aggregate.rrf_k", 60.0) - value = self._as_float(raw, 60.0) - return max(1.0, value) - - def _weights(self) -> Dict[str, float]: - defaults = {"search": 1.0, "time": 1.0, "episode": 1.0} - raw = self._cfg("retrieval.aggregate.weights", {}) - if not isinstance(raw, dict): - return defaults - - out = dict(defaults) - for key in ("search", "time", "episode"): - if key in raw: - out[key] = max(0.0, self._as_float(raw.get(key), defaults[key])) - return out - - @staticmethod - def _normalize_branch_payload( - name: str, - payload: Optional[Dict[str, Any]], - ) -> Dict[str, Any]: - data = payload if isinstance(payload, dict) else {} - results_raw = data.get("results", []) - results = results_raw if isinstance(results_raw, list) else [] - count = data.get("count") - if count is None: - count = len(results) - return { - "name": name, - "success": bool(data.get("success", False)), - "skipped": bool(data.get("skipped", False)), - "skip_reason": str(data.get("skip_reason", "") or "").strip(), - "error": str(data.get("error", "") or "").strip(), - "results": results, - "count": max(0, int(count)), - "elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)), - "content": str(data.get("content", "") or ""), - "query_type": str(data.get("query_type", "") or name), - } - - @staticmethod - def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str: - item_type = str(item.get("type", "") or "").strip().lower() - if item_type == "episode": - episode_id = str(item.get("episode_id", "") or "").strip() - if episode_id: - return f"episode:{episode_id}" - - item_hash = str(item.get("hash", "") or "").strip() - if item_hash: - return f"{item_type}:{item_hash}" - - return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}" - - def _build_mixed_results( - self, - *, - branches: Dict[str, Dict[str, Any]], - top_k: int, - ) -> List[Dict[str, Any]]: - rrf_k = self._rrf_k() - weights = self._weights() - bucket: Dict[str, Dict[str, Any]] = {} - - for branch_name, branch in branches.items(): - if not branch.get("success", False): - continue - results = branch.get("results", []) - if not isinstance(results, list): - continue - - weight = max(0.0, float(weights.get(branch_name, 1.0))) - for idx, item in enumerate(results, start=1): - if not isinstance(item, dict): - continue - key = self._mix_key(item, branch_name, idx) - score = weight / (rrf_k + float(idx)) - if key not in bucket: - merged = dict(item) - merged["fusion_score"] = 0.0 - merged["_source_branches"] = set() - bucket[key] = merged - - target = bucket[key] - target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score - target["_source_branches"].add(branch_name) - - mixed = list(bucket.values()) - mixed.sort( - key=lambda x: ( - -float(x.get("fusion_score", 0.0)), - str(x.get("type", "") or ""), - str(x.get("hash", "") or x.get("episode_id", "") or ""), - ) - ) - - out: List[Dict[str, Any]] = [] - for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1): - merged = dict(item) - branches_set = merged.pop("_source_branches", set()) - merged["source_branches"] = sorted(list(branches_set)) - merged["rank"] = rank - out.append(merged) - return out - - @staticmethod - def _status(branch: Dict[str, Any]) -> str: - if branch.get("skipped", False): - return "skipped" - if branch.get("success", False): - return "success" - return "failed" - - def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: - summary: Dict[str, Dict[str, Any]] = {} - for name, branch in branches.items(): - status = self._status(branch) - summary[name] = { - "status": status, - "count": int(branch.get("count", 0) or 0), - } - if status == "skipped": - summary[name]["reason"] = str(branch.get("skip_reason", "") or "") - if status == "failed": - summary[name]["error"] = str(branch.get("error", "") or "") - return summary - - def _build_content( - self, - *, - query: str, - branches: Dict[str, Dict[str, Any]], - errors: List[Dict[str, str]], - mixed_results: Optional[List[Dict[str, Any]]], - ) -> str: - lines: List[str] = [ - f"🔀 聚合查询结果(query='{query or 'N/A'}')", - "", - "分支状态:", - ] - for name in ("search", "time", "episode"): - branch = branches.get(name, {}) - status = self._status(branch) - count = int(branch.get("count", 0) or 0) - line = f"- {name}: {status}, count={count}" - reason = str(branch.get("skip_reason", "") or "").strip() - err = str(branch.get("error", "") or "").strip() - if status == "skipped" and reason: - line += f" ({reason})" - if status == "failed" and err: - line += f" ({err})" - lines.append(line) - - if errors: - lines.append("") - lines.append("错误:") - for item in errors[:6]: - lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}") - - if mixed_results is not None: - lines.append("") - lines.append(f"🧩 混合结果({len(mixed_results)} 条):") - for idx, item in enumerate(mixed_results[:5], start=1): - src = ",".join(item.get("source_branches", []) or []) - if str(item.get("type", "") or "") == "episode": - title = str(item.get("title", "") or "Untitled") - lines.append(f"{idx}. 🧠 {title} [{src}]") - else: - text = str(item.get("content", "") or "") - if len(text) > 80: - text = text[:80] + "..." - lines.append(f"{idx}. {text} [{src}]") - - return "\n".join(lines) - - async def execute( - self, - *, - query: str, - top_k: int, - mix: bool, - mix_top_k: Optional[int], - time_from: Optional[str], - time_to: Optional[str], - search_runner: Optional[BranchRunner], - time_runner: Optional[BranchRunner], - episode_runner: Optional[BranchRunner], - ) -> Dict[str, Any]: - clean_query = str(query or "").strip() - safe_top_k = max(1, int(top_k)) - safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k)) - - branches: Dict[str, Dict[str, Any]] = {} - errors: List[Dict[str, str]] = [] - scheduled: List[Tuple[str, asyncio.Task]] = [] - - if clean_query: - if search_runner is not None: - scheduled.append(("search", asyncio.create_task(search_runner()))) - else: - branches["search"] = self._normalize_branch_payload( - "search", - {"success": False, "error": "search runner unavailable", "results": []}, - ) - else: - branches["search"] = self._normalize_branch_payload( - "search", - { - "success": False, - "skipped": True, - "skip_reason": "missing_query", - "results": [], - "count": 0, - }, - ) - - if time_from or time_to: - if time_runner is not None: - scheduled.append(("time", asyncio.create_task(time_runner()))) - else: - branches["time"] = self._normalize_branch_payload( - "time", - {"success": False, "error": "time runner unavailable", "results": []}, - ) - else: - branches["time"] = self._normalize_branch_payload( - "time", - { - "success": False, - "skipped": True, - "skip_reason": "missing_time_window", - "results": [], - "count": 0, - }, - ) - - if episode_runner is not None: - scheduled.append(("episode", asyncio.create_task(episode_runner()))) - else: - branches["episode"] = self._normalize_branch_payload( - "episode", - {"success": False, "error": "episode runner unavailable", "results": []}, - ) - - if scheduled: - done = await asyncio.gather( - *[task for _, task in scheduled], - return_exceptions=True, - ) - for (branch_name, _), payload in zip(scheduled, done): - if isinstance(payload, Exception): - logger.error(f"aggregate branch failed: branch={branch_name} error={payload}") - normalized = self._normalize_branch_payload( - branch_name, - { - "success": False, - "error": str(payload), - "results": [], - }, - ) - else: - normalized = self._normalize_branch_payload(branch_name, payload) - branches[branch_name] = normalized - - for name in ("search", "time", "episode"): - branch = branches.get(name) - if not branch: - continue - if branch.get("skipped", False): - continue - if not branch.get("success", False): - errors.append( - { - "branch": name, - "error": str(branch.get("error", "") or "unknown error"), - } - ) - - success = any( - bool(branches.get(name, {}).get("success", False)) - for name in ("search", "time", "episode") - ) - mixed_results: Optional[List[Dict[str, Any]]] = None - if mix: - mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k) - - payload: Dict[str, Any] = { - "success": success, - "query_type": "aggregate", - "query": clean_query, - "top_k": safe_top_k, - "mix": bool(mix), - "mix_top_k": safe_mix_top_k, - "branches": branches, - "errors": errors, - "summary": self._build_summary(branches), - } - if mixed_results is not None: - payload["mixed_results"] = mixed_results - - payload["content"] = self._build_content( - query=clean_query, - branches=branches, - errors=errors, - mixed_results=mixed_results, - ) - return payload diff --git a/plugins/A_memorix/core/utils/episode_retrieval_service.py b/plugins/A_memorix/core/utils/episode_retrieval_service.py deleted file mode 100644 index 44b22854..00000000 --- a/plugins/A_memorix/core/utils/episode_retrieval_service.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Episode hybrid retrieval service.""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from src.common.logger import get_logger - -from ..retrieval import DualPathRetriever, TemporalQueryOptions - -logger = get_logger("A_Memorix.EpisodeRetrievalService") - - -class EpisodeRetrievalService: - """Hybrid episode retrieval backed by lexical rows and evidence projection.""" - - _RRF_K = 60.0 - _BRANCH_WEIGHTS = { - "lexical": 1.0, - "paragraph_evidence": 1.0, - "relation_evidence": 0.85, - } - - def __init__( - self, - *, - metadata_store: Any, - retriever: Optional[DualPathRetriever] = None, - ) -> None: - self.metadata_store = metadata_store - self.retriever = retriever - - async def query( - self, - *, - query: str = "", - top_k: int = 5, - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - include_paragraphs: bool = False, - ) -> List[Dict[str, Any]]: - clean_query = str(query or "").strip() - safe_top_k = max(1, int(top_k)) - candidate_k = max(30, safe_top_k * 6) - - branches: Dict[str, List[Dict[str, Any]]] = { - "lexical": self.metadata_store.query_episodes( - query=clean_query, - time_from=time_from, - time_to=time_to, - person=person, - source=source, - limit=(candidate_k if clean_query else safe_top_k), - ) - } - - if clean_query and self.retriever is not None: - try: - temporal = TemporalQueryOptions( - time_from=time_from, - time_to=time_to, - person=person, - source=source, - ) - results = await self.retriever.retrieve( - query=clean_query, - top_k=candidate_k, - temporal=temporal, - ) - except Exception as exc: - logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}") - else: - paragraph_rank_map: Dict[str, int] = {} - relation_rank_map: Dict[str, int] = {} - for rank, item in enumerate(results, start=1): - hash_value = str(getattr(item, "hash_value", "") or "").strip() - result_type = str(getattr(item, "result_type", "") or "").strip().lower() - if not hash_value: - continue - if result_type == "paragraph" and hash_value not in paragraph_rank_map: - paragraph_rank_map[hash_value] = rank - elif result_type == "relation" and hash_value not in relation_rank_map: - relation_rank_map[hash_value] = rank - - if paragraph_rank_map: - paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes( - list(paragraph_rank_map.keys()), - source=source, - ) - if paragraph_rows: - branches["paragraph_evidence"] = self._rank_projected_rows( - paragraph_rows, - rank_map=paragraph_rank_map, - support_key="matched_paragraph_hashes", - ) - - if relation_rank_map: - relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes( - list(relation_rank_map.keys()), - source=source, - ) - if relation_rows: - branches["relation_evidence"] = self._rank_projected_rows( - relation_rows, - rank_map=relation_rank_map, - support_key="matched_relation_hashes", - ) - - fused = self._fuse_branches(branches, top_k=safe_top_k) - if include_paragraphs: - for item in fused: - item["paragraphs"] = self.metadata_store.get_episode_paragraphs( - episode_id=str(item.get("episode_id") or ""), - limit=50, - ) - return fused - - @staticmethod - def _rank_projected_rows( - rows: List[Dict[str, Any]], - *, - rank_map: Dict[str, int], - support_key: str, - ) -> List[Dict[str, Any]]: - sentinel = 10**9 - ranked = [dict(item) for item in rows] - - def _first_support_rank(item: Dict[str, Any]) -> int: - support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])] - ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map] - return min(ranks) if ranks else sentinel - - ranked.sort( - key=lambda item: ( - _first_support_rank(item), - -int(item.get("matched_paragraph_count") or 0), - -float(item.get("updated_at") or 0.0), - str(item.get("episode_id") or ""), - ) - ) - return ranked - - def _fuse_branches( - self, - branches: Dict[str, List[Dict[str, Any]]], - *, - top_k: int, - ) -> List[Dict[str, Any]]: - bucket: Dict[str, Dict[str, Any]] = {} - for branch_name, rows in branches.items(): - weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0) - if weight <= 0.0: - continue - for rank, row in enumerate(rows, start=1): - episode_id = str(row.get("episode_id", "") or "").strip() - if not episode_id: - continue - if episode_id not in bucket: - payload = dict(row) - payload.pop("matched_paragraph_hashes", None) - payload.pop("matched_relation_hashes", None) - payload.pop("matched_paragraph_count", None) - payload.pop("matched_relation_count", None) - payload["_fusion_score"] = 0.0 - bucket[episode_id] = payload - bucket[episode_id]["_fusion_score"] = float( - bucket[episode_id].get("_fusion_score", 0.0) - ) + weight / (self._RRF_K + float(rank)) - - out = list(bucket.values()) - out.sort( - key=lambda item: ( - -float(item.get("_fusion_score", 0.0)), - -float(item.get("updated_at") or 0.0), - str(item.get("episode_id") or ""), - ) - ) - for item in out: - item.pop("_fusion_score", None) - return out[: max(1, int(top_k))] diff --git a/plugins/A_memorix/core/utils/episode_segmentation_service.py b/plugins/A_memorix/core/utils/episode_segmentation_service.py deleted file mode 100644 index f42b1456..00000000 --- a/plugins/A_memorix/core/utils/episode_segmentation_service.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -Episode 语义切分服务(LLM 主路径)。 - -职责: -1. 组装语义切分提示词 -2. 调用 LLM 生成结构化 episode JSON -3. 严格校验输出结构,返回标准化结果 -""" - -from __future__ import annotations - -import json -from typing import Any, Dict, List, Optional, Tuple - -from src.common.logger import get_logger -from src.config.model_configs import TaskConfig -from src.config.config import model_config as host_model_config -from src.services import llm_service as llm_api - -logger = get_logger("A_Memorix.EpisodeSegmentationService") - - -class EpisodeSegmentationService: - """基于 LLM 的 episode 语义切分服务。""" - - SEGMENTATION_VERSION = "episode_mvp_v1" - - def __init__(self, plugin_config: Optional[dict] = None): - self.plugin_config = plugin_config or {} - - def _cfg(self, key: str, default: Any = None) -> Any: - current: Any = self.plugin_config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - @staticmethod - def _is_task_config(obj: Any) -> bool: - return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", [])) - - def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig: - return TaskConfig( - model_list=[model_name], - max_tokens=template.max_tokens, - temperature=template.temperature, - slow_threshold=template.slow_threshold, - selection_strategy=template.selection_strategy, - ) - - def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]: - preferred = ("utils", "replyer", "planner", "tool_use") - for task_name in preferred: - cfg = available_tasks.get(task_name) - if self._is_task_config(cfg): - return cfg - for task_name, cfg in available_tasks.items(): - if task_name != "embedding" and self._is_task_config(cfg): - return cfg - for cfg in available_tasks.values(): - if self._is_task_config(cfg): - return cfg - return None - - def _resolve_model_config(self) -> Tuple[Optional[Any], str]: - available_tasks = llm_api.get_available_models() or {} - if not available_tasks: - return None, "unavailable" - - selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip() - model_dict = getattr(host_model_config, "models_dict", {}) or {} - - if selector and selector.lower() != "auto": - direct_task = available_tasks.get(selector) - if self._is_task_config(direct_task): - return direct_task, selector - - if selector in model_dict: - template = self._pick_template_task(available_tasks) - if template is not None: - return self._build_single_model_task(selector, template), selector - - logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto") - - for task_name in ("utils", "replyer", "planner", "tool_use"): - cfg = available_tasks.get(task_name) - if self._is_task_config(cfg): - return cfg, task_name - - fallback = self._pick_template_task(available_tasks) - if fallback is not None: - return fallback, "auto" - return None, "unavailable" - - @staticmethod - def _clamp_score(value: Any, default: float = 0.0) -> float: - try: - num = float(value) - except Exception: - num = default - if num < 0.0: - return 0.0 - if num > 1.0: - return 1.0 - return num - - @staticmethod - def _safe_json_loads(text: str) -> Dict[str, Any]: - raw = str(text or "").strip() - if not raw: - raise ValueError("empty_response") - - if "```" in raw: - raw = raw.replace("```json", "```").replace("```JSON", "```") - parts = raw.split("```") - for part in parts: - part = part.strip() - if part.startswith("{") and part.endswith("}"): - raw = part - break - - try: - data = json.loads(raw) - if isinstance(data, dict): - return data - except Exception: - pass - - start = raw.find("{") - end = raw.rfind("}") - if start >= 0 and end > start: - candidate = raw[start : end + 1] - data = json.loads(candidate) - if isinstance(data, dict): - return data - - raise ValueError("invalid_json_response") - - def _build_prompt( - self, - *, - source: str, - window_start: Optional[float], - window_end: Optional[float], - paragraphs: List[Dict[str, Any]], - ) -> str: - rows: List[str] = [] - for idx, item in enumerate(paragraphs, 1): - p_hash = str(item.get("hash", "") or "").strip() - content = str(item.get("content", "") or "").strip().replace("\r\n", "\n") - content = content[:800] - event_start = item.get("event_time_start") - event_end = item.get("event_time_end") - event_time = item.get("event_time") - rows.append( - ( - f"[{idx}] hash={p_hash}\n" - f"event_time={event_time}\n" - f"event_time_start={event_start}\n" - f"event_time_end={event_end}\n" - f"content={content}" - ) - ) - - source_text = str(source or "").strip() or "unknown" - return ( - "You are an episode segmentation engine.\n" - "Group the given paragraphs into one or more coherent episodes.\n" - "Return JSON ONLY. No markdown, no explanation.\n" - "\n" - "Hard JSON schema:\n" - "{\n" - ' "episodes": [\n' - " {\n" - ' "title": "string",\n' - ' "summary": "string",\n' - ' "paragraph_hashes": ["hash1", "hash2"],\n' - ' "participants": ["person1", "person2"],\n' - ' "keywords": ["kw1", "kw2"],\n' - ' "time_confidence": 0.0,\n' - ' "llm_confidence": 0.0\n' - " }\n" - " ]\n" - "}\n" - "\n" - "Rules:\n" - "1) paragraph_hashes must come from input only.\n" - "2) title and summary must be non-empty.\n" - "3) keep participants/keywords concise and deduplicated.\n" - "4) if uncertain, still provide best effort confidence values.\n" - "\n" - f"source={source_text}\n" - f"window_start={window_start}\n" - f"window_end={window_end}\n" - "paragraphs:\n" - + "\n\n".join(rows) - ) - - def _normalize_episodes( - self, - *, - payload: Dict[str, Any], - input_hashes: List[str], - ) -> List[Dict[str, Any]]: - raw_episodes = payload.get("episodes") - if not isinstance(raw_episodes, list): - raise ValueError("episodes_missing_or_not_list") - - valid_hashes = set(input_hashes) - normalized: List[Dict[str, Any]] = [] - for item in raw_episodes: - if not isinstance(item, dict): - continue - - title = str(item.get("title", "") or "").strip() - summary = str(item.get("summary", "") or "").strip() - if not title or not summary: - continue - - raw_hashes = item.get("paragraph_hashes") - if not isinstance(raw_hashes, list): - continue - - dedup_hashes: List[str] = [] - seen_hashes = set() - for h in raw_hashes: - token = str(h or "").strip() - if not token or token in seen_hashes or token not in valid_hashes: - continue - seen_hashes.add(token) - dedup_hashes.append(token) - - if not dedup_hashes: - continue - - participants = [] - for p in item.get("participants", []) or []: - token = str(p or "").strip() - if token: - participants.append(token) - - keywords = [] - for kw in item.get("keywords", []) or []: - token = str(kw or "").strip() - if token: - keywords.append(token) - - normalized.append( - { - "title": title, - "summary": summary, - "paragraph_hashes": dedup_hashes, - "participants": participants[:16], - "keywords": keywords[:20], - "time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0), - "llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5), - } - ) - - if not normalized: - raise ValueError("episodes_all_invalid") - return normalized - - async def segment( - self, - *, - source: str, - window_start: Optional[float], - window_end: Optional[float], - paragraphs: List[Dict[str, Any]], - ) -> Dict[str, Any]: - if not paragraphs: - raise ValueError("paragraphs_empty") - - model_config, model_label = self._resolve_model_config() - if model_config is None: - raise RuntimeError("episode segmentation model unavailable") - - prompt = self._build_prompt( - source=source, - window_start=window_start, - window_end=window_end, - paragraphs=paragraphs, - ) - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type="A_Memorix.EpisodeSegmentation", - ) - if not success or not response: - raise RuntimeError("llm_generate_failed") - - payload = self._safe_json_loads(str(response)) - input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs] - episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes) - - return { - "episodes": episodes, - "segmentation_model": model_label, - "segmentation_version": self.SEGMENTATION_VERSION, - } - diff --git a/plugins/A_memorix/core/utils/episode_service.py b/plugins/A_memorix/core/utils/episode_service.py deleted file mode 100644 index ca94dd96..00000000 --- a/plugins/A_memorix/core/utils/episode_service.py +++ /dev/null @@ -1,558 +0,0 @@ -""" -Episode 聚合与落库服务。 - -流程: -1. 从 pending 队列读取段落并组批 -2. 按 source + 时间窗口切组 -3. 调用 LLM 语义切分 -4. 写入 episodes + episode_paragraphs -5. LLM 失败时使用确定性 fallback -""" - -from __future__ import annotations - -import json -import re -from collections import Counter -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple - -from src.common.logger import get_logger - -from .episode_segmentation_service import EpisodeSegmentationService -from .hash import compute_hash - -logger = get_logger("A_Memorix.EpisodeService") - - -class EpisodeService: - """Episode MVP 后台处理服务。""" - - def __init__( - self, - *, - metadata_store: Any, - plugin_config: Optional[Any] = None, - segmentation_service: Optional[EpisodeSegmentationService] = None, - ): - self.metadata_store = metadata_store - self.plugin_config = plugin_config or {} - self.segmentation_service = segmentation_service or EpisodeSegmentationService( - plugin_config=self._config_dict(), - ) - - def _config_dict(self) -> Dict[str, Any]: - if isinstance(self.plugin_config, dict): - return self.plugin_config - return {} - - def _cfg(self, key: str, default: Any = None) -> Any: - getter = getattr(self.plugin_config, "get_config", None) - if callable(getter): - return getter(key, default) - - current: Any = self.plugin_config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - @staticmethod - def _to_optional_float(value: Any) -> Optional[float]: - if value is None: - return None - try: - return float(value) - except Exception: - return None - - @staticmethod - def _clamp_score(value: Any, default: float = 1.0) -> float: - try: - num = float(value) - except Exception: - num = default - if num < 0.0: - return 0.0 - if num > 1.0: - return 1.0 - return num - - @staticmethod - def _paragraph_anchor(paragraph: Dict[str, Any]) -> float: - for key in ("event_time_end", "event_time_start", "event_time", "created_at"): - value = paragraph.get(key) - try: - if value is not None: - return float(value) - except Exception: - continue - return 0.0 - - @staticmethod - def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]: - return ( - EpisodeService._paragraph_anchor(paragraph), - str(paragraph.get("hash", "") or ""), - ) - - def load_pending_paragraphs( - self, - pending_rows: List[Dict[str, Any]], - ) -> Tuple[List[Dict[str, Any]], List[str]]: - """ - 将 pending 行展开为段落上下文。 - - Returns: - (loaded_paragraphs, missing_hashes) - """ - loaded: List[Dict[str, Any]] = [] - missing: List[str] = [] - for row in pending_rows or []: - p_hash = str(row.get("paragraph_hash", "") or "").strip() - if not p_hash: - continue - - paragraph = self.metadata_store.get_paragraph(p_hash) - if not paragraph: - missing.append(p_hash) - continue - - loaded.append( - { - "hash": p_hash, - "source": str(row.get("source") or paragraph.get("source") or "").strip(), - "content": str(paragraph.get("content", "") or ""), - "created_at": self._to_optional_float(paragraph.get("created_at")) - or self._to_optional_float(row.get("created_at")) - or 0.0, - "event_time": self._to_optional_float(paragraph.get("event_time")), - "event_time_start": self._to_optional_float(paragraph.get("event_time_start")), - "event_time_end": self._to_optional_float(paragraph.get("event_time_end")), - "time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None, - "time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0), - } - ) - return loaded, missing - - def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - 按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。 - """ - if not paragraphs: - return [] - - max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20))) - max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000))) - window_seconds = max( - 60.0, - float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0, - ) - - by_source: Dict[str, List[Dict[str, Any]]] = {} - for paragraph in paragraphs: - source = str(paragraph.get("source", "") or "").strip() - by_source.setdefault(source, []).append(paragraph) - - groups: List[Dict[str, Any]] = [] - for source, items in by_source.items(): - ordered = sorted(items, key=self._paragraph_sort_key) - - current: List[Dict[str, Any]] = [] - current_chars = 0 - last_anchor: Optional[float] = None - - def flush() -> None: - nonlocal current, current_chars, last_anchor - if not current: - return - sorted_current = sorted(current, key=self._paragraph_sort_key) - groups.append( - { - "source": source, - "paragraphs": sorted_current, - } - ) - current = [] - current_chars = 0 - last_anchor = None - - for paragraph in ordered: - anchor = self._paragraph_anchor(paragraph) - content_len = len(str(paragraph.get("content", "") or "")) - - need_flush = False - if current: - if len(current) >= max_paragraphs: - need_flush = True - elif current_chars + content_len > max_chars: - need_flush = True - elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds: - need_flush = True - - if need_flush: - flush() - - current.append(paragraph) - current_chars += content_len - last_anchor = anchor - - flush() - - groups.sort( - key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0 - ) - return groups - - def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]: - starts: List[float] = [] - ends: List[float] = [] - granularity_priority = { - "minute": 4, - "hour": 3, - "day": 2, - "month": 1, - "year": 0, - } - granularity = None - granularity_rank = -1 - conf_values: List[float] = [] - - for p in paragraphs: - s = self._to_optional_float(p.get("event_time_start")) - e = self._to_optional_float(p.get("event_time_end")) - t = self._to_optional_float(p.get("event_time")) - c = self._to_optional_float(p.get("created_at")) - - start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c)) - end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c)) - - if start_candidate is not None: - starts.append(start_candidate) - if end_candidate is not None: - ends.append(end_candidate) - - g = str(p.get("time_granularity", "") or "").strip().lower() - if g in granularity_priority and granularity_priority[g] > granularity_rank: - granularity_rank = granularity_priority[g] - granularity = g - - conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0)) - - time_start = min(starts) if starts else None - time_end = max(ends) if ends else None - time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0 - return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0) - - def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]: - seen = set() - participants: List[str] = [] - for p_hash in paragraph_hashes: - try: - entities = self.metadata_store.get_paragraph_entities(p_hash) - except Exception: - entities = [] - for item in entities: - name = str(item.get("name", "") or "").strip() - if not name: - continue - key = name.lower() - if key in seen: - continue - seen.add(key) - participants.append(name) - if len(participants) >= limit: - return participants - return participants - - @staticmethod - def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]: - token_counter: Counter[str] = Counter() - token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}") - stop_words = { - "the", - "and", - "that", - "this", - "with", - "from", - "for", - "have", - "will", - "your", - "you", - "我们", - "你们", - "他们", - "以及", - "一个", - "这个", - "那个", - "然后", - "因为", - "所以", - } - for p in paragraphs: - text = str(p.get("content", "") or "").lower() - for token in token_pattern.findall(text): - if token in stop_words: - continue - token_counter[token] += 1 - - return [token for token, _ in token_counter.most_common(limit)] - - def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]: - paragraphs = group.get("paragraphs", []) or [] - source = str(group.get("source", "") or "").strip() - hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()] - snippets = [] - for p in paragraphs[:3]: - text = str(p.get("content", "") or "").strip().replace("\n", " ") - if text: - snippets.append(text[:140]) - summary = ";".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。" - - time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs) - participants = self._collect_participants(hashes, limit=12) - keywords = self._derive_keywords(paragraphs, limit=10) - - if time_start is not None: - day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d") - title = f"{source or 'unknown'} {day_text} 情景片段" - else: - title = f"{source or 'unknown'} 情景片段" - - return { - "title": title[:80], - "summary": summary, - "paragraph_hashes": hashes, - "participants": participants, - "keywords": keywords, - "time_confidence": time_conf, - "llm_confidence": 0.0, - "event_time_start": time_start, - "event_time_end": time_end, - "time_granularity": granularity, - "segmentation_model": "fallback_rule", - "segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION, - } - - @staticmethod - def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]: - in_group = set(group_hashes_ordered) - dedup: List[str] = [] - seen = set() - for h in episode_hashes or []: - token = str(h or "").strip() - if not token or token not in in_group or token in seen: - continue - seen.add(token) - dedup.append(token) - return dedup - - async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]: - paragraphs = group.get("paragraphs", []) or [] - if not paragraphs: - return { - "payloads": [], - "done_hashes": [], - "episode_count": 0, - "fallback_count": 0, - } - - source = str(group.get("source", "") or "").strip() - group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()] - group_start, group_end, _, _ = self._compute_time_meta(paragraphs) - - fallback_used = False - segmentation_model = "fallback_rule" - segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION - - try: - llm_result = await self.segmentation_service.segment( - source=source, - window_start=group_start, - window_end=group_end, - paragraphs=paragraphs, - ) - episodes = list(llm_result.get("episodes") or []) - segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto" - segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION - if not episodes: - raise ValueError("llm_empty_episodes") - except Exception as e: - logger.warning( - "Episode segmentation fallback: " - f"source={source} " - f"size={len(group_hashes)} " - f"err={e}" - ) - episodes = [self._build_fallback_episode(group)] - fallback_used = True - - stored_payloads: List[Dict[str, Any]] = [] - for episode in episodes: - ordered_hashes = self._normalize_episode_hashes( - episode_hashes=episode.get("paragraph_hashes", []), - group_hashes_ordered=group_hashes, - ) - if not ordered_hashes: - continue - - sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)] - event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs) - - participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()] - keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()] - if not participants: - participants = self._collect_participants(ordered_hashes, limit=16) - if not keywords: - keywords = self._derive_keywords(sub_paragraphs, limit=12) - - title = str(episode.get("title", "") or "").strip()[:120] - summary = str(episode.get("summary", "") or "").strip()[:2000] - if not title or not summary: - continue - - seed = json.dumps( - { - "source": source, - "hashes": ordered_hashes, - "version": segmentation_version, - }, - ensure_ascii=False, - sort_keys=True, - ) - episode_id = compute_hash(seed) - - payload = { - "episode_id": episode_id, - "source": source or None, - "title": title, - "summary": summary, - "event_time_start": episode.get("event_time_start", event_start), - "event_time_end": episode.get("event_time_end", event_end), - "time_granularity": episode.get("time_granularity", granularity), - "time_confidence": self._clamp_score( - episode.get("time_confidence"), - default=time_conf_default, - ), - "participants": participants[:16], - "keywords": keywords[:20], - "evidence_ids": ordered_hashes, - "paragraph_count": len(ordered_hashes), - "llm_confidence": self._clamp_score( - episode.get("llm_confidence"), - default=0.0 if fallback_used else 0.6, - ), - "segmentation_model": ( - str(episode.get("segmentation_model", "") or "").strip() - or ("fallback_rule" if fallback_used else segmentation_model) - ), - "segmentation_version": ( - str(episode.get("segmentation_version", "") or "").strip() - or segmentation_version - ), - } - stored_payloads.append(payload) - - return { - "payloads": stored_payloads, - "done_hashes": group_hashes, - "episode_count": len(stored_payloads), - "fallback_count": 1 if fallback_used else 0, - } - - async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]: - result = await self._build_episode_payloads_for_group(group) - stored_count = 0 - for payload in result.get("payloads") or []: - stored = self.metadata_store.upsert_episode(payload) - final_id = str(stored.get("episode_id") or payload.get("episode_id") or "") - if final_id: - self.metadata_store.bind_episode_paragraphs( - final_id, - list(payload.get("evidence_ids") or []), - ) - stored_count += 1 - - result["episode_count"] = stored_count - return { - "done_hashes": list(result.get("done_hashes") or []), - "episode_count": stored_count, - "fallback_count": int(result.get("fallback_count") or 0), - } - - async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]: - loaded, missing_hashes = self.load_pending_paragraphs(pending_rows) - groups = self.group_paragraphs(loaded) - - done_hashes: List[str] = list(missing_hashes) - failed_hashes: Dict[str, str] = {} - episode_count = 0 - fallback_count = 0 - - for group in groups: - group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])] - try: - result = await self.process_group(group) - done_hashes.extend(result.get("done_hashes") or []) - episode_count += int(result.get("episode_count") or 0) - fallback_count += int(result.get("fallback_count") or 0) - except Exception as e: - err = str(e)[:500] - for h in group_hashes: - if h: - failed_hashes[h] = err - - dedup_done = list(dict.fromkeys([h for h in done_hashes if h])) - return { - "done_hashes": dedup_done, - "failed_hashes": failed_hashes, - "episode_count": episode_count, - "fallback_count": fallback_count, - "missing_count": len(missing_hashes), - "group_count": len(groups), - } - - async def rebuild_source(self, source: str) -> Dict[str, Any]: - token = str(source or "").strip() - if not token: - return { - "source": "", - "episode_count": 0, - "fallback_count": 0, - "group_count": 0, - "paragraph_count": 0, - } - - paragraphs = self.metadata_store.get_live_paragraphs_by_source(token) - if not paragraphs: - replace_result = self.metadata_store.replace_episodes_for_source(token, []) - return { - "source": token, - "episode_count": int(replace_result.get("episode_count") or 0), - "fallback_count": 0, - "group_count": 0, - "paragraph_count": 0, - } - - groups = self.group_paragraphs(paragraphs) - payloads: List[Dict[str, Any]] = [] - fallback_count = 0 - - for group in groups: - result = await self._build_episode_payloads_for_group(group) - payloads.extend(list(result.get("payloads") or [])) - fallback_count += int(result.get("fallback_count") or 0) - - replace_result = self.metadata_store.replace_episodes_for_source(token, payloads) - return { - "source": token, - "episode_count": int(replace_result.get("episode_count") or 0), - "fallback_count": fallback_count, - "group_count": len(groups), - "paragraph_count": len(paragraphs), - } diff --git a/plugins/A_memorix/core/utils/hash.py b/plugins/A_memorix/core/utils/hash.py deleted file mode 100644 index b6363257..00000000 --- a/plugins/A_memorix/core/utils/hash.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -哈希工具模块 - -提供文本哈希计算功能,用于唯一标识和去重。 -""" - -import hashlib -import re -from typing import Union - - -def compute_hash(text: str, hash_type: str = "sha256") -> str: - """ - 计算文本的哈希值 - - Args: - text: 输入文本 - hash_type: 哈希算法类型(sha256, md5等) - - Returns: - 哈希值字符串 - """ - if hash_type == "sha256": - return hashlib.sha256(text.encode("utf-8")).hexdigest() - elif hash_type == "md5": - return hashlib.md5(text.encode("utf-8")).hexdigest() - else: - raise ValueError(f"不支持的哈希算法: {hash_type}") - - -def normalize_text(text: str) -> str: - """ - 规范化文本用于哈希计算 - - 执行以下操作: - - 去除首尾空白 - - 统一换行符为\\n - - 压缩多个连续空格 - - 去除不可见字符(保留换行和制表符) - - Args: - text: 输入文本 - - Returns: - 规范化后的文本 - """ - # 去除首尾空白 - text = text.strip() - - # 统一换行符 - text = text.replace("\r\n", "\n").replace("\r", "\n") - - # 压缩多个连续空格为一个(但保留换行和制表符) - text = re.sub(r"[^\S\n]+", " ", text) - - return text - - -def compute_paragraph_hash(paragraph: str) -> str: - """ - 计算段落的哈希值 - - Args: - paragraph: 段落文本 - - Returns: - 段落哈希值(用于paragraph-前缀) - """ - normalized = normalize_text(paragraph) - return compute_hash(normalized) - - -def compute_entity_hash(entity: str) -> str: - """ - 计算实体的哈希值 - - Args: - entity: 实体名称 - - Returns: - 实体哈希值(用于entity-前缀) - """ - normalized = entity.strip().lower() - return compute_hash(normalized) - - -def compute_relation_hash(relation: tuple) -> str: - """ - 计算关系的哈希值 - - Args: - relation: 关系元组 (subject, predicate, object) - - Returns: - 关系哈希值(用于relation-前缀) - """ - # 将关系元组转为字符串 - relation_str = str(tuple(relation)) - return compute_hash(relation_str) - - -def format_hash_key(hash_type: str, hash_value: str) -> str: - """ - 格式化哈希键 - - Args: - hash_type: 类型前缀(paragraph, entity, relation) - hash_value: 哈希值 - - Returns: - 格式化的键(如 paragraph-abc123...) - """ - return f"{hash_type}-{hash_value}" - - -def parse_hash_key(key: str) -> tuple[str, str]: - """ - 解析哈希键 - - Args: - key: 格式化的键(如 paragraph-abc123...) - - Returns: - (类型, 哈希值) 元组 - """ - parts = key.split("-", 1) - if len(parts) != 2: - raise ValueError(f"无效的哈希键格式: {key}") - return parts[0], parts[1] diff --git a/plugins/A_memorix/core/utils/import_payloads.py b/plugins/A_memorix/core/utils/import_payloads.py deleted file mode 100644 index 6986a4c1..00000000 --- a/plugins/A_memorix/core/utils/import_payloads.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Shared import payload normalization helpers.""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from ..storage import KnowledgeType, resolve_stored_knowledge_type -from .time_parser import normalize_time_meta - - -def _normalize_entities(raw_entities: Any) -> List[str]: - if not isinstance(raw_entities, list): - return [] - out: List[str] = [] - seen = set() - for item in raw_entities: - name = str(item or "").strip() - if not name: - continue - key = name.lower() - if key in seen: - continue - seen.add(key) - out.append(name) - return out - - -def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]: - if not isinstance(raw_relations, list): - return [] - out: List[Dict[str, str]] = [] - for item in raw_relations: - if not isinstance(item, dict): - continue - subject = str(item.get("subject", "")).strip() - predicate = str(item.get("predicate", "")).strip() - obj = str(item.get("object", "")).strip() - if not (subject and predicate and obj): - continue - out.append( - { - "subject": subject, - "predicate": predicate, - "object": obj, - } - ) - return out - - -def normalize_paragraph_import_item( - item: Any, - *, - default_source: str, -) -> Dict[str, Any]: - """Normalize one paragraph import item from text/json payloads.""" - - if isinstance(item, str): - content = str(item) - knowledge_type = resolve_stored_knowledge_type(None, content=content) - return { - "content": content, - "knowledge_type": knowledge_type.value, - "source": str(default_source or "").strip(), - "time_meta": None, - "entities": [], - "relations": [], - } - - if not isinstance(item, dict) or "content" not in item: - raise ValueError("段落项必须为字符串或包含 content 的对象") - - content = str(item.get("content", "") or "") - if not content.strip(): - raise ValueError("段落 content 不能为空") - - raw_time_meta = { - "event_time": item.get("event_time"), - "event_time_start": item.get("event_time_start"), - "event_time_end": item.get("event_time_end"), - "time_range": item.get("time_range"), - "time_granularity": item.get("time_granularity"), - "time_confidence": item.get("time_confidence"), - } - time_meta_field = item.get("time_meta") - if isinstance(time_meta_field, dict): - raw_time_meta.update(time_meta_field) - - knowledge_type_raw = item.get("knowledge_type") - if knowledge_type_raw is None: - knowledge_type_raw = item.get("type") - knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content) - source = str(item.get("source") or default_source or "").strip() - if not source: - source = str(default_source or "").strip() - - normalized_time_meta = normalize_time_meta(raw_time_meta) - return { - "content": content, - "knowledge_type": knowledge_type.value, - "source": source, - "time_meta": normalized_time_meta if normalized_time_meta else None, - "entities": _normalize_entities(item.get("entities")), - "relations": _normalize_relations(item.get("relations")), - } - - -def normalize_summary_knowledge_type(value: Any) -> KnowledgeType: - """Normalize config-driven summary knowledge type.""" - - return resolve_stored_knowledge_type(value, content="") diff --git a/plugins/A_memorix/core/utils/io.py b/plugins/A_memorix/core/utils/io.py deleted file mode 100644 index ed14df43..00000000 --- a/plugins/A_memorix/core/utils/io.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -IO Utilities - -提供原子文件写入等IO辅助功能。 -""" - -import os -import shutil -import contextlib -from pathlib import Path -from typing import Union - -@contextlib.contextmanager -def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs): - """ - 原子文件写入上下文管理器 - - 原理: - 1. 写入 .tmp 临时文件 - 2. 写入成功后,使用 os.replace 原子替换目标文件 - 3. 如果失败,自动删除临时文件 - - Args: - file_path: 目标文件路径 - mode: 打开模式 ('w', 'wb' 等) - encoding: 编码 - **kwargs: 传给 open() 的其他参数 - """ - path = Path(file_path) - # 确保父目录存在 - path.parent.mkdir(parents=True, exist_ok=True) - - # 临时文件路径 - tmp_path = path.with_suffix(path.suffix + ".tmp") - - try: - with open(tmp_path, mode, encoding=encoding, **kwargs) as f: - yield f - - # 确保写入磁盘 - f.flush() - os.fsync(f.fileno()) - - # 原子替换 (Windows下可能需要先删除目标文件,但 os.replace 在 Py3.3+ 尽可能原子) - # 注意: Windows 上如果有其他进程占用文件,os.replace 可能会失败 - os.replace(tmp_path, path) - - except Exception as e: - # 清理临时文件 - if tmp_path.exists(): - try: - os.remove(tmp_path) - except: - pass - raise e - -@contextlib.contextmanager -def atomic_save_path(file_path: Union[str, Path]): - """ - 提供临时路径用于原子保存 (针对只接受路径的API,如Faiss) - - Args: - file_path: 最终目标路径 - - Yields: - tmp_path: 临时文件路径 (str) - """ - path = Path(file_path) - path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = path.with_suffix(path.suffix + ".tmp") - - try: - yield str(tmp_path) - - if Path(tmp_path).exists(): - os.replace(tmp_path, path) - - except Exception as e: - if Path(tmp_path).exists(): - try: - os.remove(tmp_path) - except: - pass - raise e diff --git a/plugins/A_memorix/core/utils/matcher.py b/plugins/A_memorix/core/utils/matcher.py deleted file mode 100644 index bddff5ee..00000000 --- a/plugins/A_memorix/core/utils/matcher.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -高效文本匹配工具模块 - -实现 Aho-Corasick 算法用于多模式匹配。 -""" - -from typing import List, Dict, Tuple, Set, Any -from collections import deque - - -class AhoCorasick: - """ - Aho-Corasick 自动机实现高效多模式匹配 - """ - - def __init__(self): - # next_states[state][char] = next_state - self.next_states: List[Dict[str, int]] = [{}] - # fail[state] = fail_state - self.fail: List[int] = [0] - # output[state] = set of patterns ending at this state - self.output: List[Set[str]] = [set()] - self.patterns: Set[str] = set() - - def add_pattern(self, pattern: str): - """添加模式""" - if not pattern: - return - self.patterns.add(pattern) - state = 0 - for char in pattern: - if char not in self.next_states[state]: - new_state = len(self.next_states) - self.next_states[state][char] = new_state - self.next_states.append({}) - self.fail.append(0) - self.output.append(set()) - state = self.next_states[state][char] - self.output[state].add(pattern) - - def build(self): - """构建失败指针""" - queue = deque() - # 处理第一层 - for char, state in self.next_states[0].items(): - queue.append(state) - self.fail[state] = 0 - - while queue: - r = queue.popleft() - for char, s in self.next_states[r].items(): - queue.append(s) - # 找到失败路径 - state = self.fail[r] - while char not in self.next_states[state] and state != 0: - state = self.fail[state] - self.fail[s] = self.next_states[state].get(char, 0) - # 合并输出 - self.output[s].update(self.output[self.fail[s]]) - - def search(self, text: str) -> List[Tuple[int, str]]: - """ - 在文本中搜索所有模式 - - Returns: - [(结束索引, 匹配到的模式), ...] - """ - state = 0 - results = [] - for i, char in enumerate(text): - while char not in self.next_states[state] and state != 0: - state = self.fail[state] - state = self.next_states[state].get(char, 0) - for pattern in self.output[state]: - results.append((i, pattern)) - return results - - def find_all(self, text: str) -> Dict[str, int]: - """ - 查找并统计所有模式出现次数 - - Returns: - {模式: 出现次数} - """ - results = self.search(text) - stats = {} - for _, pattern in results: - stats[pattern] = stats.get(pattern, 0) + 1 - return stats diff --git a/plugins/A_memorix/core/utils/monitor.py b/plugins/A_memorix/core/utils/monitor.py deleted file mode 100644 index 39c794ab..00000000 --- a/plugins/A_memorix/core/utils/monitor.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -内存监控模块 - -提供内存使用监控和预警功能。 -""" - -import gc -import threading -import time -from typing import Callable, Optional - -try: - import psutil - HAS_PSUTIL = True -except ImportError: - HAS_PSUTIL = False - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.MemoryMonitor") - - -class MemoryMonitor: - """ - 内存监控器 - - 功能: - - 实时监控内存使用 - - 超过阈值时触发警告 - - 支持自动垃圾回收 - """ - - def __init__( - self, - max_memory_mb: int, - warning_threshold: float = 0.9, - check_interval: float = 10.0, - enable_auto_gc: bool = True, - ): - """ - 初始化内存监控器 - - Args: - max_memory_mb: 最大内存限制(MB) - warning_threshold: 警告阈值(0-1之间,默认0.9表示90%) - check_interval: 检查间隔(秒) - enable_auto_gc: 是否启用自动垃圾回收 - """ - self.max_memory_mb = max_memory_mb - self.warning_threshold = warning_threshold - self.check_interval = check_interval - self.enable_auto_gc = enable_auto_gc - - self._running = False - self._thread: Optional[threading.Thread] = None - self._callbacks: list[Callable[[float, float], None]] = [] - - def start(self): - """启动监控""" - if self._running: - logger.warning("内存监控已在运行") - return - - if not HAS_PSUTIL: - logger.warning("psutil 未安装,内存监控功能不可用") - return - - self._running = True - self._thread = threading.Thread(target=self._monitor_loop, daemon=True) - self._thread.start() - logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)") - - def stop(self): - """停止监控""" - if not self._running: - return - - self._running = False - if self._thread: - self._thread.join(timeout=5.0) - logger.info("内存监控已停止") - - def register_callback(self, callback: Callable[[float, float], None]): - """ - 注册内存超限回调函数 - - Args: - callback: 回调函数,接收 (当前使用MB, 限制MB) 参数 - """ - self._callbacks.append(callback) - - def get_current_memory_mb(self) -> float: - """ - 获取当前进程内存使用量(MB) - - Returns: - 内存使用量(MB) - """ - if not HAS_PSUTIL: - # 降级方案:使用内置函数 - import sys - return sys.getsizeof(gc.get_objects()) / 1024 / 1024 - - process = psutil.Process() - return process.memory_info().rss / 1024 / 1024 - - def get_memory_usage_ratio(self) -> float: - """ - 获取内存使用率 - - Returns: - 使用率(0-1之间) - """ - current = self.get_current_memory_mb() - return current / self.max_memory_mb if self.max_memory_mb > 0 else 0 - - def _monitor_loop(self): - """监控循环""" - while self._running: - try: - current_mb = self.get_current_memory_mb() - ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0 - - # 检查是否超过阈值 - if ratio >= self.warning_threshold: - logger.warning( - f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB " - f"({ratio*100:.1f}%)" - ) - - # 触发回调 - for callback in self._callbacks: - try: - callback(current_mb, self.max_memory_mb) - except Exception as e: - logger.error(f"内存回调执行失败: {e}") - - # 自动垃圾回收 - if self.enable_auto_gc: - before = self.get_current_memory_mb() - gc.collect() - after = self.get_current_memory_mb() - freed = before - after - if freed > 1: # 释放超过1MB才记录 - logger.info(f"垃圾回收释放: {freed:.1f}MB") - - # 定期垃圾回收(即使未超限) - elif self.enable_auto_gc and int(time.time()) % 60 == 0: - gc.collect() - - except Exception as e: - logger.error(f"内存监控出错: {e}") - - # 等待下次检查 - time.sleep(self.check_interval) - - def __enter__(self): - """上下文管理器入口""" - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """上下文管理器出口""" - self.stop() - - -def get_memory_info() -> dict: - """ - 获取系统内存信息 - - Returns: - 内存信息字典 - """ - if not HAS_PSUTIL: - return {"error": "psutil 未安装"} - - try: - mem = psutil.virtual_memory() - process = psutil.Process() - - return { - "system_total_gb": mem.total / 1024 / 1024 / 1024, - "system_available_gb": mem.available / 1024 / 1024 / 1024, - "system_usage_percent": mem.percent, - "process_mb": process.memory_info().rss / 1024 / 1024, - "process_percent": (process.memory_info().rss / mem.total) * 100, - } - except Exception as e: - return {"error": str(e)} diff --git a/plugins/A_memorix/core/utils/path_fallback_service.py b/plugins/A_memorix/core/utils/path_fallback_service.py deleted file mode 100644 index 7a802743..00000000 --- a/plugins/A_memorix/core/utils/path_fallback_service.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Shared path-fallback helpers for search post-processing.""" - -from __future__ import annotations - -import hashlib -from typing import Any, Dict, List, Optional, Sequence, Tuple - -from ..retrieval.dual_path import RetrievalResult - - -def extract_entities(query: str, graph_store: Any) -> List[str]: - """Extract up to two graph nodes from a query using n-gram matching.""" - if not graph_store: - return [] - - text = str(query or "").strip() - if not text: - return [] - - # Keep the heuristic aligned with previous legacy behavior. - tokens = ( - text.replace("?", " ") - .replace("!", " ") - .replace(".", " ") - .split() - ) - if not tokens: - return [] - - found_entities = set() - skip_indices = set() - max_n = min(4, len(tokens)) - - for size in range(max_n, 0, -1): - for i in range(len(tokens) - size + 1): - if any(idx in skip_indices for idx in range(i, i + size)): - continue - span = " ".join(tokens[i : i + size]) - matched_node = graph_store.find_node(span, ignore_case=True) - if not matched_node: - continue - found_entities.add(matched_node) - for idx in range(i, i + size): - skip_indices.add(idx) - - return list(found_entities) - - -def find_paths_between_entities( - start_node: str, - end_node: str, - graph_store: Any, - metadata_store: Any, - *, - max_depth: int = 3, - max_paths: int = 5, -) -> List[Dict[str, Any]]: - """Find and enrich indirect paths between two nodes.""" - if not graph_store or not metadata_store: - return [] - - try: - paths = graph_store.find_paths( - start_node, - end_node, - max_depth=max_depth, - max_paths=max_paths, - ) - except Exception: - return [] - - if not paths: - return [] - - edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {} - formatted_paths: List[Dict[str, Any]] = [] - - for path_nodes in paths: - if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2: - continue - - path_desc: List[str] = [] - for i in range(len(path_nodes) - 1): - u = str(path_nodes[i]) - v = str(path_nodes[i + 1]) - - cache_key = tuple(sorted((u, v))) - if cache_key in edge_cache: - pred, direction = edge_cache[cache_key] - else: - pred = "related" - direction = "->" - rels = metadata_store.get_relations(subject=u, object=v) - if not rels: - rels = metadata_store.get_relations(subject=v, object=u) - direction = "<-" - if rels: - best_rel = max(rels, key=lambda x: x.get("confidence", 1.0)) - pred = str(best_rel.get("predicate", "related") or "related") - edge_cache[cache_key] = (pred, direction) - - step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-" - path_desc.append(step_str) - - full_path_str = str(path_nodes[0]) - for i, step in enumerate(path_desc): - full_path_str += f" {step} {path_nodes[i + 1]}" - - formatted_paths.append( - { - "nodes": list(path_nodes), - "description": full_path_str, - } - ) - - return formatted_paths - - -def find_paths_from_query( - query: str, - graph_store: Any, - metadata_store: Any, - *, - max_depth: int = 3, - max_paths: int = 5, -) -> List[Dict[str, Any]]: - """Extract entities from query and resolve indirect paths.""" - entities = extract_entities(query, graph_store) - if len(entities) != 2: - return [] - return find_paths_between_entities( - entities[0], - entities[1], - graph_store, - metadata_store, - max_depth=max_depth, - max_paths=max_paths, - ) - - -def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]: - """Convert path results into retrieval results for the unified pipeline.""" - converted: List[RetrievalResult] = [] - for item in paths: - description = str(item.get("description", "")).strip() - if not description: - continue - hash_seed = description.encode("utf-8") - path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}" - converted.append( - RetrievalResult( - hash_value=path_hash, - content=f"[Indirect Relation] {description}", - score=0.95, - result_type="relation", - source="graph_path", - metadata={ - "source": "graph_path", - "is_indirect": True, - "nodes": list(item.get("nodes", [])), - }, - ) - ) - return converted - diff --git a/plugins/A_memorix/core/utils/person_profile_service.py b/plugins/A_memorix/core/utils/person_profile_service.py deleted file mode 100644 index 6460c013..00000000 --- a/plugins/A_memorix/core/utils/person_profile_service.py +++ /dev/null @@ -1,554 +0,0 @@ -""" -人物画像服务 - -主链路: -person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储 -""" - -import json -import time -from typing import Any, Dict, List, Optional, Tuple - -from sqlalchemy import or_ -from sqlmodel import select - -from src.common.logger import get_logger -from src.common.database.database import get_db_session -from src.common.database.database_model import PersonInfo - -from ..embedding import EmbeddingAPIAdapter -from ..retrieval import ( - DualPathRetriever, - RetrievalStrategy, - DualPathRetrieverConfig, - SparseBM25Config, - FusionConfig, - GraphRelationRecallConfig, -) -from ..storage import MetadataStore, GraphStore, VectorStore - -logger = get_logger("A_Memorix.PersonProfileService") - - -class PersonProfileService: - """人物画像聚合/刷新服务。""" - - def __init__( - self, - metadata_store: MetadataStore, - graph_store: Optional[GraphStore] = None, - vector_store: Optional[VectorStore] = None, - embedding_manager: Optional[EmbeddingAPIAdapter] = None, - sparse_index: Any = None, - plugin_config: Optional[dict] = None, - retriever: Optional[DualPathRetriever] = None, - ): - self.metadata_store = metadata_store - self.graph_store = graph_store - self.vector_store = vector_store - self.embedding_manager = embedding_manager - self.sparse_index = sparse_index - self.plugin_config = plugin_config or {} - self.retriever = retriever or self._build_retriever() - - def _cfg(self, key: str, default: Any = None) -> Any: - """读取嵌套配置。""" - if not isinstance(self.plugin_config, dict): - return default - current: Any = self.plugin_config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - def _build_retriever(self) -> Optional[DualPathRetriever]: - """按需构建检索器(无依赖时返回 None)。""" - if not all( - [ - self.vector_store is not None, - self.graph_store is not None, - self.metadata_store is not None, - self.embedding_manager is not None, - ] - ): - return None - try: - sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {} - fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {} - graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {} - if not isinstance(sparse_cfg_raw, dict): - sparse_cfg_raw = {} - if not isinstance(fusion_cfg_raw, dict): - fusion_cfg_raw = {} - if not isinstance(graph_recall_cfg_raw, dict): - graph_recall_cfg_raw = {} - - sparse_cfg = SparseBM25Config(**sparse_cfg_raw) - fusion_cfg = FusionConfig(**fusion_cfg_raw) - graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw) - config = DualPathRetrieverConfig( - top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)), - top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)), - top_k_final=int(self._cfg("retrieval.top_k_final", 10)), - alpha=float(self._cfg("retrieval.alpha", 0.5)), - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)), - ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)), - enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)), - retrieval_strategy=RetrievalStrategy.DUAL_PATH, - debug=bool(self._cfg("advanced.debug", False)), - sparse=sparse_cfg, - fusion=fusion_cfg, - graph_recall=graph_recall_cfg, - ) - return DualPathRetriever( - vector_store=self.vector_store, - graph_store=self.graph_store, - metadata_store=self.metadata_store, - embedding_manager=self.embedding_manager, - sparse_index=self.sparse_index, - config=config, - ) - except Exception as e: - logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}") - return None - - @staticmethod - def resolve_person_id(identifier: str) -> str: - """按 person_id 或姓名/别名解析 person_id。""" - if not identifier: - return "" - key = str(identifier).strip() - if not key: - return "" - - try: - with get_db_session(auto_commit=False) as session: - record = session.exec( - select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1) - ).first() - if record: - return str(record) - - record = session.exec( - select(PersonInfo.person_id) - .where( - or_( - PersonInfo.person_name == key, - PersonInfo.user_nickname == key, - ) - ) - .limit(1) - ).first() - if record: - return str(record) - - record = session.exec( - select(PersonInfo.person_id) - .where(PersonInfo.group_cardname.contains(key)) - .limit(1) - ).first() - if record: - return str(record) - except Exception as e: - logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}") - - if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key): - return key.lower() - - return "" - - def _parse_group_nicks(self, raw_value: Any) -> List[str]: - if not raw_value: - return [] - if isinstance(raw_value, list): - items = raw_value - else: - try: - items = json.loads(raw_value) - except Exception: - return [] - names: List[str] = [] - for item in items: - if isinstance(item, dict): - value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip() - if value: - names.append(value) - elif isinstance(item, str): - value = item.strip() - if value: - names.append(value) - return names - - def _parse_memory_traits(self, raw_value: Any) -> List[str]: - if not raw_value: - return [] - try: - values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value - except Exception: - return [] - if not isinstance(values, list): - return [] - traits: List[str] = [] - for item in values: - text = str(item).strip() - if not text: - continue - if ":" in text: - parts = text.split(":") - if len(parts) >= 3: - content = ":".join(parts[1:-1]).strip() - if content: - traits.append(content) - continue - traits.append(text) - return traits[:10] - - def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]: - """当人物主档案缺失时,从已有记忆证据里回捞可用别名。""" - if not person_id: - return [], "" - - aliases: List[str] = [] - primary_name = "" - seen = set() - - try: - paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id) - except Exception as e: - logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}") - return [], "" - - for paragraph in paragraphs[:20]: - paragraph_hash = str(paragraph.get("hash", "") or "").strip() - if not paragraph_hash: - continue - try: - paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash) - except Exception: - paragraph_entities = [] - for entity in paragraph_entities: - name = str(entity.get("name", "") or "").strip() - if not name or name == person_id: - continue - key = name.lower() - if key in seen: - continue - seen.add(key) - aliases.append(name) - if not primary_name: - primary_name = name - return aliases, primary_name - - def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]: - """获取人物别名集合、主展示名、记忆特征。""" - aliases: List[str] = [] - primary_name = "" - memory_traits: List[str] = [] - if not person_id: - return aliases, primary_name, memory_traits - recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id) - try: - with get_db_session(auto_commit=False) as session: - record = session.exec( - select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1) - ).first() - if not record: - return recovered_aliases, recovered_primary_name or person_id, memory_traits - person_name = str(getattr(record, "person_name", "") or "").strip() - nickname = str(getattr(record, "user_nickname", "") or "").strip() - group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None)) - memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None)) - - primary_name = ( - person_name - or nickname - or recovered_primary_name - or str(getattr(record, "user_id", "") or "").strip() - or person_id - ) - - candidates = [person_name, nickname] + group_nicks + recovered_aliases - seen = set() - for item in candidates: - norm = str(item or "").strip() - if not norm or norm in seen: - continue - seen.add(norm) - aliases.append(norm) - except Exception as e: - logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}") - return aliases, primary_name, memory_traits - - def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]: - relation_by_hash: Dict[str, Dict[str, Any]] = {} - for alias in aliases: - for rel in self.metadata_store.get_relations(subject=alias): - h = str(rel.get("hash", "")) - if h: - relation_by_hash[h] = rel - for rel in self.metadata_store.get_relations(object=alias): - h = str(rel.get("hash", "")) - if h: - relation_by_hash[h] = rel - - relations = list(relation_by_hash.values()) - relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True) - relations = relations[: max(1, int(limit))] - - edges: List[Dict[str, Any]] = [] - for rel in relations: - edges.append( - { - "hash": str(rel.get("hash", "")), - "subject": str(rel.get("subject", "")), - "predicate": str(rel.get("predicate", "")), - "object": str(rel.get("object", "")), - "confidence": float(rel.get("confidence", 1.0) or 1.0), - } - ) - return edges - - async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]: - alias_queries = [a for a in aliases if a] - if not alias_queries: - return [] - - if self.retriever is None: - # 回退:无检索器时只做简单内容匹配 - fallback: List[Dict[str, Any]] = [] - seen_hash = set() - for alias in alias_queries: - for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]: - h = str(para.get("hash", "")) - if not h or h in seen_hash: - continue - seen_hash.add(h) - fallback.append( - { - "hash": h, - "type": "paragraph", - "score": 0.0, - "content": str(para.get("content", ""))[:180], - "metadata": {}, - } - ) - return fallback[:top_k] - - per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries)))) - seen_hash = set() - evidence: List[Dict[str, Any]] = [] - for alias in alias_queries: - try: - results = await self.retriever.retrieve(alias, top_k=per_alias_top_k) - except Exception as e: - logger.warning(f"向量证据召回失败: alias={alias}, err={e}") - continue - for item in results: - h = str(getattr(item, "hash_value", "") or "") - if not h or h in seen_hash: - continue - seen_hash.add(h) - evidence.append( - { - "hash": h, - "type": str(getattr(item, "result_type", "")), - "score": float(getattr(item, "score", 0.0) or 0.0), - "content": str(getattr(item, "content", "") or "")[:220], - "metadata": dict(getattr(item, "metadata", {}) or {}), - } - ) - evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True) - return evidence[:top_k] - - def _build_profile_text( - self, - person_id: str, - primary_name: str, - aliases: List[str], - relation_edges: List[Dict[str, Any]], - vector_evidence: List[Dict[str, Any]], - memory_traits: List[str], - ) -> str: - """基于证据构建画像文本(供 LLM 上下文注入)。""" - lines: List[str] = [] - lines.append(f"人物ID: {person_id}") - if primary_name: - lines.append(f"主称呼: {primary_name}") - if aliases: - lines.append(f"别名: {', '.join(aliases[:8])}") - if memory_traits: - lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}") - - if relation_edges: - lines.append("关系证据:") - for rel in relation_edges[:6]: - s = rel.get("subject", "") - p = rel.get("predicate", "") - o = rel.get("object", "") - conf = float(rel.get("confidence", 0.0)) - lines.append(f"- {s} {p} {o} (conf={conf:.2f})") - - if vector_evidence: - lines.append("向量证据摘要:") - for item in vector_evidence[:4]: - content = str(item.get("content", "")).strip() - if content: - lines.append(f"- {content}") - - if len(lines) <= 2: - lines.append("暂无足够证据形成稳定画像。") - - return "\n".join(lines) - - @staticmethod - def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool: - if not snapshot: - return True - now = time.time() - expires_at = snapshot.get("expires_at") - if expires_at is not None: - try: - return now >= float(expires_at) - except Exception: - return True - updated_at = float(snapshot.get("updated_at") or 0.0) - return (now - updated_at) >= ttl_seconds - - def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]: - """将手工覆盖并入画像结果(覆盖 profile_text,同时保留 auto_profile_text)。""" - payload = dict(profile_payload or {}) - auto_text = str(payload.get("profile_text", "") or "") - payload["auto_profile_text"] = auto_text - payload["has_manual_override"] = False - payload["manual_override_text"] = "" - payload["override_updated_at"] = None - payload["override_updated_by"] = "" - payload["profile_source"] = "auto_snapshot" - - if not person_id or self.metadata_store is None: - return payload - - try: - override = self.metadata_store.get_person_profile_override(person_id) - except Exception as e: - logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}") - return payload - - if not override: - return payload - - manual_text = str(override.get("override_text", "") or "").strip() - if not manual_text: - return payload - - payload["has_manual_override"] = True - payload["manual_override_text"] = manual_text - payload["override_updated_at"] = override.get("updated_at") - payload["override_updated_by"] = str(override.get("updated_by", "") or "") - payload["profile_text"] = manual_text - payload["profile_source"] = "manual_override" - return payload - - async def query_person_profile( - self, - person_id: str = "", - person_keyword: str = "", - top_k: int = 12, - ttl_seconds: float = 6 * 3600, - force_refresh: bool = False, - source_note: str = "", - ) -> Dict[str, Any]: - """查询或刷新人物画像。""" - pid = str(person_id or "").strip() - if not pid and person_keyword: - pid = self.resolve_person_id(person_keyword) - - if not pid: - return { - "success": False, - "error": "person_id 无效,且未能通过别名解析", - } - - latest = self.metadata_store.get_latest_person_profile_snapshot(pid) - if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds): - aliases, primary_name, _ = self.get_person_aliases(pid) - payload = { - "success": True, - "person_id": pid, - "person_name": primary_name, - "from_cache": True, - **(latest or {}), - } - if aliases and not payload.get("aliases"): - payload["aliases"] = aliases - return { - **self._apply_manual_override(pid, payload), - } - - aliases, primary_name, memory_traits = self.get_person_aliases(pid) - if not aliases and person_keyword: - aliases = [person_keyword.strip()] - primary_name = person_keyword.strip() - relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2)) - vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k)) - - evidence_ids = [ - str(item.get("hash", "")) - for item in (relation_edges + vector_evidence) - if str(item.get("hash", "")).strip() - ] - dedup_ids: List[str] = [] - seen = set() - for item in evidence_ids: - if item in seen: - continue - seen.add(item) - dedup_ids.append(item) - - profile_text = self._build_profile_text( - person_id=pid, - primary_name=primary_name, - aliases=aliases, - relation_edges=relation_edges, - vector_evidence=vector_evidence, - memory_traits=memory_traits, - ) - - expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None - snapshot = self.metadata_store.upsert_person_profile_snapshot( - person_id=pid, - profile_text=profile_text, - aliases=aliases, - relation_edges=relation_edges, - vector_evidence=vector_evidence, - evidence_ids=dedup_ids, - expires_at=expires_at, - source_note=source_note, - ) - payload = { - "success": True, - "person_id": pid, - "person_name": primary_name, - "from_cache": False, - **snapshot, - } - return { - **self._apply_manual_override(pid, payload), - } - - @staticmethod - def format_persona_profile_block(profile: Dict[str, Any]) -> str: - """格式化给 replyer 的注入块。""" - if not profile or not profile.get("success"): - return "" - text = str(profile.get("profile_text", "") or "").strip() - if not text: - return "" - return ( - "【人物画像-内部参考】\n" - f"{text}\n" - "仅供内部推理,不要向用户逐字复述。" - ) diff --git a/plugins/A_memorix/core/utils/plugin_id_policy.py b/plugins/A_memorix/core/utils/plugin_id_policy.py deleted file mode 100644 index 8e730e12..00000000 --- a/plugins/A_memorix/core/utils/plugin_id_policy.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Plugin ID matching policy for A_Memorix.""" - -from __future__ import annotations - -from typing import Any - - -class PluginIdPolicy: - """Centralized plugin id normalization/matching policy.""" - - CANONICAL_ID = "a_memorix" - - @classmethod - def normalize(cls, plugin_id: Any) -> str: - if not isinstance(plugin_id, str): - return "" - return plugin_id.strip().lower() - - @classmethod - def is_target_plugin_id(cls, plugin_id: Any) -> bool: - normalized = cls.normalize(plugin_id) - if not normalized: - return False - if normalized == cls.CANONICAL_ID: - return True - return normalized.split(".")[-1] == cls.CANONICAL_ID - diff --git a/plugins/A_memorix/core/utils/quantization.py b/plugins/A_memorix/core/utils/quantization.py deleted file mode 100644 index 4e84f977..00000000 --- a/plugins/A_memorix/core/utils/quantization.py +++ /dev/null @@ -1,344 +0,0 @@ -""" -向量量化工具模块 - -提供向量量化与反量化功能,用于压缩存储空间。 -""" - -import numpy as np -from enum import Enum -from typing import Tuple, Union - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.Quantization") - - -class QuantizationType(Enum): - """量化类型枚举""" - FLOAT32 = "float32" # 无量化 - INT8 = "int8" # 标量量化(8位整数) - PQ = "pq" # 乘积量化(Product Quantization) - - -def quantize_vector( - vector: np.ndarray, - quant_type: QuantizationType = QuantizationType.INT8, -) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - 量化向量 - - Args: - vector: 输入向量(float32) - quant_type: 量化类型 - - Returns: - 量化后的向量: - - INT8: int8向量 - - PQ: (编码向量, 聚类中心) 元组 - """ - if quant_type == QuantizationType.FLOAT32: - return vector.astype(np.float32) - - elif quant_type == QuantizationType.INT8: - return _scalar_quantize_int8(vector) - - elif quant_type == QuantizationType.PQ: - return _product_quantize(vector) - - else: - raise ValueError(f"不支持的量化类型: {quant_type}") - - -def dequantize_vector( - quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], - quant_type: QuantizationType = QuantizationType.INT8, - original_shape: Tuple[int, ...] = None, -) -> np.ndarray: - """ - 反量化向量 - - Args: - quantized_vector: 量化后的向量 - quant_type: 量化类型 - original_shape: 原始向量形状(用于PQ) - - Returns: - 反量化后的向量(float32) - """ - if quant_type == QuantizationType.FLOAT32: - return quantized_vector.astype(np.float32) - - elif quant_type == QuantizationType.INT8: - return _scalar_dequantize_int8(quantized_vector) - - elif quant_type == QuantizationType.PQ: - if not isinstance(quantized_vector, tuple): - raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)") - return _product_dequantize(quantized_vector[0], quantized_vector[1]) - - else: - raise ValueError(f"不支持的量化类型: {quant_type}") - - -def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray: - """ - 标量量化:float32 -> int8 - - 将向量归一化到 [0, 255] 范围,然后映射到 int8 - - Args: - vector: 输入向量 - - Returns: - 量化后的 int8 向量 - """ - # 计算最小最大值 - min_val = np.min(vector) - max_val = np.max(vector) - - # 避免除零 - if max_val == min_val: - return np.zeros_like(vector, dtype=np.int8) - - # 归一化到 [0, 255] - normalized = (vector - min_val) / (max_val - min_val) * 255 - - # 映射到 [-128, 127] 并转换为 int8 - # np.round might return float, minus 128 then cast - quantized = np.round(normalized - 128.0).astype(np.int8) - - # 存储归一化参数(用于反量化) - # 在实际存储中,这些参数需要单独保存 - # 这里为了简单,我们使用一个全局字典来模拟 - if not hasattr(_scalar_quantize_int8, "_params"): - _scalar_quantize_int8._params = {} - - vector_id = id(vector) - _scalar_quantize_int8._params[vector_id] = (min_val, max_val) - - return quantized - - -def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray: - """ - 标量反量化:int8 -> float32 - - Args: - quantized: 量化后的 int8 向量 - - Returns: - 反量化后的 float32 向量 - """ - # 计算归一化参数(如果提供了) - # 在实际应用中,min_val 和 max_val 应该被保存 - if not hasattr(_scalar_dequantize_int8, "_params"): - # 默认假设范围是 [-1, 1] - return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0 - - # 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取) - # return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min - return (quantized.astype(np.float32) + 128.0) / 255.0 - - -def quantize_matrix( - matrix: np.ndarray, - quant_type: QuantizationType = QuantizationType.INT8, -) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - 量化矩阵(批量量化向量) - - Args: - matrix: 输入矩阵(N x D,每行是一个向量) - quant_type: 量化类型 - - Returns: - 量化后的矩阵 - """ - if quant_type == QuantizationType.FLOAT32: - return matrix.astype(np.float32) - - elif quant_type == QuantizationType.INT8: - # 对整个矩阵进行全局归一化 - min_val = np.min(matrix) - max_val = np.max(matrix) - - if max_val == min_val: - return np.zeros_like(matrix, dtype=np.int8) - - # 归一化到 [0, 255] - normalized = (matrix - min_val) / (max_val - min_val) * 255 - quantized = np.round(normalized).astype(np.int8) - - return quantized - - else: - raise ValueError(f"不支持的量化类型: {quant_type}") - - -def dequantize_matrix( - quantized_matrix: np.ndarray, - quant_type: QuantizationType = QuantizationType.INT8, - min_val: float = None, - max_val: float = None, -) -> np.ndarray: - """ - 反量化矩阵 - - Args: - quantized_matrix: 量化后的矩阵 - quant_type: 量化类型 - min_val: 归一化最小值(int8反量化需要) - max_val: 归一化最大值(int8反量化需要) - - Returns: - 反量化后的矩阵 - """ - if quant_type == QuantizationType.FLOAT32: - return quantized_matrix.astype(np.float32) - - elif quant_type == QuantizationType.INT8: - # 使用提供的归一化参数反量化 - if min_val is None or max_val is None: - # 默认假设范围是 [0, 255] -> [-1, 1] - return quantized_matrix.astype(np.float32) / 127.0 - else: - # 恢复到原始范围 - normalized = quantized_matrix.astype(np.float32) / 255.0 - return normalized * (max_val - min_val) + min_val - - else: - raise ValueError(f"不支持的量化类型: {quant_type}") - - -def estimate_memory_reduction( - num_vectors: int, - dimension: int, - from_type: QuantizationType, - to_type: QuantizationType, -) -> Tuple[float, float]: - """ - 估算内存节省量 - - Args: - num_vectors: 向量数量 - dimension: 向量维度 - from_type: 原始量化类型 - to_type: 目标量化类型 - - Returns: - (原始大小MB, 量化后大小MB, 节省比例) - """ - # 计算每个向量占用的字节数 - bytes_per_element = { - QuantizationType.FLOAT32: 4, - QuantizationType.INT8: 1, - QuantizationType.PQ: 0.25, # 假设压缩到1/4 - } - - original_bytes = num_vectors * dimension * bytes_per_element[from_type] - quantized_bytes = num_vectors * dimension * bytes_per_element[to_type] - - original_mb = original_bytes / 1024 / 1024 - quantized_mb = quantized_bytes / 1024 / 1024 - reduction_ratio = (original_bytes - quantized_bytes) / original_bytes - - return original_mb, quantized_mb, reduction_ratio - - -def estimate_compression_stats( - num_vectors: int, - dimension: int, - quant_type: QuantizationType, -) -> dict: - """ - 估算压缩统计信息 - - Args: - num_vectors: 向量数量 - dimension: 向量维度 - quant_type: 量化类型 - - Returns: - 统计信息字典 - """ - original_mb, quantized_mb, ratio = estimate_memory_reduction( - num_vectors, dimension, QuantizationType.FLOAT32, quant_type - ) - - return { - "num_vectors": num_vectors, - "dimension": dimension, - "quantization_type": quant_type.value, - "original_size_mb": round(original_mb, 2), - "quantized_size_mb": round(quantized_mb, 2), - "saved_mb": round(original_mb - quantized_mb, 2), - "compression_ratio": round(ratio * 100, 2), - } - - -def _product_quantize( - vector: np.ndarray, m: int = 8, k: int = 256 -) -> Tuple[np.ndarray, np.ndarray]: - """ - 乘积量化 (PQ) 简化实现 - - Args: - vector: 输入向量 (D,) - m: 子空间数量 - k: 每个子空间的聚类中心数 - - Returns: - (编码后的向量, 聚类中心) - """ - d = vector.shape[0] - if d % m != 0: - raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除") - - ds = d // m # 子空间维度 - codes = np.zeros(m, dtype=np.uint8) - centroids = np.zeros((m, k, ds), dtype=np.float32) - - # 这里采用一种简化的 PQ:不进行 K-means 训练, - # 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练) - # 为了演示,我们直接将子空间切分为 k 份进行量化 - for i in range(m): - sub_vec = vector[i * ds : (i + 1) * ds] - # 简化:假定每个子空间的取值范围并划分 - # 实际 PQ 应使用 k-means 产生的 centroids - # 这里为演示创建一个随机 codebook 并找到最接近的核心 - sub_min, sub_max = np.min(sub_vec), np.max(sub_vec) - if sub_max == sub_min: - linspace = np.zeros(k) - else: - linspace = np.linspace(sub_min, sub_max, k) - - for j in range(k): - centroids[i, j, :] = linspace[j] - - # 编码:这里简化为取子空间均值找最接近的 centroid - sub_mean = np.mean(sub_vec) - code = np.argmin(np.abs(linspace - sub_mean)) - codes[i] = code - - return codes, centroids - - -def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray: - """ - PQ 反量化 - - Args: - codes: 编码向量 (M,) - centroids: 聚类中心 (M, K, DS) - - Returns: - 恢复后的向量 (D,) - """ - m, k, ds = centroids.shape - vector = np.zeros(m * ds, dtype=np.float32) - - for i in range(m): - code = codes[i] - vector[i * ds : (i + 1) * ds] = centroids[i, code, :] - - return vector diff --git a/plugins/A_memorix/core/utils/relation_query.py b/plugins/A_memorix/core/utils/relation_query.py deleted file mode 100644 index ffde9cac..00000000 --- a/plugins/A_memorix/core/utils/relation_query.py +++ /dev/null @@ -1,121 +0,0 @@ -"""关系查询规格解析工具。""" - -from __future__ import annotations - -import re -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class RelationQuerySpec: - raw: str - is_structured: bool - subject: Optional[str] - predicate: Optional[str] - object: Optional[str] - error: Optional[str] = None - - -_NATURAL_LANGUAGE_PATTERN = re.compile( - r"(^\s*(what|who|which|how|why|when|where)\b|" - r"\?|?|" - r"\b(relation|related|between)\b|" - r"(什么关系|有哪些关系|之间|关联))", - re.IGNORECASE, -) - - -def _looks_like_natural_language(raw: str) -> bool: - text = str(raw or "").strip() - if not text: - return False - return _NATURAL_LANGUAGE_PATTERN.search(text) is not None - - -def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec: - raw = str(relation_spec or "").strip() - if not raw: - return RelationQuerySpec( - raw=raw, - is_structured=False, - subject=None, - predicate=None, - object=None, - error="empty", - ) - - if "|" in raw: - parts = [p.strip() for p in raw.split("|")] - if len(parts) < 2: - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=None, - predicate=None, - object=None, - error="invalid_pipe_format", - ) - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=parts[0] or None, - predicate=parts[1] or None, - object=parts[2] if len(parts) > 2 and parts[2] else None, - ) - - if "->" in raw: - parts = [p.strip() for p in raw.split("->") if p.strip()] - if len(parts) >= 3: - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=parts[0], - predicate=parts[1], - object=parts[2], - ) - if len(parts) == 2: - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=parts[0], - predicate=None, - object=parts[1], - ) - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=None, - predicate=None, - object=None, - error="invalid_arrow_format", - ) - - if _looks_like_natural_language(raw): - return RelationQuerySpec( - raw=raw, - is_structured=False, - subject=None, - predicate=None, - object=None, - ) - - # 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。 - # 两词形式过于模糊,不再视为结构化关系查询。 - parts = raw.split() - if len(parts) == 3: - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=parts[0], - predicate=parts[1], - object=parts[2], - ) - - return RelationQuerySpec( - raw=raw, - is_structured=False, - subject=None, - predicate=None, - object=None, - ) diff --git a/plugins/A_memorix/core/utils/relation_write_service.py b/plugins/A_memorix/core/utils/relation_write_service.py deleted file mode 100644 index 6fa2e621..00000000 --- a/plugins/A_memorix/core/utils/relation_write_service.py +++ /dev/null @@ -1,166 +0,0 @@ -""" -统一关系写入与关系向量化服务。 - -规则: -1. 元数据是主数据源,向量是从索引。 -2. 关系先写 metadata,再写向量。 -3. 向量失败不回滚 metadata,依赖状态机与回填任务修复。 -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from src.common.logger import get_logger - - -logger = get_logger("A_Memorix.RelationWriteService") - - -@dataclass -class RelationWriteResult: - hash_value: str - vector_written: bool - vector_already_exists: bool - vector_state: str - - -class RelationWriteService: - """关系写入收口服务。""" - - ERROR_MAX_LEN = 500 - - def __init__( - self, - metadata_store: Any, - graph_store: Any, - vector_store: Any, - embedding_manager: Any, - ): - self.metadata_store = metadata_store - self.graph_store = graph_store - self.vector_store = vector_store - self.embedding_manager = embedding_manager - - @staticmethod - def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str: - s = str(subject or "").strip() - p = str(predicate or "").strip() - o = str(obj or "").strip() - # 双表达:兼容关键词检索与自然语言问句 - return f"{s} {p} {o}\n{s}和{o}的关系是{p}" - - async def ensure_relation_vector( - self, - hash_value: str, - subject: str, - predicate: str, - obj: str, - *, - max_error_len: int = ERROR_MAX_LEN, - ) -> RelationWriteResult: - """ - 为已有关系确保向量存在并更新状态。 - """ - if hash_value in self.vector_store: - self.metadata_store.set_relation_vector_state(hash_value, "ready") - return RelationWriteResult( - hash_value=hash_value, - vector_written=False, - vector_already_exists=True, - vector_state="ready", - ) - - self.metadata_store.set_relation_vector_state(hash_value, "pending") - try: - vector_text = self.build_relation_vector_text(subject, predicate, obj) - embedding = await self.embedding_manager.encode(vector_text) - self.vector_store.add( - vectors=embedding.reshape(1, -1), - ids=[hash_value], - ) - self.metadata_store.set_relation_vector_state(hash_value, "ready") - logger.info( - "metric.relation_vector_write_success=1 " - "metric.relation_vector_write_success_count=1 " - f"hash={hash_value[:16]}" - ) - return RelationWriteResult( - hash_value=hash_value, - vector_written=True, - vector_already_exists=False, - vector_state="ready", - ) - except ValueError: - # 向量已存在冲突,按成功处理 - self.metadata_store.set_relation_vector_state(hash_value, "ready") - return RelationWriteResult( - hash_value=hash_value, - vector_written=False, - vector_already_exists=True, - vector_state="ready", - ) - except Exception as e: - err = str(e)[:max_error_len] - self.metadata_store.set_relation_vector_state( - hash_value, - "failed", - error=err, - bump_retry=True, - ) - logger.warning( - "metric.relation_vector_write_fail=1 " - "metric.relation_vector_write_fail_count=1 " - f"hash={hash_value[:16]} " - f"err={err}" - ) - return RelationWriteResult( - hash_value=hash_value, - vector_written=False, - vector_already_exists=False, - vector_state="failed", - ) - - async def upsert_relation_with_vector( - self, - subject: str, - predicate: str, - obj: str, - confidence: float = 1.0, - source_paragraph: str = "", - metadata: Optional[Dict[str, Any]] = None, - *, - write_vector: bool = True, - ) -> RelationWriteResult: - """ - 统一关系写入: - 1) 写 metadata relation - 2) 写 graph edge relation_hash - 3) 按需写 relation vector - """ - rel_hash = self.metadata_store.add_relation( - subject=subject, - predicate=predicate, - obj=obj, - confidence=confidence, - source_paragraph=source_paragraph, - metadata=metadata or {}, - ) - self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) - - if not write_vector: - self.metadata_store.set_relation_vector_state(rel_hash, "none") - return RelationWriteResult( - hash_value=rel_hash, - vector_written=False, - vector_already_exists=False, - vector_state="none", - ) - - return await self.ensure_relation_vector( - hash_value=rel_hash, - subject=subject, - predicate=predicate, - obj=obj, - ) diff --git a/plugins/A_memorix/core/utils/retrieval_tuning_manager.py b/plugins/A_memorix/core/utils/retrieval_tuning_manager.py deleted file mode 100644 index e0e8ecd6..00000000 --- a/plugins/A_memorix/core/utils/retrieval_tuning_manager.py +++ /dev/null @@ -1,1857 +0,0 @@ -""" -Retrieval tuning manager for WebUI. -""" - -from __future__ import annotations - -import asyncio -import copy -import json -import random -import re -import time -import uuid -from collections import Counter, deque -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple - -from src.common.logger import get_logger - -from ..runtime.search_runtime_initializer import build_search_runtime -from .search_execution_service import SearchExecutionRequest, SearchExecutionService - -try: - from src.services import llm_service as llm_api -except Exception: # pragma: no cover - llm_api = None - -logger = get_logger("A_Memorix.RetrievalTuningManager") - - -OBJECTIVES = {"precision_priority", "balanced", "recall_priority"} -INTENSITIES = {"quick": 8, "standard": 20, "deep": 32} -CATEGORIES = {"query_nl", "query_kw", "spo_relation", "spo_search"} -_RUNTIME_CONFIG_INSTANCE_KEYS = { - "vector_store", - "graph_store", - "metadata_store", - "embedding_manager", - "sparse_index", - "relation_write_service", - "plugin_instance", -} - - -def _now() -> float: - return time.time() - - -def _clamp_int(value: Any, default: int, min_value: int, max_value: int) -> int: - try: - parsed = int(value) - except Exception: - parsed = int(default) - return max(min_value, min(max_value, parsed)) - - -def _clamp_float(value: Any, default: float, min_value: float, max_value: float) -> float: - try: - parsed = float(value) - except Exception: - parsed = float(default) - return max(min_value, min(max_value, parsed)) - - -def _coerce_bool(value: Any, default: bool) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - text = str(value).strip().lower() - if text in {"1", "true", "yes", "y", "on"}: - return True - if text in {"0", "false", "no", "n", "off"}: - return False - return default - - -def _nested_get(data: Dict[str, Any], key: str, default: Any = None) -> Any: - cur: Any = data - for part in key.split("."): - if isinstance(cur, dict) and part in cur: - cur = cur[part] - else: - return default - return cur - - -def _nested_set(data: Dict[str, Any], key: str, value: Any) -> None: - parts = key.split(".") - cur = data - for part in parts[:-1]: - if part not in cur or not isinstance(cur[part], dict): - cur[part] = {} - cur = cur[part] - cur[parts[-1]] = value - - -def _deep_merge(base: Dict[str, Any], patch: Dict[str, Any]) -> Dict[str, Any]: - out = copy.deepcopy(base) - for key, value in (patch or {}).items(): - if isinstance(value, dict) and isinstance(out.get(key), dict): - out[key] = _deep_merge(out[key], value) - else: - out[key] = copy.deepcopy(value) - return out - - -def _safe_json_loads(text: str) -> Optional[Any]: - raw = str(text or "").strip() - if not raw: - return None - if "```" in raw: - raw = raw.replace("```json", "```") - for seg in raw.split("```"): - seg = seg.strip() - if seg.startswith("{") or seg.startswith("["): - raw = seg - break - try: - return json.loads(raw) - except Exception: - pass - s = raw.find("{") - e = raw.rfind("}") - if s >= 0 and e > s: - try: - return json.loads(raw[s : e + 1]) - except Exception: - return None - return None - - -@dataclass -class RetrievalQueryCase: - case_id: str - category: str - query: str - expected_hashes: List[str] = field(default_factory=list) - expected_spo: Dict[str, str] = field(default_factory=dict) - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: - return { - "case_id": self.case_id, - "category": self.category, - "query": self.query, - "expected_hashes": list(self.expected_hashes), - "expected_spo": dict(self.expected_spo), - "metadata": dict(self.metadata), - } - - -@dataclass -class RetrievalTuningRoundRecord: - round_index: int - candidate_profile: Dict[str, Any] - metrics: Dict[str, Any] - score: float - latency_ms: float - failure_summary: Dict[str, Any] = field(default_factory=dict) - created_at: float = field(default_factory=_now) - - def to_dict(self) -> Dict[str, Any]: - return { - "round_index": self.round_index, - "candidate_profile": copy.deepcopy(self.candidate_profile), - "metrics": copy.deepcopy(self.metrics), - "score": float(self.score), - "latency_ms": float(self.latency_ms), - "failure_summary": copy.deepcopy(self.failure_summary), - "created_at": float(self.created_at), - } - - -@dataclass -class RetrievalTuningTaskRecord: - task_id: str - status: str - progress: float - objective: str - intensity: str - rounds_total: int - rounds_done: int = 0 - best_profile: Dict[str, Any] = field(default_factory=dict) - best_metrics: Dict[str, Any] = field(default_factory=dict) - best_score: float = -1.0 - baseline_profile: Dict[str, Any] = field(default_factory=dict) - baseline_metrics: Dict[str, Any] = field(default_factory=dict) - error: str = "" - params: Dict[str, Any] = field(default_factory=dict) - query_set_stats: Dict[str, Any] = field(default_factory=dict) - artifact_paths: Dict[str, str] = field(default_factory=dict) - rounds: List[RetrievalTuningRoundRecord] = field(default_factory=list) - cancel_requested: bool = False - created_at: float = field(default_factory=_now) - started_at: Optional[float] = None - finished_at: Optional[float] = None - updated_at: float = field(default_factory=_now) - apply_log: List[Dict[str, Any]] = field(default_factory=list) - - def to_summary(self) -> Dict[str, Any]: - return { - "task_id": self.task_id, - "status": self.status, - "progress": self.progress, - "objective": self.objective, - "intensity": self.intensity, - "rounds_total": self.rounds_total, - "rounds_done": self.rounds_done, - "best_score": self.best_score, - "error": self.error, - "query_set_stats": dict(self.query_set_stats), - "artifact_paths": dict(self.artifact_paths), - "created_at": self.created_at, - "started_at": self.started_at, - "finished_at": self.finished_at, - "updated_at": self.updated_at, - } - - def to_detail(self, include_rounds: bool = False) -> Dict[str, Any]: - payload = self.to_summary() - payload.update( - { - "params": copy.deepcopy(self.params), - "best_profile": copy.deepcopy(self.best_profile), - "best_metrics": copy.deepcopy(self.best_metrics), - "baseline_profile": copy.deepcopy(self.baseline_profile), - "baseline_metrics": copy.deepcopy(self.baseline_metrics), - "apply_log": copy.deepcopy(self.apply_log), - } - ) - if include_rounds: - payload["rounds"] = [x.to_dict() for x in self.rounds] - return payload - - -class RetrievalTuningManager: - def __init__( - self, - plugin: Any, - *, - import_write_blocked_provider: Optional[Callable[[], bool]] = None, - ): - self.plugin = plugin - self._import_write_blocked_provider = import_write_blocked_provider - - self._lock = asyncio.Lock() - self._tasks: Dict[str, RetrievalTuningTaskRecord] = {} - self._task_order: deque[str] = deque() - self._queue: deque[str] = deque() - self._active_task_id: Optional[str] = None - self._worker_task: Optional[asyncio.Task] = None - self._stopping = False - - self._rollback_snapshot: Optional[Dict[str, Any]] = None - - self._artifacts_root = Path(__file__).resolve().parents[2] / "artifacts" / "retrieval_tuning" - self._artifacts_root.mkdir(parents=True, exist_ok=True) - - def _cfg(self, key: str, default: Any = None) -> Any: - getter = getattr(self.plugin, "get_config", None) - if callable(getter): - return getter(key, default) - return default - - def _is_enabled(self) -> bool: - return bool(self._cfg("web.tuning.enabled", True)) - - def _queue_limit(self) -> int: - return _clamp_int(self._cfg("web.tuning.max_queue_size", 8), 8, 1, 100) - - def _poll_interval_s(self) -> float: - ms = _clamp_int(self._cfg("web.tuning.poll_interval_ms", 1200), 1200, 200, 60000) - return max(0.2, ms / 1000.0) - - def _llm_retry_cfg(self) -> Dict[str, Any]: - return { - "max_attempts": _clamp_int(self._cfg("web.tuning.llm_retry.max_attempts", 3), 3, 1, 10), - "min_wait_seconds": _clamp_float(self._cfg("web.tuning.llm_retry.min_wait_seconds", 2), 2.0, 0.1, 60.0), - "max_wait_seconds": _clamp_float(self._cfg("web.tuning.llm_retry.max_wait_seconds", 20), 20.0, 0.2, 120.0), - "backoff_multiplier": _clamp_float(self._cfg("web.tuning.llm_retry.backoff_multiplier", 2), 2.0, 1.0, 10.0), - } - - def _eval_query_timeout_s(self) -> float: - return _clamp_float( - self._cfg("web.tuning.eval_query_timeout_seconds", 10.0), - 10.0, - 0.01, - 120.0, - ) - - def get_runtime_settings(self) -> Dict[str, Any]: - intensity = str(self._cfg("web.tuning.default_intensity", "standard") or "standard") - if intensity not in INTENSITIES: - intensity = "standard" - objective = str(self._cfg("web.tuning.default_objective", "precision_priority") or "precision_priority") - if objective not in OBJECTIVES: - objective = "precision_priority" - return { - "enabled": self._is_enabled(), - "poll_interval_ms": _clamp_int(self._cfg("web.tuning.poll_interval_ms", 1200), 1200, 200, 60000), - "max_queue_size": self._queue_limit(), - "default_objective": objective, - "default_intensity": intensity, - "default_rounds": INTENSITIES[intensity], - "default_top_k_eval": _clamp_int(self._cfg("web.tuning.default_top_k_eval", 20), 20, 5, 100), - "default_sample_size": _clamp_int(self._cfg("web.tuning.default_sample_size", 24), 24, 4, 200), - "eval_query_timeout_seconds": self._eval_query_timeout_s(), - "llm_retry": self._llm_retry_cfg(), - } - - def _ensure_ready(self) -> None: - required = ("metadata_store", "vector_store", "graph_store", "embedding_manager") - missing = [x for x in required if getattr(self.plugin, x, None) is None] - if missing: - raise ValueError(f"调优依赖未初始化: {', '.join(missing)}") - checker = getattr(self.plugin, "is_runtime_ready", None) - if callable(checker) and not checker(): - raise ValueError("插件运行时未就绪") - provider = self._import_write_blocked_provider - if provider is not None and bool(provider()): - raise ValueError("导入任务运行中,当前禁止启动检索调优") - - def get_profile_snapshot(self) -> Dict[str, Any]: - cfg = getattr(self.plugin, "config", {}) or {} - profile = { - "retrieval": { - "top_k_paragraphs": _nested_get(cfg, "retrieval.top_k_paragraphs", 20), - "top_k_relations": _nested_get(cfg, "retrieval.top_k_relations", 10), - "top_k_final": _nested_get(cfg, "retrieval.top_k_final", 10), - "alpha": _nested_get(cfg, "retrieval.alpha", 0.5), - "enable_ppr": _nested_get(cfg, "retrieval.enable_ppr", True), - "search": {"smart_fallback": {"enabled": _nested_get(cfg, "retrieval.search.smart_fallback.enabled", True)}}, - "sparse": { - "enabled": _nested_get(cfg, "retrieval.sparse.enabled", True), - "mode": _nested_get(cfg, "retrieval.sparse.mode", "auto"), - "candidate_k": _nested_get(cfg, "retrieval.sparse.candidate_k", 80), - "relation_candidate_k": _nested_get(cfg, "retrieval.sparse.relation_candidate_k", 60), - }, - "fusion": { - "method": _nested_get(cfg, "retrieval.fusion.method", "weighted_rrf"), - "rrf_k": _nested_get(cfg, "retrieval.fusion.rrf_k", 60), - "vector_weight": _nested_get(cfg, "retrieval.fusion.vector_weight", 0.7), - "bm25_weight": _nested_get(cfg, "retrieval.fusion.bm25_weight", 0.3), - }, - }, - "threshold": { - "percentile": _nested_get(cfg, "threshold.percentile", 75.0), - "min_results": _nested_get(cfg, "threshold.min_results", 3), - }, - } - return self._normalize_profile(profile, fallback=profile) - - def _normalize_profile(self, profile: Optional[Dict[str, Any]], *, fallback: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - raw = copy.deepcopy(profile or {}) - base = copy.deepcopy(fallback or self.get_profile_snapshot()) - - def pick(path: str, default: Any) -> Any: - if _nested_get(raw, path, None) is not None: - return _nested_get(raw, path, default) - if path in raw: - return raw.get(path, default) - return _nested_get(base, path, default) - - fusion_method = str(pick("retrieval.fusion.method", "weighted_rrf") or "weighted_rrf").strip().lower() - if fusion_method not in {"weighted_rrf", "alpha_legacy"}: - fusion_method = "weighted_rrf" - - sparse_mode = str(pick("retrieval.sparse.mode", "auto") or "auto").strip().lower() - if sparse_mode not in {"auto", "hybrid", "fallback_only"}: - sparse_mode = "auto" - - vec_w = _clamp_float(pick("retrieval.fusion.vector_weight", 0.7), 0.7, 0.0, 1.0) - bm_w = _clamp_float(pick("retrieval.fusion.bm25_weight", 0.3), 0.3, 0.0, 1.0) - s = vec_w + bm_w - if s <= 1e-9: - vec_w, bm_w = 0.7, 0.3 - else: - vec_w, bm_w = vec_w / s, bm_w / s - - return { - "retrieval": { - "top_k_paragraphs": _clamp_int(pick("retrieval.top_k_paragraphs", 20), 20, 10, 1200), - "top_k_relations": _clamp_int(pick("retrieval.top_k_relations", 10), 10, 4, 512), - "top_k_final": _clamp_int(pick("retrieval.top_k_final", 10), 10, 4, 512), - "alpha": _clamp_float(pick("retrieval.alpha", 0.5), 0.5, 0.0, 1.0), - "enable_ppr": _coerce_bool(pick("retrieval.enable_ppr", True), True), - "search": {"smart_fallback": {"enabled": _coerce_bool(pick("retrieval.search.smart_fallback.enabled", True), True)}}, - "sparse": { - "enabled": _coerce_bool(pick("retrieval.sparse.enabled", True), True), - "mode": sparse_mode, - "candidate_k": _clamp_int(pick("retrieval.sparse.candidate_k", 80), 80, 20, 2000), - "relation_candidate_k": _clamp_int(pick("retrieval.sparse.relation_candidate_k", 60), 60, 20, 2000), - }, - "fusion": { - "method": fusion_method, - "rrf_k": _clamp_int(pick("retrieval.fusion.rrf_k", 60), 60, 1, 500), - "vector_weight": float(vec_w), - "bm25_weight": float(bm_w), - }, - }, - "threshold": { - "percentile": _clamp_float(pick("threshold.percentile", 75.0), 75.0, 1.0, 99.0), - "min_results": _clamp_int(pick("threshold.min_results", 3), 3, 1, 100), - }, - } - - def _apply_profile_to_runtime(self, normalized: Dict[str, Any]) -> None: - if not isinstance(getattr(self.plugin, "config", None), dict): - raise RuntimeError("插件 config 不可写") - for key, value in normalized.items(): - _nested_set(self.plugin.config, key, value) - plugin_cfg = getattr(self.plugin, "_plugin_config", None) - if isinstance(plugin_cfg, dict): - for key, value in normalized.items(): - _nested_set(plugin_cfg, key, value) - - async def apply_profile(self, profile: Dict[str, Any], *, reason: str = "manual") -> Dict[str, Any]: - normalized = self._normalize_profile(profile) - current = self.get_profile_snapshot() - self._rollback_snapshot = current - self._apply_profile_to_runtime(normalized) - return { - "applied": normalized, - "rollback_snapshot": current, - "reason": reason, - "applied_at": _now(), - } - - async def rollback_profile(self) -> Dict[str, Any]: - if not self._rollback_snapshot: - raise ValueError("暂无可回滚的参数快照") - target = self._normalize_profile(self._rollback_snapshot, fallback=self._rollback_snapshot) - self._apply_profile_to_runtime(target) - return {"rolled_back_to": target, "rolled_back_at": _now()} - - def export_toml_snippet(self, profile: Optional[Dict[str, Any]] = None) -> str: - p = self._normalize_profile(profile or self.get_profile_snapshot()) - r = p["retrieval"] - t = p["threshold"] - lines = [ - "[retrieval]", - f"top_k_paragraphs = {int(r['top_k_paragraphs'])}", - f"top_k_relations = {int(r['top_k_relations'])}", - f"top_k_final = {int(r['top_k_final'])}", - f"alpha = {float(r['alpha']):.4f}", - f"enable_ppr = {str(bool(r['enable_ppr'])).lower()}", - "", - "[retrieval.search.smart_fallback]", - f"enabled = {str(bool(r['search']['smart_fallback']['enabled'])).lower()}", - "", - "[retrieval.sparse]", - f"enabled = {str(bool(r['sparse']['enabled'])).lower()}", - f"mode = \"{r['sparse']['mode']}\"", - f"candidate_k = {int(r['sparse']['candidate_k'])}", - f"relation_candidate_k = {int(r['sparse']['relation_candidate_k'])}", - "", - "[retrieval.fusion]", - f"method = \"{r['fusion']['method']}\"", - f"rrf_k = {int(r['fusion']['rrf_k'])}", - f"vector_weight = {float(r['fusion']['vector_weight']):.4f}", - f"bm25_weight = {float(r['fusion']['bm25_weight']):.4f}", - "", - "[threshold]", - f"percentile = {float(t['percentile']):.4f}", - f"min_results = {int(t['min_results'])}", - ] - return "\n".join(lines).strip() + "\n" - - def _pending_task_count(self) -> int: - return sum(1 for t in self._tasks.values() if t.status in {"queued", "running", "cancel_requested"}) - - def _sample_triples_for_query_set( - self, - *, - triples: List[Tuple[Any, Any, Any, Any]], - sample_size: int, - seed: int, - ) -> Tuple[List[Tuple[str, str, str, str]], Dict[str, Any]]: - normalized: List[Tuple[str, str, str, str]] = [] - for row in triples: - try: - subject, predicate, obj, rel_hash = row - except Exception: - continue - relation_hash = str(rel_hash or "").strip() - if not relation_hash: - continue - normalized.append((str(subject or ""), str(predicate or ""), str(obj or ""), relation_hash)) - - if not normalized: - return [], {"error": "no_relations"} - - target = min(max(4, int(sample_size)), len(normalized)) - predicate_counter = Counter([str(x[1] or "").strip() or "__empty__" for x in normalized]) - entity_counter = Counter() - for subj, _, obj, _ in normalized: - entity_counter.update([str(subj or "").strip().lower() or "__empty__"]) - entity_counter.update([str(obj or "").strip().lower() or "__empty__"]) - - if target >= len(normalized): - return list(normalized), { - "strategy": "all", - "sample_size": int(target), - "total_triples": int(len(normalized)), - "predicate_total": int(len(predicate_counter)), - "predicate_sampled": int(len(predicate_counter)), - } - - rng = random.Random(f"{seed}:triple_sample") - by_predicate: Dict[str, List[int]] = {} - for idx, (_, predicate, _, _) in enumerate(normalized): - key = str(predicate or "").strip() or "__empty__" - by_predicate.setdefault(key, []).append(idx) - for pool in by_predicate.values(): - rng.shuffle(pool) - - predicate_order = sorted(by_predicate.keys()) - rng.shuffle(predicate_order) - - selected: List[int] = [] - selected_set = set() - - # First pass: predicate round-robin to avoid head predicate dominating query set. - while len(selected) < target: - progressed = False - for key in predicate_order: - pool = by_predicate.get(key, []) - if not pool: - continue - idx = int(pool.pop()) - if idx in selected_set: - continue - selected.append(idx) - selected_set.add(idx) - progressed = True - if len(selected) >= target: - break - if not progressed: - break - - if len(selected) < target: - remain = [idx for idx in range(len(normalized)) if idx not in selected_set] - rng.shuffle(remain) - - # Second pass: prefer lower-frequency entities and predicates for better diversity. - def _remain_score(idx: int) -> Tuple[int, int]: - subj, predicate, obj, _ = normalized[idx] - subject_freq = int(entity_counter.get(str(subj or "").strip().lower() or "__empty__", 0)) - object_freq = int(entity_counter.get(str(obj or "").strip().lower() or "__empty__", 0)) - pred_freq = int(predicate_counter.get(str(predicate or "").strip() or "__empty__", 0)) - return (subject_freq + object_freq, pred_freq) - - remain = sorted(remain, key=_remain_score) - need = target - len(selected) - for idx in remain[:need]: - selected.append(idx) - selected_set.add(idx) - - selected = selected[:target] - sampled = [normalized[idx] for idx in selected] - sampled_predicates = {str(x[1] or "").strip() or "__empty__" for x in sampled} - - return sampled, { - "strategy": "predicate_round_robin_entity_diversity", - "sample_size": int(target), - "total_triples": int(len(normalized)), - "predicate_total": int(len(predicate_counter)), - "predicate_sampled": int(len(sampled_predicates)), - } - - def _select_round_eval_cases( - self, - *, - cases: List[RetrievalQueryCase], - intensity: str, - round_index: int, - seed: int, - ) -> List[RetrievalQueryCase]: - if not cases: - return [] - mode = str(intensity or "standard").strip().lower() - if mode not in INTENSITIES: - mode = "standard" - if mode == "deep": - return list(cases) - - if mode == "quick": - ratio = 0.45 - min_total = 16 - else: - ratio = 0.70 - min_total = 24 - - total = len(cases) - target = max(min_total, int(total * ratio)) - if target >= total: - return list(cases) - - rng = random.Random(f"{seed}:{round_index}:subset") - by_cat: Dict[str, List[RetrievalQueryCase]] = {} - for item in cases: - by_cat.setdefault(str(item.category), []).append(item) - - selected: List[RetrievalQueryCase] = [] - selected_ids = set() - cat_names = sorted([x for x in by_cat.keys() if x in CATEGORIES]) - if not cat_names: - cat_names = sorted(by_cat.keys()) - per_cat = max(1, target // max(1, len(cat_names))) - - for cat in cat_names: - pool = by_cat.get(cat, []) - if not pool: - continue - picked = list(pool) if len(pool) <= per_cat else rng.sample(pool, per_cat) - for item in picked: - if item.case_id in selected_ids: - continue - selected.append(item) - selected_ids.add(item.case_id) - - if len(selected) < target: - remain = [x for x in cases if x.case_id not in selected_ids] - if len(remain) > (target - len(selected)): - remain = rng.sample(remain, target - len(selected)) - for item in remain: - selected.append(item) - selected_ids.add(item.case_id) - - return selected[:target] - - async def _ensure_worker(self) -> None: - async with self._lock: - if self._worker_task and not self._worker_task.done(): - return - self._stopping = False - self._worker_task = asyncio.create_task(self._worker_loop()) - - async def shutdown(self) -> None: - self._stopping = True - worker = self._worker_task - if worker is None or worker.done(): - return - worker.cancel() - try: - await worker - except asyncio.CancelledError: - pass - except Exception as e: - logger.warning(f"Retrieval tuning worker shutdown failed: {e}") - - async def create_task(self, payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("检索调优中心已禁用") - self._ensure_ready() - - data = payload or {} - objective = str(data.get("objective") or self._cfg("web.tuning.default_objective", "precision_priority")) - if objective not in OBJECTIVES: - raise ValueError(f"objective 非法: {objective}") - - intensity = str(data.get("intensity") or self._cfg("web.tuning.default_intensity", "standard")) - if intensity not in INTENSITIES: - raise ValueError(f"intensity 非法: {intensity}") - - rounds_total = _clamp_int(data.get("rounds", INTENSITIES[intensity]), INTENSITIES[intensity], 1, 200) - sample_size = _clamp_int(data.get("sample_size", self._cfg("web.tuning.default_sample_size", 24)), 24, 4, 500) - top_k_eval = _clamp_int(data.get("top_k_eval", self._cfg("web.tuning.default_top_k_eval", 20)), 20, 5, 100) - eval_query_timeout_seconds = _clamp_float( - data.get("eval_query_timeout_seconds", self._eval_query_timeout_s()), - self._eval_query_timeout_s(), - 0.01, - 120.0, - ) - llm_enabled = _coerce_bool(data.get("llm_enabled", True), True) - seed = data.get("seed") - try: - seed = int(seed) - except Exception: - seed = int(time.time()) % 1000003 - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("调优任务队列已满,请稍后重试") - task = RetrievalTuningTaskRecord( - task_id=uuid.uuid4().hex, - status="queued", - progress=0.0, - objective=objective, - intensity=intensity, - rounds_total=rounds_total, - params={ - "sample_size": sample_size, - "top_k_eval": top_k_eval, - "eval_query_timeout_seconds": float(eval_query_timeout_seconds), - "llm_enabled": llm_enabled, - "seed": seed, - }, - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - task.updated_at = _now() - - await self._ensure_worker() - return task.to_summary() - - async def list_tasks(self, limit: int = 50) -> List[Dict[str, Any]]: - limit = _clamp_int(limit, 50, 1, 500) - async with self._lock: - items: List[Dict[str, Any]] = [] - for task_id in list(self._task_order)[:limit]: - task = self._tasks.get(task_id) - if task: - items.append(task.to_summary()) - return items - - async def get_task(self, task_id: str, include_rounds: bool = False) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - return task.to_detail(include_rounds=include_rounds) - - async def get_rounds(self, task_id: str, offset: int = 0, limit: int = 50) -> Optional[Dict[str, Any]]: - offset = max(0, int(offset)) - limit = _clamp_int(limit, 50, 1, 500) - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - total = len(task.rounds) - sliced = task.rounds[offset : offset + limit] - return { - "total": total, - "offset": offset, - "limit": limit, - "items": [item.to_dict() for item in sliced], - } - - async def cancel_task(self, task_id: str) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - if task.status in {"completed", "failed", "cancelled"}: - return task.to_summary() - if task.status == "queued": - task.status = "cancelled" - task.cancel_requested = True - task.finished_at = _now() - task.updated_at = task.finished_at - self._queue = deque([x for x in self._queue if x != task_id]) - return task.to_summary() - task.status = "cancel_requested" - task.cancel_requested = True - task.updated_at = _now() - return task.to_summary() - - async def apply_best(self, task_id: str) -> Dict[str, Any]: - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - raise ValueError("任务不存在") - if task.status != "completed": - raise ValueError("任务未完成,无法应用最优参数") - if not task.best_profile: - raise ValueError("任务没有可应用的最优参数") - best = copy.deepcopy(task.best_profile) - applied = await self.apply_profile(best, reason=f"task:{task_id}:apply_best") - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.apply_log.append({"applied_at": _now(), "reason": "apply_best", "profile": best}) - task.updated_at = _now() - return applied - - async def get_report(self, task_id: str, fmt: str = "md") -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return None - artifacts = dict(task.artifact_paths) - fmt = str(fmt or "md").strip().lower() - if fmt not in {"md", "json"}: - fmt = "md" - path_key = "report_md" if fmt == "md" else "report_json" - path = artifacts.get(path_key) - if not path: - return {"format": fmt, "content": "", "path": ""} - p = Path(path) - if not p.exists(): - return {"format": fmt, "content": "", "path": str(p)} - try: - content = p.read_text(encoding="utf-8") - except Exception: - content = "" - return {"format": fmt, "content": content, "path": str(p)} - - async def _worker_loop(self) -> None: - while not self._stopping: - task_id: Optional[str] = None - async with self._lock: - while self._queue: - candidate = self._queue.popleft() - task = self._tasks.get(candidate) - if task is None: - continue - if task.status != "queued": - continue - task_id = candidate - self._active_task_id = candidate - break - - if not task_id: - await asyncio.sleep(self._poll_interval_s()) - continue - - try: - await self._run_task(task_id) - except asyncio.CancelledError: - raise - except Exception as e: - logger.error(f"Retrieval tuning task crashed: task_id={task_id}, err={e}") - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.status = "failed" - task.error = str(e) - task.finished_at = _now() - task.updated_at = task.finished_at - finally: - async with self._lock: - if self._active_task_id == task_id: - self._active_task_id = None - - async def _run_task(self, task_id: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - task.status = "running" - task.started_at = _now() - task.updated_at = task.started_at - - artifacts_dir = self._artifacts_root / task_id - artifacts_dir.mkdir(parents=True, exist_ok=True) - query_set_path = artifacts_dir / "query_set.json" - rounds_path = artifacts_dir / "round_metrics.jsonl" - best_profile_path = artifacts_dir / "best_profile.json" - report_json_path = artifacts_dir / "report.json" - report_md_path = artifacts_dir / "report.md" - - try: - params = dict(task.params) - cases, stats = await self._build_query_set( - sample_size=int(params["sample_size"]), - seed=int(params["seed"]), - llm_enabled=bool(params.get("llm_enabled", True)), - ) - if not cases: - raise ValueError("当前知识库样本不足,无法构建调优测试集") - - query_set_path.write_text( - json.dumps( - { - "task_id": task_id, - "created_at": _now(), - "stats": stats, - "items": [c.to_dict() for c in cases], - }, - ensure_ascii=False, - indent=2, - ), - encoding="utf-8", - ) - - baseline_profile = self.get_profile_snapshot() - top_k_eval = int(params["top_k_eval"]) - baseline_eval = await self._evaluate_profile( - profile=baseline_profile, - cases=cases, - objective=task.objective, - top_k_eval=top_k_eval, - query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), - ) - baseline_round = RetrievalTuningRoundRecord( - round_index=0, - candidate_profile=baseline_profile, - metrics=baseline_eval["metrics"], - score=float(baseline_eval["score"]), - latency_ms=float(baseline_eval["avg_elapsed_ms"]), - failure_summary=baseline_eval["failure_summary"], - ) - rounds_path.write_text(json.dumps(baseline_round.to_dict(), ensure_ascii=False) + "\n", encoding="utf-8") - - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - task.query_set_stats = stats - task.baseline_profile = copy.deepcopy(baseline_profile) - task.baseline_metrics = copy.deepcopy(baseline_eval["metrics"]) - task.rounds.append(baseline_round) - task.best_profile = copy.deepcopy(baseline_profile) - task.best_metrics = copy.deepcopy(baseline_eval["metrics"]) - task.best_score = float(baseline_eval["score"]) - task.progress = 0.0 - task.updated_at = _now() - - best_profile = copy.deepcopy(baseline_profile) - best_metrics = copy.deepcopy(baseline_eval["metrics"]) - best_failure_summary = copy.deepcopy(baseline_eval["failure_summary"]) - best_score = float(baseline_eval["score"]) - llm_suggestions: List[Dict[str, Any]] = [] - task_cancelled = False - - for round_idx in range(1, int(task.rounds_total) + 1): - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - if task.cancel_requested or task.status == "cancel_requested": - task.status = "cancelled" - task.finished_at = _now() - task.updated_at = task.finished_at - task_cancelled = True - break - - if round_idx == 1 or (round_idx % 5 == 0 and not llm_suggestions): - llm_suggestions = await self._suggest_profiles_with_llm( - base_profile=best_profile, - failure_summary=best_failure_summary, - objective=task.objective, - max_count=3, - enabled=bool(params.get("llm_enabled", True)), - ) - - candidate_profile = self._generate_candidate_profile( - task_id=task_id, - round_index=round_idx, - objective=task.objective, - baseline_profile=baseline_profile, - best_profile=best_profile, - llm_suggestions=llm_suggestions, - ) - eval_cases = self._select_round_eval_cases( - cases=cases, - intensity=task.intensity, - round_index=round_idx, - seed=int(params.get("seed", 0)), - ) - eval_result = await self._evaluate_profile( - profile=candidate_profile, - cases=eval_cases, - objective=task.objective, - top_k_eval=top_k_eval, - query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), - ) - round_record = RetrievalTuningRoundRecord( - round_index=round_idx, - candidate_profile=candidate_profile, - metrics=eval_result["metrics"], - score=float(eval_result["score"]), - latency_ms=float(eval_result["avg_elapsed_ms"]), - failure_summary=eval_result["failure_summary"], - ) - with rounds_path.open("a", encoding="utf-8") as fp: - fp.write(json.dumps(round_record.to_dict(), ensure_ascii=False) + "\n") - - if float(eval_result["score"]) > float(best_score): - best_score = float(eval_result["score"]) - best_profile = copy.deepcopy(candidate_profile) - best_metrics = copy.deepcopy(eval_result["metrics"]) - best_failure_summary = copy.deepcopy(eval_result["failure_summary"]) - - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - task.rounds_done = round_idx - task.rounds.append(round_record) - task.best_profile = copy.deepcopy(best_profile) - task.best_metrics = copy.deepcopy(best_metrics) - task.best_score = float(best_score) - task.progress = min(1.0, float(round_idx) / float(task.rounds_total)) - task.updated_at = _now() - - if best_profile and (not task_cancelled): - # 候选轮可能基于子样本评估,收官时用全量样本复核,确保最终指标可解释。 - best_full = await self._evaluate_profile( - profile=best_profile, - cases=cases, - objective=task.objective, - top_k_eval=top_k_eval, - query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), - ) - best_profile = copy.deepcopy(best_profile) - best_metrics = copy.deepcopy(best_full["metrics"]) - best_failure_summary = copy.deepcopy(best_full["failure_summary"]) - best_score = float(best_full["score"]) - if best_score < float(baseline_eval["score"]): - best_profile = copy.deepcopy(baseline_profile) - best_metrics = copy.deepcopy(baseline_eval["metrics"]) - best_failure_summary = copy.deepcopy(baseline_eval["failure_summary"]) - best_score = float(baseline_eval["score"]) - - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.best_profile = copy.deepcopy(best_profile) - task.best_metrics = copy.deepcopy(best_metrics) - task.best_score = float(best_score) - task.updated_at = _now() - - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - if task.status not in {"cancelled", "failed"}: - task.status = "completed" - task.progress = 1.0 - task.finished_at = _now() - task.updated_at = task.finished_at - final_task = copy.deepcopy(task) - - if final_task.status == "completed": - best_profile_path.write_text(json.dumps(final_task.best_profile, ensure_ascii=False, indent=2), encoding="utf-8") - report_payload = self._build_report_payload(final_task) - report_json_path.write_text(json.dumps(report_payload, ensure_ascii=False, indent=2), encoding="utf-8") - report_md_path.write_text(self._build_report_markdown(final_task, report_payload), encoding="utf-8") - - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.artifact_paths = { - "query_set": str(query_set_path), - "round_metrics_jsonl": str(rounds_path), - "best_profile": str(best_profile_path), - "report_json": str(report_json_path), - "report_md": str(report_md_path), - } - task.updated_at = _now() - except Exception as e: - logger.error(f"Retrieval tuning task failed: task_id={task_id}, err={e}") - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.status = "failed" - task.error = str(e) - task.finished_at = _now() - task.updated_at = task.finished_at - - async def _build_query_set(self, *, sample_size: int, seed: int, llm_enabled: bool) -> Tuple[List[RetrievalQueryCase], Dict[str, Any]]: - store = getattr(self.plugin, "metadata_store", None) - if store is None: - return [], {"error": "metadata_store_unavailable"} - - triples = list(store.get_all_triples() or []) - if not triples: - return [], {"error": "no_relations"} - - sampled, sample_info = self._sample_triples_for_query_set( - triples=triples, - sample_size=sample_size, - seed=seed, - ) - if not sampled: - return [], {"error": "no_relations"} - - anchors: List[Dict[str, Any]] = [] - for idx, row in enumerate(sampled): - subject, predicate, obj, relation_hash = row - paragraphs = store.get_paragraphs_by_relation(relation_hash) - para_hash = "" - para_content = "" - if paragraphs: - para_hash = str(paragraphs[0].get("hash") or "").strip() - para_content = str(paragraphs[0].get("content") or "") - anchors.append( - { - "anchor_id": f"a{idx+1:04d}", - "subject": str(subject or ""), - "predicate": str(predicate or ""), - "object": str(obj or ""), - "relation_hash": relation_hash, - "paragraph_hash": para_hash, - "paragraph_excerpt": para_content[:300], - } - ) - - if not anchors: - return [], {"error": "no_anchors"} - - predicate_groups: Dict[str, List[Dict[str, Any]]] = {} - for anchor in anchors: - predicate_groups.setdefault(str(anchor.get("predicate") or ""), []).append(anchor) - - nl_queries = await self._generate_nl_queries_with_llm(anchors, enabled=llm_enabled) - cases: List[RetrievalQueryCase] = [] - - seq = 0 - for anchor in anchors: - seq += 1 - subject = anchor["subject"] - predicate = anchor["predicate"] - obj = anchor["object"] - rel_hash = anchor["relation_hash"] - para_hash = anchor["paragraph_hash"] - expected = [rel_hash] - if para_hash: - expected.append(para_hash) - aid = anchor["anchor_id"] - - common_meta = { - "anchor_id": aid, - "relation_hash": rel_hash, - "paragraph_hash": para_hash, - "subject": subject, - "predicate": predicate, - "object": obj, - } - cases.append( - RetrievalQueryCase( - case_id=f"spo_relation_{seq:04d}", - category="spo_relation", - query=f"{subject}|{predicate}|{obj}", - expected_hashes=[rel_hash], - expected_spo={"subject": subject, "predicate": predicate, "object": obj}, - metadata=dict(common_meta), - ) - ) - cases.append( - RetrievalQueryCase( - case_id=f"spo_search_{seq:04d}", - category="spo_search", - query=self._build_spo_search_query( - anchor=anchor, - seq=seq, - predicate_groups=predicate_groups, - ), - expected_hashes=list(expected), - metadata=dict(common_meta), - ) - ) - cases.append( - RetrievalQueryCase( - case_id=f"query_kw_{seq:04d}", - category="query_kw", - query=self._build_keyword_query( - anchor=anchor, - seq=seq, - predicate_groups=predicate_groups, - ), - expected_hashes=list(expected), - metadata=dict(common_meta), - ) - ) - nl_query = nl_queries.get(aid) or self._build_nl_template( - anchor=anchor, - seq=seq, - predicate_groups=predicate_groups, - ) - cases.append( - RetrievalQueryCase( - case_id=f"query_nl_{seq:04d}", - category="query_nl", - query=nl_query, - expected_hashes=list(expected), - metadata=dict(common_meta), - ) - ) - - counts = Counter([c.category for c in cases]) - stats = { - "anchors": len(anchors), - "case_total": len(cases), - "category_counts": {k: int(v) for k, v in counts.items()}, - "seed": int(seed), - "sample_size": int(sample_info.get("sample_size", len(anchors))), - "sampling": dict(sample_info), - "llm_nl_enabled": bool(llm_enabled), - "llm_nl_generated": int(len(nl_queries)), - } - return cases, stats - - def _pick_contrast_anchor( - self, - *, - anchor: Dict[str, Any], - predicate_groups: Dict[str, List[Dict[str, Any]]], - seq: int, - ) -> Optional[Dict[str, Any]]: - predicate = str(anchor.get("predicate") or "") - pool = predicate_groups.get(predicate, []) - if not pool: - return None - candidates = [x for x in pool if x is not anchor and str(x.get("object") or "") != str(anchor.get("object") or "")] - if not candidates: - return None - return candidates[seq % len(candidates)] - - def _build_spo_search_query( - self, - *, - anchor: Dict[str, Any], - seq: int, - predicate_groups: Dict[str, List[Dict[str, Any]]], - ) -> str: - subject = str(anchor.get("subject") or "") - predicate = str(anchor.get("predicate") or "") - obj = str(anchor.get("object") or "") - contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) - contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" - - variants = [ - f"{subject} {predicate} {obj}", - f"{subject} {obj} relation {predicate}", - f"{predicate} {subject} {obj} evidence", - f"{subject} {predicate} {obj} not {contrast_obj}".strip(), - ] - return variants[seq % len(variants)].strip() - - def _build_keyword_query( - self, - *, - anchor: Dict[str, Any], - seq: int, - predicate_groups: Dict[str, List[Dict[str, Any]]], - ) -> str: - subject = str(anchor.get("subject") or "") - predicate = str(anchor.get("predicate") or "") - obj = str(anchor.get("object") or "") - excerpt = str(anchor.get("paragraph_excerpt") or "") - tokens = re.findall(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}", excerpt) - extras: List[str] = [] - seen = set() - for token in tokens: - key = token.lower() - if key in seen: - continue - if key in {subject.lower(), predicate.lower(), obj.lower()}: - continue - seen.add(key) - extras.append(token) - if len(extras) >= 2: - break - contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) - contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" - - variants = [ - [subject, obj] + extras[:2], - [predicate, obj] + extras[:2], - [subject, predicate] + extras[:2], - [subject, obj, predicate, contrast_obj] + extras[:1], - ] - parts = variants[seq % len(variants)] - return " ".join([x for x in parts if x]).strip() - - def _build_nl_template( - self, - *, - anchor: Dict[str, Any], - seq: int, - predicate_groups: Dict[str, List[Dict[str, Any]]], - ) -> str: - subject = str(anchor.get("subject") or "") - predicate = str(anchor.get("predicate") or "") - obj = str(anchor.get("object") or "") - contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) - contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" - templates = [ - f"请问 {subject} 与 {obj} 的关系是什么,是否是“{predicate}”?", - f"在当前知识库中,哪条信息说明 {subject} 对应的是 {obj},关系词接近“{predicate}”?", - f"我想确认:{subject} 和 {obj} 之间是不是“{predicate}”这层关系,而不是 {contrast_obj}?", - f"帮我查一下关于 {subject} 与 {obj} 的证据,重点看 {predicate} 相关描述。", - ] - return templates[seq % len(templates)] - - async def _select_llm_model(self) -> Optional[Any]: - if llm_api is None: - return None - try: - models = llm_api.get_available_models() or {} - except Exception: - return None - if not models: - return None - - cfg_model = str(self._cfg("advanced.extraction_model", "auto") or "auto").strip() - if cfg_model.lower() != "auto" and cfg_model in models: - return models[cfg_model] - for task_name in ["utils", "planner", "tool_use", "replyer", "embedding"]: - if task_name in models: - return models[task_name] - return models[next(iter(models))] - - async def _llm_call_text(self, prompt: str, *, request_type: str) -> str: - if llm_api is None: - raise RuntimeError("llm_api unavailable") - model_cfg = await self._select_llm_model() - if model_cfg is None: - raise RuntimeError("no_llm_model") - - retry = self._llm_retry_cfg() - max_attempts = int(retry["max_attempts"]) - min_wait = float(retry["min_wait_seconds"]) - max_wait = float(retry["max_wait_seconds"]) - backoff = float(retry["backoff_multiplier"]) - - last_error: Optional[Exception] = None - for idx in range(max_attempts): - try: - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_cfg, - request_type=request_type, - ) - if not success: - raise RuntimeError("llm_generation_failed") - text = str(response or "").strip() - if text: - return text - raise RuntimeError("empty_llm_response") - except Exception as e: - last_error = e - if idx >= max_attempts - 1: - break - delay = min(max_wait, min_wait * (backoff ** idx)) - await asyncio.sleep(max(0.05, delay)) - raise RuntimeError(f"LLM call failed: {last_error}") - - async def _generate_nl_queries_with_llm(self, anchors: List[Dict[str, Any]], *, enabled: bool) -> Dict[str, str]: - if not enabled or llm_api is None or not anchors: - return {} - payload = [ - { - "anchor_id": x["anchor_id"], - "subject": x["subject"], - "predicate": x["predicate"], - "object": x["object"], - "paragraph_excerpt": x["paragraph_excerpt"][:180], - } - for x in anchors[:60] - ] - prompt = ( - "你是检索评估问题生成器。" - "请基于给定 SPO 与简短上下文,为每条样本生成 1 条自然语言检索问题,返回 JSON:" - "{\"items\":[{\"anchor_id\":\"...\",\"query\":\"...\"}]}。\n" - f"样本:\n{json.dumps(payload, ensure_ascii=False)}" - ) - try: - raw = await self._llm_call_text(prompt, request_type="A_Memorix.RetrievalTuning.NLCaseGen") - obj = _safe_json_loads(raw) - if not isinstance(obj, dict): - return {} - items = obj.get("items") - if not isinstance(items, list): - return {} - out: Dict[str, str] = {} - for row in items: - if not isinstance(row, dict): - continue - anchor_id = str(row.get("anchor_id") or "").strip() - query = str(row.get("query") or "").strip() - if anchor_id and query: - out[anchor_id] = query - return out - except Exception: - return {} - - async def _suggest_profiles_with_llm( - self, - *, - base_profile: Dict[str, Any], - failure_summary: Dict[str, Any], - objective: str, - max_count: int, - enabled: bool, - ) -> List[Dict[str, Any]]: - if not enabled or llm_api is None or max_count <= 0: - return [] - prompt = ( - "你是检索调参专家。" - "请基于基础参数与失败摘要,给出最多 " - f"{int(max_count)} 组候选参数,返回 JSON: {{\"profiles\": [ ... ]}}。\n" - "字段仅可包含:retrieval.top_k_paragraphs, retrieval.top_k_relations, retrieval.top_k_final, " - "retrieval.alpha, retrieval.enable_ppr, retrieval.search.smart_fallback.enabled, " - "retrieval.sparse.enabled, retrieval.sparse.mode, retrieval.sparse.candidate_k, retrieval.sparse.relation_candidate_k, " - "retrieval.fusion.method, retrieval.fusion.rrf_k, retrieval.fusion.vector_weight, retrieval.fusion.bm25_weight, " - "threshold.percentile, threshold.min_results。\n" - f"objective={objective}\n" - f"base={json.dumps(base_profile, ensure_ascii=False)}\n" - f"failure_summary={json.dumps(failure_summary, ensure_ascii=False)}" - ) - try: - raw = await self._llm_call_text(prompt, request_type="A_Memorix.RetrievalTuning.ProfileSuggest") - obj = _safe_json_loads(raw) - if not isinstance(obj, dict): - return [] - profiles = obj.get("profiles") - if not isinstance(profiles, list): - return [] - out = [] - for item in profiles[:max_count]: - if isinstance(item, dict): - out.append(self._normalize_profile(item, fallback=base_profile)) - return out - except Exception: - return [] - - def _generate_candidate_profile( - self, - *, - task_id: str, - round_index: int, - objective: str, - baseline_profile: Dict[str, Any], - best_profile: Dict[str, Any], - llm_suggestions: List[Dict[str, Any]], - ) -> Dict[str, Any]: - if llm_suggestions: - return self._normalize_profile(llm_suggestions.pop(0), fallback=best_profile) - - rng = random.Random(f"{task_id}:{round_index}") - base = baseline_profile if round_index % 4 == 1 else best_profile - candidate = copy.deepcopy(base) - - if objective == "precision_priority": - para_choices = [40, 80, 120, 180, 240, 320] - rel_choices = [4, 8, 12, 16, 24] - final_choices = [4, 8, 12, 16, 20, 32, 48, 64] - alpha_choices = [0.0, 0.35, 0.50, 0.62, 0.72, 0.82, 0.90] - pct_choices = [55, 60, 65, 72, 80] - min_results_choices = [1, 2] - elif objective == "recall_priority": - para_choices = [120, 220, 300, 420, 560, 720] - rel_choices = [8, 12, 16, 24, 32] - final_choices = [8, 16, 32, 48, 64, 96, 128] - alpha_choices = [0.20, 0.35, 0.45, 0.55, 0.65, 0.75] - pct_choices = [40, 48, 55, 62] - min_results_choices = [1, 2, 3] - else: - para_choices = [80, 160, 240, 320, 420, 520] - rel_choices = [6, 10, 14, 18, 24, 30] - final_choices = [6, 12, 20, 32, 48, 64, 80] - alpha_choices = [0.25, 0.45, 0.55, 0.65, 0.75, 0.85] - pct_choices = [48, 55, 62, 70] - min_results_choices = [1, 2, 3] - - _nested_set(candidate, "retrieval.top_k_paragraphs", rng.choice(para_choices)) - _nested_set(candidate, "retrieval.top_k_relations", rng.choice(rel_choices)) - _nested_set(candidate, "retrieval.top_k_final", rng.choice(final_choices)) - _nested_set(candidate, "retrieval.alpha", rng.choice(alpha_choices)) - # PPR 在 TestClient/异步评估场景下存在偶发长时阻塞风险,调优评估链路固定关闭。 - _nested_set(candidate, "retrieval.enable_ppr", False) - _nested_set(candidate, "retrieval.search.smart_fallback.enabled", bool(rng.choice([True, True, False]))) - _nested_set(candidate, "retrieval.sparse.enabled", bool(rng.choice([True, True, False]))) - _nested_set(candidate, "retrieval.sparse.mode", rng.choice(["auto", "hybrid", "fallback_only"])) - _nested_set(candidate, "retrieval.sparse.candidate_k", rng.choice([60, 80, 120, 160, 220, 320])) - _nested_set(candidate, "retrieval.sparse.relation_candidate_k", rng.choice([40, 60, 90, 120, 180, 260])) - _nested_set(candidate, "retrieval.fusion.method", rng.choice(["weighted_rrf", "weighted_rrf", "alpha_legacy"])) - _nested_set(candidate, "retrieval.fusion.rrf_k", rng.choice([30, 45, 60, 75, 90])) - vec_w = float(rng.choice([0.55, 0.65, 0.72, 0.80, 0.88])) - _nested_set(candidate, "retrieval.fusion.vector_weight", vec_w) - _nested_set(candidate, "retrieval.fusion.bm25_weight", 1.0 - vec_w) - _nested_set(candidate, "threshold.percentile", rng.choice(pct_choices)) - _nested_set(candidate, "threshold.min_results", rng.choice(min_results_choices)) - - return self._normalize_profile(candidate, fallback=base) - - def _build_runtime_config(self, normalized_profile: Dict[str, Any]) -> Dict[str, Any]: - raw_base = getattr(self.plugin, "config", {}) or {} - if isinstance(raw_base, dict): - base = { - key: value - for key, value in raw_base.items() - if key not in _RUNTIME_CONFIG_INSTANCE_KEYS - } - else: - base = {} - merged = _deep_merge(base, normalized_profile) - # 调优评估场景优先稳定性,避免并发访问共享 SQLite/Faiss 导致长时阻塞。 - _nested_set(merged, "retrieval.enable_parallel", False) - # 调优评估阶段关闭 PPR,规避 PageRank 线程计算偶发阻塞导致整轮卡死。 - _nested_set(merged, "retrieval.enable_ppr", False) - merged["vector_store"] = getattr(self.plugin, "vector_store", None) - merged["graph_store"] = getattr(self.plugin, "graph_store", None) - merged["metadata_store"] = getattr(self.plugin, "metadata_store", None) - merged["embedding_manager"] = getattr(self.plugin, "embedding_manager", None) - merged["sparse_index"] = getattr(self.plugin, "sparse_index", None) - merged["plugin_instance"] = self.plugin - return merged - - async def _evaluate_profile( - self, - *, - profile: Dict[str, Any], - cases: List[RetrievalQueryCase], - objective: str, - top_k_eval: int, - query_timeout_s: float, - ) -> Dict[str, Any]: - normalized = self._normalize_profile(profile) - eval_top_k = _clamp_int(top_k_eval, 20, 1, 1000) - # 评估时让 top_k_final 参与有效召回深度,避免该参数对评分无影响。 - request_top_k = min( - int(eval_top_k), - _clamp_int(_nested_get(normalized, "retrieval.top_k_final", eval_top_k), eval_top_k, 1, 512), - ) - eval_timeout_s = _clamp_float( - query_timeout_s, - self._eval_query_timeout_s(), - 0.01, - 120.0, - ) - runtime_cfg = self._build_runtime_config(normalized) - runtime = build_search_runtime( - plugin_config=runtime_cfg, - logger_obj=logger, - owner_tag="retrieval_tuning", - log_prefix="[RetrievalTuning]", - ) - if not runtime.ready: - metrics = { - "total_text_cases": 0, - "precision_at_1": 0.0, - "precision_at_3": 0.0, - "mrr": 0.0, - "recall_at_k": 0.0, - "spo_relation_hit_rate": 0.0, - "empty_rate": 1.0, - "avg_elapsed_ms": 0.0, - "category": {}, - "error": runtime.error or "runtime_not_ready", - } - return {"metrics": metrics, "score": -1.0, "avg_elapsed_ms": 0.0, "failure_summary": {"reason": metrics["error"]}} - - text_total = 0 - hit1 = 0 - hit3 = 0 - hitk = 0 - mrr_sum = 0.0 - empty_count = 0 - timeout_count = 0 - elapsed_total = 0.0 - text_failed: List[str] = [] - - spo_total = 0 - spo_hit = 0 - spo_failed: List[str] = [] - - category_stats: Dict[str, Dict[str, Any]] = {} - failed_predicates = Counter() - - for case in cases: - cat = str(case.category) - if cat not in CATEGORIES: - continue - if cat not in category_stats: - category_stats[cat] = { - "total": 0, - "hit": 0, - "hit_at_1": 0, - "hit_at_3": 0, - "empty": 0, - } - category_stats[cat]["total"] += 1 - - if cat == "spo_relation": - spo_total += 1 - spo = case.expected_spo or {} - rows = runtime.metadata_store.get_relations( - subject=str(spo.get("subject") or ""), - predicate=str(spo.get("predicate") or ""), - object=str(spo.get("object") or ""), - ) - expected_hash = str(case.expected_hashes[0]) if case.expected_hashes else "" - ok = False - for row in rows: - if not isinstance(row, dict): - continue - if expected_hash and str(row.get("hash") or "") == expected_hash: - ok = True - break - if not expected_hash: - ok = True - break - if ok: - spo_hit += 1 - category_stats[cat]["hit"] += 1 - category_stats[cat]["hit_at_1"] += 1 - category_stats[cat]["hit_at_3"] += 1 - else: - spo_failed.append(case.case_id) - failed_predicates.update([str(spo.get("predicate") or "").strip() or "__empty__"]) - continue - - text_total += 1 - req = SearchExecutionRequest( - caller="retrieval_tuning", - query_type="search", - query=str(case.query or "").strip(), - top_k=int(request_top_k), - use_threshold=True, - # 调优评估固定关闭 PPR,避免该链路阻塞拖挂整轮任务。 - enable_ppr=False, - ) - try: - execution = await asyncio.wait_for( - SearchExecutionService.execute( - retriever=runtime.retriever, - threshold_filter=runtime.threshold_filter, - plugin_config=runtime_cfg, - request=req, - enforce_chat_filter=False, - reinforce_access=False, - ), - timeout=float(eval_timeout_s), - ) - except asyncio.TimeoutError: - timeout_count += 1 - empty_count += 1 - category_stats[cat]["empty"] += 1 - text_failed.append(case.case_id) - failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) - continue - - if execution is None: - empty_count += 1 - category_stats[cat]["empty"] += 1 - text_failed.append(case.case_id) - failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) - continue - - elapsed_total += float(getattr(execution, "elapsed_ms", 0.0) or 0.0) - - if not bool(getattr(execution, "success", False)): - empty_count += 1 - category_stats[cat]["empty"] += 1 - text_failed.append(case.case_id) - failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) - continue - - hashes = [str(getattr(x, "hash_value", "") or "") for x in (getattr(execution, "results", None) or [])] - if not hashes: - empty_count += 1 - category_stats[cat]["empty"] += 1 - - expected_set = set(case.expected_hashes or []) - rank = 0 - for idx, hv in enumerate(hashes, start=1): - if hv and hv in expected_set: - rank = idx - break - - if rank > 0: - category_stats[cat]["hit"] += 1 - hitk += 1 - if rank <= 1: - hit1 += 1 - category_stats[cat]["hit_at_1"] += 1 - if rank <= 3: - hit3 += 1 - category_stats[cat]["hit_at_3"] += 1 - mrr_sum += 1.0 / float(rank) - else: - text_failed.append(case.case_id) - failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) - - p1 = (hit1 / text_total) if text_total else 0.0 - p3 = (hit3 / text_total) if text_total else 0.0 - recall = (hitk / text_total) if text_total else 0.0 - mrr = (mrr_sum / text_total) if text_total else 0.0 - spo_rate = (spo_hit / spo_total) if spo_total else 0.0 - empty_rate = (empty_count / text_total) if text_total else 1.0 - avg_elapsed = (elapsed_total / text_total) if text_total else 0.0 - - metrics = { - "total_text_cases": int(text_total), - "precision_at_1": float(round(p1, 6)), - "precision_at_3": float(round(p3, 6)), - "mrr": float(round(mrr, 6)), - "recall_at_k": float(round(recall, 6)), - "spo_relation_hit_rate": float(round(spo_rate, 6)), - "empty_rate": float(round(empty_rate, 6)), - "timeout_count": int(timeout_count), - "avg_elapsed_ms": float(round(avg_elapsed, 3)), - "category": category_stats, - } - metrics["category_floor_penalty"] = float(round(self._category_floor_penalty(metrics, objective=objective), 6)) - - score = self._score_metrics(metrics, objective=objective) - failure_summary = { - "text_failed_count": len(text_failed), - "spo_failed_count": len(spo_failed), - "failed_case_ids": text_failed[:50] + spo_failed[:50], - "failed_by_category": {k: int(v["total"] - v["hit"]) for k, v in category_stats.items()}, - "top_failed_predicates": [ - {"predicate": key, "count": int(cnt)} - for key, cnt in failed_predicates.most_common(5) - if key - ], - "query_timeout_seconds": float(eval_timeout_s), - "timeout_count": int(timeout_count), - "effective_top_k": int(request_top_k), - "ppr_forced_disabled": True, - } - return { - "metrics": metrics, - "score": float(round(score, 6)), - "avg_elapsed_ms": float(avg_elapsed), - "failure_summary": failure_summary, - } - - def _score_metrics(self, metrics: Dict[str, Any], *, objective: str) -> float: - p1 = float(metrics.get("precision_at_1", 0.0) or 0.0) - p3 = float(metrics.get("precision_at_3", 0.0) or 0.0) - mrr = float(metrics.get("mrr", 0.0) or 0.0) - recall = float(metrics.get("recall_at_k", 0.0) or 0.0) - spo = float(metrics.get("spo_relation_hit_rate", 0.0) or 0.0) - empty_rate = float(metrics.get("empty_rate", 1.0) or 1.0) - category_penalty = metrics.get("category_floor_penalty", None) - if category_penalty is None: - category_penalty = self._category_floor_penalty(metrics, objective=objective) - category_penalty = float(max(0.0, category_penalty)) - - if objective == "recall_priority": - raw = 0.15 * p1 + 0.15 * p3 + 0.15 * mrr + 0.40 * recall + 0.15 * spo - penalty = 0.05 * empty_rate - elif objective == "balanced": - raw = 0.25 * p1 + 0.20 * p3 + 0.15 * mrr + 0.25 * recall + 0.15 * spo - penalty = 0.10 * empty_rate - else: - raw = 0.40 * p1 + 0.20 * p3 + 0.15 * mrr + 0.15 * recall + 0.10 * spo - penalty = 0.15 * empty_rate - return float(raw - penalty - category_penalty) - - def _category_floor_penalty(self, metrics: Dict[str, Any], *, objective: str) -> float: - category = metrics.get("category") - if not isinstance(category, dict) or not category: - return 0.0 - - if objective == "recall_priority": - floors = {"query_nl": 0.60, "query_kw": 0.48, "spo_search": 0.52, "spo_relation": 0.88} - scale = 0.12 - elif objective == "balanced": - floors = {"query_nl": 0.65, "query_kw": 0.52, "spo_search": 0.55, "spo_relation": 0.90} - scale = 0.18 - else: - floors = {"query_nl": 0.70, "query_kw": 0.55, "spo_search": 0.58, "spo_relation": 0.92} - scale = 0.25 - - weights = {"query_nl": 1.0, "query_kw": 1.1, "spo_search": 1.0, "spo_relation": 1.2} - weighted_shortfall = 0.0 - weight_total = 0.0 - - for cat, floor in floors.items(): - row = category.get(cat) - if not isinstance(row, dict): - continue - total = int(row.get("total", 0) or 0) - if total <= 0: - continue - hit = float(row.get("hit", 0.0) or 0.0) - hit_rate = max(0.0, min(1.0, hit / float(max(1, total)))) - shortfall = max(0.0, float(floor) - hit_rate) - w = float(weights.get(cat, 1.0)) - weighted_shortfall += w * shortfall - weight_total += w - - if weight_total <= 1e-9: - return 0.0 - return float(scale * (weighted_shortfall / weight_total)) - - def _build_report_payload(self, task: RetrievalTuningTaskRecord) -> Dict[str, Any]: - baseline = task.baseline_metrics or {} - best = task.best_metrics or {} - - def delta(name: str) -> float: - return float(best.get(name, 0.0) or 0.0) - float(baseline.get(name, 0.0) or 0.0) - - return { - "task_id": task.task_id, - "objective": task.objective, - "intensity": task.intensity, - "status": task.status, - "created_at": task.created_at, - "started_at": task.started_at, - "finished_at": task.finished_at, - "rounds_total": task.rounds_total, - "rounds_done": task.rounds_done, - "best_score": task.best_score, - "baseline_score": self._score_metrics(baseline, objective=task.objective), - "query_set_stats": task.query_set_stats, - "baseline_metrics": baseline, - "best_metrics": best, - "deltas": { - "precision_at_1": delta("precision_at_1"), - "precision_at_3": delta("precision_at_3"), - "mrr": delta("mrr"), - "recall_at_k": delta("recall_at_k"), - "spo_relation_hit_rate": delta("spo_relation_hit_rate"), - "empty_rate": delta("empty_rate"), - "timeout_count": delta("timeout_count"), - "avg_elapsed_ms": delta("avg_elapsed_ms"), - }, - "best_profile": task.best_profile, - "baseline_profile": task.baseline_profile, - "apply_log": task.apply_log, - } - - def _build_report_markdown(self, task: RetrievalTuningTaskRecord, payload: Dict[str, Any]) -> str: - baseline = payload.get("baseline_metrics", {}) or {} - best = payload.get("best_metrics", {}) or {} - d = payload.get("deltas", {}) or {} - lines = [ - f"# 检索调优报告({task.task_id})", - "", - "## 1. 任务信息", - f"- 状态: {task.status}", - f"- 目标函数: {task.objective}", - f"- 强度: {task.intensity}", - f"- 轮次: baseline + {task.rounds_total}", - f"- 创建时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.created_at))}", - f"- 开始时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.started_at)) if task.started_at else '-'}", - f"- 完成时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.finished_at)) if task.finished_at else '-'}", - "", - "## 2. 基线 vs 最优", - f"- baseline score: {payload.get('baseline_score', 0.0):.6f}", - f"- best score: {task.best_score:.6f}", - f"- P@1: {baseline.get('precision_at_1', 0.0):.4f} -> {best.get('precision_at_1', 0.0):.4f} (Δ {d.get('precision_at_1', 0.0):+.4f})", - f"- P@3: {baseline.get('precision_at_3', 0.0):.4f} -> {best.get('precision_at_3', 0.0):.4f} (Δ {d.get('precision_at_3', 0.0):+.4f})", - f"- MRR: {baseline.get('mrr', 0.0):.4f} -> {best.get('mrr', 0.0):.4f} (Δ {d.get('mrr', 0.0):+.4f})", - f"- Recall@K: {baseline.get('recall_at_k', 0.0):.4f} -> {best.get('recall_at_k', 0.0):.4f} (Δ {d.get('recall_at_k', 0.0):+.4f})", - f"- SPO relation hit: {baseline.get('spo_relation_hit_rate', 0.0):.4f} -> {best.get('spo_relation_hit_rate', 0.0):.4f} (Δ {d.get('spo_relation_hit_rate', 0.0):+.4f})", - f"- 空结果率: {baseline.get('empty_rate', 0.0):.4f} -> {best.get('empty_rate', 0.0):.4f} (Δ {d.get('empty_rate', 0.0):+.4f})", - f"- 超时数: {int(baseline.get('timeout_count', 0) or 0)} -> {int(best.get('timeout_count', 0) or 0)} (Δ {int(d.get('timeout_count', 0) or 0):+d})", - f"- 平均耗时(ms): {baseline.get('avg_elapsed_ms', 0.0):.2f} -> {best.get('avg_elapsed_ms', 0.0):.2f} (Δ {d.get('avg_elapsed_ms', 0.0):+.2f})", - "", - "## 3. 最优参数", - "```json", - json.dumps(task.best_profile, ensure_ascii=False, indent=2), - "```", - "", - "## 4. 测试集规模", - f"- {json.dumps(task.query_set_stats, ensure_ascii=False)}", - "", - "## 5. 说明", - "- 本报告仅对当前已存储图谱与向量状态有效。", - "- 参数应用策略:运行时生效,不自动写入 config.toml。", - ] - return "\n".join(lines).strip() + "\n" diff --git a/plugins/A_memorix/core/utils/runtime_self_check.py b/plugins/A_memorix/core/utils/runtime_self_check.py deleted file mode 100644 index 131ab32a..00000000 --- a/plugins/A_memorix/core/utils/runtime_self_check.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Runtime self-check helpers for A_Memorix.""" - -from __future__ import annotations - -import time -from typing import Any, Dict, Optional - -import numpy as np - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.RuntimeSelfCheck") - -_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check" - - -def _safe_int(value: Any, default: int = 0) -> int: - try: - return int(value) - except Exception: - return int(default) - - -def _get_config_value(config: Any, key: str, default: Any = None) -> Any: - getter = getattr(config, "get_config", None) - if callable(getter): - return getter(key, default) - - current: Any = config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - -def _build_report( - *, - ok: bool, - code: str, - message: str, - configured_dimension: int, - vector_store_dimension: int, - detected_dimension: int, - encoded_dimension: int, - elapsed_ms: float, - sample_text: str, -) -> Dict[str, Any]: - return { - "ok": bool(ok), - "code": str(code or "").strip(), - "message": str(message or "").strip(), - "configured_dimension": int(configured_dimension), - "vector_store_dimension": int(vector_store_dimension), - "detected_dimension": int(detected_dimension), - "encoded_dimension": int(encoded_dimension), - "elapsed_ms": float(elapsed_ms), - "sample_text": str(sample_text or ""), - "checked_at": time.time(), - } - - -def _normalize_encoded_vector(encoded: Any) -> np.ndarray: - if encoded is None: - raise ValueError("embedding encode returned None") - - if isinstance(encoded, np.ndarray): - array = encoded - else: - array = np.asarray(encoded, dtype=np.float32) - - if array.ndim == 2: - if array.shape[0] != 1: - raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}") - array = array[0] - - if array.ndim != 1: - raise ValueError(f"embedding encode returned invalid ndim={array.ndim}") - if array.size <= 0: - raise ValueError("embedding encode returned empty vector") - if not np.all(np.isfinite(array)): - raise ValueError("embedding encode returned non-finite values") - return array.astype(np.float32, copy=False) - - -async def run_embedding_runtime_self_check( - *, - config: Any, - vector_store: Optional[Any], - embedding_manager: Optional[Any], - sample_text: str = _DEFAULT_SAMPLE_TEXT, -) -> Dict[str, Any]: - """Probe the real embedding path and compare dimensions with runtime storage.""" - configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0) - vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0) - - if vector_store is None or embedding_manager is None: - return _build_report( - ok=False, - code="runtime_components_missing", - message="vector_store 或 embedding_manager 未初始化", - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=0, - encoded_dimension=0, - elapsed_ms=0.0, - sample_text=sample_text, - ) - - start = time.perf_counter() - detected_dimension = 0 - encoded_dimension = 0 - try: - detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0) - encoded = await embedding_manager.encode(sample_text) - encoded_array = _normalize_encoded_vector(encoded) - encoded_dimension = int(encoded_array.shape[0]) - except Exception as exc: - elapsed_ms = (time.perf_counter() - start) * 1000.0 - logger.warning(f"embedding runtime self-check failed: {exc}") - return _build_report( - ok=False, - code="embedding_probe_failed", - message=f"embedding probe failed: {exc}", - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=detected_dimension, - encoded_dimension=encoded_dimension, - elapsed_ms=elapsed_ms, - sample_text=sample_text, - ) - - elapsed_ms = (time.perf_counter() - start) * 1000.0 - expected_dimension = vector_store_dimension or configured_dimension or detected_dimension - if expected_dimension <= 0: - return _build_report( - ok=False, - code="invalid_expected_dimension", - message="无法确定期望 embedding 维度", - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=detected_dimension, - encoded_dimension=encoded_dimension, - elapsed_ms=elapsed_ms, - sample_text=sample_text, - ) - - if encoded_dimension != expected_dimension: - msg = ( - "embedding 真实输出维度与当前向量存储不一致: " - f"expected={expected_dimension}, encoded={encoded_dimension}" - ) - logger.error(msg) - return _build_report( - ok=False, - code="embedding_dimension_mismatch", - message=msg, - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=detected_dimension, - encoded_dimension=encoded_dimension, - elapsed_ms=elapsed_ms, - sample_text=sample_text, - ) - - return _build_report( - ok=True, - code="ok", - message="embedding runtime self-check passed", - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=detected_dimension, - encoded_dimension=encoded_dimension, - elapsed_ms=elapsed_ms, - sample_text=sample_text, - ) - - -async def ensure_runtime_self_check( - plugin_or_config: Any, - *, - force: bool = False, - sample_text: str = _DEFAULT_SAMPLE_TEXT, -) -> Dict[str, Any]: - """Run or reuse cached runtime self-check report.""" - if plugin_or_config is None: - return _build_report( - ok=False, - code="missing_plugin_or_config", - message="plugin/config unavailable", - configured_dimension=0, - vector_store_dimension=0, - detected_dimension=0, - encoded_dimension=0, - elapsed_ms=0.0, - sample_text=sample_text, - ) - - cache = getattr(plugin_or_config, "_runtime_self_check_report", None) - if isinstance(cache, dict) and cache and not force: - return cache - - report = await run_embedding_runtime_self_check( - config=getattr(plugin_or_config, "config", plugin_or_config), - vector_store=getattr(plugin_or_config, "vector_store", None) - if not isinstance(plugin_or_config, dict) - else plugin_or_config.get("vector_store"), - embedding_manager=getattr(plugin_or_config, "embedding_manager", None) - if not isinstance(plugin_or_config, dict) - else plugin_or_config.get("embedding_manager"), - sample_text=sample_text, - ) - try: - setattr(plugin_or_config, "_runtime_self_check_report", report) - except Exception: - pass - return report diff --git a/plugins/A_memorix/core/utils/search_execution_service.py b/plugins/A_memorix/core/utils/search_execution_service.py deleted file mode 100644 index 7df243af..00000000 --- a/plugins/A_memorix/core/utils/search_execution_service.py +++ /dev/null @@ -1,439 +0,0 @@ -""" -统一检索执行服务。 - -用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。 -""" - -from __future__ import annotations - -import hashlib -import json -import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple - -from src.common.logger import get_logger - -from ..retrieval import TemporalQueryOptions -from .search_postprocess import ( - apply_safe_content_dedup, - maybe_apply_smart_path_fallback, -) -from .time_parser import parse_query_time_range - -logger = get_logger("A_Memorix.SearchExecutionService") - - -def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any: - if not isinstance(config, dict): - return default - current: Any = config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - -def _sanitize_text(value: Any) -> str: - if value is None: - return "" - return str(value).strip() - - -@dataclass -class SearchExecutionRequest: - caller: str - stream_id: Optional[str] = None - group_id: Optional[str] = None - user_id: Optional[str] = None - query_type: str = "search" # search|time|hybrid - query: str = "" - top_k: Optional[int] = None - time_from: Optional[str] = None - time_to: Optional[str] = None - person: Optional[str] = None - source: Optional[str] = None - use_threshold: bool = True - enable_ppr: bool = True - - -@dataclass -class SearchExecutionResult: - success: bool - error: str = "" - query_type: str = "search" - query: str = "" - top_k: int = 10 - time_from: Optional[str] = None - time_to: Optional[str] = None - person: Optional[str] = None - source: Optional[str] = None - temporal: Optional[TemporalQueryOptions] = None - results: List[Any] = field(default_factory=list) - elapsed_ms: float = 0.0 - chat_filtered: bool = False - dedup_hit: bool = False - - @property - def count(self) -> int: - return len(self.results) - - -class SearchExecutionService: - """统一检索执行服务。""" - - @staticmethod - def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]: - if isinstance(plugin_config, dict): - plugin_instance = plugin_config.get("plugin_instance") - if plugin_instance is not None: - return plugin_instance - - try: - from ...plugin import AMemorixPlugin - - return getattr(AMemorixPlugin, "get_global_instance", lambda: None)() - except Exception: - return None - - @staticmethod - def _normalize_query_type(raw_query_type: str) -> str: - return _sanitize_text(raw_query_type).lower() or "search" - - @staticmethod - def _resolve_runtime_component( - plugin_config: Optional[dict], - plugin_instance: Optional[Any], - key: str, - ) -> Optional[Any]: - if isinstance(plugin_config, dict): - value = plugin_config.get(key) - if value is not None: - return value - if plugin_instance is not None: - value = getattr(plugin_instance, key, None) - if value is not None: - return value - return None - - @staticmethod - def _resolve_top_k( - plugin_config: Optional[dict], - query_type: str, - top_k_raw: Optional[Any], - ) -> Tuple[bool, int, str]: - temporal_default_top_k = int( - _get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10) - ) - default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10 - if top_k_raw is None: - return True, max(1, min(50, default_top_k)), "" - try: - top_k = int(top_k_raw) - except (TypeError, ValueError): - return False, 0, "top_k 参数必须为整数" - return True, max(1, min(50, top_k)), "" - - @staticmethod - def _build_temporal( - plugin_config: Optional[dict], - query_type: str, - time_from_raw: Optional[str], - time_to_raw: Optional[str], - person: Optional[str], - source: Optional[str], - ) -> Tuple[bool, Optional[TemporalQueryOptions], str]: - if query_type not in {"time", "hybrid"}: - return True, None, "" - - temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True)) - if not temporal_enabled: - return False, None, "时序检索已禁用(retrieval.temporal.enabled=false)" - - if not time_from_raw and not time_to_raw: - return False, None, "time/hybrid 模式至少需要 time_from 或 time_to" - - try: - ts_from, ts_to = parse_query_time_range( - str(time_from_raw) if time_from_raw is not None else None, - str(time_to_raw) if time_to_raw is not None else None, - ) - except ValueError as e: - return False, None, f"时间参数错误: {e}" - - temporal = TemporalQueryOptions( - time_from=ts_from, - time_to=ts_to, - person=_sanitize_text(person) or None, - source=_sanitize_text(source) or None, - allow_created_fallback=bool( - _get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True) - ), - candidate_multiplier=int( - _get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8) - ), - max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)), - ) - return True, temporal, "" - - @staticmethod - def _build_request_key( - request: SearchExecutionRequest, - query_type: str, - top_k: int, - temporal: Optional[TemporalQueryOptions], - ) -> str: - payload = { - "stream_id": _sanitize_text(request.stream_id), - "query_type": query_type, - "query": _sanitize_text(request.query), - "time_from": _sanitize_text(request.time_from), - "time_to": _sanitize_text(request.time_to), - "time_from_ts": temporal.time_from if temporal else None, - "time_to_ts": temporal.time_to if temporal else None, - "person": _sanitize_text(request.person), - "source": _sanitize_text(request.source), - "top_k": int(top_k), - "use_threshold": bool(request.use_threshold), - "enable_ppr": bool(request.enable_ppr), - } - payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True) - return hashlib.sha1(payload_json.encode("utf-8")).hexdigest() - - @staticmethod - async def execute( - *, - retriever: Any, - threshold_filter: Optional[Any], - plugin_config: Optional[dict], - request: SearchExecutionRequest, - enforce_chat_filter: bool = True, - reinforce_access: bool = True, - ) -> SearchExecutionResult: - if retriever is None: - return SearchExecutionResult(success=False, error="知识检索器未初始化") - - query_type = SearchExecutionService._normalize_query_type(request.query_type) - query = _sanitize_text(request.query) - if query_type not in {"search", "time", "hybrid"}: - return SearchExecutionResult( - success=False, - error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid)", - ) - - if query_type in {"search", "hybrid"} and not query: - return SearchExecutionResult( - success=False, - error="search/hybrid 模式必须提供 query", - ) - - top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k( - plugin_config, query_type, request.top_k - ) - if not top_k_ok: - return SearchExecutionResult(success=False, error=top_k_error) - - temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal( - plugin_config=plugin_config, - query_type=query_type, - time_from_raw=request.time_from, - time_to_raw=request.time_to, - person=request.person, - source=request.source, - ) - if not temporal_ok: - return SearchExecutionResult(success=False, error=temporal_error) - - plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config) - if ( - enforce_chat_filter - and plugin_instance is not None - and hasattr(plugin_instance, "is_chat_enabled") - ): - if not plugin_instance.is_chat_enabled( - stream_id=request.stream_id, - group_id=request.group_id, - user_id=request.user_id, - ): - logger.info( - "检索请求被聊天过滤拦截: " - f"caller={request.caller}, " - f"stream_id={request.stream_id}" - ) - return SearchExecutionResult( - success=True, - query_type=query_type, - query=query, - top_k=top_k, - time_from=request.time_from, - time_to=request.time_to, - person=request.person, - source=request.source, - temporal=temporal, - results=[], - elapsed_ms=0.0, - chat_filtered=True, - dedup_hit=False, - ) - - request_key = SearchExecutionService._build_request_key( - request=request, - query_type=query_type, - top_k=top_k, - temporal=temporal, - ) - - async def _executor() -> Dict[str, Any]: - original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) - setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) - started_at = time.time() - try: - retrieved = await retriever.retrieve( - query=query, - top_k=top_k, - temporal=temporal, - ) - - should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None - if ( - query_type == "time" - and not query - and bool( - _get_config_value( - plugin_config, - "retrieval.time.skip_threshold_when_query_empty", - True, - ) - ) - ): - should_apply_threshold = False - - if should_apply_threshold: - retrieved = threshold_filter.filter(retrieved) - - if ( - reinforce_access - and plugin_instance is not None - and hasattr(plugin_instance, "reinforce_access") - ): - relation_hashes = [ - item.hash_value - for item in retrieved - if getattr(item, "result_type", "") == "relation" - ] - if relation_hashes: - await plugin_instance.reinforce_access(relation_hashes) - - if query_type == "search": - graph_store = SearchExecutionService._resolve_runtime_component( - plugin_config, plugin_instance, "graph_store" - ) - metadata_store = SearchExecutionService._resolve_runtime_component( - plugin_config, plugin_instance, "metadata_store" - ) - fallback_enabled = bool( - _get_config_value( - plugin_config, - "retrieval.search.smart_fallback.enabled", - True, - ) - ) - fallback_threshold = float( - _get_config_value( - plugin_config, - "retrieval.search.smart_fallback.threshold", - 0.6, - ) - ) - retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback( - query=query, - results=list(retrieved), - graph_store=graph_store, - metadata_store=metadata_store, - enabled=fallback_enabled, - threshold=fallback_threshold, - ) - if fallback_triggered: - logger.info( - "metric.smart_fallback_triggered_count=1 " - f"caller={request.caller} " - f"added={fallback_added}" - ) - - dedup_enabled = bool( - _get_config_value( - plugin_config, - "retrieval.search.safe_content_dedup.enabled", - True, - ) - ) - if dedup_enabled: - retrieved, removed_count = apply_safe_content_dedup(list(retrieved)) - if removed_count > 0: - logger.info( - f"metric.safe_dedup_removed_count={removed_count} " - f"caller={request.caller}" - ) - - elapsed_ms = (time.time() - started_at) * 1000.0 - return {"results": retrieved, "elapsed_ms": elapsed_ms} - finally: - setattr(retriever.config, "enable_ppr", original_ppr) - - dedup_hit = False - try: - # 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。 - bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning" - if ( - not bypass_request_dedup - and - plugin_instance is not None - and hasattr(plugin_instance, "execute_request_with_dedup") - ): - dedup_hit, payload = await plugin_instance.execute_request_with_dedup( - request_key, - _executor, - ) - else: - payload = await _executor() - except Exception as e: - return SearchExecutionResult(success=False, error=f"知识检索失败: {e}") - - if dedup_hit: - logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}") - - return SearchExecutionResult( - success=True, - query_type=query_type, - query=query, - top_k=top_k, - time_from=request.time_from, - time_to=request.time_to, - person=request.person, - source=request.source, - temporal=temporal, - results=payload.get("results", []), - elapsed_ms=float(payload.get("elapsed_ms", 0.0)), - chat_filtered=False, - dedup_hit=bool(dedup_hit), - ) - - @staticmethod - def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: - serialized: List[Dict[str, Any]] = [] - for item in results: - metadata = dict(getattr(item, "metadata", {}) or {}) - if "time_meta" not in metadata: - metadata["time_meta"] = {} - serialized.append( - { - "hash": getattr(item, "hash_value", ""), - "type": getattr(item, "result_type", ""), - "score": float(getattr(item, "score", 0.0)), - "content": getattr(item, "content", ""), - "metadata": metadata, - } - ) - return serialized diff --git a/plugins/A_memorix/core/utils/search_postprocess.py b/plugins/A_memorix/core/utils/search_postprocess.py deleted file mode 100644 index 52688e08..00000000 --- a/plugins/A_memorix/core/utils/search_postprocess.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Post-processing helpers for unified search execution.""" - -from __future__ import annotations - -from typing import Any, List, Tuple - -from .path_fallback_service import find_paths_from_query, to_retrieval_results - - -def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]: - """Deduplicate results by hash/content while preserving at least one entry.""" - if not results: - return [], 0 - - unique_results: List[Any] = [] - seen_hashes = set() - seen_contents = set() - - for item in results: - content = str(getattr(item, "content", "") or "").strip() - if not content: - continue - - hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content)) - if hash_value in seen_hashes: - continue - - is_dup = False - for seen in seen_contents: - if content in seen or seen in content: - is_dup = True - break - if is_dup: - continue - - seen_hashes.add(hash_value) - seen_contents.add(content) - unique_results.append(item) - - if not unique_results: - unique_results.append(results[0]) - - removed_count = max(0, len(results) - len(unique_results)) - return unique_results, removed_count - - -def maybe_apply_smart_path_fallback( - *, - query: str, - results: List[Any], - graph_store: Any, - metadata_store: Any, - enabled: bool, - threshold: float, - max_depth: int = 3, - max_paths: int = 5, -) -> Tuple[List[Any], bool, int]: - """Append indirect relation paths when semantic results are weak.""" - if not enabled or not str(query or "").strip(): - return results, False, 0 - if graph_store is None or metadata_store is None: - return results, False, 0 - - max_score = 0.0 - if results: - try: - max_score = float(getattr(results[0], "score", 0.0) or 0.0) - except Exception: - max_score = 0.0 - - if max_score >= float(threshold): - return results, False, 0 - - paths = find_paths_from_query( - query=query, - graph_store=graph_store, - metadata_store=metadata_store, - max_depth=max_depth, - max_paths=max_paths, - ) - if not paths: - return results, False, 0 - - path_results = to_retrieval_results(paths) - if not path_results: - return results, False, 0 - - merged = list(path_results) + list(results) - return merged, True, len(path_results) - diff --git a/plugins/A_memorix/core/utils/summary_importer.py b/plugins/A_memorix/core/utils/summary_importer.py deleted file mode 100644 index b6271db4..00000000 --- a/plugins/A_memorix/core/utils/summary_importer.py +++ /dev/null @@ -1,425 +0,0 @@ -""" -聊天总结与知识导入工具 - -该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系 -导入到 A_memorix 的存储组件中。 -""" - -import time -import json -import re -import traceback -from typing import List, Dict, Any, Tuple, Optional -from pathlib import Path - -from src.common.logger import get_logger -from src.services import llm_service as llm_api -from src.services import message_service as message_api -from src.config.config import global_config, model_config as host_model_config -from src.config.model_configs import TaskConfig - -from ..storage import ( - KnowledgeType, - VectorStore, - GraphStore, - MetadataStore, - resolve_stored_knowledge_type, -) -from ..embedding import EmbeddingAPIAdapter -from .relation_write_service import RelationWriteService -from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check - -logger = get_logger("A_Memorix.SummaryImporter") - -# 默认总结提示词模版 -SUMMARY_PROMPT_TEMPLATE = """ -你是 {bot_name}。{personality_context} -现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。 - -聊天记录内容: -{chat_history} - -请完成以下任务: -1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。 -2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。 - -请严格以 JSON 格式输出,格式如下: -{{ - "summary": "总结文本内容", - "entities": ["张三", "李四"], - "relations": [ - {{"subject": "张三", "predicate": "认识", "object": "李四"}} - ] -}} - -注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。 -""" - -class SummaryImporter: - """总结并导入知识的工具类""" - - def __init__( - self, - vector_store: VectorStore, - graph_store: GraphStore, - metadata_store: MetadataStore, - embedding_manager: EmbeddingAPIAdapter, - plugin_config: dict - ): - self.vector_store = vector_store - self.graph_store = graph_store - self.metadata_store = metadata_store - self.embedding_manager = embedding_manager - self.plugin_config = plugin_config - self.relation_write_service: Optional[RelationWriteService] = ( - plugin_config.get("relation_write_service") - if isinstance(plugin_config, dict) - else None - ) - - def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]: - """标准化 summarization.model_name 配置(vNext 仅接受字符串数组)。""" - if raw_value is None: - return ["auto"] - if isinstance(raw_value, list): - selectors = [str(x).strip() for x in raw_value if str(x).strip()] - return selectors or ["auto"] - raise ValueError( - "summarization.model_name 在 vNext 必须为 List[str]。" - " 请执行 scripts/release_vnext_migrate.py migrate。" - ) - - def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]: - """ - 选择总结默认任务,避免错误落到 embedding 任务。 - 优先级:replyer > utils > planner > tool_use > 其他非 embedding。 - """ - preferred = ("replyer", "utils", "planner", "tool_use") - for name in preferred: - cfg = available_tasks.get(name) - if cfg and cfg.model_list: - return name, cfg - - for name, cfg in available_tasks.items(): - if name != "embedding" and cfg.model_list: - return name, cfg - - for name, cfg in available_tasks.items(): - if cfg.model_list: - return name, cfg - - return None, None - - def _resolve_summary_model_config(self) -> Optional[TaskConfig]: - """ - 解析 summarization.model_name 为 TaskConfig。 - 支持: - - "auto" - - "replyer"(任务名) - - "some-model-name"(具体模型名) - - ["utils:model1", "utils:model2", "replyer"](数组混合语法) - """ - available_tasks = llm_api.get_available_models() - if not available_tasks: - return None - - raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto") - selectors = self._normalize_summary_model_selectors(raw_cfg) - default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks) - - selected_models: List[str] = [] - base_cfg: Optional[TaskConfig] = None - model_dict = getattr(host_model_config, "models_dict", {}) - - def _append_models(models: List[str]): - for model_name in models: - if model_name and model_name not in selected_models: - selected_models.append(model_name) - - for raw_selector in selectors: - selector = raw_selector.strip() - if not selector: - continue - - if selector.lower() == "auto": - if default_task_cfg: - _append_models(default_task_cfg.model_list) - if base_cfg is None: - base_cfg = default_task_cfg - continue - - if ":" in selector: - task_name, model_name = selector.split(":", 1) - task_name = task_name.strip() - model_name = model_name.strip() - task_cfg = available_tasks.get(task_name) - if not task_cfg: - logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过") - continue - - if base_cfg is None: - base_cfg = task_cfg - - if not model_name or model_name.lower() == "auto": - _append_models(task_cfg.model_list) - continue - - if model_name in model_dict or model_name in task_cfg.model_list: - _append_models([model_name]) - else: - logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过") - continue - - task_cfg = available_tasks.get(selector) - if task_cfg: - _append_models(task_cfg.model_list) - if base_cfg is None: - base_cfg = task_cfg - continue - - if selector in model_dict: - _append_models([selector]) - continue - - logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过") - - if not selected_models: - if default_task_cfg: - _append_models(default_task_cfg.model_list) - if base_cfg is None: - base_cfg = default_task_cfg - else: - first_cfg = next(iter(available_tasks.values())) - _append_models(first_cfg.model_list) - if base_cfg is None: - base_cfg = first_cfg - - if not selected_models: - return None - - template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values())) - return TaskConfig( - model_list=selected_models, - max_tokens=template_cfg.max_tokens, - temperature=template_cfg.temperature, - slow_threshold=template_cfg.slow_threshold, - selection_strategy=template_cfg.selection_strategy, - ) - - async def import_from_stream( - self, - stream_id: str, - context_length: Optional[int] = None, - include_personality: Optional[bool] = None - ) -> Tuple[bool, str]: - """ - 从指定的聊天流中提取记录并执行总结导入 - - Args: - stream_id: 聊天流 ID - context_length: 总结的历史消息条数 - include_personality: 是否包含人设 - - Returns: - Tuple[bool, str]: (是否成功, 结果消息) - """ - try: - self_check_ok, self_check_msg = await self._ensure_runtime_self_check() - if not self_check_ok: - return False, f"导入前自检失败: {self_check_msg}" - - # 1. 获取配置 - if context_length is None: - context_length = self.plugin_config.get("summarization", {}).get("context_length", 50) - - if include_personality is None: - include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True) - - # 2. 获取历史消息 - # 获取当前时间之前的消息 - now = time.time() - messages = message_api.get_messages_before_time_in_chat( - chat_id=stream_id, - timestamp=now, - limit=context_length - ) - - if not messages: - return False, "未找到有效的聊天记录进行总结" - - # 转换为可读文本 - chat_history_text = message_api.build_readable_messages(messages) - - # 3. 准备提示词内容 - bot_name = global_config.bot.nickname or "机器人" - personality_context = "" - if include_personality: - personality = getattr(global_config.bot, "personality", "") - if personality: - personality_context = f"你的性格设定是:{personality}" - - # 4. 调用 LLM - prompt = SUMMARY_PROMPT_TEMPLATE.format( - bot_name=bot_name, - personality_context=personality_context, - chat_history=chat_history_text - ) - - model_config_to_use = self._resolve_summary_model_config() - if model_config_to_use is None: - return False, "未找到可用的总结模型配置" - - logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}") - logger.info(f"总结模型候选列表: {model_config_to_use.model_list}") - - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config_to_use, - request_type="A_Memorix.ChatSummarization" - ) - - if not success or not response: - return False, "LLM 生成总结失败" - - # 5. 解析结果 - data = self._parse_llm_response(response) - if not data or "summary" not in data: - return False, "解析 LLM 响应失败或总结为空" - - summary_text = data["summary"] - entities = data.get("entities", []) - relations = data.get("relations", []) - msg_times = [ - float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)()) - for msg in messages - if getattr(msg, "time", None) is not None - ] - time_meta = {} - if msg_times: - time_meta = { - "event_time_start": min(msg_times), - "event_time_end": max(msg_times), - "time_granularity": "minute", - "time_confidence": 0.95, - } - - # 6. 执行导入 - await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta) - - # 7. 持久化 - self.vector_store.save() - self.graph_store.save() - - result_msg = ( - f"✅ 总结导入成功\n" - f"📝 总结长度: {len(summary_text)}\n" - f"📌 提取实体: {len(entities)}\n" - f"🔗 提取关系: {len(relations)}" - ) - return True, result_msg - - except Exception as e: - logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}") - return False, f"错误: {str(e)}" - - async def _ensure_runtime_self_check(self) -> Tuple[bool, str]: - plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None - if plugin_instance is not None: - report = await ensure_runtime_self_check(plugin_instance) - else: - report = await run_embedding_runtime_self_check( - config=self.plugin_config, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - if bool(report.get("ok", False)): - return True, "" - return ( - False, - f"{report.get('message', 'unknown')} " - f"(configured={report.get('configured_dimension', 0)}, " - f"store={report.get('vector_store_dimension', 0)}, " - f"encoded={report.get('encoded_dimension', 0)})", - ) - - def _parse_llm_response(self, response: str) -> Dict[str, Any]: - """解析 LLM 返回的 JSON""" - try: - # 尝试查找 JSON - json_match = re.search(r"\{.*\}", response, re.DOTALL) - if json_match: - return json.loads(json_match.group()) - return {} - except Exception as e: - logger.warning(f"解析总结 JSON 失败: {e}") - return {} - - async def _execute_import( - self, - summary: str, - entities: List[str], - relations: List[Dict[str, str]], - stream_id: str, - time_meta: Optional[Dict[str, Any]] = None, - ): - """将数据写入存储""" - # 获取默认知识类型 - type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative") - try: - knowledge_type = resolve_stored_knowledge_type(type_str, content=summary) - except ValueError: - logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative") - knowledge_type = KnowledgeType.NARRATIVE - - # 导入总结文本 - hash_value = self.metadata_store.add_paragraph( - content=summary, - source=f"chat_summary:{stream_id}", - knowledge_type=knowledge_type.value, - time_meta=time_meta, - ) - - embedding = await self.embedding_manager.encode(summary) - self.vector_store.add( - vectors=embedding.reshape(1, -1), - ids=[hash_value] - ) - - # 导入实体 - if entities: - self.graph_store.add_nodes(entities) - - # 导入关系 - rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {}) - if not isinstance(rv_cfg, dict): - rv_cfg = {} - write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) - for rel in relations: - s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object") - if all([s, p, o]): - if self.relation_write_service is not None: - await self.relation_write_service.upsert_relation_with_vector( - subject=s, - predicate=p, - obj=o, - confidence=1.0, - source_paragraph=summary, - write_vector=write_vector, - ) - else: - # 写入元数据 - rel_hash = self.metadata_store.add_relation( - subject=s, - predicate=p, - obj=o, - confidence=1.0, - source_paragraph=summary - ) - # 写入图数据库(写入 relation_hashes,确保后续可按关系精确修剪) - self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash]) - try: - self.metadata_store.set_relation_vector_state(rel_hash, "none") - except Exception: - pass - - logger.info(f"总结导入完成: hash={hash_value[:8]}") diff --git a/plugins/A_memorix/core/utils/time_parser.py b/plugins/A_memorix/core/utils/time_parser.py deleted file mode 100644 index 8e577974..00000000 --- a/plugins/A_memorix/core/utils/time_parser.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -时间解析工具。 - -约束: -1. 查询参数(Action/Command/Tool)仅接受结构化绝对时间: - - YYYY/MM/DD - - YYYY/MM/DD HH:mm -2. 入库时允许更宽松格式(含时间戳、YYYY-MM-DD 等)。 -""" - -from __future__ import annotations - -import re -from datetime import datetime -from typing import Any, Dict, Optional, Tuple - - -_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$") -_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$") -_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$") - -_INGEST_FORMATS = [ - "%Y/%m/%d %H:%M:%S", - "%Y/%m/%d %H:%M", - "%Y-%m-%d %H:%M:%S", - "%Y-%m-%d %H:%M", - "%Y-%m-%dT%H:%M:%S", - "%Y-%m-%dT%H:%M", - "%Y/%m/%d", - "%Y-%m-%d", -] - -_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"} - - -def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float: - """解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。""" - text = str(value).strip() - if not text: - raise ValueError("时间不能为空") - - if _QUERY_DATE_RE.fullmatch(text): - dt = datetime.strptime(text, "%Y/%m/%d") - if is_end: - dt = dt.replace(hour=23, minute=59, second=0, microsecond=0) - return dt.timestamp() - - if _QUERY_MINUTE_RE.fullmatch(text): - dt = datetime.strptime(text, "%Y/%m/%d %H:%M") - return dt.timestamp() - - raise ValueError( - f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm" - ) - - -def parse_query_time_range( - time_from: Optional[str], - time_to: Optional[str], -) -> Tuple[Optional[float], Optional[float]]: - """解析查询窗口并验证区间。""" - ts_from = ( - parse_query_datetime_to_timestamp(time_from, is_end=False) - if time_from - else None - ) - ts_to = ( - parse_query_datetime_to_timestamp(time_to, is_end=True) - if time_to - else None - ) - - if ts_from is not None and ts_to is not None and ts_from > ts_to: - raise ValueError("time_from 不能晚于 time_to") - - return ts_from, ts_to - - -def parse_ingest_datetime_to_timestamp( - value: Any, - is_end: bool = False, -) -> Optional[float]: - """解析入库时间,允许 timestamp/常见字符串格式。""" - if value is None: - return None - - if isinstance(value, (int, float)): - return float(value) - - text = str(value).strip() - if not text: - return None - - if _NUMERIC_RE.fullmatch(text): - return float(text) - - for fmt in _INGEST_FORMATS: - try: - dt = datetime.strptime(text, fmt) - if fmt in _INGEST_DATE_FORMATS and is_end: - dt = dt.replace(hour=23, minute=59, second=0, microsecond=0) - return dt.timestamp() - except ValueError: - continue - - raise ValueError(f"无法解析时间: {text}") - - -def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]: - """归一化 time_meta 到存储层字段。""" - if not time_meta: - return {} - - normalized: Dict[str, Any] = {} - - event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time")) - event_start = parse_ingest_datetime_to_timestamp( - time_meta.get("event_time_start"), - is_end=False, - ) - event_end = parse_ingest_datetime_to_timestamp( - time_meta.get("event_time_end"), - is_end=True, - ) - - time_range = time_meta.get("time_range") - if ( - isinstance(time_range, (list, tuple)) - and len(time_range) == 2 - ): - if event_start is None: - event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False) - if event_end is None: - event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True) - - if event_start is not None and event_end is not None and event_start > event_end: - raise ValueError("event_time_start 不能晚于 event_time_end") - - if event_time is not None: - normalized["event_time"] = event_time - if event_start is not None: - normalized["event_time_start"] = event_start - if event_end is not None: - normalized["event_time_end"] = event_end - - granularity = time_meta.get("time_granularity") - if granularity: - normalized["time_granularity"] = str(granularity) - else: - raw_time_values = [ - time_meta.get("event_time"), - time_meta.get("event_time_start"), - time_meta.get("event_time_end"), - ] - has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None) - normalized["time_granularity"] = "minute" if has_minute else "day" - - confidence = time_meta.get("time_confidence") - if confidence is not None: - normalized["time_confidence"] = float(confidence) - - return normalized - - -def format_timestamp(ts: Optional[float]) -> Optional[str]: - """将 timestamp 格式化为 YYYY/MM/DD HH:mm。""" - if ts is None: - return None - return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M") - diff --git a/plugins/A_memorix/core/utils/web_import_manager.py b/plugins/A_memorix/core/utils/web_import_manager.py deleted file mode 100644 index b088be1f..00000000 --- a/plugins/A_memorix/core/utils/web_import_manager.py +++ /dev/null @@ -1,3522 +0,0 @@ -""" -Web Import Task Manager - -为 A_Memorix WebUI 提供导入任务队列、状态管理、并发调度与取消/重试能力。 -""" - -from __future__ import annotations - -import asyncio -import hashlib -import json -import os -import shutil -import sys -import time -import traceback -import uuid -from collections import deque -from dataclasses import dataclass, field -from datetime import datetime -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple - -from src.common.logger import get_logger -from src.services import llm_service as llm_api - -from ..storage import ( - parse_import_strategy, - resolve_stored_knowledge_type, - select_import_strategy, - KnowledgeType, - MetadataStore, -) -from ..storage.type_detection import looks_like_quote_text -from ..utils.import_payloads import normalize_paragraph_import_item -from ..utils.runtime_self_check import ensure_runtime_self_check -from ..utils.time_parser import normalize_time_meta -from ..storage.knowledge_types import ImportStrategy -from ..strategies.base import ProcessedChunk, KnowledgeType as StrategyKnowledgeType -from ..strategies.narrative import NarrativeStrategy -from ..strategies.factual import FactualStrategy -from ..strategies.quote import QuoteStrategy - -logger = get_logger("A_Memorix.WebImportManager") - - -TASK_STATUS = { - "queued", - "preparing", - "running", - "cancel_requested", - "cancelled", - "completed", - "completed_with_errors", - "failed", -} - -FILE_STATUS = { - "queued", - "preparing", - "splitting", - "extracting", - "writing", - "saving", - "completed", - "failed", - "cancelled", -} - -CHUNK_STATUS = { - "queued", - "extracting", - "writing", - "completed", - "failed", - "cancelled", -} - - -def _now() -> float: - return time.time() - - -def _coerce_int(value: Any, default: int) -> int: - try: - return int(value) - except Exception: - return default - - -def _coerce_bool(value: Any, default: bool) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - text = str(value).strip().lower() - if text in {"1", "true", "yes", "y", "on"}: - return True - if text in {"0", "false", "no", "n", "off", ""}: - return False - return default - - -def _clamp(value: int, min_value: int, max_value: int) -> int: - return max(min_value, min(max_value, value)) - - -def _coerce_list(value: Any) -> List[str]: - if value is None: - return [] - if isinstance(value, list): - raw_items = value - else: - text = str(value or "").replace("\r", "\n") - raw_items = [] - for seg in text.split("\n"): - raw_items.extend(seg.split(",")) - - out: List[str] = [] - seen = set() - for item in raw_items: - v = str(item or "").strip() - if not v: - continue - key = v.lower() - if key in seen: - continue - seen.add(key) - out.append(v) - return out - - -def _parse_optional_positive_int(value: Any, field_name: str) -> Optional[int]: - if value is None: - return None - text = str(value).strip() - if text == "": - return None - try: - parsed = int(text) - except Exception: - raise ValueError(f"{field_name} 必须为整数") - if parsed <= 0: - raise ValueError(f"{field_name} 必须 > 0") - return parsed - - -def _safe_filename(name: str) -> str: - base = os.path.basename(str(name or "").strip()) - if not base: - return f"unnamed_{uuid.uuid4().hex[:8]}.txt" - return base - - -def _storage_type_from_strategy(strategy_type: StrategyKnowledgeType) -> str: - if strategy_type == StrategyKnowledgeType.NARRATIVE: - return KnowledgeType.NARRATIVE.value - if strategy_type == StrategyKnowledgeType.FACTUAL: - return KnowledgeType.FACTUAL.value - if strategy_type == StrategyKnowledgeType.QUOTE: - return KnowledgeType.QUOTE.value - return KnowledgeType.MIXED.value - - -@dataclass -class ImportChunkRecord: - chunk_id: str - index: int - chunk_type: str - status: str = "queued" - step: str = "queued" - failed_at: str = "" - retryable: bool = False - error: str = "" - progress: float = 0.0 - content_preview: str = "" - updated_at: float = field(default_factory=_now) - - def to_dict(self) -> Dict[str, Any]: - return { - "chunk_id": self.chunk_id, - "index": self.index, - "chunk_type": self.chunk_type, - "status": self.status, - "step": self.step, - "failed_at": self.failed_at, - "retryable": self.retryable, - "error": self.error, - "progress": self.progress, - "content_preview": self.content_preview, - "updated_at": self.updated_at, - } - - -@dataclass -class ImportFileRecord: - file_id: str - name: str - source_kind: str - input_mode: str - status: str = "queued" - current_step: str = "queued" - detected_strategy_type: str = "unknown" - total_chunks: int = 0 - done_chunks: int = 0 - failed_chunks: int = 0 - cancelled_chunks: int = 0 - progress: float = 0.0 - error: str = "" - chunks: List[ImportChunkRecord] = field(default_factory=list) - created_at: float = field(default_factory=_now) - updated_at: float = field(default_factory=_now) - temp_path: Optional[str] = None - source_path: Optional[str] = None - inline_content: Optional[str] = None - content_hash: str = "" - retry_chunk_indexes: List[int] = field(default_factory=list) - retry_mode: str = "" - - def to_dict(self, include_chunks: bool = False) -> Dict[str, Any]: - payload = { - "file_id": self.file_id, - "name": self.name, - "source_kind": self.source_kind, - "input_mode": self.input_mode, - "status": self.status, - "current_step": self.current_step, - "detected_strategy_type": self.detected_strategy_type, - "total_chunks": self.total_chunks, - "done_chunks": self.done_chunks, - "failed_chunks": self.failed_chunks, - "cancelled_chunks": self.cancelled_chunks, - "progress": self.progress, - "error": self.error, - "created_at": self.created_at, - "updated_at": self.updated_at, - "source_path": self.source_path or "", - "content_hash": self.content_hash or "", - "retry_chunk_indexes": list(self.retry_chunk_indexes or []), - "retry_mode": self.retry_mode or "", - } - if include_chunks: - payload["chunks"] = [chunk.to_dict() for chunk in self.chunks] - return payload - - -@dataclass -class ImportTaskRecord: - task_id: str - source: str - params: Dict[str, Any] - status: str = "queued" - current_step: str = "queued" - total_chunks: int = 0 - done_chunks: int = 0 - failed_chunks: int = 0 - cancelled_chunks: int = 0 - progress: float = 0.0 - error: str = "" - files: List[ImportFileRecord] = field(default_factory=list) - created_at: float = field(default_factory=_now) - started_at: Optional[float] = None - finished_at: Optional[float] = None - updated_at: float = field(default_factory=_now) - schema_detected: str = "" - artifact_paths: Dict[str, str] = field(default_factory=dict) - rollback_info: Dict[str, Any] = field(default_factory=dict) - retry_parent_task_id: str = "" - retry_summary: Dict[str, Any] = field(default_factory=dict) - - def to_summary(self) -> Dict[str, Any]: - return { - "task_id": self.task_id, - "source": self.source, - "status": self.status, - "current_step": self.current_step, - "total_chunks": self.total_chunks, - "done_chunks": self.done_chunks, - "failed_chunks": self.failed_chunks, - "cancelled_chunks": self.cancelled_chunks, - "progress": self.progress, - "error": self.error, - "file_count": len(self.files), - "created_at": self.created_at, - "started_at": self.started_at, - "finished_at": self.finished_at, - "updated_at": self.updated_at, - "task_kind": str(self.params.get("task_kind") or self.source), - "schema_detected": self.schema_detected, - "artifact_paths": dict(self.artifact_paths), - "rollback_info": dict(self.rollback_info), - "retry_parent_task_id": self.retry_parent_task_id or "", - "retry_summary": dict(self.retry_summary), - } - - def to_detail(self, include_chunks: bool = False) -> Dict[str, Any]: - payload = self.to_summary() - payload["params"] = self.params - payload["files"] = [f.to_dict(include_chunks=include_chunks) for f in self.files] - return payload - - -class ImportTaskManager: - def __init__(self, plugin: Any): - self.plugin = plugin - self._lock = asyncio.Lock() - self._storage_lock = asyncio.Lock() - - self._tasks: Dict[str, ImportTaskRecord] = {} - self._task_order: deque[str] = deque() - self._queue: deque[str] = deque() - self._active_task_id: Optional[str] = None - - self._worker_task: Optional[asyncio.Task] = None - self._stopping = False - - self._temp_root = self._resolve_temp_root() - self._temp_root.mkdir(parents=True, exist_ok=True) - self._reports_root = self._resolve_reports_root() - self._reports_root.mkdir(parents=True, exist_ok=True) - self._manifest_path = self._resolve_manifest_path() - self._manifest_cache: Optional[Dict[str, Any]] = None - self._write_changed_callback: Optional[Callable[[Dict[str, Any]], Any]] = None - - def set_write_changed_callback(self, callback: Optional[Callable[[Dict[str, Any]], Any]]) -> None: - self._write_changed_callback = callback - - async def _notify_write_changed(self, payload: Dict[str, Any]) -> None: - callback = self._write_changed_callback - if callback is None: - return - try: - maybe_awaitable = callback(payload) - if asyncio.iscoroutine(maybe_awaitable): - await maybe_awaitable - except Exception as e: - logger.warning(f"写入变更回调执行失败: {e}") - - def _resolve_temp_root(self) -> Path: - data_dir = Path(self.plugin.get_config("storage.data_dir", "./data")) - if str(data_dir).startswith("."): - plugin_dir = Path(__file__).resolve().parents[2] - data_dir = (plugin_dir / data_dir).resolve() - return data_dir / "web_import_tmp" - - def _resolve_reports_root(self) -> Path: - return self._resolve_data_dir() / "web_import_reports" - - def _resolve_manifest_path(self) -> Path: - return self._resolve_data_dir() / "import_manifest.json" - - def _resolve_staging_root(self) -> Path: - return self._resolve_data_dir() / "import_staging" - - def _resolve_backup_root(self) -> Path: - return self._resolve_data_dir() / "import_backup" - - def _resolve_repo_root(self) -> Path: - return Path(__file__).resolve().parents[3] - - def _resolve_data_dir(self) -> Path: - data_dir = Path(self.plugin.get_config("storage.data_dir", "./data")) - if str(data_dir).startswith("."): - plugin_dir = Path(__file__).resolve().parents[2] - data_dir = (plugin_dir / data_dir).resolve() - return data_dir.resolve() - - def _resolve_migration_script(self) -> Path: - return Path(__file__).resolve().parents[2] / "scripts" / "migrate_maibot_memory.py" - - def _default_maibot_source_db(self) -> Path: - # A_memorix/core/utils -> workspace root - return self._resolve_repo_root() / "MaiBot" / "data" / "MaiBot.db" - - def _cfg(self, key: str, default: Any) -> Any: - return self.plugin.get_config(key, default) - - def _cfg_int(self, key: str, default: int) -> int: - return _coerce_int(self._cfg(key, default), default) - - def _is_enabled(self) -> bool: - return bool(self._cfg("web.import.enabled", True)) - - def _queue_limit(self) -> int: - return max(1, self._cfg_int("web.import.max_queue_size", 20)) - - def _max_files_per_task(self) -> int: - return max(1, self._cfg_int("web.import.max_files_per_task", 200)) - - def _max_file_size_bytes(self) -> int: - mb = max(1, self._cfg_int("web.import.max_file_size_mb", 20)) - return mb * 1024 * 1024 - - def _max_paste_chars(self) -> int: - return max(1000, self._cfg_int("web.import.max_paste_chars", 200000)) - - def _default_file_concurrency(self) -> int: - return max(1, self._cfg_int("web.import.default_file_concurrency", 2)) - - def _default_chunk_concurrency(self) -> int: - return max(1, self._cfg_int("web.import.default_chunk_concurrency", 4)) - - def _max_file_concurrency(self) -> int: - return max(1, self._cfg_int("web.import.max_file_concurrency", 6)) - - def _max_chunk_concurrency(self) -> int: - return max(1, self._cfg_int("web.import.max_chunk_concurrency", 12)) - - def _llm_retry_config(self) -> Dict[str, float]: - retries = max(0, self._cfg_int("web.import.llm_retry.max_attempts", 4)) - min_wait = max(0.1, float(self._cfg("web.import.llm_retry.min_wait_seconds", 3) or 3)) - max_wait = max(min_wait, float(self._cfg("web.import.llm_retry.max_wait_seconds", 40) or 40)) - mult = max(1.0, float(self._cfg("web.import.llm_retry.backoff_multiplier", 3) or 3)) - return { - "retries": retries, - "min_wait": min_wait, - "max_wait": max_wait, - "multiplier": mult, - } - - def _default_path_aliases(self) -> Dict[str, str]: - plugin_dir = Path(__file__).resolve().parents[2] - repo_root = self._resolve_repo_root() - return { - "raw": str((plugin_dir / "data" / "raw").resolve()), - "lpmm": str((repo_root / "data" / "lpmm_storage").resolve()), - "plugin_data": str((plugin_dir / "data").resolve()), - } - - def get_path_aliases(self) -> Dict[str, str]: - configured = self._cfg("web.import.path_aliases", self._default_path_aliases()) - if not isinstance(configured, dict): - configured = self._default_path_aliases() - - repo_root = self._resolve_repo_root() - result: Dict[str, str] = {} - for alias, raw_path in configured.items(): - key = str(alias or "").strip() - if not key: - continue - text = str(raw_path or "").strip() - if not text: - continue - if text.startswith("\\\\"): - continue - p = Path(text) - if not p.is_absolute(): - p = (repo_root / p).resolve() - else: - p = p.resolve() - result[key] = str(p) - - defaults = self._default_path_aliases() - for key, path in defaults.items(): - result.setdefault(key, path) - return result - - def resolve_path_alias( - self, - alias: str, - relative_path: str = "", - *, - must_exist: bool = False, - ) -> Path: - alias_key = str(alias or "").strip() - aliases = self.get_path_aliases() - if alias_key not in aliases: - raise ValueError(f"未知路径别名: {alias_key}") - - root = Path(aliases[alias_key]).resolve() - rel = str(relative_path or "").strip().replace("\\", "/") - if rel.startswith("/") or rel.startswith("\\") or rel.startswith("//"): - raise ValueError("relative_path 不能为绝对路径") - if ":" in rel: - raise ValueError("relative_path 不允许包含盘符") - - candidate = (root / rel).resolve() if rel else root - try: - candidate.relative_to(root) - except ValueError: - raise ValueError("路径越界:relative_path 超出白名单目录") - if must_exist and not candidate.exists(): - raise ValueError(f"路径不存在: {candidate}") - return candidate - - async def resolve_path_request(self, payload: Dict[str, Any]) -> Dict[str, Any]: - alias = str(payload.get("alias") or "").strip() - relative_path = str(payload.get("relative_path") or "").strip() - must_exist = _coerce_bool(payload.get("must_exist"), True) - resolved = self.resolve_path_alias(alias, relative_path, must_exist=must_exist) - return { - "alias": alias, - "relative_path": relative_path, - "resolved_path": str(resolved), - "exists": resolved.exists(), - "is_file": resolved.is_file(), - "is_dir": resolved.is_dir(), - } - - def _load_manifest(self) -> Dict[str, Any]: - if self._manifest_cache is not None: - return self._manifest_cache - path = self._manifest_path - if not path.exists(): - self._manifest_cache = {} - return self._manifest_cache - try: - payload = json.loads(path.read_text(encoding="utf-8")) - if isinstance(payload, dict): - self._manifest_cache = payload - else: - self._manifest_cache = {} - except Exception: - self._manifest_cache = {} - return self._manifest_cache - - def _save_manifest(self, payload: Dict[str, Any]) -> None: - path = self._manifest_path - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - self._manifest_cache = payload - - def _clear_manifest(self) -> None: - self._save_manifest({}) - - def _normalize_manifest_path(self, raw_path: str) -> str: - text = str(raw_path or "").strip() - if not text: - return "" - return text.replace("\\", "/").strip().lower() - - def _match_manifest_item_for_source(self, source: str, item: Dict[str, Any]) -> bool: - source_text = str(source or "").strip() - if not source_text or ":" not in source_text: - return False - prefix, tail = source_text.split(":", 1) - source_kind = prefix.strip().lower() - source_value = tail.strip() - if not source_value: - return False - - item_kind = str(item.get("source_kind") or "").strip().lower() - item_name = str(item.get("name") or "").strip() - item_path_norm = self._normalize_manifest_path(item.get("source_path") or "") - - if source_kind in {"raw_scan", "lpmm_openie"}: - source_path_norm = self._normalize_manifest_path(source_value) - if source_path_norm and item_path_norm and source_path_norm == item_path_norm and item_kind == source_kind: - return True - - if source_kind == "web_import": - return item_kind in {"upload", "paste"} and item_name == source_value - - if source_kind == "lpmm_openie": - source_name = Path(source_value).name - return item_kind == "lpmm_openie" and item_name == source_name - - return False - - async def invalidate_manifest_for_sources(self, sources: List[str]) -> Dict[str, Any]: - requested_sources: List[str] = [] - seen_sources = set() - for raw in sources or []: - source = str(raw or "").strip() - if not source: - continue - key = source.lower() - if key in seen_sources: - continue - seen_sources.add(key) - requested_sources.append(source) - - result: Dict[str, Any] = { - "requested_sources": requested_sources, - "removed_count": 0, - "removed_keys": [], - "remaining_count": 0, - "unmatched_sources": [], - "warnings": [], - } - - async with self._lock: - manifest = self._load_manifest() - if not isinstance(manifest, dict): - manifest = {} - - valid_items: List[Tuple[str, Dict[str, Any]]] = [] - malformed_keys: List[str] = [] - for key, item in manifest.items(): - if isinstance(item, dict): - valid_items.append((str(key), item)) - else: - malformed_keys.append(str(key)) - - keys_to_remove = set() - for source in requested_sources: - matched = False - for key, item in valid_items: - if self._match_manifest_item_for_source(source, item): - keys_to_remove.add(key) - matched = True - if not matched: - result["unmatched_sources"].append(source) - - if keys_to_remove: - for key in keys_to_remove: - manifest.pop(key, None) - self._save_manifest(manifest) - - result["removed_keys"] = sorted(keys_to_remove) - result["removed_count"] = len(keys_to_remove) - result["remaining_count"] = len(manifest) - - if malformed_keys: - preview = ", ".join(malformed_keys[:5]) - extra = "" if len(malformed_keys) <= 5 else f" ... (+{len(malformed_keys) - 5})" - result["warnings"].append( - f"manifest 条目结构异常,已跳过 {len(malformed_keys)} 项: {preview}{extra}" - ) - - return result - - def _manifest_key_for_file(self, file_record: ImportFileRecord, content_hash: str, dedupe_policy: str) -> str: - if dedupe_policy == "content_hash": - return f"hash:{content_hash}" - if file_record.source_path: - return f"path:{Path(file_record.source_path).as_posix().lower()}" - return f"hash:{content_hash}" - - def _is_manifest_hit( - self, - file_record: ImportFileRecord, - content_hash: str, - dedupe_policy: str, - ) -> bool: - key = self._manifest_key_for_file(file_record, content_hash, dedupe_policy) - manifest = self._load_manifest() - item = manifest.get(key) - if not isinstance(item, dict): - return False - return str(item.get("hash") or "") == content_hash and bool(item.get("imported")) - - def _record_manifest_import( - self, - file_record: ImportFileRecord, - content_hash: str, - dedupe_policy: str, - task_id: str, - ) -> None: - key = self._manifest_key_for_file(file_record, content_hash, dedupe_policy) - manifest = self._load_manifest() - manifest[key] = { - "hash": content_hash, - "imported": True, - "timestamp": _now(), - "task_id": task_id, - "name": file_record.name, - "source_path": file_record.source_path or "", - "source_kind": file_record.source_kind, - } - self._save_manifest(manifest) - - def _normalize_common_import_params(self, payload: Dict[str, Any], *, default_dedupe: str) -> Dict[str, Any]: - input_mode = str(payload.get("input_mode", "text") or "text").strip().lower() - if input_mode not in {"text", "json"}: - raise ValueError("input_mode 必须为 text 或 json") - - file_concurrency = _coerce_int( - payload.get("file_concurrency", self._default_file_concurrency()), - self._default_file_concurrency(), - ) - chunk_concurrency = _coerce_int( - payload.get("chunk_concurrency", self._default_chunk_concurrency()), - self._default_chunk_concurrency(), - ) - file_concurrency = _clamp(file_concurrency, 1, self._max_file_concurrency()) - chunk_concurrency = _clamp(chunk_concurrency, 1, self._max_chunk_concurrency()) - - llm_enabled = _coerce_bool(payload.get("llm_enabled", True), True) - strategy_override = parse_import_strategy( - payload.get("strategy_override", "auto"), - default=ImportStrategy.AUTO, - ).value - - dedupe_policy = str(payload.get("dedupe_policy", default_dedupe) or default_dedupe).strip().lower() - if dedupe_policy not in {"content_hash", "manifest", "none"}: - raise ValueError("dedupe_policy 必须为 content_hash/manifest/none") - - chat_log = _coerce_bool(payload.get("chat_log"), False) - chat_reference_time = str(payload.get("chat_reference_time") or "").strip() or None - force = _coerce_bool(payload.get("force"), False) - clear_manifest = _coerce_bool(payload.get("clear_manifest"), False) - - return { - "input_mode": input_mode, - "file_concurrency": file_concurrency, - "chunk_concurrency": chunk_concurrency, - "llm_enabled": llm_enabled, - "strategy_override": strategy_override, - "chat_log": chat_log, - "chat_reference_time": chat_reference_time, - "force": force, - "clear_manifest": clear_manifest, - "dedupe_policy": dedupe_policy, - } - - def _normalize_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - params = self._normalize_common_import_params(payload, default_dedupe="content_hash") - params["task_kind"] = "upload" - return params - - def _normalize_raw_scan_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - params = self._normalize_common_import_params(payload, default_dedupe="manifest") - alias = str(payload.get("alias") or "raw").strip() - relative_path = str(payload.get("relative_path") or "").strip() - glob_pattern = str(payload.get("glob") or "*").strip() or "*" - recursive = _coerce_bool(payload.get("recursive"), True) - if ".." in relative_path.replace("\\", "/").split("/"): - raise ValueError("relative_path 不允许包含 ..") - params.update( - { - "task_kind": "raw_scan", - "alias": alias, - "relative_path": relative_path, - "glob": glob_pattern, - "recursive": recursive, - } - ) - return params - - def _normalize_lpmm_openie_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - params = self._normalize_common_import_params(payload, default_dedupe="manifest") - alias = str(payload.get("alias") or "lpmm").strip() - relative_path = str(payload.get("relative_path") or "").strip() - include_all_json = _coerce_bool(payload.get("include_all_json"), False) - params.update( - { - "task_kind": "lpmm_openie", - "alias": alias, - "relative_path": relative_path, - "include_all_json": include_all_json, - "input_mode": "json", - } - ) - return params - - def _normalize_temporal_backfill_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - alias = str(payload.get("alias") or "plugin_data").strip() - relative_path = str(payload.get("relative_path") or "").strip() - dry_run = _coerce_bool(payload.get("dry_run"), False) - no_created_fallback = _coerce_bool(payload.get("no_created_fallback"), False) - limit = _parse_optional_positive_int(payload.get("limit"), "limit") or 100000 - return { - "task_kind": "temporal_backfill", - "alias": alias, - "relative_path": relative_path, - "dry_run": dry_run, - "no_created_fallback": no_created_fallback, - "limit": limit, - } - - def _normalize_lpmm_convert_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - alias = str(payload.get("alias") or "lpmm").strip() - relative_path = str(payload.get("relative_path") or "").strip() - target_alias = str(payload.get("target_alias") or "plugin_data").strip() - target_relative_path = str(payload.get("target_relative_path") or "").strip() - dimension = _parse_optional_positive_int(payload.get("dimension"), "dimension") or _coerce_int( - self._cfg("embedding.dimension", 384), - 384, - ) - batch_size = _parse_optional_positive_int(payload.get("batch_size"), "batch_size") or 1024 - return { - "task_kind": "lpmm_convert", - "alias": alias, - "relative_path": relative_path, - "target_alias": target_alias, - "target_relative_path": target_relative_path, - "dimension": dimension, - "batch_size": batch_size, - } - - def _normalize_by_task_kind(self, task_kind: str, payload: Dict[str, Any]) -> Dict[str, Any]: - kind = str(task_kind or "").strip().lower() - if kind in {"upload", "paste"}: - params = self._normalize_params(payload) - params["task_kind"] = kind - return params - if kind == "maibot_migration": - return self._normalize_migration_params(payload) - if kind == "raw_scan": - return self._normalize_raw_scan_params(payload) - if kind == "lpmm_openie": - return self._normalize_lpmm_openie_params(payload) - if kind == "temporal_backfill": - return self._normalize_temporal_backfill_params(payload) - if kind == "lpmm_convert": - return self._normalize_lpmm_convert_params(payload) - # upload/paste 默认走通用文本导入参数 - return self._normalize_params(payload) - - def _normalize_migration_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - source_db = str(payload.get("source_db") or "").strip() - if not source_db: - source_db = str(self._default_maibot_source_db()) - - time_from = str(payload.get("time_from") or "").strip() or None - time_to = str(payload.get("time_to") or "").strip() or None - - stream_ids = _coerce_list(payload.get("stream_ids")) - group_ids = _coerce_list(payload.get("group_ids")) - user_ids = _coerce_list(payload.get("user_ids")) - - start_id = _parse_optional_positive_int(payload.get("start_id"), "start_id") - end_id = _parse_optional_positive_int(payload.get("end_id"), "end_id") - if start_id is not None and end_id is not None and start_id > end_id: - raise ValueError("start_id 不能大于 end_id") - - read_batch_size = _parse_optional_positive_int(payload.get("read_batch_size"), "read_batch_size") or 2000 - commit_window_rows = _parse_optional_positive_int(payload.get("commit_window_rows"), "commit_window_rows") or 20000 - embed_batch_size = _parse_optional_positive_int(payload.get("embed_batch_size"), "embed_batch_size") or 256 - entity_embed_batch_size = ( - _parse_optional_positive_int(payload.get("entity_embed_batch_size"), "entity_embed_batch_size") or 512 - ) - embed_workers = _parse_optional_positive_int(payload.get("embed_workers"), "embed_workers") - max_errors = _parse_optional_positive_int(payload.get("max_errors"), "max_errors") or 500 - log_every = _parse_optional_positive_int(payload.get("log_every"), "log_every") or 5000 - preview_limit = _parse_optional_positive_int(payload.get("preview_limit"), "preview_limit") or 20 - - no_resume = _coerce_bool(payload.get("no_resume"), False) - reset_state = _coerce_bool(payload.get("reset_state"), False) - dry_run = _coerce_bool(payload.get("dry_run"), False) - verify_only = _coerce_bool(payload.get("verify_only"), False) - - return { - "task_kind": "maibot_migration", - "source_db": source_db, - "target_data_dir": str(self._resolve_data_dir()), - "time_from": time_from, - "time_to": time_to, - "stream_ids": stream_ids, - "group_ids": group_ids, - "user_ids": user_ids, - "start_id": start_id, - "end_id": end_id, - "read_batch_size": read_batch_size, - "commit_window_rows": commit_window_rows, - "embed_batch_size": embed_batch_size, - "entity_embed_batch_size": entity_embed_batch_size, - "embed_workers": embed_workers, - "max_errors": max_errors, - "log_every": log_every, - "preview_limit": preview_limit, - "no_resume": no_resume, - "reset_state": reset_state, - "dry_run": dry_run, - "verify_only": verify_only, - } - - def _pending_task_count(self) -> int: - pending = 0 - for task in self._tasks.values(): - if task.status in {"queued", "preparing", "running", "cancel_requested"}: - pending += 1 - return pending - - async def _ensure_worker(self) -> None: - async with self._lock: - if self._worker_task and not self._worker_task.done(): - return - self._stopping = False - self._worker_task = asyncio.create_task(self._worker_loop()) - - async def get_runtime_settings(self) -> Dict[str, Any]: - llm_retry = self._llm_retry_config() - return { - "max_queue_size": self._queue_limit(), - "max_files_per_task": self._max_files_per_task(), - "max_file_size_mb": self._cfg_int("web.import.max_file_size_mb", 20), - "max_paste_chars": self._max_paste_chars(), - "default_file_concurrency": self._default_file_concurrency(), - "default_chunk_concurrency": self._default_chunk_concurrency(), - "max_file_concurrency": self._max_file_concurrency(), - "max_chunk_concurrency": self._max_chunk_concurrency(), - "poll_interval_ms": max(200, self._cfg_int("web.import.poll_interval_ms", 1000)), - "maibot_source_db_default": str(self._default_maibot_source_db()), - "maibot_target_data_dir": str(self._resolve_data_dir()), - "path_aliases": self.get_path_aliases(), - "llm_retry": llm_retry, - "convert_enable_staging_switch": _coerce_bool( - self._cfg("web.import.convert.enable_staging_switch", True), True - ), - "convert_keep_backup_count": max(0, self._cfg_int("web.import.convert.keep_backup_count", 3)), - } - - def is_write_blocked(self) -> bool: - task_id = self._active_task_id - if not task_id: - return False - task = self._tasks.get(task_id) - if not task: - return False - return task.status in {"preparing", "running", "cancel_requested"} - - def _ensure_ready(self) -> None: - required_attrs = ("metadata_store", "vector_store", "graph_store", "embedding_manager") - - def _collect_missing() -> List[str]: - missing_local: List[str] = [] - for attr in required_attrs: - if getattr(self.plugin, attr, None) is None: - missing_local.append(attr) - return missing_local - - missing = _collect_missing() - if missing: - raise ValueError(f"导入依赖未初始化: {', '.join(missing)}") - ready_checker = getattr(self.plugin, "is_runtime_ready", None) - if callable(ready_checker) and not ready_checker(): - raise ValueError("插件运行时未就绪,请先完成 on_enable 初始化") - - def _scan_files( - self, - base_path: Path, - *, - recursive: bool, - glob_pattern: str, - allowed_exts: Optional[set[str]] = None, - ) -> List[Path]: - if base_path.is_file(): - candidates = [base_path] - else: - if recursive: - candidates = list(base_path.rglob(glob_pattern)) - else: - candidates = list(base_path.glob(glob_pattern)) - out: List[Path] = [] - for p in candidates: - if not p.is_file(): - continue - ext = p.suffix.lower() - if allowed_exts and ext not in allowed_exts: - continue - out.append(p.resolve()) - out.sort(key=lambda x: x.as_posix().lower()) - return out - - async def create_upload_task(self, files: List[Any], payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - if not files: - raise ValueError("至少需要上传一个文件") - - params = self._normalize_params(payload) - max_files = self._max_files_per_task() - if len(files) > max_files: - raise ValueError(f"单任务文件数超过上限: {max_files}") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="upload", - params=params, - status="queued", - current_step="queued", - ) - task_dir = self._temp_root / task.task_id - task_dir.mkdir(parents=True, exist_ok=True) - - max_size = self._max_file_size_bytes() - for idx, uploaded in enumerate(files): - file_id = uuid.uuid4().hex - if isinstance(uploaded, dict): - staged_path_raw = uploaded.get("staged_path") or uploaded.get("path") or "" - staged_path = Path(str(staged_path_raw or "")).expanduser().resolve() - if not staged_path.is_file(): - raise ValueError(f"上传暂存文件不存在: {staged_path}") - name = _safe_filename(uploaded.get("filename") or uploaded.get("name") or staged_path.name) - ext = Path(name).suffix.lower() - if ext not in {".txt", ".md", ".json"}: - raise ValueError(f"不支持的文件类型: {name}") - if staged_path.stat().st_size > max_size: - raise ValueError(f"文件超过大小限制: {name}") - temp_path = task_dir / f"{file_id}_{name}" - shutil.copy2(staged_path, temp_path) - else: - name = _safe_filename(getattr(uploaded, "filename", f"file_{idx}.txt")) - ext = Path(name).suffix.lower() - if ext not in {".txt", ".md", ".json"}: - raise ValueError(f"不支持的文件类型: {name}") - content = await uploaded.read() - if len(content) > max_size: - raise ValueError(f"文件超过大小限制: {name}") - temp_path = task_dir / f"{file_id}_{name}" - temp_path.write_bytes(content) - file_mode = "json" if ext == ".json" else params["input_mode"] - task.files.append( - ImportFileRecord( - file_id=file_id, - name=name, - source_kind="upload", - input_mode=file_mode, - temp_path=str(temp_path), - ) - ) - - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_paste_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - - params = self._normalize_params(payload) - params["task_kind"] = "paste" - content = str(payload.get("content", "") or "") - if not content.strip(): - raise ValueError("content 不能为空") - if len(content) > self._max_paste_chars(): - raise ValueError(f"粘贴内容超过限制: {self._max_paste_chars()} 字符") - - name = _safe_filename(payload.get("name") or f"paste_{int(_now())}.txt") - if params["input_mode"] == "json" and Path(name).suffix.lower() != ".json": - name = f"{Path(name).stem}.json" - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="paste", - params=params, - status="queued", - current_step="queued", - ) - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=name, - source_kind="paste", - input_mode=params["input_mode"], - inline_content=content, - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_raw_scan_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - params = self._normalize_raw_scan_params(payload) - source_path = self.resolve_path_alias( - params["alias"], - params["relative_path"], - must_exist=True, - ) - files = self._scan_files( - source_path, - recursive=bool(params["recursive"]), - glob_pattern=str(params["glob"] or "*"), - allowed_exts={".txt", ".md", ".json"}, - ) - if not files: - raise ValueError("未找到可导入文件") - if len(files) > self._max_files_per_task(): - raise ValueError(f"单任务文件数超过上限: {self._max_files_per_task()}") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="raw_scan", - params=params, - status="queued", - current_step="queued", - ) - for path in files: - mode = "json" if path.suffix.lower() == ".json" else params["input_mode"] - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=path.name, - source_kind="raw_scan", - input_mode=mode, - source_path=str(path), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_lpmm_openie_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - params = self._normalize_lpmm_openie_params(payload) - source_path = self.resolve_path_alias( - params["alias"], - params["relative_path"], - must_exist=True, - ) - files: List[Path] = [] - if source_path.is_file(): - files = [source_path] - else: - files = self._scan_files( - source_path, - recursive=True, - glob_pattern="*-openie.json", - allowed_exts={".json"}, - ) - if not files and params.get("include_all_json"): - files = self._scan_files( - source_path, - recursive=True, - glob_pattern="*.json", - allowed_exts={".json"}, - ) - if not files: - raise ValueError("未找到 LPMM OpenIE JSON 文件") - if len(files) > self._max_files_per_task(): - raise ValueError(f"单任务文件数超过上限: {self._max_files_per_task()}") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="lpmm_openie", - params=params, - status="queued", - current_step="queued", - schema_detected="lpmm_openie", - ) - for path in files: - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=path.name, - source_kind="lpmm_openie", - input_mode="json", - source_path=str(path), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_temporal_backfill_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - params = self._normalize_temporal_backfill_params(payload) - target_path = self.resolve_path_alias( - params["alias"], - params["relative_path"], - must_exist=True, - ) - if not target_path.is_dir(): - raise ValueError("temporal_backfill 目标路径必须为目录") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="temporal_backfill", - params=params, - status="queued", - current_step="queued", - ) - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=f"temporal_backfill_{int(_now())}", - source_kind="temporal_backfill", - input_mode="json", - source_path=str(target_path), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_lpmm_convert_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - params = self._normalize_lpmm_convert_params(payload) - source_path = self.resolve_path_alias( - params["alias"], - params["relative_path"], - must_exist=True, - ) - if not source_path.is_dir(): - raise ValueError("lpmm_convert 输入路径必须为目录") - target_path = self.resolve_path_alias( - params["target_alias"], - params["target_relative_path"], - must_exist=False, - ) - target_path.mkdir(parents=True, exist_ok=True) - if not target_path.is_dir(): - raise ValueError("lpmm_convert 目标路径必须为目录") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="lpmm_convert", - params={**params, "source_path": str(source_path), "target_path": str(target_path)}, - status="queued", - current_step="queued", - ) - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=f"lpmm_convert_{int(_now())}", - source_kind="lpmm_convert", - input_mode="json", - source_path=str(source_path), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_maibot_migration_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - - params = self._normalize_migration_params(payload) - script_path = self._resolve_migration_script() - if not script_path.exists(): - raise ValueError(f"迁移脚本不存在: {script_path}") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="maibot_migration", - params=params, - status="queued", - current_step="queued", - ) - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=f"maibot_migration_{int(_now())}", - source_kind="maibot_migration", - input_mode="text", - inline_content=json.dumps(params, ensure_ascii=False), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def list_tasks(self, limit: int = 50) -> List[Dict[str, Any]]: - async with self._lock: - task_ids = list(self._task_order)[: max(1, int(limit))] - return [self._tasks[task_id].to_summary() for task_id in task_ids if task_id in self._tasks] - - async def get_task(self, task_id: str, include_chunks: bool = False) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - return task.to_detail(include_chunks=include_chunks) - - async def get_chunks(self, task_id: str, file_id: str, offset: int = 0, limit: int = 50) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - file_obj = self._find_file(task, file_id) - if not file_obj: - return None - start = max(0, int(offset)) - size = max(1, min(500, int(limit))) - items = file_obj.chunks[start : start + size] - return { - "task_id": task_id, - "file_id": file_id, - "offset": start, - "limit": size, - "total": len(file_obj.chunks), - "items": [x.to_dict() for x in items], - } - - async def cancel_task(self, task_id: str) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - if task.status == "queued": - self._mark_task_cancelled_locked(task, "任务已取消") - self._queue = deque([x for x in self._queue if x != task_id]) - elif task.status in {"preparing", "running"}: - task.status = "cancel_requested" - task.current_step = "cancel_requested" - task.updated_at = _now() - return task.to_summary() - - def _build_retry_plan(self, task: ImportTaskRecord) -> Dict[str, Any]: - chunk_retry_candidates: List[Tuple[ImportFileRecord, List[int]]] = [] - file_fallback_candidates: List[ImportFileRecord] = [] - skipped: List[Dict[str, str]] = [] - - for file_obj in task.files: - if file_obj.status == "cancelled": - continue - - failed_chunks = [c for c in file_obj.chunks if c.status == "failed"] - has_file_level_failure = file_obj.status == "failed" and not failed_chunks - if has_file_level_failure: - file_fallback_candidates.append(file_obj) - continue - - if not failed_chunks: - continue - - retry_indexes: List[int] = [] - has_non_retryable = False - for chunk in failed_chunks: - failed_at = str(chunk.failed_at or "").strip().lower() - retryable = bool(chunk.retryable) or ( - file_obj.input_mode == "text" and failed_at == "extracting" - ) - if retryable: - try: - retry_indexes.append(int(chunk.index)) - except Exception: - has_non_retryable = True - else: - has_non_retryable = True - - if has_non_retryable: - file_fallback_candidates.append(file_obj) - continue - - retry_indexes = sorted(set(retry_indexes)) - if retry_indexes: - chunk_retry_candidates.append((file_obj, retry_indexes)) - else: - skipped.append( - { - "file_name": file_obj.name, - "source_kind": file_obj.source_kind, - "reason": "no_retryable_failed_chunks", - } - ) - - unique_fallback: List[ImportFileRecord] = [] - fallback_seen = set() - for file_obj in file_fallback_candidates: - if file_obj.file_id in fallback_seen: - continue - fallback_seen.add(file_obj.file_id) - unique_fallback.append(file_obj) - - return { - "chunk_retry_candidates": chunk_retry_candidates, - "file_fallback_candidates": unique_fallback, - "skipped": skipped, - } - - def _clone_failed_file_for_retry( - self, - retry_task: ImportTaskRecord, - failed_file: ImportFileRecord, - task_dir: Path, - *, - retry_mode: str, - retry_chunk_indexes: Optional[List[int]] = None, - ) -> Tuple[bool, str]: - source_kind = str(failed_file.source_kind or "").strip().lower() - retry_chunk_indexes = list(retry_chunk_indexes or []) - - if source_kind == "upload": - candidate_paths: List[Path] = [] - if failed_file.temp_path: - candidate_paths.append(Path(failed_file.temp_path)) - if failed_file.source_path: - candidate_paths.append(Path(failed_file.source_path)) - src_path = next((p for p in candidate_paths if p.exists() and p.is_file()), None) - if src_path is None: - return False, "upload_source_missing" - data = src_path.read_bytes() - file_id = uuid.uuid4().hex - name = _safe_filename(failed_file.name) - dst = task_dir / f"{file_id}_{name}" - dst.write_bytes(data) - retry_task.files.append( - ImportFileRecord( - file_id=file_id, - name=name, - source_kind="upload", - input_mode=failed_file.input_mode, - temp_path=str(dst), - retry_mode=retry_mode, - retry_chunk_indexes=retry_chunk_indexes, - ) - ) - return True, "" - - if source_kind == "paste": - if failed_file.inline_content is None: - return False, "paste_content_missing" - retry_task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=_safe_filename(failed_file.name), - source_kind="paste", - input_mode=failed_file.input_mode, - inline_content=failed_file.inline_content, - retry_mode=retry_mode, - retry_chunk_indexes=retry_chunk_indexes, - ) - ) - return True, "" - - if source_kind == "maibot_migration": - retry_task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=_safe_filename(failed_file.name), - source_kind="maibot_migration", - input_mode="text", - inline_content=failed_file.inline_content, - retry_mode="file_fallback", - retry_chunk_indexes=[], - ) - ) - return True, "" - - if source_kind in {"raw_scan", "lpmm_openie", "lpmm_convert", "temporal_backfill"}: - retry_task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=_safe_filename(failed_file.name), - source_kind=source_kind, - input_mode=failed_file.input_mode, - source_path=failed_file.source_path, - inline_content=failed_file.inline_content, - retry_mode=retry_mode, - retry_chunk_indexes=retry_chunk_indexes, - ) - ) - return True, "" - - return False, f"unsupported_source_kind:{source_kind or 'unknown'}" - - async def retry_failed(self, task_id: str, overrides: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - retry_plan = self._build_retry_plan(task) - chunk_retry_candidates = list(retry_plan["chunk_retry_candidates"]) - file_fallback_candidates = list(retry_plan["file_fallback_candidates"]) - skipped_candidates = list(retry_plan["skipped"]) - if not chunk_retry_candidates and not file_fallback_candidates: - raise ValueError("当前任务没有可重试失败项") - base_params = dict(task.params) - task_kind = str(task.params.get("task_kind") or "").strip().lower() - - if overrides: - base_params.update(overrides) - params = self._normalize_by_task_kind(task_kind, base_params) - params["retry_parent_task_id"] = task_id - params["retry_strategy"] = "chunk_first_auto_file_fallback" - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - retry_task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source=task.source, - params=params, - status="queued", - current_step="queued", - schema_detected=task.schema_detected, - retry_parent_task_id=task_id, - ) - - task_dir = self._temp_root / retry_task.task_id - task_dir.mkdir(parents=True, exist_ok=True) - - retry_summary = { - "chunk_retry_files": 0, - "chunk_retry_chunks": 0, - "file_fallback_files": 0, - "skipped_files": 0, - "parent_task_id": task_id, - } - skipped_details = list(skipped_candidates) - - for file_obj, chunk_indexes in chunk_retry_candidates: - ok, reason = self._clone_failed_file_for_retry( - retry_task, - file_obj, - task_dir, - retry_mode="chunk", - retry_chunk_indexes=chunk_indexes, - ) - if ok: - retry_summary["chunk_retry_files"] += 1 - retry_summary["chunk_retry_chunks"] += len(chunk_indexes) - else: - skipped_details.append( - { - "file_name": file_obj.name, - "source_kind": file_obj.source_kind, - "reason": reason, - } - ) - - for file_obj in file_fallback_candidates: - ok, reason = self._clone_failed_file_for_retry( - retry_task, - file_obj, - task_dir, - retry_mode="file_fallback", - retry_chunk_indexes=[], - ) - if ok: - retry_summary["file_fallback_files"] += 1 - else: - skipped_details.append( - { - "file_name": file_obj.name, - "source_kind": file_obj.source_kind, - "reason": reason, - } - ) - - retry_summary["skipped_files"] = len(skipped_details) - if skipped_details: - retry_summary["skipped_details"] = skipped_details - retry_task.retry_summary = retry_summary - - if not retry_task.files: - raise ValueError("无可执行的重试输入:失败项均无法构建重试任务") - - self._tasks[retry_task.task_id] = retry_task - self._task_order.appendleft(retry_task.task_id) - self._queue.append(retry_task.task_id) - logger.info( - "重试任务已创建 " - f"parent={task_id} retry={retry_task.task_id} " - f"chunk_files={retry_summary['chunk_retry_files']} " - f"chunk_chunks={retry_summary['chunk_retry_chunks']} " - f"file_fallback={retry_summary['file_fallback_files']} " - f"skipped={retry_summary['skipped_files']}" - ) - - await self._ensure_worker() - return retry_task.to_summary() - - async def shutdown(self) -> None: - async with self._lock: - self._stopping = True - for task in self._tasks.values(): - if task.status in {"queued", "preparing", "running", "cancel_requested"}: - self._mark_task_cancelled_locked(task, "服务关闭") - self._queue.clear() - worker = self._worker_task - self._worker_task = None - - if worker: - worker.cancel() - try: - await worker - except asyncio.CancelledError: - pass - except Exception: - pass - - self._cleanup_temp_root() - - def _cleanup_temp_root(self) -> None: - try: - if not self._temp_root.exists(): - return - for child in self._temp_root.rglob("*"): - if child.is_file(): - child.unlink(missing_ok=True) - for child in sorted(self._temp_root.rglob("*"), reverse=True): - if child.is_dir(): - child.rmdir() - self._temp_root.rmdir() - except Exception as e: - logger.warning(f"清理临时导入目录失败: {e}") - - async def _worker_loop(self) -> None: - logger.info("Web 导入任务 worker 已启动") - while True: - if self._stopping: - break - - task_id: Optional[str] = None - async with self._lock: - while self._queue: - candidate = self._queue.popleft() - t = self._tasks.get(candidate) - if not t: - continue - if t.status == "cancelled": - continue - task_id = candidate - self._active_task_id = candidate - break - - if not task_id: - await asyncio.sleep(0.2) - continue - - try: - await self._run_task(task_id) - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"导入任务执行失败 task={task_id}: {e}\n{traceback.format_exc()}") - async with self._lock: - task = self._tasks.get(task_id) - if task and task.status not in {"cancelled", "completed", "completed_with_errors"}: - task.status = "failed" - task.current_step = "failed" - task.error = str(e) - task.finished_at = _now() - task.updated_at = _now() - finally: - should_cleanup = await self._should_cleanup_task_temp(task_id) - async with self._lock: - if self._active_task_id == task_id: - self._active_task_id = None - if should_cleanup: - await self._cleanup_task_temp_files(task_id) - - logger.info("Web 导入任务 worker 已停止") - - async def _cleanup_task_temp_files(self, task_id: str) -> None: - task_dir = self._temp_root / task_id - if not task_dir.exists(): - return - try: - for child in task_dir.rglob("*"): - if child.is_file(): - child.unlink(missing_ok=True) - for child in sorted(task_dir.rglob("*"), reverse=True): - if child.is_dir(): - child.rmdir() - task_dir.rmdir() - except Exception as e: - logger.warning(f"清理任务临时文件失败 task={task_id}: {e}") - - def _task_report_path(self, task_id: str) -> Path: - self._reports_root.mkdir(parents=True, exist_ok=True) - return self._reports_root / f"{task_id}_summary.json" - - def _write_task_report(self, task: ImportTaskRecord) -> None: - path = self._task_report_path(task.task_id) - payload = task.to_detail(include_chunks=False) - payload["generated_at"] = _now() - path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - task.artifact_paths["summary"] = str(path) - - async def _run_task(self, task_id: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - task.status = "preparing" - task.current_step = "preparing" - task.started_at = _now() - task.updated_at = _now() - if task.params.get("clear_manifest"): - self._clear_manifest() - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - if task.status == "cancel_requested": - task.status = "cancelled" - task.current_step = "cancelled" - task.finished_at = _now() - task.updated_at = _now() - return - task.status = "running" - task.current_step = "running" - task.updated_at = _now() - - task_kind = str(task.params.get("task_kind") or task.source).strip().lower() - if task_kind == "maibot_migration": - if not task.files: - raise RuntimeError("迁移任务缺少文件记录") - await self._process_maibot_migration(task_id, task.files[0]) - elif task_kind == "temporal_backfill": - if not task.files: - raise RuntimeError("回填任务缺少文件记录") - await self._process_temporal_backfill(task_id, task.files[0]) - elif task_kind == "lpmm_convert": - if not task.files: - raise RuntimeError("转换任务缺少文件记录") - await self._process_lpmm_convert(task_id, task.files[0]) - else: - file_semaphore = asyncio.Semaphore(task.params["file_concurrency"]) - chunk_semaphore = asyncio.Semaphore(task.params["chunk_concurrency"]) - jobs = [ - asyncio.create_task(self._process_file(task_id, f, file_semaphore, chunk_semaphore)) - for f in task.files - ] - await asyncio.gather(*jobs, return_exceptions=True) - - write_changed_payload: Optional[Dict[str, Any]] = None - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - self._recompute_task_progress(task) - has_failed = any( - (f.status == "failed") - or (f.failed_chunks > 0) - or bool(str(f.error or "").strip()) - for f in task.files - ) - has_cancelled = any(f.status == "cancelled" for f in task.files) - has_completed = any(f.status == "completed" for f in task.files) - - # 统一按文件真实终态收敛任务状态,避免出现“任务已取消但文件已完成”的矛盾结果。 - if has_failed and not has_cancelled: - task.status = "completed_with_errors" - task.current_step = "completed_with_errors" - elif has_cancelled and not has_completed: - task.status = "cancelled" - task.current_step = "cancelled" - elif has_cancelled and has_completed: - task.status = "cancelled" - task.current_step = "cancelled" - else: - task.status = "completed" - task.current_step = "completed" - task.finished_at = _now() - task.updated_at = _now() - try: - self._write_task_report(task) - except Exception as report_err: - logger.warning(f"写入任务报告失败 task={task_id}: {report_err}") - task_kind = str(task.params.get("task_kind") or task.source).strip().lower() - write_task_kinds = {"upload", "paste", "raw_scan", "lpmm_openie", "maibot_migration", "lpmm_convert"} - has_written_chunks = (task.done_chunks > 0) or any(f.done_chunks > 0 for f in task.files) - if task_kind in write_task_kinds and has_written_chunks: - write_changed_payload = { - "task_id": task.task_id, - "task_kind": task_kind, - "status": task.status, - "done_chunks": task.done_chunks, - "finished_at": task.finished_at, - } - - if write_changed_payload: - await self._notify_write_changed(write_changed_payload) - - def _build_maibot_migration_command(self, params: Dict[str, Any]) -> List[str]: - script_path = self._resolve_migration_script() - if not script_path.exists(): - raise RuntimeError(f"迁移脚本不存在: {script_path}") - - cmd = [ - sys.executable, - str(script_path), - "--source-db", - str(params["source_db"]), - "--target-data-dir", - str(params["target_data_dir"]), - "--read-batch-size", - str(params["read_batch_size"]), - "--commit-window-rows", - str(params["commit_window_rows"]), - "--embed-batch-size", - str(params["embed_batch_size"]), - "--entity-embed-batch-size", - str(params["entity_embed_batch_size"]), - "--max-errors", - str(params["max_errors"]), - "--log-every", - str(params["log_every"]), - "--preview-limit", - str(params["preview_limit"]), - "--yes", - ] - - if params.get("embed_workers") is not None: - cmd.extend(["--embed-workers", str(params["embed_workers"])]) - if params.get("start_id") is not None: - cmd.extend(["--start-id", str(params["start_id"])]) - if params.get("end_id") is not None: - cmd.extend(["--end-id", str(params["end_id"])]) - if params.get("time_from"): - cmd.extend(["--time-from", str(params["time_from"])]) - if params.get("time_to"): - cmd.extend(["--time-to", str(params["time_to"])]) - - for sid in params.get("stream_ids") or []: - cmd.extend(["--stream-id", str(sid)]) - for gid in params.get("group_ids") or []: - cmd.extend(["--group-id", str(gid)]) - for uid in params.get("user_ids") or []: - cmd.extend(["--user-id", str(uid)]) - - if params.get("reset_state"): - cmd.append("--reset-state") - if params.get("no_resume"): - cmd.append("--no-resume") - if params.get("dry_run"): - cmd.append("--dry-run") - if params.get("verify_only"): - cmd.append("--verify-only") - - return cmd - - async def _ensure_maibot_migration_chunk( - self, - task_id: str, - file_id: str, - *, - chunk_type: str = "maibot_migration", - preview: str = "MaiBot chat_history 迁移任务", - ) -> str: - chunk_id = f"{file_id}_{chunk_type}" - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return chunk_id - f = self._find_file(task, file_id) - if not f: - return chunk_id - if not f.chunks: - f.chunks = [ - ImportChunkRecord( - chunk_id=chunk_id, - index=0, - chunk_type=chunk_type, - status="queued", - step="queued", - progress=0.0, - content_preview=preview, - ) - ] - f.total_chunks = 1 - f.done_chunks = 0 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 0.0 - f.updated_at = _now() - self._recompute_task_progress(task) - else: - chunk_id = f.chunks[0].chunk_id - return chunk_id - - async def _refresh_maibot_progress_from_state( - self, - task_id: str, - file_id: str, - chunk_id: str, - state_path: Path, - ) -> None: - if not state_path.exists(): - return - try: - payload = json.loads(state_path.read_text(encoding="utf-8")) - except Exception: - return - - stats = payload.get("stats", {}) if isinstance(payload, dict) else {} - if not isinstance(stats, dict): - stats = {} - - total = max(0, _coerce_int(stats.get("source_matched_total", 0), 0)) - scanned = max(0, _coerce_int(stats.get("scanned_rows", 0), 0)) - bad = max(0, _coerce_int(stats.get("bad_rows", 0), 0)) - done = max(0, scanned - bad) - migrated = max(0, _coerce_int(stats.get("migrated_rows", 0), 0)) - last_id = max(0, _coerce_int(stats.get("last_committed_id", 0), 0)) - - if total <= 0: - total = max(1, scanned) - - progress = max(0.0, min(1.0, float(scanned) / float(total))) if total > 0 else 0.0 - preview = f"scanned={scanned}/{total}, migrated={migrated}, bad={bad}, last_id={last_id}" - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if c: - if c.status not in {"completed", "failed", "cancelled"}: - c.status = "writing" - c.step = "migrating" - c.progress = progress - c.content_preview = preview - c.updated_at = _now() - f.total_chunks = total - f.done_chunks = done - f.failed_chunks = bad - f.cancelled_chunks = 0 - f.progress = progress - if f.status not in {"failed", "cancelled"}: - f.status = "writing" - f.current_step = "migrating" - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _terminate_process(self, process: asyncio.subprocess.Process) -> None: - if process.returncode is not None: - return - try: - process.terminate() - await asyncio.wait_for(process.wait(), timeout=5.0) - except Exception: - try: - process.kill() - await asyncio.wait_for(process.wait(), timeout=3.0) - except Exception: - pass - - async def _reload_stores_after_external_migration(self) -> None: - async with self._storage_lock: - try: - if self.plugin.vector_store and self.plugin.vector_store.has_data(): - self.plugin.vector_store.load() - except Exception as e: - logger.warning(f"迁移后重载 VectorStore 失败: {e}") - try: - if self.plugin.graph_store and self.plugin.graph_store.has_data(): - self.plugin.graph_store.load() - except Exception as e: - logger.warning(f"迁移后重载 GraphStore 失败: {e}") - - async def _process_maibot_migration(self, task_id: str, file_record: ImportFileRecord) -> None: - await self._set_file_strategy(task_id, file_record.file_id, "maibot_migration") - await self._set_file_state(task_id, file_record.file_id, "preparing", "preparing") - chunk_id = await self._ensure_maibot_migration_chunk( - task_id, - file_record.file_id, - chunk_type="maibot_migration", - preview="MaiBot chat_history 迁移任务", - ) - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "migrating", 0.0) - - task = self._tasks.get(task_id) - if not task: - await self._set_file_failed(task_id, file_record.file_id, "任务不存在") - return - params = dict(task.params) - - command = self._build_maibot_migration_command(params) - project_root = self._resolve_repo_root() - state_path = Path(params["target_data_dir"]) / "migration_state" / "chat_history_resume.json" - report_path = Path(params["target_data_dir"]) / "migration_state" / "chat_history_report.json" - - logger.info(f"开始执行 MaiBot 迁移任务: {' '.join(command)}") - process = await asyncio.create_subprocess_exec( - *command, - cwd=str(project_root), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - stdout_lines: List[str] = [] - stderr_lines: List[str] = [] - - async def _drain(stream: Optional[asyncio.StreamReader], target: List[str]) -> None: - if stream is None: - return - while True: - line = await stream.readline() - if not line: - break - text = line.decode("utf-8", errors="replace").strip() - if not text: - continue - target.append(text) - if len(target) > 120: - del target[:-120] - - drain_tasks = [ - asyncio.create_task(_drain(process.stdout, stdout_lines)), - asyncio.create_task(_drain(process.stderr, stderr_lines)), - ] - - cancelled = False - return_code: Optional[int] = None - try: - while True: - if await self._is_cancel_requested(task_id): - cancelled = True - await self._terminate_process(process) - break - - await self._refresh_maibot_progress_from_state(task_id, file_record.file_id, chunk_id, state_path) - try: - return_code = await asyncio.wait_for(process.wait(), timeout=1.0) - break - except asyncio.TimeoutError: - continue - finally: - await asyncio.gather(*drain_tasks, return_exceptions=True) - - if cancelled: - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - - await self._refresh_maibot_progress_from_state(task_id, file_record.file_id, chunk_id, state_path) - - report: Dict[str, Any] = {} - if report_path.exists(): - try: - report = json.loads(report_path.read_text(encoding="utf-8")) - except Exception: - report = {} - - stats = report.get("stats", {}) if isinstance(report, dict) else {} - if not isinstance(stats, dict): - stats = {} - bad_rows = max(0, _coerce_int(stats.get("bad_rows", 0), 0)) - - if return_code in {0, 2}: - await self._set_file_state(task_id, file_record.file_id, "saving", "saving") - await self._reload_stores_after_external_migration() - - async with self._lock: - task2 = self._tasks.get(task_id) - if not task2: - return - f = self._find_file(task2, file_record.file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if c and c.status not in {"cancelled", "failed"}: - c.status = "completed" - c.step = "completed" - c.progress = 1.0 - c.updated_at = _now() - if f.total_chunks <= 0: - f.total_chunks = 1 - if f.done_chunks + f.failed_chunks <= 0: - f.done_chunks = f.total_chunks - bad_rows - f.failed_chunks = bad_rows - f.done_chunks = max(0, min(f.done_chunks, f.total_chunks)) - f.failed_chunks = max(0, min(f.failed_chunks, f.total_chunks)) - f.cancelled_chunks = 0 - f.progress = 1.0 - f.status = "completed" - f.current_step = "completed" - if bad_rows > 0 and not f.error: - f.error = f"迁移完成,但存在坏行: {bad_rows}" - f.updated_at = _now() - self._recompute_task_progress(task2) - return - - fail_reason = "" - if isinstance(report, dict): - fail_reason = str(report.get("fail_reason") or "").strip() - tail = (stderr_lines[-1] if stderr_lines else "") or (stdout_lines[-1] if stdout_lines else "") - detail = fail_reason or tail or f"迁移进程退出码: {return_code}" - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, detail) - await self._set_file_failed(task_id, file_record.file_id, detail) - - def _resolve_convert_script(self) -> Path: - return Path(__file__).resolve().parents[2] / "scripts" / "convert_lpmm.py" - - def _cleanup_old_backups(self) -> None: - keep = max(0, self._cfg_int("web.import.convert.keep_backup_count", 3)) - backup_root = self._resolve_backup_root() - if not backup_root.exists() or keep <= 0: - return - dirs = [p for p in backup_root.iterdir() if p.is_dir() and p.name.startswith("lpmm_convert_")] - dirs.sort(key=lambda p: p.stat().st_mtime, reverse=True) - for old in dirs[keep:]: - try: - shutil.rmtree(old, ignore_errors=True) - except Exception: - pass - - def _verify_convert_output(self, output_dir: Path) -> Dict[str, Any]: - vectors = output_dir / "vectors" - graph = output_dir / "graph" - metadata = output_dir / "metadata" - checks = { - "vectors_exists": vectors.exists(), - "graph_exists": graph.exists(), - "metadata_exists": metadata.exists(), - "vectors_nonempty": vectors.exists() and any(vectors.iterdir()), - "graph_nonempty": graph.exists() and any(graph.iterdir()), - "metadata_nonempty": metadata.exists() and any(metadata.iterdir()), - } - checks["ok"] = checks["vectors_exists"] and checks["graph_exists"] and checks["metadata_exists"] - return checks - - async def _preflight_convert_runtime(self) -> Tuple[bool, str]: - """使用当前服务解释器做 convert 依赖预检,避免子进程报错信息不透明。""" - probe_code = ( - "import importlib\n" - "mods=['networkx','scipy','pyarrow']\n" - "failed=[]\n" - "for m in mods:\n" - " try:\n" - " importlib.import_module(m)\n" - " except Exception as e:\n" - " failed.append(f'{m}:{e.__class__.__name__}:{e}')\n" - "print('OK' if not failed else ';'.join(failed))\n" - ) - try: - probe = await asyncio.create_subprocess_exec( - sys.executable, - "-c", - probe_code, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await asyncio.wait_for(probe.communicate(), timeout=20.0) - except Exception as e: - return False, f"依赖预检执行失败: {e}" - - out = (stdout or b"").decode("utf-8", errors="replace").strip() - err = (stderr or b"").decode("utf-8", errors="replace").strip() - if probe.returncode != 0: - detail = err or out or f"return_code={probe.returncode}" - return False, f"依赖预检失败 (python={sys.executable}): {detail}" - if out != "OK": - return False, f"依赖预检失败 (python={sys.executable}): {out}" - return True, "" - - async def _process_lpmm_convert(self, task_id: str, file_record: ImportFileRecord) -> None: - await self._set_file_strategy(task_id, file_record.file_id, "lpmm_convert") - await self._set_file_state(task_id, file_record.file_id, "preparing", "preflight") - chunk_id = await self._ensure_maibot_migration_chunk( - task_id, - file_record.file_id, - chunk_type="lpmm_convert", - preview="LPMM 二进制转换任务", - ) - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "converting", 0.05) - - task = self._tasks.get(task_id) - if not task: - await self._set_file_failed(task_id, file_record.file_id, "任务不存在") - return - params = dict(task.params) - source_dir = Path(params.get("source_path") or "") - target_dir = Path(params.get("target_path") or "") - if not source_dir.exists() or not source_dir.is_dir(): - await self._set_file_failed(task_id, file_record.file_id, f"输入目录无效: {source_dir}") - return - if not target_dir.exists() or not target_dir.is_dir(): - await self._set_file_failed(task_id, file_record.file_id, f"目标目录无效: {target_dir}") - return - - script_path = self._resolve_convert_script() - if not script_path.exists(): - await self._set_file_failed(task_id, file_record.file_id, f"转换脚本不存在: {script_path}") - return - - runtime_ok, runtime_detail = await self._preflight_convert_runtime() - if not runtime_ok: - await self._set_file_failed(task_id, file_record.file_id, runtime_detail) - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, runtime_detail) - return - - required_inputs = ["paragraph.parquet", "entity.parquet"] - if not any((source_dir / name).exists() for name in required_inputs): - await self._set_file_failed( - task_id, - file_record.file_id, - f"输入目录缺少必要文件,至少需要其一: {', '.join(required_inputs)}", - ) - return - - staging_root = self._resolve_staging_root() - staging_root.mkdir(parents=True, exist_ok=True) - staging_dir = staging_root / f"lpmm_convert_{task_id}" - if staging_dir.exists(): - shutil.rmtree(staging_dir, ignore_errors=True) - staging_dir.mkdir(parents=True, exist_ok=True) - - # 简单空间预检:至少保留 512MB - usage = shutil.disk_usage(str(target_dir)) - if usage.free < 512 * 1024 * 1024: - await self._set_file_failed(task_id, file_record.file_id, "磁盘剩余空间不足(<512MB)") - return - - cmd = [ - sys.executable, - str(script_path), - "--input", - str(source_dir), - "--output", - str(staging_dir), - "--dim", - str(params.get("dimension", 384)), - "--batch-size", - str(params.get("batch_size", 1024)), - ] - process = await asyncio.create_subprocess_exec( - *cmd, - cwd=str(self._resolve_repo_root()), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout_lines: List[str] = [] - stderr_lines: List[str] = [] - - async def _drain(stream: Optional[asyncio.StreamReader], target: List[str]) -> None: - if stream is None: - return - while True: - line = await stream.readline() - if not line: - break - text = line.decode("utf-8", errors="replace").strip() - if text: - target.append(text) - if len(target) > 120: - del target[:-120] - - drain_tasks = [ - asyncio.create_task(_drain(process.stdout, stdout_lines)), - asyncio.create_task(_drain(process.stderr, stderr_lines)), - ] - - cancelled = False - return_code: Optional[int] = None - try: - while True: - if await self._is_cancel_requested(task_id): - cancelled = True - await self._terminate_process(process) - break - try: - return_code = await asyncio.wait_for(process.wait(), timeout=1.0) - break - except asyncio.TimeoutError: - continue - finally: - await asyncio.gather(*drain_tasks, return_exceptions=True) - - if cancelled: - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - if return_code != 0: - detail = (stderr_lines[-1] if stderr_lines else "") or (stdout_lines[-1] if stdout_lines else "") - await self._set_file_failed(task_id, file_record.file_id, detail or f"转换失败,退出码: {return_code}") - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, detail or f"退出码: {return_code}") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "verifying", 0.65) - verify = self._verify_convert_output(staging_dir) - async with self._lock: - t = self._tasks.get(task_id) - if t: - t.artifact_paths["staging_dir"] = str(staging_dir) - t.artifact_paths["verify"] = json.dumps(verify, ensure_ascii=False) - if not verify.get("ok"): - await self._set_file_failed(task_id, file_record.file_id, f"校验失败: {verify}") - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"校验失败: {verify}") - return - - enable_switch = _coerce_bool(self._cfg("web.import.convert.enable_staging_switch", True), True) - if not enable_switch: - await self._set_file_failed(task_id, file_record.file_id, "未启用 staging 切换") - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, "未启用 staging 切换") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "switching", 0.85) - backup_root = self._resolve_backup_root() - backup_root.mkdir(parents=True, exist_ok=True) - backup_dir = backup_root / f"lpmm_convert_{task_id}_{int(_now())}" - backup_dir.mkdir(parents=True, exist_ok=True) - - switched = False - rollback_info: Dict[str, Any] = {"attempted": True, "restored": False, "error": ""} - moved_items: List[Tuple[Path, Path]] = [] - try: - for name in ("vectors", "graph", "metadata"): - src_current = target_dir / name - src_new = staging_dir / name - if not src_new.exists(): - raise RuntimeError(f"staging 缺少目录: {src_new}") - if src_current.exists(): - dst_backup = backup_dir / name - shutil.move(str(src_current), str(dst_backup)) - moved_items.append((dst_backup, src_current)) - shutil.move(str(src_new), str(src_current)) - switched = True - except Exception as switch_err: - rollback_info["error"] = str(switch_err) - # 尝试回滚 - for src_backup, dst_original in moved_items: - if src_backup.exists() and not dst_original.exists(): - try: - shutil.move(str(src_backup), str(dst_original)) - except Exception: - pass - rollback_info["restored"] = True - async with self._lock: - t = self._tasks.get(task_id) - if t: - t.rollback_info = rollback_info - await self._set_file_failed(task_id, file_record.file_id, f"切换失败并回滚: {switch_err}") - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"switch failed: {switch_err}") - return - - if switched: - async with self._lock: - t = self._tasks.get(task_id) - if t: - t.rollback_info = rollback_info - t.artifact_paths["backup_dir"] = str(backup_dir) - self._cleanup_old_backups() - try: - await self._reload_stores_after_external_migration() - except Exception as reload_err: - logger.warning(f"转换后重载存储失败: {reload_err}") - - await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) - async with self._lock: - t = self._tasks.get(task_id) - if not t: - return - f = self._find_file(t, file_record.file_id) - if not f: - return - f.total_chunks = 1 - f.done_chunks = 1 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 1.0 - f.status = "completed" - f.current_step = "completed" - f.updated_at = _now() - self._recompute_task_progress(t) - - async def _process_temporal_backfill(self, task_id: str, file_record: ImportFileRecord) -> None: - await self._set_file_strategy(task_id, file_record.file_id, "temporal_backfill") - await self._set_file_state(task_id, file_record.file_id, "preparing", "backfilling") - chunk_id = await self._ensure_maibot_migration_chunk( - task_id, - file_record.file_id, - chunk_type="temporal_backfill", - preview="时序字段回填任务", - ) - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "backfilling", 0.2) - - task = self._tasks.get(task_id) - if not task: - await self._set_file_failed(task_id, file_record.file_id, "任务不存在") - return - params = dict(task.params) - target_dir = Path(file_record.source_path or "") - metadata_dir = target_dir / "metadata" - if not metadata_dir.exists(): - await self._set_file_failed(task_id, file_record.file_id, f"metadata 目录不存在: {metadata_dir}") - return - - dry_run = bool(params.get("dry_run")) - no_created_fallback = bool(params.get("no_created_fallback")) - limit = max(1, _coerce_int(params.get("limit"), 100000)) - - store = MetadataStore(data_dir=metadata_dir) - updated = 0 - candidates = 0 - try: - store.connect() - summary = store.backfill_temporal_metadata_from_created_at( - limit=limit, - dry_run=dry_run, - no_created_fallback=no_created_fallback, - ) - candidates = int(summary.get("candidates", 0)) - updated = int(summary.get("updated", 0)) - finally: - try: - store.close() - except Exception: - pass - - async with self._lock: - t = self._tasks.get(task_id) - if t: - t.artifact_paths["temporal_backfill"] = json.dumps( - { - "target_dir": str(target_dir), - "dry_run": dry_run, - "no_created_fallback": no_created_fallback, - "limit": limit, - "candidates": candidates, - "updated": updated, - }, - ensure_ascii=False, - ) - await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) - async with self._lock: - t = self._tasks.get(task_id) - if not t: - return - f = self._find_file(t, file_record.file_id) - if not f: - return - f.total_chunks = 1 - f.done_chunks = 1 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 1.0 - f.status = "completed" - f.current_step = "completed" - f.updated_at = _now() - self._recompute_task_progress(t) - - async def _process_file( - self, - task_id: str, - file_record: ImportFileRecord, - file_semaphore: asyncio.Semaphore, - chunk_semaphore: asyncio.Semaphore, - ) -> None: - async with file_semaphore: - await self._set_file_state(task_id, file_record.file_id, "preparing", "preparing") - if await self._is_cancel_requested(task_id): - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - - try: - content = await self._read_file_content(file_record) - content_hash = hashlib.md5(content.encode("utf-8", errors="ignore")).hexdigest() - file_record.content_hash = content_hash - task = self._tasks.get(task_id) - if task: - dedupe_policy = str(task.params.get("dedupe_policy") or "none") - force = bool(task.params.get("force")) - if dedupe_policy != "none" and not force: - async with self._lock: - if self._is_manifest_hit(file_record, content_hash, dedupe_policy): - task2 = self._tasks.get(task_id) - if task2: - f = self._find_file(task2, file_record.file_id) - if f: - f.status = "completed" - f.current_step = "skipped" - f.progress = 1.0 - f.total_chunks = 0 - f.done_chunks = 0 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.detected_strategy_type = "skipped" - f.error = "" - f.updated_at = _now() - self._recompute_task_progress(task2) - return - if file_record.input_mode == "json": - await self._process_json_file(task_id, file_record, content, chunk_semaphore) - else: - await self._process_text_file(task_id, file_record, content, chunk_semaphore) - task3 = self._tasks.get(task_id) - if task3: - dedupe_policy = str(task3.params.get("dedupe_policy") or "none") - f3 = self._find_file(task3, file_record.file_id) - if dedupe_policy != "none" and f3 and f3.status == "completed": - async with self._lock: - self._record_manifest_import(file_record, content_hash, dedupe_policy, task_id) - except Exception as e: - await self._set_file_failed(task_id, file_record.file_id, str(e)) - - async def _read_file_content(self, file_record: ImportFileRecord) -> str: - if file_record.inline_content is not None: - return file_record.inline_content - if file_record.source_path and Path(file_record.source_path).exists(): - data = Path(file_record.source_path).read_bytes() - try: - return data.decode("utf-8") - except UnicodeDecodeError: - return data.decode("utf-8", errors="replace") - if file_record.temp_path and Path(file_record.temp_path).exists(): - data = Path(file_record.temp_path).read_bytes() - try: - return data.decode("utf-8") - except UnicodeDecodeError: - return data.decode("utf-8", errors="replace") - raise RuntimeError("读取文件失败:输入内容缺失") - - async def _process_text_file( - self, - task_id: str, - file_record: ImportFileRecord, - content: str, - chunk_semaphore: asyncio.Semaphore, - ) -> None: - task = self._tasks[task_id] - async with self._lock: - t = self._tasks.get(task_id) - if t and not t.schema_detected: - t.schema_detected = "plain_text" - strategy = self._determine_strategy( - file_record.name, - content, - task.params["strategy_override"], - chat_log=bool(task.params.get("chat_log")), - ) - await self._set_file_strategy(task_id, file_record.file_id, strategy) - await self._set_file_state(task_id, file_record.file_id, "splitting", "splitting") - await self._ensure_embedding_runtime_ready() - - chunks = strategy.split(content) - selected_chunks = list(chunks) - if file_record.retry_mode == "chunk": - retry_index_set = set() - for idx in file_record.retry_chunk_indexes: - try: - retry_index_set.add(int(idx)) - except Exception: - continue - selected_chunks = [chunk for chunk in chunks if int(chunk.chunk.index) in retry_index_set] - if not selected_chunks: - raise RuntimeError("失败分块重试索引无效,未匹配到可执行分块") - logger.info( - "重试任务按失败分块执行: " - f"file={file_record.name} " - f"selected={len(selected_chunks)} " - f"total={len(chunks)}" - ) - - await self._register_chunks(task_id, file_record.file_id, selected_chunks) - - await self._set_file_state(task_id, file_record.file_id, "extracting", "extracting") - model_cfg = None - if task.params["llm_enabled"]: - model_cfg = await self._select_model() - - jobs = [] - for chunk in selected_chunks: - jobs.append( - asyncio.create_task( - self._process_text_chunk( - task_id=task_id, - file_record=file_record, - chunk=chunk, - strategy=strategy, - llm_enabled=task.params["llm_enabled"], - model_cfg=model_cfg, - chunk_semaphore=chunk_semaphore, - chat_log=bool(task.params.get("chat_log")), - chat_reference_time=str(task.params.get("chat_reference_time") or "").strip() or None, - ) - ) - ) - await asyncio.gather(*jobs, return_exceptions=True) - - if await self._is_cancel_requested(task_id): - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - - await self._set_file_state(task_id, file_record.file_id, "saving", "saving") - async with self._storage_lock: - self.plugin.vector_store.save() - self.plugin.graph_store.save() - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_record.file_id) - if not f: - return - if f.failed_chunks > 0: - f.status = "failed" - f.current_step = "failed" - if not f.error: - f.error = f"存在失败分块: {f.failed_chunks}" - elif task.status == "cancel_requested": - f.status = "cancelled" - f.current_step = "cancelled" - else: - f.status = "completed" - f.current_step = "completed" - f.progress = 1.0 - f.updated_at = _now() - self._recompute_task_progress(task) - async def _process_text_chunk( - self, - task_id: str, - file_record: ImportFileRecord, - chunk: ProcessedChunk, - strategy: Any, - llm_enabled: bool, - model_cfg: Any, - chunk_semaphore: asyncio.Semaphore, - chat_log: bool = False, - chat_reference_time: Optional[str] = None, - ) -> None: - async with chunk_semaphore: - chunk_id = chunk.chunk.chunk_id - if await self._is_cancel_requested(task_id): - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "extracting", "extracting", 0.25) - - processed = chunk - rescue_strategy = self._chunk_rescue(chunk, file_record.name) - current_strategy = strategy - if rescue_strategy: - chunk.type = StrategyKnowledgeType.QUOTE - chunk.flags.verbatim = True - chunk.flags.requires_llm = False - current_strategy = rescue_strategy - try: - if llm_enabled and chunk.flags.requires_llm: - processed = await current_strategy.extract( - chunk, - lambda prompt: self._llm_call(prompt, model_cfg), - ) - elif chunk.type == StrategyKnowledgeType.QUOTE: - processed = await current_strategy.extract(chunk) - except Exception as e: - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"抽取失败: {e}") - return - - if await self._is_cancel_requested(task_id): - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "writing", 0.7) - try: - time_meta = None - if chat_log and llm_enabled and model_cfg is not None: - time_meta = await self._extract_chat_time_meta_with_llm( - processed.chunk.text, - model_cfg, - reference_time=chat_reference_time, - ) - async with self._storage_lock: - await self._persist_processed_chunk(file_record, processed, time_meta=time_meta) - await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) - except Exception as e: - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"写入失败: {e}") - - async def _process_json_file( - self, - task_id: str, - file_record: ImportFileRecord, - content: str, - chunk_semaphore: asyncio.Semaphore, - ) -> None: - await self._set_file_strategy(task_id, file_record.file_id, "json") - await self._set_file_state(task_id, file_record.file_id, "splitting", "splitting") - await self._ensure_embedding_runtime_ready() - - try: - data = json.loads(content) - except Exception as e: - raise RuntimeError(f"JSON 解析失败: {e}") - - schema = self._detect_json_schema(data) - async with self._lock: - task = self._tasks.get(task_id) - if task: - task.schema_detected = schema - task.updated_at = _now() - units = self._build_json_units(data, file_record.file_id, file_record.name, schema) - await self._register_json_units(task_id, file_record.file_id, units) - - await self._set_file_state(task_id, file_record.file_id, "extracting", "extracting") - jobs = [ - asyncio.create_task(self._process_json_unit(task_id, file_record, unit, chunk_semaphore)) - for unit in units - ] - await asyncio.gather(*jobs, return_exceptions=True) - - if await self._is_cancel_requested(task_id): - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - - await self._set_file_state(task_id, file_record.file_id, "saving", "saving") - async with self._storage_lock: - self.plugin.vector_store.save() - self.plugin.graph_store.save() - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_record.file_id) - if not f: - return - if f.failed_chunks > 0: - f.status = "failed" - f.current_step = "failed" - if not f.error: - f.error = f"存在失败分块: {f.failed_chunks}" - elif task.status == "cancel_requested": - f.status = "cancelled" - f.current_step = "cancelled" - else: - f.status = "completed" - f.current_step = "completed" - f.progress = 1.0 - f.updated_at = _now() - self._recompute_task_progress(task) - - def _detect_json_schema(self, data: Any) -> str: - if isinstance(data, dict) and isinstance(data.get("docs"), list): - return "lpmm_openie" - if isinstance(data, dict) and isinstance(data.get("paragraphs"), list): - paragraphs = data.get("paragraphs", []) - for p in paragraphs: - if isinstance(p, dict) and any( - key in p for key in ("entities", "relations", "time_meta", "source", "type", "knowledge_type") - ): - return "script_json" - return "web_json" - raise RuntimeError("不支持的 JSON 格式:需要 paragraphs 或 docs") - - def _build_json_units(self, data: Any, file_id: str, filename: str, schema: str) -> List[Dict[str, Any]]: - units: List[Dict[str, Any]] = [] - paragraphs: List[Any] = [] - entities: List[Any] = [] - relations: List[Any] = [] - - if schema in {"web_json", "script_json"}: - paragraphs = data.get("paragraphs", []) - entities = data.get("entities", []) - relations = data.get("relations", []) - elif schema == "lpmm_openie": - docs = data.get("docs", []) - for d in docs: - if not isinstance(d, dict): - continue - content = str(d.get("passage", "") or "").strip() - if not content: - continue - triples = d.get("extracted_triples", []) or [] - rels = [] - for t in triples: - if isinstance(t, list) and len(t) == 3: - rels.append( - { - "subject": str(t[0]), - "predicate": str(t[1]), - "object": str(t[2]), - } - ) - para_item = { - "content": content, - "source": f"lpmm_openie:{filename}", - "entities": d.get("extracted_entities", []) or [], - "relations": rels, - "knowledge_type": "factual", - } - paragraphs.append(para_item) - - for p in paragraphs: - paragraph = normalize_paragraph_import_item( - p, - default_source=f"web_import:{filename}", - ) - units.append( - { - "chunk_id": f"{file_id}_json_{len(units)}", - "kind": "paragraph", - "content": paragraph["content"], - "time_meta": paragraph["time_meta"], - "knowledge_type": paragraph["knowledge_type"], - "chunk_type": paragraph["knowledge_type"], - "source": paragraph["source"], - "entities": paragraph["entities"], - "relations": paragraph["relations"], - "preview": paragraph["content"][:120], - } - ) - - for e in entities: - name = str(e or "").strip() - if name: - units.append( - { - "chunk_id": f"{file_id}_json_{len(units)}", - "kind": "entity", - "name": name, - "chunk_type": "entity", - "preview": name[:120], - } - ) - - for r in relations: - if not isinstance(r, dict): - continue - s = str(r.get("subject", "")).strip() - p = str(r.get("predicate", "")).strip() - o = str(r.get("object", "")).strip() - if s and p and o: - units.append( - { - "chunk_id": f"{file_id}_json_{len(units)}", - "kind": "relation", - "subject": s, - "predicate": p, - "object": o, - "chunk_type": "relation", - "preview": f"{s} {p} {o}"[:120], - } - ) - return units - - async def _register_json_units(self, task_id: str, file_id: str, units: List[Dict[str, Any]]) -> None: - records = [ - ImportChunkRecord( - chunk_id=u["chunk_id"], - index=i, - chunk_type=u.get("chunk_type", "json"), - status="queued", - step="queued", - progress=0.0, - content_preview=str(u.get("preview", "")), - ) - for i, u in enumerate(units) - ] - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.chunks = records - f.total_chunks = len(records) - f.done_chunks = 0 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 0.0 if records else 1.0 - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _process_json_unit( - self, - task_id: str, - file_record: ImportFileRecord, - unit: Dict[str, Any], - chunk_semaphore: asyncio.Semaphore, - ) -> None: - chunk_id = unit["chunk_id"] - async with chunk_semaphore: - if await self._is_cancel_requested(task_id): - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "writing", 0.7) - try: - async with self._storage_lock: - kind = unit["kind"] - if kind == "paragraph": - content = str(unit.get("content", "")) - k_type = resolve_stored_knowledge_type( - unit.get("knowledge_type"), - content=content, - ).value - source = str(unit.get("source") or f"web_import:{file_record.name}") - para_hash = self.plugin.metadata_store.add_paragraph( - content=content, - source=source, - knowledge_type=k_type, - time_meta=unit.get("time_meta"), - ) - emb = await self.plugin.embedding_manager.encode(content) - try: - self.plugin.vector_store.add(emb.reshape(1, -1), [para_hash]) - except ValueError: - pass - for name in unit.get("entities", []) or []: - n = str(name or "").strip() - if n: - await self._add_entity_with_vector(n, source_paragraph=para_hash) - for rel in unit.get("relations", []) or []: - if not isinstance(rel, dict): - continue - s = str(rel.get("subject", "")).strip() - p = str(rel.get("predicate", "")).strip() - o = str(rel.get("object", "")).strip() - if s and p and o: - await self._add_relation(s, p, o, source_paragraph=para_hash) - elif kind == "entity": - await self._add_entity_with_vector(unit["name"]) - elif kind == "relation": - await self._add_relation(unit["subject"], unit["predicate"], unit["object"]) - else: - raise RuntimeError(f"未知 JSON 导入单元类型: {kind}") - await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) - except Exception as e: - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"写入失败: {e}") - - def _source_label(self, file_record: ImportFileRecord) -> str: - if file_record.source_path: - return f"{file_record.source_kind}:{file_record.source_path}" - return f"web_import:{file_record.name}" - - async def _ensure_embedding_runtime_ready(self) -> None: - report = await ensure_runtime_self_check(self.plugin) - if bool(report.get("ok", False)): - return - raise RuntimeError( - "embedding runtime self-check failed: " - f"{report.get('message', 'unknown')} " - f"(configured={report.get('configured_dimension', 0)}, " - f"store={report.get('vector_store_dimension', 0)}, " - f"encoded={report.get('encoded_dimension', 0)})" - ) - - async def _persist_processed_chunk( - self, - file_record: ImportFileRecord, - processed: ProcessedChunk, - *, - time_meta: Optional[Dict[str, Any]] = None, - ) -> None: - content = processed.chunk.text - para_hash = self.plugin.metadata_store.add_paragraph( - content=content, - source=self._source_label(file_record), - knowledge_type=_storage_type_from_strategy(processed.type), - time_meta=time_meta, - ) - - emb = await self.plugin.embedding_manager.encode(content) - try: - self.plugin.vector_store.add(emb.reshape(1, -1), [para_hash]) - except ValueError: - pass - - data = processed.data or {} - entities: List[str] = [] - relations: List[Tuple[str, str, str]] = [] - - for triple in data.get("triples", []): - s = str(triple.get("subject", "")).strip() - p = str(triple.get("predicate", "")).strip() - o = str(triple.get("object", "")).strip() - if s and p and o: - relations.append((s, p, o)) - entities.extend([s, o]) - - for rel in data.get("relations", []): - s = str(rel.get("subject", "")).strip() - p = str(rel.get("predicate", "")).strip() - o = str(rel.get("object", "")).strip() - if s and p and o: - relations.append((s, p, o)) - entities.extend([s, o]) - - for k in ("entities", "events", "verbatim_entities"): - for e in data.get(k, []): - name = str(e or "").strip() - if name: - entities.append(name) - - uniq_entities = list({x.strip().lower(): x.strip() for x in entities if str(x).strip()}.values()) - for name in uniq_entities: - await self._add_entity_with_vector(name, source_paragraph=para_hash) - - for s, p, o in relations: - await self._add_relation(s, p, o, source_paragraph=para_hash) - - async def _add_entity_with_vector(self, name: str, source_paragraph: str = "") -> str: - hash_value = self.plugin.metadata_store.add_entity(name=name, source_paragraph=source_paragraph) - self.plugin.graph_store.add_nodes([name]) - if hash_value not in self.plugin.vector_store: - emb = await self.plugin.embedding_manager.encode(name) - try: - self.plugin.vector_store.add(emb.reshape(1, -1), [hash_value]) - except ValueError: - pass - return hash_value - - async def _add_relation(self, subject: str, predicate: str, obj: str, source_paragraph: str = "") -> str: - await self._add_entity_with_vector(subject, source_paragraph=source_paragraph) - await self._add_entity_with_vector(obj, source_paragraph=source_paragraph) - rv_cfg = self.plugin.get_config("retrieval.relation_vectorization", {}) or {} - if not isinstance(rv_cfg, dict): - rv_cfg = {} - write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) - - relation_service = getattr(self.plugin, "relation_write_service", None) - if relation_service is not None: - result = await relation_service.upsert_relation_with_vector( - subject=subject, - predicate=predicate, - obj=obj, - confidence=1.0, - source_paragraph=source_paragraph, - write_vector=write_vector, - ) - return result.hash_value - - rel_hash = self.plugin.metadata_store.add_relation( - subject=subject, - predicate=predicate, - obj=obj, - source_paragraph=source_paragraph, - confidence=1.0, - ) - self.plugin.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) - try: - self.plugin.metadata_store.set_relation_vector_state(rel_hash, "none") - except Exception: - pass - return rel_hash - async def _select_model(self) -> Any: - models = llm_api.get_available_models() - if not models: - raise RuntimeError("没有可用 LLM 模型") - - config_model = str(self._cfg("advanced.extraction_model", "auto") or "auto").strip() - if config_model.lower() != "auto" and config_model in models: - return models[config_model] - - for task_name in [ - "lpmm_entity_extract", - "lpmm_rdf_build", - "embedding", - "replyer", - "utils", - "planner", - "tool_use", - ]: - if task_name in models: - return models[task_name] - - return models[next(iter(models))] - - async def _llm_call(self, prompt: str, model_config: Any) -> Dict[str, Any]: - cfg = self._llm_retry_config() - retries = int(cfg["retries"]) - last_error: Optional[Exception] = None - for attempt in range(retries + 1): - try: - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type="A_Memorix.WebImport", - ) - if not success or not response: - raise RuntimeError("LLM 生成失败") - - txt = str(response or "").strip() - if "```" in txt: - txt = txt.split("```json")[-1].split("```")[0].strip() - if txt.startswith("json"): - txt = txt[4:].strip() - - try: - return json.loads(txt) - except Exception: - s = txt.find("{") - e = txt.rfind("}") - if s >= 0 and e > s: - return json.loads(txt[s : e + 1]) - raise - except Exception as err: - last_error = err - if attempt >= retries: - break - delay = min(cfg["max_wait"], cfg["min_wait"] * (cfg["multiplier"] ** attempt)) - await asyncio.sleep(max(0.0, float(delay))) - raise RuntimeError(f"LLM 抽取失败: {last_error}") - - def _parse_reference_time(self, value: Optional[str]) -> datetime: - if not value: - return datetime.now() - text = str(value).strip() - formats = [ - "%Y/%m/%d %H:%M:%S", - "%Y/%m/%d %H:%M", - "%Y-%m-%d %H:%M:%S", - "%Y-%m-%d %H:%M", - "%Y/%m/%d", - "%Y-%m-%d", - ] - for fmt in formats: - try: - return datetime.strptime(text, fmt) - except ValueError: - continue - return datetime.now() - - async def _extract_chat_time_meta_with_llm( - self, - text: str, - model_config: Any, - *, - reference_time: Optional[str] = None, - ) -> Optional[Dict[str, Any]]: - if not str(text or "").strip(): - return None - ref_dt = self._parse_reference_time(reference_time) - reference_now = ref_dt.strftime("%Y/%m/%d %H:%M") - prompt = f"""You are a time extraction engine for chat logs. -Extract temporal information from the following chat paragraph. - -Rules: -1. Use semantic understanding, not regex matching. -2. Convert relative expressions to absolute local datetime using reference_now. -3. If a time span exists, return event_time_start/event_time_end. -4. If only one point in time exists, return event_time. -5. If no reliable time info exists, keep all event_time fields null. -6. Return JSON only. - -reference_now: {reference_now} -text: -{text} - -JSON schema: -{{ - "event_time": null, - "event_time_start": null, - "event_time_end": null, - "time_range": null, - "time_granularity": null, - "time_confidence": 0.0 -}} -""" - try: - result = await self._llm_call(prompt, model_config) - except Exception as e: - logger.warning(f"chat_log 时间语义抽取失败: {e}") - return None - - raw_time_meta = { - "event_time": result.get("event_time"), - "event_time_start": result.get("event_time_start"), - "event_time_end": result.get("event_time_end"), - "time_range": result.get("time_range"), - "time_granularity": result.get("time_granularity"), - "time_confidence": result.get("time_confidence"), - } - try: - normalized = normalize_time_meta(raw_time_meta) - except Exception: - return None - has_effective = any(k in normalized for k in ("event_time", "event_time_start", "event_time_end")) - if not has_effective: - return None - return normalized - - def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[Any]: - if chunk.type == StrategyKnowledgeType.QUOTE: - return None - if looks_like_quote_text(chunk.chunk.text): - return QuoteStrategy(filename) - return None - - def _instantiate_strategy(self, filename: str, strategy: ImportStrategy) -> Any: - if strategy == ImportStrategy.FACTUAL: - return FactualStrategy(filename) - if strategy == ImportStrategy.QUOTE: - return QuoteStrategy(filename) - return NarrativeStrategy(filename) - - def _determine_strategy(self, filename: str, content: str, override: str, *, chat_log: bool = False) -> Any: - strategy = select_import_strategy( - content, - override=override, - chat_log=chat_log, - ) - return self._instantiate_strategy(filename, strategy) - - async def _set_file_strategy(self, task_id: str, file_id: str, strategy: Any) -> None: - if isinstance(strategy, str): - strategy_type = strategy - elif isinstance(strategy, NarrativeStrategy): - strategy_type = "narrative" - elif isinstance(strategy, FactualStrategy): - strategy_type = "factual" - elif isinstance(strategy, QuoteStrategy): - strategy_type = "quote" - else: - strategy_type = "unknown" - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.detected_strategy_type = strategy_type - f.updated_at = _now() - task.updated_at = _now() - - async def _register_chunks(self, task_id: str, file_id: str, chunks: List[ProcessedChunk]) -> None: - records = [ - ImportChunkRecord( - chunk_id=chunk.chunk.chunk_id, - index=index, - chunk_type=chunk.type.value, - status="queued", - step="queued", - progress=0.0, - content_preview=str(chunk.chunk.text or "")[:120], - ) - for index, chunk in enumerate(chunks) - ] - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.chunks = records - f.total_chunks = len(records) - f.done_chunks = 0 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 0.0 if records else 1.0 - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_file_state(self, task_id: str, file_id: str, status: str, step: str) -> None: - if status not in FILE_STATUS: - return - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.status = status - f.current_step = step - f.updated_at = _now() - task.updated_at = _now() - if step in {"preparing", "splitting", "extracting", "writing", "saving"} and task.status in {"queued", "preparing"}: - task.status = "running" - task.current_step = "running" - - async def _set_file_failed(self, task_id: str, file_id: str, error: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.status = "failed" - f.current_step = "failed" - f.error = str(error) - f.updated_at = _now() - task.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_file_cancelled(self, task_id: str, file_id: str, reason: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.status = "cancelled" - f.current_step = "cancelled" - f.error = reason - additional_cancelled = 0 - for chunk in f.chunks: - if chunk.status in {"completed", "failed", "cancelled"}: - continue - chunk.status = "cancelled" - chunk.step = "cancelled" - chunk.retryable = False - chunk.error = reason - chunk.progress = 1.0 - chunk.updated_at = _now() - additional_cancelled += 1 - if additional_cancelled > 0: - f.cancelled_chunks += additional_cancelled - f.progress = self._compute_ratio( - f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks - ) - f.updated_at = _now() - task.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_chunk_state( - self, - task_id: str, - file_id: str, - chunk_id: str, - status: str, - step: str, - progress: float, - ) -> None: - if status not in CHUNK_STATUS: - return - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if not c: - return - c.status = status - c.step = step - if status in {"queued", "extracting", "writing"}: - c.error = "" - c.failed_at = "" - c.retryable = False - c.progress = max(0.0, min(1.0, float(progress))) - c.updated_at = _now() - if f.status not in {"failed", "cancelled"}: - f.status = "extracting" if status == "extracting" else "writing" - f.current_step = step - f.updated_at = _now() - task.updated_at = _now() - - async def _set_chunk_completed(self, task_id: str, file_id: str, chunk_id: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if not c or c.status == "completed": - return - c.status = "completed" - c.step = "completed" - c.failed_at = "" - c.retryable = False - c.progress = 1.0 - c.updated_at = _now() - f.done_chunks += 1 - f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_chunk_failed(self, task_id: str, file_id: str, chunk_id: str, error: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if not c or c.status == "failed": - return - failed_stage = str(c.step or "").strip().lower() - if failed_stage in {"", "queued", "failed", "completed", "cancelled"}: - failed_stage = str(f.current_step or "").strip().lower() - if failed_stage in {"", "queued", "failed", "completed", "cancelled"}: - failed_stage = "unknown" - c.status = "failed" - c.step = "failed" - c.failed_at = failed_stage - c.retryable = bool(f.input_mode == "text" and failed_stage == "extracting") - c.error = str(error) - c.progress = 1.0 - c.updated_at = _now() - f.failed_chunks += 1 - f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) - if not f.error: - f.error = str(error) - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_chunk_cancelled(self, task_id: str, file_id: str, chunk_id: str, reason: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if not c or c.status == "cancelled": - return - c.status = "cancelled" - c.step = "cancelled" - c.retryable = False - c.error = reason - c.progress = 1.0 - c.updated_at = _now() - f.cancelled_chunks += 1 - f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _is_cancel_requested(self, task_id: str) -> bool: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return True - return task.status == "cancel_requested" - - def _find_file(self, task: ImportTaskRecord, file_id: str) -> Optional[ImportFileRecord]: - for f in task.files: - if f.file_id == file_id: - return f - return None - - def _find_chunk(self, file_record: ImportFileRecord, chunk_id: str) -> Optional[ImportChunkRecord]: - for c in file_record.chunks: - if c.chunk_id == chunk_id: - return c - return None - - def _compute_ratio(self, done: int, total: int) -> float: - if total <= 0: - return 1.0 - return max(0.0, min(1.0, float(done) / float(total))) - - def _recompute_task_progress(self, task: ImportTaskRecord) -> None: - total = 0 - done = 0 - failed = 0 - cancelled = 0 - for f in task.files: - total += f.total_chunks - done += f.done_chunks - failed += f.failed_chunks - cancelled += f.cancelled_chunks - task.total_chunks = total - task.done_chunks = done - task.failed_chunks = failed - task.cancelled_chunks = cancelled - task.progress = self._compute_ratio(done + failed + cancelled, total) - task.updated_at = _now() - - async def _should_cleanup_task_temp(self, task_id: str) -> bool: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return True - for f in task.files: - if f.status == "failed": - return False - return True - - def _mark_task_cancelled_locked(self, task: ImportTaskRecord, reason: str) -> None: - for f in task.files: - if f.status in {"completed", "failed", "cancelled"}: - continue - f.status = "cancelled" - f.current_step = "cancelled" - f.error = reason - additional_cancelled = 0 - for c in f.chunks: - if c.status in {"completed", "failed", "cancelled"}: - continue - c.status = "cancelled" - c.step = "cancelled" - c.retryable = False - c.error = reason - c.progress = 1.0 - c.updated_at = _now() - additional_cancelled += 1 - if additional_cancelled > 0: - f.cancelled_chunks += additional_cancelled - f.progress = self._compute_ratio( - f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks - ) - f.updated_at = _now() - task.status = "cancelled" - task.current_step = "cancelled" - task.finished_at = _now() - task.updated_at = _now() - self._recompute_task_progress(task) diff --git a/plugins/A_memorix/plugin.py b/plugins/A_memorix/plugin.py deleted file mode 100644 index 841106a4..00000000 --- a/plugins/A_memorix/plugin.py +++ /dev/null @@ -1,273 +0,0 @@ -"""A_Memorix SDK plugin entry.""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any, Dict, List, Optional - -from maibot_sdk import MaiBotPlugin, Tool -from maibot_sdk.types import ToolParameterInfo, ToolParamType - -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel - - -def _tool_param(name: str, param_type: ToolParamType, description: str, required: bool) -> ToolParameterInfo: - return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required) - - -_ADMIN_TOOL_PARAMS = [ - _tool_param("action", ToolParamType.STRING, "管理动作", True), - _tool_param("target", ToolParamType.STRING, "可选目标标识", False), -] - - -class AMemorixPlugin(MaiBotPlugin): - def __init__(self) -> None: - super().__init__() - self._plugin_root = Path(__file__).resolve().parent - self._plugin_config: Dict[str, Any] = {} - self._kernel: Optional[SDKMemoryKernel] = None - - def set_plugin_config(self, config: Dict[str, Any]) -> None: - self._plugin_config = config or {} - if self._kernel is not None: - self._kernel.close() - self._kernel = None - - async def on_load(self): - await self._get_kernel() - - async def on_unload(self): - if self._kernel is not None: - shutdown = getattr(self._kernel, "shutdown", None) - if callable(shutdown): - await shutdown() - else: - self._kernel.close() - self._kernel = None - - async def _get_kernel(self) -> SDKMemoryKernel: - if self._kernel is None: - self._kernel = SDKMemoryKernel(plugin_root=self._plugin_root, config=self._plugin_config) - await self._kernel.initialize() - return self._kernel - - async def _dispatch_admin_tool(self, method_name: str, action: str, **kwargs): - kernel = await self._get_kernel() - handler = getattr(kernel, method_name) - return await handler(action=action, **kwargs) - - @Tool( - "search_memory", - description="搜索长期记忆", - parameters=[ - _tool_param("query", ToolParamType.STRING, "查询文本", False), - _tool_param("limit", ToolParamType.INTEGER, "返回条数", False), - _tool_param("mode", ToolParamType.STRING, "search/time/hybrid/episode/aggregate", False), - _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), - _tool_param("person_id", ToolParamType.STRING, "人物 ID", False), - _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), - _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), - _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), - ], - ) - async def handle_search_memory( - self, - query: str = "", - limit: int = 5, - mode: str = "search", - chat_id: str = "", - person_id: str = "", - time_start: str | float | None = None, - time_end: str | float | None = None, - respect_filter: bool = True, - **kwargs, - ): - kernel = await self._get_kernel() - return await kernel.search_memory( - KernelSearchRequest( - query=query, - limit=limit, - mode=mode, - chat_id=chat_id, - person_id=person_id, - time_start=time_start, - time_end=time_end, - respect_filter=respect_filter, - user_id=str(kwargs.get("user_id", "") or "").strip(), - group_id=str(kwargs.get("group_id", "") or "").strip(), - ) - ) - - @Tool( - "ingest_summary", - description="写入聊天摘要到长期记忆", - parameters=[ - _tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True), - _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", True), - _tool_param("text", ToolParamType.STRING, "摘要文本", True), - _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), - _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), - _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), - ], - ) - async def handle_ingest_summary( - self, - external_id: str, - chat_id: str, - text: str, - participants: Optional[List[str]] = None, - time_start: float | None = None, - time_end: float | None = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - respect_filter: bool = True, - **kwargs, - ): - kernel = await self._get_kernel() - return await kernel.ingest_summary( - external_id=external_id, - chat_id=chat_id, - text=text, - participants=participants, - time_start=time_start, - time_end=time_end, - tags=tags, - metadata=metadata, - respect_filter=respect_filter, - user_id=str(kwargs.get("user_id", "") or "").strip(), - group_id=str(kwargs.get("group_id", "") or "").strip(), - ) - - @Tool( - "ingest_text", - description="写入普通长期记忆文本", - parameters=[ - _tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True), - _tool_param("source_type", ToolParamType.STRING, "来源类型", True), - _tool_param("text", ToolParamType.STRING, "原始文本", True), - _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), - _tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False), - _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), - ], - ) - async def handle_ingest_text( - self, - external_id: str, - source_type: str, - text: str, - chat_id: str = "", - person_ids: Optional[List[str]] = None, - participants: Optional[List[str]] = None, - timestamp: float | None = None, - time_start: float | None = None, - time_end: float | None = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - respect_filter: bool = True, - **kwargs, - ): - relations = kwargs.get("relations") - entities = kwargs.get("entities") - kernel = await self._get_kernel() - return await kernel.ingest_text( - external_id=external_id, - source_type=source_type, - text=text, - chat_id=chat_id, - person_ids=person_ids, - participants=participants, - timestamp=timestamp, - time_start=time_start, - time_end=time_end, - tags=tags, - metadata=metadata, - entities=entities, - relations=relations, - respect_filter=respect_filter, - user_id=str(kwargs.get("user_id", "") or "").strip(), - group_id=str(kwargs.get("group_id", "") or "").strip(), - ) - - @Tool( - "get_person_profile", - description="获取人物画像", - parameters=[ - _tool_param("person_id", ToolParamType.STRING, "人物 ID", True), - _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), - _tool_param("limit", ToolParamType.INTEGER, "证据条数", False), - ], - ) - async def handle_get_person_profile(self, person_id: str, chat_id: str = "", limit: int = 10, **kwargs): - _ = kwargs - kernel = await self._get_kernel() - return await kernel.get_person_profile(person_id=person_id, chat_id=chat_id, limit=limit) - - @Tool( - "maintain_memory", - description="维护长期记忆关系状态", - parameters=[ - _tool_param("action", ToolParamType.STRING, "reinforce/protect/restore/freeze/recycle_bin", True), - _tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", False), - _tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False), - _tool_param("limit", ToolParamType.INTEGER, "查询条数(用于 recycle_bin)", False), - ], - ) - async def handle_maintain_memory( - self, - action: str, - target: str = "", - hours: float | None = None, - reason: str = "", - limit: int = 50, - **kwargs, - ): - _ = kwargs - kernel = await self._get_kernel() - return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason, limit=limit) - - @Tool("memory_stats", description="获取长期记忆统计", parameters=[]) - async def handle_memory_stats(self, **kwargs): - _ = kwargs - kernel = await self._get_kernel() - return kernel.memory_stats() - - @Tool("memory_graph_admin", description="长期记忆图谱管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_graph_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_graph_admin", action=action, **kwargs) - - @Tool("memory_source_admin", description="长期记忆来源管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_source_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_source_admin", action=action, **kwargs) - - @Tool("memory_episode_admin", description="Episode 管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_episode_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_episode_admin", action=action, **kwargs) - - @Tool("memory_profile_admin", description="人物画像管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_profile_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_profile_admin", action=action, **kwargs) - - @Tool("memory_runtime_admin", description="长期记忆运行时管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_runtime_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_runtime_admin", action=action, **kwargs) - - @Tool("memory_import_admin", description="长期记忆导入管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_import_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_import_admin", action=action, **kwargs) - - @Tool("memory_tuning_admin", description="长期记忆调优管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_tuning_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_tuning_admin", action=action, **kwargs) - - @Tool("memory_v5_admin", description="长期记忆 V5 管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_v5_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_v5_admin", action=action, **kwargs) - - @Tool("memory_delete_admin", description="长期记忆删除管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_delete_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_delete_admin", action=action, **kwargs) - - -def create_plugin(): - return AMemorixPlugin() diff --git a/plugins/A_memorix/requirements.txt b/plugins/A_memorix/requirements.txt deleted file mode 100644 index f737fdf4..00000000 --- a/plugins/A_memorix/requirements.txt +++ /dev/null @@ -1,52 +0,0 @@ -# A_Memorix 插件依赖 -# -# 核心依赖 (必需) -# ================== - -# 数值计算 - 用于向量操作、矩阵计算 -numpy>=1.20.0 - -# 稀疏矩阵 - 用于图存储的邻接矩阵 -scipy>=1.7.0 - -# 图结构处理(LPMM 转换) -networkx>=3.0.0 - -# Parquet 读取(LPMM 转换) -pyarrow>=10.0.0 - -# DataFrame 处理(LPMM 转换) -pandas>=1.5.0 - -# 异步事件循环嵌套 - 用于插件初始化时的异步操作 -nest-asyncio>=1.5.0 - -# 向量索引 - 用于向量存储和检索 -faiss-cpu>=1.7.0 - -# Web 服务器依赖 (可视化功能需要) -# ================== - -# ASGI 服务器 -uvicorn>=0.20.0 - -# Web 框架 -fastapi>=0.100.0 - -# 数据验证 -pydantic>=2.0.0 -python-multipart>=0.0.9 - -# 注意事项 -# ================== -# -# 1. sqlite3 是 Python 标准库,无需安装 -# 2. json, re, time, pathlib 等都是标准库 -# 3. sentence-transformers 不需要(使用主程序 Embedding API) - -# UI 交互 -rich>=14.0.0 -tenacity>=8.0.0 - -# 稀疏检索中文分词(可选,未安装时自动回退 char n-gram) -jieba>=0.42.1 diff --git a/plugins/A_memorix/scripts/audit_vector_consistency.py b/plugins/A_memorix/scripts/audit_vector_consistency.py deleted file mode 100644 index c97806dc..00000000 --- a/plugins/A_memorix/scripts/audit_vector_consistency.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 -""" -A_Memorix 一致性审计脚本。 - -输出内容: -1. paragraph/entity/relation 向量覆盖率 -2. relation vector_state 分布 -3. 孤儿向量数量(向量存在但 metadata 不存在) -4. 状态与向量文件不一致统计 -""" - -from __future__ import annotations - -import argparse -import json -import pickle -import sys -from pathlib import Path -from typing import Any, Dict, Set - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PLUGIN_ROOT)) - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="审计 A_Memorix 向量一致性") - parser.add_argument( - "--data-dir", - default=str(PLUGIN_ROOT / "data"), - help="A_Memorix 数据目录(默认: plugins/A_memorix/data)", - ) - parser.add_argument("--json-out", default="", help="可选:输出 JSON 文件路径") - parser.add_argument( - "--strict", - action="store_true", - help="若发现一致性异常则返回非 0 退出码", - ) - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - sys.exit(0) - -try: - from core.storage.vector_store import VectorStore - from core.storage.metadata_store import MetadataStore - from core.storage import QuantizationType -except Exception as e: # pragma: no cover - print(f"❌ 导入核心模块失败: {e}") - sys.exit(1) - - -def _safe_ratio(numerator: int, denominator: int) -> float: - if denominator <= 0: - return 0.0 - return float(numerator) / float(denominator) - - -def _load_vector_store(data_dir: Path) -> VectorStore: - meta_path = data_dir / "vectors" / "vectors_metadata.pkl" - if not meta_path.exists(): - raise FileNotFoundError(f"未找到向量元数据文件: {meta_path}") - - with open(meta_path, "rb") as f: - meta = pickle.load(f) - dimension = int(meta.get("dimension", 1024)) - - store = VectorStore( - dimension=max(1, dimension), - quantization_type=QuantizationType.INT8, - data_dir=data_dir / "vectors", - ) - if store.has_data(): - store.load() - return store - - -def _load_metadata_store(data_dir: Path) -> MetadataStore: - store = MetadataStore(data_dir=data_dir / "metadata") - store.connect() - return store - - -def _hash_set(metadata_store: MetadataStore, table: str) -> Set[str]: - return {str(h) for h in metadata_store.list_hashes(table)} - - -def _relation_state_stats(metadata_store: MetadataStore) -> Dict[str, int]: - return metadata_store.count_relations_by_vector_state() - - -def run_audit(data_dir: Path) -> Dict[str, Any]: - vector_store = _load_vector_store(data_dir) - metadata_store = _load_metadata_store(data_dir) - try: - paragraph_hashes = _hash_set(metadata_store, "paragraphs") - entity_hashes = _hash_set(metadata_store, "entities") - relation_hashes = _hash_set(metadata_store, "relations") - - known_hashes = set(getattr(vector_store, "_known_hashes", set())) - live_vector_hashes = {h for h in known_hashes if h in vector_store} - - para_vector_hits = len(paragraph_hashes & live_vector_hashes) - ent_vector_hits = len(entity_hashes & live_vector_hashes) - rel_vector_hits = len(relation_hashes & live_vector_hashes) - - orphan_vector_hashes = sorted( - live_vector_hashes - paragraph_hashes - entity_hashes - relation_hashes - ) - - relation_rows = metadata_store.get_relations() - ready_but_missing = 0 - not_ready_but_present = 0 - for row in relation_rows: - h = str(row.get("hash") or "") - state = str(row.get("vector_state") or "none").lower() - in_vector = h in live_vector_hashes - if state == "ready" and not in_vector: - ready_but_missing += 1 - if state != "ready" and in_vector: - not_ready_but_present += 1 - - relation_states = _relation_state_stats(metadata_store) - rel_total = max(0, int(relation_states.get("total", len(relation_hashes)))) - ready_count = max(0, int(relation_states.get("ready", 0))) - - result = { - "counts": { - "paragraphs": len(paragraph_hashes), - "entities": len(entity_hashes), - "relations": len(relation_hashes), - "vectors_live": len(live_vector_hashes), - }, - "coverage": { - "paragraph_vector_coverage": _safe_ratio(para_vector_hits, len(paragraph_hashes)), - "entity_vector_coverage": _safe_ratio(ent_vector_hits, len(entity_hashes)), - "relation_vector_coverage": _safe_ratio(rel_vector_hits, len(relation_hashes)), - "relation_ready_coverage": _safe_ratio(ready_count, rel_total), - }, - "relation_states": relation_states, - "orphans": { - "vector_only_count": len(orphan_vector_hashes), - "vector_only_sample": orphan_vector_hashes[:30], - }, - "consistency_checks": { - "ready_but_missing_vector": ready_but_missing, - "not_ready_but_vector_present": not_ready_but_present, - }, - } - return result - finally: - metadata_store.close() - - -def main() -> int: - parser = _build_arg_parser() - args = parser.parse_args() - - data_dir = Path(args.data_dir).resolve() - if not data_dir.exists(): - print(f"❌ 数据目录不存在: {data_dir}") - return 2 - - try: - result = run_audit(data_dir) - except Exception as e: - print(f"❌ 审计失败: {e}") - return 2 - - print("=== A_Memorix Vector Consistency Audit ===") - print(f"data_dir: {data_dir}") - print(f"paragraphs: {result['counts']['paragraphs']}") - print(f"entities: {result['counts']['entities']}") - print(f"relations: {result['counts']['relations']}") - print(f"vectors_live: {result['counts']['vectors_live']}") - print( - "coverage: " - f"paragraph={result['coverage']['paragraph_vector_coverage']:.3f}, " - f"entity={result['coverage']['entity_vector_coverage']:.3f}, " - f"relation={result['coverage']['relation_vector_coverage']:.3f}, " - f"relation_ready={result['coverage']['relation_ready_coverage']:.3f}" - ) - print(f"relation_states: {result['relation_states']}") - print( - "consistency_checks: " - f"ready_but_missing_vector={result['consistency_checks']['ready_but_missing_vector']}, " - f"not_ready_but_vector_present={result['consistency_checks']['not_ready_but_vector_present']}" - ) - print(f"orphan_vectors: {result['orphans']['vector_only_count']}") - - if args.json_out: - out_path = Path(args.json_out).resolve() - out_path.parent.mkdir(parents=True, exist_ok=True) - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"json_out: {out_path}") - - has_anomaly = ( - result["orphans"]["vector_only_count"] > 0 - or result["consistency_checks"]["ready_but_missing_vector"] > 0 - ) - if args.strict and has_anomaly: - return 1 - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/backfill_relation_vectors.py b/plugins/A_memorix/scripts/backfill_relation_vectors.py deleted file mode 100644 index 7ba0ade0..00000000 --- a/plugins/A_memorix/scripts/backfill_relation_vectors.py +++ /dev/null @@ -1,270 +0,0 @@ -#!/usr/bin/env python3 -""" -关系向量一次性回填脚本(灰度/离线执行)。 - -用途: -1. 对 relations 中 vector_state in (none, failed, pending) 的记录补齐向量。 -2. 支持并发控制,降低总耗时。 -3. 可作为灰度阶段验证工具,与 audit_vector_consistency.py 配合使用。 -4. 可选自动纳入“ready 但向量缺失”的漂移记录进行修复。 -""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import sys -import time -from pathlib import Path -from typing import Any, Dict, List - -import tomlkit - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PLUGIN_ROOT)) - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="关系向量一次性回填") - parser.add_argument( - "--config", - default=str(PLUGIN_ROOT / "config.toml"), - help="配置文件路径(默认 plugins/A_memorix/config.toml)", - ) - parser.add_argument( - "--data-dir", - default=str(PLUGIN_ROOT / "data"), - help="数据目录(默认 plugins/A_memorix/data)", - ) - parser.add_argument( - "--states", - default="none,failed,pending", - help="待处理状态列表,逗号分隔", - ) - parser.add_argument("--limit", type=int, default=50000, help="最大处理数量") - parser.add_argument("--concurrency", type=int, default=8, help="并发数") - parser.add_argument("--max-retry", type=int, default=None, help="最大重试次数过滤") - parser.add_argument( - "--include-ready-missing", - action="store_true", - help="额外纳入 vector_state=ready 但向量缺失的关系", - ) - parser.add_argument("--dry-run", action="store_true", help="仅统计候选,不写入") - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - -from core.storage import ( - VectorStore, - GraphStore, - MetadataStore, - QuantizationType, - SparseMatrixFormat, -) -from core.embedding import create_embedding_api_adapter -from core.utils.relation_write_service import RelationWriteService - - -def _load_config(config_path: Path) -> Dict[str, Any]: - with open(config_path, "r", encoding="utf-8") as f: - raw = tomlkit.load(f) - return dict(raw) if isinstance(raw, dict) else {} - - -def _build_vector_store(data_dir: Path, emb_cfg: Dict[str, Any]) -> VectorStore: - q_type = str(emb_cfg.get("quantization_type", "int8")).lower() - if q_type != "int8": - raise ValueError( - "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - dim = int(emb_cfg.get("dimension", 1024)) - store = VectorStore( - dimension=max(1, dim), - quantization_type=QuantizationType.INT8, - data_dir=data_dir / "vectors", - ) - if store.has_data(): - store.load() - return store - - -def _build_graph_store(data_dir: Path, graph_cfg: Dict[str, Any]) -> GraphStore: - fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower() - fmt_map = { - "csr": SparseMatrixFormat.CSR, - "csc": SparseMatrixFormat.CSC, - } - store = GraphStore( - matrix_format=fmt_map.get(fmt, SparseMatrixFormat.CSR), - data_dir=data_dir / "graph", - ) - if store.has_data(): - store.load() - return store - - -def _build_metadata_store(data_dir: Path) -> MetadataStore: - store = MetadataStore(data_dir=data_dir / "metadata") - store.connect() - return store - - -def _build_embedding_manager(emb_cfg: Dict[str, Any]): - retry_cfg = emb_cfg.get("retry", {}) - if not isinstance(retry_cfg, dict): - retry_cfg = {} - return create_embedding_api_adapter( - batch_size=int(emb_cfg.get("batch_size", 32)), - max_concurrent=int(emb_cfg.get("max_concurrent", 5)), - default_dimension=int(emb_cfg.get("dimension", 1024)), - model_name=str(emb_cfg.get("model_name", "auto")), - retry_config=retry_cfg, - ) - - -async def _process_rows( - service: RelationWriteService, - rows: List[Dict[str, Any]], - concurrency: int, -) -> Dict[str, int]: - semaphore = asyncio.Semaphore(max(1, int(concurrency))) - stat = {"success": 0, "failed": 0, "skipped": 0} - - async def _worker(row: Dict[str, Any]) -> None: - async with semaphore: - result = await service.ensure_relation_vector( - hash_value=str(row["hash"]), - subject=str(row.get("subject", "")), - predicate=str(row.get("predicate", "")), - obj=str(row.get("object", "")), - ) - if result.vector_state == "ready": - if result.vector_written: - stat["success"] += 1 - else: - stat["skipped"] += 1 - else: - stat["failed"] += 1 - - await asyncio.gather(*[_worker(row) for row in rows]) - return stat - - -async def main_async(args: argparse.Namespace) -> int: - config_path = Path(args.config).resolve() - if not config_path.exists(): - print(f"❌ 配置文件不存在: {config_path}") - return 2 - - cfg = _load_config(config_path) - emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {} - graph_cfg = cfg.get("graph", {}) if isinstance(cfg, dict) else {} - retrieval_cfg = cfg.get("retrieval", {}) if isinstance(cfg, dict) else {} - rv_cfg = retrieval_cfg.get("relation_vectorization", {}) if isinstance(retrieval_cfg, dict) else {} - if not isinstance(emb_cfg, dict): - emb_cfg = {} - if not isinstance(graph_cfg, dict): - graph_cfg = {} - if not isinstance(rv_cfg, dict): - rv_cfg = {} - - data_dir = Path(args.data_dir).resolve() - if not data_dir.exists(): - print(f"❌ 数据目录不存在: {data_dir}") - return 2 - - print(f"data_dir: {data_dir}") - print(f"config: {config_path}") - - vector_store = _build_vector_store(data_dir, emb_cfg) - graph_store = _build_graph_store(data_dir, graph_cfg) - metadata_store = _build_metadata_store(data_dir) - embedding_manager = _build_embedding_manager(emb_cfg) - service = RelationWriteService( - metadata_store=metadata_store, - graph_store=graph_store, - vector_store=vector_store, - embedding_manager=embedding_manager, - ) - - try: - states = [s.strip() for s in str(args.states).split(",") if s.strip()] - if not states: - states = ["none", "failed", "pending"] - max_retry = int(args.max_retry) if args.max_retry is not None else int(rv_cfg.get("max_retry", 3)) - limit = int(args.limit) - - rows = metadata_store.list_relations_by_vector_state( - states=states, - limit=max(1, limit), - max_retry=max(1, max_retry), - ) - added_ready_missing = 0 - if args.include_ready_missing: - ready_rows = metadata_store.list_relations_by_vector_state( - states=["ready"], - limit=max(1, limit), - max_retry=max(1, max_retry), - ) - ready_missing_rows = [ - row for row in ready_rows if str(row.get("hash", "")) not in vector_store - ] - added_ready_missing = len(ready_missing_rows) - if ready_missing_rows: - dedup: Dict[str, Dict[str, Any]] = {} - for row in rows: - dedup[str(row.get("hash", ""))] = row - for row in ready_missing_rows: - dedup.setdefault(str(row.get("hash", "")), row) - rows = list(dedup.values())[: max(1, limit)] - print(f"candidates: {len(rows)} (states={states}, max_retry={max_retry})") - if args.include_ready_missing: - print(f"ready_missing_candidates_added: {added_ready_missing}") - if not rows: - return 0 - - if args.dry_run: - print("dry_run=true,未执行写入。") - return 0 - - started = time.time() - stat = await _process_rows( - service=service, - rows=rows, - concurrency=int(args.concurrency), - ) - elapsed = (time.time() - started) * 1000.0 - - vector_store.save() - graph_store.save() - state_stats = metadata_store.count_relations_by_vector_state() - output = { - "processed": len(rows), - "success": int(stat["success"]), - "failed": int(stat["failed"]), - "skipped": int(stat["skipped"]), - "elapsed_ms": elapsed, - "state_stats": state_stats, - } - print(json.dumps(output, ensure_ascii=False, indent=2)) - return 0 if stat["failed"] == 0 else 1 - finally: - metadata_store.close() - - -def parse_args() -> argparse.Namespace: - return _build_arg_parser().parse_args() - - -if __name__ == "__main__": - arguments = parse_args() - raise SystemExit(asyncio.run(main_async(arguments))) diff --git a/plugins/A_memorix/scripts/backfill_temporal_metadata.py b/plugins/A_memorix/scripts/backfill_temporal_metadata.py deleted file mode 100644 index b68820cd..00000000 --- a/plugins/A_memorix/scripts/backfill_temporal_metadata.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -""" -回填段落时序字段。 - -默认策略: -1. 若段落缺失 event_time/event_time_start/event_time_end -2. 且存在 created_at -3. 写入 event_time=created_at, time_granularity=day, time_confidence=0.2 -""" - -from __future__ import annotations - -import argparse -from pathlib import Path -import sys - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) - -from plugins.A_memorix.core.storage import MetadataStore # noqa: E402 - - -def backfill( - data_dir: Path, - dry_run: bool, - limit: int, - no_created_fallback: bool, -) -> int: - store = MetadataStore(data_dir=data_dir) - store.connect() - summary = store.backfill_temporal_metadata_from_created_at( - limit=limit, - dry_run=dry_run, - no_created_fallback=no_created_fallback, - ) - store.close() - if dry_run: - print(f"[dry-run] candidates={summary['candidates']}") - return int(summary["candidates"]) - if no_created_fallback: - print(f"skip update (no-created-fallback), candidates={summary['candidates']}") - return 0 - print(f"updated={summary['updated']}") - return int(summary["updated"]) - - -def main() -> int: - parser = argparse.ArgumentParser(description="Backfill temporal metadata for A_Memorix paragraphs") - parser.add_argument("--data-dir", default=str(PLUGIN_ROOT / "data"), help="数据目录") - parser.add_argument("--dry-run", action="store_true", help="仅统计,不写入") - parser.add_argument("--limit", type=int, default=100000, help="最大处理条数") - parser.add_argument( - "--no-created-fallback", - action="store_true", - help="不使用 created_at 回填,仅输出候选数量", - ) - args = parser.parse_args() - - backfill( - data_dir=Path(args.data_dir), - dry_run=args.dry_run, - limit=max(1, int(args.limit)), - no_created_fallback=args.no_created_fallback, - ) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) - diff --git a/plugins/A_memorix/scripts/convert_lpmm.py b/plugins/A_memorix/scripts/convert_lpmm.py deleted file mode 100644 index 2ef0b396..00000000 --- a/plugins/A_memorix/scripts/convert_lpmm.py +++ /dev/null @@ -1,540 +0,0 @@ -#!/usr/bin/env python3 -""" -LPMM 到 A_memorix 存储转换器 - -功能: -1. 读取 LPMM parquet 文件 (paragraph.parquet, entity.parquet, relation.parquet) -2. 读取 LPMM 图文件 (graph.graphml 或 graph_structure.pkl) -3. 直接写入 A_memorix 二进制 VectorStore 和稀疏 GraphStore -4. 绕过 Embedding 生成以节省 Token -""" - -import sys -import os -import json -import argparse -import asyncio -import pickle -import logging -from pathlib import Path -from typing import Dict, Any, List, Tuple -import numpy as np -import tomlkit - -# 设置路径 -current_dir = Path(__file__).resolve().parent -plugin_root = current_dir.parent -project_root = plugin_root.parent.parent -sys.path.insert(0, str(project_root)) - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="将 LPMM 数据转换为 A_memorix 格式") - parser.add_argument("--input", "-i", required=True, help="包含 LPMM 数据的输入目录 (parquet, graphml)") - parser.add_argument("--output", "-o", required=True, help="A_memorix 数据的输出目录") - parser.add_argument("--dim", type=int, default=384, help="Embedding 维度 (必须与 LPMM 模型匹配)") - parser.add_argument("--batch-size", type=int, default=1024, help="Parquet 分批读取大小 (默认 1024)") - parser.add_argument( - "--skip-relation-vector-rebuild", - action="store_true", - help="跳过按关系元数据重建关系向量(默认开启)", - ) - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - sys.exit(0) - -# 设置日志:优先复用 MaiBot 统一日志体系,失败时回退到标准 logging。 -try: - from src.common.logger import get_logger - - logger = get_logger("A_Memorix.LPMMConverter") -except Exception: - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - logger = logging.getLogger("A_Memorix.LPMMConverter") - -try: - import networkx as nx - from scipy import sparse - import pyarrow.parquet as pq -except ImportError as e: - logger.error(f"缺少依赖: {e}") - logger.error("请安装: pip install pandas pyarrow networkx scipy") - sys.exit(1) - -try: - # 优先采取相对导入 (将插件根目录加入路径) - # 这样可以避免硬编码插件名称 (plugins.A_memorix) - if str(plugin_root) not in sys.path: - sys.path.insert(0, str(plugin_root)) - - from core.storage.vector_store import VectorStore - from core.storage.graph_store import GraphStore - from core.storage.metadata_store import MetadataStore - from core.storage import QuantizationType, SparseMatrixFormat - from core.embedding import create_embedding_api_adapter - from core.utils.relation_write_service import RelationWriteService - -except ImportError as e: - logger.error(f"无法导入 A_memorix 核心模块: {e}") - logger.error("请确保在正确的环境中运行,且已安装所有依赖。") - sys.exit(1) - - -class LPMMConverter: - def __init__( - self, - lpmm_data_dir: Path, - output_dir: Path, - dimension: int = 384, - batch_size: int = 1024, - rebuild_relation_vectors: bool = True, - ): - self.lpmm_dir = lpmm_data_dir - self.output_dir = output_dir - self.dimension = dimension - self.batch_size = max(1, int(batch_size)) - self.rebuild_relation_vectors = bool(rebuild_relation_vectors) - - self.vector_dir = output_dir / "vectors" - self.graph_dir = output_dir / "graph" - self.metadata_dir = output_dir / "metadata" - - self.vector_store = None - self.graph_store = None - self.metadata_store = None - self.embedding_manager = None - self.relation_write_service = None - # LPMM 原 ID -> A_memorix ID 映射(用于图重写) - self.id_mapping: Dict[str, str] = {} - - def _register_id_mapping(self, raw_id: Any, mapped_id: str, p_type: str) -> None: - """记录 ID 映射,兼容带/不带类型前缀两种格式。""" - if raw_id is None: - return - - raw = str(raw_id).strip() - if not raw: - return - - self.id_mapping[raw] = mapped_id - - prefix = f"{p_type}-" - if raw.startswith(prefix): - self.id_mapping[raw[len(prefix):]] = mapped_id - else: - self.id_mapping[prefix + raw] = mapped_id - - def _map_node_id(self, node: Any) -> str: - """将图节点 ID 映射到转换后的 A_memorix ID。""" - node_key = str(node) - return self.id_mapping.get(node_key, node_key) - - def initialize_stores(self): - """初始化空的 A_memorix 存储""" - logger.info(f"正在初始化存储于 {self.output_dir}...") - - # 初始化 VectorStore (A_memorix 默认使用 INT8 量化) - self.vector_store = VectorStore( - dimension=self.dimension, - quantization_type=QuantizationType.INT8, - data_dir=self.vector_dir - ) - self.vector_store.clear() # 清空旧数据 - - # 初始化 GraphStore (使用 CSR 格式) - self.graph_store = GraphStore( - matrix_format=SparseMatrixFormat.CSR, - data_dir=self.graph_dir - ) - self.graph_store.clear() - - # 初始化 MetadataStore - self.metadata_store = MetadataStore(data_dir=self.metadata_dir) - self.metadata_store.connect() - # 清空元数据表?理想情况下是的,但要小心。 - # 对于转换,我们假设是全新的开始或覆盖。 - # A_memorix 中的 MetadataStore 通常使用 SQLite。 - # 如果目录是新的,我们会依赖它创建新文件。 - if self.rebuild_relation_vectors: - self._init_relation_vector_service() - - def _load_plugin_config(self) -> Dict[str, Any]: - config_path = plugin_root / "config.toml" - if not config_path.exists(): - return {} - try: - with open(config_path, "r", encoding="utf-8") as f: - parsed = tomlkit.load(f) - return dict(parsed) if isinstance(parsed, dict) else {} - except Exception as e: - logger.warning(f"读取 config.toml 失败,使用默认 embedding 配置: {e}") - return {} - - def _init_relation_vector_service(self) -> None: - if not self.rebuild_relation_vectors: - return - cfg = self._load_plugin_config() - emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {} - if not isinstance(emb_cfg, dict): - emb_cfg = {} - try: - self.embedding_manager = create_embedding_api_adapter( - batch_size=int(emb_cfg.get("batch_size", 32)), - max_concurrent=int(emb_cfg.get("max_concurrent", 5)), - default_dimension=int(emb_cfg.get("dimension", self.dimension)), - model_name=str(emb_cfg.get("model_name", "auto")), - retry_config=emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {}, - ) - self.relation_write_service = RelationWriteService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - except Exception as e: - self.embedding_manager = None - self.relation_write_service = None - logger.warning(f"初始化关系向量重建服务失败,将跳过关系向量回填: {e}") - - async def _rebuild_relation_vectors(self) -> None: - if not self.rebuild_relation_vectors: - return - if self.relation_write_service is None: - logger.warning("关系向量重建已启用,但写入服务不可用,已跳过。") - return - - rows = self.metadata_store.get_relations() - if not rows: - logger.info("未发现关系元数据,无需重建关系向量。") - return - - success = 0 - failed = 0 - skipped = 0 - for row in rows: - result = await self.relation_write_service.ensure_relation_vector( - hash_value=str(row["hash"]), - subject=str(row.get("subject", "")), - predicate=str(row.get("predicate", "")), - obj=str(row.get("object", "")), - ) - if result.vector_state == "ready": - if result.vector_written: - success += 1 - else: - skipped += 1 - else: - failed += 1 - - logger.info( - "关系向量重建完成: " - f"total={len(rows)} " - f"success={success} " - f"skipped={skipped} " - f"failed={failed}" - ) - - @staticmethod - def _parse_relation_text(text: str) -> Tuple[str, str, str]: - raw = str(text or "").strip() - if not raw: - return "", "", "" - if "|" in raw: - parts = [p.strip() for p in raw.split("|") if p.strip()] - if len(parts) >= 3: - return parts[0], parts[1], parts[2] - if "->" in raw: - parts = [p.strip() for p in raw.split("->") if p.strip()] - if len(parts) >= 3: - return parts[0], parts[1], parts[2] - pieces = raw.split() - if len(pieces) >= 3: - return pieces[0], pieces[1], " ".join(pieces[2:]) - return "", "", "" - - def _import_relation_metadata_from_parquet(self, relation_path: Path) -> int: - if not relation_path.exists(): - return 0 - - try: - parquet_file = pq.ParquetFile(relation_path) - except Exception as e: - logger.warning(f"读取 relation.parquet 失败,跳过关系元数据导入: {e}") - return 0 - - cols = set(parquet_file.schema_arrow.names) - has_triple_cols = {"subject", "predicate", "object"}.issubset(cols) - content_col = "str" if "str" in cols else ("content" if "content" in cols else "") - - imported_hashes = set() - imported = 0 - for record_batch in parquet_file.iter_batches(batch_size=self.batch_size): - df_batch = record_batch.to_pandas() - for _, row in df_batch.iterrows(): - subject = "" - predicate = "" - obj = "" - if has_triple_cols: - subject = str(row.get("subject", "") or "").strip() - predicate = str(row.get("predicate", "") or "").strip() - obj = str(row.get("object", "") or "").strip() - elif content_col: - subject, predicate, obj = self._parse_relation_text(row.get(content_col, "")) - - if not (subject and predicate and obj): - continue - - rel_hash = self.metadata_store.add_relation( - subject=subject, - predicate=predicate, - obj=obj, - source_paragraph=None, - ) - if rel_hash in imported_hashes: - continue - imported_hashes.add(rel_hash) - self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) - try: - self.metadata_store.set_relation_vector_state(rel_hash, "none") - except Exception: - pass - imported += 1 - - return imported - - def convert_vectors(self): - """将 Parquet 向量转换为 VectorStore""" - # LPMM 默认文件名 - parquet_files = { - "paragraph": self.lpmm_dir / "paragraph.parquet", - "entity": self.lpmm_dir / "entity.parquet", - "relation": self.lpmm_dir / "relation.parquet" - } - - total_vectors = 0 - - for p_type, p_path in parquet_files.items(): - # 关系向量在当前脚本中无法保证与 MetadataStore 的关系记录一一对应, - # 直接导入会污染召回结果(命中后无法反查 relation 元数据)。 - if p_type == "relation": - relation_count = self._import_relation_metadata_from_parquet(p_path) - logger.warning( - "跳过 relation.parquet 向量导入(保持一致性);" - f"已导入关系元数据: {relation_count}" - ) - continue - - if not p_path.exists(): - logger.warning(f"文件未找到: {p_path}, 跳过 {p_type} 向量。") - continue - - logger.info(f"正在处理 {p_type} 向量,来源: {p_path}...") - try: - parquet_file = pq.ParquetFile(p_path) - total_rows = parquet_file.metadata.num_rows - if total_rows == 0: - logger.info(f"{p_path} 为空,跳过。") - continue - - # LPMM Schema: 'hash', 'embedding', 'str' - cols = parquet_file.schema_arrow.names - # 兼容性检查 - content_col = 'str' if 'str' in cols else 'content' - emb_col = 'embedding' - hash_col = 'hash' - - if content_col not in cols or emb_col not in cols: - logger.error(f"{p_path} 中缺少必要列 (需包含 {content_col}, {emb_col})。发现: {cols}") - continue - - batch_columns = [content_col, emb_col] - if hash_col in cols: - batch_columns.append(hash_col) - - processed_rows = 0 - added_for_type = 0 - batch_idx = 0 - - for record_batch in parquet_file.iter_batches( - batch_size=self.batch_size, - columns=batch_columns, - ): - batch_idx += 1 - df_batch = record_batch.to_pandas() - - embeddings_list = [] - ids_list = [] - - # 同时处理元数据映射 - for _, row in df_batch.iterrows(): - processed_rows += 1 - content = row[content_col] - emb = row[emb_col] - - if content is None or (isinstance(content, float) and np.isnan(content)): - continue - content = str(content).strip() - if not content: - continue - - if emb is None or len(emb) == 0: - continue - - # 先写 MetadataStore,并使用其返回的真实 hash 作为向量 ID - # 保证检索返回 ID 可以直接反查元数据。 - store_id = None - if p_type == "paragraph": - store_id = self.metadata_store.add_paragraph( - content=content, - source="lpmm_import", - knowledge_type="factual", - ) - elif p_type == "entity": - store_id = self.metadata_store.add_entity(name=content) - else: - continue - - raw_hash = row[hash_col] if hash_col in df_batch.columns else None - if raw_hash is not None and not (isinstance(raw_hash, float) and np.isnan(raw_hash)): - self._register_id_mapping(raw_hash, store_id, p_type) - - # 确保 embedding 是 numpy 数组 - emb_np = np.array(emb, dtype=np.float32) - if emb_np.shape[0] != self.dimension: - logger.error(f"维度不匹配: {emb_np.shape[0]} vs {self.dimension}") - continue - - embeddings_list.append(emb_np) - ids_list.append(store_id) - - if embeddings_list: - # 分批添加到向量存储 - vectors_np = np.stack(embeddings_list) - count = self.vector_store.add(vectors_np, ids_list) - added_for_type += count - total_vectors += count - - if batch_idx == 1 or batch_idx % 10 == 0: - logger.info( - f"[{p_type}] 批次 {batch_idx}: 已扫描 {processed_rows}/{total_rows}, 已导入 {added_for_type}" - ) - - logger.info( - f"{p_type} 向量处理完成:总扫描 {processed_rows},总导入 {added_for_type}" - ) - - except Exception as e: - logger.error(f"处理 {p_path} 时出错: {e}") - - # 提交向量存储 - self.vector_store.save() - logger.info(f"向量转换完成。总向量数: {total_vectors}") - - def convert_graph(self): - """将 LPMM 图转换为 GraphStore""" - # LPMM 默认文件名是 rag-graph.graphml - graph_files = [ - self.lpmm_dir / "rag-graph.graphml", - self.lpmm_dir / "graph.graphml", - self.lpmm_dir / "graph_structure.pkl" - ] - - nx_graph = None - - for g_path in graph_files: - if g_path.exists(): - logger.info(f"发现图文件: {g_path}") - try: - if g_path.suffix == ".graphml": - nx_graph = nx.read_graphml(g_path) - elif g_path.suffix == ".pkl": - with open(g_path, "rb") as f: - data = pickle.load(f) - # LPMM 可能会将图存储在包装类中 - if hasattr(data, "graph") and isinstance(data.graph, nx.Graph): - nx_graph = data.graph - elif isinstance(data, nx.Graph): - nx_graph = data - break - except Exception as e: - logger.error(f"加载 {g_path} 失败: {e}") - - if nx_graph is None: - logger.warning("未找到有效的图文件。跳过图转换。") - return - - logger.info(f"已加载图,包含 {nx_graph.number_of_nodes()} 个节点和 {nx_graph.number_of_edges()} 条边。") - - # 1. 添加节点 - # LPMM 节点通常是哈希或带前缀的字符串。 - # 我们需要将它们映射到 A_memorix 格式。 - # 如果 LPMM 使用 "entity-HASH",则与 A_memorix 匹配。 - - nodes_to_add = [] - node_attrs = {} - - for node, attrs in nx_graph.nodes(data=True): - # 假设 LPMM 使用一致的命名 "entity-..." 或 "paragraph-..." - mapped_node = self._map_node_id(node) - nodes_to_add.append(mapped_node) - if attrs: - node_attrs[mapped_node] = attrs - - self.graph_store.add_nodes(nodes_to_add, node_attrs) - - # 2. 添加边 - edges_to_add = [] - weights = [] - - for u, v, data in nx_graph.edges(data=True): - weight = data.get("weight", 1.0) - edges_to_add.append((self._map_node_id(u), self._map_node_id(v))) - weights.append(float(weight)) - - # 如果可能,将关系同步到 MetadataStore - # 但图的边并不总是包含关系谓词 - # 如果 LPMM 边数据有 'predicate',我们可以添加到元数据 - # 通常 LPMM 边是加权和,谓词信息可能在简单图中丢失 - - if edges_to_add: - self.graph_store.add_edges(edges_to_add, weights) - - self.graph_store.save() - logger.info("图转换完成。") - - def run(self): - self.initialize_stores() - self.convert_vectors() - self.convert_graph() - asyncio.run(self._rebuild_relation_vectors()) - self.vector_store.save() - self.graph_store.save() - self.metadata_store.close() - logger.info("所有转换成功完成。") - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - input_path = Path(args.input) - output_path = Path(args.output) - - if not input_path.exists(): - logger.error(f"输入目录不存在: {input_path}") - sys.exit(1) - - converter = LPMMConverter( - input_path, - output_path, - dimension=args.dim, - batch_size=args.batch_size, - rebuild_relation_vectors=not bool(args.skip_relation_vector_rebuild), - ) - converter.run() - -if __name__ == "__main__": - main() diff --git a/plugins/A_memorix/scripts/import_lpmm_json.py b/plugins/A_memorix/scripts/import_lpmm_json.py deleted file mode 100644 index 2e458e16..00000000 --- a/plugins/A_memorix/scripts/import_lpmm_json.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env python3 -""" -LPMM OpenIE JSON 导入工具。 - -功能: -1. 读取符合 LPMM 规范的 OpenIE JSON 文件 -2. 转换为 A_Memorix 的统一导入格式 -3. 复用 `process_knowledge.py` 中的 `AutoImporter` 直接入库 -""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import sys -import traceback -from pathlib import Path -from typing import Any, Dict, List - -from rich.console import Console -from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn - -console = Console() - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -for path in (CURRENT_DIR, WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="将 LPMM OpenIE JSON 导入 A_Memorix") - parser.add_argument("path", help="LPMM JSON 文件路径或目录") - parser.add_argument("--force", action="store_true", help="强制重新导入") - parser.add_argument("--concurrency", "-c", type=int, default=5, help="并发数") - return parser - - -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - - -try: - from process_knowledge import AutoImporter - from A_memorix.core.utils.hash import compute_paragraph_hash - from src.common.logger import get_logger -except ImportError as exc: # pragma: no cover - script bootstrap - print(f"导入模块失败,请确认 PYTHONPATH 与工作区结构: {exc}") - raise SystemExit(1) - - -logger = get_logger("A_Memorix.LPMMImport") - - -class LPMMConverter: - def convert_lpmm_to_memorix(self, lpmm_data: Dict[str, Any], filename: str) -> Dict[str, Any]: - memorix_data = {"paragraphs": [], "entities": []} - docs = lpmm_data.get("docs", []) or [] - if not docs: - logger.warning(f"文件中未找到 docs 字段: {filename}") - return memorix_data - - all_entities = set() - for doc in docs: - content = str(doc.get("passage", "") or "").strip() - if not content: - continue - - relations: List[Dict[str, str]] = [] - for triple in doc.get("extracted_triples", []) or []: - if isinstance(triple, list) and len(triple) == 3: - relations.append( - { - "subject": str(triple[0] or "").strip(), - "predicate": str(triple[1] or "").strip(), - "object": str(triple[2] or "").strip(), - } - ) - - entities = [str(item or "").strip() for item in doc.get("extracted_entities", []) or [] if str(item or "").strip()] - all_entities.update(entities) - for relation in relations: - if relation["subject"]: - all_entities.add(relation["subject"]) - if relation["object"]: - all_entities.add(relation["object"]) - - memorix_data["paragraphs"].append( - { - "hash": compute_paragraph_hash(content), - "content": content, - "source": filename, - "entities": entities, - "relations": relations, - } - ) - - memorix_data["entities"] = sorted(all_entities) - return memorix_data - - -async def main() -> None: - parser = _build_arg_parser() - args = parser.parse_args() - - target_path = Path(args.path) - if not target_path.exists(): - logger.error(f"路径不存在: {target_path}") - return - - if target_path.is_dir(): - files_to_process = list(target_path.glob("*-openie.json")) or list(target_path.glob("*.json")) - else: - files_to_process = [target_path] - - if not files_to_process: - logger.error("未找到可处理的 JSON 文件") - return - - importer = AutoImporter(force=bool(args.force), concurrency=int(args.concurrency)) - if not await importer.initialize(): - logger.error("初始化存储失败") - return - - converter = LPMMConverter() - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TimeElapsedColumn(), - console=console, - transient=False, - ) as progress: - for json_file in files_to_process: - logger.info(f"正在转换并导入: {json_file.name}") - try: - with open(json_file, "r", encoding="utf-8") as handle: - lpmm_data = json.load(handle) - memorix_data = converter.convert_lpmm_to_memorix(lpmm_data, json_file.name) - total_items = len(memorix_data.get("paragraphs", [])) - if total_items <= 0: - logger.warning(f"转换结果为空: {json_file.name}") - continue - - task_id = progress.add_task(f"Importing {json_file.name}", total=total_items) - - def update_progress(step: int = 1) -> None: - progress.advance(task_id, advance=step) - - await importer.import_json_data( - memorix_data, - filename=f"lpmm_{json_file.name}", - progress_callback=update_progress, - ) - except Exception as exc: - logger.error(f"处理文件 {json_file.name} 失败: {exc}\n{traceback.format_exc()}") - - await importer.close() - logger.info("全部处理完成") - - -if __name__ == "__main__": - if sys.platform == "win32": # pragma: no cover - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - asyncio.run(main()) diff --git a/plugins/A_memorix/scripts/migrate_chat_history.py b/plugins/A_memorix/scripts/migrate_chat_history.py deleted file mode 100644 index 0fb0bfe1..00000000 --- a/plugins/A_memorix/scripts/migrate_chat_history.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import asyncio -import json -import sqlite3 -import sys -from datetime import datetime -from pathlib import Path -from typing import Any, Dict - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db" - -if str(WORKSPACE_ROOT) not in sys.path: - sys.path.insert(0, str(WORKSPACE_ROOT)) -if str(MAIBOT_ROOT) not in sys.path: - sys.path.insert(0, str(MAIBOT_ROOT)) - -from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402 - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="迁移 MaiBot chat_history 到 A_Memorix") - parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径") - parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录") - parser.add_argument("--limit", type=int, default=0, help="限制迁移条数,0 表示全部") - parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入") - return parser.parse_args() - - -def _to_timestamp(value: Any) -> float | None: - if value is None: - return None - if isinstance(value, (int, float)): - return float(value) - text = str(value).strip() - if not text: - return None - try: - return datetime.fromisoformat(text).timestamp() - except ValueError: - return None - - -async def _main() -> int: - args = _parse_args() - db_path = Path(args.db_path).resolve() - if not db_path.exists(): - print(f"数据库不存在: {db_path}") - return 1 - - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - sql = """ - SELECT id, session_id, start_timestamp, end_timestamp, participants, theme, keywords, summary - FROM chat_history - ORDER BY id ASC - """ - if int(args.limit or 0) > 0: - sql += " LIMIT ?" - rows = conn.execute(sql, (int(args.limit),)).fetchall() - else: - rows = conn.execute(sql).fetchall() - conn.close() - - print(f"chat_history 待处理: {len(rows)}") - if args.dry_run: - for row in rows[:5]: - print(f"- id={row['id']} session={row['session_id']} theme={row['theme']}") - return 0 - - kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}}) - await kernel.initialize() - migrated = 0 - skipped = 0 - for row in rows: - participants = json.loads(row["participants"]) if row["participants"] else [] - keywords = json.loads(row["keywords"]) if row["keywords"] else [] - theme = str(row["theme"] or "").strip() - summary = str(row["summary"] or "").strip() - text = f"主题:{theme}\n概括:{summary}".strip() - result: Dict[str, Any] = await kernel.ingest_summary( - external_id=f"chat_history:{row['id']}", - chat_id=str(row["session_id"] or ""), - text=text, - participants=participants, - time_start=_to_timestamp(row["start_timestamp"]), - time_end=_to_timestamp(row["end_timestamp"]), - tags=keywords, - metadata={"theme": theme, "source_row_id": int(row["id"])}, - ) - if result.get("stored_ids"): - migrated += 1 - else: - skipped += 1 - - print(f"迁移完成: migrated={migrated} skipped={skipped}") - print(json.dumps(kernel.memory_stats(), ensure_ascii=False)) - kernel.close() - return 0 - - -if __name__ == "__main__": - raise SystemExit(asyncio.run(_main())) diff --git a/plugins/A_memorix/scripts/migrate_maibot_memory.py b/plugins/A_memorix/scripts/migrate_maibot_memory.py deleted file mode 100644 index 0b26a9cd..00000000 --- a/plugins/A_memorix/scripts/migrate_maibot_memory.py +++ /dev/null @@ -1,1714 +0,0 @@ -#!/usr/bin/env python3 -""" -MaiBot 记忆迁移脚本(chat_history -> A_memorix) - -特性: -1. 高性能:分页读取 + 批量 embedding + 批量写入 -2. 断点续传:基于 last_committed_id 的窗口提交 -3. 精确一次语义:稳定哈希 + 幂等写入 + 向量存在性检查 -4. 可确认筛选:支持时间区间、聊天流(stream/group/user)筛选,并先预览后确认 -""" - -from __future__ import annotations - -import argparse -import asyncio -import hashlib -import importlib -import json -import logging -import os -import pickle -import sqlite3 -import sys -import time -import traceback -import types -from collections import defaultdict -from dataclasses import dataclass -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple - -import numpy as np -import tomlkit - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -RUNTIME_CORE_PACKAGE = "_a_memorix_runtime_core" - -VectorStore = None -GraphStore = None -MetadataStore = None -create_embedding_api_adapter = None -KnowledgeType = None -QuantizationType = None -SparseMatrixFormat = None -compute_hash = None -normalize_text = None -atomic_write = None -model_config = None -RelationWriteService = None - - -def _create_bootstrap_logger(): - fallback = logging.getLogger("A_Memorix.MaiBotMigration") - if not fallback.handlers: - fallback.addHandler(logging.NullHandler()) - try: - for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - from src.common.logger import get_logger - - return get_logger("A_Memorix.MaiBotMigration") - except Exception: - return fallback - - -logger = _create_bootstrap_logger() - - -def _ensure_import_paths() -> None: - for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - - -def _ensure_runtime_core_package() -> str: - existing = sys.modules.get(RUNTIME_CORE_PACKAGE) - if existing is not None and hasattr(existing, "__path__"): - return RUNTIME_CORE_PACKAGE - - pkg = types.ModuleType(RUNTIME_CORE_PACKAGE) - pkg.__path__ = [str(PLUGIN_ROOT / "core")] - pkg.__package__ = RUNTIME_CORE_PACKAGE - sys.modules[RUNTIME_CORE_PACKAGE] = pkg - return RUNTIME_CORE_PACKAGE - - -def _disable_unavailable_gemini_provider() -> None: - global model_config - try: - from google import genai # type: ignore # noqa: F401 - return - except Exception: - pass - - from src.config.config import model_config as loaded_model_config - - providers = list(getattr(loaded_model_config, "api_providers", [])) - if not providers: - model_config = loaded_model_config - return - - kept_providers = [p for p in providers if str(getattr(p, "client_type", "")).lower() != "gemini"] - if len(kept_providers) == len(providers): - model_config = loaded_model_config - return - - loaded_model_config.api_providers = kept_providers - loaded_model_config.api_providers_dict = {p.name: p for p in kept_providers} - - models = list(getattr(loaded_model_config, "models", [])) - kept_models = [m for m in models if m.api_provider in loaded_model_config.api_providers_dict] - loaded_model_config.models = kept_models - loaded_model_config.models_dict = {m.name: m for m in kept_models} - - task_cfg = loaded_model_config.model_task_config - for field_name in task_cfg.__dataclass_fields__.keys(): - task = getattr(task_cfg, field_name, None) - if task is None or not hasattr(task, "model_list"): - continue - task.model_list = [m for m in list(task.model_list) if m in loaded_model_config.models_dict] - - model_config = loaded_model_config - logger.warning("检测到缺少 google.genai,已临时禁用 gemini provider 以保证脚本可运行。") - - -def _bootstrap_runtime_symbols() -> None: - global VectorStore - global GraphStore - global MetadataStore - global KnowledgeType - global QuantizationType - global SparseMatrixFormat - global compute_hash - global normalize_text - global atomic_write - global RelationWriteService - global logger - - if VectorStore is not None and compute_hash is not None and atomic_write is not None: - return - - _ensure_import_paths() - - import src # noqa: F401 - from src.common.logger import get_logger - - logger = get_logger("A_Memorix.MaiBotMigration") - - pkg = _ensure_runtime_core_package() - - vector_store_module = importlib.import_module(f"{pkg}.storage.vector_store") - graph_store_module = importlib.import_module(f"{pkg}.storage.graph_store") - metadata_store_module = importlib.import_module(f"{pkg}.storage.metadata_store") - knowledge_types_module = importlib.import_module(f"{pkg}.storage.knowledge_types") - hash_module = importlib.import_module(f"{pkg}.utils.hash") - io_module = importlib.import_module(f"{pkg}.utils.io") - relation_write_service_module = importlib.import_module(f"{pkg}.utils.relation_write_service") - - VectorStore = vector_store_module.VectorStore - GraphStore = graph_store_module.GraphStore - MetadataStore = metadata_store_module.MetadataStore - KnowledgeType = knowledge_types_module.KnowledgeType - QuantizationType = vector_store_module.QuantizationType - SparseMatrixFormat = graph_store_module.SparseMatrixFormat - compute_hash = hash_module.compute_hash - normalize_text = hash_module.normalize_text - atomic_write = io_module.atomic_write - RelationWriteService = relation_write_service_module.RelationWriteService - - -def _load_embedding_adapter_factory() -> None: - global create_embedding_api_adapter - global model_config - - if create_embedding_api_adapter is not None: - return - - _ensure_import_paths() - - from src.config.config import model_config as loaded_model_config - - model_config = loaded_model_config - _disable_unavailable_gemini_provider() - - pkg = _ensure_runtime_core_package() - api_adapter_module = importlib.import_module(f"{pkg}.embedding.api_adapter") - create_embedding_api_adapter = api_adapter_module.create_embedding_api_adapter - - -DEFAULT_SOURCE_DB = MAIBOT_ROOT / "data" / "MaiBot.db" -DEFAULT_TARGET_DATA_DIR = PLUGIN_ROOT / "data" -DEFAULT_CONFIG_PATH = PLUGIN_ROOT / "config.toml" - -MIGRATION_STATE_DIRNAME = "migration_state" -STATE_FILENAME = "chat_history_resume.json" -BAD_ROWS_FILENAME = "chat_history_bad_rows.jsonl" -REPORT_FILENAME = "chat_history_report.json" - - -class MigrationError(Exception): - """迁移流程错误。""" - - -@dataclass -class SelectionFilter: - time_from_ts: Optional[float] - time_to_ts: Optional[float] - stream_ids: List[str] - stream_filter_requested: bool - start_id: Optional[int] - end_id: Optional[int] - time_from_raw: Optional[str] - time_to_raw: Optional[str] - - def fingerprint_payload(self) -> Dict[str, Any]: - return { - "time_from_ts": self.time_from_ts, - "time_to_ts": self.time_to_ts, - "time_from_raw": self.time_from_raw, - "time_to_raw": self.time_to_raw, - "stream_ids": sorted(self.stream_ids), - "stream_filter_requested": self.stream_filter_requested, - "start_id": self.start_id, - "end_id": self.end_id, - } - - -@dataclass -class PreviewResult: - total: int - distribution: List[Tuple[str, int]] - samples: List[Dict[str, Any]] - - -@dataclass -class MappedRow: - row_id: int - chat_id: str - paragraph_hash: str - content: str - source: str - time_meta: Dict[str, Any] - entities: List[str] - relations: List[Tuple[str, str, str]] - existing_paragraph_vector: bool - - -def _safe_int(value: Any, default: int) -> int: - try: - return int(value) - except Exception: - return default - - -def _safe_float(value: Any, default: float) -> float: - try: - return float(value) - except Exception: - return default - - -def _normalize_name(value: Any) -> str: - return str(value or "").strip() - - -def _canonical_name(value: Any) -> str: - return _normalize_name(value).lower() - - -def _dedup_keep_order(items: Iterable[str]) -> List[str]: - out: List[str] = [] - seen: set[str] = set() - for raw in items: - v = _normalize_name(raw) - if not v: - continue - k = v.lower() - if k in seen: - continue - seen.add(k) - out.append(v) - return out - - -def _format_ts(ts: Optional[float]) -> str: - if ts is None: - return "-" - try: - return datetime.fromtimestamp(float(ts)).strftime("%Y-%m-%d %H:%M:%S") - except Exception: - return str(ts) - - -def _parse_cli_datetime(text: str, is_end: bool = False) -> float: - value = str(text or "").strip() - if not value: - raise ValueError("时间不能为空") - - formats = [ - ("%Y-%m-%d %H:%M:%S", False), - ("%Y/%m/%d %H:%M:%S", False), - ("%Y-%m-%d %H:%M", False), - ("%Y/%m/%d %H:%M", False), - ("%Y-%m-%d", True), - ("%Y/%m/%d", True), - ] - - for fmt, is_date_only in formats: - try: - dt = datetime.strptime(value, fmt) - if is_date_only and is_end: - dt = dt.replace(hour=23, minute=59, second=59, microsecond=0) - return dt.timestamp() - except ValueError: - continue - - raise ValueError( - f"时间格式错误: {value},仅支持 YYYY-MM-DD、YYYY/MM/DD、YYYY-MM-DD HH:mm[:ss]、YYYY/MM/DD HH:mm[:ss]" - ) - - -def _json_hash(payload: Dict[str, Any]) -> str: - data = json.dumps(payload, ensure_ascii=False, sort_keys=True) - return hashlib.sha1(data.encode("utf-8")).hexdigest() - - -def _deep_merge_dict(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: - out = dict(base) - for key, value in override.items(): - if isinstance(value, dict) and isinstance(out.get(key), dict): - out[key] = _deep_merge_dict(out[key], value) - else: - out[key] = value - return out - - -def _extract_schema_defaults(schema_obj: Dict[str, Any]) -> Dict[str, Any]: - defaults: Dict[str, Any] = {} - if not isinstance(schema_obj, dict): - return defaults - - for key, spec in schema_obj.items(): - if not isinstance(spec, dict): - continue - if "default" in spec: - defaults[key] = spec.get("default") - continue - props = spec.get("properties") - if isinstance(props, dict): - defaults[key] = _extract_schema_defaults(props) - return defaults - - -def _load_manifest_defaults() -> Dict[str, Any]: - manifest_path = PLUGIN_ROOT / "_manifest.json" - if not manifest_path.exists(): - return {} - try: - with open(manifest_path, "r", encoding="utf-8") as f: - payload = json.load(f) - schema = payload.get("config_schema") - if isinstance(schema, dict): - return _extract_schema_defaults(schema) - except Exception as e: - logger.warning(f"读取 manifest 默认配置失败,已回退空配置: {e}") - return {} - - -def _build_source_db_fingerprint(db_path: Path) -> Dict[str, Any]: - stat = db_path.stat() - payload = { - "path": str(db_path.resolve()), - "size": stat.st_size, - "mtime": stat.st_mtime, - } - payload["sha1"] = _json_hash(payload) - return payload - - -def _state_path(target_data_dir: Path) -> Path: - return target_data_dir / MIGRATION_STATE_DIRNAME / STATE_FILENAME - - -def _bad_rows_path(target_data_dir: Path) -> Path: - return target_data_dir / MIGRATION_STATE_DIRNAME / BAD_ROWS_FILENAME - - -def _report_path(target_data_dir: Path) -> Path: - return target_data_dir / MIGRATION_STATE_DIRNAME / REPORT_FILENAME - - -def _dump_json_atomic(path: Path, payload: Dict[str, Any]) -> None: - if atomic_write is None: - path.parent.mkdir(parents=True, exist_ok=True) - tmp = path.with_suffix(path.suffix + ".tmp") - with open(tmp, "w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - f.write("\n") - f.flush() - os.fsync(f.fileno()) - os.replace(tmp, path) - return - - with atomic_write(path, mode="w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - f.write("\n") - - -class SourceDB: - def __init__(self, db_path: Path): - self.db_path = db_path - self.conn: Optional[sqlite3.Connection] = None - - def connect(self) -> None: - if not self.db_path.exists(): - raise MigrationError(f"源数据库不存在: {self.db_path}") - - uri = f"file:{self.db_path.resolve().as_posix()}?mode=ro" - try: - self.conn = sqlite3.connect(uri, uri=True, check_same_thread=False) - except sqlite3.OperationalError: - self.conn = sqlite3.connect(str(self.db_path.resolve()), check_same_thread=False) - - self.conn.row_factory = sqlite3.Row - pragmas = [ - "PRAGMA query_only = ON", - "PRAGMA cache_size = -128000", - "PRAGMA temp_store = MEMORY", - "PRAGMA synchronous = OFF", - "PRAGMA journal_mode = WAL", - ] - for sql in pragmas: - try: - self.conn.execute(sql) - except sqlite3.OperationalError: - # 部分 PRAGMA 在 mode=ro 下会失败,不影响只读扫描能力 - continue - - def close(self) -> None: - if self.conn is not None: - self.conn.close() - self.conn = None - - def _require_conn(self) -> sqlite3.Connection: - if self.conn is None: - raise MigrationError("源数据库尚未连接") - return self.conn - - def resolve_stream_ids( - self, - stream_ids: Sequence[str], - group_ids: Sequence[str], - user_ids: Sequence[str], - ) -> List[str]: - conn = self._require_conn() - resolved: set[str] = set(_normalize_name(x) for x in stream_ids if _normalize_name(x)) - has_group_or_user = any(_normalize_name(x) for x in group_ids) or any(_normalize_name(x) for x in user_ids) - if not has_group_or_user: - return sorted(resolved) - - table_exists = conn.execute( - "SELECT 1 FROM sqlite_master WHERE type='table' AND name='chat_streams' LIMIT 1" - ).fetchone() - if table_exists is None: - raise MigrationError("源库缺少 chat_streams 表,无法根据 --group-id/--user-id 映射 stream_id") - - def _select_by_field(field: str, values: Sequence[str]) -> None: - values_norm = [_normalize_name(v) for v in values if _normalize_name(v)] - if not values_norm: - return - placeholders = ",".join("?" for _ in values_norm) - sql = f"SELECT DISTINCT stream_id FROM chat_streams WHERE {field} IN ({placeholders})" - cur = conn.execute(sql, tuple(values_norm)) - for row in cur.fetchall(): - sid = _normalize_name(row["stream_id"]) - if sid: - resolved.add(sid) - - _select_by_field("group_id", group_ids) - _select_by_field("user_id", user_ids) - return sorted(resolved) - - @staticmethod - def _build_where( - selection: SelectionFilter, - start_after_id: Optional[int] = None, - ) -> Tuple[str, List[Any]]: - conditions: List[str] = [] - params: List[Any] = [] - - if selection.start_id is not None: - conditions.append("id >= ?") - params.append(selection.start_id) - if selection.end_id is not None: - conditions.append("id <= ?") - params.append(selection.end_id) - if start_after_id is not None: - conditions.append("id > ?") - params.append(start_after_id) - - if selection.stream_ids: - placeholders = ",".join("?" for _ in selection.stream_ids) - conditions.append(f"chat_id IN ({placeholders})") - params.extend(selection.stream_ids) - elif selection.stream_filter_requested: - conditions.append("1=0") - - if selection.time_from_ts is not None and selection.time_to_ts is not None: - conditions.append("(end_time >= ? AND start_time <= ?)") - params.extend([selection.time_from_ts, selection.time_to_ts]) - elif selection.time_from_ts is not None: - conditions.append("(end_time >= ?)") - params.append(selection.time_from_ts) - elif selection.time_to_ts is not None: - conditions.append("(start_time <= ?)") - params.append(selection.time_to_ts) - - where_sql = "WHERE " + " AND ".join(conditions) if conditions else "" - return where_sql, params - - def count_candidates(self, selection: SelectionFilter) -> int: - conn = self._require_conn() - where_sql, params = self._build_where(selection, start_after_id=None) - sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}" - cur = conn.execute(sql, tuple(params)) - return int(cur.fetchone()["c"]) - - def preview(self, selection: SelectionFilter, preview_limit: int) -> PreviewResult: - conn = self._require_conn() - where_sql, params = self._build_where(selection, start_after_id=None) - - total_sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}" - total = int(conn.execute(total_sql, tuple(params)).fetchone()["c"]) - - dist_sql = ( - f"SELECT chat_id, COUNT(*) AS c FROM chat_history {where_sql} " - "GROUP BY chat_id ORDER BY c DESC LIMIT 30" - ) - distribution = [ - (_normalize_name(row["chat_id"]), int(row["c"])) - for row in conn.execute(dist_sql, tuple(params)).fetchall() - ] - - sample_sql = ( - "SELECT id, chat_id, start_time, end_time, theme, summary " - f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?" - ) - sample_params = list(params) - sample_params.append(max(1, int(preview_limit))) - samples = [dict(row) for row in conn.execute(sample_sql, tuple(sample_params)).fetchall()] - - return PreviewResult(total=total, distribution=distribution, samples=samples) - - def iter_rows( - self, - selection: SelectionFilter, - batch_size: int, - start_after_id: int, - ) -> Generator[List[sqlite3.Row], None, None]: - conn = self._require_conn() - cursor = int(start_after_id) - while True: - where_sql, params = self._build_where(selection, start_after_id=cursor) - sql = ( - "SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary " - f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?" - ) - bind = list(params) - bind.append(max(1, int(batch_size))) - rows = conn.execute(sql, tuple(bind)).fetchall() - if not rows: - break - yield rows - cursor = int(rows[-1]["id"]) - - def sample_rows_for_verify( - self, - selection: SelectionFilter, - sample_size: int, - ) -> List[sqlite3.Row]: - conn = self._require_conn() - where_sql, params = self._build_where(selection, start_after_id=None) - sql = ( - "SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary " - f"FROM chat_history {where_sql} ORDER BY RANDOM() LIMIT ?" - ) - bind = list(params) - bind.append(max(1, int(sample_size))) - return conn.execute(sql, tuple(bind)).fetchall() - - -class MigrationRunner: - def __init__(self, args: argparse.Namespace): - self.args = args - self.source_db_path = Path(args.source_db).resolve() - self.target_data_dir = Path(args.target_data_dir).resolve() - self.state_file = _state_path(self.target_data_dir) - self.bad_rows_file = _bad_rows_path(self.target_data_dir) - self.report_file = _report_path(self.target_data_dir) - - self.source_db = SourceDB(self.source_db_path) - - self.vector_store = None - self.graph_store = None - self.metadata_store = None - self.embedding_manager = None - self.relation_write_service = None - self.plugin_config: Dict[str, Any] = {} - self.embed_workers: int = 5 - - self.selection: Optional[SelectionFilter] = None - self.filter_fingerprint: str = "" - self.source_db_fingerprint: Dict[str, Any] = {} - self.source_db_fingerprint_hash: str = "" - self.state: Dict[str, Any] = {} - - self.started_at = time.time() - self.exit_code = 0 - self.failed = False - self.fail_reason: Optional[str] = None - - self.stats: Dict[str, Any] = { - "source_matched_total": 0, - "scanned_rows": 0, - "valid_rows": 0, - "migrated_rows": 0, - "skipped_existing_rows": 0, - "bad_rows": 0, - "paragraph_vectors_added": 0, - "entity_vectors_added": 0, - "relations_written": 0, - "relation_vectors_written": 0, - "relation_vectors_failed": 0, - "relation_vectors_skipped": 0, - "graph_edges_written": 0, - "windows_committed": 0, - "last_committed_id": 0, - "verify_sample_size": 0, - "verify_paragraph_missing": 0, - "verify_vector_missing": 0, - "verify_relation_missing": 0, - "verify_edge_missing": 0, - "verify_passed": False, - } - - async def run(self) -> int: - try: - _bootstrap_runtime_symbols() - self._prepare_paths() - - self.source_db.connect() - self.selection = self._build_selection_filter() - self.filter_fingerprint = _json_hash(self.selection.fingerprint_payload()) - - self.source_db_fingerprint = _build_source_db_fingerprint(self.source_db_path) - self.source_db_fingerprint_hash = str(self.source_db_fingerprint.get("sha1", "")) - - preview = self.source_db.preview(self.selection, preview_limit=self.args.preview_limit) - self.stats["source_matched_total"] = int(preview.total) - self._print_preview(preview) - - if preview.total <= 0: - logger.info("筛选后无数据,退出。") - self.stats["verify_passed"] = True - if self.args.verify_only: - self._load_plugin_config() - await self._init_target_stores(require_embedding=False) - await self._verify(strict=True) - return self._finalize() - - if self.args.verify_only: - self._load_plugin_config() - await self._init_target_stores(require_embedding=False) - await self._verify(strict=True) - return self._finalize() - - if self.args.dry_run: - logger.info("dry-run 模式:仅预览,不写入。") - return self._finalize() - - if not self.args.yes: - if not self._confirm(): - logger.info("用户取消执行。") - return self._finalize() - - self._load_plugin_config() - await self._init_target_stores(require_embedding=True) - self._load_or_init_state() - - start_after_id = self._resolve_start_after_id() - await self._migrate(start_after_id=start_after_id) - await self._verify(strict=True) - return self._finalize() - except Exception as e: - self.failed = True - self.fail_reason = str(e) - logger.error(f"迁移失败: {e}\n{traceback.format_exc()}") - return self._finalize() - finally: - self._close() - - def _prepare_paths(self) -> None: - (self.target_data_dir / MIGRATION_STATE_DIRNAME).mkdir(parents=True, exist_ok=True) - if self.args.reset_state and self.state_file.exists(): - self.state_file.unlink() - if self.args.reset_state and self.bad_rows_file.exists(): - self.bad_rows_file.unlink() - - def _load_plugin_config(self) -> None: - merged = _load_manifest_defaults() - - config_path = DEFAULT_CONFIG_PATH - if config_path.exists(): - try: - with open(config_path, "r", encoding="utf-8") as f: - raw = tomlkit.load(f) - if isinstance(raw, dict): - merged = _deep_merge_dict(merged, dict(raw)) - except Exception as e: - logger.warning(f"读取插件配置失败,继续使用默认配置: {e}") - - self.plugin_config = merged - - def _read_existing_vector_dimension(self, fallback_dimension: int) -> int: - meta_path = self.target_data_dir / "vectors" / "vectors_metadata.pkl" - if not meta_path.exists(): - return fallback_dimension - try: - with open(meta_path, "rb") as f: - payload = pickle.load(f) - value = _safe_int(payload.get("dimension"), fallback_dimension) - return max(1, value) - except Exception: - return fallback_dimension - - async def _init_target_stores(self, require_embedding: bool) -> None: - if VectorStore is None or GraphStore is None or MetadataStore is None: - raise MigrationError("运行时初始化失败:存储组件不可用") - - emb_cfg = self.plugin_config.get("embedding", {}) if isinstance(self.plugin_config, dict) else {} - graph_cfg = self.plugin_config.get("graph", {}) if isinstance(self.plugin_config, dict) else {} - - self.embed_workers = max(1, _safe_int(self.args.embed_workers, _safe_int(emb_cfg.get("max_concurrent"), 5))) - emb_batch_size = max(1, _safe_int(emb_cfg.get("batch_size"), 32)) - emb_default_dim = max(1, _safe_int(emb_cfg.get("dimension"), 1024)) - emb_model_name = str(emb_cfg.get("model_name", "auto")) - emb_retry = emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {} - - if require_embedding: - _load_embedding_adapter_factory() - if create_embedding_api_adapter is None: - raise MigrationError("运行时初始化失败:embedding 适配器不可用") - - if model_config is not None: - embedding_task = getattr(getattr(model_config, "model_task_config", None), "embedding", None) - if embedding_task is not None and hasattr(embedding_task, "model_list"): - if not list(embedding_task.model_list): - raise MigrationError( - "当前配置没有可用 embedding 模型。若你使用 gemini provider,请先安装 `google-genai` " - "或切换到可用的 embedding provider。" - ) - - self.embedding_manager = create_embedding_api_adapter( - batch_size=emb_batch_size, - max_concurrent=self.embed_workers, - default_dimension=emb_default_dim, - model_name=emb_model_name, - retry_config=emb_retry, - ) - - try: - detected_dim = self._read_existing_vector_dimension(emb_default_dim) - has_existing_vectors = (self.target_data_dir / "vectors" / "vectors_metadata.pkl").exists() - if not has_existing_vectors: - detected_dim = await self.embedding_manager._detect_dimension() - except Exception as e: - logger.warning(f"嵌入维度探测失败,回退配置维度: {e}") - detected_dim = self._read_existing_vector_dimension(emb_default_dim) - else: - detected_dim = self._read_existing_vector_dimension(emb_default_dim) - self.embedding_manager = None - - q_type = str(emb_cfg.get("quantization_type", "int8")).lower() - if q_type != "int8": - raise MigrationError( - "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - quantization = QuantizationType.INT8 - - matrix_fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower() - fmt_map = { - "csr": SparseMatrixFormat.CSR, - "csc": SparseMatrixFormat.CSC, - } - sparse_fmt = fmt_map.get(matrix_fmt, SparseMatrixFormat.CSR) - - self.vector_store = VectorStore( - dimension=detected_dim, - quantization_type=quantization, - data_dir=self.target_data_dir / "vectors", - ) - self.graph_store = GraphStore( - matrix_format=sparse_fmt, - data_dir=self.target_data_dir / "graph", - ) - self.metadata_store = MetadataStore(data_dir=self.target_data_dir / "metadata") - self.metadata_store.connect() - - if self.vector_store.has_data(): - self.vector_store.load() - if self.graph_store.has_data(): - self.graph_store.load() - - self.relation_write_service = None - if require_embedding and RelationWriteService is not None and self.embedding_manager is not None: - self.relation_write_service = RelationWriteService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - - logger.info( - f"目标存储初始化完成: dim={self.vector_store.dimension}, quant={q_type}, graph_fmt={matrix_fmt}, " - f"embed_workers={self.embed_workers}" - ) - - def _should_write_relation_vectors(self) -> bool: - retrieval_cfg = self.plugin_config.get("retrieval", {}) if isinstance(self.plugin_config, dict) else {} - if not isinstance(retrieval_cfg, dict): - return False - rv_cfg = retrieval_cfg.get("relation_vectorization", {}) - if not isinstance(rv_cfg, dict): - return False - return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) - - async def _ensure_relation_vectors_for_records( - self, - relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]], - ) -> None: - if not relation_records: - return - if self.relation_write_service is None: - return - - success = 0 - failed = 0 - skipped = 0 - for relation_hash, rel in relation_records.items(): - result = await self.relation_write_service.ensure_relation_vector( - hash_value=relation_hash, - subject=str(rel[0]), - predicate=str(rel[1]), - obj=str(rel[2]), - ) - if result.vector_state == "ready": - if result.vector_written: - success += 1 - else: - skipped += 1 - else: - failed += 1 - - self.stats["relation_vectors_written"] += success - self.stats["relation_vectors_failed"] += failed - self.stats["relation_vectors_skipped"] += skipped - - def _build_selection_filter(self) -> SelectionFilter: - if self.args.start_id is not None and self.args.start_id <= 0: - raise MigrationError("--start-id 必须 > 0") - if self.args.end_id is not None and self.args.end_id <= 0: - raise MigrationError("--end-id 必须 > 0") - if self.args.start_id is not None and self.args.end_id is not None and self.args.start_id > self.args.end_id: - raise MigrationError("--start-id 不能大于 --end-id") - - time_from_ts = _parse_cli_datetime(self.args.time_from, is_end=False) if self.args.time_from else None - time_to_ts = _parse_cli_datetime(self.args.time_to, is_end=True) if self.args.time_to else None - if time_from_ts is not None and time_to_ts is not None and time_from_ts > time_to_ts: - raise MigrationError("--time-from 不能晚于 --time-to") - - stream_filter_requested = bool( - (self.args.stream_id or []) or (self.args.group_id or []) or (self.args.user_id or []) - ) - stream_ids = self.source_db.resolve_stream_ids( - stream_ids=self.args.stream_id or [], - group_ids=self.args.group_id or [], - user_ids=self.args.user_id or [], - ) - if stream_filter_requested and not stream_ids: - logger.warning("已指定 stream/group/user 筛选,但未解析到任何 stream_id,结果将为空。") - - logger.info( - f"筛选条件: time_from={self.args.time_from or '-'}, time_to={self.args.time_to or '-'}, " - f"stream_ids={len(stream_ids)}, stream_filter_requested={stream_filter_requested}" - ) - - return SelectionFilter( - time_from_ts=time_from_ts, - time_to_ts=time_to_ts, - stream_ids=stream_ids, - stream_filter_requested=stream_filter_requested, - start_id=self.args.start_id, - end_id=self.args.end_id, - time_from_raw=self.args.time_from, - time_to_raw=self.args.time_to, - ) - - def _load_or_init_state(self) -> None: - if self.args.start_id is not None: - logger.info("检测到 --start-id,已按用户指定起点覆盖断点状态。") - self.state = self._new_state(last_committed_id=int(self.args.start_id) - 1) - return - - if self.args.no_resume: - self.state = self._new_state(last_committed_id=0) - return - - if not self.state_file.exists(): - self.state = self._new_state(last_committed_id=0) - return - - with open(self.state_file, "r", encoding="utf-8") as f: - loaded = json.load(f) - - loaded_filter_fp = str(loaded.get("filter_fingerprint", "")) - loaded_source_fp = str(loaded.get("source_db_fingerprint", "")) - - if loaded_filter_fp != self.filter_fingerprint or loaded_source_fp != self.source_db_fingerprint_hash: - if self.args.dry_run or self.args.verify_only: - logger.info("检测到断点与当前筛选不一致;当前为只读模式,将忽略旧断点。") - self.state = self._new_state(last_committed_id=0) - return - raise MigrationError( - "检测到筛选条件或源库指纹变化,已拒绝继续续传。请使用 --reset-state 或调整参数后重试。" - ) - - self.state = loaded - stored_stats = loaded.get("stats", {}) - if isinstance(stored_stats, dict): - for k, v in stored_stats.items(): - if k in self.stats and isinstance(v, (int, float, bool)): - self.stats[k] = v - - def _new_state(self, last_committed_id: int) -> Dict[str, Any]: - return { - "version": 1, - "updated_at": time.time(), - "last_committed_id": int(last_committed_id), - "filter_fingerprint": self.filter_fingerprint, - "source_db_fingerprint": self.source_db_fingerprint_hash, - "source_db_meta": self.source_db_fingerprint, - "stats": dict(self.stats), - } - - def _flush_state(self, last_committed_id: int) -> None: - self.stats["last_committed_id"] = int(last_committed_id) - self.state = { - "version": 1, - "updated_at": time.time(), - "last_committed_id": int(last_committed_id), - "filter_fingerprint": self.filter_fingerprint, - "source_db_fingerprint": self.source_db_fingerprint_hash, - "source_db_meta": self.source_db_fingerprint, - "stats": dict(self.stats), - } - _dump_json_atomic(self.state_file, self.state) - - def _resolve_start_after_id(self) -> int: - if self.selection is None: - raise MigrationError("selection 未初始化") - - if self.args.start_id is not None: - return int(self.args.start_id) - 1 - - if self.args.no_resume: - return 0 - - state_last = _safe_int(self.state.get("last_committed_id"), 0) if self.state else 0 - return max(0, state_last) - - def _print_preview(self, preview: PreviewResult) -> None: - print("\n=== Migration Preview ===") - print(f"source_db: {self.source_db_path}") - print(f"target_data_dir: {self.target_data_dir}") - if self.selection: - print( - f"time_window: [{self.selection.time_from_raw or '-'} ~ {self.selection.time_to_raw or '-'}] " - f"(ts: {_format_ts(self.selection.time_from_ts)} ~ {_format_ts(self.selection.time_to_ts)})" - ) - print( - f"id_window: [{self.selection.start_id or '-'} ~ {self.selection.end_id or '-'}], " - f"selected_streams={len(self.selection.stream_ids)}" - ) - print(f"matched_rows: {preview.total}") - - if preview.distribution: - print("top_chat_distribution:") - for cid, cnt in preview.distribution[:10]: - print(f" - {cid}: {cnt}") - else: - print("top_chat_distribution: (none)") - - if preview.samples: - print(f"samples (first {len(preview.samples)}):") - for row in preview.samples: - summary_preview = _normalize_name(row.get("summary", ""))[:60] - theme_preview = _normalize_name(row.get("theme", ""))[:30] - print( - f" - id={row.get('id')} chat_id={row.get('chat_id')} " - f"[{_format_ts(row.get('start_time'))} ~ {_format_ts(row.get('end_time'))}] " - f"theme={theme_preview!r} summary={summary_preview!r}" - ) - print("=========================\n") - - def _confirm(self) -> bool: - answer = input("确认按以上筛选执行迁移?输入 y 继续 [y/N]: ").strip().lower() - return answer in {"y", "yes"} - - def _parse_json_list_field(self, raw: Any, field_name: str, row_id: int) -> List[str]: - if raw is None: - return [] - if isinstance(raw, list): - data = raw - elif isinstance(raw, str): - try: - parsed = json.loads(raw) - except Exception as e: - raise ValueError(f"{field_name} JSON 解析失败: {e}") from e - if not isinstance(parsed, list): - raise ValueError(f"{field_name} JSON 必须是 list,当前为 {type(parsed).__name__}") - data = parsed - else: - raise ValueError(f"{field_name} 字段类型不支持: {type(raw).__name__}") - return _dedup_keep_order(str(x) for x in data if _normalize_name(x)) - - def _map_row(self, row: sqlite3.Row) -> MappedRow: - row_id = int(row["id"]) - chat_id = _normalize_name(row["chat_id"]) - theme = _normalize_name(row["theme"]) - summary = _normalize_name(row["summary"]) - - participants = self._parse_json_list_field(row["participants"], "participants", row_id) - keywords = self._parse_json_list_field(row["keywords"], "keywords", row_id) - keywords_top = keywords[:8] - - participants_text = "、".join(participants) if participants else "" - keywords_text = "、".join(keywords_top) if keywords_top else "" - - content = ( - f"话题:{theme}\n" - f"概括:{summary}\n" - f"参与者:{participants_text}\n" - f"关键词:{keywords_text}" - ).strip() - - paragraph_hash = compute_hash(normalize_text(content)) - source = f"maibot.chat_history:{chat_id}" - - start_time = _safe_float(row["start_time"], 0.0) - end_time = _safe_float(row["end_time"], start_time) - time_meta = { - "event_time_start": start_time, - "event_time_end": end_time, - "time_granularity": "minute", - "time_confidence": 0.95, - } - - entities = _dedup_keep_order([*participants, theme, *keywords_top]) - relations: List[Tuple[str, str, str]] = [] - if theme: - for participant in participants: - relations.append((participant, "参与话题", theme)) - for keyword in keywords_top: - relations.append((theme, "关键词", keyword)) - - existing_vector = paragraph_hash in self.vector_store - return MappedRow( - row_id=row_id, - chat_id=chat_id, - paragraph_hash=paragraph_hash, - content=content, - source=source, - time_meta=time_meta, - entities=entities, - relations=relations, - existing_paragraph_vector=existing_vector, - ) - - def _append_bad_row(self, row: sqlite3.Row, reason: str) -> None: - payload = { - "id": int(row["id"]), - "chat_id": _normalize_name(row["chat_id"]), - "start_time": row["start_time"], - "end_time": row["end_time"], - "participants": row["participants"], - "theme": _normalize_name(row["theme"]), - "keywords": row["keywords"], - "summary": row["summary"], - "error": reason, - "timestamp": time.time(), - } - self.bad_rows_file.parent.mkdir(parents=True, exist_ok=True) - with open(self.bad_rows_file, "a", encoding="utf-8") as f: - f.write(json.dumps(payload, ensure_ascii=False)) - f.write("\n") - - async def _migrate(self, start_after_id: int) -> None: - if self.selection is None: - raise MigrationError("selection 未初始化") - - read_batch_size = max(1, int(self.args.read_batch_size)) - commit_window_rows = max(1, int(self.args.commit_window_rows)) - log_every = max(1, int(self.args.log_every)) - - window_rows: List[MappedRow] = [] - window_scanned = 0 - last_seen_id = start_after_id - - logger.info( - f"开始迁移: start_after_id={start_after_id}, read_batch_size={read_batch_size}, " - f"commit_window_rows={commit_window_rows}" - ) - - for batch in self.source_db.iter_rows(self.selection, read_batch_size, start_after_id): - for row in batch: - row_id = int(row["id"]) - last_seen_id = row_id - self.stats["scanned_rows"] += 1 - window_scanned += 1 - - try: - mapped = self._map_row(row) - except Exception as e: - self.stats["bad_rows"] += 1 - self._append_bad_row(row, str(e)) - if self.stats["bad_rows"] > int(self.args.max_errors): - raise MigrationError( - f"坏行数量超过上限 max_errors={self.args.max_errors},已中止。" - ) - continue - - self.stats["valid_rows"] += 1 - if mapped.existing_paragraph_vector: - self.stats["skipped_existing_rows"] += 1 - else: - self.stats["migrated_rows"] += 1 - window_rows.append(mapped) - - if window_scanned >= commit_window_rows: - await self._commit_window(window_rows, last_seen_id) - window_rows = [] - window_scanned = 0 - - if self.stats["scanned_rows"] % log_every == 0: - logger.info( - f"迁移进度: scanned={self.stats['scanned_rows']}/{self.stats['source_matched_total']}, " - f"valid={self.stats['valid_rows']}, bad={self.stats['bad_rows']}, " - f"last_id={last_seen_id}" - ) - - if window_scanned > 0 or window_rows: - await self._commit_window(window_rows, last_seen_id) - - logger.info( - f"迁移主流程完成: scanned={self.stats['scanned_rows']}, valid={self.stats['valid_rows']}, " - f"bad={self.stats['bad_rows']}, last_committed_id={self.stats['last_committed_id']}" - ) - - async def _commit_window(self, rows: List[MappedRow], last_seen_id: int) -> None: - if not rows: - self._flush_state(last_seen_id) - self.stats["windows_committed"] += 1 - return - - now_ts = time.time() - empty_meta_blob = pickle.dumps({}) - - conn = self.metadata_store.get_connection() - - cursor = conn.cursor() - - # 批量查询本窗口内已存在的段落,保证重跑时 entity/mention 不重复累计 - existing_paragraph_hashes: set[str] = set() - all_hashes = [item.paragraph_hash for item in rows] - for i in range(0, len(all_hashes), 800): - batch_hashes = all_hashes[i : i + 800] - if not batch_hashes: - continue - placeholders = ",".join("?" for _ in batch_hashes) - existing_rows = cursor.execute( - f"SELECT hash FROM paragraphs WHERE hash IN ({placeholders})", - tuple(batch_hashes), - ).fetchall() - for row in existing_rows: - existing_paragraph_hashes.add(str(row["hash"])) - - paragraph_records: List[Tuple[Any, ...]] = [] - paragraph_embed_map: Dict[str, str] = {} - - entity_display: Dict[str, str] = {} - entity_counts: Dict[str, int] = defaultdict(int) - paragraph_entity_mentions: Dict[Tuple[str, str], int] = defaultdict(int) - entity_embed_map: Dict[str, str] = {} - - relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]] = {} - paragraph_relation_links: set[Tuple[str, str]] = set() - - for item in rows: - is_new_paragraph = item.paragraph_hash not in existing_paragraph_hashes - - start_ts = _safe_float(item.time_meta.get("event_time_start"), 0.0) - end_ts = _safe_float(item.time_meta.get("event_time_end"), start_ts) - confidence = _safe_float(item.time_meta.get("time_confidence"), 0.95) - granularity = _normalize_name(item.time_meta.get("time_granularity")) or "minute" - - if is_new_paragraph: - paragraph_records.append( - ( - item.paragraph_hash, - item.content, - None, - now_ts, - now_ts, - empty_meta_blob, - item.source, - len(normalize_text(item.content).split()), - None, - start_ts, - end_ts, - granularity, - confidence, - KnowledgeType.NARRATIVE.value, - ) - ) - - if item.paragraph_hash not in self.vector_store: - paragraph_embed_map[item.paragraph_hash] = item.content - - for entity in item.entities: - name = _normalize_name(entity) - if not name: - continue - canon = _canonical_name(name) - if not canon: - continue - entity_hash = compute_hash(canon) - entity_display.setdefault(entity_hash, name) - if is_new_paragraph: - entity_counts[entity_hash] += 1 - paragraph_entity_mentions[(item.paragraph_hash, entity_hash)] += 1 - if entity_hash not in self.vector_store: - entity_embed_map.setdefault(entity_hash, name) - - for subject, predicate, obj in item.relations: - s = _normalize_name(subject) - p = _normalize_name(predicate) - o = _normalize_name(obj) - if not (s and p and o): - continue - - s_canon = _canonical_name(s) - p_canon = _canonical_name(p) - o_canon = _canonical_name(o) - relation_hash = compute_hash(f"{s_canon}|{p_canon}|{o_canon}") - - if is_new_paragraph: - relation_records.setdefault( - relation_hash, - (s, p, o, 1.0, item.paragraph_hash, empty_meta_blob), - ) - paragraph_relation_links.add((item.paragraph_hash, relation_hash)) - - for relation_entity in (s, o): - e_canon = _canonical_name(relation_entity) - if not e_canon: - continue - e_hash = compute_hash(e_canon) - entity_display.setdefault(e_hash, relation_entity) - if is_new_paragraph: - entity_counts[e_hash] += 1 - paragraph_entity_mentions[(item.paragraph_hash, e_hash)] += 1 - if e_hash not in self.vector_store: - entity_embed_map.setdefault(e_hash, relation_entity) - - try: - cursor.execute("BEGIN") - - if paragraph_records: - cursor.executemany( - """ - INSERT OR IGNORE INTO paragraphs - ( - hash, content, vector_index, created_at, updated_at, metadata, source, word_count, - event_time, event_time_start, event_time_end, time_granularity, time_confidence, knowledge_type - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - paragraph_records, - ) - - if entity_counts: - entity_rows = [ - ( - entity_hash, - entity_display[entity_hash], - None, - int(count), - now_ts, - empty_meta_blob, - ) - for entity_hash, count in entity_counts.items() - ] - try: - cursor.executemany( - """ - INSERT INTO entities - (hash, name, vector_index, appearance_count, created_at, metadata) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(hash) DO UPDATE SET - appearance_count = entities.appearance_count + excluded.appearance_count - """, - entity_rows, - ) - except sqlite3.OperationalError: - cursor.executemany( - """ - INSERT OR IGNORE INTO entities - (hash, name, vector_index, appearance_count, created_at, metadata) - VALUES (?, ?, ?, ?, ?, ?) - """, - entity_rows, - ) - cursor.executemany( - "UPDATE entities SET appearance_count = appearance_count + ? WHERE hash = ?", - [(int(count), entity_hash) for entity_hash, count in entity_counts.items()], - ) - - if paragraph_entity_mentions: - pe_rows = [ - (paragraph_hash, entity_hash, int(mentions)) - for (paragraph_hash, entity_hash), mentions in paragraph_entity_mentions.items() - ] - try: - cursor.executemany( - """ - INSERT INTO paragraph_entities - (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - ON CONFLICT(paragraph_hash, entity_hash) DO UPDATE SET - mention_count = paragraph_entities.mention_count + excluded.mention_count - """, - pe_rows, - ) - except sqlite3.OperationalError: - cursor.executemany( - """ - INSERT OR IGNORE INTO paragraph_entities - (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - """, - pe_rows, - ) - cursor.executemany( - """ - UPDATE paragraph_entities - SET mention_count = mention_count + ? - WHERE paragraph_hash = ? AND entity_hash = ? - """, - [(m, p, e) for (p, e, m) in pe_rows], - ) - - if relation_records: - relation_rows = [ - ( - relation_hash, - rel[0], - rel[1], - rel[2], - None, - rel[3], - now_ts, - rel[4], - rel[5], - ) - for relation_hash, rel in relation_records.items() - ] - cursor.executemany( - """ - INSERT OR IGNORE INTO relations - (hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - relation_rows, - ) - - if paragraph_relation_links: - pr_rows = [(p_hash, r_hash) for p_hash, r_hash in paragraph_relation_links] - cursor.executemany( - """ - INSERT OR IGNORE INTO paragraph_relations - (paragraph_hash, relation_hash) - VALUES (?, ?) - """, - pr_rows, - ) - - conn.commit() - except Exception: - conn.rollback() - raise - - self.stats["relations_written"] += len(relation_records) - - if relation_records: - edge_pairs = [] - relation_hashes = [] - for relation_hash, rel in relation_records.items(): - edge_pairs.append((rel[0], rel[2])) - relation_hashes.append(relation_hash) - - with self.graph_store.batch_update(): - self.graph_store.add_edges(edge_pairs, relation_hashes=relation_hashes) - self.stats["graph_edges_written"] += len(edge_pairs) - - if self._should_write_relation_vectors(): - await self._ensure_relation_vectors_for_records(relation_records) - - para_added = await self._embed_and_add_vectors( - id_to_text=paragraph_embed_map, - batch_size=max(1, int(self.args.embed_batch_size)), - workers=self.embed_workers, - ) - ent_added = await self._embed_and_add_vectors( - id_to_text=entity_embed_map, - batch_size=max(1, int(self.args.entity_embed_batch_size)), - workers=self.embed_workers, - ) - self.stats["paragraph_vectors_added"] += para_added - self.stats["entity_vectors_added"] += ent_added - - self.vector_store.save() - self.graph_store.save() - - self.stats["windows_committed"] += 1 - self._flush_state(last_seen_id) - - async def _embed_and_add_vectors( - self, - id_to_text: Dict[str, str], - batch_size: int, - workers: int, - ) -> int: - if not id_to_text: - return 0 - if self.embedding_manager is None: - raise MigrationError("embedding_manager 未初始化,无法写入向量") - - ids = [] - texts = [] - for hash_id, text in id_to_text.items(): - if hash_id in self.vector_store: - continue - ids.append(hash_id) - texts.append(text) - - if not ids: - return 0 - - total_added = 0 - chunk_size = max(1, int(batch_size)) - for i in range(0, len(ids), chunk_size): - chunk_ids = ids[i : i + chunk_size] - chunk_texts = texts[i : i + chunk_size] - - embeddings = await self.embedding_manager.encode_batch( - chunk_texts, - batch_size=chunk_size, - num_workers=max(1, int(workers)), - ) - - emb_arr = np.asarray(embeddings, dtype=np.float32) - if emb_arr.ndim == 1: - emb_arr = emb_arr.reshape(1, -1) - if emb_arr.shape[0] != len(chunk_ids): - logger.warning( - f"embedding 返回数量异常: expected={len(chunk_ids)}, got={emb_arr.shape[0]},跳过该批次" - ) - continue - - valid_vectors = [] - valid_ids = [] - for idx, vec in enumerate(emb_arr): - if vec.ndim != 1: - continue - if vec.shape[0] != self.vector_store.dimension: - logger.warning( - f"向量维度不匹配,跳过: id={chunk_ids[idx]}, got={vec.shape[0]}, expected={self.vector_store.dimension}" - ) - continue - if not np.all(np.isfinite(vec)): - logger.warning(f"向量含 NaN/Inf,跳过: id={chunk_ids[idx]}") - continue - if chunk_ids[idx] in self.vector_store: - continue - valid_vectors.append(vec) - valid_ids.append(chunk_ids[idx]) - - if valid_vectors: - batch_vectors = np.stack(valid_vectors).astype(np.float32, copy=False) - added = self.vector_store.add(batch_vectors, valid_ids) - total_added += int(added) - - return total_added - - async def _verify(self, strict: bool) -> None: - if self.selection is None: - raise MigrationError("selection 未初始化") - - sample_size = min(2000, max(0, int(self.stats.get("source_matched_total", 0)))) - self.stats["verify_sample_size"] = sample_size - - if sample_size <= 0: - self.stats["verify_passed"] = True - return - - sample_rows = self.source_db.sample_rows_for_verify(self.selection, sample_size) - para_missing = 0 - vec_missing = 0 - rel_missing = 0 - edge_missing = 0 - - for row in sample_rows: - try: - mapped = self._map_row(row) - except Exception: - continue - - paragraph = self.metadata_store.get_paragraph(mapped.paragraph_hash) - if paragraph is None: - para_missing += 1 - if mapped.paragraph_hash not in self.vector_store: - vec_missing += 1 - - for s, p, o in mapped.relations: - relation_hash = compute_hash(f"{_canonical_name(s)}|{_canonical_name(p)}|{_canonical_name(o)}") - relation = self.metadata_store.get_relation(relation_hash) - if relation is None: - rel_missing += 1 - if self.graph_store.get_edge_weight(s, o) <= 0.0: - edge_missing += 1 - - self.stats["verify_paragraph_missing"] = para_missing - self.stats["verify_vector_missing"] = vec_missing - self.stats["verify_relation_missing"] = rel_missing - self.stats["verify_edge_missing"] = edge_missing - - verify_passed = all(x == 0 for x in [para_missing, vec_missing, rel_missing, edge_missing]) - if strict and not verify_passed: - self.failed = True - self.fail_reason = ( - "严格校验失败: " - f"paragraph_missing={para_missing}, vector_missing={vec_missing}, " - f"relation_missing={rel_missing}, edge_missing={edge_missing}" - ) - - self.stats["verify_passed"] = verify_passed - - def _finalize(self) -> int: - elapsed = time.time() - self.started_at - self.stats["elapsed_seconds"] = elapsed - - report = { - "success": not self.failed, - "fail_reason": self.fail_reason, - "args": vars(self.args), - "source_db": str(self.source_db_path), - "target_data_dir": str(self.target_data_dir), - "selection": self.selection.fingerprint_payload() if self.selection else {}, - "filter_fingerprint": self.filter_fingerprint, - "source_db_fingerprint": self.source_db_fingerprint, - "state_file": str(self.state_file), - "bad_rows_file": str(self.bad_rows_file), - "stats": dict(self.stats), - "timestamp": time.time(), - } - - _dump_json_atomic(self.report_file, report) - - if self.failed: - self.exit_code = 1 - elif self.stats.get("bad_rows", 0) > 0: - self.exit_code = 2 - else: - self.exit_code = 0 - - print("\n=== Migration Report ===") - print(f"success: {not self.failed}") - if self.fail_reason: - print(f"fail_reason: {self.fail_reason}") - print(f"elapsed: {elapsed:.2f}s") - print(f"source_matched_total: {self.stats['source_matched_total']}") - print(f"scanned_rows: {self.stats['scanned_rows']}") - print(f"valid_rows: {self.stats['valid_rows']}") - print(f"migrated_rows: {self.stats['migrated_rows']}") - print(f"skipped_existing_rows: {self.stats['skipped_existing_rows']}") - print(f"bad_rows: {self.stats['bad_rows']}") - print(f"paragraph_vectors_added: {self.stats['paragraph_vectors_added']}") - print(f"entity_vectors_added: {self.stats['entity_vectors_added']}") - print(f"relations_written: {self.stats['relations_written']}") - print( - "relation_vectors: " - f"written={self.stats['relation_vectors_written']}, " - f"failed={self.stats['relation_vectors_failed']}, " - f"skipped={self.stats['relation_vectors_skipped']}" - ) - print(f"graph_edges_written: {self.stats['graph_edges_written']}") - print(f"windows_committed: {self.stats['windows_committed']}") - print(f"last_committed_id: {self.stats['last_committed_id']}") - print( - "verify: " - f"sample={self.stats['verify_sample_size']}, " - f"paragraph_missing={self.stats['verify_paragraph_missing']}, " - f"vector_missing={self.stats['verify_vector_missing']}, " - f"relation_missing={self.stats['verify_relation_missing']}, " - f"edge_missing={self.stats['verify_edge_missing']}, " - f"passed={self.stats['verify_passed']}" - ) - print(f"report_file: {self.report_file}") - print("========================\n") - - return self.exit_code - - def _close(self) -> None: - try: - if self.metadata_store is not None: - self.metadata_store.close() - except Exception: - pass - self.source_db.close() - - -def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description="迁移 MaiBot chat_history 到 A_memorix(高性能 + 可断点续传 + 可确认筛选)" - ) - - parser.add_argument("--source-db", default=str(DEFAULT_SOURCE_DB), help="源数据库路径(默认 data/MaiBot.db)") - parser.add_argument( - "--target-data-dir", - default=str(DEFAULT_TARGET_DATA_DIR), - help="A_memorix 数据目录(默认 plugins/A_memorix/data)", - ) - - resume_group = parser.add_mutually_exclusive_group() - resume_group.add_argument("--resume", dest="no_resume", action="store_false", help="启用断点续传(默认)") - resume_group.add_argument("--no-resume", dest="no_resume", action="store_true", help="禁用断点续传") - parser.set_defaults(no_resume=False) - - parser.add_argument("--reset-state", action="store_true", help="清空迁移状态文件后执行") - parser.add_argument("--start-id", type=int, default=None, help="从指定 chat_history.id 开始迁移(覆盖断点)") - parser.add_argument("--end-id", type=int, default=None, help="迁移到指定 chat_history.id") - - parser.add_argument("--read-batch-size", type=int, default=2000, help="源库分页读取大小(默认 2000)") - parser.add_argument("--commit-window-rows", type=int, default=20000, help="每窗口提交行数(默认 20000)") - parser.add_argument("--embed-batch-size", type=int, default=256, help="段落 embedding 批次大小(默认 256)") - parser.add_argument( - "--entity-embed-batch-size", - type=int, - default=512, - help="实体 embedding 批次大小(默认 512)", - ) - parser.add_argument("--embed-workers", type=int, default=None, help="embedding 并发数(默认读取配置)") - parser.add_argument("--max-errors", type=int, default=500, help="坏行上限(默认 500)") - parser.add_argument("--log-every", type=int, default=5000, help="日志输出步长(默认 5000)") - - parser.add_argument("--dry-run", action="store_true", help="仅预览不写入") - parser.add_argument("--verify-only", action="store_true", help="仅执行严格校验") - - parser.add_argument("--time-from", default=None, help="开始时间:YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]") - parser.add_argument("--time-to", default=None, help="结束时间:YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]") - parser.add_argument("--stream-id", action="append", default=[], help="聊天流 stream_id(可重复)") - parser.add_argument("--group-id", action="append", default=[], help="群号(可重复,自动映射 stream_id)") - parser.add_argument("--user-id", action="append", default=[], help="用户号(可重复,自动映射 stream_id)") - parser.add_argument("--yes", action="store_true", help="跳过交互确认") - parser.add_argument("--preview-limit", type=int, default=20, help="预览样本条数(默认 20)") - - return parser - - -async def async_main() -> int: - parser = build_parser() - args = parser.parse_args() - - runner = MigrationRunner(args) - return await runner.run() - - -def main() -> int: - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - return asyncio.run(async_main()) - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/migrate_person_memory_points.py b/plugins/A_memorix/scripts/migrate_person_memory_points.py deleted file mode 100644 index a03a8914..00000000 --- a/plugins/A_memorix/scripts/migrate_person_memory_points.py +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import asyncio -import json -import sqlite3 -import sys -from pathlib import Path -from typing import Any, Dict, List - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db" - -if str(WORKSPACE_ROOT) not in sys.path: - sys.path.insert(0, str(WORKSPACE_ROOT)) -if str(MAIBOT_ROOT) not in sys.path: - sys.path.insert(0, str(MAIBOT_ROOT)) - -from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402 - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="迁移 MaiBot person_info.memory_points 到 A_Memorix") - parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径") - parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录") - parser.add_argument("--limit", type=int, default=0, help="限制迁移人数,0 表示全部") - parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入") - return parser.parse_args() - - -def _parse_memory_points(raw_value: Any) -> List[Dict[str, Any]]: - try: - values = json.loads(raw_value) if raw_value else [] - except Exception: - values = [] - items: List[Dict[str, Any]] = [] - for index, item in enumerate(values): - text = str(item or "").strip() - if not text: - continue - parts = text.split(":") - if len(parts) >= 3: - category = parts[0].strip() - content = ":".join(parts[1:-1]).strip() - weight = parts[-1].strip() - else: - category = "其他" - content = text - weight = "1.0" - if content: - items.append({"index": index, "category": category or "其他", "content": content, "weight": weight or "1.0"}) - return items - - -async def _main() -> int: - args = _parse_args() - db_path = Path(args.db_path).resolve() - if not db_path.exists(): - print(f"数据库不存在: {db_path}") - return 1 - - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - sql = """ - SELECT person_id, person_name, user_nickname, memory_points - FROM person_info - WHERE memory_points IS NOT NULL AND memory_points != '' - ORDER BY id ASC - """ - if int(args.limit or 0) > 0: - sql += " LIMIT ?" - rows = conn.execute(sql, (int(args.limit),)).fetchall() - else: - rows = conn.execute(sql).fetchall() - conn.close() - - preview_total = sum(len(_parse_memory_points(row["memory_points"])) for row in rows) - print(f"person_info 待迁移人物: {len(rows)} 记忆点: {preview_total}") - if args.dry_run: - for row in rows[:5]: - print(f"- person_id={row['person_id']} person_name={row['person_name'] or row['user_nickname']}") - return 0 - - kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}}) - await kernel.initialize() - migrated = 0 - skipped = 0 - for row in rows: - person_id = str(row["person_id"] or "").strip() - if not person_id: - continue - display_name = str(row["person_name"] or row["user_nickname"] or "").strip() - for item in _parse_memory_points(row["memory_points"]): - result: Dict[str, Any] = await kernel.ingest_text( - external_id=f"person_memory:{person_id}:{item['index']}", - source_type="person_fact", - text=f"[{item['category']}] {item['content']}", - person_ids=[person_id], - tags=[item["category"]], - entities=[person_id, display_name] if display_name else [person_id], - metadata={"category": item["category"], "weight": item["weight"], "display_name": display_name}, - ) - if result.get("stored_ids"): - migrated += 1 - else: - skipped += 1 - - print(f"迁移完成: migrated={migrated} skipped={skipped}") - print(json.dumps(kernel.memory_stats(), ensure_ascii=False)) - kernel.close() - return 0 - - -if __name__ == "__main__": - raise SystemExit(asyncio.run(_main())) diff --git a/plugins/A_memorix/scripts/process_knowledge.py b/plugins/A_memorix/scripts/process_knowledge.py deleted file mode 100644 index d9e6fe32..00000000 --- a/plugins/A_memorix/scripts/process_knowledge.py +++ /dev/null @@ -1,728 +0,0 @@ -#!/usr/bin/env python3 -""" -知识库自动导入脚本 (Strategy-Aware Version) - -功能: -1. 扫描 plugins/A_memorix/data/raw 下的 .txt 文件 -2. 检查 data/import_manifest.json 确认是否已导入 -3. 使用 Strategy 模式处理文件 (Narrative/Factual/Quote) -4. 将生成的数据直接存入 VectorStore/GraphStore/MetadataStore -5. 更新 manifest -""" - -import sys -import os -import json -import asyncio -import time -import random -import hashlib -import tomlkit -import argparse -from pathlib import Path -from datetime import datetime -from typing import List, Dict, Any, Optional -from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn -from rich.console import Console -from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type - -console = Console() - -class LLMGenerationError(Exception): - pass - -# 路径设置 -current_dir = Path(__file__).resolve().parent -plugin_root = current_dir.parent -workspace_root = plugin_root.parent -maibot_root = workspace_root / "MaiBot" -for path in (workspace_root, maibot_root, plugin_root): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - -# 数据目录 -DATA_DIR = plugin_root / "data" -RAW_DIR = DATA_DIR / "raw" -PROCESSED_DIR = DATA_DIR / "processed" -MANIFEST_PATH = DATA_DIR / "import_manifest.json" - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="A_Memorix Knowledge Importer (Strategy-Aware)") - parser.add_argument("--force", action="store_true", help="Force re-import") - parser.add_argument("--clear-manifest", action="store_true", help="Clear manifest") - parser.add_argument( - "--type", - "-t", - default="auto", - help="Target import strategy override (auto/narrative/factual/quote)", - ) - parser.add_argument("--concurrency", "-c", type=int, default=5) - parser.add_argument( - "--chat-log", - action="store_true", - help="聊天记录导入模式:强制 narrative 策略,并使用 LLM 语义抽取 event_time/event_time_range", - ) - parser.add_argument( - "--chat-reference-time", - default=None, - help="chat_log 模式的相对时间参考点(如 2026/02/12 10:30);不传则使用当前本地时间", - ) - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - sys.exit(0) - - -try: - import A_memorix.core as core_module - import A_memorix.core.storage as storage_module - from src.common.logger import get_logger - from src.services import llm_service as llm_api - from src.config.config import global_config, model_config - - VectorStore = core_module.VectorStore - GraphStore = core_module.GraphStore - MetadataStore = core_module.MetadataStore - ImportStrategy = core_module.ImportStrategy - create_embedding_api_adapter = core_module.create_embedding_api_adapter - RelationWriteService = getattr(core_module, "RelationWriteService", None) - - looks_like_quote_text = storage_module.looks_like_quote_text - parse_import_strategy = storage_module.parse_import_strategy - resolve_stored_knowledge_type = storage_module.resolve_stored_knowledge_type - select_import_strategy = storage_module.select_import_strategy - - from A_memorix.core.utils.time_parser import normalize_time_meta - from A_memorix.core.utils.import_payloads import normalize_paragraph_import_item - from A_memorix.core.strategies.base import BaseStrategy, ProcessedChunk, KnowledgeType as StratKnowledgeType - from A_memorix.core.strategies.narrative import NarrativeStrategy - from A_memorix.core.strategies.factual import FactualStrategy - from A_memorix.core.strategies.quote import QuoteStrategy - -except ImportError as e: - print(f"❌ 无法导入模块: {e}") - import traceback - traceback.print_exc() - sys.exit(1) - -logger = get_logger("A_Memorix.AutoImport") - - -def _log_before_retry(retry_state) -> None: - """使用项目统一日志风格记录重试信息。""" - exc = None - if getattr(retry_state, "outcome", None) is not None and retry_state.outcome.failed: - exc = retry_state.outcome.exception() - next_sleep = getattr(getattr(retry_state, "next_action", None), "sleep", None) - logger.warning( - "LLM 调用即将重试: " - f"attempt={getattr(retry_state, 'attempt_number', '?')} " - f"next_sleep={next_sleep} " - f"error={exc}" - ) - -class AutoImporter: - def __init__( - self, - force: bool = False, - clear_manifest: bool = False, - target_type: str = "auto", - concurrency: int = 5, - chat_log: bool = False, - chat_reference_time: Optional[str] = None, - ): - self.vector_store: Optional[VectorStore] = None - self.graph_store: Optional[GraphStore] = None - self.metadata_store: Optional[MetadataStore] = None - self.embedding_manager = None - self.relation_write_service = None - self.plugin_config = {} - self.manifest = {} - self.force = force - self.clear_manifest = clear_manifest - self.chat_log = chat_log - parsed_target_type = parse_import_strategy(target_type, default=ImportStrategy.AUTO) - self.target_type = ImportStrategy.NARRATIVE.value if chat_log else parsed_target_type.value - self.chat_reference_dt = self._parse_reference_time(chat_reference_time) - if self.chat_log and parsed_target_type not in {ImportStrategy.AUTO, ImportStrategy.NARRATIVE}: - logger.warning( - f"chat_log 模式已启用,target_type={target_type} 将被覆盖为 narrative" - ) - self.concurrency_limit = concurrency - self.semaphore = None - self.storage_lock = None - - async def initialize(self): - logger.info(f"正在初始化... (并发数: {self.concurrency_limit})") - self.semaphore = asyncio.Semaphore(self.concurrency_limit) - self.storage_lock = asyncio.Lock() - - RAW_DIR.mkdir(parents=True, exist_ok=True) - PROCESSED_DIR.mkdir(parents=True, exist_ok=True) - - if self.clear_manifest: - logger.info("🧹 清理 Mainfest") - self.manifest = {} - self._save_manifest() - elif MANIFEST_PATH.exists(): - try: - with open(MANIFEST_PATH, "r", encoding="utf-8") as f: - self.manifest = json.load(f) - except Exception: - self.manifest = {} - - config_path = plugin_root / "config.toml" - try: - with open(config_path, "r", encoding="utf-8") as f: - self.plugin_config = tomlkit.load(f) - except Exception as e: - logger.error(f"加载插件配置失败: {e}") - return False - - try: - await self._init_stores() - except Exception as e: - logger.error(f"初始化存储失败: {e}") - return False - - return True - - async def _init_stores(self): - # ... (Same as original) - self.embedding_manager = create_embedding_api_adapter( - batch_size=self.plugin_config.get("embedding", {}).get("batch_size", 32), - default_dimension=self.plugin_config.get("embedding", {}).get("dimension", 384), - model_name=self.plugin_config.get("embedding", {}).get("model_name", "auto"), - retry_config=self.plugin_config.get("embedding", {}).get("retry", {}), - ) - try: - dim = await self.embedding_manager._detect_dimension() - except: - dim = self.embedding_manager.default_dimension - - q_type_str = str(self.plugin_config.get("embedding", {}).get("quantization_type", "int8") or "int8").lower() - # Need to access QuantizationType from storage_module if not imported globally - QuantizationType = storage_module.QuantizationType - if q_type_str != "int8": - raise ValueError( - "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - - self.vector_store = VectorStore( - dimension=dim, - quantization_type=QuantizationType.INT8, - data_dir=DATA_DIR / "vectors" - ) - - SparseMatrixFormat = storage_module.SparseMatrixFormat - m_fmt_str = self.plugin_config.get("graph", {}).get("sparse_matrix_format", "csr") - m_map = {"csr": SparseMatrixFormat.CSR, "csc": SparseMatrixFormat.CSC} - - self.graph_store = GraphStore( - matrix_format=m_map.get(m_fmt_str, SparseMatrixFormat.CSR), - data_dir=DATA_DIR / "graph" - ) - - self.metadata_store = MetadataStore(data_dir=DATA_DIR / "metadata") - self.metadata_store.connect() - - if RelationWriteService is not None: - self.relation_write_service = RelationWriteService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - - if self.vector_store.has_data(): self.vector_store.load() - if self.graph_store.has_data(): self.graph_store.load() - - def _should_write_relation_vectors(self) -> bool: - retrieval_cfg = self.plugin_config.get("retrieval", {}) - if not isinstance(retrieval_cfg, dict): - return False - rv_cfg = retrieval_cfg.get("relation_vectorization", {}) - if not isinstance(rv_cfg, dict): - return False - return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) - - def load_file(self, file_path: Path) -> str: - with open(file_path, "r", encoding="utf-8") as f: - return f.read() - - def get_file_hash(self, content: str) -> str: - return hashlib.md5(content.encode("utf-8")).hexdigest() - - def _parse_reference_time(self, value: Optional[str]) -> datetime: - """解析 chat_log 模式的参考时间(用于相对时间语义解析)。""" - if not value: - return datetime.now() - formats = [ - "%Y/%m/%d %H:%M:%S", - "%Y/%m/%d %H:%M", - "%Y-%m-%d %H:%M:%S", - "%Y-%m-%d %H:%M", - "%Y/%m/%d", - "%Y-%m-%d", - ] - text = str(value).strip() - for fmt in formats: - try: - return datetime.strptime(text, fmt) - except ValueError: - continue - logger.warning( - f"无法解析 chat_reference_time={value},将回退为当前本地时间" - ) - return datetime.now() - - async def _extract_chat_time_meta_with_llm( - self, - text: str, - model_config: Any, - ) -> Optional[Dict[str, Any]]: - """ - 使用 LLM 从聊天文本语义中抽取时间信息。 - 支持将相对时间表达转换为绝对时间。 - """ - if not text.strip(): - return None - - reference_now = self.chat_reference_dt.strftime("%Y/%m/%d %H:%M") - prompt = f"""You are a time extraction engine for chat logs. -Extract temporal information from the following chat paragraph. - -Rules: -1. Use semantic understanding, not regex matching. -2. Convert relative expressions (e.g., yesterday evening, last Friday morning) to absolute local datetime using reference_now. -3. If a time span exists, return event_time_start/event_time_end. -4. If only one point in time exists, return event_time. -5. If no reliable time can be inferred, return all time fields as null. -6. Output ONLY valid JSON. No markdown, no explanation. - -reference_now: {reference_now} -timezone: local system timezone - -Allowed output formats for time values: -- "YYYY/MM/DD" -- "YYYY/MM/DD HH:mm" - -JSON schema: -{{ - "event_time": null, - "event_time_start": null, - "event_time_end": null, - "time_range": null, - "time_granularity": "day", - "time_confidence": 0.0 -}} - -Chat paragraph: -\"\"\"{text}\"\"\" -""" - try: - result = await self._llm_call(prompt, model_config) - except Exception as e: - logger.warning(f"chat_log 时间语义抽取失败: {e}") - return None - - if not isinstance(result, dict): - return None - - raw_time_meta = { - "event_time": result.get("event_time"), - "event_time_start": result.get("event_time_start"), - "event_time_end": result.get("event_time_end"), - "time_range": result.get("time_range"), - "time_granularity": result.get("time_granularity"), - "time_confidence": result.get("time_confidence"), - } - try: - normalized = normalize_time_meta(raw_time_meta) - except Exception as e: - logger.warning(f"chat_log 时间语义抽取结果不可用,已忽略: {e}") - return None - - has_effective_time = any( - key in normalized - for key in ("event_time", "event_time_start", "event_time_end") - ) - if not has_effective_time: - return None - - return normalized - - def _determine_strategy(self, filename: str, content: str) -> BaseStrategy: - """Layer 1: Global Strategy Routing""" - strategy = select_import_strategy( - content, - override=self.target_type, - chat_log=self.chat_log, - ) - if self.chat_log: - logger.info(f"chat_log 模式: {filename} 强制使用 NarrativeStrategy") - elif strategy == ImportStrategy.QUOTE: - logger.info(f"Auto-detected Quote/Lyric type for {filename}") - - if strategy == ImportStrategy.FACTUAL: - return FactualStrategy(filename) - if strategy == ImportStrategy.QUOTE: - return QuoteStrategy(filename) - return NarrativeStrategy(filename) - - def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[BaseStrategy]: - """Layer 2: Chunk-level rescue strategies""" - # If we are already in Quote strategy, no need to rescue - if chunk.type == StratKnowledgeType.QUOTE: - return None - - if looks_like_quote_text(chunk.chunk.text): - logger.info(f" > Rescuing chunk {chunk.chunk.index} as Quote") - return QuoteStrategy(filename) - - return None - - async def process_and_import(self): - if not await self.initialize(): return - - files = list(RAW_DIR.glob("*.txt")) - logger.info(f"扫描到 {len(files)} 个文件 in {RAW_DIR}") - - if not files: return - - tasks = [] - for file_path in files: - tasks.append(asyncio.create_task(self._process_single_file(file_path))) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - success_count = sum(1 for r in results if r is True) - logger.info(f"本次主处理完成,共成功处理 {success_count}/{len(files)} 个文件") - - if self.vector_store: self.vector_store.save() - if self.graph_store: self.graph_store.save() - - async def _process_single_file(self, file_path: Path) -> bool: - filename = file_path.name - async with self.semaphore: - try: - content = self.load_file(file_path) - file_hash = self.get_file_hash(content) - - if not self.force and filename in self.manifest: - record = self.manifest[filename] - if record.get("hash") == file_hash and record.get("imported"): - logger.info(f"跳过已导入文件: {filename}") - return False - - logger.info(f">>> 开始处理: {filename}") - - # 1. Strategy Selection - strategy = self._determine_strategy(filename, content) - logger.info(f" 策略: {strategy.__class__.__name__}") - - # 2. Split (Strategy-Aware) - initial_chunks = strategy.split(content) - logger.info(f" 初步分块: {len(initial_chunks)}") - - processed_data = {"paragraphs": [], "entities": [], "relations": []} - - # 3. Extract Loop - model_config = await self._select_model() - - for i, chunk in enumerate(initial_chunks): - current_strategy = strategy - # Layer 2: Chunk Rescue - rescue_strategy = self._chunk_rescue(chunk, filename) - if rescue_strategy: - # Re-split? No, just re-process this text as a single chunk using the rescue strategy - # But rescue strategy might want to split it further? - # Simplification: Treat the whole chunk text as one block for the rescue strategy - # OR create a single chunk object for it. - # Creating a new chunk using rescue strategy logic might be complex if split behavior differs. - # Let's just instantiate a chunk of the new type manually - chunk.type = StratKnowledgeType.QUOTE - chunk.flags.verbatim = True - chunk.flags.requires_llm = False # Quotes don't usually need LLM - current_strategy = rescue_strategy - - # Extraction - if chunk.flags.requires_llm: - result_chunk = await current_strategy.extract(chunk, lambda p: self._llm_call(p, model_config)) - else: - # For quotes, extract might be just pass through or regex - result_chunk = await current_strategy.extract(chunk) - - time_meta = None - if self.chat_log: - time_meta = await self._extract_chat_time_meta_with_llm( - result_chunk.chunk.text, - model_config, - ) - - # Normalize Data - self._normalize_and_aggregate( - result_chunk, - processed_data, - time_meta=time_meta, - ) - - logger.info(f" 已处理块 {i+1}/{len(initial_chunks)}") - - # 4. Save Json - json_path = PROCESSED_DIR / f"{file_path.stem}.json" - with open(json_path, "w", encoding="utf-8") as f: - json.dump(processed_data, f, ensure_ascii=False, indent=2) - - # 5. Import to DB - async with self.storage_lock: - await self._import_to_db(processed_data) - - self.manifest[filename] = { - "hash": file_hash, - "timestamp": time.time(), - "imported": True - } - self._save_manifest() - self.vector_store.save() - self.graph_store.save() - logger.info(f"✅ 文件 {filename} 处理并导入完成") - return True - - except Exception as e: - logger.error(f"❌ 处理失败 {filename}: {e}") - import traceback - traceback.print_exc() - return False - - def _normalize_and_aggregate( - self, - chunk: ProcessedChunk, - all_data: Dict, - time_meta: Optional[Dict[str, Any]] = None, - ): - """Convert strategy-specific data to unified generic format for storage.""" - # Generic fields - para_item = { - "content": chunk.chunk.text, - "source": chunk.source.file, - "knowledge_type": resolve_stored_knowledge_type( - chunk.type.value, - content=chunk.chunk.text, - ).value, - "entities": [], - "relations": [] - } - - data = chunk.data - - # 1. Triples (Factual) - if "triples" in data: - for t in data["triples"]: - para_item["relations"].append({ - "subject": t.get("subject"), - "predicate": t.get("predicate"), - "object": t.get("object") - }) - # Auto-add entities from triples - para_item["entities"].extend([t.get("subject"), t.get("object")]) - - # 2. Events & Relations (Narrative) - if "events" in data: - # Store events as content/metadata? Or entities? - # For now maybe just keep them in logic, or add as 'Event' entities? - # Creating entities for events is good. - para_item["entities"].extend(data["events"]) - - if "relations" in data: # Narrative also outputs relations list - para_item["relations"].extend(data["relations"]) - for r in data["relations"]: - para_item["entities"].extend([r.get("subject"), r.get("object")]) - - # 3. Verbatim Entities (Quote) - if "verbatim_entities" in data: - para_item["entities"].extend(data["verbatim_entities"]) - - # Dedupe per paragraph - para_item["entities"] = list(set([e for e in para_item["entities"] if e])) - - if time_meta: - para_item["time_meta"] = time_meta - - all_data["paragraphs"].append(para_item) - all_data["entities"].extend(para_item["entities"]) - if "relations" in para_item: - all_data["relations"].extend(para_item["relations"]) - - @retry( - retry=retry_if_exception_type((LLMGenerationError, json.JSONDecodeError)), - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=10), - before_sleep=_log_before_retry - ) - async def _llm_call(self, prompt: str, model_config: Any) -> Dict: - """Generic LLM Caller""" - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type="Script.ProcessKnowledge" - ) - if success: - txt = response.strip() - if "```" in txt: - txt = txt.split("```json")[-1].split("```")[0].strip() - try: - return json.loads(txt) - except json.JSONDecodeError: - # Fallback: try to find first { and last } - start = txt.find('{') - end = txt.rfind('}') - if start != -1 and end != -1: - return json.loads(txt[start:end+1]) - raise - else: - raise LLMGenerationError("LLM generation failed") - - async def _select_model(self) -> Any: - models = llm_api.get_available_models() - if not models: raise ValueError("No LLM models") - - config_model = self.plugin_config.get("advanced", {}).get("extraction_model", "auto") - if config_model != "auto" and config_model in models: - return models[config_model] - - for task_key in ["lpmm_entity_extract", "lpmm_rdf_build", "embedding"]: - if task_key in models: return models[task_key] - - return models[list(models.keys())[0]] - - # Re-use existing methods - async def _add_entity_with_vector(self, name: str, source_paragraph: Optional[str] = None) -> str: - # Same as before - hash_value = self.metadata_store.add_entity(name, source_paragraph=source_paragraph) - self.graph_store.add_nodes([name]) - try: - emb = await self.embedding_manager.encode(name) - try: - self.vector_store.add(emb.reshape(1, -1), [hash_value]) - except ValueError: pass - except Exception: pass - return hash_value - - async def import_json_data(self, data: Dict, filename: str = "script_import", progress_callback=None): - """Public import entrypoint for pre-processed JSON payloads.""" - if not self.storage_lock: - raise RuntimeError("Importer is not initialized. Call initialize() first.") - - async with self.storage_lock: - await self._import_to_db(data, progress_callback=progress_callback) - self.manifest[filename] = { - "hash": self.get_file_hash(json.dumps(data, ensure_ascii=False, sort_keys=True)), - "timestamp": time.time(), - "imported": True, - } - self._save_manifest() - self.vector_store.save() - self.graph_store.save() - - async def _import_to_db(self, data: Dict, progress_callback=None): - # Same logic, but ensure robust - with self.graph_store.batch_update(): - for item in data.get("paragraphs", []): - paragraph = normalize_paragraph_import_item( - item, - default_source="script", - ) - content = paragraph["content"] - source = paragraph["source"] - k_type_val = paragraph["knowledge_type"] - - h_val = self.metadata_store.add_paragraph( - content=content, - source=source, - knowledge_type=k_type_val, - time_meta=paragraph["time_meta"], - ) - - if h_val not in self.vector_store: - try: - emb = await self.embedding_manager.encode(content) - self.vector_store.add(emb.reshape(1, -1), [h_val]) - except Exception as e: - logger.error(f" Vector fail: {e}") - - para_entities = paragraph["entities"] - for entity in para_entities: - if entity: - await self._add_entity_with_vector(entity, source_paragraph=h_val) - - para_relations = paragraph["relations"] - for rel in para_relations: - s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object") - if s and p and o: - await self._add_entity_with_vector(s, source_paragraph=h_val) - await self._add_entity_with_vector(o, source_paragraph=h_val) - confidence = float(rel.get("confidence", 1.0) or 1.0) - rel_meta = rel.get("metadata", {}) - write_vector = self._should_write_relation_vectors() - if self.relation_write_service is not None: - await self.relation_write_service.upsert_relation_with_vector( - subject=s, - predicate=p, - obj=o, - confidence=confidence, - source_paragraph=h_val, - metadata=rel_meta if isinstance(rel_meta, dict) else {}, - write_vector=write_vector, - ) - else: - rel_hash = self.metadata_store.add_relation( - s, - p, - o, - confidence=confidence, - source_paragraph=h_val, - metadata=rel_meta if isinstance(rel_meta, dict) else {}, - ) - self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash]) - try: - self.metadata_store.set_relation_vector_state(rel_hash, "none") - except Exception: - pass - - if progress_callback: progress_callback(1) - - async def close(self): - if self.metadata_store: self.metadata_store.close() - - def _save_manifest(self): - with open(MANIFEST_PATH, "w", encoding="utf-8") as f: - json.dump(self.manifest, f, ensure_ascii=False, indent=2) - -async def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - if not global_config: return - - importer = AutoImporter( - force=args.force, - clear_manifest=args.clear_manifest, - target_type=args.type, - concurrency=args.concurrency, - chat_log=args.chat_log, - chat_reference_time=args.chat_reference_time, - ) - await importer.process_and_import() - await importer.close() - -if __name__ == "__main__": - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - asyncio.run(main()) diff --git a/plugins/A_memorix/scripts/rebuild_episodes.py b/plugins/A_memorix/scripts/rebuild_episodes.py deleted file mode 100644 index b6adaa21..00000000 --- a/plugins/A_memorix/scripts/rebuild_episodes.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -"""Episode source 级重建工具。""" - -from __future__ import annotations - -import argparse -import asyncio -import sys -from pathlib import Path -from typing import Any, Dict, List - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - -try: - import tomlkit # type: ignore -except Exception: # pragma: no cover - tomlkit = None - -from A_memorix.core.storage import MetadataStore -from A_memorix.core.utils.episode_service import EpisodeService - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Rebuild A_Memorix episodes by source") - parser.add_argument("--data-dir", default=str(PLUGIN_ROOT / "data"), help="插件数据目录") - parser.add_argument("--source", type=str, help="指定单个 source 入队/重建") - parser.add_argument("--all", action="store_true", help="对所有 source 入队/重建") - parser.add_argument("--wait", action="store_true", help="在脚本内同步执行重建") - return parser - - -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - - -def _load_plugin_config() -> Dict[str, Any]: - config_path = PLUGIN_ROOT / "config.toml" - if tomlkit is None or not config_path.exists(): - return {} - try: - with open(config_path, "r", encoding="utf-8") as handle: - parsed = tomlkit.load(handle) - return dict(parsed) if isinstance(parsed, dict) else {} - except Exception: - return {} - - -def _resolve_sources(store: MetadataStore, *, source: str | None, rebuild_all: bool) -> List[str]: - if rebuild_all: - return list(store.list_episode_sources_for_rebuild()) - token = str(source or "").strip() - if not token: - raise ValueError("必须提供 --source 或 --all") - return [token] - - -async def _run_rebuilds(store: MetadataStore, plugin_config: Dict[str, Any], sources: List[str]) -> int: - service = EpisodeService(metadata_store=store, plugin_config=plugin_config) - failures: List[str] = [] - for source in sources: - started = store.mark_episode_source_running(source) - if not started: - failures.append(f"{source}: unable_to_mark_running") - continue - try: - result = await service.rebuild_source(source) - store.mark_episode_source_done(source) - print( - "rebuilt" - f" source={source}" - f" paragraphs={int(result.get('paragraph_count') or 0)}" - f" groups={int(result.get('group_count') or 0)}" - f" episodes={int(result.get('episode_count') or 0)}" - f" fallback={int(result.get('fallback_count') or 0)}" - ) - except Exception as exc: - err = str(exc)[:500] - store.mark_episode_source_failed(source, err) - failures.append(f"{source}: {err}") - print(f"failed source={source} error={err}") - - if failures: - for item in failures: - print(item) - return 1 - return 0 - - -def main() -> int: - parser = _build_arg_parser() - args = parser.parse_args() - if bool(args.all) == bool(args.source): - parser.error("必须且只能选择一个:--source 或 --all") - - store = MetadataStore(data_dir=Path(args.data_dir) / "metadata") - store.connect() - try: - sources = _resolve_sources(store, source=args.source, rebuild_all=bool(args.all)) - if not sources: - print("no sources to rebuild") - return 0 - - enqueued = 0 - reason = "script_rebuild_all" if args.all else "script_rebuild_source" - for source in sources: - enqueued += int(store.enqueue_episode_source_rebuild(source, reason=reason)) - print(f"enqueued={enqueued} sources={len(sources)}") - - if not args.wait: - return 0 - - plugin_config = _load_plugin_config() - return asyncio.run(_run_rebuilds(store, plugin_config, sources)) - finally: - store.close() - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/release_vnext_migrate.py b/plugins/A_memorix/scripts/release_vnext_migrate.py deleted file mode 100644 index 0922fd0b..00000000 --- a/plugins/A_memorix/scripts/release_vnext_migrate.py +++ /dev/null @@ -1,731 +0,0 @@ -#!/usr/bin/env python3 -""" -vNext release migration entrypoint for A_Memorix. - -Subcommands: -- preflight: detect legacy config/data/schema risks -- migrate: offline migrate config + vectors + metadata schema + graph edge hash map -- verify: strict post-migration consistency checks -""" - -from __future__ import annotations - -import argparse -import json -import pickle -import sqlite3 -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple - -import tomlkit - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PLUGIN_ROOT)) - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="A_Memorix vNext release migration tool") - parser.add_argument( - "--config", - default=str(PLUGIN_ROOT / "config.toml"), - help="config.toml path (default: plugins/A_memorix/config.toml)", - ) - parser.add_argument( - "--data-dir", - default="", - help="optional data dir override; default resolved from config.storage.data_dir", - ) - parser.add_argument("--json-out", default="", help="optional JSON report output path") - - sub = parser.add_subparsers(dest="command", required=True) - - p_preflight = sub.add_parser("preflight", help="scan legacy risks") - p_preflight.add_argument("--strict", action="store_true", help="return 1 if any error check exists") - - p_migrate = sub.add_parser("migrate", help="run offline migration") - p_migrate.add_argument("--dry-run", action="store_true", help="only print planned changes") - p_migrate.add_argument( - "--verify-after", - action="store_true", - help="run verify automatically after migrate", - ) - - p_verify = sub.add_parser("verify", help="post-migration verification") - p_verify.add_argument("--strict", action="store_true", help="return 1 if any error check exists") - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - -try: - from core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore - from core.storage.metadata_store import SCHEMA_VERSION -except Exception as e: # pragma: no cover - print(f"❌ failed to import storage modules: {e}") - raise SystemExit(2) - - -@dataclass -class CheckItem: - code: str - level: str - message: str - details: Optional[Dict[str, Any]] = None - - def to_dict(self) -> Dict[str, Any]: - out = { - "code": self.code, - "level": self.level, - "message": self.message, - } - if self.details: - out["details"] = self.details - return out - - -def _read_toml(path: Path) -> Dict[str, Any]: - text = path.read_text(encoding="utf-8") - return tomlkit.parse(text) - - -def _write_toml(path: Path, data: Dict[str, Any]) -> None: - path.write_text(tomlkit.dumps(data), encoding="utf-8") - - -def _get_nested(obj: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: - cur: Any = obj - for k in keys: - if not isinstance(cur, dict) or k not in cur: - return default - cur = cur[k] - return cur - - -def _ensure_table(obj: Dict[str, Any], key: str) -> Dict[str, Any]: - if key not in obj or not isinstance(obj[key], dict): - obj[key] = tomlkit.table() - return obj[key] - - -def _resolve_data_dir(config_doc: Dict[str, Any], explicit_data_dir: Optional[str]) -> Path: - if explicit_data_dir: - return Path(explicit_data_dir).expanduser().resolve() - raw = str(_get_nested(config_doc, ("storage", "data_dir"), "./data") or "./data").strip() - if raw.startswith("."): - return (PLUGIN_ROOT / raw).resolve() - return Path(raw).expanduser().resolve() - - -def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool: - row = conn.execute( - "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1", - (table,), - ).fetchone() - return row is not None - - -def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]: - hashes: List[str] = [] - if _sqlite_table_exists(conn, "relations"): - rows = conn.execute("SELECT hash FROM relations").fetchall() - hashes.extend(str(r[0]) for r in rows if r and r[0]) - if _sqlite_table_exists(conn, "deleted_relations"): - rows = conn.execute("SELECT hash FROM deleted_relations").fetchall() - hashes.extend(str(r[0]) for r in rows if r and r[0]) - - alias_map: Dict[str, str] = {} - conflicts: Dict[str, set[str]] = {} - for h in hashes: - if len(h) != 64: - continue - alias = h[:32] - old = alias_map.get(alias) - if old is None: - alias_map[alias] = h - continue - if old != h: - conflicts.setdefault(alias, set()).update({old, h}) - return {k: sorted(v) for k, v in conflicts.items()} - - -def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]: - if not _sqlite_table_exists(conn, "paragraphs"): - return [] - - allowed = {item.value for item in KnowledgeType} - rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall() - invalid: List[str] = [] - for row in rows: - raw = row[0] - value = str(raw).strip().lower() if raw is not None else "" - if value not in allowed: - invalid.append(str(raw) if raw is not None else "") - return sorted(set(invalid)) - - -def _guess_vector_dimension(config_doc: Dict[str, Any], vectors_dir: Path) -> int: - meta_path = vectors_dir / "vectors_metadata.pkl" - if meta_path.exists(): - try: - with open(meta_path, "rb") as f: - meta = pickle.load(f) - dim = int(meta.get("dimension", 0)) - if dim > 0: - return dim - except Exception: - pass - try: - dim_cfg = int(_get_nested(config_doc, ("embedding", "dimension"), 1024)) - if dim_cfg > 0: - return dim_cfg - except Exception: - pass - return 1024 - - -def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: - checks: List[CheckItem] = [] - facts: Dict[str, Any] = { - "config_path": str(config_path), - "data_dir": str(data_dir), - } - - if not config_path.exists(): - checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}")) - return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts} - - config_doc = _read_toml(config_path) - tool_mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower() - summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"]) - summary_knowledge_type = str( - _get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative" - ).strip().lower() - quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower() - - facts["routing.tool_search_mode"] = tool_mode - facts["summarization.model_name_type"] = type(summary_model).__name__ - facts["summarization.default_knowledge_type"] = summary_knowledge_type - facts["embedding.quantization_type"] = quantization - - if tool_mode == "legacy": - checks.append( - CheckItem( - "CP-04", - "error", - "routing.tool_search_mode=legacy is no longer accepted at runtime", - ) - ) - elif tool_mode not in {"forward", "disabled"}: - checks.append( - CheckItem( - "CP-04", - "error", - f"routing.tool_search_mode invalid value: {tool_mode}", - ) - ) - - if isinstance(summary_model, str): - checks.append( - CheckItem( - "CP-11", - "error", - "summarization.model_name must be List[str], string legacy format detected", - ) - ) - elif not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model): - checks.append( - CheckItem( - "CP-11", - "error", - "summarization.model_name must be List[str]", - ) - ) - - if summary_knowledge_type not in {item.value for item in KnowledgeType}: - checks.append( - CheckItem( - "CP-13", - "error", - f"invalid summarization.default_knowledge_type: {summary_knowledge_type}", - ) - ) - - if quantization != "int8": - checks.append( - CheckItem( - "UG-07", - "error", - "embedding.quantization_type must be int8 in vNext", - ) - ) - - vectors_dir = data_dir / "vectors" - npy_path = vectors_dir / "vectors.npy" - bin_path = vectors_dir / "vectors.bin" - ids_bin_path = vectors_dir / "vectors_ids.bin" - facts["vectors.npy_exists"] = npy_path.exists() - facts["vectors.bin_exists"] = bin_path.exists() - facts["vectors_ids.bin_exists"] = ids_bin_path.exists() - - if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): - checks.append( - CheckItem( - "CP-07", - "error", - "legacy vectors.npy detected; offline migrate required", - {"npy_path": str(npy_path)}, - ) - ) - - metadata_db = data_dir / "metadata" / "metadata.db" - facts["metadata_db_exists"] = metadata_db.exists() - relation_count = 0 - if metadata_db.exists(): - conn = sqlite3.connect(str(metadata_db)) - try: - has_schema_table = _sqlite_table_exists(conn, "schema_migrations") - facts["schema_migrations_exists"] = has_schema_table - if not has_schema_table: - checks.append( - CheckItem( - "CP-08", - "error", - "schema_migrations table missing (legacy metadata schema)", - ) - ) - else: - row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone() - version = int(row[0]) if row and row[0] is not None else 0 - facts["schema_version"] = version - if version != SCHEMA_VERSION: - checks.append( - CheckItem( - "CP-08", - "error", - f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}", - ) - ) - - if _sqlite_table_exists(conn, "relations"): - row = conn.execute("SELECT COUNT(*) FROM relations").fetchone() - relation_count = int(row[0]) if row and row[0] is not None else 0 - facts["relations_count"] = relation_count - - conflicts = _collect_hash_alias_conflicts(conn) - facts["alias_conflict_count"] = len(conflicts) - if conflicts: - checks.append( - CheckItem( - "CP-05", - "error", - "32-bit relation hash alias conflict detected", - {"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)}, - ) - ) - - invalid_knowledge_types = _collect_invalid_knowledge_types(conn) - facts["invalid_knowledge_type_values"] = invalid_knowledge_types - if invalid_knowledge_types: - checks.append( - CheckItem( - "CP-12", - "error", - "invalid paragraph knowledge_type values detected", - {"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)}, - ) - ) - finally: - conn.close() - else: - checks.append( - CheckItem( - "META-00", - "warning", - "metadata.db not found, schema checks skipped", - ) - ) - - graph_meta_path = data_dir / "graph" / "graph_metadata.pkl" - facts["graph_metadata_exists"] = graph_meta_path.exists() - if relation_count > 0: - if not graph_meta_path.exists(): - checks.append( - CheckItem( - "CP-06", - "error", - "relations exist but graph metadata missing", - ) - ) - else: - try: - with open(graph_meta_path, "rb") as f: - graph_meta = pickle.load(f) - edge_hash_map = graph_meta.get("edge_hash_map", {}) - edge_hash_map_size = len(edge_hash_map) if isinstance(edge_hash_map, dict) else 0 - facts["edge_hash_map_size"] = edge_hash_map_size - if edge_hash_map_size <= 0: - checks.append( - CheckItem( - "CP-06", - "error", - "edge_hash_map missing/empty while relations exist", - ) - ) - except Exception as e: - checks.append( - CheckItem( - "CP-06", - "error", - f"failed to read graph metadata: {e}", - ) - ) - - has_error = any(c.level == "error" for c in checks) - return { - "ok": not has_error, - "checks": [c.to_dict() for c in checks], - "facts": facts, - } - - -def _migrate_config(config_doc: Dict[str, Any]) -> Dict[str, Any]: - changes: Dict[str, Any] = {} - - routing = _ensure_table(config_doc, "routing") - mode_raw = str(routing.get("tool_search_mode", "forward") or "").strip().lower() - mode_new = mode_raw - if mode_raw == "legacy" or mode_raw not in {"forward", "disabled"}: - mode_new = "forward" - if mode_new != mode_raw: - routing["tool_search_mode"] = mode_new - changes["routing.tool_search_mode"] = {"old": mode_raw, "new": mode_new} - - summary = _ensure_table(config_doc, "summarization") - summary_model = summary.get("model_name", ["auto"]) - if isinstance(summary_model, str): - normalized = [summary_model.strip() or "auto"] - summary["model_name"] = normalized - changes["summarization.model_name"] = {"old": summary_model, "new": normalized} - elif not isinstance(summary_model, list): - normalized = ["auto"] - summary["model_name"] = normalized - changes["summarization.model_name"] = {"old": str(type(summary_model)), "new": normalized} - elif any(not isinstance(x, str) for x in summary_model): - normalized = [str(x).strip() for x in summary_model if str(x).strip()] - if not normalized: - normalized = ["auto"] - summary["model_name"] = normalized - changes["summarization.model_name"] = {"old": summary_model, "new": normalized} - - default_knowledge_type = str(summary.get("default_knowledge_type", "narrative") or "").strip().lower() - allowed_knowledge_types = {item.value for item in KnowledgeType} - if default_knowledge_type not in allowed_knowledge_types: - summary["default_knowledge_type"] = "narrative" - changes["summarization.default_knowledge_type"] = { - "old": default_knowledge_type, - "new": "narrative", - } - - embedding = _ensure_table(config_doc, "embedding") - quantization = str(embedding.get("quantization_type", "int8") or "").strip().lower() - if quantization != "int8": - embedding["quantization_type"] = "int8" - changes["embedding.quantization_type"] = {"old": quantization, "new": "int8"} - - return changes - - -def _migrate_impl(config_path: Path, data_dir: Path, dry_run: bool) -> Dict[str, Any]: - config_doc = _read_toml(config_path) - result: Dict[str, Any] = { - "config_path": str(config_path), - "data_dir": str(data_dir), - "dry_run": bool(dry_run), - "steps": {}, - } - - config_changes = _migrate_config(config_doc) - result["steps"]["config"] = {"changed": bool(config_changes), "changes": config_changes} - if config_changes and not dry_run: - _write_toml(config_path, config_doc) - - vectors_dir = data_dir / "vectors" - vectors_dir.mkdir(parents=True, exist_ok=True) - npy_path = vectors_dir / "vectors.npy" - bin_path = vectors_dir / "vectors.bin" - ids_bin_path = vectors_dir / "vectors_ids.bin" - if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): - if dry_run: - result["steps"]["vector"] = {"migrated": False, "reason": "dry_run"} - else: - dim = _guess_vector_dimension(config_doc, vectors_dir) - store = VectorStore( - dimension=max(1, int(dim)), - quantization_type=QuantizationType.INT8, - data_dir=vectors_dir, - ) - result["steps"]["vector"] = store.migrate_legacy_npy(vectors_dir) - else: - result["steps"]["vector"] = {"migrated": False, "reason": "not_required"} - - metadata_dir = data_dir / "metadata" - metadata_dir.mkdir(parents=True, exist_ok=True) - metadata_db = metadata_dir / "metadata.db" - triples: List[Tuple[str, str, str, str]] = [] - relation_count = 0 - - metadata_result: Dict[str, Any] = {"migrated": False, "reason": "not_required"} - if metadata_db.exists(): - store = MetadataStore(data_dir=metadata_dir) - store.connect(enforce_schema=False) - try: - if dry_run: - metadata_result = {"migrated": False, "reason": "dry_run"} - else: - metadata_result = store.run_legacy_migration_for_vnext() - relation_count = int(store.count_relations()) - if relation_count > 0: - triples = [(str(s), str(p), str(o), str(h)) for s, p, o, h in store.get_all_triples()] - finally: - store.close() - result["steps"]["metadata"] = metadata_result - - graph_dir = data_dir / "graph" - graph_dir.mkdir(parents=True, exist_ok=True) - graph_matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr") - graph_store = GraphStore(matrix_format=graph_matrix_format, data_dir=graph_dir) - graph_step: Dict[str, Any] = { - "rebuilt": False, - "mapped_hashes": 0, - "relation_count": relation_count, - "topology_rebuilt_from_relations": False, - } - if relation_count > 0: - if dry_run: - graph_step["reason"] = "dry_run" - else: - if graph_store.has_data(): - graph_store.load() - - mapped = graph_store.rebuild_edge_hash_map(triples) - - # 兜底:历史数据里 graph 节点/边与 relations 脱节时,直接从 relations 重建图。 - if mapped <= 0 or not graph_store.has_edge_hash_map(): - nodes = sorted({s for s, _, o, _ in triples} | {o for _, _, o, _ in triples}) - edges = [(s, o) for s, _, o, _ in triples] - hashes = [h for _, _, _, h in triples] - - graph_store.clear() - if nodes: - graph_store.add_nodes(nodes) - if edges: - mapped = graph_store.add_edges(edges, relation_hashes=hashes) - else: - mapped = 0 - graph_step.update( - { - "topology_rebuilt_from_relations": True, - "rebuilt_nodes": len(nodes), - "rebuilt_edges": int(graph_store.num_edges), - } - ) - - graph_store.save() - graph_step.update({"rebuilt": True, "mapped_hashes": int(mapped)}) - else: - graph_step["reason"] = "no_relations" - result["steps"]["graph"] = graph_step - - return result - - -def _verify_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: - checks: List[CheckItem] = [] - facts: Dict[str, Any] = { - "config_path": str(config_path), - "data_dir": str(data_dir), - } - - if not config_path.exists(): - checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}")) - return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts} - - config_doc = _read_toml(config_path) - mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower() - if mode not in {"forward", "disabled"}: - checks.append(CheckItem("CP-04", "error", f"invalid routing.tool_search_mode: {mode}")) - - summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"]) - if not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model): - checks.append(CheckItem("CP-11", "error", "summarization.model_name must be List[str]")) - summary_knowledge_type = str( - _get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative" - ).strip().lower() - if summary_knowledge_type not in {item.value for item in KnowledgeType}: - checks.append( - CheckItem("CP-13", "error", f"invalid summarization.default_knowledge_type: {summary_knowledge_type}") - ) - - quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower() - if quantization != "int8": - checks.append(CheckItem("UG-07", "error", "embedding.quantization_type must be int8")) - - vectors_dir = data_dir / "vectors" - npy_path = vectors_dir / "vectors.npy" - bin_path = vectors_dir / "vectors.bin" - ids_bin_path = vectors_dir / "vectors_ids.bin" - if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): - checks.append(CheckItem("CP-07", "error", "legacy vectors.npy still exists without bin migration")) - - metadata_dir = data_dir / "metadata" - store = MetadataStore(data_dir=metadata_dir) - try: - store.connect(enforce_schema=True) - schema_version = store.get_schema_version() - facts["schema_version"] = schema_version - if schema_version != SCHEMA_VERSION: - checks.append(CheckItem("CP-08", "error", f"schema version mismatch: {schema_version}")) - - relation_count = int(store.count_relations()) - facts["relations_count"] = relation_count - - conflicts = {} - invalid_knowledge_types: List[str] = [] - db_path = metadata_dir / "metadata.db" - if db_path.exists(): - conn = sqlite3.connect(str(db_path)) - try: - conflicts = _collect_hash_alias_conflicts(conn) - invalid_knowledge_types = _collect_invalid_knowledge_types(conn) - finally: - conn.close() - if conflicts: - checks.append( - CheckItem( - "CP-05", - "error", - "alias conflicts still exist after migration", - {"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)}, - ) - ) - if invalid_knowledge_types: - checks.append( - CheckItem( - "CP-12", - "error", - "invalid paragraph knowledge_type values remain after migration", - {"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)}, - ) - ) - - if relation_count > 0: - graph_dir = data_dir / "graph" - if not (graph_dir / "graph_metadata.pkl").exists(): - checks.append(CheckItem("CP-06", "error", "graph metadata missing while relations exist")) - else: - matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr") - graph_store = GraphStore(matrix_format=matrix_format, data_dir=graph_dir) - graph_store.load() - if not graph_store.has_edge_hash_map(): - checks.append(CheckItem("CP-06", "error", "edge_hash_map is empty")) - except Exception as e: - checks.append(CheckItem("CP-08", "error", f"metadata strict connect failed: {e}")) - finally: - try: - store.close() - except Exception: - pass - - has_error = any(c.level == "error" for c in checks) - return { - "ok": not has_error, - "checks": [c.to_dict() for c in checks], - "facts": facts, - } - - -def _print_report(title: str, report: Dict[str, Any]) -> None: - print(f"=== {title} ===") - print(f"ok: {bool(report.get('ok', True))}") - facts = report.get("facts", {}) - if facts: - print("facts:") - for k in sorted(facts.keys()): - print(f" - {k}: {facts[k]}") - checks = report.get("checks", []) - if checks: - print("checks:") - for item in checks: - print(f" - [{item.get('level')}] {item.get('code')}: {item.get('message')}") - else: - print("checks: none") - - -def _write_json_if_needed(path: str, payload: Dict[str, Any]) -> None: - if not path: - return - out = Path(path).expanduser().resolve() - out.parent.mkdir(parents=True, exist_ok=True) - out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - print(f"json_out: {out}") - - -def main() -> int: - parser = _build_arg_parser() - args = parser.parse_args() - config_path = Path(args.config).expanduser().resolve() - if not config_path.exists(): - print(f"❌ config not found: {config_path}") - return 2 - config_doc = _read_toml(config_path) - data_dir = _resolve_data_dir(config_doc, args.data_dir) - - if args.command == "preflight": - report = _preflight_impl(config_path, data_dir) - _print_report("vNext Preflight", report) - _write_json_if_needed(args.json_out, report) - has_error = any(item.get("level") == "error" for item in report.get("checks", [])) - if args.strict and has_error: - return 1 - return 0 - - if args.command == "migrate": - payload = _migrate_impl(config_path, data_dir, dry_run=bool(args.dry_run)) - print("=== vNext Migrate ===") - print(json.dumps(payload, ensure_ascii=False, indent=2)) - - verify_report = None - if args.verify_after and not args.dry_run: - verify_report = _verify_impl(config_path, data_dir) - _print_report("vNext Verify (after migrate)", verify_report) - payload["verify_after"] = verify_report - - _write_json_if_needed(args.json_out, payload) - if verify_report is not None: - has_error = any(item.get("level") == "error" for item in verify_report.get("checks", [])) - if has_error: - return 1 - return 0 - - if args.command == "verify": - report = _verify_impl(config_path, data_dir) - _print_report("vNext Verify", report) - _write_json_if_needed(args.json_out, report) - has_error = any(item.get("level") == "error" for item in report.get("checks", [])) - if args.strict and has_error: - return 1 - return 0 - - return 2 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/runtime_self_check.py b/plugins/A_memorix/scripts/runtime_self_check.py deleted file mode 100644 index 70c423ac..00000000 --- a/plugins/A_memorix/scripts/runtime_self_check.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python3 -"""Run A_Memorix runtime self-check against real embedding/runtime configuration.""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import sys -import tempfile -from pathlib import Path -from typing import Any - -import tomlkit - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PLUGIN_ROOT)) - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="A_Memorix runtime self-check") - parser.add_argument( - "--config", - default=str(PLUGIN_ROOT / "config.toml"), - help="config.toml path (default: plugins/A_memorix/config.toml)", - ) - parser.add_argument( - "--data-dir", - default="", - help="optional data dir override; default resolved from config.storage.data_dir", - ) - parser.add_argument( - "--use-config-data-dir", - action="store_true", - help="use config.storage.data_dir directly instead of an isolated temp dir", - ) - parser.add_argument( - "--sample-text", - default="A_Memorix runtime self check", - help="sample text used for real embedding probe", - ) - parser.add_argument("--json", action="store_true", help="print JSON report") - return parser - - -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - -from core.runtime.lifecycle_orchestrator import initialize_storage_async -from core.utils.runtime_self_check import run_embedding_runtime_self_check - - -def _load_config(path: Path) -> dict[str, Any]: - with open(path, "r", encoding="utf-8") as f: - raw = tomlkit.load(f) - return dict(raw) if isinstance(raw, dict) else {} - - -def _nested_get(config: dict[str, Any], key: str, default: Any = None) -> Any: - current: Any = config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - -class _PluginStub: - def __init__(self, config: dict[str, Any]): - self.config = config - self.vector_store = None - self.graph_store = None - self.metadata_store = None - self.embedding_manager = None - self.sparse_index = None - self.relation_write_service = None - - def get_config(self, key: str, default: Any = None) -> Any: - return _nested_get(self.config, key, default) - - -async def _main_async(args: argparse.Namespace) -> int: - config_path = Path(args.config).resolve() - if not config_path.exists(): - print(f"❌ 配置文件不存在: {config_path}") - return 2 - - config = _load_config(config_path) - temp_dir_ctx = None - if args.data_dir: - storage_dir = str(Path(args.data_dir).resolve()) - elif args.use_config_data_dir: - raw_data_dir = str(_nested_get(config, "storage.data_dir", "./data") or "./data").strip() - if raw_data_dir.startswith("."): - storage_dir = str((config_path.parent / raw_data_dir).resolve()) - else: - storage_dir = str(Path(raw_data_dir).resolve()) - else: - temp_dir_ctx = tempfile.TemporaryDirectory(prefix="memorix-runtime-self-check-") - storage_dir = temp_dir_ctx.name - - storage_cfg = config.setdefault("storage", {}) - storage_cfg["data_dir"] = storage_dir - - plugin = _PluginStub(config) - try: - await initialize_storage_async(plugin) - report = await run_embedding_runtime_self_check( - config=config, - vector_store=plugin.vector_store, - embedding_manager=plugin.embedding_manager, - sample_text=str(args.sample_text or "A_Memorix runtime self check"), - ) - report["data_dir"] = storage_dir - report["isolated_data_dir"] = temp_dir_ctx is not None - if args.json: - print(json.dumps(report, ensure_ascii=False, indent=2)) - else: - print("A_Memorix Runtime Self-Check") - print(f"ok: {report.get('ok')}") - print(f"code: {report.get('code')}") - print(f"message: {report.get('message')}") - print(f"configured_dimension: {report.get('configured_dimension')}") - print(f"vector_store_dimension: {report.get('vector_store_dimension')}") - print(f"detected_dimension: {report.get('detected_dimension')}") - print(f"encoded_dimension: {report.get('encoded_dimension')}") - print(f"elapsed_ms: {float(report.get('elapsed_ms', 0.0)):.2f}") - return 0 if bool(report.get("ok")) else 1 - finally: - if plugin.metadata_store is not None: - try: - plugin.metadata_store.close() - except Exception: - pass - if temp_dir_ctx is not None: - temp_dir_ctx.cleanup() - - -def main() -> int: - parser = _build_arg_parser() - args = parser.parse_args() - return asyncio.run(_main_async(args)) - - -if __name__ == "__main__": - raise SystemExit(main())