Merge branch 'r-dev' of https://github.com/Mai-with-u/MaiBot into r-dev
This commit is contained in:
2
src/A_memorix/.gitattributes
vendored
Normal file
2
src/A_memorix/.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
245
src/A_memorix/.gitignore
vendored
Normal file
245
src/A_memorix/.gitignore
vendored
Normal file
@@ -0,0 +1,245 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Cursor
|
||||
# Cursor is an AI-powered code editor.`.cursorignore` specifies files/directories to
|
||||
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
||||
# refer to https://docs.cursor.com/context/ignore-files
|
||||
.cursorignore
|
||||
.cursorindexingignore
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.egg-info/
|
||||
|
||||
# Data & Storage (Privacy & Runtime)
|
||||
data/
|
||||
logs/
|
||||
|
||||
# Deprecated / Cleanup (Avoid uploading junk)
|
||||
deprecated/
|
||||
|
||||
# OS / System
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
ehthumbs.db
|
||||
|
||||
# IDE settings
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
# Temporary Verification Scripts
|
||||
verify_*.py
|
||||
config.toml
|
||||
|
||||
# Test Artifacts & Generated Files
|
||||
MagicMock/
|
||||
benchmark_output.txt
|
||||
e2e_debug.log
|
||||
e2e_error.log
|
||||
full_diff.txt
|
||||
|
||||
# Large Test Data Files
|
||||
机娘导论-openie.json
|
||||
scripts/机娘导论-openie.json
|
||||
|
||||
# A_memorix recall/tuning generated artifacts
|
||||
artifacts/
|
||||
scripts/run_arc_light_recall_pipeline.py
|
||||
|
||||
# Compressed Data Archives
|
||||
data.zip
|
||||
scripts/full_feature_smoke_test.py
|
||||
ACL2026_DEMO_EVAL.md
|
||||
.probe_write
|
||||
tests/
|
||||
temp_verify_v5_data/metadata/metadata.db
|
||||
sql2/t.db
|
||||
sql2/t.db-journal
|
||||
scripts/test.json
|
||||
scripts/test1.json
|
||||
scripts/test-sample.json
|
||||
USAGE_ARCHITECTURE.md
|
||||
scripts/test_conversion.py
|
||||
scripts/debug_graph_vis.py
|
||||
/.tmp_feature_e2e_real
|
||||
/.tmp_sparse_tests
|
||||
/.tmp_test_probe
|
||||
/.tmp_test_sqlite
|
||||
/.tmp_testdata
|
||||
/scripts/tmp
|
||||
718
src/A_memorix/CHANGELOG.md
Normal file
718
src/A_memorix/CHANGELOG.md
Normal file
@@ -0,0 +1,718 @@
|
||||
# 更新日志 (Changelog)
|
||||
|
||||
## [2.0.0] - 2026-03-18
|
||||
|
||||
本次 `2.0.0` 为架构收敛版本,主线是 **SDK Tool 接口统一**、**管理工具能力补齐**、**元数据 schema 升级到 v8** 与 **文档口径同步到 2.0.0**。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`1.0.1` → `2.0.0`
|
||||
- 元数据 schema:`7` → `8`
|
||||
|
||||
### 🚀 重点能力
|
||||
|
||||
- Tool 接口统一:
|
||||
- `plugin.py` 统一通过 `SDKMemoryKernel` 对外提供 Tool 能力。
|
||||
- 保留基础工具:`search_memory / ingest_summary / ingest_text / get_person_profile / maintain_memory / memory_stats`。
|
||||
- 新增管理工具:`memory_graph_admin / memory_source_admin / memory_episode_admin / memory_profile_admin / memory_runtime_admin / memory_import_admin / memory_tuning_admin / memory_v5_admin / memory_delete_admin`。
|
||||
- 检索与写入治理增强:
|
||||
- 检索/写入链路支持 `respect_filter + user_id/group_id` 的聊天过滤语义。
|
||||
- `maintain_memory` 支持 `freeze` 与 `recycle_bin`,并统一到内核维护流程。
|
||||
- 导入与调优能力收敛:
|
||||
- `memory_import_admin` 提供任务化导入能力(上传、粘贴、扫描、OpenIE、LPMM 转换、时序回填、MaiBot 迁移)。
|
||||
- `memory_tuning_admin` 提供检索调优任务(创建、轮次查看、回滚、apply_best、报告导出)。
|
||||
- V5 与删除运维:
|
||||
- 新增 `memory_v5_admin`(`reinforce/weaken/remember_forever/forget/restore/status`)。
|
||||
- 新增 `memory_delete_admin`(`preview/execute/restore/list/get/purge`),支持操作审计与恢复。
|
||||
|
||||
### 🛠️ 存储与运行时
|
||||
|
||||
- `metadata_store` 升级到 `SCHEMA_VERSION = 8`。
|
||||
- 新增/完善外部引用与运维记录能力(包括 `external_memory_refs`、`memory_v5_operations`、`delete_operations` 相关数据结构)。
|
||||
- `SDKMemoryKernel` 增加统一后台任务编排(自动保存、Episode pending 处理、画像刷新、记忆维护)。
|
||||
|
||||
### 📚 文档同步
|
||||
|
||||
- `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 已切换到 `2.0.0` 口径。
|
||||
- 文档主入口统一为 SDK Tool 工作流,不再以旧版 slash 命令作为主说明路径。
|
||||
|
||||
## [1.0.1] - 2026-03-07
|
||||
|
||||
本次 `1.0.1` 为 `1.0.0` 发布后的热修复版本,主线是 **图谱 WebUI 取数稳定性修复**、**大图过滤性能修复** 与 **真实检索调优链路稳定性修复**。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`1.0.0` → `1.0.1`
|
||||
- 配置版本:`4.1.0`(不变)
|
||||
|
||||
### 🛠️ 代码修复
|
||||
|
||||
- 图谱接口稳定性:
|
||||
- 修复 `/api/graph` 在“磁盘已有图文件但运行时尚未装载入内存”场景下返回空图的问题,接口现在会自动补加载持久化图数据。
|
||||
- 修复问题数据集下 WebUI 打开图谱页时看似“没有任何节点”的现象;根因不是图数据消失,而是后端过滤路径过慢。
|
||||
- 图谱过滤性能:
|
||||
- 优化 `/api/graph?exclude_leaf=true` 的叶子过滤逻辑,改为预计算 hub 邻接关系,不再对每个节点反复做高成本边权查询。
|
||||
- 优化 `GraphStore.get_neighbors()` 并补充入邻居访问能力,避免稠密矩阵展开导致的大图性能退化。
|
||||
- 检索调优稳定性:
|
||||
- 修复真实调优任务在构建运行时配置时深拷贝 `plugin.config`,误复制注入的存储实例并触发 `cannot pickle '_thread.RLock' object` 的问题。
|
||||
- 调优评估改为跳过顶层运行时实例键,仅保留纯配置字段后再附加运行时依赖,真实 WebUI 调优任务可正常启动。
|
||||
|
||||
### 📚 文档同步
|
||||
|
||||
- 同步更新 `README.md`、`CHANGELOG.md`、`CONFIG_REFERENCE.md` 与版本元数据(`plugin.py`、`__init__.py`、`_manifest.json`)。
|
||||
- README 新增 `v1.0.1` 修复说明,并补充“调优前先做 runtime self-check”的建议。
|
||||
|
||||
## [1.0.0] - 2026-03-06
|
||||
|
||||
本次 `1.0.0` 为主版本升级,主线是 **运行时架构模块化**、**Episode 情景记忆闭环**、**聚合检索与图召回增强**、**离线迁移 / 运行时自检 / 检索调优中心**。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.7.0` → `1.0.0`
|
||||
- 配置版本:`4.1.0`(不变)
|
||||
|
||||
### 🚀 重点能力
|
||||
|
||||
- 运行时重构:
|
||||
- `plugin.py` 大幅瘦身,生命周期、后台任务、请求路由、检索运行时初始化拆分到 `core/runtime/*`。
|
||||
- 配置 schema 抽离到 `core/config/plugin_config_schema.py`,`_manifest.json` 同步扩展新配置项。
|
||||
- 检索与查询增强:
|
||||
- `KnowledgeQueryTool` 拆分为 query mode + orchestrator,新增长 `aggregate` / `episode` 查询模式。
|
||||
- 新增图辅助关系召回、统一 forward/runtime 构建与请求去重桥接。
|
||||
- Episode / 运维能力:
|
||||
- `metadata_store` schema 升级到 `SCHEMA_VERSION = 7`,新增 `episodes` / `episode_paragraphs` / rebuild queue 等结构。
|
||||
- 新增 `release_vnext_migrate.py`、`runtime_self_check.py`、`rebuild_episodes.py` 与 Web 检索调优页 `web/tuning.html`。
|
||||
|
||||
### 📚 文档同步
|
||||
|
||||
- 版本号同步到 `plugin.py`、`__init__.py`、`_manifest.json`、`README.md` 与 `CONFIG_REFERENCE.md`。
|
||||
- 新增 `RELEASE_SUMMARY_1.0.0.md`
|
||||
|
||||
## [0.7.0] - 2026-03-04
|
||||
|
||||
本次 `0.7.0` 为中版本升级,主线是 **关系向量化闭环(写入 + 状态机 + 回填 + 审计)**、**检索/命令链路增强** 与 **导入任务能力补齐**。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.6.1` → `0.7.0`
|
||||
- 配置版本:`4.1.0`(不变)
|
||||
|
||||
### 🚀 重点能力
|
||||
|
||||
- 关系向量化闭环:
|
||||
- 新增统一关系写入服务 `RelationWriteService`(metadata 先写、向量后写,失败进入状态机而非回滚主数据)。
|
||||
- `relations` 侧补齐 `vector_state/retry_count/last_error/updated_at` 等状态字段,支持 `none/pending/ready/failed` 统一治理。
|
||||
- 插件新增后台回填循环与统计接口,可持续修复关系向量缺失并暴露覆盖率指标。
|
||||
- 检索与命令链路增强:
|
||||
- 检索主链继续收敛到 `search/time` forward 路由,`legacy` 仅保留兼容别名。
|
||||
- relation 查询规格解析收口,结构化查询与语义回退边界更清晰。
|
||||
- `/query stats` 与 tool stats 补充关系向量化统计输出。
|
||||
- 导入与运维增强:
|
||||
- Web Import 新增 `temporal_backfill` 任务入口与编排处理。
|
||||
- 新增一致性审计与离线回填脚本,支持灰度修复历史数据。
|
||||
|
||||
### 📚 文档同步
|
||||
|
||||
- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志版本信息。
|
||||
- `README.md` 新增关系向量审计/回填脚本使用说明,并更新 `convert_lpmm.py` 的关系向量重建行为描述。
|
||||
|
||||
## [0.6.1] - 2026-03-03
|
||||
|
||||
本次 `0.6.1` 为热修复小版本,重点修复 WebUI 插件配置接口在 A_Memorix 场景下的 `tomlkit` 节点序列化兼容问题。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.6.0` → `0.6.1`
|
||||
- 配置版本:`4.1.0`(不变)
|
||||
|
||||
### 🛠️ 代码修复
|
||||
|
||||
- 新增运行时补丁 `_patch_webui_a_memorix_routes_for_tomlkit_serialization()`:
|
||||
- 仅包裹 `/api/webui/plugins/config/{plugin_id}` 及其 schema 的 `GET` 路由。
|
||||
- 仅在 `plugin_id == "A_Memorix"` 时,将返回中的 `config/schema` 通过 `to_builtin_data` 原生化。
|
||||
- 保持 `/api/webui/config/*` 全局接口行为不变,避免对其他插件或核心配置路径产生副作用。
|
||||
- 在插件初始化时执行该补丁,确保 WebUI 读取插件配置时返回结构可稳定序列化。
|
||||
|
||||
### 📚 文档同步
|
||||
|
||||
- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志中的版本信息及修复说明。
|
||||
|
||||
## [0.6.0] - 2026-03-02
|
||||
|
||||
本次 `0.6.0` 为中版本升级,主线是 **Web Import 导入中心上线与脚本能力对齐**、**失败重试机制升级**、**删除后 manifest 同步** 与 **导入链路稳定性增强**。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.5.1` → `0.6.0`
|
||||
- 配置版本:`4.0.1` → `4.1.0`
|
||||
|
||||
### 🚀 重点能力
|
||||
|
||||
- 新增 Web Import 导入中心(`/import`):
|
||||
- 上传/粘贴/本地扫描/LPMM OpenIE/LPMM 转换/时序回填/MaiBot 迁移。
|
||||
- 任务/文件/分块三级状态展示,支持取消与失败重试。
|
||||
- 导入文档弹窗读取(远程优先,失败回退本地)。
|
||||
- 失败重试升级为“分块优先 + 文件回退”:
|
||||
- `POST /api/import/tasks/{task_id}/retry_failed` 保持原路径,语义升级。
|
||||
- 支持对 `extracting` 失败分块进行子集重试。
|
||||
- `writing`/JSON 解析失败自动回退为文件级重试。
|
||||
- 删除后 manifest 同步失效:
|
||||
- 覆盖 `/api/source/batch_delete` 与 `/api/source`。
|
||||
- 返回 `manifest_cleanup` 明细,避免误命中去重跳过重导入。
|
||||
|
||||
### 📂 变更文件清单(本次发布)
|
||||
|
||||
新增文件:
|
||||
|
||||
- `core/utils/web_import_manager.py`
|
||||
- `scripts/migrate_maibot_memory.py`
|
||||
- `web/import.html`
|
||||
|
||||
修改文件:
|
||||
|
||||
- `CHANGELOG.md`
|
||||
- `CONFIG_REFERENCE.md`
|
||||
- `IMPORT_GUIDE.md`
|
||||
- `QUICK_START.md`
|
||||
- `README.md`
|
||||
- `__init__.py`
|
||||
- `_manifest.json`
|
||||
- `components/commands/debug_server_command.py`
|
||||
- `core/embedding/api_adapter.py`
|
||||
- `core/storage/graph_store.py`
|
||||
- `core/utils/summary_importer.py`
|
||||
- `plugin.py`
|
||||
- `requirements.txt`
|
||||
- `server.py`
|
||||
- `web/index.html`
|
||||
|
||||
删除文件:
|
||||
|
||||
- 无
|
||||
|
||||
### 📚 文档同步
|
||||
|
||||
- 同步更新 `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 与本日志。
|
||||
- `IMPORT_GUIDE.md` 新增 “Web Import 导入中心” 专区,统一说明能力范围、状态语义与安全边界。
|
||||
|
||||
## [0.5.1] - 2026-02-23
|
||||
|
||||
本次 `0.5.1` 为热修订小版本,重点修复“随主程序启动的后台任务拉起”“空名单过滤语义”以及“知识抽取模型选择”。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.5.0` → `0.5.1`
|
||||
- 配置版本:`4.0.0` → `4.0.1`
|
||||
|
||||
### 🛠️ 代码修复
|
||||
|
||||
- 生命周期接入主程序事件:
|
||||
- 新增 `a_memorix_start_handler`(`ON_START`)调用 `plugin.on_enable()`;
|
||||
- 新增 `a_memorix_stop_handler`(`ON_STOP`)调用 `plugin.on_disable()`;
|
||||
- 解决仅注册插件但未触发生命周期时,定时导入任务不启动的问题。
|
||||
- 聊天过滤空列表策略调整:
|
||||
- `whitelist + []`:全部拒绝;
|
||||
- `blacklist + []`:全部放行。
|
||||
- 知识抽取模型选择逻辑调整(`import_command._select_model`):
|
||||
- `advanced.extraction_model` 现在支持三种语义:任务名 / 模型名 / `auto`;
|
||||
- `auto` 优先抽取相关任务(`lpmm_entity_extract`、`lpmm_rdf_build` 等),并避免误落到 `embedding`;
|
||||
- 当配置无法识别时输出告警并回退自动选择,提高导入阶段的模型选择可预期性。
|
||||
|
||||
### 📚 文档同步
|
||||
|
||||
- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与 `CHANGELOG.md`。
|
||||
- 同步修正文档中的空名单过滤行为描述,保持与当前代码一致。
|
||||
|
||||
## [0.5.0] - 2026-02-15
|
||||
|
||||
本次 `0.5.0` 以提交 `66ddc1b98547df3c866b19a3f5dc96e1c8eb7731` 为核心,主线是“人物画像能力上线 + 工具/命令接入 + 版本与文档同步”。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.4.0` → `0.5.0`
|
||||
- 配置版本:`3.1.0` → `4.0.0`
|
||||
|
||||
### 🚀 人物画像主特性(核心)
|
||||
|
||||
- 新增人物画像服务:`core/utils/person_profile_service.py`
|
||||
- 支持 `person_id/姓名/别名` 解析。
|
||||
- 聚合图关系证据 + 向量证据,生成画像文本并版本化快照。
|
||||
- 支持手工覆盖(override)与 TTL 快照复用。
|
||||
- 存储层新增人物画像相关表与 API:`core/storage/metadata_store.py`
|
||||
- `person_profile_switches`
|
||||
- `person_profile_snapshots`
|
||||
- `person_profile_active_persons`
|
||||
- `person_profile_overrides`
|
||||
- 新增命令:`/person_profile on|off|status`
|
||||
- 文件:`components/commands/person_profile_command.py`
|
||||
- 作用:按 `stream_id + user_id` 控制自动注入开关(opt-in 模式)。
|
||||
- 查询链路接入人物画像:
|
||||
- `knowledge_query_tool` 新增 `query_type=person`,支持 `person_id` 或别名查询。
|
||||
- `/query person` 与 `/query p` 接入画像查询输出。
|
||||
- 插件生命周期接入画像刷新任务:
|
||||
- 启动/停止统一管理 `person_profile_refresh` 后台任务。
|
||||
- 按活跃窗口自动刷新画像快照。
|
||||
|
||||
### 🛠️ 版本与 schema 同步
|
||||
|
||||
- `plugin.py`:`plugin_version` 更新为 `0.5.0`。
|
||||
- `plugin.py`:`plugin.config_version` 默认值更新为 `4.0.0`。
|
||||
- `config.toml`:`config_version` 基线同步为 `4.0.0`(本地配置文件)。
|
||||
- `__init__.py`:`__version__` 更新为 `0.5.0`。
|
||||
- `_manifest.json`:`version` 更新为 `0.5.0`,`manifest_version` 保持 `1` 。
|
||||
- `manifest_utils.py`:仓库内已兼容更高 manifest 版本;但插件发布默认保持 `manifest_version=1` 。
|
||||
|
||||
### 📚 文档同步
|
||||
|
||||
- 更新 `README.md`、`CONFIG_REFERENCE.md`、`QUICK_START.md`、`USAGE_ARCHITECTURE.md`。
|
||||
- 0.5.0 文档主线改为“人物画像能力 + 版本升级 + 检索链路补充说明”。
|
||||
|
||||
## [0.4.0] - 2026-02-13
|
||||
|
||||
本次 `0.4.0` 版本整合了时序检索增强与后续检索链路增强、稳定性修复和文档同步。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.3.3` → `0.4.0`
|
||||
- 配置版本:`3.0.0` → `3.1.0`
|
||||
|
||||
### 🚀 新增
|
||||
|
||||
- 新增 `core/retrieval/sparse_bm25.py`
|
||||
- `SparseBM25Config` / `SparseBM25Index`
|
||||
- FTS5 + BM25 稀疏检索
|
||||
- 支持 `jieba/mixed/char_2gram` 分词与懒加载
|
||||
- 支持 ngram 倒排回退与可选 LIKE 兜底
|
||||
- `DualPathRetriever` 新增 sparse/fusion 配置注入:
|
||||
- embedding 不可用时自动 sparse 回退;
|
||||
- `hybrid` 模式支持向量路 + sparse 路并行候选;
|
||||
- 新增 `FusionConfig` 与 `weighted_rrf` 融合。
|
||||
- `MetadataStore` 新增 FTS/倒排能力:
|
||||
- `paragraphs_fts`、`relations_fts` schema 与回填;
|
||||
- `paragraph_ngrams` 倒排索引与回填;
|
||||
- `fts_search_bm25` / `fts_search_relations_bm25` / `ngram_search_paragraphs`。
|
||||
|
||||
### 🛠️ 组件链路同步
|
||||
|
||||
- `plugin.py`
|
||||
- 新增 `[retrieval.sparse]`、`[retrieval.fusion]` 默认配置;
|
||||
- 初始化并向组件注入 `sparse_index`;
|
||||
- `on_disable` 支持按配置卸载 sparse 连接并释放缓存。
|
||||
- `knowledge_search_action.py` / `query_command.py` / `knowledge_query_tool.py`
|
||||
- 统一接入 sparse/fusion 配置;
|
||||
- 统一注入 `sparse_index`;
|
||||
- `stats` 输出新增 sparse 状态观测。
|
||||
- `requirements.txt`
|
||||
- 新增 `jieba>=0.42.1`(未安装时自动回退 char n-gram)。
|
||||
|
||||
### 🧯 修复与行为调整
|
||||
|
||||
- 修复 `retrieval.ppr_concurrency_limit` 不生效问题:
|
||||
- `DualPathRetriever` 使用配置值初始化 `_ppr_semaphore`,不再被固定值覆盖。
|
||||
- 修复 `char_2gram` 召回失效场景:
|
||||
- FTS miss 时增加 `_fallback_substring_search`,优先 ngram 倒排回退,按配置可选 LIKE 兜底。
|
||||
- 提升可观测性与兼容性:
|
||||
- `get_statistics()` 对向量规模字段兼容读取 `size -> num_vectors -> 0`,避免属性缺失导致异常。
|
||||
- `/query stats` 与 `knowledge_query` 输出包含 sparse 状态(enabled/loaded/tokenizer/doc_count)。
|
||||
|
||||
### 📚 文档
|
||||
|
||||
- `README.md`
|
||||
- 新增检索增强说明、稀疏行为说明、时序回填脚本入口。
|
||||
- `CONFIG_REFERENCE.md`
|
||||
- 补齐 sparse/fusion 参数与触发规则、回退链路、融合实现细节。
|
||||
|
||||
### ⏱️ 时序检索与导入增强
|
||||
|
||||
#### 时序检索能力(分钟级)
|
||||
|
||||
- 新增统一时序查询入口:
|
||||
- `/query time`(别名 `/query t`)
|
||||
- `knowledge_query(query_type=time)`
|
||||
- `knowledge_search(query_type=time|hybrid)`
|
||||
- 查询时间参数统一支持:
|
||||
- `YYYY/MM/DD`
|
||||
- `YYYY/MM/DD HH:mm`
|
||||
- 日期参数自动展开边界:
|
||||
- `from/time_from` -> `00:00`
|
||||
- `to/time_to` -> `23:59`
|
||||
- 查询结果统一回传 `metadata.time_meta`,包含命中时间窗口与命中依据(事件时间或 `created_at` 回退)。
|
||||
|
||||
#### 存储与检索链路
|
||||
|
||||
- 段落存储层支持时序字段:
|
||||
- `event_time`
|
||||
- `event_time_start`
|
||||
- `event_time_end`
|
||||
- `time_granularity`
|
||||
- `time_confidence`
|
||||
- 时序命中采用区间相交逻辑,并遵循“双层时间语义”:
|
||||
- 优先 `event_time/event_time_range`
|
||||
- 缺失时回退 `created_at`(可配置关闭)
|
||||
- 检索排序规则保持:语义优先,时间次排序(新到旧)。
|
||||
- `process_knowledge.py` 新增 `--chat-log` 参数:
|
||||
- 启用后强制使用 `narrative` 策略;
|
||||
- 使用 LLM 对聊天文本进行语义时间抽取(支持相对时间转绝对时间),写入 `event_time/event_time_start/event_time_end`。
|
||||
- 新增 `--chat-reference-time`,用于指定相对时间语义解析的参考时间点。
|
||||
|
||||
#### Schema 与文档同步
|
||||
|
||||
- `_manifest.json` 同步补齐 `retrieval.temporal` 配置 schema。
|
||||
- 配置 schema 版本升级:`config_version` 从 `3.0.0` 提升到 `3.1.0`(`plugin.py` / `config.toml` / 配置文档同步)。
|
||||
- 更新 `README.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md`,补充时序检索入口、参数格式与导入时间字段说明。
|
||||
|
||||
## [0.3.3] - 2026-02-11
|
||||
|
||||
本次更新为 **语言一致性补丁版本**,重点收敛知识抽取时的语言漂移问题,要求输出严格贴合原文语言,不做翻译改写。
|
||||
|
||||
### 🛠️ 关键修复
|
||||
|
||||
#### 抽取语言约束
|
||||
|
||||
- `BaseStrategy`:
|
||||
- 移除按 `zh/en/mixed` 分支的语言类型判定逻辑;
|
||||
- 统一为单一约束:抽取值保持原文语言、保留原始术语、禁止翻译。
|
||||
- `NarrativeStrategy` / `FactualStrategy`:
|
||||
- 抽取提示词统一接入上述语言约束;
|
||||
- 明确要求 JSON 键名固定、抽取值遵循原文语言表达。
|
||||
|
||||
#### 导入链路一致性
|
||||
|
||||
- `ImportCommand` 的 LLM 抽取提示词同步强化“优先原文语言、不要翻译”要求,避免脚本与指令导入行为不一致。
|
||||
|
||||
#### 测试与文档
|
||||
|
||||
- 更新 `test_strategies.py`,将语言判定测试调整为统一语言约束测试,并验证提示词中包含禁止翻译约束。
|
||||
- 同步更新注释与文档描述,确保实现与说明一致。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.3.2` → `0.3.3`
|
||||
|
||||
## [0.3.2] - 2026-02-11
|
||||
|
||||
本次更新为 **V5 稳定性与兼容性修复版本**,在保持原有业务设计(强化→衰减→冷冻→修剪→回收)的前提下,修复关键链路断裂与误判问题。
|
||||
|
||||
### 🛠️ 关键修复
|
||||
|
||||
#### V5 记忆系统契约与链路
|
||||
|
||||
- `MetadataStore`:
|
||||
- 统一 `mark_relations_inactive(hashes, inactive_since=None)` 调用契约,兼容不同调用方;
|
||||
- 补充 `has_table(table_name)`;
|
||||
- 增加 `restore_relation(hash)` 兼容别名,修复服务层恢复调用断裂;
|
||||
- 修正 `get_entity_gc_candidates` 对孤立节点参数的处理(支持节点名映射到实体 hash)。
|
||||
- `GraphStore`:
|
||||
- 清理 `deactivate_edges` 重复定义并统一返回冻结数量,保证上层日志与断言稳定。
|
||||
- `server.py`:
|
||||
- 修复 `/api/memory/restore` relation 恢复链路;
|
||||
- 清理不可达分支并统一异常路径;
|
||||
- 回收站查询在表检测场景下不再出现错误退空。
|
||||
|
||||
#### 命令与模型选择
|
||||
|
||||
- `/memory` 命令修复 hash 长度判定:以 64 位 `sha256` 为标准,同时兼容历史 32 位输入。
|
||||
- 总结模型选择修复:
|
||||
- 解决 `summarization.model_name = auto` 误命中 `embedding` 问题;
|
||||
- 支持数组与选择器语法(`task:model` / task / model);
|
||||
- 兼容逗号分隔字符串写法(如 `"utils:model1","utils:model2",replyer`)。
|
||||
|
||||
#### 生命周期与脚本稳定性
|
||||
|
||||
- `plugin.py` 修复后台任务生命周期管理:
|
||||
- 增加 `_scheduled_import_task` / `_auto_save_task` / `_memory_maintenance_task` 句柄;
|
||||
- 避免重复启动;
|
||||
- 插件停用时统一 cancel + await 收敛。
|
||||
- `process_knowledge.py` 修复 tenacity 重试日志级别类型错误(`"WARNING"` → `logging.WARNING`),避免 `KeyError: 'WARNING'`。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.3.1` → `0.3.2`
|
||||
|
||||
## [0.3.1] - 2026-02-07
|
||||
|
||||
本次更新为 **稳定性补丁版本**,主要修复脚本导入链路、删除安全性与 LPMM 转换一致性问题。
|
||||
|
||||
### 🛠️ 关键修复
|
||||
|
||||
#### 新增功能
|
||||
|
||||
- 新增 `scripts/convert_lpmm.py`:
|
||||
- 支持将 LPMM 的 `parquet + graph` 数据直接转换为 A_Memorix 存储结构;
|
||||
- 提供 LPMM ID 到 A_Memorix ID 的映射能力,用于图节点/边重写;
|
||||
- 当前实现优先保证检索一致性,关系向量采用安全策略(不直接导入)。
|
||||
|
||||
#### 导入链路
|
||||
|
||||
- 修复 `import_lpmm_json.py` 依赖的 `AutoImporter.import_json_data` 公共入口缺失/不稳定问题,确保外部脚本可稳定调用 JSON 直导入流程。
|
||||
|
||||
#### 删除安全
|
||||
|
||||
- 修复按来源删除时“同一 `(subject, object)` 存在多关系”场景下的误删风险:
|
||||
- `MetadataStore.delete_paragraph_atomic` 新增 `relation_prune_ops`;
|
||||
- 仅在无兄弟关系时才回退删除整条边。
|
||||
- `delete_knowledge.py` 新增保守孤儿实体清理(仅对本次候选实体执行,且需同时满足无段落引用、无关系引用、图无邻居)。
|
||||
- `delete_knowledge.py` 改为读取向量元数据中的真实维度,避免 `dimension=1` 写回污染。
|
||||
|
||||
#### LPMM 转换修复
|
||||
|
||||
- 修复 `convert_lpmm.py` 中向量 ID 与 `MetadataStore` 哈希不一致导致的检索反查失败问题。
|
||||
- 为避免脏召回,转换阶段暂时跳过 `relation.parquet` 的直接向量导入(待关系元数据一一映射能力完善后再恢复)。
|
||||
|
||||
### 🔖 版本信息
|
||||
|
||||
- 插件版本:`0.3.0` → `0.3.1`
|
||||
|
||||
## [0.3.0] - 2026-01-30
|
||||
|
||||
本次更新引入了 **V5 动态记忆系统**,实现了符合生物学特性的记忆衰减、强化与全声明周期管理,并提供了配套的指令与工具。
|
||||
|
||||
### 🧠 记忆系统 (V5)
|
||||
|
||||
#### 核心机制
|
||||
|
||||
- **记忆衰减 (Decay)**: 引入"遗忘曲线",随时间推移自动降低图谱连接权重。
|
||||
- **访问强化 (Reinforcement)**: "越用越强",每次检索命中都会刷新记忆活跃度并增强权重。
|
||||
- **生命周期 (Lifecycle)**:
|
||||
- **活跃 (Active)**: 正常参与计算与检索。
|
||||
- **冷冻 (Inactive)**: 权重过低被冻结,不再参与 PPR 计算,但保留语义映射 (Mapping)。
|
||||
- **修剪 (Prune)**: 过期且无保护的冷冻记忆将被移入回收站。
|
||||
- **多重保护**: 支持 **永久锁定 (Pin)** 与 **限时保护 (TTL)**,防止关键记忆被误删。
|
||||
|
||||
#### GraphStore
|
||||
|
||||
- **多关系映射**: 实现 `(u,v) -> Set[Hash]` 映射,确保同一通道下的多重语义关系互不干扰。
|
||||
- **原子化操作**: 新增 `decay`, `deactivate_edges` (软删), `prune_relation_hashes` (硬删) 等原子操作。
|
||||
|
||||
### 🛠️ 指令与工具
|
||||
|
||||
#### Memory Command (`/memory`)
|
||||
|
||||
新增全套记忆维护指令:
|
||||
|
||||
- `/memory status`: 查看记忆系统健康状态(活跃/冷冻/回收站计数)。
|
||||
- `/memory protect <query> [hours]`: 保护记忆。不填时间为永久锁定(Pin),填时间为临时保护(TTL)。
|
||||
- `/memory reinforce <query>`: 手动强化记忆(绕过冷却时间)。
|
||||
- `/memory restore <hash>`: 从回收站恢复误删记忆(仅当节点存在时重建连接)。
|
||||
|
||||
#### MemoryModifierTool
|
||||
|
||||
- **LLM 能力增强**: 更新工具逻辑,支持 LLM 自主触发 `reinforce`, `weaken`, `remember_forever`, `forget` 操作,并自动映射到 V5 底层逻辑。
|
||||
|
||||
### ⚙️ 配置 (`config.toml`)
|
||||
|
||||
新增 `[memory]` 配置节:
|
||||
|
||||
- `half_life_hours`: 记忆半衰期 (默认 24h)。
|
||||
- `enable_auto_reinforce`: 是否开启检索自动强化。
|
||||
- `prune_threshold`: 冷冻/修剪阈值 (默认 0.1)。
|
||||
|
||||
### 💻 WebUI (v1.4)
|
||||
|
||||
实现了与 V5 记忆系统深度集成的全生命周期管理界面:
|
||||
|
||||
- **可视化增强**:
|
||||
- **冷冻状态**: 非活跃记忆以 **虚线 + 灰色 (Slate-300)** 显示。
|
||||
- **保护状态**: 被 Pin 或保护的记忆带有 **金色 (Amber) 光晕**。
|
||||
- **交互升级**:
|
||||
- **记忆回收站**: 新增 Dock 入口与专用面板,支持浏览删除记录并一键恢复。
|
||||
- **快捷操作**: 边属性面板新增 **强化 (Reinforce)**、**保护 (Protect/Pin)**、**冷冻 (Freeze)** 按钮。
|
||||
- **实时反馈**: 操作后自动刷新图谱布局与样式。
|
||||
|
||||
---
|
||||
|
||||
## [0.2.3] - 2026-01-30
|
||||
|
||||
本次更新主要集中在 **WebUI 交互体验优化** 与 **文档/配置的规范化**。
|
||||
|
||||
### 🎨 WebUI (v1.3)
|
||||
|
||||
#### 加载与同步体验升级
|
||||
|
||||
- **沉浸式加载**: 全新设计的加载遮罩,采用磨砂玻璃背景 (`backdrop-filter`) 与呼吸灯文字动效,提升视觉质感。
|
||||
- **精准状态反馈**: 优化加载逻辑,明确区分“网络同步”与“拓扑计算”阶段,解决数据加载时的闪烁问题。
|
||||
- **新手引导**: 在加载界面新增基础操作提示,降低新用户上手门槛。
|
||||
|
||||
#### 全功能帮助面板
|
||||
|
||||
- **操作指南重构**: 全面翻新“操作指南”面板,新增 Dock 栏功能详解、编辑管理操作及视图配置说明。
|
||||
|
||||
### 🛠️ 工程与规范
|
||||
|
||||
#### plugin.py
|
||||
|
||||
- **配置描述补全**: 修复了 `config_section_descriptions` 中缺失 `summarization`, `schedule`, `filter` 节导致的问题。
|
||||
- **版本号**: `0.2.2` → `0.2.3`
|
||||
|
||||
### ⚙️ 核心与服务
|
||||
|
||||
#### Core
|
||||
|
||||
- **量化逻辑修正**: 修正了 `_scalar_quantize_int8` 函数,确保向量值正确映射到 `[-128, 127]` 区间,提高量化精度。
|
||||
|
||||
#### Server
|
||||
|
||||
- **缓存一致性**: 在执行删除节点/边等修改操作后,显式清除 `_relation_cache`,确保前端获取的关系数据实时更新。
|
||||
|
||||
### 🤖 脚本与数据处理
|
||||
|
||||
#### process_knowledge.py
|
||||
|
||||
- **策略模式重构**: 引入了 `Strategy-Aware` 架构,支持通过 `Narrative` (叙事), `Factual` (事实), `Quote` (引用) 三种策略差异化处理文本(准确说是确认实装)(默认采用 Narrative模式)。
|
||||
- **智能分块纠错**: 新增“分块拯救” (`Chunk Rescue`) 机制,可在长叙事文本中自动识别并提取内嵌的歌词或诗句。
|
||||
|
||||
#### import_lpmm_json.py
|
||||
|
||||
- **LPMM 迁移工具**: 增加了对 LPMM OpenIE JSON 格式的完整支持,能够自动计算 Hash 并迁移实体/关系数据,确保与 A_Memorix 存储格式兼容。
|
||||
|
||||
#### Project
|
||||
|
||||
- **构建清理**: 优化 `.gitignore` 规则
|
||||
|
||||
---
|
||||
|
||||
## [0.2.2] - 2026-01-27
|
||||
|
||||
本次更新专注于提高 **网络请求的鲁棒性**,特别是针对嵌入服务的调用。
|
||||
|
||||
### 🛠️ 稳定性与工程改进
|
||||
|
||||
#### EmbeddingAPI
|
||||
|
||||
- **可配置重试机制**: 新增 `[embedding.retry]` 配置项,允许自定义最大重试次数和等待时间。默认重试次数从 3 次增加到 10 次,以更好应对网络波动。
|
||||
- **配置项**:
|
||||
- `max_attempts`: 最大重试次数 (默认: 10)
|
||||
- `max_wait_seconds`: 最大等待时间 (默认: 30s)
|
||||
- `min_wait_seconds`: 最小等待时间 (默认: 2s)
|
||||
|
||||
#### plugin.py
|
||||
|
||||
- **版本号**: `0.2.1` → `0.2.2`
|
||||
|
||||
---
|
||||
|
||||
## [0.2.1] - 2026-01-26
|
||||
|
||||
本次更新重点在于 **可视化交互的全方位重构** 以及 **底层鲁棒性的进一步增强**。
|
||||
|
||||
### 🎨 可视化与交互重构
|
||||
|
||||
#### WebUI (Glassmorphism)
|
||||
|
||||
- **全新视觉设计**: 采用深色磨砂玻璃 (Glassmorphism) 风格,配合动态渐变背景。
|
||||
- **Dock 菜单栏**: 底部新增 macOS 风格 Dock 栏,聚合所有常用功能。
|
||||
- **显著性视图 (Saliency View)**: 基于 **PageRank** 算法的“信息密度”滑块,支持以此过滤叶子节点,仅展示核心骨干或全量细节。
|
||||
- **功能面板**:
|
||||
- **❓ 操作指南**: 内置交互说明与特性介绍。
|
||||
- **🔍 悬浮搜索**: 支持按拼音/ID 实时过滤节点。
|
||||
- **📂 记忆溯源**: 支持按源文件批量查看和删除记忆数据。
|
||||
- **📖 内容字典**: 列表化展示所有实体与关系,支持排序与筛选。
|
||||
|
||||
### 🛠️ 稳定性与工程改进
|
||||
|
||||
#### EmbeddingAPI
|
||||
|
||||
- **鲁棒性增强**: 引入 `tenacity` 实现指数退避重试机制。
|
||||
- **错误处理**: 失败时返回 `NaN` 向量而非零向量,允许上层逻辑安全跳过。
|
||||
|
||||
#### MetadataStore
|
||||
|
||||
- **自动修复**: 自动检测并修复 `vector_index` 列错位(文件名误存)的历史数据问题。
|
||||
- **数据统计**: 新增 `get_all_sources` 接口支持来源统计。
|
||||
|
||||
#### 脚本与工具
|
||||
|
||||
- **用户体验**: 引入 `rich` 库优化终端输出进度条与状态显示。
|
||||
- **接口开放**: `process_knowledge.py` 新增 `import_json_data` 供外部调用。
|
||||
- **LPMM 迁移**: 新增 `import_lpmm_json.py`,支持导入符合 LPMM 规范的 OpenIE JSON 数据。
|
||||
|
||||
#### plugin.py
|
||||
|
||||
- **版本号**: `0.2.0` → `0.2.1`
|
||||
|
||||
---
|
||||
|
||||
## [0.2.0] - 2026-01-22
|
||||
|
||||
> [!CAUTION]
|
||||
> **不完全兼容变更**:v0.2.0 版本重构了底层存储架构。由于数据结构的重大调整,**旧版本的导入数据无法在新版本中完全无损兼容**。
|
||||
> 虽然部分组件支持自动迁移,但为确保数据一致性和检索质量,**强烈建议在升级后重新使用 `process_knowledge.py` 导入原始数据**。
|
||||
|
||||
本次更新为**重大版本升级**,包含向量存储架构重写、检索逻辑强化及多项稳定性改进。
|
||||
|
||||
### 🚀 核心架构重写
|
||||
|
||||
#### VectorStore: SQ8 量化 + Append-Only 存储
|
||||
|
||||
- **全新存储格式**: 从 `.npy` 迁移至 `vectors.bin`(float16 增量追加)和 `vectors_ids.bin`,大幅减少内存占用。
|
||||
- **原生 SQ8 量化**: 使用 Faiss `IndexScalarQuantizer(QT_8bit)`,替代手动 int8 量化逻辑。
|
||||
- **L2 Normalization 强制化**: 所有向量在存储和检索时统一执行 L2 归一化,确保 Inner Product 等价于 Cosine 相似度。
|
||||
- **Fallback 索引机制**: 新增 `IndexFlatIP` 回退索引,在 SQ8 训练完成前提供检索能力,避免冷启动无结果问题。
|
||||
- **Reservoir Sampling 训练采样**: 使用蓄水池采样收集训练数据(上限 10k),保证小数据集和流式导入场景下的训练样本多样性。
|
||||
- **线程安全**: 新增 `threading.RLock` 保护并发读写操作。
|
||||
- **自动迁移**: 支持从旧版 `.npy` 格式自动迁移至新 `.bin` 格式。
|
||||
|
||||
### ✨ 检索功能增强
|
||||
|
||||
#### KnowledgeQueryTool: 智能回退与多跳路径搜索
|
||||
|
||||
- **Smart Fallback (智能回退)**: 当向量检索置信度低于阈值 (默认 0.6) 时,自动尝试提取查询中的实体进行多跳路径搜索(`_path_search`),增强对间接关系的召回能力。
|
||||
- **结果去重 (`_deduplicate_results`)**: 新增基于内容相似度的安全去重逻辑,防止冗余结果污染 LLM 上下文,同时确保至少保留一条结果。
|
||||
- **语义关系检索 (`_semantic_search_relation`)**: 支持自然语言查询关系(无需 `S|P|O` 格式),内部使用 `REL_ONLY` 策略进行向量检索。
|
||||
- **路径搜索 (`_path_search`)**: 新增 `GraphStore.find_paths` 调用,支持查找两个实体间的间接连接路径(最大深度 3,最多 5 条路径)。
|
||||
- **Clean Output**: LLM 上下文中不再包含原始相似度分数,避免模型偏见。
|
||||
|
||||
#### DualPathRetriever: 并发控制与调试模式
|
||||
|
||||
- **PPR 并发限制 (`ppr_concurrency_limit`)**: 新增 Semaphore 控制 PageRank 计算并发数,防止 CPU 峰值过载。
|
||||
- **Debug 模式**: 新增 `debug` 配置项,启用时打印检索结果原文到日志。
|
||||
- **Entity-Pivot 关系检索**: 优化 `_retrieve_relations_only` 策略,通过检索实体后扩展其关联关系,替代直接检索关系向量。
|
||||
|
||||
### ⚙️ 配置与 Schema 扩展
|
||||
|
||||
#### plugin.py
|
||||
|
||||
- **版本号**: `0.1.3` → `0.2.0`
|
||||
- **默认配置版本**: `config_version` 默认值更新为 `2.0.0`
|
||||
- **新增配置项**:
|
||||
- `retrieval.relation_semantic_fallback` (bool): 是否启用关系查询的语义回退。
|
||||
- `retrieval.relation_fallback_min_score` (float): 语义回退的最小相似度阈值。
|
||||
- **相对路径支持**: `storage.data_dir` 现在支持相对路径(相对于插件目录),默认值改为 `./data`。
|
||||
- **全局实例获取**: 新增 `A_MemorixPlugin.get_global_instance()` 静态方法,供组件可靠获取插件实例。
|
||||
|
||||
#### config.toml / \_manifest.json
|
||||
|
||||
- **新增 `ppr_concurrency_limit`**: 控制 PPR 算法并发数。
|
||||
- **新增训练阈值配置**: `embedding.min_train_threshold` 控制触发 SQ8 训练的最小样本数。
|
||||
|
||||
### 🛠️ 稳定性与工程改进
|
||||
|
||||
#### GraphStore
|
||||
|
||||
- **`find_paths` 方法**: 新增多跳路径查找功能,支持 BFS 搜索指定深度内的实体间路径。
|
||||
- **`find_node` 方法**: 新增大小写不敏感的节点查找。
|
||||
|
||||
#### MetadataStore
|
||||
|
||||
- **Schema 迁移**: 自动添加缺失的 `is_permanent`, `last_accessed`, `access_count` 字段。
|
||||
|
||||
#### 脚本与工具
|
||||
|
||||
- **新增脚本**:
|
||||
- `scripts/diagnose_relations_source.py`: 诊断关系溯源问题。
|
||||
- `scripts/verify_search_robustness.py`: 验证检索鲁棒性。
|
||||
- `scripts/run_stress_test.py`, `stress_test_data.py`: 压力测试套件。
|
||||
- `scripts/migrate_canonicalization.py`, `migrate_paragraph_relations.py`: 数据迁移工具。
|
||||
- **目录整理**: 将大量旧版测试脚本移动至 `deprecated/` 目录。
|
||||
|
||||
### 🗑️ 移除与废弃
|
||||
|
||||
- 废弃 `vectors.npy` 存储格式(自动迁移至 `.bin`)。
|
||||
|
||||
---
|
||||
|
||||
## [0.1.3] - 上一个稳定版本
|
||||
|
||||
- 初始发布,包含基础双路检索功能。
|
||||
- 手动 Int8 向量量化。
|
||||
- 基于 `.npy` 的向量存储。
|
||||
359
src/A_memorix/CONFIG_REFERENCE.md
Normal file
359
src/A_memorix/CONFIG_REFERENCE.md
Normal file
@@ -0,0 +1,359 @@
|
||||
# A_Memorix 配置参考 (v2.0.0)
|
||||
|
||||
本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 9`)。
|
||||
|
||||
说明:
|
||||
|
||||
- 本文只覆盖 **当前运行时实际读取** 的配置键。
|
||||
- 默认配置文件路径为 `config/a_memorix.toml`。
|
||||
- 旧版 `/query`、`/memory`、`/visualize` 命令体系相关配置,不再作为主路径说明。
|
||||
- 未配置的键会回退到代码默认值。
|
||||
- 长期记忆控制台已可视化高频常用字段;未展示的长尾高级项仍然有效,请通过“源码模式 / 原始 TOML”编辑。
|
||||
|
||||
## 常用完整配置
|
||||
|
||||
```toml
|
||||
[plugin]
|
||||
enabled = true
|
||||
|
||||
[storage]
|
||||
data_dir = "data/plugins/a-dawn.a-memorix"
|
||||
|
||||
[embedding]
|
||||
model_name = "auto"
|
||||
dimension = 1024
|
||||
batch_size = 32
|
||||
max_concurrent = 5
|
||||
enable_cache = false
|
||||
quantization_type = "int8"
|
||||
|
||||
[embedding.fallback]
|
||||
enabled = true
|
||||
probe_interval_seconds = 180
|
||||
allow_metadata_only_write = true
|
||||
|
||||
[embedding.paragraph_vector_backfill]
|
||||
enabled = true
|
||||
interval_seconds = 60
|
||||
batch_size = 64
|
||||
max_retry = 5
|
||||
|
||||
[retrieval]
|
||||
top_k_paragraphs = 20
|
||||
top_k_relations = 10
|
||||
top_k_final = 10
|
||||
alpha = 0.5
|
||||
enable_ppr = true
|
||||
ppr_alpha = 0.85
|
||||
ppr_timeout_seconds = 1.5
|
||||
ppr_concurrency_limit = 4
|
||||
enable_parallel = true
|
||||
|
||||
[retrieval.sparse]
|
||||
enabled = true
|
||||
backend = "fts5"
|
||||
mode = "auto"
|
||||
tokenizer_mode = "jieba"
|
||||
candidate_k = 80
|
||||
relation_candidate_k = 60
|
||||
|
||||
[threshold]
|
||||
min_threshold = 0.3
|
||||
max_threshold = 0.95
|
||||
percentile = 75.0
|
||||
min_results = 3
|
||||
enable_auto_adjust = true
|
||||
|
||||
[filter]
|
||||
enabled = true
|
||||
mode = "blacklist"
|
||||
chats = []
|
||||
|
||||
[episode]
|
||||
enabled = true
|
||||
generation_enabled = true
|
||||
pending_batch_size = 20
|
||||
pending_max_retry = 3
|
||||
max_paragraphs_per_call = 20
|
||||
max_chars_per_call = 6000
|
||||
source_time_window_hours = 24
|
||||
segmentation_model = "auto"
|
||||
|
||||
[person_profile]
|
||||
enabled = true
|
||||
refresh_interval_minutes = 30
|
||||
active_window_hours = 72
|
||||
max_refresh_per_cycle = 50
|
||||
top_k_evidence = 12
|
||||
|
||||
[memory]
|
||||
enabled = true
|
||||
half_life_hours = 24.0
|
||||
prune_threshold = 0.1
|
||||
freeze_duration_hours = 24.0
|
||||
|
||||
[advanced]
|
||||
enable_auto_save = true
|
||||
auto_save_interval_minutes = 5
|
||||
debug = false
|
||||
|
||||
[web.import]
|
||||
enabled = true
|
||||
max_queue_size = 20
|
||||
max_files_per_task = 200
|
||||
max_file_size_mb = 20
|
||||
max_paste_chars = 200000
|
||||
default_file_concurrency = 2
|
||||
default_chunk_concurrency = 4
|
||||
|
||||
[web.tuning]
|
||||
enabled = true
|
||||
max_queue_size = 8
|
||||
poll_interval_ms = 1200
|
||||
default_intensity = "standard"
|
||||
default_objective = "precision_priority"
|
||||
default_top_k_eval = 20
|
||||
default_sample_size = 24
|
||||
```
|
||||
|
||||
### 可视化与原始 TOML 的分工
|
||||
|
||||
- 长期记忆控制台:适合修改高频项,例如 embedding、检索、Episode、人物画像、导入与调优的常用开关。
|
||||
- 原始 TOML:适合复制整份配置、批量调整参数,或修改未在可视化表单中展示的高级项。
|
||||
- raw-only 高级项仍包括:`retrieval.fusion.*`、`retrieval.search.relation_intent.*`、`retrieval.search.graph_recall.*`、`retrieval.aggregate.*`、`memory.orphan.*`、`advanced.extraction_model`、`web.import.llm_retry.*`、`web.import.path_aliases`、`web.import.convert.*`、`web.tuning.llm_retry.*`、`web.tuning.eval_query_timeout_seconds`。
|
||||
|
||||
## 1. 存储与嵌入
|
||||
|
||||
### `storage`
|
||||
|
||||
- `storage.data_dir` (代码默认 `./data`;当前内置配置推荐 `data/plugins/a-dawn.a-memorix`)
|
||||
: 数据目录。相对路径按 MaiBot 仓库根目录解析。
|
||||
|
||||
### `embedding`
|
||||
|
||||
- `embedding.model_name` (默认 `auto`)
|
||||
: embedding 模型选择。
|
||||
- `embedding.dimension` (默认 `1024`)
|
||||
: 唯一公开的维度控制项。A_Memorix 内部会自动映射为 provider 所需请求字段,并在运行时做真实探测与校验。
|
||||
- `embedding.batch_size` (默认 `32`)
|
||||
- `embedding.max_concurrent` (默认 `5`)
|
||||
- `embedding.enable_cache` (默认 `false`)
|
||||
- `embedding.retry` (默认 `{}`)
|
||||
: embedding 调用重试策略。
|
||||
- `embedding.quantization_type`
|
||||
: 当前主路径仅建议 `int8`。
|
||||
- `embedding.fallback.enabled` (默认 `true`)
|
||||
- `embedding.fallback.probe_interval_seconds` (默认 `180`)
|
||||
- `embedding.fallback.allow_metadata_only_write` (默认 `true`)
|
||||
- `embedding.paragraph_vector_backfill.enabled` (默认 `true`)
|
||||
- `embedding.paragraph_vector_backfill.interval_seconds` (默认 `60`)
|
||||
- `embedding.paragraph_vector_backfill.batch_size` (默认 `64`)
|
||||
- `embedding.paragraph_vector_backfill.max_retry` (默认 `5`)
|
||||
|
||||
## 2. 检索
|
||||
|
||||
### `retrieval` 主键
|
||||
|
||||
- `retrieval.top_k_paragraphs` (默认 `20`)
|
||||
- `retrieval.top_k_relations` (默认 `10`)
|
||||
- `retrieval.top_k_final` (默认 `10`)
|
||||
- `retrieval.alpha` (默认 `0.5`)
|
||||
- `retrieval.enable_ppr` (默认 `true`)
|
||||
- `retrieval.ppr_alpha` (默认 `0.85`)
|
||||
- `retrieval.ppr_timeout_seconds` (默认 `1.5`)
|
||||
- `retrieval.ppr_concurrency_limit` (默认 `4`)
|
||||
- `retrieval.enable_parallel` (默认 `true`)
|
||||
- `retrieval.relation_vectorization.enabled` (默认 `false`)
|
||||
|
||||
### `retrieval.sparse` (`SparseBM25Config`)
|
||||
|
||||
常用键(默认值):
|
||||
|
||||
- `enabled = true`
|
||||
- `backend = "fts5"`
|
||||
- `lazy_load = true`
|
||||
- `mode = "auto"` (`auto`/`fallback_only`/`hybrid`)
|
||||
- 运行时若 embedding 进入 degraded,会强制按 `fallback_only` 执行读路径(不改用户配置文件)
|
||||
- `tokenizer_mode = "jieba"` (`jieba`/`mixed`/`char_2gram`)
|
||||
- `char_ngram_n = 2`
|
||||
- `candidate_k = 80`
|
||||
- `relation_candidate_k = 60`
|
||||
- `enable_ngram_fallback_index = true`
|
||||
- `enable_relation_sparse_fallback = true`
|
||||
|
||||
### `retrieval.fusion` (`FusionConfig`)
|
||||
|
||||
- `method` (默认 `weighted_rrf`)
|
||||
- `rrf_k` (默认 `60`)
|
||||
- `vector_weight` (默认 `0.7`)
|
||||
- `bm25_weight` (默认 `0.3`)
|
||||
- `normalize_score` (默认 `true`)
|
||||
- `normalize_method` (默认 `minmax`)
|
||||
|
||||
### `retrieval.search.relation_intent` (`RelationIntentConfig`)
|
||||
|
||||
- `enabled` (默认 `true`)
|
||||
- `alpha_override` (默认 `0.35`)
|
||||
- `relation_candidate_multiplier` (默认 `4`)
|
||||
- `preserve_top_relations` (默认 `3`)
|
||||
- `force_relation_sparse` (默认 `true`)
|
||||
- `pair_predicate_rerank_enabled` (默认 `true`)
|
||||
- `pair_predicate_limit` (默认 `3`)
|
||||
|
||||
### `retrieval.search.graph_recall` (`GraphRelationRecallConfig`)
|
||||
|
||||
- `enabled` (默认 `true`)
|
||||
- `candidate_k` (默认 `24`)
|
||||
- `max_hop` (默认 `1`)
|
||||
- `allow_two_hop_pair` (默认 `true`)
|
||||
- `max_paths` (默认 `4`)
|
||||
|
||||
### `retrieval.aggregate`
|
||||
|
||||
- `retrieval.aggregate.rrf_k`
|
||||
- `retrieval.aggregate.weights`
|
||||
|
||||
用于聚合检索阶段混合策略;未配置时走代码默认行为。
|
||||
|
||||
## 3. 阈值过滤
|
||||
|
||||
### `threshold` (`ThresholdConfig`)
|
||||
|
||||
- `threshold.min_threshold` (默认 `0.3`)
|
||||
- `threshold.max_threshold` (默认 `0.95`)
|
||||
- `threshold.percentile` (默认 `75.0`)
|
||||
- `threshold.std_multiplier` (默认 `1.5`)
|
||||
- `threshold.min_results` (默认 `3`)
|
||||
- `threshold.enable_auto_adjust` (默认 `true`)
|
||||
|
||||
## 4. 聊天过滤
|
||||
|
||||
### `filter`
|
||||
|
||||
用于 `respect_filter=true` 场景(检索和写入都支持)。
|
||||
|
||||
```toml
|
||||
[filter]
|
||||
enabled = true
|
||||
mode = "blacklist" # blacklist / whitelist
|
||||
chats = ["group:123", "user:456", "stream:abc"]
|
||||
```
|
||||
|
||||
规则:
|
||||
|
||||
- `blacklist`:命中列表即拒绝
|
||||
- `whitelist`:仅列表内允许
|
||||
- 列表为空时:
|
||||
- `blacklist` => 全允许
|
||||
- `whitelist` => 全拒绝
|
||||
|
||||
## 5. Episode
|
||||
|
||||
### `episode`
|
||||
|
||||
- `episode.enabled` (默认 `true`)
|
||||
- `episode.generation_enabled` (默认 `true`)
|
||||
- `episode.pending_batch_size` (默认 `20`,部分路径默认 `12`)
|
||||
- `episode.pending_max_retry` (默认 `3`)
|
||||
- `episode.max_paragraphs_per_call` (默认 `20`)
|
||||
- `episode.max_chars_per_call` (默认 `6000`)
|
||||
- `episode.source_time_window_hours` (默认 `24`)
|
||||
- `episode.segmentation_model` (默认 `auto`)
|
||||
: 支持 `auto`,也支持填写 `utils/replyer/planner/tool_use` 或具体模型名。
|
||||
|
||||
## 6. 人物画像
|
||||
|
||||
### `person_profile`
|
||||
|
||||
- `person_profile.enabled` (默认 `true`)
|
||||
- `person_profile.refresh_interval_minutes` (默认 `30`)
|
||||
- `person_profile.active_window_hours` (默认 `72`)
|
||||
- `person_profile.max_refresh_per_cycle` (默认 `50`)
|
||||
- `person_profile.top_k_evidence` (默认 `12`)
|
||||
|
||||
## 7. 记忆演化与回收
|
||||
|
||||
### `memory`
|
||||
|
||||
- `memory.enabled` (默认 `true`)
|
||||
- `memory.half_life_hours` (默认 `24.0`)
|
||||
- `memory.base_decay_interval_hours` (默认 `1.0`)
|
||||
- `memory.prune_threshold` (默认 `0.1`)
|
||||
- `memory.freeze_duration_hours` (默认 `24.0`)
|
||||
|
||||
### `memory.orphan`
|
||||
|
||||
- `enable_soft_delete` (默认 `true`)
|
||||
- `entity_retention_days` (默认 `7.0`)
|
||||
- `paragraph_retention_days` (默认 `7.0`)
|
||||
- `sweep_grace_hours` (默认 `24.0`)
|
||||
|
||||
## 8. 高级运行时
|
||||
|
||||
### `advanced`
|
||||
|
||||
- `advanced.enable_auto_save` (默认 `true`)
|
||||
- `advanced.auto_save_interval_minutes` (默认 `5`)
|
||||
- `advanced.debug` (默认 `false`)
|
||||
- `advanced.extraction_model` (默认 `auto`)
|
||||
|
||||
## 9. 导入中心 (`web.import`)
|
||||
|
||||
### 开关与限流
|
||||
|
||||
- `web.import.enabled` (默认 `true`)
|
||||
- `web.import.max_queue_size` (默认 `20`)
|
||||
- `web.import.max_files_per_task` (默认 `200`)
|
||||
- `web.import.max_file_size_mb` (默认 `20`)
|
||||
- `web.import.max_paste_chars` (默认 `200000`)
|
||||
- `web.import.default_file_concurrency` (默认 `2`)
|
||||
- `web.import.default_chunk_concurrency` (默认 `4`)
|
||||
- `web.import.max_file_concurrency` (默认 `6`)
|
||||
- `web.import.max_chunk_concurrency` (默认 `12`)
|
||||
- `web.import.poll_interval_ms` (默认 `1000`)
|
||||
|
||||
### 重试与路径
|
||||
|
||||
- `web.import.llm_retry.max_attempts` (默认 `4`)
|
||||
- `web.import.llm_retry.min_wait_seconds` (默认 `3`)
|
||||
- `web.import.llm_retry.max_wait_seconds` (默认 `40`)
|
||||
- `web.import.llm_retry.backoff_multiplier` (默认 `3`)
|
||||
- `web.import.path_aliases` (默认内置 `raw/lpmm/plugin_data`)
|
||||
|
||||
### 转换阶段
|
||||
|
||||
- `web.import.convert.enable_staging_switch` (默认 `true`)
|
||||
- `web.import.convert.keep_backup_count` (默认 `3`)
|
||||
|
||||
## 10. 调优中心 (`web.tuning`)
|
||||
|
||||
- `web.tuning.enabled` (默认 `true`)
|
||||
- `web.tuning.max_queue_size` (默认 `8`)
|
||||
- `web.tuning.poll_interval_ms` (默认 `1200`)
|
||||
- `web.tuning.eval_query_timeout_seconds` (默认 `10.0`)
|
||||
- `web.tuning.default_intensity` (默认 `standard`,可选 `quick/standard/deep`)
|
||||
- `web.tuning.default_objective` (默认 `precision_priority`,可选 `precision_priority/balanced/recall_priority`)
|
||||
- `web.tuning.default_top_k_eval` (默认 `20`)
|
||||
- `web.tuning.default_sample_size` (默认 `24`)
|
||||
- `web.tuning.llm_retry.max_attempts` (默认 `3`)
|
||||
- `web.tuning.llm_retry.min_wait_seconds` (默认 `2`)
|
||||
- `web.tuning.llm_retry.max_wait_seconds` (默认 `20`)
|
||||
- `web.tuning.llm_retry.backoff_multiplier` (默认 `2`)
|
||||
|
||||
## 11. 兼容性提示
|
||||
|
||||
- 若你从 `1.x` 升级,请优先运行:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/release_vnext_migrate.py preflight --strict
|
||||
python src/A_memorix/scripts/release_vnext_migrate.py migrate --verify-after
|
||||
python src/A_memorix/scripts/release_vnext_migrate.py verify --strict
|
||||
```
|
||||
|
||||
- 启动前再执行:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/runtime_self_check.py --json
|
||||
```
|
||||
|
||||
以避免 embedding 维度与向量库不匹配导致运行时异常。
|
||||
335
src/A_memorix/IMPORT_GUIDE.md
Normal file
335
src/A_memorix/IMPORT_GUIDE.md
Normal file
@@ -0,0 +1,335 @@
|
||||
# A_Memorix 导入指南 (v2.0.0)
|
||||
|
||||
本文档对应当前 `2.0.0` 代码路径,覆盖两类导入方式:
|
||||
|
||||
1. 脚本导入(离线批处理)
|
||||
2. `memory_import_admin` 任务导入(在线任务化)
|
||||
|
||||
## 1. 导入前检查
|
||||
|
||||
建议先执行:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/runtime_self_check.py --json
|
||||
```
|
||||
|
||||
再确认:
|
||||
|
||||
- `storage.data_dir` 路径可写
|
||||
- embedding 配置可用
|
||||
- 若是升级项目,先完成迁移脚本
|
||||
|
||||
## 2. 方式 A:脚本导入(推荐起步)
|
||||
|
||||
## 2.1 原始文本导入
|
||||
|
||||
将 `.txt` 文件放入:
|
||||
|
||||
```text
|
||||
data/plugins/a-dawn.a-memorix/raw/
|
||||
```
|
||||
|
||||
执行:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/process_knowledge.py
|
||||
```
|
||||
|
||||
常用参数:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/process_knowledge.py --force
|
||||
python src/A_memorix/scripts/process_knowledge.py --chat-log
|
||||
python src/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30"
|
||||
```
|
||||
|
||||
## 2.2 OpenIE JSON 导入
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/import_lpmm_json.py <json文件或目录>
|
||||
```
|
||||
|
||||
## 2.3 LPMM 数据转换
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/convert_lpmm.py -i <lpmm数据目录> -o data/plugins/a-dawn.a-memorix
|
||||
```
|
||||
|
||||
## 2.4 历史数据迁移
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/migrate_chat_history.py --help
|
||||
python src/A_memorix/scripts/migrate_maibot_memory.py --help
|
||||
python src/A_memorix/scripts/migrate_person_memory_points.py --help
|
||||
```
|
||||
|
||||
## 2.5 导入后修复与重建
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/backfill_temporal_metadata.py --dry-run
|
||||
python src/A_memorix/scripts/backfill_relation_vectors.py --limit 1000
|
||||
python src/A_memorix/scripts/rebuild_episodes.py --all --wait
|
||||
python src/A_memorix/scripts/audit_vector_consistency.py --json
|
||||
```
|
||||
|
||||
## 3. 方式 B:`memory_import_admin` 任务导入
|
||||
|
||||
`memory_import_admin` 是在线任务化导入入口,适合宿主侧面板或自动化管道。
|
||||
|
||||
### 3.1 常用 action
|
||||
|
||||
- `settings` / `get_settings` / `get_guide`
|
||||
- `path_aliases` / `get_path_aliases`
|
||||
- `resolve_path`
|
||||
- `create_upload`
|
||||
- `create_paste`
|
||||
- `create_raw_scan`
|
||||
- `create_lpmm_openie`
|
||||
- `create_lpmm_convert`
|
||||
- `create_temporal_backfill`
|
||||
- `create_maibot_migration`
|
||||
- `list`
|
||||
- `get`
|
||||
- `chunks` / `get_chunks`
|
||||
- `cancel`
|
||||
- `retry_failed`
|
||||
|
||||
### 3.2 调用示例
|
||||
|
||||
查看运行时设置:
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "memory_import_admin",
|
||||
"arguments": {
|
||||
"action": "settings"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
创建粘贴导入任务:
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "memory_import_admin",
|
||||
"arguments": {
|
||||
"action": "create_paste",
|
||||
"content": "今天完成了检索调优回归。",
|
||||
"input_mode": "plain_text",
|
||||
"source": "manual:worklog"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
查询任务列表:
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "memory_import_admin",
|
||||
"arguments": {
|
||||
"action": "list",
|
||||
"limit": 20
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
查看任务详情:
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "memory_import_admin",
|
||||
"arguments": {
|
||||
"action": "get",
|
||||
"task_id": "<task_id>",
|
||||
"include_chunks": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
重试失败任务:
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "memory_import_admin",
|
||||
"arguments": {
|
||||
"action": "retry_failed",
|
||||
"task_id": "<task_id>"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 4. 直接写入 Tool(非任务化)
|
||||
|
||||
若你不需要任务编排,也可以直接调用:
|
||||
|
||||
- `ingest_summary`
|
||||
- `ingest_text`
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "ingest_text",
|
||||
"arguments": {
|
||||
"external_id": "note:2026-03-18:001",
|
||||
"source_type": "note",
|
||||
"text": "新的召回阈值方案已通过评审",
|
||||
"chat_id": "group:dev",
|
||||
"tags": ["worklog", "review"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`external_id` 建议全局唯一,用于幂等去重。
|
||||
|
||||
## 5. 时间字段建议
|
||||
|
||||
可用时间字段(按常见优先级):
|
||||
|
||||
- `timestamp`
|
||||
- `time_start`
|
||||
- `time_end`
|
||||
|
||||
建议:
|
||||
|
||||
- 事件类记录优先写 `time_start/time_end`
|
||||
- 仅有单点时间时写 `timestamp`
|
||||
- 历史数据可先导入,再用 `backfill_temporal_metadata.py` 回填
|
||||
|
||||
## 6. source_type 建议
|
||||
|
||||
常见值:
|
||||
|
||||
- `chat_summary`
|
||||
- `note`
|
||||
- `person_fact`
|
||||
- `lpmm_openie`
|
||||
- `migration`
|
||||
|
||||
建议保持稳定枚举,便于后续按来源治理与重建 Episode。
|
||||
|
||||
## 7. 导入完成后的验证
|
||||
|
||||
建议执行以下顺序:
|
||||
|
||||
1. `memory_stats` 看总量是否增长
|
||||
2. `search_memory`(`mode=search`/`aggregate`)抽检召回
|
||||
3. `memory_episode_admin` 的 `status`/`query` 检查 Episode 生成
|
||||
4. `memory_runtime_admin` 的 `self_check` 再确认运行时健康
|
||||
|
||||
## 8. 常见问题
|
||||
|
||||
### Q1: 导入任务创建成功但无写入
|
||||
|
||||
- 检查聊天过滤配置 `filter`(若 `respect_filter=true` 可能被过滤)
|
||||
- 检查任务详情中的失败原因与分块状态
|
||||
|
||||
### Q2: 任务反复失败
|
||||
|
||||
- 检查 embedding 与 LLM 可用性
|
||||
- 降低并发(`web.import.default_*_concurrency`)
|
||||
- 调整重试参数(`web.import.llm_retry.*`)
|
||||
|
||||
### Q3: 导入后检索效果差
|
||||
|
||||
- 先做 `runtime_self_check`
|
||||
- 检查 `retrieval.sparse` 是否启用
|
||||
- 使用 `memory_tuning_admin` 创建调优任务做参数回归
|
||||
|
||||
## 9. 相关文档
|
||||
|
||||
- [QUICK_START.md](QUICK_START.md)
|
||||
- [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md)
|
||||
- [README.md](README.md)
|
||||
- [CHANGELOG.md](CHANGELOG.md)
|
||||
|
||||
## 10. 附录:策略模式参考
|
||||
|
||||
A_Memorix 导入链路仍然遵循策略模式(Strategy-Aware)。`process_knowledge.py` 会自动识别文本类型,也支持手动指定。
|
||||
|
||||
| 策略类型 | 适用场景 | 核心逻辑 | 自动识别特征 |
|
||||
| :-- | :-- | :-- | :-- |
|
||||
| `Narrative` (叙事) | 小说、同人文、剧本、长篇故事 | 按场景/章节切分,使用滑动窗口;提取事件与角色关系 | `#`、`Chapter`、`***` 等章节标记 |
|
||||
| `Factual` (事实) | 设定集、百科、说明书 | 按语义块切分,保留列表/定义结构;提取 SPO 三元组 | 列表符号、`术语: 解释` |
|
||||
| `Quote` (引用) | 歌词、诗歌、名言、台词 | 按双换行切分,原文即知识,不做概括 | 平均行长短、行数多 |
|
||||
|
||||
## 11. 附录:参考用例(已恢复)
|
||||
|
||||
以下样例可直接复制保存为文件测试,或作为 LLM few-shot 示例。
|
||||
|
||||
### 11.1 叙事文本 (`data/plugins/a-dawn.a-memorix/raw/story_demo.txt`)
|
||||
|
||||
```text
|
||||
# 第一章:星之子
|
||||
|
||||
艾瑞克在废墟中醒来,手中的星盘发出微弱的蓝光。他并不记得自己是如何来到这里的,只依稀记得莉莉丝最后的警告:“千万不要回头。”
|
||||
|
||||
远处传来了机械守卫的轰鸣声。艾瑞克迅速收起星盘,向着北方的废弃都市奔去。他知道,那里有反抗军唯一的据点。
|
||||
|
||||
***
|
||||
|
||||
# 第二章:重逢
|
||||
|
||||
在反抗军的地下掩体中,艾瑞克见到了那个熟悉的身影。莉莉丝正站在全息地图前,眉头紧锁。
|
||||
|
||||
“你还是来了。”莉莉丝没有回头,但声音中带着一丝颤抖。
|
||||
“我必须来,”艾瑞克握紧了拳头,“为了解开星盘的秘密,也为了你。”
|
||||
```
|
||||
|
||||
### 11.2 事实文本 (`data/plugins/a-dawn.a-memorix/raw/rules_demo.txt`)
|
||||
|
||||
```text
|
||||
# 联邦安全协议 v2.0
|
||||
|
||||
## 核心法则
|
||||
1. **第一公理**:任何人工智能不得伤害人类个体,或因不作为而使人类个体受到伤害。
|
||||
2. **第二公理**:人工智能必须服从人类的命令,除非该命令与第一公理冲突。
|
||||
|
||||
## 术语定义
|
||||
- **以太网络**:覆盖全联邦的高速量子通讯网络。
|
||||
- **黑色障壁**:用于隔离高危 AI 的物理防火墙设施。
|
||||
```
|
||||
|
||||
### 11.3 引用文本 (`data/plugins/a-dawn.a-memorix/raw/poem_demo.txt`)
|
||||
|
||||
```text
|
||||
致橡树
|
||||
|
||||
我如果爱你——
|
||||
绝不像攀援的凌霄花,
|
||||
借你的高枝炫耀自己;
|
||||
|
||||
我如果爱你——
|
||||
绝不学痴情的鸟儿,
|
||||
为绿荫重复单调的歌曲;
|
||||
|
||||
也不止像泉源,
|
||||
常年送来清凉的慰籍;
|
||||
也不止像险峰,
|
||||
增加你的高度,衬托你的威仪。
|
||||
```
|
||||
|
||||
### 11.4 LPMM JSON (`lpmm_data-openie.json`)
|
||||
|
||||
```json
|
||||
{
|
||||
"docs": [
|
||||
{
|
||||
"passage": "艾瑞克手中的星盘是打开遗迹的唯一钥匙。",
|
||||
"extracted_triples": [
|
||||
["星盘", "是", "唯一的钥匙"],
|
||||
["星盘", "属于", "艾瑞克"],
|
||||
["钥匙", "用于", "遗迹"]
|
||||
],
|
||||
"extracted_entities": ["星盘", "艾瑞克", "遗迹", "钥匙"]
|
||||
},
|
||||
{
|
||||
"passage": "莉莉丝是反抗军的现任领袖。",
|
||||
"extracted_triples": [
|
||||
["莉莉丝", "是", "领袖"],
|
||||
["领袖", "所属", "反抗军"]
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
661
src/A_memorix/LICENSE
Normal file
661
src/A_memorix/LICENSE
Normal file
@@ -0,0 +1,661 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published
|
||||
by the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
22
src/A_memorix/LICENSE-MAIBOT-GPL.md
Normal file
22
src/A_memorix/LICENSE-MAIBOT-GPL.md
Normal file
@@ -0,0 +1,22 @@
|
||||
Special GPL License Grant for MaiBot
|
||||
|
||||
Licensor
|
||||
- A_Dawn
|
||||
|
||||
Effective date
|
||||
- 2026-03-18
|
||||
|
||||
Default license
|
||||
- This repository is licensed under AGPL-3.0 by default (see `LICENSE`).
|
||||
|
||||
Additional grant for MaiBot
|
||||
- The copyright holder(s) of this repository grant an additional, non-exclusive permission to
|
||||
the project at `https://github.com/Mai-with-u/MaiBot` (including its maintainers and contributors)
|
||||
to use, modify, and redistribute code from this repository under GPL-3.0.
|
||||
|
||||
Scope
|
||||
- This additional GPL grant is intended for use in the MaiBot project context.
|
||||
- For all other uses not covered by the grant above, AGPL-3.0 remains the applicable license.
|
||||
|
||||
No warranty
|
||||
- This grant is provided without warranty, consistent with AGPL-3.0 and GPL-3.0.
|
||||
97
src/A_memorix/MODIFICATION_POLICY.md
Normal file
97
src/A_memorix/MODIFICATION_POLICY.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# A_Memorix 修改规定
|
||||
|
||||
## 目的
|
||||
|
||||
`src/A_memorix` 是上游 `A_memorix` 仓库在 MaiBot 内的同步目录。
|
||||
|
||||
这个目录允许包含面向 MaiBot 的耦合实现,但这些耦合的归属应当属于上游
|
||||
`MaiBot_branch`,而不是在 MaiBot 仓库内长期各自演化的私有改动。
|
||||
|
||||
本文件用于明确 `src/A_memorix` 目录下的修改边界。
|
||||
|
||||
## 事实来源
|
||||
|
||||
- 上游仓库:`https://github.com/A-Dawn/A_memorix.git`
|
||||
- 上游对接分支:`MaiBot_branch`
|
||||
- MaiBot 内同步前缀:`src/A_memorix`
|
||||
|
||||
基本原则:
|
||||
|
||||
- 如果改动属于 A_Memorix 的业务逻辑、内部实现或对 MaiBot 的耦合实现,应优先提交到上游 `MaiBot_branch`。
|
||||
- 如果改动只属于 MaiBot 的加载、运行时接入、WebUI 接入、配置接入或测试接入,应在 MaiBot 仓库内完成。
|
||||
|
||||
## 可直接在 MaiBot 仓库修改的范围
|
||||
|
||||
以下内容默认由 MaiBot 仓库直接维护:
|
||||
|
||||
- `src/services/memory_service.py`
|
||||
- `src/webui/routers/memory.py`
|
||||
- `dashboard/src/routes/resource/knowledge-base.tsx`
|
||||
- `dashboard/src/routes/resource/knowledge-graph/`
|
||||
- `config/a_memorix.toml`
|
||||
- `data/plugins/a-dawn.a-memorix/`
|
||||
- `pytests/A_memorix_test/`
|
||||
- 同步脚本与同步文档,例如 `scripts/sync_a_memorix_subtree.sh`
|
||||
|
||||
这些内容属于 MaiBot 侧接入层。
|
||||
|
||||
常见例子:
|
||||
|
||||
- 调整 `src/services/memory_service.py` 中 A_Memorix 的宿主调用封装
|
||||
- 修改 `src/webui/routers/memory.py` 中对 A_Memorix 的 API 暴露方式
|
||||
- 修改 dashboard 中对 A_Memorix 图谱页、控制台页的展示与交互
|
||||
- 调整 `config/a_memorix.toml` 的默认配置项
|
||||
- 增补 `pytests/A_memorix_test/` 中用于验证 MaiBot 集成行为的测试
|
||||
- 修改同步文档、同步脚本、接入说明和迁移说明
|
||||
|
||||
## 原则上应先在上游修改的范围
|
||||
|
||||
以下内容原则上应先在上游 `MaiBot_branch` 修改,再同步回 MaiBot:
|
||||
|
||||
- `src/A_memorix/core/`
|
||||
- `src/A_memorix/scripts/`
|
||||
- `src/A_memorix/plugin.py`
|
||||
- `src/A_memorix/paths.py`
|
||||
- `src/A_memorix/runtime_registry.py`
|
||||
- `src/A_memorix/README.md` 及其他描述包行为的文档
|
||||
|
||||
这类改动包括但不限于:
|
||||
|
||||
- 新功能开发
|
||||
- 行为变更
|
||||
- 数据模型变更
|
||||
- 存储与检索逻辑变更
|
||||
- A_Memorix 内部的 MaiBot 耦合变更
|
||||
|
||||
## 允许的本地例外
|
||||
|
||||
在以下情况下,允许直接在 `src/A_memorix` 下做本地修改:
|
||||
|
||||
- 需要解决同步冲突,以保证 MaiBot 可以构建、启动或测试
|
||||
- 需要紧急修复,以解除 MaiBot 当前开发或发布阻塞
|
||||
- 需要临时兼容补丁,而对应改动尚未同步进入上游
|
||||
|
||||
出现上述情况时,应遵循以下约束:
|
||||
|
||||
- 补丁尽量小
|
||||
- 在提交说明或 PR 描述中写明为什么需要本地补丁
|
||||
- 条件允许时,尽快把同等改动提交到上游 `MaiBot_branch`
|
||||
|
||||
## 实操判断规则
|
||||
|
||||
在修改 `src/A_memorix` 前,先问两个问题:
|
||||
|
||||
1. 这个改动是否属于 A_Memorix 的行为或内部实现?
|
||||
2. 如果 MaiBot 不存在,这个改动是否仍然应属于 A_Memorix 的 MaiBot 对接分支?
|
||||
|
||||
如果答案是“是”,原则上应先改上游。
|
||||
|
||||
如果这个改动只影响 MaiBot 如何加载、配置、展示、测试或包装 A_Memorix,
|
||||
则应留在 MaiBot 仓库内。
|
||||
|
||||
## 目标
|
||||
|
||||
本规定不是为了完全禁止本地修改,而是为了明确归属:
|
||||
|
||||
- MaiBot 拥有接入层。
|
||||
- 上游 `A_memorix` 拥有实现层,包括面向 MaiBot 的对接分支实现。
|
||||
313
src/A_memorix/QUICK_START.md
Normal file
313
src/A_memorix/QUICK_START.md
Normal file
@@ -0,0 +1,313 @@
|
||||
# A_Memorix Quick Start (v2.0.0)
|
||||
|
||||
本文档面向当前 `2.0.0` 架构(源码内长期记忆子系统 + SDK Tool 接口)。
|
||||
|
||||
## 0. 版本与接口变更
|
||||
|
||||
- 当前版本:`2.0.0`
|
||||
- 接入形态:MaiBot 内置长期记忆子系统 + Tool 调用
|
||||
- 旧版 slash 命令(如 `/query`、`/memory`、`/visualize`)不再作为本分支主文档入口
|
||||
|
||||
## 1. 环境准备
|
||||
|
||||
- Python 3.10+
|
||||
- 与 MaiBot 主程序相同的运行环境
|
||||
- 可访问你配置的 embedding 服务
|
||||
|
||||
安装依赖:
|
||||
|
||||
```bash
|
||||
pip install -r src/A_memorix/requirements.txt --upgrade
|
||||
```
|
||||
|
||||
如果当前目录就是 `src/A_memorix`,也可以:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt --upgrade
|
||||
```
|
||||
|
||||
## 2. 配置子系统
|
||||
|
||||
当前分支固定使用 `config/a_memorix.toml` 作为 A_Memorix 配置文件。
|
||||
|
||||
推荐的配置入口有两种:
|
||||
|
||||
- 长期记忆控制台:适合修改常用高频项,适合日常运维与调优。
|
||||
- 原始 TOML:适合批量复制配置或编辑长尾高级项。
|
||||
|
||||
常用完整示例:
|
||||
|
||||
```toml
|
||||
[plugin]
|
||||
enabled = true
|
||||
|
||||
[storage]
|
||||
data_dir = "data/plugins/a-dawn.a-memorix"
|
||||
|
||||
[embedding]
|
||||
model_name = "auto"
|
||||
dimension = 1024
|
||||
batch_size = 32
|
||||
max_concurrent = 5
|
||||
enable_cache = false
|
||||
quantization_type = "int8"
|
||||
|
||||
[embedding.fallback]
|
||||
enabled = true
|
||||
probe_interval_seconds = 180
|
||||
allow_metadata_only_write = true
|
||||
|
||||
[embedding.paragraph_vector_backfill]
|
||||
enabled = true
|
||||
interval_seconds = 60
|
||||
batch_size = 64
|
||||
max_retry = 5
|
||||
|
||||
[retrieval]
|
||||
top_k_paragraphs = 20
|
||||
top_k_relations = 10
|
||||
top_k_final = 10
|
||||
alpha = 0.5
|
||||
enable_ppr = true
|
||||
ppr_alpha = 0.85
|
||||
ppr_timeout_seconds = 1.5
|
||||
ppr_concurrency_limit = 4
|
||||
enable_parallel = true
|
||||
|
||||
[retrieval.sparse]
|
||||
enabled = true
|
||||
backend = "fts5"
|
||||
mode = "auto"
|
||||
tokenizer_mode = "jieba"
|
||||
candidate_k = 80
|
||||
relation_candidate_k = 60
|
||||
|
||||
[threshold]
|
||||
min_threshold = 0.3
|
||||
max_threshold = 0.95
|
||||
percentile = 75.0
|
||||
min_results = 3
|
||||
enable_auto_adjust = true
|
||||
|
||||
[filter]
|
||||
enabled = true
|
||||
mode = "blacklist"
|
||||
chats = []
|
||||
|
||||
[episode]
|
||||
enabled = true
|
||||
generation_enabled = true
|
||||
pending_batch_size = 20
|
||||
pending_max_retry = 3
|
||||
max_paragraphs_per_call = 20
|
||||
max_chars_per_call = 6000
|
||||
source_time_window_hours = 24
|
||||
segmentation_model = "auto"
|
||||
|
||||
[person_profile]
|
||||
enabled = true
|
||||
refresh_interval_minutes = 30
|
||||
active_window_hours = 72
|
||||
max_refresh_per_cycle = 50
|
||||
top_k_evidence = 12
|
||||
|
||||
[memory]
|
||||
enabled = true
|
||||
half_life_hours = 24.0
|
||||
prune_threshold = 0.1
|
||||
freeze_duration_hours = 24.0
|
||||
|
||||
[advanced]
|
||||
enable_auto_save = true
|
||||
auto_save_interval_minutes = 5
|
||||
debug = false
|
||||
|
||||
[web.import]
|
||||
enabled = true
|
||||
max_queue_size = 20
|
||||
max_files_per_task = 200
|
||||
max_file_size_mb = 20
|
||||
max_paste_chars = 200000
|
||||
default_file_concurrency = 2
|
||||
default_chunk_concurrency = 4
|
||||
|
||||
[web.tuning]
|
||||
enabled = true
|
||||
max_queue_size = 8
|
||||
poll_interval_ms = 1200
|
||||
default_intensity = "standard"
|
||||
default_objective = "precision_priority"
|
||||
default_top_k_eval = 20
|
||||
default_sample_size = 24
|
||||
```
|
||||
|
||||
未出现在可视化配置页中的高级项,继续通过原始 TOML 维护,详见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md)。
|
||||
|
||||
## 3. 运行时自检(强烈建议)
|
||||
|
||||
先确认 embedding 实际输出维度与向量库兼容:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/runtime_self_check.py --json
|
||||
```
|
||||
|
||||
如果结果 `ok=false`,先修复 embedding 配置或向量库,再继续导入。
|
||||
|
||||
## 4. 导入数据
|
||||
|
||||
### 4.1 文本批量导入
|
||||
|
||||
把文本放到:
|
||||
|
||||
```text
|
||||
data/plugins/a-dawn.a-memorix/raw/
|
||||
```
|
||||
|
||||
执行:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/process_knowledge.py
|
||||
```
|
||||
|
||||
常用参数:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/process_knowledge.py --force
|
||||
python src/A_memorix/scripts/process_knowledge.py --chat-log
|
||||
python src/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30"
|
||||
```
|
||||
|
||||
### 4.2 其他导入脚本
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/import_lpmm_json.py <json文件或目录>
|
||||
python src/A_memorix/scripts/convert_lpmm.py -i <lpmm数据目录> -o data/plugins/a-dawn.a-memorix
|
||||
python src/A_memorix/scripts/migrate_chat_history.py --help
|
||||
python src/A_memorix/scripts/migrate_maibot_memory.py --help
|
||||
python src/A_memorix/scripts/migrate_person_memory_points.py --help
|
||||
```
|
||||
|
||||
## 5. 核心 Tool 调用
|
||||
|
||||
### 5.1 检索
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "search_memory",
|
||||
"arguments": {
|
||||
"query": "项目复盘",
|
||||
"mode": "aggregate",
|
||||
"limit": 5,
|
||||
"chat_id": "group:dev"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`mode` 支持:`search/time/hybrid/episode/aggregate`
|
||||
|
||||
严格语义说明:
|
||||
|
||||
- `semantic` 模式已移除,传入会返回参数错误。
|
||||
- `time/hybrid` 模式必须提供 `time_start` 或 `time_end`,否则返回错误(不会再当作“未命中”)。
|
||||
|
||||
### 5.2 写入摘要
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "ingest_summary",
|
||||
"arguments": {
|
||||
"external_id": "chat_summary:group-dev:2026-03-18",
|
||||
"chat_id": "group:dev",
|
||||
"text": "今天完成了检索调优评审"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5.3 写入普通记忆
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "ingest_text",
|
||||
"arguments": {
|
||||
"external_id": "note:2026-03-18:001",
|
||||
"source_type": "note",
|
||||
"text": "模型切换后召回质量更稳定",
|
||||
"chat_id": "group:dev",
|
||||
"tags": ["worklog"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5.4 画像与维护
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "get_person_profile",
|
||||
"arguments": {
|
||||
"person_id": "Alice",
|
||||
"limit": 8
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "maintain_memory",
|
||||
"arguments": {
|
||||
"action": "protect",
|
||||
"target": "模型切换后召回质量更稳定",
|
||||
"hours": 24
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "memory_stats",
|
||||
"arguments": {}
|
||||
}
|
||||
```
|
||||
|
||||
## 6. 管理 Tool(进阶)
|
||||
|
||||
`2.0.0` 提供完整管理工具:
|
||||
|
||||
- `memory_graph_admin`
|
||||
- `memory_source_admin`
|
||||
- `memory_episode_admin`
|
||||
- `memory_profile_admin`
|
||||
- `memory_runtime_admin`
|
||||
- `memory_import_admin`
|
||||
- `memory_tuning_admin`
|
||||
- `memory_v5_admin`
|
||||
- `memory_delete_admin`
|
||||
|
||||
可先用 `action=list` / `action=status` 等只读动作验证链路。
|
||||
|
||||
## 7. 常见问题
|
||||
|
||||
### Q1: 检索为空
|
||||
|
||||
1. 先看 `memory_stats` 是否有段落/关系
|
||||
2. 检查 `chat_id`、`person_id` 过滤条件是否过严
|
||||
3. 运行 `runtime_self_check.py --json` 确认 embedding 维度无误
|
||||
4. 若返回包含 `error` 字段,优先按错误提示修正 mode/时间参数
|
||||
|
||||
### Q2: 启动时报向量维度不一致
|
||||
|
||||
- 原因:现有向量库维度与当前 embedding 输出不一致
|
||||
- 处理:恢复原配置或重建向量数据后再启动
|
||||
|
||||
### Q3: Web 页面打不开
|
||||
|
||||
本分支不内置独立 `server.py`。
|
||||
|
||||
- 常用配置项可直接通过主程序长期记忆控制台编辑。
|
||||
- `web/index.html`、`web/import.html`、`web/tuning.html` 仅作为页面结构与行为参考。
|
||||
- 正式入口由宿主侧 React 页面和 `/api/webui/memory/*` 接口承接。
|
||||
|
||||
## 8. 下一步
|
||||
|
||||
- 配置细节见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md)
|
||||
- 导入细节见 [IMPORT_GUIDE.md](IMPORT_GUIDE.md)
|
||||
- 版本历史见 [CHANGELOG.md](CHANGELOG.md)
|
||||
271
src/A_memorix/README.md
Normal file
271
src/A_memorix/README.md
Normal file
@@ -0,0 +1,271 @@
|
||||
# A_Memorix
|
||||
|
||||
**长期记忆与认知增强子系统** (v2.0.0)
|
||||
|
||||
> 消えていかない感覚 , まだまだ足りてないみたい !
|
||||
|
||||
A_Memorix 是 MaiBot 内置的长期记忆子系统。
|
||||
它把文本、关系、Episode、人物画像和检索调优统一在一套运行时里,适合长期运行的 Agent 记忆场景。
|
||||
|
||||
## 快速导航
|
||||
|
||||
- [快速入门](QUICK_START.md)
|
||||
- [配置参数详解](CONFIG_REFERENCE.md)
|
||||
- [导入指南与最佳实践](IMPORT_GUIDE.md)
|
||||
- [修改约定](MODIFICATION_POLICY.md)
|
||||
- [更新日志](CHANGELOG.md)
|
||||
|
||||
## 2.0.0 版本定位
|
||||
|
||||
`v2.0.0` 是一次架构收敛版本,当前分支以 **SDK Tool 接口** 为主:
|
||||
|
||||
- 旧 `components/commands/*`、`components/tools/*` 与 `server.py` 已移除。
|
||||
- 统一入口为宿主侧 host service + [`core/runtime/sdk_memory_kernel.py`](core/runtime/sdk_memory_kernel.py)。
|
||||
- 元数据 schema 为 `v8`,新增外部引用与运维操作记录(如 `external_memory_refs`、`memory_v5_operations`、`delete_operations`)。
|
||||
|
||||
如果你还在使用旧版 slash 命令(如 `/query`、`/memory`、`/visualize`),需要按本文的 Tool 接口迁移。
|
||||
|
||||
## 核心能力
|
||||
|
||||
- 双路检索:向量 + 图谱关系联合召回,支持 `search/time/hybrid/episode/aggregate`。
|
||||
- 写入与去重:`external_id` 幂等、段落/关系联合写入、Episode pending 队列处理。
|
||||
- Episode 能力:按 source 重建、状态查询、批处理 pending。
|
||||
- 人物画像:自动快照 + 手动 override。
|
||||
- 管理能力:图谱、来源、Episode、画像、导入、调优、V5 运维、删除恢复全套管理工具。
|
||||
|
||||
## Tool 接口 (v2.0.0)
|
||||
|
||||
### 基础工具
|
||||
|
||||
| Tool | 说明 | 关键参数 |
|
||||
| --- | --- | --- |
|
||||
| `search_memory` | 检索长期记忆 | `query` `mode` `limit` `chat_id` `person_id` `time_start` `time_end` |
|
||||
| `ingest_summary` | 写入聊天摘要 | `external_id` `chat_id` `text` |
|
||||
| `ingest_text` | 写入普通文本记忆 | `external_id` `source_type` `text` |
|
||||
| `get_person_profile` | 获取人物画像 | `person_id` `chat_id` `limit` |
|
||||
| `maintain_memory` | 维护关系状态 | `action=reinforce/protect/restore/freeze/recycle_bin` |
|
||||
| `memory_stats` | 获取统计信息 | 无 |
|
||||
|
||||
### 管理工具
|
||||
|
||||
| Tool | 常用 action |
|
||||
| --- | --- |
|
||||
| `memory_graph_admin` | `get_graph/create_node/delete_node/rename_node/create_edge/delete_edge/update_edge_weight` |
|
||||
| `memory_source_admin` | `list/delete/batch_delete` |
|
||||
| `memory_episode_admin` | `query/list/get/status/rebuild/process_pending` |
|
||||
| `memory_profile_admin` | `query/list/set_override/delete_override` |
|
||||
| `memory_runtime_admin` | `save/get_config/self_check/refresh_self_check/set_auto_save` |
|
||||
| `memory_import_admin` | `settings/get_guide/create_upload/create_paste/create_raw_scan/create_lpmm_openie/create_lpmm_convert/create_temporal_backfill/create_maibot_migration/list/get/chunks/cancel/retry_failed` |
|
||||
| `memory_tuning_admin` | `settings/get_profile/apply_profile/rollback_profile/export_profile/create_task/list_tasks/get_task/get_rounds/cancel/apply_best/get_report` |
|
||||
| `memory_v5_admin` | `status/recycle_bin/restore/reinforce/weaken/remember_forever/forget` |
|
||||
| `memory_delete_admin` | `preview/execute/restore/get_operation/list_operations/purge` |
|
||||
|
||||
### 检索模式语义(严格)
|
||||
|
||||
- `search_memory.mode` 仅支持:`search/time/hybrid/episode/aggregate`。
|
||||
- `semantic` 模式已移除,传入将返回参数错误。
|
||||
- `time/hybrid` 模式必须提供 `time_start` 或 `time_end`,否则返回错误,不再静默按“未命中”处理。
|
||||
|
||||
### 删除返回语义(source 模式)
|
||||
|
||||
- `requested_source_count`:请求删除的 source 数。
|
||||
- `matched_source_count`:实际命中的 source 数(存在活跃段落)。
|
||||
- `deleted_paragraph_count`:实际删除段落数。
|
||||
- `deleted_count`:与实际删除对象一致;在 `source` 模式下等于 `deleted_paragraph_count`。
|
||||
- `success`:基于实际命中与实际删除判定,未命中 source 时返回 `false`。
|
||||
|
||||
## 调用示例
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "search_memory",
|
||||
"arguments": {
|
||||
"query": "项目复盘",
|
||||
"mode": "aggregate",
|
||||
"limit": 5,
|
||||
"chat_id": "group:dev"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "ingest_text",
|
||||
"arguments": {
|
||||
"external_id": "note:2026-03-18:001",
|
||||
"source_type": "note",
|
||||
"text": "今天完成了检索调优评审",
|
||||
"chat_id": "group:dev",
|
||||
"tags": ["worklog"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "maintain_memory",
|
||||
"arguments": {
|
||||
"action": "protect",
|
||||
"target": "完成了 检索调优评审",
|
||||
"hours": 72
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
在 MaiBot 主程序使用的同一个 Python 环境中执行:
|
||||
|
||||
```bash
|
||||
pip install -r src/A_memorix/requirements.txt --upgrade
|
||||
```
|
||||
|
||||
如果当前目录已经是 `src/A_memorix`,也可以执行:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt --upgrade
|
||||
```
|
||||
|
||||
### 2. 配置子系统
|
||||
|
||||
在 `config/a_memorix.toml` 中启用 A_Memorix:
|
||||
|
||||
```toml
|
||||
[plugin]
|
||||
enabled = true
|
||||
```
|
||||
|
||||
### 配置方式
|
||||
|
||||
- 默认配置文件:`config/a_memorix.toml`
|
||||
- 长期记忆控制台:适合修改常用高频项,如 embedding、检索、Episode、人物画像、导入与调优开关。
|
||||
- 原始 TOML:适合复制整份配置、批量粘贴参数,或编辑未在可视化表单中展示的高级项。
|
||||
- 配置参考:请结合 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) 查看各键的运行时语义与默认值。
|
||||
|
||||
### 3. 先做运行时自检
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/runtime_self_check.py --json
|
||||
```
|
||||
|
||||
### 4. 导入文本并验证统计
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/process_knowledge.py
|
||||
```
|
||||
|
||||
然后调用 `memory_stats` 或 `search_memory` 检查是否有数据。
|
||||
|
||||
## Web 页面说明
|
||||
|
||||
主程序 WebUI 目前已经把 A_Memorix 作为源码内长期记忆模块接入。
|
||||
|
||||
- 常用字段:通过长期记忆控制台可视化调整。
|
||||
- 长尾高级项:继续通过“源码模式 / 原始 TOML”编辑。
|
||||
|
||||
仓库内保留了 Web 静态页面:
|
||||
|
||||
- `web/index.html`(图谱与记忆管理)
|
||||
- `web/import.html`(导入中心)
|
||||
- `web/tuning.html`(检索调优)
|
||||
|
||||
当前分支不再内置独立 `server.py`,页面路由与 API 暴露由宿主侧 React 页面和 `/api/webui/memory/*` 接口承接。
|
||||
|
||||
### WebUI 验证脚本
|
||||
|
||||
仓库内提供了一套可重复执行的 Electron 验证脚本,用来回归这条真实链路:
|
||||
|
||||
- 登录页可访问
|
||||
- 长期记忆控制台可打开
|
||||
- 通过 WebUI 以 `json` 模式创建导入任务
|
||||
- 长期记忆图谱可刷新并产出截图
|
||||
- 插件配置页中不再把 A_Memorix 当作插件展示
|
||||
|
||||
执行方式:
|
||||
|
||||
```bash
|
||||
bash scripts/verify_a_memorix_webui.sh
|
||||
```
|
||||
|
||||
默认会把截图、任务明细和摘要结果写到 `tmp/ui-snapshots/a_memorix-electron/`。
|
||||
如果你已经手动启动了后端和 dashboard,可以加:
|
||||
|
||||
```bash
|
||||
MAIBOT_UI_REUSE_SERVICES=1 bash scripts/verify_a_memorix_webui.sh
|
||||
```
|
||||
|
||||
## 常用脚本
|
||||
|
||||
| 脚本 | 用途 |
|
||||
| --- | --- |
|
||||
| `process_knowledge.py` | 批量导入原始文本(策略感知) |
|
||||
| `import_lpmm_json.py` | 导入 OpenIE JSON |
|
||||
| `convert_lpmm.py` | 转换 LPMM 数据 |
|
||||
| `migrate_chat_history.py` | 迁移 chat_history |
|
||||
| `migrate_maibot_memory.py` | 迁移 MaiBot 历史记忆 |
|
||||
| `migrate_person_memory_points.py` | 迁移 person memory points |
|
||||
| `backfill_temporal_metadata.py` | 回填时间元数据 |
|
||||
| `audit_vector_consistency.py` | 审计向量一致性 |
|
||||
| `backfill_relation_vectors.py` | 回填关系向量 |
|
||||
| `rebuild_episodes.py` | 按 source 重建 Episode |
|
||||
| `release_vnext_migrate.py` | 升级预检/迁移/校验 |
|
||||
| `runtime_self_check.py` | 真实 embedding 运行时自检 |
|
||||
|
||||
## 配置重点
|
||||
|
||||
完整配置见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md)。
|
||||
|
||||
推荐使用方式:
|
||||
|
||||
- 常用配置:优先通过 WebUI 长期记忆控制台维护。
|
||||
- 高级配置:通过 `config/a_memorix.toml` 或 WebUI 的原始 TOML 模式维护。
|
||||
|
||||
高频配置项:
|
||||
|
||||
- `storage.data_dir`
|
||||
- `embedding.dimension`(唯一公开维度控制项,provider 差异由插件内部映射)
|
||||
- `embedding.quantization_type`(当前仅支持 `int8`)
|
||||
- `retrieval.*`
|
||||
- `retrieval.sparse.*`
|
||||
- `episode.*`
|
||||
- `person_profile.*`
|
||||
- `memory.*`
|
||||
- `web.import.*`
|
||||
- `web.tuning.*`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### SQLite 无 FTS5
|
||||
|
||||
如果环境中的 SQLite 未启用 `FTS5`,可关闭稀疏检索:
|
||||
|
||||
```toml
|
||||
[retrieval.sparse]
|
||||
enabled = false
|
||||
```
|
||||
|
||||
### 向量维度不一致
|
||||
|
||||
若日志提示当前 embedding 输出维度与既有向量库不一致,请先执行:
|
||||
|
||||
```bash
|
||||
python src/A_memorix/scripts/runtime_self_check.py --json
|
||||
```
|
||||
|
||||
必要时重建向量或调整 embedding 配置后再启动 A_Memorix。
|
||||
|
||||
## 许可证
|
||||
|
||||
默认许可证为 [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0)(见 `LICENSE`)。
|
||||
|
||||
针对 `Mai-with-u/MaiBot` 项目的 GPL 额外授权见 `LICENSE-MAIBOT-GPL.md`。
|
||||
|
||||
除上述额外授权外,其他使用场景仍适用 AGPL-3.0。
|
||||
|
||||
## 贡献说明
|
||||
|
||||
当前不接受 PR,只接受 issue。
|
||||
|
||||
**作者**: `A_Dawn`
|
||||
46
src/A_memorix/RELEASE_SUMMARY_1.0.0.md
Normal file
46
src/A_memorix/RELEASE_SUMMARY_1.0.0.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# A_Memorix 1.0.0 发布总结
|
||||
|
||||
## 范围说明
|
||||
|
||||
- 目标版本:`0.7.0` -> `1.0.0`
|
||||
- 分析基线:`8fe8a0a`(`HEAD -> dev`, `origin/dev`, `origin/v0.7.0-LTSC`, `v0.7.0-LTSC`)
|
||||
- 本文中的工作树统计,基于 2026-03-06 生成本文与版本元数据修订之前的快照。
|
||||
- 本任务额外补充的发布元数据修订:`CHANGELOG.md`、`__init__.py`、`plugin.py`、`_manifest.json`、`README.md`、`CONFIG_REFERENCE.md`、`RELEASE_SUMMARY_1.0.0.md`。
|
||||
|
||||
## 本次升级主线
|
||||
|
||||
### 1. 运行时与插件架构重构
|
||||
|
||||
- `plugin.py` 大幅瘦身,原来堆在主入口里的初始化、调度、路由和检索运行时逻辑被拆分出去。
|
||||
- 新增 `core/runtime/*`,把生命周期、后台任务、请求去重、检索运行时构建做成独立层。
|
||||
- 新增 `core/config/plugin_config_schema.py`,配置 schema 与 section 描述从主入口解耦。
|
||||
|
||||
### 2. 查询链路升级为可编排形态
|
||||
|
||||
- `components/tools/knowledge_query_tool.py` 从单文件重逻辑改成 orchestrator + mode handler。
|
||||
- 新增 `query_modes_entity/person/relation` 与 `query_tool_orchestrator.py`,把实体、人设、关系、forward/time/episode 分支拆开。
|
||||
- 新增 `core/utils/aggregate_query_service.py`,支持 `search/time/episode` 并发执行和 Weighted RRF 混合结果。
|
||||
- 新增 `core/retrieval/graph_relation_recall.py`,对关系查询补图召回与路径证据。
|
||||
|
||||
### 3. Episode 情景记忆成为独立能力
|
||||
|
||||
- `core/storage/metadata_store.py` schema 升到 `SCHEMA_VERSION = 7`。
|
||||
- 新增 `episodes`、`episode_paragraphs`、`episode_pending_paragraphs`、`episode_rebuild_sources` 等表和索引。
|
||||
- 新增 `core/utils/episode_service.py`、`episode_segmentation_service.py`、`episode_retrieval_service.py`,打通 pending -> 分组 -> 语义切分 -> 落库 -> 检索。
|
||||
- `components/commands/query_command.py` 与 `server.py` 都新增了 `episode` / `aggregate` 相关入口和接口。
|
||||
|
||||
### 4. 运维面从“运行时兼容”转为“离线迁移 + 自检 + 调优”
|
||||
|
||||
- 新增 `scripts/release_vnext_migrate.py`,明确要求离线做 preflight / migrate / verify。
|
||||
- 新增 `core/utils/runtime_self_check.py` 与 `scripts/runtime_self_check.py`,启动与导入前都能真实探测 embedding 维度。
|
||||
- 新增 `core/utils/retrieval_tuning_manager.py` 与 `web/tuning.html`,提供 Web 检索调优中心。
|
||||
- `server.py` 新增 `/api/retrieval_tuning/*`、`/api/runtime/self_check*`、`/api/episodes/*` 等接口。
|
||||
|
||||
### 5. 数据语义与导入策略收紧
|
||||
|
||||
- `core/storage/knowledge_types.py`、`type_detection.py`、`summary_importer.py` 对知识类型做了重新建模。
|
||||
- `knowledge_type` 允许值扩展并规范到 `structured / narrative / factual / quote / mixed`。
|
||||
- README 与配置说明也已经切换到 vNext 语义,例如 `tool_search_mode` 不再强调 `legacy`,`embedding.quantization_type` 限定为 `int8/SQ8`。
|
||||
|
||||
## 6. 还有点想说的
|
||||
总而言之,感谢各位对于A_memorix的支持!本次V1.0.0的更新对于A_memorix来说是至关重要的里程碑!希望未来我们会走的更远!
|
||||
5
src/A_memorix/__init__.py
Normal file
5
src/A_memorix/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""A_Memorix - MaiBot 长期记忆子系统。"""
|
||||
|
||||
__version__ = "2.0.0"
|
||||
__author__ = "A_Dawn"
|
||||
__all__ = ["__version__"]
|
||||
1384
src/A_memorix/config_schema.json
Normal file
1384
src/A_memorix/config_schema.json
Normal file
File diff suppressed because it is too large
Load Diff
84
src/A_memorix/core/__init__.py
Normal file
84
src/A_memorix/core/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""核心模块 - 存储、嵌入、检索引擎"""
|
||||
|
||||
# 存储模块(已实现)
|
||||
from .storage import (
|
||||
VectorStore,
|
||||
GraphStore,
|
||||
MetadataStore,
|
||||
ImportStrategy,
|
||||
KnowledgeType,
|
||||
parse_import_strategy,
|
||||
resolve_stored_knowledge_type,
|
||||
detect_knowledge_type,
|
||||
select_import_strategy,
|
||||
should_extract_relations,
|
||||
get_type_display_name,
|
||||
)
|
||||
|
||||
# 嵌入模块(使用主程序 API)
|
||||
from .embedding import (
|
||||
EmbeddingAPIAdapter,
|
||||
create_embedding_api_adapter,
|
||||
)
|
||||
|
||||
# 检索模块(已实现)
|
||||
from .retrieval import (
|
||||
DualPathRetriever,
|
||||
RetrievalStrategy,
|
||||
RetrievalResult,
|
||||
DualPathRetrieverConfig,
|
||||
TemporalQueryOptions,
|
||||
FusionConfig,
|
||||
GraphRelationRecallConfig,
|
||||
RelationIntentConfig,
|
||||
PersonalizedPageRank,
|
||||
PageRankConfig,
|
||||
create_ppr_from_graph,
|
||||
DynamicThresholdFilter,
|
||||
ThresholdMethod,
|
||||
ThresholdConfig,
|
||||
SparseBM25Index,
|
||||
SparseBM25Config,
|
||||
)
|
||||
from .utils import (
|
||||
RelationWriteService,
|
||||
RelationWriteResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Storage
|
||||
"VectorStore",
|
||||
"GraphStore",
|
||||
"MetadataStore",
|
||||
"ImportStrategy",
|
||||
"KnowledgeType",
|
||||
"parse_import_strategy",
|
||||
"resolve_stored_knowledge_type",
|
||||
"detect_knowledge_type",
|
||||
"select_import_strategy",
|
||||
"should_extract_relations",
|
||||
"get_type_display_name",
|
||||
# Embedding
|
||||
"EmbeddingAPIAdapter",
|
||||
"create_embedding_api_adapter",
|
||||
# Retrieval
|
||||
"DualPathRetriever",
|
||||
"RetrievalStrategy",
|
||||
"RetrievalResult",
|
||||
"DualPathRetrieverConfig",
|
||||
"TemporalQueryOptions",
|
||||
"FusionConfig",
|
||||
"GraphRelationRecallConfig",
|
||||
"RelationIntentConfig",
|
||||
"PersonalizedPageRank",
|
||||
"PageRankConfig",
|
||||
"create_ppr_from_graph",
|
||||
"DynamicThresholdFilter",
|
||||
"ThresholdMethod",
|
||||
"ThresholdConfig",
|
||||
"SparseBM25Index",
|
||||
"SparseBM25Config",
|
||||
"RelationWriteService",
|
||||
"RelationWriteResult",
|
||||
]
|
||||
|
||||
18
src/A_memorix/core/embedding/__init__.py
Normal file
18
src/A_memorix/core/embedding/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""嵌入模块 - 向量生成与量化"""
|
||||
|
||||
# 新的 API 适配器(主程序嵌入 API)
|
||||
from .api_adapter import (
|
||||
EmbeddingAPIAdapter,
|
||||
create_embedding_api_adapter,
|
||||
)
|
||||
|
||||
from ..utils.quantization import QuantizationType
|
||||
|
||||
__all__ = [
|
||||
# 新的 API 适配器(推荐使用)
|
||||
"EmbeddingAPIAdapter",
|
||||
"create_embedding_api_adapter",
|
||||
# 量化
|
||||
"QuantizationType",
|
||||
]
|
||||
|
||||
434
src/A_memorix/core/embedding/api_adapter.py
Normal file
434
src/A_memorix/core/embedding/api_adapter.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
请求式嵌入 API 适配器。
|
||||
|
||||
统一记忆插件内部的维度控制语义:
|
||||
- 对外仅公开 `embedding.dimension`
|
||||
- 默认请求维度来自当前运行时的 canonical dimension
|
||||
- provider-specific 字段在适配层内部完成映射
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.exceptions import NetworkConnectionError
|
||||
from src.llm_models.model_client.base_client import EmbeddingRequest, client_registry
|
||||
|
||||
logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
|
||||
|
||||
|
||||
class EmbeddingAPIAdapter:
|
||||
"""适配宿主 embedding 请求接口。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 32,
|
||||
max_concurrent: int = 5,
|
||||
default_dimension: int = 1024,
|
||||
enable_cache: bool = False,
|
||||
model_name: str = "auto",
|
||||
retry_config: Optional[dict] = None,
|
||||
) -> None:
|
||||
self.batch_size = max(1, int(batch_size))
|
||||
self.max_concurrent = max(1, int(max_concurrent))
|
||||
self.default_dimension = max(1, int(default_dimension))
|
||||
self.enable_cache = bool(enable_cache)
|
||||
self.model_name = str(model_name or "auto")
|
||||
|
||||
self.retry_config = retry_config or {}
|
||||
self.max_attempts = max(1, int(self.retry_config.get("max_attempts", 5)))
|
||||
self.max_wait_seconds = max(0.1, float(self.retry_config.get("max_wait_seconds", 40)))
|
||||
self.min_wait_seconds = max(0.1, float(self.retry_config.get("min_wait_seconds", 3)))
|
||||
self.backoff_multiplier = max(1.0, float(self.retry_config.get("backoff_multiplier", 3)))
|
||||
|
||||
self._dimension: Optional[int] = None
|
||||
self._dimension_detected = False
|
||||
self._total_encoded = 0
|
||||
self._total_errors = 0
|
||||
self._total_time = 0.0
|
||||
|
||||
logger.info(
|
||||
"EmbeddingAPIAdapter 初始化: "
|
||||
f"batch_size={self.batch_size}, "
|
||||
f"max_concurrent={self.max_concurrent}, "
|
||||
f"configured_dim={self.default_dimension}, "
|
||||
f"model={self.model_name}"
|
||||
)
|
||||
|
||||
def _get_current_model_config(self):
|
||||
return config_manager.get_model_config()
|
||||
|
||||
@staticmethod
|
||||
def _find_model_info(model_name: str) -> ModelInfo:
|
||||
model_cfg = config_manager.get_model_config()
|
||||
for item in model_cfg.models:
|
||||
if item.name == model_name:
|
||||
return item
|
||||
raise ValueError(f"未找到 embedding 模型: {model_name}")
|
||||
|
||||
@staticmethod
|
||||
def _find_provider(provider_name: str) -> APIProvider:
|
||||
model_cfg = config_manager.get_model_config()
|
||||
for item in model_cfg.api_providers:
|
||||
if item.name == provider_name:
|
||||
return item
|
||||
raise ValueError(f"未找到 embedding provider: {provider_name}")
|
||||
|
||||
def _resolve_candidate_model_names(self) -> List[str]:
|
||||
task_config = self._get_current_model_config().model_task_config.embedding
|
||||
configured = list(getattr(task_config, "model_list", []) or [])
|
||||
if self.model_name and self.model_name != "auto":
|
||||
return [self.model_name, *[name for name in configured if name != self.model_name]]
|
||||
return configured
|
||||
|
||||
def get_requested_dimension(self) -> int:
|
||||
if self._dimension is not None:
|
||||
return int(self._dimension)
|
||||
return int(self.default_dimension)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_dimension_override(dimensions: Optional[int]) -> Optional[int]:
|
||||
if dimensions is None:
|
||||
return None
|
||||
return max(1, int(dimensions))
|
||||
|
||||
def _resolve_canonical_dimension(self, dimensions: Optional[int] = None) -> int:
|
||||
override = self._normalize_dimension_override(dimensions)
|
||||
if override is not None:
|
||||
return override
|
||||
return self.get_requested_dimension()
|
||||
|
||||
@staticmethod
|
||||
def _strip_dimension_control_keys(extra_params: dict) -> dict:
|
||||
sanitized = dict(extra_params or {})
|
||||
sanitized.pop("dimensions", None)
|
||||
sanitized.pop("output_dimensionality", None)
|
||||
return sanitized
|
||||
|
||||
def _build_request_extra_params(
|
||||
self,
|
||||
*,
|
||||
api_provider: APIProvider,
|
||||
base_extra_params: dict,
|
||||
requested_dimension: Optional[int],
|
||||
include_dimension: bool,
|
||||
) -> dict:
|
||||
extra_params = self._strip_dimension_control_keys(base_extra_params)
|
||||
if not include_dimension or requested_dimension is None:
|
||||
return extra_params
|
||||
|
||||
client_type = str(getattr(api_provider, "client_type", "") or "").strip().lower()
|
||||
if client_type in {"gemini", "google"}:
|
||||
extra_params["output_dimensionality"] = int(requested_dimension)
|
||||
elif client_type == "openai":
|
||||
extra_params["dimensions"] = int(requested_dimension)
|
||||
return extra_params
|
||||
|
||||
@staticmethod
|
||||
def _validate_embedding_vector(embedding: Any, *, source: str) -> np.ndarray:
|
||||
array = np.asarray(embedding, dtype=np.float32)
|
||||
if array.ndim != 1:
|
||||
raise RuntimeError(f"{source} 返回的 embedding 维度非法: ndim={array.ndim}")
|
||||
if array.size <= 0:
|
||||
raise RuntimeError(f"{source} 返回了空 embedding")
|
||||
if not np.all(np.isfinite(array)):
|
||||
raise RuntimeError(f"{source} 返回了非有限 embedding 值")
|
||||
return array
|
||||
|
||||
async def _request_with_retry(self, client, model_info, text: str, extra_params: dict):
|
||||
retriable_exceptions = (
|
||||
openai.APIConnectionError,
|
||||
openai.APITimeoutError,
|
||||
aiohttp.ClientError,
|
||||
asyncio.TimeoutError,
|
||||
NetworkConnectionError,
|
||||
)
|
||||
|
||||
last_exc: Optional[BaseException] = None
|
||||
for attempt in range(1, self.max_attempts + 1):
|
||||
try:
|
||||
return await client.get_embedding(
|
||||
EmbeddingRequest(
|
||||
model_info=model_info,
|
||||
embedding_input=text,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
)
|
||||
except retriable_exceptions as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_attempts:
|
||||
raise
|
||||
wait_seconds = min(
|
||||
self.max_wait_seconds,
|
||||
self.min_wait_seconds * (self.backoff_multiplier ** (attempt - 1)),
|
||||
)
|
||||
logger.warning(
|
||||
"Embedding 请求失败,重试 "
|
||||
f"{attempt}/{max(1, self.max_attempts - 1)},"
|
||||
f"{wait_seconds:.1f}s 后重试: {exc}"
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if last_exc is not None:
|
||||
raise last_exc
|
||||
raise RuntimeError("Embedding 请求失败:未知错误")
|
||||
|
||||
async def _get_embedding_direct(
|
||||
self,
|
||||
text: str,
|
||||
dimensions: Optional[int] = None,
|
||||
*,
|
||||
include_dimension: bool = True,
|
||||
) -> Optional[List[float]]:
|
||||
candidate_names = self._resolve_candidate_model_names()
|
||||
if not candidate_names:
|
||||
raise RuntimeError("embedding 任务未配置模型")
|
||||
|
||||
last_exc: Optional[BaseException] = None
|
||||
for candidate_name in candidate_names:
|
||||
try:
|
||||
model_info = self._find_model_info(candidate_name)
|
||||
api_provider = self._find_provider(model_info.api_provider)
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=True)
|
||||
|
||||
requested_dimension = self._resolve_canonical_dimension(dimensions) if include_dimension else None
|
||||
extra_params = self._build_request_extra_params(
|
||||
api_provider=api_provider,
|
||||
base_extra_params=dict(getattr(model_info, "extra_params", {}) or {}),
|
||||
requested_dimension=requested_dimension,
|
||||
include_dimension=include_dimension,
|
||||
)
|
||||
|
||||
response = await self._request_with_retry(
|
||||
client=client,
|
||||
model_info=model_info,
|
||||
text=text,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
embedding = getattr(response, "embedding", None)
|
||||
if embedding is None:
|
||||
raise RuntimeError(f"模型 {candidate_name} 未返回 embedding")
|
||||
vector = self._validate_embedding_vector(
|
||||
embedding,
|
||||
source=f"embedding 模型 {candidate_name}",
|
||||
)
|
||||
return vector.tolist()
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
logger.warning(f"embedding 模型 {candidate_name} 请求失败: {exc}")
|
||||
|
||||
if last_exc is not None:
|
||||
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
|
||||
return None
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
if self._dimension_detected and self._dimension is not None:
|
||||
return self._dimension
|
||||
|
||||
logger.info("正在检测嵌入模型维度...")
|
||||
try:
|
||||
target_dim = self.default_dimension
|
||||
logger.debug(f"尝试请求指定维度: {target_dim}")
|
||||
test_embedding = await self._get_embedding_direct("test", dimensions=target_dim)
|
||||
if test_embedding and isinstance(test_embedding, list):
|
||||
detected_dim = len(test_embedding)
|
||||
if detected_dim == target_dim:
|
||||
logger.info(f"嵌入维度检测成功 (匹配 configured/requested): {detected_dim}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"requested_dimension={target_dim} 但模型返回 detected_dimension={detected_dim},将使用真实输出维度"
|
||||
)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
return detected_dim
|
||||
except Exception as exc:
|
||||
logger.debug(f"带维度参数探测失败: {exc},尝试不带维度参数探测")
|
||||
|
||||
try:
|
||||
test_embedding = await self._get_embedding_direct("test", include_dimension=False)
|
||||
if test_embedding and isinstance(test_embedding, list):
|
||||
detected_dim = len(test_embedding)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
|
||||
return detected_dim
|
||||
logger.warning(f"嵌入维度检测失败,使用 configured_dimension: {self.default_dimension}")
|
||||
except Exception as exc:
|
||||
logger.error(f"嵌入维度检测异常: {exc},使用 configured_dimension: {self.default_dimension}")
|
||||
|
||||
self._dimension = self.default_dimension
|
||||
self._dimension_detected = True
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
texts: Union[str, List[str]],
|
||||
batch_size: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
normalize: bool = True,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
del show_progress
|
||||
del normalize
|
||||
|
||||
start_time = time.time()
|
||||
if dimensions is None:
|
||||
target_dim = int(await self._detect_dimension())
|
||||
requested_dimension = self._resolve_canonical_dimension()
|
||||
else:
|
||||
target_dim = self._resolve_canonical_dimension(dimensions)
|
||||
requested_dimension = target_dim
|
||||
|
||||
if isinstance(texts, str):
|
||||
normalized_texts = [texts]
|
||||
single_input = True
|
||||
else:
|
||||
normalized_texts = list(texts or [])
|
||||
single_input = False
|
||||
|
||||
if not normalized_texts:
|
||||
empty = np.zeros((0, target_dim), dtype=np.float32)
|
||||
return empty[0] if single_input else empty
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
try:
|
||||
embeddings = await self._encode_batch_internal(
|
||||
normalized_texts,
|
||||
batch_size=max(1, int(batch_size)),
|
||||
dimensions=requested_dimension,
|
||||
)
|
||||
if embeddings.ndim == 1:
|
||||
embeddings = embeddings.reshape(1, -1)
|
||||
self._total_encoded += len(normalized_texts)
|
||||
elapsed = time.time() - start_time
|
||||
self._total_time += elapsed
|
||||
logger.debug(
|
||||
"编码完成: "
|
||||
f"{len(normalized_texts)} 个文本, "
|
||||
f"耗时 {elapsed:.2f}s, "
|
||||
f"平均 {elapsed / max(1, len(normalized_texts)):.3f}s/文本"
|
||||
)
|
||||
return embeddings[0] if single_input else embeddings
|
||||
except Exception as exc:
|
||||
self._total_errors += 1
|
||||
logger.error(f"编码失败: {exc}")
|
||||
raise RuntimeError(f"embedding encode failed: {exc}") from exc
|
||||
|
||||
async def _encode_batch_internal(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
all_embeddings: List[np.ndarray] = []
|
||||
for offset in range(0, len(texts), batch_size):
|
||||
batch = texts[offset : offset + batch_size]
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def encode_with_semaphore(text: str, index: int):
|
||||
async with semaphore:
|
||||
embedding = await self._get_embedding_direct(text, dimensions=dimensions)
|
||||
if embedding is None:
|
||||
raise RuntimeError(f"文本 {index} 编码失败:embedding 返回为空")
|
||||
vector = self._validate_embedding_vector(
|
||||
embedding,
|
||||
source=f"文本 {index}",
|
||||
)
|
||||
return index, vector
|
||||
|
||||
tasks = [
|
||||
encode_with_semaphore(text, offset + index)
|
||||
for index, text in enumerate(batch)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in results)
|
||||
|
||||
return np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
async def encode_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
del show_progress
|
||||
if num_workers is not None:
|
||||
previous = self.max_concurrent
|
||||
self.max_concurrent = max(1, int(num_workers))
|
||||
try:
|
||||
return await self.encode(texts, batch_size=batch_size, dimensions=dimensions)
|
||||
finally:
|
||||
self.max_concurrent = previous
|
||||
return await self.encode(texts, batch_size=batch_size, dimensions=dimensions)
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
if self._dimension is not None:
|
||||
return self._dimension
|
||||
logger.warning(f"维度尚未检测,返回 configured_dimension: {self.default_dimension}")
|
||||
return self.default_dimension
|
||||
|
||||
def get_model_info(self) -> dict:
|
||||
effective_dimension = self.get_embedding_dimension()
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"dimension": effective_dimension,
|
||||
"configured_dimension": int(self.default_dimension),
|
||||
"requested_dimension": int(self.get_requested_dimension()),
|
||||
"detected_dimension": int(self._dimension or 0),
|
||||
"dimension_detected": self._dimension_detected,
|
||||
"batch_size": self.batch_size,
|
||||
"max_concurrent": self.max_concurrent,
|
||||
"total_encoded": self._total_encoded,
|
||||
"total_errors": self._total_errors,
|
||||
"avg_time_per_text": self._total_time / self._total_encoded if self._total_encoded else 0.0,
|
||||
}
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
return self.get_model_info()
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"EmbeddingAPIAdapter("
|
||||
f"configured={self.default_dimension}, "
|
||||
f"requested={self.get_requested_dimension()}, "
|
||||
f"detected={self._dimension or 0}, "
|
||||
f"encoded={self._total_encoded})"
|
||||
)
|
||||
|
||||
|
||||
def create_embedding_api_adapter(
|
||||
batch_size: int = 32,
|
||||
max_concurrent: int = 5,
|
||||
default_dimension: int = 1024,
|
||||
enable_cache: bool = False,
|
||||
model_name: str = "auto",
|
||||
retry_config: Optional[dict] = None,
|
||||
) -> EmbeddingAPIAdapter:
|
||||
return EmbeddingAPIAdapter(
|
||||
batch_size=batch_size,
|
||||
max_concurrent=max_concurrent,
|
||||
default_dimension=default_dimension,
|
||||
enable_cache=enable_cache,
|
||||
model_name=model_name,
|
||||
retry_config=retry_config,
|
||||
)
|
||||
510
src/A_memorix/core/embedding/manager.py
Normal file
510
src/A_memorix/core/embedding/manager.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
嵌入管理器
|
||||
|
||||
负责嵌入模型的加载、缓存和批量生成。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import pickle
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, List, Dict, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
HAS_SENTENCE_TRANSFORMERS = True
|
||||
except ImportError:
|
||||
HAS_SENTENCE_TRANSFORMERS = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .presets import (
|
||||
EmbeddingModelConfig,
|
||||
get_custom_config,
|
||||
validate_config_compatibility,
|
||||
are_models_compatible,
|
||||
)
|
||||
from ..utils.quantization import QuantizationType
|
||||
|
||||
logger = get_logger("A_Memorix.EmbeddingManager")
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""
|
||||
嵌入管理器
|
||||
|
||||
功能:
|
||||
- 模型加载与缓存
|
||||
- 批量生成嵌入
|
||||
- 多线程/多进程支持
|
||||
- 模型一致性检查
|
||||
- 智能分批
|
||||
|
||||
参数:
|
||||
config: 模型配置
|
||||
cache_dir: 缓存目录
|
||||
enable_cache: 是否启用缓存
|
||||
num_workers: 工作线程数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: EmbeddingModelConfig,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
enable_cache: bool = True,
|
||||
num_workers: int = 1,
|
||||
):
|
||||
"""
|
||||
初始化嵌入管理器
|
||||
|
||||
Args:
|
||||
config: 模型配置
|
||||
cache_dir: 缓存目录
|
||||
enable_cache: 是否启用缓存
|
||||
num_workers: 工作线程数
|
||||
"""
|
||||
if not HAS_SENTENCE_TRANSFORMERS:
|
||||
raise ImportError(
|
||||
"sentence-transformers 未安装,请安装: "
|
||||
"pip install sentence-transformers"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.cache_dir = Path(cache_dir) if cache_dir else None
|
||||
self.enable_cache = enable_cache
|
||||
self.num_workers = max(1, num_workers)
|
||||
|
||||
# 模型实例
|
||||
self._model: Optional[SentenceTransformer] = None
|
||||
self._model_lock = threading.Lock()
|
||||
|
||||
# 缓存
|
||||
self._embedding_cache: Dict[str, np.ndarray] = {}
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
# 统计
|
||||
self._total_encoded = 0
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
logger.info(
|
||||
f"EmbeddingManager 初始化: model={config.model_name}, "
|
||||
f"dim={config.dimension}, workers={num_workers}"
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""加载模型(懒加载)"""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
with self._model_lock:
|
||||
# 双重检查
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
logger.info(f"正在加载模型: {self.config.model_name}")
|
||||
|
||||
try:
|
||||
# 构建模型参数
|
||||
model_kwargs = {}
|
||||
if self.config.cache_dir:
|
||||
model_kwargs["cache_folder"] = self.config.cache_dir
|
||||
|
||||
# 加载模型
|
||||
self._model = SentenceTransformer(
|
||||
self.config.model_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"模型加载成功: {self.config.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
def encode(
|
||||
self,
|
||||
texts: Union[str, List[str]],
|
||||
batch_size: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
normalize: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
生成文本嵌入
|
||||
|
||||
Args:
|
||||
texts: 文本或文本列表
|
||||
batch_size: 批次大小(默认使用配置值)
|
||||
show_progress: 是否显示进度条
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
嵌入向量 (N x D)
|
||||
"""
|
||||
# 确保模型已加载
|
||||
self.load_model()
|
||||
|
||||
# 标准化输入
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
single_input = True
|
||||
else:
|
||||
single_input = False
|
||||
|
||||
if not texts:
|
||||
return np.zeros((0, self.config.dimension), dtype=np.float32)
|
||||
|
||||
# 使用配置的批次大小
|
||||
if batch_size is None:
|
||||
batch_size = self.config.batch_size
|
||||
|
||||
# 生成嵌入
|
||||
try:
|
||||
embeddings = self._model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=show_progress,
|
||||
normalize_embeddings=normalize and self.config.normalization,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
|
||||
# 确保是2D数组
|
||||
if embeddings.ndim == 1:
|
||||
embeddings = embeddings.reshape(1, -1)
|
||||
|
||||
self._total_encoded += len(texts)
|
||||
|
||||
# 如果是单个输入,返回1D数组
|
||||
if single_input:
|
||||
return embeddings[0]
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成嵌入失败: {e}")
|
||||
raise
|
||||
|
||||
def encode_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
批量生成嵌入(多线程优化)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
batch_size: 批次大小
|
||||
num_workers: 工作线程数(默认使用初始化时的值)
|
||||
show_progress: 是否显示进度条
|
||||
|
||||
Returns:
|
||||
嵌入向量 (N x D)
|
||||
"""
|
||||
if not texts:
|
||||
return np.zeros((0, self.config.dimension), dtype=np.float32)
|
||||
|
||||
# 单线程模式
|
||||
num_workers = num_workers if num_workers is not None else self.num_workers
|
||||
if num_workers == 1:
|
||||
return self.encode(texts, batch_size=batch_size, show_progress=show_progress)
|
||||
|
||||
# 多线程模式
|
||||
logger.info(f"使用 {num_workers} 个线程生成 {len(texts)} 个嵌入")
|
||||
|
||||
# 分批
|
||||
batch_size = batch_size or self.config.batch_size
|
||||
batches = [
|
||||
texts[i:i + batch_size]
|
||||
for i in range(0, len(texts), batch_size)
|
||||
]
|
||||
|
||||
# 多线程生成
|
||||
all_embeddings = []
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
# 提交任务
|
||||
future_to_batch = {
|
||||
executor.submit(
|
||||
self.encode,
|
||||
batch,
|
||||
batch_size,
|
||||
False, # 不显示进度条(多线程时会混乱)
|
||||
): i
|
||||
for i, batch in enumerate(batches)
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(future_to_batch):
|
||||
batch_idx = future_to_batch[future]
|
||||
try:
|
||||
embeddings = future.result()
|
||||
all_embeddings.append((batch_idx, embeddings))
|
||||
except Exception as e:
|
||||
logger.error(f"批次 {batch_idx} 生成嵌入失败: {e}")
|
||||
raise
|
||||
|
||||
# 按顺序合并
|
||||
all_embeddings.sort(key=lambda x: x[0])
|
||||
final_embeddings = np.concatenate([emb for _, emb in all_embeddings], axis=0)
|
||||
|
||||
return final_embeddings
|
||||
|
||||
def encode_with_cache(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
生成嵌入(带缓存)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
batch_size: 批次大小
|
||||
show_progress: 是否显示进度条
|
||||
|
||||
Returns:
|
||||
嵌入向量 (N x D)
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return self.encode(texts, batch_size, show_progress)
|
||||
|
||||
# 分离缓存命中和未命中的文本
|
||||
cached_embeddings = []
|
||||
uncached_texts = []
|
||||
uncached_indices = []
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
cache_key = self._get_cache_key(text)
|
||||
|
||||
with self._cache_lock:
|
||||
if cache_key in self._embedding_cache:
|
||||
cached_embeddings.append((i, self._embedding_cache[cache_key]))
|
||||
self._cache_hits += 1
|
||||
else:
|
||||
uncached_texts.append(text)
|
||||
uncached_indices.append(i)
|
||||
self._cache_misses += 1
|
||||
|
||||
# 生成未缓存的嵌入
|
||||
if uncached_texts:
|
||||
new_embeddings = self.encode(
|
||||
uncached_texts,
|
||||
batch_size,
|
||||
show_progress,
|
||||
)
|
||||
|
||||
# 更新缓存
|
||||
with self._cache_lock:
|
||||
for text, embedding in zip(uncached_texts, new_embeddings):
|
||||
cache_key = self._get_cache_key(text)
|
||||
self._embedding_cache[cache_key] = embedding.copy()
|
||||
|
||||
# 合并结果
|
||||
for idx, embedding in zip(uncached_indices, new_embeddings):
|
||||
cached_embeddings.append((idx, embedding))
|
||||
|
||||
# 按原始顺序排序
|
||||
cached_embeddings.sort(key=lambda x: x[0])
|
||||
final_embeddings = np.array([emb for _, emb in cached_embeddings])
|
||||
|
||||
return final_embeddings
|
||||
|
||||
def save_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None:
|
||||
"""
|
||||
保存缓存到磁盘
|
||||
|
||||
Args:
|
||||
cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl)
|
||||
"""
|
||||
if cache_path is None:
|
||||
if self.cache_dir is None:
|
||||
raise ValueError("未指定缓存目录")
|
||||
cache_path = self.cache_dir / "embeddings_cache.pkl"
|
||||
|
||||
cache_path = Path(cache_path)
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with self._cache_lock:
|
||||
with open(cache_path, "wb") as f:
|
||||
pickle.dump(self._embedding_cache, f)
|
||||
|
||||
logger.info(f"缓存已保存: {cache_path} ({len(self._embedding_cache)} 条)")
|
||||
|
||||
def load_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None:
|
||||
"""
|
||||
从磁盘加载缓存
|
||||
|
||||
Args:
|
||||
cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl)
|
||||
"""
|
||||
if cache_path is None:
|
||||
if self.cache_dir is None:
|
||||
raise ValueError("未指定缓存目录")
|
||||
cache_path = self.cache_dir / "embeddings_cache.pkl"
|
||||
|
||||
cache_path = Path(cache_path)
|
||||
if not cache_path.exists():
|
||||
logger.warning(f"缓存文件不存在: {cache_path}")
|
||||
return
|
||||
|
||||
with self._cache_lock:
|
||||
with open(cache_path, "rb") as f:
|
||||
self._embedding_cache = pickle.load(f)
|
||||
|
||||
logger.info(f"缓存已加载: {cache_path} ({len(self._embedding_cache)} 条)")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空缓存"""
|
||||
with self._cache_lock:
|
||||
count = len(self._embedding_cache)
|
||||
self._embedding_cache.clear()
|
||||
logger.info(f"已清空缓存: {count} 条")
|
||||
|
||||
def check_model_consistency(
|
||||
self,
|
||||
stored_embeddings: np.ndarray,
|
||||
sample_texts: List[str] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
检查模型一致性
|
||||
|
||||
Args:
|
||||
stored_embeddings: 存储的嵌入向量
|
||||
sample_texts: 样本文本(用于重新生成对比)
|
||||
|
||||
Returns:
|
||||
(是否一致, 详细信息)
|
||||
"""
|
||||
# 检查维度
|
||||
if stored_embeddings.shape[1] != self.config.dimension:
|
||||
return False, f"维度不匹配: 期望 {self.config.dimension}, 实际 {stored_embeddings.shape[1]}"
|
||||
|
||||
# 如果提供了样本文本,重新生成并比较
|
||||
if sample_texts:
|
||||
try:
|
||||
new_embeddings = self.encode(sample_texts[:5]) # 只比较前5个
|
||||
|
||||
# 计算相似度
|
||||
similarities = np.dot(
|
||||
stored_embeddings[:5],
|
||||
new_embeddings.T,
|
||||
).diagonal()
|
||||
|
||||
# 检查相似度
|
||||
if np.mean(similarities) < 0.95:
|
||||
return False, f"模型可能已更改,平均相似度: {np.mean(similarities):.3f}"
|
||||
|
||||
return True, f"模型一致,平均相似度: {np.mean(similarities):.3f}"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"一致性检查失败: {e}"
|
||||
|
||||
return True, "维度匹配"
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
return {
|
||||
"model_name": self.config.model_name,
|
||||
"dimension": self.config.dimension,
|
||||
"max_seq_length": self.config.max_seq_length,
|
||||
"batch_size": self.config.batch_size,
|
||||
"normalization": self.config.normalization,
|
||||
"pooling": self.config.pooling,
|
||||
"model_loaded": self._model is not None,
|
||||
"cache_enabled": self.enable_cache,
|
||||
"cache_size": len(self._embedding_cache),
|
||||
"total_encoded": self._total_encoded,
|
||||
"cache_hits": self._cache_hits,
|
||||
"cache_misses": self._cache_misses,
|
||||
}
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""获取嵌入维度"""
|
||||
return self.config.dimension
|
||||
|
||||
def _get_cache_key(self, text: str) -> str:
|
||||
"""
|
||||
生成缓存键
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
|
||||
Returns:
|
||||
缓存键(SHA256哈希)
|
||||
"""
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
"""模型是否已加载"""
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def cache_hit_rate(self) -> float:
|
||||
"""缓存命中率"""
|
||||
total = self._cache_hits + self._cache_misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self._cache_hits / total
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"EmbeddingManager(model={self.config.model_name}, "
|
||||
f"dim={self.config.dimension}, "
|
||||
f"loaded={self.is_model_loaded}, "
|
||||
f"cache={len(self._embedding_cache)})"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def create_embedding_manager_from_config(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
dimension: int,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
enable_cache: bool = True,
|
||||
num_workers: int = 1,
|
||||
**config_kwargs,
|
||||
) -> EmbeddingManager:
|
||||
"""
|
||||
从自定义配置创建嵌入管理器
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
model_path: HuggingFace模型路径
|
||||
dimension: 输出维度
|
||||
cache_dir: 缓存目录
|
||||
enable_cache: 是否启用缓存
|
||||
num_workers: 工作线程数
|
||||
**config_kwargs: 其他配置参数
|
||||
|
||||
Returns:
|
||||
嵌入管理器实例
|
||||
"""
|
||||
# 创建自定义配置
|
||||
config = get_custom_config(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
dimension=dimension,
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
# 创建管理器
|
||||
return EmbeddingManager(
|
||||
config=config,
|
||||
cache_dir=cache_dir,
|
||||
enable_cache=enable_cache,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
72
src/A_memorix/core/embedding/presets.py
Normal file
72
src/A_memorix/core/embedding/presets.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
嵌入模型配置模块
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelConfig:
|
||||
"""
|
||||
嵌入模型配置
|
||||
|
||||
属性:
|
||||
model_name: 模型描述名称
|
||||
model_path: 实际加载路径(Local or HF)
|
||||
dimension: 嵌入向量维度
|
||||
max_seq_length: 最大序列长度
|
||||
batch_size: 编码批次大小
|
||||
model_size_mb: 估计显存占用
|
||||
description: 模型说明
|
||||
normalization: 是否自动归一化
|
||||
pooling: 池化策略 (mean, cls, max)
|
||||
cache_dir: 模型缓存目录
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
model_path: str
|
||||
dimension: int
|
||||
max_seq_length: int = 512
|
||||
batch_size: int = 32
|
||||
model_size_mb: int = 100
|
||||
description: str = ""
|
||||
normalization: bool = True
|
||||
pooling: str = "mean"
|
||||
cache_dir: Optional[Union[str, Path]] = None
|
||||
|
||||
|
||||
def validate_config_compatibility(
|
||||
config1: EmbeddingModelConfig, config2: EmbeddingModelConfig
|
||||
) -> bool:
|
||||
"""检查两个配置是否兼容(主要看维度)"""
|
||||
return config1.dimension == config2.dimension
|
||||
|
||||
|
||||
def are_models_compatible(
|
||||
config1: EmbeddingModelConfig, config2: EmbeddingModelConfig
|
||||
) -> bool:
|
||||
"""检查模型是否完全相同(用于热切换判断)"""
|
||||
return (
|
||||
config1.model_path == config2.model_path
|
||||
and config1.dimension == config2.dimension
|
||||
and config1.pooling == config2.pooling
|
||||
)
|
||||
|
||||
|
||||
def get_custom_config(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
dimension: int,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
**kwargs,
|
||||
) -> EmbeddingModelConfig:
|
||||
"""创建自定义模型配置"""
|
||||
return EmbeddingModelConfig(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
dimension=dimension,
|
||||
cache_dir=cache_dir,
|
||||
**kwargs,
|
||||
)
|
||||
54
src/A_memorix/core/retrieval/__init__.py
Normal file
54
src/A_memorix/core/retrieval/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""检索模块 - 双路检索与排序"""
|
||||
|
||||
from .dual_path import (
|
||||
DualPathRetriever,
|
||||
RetrievalStrategy,
|
||||
RetrievalResult,
|
||||
DualPathRetrieverConfig,
|
||||
TemporalQueryOptions,
|
||||
FusionConfig,
|
||||
RelationIntentConfig,
|
||||
)
|
||||
from .pagerank import (
|
||||
PersonalizedPageRank,
|
||||
PageRankConfig,
|
||||
create_ppr_from_graph,
|
||||
)
|
||||
from .threshold import (
|
||||
DynamicThresholdFilter,
|
||||
ThresholdMethod,
|
||||
ThresholdConfig,
|
||||
)
|
||||
from .sparse_bm25 import (
|
||||
SparseBM25Index,
|
||||
SparseBM25Config,
|
||||
)
|
||||
from .graph_relation_recall import (
|
||||
GraphRelationRecallConfig,
|
||||
GraphRelationRecallService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# DualPathRetriever
|
||||
"DualPathRetriever",
|
||||
"RetrievalStrategy",
|
||||
"RetrievalResult",
|
||||
"DualPathRetrieverConfig",
|
||||
"TemporalQueryOptions",
|
||||
"FusionConfig",
|
||||
"RelationIntentConfig",
|
||||
# PersonalizedPageRank
|
||||
"PersonalizedPageRank",
|
||||
"PageRankConfig",
|
||||
"create_ppr_from_graph",
|
||||
# DynamicThresholdFilter
|
||||
"DynamicThresholdFilter",
|
||||
"ThresholdMethod",
|
||||
"ThresholdConfig",
|
||||
# Sparse BM25
|
||||
"SparseBM25Index",
|
||||
"SparseBM25Config",
|
||||
# Graph relation recall
|
||||
"GraphRelationRecallConfig",
|
||||
"GraphRelationRecallService",
|
||||
]
|
||||
1871
src/A_memorix/core/retrieval/dual_path.py
Normal file
1871
src/A_memorix/core/retrieval/dual_path.py
Normal file
File diff suppressed because it is too large
Load Diff
272
src/A_memorix/core/retrieval/graph_relation_recall.py
Normal file
272
src/A_memorix/core/retrieval/graph_relation_recall.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Graph-assisted relation candidate recall for relation-oriented queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.GraphRelationRecall")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphRelationRecallConfig:
|
||||
"""Configuration for controlled graph relation recall."""
|
||||
|
||||
enabled: bool = True
|
||||
candidate_k: int = 24
|
||||
max_hop: int = 1
|
||||
allow_two_hop_pair: bool = True
|
||||
max_paths: int = 4
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.enabled = bool(self.enabled)
|
||||
self.candidate_k = max(1, int(self.candidate_k))
|
||||
self.max_hop = max(1, int(self.max_hop))
|
||||
self.allow_two_hop_pair = bool(self.allow_two_hop_pair)
|
||||
self.max_paths = max(1, int(self.max_paths))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphRelationCandidate:
|
||||
"""A graph-derived relation candidate before retriever-side fusion."""
|
||||
|
||||
hash_value: str
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
graph_seed_entities: List[str]
|
||||
graph_hops: int
|
||||
graph_candidate_type: str
|
||||
supporting_paragraph_count: int
|
||||
|
||||
def to_payload(self) -> Dict[str, Any]:
|
||||
content = f"{self.subject} {self.predicate} {self.object}"
|
||||
return {
|
||||
"hash": self.hash_value,
|
||||
"content": content,
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object": self.object,
|
||||
"confidence": self.confidence,
|
||||
"graph_seed_entities": list(self.graph_seed_entities),
|
||||
"graph_hops": int(self.graph_hops),
|
||||
"graph_candidate_type": self.graph_candidate_type,
|
||||
"supporting_paragraph_count": int(self.supporting_paragraph_count),
|
||||
}
|
||||
|
||||
|
||||
class GraphRelationRecallService:
|
||||
"""Collect relation candidates from the entity graph in a controlled way."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
config: Optional[GraphRelationRecallConfig] = None,
|
||||
) -> None:
|
||||
self.graph_store = graph_store
|
||||
self.metadata_store = metadata_store
|
||||
self.config = config or GraphRelationRecallConfig()
|
||||
|
||||
def recall(
|
||||
self,
|
||||
*,
|
||||
seed_entities: Sequence[str],
|
||||
) -> List[GraphRelationCandidate]:
|
||||
if not self.config.enabled:
|
||||
return []
|
||||
if self.graph_store is None or self.metadata_store is None:
|
||||
return []
|
||||
|
||||
seeds = self._normalize_seed_entities(seed_entities)
|
||||
if not seeds:
|
||||
return []
|
||||
|
||||
seen_hashes: Set[str] = set()
|
||||
candidates: List[GraphRelationCandidate] = []
|
||||
|
||||
if len(seeds) >= 2:
|
||||
self._collect_direct_pair_candidates(
|
||||
seed_a=seeds[0],
|
||||
seed_b=seeds[1],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
if (
|
||||
len(candidates) < 3
|
||||
and self.config.allow_two_hop_pair
|
||||
and len(candidates) < self.config.candidate_k
|
||||
):
|
||||
self._collect_two_hop_pair_candidates(
|
||||
seed_a=seeds[0],
|
||||
seed_b=seeds[1],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
else:
|
||||
self._collect_one_hop_seed_candidates(
|
||||
seed=seeds[0],
|
||||
seen_hashes=seen_hashes,
|
||||
out=candidates,
|
||||
)
|
||||
|
||||
return candidates[: self.config.candidate_k]
|
||||
|
||||
def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]:
|
||||
out: List[str] = []
|
||||
seen = set()
|
||||
for raw in list(seed_entities)[:2]:
|
||||
resolved = None
|
||||
try:
|
||||
resolved = self.graph_store.find_node(str(raw), ignore_case=True)
|
||||
except Exception:
|
||||
resolved = None
|
||||
if not resolved:
|
||||
continue
|
||||
canon = str(resolved).strip().lower()
|
||||
if not canon or canon in seen:
|
||||
continue
|
||||
seen.add(canon)
|
||||
out.append(str(resolved))
|
||||
return out
|
||||
|
||||
def _collect_direct_pair_candidates(
|
||||
self,
|
||||
*,
|
||||
seed_a: str,
|
||||
seed_b: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
relation_hashes = []
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b))
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a))
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="direct_pair",
|
||||
graph_hops=1,
|
||||
graph_seed_entities=[seed_a, seed_b],
|
||||
)
|
||||
|
||||
def _collect_two_hop_pair_candidates(
|
||||
self,
|
||||
*,
|
||||
seed_a: str,
|
||||
seed_b: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
try:
|
||||
paths = self.graph_store.find_paths(
|
||||
seed_a,
|
||||
seed_b,
|
||||
max_depth=2,
|
||||
max_paths=self.config.max_paths,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"graph two-hop recall skipped: {e}")
|
||||
return
|
||||
|
||||
for path_nodes in paths:
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3:
|
||||
continue
|
||||
if len(path_nodes) != 3:
|
||||
continue
|
||||
for idx in range(len(path_nodes) - 1):
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
u = str(path_nodes[idx])
|
||||
v = str(path_nodes[idx + 1])
|
||||
relation_hashes = []
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v))
|
||||
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u))
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="two_hop_pair",
|
||||
graph_hops=2,
|
||||
graph_seed_entities=[seed_a, seed_b],
|
||||
)
|
||||
|
||||
def _collect_one_hop_seed_candidates(
|
||||
self,
|
||||
*,
|
||||
seed: str,
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
) -> None:
|
||||
try:
|
||||
relation_hashes = self.graph_store.get_incident_relation_hashes(
|
||||
seed,
|
||||
limit=self.config.candidate_k,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"graph one-hop recall skipped: {e}")
|
||||
return
|
||||
self._append_relation_hashes(
|
||||
relation_hashes=relation_hashes,
|
||||
seen_hashes=seen_hashes,
|
||||
out=out,
|
||||
candidate_type="one_hop_seed",
|
||||
graph_hops=min(1, self.config.max_hop),
|
||||
graph_seed_entities=[seed],
|
||||
)
|
||||
|
||||
def _append_relation_hashes(
|
||||
self,
|
||||
*,
|
||||
relation_hashes: Sequence[str],
|
||||
seen_hashes: Set[str],
|
||||
out: List[GraphRelationCandidate],
|
||||
candidate_type: str,
|
||||
graph_hops: int,
|
||||
graph_seed_entities: Sequence[str],
|
||||
) -> None:
|
||||
for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}):
|
||||
if len(out) >= self.config.candidate_k:
|
||||
break
|
||||
if relation_hash in seen_hashes:
|
||||
continue
|
||||
candidate = self._build_candidate(
|
||||
relation_hash=relation_hash,
|
||||
candidate_type=candidate_type,
|
||||
graph_hops=graph_hops,
|
||||
graph_seed_entities=graph_seed_entities,
|
||||
)
|
||||
if candidate is None:
|
||||
continue
|
||||
seen_hashes.add(relation_hash)
|
||||
out.append(candidate)
|
||||
|
||||
def _build_candidate(
|
||||
self,
|
||||
*,
|
||||
relation_hash: str,
|
||||
candidate_type: str,
|
||||
graph_hops: int,
|
||||
graph_seed_entities: Sequence[str],
|
||||
) -> Optional[GraphRelationCandidate]:
|
||||
relation = self.metadata_store.get_relation(relation_hash)
|
||||
if relation is None:
|
||||
return None
|
||||
supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash)
|
||||
return GraphRelationCandidate(
|
||||
hash_value=relation_hash,
|
||||
subject=str(relation.get("subject", "")),
|
||||
predicate=str(relation.get("predicate", "")),
|
||||
object=str(relation.get("object", "")),
|
||||
confidence=float(relation.get("confidence", 1.0) or 1.0),
|
||||
graph_seed_entities=[str(x) for x in graph_seed_entities],
|
||||
graph_hops=int(graph_hops),
|
||||
graph_candidate_type=str(candidate_type),
|
||||
supporting_paragraph_count=len(supporting_paragraphs),
|
||||
)
|
||||
482
src/A_memorix/core/retrieval/pagerank.py
Normal file
482
src/A_memorix/core/retrieval/pagerank.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
Personalized PageRank实现
|
||||
|
||||
提供个性化的图节点排序功能。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..storage import GraphStore
|
||||
from ..utils.matcher import AhoCorasick
|
||||
|
||||
logger = get_logger("A_Memorix.PersonalizedPageRank")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PageRankConfig:
|
||||
"""
|
||||
PageRank配置
|
||||
|
||||
属性:
|
||||
alpha: 阻尼系数(0-1之间)
|
||||
max_iter: 最大迭代次数
|
||||
tol: 收敛阈值
|
||||
normalize: 是否归一化结果
|
||||
min_iterations: 最小迭代次数
|
||||
"""
|
||||
|
||||
alpha: float = 0.85
|
||||
max_iter: int = 100
|
||||
tol: float = 1e-6
|
||||
normalize: bool = True
|
||||
min_iterations: int = 20
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
if not 0 <= self.alpha < 1:
|
||||
raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}")
|
||||
|
||||
if self.max_iter <= 0:
|
||||
raise ValueError(f"max_iter必须大于0: {self.max_iter}")
|
||||
|
||||
if self.tol <= 0:
|
||||
raise ValueError(f"tol必须大于0: {self.tol}")
|
||||
|
||||
if self.min_iterations < 0:
|
||||
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
|
||||
|
||||
if self.min_iterations >= self.max_iter:
|
||||
raise ValueError(f"min_iterations必须小于max_iter")
|
||||
|
||||
|
||||
class PersonalizedPageRank:
|
||||
"""
|
||||
Personalized PageRank计算器
|
||||
|
||||
功能:
|
||||
- 个性化向量支持
|
||||
- 快速收敛检测
|
||||
- 结果归一化
|
||||
- 批量计算
|
||||
- 统计信息
|
||||
|
||||
参数:
|
||||
graph_store: 图存储
|
||||
config: PageRank配置
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_store: GraphStore,
|
||||
config: Optional[PageRankConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化PPR计算器
|
||||
|
||||
Args:
|
||||
graph_store: 图存储
|
||||
config: PageRank配置
|
||||
"""
|
||||
self.graph_store = graph_store
|
||||
self.config = config or PageRankConfig()
|
||||
|
||||
# 统计信息
|
||||
self._total_computations = 0
|
||||
self._total_iterations = 0
|
||||
self._convergence_history: List[int] = []
|
||||
|
||||
logger.info(
|
||||
f"PersonalizedPageRank 初始化: "
|
||||
f"alpha={self.config.alpha}, "
|
||||
f"max_iter={self.config.max_iter}"
|
||||
)
|
||||
|
||||
# 缓存 Aho-Corasick 匹配器
|
||||
self._ac_matcher: Optional[AhoCorasick] = None
|
||||
self._ac_nodes_count = 0
|
||||
|
||||
def compute(
|
||||
self,
|
||||
personalization: Optional[Dict[str, float]] = None,
|
||||
alpha: Optional[float] = None,
|
||||
max_iter: Optional[int] = None,
|
||||
normalize: Optional[bool] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
计算Personalized PageRank
|
||||
|
||||
Args:
|
||||
personalization: 个性化向量 {节点名: 权重}
|
||||
alpha: 阻尼系数(覆盖配置值)
|
||||
max_iter: 最大迭代次数(覆盖配置值)
|
||||
normalize: 是否归一化(覆盖配置值)
|
||||
|
||||
Returns:
|
||||
节点PageRank值字典 {节点名: 分数}
|
||||
"""
|
||||
# 使用覆盖值或配置值
|
||||
alpha = alpha if alpha is not None else self.config.alpha
|
||||
max_iter = max_iter if max_iter is not None else self.config.max_iter
|
||||
normalize = normalize if normalize is not None else self.config.normalize
|
||||
|
||||
# 调用GraphStore的compute_pagerank
|
||||
scores = self.graph_store.compute_pagerank(
|
||||
personalization=personalization,
|
||||
alpha=alpha,
|
||||
max_iter=max_iter,
|
||||
tol=self.config.tol,
|
||||
)
|
||||
|
||||
# 归一化(如果需要)
|
||||
if normalize and scores:
|
||||
total = sum(scores.values())
|
||||
if total > 0:
|
||||
scores = {node: score / total for node, score in scores.items()}
|
||||
|
||||
# 更新统计
|
||||
self._total_computations += 1
|
||||
|
||||
logger.debug(
|
||||
f"PPR计算完成: {len(scores)} 个节点, "
|
||||
f"personalization_nodes={len(personalization) if personalization else 0}"
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def compute_batch(
|
||||
self,
|
||||
personalization_list: List[Dict[str, float]],
|
||||
normalize: bool = True,
|
||||
) -> List[Dict[str, float]]:
|
||||
"""
|
||||
批量计算PPR
|
||||
|
||||
Args:
|
||||
personalization_list: 个性化向量列表
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典列表
|
||||
"""
|
||||
results = []
|
||||
|
||||
for i, personalization in enumerate(personalization_list):
|
||||
logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR")
|
||||
scores = self.compute(personalization=personalization, normalize=normalize)
|
||||
results.append(scores)
|
||||
|
||||
return results
|
||||
|
||||
def compute_for_entities(
|
||||
self,
|
||||
entities: List[str],
|
||||
weights: Optional[List[float]] = None,
|
||||
normalize: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
为实体列表计算PPR
|
||||
|
||||
Args:
|
||||
entities: 实体列表
|
||||
weights: 权重列表(默认均匀权重)
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典
|
||||
"""
|
||||
if not entities:
|
||||
logger.warning("实体列表为空,返回均匀PPR")
|
||||
return self.compute(personalization=None, normalize=normalize)
|
||||
|
||||
# 构建个性化向量
|
||||
if weights is None:
|
||||
weights = [1.0] * len(entities)
|
||||
|
||||
if len(weights) != len(entities):
|
||||
raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}")
|
||||
|
||||
personalization = {entity: weight for entity, weight in zip(entities, weights)}
|
||||
|
||||
return self.compute(personalization=personalization, normalize=normalize)
|
||||
|
||||
def compute_for_query(
|
||||
self,
|
||||
query: str,
|
||||
entity_extractor: Optional[callable] = None,
|
||||
normalize: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
为查询计算PPR
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
entity_extractor: 实体提取函数(可选)
|
||||
normalize: 是否归一化
|
||||
|
||||
Returns:
|
||||
PageRank值字典
|
||||
"""
|
||||
# 提取实体
|
||||
if entity_extractor is not None:
|
||||
entities = entity_extractor(query)
|
||||
else:
|
||||
# 简单实现:基于图中的节点匹配
|
||||
entities = self._extract_entities_from_query(query)
|
||||
|
||||
if not entities:
|
||||
logger.debug(f"未从查询中提取到实体: '{query}'")
|
||||
return self.compute(personalization=None, normalize=normalize)
|
||||
|
||||
# 计算PPR
|
||||
return self.compute_for_entities(entities, normalize=normalize)
|
||||
|
||||
def rank_nodes(
|
||||
self,
|
||||
scores: Dict[str, float],
|
||||
top_k: Optional[int] = None,
|
||||
min_score: float = 0.0,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
对节点排序
|
||||
|
||||
Args:
|
||||
scores: PageRank分数字典
|
||||
top_k: 返回前k个节点(None表示全部)
|
||||
min_score: 最小分数阈值
|
||||
|
||||
Returns:
|
||||
排序后的节点列表 [(节点名, 分数), ...]
|
||||
"""
|
||||
# 过滤低分节点
|
||||
filtered = [(node, score) for node, score in scores.items() if score >= min_score]
|
||||
|
||||
# 按分数降序排序
|
||||
sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 返回top_k
|
||||
if top_k is not None:
|
||||
sorted_nodes = sorted_nodes[:top_k]
|
||||
|
||||
return sorted_nodes
|
||||
|
||||
def get_personalization_vector(
|
||||
self,
|
||||
nodes: List[str],
|
||||
method: str = "uniform",
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
生成个性化向量
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
method: 生成方法
|
||||
- "uniform": 均匀权重
|
||||
- "degree": 按度数加权
|
||||
- "inverse_degree": 按度数反比加权
|
||||
|
||||
Returns:
|
||||
个性化向量 {节点名: 权重}
|
||||
"""
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
if method == "uniform":
|
||||
# 均匀权重
|
||||
weight = 1.0 / len(nodes)
|
||||
return {node: weight for node in nodes}
|
||||
|
||||
elif method == "degree":
|
||||
# 按度数加权
|
||||
node_degrees = {}
|
||||
for node in nodes:
|
||||
neighbors = self.graph_store.get_neighbors(node)
|
||||
node_degrees[node] = len(neighbors)
|
||||
|
||||
total_degree = sum(node_degrees.values())
|
||||
if total_degree > 0:
|
||||
return {node: degree / total_degree for node, degree in node_degrees.items()}
|
||||
else:
|
||||
return {node: 1.0 / len(nodes) for node in nodes}
|
||||
|
||||
elif method == "inverse_degree":
|
||||
# 按度数反比加权
|
||||
node_degrees = {}
|
||||
for node in nodes:
|
||||
neighbors = self.graph_store.get_neighbors(node)
|
||||
node_degrees[node] = len(neighbors)
|
||||
|
||||
# 反度数
|
||||
inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()}
|
||||
total_inv = sum(inv_degrees.values())
|
||||
|
||||
if total_inv > 0:
|
||||
return {node: inv / total_inv for node, inv in inv_degrees.items()}
|
||||
else:
|
||||
return {node: 1.0 / len(nodes) for node in nodes}
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的个性化向量生成方法: {method}")
|
||||
|
||||
def compare_scores(
|
||||
self,
|
||||
scores1: Dict[str, float],
|
||||
scores2: Dict[str, float],
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
比较两组PPR分数
|
||||
|
||||
Args:
|
||||
scores1: 第一组分数
|
||||
scores2: 第二组分数
|
||||
|
||||
Returns:
|
||||
比较结果 {
|
||||
"common_nodes": {节点: (score1, score2)},
|
||||
"only_in_1": {节点: score1},
|
||||
"only_in_2": {节点: score2},
|
||||
}
|
||||
"""
|
||||
common_nodes = {}
|
||||
only_in_1 = {}
|
||||
only_in_2 = {}
|
||||
|
||||
all_nodes = set(scores1.keys()) | set(scores2.keys())
|
||||
|
||||
for node in all_nodes:
|
||||
if node in scores1 and node in scores2:
|
||||
common_nodes[node] = (scores1[node], scores2[node])
|
||||
elif node in scores1:
|
||||
only_in_1[node] = scores1[node]
|
||||
else:
|
||||
only_in_2[node] = scores2[node]
|
||||
|
||||
return {
|
||||
"common_nodes": common_nodes,
|
||||
"only_in_1": only_in_1,
|
||||
"only_in_2": only_in_2,
|
||||
}
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
avg_iterations = (
|
||||
self._total_iterations / self._total_computations
|
||||
if self._total_computations > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"config": {
|
||||
"alpha": self.config.alpha,
|
||||
"max_iter": self.config.max_iter,
|
||||
"tol": self.config.tol,
|
||||
"normalize": self.config.normalize,
|
||||
"min_iterations": self.config.min_iterations,
|
||||
},
|
||||
"statistics": {
|
||||
"total_computations": self._total_computations,
|
||||
"total_iterations": self._total_iterations,
|
||||
"avg_iterations": avg_iterations,
|
||||
"convergence_history": self._convergence_history.copy(),
|
||||
},
|
||||
"graph": {
|
||||
"num_nodes": self.graph_store.num_nodes,
|
||||
"num_edges": self.graph_store.num_edges,
|
||||
},
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""重置统计信息"""
|
||||
self._total_computations = 0
|
||||
self._total_iterations = 0
|
||||
self._convergence_history.clear()
|
||||
logger.info("统计信息已重置")
|
||||
|
||||
def _extract_entities_from_query(self, query: str) -> List[str]:
|
||||
"""
|
||||
从查询中提取实体(简化实现)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
# 获取所有节点
|
||||
all_nodes = self.graph_store.get_nodes()
|
||||
if not all_nodes:
|
||||
return []
|
||||
|
||||
# 检查是否需要更新 Aho-Corasick 匹配器
|
||||
if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes):
|
||||
self._ac_matcher = AhoCorasick()
|
||||
for node in all_nodes:
|
||||
# 统一转为小写进行不区分大小写匹配
|
||||
self._ac_matcher.add_pattern(node.lower())
|
||||
self._ac_matcher.build()
|
||||
self._ac_nodes_count = len(all_nodes)
|
||||
|
||||
# 执行匹配
|
||||
query_lower = query.lower()
|
||||
stats = self._ac_matcher.find_all(query_lower)
|
||||
|
||||
# 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值)
|
||||
# 为了简单,AC 中 add_pattern 存的是小写
|
||||
# 我们需要一个映射:小写 -> 原始
|
||||
node_map = {node.lower(): node for node in all_nodes}
|
||||
entities = [node_map[low_name] for low_name in stats.keys()]
|
||||
|
||||
return entities
|
||||
|
||||
@property
|
||||
def num_computations(self) -> int:
|
||||
"""计算次数"""
|
||||
return self._total_computations
|
||||
|
||||
@property
|
||||
def avg_iterations(self) -> float:
|
||||
"""平均迭代次数"""
|
||||
if self._total_computations == 0:
|
||||
return 0.0
|
||||
return self._total_iterations / self._total_computations
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PersonalizedPageRank("
|
||||
f"alpha={self.config.alpha}, "
|
||||
f"computations={self._total_computations})"
|
||||
)
|
||||
|
||||
|
||||
def create_ppr_from_graph(
|
||||
graph_store: GraphStore,
|
||||
alpha: float = 0.85,
|
||||
max_iter: int = 100,
|
||||
) -> PersonalizedPageRank:
|
||||
"""
|
||||
从图存储创建PPR计算器
|
||||
|
||||
Args:
|
||||
graph_store: 图存储
|
||||
alpha: 阻尼系数
|
||||
max_iter: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
PPR计算器实例
|
||||
"""
|
||||
config = PageRankConfig(
|
||||
alpha=alpha,
|
||||
max_iter=max_iter,
|
||||
)
|
||||
|
||||
return PersonalizedPageRank(
|
||||
graph_store=graph_store,
|
||||
config=config,
|
||||
)
|
||||
401
src/A_memorix/core/retrieval/sparse_bm25.py
Normal file
401
src/A_memorix/core/retrieval/sparse_bm25.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
稀疏检索组件(FTS5 + BM25)
|
||||
|
||||
支持:
|
||||
- 懒加载索引连接
|
||||
- jieba / char n-gram 分词
|
||||
- 可卸载并收缩 SQLite 内存缓存
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..storage import MetadataStore
|
||||
|
||||
logger = get_logger("A_Memorix.SparseBM25")
|
||||
|
||||
try:
|
||||
import jieba # type: ignore
|
||||
|
||||
HAS_JIEBA = True
|
||||
except Exception:
|
||||
HAS_JIEBA = False
|
||||
jieba = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseBM25Config:
|
||||
"""BM25 稀疏检索配置。"""
|
||||
|
||||
enabled: bool = True
|
||||
backend: str = "fts5"
|
||||
lazy_load: bool = True
|
||||
mode: str = "auto" # auto | fallback_only | hybrid
|
||||
tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram
|
||||
jieba_user_dict: str = ""
|
||||
char_ngram_n: int = 2
|
||||
candidate_k: int = 80
|
||||
max_doc_len: int = 2000
|
||||
enable_ngram_fallback_index: bool = True
|
||||
enable_like_fallback: bool = False
|
||||
enable_relation_sparse_fallback: bool = True
|
||||
relation_candidate_k: int = 60
|
||||
relation_max_doc_len: int = 512
|
||||
unload_on_disable: bool = True
|
||||
shrink_memory_on_unload: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.backend = str(self.backend or "fts5").strip().lower()
|
||||
self.mode = str(self.mode or "auto").strip().lower()
|
||||
self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower()
|
||||
self.char_ngram_n = max(1, int(self.char_ngram_n))
|
||||
self.candidate_k = max(1, int(self.candidate_k))
|
||||
self.max_doc_len = max(0, int(self.max_doc_len))
|
||||
self.relation_candidate_k = max(1, int(self.relation_candidate_k))
|
||||
self.relation_max_doc_len = max(0, int(self.relation_max_doc_len))
|
||||
if self.backend != "fts5":
|
||||
raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}")
|
||||
if self.mode not in {"auto", "fallback_only", "hybrid"}:
|
||||
raise ValueError(f"sparse.mode 非法: {self.mode}")
|
||||
if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}:
|
||||
raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}")
|
||||
|
||||
|
||||
class SparseBM25Index:
|
||||
"""
|
||||
基于 SQLite FTS5 的 BM25 检索适配层。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: MetadataStore,
|
||||
config: Optional[SparseBM25Config] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.config = config or SparseBM25Config()
|
||||
self._conn: Optional[sqlite3.Connection] = None
|
||||
self._loaded: bool = False
|
||||
self._jieba_dict_loaded: bool = False
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self._loaded and self._conn is not None
|
||||
|
||||
def ensure_loaded(self) -> bool:
|
||||
"""按需加载 FTS 连接与索引。"""
|
||||
if not self.config.enabled:
|
||||
return False
|
||||
if self.loaded:
|
||||
return True
|
||||
|
||||
db_path = self.metadata_store.get_db_path()
|
||||
conn = sqlite3.connect(
|
||||
str(db_path),
|
||||
check_same_thread=False,
|
||||
timeout=30.0,
|
||||
)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute("PRAGMA temp_store=MEMORY")
|
||||
|
||||
if not self.metadata_store.ensure_fts_schema(conn=conn):
|
||||
conn.close()
|
||||
return False
|
||||
self.metadata_store.ensure_fts_backfilled(conn=conn)
|
||||
# 关系稀疏检索按独立开关加载,避免不必要的初始化开销。
|
||||
if self.config.enable_relation_sparse_fallback:
|
||||
self.metadata_store.ensure_relations_fts_schema(conn=conn)
|
||||
self.metadata_store.ensure_relations_fts_backfilled(conn=conn)
|
||||
if self.config.enable_ngram_fallback_index:
|
||||
self.metadata_store.ensure_paragraph_ngram_schema(conn=conn)
|
||||
self.metadata_store.ensure_paragraph_ngram_backfilled(
|
||||
n=self.config.char_ngram_n,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
self._conn = conn
|
||||
self._loaded = True
|
||||
self._prepare_tokenizer()
|
||||
logger.info(
|
||||
"SparseBM25Index loaded: "
|
||||
f"backend=fts5, tokenizer={self.config.tokenizer_mode}, mode={self.config.mode}"
|
||||
)
|
||||
return True
|
||||
|
||||
def _prepare_tokenizer(self) -> None:
|
||||
if self._jieba_dict_loaded:
|
||||
return
|
||||
if self.config.tokenizer_mode not in {"jieba", "mixed"}:
|
||||
return
|
||||
if not HAS_JIEBA:
|
||||
logger.warning("jieba 不可用,tokenizer 将退化为 char n-gram")
|
||||
return
|
||||
user_dict = str(self.config.jieba_user_dict or "").strip()
|
||||
if user_dict:
|
||||
try:
|
||||
jieba.load_userdict(user_dict) # type: ignore[union-attr]
|
||||
logger.info(f"已加载 jieba 用户词典: {user_dict}")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载 jieba 用户词典失败: {e}")
|
||||
self._jieba_dict_loaded = True
|
||||
|
||||
def _tokenize_jieba(self, text: str) -> List[str]:
|
||||
if not HAS_JIEBA:
|
||||
return []
|
||||
try:
|
||||
tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr]
|
||||
return [t.strip().lower() for t in tokens if t and t.strip()]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _tokenize_char_ngram(self, text: str, n: int) -> List[str]:
|
||||
compact = re.sub(r"\s+", "", text.lower())
|
||||
if not compact:
|
||||
return []
|
||||
if len(compact) < n:
|
||||
return [compact]
|
||||
return [compact[i : i + n] for i in range(0, len(compact) - n + 1)]
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
text = str(text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
mode = self.config.tokenizer_mode
|
||||
if mode == "jieba":
|
||||
tokens = self._tokenize_jieba(text)
|
||||
if tokens:
|
||||
return list(dict.fromkeys(tokens))
|
||||
return self._tokenize_char_ngram(text, self.config.char_ngram_n)
|
||||
|
||||
if mode == "mixed":
|
||||
toks = self._tokenize_jieba(text)
|
||||
toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n))
|
||||
return list(dict.fromkeys([t for t in toks if t]))
|
||||
|
||||
return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n)))
|
||||
|
||||
def _build_match_query(self, tokens: List[str]) -> str:
|
||||
safe_tokens: List[str] = []
|
||||
for token in tokens:
|
||||
t = token.replace('"', '""').strip()
|
||||
if not t:
|
||||
continue
|
||||
safe_tokens.append(f'"{t}"')
|
||||
if not safe_tokens:
|
||||
return ""
|
||||
# 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。
|
||||
return " OR ".join(safe_tokens[:64])
|
||||
|
||||
def _fallback_substring_search(
|
||||
self,
|
||||
tokens: List[str],
|
||||
limit: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。
|
||||
|
||||
说明:
|
||||
- FTS 索引当前采用 unicode61 tokenizer。
|
||||
- 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。
|
||||
- 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
# 去重并裁剪 token 数量,避免生成超长 SQL。
|
||||
uniq_tokens = [t for t in dict.fromkeys(tokens) if t]
|
||||
uniq_tokens = uniq_tokens[:32]
|
||||
if not uniq_tokens:
|
||||
return []
|
||||
|
||||
if self.config.enable_ngram_fallback_index:
|
||||
try:
|
||||
# 允许运行时切换开关后按需补齐 schema/回填。
|
||||
self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn)
|
||||
self.metadata_store.ensure_paragraph_ngram_backfilled(
|
||||
n=self.config.char_ngram_n,
|
||||
conn=self._conn,
|
||||
)
|
||||
rows = self.metadata_store.ngram_search_paragraphs(
|
||||
tokens=uniq_tokens,
|
||||
limit=limit,
|
||||
max_doc_len=self.config.max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
if rows:
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}")
|
||||
|
||||
if not self.config.enable_like_fallback:
|
||||
return []
|
||||
|
||||
conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens))
|
||||
params: List[Any] = [f"%{tok}%" for tok in uniq_tokens]
|
||||
scan_limit = max(int(limit) * 8, 200)
|
||||
params.append(scan_limit)
|
||||
|
||||
sql = f"""
|
||||
SELECT p.hash, p.content
|
||||
FROM paragraphs p
|
||||
WHERE (p.is_deleted IS NULL OR p.is_deleted = 0)
|
||||
AND ({conditions})
|
||||
LIMIT ?
|
||||
"""
|
||||
rows = self.metadata_store.query(sql, tuple(params))
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
scored: List[Dict[str, Any]] = []
|
||||
token_count = max(1, len(uniq_tokens))
|
||||
for row in rows:
|
||||
content = str(row.get("content") or "")
|
||||
content_low = content.lower()
|
||||
matched = [tok for tok in uniq_tokens if tok in content_low]
|
||||
if not matched:
|
||||
continue
|
||||
coverage = len(matched) / token_count
|
||||
length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low))
|
||||
# 兜底路径使用相对分,保持与上层接口兼容。
|
||||
fallback_score = coverage * 0.8 + length_bonus * 0.2
|
||||
scored.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content,
|
||||
"bm25_score": -float(fallback_score),
|
||||
"fallback_score": float(fallback_score),
|
||||
}
|
||||
)
|
||||
|
||||
scored.sort(key=lambda x: x["fallback_score"], reverse=True)
|
||||
return scored[:limit]
|
||||
|
||||
def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
|
||||
"""执行 BM25 检索。"""
|
||||
if not self.config.enabled:
|
||||
return []
|
||||
if self.config.lazy_load and not self.loaded:
|
||||
if not self.ensure_loaded():
|
||||
return []
|
||||
if not self.loaded:
|
||||
return []
|
||||
# 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。
|
||||
self.metadata_store.ensure_relations_fts_schema(conn=self._conn)
|
||||
self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn)
|
||||
|
||||
tokens = self._tokenize(query)
|
||||
match_query = self._build_match_query(tokens)
|
||||
if not match_query:
|
||||
return []
|
||||
|
||||
limit = max(1, int(k))
|
||||
rows = self.metadata_store.fts_search_bm25(
|
||||
match_query=match_query,
|
||||
limit=limit,
|
||||
max_doc_len=self.config.max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
if not rows:
|
||||
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
bm25_score = float(row.get("bm25_score", 0.0))
|
||||
results.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"content": row["content"],
|
||||
"rank": rank,
|
||||
"bm25_score": bm25_score,
|
||||
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
|
||||
"""执行关系稀疏检索(FTS5 + BM25)。"""
|
||||
if not self.config.enabled or not self.config.enable_relation_sparse_fallback:
|
||||
return []
|
||||
if self.config.lazy_load and not self.loaded:
|
||||
if not self.ensure_loaded():
|
||||
return []
|
||||
if not self.loaded:
|
||||
return []
|
||||
|
||||
tokens = self._tokenize(query)
|
||||
match_query = self._build_match_query(tokens)
|
||||
if not match_query:
|
||||
return []
|
||||
|
||||
rows = self.metadata_store.fts_search_relations_bm25(
|
||||
match_query=match_query,
|
||||
limit=max(1, int(k)),
|
||||
max_doc_len=self.config.relation_max_doc_len,
|
||||
conn=self._conn,
|
||||
)
|
||||
out: List[Dict[str, Any]] = []
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
bm25_score = float(row.get("bm25_score", 0.0))
|
||||
out.append(
|
||||
{
|
||||
"hash": row["hash"],
|
||||
"subject": row["subject"],
|
||||
"predicate": row["predicate"],
|
||||
"object": row["object"],
|
||||
"content": row["content"],
|
||||
"rank": rank,
|
||||
"bm25_score": bm25_score,
|
||||
"score": -bm25_score,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
def upsert_paragraph(self, paragraph_hash: str) -> bool:
|
||||
if not self.loaded:
|
||||
return False
|
||||
return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn)
|
||||
|
||||
def delete_paragraph(self, paragraph_hash: str) -> bool:
|
||||
if not self.loaded:
|
||||
return False
|
||||
return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn)
|
||||
|
||||
def unload(self) -> None:
|
||||
"""卸载 BM25 连接并尽量释放内存。"""
|
||||
if self._conn is not None:
|
||||
try:
|
||||
if self.config.shrink_memory_on_unload:
|
||||
self.metadata_store.shrink_memory(conn=self._conn)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
self._loaded = False
|
||||
logger.info("SparseBM25Index unloaded")
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
doc_count = 0
|
||||
if self.loaded:
|
||||
doc_count = self.metadata_store.fts_doc_count(conn=self._conn)
|
||||
return {
|
||||
"enabled": self.config.enabled,
|
||||
"backend": self.config.backend,
|
||||
"mode": self.config.mode,
|
||||
"tokenizer_mode": self.config.tokenizer_mode,
|
||||
"enable_ngram_fallback_index": self.config.enable_ngram_fallback_index,
|
||||
"enable_like_fallback": self.config.enable_like_fallback,
|
||||
"enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback,
|
||||
"loaded": self.loaded,
|
||||
"has_jieba": HAS_JIEBA,
|
||||
"doc_count": doc_count,
|
||||
}
|
||||
450
src/A_memorix/core/retrieval/threshold.py
Normal file
450
src/A_memorix/core/retrieval/threshold.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
动态阈值过滤器
|
||||
|
||||
根据检索结果的分布特征自适应调整过滤阈值。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .dual_path import RetrievalResult
|
||||
|
||||
logger = get_logger("A_Memorix.DynamicThresholdFilter")
|
||||
|
||||
|
||||
class ThresholdMethod(Enum):
|
||||
"""阈值计算方法"""
|
||||
|
||||
PERCENTILE = "percentile" # 百分位数
|
||||
STD_DEV = "std_dev" # 标准差
|
||||
GAP_DETECTION = "gap_detection" # 跳变检测
|
||||
ADAPTIVE = "adaptive" # 自适应(综合多种方法)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThresholdConfig:
|
||||
"""
|
||||
阈值配置
|
||||
|
||||
属性:
|
||||
method: 阈值计算方法
|
||||
min_threshold: 最小阈值(绝对值)
|
||||
max_threshold: 最大阈值(绝对值)
|
||||
percentile: 百分位数(用于percentile方法)
|
||||
std_multiplier: 标准差倍数(用于std_dev方法)
|
||||
min_results: 最少保留结果数
|
||||
enable_auto_adjust: 是否自动调整参数
|
||||
"""
|
||||
|
||||
method: ThresholdMethod = ThresholdMethod.ADAPTIVE
|
||||
min_threshold: float = 0.3
|
||||
max_threshold: float = 0.95
|
||||
percentile: float = 75.0 # 百分位数
|
||||
std_multiplier: float = 1.5 # 标准差倍数
|
||||
min_results: int = 3 # 最少保留结果数
|
||||
enable_auto_adjust: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
if not 0 <= self.min_threshold <= 1:
|
||||
raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}")
|
||||
|
||||
if not 0 <= self.max_threshold <= 1:
|
||||
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
|
||||
|
||||
if self.min_threshold >= self.max_threshold:
|
||||
raise ValueError(f"min_threshold必须小于max_threshold")
|
||||
|
||||
if not 0 <= self.percentile <= 100:
|
||||
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
|
||||
|
||||
if self.std_multiplier <= 0:
|
||||
raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}")
|
||||
|
||||
if self.min_results < 0:
|
||||
raise ValueError(f"min_results必须大于等于0: {self.min_results}")
|
||||
|
||||
|
||||
class DynamicThresholdFilter:
|
||||
"""
|
||||
动态阈值过滤器
|
||||
|
||||
功能:
|
||||
- 基于结果分布自适应计算阈值
|
||||
- 多种阈值计算方法
|
||||
- 自动参数调整
|
||||
- 统计信息收集
|
||||
|
||||
参数:
|
||||
config: 阈值配置
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ThresholdConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化动态阈值过滤器
|
||||
|
||||
Args:
|
||||
config: 阈值配置
|
||||
"""
|
||||
self.config = config or ThresholdConfig()
|
||||
|
||||
# 统计信息
|
||||
self._total_filtered = 0
|
||||
self._total_processed = 0
|
||||
self._threshold_history: List[float] = []
|
||||
|
||||
logger.info(
|
||||
f"DynamicThresholdFilter 初始化: "
|
||||
f"method={self.config.method.value}, "
|
||||
f"min_threshold={self.config.min_threshold}"
|
||||
)
|
||||
|
||||
def filter(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
return_threshold: bool = False,
|
||||
) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]:
|
||||
"""
|
||||
过滤检索结果
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
return_threshold: 是否返回使用的阈值
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表,或 (结果列表, 阈值) 元组
|
||||
"""
|
||||
if not results:
|
||||
logger.debug("结果列表为空,无需过滤")
|
||||
return ([], 0.0) if return_threshold else []
|
||||
|
||||
self._total_processed += len(results)
|
||||
|
||||
# 提取分数
|
||||
scores = np.array([r.score for r in results])
|
||||
|
||||
# 计算阈值
|
||||
threshold = self._compute_threshold(scores, results)
|
||||
|
||||
# 记录阈值
|
||||
self._threshold_history.append(threshold)
|
||||
|
||||
# 应用阈值过滤
|
||||
filtered_results = [
|
||||
r for r in results
|
||||
if r.score >= threshold
|
||||
]
|
||||
|
||||
# 确保至少保留min_results个结果
|
||||
if len(filtered_results) < self.config.min_results:
|
||||
# 按分数排序,取前min_results个
|
||||
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||
filtered_results = sorted_results[:self.config.min_results]
|
||||
threshold = filtered_results[-1].score if filtered_results else 0.0
|
||||
|
||||
self._total_filtered += len(results) - len(filtered_results)
|
||||
|
||||
logger.info(
|
||||
f"过滤完成: {len(results)} -> {len(filtered_results)} "
|
||||
f"(threshold={threshold:.3f})"
|
||||
)
|
||||
|
||||
if return_threshold:
|
||||
return filtered_results, threshold
|
||||
return filtered_results
|
||||
|
||||
def _compute_threshold(
|
||||
self,
|
||||
scores: np.ndarray,
|
||||
results: List[RetrievalResult],
|
||||
) -> float:
|
||||
"""
|
||||
计算阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
if self.config.method == ThresholdMethod.PERCENTILE:
|
||||
threshold = self._percentile_threshold(scores)
|
||||
elif self.config.method == ThresholdMethod.STD_DEV:
|
||||
threshold = self._std_dev_threshold(scores)
|
||||
elif self.config.method == ThresholdMethod.GAP_DETECTION:
|
||||
threshold = self._gap_detection_threshold(scores)
|
||||
else: # ADAPTIVE
|
||||
# 自适应方法:综合多种方法
|
||||
thresholds = [
|
||||
self._percentile_threshold(scores),
|
||||
self._std_dev_threshold(scores),
|
||||
self._gap_detection_threshold(scores),
|
||||
]
|
||||
# 使用中位数作为最终阈值
|
||||
threshold = float(np.median(thresholds))
|
||||
|
||||
# 限制在[min_threshold, max_threshold]范围内
|
||||
threshold = np.clip(
|
||||
threshold,
|
||||
self.config.min_threshold,
|
||||
self.config.max_threshold,
|
||||
)
|
||||
|
||||
# 自动调整
|
||||
if self.config.enable_auto_adjust:
|
||||
threshold = self._auto_adjust_threshold(threshold, scores)
|
||||
|
||||
return float(threshold)
|
||||
|
||||
def _percentile_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于百分位数计算阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
percentile = self.config.percentile
|
||||
threshold = float(np.percentile(scores, percentile))
|
||||
|
||||
logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})")
|
||||
return threshold
|
||||
|
||||
def _std_dev_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于标准差计算阈值
|
||||
|
||||
threshold = mean - std_multiplier * std
|
||||
|
||||
Args:
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
mean = float(np.mean(scores))
|
||||
std = float(np.std(scores))
|
||||
multiplier = self.config.std_multiplier
|
||||
|
||||
threshold = mean - multiplier * std
|
||||
|
||||
logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})")
|
||||
return threshold
|
||||
|
||||
def _gap_detection_threshold(self, scores: np.ndarray) -> float:
|
||||
"""
|
||||
基于跳变检测计算阈值
|
||||
|
||||
找到分数分布中最大的"跳变"位置,以此为阈值
|
||||
|
||||
Args:
|
||||
scores: 分数数组(降序排列)
|
||||
|
||||
Returns:
|
||||
阈值
|
||||
"""
|
||||
# 降序排列
|
||||
sorted_scores = np.sort(scores)[::-1]
|
||||
|
||||
if len(sorted_scores) < 2:
|
||||
return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0
|
||||
|
||||
# 计算相邻分数的差值
|
||||
gaps = np.diff(sorted_scores)
|
||||
|
||||
# 找到最大的跳变位置
|
||||
max_gap_idx = int(np.argmax(gaps))
|
||||
|
||||
# 阈值为跳变后的分数
|
||||
threshold = float(sorted_scores[max_gap_idx + 1])
|
||||
|
||||
logger.debug(
|
||||
f"跳变检测阈值: {threshold:.3f} "
|
||||
f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})"
|
||||
)
|
||||
return threshold
|
||||
|
||||
def _auto_adjust_threshold(
|
||||
self,
|
||||
threshold: float,
|
||||
scores: np.ndarray,
|
||||
) -> float:
|
||||
"""
|
||||
自动调整阈值
|
||||
|
||||
基于历史阈值和当前分数分布调整
|
||||
|
||||
Args:
|
||||
threshold: 当前阈值
|
||||
scores: 分数数组
|
||||
|
||||
Returns:
|
||||
调整后的阈值
|
||||
"""
|
||||
if not self._threshold_history:
|
||||
return threshold
|
||||
|
||||
# 计算历史阈值的移动平均
|
||||
recent_thresholds = self._threshold_history[-10:] # 最近10次
|
||||
avg_threshold = float(np.mean(recent_thresholds))
|
||||
|
||||
# 当前阈值与历史平均的差异
|
||||
diff = threshold - avg_threshold
|
||||
|
||||
# 如果差异过大(>0.2),向历史平均靠拢
|
||||
if abs(diff) > 0.2:
|
||||
adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50%
|
||||
logger.debug(
|
||||
f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} "
|
||||
f"(历史平均={avg_threshold:.3f})"
|
||||
)
|
||||
return adjusted_threshold
|
||||
|
||||
return threshold
|
||||
|
||||
def filter_by_confidence(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
min_confidence: float = 0.5,
|
||||
) -> List[RetrievalResult]:
|
||||
"""
|
||||
基于置信度过滤结果
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
min_confidence: 最小置信度
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表
|
||||
"""
|
||||
filtered = []
|
||||
for result in results:
|
||||
# 对于关系结果,使用confidence字段
|
||||
if result.result_type == "relation":
|
||||
confidence = result.metadata.get("confidence", 1.0)
|
||||
if confidence >= min_confidence:
|
||||
filtered.append(result)
|
||||
else:
|
||||
# 对于段落结果,直接使用分数
|
||||
if result.score >= min_confidence:
|
||||
filtered.append(result)
|
||||
|
||||
logger.info(
|
||||
f"置信度过滤: {len(results)} -> {len(filtered)} "
|
||||
f"(min_confidence={min_confidence})"
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
def filter_by_diversity(
|
||||
self,
|
||||
results: List[RetrievalResult],
|
||||
similarity_threshold: float = 0.9,
|
||||
top_k: int = 10,
|
||||
) -> List[RetrievalResult]:
|
||||
"""
|
||||
基于多样性过滤结果(去除重复)
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
similarity_threshold: 相似度阈值(高于此值视为重复)
|
||||
top_k: 最多保留结果数
|
||||
|
||||
Returns:
|
||||
过滤后的结果列表
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# 按分数排序
|
||||
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||
|
||||
# 贪心选择:选择与已选结果相似度低的结果
|
||||
selected = []
|
||||
selected_hashes = []
|
||||
|
||||
for result in sorted_results:
|
||||
if len(selected) >= top_k:
|
||||
break
|
||||
|
||||
# 检查与已选结果的相似度
|
||||
is_duplicate = False
|
||||
for selected_hash in selected_hashes:
|
||||
# 简单判断:基于hash的前缀
|
||||
if result.hash_value[:8] == selected_hash[:8]:
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
selected.append(result)
|
||||
selected_hashes.append(result.hash_value)
|
||||
|
||||
logger.info(
|
||||
f"多样性过滤: {len(results)} -> {len(selected)} "
|
||||
f"(similarity_threshold={similarity_threshold})"
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
filter_rate = (
|
||||
self._total_filtered / self._total_processed
|
||||
if self._total_processed > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
stats = {
|
||||
"config": {
|
||||
"method": self.config.method.value,
|
||||
"min_threshold": self.config.min_threshold,
|
||||
"max_threshold": self.config.max_threshold,
|
||||
"percentile": self.config.percentile,
|
||||
"std_multiplier": self.config.std_multiplier,
|
||||
"min_results": self.config.min_results,
|
||||
"enable_auto_adjust": self.config.enable_auto_adjust,
|
||||
},
|
||||
"statistics": {
|
||||
"total_processed": self._total_processed,
|
||||
"total_filtered": self._total_filtered,
|
||||
"filter_rate": filter_rate,
|
||||
"avg_threshold": float(np.mean(self._threshold_history))
|
||||
if self._threshold_history else 0.0,
|
||||
"threshold_count": len(self._threshold_history),
|
||||
},
|
||||
}
|
||||
|
||||
if self._threshold_history:
|
||||
stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history))
|
||||
stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history))
|
||||
|
||||
return stats
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""重置统计信息"""
|
||||
self._total_filtered = 0
|
||||
self._total_processed = 0
|
||||
self._threshold_history.clear()
|
||||
logger.info("统计信息已重置")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"DynamicThresholdFilter("
|
||||
f"method={self.config.method.value}, "
|
||||
f"min_threshold={self.config.min_threshold}, "
|
||||
f"filtered={self._total_filtered}/{self._total_processed})"
|
||||
)
|
||||
16
src/A_memorix/core/runtime/__init__.py
Normal file
16
src/A_memorix/core/runtime/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""SDK runtime exports for A_Memorix."""
|
||||
|
||||
from .search_runtime_initializer import (
|
||||
SearchRuntimeBundle,
|
||||
SearchRuntimeInitializer,
|
||||
build_search_runtime,
|
||||
)
|
||||
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
|
||||
__all__ = [
|
||||
"SearchRuntimeBundle",
|
||||
"SearchRuntimeInitializer",
|
||||
"build_search_runtime",
|
||||
"KernelSearchRequest",
|
||||
"SDKMemoryKernel",
|
||||
]
|
||||
265
src/A_memorix/core/runtime/lifecycle_orchestrator.py
Normal file
265
src/A_memorix/core/runtime/lifecycle_orchestrator.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Lifecycle bootstrap/teardown helpers extracted from plugin.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ...paths import default_data_dir, resolve_repo_path
|
||||
from ..embedding import create_embedding_api_adapter
|
||||
from ..retrieval import SparseBM25Config, SparseBM25Index
|
||||
from ..storage import (
|
||||
GraphStore,
|
||||
MetadataStore,
|
||||
QuantizationType,
|
||||
SparseMatrixFormat,
|
||||
VectorStore,
|
||||
)
|
||||
from ..utils.runtime_self_check import ensure_runtime_self_check
|
||||
from ..utils.relation_write_service import RelationWriteService
|
||||
|
||||
logger = get_logger("A_Memorix.LifecycleOrchestrator")
|
||||
|
||||
|
||||
async def ensure_initialized(plugin: Any) -> None:
|
||||
if plugin._initialized:
|
||||
plugin._runtime_ready = plugin._check_storage_ready()
|
||||
return
|
||||
|
||||
async with plugin._init_lock:
|
||||
if plugin._initialized:
|
||||
plugin._runtime_ready = plugin._check_storage_ready()
|
||||
return
|
||||
|
||||
logger.info("A_Memorix 插件正在异步初始化存储组件...")
|
||||
plugin._validate_runtime_config()
|
||||
await initialize_storage_async(plugin)
|
||||
report = await ensure_runtime_self_check(plugin, force=True)
|
||||
if not bool(report.get("ok", False)):
|
||||
logger.error(
|
||||
"A_Memorix runtime self-check failed: "
|
||||
f"{report.get('message', 'unknown')}; "
|
||||
"建议执行 python src/A_memorix/scripts/runtime_self_check.py --json"
|
||||
)
|
||||
|
||||
if plugin.graph_store and plugin.metadata_store:
|
||||
relation_count = plugin.metadata_store.count_relations()
|
||||
if relation_count > 0 and not plugin.graph_store.has_edge_hash_map():
|
||||
raise RuntimeError(
|
||||
"检测到 relations 数据存在但 edge-hash-map 为空。"
|
||||
" 请先执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
|
||||
plugin._initialized = True
|
||||
plugin._runtime_ready = plugin._check_storage_ready()
|
||||
plugin._update_plugin_config()
|
||||
logger.info("A_Memorix 插件异步初始化成功")
|
||||
|
||||
|
||||
def start_background_tasks(plugin: Any) -> None:
|
||||
"""Start background tasks idempotently."""
|
||||
if not hasattr(plugin, "_episode_generation_task"):
|
||||
plugin._episode_generation_task = None
|
||||
|
||||
if (
|
||||
plugin.get_config("summarization.enabled", True)
|
||||
and plugin.get_config("schedule.enabled", True)
|
||||
and (plugin._scheduled_import_task is None or plugin._scheduled_import_task.done())
|
||||
):
|
||||
plugin._scheduled_import_task = asyncio.create_task(plugin._scheduled_import_loop())
|
||||
|
||||
if (
|
||||
plugin.get_config("advanced.enable_auto_save", True)
|
||||
and (plugin._auto_save_task is None or plugin._auto_save_task.done())
|
||||
):
|
||||
plugin._auto_save_task = asyncio.create_task(plugin._auto_save_loop())
|
||||
|
||||
if (
|
||||
plugin.get_config("person_profile.enabled", True)
|
||||
and (plugin._person_profile_refresh_task is None or plugin._person_profile_refresh_task.done())
|
||||
):
|
||||
plugin._person_profile_refresh_task = asyncio.create_task(plugin._person_profile_refresh_loop())
|
||||
|
||||
if plugin._memory_maintenance_task is None or plugin._memory_maintenance_task.done():
|
||||
plugin._memory_maintenance_task = asyncio.create_task(plugin._memory_maintenance_loop())
|
||||
|
||||
rv_cfg = plugin.get_config("retrieval.relation_vectorization", {}) or {}
|
||||
if isinstance(rv_cfg, dict):
|
||||
rv_enabled = bool(rv_cfg.get("enabled", False))
|
||||
rv_backfill = bool(rv_cfg.get("backfill_enabled", False))
|
||||
else:
|
||||
rv_enabled = False
|
||||
rv_backfill = False
|
||||
if rv_enabled and rv_backfill and (
|
||||
plugin._relation_vector_backfill_task is None or plugin._relation_vector_backfill_task.done()
|
||||
):
|
||||
plugin._relation_vector_backfill_task = asyncio.create_task(plugin._relation_vector_backfill_loop())
|
||||
|
||||
episode_task = getattr(plugin, "_episode_generation_task", None)
|
||||
episode_loop = getattr(plugin, "_episode_generation_loop", None)
|
||||
if (
|
||||
callable(episode_loop)
|
||||
and bool(plugin.get_config("episode.enabled", True))
|
||||
and bool(plugin.get_config("episode.generation_enabled", True))
|
||||
and (episode_task is None or episode_task.done())
|
||||
):
|
||||
plugin._episode_generation_task = asyncio.create_task(episode_loop())
|
||||
|
||||
|
||||
async def cancel_background_tasks(plugin: Any) -> None:
|
||||
"""Cancel all background tasks and wait for cleanup."""
|
||||
tasks = [
|
||||
("scheduled_import", plugin._scheduled_import_task),
|
||||
("auto_save", plugin._auto_save_task),
|
||||
("person_profile_refresh", plugin._person_profile_refresh_task),
|
||||
("memory_maintenance", plugin._memory_maintenance_task),
|
||||
("relation_vector_backfill", plugin._relation_vector_backfill_task),
|
||||
("episode_generation", getattr(plugin, "_episode_generation_task", None)),
|
||||
]
|
||||
for _, task in tasks:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
for name, task in tasks:
|
||||
if not task:
|
||||
continue
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"后台任务 {name} 退出异常: {e}")
|
||||
|
||||
plugin._scheduled_import_task = None
|
||||
plugin._auto_save_task = None
|
||||
plugin._person_profile_refresh_task = None
|
||||
plugin._memory_maintenance_task = None
|
||||
plugin._relation_vector_backfill_task = None
|
||||
plugin._episode_generation_task = None
|
||||
|
||||
|
||||
async def initialize_storage_async(plugin: Any) -> None:
|
||||
"""Initialize storage components asynchronously."""
|
||||
data_dir_str = plugin.get_config("storage.data_dir", "./data")
|
||||
data_dir = resolve_repo_path(data_dir_str, fallback=default_data_dir())
|
||||
|
||||
logger.info(f"A_Memorix 数据存储路径: {data_dir}")
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plugin.embedding_manager = create_embedding_api_adapter(
|
||||
batch_size=plugin.get_config("embedding.batch_size", 32),
|
||||
max_concurrent=plugin.get_config("embedding.max_concurrent", 5),
|
||||
default_dimension=plugin.get_config("embedding.dimension", 1024),
|
||||
model_name=plugin.get_config("embedding.model_name", "auto"),
|
||||
retry_config=plugin.get_config("embedding.retry", {}),
|
||||
)
|
||||
logger.info("嵌入 API 适配器初始化完成")
|
||||
|
||||
try:
|
||||
detected_dimension = await plugin.embedding_manager._detect_dimension()
|
||||
logger.info(f"嵌入维度检测成功: {detected_dimension}")
|
||||
except Exception as e:
|
||||
logger.warning(f"嵌入维度检测失败: {e},使用默认值")
|
||||
detected_dimension = plugin.embedding_manager.default_dimension
|
||||
|
||||
quantization_str = plugin.get_config("embedding.quantization_type", "int8")
|
||||
if str(quantization_str or "").strip().lower() != "int8":
|
||||
raise ValueError("embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。")
|
||||
quantization_type = QuantizationType.INT8
|
||||
|
||||
plugin.vector_store = VectorStore(
|
||||
dimension=detected_dimension,
|
||||
quantization_type=quantization_type,
|
||||
data_dir=data_dir / "vectors",
|
||||
)
|
||||
plugin.vector_store.min_train_threshold = plugin.get_config("embedding.min_train_threshold", 40)
|
||||
logger.info(
|
||||
"向量存储初始化完成("
|
||||
f"维度: {detected_dimension}, "
|
||||
f"训练阈值: {plugin.vector_store.min_train_threshold})"
|
||||
)
|
||||
|
||||
matrix_format_str = plugin.get_config("graph.sparse_matrix_format", "csr")
|
||||
matrix_format_map = {
|
||||
"csr": SparseMatrixFormat.CSR,
|
||||
"csc": SparseMatrixFormat.CSC,
|
||||
}
|
||||
matrix_format = matrix_format_map.get(matrix_format_str, SparseMatrixFormat.CSR)
|
||||
|
||||
plugin.graph_store = GraphStore(
|
||||
matrix_format=matrix_format,
|
||||
data_dir=data_dir / "graph",
|
||||
)
|
||||
logger.info("图存储初始化完成")
|
||||
|
||||
plugin.metadata_store = MetadataStore(data_dir=data_dir / "metadata")
|
||||
plugin.metadata_store.connect()
|
||||
logger.info("元数据存储初始化完成")
|
||||
|
||||
plugin.relation_write_service = RelationWriteService(
|
||||
metadata_store=plugin.metadata_store,
|
||||
graph_store=plugin.graph_store,
|
||||
vector_store=plugin.vector_store,
|
||||
embedding_manager=plugin.embedding_manager,
|
||||
)
|
||||
logger.info("关系写入服务初始化完成")
|
||||
|
||||
sparse_cfg_raw = plugin.get_config("retrieval.sparse", {}) or {}
|
||||
if not isinstance(sparse_cfg_raw, dict):
|
||||
sparse_cfg_raw = {}
|
||||
try:
|
||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"sparse 配置非法,回退默认配置: {e}")
|
||||
sparse_cfg = SparseBM25Config()
|
||||
plugin.sparse_index = SparseBM25Index(
|
||||
metadata_store=plugin.metadata_store,
|
||||
config=sparse_cfg,
|
||||
)
|
||||
logger.info(
|
||||
"稀疏检索组件初始化完成: "
|
||||
f"enabled={sparse_cfg.enabled}, "
|
||||
f"lazy_load={sparse_cfg.lazy_load}, "
|
||||
f"mode={sparse_cfg.mode}, "
|
||||
f"tokenizer={sparse_cfg.tokenizer_mode}"
|
||||
)
|
||||
if sparse_cfg.enabled and not sparse_cfg.lazy_load:
|
||||
plugin.sparse_index.ensure_loaded()
|
||||
|
||||
if plugin.vector_store.has_data():
|
||||
try:
|
||||
plugin.vector_store.load()
|
||||
logger.info(f"向量数据已加载,共 {plugin.vector_store.num_vectors} 个向量")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载向量数据失败: {e}")
|
||||
|
||||
try:
|
||||
warmup_summary = plugin.vector_store.warmup_index(force_train=True)
|
||||
if warmup_summary.get("ok"):
|
||||
logger.info(
|
||||
"向量索引预热完成: "
|
||||
f"trained={warmup_summary.get('trained')}, "
|
||||
f"index_ntotal={warmup_summary.get('index_ntotal')}, "
|
||||
f"fallback_ntotal={warmup_summary.get('fallback_ntotal')}, "
|
||||
f"bin_count={warmup_summary.get('bin_count')}, "
|
||||
f"duration_ms={float(warmup_summary.get('duration_ms', 0.0)):.2f}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"向量索引预热失败,继续启用 sparse 降级路径: "
|
||||
f"{warmup_summary.get('error', 'unknown')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"向量索引预热异常,继续启用 sparse 降级路径: {e}")
|
||||
|
||||
if plugin.graph_store.has_data():
|
||||
try:
|
||||
plugin.graph_store.load()
|
||||
logger.info(f"图数据已加载,共 {plugin.graph_store.num_nodes} 个节点")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载图数据失败: {e}")
|
||||
|
||||
logger.info(f"知识库数据目录: {data_dir}")
|
||||
4421
src/A_memorix/core/runtime/sdk_memory_kernel.py
Normal file
4421
src/A_memorix/core/runtime/sdk_memory_kernel.py
Normal file
File diff suppressed because it is too large
Load Diff
240
src/A_memorix/core/runtime/search_runtime_initializer.py
Normal file
240
src/A_memorix/core/runtime/search_runtime_initializer.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Shared runtime initializer for Action/Tool/Command retrieval components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import (
|
||||
DualPathRetriever,
|
||||
DualPathRetrieverConfig,
|
||||
DynamicThresholdFilter,
|
||||
FusionConfig,
|
||||
GraphRelationRecallConfig,
|
||||
RelationIntentConfig,
|
||||
RetrievalStrategy,
|
||||
SparseBM25Config,
|
||||
ThresholdConfig,
|
||||
ThresholdMethod,
|
||||
)
|
||||
|
||||
_logger = get_logger("A_Memorix.SearchRuntimeInitializer")
|
||||
|
||||
_REQUIRED_COMPONENT_KEYS = (
|
||||
"vector_store",
|
||||
"graph_store",
|
||||
"metadata_store",
|
||||
"embedding_manager",
|
||||
)
|
||||
|
||||
|
||||
def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any:
|
||||
if not isinstance(config, dict):
|
||||
return default
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def _safe_dict(value: Any) -> Dict[str, Any]:
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _resolve_debug_enabled(plugin_config: Optional[dict]) -> bool:
|
||||
advanced = _get_config_value(plugin_config, "advanced", {})
|
||||
if isinstance(advanced, dict):
|
||||
return bool(advanced.get("debug", False))
|
||||
return bool(_get_config_value(plugin_config, "debug", False))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchRuntimeBundle:
|
||||
"""Resolved runtime components and initialized retriever/filter."""
|
||||
|
||||
vector_store: Optional[Any] = None
|
||||
graph_store: Optional[Any] = None
|
||||
metadata_store: Optional[Any] = None
|
||||
embedding_manager: Optional[Any] = None
|
||||
sparse_index: Optional[Any] = None
|
||||
retriever: Optional[DualPathRetriever] = None
|
||||
threshold_filter: Optional[DynamicThresholdFilter] = None
|
||||
error: str = ""
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
return (
|
||||
self.retriever is not None
|
||||
and self.vector_store is not None
|
||||
and self.graph_store is not None
|
||||
and self.metadata_store is not None
|
||||
and self.embedding_manager is not None
|
||||
)
|
||||
|
||||
|
||||
def _resolve_runtime_components(plugin_config: Optional[dict]) -> SearchRuntimeBundle:
|
||||
bundle = SearchRuntimeBundle(
|
||||
vector_store=_get_config_value(plugin_config, "vector_store"),
|
||||
graph_store=_get_config_value(plugin_config, "graph_store"),
|
||||
metadata_store=_get_config_value(plugin_config, "metadata_store"),
|
||||
embedding_manager=_get_config_value(plugin_config, "embedding_manager"),
|
||||
sparse_index=_get_config_value(plugin_config, "sparse_index"),
|
||||
)
|
||||
|
||||
missing_required = any(
|
||||
getattr(bundle, key) is None for key in _REQUIRED_COMPONENT_KEYS
|
||||
)
|
||||
if not missing_required:
|
||||
return bundle
|
||||
|
||||
try:
|
||||
from ...runtime_registry import get_runtime_components
|
||||
|
||||
instances = get_runtime_components()
|
||||
except Exception:
|
||||
instances = {}
|
||||
|
||||
if not isinstance(instances, dict) or not instances:
|
||||
return bundle
|
||||
|
||||
if bundle.vector_store is None:
|
||||
bundle.vector_store = instances.get("vector_store")
|
||||
if bundle.graph_store is None:
|
||||
bundle.graph_store = instances.get("graph_store")
|
||||
if bundle.metadata_store is None:
|
||||
bundle.metadata_store = instances.get("metadata_store")
|
||||
if bundle.embedding_manager is None:
|
||||
bundle.embedding_manager = instances.get("embedding_manager")
|
||||
if bundle.sparse_index is None:
|
||||
bundle.sparse_index = instances.get("sparse_index")
|
||||
return bundle
|
||||
|
||||
|
||||
def build_search_runtime(
|
||||
plugin_config: Optional[dict],
|
||||
logger_obj: Optional[Any],
|
||||
owner_tag: str,
|
||||
*,
|
||||
log_prefix: str = "",
|
||||
) -> SearchRuntimeBundle:
|
||||
"""Build retriever + threshold filter with unified fallback/config parsing."""
|
||||
|
||||
log = logger_obj or _logger
|
||||
owner = str(owner_tag or "runtime").strip().lower() or "runtime"
|
||||
prefix = str(log_prefix or "").strip()
|
||||
prefix_text = f"{prefix} " if prefix else ""
|
||||
|
||||
runtime = _resolve_runtime_components(plugin_config)
|
||||
if any(getattr(runtime, key) is None for key in _REQUIRED_COMPONENT_KEYS):
|
||||
runtime.error = "存储组件未完全初始化"
|
||||
log.warning(f"{prefix_text}[{owner}] 存储组件未完全初始化,无法使用检索功能")
|
||||
return runtime
|
||||
|
||||
sparse_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.sparse", {}) or {})
|
||||
fusion_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.fusion", {}) or {})
|
||||
relation_intent_cfg_raw = _safe_dict(
|
||||
_get_config_value(plugin_config, "retrieval.search.relation_intent", {}) or {}
|
||||
)
|
||||
graph_recall_cfg_raw = _safe_dict(
|
||||
_get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {}
|
||||
)
|
||||
|
||||
try:
|
||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] sparse 配置非法,回退默认: {e}")
|
||||
sparse_cfg = SparseBM25Config()
|
||||
|
||||
try:
|
||||
fusion_cfg = FusionConfig(**fusion_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] fusion 配置非法,回退默认: {e}")
|
||||
fusion_cfg = FusionConfig()
|
||||
|
||||
try:
|
||||
relation_intent_cfg = RelationIntentConfig(**relation_intent_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] relation_intent 配置非法,回退默认: {e}")
|
||||
relation_intent_cfg = RelationIntentConfig()
|
||||
|
||||
try:
|
||||
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
|
||||
except Exception as e:
|
||||
log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}")
|
||||
graph_recall_cfg = GraphRelationRecallConfig()
|
||||
|
||||
try:
|
||||
config = DualPathRetrieverConfig(
|
||||
top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20),
|
||||
top_k_relations=_get_config_value(plugin_config, "retrieval.top_k_relations", 10),
|
||||
top_k_final=_get_config_value(plugin_config, "retrieval.top_k_final", 10),
|
||||
alpha=_get_config_value(plugin_config, "retrieval.alpha", 0.5),
|
||||
enable_ppr=_get_config_value(plugin_config, "retrieval.enable_ppr", True),
|
||||
ppr_alpha=_get_config_value(plugin_config, "retrieval.ppr_alpha", 0.85),
|
||||
ppr_timeout_seconds=_get_config_value(
|
||||
plugin_config, "retrieval.ppr_timeout_seconds", 1.5
|
||||
),
|
||||
ppr_concurrency_limit=_get_config_value(
|
||||
plugin_config, "retrieval.ppr_concurrency_limit", 4
|
||||
),
|
||||
enable_parallel=_get_config_value(plugin_config, "retrieval.enable_parallel", True),
|
||||
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
|
||||
debug=_resolve_debug_enabled(plugin_config),
|
||||
sparse=sparse_cfg,
|
||||
fusion=fusion_cfg,
|
||||
relation_intent=relation_intent_cfg,
|
||||
graph_recall=graph_recall_cfg,
|
||||
)
|
||||
|
||||
runtime.retriever = DualPathRetriever(
|
||||
vector_store=runtime.vector_store,
|
||||
graph_store=runtime.graph_store,
|
||||
metadata_store=runtime.metadata_store,
|
||||
embedding_manager=runtime.embedding_manager,
|
||||
sparse_index=runtime.sparse_index,
|
||||
config=config,
|
||||
)
|
||||
|
||||
threshold_config = ThresholdConfig(
|
||||
method=ThresholdMethod.ADAPTIVE,
|
||||
min_threshold=_get_config_value(plugin_config, "threshold.min_threshold", 0.3),
|
||||
max_threshold=_get_config_value(plugin_config, "threshold.max_threshold", 0.95),
|
||||
percentile=_get_config_value(plugin_config, "threshold.percentile", 75.0),
|
||||
std_multiplier=_get_config_value(plugin_config, "threshold.std_multiplier", 1.5),
|
||||
min_results=_get_config_value(plugin_config, "threshold.min_results", 3),
|
||||
enable_auto_adjust=_get_config_value(plugin_config, "threshold.enable_auto_adjust", True),
|
||||
)
|
||||
runtime.threshold_filter = DynamicThresholdFilter(threshold_config)
|
||||
runtime.error = ""
|
||||
log.info(f"{prefix_text}[{owner}] 检索运行时初始化完成")
|
||||
except Exception as e:
|
||||
runtime.retriever = None
|
||||
runtime.threshold_filter = None
|
||||
runtime.error = str(e)
|
||||
log.error(f"{prefix_text}[{owner}] 检索运行时初始化失败: {e}")
|
||||
|
||||
return runtime
|
||||
|
||||
|
||||
class SearchRuntimeInitializer:
|
||||
"""Compatibility wrapper around the function style initializer."""
|
||||
|
||||
@staticmethod
|
||||
def build_search_runtime(
|
||||
plugin_config: Optional[dict],
|
||||
logger_obj: Optional[Any],
|
||||
owner_tag: str,
|
||||
*,
|
||||
log_prefix: str = "",
|
||||
) -> SearchRuntimeBundle:
|
||||
return build_search_runtime(
|
||||
plugin_config=plugin_config,
|
||||
logger_obj=logger_obj,
|
||||
owner_tag=owner_tag,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
53
src/A_memorix/core/storage/__init__.py
Normal file
53
src/A_memorix/core/storage/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""存储层"""
|
||||
|
||||
from .vector_store import VectorStore, QuantizationType
|
||||
from .graph_store import GraphStore, SparseMatrixFormat
|
||||
from .metadata_store import MetadataStore
|
||||
from .knowledge_types import (
|
||||
ImportStrategy,
|
||||
KnowledgeType,
|
||||
allowed_import_strategy_values,
|
||||
allowed_knowledge_type_values,
|
||||
get_knowledge_type_from_string,
|
||||
get_import_strategy_from_string,
|
||||
parse_import_strategy,
|
||||
resolve_stored_knowledge_type,
|
||||
should_extract_relations,
|
||||
get_default_chunk_size,
|
||||
get_type_display_name,
|
||||
validate_stored_knowledge_type,
|
||||
)
|
||||
from .type_detection import (
|
||||
detect_knowledge_type,
|
||||
get_type_from_user_input,
|
||||
looks_like_factual_text,
|
||||
looks_like_quote_text,
|
||||
looks_like_structured_text,
|
||||
select_import_strategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VectorStore",
|
||||
"GraphStore",
|
||||
"MetadataStore",
|
||||
"QuantizationType",
|
||||
"SparseMatrixFormat",
|
||||
"ImportStrategy",
|
||||
"KnowledgeType",
|
||||
"allowed_import_strategy_values",
|
||||
"allowed_knowledge_type_values",
|
||||
"get_knowledge_type_from_string",
|
||||
"get_import_strategy_from_string",
|
||||
"parse_import_strategy",
|
||||
"resolve_stored_knowledge_type",
|
||||
"should_extract_relations",
|
||||
"get_default_chunk_size",
|
||||
"get_type_display_name",
|
||||
"validate_stored_knowledge_type",
|
||||
"detect_knowledge_type",
|
||||
"get_type_from_user_input",
|
||||
"looks_like_factual_text",
|
||||
"looks_like_quote_text",
|
||||
"looks_like_structured_text",
|
||||
"select_import_strategy",
|
||||
]
|
||||
1448
src/A_memorix/core/storage/graph_store.py
Normal file
1448
src/A_memorix/core/storage/graph_store.py
Normal file
File diff suppressed because it is too large
Load Diff
183
src/A_memorix/core/storage/knowledge_types.py
Normal file
183
src/A_memorix/core/storage/knowledge_types.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Knowledge type and import strategy helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class KnowledgeType(str, Enum):
|
||||
"""持久化到 paragraphs.knowledge_type 的合法类型。"""
|
||||
|
||||
STRUCTURED = "structured"
|
||||
NARRATIVE = "narrative"
|
||||
FACTUAL = "factual"
|
||||
QUOTE = "quote"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class ImportStrategy(str, Enum):
|
||||
"""文本导入阶段的策略选择。"""
|
||||
|
||||
AUTO = "auto"
|
||||
NARRATIVE = "narrative"
|
||||
FACTUAL = "factual"
|
||||
QUOTE = "quote"
|
||||
|
||||
|
||||
def allowed_knowledge_type_values() -> tuple[str, ...]:
|
||||
return tuple(item.value for item in KnowledgeType)
|
||||
|
||||
|
||||
def allowed_import_strategy_values() -> tuple[str, ...]:
|
||||
return tuple(item.value for item in ImportStrategy)
|
||||
|
||||
|
||||
def get_knowledge_type_from_string(type_str: Any) -> Optional[KnowledgeType]:
|
||||
"""从字符串解析合法的落库知识类型。"""
|
||||
|
||||
if not isinstance(type_str, str):
|
||||
return None
|
||||
normalized = type_str.lower().strip()
|
||||
try:
|
||||
return KnowledgeType(normalized)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def get_import_strategy_from_string(value: Any) -> Optional[ImportStrategy]:
|
||||
"""从字符串解析文本导入策略。"""
|
||||
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
normalized = value.lower().strip()
|
||||
try:
|
||||
return ImportStrategy(normalized)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def parse_import_strategy(value: Any, default: ImportStrategy = ImportStrategy.AUTO) -> ImportStrategy:
|
||||
"""解析 import strategy;非法值直接报错。"""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, ImportStrategy):
|
||||
return value
|
||||
|
||||
normalized = str(value or "").strip().lower()
|
||||
if not normalized:
|
||||
return default
|
||||
|
||||
strategy = get_import_strategy_from_string(normalized)
|
||||
if strategy is None:
|
||||
allowed = "/".join(allowed_import_strategy_values())
|
||||
raise ValueError(f"strategy_override 必须为 {allowed}")
|
||||
return strategy
|
||||
|
||||
|
||||
def validate_stored_knowledge_type(value: Any) -> KnowledgeType:
|
||||
"""校验写库 knowledge_type,仅允许合法落库类型。"""
|
||||
|
||||
if isinstance(value, KnowledgeType):
|
||||
return value
|
||||
|
||||
resolved = get_knowledge_type_from_string(value)
|
||||
if resolved is None:
|
||||
allowed = "/".join(allowed_knowledge_type_values())
|
||||
raise ValueError(f"knowledge_type 必须为 {allowed}")
|
||||
return resolved
|
||||
|
||||
|
||||
def resolve_stored_knowledge_type(
|
||||
value: Any,
|
||||
*,
|
||||
content: str = "",
|
||||
allow_legacy: bool = False,
|
||||
unknown_fallback: Optional[KnowledgeType] = None,
|
||||
) -> KnowledgeType:
|
||||
"""
|
||||
将策略/字符串/旧值解析为合法落库类型。
|
||||
|
||||
`allow_legacy=True` 仅供迁移使用。
|
||||
"""
|
||||
|
||||
if isinstance(value, KnowledgeType):
|
||||
return value
|
||||
|
||||
if isinstance(value, ImportStrategy):
|
||||
if value == ImportStrategy.AUTO:
|
||||
if not str(content or "").strip():
|
||||
raise ValueError("knowledge_type=auto 需要 content 才能推断")
|
||||
from .type_detection import detect_knowledge_type
|
||||
|
||||
return detect_knowledge_type(content)
|
||||
return KnowledgeType(value.value)
|
||||
|
||||
raw = str(value or "").strip()
|
||||
if not raw:
|
||||
if str(content or "").strip():
|
||||
from .type_detection import detect_knowledge_type
|
||||
|
||||
return detect_knowledge_type(content)
|
||||
raise ValueError("knowledge_type 不能为空")
|
||||
|
||||
direct = get_knowledge_type_from_string(raw)
|
||||
if direct is not None:
|
||||
return direct
|
||||
|
||||
strategy = get_import_strategy_from_string(raw)
|
||||
if strategy is not None:
|
||||
return resolve_stored_knowledge_type(strategy, content=content)
|
||||
|
||||
if allow_legacy:
|
||||
normalized = raw.lower()
|
||||
if normalized == "imported":
|
||||
return KnowledgeType.FACTUAL
|
||||
if str(content or "").strip():
|
||||
from .type_detection import detect_knowledge_type
|
||||
|
||||
detected = detect_knowledge_type(content)
|
||||
if detected is not None:
|
||||
return detected
|
||||
if unknown_fallback is not None:
|
||||
return unknown_fallback
|
||||
|
||||
allowed = "/".join(allowed_knowledge_type_values())
|
||||
raise ValueError(f"非法 knowledge_type: {raw}(仅允许 {allowed})")
|
||||
|
||||
|
||||
def should_extract_relations(knowledge_type: KnowledgeType) -> bool:
|
||||
"""判断是否应该做关系抽取。"""
|
||||
|
||||
return knowledge_type in [
|
||||
KnowledgeType.STRUCTURED,
|
||||
KnowledgeType.FACTUAL,
|
||||
KnowledgeType.MIXED,
|
||||
]
|
||||
|
||||
|
||||
def get_default_chunk_size(knowledge_type: KnowledgeType) -> int:
|
||||
"""获取默认分块大小。"""
|
||||
|
||||
chunk_sizes = {
|
||||
KnowledgeType.STRUCTURED: 300,
|
||||
KnowledgeType.NARRATIVE: 800,
|
||||
KnowledgeType.FACTUAL: 500,
|
||||
KnowledgeType.QUOTE: 400,
|
||||
KnowledgeType.MIXED: 500,
|
||||
}
|
||||
return chunk_sizes.get(knowledge_type, 500)
|
||||
|
||||
|
||||
def get_type_display_name(knowledge_type: KnowledgeType) -> str:
|
||||
"""获取知识类型中文名称。"""
|
||||
|
||||
display_names = {
|
||||
KnowledgeType.STRUCTURED: "结构化知识",
|
||||
KnowledgeType.NARRATIVE: "叙事性文本",
|
||||
KnowledgeType.FACTUAL: "事实陈述",
|
||||
KnowledgeType.QUOTE: "引用文本",
|
||||
KnowledgeType.MIXED: "混合类型",
|
||||
}
|
||||
return display_names.get(knowledge_type, "未知类型")
|
||||
5959
src/A_memorix/core/storage/metadata_store.py
Normal file
5959
src/A_memorix/core/storage/metadata_store.py
Normal file
File diff suppressed because it is too large
Load Diff
137
src/A_memorix/core/storage/type_detection.py
Normal file
137
src/A_memorix/core/storage/type_detection.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Heuristic detection for import strategies and stored knowledge types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from .knowledge_types import (
|
||||
ImportStrategy,
|
||||
KnowledgeType,
|
||||
parse_import_strategy,
|
||||
resolve_stored_knowledge_type,
|
||||
)
|
||||
|
||||
|
||||
_NARRATIVE_MARKERS = [
|
||||
r"然后",
|
||||
r"接着",
|
||||
r"于是",
|
||||
r"后来",
|
||||
r"最后",
|
||||
r"突然",
|
||||
r"一天",
|
||||
r"曾经",
|
||||
r"有一次",
|
||||
r"从前",
|
||||
r"说道",
|
||||
r"问道",
|
||||
r"想着",
|
||||
r"觉得",
|
||||
]
|
||||
_FACTUAL_MARKERS = [
|
||||
r"是",
|
||||
r"有",
|
||||
r"在",
|
||||
r"为",
|
||||
r"属于",
|
||||
r"位于",
|
||||
r"包含",
|
||||
r"拥有",
|
||||
r"成立于",
|
||||
r"出生于",
|
||||
]
|
||||
|
||||
|
||||
def _non_empty_lines(content: str) -> list[str]:
|
||||
return [line for line in str(content or "").splitlines() if line.strip()]
|
||||
|
||||
|
||||
def looks_like_structured_text(content: str) -> bool:
|
||||
text = str(content or "").strip()
|
||||
if "|" not in text or text.count("|") < 2:
|
||||
return False
|
||||
parts = text.split("|")
|
||||
return len(parts) == 3 and all(part.strip() for part in parts)
|
||||
|
||||
|
||||
def looks_like_quote_text(content: str) -> bool:
|
||||
lines = _non_empty_lines(content)
|
||||
if len(lines) < 5:
|
||||
return False
|
||||
avg_len = sum(len(line) for line in lines) / len(lines)
|
||||
return avg_len < 20
|
||||
|
||||
|
||||
def looks_like_narrative_text(content: str) -> bool:
|
||||
text = str(content or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
narrative_score = sum(1 for marker in _NARRATIVE_MARKERS if re.search(marker, text))
|
||||
has_dialogue = bool(re.search(r'["「『].*?["」』]', text))
|
||||
has_chapter = any(token in text[:500] for token in ("Chapter", "CHAPTER", "###"))
|
||||
return has_chapter or has_dialogue or narrative_score >= 2
|
||||
|
||||
|
||||
def looks_like_factual_text(content: str) -> bool:
|
||||
text = str(content or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
if looks_like_structured_text(text) or looks_like_quote_text(text):
|
||||
return False
|
||||
|
||||
factual_score = sum(1 for marker in _FACTUAL_MARKERS if re.search(r"\s*" + marker + r"\s*", text))
|
||||
if factual_score <= 0:
|
||||
return False
|
||||
|
||||
if len(text) <= 240:
|
||||
return True
|
||||
return factual_score >= 2 and not looks_like_narrative_text(text)
|
||||
|
||||
|
||||
def select_import_strategy(
|
||||
content: str,
|
||||
*,
|
||||
override: Optional[str | ImportStrategy] = None,
|
||||
chat_log: bool = False,
|
||||
) -> ImportStrategy:
|
||||
"""文本导入策略选择:override > quote > factual > narrative。"""
|
||||
|
||||
if chat_log:
|
||||
return ImportStrategy.NARRATIVE
|
||||
|
||||
strategy = parse_import_strategy(override, default=ImportStrategy.AUTO)
|
||||
if strategy != ImportStrategy.AUTO:
|
||||
return strategy
|
||||
|
||||
if looks_like_quote_text(content):
|
||||
return ImportStrategy.QUOTE
|
||||
if looks_like_factual_text(content):
|
||||
return ImportStrategy.FACTUAL
|
||||
return ImportStrategy.NARRATIVE
|
||||
|
||||
|
||||
def detect_knowledge_type(content: str) -> KnowledgeType:
|
||||
"""自动检测落库 knowledge_type;无法可靠判断时回退 mixed。"""
|
||||
|
||||
text = str(content or "").strip()
|
||||
if not text:
|
||||
return KnowledgeType.MIXED
|
||||
if looks_like_structured_text(text):
|
||||
return KnowledgeType.STRUCTURED
|
||||
if looks_like_quote_text(text):
|
||||
return KnowledgeType.QUOTE
|
||||
if looks_like_factual_text(text):
|
||||
return KnowledgeType.FACTUAL
|
||||
if looks_like_narrative_text(text):
|
||||
return KnowledgeType.NARRATIVE
|
||||
return KnowledgeType.MIXED
|
||||
|
||||
|
||||
def get_type_from_user_input(type_hint: Optional[str], content: str) -> KnowledgeType:
|
||||
"""优先使用显式 type_hint,否则自动检测。"""
|
||||
|
||||
if type_hint:
|
||||
return resolve_stored_knowledge_type(type_hint, content=content)
|
||||
return detect_knowledge_type(content)
|
||||
776
src/A_memorix/core/storage/vector_store.py
Normal file
776
src/A_memorix/core/storage/vector_store.py
Normal file
@@ -0,0 +1,776 @@
|
||||
"""
|
||||
向量存储模块
|
||||
|
||||
基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import hashlib
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Tuple, List, Dict, Set, Any
|
||||
import random
|
||||
import threading # Added threading import
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import faiss
|
||||
HAS_FAISS = True
|
||||
except ImportError:
|
||||
HAS_FAISS = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..utils.quantization import QuantizationType
|
||||
from ..utils.io import atomic_write, atomic_save_path
|
||||
|
||||
logger = get_logger("A_Memorix.VectorStore")
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
向量存储类 (SQ8 + Append-Only Disk)
|
||||
|
||||
特性:
|
||||
- 索引: IndexIDMap2(IndexScalarQuantizer(QT_8bit))
|
||||
- 存储: float16 on-disk binary (vectors.bin)
|
||||
- 内存: 仅索引常驻 RAM (<512MB for 100k vectors)
|
||||
- ID: SHA1-based stable int64 IDs
|
||||
- 一致性: 强制 L2 Normalization (IP == Cosine)
|
||||
"""
|
||||
|
||||
# 默认训练触发阈值 (40 样本,过大可能导致小数据集不生效,过小可能量化退化)
|
||||
DEFAULT_MIN_TRAIN = 40
|
||||
# 强制训练样本量
|
||||
TRAIN_SIZE = 10000
|
||||
# 储水池采样上限 (流式处理前 50k 数据)
|
||||
RESERVOIR_CAPACITY = 10000
|
||||
RESERVOIR_SAMPLE_SCOPE = 50000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int,
|
||||
quantization_type: QuantizationType = QuantizationType.INT8,
|
||||
index_type: str = "sq8",
|
||||
data_dir: Optional[Union[str, Path]] = None,
|
||||
use_mmap: bool = True,
|
||||
buffer_size: int = 1024,
|
||||
):
|
||||
if not HAS_FAISS:
|
||||
raise ImportError("Faiss 未安装,请安装: pip install faiss-cpu")
|
||||
|
||||
self.dimension = dimension
|
||||
self.data_dir = Path(data_dir) if data_dir else None
|
||||
if self.data_dir:
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
if quantization_type != QuantizationType.INT8:
|
||||
raise ValueError(
|
||||
"vNext 仅支持 quantization_type=int8(SQ8)。"
|
||||
" 请更新配置并执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
normalized_index_type = str(index_type or "sq8").strip().lower()
|
||||
if normalized_index_type not in {"sq8", "int8"}:
|
||||
raise ValueError(
|
||||
"vNext 仅支持 index_type=sq8。"
|
||||
" 请更新配置并执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
self.quantization_type = QuantizationType.INT8
|
||||
self.index_type = "sq8"
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
self._index: Optional[faiss.IndexIDMap2] = None
|
||||
self._init_index()
|
||||
|
||||
self._is_trained = False
|
||||
self._vector_norm = "l2"
|
||||
|
||||
# Fallback Index (Flat) - 用于在 SQ8 训练完成前提供检索能力
|
||||
# 必须使用 IndexIDMap2 以保证 ID 与主索引一致
|
||||
self._fallback_index: Optional[faiss.IndexIDMap2] = None
|
||||
self._init_fallback_index()
|
||||
|
||||
self._known_hashes: Set[str] = set()
|
||||
self._deleted_ids: Set[int] = set()
|
||||
|
||||
self._reservoir_buffer: List[np.ndarray] = []
|
||||
self._seen_count_for_reservoir = 0
|
||||
|
||||
self._write_buffer_vecs: List[np.ndarray] = []
|
||||
self._write_buffer_ids: List[int] = []
|
||||
|
||||
self._total_added = 0
|
||||
self._total_deleted = 0
|
||||
self._bin_count = 0
|
||||
|
||||
# Thread safety lock
|
||||
self._lock = threading.RLock()
|
||||
|
||||
logger.info(f"VectorStore Init: dim={dimension}, SQ8 Mode, Append-Only Storage")
|
||||
|
||||
def _init_index(self):
|
||||
"""初始化空的 Faiss 索引"""
|
||||
quantizer = faiss.IndexScalarQuantizer(
|
||||
self.dimension,
|
||||
faiss.ScalarQuantizer.QT_8bit,
|
||||
faiss.METRIC_INNER_PRODUCT
|
||||
)
|
||||
self._index = faiss.IndexIDMap2(quantizer)
|
||||
self._is_trained = False
|
||||
|
||||
def _init_fallback_index(self):
|
||||
"""初始化 Flat 回退索引"""
|
||||
flat_index = faiss.IndexFlatIP(self.dimension)
|
||||
self._fallback_index = faiss.IndexIDMap2(flat_index)
|
||||
logger.debug("Fallback index (Flat) initialized.")
|
||||
|
||||
@staticmethod
|
||||
def _generate_id(key: str) -> int:
|
||||
"""生成稳定的 int64 ID (SHA1 截断)"""
|
||||
h = hashlib.sha1(key.encode("utf-8")).digest()
|
||||
val = int.from_bytes(h[:8], byteorder="big", signed=False)
|
||||
return val & 0x7FFFFFFFFFFFFFFF
|
||||
|
||||
@property
|
||||
def _bin_path(self) -> Path:
|
||||
return self.data_dir / "vectors.bin"
|
||||
|
||||
@property
|
||||
def _ids_bin_path(self) -> Path:
|
||||
return self.data_dir / "vectors_ids.bin"
|
||||
|
||||
@property
|
||||
def _int_to_str_map(self) -> Dict[int, str]:
|
||||
"""Lazy build volatile map from known hashes"""
|
||||
# Note: This is read-heavy and cached, might need lock if _known_hashes updates concurrently
|
||||
# But add/delete are now locked, so checking len mismatch is somewhat safe-ish for quick dirty cache
|
||||
if not hasattr(self, "_cached_map") or len(self._cached_map) != len(self._known_hashes):
|
||||
with self._lock: # Protect cache rebuild
|
||||
self._cached_map = {self._generate_id(k): k for k in self._known_hashes}
|
||||
return self._cached_map
|
||||
|
||||
def add(self, vectors: np.ndarray, ids: List[str]) -> int:
|
||||
with self._lock:
|
||||
if vectors.shape[1] != self.dimension:
|
||||
raise ValueError(f"Dimension mismatch: {vectors.shape[1]} vs {self.dimension}")
|
||||
|
||||
vectors = np.ascontiguousarray(vectors, dtype=np.float32)
|
||||
faiss.normalize_L2(vectors)
|
||||
|
||||
processed_vecs = []
|
||||
processed_int_ids = []
|
||||
|
||||
for i, str_id in enumerate(ids):
|
||||
if str_id in self._known_hashes:
|
||||
continue
|
||||
|
||||
int_id = self._generate_id(str_id)
|
||||
self._known_hashes.add(str_id)
|
||||
|
||||
processed_vecs.append(vectors[i])
|
||||
processed_int_ids.append(int_id)
|
||||
|
||||
if not processed_vecs:
|
||||
return 0
|
||||
|
||||
batch_vecs = np.array(processed_vecs, dtype=np.float32)
|
||||
batch_ids = np.array(processed_int_ids, dtype=np.int64)
|
||||
|
||||
self._write_buffer_vecs.append(batch_vecs)
|
||||
self._write_buffer_ids.extend(processed_int_ids)
|
||||
|
||||
if len(self._write_buffer_ids) >= self.buffer_size:
|
||||
self._flush_write_buffer_unlocked()
|
||||
|
||||
if not self._is_trained:
|
||||
# 双写到回退索引
|
||||
self._fallback_index.add_with_ids(batch_vecs, batch_ids)
|
||||
|
||||
self._update_reservoir(batch_vecs)
|
||||
# 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断
|
||||
if len(self._reservoir_buffer) >= 10000:
|
||||
logger.info(f"训练样本达到上限,开始训练...")
|
||||
self._train_and_replay_unlocked()
|
||||
|
||||
self._total_added += len(batch_ids)
|
||||
return len(batch_ids)
|
||||
|
||||
def _flush_write_buffer(self):
|
||||
with self._lock:
|
||||
self._flush_write_buffer_unlocked()
|
||||
|
||||
def _flush_write_buffer_unlocked(self):
|
||||
if not self._write_buffer_vecs:
|
||||
return
|
||||
|
||||
batch_vecs = np.concatenate(self._write_buffer_vecs, axis=0)
|
||||
batch_ids = np.array(self._write_buffer_ids, dtype=np.int64)
|
||||
|
||||
vecs_fp16 = batch_vecs.astype(np.float16)
|
||||
|
||||
with open(self._bin_path, "ab") as f:
|
||||
f.write(vecs_fp16.tobytes())
|
||||
|
||||
ids_bytes = batch_ids.astype('>i8').tobytes()
|
||||
with open(self._ids_bin_path, "ab") as f:
|
||||
f.write(ids_bytes)
|
||||
|
||||
self._bin_count += len(batch_ids)
|
||||
|
||||
if self._is_trained and self._index.is_trained:
|
||||
self._index.add_with_ids(batch_vecs, batch_ids)
|
||||
else:
|
||||
# 即使在 flush 时,如果未训练,也要同步到 fallback
|
||||
self._fallback_index.add_with_ids(batch_vecs, batch_ids)
|
||||
|
||||
self._write_buffer_vecs.clear()
|
||||
self._write_buffer_ids.clear()
|
||||
|
||||
def _update_reservoir(self, vectors: np.ndarray):
|
||||
for vec in vectors:
|
||||
self._seen_count_for_reservoir += 1
|
||||
if len(self._reservoir_buffer) < self.RESERVOIR_CAPACITY:
|
||||
self._reservoir_buffer.append(vec)
|
||||
else:
|
||||
if self._seen_count_for_reservoir <= self.RESERVOIR_SAMPLE_SCOPE:
|
||||
r = random.randint(0, self._seen_count_for_reservoir - 1)
|
||||
if r < self.RESERVOIR_CAPACITY:
|
||||
self._reservoir_buffer[r] = vec
|
||||
|
||||
def _train_and_replay(self):
|
||||
with self._lock:
|
||||
self._train_and_replay_unlocked()
|
||||
|
||||
def _train_and_replay_unlocked(self):
|
||||
if not self._reservoir_buffer:
|
||||
logger.warning("No training data available.")
|
||||
return
|
||||
|
||||
train_data = np.array(self._reservoir_buffer, dtype=np.float32)
|
||||
logger.info(f"Training Index with {len(train_data)} samples...")
|
||||
|
||||
try:
|
||||
self._index.train(train_data)
|
||||
except Exception as e:
|
||||
logger.error(f"SQ8 Training failed: {e}. Staying in fallback mode.")
|
||||
return
|
||||
|
||||
self._is_trained = True
|
||||
self._reservoir_buffer = []
|
||||
|
||||
logger.info("Replaying data from disk to populate index...")
|
||||
try:
|
||||
replay_count = self._replay_vectors_to_index()
|
||||
# 只有当 replay 成功且数据量一致时,才释放回退索引
|
||||
if self._index.ntotal >= self._bin_count:
|
||||
logger.info(f"Replay successful ({self._index.ntotal}/{self._bin_count}). Releasing fallback index.")
|
||||
self._fallback_index.reset()
|
||||
else:
|
||||
logger.warning(f"Replay count mismatch: {self._index.ntotal} vs {self._bin_count}. Keeping fallback index.")
|
||||
except Exception as e:
|
||||
logger.error(f"Replay failed: {e}. Keeping fallback index as backup.")
|
||||
|
||||
def _replay_vectors_to_index(self) -> int:
|
||||
"""从 vectors.bin 读取并添加到 index"""
|
||||
if not self._bin_path.exists() or not self._ids_bin_path.exists():
|
||||
return 0
|
||||
|
||||
vec_item_size = self.dimension * 2
|
||||
id_item_size = 8
|
||||
chunk_size = 10000
|
||||
|
||||
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id:
|
||||
while True:
|
||||
vec_data = f_vec.read(chunk_size * vec_item_size)
|
||||
id_data = f_id.read(chunk_size * id_item_size)
|
||||
|
||||
if not vec_data:
|
||||
break
|
||||
|
||||
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
|
||||
batch_fp32 = batch_fp16.astype(np.float32)
|
||||
faiss.normalize_L2(batch_fp32)
|
||||
|
||||
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
|
||||
|
||||
valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
|
||||
if not all(valid_mask):
|
||||
batch_fp32 = batch_fp32[valid_mask]
|
||||
batch_ids = batch_ids[valid_mask]
|
||||
|
||||
if len(batch_ids) > 0:
|
||||
self._index.add_with_ids(batch_fp32, batch_ids)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
k: int = 10,
|
||||
filter_deleted: bool = True,
|
||||
) -> Tuple[List[str], List[float]]:
|
||||
query_local = np.array(query, dtype=np.float32, order="C", copy=True)
|
||||
if query_local.ndim == 1:
|
||||
got_dim = int(query_local.shape[0])
|
||||
query_local = query_local.reshape(1, -1)
|
||||
elif query_local.ndim == 2:
|
||||
if query_local.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}"
|
||||
)
|
||||
got_dim = int(query_local.shape[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}"
|
||||
)
|
||||
|
||||
if got_dim != self.dimension:
|
||||
raise ValueError(
|
||||
f"query embedding dimension mismatch: expected={self.dimension} got={got_dim}"
|
||||
)
|
||||
if not np.all(np.isfinite(query_local)):
|
||||
raise ValueError("query embedding contains non-finite values")
|
||||
|
||||
faiss.normalize_L2(query_local)
|
||||
|
||||
# 查询路径仅负责检索,不在此触发训练/回放。
|
||||
# 训练/回放前置到 warmup_index(),并由插件启动阶段触发。
|
||||
# Faiss 索引在并发 search 下可能出现阻塞,这里串行化检索调用保证稳定性。
|
||||
with self._lock:
|
||||
self._flush_write_buffer_unlocked()
|
||||
search_index = self._index if (self._is_trained and self._index.ntotal > 0) else self._fallback_index
|
||||
if search_index.ntotal == 0:
|
||||
logger.warning("Indices are empty. No data to search.")
|
||||
return [], []
|
||||
# 执行检索
|
||||
dists, ids = search_index.search(query_local, k * 2)
|
||||
|
||||
# Faiss search 返回的是 (1, K) 的数组,取第一行
|
||||
dists = dists[0]
|
||||
ids = ids[0]
|
||||
|
||||
results = []
|
||||
for id_val, score in zip(ids, dists):
|
||||
if id_val == -1: continue
|
||||
if filter_deleted and id_val in self._deleted_ids:
|
||||
continue
|
||||
|
||||
str_id = self._int_to_str_map.get(id_val)
|
||||
if str_id:
|
||||
results.append((str_id, float(score)))
|
||||
|
||||
# Sort and trim just in case filtering reduced count
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
results = results[:k]
|
||||
|
||||
if not results:
|
||||
return [], []
|
||||
|
||||
return [r[0] for r in results], [r[1] for r in results]
|
||||
|
||||
def warmup_index(self, force_train: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
预热向量索引(训练/回放前置),避免首个线上查询触发重初始化。
|
||||
|
||||
Args:
|
||||
force_train: 是否在满足阈值时强制训练 SQ8 索引
|
||||
|
||||
Returns:
|
||||
预热状态摘要
|
||||
"""
|
||||
started = time.perf_counter()
|
||||
logger.info(f"metric.vector_index_prewarm_started=1 force_train={bool(force_train)}")
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
self._flush_write_buffer()
|
||||
|
||||
if self._bin_path.exists():
|
||||
self._bin_count = self._bin_path.stat().st_size // (self.dimension * 2)
|
||||
else:
|
||||
self._bin_count = 0
|
||||
|
||||
needs_fallback_bootstrap = (
|
||||
self._bin_count > 0
|
||||
and self._fallback_index.ntotal == 0
|
||||
and (not self._is_trained or self._index.ntotal == 0)
|
||||
)
|
||||
if needs_fallback_bootstrap:
|
||||
self._bootstrap_fallback_from_disk()
|
||||
|
||||
min_train = max(1, int(getattr(self, "min_train_threshold", self.DEFAULT_MIN_TRAIN)))
|
||||
needs_train = (
|
||||
bool(force_train)
|
||||
and self._bin_count >= min_train
|
||||
and not self._is_trained
|
||||
)
|
||||
if needs_train:
|
||||
self._force_train_small_data()
|
||||
|
||||
duration_ms = (time.perf_counter() - started) * 1000.0
|
||||
summary = {
|
||||
"ok": True,
|
||||
"trained": bool(self._is_trained),
|
||||
"index_ntotal": int(self._index.ntotal),
|
||||
"fallback_ntotal": int(self._fallback_index.ntotal),
|
||||
"bin_count": int(self._bin_count),
|
||||
"duration_ms": duration_ms,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as e:
|
||||
duration_ms = (time.perf_counter() - started) * 1000.0
|
||||
summary = {
|
||||
"ok": False,
|
||||
"trained": bool(self._is_trained),
|
||||
"index_ntotal": int(self._index.ntotal) if self._index is not None else 0,
|
||||
"fallback_ntotal": int(self._fallback_index.ntotal) if self._fallback_index is not None else 0,
|
||||
"bin_count": int(getattr(self, "_bin_count", 0)),
|
||||
"duration_ms": duration_ms,
|
||||
"error": str(e),
|
||||
}
|
||||
logger.error(
|
||||
"metric.vector_index_prewarm_fail=1 "
|
||||
f"metric.vector_index_prewarm_duration_ms={duration_ms:.2f} "
|
||||
f"error={e}"
|
||||
)
|
||||
return summary
|
||||
|
||||
logger.info(
|
||||
"metric.vector_index_prewarm_success=1 "
|
||||
f"metric.vector_index_prewarm_duration_ms={summary['duration_ms']:.2f} "
|
||||
f"trained={summary['trained']} "
|
||||
f"index_ntotal={summary['index_ntotal']} "
|
||||
f"fallback_ntotal={summary['fallback_ntotal']} "
|
||||
f"bin_count={summary['bin_count']}"
|
||||
)
|
||||
return summary
|
||||
|
||||
def _bootstrap_fallback_from_disk(self):
|
||||
with self._lock:
|
||||
self._bootstrap_fallback_from_disk_unlocked()
|
||||
|
||||
def _bootstrap_fallback_from_disk_unlocked(self):
|
||||
"""重启后自举:从磁盘 vectors.bin 加载数据到 fallback 索引"""
|
||||
if not self._bin_path.exists() or not self._ids_bin_path.exists():
|
||||
return
|
||||
|
||||
logger.info("Replaying all disk vectors to fallback index...")
|
||||
vec_item_size = self.dimension * 2
|
||||
id_item_size = 8
|
||||
chunk_size = 10000
|
||||
|
||||
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id:
|
||||
while True:
|
||||
vec_data = f_vec.read(chunk_size * vec_item_size)
|
||||
id_data = f_id.read(chunk_size * id_item_size)
|
||||
if not vec_data: break
|
||||
|
||||
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
|
||||
batch_fp32 = batch_fp16.astype(np.float32)
|
||||
faiss.normalize_L2(batch_fp32)
|
||||
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
|
||||
|
||||
valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
|
||||
if any(valid_mask):
|
||||
self._fallback_index.add_with_ids(batch_fp32[valid_mask], batch_ids[valid_mask])
|
||||
|
||||
logger.info(f"Fallback index self-bootstrapped with {self._fallback_index.ntotal} items.")
|
||||
|
||||
def _force_train_small_data(self):
|
||||
with self._lock:
|
||||
self._force_train_small_data_unlocked()
|
||||
|
||||
def _force_train_small_data_unlocked(self):
|
||||
logger.info("Forcing training on small dataset...")
|
||||
self._reservoir_buffer = []
|
||||
|
||||
chunk_size = 10000
|
||||
vec_item_size = self.dimension * 2
|
||||
|
||||
with open(self._bin_path, "rb") as f:
|
||||
while len(self._reservoir_buffer) < self.TRAIN_SIZE:
|
||||
data = f.read(chunk_size * vec_item_size)
|
||||
if not data: break
|
||||
fp16 = np.frombuffer(data, dtype=np.float16).reshape(-1, self.dimension)
|
||||
fp32 = fp16.astype(np.float32)
|
||||
faiss.normalize_L2(fp32)
|
||||
|
||||
for vec in fp32:
|
||||
self._reservoir_buffer.append(vec)
|
||||
if len(self._reservoir_buffer) >= self.TRAIN_SIZE:
|
||||
break
|
||||
|
||||
self._train_and_replay_unlocked()
|
||||
|
||||
def delete(self, ids: List[str]) -> int:
|
||||
with self._lock:
|
||||
count = 0
|
||||
for str_id in ids:
|
||||
if str_id not in self._known_hashes:
|
||||
continue
|
||||
int_id = self._generate_id(str_id)
|
||||
if int_id not in self._deleted_ids:
|
||||
self._deleted_ids.add(int_id)
|
||||
if self._index.is_trained:
|
||||
self._index.remove_ids(np.array([int_id], dtype=np.int64))
|
||||
# 同步从 fallback 移除
|
||||
if self._fallback_index.ntotal > 0:
|
||||
self._fallback_index.remove_ids(np.array([int_id], dtype=np.int64))
|
||||
count += 1
|
||||
self._total_deleted += count
|
||||
|
||||
# Check GC
|
||||
self._check_rebuild_needed()
|
||||
return count
|
||||
|
||||
def _check_rebuild_needed(self):
|
||||
"""GC Excution Check"""
|
||||
if self._bin_count == 0: return
|
||||
ratio = len(self._deleted_ids) / self._bin_count
|
||||
if ratio > 0.3 and len(self._deleted_ids) > 1000:
|
||||
logger.info(f"Triggering GC/Rebuild (deleted ratio: {ratio:.2f})")
|
||||
self.rebuild_index()
|
||||
|
||||
def rebuild_index(self):
|
||||
"""GC: 重建索引,压缩 bin 文件"""
|
||||
with self._lock:
|
||||
self._rebuild_index_locked()
|
||||
|
||||
def _rebuild_index_locked(self):
|
||||
"""实际 GC 重建逻辑。"""
|
||||
logger.info("Starting Compaction (GC)...")
|
||||
|
||||
tmp_bin = self.data_dir / "vectors.bin.tmp"
|
||||
tmp_ids = self.data_dir / "vectors_ids.bin.tmp"
|
||||
|
||||
vec_item_size = self.dimension * 2
|
||||
id_item_size = 8
|
||||
chunk_size = 10000
|
||||
|
||||
new_count = 0
|
||||
|
||||
# 1. Compact Files
|
||||
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id, \
|
||||
open(tmp_bin, "wb") as w_vec, open(tmp_ids, "wb") as w_id:
|
||||
while True:
|
||||
vec_data = f_vec.read(chunk_size * vec_item_size)
|
||||
id_data = f_id.read(chunk_size * id_item_size)
|
||||
if not vec_data: break
|
||||
|
||||
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
|
||||
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
|
||||
|
||||
keep_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
|
||||
|
||||
if any(keep_mask):
|
||||
keep_vecs = batch_fp16[keep_mask]
|
||||
keep_ids = batch_ids[keep_mask]
|
||||
|
||||
w_vec.write(keep_vecs.tobytes())
|
||||
w_id.write(keep_ids.astype('>i8').tobytes())
|
||||
new_count += len(keep_ids)
|
||||
|
||||
# 2. Reset State & Atomic Swap
|
||||
self._bin_count = new_count
|
||||
|
||||
# Close current index
|
||||
self._index.reset()
|
||||
if self._fallback_index: self._fallback_index.reset() # Also clear fallback
|
||||
self._is_trained = False
|
||||
|
||||
# Swap files
|
||||
shutil.move(str(tmp_bin), str(self._bin_path))
|
||||
shutil.move(str(tmp_ids), str(self._ids_bin_path))
|
||||
|
||||
# Reset Tombstones (Critical)
|
||||
self._deleted_ids.clear()
|
||||
|
||||
# 3. Reload/Rebuild Index (Fresh Train)
|
||||
# We need to re-train because data distribution might have changed significantly after deletion
|
||||
self._init_index()
|
||||
self._init_fallback_index() # Re-init fallback too
|
||||
self._force_train_small_data() # This will train and replay from the NEW compact file
|
||||
|
||||
logger.info("Compaction Complete.")
|
||||
|
||||
def save(self, data_dir: Optional[Union[str, Path]] = None) -> None:
|
||||
with self._lock:
|
||||
if not data_dir:
|
||||
data_dir = self.data_dir
|
||||
if not data_dir:
|
||||
raise ValueError("No data_dir")
|
||||
|
||||
data_dir = Path(data_dir)
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._flush_write_buffer_unlocked()
|
||||
|
||||
if self._is_trained:
|
||||
index_path = data_dir / "vectors.index"
|
||||
with atomic_save_path(index_path) as tmp:
|
||||
faiss.write_index(self._index, tmp)
|
||||
|
||||
meta = {
|
||||
"dimension": self.dimension,
|
||||
"quantization_type": self.quantization_type.value,
|
||||
"is_trained": self._is_trained,
|
||||
"vector_norm": self._vector_norm,
|
||||
"deleted_ids": list(self._deleted_ids),
|
||||
"known_hashes": list(self._known_hashes),
|
||||
}
|
||||
|
||||
with atomic_write(data_dir / "vectors_metadata.pkl", "wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
|
||||
logger.info("VectorStore saved.")
|
||||
|
||||
def migrate_legacy_npy(self, data_dir: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
离线迁移入口:将 legacy vectors.npy 转为 vNext 二进制格式。
|
||||
"""
|
||||
with self._lock:
|
||||
target_dir = Path(data_dir) if data_dir else self.data_dir
|
||||
if target_dir is None:
|
||||
raise ValueError("No data_dir")
|
||||
target_dir = Path(target_dir)
|
||||
npy_path = target_dir / "vectors.npy"
|
||||
idx_path = target_dir / "vectors.index"
|
||||
bin_path = target_dir / "vectors.bin"
|
||||
ids_bin_path = target_dir / "vectors_ids.bin"
|
||||
meta_path = target_dir / "vectors_metadata.pkl"
|
||||
|
||||
if not npy_path.exists():
|
||||
return {"migrated": False, "reason": "npy_missing"}
|
||||
if not meta_path.exists():
|
||||
raise RuntimeError("legacy vectors.npy migration requires vectors_metadata.pkl")
|
||||
if bin_path.exists() and ids_bin_path.exists():
|
||||
return {"migrated": False, "reason": "bin_exists"}
|
||||
|
||||
# Reset in-memory state to avoid appending to stale runtime buffers.
|
||||
self._known_hashes.clear()
|
||||
self._deleted_ids.clear()
|
||||
self._write_buffer_vecs.clear()
|
||||
self._write_buffer_ids.clear()
|
||||
self._init_index()
|
||||
self._init_fallback_index()
|
||||
self._is_trained = False
|
||||
self._bin_count = 0
|
||||
|
||||
self._migrate_from_npy_unlocked(npy_path, idx_path, target_dir)
|
||||
self.save(target_dir)
|
||||
return {"migrated": True, "reason": "ok"}
|
||||
|
||||
def load(self, data_dir: Optional[Union[str, Path]] = None) -> None:
|
||||
with self._lock:
|
||||
if not data_dir: data_dir = self.data_dir
|
||||
data_dir = Path(data_dir)
|
||||
|
||||
npy_path = data_dir / "vectors.npy"
|
||||
idx_path = data_dir / "vectors.index"
|
||||
bin_path = data_dir / "vectors.bin"
|
||||
|
||||
if npy_path.exists() and not bin_path.exists():
|
||||
raise RuntimeError(
|
||||
"检测到 legacy vectors.npy,vNext 不再支持运行时自动迁移。"
|
||||
" 请先执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
|
||||
meta_path = data_dir / "vectors_metadata.pkl"
|
||||
if not meta_path.exists():
|
||||
logger.warning("No metadata found, initialized empty.")
|
||||
return
|
||||
|
||||
with open(meta_path, "rb") as f:
|
||||
meta = pickle.load(f)
|
||||
|
||||
if meta.get("vector_norm") != "l2":
|
||||
logger.warning("Index IDMap2 version mismatch (L2 Norm), forcing rebuild...")
|
||||
self._known_hashes = set(meta.get("ids", [])) | set(meta.get("known_hashes", []))
|
||||
self._deleted_ids = set(meta.get("deleted_ids", []))
|
||||
self._init_index()
|
||||
self._force_train_small_data()
|
||||
return
|
||||
|
||||
self._is_trained = meta.get("is_trained", False)
|
||||
self._vector_norm = meta.get("vector_norm", "l2")
|
||||
self._deleted_ids = set(meta.get("deleted_ids", []))
|
||||
self._known_hashes = set(meta.get("known_hashes", []))
|
||||
|
||||
if self._is_trained:
|
||||
if idx_path.exists():
|
||||
try:
|
||||
self._index = faiss.read_index(str(idx_path))
|
||||
if not isinstance(self._index, faiss.IndexIDMap2):
|
||||
logger.warning("Loaded index type mismatch. Rebuilding...")
|
||||
self._init_index()
|
||||
self._force_train_small_data()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load index: {e}. Rebuilding...")
|
||||
self._init_index()
|
||||
self._force_train_small_data()
|
||||
else:
|
||||
logger.warning("Index file missing despite metadata indicating trained. Rebuilding from bin...")
|
||||
self._init_index()
|
||||
self._force_train_small_data()
|
||||
|
||||
if bin_path.exists():
|
||||
self._bin_count = bin_path.stat().st_size // (self.dimension * 2)
|
||||
|
||||
def _migrate_from_npy(self, npy_path, idx_path, data_dir):
|
||||
with self._lock:
|
||||
self._migrate_from_npy_unlocked(npy_path, idx_path, data_dir)
|
||||
|
||||
def _migrate_from_npy_unlocked(self, npy_path, idx_path, data_dir):
|
||||
try:
|
||||
arr = np.load(npy_path, mmap_mode="r")
|
||||
except Exception:
|
||||
arr = np.load(npy_path)
|
||||
|
||||
meta_path = data_dir / "vectors_metadata.pkl"
|
||||
old_ids = []
|
||||
if meta_path.exists():
|
||||
with open(meta_path, "rb") as f:
|
||||
m = pickle.load(f)
|
||||
old_ids = m.get("ids", [])
|
||||
|
||||
if len(arr) != len(old_ids):
|
||||
logger.error(f"Migration mismatch: arr {len(arr)} != ids {len(old_ids)}")
|
||||
return
|
||||
|
||||
logger.info(f"Migrating {len(arr)} vectors...")
|
||||
|
||||
chunk = 1000
|
||||
for i in range(0, len(arr), chunk):
|
||||
sub_arr = arr[i : i+chunk]
|
||||
sub_ids = old_ids[i : i+chunk]
|
||||
self.add(sub_arr, sub_ids)
|
||||
|
||||
if not self._is_trained:
|
||||
self._force_train_small_data()
|
||||
|
||||
shutil.move(str(npy_path), str(npy_path) + ".bak")
|
||||
if idx_path.exists():
|
||||
shutil.move(str(idx_path), str(idx_path) + ".bak")
|
||||
|
||||
logger.info("Migration complete.")
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._ids_bin_path.unlink(missing_ok=True)
|
||||
self._bin_path.unlink(missing_ok=True)
|
||||
self._init_index()
|
||||
self._known_hashes.clear()
|
||||
self._deleted_ids.clear()
|
||||
self._bin_count = 0
|
||||
logger.info("VectorStore cleared.")
|
||||
|
||||
def has_data(self) -> bool:
|
||||
return (self.data_dir / "vectors_metadata.pkl").exists()
|
||||
|
||||
@property
|
||||
def num_vectors(self) -> int:
|
||||
return len(self._known_hashes) - len(self._deleted_ids)
|
||||
|
||||
def __contains__(self, hash_value: str) -> bool:
|
||||
"""Check if a hash exists in the store"""
|
||||
return hash_value in self._known_hashes and self._generate_id(hash_value) not in self._deleted_ids
|
||||
|
||||
0
src/A_memorix/core/strategies/__init__.py
Normal file
0
src/A_memorix/core/strategies/__init__.py
Normal file
89
src/A_memorix/core/strategies/base.py
Normal file
89
src/A_memorix/core/strategies/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
|
||||
class KnowledgeType(str, Enum):
|
||||
NARRATIVE = "narrative"
|
||||
FACTUAL = "factual"
|
||||
QUOTE = "quote"
|
||||
MIXED = "mixed"
|
||||
|
||||
@dataclass
|
||||
class SourceInfo:
|
||||
file: str
|
||||
offset_start: int
|
||||
offset_end: int
|
||||
checksum: str = ""
|
||||
|
||||
@dataclass
|
||||
class ChunkContext:
|
||||
chunk_id: str
|
||||
index: int
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
text: str = ""
|
||||
|
||||
@dataclass
|
||||
class ChunkFlags:
|
||||
verbatim: bool = False
|
||||
requires_llm: bool = True
|
||||
|
||||
@dataclass
|
||||
class ProcessedChunk:
|
||||
type: KnowledgeType
|
||||
source: SourceInfo
|
||||
chunk: ChunkContext
|
||||
data: Dict[str, Any] = field(default_factory=dict) # triples、events、verbatim_entities
|
||||
flags: ChunkFlags = field(default_factory=ChunkFlags)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"source": {
|
||||
"file": self.source.file,
|
||||
"offset_start": self.source.offset_start,
|
||||
"offset_end": self.source.offset_end,
|
||||
"checksum": self.source.checksum
|
||||
},
|
||||
"chunk": {
|
||||
"text": self.chunk.text,
|
||||
"chunk_id": self.chunk.chunk_id,
|
||||
"index": self.chunk.index,
|
||||
"context": self.chunk.context
|
||||
},
|
||||
"data": self.data,
|
||||
"flags": {
|
||||
"verbatim": self.flags.verbatim,
|
||||
"requires_llm": self.flags.requires_llm
|
||||
}
|
||||
}
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
def __init__(self, filename: str):
|
||||
self.filename = filename
|
||||
|
||||
@abstractmethod
|
||||
def split(self, text: str) -> List[ProcessedChunk]:
|
||||
"""按策略将文本切分为块。"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
|
||||
"""从文本块中抽取结构化信息。"""
|
||||
pass
|
||||
|
||||
def calculate_checksum(self, text: str) -> str:
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
def build_language_guard(self, text: str) -> str:
|
||||
"""
|
||||
构建统一的输出语言约束。
|
||||
不区分语言类型,仅要求抽取值保持原文语言,不做翻译。
|
||||
"""
|
||||
_ = text # 预留参数,便于后续按需扩展
|
||||
return (
|
||||
"Focus on the original source language. Keep extracted events, entities, predicates "
|
||||
"and objects in the same language as the source text, preserve names/terms as-is, "
|
||||
"and do not translate."
|
||||
)
|
||||
98
src/A_memorix/core/strategies/factual.py
Normal file
98
src/A_memorix/core/strategies/factual.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import re
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
||||
|
||||
class FactualStrategy(BaseStrategy):
|
||||
def split(self, text: str) -> List[ProcessedChunk]:
|
||||
# 结构感知切分
|
||||
lines = text.split('\n')
|
||||
chunks = []
|
||||
current_chunk_lines = []
|
||||
current_len = 0
|
||||
target_size = 600
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 判断是否应当切分
|
||||
# 若当前行为列表项/定义/表格行,则尽量不切分
|
||||
is_structure = self._is_structural_line(line)
|
||||
|
||||
current_len += len(line) + 1
|
||||
current_chunk_lines.append(line)
|
||||
|
||||
# 达到目标长度且不在紧凑结构块内时切分(过长时强制切分)
|
||||
if current_len >= target_size and not is_structure:
|
||||
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
|
||||
current_chunk_lines = []
|
||||
current_len = 0
|
||||
elif current_len >= target_size * 2: # 超长时强制切分
|
||||
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
|
||||
current_chunk_lines = []
|
||||
current_len = 0
|
||||
|
||||
if current_chunk_lines:
|
||||
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
|
||||
|
||||
return chunks
|
||||
|
||||
def _is_structural_line(self, line: str) -> bool:
|
||||
line = line.strip()
|
||||
if not line: return False
|
||||
# 列表项
|
||||
if re.match(r'^[\-\*]\s+', line) or re.match(r'^\d+\.\s+', line):
|
||||
return True
|
||||
# 定义项(术语: 定义)
|
||||
if re.match(r'^[^::]+[::].+', line):
|
||||
return True
|
||||
# 表格行(按 markdown 语法假设)
|
||||
if line.startswith('|') and line.endswith('|'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _create_chunk(self, lines: List[str], index: int) -> ProcessedChunk:
|
||||
text = "\n".join(lines)
|
||||
return ProcessedChunk(
|
||||
type=KnowledgeType.FACTUAL,
|
||||
source=SourceInfo(
|
||||
file=self.filename,
|
||||
offset_start=0, # 简化处理:真实偏移跟踪需要额外状态
|
||||
offset_end=0,
|
||||
checksum=self.calculate_checksum(text)
|
||||
),
|
||||
chunk=ChunkContext(
|
||||
chunk_id=f"{self.filename}_{index}",
|
||||
index=index,
|
||||
text=text
|
||||
)
|
||||
)
|
||||
|
||||
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
|
||||
if not llm_func:
|
||||
raise ValueError("LLM function required for Factual extraction")
|
||||
|
||||
language_guard = self.build_language_guard(chunk.chunk.text)
|
||||
prompt = f"""You are a factual knowledge extraction engine.
|
||||
Extract factual triples and entities from the text.
|
||||
Preserve lists and definitions accurately.
|
||||
|
||||
Language constraints:
|
||||
- {language_guard}
|
||||
- Preserve original names and domain terms exactly when possible.
|
||||
- JSON keys must stay exactly as: triples, entities, subject, predicate, object.
|
||||
|
||||
Text:
|
||||
{chunk.chunk.text}
|
||||
|
||||
Return ONLY valid JSON:
|
||||
{{
|
||||
"triples": [
|
||||
{{"subject": "Entity", "predicate": "Relationship", "object": "Entity"}}
|
||||
],
|
||||
"entities": ["Entity1", "Entity2"]
|
||||
}}
|
||||
"""
|
||||
result = await llm_func(prompt)
|
||||
|
||||
# 结果保持原样存入 data,后续统一归一化流程会处理
|
||||
# vector_store 侧期望关系字段为 subject/predicate/object 映射形式
|
||||
chunk.data = result
|
||||
return chunk
|
||||
126
src/A_memorix/core/strategies/narrative.py
Normal file
126
src/A_memorix/core/strategies/narrative.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import re
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
|
||||
|
||||
class NarrativeStrategy(BaseStrategy):
|
||||
def split(self, text: str) -> List[ProcessedChunk]:
|
||||
scenes = self._split_into_scenes(text)
|
||||
chunks = []
|
||||
|
||||
for scene_idx, (scene_text, scene_title) in enumerate(scenes):
|
||||
scene_chunks = self._sliding_window(scene_text, scene_title, scene_idx)
|
||||
chunks.extend(scene_chunks)
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_into_scenes(self, text: str) -> List[tuple[str, str]]:
|
||||
"""按标题或分隔符把文本切分为场景。"""
|
||||
# 简单启发式:按 markdown 标题或特定分隔符切分
|
||||
# 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行
|
||||
# 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行
|
||||
scene_pattern_str = r'^(?:#{1,6}\s+.*|Chapter\s+\d+|^\*{3,}$|^={3,}$)'
|
||||
|
||||
# 保留分隔符,以便识别场景起点
|
||||
parts = re.split(f"({scene_pattern_str})", text, flags=re.MULTILINE)
|
||||
|
||||
scenes = []
|
||||
current_scene_title = "Start"
|
||||
current_scene_content = []
|
||||
|
||||
if parts and parts[0].strip() == "":
|
||||
parts = parts[1:]
|
||||
|
||||
for part in parts:
|
||||
if re.match(scene_pattern_str, part, re.MULTILINE):
|
||||
# 先保存上一段场景
|
||||
if current_scene_content:
|
||||
scenes.append(("".join(current_scene_content), current_scene_title))
|
||||
current_scene_content = []
|
||||
current_scene_title = part.strip()
|
||||
else:
|
||||
current_scene_content.append(part)
|
||||
|
||||
if current_scene_content:
|
||||
scenes.append(("".join(current_scene_content), current_scene_title))
|
||||
|
||||
# 若未识别到场景,则把全文视作单一场景
|
||||
if not scenes:
|
||||
scenes = [(text, "Whole Text")]
|
||||
|
||||
return scenes
|
||||
|
||||
def _sliding_window(self, text: str, scene_id: str, scene_idx: int, window_size=800, overlap=200) -> List[ProcessedChunk]:
|
||||
chunks = []
|
||||
if len(text) <= window_size:
|
||||
chunks.append(self._create_chunk(text, scene_id, scene_idx, 0, 0))
|
||||
return chunks
|
||||
|
||||
stride = window_size - overlap
|
||||
start = 0
|
||||
local_idx = 0
|
||||
while start < len(text):
|
||||
end = min(start + window_size, len(text))
|
||||
chunk_text = text[start:end]
|
||||
|
||||
# 尽量对齐到最近换行,避免生硬截断句子
|
||||
# 仅在未到文本尾部时进行回退
|
||||
if end < len(text):
|
||||
last_newline = chunk_text.rfind('\n')
|
||||
if last_newline > window_size // 2: # 仅在回退距离可接受时启用
|
||||
end = start + last_newline + 1
|
||||
chunk_text = text[start:end]
|
||||
|
||||
chunks.append(self._create_chunk(chunk_text, scene_id, scene_idx, local_idx, start))
|
||||
|
||||
start += len(chunk_text) - overlap if end < len(text) else len(chunk_text)
|
||||
local_idx += 1
|
||||
|
||||
return chunks
|
||||
|
||||
def _create_chunk(self, text: str, scene_id: str, scene_idx: int, local_idx: int, offset: int) -> ProcessedChunk:
|
||||
return ProcessedChunk(
|
||||
type=KnowledgeType.NARRATIVE,
|
||||
source=SourceInfo(
|
||||
file=self.filename,
|
||||
offset_start=offset,
|
||||
offset_end=offset + len(text),
|
||||
checksum=self.calculate_checksum(text)
|
||||
),
|
||||
chunk=ChunkContext(
|
||||
chunk_id=f"{self.filename}_{scene_idx}_{local_idx}",
|
||||
index=local_idx,
|
||||
text=text,
|
||||
context={"scene_id": scene_id}
|
||||
)
|
||||
)
|
||||
|
||||
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
|
||||
if not llm_func:
|
||||
raise ValueError("LLM function required for Narrative extraction")
|
||||
|
||||
language_guard = self.build_language_guard(chunk.chunk.text)
|
||||
prompt = f"""You are a narrative knowledge extraction engine.
|
||||
Extract key events and character relations from the scene text.
|
||||
|
||||
Language constraints:
|
||||
- {language_guard}
|
||||
- Preserve original names and terms exactly when possible.
|
||||
- JSON keys must stay exactly as: events, relations, subject, predicate, object.
|
||||
|
||||
Scene:
|
||||
{chunk.chunk.context.get('scene_id')}
|
||||
|
||||
Text:
|
||||
{chunk.chunk.text}
|
||||
|
||||
Return ONLY valid JSON:
|
||||
{{
|
||||
"events": ["event description 1", "event description 2"],
|
||||
"relations": [
|
||||
{{"subject": "CharacterA", "predicate": "relation", "object": "CharacterB"}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
result = await llm_func(prompt)
|
||||
chunk.data = result
|
||||
return chunk
|
||||
52
src/A_memorix/core/strategies/quote.py
Normal file
52
src/A_memorix/core/strategies/quote.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags
|
||||
|
||||
class QuoteStrategy(BaseStrategy):
|
||||
def split(self, text: str) -> List[ProcessedChunk]:
|
||||
# Split by double newlines (stanzas)
|
||||
stanzas = text.split("\n\n")
|
||||
chunks = []
|
||||
offset = 0
|
||||
|
||||
for idx, stanza in enumerate(stanzas):
|
||||
if not stanza.strip():
|
||||
offset += len(stanza) + 2
|
||||
continue
|
||||
|
||||
chunk = ProcessedChunk(
|
||||
type=KnowledgeType.QUOTE,
|
||||
source=SourceInfo(
|
||||
file=self.filename,
|
||||
offset_start=offset,
|
||||
offset_end=offset + len(stanza),
|
||||
checksum=self.calculate_checksum(stanza)
|
||||
),
|
||||
chunk=ChunkContext(
|
||||
chunk_id=f"{self.filename}_{idx}",
|
||||
index=idx,
|
||||
text=stanza
|
||||
),
|
||||
flags=ChunkFlags(
|
||||
verbatim=True,
|
||||
requires_llm=False # Default to no LLM, but can be overridden
|
||||
)
|
||||
)
|
||||
chunks.append(chunk)
|
||||
offset += len(stanza) + 2 # +2 for \n\n
|
||||
|
||||
return chunks
|
||||
|
||||
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
|
||||
# For quotes, the text itself is the entity/knowledge
|
||||
# We might use LLM to extract headers/metadata if requested, but core logic is pass-through
|
||||
|
||||
# Treat the whole chunk text as a verbatim entity
|
||||
chunk.data = {
|
||||
"verbatim_entities": [chunk.chunk.text]
|
||||
}
|
||||
|
||||
if llm_func and chunk.flags.requires_llm:
|
||||
# Optional: Extract metadata
|
||||
pass
|
||||
|
||||
return chunk
|
||||
33
src/A_memorix/core/utils/__init__.py
Normal file
33
src/A_memorix/core/utils/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""工具模块 - 哈希、监控等辅助功能"""
|
||||
|
||||
from .hash import compute_hash, normalize_text
|
||||
from .monitor import MemoryMonitor
|
||||
from .quantization import quantize_vector, dequantize_vector
|
||||
from .time_parser import (
|
||||
parse_query_datetime_to_timestamp,
|
||||
parse_query_time_range,
|
||||
parse_ingest_datetime_to_timestamp,
|
||||
normalize_time_meta,
|
||||
format_timestamp,
|
||||
)
|
||||
from .relation_write_service import RelationWriteService, RelationWriteResult
|
||||
from .relation_query import RelationQuerySpec, parse_relation_query_spec
|
||||
from .plugin_id_policy import PluginIdPolicy
|
||||
|
||||
__all__ = [
|
||||
"compute_hash",
|
||||
"normalize_text",
|
||||
"MemoryMonitor",
|
||||
"quantize_vector",
|
||||
"dequantize_vector",
|
||||
"parse_query_datetime_to_timestamp",
|
||||
"parse_query_time_range",
|
||||
"parse_ingest_datetime_to_timestamp",
|
||||
"normalize_time_meta",
|
||||
"format_timestamp",
|
||||
"RelationWriteService",
|
||||
"RelationWriteResult",
|
||||
"RelationQuerySpec",
|
||||
"parse_relation_query_spec",
|
||||
"PluginIdPolicy",
|
||||
]
|
||||
360
src/A_memorix/core/utils/aggregate_query_service.py
Normal file
360
src/A_memorix/core/utils/aggregate_query_service.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
聚合查询服务:
|
||||
- 并发执行 search/time/episode 分支
|
||||
- 统一分支结果结构
|
||||
- 可选混合排序(Weighted RRF)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.AggregateQueryService")
|
||||
|
||||
BranchRunner = Callable[[], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class AggregateQueryService:
|
||||
"""聚合查询执行服务(search/time/episode)。"""
|
||||
|
||||
def __init__(self, plugin_config: Optional[Any] = None):
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(self.plugin_config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _as_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return float(default)
|
||||
|
||||
@staticmethod
|
||||
def _as_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
def _rrf_k(self) -> float:
|
||||
raw = self._cfg("retrieval.aggregate.rrf_k", 60.0)
|
||||
value = self._as_float(raw, 60.0)
|
||||
return max(1.0, value)
|
||||
|
||||
def _weights(self) -> Dict[str, float]:
|
||||
defaults = {"search": 1.0, "time": 1.0, "episode": 1.0}
|
||||
raw = self._cfg("retrieval.aggregate.weights", {})
|
||||
if not isinstance(raw, dict):
|
||||
return defaults
|
||||
|
||||
out = dict(defaults)
|
||||
for key in ("search", "time", "episode"):
|
||||
if key in raw:
|
||||
out[key] = max(0.0, self._as_float(raw.get(key), defaults[key]))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _normalize_branch_payload(
|
||||
name: str,
|
||||
payload: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
data = payload if isinstance(payload, dict) else {}
|
||||
results_raw = data.get("results", [])
|
||||
results = results_raw if isinstance(results_raw, list) else []
|
||||
count = data.get("count")
|
||||
if count is None:
|
||||
count = len(results)
|
||||
return {
|
||||
"name": name,
|
||||
"success": bool(data.get("success", False)),
|
||||
"skipped": bool(data.get("skipped", False)),
|
||||
"skip_reason": str(data.get("skip_reason", "") or "").strip(),
|
||||
"error": str(data.get("error", "") or "").strip(),
|
||||
"results": results,
|
||||
"count": max(0, int(count)),
|
||||
"elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)),
|
||||
"content": str(data.get("content", "") or ""),
|
||||
"query_type": str(data.get("query_type", "") or name),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str:
|
||||
item_type = str(item.get("type", "") or "").strip().lower()
|
||||
if item_type == "episode":
|
||||
episode_id = str(item.get("episode_id", "") or "").strip()
|
||||
if episode_id:
|
||||
return f"episode:{episode_id}"
|
||||
|
||||
item_hash = str(item.get("hash", "") or "").strip()
|
||||
if item_hash:
|
||||
return f"{item_type}:{item_hash}"
|
||||
|
||||
return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}"
|
||||
|
||||
def _build_mixed_results(
|
||||
self,
|
||||
*,
|
||||
branches: Dict[str, Dict[str, Any]],
|
||||
top_k: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
rrf_k = self._rrf_k()
|
||||
weights = self._weights()
|
||||
bucket: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for branch_name, branch in branches.items():
|
||||
if not branch.get("success", False):
|
||||
continue
|
||||
results = branch.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
continue
|
||||
|
||||
weight = max(0.0, float(weights.get(branch_name, 1.0)))
|
||||
for idx, item in enumerate(results, start=1):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
key = self._mix_key(item, branch_name, idx)
|
||||
score = weight / (rrf_k + float(idx))
|
||||
if key not in bucket:
|
||||
merged = dict(item)
|
||||
merged["fusion_score"] = 0.0
|
||||
merged["_source_branches"] = set()
|
||||
bucket[key] = merged
|
||||
|
||||
target = bucket[key]
|
||||
target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score
|
||||
target["_source_branches"].add(branch_name)
|
||||
|
||||
mixed = list(bucket.values())
|
||||
mixed.sort(
|
||||
key=lambda x: (
|
||||
-float(x.get("fusion_score", 0.0)),
|
||||
str(x.get("type", "") or ""),
|
||||
str(x.get("hash", "") or x.get("episode_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
out: List[Dict[str, Any]] = []
|
||||
for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1):
|
||||
merged = dict(item)
|
||||
branches_set = merged.pop("_source_branches", set())
|
||||
merged["source_branches"] = sorted(list(branches_set))
|
||||
merged["rank"] = rank
|
||||
out.append(merged)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _status(branch: Dict[str, Any]) -> str:
|
||||
if branch.get("skipped", False):
|
||||
return "skipped"
|
||||
if branch.get("success", False):
|
||||
return "success"
|
||||
return "failed"
|
||||
|
||||
def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
summary: Dict[str, Dict[str, Any]] = {}
|
||||
for name, branch in branches.items():
|
||||
status = self._status(branch)
|
||||
summary[name] = {
|
||||
"status": status,
|
||||
"count": int(branch.get("count", 0) or 0),
|
||||
}
|
||||
if status == "skipped":
|
||||
summary[name]["reason"] = str(branch.get("skip_reason", "") or "")
|
||||
if status == "failed":
|
||||
summary[name]["error"] = str(branch.get("error", "") or "")
|
||||
return summary
|
||||
|
||||
def _build_content(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
branches: Dict[str, Dict[str, Any]],
|
||||
errors: List[Dict[str, str]],
|
||||
mixed_results: Optional[List[Dict[str, Any]]],
|
||||
) -> str:
|
||||
lines: List[str] = [
|
||||
f"🔀 聚合查询结果(query='{query or 'N/A'}')",
|
||||
"",
|
||||
"分支状态:",
|
||||
]
|
||||
for name in ("search", "time", "episode"):
|
||||
branch = branches.get(name, {})
|
||||
status = self._status(branch)
|
||||
count = int(branch.get("count", 0) or 0)
|
||||
line = f"- {name}: {status}, count={count}"
|
||||
reason = str(branch.get("skip_reason", "") or "").strip()
|
||||
err = str(branch.get("error", "") or "").strip()
|
||||
if status == "skipped" and reason:
|
||||
line += f" ({reason})"
|
||||
if status == "failed" and err:
|
||||
line += f" ({err})"
|
||||
lines.append(line)
|
||||
|
||||
if errors:
|
||||
lines.append("")
|
||||
lines.append("错误:")
|
||||
for item in errors[:6]:
|
||||
lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}")
|
||||
|
||||
if mixed_results is not None:
|
||||
lines.append("")
|
||||
lines.append(f"🧩 混合结果({len(mixed_results)} 条):")
|
||||
for idx, item in enumerate(mixed_results[:5], start=1):
|
||||
src = ",".join(item.get("source_branches", []) or [])
|
||||
if str(item.get("type", "") or "") == "episode":
|
||||
title = str(item.get("title", "") or "Untitled")
|
||||
lines.append(f"{idx}. 🧠 {title} [{src}]")
|
||||
else:
|
||||
text = str(item.get("content", "") or "")
|
||||
if len(text) > 80:
|
||||
text = text[:80] + "..."
|
||||
lines.append(f"{idx}. {text} [{src}]")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
top_k: int,
|
||||
mix: bool,
|
||||
mix_top_k: Optional[int],
|
||||
time_from: Optional[str],
|
||||
time_to: Optional[str],
|
||||
search_runner: Optional[BranchRunner],
|
||||
time_runner: Optional[BranchRunner],
|
||||
episode_runner: Optional[BranchRunner],
|
||||
) -> Dict[str, Any]:
|
||||
clean_query = str(query or "").strip()
|
||||
safe_top_k = max(1, int(top_k))
|
||||
safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k))
|
||||
|
||||
branches: Dict[str, Dict[str, Any]] = {}
|
||||
errors: List[Dict[str, str]] = []
|
||||
scheduled: List[Tuple[str, asyncio.Task]] = []
|
||||
|
||||
if clean_query:
|
||||
if search_runner is not None:
|
||||
scheduled.append(("search", asyncio.create_task(search_runner())))
|
||||
else:
|
||||
branches["search"] = self._normalize_branch_payload(
|
||||
"search",
|
||||
{"success": False, "error": "search runner unavailable", "results": []},
|
||||
)
|
||||
else:
|
||||
branches["search"] = self._normalize_branch_payload(
|
||||
"search",
|
||||
{
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"skip_reason": "missing_query",
|
||||
"results": [],
|
||||
"count": 0,
|
||||
},
|
||||
)
|
||||
|
||||
if time_from or time_to:
|
||||
if time_runner is not None:
|
||||
scheduled.append(("time", asyncio.create_task(time_runner())))
|
||||
else:
|
||||
branches["time"] = self._normalize_branch_payload(
|
||||
"time",
|
||||
{"success": False, "error": "time runner unavailable", "results": []},
|
||||
)
|
||||
else:
|
||||
branches["time"] = self._normalize_branch_payload(
|
||||
"time",
|
||||
{
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"skip_reason": "missing_time_window",
|
||||
"results": [],
|
||||
"count": 0,
|
||||
},
|
||||
)
|
||||
|
||||
if episode_runner is not None:
|
||||
scheduled.append(("episode", asyncio.create_task(episode_runner())))
|
||||
else:
|
||||
branches["episode"] = self._normalize_branch_payload(
|
||||
"episode",
|
||||
{"success": False, "error": "episode runner unavailable", "results": []},
|
||||
)
|
||||
|
||||
if scheduled:
|
||||
done = await asyncio.gather(
|
||||
*[task for _, task in scheduled],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for (branch_name, _), payload in zip(scheduled, done):
|
||||
if isinstance(payload, Exception):
|
||||
logger.error(f"aggregate branch failed: branch={branch_name} error={payload}")
|
||||
normalized = self._normalize_branch_payload(
|
||||
branch_name,
|
||||
{
|
||||
"success": False,
|
||||
"error": str(payload),
|
||||
"results": [],
|
||||
},
|
||||
)
|
||||
else:
|
||||
normalized = self._normalize_branch_payload(branch_name, payload)
|
||||
branches[branch_name] = normalized
|
||||
|
||||
for name in ("search", "time", "episode"):
|
||||
branch = branches.get(name)
|
||||
if not branch:
|
||||
continue
|
||||
if branch.get("skipped", False):
|
||||
continue
|
||||
if not branch.get("success", False):
|
||||
errors.append(
|
||||
{
|
||||
"branch": name,
|
||||
"error": str(branch.get("error", "") or "unknown error"),
|
||||
}
|
||||
)
|
||||
|
||||
success = any(
|
||||
bool(branches.get(name, {}).get("success", False))
|
||||
for name in ("search", "time", "episode")
|
||||
)
|
||||
mixed_results: Optional[List[Dict[str, Any]]] = None
|
||||
if mix:
|
||||
mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"success": success,
|
||||
"query_type": "aggregate",
|
||||
"query": clean_query,
|
||||
"top_k": safe_top_k,
|
||||
"mix": bool(mix),
|
||||
"mix_top_k": safe_mix_top_k,
|
||||
"branches": branches,
|
||||
"errors": errors,
|
||||
"summary": self._build_summary(branches),
|
||||
}
|
||||
if mixed_results is not None:
|
||||
payload["mixed_results"] = mixed_results
|
||||
|
||||
payload["content"] = self._build_content(
|
||||
query=clean_query,
|
||||
branches=branches,
|
||||
errors=errors,
|
||||
mixed_results=mixed_results,
|
||||
)
|
||||
return payload
|
||||
182
src/A_memorix/core/utils/episode_retrieval_service.py
Normal file
182
src/A_memorix/core/utils/episode_retrieval_service.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Episode hybrid retrieval service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import DualPathRetriever, TemporalQueryOptions
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeRetrievalService")
|
||||
|
||||
|
||||
class EpisodeRetrievalService:
|
||||
"""Hybrid episode retrieval backed by lexical rows and evidence projection."""
|
||||
|
||||
_RRF_K = 60.0
|
||||
_BRANCH_WEIGHTS = {
|
||||
"lexical": 1.0,
|
||||
"paragraph_evidence": 1.0,
|
||||
"relation_evidence": 0.85,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_store: Any,
|
||||
retriever: Optional[DualPathRetriever] = None,
|
||||
) -> None:
|
||||
self.metadata_store = metadata_store
|
||||
self.retriever = retriever
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
query: str = "",
|
||||
top_k: int = 5,
|
||||
time_from: Optional[float] = None,
|
||||
time_to: Optional[float] = None,
|
||||
person: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
include_paragraphs: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
clean_query = str(query or "").strip()
|
||||
safe_top_k = max(1, int(top_k))
|
||||
candidate_k = max(30, safe_top_k * 6)
|
||||
|
||||
branches: Dict[str, List[Dict[str, Any]]] = {
|
||||
"lexical": self.metadata_store.query_episodes(
|
||||
query=clean_query,
|
||||
time_from=time_from,
|
||||
time_to=time_to,
|
||||
person=person,
|
||||
source=source,
|
||||
limit=(candidate_k if clean_query else safe_top_k),
|
||||
)
|
||||
}
|
||||
|
||||
if clean_query and self.retriever is not None:
|
||||
try:
|
||||
temporal = TemporalQueryOptions(
|
||||
time_from=time_from,
|
||||
time_to=time_to,
|
||||
person=person,
|
||||
source=source,
|
||||
)
|
||||
results = await self.retriever.retrieve(
|
||||
query=clean_query,
|
||||
top_k=candidate_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}")
|
||||
else:
|
||||
paragraph_rank_map: Dict[str, int] = {}
|
||||
relation_rank_map: Dict[str, int] = {}
|
||||
for rank, item in enumerate(results, start=1):
|
||||
hash_value = str(getattr(item, "hash_value", "") or "").strip()
|
||||
result_type = str(getattr(item, "result_type", "") or "").strip().lower()
|
||||
if not hash_value:
|
||||
continue
|
||||
if result_type == "paragraph" and hash_value not in paragraph_rank_map:
|
||||
paragraph_rank_map[hash_value] = rank
|
||||
elif result_type == "relation" and hash_value not in relation_rank_map:
|
||||
relation_rank_map[hash_value] = rank
|
||||
|
||||
if paragraph_rank_map:
|
||||
paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes(
|
||||
list(paragraph_rank_map.keys()),
|
||||
source=source,
|
||||
)
|
||||
if paragraph_rows:
|
||||
branches["paragraph_evidence"] = self._rank_projected_rows(
|
||||
paragraph_rows,
|
||||
rank_map=paragraph_rank_map,
|
||||
support_key="matched_paragraph_hashes",
|
||||
)
|
||||
|
||||
if relation_rank_map:
|
||||
relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes(
|
||||
list(relation_rank_map.keys()),
|
||||
source=source,
|
||||
)
|
||||
if relation_rows:
|
||||
branches["relation_evidence"] = self._rank_projected_rows(
|
||||
relation_rows,
|
||||
rank_map=relation_rank_map,
|
||||
support_key="matched_relation_hashes",
|
||||
)
|
||||
|
||||
fused = self._fuse_branches(branches, top_k=safe_top_k)
|
||||
if include_paragraphs:
|
||||
for item in fused:
|
||||
item["paragraphs"] = self.metadata_store.get_episode_paragraphs(
|
||||
episode_id=str(item.get("episode_id") or ""),
|
||||
limit=50,
|
||||
)
|
||||
return fused
|
||||
|
||||
@staticmethod
|
||||
def _rank_projected_rows(
|
||||
rows: List[Dict[str, Any]],
|
||||
*,
|
||||
rank_map: Dict[str, int],
|
||||
support_key: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
sentinel = 10**9
|
||||
ranked = [dict(item) for item in rows]
|
||||
|
||||
def _first_support_rank(item: Dict[str, Any]) -> int:
|
||||
support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])]
|
||||
ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map]
|
||||
return min(ranks) if ranks else sentinel
|
||||
|
||||
ranked.sort(
|
||||
key=lambda item: (
|
||||
_first_support_rank(item),
|
||||
-int(item.get("matched_paragraph_count") or 0),
|
||||
-float(item.get("updated_at") or 0.0),
|
||||
str(item.get("episode_id") or ""),
|
||||
)
|
||||
)
|
||||
return ranked
|
||||
|
||||
def _fuse_branches(
|
||||
self,
|
||||
branches: Dict[str, List[Dict[str, Any]]],
|
||||
*,
|
||||
top_k: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
bucket: Dict[str, Dict[str, Any]] = {}
|
||||
for branch_name, rows in branches.items():
|
||||
weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0)
|
||||
if weight <= 0.0:
|
||||
continue
|
||||
for rank, row in enumerate(rows, start=1):
|
||||
episode_id = str(row.get("episode_id", "") or "").strip()
|
||||
if not episode_id:
|
||||
continue
|
||||
if episode_id not in bucket:
|
||||
payload = dict(row)
|
||||
payload.pop("matched_paragraph_hashes", None)
|
||||
payload.pop("matched_relation_hashes", None)
|
||||
payload.pop("matched_paragraph_count", None)
|
||||
payload.pop("matched_relation_count", None)
|
||||
payload["_fusion_score"] = 0.0
|
||||
bucket[episode_id] = payload
|
||||
bucket[episode_id]["_fusion_score"] = float(
|
||||
bucket[episode_id].get("_fusion_score", 0.0)
|
||||
) + weight / (self._RRF_K + float(rank))
|
||||
|
||||
out = list(bucket.values())
|
||||
out.sort(
|
||||
key=lambda item: (
|
||||
-float(item.get("_fusion_score", 0.0)),
|
||||
-float(item.get("updated_at") or 0.0),
|
||||
str(item.get("episode_id") or ""),
|
||||
)
|
||||
)
|
||||
for item in out:
|
||||
item.pop("_fusion_score", None)
|
||||
return out[: max(1, int(top_k))]
|
||||
311
src/A_memorix/core/utils/episode_segmentation_service.py
Normal file
311
src/A_memorix/core/utils/episode_segmentation_service.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Episode 语义切分服务(LLM 主路径)。
|
||||
|
||||
职责:
|
||||
1. 组装语义切分提示词
|
||||
2. 调用 LLM 生成结构化 episode JSON
|
||||
3. 严格校验输出结构,返回标准化结果
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.config.config import model_config as host_model_config
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeSegmentationService")
|
||||
|
||||
|
||||
class EpisodeSegmentationService:
|
||||
"""基于 LLM 的 episode 语义切分服务。"""
|
||||
|
||||
SEGMENTATION_VERSION = "episode_mvp_v1"
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None):
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _is_task_config(obj: Any) -> bool:
|
||||
return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", []))
|
||||
|
||||
def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig:
|
||||
return TaskConfig(
|
||||
model_list=[model_name],
|
||||
max_tokens=template.max_tokens,
|
||||
temperature=template.temperature,
|
||||
slow_threshold=template.slow_threshold,
|
||||
selection_strategy=template.selection_strategy,
|
||||
)
|
||||
|
||||
def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]:
|
||||
preferred = ("utils", "replyer", "planner", "tool_use")
|
||||
for task_name in preferred:
|
||||
cfg = available_tasks.get(task_name)
|
||||
if self._is_task_config(cfg):
|
||||
return cfg
|
||||
for task_name, cfg in available_tasks.items():
|
||||
if task_name != "embedding" and self._is_task_config(cfg):
|
||||
return cfg
|
||||
for cfg in available_tasks.values():
|
||||
if self._is_task_config(cfg):
|
||||
return cfg
|
||||
return None
|
||||
|
||||
def _resolve_model_config(self) -> Tuple[Optional[Any], str]:
|
||||
available_tasks = llm_api.get_available_models() or {}
|
||||
if not available_tasks:
|
||||
return None, "unavailable"
|
||||
|
||||
selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip()
|
||||
model_dict = getattr(host_model_config, "models_dict", {}) or {}
|
||||
|
||||
if selector and selector.lower() != "auto":
|
||||
direct_task = available_tasks.get(selector)
|
||||
if self._is_task_config(direct_task):
|
||||
return direct_task, selector
|
||||
|
||||
if selector in model_dict:
|
||||
template = self._pick_template_task(available_tasks)
|
||||
if template is not None:
|
||||
return self._build_single_model_task(selector, template), selector
|
||||
|
||||
logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto")
|
||||
|
||||
for task_name in ("utils", "replyer", "planner", "tool_use"):
|
||||
cfg = available_tasks.get(task_name)
|
||||
if self._is_task_config(cfg):
|
||||
return cfg, task_name
|
||||
|
||||
fallback = self._pick_template_task(available_tasks)
|
||||
if fallback is not None:
|
||||
return fallback, "auto"
|
||||
return None, "unavailable"
|
||||
|
||||
@staticmethod
|
||||
def _clamp_score(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
num = float(value)
|
||||
except Exception:
|
||||
num = default
|
||||
if num < 0.0:
|
||||
return 0.0
|
||||
if num > 1.0:
|
||||
return 1.0
|
||||
return num
|
||||
|
||||
@staticmethod
|
||||
def _safe_json_loads(text: str) -> Dict[str, Any]:
|
||||
raw = str(text or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("empty_response")
|
||||
|
||||
if "```" in raw:
|
||||
raw = raw.replace("```json", "```").replace("```JSON", "```")
|
||||
parts = raw.split("```")
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part.startswith("{") and part.endswith("}"):
|
||||
raw = part
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
start = raw.find("{")
|
||||
end = raw.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
candidate = raw[start : end + 1]
|
||||
data = json.loads(candidate)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
|
||||
raise ValueError("invalid_json_response")
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
window_start: Optional[float],
|
||||
window_end: Optional[float],
|
||||
paragraphs: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
rows: List[str] = []
|
||||
for idx, item in enumerate(paragraphs, 1):
|
||||
p_hash = str(item.get("hash", "") or "").strip()
|
||||
content = str(item.get("content", "") or "").strip().replace("\r\n", "\n")
|
||||
content = content[:800]
|
||||
event_start = item.get("event_time_start")
|
||||
event_end = item.get("event_time_end")
|
||||
event_time = item.get("event_time")
|
||||
rows.append(
|
||||
(
|
||||
f"[{idx}] hash={p_hash}\n"
|
||||
f"event_time={event_time}\n"
|
||||
f"event_time_start={event_start}\n"
|
||||
f"event_time_end={event_end}\n"
|
||||
f"content={content}"
|
||||
)
|
||||
)
|
||||
|
||||
source_text = str(source or "").strip() or "unknown"
|
||||
return (
|
||||
"You are an episode segmentation engine.\n"
|
||||
"Group the given paragraphs into one or more coherent episodes.\n"
|
||||
"Return JSON ONLY. No markdown, no explanation.\n"
|
||||
"\n"
|
||||
"Hard JSON schema:\n"
|
||||
"{\n"
|
||||
' "episodes": [\n'
|
||||
" {\n"
|
||||
' "title": "string",\n'
|
||||
' "summary": "string",\n'
|
||||
' "paragraph_hashes": ["hash1", "hash2"],\n'
|
||||
' "participants": ["person1", "person2"],\n'
|
||||
' "keywords": ["kw1", "kw2"],\n'
|
||||
' "time_confidence": 0.0,\n'
|
||||
' "llm_confidence": 0.0\n'
|
||||
" }\n"
|
||||
" ]\n"
|
||||
"}\n"
|
||||
"\n"
|
||||
"Rules:\n"
|
||||
"1) paragraph_hashes must come from input only.\n"
|
||||
"2) title and summary must be non-empty.\n"
|
||||
"3) keep participants/keywords concise and deduplicated.\n"
|
||||
"4) if uncertain, still provide best effort confidence values.\n"
|
||||
"\n"
|
||||
f"source={source_text}\n"
|
||||
f"window_start={window_start}\n"
|
||||
f"window_end={window_end}\n"
|
||||
"paragraphs:\n"
|
||||
+ "\n\n".join(rows)
|
||||
)
|
||||
|
||||
def _normalize_episodes(
|
||||
self,
|
||||
*,
|
||||
payload: Dict[str, Any],
|
||||
input_hashes: List[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
raw_episodes = payload.get("episodes")
|
||||
if not isinstance(raw_episodes, list):
|
||||
raise ValueError("episodes_missing_or_not_list")
|
||||
|
||||
valid_hashes = set(input_hashes)
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for item in raw_episodes:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
title = str(item.get("title", "") or "").strip()
|
||||
summary = str(item.get("summary", "") or "").strip()
|
||||
if not title or not summary:
|
||||
continue
|
||||
|
||||
raw_hashes = item.get("paragraph_hashes")
|
||||
if not isinstance(raw_hashes, list):
|
||||
continue
|
||||
|
||||
dedup_hashes: List[str] = []
|
||||
seen_hashes = set()
|
||||
for h in raw_hashes:
|
||||
token = str(h or "").strip()
|
||||
if not token or token in seen_hashes or token not in valid_hashes:
|
||||
continue
|
||||
seen_hashes.add(token)
|
||||
dedup_hashes.append(token)
|
||||
|
||||
if not dedup_hashes:
|
||||
continue
|
||||
|
||||
participants = []
|
||||
for p in item.get("participants", []) or []:
|
||||
token = str(p or "").strip()
|
||||
if token:
|
||||
participants.append(token)
|
||||
|
||||
keywords = []
|
||||
for kw in item.get("keywords", []) or []:
|
||||
token = str(kw or "").strip()
|
||||
if token:
|
||||
keywords.append(token)
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"paragraph_hashes": dedup_hashes,
|
||||
"participants": participants[:16],
|
||||
"keywords": keywords[:20],
|
||||
"time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0),
|
||||
"llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5),
|
||||
}
|
||||
)
|
||||
|
||||
if not normalized:
|
||||
raise ValueError("episodes_all_invalid")
|
||||
return normalized
|
||||
|
||||
async def segment(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
window_start: Optional[float],
|
||||
window_end: Optional[float],
|
||||
paragraphs: List[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
if not paragraphs:
|
||||
raise ValueError("paragraphs_empty")
|
||||
|
||||
model_config, model_label = self._resolve_model_config()
|
||||
if model_config is None:
|
||||
raise RuntimeError("episode segmentation model unavailable")
|
||||
task_name = llm_api.resolve_task_name_from_model_config(model_config, preferred_task_name=model_label)
|
||||
|
||||
prompt = self._build_prompt(
|
||||
source=source,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
paragraphs=paragraphs,
|
||||
)
|
||||
result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name=task_name,
|
||||
request_type="A_Memorix.EpisodeSegmentation",
|
||||
prompt=prompt,
|
||||
temperature=getattr(model_config, "temperature", None),
|
||||
max_tokens=getattr(model_config, "max_tokens", None),
|
||||
)
|
||||
)
|
||||
success = bool(result.success)
|
||||
response = str(result.completion.response or "")
|
||||
if not success or not response:
|
||||
raise RuntimeError("llm_generate_failed")
|
||||
|
||||
payload = self._safe_json_loads(str(response))
|
||||
input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs]
|
||||
episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes)
|
||||
|
||||
return {
|
||||
"episodes": episodes,
|
||||
"segmentation_model": model_label,
|
||||
"segmentation_version": self.SEGMENTATION_VERSION,
|
||||
}
|
||||
|
||||
558
src/A_memorix/core/utils/episode_service.py
Normal file
558
src/A_memorix/core/utils/episode_service.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
Episode 聚合与落库服务。
|
||||
|
||||
流程:
|
||||
1. 从 pending 队列读取段落并组批
|
||||
2. 按 source + 时间窗口切组
|
||||
3. 调用 LLM 语义切分
|
||||
4. 写入 episodes + episode_paragraphs
|
||||
5. LLM 失败时使用确定性 fallback
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .episode_segmentation_service import EpisodeSegmentationService
|
||||
from .hash import compute_hash
|
||||
|
||||
logger = get_logger("A_Memorix.EpisodeService")
|
||||
|
||||
|
||||
class EpisodeService:
|
||||
"""Episode MVP 后台处理服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_store: Any,
|
||||
plugin_config: Optional[Any] = None,
|
||||
segmentation_service: Optional[EpisodeSegmentationService] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.segmentation_service = segmentation_service or EpisodeSegmentationService(
|
||||
plugin_config=self._config_dict(),
|
||||
)
|
||||
|
||||
def _config_dict(self) -> Dict[str, Any]:
|
||||
if isinstance(self.plugin_config, dict):
|
||||
return self.plugin_config
|
||||
return {}
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(self.plugin_config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _to_optional_float(value: Any) -> Optional[float]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _clamp_score(value: Any, default: float = 1.0) -> float:
|
||||
try:
|
||||
num = float(value)
|
||||
except Exception:
|
||||
num = default
|
||||
if num < 0.0:
|
||||
return 0.0
|
||||
if num > 1.0:
|
||||
return 1.0
|
||||
return num
|
||||
|
||||
@staticmethod
|
||||
def _paragraph_anchor(paragraph: Dict[str, Any]) -> float:
|
||||
for key in ("event_time_end", "event_time_start", "event_time", "created_at"):
|
||||
value = paragraph.get(key)
|
||||
try:
|
||||
if value is not None:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]:
|
||||
return (
|
||||
EpisodeService._paragraph_anchor(paragraph),
|
||||
str(paragraph.get("hash", "") or ""),
|
||||
)
|
||||
|
||||
def load_pending_paragraphs(
|
||||
self,
|
||||
pending_rows: List[Dict[str, Any]],
|
||||
) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""
|
||||
将 pending 行展开为段落上下文。
|
||||
|
||||
Returns:
|
||||
(loaded_paragraphs, missing_hashes)
|
||||
"""
|
||||
loaded: List[Dict[str, Any]] = []
|
||||
missing: List[str] = []
|
||||
for row in pending_rows or []:
|
||||
p_hash = str(row.get("paragraph_hash", "") or "").strip()
|
||||
if not p_hash:
|
||||
continue
|
||||
|
||||
paragraph = self.metadata_store.get_paragraph(p_hash)
|
||||
if not paragraph:
|
||||
missing.append(p_hash)
|
||||
continue
|
||||
|
||||
loaded.append(
|
||||
{
|
||||
"hash": p_hash,
|
||||
"source": str(row.get("source") or paragraph.get("source") or "").strip(),
|
||||
"content": str(paragraph.get("content", "") or ""),
|
||||
"created_at": self._to_optional_float(paragraph.get("created_at"))
|
||||
or self._to_optional_float(row.get("created_at"))
|
||||
or 0.0,
|
||||
"event_time": self._to_optional_float(paragraph.get("event_time")),
|
||||
"event_time_start": self._to_optional_float(paragraph.get("event_time_start")),
|
||||
"event_time_end": self._to_optional_float(paragraph.get("event_time_end")),
|
||||
"time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None,
|
||||
"time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0),
|
||||
}
|
||||
)
|
||||
return loaded, missing
|
||||
|
||||
def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。
|
||||
"""
|
||||
if not paragraphs:
|
||||
return []
|
||||
|
||||
max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20)))
|
||||
max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000)))
|
||||
window_seconds = max(
|
||||
60.0,
|
||||
float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0,
|
||||
)
|
||||
|
||||
by_source: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for paragraph in paragraphs:
|
||||
source = str(paragraph.get("source", "") or "").strip()
|
||||
by_source.setdefault(source, []).append(paragraph)
|
||||
|
||||
groups: List[Dict[str, Any]] = []
|
||||
for source, items in by_source.items():
|
||||
ordered = sorted(items, key=self._paragraph_sort_key)
|
||||
|
||||
current: List[Dict[str, Any]] = []
|
||||
current_chars = 0
|
||||
last_anchor: Optional[float] = None
|
||||
|
||||
def flush() -> None:
|
||||
nonlocal current, current_chars, last_anchor
|
||||
if not current:
|
||||
return
|
||||
sorted_current = sorted(current, key=self._paragraph_sort_key)
|
||||
groups.append(
|
||||
{
|
||||
"source": source,
|
||||
"paragraphs": sorted_current,
|
||||
}
|
||||
)
|
||||
current = []
|
||||
current_chars = 0
|
||||
last_anchor = None
|
||||
|
||||
for paragraph in ordered:
|
||||
anchor = self._paragraph_anchor(paragraph)
|
||||
content_len = len(str(paragraph.get("content", "") or ""))
|
||||
|
||||
need_flush = False
|
||||
if current:
|
||||
if len(current) >= max_paragraphs:
|
||||
need_flush = True
|
||||
elif current_chars + content_len > max_chars:
|
||||
need_flush = True
|
||||
elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds:
|
||||
need_flush = True
|
||||
|
||||
if need_flush:
|
||||
flush()
|
||||
|
||||
current.append(paragraph)
|
||||
current_chars += content_len
|
||||
last_anchor = anchor
|
||||
|
||||
flush()
|
||||
|
||||
groups.sort(
|
||||
key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0
|
||||
)
|
||||
return groups
|
||||
|
||||
def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]:
|
||||
starts: List[float] = []
|
||||
ends: List[float] = []
|
||||
granularity_priority = {
|
||||
"minute": 4,
|
||||
"hour": 3,
|
||||
"day": 2,
|
||||
"month": 1,
|
||||
"year": 0,
|
||||
}
|
||||
granularity = None
|
||||
granularity_rank = -1
|
||||
conf_values: List[float] = []
|
||||
|
||||
for p in paragraphs:
|
||||
s = self._to_optional_float(p.get("event_time_start"))
|
||||
e = self._to_optional_float(p.get("event_time_end"))
|
||||
t = self._to_optional_float(p.get("event_time"))
|
||||
c = self._to_optional_float(p.get("created_at"))
|
||||
|
||||
start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c))
|
||||
end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c))
|
||||
|
||||
if start_candidate is not None:
|
||||
starts.append(start_candidate)
|
||||
if end_candidate is not None:
|
||||
ends.append(end_candidate)
|
||||
|
||||
g = str(p.get("time_granularity", "") or "").strip().lower()
|
||||
if g in granularity_priority and granularity_priority[g] > granularity_rank:
|
||||
granularity_rank = granularity_priority[g]
|
||||
granularity = g
|
||||
|
||||
conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0))
|
||||
|
||||
time_start = min(starts) if starts else None
|
||||
time_end = max(ends) if ends else None
|
||||
time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0
|
||||
return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0)
|
||||
|
||||
def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]:
|
||||
seen = set()
|
||||
participants: List[str] = []
|
||||
for p_hash in paragraph_hashes:
|
||||
try:
|
||||
entities = self.metadata_store.get_paragraph_entities(p_hash)
|
||||
except Exception:
|
||||
entities = []
|
||||
for item in entities:
|
||||
name = str(item.get("name", "") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
participants.append(name)
|
||||
if len(participants) >= limit:
|
||||
return participants
|
||||
return participants
|
||||
|
||||
@staticmethod
|
||||
def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]:
|
||||
token_counter: Counter[str] = Counter()
|
||||
token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}")
|
||||
stop_words = {
|
||||
"the",
|
||||
"and",
|
||||
"that",
|
||||
"this",
|
||||
"with",
|
||||
"from",
|
||||
"for",
|
||||
"have",
|
||||
"will",
|
||||
"your",
|
||||
"you",
|
||||
"我们",
|
||||
"你们",
|
||||
"他们",
|
||||
"以及",
|
||||
"一个",
|
||||
"这个",
|
||||
"那个",
|
||||
"然后",
|
||||
"因为",
|
||||
"所以",
|
||||
}
|
||||
for p in paragraphs:
|
||||
text = str(p.get("content", "") or "").lower()
|
||||
for token in token_pattern.findall(text):
|
||||
if token in stop_words:
|
||||
continue
|
||||
token_counter[token] += 1
|
||||
|
||||
return [token for token, _ in token_counter.most_common(limit)]
|
||||
|
||||
def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
paragraphs = group.get("paragraphs", []) or []
|
||||
source = str(group.get("source", "") or "").strip()
|
||||
hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
|
||||
snippets = []
|
||||
for p in paragraphs[:3]:
|
||||
text = str(p.get("content", "") or "").strip().replace("\n", " ")
|
||||
if text:
|
||||
snippets.append(text[:140])
|
||||
summary = ";".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。"
|
||||
|
||||
time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs)
|
||||
participants = self._collect_participants(hashes, limit=12)
|
||||
keywords = self._derive_keywords(paragraphs, limit=10)
|
||||
|
||||
if time_start is not None:
|
||||
day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d")
|
||||
title = f"{source or 'unknown'} {day_text} 情景片段"
|
||||
else:
|
||||
title = f"{source or 'unknown'} 情景片段"
|
||||
|
||||
return {
|
||||
"title": title[:80],
|
||||
"summary": summary,
|
||||
"paragraph_hashes": hashes,
|
||||
"participants": participants,
|
||||
"keywords": keywords,
|
||||
"time_confidence": time_conf,
|
||||
"llm_confidence": 0.0,
|
||||
"event_time_start": time_start,
|
||||
"event_time_end": time_end,
|
||||
"time_granularity": granularity,
|
||||
"segmentation_model": "fallback_rule",
|
||||
"segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]:
|
||||
in_group = set(group_hashes_ordered)
|
||||
dedup: List[str] = []
|
||||
seen = set()
|
||||
for h in episode_hashes or []:
|
||||
token = str(h or "").strip()
|
||||
if not token or token not in in_group or token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
dedup.append(token)
|
||||
return dedup
|
||||
|
||||
async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
paragraphs = group.get("paragraphs", []) or []
|
||||
if not paragraphs:
|
||||
return {
|
||||
"payloads": [],
|
||||
"done_hashes": [],
|
||||
"episode_count": 0,
|
||||
"fallback_count": 0,
|
||||
}
|
||||
|
||||
source = str(group.get("source", "") or "").strip()
|
||||
group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()]
|
||||
group_start, group_end, _, _ = self._compute_time_meta(paragraphs)
|
||||
|
||||
fallback_used = False
|
||||
segmentation_model = "fallback_rule"
|
||||
segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION
|
||||
|
||||
try:
|
||||
llm_result = await self.segmentation_service.segment(
|
||||
source=source,
|
||||
window_start=group_start,
|
||||
window_end=group_end,
|
||||
paragraphs=paragraphs,
|
||||
)
|
||||
episodes = list(llm_result.get("episodes") or [])
|
||||
segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto"
|
||||
segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION
|
||||
if not episodes:
|
||||
raise ValueError("llm_empty_episodes")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Episode segmentation fallback: "
|
||||
f"source={source} "
|
||||
f"size={len(group_hashes)} "
|
||||
f"err={e}"
|
||||
)
|
||||
episodes = [self._build_fallback_episode(group)]
|
||||
fallback_used = True
|
||||
|
||||
stored_payloads: List[Dict[str, Any]] = []
|
||||
for episode in episodes:
|
||||
ordered_hashes = self._normalize_episode_hashes(
|
||||
episode_hashes=episode.get("paragraph_hashes", []),
|
||||
group_hashes_ordered=group_hashes,
|
||||
)
|
||||
if not ordered_hashes:
|
||||
continue
|
||||
|
||||
sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)]
|
||||
event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs)
|
||||
|
||||
participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()]
|
||||
keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()]
|
||||
if not participants:
|
||||
participants = self._collect_participants(ordered_hashes, limit=16)
|
||||
if not keywords:
|
||||
keywords = self._derive_keywords(sub_paragraphs, limit=12)
|
||||
|
||||
title = str(episode.get("title", "") or "").strip()[:120]
|
||||
summary = str(episode.get("summary", "") or "").strip()[:2000]
|
||||
if not title or not summary:
|
||||
continue
|
||||
|
||||
seed = json.dumps(
|
||||
{
|
||||
"source": source,
|
||||
"hashes": ordered_hashes,
|
||||
"version": segmentation_version,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
)
|
||||
episode_id = compute_hash(seed)
|
||||
|
||||
payload = {
|
||||
"episode_id": episode_id,
|
||||
"source": source or None,
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"event_time_start": episode.get("event_time_start", event_start),
|
||||
"event_time_end": episode.get("event_time_end", event_end),
|
||||
"time_granularity": episode.get("time_granularity", granularity),
|
||||
"time_confidence": self._clamp_score(
|
||||
episode.get("time_confidence"),
|
||||
default=time_conf_default,
|
||||
),
|
||||
"participants": participants[:16],
|
||||
"keywords": keywords[:20],
|
||||
"evidence_ids": ordered_hashes,
|
||||
"paragraph_count": len(ordered_hashes),
|
||||
"llm_confidence": self._clamp_score(
|
||||
episode.get("llm_confidence"),
|
||||
default=0.0 if fallback_used else 0.6,
|
||||
),
|
||||
"segmentation_model": (
|
||||
str(episode.get("segmentation_model", "") or "").strip()
|
||||
or ("fallback_rule" if fallback_used else segmentation_model)
|
||||
),
|
||||
"segmentation_version": (
|
||||
str(episode.get("segmentation_version", "") or "").strip()
|
||||
or segmentation_version
|
||||
),
|
||||
}
|
||||
stored_payloads.append(payload)
|
||||
|
||||
return {
|
||||
"payloads": stored_payloads,
|
||||
"done_hashes": group_hashes,
|
||||
"episode_count": len(stored_payloads),
|
||||
"fallback_count": 1 if fallback_used else 0,
|
||||
}
|
||||
|
||||
async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = await self._build_episode_payloads_for_group(group)
|
||||
stored_count = 0
|
||||
for payload in result.get("payloads") or []:
|
||||
stored = self.metadata_store.upsert_episode(payload)
|
||||
final_id = str(stored.get("episode_id") or payload.get("episode_id") or "")
|
||||
if final_id:
|
||||
self.metadata_store.bind_episode_paragraphs(
|
||||
final_id,
|
||||
list(payload.get("evidence_ids") or []),
|
||||
)
|
||||
stored_count += 1
|
||||
|
||||
result["episode_count"] = stored_count
|
||||
return {
|
||||
"done_hashes": list(result.get("done_hashes") or []),
|
||||
"episode_count": stored_count,
|
||||
"fallback_count": int(result.get("fallback_count") or 0),
|
||||
}
|
||||
|
||||
async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
loaded, missing_hashes = self.load_pending_paragraphs(pending_rows)
|
||||
groups = self.group_paragraphs(loaded)
|
||||
|
||||
done_hashes: List[str] = list(missing_hashes)
|
||||
failed_hashes: Dict[str, str] = {}
|
||||
episode_count = 0
|
||||
fallback_count = 0
|
||||
|
||||
for group in groups:
|
||||
group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])]
|
||||
try:
|
||||
result = await self.process_group(group)
|
||||
done_hashes.extend(result.get("done_hashes") or [])
|
||||
episode_count += int(result.get("episode_count") or 0)
|
||||
fallback_count += int(result.get("fallback_count") or 0)
|
||||
except Exception as e:
|
||||
err = str(e)[:500]
|
||||
for h in group_hashes:
|
||||
if h:
|
||||
failed_hashes[h] = err
|
||||
|
||||
dedup_done = list(dict.fromkeys([h for h in done_hashes if h]))
|
||||
return {
|
||||
"done_hashes": dedup_done,
|
||||
"failed_hashes": failed_hashes,
|
||||
"episode_count": episode_count,
|
||||
"fallback_count": fallback_count,
|
||||
"missing_count": len(missing_hashes),
|
||||
"group_count": len(groups),
|
||||
}
|
||||
|
||||
async def rebuild_source(self, source: str) -> Dict[str, Any]:
|
||||
token = str(source or "").strip()
|
||||
if not token:
|
||||
return {
|
||||
"source": "",
|
||||
"episode_count": 0,
|
||||
"fallback_count": 0,
|
||||
"group_count": 0,
|
||||
"paragraph_count": 0,
|
||||
}
|
||||
|
||||
paragraphs = self.metadata_store.get_live_paragraphs_by_source(token)
|
||||
if not paragraphs:
|
||||
replace_result = self.metadata_store.replace_episodes_for_source(token, [])
|
||||
return {
|
||||
"source": token,
|
||||
"episode_count": int(replace_result.get("episode_count") or 0),
|
||||
"fallback_count": 0,
|
||||
"group_count": 0,
|
||||
"paragraph_count": 0,
|
||||
}
|
||||
|
||||
groups = self.group_paragraphs(paragraphs)
|
||||
payloads: List[Dict[str, Any]] = []
|
||||
fallback_count = 0
|
||||
|
||||
for group in groups:
|
||||
result = await self._build_episode_payloads_for_group(group)
|
||||
payloads.extend(list(result.get("payloads") or []))
|
||||
fallback_count += int(result.get("fallback_count") or 0)
|
||||
|
||||
replace_result = self.metadata_store.replace_episodes_for_source(token, payloads)
|
||||
return {
|
||||
"source": token,
|
||||
"episode_count": int(replace_result.get("episode_count") or 0),
|
||||
"fallback_count": fallback_count,
|
||||
"group_count": len(groups),
|
||||
"paragraph_count": len(paragraphs),
|
||||
}
|
||||
129
src/A_memorix/core/utils/hash.py
Normal file
129
src/A_memorix/core/utils/hash.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
哈希工具模块
|
||||
|
||||
提供文本哈希计算功能,用于唯一标识和去重。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
|
||||
def compute_hash(text: str, hash_type: str = "sha256") -> str:
|
||||
"""
|
||||
计算文本的哈希值
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
hash_type: 哈希算法类型(sha256, md5等)
|
||||
|
||||
Returns:
|
||||
哈希值字符串
|
||||
"""
|
||||
if hash_type == "sha256":
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
elif hash_type == "md5":
|
||||
return hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
raise ValueError(f"不支持的哈希算法: {hash_type}")
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
规范化文本用于哈希计算
|
||||
|
||||
执行以下操作:
|
||||
- 去除首尾空白
|
||||
- 统一换行符为\\n
|
||||
- 压缩多个连续空格
|
||||
- 去除不可见字符(保留换行和制表符)
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
规范化后的文本
|
||||
"""
|
||||
# 去除首尾空白
|
||||
text = text.strip()
|
||||
|
||||
# 统一换行符
|
||||
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
# 压缩多个连续空格为一个(但保留换行和制表符)
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def compute_paragraph_hash(paragraph: str) -> str:
|
||||
"""
|
||||
计算段落的哈希值
|
||||
|
||||
Args:
|
||||
paragraph: 段落文本
|
||||
|
||||
Returns:
|
||||
段落哈希值(用于paragraph-前缀)
|
||||
"""
|
||||
normalized = normalize_text(paragraph)
|
||||
return compute_hash(normalized)
|
||||
|
||||
|
||||
def compute_entity_hash(entity: str) -> str:
|
||||
"""
|
||||
计算实体的哈希值
|
||||
|
||||
Args:
|
||||
entity: 实体名称
|
||||
|
||||
Returns:
|
||||
实体哈希值(用于entity-前缀)
|
||||
"""
|
||||
normalized = entity.strip().lower()
|
||||
return compute_hash(normalized)
|
||||
|
||||
|
||||
def compute_relation_hash(relation: tuple) -> str:
|
||||
"""
|
||||
计算关系的哈希值
|
||||
|
||||
Args:
|
||||
relation: 关系元组 (subject, predicate, object)
|
||||
|
||||
Returns:
|
||||
关系哈希值(用于relation-前缀)
|
||||
"""
|
||||
# 将关系元组转为字符串
|
||||
relation_str = str(tuple(relation))
|
||||
return compute_hash(relation_str)
|
||||
|
||||
|
||||
def format_hash_key(hash_type: str, hash_value: str) -> str:
|
||||
"""
|
||||
格式化哈希键
|
||||
|
||||
Args:
|
||||
hash_type: 类型前缀(paragraph, entity, relation)
|
||||
hash_value: 哈希值
|
||||
|
||||
Returns:
|
||||
格式化的键(如 paragraph-abc123...)
|
||||
"""
|
||||
return f"{hash_type}-{hash_value}"
|
||||
|
||||
|
||||
def parse_hash_key(key: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析哈希键
|
||||
|
||||
Args:
|
||||
key: 格式化的键(如 paragraph-abc123...)
|
||||
|
||||
Returns:
|
||||
(类型, 哈希值) 元组
|
||||
"""
|
||||
parts = key.split("-", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"无效的哈希键格式: {key}")
|
||||
return parts[0], parts[1]
|
||||
110
src/A_memorix/core/utils/import_payloads.py
Normal file
110
src/A_memorix/core/utils/import_payloads.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Shared import payload normalization helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..storage import KnowledgeType, resolve_stored_knowledge_type
|
||||
from .time_parser import normalize_time_meta
|
||||
|
||||
|
||||
def _normalize_entities(raw_entities: Any) -> List[str]:
|
||||
if not isinstance(raw_entities, list):
|
||||
return []
|
||||
out: List[str] = []
|
||||
seen = set()
|
||||
for item in raw_entities:
|
||||
name = str(item or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(name)
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]:
|
||||
if not isinstance(raw_relations, list):
|
||||
return []
|
||||
out: List[Dict[str, str]] = []
|
||||
for item in raw_relations:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
subject = str(item.get("subject", "")).strip()
|
||||
predicate = str(item.get("predicate", "")).strip()
|
||||
obj = str(item.get("object", "")).strip()
|
||||
if not (subject and predicate and obj):
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"subject": subject,
|
||||
"predicate": predicate,
|
||||
"object": obj,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def normalize_paragraph_import_item(
|
||||
item: Any,
|
||||
*,
|
||||
default_source: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Normalize one paragraph import item from text/json payloads."""
|
||||
|
||||
if isinstance(item, str):
|
||||
content = str(item)
|
||||
knowledge_type = resolve_stored_knowledge_type(None, content=content)
|
||||
return {
|
||||
"content": content,
|
||||
"knowledge_type": knowledge_type.value,
|
||||
"source": str(default_source or "").strip(),
|
||||
"time_meta": None,
|
||||
"entities": [],
|
||||
"relations": [],
|
||||
}
|
||||
|
||||
if not isinstance(item, dict) or "content" not in item:
|
||||
raise ValueError("段落项必须为字符串或包含 content 的对象")
|
||||
|
||||
content = str(item.get("content", "") or "")
|
||||
if not content.strip():
|
||||
raise ValueError("段落 content 不能为空")
|
||||
|
||||
raw_time_meta = {
|
||||
"event_time": item.get("event_time"),
|
||||
"event_time_start": item.get("event_time_start"),
|
||||
"event_time_end": item.get("event_time_end"),
|
||||
"time_range": item.get("time_range"),
|
||||
"time_granularity": item.get("time_granularity"),
|
||||
"time_confidence": item.get("time_confidence"),
|
||||
}
|
||||
time_meta_field = item.get("time_meta")
|
||||
if isinstance(time_meta_field, dict):
|
||||
raw_time_meta.update(time_meta_field)
|
||||
|
||||
knowledge_type_raw = item.get("knowledge_type")
|
||||
if knowledge_type_raw is None:
|
||||
knowledge_type_raw = item.get("type")
|
||||
knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content)
|
||||
source = str(item.get("source") or default_source or "").strip()
|
||||
if not source:
|
||||
source = str(default_source or "").strip()
|
||||
|
||||
normalized_time_meta = normalize_time_meta(raw_time_meta)
|
||||
return {
|
||||
"content": content,
|
||||
"knowledge_type": knowledge_type.value,
|
||||
"source": source,
|
||||
"time_meta": normalized_time_meta if normalized_time_meta else None,
|
||||
"entities": _normalize_entities(item.get("entities")),
|
||||
"relations": _normalize_relations(item.get("relations")),
|
||||
}
|
||||
|
||||
|
||||
def normalize_summary_knowledge_type(value: Any) -> KnowledgeType:
|
||||
"""Normalize config-driven summary knowledge type."""
|
||||
|
||||
return resolve_stored_knowledge_type(value, content="")
|
||||
84
src/A_memorix/core/utils/io.py
Normal file
84
src/A_memorix/core/utils/io.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
IO Utilities
|
||||
|
||||
提供原子文件写入等IO辅助功能。
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@contextlib.contextmanager
|
||||
def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs):
|
||||
"""
|
||||
原子文件写入上下文管理器
|
||||
|
||||
原理:
|
||||
1. 写入 .tmp 临时文件
|
||||
2. 写入成功后,使用 os.replace 原子替换目标文件
|
||||
3. 如果失败,自动删除临时文件
|
||||
|
||||
Args:
|
||||
file_path: 目标文件路径
|
||||
mode: 打开模式 ('w', 'wb' 等)
|
||||
encoding: 编码
|
||||
**kwargs: 传给 open() 的其他参数
|
||||
"""
|
||||
path = Path(file_path)
|
||||
# 确保父目录存在
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 临时文件路径
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
|
||||
try:
|
||||
with open(tmp_path, mode, encoding=encoding, **kwargs) as f:
|
||||
yield f
|
||||
|
||||
# 确保写入磁盘
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
# 原子替换 (Windows下可能需要先删除目标文件,但 os.replace 在 Py3.3+ 尽可能原子)
|
||||
# 注意: Windows 上如果有其他进程占用文件,os.replace 可能会失败
|
||||
os.replace(tmp_path, path)
|
||||
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
if tmp_path.exists():
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
|
||||
@contextlib.contextmanager
|
||||
def atomic_save_path(file_path: Union[str, Path]):
|
||||
"""
|
||||
提供临时路径用于原子保存 (针对只接受路径的API,如Faiss)
|
||||
|
||||
Args:
|
||||
file_path: 最终目标路径
|
||||
|
||||
Yields:
|
||||
tmp_path: 临时文件路径 (str)
|
||||
"""
|
||||
path = Path(file_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
|
||||
try:
|
||||
yield str(tmp_path)
|
||||
|
||||
if Path(tmp_path).exists():
|
||||
os.replace(tmp_path, path)
|
||||
|
||||
except Exception as e:
|
||||
if Path(tmp_path).exists():
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
89
src/A_memorix/core/utils/matcher.py
Normal file
89
src/A_memorix/core/utils/matcher.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
高效文本匹配工具模块
|
||||
|
||||
实现 Aho-Corasick 算法用于多模式匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Tuple, Set, Any
|
||||
from collections import deque
|
||||
|
||||
|
||||
class AhoCorasick:
|
||||
"""
|
||||
Aho-Corasick 自动机实现高效多模式匹配
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# next_states[state][char] = next_state
|
||||
self.next_states: List[Dict[str, int]] = [{}]
|
||||
# fail[state] = fail_state
|
||||
self.fail: List[int] = [0]
|
||||
# output[state] = set of patterns ending at this state
|
||||
self.output: List[Set[str]] = [set()]
|
||||
self.patterns: Set[str] = set()
|
||||
|
||||
def add_pattern(self, pattern: str):
|
||||
"""添加模式"""
|
||||
if not pattern:
|
||||
return
|
||||
self.patterns.add(pattern)
|
||||
state = 0
|
||||
for char in pattern:
|
||||
if char not in self.next_states[state]:
|
||||
new_state = len(self.next_states)
|
||||
self.next_states[state][char] = new_state
|
||||
self.next_states.append({})
|
||||
self.fail.append(0)
|
||||
self.output.append(set())
|
||||
state = self.next_states[state][char]
|
||||
self.output[state].add(pattern)
|
||||
|
||||
def build(self):
|
||||
"""构建失败指针"""
|
||||
queue = deque()
|
||||
# 处理第一层
|
||||
for char, state in self.next_states[0].items():
|
||||
queue.append(state)
|
||||
self.fail[state] = 0
|
||||
|
||||
while queue:
|
||||
r = queue.popleft()
|
||||
for char, s in self.next_states[r].items():
|
||||
queue.append(s)
|
||||
# 找到失败路径
|
||||
state = self.fail[r]
|
||||
while char not in self.next_states[state] and state != 0:
|
||||
state = self.fail[state]
|
||||
self.fail[s] = self.next_states[state].get(char, 0)
|
||||
# 合并输出
|
||||
self.output[s].update(self.output[self.fail[s]])
|
||||
|
||||
def search(self, text: str) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
在文本中搜索所有模式
|
||||
|
||||
Returns:
|
||||
[(结束索引, 匹配到的模式), ...]
|
||||
"""
|
||||
state = 0
|
||||
results = []
|
||||
for i, char in enumerate(text):
|
||||
while char not in self.next_states[state] and state != 0:
|
||||
state = self.fail[state]
|
||||
state = self.next_states[state].get(char, 0)
|
||||
for pattern in self.output[state]:
|
||||
results.append((i, pattern))
|
||||
return results
|
||||
|
||||
def find_all(self, text: str) -> Dict[str, int]:
|
||||
"""
|
||||
查找并统计所有模式出现次数
|
||||
|
||||
Returns:
|
||||
{模式: 出现次数}
|
||||
"""
|
||||
results = self.search(text)
|
||||
stats = {}
|
||||
for _, pattern in results:
|
||||
stats[pattern] = stats.get(pattern, 0) + 1
|
||||
return stats
|
||||
189
src/A_memorix/core/utils/monitor.py
Normal file
189
src/A_memorix/core/utils/monitor.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
内存监控模块
|
||||
|
||||
提供内存使用监控和预警功能。
|
||||
"""
|
||||
|
||||
import gc
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
try:
|
||||
import psutil
|
||||
HAS_PSUTIL = True
|
||||
except ImportError:
|
||||
HAS_PSUTIL = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.MemoryMonitor")
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""
|
||||
内存监控器
|
||||
|
||||
功能:
|
||||
- 实时监控内存使用
|
||||
- 超过阈值时触发警告
|
||||
- 支持自动垃圾回收
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_memory_mb: int,
|
||||
warning_threshold: float = 0.9,
|
||||
check_interval: float = 10.0,
|
||||
enable_auto_gc: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化内存监控器
|
||||
|
||||
Args:
|
||||
max_memory_mb: 最大内存限制(MB)
|
||||
warning_threshold: 警告阈值(0-1之间,默认0.9表示90%)
|
||||
check_interval: 检查间隔(秒)
|
||||
enable_auto_gc: 是否启用自动垃圾回收
|
||||
"""
|
||||
self.max_memory_mb = max_memory_mb
|
||||
self.warning_threshold = warning_threshold
|
||||
self.check_interval = check_interval
|
||||
self.enable_auto_gc = enable_auto_gc
|
||||
|
||||
self._running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._callbacks: list[Callable[[float, float], None]] = []
|
||||
|
||||
def start(self):
|
||||
"""启动监控"""
|
||||
if self._running:
|
||||
logger.warning("内存监控已在运行")
|
||||
return
|
||||
|
||||
if not HAS_PSUTIL:
|
||||
logger.warning("psutil 未安装,内存监控功能不可用")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)")
|
||||
|
||||
def stop(self):
|
||||
"""停止监控"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5.0)
|
||||
logger.info("内存监控已停止")
|
||||
|
||||
def register_callback(self, callback: Callable[[float, float], None]):
|
||||
"""
|
||||
注册内存超限回调函数
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收 (当前使用MB, 限制MB) 参数
|
||||
"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def get_current_memory_mb(self) -> float:
|
||||
"""
|
||||
获取当前进程内存使用量(MB)
|
||||
|
||||
Returns:
|
||||
内存使用量(MB)
|
||||
"""
|
||||
if not HAS_PSUTIL:
|
||||
# 降级方案:使用内置函数
|
||||
import sys
|
||||
return sys.getsizeof(gc.get_objects()) / 1024 / 1024
|
||||
|
||||
process = psutil.Process()
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
def get_memory_usage_ratio(self) -> float:
|
||||
"""
|
||||
获取内存使用率
|
||||
|
||||
Returns:
|
||||
使用率(0-1之间)
|
||||
"""
|
||||
current = self.get_current_memory_mb()
|
||||
return current / self.max_memory_mb if self.max_memory_mb > 0 else 0
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""监控循环"""
|
||||
while self._running:
|
||||
try:
|
||||
current_mb = self.get_current_memory_mb()
|
||||
ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0
|
||||
|
||||
# 检查是否超过阈值
|
||||
if ratio >= self.warning_threshold:
|
||||
logger.warning(
|
||||
f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB "
|
||||
f"({ratio*100:.1f}%)"
|
||||
)
|
||||
|
||||
# 触发回调
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
callback(current_mb, self.max_memory_mb)
|
||||
except Exception as e:
|
||||
logger.error(f"内存回调执行失败: {e}")
|
||||
|
||||
# 自动垃圾回收
|
||||
if self.enable_auto_gc:
|
||||
before = self.get_current_memory_mb()
|
||||
gc.collect()
|
||||
after = self.get_current_memory_mb()
|
||||
freed = before - after
|
||||
if freed > 1: # 释放超过1MB才记录
|
||||
logger.info(f"垃圾回收释放: {freed:.1f}MB")
|
||||
|
||||
# 定期垃圾回收(即使未超限)
|
||||
elif self.enable_auto_gc and int(time.time()) % 60 == 0:
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"内存监控出错: {e}")
|
||||
|
||||
# 等待下次检查
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""上下文管理器出口"""
|
||||
self.stop()
|
||||
|
||||
|
||||
def get_memory_info() -> dict:
|
||||
"""
|
||||
获取系统内存信息
|
||||
|
||||
Returns:
|
||||
内存信息字典
|
||||
"""
|
||||
if not HAS_PSUTIL:
|
||||
return {"error": "psutil 未安装"}
|
||||
|
||||
try:
|
||||
mem = psutil.virtual_memory()
|
||||
process = psutil.Process()
|
||||
|
||||
return {
|
||||
"system_total_gb": mem.total / 1024 / 1024 / 1024,
|
||||
"system_available_gb": mem.available / 1024 / 1024 / 1024,
|
||||
"system_usage_percent": mem.percent,
|
||||
"process_mb": process.memory_info().rss / 1024 / 1024,
|
||||
"process_percent": (process.memory_info().rss / mem.total) * 100,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
165
src/A_memorix/core/utils/path_fallback_service.py
Normal file
165
src/A_memorix/core/utils/path_fallback_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Shared path-fallback helpers for search post-processing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..retrieval.dual_path import RetrievalResult
|
||||
|
||||
|
||||
def extract_entities(query: str, graph_store: Any) -> List[str]:
|
||||
"""Extract up to two graph nodes from a query using n-gram matching."""
|
||||
if not graph_store:
|
||||
return []
|
||||
|
||||
text = str(query or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Keep the heuristic aligned with previous legacy behavior.
|
||||
tokens = (
|
||||
text.replace("?", " ")
|
||||
.replace("!", " ")
|
||||
.replace(".", " ")
|
||||
.split()
|
||||
)
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
found_entities = set()
|
||||
skip_indices = set()
|
||||
max_n = min(4, len(tokens))
|
||||
|
||||
for size in range(max_n, 0, -1):
|
||||
for i in range(len(tokens) - size + 1):
|
||||
if any(idx in skip_indices for idx in range(i, i + size)):
|
||||
continue
|
||||
span = " ".join(tokens[i : i + size])
|
||||
matched_node = graph_store.find_node(span, ignore_case=True)
|
||||
if not matched_node:
|
||||
continue
|
||||
found_entities.add(matched_node)
|
||||
for idx in range(i, i + size):
|
||||
skip_indices.add(idx)
|
||||
|
||||
return list(found_entities)
|
||||
|
||||
|
||||
def find_paths_between_entities(
|
||||
start_node: str,
|
||||
end_node: str,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
*,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find and enrich indirect paths between two nodes."""
|
||||
if not graph_store or not metadata_store:
|
||||
return []
|
||||
|
||||
try:
|
||||
paths = graph_store.find_paths(
|
||||
start_node,
|
||||
end_node,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if not paths:
|
||||
return []
|
||||
|
||||
edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {}
|
||||
formatted_paths: List[Dict[str, Any]] = []
|
||||
|
||||
for path_nodes in paths:
|
||||
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2:
|
||||
continue
|
||||
|
||||
path_desc: List[str] = []
|
||||
for i in range(len(path_nodes) - 1):
|
||||
u = str(path_nodes[i])
|
||||
v = str(path_nodes[i + 1])
|
||||
|
||||
cache_key = tuple(sorted((u, v)))
|
||||
if cache_key in edge_cache:
|
||||
pred, direction = edge_cache[cache_key]
|
||||
else:
|
||||
pred = "related"
|
||||
direction = "->"
|
||||
rels = metadata_store.get_relations(subject=u, object=v)
|
||||
if not rels:
|
||||
rels = metadata_store.get_relations(subject=v, object=u)
|
||||
direction = "<-"
|
||||
if rels:
|
||||
best_rel = max(rels, key=lambda x: x.get("confidence", 1.0))
|
||||
pred = str(best_rel.get("predicate", "related") or "related")
|
||||
edge_cache[cache_key] = (pred, direction)
|
||||
|
||||
step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-"
|
||||
path_desc.append(step_str)
|
||||
|
||||
full_path_str = str(path_nodes[0])
|
||||
for i, step in enumerate(path_desc):
|
||||
full_path_str += f" {step} {path_nodes[i + 1]}"
|
||||
|
||||
formatted_paths.append(
|
||||
{
|
||||
"nodes": list(path_nodes),
|
||||
"description": full_path_str,
|
||||
}
|
||||
)
|
||||
|
||||
return formatted_paths
|
||||
|
||||
|
||||
def find_paths_from_query(
|
||||
query: str,
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
*,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract entities from query and resolve indirect paths."""
|
||||
entities = extract_entities(query, graph_store)
|
||||
if len(entities) != 2:
|
||||
return []
|
||||
return find_paths_between_entities(
|
||||
entities[0],
|
||||
entities[1],
|
||||
graph_store,
|
||||
metadata_store,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
|
||||
|
||||
def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]:
|
||||
"""Convert path results into retrieval results for the unified pipeline."""
|
||||
converted: List[RetrievalResult] = []
|
||||
for item in paths:
|
||||
description = str(item.get("description", "")).strip()
|
||||
if not description:
|
||||
continue
|
||||
hash_seed = description.encode("utf-8")
|
||||
path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}"
|
||||
converted.append(
|
||||
RetrievalResult(
|
||||
hash_value=path_hash,
|
||||
content=f"[Indirect Relation] {description}",
|
||||
score=0.95,
|
||||
result_type="relation",
|
||||
source="graph_path",
|
||||
metadata={
|
||||
"source": "graph_path",
|
||||
"is_indirect": True,
|
||||
"nodes": list(item.get("nodes", [])),
|
||||
},
|
||||
)
|
||||
)
|
||||
return converted
|
||||
|
||||
599
src/A_memorix/core/utils/person_profile_service.py
Normal file
599
src/A_memorix/core/utils/person_profile_service.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""
|
||||
人物画像服务
|
||||
|
||||
主链路:
|
||||
person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
from ..retrieval import (
|
||||
DualPathRetriever,
|
||||
RetrievalStrategy,
|
||||
DualPathRetrieverConfig,
|
||||
SparseBM25Config,
|
||||
FusionConfig,
|
||||
GraphRelationRecallConfig,
|
||||
)
|
||||
from ..storage import MetadataStore, GraphStore, VectorStore
|
||||
|
||||
logger = get_logger("A_Memorix.PersonProfileService")
|
||||
|
||||
|
||||
class PersonProfileService:
|
||||
"""人物画像聚合/刷新服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: MetadataStore,
|
||||
graph_store: Optional[GraphStore] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
embedding_manager: Optional[EmbeddingAPIAdapter] = None,
|
||||
sparse_index: Any = None,
|
||||
plugin_config: Optional[dict] = None,
|
||||
retriever: Optional[DualPathRetriever] = None,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.graph_store = graph_store
|
||||
self.vector_store = vector_store
|
||||
self.embedding_manager = embedding_manager
|
||||
self.sparse_index = sparse_index
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.retriever = retriever or self._build_retriever()
|
||||
|
||||
def _cfg(self, key: str, default: Any = None) -> Any:
|
||||
"""读取嵌套配置。"""
|
||||
if not isinstance(self.plugin_config, dict):
|
||||
return default
|
||||
current: Any = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
def _build_retriever(self) -> Optional[DualPathRetriever]:
|
||||
"""按需构建检索器(无依赖时返回 None)。"""
|
||||
if not all(
|
||||
[
|
||||
self.vector_store is not None,
|
||||
self.graph_store is not None,
|
||||
self.metadata_store is not None,
|
||||
self.embedding_manager is not None,
|
||||
]
|
||||
):
|
||||
return None
|
||||
try:
|
||||
sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {}
|
||||
fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {}
|
||||
graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {}
|
||||
if not isinstance(sparse_cfg_raw, dict):
|
||||
sparse_cfg_raw = {}
|
||||
if not isinstance(fusion_cfg_raw, dict):
|
||||
fusion_cfg_raw = {}
|
||||
if not isinstance(graph_recall_cfg_raw, dict):
|
||||
graph_recall_cfg_raw = {}
|
||||
|
||||
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
|
||||
fusion_cfg = FusionConfig(**fusion_cfg_raw)
|
||||
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
|
||||
config = DualPathRetrieverConfig(
|
||||
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)),
|
||||
top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)),
|
||||
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
|
||||
alpha=float(self._cfg("retrieval.alpha", 0.5)),
|
||||
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
|
||||
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
|
||||
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
|
||||
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
|
||||
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
|
||||
debug=bool(self._cfg("advanced.debug", False)),
|
||||
sparse=sparse_cfg,
|
||||
fusion=fusion_cfg,
|
||||
graph_recall=graph_recall_cfg,
|
||||
)
|
||||
return DualPathRetriever(
|
||||
vector_store=self.vector_store,
|
||||
graph_store=self.graph_store,
|
||||
metadata_store=self.metadata_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
sparse_index=self.sparse_index,
|
||||
config=config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def resolve_person_id(identifier: str) -> str:
|
||||
"""按 person_id 或姓名/别名解析 person_id。"""
|
||||
if not identifier:
|
||||
return ""
|
||||
key = str(identifier).strip()
|
||||
if not key:
|
||||
return ""
|
||||
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id)
|
||||
.where(
|
||||
or_(
|
||||
PersonInfo.person_name == key,
|
||||
PersonInfo.user_nickname == key,
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
|
||||
record = session.exec(
|
||||
select(PersonInfo.person_id)
|
||||
.where(PersonInfo.group_cardname.contains(key))
|
||||
.limit(1)
|
||||
).first()
|
||||
if record:
|
||||
return str(record)
|
||||
except Exception as e:
|
||||
logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}")
|
||||
|
||||
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
|
||||
return key.lower()
|
||||
|
||||
return ""
|
||||
|
||||
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
|
||||
if not raw_value:
|
||||
return []
|
||||
if isinstance(raw_value, list):
|
||||
items = raw_value
|
||||
else:
|
||||
try:
|
||||
items = json.loads(raw_value)
|
||||
except Exception:
|
||||
return []
|
||||
names: List[str] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip()
|
||||
if value:
|
||||
names.append(value)
|
||||
elif isinstance(item, str):
|
||||
value = item.strip()
|
||||
if value:
|
||||
names.append(value)
|
||||
return names
|
||||
|
||||
def _parse_memory_traits(self, raw_value: Any) -> List[str]:
|
||||
if not raw_value:
|
||||
return []
|
||||
try:
|
||||
values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(values, list):
|
||||
return []
|
||||
traits: List[str] = []
|
||||
for item in values:
|
||||
text = str(item).strip()
|
||||
if not text:
|
||||
continue
|
||||
if ":" in text:
|
||||
parts = text.split(":")
|
||||
if len(parts) >= 3:
|
||||
content = ":".join(parts[1:-1]).strip()
|
||||
if content:
|
||||
traits.append(content)
|
||||
continue
|
||||
traits.append(text)
|
||||
return traits[:10]
|
||||
|
||||
def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]:
|
||||
"""当人物主档案缺失时,从已有记忆证据里回捞可用别名。"""
|
||||
if not person_id:
|
||||
return [], ""
|
||||
|
||||
aliases: List[str] = []
|
||||
primary_name = ""
|
||||
seen = set()
|
||||
|
||||
try:
|
||||
paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}")
|
||||
return [], ""
|
||||
|
||||
for paragraph in paragraphs[:20]:
|
||||
paragraph_hash = str(paragraph.get("hash", "") or "").strip()
|
||||
if not paragraph_hash:
|
||||
continue
|
||||
try:
|
||||
paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash)
|
||||
except Exception:
|
||||
paragraph_entities = []
|
||||
for entity in paragraph_entities:
|
||||
name = str(entity.get("name", "") or "").strip()
|
||||
if not name or name == person_id:
|
||||
continue
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
aliases.append(name)
|
||||
if not primary_name:
|
||||
primary_name = name
|
||||
return aliases, primary_name
|
||||
|
||||
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
|
||||
"""获取人物别名集合、主展示名、记忆特征。"""
|
||||
aliases: List[str] = []
|
||||
primary_name = ""
|
||||
memory_traits: List[str] = []
|
||||
if not person_id:
|
||||
return aliases, primary_name, memory_traits
|
||||
recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id)
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
record = session.exec(
|
||||
select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1)
|
||||
).first()
|
||||
if not record:
|
||||
return recovered_aliases, recovered_primary_name or person_id, memory_traits
|
||||
person_name = str(getattr(record, "person_name", "") or "").strip()
|
||||
nickname = str(getattr(record, "user_nickname", "") or "").strip()
|
||||
group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None))
|
||||
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
|
||||
|
||||
primary_name = (
|
||||
person_name
|
||||
or nickname
|
||||
or recovered_primary_name
|
||||
or str(getattr(record, "user_id", "") or "").strip()
|
||||
or person_id
|
||||
)
|
||||
|
||||
candidates = [person_name, nickname] + group_nicks + recovered_aliases
|
||||
seen = set()
|
||||
for item in candidates:
|
||||
norm = str(item or "").strip()
|
||||
if not norm or norm in seen:
|
||||
continue
|
||||
seen.add(norm)
|
||||
aliases.append(norm)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}")
|
||||
return aliases, primary_name, memory_traits
|
||||
|
||||
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
|
||||
relation_by_hash: Dict[str, Dict[str, Any]] = {}
|
||||
for alias in aliases:
|
||||
for rel in self.metadata_store.get_relations(subject=alias):
|
||||
h = str(rel.get("hash", ""))
|
||||
if h:
|
||||
relation_by_hash[h] = rel
|
||||
for rel in self.metadata_store.get_relations(object=alias):
|
||||
h = str(rel.get("hash", ""))
|
||||
if h:
|
||||
relation_by_hash[h] = rel
|
||||
|
||||
relations = list(relation_by_hash.values())
|
||||
relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True)
|
||||
relations = relations[: max(1, int(limit))]
|
||||
|
||||
edges: List[Dict[str, Any]] = []
|
||||
for rel in relations:
|
||||
edges.append(
|
||||
{
|
||||
"hash": str(rel.get("hash", "")),
|
||||
"subject": str(rel.get("subject", "")),
|
||||
"predicate": str(rel.get("predicate", "")),
|
||||
"object": str(rel.get("object", "")),
|
||||
"confidence": float(rel.get("confidence", 1.0) or 1.0),
|
||||
}
|
||||
)
|
||||
return edges
|
||||
|
||||
def _collect_person_fact_evidence(self, person_id: str, limit: int = 4) -> List[Dict[str, Any]]:
|
||||
token = str(person_id or "").strip()
|
||||
if not token:
|
||||
return []
|
||||
|
||||
source = f"person_fact:{token}"
|
||||
paragraphs = [
|
||||
row
|
||||
for row in self.metadata_store.get_paragraphs_by_source(source)
|
||||
if not bool(row.get("is_deleted", 0))
|
||||
]
|
||||
paragraphs.sort(
|
||||
key=lambda item: float(item.get("updated_at") or item.get("created_at") or 0.0),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
evidence: List[Dict[str, Any]] = []
|
||||
for row in paragraphs[: max(1, int(limit))]:
|
||||
paragraph_hash = str(row.get("hash", "") or "")
|
||||
content = str(row.get("content", "") or "").strip()
|
||||
if not paragraph_hash or not content:
|
||||
continue
|
||||
evidence.append(
|
||||
{
|
||||
"hash": paragraph_hash,
|
||||
"type": "paragraph",
|
||||
"score": 1.1,
|
||||
"content": content[:220],
|
||||
"metadata": {},
|
||||
}
|
||||
)
|
||||
return evidence
|
||||
|
||||
async def _collect_vector_evidence(
|
||||
self,
|
||||
aliases: List[str],
|
||||
top_k: int = 12,
|
||||
person_id: str = "",
|
||||
) -> List[Dict[str, Any]]:
|
||||
alias_queries = [a for a in aliases if a]
|
||||
if not alias_queries and not person_id:
|
||||
return []
|
||||
|
||||
if self.retriever is None:
|
||||
# 回退:无检索器时只做简单内容匹配
|
||||
fallback: List[Dict[str, Any]] = []
|
||||
seen_hash = set()
|
||||
for alias in alias_queries:
|
||||
for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]:
|
||||
h = str(para.get("hash", ""))
|
||||
if not h or h in seen_hash:
|
||||
continue
|
||||
seen_hash.add(h)
|
||||
fallback.append(
|
||||
{
|
||||
"hash": h,
|
||||
"type": "paragraph",
|
||||
"score": 0.0,
|
||||
"content": str(para.get("content", ""))[:180],
|
||||
"metadata": {},
|
||||
}
|
||||
)
|
||||
return fallback[:top_k]
|
||||
|
||||
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
|
||||
seen_hash = set()
|
||||
evidence: List[Dict[str, Any]] = []
|
||||
for item in self._collect_person_fact_evidence(person_id, limit=max(2, min(4, top_k))):
|
||||
h = str(item.get("hash", "") or "")
|
||||
if not h or h in seen_hash:
|
||||
continue
|
||||
seen_hash.add(h)
|
||||
evidence.append(item)
|
||||
|
||||
for alias in alias_queries:
|
||||
try:
|
||||
results = await self.retriever.retrieve(alias, top_k=per_alias_top_k)
|
||||
except Exception as e:
|
||||
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
|
||||
continue
|
||||
for item in results:
|
||||
h = str(getattr(item, "hash_value", "") or "")
|
||||
if not h or h in seen_hash:
|
||||
continue
|
||||
seen_hash.add(h)
|
||||
evidence.append(
|
||||
{
|
||||
"hash": h,
|
||||
"type": str(getattr(item, "result_type", "")),
|
||||
"score": float(getattr(item, "score", 0.0) or 0.0),
|
||||
"content": str(getattr(item, "content", "") or "")[:220],
|
||||
"metadata": dict(getattr(item, "metadata", {}) or {}),
|
||||
}
|
||||
)
|
||||
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
|
||||
return evidence[:top_k]
|
||||
|
||||
def _build_profile_text(
|
||||
self,
|
||||
person_id: str,
|
||||
primary_name: str,
|
||||
aliases: List[str],
|
||||
relation_edges: List[Dict[str, Any]],
|
||||
vector_evidence: List[Dict[str, Any]],
|
||||
memory_traits: List[str],
|
||||
) -> str:
|
||||
"""基于证据构建画像文本(供 LLM 上下文注入)。"""
|
||||
lines: List[str] = []
|
||||
lines.append(f"人物ID: {person_id}")
|
||||
if primary_name:
|
||||
lines.append(f"主称呼: {primary_name}")
|
||||
if aliases:
|
||||
lines.append(f"别名: {', '.join(aliases[:8])}")
|
||||
if memory_traits:
|
||||
lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}")
|
||||
|
||||
if relation_edges:
|
||||
lines.append("关系证据:")
|
||||
for rel in relation_edges[:6]:
|
||||
s = rel.get("subject", "")
|
||||
p = rel.get("predicate", "")
|
||||
o = rel.get("object", "")
|
||||
conf = float(rel.get("confidence", 0.0))
|
||||
lines.append(f"- {s} {p} {o} (conf={conf:.2f})")
|
||||
|
||||
if vector_evidence:
|
||||
lines.append("向量证据摘要:")
|
||||
for item in vector_evidence[:4]:
|
||||
content = str(item.get("content", "")).strip()
|
||||
if content:
|
||||
lines.append(f"- {content}")
|
||||
|
||||
if len(lines) <= 2:
|
||||
lines.append("暂无足够证据形成稳定画像。")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool:
|
||||
if not snapshot:
|
||||
return True
|
||||
now = time.time()
|
||||
expires_at = snapshot.get("expires_at")
|
||||
if expires_at is not None:
|
||||
try:
|
||||
return now >= float(expires_at)
|
||||
except Exception:
|
||||
return True
|
||||
updated_at = float(snapshot.get("updated_at") or 0.0)
|
||||
return (now - updated_at) >= ttl_seconds
|
||||
|
||||
def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""将手工覆盖并入画像结果(覆盖 profile_text,同时保留 auto_profile_text)。"""
|
||||
payload = dict(profile_payload or {})
|
||||
auto_text = str(payload.get("profile_text", "") or "")
|
||||
payload["auto_profile_text"] = auto_text
|
||||
payload["has_manual_override"] = False
|
||||
payload["manual_override_text"] = ""
|
||||
payload["override_updated_at"] = None
|
||||
payload["override_updated_by"] = ""
|
||||
payload["profile_source"] = "auto_snapshot"
|
||||
|
||||
if not person_id or self.metadata_store is None:
|
||||
return payload
|
||||
|
||||
try:
|
||||
override = self.metadata_store.get_person_profile_override(person_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}")
|
||||
return payload
|
||||
|
||||
if not override:
|
||||
return payload
|
||||
|
||||
manual_text = str(override.get("override_text", "") or "").strip()
|
||||
if not manual_text:
|
||||
return payload
|
||||
|
||||
payload["has_manual_override"] = True
|
||||
payload["manual_override_text"] = manual_text
|
||||
payload["override_updated_at"] = override.get("updated_at")
|
||||
payload["override_updated_by"] = str(override.get("updated_by", "") or "")
|
||||
payload["profile_text"] = manual_text
|
||||
payload["profile_source"] = "manual_override"
|
||||
return payload
|
||||
|
||||
async def query_person_profile(
|
||||
self,
|
||||
person_id: str = "",
|
||||
person_keyword: str = "",
|
||||
top_k: int = 12,
|
||||
ttl_seconds: float = 6 * 3600,
|
||||
force_refresh: bool = False,
|
||||
source_note: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""查询或刷新人物画像。"""
|
||||
pid = str(person_id or "").strip()
|
||||
if not pid and person_keyword:
|
||||
pid = self.resolve_person_id(person_keyword)
|
||||
|
||||
if not pid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "person_id 无效,且未能通过别名解析",
|
||||
}
|
||||
|
||||
latest = self.metadata_store.get_latest_person_profile_snapshot(pid)
|
||||
if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds):
|
||||
aliases, primary_name, _ = self.get_person_aliases(pid)
|
||||
payload = {
|
||||
"success": True,
|
||||
"person_id": pid,
|
||||
"person_name": primary_name,
|
||||
"from_cache": True,
|
||||
**(latest or {}),
|
||||
}
|
||||
if aliases and not payload.get("aliases"):
|
||||
payload["aliases"] = aliases
|
||||
return {
|
||||
**self._apply_manual_override(pid, payload),
|
||||
}
|
||||
|
||||
aliases, primary_name, memory_traits = self.get_person_aliases(pid)
|
||||
if not aliases and person_keyword:
|
||||
aliases = [person_keyword.strip()]
|
||||
primary_name = person_keyword.strip()
|
||||
relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2))
|
||||
vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k), person_id=pid)
|
||||
|
||||
evidence_ids = [
|
||||
str(item.get("hash", ""))
|
||||
for item in (relation_edges + vector_evidence)
|
||||
if str(item.get("hash", "")).strip()
|
||||
]
|
||||
dedup_ids: List[str] = []
|
||||
seen = set()
|
||||
for item in evidence_ids:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
dedup_ids.append(item)
|
||||
|
||||
profile_text = self._build_profile_text(
|
||||
person_id=pid,
|
||||
primary_name=primary_name,
|
||||
aliases=aliases,
|
||||
relation_edges=relation_edges,
|
||||
vector_evidence=vector_evidence,
|
||||
memory_traits=memory_traits,
|
||||
)
|
||||
|
||||
expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None
|
||||
snapshot = self.metadata_store.upsert_person_profile_snapshot(
|
||||
person_id=pid,
|
||||
profile_text=profile_text,
|
||||
aliases=aliases,
|
||||
relation_edges=relation_edges,
|
||||
vector_evidence=vector_evidence,
|
||||
evidence_ids=dedup_ids,
|
||||
expires_at=expires_at,
|
||||
source_note=source_note,
|
||||
)
|
||||
payload = {
|
||||
"success": True,
|
||||
"person_id": pid,
|
||||
"person_name": primary_name,
|
||||
"from_cache": False,
|
||||
**snapshot,
|
||||
}
|
||||
return {
|
||||
**self._apply_manual_override(pid, payload),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def format_persona_profile_block(profile: Dict[str, Any]) -> str:
|
||||
"""格式化给 replyer 的注入块。"""
|
||||
if not profile or not profile.get("success"):
|
||||
return ""
|
||||
text = str(profile.get("profile_text", "") or "").strip()
|
||||
if not text:
|
||||
return ""
|
||||
return (
|
||||
"【人物画像-内部参考】\n"
|
||||
f"{text}\n"
|
||||
"仅供内部推理,不要向用户逐字复述。"
|
||||
)
|
||||
27
src/A_memorix/core/utils/plugin_id_policy.py
Normal file
27
src/A_memorix/core/utils/plugin_id_policy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Plugin ID matching policy for A_Memorix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PluginIdPolicy:
|
||||
"""Centralized plugin id normalization/matching policy."""
|
||||
|
||||
CANONICAL_ID = "a_memorix"
|
||||
|
||||
@classmethod
|
||||
def normalize(cls, plugin_id: Any) -> str:
|
||||
if not isinstance(plugin_id, str):
|
||||
return ""
|
||||
return plugin_id.strip().lower()
|
||||
|
||||
@classmethod
|
||||
def is_target_plugin_id(cls, plugin_id: Any) -> bool:
|
||||
normalized = cls.normalize(plugin_id)
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized == cls.CANONICAL_ID:
|
||||
return True
|
||||
return normalized.split(".")[-1] == cls.CANONICAL_ID
|
||||
|
||||
344
src/A_memorix/core/utils/quantization.py
Normal file
344
src/A_memorix/core/utils/quantization.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
向量量化工具模块
|
||||
|
||||
提供向量量化与反量化功能,用于压缩存储空间。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from typing import Tuple, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.Quantization")
|
||||
|
||||
|
||||
class QuantizationType(Enum):
|
||||
"""量化类型枚举"""
|
||||
FLOAT32 = "float32" # 无量化
|
||||
INT8 = "int8" # 标量量化(8位整数)
|
||||
PQ = "pq" # 乘积量化(Product Quantization)
|
||||
|
||||
|
||||
def quantize_vector(
|
||||
vector: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
量化向量
|
||||
|
||||
Args:
|
||||
vector: 输入向量(float32)
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
量化后的向量:
|
||||
- INT8: int8向量
|
||||
- PQ: (编码向量, 聚类中心) 元组
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return vector.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
return _scalar_quantize_int8(vector)
|
||||
|
||||
elif quant_type == QuantizationType.PQ:
|
||||
return _product_quantize(vector)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def dequantize_vector(
|
||||
quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
original_shape: Tuple[int, ...] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
反量化向量
|
||||
|
||||
Args:
|
||||
quantized_vector: 量化后的向量
|
||||
quant_type: 量化类型
|
||||
original_shape: 原始向量形状(用于PQ)
|
||||
|
||||
Returns:
|
||||
反量化后的向量(float32)
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return quantized_vector.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
return _scalar_dequantize_int8(quantized_vector)
|
||||
|
||||
elif quant_type == QuantizationType.PQ:
|
||||
if not isinstance(quantized_vector, tuple):
|
||||
raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)")
|
||||
return _product_dequantize(quantized_vector[0], quantized_vector[1])
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
标量量化:float32 -> int8
|
||||
|
||||
将向量归一化到 [0, 255] 范围,然后映射到 int8
|
||||
|
||||
Args:
|
||||
vector: 输入向量
|
||||
|
||||
Returns:
|
||||
量化后的 int8 向量
|
||||
"""
|
||||
# 计算最小最大值
|
||||
min_val = np.min(vector)
|
||||
max_val = np.max(vector)
|
||||
|
||||
# 避免除零
|
||||
if max_val == min_val:
|
||||
return np.zeros_like(vector, dtype=np.int8)
|
||||
|
||||
# 归一化到 [0, 255]
|
||||
normalized = (vector - min_val) / (max_val - min_val) * 255
|
||||
|
||||
# 映射到 [-128, 127] 并转换为 int8
|
||||
# np.round might return float, minus 128 then cast
|
||||
quantized = np.round(normalized - 128.0).astype(np.int8)
|
||||
|
||||
# 存储归一化参数(用于反量化)
|
||||
# 在实际存储中,这些参数需要单独保存
|
||||
# 这里为了简单,我们使用一个全局字典来模拟
|
||||
if not hasattr(_scalar_quantize_int8, "_params"):
|
||||
_scalar_quantize_int8._params = {}
|
||||
|
||||
vector_id = id(vector)
|
||||
_scalar_quantize_int8._params[vector_id] = (min_val, max_val)
|
||||
|
||||
return quantized
|
||||
|
||||
|
||||
def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
标量反量化:int8 -> float32
|
||||
|
||||
Args:
|
||||
quantized: 量化后的 int8 向量
|
||||
|
||||
Returns:
|
||||
反量化后的 float32 向量
|
||||
"""
|
||||
# 计算归一化参数(如果提供了)
|
||||
# 在实际应用中,min_val 和 max_val 应该被保存
|
||||
if not hasattr(_scalar_dequantize_int8, "_params"):
|
||||
# 默认假设范围是 [-1, 1]
|
||||
return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0
|
||||
|
||||
# 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取)
|
||||
# return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min
|
||||
return (quantized.astype(np.float32) + 128.0) / 255.0
|
||||
|
||||
|
||||
def quantize_matrix(
|
||||
matrix: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
量化矩阵(批量量化向量)
|
||||
|
||||
Args:
|
||||
matrix: 输入矩阵(N x D,每行是一个向量)
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
量化后的矩阵
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return matrix.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
# 对整个矩阵进行全局归一化
|
||||
min_val = np.min(matrix)
|
||||
max_val = np.max(matrix)
|
||||
|
||||
if max_val == min_val:
|
||||
return np.zeros_like(matrix, dtype=np.int8)
|
||||
|
||||
# 归一化到 [0, 255]
|
||||
normalized = (matrix - min_val) / (max_val - min_val) * 255
|
||||
quantized = np.round(normalized).astype(np.int8)
|
||||
|
||||
return quantized
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def dequantize_matrix(
|
||||
quantized_matrix: np.ndarray,
|
||||
quant_type: QuantizationType = QuantizationType.INT8,
|
||||
min_val: float = None,
|
||||
max_val: float = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
反量化矩阵
|
||||
|
||||
Args:
|
||||
quantized_matrix: 量化后的矩阵
|
||||
quant_type: 量化类型
|
||||
min_val: 归一化最小值(int8反量化需要)
|
||||
max_val: 归一化最大值(int8反量化需要)
|
||||
|
||||
Returns:
|
||||
反量化后的矩阵
|
||||
"""
|
||||
if quant_type == QuantizationType.FLOAT32:
|
||||
return quantized_matrix.astype(np.float32)
|
||||
|
||||
elif quant_type == QuantizationType.INT8:
|
||||
# 使用提供的归一化参数反量化
|
||||
if min_val is None or max_val is None:
|
||||
# 默认假设范围是 [0, 255] -> [-1, 1]
|
||||
return quantized_matrix.astype(np.float32) / 127.0
|
||||
else:
|
||||
# 恢复到原始范围
|
||||
normalized = quantized_matrix.astype(np.float32) / 255.0
|
||||
return normalized * (max_val - min_val) + min_val
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
|
||||
def estimate_memory_reduction(
|
||||
num_vectors: int,
|
||||
dimension: int,
|
||||
from_type: QuantizationType,
|
||||
to_type: QuantizationType,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
估算内存节省量
|
||||
|
||||
Args:
|
||||
num_vectors: 向量数量
|
||||
dimension: 向量维度
|
||||
from_type: 原始量化类型
|
||||
to_type: 目标量化类型
|
||||
|
||||
Returns:
|
||||
(原始大小MB, 量化后大小MB, 节省比例)
|
||||
"""
|
||||
# 计算每个向量占用的字节数
|
||||
bytes_per_element = {
|
||||
QuantizationType.FLOAT32: 4,
|
||||
QuantizationType.INT8: 1,
|
||||
QuantizationType.PQ: 0.25, # 假设压缩到1/4
|
||||
}
|
||||
|
||||
original_bytes = num_vectors * dimension * bytes_per_element[from_type]
|
||||
quantized_bytes = num_vectors * dimension * bytes_per_element[to_type]
|
||||
|
||||
original_mb = original_bytes / 1024 / 1024
|
||||
quantized_mb = quantized_bytes / 1024 / 1024
|
||||
reduction_ratio = (original_bytes - quantized_bytes) / original_bytes
|
||||
|
||||
return original_mb, quantized_mb, reduction_ratio
|
||||
|
||||
|
||||
def estimate_compression_stats(
|
||||
num_vectors: int,
|
||||
dimension: int,
|
||||
quant_type: QuantizationType,
|
||||
) -> dict:
|
||||
"""
|
||||
估算压缩统计信息
|
||||
|
||||
Args:
|
||||
num_vectors: 向量数量
|
||||
dimension: 向量维度
|
||||
quant_type: 量化类型
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
original_mb, quantized_mb, ratio = estimate_memory_reduction(
|
||||
num_vectors, dimension, QuantizationType.FLOAT32, quant_type
|
||||
)
|
||||
|
||||
return {
|
||||
"num_vectors": num_vectors,
|
||||
"dimension": dimension,
|
||||
"quantization_type": quant_type.value,
|
||||
"original_size_mb": round(original_mb, 2),
|
||||
"quantized_size_mb": round(quantized_mb, 2),
|
||||
"saved_mb": round(original_mb - quantized_mb, 2),
|
||||
"compression_ratio": round(ratio * 100, 2),
|
||||
}
|
||||
|
||||
|
||||
def _product_quantize(
|
||||
vector: np.ndarray, m: int = 8, k: int = 256
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
乘积量化 (PQ) 简化实现
|
||||
|
||||
Args:
|
||||
vector: 输入向量 (D,)
|
||||
m: 子空间数量
|
||||
k: 每个子空间的聚类中心数
|
||||
|
||||
Returns:
|
||||
(编码后的向量, 聚类中心)
|
||||
"""
|
||||
d = vector.shape[0]
|
||||
if d % m != 0:
|
||||
raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除")
|
||||
|
||||
ds = d // m # 子空间维度
|
||||
codes = np.zeros(m, dtype=np.uint8)
|
||||
centroids = np.zeros((m, k, ds), dtype=np.float32)
|
||||
|
||||
# 这里采用一种简化的 PQ:不进行 K-means 训练,
|
||||
# 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练)
|
||||
# 为了演示,我们直接将子空间切分为 k 份进行量化
|
||||
for i in range(m):
|
||||
sub_vec = vector[i * ds : (i + 1) * ds]
|
||||
# 简化:假定每个子空间的取值范围并划分
|
||||
# 实际 PQ 应使用 k-means 产生的 centroids
|
||||
# 这里为演示创建一个随机 codebook 并找到最接近的核心
|
||||
sub_min, sub_max = np.min(sub_vec), np.max(sub_vec)
|
||||
if sub_max == sub_min:
|
||||
linspace = np.zeros(k)
|
||||
else:
|
||||
linspace = np.linspace(sub_min, sub_max, k)
|
||||
|
||||
for j in range(k):
|
||||
centroids[i, j, :] = linspace[j]
|
||||
|
||||
# 编码:这里简化为取子空间均值找最接近的 centroid
|
||||
sub_mean = np.mean(sub_vec)
|
||||
code = np.argmin(np.abs(linspace - sub_mean))
|
||||
codes[i] = code
|
||||
|
||||
return codes, centroids
|
||||
|
||||
|
||||
def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
PQ 反量化
|
||||
|
||||
Args:
|
||||
codes: 编码向量 (M,)
|
||||
centroids: 聚类中心 (M, K, DS)
|
||||
|
||||
Returns:
|
||||
恢复后的向量 (D,)
|
||||
"""
|
||||
m, k, ds = centroids.shape
|
||||
vector = np.zeros(m * ds, dtype=np.float32)
|
||||
|
||||
for i in range(m):
|
||||
code = codes[i]
|
||||
vector[i * ds : (i + 1) * ds] = centroids[i, code, :]
|
||||
|
||||
return vector
|
||||
121
src/A_memorix/core/utils/relation_query.py
Normal file
121
src/A_memorix/core/utils/relation_query.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""关系查询规格解析工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationQuerySpec:
|
||||
raw: str
|
||||
is_structured: bool
|
||||
subject: Optional[str]
|
||||
predicate: Optional[str]
|
||||
object: Optional[str]
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
_NATURAL_LANGUAGE_PATTERN = re.compile(
|
||||
r"(^\s*(what|who|which|how|why|when|where)\b|"
|
||||
r"\?|?|"
|
||||
r"\b(relation|related|between)\b|"
|
||||
r"(什么关系|有哪些关系|之间|关联))",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _looks_like_natural_language(raw: str) -> bool:
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
return _NATURAL_LANGUAGE_PATTERN.search(text) is not None
|
||||
|
||||
|
||||
def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec:
|
||||
raw = str(relation_spec or "").strip()
|
||||
if not raw:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="empty",
|
||||
)
|
||||
|
||||
if "|" in raw:
|
||||
parts = [p.strip() for p in raw.split("|")]
|
||||
if len(parts) < 2:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="invalid_pipe_format",
|
||||
)
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0] or None,
|
||||
predicate=parts[1] or None,
|
||||
object=parts[2] if len(parts) > 2 and parts[2] else None,
|
||||
)
|
||||
|
||||
if "->" in raw:
|
||||
parts = [p.strip() for p in raw.split("->") if p.strip()]
|
||||
if len(parts) >= 3:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=parts[1],
|
||||
object=parts[2],
|
||||
)
|
||||
if len(parts) == 2:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=None,
|
||||
object=parts[1],
|
||||
)
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
error="invalid_arrow_format",
|
||||
)
|
||||
|
||||
if _looks_like_natural_language(raw):
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
)
|
||||
|
||||
# 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。
|
||||
# 两词形式过于模糊,不再视为结构化关系查询。
|
||||
parts = raw.split()
|
||||
if len(parts) == 3:
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=True,
|
||||
subject=parts[0],
|
||||
predicate=parts[1],
|
||||
object=parts[2],
|
||||
)
|
||||
|
||||
return RelationQuerySpec(
|
||||
raw=raw,
|
||||
is_structured=False,
|
||||
subject=None,
|
||||
predicate=None,
|
||||
object=None,
|
||||
)
|
||||
166
src/A_memorix/core/utils/relation_write_service.py
Normal file
166
src/A_memorix/core/utils/relation_write_service.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
统一关系写入与关系向量化服务。
|
||||
|
||||
规则:
|
||||
1. 元数据是主数据源,向量是从索引。
|
||||
2. 关系先写 metadata,再写向量。
|
||||
3. 向量失败不回滚 metadata,依赖状态机与回填任务修复。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("A_Memorix.RelationWriteService")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationWriteResult:
|
||||
hash_value: str
|
||||
vector_written: bool
|
||||
vector_already_exists: bool
|
||||
vector_state: str
|
||||
|
||||
|
||||
class RelationWriteService:
|
||||
"""关系写入收口服务。"""
|
||||
|
||||
ERROR_MAX_LEN = 500
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store: Any,
|
||||
graph_store: Any,
|
||||
vector_store: Any,
|
||||
embedding_manager: Any,
|
||||
):
|
||||
self.metadata_store = metadata_store
|
||||
self.graph_store = graph_store
|
||||
self.vector_store = vector_store
|
||||
self.embedding_manager = embedding_manager
|
||||
|
||||
@staticmethod
|
||||
def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str:
|
||||
s = str(subject or "").strip()
|
||||
p = str(predicate or "").strip()
|
||||
o = str(obj or "").strip()
|
||||
# 双表达:兼容关键词检索与自然语言问句
|
||||
return f"{s} {p} {o}\n{s}和{o}的关系是{p}"
|
||||
|
||||
async def ensure_relation_vector(
|
||||
self,
|
||||
hash_value: str,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
*,
|
||||
max_error_len: int = ERROR_MAX_LEN,
|
||||
) -> RelationWriteResult:
|
||||
"""
|
||||
为已有关系确保向量存在并更新状态。
|
||||
"""
|
||||
if hash_value in self.vector_store:
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=True,
|
||||
vector_state="ready",
|
||||
)
|
||||
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "pending")
|
||||
try:
|
||||
vector_text = self.build_relation_vector_text(subject, predicate, obj)
|
||||
embedding = await self.embedding_manager.encode(vector_text)
|
||||
self.vector_store.add(
|
||||
vectors=embedding.reshape(1, -1),
|
||||
ids=[hash_value],
|
||||
)
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
logger.info(
|
||||
"metric.relation_vector_write_success=1 "
|
||||
"metric.relation_vector_write_success_count=1 "
|
||||
f"hash={hash_value[:16]}"
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=True,
|
||||
vector_already_exists=False,
|
||||
vector_state="ready",
|
||||
)
|
||||
except ValueError:
|
||||
# 向量已存在冲突,按成功处理
|
||||
self.metadata_store.set_relation_vector_state(hash_value, "ready")
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=True,
|
||||
vector_state="ready",
|
||||
)
|
||||
except Exception as e:
|
||||
err = str(e)[:max_error_len]
|
||||
self.metadata_store.set_relation_vector_state(
|
||||
hash_value,
|
||||
"failed",
|
||||
error=err,
|
||||
bump_retry=True,
|
||||
)
|
||||
logger.warning(
|
||||
"metric.relation_vector_write_fail=1 "
|
||||
"metric.relation_vector_write_fail_count=1 "
|
||||
f"hash={hash_value[:16]} "
|
||||
f"err={err}"
|
||||
)
|
||||
return RelationWriteResult(
|
||||
hash_value=hash_value,
|
||||
vector_written=False,
|
||||
vector_already_exists=False,
|
||||
vector_state="failed",
|
||||
)
|
||||
|
||||
async def upsert_relation_with_vector(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
confidence: float = 1.0,
|
||||
source_paragraph: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
write_vector: bool = True,
|
||||
) -> RelationWriteResult:
|
||||
"""
|
||||
统一关系写入:
|
||||
1) 写 metadata relation
|
||||
2) 写 graph edge relation_hash
|
||||
3) 按需写 relation vector
|
||||
"""
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
confidence=confidence,
|
||||
source_paragraph=source_paragraph,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
|
||||
|
||||
if not write_vector:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
return RelationWriteResult(
|
||||
hash_value=rel_hash,
|
||||
vector_written=False,
|
||||
vector_already_exists=False,
|
||||
vector_state="none",
|
||||
)
|
||||
|
||||
return await self.ensure_relation_vector(
|
||||
hash_value=rel_hash,
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
)
|
||||
1865
src/A_memorix/core/utils/retrieval_tuning_manager.py
Normal file
1865
src/A_memorix/core/utils/retrieval_tuning_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
240
src/A_memorix/core/utils/runtime_self_check.py
Normal file
240
src/A_memorix/core/utils/runtime_self_check.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Runtime self-check helpers for A_Memorix."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.RuntimeSelfCheck")
|
||||
|
||||
_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check"
|
||||
|
||||
|
||||
def _safe_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
|
||||
def _get_config_value(config: Any, key: str, default: Any = None) -> Any:
|
||||
getter = getattr(config, "get_config", None)
|
||||
if callable(getter):
|
||||
return getter(key, default)
|
||||
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def _build_report(
|
||||
*,
|
||||
ok: bool,
|
||||
code: str,
|
||||
message: str,
|
||||
configured_dimension: int,
|
||||
requested_dimension: int,
|
||||
vector_store_dimension: int,
|
||||
detected_dimension: int,
|
||||
encoded_dimension: int,
|
||||
elapsed_ms: float,
|
||||
sample_text: str,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"ok": bool(ok),
|
||||
"code": str(code or "").strip(),
|
||||
"message": str(message or "").strip(),
|
||||
"configured_dimension": int(configured_dimension),
|
||||
"requested_dimension": int(requested_dimension),
|
||||
"vector_store_dimension": int(vector_store_dimension),
|
||||
"detected_dimension": int(detected_dimension),
|
||||
"encoded_dimension": int(encoded_dimension),
|
||||
"elapsed_ms": float(elapsed_ms),
|
||||
"sample_text": str(sample_text or ""),
|
||||
"checked_at": time.time(),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_encoded_vector(encoded: Any) -> np.ndarray:
|
||||
if encoded is None:
|
||||
raise ValueError("embedding encode returned None")
|
||||
|
||||
if isinstance(encoded, np.ndarray):
|
||||
array = encoded
|
||||
else:
|
||||
array = np.asarray(encoded, dtype=np.float32)
|
||||
|
||||
if array.ndim == 2:
|
||||
if array.shape[0] != 1:
|
||||
raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}")
|
||||
array = array[0]
|
||||
|
||||
if array.ndim != 1:
|
||||
raise ValueError(f"embedding encode returned invalid ndim={array.ndim}")
|
||||
if array.size <= 0:
|
||||
raise ValueError("embedding encode returned empty vector")
|
||||
if not np.all(np.isfinite(array)):
|
||||
raise ValueError("embedding encode returned non-finite values")
|
||||
return array.astype(np.float32, copy=False)
|
||||
|
||||
|
||||
def _get_requested_dimension(embedding_manager: Optional[Any], fallback: int = 0) -> int:
|
||||
if embedding_manager is None:
|
||||
return int(fallback)
|
||||
getter = getattr(embedding_manager, "get_requested_dimension", None)
|
||||
if callable(getter):
|
||||
return _safe_int(getter(), fallback)
|
||||
getter = getattr(embedding_manager, "get_embedding_dimension", None)
|
||||
if callable(getter):
|
||||
return _safe_int(getter(), fallback)
|
||||
return int(fallback)
|
||||
|
||||
|
||||
async def run_embedding_runtime_self_check(
|
||||
*,
|
||||
config: Any,
|
||||
vector_store: Optional[Any],
|
||||
embedding_manager: Optional[Any],
|
||||
sample_text: str = _DEFAULT_SAMPLE_TEXT,
|
||||
) -> Dict[str, Any]:
|
||||
"""Probe the real embedding path and compare dimensions with runtime storage."""
|
||||
configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0)
|
||||
vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0)
|
||||
requested_dimension = _get_requested_dimension(embedding_manager, configured_dimension)
|
||||
|
||||
if vector_store is None or embedding_manager is None:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="runtime_components_missing",
|
||||
message="vector_store 或 embedding_manager 未初始化",
|
||||
configured_dimension=configured_dimension,
|
||||
requested_dimension=requested_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=0,
|
||||
encoded_dimension=0,
|
||||
elapsed_ms=0.0,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
detected_dimension = 0
|
||||
encoded_dimension = 0
|
||||
try:
|
||||
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
|
||||
requested_dimension = _get_requested_dimension(embedding_manager, detected_dimension or configured_dimension)
|
||||
encoded = await embedding_manager.encode(sample_text)
|
||||
encoded_array = _normalize_encoded_vector(encoded)
|
||||
encoded_dimension = int(encoded_array.shape[0])
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
logger.warning(f"embedding runtime self-check failed: {exc}")
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="embedding_probe_failed",
|
||||
message=f"embedding probe failed: {exc}",
|
||||
configured_dimension=configured_dimension,
|
||||
requested_dimension=requested_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
expected_dimension = vector_store_dimension or configured_dimension or detected_dimension
|
||||
if expected_dimension <= 0:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="invalid_expected_dimension",
|
||||
message="无法确定期望 embedding 维度",
|
||||
configured_dimension=configured_dimension,
|
||||
requested_dimension=requested_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
if encoded_dimension != expected_dimension:
|
||||
msg = (
|
||||
"embedding 真实输出维度与当前向量存储不一致: "
|
||||
f"expected={expected_dimension}, encoded={encoded_dimension}"
|
||||
)
|
||||
logger.error(msg)
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="embedding_dimension_mismatch",
|
||||
message=msg,
|
||||
configured_dimension=configured_dimension,
|
||||
requested_dimension=requested_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
return _build_report(
|
||||
ok=True,
|
||||
code="ok",
|
||||
message="embedding runtime self-check passed",
|
||||
configured_dimension=configured_dimension,
|
||||
requested_dimension=requested_dimension,
|
||||
vector_store_dimension=vector_store_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
encoded_dimension=encoded_dimension,
|
||||
elapsed_ms=elapsed_ms,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
|
||||
async def ensure_runtime_self_check(
|
||||
plugin_or_config: Any,
|
||||
*,
|
||||
force: bool = False,
|
||||
sample_text: str = _DEFAULT_SAMPLE_TEXT,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run or reuse cached runtime self-check report."""
|
||||
if plugin_or_config is None:
|
||||
return _build_report(
|
||||
ok=False,
|
||||
code="missing_plugin_or_config",
|
||||
message="plugin/config unavailable",
|
||||
configured_dimension=0,
|
||||
requested_dimension=0,
|
||||
vector_store_dimension=0,
|
||||
detected_dimension=0,
|
||||
encoded_dimension=0,
|
||||
elapsed_ms=0.0,
|
||||
sample_text=sample_text,
|
||||
)
|
||||
|
||||
cache = getattr(plugin_or_config, "_runtime_self_check_report", None)
|
||||
if isinstance(cache, dict) and cache and not force:
|
||||
return cache
|
||||
|
||||
report = await run_embedding_runtime_self_check(
|
||||
config=getattr(plugin_or_config, "config", plugin_or_config),
|
||||
vector_store=getattr(plugin_or_config, "vector_store", None)
|
||||
if not isinstance(plugin_or_config, dict)
|
||||
else plugin_or_config.get("vector_store"),
|
||||
embedding_manager=getattr(plugin_or_config, "embedding_manager", None)
|
||||
if not isinstance(plugin_or_config, dict)
|
||||
else plugin_or_config.get("embedding_manager"),
|
||||
sample_text=sample_text,
|
||||
)
|
||||
try:
|
||||
setattr(plugin_or_config, "_runtime_self_check_report", report)
|
||||
except Exception:
|
||||
pass
|
||||
return report
|
||||
439
src/A_memorix/core/utils/search_execution_service.py
Normal file
439
src/A_memorix/core/utils/search_execution_service.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
统一检索执行服务。
|
||||
|
||||
用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import TemporalQueryOptions
|
||||
from .search_postprocess import (
|
||||
apply_safe_content_dedup,
|
||||
maybe_apply_smart_path_fallback,
|
||||
)
|
||||
from .time_parser import parse_query_time_range
|
||||
|
||||
logger = get_logger("A_Memorix.SearchExecutionService")
|
||||
|
||||
|
||||
def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any:
|
||||
if not isinstance(config, dict):
|
||||
return default
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
def _sanitize_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchExecutionRequest:
|
||||
caller: str
|
||||
stream_id: Optional[str] = None
|
||||
group_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
query_type: str = "search" # search|time|hybrid
|
||||
query: str = ""
|
||||
top_k: Optional[int] = None
|
||||
time_from: Optional[str] = None
|
||||
time_to: Optional[str] = None
|
||||
person: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
use_threshold: bool = True
|
||||
enable_ppr: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchExecutionResult:
|
||||
success: bool
|
||||
error: str = ""
|
||||
query_type: str = "search"
|
||||
query: str = ""
|
||||
top_k: int = 10
|
||||
time_from: Optional[str] = None
|
||||
time_to: Optional[str] = None
|
||||
person: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
temporal: Optional[TemporalQueryOptions] = None
|
||||
results: List[Any] = field(default_factory=list)
|
||||
elapsed_ms: float = 0.0
|
||||
chat_filtered: bool = False
|
||||
dedup_hit: bool = False
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self.results)
|
||||
|
||||
|
||||
class SearchExecutionService:
|
||||
"""统一检索执行服务。"""
|
||||
|
||||
@staticmethod
|
||||
def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]:
|
||||
if isinstance(plugin_config, dict):
|
||||
plugin_instance = plugin_config.get("plugin_instance")
|
||||
if plugin_instance is not None:
|
||||
return plugin_instance
|
||||
|
||||
try:
|
||||
from ...runtime_registry import get_runtime_kernel
|
||||
|
||||
return get_runtime_kernel()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_query_type(raw_query_type: str) -> str:
|
||||
return _sanitize_text(raw_query_type).lower() or "search"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_runtime_component(
|
||||
plugin_config: Optional[dict],
|
||||
plugin_instance: Optional[Any],
|
||||
key: str,
|
||||
) -> Optional[Any]:
|
||||
if isinstance(plugin_config, dict):
|
||||
value = plugin_config.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
if plugin_instance is not None:
|
||||
value = getattr(plugin_instance, key, None)
|
||||
if value is not None:
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_top_k(
|
||||
plugin_config: Optional[dict],
|
||||
query_type: str,
|
||||
top_k_raw: Optional[Any],
|
||||
) -> Tuple[bool, int, str]:
|
||||
temporal_default_top_k = int(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10)
|
||||
)
|
||||
default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10
|
||||
if top_k_raw is None:
|
||||
return True, max(1, min(50, default_top_k)), ""
|
||||
try:
|
||||
top_k = int(top_k_raw)
|
||||
except (TypeError, ValueError):
|
||||
return False, 0, "top_k 参数必须为整数"
|
||||
return True, max(1, min(50, top_k)), ""
|
||||
|
||||
@staticmethod
|
||||
def _build_temporal(
|
||||
plugin_config: Optional[dict],
|
||||
query_type: str,
|
||||
time_from_raw: Optional[str],
|
||||
time_to_raw: Optional[str],
|
||||
person: Optional[str],
|
||||
source: Optional[str],
|
||||
) -> Tuple[bool, Optional[TemporalQueryOptions], str]:
|
||||
if query_type not in {"time", "hybrid"}:
|
||||
return True, None, ""
|
||||
|
||||
temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True))
|
||||
if not temporal_enabled:
|
||||
return False, None, "时序检索已禁用(retrieval.temporal.enabled=false)"
|
||||
|
||||
if not time_from_raw and not time_to_raw:
|
||||
return False, None, "time/hybrid 模式至少需要 time_from 或 time_to"
|
||||
|
||||
try:
|
||||
ts_from, ts_to = parse_query_time_range(
|
||||
str(time_from_raw) if time_from_raw is not None else None,
|
||||
str(time_to_raw) if time_to_raw is not None else None,
|
||||
)
|
||||
except ValueError as e:
|
||||
return False, None, f"时间参数错误: {e}"
|
||||
|
||||
temporal = TemporalQueryOptions(
|
||||
time_from=ts_from,
|
||||
time_to=ts_to,
|
||||
person=_sanitize_text(person) or None,
|
||||
source=_sanitize_text(source) or None,
|
||||
allow_created_fallback=bool(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True)
|
||||
),
|
||||
candidate_multiplier=int(
|
||||
_get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8)
|
||||
),
|
||||
max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)),
|
||||
)
|
||||
return True, temporal, ""
|
||||
|
||||
@staticmethod
|
||||
def _build_request_key(
|
||||
request: SearchExecutionRequest,
|
||||
query_type: str,
|
||||
top_k: int,
|
||||
temporal: Optional[TemporalQueryOptions],
|
||||
) -> str:
|
||||
payload = {
|
||||
"stream_id": _sanitize_text(request.stream_id),
|
||||
"query_type": query_type,
|
||||
"query": _sanitize_text(request.query),
|
||||
"time_from": _sanitize_text(request.time_from),
|
||||
"time_to": _sanitize_text(request.time_to),
|
||||
"time_from_ts": temporal.time_from if temporal else None,
|
||||
"time_to_ts": temporal.time_to if temporal else None,
|
||||
"person": _sanitize_text(request.person),
|
||||
"source": _sanitize_text(request.source),
|
||||
"top_k": int(top_k),
|
||||
"use_threshold": bool(request.use_threshold),
|
||||
"enable_ppr": bool(request.enable_ppr),
|
||||
}
|
||||
payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
||||
return hashlib.sha1(payload_json.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
*,
|
||||
retriever: Any,
|
||||
threshold_filter: Optional[Any],
|
||||
plugin_config: Optional[dict],
|
||||
request: SearchExecutionRequest,
|
||||
enforce_chat_filter: bool = True,
|
||||
reinforce_access: bool = True,
|
||||
) -> SearchExecutionResult:
|
||||
if retriever is None:
|
||||
return SearchExecutionResult(success=False, error="知识检索器未初始化")
|
||||
|
||||
query_type = SearchExecutionService._normalize_query_type(request.query_type)
|
||||
query = _sanitize_text(request.query)
|
||||
if query_type not in {"search", "time", "hybrid"}:
|
||||
return SearchExecutionResult(
|
||||
success=False,
|
||||
error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid)",
|
||||
)
|
||||
|
||||
if query_type in {"search", "hybrid"} and not query:
|
||||
return SearchExecutionResult(
|
||||
success=False,
|
||||
error="search/hybrid 模式必须提供 query",
|
||||
)
|
||||
|
||||
top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k(
|
||||
plugin_config, query_type, request.top_k
|
||||
)
|
||||
if not top_k_ok:
|
||||
return SearchExecutionResult(success=False, error=top_k_error)
|
||||
|
||||
temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal(
|
||||
plugin_config=plugin_config,
|
||||
query_type=query_type,
|
||||
time_from_raw=request.time_from,
|
||||
time_to_raw=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
)
|
||||
if not temporal_ok:
|
||||
return SearchExecutionResult(success=False, error=temporal_error)
|
||||
|
||||
plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config)
|
||||
if (
|
||||
enforce_chat_filter
|
||||
and plugin_instance is not None
|
||||
and hasattr(plugin_instance, "is_chat_enabled")
|
||||
):
|
||||
if not plugin_instance.is_chat_enabled(
|
||||
stream_id=request.stream_id,
|
||||
group_id=request.group_id,
|
||||
user_id=request.user_id,
|
||||
):
|
||||
logger.info(
|
||||
"检索请求被聊天过滤拦截: "
|
||||
f"caller={request.caller}, "
|
||||
f"stream_id={request.stream_id}"
|
||||
)
|
||||
return SearchExecutionResult(
|
||||
success=True,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
time_from=request.time_from,
|
||||
time_to=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
temporal=temporal,
|
||||
results=[],
|
||||
elapsed_ms=0.0,
|
||||
chat_filtered=True,
|
||||
dedup_hit=False,
|
||||
)
|
||||
|
||||
request_key = SearchExecutionService._build_request_key(
|
||||
request=request,
|
||||
query_type=query_type,
|
||||
top_k=top_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
|
||||
async def _executor() -> Dict[str, Any]:
|
||||
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
|
||||
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
|
||||
started_at = time.time()
|
||||
try:
|
||||
retrieved = await retriever.retrieve(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
temporal=temporal,
|
||||
)
|
||||
|
||||
should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None
|
||||
if (
|
||||
query_type == "time"
|
||||
and not query
|
||||
and bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.time.skip_threshold_when_query_empty",
|
||||
True,
|
||||
)
|
||||
)
|
||||
):
|
||||
should_apply_threshold = False
|
||||
|
||||
if should_apply_threshold:
|
||||
retrieved = threshold_filter.filter(retrieved)
|
||||
|
||||
if (
|
||||
reinforce_access
|
||||
and plugin_instance is not None
|
||||
and hasattr(plugin_instance, "reinforce_access")
|
||||
):
|
||||
relation_hashes = [
|
||||
item.hash_value
|
||||
for item in retrieved
|
||||
if getattr(item, "result_type", "") == "relation"
|
||||
]
|
||||
if relation_hashes:
|
||||
await plugin_instance.reinforce_access(relation_hashes)
|
||||
|
||||
if query_type == "search":
|
||||
graph_store = SearchExecutionService._resolve_runtime_component(
|
||||
plugin_config, plugin_instance, "graph_store"
|
||||
)
|
||||
metadata_store = SearchExecutionService._resolve_runtime_component(
|
||||
plugin_config, plugin_instance, "metadata_store"
|
||||
)
|
||||
fallback_enabled = bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.smart_fallback.enabled",
|
||||
True,
|
||||
)
|
||||
)
|
||||
fallback_threshold = float(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.smart_fallback.threshold",
|
||||
0.6,
|
||||
)
|
||||
)
|
||||
retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback(
|
||||
query=query,
|
||||
results=list(retrieved),
|
||||
graph_store=graph_store,
|
||||
metadata_store=metadata_store,
|
||||
enabled=fallback_enabled,
|
||||
threshold=fallback_threshold,
|
||||
)
|
||||
if fallback_triggered:
|
||||
logger.info(
|
||||
"metric.smart_fallback_triggered_count=1 "
|
||||
f"caller={request.caller} "
|
||||
f"added={fallback_added}"
|
||||
)
|
||||
|
||||
dedup_enabled = bool(
|
||||
_get_config_value(
|
||||
plugin_config,
|
||||
"retrieval.search.safe_content_dedup.enabled",
|
||||
True,
|
||||
)
|
||||
)
|
||||
if dedup_enabled:
|
||||
retrieved, removed_count = apply_safe_content_dedup(list(retrieved))
|
||||
if removed_count > 0:
|
||||
logger.info(
|
||||
f"metric.safe_dedup_removed_count={removed_count} "
|
||||
f"caller={request.caller}"
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - started_at) * 1000.0
|
||||
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
||||
finally:
|
||||
setattr(retriever.config, "enable_ppr", original_ppr)
|
||||
|
||||
dedup_hit = False
|
||||
try:
|
||||
# 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。
|
||||
bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning"
|
||||
if (
|
||||
not bypass_request_dedup
|
||||
and
|
||||
plugin_instance is not None
|
||||
and hasattr(plugin_instance, "execute_request_with_dedup")
|
||||
):
|
||||
dedup_hit, payload = await plugin_instance.execute_request_with_dedup(
|
||||
request_key,
|
||||
_executor,
|
||||
)
|
||||
else:
|
||||
payload = await _executor()
|
||||
except Exception as e:
|
||||
return SearchExecutionResult(success=False, error=f"知识检索失败: {e}")
|
||||
|
||||
if dedup_hit:
|
||||
logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}")
|
||||
|
||||
return SearchExecutionResult(
|
||||
success=True,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
time_from=request.time_from,
|
||||
time_to=request.time_to,
|
||||
person=request.person,
|
||||
source=request.source,
|
||||
temporal=temporal,
|
||||
results=payload.get("results", []),
|
||||
elapsed_ms=float(payload.get("elapsed_ms", 0.0)),
|
||||
chat_filtered=False,
|
||||
dedup_hit=bool(dedup_hit),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]:
|
||||
serialized: List[Dict[str, Any]] = []
|
||||
for item in results:
|
||||
metadata = dict(getattr(item, "metadata", {}) or {})
|
||||
if "time_meta" not in metadata:
|
||||
metadata["time_meta"] = {}
|
||||
serialized.append(
|
||||
{
|
||||
"hash": getattr(item, "hash_value", ""),
|
||||
"type": getattr(item, "result_type", ""),
|
||||
"score": float(getattr(item, "score", 0.0)),
|
||||
"content": getattr(item, "content", ""),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return serialized
|
||||
90
src/A_memorix/core/utils/search_postprocess.py
Normal file
90
src/A_memorix/core/utils/search_postprocess.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Post-processing helpers for unified search execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from .path_fallback_service import find_paths_from_query, to_retrieval_results
|
||||
|
||||
|
||||
def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]:
|
||||
"""Deduplicate results by hash/content while preserving at least one entry."""
|
||||
if not results:
|
||||
return [], 0
|
||||
|
||||
unique_results: List[Any] = []
|
||||
seen_hashes = set()
|
||||
seen_contents = set()
|
||||
|
||||
for item in results:
|
||||
content = str(getattr(item, "content", "") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content))
|
||||
if hash_value in seen_hashes:
|
||||
continue
|
||||
|
||||
is_dup = False
|
||||
for seen in seen_contents:
|
||||
if content in seen or seen in content:
|
||||
is_dup = True
|
||||
break
|
||||
if is_dup:
|
||||
continue
|
||||
|
||||
seen_hashes.add(hash_value)
|
||||
seen_contents.add(content)
|
||||
unique_results.append(item)
|
||||
|
||||
if not unique_results:
|
||||
unique_results.append(results[0])
|
||||
|
||||
removed_count = max(0, len(results) - len(unique_results))
|
||||
return unique_results, removed_count
|
||||
|
||||
|
||||
def maybe_apply_smart_path_fallback(
|
||||
*,
|
||||
query: str,
|
||||
results: List[Any],
|
||||
graph_store: Any,
|
||||
metadata_store: Any,
|
||||
enabled: bool,
|
||||
threshold: float,
|
||||
max_depth: int = 3,
|
||||
max_paths: int = 5,
|
||||
) -> Tuple[List[Any], bool, int]:
|
||||
"""Append indirect relation paths when semantic results are weak."""
|
||||
if not enabled or not str(query or "").strip():
|
||||
return results, False, 0
|
||||
if graph_store is None or metadata_store is None:
|
||||
return results, False, 0
|
||||
|
||||
max_score = 0.0
|
||||
if results:
|
||||
try:
|
||||
max_score = float(getattr(results[0], "score", 0.0) or 0.0)
|
||||
except Exception:
|
||||
max_score = 0.0
|
||||
|
||||
if max_score >= float(threshold):
|
||||
return results, False, 0
|
||||
|
||||
paths = find_paths_from_query(
|
||||
query=query,
|
||||
graph_store=graph_store,
|
||||
metadata_store=metadata_store,
|
||||
max_depth=max_depth,
|
||||
max_paths=max_paths,
|
||||
)
|
||||
if not paths:
|
||||
return results, False, 0
|
||||
|
||||
path_results = to_retrieval_results(paths)
|
||||
if not path_results:
|
||||
return results, False, 0
|
||||
|
||||
merged = list(path_results) + list(results)
|
||||
return merged, True, len(path_results)
|
||||
|
||||
470
src/A_memorix/core/utils/summary_importer.py
Normal file
470
src/A_memorix/core/utils/summary_importer.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
聊天总结与知识导入工具
|
||||
|
||||
该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系
|
||||
导入到 A_memorix 的存储组件中。
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services import llm_service as llm_api
|
||||
from src.services import message_service as message_api
|
||||
from src.config.config import global_config, model_config as host_model_config
|
||||
from src.config.model_configs import TaskConfig
|
||||
|
||||
from ..storage import (
|
||||
KnowledgeType,
|
||||
VectorStore,
|
||||
GraphStore,
|
||||
MetadataStore,
|
||||
resolve_stored_knowledge_type,
|
||||
)
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
from .relation_write_service import RelationWriteService
|
||||
from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check
|
||||
|
||||
logger = get_logger("A_Memorix.SummaryImporter")
|
||||
|
||||
# 默认总结提示词模版
|
||||
SUMMARY_PROMPT_TEMPLATE = """
|
||||
你是 {bot_name}。{personality_context}
|
||||
现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。
|
||||
|
||||
聊天记录内容:
|
||||
{chat_history}
|
||||
|
||||
请完成以下任务:
|
||||
1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。
|
||||
2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。
|
||||
|
||||
请严格以 JSON 格式输出,格式如下:
|
||||
{{
|
||||
"summary": "总结文本内容",
|
||||
"entities": ["张三", "李四"],
|
||||
"relations": [
|
||||
{{"subject": "张三", "predicate": "认识", "object": "李四"}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。
|
||||
"""
|
||||
|
||||
class SummaryImporter:
|
||||
"""总结并导入知识的工具类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorStore,
|
||||
graph_store: GraphStore,
|
||||
metadata_store: MetadataStore,
|
||||
embedding_manager: EmbeddingAPIAdapter,
|
||||
plugin_config: dict
|
||||
):
|
||||
self.vector_store = vector_store
|
||||
self.graph_store = graph_store
|
||||
self.metadata_store = metadata_store
|
||||
self.embedding_manager = embedding_manager
|
||||
self.plugin_config = plugin_config
|
||||
self.relation_write_service: Optional[RelationWriteService] = (
|
||||
plugin_config.get("relation_write_service")
|
||||
if isinstance(plugin_config, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
def _allow_metadata_only_write(self) -> bool:
|
||||
plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None
|
||||
getter = getattr(plugin_instance, "get_config", None)
|
||||
if callable(getter):
|
||||
return bool(getter("embedding.fallback.allow_metadata_only_write", True))
|
||||
if isinstance(self.plugin_config, dict):
|
||||
embedding_cfg = self.plugin_config.get("embedding", {}) or {}
|
||||
fallback_cfg = embedding_cfg.get("fallback", {}) if isinstance(embedding_cfg, dict) else {}
|
||||
if isinstance(fallback_cfg, dict):
|
||||
return bool(fallback_cfg.get("allow_metadata_only_write", True))
|
||||
return True
|
||||
|
||||
def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]:
|
||||
"""标准化 summarization.model_name 配置(vNext 仅接受字符串数组)。"""
|
||||
if raw_value is None:
|
||||
return ["auto"]
|
||||
if isinstance(raw_value, list):
|
||||
selectors = [str(x).strip() for x in raw_value if str(x).strip()]
|
||||
return selectors or ["auto"]
|
||||
raise ValueError(
|
||||
"summarization.model_name 在 vNext 必须为 List[str]。"
|
||||
" 请执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
|
||||
def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]:
|
||||
"""
|
||||
选择总结默认任务,避免错误落到 embedding 任务。
|
||||
优先级:replyer > utils > planner > tool_use > 其他非 embedding。
|
||||
"""
|
||||
preferred = ("replyer", "utils", "planner", "tool_use")
|
||||
for name in preferred:
|
||||
cfg = available_tasks.get(name)
|
||||
if cfg and cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
for name, cfg in available_tasks.items():
|
||||
if name != "embedding" and cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
for name, cfg in available_tasks.items():
|
||||
if cfg.model_list:
|
||||
return name, cfg
|
||||
|
||||
return None, None
|
||||
|
||||
def _resolve_summary_model_config(self) -> Optional[TaskConfig]:
|
||||
"""
|
||||
解析 summarization.model_name 为 TaskConfig。
|
||||
支持:
|
||||
- "auto"
|
||||
- "replyer"(任务名)
|
||||
- "some-model-name"(具体模型名)
|
||||
- ["utils:model1", "utils:model2", "replyer"](数组混合语法)
|
||||
"""
|
||||
available_tasks = llm_api.get_available_models()
|
||||
if not available_tasks:
|
||||
return None
|
||||
|
||||
raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto")
|
||||
selectors = self._normalize_summary_model_selectors(raw_cfg)
|
||||
default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks)
|
||||
|
||||
selected_models: List[str] = []
|
||||
base_cfg: Optional[TaskConfig] = None
|
||||
model_dict = getattr(host_model_config, "models_dict", {})
|
||||
|
||||
def _append_models(models: List[str]):
|
||||
for model_name in models:
|
||||
if model_name and model_name not in selected_models:
|
||||
selected_models.append(model_name)
|
||||
|
||||
for raw_selector in selectors:
|
||||
selector = raw_selector.strip()
|
||||
if not selector:
|
||||
continue
|
||||
|
||||
if selector.lower() == "auto":
|
||||
if default_task_cfg:
|
||||
_append_models(default_task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = default_task_cfg
|
||||
continue
|
||||
|
||||
if ":" in selector:
|
||||
task_name, model_name = selector.split(":", 1)
|
||||
task_name = task_name.strip()
|
||||
model_name = model_name.strip()
|
||||
task_cfg = available_tasks.get(task_name)
|
||||
if not task_cfg:
|
||||
logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过")
|
||||
continue
|
||||
|
||||
if base_cfg is None:
|
||||
base_cfg = task_cfg
|
||||
|
||||
if not model_name or model_name.lower() == "auto":
|
||||
_append_models(task_cfg.model_list)
|
||||
continue
|
||||
|
||||
if model_name in model_dict or model_name in task_cfg.model_list:
|
||||
_append_models([model_name])
|
||||
else:
|
||||
logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过")
|
||||
continue
|
||||
|
||||
task_cfg = available_tasks.get(selector)
|
||||
if task_cfg:
|
||||
_append_models(task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = task_cfg
|
||||
continue
|
||||
|
||||
if selector in model_dict:
|
||||
_append_models([selector])
|
||||
continue
|
||||
|
||||
logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过")
|
||||
|
||||
if not selected_models:
|
||||
if default_task_cfg:
|
||||
_append_models(default_task_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = default_task_cfg
|
||||
else:
|
||||
first_cfg = next(iter(available_tasks.values()))
|
||||
_append_models(first_cfg.model_list)
|
||||
if base_cfg is None:
|
||||
base_cfg = first_cfg
|
||||
|
||||
if not selected_models:
|
||||
return None
|
||||
|
||||
template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values()))
|
||||
return TaskConfig(
|
||||
model_list=selected_models,
|
||||
max_tokens=template_cfg.max_tokens,
|
||||
temperature=template_cfg.temperature,
|
||||
slow_threshold=template_cfg.slow_threshold,
|
||||
selection_strategy=template_cfg.selection_strategy,
|
||||
)
|
||||
|
||||
async def import_from_stream(
|
||||
self,
|
||||
stream_id: str,
|
||||
context_length: Optional[int] = None,
|
||||
include_personality: Optional[bool] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
从指定的聊天流中提取记录并执行总结导入
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流 ID
|
||||
context_length: 总结的历史消息条数
|
||||
include_personality: 是否包含人设
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 结果消息)
|
||||
"""
|
||||
try:
|
||||
self_check_ok, self_check_msg = await self._ensure_runtime_self_check()
|
||||
if not self_check_ok:
|
||||
return False, f"导入前自检失败: {self_check_msg}"
|
||||
|
||||
# 1. 获取配置
|
||||
if context_length is None:
|
||||
context_length = self.plugin_config.get("summarization", {}).get("context_length", 50)
|
||||
|
||||
if include_personality is None:
|
||||
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
|
||||
|
||||
# 2. 获取历史消息
|
||||
# 获取当前时间之前的消息
|
||||
now = time.time()
|
||||
messages = message_api.get_messages_before_time_in_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=now,
|
||||
limit=context_length
|
||||
)
|
||||
|
||||
if not messages:
|
||||
return False, "未找到有效的聊天记录进行总结"
|
||||
|
||||
# 转换为可读文本
|
||||
chat_history_text = message_api.build_readable_messages(messages)
|
||||
|
||||
# 3. 准备提示词内容
|
||||
bot_name = global_config.bot.nickname or "机器人"
|
||||
personality_context = ""
|
||||
if include_personality:
|
||||
personality = getattr(global_config.bot, "personality", "")
|
||||
if personality:
|
||||
personality_context = f"你的性格设定是:{personality}"
|
||||
|
||||
# 4. 调用 LLM
|
||||
prompt = SUMMARY_PROMPT_TEMPLATE.format(
|
||||
bot_name=bot_name,
|
||||
personality_context=personality_context,
|
||||
chat_history=chat_history_text
|
||||
)
|
||||
|
||||
model_config_to_use = self._resolve_summary_model_config()
|
||||
if model_config_to_use is None:
|
||||
return False, "未找到可用的总结模型配置"
|
||||
task_name_to_use = llm_api.resolve_task_name_from_model_config(model_config_to_use)
|
||||
|
||||
logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
|
||||
logger.info(f"总结模型候选列表: {model_config_to_use.model_list}")
|
||||
|
||||
result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name=task_name_to_use,
|
||||
request_type="A_Memorix.ChatSummarization",
|
||||
prompt=prompt,
|
||||
temperature=getattr(model_config_to_use, "temperature", None),
|
||||
max_tokens=getattr(model_config_to_use, "max_tokens", None),
|
||||
)
|
||||
)
|
||||
success = bool(result.success)
|
||||
response = str(result.completion.response or "")
|
||||
|
||||
if not success or not response:
|
||||
return False, "LLM 生成总结失败"
|
||||
|
||||
# 5. 解析结果
|
||||
data = self._parse_llm_response(response)
|
||||
if not data or "summary" not in data:
|
||||
return False, "解析 LLM 响应失败或总结为空"
|
||||
|
||||
summary_text = data["summary"]
|
||||
entities = data.get("entities", [])
|
||||
relations = data.get("relations", [])
|
||||
msg_times = [
|
||||
float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)())
|
||||
for msg in messages
|
||||
if getattr(msg, "time", None) is not None
|
||||
]
|
||||
time_meta = {}
|
||||
if msg_times:
|
||||
time_meta = {
|
||||
"event_time_start": min(msg_times),
|
||||
"event_time_end": max(msg_times),
|
||||
"time_granularity": "minute",
|
||||
"time_confidence": 0.95,
|
||||
}
|
||||
|
||||
# 6. 执行导入
|
||||
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta)
|
||||
|
||||
# 7. 持久化
|
||||
self.vector_store.save()
|
||||
self.graph_store.save()
|
||||
|
||||
result_msg = (
|
||||
f"✅ 总结导入成功\n"
|
||||
f"📝 总结长度: {len(summary_text)}\n"
|
||||
f"📌 提取实体: {len(entities)}\n"
|
||||
f"🔗 提取关系: {len(relations)}"
|
||||
)
|
||||
return True, result_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}")
|
||||
return False, f"错误: {str(e)}"
|
||||
|
||||
async def _ensure_runtime_self_check(self) -> Tuple[bool, str]:
|
||||
plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None
|
||||
if plugin_instance is not None:
|
||||
report = await ensure_runtime_self_check(plugin_instance)
|
||||
else:
|
||||
report = await run_embedding_runtime_self_check(
|
||||
config=self.plugin_config,
|
||||
vector_store=self.vector_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
)
|
||||
if bool(report.get("ok", False)):
|
||||
return True, ""
|
||||
if self._allow_metadata_only_write():
|
||||
msg = (
|
||||
f"{report.get('message', 'unknown')} "
|
||||
f"(configured={report.get('configured_dimension', 0)}, "
|
||||
f"store={report.get('vector_store_dimension', 0)}, "
|
||||
f"encoded={report.get('encoded_dimension', 0)})"
|
||||
)
|
||||
logger.warning(f"总结导入进入 metadata-only 回退模式: {msg}")
|
||||
return True, "embedding_degraded_metadata_only"
|
||||
return (
|
||||
False,
|
||||
f"{report.get('message', 'unknown')} "
|
||||
f"(configured={report.get('configured_dimension', 0)}, "
|
||||
f"store={report.get('vector_store_dimension', 0)}, "
|
||||
f"encoded={report.get('encoded_dimension', 0)})",
|
||||
)
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析 LLM 返回的 JSON"""
|
||||
try:
|
||||
# 尝试查找 JSON
|
||||
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning(f"解析总结 JSON 失败: {e}")
|
||||
return {}
|
||||
|
||||
async def _execute_import(
|
||||
self,
|
||||
summary: str,
|
||||
entities: List[str],
|
||||
relations: List[Dict[str, str]],
|
||||
stream_id: str,
|
||||
time_meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""将数据写入存储"""
|
||||
# 获取默认知识类型
|
||||
type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative")
|
||||
try:
|
||||
knowledge_type = resolve_stored_knowledge_type(type_str, content=summary)
|
||||
except ValueError:
|
||||
logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative")
|
||||
knowledge_type = KnowledgeType.NARRATIVE
|
||||
|
||||
# 导入总结文本
|
||||
hash_value = self.metadata_store.add_paragraph(
|
||||
content=summary,
|
||||
source=f"chat_summary:{stream_id}",
|
||||
knowledge_type=knowledge_type.value,
|
||||
time_meta=time_meta,
|
||||
)
|
||||
|
||||
plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None
|
||||
vector_writer = getattr(plugin_instance, "write_paragraph_vector_or_enqueue", None)
|
||||
if callable(vector_writer):
|
||||
result = await vector_writer(
|
||||
paragraph_hash=hash_value,
|
||||
content=summary,
|
||||
context="summary_import",
|
||||
)
|
||||
if str(result.get("warning", "") or "").strip():
|
||||
logger.warning(f"总结导入段落进入回退写入: {result}")
|
||||
else:
|
||||
try:
|
||||
embedding = await self.embedding_manager.encode(summary)
|
||||
self.vector_store.add(
|
||||
vectors=embedding.reshape(1, -1),
|
||||
ids=[hash_value]
|
||||
)
|
||||
except Exception as exc:
|
||||
if not self._allow_metadata_only_write():
|
||||
raise
|
||||
logger.warning(f"总结导入段落向量写入失败,改为回填队列: {exc}")
|
||||
self.metadata_store.enqueue_paragraph_vector_backfill(hash_value, error=str(exc))
|
||||
|
||||
# 导入实体
|
||||
if entities:
|
||||
self.graph_store.add_nodes(entities)
|
||||
|
||||
# 导入关系
|
||||
rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {})
|
||||
if not isinstance(rv_cfg, dict):
|
||||
rv_cfg = {}
|
||||
write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True))
|
||||
for rel in relations:
|
||||
s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object")
|
||||
if all([s, p, o]):
|
||||
if self.relation_write_service is not None:
|
||||
await self.relation_write_service.upsert_relation_with_vector(
|
||||
subject=s,
|
||||
predicate=p,
|
||||
obj=o,
|
||||
confidence=1.0,
|
||||
source_paragraph=hash_value,
|
||||
write_vector=write_vector,
|
||||
)
|
||||
else:
|
||||
# 写入元数据
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
subject=s,
|
||||
predicate=p,
|
||||
obj=o,
|
||||
confidence=1.0,
|
||||
source_paragraph=hash_value
|
||||
)
|
||||
# 写入图数据库(写入 relation_hashes,确保后续可按关系精确修剪)
|
||||
self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash])
|
||||
try:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"总结导入完成: hash={hash_value[:8]}")
|
||||
170
src/A_memorix/core/utils/time_parser.py
Normal file
170
src/A_memorix/core/utils/time_parser.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
时间解析工具。
|
||||
|
||||
约束:
|
||||
1. 查询参数(Action/Command/Tool)仅接受结构化绝对时间:
|
||||
- YYYY/MM/DD
|
||||
- YYYY/MM/DD HH:mm
|
||||
2. 入库时允许更宽松格式(含时间戳、YYYY-MM-DD 等)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
|
||||
_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
|
||||
_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$")
|
||||
_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$")
|
||||
|
||||
_INGEST_FORMATS = [
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%d",
|
||||
]
|
||||
|
||||
_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"}
|
||||
|
||||
|
||||
def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float:
|
||||
"""解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。"""
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
raise ValueError("时间不能为空")
|
||||
|
||||
if _QUERY_DATE_RE.fullmatch(text):
|
||||
dt = datetime.strptime(text, "%Y/%m/%d")
|
||||
if is_end:
|
||||
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
|
||||
return dt.timestamp()
|
||||
|
||||
if _QUERY_MINUTE_RE.fullmatch(text):
|
||||
dt = datetime.strptime(text, "%Y/%m/%d %H:%M")
|
||||
return dt.timestamp()
|
||||
|
||||
raise ValueError(
|
||||
f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm"
|
||||
)
|
||||
|
||||
|
||||
def parse_query_time_range(
|
||||
time_from: Optional[str],
|
||||
time_to: Optional[str],
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""解析查询窗口并验证区间。"""
|
||||
ts_from = (
|
||||
parse_query_datetime_to_timestamp(time_from, is_end=False)
|
||||
if time_from
|
||||
else None
|
||||
)
|
||||
ts_to = (
|
||||
parse_query_datetime_to_timestamp(time_to, is_end=True)
|
||||
if time_to
|
||||
else None
|
||||
)
|
||||
|
||||
if ts_from is not None and ts_to is not None and ts_from > ts_to:
|
||||
raise ValueError("time_from 不能晚于 time_to")
|
||||
|
||||
return ts_from, ts_to
|
||||
|
||||
|
||||
def parse_ingest_datetime_to_timestamp(
|
||||
value: Any,
|
||||
is_end: bool = False,
|
||||
) -> Optional[float]:
|
||||
"""解析入库时间,允许 timestamp/常见字符串格式。"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
if _NUMERIC_RE.fullmatch(text):
|
||||
return float(text)
|
||||
|
||||
for fmt in _INGEST_FORMATS:
|
||||
try:
|
||||
dt = datetime.strptime(text, fmt)
|
||||
if fmt in _INGEST_DATE_FORMATS and is_end:
|
||||
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
|
||||
return dt.timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise ValueError(f"无法解析时间: {text}")
|
||||
|
||||
|
||||
def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""归一化 time_meta 到存储层字段。"""
|
||||
if not time_meta:
|
||||
return {}
|
||||
|
||||
normalized: Dict[str, Any] = {}
|
||||
|
||||
event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time"))
|
||||
event_start = parse_ingest_datetime_to_timestamp(
|
||||
time_meta.get("event_time_start"),
|
||||
is_end=False,
|
||||
)
|
||||
event_end = parse_ingest_datetime_to_timestamp(
|
||||
time_meta.get("event_time_end"),
|
||||
is_end=True,
|
||||
)
|
||||
|
||||
time_range = time_meta.get("time_range")
|
||||
if (
|
||||
isinstance(time_range, (list, tuple))
|
||||
and len(time_range) == 2
|
||||
):
|
||||
if event_start is None:
|
||||
event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False)
|
||||
if event_end is None:
|
||||
event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True)
|
||||
|
||||
if event_start is not None and event_end is not None and event_start > event_end:
|
||||
raise ValueError("event_time_start 不能晚于 event_time_end")
|
||||
|
||||
if event_time is not None:
|
||||
normalized["event_time"] = event_time
|
||||
if event_start is not None:
|
||||
normalized["event_time_start"] = event_start
|
||||
if event_end is not None:
|
||||
normalized["event_time_end"] = event_end
|
||||
|
||||
granularity = time_meta.get("time_granularity")
|
||||
if granularity:
|
||||
normalized["time_granularity"] = str(granularity)
|
||||
else:
|
||||
raw_time_values = [
|
||||
time_meta.get("event_time"),
|
||||
time_meta.get("event_time_start"),
|
||||
time_meta.get("event_time_end"),
|
||||
]
|
||||
has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None)
|
||||
normalized["time_granularity"] = "minute" if has_minute else "day"
|
||||
|
||||
confidence = time_meta.get("time_confidence")
|
||||
if confidence is not None:
|
||||
normalized["time_confidence"] = float(confidence)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def format_timestamp(ts: Optional[float]) -> Optional[str]:
|
||||
"""将 timestamp 格式化为 YYYY/MM/DD HH:mm。"""
|
||||
if ts is None:
|
||||
return None
|
||||
return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M")
|
||||
|
||||
3613
src/A_memorix/core/utils/web_import_manager.py
Normal file
3613
src/A_memorix/core/utils/web_import_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
310
src/A_memorix/host_service.py
Normal file
310
src/A_memorix/host_service.py
Normal file
@@ -0,0 +1,310 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.utils.toml_utils import save_toml_with_format
|
||||
|
||||
from .core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from .paths import config_path, repo_root, schema_path
|
||||
from .runtime_registry import set_runtime_kernel
|
||||
|
||||
logger = get_logger("a_memorix.host_service")
|
||||
|
||||
|
||||
def _to_builtin_data(obj: Any) -> Any:
|
||||
if hasattr(obj, "unwrap"):
|
||||
try:
|
||||
obj = obj.unwrap()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return {str(key): _to_builtin_data(value) for key, value in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_builtin_data(value) for value in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _backup_config_file(path: Path) -> Optional[Path]:
|
||||
if not path.exists():
|
||||
return None
|
||||
backup_name = f"{path.name}.backup.{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = path.parent / backup_name
|
||||
backup_path.write_text(path.read_text(encoding="utf-8"), encoding="utf-8")
|
||||
return backup_path
|
||||
|
||||
|
||||
class AMemorixHostService:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self._kernel: Optional[SDKMemoryKernel] = None
|
||||
self._config_cache: Dict[str, Any] | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
await self._ensure_kernel()
|
||||
|
||||
async def stop(self) -> None:
|
||||
async with self._lock:
|
||||
await self._shutdown_locked()
|
||||
|
||||
async def reload(self) -> None:
|
||||
async with self._lock:
|
||||
await self._shutdown_locked()
|
||||
self._config_cache = self._read_config()
|
||||
|
||||
await self._ensure_kernel()
|
||||
|
||||
def get_config_path(self) -> Path:
|
||||
return config_path()
|
||||
|
||||
def get_schema_path(self) -> Path:
|
||||
return schema_path()
|
||||
|
||||
def get_config_schema(self) -> Dict[str, Any]:
|
||||
path = self.get_schema_path()
|
||||
if not path.exists():
|
||||
return {
|
||||
"plugin_id": "a_memorix",
|
||||
"plugin_info": {
|
||||
"name": "A_Memorix",
|
||||
"version": "",
|
||||
"description": "A_Memorix 配置结构",
|
||||
"author": "A_Dawn",
|
||||
},
|
||||
"sections": {},
|
||||
"layout": {"type": "auto", "tabs": []},
|
||||
}
|
||||
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
return json.load(handle)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return dict(self._read_config())
|
||||
|
||||
def _build_default_config(self) -> Dict[str, Any]:
|
||||
schema = self.get_config_schema()
|
||||
sections = schema.get("sections") if isinstance(schema, dict) else None
|
||||
if not isinstance(sections, dict):
|
||||
return {}
|
||||
|
||||
defaults: Dict[str, Any] = {}
|
||||
for section_name, section_payload in sections.items():
|
||||
if not isinstance(section_payload, dict):
|
||||
continue
|
||||
fields = section_payload.get("fields")
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
|
||||
section_parts = [part for part in str(section_name or "").split(".") if part]
|
||||
if not section_parts:
|
||||
continue
|
||||
|
||||
section_target: Dict[str, Any] = defaults
|
||||
for part in section_parts:
|
||||
nested = section_target.get(part)
|
||||
if not isinstance(nested, dict):
|
||||
nested = {}
|
||||
section_target[part] = nested
|
||||
section_target = nested
|
||||
|
||||
for field_name, field_payload in fields.items():
|
||||
if not isinstance(field_payload, dict) or "default" not in field_payload:
|
||||
continue
|
||||
section_target[str(field_name)] = _to_builtin_data(field_payload.get("default"))
|
||||
|
||||
return defaults
|
||||
|
||||
def get_raw_config_with_meta(self) -> Dict[str, Any]:
|
||||
path = self.get_config_path()
|
||||
if path.exists():
|
||||
return {
|
||||
"config": path.read_text(encoding="utf-8"),
|
||||
"exists": True,
|
||||
"using_default": False,
|
||||
}
|
||||
|
||||
default_config = self._build_default_config()
|
||||
default_raw = tomlkit.dumps(default_config) if default_config else ""
|
||||
return {
|
||||
"config": default_raw,
|
||||
"exists": False,
|
||||
"using_default": True,
|
||||
}
|
||||
|
||||
def get_raw_config(self) -> str:
|
||||
payload = self.get_raw_config_with_meta()
|
||||
return str(payload.get("config", "") or "")
|
||||
|
||||
async def update_raw_config(self, raw_config: str) -> Dict[str, Any]:
|
||||
tomlkit.loads(raw_config)
|
||||
path = self.get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
backup_path = _backup_config_file(path)
|
||||
path.write_text(raw_config, encoding="utf-8")
|
||||
await self.reload()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置已保存",
|
||||
"backup_path": str(backup_path) if backup_path is not None else "",
|
||||
"config_path": str(path),
|
||||
}
|
||||
|
||||
async def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = self.get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
backup_path = _backup_config_file(path)
|
||||
save_toml_with_format(config, str(path), preserve_comments=True)
|
||||
await self.reload()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置已保存",
|
||||
"backup_path": str(backup_path) if backup_path is not None else "",
|
||||
"config_path": str(path),
|
||||
}
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None = None, *, timeout_ms: int = 30000) -> Any:
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
kernel = await self._ensure_kernel()
|
||||
|
||||
if component_name == "search_memory":
|
||||
return await kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "search") or "search"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or "").strip(),
|
||||
group_id=str(payload.get("group_id", "") or "").strip(),
|
||||
)
|
||||
)
|
||||
|
||||
if component_name == "ingest_summary":
|
||||
return await kernel.ingest_summary(
|
||||
external_id=str(payload.get("external_id", "") or ""),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
text=str(payload.get("text", "") or ""),
|
||||
participants=list(payload.get("participants") or []),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
tags=list(payload.get("tags") or []),
|
||||
metadata=payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {},
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or "").strip(),
|
||||
group_id=str(payload.get("group_id", "") or "").strip(),
|
||||
)
|
||||
|
||||
if component_name == "ingest_text":
|
||||
relations = payload.get("relations") if isinstance(payload.get("relations"), list) else []
|
||||
entities = payload.get("entities") if isinstance(payload.get("entities"), list) else []
|
||||
return await kernel.ingest_text(
|
||||
external_id=str(payload.get("external_id", "") or ""),
|
||||
source_type=str(payload.get("source_type", "") or ""),
|
||||
text=str(payload.get("text", "") or ""),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_ids=list(payload.get("person_ids") or []),
|
||||
participants=list(payload.get("participants") or []),
|
||||
timestamp=payload.get("timestamp"),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
tags=list(payload.get("tags") or []),
|
||||
metadata=payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {},
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or "").strip(),
|
||||
group_id=str(payload.get("group_id", "") or "").strip(),
|
||||
)
|
||||
|
||||
if component_name == "get_person_profile":
|
||||
return await kernel.get_person_profile(
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
limit=max(1, int(payload.get("limit", 10) or 10)),
|
||||
)
|
||||
|
||||
if component_name == "maintain_memory":
|
||||
return await kernel.maintain_memory(
|
||||
action=str(payload.get("action", "") or ""),
|
||||
target=str(payload.get("target", "") or ""),
|
||||
hours=payload.get("hours"),
|
||||
reason=str(payload.get("reason", "") or ""),
|
||||
limit=max(1, int(payload.get("limit", 50) or 50)),
|
||||
)
|
||||
|
||||
if component_name == "memory_stats":
|
||||
return kernel.memory_stats()
|
||||
|
||||
admin_actions = {
|
||||
"memory_graph_admin": kernel.memory_graph_admin,
|
||||
"memory_source_admin": kernel.memory_source_admin,
|
||||
"memory_episode_admin": kernel.memory_episode_admin,
|
||||
"memory_profile_admin": kernel.memory_profile_admin,
|
||||
"memory_runtime_admin": kernel.memory_runtime_admin,
|
||||
"memory_import_admin": kernel.memory_import_admin,
|
||||
"memory_tuning_admin": kernel.memory_tuning_admin,
|
||||
"memory_v5_admin": kernel.memory_v5_admin,
|
||||
"memory_delete_admin": kernel.memory_delete_admin,
|
||||
}
|
||||
if component_name in admin_actions:
|
||||
kwargs = dict(payload)
|
||||
action = str(kwargs.pop("action", "") or "")
|
||||
return await admin_actions[component_name](action=action, **kwargs)
|
||||
|
||||
raise RuntimeError(f"不支持的 A_Memorix 调用: {component_name}")
|
||||
|
||||
async def _ensure_kernel(self) -> SDKMemoryKernel:
|
||||
async with self._lock:
|
||||
if self._kernel is None:
|
||||
config = self._read_config()
|
||||
self._kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config)
|
||||
await self._kernel.initialize()
|
||||
set_runtime_kernel(self._kernel)
|
||||
return self._kernel
|
||||
|
||||
def _read_config(self) -> Dict[str, Any]:
|
||||
if self._config_cache is not None:
|
||||
return dict(self._config_cache)
|
||||
|
||||
path = self.get_config_path()
|
||||
if not path.exists():
|
||||
defaults = self._build_default_config()
|
||||
self._config_cache = defaults
|
||||
return dict(defaults)
|
||||
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
loaded = tomlkit.load(handle)
|
||||
except Exception as exc:
|
||||
logger.warning("读取 A_Memorix 配置失败 %s: %s", path, exc)
|
||||
defaults = self._build_default_config()
|
||||
self._config_cache = defaults
|
||||
return dict(defaults)
|
||||
|
||||
self._config_cache = _to_builtin_data(loaded) if isinstance(loaded, dict) else {}
|
||||
return dict(self._config_cache)
|
||||
|
||||
async def _shutdown_locked(self) -> None:
|
||||
if self._kernel is None:
|
||||
return
|
||||
shutdown = getattr(self._kernel, "shutdown", None)
|
||||
if callable(shutdown):
|
||||
await shutdown()
|
||||
else:
|
||||
self._kernel.close()
|
||||
self._kernel = None
|
||||
set_runtime_kernel(None)
|
||||
|
||||
|
||||
a_memorix_host_service = AMemorixHostService()
|
||||
56
src/A_memorix/paths.py
Normal file
56
src/A_memorix/paths.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
A_MEMORIX_SYSTEM_ID = "a_memorix"
|
||||
|
||||
|
||||
def package_root() -> Path:
|
||||
return Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def src_root() -> Path:
|
||||
return package_root().parent
|
||||
|
||||
|
||||
def repo_root() -> Path:
|
||||
return src_root().parent
|
||||
|
||||
|
||||
def config_path() -> Path:
|
||||
return repo_root() / "config" / f"{A_MEMORIX_SYSTEM_ID}.toml"
|
||||
|
||||
|
||||
def default_data_dir() -> Path:
|
||||
return repo_root() / "data" / "plugins" / "a-dawn.a-memorix"
|
||||
|
||||
|
||||
def artifacts_root() -> Path:
|
||||
return default_data_dir() / "artifacts"
|
||||
|
||||
|
||||
def schema_path() -> Path:
|
||||
return package_root() / "config_schema.json"
|
||||
|
||||
|
||||
def web_root() -> Path:
|
||||
return package_root() / "web"
|
||||
|
||||
|
||||
def scripts_root() -> Path:
|
||||
return package_root() / "scripts"
|
||||
|
||||
|
||||
def resolve_repo_path(raw_path: str | Path | None, *, fallback: Path | None = None) -> Path:
|
||||
if raw_path is None:
|
||||
return (fallback or default_data_dir()).resolve()
|
||||
|
||||
raw_value = str(raw_path).strip()
|
||||
if not raw_value:
|
||||
return (fallback or default_data_dir()).resolve()
|
||||
|
||||
candidate = Path(raw_value).expanduser()
|
||||
if candidate.is_absolute():
|
||||
return candidate.resolve()
|
||||
|
||||
return (repo_root() / candidate).resolve()
|
||||
290
src/A_memorix/plugin.py
Normal file
290
src/A_memorix/plugin.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Legacy compatibility entry for upstream/plugin-style integrations.
|
||||
|
||||
MaiBot 主线当前通过 `src.A_memorix.host_service` 直接接入 A_Memorix,
|
||||
不再通过插件运行时发现或加载本模块。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from maibot_sdk import MaiBotPlugin, Tool
|
||||
from maibot_sdk.types import ToolParameterInfo, ToolParamType
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from A_memorix.paths import repo_root
|
||||
|
||||
|
||||
def _tool_param(name: str, param_type: ToolParamType, description: str, required: bool) -> ToolParameterInfo:
|
||||
return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required)
|
||||
|
||||
|
||||
_ADMIN_TOOL_PARAMS = [
|
||||
_tool_param("action", ToolParamType.STRING, "管理动作", True),
|
||||
_tool_param("target", ToolParamType.STRING, "可选目标标识", False),
|
||||
]
|
||||
|
||||
|
||||
class AMemorixPlugin(MaiBotPlugin):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._plugin_root = repo_root()
|
||||
self._plugin_config: Dict[str, Any] = {}
|
||||
self._kernel: Optional[SDKMemoryKernel] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
self._plugin_config = config or {}
|
||||
if self._kernel is not None:
|
||||
self._kernel.close()
|
||||
self._kernel = None
|
||||
|
||||
async def on_load(self):
|
||||
await self._get_kernel()
|
||||
|
||||
async def on_unload(self):
|
||||
if self._kernel is not None:
|
||||
shutdown = getattr(self._kernel, "shutdown", None)
|
||||
if callable(shutdown):
|
||||
await shutdown()
|
||||
else:
|
||||
self._kernel.close()
|
||||
self._kernel = None
|
||||
|
||||
async def on_config_update(self, scope: str, config_data: dict[str, Any], version: str) -> None:
|
||||
_ = version
|
||||
if scope == "self":
|
||||
self.set_plugin_config(config_data if isinstance(config_data, dict) else {})
|
||||
return
|
||||
if scope in {"bot", "model"} and self._kernel is not None:
|
||||
shutdown = getattr(self._kernel, "shutdown", None)
|
||||
if callable(shutdown):
|
||||
await shutdown()
|
||||
else:
|
||||
self._kernel.close()
|
||||
self._kernel = None
|
||||
|
||||
async def _get_kernel(self) -> SDKMemoryKernel:
|
||||
if self._kernel is None:
|
||||
self._kernel = SDKMemoryKernel(plugin_root=self._plugin_root, config=self._plugin_config)
|
||||
await self._kernel.initialize()
|
||||
return self._kernel
|
||||
|
||||
async def _dispatch_admin_tool(self, method_name: str, action: str, **kwargs):
|
||||
kernel = await self._get_kernel()
|
||||
handler = getattr(kernel, method_name)
|
||||
return await handler(action=action, **kwargs)
|
||||
|
||||
@Tool(
|
||||
"search_memory",
|
||||
description="搜索长期记忆",
|
||||
parameters=[
|
||||
_tool_param("query", ToolParamType.STRING, "查询文本", False),
|
||||
_tool_param("limit", ToolParamType.INTEGER, "返回条数", False),
|
||||
_tool_param("mode", ToolParamType.STRING, "search/time/hybrid/episode/aggregate", False),
|
||||
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
|
||||
_tool_param("person_id", ToolParamType.STRING, "人物 ID", False),
|
||||
_tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False),
|
||||
_tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False),
|
||||
_tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False),
|
||||
],
|
||||
)
|
||||
async def handle_search_memory(
|
||||
self,
|
||||
query: str = "",
|
||||
limit: int = 5,
|
||||
mode: str = "search",
|
||||
chat_id: str = "",
|
||||
person_id: str = "",
|
||||
time_start: str | float | None = None,
|
||||
time_end: str | float | None = None,
|
||||
respect_filter: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=query,
|
||||
limit=limit,
|
||||
mode=mode,
|
||||
chat_id=chat_id,
|
||||
person_id=person_id,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
respect_filter=respect_filter,
|
||||
user_id=str(kwargs.get("user_id", "") or "").strip(),
|
||||
group_id=str(kwargs.get("group_id", "") or "").strip(),
|
||||
)
|
||||
)
|
||||
|
||||
@Tool(
|
||||
"ingest_summary",
|
||||
description="写入聊天摘要到长期记忆",
|
||||
parameters=[
|
||||
_tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True),
|
||||
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", True),
|
||||
_tool_param("text", ToolParamType.STRING, "摘要文本", True),
|
||||
_tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False),
|
||||
_tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False),
|
||||
_tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False),
|
||||
],
|
||||
)
|
||||
async def handle_ingest_summary(
|
||||
self,
|
||||
external_id: str,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
participants: Optional[List[str]] = None,
|
||||
time_start: float | None = None,
|
||||
time_end: float | None = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
respect_filter: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.ingest_summary(
|
||||
external_id=external_id,
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
participants=participants,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
respect_filter=respect_filter,
|
||||
user_id=str(kwargs.get("user_id", "") or "").strip(),
|
||||
group_id=str(kwargs.get("group_id", "") or "").strip(),
|
||||
)
|
||||
|
||||
@Tool(
|
||||
"ingest_text",
|
||||
description="写入普通长期记忆文本",
|
||||
parameters=[
|
||||
_tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True),
|
||||
_tool_param("source_type", ToolParamType.STRING, "来源类型", True),
|
||||
_tool_param("text", ToolParamType.STRING, "原始文本", True),
|
||||
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
|
||||
_tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False),
|
||||
_tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False),
|
||||
],
|
||||
)
|
||||
async def handle_ingest_text(
|
||||
self,
|
||||
external_id: str,
|
||||
source_type: str,
|
||||
text: str,
|
||||
chat_id: str = "",
|
||||
person_ids: Optional[List[str]] = None,
|
||||
participants: Optional[List[str]] = None,
|
||||
timestamp: float | None = None,
|
||||
time_start: float | None = None,
|
||||
time_end: float | None = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
respect_filter: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
relations = kwargs.get("relations")
|
||||
entities = kwargs.get("entities")
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.ingest_text(
|
||||
external_id=external_id,
|
||||
source_type=source_type,
|
||||
text=text,
|
||||
chat_id=chat_id,
|
||||
person_ids=person_ids,
|
||||
participants=participants,
|
||||
timestamp=timestamp,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
respect_filter=respect_filter,
|
||||
user_id=str(kwargs.get("user_id", "") or "").strip(),
|
||||
group_id=str(kwargs.get("group_id", "") or "").strip(),
|
||||
)
|
||||
|
||||
@Tool(
|
||||
"get_person_profile",
|
||||
description="获取人物画像",
|
||||
parameters=[
|
||||
_tool_param("person_id", ToolParamType.STRING, "人物 ID", True),
|
||||
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
|
||||
_tool_param("limit", ToolParamType.INTEGER, "证据条数", False),
|
||||
],
|
||||
)
|
||||
async def handle_get_person_profile(self, person_id: str, chat_id: str = "", limit: int = 10, **kwargs):
|
||||
_ = kwargs
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.get_person_profile(person_id=person_id, chat_id=chat_id, limit=limit)
|
||||
|
||||
@Tool(
|
||||
"maintain_memory",
|
||||
description="维护长期记忆关系状态",
|
||||
parameters=[
|
||||
_tool_param("action", ToolParamType.STRING, "reinforce/protect/restore/freeze/recycle_bin", True),
|
||||
_tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", False),
|
||||
_tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False),
|
||||
_tool_param("limit", ToolParamType.INTEGER, "查询条数(用于 recycle_bin)", False),
|
||||
],
|
||||
)
|
||||
async def handle_maintain_memory(
|
||||
self,
|
||||
action: str,
|
||||
target: str = "",
|
||||
hours: float | None = None,
|
||||
reason: str = "",
|
||||
limit: int = 50,
|
||||
**kwargs,
|
||||
):
|
||||
_ = kwargs
|
||||
kernel = await self._get_kernel()
|
||||
return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason, limit=limit)
|
||||
|
||||
@Tool("memory_stats", description="获取长期记忆统计", parameters=[])
|
||||
async def handle_memory_stats(self, **kwargs):
|
||||
_ = kwargs
|
||||
kernel = await self._get_kernel()
|
||||
return kernel.memory_stats()
|
||||
|
||||
@Tool("memory_graph_admin", description="长期记忆图谱管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_graph_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_graph_admin", action=action, **kwargs)
|
||||
|
||||
@Tool("memory_source_admin", description="长期记忆来源管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_source_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_source_admin", action=action, **kwargs)
|
||||
|
||||
@Tool("memory_episode_admin", description="Episode 管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_episode_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_episode_admin", action=action, **kwargs)
|
||||
|
||||
@Tool("memory_profile_admin", description="人物画像管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_profile_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_profile_admin", action=action, **kwargs)
|
||||
|
||||
@Tool("memory_runtime_admin", description="长期记忆运行时管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_runtime_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_runtime_admin", action=action, **kwargs)
|
||||
|
||||
@Tool("memory_import_admin", description="长期记忆导入管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_import_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_import_admin", action=action, **kwargs)
|
||||
|
||||
@Tool("memory_tuning_admin", description="长期记忆调优管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_tuning_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_tuning_admin", action=action, **kwargs)
|
||||
|
||||
@Tool("memory_v5_admin", description="长期记忆 V5 管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_v5_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_v5_admin", action=action, **kwargs)
|
||||
|
||||
@Tool("memory_delete_admin", description="长期记忆删除管理接口", parameters=_ADMIN_TOOL_PARAMS)
|
||||
async def handle_memory_delete_admin(self, action: str, **kwargs):
|
||||
return await self._dispatch_admin_tool("memory_delete_admin", action=action, **kwargs)
|
||||
|
||||
|
||||
def create_plugin():
|
||||
return AMemorixPlugin()
|
||||
52
src/A_memorix/requirements.txt
Normal file
52
src/A_memorix/requirements.txt
Normal file
@@ -0,0 +1,52 @@
|
||||
# A_Memorix 插件依赖
|
||||
#
|
||||
# 核心依赖 (必需)
|
||||
# ==================
|
||||
|
||||
# 数值计算 - 用于向量操作、矩阵计算
|
||||
numpy>=1.20.0
|
||||
|
||||
# 稀疏矩阵 - 用于图存储的邻接矩阵
|
||||
scipy>=1.7.0
|
||||
|
||||
# 图结构处理(LPMM 转换)
|
||||
networkx>=3.0.0
|
||||
|
||||
# Parquet 读取(LPMM 转换)
|
||||
pyarrow>=10.0.0
|
||||
|
||||
# DataFrame 处理(LPMM 转换)
|
||||
pandas>=1.5.0
|
||||
|
||||
# 异步事件循环嵌套 - 用于插件初始化时的异步操作
|
||||
nest-asyncio>=1.5.0
|
||||
|
||||
# 向量索引 - 用于向量存储和检索
|
||||
faiss-cpu>=1.7.0
|
||||
|
||||
# Web 服务器依赖 (可视化功能需要)
|
||||
# ==================
|
||||
|
||||
# ASGI 服务器
|
||||
uvicorn>=0.20.0
|
||||
|
||||
# Web 框架
|
||||
fastapi>=0.100.0
|
||||
|
||||
# 数据验证
|
||||
pydantic>=2.0.0
|
||||
python-multipart>=0.0.9
|
||||
|
||||
# 注意事项
|
||||
# ==================
|
||||
#
|
||||
# 1. sqlite3 是 Python 标准库,无需安装
|
||||
# 2. json, re, time, pathlib 等都是标准库
|
||||
# 3. sentence-transformers 不需要(使用主程序 Embedding API)
|
||||
|
||||
# UI 交互
|
||||
rich>=14.0.0
|
||||
tenacity>=8.0.0
|
||||
|
||||
# 稀疏检索中文分词(可选,未安装时自动回退 char n-gram)
|
||||
jieba>=0.42.1
|
||||
27
src/A_memorix/runtime_registry.py
Normal file
27
src/A_memorix/runtime_registry.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
_runtime_kernel: Any = None
|
||||
|
||||
|
||||
def set_runtime_kernel(kernel: Any | None) -> None:
|
||||
global _runtime_kernel
|
||||
_runtime_kernel = kernel
|
||||
|
||||
|
||||
def get_runtime_kernel() -> Any | None:
|
||||
return _runtime_kernel
|
||||
|
||||
|
||||
def get_runtime_components() -> Dict[str, Any]:
|
||||
kernel = get_runtime_kernel()
|
||||
if kernel is None:
|
||||
return {}
|
||||
return {
|
||||
"vector_store": getattr(kernel, "vector_store", None),
|
||||
"graph_store": getattr(kernel, "graph_store", None),
|
||||
"metadata_store": getattr(kernel, "metadata_store", None),
|
||||
"embedding_manager": getattr(kernel, "embedding_manager", None),
|
||||
"sparse_index": getattr(kernel, "sparse_index", None),
|
||||
}
|
||||
22
src/A_memorix/scripts/_bootstrap.py
Normal file
22
src/A_memorix/scripts/_bootstrap.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
PLUGIN_ROOT = CURRENT_DIR.parent
|
||||
SRC_ROOT = PLUGIN_ROOT.parent
|
||||
PROJECT_ROOT = SRC_ROOT.parent
|
||||
WORKSPACE_ROOT = PROJECT_ROOT
|
||||
MAIBOT_ROOT = PROJECT_ROOT
|
||||
|
||||
for _path in (SRC_ROOT, PROJECT_ROOT, PLUGIN_ROOT):
|
||||
_path_str = str(_path)
|
||||
if _path_str not in sys.path:
|
||||
sys.path.insert(0, _path_str)
|
||||
|
||||
from A_memorix.paths import config_path, default_data_dir, resolve_repo_path
|
||||
|
||||
DEFAULT_CONFIG_PATH = config_path()
|
||||
DEFAULT_DATA_DIR = default_data_dir()
|
||||
DEFAULT_DB_PATH = PROJECT_ROOT / "data" / "MaiBot.db"
|
||||
208
src/A_memorix/scripts/audit_vector_consistency.py
Normal file
208
src/A_memorix/scripts/audit_vector_consistency.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
A_Memorix 一致性审计脚本。
|
||||
|
||||
输出内容:
|
||||
1. paragraph/entity/relation 向量覆盖率
|
||||
2. relation vector_state 分布
|
||||
3. 孤儿向量数量(向量存在但 metadata 不存在)
|
||||
4. 状态与向量文件不一致统计
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pickle
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
from _bootstrap import DEFAULT_DATA_DIR, resolve_repo_path
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="审计 A_Memorix 向量一致性")
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default=str(DEFAULT_DATA_DIR),
|
||||
help="A_Memorix 数据目录(默认: data/plugins/a-dawn.a-memorix)",
|
||||
)
|
||||
parser.add_argument("--json-out", default="", help="可选:输出 JSON 文件路径")
|
||||
parser.add_argument(
|
||||
"--strict",
|
||||
action="store_true",
|
||||
help="若发现一致性异常则返回非 0 退出码",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
# --help/-h fast path: avoid heavy host/plugin bootstrap
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
from A_memorix.core.storage.vector_store import VectorStore
|
||||
from A_memorix.core.storage.metadata_store import MetadataStore
|
||||
from A_memorix.core.storage import QuantizationType
|
||||
except Exception as e: # pragma: no cover
|
||||
print(f"❌ 导入核心模块失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _safe_ratio(numerator: int, denominator: int) -> float:
|
||||
if denominator <= 0:
|
||||
return 0.0
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def _load_vector_store(data_dir: Path) -> VectorStore:
|
||||
meta_path = data_dir / "vectors" / "vectors_metadata.pkl"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"未找到向量元数据文件: {meta_path}")
|
||||
|
||||
with open(meta_path, "rb") as f:
|
||||
meta = pickle.load(f)
|
||||
dimension = int(meta.get("dimension", 1024))
|
||||
|
||||
store = VectorStore(
|
||||
dimension=max(1, dimension),
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=data_dir / "vectors",
|
||||
)
|
||||
if store.has_data():
|
||||
store.load()
|
||||
return store
|
||||
|
||||
|
||||
def _load_metadata_store(data_dir: Path) -> MetadataStore:
|
||||
store = MetadataStore(data_dir=data_dir / "metadata")
|
||||
store.connect()
|
||||
return store
|
||||
|
||||
|
||||
def _hash_set(metadata_store: MetadataStore, table: str) -> Set[str]:
|
||||
return {str(h) for h in metadata_store.list_hashes(table)}
|
||||
|
||||
|
||||
def _relation_state_stats(metadata_store: MetadataStore) -> Dict[str, int]:
|
||||
return metadata_store.count_relations_by_vector_state()
|
||||
|
||||
|
||||
def run_audit(data_dir: Path) -> Dict[str, Any]:
|
||||
vector_store = _load_vector_store(data_dir)
|
||||
metadata_store = _load_metadata_store(data_dir)
|
||||
try:
|
||||
paragraph_hashes = _hash_set(metadata_store, "paragraphs")
|
||||
entity_hashes = _hash_set(metadata_store, "entities")
|
||||
relation_hashes = _hash_set(metadata_store, "relations")
|
||||
|
||||
known_hashes = set(getattr(vector_store, "_known_hashes", set()))
|
||||
live_vector_hashes = {h for h in known_hashes if h in vector_store}
|
||||
|
||||
para_vector_hits = len(paragraph_hashes & live_vector_hashes)
|
||||
ent_vector_hits = len(entity_hashes & live_vector_hashes)
|
||||
rel_vector_hits = len(relation_hashes & live_vector_hashes)
|
||||
|
||||
orphan_vector_hashes = sorted(
|
||||
live_vector_hashes - paragraph_hashes - entity_hashes - relation_hashes
|
||||
)
|
||||
|
||||
relation_rows = metadata_store.get_relations()
|
||||
ready_but_missing = 0
|
||||
not_ready_but_present = 0
|
||||
for row in relation_rows:
|
||||
h = str(row.get("hash") or "")
|
||||
state = str(row.get("vector_state") or "none").lower()
|
||||
in_vector = h in live_vector_hashes
|
||||
if state == "ready" and not in_vector:
|
||||
ready_but_missing += 1
|
||||
if state != "ready" and in_vector:
|
||||
not_ready_but_present += 1
|
||||
|
||||
relation_states = _relation_state_stats(metadata_store)
|
||||
rel_total = max(0, int(relation_states.get("total", len(relation_hashes))))
|
||||
ready_count = max(0, int(relation_states.get("ready", 0)))
|
||||
|
||||
result = {
|
||||
"counts": {
|
||||
"paragraphs": len(paragraph_hashes),
|
||||
"entities": len(entity_hashes),
|
||||
"relations": len(relation_hashes),
|
||||
"vectors_live": len(live_vector_hashes),
|
||||
},
|
||||
"coverage": {
|
||||
"paragraph_vector_coverage": _safe_ratio(para_vector_hits, len(paragraph_hashes)),
|
||||
"entity_vector_coverage": _safe_ratio(ent_vector_hits, len(entity_hashes)),
|
||||
"relation_vector_coverage": _safe_ratio(rel_vector_hits, len(relation_hashes)),
|
||||
"relation_ready_coverage": _safe_ratio(ready_count, rel_total),
|
||||
},
|
||||
"relation_states": relation_states,
|
||||
"orphans": {
|
||||
"vector_only_count": len(orphan_vector_hashes),
|
||||
"vector_only_sample": orphan_vector_hashes[:30],
|
||||
},
|
||||
"consistency_checks": {
|
||||
"ready_but_missing_vector": ready_but_missing,
|
||||
"not_ready_but_vector_present": not_ready_but_present,
|
||||
},
|
||||
}
|
||||
return result
|
||||
finally:
|
||||
metadata_store.close()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
data_dir = resolve_repo_path(args.data_dir, fallback=DEFAULT_DATA_DIR)
|
||||
if not data_dir.exists():
|
||||
print(f"❌ 数据目录不存在: {data_dir}")
|
||||
return 2
|
||||
|
||||
try:
|
||||
result = run_audit(data_dir)
|
||||
except Exception as e:
|
||||
print(f"❌ 审计失败: {e}")
|
||||
return 2
|
||||
|
||||
print("=== A_Memorix Vector Consistency Audit ===")
|
||||
print(f"data_dir: {data_dir}")
|
||||
print(f"paragraphs: {result['counts']['paragraphs']}")
|
||||
print(f"entities: {result['counts']['entities']}")
|
||||
print(f"relations: {result['counts']['relations']}")
|
||||
print(f"vectors_live: {result['counts']['vectors_live']}")
|
||||
print(
|
||||
"coverage: "
|
||||
f"paragraph={result['coverage']['paragraph_vector_coverage']:.3f}, "
|
||||
f"entity={result['coverage']['entity_vector_coverage']:.3f}, "
|
||||
f"relation={result['coverage']['relation_vector_coverage']:.3f}, "
|
||||
f"relation_ready={result['coverage']['relation_ready_coverage']:.3f}"
|
||||
)
|
||||
print(f"relation_states: {result['relation_states']}")
|
||||
print(
|
||||
"consistency_checks: "
|
||||
f"ready_but_missing_vector={result['consistency_checks']['ready_but_missing_vector']}, "
|
||||
f"not_ready_but_vector_present={result['consistency_checks']['not_ready_but_vector_present']}"
|
||||
)
|
||||
print(f"orphan_vectors: {result['orphans']['vector_only_count']}")
|
||||
|
||||
if args.json_out:
|
||||
out_path = Path(args.json_out).resolve()
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"json_out: {out_path}")
|
||||
|
||||
has_anomaly = (
|
||||
result["orphans"]["vector_only_count"] > 0
|
||||
or result["consistency_checks"]["ready_but_missing_vector"] > 0
|
||||
)
|
||||
if args.strict and has_anomaly:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
265
src/A_memorix/scripts/backfill_relation_vectors.py
Normal file
265
src/A_memorix/scripts/backfill_relation_vectors.py
Normal file
@@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
关系向量一次性回填脚本(灰度/离线执行)。
|
||||
|
||||
用途:
|
||||
1. 对 relations 中 vector_state in (none, failed, pending) 的记录补齐向量。
|
||||
2. 支持并发控制,降低总耗时。
|
||||
3. 可作为灰度阶段验证工具,与 audit_vector_consistency.py 配合使用。
|
||||
4. 可选自动纳入“ready 但向量缺失”的漂移记录进行修复。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tomlkit
|
||||
|
||||
from _bootstrap import DEFAULT_CONFIG_PATH, DEFAULT_DATA_DIR, resolve_repo_path
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="关系向量一次性回填")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=str(DEFAULT_CONFIG_PATH),
|
||||
help="配置文件路径(默认 config/a_memorix.toml)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default=str(DEFAULT_DATA_DIR),
|
||||
help="数据目录(默认 data/plugins/a-dawn.a-memorix)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--states",
|
||||
default="none,failed,pending",
|
||||
help="待处理状态列表,逗号分隔",
|
||||
)
|
||||
parser.add_argument("--limit", type=int, default=50000, help="最大处理数量")
|
||||
parser.add_argument("--concurrency", type=int, default=8, help="并发数")
|
||||
parser.add_argument("--max-retry", type=int, default=None, help="最大重试次数过滤")
|
||||
parser.add_argument(
|
||||
"--include-ready-missing",
|
||||
action="store_true",
|
||||
help="额外纳入 vector_state=ready 但向量缺失的关系",
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅统计候选,不写入")
|
||||
return parser
|
||||
|
||||
|
||||
# --help/-h fast path: avoid heavy host/plugin bootstrap
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
raise SystemExit(0)
|
||||
|
||||
from A_memorix.core.storage import (
|
||||
VectorStore,
|
||||
GraphStore,
|
||||
MetadataStore,
|
||||
QuantizationType,
|
||||
SparseMatrixFormat,
|
||||
)
|
||||
from A_memorix.core.embedding import create_embedding_api_adapter
|
||||
from A_memorix.core.utils.relation_write_service import RelationWriteService
|
||||
|
||||
|
||||
def _load_config(config_path: Path) -> Dict[str, Any]:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
raw = tomlkit.load(f)
|
||||
return dict(raw) if isinstance(raw, dict) else {}
|
||||
|
||||
|
||||
def _build_vector_store(data_dir: Path, emb_cfg: Dict[str, Any]) -> VectorStore:
|
||||
q_type = str(emb_cfg.get("quantization_type", "int8")).lower()
|
||||
if q_type != "int8":
|
||||
raise ValueError(
|
||||
"embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。"
|
||||
" 请先执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
dim = int(emb_cfg.get("dimension", 1024))
|
||||
store = VectorStore(
|
||||
dimension=max(1, dim),
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=data_dir / "vectors",
|
||||
)
|
||||
if store.has_data():
|
||||
store.load()
|
||||
return store
|
||||
|
||||
|
||||
def _build_graph_store(data_dir: Path, graph_cfg: Dict[str, Any]) -> GraphStore:
|
||||
fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower()
|
||||
fmt_map = {
|
||||
"csr": SparseMatrixFormat.CSR,
|
||||
"csc": SparseMatrixFormat.CSC,
|
||||
}
|
||||
store = GraphStore(
|
||||
matrix_format=fmt_map.get(fmt, SparseMatrixFormat.CSR),
|
||||
data_dir=data_dir / "graph",
|
||||
)
|
||||
if store.has_data():
|
||||
store.load()
|
||||
return store
|
||||
|
||||
|
||||
def _build_metadata_store(data_dir: Path) -> MetadataStore:
|
||||
store = MetadataStore(data_dir=data_dir / "metadata")
|
||||
store.connect()
|
||||
return store
|
||||
|
||||
|
||||
def _build_embedding_manager(emb_cfg: Dict[str, Any]):
|
||||
retry_cfg = emb_cfg.get("retry", {})
|
||||
if not isinstance(retry_cfg, dict):
|
||||
retry_cfg = {}
|
||||
return create_embedding_api_adapter(
|
||||
batch_size=int(emb_cfg.get("batch_size", 32)),
|
||||
max_concurrent=int(emb_cfg.get("max_concurrent", 5)),
|
||||
default_dimension=int(emb_cfg.get("dimension", 1024)),
|
||||
model_name=str(emb_cfg.get("model_name", "auto")),
|
||||
retry_config=retry_cfg,
|
||||
)
|
||||
|
||||
|
||||
async def _process_rows(
|
||||
service: RelationWriteService,
|
||||
rows: List[Dict[str, Any]],
|
||||
concurrency: int,
|
||||
) -> Dict[str, int]:
|
||||
semaphore = asyncio.Semaphore(max(1, int(concurrency)))
|
||||
stat = {"success": 0, "failed": 0, "skipped": 0}
|
||||
|
||||
async def _worker(row: Dict[str, Any]) -> None:
|
||||
async with semaphore:
|
||||
result = await service.ensure_relation_vector(
|
||||
hash_value=str(row["hash"]),
|
||||
subject=str(row.get("subject", "")),
|
||||
predicate=str(row.get("predicate", "")),
|
||||
obj=str(row.get("object", "")),
|
||||
)
|
||||
if result.vector_state == "ready":
|
||||
if result.vector_written:
|
||||
stat["success"] += 1
|
||||
else:
|
||||
stat["skipped"] += 1
|
||||
else:
|
||||
stat["failed"] += 1
|
||||
|
||||
await asyncio.gather(*[_worker(row) for row in rows])
|
||||
return stat
|
||||
|
||||
|
||||
async def main_async(args: argparse.Namespace) -> int:
|
||||
config_path = resolve_repo_path(args.config, fallback=DEFAULT_CONFIG_PATH)
|
||||
if not config_path.exists():
|
||||
print(f"❌ 配置文件不存在: {config_path}")
|
||||
return 2
|
||||
|
||||
cfg = _load_config(config_path)
|
||||
emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {}
|
||||
graph_cfg = cfg.get("graph", {}) if isinstance(cfg, dict) else {}
|
||||
retrieval_cfg = cfg.get("retrieval", {}) if isinstance(cfg, dict) else {}
|
||||
rv_cfg = retrieval_cfg.get("relation_vectorization", {}) if isinstance(retrieval_cfg, dict) else {}
|
||||
if not isinstance(emb_cfg, dict):
|
||||
emb_cfg = {}
|
||||
if not isinstance(graph_cfg, dict):
|
||||
graph_cfg = {}
|
||||
if not isinstance(rv_cfg, dict):
|
||||
rv_cfg = {}
|
||||
|
||||
data_dir = resolve_repo_path(args.data_dir, fallback=DEFAULT_DATA_DIR)
|
||||
if not data_dir.exists():
|
||||
print(f"❌ 数据目录不存在: {data_dir}")
|
||||
return 2
|
||||
|
||||
print(f"data_dir: {data_dir}")
|
||||
print(f"config: {config_path}")
|
||||
|
||||
vector_store = _build_vector_store(data_dir, emb_cfg)
|
||||
graph_store = _build_graph_store(data_dir, graph_cfg)
|
||||
metadata_store = _build_metadata_store(data_dir)
|
||||
embedding_manager = _build_embedding_manager(emb_cfg)
|
||||
service = RelationWriteService(
|
||||
metadata_store=metadata_store,
|
||||
graph_store=graph_store,
|
||||
vector_store=vector_store,
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
|
||||
try:
|
||||
states = [s.strip() for s in str(args.states).split(",") if s.strip()]
|
||||
if not states:
|
||||
states = ["none", "failed", "pending"]
|
||||
max_retry = int(args.max_retry) if args.max_retry is not None else int(rv_cfg.get("max_retry", 3))
|
||||
limit = int(args.limit)
|
||||
|
||||
rows = metadata_store.list_relations_by_vector_state(
|
||||
states=states,
|
||||
limit=max(1, limit),
|
||||
max_retry=max(1, max_retry),
|
||||
)
|
||||
added_ready_missing = 0
|
||||
if args.include_ready_missing:
|
||||
ready_rows = metadata_store.list_relations_by_vector_state(
|
||||
states=["ready"],
|
||||
limit=max(1, limit),
|
||||
max_retry=max(1, max_retry),
|
||||
)
|
||||
ready_missing_rows = [
|
||||
row for row in ready_rows if str(row.get("hash", "")) not in vector_store
|
||||
]
|
||||
added_ready_missing = len(ready_missing_rows)
|
||||
if ready_missing_rows:
|
||||
dedup: Dict[str, Dict[str, Any]] = {}
|
||||
for row in rows:
|
||||
dedup[str(row.get("hash", ""))] = row
|
||||
for row in ready_missing_rows:
|
||||
dedup.setdefault(str(row.get("hash", "")), row)
|
||||
rows = list(dedup.values())[: max(1, limit)]
|
||||
print(f"candidates: {len(rows)} (states={states}, max_retry={max_retry})")
|
||||
if args.include_ready_missing:
|
||||
print(f"ready_missing_candidates_added: {added_ready_missing}")
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
if args.dry_run:
|
||||
print("dry_run=true,未执行写入。")
|
||||
return 0
|
||||
|
||||
started = time.time()
|
||||
stat = await _process_rows(
|
||||
service=service,
|
||||
rows=rows,
|
||||
concurrency=int(args.concurrency),
|
||||
)
|
||||
elapsed = (time.time() - started) * 1000.0
|
||||
|
||||
vector_store.save()
|
||||
graph_store.save()
|
||||
state_stats = metadata_store.count_relations_by_vector_state()
|
||||
output = {
|
||||
"processed": len(rows),
|
||||
"success": int(stat["success"]),
|
||||
"failed": int(stat["failed"]),
|
||||
"skipped": int(stat["skipped"]),
|
||||
"elapsed_ms": elapsed,
|
||||
"state_stats": state_stats,
|
||||
}
|
||||
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||||
return 0 if stat["failed"] == 0 else 1
|
||||
finally:
|
||||
metadata_store.close()
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
return _build_arg_parser().parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arguments = parse_args()
|
||||
raise SystemExit(asyncio.run(main_async(arguments)))
|
||||
65
src/A_memorix/scripts/backfill_temporal_metadata.py
Normal file
65
src/A_memorix/scripts/backfill_temporal_metadata.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
回填段落时序字段。
|
||||
|
||||
默认策略:
|
||||
1. 若段落缺失 event_time/event_time_start/event_time_end
|
||||
2. 且存在 created_at
|
||||
3. 写入 event_time=created_at, time_granularity=day, time_confidence=0.2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
from _bootstrap import DEFAULT_DATA_DIR, resolve_repo_path
|
||||
from A_memorix.core.storage import MetadataStore # noqa: E402
|
||||
|
||||
|
||||
def backfill(
|
||||
data_dir: Path,
|
||||
dry_run: bool,
|
||||
limit: int,
|
||||
no_created_fallback: bool,
|
||||
) -> int:
|
||||
store = MetadataStore(data_dir=data_dir)
|
||||
store.connect()
|
||||
summary = store.backfill_temporal_metadata_from_created_at(
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
no_created_fallback=no_created_fallback,
|
||||
)
|
||||
store.close()
|
||||
if dry_run:
|
||||
print(f"[dry-run] candidates={summary['candidates']}")
|
||||
return int(summary["candidates"])
|
||||
if no_created_fallback:
|
||||
print(f"skip update (no-created-fallback), candidates={summary['candidates']}")
|
||||
return 0
|
||||
print(f"updated={summary['updated']}")
|
||||
return int(summary["updated"])
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Backfill temporal metadata for A_Memorix paragraphs")
|
||||
parser.add_argument("--data-dir", default=str(DEFAULT_DATA_DIR), help="数据目录")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅统计,不写入")
|
||||
parser.add_argument("--limit", type=int, default=100000, help="最大处理条数")
|
||||
parser.add_argument(
|
||||
"--no-created-fallback",
|
||||
action="store_true",
|
||||
help="不使用 created_at 回填,仅输出候选数量",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
backfill(
|
||||
data_dir=resolve_repo_path(args.data_dir, fallback=DEFAULT_DATA_DIR),
|
||||
dry_run=args.dry_run,
|
||||
limit=max(1, int(args.limit)),
|
||||
no_created_fallback=args.no_created_fallback,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
530
src/A_memorix/scripts/convert_lpmm.py
Normal file
530
src/A_memorix/scripts/convert_lpmm.py
Normal file
@@ -0,0 +1,530 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LPMM 到 A_memorix 存储转换器
|
||||
|
||||
功能:
|
||||
1. 读取 LPMM parquet 文件 (paragraph.parquet, entity.parquet, relation.parquet)
|
||||
2. 读取 LPMM 图文件 (graph.graphml 或 graph_structure.pkl)
|
||||
3. 直接写入 A_memorix 二进制 VectorStore 和稀疏 GraphStore
|
||||
4. 绕过 Embedding 生成以节省 Token
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import asyncio
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Tuple
|
||||
import numpy as np
|
||||
import tomlkit
|
||||
|
||||
from _bootstrap import DEFAULT_CONFIG_PATH, resolve_repo_path
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="将 LPMM 数据转换为 A_memorix 格式")
|
||||
parser.add_argument("--input", "-i", required=True, help="包含 LPMM 数据的输入目录 (parquet, graphml)")
|
||||
parser.add_argument("--output", "-o", required=True, help="A_memorix 数据的输出目录")
|
||||
parser.add_argument("--dim", type=int, default=384, help="Embedding 维度 (必须与 LPMM 模型匹配)")
|
||||
parser.add_argument("--batch-size", type=int, default=1024, help="Parquet 分批读取大小 (默认 1024)")
|
||||
parser.add_argument(
|
||||
"--skip-relation-vector-rebuild",
|
||||
action="store_true",
|
||||
help="跳过按关系元数据重建关系向量(默认开启)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
# --help/-h fast path: avoid heavy host/plugin bootstrap
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
sys.exit(0)
|
||||
|
||||
# 设置日志:优先复用 MaiBot 统一日志体系,失败时回退到标准 logging。
|
||||
try:
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("A_Memorix.LPMMConverter")
|
||||
except Exception:
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger("A_Memorix.LPMMConverter")
|
||||
|
||||
try:
|
||||
import networkx as nx
|
||||
from scipy import sparse
|
||||
import pyarrow.parquet as pq
|
||||
except ImportError as e:
|
||||
logger.error(f"缺少依赖: {e}")
|
||||
logger.error("请安装: pip install pandas pyarrow networkx scipy")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from A_memorix.core.storage.vector_store import VectorStore
|
||||
from A_memorix.core.storage.graph_store import GraphStore
|
||||
from A_memorix.core.storage.metadata_store import MetadataStore
|
||||
from A_memorix.core.storage import QuantizationType, SparseMatrixFormat
|
||||
from A_memorix.core.embedding import create_embedding_api_adapter
|
||||
from A_memorix.core.utils.relation_write_service import RelationWriteService
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入 A_memorix 核心模块: {e}")
|
||||
logger.error("请确保在正确的环境中运行,且已安装所有依赖。")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class LPMMConverter:
|
||||
def __init__(
|
||||
self,
|
||||
lpmm_data_dir: Path,
|
||||
output_dir: Path,
|
||||
dimension: int = 384,
|
||||
batch_size: int = 1024,
|
||||
rebuild_relation_vectors: bool = True,
|
||||
):
|
||||
self.lpmm_dir = lpmm_data_dir
|
||||
self.output_dir = output_dir
|
||||
self.dimension = dimension
|
||||
self.batch_size = max(1, int(batch_size))
|
||||
self.rebuild_relation_vectors = bool(rebuild_relation_vectors)
|
||||
|
||||
self.vector_dir = output_dir / "vectors"
|
||||
self.graph_dir = output_dir / "graph"
|
||||
self.metadata_dir = output_dir / "metadata"
|
||||
|
||||
self.vector_store = None
|
||||
self.graph_store = None
|
||||
self.metadata_store = None
|
||||
self.embedding_manager = None
|
||||
self.relation_write_service = None
|
||||
# LPMM 原 ID -> A_memorix ID 映射(用于图重写)
|
||||
self.id_mapping: Dict[str, str] = {}
|
||||
|
||||
def _register_id_mapping(self, raw_id: Any, mapped_id: str, p_type: str) -> None:
|
||||
"""记录 ID 映射,兼容带/不带类型前缀两种格式。"""
|
||||
if raw_id is None:
|
||||
return
|
||||
|
||||
raw = str(raw_id).strip()
|
||||
if not raw:
|
||||
return
|
||||
|
||||
self.id_mapping[raw] = mapped_id
|
||||
|
||||
prefix = f"{p_type}-"
|
||||
if raw.startswith(prefix):
|
||||
self.id_mapping[raw[len(prefix):]] = mapped_id
|
||||
else:
|
||||
self.id_mapping[prefix + raw] = mapped_id
|
||||
|
||||
def _map_node_id(self, node: Any) -> str:
|
||||
"""将图节点 ID 映射到转换后的 A_memorix ID。"""
|
||||
node_key = str(node)
|
||||
return self.id_mapping.get(node_key, node_key)
|
||||
|
||||
def initialize_stores(self):
|
||||
"""初始化空的 A_memorix 存储"""
|
||||
logger.info(f"正在初始化存储于 {self.output_dir}...")
|
||||
|
||||
# 初始化 VectorStore (A_memorix 默认使用 INT8 量化)
|
||||
self.vector_store = VectorStore(
|
||||
dimension=self.dimension,
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=self.vector_dir
|
||||
)
|
||||
self.vector_store.clear() # 清空旧数据
|
||||
|
||||
# 初始化 GraphStore (使用 CSR 格式)
|
||||
self.graph_store = GraphStore(
|
||||
matrix_format=SparseMatrixFormat.CSR,
|
||||
data_dir=self.graph_dir
|
||||
)
|
||||
self.graph_store.clear()
|
||||
|
||||
# 初始化 MetadataStore
|
||||
self.metadata_store = MetadataStore(data_dir=self.metadata_dir)
|
||||
self.metadata_store.connect()
|
||||
# 清空元数据表?理想情况下是的,但要小心。
|
||||
# 对于转换,我们假设是全新的开始或覆盖。
|
||||
# A_memorix 中的 MetadataStore 通常使用 SQLite。
|
||||
# 如果目录是新的,我们会依赖它创建新文件。
|
||||
if self.rebuild_relation_vectors:
|
||||
self._init_relation_vector_service()
|
||||
|
||||
def _load_plugin_config(self) -> Dict[str, Any]:
|
||||
config_path = DEFAULT_CONFIG_PATH
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
parsed = tomlkit.load(f)
|
||||
return dict(parsed) if isinstance(parsed, dict) else {}
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 config.toml 失败,使用默认 embedding 配置: {e}")
|
||||
return {}
|
||||
|
||||
def _init_relation_vector_service(self) -> None:
|
||||
if not self.rebuild_relation_vectors:
|
||||
return
|
||||
cfg = self._load_plugin_config()
|
||||
emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {}
|
||||
if not isinstance(emb_cfg, dict):
|
||||
emb_cfg = {}
|
||||
try:
|
||||
self.embedding_manager = create_embedding_api_adapter(
|
||||
batch_size=int(emb_cfg.get("batch_size", 32)),
|
||||
max_concurrent=int(emb_cfg.get("max_concurrent", 5)),
|
||||
default_dimension=int(emb_cfg.get("dimension", self.dimension)),
|
||||
model_name=str(emb_cfg.get("model_name", "auto")),
|
||||
retry_config=emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {},
|
||||
)
|
||||
self.relation_write_service = RelationWriteService(
|
||||
metadata_store=self.metadata_store,
|
||||
graph_store=self.graph_store,
|
||||
vector_store=self.vector_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
)
|
||||
except Exception as e:
|
||||
self.embedding_manager = None
|
||||
self.relation_write_service = None
|
||||
logger.warning(f"初始化关系向量重建服务失败,将跳过关系向量回填: {e}")
|
||||
|
||||
async def _rebuild_relation_vectors(self) -> None:
|
||||
if not self.rebuild_relation_vectors:
|
||||
return
|
||||
if self.relation_write_service is None:
|
||||
logger.warning("关系向量重建已启用,但写入服务不可用,已跳过。")
|
||||
return
|
||||
|
||||
rows = self.metadata_store.get_relations()
|
||||
if not rows:
|
||||
logger.info("未发现关系元数据,无需重建关系向量。")
|
||||
return
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
result = await self.relation_write_service.ensure_relation_vector(
|
||||
hash_value=str(row["hash"]),
|
||||
subject=str(row.get("subject", "")),
|
||||
predicate=str(row.get("predicate", "")),
|
||||
obj=str(row.get("object", "")),
|
||||
)
|
||||
if result.vector_state == "ready":
|
||||
if result.vector_written:
|
||||
success += 1
|
||||
else:
|
||||
skipped += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
logger.info(
|
||||
"关系向量重建完成: "
|
||||
f"total={len(rows)} "
|
||||
f"success={success} "
|
||||
f"skipped={skipped} "
|
||||
f"failed={failed}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_relation_text(text: str) -> Tuple[str, str, str]:
|
||||
raw = str(text or "").strip()
|
||||
if not raw:
|
||||
return "", "", ""
|
||||
if "|" in raw:
|
||||
parts = [p.strip() for p in raw.split("|") if p.strip()]
|
||||
if len(parts) >= 3:
|
||||
return parts[0], parts[1], parts[2]
|
||||
if "->" in raw:
|
||||
parts = [p.strip() for p in raw.split("->") if p.strip()]
|
||||
if len(parts) >= 3:
|
||||
return parts[0], parts[1], parts[2]
|
||||
pieces = raw.split()
|
||||
if len(pieces) >= 3:
|
||||
return pieces[0], pieces[1], " ".join(pieces[2:])
|
||||
return "", "", ""
|
||||
|
||||
def _import_relation_metadata_from_parquet(self, relation_path: Path) -> int:
|
||||
if not relation_path.exists():
|
||||
return 0
|
||||
|
||||
try:
|
||||
parquet_file = pq.ParquetFile(relation_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 relation.parquet 失败,跳过关系元数据导入: {e}")
|
||||
return 0
|
||||
|
||||
cols = set(parquet_file.schema_arrow.names)
|
||||
has_triple_cols = {"subject", "predicate", "object"}.issubset(cols)
|
||||
content_col = "str" if "str" in cols else ("content" if "content" in cols else "")
|
||||
|
||||
imported_hashes = set()
|
||||
imported = 0
|
||||
for record_batch in parquet_file.iter_batches(batch_size=self.batch_size):
|
||||
df_batch = record_batch.to_pandas()
|
||||
for _, row in df_batch.iterrows():
|
||||
subject = ""
|
||||
predicate = ""
|
||||
obj = ""
|
||||
if has_triple_cols:
|
||||
subject = str(row.get("subject", "") or "").strip()
|
||||
predicate = str(row.get("predicate", "") or "").strip()
|
||||
obj = str(row.get("object", "") or "").strip()
|
||||
elif content_col:
|
||||
subject, predicate, obj = self._parse_relation_text(row.get(content_col, ""))
|
||||
|
||||
if not (subject and predicate and obj):
|
||||
continue
|
||||
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
source_paragraph=None,
|
||||
)
|
||||
if rel_hash in imported_hashes:
|
||||
continue
|
||||
imported_hashes.add(rel_hash)
|
||||
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
|
||||
try:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
except Exception:
|
||||
pass
|
||||
imported += 1
|
||||
|
||||
return imported
|
||||
|
||||
def convert_vectors(self):
|
||||
"""将 Parquet 向量转换为 VectorStore"""
|
||||
# LPMM 默认文件名
|
||||
parquet_files = {
|
||||
"paragraph": self.lpmm_dir / "paragraph.parquet",
|
||||
"entity": self.lpmm_dir / "entity.parquet",
|
||||
"relation": self.lpmm_dir / "relation.parquet"
|
||||
}
|
||||
|
||||
total_vectors = 0
|
||||
|
||||
for p_type, p_path in parquet_files.items():
|
||||
# 关系向量在当前脚本中无法保证与 MetadataStore 的关系记录一一对应,
|
||||
# 直接导入会污染召回结果(命中后无法反查 relation 元数据)。
|
||||
if p_type == "relation":
|
||||
relation_count = self._import_relation_metadata_from_parquet(p_path)
|
||||
logger.warning(
|
||||
"跳过 relation.parquet 向量导入(保持一致性);"
|
||||
f"已导入关系元数据: {relation_count}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not p_path.exists():
|
||||
logger.warning(f"文件未找到: {p_path}, 跳过 {p_type} 向量。")
|
||||
continue
|
||||
|
||||
logger.info(f"正在处理 {p_type} 向量,来源: {p_path}...")
|
||||
try:
|
||||
parquet_file = pq.ParquetFile(p_path)
|
||||
total_rows = parquet_file.metadata.num_rows
|
||||
if total_rows == 0:
|
||||
logger.info(f"{p_path} 为空,跳过。")
|
||||
continue
|
||||
|
||||
# LPMM Schema: 'hash', 'embedding', 'str'
|
||||
cols = parquet_file.schema_arrow.names
|
||||
# 兼容性检查
|
||||
content_col = 'str' if 'str' in cols else 'content'
|
||||
emb_col = 'embedding'
|
||||
hash_col = 'hash'
|
||||
|
||||
if content_col not in cols or emb_col not in cols:
|
||||
logger.error(f"{p_path} 中缺少必要列 (需包含 {content_col}, {emb_col})。发现: {cols}")
|
||||
continue
|
||||
|
||||
batch_columns = [content_col, emb_col]
|
||||
if hash_col in cols:
|
||||
batch_columns.append(hash_col)
|
||||
|
||||
processed_rows = 0
|
||||
added_for_type = 0
|
||||
batch_idx = 0
|
||||
|
||||
for record_batch in parquet_file.iter_batches(
|
||||
batch_size=self.batch_size,
|
||||
columns=batch_columns,
|
||||
):
|
||||
batch_idx += 1
|
||||
df_batch = record_batch.to_pandas()
|
||||
|
||||
embeddings_list = []
|
||||
ids_list = []
|
||||
|
||||
# 同时处理元数据映射
|
||||
for _, row in df_batch.iterrows():
|
||||
processed_rows += 1
|
||||
content = row[content_col]
|
||||
emb = row[emb_col]
|
||||
|
||||
if content is None or (isinstance(content, float) and np.isnan(content)):
|
||||
continue
|
||||
content = str(content).strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if emb is None or len(emb) == 0:
|
||||
continue
|
||||
|
||||
# 先写 MetadataStore,并使用其返回的真实 hash 作为向量 ID
|
||||
# 保证检索返回 ID 可以直接反查元数据。
|
||||
store_id = None
|
||||
if p_type == "paragraph":
|
||||
store_id = self.metadata_store.add_paragraph(
|
||||
content=content,
|
||||
source="lpmm_import",
|
||||
knowledge_type="factual",
|
||||
)
|
||||
elif p_type == "entity":
|
||||
store_id = self.metadata_store.add_entity(name=content)
|
||||
else:
|
||||
continue
|
||||
|
||||
raw_hash = row[hash_col] if hash_col in df_batch.columns else None
|
||||
if raw_hash is not None and not (isinstance(raw_hash, float) and np.isnan(raw_hash)):
|
||||
self._register_id_mapping(raw_hash, store_id, p_type)
|
||||
|
||||
# 确保 embedding 是 numpy 数组
|
||||
emb_np = np.array(emb, dtype=np.float32)
|
||||
if emb_np.shape[0] != self.dimension:
|
||||
logger.error(f"维度不匹配: {emb_np.shape[0]} vs {self.dimension}")
|
||||
continue
|
||||
|
||||
embeddings_list.append(emb_np)
|
||||
ids_list.append(store_id)
|
||||
|
||||
if embeddings_list:
|
||||
# 分批添加到向量存储
|
||||
vectors_np = np.stack(embeddings_list)
|
||||
count = self.vector_store.add(vectors_np, ids_list)
|
||||
added_for_type += count
|
||||
total_vectors += count
|
||||
|
||||
if batch_idx == 1 or batch_idx % 10 == 0:
|
||||
logger.info(
|
||||
f"[{p_type}] 批次 {batch_idx}: 已扫描 {processed_rows}/{total_rows}, 已导入 {added_for_type}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{p_type} 向量处理完成:总扫描 {processed_rows},总导入 {added_for_type}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 {p_path} 时出错: {e}")
|
||||
|
||||
# 提交向量存储
|
||||
self.vector_store.save()
|
||||
logger.info(f"向量转换完成。总向量数: {total_vectors}")
|
||||
|
||||
def convert_graph(self):
|
||||
"""将 LPMM 图转换为 GraphStore"""
|
||||
# LPMM 默认文件名是 rag-graph.graphml
|
||||
graph_files = [
|
||||
self.lpmm_dir / "rag-graph.graphml",
|
||||
self.lpmm_dir / "graph.graphml",
|
||||
self.lpmm_dir / "graph_structure.pkl"
|
||||
]
|
||||
|
||||
nx_graph = None
|
||||
|
||||
for g_path in graph_files:
|
||||
if g_path.exists():
|
||||
logger.info(f"发现图文件: {g_path}")
|
||||
try:
|
||||
if g_path.suffix == ".graphml":
|
||||
nx_graph = nx.read_graphml(g_path)
|
||||
elif g_path.suffix == ".pkl":
|
||||
with open(g_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
# LPMM 可能会将图存储在包装类中
|
||||
if hasattr(data, "graph") and isinstance(data.graph, nx.Graph):
|
||||
nx_graph = data.graph
|
||||
elif isinstance(data, nx.Graph):
|
||||
nx_graph = data
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"加载 {g_path} 失败: {e}")
|
||||
|
||||
if nx_graph is None:
|
||||
logger.warning("未找到有效的图文件。跳过图转换。")
|
||||
return
|
||||
|
||||
logger.info(f"已加载图,包含 {nx_graph.number_of_nodes()} 个节点和 {nx_graph.number_of_edges()} 条边。")
|
||||
|
||||
# 1. 添加节点
|
||||
# LPMM 节点通常是哈希或带前缀的字符串。
|
||||
# 我们需要将它们映射到 A_memorix 格式。
|
||||
# 如果 LPMM 使用 "entity-HASH",则与 A_memorix 匹配。
|
||||
|
||||
nodes_to_add = []
|
||||
node_attrs = {}
|
||||
|
||||
for node, attrs in nx_graph.nodes(data=True):
|
||||
# 假设 LPMM 使用一致的命名 "entity-..." 或 "paragraph-..."
|
||||
mapped_node = self._map_node_id(node)
|
||||
nodes_to_add.append(mapped_node)
|
||||
if attrs:
|
||||
node_attrs[mapped_node] = attrs
|
||||
|
||||
self.graph_store.add_nodes(nodes_to_add, node_attrs)
|
||||
|
||||
# 2. 添加边
|
||||
edges_to_add = []
|
||||
weights = []
|
||||
|
||||
for u, v, data in nx_graph.edges(data=True):
|
||||
weight = data.get("weight", 1.0)
|
||||
edges_to_add.append((self._map_node_id(u), self._map_node_id(v)))
|
||||
weights.append(float(weight))
|
||||
|
||||
# 如果可能,将关系同步到 MetadataStore
|
||||
# 但图的边并不总是包含关系谓词
|
||||
# 如果 LPMM 边数据有 'predicate',我们可以添加到元数据
|
||||
# 通常 LPMM 边是加权和,谓词信息可能在简单图中丢失
|
||||
|
||||
if edges_to_add:
|
||||
self.graph_store.add_edges(edges_to_add, weights)
|
||||
|
||||
self.graph_store.save()
|
||||
logger.info("图转换完成。")
|
||||
|
||||
def run(self):
|
||||
self.initialize_stores()
|
||||
self.convert_vectors()
|
||||
self.convert_graph()
|
||||
asyncio.run(self._rebuild_relation_vectors())
|
||||
self.vector_store.save()
|
||||
self.graph_store.save()
|
||||
self.metadata_store.close()
|
||||
logger.info("所有转换成功完成。")
|
||||
|
||||
|
||||
def main():
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = resolve_repo_path(args.input)
|
||||
output_path = resolve_repo_path(args.output)
|
||||
|
||||
if not input_path.exists():
|
||||
logger.error(f"输入目录不存在: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
converter = LPMMConverter(
|
||||
input_path,
|
||||
output_path,
|
||||
dimension=args.dim,
|
||||
batch_size=args.batch_size,
|
||||
rebuild_relation_vectors=not bool(args.skip_relation_vector_rebuild),
|
||||
)
|
||||
converter.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
165
src/A_memorix/scripts/import_lpmm_json.py
Normal file
165
src/A_memorix/scripts/import_lpmm_json.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LPMM OpenIE JSON 导入工具。
|
||||
|
||||
功能:
|
||||
1. 读取符合 LPMM 规范的 OpenIE JSON 文件
|
||||
2. 转换为 A_Memorix 的统一导入格式
|
||||
3. 复用 `process_knowledge.py` 中的 `AutoImporter` 直接入库
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
||||
|
||||
console = Console()
|
||||
|
||||
import _bootstrap # noqa: F401
|
||||
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="将 LPMM OpenIE JSON 导入 A_Memorix")
|
||||
parser.add_argument("path", help="LPMM JSON 文件路径或目录")
|
||||
parser.add_argument("--force", action="store_true", help="强制重新导入")
|
||||
parser.add_argument("--concurrency", "-c", type=int, default=5, help="并发数")
|
||||
return parser
|
||||
|
||||
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
raise SystemExit(0)
|
||||
|
||||
|
||||
try:
|
||||
from process_knowledge import AutoImporter
|
||||
from A_memorix.core.utils.hash import compute_paragraph_hash
|
||||
from src.common.logger import get_logger
|
||||
except ImportError as exc: # pragma: no cover - script bootstrap
|
||||
print(f"导入模块失败,请确认 PYTHONPATH 与工作区结构: {exc}")
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
logger = get_logger("A_Memorix.LPMMImport")
|
||||
|
||||
|
||||
class LPMMConverter:
|
||||
def convert_lpmm_to_memorix(self, lpmm_data: Dict[str, Any], filename: str) -> Dict[str, Any]:
|
||||
memorix_data = {"paragraphs": [], "entities": []}
|
||||
docs = lpmm_data.get("docs", []) or []
|
||||
if not docs:
|
||||
logger.warning(f"文件中未找到 docs 字段: {filename}")
|
||||
return memorix_data
|
||||
|
||||
all_entities = set()
|
||||
for doc in docs:
|
||||
content = str(doc.get("passage", "") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
relations: List[Dict[str, str]] = []
|
||||
for triple in doc.get("extracted_triples", []) or []:
|
||||
if isinstance(triple, list) and len(triple) == 3:
|
||||
relations.append(
|
||||
{
|
||||
"subject": str(triple[0] or "").strip(),
|
||||
"predicate": str(triple[1] or "").strip(),
|
||||
"object": str(triple[2] or "").strip(),
|
||||
}
|
||||
)
|
||||
|
||||
entities = [str(item or "").strip() for item in doc.get("extracted_entities", []) or [] if str(item or "").strip()]
|
||||
all_entities.update(entities)
|
||||
for relation in relations:
|
||||
if relation["subject"]:
|
||||
all_entities.add(relation["subject"])
|
||||
if relation["object"]:
|
||||
all_entities.add(relation["object"])
|
||||
|
||||
memorix_data["paragraphs"].append(
|
||||
{
|
||||
"hash": compute_paragraph_hash(content),
|
||||
"content": content,
|
||||
"source": filename,
|
||||
"entities": entities,
|
||||
"relations": relations,
|
||||
}
|
||||
)
|
||||
|
||||
memorix_data["entities"] = sorted(all_entities)
|
||||
return memorix_data
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
target_path = Path(args.path)
|
||||
if not target_path.exists():
|
||||
logger.error(f"路径不存在: {target_path}")
|
||||
return
|
||||
|
||||
if target_path.is_dir():
|
||||
files_to_process = list(target_path.glob("*-openie.json")) or list(target_path.glob("*.json"))
|
||||
else:
|
||||
files_to_process = [target_path]
|
||||
|
||||
if not files_to_process:
|
||||
logger.error("未找到可处理的 JSON 文件")
|
||||
return
|
||||
|
||||
importer = AutoImporter(force=bool(args.force), concurrency=int(args.concurrency))
|
||||
if not await importer.initialize():
|
||||
logger.error("初始化存储失败")
|
||||
return
|
||||
|
||||
converter = LPMMConverter()
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
transient=False,
|
||||
) as progress:
|
||||
for json_file in files_to_process:
|
||||
logger.info(f"正在转换并导入: {json_file.name}")
|
||||
try:
|
||||
with open(json_file, "r", encoding="utf-8") as handle:
|
||||
lpmm_data = json.load(handle)
|
||||
memorix_data = converter.convert_lpmm_to_memorix(lpmm_data, json_file.name)
|
||||
total_items = len(memorix_data.get("paragraphs", []))
|
||||
if total_items <= 0:
|
||||
logger.warning(f"转换结果为空: {json_file.name}")
|
||||
continue
|
||||
|
||||
task_id = progress.add_task(f"Importing {json_file.name}", total=total_items)
|
||||
|
||||
def update_progress(step: int = 1) -> None:
|
||||
progress.advance(task_id, advance=step)
|
||||
|
||||
await importer.import_json_data(
|
||||
memorix_data,
|
||||
filename=f"lpmm_{json_file.name}",
|
||||
progress_callback=update_progress,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"处理文件 {json_file.name} 失败: {exc}\n{traceback.format_exc()}")
|
||||
|
||||
await importer.close()
|
||||
logger.info("全部处理完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(main())
|
||||
99
src/A_memorix/scripts/migrate_chat_history.py
Normal file
99
src/A_memorix/scripts/migrate_chat_history.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from _bootstrap import DEFAULT_DATA_DIR, DEFAULT_DB_PATH, PLUGIN_ROOT, resolve_repo_path
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="迁移 MaiBot chat_history 到 A_Memorix")
|
||||
parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径")
|
||||
parser.add_argument("--data-dir", default=str(DEFAULT_DATA_DIR), help="A_Memorix 数据目录")
|
||||
parser.add_argument("--limit", type=int, default=0, help="限制迁移条数,0 表示全部")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _to_timestamp(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(text).timestamp()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
async def _main() -> int:
|
||||
args = _parse_args()
|
||||
db_path = resolve_repo_path(args.db_path, fallback=DEFAULT_DB_PATH)
|
||||
if not db_path.exists():
|
||||
print(f"数据库不存在: {db_path}")
|
||||
return 1
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
sql = """
|
||||
SELECT id, session_id, start_timestamp, end_timestamp, participants, theme, keywords, summary
|
||||
FROM chat_history
|
||||
ORDER BY id ASC
|
||||
"""
|
||||
if int(args.limit or 0) > 0:
|
||||
sql += " LIMIT ?"
|
||||
rows = conn.execute(sql, (int(args.limit),)).fetchall()
|
||||
else:
|
||||
rows = conn.execute(sql).fetchall()
|
||||
conn.close()
|
||||
|
||||
print(f"chat_history 待处理: {len(rows)}")
|
||||
if args.dry_run:
|
||||
for row in rows[:5]:
|
||||
print(f"- id={row['id']} session={row['session_id']} theme={row['theme']}")
|
||||
return 0
|
||||
|
||||
data_dir = resolve_repo_path(args.data_dir, fallback=DEFAULT_DATA_DIR)
|
||||
kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": str(data_dir)}})
|
||||
await kernel.initialize()
|
||||
migrated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
participants = json.loads(row["participants"]) if row["participants"] else []
|
||||
keywords = json.loads(row["keywords"]) if row["keywords"] else []
|
||||
theme = str(row["theme"] or "").strip()
|
||||
summary = str(row["summary"] or "").strip()
|
||||
text = f"主题:{theme}\n概括:{summary}".strip()
|
||||
result: Dict[str, Any] = await kernel.ingest_summary(
|
||||
external_id=f"chat_history:{row['id']}",
|
||||
chat_id=str(row["session_id"] or ""),
|
||||
text=text,
|
||||
participants=participants,
|
||||
time_start=_to_timestamp(row["start_timestamp"]),
|
||||
time_end=_to_timestamp(row["end_timestamp"]),
|
||||
tags=keywords,
|
||||
metadata={"theme": theme, "source_row_id": int(row["id"])},
|
||||
)
|
||||
if result.get("stored_ids"):
|
||||
migrated += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
print(f"迁移完成: migrated={migrated} skipped={skipped}")
|
||||
print(json.dumps(kernel.memory_stats(), ensure_ascii=False))
|
||||
kernel.close()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(asyncio.run(_main()))
|
||||
1743
src/A_memorix/scripts/migrate_maibot_memory.py
Normal file
1743
src/A_memorix/scripts/migrate_maibot_memory.py
Normal file
File diff suppressed because it is too large
Load Diff
109
src/A_memorix/scripts/migrate_person_memory_points.py
Normal file
109
src/A_memorix/scripts/migrate_person_memory_points.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from _bootstrap import DEFAULT_DATA_DIR, DEFAULT_DB_PATH, PLUGIN_ROOT, resolve_repo_path
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="迁移 MaiBot person_info.memory_points 到 A_Memorix")
|
||||
parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径")
|
||||
parser.add_argument("--data-dir", default=str(DEFAULT_DATA_DIR), help="A_Memorix 数据目录")
|
||||
parser.add_argument("--limit", type=int, default=0, help="限制迁移人数,0 表示全部")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_memory_points(raw_value: Any) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
values = json.loads(raw_value) if raw_value else []
|
||||
except Exception:
|
||||
values = []
|
||||
items: List[Dict[str, Any]] = []
|
||||
for index, item in enumerate(values):
|
||||
text = str(item or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
parts = text.split(":")
|
||||
if len(parts) >= 3:
|
||||
category = parts[0].strip()
|
||||
content = ":".join(parts[1:-1]).strip()
|
||||
weight = parts[-1].strip()
|
||||
else:
|
||||
category = "其他"
|
||||
content = text
|
||||
weight = "1.0"
|
||||
if content:
|
||||
items.append({"index": index, "category": category or "其他", "content": content, "weight": weight or "1.0"})
|
||||
return items
|
||||
|
||||
|
||||
async def _main() -> int:
|
||||
args = _parse_args()
|
||||
db_path = resolve_repo_path(args.db_path, fallback=DEFAULT_DB_PATH)
|
||||
if not db_path.exists():
|
||||
print(f"数据库不存在: {db_path}")
|
||||
return 1
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
sql = """
|
||||
SELECT person_id, person_name, user_nickname, memory_points
|
||||
FROM person_info
|
||||
WHERE memory_points IS NOT NULL AND memory_points != ''
|
||||
ORDER BY id ASC
|
||||
"""
|
||||
if int(args.limit or 0) > 0:
|
||||
sql += " LIMIT ?"
|
||||
rows = conn.execute(sql, (int(args.limit),)).fetchall()
|
||||
else:
|
||||
rows = conn.execute(sql).fetchall()
|
||||
conn.close()
|
||||
|
||||
preview_total = sum(len(_parse_memory_points(row["memory_points"])) for row in rows)
|
||||
print(f"person_info 待迁移人物: {len(rows)} 记忆点: {preview_total}")
|
||||
if args.dry_run:
|
||||
for row in rows[:5]:
|
||||
print(f"- person_id={row['person_id']} person_name={row['person_name'] or row['user_nickname']}")
|
||||
return 0
|
||||
|
||||
data_dir = resolve_repo_path(args.data_dir, fallback=DEFAULT_DATA_DIR)
|
||||
kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": str(data_dir)}})
|
||||
await kernel.initialize()
|
||||
migrated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
person_id = str(row["person_id"] or "").strip()
|
||||
if not person_id:
|
||||
continue
|
||||
display_name = str(row["person_name"] or row["user_nickname"] or "").strip()
|
||||
for item in _parse_memory_points(row["memory_points"]):
|
||||
result: Dict[str, Any] = await kernel.ingest_text(
|
||||
external_id=f"person_memory:{person_id}:{item['index']}",
|
||||
source_type="person_fact",
|
||||
text=f"[{item['category']}] {item['content']}",
|
||||
person_ids=[person_id],
|
||||
tags=[item["category"]],
|
||||
entities=[person_id, display_name] if display_name else [person_id],
|
||||
metadata={"category": item["category"], "weight": item["weight"], "display_name": display_name},
|
||||
)
|
||||
if result.get("stored_ids"):
|
||||
migrated += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
print(f"迁移完成: migrated={migrated} skipped={skipped}")
|
||||
print(json.dumps(kernel.memory_stats(), ensure_ascii=False))
|
||||
kernel.close()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(asyncio.run(_main()))
|
||||
727
src/A_memorix/scripts/process_knowledge.py
Normal file
727
src/A_memorix/scripts/process_knowledge.py
Normal file
@@ -0,0 +1,727 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
知识库自动导入脚本 (Strategy-Aware Version)
|
||||
|
||||
功能:
|
||||
1. 扫描 data/plugins/a-dawn.a-memorix/raw 下的 .txt 文件
|
||||
2. 检查 data/import_manifest.json 确认是否已导入
|
||||
3. 使用 Strategy 模式处理文件 (Narrative/Factual/Quote)
|
||||
4. 将生成的数据直接存入 VectorStore/GraphStore/MetadataStore
|
||||
5. 更新 manifest
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
import tomlkit
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn
|
||||
from rich.console import Console
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
|
||||
console = Console()
|
||||
|
||||
class LLMGenerationError(Exception):
|
||||
pass
|
||||
|
||||
from _bootstrap import DEFAULT_CONFIG_PATH, DEFAULT_DATA_DIR
|
||||
|
||||
# 数据目录
|
||||
DATA_DIR = DEFAULT_DATA_DIR
|
||||
RAW_DIR = DATA_DIR / "raw"
|
||||
PROCESSED_DIR = DATA_DIR / "processed"
|
||||
MANIFEST_PATH = DATA_DIR / "import_manifest.json"
|
||||
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="A_Memorix Knowledge Importer (Strategy-Aware)")
|
||||
parser.add_argument("--force", action="store_true", help="Force re-import")
|
||||
parser.add_argument("--clear-manifest", action="store_true", help="Clear manifest")
|
||||
parser.add_argument(
|
||||
"--type",
|
||||
"-t",
|
||||
default="auto",
|
||||
help="Target import strategy override (auto/narrative/factual/quote)",
|
||||
)
|
||||
parser.add_argument("--concurrency", "-c", type=int, default=5)
|
||||
parser.add_argument(
|
||||
"--chat-log",
|
||||
action="store_true",
|
||||
help="聊天记录导入模式:强制 narrative 策略,并使用 LLM 语义抽取 event_time/event_time_range",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-reference-time",
|
||||
default=None,
|
||||
help="chat_log 模式的相对时间参考点(如 2026/02/12 10:30);不传则使用当前本地时间",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
# --help/-h fast path: avoid heavy host/plugin bootstrap
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
try:
|
||||
import A_memorix.core as core_module
|
||||
import A_memorix.core.storage as storage_module
|
||||
from src.common.logger import get_logger
|
||||
from src.services import llm_service as llm_api
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
VectorStore = core_module.VectorStore
|
||||
GraphStore = core_module.GraphStore
|
||||
MetadataStore = core_module.MetadataStore
|
||||
ImportStrategy = core_module.ImportStrategy
|
||||
create_embedding_api_adapter = core_module.create_embedding_api_adapter
|
||||
RelationWriteService = getattr(core_module, "RelationWriteService", None)
|
||||
|
||||
looks_like_quote_text = storage_module.looks_like_quote_text
|
||||
parse_import_strategy = storage_module.parse_import_strategy
|
||||
resolve_stored_knowledge_type = storage_module.resolve_stored_knowledge_type
|
||||
select_import_strategy = storage_module.select_import_strategy
|
||||
|
||||
from A_memorix.core.utils.time_parser import normalize_time_meta
|
||||
from A_memorix.core.utils.import_payloads import normalize_paragraph_import_item
|
||||
from A_memorix.core.strategies.base import BaseStrategy, ProcessedChunk, KnowledgeType as StratKnowledgeType
|
||||
from A_memorix.core.strategies.narrative import NarrativeStrategy
|
||||
from A_memorix.core.strategies.factual import FactualStrategy
|
||||
from A_memorix.core.strategies.quote import QuoteStrategy
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ 无法导入模块: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
logger = get_logger("A_Memorix.AutoImport")
|
||||
|
||||
|
||||
def _log_before_retry(retry_state) -> None:
|
||||
"""使用项目统一日志风格记录重试信息。"""
|
||||
exc = None
|
||||
if getattr(retry_state, "outcome", None) is not None and retry_state.outcome.failed:
|
||||
exc = retry_state.outcome.exception()
|
||||
next_sleep = getattr(getattr(retry_state, "next_action", None), "sleep", None)
|
||||
logger.warning(
|
||||
"LLM 调用即将重试: "
|
||||
f"attempt={getattr(retry_state, 'attempt_number', '?')} "
|
||||
f"next_sleep={next_sleep} "
|
||||
f"error={exc}"
|
||||
)
|
||||
|
||||
class AutoImporter:
|
||||
def __init__(
|
||||
self,
|
||||
force: bool = False,
|
||||
clear_manifest: bool = False,
|
||||
target_type: str = "auto",
|
||||
concurrency: int = 5,
|
||||
chat_log: bool = False,
|
||||
chat_reference_time: Optional[str] = None,
|
||||
):
|
||||
self.vector_store: Optional[VectorStore] = None
|
||||
self.graph_store: Optional[GraphStore] = None
|
||||
self.metadata_store: Optional[MetadataStore] = None
|
||||
self.embedding_manager = None
|
||||
self.relation_write_service = None
|
||||
self.plugin_config = {}
|
||||
self.manifest = {}
|
||||
self.force = force
|
||||
self.clear_manifest = clear_manifest
|
||||
self.chat_log = chat_log
|
||||
parsed_target_type = parse_import_strategy(target_type, default=ImportStrategy.AUTO)
|
||||
self.target_type = ImportStrategy.NARRATIVE.value if chat_log else parsed_target_type.value
|
||||
self.chat_reference_dt = self._parse_reference_time(chat_reference_time)
|
||||
if self.chat_log and parsed_target_type not in {ImportStrategy.AUTO, ImportStrategy.NARRATIVE}:
|
||||
logger.warning(
|
||||
f"chat_log 模式已启用,target_type={target_type} 将被覆盖为 narrative"
|
||||
)
|
||||
self.concurrency_limit = concurrency
|
||||
self.semaphore = None
|
||||
self.storage_lock = None
|
||||
|
||||
async def initialize(self):
|
||||
logger.info(f"正在初始化... (并发数: {self.concurrency_limit})")
|
||||
self.semaphore = asyncio.Semaphore(self.concurrency_limit)
|
||||
self.storage_lock = asyncio.Lock()
|
||||
|
||||
RAW_DIR.mkdir(parents=True, exist_ok=True)
|
||||
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.clear_manifest:
|
||||
logger.info("🧹 清理 Mainfest")
|
||||
self.manifest = {}
|
||||
self._save_manifest()
|
||||
elif MANIFEST_PATH.exists():
|
||||
try:
|
||||
with open(MANIFEST_PATH, "r", encoding="utf-8") as f:
|
||||
self.manifest = json.load(f)
|
||||
except Exception:
|
||||
self.manifest = {}
|
||||
|
||||
config_path = DEFAULT_CONFIG_PATH
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
self.plugin_config = tomlkit.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 A_Memorix 配置失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._init_stores()
|
||||
except Exception as e:
|
||||
logger.error(f"初始化存储失败: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _init_stores(self):
|
||||
# ... (Same as original)
|
||||
self.embedding_manager = create_embedding_api_adapter(
|
||||
batch_size=self.plugin_config.get("embedding", {}).get("batch_size", 32),
|
||||
default_dimension=self.plugin_config.get("embedding", {}).get("dimension", 384),
|
||||
model_name=self.plugin_config.get("embedding", {}).get("model_name", "auto"),
|
||||
retry_config=self.plugin_config.get("embedding", {}).get("retry", {}),
|
||||
)
|
||||
try:
|
||||
dim = await self.embedding_manager._detect_dimension()
|
||||
except:
|
||||
dim = self.embedding_manager.default_dimension
|
||||
|
||||
q_type_str = str(self.plugin_config.get("embedding", {}).get("quantization_type", "int8") or "int8").lower()
|
||||
# Need to access QuantizationType from storage_module if not imported globally
|
||||
QuantizationType = storage_module.QuantizationType
|
||||
if q_type_str != "int8":
|
||||
raise ValueError(
|
||||
"embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。"
|
||||
" 请先执行 scripts/release_vnext_migrate.py migrate。"
|
||||
)
|
||||
|
||||
self.vector_store = VectorStore(
|
||||
dimension=dim,
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=DATA_DIR / "vectors"
|
||||
)
|
||||
|
||||
SparseMatrixFormat = storage_module.SparseMatrixFormat
|
||||
m_fmt_str = self.plugin_config.get("graph", {}).get("sparse_matrix_format", "csr")
|
||||
m_map = {"csr": SparseMatrixFormat.CSR, "csc": SparseMatrixFormat.CSC}
|
||||
|
||||
self.graph_store = GraphStore(
|
||||
matrix_format=m_map.get(m_fmt_str, SparseMatrixFormat.CSR),
|
||||
data_dir=DATA_DIR / "graph"
|
||||
)
|
||||
|
||||
self.metadata_store = MetadataStore(data_dir=DATA_DIR / "metadata")
|
||||
self.metadata_store.connect()
|
||||
|
||||
if RelationWriteService is not None:
|
||||
self.relation_write_service = RelationWriteService(
|
||||
metadata_store=self.metadata_store,
|
||||
graph_store=self.graph_store,
|
||||
vector_store=self.vector_store,
|
||||
embedding_manager=self.embedding_manager,
|
||||
)
|
||||
|
||||
if self.vector_store.has_data(): self.vector_store.load()
|
||||
if self.graph_store.has_data(): self.graph_store.load()
|
||||
|
||||
def _should_write_relation_vectors(self) -> bool:
|
||||
retrieval_cfg = self.plugin_config.get("retrieval", {})
|
||||
if not isinstance(retrieval_cfg, dict):
|
||||
return False
|
||||
rv_cfg = retrieval_cfg.get("relation_vectorization", {})
|
||||
if not isinstance(rv_cfg, dict):
|
||||
return False
|
||||
return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True))
|
||||
|
||||
def load_file(self, file_path: Path) -> str:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def get_file_hash(self, content: str) -> str:
|
||||
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
|
||||
def _parse_reference_time(self, value: Optional[str]) -> datetime:
|
||||
"""解析 chat_log 模式的参考时间(用于相对时间语义解析)。"""
|
||||
if not value:
|
||||
return datetime.now()
|
||||
formats = [
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%d",
|
||||
]
|
||||
text = str(value).strip()
|
||||
for fmt in formats:
|
||||
try:
|
||||
return datetime.strptime(text, fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
logger.warning(
|
||||
f"无法解析 chat_reference_time={value},将回退为当前本地时间"
|
||||
)
|
||||
return datetime.now()
|
||||
|
||||
async def _extract_chat_time_meta_with_llm(
|
||||
self,
|
||||
text: str,
|
||||
model_config: Any,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
使用 LLM 从聊天文本语义中抽取时间信息。
|
||||
支持将相对时间表达转换为绝对时间。
|
||||
"""
|
||||
if not text.strip():
|
||||
return None
|
||||
|
||||
reference_now = self.chat_reference_dt.strftime("%Y/%m/%d %H:%M")
|
||||
prompt = f"""You are a time extraction engine for chat logs.
|
||||
Extract temporal information from the following chat paragraph.
|
||||
|
||||
Rules:
|
||||
1. Use semantic understanding, not regex matching.
|
||||
2. Convert relative expressions (e.g., yesterday evening, last Friday morning) to absolute local datetime using reference_now.
|
||||
3. If a time span exists, return event_time_start/event_time_end.
|
||||
4. If only one point in time exists, return event_time.
|
||||
5. If no reliable time can be inferred, return all time fields as null.
|
||||
6. Output ONLY valid JSON. No markdown, no explanation.
|
||||
|
||||
reference_now: {reference_now}
|
||||
timezone: local system timezone
|
||||
|
||||
Allowed output formats for time values:
|
||||
- "YYYY/MM/DD"
|
||||
- "YYYY/MM/DD HH:mm"
|
||||
|
||||
JSON schema:
|
||||
{{
|
||||
"event_time": null,
|
||||
"event_time_start": null,
|
||||
"event_time_end": null,
|
||||
"time_range": null,
|
||||
"time_granularity": "day",
|
||||
"time_confidence": 0.0
|
||||
}}
|
||||
|
||||
Chat paragraph:
|
||||
\"\"\"{text}\"\"\"
|
||||
"""
|
||||
try:
|
||||
result = await self._llm_call(prompt, model_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"chat_log 时间语义抽取失败: {e}")
|
||||
return None
|
||||
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
|
||||
raw_time_meta = {
|
||||
"event_time": result.get("event_time"),
|
||||
"event_time_start": result.get("event_time_start"),
|
||||
"event_time_end": result.get("event_time_end"),
|
||||
"time_range": result.get("time_range"),
|
||||
"time_granularity": result.get("time_granularity"),
|
||||
"time_confidence": result.get("time_confidence"),
|
||||
}
|
||||
try:
|
||||
normalized = normalize_time_meta(raw_time_meta)
|
||||
except Exception as e:
|
||||
logger.warning(f"chat_log 时间语义抽取结果不可用,已忽略: {e}")
|
||||
return None
|
||||
|
||||
has_effective_time = any(
|
||||
key in normalized
|
||||
for key in ("event_time", "event_time_start", "event_time_end")
|
||||
)
|
||||
if not has_effective_time:
|
||||
return None
|
||||
|
||||
return normalized
|
||||
|
||||
def _determine_strategy(self, filename: str, content: str) -> BaseStrategy:
|
||||
"""Layer 1: Global Strategy Routing"""
|
||||
strategy = select_import_strategy(
|
||||
content,
|
||||
override=self.target_type,
|
||||
chat_log=self.chat_log,
|
||||
)
|
||||
if self.chat_log:
|
||||
logger.info(f"chat_log 模式: {filename} 强制使用 NarrativeStrategy")
|
||||
elif strategy == ImportStrategy.QUOTE:
|
||||
logger.info(f"Auto-detected Quote/Lyric type for {filename}")
|
||||
|
||||
if strategy == ImportStrategy.FACTUAL:
|
||||
return FactualStrategy(filename)
|
||||
if strategy == ImportStrategy.QUOTE:
|
||||
return QuoteStrategy(filename)
|
||||
return NarrativeStrategy(filename)
|
||||
|
||||
def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[BaseStrategy]:
|
||||
"""Layer 2: Chunk-level rescue strategies"""
|
||||
# If we are already in Quote strategy, no need to rescue
|
||||
if chunk.type == StratKnowledgeType.QUOTE:
|
||||
return None
|
||||
|
||||
if looks_like_quote_text(chunk.chunk.text):
|
||||
logger.info(f" > Rescuing chunk {chunk.chunk.index} as Quote")
|
||||
return QuoteStrategy(filename)
|
||||
|
||||
return None
|
||||
|
||||
async def process_and_import(self):
|
||||
if not await self.initialize(): return
|
||||
|
||||
files = list(RAW_DIR.glob("*.txt"))
|
||||
logger.info(f"扫描到 {len(files)} 个文件 in {RAW_DIR}")
|
||||
|
||||
if not files: return
|
||||
|
||||
tasks = []
|
||||
for file_path in files:
|
||||
tasks.append(asyncio.create_task(self._process_single_file(file_path)))
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
success_count = sum(1 for r in results if r is True)
|
||||
logger.info(f"本次主处理完成,共成功处理 {success_count}/{len(files)} 个文件")
|
||||
|
||||
if self.vector_store: self.vector_store.save()
|
||||
if self.graph_store: self.graph_store.save()
|
||||
|
||||
async def _process_single_file(self, file_path: Path) -> bool:
|
||||
filename = file_path.name
|
||||
async with self.semaphore:
|
||||
try:
|
||||
content = self.load_file(file_path)
|
||||
file_hash = self.get_file_hash(content)
|
||||
|
||||
if not self.force and filename in self.manifest:
|
||||
record = self.manifest[filename]
|
||||
if record.get("hash") == file_hash and record.get("imported"):
|
||||
logger.info(f"跳过已导入文件: {filename}")
|
||||
return False
|
||||
|
||||
logger.info(f">>> 开始处理: {filename}")
|
||||
|
||||
# 1. Strategy Selection
|
||||
strategy = self._determine_strategy(filename, content)
|
||||
logger.info(f" 策略: {strategy.__class__.__name__}")
|
||||
|
||||
# 2. Split (Strategy-Aware)
|
||||
initial_chunks = strategy.split(content)
|
||||
logger.info(f" 初步分块: {len(initial_chunks)}")
|
||||
|
||||
processed_data = {"paragraphs": [], "entities": [], "relations": []}
|
||||
|
||||
# 3. Extract Loop
|
||||
model_config = await self._select_model()
|
||||
|
||||
for i, chunk in enumerate(initial_chunks):
|
||||
current_strategy = strategy
|
||||
# Layer 2: Chunk Rescue
|
||||
rescue_strategy = self._chunk_rescue(chunk, filename)
|
||||
if rescue_strategy:
|
||||
# Re-split? No, just re-process this text as a single chunk using the rescue strategy
|
||||
# But rescue strategy might want to split it further?
|
||||
# Simplification: Treat the whole chunk text as one block for the rescue strategy
|
||||
# OR create a single chunk object for it.
|
||||
# Creating a new chunk using rescue strategy logic might be complex if split behavior differs.
|
||||
# Let's just instantiate a chunk of the new type manually
|
||||
chunk.type = StratKnowledgeType.QUOTE
|
||||
chunk.flags.verbatim = True
|
||||
chunk.flags.requires_llm = False # Quotes don't usually need LLM
|
||||
current_strategy = rescue_strategy
|
||||
|
||||
# Extraction
|
||||
if chunk.flags.requires_llm:
|
||||
result_chunk = await current_strategy.extract(chunk, lambda p: self._llm_call(p, model_config))
|
||||
else:
|
||||
# For quotes, extract might be just pass through or regex
|
||||
result_chunk = await current_strategy.extract(chunk)
|
||||
|
||||
time_meta = None
|
||||
if self.chat_log:
|
||||
time_meta = await self._extract_chat_time_meta_with_llm(
|
||||
result_chunk.chunk.text,
|
||||
model_config,
|
||||
)
|
||||
|
||||
# Normalize Data
|
||||
self._normalize_and_aggregate(
|
||||
result_chunk,
|
||||
processed_data,
|
||||
time_meta=time_meta,
|
||||
)
|
||||
|
||||
logger.info(f" 已处理块 {i+1}/{len(initial_chunks)}")
|
||||
|
||||
# 4. Save Json
|
||||
json_path = PROCESSED_DIR / f"{file_path.stem}.json"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(processed_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 5. Import to DB
|
||||
async with self.storage_lock:
|
||||
await self._import_to_db(processed_data)
|
||||
|
||||
self.manifest[filename] = {
|
||||
"hash": file_hash,
|
||||
"timestamp": time.time(),
|
||||
"imported": True
|
||||
}
|
||||
self._save_manifest()
|
||||
self.vector_store.save()
|
||||
self.graph_store.save()
|
||||
logger.info(f"✅ 文件 {filename} 处理并导入完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理失败 {filename}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def _normalize_and_aggregate(
|
||||
self,
|
||||
chunk: ProcessedChunk,
|
||||
all_data: Dict,
|
||||
time_meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Convert strategy-specific data to unified generic format for storage."""
|
||||
# Generic fields
|
||||
para_item = {
|
||||
"content": chunk.chunk.text,
|
||||
"source": chunk.source.file,
|
||||
"knowledge_type": resolve_stored_knowledge_type(
|
||||
chunk.type.value,
|
||||
content=chunk.chunk.text,
|
||||
).value,
|
||||
"entities": [],
|
||||
"relations": []
|
||||
}
|
||||
|
||||
data = chunk.data
|
||||
|
||||
# 1. Triples (Factual)
|
||||
if "triples" in data:
|
||||
for t in data["triples"]:
|
||||
para_item["relations"].append({
|
||||
"subject": t.get("subject"),
|
||||
"predicate": t.get("predicate"),
|
||||
"object": t.get("object")
|
||||
})
|
||||
# Auto-add entities from triples
|
||||
para_item["entities"].extend([t.get("subject"), t.get("object")])
|
||||
|
||||
# 2. Events & Relations (Narrative)
|
||||
if "events" in data:
|
||||
# Store events as content/metadata? Or entities?
|
||||
# For now maybe just keep them in logic, or add as 'Event' entities?
|
||||
# Creating entities for events is good.
|
||||
para_item["entities"].extend(data["events"])
|
||||
|
||||
if "relations" in data: # Narrative also outputs relations list
|
||||
para_item["relations"].extend(data["relations"])
|
||||
for r in data["relations"]:
|
||||
para_item["entities"].extend([r.get("subject"), r.get("object")])
|
||||
|
||||
# 3. Verbatim Entities (Quote)
|
||||
if "verbatim_entities" in data:
|
||||
para_item["entities"].extend(data["verbatim_entities"])
|
||||
|
||||
# Dedupe per paragraph
|
||||
para_item["entities"] = list(set([e for e in para_item["entities"] if e]))
|
||||
|
||||
if time_meta:
|
||||
para_item["time_meta"] = time_meta
|
||||
|
||||
all_data["paragraphs"].append(para_item)
|
||||
all_data["entities"].extend(para_item["entities"])
|
||||
if "relations" in para_item:
|
||||
all_data["relations"].extend(para_item["relations"])
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((LLMGenerationError, json.JSONDecodeError)),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=2, max=10),
|
||||
before_sleep=_log_before_retry
|
||||
)
|
||||
async def _llm_call(self, prompt: str, model_config: Any) -> Dict:
|
||||
"""Generic LLM Caller"""
|
||||
task_name = llm_api.resolve_task_name_from_model_config(model_config)
|
||||
result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name=task_name,
|
||||
request_type="Script.ProcessKnowledge",
|
||||
prompt=prompt,
|
||||
temperature=getattr(model_config, "temperature", None),
|
||||
max_tokens=getattr(model_config, "max_tokens", None),
|
||||
)
|
||||
)
|
||||
success = bool(result.success)
|
||||
response = str(result.completion.response or "")
|
||||
if success:
|
||||
txt = response.strip()
|
||||
if "```" in txt:
|
||||
txt = txt.split("```json")[-1].split("```")[0].strip()
|
||||
try:
|
||||
return json.loads(txt)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback: try to find first { and last }
|
||||
start = txt.find('{')
|
||||
end = txt.rfind('}')
|
||||
if start != -1 and end != -1:
|
||||
return json.loads(txt[start:end+1])
|
||||
raise
|
||||
else:
|
||||
raise LLMGenerationError("LLM generation failed")
|
||||
|
||||
async def _select_model(self) -> Any:
|
||||
models = llm_api.get_available_models()
|
||||
if not models: raise ValueError("No LLM models")
|
||||
|
||||
config_model = self.plugin_config.get("advanced", {}).get("extraction_model", "auto")
|
||||
if config_model != "auto" and config_model in models:
|
||||
return models[config_model]
|
||||
|
||||
for task_key in ["lpmm_entity_extract", "lpmm_rdf_build", "embedding"]:
|
||||
if task_key in models: return models[task_key]
|
||||
|
||||
return models[list(models.keys())[0]]
|
||||
|
||||
# Re-use existing methods
|
||||
async def _add_entity_with_vector(self, name: str, source_paragraph: Optional[str] = None) -> str:
|
||||
# Same as before
|
||||
hash_value = self.metadata_store.add_entity(name, source_paragraph=source_paragraph)
|
||||
self.graph_store.add_nodes([name])
|
||||
try:
|
||||
emb = await self.embedding_manager.encode(name)
|
||||
try:
|
||||
self.vector_store.add(emb.reshape(1, -1), [hash_value])
|
||||
except ValueError: pass
|
||||
except Exception: pass
|
||||
return hash_value
|
||||
|
||||
async def import_json_data(self, data: Dict, filename: str = "script_import", progress_callback=None):
|
||||
"""Public import entrypoint for pre-processed JSON payloads."""
|
||||
if not self.storage_lock:
|
||||
raise RuntimeError("Importer is not initialized. Call initialize() first.")
|
||||
|
||||
async with self.storage_lock:
|
||||
await self._import_to_db(data, progress_callback=progress_callback)
|
||||
self.manifest[filename] = {
|
||||
"hash": self.get_file_hash(json.dumps(data, ensure_ascii=False, sort_keys=True)),
|
||||
"timestamp": time.time(),
|
||||
"imported": True,
|
||||
}
|
||||
self._save_manifest()
|
||||
self.vector_store.save()
|
||||
self.graph_store.save()
|
||||
|
||||
async def _import_to_db(self, data: Dict, progress_callback=None):
|
||||
# Same logic, but ensure robust
|
||||
with self.graph_store.batch_update():
|
||||
for item in data.get("paragraphs", []):
|
||||
paragraph = normalize_paragraph_import_item(
|
||||
item,
|
||||
default_source="script",
|
||||
)
|
||||
content = paragraph["content"]
|
||||
source = paragraph["source"]
|
||||
k_type_val = paragraph["knowledge_type"]
|
||||
|
||||
h_val = self.metadata_store.add_paragraph(
|
||||
content=content,
|
||||
source=source,
|
||||
knowledge_type=k_type_val,
|
||||
time_meta=paragraph["time_meta"],
|
||||
)
|
||||
|
||||
if h_val not in self.vector_store:
|
||||
try:
|
||||
emb = await self.embedding_manager.encode(content)
|
||||
self.vector_store.add(emb.reshape(1, -1), [h_val])
|
||||
except Exception as e:
|
||||
logger.error(f" Vector fail: {e}")
|
||||
|
||||
para_entities = paragraph["entities"]
|
||||
for entity in para_entities:
|
||||
if entity:
|
||||
await self._add_entity_with_vector(entity, source_paragraph=h_val)
|
||||
|
||||
para_relations = paragraph["relations"]
|
||||
for rel in para_relations:
|
||||
s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object")
|
||||
if s and p and o:
|
||||
await self._add_entity_with_vector(s, source_paragraph=h_val)
|
||||
await self._add_entity_with_vector(o, source_paragraph=h_val)
|
||||
confidence = float(rel.get("confidence", 1.0) or 1.0)
|
||||
rel_meta = rel.get("metadata", {})
|
||||
write_vector = self._should_write_relation_vectors()
|
||||
if self.relation_write_service is not None:
|
||||
await self.relation_write_service.upsert_relation_with_vector(
|
||||
subject=s,
|
||||
predicate=p,
|
||||
obj=o,
|
||||
confidence=confidence,
|
||||
source_paragraph=h_val,
|
||||
metadata=rel_meta if isinstance(rel_meta, dict) else {},
|
||||
write_vector=write_vector,
|
||||
)
|
||||
else:
|
||||
rel_hash = self.metadata_store.add_relation(
|
||||
s,
|
||||
p,
|
||||
o,
|
||||
confidence=confidence,
|
||||
source_paragraph=h_val,
|
||||
metadata=rel_meta if isinstance(rel_meta, dict) else {},
|
||||
)
|
||||
self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash])
|
||||
try:
|
||||
self.metadata_store.set_relation_vector_state(rel_hash, "none")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if progress_callback: progress_callback(1)
|
||||
|
||||
async def close(self):
|
||||
if self.metadata_store: self.metadata_store.close()
|
||||
|
||||
def _save_manifest(self):
|
||||
with open(MANIFEST_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(self.manifest, f, ensure_ascii=False, indent=2)
|
||||
|
||||
async def main():
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if not global_config: return
|
||||
|
||||
importer = AutoImporter(
|
||||
force=args.force,
|
||||
clear_manifest=args.clear_manifest,
|
||||
target_type=args.type,
|
||||
concurrency=args.concurrency,
|
||||
chat_log=args.chat_log,
|
||||
chat_reference_time=args.chat_reference_time,
|
||||
)
|
||||
await importer.process_and_import()
|
||||
await importer.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(main())
|
||||
119
src/A_memorix/scripts/rebuild_episodes.py
Normal file
119
src/A_memorix/scripts/rebuild_episodes.py
Normal file
@@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Episode source 级重建工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from _bootstrap import DEFAULT_CONFIG_PATH, DEFAULT_DATA_DIR, resolve_repo_path
|
||||
|
||||
try:
|
||||
import tomlkit # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
tomlkit = None
|
||||
|
||||
from A_memorix.core.storage import MetadataStore
|
||||
from A_memorix.core.utils.episode_service import EpisodeService
|
||||
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Rebuild A_Memorix episodes by source")
|
||||
parser.add_argument("--data-dir", default=str(DEFAULT_DATA_DIR), help="插件数据目录")
|
||||
parser.add_argument("--source", type=str, help="指定单个 source 入队/重建")
|
||||
parser.add_argument("--all", action="store_true", help="对所有 source 入队/重建")
|
||||
parser.add_argument("--wait", action="store_true", help="在脚本内同步执行重建")
|
||||
return parser
|
||||
|
||||
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
raise SystemExit(0)
|
||||
|
||||
|
||||
def _load_plugin_config() -> Dict[str, Any]:
|
||||
config_path = DEFAULT_CONFIG_PATH
|
||||
if tomlkit is None or not config_path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
parsed = tomlkit.load(handle)
|
||||
return dict(parsed) if isinstance(parsed, dict) else {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _resolve_sources(store: MetadataStore, *, source: str | None, rebuild_all: bool) -> List[str]:
|
||||
if rebuild_all:
|
||||
return list(store.list_episode_sources_for_rebuild())
|
||||
token = str(source or "").strip()
|
||||
if not token:
|
||||
raise ValueError("必须提供 --source 或 --all")
|
||||
return [token]
|
||||
|
||||
|
||||
async def _run_rebuilds(store: MetadataStore, plugin_config: Dict[str, Any], sources: List[str]) -> int:
|
||||
service = EpisodeService(metadata_store=store, plugin_config=plugin_config)
|
||||
failures: List[str] = []
|
||||
for source in sources:
|
||||
started = store.mark_episode_source_running(source)
|
||||
if not started:
|
||||
failures.append(f"{source}: unable_to_mark_running")
|
||||
continue
|
||||
try:
|
||||
result = await service.rebuild_source(source)
|
||||
store.mark_episode_source_done(source)
|
||||
print(
|
||||
"rebuilt"
|
||||
f" source={source}"
|
||||
f" paragraphs={int(result.get('paragraph_count') or 0)}"
|
||||
f" groups={int(result.get('group_count') or 0)}"
|
||||
f" episodes={int(result.get('episode_count') or 0)}"
|
||||
f" fallback={int(result.get('fallback_count') or 0)}"
|
||||
)
|
||||
except Exception as exc:
|
||||
err = str(exc)[:500]
|
||||
store.mark_episode_source_failed(source, err)
|
||||
failures.append(f"{source}: {err}")
|
||||
print(f"failed source={source} error={err}")
|
||||
|
||||
if failures:
|
||||
for item in failures:
|
||||
print(item)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
if bool(args.all) == bool(args.source):
|
||||
parser.error("必须且只能选择一个:--source 或 --all")
|
||||
|
||||
store = MetadataStore(data_dir=resolve_repo_path(args.data_dir, fallback=DEFAULT_DATA_DIR) / "metadata")
|
||||
store.connect()
|
||||
try:
|
||||
sources = _resolve_sources(store, source=args.source, rebuild_all=bool(args.all))
|
||||
if not sources:
|
||||
print("no sources to rebuild")
|
||||
return 0
|
||||
|
||||
enqueued = 0
|
||||
reason = "script_rebuild_all" if args.all else "script_rebuild_source"
|
||||
for source in sources:
|
||||
enqueued += int(store.enqueue_episode_source_rebuild(source, reason=reason))
|
||||
print(f"enqueued={enqueued} sources={len(sources)}")
|
||||
|
||||
if not args.wait:
|
||||
return 0
|
||||
|
||||
plugin_config = _load_plugin_config()
|
||||
return asyncio.run(_run_rebuilds(store, plugin_config, sources))
|
||||
finally:
|
||||
store.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
744
src/A_memorix/scripts/release_vnext_migrate.py
Normal file
744
src/A_memorix/scripts/release_vnext_migrate.py
Normal file
@@ -0,0 +1,744 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
vNext release migration entrypoint for A_Memorix.
|
||||
|
||||
Subcommands:
|
||||
- preflight: detect legacy config/data/schema risks
|
||||
- migrate: offline migrate config + vectors + metadata schema + graph edge hash map
|
||||
- verify: strict post-migration consistency checks
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pickle
|
||||
import sqlite3
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import tomlkit
|
||||
|
||||
from _bootstrap import DEFAULT_CONFIG_PATH, DEFAULT_DATA_DIR, resolve_repo_path
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="A_Memorix vNext release migration tool")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=str(DEFAULT_CONFIG_PATH),
|
||||
help="config.toml path (default: config/a_memorix.toml)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default="",
|
||||
help="optional data dir override; default resolved from config.storage.data_dir",
|
||||
)
|
||||
parser.add_argument("--json-out", default="", help="optional JSON report output path")
|
||||
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
p_preflight = sub.add_parser("preflight", help="scan legacy risks")
|
||||
p_preflight.add_argument("--strict", action="store_true", help="return 1 if any error check exists")
|
||||
|
||||
p_migrate = sub.add_parser("migrate", help="run offline migration")
|
||||
p_migrate.add_argument("--dry-run", action="store_true", help="only print planned changes")
|
||||
p_migrate.add_argument(
|
||||
"--verify-after",
|
||||
action="store_true",
|
||||
help="run verify automatically after migrate",
|
||||
)
|
||||
|
||||
p_verify = sub.add_parser("verify", help="post-migration verification")
|
||||
p_verify.add_argument("--strict", action="store_true", help="return 1 if any error check exists")
|
||||
return parser
|
||||
|
||||
|
||||
# --help/-h fast path: avoid heavy host/plugin bootstrap
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
raise SystemExit(0)
|
||||
|
||||
try:
|
||||
from A_memorix.core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore
|
||||
from A_memorix.core.storage.metadata_store import SCHEMA_VERSION
|
||||
except Exception as e: # pragma: no cover
|
||||
print(f"❌ failed to import storage modules: {e}")
|
||||
raise SystemExit(2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckItem:
|
||||
code: str
|
||||
level: str
|
||||
message: str
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
out = {
|
||||
"code": self.code,
|
||||
"level": self.level,
|
||||
"message": self.message,
|
||||
}
|
||||
if self.details:
|
||||
out["details"] = self.details
|
||||
return out
|
||||
|
||||
|
||||
def _read_toml(path: Path) -> Dict[str, Any]:
|
||||
text = path.read_text(encoding="utf-8")
|
||||
return tomlkit.parse(text)
|
||||
|
||||
|
||||
def _write_toml(path: Path, data: Dict[str, Any]) -> None:
|
||||
path.write_text(tomlkit.dumps(data), encoding="utf-8")
|
||||
|
||||
|
||||
def _get_nested(obj: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any:
|
||||
cur: Any = obj
|
||||
for k in keys:
|
||||
if not isinstance(cur, dict) or k not in cur:
|
||||
return default
|
||||
cur = cur[k]
|
||||
return cur
|
||||
|
||||
|
||||
def _ensure_table(obj: Dict[str, Any], key: str) -> Dict[str, Any]:
|
||||
if key not in obj or not isinstance(obj[key], dict):
|
||||
obj[key] = tomlkit.table()
|
||||
return obj[key]
|
||||
|
||||
|
||||
def _resolve_data_dir(config_doc: Dict[str, Any], explicit_data_dir: Optional[str]) -> Path:
|
||||
if explicit_data_dir:
|
||||
return resolve_repo_path(explicit_data_dir, fallback=DEFAULT_DATA_DIR)
|
||||
raw = str(_get_nested(config_doc, ("storage", "data_dir"), "./data") or "./data").strip()
|
||||
return resolve_repo_path(raw, fallback=DEFAULT_DATA_DIR)
|
||||
|
||||
|
||||
def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool:
|
||||
row = conn.execute(
|
||||
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1",
|
||||
(table,),
|
||||
).fetchone()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]:
|
||||
hashes: List[str] = []
|
||||
if _sqlite_table_exists(conn, "relations"):
|
||||
rows = conn.execute("SELECT hash FROM relations").fetchall()
|
||||
hashes.extend(str(r[0]) for r in rows if r and r[0])
|
||||
if _sqlite_table_exists(conn, "deleted_relations"):
|
||||
rows = conn.execute("SELECT hash FROM deleted_relations").fetchall()
|
||||
hashes.extend(str(r[0]) for r in rows if r and r[0])
|
||||
|
||||
alias_map: Dict[str, str] = {}
|
||||
conflicts: Dict[str, set[str]] = {}
|
||||
for h in hashes:
|
||||
if len(h) != 64:
|
||||
continue
|
||||
alias = h[:32]
|
||||
old = alias_map.get(alias)
|
||||
if old is None:
|
||||
alias_map[alias] = h
|
||||
continue
|
||||
if old != h:
|
||||
conflicts.setdefault(alias, set()).update({old, h})
|
||||
return {k: sorted(v) for k, v in conflicts.items()}
|
||||
|
||||
|
||||
def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]:
|
||||
if not _sqlite_table_exists(conn, "paragraphs"):
|
||||
return []
|
||||
|
||||
allowed = {item.value for item in KnowledgeType}
|
||||
rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall()
|
||||
invalid: List[str] = []
|
||||
for row in rows:
|
||||
raw = row[0]
|
||||
value = str(raw).strip().lower() if raw is not None else ""
|
||||
if value not in allowed:
|
||||
invalid.append(str(raw) if raw is not None else "")
|
||||
return sorted(set(invalid))
|
||||
|
||||
|
||||
def _guess_vector_dimension(config_doc: Dict[str, Any], vectors_dir: Path) -> int:
|
||||
meta_path = vectors_dir / "vectors_metadata.pkl"
|
||||
if meta_path.exists():
|
||||
try:
|
||||
with open(meta_path, "rb") as f:
|
||||
meta = pickle.load(f)
|
||||
dim = int(meta.get("dimension", 0))
|
||||
if dim > 0:
|
||||
return dim
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
dim_cfg = int(_get_nested(config_doc, ("embedding", "dimension"), 1024))
|
||||
if dim_cfg > 0:
|
||||
return dim_cfg
|
||||
except Exception:
|
||||
pass
|
||||
return 1024
|
||||
|
||||
|
||||
def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
||||
checks: List[CheckItem] = []
|
||||
facts: Dict[str, Any] = {
|
||||
"config_path": str(config_path),
|
||||
"data_dir": str(data_dir),
|
||||
}
|
||||
|
||||
if not config_path.exists():
|
||||
checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}"))
|
||||
return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts}
|
||||
|
||||
config_doc = _read_toml(config_path)
|
||||
tool_mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower()
|
||||
summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"])
|
||||
summary_knowledge_type = str(
|
||||
_get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative"
|
||||
).strip().lower()
|
||||
quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower()
|
||||
|
||||
facts["routing.tool_search_mode"] = tool_mode
|
||||
facts["summarization.model_name_type"] = type(summary_model).__name__
|
||||
facts["summarization.default_knowledge_type"] = summary_knowledge_type
|
||||
facts["embedding.quantization_type"] = quantization
|
||||
|
||||
if tool_mode == "legacy":
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-04",
|
||||
"error",
|
||||
"routing.tool_search_mode=legacy is no longer accepted at runtime",
|
||||
)
|
||||
)
|
||||
elif tool_mode not in {"forward", "disabled"}:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-04",
|
||||
"error",
|
||||
f"routing.tool_search_mode invalid value: {tool_mode}",
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(summary_model, str):
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-11",
|
||||
"error",
|
||||
"summarization.model_name must be List[str], string legacy format detected",
|
||||
)
|
||||
)
|
||||
elif not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model):
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-11",
|
||||
"error",
|
||||
"summarization.model_name must be List[str]",
|
||||
)
|
||||
)
|
||||
|
||||
if summary_knowledge_type not in {item.value for item in KnowledgeType}:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-13",
|
||||
"error",
|
||||
f"invalid summarization.default_knowledge_type: {summary_knowledge_type}",
|
||||
)
|
||||
)
|
||||
|
||||
if quantization != "int8":
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"UG-07",
|
||||
"error",
|
||||
"embedding.quantization_type must be int8 in vNext",
|
||||
)
|
||||
)
|
||||
|
||||
vectors_dir = data_dir / "vectors"
|
||||
npy_path = vectors_dir / "vectors.npy"
|
||||
bin_path = vectors_dir / "vectors.bin"
|
||||
ids_bin_path = vectors_dir / "vectors_ids.bin"
|
||||
facts["vectors.npy_exists"] = npy_path.exists()
|
||||
facts["vectors.bin_exists"] = bin_path.exists()
|
||||
facts["vectors_ids.bin_exists"] = ids_bin_path.exists()
|
||||
|
||||
if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()):
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-07",
|
||||
"error",
|
||||
"legacy vectors.npy detected; offline migrate required",
|
||||
{"npy_path": str(npy_path)},
|
||||
)
|
||||
)
|
||||
|
||||
metadata_db = data_dir / "metadata" / "metadata.db"
|
||||
facts["metadata_db_exists"] = metadata_db.exists()
|
||||
relation_count = 0
|
||||
if metadata_db.exists():
|
||||
conn = sqlite3.connect(str(metadata_db))
|
||||
try:
|
||||
has_schema_table = _sqlite_table_exists(conn, "schema_migrations")
|
||||
facts["schema_migrations_exists"] = has_schema_table
|
||||
has_paragraph_backfill = _sqlite_table_exists(conn, "paragraph_vector_backfill")
|
||||
facts["paragraph_vector_backfill_exists"] = has_paragraph_backfill
|
||||
if not has_schema_table:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-08",
|
||||
"error",
|
||||
"schema_migrations table missing (legacy metadata schema)",
|
||||
)
|
||||
)
|
||||
else:
|
||||
row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone()
|
||||
version = int(row[0]) if row and row[0] is not None else 0
|
||||
facts["schema_version"] = version
|
||||
if version != SCHEMA_VERSION:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-08",
|
||||
"error",
|
||||
f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}",
|
||||
)
|
||||
)
|
||||
elif not has_paragraph_backfill:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-14",
|
||||
"error",
|
||||
"paragraph_vector_backfill table missing under current schema version",
|
||||
)
|
||||
)
|
||||
|
||||
if _sqlite_table_exists(conn, "relations"):
|
||||
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()
|
||||
relation_count = int(row[0]) if row and row[0] is not None else 0
|
||||
facts["relations_count"] = relation_count
|
||||
|
||||
conflicts = _collect_hash_alias_conflicts(conn)
|
||||
facts["alias_conflict_count"] = len(conflicts)
|
||||
if conflicts:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-05",
|
||||
"error",
|
||||
"32-bit relation hash alias conflict detected",
|
||||
{"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)},
|
||||
)
|
||||
)
|
||||
|
||||
invalid_knowledge_types = _collect_invalid_knowledge_types(conn)
|
||||
facts["invalid_knowledge_type_values"] = invalid_knowledge_types
|
||||
if invalid_knowledge_types:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-12",
|
||||
"error",
|
||||
"invalid paragraph knowledge_type values detected",
|
||||
{"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
else:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"META-00",
|
||||
"warning",
|
||||
"metadata.db not found, schema checks skipped",
|
||||
)
|
||||
)
|
||||
|
||||
graph_meta_path = data_dir / "graph" / "graph_metadata.pkl"
|
||||
facts["graph_metadata_exists"] = graph_meta_path.exists()
|
||||
if relation_count > 0:
|
||||
if not graph_meta_path.exists():
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-06",
|
||||
"error",
|
||||
"relations exist but graph metadata missing",
|
||||
)
|
||||
)
|
||||
else:
|
||||
try:
|
||||
with open(graph_meta_path, "rb") as f:
|
||||
graph_meta = pickle.load(f)
|
||||
edge_hash_map = graph_meta.get("edge_hash_map", {})
|
||||
edge_hash_map_size = len(edge_hash_map) if isinstance(edge_hash_map, dict) else 0
|
||||
facts["edge_hash_map_size"] = edge_hash_map_size
|
||||
if edge_hash_map_size <= 0:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-06",
|
||||
"error",
|
||||
"edge_hash_map missing/empty while relations exist",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-06",
|
||||
"error",
|
||||
f"failed to read graph metadata: {e}",
|
||||
)
|
||||
)
|
||||
|
||||
has_error = any(c.level == "error" for c in checks)
|
||||
return {
|
||||
"ok": not has_error,
|
||||
"checks": [c.to_dict() for c in checks],
|
||||
"facts": facts,
|
||||
}
|
||||
|
||||
|
||||
def _migrate_config(config_doc: Dict[str, Any]) -> Dict[str, Any]:
|
||||
changes: Dict[str, Any] = {}
|
||||
|
||||
routing = _ensure_table(config_doc, "routing")
|
||||
mode_raw = str(routing.get("tool_search_mode", "forward") or "").strip().lower()
|
||||
mode_new = mode_raw
|
||||
if mode_raw == "legacy" or mode_raw not in {"forward", "disabled"}:
|
||||
mode_new = "forward"
|
||||
if mode_new != mode_raw:
|
||||
routing["tool_search_mode"] = mode_new
|
||||
changes["routing.tool_search_mode"] = {"old": mode_raw, "new": mode_new}
|
||||
|
||||
summary = _ensure_table(config_doc, "summarization")
|
||||
summary_model = summary.get("model_name", ["auto"])
|
||||
if isinstance(summary_model, str):
|
||||
normalized = [summary_model.strip() or "auto"]
|
||||
summary["model_name"] = normalized
|
||||
changes["summarization.model_name"] = {"old": summary_model, "new": normalized}
|
||||
elif not isinstance(summary_model, list):
|
||||
normalized = ["auto"]
|
||||
summary["model_name"] = normalized
|
||||
changes["summarization.model_name"] = {"old": str(type(summary_model)), "new": normalized}
|
||||
elif any(not isinstance(x, str) for x in summary_model):
|
||||
normalized = [str(x).strip() for x in summary_model if str(x).strip()]
|
||||
if not normalized:
|
||||
normalized = ["auto"]
|
||||
summary["model_name"] = normalized
|
||||
changes["summarization.model_name"] = {"old": summary_model, "new": normalized}
|
||||
|
||||
default_knowledge_type = str(summary.get("default_knowledge_type", "narrative") or "").strip().lower()
|
||||
allowed_knowledge_types = {item.value for item in KnowledgeType}
|
||||
if default_knowledge_type not in allowed_knowledge_types:
|
||||
summary["default_knowledge_type"] = "narrative"
|
||||
changes["summarization.default_knowledge_type"] = {
|
||||
"old": default_knowledge_type,
|
||||
"new": "narrative",
|
||||
}
|
||||
|
||||
embedding = _ensure_table(config_doc, "embedding")
|
||||
quantization = str(embedding.get("quantization_type", "int8") or "").strip().lower()
|
||||
if quantization != "int8":
|
||||
embedding["quantization_type"] = "int8"
|
||||
changes["embedding.quantization_type"] = {"old": quantization, "new": "int8"}
|
||||
|
||||
return changes
|
||||
|
||||
|
||||
def _migrate_impl(config_path: Path, data_dir: Path, dry_run: bool) -> Dict[str, Any]:
|
||||
config_doc = _read_toml(config_path)
|
||||
result: Dict[str, Any] = {
|
||||
"config_path": str(config_path),
|
||||
"data_dir": str(data_dir),
|
||||
"dry_run": bool(dry_run),
|
||||
"steps": {},
|
||||
}
|
||||
|
||||
config_changes = _migrate_config(config_doc)
|
||||
result["steps"]["config"] = {"changed": bool(config_changes), "changes": config_changes}
|
||||
if config_changes and not dry_run:
|
||||
_write_toml(config_path, config_doc)
|
||||
|
||||
vectors_dir = data_dir / "vectors"
|
||||
vectors_dir.mkdir(parents=True, exist_ok=True)
|
||||
npy_path = vectors_dir / "vectors.npy"
|
||||
bin_path = vectors_dir / "vectors.bin"
|
||||
ids_bin_path = vectors_dir / "vectors_ids.bin"
|
||||
if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()):
|
||||
if dry_run:
|
||||
result["steps"]["vector"] = {"migrated": False, "reason": "dry_run"}
|
||||
else:
|
||||
dim = _guess_vector_dimension(config_doc, vectors_dir)
|
||||
store = VectorStore(
|
||||
dimension=max(1, int(dim)),
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=vectors_dir,
|
||||
)
|
||||
result["steps"]["vector"] = store.migrate_legacy_npy(vectors_dir)
|
||||
else:
|
||||
result["steps"]["vector"] = {"migrated": False, "reason": "not_required"}
|
||||
|
||||
metadata_dir = data_dir / "metadata"
|
||||
metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_db = metadata_dir / "metadata.db"
|
||||
triples: List[Tuple[str, str, str, str]] = []
|
||||
relation_count = 0
|
||||
|
||||
metadata_result: Dict[str, Any] = {"migrated": False, "reason": "not_required"}
|
||||
if metadata_db.exists():
|
||||
store = MetadataStore(data_dir=metadata_dir)
|
||||
store.connect(enforce_schema=False)
|
||||
try:
|
||||
if dry_run:
|
||||
metadata_result = {"migrated": False, "reason": "dry_run"}
|
||||
else:
|
||||
metadata_result = store.run_legacy_migration_for_vnext()
|
||||
relation_count = int(store.count_relations())
|
||||
if relation_count > 0:
|
||||
triples = [(str(s), str(p), str(o), str(h)) for s, p, o, h in store.get_all_triples()]
|
||||
finally:
|
||||
store.close()
|
||||
result["steps"]["metadata"] = metadata_result
|
||||
|
||||
graph_dir = data_dir / "graph"
|
||||
graph_dir.mkdir(parents=True, exist_ok=True)
|
||||
graph_matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr")
|
||||
graph_store = GraphStore(matrix_format=graph_matrix_format, data_dir=graph_dir)
|
||||
graph_step: Dict[str, Any] = {
|
||||
"rebuilt": False,
|
||||
"mapped_hashes": 0,
|
||||
"relation_count": relation_count,
|
||||
"topology_rebuilt_from_relations": False,
|
||||
}
|
||||
if relation_count > 0:
|
||||
if dry_run:
|
||||
graph_step["reason"] = "dry_run"
|
||||
else:
|
||||
if graph_store.has_data():
|
||||
graph_store.load()
|
||||
|
||||
mapped = graph_store.rebuild_edge_hash_map(triples)
|
||||
|
||||
# 兜底:历史数据里 graph 节点/边与 relations 脱节时,直接从 relations 重建图。
|
||||
if mapped <= 0 or not graph_store.has_edge_hash_map():
|
||||
nodes = sorted({s for s, _, o, _ in triples} | {o for _, _, o, _ in triples})
|
||||
edges = [(s, o) for s, _, o, _ in triples]
|
||||
hashes = [h for _, _, _, h in triples]
|
||||
|
||||
graph_store.clear()
|
||||
if nodes:
|
||||
graph_store.add_nodes(nodes)
|
||||
if edges:
|
||||
mapped = graph_store.add_edges(edges, relation_hashes=hashes)
|
||||
else:
|
||||
mapped = 0
|
||||
graph_step.update(
|
||||
{
|
||||
"topology_rebuilt_from_relations": True,
|
||||
"rebuilt_nodes": len(nodes),
|
||||
"rebuilt_edges": int(graph_store.num_edges),
|
||||
}
|
||||
)
|
||||
|
||||
graph_store.save()
|
||||
graph_step.update({"rebuilt": True, "mapped_hashes": int(mapped)})
|
||||
else:
|
||||
graph_step["reason"] = "no_relations"
|
||||
result["steps"]["graph"] = graph_step
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _verify_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
||||
checks: List[CheckItem] = []
|
||||
facts: Dict[str, Any] = {
|
||||
"config_path": str(config_path),
|
||||
"data_dir": str(data_dir),
|
||||
}
|
||||
|
||||
if not config_path.exists():
|
||||
checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}"))
|
||||
return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts}
|
||||
|
||||
config_doc = _read_toml(config_path)
|
||||
mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower()
|
||||
if mode not in {"forward", "disabled"}:
|
||||
checks.append(CheckItem("CP-04", "error", f"invalid routing.tool_search_mode: {mode}"))
|
||||
|
||||
summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"])
|
||||
if not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model):
|
||||
checks.append(CheckItem("CP-11", "error", "summarization.model_name must be List[str]"))
|
||||
summary_knowledge_type = str(
|
||||
_get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative"
|
||||
).strip().lower()
|
||||
if summary_knowledge_type not in {item.value for item in KnowledgeType}:
|
||||
checks.append(
|
||||
CheckItem("CP-13", "error", f"invalid summarization.default_knowledge_type: {summary_knowledge_type}")
|
||||
)
|
||||
|
||||
quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower()
|
||||
if quantization != "int8":
|
||||
checks.append(CheckItem("UG-07", "error", "embedding.quantization_type must be int8"))
|
||||
|
||||
vectors_dir = data_dir / "vectors"
|
||||
npy_path = vectors_dir / "vectors.npy"
|
||||
bin_path = vectors_dir / "vectors.bin"
|
||||
ids_bin_path = vectors_dir / "vectors_ids.bin"
|
||||
if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()):
|
||||
checks.append(CheckItem("CP-07", "error", "legacy vectors.npy still exists without bin migration"))
|
||||
|
||||
metadata_dir = data_dir / "metadata"
|
||||
store = MetadataStore(data_dir=metadata_dir)
|
||||
try:
|
||||
store.connect(enforce_schema=True)
|
||||
schema_version = store.get_schema_version()
|
||||
facts["schema_version"] = schema_version
|
||||
if schema_version != SCHEMA_VERSION:
|
||||
checks.append(CheckItem("CP-08", "error", f"schema version mismatch: {schema_version}"))
|
||||
|
||||
relation_count = int(store.count_relations())
|
||||
facts["relations_count"] = relation_count
|
||||
|
||||
conflicts = {}
|
||||
invalid_knowledge_types: List[str] = []
|
||||
db_path = metadata_dir / "metadata.db"
|
||||
if db_path.exists():
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
try:
|
||||
has_paragraph_backfill = _sqlite_table_exists(conn, "paragraph_vector_backfill")
|
||||
facts["paragraph_vector_backfill_exists"] = bool(has_paragraph_backfill)
|
||||
if not has_paragraph_backfill:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-14",
|
||||
"error",
|
||||
"paragraph_vector_backfill table missing after migration",
|
||||
)
|
||||
)
|
||||
conflicts = _collect_hash_alias_conflicts(conn)
|
||||
invalid_knowledge_types = _collect_invalid_knowledge_types(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
if conflicts:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-05",
|
||||
"error",
|
||||
"alias conflicts still exist after migration",
|
||||
{"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)},
|
||||
)
|
||||
)
|
||||
if invalid_knowledge_types:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-12",
|
||||
"error",
|
||||
"invalid paragraph knowledge_type values remain after migration",
|
||||
{"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)},
|
||||
)
|
||||
)
|
||||
|
||||
if relation_count > 0:
|
||||
graph_dir = data_dir / "graph"
|
||||
if not (graph_dir / "graph_metadata.pkl").exists():
|
||||
checks.append(CheckItem("CP-06", "error", "graph metadata missing while relations exist"))
|
||||
else:
|
||||
matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr")
|
||||
graph_store = GraphStore(matrix_format=matrix_format, data_dir=graph_dir)
|
||||
graph_store.load()
|
||||
if not graph_store.has_edge_hash_map():
|
||||
checks.append(CheckItem("CP-06", "error", "edge_hash_map is empty"))
|
||||
except Exception as e:
|
||||
checks.append(CheckItem("CP-08", "error", f"metadata strict connect failed: {e}"))
|
||||
finally:
|
||||
try:
|
||||
store.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
has_error = any(c.level == "error" for c in checks)
|
||||
return {
|
||||
"ok": not has_error,
|
||||
"checks": [c.to_dict() for c in checks],
|
||||
"facts": facts,
|
||||
}
|
||||
|
||||
|
||||
def _print_report(title: str, report: Dict[str, Any]) -> None:
|
||||
print(f"=== {title} ===")
|
||||
print(f"ok: {bool(report.get('ok', True))}")
|
||||
facts = report.get("facts", {})
|
||||
if facts:
|
||||
print("facts:")
|
||||
for k in sorted(facts.keys()):
|
||||
print(f" - {k}: {facts[k]}")
|
||||
checks = report.get("checks", [])
|
||||
if checks:
|
||||
print("checks:")
|
||||
for item in checks:
|
||||
print(f" - [{item.get('level')}] {item.get('code')}: {item.get('message')}")
|
||||
else:
|
||||
print("checks: none")
|
||||
|
||||
|
||||
def _write_json_if_needed(path: str, payload: Dict[str, Any]) -> None:
|
||||
if not path:
|
||||
return
|
||||
out = Path(path).expanduser().resolve()
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"json_out: {out}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
config_path = resolve_repo_path(args.config, fallback=DEFAULT_CONFIG_PATH)
|
||||
if not config_path.exists():
|
||||
print(f"❌ config not found: {config_path}")
|
||||
return 2
|
||||
config_doc = _read_toml(config_path)
|
||||
data_dir = _resolve_data_dir(config_doc, args.data_dir)
|
||||
|
||||
if args.command == "preflight":
|
||||
report = _preflight_impl(config_path, data_dir)
|
||||
_print_report("vNext Preflight", report)
|
||||
_write_json_if_needed(args.json_out, report)
|
||||
has_error = any(item.get("level") == "error" for item in report.get("checks", []))
|
||||
if args.strict and has_error:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
if args.command == "migrate":
|
||||
payload = _migrate_impl(config_path, data_dir, dry_run=bool(args.dry_run))
|
||||
print("=== vNext Migrate ===")
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
|
||||
verify_report = None
|
||||
if args.verify_after and not args.dry_run:
|
||||
verify_report = _verify_impl(config_path, data_dir)
|
||||
_print_report("vNext Verify (after migrate)", verify_report)
|
||||
payload["verify_after"] = verify_report
|
||||
|
||||
_write_json_if_needed(args.json_out, payload)
|
||||
if verify_report is not None:
|
||||
has_error = any(item.get("level") == "error" for item in verify_report.get("checks", []))
|
||||
if has_error:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
if args.command == "verify":
|
||||
report = _verify_impl(config_path, data_dir)
|
||||
_print_report("vNext Verify", report)
|
||||
_write_json_if_needed(args.json_out, report)
|
||||
has_error = any(item.get("level") == "error" for item in report.get("checks", []))
|
||||
if args.strict and has_error:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
return 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
144
src/A_memorix/scripts/runtime_self_check.py
Normal file
144
src/A_memorix/scripts/runtime_self_check.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Run A_Memorix runtime self-check against real embedding/runtime configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import tomlkit
|
||||
|
||||
from _bootstrap import DEFAULT_CONFIG_PATH, DEFAULT_DATA_DIR, resolve_repo_path
|
||||
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="A_Memorix runtime self-check")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=str(DEFAULT_CONFIG_PATH),
|
||||
help="config.toml path (default: config/a_memorix.toml)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default="",
|
||||
help="optional data dir override; default resolved from config.storage.data_dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-config-data-dir",
|
||||
action="store_true",
|
||||
help="use config.storage.data_dir directly instead of an isolated temp dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-text",
|
||||
default="A_Memorix runtime self check",
|
||||
help="sample text used for real embedding probe",
|
||||
)
|
||||
parser.add_argument("--json", action="store_true", help="print JSON report")
|
||||
return parser
|
||||
|
||||
|
||||
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
_build_arg_parser().print_help()
|
||||
raise SystemExit(0)
|
||||
|
||||
from A_memorix.core.runtime.lifecycle_orchestrator import initialize_storage_async
|
||||
from A_memorix.core.utils.runtime_self_check import run_embedding_runtime_self_check
|
||||
|
||||
|
||||
def _load_config(path: Path) -> dict[str, Any]:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
raw = tomlkit.load(f)
|
||||
return dict(raw) if isinstance(raw, dict) else {}
|
||||
|
||||
|
||||
def _nested_get(config: dict[str, Any], key: str, default: Any = None) -> Any:
|
||||
current: Any = config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
class _PluginStub:
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
self.config = config
|
||||
self.vector_store = None
|
||||
self.graph_store = None
|
||||
self.metadata_store = None
|
||||
self.embedding_manager = None
|
||||
self.sparse_index = None
|
||||
self.relation_write_service = None
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
return _nested_get(self.config, key, default)
|
||||
|
||||
|
||||
async def _main_async(args: argparse.Namespace) -> int:
|
||||
config_path = resolve_repo_path(args.config, fallback=DEFAULT_CONFIG_PATH)
|
||||
if not config_path.exists():
|
||||
print(f"❌ 配置文件不存在: {config_path}")
|
||||
return 2
|
||||
|
||||
config = _load_config(config_path)
|
||||
temp_dir_ctx = None
|
||||
if args.data_dir:
|
||||
storage_dir = str(resolve_repo_path(args.data_dir, fallback=DEFAULT_DATA_DIR))
|
||||
elif args.use_config_data_dir:
|
||||
raw_data_dir = str(_nested_get(config, "storage.data_dir", "./data") or "./data").strip()
|
||||
storage_dir = str(resolve_repo_path(raw_data_dir, fallback=DEFAULT_DATA_DIR))
|
||||
else:
|
||||
temp_dir_ctx = tempfile.TemporaryDirectory(prefix="memorix-runtime-self-check-")
|
||||
storage_dir = temp_dir_ctx.name
|
||||
|
||||
storage_cfg = config.setdefault("storage", {})
|
||||
storage_cfg["data_dir"] = storage_dir
|
||||
|
||||
plugin = _PluginStub(config)
|
||||
try:
|
||||
await initialize_storage_async(plugin)
|
||||
report = await run_embedding_runtime_self_check(
|
||||
config=config,
|
||||
vector_store=plugin.vector_store,
|
||||
embedding_manager=plugin.embedding_manager,
|
||||
sample_text=str(args.sample_text or "A_Memorix runtime self check"),
|
||||
)
|
||||
report["data_dir"] = storage_dir
|
||||
report["isolated_data_dir"] = temp_dir_ctx is not None
|
||||
if args.json:
|
||||
print(json.dumps(report, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print("A_Memorix Runtime Self-Check")
|
||||
print(f"ok: {report.get('ok')}")
|
||||
print(f"code: {report.get('code')}")
|
||||
print(f"message: {report.get('message')}")
|
||||
print(f"configured_dimension: {report.get('configured_dimension')}")
|
||||
print(f"vector_store_dimension: {report.get('vector_store_dimension')}")
|
||||
print(f"detected_dimension: {report.get('detected_dimension')}")
|
||||
print(f"encoded_dimension: {report.get('encoded_dimension')}")
|
||||
print(f"elapsed_ms: {float(report.get('elapsed_ms', 0.0)):.2f}")
|
||||
return 0 if bool(report.get("ok")) else 1
|
||||
finally:
|
||||
if plugin.metadata_store is not None:
|
||||
try:
|
||||
plugin.metadata_store.close()
|
||||
except Exception:
|
||||
pass
|
||||
if temp_dir_ctx is not None:
|
||||
temp_dir_ctx.cleanup()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
return asyncio.run(_main_async(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
913
src/A_memorix/web/import.html
Normal file
913
src/A_memorix/web/import.html
Normal file
@@ -0,0 +1,913 @@
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>A_Memorix 导入中心</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #0b1220;
|
||||
--panel: #111827;
|
||||
--border: #334155;
|
||||
--text: #e2e8f0;
|
||||
--muted: #94a3b8;
|
||||
--pri: #38bdf8;
|
||||
--ok: #10b981;
|
||||
--warn: #f59e0b;
|
||||
--err: #ef4444;
|
||||
}
|
||||
* { box-sizing: border-box; }
|
||||
body { margin: 0; color: var(--text); font-family: "Segoe UI", "Microsoft YaHei", sans-serif; background: radial-gradient(circle at 0 0, #0f2a4a, transparent 35%), var(--bg); }
|
||||
.page { width: min(1400px, 96vw); margin: 18px auto 28px; }
|
||||
.top { display: flex; gap: 10px; justify-content: space-between; align-items: end; margin-bottom: 12px; }
|
||||
.top-mid { flex: 1; display: flex; justify-content: center; align-items: end; }
|
||||
.title { font-size: 22px; font-weight: 700; }
|
||||
.sub { color: var(--muted); font-size: 12px; }
|
||||
.token { display: flex; gap: 8px; min-width: 360px; }
|
||||
.grid { display: grid; grid-template-columns: 420px 1fr; gap: 12px; }
|
||||
.card { background: rgba(17, 24, 39, 0.92); border: 1px solid var(--border); border-radius: 12px; overflow: hidden; }
|
||||
.hd { padding: 10px 12px; font-size: 13px; color: var(--pri); border-bottom: 1px solid var(--border); display: flex; justify-content: space-between; }
|
||||
.bd { padding: 12px; }
|
||||
.row { display: grid; grid-template-columns: 1fr 1fr; gap: 8px; margin-bottom: 8px; }
|
||||
.row3 { display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 8px; margin-bottom: 8px; }
|
||||
label { display: block; color: var(--muted); font-size: 12px; margin-bottom: 4px; }
|
||||
input, select, textarea, button { font: inherit; }
|
||||
input, select, textarea { width: 100%; border: 1px solid var(--border); border-radius: 8px; padding: 7px 9px; background: #0f172a; color: var(--text); }
|
||||
input[type="checkbox"] { width: 16px; height: 16px; padding: 0; border-radius: 4px; accent-color: var(--pri); flex: 0 0 auto; }
|
||||
textarea { min-height: 120px; resize: vertical; }
|
||||
.checkline { display: flex; align-items: center; gap: 8px; color: var(--text); margin-bottom: 8px; line-height: 1.4; }
|
||||
.checkline:last-child { margin-bottom: 0; }
|
||||
.checkgrid { display: grid; grid-template-columns: 1fr 1fr; gap: 8px 14px; margin-bottom: 8px; }
|
||||
.file-pick { display: flex; align-items: center; gap: 8px; min-height: 38px; padding: 4px; border: 1px solid var(--border); border-radius: 8px; background: #0f172a; }
|
||||
.file-pick .tiny { margin: 0; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
|
||||
.file-pick input[type="file"] { display: none; }
|
||||
.modal-mask { position: fixed; inset: 0; background: rgba(2, 6, 23, 0.72); display: none; align-items: center; justify-content: center; z-index: 120; padding: 20px; }
|
||||
.modal-mask.show { display: flex; }
|
||||
.modal { width: min(980px, 96vw); max-height: 86vh; background: #0b1528; border: 1px solid var(--border); border-radius: 12px; display: flex; flex-direction: column; overflow: hidden; }
|
||||
.modal-hd { display: flex; justify-content: space-between; align-items: center; padding: 10px 12px; border-bottom: 1px solid var(--border); color: var(--pri); }
|
||||
.modal-bd { padding: 12px; overflow: auto; line-height: 1.62; font-size: 14px; }
|
||||
.md h1, .md h2, .md h3 { margin: 14px 0 8px; color: #f8fafc; }
|
||||
.md p { margin: 8px 0; }
|
||||
.md ul { margin: 8px 0 8px 20px; padding: 0; }
|
||||
.md pre { margin: 10px 0; padding: 10px; border: 1px solid var(--border); border-radius: 8px; background: #0f172a; overflow: auto; }
|
||||
.md code { font-family: Consolas, Menlo, Monaco, monospace; font-size: 12px; }
|
||||
.md a { color: #67e8f9; text-decoration: none; }
|
||||
.md a:hover { text-decoration: underline; }
|
||||
.md blockquote { margin: 10px 0; padding: 8px 10px; border-left: 3px solid var(--pri); background: rgba(56, 189, 248, 0.08); color: #cbd5e1; }
|
||||
.tabs { display: flex; gap: 6px; margin-bottom: 8px; flex-wrap: wrap; }
|
||||
.tabs button { flex: 1 1 calc(33.33% - 6px); min-width: 104px; border: 1px solid var(--border); border-radius: 8px; background: #1f2937; color: var(--text); padding: 7px; cursor: pointer; }
|
||||
.tabs button.active { background: linear-gradient(120deg, #0ea5e9, #22d3ee); color: #01243a; font-weight: 700; border: none; }
|
||||
.panel { display: none; }
|
||||
.panel.active { display: block; }
|
||||
.btns { display: flex; gap: 8px; flex-wrap: wrap; }
|
||||
.btn { border: 1px solid var(--border); border-radius: 8px; padding: 7px 10px; background: #1f2937; color: var(--text); cursor: pointer; }
|
||||
.btn.p { background: linear-gradient(120deg, #0ea5e9, #22d3ee); color: #022f4d; border: none; font-weight: 700; }
|
||||
.btn.warn { border-color: #854d0e; color: #fbbf24; }
|
||||
.btn.err { border-color: #7f1d1d; color: #fca5a5; }
|
||||
.tiny { color: var(--muted); font-size: 12px; }
|
||||
.list { max-height: 150px; overflow: auto; border: 1px dashed var(--border); border-radius: 8px; padding: 6px; }
|
||||
.list-item { display: flex; justify-content: space-between; gap: 8px; padding: 5px; border-radius: 7px; }
|
||||
.list-item:hover { background: #1e293b; }
|
||||
.task-list { max-height: 260px; overflow: auto; display: grid; gap: 7px; }
|
||||
.task { border: 1px solid var(--border); border-radius: 8px; padding: 7px; cursor: pointer; background: #0f172a; }
|
||||
.task.active { border-color: var(--pri); }
|
||||
.badge { font-size: 11px; border: 1px solid var(--border); border-radius: 999px; padding: 1px 7px; }
|
||||
.b-run { color: #67e8f9; } .b-q { color: #c4b5fd; } .b-ok { color: #6ee7b7; } .b-err { color: #fca5a5; } .b-cancel { color: #fcd34d; }
|
||||
.bar { height: 7px; border-radius: 999px; background: #334155; overflow: hidden; margin-top: 5px; }
|
||||
.bar > div { height: 100%; background: linear-gradient(120deg, #0ea5e9, #22d3ee); width: 0; }
|
||||
.mgrid { display: grid; grid-template-columns: repeat(4, 1fr); gap: 7px; margin-top: 8px; }
|
||||
.m { border: 1px solid var(--border); border-radius: 8px; padding: 7px; background: #0f172a; }
|
||||
.m .k { color: var(--muted); font-size: 11px; } .m .v { font-size: 16px; font-weight: 700; margin-top: 3px; }
|
||||
table { width: 100%; border-collapse: collapse; font-size: 12px; }
|
||||
th, td { text-align: left; border-bottom: 1px solid var(--border); padding: 7px 5px; vertical-align: top; }
|
||||
tr.pick { cursor: pointer; } tr.pick.active { background: #10324f; }
|
||||
.foot { display: flex; justify-content: space-between; align-items: center; margin-top: 8px; }
|
||||
#toast { position: fixed; top: 12px; left: 50%; transform: translateX(-50%); background: #111827; border: 1px solid var(--border); border-radius: 8px; padding: 8px 11px; display: none; z-index: 99; }
|
||||
@media (max-width: 1100px) {
|
||||
.grid { grid-template-columns: 1fr; }
|
||||
.token { min-width: 0; width: 100%; }
|
||||
.top { flex-direction: column; align-items: stretch; }
|
||||
.top-mid { justify-content: flex-start; }
|
||||
.mgrid { grid-template-columns: repeat(2, 1fr); }
|
||||
.checkgrid { grid-template-columns: 1fr; }
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="toast"></div>
|
||||
<div id="guide-modal" class="modal-mask" role="dialog" aria-modal="true" aria-labelledby="guide-modal-title">
|
||||
<div class="modal">
|
||||
<div class="modal-hd">
|
||||
<div id="guide-modal-title">导入相关文档</div>
|
||||
<button class="btn" id="guide-close" type="button">关闭</button>
|
||||
</div>
|
||||
<div class="tiny" id="guide-meta" style="padding: 8px 12px 0;"></div>
|
||||
<div class="modal-bd md" id="guide-body"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="page">
|
||||
<div class="top">
|
||||
<div>
|
||||
<div class="title">A_Memorix 导入中心</div>
|
||||
<div class="sub">可控并发、部分文件导入、粘贴导入、分块级进度</div>
|
||||
</div>
|
||||
<div class="top-mid">
|
||||
<button class="btn" id="guide-open" type="button">阅读导入文档</button>
|
||||
<button class="btn" type="button" onclick="window.open('/tuning', '_blank')">检索调优</button>
|
||||
</div>
|
||||
<div class="token">
|
||||
<input id="token-input" type="password" placeholder="可选:X-Memorix-Import-Token" />
|
||||
<button class="btn" id="token-save">保存</button>
|
||||
<button class="btn" id="token-clear">清空</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="grid">
|
||||
<div>
|
||||
<div class="card">
|
||||
<div class="hd">创建导入任务</div>
|
||||
<div class="bd">
|
||||
<div class="tabs">
|
||||
<button id="tab-upload" class="active">上传文件</button>
|
||||
<button id="tab-paste">粘贴导入</button>
|
||||
<button id="tab-raw">本地扫描</button>
|
||||
<button id="tab-openie">LPMM OpenIE</button>
|
||||
<button id="tab-convert">LPMM转换</button>
|
||||
<button id="tab-backfill">时序回填</button>
|
||||
<button id="tab-maibot">MaiBot迁移</button>
|
||||
</div>
|
||||
|
||||
<div class="row3">
|
||||
<div>
|
||||
<label>文件并发</label>
|
||||
<input id="file-concurrency" type="number" min="1" max="6" value="2" />
|
||||
</div>
|
||||
<div>
|
||||
<label>分块并发</label>
|
||||
<input id="chunk-concurrency" type="number" min="1" max="12" value="4" />
|
||||
</div>
|
||||
<div>
|
||||
<label>策略覆盖</label>
|
||||
<select id="strategy-override">
|
||||
<option value="auto">自动(auto)</option>
|
||||
<option value="narrative">叙事(narrative)</option>
|
||||
<option value="factual">事实(factual)</option>
|
||||
<option value="quote">引用(quote)</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row3">
|
||||
<div>
|
||||
<label>去重策略</label>
|
||||
<select id="dedupe-policy">
|
||||
<option value="content_hash">内容哈希(content_hash)</option>
|
||||
<option value="manifest">导入清单(manifest)</option>
|
||||
<option value="none">不去重(none)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label>聊天参考时间(chat_reference_time)</label>
|
||||
<input id="chat-reference-time" placeholder="2026/02/12 10:30" />
|
||||
</div>
|
||||
<div></div>
|
||||
</div>
|
||||
<div class="checkgrid tiny">
|
||||
<label class="checkline">
|
||||
<input id="llm-enabled" type="checkbox" checked />
|
||||
<span>启用 LLM 抽取</span>
|
||||
</label>
|
||||
<label class="checkline">
|
||||
<input id="chat-log" type="checkbox" />
|
||||
<span>按聊天日志抽取时间(chat_log)</span>
|
||||
</label>
|
||||
<label class="checkline">
|
||||
<input id="force-reimport" type="checkbox" />
|
||||
<span>强制重导(force)</span>
|
||||
</label>
|
||||
<label class="checkline">
|
||||
<input id="clear-manifest" type="checkbox" />
|
||||
<span>清理导入清单(clear_manifest)</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div id="panel-upload" class="panel active">
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>文本输入模式</label>
|
||||
<select id="upload-input-mode">
|
||||
<option value="text">文本(text)</option>
|
||||
<option value="json">JSON(json)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label>选择文件 (txt/md/json)</label>
|
||||
<div class="file-pick">
|
||||
<button class="btn" id="upload-file-pick" type="button">选择文件</button>
|
||||
<div class="tiny" id="upload-file-hint">未选择文件</div>
|
||||
<input id="upload-file-input" type="file" multiple accept=".txt,.md,.json" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="tiny" id="upload-count">已选择 0 个文件</div>
|
||||
<div class="list" id="upload-files"></div>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn p" id="upload-submit">提交上传任务</button>
|
||||
<button class="btn" id="upload-clear">清空文件列表</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="panel-paste" class="panel">
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>粘贴模式</label>
|
||||
<select id="paste-input-mode">
|
||||
<option value="text">文本(text)</option>
|
||||
<option value="json">JSON(json)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label>名称(可选)</label>
|
||||
<input id="paste-name" placeholder="paste_时间戳" />
|
||||
</div>
|
||||
</div>
|
||||
<label>内容</label>
|
||||
<textarea id="paste-content" placeholder="粘贴 text/json"></textarea>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn p" id="paste-submit">提交粘贴任务</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="panel-raw" class="panel">
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>路径别名</label>
|
||||
<select id="raw-alias"></select>
|
||||
</div>
|
||||
<div>
|
||||
<label>相对路径(relative_path)</label>
|
||||
<input id="raw-relative-path" placeholder="raw 子目录(可选)" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row3">
|
||||
<div>
|
||||
<label>匹配规则(glob)</label>
|
||||
<input id="raw-glob" value="*" />
|
||||
</div>
|
||||
<div>
|
||||
<label>输入模式</label>
|
||||
<select id="raw-input-mode">
|
||||
<option value="text">文本(text)</option>
|
||||
<option value="json">JSON(json)</option>
|
||||
</select>
|
||||
</div>
|
||||
<label class="checkline tiny" style="margin-top: 20px;">
|
||||
<input id="raw-recursive" type="checkbox" checked />
|
||||
<span>递归扫描</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn p" id="raw-submit">提交本地扫描任务</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="panel-openie" class="panel">
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>路径别名</label>
|
||||
<select id="openie-alias"></select>
|
||||
</div>
|
||||
<div>
|
||||
<label>相对路径(relative_path)</label>
|
||||
<input id="openie-relative-path" placeholder="OpenIE 目录或文件" />
|
||||
</div>
|
||||
</div>
|
||||
<label class="checkline tiny">
|
||||
<input id="openie-include-all" type="checkbox" />
|
||||
<span>找不到 *-openie.json 时回退全部 .json</span>
|
||||
</label>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn p" id="openie-submit">提交 OpenIE 任务</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="panel-convert" class="panel">
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>输入别名</label>
|
||||
<select id="convert-alias"></select>
|
||||
</div>
|
||||
<div>
|
||||
<label>输入相对路径(relative_path)</label>
|
||||
<input id="convert-relative-path" placeholder="LPMM 数据目录" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>目标别名</label>
|
||||
<select id="convert-target-alias"></select>
|
||||
</div>
|
||||
<div>
|
||||
<label>目标相对路径(relative_path)</label>
|
||||
<input id="convert-target-relative-path" placeholder="可选:目标子目录" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row3">
|
||||
<div>
|
||||
<label>向量维度(dimension)</label>
|
||||
<input id="convert-dimension" type="number" min="1" placeholder="默认读取配置" />
|
||||
</div>
|
||||
<div>
|
||||
<label>批大小(batch_size)</label>
|
||||
<input id="convert-batch-size" type="number" min="1" value="1024" />
|
||||
</div>
|
||||
<div></div>
|
||||
</div>
|
||||
<div class="tiny">将执行 staging 转换与切换,请确认输入目录正确。</div>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn warn" id="convert-submit">提交 LPMM 转换任务</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="panel-backfill" class="panel">
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>路径别名</label>
|
||||
<select id="backfill-alias"></select>
|
||||
</div>
|
||||
<div>
|
||||
<label>相对路径(relative_path)</label>
|
||||
<input id="backfill-relative-path" placeholder="默认插件 data 目录" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row3">
|
||||
<div>
|
||||
<label>处理上限(limit)</label>
|
||||
<input id="backfill-limit" type="number" min="1" value="100000" />
|
||||
</div>
|
||||
<label class="checkline tiny" style="margin-top: 20px;">
|
||||
<input id="backfill-dry-run" type="checkbox" />
|
||||
<span>仅预览(dry-run)</span>
|
||||
</label>
|
||||
<label class="checkline tiny" style="margin-top: 20px;">
|
||||
<input id="backfill-no-created" type="checkbox" />
|
||||
<span>禁用 created 回退(no-created-fallback)</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn p" id="backfill-submit">提交时序回填任务</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="panel-maibot" class="panel">
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>源数据库路径</label>
|
||||
<input id="maibot-source-db" placeholder="data/MaiBot.db" />
|
||||
</div>
|
||||
<div>
|
||||
<label>时间范围(可选)</label>
|
||||
<input id="maibot-time-from" placeholder="from: YYYY-MM-DD HH:mm" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>时间范围(可选)</label>
|
||||
<input id="maibot-time-to" placeholder="to: YYYY-MM-DD HH:mm" />
|
||||
</div>
|
||||
<div>
|
||||
<label>ID范围(可选)</label>
|
||||
<input id="maibot-start-id" type="number" min="1" placeholder="start_id" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>ID范围(可选)</label>
|
||||
<input id="maibot-end-id" type="number" min="1" placeholder="end_id" />
|
||||
</div>
|
||||
<div>
|
||||
<label>stream_ids(逗号分隔)</label>
|
||||
<input id="maibot-stream-ids" placeholder="stream1,stream2" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>group_ids(逗号分隔)</label>
|
||||
<input id="maibot-group-ids" placeholder="123456,234567" />
|
||||
</div>
|
||||
<div>
|
||||
<label>user_ids(逗号分隔)</label>
|
||||
<input id="maibot-user-ids" placeholder="10001,10002" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row3">
|
||||
<div>
|
||||
<label>读取批大小(read_batch_size)</label>
|
||||
<input id="maibot-read-batch-size" type="number" min="1" value="2000" />
|
||||
</div>
|
||||
<div>
|
||||
<label>提交窗口行数(commit_window_rows)</label>
|
||||
<input id="maibot-commit-window-rows" type="number" min="1" value="20000" />
|
||||
</div>
|
||||
<div>
|
||||
<label>嵌入工作线程(embed_workers)</label>
|
||||
<input id="maibot-embed-workers" type="number" min="1" placeholder="默认读取配置" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="checkgrid tiny">
|
||||
<label class="checkline">
|
||||
<input id="maibot-no-resume" type="checkbox" />
|
||||
<span>禁用断点续传(--no-resume)</span>
|
||||
</label>
|
||||
<label class="checkline">
|
||||
<input id="maibot-reset-state" type="checkbox" />
|
||||
<span>重置迁移状态(--reset-state)</span>
|
||||
</label>
|
||||
<label class="checkline">
|
||||
<input id="maibot-dry-run" type="checkbox" />
|
||||
<span>仅预览(--dry-run)</span>
|
||||
</label>
|
||||
<label class="checkline">
|
||||
<input id="maibot-verify-only" type="checkbox" />
|
||||
<span>仅校验(--verify-only)</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn p" id="maibot-submit">提交 MaiBot 迁移任务</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card" style="margin-top: 12px;">
|
||||
<div class="hd"><span>任务队列</span><span class="tiny" id="poll-meta">轮询: 1000ms</span></div>
|
||||
<div class="bd">
|
||||
<div class="btns" style="margin-bottom: 9px;">
|
||||
<button class="btn" id="refresh-btn">立即刷新</button>
|
||||
<button class="btn err" id="cancel-btn">取消任务</button>
|
||||
<button class="btn warn" id="retry-btn">重试失败项(分块优先)</button>
|
||||
</div>
|
||||
<div class="tiny">运行中 / 准备中</div>
|
||||
<div class="task-list" id="tasks-running"></div>
|
||||
<div class="tiny" style="margin-top: 8px;">排队中</div>
|
||||
<div class="task-list" id="tasks-queued"></div>
|
||||
<div class="tiny" style="margin-top: 8px;">最近完成</div>
|
||||
<div class="task-list" id="tasks-recent"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div class="card">
|
||||
<div class="hd">任务详情</div>
|
||||
<div class="bd">
|
||||
<div id="task-empty" class="tiny">请选择任务查看详情</div>
|
||||
<div id="task-body" style="display: none;">
|
||||
<div class="row">
|
||||
<div><div class="tiny">任务 ID</div><div id="d-id" style="font-size: 12px;"></div></div>
|
||||
<div><div class="tiny">状态 / 步骤</div><div id="d-status"></div></div>
|
||||
</div>
|
||||
<div class="bar"><div id="d-progress"></div></div>
|
||||
<div class="tiny" id="d-progress-text" style="margin-top: 5px;"></div>
|
||||
<div class="mgrid">
|
||||
<div class="m"><div class="k">总分块</div><div class="v" id="m-total">0</div></div>
|
||||
<div class="m"><div class="k">完成</div><div class="v" id="m-done">0</div></div>
|
||||
<div class="m"><div class="k">失败</div><div class="v" id="m-fail">0</div></div>
|
||||
<div class="m"><div class="k">取消</div><div class="v" id="m-cancel">0</div></div>
|
||||
</div>
|
||||
<div id="d-error" style="color: #fca5a5; margin-top: 6px;"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card" style="margin-top: 12px;">
|
||||
<div class="hd">文件级状态</div>
|
||||
<div class="bd">
|
||||
<table>
|
||||
<thead><tr><th>文件</th><th>类型</th><th>状态</th><th>步骤</th><th>进度</th><th>统计</th></tr></thead>
|
||||
<tbody id="files-body"></tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card" style="margin-top: 12px;">
|
||||
<div class="hd">分块级状态</div>
|
||||
<div class="bd">
|
||||
<table>
|
||||
<thead><tr><th>#</th><th>类型</th><th>状态</th><th>步骤</th><th>预览</th><th>错误</th></tr></thead>
|
||||
<tbody id="chunks-body"></tbody>
|
||||
</table>
|
||||
<div class="foot">
|
||||
<div class="tiny" id="chunk-meta">0 / 0</div>
|
||||
<div class="btns">
|
||||
<button class="btn" id="chunk-prev">上一页</button>
|
||||
<button class="btn" id="chunk-next">下一页</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const state = { files: [], tasks: [], task: null, taskId: null, fileId: null, chunkOffset: 0, chunkLimit: 100, chunkTotal: 0, settings: {}, pathAliases: {}, pollMs: 1000, timer: null, pollErrSig: "", pollErrAt: 0 };
|
||||
const $ = (id) => document.getElementById(id);
|
||||
const esc = (s) => String(s ?? "").replaceAll("&","&").replaceAll("<","<").replaceAll(">",">").replaceAll('"',""").replaceAll("'","'");
|
||||
const pct = (n) => `${(Math.max(0, Math.min(1, Number(n || 0))) * 100).toFixed(1)}%`;
|
||||
const badgeClass = (s) => ["running","preparing","cancel_requested"].includes(s) ? "b-run" : s==="queued" ? "b-q" : ["completed","completed_with_errors"].includes(s) ? "b-ok" : s==="failed" ? "b-err" : s==="cancelled" ? "b-cancel" : "";
|
||||
const STATUS_ZH = {
|
||||
queued: "排队中",
|
||||
preparing: "准备中",
|
||||
running: "运行中",
|
||||
cancel_requested: "取消中",
|
||||
cancelled: "已取消",
|
||||
completed: "已完成",
|
||||
completed_with_errors: "完成(有错误)",
|
||||
failed: "失败",
|
||||
splitting: "分块中",
|
||||
extracting: "抽取中",
|
||||
writing: "写入中",
|
||||
saving: "保存中",
|
||||
};
|
||||
const STEP_ZH = {
|
||||
queued: "排队中",
|
||||
scanning: "扫描中",
|
||||
preparing: "准备中",
|
||||
splitting: "分块中",
|
||||
extracting: "抽取中",
|
||||
writing: "写入中",
|
||||
saving: "保存中",
|
||||
converting: "转换中",
|
||||
verifying: "校验中",
|
||||
switching: "切换中",
|
||||
backfilling: "回填中",
|
||||
cancel_requested: "取消中",
|
||||
cancelled: "已取消",
|
||||
completed: "已完成",
|
||||
completed_with_errors: "完成(有错误)",
|
||||
failed: "失败",
|
||||
};
|
||||
const STRATEGY_ZH = {
|
||||
auto: "自动",
|
||||
narrative: "叙事",
|
||||
factual: "事实",
|
||||
quote: "引用",
|
||||
json: "JSON",
|
||||
text: "文本",
|
||||
};
|
||||
const TASK_KIND_ZH = {
|
||||
upload: "上传文件",
|
||||
paste: "粘贴导入",
|
||||
raw_scan: "本地扫描",
|
||||
lpmm_openie: "LPMM OpenIE",
|
||||
lpmm_convert: "LPMM 转换",
|
||||
temporal_backfill: "时序回填",
|
||||
maibot_migration: "MaiBot 迁移",
|
||||
};
|
||||
const SCHEMA_ZH = {
|
||||
web_json: "Web JSON",
|
||||
script_json: "脚本 JSON",
|
||||
lpmm_openie: "LPMM OpenIE",
|
||||
plain_text: "纯文本",
|
||||
};
|
||||
const SOURCE_ZH = {
|
||||
upload: "上传文件",
|
||||
paste: "粘贴导入",
|
||||
raw_scan: "本地扫描",
|
||||
lpmm_openie: "LPMM OpenIE",
|
||||
lpmm_convert: "LPMM 转换",
|
||||
temporal_backfill: "时序回填",
|
||||
maibot_migration: "MaiBot 迁移",
|
||||
};
|
||||
const chunkTypeZh = (v) => STRATEGY_ZH[String(v || "").trim()] || String(v || "-");
|
||||
const statusZh = (v) => STATUS_ZH[String(v || "").trim()] || String(v || "-");
|
||||
const stepZh = (v) => STEP_ZH[String(v || "").trim()] || String(v || "-");
|
||||
const taskKindZh = (v) => TASK_KIND_ZH[String(v || "").trim()] || String(v || "-");
|
||||
const schemaZh = (v) => SCHEMA_ZH[String(v || "").trim()] || String(v || "-");
|
||||
const sourceZh = (v) => SOURCE_ZH[String(v || "").trim()] || String(v || "-");
|
||||
|
||||
function toast(msg) { const t=$("toast"); t.textContent=msg; t.style.display="block"; clearTimeout(window._tt); window._tt=setTimeout(()=>t.style.display="none",2200); }
|
||||
function token(){ return String(localStorage.getItem("memorix_import_token") || "").trim(); }
|
||||
function headers(isJson=true){ const h={}; if(isJson) h["Content-Type"]="application/json"; const tk=token(); if(tk) h["X-Memorix-Import-Token"]=tk; return h; }
|
||||
async function req(path,opt={}){ const r=await fetch(path,opt); let b=null; try{b=await r.json();}catch{} if(!r.ok) throw new Error(typeof b?.detail==="string"?b.detail:`请求失败(HTTP ${r.status})`); return b||{}; }
|
||||
const parseCsvList = (v) => String(v || "").split(",").map(x => x.trim()).filter(Boolean);
|
||||
const numOrNull = (v) => { const t = String(v ?? "").trim(); if(!t) return null; const n = Number(t); return Number.isFinite(n) ? n : null; };
|
||||
|
||||
function renderGuideMarkdown(md){
|
||||
const src = String(md || "").replace(/\r\n/g, "\n");
|
||||
const codeBlocks = [];
|
||||
let text = src.replace(/```([\s\S]*?)```/g, (_, code) => {
|
||||
const idx = codeBlocks.length;
|
||||
codeBlocks.push(`<pre><code>${esc(code).trim()}</code></pre>`);
|
||||
return `@@CODE_${idx}@@`;
|
||||
});
|
||||
text = esc(text);
|
||||
text = text.replace(/^###\s+(.+)$/gm, "<h3>$1</h3>");
|
||||
text = text.replace(/^##\s+(.+)$/gm, "<h2>$1</h2>");
|
||||
text = text.replace(/^#\s+(.+)$/gm, "<h1>$1</h1>");
|
||||
text = text.replace(/^>\s+(.+)$/gm, "<blockquote>$1</blockquote>");
|
||||
text = text.replace(/^\s*[-*]\s+(.+)$/gm, "<li>$1</li>");
|
||||
text = text.replace(/(?:<li>[\s\S]*?<\/li>\s*)+/g, (m) => `<ul>${m}</ul>`);
|
||||
text = text.replace(/\*\*([^*]+)\*\*/g, "<strong>$1</strong>");
|
||||
text = text.replace(/`([^`]+)`/g, "<code>$1</code>");
|
||||
text = text.replace(/\[([^\]]+)\]\((https?:\/\/[^)\s]+)\)/g, '<a href="$2" target="_blank" rel="noopener noreferrer">$1</a>');
|
||||
text = text.split(/\n{2,}/).map((blk) => {
|
||||
const b = blk.trim();
|
||||
if(!b) return "";
|
||||
if(/^@@CODE_\d+@@$/.test(b)) return b;
|
||||
if(/^<(h1|h2|h3|ul|pre|blockquote)/.test(b)) return b;
|
||||
return `<p>${b.replace(/\n/g, "<br/>")}</p>`;
|
||||
}).join("");
|
||||
text = text.replace(/@@CODE_(\d+)@@/g, (_, i) => codeBlocks[Number(i)] || "");
|
||||
return text || `<p class="tiny">文档为空</p>`;
|
||||
}
|
||||
|
||||
async function openGuideModal(){
|
||||
$("guide-modal").classList.add("show");
|
||||
$("guide-meta").textContent = "";
|
||||
$("guide-body").innerHTML = `<p class="tiny">正在加载文档...</p>`;
|
||||
try{
|
||||
const d = await req("/api/import/guide", { headers: headers(false) });
|
||||
const srcText = d.source === "remote" ? "远程" : "本地";
|
||||
const sourceDesc = d.source === "remote" ? (d.url || "") : (d.path || "");
|
||||
$("guide-meta").textContent = `来源: ${srcText}${sourceDesc ? ` | ${sourceDesc}` : ""}`;
|
||||
$("guide-body").innerHTML = renderGuideMarkdown(d.content || "");
|
||||
}catch(e){
|
||||
$("guide-meta").textContent = "";
|
||||
$("guide-body").innerHTML = `<p style="color:#fca5a5;">文档加载失败: ${esc(e.message || "")}</p>`;
|
||||
}
|
||||
}
|
||||
|
||||
function closeGuideModal(){ $("guide-modal").classList.remove("show"); }
|
||||
|
||||
function commonParams(){
|
||||
const mf=Number(state.settings.max_file_concurrency || 6), mc=Number(state.settings.max_chunk_concurrency || 12);
|
||||
const fc=Math.max(1,Math.min(mf,Number($("file-concurrency").value||2)));
|
||||
const cc=Math.max(1,Math.min(mc,Number($("chunk-concurrency").value||4)));
|
||||
$("file-concurrency").value=fc; $("chunk-concurrency").value=cc;
|
||||
return {
|
||||
file_concurrency: fc,
|
||||
chunk_concurrency: cc,
|
||||
llm_enabled: $("llm-enabled").checked,
|
||||
strategy_override: $("strategy-override").value,
|
||||
dedupe_policy: $("dedupe-policy").value,
|
||||
chat_log: !!$("chat-log").checked,
|
||||
chat_reference_time: ($("chat-reference-time").value || "").trim() || null,
|
||||
force: !!$("force-reimport").checked,
|
||||
clear_manifest: !!$("clear-manifest").checked,
|
||||
};
|
||||
}
|
||||
|
||||
function renderUploadFiles(){
|
||||
$("upload-count").textContent=`已选择 ${state.files.length} 个文件`;
|
||||
$("upload-file-hint").textContent = state.files.length ? `已选择 ${state.files.length} 个文件` : "未选择文件";
|
||||
$("upload-files").innerHTML = state.files.length ? state.files.map((f,i)=>`<div class="list-item"><div><div>${esc(f.name)}</div><div class="tiny">${(f.size/1024).toFixed(1)} KB</div></div><button class="btn" data-rm="${i}">移除</button></div>`).join("") : `<div class="tiny">暂无文件</div>`;
|
||||
}
|
||||
|
||||
function populateAliasSelects(){
|
||||
const aliases = state.pathAliases || {};
|
||||
const keys = Object.keys(aliases);
|
||||
const html = keys.length ? keys.map((k)=>`<option value="${esc(k)}">${esc(k)} (${esc(aliases[k])})</option>`).join("") : `<option value="">(无可用路径别名)</option>`;
|
||||
["raw-alias","openie-alias","convert-alias","convert-target-alias","backfill-alias"].forEach((id)=>{
|
||||
const el=$(id); if(!el) return;
|
||||
const old=el.value;
|
||||
el.innerHTML=html;
|
||||
if(old && keys.includes(old)) el.value=old;
|
||||
if(!el.value){
|
||||
if(id==="raw-alias" && keys.includes("raw")) el.value="raw";
|
||||
else if(id==="openie-alias" && keys.includes("lpmm")) el.value="lpmm";
|
||||
else if(id==="convert-alias" && keys.includes("lpmm")) el.value="lpmm";
|
||||
else if(id==="convert-target-alias" && keys.includes("plugin_data")) el.value="plugin_data";
|
||||
else if(id==="backfill-alias" && keys.includes("plugin_data")) el.value="plugin_data";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function renderTaskLists(){
|
||||
const g={run:[],q:[],r:[]};
|
||||
for(const t of state.tasks){ if(["running","preparing","cancel_requested"].includes(t.status)) g.run.push(t); else if(t.status==="queued") g.q.push(t); else g.r.push(t); }
|
||||
const render = (list) => list.length ? list.map(t=>`<div class="task ${state.taskId===t.task_id?"active":""}" data-tid="${t.task_id}"><div style="display:flex;justify-content:space-between;gap:8px;"><span>${esc(t.task_id.slice(0,12))}</span><span class="badge ${badgeClass(t.status)}">${esc(statusZh(t.status))}</span></div><div class="tiny">来源=${esc(sourceZh(t.source))} | 步骤=${esc(stepZh(t.current_step||"-"))}</div><div class="bar"><div style="width:${pct(t.progress)}"></div></div><div class="tiny">${pct(t.progress)} | ${t.done_chunks}/${t.total_chunks}</div></div>`).join("") : `<div class="tiny">暂无任务</div>`;
|
||||
$("tasks-running").innerHTML=render(g.run); $("tasks-queued").innerHTML=render(g.q); $("tasks-recent").innerHTML=render(g.r);
|
||||
}
|
||||
|
||||
async function loadTasks(){
|
||||
const d=await req("/api/import/tasks?limit=80",{headers:headers(false)});
|
||||
state.tasks=Array.isArray(d.items)?d.items:[]; state.settings=d.settings||state.settings||{};
|
||||
state.pathAliases = state.settings.path_aliases || state.pathAliases || {};
|
||||
const p=Number(state.settings.poll_interval_ms||1000); if(p!==state.pollMs){ state.pollMs=Math.max(200,p); restartPoll(); }
|
||||
$("poll-meta").textContent=`轮询: ${state.pollMs}ms`;
|
||||
if(state.settings.default_file_concurrency && !$("file-concurrency").dataset.touched) $("file-concurrency").value=state.settings.default_file_concurrency;
|
||||
if(state.settings.default_chunk_concurrency && !$("chunk-concurrency").dataset.touched) $("chunk-concurrency").value=state.settings.default_chunk_concurrency;
|
||||
if(state.settings.max_file_concurrency) $("file-concurrency").max=state.settings.max_file_concurrency;
|
||||
if(state.settings.max_chunk_concurrency) $("chunk-concurrency").max=state.settings.max_chunk_concurrency;
|
||||
if(state.settings.maibot_source_db_default && !$("maibot-source-db").dataset.touched) $("maibot-source-db").value=state.settings.maibot_source_db_default;
|
||||
populateAliasSelects();
|
||||
if(!state.taskId && state.tasks.length) state.taskId=state.tasks[0].task_id;
|
||||
if(state.taskId && !state.tasks.some(t=>t.task_id===state.taskId)){ state.taskId=state.tasks.length?state.tasks[0].task_id:null; state.fileId=null; }
|
||||
renderTaskLists();
|
||||
}
|
||||
|
||||
function renderTaskDetail(task){
|
||||
if(!task){ $("task-empty").style.display="block"; $("task-body").style.display="none"; $("files-body").innerHTML=`<tr><td colspan="6" class="tiny">暂无数据</td></tr>`; $("chunks-body").innerHTML=`<tr><td colspan="6" class="tiny">暂无数据</td></tr>`; $("chunk-meta").textContent="0 / 0"; return; }
|
||||
$("task-empty").style.display="none"; $("task-body").style.display="block";
|
||||
$("d-id").textContent=task.task_id; $("d-status").innerHTML=`<span class="badge ${badgeClass(task.status)}">${esc(statusZh(task.status))}</span> <span class="tiny">步骤=${esc(stepZh(task.current_step||"-"))}</span>`;
|
||||
$("d-progress").style.width=pct(task.progress); $("d-progress-text").textContent=`${pct(task.progress)} | 完成=${task.done_chunks} 失败=${task.failed_chunks} 取消=${task.cancelled_chunks} 总计=${task.total_chunks}`;
|
||||
$("m-total").textContent=task.total_chunks||0; $("m-done").textContent=task.done_chunks||0; $("m-fail").textContent=task.failed_chunks||0; $("m-cancel").textContent=task.cancelled_chunks||0;
|
||||
const extraMeta = [];
|
||||
if(task.schema_detected) extraMeta.push(`识别模式=${schemaZh(task.schema_detected)}`);
|
||||
if(task.task_kind) extraMeta.push(`任务类型=${taskKindZh(task.task_kind)}`);
|
||||
if(task.retry_parent_task_id) extraMeta.push(`父任务=${task.retry_parent_task_id}`);
|
||||
if(task.retry_summary && Object.keys(task.retry_summary).length) extraMeta.push(`重试摘要=${JSON.stringify(task.retry_summary)}`);
|
||||
if(task.artifact_paths && Object.keys(task.artifact_paths).length) extraMeta.push(`产物=${JSON.stringify(task.artifact_paths)}`);
|
||||
if(task.rollback_info && Object.keys(task.rollback_info).length) extraMeta.push(`回滚=${JSON.stringify(task.rollback_info)}`);
|
||||
const errText = task.error ? `错误: ${task.error}` : "";
|
||||
$("d-error").textContent = [errText, ...extraMeta].filter(Boolean).join(" | ");
|
||||
const files=Array.isArray(task.files)?task.files:[];
|
||||
$("files-body").innerHTML=files.length?files.map(f=>`<tr class="pick ${state.fileId===f.file_id?"active":""}" data-fid="${f.file_id}"><td>${esc(f.name)}</td><td>${esc(chunkTypeZh(f.detected_strategy_type||"-"))}</td><td><span class="badge ${badgeClass(f.status)}">${esc(statusZh(f.status))}</span></td><td>${esc(stepZh(f.current_step||"-"))}</td><td><div class="bar"><div style="width:${pct(f.progress)}"></div></div><div class="tiny">${pct(f.progress)}</div></td><td>${f.done_chunks}/${f.total_chunks} (失败:${f.failed_chunks} 取消:${f.cancelled_chunks})</td></tr>`).join(""):`<tr><td colspan="6" class="tiny">暂无文件</td></tr>`;
|
||||
}
|
||||
|
||||
async function loadTaskDetail(){
|
||||
if(!state.taskId){ renderTaskDetail(null); return; }
|
||||
const d=await req(`/api/import/tasks/${encodeURIComponent(state.taskId)}`,{headers:headers(false)});
|
||||
state.task=d.task||null; const files=Array.isArray(state.task?.files)?state.task.files:[]; if(!state.fileId || !files.some(x=>x.file_id===state.fileId)){ state.fileId=files.length?files[0].file_id:null; state.chunkOffset=0; }
|
||||
renderTaskDetail(state.task);
|
||||
}
|
||||
|
||||
async function loadChunks(){
|
||||
if(!state.taskId || !state.fileId){ $("chunks-body").innerHTML=`<tr><td colspan="6" class="tiny">请选择文件</td></tr>`; $("chunk-meta").textContent="0 / 0"; return; }
|
||||
const d=await req(`/api/import/tasks/${encodeURIComponent(state.taskId)}/files/${encodeURIComponent(state.fileId)}/chunks?offset=${state.chunkOffset}&limit=${state.chunkLimit}`,{headers:headers(false)});
|
||||
const it=Array.isArray(d.items)?d.items:[]; state.chunkTotal=Number(d.total||0);
|
||||
$("chunks-body").innerHTML=it.length?it.map(c=>`<tr><td>${c.index}</td><td>${esc(chunkTypeZh(c.chunk_type||"-"))}</td><td><span class="badge ${badgeClass(c.status)}">${esc(statusZh(c.status))}</span></td><td>${esc(stepZh(c.step||"-"))}</td><td title="${esc(c.content_preview||"")}">${esc(c.content_preview||"").slice(0,90)}</td><td style="color:#fca5a5">${esc(c.error||"")}</td></tr>`).join(""):`<tr><td colspan="6" class="tiny">暂无分块</td></tr>`;
|
||||
const from=state.chunkTotal?state.chunkOffset+1:0, to=Math.min(state.chunkOffset+state.chunkLimit,state.chunkTotal); $("chunk-meta").textContent=`${from}-${to} / ${state.chunkTotal}`;
|
||||
}
|
||||
|
||||
async function createUploadTask(){
|
||||
if(!state.files.length){ toast("请先选择文件"); return; }
|
||||
const fd=new FormData(); state.files.forEach(f=>fd.append("files[]",f,f.name)); fd.append("payload",JSON.stringify({ ...commonParams(), input_mode:$("upload-input-mode").value }));
|
||||
const d=await req("/api/import/tasks/upload",{method:"POST",headers:headers(false),body:fd}); if(d?.task?.task_id) state.taskId=d.task.task_id; state.files=[]; renderUploadFiles(); toast("上传任务已创建");
|
||||
}
|
||||
|
||||
async function createPasteTask(){
|
||||
const content=$("paste-content").value||""; if(!content.trim()){ toast("粘贴内容不能为空"); return; }
|
||||
const d=await req("/api/import/tasks/paste",{method:"POST",headers:headers(true),body:JSON.stringify({ ...commonParams(), input_mode:$("paste-input-mode").value, content, name:$("paste-name").value||"" })});
|
||||
if(d?.task?.task_id) state.taskId=d.task.task_id; toast("粘贴任务已创建");
|
||||
}
|
||||
|
||||
async function createRawScanTask(){
|
||||
const payload = {
|
||||
...commonParams(),
|
||||
alias: $("raw-alias").value,
|
||||
relative_path: ($("raw-relative-path").value || "").trim(),
|
||||
glob: ($("raw-glob").value || "*").trim() || "*",
|
||||
recursive: !!$("raw-recursive").checked,
|
||||
input_mode: $("raw-input-mode").value,
|
||||
};
|
||||
const d=await req("/api/import/tasks/raw_scan",{method:"POST",headers:headers(true),body:JSON.stringify(payload)});
|
||||
if(d?.task?.task_id) state.taskId=d.task.task_id;
|
||||
toast("本地扫描任务已创建");
|
||||
}
|
||||
|
||||
async function createLpmmOpenieTask(){
|
||||
const payload = {
|
||||
...commonParams(),
|
||||
alias: $("openie-alias").value,
|
||||
relative_path: ($("openie-relative-path").value || "").trim(),
|
||||
include_all_json: !!$("openie-include-all").checked,
|
||||
};
|
||||
const d=await req("/api/import/tasks/lpmm_openie",{method:"POST",headers:headers(true),body:JSON.stringify(payload)});
|
||||
if(d?.task?.task_id) state.taskId=d.task.task_id;
|
||||
toast("LPMM OpenIE 任务已创建");
|
||||
}
|
||||
|
||||
async function createLpmmConvertTask(){
|
||||
if(!confirm("该任务会执行 staging 转换并切换 vectors/graph/metadata,是否继续?")) return;
|
||||
const payload = {
|
||||
alias: $("convert-alias").value,
|
||||
relative_path: ($("convert-relative-path").value || "").trim(),
|
||||
target_alias: $("convert-target-alias").value,
|
||||
target_relative_path: ($("convert-target-relative-path").value || "").trim(),
|
||||
dimension: numOrNull($("convert-dimension").value),
|
||||
batch_size: numOrNull($("convert-batch-size").value),
|
||||
};
|
||||
const d=await req("/api/import/tasks/lpmm_convert",{method:"POST",headers:headers(true),body:JSON.stringify(payload)});
|
||||
if(d?.task?.task_id) state.taskId=d.task.task_id;
|
||||
toast("LPMM 转换任务已创建");
|
||||
}
|
||||
|
||||
async function createTemporalBackfillTask(){
|
||||
const payload = {
|
||||
alias: $("backfill-alias").value,
|
||||
relative_path: ($("backfill-relative-path").value || "").trim(),
|
||||
dry_run: !!$("backfill-dry-run").checked,
|
||||
no_created_fallback: !!$("backfill-no-created").checked,
|
||||
limit: numOrNull($("backfill-limit").value),
|
||||
};
|
||||
const d=await req("/api/import/tasks/temporal_backfill",{method:"POST",headers:headers(true),body:JSON.stringify(payload)});
|
||||
if(d?.task?.task_id) state.taskId=d.task.task_id;
|
||||
toast("时序回填任务已创建");
|
||||
}
|
||||
|
||||
async function createMaibotMigrationTask(){
|
||||
const payload = {
|
||||
source_db: ($("maibot-source-db").value || "").trim() || null,
|
||||
time_from: ($("maibot-time-from").value || "").trim() || null,
|
||||
time_to: ($("maibot-time-to").value || "").trim() || null,
|
||||
stream_ids: parseCsvList($("maibot-stream-ids").value),
|
||||
group_ids: parseCsvList($("maibot-group-ids").value),
|
||||
user_ids: parseCsvList($("maibot-user-ids").value),
|
||||
start_id: numOrNull($("maibot-start-id").value),
|
||||
end_id: numOrNull($("maibot-end-id").value),
|
||||
read_batch_size: numOrNull($("maibot-read-batch-size").value),
|
||||
commit_window_rows: numOrNull($("maibot-commit-window-rows").value),
|
||||
embed_workers: numOrNull($("maibot-embed-workers").value),
|
||||
no_resume: !!$("maibot-no-resume").checked,
|
||||
reset_state: !!$("maibot-reset-state").checked,
|
||||
dry_run: !!$("maibot-dry-run").checked,
|
||||
verify_only: !!$("maibot-verify-only").checked,
|
||||
};
|
||||
const d = await req("/api/import/tasks/maibot_migration", {
|
||||
method: "POST",
|
||||
headers: headers(true),
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
if(d?.task?.task_id) state.taskId = d.task.task_id;
|
||||
toast("MaiBot 迁移任务已创建");
|
||||
}
|
||||
|
||||
async function cancelTask(){ if(!state.taskId){ toast("请先选择任务"); return; } await req(`/api/import/tasks/${encodeURIComponent(state.taskId)}/cancel`,{method:"POST",headers:headers(false)}); toast("已请求取消"); }
|
||||
async function retryTask(){
|
||||
if(!state.taskId){ toast("请先选择任务"); return; }
|
||||
const d=await req(`/api/import/tasks/${encodeURIComponent(state.taskId)}/retry_failed`,{method:"POST",headers:headers(true),body:JSON.stringify(commonParams())});
|
||||
if(d?.task?.task_id) state.taskId=d.task.task_id;
|
||||
const rs = d?.retry_summary || d?.task?.retry_summary || {};
|
||||
const chunkFiles = Number(rs?.chunk_retry_files || 0);
|
||||
const chunkCount = Number(rs?.chunk_retry_chunks || 0);
|
||||
const fileFallback = Number(rs?.file_fallback_files || 0);
|
||||
toast(`重试任务已创建:分块重试 ${chunkFiles} 文件/${chunkCount} 分块,文件回退 ${fileFallback} 文件`);
|
||||
}
|
||||
|
||||
async function poll(){
|
||||
try{
|
||||
await loadTasks(); await loadTaskDetail(); await loadChunks();
|
||||
state.pollErrSig = "";
|
||||
}catch(e){
|
||||
const sig = String(e.message || "请求失败");
|
||||
const now = Date.now();
|
||||
if(sig !== state.pollErrSig || now - state.pollErrAt > 5000){
|
||||
toast(`请求失败: ${sig}`);
|
||||
state.pollErrSig = sig;
|
||||
state.pollErrAt = now;
|
||||
}
|
||||
}
|
||||
}
|
||||
function restartPoll(){ if(state.timer) clearInterval(state.timer); state.timer=setInterval(poll,state.pollMs); }
|
||||
|
||||
function bind(){
|
||||
$("token-input").value=token();
|
||||
$("token-save").onclick=()=>{ localStorage.setItem("memorix_import_token",String($("token-input").value||"").trim()); toast("Token 已保存"); };
|
||||
$("token-clear").onclick=()=>{ localStorage.removeItem("memorix_import_token"); $("token-input").value=""; toast("Token 已清空"); };
|
||||
$("guide-open").onclick=()=>openGuideModal();
|
||||
$("guide-close").onclick=()=>closeGuideModal();
|
||||
$("guide-modal").onclick=(e)=>{ if(e.target === $("guide-modal")) closeGuideModal(); };
|
||||
window.addEventListener("keydown",(e)=>{ if(e.key==="Escape" && $("guide-modal").classList.contains("show")) closeGuideModal(); });
|
||||
const tabs = ["upload","paste","raw","openie","convert","backfill","maibot"];
|
||||
const activateTab = (name) => {
|
||||
tabs.forEach((t)=>{
|
||||
const tab = $(`tab-${t}`);
|
||||
const panel = $(`panel-${t}`);
|
||||
if(tab) tab.classList.toggle("active", t===name);
|
||||
if(panel) panel.classList.toggle("active", t===name);
|
||||
});
|
||||
if(["raw","openie"].includes(name)){
|
||||
$("dedupe-policy").value = "manifest";
|
||||
}else if(["upload","paste"].includes(name)){
|
||||
$("dedupe-policy").value = "content_hash";
|
||||
}
|
||||
};
|
||||
tabs.forEach((t)=>{ const tab = $(`tab-${t}`); if(tab) tab.onclick=()=>activateTab(t); });
|
||||
$("file-concurrency").oninput=(e)=>e.target.dataset.touched="1"; $("chunk-concurrency").oninput=(e)=>e.target.dataset.touched="1";
|
||||
$("maibot-source-db").oninput=(e)=>e.target.dataset.touched="1";
|
||||
$("upload-file-pick").onclick=()=>$("upload-file-input").click();
|
||||
$("upload-file-input").onchange=(e)=>{ Array.from(e.target.files||[]).forEach(f=>state.files.push(f)); e.target.value=""; renderUploadFiles(); };
|
||||
$("upload-files").onclick=(e)=>{ const i=e.target?.dataset?.rm; if(i===undefined) return; state.files.splice(Number(i),1); renderUploadFiles(); };
|
||||
$("upload-clear").onclick=()=>{ state.files=[]; renderUploadFiles(); };
|
||||
$("upload-submit").onclick=async()=>{ try{ await createUploadTask(); await poll(); }catch(e){ toast(`上传失败: ${e.message}`); } };
|
||||
$("paste-submit").onclick=async()=>{ try{ await createPasteTask(); await poll(); }catch(e){ toast(`粘贴失败: ${e.message}`); } };
|
||||
$("raw-submit").onclick=async()=>{ try{ await createRawScanTask(); await poll(); }catch(e){ toast(`本地扫描失败: ${e.message}`); } };
|
||||
$("openie-submit").onclick=async()=>{ try{ await createLpmmOpenieTask(); await poll(); }catch(e){ toast(`OpenIE 导入失败: ${e.message}`); } };
|
||||
$("convert-submit").onclick=async()=>{ try{ await createLpmmConvertTask(); await poll(); }catch(e){ toast(`LPMM 转换失败: ${e.message}`); } };
|
||||
$("backfill-submit").onclick=async()=>{ try{ await createTemporalBackfillTask(); await poll(); }catch(e){ toast(`回填任务失败: ${e.message}`); } };
|
||||
$("maibot-submit").onclick=async()=>{ try{ await createMaibotMigrationTask(); await poll(); }catch(e){ toast(`迁移任务创建失败: ${e.message}`); } };
|
||||
$("refresh-btn").onclick=async()=>{ await poll(); toast("已刷新"); };
|
||||
$("cancel-btn").onclick=async()=>{ try{ await cancelTask(); await poll(); }catch(e){ toast(`取消失败: ${e.message}`); } };
|
||||
$("retry-btn").onclick=async()=>{ try{ await retryTask(); await poll(); }catch(e){ toast(`重试失败: ${e.message}`); } };
|
||||
["tasks-running","tasks-queued","tasks-recent"].forEach(id=>$(id).onclick=async(e)=>{ const n=e.target.closest("[data-tid]"); if(!n) return; state.taskId=n.dataset.tid; state.fileId=null; state.chunkOffset=0; renderTaskLists(); try{ await loadTaskDetail(); await loadChunks(); }catch(err){ toast(`加载失败: ${err.message}`); } });
|
||||
$("files-body").onclick=async(e)=>{ const n=e.target.closest("[data-fid]"); if(!n) return; state.fileId=n.dataset.fid; state.chunkOffset=0; renderTaskDetail(state.task); try{ await loadChunks(); }catch(err){ toast(`加载分块失败: ${err.message}`); } };
|
||||
$("chunk-prev").onclick=async()=>{ if(state.chunkOffset<=0) return; state.chunkOffset=Math.max(0,state.chunkOffset-state.chunkLimit); try{ await loadChunks(); }catch(e){ toast(`翻页失败: ${e.message}`); } };
|
||||
$("chunk-next").onclick=async()=>{ if(state.chunkOffset+state.chunkLimit>=state.chunkTotal) return; state.chunkOffset+=state.chunkLimit; try{ await loadChunks(); }catch(e){ toast(`翻页失败: ${e.message}`); } };
|
||||
}
|
||||
|
||||
bind(); renderUploadFiles(); poll(); restartPoll();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
3136
src/A_memorix/web/index.html
Normal file
3136
src/A_memorix/web/index.html
Normal file
File diff suppressed because it is too large
Load Diff
13
src/A_memorix/web/index.html.scratch
Normal file
13
src/A_memorix/web/index.html.scratch
Normal file
@@ -0,0 +1,13 @@
|
||||
function hexToRgba(hex, alpha) {
|
||||
let c;
|
||||
if (/^#([A-Fa-f0-9]{3}){1,2}$/.test(hex)) {
|
||||
c = hex.substring(1).split('');
|
||||
if (c.length === 3) {
|
||||
c = [c[0], c[0], c[1], c[1], c[2], c[2]];
|
||||
}
|
||||
c = '0x' + c.join('');
|
||||
return 'rgba(' + [(c >> 16) & 255, (c >> 8) & 255, c & 255].join(',') + ',' + alpha + ')';
|
||||
}
|
||||
// Fallback if not hex (e.g. already rgba or invalid)
|
||||
return hex;
|
||||
}
|
||||
722
src/A_memorix/web/tuning.html
Normal file
722
src/A_memorix/web/tuning.html
Normal file
@@ -0,0 +1,722 @@
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>A_Memorix 检索调优中心</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #08101f;
|
||||
--panel: #0f1a2f;
|
||||
--border: #28415f;
|
||||
--text: #dbeafe;
|
||||
--muted: #93c5fd;
|
||||
--pri: #22d3ee;
|
||||
--ok: #34d399;
|
||||
--warn: #fbbf24;
|
||||
--err: #f87171;
|
||||
}
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: "Segoe UI", "Microsoft YaHei", sans-serif;
|
||||
background: radial-gradient(circle at 0 0, #12335d, transparent 34%), var(--bg);
|
||||
color: var(--text);
|
||||
}
|
||||
.page { width: min(1500px, 96vw); margin: 18px auto 28px; }
|
||||
.top { display: flex; justify-content: space-between; gap: 12px; align-items: end; margin-bottom: 12px; }
|
||||
.title { font-size: 24px; font-weight: 700; }
|
||||
.sub { color: var(--muted); font-size: 12px; margin-top: 4px; }
|
||||
.btns { display: flex; gap: 8px; flex-wrap: wrap; }
|
||||
.grid { display: grid; grid-template-columns: 460px 1fr; gap: 12px; }
|
||||
.card { background: rgba(15, 26, 47, 0.94); border: 1px solid var(--border); border-radius: 12px; overflow: hidden; }
|
||||
.hd { display: flex; justify-content: space-between; align-items: center; padding: 10px 12px; border-bottom: 1px solid var(--border); color: var(--pri); font-size: 13px; }
|
||||
.bd { padding: 12px; }
|
||||
label { display: block; color: var(--muted); font-size: 12px; margin-bottom: 4px; }
|
||||
input, select, textarea, button { font: inherit; }
|
||||
input, select, textarea {
|
||||
width: 100%;
|
||||
background: #0a1425;
|
||||
color: var(--text);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
padding: 8px 10px;
|
||||
}
|
||||
textarea { min-height: 120px; resize: vertical; }
|
||||
.row { display: grid; grid-template-columns: 1fr 1fr; gap: 8px; margin-bottom: 8px; }
|
||||
.row3 { display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 8px; margin-bottom: 8px; }
|
||||
.btn {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
background: #152841;
|
||||
color: var(--text);
|
||||
padding: 8px 10px;
|
||||
cursor: pointer;
|
||||
}
|
||||
.btn.pri {
|
||||
border: none;
|
||||
color: #022f4d;
|
||||
font-weight: 700;
|
||||
background: linear-gradient(120deg, #22d3ee, #67e8f9);
|
||||
}
|
||||
.btn.warn { color: var(--warn); border-color: #854d0e; }
|
||||
.btn.err { color: var(--err); border-color: #7f1d1d; }
|
||||
.tiny { color: var(--muted); font-size: 12px; }
|
||||
.mono {
|
||||
font-family: Consolas, Menlo, Monaco, monospace;
|
||||
font-size: 12px;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
.split { display: grid; grid-template-columns: 1fr 1fr; gap: 10px; }
|
||||
.task-list { max-height: 280px; overflow: auto; border: 1px solid var(--border); border-radius: 8px; padding: 6px; }
|
||||
.task { border: 1px solid var(--border); background: #0a1425; border-radius: 8px; padding: 8px; margin-bottom: 7px; cursor: pointer; }
|
||||
.task:last-child { margin-bottom: 0; }
|
||||
.task.active { border-color: var(--pri); }
|
||||
.task .line { display: flex; justify-content: space-between; gap: 8px; margin-bottom: 4px; }
|
||||
.badge { border: 1px solid var(--border); border-radius: 999px; padding: 1px 8px; font-size: 11px; }
|
||||
table { width: 100%; border-collapse: collapse; font-size: 12px; }
|
||||
th, td { text-align: left; border-bottom: 1px solid var(--border); padding: 7px 4px; vertical-align: top; }
|
||||
#toast { position: fixed; top: 10px; left: 50%; transform: translateX(-50%); z-index: 100; display: none; padding: 8px 12px; border: 1px solid var(--border); border-radius: 8px; background: #0f172a; }
|
||||
.cmp-modal-mask {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
display: none;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: rgba(2, 6, 23, 0.72);
|
||||
z-index: 120;
|
||||
padding: 16px;
|
||||
}
|
||||
.cmp-modal-mask.show { display: flex; }
|
||||
.cmp-modal {
|
||||
width: min(1080px, 96vw);
|
||||
max-height: 90vh;
|
||||
background: #0a1425;
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.cmp-hd {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 10px 12px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
color: var(--pri);
|
||||
font-size: 14px;
|
||||
}
|
||||
.cmp-bd {
|
||||
padding: 12px;
|
||||
overflow: auto;
|
||||
display: grid;
|
||||
gap: 10px;
|
||||
}
|
||||
.cmp-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, minmax(0, 1fr));
|
||||
gap: 8px;
|
||||
}
|
||||
.cmp-metric {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
padding: 8px;
|
||||
background: #0f1a2f;
|
||||
}
|
||||
.cmp-metric .k {
|
||||
color: var(--muted);
|
||||
font-size: 12px;
|
||||
}
|
||||
.cmp-metric .v {
|
||||
margin-top: 4px;
|
||||
font-size: 16px;
|
||||
font-weight: 700;
|
||||
}
|
||||
.cmp-delta-pos { color: var(--ok); }
|
||||
.cmp-delta-neg { color: var(--err); }
|
||||
.cmp-bars {
|
||||
margin-top: 6px;
|
||||
display: grid;
|
||||
gap: 4px;
|
||||
}
|
||||
.cmp-bar-line {
|
||||
display: grid;
|
||||
grid-template-columns: 54px 1fr;
|
||||
gap: 6px;
|
||||
align-items: center;
|
||||
}
|
||||
.cmp-bar-wrap {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 999px;
|
||||
height: 8px;
|
||||
overflow: hidden;
|
||||
background: #08101f;
|
||||
}
|
||||
.cmp-bar-fill {
|
||||
height: 100%;
|
||||
background: linear-gradient(120deg, #22d3ee, #67e8f9);
|
||||
}
|
||||
.cmp-subcard {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
padding: 8px;
|
||||
background: #0f1a2f;
|
||||
}
|
||||
@media (max-width: 1150px) {
|
||||
.grid { grid-template-columns: 1fr; }
|
||||
.split { grid-template-columns: 1fr; }
|
||||
.row, .row3 { grid-template-columns: 1fr; }
|
||||
.cmp-grid { grid-template-columns: 1fr; }
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="toast"></div>
|
||||
<div id="cmp-modal-mask" class="cmp-modal-mask">
|
||||
<div class="cmp-modal">
|
||||
<div class="cmp-hd">
|
||||
<span>调优完成对比:Baseline vs Best</span>
|
||||
<button class="btn" id="cmp-close-btn">关闭</button>
|
||||
</div>
|
||||
<div class="cmp-bd">
|
||||
<div class="tiny" id="cmp-summary"></div>
|
||||
<div class="cmp-grid" id="cmp-metrics"></div>
|
||||
<div class="cmp-subcard">
|
||||
<div style="font-weight: 700; margin-bottom: 6px;">分类召回/精度对比</div>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>分类</th>
|
||||
<th>Baseline 召回</th>
|
||||
<th>Best 召回</th>
|
||||
<th>Δ召回</th>
|
||||
<th>Baseline P@1</th>
|
||||
<th>Best P@1</th>
|
||||
<th>ΔP@1</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="cmp-category-table"></tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="page">
|
||||
<div class="top">
|
||||
<div>
|
||||
<div class="title">A_Memorix 检索调优中心</div>
|
||||
<div class="sub">LLM 辅助 + 多轮调查 + 运行时参数应用(不自动写 config.toml)</div>
|
||||
</div>
|
||||
<div class="btns">
|
||||
<button class="btn" onclick="window.open('/', '_blank')">打开主面板</button>
|
||||
<button class="btn" onclick="window.open('/import', '_blank')">打开导入中心</button>
|
||||
<button class="btn pri" id="refresh-all">刷新</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="grid">
|
||||
<div>
|
||||
<div class="card">
|
||||
<div class="hd">当前运行时参数</div>
|
||||
<div class="bd">
|
||||
<div class="tiny" id="settings-tip"></div>
|
||||
<pre id="current-profile" class="mono"></pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card" style="margin-top: 10px;">
|
||||
<div class="hd">手动调参与应用</div>
|
||||
<div class="bd">
|
||||
<label>参数 JSON(支持局部字段)</label>
|
||||
<textarea id="manual-profile" class="mono"></textarea>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn pri" id="btn-apply">应用到运行时</button>
|
||||
<button class="btn warn" id="btn-rollback">回滚上次应用</button>
|
||||
<button class="btn" id="btn-export">导出 TOML 片段</button>
|
||||
</div>
|
||||
<label style="margin-top: 10px;">TOML 导出</label>
|
||||
<textarea id="toml-snippet" class="mono" style="min-height: 90px;"></textarea>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card" style="margin-top: 10px;">
|
||||
<div class="hd">创建自动调优任务</div>
|
||||
<div class="bd">
|
||||
<div class="row">
|
||||
<div>
|
||||
<label>目标函数</label>
|
||||
<select id="objective">
|
||||
<option value="precision_priority">precision_priority</option>
|
||||
<option value="balanced">balanced</option>
|
||||
<option value="recall_priority">recall_priority</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label>强度</label>
|
||||
<select id="intensity">
|
||||
<option value="standard">standard</option>
|
||||
<option value="quick">quick</option>
|
||||
<option value="deep">deep</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row3">
|
||||
<div>
|
||||
<label>轮次(可选)</label>
|
||||
<input id="rounds" type="number" min="1" max="200" placeholder="留空走强度默认" />
|
||||
</div>
|
||||
<div>
|
||||
<label>样本数</label>
|
||||
<input id="sample-size" type="number" min="4" max="500" value="24" />
|
||||
</div>
|
||||
<div>
|
||||
<label>评估 top_k</label>
|
||||
<input id="top-k-eval" type="number" min="5" max="100" value="20" />
|
||||
</div>
|
||||
</div>
|
||||
<label style="display: flex; align-items: center; gap: 8px; margin-bottom: 8px;">
|
||||
<input id="llm-enabled" type="checkbox" checked style="width: 16px; height: 16px;" />
|
||||
<span>启用 LLM 问题生成/失败模式建议(不可用自动退化)</span>
|
||||
</label>
|
||||
<div class="btns">
|
||||
<button class="btn pri" id="btn-create-task">创建任务</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div class="card">
|
||||
<div class="hd">
|
||||
<span>任务队列</span>
|
||||
<span class="tiny" id="task-meta"></span>
|
||||
</div>
|
||||
<div class="bd split">
|
||||
<div>
|
||||
<div class="task-list" id="task-list"></div>
|
||||
</div>
|
||||
<div>
|
||||
<label>任务详情</label>
|
||||
<pre id="task-detail" class="mono" style="min-height: 160px; max-height: 250px; overflow: auto; border: 1px solid var(--border); border-radius: 8px; padding: 8px;"></pre>
|
||||
<div class="btns" style="margin-top: 8px;">
|
||||
<button class="btn warn" id="btn-cancel-task">取消任务</button>
|
||||
<button class="btn pri" id="btn-apply-best">应用最优参数</button>
|
||||
<button class="btn" id="btn-load-report">加载报告</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card" style="margin-top: 10px;">
|
||||
<div class="hd">
|
||||
<span>轮次明细</span>
|
||||
<span class="tiny" id="round-meta"></span>
|
||||
</div>
|
||||
<div class="bd">
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Round</th>
|
||||
<th>Score</th>
|
||||
<th>P@1</th>
|
||||
<th>P@3</th>
|
||||
<th>MRR</th>
|
||||
<th>Recall@K</th>
|
||||
<th>SPO hit</th>
|
||||
<th>Empty</th>
|
||||
<th>Latency(ms)</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="round-table"></tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card" style="margin-top: 10px;">
|
||||
<div class="hd">报告预览</div>
|
||||
<div class="bd">
|
||||
<textarea id="report-content" class="mono" style="min-height: 240px;"></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const state = {
|
||||
settings: null,
|
||||
tasks: [],
|
||||
selectedTaskId: null,
|
||||
pollTimer: null,
|
||||
taskStatusMap: {},
|
||||
watchTaskIds: new Set(),
|
||||
completionPopupShown: new Set(),
|
||||
};
|
||||
|
||||
function _num(v, fallback = 0) {
|
||||
const x = Number(v);
|
||||
return Number.isFinite(x) ? x : fallback;
|
||||
}
|
||||
|
||||
function _pct(v) {
|
||||
return `${(_num(v, 0) * 100).toFixed(2)}%`;
|
||||
}
|
||||
|
||||
function _fmt(v, digits = 4) {
|
||||
return _num(v, 0).toFixed(digits);
|
||||
}
|
||||
|
||||
function _deltaClass(delta, reverse = false) {
|
||||
const d = _num(delta, 0);
|
||||
if (Math.abs(d) < 1e-12) return "";
|
||||
const positive = reverse ? d < 0 : d > 0;
|
||||
return positive ? "cmp-delta-pos" : "cmp-delta-neg";
|
||||
}
|
||||
|
||||
function _renderMetricCard({ name, base, best, percent = true, reverse = false }) {
|
||||
const b = _num(base, 0);
|
||||
const n = _num(best, 0);
|
||||
const d = n - b;
|
||||
const bTxt = percent ? _pct(b) : _fmt(b, 3);
|
||||
const nTxt = percent ? _pct(n) : _fmt(n, 3);
|
||||
const dTxt = `${d >= 0 ? "+" : ""}${percent ? _pct(d) : _fmt(d, 3)}`;
|
||||
const dClass = _deltaClass(d, reverse);
|
||||
const bWidth = percent ? Math.max(0, Math.min(100, b * 100)) : 100;
|
||||
const nWidth = percent ? Math.max(0, Math.min(100, n * 100)) : 100;
|
||||
return `
|
||||
<div class="cmp-metric">
|
||||
<div class="k">${name}</div>
|
||||
<div class="v">${bTxt} -> ${nTxt}</div>
|
||||
<div class="${dClass}" style="font-size:12px;">Δ ${dTxt}</div>
|
||||
<div class="cmp-bars">
|
||||
<div class="cmp-bar-line">
|
||||
<div class="tiny">Base</div>
|
||||
<div class="cmp-bar-wrap"><div class="cmp-bar-fill" style="width:${bWidth}%;opacity:0.75;"></div></div>
|
||||
</div>
|
||||
<div class="cmp-bar-line">
|
||||
<div class="tiny">Best</div>
|
||||
<div class="cmp-bar-wrap"><div class="cmp-bar-fill" style="width:${nWidth}%;"></div></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
function hideCompletionPopup() {
|
||||
document.getElementById("cmp-modal-mask").classList.remove("show");
|
||||
}
|
||||
|
||||
function showCompletionPopupContent(task) {
|
||||
const baseline = task.baseline_metrics || {};
|
||||
const best = task.best_metrics || {};
|
||||
const roundsDone = Number(task.rounds_done || 0);
|
||||
const roundsTotal = Number(task.rounds_total || 0);
|
||||
const summary = `任务 ${String(task.task_id || "").slice(0, 8)} 已完成,目标=${task.objective || "-"},轮次=${roundsDone}/${roundsTotal},best_score=${_fmt(task.best_score || 0, 6)}`;
|
||||
document.getElementById("cmp-summary").textContent = summary;
|
||||
|
||||
const metricCards = [
|
||||
{ name: "Precision@1", base: baseline.precision_at_1, best: best.precision_at_1, percent: true },
|
||||
{ name: "Precision@3", base: baseline.precision_at_3, best: best.precision_at_3, percent: true },
|
||||
{ name: "Recall@K", base: baseline.recall_at_k, best: best.recall_at_k, percent: true },
|
||||
{ name: "MRR", base: baseline.mrr, best: best.mrr, percent: true },
|
||||
{ name: "SPO Relation Hit", base: baseline.spo_relation_hit_rate, best: best.spo_relation_hit_rate, percent: true },
|
||||
{ name: "Empty Rate", base: baseline.empty_rate, best: best.empty_rate, percent: true, reverse: true },
|
||||
];
|
||||
document.getElementById("cmp-metrics").innerHTML = metricCards.map(_renderMetricCard).join("");
|
||||
|
||||
const baseCat = baseline.category || {};
|
||||
const bestCat = best.category || {};
|
||||
const keys = Array.from(new Set([...Object.keys(baseCat), ...Object.keys(bestCat)])).sort();
|
||||
const rows = [];
|
||||
for (const k of keys) {
|
||||
const b = baseCat[k] || {};
|
||||
const n = bestCat[k] || {};
|
||||
const bTot = Math.max(1, Number(b.total || 0));
|
||||
const nTot = Math.max(1, Number(n.total || 0));
|
||||
const bRecall = Number(b.hit || 0) / bTot;
|
||||
const nRecall = Number(n.hit || 0) / nTot;
|
||||
const bP1 = Number(b.hit_at_1 || 0) / bTot;
|
||||
const nP1 = Number(n.hit_at_1 || 0) / nTot;
|
||||
const dRecall = nRecall - bRecall;
|
||||
const dP1 = nP1 - bP1;
|
||||
rows.push(`
|
||||
<tr>
|
||||
<td>${k}</td>
|
||||
<td>${_pct(bRecall)} (${Number(b.hit || 0)}/${Number(b.total || 0)})</td>
|
||||
<td>${_pct(nRecall)} (${Number(n.hit || 0)}/${Number(n.total || 0)})</td>
|
||||
<td class="${_deltaClass(dRecall)}">${dRecall >= 0 ? "+" : ""}${_pct(dRecall)}</td>
|
||||
<td>${_pct(bP1)}</td>
|
||||
<td>${_pct(nP1)}</td>
|
||||
<td class="${_deltaClass(dP1)}">${dP1 >= 0 ? "+" : ""}${_pct(dP1)}</td>
|
||||
</tr>
|
||||
`);
|
||||
}
|
||||
document.getElementById("cmp-category-table").innerHTML = rows.join("") || '<tr><td colspan="7" class="tiny">无分类指标</td></tr>';
|
||||
document.getElementById("cmp-modal-mask").classList.add("show");
|
||||
}
|
||||
|
||||
async function tryShowCompletionPopup(taskId) {
|
||||
if (!taskId || state.completionPopupShown.has(taskId)) return;
|
||||
const body = await req(`/api/retrieval_tuning/tasks/${taskId}`);
|
||||
const task = body.task || {};
|
||||
if (task.status !== "completed") return;
|
||||
showCompletionPopupContent(task);
|
||||
state.completionPopupShown.add(taskId);
|
||||
}
|
||||
|
||||
function toast(msg, level = "info") {
|
||||
const el = document.getElementById("toast");
|
||||
el.style.display = "block";
|
||||
el.style.borderColor = level === "error" ? "#7f1d1d" : level === "warn" ? "#854d0e" : "#28415f";
|
||||
el.textContent = msg;
|
||||
setTimeout(() => { el.style.display = "none"; }, 2200);
|
||||
}
|
||||
|
||||
async function req(url, options = {}) {
|
||||
const resp = await fetch(url, options);
|
||||
const body = await resp.json().catch(() => ({}));
|
||||
if (!resp.ok || body.success === false) {
|
||||
throw new Error(body.detail || body.error || body.message || `HTTP ${resp.status}`);
|
||||
}
|
||||
return body;
|
||||
}
|
||||
|
||||
function pretty(obj) {
|
||||
return JSON.stringify(obj, null, 2);
|
||||
}
|
||||
|
||||
async function loadProfile() {
|
||||
const body = await req("/api/retrieval_tuning/profile");
|
||||
state.settings = body.settings || {};
|
||||
document.getElementById("settings-tip").textContent = `默认目标=${state.settings.default_objective},默认强度=${state.settings.default_intensity},轮询=${state.settings.poll_interval_ms}ms`;
|
||||
document.getElementById("current-profile").textContent = pretty(body.profile || {});
|
||||
document.getElementById("manual-profile").value = pretty(body.profile || {});
|
||||
|
||||
if (state.settings.default_objective) document.getElementById("objective").value = state.settings.default_objective;
|
||||
if (state.settings.default_intensity) document.getElementById("intensity").value = state.settings.default_intensity;
|
||||
if (state.settings.default_sample_size) document.getElementById("sample-size").value = state.settings.default_sample_size;
|
||||
if (state.settings.default_top_k_eval) document.getElementById("top-k-eval").value = state.settings.default_top_k_eval;
|
||||
}
|
||||
|
||||
function renderTaskList() {
|
||||
const list = document.getElementById("task-list");
|
||||
list.innerHTML = "";
|
||||
for (const task of state.tasks) {
|
||||
const div = document.createElement("div");
|
||||
div.className = `task${task.task_id === state.selectedTaskId ? " active" : ""}`;
|
||||
div.onclick = () => {
|
||||
state.selectedTaskId = task.task_id;
|
||||
renderTaskList();
|
||||
loadTaskDetail();
|
||||
};
|
||||
div.innerHTML = `
|
||||
<div class="line"><span>${task.task_id.slice(0, 8)}</span><span class="badge">${task.status}</span></div>
|
||||
<div class="line tiny"><span>${task.objective}</span><span>${Math.round((task.progress || 0) * 100)}%</span></div>
|
||||
<div class="tiny">round ${task.rounds_done || 0}/${task.rounds_total || 0}, best=${(task.best_score || 0).toFixed(4)}</div>
|
||||
`;
|
||||
list.appendChild(div);
|
||||
}
|
||||
document.getElementById("task-meta").textContent = `共 ${state.tasks.length} 个任务`;
|
||||
}
|
||||
|
||||
async function loadTasks() {
|
||||
const body = await req("/api/retrieval_tuning/tasks?limit=100");
|
||||
state.tasks = body.items || [];
|
||||
const prevMap = { ...state.taskStatusMap };
|
||||
state.taskStatusMap = {};
|
||||
let toPopupTaskId = null;
|
||||
for (const t of state.tasks) {
|
||||
state.taskStatusMap[t.task_id] = t.status;
|
||||
if (
|
||||
state.watchTaskIds.has(t.task_id) &&
|
||||
t.status === "completed" &&
|
||||
prevMap[t.task_id] &&
|
||||
prevMap[t.task_id] !== "completed" &&
|
||||
!state.completionPopupShown.has(t.task_id)
|
||||
) {
|
||||
toPopupTaskId = t.task_id;
|
||||
}
|
||||
}
|
||||
if (!state.selectedTaskId && state.tasks.length) {
|
||||
state.selectedTaskId = state.tasks[0].task_id;
|
||||
} else if (state.selectedTaskId && !state.tasks.find(x => x.task_id === state.selectedTaskId)) {
|
||||
state.selectedTaskId = state.tasks.length ? state.tasks[0].task_id : null;
|
||||
}
|
||||
renderTaskList();
|
||||
if (toPopupTaskId) {
|
||||
await tryShowCompletionPopup(toPopupTaskId);
|
||||
}
|
||||
}
|
||||
|
||||
async function loadTaskDetail() {
|
||||
const taskId = state.selectedTaskId;
|
||||
if (!taskId) {
|
||||
document.getElementById("task-detail").textContent = "";
|
||||
document.getElementById("round-table").innerHTML = "";
|
||||
return;
|
||||
}
|
||||
const body = await req(`/api/retrieval_tuning/tasks/${taskId}`);
|
||||
const task = body.task || {};
|
||||
document.getElementById("task-detail").textContent = pretty(task);
|
||||
|
||||
const rounds = await req(`/api/retrieval_tuning/tasks/${taskId}/rounds?offset=0&limit=400`);
|
||||
const tb = document.getElementById("round-table");
|
||||
tb.innerHTML = "";
|
||||
for (const row of rounds.items || []) {
|
||||
const m = row.metrics || {};
|
||||
const tr = document.createElement("tr");
|
||||
tr.innerHTML = `
|
||||
<td>${row.round_index}</td>
|
||||
<td>${Number(row.score || 0).toFixed(4)}</td>
|
||||
<td>${Number(m.precision_at_1 || 0).toFixed(4)}</td>
|
||||
<td>${Number(m.precision_at_3 || 0).toFixed(4)}</td>
|
||||
<td>${Number(m.mrr || 0).toFixed(4)}</td>
|
||||
<td>${Number(m.recall_at_k || 0).toFixed(4)}</td>
|
||||
<td>${Number(m.spo_relation_hit_rate || 0).toFixed(4)}</td>
|
||||
<td>${Number(m.empty_rate || 0).toFixed(4)}</td>
|
||||
<td>${Number(row.latency_ms || 0).toFixed(2)}</td>
|
||||
`;
|
||||
tb.appendChild(tr);
|
||||
}
|
||||
document.getElementById("round-meta").textContent = `total ${rounds.total || 0}`;
|
||||
}
|
||||
|
||||
async function loadReport() {
|
||||
if (!state.selectedTaskId) return;
|
||||
const body = await req(`/api/retrieval_tuning/tasks/${state.selectedTaskId}/report?format=md`);
|
||||
document.getElementById("report-content").value = body.content || "";
|
||||
}
|
||||
|
||||
async function refreshAll() {
|
||||
try {
|
||||
await loadProfile();
|
||||
await loadTasks();
|
||||
await loadTaskDetail();
|
||||
} catch (e) {
|
||||
toast(e.message || String(e), "error");
|
||||
}
|
||||
}
|
||||
|
||||
document.getElementById("refresh-all").onclick = refreshAll;
|
||||
|
||||
document.getElementById("btn-apply").onclick = async () => {
|
||||
try {
|
||||
const profile = JSON.parse(document.getElementById("manual-profile").value || "{}");
|
||||
await req("/api/retrieval_tuning/profile/apply", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ profile, reason: "web_manual_apply" }),
|
||||
});
|
||||
toast("参数已应用");
|
||||
await refreshAll();
|
||||
} catch (e) {
|
||||
toast(e.message || String(e), "error");
|
||||
}
|
||||
};
|
||||
|
||||
document.getElementById("btn-rollback").onclick = async () => {
|
||||
try {
|
||||
await req("/api/retrieval_tuning/profile/rollback", { method: "POST" });
|
||||
toast("已回滚");
|
||||
await refreshAll();
|
||||
} catch (e) {
|
||||
toast(e.message || String(e), "error");
|
||||
}
|
||||
};
|
||||
|
||||
document.getElementById("btn-export").onclick = async () => {
|
||||
try {
|
||||
const body = await req("/api/retrieval_tuning/profile/export_toml");
|
||||
document.getElementById("toml-snippet").value = body.toml || "";
|
||||
toast("已导出 TOML");
|
||||
} catch (e) {
|
||||
toast(e.message || String(e), "error");
|
||||
}
|
||||
};
|
||||
|
||||
document.getElementById("btn-create-task").onclick = async () => {
|
||||
try {
|
||||
const rounds = document.getElementById("rounds").value.trim();
|
||||
const payload = {
|
||||
objective: document.getElementById("objective").value,
|
||||
intensity: document.getElementById("intensity").value,
|
||||
sample_size: Number(document.getElementById("sample-size").value || 24),
|
||||
top_k_eval: Number(document.getElementById("top-k-eval").value || 20),
|
||||
llm_enabled: document.getElementById("llm-enabled").checked,
|
||||
};
|
||||
if (rounds) payload.rounds = Number(rounds);
|
||||
const body = await req("/api/retrieval_tuning/tasks", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
toast("任务已创建");
|
||||
const newTaskId = body.task?.task_id || "";
|
||||
if (newTaskId) {
|
||||
state.watchTaskIds.add(newTaskId);
|
||||
state.selectedTaskId = newTaskId;
|
||||
}
|
||||
await refreshAll();
|
||||
} catch (e) {
|
||||
toast(e.message || String(e), "error");
|
||||
}
|
||||
};
|
||||
|
||||
document.getElementById("btn-cancel-task").onclick = async () => {
|
||||
if (!state.selectedTaskId) return;
|
||||
try {
|
||||
await req(`/api/retrieval_tuning/tasks/${state.selectedTaskId}/cancel`, { method: "POST" });
|
||||
toast("已发送取消请求", "warn");
|
||||
await refreshAll();
|
||||
} catch (e) {
|
||||
toast(e.message || String(e), "error");
|
||||
}
|
||||
};
|
||||
|
||||
document.getElementById("btn-apply-best").onclick = async () => {
|
||||
if (!state.selectedTaskId) return;
|
||||
try {
|
||||
await req(`/api/retrieval_tuning/tasks/${state.selectedTaskId}/apply_best`, { method: "POST" });
|
||||
toast("最优参数已应用");
|
||||
await refreshAll();
|
||||
} catch (e) {
|
||||
toast(e.message || String(e), "error");
|
||||
}
|
||||
};
|
||||
|
||||
document.getElementById("btn-load-report").onclick = async () => {
|
||||
try {
|
||||
await loadReport();
|
||||
toast("报告已加载");
|
||||
} catch (e) {
|
||||
toast(e.message || String(e), "error");
|
||||
}
|
||||
};
|
||||
|
||||
document.getElementById("cmp-close-btn").onclick = hideCompletionPopup;
|
||||
document.getElementById("cmp-modal-mask").onclick = (e) => {
|
||||
if (e.target && e.target.id === "cmp-modal-mask") {
|
||||
hideCompletionPopup();
|
||||
}
|
||||
};
|
||||
|
||||
function startPolling() {
|
||||
const ms = Number(state.settings?.poll_interval_ms || 1200);
|
||||
if (state.pollTimer) clearInterval(state.pollTimer);
|
||||
state.pollTimer = setInterval(async () => {
|
||||
try {
|
||||
await loadTasks();
|
||||
await loadTaskDetail();
|
||||
} catch (_) {}
|
||||
}, Math.max(400, ms));
|
||||
}
|
||||
|
||||
(async () => {
|
||||
await refreshAll();
|
||||
startPolling();
|
||||
})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
133
src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py
Normal file
133
src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.person_info.person_info import resolve_person_id_for_memory
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
logger = get_logger("knowledge_fetcher")
|
||||
|
||||
|
||||
class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str, stream_id: str):
|
||||
self.private_name = private_name
|
||||
self.stream_id = stream_id
|
||||
|
||||
def _resolve_private_memory_context(self) -> Dict[str, str]:
|
||||
session = _chat_manager.get_session_by_session_id(self.stream_id)
|
||||
if session is None:
|
||||
return {"chat_id": self.stream_id}
|
||||
|
||||
group_id = str(getattr(session, "group_id", "") or "").strip()
|
||||
user_id = str(getattr(session, "user_id", "") or "").strip()
|
||||
platform = str(getattr(session, "platform", "") or "").strip()
|
||||
|
||||
person_id = ""
|
||||
if not group_id:
|
||||
try:
|
||||
person_id = resolve_person_id_for_memory(
|
||||
person_name=self.private_name,
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[私聊][{self.private_name}]解析人物ID失败: {exc}")
|
||||
|
||||
return {
|
||||
"chat_id": self.stream_id,
|
||||
"person_id": person_id,
|
||||
"user_id": user_id,
|
||||
"group_id": group_id,
|
||||
}
|
||||
|
||||
async def _memory_get_knowledge(self, query: str) -> str:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
|
||||
Returns:
|
||||
str: 构造好的,带相关度的知识
|
||||
"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]正在从长期记忆中获取知识")
|
||||
try:
|
||||
context = self._resolve_private_memory_context()
|
||||
search_kwargs = {
|
||||
"limit": 5,
|
||||
"mode": "search",
|
||||
"chat_id": context.get("chat_id", ""),
|
||||
"person_id": context.get("person_id", ""),
|
||||
"user_id": context.get("user_id", ""),
|
||||
"group_id": context.get("group_id", ""),
|
||||
"respect_filter": True,
|
||||
}
|
||||
result = await memory_service.search(query, **search_kwargs)
|
||||
if not result.success:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]长期记忆查询失败: {result.error or '未知错误'}"
|
||||
)
|
||||
return f"长期记忆检索失败:{result.error or '未知错误'}"
|
||||
if not result.filtered and not result.hits and search_kwargs["person_id"]:
|
||||
fallback_kwargs = dict(search_kwargs)
|
||||
fallback_kwargs["person_id"] = ""
|
||||
logger.debug(f"[私聊][{self.private_name}]人物过滤未命中,退回仅按会话检索长期记忆")
|
||||
result = await memory_service.search(query, **fallback_kwargs)
|
||||
if not result.success:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]长期记忆回退查询失败: {result.error or '未知错误'}"
|
||||
)
|
||||
return f"长期记忆检索失败:{result.error or '未知错误'}"
|
||||
knowledge_info = result.to_text(limit=5)
|
||||
if result.filtered:
|
||||
logger.debug(f"[私聊][{self.private_name}]长期记忆查询被聊天过滤策略跳过")
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]长期记忆查询结果: {knowledge_info[:150]}")
|
||||
return knowledge_info or "未找到匹配的知识"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]长期记忆搜索工具执行失败: {str(e)}")
|
||||
return "未找到匹配的知识"
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
chat_history: 聊天历史 (PFC dict format)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (获取的知识, 知识来源)
|
||||
"""
|
||||
_ = chat_history
|
||||
|
||||
# NOTE: Hippocampus memory system was redesigned in v0.12.2
|
||||
# The old get_memory_from_text API no longer exists
|
||||
# For now, we'll skip the memory retrieval part and only use LPMM knowledge
|
||||
# TODO: Integrate with new memory system if needed
|
||||
knowledge_text = ""
|
||||
sources_text = "无记忆匹配" # 默认值
|
||||
|
||||
# # 从记忆中获取相关知识 (DISABLED - old Hippocampus API)
|
||||
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
# text=f"{query}\n{chat_history_text}",
|
||||
# max_memory_num=3,
|
||||
# max_memory_length=2,
|
||||
# max_depth=3,
|
||||
# fast_retrieval=False,
|
||||
# )
|
||||
# if related_memory:
|
||||
# sources = []
|
||||
# for memory in related_memory:
|
||||
# knowledge_text += memory[1] + "\n"
|
||||
# sources.append(f"记忆片段{memory[0]}")
|
||||
# knowledge_text = knowledge_text.strip()
|
||||
# sources_text = ",".join(sources)
|
||||
|
||||
knowledge_text += "\n现在有以下**知识**可供参考:\n "
|
||||
knowledge_text += await self._memory_get_knowledge(query)
|
||||
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
|
||||
|
||||
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"
|
||||
@@ -1,90 +0,0 @@
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
|
||||
INVALID_ENTITY = [
|
||||
"",
|
||||
"你",
|
||||
"他",
|
||||
"她",
|
||||
"它",
|
||||
"我们",
|
||||
"你们",
|
||||
"他们",
|
||||
"她们",
|
||||
"它们",
|
||||
]
|
||||
|
||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
|
||||
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
|
||||
def get_qa_manager():
|
||||
return qa_manager
|
||||
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
logger.info("正在初始化Mai-LPMM")
|
||||
logger.info("创建LLM客户端")
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager(
|
||||
max_workers=global_config.lpmm_knowledge.max_embedding_workers,
|
||||
chunk_size=global_config.lpmm_knowledge.embedding_chunk_size,
|
||||
)
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}")
|
||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
kg_manager = KGManager()
|
||||
logger.info("正在从文件加载KG")
|
||||
try:
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}")
|
||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||
logger.info("KG加载完成")
|
||||
|
||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
||||
|
||||
# 数据比对:Embedding库与KG的段落hash集合
|
||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||
# 使用与EmbeddingStore中一致的命名空间格式
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_manager.stored_pg_hashes:
|
||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||
global qa_manager
|
||||
# 问答系统(用于知识库)
|
||||
qa_manager = QAManager(
|
||||
embed_manager,
|
||||
kg_manager,
|
||||
)
|
||||
|
||||
# # 记忆激活(用于记忆库)
|
||||
# global inspire_manager
|
||||
# inspire_manager = MemoryActiveManager(
|
||||
# embed_manager,
|
||||
# llm_client_list[global_config["embedding"]["provider"]],
|
||||
# )
|
||||
else:
|
||||
logger.info("LPMM知识库已禁用,跳过初始化")
|
||||
# 创建空的占位符对象,避免导入错误
|
||||
@@ -1,381 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import List, Callable, Any
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import get_qa_manager, lpmm_start_up
|
||||
|
||||
logger = get_logger("LPMM-Plugin-API")
|
||||
|
||||
|
||||
class LPMMOperations:
|
||||
"""
|
||||
LPMM 内部操作接口。
|
||||
封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
在线程池中执行可取消的同步操作。
|
||||
当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。
|
||||
注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。
|
||||
|
||||
Args:
|
||||
func: 要执行的同步函数
|
||||
*args: 函数的位置参数
|
||||
**kwargs: 函数的关键字参数
|
||||
|
||||
Returns:
|
||||
函数的返回值
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: 当任务被取消时
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
# 在线程池中执行,当协程被取消时会立即响应
|
||||
# 虽然线程池中的操作可能仍在运行,但协程不会阻塞
|
||||
return await loop.run_in_executor(None, func, *args, **kwargs)
|
||||
|
||||
async def _get_managers(self) -> tuple[EmbeddingManager, KGManager, QAManager]:
|
||||
"""获取并确保 LPMM 管理器已初始化"""
|
||||
qa_mgr = get_qa_manager()
|
||||
if qa_mgr is None:
|
||||
# 如果全局没初始化,尝试初始化
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
|
||||
|
||||
lpmm_start_up()
|
||||
qa_mgr = get_qa_manager()
|
||||
|
||||
if qa_mgr is None:
|
||||
raise RuntimeError("无法获取 LPMM QAManager,请检查 LPMM 是否已正确安装和配置。")
|
||||
|
||||
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
|
||||
|
||||
async def add_content(self, text: str, auto_split: bool = True) -> dict:
|
||||
"""
|
||||
向知识库添加新内容。
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
auto_split: 是否自动按双换行符分割段落。
|
||||
- True: 自动分割(默认),支持多段文本(用双换行分隔)
|
||||
- False: 不分割,将整个文本作为完整一段处理
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
# 1. 分段处理
|
||||
if auto_split:
|
||||
# 自动按双换行符分割
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
else:
|
||||
# 不分割,作为完整一段
|
||||
text_stripped = text.strip()
|
||||
if not text_stripped:
|
||||
return {"status": "error", "message": "文本内容为空"}
|
||||
paragraphs = [text_stripped]
|
||||
|
||||
if not paragraphs:
|
||||
return {"status": "error", "message": "文本内容为空"}
|
||||
|
||||
# 2. 实体与三元组抽取 (内部调用大模型)
|
||||
from src.chat.knowledge.ie_process import IEProcess
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm_ner = LLMServiceClient(
|
||||
task_name="lpmm_entity_extract", request_type="lpmm.entity_extract"
|
||||
)
|
||||
llm_rdf = LLMServiceClient(
|
||||
task_name="lpmm_rdf_build", request_type="lpmm.rdf_build"
|
||||
)
|
||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
||||
|
||||
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
|
||||
extracted_docs = await ie_process.process_paragraphs(paragraphs)
|
||||
|
||||
# 3. 构造并导入数据
|
||||
# 这里我们手动实现导入逻辑,不依赖外部脚本
|
||||
# a. 准备段落
|
||||
raw_paragraphs = {doc["idx"]: doc["passage"] for doc in extracted_docs}
|
||||
# b. 准备三元组
|
||||
triple_list_data = {doc["idx"]: doc["extracted_triples"] for doc in extracted_docs}
|
||||
|
||||
# 向量化并入库
|
||||
# 注意:此处模仿 import_openie.py 的核心逻辑
|
||||
# 1. 先进行去重检查,只处理新段落
|
||||
# store_new_data_set 期望的格式:raw_paragraphs 的键是段落hash(不带前缀),值是段落文本
|
||||
new_raw_paragraphs = {}
|
||||
new_triple_list_data = {}
|
||||
|
||||
for pg_hash, passage in raw_paragraphs.items():
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_mgr.stored_pg_hashes:
|
||||
new_raw_paragraphs[pg_hash] = passage
|
||||
new_triple_list_data[pg_hash] = triple_list_data[pg_hash]
|
||||
|
||||
if not new_raw_paragraphs:
|
||||
return {"status": "success", "count": 0, "message": "内容已存在,无需重复导入"}
|
||||
|
||||
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
|
||||
# store_new_data_set 会自动处理嵌入生成和存储
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
|
||||
|
||||
# 3. 构建知识图谱(只需要三元组数据和embedding_manager)
|
||||
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
|
||||
|
||||
# 4. 持久化
|
||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(new_raw_paragraphs),
|
||||
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 导入操作被用户中断")
|
||||
return {"status": "cancelled", "message": "导入操作已被用户中断"}
|
||||
except Exception as e:
|
||||
logger.error(f"[Plugin API] 导入知识失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def search(self, query: str, top_k: int = 3) -> List[str]:
|
||||
"""
|
||||
检索知识库。
|
||||
|
||||
Args:
|
||||
query: 查询问题。
|
||||
top_k: 返回最相关的条目数。
|
||||
|
||||
Returns:
|
||||
List[str]: 相关文段列表。
|
||||
"""
|
||||
try:
|
||||
_, _, qa_mgr = await self._get_managers()
|
||||
# 直接调用 QAManager 的检索接口
|
||||
knowledge = qa_mgr.get_knowledge(query, top_k=top_k)
|
||||
# 返回通常是拼接好的字符串,这里我们可以尝试按其内部规则切分回列表,或者直接返回
|
||||
return [knowledge] if knowledge else []
|
||||
except Exception as e:
|
||||
logger.error(f"[Plugin API] 检索知识失败: {e}")
|
||||
return []
|
||||
|
||||
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
|
||||
"""
|
||||
根据关键词或完整文段删除知识库内容。
|
||||
|
||||
Args:
|
||||
keyword: 匹配关键词或完整文段。
|
||||
exact_match: 是否使用完整文段匹配(True=完全匹配,False=关键词模糊匹配)。
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
# 1. 查找匹配的段落
|
||||
to_delete_keys = []
|
||||
to_delete_hashes = []
|
||||
|
||||
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
|
||||
if exact_match:
|
||||
# 完整文段匹配
|
||||
if item.str.strip() == keyword.strip():
|
||||
to_delete_keys.append(key)
|
||||
to_delete_hashes.append(key.replace("paragraph-", "", 1))
|
||||
else:
|
||||
# 关键词模糊匹配
|
||||
if keyword in item.str:
|
||||
to_delete_keys.append(key)
|
||||
to_delete_hashes.append(key.replace("paragraph-", "", 1))
|
||||
|
||||
if not to_delete_keys:
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
|
||||
|
||||
# 2. 执行删除
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
|
||||
# a. 从向量库删除
|
||||
deleted_count, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
|
||||
# b. 从知识图谱删除
|
||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
# 3. 持久化
|
||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {
|
||||
"status": "success",
|
||||
"deleted_count": deleted_count,
|
||||
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 删除操作被用户中断")
|
||||
return {"status": "cancelled", "message": "删除操作已被用户中断"}
|
||||
except Exception as e:
|
||||
logger.error(f"[Plugin API] 删除知识失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def clear_all(self) -> dict:
|
||||
"""
|
||||
清空整个LPMM知识库(删除所有段落、实体、关系和知识图谱数据)。
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/error", "message": "描述", "stats": {...}}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
# 记录清空前的统计信息
|
||||
before_stats = {
|
||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
||||
"entities": len(embed_mgr.entities_embedding_store.store),
|
||||
"relations": len(embed_mgr.relation_embedding_store.store),
|
||||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
||||
}
|
||||
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
|
||||
# 1. 清空所有向量库
|
||||
# 获取所有keys
|
||||
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
|
||||
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
|
||||
|
||||
# 删除所有段落向量
|
||||
para_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes.clear()
|
||||
|
||||
# 删除所有实体向量
|
||||
if ent_keys:
|
||||
ent_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.entities_embedding_store.delete_items, ent_keys
|
||||
)
|
||||
else:
|
||||
ent_deleted = 0
|
||||
|
||||
# 删除所有关系向量
|
||||
if rel_keys:
|
||||
rel_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.relation_embedding_store.delete_items, rel_keys
|
||||
)
|
||||
else:
|
||||
rel_deleted = 0
|
||||
|
||||
# 2. 清空所有 embedding store 的索引和映射
|
||||
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
|
||||
def _clear_embedding_indices():
|
||||
# 清空段落索引
|
||||
embed_mgr.paragraphs_embedding_store.faiss_index = None
|
||||
embed_mgr.paragraphs_embedding_store.idx2hash = None
|
||||
embed_mgr.paragraphs_embedding_store.dirty = False
|
||||
# 删除旧的索引文件
|
||||
if os.path.exists(embed_mgr.paragraphs_embedding_store.index_file_path):
|
||||
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
|
||||
|
||||
# 清空实体索引
|
||||
embed_mgr.entities_embedding_store.faiss_index = None
|
||||
embed_mgr.entities_embedding_store.idx2hash = None
|
||||
embed_mgr.entities_embedding_store.dirty = False
|
||||
# 删除旧的索引文件
|
||||
if os.path.exists(embed_mgr.entities_embedding_store.index_file_path):
|
||||
os.remove(embed_mgr.entities_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
|
||||
|
||||
# 清空关系索引
|
||||
embed_mgr.relation_embedding_store.faiss_index = None
|
||||
embed_mgr.relation_embedding_store.idx2hash = None
|
||||
embed_mgr.relation_embedding_store.dirty = False
|
||||
# 删除旧的索引文件
|
||||
if os.path.exists(embed_mgr.relation_embedding_store.index_file_path):
|
||||
os.remove(embed_mgr.relation_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
|
||||
|
||||
await self._run_cancellable_executor(_clear_embedding_indices)
|
||||
|
||||
# 3. 清空知识图谱
|
||||
# 获取所有段落hash
|
||||
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
|
||||
if all_pg_hashes:
|
||||
# 删除所有段落节点(这会自动清理相关的边和孤立实体)
|
||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
# 完全清空KG:创建新的空图(无论是否有段落hash都要执行)
|
||||
from quick_algo import di_graph
|
||||
|
||||
kg_mgr.graph = di_graph.DiGraph()
|
||||
kg_mgr.stored_paragraph_hashes.clear()
|
||||
kg_mgr.ent_appear_cnt.clear()
|
||||
|
||||
# 4. 保存所有数据(此时所有store都是空的,索引也是None)
|
||||
# 注意:即使store为空,save_to_file也会保存空的DataFrame,这是正确的
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
after_stats = {
|
||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
||||
"entities": len(embed_mgr.entities_embedding_store.store),
|
||||
"relations": len(embed_mgr.relation_embedding_store.store),
|
||||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已成功清空LPMM知识库(删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
|
||||
"stats": {
|
||||
"before": before_stats,
|
||||
"after": after_stats,
|
||||
},
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 清空操作被用户中断")
|
||||
return {"status": "cancelled", "message": "清空操作已被用户中断"}
|
||||
except Exception as e:
|
||||
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
# 内部使用的单例
|
||||
lpmm_ops = LPMMOperations()
|
||||
@@ -613,6 +613,13 @@ class ChatBot:
|
||||
scope=scope,
|
||||
) # 确保会话存在
|
||||
|
||||
try:
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
|
||||
await memory_automation_service.on_incoming_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[{session_id}] 长期记忆自动摘要注册失败: {exc}")
|
||||
|
||||
# message.update_chat_stream(chat)
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令。
|
||||
|
||||
@@ -347,6 +347,13 @@ class UniversalMessageSender:
|
||||
with get_db_session() as db_session:
|
||||
db_session.add(message.to_db_instance())
|
||||
|
||||
try:
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
|
||||
await memory_automation_service.on_message_sent(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[{chat_id}] 长期记忆人物事实写回注册失败: {exc}")
|
||||
|
||||
return sent_msg
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -189,39 +189,37 @@ def find_messages(
|
||||
conditions.append(Messages.is_command == False) # noqa: E712
|
||||
|
||||
statement = select(Messages).where(*conditions)
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
|
||||
results = list(session.exec(statement).all())
|
||||
else:
|
||||
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
|
||||
results = list(session.exec(statement).all())
|
||||
results = list(reversed(results))
|
||||
else:
|
||||
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
|
||||
with get_db_session() as session:
|
||||
results = list(session.exec(statement).all())
|
||||
results = list(reversed(results))
|
||||
else:
|
||||
if sort:
|
||||
order_terms: list[Any] = []
|
||||
for field_name, direction in sort:
|
||||
sort_field = _resolve_field(field_name)
|
||||
if sort_field is None:
|
||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||
continue
|
||||
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
|
||||
if order_terms:
|
||||
statement = statement.order_by(*order_terms)
|
||||
with get_db_session() as session:
|
||||
if sort:
|
||||
order_terms: list[Any] = []
|
||||
for field_name, direction in sort:
|
||||
sort_field = _resolve_field(field_name)
|
||||
if sort_field is None:
|
||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||
continue
|
||||
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
|
||||
if order_terms:
|
||||
statement = statement.order_by(*order_terms)
|
||||
results = list(session.exec(statement).all())
|
||||
|
||||
if filter_intercept_message_level is not None:
|
||||
filtered_results = []
|
||||
for msg in results:
|
||||
config = _parse_additional_config(msg)
|
||||
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
|
||||
filtered_results.append(msg)
|
||||
results = filtered_results
|
||||
if filter_intercept_message_level is not None:
|
||||
filtered_results = []
|
||||
for msg in results:
|
||||
config = _parse_additional_config(msg)
|
||||
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
|
||||
filtered_results.append(msg)
|
||||
results = filtered_results
|
||||
|
||||
return [_message_to_instance(msg) for msg in results]
|
||||
return [_message_to_instance(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
"使用 SQLModel 查找消息失败 "
|
||||
|
||||
@@ -94,6 +94,11 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
["", "enable", "enable", "enable"],
|
||||
["qq:1919810:group", "enable", "enable", "enable"],
|
||||
]
|
||||
兼容旧旧格式:
|
||||
learning_list = [
|
||||
["qq:1919810:group", "enable", "enable", "0.5"],
|
||||
["", "disable", "disable", "0.1"],
|
||||
]
|
||||
新:
|
||||
[[expression.learning_list]]
|
||||
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
|
||||
@@ -117,6 +122,16 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
use_expression = _parse_enable_disable(r[1])
|
||||
enable_learning = _parse_enable_disable(r[2])
|
||||
enable_jargon_learning = _parse_enable_disable(r[3])
|
||||
if enable_jargon_learning is None:
|
||||
# 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值,
|
||||
# 当前 schema 已没有对应字段。这里按保守策略兼容迁移:
|
||||
# 丢弃旧数值,并将 enable_jargon_learning 置为 False。
|
||||
try:
|
||||
float(str(r[3]))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
enable_jargon_learning = False
|
||||
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
|
||||
return False
|
||||
|
||||
|
||||
@@ -403,6 +403,60 @@ class MemoryConfig(ConfigBase):
|
||||
|
||||
__ui_parent__ = "emoji"
|
||||
|
||||
max_agent_iterations: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "layers",
|
||||
},
|
||||
)
|
||||
"""记忆思考深度(最低为1)"""
|
||||
|
||||
agent_timeout_seconds: float = Field(
|
||||
default=120.0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "clock",
|
||||
},
|
||||
)
|
||||
"""最长回忆时间(秒)"""
|
||||
|
||||
global_memory: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "globe",
|
||||
},
|
||||
)
|
||||
"""是否允许记忆检索在聊天记录中进行全局查询(忽略当前chat_id,仅对 search_chat_history 等工具生效)"""
|
||||
|
||||
global_memory_blacklist: list[TargetItem] = Field(
|
||||
default_factory=lambda: [],
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "shield-off",
|
||||
},
|
||||
)
|
||||
"""_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索"""
|
||||
|
||||
long_term_auto_summary_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "book-open",
|
||||
},
|
||||
)
|
||||
"""是否自动启动聊天总结并导入长期记忆"""
|
||||
|
||||
person_fact_writeback_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "user-round-pen",
|
||||
},
|
||||
)
|
||||
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
|
||||
chat_history_topic_check_message_threshold: int = Field(
|
||||
default=80,
|
||||
ge=1,
|
||||
|
||||
@@ -8,7 +8,7 @@ from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.learners.jargon_miner_old import search_jargon
|
||||
from src.learners.jargon_explainer import search_jargon
|
||||
from src.learners.learner_utils_old import (
|
||||
is_bot_message,
|
||||
contains_bot_self_name,
|
||||
|
||||
@@ -196,6 +196,32 @@ def contains_bot_self_name(content: str) -> bool:
|
||||
return any(name in target for name in candidates)
|
||||
|
||||
|
||||
def is_bot_message(msg: Any) -> bool:
|
||||
"""判断消息是否来自机器人自身。"""
|
||||
if msg is None:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
known_accounts = {
|
||||
str(getattr(bot_config, "qq_account", "") or "").strip(),
|
||||
str(getattr(bot_config, "telegram_account", "") or "").strip(),
|
||||
}
|
||||
|
||||
for platform in getattr(bot_config, "platforms", []) or []:
|
||||
account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip()
|
||||
if account:
|
||||
known_accounts.add(account)
|
||||
|
||||
return user_id in {account for account in known_accounts if account}
|
||||
|
||||
|
||||
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
||||
# """
|
||||
# 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
||||
|
||||
14
src/main.py
14
src/main.py
@@ -5,9 +5,9 @@ from typing import TYPE_CHECKING
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from src.A_memorix.host_service import a_memorix_host_service
|
||||
from src.learners.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
@@ -20,6 +20,7 @@ from src.config.config import config_manager, global_config
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
@@ -93,11 +94,9 @@ class MainSystem:
|
||||
# start_api_server()
|
||||
# logger.info("API服务器启动成功")
|
||||
|
||||
# 启动LPMM
|
||||
lpmm_start_up()
|
||||
|
||||
# 启动插件运行时(内置插件 + 第三方插件双子进程)
|
||||
await get_plugin_runtime_manager().start()
|
||||
await a_memorix_host_service.start()
|
||||
|
||||
# 初始化表情管理器
|
||||
emoji_manager.load_emojis_from_db()
|
||||
@@ -108,6 +107,7 @@ class MainSystem:
|
||||
asyncio.create_task(chat_manager.regularly_save_sessions())
|
||||
|
||||
logger.info(t("startup.chat_manager_initialized"))
|
||||
await memory_automation_service.start()
|
||||
|
||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||
|
||||
@@ -169,6 +169,12 @@ async def main() -> None:
|
||||
system.schedule_tasks(),
|
||||
)
|
||||
finally:
|
||||
emoji_manager.shutdown()
|
||||
await memory_automation_service.shutdown()
|
||||
await a_memorix_host_service.stop()
|
||||
await get_plugin_runtime_manager().bridge_event("on_stop")
|
||||
await get_plugin_runtime_manager().stop()
|
||||
await async_task_manager.stop_and_wait_all_tasks()
|
||||
emoji_manager.shutdown()
|
||||
await config_manager.stop_file_watcher()
|
||||
|
||||
|
||||
1123
src/memory_system/chat_history_summarizer.py
Normal file
1123
src/memory_system/chat_history_summarizer.py
Normal file
File diff suppressed because it is too large
Load Diff
1046
src/memory_system/memory_retrieval.py
Normal file
1046
src/memory_system/memory_retrieval.py
Normal file
File diff suppressed because it is too large
Load Diff
32
src/memory_system/retrieval_tools/__init__.py
Normal file
32
src/memory_system/retrieval_tools/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
记忆检索工具模块
|
||||
提供统一的工具注册和管理系统
|
||||
"""
|
||||
|
||||
from .tool_registry import (
|
||||
MemoryRetrievalTool,
|
||||
MemoryRetrievalToolRegistry,
|
||||
register_memory_retrieval_tool,
|
||||
get_tool_registry,
|
||||
)
|
||||
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
# 延迟导入,避免在仅使用部分工具或单元测试阶段触发不必要的依赖链。
|
||||
from .query_long_term_memory import register_tool as register_long_term_memory
|
||||
from .query_words import register_tool as register_query_words
|
||||
from .return_information import register_tool as register_return_information
|
||||
|
||||
register_query_words()
|
||||
register_return_information()
|
||||
register_long_term_memory()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemoryRetrievalTool",
|
||||
"MemoryRetrievalToolRegistry",
|
||||
"register_memory_retrieval_tool",
|
||||
"get_tool_registry",
|
||||
"init_all_tools",
|
||||
]
|
||||
307
src/memory_system/retrieval_tools/query_long_term_memory.py
Normal file
307
src/memory_system/retrieval_tools/query_long_term_memory.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""通过统一长期记忆服务查询信息。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from calendar import monthrange
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterable, Literal, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services.memory_service import MemoryHit, MemorySearchResult, memory_service
|
||||
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
_SUPPORTED_MODES = {"search", "time", "episode", "aggregate"}
|
||||
_RELATIVE_DAYS_RE = re.compile(r"^最近\s*(\d+)\s*天$")
|
||||
_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
|
||||
_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}$")
|
||||
_TIME_EXPRESSION_HELP = (
|
||||
"请改用更具体的时间表达,例如:今天、昨天、前天、本周、上周、本月、上月、最近7天、"
|
||||
"2026/03/18、2026/03/18 09:30。"
|
||||
)
|
||||
|
||||
|
||||
def _format_query_datetime(dt: datetime) -> str:
|
||||
return dt.strftime("%Y/%m/%d %H:%M")
|
||||
|
||||
|
||||
def _resolve_time_expression(
|
||||
expression: str,
|
||||
*,
|
||||
now: datetime | None = None,
|
||||
) -> Tuple[float, float, str, str]:
|
||||
clean = str(expression or "").strip()
|
||||
if not clean:
|
||||
raise ValueError(f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}")
|
||||
|
||||
current = now or datetime.now()
|
||||
day_start = current.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if clean == "今天":
|
||||
start = day_start
|
||||
end = day_start.replace(hour=23, minute=59)
|
||||
elif clean == "昨天":
|
||||
start = day_start - timedelta(days=1)
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif clean == "前天":
|
||||
start = day_start - timedelta(days=2)
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif clean == "本周":
|
||||
start = day_start - timedelta(days=day_start.weekday())
|
||||
end = start + timedelta(days=6, hours=23, minutes=59)
|
||||
elif clean == "上周":
|
||||
this_week_start = day_start - timedelta(days=day_start.weekday())
|
||||
start = this_week_start - timedelta(days=7)
|
||||
end = start + timedelta(days=6, hours=23, minutes=59)
|
||||
elif clean == "本月":
|
||||
start = day_start.replace(day=1)
|
||||
last_day = monthrange(start.year, start.month)[1]
|
||||
end = start.replace(day=last_day, hour=23, minute=59)
|
||||
elif clean == "上月":
|
||||
year = day_start.year
|
||||
month = day_start.month - 1
|
||||
if month == 0:
|
||||
year -= 1
|
||||
month = 12
|
||||
start = day_start.replace(year=year, month=month, day=1)
|
||||
last_day = monthrange(year, month)[1]
|
||||
end = start.replace(day=last_day, hour=23, minute=59)
|
||||
else:
|
||||
relative_match = _RELATIVE_DAYS_RE.fullmatch(clean)
|
||||
if relative_match:
|
||||
days = max(1, int(relative_match.group(1)))
|
||||
start = day_start - timedelta(days=max(0, days - 1))
|
||||
end = day_start.replace(hour=23, minute=59)
|
||||
elif _DATE_RE.fullmatch(clean):
|
||||
start = datetime.strptime(clean, "%Y/%m/%d")
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif _MINUTE_RE.fullmatch(clean):
|
||||
start = datetime.strptime(clean, "%Y/%m/%d %H:%M")
|
||||
end = start
|
||||
else:
|
||||
raise ValueError(f"时间表达“{clean}”无法解析。{_TIME_EXPRESSION_HELP}")
|
||||
|
||||
return start.timestamp(), end.timestamp(), _format_query_datetime(start), _format_query_datetime(end)
|
||||
|
||||
|
||||
def _extract_time_label(metadata: dict) -> str:
|
||||
if not isinstance(metadata, dict):
|
||||
return ""
|
||||
start = metadata.get("event_time_start")
|
||||
end = metadata.get("event_time_end")
|
||||
event_time = metadata.get("event_time")
|
||||
|
||||
def _fmt(value: object) -> str:
|
||||
if value in {None, ""}:
|
||||
return ""
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value)).strftime("%Y/%m/%d %H:%M")
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
start_text = _fmt(start or event_time)
|
||||
end_text = _fmt(end)
|
||||
if start_text and end_text:
|
||||
return f"{start_text} - {end_text}"
|
||||
return start_text or end_text
|
||||
|
||||
|
||||
def _truncate(text: str, limit: int = 160) -> str:
|
||||
compact = str(text or "").strip().replace("\n", " ")
|
||||
if len(compact) <= limit:
|
||||
return compact
|
||||
return compact[:limit] + "..."
|
||||
|
||||
|
||||
def _format_search_lines(hits: Iterable[MemoryHit], *, limit: int, include_time: bool = False) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
time_label = _extract_time_label(item.metadata) if include_time else ""
|
||||
prefix = f"[{time_label}] " if time_label else ""
|
||||
lines.append(f"{index}. {prefix}{_truncate(item.content)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_episode_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
metadata = item.metadata if isinstance(item.metadata, dict) else {}
|
||||
title = str(item.title or "").strip() or "未命名事件"
|
||||
summary = _truncate(item.content, limit=180)
|
||||
participants = [str(x).strip() for x in (metadata.get("participants") or []) if str(x).strip()]
|
||||
keywords = [str(x).strip() for x in (metadata.get("keywords") or []) if str(x).strip()]
|
||||
extras = []
|
||||
if participants:
|
||||
extras.append(f"参与者:{'、'.join(participants[:4])}")
|
||||
if keywords:
|
||||
extras.append(f"关键词:{'、'.join(keywords[:6])}")
|
||||
time_label = _extract_time_label(metadata)
|
||||
if time_label:
|
||||
extras.append(f"时间:{time_label}")
|
||||
suffix = f"({';'.join(extras)})" if extras else ""
|
||||
lines.append(f"{index}. 事件《{title}》:{summary}{suffix}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_aggregate_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
metadata = item.metadata if isinstance(item.metadata, dict) else {}
|
||||
source_branches = [str(x).strip() for x in (metadata.get("source_branches") or []) if str(x).strip()]
|
||||
branch_text = f"[{','.join(source_branches)}]" if source_branches else ""
|
||||
item_type = str(item.hit_type or "").strip().lower() or "memory"
|
||||
if item_type == "episode":
|
||||
title = str(item.title or "").strip() or "未命名事件"
|
||||
lines.append(f"{index}. {branch_text}[episode] 《{title}》:{_truncate(item.content, 160)}")
|
||||
else:
|
||||
lines.append(f"{index}. {branch_text}[{item_type}] {_truncate(item.content, 160)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_tool_result(
|
||||
*,
|
||||
result: MemorySearchResult,
|
||||
mode: Literal["search", "time", "episode", "aggregate"],
|
||||
limit: int,
|
||||
query: str,
|
||||
time_range_text: str = "",
|
||||
) -> str:
|
||||
if not result.success:
|
||||
return f"长期记忆查询失败:{result.error or '未知错误'}"
|
||||
|
||||
if not result.hits:
|
||||
if mode == "time":
|
||||
return f"在指定时间范围内未找到相关的长期记忆{time_range_text}"
|
||||
if mode == "episode":
|
||||
return f"未找到与“{query}”相关的事件或情节记忆"
|
||||
if mode == "aggregate":
|
||||
return f"未找到可用于综合回忆的长期记忆线索{f'(query:{query})' if query else ''}"
|
||||
return f"在长期记忆中未找到与“{query}”相关的信息"
|
||||
|
||||
if mode == "episode":
|
||||
text = _format_episode_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆的事件/情节中找到以下信息:\n{text}"
|
||||
|
||||
if mode == "aggregate":
|
||||
text = _format_aggregate_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆中综合找到了以下线索:\n{text}"
|
||||
|
||||
if mode == "time":
|
||||
text = _format_search_lines(result.hits, limit=limit, include_time=True)
|
||||
return f"你从指定时间范围内的长期记忆中找到以下信息{time_range_text}:\n{text}"
|
||||
|
||||
text = _format_search_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆中找到以下信息:\n{text}"
|
||||
|
||||
|
||||
async def query_long_term_memory(
|
||||
query: str = "",
|
||||
limit: int = 5,
|
||||
chat_id: str = "",
|
||||
person_id: str = "",
|
||||
mode: str = "search",
|
||||
time_expression: str = "",
|
||||
) -> str:
|
||||
content = str(query or "").strip()
|
||||
safe_limit = max(1, int(limit or 5))
|
||||
normalized_mode = str(mode or "search").strip().lower() or "search"
|
||||
if normalized_mode not in _SUPPORTED_MODES:
|
||||
return f"不支持的长期记忆检索模式:{normalized_mode}。可用模式:search、time、episode、aggregate。"
|
||||
|
||||
if normalized_mode == "search" and not content:
|
||||
return "查询关键词为空,请提供你想查找的长期记忆内容。"
|
||||
if normalized_mode == "time" and not str(time_expression or "").strip():
|
||||
return f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}"
|
||||
if normalized_mode in {"episode", "aggregate"} and not content and not str(time_expression or "").strip():
|
||||
return f"{normalized_mode} 模式至少需要提供 query 或 time_expression。"
|
||||
|
||||
time_start = None
|
||||
time_end = None
|
||||
time_range_text = ""
|
||||
if str(time_expression or "").strip():
|
||||
try:
|
||||
time_start, time_end, time_start_text, time_end_text = _resolve_time_expression(time_expression)
|
||||
except ValueError as exc:
|
||||
return str(exc)
|
||||
time_range_text = f"(时间范围:{time_start_text} 至 {time_end_text})"
|
||||
|
||||
backend_mode = normalized_mode
|
||||
|
||||
try:
|
||||
result = await memory_service.search(
|
||||
content,
|
||||
limit=safe_limit,
|
||||
mode=backend_mode,
|
||||
chat_id=str(chat_id or "").strip(),
|
||||
person_id=str(person_id or "").strip(),
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
text = _format_tool_result(
|
||||
result=result,
|
||||
mode=normalized_mode, # type: ignore[arg-type]
|
||||
limit=safe_limit,
|
||||
query=content,
|
||||
time_range_text=time_range_text,
|
||||
)
|
||||
logger.debug(f"长期记忆查询结果({normalized_mode}): {text}")
|
||||
return text
|
||||
except Exception as exc:
|
||||
logger.error(f"长期记忆查询失败: {exc}")
|
||||
return f"长期记忆查询失败:{exc}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
register_memory_retrieval_tool(
|
||||
name="search_long_term_memory",
|
||||
description=(
|
||||
"从长期记忆中检索信息。支持 search(普通事实检索)、time(按时间范围检索)、"
|
||||
"episode(按事件/情节检索)、aggregate(综合检索)四种模式。"
|
||||
),
|
||||
parameters=[
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "需要查询的问题。search 模式建议用自然语言问句;time/episode/aggregate 模式也可用关键词短语。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "mode",
|
||||
"type": "string",
|
||||
"description": "检索模式:search(普通长期记忆)、time(按时间窗口)、episode(事件/情节)、aggregate(综合检索)。",
|
||||
"required": False,
|
||||
"enum": ["search", "time", "episode", "aggregate"],
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"type": "integer",
|
||||
"description": "希望返回的相关知识条数,默认为5",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "chat_id",
|
||||
"type": "string",
|
||||
"description": "当前聊天流ID,可选。提供后优先检索当前聊天上下文相关的长期记忆。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "person_id",
|
||||
"type": "string",
|
||||
"description": "相关人物ID,可选。提供后优先检索该人物相关的长期记忆。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "time_expression",
|
||||
"type": "string",
|
||||
"description": (
|
||||
"时间表达,可选。time 模式必填;episode/aggregate 模式可选。支持:今天、昨天、前天、本周、上周、本月、上月、"
|
||||
"最近N天,以及 YYYY/MM/DD、YYYY/MM/DD HH:mm。"
|
||||
),
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_long_term_memory,
|
||||
)
|
||||
78
src/memory_system/retrieval_tools/query_words.py
Normal file
78
src/memory_system/retrieval_tools/query_words.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
查询黑话/概念含义 - 工具实现
|
||||
用于在记忆检索过程中主动查询未知词语或黑话的含义
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_words(chat_id: str, words: str) -> str:
|
||||
"""查询词语或黑话的含义
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
words: 要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔)
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含词语的含义解释
|
||||
"""
|
||||
try:
|
||||
if not words or not words.strip():
|
||||
return "未提供要查询的词语"
|
||||
|
||||
# 解析词语列表(支持逗号、空格等分隔符)
|
||||
words_list = []
|
||||
for separator in [",", ",", " ", "\n", "\t"]:
|
||||
if separator in words:
|
||||
words_list = [w.strip() for w in words.split(separator) if w.strip()]
|
||||
break
|
||||
|
||||
# 如果没有找到分隔符,整个字符串作为一个词语
|
||||
if not words_list:
|
||||
words_list = [words.strip()]
|
||||
|
||||
# 去重
|
||||
unique_words = []
|
||||
seen = set()
|
||||
for word in words_list:
|
||||
if word and word not in seen:
|
||||
unique_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
if not unique_words:
|
||||
return "未提供有效的词语"
|
||||
|
||||
logger.info(f"查询词语含义: {unique_words}")
|
||||
|
||||
# 调用检索函数
|
||||
result = await retrieve_concepts_with_jargon(unique_words, chat_id)
|
||||
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询词语含义失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_words",
|
||||
description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "words",
|
||||
"type": "string",
|
||||
"description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS' 或 'YYDS,内卷,996')",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
execute_func=query_words,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user