feat:优化评分网页,可关闭状态看板
This commit is contained in:
@@ -1,5 +1,120 @@
|
||||
# Changelog
|
||||
## [0.12.2] - 2025-1-11
|
||||
|
||||
|
||||
## [1.0.0-pre.1] - 2026-4-19
|
||||
### 核心功能更新
|
||||
### MaiSaka系统
|
||||
原生支持多模态模型
|
||||
原生支持工具调用,多轮调用和mcp
|
||||
升级的replyer回复器,同样支持多模态
|
||||
统一群聊与私聊回复链路
|
||||
### 记忆系统革新
|
||||
引入 A_Memorix 长期记忆系统,替代旧记忆链路
|
||||
支持记忆检索、写回、迁移、反馈修正和管理界面
|
||||
### 全新插件系统
|
||||
提供独立的插件开发SDK
|
||||
重构插件系统为 plugin_runtime,提供 RPC、Hook、能力注册、运行时隔离、配置校验、批量重载与旧能力迁移。
|
||||
### 全面重构和修复
|
||||
新增 platform_io 消息平台抽象与消息中间层,统一消息路由、出站追踪和旧驱动兼容。
|
||||
新增统一 services 服务层,集中管理 LLM、生成器、发送、数据库、记忆、Embedding 与 HTML 渲染等能力。
|
||||
引入 MCP 与统一工具系统,插件工具和 MCP 工具统一调度,并优化工具展示、索引、重试和失败留档。
|
||||
WebUI 后端完成模块化重构,新增统一 WebSocket、插件管理、记忆管理、知识库、配置和监控相关 API。
|
||||
配置系统升级,支持旧配置自动迁移、字段类型安全校验、多模态模型配置和更细的工具/回复参数。
|
||||
优化表情包、图片、表达方式和黑话学习系统,提升识别、缓存、发送、学习与调用稳定性。
|
||||
清理旧插件系统、旧记忆系统、旧回复链路、旧工具系统、旧 WebUI 构建产物和多个废弃内置插件。
|
||||
!!预发布版本WebUI暂时不可用
|
||||
|
||||
完整更新清单
|
||||
核心架构
|
||||
大规模重构核心运行结构,新增 src/services 服务层,包括 LLM、生成器、发送、消息、数据库、记忆、HTML 渲染、Embedding 等服务。
|
||||
新增统一的 platform_io 消息平台抽象,提供驱动、路由、去重、出站追踪、插件驱动和旧版驱动兼容。
|
||||
引入新的消息中间层和网关设计,为插件、适配器、主程序之间的消息流转建立统一基础。
|
||||
重构数据模型,新增聊天目标、规划动作、回复生成结果、LLM 服务请求等模型。
|
||||
新增数据库迁移管理器,支持迁移进度记录、表级/记录级追踪和旧数据兼容。
|
||||
统一机器人识别逻辑,支持多平台场景,包括 WebUI。
|
||||
|
||||
MaiSaka / 回复系统
|
||||
新增并持续完善 maisaka 主回复链路,逐步接管群聊与私聊回复逻辑。
|
||||
新增 planner / replyer / timing / subagent 等运行结构,支持 wait 打断、防抖、重试和状态监控。
|
||||
新增 Maisaka 实时聊天流监控、阶段状态面板、控制台工具调用展示、prompt log HTML 预览。
|
||||
回复器支持多模态与非多模态统一行为,新增模型 visual 参数,避免非多模态模型误传图片。
|
||||
支持复杂消息、转发消息、图片原始数据解析、URL 图片浏览、表情包类消息标记。
|
||||
优化上下文压缩,显示实时上下文占用,压缩早期 assistant 信息。
|
||||
新增聊天特定额外 prompt、多语言 prompt、prompt 独立文件管理、用户自定义 prompt 与覆盖能力。
|
||||
新增工具索引展开方式,压缩工具描述,提高工具调用成功率,修复无参工具、孤儿工具、Gemini tool 等问题。
|
||||
新增回复后打分追踪器,用于记录和分析回复效果。
|
||||
优化回复频率控制、引用回复概率、打字时间、重复思考、wait 行为和 replyer 空回复处理。
|
||||
|
||||
记忆系统 / A_Memorix
|
||||
新增并主线化 A_Memorix 长期记忆系统,包含运行时、检索、存储、管理界面和迁移脚本。
|
||||
新增记忆测试、检索工具、记忆服务、记忆自动化钩子与写回链路。
|
||||
支持将旧 LPMM/旧记忆数据迁移到新长期记忆系统。
|
||||
优化记忆检索速度、token 消耗、时间信息、上下文检索方式和人物事实提取。
|
||||
新增记忆反馈修正、知识库反馈详情、图存储持久化、总结导入、embedding 维度控制等回归测试。
|
||||
移除旧 memory_system 中的大量检索工具与聊天总结逻辑,改由新服务层和 A_Memorix 承担。
|
||||
|
||||
插件系统 / Runtime
|
||||
大规模替换旧 plugin_system,新增 plugin_runtime。
|
||||
新增插件能力注册、组件注册、事件分发、Hook 分发、API 注册、Supervisor、Runner、RPC Server/Client。
|
||||
支持插件 manifest 校验、包式插件导入、临时 sys.path 管理、导入保护和模块访问控制。
|
||||
新增插件配置版本管理、配置归一化、运行时配置校验、批量插件重载。
|
||||
新增插件依赖流水线、HTML 渲染服务、插件 SDK 集成增强。
|
||||
新增旧数据库 peewee 兼容层,初步重构插件 database API。
|
||||
新增插件侧消息网关能力、出站追踪、会话 ID 计算和适配器回执消息 ID 更新。
|
||||
修复 Windows 平台插件运行时信号处理、DLL 导入隔离、包式导入、重载机制等问题。
|
||||
限制 maibot-plugin-sdk 版本范围,并升级到 2.3.0 相关适配。
|
||||
|
||||
MCP / 工具系统
|
||||
新增独立 mcp_module,包含连接、管理、Provider、Host LLM Bridge、Hook 与数据模型。
|
||||
引入统一插件与 MCP 工具系统,移除旧工具系统和 tool_use 模型。
|
||||
工具支持索引检索、延迟展开、统一控制台展示、失败请求留档与重试分析。
|
||||
新增 host LLM bridge,使 MCP 工具和宿主模型调用链路更统一。
|
||||
|
||||
WebUI / API
|
||||
WebUI 后端整体重构,拆分为 app、依赖、中间件、routers、schemas、services、utils 等结构。
|
||||
新增统一 WebSocket 连接管理与路由。
|
||||
新增聊天、配置、表情包、表达方式、黑话、插件、记忆、知识库、统计、系统等路由重构。
|
||||
新增规划器和回复器监控 API、日志搜索、日志上线数量配置、prompt log 预览。
|
||||
新增本地已安装插件 README 读取 API、插件安装/配置/运行时管理相关 API。
|
||||
新增静态资源包提示和错误处理,后续修复为仅使用包内 WebUI 静态资源。
|
||||
修复 knowledgebase 反馈详情类型问题、WebUI memory 路由、配置 schema 测试等问题。
|
||||
注意:历史中有大量 dashboard 前端提交和 WebUI dist 迁移/删除,但本次没有修改 dashboard。
|
||||
|
||||
配置 / 模型 / 依赖
|
||||
配置系统引入 ConfigBase 测试与更严格校验,支持自动检测并升级旧版配置。
|
||||
支持 Union / Optional 字段转换,并禁止不安全的多类型 Union。
|
||||
新增配置版本到 8.4.0,加入工具筛选、回复器、多模态、Maim Message、日志颜色等配置。
|
||||
移除 Planner 问题配置项、无用配置、旧路径显示配置、模板配置文件等冗余项。
|
||||
模型配置移除无用模型、utils_small、弃用的 LLM_judge 类型和 tool_use 模型。
|
||||
新增模型随机选择策略、模型 visual 参数、OpenAI 兼容性增强。
|
||||
修复 Qwen 3.5 空回复、Gemini 请求思考签名、部分模型不支持 gif、OpenAI client 工具请求等问题。
|
||||
移除 uv.lock,更新 pyproject.toml / requirements.txt 依赖,最终 HEAD 又移除部分依赖。
|
||||
|
||||
表情包 / 图片
|
||||
新增表情包系统重构,包含注册、识别、缓存、发送、选择、数据库迁移。
|
||||
表情包选择改为一次性选择全部,支持配置,并接入 subagent。
|
||||
移除旧内置 emoji 插件,改为 Maisaka 内置动作或新系统能力。
|
||||
修复表情包发送无记录、识别失败、缓存问题、图片存储问题、图片过大自动重试等。
|
||||
新增异步后台图片/表情处理、图片展示模式优化、复杂消息查看。
|
||||
|
||||
表达方式 / 黑话 / 学习
|
||||
新增自动表达优化、表达方式检查脚本、表达方式最后修改来源字段。
|
||||
修复私聊表达风格随机、表达方式学习与使用、表达方式全局共享。
|
||||
新增 planner 黑话缓存,恢复表达学习、黑话学习、黑话使用和表达使用。
|
||||
修复黑话提取学习缓存和 Jargon 提取问题。
|
||||
新增表达方式快速版本,优化表达方式提取与 LLM 判断标记。
|
||||
|
||||
文档 / 国际化 / 工程规范
|
||||
更新 README、徽章、快速导航、版本信息和主仓库地址。
|
||||
新增/更新 changelog、设计文档、todo、记忆契约文档、Caddy 反向代理与 TLS/SSL 文档。
|
||||
新增 AGENTS.md,并更新代码规范、导入顺序、注释规范、语言规范。
|
||||
新增 Crowdin 配置和多语言资源,包含中英日韩等 locale。
|
||||
新增 CodeRabbit 配置、PR 模板、测试计划和若干调试/迁移脚本。
|
||||
新增 agentlite 子项目/模块,包含 agent、tool、provider、skills、MCP、文件/网页/shell 工具和大量测试、示例、文档。
|
||||
测试与质量
|
||||
|
||||
|
||||
## [0.12.2] - 2026-1-11
|
||||
### 功能更改
|
||||
- 优化私聊wait逻辑
|
||||
- 超时时强制引用回复
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
# 插件API与规范修改
|
||||
|
||||
1. 现在`plugin_system`的`__init__.py`文件中包含了所有插件API的导入,用户可以直接使用`from src.plugin_system import *`来导入所有API。
|
||||
|
||||
2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from src.plugin_system.apis.plugin_register_api import register_plugin`来导入。
|
||||
- 顺便一提,按照1中说法,你可以这么用:
|
||||
```python
|
||||
from src.plugin_system import register_plugin
|
||||
```
|
||||
|
||||
3. 现在强制要求的property如下,即你必须覆盖的属性有:
|
||||
- `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同)
|
||||
- `enable_plugin`: 是否启用插件,默认为`True`。
|
||||
- `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)**
|
||||
- `python_dependencies`: 插件依赖的Python包列表,默认为空。**现在并不检查**
|
||||
- `config_file_name`: 插件配置文件名,默认为`config.toml`。
|
||||
- `config_schema`: 插件配置文件的schema,用于自动生成配置文件。
|
||||
4. 部分API的参数类型和返回值进行了调整
|
||||
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||
5. 增加了`logging_api`,可以用`get_logger`来获取日志记录器。
|
||||
6. 增加了插件和组件管理的API。
|
||||
7. `BaseCommand`的`execute`方法现在返回一个三元组,包含是否执行成功、可选的回复消息和是否拦截消息。
|
||||
- 这意味着你终于可以动态控制是否继续后续消息的处理了。
|
||||
8. 移除了dependency_manager,但是依然保留了`python_dependencies`属性,等待后续重构。
|
||||
- 一并移除了文档有关manager的内容。
|
||||
9. 增加了工具的有关api
|
||||
|
||||
# 插件系统修改
|
||||
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
||||
2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容
|
||||
3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。
|
||||
4. 现在增加了参数类型检查,完善了对应注释
|
||||
5. 现在插件抽象出了总基类 `PluginBase`
|
||||
- <del>基于`Action`和`Command`的插件基类现在为`BasePlugin`。</del>
|
||||
- <del>基于`Event`的插件基类现在为`BaseEventPlugin`。</del>
|
||||
- 基于`Action`,`Command`和`Event`的插件基类现在为`BasePlugin`,所有插件都应该继承此基类。
|
||||
- `BasePlugin`继承自`PluginBase`。
|
||||
- 所有的插件类都由`register_plugin`装饰器注册。
|
||||
6. 现在我们终于可以让插件有自定义的名字了!
|
||||
- 真正实现了插件的`plugin_name`**不受文件夹名称限制**的功能。(吐槽:可乐你的某个小小细节导致我搞了好久……)
|
||||
- 通过在插件类中定义`plugin_name`属性来指定插件内部标识符。
|
||||
- 由于此更改一个文件中现在可以有多个插件类,但每个插件类必须有**唯一的**`plugin_name`。
|
||||
- 在某些插件加载失败时,现在会显示包名而不是插件内部标识符。
|
||||
- 例如:`MaiMBot.plugins.example_plugin`而不是`example_plugin`。
|
||||
- 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况)
|
||||
7. 现在不支持单文件插件了,加载方式已经完全删除。
|
||||
8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。
|
||||
9. `BaseEventHandler`现在有了`get_config`方法了。
|
||||
10. 修正了`main.py`中的错误输出。
|
||||
11. 修正了`command`所编译的`Pattern`注册时的错误输出。
|
||||
12. `events_manager`有了task相关逻辑了。
|
||||
13. 现在有了插件卸载和重载功能了,也就是热插拔。
|
||||
14. 实现了组件的全局启用和禁用功能。
|
||||
- 通过`enable_component`和`disable_component`方法来启用或禁用组件。
|
||||
- 不过这个操作不会保存到配置文件~
|
||||
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
|
||||
- 通过`disable_specific_chat_action`,`enable_specific_chat_action`,`disable_specific_chat_command`,`enable_specific_chat_command`,`disable_specific_chat_event_handler`,`enable_specific_chat_event_handler`来操作
|
||||
- 同样不保存到配置文件~
|
||||
16. 把`BaseTool`一并合并进入了插件系统
|
||||
|
||||
# 官方插件修改
|
||||
1. `HelloWorld`插件现在有一个样例的`EventHandler`。
|
||||
2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。(需要自行启用)
|
||||
3. `HelloWorld`插件现在有一个样例的`CompareNumbersTool`。
|
||||
|
||||
### 执笔BGM
|
||||
塞壬唱片!
|
||||
@@ -1,328 +0,0 @@
|
||||
"""
|
||||
评估结果统计脚本
|
||||
|
||||
功能:
|
||||
1. 扫描temp目录下所有JSON文件
|
||||
2. 分析每个文件的统计信息
|
||||
3. 输出详细的统计报告
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.logger import get_logger # noqa: E402
|
||||
|
||||
logger = get_logger("evaluation_stats_analyzer")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
|
||||
|
||||
def parse_datetime(dt_str: str) -> datetime | None:
|
||||
"""解析ISO格式的日期时间字符串"""
|
||||
try:
|
||||
return datetime.fromisoformat(dt_str)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def analyze_single_file(file_path: str) -> Dict:
|
||||
"""
|
||||
分析单个JSON文件的统计信息
|
||||
|
||||
Args:
|
||||
file_path: JSON文件路径
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
file_name = os.path.basename(file_path)
|
||||
stats = {
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"error": None,
|
||||
"last_updated": None,
|
||||
"total_count": 0,
|
||||
"actual_count": 0,
|
||||
"suitable_count": 0,
|
||||
"unsuitable_count": 0,
|
||||
"suitable_rate": 0.0,
|
||||
"unique_pairs": 0,
|
||||
"evaluators": Counter(),
|
||||
"evaluation_dates": [],
|
||||
"date_range": None,
|
||||
"has_expression_id": False,
|
||||
"has_reason": False,
|
||||
"reason_count": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 基本信息
|
||||
stats["last_updated"] = data.get("last_updated")
|
||||
stats["total_count"] = data.get("total_count", 0)
|
||||
|
||||
results = data.get("manual_results", [])
|
||||
stats["actual_count"] = len(results)
|
||||
|
||||
if not results:
|
||||
return stats
|
||||
|
||||
# 统计通过/不通过
|
||||
suitable_count = sum(1 for r in results if r.get("suitable") is True)
|
||||
unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
|
||||
stats["suitable_count"] = suitable_count
|
||||
stats["unsuitable_count"] = unsuitable_count
|
||||
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
|
||||
|
||||
# 统计唯一的(situation, style)对
|
||||
pairs: Set[Tuple[str, str]] = set()
|
||||
for r in results:
|
||||
if "situation" in r and "style" in r:
|
||||
pairs.add((r["situation"], r["style"]))
|
||||
stats["unique_pairs"] = len(pairs)
|
||||
|
||||
# 统计评估者
|
||||
for r in results:
|
||||
evaluator = r.get("evaluator", "unknown")
|
||||
stats["evaluators"][evaluator] += 1
|
||||
|
||||
# 统计评估时间
|
||||
evaluation_dates = []
|
||||
for r in results:
|
||||
evaluated_at = r.get("evaluated_at")
|
||||
if evaluated_at:
|
||||
dt = parse_datetime(evaluated_at)
|
||||
if dt:
|
||||
evaluation_dates.append(dt)
|
||||
|
||||
stats["evaluation_dates"] = evaluation_dates
|
||||
if evaluation_dates:
|
||||
min_date = min(evaluation_dates)
|
||||
max_date = max(evaluation_dates)
|
||||
stats["date_range"] = {
|
||||
"start": min_date.isoformat(),
|
||||
"end": max_date.isoformat(),
|
||||
"duration_days": (max_date - min_date).days + 1,
|
||||
}
|
||||
|
||||
# 检查字段存在性
|
||||
stats["has_expression_id"] = any("expression_id" in r for r in results)
|
||||
stats["has_reason"] = any(r.get("reason") for r in results)
|
||||
stats["reason_count"] = sum(1 for r in results if r.get("reason"))
|
||||
|
||||
except Exception as e:
|
||||
stats["error"] = str(e)
|
||||
logger.error(f"分析文件 {file_name} 时出错: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def print_file_stats(stats: Dict, index: int = None):
|
||||
"""打印单个文件的统计信息"""
|
||||
prefix = f"[{index}] " if index is not None else ""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"{prefix}文件: {stats['file_name']}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
if stats["error"]:
|
||||
print(f"✗ 错误: {stats['error']}")
|
||||
return
|
||||
|
||||
print(f"文件路径: {stats['file_path']}")
|
||||
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
|
||||
|
||||
if stats["last_updated"]:
|
||||
print(f"最后更新: {stats['last_updated']}")
|
||||
|
||||
print("\n【记录统计】")
|
||||
print(f" 文件中的 total_count: {stats['total_count']}")
|
||||
print(f" 实际记录数: {stats['actual_count']}")
|
||||
|
||||
if stats["total_count"] != stats["actual_count"]:
|
||||
diff = stats["total_count"] - stats["actual_count"]
|
||||
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
|
||||
|
||||
print("\n【评估结果统计】")
|
||||
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
|
||||
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
|
||||
|
||||
print("\n【唯一性统计】")
|
||||
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条")
|
||||
if stats["actual_count"] > 0:
|
||||
duplicate_count = stats["actual_count"] - stats["unique_pairs"]
|
||||
duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
print("\n【评估者统计】")
|
||||
if stats["evaluators"]:
|
||||
for evaluator, count in stats["evaluators"].most_common():
|
||||
rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
print("\n【时间统计】")
|
||||
if stats["date_range"]:
|
||||
print(f" 最早评估时间: {stats['date_range']['start']}")
|
||||
print(f" 最晚评估时间: {stats['date_range']['end']}")
|
||||
print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天")
|
||||
else:
|
||||
print(" 无时间信息")
|
||||
|
||||
print("\n【字段统计】")
|
||||
print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}")
|
||||
print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}")
|
||||
if stats["has_reason"]:
|
||||
rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
|
||||
|
||||
|
||||
def print_summary(all_stats: List[Dict]):
|
||||
"""打印汇总统计信息"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print("汇总统计")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
total_files = len(all_stats)
|
||||
valid_files = [s for s in all_stats if not s.get("error")]
|
||||
error_files = [s for s in all_stats if s.get("error")]
|
||||
|
||||
print("\n【文件统计】")
|
||||
print(f" 总文件数: {total_files}")
|
||||
print(f" 成功解析: {len(valid_files)}")
|
||||
print(f" 解析失败: {len(error_files)}")
|
||||
|
||||
if error_files:
|
||||
print("\n 失败文件列表:")
|
||||
for stats in error_files:
|
||||
print(f" - {stats['file_name']}: {stats['error']}")
|
||||
|
||||
if not valid_files:
|
||||
print("\n没有成功解析的文件")
|
||||
return
|
||||
|
||||
# 汇总记录统计
|
||||
total_records = sum(s["actual_count"] for s in valid_files)
|
||||
total_suitable = sum(s["suitable_count"] for s in valid_files)
|
||||
total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
|
||||
total_unique_pairs = set()
|
||||
|
||||
# 收集所有唯一的(situation, style)对
|
||||
for stats in valid_files:
|
||||
try:
|
||||
with open(stats["file_path"], "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
for r in results:
|
||||
if "situation" in r and "style" in r:
|
||||
total_unique_pairs.add((r["situation"], r["style"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("\n【记录汇总】")
|
||||
print(f" 总记录数: {total_records:,} 条")
|
||||
print(
|
||||
f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
|
||||
if total_records > 0
|
||||
else " 通过: 0 条"
|
||||
)
|
||||
print(
|
||||
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
|
||||
if total_records > 0
|
||||
else " 不通过: 0 条"
|
||||
)
|
||||
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条")
|
||||
|
||||
if total_records > 0:
|
||||
duplicate_count = total_records - len(total_unique_pairs)
|
||||
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
# 汇总评估者统计
|
||||
all_evaluators = Counter()
|
||||
for stats in valid_files:
|
||||
all_evaluators.update(stats["evaluators"])
|
||||
|
||||
print("\n【评估者汇总】")
|
||||
if all_evaluators:
|
||||
for evaluator, count in all_evaluators.most_common():
|
||||
rate = (count / total_records * 100) if total_records > 0 else 0
|
||||
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
# 汇总时间范围
|
||||
all_dates = []
|
||||
for stats in valid_files:
|
||||
all_dates.extend(stats["evaluation_dates"])
|
||||
|
||||
if all_dates:
|
||||
min_date = min(all_dates)
|
||||
max_date = max(all_dates)
|
||||
print("\n【时间汇总】")
|
||||
print(f" 最早评估时间: {min_date.isoformat()}")
|
||||
print(f" 最晚评估时间: {max_date.isoformat()}")
|
||||
print(f" 总时间跨度: {(max_date - min_date).days + 1} 天")
|
||||
|
||||
# 文件大小汇总
|
||||
total_size = sum(s["file_size"] for s in valid_files)
|
||||
avg_size = total_size / len(valid_files) if valid_files else 0
|
||||
print("\n【文件大小汇总】")
|
||||
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
|
||||
print(f" 平均大小: {avg_size:,.0f} 字节 ({avg_size / 1024:.2f} KB)")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 80)
|
||||
logger.info("开始分析评估结果统计信息")
|
||||
logger.info("=" * 80)
|
||||
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
print(f"\n✗ 错误:未找到temp目录: {TEMP_DIR}")
|
||||
logger.error(f"未找到temp目录: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
if not json_files:
|
||||
print(f"\n✗ 错误:temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
json_files.sort() # 按文件名排序
|
||||
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件")
|
||||
print("=" * 80)
|
||||
|
||||
# 分析每个文件
|
||||
all_stats = []
|
||||
for i, json_file in enumerate(json_files, 1):
|
||||
stats = analyze_single_file(json_file)
|
||||
all_stats.append(stats)
|
||||
print_file_stats(stats, index=i)
|
||||
|
||||
# 打印汇总统计
|
||||
print_summary(all_stats)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("分析完成")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,388 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import json
|
||||
import os
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能找到 src 包
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
logger = get_logger("delete_lpmm_items")
|
||||
|
||||
|
||||
def read_hashes(file_path: Path) -> List[str]:
|
||||
"""读取哈希列表,跳过空行"""
|
||||
hashes: List[str] = []
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
val = line.strip()
|
||||
if not val:
|
||||
continue
|
||||
hashes.append(val)
|
||||
return hashes
|
||||
|
||||
|
||||
def read_openie_hashes(file_path: Path) -> List[str]:
|
||||
"""从 OpenIE JSON 中提取 idx 作为段落哈希"""
|
||||
data: Dict[str, Any] = json.loads(file_path.read_text(encoding="utf-8"))
|
||||
docs = data.get("docs", []) if isinstance(data, dict) else []
|
||||
hashes: List[str] = []
|
||||
for doc in docs:
|
||||
idx = doc.get("idx") if isinstance(doc, dict) else None
|
||||
if isinstance(idx, str) and idx.strip():
|
||||
hashes.append(idx.strip())
|
||||
return hashes
|
||||
|
||||
|
||||
def normalize_paragraph_keys(raw_hashes: List[str]) -> Tuple[List[str], List[str]]:
|
||||
"""将输入规范为完整键和纯哈希两份列表"""
|
||||
keys: List[str] = []
|
||||
hashes: List[str] = []
|
||||
for h in raw_hashes:
|
||||
if h.startswith("paragraph-"):
|
||||
keys.append(h)
|
||||
hashes.append(h.replace("paragraph-", "", 1))
|
||||
else:
|
||||
keys.append(f"paragraph-{h}")
|
||||
hashes.append(h)
|
||||
return keys, hashes
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Delete paragraphs from LPMM knowledge base (vectors + graph).")
|
||||
parser.add_argument("--hash-file", help="文本文件路径,每行一个 paragraph 哈希或带前缀键")
|
||||
parser.add_argument("--openie-file", help="OpenIE 输出文件(JSON),将其 docs.idx 作为待删段落哈希")
|
||||
parser.add_argument("--raw-file", help="原始 txt 语料文件(按空行分段),可结合 --raw-index 使用")
|
||||
parser.add_argument(
|
||||
"--raw-index",
|
||||
help="在 --raw-file 中要删除的段落索引,1 基,支持逗号分隔,例如 1,3",
|
||||
)
|
||||
parser.add_argument("--search-text", help="在当前段落库中按子串搜索匹配段落并交互选择删除")
|
||||
parser.add_argument(
|
||||
"--search-limit",
|
||||
type=int,
|
||||
default=10,
|
||||
help="--search-text 模式下最多展示的候选段落数量",
|
||||
)
|
||||
parser.add_argument("--delete-entities", action="store_true", help="同时删除 OpenIE 文件中的实体节点/嵌入")
|
||||
parser.add_argument("--delete-relations", action="store_true", help="同时删除 OpenIE 文件中的关系嵌入")
|
||||
parser.add_argument("--remove-orphan-entities", action="store_true", help="删除删除后孤立的实体节点")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅预览将删除的项,不实际修改")
|
||||
parser.add_argument("--yes", action="store_true", help="跳过交互确认,直接执行删除(谨慎使用)")
|
||||
parser.add_argument(
|
||||
"--max-delete-nodes",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="单次最大允许删除的节点数量(段落+实体),超过则需要显式确认或调整该参数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help=(
|
||||
"非交互模式:不再通过 input() 询问任何信息;"
|
||||
"在该模式下,如果需要交互(例如 --search-text 未指定具体条目、未提供 --yes),"
|
||||
"会直接报错退出。"
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 至少需要一种来源
|
||||
if not (args.hash_file or args.openie_file or args.raw_file or args.search_text):
|
||||
logger.error("必须指定 --hash-file / --openie-file / --raw-file / --search-text 之一")
|
||||
sys.exit(1)
|
||||
|
||||
raw_hashes: List[str] = []
|
||||
raw_entities: List[str] = []
|
||||
raw_relations: List[str] = []
|
||||
|
||||
if args.hash_file:
|
||||
hash_file = Path(args.hash_file)
|
||||
if not hash_file.exists():
|
||||
logger.error(f"哈希文件不存在: {hash_file}")
|
||||
sys.exit(1)
|
||||
raw_hashes.extend(read_hashes(hash_file))
|
||||
|
||||
if args.openie_file:
|
||||
openie_path = Path(args.openie_file)
|
||||
if not openie_path.exists():
|
||||
logger.error(f"OpenIE 文件不存在: {openie_path}")
|
||||
sys.exit(1)
|
||||
# 段落
|
||||
raw_hashes.extend(read_openie_hashes(openie_path))
|
||||
# 实体/关系(实体同时包含 extracted_entities 与三元组主语/宾语,以匹配 KG 构图逻辑)
|
||||
try:
|
||||
data = json.loads(openie_path.read_text(encoding="utf-8"))
|
||||
docs = data.get("docs", []) if isinstance(data, dict) else []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
ents = doc.get("extracted_entities", [])
|
||||
if isinstance(ents, list):
|
||||
raw_entities.extend([e for e in ents if isinstance(e, str)])
|
||||
triples = doc.get("extracted_triples", [])
|
||||
if isinstance(triples, list):
|
||||
for t in triples:
|
||||
if isinstance(t, list) and len(t) == 3:
|
||||
subj, _, obj = t
|
||||
if isinstance(subj, str):
|
||||
raw_entities.append(subj)
|
||||
if isinstance(obj, str):
|
||||
raw_entities.append(obj)
|
||||
raw_relations.append(str(tuple(t)))
|
||||
except Exception as e:
|
||||
logger.error(f"读取 OpenIE 文件失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 从原始 txt 语料按段落索引选择删除
|
||||
if args.raw_file:
|
||||
raw_path = Path(args.raw_file)
|
||||
if not raw_path.exists():
|
||||
logger.error(f"原始语料文件不存在: {raw_path}")
|
||||
sys.exit(1)
|
||||
text = raw_path.read_text(encoding="utf-8")
|
||||
paragraphs: List[str] = []
|
||||
buf = []
|
||||
for line in text.splitlines():
|
||||
if line.strip() == "":
|
||||
if buf:
|
||||
paragraphs.append("\n".join(buf).strip())
|
||||
buf = []
|
||||
else:
|
||||
buf.append(line)
|
||||
if buf:
|
||||
paragraphs.append("\n".join(buf).strip())
|
||||
|
||||
if not paragraphs:
|
||||
logger.error(f"原始语料文件 {raw_path} 中没有解析到任何段落")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.raw_index:
|
||||
logger.info(
|
||||
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# 解析索引列表(1-based)
|
||||
try:
|
||||
idx_list = [int(x.strip()) for x in str(args.raw_index).split(",") if x.strip()]
|
||||
except ValueError:
|
||||
logger.error(f"--raw-index 解析失败: {args.raw_index}")
|
||||
sys.exit(1)
|
||||
|
||||
for idx in idx_list:
|
||||
if idx < 1 or idx > len(paragraphs):
|
||||
logger.error(f"--raw-index 包含无效索引 {idx}(有效范围 1~{len(paragraphs)})")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("根据原始语料选择段落:")
|
||||
for idx in idx_list:
|
||||
para = paragraphs[idx - 1]
|
||||
h = get_sha256(para)
|
||||
logger.info(f"- 第 {idx} 段,hash={h},内容预览:{para[:80]}")
|
||||
raw_hashes.append(h)
|
||||
|
||||
# 在现有库中按子串搜索候选段落并交互选择
|
||||
if args.search_text:
|
||||
search_text = args.search_text.strip()
|
||||
if not search_text:
|
||||
logger.error("--search-text 不能为空")
|
||||
sys.exit(1)
|
||||
logger.info(f"正在根据关键字在现有段落库中搜索:{search_text!r}")
|
||||
em_search = EmbeddingManager()
|
||||
try:
|
||||
em_search.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"加载嵌入库失败,无法使用 --search-text 功能: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
candidates = []
|
||||
for key, item in em_search.paragraphs_embedding_store.store.items():
|
||||
if search_text in item.str:
|
||||
candidates.append((key, item.str))
|
||||
if len(candidates) >= args.search_limit:
|
||||
break
|
||||
|
||||
if not candidates:
|
||||
logger.info("未在现有段落库中找到包含该关键字的段落")
|
||||
else:
|
||||
logger.info("找到以下候选段落(输入序号选择要删除的条目,可用逗号分隔,多选):")
|
||||
for i, (key, text) in enumerate(candidates, start=1):
|
||||
logger.info(f"{i}. {key} | {text[:80]}")
|
||||
if args.non_interactive:
|
||||
logger.error(
|
||||
"当前处于非交互模式,无法通过输入序号选择要删除的候选段落;"
|
||||
"如需脚本化删除,请改用 --hash-file / --openie-file / --raw-file 等方式。"
|
||||
)
|
||||
sys.exit(1)
|
||||
choice = input("请输入要删除的序号列表(如 1,3),或直接回车取消:").strip()
|
||||
if choice:
|
||||
try:
|
||||
idxs = [int(x.strip()) for x in choice.split(",") if x.strip()]
|
||||
except ValueError:
|
||||
logger.error("输入的序号列表无法解析,已取消 --search-text 删除")
|
||||
else:
|
||||
for i in idxs:
|
||||
if 1 <= i <= len(candidates):
|
||||
key, _ = candidates[i - 1]
|
||||
# key 已是完整的 paragraph-xxx
|
||||
if key.startswith("paragraph-"):
|
||||
raw_hashes.append(key.split("paragraph-", 1)[1])
|
||||
else:
|
||||
logger.warning(f"忽略无效序号: {i}")
|
||||
|
||||
# 去重但保持顺序
|
||||
seen = set()
|
||||
raw_hashes = [h for h in raw_hashes if not (h in seen or seen.add(h))]
|
||||
|
||||
if not raw_hashes:
|
||||
logger.error("未读取到任何待删哈希,无操作")
|
||||
sys.exit(1)
|
||||
|
||||
keys, pg_hashes = normalize_paragraph_keys(raw_hashes)
|
||||
|
||||
ent_hashes: List[str] = []
|
||||
rel_hashes: List[str] = []
|
||||
if args.delete_entities and raw_entities:
|
||||
ent_hashes = [get_sha256(e) for e in raw_entities]
|
||||
if args.delete_relations and raw_relations:
|
||||
rel_hashes = [get_sha256(r) for r in raw_relations]
|
||||
|
||||
logger.info("=== 删除操作预备 ===")
|
||||
logger.info("请确保已备份 data/embedding 与 data/rag,必要时可使用 --dry-run 预览")
|
||||
logger.info(f"待删除段落数量: {len(keys)}")
|
||||
logger.info(f"示例: {keys[:5]}")
|
||||
if ent_hashes:
|
||||
logger.info(f"待删除实体数量: {len(ent_hashes)}")
|
||||
if rel_hashes:
|
||||
logger.info(f"待删除关系数量: {len(rel_hashes)}")
|
||||
|
||||
total_nodes_to_delete = len(pg_hashes) + (len(ent_hashes) if args.delete_entities else 0)
|
||||
logger.info(f"本次预计删除节点总数(段落+实体): {total_nodes_to_delete}")
|
||||
|
||||
if args.dry_run:
|
||||
logger.info("dry-run 模式,未执行删除")
|
||||
return
|
||||
|
||||
# 大批次删除保护
|
||||
if total_nodes_to_delete > args.max_delete_nodes and not args.yes:
|
||||
logger.error(
|
||||
f"本次预计删除节点 {total_nodes_to_delete} 个,超过阈值 {args.max_delete_nodes}。"
|
||||
" 为避免误删,请降低批次规模或使用 --max-delete-nodes 调整阈值,并加上 --yes 明确确认。"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# 交互确认
|
||||
if not args.yes:
|
||||
if args.non_interactive:
|
||||
logger.error(
|
||||
"当前处于非交互模式且未指定 --yes,出于安全考虑,删除操作已被拒绝。\n"
|
||||
"如确认需要在非交互模式下执行删除,请显式添加 --yes 参数。"
|
||||
)
|
||||
sys.exit(1)
|
||||
confirm = input("确认删除上述数据?输入大写 YES 以继续,其他任意键取消: ").strip()
|
||||
if confirm != "YES":
|
||||
logger.info("用户取消删除操作")
|
||||
return
|
||||
|
||||
# 加载嵌入与图
|
||||
embed_manager = EmbeddingManager()
|
||||
kg_manager = KGManager()
|
||||
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"加载现有知识库失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 记录删除前全局统计,便于对比
|
||||
before_para_vec = len(embed_manager.paragraphs_embedding_store.store)
|
||||
before_ent_vec = len(embed_manager.entities_embedding_store.store)
|
||||
before_rel_vec = len(embed_manager.relation_embedding_store.store)
|
||||
before_nodes = len(kg_manager.graph.get_node_list())
|
||||
before_edges = len(kg_manager.graph.get_edge_list())
|
||||
logger.info(
|
||||
f"删除前统计: 段落向量={before_para_vec}, 实体向量={before_ent_vec}, 关系向量={before_rel_vec}, "
|
||||
f"KG节点={before_nodes}, KG边={before_edges}"
|
||||
)
|
||||
|
||||
# 删除向量
|
||||
deleted, skipped = embed_manager.paragraphs_embedding_store.delete_items(keys)
|
||||
embed_manager.stored_pg_hashes = set(embed_manager.paragraphs_embedding_store.store.keys())
|
||||
logger.info(f"段落向量删除完成,删除: {deleted}, 跳过: {skipped}")
|
||||
ent_deleted = ent_skipped = rel_deleted = rel_skipped = 0
|
||||
if ent_hashes:
|
||||
ent_keys = [f"entity-{h}" for h in ent_hashes]
|
||||
ent_deleted, ent_skipped = embed_manager.entities_embedding_store.delete_items(ent_keys)
|
||||
logger.info(f"实体向量删除完成,删除: {ent_deleted}, 跳过: {ent_skipped}")
|
||||
if rel_hashes:
|
||||
rel_keys = [f"relation-{h}" for h in rel_hashes]
|
||||
rel_deleted, rel_skipped = embed_manager.relation_embedding_store.delete_items(rel_keys)
|
||||
logger.info(f"关系向量删除完成,删除: {rel_deleted}, 跳过: {rel_skipped}")
|
||||
|
||||
# 删除图节点/边
|
||||
kg_result = kg_manager.delete_paragraphs(
|
||||
pg_hashes,
|
||||
ent_hashes=ent_hashes if args.delete_entities else None,
|
||||
remove_orphan_entities=args.remove_orphan_entities,
|
||||
)
|
||||
logger.info(
|
||||
f"KG 删除完成,删除: {kg_result.get('deleted', 0)}, 跳过: {kg_result.get('skipped', 0)}, "
|
||||
f"孤立实体清理: {kg_result.get('orphan_removed', 0)}"
|
||||
)
|
||||
|
||||
# 重建索引并保存
|
||||
logger.info("重建 Faiss 索引并保存嵌入文件...")
|
||||
embed_manager.rebuild_faiss_index()
|
||||
embed_manager.save_to_file()
|
||||
|
||||
logger.info("保存 KG 数据...")
|
||||
kg_manager.save_to_file()
|
||||
|
||||
# 删除后统计
|
||||
after_para_vec = len(embed_manager.paragraphs_embedding_store.store)
|
||||
after_ent_vec = len(embed_manager.entities_embedding_store.store)
|
||||
after_rel_vec = len(embed_manager.relation_embedding_store.store)
|
||||
after_nodes = len(kg_manager.graph.get_node_list())
|
||||
after_edges = len(kg_manager.graph.get_edge_list())
|
||||
|
||||
logger.info(
|
||||
"删除后统计: 段落向量=%d(%+d), 实体向量=%d(%+d), 关系向量=%d(%+d), KG节点=%d(%+d), KG边=%d(%+d)"
|
||||
% (
|
||||
after_para_vec,
|
||||
after_para_vec - before_para_vec,
|
||||
after_ent_vec,
|
||||
after_ent_vec - before_ent_vec,
|
||||
after_rel_vec,
|
||||
after_rel_vec - before_rel_vec,
|
||||
after_nodes,
|
||||
after_nodes - before_nodes,
|
||||
after_edges,
|
||||
after_edges - before_edges,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("删除流程完成")
|
||||
print(
|
||||
"\n[NOTICE] 删除脚本执行完毕。如主程序(聊天 / WebUI)已在运行,"
|
||||
"请重启主程序,或在主程序内部调用一次 lpmm_start_up() 以应用最新 LPMM 知识库。"
|
||||
)
|
||||
print("[NOTICE] 如果不清楚 lpmm_start_up 是什么,直接重启主程序即可。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,301 +0,0 @@
|
||||
# try:
|
||||
# import src.plugins.knowledge.lib.quick_algo
|
||||
# except ImportError:
|
||||
# print("未找到quick_algo库,无法使用quick_algo算法")
|
||||
# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace")
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from time import sleep
|
||||
from typing import Optional
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
logger = get_logger("OpenIE导入")
|
||||
|
||||
|
||||
def ensure_openie_dir():
|
||||
"""确保OpenIE数据目录存在"""
|
||||
if not os.path.exists(OPENIE_DIR):
|
||||
os.makedirs(OPENIE_DIR)
|
||||
logger.info(f"创建OpenIE数据目录:{OPENIE_DIR}")
|
||||
else:
|
||||
logger.info(f"OpenIE数据目录已存在:{OPENIE_DIR}")
|
||||
|
||||
|
||||
def hash_deduplicate(
|
||||
raw_paragraphs: dict[str, str],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
stored_pg_hashes: set,
|
||||
stored_paragraph_hashes: set,
|
||||
):
|
||||
"""Hash去重
|
||||
|
||||
Args:
|
||||
raw_paragraphs: 索引的段落原文
|
||||
triple_list_data: 索引的三元组列表
|
||||
stored_pg_hashes: 已存储的段落hash集合
|
||||
stored_paragraph_hashes: 已存储的段落hash集合
|
||||
|
||||
Returns:
|
||||
new_raw_paragraphs: 去重后的段落
|
||||
new_triple_list_data: 去重后的三元组
|
||||
"""
|
||||
# 保存去重后的段落
|
||||
new_raw_paragraphs = {}
|
||||
# 保存去重后的三元组
|
||||
new_triple_list_data = {}
|
||||
|
||||
for _, (raw_paragraph, triple_list) in enumerate(
|
||||
zip(raw_paragraphs.values(), triple_list_data.values(), strict=False)
|
||||
):
|
||||
# 段落hash
|
||||
paragraph_hash = get_sha256(raw_paragraph)
|
||||
# 使用与EmbeddingStore中一致的命名空间格式:namespace-hash
|
||||
paragraph_key = f"paragraph-{paragraph_hash}"
|
||||
if paragraph_key in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||
continue
|
||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
||||
new_triple_list_data[paragraph_hash] = triple_list
|
||||
|
||||
return new_raw_paragraphs, new_triple_list_data
|
||||
|
||||
|
||||
def handle_import_openie(
|
||||
openie_data: OpenIE,
|
||||
embed_manager: EmbeddingManager,
|
||||
kg_manager: KGManager,
|
||||
non_interactive: bool = False,
|
||||
) -> bool:
|
||||
# sourcery skip: extract-method
|
||||
# 从OpenIE数据中提取段落原文与三元组列表
|
||||
# 索引的段落原文
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
# 索引的实体列表
|
||||
entity_list_data = openie_data.extract_entity_dict()
|
||||
# 索引的三元组列表
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
# print(openie_data.docs)
|
||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||
logger.error("OpenIE数据存在异常")
|
||||
logger.error(f"原始段落数量:{len(raw_paragraphs)}")
|
||||
logger.error(f"实体列表数量:{len(entity_list_data)}")
|
||||
logger.error(f"三元组列表数量:{len(triple_list_data)}")
|
||||
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
|
||||
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
|
||||
logger.error("或者一段中只有符号的情况")
|
||||
# 新增:检查docs中每条数据的完整性
|
||||
logger.error("系统将于2秒后开始检查数据完整性")
|
||||
sleep(2)
|
||||
found_missing = False
|
||||
missing_idxs = []
|
||||
for doc in getattr(openie_data, "docs", []):
|
||||
idx = doc.get("idx", "<无idx>")
|
||||
passage = doc.get("passage", "<无passage>")
|
||||
missing = []
|
||||
# 检查字段是否存在且非空
|
||||
if "passage" not in doc or not doc.get("passage"):
|
||||
missing.append("passage")
|
||||
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
|
||||
missing.append("名词列表缺失")
|
||||
elif len(doc.get("extracted_entities", [])) == 0:
|
||||
missing.append("名词列表为空")
|
||||
if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list):
|
||||
missing.append("主谓宾三元组缺失")
|
||||
elif len(doc.get("extracted_triples", [])) == 0:
|
||||
missing.append("主谓宾三元组为空")
|
||||
# 输出所有doc的idx
|
||||
# print(f"检查: idx={idx}")
|
||||
if missing:
|
||||
found_missing = True
|
||||
missing_idxs.append(idx)
|
||||
logger.error("\n")
|
||||
logger.error("数据缺失:")
|
||||
logger.error(f"对应哈希值:{idx}")
|
||||
logger.error(f"对应文段内容内容:{passage}")
|
||||
logger.error(f"非法原因:{', '.join(missing)}")
|
||||
# 确保提示在所有非法数据输出后再输出
|
||||
if not found_missing:
|
||||
logger.info("所有数据均完整,没有发现缺失字段。")
|
||||
return False
|
||||
# 新增:提示用户是否删除非法文段继续导入
|
||||
# 在非交互模式下,不再询问用户,而是直接报错终止
|
||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||
if non_interactive:
|
||||
logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
|
||||
sys.exit(1)
|
||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||
user_choice = input().strip().lower()
|
||||
if user_choice != "y":
|
||||
logger.info("用户选择不删除非法文段,程序终止。")
|
||||
sys.exit(1)
|
||||
# 删除非法文段
|
||||
logger.info("正在删除非法文段并继续导入...")
|
||||
# 过滤掉非法文段
|
||||
openie_data.docs = [
|
||||
doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs
|
||||
]
|
||||
# 重新提取数据
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
entity_list_data = openie_data.extract_entity_dict()
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
# 再次校验
|
||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||
logger.error("删除非法文段后,数据仍不一致,程序终止。")
|
||||
sys.exit(1)
|
||||
# 将索引换为对应段落的hash值
|
||||
logger.info("正在进行段落去重与重索引")
|
||||
raw_paragraphs, triple_list_data = hash_deduplicate(
|
||||
raw_paragraphs,
|
||||
triple_list_data,
|
||||
embed_manager.stored_pg_hashes,
|
||||
kg_manager.stored_paragraph_hashes,
|
||||
)
|
||||
if len(raw_paragraphs) != 0:
|
||||
# 获取嵌入并保存
|
||||
logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}")
|
||||
logger.info("开始Embedding")
|
||||
embed_manager.store_new_data_set(raw_paragraphs, triple_list_data)
|
||||
# Embedding-Faiss重索引
|
||||
logger.info("正在重新构建向量索引")
|
||||
embed_manager.rebuild_faiss_index()
|
||||
logger.info("向量索引构建完成")
|
||||
embed_manager.save_to_file()
|
||||
logger.info("Embedding完成")
|
||||
# 构建新段落的RAG
|
||||
logger.info("开始构建RAG")
|
||||
kg_manager.build_kg(triple_list_data, embed_manager)
|
||||
kg_manager.save_to_file()
|
||||
logger.info("RAG构建完成")
|
||||
else:
|
||||
logger.info("无新段落需要处理")
|
||||
return True
|
||||
|
||||
|
||||
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
if non_interactive:
|
||||
logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
|
||||
else:
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||
print("推荐使用硅基流动的Pro/BAAI/bge-m3")
|
||||
print("每百万Token费用为0.7元")
|
||||
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
|
||||
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_openie_dir() # 确保OpenIE目录存在
|
||||
logger.info("----开始导入openie数据----\n")
|
||||
|
||||
logger.info("创建LLM客户端")
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager()
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"从文件加载Embedding库时发生错误:{e}")
|
||||
if "嵌入模型与本地存储不一致" in str(e):
|
||||
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
||||
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
sys.exit(1)
|
||||
if "不存在" in str(e):
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
kg_manager = KGManager()
|
||||
logger.info("正在从文件加载KG")
|
||||
try:
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"从文件加载KG时发生错误:{e}")
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("KG加载完成")
|
||||
|
||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
||||
|
||||
# 数据比对:Embedding库与KG的段落hash集合
|
||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||
# 使用与EmbeddingStore中一致的命名空间格式:namespace-hash
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_manager.stored_pg_hashes:
|
||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||
|
||||
logger.info("正在导入OpenIE数据文件")
|
||||
try:
|
||||
openie_data = OpenIE.load()
|
||||
except Exception as e:
|
||||
logger.error(f"导入OpenIE数据文件时发生错误:{e}")
|
||||
return False
|
||||
if handle_import_openie(openie_data, embed_manager, kg_manager, non_interactive=non_interactive) is False:
|
||||
logger.error("处理OpenIE数据时发生错误")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
"""主函数 - 解析参数并运行异步主流程。"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help="非交互模式:跳过导入确认提示以及非法文段删除询问,遇到非法文段时直接报错退出。",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# 检查是否有现有的事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_closed():
|
||||
# 如果事件循环已关闭,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
ok: bool = False
|
||||
try:
|
||||
# 在新的事件循环中运行异步主函数
|
||||
ok = loop.run_until_complete(main_async(non_interactive=args.non_interactive))
|
||||
print(
|
||||
"\n[NOTICE] OpenIE 导入脚本执行完毕。如主程序(聊天 / WebUI)已在运行,"
|
||||
"请重启主程序,或在主程序内部调用一次 lpmm_start_up() 以应用最新 LPMM 知识库。"
|
||||
)
|
||||
print("[NOTICE] 如果不清楚 lpmm_start_up 是什么,直接重启主程序即可。")
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
if not loop.is_closed():
|
||||
loop.close()
|
||||
if not ok:
|
||||
# 统一错误码,方便在非交互场景下检测失败
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
||||
main()
|
||||
@@ -1,248 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock, Event
|
||||
import sys
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
# 添加项目根目录到 sys.path
|
||||
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("LPMM知识库-信息提取")
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
"""确保临时目录和输出目录存在"""
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
os.makedirs(TEMP_DIR)
|
||||
logger.info(f"已创建临时目录: {TEMP_DIR}")
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
if not os.path.exists(RAW_DATA_PATH):
|
||||
os.makedirs(RAW_DATA_PATH)
|
||||
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
|
||||
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
open_ie_doc_lock = Lock()
|
||||
|
||||
# 创建一个事件标志,用于控制程序终止
|
||||
shutdown_event = Event()
|
||||
|
||||
lpmm_entity_extract_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
|
||||
|
||||
def process_single_text(pg_hash, raw_data):
|
||||
"""处理单个文本的函数,用于线程池"""
|
||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||
|
||||
# 使用文件锁检查和读取缓存文件
|
||||
with file_lock:
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
# 存在对应的提取结果
|
||||
logger.info(f"找到缓存的提取结果:{pg_hash}")
|
||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f), None
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON文件损坏,删除它并重新处理
|
||||
logger.warning(f"缓存文件损坏,重新处理:{pg_hash}")
|
||||
os.remove(temp_file_path)
|
||||
|
||||
entity_list, rdf_triple_list = info_extract_from_str(
|
||||
lpmm_entity_extract_llm,
|
||||
lpmm_rdf_build_llm,
|
||||
raw_data,
|
||||
)
|
||||
if entity_list is None or rdf_triple_list is None:
|
||||
return None, pg_hash
|
||||
doc_item = {
|
||||
"idx": pg_hash,
|
||||
"passage": raw_data,
|
||||
"extracted_entities": entity_list,
|
||||
"extracted_triples": rdf_triple_list,
|
||||
}
|
||||
# 保存临时提取结果
|
||||
with file_lock:
|
||||
try:
|
||||
with open(temp_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(doc_item, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
|
||||
# 如果保存失败,确保不会留下损坏的文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
sys.exit(0)
|
||||
return None, pg_hash
|
||||
return doc_item, None
|
||||
|
||||
|
||||
def signal_handler(_signum, _frame):
|
||||
"""处理Ctrl+C信号"""
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension-to-generator, extract-method
|
||||
# 设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
ensure_dirs() # 确保目录存在
|
||||
# 新增用户确认提示
|
||||
if non_interactive:
|
||||
logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
|
||||
else:
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。")
|
||||
print("建议使用硅基流动的非Pro模型")
|
||||
print("或者使用可以用赠金抵扣的Pro模型")
|
||||
print("请确保账户余额充足,并且在执行前确认无误。")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
|
||||
# 友好提示:说明“网络错误(可重试)”日志属于正常自动重试行为,避免用户误以为任务失败
|
||||
print(
|
||||
"\n提示:在提取过程中,如果看到模型出现“网络错误(可重试)”等日志,"
|
||||
"表示系统正在自动重试请求,一般不会影响整体导入结果,请耐心等待即可。\n"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_dirs() # 确保目录存在
|
||||
logger.info("--------进行信息提取--------\n")
|
||||
|
||||
# 加载原始数据
|
||||
logger.info("正在加载原始数据")
|
||||
all_sha256_list, all_raw_datas = load_raw_data()
|
||||
|
||||
failed_sha256 = []
|
||||
open_ie_doc = []
|
||||
|
||||
workers = global_config.lpmm_knowledge.info_extraction_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_hash = {
|
||||
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
|
||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||
}
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
|
||||
try:
|
||||
for future in as_completed(future_to_hash):
|
||||
if shutdown_event.is_set():
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash:
|
||||
failed_sha256.append(failed_hash)
|
||||
logger.error(f"提取失败:{failed_hash}")
|
||||
elif doc_item:
|
||||
with open_ie_doc_lock:
|
||||
open_ie_doc.append(doc_item)
|
||||
logger.info(f'已处理"{doc_item.get("passage", "")}"')
|
||||
progress.update(task, advance=1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
shutdown_event.set()
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
|
||||
# 合并所有文件的提取结果并保存
|
||||
if open_ie_doc:
|
||||
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
|
||||
openie_obj = OpenIE(
|
||||
open_ie_doc,
|
||||
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
|
||||
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
|
||||
)
|
||||
# 输出文件名格式:MM-DD-HH-ss-openie.json
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%m-%d-%H-%S-openie.json")
|
||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||
else:
|
||||
logger.warning("没有可保存的信息提取结果")
|
||||
|
||||
logger.info("--------信息提取完成--------")
|
||||
logger.info(f"提取失败的文段SHA256:{failed_sha256}")
|
||||
|
||||
|
||||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"LPMM 信息提取脚本:从 data/lpmm_raw_data/*.txt 中读取原始段落,"
|
||||
"调用 LLM 提取实体和三元组,并生成 OpenIE JSON 批次文件。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help="非交互模式:跳过费用确认提示,直接开始执行;适用于 CI / 定时任务等场景。",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
_run(non_interactive=args.non_interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,132 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("inspect_lpmm_batch")
|
||||
|
||||
|
||||
def load_openie_hashes(path: Path) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""从 OpenIE JSON 中提取段落 / 实体 / 关系的哈希
|
||||
|
||||
注意:实体既包括 extracted_entities 中的条目,也包括三元组中的主语/宾语,
|
||||
以与 KG 构图逻辑保持一致。
|
||||
"""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
pg_hashes: List[str] = []
|
||||
ent_hashes: List[str] = []
|
||||
rel_hashes: List[str] = []
|
||||
|
||||
for doc in data.get("docs", []):
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
idx = doc.get("idx")
|
||||
if isinstance(idx, str) and idx.strip():
|
||||
pg_hashes.append(idx.strip())
|
||||
|
||||
ents = doc.get("extracted_entities", [])
|
||||
if isinstance(ents, list):
|
||||
for e in ents:
|
||||
if isinstance(e, str):
|
||||
ent_hashes.append(get_sha256(e))
|
||||
|
||||
triples = doc.get("extracted_triples", [])
|
||||
if isinstance(triples, list):
|
||||
for t in triples:
|
||||
if isinstance(t, list) and len(t) == 3:
|
||||
# 主语/宾语作为实体参与构图
|
||||
subj, _, obj = t
|
||||
if isinstance(subj, str):
|
||||
ent_hashes.append(get_sha256(subj))
|
||||
if isinstance(obj, str):
|
||||
ent_hashes.append(get_sha256(obj))
|
||||
rel_hashes.append(get_sha256(str(tuple(t))))
|
||||
|
||||
# 去重但保留顺序
|
||||
def unique(seq: List[str]) -> List[str]:
|
||||
seen = set()
|
||||
return [x for x in seq if not (x in seen or seen.add(x))]
|
||||
|
||||
return unique(pg_hashes), unique(ent_hashes), unique(rel_hashes)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="检查指定 OpenIE 文件对应批次在当前向量库与 KG 中的存在情况(用于验证删除效果)。"
|
||||
)
|
||||
parser.add_argument("--openie-file", required=True, help="OpenIE 输出 JSON 文件路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
openie_path = Path(args.openie_file)
|
||||
if not openie_path.exists():
|
||||
logger.error(f"OpenIE 文件不存在: {openie_path}")
|
||||
sys.exit(1)
|
||||
|
||||
pg_hashes, ent_hashes, rel_hashes = load_openie_hashes(openie_path)
|
||||
logger.info(
|
||||
f"从 {openie_path.name} 解析到 段落 {len(pg_hashes)} 条,实体 {len(ent_hashes)} 个,关系 {len(rel_hashes)} 条"
|
||||
)
|
||||
|
||||
# 加载当前嵌入与 KG
|
||||
em = EmbeddingManager()
|
||||
kg = KGManager()
|
||||
try:
|
||||
em.load_from_file()
|
||||
kg.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"加载当前知识库失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
graph_nodes = set(kg.graph.get_node_list())
|
||||
|
||||
# 检查段落
|
||||
pg_keys = [f"paragraph-{h}" for h in pg_hashes]
|
||||
pg_in_vec = sum(1 for k in pg_keys if k in em.paragraphs_embedding_store.store)
|
||||
pg_in_kg = sum(1 for k in pg_keys if k in graph_nodes)
|
||||
|
||||
# 检查实体
|
||||
ent_keys = [f"entity-{h}" for h in ent_hashes]
|
||||
ent_in_vec = sum(1 for k in ent_keys if k in em.entities_embedding_store.store)
|
||||
ent_in_kg = sum(1 for k in ent_keys if k in graph_nodes)
|
||||
|
||||
# 检查关系(只针对向量库)
|
||||
rel_keys = [f"relation-{h}" for h in rel_hashes]
|
||||
rel_in_vec = sum(1 for k in rel_keys if k in em.relation_embedding_store.store)
|
||||
|
||||
print("==== 批次存在情况(删除前/后对比用) ====")
|
||||
print(f"段落: 总计 {len(pg_keys)}, 向量库剩余 {pg_in_vec}, KG 中剩余 {pg_in_kg}")
|
||||
print(f"实体: 总计 {len(ent_keys)}, 向量库剩余 {ent_in_vec}, KG 中剩余 {ent_in_kg}")
|
||||
print(f"关系: 总计 {len(rel_keys)}, 向量库剩余 {rel_in_vec}")
|
||||
|
||||
# 打印少量仍存在的样例,便于检查内容是否正常
|
||||
sample_pg = [k for k in pg_keys if k in graph_nodes][:3]
|
||||
if sample_pg:
|
||||
print("\n仍在 KG 中的段落节点示例:")
|
||||
for k in sample_pg:
|
||||
nd = kg.graph[k]
|
||||
content = nd["content"] if "content" in nd else k
|
||||
print(f"- {k}: {content[:80]}")
|
||||
|
||||
sample_ent = [k for k in ent_keys if k in graph_nodes][:3]
|
||||
if sample_ent:
|
||||
print("\n仍在 KG 中的实体节点示例:")
|
||||
for k in sample_ent:
|
||||
nd = kg.graph[k]
|
||||
content = nd["content"] if "content" in nd else k
|
||||
print(f"- {k}: {content[:80]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,68 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 保证可以导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("inspect_lpmm_global")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""检查当前整库(所有批次)的向量与 KG 状态,用于观察删除对剩余数据的影响。"""
|
||||
em = EmbeddingManager()
|
||||
kg = KGManager()
|
||||
|
||||
try:
|
||||
em.load_from_file()
|
||||
kg.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"加载当前知识库失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 向量库统计
|
||||
para_cnt = len(em.paragraphs_embedding_store.store)
|
||||
ent_cnt_vec = len(em.entities_embedding_store.store)
|
||||
rel_cnt_vec = len(em.relation_embedding_store.store)
|
||||
|
||||
# KG 统计
|
||||
nodes = kg.graph.get_node_list()
|
||||
edges = kg.graph.get_edge_list()
|
||||
|
||||
para_nodes = [n for n in nodes if n.startswith("paragraph-")]
|
||||
ent_nodes = [n for n in nodes if n.startswith("entity-")]
|
||||
|
||||
print("==== 向量库统计 ====")
|
||||
print(f"段落向量条数: {para_cnt}")
|
||||
print(f"实体向量条数: {ent_cnt_vec}")
|
||||
print(f"关系向量条数: {rel_cnt_vec}")
|
||||
|
||||
print("\n==== KG 图统计 ====")
|
||||
print(f"节点总数: {len(nodes)}")
|
||||
print(f"边总数: {len(edges)}")
|
||||
print(f"段落节点数: {len(para_nodes)}")
|
||||
print(f"实体节点数: {len(ent_nodes)}")
|
||||
|
||||
# ent_appear_cnt 状态
|
||||
ent_cnt_meta = len(kg.ent_appear_cnt)
|
||||
print(f"\n实体计数表条目数: {ent_cnt_meta}")
|
||||
|
||||
# 抽样查看剩余段落/实体内容
|
||||
print("\n==== 剩余段落示例(最多 3 条) ====")
|
||||
for nid in para_nodes[:3]:
|
||||
nd = kg.graph[nid]
|
||||
content = nd["content"] if "content" in nd else nid
|
||||
print(f"- {nid}: {content[:80]}")
|
||||
|
||||
print("\n==== 剩余实体示例(最多 5 条) ====")
|
||||
for nid in ent_nodes[:5]:
|
||||
nd = kg.graph[nid]
|
||||
content = nd["content"] if "content" in nd else nid
|
||||
print(f"- {nid}: {content[:80]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,278 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 尽量统一控制台编码为 utf-8,避免中文输出报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保项目根目录在 sys.path 中,以便导入 src.*
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
try:
|
||||
# 显式从 src.chat.knowledge.lpmm_ops 导入单例对象
|
||||
from src.chat.knowledge.lpmm_ops import lpmm_ops
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_system.retrieval_tools.query_lpmm_knowledge import query_lpmm_knowledge
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from src.config.config import global_config
|
||||
except ImportError as e:
|
||||
print(f"导入失败,请确保在项目根目录下运行脚本: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
logger = get_logger("lpmm_interactive_manager")
|
||||
|
||||
|
||||
async def interactive_add():
|
||||
"""交互式导入知识"""
|
||||
print("\n" + "=" * 40)
|
||||
print(" --- 📥 导入知识 (Add) ---")
|
||||
print("=" * 40)
|
||||
print("说明:请输入要导入的文本内容。")
|
||||
print(" - 支持多段落,段落间请保留空行。")
|
||||
print(" - 输入完成后,在新起的一行输入 'EOF' 并回车结束输入。")
|
||||
print("-" * 40)
|
||||
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
line = input()
|
||||
if line.strip().upper() == "EOF":
|
||||
break
|
||||
lines.append(line)
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
text = "\n".join(lines).strip()
|
||||
if not text:
|
||||
print("\n[!] 内容为空,操作已取消。")
|
||||
return
|
||||
|
||||
print("\n[进度] 正在调用 LPMM 接口进行信息抽取与向量化,请稍候...")
|
||||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.add_content(text)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
print(f" 实际新增段落数: {result.get('count', 0)}")
|
||||
else:
|
||||
print(f"\n[×] 失败:{result['message']}")
|
||||
except Exception as e:
|
||||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"add_content 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_delete():
|
||||
"""交互式删除知识"""
|
||||
print("\n" + "=" * 40)
|
||||
print(" --- 🗑️ 删除知识 (Delete) ---")
|
||||
print("=" * 40)
|
||||
print("删除模式:")
|
||||
print(" 1. 关键词模糊匹配(删除包含关键词的所有段落)")
|
||||
print(" 2. 完整文段匹配(删除完全匹配的段落)")
|
||||
print("-" * 40)
|
||||
|
||||
mode = input("请选择删除模式 (1/2): ").strip()
|
||||
exact_match = False
|
||||
|
||||
if mode == "2":
|
||||
exact_match = True
|
||||
print("\n[完整文段匹配模式]")
|
||||
print("说明:请输入要删除的完整文段内容(必须完全一致)。")
|
||||
print(" - 支持多行输入,输入完成后在新起的一行输入 'EOF' 并回车。")
|
||||
print("-" * 40)
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
line = input()
|
||||
if line.strip().upper() == "EOF":
|
||||
break
|
||||
lines.append(line)
|
||||
except EOFError:
|
||||
break
|
||||
keyword = "\n".join(lines).strip()
|
||||
else:
|
||||
if mode != "1":
|
||||
print("\n[!] 无效选择,默认使用关键词模糊匹配模式。")
|
||||
print("\n[关键词模糊匹配模式]")
|
||||
keyword = input("请输入匹配关键词: ").strip()
|
||||
|
||||
if not keyword:
|
||||
print("\n[!] 输入为空,操作已取消。")
|
||||
return
|
||||
|
||||
print("-" * 40)
|
||||
confirm = (
|
||||
input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if confirm != "y":
|
||||
print("\n[!] 已取消删除操作。")
|
||||
return
|
||||
|
||||
print("\n[进度] 正在执行删除并更新索引...")
|
||||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.delete(keyword, exact_match=exact_match)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
print(f" 删除条数: {result.get('deleted_count', 0)}")
|
||||
elif result["status"] == "info":
|
||||
print(f"\n[i] 提示:{result['message']}")
|
||||
else:
|
||||
print(f"\n[×] 失败:{result['message']}")
|
||||
except Exception as e:
|
||||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"delete 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_clear():
|
||||
"""交互式清空知识库"""
|
||||
print("\n" + "=" * 40)
|
||||
print(" --- ⚠️ 清空知识库 (Clear All) ---")
|
||||
print("=" * 40)
|
||||
print("警告:此操作将删除LPMM知识库中的所有内容!")
|
||||
print(" - 所有段落向量")
|
||||
print(" - 所有实体向量")
|
||||
print(" - 所有关系向量")
|
||||
print(" - 整个知识图谱")
|
||||
print(" - 此操作不可恢复!")
|
||||
print("-" * 40)
|
||||
|
||||
# 双重确认
|
||||
confirm1 = input("⚠️ 第一次确认:确定要清空整个知识库吗?(输入 'YES' 继续): ").strip()
|
||||
if confirm1 != "YES":
|
||||
print("\n[!] 已取消清空操作。")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 40)
|
||||
confirm2 = input("⚠️ 第二次确认:此操作不可恢复,请再次输入 'CLEAR' 确认: ").strip()
|
||||
if confirm2 != "CLEAR":
|
||||
print("\n[!] 已取消清空操作。")
|
||||
return
|
||||
|
||||
print("\n[进度] 正在清空知识库...")
|
||||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.clear_all()
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
stats = result.get("stats", {})
|
||||
before = stats.get("before", {})
|
||||
after = stats.get("after", {})
|
||||
print("\n[统计信息]")
|
||||
print(
|
||||
f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
||||
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
|
||||
)
|
||||
print(
|
||||
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
||||
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
|
||||
)
|
||||
else:
|
||||
print(f"\n[×] 失败:{result['message']}")
|
||||
except Exception as e:
|
||||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"clear_all 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_search():
|
||||
"""交互式查询知识"""
|
||||
print("\n" + "=" * 40)
|
||||
print(" --- 🔍 查询知识 (Search) ---")
|
||||
print("=" * 40)
|
||||
print("说明:输入查询问题或关键词,系统会返回相关的知识段落。")
|
||||
print("-" * 40)
|
||||
|
||||
# 确保 LPMM 已初始化
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
print("\n[!] 警告:LPMM 知识库在配置中未启用。")
|
||||
return
|
||||
|
||||
try:
|
||||
lpmm_start_up()
|
||||
except Exception as e:
|
||||
print(f"\n[!] LPMM 初始化失败: {e}")
|
||||
logger.error(f"LPMM 初始化失败: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
query = input("请输入查询问题或关键词: ").strip()
|
||||
|
||||
if not query:
|
||||
print("\n[!] 查询内容为空,操作已取消。")
|
||||
return
|
||||
|
||||
# 询问返回条数
|
||||
print("-" * 40)
|
||||
limit_str = input("希望返回的相关知识条数(默认3,直接回车使用默认值): ").strip()
|
||||
try:
|
||||
limit = int(limit_str) if limit_str else 3
|
||||
limit = max(1, min(limit, 20)) # 限制在1-20之间
|
||||
except ValueError:
|
||||
limit = 3
|
||||
print("[!] 输入无效,使用默认值 3。")
|
||||
|
||||
print("\n[进度] 正在查询知识库...")
|
||||
try:
|
||||
result = await query_lpmm_knowledge(query, limit=limit)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("[查询结果]")
|
||||
print("=" * 60)
|
||||
print(result)
|
||||
print("=" * 60)
|
||||
except Exception as e:
|
||||
print(f"\n[×] 查询失败: {e}")
|
||||
logger.error(f"查询异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主循环"""
|
||||
while True:
|
||||
print("\n" + "╔" + "═" * 38 + "╗")
|
||||
print("║ LPMM 知识库交互管理工具 ║")
|
||||
print("╠" + "═" * 38 + "╣")
|
||||
print("║ 1. 导入知识 (Add Content) ║")
|
||||
print("║ 2. 删除知识 (Delete Content) ║")
|
||||
print("║ 3. 查询知识 (Search Content) ║")
|
||||
print("║ 4. 清空知识库 (Clear All) ⚠️ ║")
|
||||
print("║ 0. 退出 (Exit) ║")
|
||||
print("╚" + "═" * 38 + "╝")
|
||||
|
||||
choice = input("请选择操作编号: ").strip()
|
||||
|
||||
if choice == "1":
|
||||
await interactive_add()
|
||||
elif choice == "2":
|
||||
await interactive_delete()
|
||||
elif choice == "3":
|
||||
await interactive_search()
|
||||
elif choice == "4":
|
||||
await interactive_clear()
|
||||
elif choice in ("0", "q", "Q", "quit", "exit"):
|
||||
print("\n已退出工具。")
|
||||
break
|
||||
else:
|
||||
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 运行主循环
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[!] 用户中断程序 (Ctrl+C)。")
|
||||
except Exception as e:
|
||||
print(f"\n[!] 程序运行出错: {e}")
|
||||
logger.error(f"Main loop 异常: {e}", exc_info=True)
|
||||
@@ -1,512 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
# 尽量统一控制台编码为 utf-8,避免中文输出报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.* 以及同目录脚本
|
||||
CURRENT_DIR = os.path.dirname(__file__)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from src.common.logger import get_logger # type: ignore # noqa: E402
|
||||
from src.config.config import global_config, model_config # type: ignore # noqa: E402
|
||||
|
||||
# 引入各功能脚本的入口函数
|
||||
from import_openie import main as import_openie_main # type: ignore # noqa: E402
|
||||
from info_extraction import main as info_extraction_main # type: ignore # noqa: E402
|
||||
from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore # noqa: E402
|
||||
from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore # noqa: E402
|
||||
from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore # noqa: E402
|
||||
from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore # noqa: E402
|
||||
from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore # noqa: E402
|
||||
from raw_data_preprocessor import load_raw_data # type: ignore # noqa: E402
|
||||
|
||||
|
||||
logger = get_logger("lpmm_manager")
|
||||
|
||||
|
||||
ACTION_INFO = {
|
||||
"prepare_raw": "预处理 data/lpmm_raw_data/*.txt,按空行切分为段落并做去重统计",
|
||||
"info_extract": "原始 txt -> OpenIE 信息抽取(调用 info_extraction.py)",
|
||||
"import_openie": "导入 OpenIE 批次到向量库与知识图(调用 import_openie.py)",
|
||||
"delete": "删除/回滚知识(调用 delete_lpmm_items.py)",
|
||||
"batch_inspect": "检查指定 OpenIE 批次在当前库中的存在情况(调用 inspect_lpmm_batch.py)",
|
||||
"global_inspect": "查看当前整库向量与 KG 状态(调用 inspect_lpmm_global.py)",
|
||||
"refresh": "刷新 LPMM 磁盘数据到内存(调用 refresh_lpmm_knowledge.py)",
|
||||
"test": "运行 LPMM 检索效果回归测试(调用 test_lpmm_retrieval.py)",
|
||||
"embedding_helper": "嵌入模型迁移辅助:查看当前嵌入模型/维度并归档 embedding_model_test.json",
|
||||
"full_import": "一键执行:信息抽取 -> 导入 OpenIE -> 刷新",
|
||||
}
|
||||
|
||||
|
||||
def _with_overridden_argv(extra_args: List[str], target_main) -> None:
|
||||
"""在不修改子脚本的前提下,临时覆盖 sys.argv 以透传参数。"""
|
||||
old_argv = list(sys.argv)
|
||||
try:
|
||||
# 第 0 个元素为“程序名”,后续元素为实际参数
|
||||
# 这里不再插入类似 delete_lpmm_items.py 的占位,避免被 argparse 误识别为位置参数
|
||||
sys.argv = [old_argv[0]] + extra_args
|
||||
target_main()
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
|
||||
|
||||
def _check_before_info_extract(non_interactive: bool = False) -> bool:
|
||||
"""信息抽取前的轻量级检查。"""
|
||||
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
|
||||
txt_files = list(raw_dir.glob("*.txt"))
|
||||
if not txt_files:
|
||||
msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,info_extraction 可能立即退出或无数据可处理。"
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
|
||||
return False
|
||||
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
return True
|
||||
|
||||
|
||||
def _check_before_import_openie(non_interactive: bool = False) -> bool:
|
||||
"""导入 OpenIE 前的轻量级检查。"""
|
||||
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
||||
json_files = list(openie_dir.glob("*.json"))
|
||||
if not json_files:
|
||||
msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,import_openie 可能会因为找不到批次而失败。"
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
|
||||
return False
|
||||
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
return True
|
||||
|
||||
|
||||
def _warn_if_lpmm_disabled() -> None:
|
||||
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
|
||||
try:
|
||||
if not getattr(global_config.lpmm_knowledge, "enable", False):
|
||||
print("[WARN] 当前配置 lpmm_knowledge.enable = false,刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
|
||||
except Exception:
|
||||
# 配置异常时不阻断主流程,仅忽略提示
|
||||
pass
|
||||
|
||||
|
||||
def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
||||
"""根据动作名称调度到对应脚本。
|
||||
|
||||
这里不重复解析子参数,而是直接调用各脚本的 main(),
|
||||
让子脚本保留原有的交互/参数行为。
|
||||
"""
|
||||
logger.info(f"开始执行操作: {action}")
|
||||
|
||||
extra_args = extra_args or []
|
||||
|
||||
try:
|
||||
if action == "prepare_raw":
|
||||
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||
elif action == "info_extract":
|
||||
if not _check_before_info_extract("--non-interactive" in extra_args):
|
||||
print("已根据用户选择,取消执行信息提取。")
|
||||
return
|
||||
_with_overridden_argv(extra_args, info_extraction_main)
|
||||
elif action == "import_openie":
|
||||
if not _check_before_import_openie("--non-interactive" in extra_args):
|
||||
print("已根据用户选择,取消执行导入。")
|
||||
return
|
||||
_with_overridden_argv(extra_args, import_openie_main)
|
||||
elif action == "delete":
|
||||
_with_overridden_argv(extra_args, delete_lpmm_items_main)
|
||||
elif action == "batch_inspect":
|
||||
_with_overridden_argv(extra_args, inspect_lpmm_batch_main)
|
||||
elif action == "global_inspect":
|
||||
_with_overridden_argv(extra_args, inspect_lpmm_global_main)
|
||||
elif action == "refresh":
|
||||
_warn_if_lpmm_disabled()
|
||||
_with_overridden_argv(extra_args, refresh_lpmm_knowledge_main)
|
||||
elif action == "test":
|
||||
_warn_if_lpmm_disabled()
|
||||
_with_overridden_argv(extra_args, test_lpmm_retrieval_main)
|
||||
elif action == "embedding_helper":
|
||||
# 嵌入模型迁移辅助:查看当前嵌入模型/维度并归档 embedding_model_test.json
|
||||
_run_embedding_helper()
|
||||
elif action == "full_import":
|
||||
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
|
||||
logger.info("开始 full_import:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||
non_interactive = "--non-interactive" in extra_args
|
||||
if not _check_before_info_extract(non_interactive):
|
||||
print("已根据用户选择,取消 full_import(信息提取阶段被取消)。")
|
||||
return
|
||||
# 使用与单步 info_extract 相同的参数透传机制,确保 --non-interactive 等生效
|
||||
_with_overridden_argv(extra_args, info_extraction_main)
|
||||
if not _check_before_import_openie(non_interactive):
|
||||
print("已根据用户选择,取消 full_import(导入阶段被取消)。")
|
||||
return
|
||||
_with_overridden_argv(extra_args, import_openie_main)
|
||||
_warn_if_lpmm_disabled()
|
||||
_with_overridden_argv(extra_args, refresh_lpmm_knowledge_main)
|
||||
else:
|
||||
logger.error(f"未知操作: {action}")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断当前操作(Ctrl+C)")
|
||||
except SystemExit:
|
||||
# 子脚本里大量使用 sys.exit,直接透传即可
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
logger.error(f"执行操作 {action} 时发生未捕获异常: {exc}")
|
||||
raise
|
||||
|
||||
|
||||
def print_menu() -> None:
|
||||
print("\n===== LPMM 管理菜单 =====")
|
||||
for idx, key in enumerate(
|
||||
[
|
||||
"prepare_raw",
|
||||
"info_extract",
|
||||
"import_openie",
|
||||
"delete",
|
||||
"batch_inspect",
|
||||
"global_inspect",
|
||||
"refresh",
|
||||
"test",
|
||||
"embedding_helper",
|
||||
"full_import",
|
||||
],
|
||||
start=1,
|
||||
):
|
||||
desc = ACTION_INFO.get(key, "")
|
||||
print(f"{idx}. {key:14s} - {desc}")
|
||||
print("0. 退出")
|
||||
print("=========================")
|
||||
|
||||
|
||||
def interactive_loop() -> None:
|
||||
"""交互式选择模式。"""
|
||||
key_order = [
|
||||
"prepare_raw",
|
||||
"info_extract",
|
||||
"import_openie",
|
||||
"delete",
|
||||
"batch_inspect",
|
||||
"global_inspect",
|
||||
"refresh",
|
||||
"test",
|
||||
"embedding_helper",
|
||||
"full_import",
|
||||
]
|
||||
|
||||
while True:
|
||||
print_menu()
|
||||
choice = input("请输入选项编号(0-10):").strip()
|
||||
|
||||
if choice in ("0", "q", "Q", "quit", "exit"):
|
||||
print("已退出 LPMM 管理器。")
|
||||
return
|
||||
|
||||
try:
|
||||
idx = int(choice)
|
||||
except ValueError:
|
||||
print("输入无效,请输入 0-10 之间的数字。")
|
||||
continue
|
||||
|
||||
if not (1 <= idx <= len(key_order)):
|
||||
print("输入编号超出范围,请重新输入。")
|
||||
continue
|
||||
|
||||
action = key_order[idx - 1]
|
||||
print(f"\n你选择了: {action} - {ACTION_INFO.get(action, '')}")
|
||||
confirm = input("确认执行该操作?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
print("已取消当前操作。\n")
|
||||
continue
|
||||
|
||||
# 通过交互式问题,尽量帮用户补全对应脚本的常用参数
|
||||
extra_args: List[str] = []
|
||||
if action == "delete":
|
||||
extra_args = _interactive_build_delete_args()
|
||||
elif action == "batch_inspect":
|
||||
extra_args = _interactive_build_batch_inspect_args()
|
||||
elif action == "test":
|
||||
extra_args = _interactive_build_test_args()
|
||||
else:
|
||||
extra_args = []
|
||||
|
||||
run_action(action, extra_args=extra_args)
|
||||
print("\n当前操作已结束,回到主菜单。\n")
|
||||
|
||||
|
||||
def _interactive_choose_openie_file(prompt: str) -> Optional[str]:
|
||||
"""在 data/openie 下列出可选 JSON 文件,并返回用户选择的路径。"""
|
||||
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
||||
files = sorted(openie_dir.glob("*.json"))
|
||||
if not files:
|
||||
print(f"[WARN] 在 {openie_dir} 下没有找到任何 OpenIE JSON 文件。")
|
||||
return input(prompt).strip() or None
|
||||
|
||||
print("\n可选的 OpenIE 批次文件:")
|
||||
for i, f in enumerate(files, start=1):
|
||||
print(f"{i}. {f.name}")
|
||||
print("0. 手动输入完整路径")
|
||||
|
||||
while True:
|
||||
choice = input("请选择文件编号:").strip()
|
||||
if choice == "0":
|
||||
manual = input(prompt).strip()
|
||||
return manual or None
|
||||
try:
|
||||
idx = int(choice)
|
||||
except ValueError:
|
||||
print("请输入合法的编号。")
|
||||
continue
|
||||
if 1 <= idx <= len(files):
|
||||
return str(files[idx - 1])
|
||||
print("编号超出范围,请重试。")
|
||||
|
||||
|
||||
def _interactive_build_delete_args() -> List[str]:
|
||||
"""为 delete_lpmm_items 构造常见参数,减少二次交互。"""
|
||||
print(
|
||||
"\n[DELETE] 请选择删除方式:\n"
|
||||
"1. 按哈希文件删除 (--hash-file)\n"
|
||||
"2. 按 OpenIE 批次删除 (--openie-file)\n"
|
||||
"3. 按原始语料文件 + 段落索引删除 (--raw-file + --raw-index)\n"
|
||||
"4. 按关键字搜索现有段落 (--search-text)\n"
|
||||
"回车跳过,由子脚本自行交互。"
|
||||
)
|
||||
mode = input("输入选项编号(1-4,或回车跳过):").strip()
|
||||
args: List[str] = []
|
||||
|
||||
if mode == "1":
|
||||
path = input("请输入哈希文件路径(每行一个 hash):").strip()
|
||||
if path:
|
||||
args += ["--hash-file", path]
|
||||
elif mode == "2":
|
||||
path = _interactive_choose_openie_file("请输入 OpenIE JSON 文件路径:")
|
||||
if path:
|
||||
args += ["--openie-file", path]
|
||||
elif mode == "3":
|
||||
raw_file = input("请输入原始语料 txt 文件路径:").strip()
|
||||
raw_index = input("请输入要删除的段落索引(如 1,3):").strip()
|
||||
if raw_file and raw_index:
|
||||
args += ["--raw-file", raw_file, "--raw-index", raw_index]
|
||||
elif mode == "4":
|
||||
text = input("请输入用于搜索的关键字(出现在段落原文中):").strip()
|
||||
if text:
|
||||
args += ["--search-text", text]
|
||||
else:
|
||||
# 留空则完全交给子脚本交互
|
||||
return []
|
||||
|
||||
# 进一步询问与安全相关的布尔选项
|
||||
print(
|
||||
"\n[DELETE] 接下来是一些安全相关选项的说明:\n"
|
||||
"- 删除实体向量/节点:会一并清理与这些段落关联的实体节点及其向量;\n"
|
||||
"- 删除关系向量:在上面的基础上,额外清理关系向量(一般与删除实体一同使用);\n"
|
||||
"- 删除孤立实体节点:删除后若实体不再连接任何段落,将其从图中移除,避免残留孤点;\n"
|
||||
"- dry-run:只预览将要删除的内容,不真正修改任何数据;\n"
|
||||
"- 跳过交互确认(--yes):直接执行删除操作,适合脚本化或已充分确认的场景;\n"
|
||||
"- 单次最大删除节点数上限:防止一次性删除规模过大,起到误操作保护作用;\n"
|
||||
"- 一般情况下建议同时删除实体向量/节点/关系向量/节点,以确保知识图谱的完整性。"
|
||||
)
|
||||
|
||||
# 快速选项:按推荐方式清理所有相关实体/关系
|
||||
quick_all = (
|
||||
input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
|
||||
)
|
||||
if quick_all in ("", "y", "yes"):
|
||||
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
|
||||
else:
|
||||
# 仅当未使用快速方案时,再逐项询问
|
||||
if input("是否同时删除实体向量/节点?(y/N): ").strip().lower() == "y":
|
||||
args.append("--delete-entities")
|
||||
if input("是否同时删除关系向量?(y/N): ").strip().lower() == "y":
|
||||
args.append("--delete-relations")
|
||||
|
||||
if input("是否删除孤立实体节点?(y/N): ").strip().lower() == "y":
|
||||
args.append("--remove-orphan-entities")
|
||||
|
||||
if input("是否以 dry-run 预览而不真正删除?(y/N): ").strip().lower() == "y":
|
||||
args.append("--dry-run")
|
||||
else:
|
||||
if input("是否跳过交互确认直接删除?(默认否,请谨慎) (y/N): ").strip().lower() == "y":
|
||||
args.append("--yes")
|
||||
|
||||
max_nodes = input("单次最大删除节点数上限(回车使用默认 2000):").strip()
|
||||
if max_nodes:
|
||||
args += ["--max-delete-nodes", max_nodes]
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _interactive_build_batch_inspect_args() -> List[str]:
|
||||
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
|
||||
path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
|
||||
if not path:
|
||||
return []
|
||||
return ["--openie-file", path]
|
||||
|
||||
|
||||
def _interactive_build_test_args() -> List[str]:
|
||||
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
|
||||
print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
|
||||
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
expect = input("请输入期望命中的关键字(可选,多项用逗号分隔):").strip()
|
||||
args: List[str] = ["--query", query]
|
||||
if expect:
|
||||
for kw in expect.split(","):
|
||||
kw = kw.strip()
|
||||
if kw:
|
||||
args.extend(["--expect-keyword", kw])
|
||||
return args
|
||||
|
||||
|
||||
def _run_embedding_helper() -> None:
|
||||
"""嵌入模型迁移辅助:展示当前配置,并安全归档 embedding_model_test.json。"""
|
||||
from src.chat.knowledge.embedding_store import EMBEDDING_TEST_FILE # type: ignore
|
||||
|
||||
# 1. 读取当前配置中的嵌入维度与模型信息
|
||||
current_dim = getattr(getattr(global_config, "lpmm_knowledge", None), "embedding_dimension", None)
|
||||
embed_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
model_ids: List[str] = []
|
||||
if embed_task is not None:
|
||||
model_ids = getattr(embed_task, "model_list", []) or []
|
||||
primary_model = model_ids[0] if model_ids else "unknown"
|
||||
safe_model_name = re.sub(r"[^0-9A-Za-z_.-]+", "_", primary_model) or "unknown"
|
||||
|
||||
print("\n===== 嵌入模型迁移辅助 (embedding_helper) =====")
|
||||
print(f"- 当前嵌入模型标识(model_task_config.embedding.model_list[0]): {primary_model}")
|
||||
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
|
||||
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
|
||||
|
||||
new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
|
||||
if new_dim and not new_dim.isdigit():
|
||||
print("输入的维度不是纯数字,已取消操作。")
|
||||
return
|
||||
|
||||
print(
|
||||
"\n[重要提示]\n"
|
||||
"- 修改嵌入模型或维度会导致当前磁盘中的旧知识库(data/embedding 下的向量)与新模型不兼容;\n"
|
||||
"- 这通常意味着你需要清空旧的向量/图数据,并重新执行 LPMM 导入流水线;\n"
|
||||
"- 请仅在你**确定要切换嵌入模型/维度**时再继续。\n"
|
||||
)
|
||||
confirm = input("是否已充分评估风险,并准备切换嵌入模型/维度?(y/N): ").strip().lower()
|
||||
if confirm != "y":
|
||||
print("已根据你的选择取消嵌入模型迁移辅助操作。")
|
||||
return
|
||||
|
||||
print(
|
||||
"\n接下来请手动完成以下操作(脚本不会自动修改配置或删除知识库):\n"
|
||||
f"1. 在配置文件中,将 lpmm_knowledge.embedding_dimension 从 {current_dim} 修改为你计划使用的新维度"
|
||||
+ (f"(例如 {new_dim})" if new_dim else "") # 仅作为示例
|
||||
+ ";\n"
|
||||
"2. 根据需要,清空 data/embedding 与相关 KG 数据(data/rag 等),然后重新执行导入流水线;\n"
|
||||
"3. 本脚本将帮助你归档当前的 embedding_model_test.json,避免旧测试文件干扰新模型的校验。\n"
|
||||
)
|
||||
|
||||
# 2. 归档 embedding_model_test.json
|
||||
test_path = Path(EMBEDDING_TEST_FILE)
|
||||
if not test_path.exists():
|
||||
print(f"\n[INFO] 未在 {test_path} 发现 embedding_model_test.json,无需归档。")
|
||||
return
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
archive_name = f"embedding_model_test-{safe_model_name}-{ts}.json"
|
||||
archive_path = test_path.with_name(archive_name)
|
||||
|
||||
# 若不巧重名,简单追加后缀避免覆盖
|
||||
suffix_id = 1
|
||||
while archive_path.exists():
|
||||
archive_name = f"embedding_model_test-{safe_model_name}-{ts}-{suffix_id}.json"
|
||||
archive_path = test_path.with_name(archive_name)
|
||||
suffix_id += 1
|
||||
|
||||
try:
|
||||
test_path.rename(archive_path)
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
logger.error(f"归档 embedding_model_test.json 失败: {exc}")
|
||||
print("[ERROR] 归档 embedding_model_test.json 失败,请检查文件权限与路径。错误详情已写入日志。")
|
||||
return
|
||||
|
||||
print(
|
||||
f"\n[OK] 已将 {test_path.name} 重命名为 {archive_path.name}。\n"
|
||||
f"- 归档位置: {archive_path}\n"
|
||||
"- 之后再次运行涉及嵌入模型的一致性校验时,将会基于当前配置与新模型生成新的测试文件。\n"
|
||||
"- 在完成配置修改与知识库重导入前,请不要手动再创建名为 embedding_model_test.json 的文件。"
|
||||
)
|
||||
|
||||
|
||||
def parse_args(argv: Optional[list[str]] = None) -> tuple[argparse.Namespace, List[str]]:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"LPMM 管理脚本:集中入口管理 LPMM 的导入 / 删除 / 自检 / 刷新 / 测试等功能。\n"
|
||||
"可以通过 --interactive 进入菜单模式,也可以使用 --action 直接执行单个操作。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="进入交互式菜单模式(推荐给手动运维使用)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--action",
|
||||
choices=list(ACTION_INFO.keys()),
|
||||
help="直接执行指定操作(非交互模式)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help=(
|
||||
"启用非交互模式:lpmm_manager 自身不会再通过 input() 询问是否继续前置检查;"
|
||||
"并会将 --non-interactive 透传给子脚本,以避免子脚本中的交互式确认。"
|
||||
),
|
||||
)
|
||||
# 允许在管理脚本之后继续跟随子脚本参数,例如:
|
||||
# python lpmm_manager.py -a delete -- --hash-file xxx --yes
|
||||
args, unknown = parser.parse_known_args(argv)
|
||||
return args, unknown
|
||||
|
||||
|
||||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
args, extra_args = parse_args(argv)
|
||||
|
||||
# 如果指定了 non-interactive,则不能进入交互式菜单
|
||||
if args.non_interactive and args.interactive:
|
||||
logger.error("不能同时指定 --interactive 与 --non-interactive,请二选一。")
|
||||
sys.exit(1)
|
||||
|
||||
# 没有指定 action 或显式要求交互 -> 进入菜单
|
||||
if args.interactive or not args.action:
|
||||
interactive_loop()
|
||||
return
|
||||
|
||||
# 在非交互模式下,将 --non-interactive 透传给子脚本,避免其内部出现 input() 交互
|
||||
if args.non_interactive:
|
||||
extra_args = ["--non-interactive"] + extra_args
|
||||
|
||||
# 非交互模式:直接执行指定操作
|
||||
run_action(args.action, extra_args=extra_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -748,6 +748,12 @@ INDEX_HTML_V2 = r"""<!doctype html>
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.header-tools { margin-bottom: 0; justify-content: flex-end; }
|
||||
.global-evaluator {
|
||||
width: 150px;
|
||||
}
|
||||
.global-evaluator {
|
||||
width: 150px;
|
||||
}
|
||||
input, select, textarea, button {
|
||||
font: inherit;
|
||||
border: 1px solid var(--line);
|
||||
@@ -868,6 +874,8 @@ INDEX_HTML_V2 = r"""<!doctype html>
|
||||
<header>
|
||||
<h1>Maisaka 回复效果评分预览</h1>
|
||||
<div class="toolbar header-tools">
|
||||
<span class="meta">评价人</span>
|
||||
<input id="globalEvaluator" class="global-evaluator" placeholder="manual" oninput="saveGlobalEvaluator()" />
|
||||
<button id="browseTab" class="tab-button active" onclick="setMode('browse')">浏览</button>
|
||||
<button id="rateTab" class="tab-button" onclick="setMode('rate')">逐条评分</button>
|
||||
<button class="secondary" onclick="reloadAll()">刷新</button>
|
||||
@@ -1527,6 +1535,33 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
}
|
||||
.message-name { font-weight: 650; color: var(--text); }
|
||||
.message-text { white-space: pre-wrap; word-break: break-word; line-height: 1.45; }
|
||||
.quote-card {
|
||||
border-left: 3px solid var(--accent);
|
||||
background: var(--accent-soft);
|
||||
border-radius: 6px;
|
||||
padding: 6px 8px;
|
||||
margin: 0 0 6px;
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
}
|
||||
.quote-card.missing {
|
||||
border-left-color: var(--warn);
|
||||
background: #fff7ed;
|
||||
}
|
||||
.quote-title {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
gap: 8px;
|
||||
margin-bottom: 3px;
|
||||
font-weight: 650;
|
||||
color: var(--text);
|
||||
}
|
||||
.quote-text {
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
line-height: 1.35;
|
||||
}
|
||||
.message-attachments {
|
||||
display: flex;
|
||||
gap: 6px;
|
||||
@@ -1578,6 +1613,8 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
<header>
|
||||
<h1>Maisaka 回复效果评分预览</h1>
|
||||
<div class="toolbar header-tools">
|
||||
<span class="meta">评价人</span>
|
||||
<input id="globalEvaluator" class="global-evaluator" placeholder="manual" oninput="saveGlobalEvaluator()" />
|
||||
<button id="browseTab" class="tab-button active" onclick="setMode('browse')">浏览</button>
|
||||
<button id="rateTab" class="tab-button" onclick="setMode('rate')">逐条评分</button>
|
||||
<button class="secondary" onclick="reloadAll()">刷新</button>
|
||||
@@ -1632,6 +1669,8 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
let selectedEffect = "";
|
||||
let activeMode = "browse";
|
||||
let selectedFivePointScore = 0;
|
||||
let currentTargetMessageId = "";
|
||||
let currentMessageIndex = new Map();
|
||||
|
||||
async function api(path, options) {
|
||||
const res = await fetch(path, options);
|
||||
@@ -1763,6 +1802,11 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
const reply = record.reply || {};
|
||||
const manual = record._manual || {};
|
||||
const followups = record.followup_messages || [];
|
||||
currentTargetMessageId = String(reply.target_message_id || "");
|
||||
const context = normalizeContextMessages(record.context_snapshot || []);
|
||||
const normalizedFollowups = normalizeFollowupMessages(followups);
|
||||
const botReply = normalizeBotReply(reply);
|
||||
buildCurrentMessageIndex(context, botReply, normalizedFollowups);
|
||||
selectedFivePointScore = Number(manual.manual_score_5 || score100ToFive(manual.manual_score) || 0);
|
||||
document.getElementById("detailPane").innerHTML = `
|
||||
<div class="block">
|
||||
@@ -1784,10 +1828,6 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
id="scoreButton${score}" onclick="selectFivePointScore(${score})">${score}</button>
|
||||
`).join("")}
|
||||
</div>
|
||||
<label>评价人</label>
|
||||
<input id="evaluator" value="${escapeAttr(manual.evaluator || "manual")}" />
|
||||
<label>备注</label>
|
||||
<textarea id="manualNotes">${escapeHtml(manual.notes || "")}</textarea>
|
||||
<div class="toolbar">
|
||||
<button onclick="saveFivePointManual('${escapeAttr(record.session.platform_type_id)}','${escapeAttr(record.effect_id)}', false)">
|
||||
保存人工评分
|
||||
@@ -1796,11 +1836,11 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
</div>
|
||||
<div class="block">
|
||||
<h2>回复内容</h2>
|
||||
${renderBotReplyCard(reply.reply_text || "")}
|
||||
${renderChatMessageCard(botReply)}
|
||||
</div>
|
||||
<div class="block">
|
||||
<h2>后续消息</h2>
|
||||
${renderFollowupCards(followups)}
|
||||
${renderMessageCards(normalizedFollowups, "暂无")}
|
||||
</div>
|
||||
<div class="block">
|
||||
<h2>完整 JSON</h2>
|
||||
@@ -1877,6 +1917,9 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
const followups = record.followup_messages || [];
|
||||
currentTargetMessageId = String(reply.target_message_id || "");
|
||||
const context = normalizeContextMessages(record.context_snapshot || []);
|
||||
const normalizedFollowups = normalizeFollowupMessages(followups);
|
||||
const botReply = normalizeBotReply(reply);
|
||||
buildCurrentMessageIndex(context, botReply, normalizedFollowups);
|
||||
document.getElementById("detailPane").innerHTML = `
|
||||
<div class="rate-top">
|
||||
<div class="toolbar">
|
||||
@@ -1904,11 +1947,11 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
</div>
|
||||
<div class="block">
|
||||
<h2>Bot 回复</h2>
|
||||
${renderBotReplyCard(reply.reply_text || "")}
|
||||
${renderChatMessageCard(botReply)}
|
||||
</div>
|
||||
<div class="block">
|
||||
<h2>后续消息</h2>
|
||||
${renderFollowupCards(followups)}
|
||||
${renderMessageCards(normalizedFollowups, "暂无")}
|
||||
</div>
|
||||
<div class="block">
|
||||
<h2>人工五点评分</h2>
|
||||
@@ -1918,10 +1961,6 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
<button class="score-button" id="scoreButton${score}" onclick="selectFivePointScore(${score})">${score}</button>
|
||||
`).join("")}
|
||||
</div>
|
||||
<label>评价人</label>
|
||||
<input id="ratingEvaluator" value="manual" />
|
||||
<label>备注</label>
|
||||
<textarea id="ratingNotes"></textarea>
|
||||
<div class="toolbar">
|
||||
<button onclick="saveFivePointManual('${escapeAttr(record.session.platform_type_id)}','${escapeAttr(record.effect_id)}', true)">
|
||||
保存并下一条
|
||||
@@ -1950,8 +1989,8 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
effect_id: effectId,
|
||||
manual_score_5: selectedFivePointScore,
|
||||
manual_label: "",
|
||||
evaluator: valueOf("ratingEvaluator") || valueOf("evaluator") || "manual",
|
||||
notes: valueOf("ratingNotes") || valueOf("manualNotes"),
|
||||
evaluator: currentEvaluator(),
|
||||
notes: "",
|
||||
};
|
||||
try {
|
||||
await api("/api/annotations", {
|
||||
@@ -1974,16 +2013,20 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
const items = Array.isArray(context) ? context : [];
|
||||
return items.filter(item => !isToolContextMessage(item)).map((item, index) => {
|
||||
const parsed = parseVisibleText(item.text || "");
|
||||
const rawText = parsed.content || item.text || "";
|
||||
const messageId = item.message_id || parsed.messageId || "";
|
||||
const attachments = Array.isArray(item.attachments) ? item.attachments : [];
|
||||
return {
|
||||
index,
|
||||
role: item.role || parsed.role || "message",
|
||||
source: item.source || "",
|
||||
timestamp: item.timestamp || parsed.time || "",
|
||||
name: parsed.name || roleName(item.role, item.source),
|
||||
messageId: parsed.messageId || "",
|
||||
text: cleanMessageText(parsed.content || item.text || ""),
|
||||
attachments: Array.isArray(item.attachments) ? item.attachments : [],
|
||||
isTarget: parsed.messageId && String(parsed.messageId) === String(selectedTargetMessageId()),
|
||||
messageId,
|
||||
text: cleanMessageText(rawText, attachments),
|
||||
quoteTargetIds: quoteTargetIdsFromMessage(item, rawText),
|
||||
attachments,
|
||||
isTarget: messageId && String(messageId) === String(selectedTargetMessageId()),
|
||||
};
|
||||
});
|
||||
}
|
||||
@@ -2000,9 +2043,23 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
|
||||
function parseVisibleText(text) {
|
||||
const value = String(text || "");
|
||||
const pattern = /^(?<time>\d{1,2}:\d{2}:\d{2})?(?:\[msg_id:(?<messageId>[^\]]+)\])?(?:\[(?<name>[^\]]+)\])?(?<content>[\s\S]*)$/;
|
||||
const plannerPattern = /^\s*\[时间\](?<time>[^\n]*)\n\[用户名\](?<name>[^\n]*)\n(?:\[用户群昵称\][^\n]*\n)?(?:\[msg_id\](?<plannerMessageId>[^\n]*)\n)?\[发言内容\](?<plannerContent>[\s\S]*)$/;
|
||||
const plannerMatch = value.match(plannerPattern);
|
||||
if (plannerMatch && plannerMatch.groups) {
|
||||
return {
|
||||
time: (plannerMatch.groups.time || "").trim(),
|
||||
messageId: (plannerMatch.groups.plannerMessageId || "").trim(),
|
||||
name: (plannerMatch.groups.name || "").trim(),
|
||||
content: (plannerMatch.groups.plannerContent || "").trim(),
|
||||
};
|
||||
}
|
||||
|
||||
const pattern = /^\s*(?<time>\d{1,2}:\d{2}:\d{2})?(?:\[msg_id(?::|\])(?<messageId>[^\]\n]+)\]?)?(?:\[(?<name>[^\]]+)\])?(?<content>[\s\S]*)$/;
|
||||
const match = value.match(pattern);
|
||||
if (!match || !match.groups) return { content: value };
|
||||
if (!match.groups.time && !match.groups.messageId && match.groups.name && !(match.groups.content || "").trim()) {
|
||||
return { content: value };
|
||||
}
|
||||
return {
|
||||
time: match.groups.time || "",
|
||||
messageId: match.groups.messageId || "",
|
||||
@@ -2015,8 +2072,71 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
return currentTargetMessageId;
|
||||
}
|
||||
|
||||
function renderMessageCards(messages) {
|
||||
if (!messages.length) return `<div class="empty">暂无上下文</div>`;
|
||||
function normalizeBotReply(reply) {
|
||||
const metadata = reply.reply_metadata || {};
|
||||
const sentIds = Array.isArray(metadata.sent_message_ids) ? metadata.sent_message_ids : [];
|
||||
return {
|
||||
side: "bot",
|
||||
role: "assistant",
|
||||
source: "guided_reply",
|
||||
name: "Bot",
|
||||
timestamp: "本次回复",
|
||||
messageId: sentIds[0] || "",
|
||||
messageIds: sentIds,
|
||||
text: cleanMessageText(reply.reply_text || "", []),
|
||||
quoteTargetIds: [],
|
||||
attachments: [],
|
||||
isTarget: false,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeFollowupMessages(followups) {
|
||||
if (!followups || !followups.length) return [];
|
||||
return followups.map(message => {
|
||||
const rawText = message.visible_text || message.plain_text || "";
|
||||
const parsed = parseVisibleText(rawText);
|
||||
const attachments = Array.isArray(message.attachments) ? message.attachments : [];
|
||||
const messageId = message.message_id || parsed.messageId || "";
|
||||
return {
|
||||
side: "user",
|
||||
role: "user",
|
||||
source: "followup",
|
||||
name: `${userName(message) || parsed.name}${message.is_target_user ? " · 目标用户" : ""}`,
|
||||
timestamp: message.timestamp || "",
|
||||
messageId,
|
||||
text: cleanMessageText(parsed.content || rawText, attachments),
|
||||
quoteTargetIds: quoteTargetIdsFromMessage(message, rawText),
|
||||
attachments,
|
||||
isTarget: message.is_target_user,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function buildCurrentMessageIndex(contextMessages, botReply, followupMessages) {
|
||||
currentMessageIndex = new Map();
|
||||
[...contextMessages, botReply, ...followupMessages].forEach(message => {
|
||||
const ids = [message.messageId, ...(Array.isArray(message.messageIds) ? message.messageIds : [])]
|
||||
.map(id => String(id || "").trim())
|
||||
.filter(Boolean);
|
||||
ids.forEach(id => {
|
||||
if (!currentMessageIndex.has(id)) currentMessageIndex.set(id, message);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function quoteTargetIdsFromMessage(message, text) {
|
||||
const structuredIds = Array.isArray(message.quote_target_ids) ? message.quote_target_ids : [];
|
||||
const textIds = [];
|
||||
String(text || "").replace(/\[引用回复\]\(([^)]+)\)/g, (_, id) => {
|
||||
const normalizedId = String(id || "").trim();
|
||||
if (normalizedId) textIds.push(normalizedId);
|
||||
return "";
|
||||
});
|
||||
return [...new Set([...structuredIds, ...textIds].map(id => String(id || "").trim()).filter(Boolean))];
|
||||
}
|
||||
|
||||
function renderMessageCards(messages, emptyText = "暂无上下文") {
|
||||
if (!messages.length) return `<div class="empty">${escapeHtml(emptyText)}</div>`;
|
||||
return messages.map(message => {
|
||||
const side = isBotContextMessage(message) ? "bot" : "user";
|
||||
return renderChatMessageCard({
|
||||
@@ -2027,43 +2147,13 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
timestamp: message.timestamp,
|
||||
messageId: message.messageId,
|
||||
text: message.text,
|
||||
quoteTargetIds: message.quoteTargetIds || [],
|
||||
attachments: message.attachments,
|
||||
isTarget: message.isTarget,
|
||||
});
|
||||
}).join("");
|
||||
}
|
||||
|
||||
function renderBotReplyCard(text) {
|
||||
return renderChatMessageCard({
|
||||
side: "bot",
|
||||
role: "assistant",
|
||||
source: "guided_reply",
|
||||
name: "Bot",
|
||||
timestamp: "本次回复",
|
||||
messageId: "",
|
||||
text,
|
||||
attachments: [],
|
||||
isTarget: false,
|
||||
});
|
||||
}
|
||||
|
||||
function renderFollowupCards(followups) {
|
||||
if (!followups || !followups.length) return `<div class="empty">暂无</div>`;
|
||||
return followups.map(message => `
|
||||
${renderChatMessageCard({
|
||||
side: "user",
|
||||
role: "user",
|
||||
source: "followup",
|
||||
name: `${userName(message)}${message.is_target_user ? " · 目标用户" : ""}`,
|
||||
timestamp: message.timestamp || "",
|
||||
messageId: message.message_id || "",
|
||||
text: cleanMessageText(message.visible_text || message.plain_text || ""),
|
||||
attachments: Array.isArray(message.attachments) ? message.attachments : [],
|
||||
isTarget: message.is_target_user,
|
||||
})}
|
||||
`).join("");
|
||||
}
|
||||
|
||||
function renderChatMessageCard(message) {
|
||||
const messageIdText = message.messageId ? ` · ${escapeHtml(message.messageId)}` : "";
|
||||
const textHtml = message.text ? `<div class="message-text">${escapeHtml(message.text)}</div>` : "";
|
||||
@@ -2074,6 +2164,7 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
<span class="message-name">${escapeHtml(message.name || "消息")}</span>
|
||||
<span>${escapeHtml(message.timestamp || "")}${messageIdText}</span>
|
||||
</div>
|
||||
${renderQuoteCards(message.quoteTargetIds || [])}
|
||||
${textHtml}
|
||||
${renderAttachments(message.attachments || [])}
|
||||
</div>
|
||||
@@ -2081,6 +2172,35 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
`;
|
||||
}
|
||||
|
||||
function renderQuoteCards(quoteTargetIds) {
|
||||
if (!quoteTargetIds || !quoteTargetIds.length) return "";
|
||||
return quoteTargetIds.map(targetId => {
|
||||
const quoted = currentMessageIndex.get(String(targetId || ""));
|
||||
if (!quoted) {
|
||||
return `
|
||||
<div class="quote-card missing">
|
||||
<div class="quote-title">
|
||||
<span>引用回复</span>
|
||||
<span>${escapeHtml(targetId)}</span>
|
||||
</div>
|
||||
<div class="quote-text">未在本记录的上下文或后续消息中找到这条消息</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
const quotedName = quoted.name || roleName(quoted.role, quoted.source);
|
||||
const quotedText = quoted.text || attachmentSummary(quoted.attachments || []) || "无文本内容";
|
||||
return `
|
||||
<div class="quote-card">
|
||||
<div class="quote-title">
|
||||
<span>引用 ${escapeHtml(quotedName)}</span>
|
||||
<span>${escapeHtml(targetId)}</span>
|
||||
</div>
|
||||
<div class="quote-text">${escapeHtml(quotedText)}</div>
|
||||
</div>
|
||||
`;
|
||||
}).join("");
|
||||
}
|
||||
|
||||
function renderAttachments(attachments) {
|
||||
const shown = (attachments || []).filter(item => attachmentUrl(item));
|
||||
if (!shown.length) return "";
|
||||
@@ -2089,7 +2209,6 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
${shown.map(item => `
|
||||
<div>
|
||||
<img class="message-image" src="${escapeAttr(attachmentUrl(item))}" alt="${escapeAttr(item.content || item.kind || "图片")}" loading="lazy" />
|
||||
${item.content ? `<div class="message-image-caption">${escapeHtml(item.content)}</div>` : ""}
|
||||
</div>
|
||||
`).join("")}
|
||||
</div>
|
||||
@@ -2104,11 +2223,32 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
return "";
|
||||
}
|
||||
|
||||
function cleanMessageText(text) {
|
||||
return String(text || "")
|
||||
.replace(/\[图片\]/g, "")
|
||||
.replace(/\[表情包?\]/g, "")
|
||||
.trim();
|
||||
function cleanMessageText(text, attachments = []) {
|
||||
let normalized = stripVisibleMessagePrefix(String(text || "")).replace(/\[引用回复\]\([^)]+\)/g, "");
|
||||
const shownAttachments = (attachments || []).filter(item => attachmentUrl(item));
|
||||
if (shownAttachments.length) {
|
||||
normalized = normalized
|
||||
.replace(/\[图片\]/g, "")
|
||||
.replace(/\[表情包?\]/g, "");
|
||||
for (const attachment of shownAttachments) {
|
||||
const content = String(attachment.content || "").trim();
|
||||
if (!content) continue;
|
||||
normalized = normalized.split(content).join("");
|
||||
}
|
||||
}
|
||||
return normalized.trim();
|
||||
}
|
||||
|
||||
function stripVisibleMessagePrefix(text) {
|
||||
const parsed = parseVisibleText(text);
|
||||
if (parsed.content && parsed.content !== text) return parsed.content;
|
||||
return String(text || "");
|
||||
}
|
||||
|
||||
function attachmentSummary(attachments) {
|
||||
const count = Array.isArray(attachments) ? attachments.length : 0;
|
||||
if (!count) return "";
|
||||
return count === 1 ? "[图片]" : `[${count} 张图片]`;
|
||||
}
|
||||
|
||||
function isBotContextMessage(message) {
|
||||
@@ -2152,6 +2292,28 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
return element ? element.value : "";
|
||||
}
|
||||
|
||||
function currentEvaluator() {
|
||||
return valueOf("globalEvaluator").trim() || "manual";
|
||||
}
|
||||
|
||||
function saveGlobalEvaluator() {
|
||||
try {
|
||||
localStorage.setItem("replyEffectEvaluator", currentEvaluator());
|
||||
} catch (_err) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
function restoreGlobalEvaluator() {
|
||||
const input = document.getElementById("globalEvaluator");
|
||||
if (!input) return;
|
||||
try {
|
||||
input.value = localStorage.getItem("replyEffectEvaluator") || "manual";
|
||||
} catch (_err) {
|
||||
input.value = "manual";
|
||||
}
|
||||
}
|
||||
|
||||
function scoreText(v) {
|
||||
return v === null || v === undefined || v === "" ? "N/A" : Number(v).toFixed(1);
|
||||
}
|
||||
@@ -2174,6 +2336,7 @@ INDEX_HTML_V3 = r"""<!doctype html>
|
||||
return escapeHtml(value).replace(/`/g, "`");
|
||||
}
|
||||
|
||||
restoreGlobalEvaluator();
|
||||
reloadAll();
|
||||
</script>
|
||||
</body>
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("lpmm")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
|
||||
def _process_text_file(file_path):
|
||||
"""处理单个文本文件,返回段落列表"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
|
||||
paragraphs = []
|
||||
paragraph = ""
|
||||
for line in raw.split("\n"):
|
||||
if line.strip() == "":
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph.strip())
|
||||
paragraph = ""
|
||||
else:
|
||||
paragraph += line + "\n"
|
||||
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph.strip())
|
||||
|
||||
return paragraphs
|
||||
|
||||
|
||||
def _process_multi_files() -> list:
|
||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||
if not raw_files:
|
||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||
sys.exit(1)
|
||||
# 处理所有文件
|
||||
all_paragraphs = []
|
||||
for file in raw_files:
|
||||
logger.info(f"正在处理文件: {file.name}")
|
||||
paragraphs = _process_text_file(file)
|
||||
all_paragraphs.extend(paragraphs)
|
||||
return all_paragraphs
|
||||
|
||||
|
||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
"""加载原始数据文件
|
||||
|
||||
读取原始数据文件,将原始数据加载到内存中
|
||||
|
||||
Args:
|
||||
path: 可选,指定要读取的json文件绝对路径
|
||||
|
||||
Returns:
|
||||
- raw_data: 原始数据列表
|
||||
- sha256_list: 原始数据的SHA256集合
|
||||
"""
|
||||
raw_paragraphs = _process_multi_files()
|
||||
sha256_list = []
|
||||
sha256_set = set()
|
||||
raw_data: list[str] = []
|
||||
for item in raw_paragraphs:
|
||||
if not isinstance(item, str):
|
||||
logger.warning(f"数据类型错误:{item}")
|
||||
continue
|
||||
pg_hash = get_sha256(item)
|
||||
if pg_hash in sha256_set:
|
||||
logger.warning(f"重复数据:{item}")
|
||||
continue
|
||||
sha256_set.add(pg_hash)
|
||||
sha256_list.append(pg_hash)
|
||||
raw_data.append(item)
|
||||
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||
|
||||
return sha256_list, raw_data
|
||||
@@ -1,66 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import lpmm_start_up, get_qa_manager
|
||||
|
||||
logger = get_logger("refresh_lpmm_knowledge")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logger.info("开始刷新 LPMM 知识库(重新加载向量库与 KG)...")
|
||||
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.warning(
|
||||
"当前配置中 lpmm_knowledge.enable = false,本次仅刷新磁盘数据与内存结构,"
|
||||
"但聊天侧如未启用 LPMM 仍不会在问答中使用知识库。"
|
||||
)
|
||||
|
||||
# 调用标准启动逻辑,内部会加载 data/embedding 与 data/rag
|
||||
lpmm_start_up()
|
||||
|
||||
qa_manager = get_qa_manager()
|
||||
if qa_manager is None:
|
||||
logger.error("刷新后 qa_manager 仍为 None,请检查是否已经成功导入过 LPMM 知识库。")
|
||||
return
|
||||
|
||||
# 简要输出当前知识库规模,方便人工确认
|
||||
embed_manager = qa_manager.embed_manager
|
||||
kg_manager = qa_manager.kg_manager
|
||||
|
||||
para_vec = len(embed_manager.paragraphs_embedding_store.store)
|
||||
ent_vec = len(embed_manager.entities_embedding_store.store)
|
||||
rel_vec = len(embed_manager.relation_embedding_store.store)
|
||||
nodes = len(kg_manager.graph.get_node_list())
|
||||
edges = len(kg_manager.graph.get_edge_list())
|
||||
|
||||
logger.info("LPMM 知识库刷新完成,当前规模:")
|
||||
logger.info(
|
||||
"段落向量=%d, 实体向量=%d, 关系向量=%d, KG节点=%d, KG边=%d",
|
||||
para_vec,
|
||||
ent_vec,
|
||||
rel_vec,
|
||||
nodes,
|
||||
edges,
|
||||
)
|
||||
|
||||
print("\n[REFRESH] 刷新完成,请注意:")
|
||||
print("- 本脚本是在独立进程内执行的,用于验证磁盘数据可以正常加载。")
|
||||
print("- 若主程序已在运行且未在内部调用 lpmm_start_up() 重新初始化,仍需重启或新增管理入口来热刷新。")
|
||||
print("- 如果不清楚 lpmm_start_up 是什么,只需要重启主程序即可。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,146 +0,0 @@
|
||||
# ruff: noqa: E402
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
SRC_ROOT = PROJECT_ROOT / "src"
|
||||
if str(SRC_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(SRC_ROOT))
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(1, str(PROJECT_ROOT))
|
||||
|
||||
from src.config.config import config_manager
|
||||
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, ResponseRequest, client_registry
|
||||
from src.llm_models.model_client.base_client import EmbeddingRequest
|
||||
from src.llm_models.request_snapshot import (
|
||||
deserialize_messages_snapshot,
|
||||
deserialize_model_info_snapshot,
|
||||
deserialize_response_format_snapshot,
|
||||
deserialize_tool_options_snapshot,
|
||||
)
|
||||
|
||||
|
||||
def _load_snapshot(snapshot_path: Path) -> dict[str, Any]:
|
||||
"""加载请求快照。"""
|
||||
return json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _resolve_api_provider(provider_name: str):
|
||||
"""根据名称解析当前配置中的 API Provider。"""
|
||||
model_config = config_manager.get_model_config()
|
||||
for api_provider in model_config.api_providers:
|
||||
if api_provider.name == provider_name:
|
||||
return api_provider
|
||||
raise ValueError(f"当前配置中不存在名为 {provider_name!r} 的 API Provider")
|
||||
|
||||
|
||||
def _build_response_request(snapshot: dict[str, Any]) -> ResponseRequest:
|
||||
"""从快照构建响应请求对象。"""
|
||||
return ResponseRequest(
|
||||
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||
max_tokens=snapshot.get("max_tokens"),
|
||||
message_list=deserialize_messages_snapshot(snapshot.get("message_list") or []),
|
||||
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||
response_format=deserialize_response_format_snapshot(snapshot.get("response_format")),
|
||||
temperature=snapshot.get("temperature"),
|
||||
tool_options=deserialize_tool_options_snapshot(snapshot.get("tool_options")),
|
||||
)
|
||||
|
||||
|
||||
def _build_embedding_request(snapshot: dict[str, Any]) -> EmbeddingRequest:
|
||||
"""从快照构建嵌入请求对象。"""
|
||||
return EmbeddingRequest(
|
||||
embedding_input=str(snapshot.get("embedding_input") or ""),
|
||||
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||
)
|
||||
|
||||
|
||||
def _build_audio_request(snapshot: dict[str, Any]) -> AudioTranscriptionRequest:
|
||||
"""从快照构建音频转写请求对象。"""
|
||||
return AudioTranscriptionRequest(
|
||||
audio_base64=str(snapshot.get("audio_base64") or ""),
|
||||
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||
max_tokens=snapshot.get("max_tokens"),
|
||||
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||
)
|
||||
|
||||
|
||||
async def _replay(snapshot_path: Path) -> int:
|
||||
"""回放一条失败请求快照。"""
|
||||
config_manager.initialize()
|
||||
snapshot = _load_snapshot(snapshot_path)
|
||||
|
||||
internal_request = snapshot.get("internal_request")
|
||||
if not isinstance(internal_request, dict):
|
||||
raise ValueError("快照缺少 internal_request 字段")
|
||||
|
||||
provider_snapshot = snapshot.get("api_provider")
|
||||
if not isinstance(provider_snapshot, dict):
|
||||
raise ValueError("快照缺少 api_provider 字段")
|
||||
|
||||
provider_name = str(provider_snapshot.get("name") or "")
|
||||
if not provider_name:
|
||||
raise ValueError("快照中的 api_provider.name 不能为空")
|
||||
|
||||
api_provider = _resolve_api_provider(provider_name)
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=True)
|
||||
|
||||
request_kind = str(internal_request.get("request_kind") or "").strip()
|
||||
if request_kind == "response":
|
||||
response = await client.get_response(_build_response_request(internal_request))
|
||||
elif request_kind == "embedding":
|
||||
response = await client.get_embedding(_build_embedding_request(internal_request))
|
||||
elif request_kind == "audio_transcription":
|
||||
response = await client.get_audio_transcriptions(_build_audio_request(internal_request))
|
||||
else:
|
||||
raise ValueError(f"不支持的 request_kind: {request_kind!r}")
|
||||
|
||||
output_payload = {
|
||||
"content": response.content,
|
||||
"embedding_length": len(response.embedding or []),
|
||||
"has_embedding": response.embedding is not None,
|
||||
"model_name": response.usage.model_name if response.usage is not None else None,
|
||||
"provider_name": response.usage.provider_name if response.usage is not None else None,
|
||||
"raw_data_type": type(response.raw_data).__name__ if response.raw_data is not None else None,
|
||||
"reasoning_content": response.reasoning_content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"args": tool_call.args,
|
||||
"call_id": tool_call.call_id,
|
||||
"func_name": tool_call.func_name,
|
||||
}
|
||||
for tool_call in (response.tool_calls or [])
|
||||
],
|
||||
"usage": {
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
if response.usage is not None
|
||||
else None,
|
||||
}
|
||||
print(json.dumps(output_payload, ensure_ascii=False, indent=2))
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本入口。"""
|
||||
parser = argparse.ArgumentParser(description="回放失败的 LLM 请求快照。")
|
||||
parser.add_argument("snapshot_path", help="请求快照 JSON 文件路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
snapshot_path = Path(args.snapshot_path).expanduser().resolve()
|
||||
if not snapshot_path.exists():
|
||||
raise FileNotFoundError(f"快照文件不存在: {snapshot_path}")
|
||||
|
||||
return asyncio.run(_replay(snapshot_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,303 +0,0 @@
|
||||
"""
|
||||
统计和展示 replyer 动作选择记录
|
||||
|
||||
用法:
|
||||
python scripts/replyer_action_stats.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
try:
|
||||
from src.common.database.database_model import ChatStreams
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _script_chat_manager
|
||||
except ImportError:
|
||||
ChatStreams = None
|
||||
_script_chat_manager = None
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
if ChatStreams:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream:
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name}"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊"
|
||||
|
||||
if _script_chat_manager:
|
||||
chat_manager = _script_chat_manager
|
||||
stream_name = chat_manager.get_stream_name(chat_id)
|
||||
if stream_name:
|
||||
return stream_name
|
||||
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id[:8]}...)"
|
||||
|
||||
|
||||
def load_records(temp_dir: str = "data/temp") -> List[Dict[str, Any]]:
|
||||
"""加载所有 replyer 动作记录"""
|
||||
records = []
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
if not temp_path.exists():
|
||||
print(f"目录不存在: {temp_dir}")
|
||||
return records
|
||||
|
||||
# 查找所有 replyer_action_*.json 文件
|
||||
pattern = "replyer_action_*.json"
|
||||
for file_path in temp_path.glob(pattern):
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
records.append(data)
|
||||
except Exception as e:
|
||||
print(f"读取文件失败 {file_path}: {e}")
|
||||
|
||||
# 按时间戳排序
|
||||
records.sort(key=lambda x: x.get("timestamp", ""))
|
||||
return records
|
||||
|
||||
|
||||
def format_timestamp(ts: str) -> str:
|
||||
"""格式化时间戳"""
|
||||
try:
|
||||
dt = datetime.fromisoformat(ts)
|
||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except Exception:
|
||||
return ts
|
||||
|
||||
|
||||
def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int]:
|
||||
"""计算时间分布"""
|
||||
now = datetime.now()
|
||||
distribution = {
|
||||
"今天": 0,
|
||||
"昨天": 0,
|
||||
"3天内": 0,
|
||||
"7天内": 0,
|
||||
"30天内": 0,
|
||||
"更早": 0,
|
||||
}
|
||||
|
||||
for record in records:
|
||||
try:
|
||||
ts = record.get("timestamp", "")
|
||||
if not ts:
|
||||
continue
|
||||
dt = datetime.fromisoformat(ts)
|
||||
diff = (now - dt).days
|
||||
|
||||
if diff == 0:
|
||||
distribution["今天"] += 1
|
||||
elif diff == 1:
|
||||
distribution["昨天"] += 1
|
||||
elif diff < 3:
|
||||
distribution["3天内"] += 1
|
||||
elif diff < 7:
|
||||
distribution["7天内"] += 1
|
||||
elif diff < 30:
|
||||
distribution["30天内"] += 1
|
||||
else:
|
||||
distribution["更早"] += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def print_statistics(records: List[Dict[str, Any]]):
|
||||
"""打印统计信息"""
|
||||
if not records:
|
||||
print("没有找到任何记录")
|
||||
return
|
||||
|
||||
print("=" * 80)
|
||||
print("Replyer 动作选择记录统计")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# 总记录数
|
||||
total_count = len(records)
|
||||
print(f"📊 总记录数: {total_count}")
|
||||
print()
|
||||
|
||||
# 时间范围
|
||||
timestamps = [r.get("timestamp", "") for r in records if r.get("timestamp")]
|
||||
if timestamps:
|
||||
first_time = format_timestamp(min(timestamps))
|
||||
last_time = format_timestamp(max(timestamps))
|
||||
print(f"📅 时间范围: {first_time} ~ {last_time}")
|
||||
print()
|
||||
|
||||
# 按 think_level 统计
|
||||
think_levels = [r.get("think_level", 0) for r in records]
|
||||
think_level_counter = Counter(think_levels)
|
||||
print("🧠 思考深度分布:")
|
||||
for level in sorted(think_level_counter.keys()):
|
||||
count = think_level_counter[level]
|
||||
percentage = (count / total_count) * 100
|
||||
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
|
||||
print(f" Level {level} ({level_name}): {count} 次 ({percentage:.1f}%)")
|
||||
print()
|
||||
|
||||
# 按 chat_id 统计(总体)
|
||||
chat_counter = Counter([r.get("chat_id", "未知") for r in records])
|
||||
print(f"💬 聊天分布 (共 {len(chat_counter)} 个聊天):")
|
||||
# 只显示前10个
|
||||
for chat_id, count in chat_counter.most_common(10):
|
||||
chat_name = get_chat_name(chat_id)
|
||||
percentage = (count / total_count) * 100
|
||||
print(f" {chat_name}: {count} 次 ({percentage:.1f}%)")
|
||||
if len(chat_counter) > 10:
|
||||
print(f" ... 还有 {len(chat_counter) - 10} 个聊天")
|
||||
print()
|
||||
|
||||
# 每个 chat_id 的详细统计
|
||||
print("=" * 80)
|
||||
print("每个聊天的详细统计")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# 按 chat_id 分组记录
|
||||
records_by_chat = defaultdict(list)
|
||||
for record in records:
|
||||
chat_id = record.get("chat_id", "未知")
|
||||
records_by_chat[chat_id].append(record)
|
||||
|
||||
# 按记录数排序
|
||||
sorted_chats = sorted(records_by_chat.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
|
||||
for chat_id, chat_records in sorted_chats:
|
||||
chat_name = get_chat_name(chat_id)
|
||||
chat_count = len(chat_records)
|
||||
chat_percentage = (chat_count / total_count) * 100
|
||||
|
||||
print(f"📱 {chat_name} ({chat_id[:8]}...)")
|
||||
print(f" 总记录数: {chat_count} ({chat_percentage:.1f}%)")
|
||||
|
||||
# 该聊天的 think_level 分布
|
||||
chat_think_levels = [r.get("think_level", 0) for r in chat_records]
|
||||
chat_think_counter = Counter(chat_think_levels)
|
||||
print(" 思考深度分布:")
|
||||
for level in sorted(chat_think_counter.keys()):
|
||||
level_count = chat_think_counter[level]
|
||||
level_percentage = (level_count / chat_count) * 100
|
||||
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
|
||||
print(f" Level {level} ({level_name}): {level_count} 次 ({level_percentage:.1f}%)")
|
||||
|
||||
# 该聊天的时间范围
|
||||
chat_timestamps = [r.get("timestamp", "") for r in chat_records if r.get("timestamp")]
|
||||
if chat_timestamps:
|
||||
first_time = format_timestamp(min(chat_timestamps))
|
||||
last_time = format_timestamp(max(chat_timestamps))
|
||||
print(f" 时间范围: {first_time} ~ {last_time}")
|
||||
|
||||
# 该聊天的时间分布
|
||||
chat_time_dist = calculate_time_distribution(chat_records)
|
||||
print(" 时间分布:")
|
||||
for period, count in chat_time_dist.items():
|
||||
if count > 0:
|
||||
period_percentage = (count / chat_count) * 100
|
||||
print(f" {period}: {count} 次 ({period_percentage:.1f}%)")
|
||||
|
||||
# 显示该聊天最近的一条理由示例
|
||||
if chat_records:
|
||||
latest_record = chat_records[-1]
|
||||
reason = latest_record.get("reason", "无理由")
|
||||
if len(reason) > 120:
|
||||
reason = reason[:120] + "..."
|
||||
timestamp = format_timestamp(latest_record.get("timestamp", ""))
|
||||
think_level = latest_record.get("think_level", 0)
|
||||
print(f" 最新记录 [{timestamp}] (Level {think_level}): {reason}")
|
||||
|
||||
print()
|
||||
|
||||
# 时间分布
|
||||
time_dist = calculate_time_distribution(records)
|
||||
print("⏰ 时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
percentage = (count / total_count) * 100
|
||||
print(f" {period}: {count} 次 ({percentage:.1f}%)")
|
||||
print()
|
||||
|
||||
# 显示一些示例理由
|
||||
print("📝 示例理由 (最近5条):")
|
||||
recent_records = records[-5:]
|
||||
for i, record in enumerate(recent_records, 1):
|
||||
reason = record.get("reason", "无理由")
|
||||
think_level = record.get("think_level", 0)
|
||||
timestamp = format_timestamp(record.get("timestamp", ""))
|
||||
chat_id = record.get("chat_id", "未知")
|
||||
chat_name = get_chat_name(chat_id)
|
||||
|
||||
# 截断过长的理由
|
||||
if len(reason) > 100:
|
||||
reason = reason[:100] + "..."
|
||||
|
||||
print(f" {i}. [{timestamp}] {chat_name} (Level {think_level})")
|
||||
print(f" {reason}")
|
||||
print()
|
||||
|
||||
# 按 think_level 分组显示理由示例
|
||||
print("=" * 80)
|
||||
print("按思考深度分类的示例理由")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
for level in [0, 1, 2]:
|
||||
level_records = [r for r in records if r.get("think_level") == level]
|
||||
if not level_records:
|
||||
continue
|
||||
|
||||
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
|
||||
print(f"Level {level} ({level_name}) - 共 {len(level_records)} 条:")
|
||||
|
||||
# 显示3个示例(选择最近的)
|
||||
examples = level_records[-3:] if len(level_records) >= 3 else level_records
|
||||
for i, record in enumerate(examples, 1):
|
||||
reason = record.get("reason", "无理由")
|
||||
if len(reason) > 150:
|
||||
reason = reason[:150] + "..."
|
||||
timestamp = format_timestamp(record.get("timestamp", ""))
|
||||
chat_id = record.get("chat_id", "未知")
|
||||
chat_name = get_chat_name(chat_id)
|
||||
print(f" {i}. [{timestamp}] {chat_name}")
|
||||
print(f" {reason}")
|
||||
print()
|
||||
|
||||
# 统计信息汇总
|
||||
print("=" * 80)
|
||||
print("统计汇总")
|
||||
print("=" * 80)
|
||||
print(f"总记录数: {total_count}")
|
||||
print(f"涉及聊天数: {len(chat_counter)}")
|
||||
if chat_counter:
|
||||
avg_count = total_count / len(chat_counter)
|
||||
print(f"平均每个聊天记录数: {avg_count:.1f}")
|
||||
else:
|
||||
print("平均每个聊天记录数: N/A")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
records = load_records()
|
||||
print_statistics(records)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,122 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错影响 Embedding 加载
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from src.memory_system.retrieval_tools.query_lpmm_knowledge import query_lpmm_knowledge
|
||||
|
||||
logger = get_logger("test_lpmm_retrieval")
|
||||
|
||||
|
||||
DEFAULT_TEST_CASES: List[Dict[str, Any]] = [
|
||||
{
|
||||
"name": "回滚一批知识",
|
||||
"query": "LPMM是什么?",
|
||||
"expect_keywords": ["哈希列表", "删除脚本", "OpenIE"],
|
||||
},
|
||||
{
|
||||
"name": "调整 LPMM 检索参数",
|
||||
"query": "不同用词习惯带来的检索偏差该如何解决",
|
||||
"expect_keywords": ["bot_config.toml", "lpmm_knowledge", "qa_paragraph_search_top_k"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def run_tests(test_cases: Optional[List[Dict[str, Any]]] = None) -> None:
|
||||
"""简单测试 LPMM 知识库检索能力"""
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.warning("当前配置中 lpmm_knowledge.enable 为 False,检索测试可能直接返回“未启用”。")
|
||||
|
||||
logger.info("开始初始化 LPMM 知识库...")
|
||||
lpmm_start_up()
|
||||
logger.info("LPMM 知识库初始化完成,开始执行测试用例。")
|
||||
|
||||
cases = test_cases if test_cases is not None else DEFAULT_TEST_CASES
|
||||
|
||||
for case in cases:
|
||||
name = case["name"]
|
||||
query = case["query"]
|
||||
expect_keywords: List[str] = case.get("expect_keywords", [])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"[TEST] {name}")
|
||||
print(f"[Q] {query}")
|
||||
|
||||
result = await query_lpmm_knowledge(query, limit=3)
|
||||
|
||||
print("\n[RAW RESULT]")
|
||||
print(result)
|
||||
|
||||
status = "UNKNOWN"
|
||||
hit_keywords: List[str] = []
|
||||
|
||||
if isinstance(result, str):
|
||||
if "未启用" in result or "未初始化" in result or "查询失败" in result:
|
||||
status = "ERROR"
|
||||
elif "未找到与" in result:
|
||||
status = "NO_HIT"
|
||||
else:
|
||||
if expect_keywords:
|
||||
hit_keywords = [kw for kw in expect_keywords if kw in result]
|
||||
status = "PASS" if hit_keywords else "WARN"
|
||||
else:
|
||||
status = "PASS"
|
||||
|
||||
print("\n[CHECK]")
|
||||
print(f"Status: {status}")
|
||||
if expect_keywords:
|
||||
print(f"Expected keywords: {expect_keywords}")
|
||||
print(f"Hit keywords: {hit_keywords}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LPMM 检索测试完成。请根据每条用例的 Status 和命中关键词判断检索效果是否符合预期。")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"测试 LPMM 知识库检索能力。\n"
|
||||
"如不提供参数,则执行内置的默认用例;\n"
|
||||
"也可以通过 --query 与 --expect-keyword 自定义一条测试用例。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
help="自定义测试问题(单条)。提供该参数时,将仅运行这一条用例。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expect-keyword",
|
||||
action="append",
|
||||
help="期望在检索结果中出现的关键字,可重复多次指定;仅在提供 --query 时生效。",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.query:
|
||||
custom_case = {
|
||||
"name": "custom",
|
||||
"query": args.query,
|
||||
"expect_keywords": args.expect_keyword or [],
|
||||
}
|
||||
asyncio.run(run_tests([custom_case]))
|
||||
else:
|
||||
asyncio.run(run_tests())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -55,7 +55,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.9.3"
|
||||
CONFIG_VERSION: str = "8.9.4"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.1"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
@@ -1090,6 +1090,15 @@ class DebugConfig(ConfigBase):
|
||||
__ui_label__ = "其他"
|
||||
__ui_icon__ = "more-horizontal"
|
||||
|
||||
enable_maisaka_stage_board: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "layout-dashboard",
|
||||
},
|
||||
)
|
||||
"""是否启用 Maisaka 阶段看板"""
|
||||
|
||||
show_prompt: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -66,7 +66,8 @@ class MainSystem:
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化系统组件"""
|
||||
enable_stage_status_board()
|
||||
if global_config.debug.enable_maisaka_stage_board:
|
||||
enable_stage_status_board()
|
||||
logger.info(t("startup.waking_up", nickname=global_config.bot.nickname))
|
||||
|
||||
# 其他初始化任务
|
||||
|
||||
@@ -160,6 +160,7 @@ async def handle_tool(
|
||||
|
||||
reply_segments = tool_ctx.post_process_reply_text(reply_text)
|
||||
combined_reply_text = "".join(reply_segments)
|
||||
sent_message_ids: list[str] = []
|
||||
try:
|
||||
sent = False
|
||||
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
|
||||
@@ -181,6 +182,9 @@ async def handle_tool(
|
||||
sent = sent_message is not None
|
||||
if not sent:
|
||||
break
|
||||
sent_message_id = str(getattr(sent_message, "message_id", "") or "").strip()
|
||||
if sent_message_id:
|
||||
sent_message_ids.append(sent_message_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"{tool_ctx.runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}"
|
||||
@@ -209,6 +213,7 @@ async def handle_tool(
|
||||
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
|
||||
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
|
||||
tool_ctx.runtime._record_reply_sent()
|
||||
reply_metadata["sent_message_ids"] = sent_message_ids
|
||||
await tool_ctx.runtime.track_reply_effect(
|
||||
tool_call_id=invocation.call_id,
|
||||
target_message=target_message,
|
||||
|
||||
@@ -66,6 +66,7 @@ class FollowupMessageSnapshot:
|
||||
plain_text: str
|
||||
latency_seconds: float
|
||||
is_target_user: bool
|
||||
quote_target_ids: List[str] = field(default_factory=list)
|
||||
attachments: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
32
src/maisaka/reply_effect/quote_utils.py
Normal file
32
src/maisaka/reply_effect/quote_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""回复效果记录中的引用消息辅助工具。"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, ReplyComponent
|
||||
|
||||
|
||||
def extract_quote_target_ids(message_sequence: MessageSequence | None) -> list[str]:
|
||||
"""从消息片段中提取引用回复目标消息 ID。"""
|
||||
|
||||
if message_sequence is None:
|
||||
return []
|
||||
|
||||
target_ids: list[str] = []
|
||||
for component in getattr(message_sequence, "components", []):
|
||||
if not isinstance(component, ReplyComponent):
|
||||
continue
|
||||
target_message_id = str(component.target_message_id or "").strip()
|
||||
if target_message_id:
|
||||
target_ids.append(target_message_id)
|
||||
return target_ids
|
||||
|
||||
|
||||
def message_id_from_context_message(message: Any) -> str:
|
||||
"""尽量从 Maisaka 上下文消息中取真实消息 ID。"""
|
||||
|
||||
message_id = str(getattr(message, "message_id", "") or "").strip()
|
||||
if message_id:
|
||||
return message_id
|
||||
|
||||
original_message = getattr(message, "original_message", None)
|
||||
return str(getattr(original_message, "message_id", "") or "").strip()
|
||||
@@ -23,6 +23,7 @@ from .models import (
|
||||
UserSnapshot,
|
||||
now_iso,
|
||||
)
|
||||
from .quote_utils import extract_quote_target_ids
|
||||
from .path_utils import build_reply_effect_chat_dir_name
|
||||
from .scoring import (
|
||||
has_explicit_negative_feedback,
|
||||
@@ -190,6 +191,7 @@ class ReplyEffectTracker:
|
||||
plain_text=plain_text,
|
||||
latency_seconds=round(latency_seconds, 3),
|
||||
is_target_user=bool(record.target_user.user_id and user_id == record.target_user.user_id),
|
||||
quote_target_ids=extract_quote_target_ids(message.raw_message),
|
||||
attachments=extract_visual_attachments_from_sequence(message.raw_message),
|
||||
)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from .display.stage_status_board import remove_stage_status, update_stage_status
|
||||
from .reasoning_engine import MaisakaReasoningEngine
|
||||
from .reply_effect import ReplyEffectTracker
|
||||
from .reply_effect.image_utils import extract_visual_attachments_from_sequence
|
||||
from .reply_effect.quote_utils import extract_quote_target_ids, message_id_from_context_message
|
||||
from .tool_provider import MaisakaBuiltinToolProvider
|
||||
|
||||
logger = get_logger("maisaka_runtime")
|
||||
@@ -349,10 +350,12 @@ class MaisakaHeartFlowChatting:
|
||||
continue
|
||||
snapshot.append(
|
||||
{
|
||||
"message_id": message_id_from_context_message(message),
|
||||
"source": message.source,
|
||||
"role": message.role,
|
||||
"timestamp": message.timestamp.isoformat(timespec="seconds"),
|
||||
"text": text,
|
||||
"quote_target_ids": extract_quote_target_ids(getattr(message, "raw_message", None)),
|
||||
"attachments": extract_visual_attachments_from_sequence(getattr(message, "raw_message", None)),
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user