Merge branch 'main' into helm-charts/0.12.0
This commit is contained in:
160
.github/workflows/docker-image-dev.yml
vendored
Normal file
160
.github/workflows/docker-image-dev.yml
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
name: Docker Build and Push (Dev)
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # every day at midnight UTC
|
||||
# branches:
|
||||
# - dev
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to build'
|
||||
required: false
|
||||
default: 'dev'
|
||||
|
||||
# Workflow's jobs
|
||||
jobs:
|
||||
build-amd64:
|
||||
name: Build AMD64 Image
|
||||
runs-on: ubuntu-24.04
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: dev
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
# - name: Clone maim_message
|
||||
# run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
|
||||
- name: Clone lpmm
|
||||
run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot
|
||||
|
||||
# Build and push AMD64 image by digest
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:dev-amd64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:dev-amd64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/maibot,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
build-arm64:
|
||||
name: Build ARM64 Image
|
||||
runs-on: ubuntu-24.04-arm
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: dev
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
# - name: Clone maim_message
|
||||
# run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
|
||||
- name: Clone lpmm
|
||||
run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot
|
||||
|
||||
# Build and push ARM64 image by digest
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/arm64/v8
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:dev-arm64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:dev-arm64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/maibot,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
create-manifest:
|
||||
name: Create Multi-Arch Manifest
|
||||
runs-on: ubuntu-24.04
|
||||
needs:
|
||||
- build-amd64
|
||||
- build-arm64
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot
|
||||
tags: |
|
||||
type=raw,value=dev
|
||||
type=schedule,pattern=dev-{{date 'YYMMDD'}}
|
||||
|
||||
- name: Create and Push Manifest
|
||||
run: |
|
||||
# 为每个标签创建多架构镜像
|
||||
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
|
||||
echo "Creating manifest for $tag"
|
||||
docker buildx imagetools create -t $tag \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maibot@${{ needs.build-amd64.outputs.digest }} \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maibot@${{ needs.build-arm64.outputs.digest }}
|
||||
done
|
||||
@@ -1,8 +1,6 @@
|
||||
name: Docker Build and Push
|
||||
name: Docker Build and Push (Main)
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
@@ -13,6 +11,11 @@ on:
|
||||
- "*.*.*"
|
||||
- "*.*.*-*"
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to build'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
# Workflow's jobs
|
||||
jobs:
|
||||
@@ -25,15 +28,14 @@ jobs:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
- name: Clone maim_message
|
||||
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
# - name: Clone maim_message
|
||||
# run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
|
||||
- name: Clone lpmm
|
||||
run: git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -79,15 +81,14 @@ jobs:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
- name: Clone maim_message
|
||||
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
# - name: Clone maim_message
|
||||
# run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
|
||||
- name: Clone lpmm
|
||||
run: git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -164,4 +165,4 @@ jobs:
|
||||
docker buildx imagetools create -t $tag \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maibot@${{ needs.build-amd64.outputs.digest }} \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maibot@${{ needs.build-arm64.outputs.digest }}
|
||||
done
|
||||
done
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -35,9 +35,6 @@ message_queue_content.bat
|
||||
message_queue_window.bat
|
||||
message_queue_window.txt
|
||||
queue_update.txt
|
||||
memory_graph.gml
|
||||
/src/tools/tool_can_use/auto_create_tool.py
|
||||
/src/tools/tool_can_use/execute_python_code_tool.py
|
||||
.env
|
||||
.env.*
|
||||
.cursor
|
||||
@@ -48,9 +45,6 @@ config/lpmm_config.toml
|
||||
config/lpmm_config.toml.bak
|
||||
template/compare/bot_config_template.toml
|
||||
template/compare/model_config_template.toml
|
||||
(测试版)麦麦生成人格.bat
|
||||
(临时版)麦麦开始学习.bat
|
||||
src/plugins/utils/statistic.py
|
||||
CLAUDE.md
|
||||
MaiBot-Dashboard/
|
||||
cloudflare-workers/
|
||||
@@ -327,6 +321,7 @@ run_pet.bat
|
||||
!/plugins/emoji_manage_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
!/plugins/deep_think
|
||||
!/plugins/MaiBot_MCPBridgePlugin
|
||||
!/plugins/ChatFrequency/
|
||||
!/plugins/__init__.py
|
||||
|
||||
@@ -334,4 +329,5 @@ config.toml
|
||||
|
||||
interested_rates.txt
|
||||
MaiBot.code-workspace
|
||||
*.lock
|
||||
*.lock
|
||||
actionlint
|
||||
|
||||
@@ -71,7 +71,6 @@
|
||||
|
||||
1. **GitHub Issues**: 对于公开的违规行为,可以在相关issue中直接指出
|
||||
2. **私下联系**: 可以通过GitHub私信联系项目维护者
|
||||
3. **邮件联系**: [如果有项目邮箱地址,请在此提供]
|
||||
|
||||
所有报告都将得到及时和公正的处理。我们承诺保护报告者的隐私和安全。
|
||||
|
||||
|
||||
10
README.md
10
README.md
@@ -50,12 +50,16 @@
|
||||
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
|
||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||
|
||||
注意,启动器处于早期开发版本,仅支持MacOS
|
||||
|
||||
**GitHub 分支说明:**
|
||||
- `main`: 稳定发布版本(推荐)
|
||||
|
||||
|
||||
- `dev`: 开发测试版本(不稳定)
|
||||
|
||||
- `classical`: 经典版本(停止维护)
|
||||
|
||||
### 最新版本部署教程
|
||||
@@ -69,7 +73,7 @@
|
||||
|
||||
## 💬 讨论
|
||||
|
||||
**技术交流群:**
|
||||
**技术交流群/答疑群:**
|
||||
[麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) |
|
||||
[麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) |
|
||||
[麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) |
|
||||
@@ -79,7 +83,7 @@
|
||||
**聊天吹水群:**
|
||||
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
|
||||
|
||||
麦麦相关闲聊群
|
||||
麦麦相关闲聊群,此群仅用于聊天,提问部署/技术问题可能不会快速得到答案
|
||||
|
||||
**插件开发/测试版讨论群:**
|
||||
- [插件开发群](https://qm.qq.com/q/1036092828)
|
||||
|
||||
27
bot.py
27
bot.py
@@ -34,13 +34,17 @@ else:
|
||||
print(f"自动创建 .env 失败: {e}")
|
||||
raise
|
||||
|
||||
initialize_logging()
|
||||
# 检查是否是 Worker 进程,只在 Worker 进程中输出详细的初始化信息
|
||||
# Runner 进程只需要基本的日志功能,不需要详细的初始化日志
|
||||
is_worker = os.environ.get("MAIBOT_WORKER_PROCESS") == "1"
|
||||
initialize_logging(verbose=is_worker)
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("main")
|
||||
|
||||
# 定义重启退出码
|
||||
RESTART_EXIT_CODE = 42
|
||||
|
||||
|
||||
def run_runner_process():
|
||||
"""
|
||||
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
||||
@@ -55,25 +59,25 @@ def run_runner_process():
|
||||
|
||||
while True:
|
||||
logger.info(f"正在启动 {script_file}...")
|
||||
|
||||
|
||||
# 启动子进程 (Worker)
|
||||
# 使用 sys.executable 确保使用相同的 Python 解释器
|
||||
cmd = [python_executable, script_file] + sys.argv[1:]
|
||||
|
||||
|
||||
process = subprocess.Popen(cmd, env=env)
|
||||
|
||||
|
||||
try:
|
||||
# 等待子进程结束
|
||||
return_code = process.wait()
|
||||
|
||||
|
||||
if return_code == RESTART_EXIT_CODE:
|
||||
logger.info("检测到重启请求 (退出码 42),正在重启...")
|
||||
time.sleep(1) # 稍作等待
|
||||
time.sleep(1) # 稍作等待
|
||||
continue
|
||||
else:
|
||||
logger.info(f"程序已退出 (退出码 {return_code})")
|
||||
sys.exit(return_code)
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# 向子进程发送终止信号
|
||||
if process.poll() is None:
|
||||
@@ -87,6 +91,7 @@ def run_runner_process():
|
||||
process.kill()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# 检查是否是 Worker 进程
|
||||
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
|
||||
# 此时应该作为 Runner 运行。
|
||||
@@ -99,8 +104,10 @@ if os.environ.get("MAIBOT_WORKER_PROCESS") != "1":
|
||||
# 以下是 Worker 进程的逻辑
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
# from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
|
||||
# initialize_logging()
|
||||
# 注意:Runner 进程已经在第 37 行初始化了日志系统,但 Worker 进程是独立进程,需要重新初始化
|
||||
# 由于 Runner 和 Worker 是不同进程,它们有独立的内存空间,所以都会初始化一次
|
||||
# 这是正常的,但为了避免重复的初始化日志,我们在 initialize_logging() 中添加了防重复机制
|
||||
# 不过由于是不同进程,每个进程仍会初始化一次,这是预期的行为
|
||||
|
||||
from src.main import MainSystem # noqa
|
||||
from src.manager.async_task_manager import async_task_manager # noqa
|
||||
@@ -143,7 +150,7 @@ def print_opensource_notice():
|
||||
"",
|
||||
f"{Fore.WHITE} 官方仓库: {Fore.BLUE}https://github.com/MaiM-with-u/MaiBot {Style.RESET_ALL}",
|
||||
f"{Fore.WHITE} 官方文档: {Fore.BLUE}https://docs.mai-mai.org {Style.RESET_ALL}",
|
||||
f"{Fore.WHITE} 官方群聊: {Fore.BLUE}766798517{Style.RESET_ALL}",
|
||||
f"{Fore.WHITE} 官方群聊: {Fore.BLUE}1006149251{Style.RESET_ALL}",
|
||||
f"{Fore.CYAN}{'─' * 70}{Style.RESET_ALL}",
|
||||
f"{Fore.RED} ⚠ 将本软件作为「商品」倒卖、隐瞒开源性质均违反协议!{Style.RESET_ALL}",
|
||||
f"{Fore.CYAN}{'═' * 70}{Style.RESET_ALL}",
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
# Changelog
|
||||
|
||||
## [0.12.0] - 2025-12-16
|
||||
### 🌟 重大更新
|
||||
- 添加思考力度机制,动态控制回复时间和长度
|
||||
- planner和replyer现在开启联动,更好的回复逻辑
|
||||
- 新的私聊系统,吸收了pfc的优秀机制
|
||||
- 增加麦麦做梦功能
|
||||
- mcp插件作为内置插件加入,默认不启用
|
||||
- 添加全局记忆配置项,现在可以选择让记忆为全局的
|
||||
|
||||
### 细节功能更改
|
||||
- 移除频率自动调整
|
||||
- 移除情绪功能
|
||||
- 优化记忆差许多呢超时设置
|
||||
- 部分配置为0的bug
|
||||
- 黑话和表达不再提取包含名称的内容
|
||||
|
||||
## [0.11.6] - 2025-12-2
|
||||
### 🌟 重大更新
|
||||
- 大幅提高记忆检索能力,略微提高token消耗
|
||||
|
||||
@@ -27,7 +27,7 @@ services:
|
||||
# image: infinitycat/maibot:dev
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
|
||||
# - EULA_AGREE=1b662741904d7155d1ce1c00b3530d0d # 同意EULA
|
||||
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
|
||||
ports:
|
||||
- "18001:8001" # webui端口
|
||||
@@ -40,7 +40,7 @@ services:
|
||||
- ./data/MaiMBot:/MaiMBot/data # 共享目录
|
||||
- ./data/MaiMBot/plugins:/MaiMBot/plugins # 插件目录
|
||||
- ./data/MaiMBot/logs:/MaiMBot/logs # 日志目录
|
||||
- site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包
|
||||
# - site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包,需要时启用
|
||||
restart: always
|
||||
networks:
|
||||
- maim_bot
|
||||
@@ -87,8 +87,8 @@ services:
|
||||
# networks:
|
||||
# - maim_bot
|
||||
|
||||
volumes:
|
||||
site-packages:
|
||||
# volumes: # 若需要持久化Python包时启用
|
||||
# site-packages:
|
||||
networks:
|
||||
maim_bot:
|
||||
driver: bridge
|
||||
|
||||
156
docs-src/lpmm_parameters_guide.md
Normal file
156
docs-src/lpmm_parameters_guide.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# LPMM 关键参数调节指南(进阶版)
|
||||
|
||||
> 本文是对 `config/bot_config.toml` 中 `[lpmm_knowledge]` 段的补充说明。
|
||||
> 如果你只想使用默认配置,可以不改这些参数,脚本仍然可以正常工作。
|
||||
>
|
||||
> 重要提醒:无论是修改 `[lpmm_knowledge]` 段的参数,还是通过脚本导入 / 删除 LPMM 知识库数据,主程序都需要重启(或在内部调用一次 `lpmm_start_up()`)后,新的参数和知识才会真正生效到聊天侧。
|
||||
|
||||
所有与 LPMM 相关的参数,都集中在:
|
||||
|
||||
```toml
|
||||
[lpmm_knowledge] # lpmm知识库配置
|
||||
enable = true
|
||||
lpmm_mode = "agent"
|
||||
...
|
||||
```
|
||||
|
||||
下面按功能将常用参数分为三组介绍。
|
||||
|
||||
---
|
||||
|
||||
## 一、检索相关参数(影响答案质量与风格)
|
||||
|
||||
```toml
|
||||
qa_relation_search_top_k = 10 # 关系检索TopK
|
||||
qa_relation_threshold = 0.5 # 关系阈值,相似度高于该值才认为“命中关系”
|
||||
qa_paragraph_search_top_k = 1000 # 段落检索TopK,越小可能影响召回
|
||||
qa_paragraph_node_weight = 0.05 # 段落节点权重,在图检索&PPR中的权重
|
||||
qa_ent_filter_top_k = 10 # 实体过滤TopK
|
||||
qa_ppr_damping = 0.8 # PPR阻尼系数
|
||||
qa_res_top_k = 3 # 最终提供给问答模型的段落数
|
||||
```
|
||||
|
||||
- `qa_relation_search_top_k`
|
||||
控制“最多考虑多少条关系向量候选”。
|
||||
- 数值大:召回更全面,但略慢;
|
||||
- 数值小:更快,可能遗漏部分隐含关系。
|
||||
|
||||
- `qa_relation_threshold`
|
||||
关系相似度的阈值:
|
||||
- 数值高:只信任非常相关的关系,系统更可能退化为纯段落向量检索;
|
||||
- 数值低:图结构影响更大,适合实体关系较丰富的场景。
|
||||
|
||||
- `qa_paragraph_search_top_k`
|
||||
控制“最多考虑多少段落候选”。
|
||||
- 太小:可能召回不全,导致答案缺失;
|
||||
- 太大:略微增加计算量,一般 1000 为安全默认。
|
||||
|
||||
- `qa_paragraph_node_weight`
|
||||
文段节点在图检索中的权重:
|
||||
- 数值大:更依赖段落向量相似度(传统向量检索);
|
||||
- 数值小:更依赖图结构和实体网络。
|
||||
|
||||
- `qa_ppr_damping`
|
||||
Personalized PageRank 的阻尼系数:
|
||||
- 通常保持在 0.8 左右即可;
|
||||
- 越接近 1:偏向长路径探索,结果更发散;
|
||||
- 略低:更集中在与问题直接相关的节点附近。
|
||||
|
||||
- `qa_res_top_k`
|
||||
LPMM 最终会把相关度最高的前 `qa_res_top_k` 条段落组合成“知识上下文”给问答模型。
|
||||
- 太多:增加模型负担、阅读更多文字;
|
||||
- 太少:信息不够充分,一般 3–5 比较平衡。
|
||||
|
||||
> 调参建议:
|
||||
> - 优先在 `qa_relation_threshold`、`qa_paragraph_node_weight` 上做小幅调整;
|
||||
> - 每次调整后,用 `scripts/test_lpmm_retrieval.py` 跑一遍固定问题,感受回答变化。
|
||||
|
||||
---
|
||||
|
||||
## 二、性能与硬件相关参数
|
||||
|
||||
```toml
|
||||
embedding_dimension = 1024 # 嵌入向量维度,应与模型输出维度一致
|
||||
max_embedding_workers = 12 # 嵌入/抽取并发线程数
|
||||
embedding_chunk_size = 16 # 每批嵌入的条数
|
||||
info_extraction_workers = 3 # 实体抽取同时执行线程数
|
||||
enable_ppr = true # 是否启用PPR,低配机器可关闭
|
||||
```
|
||||
|
||||
- `embedding_dimension`
|
||||
必须与所选嵌入模型的输出维度一致(比如 768、1024 等)。**不要随意修改,除非你知道你在做什么!!!**
|
||||
|
||||
- `max_embedding_workers`
|
||||
决定导入/抽取阶段的并行线程数:
|
||||
- 机器配置好:可以适当调大,加快导入速度;
|
||||
- 机器配置弱:建议调低(如 2 或 4),避免 CPU 长时间 100%。
|
||||
|
||||
- `embedding_chunk_size`
|
||||
每批发送给嵌入 API 的段落数量:
|
||||
- 数值大:请求次数少,但单次请求更“重”;
|
||||
- 数值小:请求次数多,但对网络和 API 的单次压力小。
|
||||
|
||||
- `info_extraction_workers`
|
||||
`scripts/info_extraction.py` 中实体抽取的并行线程数:
|
||||
- 使用 Pro/贵价模型时建议不要太大,避免并行费用过高;
|
||||
- 一般 2–4 就能取得较好平衡。
|
||||
|
||||
- `enable_ppr`
|
||||
是否启用个性化 PageRank(PPR)图检索:
|
||||
- `true`:检索会结合向量+知识图,效果更好,但略慢;
|
||||
- `false`:只用向量检索,牺牲一定效果,性能更稳定。
|
||||
|
||||
|
||||
> 调参建议:
|
||||
> - 若导入/检索阶段机器明显“顶不住”(>=1MB的大文本,且分配配置<4C),优先调低:
|
||||
> - `max_embedding_workers`
|
||||
> - `embedding_chunk_size`
|
||||
> - `info_extraction_workers`
|
||||
> - 或暂时将 `enable_ppr = false` (除非真的出现问题,否则不建议禁用此项,大幅影响检索效果)
|
||||
> - 调整后重新执行导入或检索,观察日志与系统资源占用。
|
||||
|
||||
> 小提示:每次大改参数或批量删除知识后,建议用
|
||||
> - `scripts/test_lpmm_retrieval.py` 看回答风格是否如预期;
|
||||
> - 如需确认当前磁盘数据能否正常初始化,可执行 `scripts/refresh_lpmm_knowledge.py` 做一次快速自检。
|
||||
|
||||
---
|
||||
|
||||
## 三、开启/关闭 LPMM 与模式说明
|
||||
|
||||
```toml
|
||||
enable = true # 是否开启lpmm知识库
|
||||
lpmm_mode = "agent" # 可选 classic / agent
|
||||
```
|
||||
|
||||
- `enable`
|
||||
- `true`:LPMM 知识库启用,检索和问答会使用知识库;
|
||||
- `false`:LPMM 完全关闭,脚本仍可导入/删除数据,但对聊天问答不生效。
|
||||
|
||||
- `lpmm_mode`
|
||||
- `classic`:传统模式,仅使用 LPMM 知识库本身;
|
||||
- `agent`:与新的记忆系统联动,用于更复杂的记忆+知识混合场景。
|
||||
|
||||
> 修改 `enable` 或 `lpmm_mode` 后,需要重启主程序,让配置生效。
|
||||
|
||||
---
|
||||
|
||||
## 四、推荐的调参流程
|
||||
|
||||
1. **保持默认配置,先跑一轮完整流程**
|
||||
- 导入 → `inspect_lpmm_global.py` → `test_lpmm_retrieval.py`;
|
||||
- 记录当前“答案风格”和“响应速度”。
|
||||
|
||||
2. **每次只调整一到两个参数**
|
||||
- 例如先调 `qa_relation_threshold`、`qa_paragraph_node_weight`;
|
||||
- 或在性能不佳时调整 `max_embedding_workers`、`enable_ppr`。
|
||||
|
||||
3. **调整后重复同一组测试问题**
|
||||
- 使用 `scripts/test_lpmm_retrieval.py`;
|
||||
- 对比不同配置下的答案,选择更符合需求的组合。
|
||||
|
||||
4. **出现“怎么调都不对”时**
|
||||
- 将 `[lpmm_knowledge]` 段恢复为仓库中的默认配置;
|
||||
- 重启主程序,即可回到“出厂设置”。
|
||||
|
||||
通过本指南中的参数调节,你可以在“检索质量”“响应速度”“系统资源占用”之间找到适合自己麦麦和机器的平衡点!
|
||||
|
||||
326
docs-src/lpmm_pipelines_guide.md
Normal file
326
docs-src/lpmm_pipelines_guide.md
Normal file
@@ -0,0 +1,326 @@
|
||||
## LPMM 知识库流水线使用指南(命令行版)
|
||||
|
||||
本文档介绍如何使用 `scripts/lpmm_manager.py` 及相关子脚本,完成 **导入 / 删除 / 自检 / 刷新 / 回归测试** 等常见流水线操作,并说明各参数在交互式与非交互(脚本化)场景下的用法。
|
||||
|
||||
所有命令均假设在项目根目录 `MaiBot/` 下执行:
|
||||
|
||||
```bash
|
||||
cd MaiBot
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. 管理脚本总览:`scripts/lpmm_manager.py`
|
||||
|
||||
### 1.1 基本用法
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py [--interactive] [-a ACTION] [--non-interactive] [-- ...子脚本参数...]
|
||||
```
|
||||
|
||||
- `--interactive` / `-i`:进入交互式菜单模式(推荐人工运维时使用)。
|
||||
- `--action` / `-a`:直接执行指定操作(非交互入口),可选值:
|
||||
- `prepare_raw`:预处理 `data/lpmm_raw_data/*.txt`。
|
||||
- `info_extract`:信息抽取,生成 OpenIE JSON 批次。
|
||||
- `import_openie`:导入 OpenIE 批次到向量库与知识图。
|
||||
- `delete`:删除/回滚知识(封装 `delete_lpmm_items.py`)。
|
||||
- `batch_inspect`:检查指定 OpenIE 批次的存在情况。
|
||||
- `global_inspect`:全库状态统计。
|
||||
- `refresh`:刷新 LPMM 磁盘数据到内存。
|
||||
- `test`:检索效果回归测试。
|
||||
- `full_import`:一键执行「预处理原始语料 → 信息抽取 → 导入 → 刷新」。
|
||||
- `--non-interactive`:
|
||||
- 启用 **非交互模式**:`lpmm_manager` 自身不会再调用 `input()` 询问确认;
|
||||
- 同时自动向子脚本透传 `--non-interactive`(若子脚本支持),用于在 CI / 定时任务中实现无人值守。
|
||||
- `--` 之后的内容会原样传递给对应子脚本的 `main()`,用于设置更细粒度参数。
|
||||
|
||||
> 注意:`--interactive` 与 `--non-interactive` 互斥,不能同时使用。
|
||||
|
||||
---
|
||||
|
||||
## 2. 典型流水线一:全量导入(从原始 txt 到可用 LPMM)
|
||||
|
||||
### 2.1 前置条件
|
||||
|
||||
- 将待导入的原始文本放入:
|
||||
|
||||
```text
|
||||
data/lpmm_raw_data/*.txt
|
||||
```
|
||||
|
||||
- 文本按「空行分段」,每个段落为一条候选知识。
|
||||
|
||||
### 2.2 一键全流程(交互式)
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py --interactive
|
||||
```
|
||||
|
||||
菜单中依次:
|
||||
|
||||
1. 选择 `9. full_import`(预处理 → 信息抽取 → 导入 → 刷新)。
|
||||
2. 按提示确认可能的费用与时间消耗。
|
||||
3. 等待脚本执行完成。
|
||||
|
||||
### 2.3 一键全流程(非交互 / CI 友好)
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a full_import --non-interactive
|
||||
```
|
||||
|
||||
执行顺序:
|
||||
|
||||
1. `prepare_raw`:调用 `raw_data_preprocessor.load_raw_data()`,统计段落与去重哈希数。
|
||||
2. `info_extract`:调用 `info_extraction.main(--non-interactive)`,从 `data/lpmm_raw_data` 读取段落,生成 OpenIE JSON 并写入 `data/openie/`。
|
||||
3. `import_openie`:调用 `import_openie.main(--non-interactive)`,导入 OpenIE 批次到嵌入库与 KG。
|
||||
4. `refresh`:调用 `refresh_lpmm_knowledge.main()`,刷新 LPMM 知识库到内存。
|
||||
|
||||
在 `--non-interactive` 模式下:
|
||||
|
||||
- 若 `data/lpmm_raw_data` 中没有 `.txt` 文件,或 `data/openie` 中没有 `.json` 文件,将直接报错退出,并在日志中说明缺少的目录/文件。
|
||||
- 若 OpenIE 批次中存在非法文段,导入脚本会 **直接报错退出**,不会卡在交互确认上。
|
||||
|
||||
---
|
||||
|
||||
## 3. 典型流水线二:分步导入
|
||||
|
||||
若需要逐步调试或只执行部分步骤,可以分开调用:
|
||||
|
||||
### 3.1 预处理原始语料:`prepare_raw`
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a prepare_raw
|
||||
```
|
||||
|
||||
行为:
|
||||
- 使用 `raw_data_preprocessor.load_raw_data()` 读取 `data/lpmm_raw_data/*.txt`;
|
||||
- 输出段落总数与去重后的哈希数,供人工检查原始数据质量。
|
||||
|
||||
### 3.2 信息抽取:`info_extract`
|
||||
|
||||
#### 交互式(带费用提示)
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a info_extract
|
||||
```
|
||||
|
||||
脚本会:
|
||||
- 打印预计费用/时间提示;
|
||||
- 询问 `确认继续执行?(y/n)`;
|
||||
- 然后开始从 `data/lpmm_raw_data` 中读取段落,调用 LLM 提取实体与三元组,并生成 OpenIE JSON。
|
||||
|
||||
#### 非交互式(无人工确认)
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a info_extract --non-interactive
|
||||
```
|
||||
|
||||
行为差异:
|
||||
- 跳过`确认继续执行`的交互提示,直接开始抽取;
|
||||
- 若 `data/lpmm_raw_data` 下没有 `.txt` 文件,会打印告警并以错误方式退出。
|
||||
|
||||
### 3.3 导入 OpenIE 批次:`import_openie`
|
||||
|
||||
#### 交互式
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a import_openie
|
||||
```
|
||||
|
||||
脚本会:
|
||||
- 提示导入开销与资源占用情况;
|
||||
- 询问是否继续;
|
||||
- 调用 `OpenIE.load()` 加载批次,再将其导入嵌入库与 KG。
|
||||
|
||||
#### 非交互式
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a import_openie --non-interactive
|
||||
```
|
||||
|
||||
- 跳过导入开销确认;
|
||||
- 若数据存在非法文段:
|
||||
- 在交互模式下会询问是否删除这些非法文段并继续;
|
||||
- 在非交互模式下,会直接 `logger.error` 并 `sys.exit(1)`,防止导入不完整数据。
|
||||
|
||||
> 提示:当前 `OpenIE.load()` 仍可能在内部要求你选择具体批次文件,若需完全无交互的导入,可后续扩展为显式指定文件路径。
|
||||
|
||||
### 3.4 刷新 LPMM 知识库:`refresh`
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a refresh
|
||||
# 或
|
||||
python scripts/lpmm_manager.py -a refresh --non-interactive
|
||||
```
|
||||
|
||||
两者行为相同:
|
||||
- 调用 `refresh_lpmm_knowledge.main()`,内部执行 `lpmm_start_up()`;
|
||||
- 日志中输出当前向量与 KG 规模,验证导入是否成功。
|
||||
|
||||
---
|
||||
|
||||
## 4. 典型流水线三:删除 / 回滚
|
||||
|
||||
删除操作通过 `lpmm_manager.py -a delete` 封装 `scripts/delete_lpmm_items.py`。
|
||||
|
||||
### 4.1 交互式删除(推荐人工操作)
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py --interactive
|
||||
```
|
||||
|
||||
菜单中选择:
|
||||
|
||||
1. `4. delete - 删除/回滚知识`
|
||||
2. 再选择删除方式:
|
||||
- 按哈希文件(`--hash-file`)
|
||||
- 按 OpenIE 批次(`--openie-file`)
|
||||
- 按原始语料 + 段落索引(`--raw-file + --raw-index`)
|
||||
- 按关键字搜索现有段落(`--search-text`)
|
||||
3. 管理脚本会根据你的选择自动拼好常用参数(是否删除实体/关系、是否删除孤立实体、是否 dry-run、是否自动确认等),最后调用 `delete_lpmm_items.py` 执行。
|
||||
|
||||
### 4.2 非交互删除(CI / 脚本场景)
|
||||
|
||||
#### 示例:按哈希文件删除(带完整保护参数)
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a delete --non-interactive -- \
|
||||
--hash-file data/lpmm_delete_hashes.txt \
|
||||
--delete-entities \
|
||||
--delete-relations \
|
||||
--remove-orphan-entities \
|
||||
--max-delete-nodes 2000 \
|
||||
--yes
|
||||
```
|
||||
|
||||
- `--non-interactive`(manager):禁止任何 `input()` 询问;
|
||||
- 子脚本 `delete_lpmm_items.py` 中:
|
||||
- `--hash-file`:指定待删段落哈希列表;
|
||||
- `--delete-entities` / `--delete-relations` / `--remove-orphan-entities`:同步清理实体与关系;
|
||||
- `--max-delete-nodes`:单次删除节点数上限,避免误删过大规模;
|
||||
- `--yes`:跳过终极确认,适合已验证的自动流水线。
|
||||
|
||||
#### 按 OpenIE 批次删除(常用于批次回滚)
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a delete --non-interactive -- \
|
||||
--openie-file data/openie/2025-01-01-12-00-openie.json \
|
||||
--delete-entities \
|
||||
--delete-relations \
|
||||
--remove-orphan-entities \
|
||||
--yes
|
||||
```
|
||||
|
||||
### 4.3 非交互模式下的安全限制
|
||||
|
||||
在 `delete_lpmm_items.py` 中:
|
||||
|
||||
- 若使用 `--search-text`,需要用户通过输入序号选择要删条目;
|
||||
- 在 `--non-interactive` 模式下,这一步会直接报错退出,提示改用 `--hash-file / --openie-file / --raw-file` 等纯参数方式。
|
||||
- 若未指定 `--yes`:
|
||||
- 非交互模式下会报错退出,提示「非交互模式且未指定 --yes,出于安全考虑删除操作已被拒绝」。
|
||||
|
||||
---
|
||||
|
||||
## 5. 典型流水线四:自检与状态检查
|
||||
|
||||
### 5.1 检查指定 OpenIE 批次状态:`batch_inspect`
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a batch_inspect -- --openie-file data/openie/xx.json
|
||||
```
|
||||
|
||||
输出该批次在当前库中的:
|
||||
- 段落向量数量 / KG 段落节点数量;
|
||||
- 实体向量数量 / KG 实体节点数量;
|
||||
- 关系向量数量;
|
||||
- 少量仍存在的样例内容。
|
||||
|
||||
常用于:
|
||||
- 导入后确认是否完全成功;
|
||||
- 删除后确认是否完全回滚。
|
||||
|
||||
### 5.2 查看整库状态:`global_inspect`
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a global_inspect
|
||||
```
|
||||
|
||||
输出:
|
||||
- 段落 / 实体 / 关系向量条数;
|
||||
- KG 节点/边总数,段落节点数、实体节点数;
|
||||
- 实体计数表 `ent_appear_cnt` 的条目数;
|
||||
- 少量剩余段落/实体样例,便于快速 sanity check。
|
||||
|
||||
---
|
||||
|
||||
## 6. 典型流水线五:检索效果回归测试
|
||||
|
||||
### 6.1 使用默认测试用例
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a test
|
||||
```
|
||||
|
||||
- 调用 `test_lpmm_retrieval.py` 内置的 `DEFAULT_TEST_CASES`;
|
||||
- 对每条用例输出:
|
||||
- 原始结果;
|
||||
- 状态(`PASS` / `WARN` / `NO_HIT` / `ERROR`);
|
||||
- 期望关键字与命中关键字列表。
|
||||
|
||||
### 6.2 自定义测试问题与期望关键字
|
||||
|
||||
```bash
|
||||
python scripts/lpmm_manager.py -a test -- --query "LPMM 是什么?" \
|
||||
--expect-keyword 哈希列表 \
|
||||
--expect-keyword 删除脚本
|
||||
```
|
||||
|
||||
也可以直接调用子脚本:
|
||||
|
||||
```bash
|
||||
python scripts/test_lpmm_retrieval.py \
|
||||
--query "LPMM 是什么?" \
|
||||
--expect-keyword 哈希列表 \
|
||||
--expect-keyword 删除脚本
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 推荐组合示例
|
||||
|
||||
### 7.1 导入 + 刷新 + 简单回归
|
||||
|
||||
```bash
|
||||
# 1. 执行全量导入(支持非交互)
|
||||
python scripts/lpmm_manager.py -a full_import --non-interactive
|
||||
|
||||
# 2. 使用内置用例做一次检索回归
|
||||
python scripts/lpmm_manager.py -a test
|
||||
```
|
||||
|
||||
### 7.2 批次回滚 + 自检
|
||||
|
||||
```bash
|
||||
TARGET_BATCH=data/openie/2025-01-01-12-00-openie.json
|
||||
|
||||
# 1. 按批次删除(非交互)
|
||||
python scripts/lpmm_manager.py -a delete --non-interactive -- \
|
||||
--openie-file "$TARGET_BATCH" \
|
||||
--delete-entities \
|
||||
--delete-relations \
|
||||
--remove-orphan-entities \
|
||||
--yes
|
||||
|
||||
# 2. 检查该批次是否彻底删除
|
||||
python scripts/lpmm_manager.py -a batch_inspect -- --openie-file "$TARGET_BATCH"
|
||||
|
||||
# 3. 查看全库状态
|
||||
python scripts/lpmm_manager.py -a global_inspect
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
如需扩展更多流水线(例如「导入特定批次后自动跑自定义测试用例」),可以在 `scripts/lpmm_manager.py` 中新增对应的 `ACTION_INFO` 条目和 `run_action` 分支,或直接在 CI / shell 脚本中串联上述命令。该管理脚本已支持参数化与非交互调用,适合作为二次封装的基础入口。
|
||||
|
||||
|
||||
411
docs-src/lpmm_user_guide.md
Normal file
411
docs-src/lpmm_user_guide.md
Normal file
@@ -0,0 +1,411 @@
|
||||
# LPMM 知识库脚本使用指南(零基础用户版)
|
||||
|
||||
本指南面向不熟悉命令行和代码的 C 端用户,帮助你完成:
|
||||
|
||||
- LPMM 知识库的初始部署(从本地 txt 到可检索知识库)
|
||||
- 安全删除知识(按批次、按原文、按哈希、按关键字)
|
||||
- 导入 / 删除后的自检与检索效果验证
|
||||
|
||||
> 说明:本文默认你已经完成 MaiBot 的基础安装,并能在项目根目录打开命令行终端。
|
||||
> 重要提醒:每次使用导入 / 删除相关脚本(如 `import_openie.py`、`delete_lpmm_items.py`)修改 LPMM 知识库后,聊天机器人 / WebUI 端要想看到最新知识,需要重启主程序,或在主程序内部显式调用一次 `lpmm_start_up()` 重新初始化 LPMM
|
||||
|
||||
---
|
||||
。
|
||||
|
||||
|
||||
## 一、需要用到的脚本一览
|
||||
|
||||
在项目根目录(`MaiBot-dev`)下,这些脚本是 LPMM 相关的“工具箱”:
|
||||
|
||||
- 导入相关:
|
||||
- `scripts/raw_data_preprocessor.py`
|
||||
从 `data/lpmm_raw_data` 目录读取 `.txt` 文件,按空行拆分为一个个段落,并做去重。
|
||||
- `scripts/info_extraction.py`
|
||||
调用大模型,从每个段落里抽取实体和三元组,生成中间的 OpenIE JSON 文件。
|
||||
- `scripts/import_openie.py`
|
||||
把 `data/openie` 目录中的 OpenIE JSON 文件导入到 LPMM 知识库(向量库 + 知识图)。
|
||||
- 删除相关:
|
||||
- `scripts/delete_lpmm_items.py`
|
||||
LPMM 知识库删除入口,支持按批次、按原始文本段落、按哈希列表、按关键字模糊搜索删除。
|
||||
- 自检相关:
|
||||
- `scripts/inspect_lpmm_global.py`
|
||||
查看整个知识库的当前状态:段落/实体/关系条数、知识图节点/边数量、示例内容等。
|
||||
- `scripts/inspect_lpmm_batch.py`
|
||||
针对某个 OpenIE JSON 批次,检查它在向量库和知识图中的“残留情况”(导入与删除前后对比)。
|
||||
- `scripts/test_lpmm_retrieval.py`
|
||||
使用几条预设问题测试 LPMM 检索能力,帮助你判断知识库是否正常工作。
|
||||
- `scripts/refresh_lpmm_knowledge.py`
|
||||
手动重新加载 `data/embedding` 和 `data/rag` 到内存,用来确认当前磁盘上的 LPMM 知识库能正常初始化。
|
||||
|
||||
> 注意:所有命令示例都假设你已经在虚拟环境中,命令行前缀类似 `(.venv)`,并且当前目录是项目根目录。
|
||||
|
||||
---
|
||||
|
||||
## 二、LPMM 知识库的初始部署
|
||||
|
||||
### 2.1 准备原始 txt 文本
|
||||
|
||||
1. 把要导入的知识文档放到:
|
||||
|
||||
```text
|
||||
data/lpmm_raw_data
|
||||
```
|
||||
|
||||
2. 文件要求:
|
||||
|
||||
- 必须是 `.txt` 文件,建议使用 UTF-8 编码;
|
||||
- 用**空行**分隔段落:一段话后空一行,即视为一条独立知识。
|
||||
|
||||
示例文件:
|
||||
|
||||
- `data/lpmm_raw_data/lpmm_large_sample.txt`:仓库内已经提供了一份大样本测试文本,可以直接用来练习。
|
||||
|
||||
### 2.2 第一步:预处理原始文本(拆段 + 去重)
|
||||
|
||||
在项目根目录执行:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/raw_data_preprocessor.py
|
||||
```
|
||||
|
||||
成功时通常会看到日志类似:
|
||||
|
||||
- 正在处理文件: `lpmm_large_sample.txt`
|
||||
- 共读取到 XX 条数据
|
||||
|
||||
这一步不会调用大模型,仅做拆段和去重。
|
||||
|
||||
### 2.3 第二步:进行信息抽取(生成 OpenIE JSON)
|
||||
|
||||
执行:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/info_extraction.py
|
||||
```
|
||||
|
||||
你会看到一个“重要操作确认”提示,说明:
|
||||
|
||||
- 信息抽取会调用大模型,消耗 API 费用和时间;
|
||||
- 如果确认无误,输入 `y` 回车继续。
|
||||
|
||||
提取过程中可能出现:
|
||||
|
||||
- 类似“模型 ... 网络错误(可重试)”这样的日志;
|
||||
这表示脚本在遇到网络问题时自动重试,一般无需手动干预。
|
||||
|
||||
运行结束后,会有类似提示:
|
||||
|
||||
```text
|
||||
信息提取结果已保存到: data/openie/11-27-10-06-openie.json
|
||||
```
|
||||
|
||||
- 请记住这个文件名,比如:`11-27-10-06-openie.json`
|
||||
接下来我们会用 `<OPENIE>` 来代指这类文件。
|
||||
|
||||
### 2.4 第三步:导入 OpenIE 数据到 LPMM 知识库
|
||||
|
||||
执行:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/import_openie.py
|
||||
```
|
||||
|
||||
这个脚本会:
|
||||
|
||||
- 从 `data/openie` 目录读取所有 `*.json` 文件,并合并导入;
|
||||
- 将新段落的嵌入向量写入 `data/embedding`;
|
||||
- 将三元组构建为知识图写入 `data/rag`。
|
||||
|
||||
> 提示:如果你希望“只导入某几批数据”,可以暂时把不需要的 JSON 文件移出 `data/openie`,导入结束后再移回。
|
||||
|
||||
### 2.5 第四步:全局自检(确认导入成功)
|
||||
|
||||
执行:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/inspect_lpmm_global.py
|
||||
```
|
||||
|
||||
你会看到类似输出:
|
||||
|
||||
- 段落向量条数: `52`
|
||||
- 实体向量条数: `260`
|
||||
- 关系向量条数: `299`
|
||||
- KG 节点总数 / 边总数 / 段落节点数 / 实体节点数
|
||||
- 若干条示例段落与实体内容预览
|
||||
|
||||
只要这些数字大于 0,就表示 LPMM 知识库已经有可用的数据了。
|
||||
|
||||
### 2.6 第五步:用脚本测试 LPMM 检索效果(可选但推荐)
|
||||
|
||||
执行:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/test_lpmm_retrieval.py
|
||||
```
|
||||
|
||||
脚本会:
|
||||
|
||||
- 自动初始化 LPMM(加载向量库与知识图);
|
||||
- 用几条预设问题查询 LPMM;
|
||||
- 打印原始检索结果和关键词命中情况。
|
||||
|
||||
你可以通过观察“RAW RESULT”里的内容,粗略判断:
|
||||
|
||||
- 能否命中与问题高度相关的知识;
|
||||
- 删除或导入新知识后,回答内容是否发生变化。
|
||||
|
||||
---
|
||||
|
||||
## 三、安全删除知识的几种方式
|
||||
|
||||
> 强烈建议:删除前先备份以下目录,以便“回档”:
|
||||
>
|
||||
> - `data/embedding`(向量库)
|
||||
> - `data/rag`(知识图)
|
||||
|
||||
所有删除操作使用同一个脚本:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/delete_lpmm_items.py [参数...]
|
||||
```
|
||||
|
||||
脚本特点:
|
||||
|
||||
- 删除前会打印“待删除段落数量 / 实体数量 / 关系数量 / 预计删除节点数”等摘要;
|
||||
- 需要你输入大写 `YES` 确认才会真正执行;
|
||||
- 支持多种删除策略,可灵活组合。
|
||||
|
||||
### 3.1 按批次删除(推荐:整批回滚)
|
||||
|
||||
适用场景:某次导入的整批知识有问题,希望整体回滚。
|
||||
|
||||
1. 删除前,先检查该批次状态:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/inspect_lpmm_batch.py ^
|
||||
--openie-file data/openie/<OPENIE>.json
|
||||
```
|
||||
|
||||
你会看到该批次:
|
||||
|
||||
- 段落:总计多少条、向量库剩余多少、KG 中剩余多少;
|
||||
- 实体、关系的类似统计;
|
||||
- 少量示例段落/实体内容预览。
|
||||
|
||||
2. 确认无误后,按批次删除:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/delete_lpmm_items.py ^
|
||||
--openie-file data/openie/<OPENIE>.json ^
|
||||
--delete-entities --delete-relations --remove-orphan-entities
|
||||
```
|
||||
|
||||
参数含义:
|
||||
|
||||
- `--delete-entities`:删除该批次涉及的实体向量;
|
||||
- `--delete-relations`:删除该批次涉及的关系向量;
|
||||
- `--remove-orphan-entities`:顺带清理删除后不再参与任何边的“孤立实体”节点。
|
||||
|
||||
3. 删除后再检查:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/inspect_lpmm_batch.py ^
|
||||
--openie-file data/openie/<OPENIE>.json
|
||||
|
||||
.\.venv\Scripts\python.exe scripts/inspect_lpmm_global.py
|
||||
```
|
||||
|
||||
若批次检查显示“向量库剩余 0 / KG 中剩余 0”,则说明该批次已被彻底删除。
|
||||
|
||||
### 3.2 按原始文本段落删除(精确定位某一段)
|
||||
|
||||
适用场景:某个原始 txt 的特定段落写错了,只想删这段对应的知识。
|
||||
|
||||
命令示例:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/delete_lpmm_items.py ^
|
||||
--raw-file data/lpmm_raw_data/lpmm_large_sample.txt ^
|
||||
--raw-index 2
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `--raw-index` 从 1 开始计数,可用逗号多选,例如:`1,3,5`;
|
||||
- 脚本会展示该段落的内容预览和哈希值,再请求你确认。
|
||||
|
||||
### 3.3 按哈希列表删除(进阶用法)
|
||||
|
||||
适用场景:你有一份“需要删除的段落哈希列表”(比如从其他系统导出)。
|
||||
|
||||
示例哈希列表文件:
|
||||
|
||||
- `data/openie/lpmm_delete_test_hashes.txt`
|
||||
|
||||
命令:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/delete_lpmm_items.py ^
|
||||
--hash-file data/openie/lpmm_delete_test_hashes.txt
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- 文件中每行一条,可以是 `paragraph-xxxx` 或纯哈希,脚本会自动识别;
|
||||
- 适合“精确控制删除哪些段落”,但准备哈希列表需要一定技术基础。
|
||||
|
||||
### 3.4 按关键字模糊搜索删除(对非技术用户最友好)
|
||||
|
||||
适用场景:只知道某段话里包含某个关键词,不知道它在哪个 txt 或批次里。
|
||||
|
||||
示例 1:删除与“近义词扩展”相关的段落
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/delete_lpmm_items.py --search-text "近义词扩展" --search-limit 5
|
||||
```
|
||||
|
||||
示例 2:删除与“LPMM”强相关的一些段落
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/delete_lpmm_items.py --search-text "LPMM" --search-limit 20
|
||||
|
||||
```
|
||||
|
||||
执行过程:
|
||||
|
||||
1. 脚本在当前段落库中查找包含该关键字的段落;
|
||||
2. 列出前 N 条候选(`--search-limit` 决定数量);
|
||||
3. 提示你输入要删除的序号列表,例如:`1,2,5`;
|
||||
4. 再次提示你输入 `YES` 确认,才会真正执行删除。
|
||||
|
||||
> 建议:
|
||||
>
|
||||
> - 第一次使用时可以先加 `--dry-run` 看看效果:
|
||||
> ```bash
|
||||
> .\.venv\Scripts\python.exe scripts/delete_lpmm_items.py ^
|
||||
> --search-text "LPMM" ^
|
||||
> --search-limit 20 ^
|
||||
> --dry-run
|
||||
> ```
|
||||
> - 确认候选列表确实是你要删的内容后,再去掉 `--dry-run` 正式执行。
|
||||
|
||||
---
|
||||
|
||||
## 四、自检:如何确认导入 / 删除是否“生效”
|
||||
|
||||
### 4.1 全局状态检查
|
||||
|
||||
每次导入或删除之后,建议跑一次:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/inspect_lpmm_global.py
|
||||
```
|
||||
|
||||
你可以在这里看到:
|
||||
|
||||
- 段落向量条数、实体向量条数、关系向量条数;
|
||||
- 知识图的节点总数、边总数、段落节点和实体节点数量;
|
||||
- 若干条“剩余段落示例”和“剩余实体示例”。
|
||||
|
||||
观察方式:
|
||||
|
||||
- 导入后:数字应该明显上升(说明新增数据生效);
|
||||
- 删除后:数字应该明显下降(说明删除操作生效)。
|
||||
|
||||
### 4.2 某个批次的局部状态
|
||||
|
||||
如果你想确认“某一个 OpenIE 文件对应的那一批知识”是否存在,可以使用:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/inspect_lpmm_batch.py --openie-file data/openie/<OPENIE>.json
|
||||
```
|
||||
|
||||
输出中会包含:
|
||||
|
||||
- 该批次的段落 / 实体 / 关系的总数;
|
||||
- 在向量库中还剩多少条,在 KG 中还剩多少条;
|
||||
- 若干条仍存在的段落/实体示例。
|
||||
|
||||
典型用法:
|
||||
|
||||
- 导入后立刻检查一次:确认这一批已经“写入”;
|
||||
- 删除后再检查一次:确认这一批是否已经“清空”。
|
||||
|
||||
### 4.3 检索效果回归测试
|
||||
|
||||
每次做完导入或删除,你都可以用这条命令快速验证检索效果:
|
||||
|
||||
```bash
|
||||
.\.venv\Scripts\python.exe scripts/test_lpmm_retrieval.py
|
||||
```
|
||||
|
||||
它会:
|
||||
|
||||
- 初始化 LPMM(加载当前向量库和知识图);
|
||||
- 用几条预设问题(包括与 LPMM 和配置相关的问题)进行检索;
|
||||
- 打印检索结果以及命中关键词情况。
|
||||
|
||||
通过对比不同时间点的输出,你可以判断:
|
||||
|
||||
- 某些知识是否已经被成功删除(不再出现在回答中);
|
||||
|
||||
- 新增的知识是否已经能被检索到。
|
||||
|
||||
### 4.4 进阶:一键刷新(可选)
|
||||
|
||||
- 想简单确认“现在这份 data/embedding + data/rag 是否健康”?执行:
|
||||
|
||||
`.\.venv\Scripts\python.exe scripts/refresh_lpmm_knowledge.py `
|
||||
|
||||
它会尝试初始化 LPMM,并打印当前段落/实体/关系条数和图大小。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 五、常见提示与注意事项
|
||||
|
||||
1. **看到“网络错误(可重试)”需要担心吗?**
|
||||
|
||||
- 不需要。
|
||||
- 这些日志说明脚本在自动处理网络抖动,多数情况下会在重试后成功返回结果。
|
||||
- 只要脚本最后没有报“重试耗尽并退出”,一般导入/提取结果是有效的。
|
||||
|
||||
2. **删除操作会不会“一删全没”?**
|
||||
|
||||
- 不会直接“一删全没”:
|
||||
- 每次删除会打印摘要信息;
|
||||
- 必须输入 `YES` 才会真正执行;
|
||||
- 大批次时还有 `--max-delete-nodes` 保护,超过阈值会警告。
|
||||
- 但仍然建议:
|
||||
- 在大规模删除前备份 `data/embedding` 和 `data/rag`;
|
||||
- 先通过 `--dry-run` 看看待删列表。
|
||||
|
||||
3. **可以多次导入吗?需要先清空吗?**
|
||||
|
||||
- 可以多次导入,系统会根据段落内容的哈希做去重;
|
||||
- 不需要每次都清空,只要你希望老数据仍然保留即可;
|
||||
- 如果你确实想“重来一遍”,可以:
|
||||
- 先备份,然后删除 `data/embedding` 和 `data/rag`;
|
||||
- 再重新跑导入流程。
|
||||
|
||||
4. **LPMM 开关在哪里?**
|
||||
|
||||
- 配置文件:`config/bot_config.toml`;
|
||||
- 小节:`[lpmm_knowledge]`;
|
||||
- 其中有 `enable = true/false` 开关:
|
||||
- 为 `true`:LPMM 知识库启用,问答时会使用;
|
||||
- 为 `false`:LPMM 关闭,即使知识库有数据,也不会参与回答。
|
||||
- 修改后需要重启主程序,让设置生效。
|
||||
|
||||
---
|
||||
|
||||
如果你是普通用户,只需要记住一句话:
|
||||
|
||||
> “导入三步走:预处理 → 信息抽取 → 导入 OpenIE;
|
||||
> 删除三步走:先检查 → 再删除 → 然后再检查。”
|
||||
|
||||
照着本指南中的命令一步一步执行,就可以安全地管理你的 LPMM 知识库。***
|
||||
10
dummy
Normal file
10
dummy
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
30
plugins/MaiBot_MCPBridgePlugin/.gitignore
vendored
Normal file
30
plugins/MaiBot_MCPBridgePlugin/.gitignore
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# 运行时配置(包含用户敏感信息)
|
||||
config.toml
|
||||
|
||||
# 备份文件
|
||||
*.backup.*
|
||||
*.bak
|
||||
|
||||
# 日志
|
||||
logs/
|
||||
*.log
|
||||
*.jsonl
|
||||
|
||||
# Python 缓存
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
|
||||
# 本地测试脚本(仓库不提交)
|
||||
test_*.py
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# 系统文件
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
24
plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md
Normal file
24
plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Changelog
|
||||
|
||||
本文件记录 `MaiBot_MCPBridgePlugin` 的用户可感知变更。
|
||||
|
||||
## 2.0.0
|
||||
|
||||
- 配置入口统一:MCP 服务器仅使用 Claude Desktop `mcpServers` JSON(`servers.claude_config_json`)
|
||||
- 兼容迁移:自动识别旧版 `servers.list` 并迁移为 `mcpServers`(需在 WebUI 保存一次固化)
|
||||
- 保持功能不变:保留 Workflow(硬流程/工具链)与 ReAct(软流程)双轨制能力
|
||||
- 精简实现:移除旧的 WebUI 导入导出/快速添加服务器实现与 `tomlkit` 依赖
|
||||
- 易用性:完善 Workflow 变量替换(支持数组下标与 bracket 写法),并优化 WebUI 配置区顺序
|
||||
|
||||
## 1.9.0
|
||||
|
||||
- 双轨制架构:ReAct(软流程)+ Workflow(硬流程/工具链)
|
||||
|
||||
## 1.8.0
|
||||
|
||||
- Workflow(工具链):多工具顺序执行、变量替换、自定义 Workflow 并注册为组合工具
|
||||
|
||||
## 1.7.0
|
||||
|
||||
- 断路器模式、状态刷新、工具搜索等易用性增强
|
||||
|
||||
356
plugins/MaiBot_MCPBridgePlugin/DEVELOPMENT.md
Normal file
356
plugins/MaiBot_MCPBridgePlugin/DEVELOPMENT.md
Normal file
@@ -0,0 +1,356 @@
|
||||
# MCP 桥接插件开发文档
|
||||
|
||||
本文档面向开发者,介绍插件的架构设计、核心模块和扩展方式。
|
||||
|
||||
## 架构概览
|
||||
|
||||
```
|
||||
MaiBot_MCPBridgePlugin/
|
||||
├── plugin.py # 主插件文件,包含所有核心逻辑
|
||||
├── mcp_client.py # MCP 客户端封装
|
||||
├── tool_chain.py # 工具链(Workflow)模块
|
||||
├── core/
|
||||
│ └── claude_config.py # Claude Desktop mcpServers 解析/迁移
|
||||
├── config.toml # 运行时配置
|
||||
└── _manifest.json # 插件元数据
|
||||
```
|
||||
|
||||
## 核心模块
|
||||
|
||||
### 1. MCP 客户端 (`mcp_client.py`)
|
||||
|
||||
封装了与 MCP 服务器的通信逻辑。
|
||||
|
||||
```python
|
||||
from .mcp_client import mcp_manager, MCPServerConfig, TransportType
|
||||
|
||||
# 添加服务器
|
||||
config = MCPServerConfig(
|
||||
name="my-server",
|
||||
transport=TransportType.STREAMABLE_HTTP,
|
||||
url="https://mcp.example.com/mcp"
|
||||
)
|
||||
await mcp_manager.add_server(config)
|
||||
|
||||
# 调用工具
|
||||
result = await mcp_manager.call_tool("server_tool_name", {"param": "value"})
|
||||
if result.success:
|
||||
print(result.content)
|
||||
```
|
||||
|
||||
**支持的传输类型:**
|
||||
- `STDIO`: 本地进程通信
|
||||
- `SSE`: Server-Sent Events
|
||||
- `HTTP`: HTTP 请求
|
||||
- `STREAMABLE_HTTP`: 流式 HTTP(推荐)
|
||||
|
||||
### 2. 工具注册系统
|
||||
|
||||
MCP 工具通过动态类创建注册到 MaiBot:
|
||||
|
||||
```python
|
||||
# 创建工具代理类
|
||||
class MCPToolProxy(BaseTool):
|
||||
name = "mcp_server_tool"
|
||||
description = "工具描述"
|
||||
parameters = [("param", ToolParamType.STRING, "参数描述", True, None)]
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args):
|
||||
result = await mcp_manager.call_tool(self._mcp_tool_key, function_args)
|
||||
return {"name": self.name, "content": result.content}
|
||||
```
|
||||
|
||||
### 3. 工具链模块 (`tool_chain.py`)
|
||||
|
||||
实现 Workflow 硬流程,支持多工具顺序执行。
|
||||
|
||||
```python
|
||||
from .tool_chain import ToolChainDefinition, ToolChainStep, tool_chain_manager
|
||||
|
||||
# 定义工具链
|
||||
chain = ToolChainDefinition(
|
||||
name="search_and_detail",
|
||||
description="搜索并获取详情",
|
||||
input_params={"query": "搜索关键词"},
|
||||
steps=[
|
||||
ToolChainStep(
|
||||
tool_name="mcp_server_search",
|
||||
args_template={"keyword": "${input.query}"},
|
||||
output_key="search_result"
|
||||
),
|
||||
ToolChainStep(
|
||||
tool_name="mcp_server_detail",
|
||||
args_template={"id": "${prev}"}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# 注册并执行
|
||||
tool_chain_manager.add_chain(chain)
|
||||
result = await tool_chain_manager.execute_chain("search_and_detail", {"query": "test"})
|
||||
```
|
||||
|
||||
**变量替换语法:**
|
||||
- `${input.参数名}`: 用户输入
|
||||
- `${step.输出键}`: 指定步骤的输出
|
||||
- `${prev}`: 上一步输出
|
||||
- `${prev.字段}`: 上一步输出(JSON)的字段
|
||||
- `${step.geo.return.0.location}` / `${step.geo.return[0].location}`: 数组下标访问
|
||||
- `${step.geo['return'][0]['location']}`: bracket 写法(最通用)
|
||||
|
||||
## 双轨制架构
|
||||
|
||||
### ReAct 软流程
|
||||
|
||||
将 MCP 工具注册到 MaiBot 的记忆检索 ReAct 系统,LLM 自主决策调用。
|
||||
|
||||
```python
|
||||
def _register_tools_to_react(self) -> int:
|
||||
from src.memory_system.retrieval_tools import register_memory_retrieval_tool
|
||||
|
||||
def make_execute_func(tool_key: str):
|
||||
async def execute_func(**kwargs) -> str:
|
||||
result = await mcp_manager.call_tool(tool_key, kwargs)
|
||||
return result.content if result.success else f"失败: {result.error}"
|
||||
return execute_func
|
||||
|
||||
register_memory_retrieval_tool(
|
||||
name="mcp_tool_name",
|
||||
description="工具描述",
|
||||
parameters=[{"name": "param", "type": "string", "required": True}],
|
||||
execute_func=make_execute_func("tool_key")
|
||||
)
|
||||
```
|
||||
|
||||
### Workflow 硬流程
|
||||
|
||||
用户预定义的固定执行流程,注册为组合工具。
|
||||
|
||||
```python
|
||||
def _register_tool_chains(self) -> None:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
for chain_name, chain in tool_chain_manager.get_enabled_chains().items():
|
||||
info, tool_class = tool_chain_registry.register_chain(chain)
|
||||
info.plugin_name = self.plugin_name
|
||||
component_registry.register_component(info, tool_class)
|
||||
```
|
||||
|
||||
## 配置系统
|
||||
|
||||
### MCP 服务器配置(Claude Desktop 规范)
|
||||
|
||||
插件只接受 Claude Desktop 的 `mcpServers` JSON(见 `core/claude_config.py`)。配置入口统一为:
|
||||
|
||||
- WebUI/配置文件:`[servers].claude_config_json`
|
||||
- 命令:`/mcp import`(合并 `mcpServers`)与 `/mcp export`(导出当前 `mcpServers`)
|
||||
|
||||
兼容迁移:
|
||||
- 若检测到旧版 `servers.list`,会自动迁移为 `servers.claude_config_json`(仅迁移到内存配置,需 WebUI 保存一次固化)。
|
||||
|
||||
### WebUI 配置 Schema
|
||||
|
||||
使用 `ConfigField` 定义 WebUI 配置项:
|
||||
|
||||
```python
|
||||
config_schema = {
|
||||
"section_name": {
|
||||
"field_name": ConfigField(
|
||||
type=str, # 类型: str, bool, int, float
|
||||
default="default_value", # 默认值
|
||||
description="字段描述",
|
||||
label="显示标签",
|
||||
input_type="textarea", # 输入类型: text, textarea, password
|
||||
rows=5, # textarea 行数
|
||||
disabled=True, # 只读
|
||||
choices=["a", "b"], # 下拉选项
|
||||
hint="提示信息",
|
||||
order=1, # 排序
|
||||
),
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### 配置读取
|
||||
|
||||
```python
|
||||
# 在组件中读取配置
|
||||
value = self.get_config("section.key", default="fallback")
|
||||
|
||||
# 在插件类中读取
|
||||
value = self.config.get("section", {}).get("key", "default")
|
||||
```
|
||||
|
||||
## 事件处理
|
||||
|
||||
### 启动事件
|
||||
|
||||
```python
|
||||
class MCPStartupHandler(BaseEventHandler):
|
||||
event_type = EventType.ON_START
|
||||
handler_name = "mcp_startup"
|
||||
|
||||
async def execute(self, message):
|
||||
global _plugin_instance
|
||||
if _plugin_instance:
|
||||
await _plugin_instance._async_connect_servers()
|
||||
return (True, True, None, None, None)
|
||||
```
|
||||
|
||||
### 停止事件
|
||||
|
||||
```python
|
||||
class MCPStopHandler(BaseEventHandler):
|
||||
event_type = EventType.ON_STOP
|
||||
handler_name = "mcp_stop"
|
||||
|
||||
async def execute(self, message):
|
||||
await mcp_manager.shutdown()
|
||||
return (True, True, None, None, None)
|
||||
```
|
||||
|
||||
## 命令系统
|
||||
|
||||
```python
|
||||
class MCPStatusCommand(BaseCommand):
|
||||
command_name = "mcp_status"
|
||||
command_pattern = r"^/mcp(?:\s+(?P<action>\S+))?(?:\s+(?P<arg>.+))?$"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
action = self.matched_groups.get("action", "")
|
||||
arg = self.matched_groups.get("arg", "")
|
||||
|
||||
if action == "tools":
|
||||
await self.send_text("工具列表...")
|
||||
elif action == "reconnect":
|
||||
await self._handle_reconnect(arg)
|
||||
|
||||
return (True, None, True) # (成功, 消息, 拦截)
|
||||
```
|
||||
|
||||
## 高级功能
|
||||
|
||||
### 调用追踪
|
||||
|
||||
```python
|
||||
from plugin import tool_call_tracer, ToolCallRecord
|
||||
|
||||
# 记录调用
|
||||
record = ToolCallRecord(
|
||||
call_id="xxx",
|
||||
timestamp=time.time(),
|
||||
tool_name="tool",
|
||||
server_name="server",
|
||||
arguments={"key": "value"},
|
||||
success=True,
|
||||
duration_ms=100.0
|
||||
)
|
||||
tool_call_tracer.record(record)
|
||||
|
||||
# 查询记录
|
||||
recent = tool_call_tracer.get_recent(10)
|
||||
by_tool = tool_call_tracer.get_by_tool("tool_name")
|
||||
```
|
||||
|
||||
### 调用缓存
|
||||
|
||||
```python
|
||||
from plugin import tool_call_cache
|
||||
|
||||
# 配置缓存
|
||||
tool_call_cache.configure(
|
||||
enabled=True,
|
||||
ttl=300, # 秒
|
||||
max_entries=200,
|
||||
exclude_tools="mcp_*_time_*" # 排除模式
|
||||
)
|
||||
|
||||
# 使用缓存
|
||||
cached = tool_call_cache.get("tool_name", {"param": "value"})
|
||||
if cached is None:
|
||||
result = await call_tool(...)
|
||||
tool_call_cache.set("tool_name", {"param": "value"}, result)
|
||||
```
|
||||
|
||||
### 权限控制
|
||||
|
||||
```python
|
||||
from plugin import permission_checker
|
||||
|
||||
# 配置权限
|
||||
permission_checker.configure(
|
||||
enabled=True,
|
||||
default_mode="allow_all", # 或 "deny_all"
|
||||
rules_json='[{"tool": "mcp_*_delete_*", "denied": ["qq:123:group"]}]',
|
||||
quick_deny_groups="123456789",
|
||||
quick_allow_users="111111111"
|
||||
)
|
||||
|
||||
# 检查权限
|
||||
allowed = permission_checker.check(
|
||||
tool_name="mcp_server_delete",
|
||||
chat_id="123456",
|
||||
user_id="789",
|
||||
is_group=True
|
||||
)
|
||||
```
|
||||
|
||||
### 断路器模式
|
||||
|
||||
MCP 客户端内置断路器,故障服务器快速失败:
|
||||
|
||||
- 连续失败 N 次后熔断
|
||||
- 熔断期间直接返回错误
|
||||
- 定期尝试恢复
|
||||
|
||||
## 扩展开发
|
||||
|
||||
### 添加新的传输类型
|
||||
|
||||
1. 在 `mcp_client.py` 中添加 `TransportType` 枚举值
|
||||
2. 实现对应的连接逻辑
|
||||
3. 更新 `_create_transport()` 方法
|
||||
|
||||
### 添加新的工具类型
|
||||
|
||||
1. 继承 `BaseTool` 创建新类
|
||||
2. 在 `get_plugin_components()` 中注册
|
||||
3. 实现 `execute()` 方法
|
||||
|
||||
### 添加新的命令
|
||||
|
||||
1. 在 `MCPStatusCommand.execute()` 中添加新的 action 分支
|
||||
2. 或创建新的 `BaseCommand` 子类
|
||||
|
||||
## 调试技巧
|
||||
|
||||
### 日志级别
|
||||
|
||||
```python
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("mcp_bridge_plugin")
|
||||
|
||||
logger.debug("详细调试信息")
|
||||
logger.info("一般信息")
|
||||
logger.warning("警告")
|
||||
logger.error("错误")
|
||||
```
|
||||
|
||||
### 常用调试命令
|
||||
|
||||
```bash
|
||||
/mcp # 查看状态
|
||||
/mcp tools # 查看工具列表
|
||||
/mcp trace # 查看调用记录
|
||||
/mcp cache # 查看缓存状态
|
||||
/mcp chain # 查看工具链
|
||||
```
|
||||
|
||||
## 更新日志
|
||||
|
||||
见 `plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md`
|
||||
|
||||
## 开发约定
|
||||
|
||||
- 本仓库不提交测试脚本/临时复现文件;如需本地验证,可自行在工作区创建未跟踪文件(建议放到 `.local/` 并加入 `.gitignore`)。
|
||||
357
plugins/MaiBot_MCPBridgePlugin/README.md
Normal file
357
plugins/MaiBot_MCPBridgePlugin/README.md
Normal file
@@ -0,0 +1,357 @@
|
||||
# MCP 桥接插件
|
||||
|
||||
将 [MCP (Model Context Protocol)](https://modelcontextprotocol.io/) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。
|
||||
|
||||
<img width="3012" height="1794" alt="image" src="https://github.com/user-attachments/assets/ece56404-301a-4abf-b16d-87bd430fc977" />
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 安装
|
||||
|
||||
```bash
|
||||
# 克隆到 MaiBot 插件目录
|
||||
cd /path/to/MaiBot/plugins
|
||||
git clone https://github.com/CharTyr/MaiBot_MCPBridgePlugin.git MCPBridgePlugin
|
||||
|
||||
# 安装依赖
|
||||
pip install mcp
|
||||
|
||||
# 复制配置文件
|
||||
cd MCPBridgePlugin
|
||||
cp config.example.toml config.toml
|
||||
```
|
||||
|
||||
### 2. 添加服务器
|
||||
|
||||
编辑 `config.toml`,在 `[servers]` 的 `claude_config_json` 中填写 Claude Desktop 的 `mcpServers` JSON:
|
||||
|
||||
```toml
|
||||
[servers]
|
||||
claude_config_json = '''
|
||||
{
|
||||
"mcpServers": {
|
||||
"time": { "transport": "streamable_http", "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time" },
|
||||
"my-server": { "transport": "streamable_http", "url": "https://mcp.xxx.com/mcp", "headers": { "Authorization": "Bearer 你的密钥" } },
|
||||
"fetch": { "command": "uvx", "args": ["mcp-server-fetch"] }
|
||||
}
|
||||
}
|
||||
'''
|
||||
```
|
||||
|
||||
### 3. 启动
|
||||
|
||||
重启 MaiBot,或发送 `/mcp reconnect`
|
||||
|
||||
---
|
||||
|
||||
## 📚 去哪找 MCP 服务器?
|
||||
|
||||
| 平台 | 说明 |
|
||||
|------|------|
|
||||
| [mcp.modelscope.cn](https://mcp.modelscope.cn/) | 魔搭 ModelScope,免费推荐 |
|
||||
| [smithery.ai](https://smithery.ai/) | MCP 服务器注册中心 |
|
||||
| [github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) | 官方服务器列表 |
|
||||
|
||||
---
|
||||
|
||||
## 💡 常用命令
|
||||
|
||||
| 命令 | 说明 |
|
||||
|------|------|
|
||||
| `/mcp` | 查看连接状态 |
|
||||
| `/mcp tools` | 查看可用工具 |
|
||||
| `/mcp reconnect` | 重连服务器 |
|
||||
| `/mcp trace` | 查看调用记录 |
|
||||
| `/mcp cache` | 查看缓存状态 |
|
||||
| `/mcp perm` | 查看权限配置 |
|
||||
| `/mcp import <json>` | 🆕 导入 Claude Desktop 配置 |
|
||||
| `/mcp export` | 🆕 导出配置 |
|
||||
| `/mcp search <关键词>` | 🆕 搜索工具 |
|
||||
| `/mcp chain` | 🆕 查看工具链 |
|
||||
| `/mcp chain <名称>` | 🆕 查看工具链详情 |
|
||||
| `/mcp chain test <名称> <参数>` | 🆕 测试执行工具链 |
|
||||
|
||||
---
|
||||
|
||||
## ✨ 功能特性
|
||||
|
||||
### 核心功能
|
||||
- 🔌 多服务器同时连接
|
||||
- 📡 支持 stdio / SSE / HTTP / Streamable HTTP
|
||||
- 🔄 自动重试、心跳检测、断线重连
|
||||
- 🖥️ WebUI 完整配置支持
|
||||
|
||||
### 双轨制架构
|
||||
- 🔄 **ReAct(软流程)**:LLM 自主决策,多轮动态调用 MCP 工具(适合探索式场景)
|
||||
- 🔗 **Workflow(硬流程/工具链)**:用户预定义步骤顺序与参数传递(适合可控可复用场景)
|
||||
|
||||
### 高级功能
|
||||
- 📦 Resources 支持(实验性)
|
||||
- 📝 Prompts 支持(实验性)
|
||||
- 🔄 结果后处理(LLM 摘要提炼)
|
||||
- 🔍 调用追踪 / 🗄️ 调用缓存 / 🔐 权限控制 / 🚫 工具禁用
|
||||
|
||||
### 更新日志
|
||||
- 见 `plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md`
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ 配置说明
|
||||
|
||||
### 服务器配置
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"server_name": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://..."
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `mcpServers.<name>` | 服务器名称(唯一) |
|
||||
| `enabled` | 是否启用(可选,默认 true) |
|
||||
| `transport` | `stdio` / `sse` / `http` / `streamable_http` |
|
||||
| `url` | 远程服务器地址 |
|
||||
| `headers` | 🆕 鉴权头(如 `{"Authorization": "Bearer xxx"}`) |
|
||||
| `command` / `args` | 本地服务器启动命令 |
|
||||
|
||||
### 权限控制
|
||||
|
||||
**快捷配置(推荐):**
|
||||
```toml
|
||||
[permissions]
|
||||
perm_enabled = true
|
||||
quick_deny_groups = "123456789" # 禁用的群号
|
||||
quick_allow_users = "111111111" # 管理员白名单
|
||||
```
|
||||
|
||||
**高级规则:**
|
||||
```json
|
||||
[{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}]
|
||||
```
|
||||
|
||||
### 工具禁用
|
||||
|
||||
```toml
|
||||
[tools]
|
||||
disabled_tools = '''
|
||||
mcp_filesystem_delete_file
|
||||
mcp_filesystem_write_file
|
||||
'''
|
||||
```
|
||||
|
||||
### 调用缓存
|
||||
|
||||
```toml
|
||||
[settings]
|
||||
cache_enabled = true
|
||||
cache_ttl = 300
|
||||
cache_exclude_tools = "mcp_*_time_*"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ❓ 常见问题
|
||||
|
||||
**Q: 工具没有注册?**
|
||||
- 检查 `enabled = true`
|
||||
- 检查 MaiBot 日志错误信息
|
||||
- 确认 `pip install mcp`
|
||||
|
||||
**Q: JSON 格式报错?**
|
||||
- 多行 JSON 用 `'''` 三引号包裹
|
||||
- 使用英文双引号 `"`
|
||||
|
||||
**Q: 如何手动重连?**
|
||||
- `/mcp reconnect` 或 `/mcp reconnect 服务器名`
|
||||
|
||||
---
|
||||
|
||||
## 📥 配置导入导出(Claude mcpServers)
|
||||
|
||||
### 从 Claude Desktop 导入
|
||||
|
||||
如果你已有 Claude Desktop 的 MCP 配置,可以直接导入:
|
||||
|
||||
```
|
||||
/mcp import {"mcpServers":{"time":{"command":"uvx","args":["mcp-server-time"]},"fetch":{"command":"uvx","args":["mcp-server-fetch"]}}}
|
||||
```
|
||||
|
||||
支持的格式:
|
||||
- Claude Desktop 格式(`mcpServers` 对象)
|
||||
- 兼容旧版:MaiBot servers 列表数组(将自动迁移为 `mcpServers`)
|
||||
|
||||
### 导出配置
|
||||
|
||||
```
|
||||
/mcp export # 导出为 Claude Desktop 格式(默认)
|
||||
/mcp export claude # 导出为 Claude Desktop 格式
|
||||
```
|
||||
|
||||
### 注意事项
|
||||
- 导入时会自动跳过同名服务器
|
||||
- 导入后需要发送 `/mcp reconnect` 使配置生效
|
||||
- 支持 stdio、sse、http、streamable_http 全部传输类型
|
||||
|
||||
---
|
||||
|
||||
## 🔗 Workflow(硬流程/工具链)
|
||||
|
||||
工具链允许你将多个 MCP 工具按顺序执行,后续工具可以使用前序工具的输出作为输入。
|
||||
|
||||
### 1 分钟上手(推荐 WebUI)
|
||||
1. 先完成 MCP 服务器配置并 `/mcp reconnect`
|
||||
2. 发送 `/mcp tools`,复制你要用的工具名
|
||||
3. 打开 WebUI → 「Workflow(硬流程/工具链)」→ 用“快速添加”表单填入:
|
||||
- 名称/描述
|
||||
- 输入参数(每行 `参数名=描述`)
|
||||
- 执行步骤(每行 `工具名|参数JSON|输出键`)
|
||||
4. 在“确认添加”中输入 `ADD` 并保存
|
||||
|
||||
### 快速添加工具链(推荐)
|
||||
|
||||
在 WebUI 的「工具链」配置区,使用表单快速添加:
|
||||
|
||||
1. **名称**: 填写工具链名称(英文,如 `search_and_detail`)
|
||||
2. **描述**: 填写工具链用途(供 LLM 理解何时使用)
|
||||
3. **输入参数**: 每行一个,格式 `参数名=描述`
|
||||
```
|
||||
query=搜索关键词
|
||||
max_results=最大结果数
|
||||
```
|
||||
4. **执行步骤**: 每行一个,格式 `工具名|参数JSON|输出键`
|
||||
```
|
||||
mcp_server_search|{"keyword":"${input.query}"}|search_result
|
||||
mcp_server_detail|{"id":"${prev}"}|
|
||||
```
|
||||
5. **确认添加**: 输入 `ADD` 并保存
|
||||
|
||||
### JSON 配置方式
|
||||
|
||||
也可以直接在「工具链列表」中编写 JSON:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "search_and_detail",
|
||||
"description": "先搜索模组,再获取详情",
|
||||
"input_params": {
|
||||
"query": "搜索关键词"
|
||||
},
|
||||
"steps": [
|
||||
{
|
||||
"tool_name": "mcp_mcmod_search_mod",
|
||||
"args_template": {"keyword": "${input.query}", "limit": 1},
|
||||
"output_key": "search_result",
|
||||
"description": "搜索模组"
|
||||
},
|
||||
{
|
||||
"tool_name": "mcp_mcmod_get_mod_detail",
|
||||
"args_template": {"mod_id": "${prev}"},
|
||||
"description": "获取详情"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### 变量替换
|
||||
|
||||
| 变量格式 | 说明 |
|
||||
|---------|------|
|
||||
| `${input.参数名}` | 用户输入的参数 |
|
||||
| `${step.输出键}` | 某个步骤的输出(通过 `output_key` 指定) |
|
||||
| `${prev}` | 上一步的输出 |
|
||||
| `${prev.字段}` | 上一步输出(JSON)的某个字段 |
|
||||
| `${step.geo.return.0.location}` | 数组下标访问(dot) |
|
||||
| `${step.geo.return[0].location}` | 数组下标访问([]) |
|
||||
| `${step.geo['return'][0]['location']}` | bracket 写法(最通用) |
|
||||
|
||||
### 工具链字段说明
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `name` | 工具链名称,将生成 `chain_xxx` 工具 |
|
||||
| `description` | 描述,供 LLM 理解何时使用 |
|
||||
| `input_params` | 输入参数定义 `{参数名: 描述}` |
|
||||
| `steps` | 执行步骤数组 |
|
||||
| `steps[].tool_name` | 要调用的工具名 |
|
||||
| `steps[].args_template` | 参数模板,支持变量替换 |
|
||||
| `steps[].output_key` | 输出存储键名(可选) |
|
||||
| `steps[].optional` | 是否可选,失败时继续执行(默认 false) |
|
||||
|
||||
### 命令
|
||||
|
||||
```bash
|
||||
/mcp chain # 查看所有工具链
|
||||
/mcp chain list # 列出工具链
|
||||
/mcp chain <名称> # 查看详情
|
||||
/mcp chain test <名称> {"query": "JEI"} # 测试执行
|
||||
/mcp chain reload # 重新加载配置
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 双轨制架构
|
||||
|
||||
MCP 桥接插件支持两种工具调用模式,可根据场景选择:
|
||||
|
||||
### ReAct 软流程
|
||||
|
||||
LLM 自主决策的多轮工具调用模式,适合复杂、不确定的场景。
|
||||
|
||||
**工作原理:**
|
||||
1. 用户提问 → LLM 分析需要什么信息
|
||||
2. LLM 选择调用工具 → 获取结果
|
||||
3. LLM 观察结果 → 决定是否需要更多信息
|
||||
4. 重复 2-3 直到信息足够 → 生成最终回答
|
||||
|
||||
**启用方式:**
|
||||
在 WebUI「ReAct (软流程)」配置区启用,MCP 工具将自动注册到 MaiBot 的记忆检索 ReAct 系统。
|
||||
|
||||
**适用场景:**
|
||||
- 复杂问题需要多步推理
|
||||
- 不确定需要调用哪些工具
|
||||
- 需要根据中间结果动态调整
|
||||
|
||||
### Workflow 硬流程
|
||||
|
||||
用户预定义的工作流,固定执行顺序,适合可靠、可控的场景。
|
||||
|
||||
**工作原理:**
|
||||
1. 用户定义步骤顺序和参数传递
|
||||
2. 按顺序执行每个步骤
|
||||
3. 后续步骤可使用前序步骤的输出
|
||||
4. 返回最终结果
|
||||
|
||||
**适用场景:**
|
||||
- 流程固定、可预测
|
||||
- 需要可靠、可重复的执行
|
||||
- 希望精确控制工具调用顺序
|
||||
|
||||
### 对比
|
||||
|
||||
| 特性 | ReAct 软流程 | Workflow 硬流程 |
|
||||
|------|-------------|----------------|
|
||||
| 决策者 | LLM 自主决策 | 用户预定义 |
|
||||
| 灵活性 | 高,动态调整 | 低,固定流程 |
|
||||
| 可预测性 | 低 | 高 |
|
||||
| 适用场景 | 复杂、探索性任务 | 固定、重复性任务 |
|
||||
| 配置方式 | 启用即可 | 需要定义步骤 |
|
||||
|
||||
---
|
||||
|
||||
## 📋 依赖
|
||||
|
||||
- MaiBot >= 0.11.6
|
||||
- Python >= 3.10
|
||||
- mcp >= 1.0.0
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
AGPL-3.0
|
||||
44
plugins/MaiBot_MCPBridgePlugin/__init__.py
Normal file
44
plugins/MaiBot_MCPBridgePlugin/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
MCP 桥接插件
|
||||
将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot
|
||||
|
||||
v1.1.0 新增功能:
|
||||
- 心跳检测和自动重连
|
||||
- 调用统计(次数、成功率、耗时)
|
||||
- 更好的错误处理
|
||||
|
||||
v1.2.0 新增功能:
|
||||
- Resources 支持(资源读取)
|
||||
- Prompts 支持(提示模板)
|
||||
"""
|
||||
|
||||
from .plugin import MCPBridgePlugin, mcp_tool_registry, MCPStartupHandler, MCPStopHandler
|
||||
from .mcp_client import (
|
||||
mcp_manager,
|
||||
MCPClientManager,
|
||||
MCPServerConfig,
|
||||
TransportType,
|
||||
MCPCallResult,
|
||||
MCPToolInfo,
|
||||
MCPResourceInfo,
|
||||
MCPPromptInfo,
|
||||
ToolCallStats,
|
||||
ServerStats,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MCPBridgePlugin",
|
||||
"mcp_tool_registry",
|
||||
"mcp_manager",
|
||||
"MCPClientManager",
|
||||
"MCPServerConfig",
|
||||
"TransportType",
|
||||
"MCPCallResult",
|
||||
"MCPToolInfo",
|
||||
"MCPResourceInfo",
|
||||
"MCPPromptInfo",
|
||||
"ToolCallStats",
|
||||
"ServerStats",
|
||||
"MCPStartupHandler",
|
||||
"MCPStopHandler",
|
||||
]
|
||||
67
plugins/MaiBot_MCPBridgePlugin/_manifest.json
Normal file
67
plugins/MaiBot_MCPBridgePlugin/_manifest.json
Normal file
@@ -0,0 +1,67 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "MCP桥接插件",
|
||||
"version": "2.0.0",
|
||||
"description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具",
|
||||
"author": {
|
||||
"name": "CharTyr",
|
||||
"url": "https://github.com/CharTyr"
|
||||
},
|
||||
"license": "AGPL-3.0",
|
||||
"host_application": {
|
||||
"min_version": "0.11.6"
|
||||
},
|
||||
"homepage_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"repository_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"keywords": [
|
||||
"mcp",
|
||||
"bridge",
|
||||
"tool",
|
||||
"integration",
|
||||
"resources",
|
||||
"prompts",
|
||||
"post-process",
|
||||
"cache",
|
||||
"trace",
|
||||
"permissions",
|
||||
"import",
|
||||
"export",
|
||||
"claude-desktop",
|
||||
"workflow",
|
||||
"react",
|
||||
"agent"
|
||||
],
|
||||
"categories": [
|
||||
"工具扩展",
|
||||
"外部集成"
|
||||
],
|
||||
"default_locale": "zh-CN",
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"components": [],
|
||||
"features": [
|
||||
"支持多个 MCP 服务器",
|
||||
"自动发现并注册 MCP 工具",
|
||||
"支持 stdio、SSE、HTTP、Streamable HTTP 四种传输方式",
|
||||
"工具参数自动转换",
|
||||
"心跳检测与自动重连",
|
||||
"调用统计(次数、成功率、耗时)",
|
||||
"WebUI 配置支持",
|
||||
"Resources 支持(实验性)",
|
||||
"Prompts 支持(实验性)",
|
||||
"结果后处理(LLM 摘要提炼)",
|
||||
"工具禁用管理",
|
||||
"调用链路追踪",
|
||||
"工具调用缓存(LRU)",
|
||||
"工具权限控制(群/用户级别)",
|
||||
"配置导入导出(Claude Desktop mcpServers)",
|
||||
"断路器模式(故障快速失败)",
|
||||
"状态实时刷新",
|
||||
"Workflow 硬流程(顺序执行多个工具)",
|
||||
"Workflow 快速添加(表单式配置)",
|
||||
"ReAct 软流程(LLM 自主多轮调用)",
|
||||
"双轨制架构(软流程 + 硬流程)"
|
||||
]
|
||||
},
|
||||
"id": "MaiBot Community.MCPBridgePlugin"
|
||||
}
|
||||
309
plugins/MaiBot_MCPBridgePlugin/config.example.toml
Normal file
309
plugins/MaiBot_MCPBridgePlugin/config.example.toml
Normal file
@@ -0,0 +1,309 @@
|
||||
# MCP桥接插件 - 配置文件示例
|
||||
# 将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot
|
||||
#
|
||||
# 使用方法:复制此文件为 config.toml,然后根据需要修改配置
|
||||
#
|
||||
# ============================================================
|
||||
# 🎯 快速开始(三步)
|
||||
# ============================================================
|
||||
# 1. 在下方 [servers] 添加 MCP 服务器配置
|
||||
# 2. 将 enabled 改为 true 启用服务器
|
||||
# 3. 重启 MaiBot 或发送 /mcp reconnect
|
||||
#
|
||||
# ============================================================
|
||||
# 📚 去哪找 MCP 服务器?
|
||||
# ============================================================
|
||||
#
|
||||
# 【远程服务(推荐新手)】
|
||||
# - ModelScope: https://mcp.modelscope.cn/ (免费,推荐)
|
||||
# - Smithery: https://smithery.ai/
|
||||
# - Glama: https://glama.ai/mcp/servers
|
||||
#
|
||||
# 【本地服务(需要 npx 或 uvx)】
|
||||
# - 官方列表: https://github.com/modelcontextprotocol/servers
|
||||
#
|
||||
# ============================================================
|
||||
|
||||
# ============================================================
|
||||
# 🔌 MCP 服务器配置
|
||||
# ============================================================
|
||||
#
|
||||
# ⚠️ 重要:配置格式(Claude Desktop 规范)
|
||||
# ────────────────────────────────────────────────────────────
|
||||
# 统一使用 Claude Desktop 的 mcpServers JSON。
|
||||
#
|
||||
# claude_config_json 的内容应为 JSON 对象:
|
||||
# {
|
||||
# "mcpServers": {
|
||||
# "server_name": { ...server config... },
|
||||
# "another": { ... }
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# 每个服务器支持字段:
|
||||
# transport - 传输方式: "stdio" / "sse" / "http" / "streamable_http"(可选)
|
||||
# url - 服务器地址(sse/http/streamable_http 模式)
|
||||
# command - 启动命令(stdio 模式,如 "npx" / "uvx")
|
||||
# args - 命令参数数组(stdio 模式)
|
||||
# env - 环境变量对象(stdio 模式,可选)
|
||||
# headers - 鉴权头(可选,如 {"Authorization": "Bearer xxx"})
|
||||
# enabled - 是否启用(可选,默认 true)
|
||||
# post_process - 服务器级别后处理配置(可选)
|
||||
#
|
||||
# ============================================================
|
||||
|
||||
[servers]
|
||||
claude_config_json = '''
|
||||
{
|
||||
"mcpServers": {
|
||||
"time-mcp-server": {
|
||||
"enabled": false,
|
||||
"transport": "streamable_http",
|
||||
"url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time"
|
||||
},
|
||||
"my-auth-server": {
|
||||
"enabled": false,
|
||||
"transport": "streamable_http",
|
||||
"url": "https://mcp.api-inference.modelscope.net/xxxxxx/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer ms-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
|
||||
}
|
||||
},
|
||||
"fetch-local": {
|
||||
"enabled": false,
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"]
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
# ============================================================
|
||||
# 插件基本信息
|
||||
# ============================================================
|
||||
[plugin]
|
||||
name = "mcp_bridge_plugin"
|
||||
version = "2.0.0"
|
||||
config_version = "2.0.0"
|
||||
enabled = false # 默认禁用,在 WebUI 中启用
|
||||
|
||||
# ============================================================
|
||||
# Workflow(硬流程/工具链)
|
||||
# ============================================================
|
||||
#
|
||||
# 作用:把多个工具按顺序执行;后续步骤可引用前序输出。
|
||||
#
|
||||
# ✅ 推荐配置方式:WebUI「Workflow(硬流程/工具链)」里用“快速添加”表单。
|
||||
# ✅ 也可以直接写 chains_list(JSON 数组)。
|
||||
#
|
||||
# 变量替换:
|
||||
# ${input.xxx} - 用户输入
|
||||
# ${step.<output_key>} - 指定步骤输出(需设置 output_key)
|
||||
# ${prev} - 上一步输出
|
||||
# ${prev.字段} - 上一步输出(JSON)的字段
|
||||
# ${step.geo.return.0.location} - 数组/下标访问(dot)
|
||||
# ${step.geo.return[0].location} - 数组/下标访问([])
|
||||
# ${step.geo['return'][0]['location']} - bracket 写法
|
||||
#
|
||||
# ============================================================
|
||||
|
||||
[tool_chains]
|
||||
chains_enabled = true
|
||||
|
||||
chains_list = '''
|
||||
[
|
||||
{
|
||||
"name": "search_and_detail",
|
||||
"description": "先搜索,再根据结果获取详情",
|
||||
"input_params": { "query": "搜索关键词" },
|
||||
"steps": [
|
||||
{ "tool_name": "把这里替换成你的搜索工具名", "args_template": { "keyword": "${input.query}" }, "output_key": "search" },
|
||||
{ "tool_name": "把这里替换成你的详情工具名", "args_template": { "id": "${prev}" } }
|
||||
]
|
||||
}
|
||||
]
|
||||
'''
|
||||
|
||||
# ============================================================
|
||||
# ReAct(软流程)
|
||||
# ============================================================
|
||||
#
|
||||
# 作用:把 MCP 工具注册到 MaiBot 的 ReAct 系统,LLM 可自主多轮调用。
|
||||
#
|
||||
# 注意:ReAct 适合“探索式/不确定”场景;Workflow 适合“固定/可控”场景。
|
||||
#
|
||||
# ============================================================
|
||||
|
||||
[react]
|
||||
react_enabled = false
|
||||
filter_mode = "whitelist" # whitelist / blacklist
|
||||
tool_filter = "" # 每行一个工具名,支持通配符 *
|
||||
|
||||
# ============================================================
|
||||
# 全局设置(高级设置建议保持默认)
|
||||
# ============================================================
|
||||
[settings]
|
||||
# 🏷️ 工具前缀 - 用于区分 MCP 工具和原生工具
|
||||
tool_prefix = "mcp"
|
||||
|
||||
# ⏱️ 连接超时(秒)
|
||||
connect_timeout = 30.0
|
||||
|
||||
# ⏱️ 调用超时(秒)
|
||||
call_timeout = 60.0
|
||||
|
||||
# 🔄 自动连接 - 启动时自动连接所有已启用的服务器
|
||||
auto_connect = true
|
||||
|
||||
# 🔁 重试次数 - 连接失败时的重试次数
|
||||
retry_attempts = 3
|
||||
|
||||
# ⏳ 重试间隔(秒)
|
||||
retry_interval = 5.0
|
||||
|
||||
# 💓 心跳检测 - 定期检测服务器连接状态
|
||||
heartbeat_enabled = true
|
||||
|
||||
# 💓 心跳间隔(秒)- 建议 30-120 秒
|
||||
heartbeat_interval = 60.0
|
||||
|
||||
# 🔄 自动重连 - 检测到断开时自动尝试重连
|
||||
auto_reconnect = true
|
||||
|
||||
# 🔄 最大重连次数 - 连续重连失败后暂停重连
|
||||
max_reconnect_attempts = 3
|
||||
|
||||
# ============================================================
|
||||
# 高级功能(实验性)
|
||||
# ============================================================
|
||||
# 📦 启用 Resources - 允许读取 MCP 服务器提供的资源
|
||||
enable_resources = false
|
||||
|
||||
# 📝 启用 Prompts - 允许使用 MCP 服务器提供的提示模板
|
||||
enable_prompts = false
|
||||
|
||||
# ============================================================
|
||||
# 结果后处理功能
|
||||
# ============================================================
|
||||
# 当 MCP 工具返回的内容过长时,使用 LLM 对结果进行摘要提炼
|
||||
|
||||
# 🔄 启用结果后处理
|
||||
post_process_enabled = false
|
||||
|
||||
# 📏 后处理阈值(字符数)- 结果长度超过此值才触发后处理
|
||||
post_process_threshold = 500
|
||||
|
||||
# 🔢 后处理输出限制 - LLM 摘要输出的最大 token 数
|
||||
post_process_max_tokens = 500
|
||||
|
||||
# 🤖 后处理模型(可选)- 留空则使用 utils 模型组
|
||||
post_process_model = ""
|
||||
|
||||
# 🧠 后处理提示词模板
|
||||
post_process_prompt = '''用户问题:{query}
|
||||
|
||||
工具返回内容:
|
||||
{result}
|
||||
|
||||
请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:'''
|
||||
|
||||
# ============================================================
|
||||
# 调用链路追踪
|
||||
# ============================================================
|
||||
# 记录工具调用详情,便于调试和分析
|
||||
|
||||
# 🔍 启用调用追踪
|
||||
trace_enabled = true
|
||||
|
||||
# 📊 追踪记录上限 - 内存中保留的最大记录数
|
||||
trace_max_records = 50
|
||||
|
||||
# 📝 追踪日志文件 - 是否将追踪记录写入日志文件
|
||||
# 启用后记录写入 plugins/MaiBot_MCPBridgePlugin/logs/trace.jsonl
|
||||
trace_log_enabled = false
|
||||
|
||||
# ============================================================
|
||||
# 工具调用缓存
|
||||
# ============================================================
|
||||
# 缓存相同参数的调用结果,减少重复请求
|
||||
|
||||
# 🗄️ 启用调用缓存
|
||||
cache_enabled = false
|
||||
|
||||
# ⏱️ 缓存有效期(秒)
|
||||
cache_ttl = 300
|
||||
|
||||
# 📦 最大缓存条目 - 超出后 LRU 淘汰
|
||||
cache_max_entries = 200
|
||||
|
||||
# 🚫 缓存排除列表 - 即不缓存的工具(每行一个,支持通配符 *)
|
||||
# 时间类、随机类工具建议排除
|
||||
cache_exclude_tools = '''
|
||||
mcp_*_time_*
|
||||
mcp_*_random_*
|
||||
'''
|
||||
|
||||
# ============================================================
|
||||
# 工具管理
|
||||
# ============================================================
|
||||
[tools]
|
||||
# 📋 工具清单(只读)- 启动后自动生成
|
||||
tool_list = "(启动后自动生成)"
|
||||
|
||||
# 🚫 禁用工具列表 - 要禁用的工具名(每行一个)
|
||||
# 从上方工具清单复制工具名,禁用后该工具不会被 LLM 调用
|
||||
# 示例:
|
||||
# disabled_tools = '''
|
||||
# mcp_filesystem_delete_file
|
||||
# mcp_filesystem_write_file
|
||||
# '''
|
||||
disabled_tools = ""
|
||||
|
||||
# ============================================================
|
||||
# 权限控制
|
||||
# ============================================================
|
||||
[permissions]
|
||||
# 🔐 启用权限控制 - 按群/用户限制工具使用
|
||||
perm_enabled = false
|
||||
|
||||
# 📋 默认模式
|
||||
# allow_all: 未配置规则的工具默认允许
|
||||
# deny_all: 未配置规则的工具默认禁止
|
||||
perm_default_mode = "allow_all"
|
||||
|
||||
# ────────────────────────────────────────────────────────────
|
||||
# 🚀 快捷配置(推荐新手使用)
|
||||
# ────────────────────────────────────────────────────────────
|
||||
|
||||
# 🚫 禁用群列表 - 这些群无法使用任何 MCP 工具(每行一个群号)
|
||||
# 示例:
|
||||
# quick_deny_groups = '''
|
||||
# 123456789
|
||||
# 987654321
|
||||
# '''
|
||||
quick_deny_groups = ""
|
||||
|
||||
# ✅ 管理员白名单 - 这些用户始终可以使用所有工具(每行一个QQ号)
|
||||
# 示例:
|
||||
# quick_allow_users = '''
|
||||
# 111111111
|
||||
# '''
|
||||
quick_allow_users = ""
|
||||
|
||||
# ────────────────────────────────────────────────────────────
|
||||
# 📜 高级权限规则(可选,针对特定工具配置)
|
||||
# ────────────────────────────────────────────────────────────
|
||||
# 格式: qq:ID:group/private/user,工具名支持通配符 *
|
||||
# 示例:
|
||||
# perm_rules = '''
|
||||
# [
|
||||
# {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
||||
# ]
|
||||
# '''
|
||||
perm_rules = "[]"
|
||||
|
||||
# ============================================================
|
||||
# 状态显示(只读)
|
||||
# ============================================================
|
||||
[status]
|
||||
connection_status = "未初始化"
|
||||
2
plugins/MaiBot_MCPBridgePlugin/core/__init__.py
Normal file
2
plugins/MaiBot_MCPBridgePlugin/core/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Core helpers for MCP Bridge Plugin."""
|
||||
|
||||
170
plugins/MaiBot_MCPBridgePlugin/core/claude_config.py
Normal file
170
plugins/MaiBot_MCPBridgePlugin/core/claude_config.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
|
||||
class ClaudeConfigError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
Transport = Literal["stdio", "sse", "http", "streamable_http"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClaudeMcpServer:
|
||||
name: str
|
||||
transport: Transport
|
||||
command: str = ""
|
||||
args: List[str] = field(default_factory=list)
|
||||
env: Dict[str, str] = field(default_factory=dict)
|
||||
url: str = ""
|
||||
headers: Dict[str, str] = field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
def _normalize_transport(value: Optional[str]) -> Transport:
|
||||
if not value:
|
||||
return "streamable_http"
|
||||
v = value.strip().lower().replace("-", "_")
|
||||
if v in ("streamable_http", "streamablehttp", "streamable"):
|
||||
return "streamable_http"
|
||||
if v in ("http",):
|
||||
return "http"
|
||||
if v in ("sse",):
|
||||
return "sse"
|
||||
if v in ("stdio",):
|
||||
return "stdio"
|
||||
raise ClaudeConfigError(f"unsupported transport: {value}")
|
||||
|
||||
|
||||
def _coerce_str_list(value: Any, field_name: str) -> List[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value]
|
||||
raise ClaudeConfigError(f"{field_name} must be a list")
|
||||
|
||||
|
||||
def _coerce_str_dict(value: Any, field_name: str) -> Dict[str, str]:
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return {str(k): str(v) for k, v in value.items()}
|
||||
raise ClaudeConfigError(f"{field_name} must be an object")
|
||||
|
||||
|
||||
def parse_claude_mcp_config(config_json: str) -> List[ClaudeMcpServer]:
|
||||
"""Parse Claude Desktop style MCP config JSON.
|
||||
|
||||
Supported:
|
||||
- Full object: {"mcpServers": {...}}
|
||||
- Direct mapping: {...} treated as mcpServers
|
||||
"""
|
||||
text = (config_json or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ClaudeConfigError(f"invalid JSON: {e}") from e
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ClaudeConfigError("config must be a JSON object")
|
||||
|
||||
servers_obj = data.get("mcpServers", data)
|
||||
if not isinstance(servers_obj, dict):
|
||||
raise ClaudeConfigError("mcpServers must be an object")
|
||||
|
||||
servers: List[ClaudeMcpServer] = []
|
||||
for name, raw in servers_obj.items():
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ClaudeConfigError("server name must be a non-empty string")
|
||||
if not isinstance(raw, dict):
|
||||
raise ClaudeConfigError(f"server '{name}' must be an object")
|
||||
|
||||
enabled = bool(raw.get("enabled", True))
|
||||
command = str(raw.get("command", "") or "")
|
||||
url = str(raw.get("url", "") or "")
|
||||
args = _coerce_str_list(raw.get("args"), "args")
|
||||
env = _coerce_str_dict(raw.get("env"), "env")
|
||||
headers = _coerce_str_dict(raw.get("headers"), "headers")
|
||||
|
||||
transport_hint = raw.get("transport", raw.get("type"))
|
||||
|
||||
if command:
|
||||
transport: Transport = "stdio"
|
||||
elif url:
|
||||
try:
|
||||
transport = _normalize_transport(str(transport_hint) if transport_hint is not None else None)
|
||||
except ClaudeConfigError:
|
||||
transport = "streamable_http"
|
||||
else:
|
||||
raise ClaudeConfigError(f"server '{name}' must have either 'command' or 'url'")
|
||||
|
||||
servers.append(
|
||||
ClaudeMcpServer(
|
||||
name=name,
|
||||
transport=transport,
|
||||
command=command,
|
||||
args=args,
|
||||
env=env,
|
||||
url=url,
|
||||
headers=headers,
|
||||
enabled=enabled,
|
||||
)
|
||||
)
|
||||
|
||||
return servers
|
||||
|
||||
|
||||
def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
|
||||
"""Convert legacy v1.x servers list (JSON array) to Claude mcpServers JSON.
|
||||
|
||||
Legacy item schema:
|
||||
{"name","enabled","transport","url","headers","command","args","env"}
|
||||
"""
|
||||
text = (servers_list_json or "").strip()
|
||||
if not text:
|
||||
return ""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return ""
|
||||
if isinstance(data, dict):
|
||||
data = [data]
|
||||
if not isinstance(data, list):
|
||||
return ""
|
||||
|
||||
mcp_servers: Dict[str, Any] = {}
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
name = str(item.get("name", "") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
enabled = bool(item.get("enabled", True))
|
||||
transport = str(item.get("transport", "") or "").strip().lower().replace("-", "_")
|
||||
|
||||
if transport == "stdio" or item.get("command"):
|
||||
entry: Dict[str, Any] = {
|
||||
"enabled": enabled,
|
||||
"command": item.get("command", "") or "",
|
||||
"args": item.get("args", []) or [],
|
||||
}
|
||||
if item.get("env"):
|
||||
entry["env"] = item.get("env")
|
||||
mcp_servers[name] = entry
|
||||
continue
|
||||
|
||||
entry = {"enabled": enabled, "url": item.get("url", "") or ""}
|
||||
if item.get("headers"):
|
||||
entry["headers"] = item.get("headers")
|
||||
if transport:
|
||||
entry["transport"] = transport
|
||||
mcp_servers[name] = entry
|
||||
|
||||
if not mcp_servers:
|
||||
return ""
|
||||
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)
|
||||
|
||||
1542
plugins/MaiBot_MCPBridgePlugin/mcp_client.py
Normal file
1542
plugins/MaiBot_MCPBridgePlugin/mcp_client.py
Normal file
File diff suppressed because it is too large
Load Diff
3722
plugins/MaiBot_MCPBridgePlugin/plugin.py
Normal file
3722
plugins/MaiBot_MCPBridgePlugin/plugin.py
Normal file
File diff suppressed because it is too large
Load Diff
2
plugins/MaiBot_MCPBridgePlugin/requirements.txt
Normal file
2
plugins/MaiBot_MCPBridgePlugin/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# MCP 桥接插件依赖
|
||||
mcp>=1.0.0
|
||||
582
plugins/MaiBot_MCPBridgePlugin/tool_chain.py
Normal file
582
plugins/MaiBot_MCPBridgePlugin/tool_chain.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
MCP Workflow 模块 v1.9.0
|
||||
支持用户自定义工作流(硬流程),将多个 MCP 工具按顺序执行
|
||||
|
||||
双轨制架构:
|
||||
- 软流程 (ReAct): LLM 自主决策,动态多轮调用工具,灵活但不可预测
|
||||
- 硬流程 (Workflow): 用户预定义的工作流,固定流程,可靠可控
|
||||
|
||||
功能:
|
||||
- Workflow 定义和管理
|
||||
- 顺序执行多个工具(硬流程)
|
||||
- 支持变量替换(使用前序工具的输出)
|
||||
- 自动注册为组合工具供 LLM 调用
|
||||
- 与 ReAct 软流程互补,用户可选择合适的执行方式
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
try:
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("mcp_tool_chain")
|
||||
except ImportError:
|
||||
import logging
|
||||
logger = logging.getLogger("mcp_tool_chain")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolChainStep:
|
||||
"""工具链步骤"""
|
||||
tool_name: str # 要调用的工具名(如 mcp_server_tool)
|
||||
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
|
||||
output_key: str = "" # 输出存储的键名,供后续步骤引用
|
||||
description: str = "" # 步骤描述
|
||||
optional: bool = False # 是否可选(失败时继续执行)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"tool_name": self.tool_name,
|
||||
"args_template": self.args_template,
|
||||
"output_key": self.output_key,
|
||||
"description": self.description,
|
||||
"optional": self.optional,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep":
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
args_template=data.get("args_template", {}),
|
||||
output_key=data.get("output_key", ""),
|
||||
description=data.get("description", ""),
|
||||
optional=data.get("optional", False),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolChainDefinition:
|
||||
"""工具链定义"""
|
||||
name: str # 工具链名称(将作为组合工具的名称)
|
||||
description: str # 工具链描述(供 LLM 理解)
|
||||
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
|
||||
input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述}
|
||||
enabled: bool = True # 是否启用
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"steps": [step.to_dict() for step in self.steps],
|
||||
"input_params": self.input_params,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition":
|
||||
steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])]
|
||||
return cls(
|
||||
name=data.get("name", ""),
|
||||
description=data.get("description", ""),
|
||||
steps=steps,
|
||||
input_params=data.get("input_params", {}),
|
||||
enabled=data.get("enabled", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainExecutionResult:
|
||||
"""工具链执行结果"""
|
||||
success: bool
|
||||
final_output: str # 最终输出(最后一个步骤的结果)
|
||||
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
|
||||
error: str = ""
|
||||
total_duration_ms: float = 0.0
|
||||
|
||||
def to_summary(self) -> str:
|
||||
"""生成执行摘要"""
|
||||
lines = []
|
||||
for i, step in enumerate(self.step_results):
|
||||
status = "✅" if step.get("success") else "❌"
|
||||
tool = step.get("tool_name", "unknown")
|
||||
duration = step.get("duration_ms", 0)
|
||||
lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)")
|
||||
if not step.get("success") and step.get("error"):
|
||||
lines.append(f" 错误: {step['error'][:50]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ToolChainExecutor:
|
||||
"""工具链执行器"""
|
||||
|
||||
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
|
||||
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}')
|
||||
|
||||
def __init__(self, mcp_manager):
|
||||
self._mcp_manager = mcp_manager
|
||||
|
||||
def _resolve_tool_key(self, tool_name: str) -> Optional[str]:
|
||||
"""解析工具名,返回有效的 tool_key
|
||||
|
||||
支持:
|
||||
- 直接使用 tool_key(如 mcp_server_tool)
|
||||
- 使用注册后的工具名(会自动转换 - 和 . 为 _)
|
||||
"""
|
||||
all_tools = self._mcp_manager.all_tools
|
||||
|
||||
# 直接匹配
|
||||
if tool_name in all_tools:
|
||||
return tool_name
|
||||
|
||||
# 尝试转换后匹配(用户可能使用了注册后的名称)
|
||||
normalized = tool_name.replace("-", "_").replace(".", "_")
|
||||
if normalized in all_tools:
|
||||
return normalized
|
||||
|
||||
# 尝试查找包含该名称的工具
|
||||
for key in all_tools.keys():
|
||||
if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"):
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
chain: ToolChainDefinition,
|
||||
input_args: Dict[str, Any],
|
||||
) -> ChainExecutionResult:
|
||||
"""执行工具链
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
input_args: 用户输入的参数
|
||||
|
||||
Returns:
|
||||
ChainExecutionResult: 执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
step_results = []
|
||||
context = {
|
||||
"input": input_args or {}, # 用户输入,确保不为 None
|
||||
"step": {}, # 各步骤输出,按 output_key 存储
|
||||
"prev": "", # 上一步的输出
|
||||
}
|
||||
|
||||
final_output = ""
|
||||
|
||||
# 验证必需的输入参数
|
||||
missing_params = []
|
||||
for param_name in chain.input_params.keys():
|
||||
if param_name not in context["input"]:
|
||||
missing_params.append(param_name)
|
||||
|
||||
if missing_params:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error=f"缺少必需参数: {', '.join(missing_params)}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
step_start = time.time()
|
||||
step_result = {
|
||||
"step_index": i,
|
||||
"tool_name": step.tool_name,
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": "",
|
||||
"duration_ms": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
# 替换参数中的变量
|
||||
resolved_args = self._resolve_args(step.args_template, context)
|
||||
step_result["resolved_args"] = resolved_args
|
||||
|
||||
# 解析工具名
|
||||
tool_key = self._resolve_tool_key(step.tool_name)
|
||||
if not tool_key:
|
||||
step_result["error"] = f"工具 {step.tool_name} 不存在"
|
||||
logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
step_results.append(step_result)
|
||||
continue
|
||||
|
||||
logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}")
|
||||
|
||||
# 调用工具
|
||||
result = await self._mcp_manager.call_tool(tool_key, resolved_args)
|
||||
|
||||
step_duration = (time.time() - step_start) * 1000
|
||||
step_result["duration_ms"] = step_duration
|
||||
|
||||
if result.success:
|
||||
step_result["success"] = True
|
||||
# 确保 content 不为 None
|
||||
content = result.content if result.content is not None else ""
|
||||
step_result["output"] = content
|
||||
|
||||
# 更新上下文
|
||||
context["prev"] = content
|
||||
if step.output_key:
|
||||
context["step"][step.output_key] = content
|
||||
|
||||
final_output = content
|
||||
content_preview = content[:100] if content else "(空)"
|
||||
logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...")
|
||||
else:
|
||||
step_result["error"] = result.error or "未知错误"
|
||||
logger.warning(f"工具链步骤 {i+1} 失败: {result.error}")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
step_duration = (time.time() - step_start) * 1000
|
||||
step_result["duration_ms"] = step_duration
|
||||
step_result["error"] = str(e)
|
||||
logger.error(f"工具链步骤 {i+1} 异常: {e}")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
step_results.append(step_result)
|
||||
|
||||
total_duration = (time.time() - start_time) * 1000
|
||||
|
||||
return ChainExecutionResult(
|
||||
success=True,
|
||||
final_output=final_output,
|
||||
step_results=step_results,
|
||||
total_duration_ms=total_duration,
|
||||
)
|
||||
|
||||
def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""解析参数模板,替换变量
|
||||
|
||||
支持的变量格式:
|
||||
- ${input.param_name}: 用户输入的参数
|
||||
- ${step.output_key}: 某个步骤的输出
|
||||
- ${prev}: 上一步的输出
|
||||
- ${prev.field}: 上一步输出(JSON)的某个字段
|
||||
"""
|
||||
resolved = {}
|
||||
|
||||
for key, value in args_template.items():
|
||||
if isinstance(value, str):
|
||||
resolved[key] = self._substitute_vars(value, context)
|
||||
elif isinstance(value, dict):
|
||||
resolved[key] = self._resolve_args(value, context)
|
||||
elif isinstance(value, list):
|
||||
resolved[key] = [
|
||||
self._substitute_vars(v, context) if isinstance(v, str) else v
|
||||
for v in value
|
||||
]
|
||||
else:
|
||||
resolved[key] = value
|
||||
|
||||
return resolved
|
||||
|
||||
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
|
||||
"""替换字符串中的变量"""
|
||||
def replacer(match):
|
||||
var_path = match.group(1)
|
||||
return self._get_var_value(var_path, context)
|
||||
|
||||
return self.VAR_PATTERN.sub(replacer, template)
|
||||
|
||||
def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str:
|
||||
"""获取变量值
|
||||
|
||||
Args:
|
||||
var_path: 变量路径,如 "input.query", "step.search_result", "prev", "prev.id"
|
||||
context: 上下文
|
||||
"""
|
||||
parts = self._parse_var_path(var_path)
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
|
||||
# 获取根对象
|
||||
root = parts[0]
|
||||
if root not in context:
|
||||
logger.warning(f"变量 {var_path} 的根 '{root}' 不存在")
|
||||
return ""
|
||||
|
||||
value = context[root]
|
||||
|
||||
# 遍历路径
|
||||
for part in parts[1:]:
|
||||
if isinstance(value, str):
|
||||
parsed = self._try_parse_json(value)
|
||||
if parsed is not None:
|
||||
value = parsed
|
||||
|
||||
if isinstance(value, dict):
|
||||
value = value.get(part, "")
|
||||
elif isinstance(value, list):
|
||||
if part.isdigit():
|
||||
idx = int(part)
|
||||
value = value[idx] if 0 <= idx < len(value) else ""
|
||||
else:
|
||||
value = ""
|
||||
else:
|
||||
value = ""
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
if value is None:
|
||||
return ""
|
||||
if value == "":
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
def _try_parse_json(self, value: str) -> Optional[Any]:
|
||||
"""尝试将字符串解析为 JSON 对象,失败则返回 None。"""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _parse_var_path(self, var_path: str) -> List[str]:
|
||||
"""解析变量路径,支持点号与下标写法。
|
||||
|
||||
支持:
|
||||
- step.geo.return.0.location
|
||||
- step.geo.return[0].location
|
||||
- step.geo['return'][0]['location']
|
||||
"""
|
||||
if not var_path:
|
||||
return []
|
||||
|
||||
tokens: List[str] = []
|
||||
buf: List[str] = []
|
||||
in_bracket = False
|
||||
in_quote = False
|
||||
quote_char = ""
|
||||
|
||||
def flush_buf() -> None:
|
||||
if buf:
|
||||
token = "".join(buf).strip()
|
||||
if token:
|
||||
tokens.append(token)
|
||||
buf.clear()
|
||||
|
||||
i = 0
|
||||
while i < len(var_path):
|
||||
ch = var_path[i]
|
||||
|
||||
if not in_bracket and ch == ".":
|
||||
flush_buf()
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if not in_bracket and ch == "[":
|
||||
flush_buf()
|
||||
in_bracket = True
|
||||
in_quote = False
|
||||
quote_char = ""
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_bracket and not in_quote and ch == "]":
|
||||
flush_buf()
|
||||
in_bracket = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_bracket and ch in ("'", '"'):
|
||||
if not in_quote:
|
||||
in_quote = True
|
||||
quote_char = ch
|
||||
i += 1
|
||||
continue
|
||||
if quote_char == ch:
|
||||
in_quote = False
|
||||
quote_char = ""
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_bracket and not in_quote:
|
||||
if ch.isspace():
|
||||
i += 1
|
||||
continue
|
||||
if ch == ",":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
buf.append(ch)
|
||||
i += 1
|
||||
|
||||
flush_buf()
|
||||
|
||||
if in_bracket or in_quote:
|
||||
return [p for p in var_path.split(".") if p]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class ToolChainManager:
|
||||
"""工具链管理器"""
|
||||
|
||||
_instance: Optional["ToolChainManager"] = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
self._chains: Dict[str, ToolChainDefinition] = {}
|
||||
self._executor: Optional[ToolChainExecutor] = None
|
||||
|
||||
def set_executor(self, mcp_manager) -> None:
|
||||
"""设置执行器"""
|
||||
self._executor = ToolChainExecutor(mcp_manager)
|
||||
|
||||
def add_chain(self, chain: ToolChainDefinition) -> bool:
|
||||
"""添加工具链"""
|
||||
if not chain.name:
|
||||
logger.error("工具链名称不能为空")
|
||||
return False
|
||||
|
||||
if chain.name in self._chains:
|
||||
logger.warning(f"工具链 {chain.name} 已存在,将被覆盖")
|
||||
|
||||
self._chains[chain.name] = chain
|
||||
logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)")
|
||||
return True
|
||||
|
||||
def remove_chain(self, name: str) -> bool:
|
||||
"""移除工具链"""
|
||||
if name in self._chains:
|
||||
del self._chains[name]
|
||||
logger.info(f"已移除工具链: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_chain(self, name: str) -> Optional[ToolChainDefinition]:
|
||||
"""获取工具链"""
|
||||
return self._chains.get(name)
|
||||
|
||||
def get_all_chains(self) -> Dict[str, ToolChainDefinition]:
|
||||
"""获取所有工具链"""
|
||||
return self._chains.copy()
|
||||
|
||||
def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]:
|
||||
"""获取所有启用的工具链"""
|
||||
return {name: chain for name, chain in self._chains.items() if chain.enabled}
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
chain_name: str,
|
||||
input_args: Dict[str, Any],
|
||||
) -> ChainExecutionResult:
|
||||
"""执行工具链"""
|
||||
chain = self._chains.get(chain_name)
|
||||
if not chain:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error=f"工具链 {chain_name} 不存在",
|
||||
)
|
||||
|
||||
if not chain.enabled:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error=f"工具链 {chain_name} 已禁用",
|
||||
)
|
||||
|
||||
if not self._executor:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error="工具链执行器未初始化",
|
||||
)
|
||||
|
||||
return await self._executor.execute(chain, input_args)
|
||||
|
||||
def load_from_json(self, json_str: str) -> Tuple[int, List[str]]:
|
||||
"""从 JSON 字符串加载工具链配置
|
||||
|
||||
Returns:
|
||||
(成功加载数量, 错误列表)
|
||||
"""
|
||||
errors = []
|
||||
loaded = 0
|
||||
|
||||
try:
|
||||
data = json.loads(json_str) if json_str.strip() else []
|
||||
except json.JSONDecodeError as e:
|
||||
return 0, [f"JSON 解析失败: {e}"]
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
for i, item in enumerate(data):
|
||||
try:
|
||||
chain = ToolChainDefinition.from_dict(item)
|
||||
if not chain.name:
|
||||
errors.append(f"第 {i+1} 个工具链缺少名称")
|
||||
continue
|
||||
if not chain.steps:
|
||||
errors.append(f"工具链 {chain.name} 没有步骤")
|
||||
continue
|
||||
|
||||
self.add_chain(chain)
|
||||
loaded += 1
|
||||
except Exception as e:
|
||||
errors.append(f"第 {i+1} 个工具链解析失败: {e}")
|
||||
|
||||
return loaded, errors
|
||||
|
||||
def export_to_json(self, pretty: bool = True) -> str:
|
||||
"""导出所有工具链为 JSON"""
|
||||
chains_data = [chain.to_dict() for chain in self._chains.values()]
|
||||
if pretty:
|
||||
return json.dumps(chains_data, ensure_ascii=False, indent=2)
|
||||
return json.dumps(chains_data, ensure_ascii=False)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有工具链"""
|
||||
self._chains.clear()
|
||||
|
||||
|
||||
# 全局工具链管理器实例
|
||||
tool_chain_manager = ToolChainManager()
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "MaiBot"
|
||||
version = "0.11.0"
|
||||
version = "0.11.6"
|
||||
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
@@ -14,6 +14,7 @@ dependencies = [
|
||||
"json-repair>=0.47.6",
|
||||
"maim-message",
|
||||
"matplotlib>=3.10.3",
|
||||
"msgpack>=1.1.2",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
"pandas>=2.3.1",
|
||||
@@ -23,6 +24,7 @@ dependencies = [
|
||||
"pydantic>=2.11.7",
|
||||
"pypinyin>=0.54.0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"python-multipart>=0.0.20",
|
||||
"quick-algo>=0.1.3",
|
||||
"rich>=14.0.0",
|
||||
"ruff>=0.12.2",
|
||||
@@ -32,9 +34,14 @@ dependencies = [
|
||||
"tomlkit>=0.13.3",
|
||||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"zstandard>=0.25.0",
|
||||
]
|
||||
|
||||
|
||||
[tool.uv]
|
||||
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
include = ["*.py"]
|
||||
|
||||
386
scripts/delete_lpmm_items.py
Normal file
386
scripts/delete_lpmm_items.py
Normal file
@@ -0,0 +1,386 @@
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import json
|
||||
import os
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能找到 src 包
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
logger = get_logger("delete_lpmm_items")
|
||||
|
||||
|
||||
def read_hashes(file_path: Path) -> List[str]:
|
||||
"""读取哈希列表,跳过空行"""
|
||||
hashes: List[str] = []
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
val = line.strip()
|
||||
if not val:
|
||||
continue
|
||||
hashes.append(val)
|
||||
return hashes
|
||||
|
||||
|
||||
def read_openie_hashes(file_path: Path) -> List[str]:
|
||||
"""从 OpenIE JSON 中提取 idx 作为段落哈希"""
|
||||
data: Dict[str, Any] = json.loads(file_path.read_text(encoding="utf-8"))
|
||||
docs = data.get("docs", []) if isinstance(data, dict) else []
|
||||
hashes: List[str] = []
|
||||
for doc in docs:
|
||||
idx = doc.get("idx") if isinstance(doc, dict) else None
|
||||
if isinstance(idx, str) and idx.strip():
|
||||
hashes.append(idx.strip())
|
||||
return hashes
|
||||
|
||||
|
||||
def normalize_paragraph_keys(raw_hashes: List[str]) -> Tuple[List[str], List[str]]:
|
||||
"""将输入规范为完整键和纯哈希两份列表"""
|
||||
keys: List[str] = []
|
||||
hashes: List[str] = []
|
||||
for h in raw_hashes:
|
||||
if h.startswith("paragraph-"):
|
||||
keys.append(h)
|
||||
hashes.append(h.replace("paragraph-", "", 1))
|
||||
else:
|
||||
keys.append(f"paragraph-{h}")
|
||||
hashes.append(h)
|
||||
return keys, hashes
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Delete paragraphs from LPMM knowledge base (vectors + graph).")
|
||||
parser.add_argument("--hash-file", help="文本文件路径,每行一个 paragraph 哈希或带前缀键")
|
||||
parser.add_argument("--openie-file", help="OpenIE 输出文件(JSON),将其 docs.idx 作为待删段落哈希")
|
||||
parser.add_argument("--raw-file", help="原始 txt 语料文件(按空行分段),可结合 --raw-index 使用")
|
||||
parser.add_argument(
|
||||
"--raw-index",
|
||||
help="在 --raw-file 中要删除的段落索引,1 基,支持逗号分隔,例如 1,3",
|
||||
)
|
||||
parser.add_argument("--search-text", help="在当前段落库中按子串搜索匹配段落并交互选择删除")
|
||||
parser.add_argument(
|
||||
"--search-limit",
|
||||
type=int,
|
||||
default=10,
|
||||
help="--search-text 模式下最多展示的候选段落数量",
|
||||
)
|
||||
parser.add_argument("--delete-entities", action="store_true", help="同时删除 OpenIE 文件中的实体节点/嵌入")
|
||||
parser.add_argument("--delete-relations", action="store_true", help="同时删除 OpenIE 文件中的关系嵌入")
|
||||
parser.add_argument("--remove-orphan-entities", action="store_true", help="删除删除后孤立的实体节点")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅预览将删除的项,不实际修改")
|
||||
parser.add_argument("--yes", action="store_true", help="跳过交互确认,直接执行删除(谨慎使用)")
|
||||
parser.add_argument(
|
||||
"--max-delete-nodes",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="单次最大允许删除的节点数量(段落+实体),超过则需要显式确认或调整该参数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help=(
|
||||
"非交互模式:不再通过 input() 询问任何信息;"
|
||||
"在该模式下,如果需要交互(例如 --search-text 未指定具体条目、未提供 --yes),"
|
||||
"会直接报错退出。"
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 至少需要一种来源
|
||||
if not (args.hash_file or args.openie_file or args.raw_file or args.search_text):
|
||||
logger.error("必须指定 --hash-file / --openie-file / --raw-file / --search-text 之一")
|
||||
sys.exit(1)
|
||||
|
||||
raw_hashes: List[str] = []
|
||||
raw_entities: List[str] = []
|
||||
raw_relations: List[str] = []
|
||||
|
||||
if args.hash_file:
|
||||
hash_file = Path(args.hash_file)
|
||||
if not hash_file.exists():
|
||||
logger.error(f"哈希文件不存在: {hash_file}")
|
||||
sys.exit(1)
|
||||
raw_hashes.extend(read_hashes(hash_file))
|
||||
|
||||
if args.openie_file:
|
||||
openie_path = Path(args.openie_file)
|
||||
if not openie_path.exists():
|
||||
logger.error(f"OpenIE 文件不存在: {openie_path}")
|
||||
sys.exit(1)
|
||||
# 段落
|
||||
raw_hashes.extend(read_openie_hashes(openie_path))
|
||||
# 实体/关系(实体同时包含 extracted_entities 与三元组主语/宾语,以匹配 KG 构图逻辑)
|
||||
try:
|
||||
data = json.loads(openie_path.read_text(encoding="utf-8"))
|
||||
docs = data.get("docs", []) if isinstance(data, dict) else []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
ents = doc.get("extracted_entities", [])
|
||||
if isinstance(ents, list):
|
||||
raw_entities.extend([e for e in ents if isinstance(e, str)])
|
||||
triples = doc.get("extracted_triples", [])
|
||||
if isinstance(triples, list):
|
||||
for t in triples:
|
||||
if isinstance(t, list) and len(t) == 3:
|
||||
subj, _, obj = t
|
||||
if isinstance(subj, str):
|
||||
raw_entities.append(subj)
|
||||
if isinstance(obj, str):
|
||||
raw_entities.append(obj)
|
||||
raw_relations.append(str(tuple(t)))
|
||||
except Exception as e:
|
||||
logger.error(f"读取 OpenIE 文件失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 从原始 txt 语料按段落索引选择删除
|
||||
if args.raw_file:
|
||||
raw_path = Path(args.raw_file)
|
||||
if not raw_path.exists():
|
||||
logger.error(f"原始语料文件不存在: {raw_path}")
|
||||
sys.exit(1)
|
||||
text = raw_path.read_text(encoding="utf-8")
|
||||
paragraphs: List[str] = []
|
||||
buf = []
|
||||
for line in text.splitlines():
|
||||
if line.strip() == "":
|
||||
if buf:
|
||||
paragraphs.append("\n".join(buf).strip())
|
||||
buf = []
|
||||
else:
|
||||
buf.append(line)
|
||||
if buf:
|
||||
paragraphs.append("\n".join(buf).strip())
|
||||
|
||||
if not paragraphs:
|
||||
logger.error(f"原始语料文件 {raw_path} 中没有解析到任何段落")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.raw_index:
|
||||
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
|
||||
sys.exit(1)
|
||||
|
||||
# 解析索引列表(1-based)
|
||||
try:
|
||||
idx_list = [int(x.strip()) for x in str(args.raw_index).split(",") if x.strip()]
|
||||
except ValueError:
|
||||
logger.error(f"--raw-index 解析失败: {args.raw_index}")
|
||||
sys.exit(1)
|
||||
|
||||
for idx in idx_list:
|
||||
if idx < 1 or idx > len(paragraphs):
|
||||
logger.error(f"--raw-index 包含无效索引 {idx}(有效范围 1~{len(paragraphs)})")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("根据原始语料选择段落:")
|
||||
for idx in idx_list:
|
||||
para = paragraphs[idx - 1]
|
||||
h = get_sha256(para)
|
||||
logger.info(f"- 第 {idx} 段,hash={h},内容预览:{para[:80]}")
|
||||
raw_hashes.append(h)
|
||||
|
||||
# 在现有库中按子串搜索候选段落并交互选择
|
||||
if args.search_text:
|
||||
search_text = args.search_text.strip()
|
||||
if not search_text:
|
||||
logger.error("--search-text 不能为空")
|
||||
sys.exit(1)
|
||||
logger.info(f"正在根据关键字在现有段落库中搜索:{search_text!r}")
|
||||
em_search = EmbeddingManager()
|
||||
try:
|
||||
em_search.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"加载嵌入库失败,无法使用 --search-text 功能: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
candidates = []
|
||||
for key, item in em_search.paragraphs_embedding_store.store.items():
|
||||
if search_text in item.str:
|
||||
candidates.append((key, item.str))
|
||||
if len(candidates) >= args.search_limit:
|
||||
break
|
||||
|
||||
if not candidates:
|
||||
logger.info("未在现有段落库中找到包含该关键字的段落")
|
||||
else:
|
||||
logger.info("找到以下候选段落(输入序号选择要删除的条目,可用逗号分隔,多选):")
|
||||
for i, (key, text) in enumerate(candidates, start=1):
|
||||
logger.info(f"{i}. {key} | {text[:80]}")
|
||||
if args.non_interactive:
|
||||
logger.error(
|
||||
"当前处于非交互模式,无法通过输入序号选择要删除的候选段落;"
|
||||
"如需脚本化删除,请改用 --hash-file / --openie-file / --raw-file 等方式。"
|
||||
)
|
||||
sys.exit(1)
|
||||
choice = input("请输入要删除的序号列表(如 1,3),或直接回车取消:").strip()
|
||||
if choice:
|
||||
try:
|
||||
idxs = [int(x.strip()) for x in choice.split(",") if x.strip()]
|
||||
except ValueError:
|
||||
logger.error("输入的序号列表无法解析,已取消 --search-text 删除")
|
||||
else:
|
||||
for i in idxs:
|
||||
if 1 <= i <= len(candidates):
|
||||
key, _ = candidates[i - 1]
|
||||
# key 已是完整的 paragraph-xxx
|
||||
if key.startswith("paragraph-"):
|
||||
raw_hashes.append(key.split("paragraph-", 1)[1])
|
||||
else:
|
||||
logger.warning(f"忽略无效序号: {i}")
|
||||
|
||||
# 去重但保持顺序
|
||||
seen = set()
|
||||
raw_hashes = [h for h in raw_hashes if not (h in seen or seen.add(h))]
|
||||
|
||||
if not raw_hashes:
|
||||
logger.error("未读取到任何待删哈希,无操作")
|
||||
sys.exit(1)
|
||||
|
||||
keys, pg_hashes = normalize_paragraph_keys(raw_hashes)
|
||||
|
||||
ent_hashes: List[str] = []
|
||||
rel_hashes: List[str] = []
|
||||
if args.delete_entities and raw_entities:
|
||||
ent_hashes = [get_sha256(e) for e in raw_entities]
|
||||
if args.delete_relations and raw_relations:
|
||||
rel_hashes = [get_sha256(r) for r in raw_relations]
|
||||
|
||||
logger.info("=== 删除操作预备 ===")
|
||||
logger.info("请确保已备份 data/embedding 与 data/rag,必要时可使用 --dry-run 预览")
|
||||
logger.info(f"待删除段落数量: {len(keys)}")
|
||||
logger.info(f"示例: {keys[:5]}")
|
||||
if ent_hashes:
|
||||
logger.info(f"待删除实体数量: {len(ent_hashes)}")
|
||||
if rel_hashes:
|
||||
logger.info(f"待删除关系数量: {len(rel_hashes)}")
|
||||
|
||||
total_nodes_to_delete = len(pg_hashes) + (len(ent_hashes) if args.delete_entities else 0)
|
||||
logger.info(f"本次预计删除节点总数(段落+实体): {total_nodes_to_delete}")
|
||||
|
||||
if args.dry_run:
|
||||
logger.info("dry-run 模式,未执行删除")
|
||||
return
|
||||
|
||||
# 大批次删除保护
|
||||
if total_nodes_to_delete > args.max_delete_nodes and not args.yes:
|
||||
logger.error(
|
||||
f"本次预计删除节点 {total_nodes_to_delete} 个,超过阈值 {args.max_delete_nodes}。"
|
||||
" 为避免误删,请降低批次规模或使用 --max-delete-nodes 调整阈值,并加上 --yes 明确确认。"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# 交互确认
|
||||
if not args.yes:
|
||||
if args.non_interactive:
|
||||
logger.error(
|
||||
"当前处于非交互模式且未指定 --yes,出于安全考虑,删除操作已被拒绝。\n"
|
||||
"如确认需要在非交互模式下执行删除,请显式添加 --yes 参数。"
|
||||
)
|
||||
sys.exit(1)
|
||||
confirm = input("确认删除上述数据?输入大写 YES 以继续,其他任意键取消: ").strip()
|
||||
if confirm != "YES":
|
||||
logger.info("用户取消删除操作")
|
||||
return
|
||||
|
||||
# 加载嵌入与图
|
||||
embed_manager = EmbeddingManager()
|
||||
kg_manager = KGManager()
|
||||
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"加载现有知识库失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 记录删除前全局统计,便于对比
|
||||
before_para_vec = len(embed_manager.paragraphs_embedding_store.store)
|
||||
before_ent_vec = len(embed_manager.entities_embedding_store.store)
|
||||
before_rel_vec = len(embed_manager.relation_embedding_store.store)
|
||||
before_nodes = len(kg_manager.graph.get_node_list())
|
||||
before_edges = len(kg_manager.graph.get_edge_list())
|
||||
logger.info(
|
||||
f"删除前统计: 段落向量={before_para_vec}, 实体向量={before_ent_vec}, 关系向量={before_rel_vec}, "
|
||||
f"KG节点={before_nodes}, KG边={before_edges}"
|
||||
)
|
||||
|
||||
# 删除向量
|
||||
deleted, skipped = embed_manager.paragraphs_embedding_store.delete_items(keys)
|
||||
embed_manager.stored_pg_hashes = set(embed_manager.paragraphs_embedding_store.store.keys())
|
||||
logger.info(f"段落向量删除完成,删除: {deleted}, 跳过: {skipped}")
|
||||
ent_deleted = ent_skipped = rel_deleted = rel_skipped = 0
|
||||
if ent_hashes:
|
||||
ent_keys = [f"entity-{h}" for h in ent_hashes]
|
||||
ent_deleted, ent_skipped = embed_manager.entities_embedding_store.delete_items(ent_keys)
|
||||
logger.info(f"实体向量删除完成,删除: {ent_deleted}, 跳过: {ent_skipped}")
|
||||
if rel_hashes:
|
||||
rel_keys = [f"relation-{h}" for h in rel_hashes]
|
||||
rel_deleted, rel_skipped = embed_manager.relation_embedding_store.delete_items(rel_keys)
|
||||
logger.info(f"关系向量删除完成,删除: {rel_deleted}, 跳过: {rel_skipped}")
|
||||
|
||||
# 删除图节点/边
|
||||
kg_result = kg_manager.delete_paragraphs(
|
||||
pg_hashes,
|
||||
ent_hashes=ent_hashes if args.delete_entities else None,
|
||||
remove_orphan_entities=args.remove_orphan_entities,
|
||||
)
|
||||
logger.info(
|
||||
f"KG 删除完成,删除: {kg_result.get('deleted', 0)}, 跳过: {kg_result.get('skipped', 0)}, "
|
||||
f"孤立实体清理: {kg_result.get('orphan_removed', 0)}"
|
||||
)
|
||||
|
||||
# 重建索引并保存
|
||||
logger.info("重建 Faiss 索引并保存嵌入文件...")
|
||||
embed_manager.rebuild_faiss_index()
|
||||
embed_manager.save_to_file()
|
||||
|
||||
logger.info("保存 KG 数据...")
|
||||
kg_manager.save_to_file()
|
||||
|
||||
# 删除后统计
|
||||
after_para_vec = len(embed_manager.paragraphs_embedding_store.store)
|
||||
after_ent_vec = len(embed_manager.entities_embedding_store.store)
|
||||
after_rel_vec = len(embed_manager.relation_embedding_store.store)
|
||||
after_nodes = len(kg_manager.graph.get_node_list())
|
||||
after_edges = len(kg_manager.graph.get_edge_list())
|
||||
|
||||
logger.info(
|
||||
"删除后统计: 段落向量=%d(%+d), 实体向量=%d(%+d), 关系向量=%d(%+d), KG节点=%d(%+d), KG边=%d(%+d)"
|
||||
% (
|
||||
after_para_vec,
|
||||
after_para_vec - before_para_vec,
|
||||
after_ent_vec,
|
||||
after_ent_vec - before_ent_vec,
|
||||
after_rel_vec,
|
||||
after_rel_vec - before_rel_vec,
|
||||
after_nodes,
|
||||
after_nodes - before_nodes,
|
||||
after_edges,
|
||||
after_edges - before_edges,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("删除流程完成")
|
||||
print(
|
||||
"\n[NOTICE] 删除脚本执行完毕。如主程序(聊天 / WebUI)已在运行,"
|
||||
"请重启主程序,或在主程序内部调用一次 lpmm_start_up() 以应用最新 LPMM 知识库。"
|
||||
)
|
||||
print("[NOTICE] 如果不清楚 lpmm_start_up 是什么,直接重启主程序即可。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
567
scripts/expression_merge_simulation.py
Normal file
567
scripts/expression_merge_simulation.py
Normal file
@@ -0,0 +1,567 @@
|
||||
"""
|
||||
模拟 Expression 合并过程
|
||||
|
||||
用法:
|
||||
python scripts/expression_merge_simulation.py
|
||||
或指定 chat_id:
|
||||
python scripts/expression_merge_simulation.py --chat-id <chat_id>
|
||||
或指定相似度阈值:
|
||||
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Import after setting up path (required for project imports)
|
||||
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
||||
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
|
||||
from src.llm_models.utils_model import LLMRequest # noqa: E402
|
||||
from src.config.config import model_config # noqa: E402
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name}"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id[:8]}...)"
|
||||
|
||||
|
||||
def parse_content_list(stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 content_list JSON 字符串为列表"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
|
||||
def parse_style_list(stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 style_list JSON 字符串为列表"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
|
||||
def find_exact_style_match(
|
||||
expressions: List[Expression],
|
||||
target_style: str,
|
||||
chat_id: str,
|
||||
exclude_ids: set
|
||||
) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有完全匹配 style 的 Expression 记录
|
||||
检查 style 字段和 style_list 中的每一项
|
||||
"""
|
||||
for expr in expressions:
|
||||
if expr.chat_id != chat_id or expr.id in exclude_ids:
|
||||
continue
|
||||
|
||||
# 检查 style 字段
|
||||
if expr.style == target_style:
|
||||
return expr
|
||||
|
||||
# 检查 style_list 中的每一项
|
||||
style_list = parse_style_list(expr.style_list)
|
||||
if target_style in style_list:
|
||||
return expr
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_similar_style_expression(
|
||||
expressions: List[Expression],
|
||||
target_style: str,
|
||||
chat_id: str,
|
||||
similarity_threshold: float,
|
||||
exclude_ids: set
|
||||
) -> Optional[Tuple[Expression, float]]:
|
||||
"""
|
||||
查找具有相似 style 的 Expression 记录
|
||||
检查 style 字段和 style_list 中的每一项
|
||||
|
||||
Returns:
|
||||
(Expression, similarity) 或 None
|
||||
"""
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for expr in expressions:
|
||||
if expr.chat_id != chat_id or expr.id in exclude_ids:
|
||||
continue
|
||||
|
||||
# 检查 style 字段
|
||||
similarity = calculate_style_similarity(target_style, expr.style)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
# 检查 style_list 中的每一项
|
||||
style_list = parse_style_list(expr.style_list)
|
||||
for existing_style in style_list:
|
||||
similarity = calculate_style_similarity(target_style, existing_style)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
if best_match:
|
||||
return (best_match, best_similarity)
|
||||
return None
|
||||
|
||||
|
||||
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
|
||||
"""组合 situation 文本,尝试使用 LLM 总结"""
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
if not sanitized:
|
||||
return ""
|
||||
|
||||
if len(sanitized) == 1:
|
||||
return sanitized[0]
|
||||
|
||||
# 尝试使用 LLM 总结
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
print(f" ⚠️ LLM 总结 situation 失败: {e}")
|
||||
|
||||
# 如果总结失败,返回用 "/" 连接的字符串
|
||||
return "/".join(sanitized)
|
||||
|
||||
|
||||
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
|
||||
"""组合 style 文本,尝试使用 LLM 总结"""
|
||||
sanitized = [s.strip() for s in style_list if s.strip()]
|
||||
if not sanitized:
|
||||
return ""
|
||||
|
||||
if len(sanitized) == 1:
|
||||
return sanitized[0]
|
||||
|
||||
# 尝试使用 LLM 总结
|
||||
prompt = (
|
||||
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
|
||||
print(f"Prompt:{prompt} Summary:{summary}")
|
||||
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
print(f" ⚠️ LLM 总结 style 失败: {e}")
|
||||
|
||||
# 如果总结失败,返回第一个
|
||||
return sanitized[0]
|
||||
|
||||
|
||||
async def simulate_merge(
|
||||
expressions: List[Expression],
|
||||
similarity_threshold: float = 0.75,
|
||||
use_llm: bool = False,
|
||||
max_samples: int = 10,
|
||||
) -> Dict:
|
||||
"""
|
||||
模拟合并过程
|
||||
|
||||
Args:
|
||||
expressions: Expression 列表(从数据库读出的原始记录)
|
||||
similarity_threshold: style 相似度阈值
|
||||
use_llm: 是否使用 LLM 进行实际总结
|
||||
max_samples: 最多随机抽取的 Expression 数量(为 0 或 None 时表示不限制)
|
||||
|
||||
Returns:
|
||||
包含合并统计信息的字典
|
||||
"""
|
||||
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
|
||||
if max_samples and len(expressions) > max_samples:
|
||||
expressions = random.sample(expressions, max_samples)
|
||||
|
||||
# 按 chat_id 分组
|
||||
expressions_by_chat = defaultdict(list)
|
||||
for expr in expressions:
|
||||
expressions_by_chat[expr.chat_id].append(expr)
|
||||
|
||||
# 初始化 LLM 模型(如果需要)
|
||||
summary_model = None
|
||||
if use_llm:
|
||||
try:
|
||||
summary_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="expression.summary"
|
||||
)
|
||||
print("✅ LLM 模型已初始化,将进行实际总结")
|
||||
except Exception as e:
|
||||
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
|
||||
use_llm = False
|
||||
|
||||
merge_stats = {
|
||||
"total_expressions": len(expressions),
|
||||
"total_chats": len(expressions_by_chat),
|
||||
"exact_matches": 0,
|
||||
"similar_matches": 0,
|
||||
"new_records": 0,
|
||||
"merge_details": [],
|
||||
"chat_stats": {},
|
||||
"use_llm": use_llm
|
||||
}
|
||||
|
||||
# 为每个 chat_id 模拟合并
|
||||
for chat_id, chat_expressions in expressions_by_chat.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
chat_stat = {
|
||||
"chat_id": chat_id,
|
||||
"chat_name": chat_name,
|
||||
"total": len(chat_expressions),
|
||||
"exact_matches": 0,
|
||||
"similar_matches": 0,
|
||||
"new_records": 0,
|
||||
"merges": []
|
||||
}
|
||||
|
||||
processed_ids = set()
|
||||
|
||||
for expr in chat_expressions:
|
||||
if expr.id in processed_ids:
|
||||
continue
|
||||
|
||||
target_style = expr.style
|
||||
target_situation = expr.situation
|
||||
|
||||
# 第一层:检查完全匹配
|
||||
exact_match = find_exact_style_match(
|
||||
chat_expressions,
|
||||
target_style,
|
||||
chat_id,
|
||||
{expr.id}
|
||||
)
|
||||
|
||||
if exact_match:
|
||||
# 完全匹配(不使用 LLM 总结)
|
||||
# 模拟合并后的 content_list 和 style_list
|
||||
target_content_list = parse_content_list(exact_match.content_list)
|
||||
target_content_list.append(target_situation)
|
||||
|
||||
target_style_list = parse_style_list(exact_match.style_list)
|
||||
if exact_match.style and exact_match.style not in target_style_list:
|
||||
target_style_list.append(exact_match.style)
|
||||
if target_style not in target_style_list:
|
||||
target_style_list.append(target_style)
|
||||
|
||||
merge_info = {
|
||||
"type": "exact",
|
||||
"source_id": expr.id,
|
||||
"target_id": exact_match.id,
|
||||
"source_style": target_style,
|
||||
"target_style": exact_match.style,
|
||||
"source_situation": target_situation,
|
||||
"target_situation": exact_match.situation,
|
||||
"similarity": 1.0,
|
||||
"merged_content_list": target_content_list,
|
||||
"merged_style_list": target_style_list,
|
||||
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
|
||||
"merged_style": exact_match.style # 完全匹配时保持原 style
|
||||
}
|
||||
chat_stat["exact_matches"] += 1
|
||||
chat_stat["merges"].append(merge_info)
|
||||
merge_stats["exact_matches"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
continue
|
||||
|
||||
# 第二层:检查相似匹配
|
||||
similar_match = find_similar_style_expression(
|
||||
chat_expressions,
|
||||
target_style,
|
||||
chat_id,
|
||||
similarity_threshold,
|
||||
{expr.id}
|
||||
)
|
||||
|
||||
if similar_match:
|
||||
match_expr, similarity = similar_match
|
||||
# 相似匹配(使用 LLM 总结)
|
||||
# 模拟合并后的 content_list 和 style_list
|
||||
target_content_list = parse_content_list(match_expr.content_list)
|
||||
target_content_list.append(target_situation)
|
||||
|
||||
target_style_list = parse_style_list(match_expr.style_list)
|
||||
if match_expr.style and match_expr.style not in target_style_list:
|
||||
target_style_list.append(match_expr.style)
|
||||
if target_style not in target_style_list:
|
||||
target_style_list.append(target_style)
|
||||
|
||||
# 使用 LLM 总结(如果启用)
|
||||
merged_situation = match_expr.situation
|
||||
merged_style = match_expr.style or target_style
|
||||
|
||||
if use_llm and summary_model:
|
||||
try:
|
||||
merged_situation = await compose_situation_text(target_content_list, summary_model)
|
||||
merged_style = await compose_style_text(target_style_list, summary_model)
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
|
||||
# 如果总结失败,使用 fallback
|
||||
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
|
||||
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
|
||||
else:
|
||||
# 不使用 LLM 时,使用简单拼接
|
||||
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
|
||||
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
|
||||
|
||||
merge_info = {
|
||||
"type": "similar",
|
||||
"source_id": expr.id,
|
||||
"target_id": match_expr.id,
|
||||
"source_style": target_style,
|
||||
"target_style": match_expr.style,
|
||||
"source_situation": target_situation,
|
||||
"target_situation": match_expr.situation,
|
||||
"similarity": similarity,
|
||||
"merged_content_list": target_content_list,
|
||||
"merged_style_list": target_style_list,
|
||||
"merged_situation": merged_situation,
|
||||
"merged_style": merged_style,
|
||||
"llm_used": use_llm and summary_model is not None
|
||||
}
|
||||
chat_stat["similar_matches"] += 1
|
||||
chat_stat["merges"].append(merge_info)
|
||||
merge_stats["similar_matches"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
continue
|
||||
|
||||
# 没有匹配,作为新记录
|
||||
chat_stat["new_records"] += 1
|
||||
merge_stats["new_records"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
|
||||
merge_stats["chat_stats"][chat_id] = chat_stat
|
||||
merge_stats["merge_details"].extend(chat_stat["merges"])
|
||||
|
||||
return merge_stats
|
||||
|
||||
|
||||
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
|
||||
"""打印合并结果"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Expression 合并模拟结果")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n📊 总体统计:")
|
||||
print(f" 总 Expression 数: {stats['total_expressions']}")
|
||||
print(f" 总聊天数: {stats['total_chats']}")
|
||||
print(f" 完全匹配合并: {stats['exact_matches']}")
|
||||
print(f" 相似匹配合并: {stats['similar_matches']}")
|
||||
print(f" 新记录(无匹配): {stats['new_records']}")
|
||||
if stats.get('use_llm'):
|
||||
print(" LLM 总结: 已启用")
|
||||
else:
|
||||
print(" LLM 总结: 未启用(仅模拟)")
|
||||
|
||||
total_merges = stats['exact_matches'] + stats['similar_matches']
|
||||
if stats['total_expressions'] > 0:
|
||||
merge_ratio = (total_merges / stats['total_expressions']) * 100
|
||||
print(f" 合并比例: {merge_ratio:.1f}%")
|
||||
|
||||
# 按聊天分组显示
|
||||
print("\n📋 按聊天分组统计:")
|
||||
for chat_id, chat_stat in stats['chat_stats'].items():
|
||||
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
|
||||
print(f" 总数: {chat_stat['total']}")
|
||||
print(f" 完全匹配: {chat_stat['exact_matches']}")
|
||||
print(f" 相似匹配: {chat_stat['similar_matches']}")
|
||||
print(f" 新记录: {chat_stat['new_records']}")
|
||||
|
||||
# 显示合并详情
|
||||
if show_details and stats['merge_details']:
|
||||
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
|
||||
print()
|
||||
|
||||
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
|
||||
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
|
||||
print(f" {idx}. {merge_type}")
|
||||
print(f" 源记录 ID: {merge['source_id']}")
|
||||
print(f" 目标记录 ID: {merge['target_id']}")
|
||||
print(f" 源 Style: {merge['source_style'][:50]}")
|
||||
print(f" 目标 Style: {merge['target_style'][:50]}")
|
||||
print(f" 源 Situation: {merge['source_situation'][:50]}")
|
||||
print(f" 目标 Situation: {merge['target_situation'][:50]}")
|
||||
|
||||
# 显示合并后的结果
|
||||
if 'merged_situation' in merge:
|
||||
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
|
||||
if 'merged_style' in merge:
|
||||
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
|
||||
if merge.get('llm_used'):
|
||||
print(" → LLM 总结: 已使用")
|
||||
elif merge['type'] == 'similar':
|
||||
print(" → LLM 总结: 未使用(模拟模式)")
|
||||
|
||||
# 显示合并后的列表
|
||||
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
|
||||
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
|
||||
if len(merge['merged_content_list']) > 3:
|
||||
print(f" ... 还有 {len(merge['merged_content_list']) - 3} 项")
|
||||
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
|
||||
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
|
||||
if len(merge['merged_style_list']) > 3:
|
||||
print(f" ... 还有 {len(merge['merged_style_list']) - 3} 项")
|
||||
print()
|
||||
|
||||
if len(stats['merge_details']) > max_details:
|
||||
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="指定要分析的 chat_id(不指定则分析所有)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--similarity-threshold",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="相似度阈值 (0-1, 默认: 0.75)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-details",
|
||||
action="store_true",
|
||||
help="不显示详细信息,只显示统计"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-details",
|
||||
type=int,
|
||||
default=50,
|
||||
help="最多显示的合并详情数 (默认: 50)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-llm",
|
||||
action="store_true",
|
||||
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-samples",
|
||||
type=int,
|
||||
default=10,
|
||||
help="最多随机抽取的 Expression 数量 (默认: 10,设置为 0 表示不限制)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证阈值
|
||||
if not 0 <= args.similarity_threshold <= 1:
|
||||
print("错误: similarity-threshold 必须在 0-1 之间")
|
||||
return
|
||||
|
||||
# 确定输出文件路径
|
||||
if args.output:
|
||||
output_file = args.output
|
||||
else:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
|
||||
|
||||
# 查询 Expression 记录
|
||||
print("正在从数据库加载Expression数据...")
|
||||
try:
|
||||
if args.chat_id:
|
||||
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
|
||||
else:
|
||||
expressions = list(Expression.select())
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
|
||||
except Exception as e:
|
||||
print(f"❌ 加载数据失败: {e}")
|
||||
return
|
||||
|
||||
if not expressions:
|
||||
print("❌ 数据库中没有找到Expression记录")
|
||||
return
|
||||
|
||||
# 执行合并模拟
|
||||
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples})...")
|
||||
if args.use_llm:
|
||||
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
|
||||
else:
|
||||
print("ℹ️ 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
|
||||
|
||||
stats = asyncio.run(
|
||||
simulate_merge(
|
||||
expressions,
|
||||
similarity_threshold=args.similarity_threshold,
|
||||
use_llm=args.use_llm,
|
||||
max_samples=args.max_samples,
|
||||
)
|
||||
)
|
||||
|
||||
# 输出结果
|
||||
original_stdout = sys.stdout
|
||||
try:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
sys.stdout = f
|
||||
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
|
||||
sys.stdout = original_stdout
|
||||
|
||||
# 同时在控制台输出
|
||||
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
|
||||
|
||||
except Exception as e:
|
||||
sys.stdout = original_stdout
|
||||
print(f"❌ 写入文件失败: {e}")
|
||||
return
|
||||
|
||||
print(f"\n✅ 模拟结果已保存到: {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
559
scripts/expression_similarity_analysis.py
Normal file
559
scripts/expression_similarity_analysis.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""
|
||||
分析expression库中situation和style的相似度
|
||||
|
||||
用法:
|
||||
python scripts/expression_similarity_analysis.py
|
||||
或指定阈值:
|
||||
python scripts/expression_similarity_analysis.py --situation-threshold 0.8 --style-threshold 0.7
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from typing import List, Tuple
|
||||
from collections import defaultdict
|
||||
from difflib import SequenceMatcher
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Import after setting up path (required for project imports)
|
||||
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
||||
from src.config.config import global_config # noqa: E402
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager # noqa: E402
|
||||
|
||||
|
||||
class TeeOutput:
|
||||
"""同时输出到控制台和文件的类"""
|
||||
def __init__(self, file_path: str):
|
||||
self.file = open(file_path, "w", encoding="utf-8")
|
||||
self.console = sys.stdout
|
||||
|
||||
def write(self, text: str):
|
||||
"""写入文本到控制台和文件"""
|
||||
self.console.write(text)
|
||||
self.file.write(text)
|
||||
self.file.flush() # 立即刷新到文件
|
||||
|
||||
def flush(self):
|
||||
"""刷新输出"""
|
||||
self.console.flush()
|
||||
self.file.flush()
|
||||
|
||||
def close(self):
|
||||
"""关闭文件"""
|
||||
if self.file:
|
||||
self.file.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||
"""
|
||||
解析'platform:id:type'为chat_id,直接复用 ChatManager 的逻辑
|
||||
"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def build_chat_id_groups() -> dict[str, set[str]]:
|
||||
"""
|
||||
根据expression_groups配置,构建chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
dict: {chat_id: set of related chat_ids (including itself)}
|
||||
"""
|
||||
groups = global_config.expression.expression_groups
|
||||
chat_id_groups: dict[str, set[str]] = {}
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,收集所有配置中的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if stream_config_str == "*":
|
||||
continue
|
||||
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
|
||||
# 所有chat_id都互相相关
|
||||
for chat_id in all_chat_ids:
|
||||
chat_id_groups[chat_id] = all_chat_ids.copy()
|
||||
else:
|
||||
# 处理普通组
|
||||
for group in groups:
|
||||
group_chat_ids = set()
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.add(chat_id_candidate)
|
||||
|
||||
# 组内的所有chat_id都互相相关
|
||||
for chat_id in group_chat_ids:
|
||||
if chat_id not in chat_id_groups:
|
||||
chat_id_groups[chat_id] = set()
|
||||
chat_id_groups[chat_id].update(group_chat_ids)
|
||||
|
||||
# 确保每个chat_id至少包含自身
|
||||
for chat_id in chat_id_groups:
|
||||
chat_id_groups[chat_id].add(chat_id)
|
||||
|
||||
return chat_id_groups
|
||||
|
||||
|
||||
def are_chat_ids_related(chat_id1: str, chat_id2: str, chat_id_groups: dict[str, set[str]]) -> bool:
|
||||
"""
|
||||
判断两个chat_id是否相关(相同或同组)
|
||||
|
||||
Args:
|
||||
chat_id1: 第一个chat_id
|
||||
chat_id2: 第二个chat_id
|
||||
chat_id_groups: chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
bool: 如果两个chat_id相同或同组,返回True
|
||||
"""
|
||||
if chat_id1 == chat_id2:
|
||||
return True
|
||||
|
||||
# 如果chat_id1在映射中,检查chat_id2是否在其相关集合中
|
||||
if chat_id1 in chat_id_groups:
|
||||
return chat_id2 in chat_id_groups[chat_id1]
|
||||
|
||||
# 如果chat_id1不在映射中,说明它不在任何组中,只与自己相关
|
||||
return False
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name}"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id[:8]}...)"
|
||||
|
||||
|
||||
def text_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
使用SequenceMatcher计算相似度,返回0-1之间的值
|
||||
在计算前会移除"使用"和"句式"这两个词
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# 移除"使用"和"句式"这两个词
|
||||
def remove_ignored_words(text: str) -> str:
|
||||
"""移除需要忽略的词"""
|
||||
text = text.replace("使用", "")
|
||||
text = text.replace("句式", "")
|
||||
return text.strip()
|
||||
|
||||
cleaned_text1 = remove_ignored_words(text1)
|
||||
cleaned_text2 = remove_ignored_words(text2)
|
||||
|
||||
# 如果清理后文本为空,返回0
|
||||
if not cleaned_text1 or not cleaned_text2:
|
||||
return 0.0
|
||||
|
||||
return SequenceMatcher(None, cleaned_text1, cleaned_text2).ratio()
|
||||
|
||||
|
||||
def find_similar_pairs(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
max_pairs: int = None
|
||||
) -> List[Tuple[int, int, float, str, str]]:
|
||||
"""
|
||||
找出相似的expression对
|
||||
|
||||
Args:
|
||||
expressions: Expression对象列表
|
||||
field_name: 要比较的字段名 ('situation' 或 'style')
|
||||
threshold: 相似度阈值 (0-1)
|
||||
max_pairs: 最多返回的对数,None表示返回所有
|
||||
|
||||
Returns:
|
||||
List of (index1, index2, similarity, text1, text2) tuples
|
||||
"""
|
||||
similar_pairs = []
|
||||
n = len(expressions)
|
||||
|
||||
print(f"正在分析 {field_name} 字段的相似度...")
|
||||
print(f"总共需要比较 {n * (n - 1) // 2} 对...")
|
||||
|
||||
for i in range(n):
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f" 已处理 {i + 1}/{n} 个项目...")
|
||||
|
||||
expr1 = expressions[i]
|
||||
text1 = getattr(expr1, field_name, "")
|
||||
|
||||
for j in range(i + 1, n):
|
||||
expr2 = expressions[j]
|
||||
text2 = getattr(expr2, field_name, "")
|
||||
|
||||
similarity = text_similarity(text1, text2)
|
||||
|
||||
if similarity >= threshold:
|
||||
similar_pairs.append((i, j, similarity, text1, text2))
|
||||
|
||||
# 按相似度降序排序
|
||||
similar_pairs.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
if max_pairs:
|
||||
similar_pairs = similar_pairs[:max_pairs]
|
||||
|
||||
return similar_pairs
|
||||
|
||||
|
||||
def group_similar_items(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
chat_id_groups: dict[str, set[str]]
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
将相似的expression分组(仅比较相同chat_id或同组的项目)
|
||||
|
||||
Args:
|
||||
expressions: Expression对象列表
|
||||
field_name: 要比较的字段名 ('situation' 或 'style')
|
||||
threshold: 相似度阈值 (0-1)
|
||||
chat_id_groups: chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
List of groups, each group is a list of indices
|
||||
"""
|
||||
n = len(expressions)
|
||||
# 使用并查集的思想来分组
|
||||
parent = list(range(n))
|
||||
|
||||
def find(x):
|
||||
if parent[x] != x:
|
||||
parent[x] = find(parent[x])
|
||||
return parent[x]
|
||||
|
||||
def union(x, y):
|
||||
px, py = find(x), find(y)
|
||||
if px != py:
|
||||
parent[px] = py
|
||||
|
||||
print(f"正在对 {field_name} 字段进行分组(仅比较相同chat_id或同组的项目)...")
|
||||
|
||||
# 统计需要比较的对数
|
||||
total_pairs = 0
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
if are_chat_ids_related(expressions[i].chat_id, expressions[j].chat_id, chat_id_groups):
|
||||
total_pairs += 1
|
||||
|
||||
print(f"总共需要比较 {total_pairs} 对(已过滤不同chat_id且不同组的项目)...")
|
||||
|
||||
compared_pairs = 0
|
||||
for i in range(n):
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f" 已处理 {i + 1}/{n} 个项目...")
|
||||
|
||||
expr1 = expressions[i]
|
||||
text1 = getattr(expr1, field_name, "")
|
||||
|
||||
for j in range(i + 1, n):
|
||||
expr2 = expressions[j]
|
||||
|
||||
# 只比较相同chat_id或同组的项目
|
||||
if not are_chat_ids_related(expr1.chat_id, expr2.chat_id, chat_id_groups):
|
||||
continue
|
||||
|
||||
compared_pairs += 1
|
||||
text2 = getattr(expr2, field_name, "")
|
||||
|
||||
similarity = text_similarity(text1, text2)
|
||||
|
||||
if similarity >= threshold:
|
||||
union(i, j)
|
||||
|
||||
# 收集分组
|
||||
groups = defaultdict(list)
|
||||
for i in range(n):
|
||||
root = find(i)
|
||||
groups[root].append(i)
|
||||
|
||||
# 只返回包含多个项目的组
|
||||
result = [group for group in groups.values() if len(group) > 1]
|
||||
result.sort(key=len, reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_similarity_analysis(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
chat_id_groups: dict[str, set[str]],
|
||||
show_details: bool = True,
|
||||
max_groups: int = 20
|
||||
):
|
||||
"""打印相似度分析结果"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{field_name.upper()} 相似度分析 (阈值: {threshold})")
|
||||
print("=" * 80)
|
||||
|
||||
# 分组分析
|
||||
groups = group_similar_items(expressions, field_name, threshold, chat_id_groups)
|
||||
|
||||
total_items = len(expressions)
|
||||
similar_items_count = sum(len(group) for group in groups)
|
||||
unique_groups = len(groups)
|
||||
|
||||
print("\n📊 统计信息:")
|
||||
print(f" 总项目数: {total_items}")
|
||||
print(f" 相似项目数: {similar_items_count} ({similar_items_count / total_items * 100:.1f}%)")
|
||||
print(f" 相似组数: {unique_groups}")
|
||||
print(f" 平均每组项目数: {similar_items_count / unique_groups:.1f}" if unique_groups > 0 else " 平均每组项目数: 0")
|
||||
|
||||
if not groups:
|
||||
print(f"\n未找到相似度 >= {threshold} 的项目组")
|
||||
return
|
||||
|
||||
print(f"\n📋 相似组详情 (显示前 {min(max_groups, len(groups))} 组):")
|
||||
print()
|
||||
|
||||
for group_idx, group in enumerate(groups[:max_groups], 1):
|
||||
print(f"组 {group_idx} (共 {len(group)} 个项目):")
|
||||
|
||||
if show_details:
|
||||
# 显示组内所有项目的详细信息
|
||||
for idx in group:
|
||||
expr = expressions[idx]
|
||||
text = getattr(expr, field_name, "")
|
||||
chat_name = get_chat_name(expr.chat_id)
|
||||
|
||||
# 截断过长的文本
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
|
||||
print(f" [{expr.id}] {display_text}")
|
||||
print(f" 聊天: {chat_name}, Count: {expr.count}")
|
||||
|
||||
# 计算组内平均相似度
|
||||
if len(group) > 1:
|
||||
similarities = []
|
||||
above_threshold_pairs = [] # 存储满足阈值的相似对
|
||||
above_threshold_count = 0
|
||||
for i in range(len(group)):
|
||||
for j in range(i + 1, len(group)):
|
||||
text1 = getattr(expressions[group[i]], field_name, "")
|
||||
text2 = getattr(expressions[group[j]], field_name, "")
|
||||
sim = text_similarity(text1, text2)
|
||||
similarities.append(sim)
|
||||
if sim >= threshold:
|
||||
above_threshold_count += 1
|
||||
# 存储满足阈值的对的信息
|
||||
expr1 = expressions[group[i]]
|
||||
expr2 = expressions[group[j]]
|
||||
display_text1 = text1[:40] + "..." if len(text1) > 40 else text1
|
||||
display_text2 = text2[:40] + "..." if len(text2) > 40 else text2
|
||||
above_threshold_pairs.append((
|
||||
expr1.id, display_text1,
|
||||
expr2.id, display_text2,
|
||||
sim
|
||||
))
|
||||
|
||||
if similarities:
|
||||
avg_sim = sum(similarities) / len(similarities)
|
||||
min_sim = min(similarities)
|
||||
max_sim = max(similarities)
|
||||
above_threshold_ratio = above_threshold_count / len(similarities) * 100
|
||||
print(f" 平均相似度: {avg_sim:.3f} (范围: {min_sim:.3f} - {max_sim:.3f})")
|
||||
print(f" 满足阈值({threshold})的比例: {above_threshold_ratio:.1f}% ({above_threshold_count}/{len(similarities)})")
|
||||
|
||||
# 显示满足阈值的相似对(这些是直接连接,导致它们被分到一组)
|
||||
if above_threshold_pairs:
|
||||
print(" ⚠️ 直接相似的对 (这些对导致它们被分到一组):")
|
||||
# 按相似度降序排序
|
||||
above_threshold_pairs.sort(key=lambda x: x[4], reverse=True)
|
||||
for idx1, text1, idx2, text2, sim in above_threshold_pairs[:10]: # 最多显示10对
|
||||
print(f" [{idx1}] ↔ [{idx2}]: {sim:.3f}")
|
||||
print(f" \"{text1}\" ↔ \"{text2}\"")
|
||||
if len(above_threshold_pairs) > 10:
|
||||
print(f" ... 还有 {len(above_threshold_pairs) - 10} 对满足阈值")
|
||||
else:
|
||||
print(f" ⚠️ 警告: 组内没有任何对满足阈值({threshold:.2f}),可能是通过传递性连接")
|
||||
else:
|
||||
# 只显示组内第一个项目作为示例
|
||||
expr = expressions[group[0]]
|
||||
text = getattr(expr, field_name, "")
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" 示例: {display_text}")
|
||||
print(f" ... 还有 {len(group) - 1} 个相似项目")
|
||||
|
||||
print()
|
||||
|
||||
if len(groups) > max_groups:
|
||||
print(f"... 还有 {len(groups) - max_groups} 组未显示")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="分析expression库中situation和style的相似度")
|
||||
parser.add_argument(
|
||||
"--situation-threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="situation相似度阈值 (0-1, 默认: 0.7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--style-threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="style相似度阈值 (0-1, 默认: 0.7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-details",
|
||||
action="store_true",
|
||||
help="不显示详细信息,只显示统计"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-groups",
|
||||
type=int,
|
||||
default=20,
|
||||
help="最多显示的组数 (默认: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证阈值
|
||||
if not 0 <= args.situation_threshold <= 1:
|
||||
print("错误: situation-threshold 必须在 0-1 之间")
|
||||
return
|
||||
if not 0 <= args.style_threshold <= 1:
|
||||
print("错误: style-threshold 必须在 0-1 之间")
|
||||
return
|
||||
|
||||
# 确定输出文件路径
|
||||
if args.output:
|
||||
output_file = args.output
|
||||
else:
|
||||
# 自动生成带时间戳的输出文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, f"expression_similarity_analysis_{timestamp}.txt")
|
||||
|
||||
# 使用TeeOutput同时输出到控制台和文件
|
||||
with TeeOutput(output_file) as tee:
|
||||
# 临时替换sys.stdout
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = tee
|
||||
|
||||
try:
|
||||
print("=" * 80)
|
||||
print("Expression 相似度分析工具")
|
||||
print("=" * 80)
|
||||
print(f"输出文件: {output_file}")
|
||||
print()
|
||||
|
||||
_run_analysis(args)
|
||||
|
||||
finally:
|
||||
# 恢复原始stdout
|
||||
sys.stdout = original_stdout
|
||||
|
||||
print(f"\n✅ 分析结果已保存到: {output_file}")
|
||||
|
||||
|
||||
def _run_analysis(args):
|
||||
"""执行分析的主逻辑"""
|
||||
|
||||
# 查询所有Expression记录
|
||||
print("正在从数据库加载Expression数据...")
|
||||
try:
|
||||
expressions = list(Expression.select())
|
||||
except Exception as e:
|
||||
print(f"❌ 加载数据失败: {e}")
|
||||
return
|
||||
|
||||
if not expressions:
|
||||
print("❌ 数据库中没有找到Expression记录")
|
||||
return
|
||||
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
|
||||
print()
|
||||
|
||||
# 构建chat_id分组映射
|
||||
print("正在构建chat_id分组映射(根据expression_groups配置)...")
|
||||
try:
|
||||
chat_id_groups = build_chat_id_groups()
|
||||
print(f"✅ 成功构建 {len(chat_id_groups)} 个chat_id的分组映射")
|
||||
if chat_id_groups:
|
||||
# 统计分组信息
|
||||
total_related = sum(len(related) for related in chat_id_groups.values())
|
||||
avg_related = total_related / len(chat_id_groups)
|
||||
print(f" 平均每个chat_id与 {avg_related:.1f} 个chat_id相关(包括自身)")
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"⚠️ 构建chat_id分组映射失败: {e}")
|
||||
print(" 将使用默认行为:只比较相同chat_id的项目")
|
||||
chat_id_groups = {}
|
||||
|
||||
# 分析situation相似度
|
||||
print_similarity_analysis(
|
||||
expressions,
|
||||
"situation",
|
||||
args.situation_threshold,
|
||||
chat_id_groups,
|
||||
show_details=not args.no_details,
|
||||
max_groups=args.max_groups
|
||||
)
|
||||
|
||||
# 分析style相似度
|
||||
print_similarity_analysis(
|
||||
expressions,
|
||||
"style",
|
||||
args.style_threshold,
|
||||
chat_id_groups,
|
||||
show_details=not args.no_details,
|
||||
max_groups=args.max_groups
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("分析完成!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
# print("未找到quick_algo库,无法使用quick_algo算法")
|
||||
# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace")
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from time import sleep
|
||||
from typing import Optional
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
@@ -71,7 +73,12 @@ def hash_deduplicate(
|
||||
return new_raw_paragraphs, new_triple_list_data
|
||||
|
||||
|
||||
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
|
||||
def handle_import_openie(
|
||||
openie_data: OpenIE,
|
||||
embed_manager: EmbeddingManager,
|
||||
kg_manager: KGManager,
|
||||
non_interactive: bool = False,
|
||||
) -> bool:
|
||||
# sourcery skip: extract-method
|
||||
# 从OpenIE数据中提取段落原文与三元组列表
|
||||
# 索引的段落原文
|
||||
@@ -124,8 +131,13 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
||||
logger.info("所有数据均完整,没有发现缺失字段。")
|
||||
return False
|
||||
# 新增:提示用户是否删除非法文段继续导入
|
||||
# 将print移到所有logger.error之后,确保不会被冲掉
|
||||
# 在非交互模式下,不再询问用户,而是直接报错终止
|
||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
|
||||
)
|
||||
sys.exit(1)
|
||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||
user_choice = input().strip().lower()
|
||||
if user_choice != "y":
|
||||
@@ -174,20 +186,25 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
||||
return True
|
||||
|
||||
|
||||
async def main_async(): # sourcery skip: dict-comprehension
|
||||
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||
print("推荐使用硅基流动的Pro/BAAI/bge-m3")
|
||||
print("每百万Token费用为0.7元")
|
||||
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
|
||||
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
if non_interactive:
|
||||
logger.warning(
|
||||
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
|
||||
)
|
||||
else:
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||
print("推荐使用硅基流动的Pro/BAAI/bge-m3")
|
||||
print("每百万Token费用为0.7元")
|
||||
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
|
||||
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_openie_dir() # 确保OpenIE目录存在
|
||||
logger.info("----开始导入openie数据----\n")
|
||||
@@ -235,14 +252,27 @@ async def main_async(): # sourcery skip: dict-comprehension
|
||||
except Exception as e:
|
||||
logger.error(f"导入OpenIE数据文件时发生错误:{e}")
|
||||
return False
|
||||
if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
|
||||
if handle_import_openie(openie_data, embed_manager, kg_manager, non_interactive=non_interactive) is False:
|
||||
logger.error("处理OpenIE数据时发生错误")
|
||||
return False
|
||||
return None
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 设置新的事件循环并运行异步主函数"""
|
||||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
"""主函数 - 解析参数并运行异步主流程。"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
|
||||
"将其导入到 LPMM 的向量库与知识图中。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help="非交互模式:跳过导入确认提示以及非法文段删除询问,遇到非法文段时直接报错退出。",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# 检查是否有现有的事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -255,13 +285,22 @@ def main():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
ok: bool = False
|
||||
try:
|
||||
# 在新的事件循环中运行异步主函数
|
||||
loop.run_until_complete(main_async())
|
||||
ok = loop.run_until_complete(main_async(non_interactive=args.non_interactive))
|
||||
print(
|
||||
"\n[NOTICE] OpenIE 导入脚本执行完毕。如主程序(聊天 / WebUI)已在运行,"
|
||||
"请重启主程序,或在主程序内部调用一次 lpmm_start_up() 以应用最新 LPMM 知识库。"
|
||||
)
|
||||
print("[NOTICE] 如果不清楚 lpmm_start_up 是什么,直接重启主程序即可。")
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
if not loop.is_closed():
|
||||
loop.close()
|
||||
if not ok:
|
||||
# 统一错误码,方便在非交互场景下检测失败
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
@@ -5,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock, Event
|
||||
import sys
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
# 添加项目根目录到 sys.path
|
||||
@@ -115,22 +117,34 @@ def signal_handler(_signum, _frame):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension-to-generator, extract-method
|
||||
# 设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
ensure_dirs() # 确保目录存在
|
||||
# 新增用户确认提示
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。")
|
||||
print("建议使用硅基流动的非Pro模型")
|
||||
print("或者使用可以用赠金抵扣的Pro模型")
|
||||
print("请确保账户余额充足,并且在执行前确认无误。")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
if non_interactive:
|
||||
logger.warning(
|
||||
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
|
||||
)
|
||||
else:
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。")
|
||||
print("建议使用硅基流动的非Pro模型")
|
||||
print("或者使用可以用赠金抵扣的Pro模型")
|
||||
print("请确保账户余额充足,并且在执行前确认无误。")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
|
||||
# 友好提示:说明“网络错误(可重试)”日志属于正常自动重试行为,避免用户误以为任务失败
|
||||
print(
|
||||
"\n提示:在提取过程中,如果看到模型出现“网络错误(可重试)”等日志,"
|
||||
"表示系统正在自动重试请求,一般不会影响整体导入结果,请耐心等待即可。\n"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_dirs() # 确保目录存在
|
||||
logger.info("--------进行信息提取--------\n")
|
||||
@@ -215,5 +229,22 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
logger.info(f"提取失败的文段SHA256:{failed_sha256}")
|
||||
|
||||
|
||||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"LPMM 信息提取脚本:从 data/lpmm_raw_data/*.txt 中读取原始段落,"
|
||||
"调用 LLM 提取实体和三元组,并生成 OpenIE JSON 批次文件。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help="非交互模式:跳过费用确认提示,直接开始执行;适用于 CI / 定时任务等场景。",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
_run(non_interactive=args.non_interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
132
scripts/inspect_lpmm_batch.py
Normal file
132
scripts/inspect_lpmm_batch.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("inspect_lpmm_batch")
|
||||
|
||||
|
||||
def load_openie_hashes(path: Path) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""从 OpenIE JSON 中提取段落 / 实体 / 关系的哈希
|
||||
|
||||
注意:实体既包括 extracted_entities 中的条目,也包括三元组中的主语/宾语,
|
||||
以与 KG 构图逻辑保持一致。
|
||||
"""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
pg_hashes: List[str] = []
|
||||
ent_hashes: List[str] = []
|
||||
rel_hashes: List[str] = []
|
||||
|
||||
for doc in data.get("docs", []):
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
idx = doc.get("idx")
|
||||
if isinstance(idx, str) and idx.strip():
|
||||
pg_hashes.append(idx.strip())
|
||||
|
||||
ents = doc.get("extracted_entities", [])
|
||||
if isinstance(ents, list):
|
||||
for e in ents:
|
||||
if isinstance(e, str):
|
||||
ent_hashes.append(get_sha256(e))
|
||||
|
||||
triples = doc.get("extracted_triples", [])
|
||||
if isinstance(triples, list):
|
||||
for t in triples:
|
||||
if isinstance(t, list) and len(t) == 3:
|
||||
# 主语/宾语作为实体参与构图
|
||||
subj, _, obj = t
|
||||
if isinstance(subj, str):
|
||||
ent_hashes.append(get_sha256(subj))
|
||||
if isinstance(obj, str):
|
||||
ent_hashes.append(get_sha256(obj))
|
||||
rel_hashes.append(get_sha256(str(tuple(t))))
|
||||
|
||||
# 去重但保留顺序
|
||||
def unique(seq: List[str]) -> List[str]:
|
||||
seen = set()
|
||||
return [x for x in seq if not (x in seen or seen.add(x))]
|
||||
|
||||
return unique(pg_hashes), unique(ent_hashes), unique(rel_hashes)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="检查指定 OpenIE 文件对应批次在当前向量库与 KG 中的存在情况(用于验证删除效果)。"
|
||||
)
|
||||
parser.add_argument("--openie-file", required=True, help="OpenIE 输出 JSON 文件路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
openie_path = Path(args.openie_file)
|
||||
if not openie_path.exists():
|
||||
logger.error(f"OpenIE 文件不存在: {openie_path}")
|
||||
sys.exit(1)
|
||||
|
||||
pg_hashes, ent_hashes, rel_hashes = load_openie_hashes(openie_path)
|
||||
logger.info(
|
||||
f"从 {openie_path.name} 解析到 段落 {len(pg_hashes)} 条,实体 {len(ent_hashes)} 个,关系 {len(rel_hashes)} 条"
|
||||
)
|
||||
|
||||
# 加载当前嵌入与 KG
|
||||
em = EmbeddingManager()
|
||||
kg = KGManager()
|
||||
try:
|
||||
em.load_from_file()
|
||||
kg.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"加载当前知识库失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
graph_nodes = set(kg.graph.get_node_list())
|
||||
|
||||
# 检查段落
|
||||
pg_keys = [f"paragraph-{h}" for h in pg_hashes]
|
||||
pg_in_vec = sum(1 for k in pg_keys if k in em.paragraphs_embedding_store.store)
|
||||
pg_in_kg = sum(1 for k in pg_keys if k in graph_nodes)
|
||||
|
||||
# 检查实体
|
||||
ent_keys = [f"entity-{h}" for h in ent_hashes]
|
||||
ent_in_vec = sum(1 for k in ent_keys if k in em.entities_embedding_store.store)
|
||||
ent_in_kg = sum(1 for k in ent_keys if k in graph_nodes)
|
||||
|
||||
# 检查关系(只针对向量库)
|
||||
rel_keys = [f"relation-{h}" for h in rel_hashes]
|
||||
rel_in_vec = sum(1 for k in rel_keys if k in em.relation_embedding_store.store)
|
||||
|
||||
print("==== 批次存在情况(删除前/后对比用) ====")
|
||||
print(f"段落: 总计 {len(pg_keys)}, 向量库剩余 {pg_in_vec}, KG 中剩余 {pg_in_kg}")
|
||||
print(f"实体: 总计 {len(ent_keys)}, 向量库剩余 {ent_in_vec}, KG 中剩余 {ent_in_kg}")
|
||||
print(f"关系: 总计 {len(rel_keys)}, 向量库剩余 {rel_in_vec}")
|
||||
|
||||
# 打印少量仍存在的样例,便于检查内容是否正常
|
||||
sample_pg = [k for k in pg_keys if k in graph_nodes][:3]
|
||||
if sample_pg:
|
||||
print("\n仍在 KG 中的段落节点示例:")
|
||||
for k in sample_pg:
|
||||
nd = kg.graph[k]
|
||||
content = nd["content"] if "content" in nd else k
|
||||
print(f"- {k}: {content[:80]}")
|
||||
|
||||
sample_ent = [k for k in ent_keys if k in graph_nodes][:3]
|
||||
if sample_ent:
|
||||
print("\n仍在 KG 中的实体节点示例:")
|
||||
for k in sample_ent:
|
||||
nd = kg.graph[k]
|
||||
content = nd["content"] if "content" in nd else k
|
||||
print(f"- {k}: {content[:80]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
71
scripts/inspect_lpmm_global.py
Normal file
71
scripts/inspect_lpmm_global.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Set
|
||||
|
||||
# 保证可以导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("inspect_lpmm_global")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""检查当前整库(所有批次)的向量与 KG 状态,用于观察删除对剩余数据的影响。"""
|
||||
em = EmbeddingManager()
|
||||
kg = KGManager()
|
||||
|
||||
try:
|
||||
em.load_from_file()
|
||||
kg.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"加载当前知识库失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 向量库统计
|
||||
para_cnt = len(em.paragraphs_embedding_store.store)
|
||||
ent_cnt_vec = len(em.entities_embedding_store.store)
|
||||
rel_cnt_vec = len(em.relation_embedding_store.store)
|
||||
|
||||
# KG 统计
|
||||
nodes = kg.graph.get_node_list()
|
||||
edges = kg.graph.get_edge_list()
|
||||
node_set: Set[str] = set(nodes)
|
||||
|
||||
para_nodes = [n for n in nodes if n.startswith("paragraph-")]
|
||||
ent_nodes = [n for n in nodes if n.startswith("entity-")]
|
||||
|
||||
print("==== 向量库统计 ====")
|
||||
print(f"段落向量条数: {para_cnt}")
|
||||
print(f"实体向量条数: {ent_cnt_vec}")
|
||||
print(f"关系向量条数: {rel_cnt_vec}")
|
||||
|
||||
print("\n==== KG 图统计 ====")
|
||||
print(f"节点总数: {len(nodes)}")
|
||||
print(f"边总数: {len(edges)}")
|
||||
print(f"段落节点数: {len(para_nodes)}")
|
||||
print(f"实体节点数: {len(ent_nodes)}")
|
||||
|
||||
# ent_appear_cnt 状态
|
||||
ent_cnt_meta = len(kg.ent_appear_cnt)
|
||||
print(f"\n实体计数表条目数: {ent_cnt_meta}")
|
||||
|
||||
# 抽样查看剩余段落/实体内容
|
||||
print("\n==== 剩余段落示例(最多 3 条) ====")
|
||||
for nid in para_nodes[:3]:
|
||||
nd = kg.graph[nid]
|
||||
content = nd["content"] if "content" in nd else nid
|
||||
print(f"- {nid}: {content[:80]}")
|
||||
|
||||
print("\n==== 剩余实体示例(最多 5 条) ====")
|
||||
for nid in ent_nodes[:5]:
|
||||
nd = kg.graph[nid]
|
||||
content = nd["content"] if "content" in nd else nid
|
||||
print(f"- {nid}: {content[:80]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
541
scripts/lpmm_manager.py
Normal file
541
scripts/lpmm_manager.py
Normal file
@@ -0,0 +1,541 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
# 尽量统一控制台编码为 utf-8,避免中文输出报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.* 以及同目录脚本
|
||||
CURRENT_DIR = os.path.dirname(__file__)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from src.common.logger import get_logger # type: ignore
|
||||
from src.config.config import global_config, model_config # type: ignore
|
||||
|
||||
# 引入各功能脚本的入口函数
|
||||
from import_openie import main as import_openie_main # type: ignore
|
||||
from info_extraction import main as info_extraction_main # type: ignore
|
||||
from delete_lpmm_items import main as delete_lpmm_items_main # type: ignore
|
||||
from inspect_lpmm_batch import main as inspect_lpmm_batch_main # type: ignore
|
||||
from inspect_lpmm_global import main as inspect_lpmm_global_main # type: ignore
|
||||
from refresh_lpmm_knowledge import main as refresh_lpmm_knowledge_main # type: ignore
|
||||
from test_lpmm_retrieval import main as test_lpmm_retrieval_main # type: ignore
|
||||
from raw_data_preprocessor import load_raw_data # type: ignore
|
||||
|
||||
|
||||
logger = get_logger("lpmm_manager")
|
||||
|
||||
|
||||
ACTION_INFO = {
|
||||
"prepare_raw": "预处理 data/lpmm_raw_data/*.txt,按空行切分为段落并做去重统计",
|
||||
"info_extract": "原始 txt -> OpenIE 信息抽取(调用 info_extraction.py)",
|
||||
"import_openie": "导入 OpenIE 批次到向量库与知识图(调用 import_openie.py)",
|
||||
"delete": "删除/回滚知识(调用 delete_lpmm_items.py)",
|
||||
"batch_inspect": "检查指定 OpenIE 批次在当前库中的存在情况(调用 inspect_lpmm_batch.py)",
|
||||
"global_inspect": "查看当前整库向量与 KG 状态(调用 inspect_lpmm_global.py)",
|
||||
"refresh": "刷新 LPMM 磁盘数据到内存(调用 refresh_lpmm_knowledge.py)",
|
||||
"test": "运行 LPMM 检索效果回归测试(调用 test_lpmm_retrieval.py)",
|
||||
"embedding_helper": "嵌入模型迁移辅助:查看当前嵌入模型/维度并归档 embedding_model_test.json",
|
||||
"full_import": "一键执行:信息抽取 -> 导入 OpenIE -> 刷新",
|
||||
}
|
||||
|
||||
|
||||
def _with_overridden_argv(extra_args: List[str], target_main) -> None:
|
||||
"""在不修改子脚本的前提下,临时覆盖 sys.argv 以透传参数。"""
|
||||
old_argv = list(sys.argv)
|
||||
try:
|
||||
# 第 0 个元素为“程序名”,后续元素为实际参数
|
||||
# 这里不再插入类似 delete_lpmm_items.py 的占位,避免被 argparse 误识别为位置参数
|
||||
sys.argv = [old_argv[0]] + extra_args
|
||||
target_main()
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
|
||||
|
||||
def _check_before_info_extract(non_interactive: bool = False) -> bool:
|
||||
"""信息抽取前的轻量级检查。"""
|
||||
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
|
||||
txt_files = list(raw_dir.glob("*.txt"))
|
||||
if not txt_files:
|
||||
msg = (
|
||||
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
|
||||
"info_extraction 可能立即退出或无数据可处理。"
|
||||
)
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
|
||||
)
|
||||
return False
|
||||
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
return True
|
||||
|
||||
|
||||
def _check_before_import_openie(non_interactive: bool = False) -> bool:
|
||||
"""导入 OpenIE 前的轻量级检查。"""
|
||||
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
||||
json_files = list(openie_dir.glob("*.json"))
|
||||
if not json_files:
|
||||
msg = (
|
||||
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
|
||||
"import_openie 可能会因为找不到批次而失败。"
|
||||
)
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
|
||||
)
|
||||
return False
|
||||
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
return True
|
||||
|
||||
|
||||
def _warn_if_lpmm_disabled() -> None:
|
||||
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
|
||||
try:
|
||||
if not getattr(global_config.lpmm_knowledge, "enable", False):
|
||||
print(
|
||||
"[WARN] 当前配置 lpmm_knowledge.enable = false,"
|
||||
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
|
||||
)
|
||||
except Exception:
|
||||
# 配置异常时不阻断主流程,仅忽略提示
|
||||
pass
|
||||
|
||||
|
||||
def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
||||
"""根据动作名称调度到对应脚本。
|
||||
|
||||
这里不重复解析子参数,而是直接调用各脚本的 main(),
|
||||
让子脚本保留原有的交互/参数行为。
|
||||
"""
|
||||
logger.info("开始执行操作: %s", action)
|
||||
|
||||
extra_args = extra_args or []
|
||||
|
||||
try:
|
||||
if action == "prepare_raw":
|
||||
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(
|
||||
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
|
||||
f"去重后哈希数 {len(sha_list)}。"
|
||||
)
|
||||
elif action == "info_extract":
|
||||
if not _check_before_info_extract("--non-interactive" in extra_args):
|
||||
print("已根据用户选择,取消执行信息提取。")
|
||||
return
|
||||
_with_overridden_argv(extra_args, info_extraction_main)
|
||||
elif action == "import_openie":
|
||||
if not _check_before_import_openie("--non-interactive" in extra_args):
|
||||
print("已根据用户选择,取消执行导入。")
|
||||
return
|
||||
_with_overridden_argv(extra_args, import_openie_main)
|
||||
elif action == "delete":
|
||||
_with_overridden_argv(extra_args, delete_lpmm_items_main)
|
||||
elif action == "batch_inspect":
|
||||
_with_overridden_argv(extra_args, inspect_lpmm_batch_main)
|
||||
elif action == "global_inspect":
|
||||
_with_overridden_argv(extra_args, inspect_lpmm_global_main)
|
||||
elif action == "refresh":
|
||||
_warn_if_lpmm_disabled()
|
||||
_with_overridden_argv(extra_args, refresh_lpmm_knowledge_main)
|
||||
elif action == "test":
|
||||
_warn_if_lpmm_disabled()
|
||||
_with_overridden_argv(extra_args, test_lpmm_retrieval_main)
|
||||
elif action == "embedding_helper":
|
||||
# 嵌入模型迁移辅助:查看当前嵌入模型/维度并归档 embedding_model_test.json
|
||||
_run_embedding_helper()
|
||||
elif action == "full_import":
|
||||
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
|
||||
logger.info("开始 full_import:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(
|
||||
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
|
||||
f"去重后哈希数 {len(sha_list)}。"
|
||||
)
|
||||
non_interactive = "--non-interactive" in extra_args
|
||||
if not _check_before_info_extract(non_interactive):
|
||||
print("已根据用户选择,取消 full_import(信息提取阶段被取消)。")
|
||||
return
|
||||
# 使用与单步 info_extract 相同的参数透传机制,确保 --non-interactive 等生效
|
||||
_with_overridden_argv(extra_args, info_extraction_main)
|
||||
if not _check_before_import_openie(non_interactive):
|
||||
print("已根据用户选择,取消 full_import(导入阶段被取消)。")
|
||||
return
|
||||
_with_overridden_argv(extra_args, import_openie_main)
|
||||
_warn_if_lpmm_disabled()
|
||||
_with_overridden_argv(extra_args, refresh_lpmm_knowledge_main)
|
||||
else:
|
||||
logger.error("未知操作: %s", action)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断当前操作(Ctrl+C)")
|
||||
except SystemExit:
|
||||
# 子脚本里大量使用 sys.exit,直接透传即可
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
logger.error("执行操作 %s 时发生未捕获异常: %s", action, exc)
|
||||
raise
|
||||
|
||||
|
||||
def print_menu() -> None:
|
||||
print("\n===== LPMM 管理菜单 =====")
|
||||
for idx, key in enumerate(
|
||||
[
|
||||
"prepare_raw",
|
||||
"info_extract",
|
||||
"import_openie",
|
||||
"delete",
|
||||
"batch_inspect",
|
||||
"global_inspect",
|
||||
"refresh",
|
||||
"test",
|
||||
"embedding_helper",
|
||||
"full_import",
|
||||
],
|
||||
start=1,
|
||||
):
|
||||
desc = ACTION_INFO.get(key, "")
|
||||
print(f"{idx}. {key:14s} - {desc}")
|
||||
print("0. 退出")
|
||||
print("=========================")
|
||||
|
||||
|
||||
def interactive_loop() -> None:
|
||||
"""交互式选择模式。"""
|
||||
key_order = [
|
||||
"prepare_raw",
|
||||
"info_extract",
|
||||
"import_openie",
|
||||
"delete",
|
||||
"batch_inspect",
|
||||
"global_inspect",
|
||||
"refresh",
|
||||
"test",
|
||||
"embedding_helper",
|
||||
"full_import",
|
||||
]
|
||||
|
||||
while True:
|
||||
print_menu()
|
||||
choice = input("请输入选项编号(0-10):").strip()
|
||||
|
||||
if choice in ("0", "q", "Q", "quit", "exit"):
|
||||
print("已退出 LPMM 管理器。")
|
||||
return
|
||||
|
||||
try:
|
||||
idx = int(choice)
|
||||
except ValueError:
|
||||
print("输入无效,请输入 0-10 之间的数字。")
|
||||
continue
|
||||
|
||||
if not (1 <= idx <= len(key_order)):
|
||||
print("输入编号超出范围,请重新输入。")
|
||||
continue
|
||||
|
||||
action = key_order[idx - 1]
|
||||
print(f"\n你选择了: {action} - {ACTION_INFO.get(action, '')}")
|
||||
confirm = input("确认执行该操作?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
print("已取消当前操作。\n")
|
||||
continue
|
||||
|
||||
# 通过交互式问题,尽量帮用户补全对应脚本的常用参数
|
||||
extra_args: List[str] = []
|
||||
if action == "delete":
|
||||
extra_args = _interactive_build_delete_args()
|
||||
elif action == "batch_inspect":
|
||||
extra_args = _interactive_build_batch_inspect_args()
|
||||
elif action == "test":
|
||||
extra_args = _interactive_build_test_args()
|
||||
else:
|
||||
extra_args = []
|
||||
|
||||
run_action(action, extra_args=extra_args)
|
||||
print("\n当前操作已结束,回到主菜单。\n")
|
||||
|
||||
|
||||
def _interactive_choose_openie_file(prompt: str) -> Optional[str]:
|
||||
"""在 data/openie 下列出可选 JSON 文件,并返回用户选择的路径。"""
|
||||
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
||||
files = sorted(openie_dir.glob("*.json"))
|
||||
if not files:
|
||||
print(f"[WARN] 在 {openie_dir} 下没有找到任何 OpenIE JSON 文件。")
|
||||
return input(prompt).strip() or None
|
||||
|
||||
print("\n可选的 OpenIE 批次文件:")
|
||||
for i, f in enumerate(files, start=1):
|
||||
print(f"{i}. {f.name}")
|
||||
print("0. 手动输入完整路径")
|
||||
|
||||
while True:
|
||||
choice = input("请选择文件编号:").strip()
|
||||
if choice == "0":
|
||||
manual = input(prompt).strip()
|
||||
return manual or None
|
||||
try:
|
||||
idx = int(choice)
|
||||
except ValueError:
|
||||
print("请输入合法的编号。")
|
||||
continue
|
||||
if 1 <= idx <= len(files):
|
||||
return str(files[idx - 1])
|
||||
print("编号超出范围,请重试。")
|
||||
|
||||
|
||||
def _interactive_build_delete_args() -> List[str]:
|
||||
"""为 delete_lpmm_items 构造常见参数,减少二次交互。"""
|
||||
print(
|
||||
"\n[DELETE] 请选择删除方式:\n"
|
||||
"1. 按哈希文件删除 (--hash-file)\n"
|
||||
"2. 按 OpenIE 批次删除 (--openie-file)\n"
|
||||
"3. 按原始语料文件 + 段落索引删除 (--raw-file + --raw-index)\n"
|
||||
"4. 按关键字搜索现有段落 (--search-text)\n"
|
||||
"回车跳过,由子脚本自行交互。"
|
||||
)
|
||||
mode = input("输入选项编号(1-4,或回车跳过):").strip()
|
||||
args: List[str] = []
|
||||
|
||||
if mode == "1":
|
||||
path = input("请输入哈希文件路径(每行一个 hash):").strip()
|
||||
if path:
|
||||
args += ["--hash-file", path]
|
||||
elif mode == "2":
|
||||
path = _interactive_choose_openie_file("请输入 OpenIE JSON 文件路径:")
|
||||
if path:
|
||||
args += ["--openie-file", path]
|
||||
elif mode == "3":
|
||||
raw_file = input("请输入原始语料 txt 文件路径:").strip()
|
||||
raw_index = input("请输入要删除的段落索引(如 1,3):").strip()
|
||||
if raw_file and raw_index:
|
||||
args += ["--raw-file", raw_file, "--raw-index", raw_index]
|
||||
elif mode == "4":
|
||||
text = input("请输入用于搜索的关键字(出现在段落原文中):").strip()
|
||||
if text:
|
||||
args += ["--search-text", text]
|
||||
else:
|
||||
# 留空则完全交给子脚本交互
|
||||
return []
|
||||
|
||||
# 进一步询问与安全相关的布尔选项
|
||||
print(
|
||||
"\n[DELETE] 接下来是一些安全相关选项的说明:\n"
|
||||
"- 删除实体向量/节点:会一并清理与这些段落关联的实体节点及其向量;\n"
|
||||
"- 删除关系向量:在上面的基础上,额外清理关系向量(一般与删除实体一同使用);\n"
|
||||
"- 删除孤立实体节点:删除后若实体不再连接任何段落,将其从图中移除,避免残留孤点;\n"
|
||||
"- dry-run:只预览将要删除的内容,不真正修改任何数据;\n"
|
||||
"- 跳过交互确认(--yes):直接执行删除操作,适合脚本化或已充分确认的场景;\n"
|
||||
"- 单次最大删除节点数上限:防止一次性删除规模过大,起到误操作保护作用;\n"
|
||||
"- 一般情况下建议同时删除实体向量/节点/关系向量/节点,以确保知识图谱的完整性。"
|
||||
)
|
||||
|
||||
# 快速选项:按推荐方式清理所有相关实体/关系
|
||||
quick_all = input(
|
||||
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): "
|
||||
).strip().lower()
|
||||
if quick_all in ("", "y", "yes"):
|
||||
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
|
||||
else:
|
||||
# 仅当未使用快速方案时,再逐项询问
|
||||
if input("是否同时删除实体向量/节点?(y/N): ").strip().lower() == "y":
|
||||
args.append("--delete-entities")
|
||||
if input("是否同时删除关系向量?(y/N): ").strip().lower() == "y":
|
||||
args.append("--delete-relations")
|
||||
|
||||
if input("是否删除孤立实体节点?(y/N): ").strip().lower() == "y":
|
||||
args.append("--remove-orphan-entities")
|
||||
|
||||
if input("是否以 dry-run 预览而不真正删除?(y/N): ").strip().lower() == "y":
|
||||
args.append("--dry-run")
|
||||
else:
|
||||
if input("是否跳过交互确认直接删除?(默认否,请谨慎) (y/N): ").strip().lower() == "y":
|
||||
args.append("--yes")
|
||||
|
||||
max_nodes = input("单次最大删除节点数上限(回车使用默认 2000):").strip()
|
||||
if max_nodes:
|
||||
args += ["--max-delete-nodes", max_nodes]
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _interactive_build_batch_inspect_args() -> List[str]:
|
||||
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
|
||||
path = _interactive_choose_openie_file(
|
||||
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
|
||||
)
|
||||
if not path:
|
||||
return []
|
||||
return ["--openie-file", path]
|
||||
|
||||
|
||||
def _interactive_build_test_args() -> List[str]:
|
||||
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
|
||||
print(
|
||||
"\n[TEST] 你可以:\n"
|
||||
"- 直接回车使用内置的默认测试用例;\n"
|
||||
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
|
||||
)
|
||||
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
expect = input("请输入期望命中的关键字(可选,多项用逗号分隔):").strip()
|
||||
args: List[str] = ["--query", query]
|
||||
if expect:
|
||||
for kw in expect.split(","):
|
||||
kw = kw.strip()
|
||||
if kw:
|
||||
args.extend(["--expect-keyword", kw])
|
||||
return args
|
||||
|
||||
|
||||
def _run_embedding_helper() -> None:
|
||||
"""嵌入模型迁移辅助:展示当前配置,并安全归档 embedding_model_test.json。"""
|
||||
from src.chat.knowledge.embedding_store import EMBEDDING_TEST_FILE # type: ignore
|
||||
|
||||
# 1. 读取当前配置中的嵌入维度与模型信息
|
||||
current_dim = getattr(getattr(global_config, "lpmm_knowledge", None), "embedding_dimension", None)
|
||||
embed_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
model_ids: List[str] = []
|
||||
if embed_task is not None:
|
||||
model_ids = getattr(embed_task, "model_list", []) or []
|
||||
primary_model = model_ids[0] if model_ids else "unknown"
|
||||
safe_model_name = re.sub(r"[^0-9A-Za-z_.-]+", "_", primary_model) or "unknown"
|
||||
|
||||
print("\n===== 嵌入模型迁移辅助 (embedding_helper) =====")
|
||||
print(f"- 当前嵌入模型标识(model_task_config.embedding.model_list[0]): {primary_model}")
|
||||
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
|
||||
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
|
||||
|
||||
new_dim = input(
|
||||
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
|
||||
).strip()
|
||||
if new_dim and not new_dim.isdigit():
|
||||
print("输入的维度不是纯数字,已取消操作。")
|
||||
return
|
||||
|
||||
print(
|
||||
"\n[重要提示]\n"
|
||||
"- 修改嵌入模型或维度会导致当前磁盘中的旧知识库(data/embedding 下的向量)与新模型不兼容;\n"
|
||||
"- 这通常意味着你需要清空旧的向量/图数据,并重新执行 LPMM 导入流水线;\n"
|
||||
"- 请仅在你**确定要切换嵌入模型/维度**时再继续。\n"
|
||||
)
|
||||
confirm = input("是否已充分评估风险,并准备切换嵌入模型/维度?(y/N): ").strip().lower()
|
||||
if confirm != "y":
|
||||
print("已根据你的选择取消嵌入模型迁移辅助操作。")
|
||||
return
|
||||
|
||||
print(
|
||||
"\n接下来请手动完成以下操作(脚本不会自动修改配置或删除知识库):\n"
|
||||
f"1. 在配置文件中,将 lpmm_knowledge.embedding_dimension 从 {current_dim} 修改为你计划使用的新维度"
|
||||
+ (f"(例如 {new_dim})" if new_dim else "") # 仅作为示例
|
||||
+ ";\n"
|
||||
"2. 根据需要,清空 data/embedding 与相关 KG 数据(data/rag 等),然后重新执行导入流水线;\n"
|
||||
"3. 本脚本将帮助你归档当前的 embedding_model_test.json,避免旧测试文件干扰新模型的校验。\n"
|
||||
)
|
||||
|
||||
# 2. 归档 embedding_model_test.json
|
||||
test_path = Path(EMBEDDING_TEST_FILE)
|
||||
if not test_path.exists():
|
||||
print(f"\n[INFO] 未在 {test_path} 发现 embedding_model_test.json,无需归档。")
|
||||
return
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
archive_name = f"embedding_model_test-{safe_model_name}-{ts}.json"
|
||||
archive_path = test_path.with_name(archive_name)
|
||||
|
||||
# 若不巧重名,简单追加后缀避免覆盖
|
||||
suffix_id = 1
|
||||
while archive_path.exists():
|
||||
archive_name = f"embedding_model_test-{safe_model_name}-{ts}-{suffix_id}.json"
|
||||
archive_path = test_path.with_name(archive_name)
|
||||
suffix_id += 1
|
||||
|
||||
try:
|
||||
test_path.rename(archive_path)
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
logger.error("归档 embedding_model_test.json 失败: %s", exc)
|
||||
print(f"[ERROR] 归档 embedding_model_test.json 失败,请检查文件权限与路径。错误详情已写入日志。")
|
||||
return
|
||||
|
||||
print(
|
||||
f"\n[OK] 已将 {test_path.name} 重命名为 {archive_path.name}。\n"
|
||||
f"- 归档位置: {archive_path}\n"
|
||||
"- 之后再次运行涉及嵌入模型的一致性校验时,将会基于当前配置与新模型生成新的测试文件。\n"
|
||||
"- 在完成配置修改与知识库重导入前,请不要手动再创建名为 embedding_model_test.json 的文件。"
|
||||
)
|
||||
|
||||
|
||||
def parse_args(argv: Optional[list[str]] = None) -> tuple[argparse.Namespace, List[str]]:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"LPMM 管理脚本:集中入口管理 LPMM 的导入 / 删除 / 自检 / 刷新 / 测试等功能。\n"
|
||||
"可以通过 --interactive 进入菜单模式,也可以使用 --action 直接执行单个操作。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="进入交互式菜单模式(推荐给手动运维使用)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--action",
|
||||
choices=list(ACTION_INFO.keys()),
|
||||
help="直接执行指定操作(非交互模式)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help=(
|
||||
"启用非交互模式:lpmm_manager 自身不会再通过 input() 询问是否继续前置检查;"
|
||||
"并会将 --non-interactive 透传给子脚本,以避免子脚本中的交互式确认。"
|
||||
),
|
||||
)
|
||||
# 允许在管理脚本之后继续跟随子脚本参数,例如:
|
||||
# python lpmm_manager.py -a delete -- --hash-file xxx --yes
|
||||
args, unknown = parser.parse_known_args(argv)
|
||||
return args, unknown
|
||||
|
||||
|
||||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
args, extra_args = parse_args(argv)
|
||||
|
||||
# 如果指定了 non-interactive,则不能进入交互式菜单
|
||||
if args.non_interactive and args.interactive:
|
||||
logger.error("不能同时指定 --interactive 与 --non-interactive,请二选一。")
|
||||
sys.exit(1)
|
||||
|
||||
# 没有指定 action 或显式要求交互 -> 进入菜单
|
||||
if args.interactive or not args.action:
|
||||
interactive_loop()
|
||||
return
|
||||
|
||||
# 在非交互模式下,将 --non-interactive 透传给子脚本,避免其内部出现 input() 交互
|
||||
if args.non_interactive:
|
||||
extra_args = ["--non-interactive"] + extra_args
|
||||
|
||||
# 非交互模式:直接执行指定操作
|
||||
run_action(args.action, extra_args=extra_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("lpmm")
|
||||
@@ -59,10 +59,11 @@ def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
- raw_data: 原始数据列表
|
||||
- sha256_list: 原始数据的SHA256集合
|
||||
"""
|
||||
raw_data = _process_multi_files()
|
||||
raw_paragraphs = _process_multi_files()
|
||||
sha256_list = []
|
||||
sha256_set = set()
|
||||
for item in raw_data:
|
||||
raw_data: list[str] = []
|
||||
for item in raw_paragraphs:
|
||||
if not isinstance(item, str):
|
||||
logger.warning(f"数据类型错误:{item}")
|
||||
continue
|
||||
|
||||
66
scripts/refresh_lpmm_knowledge.py
Normal file
66
scripts/refresh_lpmm_knowledge.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import lpmm_start_up, get_qa_manager
|
||||
|
||||
logger = get_logger("refresh_lpmm_knowledge")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logger.info("开始刷新 LPMM 知识库(重新加载向量库与 KG)...")
|
||||
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.warning(
|
||||
"当前配置中 lpmm_knowledge.enable = false,本次仅刷新磁盘数据与内存结构,"
|
||||
"但聊天侧如未启用 LPMM 仍不会在问答中使用知识库。"
|
||||
)
|
||||
|
||||
# 调用标准启动逻辑,内部会加载 data/embedding 与 data/rag
|
||||
lpmm_start_up()
|
||||
|
||||
qa_manager = get_qa_manager()
|
||||
if qa_manager is None:
|
||||
logger.error("刷新后 qa_manager 仍为 None,请检查是否已经成功导入过 LPMM 知识库。")
|
||||
return
|
||||
|
||||
# 简要输出当前知识库规模,方便人工确认
|
||||
embed_manager = qa_manager.embed_manager
|
||||
kg_manager = qa_manager.kg_manager
|
||||
|
||||
para_vec = len(embed_manager.paragraphs_embedding_store.store)
|
||||
ent_vec = len(embed_manager.entities_embedding_store.store)
|
||||
rel_vec = len(embed_manager.relation_embedding_store.store)
|
||||
nodes = len(kg_manager.graph.get_node_list())
|
||||
edges = len(kg_manager.graph.get_edge_list())
|
||||
|
||||
logger.info("LPMM 知识库刷新完成,当前规模:")
|
||||
logger.info(
|
||||
"段落向量=%d, 实体向量=%d, 关系向量=%d, KG节点=%d, KG边=%d",
|
||||
para_vec,
|
||||
ent_vec,
|
||||
rel_vec,
|
||||
nodes,
|
||||
edges,
|
||||
)
|
||||
|
||||
print("\n[REFRESH] 刷新完成,请注意:")
|
||||
print("- 本脚本是在独立进程内执行的,用于验证磁盘数据可以正常加载。")
|
||||
print("- 若主程序已在运行且未在内部调用 lpmm_start_up() 重新初始化,仍需重启或新增管理入口来热刷新。")
|
||||
print("- 如果不清楚 lpmm_start_up 是什么,只需要重启主程序即可。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
303
scripts/replyer_action_stats.py
Normal file
303
scripts/replyer_action_stats.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
统计和展示 replyer 动作选择记录
|
||||
|
||||
用法:
|
||||
python scripts/replyer_action_stats.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
try:
|
||||
from src.common.database.database_model import ChatStreams
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
except ImportError:
|
||||
ChatStreams = None
|
||||
get_chat_manager = None
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
if ChatStreams:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream:
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name}"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊"
|
||||
|
||||
if get_chat_manager:
|
||||
chat_manager = get_chat_manager()
|
||||
stream_name = chat_manager.get_stream_name(chat_id)
|
||||
if stream_name:
|
||||
return stream_name
|
||||
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id[:8]}...)"
|
||||
|
||||
|
||||
def load_records(temp_dir: str = "data/temp") -> List[Dict[str, Any]]:
|
||||
"""加载所有 replyer 动作记录"""
|
||||
records = []
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
if not temp_path.exists():
|
||||
print(f"目录不存在: {temp_dir}")
|
||||
return records
|
||||
|
||||
# 查找所有 replyer_action_*.json 文件
|
||||
pattern = "replyer_action_*.json"
|
||||
for file_path in temp_path.glob(pattern):
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
records.append(data)
|
||||
except Exception as e:
|
||||
print(f"读取文件失败 {file_path}: {e}")
|
||||
|
||||
# 按时间戳排序
|
||||
records.sort(key=lambda x: x.get("timestamp", ""))
|
||||
return records
|
||||
|
||||
|
||||
def format_timestamp(ts: str) -> str:
|
||||
"""格式化时间戳"""
|
||||
try:
|
||||
dt = datetime.fromisoformat(ts)
|
||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except Exception:
|
||||
return ts
|
||||
|
||||
|
||||
def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int]:
|
||||
"""计算时间分布"""
|
||||
now = datetime.now()
|
||||
distribution = {
|
||||
"今天": 0,
|
||||
"昨天": 0,
|
||||
"3天内": 0,
|
||||
"7天内": 0,
|
||||
"30天内": 0,
|
||||
"更早": 0,
|
||||
}
|
||||
|
||||
for record in records:
|
||||
try:
|
||||
ts = record.get("timestamp", "")
|
||||
if not ts:
|
||||
continue
|
||||
dt = datetime.fromisoformat(ts)
|
||||
diff = (now - dt).days
|
||||
|
||||
if diff == 0:
|
||||
distribution["今天"] += 1
|
||||
elif diff == 1:
|
||||
distribution["昨天"] += 1
|
||||
elif diff < 3:
|
||||
distribution["3天内"] += 1
|
||||
elif diff < 7:
|
||||
distribution["7天内"] += 1
|
||||
elif diff < 30:
|
||||
distribution["30天内"] += 1
|
||||
else:
|
||||
distribution["更早"] += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def print_statistics(records: List[Dict[str, Any]]):
|
||||
"""打印统计信息"""
|
||||
if not records:
|
||||
print("没有找到任何记录")
|
||||
return
|
||||
|
||||
print("=" * 80)
|
||||
print("Replyer 动作选择记录统计")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# 总记录数
|
||||
total_count = len(records)
|
||||
print(f"📊 总记录数: {total_count}")
|
||||
print()
|
||||
|
||||
# 时间范围
|
||||
timestamps = [r.get("timestamp", "") for r in records if r.get("timestamp")]
|
||||
if timestamps:
|
||||
first_time = format_timestamp(min(timestamps))
|
||||
last_time = format_timestamp(max(timestamps))
|
||||
print(f"📅 时间范围: {first_time} ~ {last_time}")
|
||||
print()
|
||||
|
||||
# 按 think_level 统计
|
||||
think_levels = [r.get("think_level", 0) for r in records]
|
||||
think_level_counter = Counter(think_levels)
|
||||
print("🧠 思考深度分布:")
|
||||
for level in sorted(think_level_counter.keys()):
|
||||
count = think_level_counter[level]
|
||||
percentage = (count / total_count) * 100
|
||||
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
|
||||
print(f" Level {level} ({level_name}): {count} 次 ({percentage:.1f}%)")
|
||||
print()
|
||||
|
||||
# 按 chat_id 统计(总体)
|
||||
chat_counter = Counter([r.get("chat_id", "未知") for r in records])
|
||||
print(f"💬 聊天分布 (共 {len(chat_counter)} 个聊天):")
|
||||
# 只显示前10个
|
||||
for chat_id, count in chat_counter.most_common(10):
|
||||
chat_name = get_chat_name(chat_id)
|
||||
percentage = (count / total_count) * 100
|
||||
print(f" {chat_name}: {count} 次 ({percentage:.1f}%)")
|
||||
if len(chat_counter) > 10:
|
||||
print(f" ... 还有 {len(chat_counter) - 10} 个聊天")
|
||||
print()
|
||||
|
||||
# 每个 chat_id 的详细统计
|
||||
print("=" * 80)
|
||||
print("每个聊天的详细统计")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# 按 chat_id 分组记录
|
||||
records_by_chat = defaultdict(list)
|
||||
for record in records:
|
||||
chat_id = record.get("chat_id", "未知")
|
||||
records_by_chat[chat_id].append(record)
|
||||
|
||||
# 按记录数排序
|
||||
sorted_chats = sorted(records_by_chat.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
|
||||
for chat_id, chat_records in sorted_chats:
|
||||
chat_name = get_chat_name(chat_id)
|
||||
chat_count = len(chat_records)
|
||||
chat_percentage = (chat_count / total_count) * 100
|
||||
|
||||
print(f"📱 {chat_name} ({chat_id[:8]}...)")
|
||||
print(f" 总记录数: {chat_count} ({chat_percentage:.1f}%)")
|
||||
|
||||
# 该聊天的 think_level 分布
|
||||
chat_think_levels = [r.get("think_level", 0) for r in chat_records]
|
||||
chat_think_counter = Counter(chat_think_levels)
|
||||
print(" 思考深度分布:")
|
||||
for level in sorted(chat_think_counter.keys()):
|
||||
level_count = chat_think_counter[level]
|
||||
level_percentage = (level_count / chat_count) * 100
|
||||
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
|
||||
print(f" Level {level} ({level_name}): {level_count} 次 ({level_percentage:.1f}%)")
|
||||
|
||||
# 该聊天的时间范围
|
||||
chat_timestamps = [r.get("timestamp", "") for r in chat_records if r.get("timestamp")]
|
||||
if chat_timestamps:
|
||||
first_time = format_timestamp(min(chat_timestamps))
|
||||
last_time = format_timestamp(max(chat_timestamps))
|
||||
print(f" 时间范围: {first_time} ~ {last_time}")
|
||||
|
||||
# 该聊天的时间分布
|
||||
chat_time_dist = calculate_time_distribution(chat_records)
|
||||
print(" 时间分布:")
|
||||
for period, count in chat_time_dist.items():
|
||||
if count > 0:
|
||||
period_percentage = (count / chat_count) * 100
|
||||
print(f" {period}: {count} 次 ({period_percentage:.1f}%)")
|
||||
|
||||
# 显示该聊天最近的一条理由示例
|
||||
if chat_records:
|
||||
latest_record = chat_records[-1]
|
||||
reason = latest_record.get("reason", "无理由")
|
||||
if len(reason) > 120:
|
||||
reason = reason[:120] + "..."
|
||||
timestamp = format_timestamp(latest_record.get("timestamp", ""))
|
||||
think_level = latest_record.get("think_level", 0)
|
||||
print(f" 最新记录 [{timestamp}] (Level {think_level}): {reason}")
|
||||
|
||||
print()
|
||||
|
||||
# 时间分布
|
||||
time_dist = calculate_time_distribution(records)
|
||||
print("⏰ 时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
percentage = (count / total_count) * 100
|
||||
print(f" {period}: {count} 次 ({percentage:.1f}%)")
|
||||
print()
|
||||
|
||||
# 显示一些示例理由
|
||||
print("📝 示例理由 (最近5条):")
|
||||
recent_records = records[-5:]
|
||||
for i, record in enumerate(recent_records, 1):
|
||||
reason = record.get("reason", "无理由")
|
||||
think_level = record.get("think_level", 0)
|
||||
timestamp = format_timestamp(record.get("timestamp", ""))
|
||||
chat_id = record.get("chat_id", "未知")
|
||||
chat_name = get_chat_name(chat_id)
|
||||
|
||||
# 截断过长的理由
|
||||
if len(reason) > 100:
|
||||
reason = reason[:100] + "..."
|
||||
|
||||
print(f" {i}. [{timestamp}] {chat_name} (Level {think_level})")
|
||||
print(f" {reason}")
|
||||
print()
|
||||
|
||||
# 按 think_level 分组显示理由示例
|
||||
print("=" * 80)
|
||||
print("按思考深度分类的示例理由")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
for level in [0, 1, 2]:
|
||||
level_records = [r for r in records if r.get("think_level") == level]
|
||||
if not level_records:
|
||||
continue
|
||||
|
||||
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
|
||||
print(f"Level {level} ({level_name}) - 共 {len(level_records)} 条:")
|
||||
|
||||
# 显示3个示例(选择最近的)
|
||||
examples = level_records[-3:] if len(level_records) >= 3 else level_records
|
||||
for i, record in enumerate(examples, 1):
|
||||
reason = record.get("reason", "无理由")
|
||||
if len(reason) > 150:
|
||||
reason = reason[:150] + "..."
|
||||
timestamp = format_timestamp(record.get("timestamp", ""))
|
||||
chat_id = record.get("chat_id", "未知")
|
||||
chat_name = get_chat_name(chat_id)
|
||||
print(f" {i}. [{timestamp}] {chat_name}")
|
||||
print(f" {reason}")
|
||||
print()
|
||||
|
||||
# 统计信息汇总
|
||||
print("=" * 80)
|
||||
print("统计汇总")
|
||||
print("=" * 80)
|
||||
print(f"总记录数: {total_count}")
|
||||
print(f"涉及聊天数: {len(chat_counter)}")
|
||||
if chat_counter:
|
||||
avg_count = total_count / len(chat_counter)
|
||||
print(f"平均每个聊天记录数: {avg_count:.1f}")
|
||||
else:
|
||||
print("平均每个聊天记录数: N/A")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
records = load_records()
|
||||
print_statistics(records)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
scripts/test_lpmm_retrieval.py
Normal file
122
scripts/test_lpmm_retrieval.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错影响 Embedding 加载
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from src.memory_system.retrieval_tools.query_lpmm_knowledge import query_lpmm_knowledge
|
||||
|
||||
logger = get_logger("test_lpmm_retrieval")
|
||||
|
||||
|
||||
DEFAULT_TEST_CASES: List[Dict[str, Any]] = [
|
||||
{
|
||||
"name": "回滚一批知识",
|
||||
"query": "LPMM是什么?",
|
||||
"expect_keywords": ["哈希列表", "删除脚本", "OpenIE"],
|
||||
},
|
||||
{
|
||||
"name": "调整 LPMM 检索参数",
|
||||
"query": "不同用词习惯带来的检索偏差该如何解决",
|
||||
"expect_keywords": ["bot_config.toml", "lpmm_knowledge", "qa_paragraph_search_top_k"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def run_tests(test_cases: Optional[List[Dict[str, Any]]] = None) -> None:
|
||||
"""简单测试 LPMM 知识库检索能力"""
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.warning("当前配置中 lpmm_knowledge.enable 为 False,检索测试可能直接返回“未启用”。")
|
||||
|
||||
logger.info("开始初始化 LPMM 知识库...")
|
||||
lpmm_start_up()
|
||||
logger.info("LPMM 知识库初始化完成,开始执行测试用例。")
|
||||
|
||||
cases = test_cases if test_cases is not None else DEFAULT_TEST_CASES
|
||||
|
||||
for case in cases:
|
||||
name = case["name"]
|
||||
query = case["query"]
|
||||
expect_keywords: List[str] = case.get("expect_keywords", [])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"[TEST] {name}")
|
||||
print(f"[Q] {query}")
|
||||
|
||||
result = await query_lpmm_knowledge(query, limit=3)
|
||||
|
||||
print("\n[RAW RESULT]")
|
||||
print(result)
|
||||
|
||||
status = "UNKNOWN"
|
||||
hit_keywords: List[str] = []
|
||||
|
||||
if isinstance(result, str):
|
||||
if "未启用" in result or "未初始化" in result or "查询失败" in result:
|
||||
status = "ERROR"
|
||||
elif "未找到与" in result:
|
||||
status = "NO_HIT"
|
||||
else:
|
||||
if expect_keywords:
|
||||
hit_keywords = [kw for kw in expect_keywords if kw in result]
|
||||
status = "PASS" if hit_keywords else "WARN"
|
||||
else:
|
||||
status = "PASS"
|
||||
|
||||
print("\n[CHECK]")
|
||||
print(f"Status: {status}")
|
||||
if expect_keywords:
|
||||
print(f"Expected keywords: {expect_keywords}")
|
||||
print(f"Hit keywords: {hit_keywords}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LPMM 检索测试完成。请根据每条用例的 Status 和命中关键词判断检索效果是否符合预期。")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"测试 LPMM 知识库检索能力。\n"
|
||||
"如不提供参数,则执行内置的默认用例;\n"
|
||||
"也可以通过 --query 与 --expect-keyword 自定义一条测试用例。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
help="自定义测试问题(单条)。提供该参数时,将仅运行这一条用例。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expect-keyword",
|
||||
action="append",
|
||||
help="期望在检索结果中出现的关键字,可重复多次指定;仅在提供 --query 时生效。",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.query:
|
||||
custom_case = {
|
||||
"name": "custom",
|
||||
"query": args.query,
|
||||
"expect_keywords": args.expect_keyword or [],
|
||||
}
|
||||
asyncio.run(run_tests([custom_case]))
|
||||
else:
|
||||
asyncio.run(run_tests())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
794
src/bw_learner/expression_learner.py
Normal file
794
src/bw_learner/expression_learner.py
Normal file
@@ -0,0 +1,794 @@
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_anonymous_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.bw_learner.learner_utils import (
|
||||
filter_message_content,
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
calculate_style_similarity,
|
||||
)
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """{chat_str}
|
||||
你的名字是{bot_name},现在请你完成两个提取任务
|
||||
任务1:请从上面这段群聊中用户的语言风格和说话方式
|
||||
1. 只考虑文字,不要考虑表情包和图片
|
||||
2. 不要总结SELF的发言,因为这是你自己的发言,不要重复学习你自己的发言
|
||||
3. 不要涉及具体的人名,也不要涉及具体名词
|
||||
4. 思考有没有特殊的梗,一并总结成语言风格
|
||||
5. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
表达方式在3-5个左右,不要超过10个
|
||||
|
||||
|
||||
任务2:请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||
- 必须为对话中真实出现过的短词或短语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语
|
||||
- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||
- 请你提取出可能的黑话,最多30个黑话,请尽量提取所有
|
||||
|
||||
黑话必须为以下几种类型:
|
||||
- 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||
- 英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API
|
||||
- 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||
|
||||
输出要求:
|
||||
将表达方式,语言风格和黑话以 JSON 数组输出,每个元素为一个对象,结构如下(注意字段名):
|
||||
注意请不要输出重复内容,请对表达方式和黑话进行去重。
|
||||
|
||||
[
|
||||
{{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}},
|
||||
{{"situation": "CCCC", "style": "DDDD", "source_id": "7"}}
|
||||
{{"situation": "对某件事表示十分惊叹", "style": "使用 我嘞个xxxx", "source_id": "[消息编号]"}},
|
||||
{{"situation": "表示讽刺的赞同,不讲道理", "style": "对对对", "source_id": "[消息编号]"}},
|
||||
{{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "使用 这么强!", "source_id": "[消息编号]"}},
|
||||
{{"content": "词条", "source_id": "12"}},
|
||||
{{"content": "词条2", "source_id": "5"}}
|
||||
]
|
||||
|
||||
其中:
|
||||
表达方式条目:
|
||||
- situation:表示“在什么情境下”的简短概括(不超过20个字)
|
||||
- style:表示对应的语言风格或常用表达(不超过20个字)
|
||||
- source_id:该表达方式对应的“来源行编号”,即上方聊天记录中方括号里的数字(例如 [3]),请只输出数字本身,不要包含方括号
|
||||
黑话jargon条目:
|
||||
- content:表示黑话的内容
|
||||
- source_id:该黑话对应的“来源行编号”,即上方聊天记录中方括号里的数字(例如 [3]),请只输出数字本身,不要包含方括号
|
||||
|
||||
现在请你输出 JSON:
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="expression.learner"
|
||||
)
|
||||
self.summary_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
|
||||
)
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
async def learn_and_store(
|
||||
self,
|
||||
messages: List[Any],
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
num: 学习数量
|
||||
timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
random_msg = messages
|
||||
|
||||
# 学习用(开启行编号,便于溯源)
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"learn_style_prompt",
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
# logger.info(f"学习{type_str}的prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
|
||||
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
jargon_entries: List[Tuple[str, str]] # (content, source_id)
|
||||
expressions, jargon_entries = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
|
||||
# 检查表达方式数量,如果超过10个则放弃本次表达学习
|
||||
if len(expressions) > 10:
|
||||
logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习")
|
||||
expressions = []
|
||||
|
||||
# 检查黑话数量,如果超过30个则放弃本次黑话学习
|
||||
if len(jargon_entries) > 30:
|
||||
logger.info(f"黑话提取数量超过30个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
jargon_entries = []
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
if jargon_entries:
|
||||
await self._process_jargon_entries(jargon_entries, random_msg)
|
||||
|
||||
# 如果没有表达方式,直接返回
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
return []
|
||||
|
||||
logger.info(f"学习的prompt: {prompt}")
|
||||
logger.info(f"学习的expressions: {expressions}")
|
||||
logger.info(f"学习的jargon_entries: {jargon_entries}")
|
||||
logger.info(f"学习的response: {response}")
|
||||
|
||||
# 直接根据 source_id 在 random_msg 中溯源,获取 context
|
||||
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
|
||||
|
||||
for situation, style, source_id in expressions:
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
# 无效的来源行编号,跳过
|
||||
continue
|
||||
|
||||
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
|
||||
if line_index < 0 or line_index >= len(random_msg):
|
||||
# 超出范围,跳过
|
||||
continue
|
||||
|
||||
# 当前行的原始内容
|
||||
current_msg = random_msg[line_index]
|
||||
|
||||
# 过滤掉从bot自己发言中提取到的表达方式
|
||||
if is_bot_message(current_msg):
|
||||
continue
|
||||
|
||||
context = filter_message_content(current_msg.processed_plain_text or "")
|
||||
if not context:
|
||||
continue
|
||||
|
||||
# 过滤掉包含 SELF 的内容(不学习)
|
||||
if "SELF" in (situation or "") or "SELF" in (style or "") or "SELF" in context:
|
||||
logger.info(
|
||||
f"跳过包含 SELF 的表达方式: situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
filtered_expressions.append((situation, style, context))
|
||||
|
||||
learnt_expressions = filtered_expressions
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
def parse_expression_response(self, response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
# 尝试提取 ```json 代码块
|
||||
json_block_pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(json_block_pattern, raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1).strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as parse_error:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
if char == '"': # 中文左引号 U+201C
|
||||
result.append('\\"')
|
||||
elif char == '"': # 中文右引号 U+201D
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
# 不在字符串内,直接添加
|
||||
result.append(char)
|
||||
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
else:
|
||||
repaired = repair_json(fixed_raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as fix_error:
|
||||
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return [], []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
elif item.get("content"):
|
||||
# 黑话条目(有 content 字段)
|
||||
content = str(item.get("content", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
|
||||
return expressions, jargon_entries
|
||||
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
过滤掉style与机器人名称/昵称重复的表达
|
||||
"""
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
filtered: List[Tuple[str, str, str]] = []
|
||||
removed_count = 0
|
||||
for situation, style, source_id in expressions:
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() not in banned_casefold:
|
||||
filtered.append((situation, style, source_id))
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
|
||||
|
||||
return filtered
|
||||
|
||||
async def _upsert_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
# 第一层:检查是否有完全一致的 style(检查 style 字段和 style_list)
|
||||
expr_obj = await self._find_exact_style_match(style)
|
||||
|
||||
if expr_obj:
|
||||
# 找到完全匹配的 style,合并到现有记录(不使用 LLM 总结)
|
||||
await self._update_existing_expression(
|
||||
expr_obj=expr_obj,
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
use_llm_summary=False,
|
||||
)
|
||||
return
|
||||
|
||||
# 第二层:检查是否有相似的 style(相似度 >= 0.75,检查 style 字段和 style_list)
|
||||
similar_expr_obj = await self._find_similar_style_expression(style, similarity_threshold=0.75)
|
||||
|
||||
if similar_expr_obj:
|
||||
# 找到相似的 style,合并到现有记录(使用 LLM 总结)
|
||||
await self._update_existing_expression(
|
||||
expr_obj=similar_expr_obj,
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
use_llm_summary=True,
|
||||
)
|
||||
return
|
||||
|
||||
# 没有找到匹配的记录,创建新记录
|
||||
await self._create_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
async def _create_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = [situation]
|
||||
# 创建新记录时,直接使用原始的 situation,不进行总结
|
||||
formatted_situation = situation
|
||||
|
||||
Expression.create(
|
||||
situation=formatted_situation,
|
||||
style=style,
|
||||
content_list=json.dumps(content_list, ensure_ascii=False),
|
||||
style_list=None, # 新记录初始时 style_list 为空
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _update_existing_expression(
|
||||
self,
|
||||
expr_obj: Expression,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
use_llm_summary: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
更新现有 Expression 记录(style 完全匹配或相似的情况)
|
||||
将新的 situation 添加到 content_list,将新的 style 添加到 style_list(如果不同)
|
||||
|
||||
Args:
|
||||
use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False,相似匹配时为 True
|
||||
"""
|
||||
# 更新 content_list(添加新的 situation)
|
||||
content_list = self._parse_content_list(expr_obj.content_list)
|
||||
content_list.append(situation)
|
||||
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
|
||||
|
||||
# 更新 style_list(如果 style 不同,添加到 style_list)
|
||||
style_list = self._parse_style_list(expr_obj.style_list)
|
||||
# 将原有的 style 也加入 style_list(如果还没有的话)
|
||||
if expr_obj.style and expr_obj.style not in style_list:
|
||||
style_list.append(expr_obj.style)
|
||||
# 如果新的 style 不在 style_list 中,添加它
|
||||
if style not in style_list:
|
||||
style_list.append(style)
|
||||
expr_obj.style_list = json.dumps(style_list, ensure_ascii=False)
|
||||
|
||||
# 更新其他字段
|
||||
expr_obj.count = (expr_obj.count or 0) + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.context = context
|
||||
|
||||
if use_llm_summary:
|
||||
# 相似匹配时,使用 LLM 重新组合 situation 和 style
|
||||
new_situation = await self._compose_situation_text(
|
||||
content_list=content_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.situation,
|
||||
)
|
||||
expr_obj.situation = new_situation
|
||||
|
||||
new_style = await self._compose_style_text(
|
||||
style_list=style_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.style or style,
|
||||
)
|
||||
expr_obj.style = new_style
|
||||
else:
|
||||
# 完全匹配时,不进行 LLM 总结,保持原有的 situation 和 style 不变
|
||||
# 只更新 content_list 和 style_list
|
||||
pass
|
||||
|
||||
expr_obj.save()
|
||||
|
||||
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
def _parse_style_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 style_list JSON 字符串为列表,逻辑与 _parse_content_list 相同"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
async def _find_exact_style_match(self, style: str) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有完全匹配 style 的 Expression 记录
|
||||
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
|
||||
|
||||
Args:
|
||||
style: 要查找的 style
|
||||
|
||||
Returns:
|
||||
找到的 Expression 对象,如果没有找到则返回 None
|
||||
"""
|
||||
# 查询同一 chat_id 的所有记录
|
||||
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
|
||||
|
||||
for expr in all_expressions:
|
||||
# 只检查 style_list 中的每一项
|
||||
style_list = self._parse_style_list(expr.style_list)
|
||||
if style in style_list:
|
||||
return expr
|
||||
|
||||
return None
|
||||
|
||||
async def _find_similar_style_expression(self, style: str, similarity_threshold: float = 0.75) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有相似 style 的 Expression 记录
|
||||
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
|
||||
|
||||
Args:
|
||||
style: 要查找的 style
|
||||
similarity_threshold: 相似度阈值,默认 0.75
|
||||
|
||||
Returns:
|
||||
找到的最相似的 Expression 对象,如果没有找到则返回 None
|
||||
"""
|
||||
# 查询同一 chat_id 的所有记录
|
||||
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
|
||||
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 只检查 style_list 中的每一项
|
||||
style_list = self._parse_style_list(expr.style_list)
|
||||
for existing_style in style_list:
|
||||
similarity = calculate_style_similarity(style, existing_style)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
if best_match:
|
||||
logger.debug(f"找到相似的 style: 相似度={best_similarity:.3f}, 现有='{best_match.style}', 新='{style}'")
|
||||
|
||||
return best_match
|
||||
|
||||
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
summary = await self._summarize_situations(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _compose_style_text(self, style_list: List[str], count: int, fallback: str = "") -> str:
|
||||
"""
|
||||
组合 style 文本,如果 style_list 有多个元素则尝试总结
|
||||
"""
|
||||
sanitized = [s.strip() for s in style_list if s.strip()]
|
||||
if len(sanitized) > 1:
|
||||
# 只有当有多个 style 时才尝试总结
|
||||
summary = await self._summarize_styles(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
# 如果只有一个或总结失败,返回第一个或 fallback
|
||||
return sanitized[0] if sanitized else fallback
|
||||
|
||||
async def _summarize_styles(self, styles: List[str]) -> Optional[str]:
|
||||
"""总结多个 style,生成一个概括性的 style 描述"""
|
||||
if not styles or len(styles) <= 1:
|
||||
return None
|
||||
|
||||
# 计算输入列表中最长项目的长度
|
||||
max_input_length = max(len(s) for s in styles) if styles else 0
|
||||
max_summary_length = max_input_length * 2
|
||||
|
||||
# 最多重试3次
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
# 如果是重试,在 prompt 中强调要更简洁
|
||||
length_hint = f"长度不超过{max_summary_length}个字符," if retry_count > 0 else "长度不超过20个字,"
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个语言风格/表达方式,对其进行总结。"
|
||||
"不要对其进行语义概括,而是尽可能找出其中不变的部分或共同表达,尽量使用原文"
|
||||
f"{length_hint}保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in styles[-10:])}\n只输出概括内容。不要输出其他内容"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
# 检查总结长度是否超过限制
|
||||
if len(summary) <= max_summary_length:
|
||||
return summary
|
||||
else:
|
||||
retry_count += 1
|
||||
logger.debug(
|
||||
f"总结长度 {len(summary)} 超过限制 {max_summary_length} "
|
||||
f"(输入最长项长度: {max_input_length}),重试第 {retry_count} 次"
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达风格失败: {e}")
|
||||
return None
|
||||
|
||||
# 如果重试多次后仍然超过长度,返回 None(不进行总结)
|
||||
logger.warning(
|
||||
f"总结多次后仍超过长度限制,放弃总结。"
|
||||
f"输入最长项长度: {max_input_length}, 最大允许长度: {max_summary_length}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
|
||||
if not situations:
|
||||
return None
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达情境失败: {e}")
|
||||
return None
|
||||
|
||||
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
|
||||
"""
|
||||
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
|
||||
|
||||
Args:
|
||||
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
|
||||
messages: 消息列表,用于构建上下文
|
||||
"""
|
||||
if not jargon_entries or not messages:
|
||||
return
|
||||
|
||||
# 获取 jargon_miner 实例
|
||||
jargon_miner = miner_manager.get_miner(self.chat_id)
|
||||
|
||||
# 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致
|
||||
entries: List[Dict[str, List[str]]] = []
|
||||
|
||||
for content, source_id in jargon_entries:
|
||||
content = content.strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 过滤掉包含 SELF 的黑话,不学习
|
||||
if "SELF" in content:
|
||||
logger.info(f"跳过包含 SELF 的黑话: {content}")
|
||||
continue
|
||||
|
||||
# 检查是否包含机器人名称
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
|
||||
continue
|
||||
|
||||
# 解析 source_id
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
# build_anonymous_messages 的编号从 1 开始
|
||||
line_index = int(source_id_str) - 1
|
||||
if line_index < 0 or line_index >= len(messages):
|
||||
logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
# 检查是否是机器人自己的消息
|
||||
target_msg = messages[line_index]
|
||||
if is_bot_message(target_msg):
|
||||
logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
# 构建上下文段落
|
||||
context_paragraph = build_context_paragraph(messages, line_index)
|
||||
if not context_paragraph:
|
||||
logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
entries.append({"content": content, "raw_content": [context_paragraph]})
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
# 调用 jargon_miner 处理这些条目
|
||||
await jargon_miner.process_extracted_entries(entries)
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
class ExpressionLearnerManager:
|
||||
def __init__(self):
|
||||
self.expression_learners = {}
|
||||
|
||||
self._ensure_expression_directories()
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
directories_to_create = [
|
||||
base_dir,
|
||||
os.path.join(base_dir, "learnt_style"),
|
||||
os.path.join(base_dir, "learnt_grammar"),
|
||||
]
|
||||
|
||||
for directory in directories_to_create:
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
|
||||
expression_learner_manager = ExpressionLearnerManager()
|
||||
@@ -82,9 +82,7 @@ class ExpressionReflector:
|
||||
# 获取未检查的表达
|
||||
try:
|
||||
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||
expressions = (
|
||||
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||
)
|
||||
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||
|
||||
expr_list = list(expressions)
|
||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||
@@ -147,7 +145,7 @@ expression_reflector_manager = ExpressionReflectorManager()
|
||||
|
||||
async def _check_tracker_exists(operator_config: str) -> bool:
|
||||
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = None
|
||||
@@ -242,7 +240,7 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
||||
stream_id = chat_stream.stream_id
|
||||
|
||||
# 注册 Tracker
|
||||
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
||||
from src.bw_learner.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
||||
|
||||
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
|
||||
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
@@ -10,7 +9,8 @@ from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.express.express_utils import weighted_sample
|
||||
from src.bw_learner.learner_utils import weighted_sample
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -67,7 +67,7 @@ class ExpressionSelector:
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
"""解析'platform:id:type'为chat_id,直接使用 ChatManager 提供的接口"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
@@ -76,12 +76,8 @@ class ExpressionSelector:
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
# 统一通过 chat_manager 生成 stream_id,避免各处自行实现哈希逻辑
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -111,6 +107,85 @@ class ExpressionSelector:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def _select_expressions_simple(self, chat_id: str, max_num: int) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
简单模式:只选择 count > 1 的项目,要求至少有10个才进行选择,随机选5个,不进行LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
max_num: 最大选择数量(此参数在此模式下不使用,固定选择5个)
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 要求至少有一定数量的 count > 1 的表达方式才进行“完整简单模式”选择
|
||||
min_required = 8
|
||||
if len(style_exprs) < min_required:
|
||||
# 高 count 样本不足:如果还有候选,就降级为随机选 3 个;如果一个都没有,则直接返回空
|
||||
if not style_exprs:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择"
|
||||
)
|
||||
# 完全没有高 count 样本时,退化为全量随机抽样(不进入LLM流程)
|
||||
fallback_num = min(3, max_num) if max_num > 0 else 3
|
||||
fallback_selected = self._random_expressions(chat_id, fallback_num)
|
||||
if fallback_selected:
|
||||
self.update_expressions_last_active_time(fallback_selected)
|
||||
selected_ids = [expr["id"] for expr in fallback_selected]
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 使用简单模式降级随机抽选 {len(fallback_selected)} 个表达(无 count>1 样本)"
|
||||
)
|
||||
return fallback_selected, selected_ids
|
||||
return [], []
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),"
|
||||
f"简单模式降级为随机选择 3 个"
|
||||
)
|
||||
select_count = min(3, len(style_exprs))
|
||||
else:
|
||||
# 高 count 数量达标时,固定选择 5 个
|
||||
select_count = 5
|
||||
import random
|
||||
|
||||
selected_style = random.sample(style_exprs, select_count)
|
||||
|
||||
# 更新last_active_time
|
||||
if selected_style:
|
||||
self.update_expressions_last_active_time(selected_style)
|
||||
|
||||
selected_ids = [expr["id"] for expr in selected_style]
|
||||
logger.debug(
|
||||
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个"
|
||||
)
|
||||
return selected_style, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"简单模式选择表达方式失败: {e}")
|
||||
return [], []
|
||||
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择表达方式
|
||||
@@ -127,9 +202,7 @@ class ExpressionSelector:
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
)
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
@@ -164,6 +237,7 @@ class ExpressionSelector:
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
选择适合的表达方式(使用classic模式:随机选择+LLM选择)
|
||||
@@ -174,6 +248,7 @@ class ExpressionSelector:
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
@@ -184,8 +259,10 @@ class ExpressionSelector:
|
||||
return [], []
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||
)
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
@@ -194,6 +271,7 @@ class ExpressionSelector:
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
@@ -204,24 +282,91 @@ class ExpressionSelector:
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 1. 使用随机抽样选择表达方式
|
||||
style_exprs = self._random_expressions(chat_id, 20)
|
||||
# think_level == 0: 只选择 count > 1 的项目,随机选10个,不进行LLM选择
|
||||
if think_level == 0:
|
||||
return self._select_expressions_simple(chat_id, max_num)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
||||
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||
|
||||
all_style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 分离 count > 1 和 count <= 1 的表达方式
|
||||
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
|
||||
|
||||
# 根据 think_level 设置要求(仅支持 0/1,0 已在上方返回)
|
||||
min_high_count = 10
|
||||
min_total_count = 10
|
||||
select_high_count = 5
|
||||
select_random_count = 5
|
||||
|
||||
# 检查数量要求
|
||||
# 对于高 count 表达:如果数量不足,不再直接停止,而是仅跳过“高 count 优先选择”
|
||||
if len(high_count_exprs) < min_high_count:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),"
|
||||
f"将跳过高 count 优先选择,仅从全部表达中随机抽样"
|
||||
)
|
||||
high_count_valid = False
|
||||
else:
|
||||
high_count_valid = True
|
||||
|
||||
# 总量不足仍然直接返回,避免样本过少导致选择质量过低
|
||||
if len(all_style_exprs) < min_total_count:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
|
||||
)
|
||||
return [], []
|
||||
|
||||
# 先选取高count的表达方式(如果数量达标)
|
||||
if high_count_valid:
|
||||
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
|
||||
else:
|
||||
selected_high = []
|
||||
|
||||
# 然后从所有表达方式中随机抽样(使用加权抽样)
|
||||
remaining_num = select_random_count
|
||||
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
|
||||
|
||||
# 合并候选池(去重,避免重复)
|
||||
candidate_exprs = selected_high.copy()
|
||||
candidate_ids = {expr["id"] for expr in candidate_exprs}
|
||||
for expr in selected_random:
|
||||
if expr["id"] not in candidate_ids:
|
||||
candidate_exprs.append(expr)
|
||||
candidate_ids.add(expr["id"])
|
||||
|
||||
# 打乱顺序,避免高count的都在前面
|
||||
import random
|
||||
|
||||
random.shuffle(candidate_exprs)
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
for expr in candidate_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
@@ -233,7 +378,7 @@ class ExpressionSelector:
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要对这条消息进行回复:“{target_message}”"
|
||||
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
@@ -262,7 +407,8 @@ class ExpressionSelector:
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# print(prompt)
|
||||
print(prompt)
|
||||
print(content)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
@@ -7,8 +7,13 @@ from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.jargon.jargon_miner import search_jargon
|
||||
from src.jargon.jargon_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
|
||||
from src.bw_learner.jargon_miner import search_jargon
|
||||
from src.bw_learner.learner_utils import (
|
||||
is_bot_message,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
chat_id_list_contains,
|
||||
)
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
@@ -82,7 +87,7 @@ class JargonExplainer:
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:只查询is_global=True的记录
|
||||
query = query.where(Jargon.is_global)
|
||||
else:
|
||||
@@ -107,7 +112,7 @@ class JargonExplainer:
|
||||
continue
|
||||
|
||||
# 检查chat_id(如果all_global=False)
|
||||
if not global_config.jargon.all_global:
|
||||
if not global_config.expression.all_global_jargon:
|
||||
if jargon.is_global:
|
||||
# 全局黑话,包含
|
||||
pass
|
||||
@@ -181,7 +186,7 @@ class JargonExplainer:
|
||||
content = entry["content"]
|
||||
|
||||
# 根据是否开启全局黑话,决定查询方式
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启全局黑话:查询所有is_global=True的记录
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
@@ -265,7 +270,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
return []
|
||||
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
query = query.where(Jargon.is_global)
|
||||
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
@@ -277,7 +282,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if not global_config.jargon.all_global and not jargon.is_global:
|
||||
if not global_config.expression.all_global_jargon and not jargon.is_global:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
@@ -357,4 +362,4 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
|
||||
|
||||
if results:
|
||||
return "【概念检索结果】\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
return ""
|
||||
@@ -1,8 +1,8 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import List, Dict, Optional, Any
|
||||
from typing import List, Dict, Optional, Any, Callable
|
||||
from json_repair import repair_json
|
||||
from peewee import fn
|
||||
|
||||
@@ -13,10 +13,9 @@ from src.config.config import model_config, global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.jargon.jargon_utils import (
|
||||
from src.bw_learner.learner_utils import (
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
@@ -29,6 +28,29 @@ from src.jargon.jargon_utils import (
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
def _is_single_char_jargon(content: str) -> bool:
|
||||
"""
|
||||
判断是否是单字黑话(单个汉字、英文或数字)
|
||||
|
||||
Args:
|
||||
content: 词条内容
|
||||
|
||||
Returns:
|
||||
bool: 如果是单字黑话返回True,否则返回False
|
||||
"""
|
||||
if not content or len(content) != 1:
|
||||
return False
|
||||
|
||||
char = content[0]
|
||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||
return (
|
||||
"\u4e00" <= char <= "\u9fff" # 汉字
|
||||
or "a" <= char <= "z" # 小写字母
|
||||
or "A" <= char <= "Z" # 大写字母
|
||||
or "0" <= char <= "9" # 数字
|
||||
)
|
||||
|
||||
|
||||
def _init_prompt() -> None:
|
||||
prompt_str = """
|
||||
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
||||
@@ -36,11 +58,9 @@ def _init_prompt() -> None:
|
||||
|
||||
请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||
- 必须为对话中真实出现过的短词或短语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语
|
||||
- 请不要选择有明确含义,或者含义清晰的词语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语
|
||||
- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||
- 合并重复项,去重
|
||||
|
||||
黑话必须为以下几种类型:
|
||||
- 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||
@@ -48,7 +68,7 @@ def _init_prompt() -> None:
|
||||
- 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||
|
||||
以 JSON 数组输出,元素为对象(严格按以下结构):
|
||||
请你提取出可能的黑话,最多10
|
||||
请你提取出可能的黑话,最多30个黑话,请尽量提取所有
|
||||
[
|
||||
{{"content": "词条", "msg_id": "m12"}}, // msg_id 必须与上方聊天中展示的ID完全一致
|
||||
{{"content": "词条2", "msg_id": "m15"}}
|
||||
@@ -67,12 +87,14 @@ def _init_inference_prompts() -> None:
|
||||
{content}
|
||||
**词条出现的上下文。其中的{bot_name}的发言内容是你自己的发言**
|
||||
{raw_content_list}
|
||||
{previous_meaning_section}
|
||||
|
||||
请根据上下文,推断"{content}"这个词条的含义。
|
||||
- 如果这是一个黑话、俚语或网络用语,请推断其含义
|
||||
- 如果含义明确(常规词汇),也请说明
|
||||
- {bot_name} 的发言内容可能包含错误,请不要参考其发言内容
|
||||
- 如果上下文信息不足,无法推断含义,请设置 no_info 为 true
|
||||
{previous_meaning_instruction}
|
||||
|
||||
以 JSON 格式输出:
|
||||
{{
|
||||
@@ -166,23 +188,24 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||
class JargonMiner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.last_learning_time: float = time.time()
|
||||
# 频率控制,可按需调整
|
||||
self.min_messages_for_learning: int = 10
|
||||
self.min_learning_interval: float = 20
|
||||
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.extract",
|
||||
)
|
||||
|
||||
self.llm_inference = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.inference",
|
||||
)
|
||||
|
||||
# 初始化stream_name作为类属性,避免重复提取
|
||||
chat_manager = get_chat_manager()
|
||||
stream_name = chat_manager.get_stream_name(self.chat_id)
|
||||
self.stream_name = stream_name if stream_name else self.chat_id
|
||||
self.cache_limit = 100
|
||||
self.cache_limit = 50
|
||||
self.cache: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
@@ -195,6 +218,10 @@ class JargonMiner:
|
||||
if not key:
|
||||
return
|
||||
|
||||
# 单字黑话(单个汉字、英文或数字)不记录到缓存
|
||||
if _is_single_char_jargon(key):
|
||||
return
|
||||
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
@@ -267,16 +294,44 @@ class JargonMiner:
|
||||
logger.warning(f"jargon {content} 没有raw_content,跳过推断")
|
||||
return
|
||||
|
||||
# 获取当前count和上一次的meaning
|
||||
current_count = jargon_obj.count or 0
|
||||
previous_meaning = jargon_obj.meaning or ""
|
||||
|
||||
# 当count为24, 60时,随机移除一半的raw_content项目
|
||||
if current_count in [24, 60] and len(raw_content_list) > 1:
|
||||
# 计算要保留的数量(至少保留1个)
|
||||
keep_count = max(1, len(raw_content_list) // 2)
|
||||
raw_content_list = random.sample(raw_content_list, keep_count)
|
||||
logger.info(
|
||||
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
|
||||
)
|
||||
|
||||
# 步骤1: 基于raw_content和content推断
|
||||
raw_content_text = "\n".join(raw_content_list)
|
||||
|
||||
# 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考
|
||||
previous_meaning_section = ""
|
||||
previous_meaning_instruction = ""
|
||||
if current_count in [24, 60, 100] and previous_meaning:
|
||||
previous_meaning_section = f"""
|
||||
**上一次推断的含义(仅供参考)**
|
||||
{previous_meaning}
|
||||
"""
|
||||
previous_meaning_instruction = (
|
||||
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||||
)
|
||||
|
||||
prompt1 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_with_context_prompt",
|
||||
content=content,
|
||||
bot_name=global_config.bot.nickname,
|
||||
raw_content_list=raw_content_text,
|
||||
previous_meaning_section=previous_meaning_section,
|
||||
previous_meaning_instruction=previous_meaning_instruction,
|
||||
)
|
||||
|
||||
response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3)
|
||||
response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3)
|
||||
if not response1:
|
||||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||||
return
|
||||
@@ -313,7 +368,7 @@ class JargonMiner:
|
||||
content=content,
|
||||
)
|
||||
|
||||
response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3)
|
||||
response2, _ = await self.llm_inference.generate_response_async(prompt2, temperature=0.3)
|
||||
if not response2:
|
||||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||||
return
|
||||
@@ -360,7 +415,7 @@ class JargonMiner:
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||
|
||||
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
|
||||
response3, _ = await self.llm_inference.generate_response_async(prompt3, temperature=0.3)
|
||||
if not response3:
|
||||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||||
return
|
||||
@@ -425,45 +480,21 @@ class JargonMiner:
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def should_trigger(self) -> bool:
|
||||
# 冷却时间检查
|
||||
if time.time() - self.last_learning_time < self.min_learning_interval:
|
||||
return False
|
||||
async def run_once(
|
||||
self,
|
||||
messages: List[Any],
|
||||
person_name_filter: Optional[Callable[[str], bool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
运行一次黑话提取
|
||||
|
||||
# 拉取最近消息数量是否足够
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
return bool(recent_messages and len(recent_messages) >= self.min_messages_for_learning)
|
||||
|
||||
async def run_once(self) -> None:
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
||||
"""
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._extraction_lock:
|
||||
try:
|
||||
# 在锁内检查,避免并发触发
|
||||
if not self.should_trigger():
|
||||
return
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
if not chat_stream:
|
||||
return
|
||||
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_learning_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
# 立即更新学习时间,防止并发触发
|
||||
self.last_learning_time = extraction_end_time
|
||||
|
||||
# 拉取学习窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
limit=20,
|
||||
)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
@@ -538,6 +569,11 @@ class JargonMiner:
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
|
||||
# 检查是否包含人物名称
|
||||
if person_name_filter and person_name_filter(content):
|
||||
logger.info(f"解析阶段跳过包含人物名称的词条: {content}")
|
||||
continue
|
||||
|
||||
msg_id_str = str(msg_id_value or "").strip()
|
||||
if not msg_id_str:
|
||||
@@ -603,7 +639,7 @@ class JargonMiner:
|
||||
# 查找匹配的记录
|
||||
matched_obj = None
|
||||
for obj in query:
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_obj = obj
|
||||
break
|
||||
@@ -626,7 +662,9 @@ class JargonMiner:
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
json.loads(obj.raw_content)
|
||||
if isinstance(obj.raw_content, str)
|
||||
else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
@@ -643,7 +681,7 @@ class JargonMiner:
|
||||
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
@@ -659,7 +697,7 @@ class JargonMiner:
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
@@ -699,6 +737,158 @@ class JargonMiner:
|
||||
logger.error(f"JargonMiner 运行失败: {e}")
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def process_extracted_entries(
|
||||
self,
|
||||
entries: List[Dict[str, List[str]]],
|
||||
person_name_filter: Optional[Callable[[str], bool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
处理已提取的黑话条目(从 expression_learner 路由过来的)
|
||||
|
||||
Args:
|
||||
entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]}
|
||||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
||||
"""
|
||||
if not entries:
|
||||
return
|
||||
|
||||
try:
|
||||
# 去重并合并raw_content(按 content 聚合)
|
||||
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
|
||||
for entry in entries:
|
||||
content_key = entry["content"]
|
||||
|
||||
# 检查是否包含人物名称
|
||||
# logger.info(f"process_extracted_entries 检查是否包含人物名称: {content_key}")
|
||||
# logger.info(f"person_name_filter: {person_name_filter}")
|
||||
if person_name_filter and person_name_filter(content_key):
|
||||
logger.info(f"process_extracted_entries 跳过包含人物名称的黑话: {content_key}")
|
||||
continue
|
||||
|
||||
raw_list = entry.get("raw_content", []) or []
|
||||
if content_key in merged_entries:
|
||||
merged_entries[content_key]["raw_content"].extend(raw_list)
|
||||
else:
|
||||
merged_entries[content_key] = {
|
||||
"content": content_key,
|
||||
"raw_content": list(raw_list),
|
||||
}
|
||||
|
||||
uniq_entries = []
|
||||
for merged_entry in merged_entries.values():
|
||||
raw_content_list = merged_entry["raw_content"]
|
||||
if raw_content_list:
|
||||
merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list))
|
||||
uniq_entries.append(merged_entry)
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
try:
|
||||
# 查询所有content匹配的记录
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
|
||||
# 查找匹配的记录
|
||||
matched_obj = None
|
||||
for obj in query:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_obj = obj
|
||||
break
|
||||
else:
|
||||
# 关闭all_global:需要检查chat_id列表是否包含目标chat_id
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
if chat_id_list_contains(chat_id_list, self.chat_id):
|
||||
matched_obj = obj
|
||||
break
|
||||
|
||||
if matched_obj:
|
||||
obj = matched_obj
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
# 更新chat_id列表:增加当前chat_id的计数
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1)
|
||||
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.expression.all_global_jargon:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
obj.save()
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
# 使用新格式创建chat_id列表:[[chat_id, count]]
|
||||
chat_id_list = [[self.chat_id, 1]]
|
||||
chat_id_json = json.dumps(chat_id_list, ensure_ascii=False)
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=chat_id_json,
|
||||
is_global=is_global_new,
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
||||
continue
|
||||
finally:
|
||||
self._add_to_cache(content)
|
||||
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
if saved or updated:
|
||||
logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理已提取的黑话条目失败: {e}")
|
||||
|
||||
|
||||
class JargonMinerManager:
|
||||
def __init__(self) -> None:
|
||||
@@ -713,11 +903,6 @@ class JargonMinerManager:
|
||||
miner_manager = JargonMinerManager()
|
||||
|
||||
|
||||
async def extract_and_store_jargon(chat_id: str) -> None:
|
||||
miner = miner_manager.get_miner(chat_id)
|
||||
await miner.run_once()
|
||||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
@@ -765,7 +950,7 @@ def search_jargon(
|
||||
query = query.where(search_condition)
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id)
|
||||
query = query.where(Jargon.is_global)
|
||||
# 注意:对于all_global=False的情况,chat_id过滤在Python层面进行,以便兼容新旧格式
|
||||
@@ -782,7 +967,7 @@ def search_jargon(
|
||||
results = []
|
||||
for jargon in query:
|
||||
# 如果提供了chat_id且all_global=False,需要检查chat_id列表是否包含目标chat_id
|
||||
if chat_id and not global_config.jargon.all_global:
|
||||
if chat_id and not global_config.expression.all_global_jargon:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
# 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含
|
||||
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):
|
||||
380
src/bw_learner/learner_utils.py
Normal file
380
src/bw_learner/learner_utils.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import re
|
||||
import difflib
|
||||
import random
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.chat.utils.utils import parse_platform_accounts
|
||||
|
||||
|
||||
logger = get_logger("learner_utils")
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
Returns:
|
||||
str: 过滤后的内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[picid:...]格式的图片ID
|
||||
content = re.sub(r"\[picid:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
使用SequenceMatcher计算相似度
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def calculate_style_similarity(style1: str, style2: str) -> float:
|
||||
"""
|
||||
计算两个 style 的相似度,返回0-1之间的值
|
||||
在计算前会移除"使用"和"句式"这两个词(参考 expression_similarity_analysis.py)
|
||||
|
||||
Args:
|
||||
style1: 第一个 style
|
||||
style2: 第二个 style
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
if not style1 or not style2:
|
||||
return 0.0
|
||||
|
||||
# 移除"使用"和"句式"这两个词
|
||||
def remove_ignored_words(text: str) -> str:
|
||||
"""移除需要忽略的词"""
|
||||
text = text.replace("使用", "")
|
||||
text = text.replace("句式", "")
|
||||
return text.strip()
|
||||
|
||||
cleaned_style1 = remove_ignored_words(style1)
|
||||
cleaned_style2 = remove_ignored_words(style2)
|
||||
|
||||
# 如果清理后文本为空,返回0
|
||||
if not cleaned_style1 or not cleaned_style2:
|
||||
return 0.0
|
||||
|
||||
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
str: 格式化后的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def _compute_weights(population: List[Dict]) -> List[float]:
|
||||
"""
|
||||
根据表达的count计算权重,范围限定在1~5之间。
|
||||
count越高,权重越高,但最多为基础权重的5倍。
|
||||
如果表达已checked,权重会再乘以3倍。
|
||||
"""
|
||||
if not population:
|
||||
return []
|
||||
|
||||
counts = []
|
||||
checked_flags = []
|
||||
for item in population:
|
||||
count = item.get("count", 1)
|
||||
try:
|
||||
count_value = float(count)
|
||||
except (TypeError, ValueError):
|
||||
count_value = 1.0
|
||||
counts.append(max(count_value, 0.0))
|
||||
# 获取checked状态
|
||||
checked = item.get("checked", False)
|
||||
checked_flags.append(bool(checked))
|
||||
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
|
||||
if max_count == min_count:
|
||||
base_weights = [1.0 for _ in counts]
|
||||
else:
|
||||
base_weights = []
|
||||
for count_value in counts:
|
||||
# 线性映射到[1,5]区间
|
||||
normalized = (count_value - min_count) / (max_count - min_count)
|
||||
base_weights.append(1.0 + normalized * 4.0) # 1~5
|
||||
|
||||
# 如果checked,权重乘以3
|
||||
weights = []
|
||||
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
|
||||
if checked:
|
||||
weights.append(base_weight * 3.0)
|
||||
else:
|
||||
weights.append(base_weight)
|
||||
return weights
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""
|
||||
随机抽样函数
|
||||
|
||||
Args:
|
||||
population: 总体数据列表
|
||||
k: 需要抽取的数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 抽取的数据列表
|
||||
"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
selected: List[Dict] = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(min(k, len(population_copy))):
|
||||
weights = _compute_weights(population_copy)
|
||||
total_weight = sum(weights)
|
||||
if total_weight <= 0:
|
||||
# 回退到均匀随机
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
continue
|
||||
|
||||
threshold = random.uniform(0, total_weight)
|
||||
cumulative = 0.0
|
||||
for idx, weight in enumerate(weights):
|
||||
cumulative += weight
|
||||
if threshold <= cumulative:
|
||||
selected.append(population_copy.pop(idx))
|
||||
break
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||
"""
|
||||
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
||||
|
||||
Args:
|
||||
chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式)
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
|
||||
"""
|
||||
if not chat_id_value:
|
||||
return []
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(chat_id_value, str):
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
parsed = json.loads(chat_id_value)
|
||||
if isinstance(parsed, list):
|
||||
# 新格式:已经是列表
|
||||
return parsed
|
||||
elif isinstance(parsed, str):
|
||||
# 解析后还是字符串,说明是旧格式
|
||||
return [[parsed, 1]]
|
||||
else:
|
||||
# 其他类型,当作旧格式处理
|
||||
return [[str(chat_id_value), 1]]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 解析失败,当作旧格式(纯字符串)
|
||||
return [[str(chat_id_value), 1]]
|
||||
elif isinstance(chat_id_value, list):
|
||||
# 已经是列表格式
|
||||
return chat_id_value
|
||||
else:
|
||||
# 其他类型,转换为旧格式
|
||||
return [[str(chat_id_value), 1]]
|
||||
|
||||
|
||||
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
|
||||
"""
|
||||
更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目
|
||||
|
||||
Args:
|
||||
chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要更新或添加的chat_id
|
||||
increment: 增加的计数,默认为1
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 更新后的chat_id列表
|
||||
"""
|
||||
item = _find_chat_id_item(chat_id_list, target_chat_id)
|
||||
if item is not None:
|
||||
# 找到匹配的chat_id,增加计数
|
||||
if len(item) >= 2:
|
||||
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
|
||||
else:
|
||||
item.append(increment)
|
||||
else:
|
||||
# 未找到,添加新条目
|
||||
chat_id_list.append([target_chat_id, increment])
|
||||
|
||||
return chat_id_list
|
||||
|
||||
|
||||
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
|
||||
"""
|
||||
在chat_id列表中查找匹配的项(辅助函数)
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
如果找到则返回匹配的项,否则返回None
|
||||
"""
|
||||
for item in chat_id_list:
|
||||
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
|
||||
"""
|
||||
检查chat_id列表中是否包含指定的chat_id
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
bool: 如果包含则返回True
|
||||
"""
|
||||
return _find_chat_id_item(chat_id_list, target_chat_id) is not None
|
||||
|
||||
|
||||
def contains_bot_self_name(content: str) -> bool:
|
||||
"""
|
||||
判断词条是否包含机器人的昵称或别名
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
target = content.strip().lower()
|
||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||
|
||||
candidates = [name for name in [nickname, *alias_names] if name]
|
||||
|
||||
return any(name in target for name in candidates)
|
||||
|
||||
|
||||
def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
||||
"""
|
||||
构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
||||
"""
|
||||
if not messages or center_index < 0 or center_index >= len(messages):
|
||||
return None
|
||||
|
||||
context_start = max(0, center_index - 3)
|
||||
context_end = min(len(messages), center_index + 1 + 3)
|
||||
context_messages = messages[context_start:context_end]
|
||||
|
||||
if not context_messages:
|
||||
return None
|
||||
|
||||
try:
|
||||
paragraph = build_readable_messages(
|
||||
messages=context_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
message_id_list=None,
|
||||
remove_emoji_stickers=False,
|
||||
pic_single=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"构建上下文段落失败: {e}")
|
||||
return None
|
||||
|
||||
paragraph = paragraph.strip()
|
||||
return paragraph or None
|
||||
|
||||
|
||||
def is_bot_message(msg: Any) -> bool:
|
||||
"""判断消息是否来自机器人自身"""
|
||||
if msg is None:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
platform = (
|
||||
str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
|
||||
if not platform or not user_id:
|
||||
return False
|
||||
|
||||
platform_accounts = {}
|
||||
try:
|
||||
platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
|
||||
except Exception:
|
||||
platform_accounts = {}
|
||||
|
||||
bot_accounts: Dict[str, str] = {}
|
||||
qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
|
||||
if qq_account:
|
||||
bot_accounts["qq"] = qq_account
|
||||
|
||||
telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
|
||||
if telegram_account:
|
||||
bot_accounts["telegram"] = telegram_account
|
||||
|
||||
for plat, account in platform_accounts.items():
|
||||
if account and plat not in bot_accounts:
|
||||
bot_accounts[plat] = account
|
||||
|
||||
bot_account = bot_accounts.get(platform)
|
||||
return bool(bot_account and user_id == bot_account)
|
||||
212
src/bw_learner/message_recorder.py
Normal file
212
src/bw_learner/message_recorder.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
|
||||
logger = get_logger("bw_learner")
|
||||
|
||||
|
||||
class MessageRecorder:
|
||||
"""
|
||||
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 维护每个chat的上次提取时间
|
||||
self.last_extraction_time: float = time.time()
|
||||
|
||||
# 提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
# 获取 expression 和 jargon 的配置参数
|
||||
self._init_parameters()
|
||||
|
||||
# 获取 expression_learner 和 jargon_miner 实例
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
|
||||
self.jargon_miner = miner_manager.get_miner(chat_id)
|
||||
|
||||
def _init_parameters(self) -> None:
|
||||
"""初始化提取参数"""
|
||||
# 获取 expression 配置
|
||||
_, self.enable_expression_learning, self.enable_jargon_learning = (
|
||||
global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
)
|
||||
self.min_messages_for_extraction = 30
|
||||
self.min_extraction_interval = 60
|
||||
|
||||
logger.debug(
|
||||
f"MessageRecorder 初始化: chat_id={self.chat_id}, "
|
||||
f"min_messages={self.min_messages_for_extraction}, "
|
||||
f"min_interval={self.min_extraction_interval}"
|
||||
)
|
||||
|
||||
def should_trigger_extraction(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发消息提取
|
||||
|
||||
Returns:
|
||||
bool: 是否应该触发提取
|
||||
"""
|
||||
# 检查时间间隔
|
||||
time_diff = time.time() - self.last_extraction_time
|
||||
if time_diff < self.min_extraction_interval:
|
||||
return False
|
||||
|
||||
# 检查消息数量
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_extraction_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def extract_and_distribute(self) -> None:
|
||||
"""
|
||||
提取消息并分发给 expression_learner 和 jargon_miner
|
||||
"""
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._extraction_lock:
|
||||
# 在锁内检查,避免并发触发
|
||||
if not self.should_trigger_extraction():
|
||||
return
|
||||
|
||||
# 检查 chat_stream 是否存在
|
||||
if not self.chat_stream:
|
||||
return
|
||||
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_extraction_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
# 立即更新提取时间,防止并发触发
|
||||
self.last_extraction_time = extraction_end_time
|
||||
|
||||
try:
|
||||
# logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
|
||||
|
||||
# 拉取提取窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
|
||||
return
|
||||
|
||||
# 按时间排序,确保顺序一致
|
||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||
|
||||
logger.info(
|
||||
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
|
||||
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
||||
)
|
||||
|
||||
# 分别触发 expression_learner 和 jargon_miner 的处理
|
||||
# 传递提取的消息,避免它们重复获取
|
||||
# 触发 expression 学习(如果启用)
|
||||
if self.enable_expression_learning:
|
||||
asyncio.create_task(
|
||||
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
|
||||
)
|
||||
|
||||
# 触发 jargon 提取(如果启用),传递消息
|
||||
# if self.enable_jargon_learning:
|
||||
# asyncio.create_task(
|
||||
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def _trigger_expression_learning(
|
||||
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 expression 学习,使用指定的消息列表
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
messages: 消息列表
|
||||
"""
|
||||
try:
|
||||
# 传递消息给 ExpressionLearner(必需参数)
|
||||
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
else:
|
||||
logger.debug(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _trigger_jargon_extraction(
|
||||
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 jargon 提取,使用指定的消息列表
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
messages: 消息列表
|
||||
"""
|
||||
try:
|
||||
# 传递消息给 JargonMiner,避免它重复获取
|
||||
await self.jargon_miner.run_once(messages=messages)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class MessageRecorderManager:
|
||||
"""MessageRecorder 管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._recorders: dict[str, MessageRecorder] = {}
|
||||
|
||||
def get_recorder(self, chat_id: str) -> MessageRecorder:
|
||||
"""获取或创建指定 chat_id 的 MessageRecorder"""
|
||||
if chat_id not in self._recorders:
|
||||
self._recorders[chat_id] = MessageRecorder(chat_id)
|
||||
return self._recorders[chat_id]
|
||||
|
||||
|
||||
# 全局管理器实例
|
||||
recorder_manager = MessageRecorderManager()
|
||||
|
||||
|
||||
async def extract_and_distribute_messages(chat_id: str) -> None:
|
||||
"""
|
||||
统一的消息提取和分发入口函数
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
"""
|
||||
recorder = recorder_manager.get_recorder(chat_id)
|
||||
await recorder.extract_and_distribute()
|
||||
491
src/chat/brain_chat/PFC/action_planner.py
Normal file
491
src/chat/brain_chat/PFC/action_planner.py
Normal file
@@ -0,0 +1,491 @@
|
||||
import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
from src.common.logger_manager import get_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from src.individuality.individuality import Individuality
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
logger = get_logger("pfc_action_planner")
|
||||
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt(1): 首次回复或非连续回复时的决策 Prompt
|
||||
PROMPT_INITIAL_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以回复,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识或记忆,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
listening: 倾听对方发言,当你认为对方话才说到一半,发言明显未结束时选择
|
||||
direct_reply: 直接回复对方
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# Prompt(2): 上一次成功回复后,决定继续发言时的决策 Prompt
|
||||
PROMPT_FOLLOW_UP = """{persona_text}。现在你在参与一场QQ私聊,刚刚你已经回复了对方,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以继续发送新消息,可以等待,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
wait: 暂时不说话,留给对方交互空间,等待对方回复(尤其是在你刚发言后、或上次发言因重复、发言过多被拒时、或不确定做什么时,这是不错的选择)
|
||||
listening: 倾听对方发言(虽然你刚发过言,但如果对方立刻回复且明显话没说完,可以选择这个)
|
||||
send_new_message: 发送一条新消息继续对话,允许适当的追问、补充、深入话题,或开启相关新话题。**但是避免在因重复被拒后立即使用,也不要在对方没有回复的情况下过多的“消息轰炸”或重复发言**
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的。请说明你为什么选择继续发言而不是等待,以及打算发送什么类型的新消息连续发言,必须记录已经发言了几次)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# 新增:Prompt(3): 决定是否在结束对话前发送告别语
|
||||
PROMPT_END_DECISION = """{persona_text}。刚刚你决定结束一场 QQ 私聊。
|
||||
|
||||
【你们之前的聊天记录】
|
||||
{chat_history_text}
|
||||
|
||||
你觉得你们的对话已经完整结束了吗?有时候,在对话自然结束后再说点什么可能会有点奇怪,但有时也可能需要一条简短的消息来圆满结束。
|
||||
如果觉得确实有必要再发一条简短、自然、符合你人设的告别消息(比如 "好,下次再聊~" 或 "嗯,先这样吧"),就输出 "yes"。
|
||||
如果觉得当前状态下直接结束对话更好,没有必要再发消息,就输出 "no"。
|
||||
|
||||
请以 JSON 格式输出你的选择:
|
||||
{{
|
||||
"say_bye": "yes/no",
|
||||
"reason": "选择 yes 或 no 的原因和内心想法 (简要说明)"
|
||||
}}
|
||||
|
||||
注意:请严格按照 JSON 格式输出,不要包含任何其他内容。"""
|
||||
|
||||
|
||||
# ActionPlanner 类定义,顶格
|
||||
class ActionPlanner:
|
||||
"""行动规划器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_action_planner,
|
||||
temperature=global_config.llm_PFC_action_planner["temp"],
|
||||
max_tokens=1500,
|
||||
request_type="action_planning",
|
||||
)
|
||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
|
||||
|
||||
# 修改 plan 方法签名,增加 last_successful_reply_action 参数
|
||||
async def plan(
|
||||
self,
|
||||
observation_info: ObservationInfo,
|
||||
conversation_info: ConversationInfo,
|
||||
last_successful_reply_action: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
"""规划下一步行动
|
||||
|
||||
Args:
|
||||
observation_info: 决策信息
|
||||
conversation_info: 对话信息
|
||||
last_successful_reply_action: 上一次成功的回复动作类型 ('direct_reply' 或 'send_new_message' 或 None)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (行动类型, 行动原因)
|
||||
"""
|
||||
# --- 获取 Bot 上次发言时间信息 ---
|
||||
# (这部分逻辑不变)
|
||||
time_since_last_bot_message_info = ""
|
||||
try:
|
||||
bot_id = str(global_config.BOT_QQ)
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
for i in range(len(observation_info.chat_history) - 1, -1, -1):
|
||||
msg = observation_info.chat_history[i]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
sender_info = msg.get("user_info", {})
|
||||
sender_id = str(sender_info.get("user_id")) if isinstance(sender_info, dict) else None
|
||||
msg_time = msg.get("time")
|
||||
if sender_id == bot_id and msg_time:
|
||||
time_diff = time.time() - msg_time
|
||||
if time_diff < 60.0:
|
||||
time_since_last_bot_message_info = (
|
||||
f"提示:你上一条成功发送的消息是在 {time_diff:.1f} 秒前。\n"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Observation info chat history is empty or not available for bot time check."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might not have chat_history attribute yet for bot time check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
timeout_context = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
last_goal_dict = conversation_info.goal_list[-1]
|
||||
if isinstance(last_goal_dict, dict) and "goal" in last_goal_dict:
|
||||
last_goal_text = last_goal_dict["goal"]
|
||||
if isinstance(last_goal_text, str) and "分钟,思考接下来要做什么" in last_goal_text:
|
||||
try:
|
||||
timeout_minutes_text = last_goal_text.split(",")[0].replace("你等待了", "")
|
||||
timeout_context = f"重要提示:对方已经长时间({timeout_minutes_text})没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
except Exception:
|
||||
timeout_context = "重要提示:对方已经长时间没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Conversation info goal_list is empty or not available for timeout check."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet for timeout check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]检查超时目标时出错: {e}")
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始规划行动:当前目标: {getattr(conversation_info, 'goal_list', '不可用')}"
|
||||
)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
|
||||
if not goals_str:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet."
|
||||
)
|
||||
goals_str = "- 获取对话目标时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建对话目标字符串时出错: {e}")
|
||||
goals_str = "- 构建对话目标时出错。\n"
|
||||
|
||||
# --- 知识信息字符串构建开始 ---
|
||||
knowledge_info_str = "【已获取的相关知识和记忆】\n"
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识,防止 Prompt 过长
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字,避免太长
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' 的知识 (来源: {source}):\n {knowledge_snippet}\n"
|
||||
)
|
||||
else:
|
||||
# 处理列表里不是字典的异常情况
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge: # 如果 knowledge_list 存在但为空
|
||||
knowledge_info_str += "- 暂无相关知识和记忆。\n"
|
||||
|
||||
else:
|
||||
# 如果 conversation_info 没有 knowledge_list 属性,或者列表为空
|
||||
knowledge_info_str += "- 暂无相关知识记忆。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
# --- 知识信息字符串构建结束 ---
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
try:
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
else:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
|
||||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += (
|
||||
f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo has new_messages_count > 0 but unprocessed_messages is empty or missing."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might be missing expected attributes for chat history."
|
||||
)
|
||||
chat_history_text = "获取聊天记录时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理聊天记录时发生未知错误: {e}")
|
||||
chat_history_text = "处理聊天记录时出错。\n"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# 构建行动历史和上一次行动结果 (action_history_summary, last_action_context)
|
||||
# (这部分逻辑不变)
|
||||
action_history_summary = "你最近执行的行动历史:\n"
|
||||
last_action_context = "关于你【上一次尝试】的行动:\n"
|
||||
action_history_list = []
|
||||
try:
|
||||
if hasattr(conversation_info, "done_action") and conversation_info.done_action:
|
||||
action_history_list = conversation_info.done_action[-5:]
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]Conversation info done_action is empty or not available.")
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have done_action attribute yet."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]访问行动历史时出错: {e}")
|
||||
|
||||
if not action_history_list:
|
||||
action_history_summary += "- 还没有执行过行动。\n"
|
||||
last_action_context += "- 这是你规划的第一个行动。\n"
|
||||
else:
|
||||
for i, action_data in enumerate(action_history_list):
|
||||
action_type = "未知"
|
||||
plan_reason = "未知"
|
||||
status = "未知"
|
||||
final_reason = ""
|
||||
action_time = ""
|
||||
|
||||
if isinstance(action_data, dict):
|
||||
action_type = action_data.get("action", "未知")
|
||||
plan_reason = action_data.get("plan_reason", "未知规划原因")
|
||||
status = action_data.get("status", "未知")
|
||||
final_reason = action_data.get("final_reason", "")
|
||||
action_time = action_data.get("time", "")
|
||||
elif isinstance(action_data, tuple):
|
||||
# 假设旧格式兼容
|
||||
if len(action_data) > 0:
|
||||
action_type = action_data[0]
|
||||
if len(action_data) > 1:
|
||||
plan_reason = action_data[1] # 可能是规划原因或最终原因
|
||||
if len(action_data) > 2:
|
||||
status = action_data[2]
|
||||
if status == "recall" and len(action_data) > 3:
|
||||
final_reason = action_data[3]
|
||||
elif status == "done" and action_type in ["direct_reply", "send_new_message"]:
|
||||
plan_reason = "成功发送" # 简化显示
|
||||
|
||||
reason_text = f", 失败/取消原因: {final_reason}" if final_reason else ""
|
||||
summary_line = f"- 时间:{action_time}, 尝试行动:'{action_type}', 状态:{status}{reason_text}"
|
||||
action_history_summary += summary_line + "\n"
|
||||
|
||||
if i == len(action_history_list) - 1:
|
||||
last_action_context += f"- 上次【规划】的行动是: '{action_type}'\n"
|
||||
last_action_context += f"- 当时规划的【原因】是: {plan_reason}\n"
|
||||
if status == "done":
|
||||
last_action_context += "- 该行动已【成功执行】。\n"
|
||||
# 记录这次成功的行动类型,供下次决策
|
||||
# self.last_successful_action_type = action_type # 不在这里记录,由 conversation 控制
|
||||
elif status == "recall":
|
||||
last_action_context += "- 但该行动最终【未能执行/被取消】。\n"
|
||||
if final_reason:
|
||||
last_action_context += f"- 【重要】失败/取消的具体原因是: “{final_reason}”\n"
|
||||
else:
|
||||
last_action_context += "- 【重要】失败/取消原因未明确记录。\n"
|
||||
# self.last_successful_action_type = None # 行动失败,清除记录
|
||||
else:
|
||||
last_action_context += f"- 该行动当前状态: {status}\n"
|
||||
# self.last_successful_action_type = None # 非完成状态,清除记录
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if last_successful_reply_action in ["direct_reply", "send_new_message"]:
|
||||
prompt_template = PROMPT_FOLLOW_UP
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_FOLLOW_UP (追问决策)")
|
||||
else:
|
||||
prompt_template = PROMPT_INITIAL_REPLY
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_INITIAL_REPLY (首次/非连续回复决策)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str if goals_str.strip() else "- 目前没有明确对话目标,请考虑设定一个。",
|
||||
action_history_summary=action_history_summary,
|
||||
last_action_context=last_action_context,
|
||||
time_since_last_bot_message_info=time_since_last_bot_message_info,
|
||||
timeout_context=timeout_context,
|
||||
chat_history_text=chat_history_text if chat_history_text.strip() else "还没有聊天记录。",
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
|
||||
|
||||
# --- 初始行动规划解析 ---
|
||||
success, initial_result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"action",
|
||||
"reason",
|
||||
default_values={"action": "wait", "reason": "LLM返回格式错误或未提供原因,默认等待"},
|
||||
)
|
||||
|
||||
initial_action = initial_result.get("action", "wait")
|
||||
initial_reason = initial_result.get("reason", "LLM未提供原因,默认等待")
|
||||
|
||||
# 检查是否需要进行结束对话决策 ---
|
||||
if initial_action == "end_conversation":
|
||||
logger.info(f"[私聊][{self.private_name}]初步规划结束对话,进入告别决策...")
|
||||
|
||||
# 使用新的 PROMPT_END_DECISION
|
||||
end_decision_prompt = PROMPT_END_DECISION.format(
|
||||
persona_text=persona_text, # 复用之前的 persona_text
|
||||
chat_history_text=chat_history_text, # 复用之前的 chat_history_text
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
|
||||
)
|
||||
try:
|
||||
end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
|
||||
|
||||
# 解析结束决策的JSON
|
||||
end_success, end_result = get_items_from_json(
|
||||
end_content,
|
||||
self.private_name,
|
||||
"say_bye",
|
||||
"reason",
|
||||
default_values={"say_bye": "no", "reason": "结束决策LLM返回格式错误,默认不告别"},
|
||||
required_types={"say_bye": str, "reason": str}, # 明确类型
|
||||
)
|
||||
|
||||
say_bye_decision = end_result.get("say_bye", "no").lower() # 转小写方便比较
|
||||
end_decision_reason = end_result.get("reason", "未提供原因")
|
||||
|
||||
if end_success and say_bye_decision == "yes":
|
||||
# 决定要告别,返回新的 'say_goodbye' 动作
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: yes, 准备生成告别语. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 注意:这里的 reason 可以考虑拼接初始原因和结束决策原因,或者只用结束决策原因
|
||||
final_action = "say_goodbye"
|
||||
final_reason = f"决定发送告别语。决策原因: {end_decision_reason} (原结束理由: {initial_reason})"
|
||||
return final_action, final_reason
|
||||
else:
|
||||
# 决定不告别 (包括解析失败或明确说no)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: no, 直接结束对话. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 返回原始的 'end_conversation' 动作
|
||||
final_action = "end_conversation"
|
||||
final_reason = initial_reason # 保持原始的结束理由
|
||||
return final_action, final_reason
|
||||
|
||||
except Exception as end_e:
|
||||
logger.error(f"[私聊][{self.private_name}]调用结束决策LLM或处理结果时出错: {str(end_e)}")
|
||||
# 出错时,默认执行原始的结束对话
|
||||
logger.warning(f"[私聊][{self.private_name}]结束决策出错,将按原计划执行 end_conversation")
|
||||
return "end_conversation", initial_reason # 返回原始动作和原因
|
||||
|
||||
else:
|
||||
action = initial_action
|
||||
reason = initial_reason
|
||||
|
||||
# 验证action类型 (保持不变)
|
||||
valid_actions = [
|
||||
"direct_reply",
|
||||
"send_new_message",
|
||||
"fetch_knowledge",
|
||||
"wait",
|
||||
"listening",
|
||||
"rethink_goal",
|
||||
"end_conversation", # 仍然需要验证,因为可能从上面决策后返回
|
||||
"block_and_ignore",
|
||||
"say_goodbye", # 也要验证这个新动作
|
||||
]
|
||||
if action not in valid_actions:
|
||||
logger.warning(f"[私聊][{self.private_name}]LLM返回了未知的行动类型: '{action}',强制改为 wait")
|
||||
reason = f"(原始行动'{action}'无效,已强制改为wait) {reason}"
|
||||
action = "wait"
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]规划的行动: {action}")
|
||||
logger.info(f"[私聊][{self.private_name}]行动原因: {reason}")
|
||||
return action, reason
|
||||
|
||||
except Exception as e:
|
||||
# 外层异常处理保持不变
|
||||
logger.error(f"[私聊][{self.private_name}]规划行动时调用 LLM 或处理结果出错: {str(e)}")
|
||||
return "wait", f"行动规划处理中发生错误,暂时等待: {str(e)}"
|
||||
379
src/chat/brain_chat/PFC/chat_observer.py
Normal file
379
src/chat/brain_chat/PFC/chat_observer.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_module_logger
|
||||
from maim_message import UserInfo
|
||||
from ...config.config import global_config
|
||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||
from .message_storage import MongoDBMessageStorage
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_module_logger("chat_observer")
|
||||
|
||||
|
||||
class ChatObserver:
|
||||
"""聊天状态观察器"""
|
||||
|
||||
# 类级别的实例管理
|
||||
_instances: Dict[str, "ChatObserver"] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, stream_id: str, private_name: str) -> "ChatObserver":
|
||||
"""获取或创建观察器实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
ChatObserver: 观察器实例
|
||||
"""
|
||||
if stream_id not in cls._instances:
|
||||
cls._instances[stream_id] = cls(stream_id, private_name)
|
||||
return cls._instances[stream_id]
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化观察器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.last_check_time = None
|
||||
self.last_bot_speak_time = None
|
||||
self.last_user_speak_time = None
|
||||
if stream_id in self._instances:
|
||||
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
||||
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.message_storage = MongoDBMessageStorage()
|
||||
|
||||
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||
# self.last_check_time: float = time.time() # 上次查看聊天记录时间
|
||||
self.last_message_read: Optional[Dict[str, Any]] = None # 最后读取的消息ID
|
||||
self.last_message_time: float = time.time()
|
||||
|
||||
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._update_event = asyncio.Event() # 触发更新的事件
|
||||
self._update_complete = asyncio.Event() # 更新完成的事件
|
||||
|
||||
# 通知管理器
|
||||
self.notification_manager = NotificationManager()
|
||||
|
||||
# 冷场检查配置
|
||||
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
|
||||
self.last_cold_chat_check: float = time.time()
|
||||
self.is_cold_chat_state: bool = False
|
||||
|
||||
self.update_event = asyncio.Event()
|
||||
self.update_interval = 2 # 更新间隔(秒)
|
||||
self.message_cache = []
|
||||
self.update_running = False
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""检查距离上一次观察之后是否有了新消息
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||
|
||||
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
|
||||
|
||||
if new_message_exists:
|
||||
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
||||
self.last_check_time = time.time()
|
||||
|
||||
return new_message_exists
|
||||
|
||||
async def _add_message_to_history(self, message: Dict[str, Any]):
|
||||
"""添加消息到历史记录并发送通知
|
||||
|
||||
Args:
|
||||
message: 消息数据
|
||||
"""
|
||||
try:
|
||||
# 发送新消息通知
|
||||
notification = create_new_message_notification(
|
||||
sender="chat_observer", target="observation_info", message=message
|
||||
)
|
||||
# print(self.notification_manager)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]添加消息到历史记录时出错: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 检查并更新冷场状态
|
||||
await self._check_cold_chat()
|
||||
|
||||
async def _check_cold_chat(self):
|
||||
"""检查是否处于冷场状态并发送通知"""
|
||||
current_time = time.time()
|
||||
|
||||
# 每10秒检查一次冷场状态
|
||||
if current_time - self.last_cold_chat_check < 10:
|
||||
return
|
||||
|
||||
self.last_cold_chat_check = current_time
|
||||
|
||||
# 判断是否冷场
|
||||
is_cold = (
|
||||
True
|
||||
if self.last_message_time is None
|
||||
else (current_time - self.last_message_time) > self.cold_chat_threshold
|
||||
)
|
||||
|
||||
# 如果冷场状态发生变化,发送通知
|
||||
if is_cold != self.is_cold_chat_state:
|
||||
self.is_cold_chat_state = is_cold
|
||||
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
|
||||
def new_message_after(self, time_point: float) -> bool:
|
||||
"""判断是否在指定时间点后有新消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
|
||||
if self.last_message_time is None:
|
||||
logger.debug(f"[私聊][{self.private_name}]没有最后消息时间,返回 False")
|
||||
return False
|
||||
|
||||
has_new = self.last_message_time > time_point
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}"
|
||||
)
|
||||
return has_new
|
||||
|
||||
def get_message_history(
|
||||
self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
limit: Optional[int] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取消息历史
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回消息数量
|
||||
user_id: 指定用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
filtered_messages = self.message_history
|
||||
|
||||
if start_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
|
||||
|
||||
if end_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
|
||||
|
||||
if user_id is not None:
|
||||
filtered_messages = [
|
||||
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
|
||||
]
|
||||
|
||||
if limit is not None:
|
||||
filtered_messages = filtered_messages[-limit:]
|
||||
|
||||
return filtered_messages
|
||||
|
||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取新消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_time)
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]
|
||||
self.last_message_time = new_messages[-1]["time"]
|
||||
|
||||
# print(f"获取数据库中找到的新消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]获取指定时间点111之前的消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
"""主要观察循环"""
|
||||
|
||||
async def _update_loop(self):
|
||||
"""更新循环"""
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# messages = await self._fetch_new_messages_before(start_time)
|
||||
# for message in messages:
|
||||
# await self._add_message_to_history(message)
|
||||
# logger.debug(f"[私聊][{self.private_name}]缓冲消息: {messages}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]缓冲消息出错: {e}")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 等待事件或超时(1秒)
|
||||
try:
|
||||
# print("等待事件")
|
||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# print("超时")
|
||||
pass # 超时后也执行一次检查
|
||||
|
||||
self._update_event.clear() # 重置触发事件
|
||||
self._update_complete.clear() # 重置完成事件
|
||||
|
||||
# 获取新消息
|
||||
new_messages = await self._fetch_new_messages()
|
||||
|
||||
if new_messages:
|
||||
# 处理新消息
|
||||
for message in new_messages:
|
||||
await self._add_message_to_history(message)
|
||||
|
||||
# 设置完成事件
|
||||
self._update_complete.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]更新循环出错: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self._update_complete.set() # 即使出错也要设置完成事件
|
||||
|
||||
def trigger_update(self):
|
||||
"""触发一次立即更新"""
|
||||
self._update_event.set()
|
||||
|
||||
async def wait_for_update(self, timeout: float = 5.0) -> bool:
|
||||
"""等待更新完成
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功完成更新(False表示超时)
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._update_complete.wait(), timeout=timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[私聊][{self.private_name}]等待更新完成超时({timeout}秒)")
|
||||
return False
|
||||
|
||||
def start(self):
|
||||
"""启动观察器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._update_loop())
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} started")
|
||||
|
||||
def stop(self):
|
||||
"""停止观察器"""
|
||||
self._running = False
|
||||
self._update_event.set() # 设置事件以解除等待
|
||||
self._update_complete.set() # 设置完成事件以解除等待
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} stopped")
|
||||
|
||||
async def process_chat_history(self, messages: list):
|
||||
"""处理聊天历史
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
"""
|
||||
self.update_check_time()
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if user_info.user_id == global_config.BOT_QQ:
|
||||
self.update_bot_speak_time(msg["time"])
|
||||
else:
|
||||
self.update_user_speak_time(msg["time"])
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]处理消息时间时出错: {e}")
|
||||
continue
|
||||
|
||||
def update_check_time(self):
|
||||
"""更新查看时间"""
|
||||
self.last_check_time = time.time()
|
||||
|
||||
def update_bot_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新机器人说话时间"""
|
||||
self.last_bot_speak_time = speak_time or time.time()
|
||||
|
||||
def update_user_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新用户说话时间"""
|
||||
self.last_user_speak_time = speak_time or time.time()
|
||||
|
||||
def get_time_info(self) -> str:
|
||||
"""获取时间信息文本"""
|
||||
current_time = time.time()
|
||||
time_info = ""
|
||||
|
||||
if self.last_bot_speak_time:
|
||||
bot_speak_ago = current_time - self.last_bot_speak_time
|
||||
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
|
||||
|
||||
if self.last_user_speak_time:
|
||||
user_speak_ago = current_time - self.last_user_speak_time
|
||||
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
|
||||
|
||||
return time_info
|
||||
|
||||
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""获取缓存的消息历史
|
||||
|
||||
Args:
|
||||
limit: 获取的最大消息数量,默认50
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 缓存的消息历史列表
|
||||
"""
|
||||
return self.message_cache[-limit:]
|
||||
|
||||
def get_last_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取最后一条消息
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 最后一条消息,如果没有则返回None
|
||||
"""
|
||||
if not self.message_cache:
|
||||
return None
|
||||
return self.message_cache[-1]
|
||||
|
||||
def __str__(self):
|
||||
return f"ChatObserver for {self.stream_id}"
|
||||
290
src/chat/brain_chat/PFC/chat_states.py
Normal file
290
src/chat/brain_chat/PFC/chat_states.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, Dict, Any, List, Set
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ChatState(Enum):
|
||||
"""聊天状态枚举"""
|
||||
|
||||
NORMAL = auto() # 正常状态
|
||||
NEW_MESSAGE = auto() # 有新消息
|
||||
COLD_CHAT = auto() # 冷场状态
|
||||
ACTIVE_CHAT = auto() # 活跃状态
|
||||
BOT_SPEAKING = auto() # 机器人正在说话
|
||||
USER_SPEAKING = auto() # 用户正在说话
|
||||
SILENT = auto() # 沉默状态
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class NotificationType(Enum):
|
||||
"""通知类型枚举"""
|
||||
|
||||
NEW_MESSAGE = auto() # 新消息通知
|
||||
COLD_CHAT = auto() # 冷场通知
|
||||
ACTIVE_CHAT = auto() # 活跃通知
|
||||
BOT_SPEAKING = auto() # 机器人说话通知
|
||||
USER_SPEAKING = auto() # 用户说话通知
|
||||
MESSAGE_DELETED = auto() # 消息删除通知
|
||||
USER_JOINED = auto() # 用户加入通知
|
||||
USER_LEFT = auto() # 用户离开通知
|
||||
ERROR = auto() # 错误通知
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatStateInfo:
|
||||
"""聊天状态信息"""
|
||||
|
||||
state: ChatState
|
||||
last_message_time: Optional[float] = None
|
||||
last_message_content: Optional[str] = None
|
||||
last_speaker: Optional[str] = None
|
||||
message_count: int = 0
|
||||
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
||||
active_duration: float = 0.0 # 活跃持续时间(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Notification:
|
||||
"""通知基类"""
|
||||
|
||||
type: NotificationType
|
||||
timestamp: float
|
||||
sender: str # 发送者标识
|
||||
target: str # 接收者标识
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateNotification(Notification):
|
||||
"""持续状态通知"""
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
base_dict = super().to_dict()
|
||||
base_dict["is_active"] = self.is_active
|
||||
return base_dict
|
||||
|
||||
|
||||
class NotificationHandler(ABC):
|
||||
"""通知处理器接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_notification(self, notification: Notification):
|
||||
"""处理通知"""
|
||||
pass
|
||||
|
||||
|
||||
class NotificationManager:
|
||||
"""通知管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 按接收者和通知类型存储处理器
|
||||
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
|
||||
self._active_states: Set[NotificationType] = set()
|
||||
self._notification_history: List[Notification] = []
|
||||
|
||||
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注册通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识(例如:"pfc")
|
||||
notification_type: 要处理的通知类型
|
||||
handler: 处理器实例
|
||||
"""
|
||||
if target not in self._handlers:
|
||||
self._handlers[target] = {}
|
||||
if notification_type not in self._handlers[target]:
|
||||
self._handlers[target][notification_type] = []
|
||||
# print(self._handlers[target][notification_type])
|
||||
self._handlers[target][notification_type].append(handler)
|
||||
# print(self._handlers[target][notification_type])
|
||||
|
||||
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注销通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识
|
||||
notification_type: 通知类型
|
||||
handler: 要注销的处理器实例
|
||||
"""
|
||||
if target in self._handlers and notification_type in self._handlers[target]:
|
||||
handlers = self._handlers[target][notification_type]
|
||||
if handler in handlers:
|
||||
handlers.remove(handler)
|
||||
# 如果该类型的处理器列表为空,删除该类型
|
||||
if not handlers:
|
||||
del self._handlers[target][notification_type]
|
||||
# 如果该目标没有任何处理器,删除该目标
|
||||
if not self._handlers[target]:
|
||||
del self._handlers[target]
|
||||
|
||||
async def send_notification(self, notification: Notification):
|
||||
"""发送通知"""
|
||||
self._notification_history.append(notification)
|
||||
|
||||
# 如果是状态通知,更新活跃状态
|
||||
if isinstance(notification, StateNotification):
|
||||
if notification.is_active:
|
||||
self._active_states.add(notification.type)
|
||||
else:
|
||||
self._active_states.discard(notification.type)
|
||||
|
||||
# 调用目标接收者的处理器
|
||||
target = notification.target
|
||||
if target in self._handlers:
|
||||
handlers = self._handlers[target].get(notification.type, [])
|
||||
# print(handlers)
|
||||
for handler in handlers:
|
||||
# print(f"调用处理器: {handler}")
|
||||
await handler.handle_notification(notification)
|
||||
|
||||
def get_active_states(self) -> Set[NotificationType]:
|
||||
"""获取当前活跃的状态"""
|
||||
return self._active_states.copy()
|
||||
|
||||
def is_state_active(self, state_type: NotificationType) -> bool:
|
||||
"""检查特定状态是否活跃"""
|
||||
return state_type in self._active_states
|
||||
|
||||
def get_notification_history(
|
||||
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> List[Notification]:
|
||||
"""获取通知历史
|
||||
|
||||
Args:
|
||||
sender: 过滤特定发送者的通知
|
||||
target: 过滤特定接收者的通知
|
||||
limit: 限制返回数量
|
||||
"""
|
||||
history = self._notification_history
|
||||
|
||||
if sender:
|
||||
history = [n for n in history if n.sender == sender]
|
||||
if target:
|
||||
history = [n for n in history if n.target == target]
|
||||
|
||||
if limit is not None:
|
||||
history = history[-limit:]
|
||||
|
||||
return history
|
||||
|
||||
def __str__(self):
|
||||
str = ""
|
||||
for target, handlers in self._handlers.items():
|
||||
for notification_type, handler_list in handlers.items():
|
||||
str += f"NotificationManager for {target} {notification_type} {handler_list}"
|
||||
return str
|
||||
|
||||
|
||||
# 一些常用的通知创建函数
|
||||
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||
"""创建新消息通知"""
|
||||
return Notification(
|
||||
type=NotificationType.NEW_MESSAGE,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={
|
||||
"message_id": message.get("message_id"),
|
||||
"processed_plain_text": message.get("processed_plain_text"),
|
||||
"detailed_plain_text": message.get("detailed_plain_text"),
|
||||
"user_info": message.get("user_info"),
|
||||
"time": message.get("time"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
||||
"""创建冷场状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.COLD_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_cold": is_cold},
|
||||
is_active=is_cold,
|
||||
)
|
||||
|
||||
|
||||
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
||||
"""创建活跃状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.ACTIVE_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_active": is_active},
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
|
||||
class ChatStateManager:
|
||||
"""聊天状态管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_state = ChatState.NORMAL
|
||||
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
|
||||
self.state_history: list[ChatStateInfo] = []
|
||||
|
||||
def update_state(self, new_state: ChatState, **kwargs):
|
||||
"""更新聊天状态
|
||||
|
||||
Args:
|
||||
new_state: 新的状态
|
||||
**kwargs: 其他状态信息
|
||||
"""
|
||||
self.current_state = new_state
|
||||
self.state_info.state = new_state
|
||||
|
||||
# 更新其他状态信息
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self.state_info, key):
|
||||
setattr(self.state_info, key, value)
|
||||
|
||||
# 记录状态历史
|
||||
self.state_history.append(self.state_info)
|
||||
|
||||
def get_current_state_info(self) -> ChatStateInfo:
|
||||
"""获取当前状态信息"""
|
||||
return self.state_info
|
||||
|
||||
def get_state_history(self) -> list[ChatStateInfo]:
|
||||
"""获取状态历史"""
|
||||
return self.state_history
|
||||
|
||||
def is_cold_chat(self, threshold: float = 60.0) -> bool:
|
||||
"""判断是否处于冷场状态
|
||||
|
||||
Args:
|
||||
threshold: 冷场阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否冷场
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return True
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) > threshold
|
||||
|
||||
def is_active_chat(self, threshold: float = 5.0) -> bool:
|
||||
"""判断是否处于活跃状态
|
||||
|
||||
Args:
|
||||
threshold: 活跃阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否活跃
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return False
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) <= threshold
|
||||
701
src/chat/brain_chat/PFC/conversation.py
Normal file
701
src/chat/brain_chat/PFC/conversation.py
Normal file
@@ -0,0 +1,701 @@
|
||||
import time
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# from ...config.config import global_config
|
||||
from typing import Dict, Any, Optional
|
||||
from ..chat.message import Message
|
||||
from .pfc_types import ConversationState
|
||||
from .pfc import ChatObserver, GoalAnalyzer
|
||||
from .message_sender import DirectMessageSender
|
||||
from src.common.logger_manager import get_logger
|
||||
from .action_planner import ActionPlanner
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
|
||||
from .reply_generator import ReplyGenerator
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from maim_message import UserInfo
|
||||
from src.plugins.chat.chat_stream import chat_manager
|
||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||
from .waiter import Waiter
|
||||
|
||||
import traceback
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("pfc")
|
||||
|
||||
|
||||
class Conversation:
|
||||
"""对话类,负责管理单个对话的状态和行为"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.state = ConversationState.INIT
|
||||
self.should_continue = False
|
||||
self.ignore_until_timestamp: Optional[float] = None
|
||||
|
||||
# 回复相关
|
||||
self.generated_reply = ""
|
||||
|
||||
async def _initialize(self):
|
||||
"""初始化实例,注册所有组件"""
|
||||
|
||||
try:
|
||||
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
|
||||
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
|
||||
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
|
||||
self.knowledge_fetcher = KnowledgeFetcher(self.private_name)
|
||||
self.waiter = Waiter(self.stream_id, self.private_name)
|
||||
self.direct_sender = DirectMessageSender(self.private_name)
|
||||
|
||||
# 获取聊天流信息
|
||||
self.chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
|
||||
self.stop_action_planner = False
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册运行组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
try:
|
||||
# 决策所需要的信息,包括自身自信和观察信息两部分
|
||||
# 注册观察器和观测信息
|
||||
self.chat_observer = ChatObserver.get_instance(self.stream_id, self.private_name)
|
||||
self.chat_observer.start()
|
||||
self.observation_info = ObservationInfo(self.private_name)
|
||||
self.observation_info.bind_to_chat_observer(self.chat_observer)
|
||||
# print(self.chat_observer.get_cached_messages(limit=)
|
||||
|
||||
self.conversation_info = ConversationInfo()
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册信息组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...")
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat( #
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
initial_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
self.observation_info.chat_history_count = len(initial_messages)
|
||||
|
||||
# 更新 ObservationInfo 中的时间戳等信息
|
||||
last_msg = initial_messages[-1]
|
||||
self.observation_info.last_message_time = last_msg.get("time")
|
||||
last_user_info = UserInfo.from_dict(last_msg.get("user_info", {}))
|
||||
self.observation_info.last_message_sender = last_user_info.user_id
|
||||
self.observation_info.last_message_content = last_msg.get("processed_plain_text", "")
|
||||
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]成功加载 {len(initial_messages)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}"
|
||||
)
|
||||
|
||||
# 让 ChatObserver 从加载的最后一条消息之后开始同步
|
||||
self.chat_observer.last_message_time = self.observation_info.last_message_time
|
||||
self.chat_observer.last_message_read = last_msg # 更新 observer 的最后读取记录
|
||||
else:
|
||||
logger.info(f"[私聊][{self.private_name}]没有找到初始聊天记录。")
|
||||
|
||||
except Exception as load_err:
|
||||
logger.error(f"[私聊][{self.private_name}]加载初始聊天记录时出错: {load_err}")
|
||||
# 出错也要继续,只是没有历史记录而已
|
||||
# 组件准备完成,启动该论对话
|
||||
self.should_continue = True
|
||||
asyncio.create_task(self.start())
|
||||
|
||||
async def start(self):
|
||||
"""开始对话流程"""
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]对话系统启动中...")
|
||||
asyncio.create_task(self._plan_and_action_loop())
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]启动对话系统失败: {e}")
|
||||
raise
|
||||
|
||||
async def _plan_and_action_loop(self):
|
||||
"""思考步,PFC核心循环模块"""
|
||||
while self.should_continue:
|
||||
# 忽略逻辑
|
||||
if self.ignore_until_timestamp and time.time() < self.ignore_until_timestamp:
|
||||
await asyncio.sleep(30)
|
||||
continue
|
||||
elif self.ignore_until_timestamp and time.time() >= self.ignore_until_timestamp:
|
||||
logger.info(f"[私聊][{self.private_name}]忽略时间已到 {self.stream_id},准备结束对话。")
|
||||
self.ignore_until_timestamp = None
|
||||
self.should_continue = False
|
||||
continue
|
||||
try:
|
||||
# --- 在规划前记录当前新消息数量 ---
|
||||
initial_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
initial_new_message_count = self.observation_info.new_messages_count + 1 # 算上麦麦自己发的那一条
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' before planning."
|
||||
)
|
||||
|
||||
# --- 调用 Action Planner ---
|
||||
# 传递 self.conversation_info.last_successful_reply_action
|
||||
action, reason = await self.action_planner.plan(
|
||||
self.observation_info, self.conversation_info, self.conversation_info.last_successful_reply_action
|
||||
)
|
||||
|
||||
# --- 规划后检查是否有 *更多* 新消息到达 ---
|
||||
current_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
current_new_message_count = self.observation_info.new_messages_count
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' after planning."
|
||||
)
|
||||
|
||||
if current_new_message_count > initial_new_message_count + 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]规划期间发现新增消息 ({initial_new_message_count} -> {current_new_message_count}),跳过本次行动,重新规划"
|
||||
)
|
||||
# 如果规划期间有新消息,也应该重置上次回复状态,因为现在要响应新消息了
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 包含 send_new_message
|
||||
if initial_new_message_count > 0 and action in ["direct_reply", "send_new_message"]:
|
||||
if hasattr(self.observation_info, "clear_unprocessed_messages"):
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]准备执行 {action},清理 {initial_new_message_count} 条规划时已知的新消息。"
|
||||
)
|
||||
await self.observation_info.clear_unprocessed_messages()
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
self.observation_info.new_messages_count = 0
|
||||
else:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]无法清理未处理消息: ObservationInfo 缺少 clear_unprocessed_messages 方法!"
|
||||
)
|
||||
|
||||
await self._handle_action(action, reason, self.observation_info, self.conversation_info)
|
||||
|
||||
# 检查是否需要结束对话 (逻辑不变)
|
||||
goal_ended = False
|
||||
if hasattr(self.conversation_info, "goal_list") and self.conversation_info.goal_list:
|
||||
for goal_item in self.conversation_info.goal_list:
|
||||
if isinstance(goal_item, dict):
|
||||
current_goal = goal_item.get("goal")
|
||||
|
||||
if current_goal == "结束对话":
|
||||
goal_ended = True
|
||||
break
|
||||
|
||||
if goal_ended:
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]检测到'结束对话'目标,停止循环。")
|
||||
|
||||
except Exception as loop_err:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC主循环出错: {loop_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if self.should_continue:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]PFC 循环结束 for stream_id: {self.stream_id}")
|
||||
|
||||
def _check_new_messages_after_planning(self):
|
||||
"""检查在规划后是否有新消息"""
|
||||
# 检查 ObservationInfo 是否已初始化并且有 new_messages_count 属性
|
||||
if not hasattr(self, "observation_info") or not hasattr(self.observation_info, "new_messages_count"):
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo 未初始化或缺少 'new_messages_count' 属性,无法检查新消息。"
|
||||
)
|
||||
return False # 或者根据需要抛出错误
|
||||
|
||||
if self.observation_info.new_messages_count > 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]生成/执行动作期间收到 {self.observation_info.new_messages_count} 条新消息,取消当前动作并重新规划"
|
||||
)
|
||||
# 如果有新消息,也应该重置上次回复状态
|
||||
if hasattr(self, "conversation_info"): # 确保 conversation_info 已初始化
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo 未初始化,无法重置 last_successful_reply_action。"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
||||
"""将消息字典转换为Message对象"""
|
||||
try:
|
||||
# 尝试从 msg_dict 直接获取 chat_stream,如果失败则从全局 chat_manager 获取
|
||||
chat_info = msg_dict.get("chat_info")
|
||||
if chat_info and isinstance(chat_info, dict):
|
||||
chat_stream = ChatStream.from_dict(chat_info)
|
||||
elif self.chat_stream: # 使用实例变量中的 chat_stream
|
||||
chat_stream = self.chat_stream
|
||||
else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id)
|
||||
chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
if not chat_stream:
|
||||
raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}")
|
||||
|
||||
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
|
||||
|
||||
return Message(
|
||||
message_id=msg_dict.get("message_id", f"gen_{time.time()}"), # 提供默认 ID
|
||||
chat_stream=chat_stream, # 使用确定的 chat_stream
|
||||
time=msg_dict.get("time", time.time()), # 提供默认时间
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
|
||||
# 可以选择返回 None 或重新抛出异常,这里选择重新抛出以指示问题
|
||||
raise ValueError(f"无法将字典转换为 Message 对象: {e}") from e
|
||||
|
||||
async def _handle_action(
|
||||
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
|
||||
):
|
||||
"""处理规划的行动"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]执行行动: {action}, 原因: {reason}")
|
||||
|
||||
# 记录action历史 (逻辑不变)
|
||||
current_action_record = {
|
||||
"action": action,
|
||||
"plan_reason": reason,
|
||||
"status": "start",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
# 确保 done_action 列表存在
|
||||
if not hasattr(conversation_info, "done_action"):
|
||||
conversation_info.done_action = []
|
||||
conversation_info.done_action.append(current_action_record)
|
||||
action_index = len(conversation_info.done_action) - 1
|
||||
|
||||
action_successful = False # 用于标记动作是否成功完成
|
||||
|
||||
# --- 根据不同的 action 执行 ---
|
||||
|
||||
# send_new_message 失败后执行 wait
|
||||
if action == "send_new_message":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成追问回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复 (调用 generate 时传入 action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="send_new_message"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的追问回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复 (逻辑不变)
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (追问) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成追问回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送追问: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 send_new_message
|
||||
self.conversation_info.last_successful_reply_action = "send_new_message"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,追问回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 追问失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的追问回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 追问失败,下次用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适追问回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 send_new_message 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "direct_reply":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成首次回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="direct_reply"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的首次回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (首次回复) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成首次回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送首次回复: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 direct_reply
|
||||
self.conversation_info.last_successful_reply_action = "direct_reply"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,首次回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 首次回复失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的首次回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 首次回复失败,下次还是用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作 (保持原有逻辑)
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适首次回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "fetch_knowledge":
|
||||
self.state = ConversationState.FETCHING
|
||||
knowledge_query = reason
|
||||
try:
|
||||
# 检查 knowledge_fetcher 是否存在
|
||||
if not hasattr(self, "knowledge_fetcher"):
|
||||
logger.error(f"[私聊][{self.private_name}]KnowledgeFetcher 未初始化,无法获取知识。")
|
||||
raise AttributeError("KnowledgeFetcher not initialized")
|
||||
|
||||
knowledge, source = await self.knowledge_fetcher.fetch(knowledge_query, observation_info.chat_history)
|
||||
logger.info(f"[私聊][{self.private_name}]获取到知识: {knowledge[:100]}..., 来源: {source}")
|
||||
if knowledge:
|
||||
# 确保 knowledge_list 存在
|
||||
if not hasattr(conversation_info, "knowledge_list"):
|
||||
conversation_info.knowledge_list = []
|
||||
conversation_info.knowledge_list.append(
|
||||
{"query": knowledge_query, "knowledge": knowledge, "source": source}
|
||||
)
|
||||
action_successful = True
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"[私聊][{self.private_name}]获取知识时出错: {str(fetch_err)}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"获取知识失败: {str(fetch_err)}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "rethink_goal":
|
||||
self.state = ConversationState.RETHINKING
|
||||
try:
|
||||
# 检查 goal_analyzer 是否存在
|
||||
if not hasattr(self, "goal_analyzer"):
|
||||
logger.error(f"[私聊][{self.private_name}]GoalAnalyzer 未初始化,无法重新思考目标。")
|
||||
raise AttributeError("GoalAnalyzer not initialized")
|
||||
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
||||
action_successful = True
|
||||
except Exception as rethink_err:
|
||||
logger.error(f"[私聊][{self.private_name}]重新思考目标时出错: {rethink_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"重新思考目标失败: {rethink_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "listening":
|
||||
self.state = ConversationState.LISTENING
|
||||
logger.info(f"[私聊][{self.private_name}]倾听对方发言...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法倾听。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
await self.waiter.wait_listening(conversation_info)
|
||||
action_successful = True # Listening 完成就算成功
|
||||
except Exception as listen_err:
|
||||
logger.error(f"[私聊][{self.private_name}]倾听时出错: {listen_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"倾听失败: {listen_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "say_goodbye":
|
||||
self.state = ConversationState.GENERATING # 也可以定义一个新的状态,如 ENDING
|
||||
logger.info(f"[私聊][{self.private_name}]执行行动: 生成并发送告别语...")
|
||||
try:
|
||||
# 1. 生成告别语 (使用 'say_goodbye' action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="say_goodbye"
|
||||
)
|
||||
logger.info(f"[私聊][{self.private_name}]生成的告别语: {self.generated_reply}")
|
||||
|
||||
# 2. 直接发送告别语 (不经过检查)
|
||||
if self.generated_reply: # 确保生成了内容
|
||||
await self._send_reply() # 调用发送方法
|
||||
# 发送成功后,标记动作成功
|
||||
action_successful = True
|
||||
logger.info(f"[私聊][{self.private_name}]告别语已发送。")
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]未能生成告别语内容,无法发送。")
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": "未能生成告别语内容"}
|
||||
)
|
||||
|
||||
# 3. 无论是否发送成功,都准备结束对话
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]发送告别语流程结束,即将停止对话实例。")
|
||||
|
||||
except Exception as goodbye_err:
|
||||
logger.error(f"[私聊][{self.private_name}]生成或发送告别语时出错: {goodbye_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
# 即使出错,也结束对话
|
||||
self.should_continue = False
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"生成或发送告别语时出错: {goodbye_err}"}
|
||||
)
|
||||
|
||||
elif action == "end_conversation":
|
||||
# 这个分支现在只会在 action_planner 最终决定不告别时被调用
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]收到最终结束指令,停止对话...")
|
||||
action_successful = True # 标记这个指令本身是成功的
|
||||
|
||||
elif action == "block_and_ignore":
|
||||
logger.info(f"[私聊][{self.private_name}]不想再理你了...")
|
||||
ignore_duration_seconds = 10 * 60
|
||||
self.ignore_until_timestamp = time.time() + ignore_duration_seconds
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]将忽略此对话直到: {datetime.datetime.fromtimestamp(self.ignore_until_timestamp)}"
|
||||
)
|
||||
self.state = ConversationState.IGNORED
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
else: # 对应 'wait' 动作
|
||||
self.state = ConversationState.WAITING
|
||||
logger.info(f"[私聊][{self.private_name}]等待更多信息...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法等待。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
_timeout_occurred = await self.waiter.wait(self.conversation_info)
|
||||
action_successful = True # Wait 完成就算成功
|
||||
except Exception as wait_err:
|
||||
logger.error(f"[私聊][{self.private_name}]等待时出错: {wait_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"等待失败: {wait_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
# --- 更新 Action History 状态 ---
|
||||
# 只有当动作本身成功时,才更新状态为 done
|
||||
if action_successful:
|
||||
conversation_info.done_action[action_index].update(
|
||||
{
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
}
|
||||
)
|
||||
# 重置状态: 对于非回复类动作的成功,清除上次回复状态
|
||||
if action not in ["direct_reply", "send_new_message"]:
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
logger.debug(f"[私聊][{self.private_name}]动作 {action} 成功完成,重置 last_successful_reply_action")
|
||||
# 如果动作是 recall 状态,在各自的处理逻辑中已经更新了 done_action
|
||||
|
||||
async def _send_reply(self):
|
||||
"""发送回复"""
|
||||
if not self.generated_reply:
|
||||
logger.warning(f"[私聊][{self.private_name}]没有生成回复内容,无法发送。")
|
||||
return
|
||||
|
||||
try:
|
||||
_current_time = time.time()
|
||||
reply_content = self.generated_reply
|
||||
|
||||
# 发送消息 (确保 direct_sender 和 chat_stream 有效)
|
||||
if not hasattr(self, "direct_sender") or not self.direct_sender:
|
||||
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
|
||||
return
|
||||
if not self.chat_stream:
|
||||
logger.error(f"[私聊][{self.private_name}]ChatStream 未初始化,无法发送回复。")
|
||||
return
|
||||
|
||||
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
|
||||
|
||||
# 发送成功后,手动触发 observer 更新可能导致重复处理自己发送的消息
|
||||
# 更好的做法是依赖 observer 的自动轮询或数据库触发器(如果支持)
|
||||
# 暂时注释掉,观察是否影响 ObservationInfo 的更新
|
||||
# self.chat_observer.trigger_update()
|
||||
# if not await self.chat_observer.wait_for_update():
|
||||
# logger.warning(f"[私聊][{self.private_name}]等待 ChatObserver 更新完成超时")
|
||||
|
||||
self.state = ConversationState.ANALYZING # 更新状态
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送消息或更新状态时失败: {str(e)}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self.state = ConversationState.ANALYZING
|
||||
|
||||
async def _send_timeout_message(self):
|
||||
"""发送超时结束消息"""
|
||||
try:
|
||||
messages = self.chat_observer.get_cached_messages(limit=1)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
latest_message = self._convert_to_message(messages[0])
|
||||
await self.direct_sender.send_message(
|
||||
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送超时消息失败: {str(e)}")
|
||||
10
src/chat/brain_chat/PFC/conversation_info.py
Normal file
10
src/chat/brain_chat/PFC/conversation_info.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ConversationInfo:
|
||||
def __init__(self):
|
||||
self.done_action = []
|
||||
self.goal_list = []
|
||||
self.knowledge_list = []
|
||||
self.memory_list = []
|
||||
self.last_successful_reply_action: Optional[str] = None
|
||||
81
src/chat/brain_chat/PFC/message_sender.py
Normal file
81
src/chat/brain_chat/PFC/message_sender.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from ..chat.message import Message
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.plugins.chat.message import MessageSending, MessageSet
|
||||
from src.plugins.chat.message_sender import message_manager
|
||||
from ..storage.storage import MessageStorage
|
||||
from ...config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_module_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接消息发送器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.private_name = private_name
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
content: str,
|
||||
reply_to_message: Optional[Message] = None,
|
||||
) -> None:
|
||||
"""发送消息到聊天流
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流
|
||||
content: 消息内容
|
||||
reply_to_message: 要回复的消息(可选)
|
||||
"""
|
||||
try:
|
||||
# 创建消息内容
|
||||
segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
|
||||
|
||||
# 获取麦麦的信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
platform=chat_stream.platform,
|
||||
)
|
||||
|
||||
# 用当前时间作为message_id,和之前那套sender一样
|
||||
message_id = f"dm{round(time.time(), 2)}"
|
||||
|
||||
# 构建消息对象
|
||||
message = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
|
||||
message_segment=segments,
|
||||
reply=reply_to_message,
|
||||
is_head=True,
|
||||
is_emoji=False,
|
||||
thinking_start_time=time.time(),
|
||||
)
|
||||
|
||||
# 处理消息
|
||||
await message.process()
|
||||
|
||||
# 不知道有什么用,先留下来了,和之前那套sender一样
|
||||
_message_json = message.to_dict()
|
||||
|
||||
# 发送消息
|
||||
message_set = MessageSet(chat_stream, message_id)
|
||||
message_set.add_message(message)
|
||||
await message_manager.add_message(message_set)
|
||||
await self.storage.store_message(message, chat_stream)
|
||||
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
raise
|
||||
119
src/chat/brain_chat/PFC/message_storage.py
Normal file
119
src/chat/brain_chat/PFC/message_storage.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from src.common.database import db
|
||||
|
||||
|
||||
class MessageStorage(ABC):
|
||||
"""消息存储接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages_after(self, chat_id: str, message: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""获取指定消息ID之后的所有消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
message: 消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_point: 时间戳
|
||||
limit: 最大消息数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
"""检查是否有新消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
after_time: 时间戳
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MongoDBMessageStorage(MessageStorage):
|
||||
"""MongoDB消息存储实现"""
|
||||
|
||||
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
|
||||
# print(f"storage_check_message: {message_time}")
|
||||
|
||||
return list(db.messages.find(query).sort("time", 1))
|
||||
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
|
||||
|
||||
messages = list(db.messages.find(query).sort("time", -1).limit(limit))
|
||||
|
||||
# 将消息按时间正序排列
|
||||
messages.reverse()
|
||||
return messages
|
||||
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
|
||||
|
||||
return db.messages.find_one(query) is not None
|
||||
|
||||
|
||||
# # 创建一个内存消息存储实现,用于测试
|
||||
# class InMemoryMessageStorage(MessageStorage):
|
||||
# """内存消息存储实现,主要用于测试"""
|
||||
|
||||
# def __init__(self):
|
||||
# self.messages: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
# if chat_id not in self.messages:
|
||||
# return []
|
||||
|
||||
# messages = self.messages[chat_id]
|
||||
# if not message_id:
|
||||
# return messages
|
||||
|
||||
# # 找到message_id的索引
|
||||
# try:
|
||||
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
|
||||
# return messages[index + 1:]
|
||||
# except StopIteration:
|
||||
# return []
|
||||
|
||||
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
# if chat_id not in self.messages:
|
||||
# return []
|
||||
|
||||
# messages = [
|
||||
# m for m in self.messages[chat_id]
|
||||
# if m["time"] < time_point
|
||||
# ]
|
||||
|
||||
# return messages[-limit:]
|
||||
|
||||
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
# if chat_id not in self.messages:
|
||||
# return False
|
||||
|
||||
# return any(m["time"] > after_time for m in self.messages[chat_id])
|
||||
|
||||
# # 测试辅助方法
|
||||
# def add_message(self, chat_id: str, message: Dict[str, Any]):
|
||||
# """添加测试消息"""
|
||||
# if chat_id not in self.messages:
|
||||
# self.messages[chat_id] = []
|
||||
# self.messages[chat_id].append(message)
|
||||
# self.messages[chat_id].sort(key=lambda m: m["time"])
|
||||
389
src/chat/brain_chat/PFC/observation_info.py
Normal file
389
src/chat/brain_chat/PFC/observation_info.py
Normal file
@@ -0,0 +1,389 @@
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_module_logger("observation_info")
|
||||
|
||||
|
||||
class ObservationInfoHandler(NotificationHandler):
|
||||
"""ObservationInfo的通知处理器"""
|
||||
|
||||
def __init__(self, observation_info: "ObservationInfo", private_name: str):
|
||||
"""初始化处理器
|
||||
|
||||
Args:
|
||||
observation_info: 要更新的ObservationInfo实例
|
||||
private_name: 私聊对象的名称,用于日志记录
|
||||
"""
|
||||
self.observation_info = observation_info
|
||||
# 将 private_name 存储在 handler 实例中
|
||||
self.private_name = private_name
|
||||
|
||||
async def handle_notification(self, notification: Notification): # 添加类型提示
|
||||
# 获取通知类型和数据
|
||||
notification_type = notification.type
|
||||
data = notification.data
|
||||
|
||||
try: # 添加错误处理块
|
||||
if notification_type == NotificationType.NEW_MESSAGE:
|
||||
# 处理新消息通知
|
||||
# logger.debug(f"[私聊][{self.private_name}]收到新消息通知data: {data}") # 可以在需要时取消注释
|
||||
message_id = data.get("message_id")
|
||||
processed_plain_text = data.get("processed_plain_text")
|
||||
detailed_plain_text = data.get("detailed_plain_text")
|
||||
user_info_dict = data.get("user_info") # 先获取字典
|
||||
time_value = data.get("time")
|
||||
|
||||
# 确保 user_info 是字典类型再创建 UserInfo 对象
|
||||
user_info = None
|
||||
if isinstance(user_info_dict, dict):
|
||||
try:
|
||||
user_info = UserInfo.from_dict(user_info_dict)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]从字典创建 UserInfo 时出错: {e}, 字典内容: {user_info_dict}"
|
||||
)
|
||||
# 可以选择在这里返回或记录错误,避免后续代码出错
|
||||
return
|
||||
elif user_info_dict is not None:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]收到的 user_info 不是预期的字典类型: {type(user_info_dict)}"
|
||||
)
|
||||
# 根据需要处理非字典情况,这里暂时返回
|
||||
return
|
||||
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"processed_plain_text": processed_plain_text,
|
||||
"detailed_plain_text": detailed_plain_text,
|
||||
"user_info": user_info_dict, # 存储原始字典或 UserInfo 对象,取决于你的 update_from_message 如何处理
|
||||
"time": time_value,
|
||||
}
|
||||
# 传递 UserInfo 对象(如果成功创建)或原始字典
|
||||
await self.observation_info.update_from_message(message, user_info) # 修改:传递 user_info 对象
|
||||
|
||||
elif notification_type == NotificationType.COLD_CHAT:
|
||||
# 处理冷场通知
|
||||
is_cold = data.get("is_cold", False)
|
||||
await self.observation_info.update_cold_chat_status(is_cold, time.time()) # 修改:改为 await 调用
|
||||
|
||||
elif notification_type == NotificationType.ACTIVE_CHAT:
|
||||
# 处理活跃通知 (通常由 COLD_CHAT 的反向状态处理)
|
||||
is_active = data.get("is_active", False)
|
||||
self.observation_info.is_cold = not is_active
|
||||
|
||||
elif notification_type == NotificationType.BOT_SPEAKING:
|
||||
# 处理机器人说话通知 (按需实现)
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_bot_speak_time = time.time()
|
||||
|
||||
elif notification_type == NotificationType.USER_SPEAKING:
|
||||
# 处理用户说话通知
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_user_speak_time = time.time()
|
||||
|
||||
elif notification_type == NotificationType.MESSAGE_DELETED:
|
||||
# 处理消息删除通知
|
||||
message_id = data.get("message_id")
|
||||
# 从 unprocessed_messages 中移除被删除的消息
|
||||
original_count = len(self.observation_info.unprocessed_messages)
|
||||
self.observation_info.unprocessed_messages = [
|
||||
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
|
||||
]
|
||||
if len(self.observation_info.unprocessed_messages) < original_count:
|
||||
logger.info(f"[私聊][{self.private_name}]移除了未处理的消息 (ID: {message_id})")
|
||||
|
||||
elif notification_type == NotificationType.USER_JOINED:
|
||||
# 处理用户加入通知 (如果适用私聊场景)
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.add(str(user_id)) # 确保是字符串
|
||||
|
||||
elif notification_type == NotificationType.USER_LEFT:
|
||||
# 处理用户离开通知 (如果适用私聊场景)
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.discard(str(user_id)) # 确保是字符串
|
||||
|
||||
elif notification_type == NotificationType.ERROR:
|
||||
# 处理错误通知
|
||||
error_msg = data.get("error", "未提供错误信息")
|
||||
logger.error(f"[私聊][{self.private_name}]收到错误通知: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理通知时发生错误: {e}")
|
||||
logger.error(traceback.format_exc()) # 打印详细堆栈信息
|
||||
|
||||
|
||||
# @dataclass <-- 这个,不需要了(递黄瓜)
|
||||
class ObservationInfo:
|
||||
"""决策信息类,用于收集和管理来自chat_observer的通知信息 (手动实现 __init__)"""
|
||||
|
||||
# 类型提示保留,可用于文档和静态分析
|
||||
private_name: str
|
||||
chat_history: List[Dict[str, Any]]
|
||||
chat_history_str: str
|
||||
unprocessed_messages: List[Dict[str, Any]]
|
||||
active_users: Set[str]
|
||||
last_bot_speak_time: Optional[float]
|
||||
last_user_speak_time: Optional[float]
|
||||
last_message_time: Optional[float]
|
||||
last_message_id: Optional[str]
|
||||
last_message_content: str
|
||||
last_message_sender: Optional[str]
|
||||
bot_id: Optional[str]
|
||||
chat_history_count: int
|
||||
new_messages_count: int
|
||||
cold_chat_start_time: Optional[float]
|
||||
cold_chat_duration: float
|
||||
is_typing: bool
|
||||
is_cold_chat: bool
|
||||
changed: bool
|
||||
chat_observer: Optional[ChatObserver]
|
||||
handler: Optional[ObservationInfoHandler]
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
"""
|
||||
手动初始化 ObservationInfo 的所有实例变量。
|
||||
"""
|
||||
|
||||
# 接收的参数
|
||||
self.private_name: str = private_name
|
||||
|
||||
# data_list
|
||||
self.chat_history: List[Dict[str, Any]] = []
|
||||
self.chat_history_str: str = ""
|
||||
self.unprocessed_messages: List[Dict[str, Any]] = []
|
||||
self.active_users: Set[str] = set()
|
||||
|
||||
# data
|
||||
self.last_bot_speak_time: Optional[float] = None
|
||||
self.last_user_speak_time: Optional[float] = None
|
||||
self.last_message_time: Optional[float] = None
|
||||
self.last_message_id: Optional[str] = None
|
||||
self.last_message_content: str = ""
|
||||
self.last_message_sender: Optional[str] = None
|
||||
self.bot_id: Optional[str] = None
|
||||
self.chat_history_count: int = 0
|
||||
self.new_messages_count: int = 0
|
||||
self.cold_chat_start_time: Optional[float] = None
|
||||
self.cold_chat_duration: float = 0.0
|
||||
|
||||
# state
|
||||
self.is_typing: bool = False
|
||||
self.is_cold_chat: bool = False
|
||||
self.changed: bool = False
|
||||
|
||||
# 关联对象
|
||||
self.chat_observer: Optional[ChatObserver] = None
|
||||
|
||||
self.handler: ObservationInfoHandler = ObservationInfoHandler(self, self.private_name)
|
||||
|
||||
def bind_to_chat_observer(self, chat_observer: ChatObserver):
|
||||
"""绑定到指定的chat_observer
|
||||
|
||||
Args:
|
||||
chat_observer: 要绑定的 ChatObserver 实例
|
||||
"""
|
||||
if self.chat_observer:
|
||||
logger.warning(f"[私聊][{self.private_name}]尝试重复绑定 ChatObserver")
|
||||
return
|
||||
|
||||
self.chat_observer = chat_observer
|
||||
try:
|
||||
if not self.handler: # 确保 handler 已经被创建
|
||||
logger.error(f"[私聊][{self.private_name}] 尝试绑定时 handler 未初始化!")
|
||||
self.chat_observer = None # 重置,防止后续错误
|
||||
return
|
||||
|
||||
# 注册关心的通知类型
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
# 可以根据需要注册更多通知类型
|
||||
# self.chat_observer.notification_manager.register_handler(
|
||||
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
|
||||
# )
|
||||
logger.info(f"[私聊][{self.private_name}]成功绑定到 ChatObserver")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]绑定到 ChatObserver 时出错: {e}")
|
||||
self.chat_observer = None # 绑定失败,重置
|
||||
|
||||
def unbind_from_chat_observer(self):
|
||||
"""解除与chat_observer的绑定"""
|
||||
if (
|
||||
self.chat_observer and hasattr(self.chat_observer, "notification_manager") and self.handler
|
||||
): # 增加 handler 检查
|
||||
try:
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
# 如果注册了其他类型,也要在这里注销
|
||||
# self.chat_observer.notification_manager.unregister_handler(
|
||||
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
|
||||
# )
|
||||
logger.info(f"[私聊][{self.private_name}]成功从 ChatObserver 解绑")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]从 ChatObserver 解绑时出错: {e}")
|
||||
finally: # 确保 chat_observer 被重置
|
||||
self.chat_observer = None
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]尝试解绑时 ChatObserver 不存在、无效或 handler 未设置")
|
||||
|
||||
# 修改:update_from_message 接收 UserInfo 对象
|
||||
async def update_from_message(self, message: Dict[str, Any], user_info: Optional[UserInfo]):
|
||||
"""从消息更新信息
|
||||
|
||||
Args:
|
||||
message: 消息数据字典
|
||||
user_info: 解析后的 UserInfo 对象 (可能为 None)
|
||||
"""
|
||||
message_time = message.get("time")
|
||||
message_id = message.get("message_id")
|
||||
processed_text = message.get("processed_plain_text", "")
|
||||
|
||||
# 只有在新消息到达时才更新 last_message 相关信息
|
||||
if message_time and message_time > (self.last_message_time or 0):
|
||||
self.last_message_time = message_time
|
||||
self.last_message_id = message_id
|
||||
self.last_message_content = processed_text
|
||||
# 重置冷场计时器
|
||||
self.is_cold_chat = False
|
||||
self.cold_chat_start_time = None
|
||||
self.cold_chat_duration = 0.0
|
||||
|
||||
if user_info:
|
||||
sender_id = str(user_info.user_id) # 确保是字符串
|
||||
self.last_message_sender = sender_id
|
||||
# 更新发言时间
|
||||
if sender_id == self.bot_id:
|
||||
self.last_bot_speak_time = message_time
|
||||
else:
|
||||
self.last_user_speak_time = message_time
|
||||
self.active_users.add(sender_id) # 用户发言则认为其活跃
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]处理消息更新时缺少有效的 UserInfo 对象, message_id: {message_id}"
|
||||
)
|
||||
self.last_message_sender = None # 发送者未知
|
||||
|
||||
# 将原始消息字典添加到未处理列表
|
||||
self.unprocessed_messages.append(message)
|
||||
self.new_messages_count = len(self.unprocessed_messages) # 直接用列表长度
|
||||
|
||||
# logger.debug(f"[私聊][{self.private_name}]消息更新: last_time={self.last_message_time}, new_count={self.new_messages_count}")
|
||||
self.update_changed() # 标记状态已改变
|
||||
else:
|
||||
# 如果消息时间戳不是最新的,可能不需要处理,或者记录一个警告
|
||||
pass
|
||||
# logger.warning(f"[私聊][{self.private_name}]收到过时或无效时间戳的消息: ID={message_id}, time={message_time}")
|
||||
|
||||
def update_changed(self):
|
||||
"""标记状态已改变,并重置标记"""
|
||||
# logger.debug(f"[私聊][{self.private_name}]状态标记为已改变 (changed=True)")
|
||||
self.changed = True
|
||||
|
||||
async def update_cold_chat_status(self, is_cold: bool, current_time: float):
|
||||
"""更新冷场状态
|
||||
|
||||
Args:
|
||||
is_cold: 是否处于冷场状态
|
||||
current_time: 当前时间戳
|
||||
"""
|
||||
if is_cold != self.is_cold_chat: # 仅在状态变化时更新
|
||||
self.is_cold_chat = is_cold
|
||||
if is_cold:
|
||||
# 进入冷场状态
|
||||
self.cold_chat_start_time = (
|
||||
self.last_message_time or current_time
|
||||
) # 从最后消息时间开始算,或从当前时间开始
|
||||
logger.info(f"[私聊][{self.private_name}]进入冷场状态,开始时间: {self.cold_chat_start_time}")
|
||||
else:
|
||||
# 结束冷场状态
|
||||
if self.cold_chat_start_time:
|
||||
self.cold_chat_duration = current_time - self.cold_chat_start_time
|
||||
logger.info(f"[私聊][{self.private_name}]结束冷场状态,持续时间: {self.cold_chat_duration:.2f} 秒")
|
||||
self.cold_chat_start_time = None # 重置开始时间
|
||||
self.update_changed() # 状态变化,标记改变
|
||||
|
||||
# 即使状态没变,如果是冷场状态,也更新持续时间
|
||||
if self.is_cold_chat and self.cold_chat_start_time:
|
||||
self.cold_chat_duration = current_time - self.cold_chat_start_time
|
||||
|
||||
def get_active_duration(self) -> float:
|
||||
"""获取当前活跃时长 (距离最后一条消息的时间)
|
||||
|
||||
Returns:
|
||||
float: 最后一条消息到现在的时长(秒)
|
||||
"""
|
||||
if not self.last_message_time:
|
||||
return 0.0
|
||||
return time.time() - self.last_message_time
|
||||
|
||||
def get_user_response_time(self) -> Optional[float]:
|
||||
"""获取用户最后响应时间 (距离用户最后发言的时间)
|
||||
|
||||
Returns:
|
||||
Optional[float]: 用户最后发言到现在的时长(秒),如果没有用户发言则返回None
|
||||
"""
|
||||
if not self.last_user_speak_time:
|
||||
return None
|
||||
return time.time() - self.last_user_speak_time
|
||||
|
||||
def get_bot_response_time(self) -> Optional[float]:
|
||||
"""获取机器人最后响应时间 (距离机器人最后发言的时间)
|
||||
|
||||
Returns:
|
||||
Optional[float]: 机器人最后发言到现在的时长(秒),如果没有机器人发言则返回None
|
||||
"""
|
||||
if not self.last_bot_speak_time:
|
||||
return None
|
||||
return time.time() - self.last_bot_speak_time
|
||||
|
||||
async def clear_unprocessed_messages(self):
|
||||
"""将未处理消息移入历史记录,并更新相关状态"""
|
||||
if not self.unprocessed_messages:
|
||||
return # 没有未处理消息,直接返回
|
||||
|
||||
# logger.debug(f"[私聊][{self.private_name}]处理 {len(self.unprocessed_messages)} 条未处理消息...")
|
||||
# 将未处理消息添加到历史记录中 (确保历史记录有长度限制,避免无限增长)
|
||||
max_history_len = 100 # 示例:最多保留100条历史记录
|
||||
self.chat_history.extend(self.unprocessed_messages)
|
||||
if len(self.chat_history) > max_history_len:
|
||||
self.chat_history = self.chat_history[-max_history_len:]
|
||||
|
||||
# 更新历史记录字符串 (只使用最近一部分生成,例如20条)
|
||||
history_slice_for_str = self.chat_history[-20:]
|
||||
try:
|
||||
self.chat_history_str = await build_readable_messages(
|
||||
history_slice_for_str,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0, # read_mark 可能需要根据逻辑调整
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建聊天记录字符串时出错: {e}")
|
||||
self.chat_history_str = "[构建聊天记录出错]" # 提供错误提示
|
||||
|
||||
# 清空未处理消息列表和计数
|
||||
# cleared_count = len(self.unprocessed_messages)
|
||||
self.unprocessed_messages.clear()
|
||||
self.new_messages_count = 0
|
||||
# self.has_unread_messages = False # 这个状态可以通过 new_messages_count 判断
|
||||
|
||||
self.chat_history_count = len(self.chat_history) # 更新历史记录总数
|
||||
# logger.debug(f"[私聊][{self.private_name}]已处理 {cleared_count} 条消息,当前历史记录 {self.chat_history_count} 条。")
|
||||
|
||||
self.update_changed() # 状态改变
|
||||
345
src/chat/brain_chat/PFC/pfc.py
Normal file
345
src/chat/brain_chat/PFC/pfc.py
Normal file
@@ -0,0 +1,345 @@
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from src.individuality.individuality import Individuality
|
||||
from .conversation_info import ConversationInfo
|
||||
from .observation_info import ObservationInfo
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = get_module_logger("pfc")
|
||||
|
||||
|
||||
def _calculate_similarity(goal1: str, goal2: str) -> float:
|
||||
"""简单计算两个目标之间的相似度
|
||||
|
||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||
|
||||
Args:
|
||||
goal1: 第一个目标
|
||||
goal2: 第二个目标
|
||||
|
||||
Returns:
|
||||
float: 相似度得分 (0-1)
|
||||
"""
|
||||
# 简单实现:检查重叠字数比例
|
||||
words1 = set(goal1)
|
||||
words2 = set(goal2)
|
||||
overlap = len(words1.intersection(words2))
|
||||
total = len(words1.union(words2))
|
||||
return overlap / total if total > 0 else 0
|
||||
|
||||
|
||||
class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
|
||||
)
|
||||
|
||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.nick_name = global_config.BOT_ALIAS_NAMES
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
|
||||
# 多目标存储结构
|
||||
self.goals = [] # 存储多个目标
|
||||
self.max_goals = 3 # 同时保持的最大目标数量
|
||||
self.current_goal_and_reason = None
|
||||
|
||||
async def analyze_goal(self, conversation_info: ConversationInfo, observation_info: ObservationInfo):
|
||||
"""分析对话历史并设定目标
|
||||
|
||||
Args:
|
||||
conversation_info: 对话信息
|
||||
observation_info: 观察信息
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, str]: (目标, 方法, 原因)
|
||||
"""
|
||||
# 构建对话目标
|
||||
goals_str = ""
|
||||
if conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
goals_str += goal_str
|
||||
else:
|
||||
goal = "目前没有明确对话目标"
|
||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
|
||||
# 获取聊天历史记录
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
|
||||
# await observation_info.clear_unprocessed_messages()
|
||||
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
# 构建action历史文本
|
||||
action_history_list = conversation_info.done_action
|
||||
action_history_text = "你之前做的事情是:"
|
||||
for action in action_history_list:
|
||||
action_history_text += f"{action}\n"
|
||||
|
||||
prompt = f"""{persona_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
||||
这些目标应该反映出对话的不同方面和意图。
|
||||
|
||||
{action_history_text}
|
||||
当前对话目标:
|
||||
{goals_str}
|
||||
|
||||
聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
请分析当前对话并确定最适合的对话目标。你可以:
|
||||
1. 保持现有目标不变
|
||||
2. 修改现有目标
|
||||
3. 添加新目标
|
||||
4. 删除不再相关的目标
|
||||
5. 如果你想结束对话,请设置一个目标,目标goal为"结束对话",原因reasoning为你希望结束对话
|
||||
|
||||
请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段:
|
||||
1. goal: 对话目标(简短的一句话)
|
||||
2. reasoning: 对话原因,为什么设定这个目标(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
[
|
||||
{{
|
||||
"goal": "回答用户关于Python编程的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}},
|
||||
{{
|
||||
"goal": "回答用户关于python安装的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}}
|
||||
]"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}")
|
||||
content = ""
|
||||
|
||||
# 使用改进后的get_items_from_json函数处理JSON数组
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"goal",
|
||||
"reasoning",
|
||||
required_types={"goal": str, "reasoning": str},
|
||||
allow_array=True,
|
||||
)
|
||||
|
||||
if success:
|
||||
# 判断结果是单个字典还是字典列表
|
||||
if isinstance(result, list):
|
||||
# 清空现有目标列表并添加新目标
|
||||
conversation_info.goal_list = []
|
||||
for item in result:
|
||||
conversation_info.goal_list.append(item)
|
||||
|
||||
# 返回第一个目标作为当前主要目标(如果有)
|
||||
if result:
|
||||
first_goal = result[0]
|
||||
return first_goal.get("goal", ""), "", first_goal.get("reasoning", "")
|
||||
else:
|
||||
# 单个目标的情况
|
||||
conversation_info.goal_list.append(result)
|
||||
return goal, "", reasoning
|
||||
|
||||
# 如果解析失败,返回默认值
|
||||
return "", "", ""
|
||||
|
||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||
"""更新目标列表
|
||||
|
||||
Args:
|
||||
new_goal: 新的目标
|
||||
method: 实现目标的方法
|
||||
reasoning: 目标的原因
|
||||
"""
|
||||
# 检查新目标是否与现有目标相似
|
||||
for i, (existing_goal, _, _) in enumerate(self.goals):
|
||||
if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
|
||||
# 更新现有目标
|
||||
self.goals[i] = (new_goal, method, reasoning)
|
||||
# 将此目标移到列表前面(最主要的位置)
|
||||
self.goals.insert(0, self.goals.pop(i))
|
||||
return
|
||||
|
||||
# 添加新目标到列表前面
|
||||
self.goals.insert(0, (new_goal, method, reasoning))
|
||||
|
||||
# 限制目标数量
|
||||
if len(self.goals) > self.max_goals:
|
||||
self.goals.pop() # 移除最老的目标
|
||||
|
||||
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取所有当前目标
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
|
||||
"""
|
||||
return self.goals.copy()
|
||||
|
||||
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取除了当前主要目标外的其他备选目标
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 备选目标列表
|
||||
"""
|
||||
if len(self.goals) <= 1:
|
||||
return []
|
||||
return self.goals[1:].copy()
|
||||
|
||||
async def analyze_conversation(self, goal, reasoning):
|
||||
messages = self.chat_observer.get_cached_messages()
|
||||
chat_history_text = await build_readable_messages(
|
||||
messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
# ===> Persona 文本构建结束 <===
|
||||
|
||||
# --- 修改 Prompt 字符串,使用 persona_text ---
|
||||
prompt = f"""{persona_text}。现在你在参与一场QQ聊天,
|
||||
当前对话目标:{goal}
|
||||
产生该对话目标的原因:{reasoning}
|
||||
|
||||
请分析以下聊天记录,并根据你的性格特征评估该目标是否已经达到,或者你是否希望停止该次对话。
|
||||
聊天记录:
|
||||
{chat_history_text}
|
||||
请以JSON格式输出,包含以下字段:
|
||||
1. goal_achieved: 对话目标是否已经达到(true/false)
|
||||
2. stop_conversation: 是否希望停止该次对话(true/false)
|
||||
3. reason: 为什么希望停止该次对话(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
{{
|
||||
"goal_achieved": true,
|
||||
"stop_conversation": false,
|
||||
"reason": "虽然目标已达成,但对话仍然有继续的价值"
|
||||
}}"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
|
||||
# 尝试解析JSON
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"goal_achieved",
|
||||
"stop_conversation",
|
||||
"reason",
|
||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[私聊][{self.private_name}]无法解析对话分析结果JSON")
|
||||
return False, False, "解析结果失败"
|
||||
|
||||
goal_achieved = result["goal_achieved"]
|
||||
stop_conversation = result["stop_conversation"]
|
||||
reason = result["reason"]
|
||||
|
||||
return goal_achieved, stop_conversation, reason
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话状态时出错: {str(e)}")
|
||||
return False, False, f"分析出错: {str(e)}"
|
||||
|
||||
|
||||
# 先注释掉,万一以后出问题了还能开回来(((
|
||||
# class DirectMessageSender:
|
||||
# """直接发送消息到平台的发送器"""
|
||||
|
||||
# def __init__(self, private_name: str):
|
||||
# self.logger = get_module_logger("direct_sender")
|
||||
# self.storage = MessageStorage()
|
||||
# self.private_name = private_name
|
||||
|
||||
# async def send_via_ws(self, message: MessageSending) -> None:
|
||||
# try:
|
||||
# await global_api.send_message(message)
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置,请检查配置文件") from e
|
||||
|
||||
# async def send_message(
|
||||
# self,
|
||||
# chat_stream: ChatStream,
|
||||
# content: str,
|
||||
# reply_to_message: Optional[Message] = None,
|
||||
# ) -> None:
|
||||
# """直接发送消息到平台
|
||||
|
||||
# Args:
|
||||
# chat_stream: 聊天流
|
||||
# content: 消息内容
|
||||
# reply_to_message: 要回复的消息
|
||||
# """
|
||||
# # 构建消息对象
|
||||
# message_segment = Seg(type="text", data=content)
|
||||
# bot_user_info = UserInfo(
|
||||
# user_id=global_config.BOT_QQ,
|
||||
# user_nickname=global_config.BOT_NICKNAME,
|
||||
# platform=chat_stream.platform,
|
||||
# )
|
||||
|
||||
# message = MessageSending(
|
||||
# message_id=f"dm{round(time.time(), 2)}",
|
||||
# chat_stream=chat_stream,
|
||||
# bot_user_info=bot_user_info,
|
||||
# sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
|
||||
# message_segment=message_segment,
|
||||
# reply=reply_to_message,
|
||||
# is_head=True,
|
||||
# is_emoji=False,
|
||||
# thinking_start_time=time.time(),
|
||||
# )
|
||||
|
||||
# # 处理消息
|
||||
# await message.process()
|
||||
|
||||
# _message_json = message.to_dict()
|
||||
|
||||
# # 发送消息
|
||||
# try:
|
||||
# await self.send_via_ws(message)
|
||||
# await self.storage.store_message(message, chat_stream)
|
||||
# logger.success(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
85
src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py
Normal file
85
src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import List, Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from ..chat.message import Message
|
||||
from ..knowledge.knowledge_lib import qa_manager
|
||||
from ..utils.chat_message_builder import build_readable_messages
|
||||
|
||||
logger = get_module_logger("knowledge_fetcher")
|
||||
|
||||
|
||||
class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_normal,
|
||||
temperature=global_config.llm_normal["temp"],
|
||||
max_tokens=1000,
|
||||
request_type="knowledge_fetch",
|
||||
)
|
||||
self.private_name = private_name
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
|
||||
Returns:
|
||||
str: 构造好的,带相关度的知识
|
||||
"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识")
|
||||
try:
|
||||
knowledge_info = qa_manager.get_knowledge(query)
|
||||
logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}")
|
||||
return knowledge_info
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}")
|
||||
return "未找到匹配的知识"
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
chat_history: 聊天历史
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (获取的知识, 知识来源)
|
||||
"""
|
||||
# 构建查询上下文
|
||||
chat_history_text = await build_readable_messages(
|
||||
chat_history,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
||||
# 从记忆中获取相关知识
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=f"{query}\n{chat_history_text}",
|
||||
max_memory_num=3,
|
||||
max_memory_length=2,
|
||||
max_depth=3,
|
||||
fast_retrieval=False,
|
||||
)
|
||||
knowledge_text = ""
|
||||
sources_text = "无记忆匹配" # 默认值
|
||||
if related_memory:
|
||||
sources = []
|
||||
for memory in related_memory:
|
||||
knowledge_text += memory[1] + "\n"
|
||||
sources.append(f"记忆片段{memory[0]}")
|
||||
knowledge_text = knowledge_text.strip()
|
||||
sources_text = ",".join(sources)
|
||||
|
||||
knowledge_text += "\n现在有以下**知识**可供参考:\n "
|
||||
knowledge_text += self._lpmm_get_knowledge(query)
|
||||
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
|
||||
|
||||
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"
|
||||
115
src/chat/brain_chat/PFC/pfc_manager.py
Normal file
115
src/chat/brain_chat/PFC/pfc_manager.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from .conversation import Conversation
|
||||
import traceback
|
||||
|
||||
logger = get_module_logger("pfc_manager")
|
||||
|
||||
|
||||
class PFCManager:
|
||||
"""PFC对话管理器,负责管理所有对话实例"""
|
||||
|
||||
# 单例模式
|
||||
_instance = None
|
||||
|
||||
# 会话实例管理
|
||||
_instances: Dict[str, Conversation] = {}
|
||||
_initializing: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "PFCManager":
|
||||
"""获取管理器单例
|
||||
|
||||
Returns:
|
||||
PFCManager: 管理器实例
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = PFCManager()
|
||||
return cls._instance
|
||||
|
||||
async def get_or_create_conversation(self, stream_id: str, private_name: str) -> Optional[Conversation]:
|
||||
"""获取或创建对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
Optional[Conversation]: 对话实例,创建失败则返回None
|
||||
"""
|
||||
# 检查是否已经有实例
|
||||
if stream_id in self._initializing and self._initializing[stream_id]:
|
||||
logger.debug(f"[私聊][{private_name}]会话实例正在初始化中: {stream_id}")
|
||||
return None
|
||||
|
||||
if stream_id in self._instances and self._instances[stream_id].should_continue:
|
||||
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
|
||||
return self._instances[stream_id]
|
||||
if stream_id in self._instances:
|
||||
instance = self._instances[stream_id]
|
||||
if (
|
||||
hasattr(instance, "ignore_until_timestamp")
|
||||
and instance.ignore_until_timestamp
|
||||
and time.time() < instance.ignore_until_timestamp
|
||||
):
|
||||
logger.debug(f"[私聊][{private_name}]会话实例当前处于忽略状态: {stream_id}")
|
||||
# 返回 None 阻止交互。或者可以返回实例但标记它被忽略了喵?
|
||||
# 还是返回 None 吧喵。
|
||||
return None
|
||||
|
||||
# 检查 should_continue 状态
|
||||
if instance.should_continue:
|
||||
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
|
||||
return instance
|
||||
# else: 实例存在但不应继续
|
||||
try:
|
||||
# 创建新实例
|
||||
logger.info(f"[私聊][{private_name}]创建新的对话实例: {stream_id}")
|
||||
self._initializing[stream_id] = True
|
||||
# 创建实例
|
||||
conversation_instance = Conversation(stream_id, private_name)
|
||||
self._instances[stream_id] = conversation_instance
|
||||
|
||||
# 启动实例初始化
|
||||
await self._initialize_conversation(conversation_instance)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{private_name}]创建会话实例失败: {stream_id}, 错误: {e}")
|
||||
return None
|
||||
|
||||
return conversation_instance
|
||||
|
||||
async def _initialize_conversation(self, conversation: Conversation):
|
||||
"""初始化会话实例
|
||||
|
||||
Args:
|
||||
conversation: 要初始化的会话实例
|
||||
"""
|
||||
stream_id = conversation.stream_id
|
||||
private_name = conversation.private_name
|
||||
|
||||
try:
|
||||
logger.info(f"[私聊][{private_name}]开始初始化会话实例: {stream_id}")
|
||||
# 启动初始化流程
|
||||
await conversation._initialize()
|
||||
|
||||
# 标记初始化完成
|
||||
self._initializing[stream_id] = False
|
||||
|
||||
logger.info(f"[私聊][{private_name}]会话实例 {stream_id} 初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{private_name}]管理器初始化会话实例失败: {stream_id}, 错误: {e}")
|
||||
logger.error(f"[私聊][{private_name}]{traceback.format_exc()}")
|
||||
# 清理失败的初始化
|
||||
|
||||
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
|
||||
"""获取已存在的会话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
Optional[Conversation]: 会话实例,不存在则返回None
|
||||
"""
|
||||
return self._instances.get(stream_id)
|
||||
23
src/chat/brain_chat/PFC/pfc_types.py
Normal file
23
src/chat/brain_chat/PFC/pfc_types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""对话状态"""
|
||||
|
||||
INIT = "初始化"
|
||||
RETHINKING = "重新思考"
|
||||
ANALYZING = "分析历史"
|
||||
PLANNING = "规划目标"
|
||||
GENERATING = "生成回复"
|
||||
CHECKING = "检查回复"
|
||||
SENDING = "发送消息"
|
||||
FETCHING = "获取知识"
|
||||
WAITING = "等待"
|
||||
LISTENING = "倾听"
|
||||
ENDED = "结束"
|
||||
JUDGING = "判断"
|
||||
IGNORED = "屏蔽"
|
||||
|
||||
|
||||
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
|
||||
127
src/chat/brain_chat/PFC/pfc_utils.py
Normal file
127
src/chat/brain_chat/PFC/pfc_utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Tuple, List, Union
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("pfc_utils")
|
||||
|
||||
|
||||
def get_items_from_json(
|
||||
content: str,
|
||||
private_name: str,
|
||||
*items: str,
|
||||
default_values: Optional[Dict[str, Any]] = None,
|
||||
required_types: Optional[Dict[str, type]] = None,
|
||||
allow_array: bool = True,
|
||||
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||
"""从文本中提取JSON内容并获取指定字段
|
||||
|
||||
Args:
|
||||
content: 包含JSON的文本
|
||||
private_name: 私聊名称
|
||||
*items: 要提取的字段名
|
||||
default_values: 字段的默认值,格式为 {字段名: 默认值}
|
||||
required_types: 字段的必需类型,格式为 {字段名: 类型}
|
||||
allow_array: 是否允许解析JSON数组
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
|
||||
"""
|
||||
content = content.strip()
|
||||
result = {}
|
||||
|
||||
# 设置默认值
|
||||
if default_values:
|
||||
result.update(default_values)
|
||||
|
||||
# 首先尝试解析为JSON数组
|
||||
if allow_array:
|
||||
try:
|
||||
# 尝试找到文本中的JSON数组
|
||||
array_pattern = r"\[[\s\S]*\]"
|
||||
array_match = re.search(array_pattern, content)
|
||||
if array_match:
|
||||
array_content = array_match.group()
|
||||
json_array = json.loads(array_content)
|
||||
|
||||
# 确认是数组类型
|
||||
if isinstance(json_array, list):
|
||||
# 验证数组中的每个项目是否包含所有必需字段
|
||||
valid_items = []
|
||||
for item in json_array:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否有所有必需字段
|
||||
if all(field in item for field in items):
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
type_valid = True
|
||||
for field, expected_type in required_types.items():
|
||||
if field in item and not isinstance(item[field], expected_type):
|
||||
type_valid = False
|
||||
break
|
||||
|
||||
if not type_valid:
|
||||
continue
|
||||
|
||||
# 验证字符串字段不为空
|
||||
string_valid = True
|
||||
for field in items:
|
||||
if isinstance(item[field], str) and not item[field].strip():
|
||||
string_valid = False
|
||||
break
|
||||
|
||||
if not string_valid:
|
||||
continue
|
||||
|
||||
valid_items.append(item)
|
||||
|
||||
if valid_items:
|
||||
return True, valid_items
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(f"[私聊][{private_name}]JSON数组解析失败,尝试解析单个JSON对象")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{private_name}]尝试解析JSON数组时出错: {str(e)}")
|
||||
|
||||
# 尝试解析JSON对象
|
||||
try:
|
||||
json_data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
json_data = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[私聊][{private_name}]提取的JSON内容解析失败")
|
||||
return False, result
|
||||
else:
|
||||
logger.error(f"[私聊][{private_name}]无法在返回内容中找到有效的JSON")
|
||||
return False, result
|
||||
|
||||
# 提取字段
|
||||
for item in items:
|
||||
if item in json_data:
|
||||
result[item] = json_data[item]
|
||||
|
||||
# 验证必需字段
|
||||
if not all(item in result for item in items):
|
||||
logger.error(f"[私聊][{private_name}]JSON缺少必要字段,实际内容: {json_data}")
|
||||
return False, result
|
||||
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
for field, expected_type in required_types.items():
|
||||
if field in result and not isinstance(result[field], expected_type):
|
||||
logger.error(f"[私聊][{private_name}]{field} 必须是 {expected_type.__name__} 类型")
|
||||
return False, result
|
||||
|
||||
# 验证字符串字段不为空
|
||||
for field in items:
|
||||
if isinstance(result[field], str) and not result[field].strip():
|
||||
logger.error(f"[私聊][{private_name}]{field} 不能为空")
|
||||
return False, result
|
||||
|
||||
return True, result
|
||||
183
src/chat/brain_chat/PFC/reply_checker.py
Normal file
183
src/chat/brain_chat/PFC/reply_checker.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import json
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from maim_message import UserInfo
|
||||
|
||||
logger = get_module_logger("reply_checker")
|
||||
|
||||
|
||||
class ReplyChecker:
|
||||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
|
||||
)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
|
||||
async def check(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查生成的回复是否合适
|
||||
|
||||
Args:
|
||||
reply: 生成的回复
|
||||
goal: 对话目标
|
||||
chat_history: 对话历史记录
|
||||
chat_history_text: 对话历史记录文本
|
||||
retry_count: 当前重试次数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
|
||||
"""
|
||||
# 不再从 observer 获取,直接使用传入的 chat_history
|
||||
# messages = self.chat_observer.get_cached_messages(limit=20)
|
||||
try:
|
||||
# 筛选出最近由 Bot 自己发送的消息
|
||||
bot_messages = []
|
||||
for msg in reversed(chat_history):
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串
|
||||
bot_messages.append(msg.get("processed_plain_text", ""))
|
||||
if len(bot_messages) >= 2: # 只和最近的两条比较
|
||||
break
|
||||
# 进行比较
|
||||
if bot_messages:
|
||||
# 可以用简单比较,或者更复杂的相似度库 (如 difflib)
|
||||
# 简单比较:是否完全相同
|
||||
if reply == bot_messages[0]: # 和最近一条完全一样
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息完全相同: '{reply}'"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
"被逻辑检查拒绝:回复内容与你上一条发言完全相同,可以选择深入话题或寻找其它话题或等待",
|
||||
True,
|
||||
) # 不合适,需要返回至决策层
|
||||
# 2. 相似度检查 (如果精确匹配未通过)
|
||||
import difflib # 导入 difflib 库
|
||||
|
||||
# 计算编辑距离相似度,ratio() 返回 0 到 1 之间的浮点数
|
||||
similarity_ratio = difflib.SequenceMatcher(None, reply, bot_messages[0]).ratio()
|
||||
logger.debug(f"[私聊][{self.private_name}]ReplyChecker - 相似度: {similarity_ratio:.2f}")
|
||||
|
||||
# 设置一个相似度阈值
|
||||
similarity_threshold = 0.9
|
||||
if similarity_ratio > similarity_threshold:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息高度相似 (相似度 {similarity_ratio:.2f}): '{reply}'"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"被逻辑检查拒绝:回复内容与你上一条发言高度相似 (相似度 {similarity_ratio:.2f}),可以选择深入话题或寻找其它话题或等待。",
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(f"[私聊][{self.private_name}]检查回复时出错: 类型={type(e)}, 值={e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") # 打印详细的回溯信息
|
||||
|
||||
prompt = f"""你是一个聊天逻辑检查器,请检查以下回复或消息是否合适:
|
||||
|
||||
当前对话目标:{goal}
|
||||
最新的对话记录:
|
||||
{chat_history_text}
|
||||
|
||||
待检查的消息:
|
||||
{reply}
|
||||
|
||||
请结合聊天记录检查以下几点:
|
||||
1. 这条消息是否依然符合当前对话目标和实现方式
|
||||
2. 这条消息是否与最新的对话记录保持一致性
|
||||
3. 是否存在重复发言,或重复表达同质内容(尤其是只是换一种方式表达了相同的含义)
|
||||
4. 这条消息是否包含违规内容(例如血腥暴力,政治敏感等)
|
||||
5. 这条消息是否以发送者的角度发言(不要让发送者自己回复自己的消息)
|
||||
6. 这条消息是否通俗易懂
|
||||
7. 这条消息是否有些多余,例如在对方没有回复的情况下,依然连续多次“消息轰炸”(尤其是已经连续发送3条信息的情况,这很可能不合理,需要着重判断)
|
||||
8. 这条消息是否使用了完全没必要的修辞
|
||||
9. 这条消息是否逻辑通顺
|
||||
10. 这条消息是否太过冗长了(通常私聊的每条消息长度在20字以内,除非特殊情况)
|
||||
11. 在连续多次发送消息的情况下,这条消息是否衔接自然,会不会显得奇怪(例如连续两条消息中部分内容重叠)
|
||||
|
||||
请以JSON格式输出,包含以下字段:
|
||||
1. suitable: 是否合适 (true/false)
|
||||
2. reason: 原因说明
|
||||
3. need_replan: 是否需要重新决策 (true/false),当你认为此时已经不适合发消息,需要规划其它行动时,设为true
|
||||
|
||||
输出格式示例:
|
||||
{{
|
||||
"suitable": true,
|
||||
"reason": "回复符合要求,虽然有可能略微偏离目标,但是整体内容流畅得体",
|
||||
"need_replan": false
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}")
|
||||
|
||||
# 清理内容,尝试提取JSON部分
|
||||
content = content.strip()
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
import re
|
||||
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解析失败,尝试从文本中提取结果
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
else:
|
||||
# 如果找不到JSON,从文本中判断
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
|
||||
# 验证JSON字段
|
||||
suitable = result.get("suitable", None)
|
||||
reason = result.get("reason", "未提供原因")
|
||||
need_replan = result.get("need_replan", False)
|
||||
|
||||
# 如果suitable字段是字符串,转换为布尔值
|
||||
if isinstance(suitable, str):
|
||||
suitable = suitable.lower() == "true"
|
||||
|
||||
# 如果suitable字段不存在或不是布尔值,从reason中判断
|
||||
if suitable is None:
|
||||
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
|
||||
|
||||
# 如果不合适且未达到最大重试次数,返回需要重试
|
||||
if not suitable and retry_count < self.max_retries:
|
||||
return False, reason, False
|
||||
|
||||
# 如果不合适且已达到最大重试次数,返回需要重新规划
|
||||
if not suitable and retry_count >= self.max_retries:
|
||||
return False, f"多次重试后仍不合适: {reason}", True
|
||||
|
||||
return suitable, reason, need_replan
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]检查回复时出错: {e}")
|
||||
# 如果出错且已达到最大重试次数,建议重新规划
|
||||
if retry_count >= self.max_retries:
|
||||
return False, "多次检查失败,建议重新规划", True
|
||||
return False, f"检查过程出错,建议重试: {str(e)}", False
|
||||
228
src/chat/brain_chat/PFC/reply_generator.py
Normal file
228
src/chat/brain_chat/PFC/reply_generator.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
from src.individuality.individuality import Individuality
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
logger = get_module_logger("reply_generator")
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt for direct_reply (首次回复)
|
||||
PROMPT_DIRECT_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下信息生成一条回复:
|
||||
|
||||
当前对话目标:{goals_str}
|
||||
|
||||
{knowledge_info_str}
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
|
||||
请根据上述信息,结合聊天记录,回复对方。该回复应该:
|
||||
1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!)
|
||||
2. 符合你的性格特征和身份细节
|
||||
3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况)
|
||||
4. 可以适当利用相关知识,但不要生硬引用
|
||||
5. 自然、得体,结合聊天记录逻辑合理,且没有重复表达同质内容
|
||||
|
||||
请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。
|
||||
可以回复得自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。
|
||||
请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
|
||||
请直接输出回复内容,不需要任何额外格式。"""
|
||||
|
||||
# Prompt for send_new_message (追问/补充)
|
||||
PROMPT_SEND_NEW_MESSAGE = """{persona_text}。现在你在参与一场QQ私聊,**刚刚你已经发送了一条或多条消息**,现在请根据以下信息再发一条新消息:
|
||||
|
||||
当前对话目标:{goals_str}
|
||||
|
||||
{knowledge_info_str}
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
|
||||
请根据上述信息,结合聊天记录,继续发一条新消息(例如对之前消息的补充,深入话题,或追问等等)。该消息应该:
|
||||
1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!)
|
||||
2. 符合你的性格特征和身份细节
|
||||
3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况)
|
||||
4. 可以适当利用相关知识,但不要生硬引用
|
||||
5. 跟之前你发的消息自然的衔接,逻辑合理,且没有重复表达同质内容或部分重叠内容
|
||||
|
||||
请注意把握聊天内容,不用太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。
|
||||
这条消息可以自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。
|
||||
请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出消息内容。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
|
||||
请直接输出回复内容,不需要任何额外格式。"""
|
||||
|
||||
# Prompt for say_goodbye (告别语生成)
|
||||
PROMPT_FAREWELL = """{persona_text}。你在参与一场 QQ 私聊,现在对话似乎已经结束,你决定再发一条最后的消息来圆满结束。
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
请根据上述信息,结合聊天记录,构思一条**简短、自然、符合你人设**的最后的消息。
|
||||
这条消息应该:
|
||||
1. 从你自己的角度发言。
|
||||
2. 符合你的性格特征和身份细节。
|
||||
3. 通俗易懂,自然流畅,通常很简短。
|
||||
4. 自然地为这场对话画上句号,避免开启新话题或显得冗长、刻意。
|
||||
|
||||
请像真人一样随意自然,**简洁是关键**。
|
||||
不要输出多余内容(包括前后缀、冒号、引号、括号、表情包、at或@等)。
|
||||
|
||||
请直接输出最终的告别消息内容,不需要任何额外格式。"""
|
||||
|
||||
|
||||
class ReplyGenerator:
|
||||
"""回复生成器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_chat,
|
||||
temperature=global_config.llm_PFC_chat["temp"],
|
||||
max_tokens=300,
|
||||
request_type="reply_generation",
|
||||
)
|
||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.reply_checker = ReplyChecker(stream_id, private_name)
|
||||
|
||||
# 修改 generate 方法签名,增加 action_type 参数
|
||||
async def generate(
|
||||
self, observation_info: ObservationInfo, conversation_info: ConversationInfo, action_type: str
|
||||
) -> str:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
observation_info: 观察信息
|
||||
conversation_info: 对话信息
|
||||
action_type: 当前执行的动作类型 ('direct_reply' 或 'send_new_message')
|
||||
|
||||
Returns:
|
||||
str: 生成的回复
|
||||
"""
|
||||
# 构建提示词
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始生成回复 (动作类型: {action_type}):当前目标: {conversation_info.goal_list}"
|
||||
)
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
# (这部分逻辑基本不变)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
if conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标\n" # 简化无目标情况
|
||||
|
||||
# --- 新增:构建知识信息字符串 ---
|
||||
knowledge_info_str = "【供参考的相关知识和记忆】\n" # 稍微改下标题,表明是供参考
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' (来源: {source}): {knowledge_snippet}\n" # 格式微调,更简洁
|
||||
)
|
||||
else:
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge:
|
||||
knowledge_info_str += "- 暂无。\n" # 更简洁的提示
|
||||
|
||||
else:
|
||||
knowledge_info_str += "- 暂无。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
elif not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if action_type == "send_new_message":
|
||||
prompt_template = PROMPT_SEND_NEW_MESSAGE
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_SEND_NEW_MESSAGE (追问生成)")
|
||||
elif action_type == "say_goodbye": # 处理告别动作
|
||||
prompt_template = PROMPT_FAREWELL
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_FAREWELL (告别语生成)")
|
||||
else: # 默认使用 direct_reply 的 prompt (包括 'direct_reply' 或其他未明确处理的类型)
|
||||
prompt_template = PROMPT_DIRECT_REPLY
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_DIRECT_REPLY (首次/非连续回复生成)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str,
|
||||
chat_history_text=chat_history_text,
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
# --- 调用 LLM 生成 ---
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}")
|
||||
# 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]生成回复时出错: {e}")
|
||||
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
||||
|
||||
# check_reply 方法保持不变
|
||||
async def check_reply(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_str: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查回复是否合适
|
||||
(此方法逻辑保持不变)
|
||||
"""
|
||||
return await self.reply_checker.check(reply, goal, chat_history, chat_history_str, retry_count)
|
||||
79
src/chat/brain_chat/PFC/waiter.py
Normal file
79
src/chat/brain_chat/PFC/waiter.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
# from src.individuality.individuality import Individuality # 不再需要
|
||||
from ...config.config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
logger = get_module_logger("waiter")
|
||||
|
||||
# --- 在这里设定你想要的超时时间(秒) ---
|
||||
# 例如: 120 秒 = 2 分钟
|
||||
DESIRED_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
class Waiter:
|
||||
"""等待处理类"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.private_name = private_name
|
||||
# self.wait_accumulated_time = 0 # 不再需要累加计时
|
||||
|
||||
async def wait(self, conversation_info: ConversationInfo) -> bool:
|
||||
"""等待用户新消息或超时"""
|
||||
wait_start_time = time.time()
|
||||
logger.info(f"[私聊][{self.private_name}]进入常规等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
|
||||
|
||||
while True:
|
||||
# 检查是否有新消息
|
||||
if self.chat_observer.new_message_after(wait_start_time):
|
||||
logger.info(f"[私聊][{self.private_name}]等待结束,收到新消息")
|
||||
return False # 返回 False 表示不是超时
|
||||
|
||||
# 检查是否超时
|
||||
elapsed_time = time.time() - wait_start_time
|
||||
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
|
||||
logger.info(f"[私聊][{self.private_name}]等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
|
||||
wait_goal = {
|
||||
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,注意可能在对方看来聊天已经结束,思考接下来要做什么",
|
||||
"reasoning": "对方很久没有回复你的消息了",
|
||||
}
|
||||
conversation_info.goal_list.append(wait_goal)
|
||||
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
|
||||
return True # 返回 True 表示超时
|
||||
|
||||
await asyncio.sleep(5) # 每 5 秒检查一次
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]等待中..."
|
||||
) # 可以考虑把这个频繁日志注释掉,只在超时或收到消息时输出
|
||||
|
||||
async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
|
||||
"""倾听用户发言或超时"""
|
||||
wait_start_time = time.time()
|
||||
logger.info(f"[私聊][{self.private_name}]进入倾听等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
|
||||
|
||||
while True:
|
||||
# 检查是否有新消息
|
||||
if self.chat_observer.new_message_after(wait_start_time):
|
||||
logger.info(f"[私聊][{self.private_name}]倾听等待结束,收到新消息")
|
||||
return False # 返回 False 表示不是超时
|
||||
|
||||
# 检查是否超时
|
||||
elapsed_time = time.time() - wait_start_time
|
||||
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
|
||||
logger.info(f"[私聊][{self.private_name}]倾听等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
|
||||
wait_goal = {
|
||||
# 保持 goal 文本一致
|
||||
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,对方似乎话说一半突然消失了,可能忙去了?也可能忘记了回复?要问问吗?还是结束对话?或继续等待?思考接下来要做什么",
|
||||
"reasoning": "对方话说一半消失了,很久没有回复",
|
||||
}
|
||||
conversation_info.goal_list.append(wait_goal)
|
||||
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
|
||||
return True # 返回 True 表示超时
|
||||
|
||||
await asyncio.sleep(5) # 每 5 秒检查一次
|
||||
logger.debug(f"[私聊][{self.private_name}]倾听等待中...") # 同上,可以考虑注释掉
|
||||
@@ -16,7 +16,8 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.message_recorder import extract_and_distribute_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
@@ -96,6 +97,9 @@ class BrainChatting:
|
||||
|
||||
self.more_plan = False
|
||||
|
||||
# 最近一次是否成功进行了 reply,用于选择 BrainPlanner 的 Prompt
|
||||
self._last_successful_reply: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
@@ -157,6 +161,7 @@ class BrainChatting:
|
||||
)
|
||||
|
||||
async def _loopbody(self): # sourcery skip: hoist-if-from-if
|
||||
# 获取最新消息(用于上下文,但不影响是否调用 observe)
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
@@ -165,17 +170,25 @@ class BrainChatting:
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
# 如果有新消息,更新 last_read_time
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
# 总是执行一次思考迭代(不管有没有新消息)
|
||||
# wait 动作会在其内部等待,不需要在这里处理
|
||||
should_continue = await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
if not should_continue:
|
||||
# 选择了 complete_talk,返回 False 表示需要等待新消息
|
||||
return False
|
||||
|
||||
# 继续下一次迭代(除非选择了 complete_talk)
|
||||
# 短暂等待后再继续,避免过于频繁的循环
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
@@ -240,7 +253,7 @@ class BrainChatting:
|
||||
# ReflectTracker Check
|
||||
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
|
||||
# -------------------------------------------------------------------------
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
||||
|
||||
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
|
||||
if tracker:
|
||||
@@ -253,13 +266,15 @@ class BrainChatting:
|
||||
# Expression Reflection Check
|
||||
# 检查是否需要提问表达反思
|
||||
# -------------------------------------------------------------------------
|
||||
from src.express.expression_reflector import expression_reflector_manager
|
||||
from src.bw_learner.expression_reflector import expression_reflector_manager
|
||||
|
||||
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
||||
asyncio.create_task(reflector.check_and_ask())
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
|
||||
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
|
||||
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
@@ -272,14 +287,16 @@ class BrainChatting:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 执行planner
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
# 一次思考迭代:Think - Act - Observe
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
@@ -290,12 +307,11 @@ class BrainChatting:
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
prompt_key="brain_planner_prompt_react",
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
@@ -311,7 +327,10 @@ class BrainChatting:
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
||||
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
|
||||
|
||||
# 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
@@ -343,7 +362,14 @@ class BrainChatting:
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
# 构建最终的循环信息
|
||||
# 更新观察时间标记
|
||||
self.action_planner.last_obs_time_mark = time.time()
|
||||
|
||||
# 如果选择了 complete_talk,标记为完成,不再继续迭代
|
||||
if has_complete_talk:
|
||||
logger.info(f"{self.log_prefix} 检测到 complete_talk 动作,本次思考完成")
|
||||
|
||||
# 构建循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
@@ -369,10 +395,16 @@ class BrainChatting:
|
||||
}
|
||||
_reply_text = action_reply_text
|
||||
|
||||
# 如果选择了 complete_talk,返回 False 以停止 _loopbody 的循环
|
||||
# 否则返回 True,让 _loopbody 继续下一次迭代
|
||||
should_continue = not has_complete_talk
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
return True
|
||||
# 如果选择了 complete_talk,返回 False 停止循环
|
||||
# 否则返回 True,继续下一次思考迭代
|
||||
return should_continue
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
@@ -380,9 +412,13 @@ class BrainChatting:
|
||||
while self.running:
|
||||
# 主循环
|
||||
success = await self._loopbody()
|
||||
await asyncio.sleep(0.1)
|
||||
if not success:
|
||||
break
|
||||
# 选择了 complete,等待新消息
|
||||
logger.info(f"{self.log_prefix} 选择了 complete,等待新消息...")
|
||||
await self._wait_for_new_message()
|
||||
# 有新消息后继续循环
|
||||
continue
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
|
||||
@@ -393,6 +429,33 @@ class BrainChatting:
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
async def _wait_for_new_message(self):
|
||||
"""等待新消息到达"""
|
||||
last_check_time = self.last_read_time
|
||||
check_interval = 1.0 # 每秒检查一次
|
||||
|
||||
while self.running:
|
||||
# 检查是否有新消息
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=last_check_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
# 如果有新消息,更新 last_read_time 并返回
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
|
||||
return
|
||||
|
||||
# 等待一段时间后再次检查
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
@@ -506,12 +569,12 @@ class BrainChatting:
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
if action_planner_info.action_type == "complete_talk":
|
||||
# 直接处理complete_talk逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择完成对话"
|
||||
logger.info(f"{self.log_prefix} 选择完成对话,原因: {reason}")
|
||||
|
||||
# 存储no_reply信息到数据库
|
||||
# 存储complete_talk信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -519,18 +582,33 @@ class BrainChatting:
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
action_name="complete_talk",
|
||||
)
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
return {"action_type": "complete_talk", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
try:
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned_uw: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned_uw.append(s)
|
||||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=action_planner_info.reasoning or "",
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
@@ -543,11 +621,17 @@ class BrainChatting:
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
@@ -558,6 +642,8 @@ class BrainChatting:
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
# 标记这次循环已经成功进行了回复
|
||||
self._last_successful_reply = True
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
@@ -567,7 +653,88 @@ class BrainChatting:
|
||||
|
||||
# 其他动作
|
||||
else:
|
||||
# 执行普通动作
|
||||
# 内建 wait / listening:不通过插件系统,直接在这里处理
|
||||
if action_planner_info.action_type in ["wait", "listening"]:
|
||||
reason = action_planner_info.reasoning or ""
|
||||
action_data = action_planner_info.action_data or {}
|
||||
|
||||
if action_planner_info.action_type == "wait":
|
||||
# 获取等待时间(必填)
|
||||
wait_seconds = action_data.get("wait_seconds")
|
||||
if wait_seconds is None:
|
||||
logger.warning(f"{self.log_prefix} wait 动作缺少 wait_seconds 参数,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
else:
|
||||
try:
|
||||
wait_seconds = float(wait_seconds)
|
||||
if wait_seconds < 0:
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 不能为负数,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds} 秒")
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason or f"等待 {wait_seconds} 秒",
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="wait",
|
||||
)
|
||||
|
||||
# 等待指定时间
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
"action_type": "wait",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
# listening 已合并到 wait,如果遇到则转换为 wait(向后兼容)
|
||||
elif action_planner_info.action_type == "listening":
|
||||
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait,自动转换")
|
||||
# 使用默认等待时间
|
||||
wait_seconds = 3
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒")
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason or f"倾听并等待 {wait_seconds} 秒",
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="listening",
|
||||
)
|
||||
|
||||
# 等待指定时间
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
"action_type": "listening",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
# 其余动作:走原有插件 Action 体系
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_planner_info.action_type,
|
||||
@@ -577,6 +744,10 @@ class BrainChatting:
|
||||
thinking_id,
|
||||
action_planner_info.action_message,
|
||||
)
|
||||
# 非 reply 类动作执行成功时,清空最近成功回复标记,让下一轮回到 initial Prompt
|
||||
if success and action_planner_info.action_type != "reply":
|
||||
self._last_successful_reply = False
|
||||
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
|
||||
@@ -35,12 +35,13 @@ install(extra_lines=3)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
# ReAct 形式的 Planner Prompt
|
||||
Prompt(
|
||||
"""
|
||||
{time_block}
|
||||
{name_block}
|
||||
你的兴趣是:{interest}
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
|
||||
**聊天内容**
|
||||
{chat_content_block}
|
||||
|
||||
@@ -57,11 +58,35 @@ reply
|
||||
"reason":"回复的原因"
|
||||
}}
|
||||
|
||||
no_reply
|
||||
wait
|
||||
动作描述:
|
||||
等待,保持沉默,等待对方发言
|
||||
暂时不再发言,等待指定时间。适用于以下情况:
|
||||
- 你已经表达清楚一轮,想给对方留出空间
|
||||
- 你感觉对方的话还没说完,或者自己刚刚发了好几条连续消息
|
||||
- 你想要等待一定时间来让对方把话说完,或者等待对方反应
|
||||
- 你想保持安静,专注"听"而不是马上回复
|
||||
请你根据上下文来判断要等待多久,请你灵活判断:
|
||||
- 如果你们交流间隔时间很短,聊的很频繁,不宜等待太久
|
||||
- 如果你们交流间隔时间很长,聊的很少,可以等待较长时间
|
||||
{{
|
||||
"action": "no_reply",
|
||||
"action": "wait",
|
||||
"target_message_id":"想要作为这次等待依据的消息id(通常是对方的最新消息)",
|
||||
"wait_seconds": 等待的秒数(必填,例如:5 表示等待5秒),
|
||||
"reason":"选择等待的原因"
|
||||
}}
|
||||
|
||||
complete_talk
|
||||
动作描述:
|
||||
当前聊天暂时结束了,对方离开,没有更多话题了
|
||||
你可以使用该动作来暂时休息,等待对方有新发言再继续:
|
||||
- 多次wait之后,对方迟迟不回复消息才用
|
||||
- 如果对方只是短暂不回复,应该使用wait而不是complete_talk
|
||||
- 聊天内容显示当前聊天已经结束或者没有新内容时候,选择complete_talk
|
||||
选择此动作后,将不再继续循环思考,直到收到对方的新消息
|
||||
{{
|
||||
"action": "complete_talk",
|
||||
"target_message_id":"触发完成对话的消息id(通常是对方的最新消息)",
|
||||
"reason":"选择完成对话的原因"
|
||||
}}
|
||||
|
||||
{action_options_text}
|
||||
@@ -92,7 +117,7 @@ no_reply
|
||||
```
|
||||
|
||||
""",
|
||||
"brain_planner_prompt",
|
||||
"brain_planner_prompt_react",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
@@ -123,6 +148,9 @@ class BrainPlanner:
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
# 计划日志记录
|
||||
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
|
||||
) -> Optional["DatabaseMessages"]:
|
||||
@@ -152,10 +180,11 @@ class BrainPlanner:
|
||||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
action = action_json.get("action", "complete_talk")
|
||||
logger.debug(f"{self.log_prefix}解析动作JSON: action={action}, json={action_json}")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非no_reply动作需要target_message_id
|
||||
# 非complete_talk动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
@@ -171,16 +200,28 @@ class BrainPlanner:
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time"]
|
||||
# 内部保留动作(不依赖插件系统)
|
||||
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
||||
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
|
||||
)
|
||||
|
||||
# 将 listening 转换为 wait(向后兼容)
|
||||
if action == "listening":
|
||||
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait,自动转换")
|
||||
action = "wait"
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_reply'"
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (内部动作: {internal_action_names}, 可用插件动作: {available_action_names}),将强制使用 'complete_talk'"
|
||||
)
|
||||
reasoning = (
|
||||
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
|
||||
)
|
||||
action = "no_reply"
|
||||
action = "complete_talk"
|
||||
logger.warning(f"{self.log_prefix}动作已转换为 complete_talk")
|
||||
|
||||
# 创建ActionPlannerInfo对象
|
||||
# 将列表转换为字典格式
|
||||
@@ -201,7 +242,7 @@ class BrainPlanner:
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
action_type="complete_talk",
|
||||
reasoning=f"解析单个action时出错: {e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
@@ -218,7 +259,7 @@ class BrainPlanner:
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作(ReAct模式)。
|
||||
"""
|
||||
|
||||
# 获取聊天上下文
|
||||
@@ -226,7 +267,7 @@ class BrainPlanner:
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
message_id_list: list[Tuple[str, "DatabaseMessages"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
@@ -257,18 +298,19 @@ class BrainPlanner:
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
# 构建包含所有动作的提示词
|
||||
# 构建包含所有动作的提示词:使用统一的 ReAct Prompt
|
||||
prompt_key = "brain_planner_prompt_react"
|
||||
# 这里不记录日志,避免重复打印,由调用方按需控制 log_prompt
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
prompt_key=prompt_key,
|
||||
)
|
||||
|
||||
# 调用LLM获取决策
|
||||
actions = await self._execute_main_planner(
|
||||
reasoning, actions = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
@@ -276,16 +318,22 @@ class BrainPlanner:
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 记录和展示计划日志
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner: {reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
return actions
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool,
|
||||
chat_target_info: Optional["TargetPersonInfo"],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
prompt_key: str = "brain_planner_prompt_react",
|
||||
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
@@ -321,7 +369,7 @@ class BrainPlanner:
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 获取主规划器模板并填充
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("brain_planner_prompt")
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async(prompt_key)
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
@@ -431,17 +479,18 @@ class BrainPlanner:
|
||||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
) -> Tuple[str, List[ActionPlannerInfo]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
extracted_reasoning = ""
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
@@ -456,10 +505,11 @@ class BrainPlanner:
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return [
|
||||
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
return extracted_reasoning, [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
@@ -469,24 +519,32 @@ class BrainPlanner:
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
try:
|
||||
if json_objects := self._extract_json_from_markdown(llm_content):
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
if json_objects:
|
||||
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
for i, json_obj in enumerate(json_objects):
|
||||
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list))
|
||||
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
||||
logger.info(f"{self.log_prefix}解析后的动作: {[a.action_type for a in parsed_actions]}")
|
||||
actions.extend(parsed_actions)
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
|
||||
extracted_reasoning = extracted_reasoning or "LLM没有返回可用动作"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
|
||||
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
traceback.print_exc()
|
||||
else:
|
||||
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
|
||||
extracted_reasoning = "规划器没有获得LLM响应"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
|
||||
# 添加循环开始时间到所有非no_reply动作
|
||||
# 添加循环开始时间到所有动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
@@ -495,13 +553,15 @@ class BrainPlanner:
|
||||
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
return actions
|
||||
return extracted_reasoning, actions
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_reply"""
|
||||
def _create_complete_talk(
|
||||
self, reasoning: str, available_actions: Dict[str, ActionInfo]
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""创建complete_talk"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
action_type="complete_talk",
|
||||
reasoning=reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
@@ -509,33 +569,122 @@ class BrainPlanner:
|
||||
)
|
||||
]
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> List[dict]:
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
"""添加计划日志"""
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""从Markdown格式的内容中提取JSON对象"""
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
markdown_matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
for match in matches:
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
first_json_pos = len(content)
|
||||
if markdown_matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = content.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = content[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
# 处理```json包裹的JSON
|
||||
for match in markdown_matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
# 先尝试将整个块作为一个JSON对象或数组(适用于多行JSON)
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 如果整个块解析失败,尝试按行分割(适用于多个单行JSON对象)
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
# 尝试解析每一行作为独立的JSON对象
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 单行解析失败,继续下一行
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
logger.warning(f"{self.log_prefix}解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
return json_objects
|
||||
# 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```)
|
||||
if not json_objects:
|
||||
json_start_pos = content.find("```json")
|
||||
if json_start_pos != -1:
|
||||
# 找到```json之后的内容
|
||||
json_content_start = json_start_pos + 7 # ```json的长度
|
||||
# 提取从```json之后到内容结尾的所有内容
|
||||
incomplete_json_str = content[json_content_start:].strip()
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if json_start_pos > 0:
|
||||
reasoning_content = content[:json_start_pos].strip()
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
if incomplete_json_str:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
json_str = json_str.strip()
|
||||
|
||||
if json_str:
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"处理不完整的JSON代码块时出错: {e}")
|
||||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
|
||||
emoji.description = emoji_data.description
|
||||
# Deserialize emotion string from DB to list
|
||||
emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
|
||||
emoji.emotion = emoji_data.emotion.replace(",", ",").split(",") if emoji_data.emotion else []
|
||||
emoji.usage_count = emoji_data.usage_count
|
||||
|
||||
db_last_used_time = emoji_data.last_used_time
|
||||
@@ -732,7 +732,7 @@ class EmojiManager:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if emoji_record and emoji_record.emotion:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion.split(",")
|
||||
return emoji_record.emotion.replace(",", ",").split(",")
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||
|
||||
@@ -993,7 +993,7 @@ class EmojiManager:
|
||||
)
|
||||
|
||||
# 处理情感列表
|
||||
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||
emotions = [e.strip() for e in emotions_text.replace(",", ",").split(",") if e.strip()]
|
||||
|
||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||
if len(emotions) > 5:
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
from datetime import datetime
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import frequency_api
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""{name_block}
|
||||
{time_block}
|
||||
你现在正在聊天,请根据下面的聊天记录判断是否有用户觉得你的发言过于频繁或者发言过少
|
||||
{message_str}
|
||||
|
||||
如果用户觉得你的发言过于频繁,请输出"过于频繁",否则输出"正常"
|
||||
如果用户觉得你的发言过少,请输出"过少",否则输出"正常"
|
||||
**你只能输出以下三个词之一,不要输出任何其他文字、解释或标点:**
|
||||
- 正常
|
||||
- 过于频繁
|
||||
- 过少
|
||||
""",
|
||||
"frequency_adjust_prompt",
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("frequency_control")
|
||||
|
||||
|
||||
class FrequencyControl:
|
||||
"""简化的频率控制类,仅管理不同chat_id的频率值"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
# 发言频率调整值
|
||||
self.talk_frequency_adjust: float = 1.0
|
||||
|
||||
self.last_frequency_adjust_time: float = 0.0
|
||||
self.frequency_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
|
||||
)
|
||||
# 频率调整锁,防止并发执行
|
||||
self._adjust_lock = asyncio.Lock()
|
||||
|
||||
def get_talk_frequency_adjust(self) -> float:
|
||||
"""获取发言频率调整值"""
|
||||
return self.talk_frequency_adjust
|
||||
|
||||
def set_talk_frequency_adjust(self, value: float) -> None:
|
||||
"""设置发言频率调整值"""
|
||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||
|
||||
async def trigger_frequency_adjust(self) -> None:
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._adjust_lock:
|
||||
# 在锁内检查,避免并发触发
|
||||
current_time = time.time()
|
||||
previous_adjust_time = self.last_frequency_adjust_time
|
||||
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=previous_adjust_time,
|
||||
timestamp_end=current_time,
|
||||
)
|
||||
|
||||
if current_time - previous_adjust_time < 160 or len(msg_list) <= 20:
|
||||
return
|
||||
|
||||
# 立即更新调整时间,防止并发触发
|
||||
self.last_frequency_adjust_time = current_time
|
||||
|
||||
try:
|
||||
new_msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=previous_adjust_time,
|
||||
timestamp_end=current_time,
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
message_str = build_readable_messages(
|
||||
new_msg_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=False,
|
||||
)
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"frequency_adjust_prompt",
|
||||
name_block=name_block,
|
||||
time_block=time_block,
|
||||
message_str=message_str,
|
||||
)
|
||||
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
|
||||
prompt,
|
||||
)
|
||||
|
||||
# logger.info(f"频率调整 prompt: {prompt}")
|
||||
# logger.info(f"频率调整 response: {response}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"频率调整 prompt: {prompt}")
|
||||
logger.info(f"频率调整 response: {response}")
|
||||
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
|
||||
|
||||
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
|
||||
|
||||
# LLM依然输出过多内容时取消本次调整。合法最多4个字,但有的模型可能会输出一些markdown换行符等,需要长度宽限
|
||||
if len(response) < 20:
|
||||
if "过于频繁" in response:
|
||||
logger.info(f"频率调整: 过于频繁,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 0.8))
|
||||
elif "过少" in response:
|
||||
logger.info(f"频率调整: 过少,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
|
||||
except Exception as e:
|
||||
logger.error(f"频率调整失败: {e}")
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
|
||||
class FrequencyControlManager:
|
||||
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
||||
|
||||
def __init__(self):
|
||||
self.frequency_control_dict: Dict[str, FrequencyControl] = {}
|
||||
|
||||
def get_or_create_frequency_control(self, chat_id: str) -> FrequencyControl:
|
||||
"""获取或创建指定聊天流的频率控制实例"""
|
||||
if chat_id not in self.frequency_control_dict:
|
||||
self.frequency_control_dict[chat_id] = FrequencyControl(chat_id)
|
||||
return self.frequency_control_dict[chat_id]
|
||||
|
||||
def remove_frequency_control(self, chat_id: str) -> bool:
|
||||
"""移除指定聊天流的频率控制实例"""
|
||||
if chat_id in self.frequency_control_dict:
|
||||
del self.frequency_control_dict[chat_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all_chat_ids(self) -> list[str]:
|
||||
"""获取所有有频率控制的聊天ID"""
|
||||
return list(self.frequency_control_dict.keys())
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
# 创建全局实例
|
||||
frequency_control_manager = FrequencyControlManager()
|
||||
50
src/chat/heart_flow/frequency_control.py
Normal file
50
src/chat/heart_flow/frequency_control.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("frequency_control")
|
||||
|
||||
|
||||
class FrequencyControl:
|
||||
"""简化的频率控制类,仅管理不同chat_id的频率值"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
# 发言频率调整值
|
||||
self.talk_frequency_adjust: float = 1.0
|
||||
|
||||
def get_talk_frequency_adjust(self) -> float:
|
||||
"""获取发言频率调整值"""
|
||||
return self.talk_frequency_adjust
|
||||
|
||||
def set_talk_frequency_adjust(self, value: float) -> None:
|
||||
"""设置发言频率调整值"""
|
||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||
|
||||
|
||||
class FrequencyControlManager:
|
||||
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
||||
|
||||
def __init__(self):
|
||||
self.frequency_control_dict: Dict[str, FrequencyControl] = {}
|
||||
|
||||
def get_or_create_frequency_control(self, chat_id: str) -> FrequencyControl:
|
||||
"""获取或创建指定聊天流的频率控制实例"""
|
||||
if chat_id not in self.frequency_control_dict:
|
||||
self.frequency_control_dict[chat_id] = FrequencyControl(chat_id)
|
||||
return self.frequency_control_dict[chat_id]
|
||||
|
||||
def remove_frequency_control(self, chat_id: str) -> bool:
|
||||
"""移除指定聊天流的频率控制实例"""
|
||||
if chat_id in self.frequency_control_dict:
|
||||
del self.frequency_control_dict[chat_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all_chat_ids(self) -> list[str]:
|
||||
"""获取所有有频率控制的聊天ID"""
|
||||
return list(self.frequency_control_dict.keys())
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
frequency_control_manager = FrequencyControlManager()
|
||||
@@ -16,11 +16,11 @@ from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
from src.express.expression_reflector import expression_reflector_manager
|
||||
from src.jargon import extract_and_store_jargon
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
||||
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
||||
from src.bw_learner.expression_reflector import expression_reflector_manager
|
||||
from src.bw_learner.message_recorder import extract_and_distribute_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
@@ -29,6 +29,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import record_replyer_action_temp
|
||||
from src.hippo_memorizer.chat_history_summarizer import ChatHistorySummarizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -99,7 +100,6 @@ class HeartFChatting:
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
self.no_reply_until_call = False
|
||||
|
||||
self.is_mute = False
|
||||
|
||||
@@ -190,7 +190,7 @@ class HeartFChatting:
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=0,
|
||||
)
|
||||
|
||||
# 根据连续 no_reply 次数动态调整阈值
|
||||
@@ -207,23 +207,6 @@ class HeartFChatting:
|
||||
if len(recent_messages_list) >= threshold:
|
||||
# for message in recent_messages_list:
|
||||
# print(message.processed_plain_text)
|
||||
# !处理no_reply_until_call逻辑
|
||||
if self.no_reply_until_call:
|
||||
for message in recent_messages_list:
|
||||
if (
|
||||
message.is_mentioned
|
||||
or message.is_at
|
||||
or len(recent_messages_list) >= 8
|
||||
or time.time() - self.last_read_time > 600
|
||||
):
|
||||
self.no_reply_until_call = False
|
||||
self.last_read_time = time.time()
|
||||
break
|
||||
# 没有提到,继续保持沉默
|
||||
if self.no_reply_until_call:
|
||||
# logger.info(f"{self.log_prefix} 没有提到,继续保持沉默")
|
||||
await asyncio.sleep(1)
|
||||
return True
|
||||
|
||||
self.last_read_time = time.time()
|
||||
|
||||
@@ -303,90 +286,6 @@ class HeartFChatting:
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _run_planner_without_reply(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""执行planner,但不包含reply动作(用于并行执行场景,提及时使用简化版提示词)"""
|
||||
try:
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
is_mentioned=True, # 标记为提及时,使用简化版提示词
|
||||
)
|
||||
# 过滤掉reply动作(虽然提及时不应该有reply,但为了安全还是过滤一下)
|
||||
return [action for action in action_to_use_info if action.action_type != "reply"]
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} Planner执行失败: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
async def _generate_mentioned_reply(
|
||||
self,
|
||||
force_reply_message: "DatabaseMessages",
|
||||
thinking_id: str,
|
||||
cycle_timers: Dict[str, float],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
) -> Dict[str, Any]:
|
||||
"""当被提及时,独立生成回复的任务"""
|
||||
try:
|
||||
self.questioned = False
|
||||
# 重置连续 no_reply 计数
|
||||
self.consecutive_no_reply_count = 0
|
||||
reason = ""
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
with Timer("提及回复生成", cycle_timers):
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=force_reply_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=[], # 独立回复,不依赖planner的动作
|
||||
reply_reason=reason,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=self.last_read_time,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
logger.warning(f"{self.log_prefix} 提及回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "提及回复生成失败", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=force_reply_message,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=[], # 独立回复,不依赖planner的动作
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你回复内容{reply_text}",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 提及回复生成异常: {e}")
|
||||
traceback.print_exc()
|
||||
return {"action_type": "reply", "success": False, "result": f"提及回复生成异常: {e}", "loop_info": None}
|
||||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
|
||||
@@ -412,15 +311,12 @@ class HeartFChatting:
|
||||
|
||||
start_time = time.time()
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(
|
||||
frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()
|
||||
)
|
||||
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
|
||||
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
|
||||
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
|
||||
|
||||
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
# asyncio.create_task(check_and_make_question(self.stream_id))
|
||||
# 添加jargon提取任务 - 提取聊天中的黑话/俚语并入库(内部自行取消息并带冷却)
|
||||
asyncio.create_task(extract_and_store_jargon(self.stream_id))
|
||||
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
|
||||
# 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
|
||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||
@@ -438,95 +334,50 @@ class HeartFChatting:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 如果被提及,让回复生成和planner并行执行
|
||||
if force_reply_message:
|
||||
logger.info(f"{self.log_prefix} 检测到提及,回复生成与planner并行执行")
|
||||
# 执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
# 并行执行planner和回复生成
|
||||
planner_task = asyncio.create_task(
|
||||
self._run_planner_without_reply(
|
||||
available_actions=available_actions,
|
||||
cycle_timers=cycle_timers,
|
||||
)
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
force_reply_message=force_reply_message,
|
||||
)
|
||||
reply_task = asyncio.create_task(
|
||||
self._generate_mentioned_reply(
|
||||
force_reply_message=force_reply_message,
|
||||
thinking_id=thinking_id,
|
||||
cycle_timers=cycle_timers,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
)
|
||||
|
||||
# 等待两个任务完成
|
||||
planner_result, reply_result = await asyncio.gather(planner_task, reply_task, return_exceptions=True)
|
||||
|
||||
# 处理planner结果
|
||||
if isinstance(planner_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} Planner执行异常: {planner_result}")
|
||||
action_to_use_info = []
|
||||
else:
|
||||
action_to_use_info = planner_result
|
||||
|
||||
# 处理回复结果
|
||||
if isinstance(reply_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
|
||||
reply_result = {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"result": "回复生成异常",
|
||||
"loop_info": None,
|
||||
}
|
||||
else:
|
||||
# 正常流程:只执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_no_read_command=True,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
reply_result = None
|
||||
|
||||
# 只在提及情况下过滤掉planner返回的reply动作(提及时已有独立回复生成)
|
||||
if force_reply_message:
|
||||
action_to_use_info = [action for action in action_to_use_info if action.action_type != "reply"]
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作(不包括reply,reply已经独立执行)
|
||||
# 3. 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
@@ -537,10 +388,6 @@ class HeartFChatting:
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
# 如果有独立的回复结果,添加到结果列表中
|
||||
if reply_result:
|
||||
results = list(results) + [reply_result]
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
@@ -751,31 +598,6 @@ class HeartFChatting:
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "no_reply_until_call":
|
||||
# 直接当场执行no_reply_until_call逻辑
|
||||
logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
|
||||
# 增加连续 no_reply 计数
|
||||
self.consecutive_no_reply_count += 1
|
||||
self.no_reply_until_call = True
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply_until_call",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {
|
||||
"action_type": "no_reply_until_call",
|
||||
"success": True,
|
||||
"result": "保持沉默,直到有人直接叫的名字",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
# 直接当场执行reply逻辑
|
||||
self.questioned = False
|
||||
@@ -784,8 +606,27 @@ class HeartFChatting:
|
||||
self.consecutive_no_reply_count = 0
|
||||
|
||||
reason = action_planner_info.reasoning or ""
|
||||
# 根据 think_mode 配置决定 think_level 的值
|
||||
think_mode = global_config.chat.think_mode
|
||||
if think_mode == "default":
|
||||
think_level = 0
|
||||
elif think_mode == "deep":
|
||||
think_level = 1
|
||||
elif think_mode == "dynamic":
|
||||
# dynamic 模式:从 planner 返回的 action_data 中获取
|
||||
think_level = action_planner_info.action_data.get("think_level", 1)
|
||||
else:
|
||||
# 默认使用 default 模式
|
||||
think_level = 0
|
||||
# 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason
|
||||
planner_reasoning = action_planner_info.action_reasoning or reason
|
||||
|
||||
record_replyer_action_temp(
|
||||
chat_id=self.stream_id,
|
||||
reason=reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -797,16 +638,32 @@ class HeartFChatting:
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned_uw: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned_uw.append(s)
|
||||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=planner_reasoning,
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
|
||||
@@ -42,7 +42,10 @@ def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
logger.info("创建LLM客户端")
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager()
|
||||
embed_manager = EmbeddingManager(
|
||||
max_workers=global_config.lpmm_knowledge.max_embedding_workers,
|
||||
chunk_size=global_config.lpmm_knowledge.embedding_chunk_size,
|
||||
)
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
|
||||
@@ -104,7 +104,9 @@ class EmbeddingStore:
|
||||
self.dir = dir_path
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
|
||||
|
||||
self.dirty = False # 标记是否有新增数据需要重建索引
|
||||
|
||||
# 多线程配置参数验证和设置
|
||||
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||
@@ -125,6 +127,11 @@ class EmbeddingStore:
|
||||
self.faiss_index = None
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def hash_texts(namespace: str, texts: List[str]) -> List[str]:
|
||||
"""将原文计算为带前缀的键"""
|
||||
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
@@ -412,6 +419,7 @@ class EmbeddingStore:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if embedding: # 只有成功获取到嵌入才存入
|
||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||
self.dirty = True
|
||||
else:
|
||||
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
|
||||
|
||||
@@ -488,9 +496,17 @@ class EmbeddingStore:
|
||||
self.build_faiss_index()
|
||||
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
|
||||
self.save_to_file()
|
||||
self.dirty = False
|
||||
|
||||
def build_faiss_index(self) -> None:
|
||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
||||
# 空库直接跳过,清空索引映射
|
||||
if not self.store:
|
||||
self.idx2hash = {}
|
||||
self.faiss_index = None
|
||||
self.dirty = False
|
||||
return
|
||||
|
||||
# 获取所有的embedding
|
||||
array = []
|
||||
self.idx2hash = dict()
|
||||
@@ -498,11 +514,44 @@ class EmbeddingStore:
|
||||
array.append(self.store[key].embedding)
|
||||
self.idx2hash[str(len(array) - 1)] = key
|
||||
embeddings = np.array(array, dtype=np.float32)
|
||||
if embeddings.size == 0:
|
||||
self.idx2hash = {}
|
||||
self.faiss_index = None
|
||||
self.dirty = False
|
||||
return
|
||||
# L2归一化
|
||||
faiss.normalize_L2(embeddings)
|
||||
# 构建索引
|
||||
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
||||
self.faiss_index.add(embeddings)
|
||||
self.dirty = False
|
||||
|
||||
def delete_items(self, hashes: List[str]) -> Tuple[int, int]:
|
||||
"""删除指定键的嵌入并重建 idx2hash(不直接重建 faiss)
|
||||
|
||||
Args:
|
||||
hashes: 需要删除的完整键列表(如 paragraph-xxx)
|
||||
|
||||
Returns:
|
||||
(deleted, skipped)
|
||||
"""
|
||||
deleted = 0
|
||||
skipped = 0
|
||||
for h in hashes:
|
||||
if h in self.store:
|
||||
self.store.pop(h)
|
||||
deleted += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
# 重新构建 idx2hash 映射
|
||||
self.idx2hash = {}
|
||||
for idx, key in enumerate(self.store.keys()):
|
||||
self.idx2hash[str(idx)] = key
|
||||
|
||||
# 删除后标记 dirty,faiss 重建由上层统一调用
|
||||
self.dirty = True
|
||||
return deleted, skipped
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
"""搜索最相似的k个项,以余弦相似度为度量
|
||||
@@ -536,7 +585,7 @@ class EmbeddingStore:
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
def __init__(self, max_workers: int | None = None, chunk_size: int | None = None):
|
||||
"""
|
||||
初始化EmbeddingManager
|
||||
|
||||
@@ -544,6 +593,8 @@ class EmbeddingManager:
|
||||
max_workers: 最大线程数
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
"""
|
||||
max_workers = max_workers if max_workers is not None else global_config.lpmm_knowledge.max_embedding_workers
|
||||
chunk_size = chunk_size if chunk_size is not None else global_config.lpmm_knowledge.embedding_chunk_size
|
||||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
"paragraph", # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
@@ -617,7 +668,19 @@ class EmbeddingManager:
|
||||
self.relation_embedding_store.save_to_file()
|
||||
|
||||
def rebuild_faiss_index(self):
|
||||
"""重建Faiss索引(请在添加新数据后调用)"""
|
||||
self.paragraphs_embedding_store.build_faiss_index()
|
||||
self.entities_embedding_store.build_faiss_index()
|
||||
self.relation_embedding_store.build_faiss_index()
|
||||
"""重建Faiss索引,新增数据后调用,带跳过逻辑"""
|
||||
|
||||
def _rebuild_if_needed(store: EmbeddingStore):
|
||||
if (
|
||||
not store.dirty
|
||||
and store.faiss_index is not None
|
||||
and store.idx2hash is not None
|
||||
and getattr(store.faiss_index, "ntotal", 0) == len(store.idx2hash) == len(store.store)
|
||||
):
|
||||
logger.info(f"{store.namespace} FaissIndex 已是最新,跳过重建")
|
||||
return
|
||||
store.build_faiss_index()
|
||||
|
||||
_rebuild_if_needed(self.paragraphs_embedding_store)
|
||||
_rebuild_if_needed(self.entities_embedding_store)
|
||||
_rebuild_if_needed(self.relation_embedding_store)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Set
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -98,6 +99,28 @@ class KGManager:
|
||||
# 加载KG
|
||||
self.graph = di_graph.load_from_file(self.graph_data_path)
|
||||
|
||||
def _rebuild_metadata_from_graph(self) -> None:
|
||||
"""根据当前图重建 stored_paragraph_hashes 与 ent_appear_cnt"""
|
||||
nodes = self.graph.get_node_list()
|
||||
edges = self.graph.get_edge_list()
|
||||
|
||||
# 段落 hash:paragraph-{hash}
|
||||
self.stored_paragraph_hashes = set()
|
||||
for node_id in nodes:
|
||||
if node_id.startswith("paragraph-"):
|
||||
self.stored_paragraph_hashes.add(node_id.split("paragraph-", 1)[1])
|
||||
|
||||
# 实体出现次数:基于 entity -> paragraph 的边权
|
||||
ent_appear_cnt: Dict[str, float] = {}
|
||||
for edge_tuple in edges:
|
||||
src, tgt = edge_tuple[0], edge_tuple[1]
|
||||
if src.startswith("entity") and tgt.startswith("paragraph"):
|
||||
edge_data = self.graph[src, tgt]
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
ent_appear_cnt[src] = ent_appear_cnt.get(src, 0.0) + float(weight)
|
||||
|
||||
self.ent_appear_cnt = ent_appear_cnt
|
||||
|
||||
def _build_edges_between_ent(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
@@ -149,6 +172,13 @@ class KGManager:
|
||||
ent_hash_list.add("entity" + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add("entity" + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list = list(ent_hash_list)
|
||||
# 性能保护:限制同义连接的实体数量
|
||||
max_synonym_entities = global_config.lpmm_knowledge.max_synonym_entities
|
||||
if max_synonym_entities and len(ent_hash_list) > max_synonym_entities:
|
||||
logger.warning(
|
||||
f"同义连接实体数 {len(ent_hash_list)} 超过阈值 {max_synonym_entities},跳过同义边构建以保护性能"
|
||||
)
|
||||
return 0
|
||||
|
||||
synonym_hash_set = set()
|
||||
synonym_result = {}
|
||||
@@ -328,6 +358,10 @@ class KGManager:
|
||||
paragraph_search_result: ParagraphEmbedding的搜索结果(paragraph_hash, similarity)
|
||||
embed_manager: EmbeddingManager对象
|
||||
"""
|
||||
# 性能保护:关闭时直接返回向量检索结果
|
||||
if not global_config.lpmm_knowledge.enable_ppr:
|
||||
logger.info("PPR 已禁用,使用纯向量检索结果")
|
||||
return paragraph_search_result, None
|
||||
# 图中存在的节点总集
|
||||
existed_nodes = self.graph.get_node_list()
|
||||
|
||||
@@ -357,7 +391,15 @@ class KGManager:
|
||||
ent_mean_scores = {} # 记录实体的平均相似度
|
||||
for ent_hash, scores in ent_sim_scores.items():
|
||||
# 先对相似度进行累加,然后与实体计数相除获取最终权重
|
||||
ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
|
||||
# 保护:有些实体在当前图中可能只有实体-实体关系,不会出现在 ent_appear_cnt 中
|
||||
appear_cnt = self.ent_appear_cnt.get(ent_hash)
|
||||
if not appear_cnt or appear_cnt <= 0:
|
||||
logger.debug(
|
||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,"
|
||||
f"将使用 1.0 作为默认出现次数参与权重计算"
|
||||
)
|
||||
appear_cnt = 1.0
|
||||
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)
|
||||
# 记录实体的平均相似度,用于后续的top_k筛选
|
||||
ent_mean_scores[ent_hash] = float(np.mean(scores))
|
||||
del ent_sim_scores
|
||||
@@ -434,3 +476,115 @@ class KGManager:
|
||||
passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True)
|
||||
|
||||
return passage_node_res, ppr_node_weights
|
||||
|
||||
def delete_paragraphs(
|
||||
self,
|
||||
pg_hashes: List[str],
|
||||
ent_hashes: List[str] | None = None,
|
||||
remove_orphan_entities: bool = False,
|
||||
) -> Dict[str, int]:
|
||||
"""删除段落/实体节点及相关边(基于 GraphML),可选清理孤立实体,并重建元数据"""
|
||||
# 要删除的节点 ID
|
||||
nodes_to_delete: Set[str] = {f"paragraph-{h}" for h in pg_hashes}
|
||||
if ent_hashes:
|
||||
nodes_to_delete.update({f"entity-{h}" for h in ent_hashes})
|
||||
|
||||
if not os.path.exists(self.graph_data_path):
|
||||
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
|
||||
|
||||
tree = ET.parse(self.graph_data_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# GraphML 可能带命名空间,用尾缀判断
|
||||
def is_node(elem: ET.Element) -> bool:
|
||||
return elem.tag.endswith("node")
|
||||
|
||||
def is_edge(elem: ET.Element) -> bool:
|
||||
return elem.tag.endswith("edge")
|
||||
|
||||
graph_elem = None
|
||||
for child in root:
|
||||
if child.tag.endswith("graph"):
|
||||
graph_elem = child
|
||||
break
|
||||
if graph_elem is None:
|
||||
raise RuntimeError("GraphML 中未找到 <graph> 节点")
|
||||
|
||||
# 统计现有节点
|
||||
existing_nodes: Set[str] = set()
|
||||
for elem in graph_elem:
|
||||
if is_node(elem):
|
||||
node_id = elem.get("id")
|
||||
if node_id:
|
||||
existing_nodes.add(node_id)
|
||||
|
||||
deleted_nodes = len(nodes_to_delete & existing_nodes)
|
||||
skipped_nodes = len(nodes_to_delete - existing_nodes)
|
||||
|
||||
# 先删除指定节点及相关边
|
||||
# 删除节点
|
||||
for elem in list(graph_elem):
|
||||
if is_node(elem):
|
||||
node_id = elem.get("id")
|
||||
if node_id and node_id in nodes_to_delete:
|
||||
graph_elem.remove(elem)
|
||||
|
||||
# 删除 incident edges
|
||||
for elem in list(graph_elem):
|
||||
if is_edge(elem):
|
||||
src = elem.get("source")
|
||||
tgt = elem.get("target")
|
||||
if src in nodes_to_delete or tgt in nodes_to_delete:
|
||||
graph_elem.remove(elem)
|
||||
|
||||
orphan_removed = 0
|
||||
if remove_orphan_entities:
|
||||
# 计算仍然参与边的节点
|
||||
used_nodes: Set[str] = set()
|
||||
for elem in graph_elem:
|
||||
if is_edge(elem):
|
||||
src = elem.get("source")
|
||||
tgt = elem.get("target")
|
||||
if src:
|
||||
used_nodes.add(src)
|
||||
if tgt:
|
||||
used_nodes.add(tgt)
|
||||
|
||||
# 找出没有任何边的实体节点
|
||||
orphan_entities: Set[str] = set()
|
||||
for elem in graph_elem:
|
||||
if is_node(elem):
|
||||
node_id = elem.get("id")
|
||||
if node_id and node_id.startswith("entity") and node_id not in used_nodes:
|
||||
orphan_entities.add(node_id)
|
||||
|
||||
orphan_removed = len(orphan_entities)
|
||||
|
||||
if orphan_entities:
|
||||
# 删除孤立实体节点
|
||||
for elem in list(graph_elem):
|
||||
if is_node(elem):
|
||||
node_id = elem.get("id")
|
||||
if node_id in orphan_entities:
|
||||
graph_elem.remove(elem)
|
||||
|
||||
# 删除与孤立实体相关的边(理论上已无,但做一次防御性清理)
|
||||
for elem in list(graph_elem):
|
||||
if is_edge(elem):
|
||||
src = elem.get("source")
|
||||
tgt = elem.get("target")
|
||||
if src in orphan_entities or tgt in orphan_entities:
|
||||
graph_elem.remove(elem)
|
||||
|
||||
# 写回 GraphML
|
||||
tree.write(self.graph_data_path, encoding="utf-8", xml_declaration=True)
|
||||
|
||||
# 重新加载图并重建元数据
|
||||
self.graph = di_graph.load_from_file(self.graph_data_path)
|
||||
self._rebuild_metadata_from_graph()
|
||||
|
||||
return {
|
||||
"deleted": deleted_nodes,
|
||||
"skipped": skipped_nodes,
|
||||
"orphan_removed": orphan_removed,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ from maim_message import UserInfo, Seg, GroupInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
@@ -73,7 +72,6 @@ class ChatBot:
|
||||
def __init__(self):
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
|
||||
async def _ensure_started(self):
|
||||
@@ -83,7 +81,7 @@ class ChatBot:
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _process_commands_with_new_system(self, message: MessageRecv):
|
||||
async def _process_commands(self, message: MessageRecv):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
try:
|
||||
@@ -115,17 +113,21 @@ class ChatBot:
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
success, response, intercept_message = await command_instance.execute()
|
||||
message.is_no_read_command = bool(intercept_message)
|
||||
success, response, intercept_message_level = await command_instance.execute()
|
||||
message.intercept_message_level = intercept_message_level
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return True, response, not intercept_message # 找到命令,根据intercept_message决定是否继续
|
||||
return (
|
||||
True,
|
||||
response,
|
||||
not bool(intercept_message_level),
|
||||
) # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
@@ -295,7 +297,7 @@ class ChatBot:
|
||||
# return
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
||||
is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
|
||||
@@ -122,7 +122,7 @@ class MessageRecv(Message):
|
||||
self.is_notify = False
|
||||
|
||||
self.is_command = False
|
||||
self.is_no_read_command = False
|
||||
self.intercept_message_level = 0
|
||||
|
||||
self.priority_mode = "interest"
|
||||
self.priority_info = None
|
||||
@@ -213,6 +213,68 @@ class MessageRecv(Message):
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "video_card":
|
||||
# 处理视频卡片消息
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get("file", "未知视频")
|
||||
file_size = segment.data.get("file_size", "")
|
||||
url = segment.data.get("url", "")
|
||||
text = f"[视频: {file_name}"
|
||||
if file_size:
|
||||
text += f", 大小: {file_size}字节"
|
||||
text += "]"
|
||||
if url:
|
||||
text += f" 链接: {url}"
|
||||
return text
|
||||
return "[视频]"
|
||||
elif segment.type == "music_card":
|
||||
# 处理音乐卡片消息
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
title = segment.data.get("title", "未知歌曲")
|
||||
singer = segment.data.get("singer", "")
|
||||
tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐"
|
||||
jump_url = segment.data.get("jump_url", "")
|
||||
music_url = segment.data.get("music_url", "")
|
||||
text = f"[音乐: {title}"
|
||||
if singer:
|
||||
text += f" - {singer}"
|
||||
if tag:
|
||||
text += f" ({tag})"
|
||||
text += "]"
|
||||
if jump_url:
|
||||
text += f" 跳转链接: {jump_url}"
|
||||
if music_url:
|
||||
text += f" 音乐链接: {music_url}"
|
||||
return text
|
||||
return "[音乐]"
|
||||
elif segment.type == "miniapp_card":
|
||||
# 处理小程序分享卡片(如B站视频分享)
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
title = segment.data.get("title", "") # 小程序名称
|
||||
desc = segment.data.get("desc", "") # 内容描述
|
||||
source_url = segment.data.get("source_url", "") # 原始链接
|
||||
url = segment.data.get("url", "") # 小程序链接
|
||||
text = "[小程序分享"
|
||||
if title:
|
||||
text += f" - {title}"
|
||||
text += "]"
|
||||
if desc:
|
||||
text += f" {desc}"
|
||||
if source_url:
|
||||
text += f" 链接: {source_url}"
|
||||
elif url:
|
||||
text += f" 链接: {url}"
|
||||
return text
|
||||
return "[小程序分享]"
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
|
||||
@@ -72,7 +72,7 @@ class MessageStorage:
|
||||
key_words = ""
|
||||
key_words_lite = ""
|
||||
selected_expressions = message.selected_expressions
|
||||
is_no_read_command = False
|
||||
intercept_message_level = 0
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = message.interest_value
|
||||
@@ -86,7 +86,7 @@ class MessageStorage:
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
is_no_read_command = getattr(message, "is_no_read_command", False)
|
||||
intercept_message_level = getattr(message, "intercept_message_level", 0)
|
||||
# 序列化关键词列表为JSON字符串
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
@@ -138,7 +138,7 @@ class MessageStorage:
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
is_command=is_command,
|
||||
is_no_read_command=is_no_read_command,
|
||||
intercept_message_level=intercept_message_level,
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
selected_expressions=selected_expressions,
|
||||
|
||||
@@ -40,6 +40,93 @@ def is_webui_virtual_group(group_id: str) -> bool:
|
||||
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_message_segments(segment) -> list:
|
||||
"""解析消息段,转换为 WebUI 可用的格式
|
||||
|
||||
参考 NapCat 适配器的消息解析逻辑
|
||||
|
||||
Args:
|
||||
segment: Seg 消息段对象
|
||||
|
||||
Returns:
|
||||
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||
"""
|
||||
|
||||
result = []
|
||||
|
||||
if segment is None:
|
||||
return result
|
||||
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
if segment.data:
|
||||
for seg in segment.data:
|
||||
result.extend(parse_message_segments(seg))
|
||||
elif segment.type == "text":
|
||||
# 文本消息
|
||||
if segment.data:
|
||||
result.append({"type": "text", "data": segment.data})
|
||||
elif segment.type == "image":
|
||||
# 图片消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
|
||||
elif segment.type == "emoji":
|
||||
# 表情包消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
|
||||
elif segment.type == "imageurl":
|
||||
# 图片链接消息
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": segment.data})
|
||||
elif segment.type == "face":
|
||||
# 原生表情
|
||||
result.append({"type": "face", "data": segment.data})
|
||||
elif segment.type == "voice":
|
||||
# 语音消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
|
||||
elif segment.type == "voiceurl":
|
||||
# 语音链接
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": segment.data})
|
||||
elif segment.type == "video":
|
||||
# 视频消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
|
||||
elif segment.type == "videourl":
|
||||
# 视频链接
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": segment.data})
|
||||
elif segment.type == "music":
|
||||
# 音乐消息
|
||||
result.append({"type": "music", "data": segment.data})
|
||||
elif segment.type == "file":
|
||||
# 文件消息
|
||||
result.append({"type": "file", "data": segment.data})
|
||||
elif segment.type == "reply":
|
||||
# 回复消息
|
||||
result.append({"type": "reply", "data": segment.data})
|
||||
elif segment.type == "forward":
|
||||
# 转发消息
|
||||
forward_items = []
|
||||
if segment.data:
|
||||
for item in segment.data:
|
||||
forward_items.append(
|
||||
{
|
||||
"content": parse_message_segments(item.get("message_segment", {}))
|
||||
if isinstance(item, dict)
|
||||
else []
|
||||
}
|
||||
)
|
||||
result.append({"type": "forward", "data": forward_items})
|
||||
else:
|
||||
# 未知类型,尝试作为文本处理
|
||||
if segment.data:
|
||||
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||
@@ -50,17 +137,31 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||
|
||||
|
||||
if is_webui_message and chat_manager is not None:
|
||||
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
|
||||
import time
|
||||
from src.config.config import global_config
|
||||
|
||||
# 解析消息段,获取富文本内容
|
||||
message_segments = parse_message_segments(message.message_segment)
|
||||
|
||||
# 判断消息类型
|
||||
# 如果只有一个文本段,使用简单的 text 类型
|
||||
# 否则使用 rich 类型,包含完整的消息段
|
||||
if len(message_segments) == 1 and message_segments[0].get("type") == "text":
|
||||
message_type = "text"
|
||||
segments = None
|
||||
else:
|
||||
message_type = "rich"
|
||||
segments = message_segments
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "bot_message",
|
||||
"content": message.processed_plain_text,
|
||||
"message_type": "text",
|
||||
"message_type": message_type,
|
||||
"segments": segments, # 富文本消息段
|
||||
"timestamp": time.time(),
|
||||
"group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签
|
||||
"sender": {
|
||||
@@ -81,11 +182,70 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
|
||||
return True
|
||||
|
||||
# 直接调用API发送消息
|
||||
await get_global_api().send_message(message)
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
return True
|
||||
# Fallback 逻辑: 尝试通过 API Server 发送
|
||||
async def send_with_new_api(legacy_exception=None):
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
# 如果未开启 API Server,直接跳过 Fallback
|
||||
if not global_config.maim_message.enable_api_server:
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
||||
global_api = get_global_api()
|
||||
extra_server = getattr(global_api, "extra_server", None)
|
||||
|
||||
if extra_server and extra_server.is_running():
|
||||
# Fallback: 使用极其简单的 Platform -> API Key 映射
|
||||
# 只有收到过该平台的消息,我们才知道该平台的 API Key,才能回传消息
|
||||
platform_map = getattr(global_api, "platform_map", {})
|
||||
target_api_key = platform_map.get(platform)
|
||||
|
||||
if target_api_key:
|
||||
# 构造 APIMessageBase
|
||||
from maim_message.message import APIMessageBase, MessageDim
|
||||
|
||||
msg_dim = MessageDim(api_key=target_api_key, platform=platform)
|
||||
|
||||
api_message = APIMessageBase(
|
||||
message_info=message.message_info,
|
||||
message_segment=message.message_segment,
|
||||
message_dim=msg_dim,
|
||||
)
|
||||
|
||||
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
||||
results = await extra_server.send_message(api_message)
|
||||
|
||||
# 检查是否有任何连接发送成功
|
||||
if any(results.values()):
|
||||
if show_log:
|
||||
logger.info(
|
||||
f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
||||
try:
|
||||
send_result = await get_global_api().send_message(message)
|
||||
# if send_result:
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
return True
|
||||
|
||||
# Legacy API 返回 False (发送失败但未报错),尝试 Fallback
|
||||
# return await send_with_new_api()
|
||||
|
||||
except Exception as legacy_e:
|
||||
# Legacy API 抛出异常,尝试 Fallback
|
||||
# 如果 Fallback 也失败,将重新抛出 legacy_e
|
||||
return await send_with_new_api(legacy_exception=legacy_e)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
|
||||
|
||||
@@ -69,7 +69,7 @@ class ActionModifier:
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
chat_content = build_readable_messages(
|
||||
|
||||
@@ -15,12 +15,15 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
@@ -36,7 +39,6 @@ def init_prompt():
|
||||
"""
|
||||
{time_block}
|
||||
{name_block}
|
||||
你的兴趣是:{interest}
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
**聊天内容**
|
||||
{chat_content_block}
|
||||
@@ -46,9 +48,12 @@ reply
|
||||
动作描述:
|
||||
1.你可以选择呼叫了你的名字,但是你没有做出回应的消息进行回复
|
||||
2.你可以自然的顺着正在进行的聊天内容进行回复或自然的提出一个问题
|
||||
3.不要回复你自己发送的消息
|
||||
4.不要单独对表情包进行回复
|
||||
{{"action":"reply", "target_message_id":"消息id(m+数字)", "reason":"原因"}}
|
||||
3.最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
4.不要选择回复你自己发送的消息
|
||||
5.不要单独对表情包进行回复
|
||||
6.将上下文中所有含义不明的,疑似黑话的,缩写词均写入unknown_words中
|
||||
7.用一句简单的话来描述当前回复场景,不超过10个字
|
||||
{reply_action_example}
|
||||
|
||||
no_reply
|
||||
动作描述:
|
||||
@@ -56,75 +61,37 @@ no_reply
|
||||
控制聊天频率,不要太过频繁的发言
|
||||
{{"action":"no_reply"}}
|
||||
|
||||
{no_reply_until_call_block}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
|
||||
**你之前的action执行和思考记录**
|
||||
{actions_before_now_block}
|
||||
|
||||
请选择**可选的**且符合使用条件的action,并说明触发action的消息id(消息id格式:m+数字)
|
||||
不要回复你自己发送的消息
|
||||
先输出你的简短的选择思考理由,再输出你选择的action,理由不要分点,精简。
|
||||
**动作选择要求**
|
||||
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
|
||||
{plan_style}
|
||||
{moderation_prompt}
|
||||
|
||||
请选择所有符合使用要求的action,动作用json格式输出,用```json包裹,如果输出多个json,每个json都要单独一行放在同一个```json代码块内,你可以重复使用同一个动作或不同动作:
|
||||
target_message_id为必填,表示触发消息的id
|
||||
请选择所有符合使用要求的action,每个动作最多选择一次,但是可以选择多个动作;
|
||||
动作用json格式输出,用```json包裹,如果输出多个json,每个json都要单独一行放在同一个```json代码块内:
|
||||
**示例**
|
||||
// 理由文本(简短)
|
||||
```json
|
||||
{{"action":"动作名", "target_message_id":"m123", "reason":"原因"}}
|
||||
{{"action":"动作名", "target_message_id":"m456", "reason":"原因"}}
|
||||
{{"action":"动作名", "target_message_id":"m123", .....}}
|
||||
{{"action":"动作名", "target_message_id":"m456", .....}}
|
||||
```""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{time_block}
|
||||
{name_block}
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
**聊天内容**
|
||||
{chat_content_block}
|
||||
|
||||
**可选的action**
|
||||
no_reply
|
||||
动作描述:
|
||||
没有合适的可以使用的动作,不使用action
|
||||
{{"action":"no_reply"}}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
**你之前的action执行和思考记录**
|
||||
{actions_before_now_block}
|
||||
|
||||
请选择**可选的**且符合使用条件的action,并说明触发action的消息id(消息id格式:m+数字)
|
||||
先输出你的简短的选择思考理由,再输出你选择的action,理由不要分点,精简。
|
||||
**动作选择要求**
|
||||
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
|
||||
1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
|
||||
2.如果相同的内容已经被执行,请不要重复执行
|
||||
{moderation_prompt}
|
||||
|
||||
请选择所有符合使用要求的action,动作用json格式输出,用```json包裹,如果输出多个json,每个json都要单独一行放在同一个```json代码块内,你可以重复使用同一个动作或不同动作:
|
||||
**示例**
|
||||
// 理由文本(简短)
|
||||
```json
|
||||
{{"action":"动作名", "target_message_id":"m123", "reason":"原因"}}
|
||||
{{"action":"动作名", "target_message_id":"m456", "reason":"原因"}}
|
||||
```""",
|
||||
"planner_prompt_mentioned",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{action_name}
|
||||
动作描述:{action_description}
|
||||
使用条件{parallel_text}:
|
||||
{action_require}
|
||||
{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)", "reason":"原因"}}
|
||||
{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)"}}
|
||||
""",
|
||||
"action_prompt",
|
||||
)
|
||||
@@ -195,11 +162,41 @@ class ActionPlanner:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
|
||||
return msg_id
|
||||
|
||||
msg_text = (message.processed_plain_text or message.display_message or "").strip()
|
||||
msg_text = (message.processed_plain_text or "").strip()
|
||||
if not msg_text:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
|
||||
return msg_id
|
||||
|
||||
# 替换 [picid:xxx] 为 [图片:描述]
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
def replace_pic_id(pic_match: re.Match) -> str:
|
||||
pic_id = pic_match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
||||
# 这里匹配到的应该都是单独的格式
|
||||
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
||||
def replace_user_ref(user_match: re.Match) -> str:
|
||||
user_name = user_match.group(1)
|
||||
user_id = user_match.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_name
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
return user_name
|
||||
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview})")
|
||||
return f"消息({msg_text})"
|
||||
@@ -218,11 +215,14 @@ class ActionPlanner:
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
original_reasoning = action_json.get("reason", "未提供原因")
|
||||
reasoning = self._replace_message_ids_with_text(original_reasoning, message_id_list)
|
||||
if reasoning is None:
|
||||
reasoning = original_reasoning
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 使用 extracted_reasoning(整体推理文本)作为 reasoning
|
||||
if extracted_reasoning:
|
||||
reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list)
|
||||
if reasoning is None:
|
||||
reasoning = extracted_reasoning
|
||||
else:
|
||||
reasoning = "未提供原因"
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
@@ -248,7 +248,7 @@ class ActionPlanner:
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time", "no_reply_until_call"]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time"]
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
@@ -304,7 +304,7 @@ class ActionPlanner:
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
is_mentioned: bool = False,
|
||||
force_reply_message: Optional["DatabaseMessages"] = None,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
@@ -316,7 +316,7 @@ class ActionPlanner:
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
message_id_list: list[Tuple[str, "DatabaseMessages"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
@@ -345,11 +345,6 @@ class ActionPlanner:
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
# 如果是提及时且没有可用动作,直接返回空列表,不调用LLM以节省token
|
||||
if is_mentioned and not filtered_actions:
|
||||
logger.info(f"{self.log_prefix}提及时没有可用动作,跳过plan调用")
|
||||
return []
|
||||
|
||||
# 构建包含所有动作的提示词
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
@@ -357,8 +352,6 @@ class ActionPlanner:
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
is_mentioned=is_mentioned,
|
||||
)
|
||||
|
||||
# 调用LLM获取决策
|
||||
@@ -370,6 +363,34 @@ class ActionPlanner:
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 如果有强制回复消息,确保回复该消息
|
||||
if force_reply_message:
|
||||
# 检查是否已经有回复该消息的 action
|
||||
has_reply_to_force_message = False
|
||||
for action in actions:
|
||||
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
|
||||
has_reply_to_force_message = True
|
||||
break
|
||||
|
||||
# 如果没有回复该消息,强制添加回复 action
|
||||
if not has_reply_to_force_message:
|
||||
# 移除所有 no_reply action(如果有)
|
||||
actions = [a for a in actions if a.action_type != "no_reply"]
|
||||
|
||||
# 创建强制回复 action
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
force_reply_action = ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="用户提及了我,必须回复该消息",
|
||||
action_data={"loop_start_time": loop_start_time},
|
||||
action_message=force_reply_message,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=None,
|
||||
)
|
||||
# 将强制回复 action 放在最前面
|
||||
actions.insert(0, force_reply_action)
|
||||
logger.info(f"{self.log_prefix} 检测到强制回复消息,已添加回复动作")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
@@ -430,32 +451,6 @@ class ActionPlanner:
|
||||
|
||||
return plan_log_str
|
||||
|
||||
def _has_consecutive_no_reply(self, min_count: int = 3) -> bool:
|
||||
"""
|
||||
检查是否有连续min_count次以上的no_reply
|
||||
|
||||
Args:
|
||||
min_count: 需要连续的最少次数,默认3
|
||||
|
||||
Returns:
|
||||
如果有连续min_count次以上no_reply返回True,否则返回False
|
||||
"""
|
||||
consecutive_count = 0
|
||||
|
||||
# 从后往前遍历plan_log,检查最新的连续记录
|
||||
for _reasoning, _timestamp, content in reversed(self.plan_log):
|
||||
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
|
||||
# 检查所有action是否都是no_reply
|
||||
if all(action.action_type == "no_reply" for action in content):
|
||||
consecutive_count += 1
|
||||
if consecutive_count >= min_count:
|
||||
return True
|
||||
else:
|
||||
# 如果遇到非no_reply的action,重置计数
|
||||
break
|
||||
|
||||
return False
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool,
|
||||
@@ -464,7 +459,6 @@ class ActionPlanner:
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
is_mentioned: bool = False,
|
||||
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
@@ -485,48 +479,35 @@ class ActionPlanner:
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 根据是否是提及时选择不同的模板
|
||||
if is_mentioned:
|
||||
# 提及时使用简化版提示词,不需要reply、no_reply、no_reply_until_call
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt_mentioned")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
# 根据 think_mode 配置决定 reply action 的示例 JSON
|
||||
# 在 JSON 中直接作为 action 参数携带 unknown_words
|
||||
if global_config.chat.think_mode == "classic":
|
||||
reply_action_example = (
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]}}'
|
||||
)
|
||||
else:
|
||||
# 正常流程使用完整版提示词
|
||||
# 检查是否有连续3次以上no_reply,如果有则添加no_reply_until_call选项
|
||||
no_reply_until_call_block = ""
|
||||
if self._has_consecutive_no_reply(min_count=3):
|
||||
no_reply_until_call_block = """no_reply_until_call
|
||||
动作描述:
|
||||
保持沉默,直到有人直接叫你的名字
|
||||
当前话题不感兴趣时使用,或有人不喜欢你的发言时使用
|
||||
当你频繁选择no_reply时使用,表示话题暂时与你无关
|
||||
{{"action":"no_reply_until_call"}}
|
||||
"""
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
no_reply_until_call_block=no_reply_until_call_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
reply_action_example = (
|
||||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||
+ '{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]}}'
|
||||
)
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
reply_action_example=reply_action_example,
|
||||
)
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
@@ -696,6 +677,12 @@ class ActionPlanner:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
# 去重:如果同一个动作被选择了多次,随机选择其中一个
|
||||
if actions:
|
||||
shuffled = actions.copy()
|
||||
random.shuffle(shuffled)
|
||||
actions = list({a.action_type: a for a in shuffled}.values())
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
|
||||
return extracted_reasoning, actions
|
||||
@@ -747,23 +734,27 @@ class ActionPlanner:
|
||||
# 尝试解析每一行作为独立的JSON对象
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 如果单行解析失败,尝试将整个块作为一个JSON对象或数组
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
# 过滤掉空字典
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
@@ -798,23 +789,27 @@ class ActionPlanner:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
# 过滤掉空字典
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
|
||||
|
||||
@@ -18,13 +18,12 @@ from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
@@ -36,7 +35,7 @@ from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
|
||||
from src.jargon.jargon_explainer import explain_jargon_in_context
|
||||
from src.bw_learner.jargon_explainer import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||
|
||||
init_lpmm_prompt()
|
||||
init_replyer_prompt()
|
||||
@@ -73,6 +72,8 @@ class DefaultReplyer:
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
@@ -98,8 +99,10 @@ class DefaultReplyer:
|
||||
available_actions = {}
|
||||
try:
|
||||
# 3. 构建 Prompt
|
||||
timing_logs = []
|
||||
almost_zero_str = ""
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt, selected_expressions = await self.build_prompt_reply_context(
|
||||
prompt, selected_expressions, timing_logs, almost_zero_str = await self.build_prompt_reply_context(
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_actions,
|
||||
@@ -107,6 +110,8 @@ class DefaultReplyer:
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
reply_time_point=reply_time_point,
|
||||
think_level=think_level,
|
||||
unknown_words=unknown_words,
|
||||
)
|
||||
llm_response.prompt = prompt
|
||||
llm_response.selected_expressions = selected_expressions
|
||||
@@ -135,10 +140,22 @@ class DefaultReplyer:
|
||||
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
|
||||
# logger.debug(f"replyer生成内容: {content}")
|
||||
|
||||
logger.info(f"replyer生成内容: {content}")
|
||||
if global_config.debug.show_replyer_reasoning:
|
||||
logger.info(f"replyer生成推理:\n{reasoning_content}")
|
||||
logger.info(f"replyer生成模型: {model_name}")
|
||||
# 统一输出所有日志信息,使用try-except确保即使某个步骤出错也能输出
|
||||
try:
|
||||
# 1. 输出回复准备日志
|
||||
timing_log_str = f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s" if timing_logs or almost_zero_str else "回复准备: 无计时信息"
|
||||
logger.info(timing_log_str)
|
||||
# 2. 输出Prompt日志
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
# 3. 输出模型生成内容和推理日志
|
||||
logger.info(f"模型: [{model_name}][思考等级:{think_level}]生成内容: {content}")
|
||||
if global_config.debug.show_replyer_reasoning and reasoning_content:
|
||||
logger.info(f"模型: [{model_name}][思考等级:{think_level}]生成推理:\n{reasoning_content}")
|
||||
except Exception as e:
|
||||
logger.warning(f"输出日志时出错: {e}")
|
||||
|
||||
llm_response.content = content
|
||||
llm_response.reasoning = reasoning_content
|
||||
@@ -162,6 +179,21 @@ class DefaultReplyer:
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"LLM 生成失败: {llm_e}")
|
||||
# 即使LLM生成失败,也尝试输出已收集的日志信息
|
||||
try:
|
||||
# 1. 输出回复准备日志
|
||||
timing_log_str = f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s" if timing_logs or almost_zero_str else "回复准备: 无计时信息"
|
||||
logger.info(timing_log_str)
|
||||
# 2. 输出Prompt日志
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
# 3. 输出模型生成失败信息
|
||||
logger.info("模型生成失败,无法输出生成内容和推理")
|
||||
except Exception as log_e:
|
||||
logger.warning(f"输出日志时出错: {log_e}")
|
||||
|
||||
return False, llm_response # LLM 调用失败则无法生成回复
|
||||
|
||||
return True, llm_response
|
||||
@@ -228,7 +260,7 @@ class DefaultReplyer:
|
||||
return False, llm_response
|
||||
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = ""
|
||||
self, chat_history: str, target: str, reply_reason: str = "", think_level: int = 1
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
@@ -237,6 +269,7 @@ class DefaultReplyer:
|
||||
chat_history: 聊天历史记录
|
||||
target: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1/2
|
||||
|
||||
Returns:
|
||||
str: 表达习惯信息字符串
|
||||
@@ -249,14 +282,19 @@ class DefaultReplyer:
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
|
||||
self.chat_stream.stream_id,
|
||||
chat_history,
|
||||
max_num=8,
|
||||
target_message=target,
|
||||
reply_reason=reply_reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
style_habits.append(f"当{expr['situation']}时:{expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
@@ -272,13 +310,6 @@ class DefaultReplyer:
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_mood_state_prompt(self) -> str:
|
||||
"""构建情绪状态提示"""
|
||||
if not global_config.mood.enable_mood:
|
||||
return ""
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
@@ -459,6 +490,57 @@ class DefaultReplyer:
|
||||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
async def _build_disabled_jargon_explanation(self) -> str:
|
||||
"""当关闭黑话解释时使用的占位协程,避免额外的LLM调用"""
|
||||
return ""
|
||||
|
||||
async def _build_unknown_words_jargon(self, unknown_words: Optional[List[str]], chat_id: str) -> str:
|
||||
"""针对 Planner 提供的未知词语列表执行黑话检索"""
|
||||
if not unknown_words:
|
||||
return ""
|
||||
# 清洗未知词语列表,只保留非空字符串
|
||||
concepts: List[str] = []
|
||||
for item in unknown_words:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
concepts.append(s)
|
||||
if not concepts:
|
||||
return ""
|
||||
try:
|
||||
return await retrieve_concepts_with_jargon(concepts, chat_id)
|
||||
except Exception as e:
|
||||
logger.error(f"未知词语黑话检索失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _build_jargon_explanation(
|
||||
self,
|
||||
chat_id: str,
|
||||
messages_short: List[DatabaseMessages],
|
||||
chat_talking_prompt_short: str,
|
||||
unknown_words: Optional[List[str]],
|
||||
) -> str:
|
||||
"""
|
||||
统一的黑话解释构建函数:
|
||||
- 根据 enable_jargon_explanation / jargon_mode 决定具体策略
|
||||
"""
|
||||
enable_jargon_explanation = getattr(global_config.expression, "enable_jargon_explanation", True)
|
||||
if not enable_jargon_explanation:
|
||||
return ""
|
||||
|
||||
jargon_mode = getattr(global_config.expression, "jargon_mode", "context")
|
||||
|
||||
# planner 模式:仅使用 Planner 的 unknown_words
|
||||
if jargon_mode == "planner":
|
||||
return await self._build_unknown_words_jargon(unknown_words, chat_id)
|
||||
|
||||
# 默认 / context 模式:使用上下文自动匹配黑话
|
||||
try:
|
||||
return await explain_jargon_in_context(chat_id, messages_short, chat_talking_prompt_short)
|
||||
except Exception as e:
|
||||
logger.error(f"上下文黑话解释失败: {e}")
|
||||
return ""
|
||||
|
||||
def build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
@@ -606,7 +688,7 @@ class DefaultReplyer:
|
||||
# 获取基础personality
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
# 检查是否需要随机替换为状态(personality 本体)
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
@@ -643,16 +725,10 @@ class DefaultReplyer:
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
chat_id = hashlib.md5(key.encode()).hexdigest()
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
@@ -705,7 +781,9 @@ class DefaultReplyer:
|
||||
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
) -> Tuple[str, List[int]]:
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
) -> Tuple[str, List[int], List[str], str]:
|
||||
"""
|
||||
构建回复器上下文
|
||||
|
||||
@@ -751,14 +829,14 @@ class DefaultReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=reply_time_point,
|
||||
limit=global_config.chat.max_context_size * 1,
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=reply_time_point,
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
person_list_short: List[Person] = []
|
||||
@@ -789,10 +867,16 @@ class DefaultReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 并行执行八个构建任务(包括黑话解释)
|
||||
# 统一黑话解释构建:根据配置选择上下文或 Planner 模式
|
||||
jargon_coroutine = self._build_jargon_explanation(
|
||||
chat_id, message_list_before_short, chat_talking_prompt_short, unknown_words
|
||||
)
|
||||
|
||||
# 并行执行构建任务(包括黑话解释,可配置关闭)
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
|
||||
"expression_habits",
|
||||
),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
@@ -800,17 +884,13 @@ class DefaultReplyer:
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
self._time_and_run_task(
|
||||
build_memory_retrieval_prompt(
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level
|
||||
),
|
||||
"memory_retrieval",
|
||||
),
|
||||
self._time_and_run_task(
|
||||
explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short),
|
||||
"jargon_explanation",
|
||||
),
|
||||
self._time_and_run_task(jargon_coroutine, "jargon_explanation"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
@@ -821,7 +901,6 @@ class DefaultReplyer:
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
"memory_retrieval": "记忆检索",
|
||||
"jargon_explanation": "黑话解释",
|
||||
}
|
||||
@@ -839,7 +918,8 @@ class DefaultReplyer:
|
||||
continue
|
||||
|
||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
|
||||
# 不再在这里输出日志,而是返回给调用者统一输出
|
||||
# logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
|
||||
|
||||
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||
expression_habits_block: str
|
||||
@@ -851,14 +931,8 @@ class DefaultReplyer:
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
memory_retrieval: str = results_dict["memory_retrieval"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
jargon_explanation: str = results_dict.get("jargon_explanation") or ""
|
||||
|
||||
# 从 chosen_actions 中提取 planner 的整体思考理由
|
||||
planner_reasoning = ""
|
||||
if global_config.chat.include_planner_reasoning and reply_reason:
|
||||
# 如果没有 chosen_actions,使用 reply_reason 作为备选
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
@@ -893,14 +967,31 @@ class DefaultReplyer:
|
||||
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
|
||||
chat_prompt_block = f"{chat_prompt_content}\n" if chat_prompt_content else ""
|
||||
|
||||
# 固定使用群聊回复模板
|
||||
# 根据think_level选择不同的回复模板
|
||||
# think_level=0: 轻量回复(简短平淡)
|
||||
# think_level=1: 中等回复(日常口语化)
|
||||
if think_level == 0:
|
||||
prompt_name = "replyer_prompt_0"
|
||||
else: # think_level == 1 或默认
|
||||
prompt_name = "replyer_prompt"
|
||||
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
|
||||
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
|
||||
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
|
||||
try:
|
||||
reply_style = random.choice(list(multi_styles))
|
||||
except Exception:
|
||||
# 兜底:即使 multiple_reply_style 配置异常也不影响正常回复
|
||||
reply_style = global_config.personality.reply_style
|
||||
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_prompt",
|
||||
prompt_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
bot_name=global_config.bot.nickname,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
# relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
jargon_explanation=jargon_explanation,
|
||||
@@ -910,13 +1001,13 @@ class DefaultReplyer:
|
||||
dialogue_prompt=dialogue_prompt,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
memory_retrieval=memory_retrieval,
|
||||
chat_prompt=chat_prompt_block,
|
||||
planner_reasoning=planner_reasoning,
|
||||
), selected_expressions
|
||||
), selected_expressions, timing_logs, almost_zero_str
|
||||
|
||||
async def build_prompt_rewrite_context(
|
||||
self,
|
||||
@@ -926,8 +1017,6 @@ class DefaultReplyer:
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
@@ -941,7 +1030,7 @@ class DefaultReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
@@ -967,61 +1056,42 @@ class DefaultReplyer:
|
||||
|
||||
if sender and target:
|
||||
# 使用预先分析的内容类型结果
|
||||
if is_group_chat:
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else: # private chat
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要回复。"
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
if is_group_chat:
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
else:
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||
chat_target_1 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private1", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
|
||||
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
|
||||
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
|
||||
try:
|
||||
reply_style = random.choice(list(multi_styles))
|
||||
except Exception:
|
||||
reply_style = global_config.personality.reply_style
|
||||
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
@@ -1034,7 +1104,7 @@ class DefaultReplyer:
|
||||
reply_target_block=reply_target_block,
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
@@ -1078,10 +1148,11 @@ class DefaultReplyer:
|
||||
# 直接使用已初始化的模型实例
|
||||
# logger.info(f"\n{prompt}\n")
|
||||
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
# 不再在这里输出日志,而是返回给调用者统一输出
|
||||
# if global_config.debug.show_replyer_prompt:
|
||||
# logger.info(f"\n{prompt}\n")
|
||||
# else:
|
||||
# logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
@@ -1090,7 +1161,7 @@ class DefaultReplyer:
|
||||
# 移除 content 前后的换行符和空格
|
||||
content = content.strip()
|
||||
|
||||
logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
# logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
|
||||
@@ -23,9 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
@@ -34,13 +33,13 @@ from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||
from src.chat.replyer.prompt.replyer_private_prompt import init_replyer_private_prompt
|
||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
|
||||
from src.jargon.jargon_explainer import explain_jargon_in_context
|
||||
from src.bw_learner.jargon_explainer import explain_jargon_in_context
|
||||
|
||||
init_lpmm_prompt()
|
||||
init_replyer_prompt()
|
||||
init_replyer_private_prompt()
|
||||
init_rewrite_prompt()
|
||||
init_memory_retrieval_prompt()
|
||||
|
||||
@@ -72,9 +71,11 @@ class PrivateReplyer:
|
||||
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
from_plugin: bool = True,
|
||||
think_level: int = 1,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
@@ -271,7 +272,7 @@ class PrivateReplyer:
|
||||
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
style_habits.append(f"当{expr['situation']}时:{expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
@@ -287,13 +288,6 @@ class PrivateReplyer:
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_mood_state_prompt(self) -> str:
|
||||
"""构建情绪状态提示"""
|
||||
if not global_config.mood.enable_mood:
|
||||
return ""
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
@@ -474,6 +468,10 @@ class PrivateReplyer:
|
||||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
async def _build_disabled_jargon_explanation(self) -> str:
|
||||
"""当关闭黑话解释时使用的占位协程,避免额外的LLM调用"""
|
||||
return ""
|
||||
|
||||
async def build_actions_prompt(
|
||||
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
|
||||
) -> str:
|
||||
@@ -557,16 +555,10 @@ class PrivateReplyer:
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
chat_id = hashlib.md5(key.encode()).hexdigest()
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
@@ -663,7 +655,7 @@ class PrivateReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
dialogue_prompt = build_readable_messages(
|
||||
@@ -678,7 +670,7 @@ class PrivateReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
person_list_short: List[Person] = []
|
||||
@@ -709,7 +701,14 @@ class PrivateReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 并行执行九个构建任务(包括黑话解释)
|
||||
# 根据配置决定是否启用黑话解释
|
||||
enable_jargon_explanation = getattr(global_config.expression, "enable_jargon_explanation", True)
|
||||
if enable_jargon_explanation:
|
||||
jargon_coroutine = explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short)
|
||||
else:
|
||||
jargon_coroutine = self._build_disabled_jargon_explanation()
|
||||
|
||||
# 并行执行九个构建任务(包括黑话解释,可配置关闭)
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
|
||||
@@ -721,17 +720,13 @@ class PrivateReplyer:
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
self._time_and_run_task(
|
||||
build_memory_retrieval_prompt(
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
|
||||
),
|
||||
"memory_retrieval",
|
||||
),
|
||||
self._time_and_run_task(
|
||||
explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short),
|
||||
"jargon_explanation",
|
||||
),
|
||||
self._time_and_run_task(jargon_coroutine, "jargon_explanation"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
@@ -742,7 +737,6 @@ class PrivateReplyer:
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
"memory_retrieval": "记忆检索",
|
||||
"jargon_explanation": "黑话解释",
|
||||
}
|
||||
@@ -770,16 +764,10 @@ class PrivateReplyer:
|
||||
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
actions_info: str = results_dict["actions_info"]
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
memory_retrieval: str = results_dict["memory_retrieval"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
jargon_explanation: str = results_dict.get("jargon_explanation") or ""
|
||||
|
||||
# 从 chosen_actions 中提取 planner 的整体思考理由
|
||||
planner_reasoning = ""
|
||||
if global_config.chat.include_planner_reasoning and reply_reason:
|
||||
# 如果没有 chosen_actions,使用 reply_reason 作为备选
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
@@ -814,7 +802,6 @@ class PrivateReplyer:
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
@@ -837,7 +824,6 @@ class PrivateReplyer:
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
@@ -878,7 +864,7 @@ class PrivateReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
@@ -904,59 +890,30 @@ class PrivateReplyer:
|
||||
)
|
||||
|
||||
if sender and target:
|
||||
# 使用预先分析的内容类型结果
|
||||
if is_group_chat:
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
)
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else: # private chat
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要回复。"
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要回复。"
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
if is_group_chat:
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
else:
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||
chat_target_1 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private1", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
|
||||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
|
||||
41
src/chat/replyer/prompt/replyer_private_prompt.py
Normal file
41
src/chat/replyer/prompt/replyer_private_prompt.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
|
||||
def init_replyer_private_prompt():
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"private_replyer_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。
|
||||
{identity}
|
||||
{chat_prompt}尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"private_replyer_self_prompt",
|
||||
)
|
||||
@@ -3,8 +3,27 @@ from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
|
||||
def init_replyer_prompt():
|
||||
Prompt("正在群里聊天", "chat_target_group2")
|
||||
Prompt("和{sender_name}聊天", "chat_target_private2")
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片
|
||||
其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,
|
||||
尽量简短一些。{keywords_reaction_prompt}
|
||||
请注意把握聊天内容,不要回复的太有条理。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。
|
||||
最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
现在,你说:""",
|
||||
"replyer_prompt_0",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
@@ -18,49 +37,12 @@ def init_replyer_prompt():
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且简短的回复。
|
||||
最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
{keywords_reaction_prompt}
|
||||
请注意把握聊天内容。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出一句回复内容就好。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,at或 @等 ),只输出发言内容就好。
|
||||
现在,你说:""",
|
||||
"replyer_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"private_replyer_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。{mood_state}
|
||||
{identity}
|
||||
{chat_prompt}尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"private_replyer_self_prompt",
|
||||
)
|
||||
|
||||
@@ -120,7 +120,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
filter_no_read_command=False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -138,7 +138,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_bot,
|
||||
filter_command=filter_command,
|
||||
filter_no_read_command=filter_no_read_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
filter_no_read_command=False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -167,7 +167,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_bot,
|
||||
filter_command=filter_command,
|
||||
filter_no_read_command=filter_no_read_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
@@ -303,7 +303,7 @@ def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Datab
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id: str, timestamp: float, limit: int = 0, filter_no_read_command: bool = False
|
||||
chat_id: str, timestamp: float, limit: int = 0, filter_intercept_message_level: Optional[int] = None
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
|
||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(
|
||||
message_filter=filter_query, sort=sort_order, limit=limit, filter_no_read_command=filter_no_read_command
|
||||
message_filter=filter_query,
|
||||
sort=sort_order,
|
||||
limit=limit,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Dict, Tuple, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages, ActionRecords
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.config.config import global_config
|
||||
@@ -505,13 +505,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 获取bot的QQ账号
|
||||
bot_qq_account = (
|
||||
str(global_config.bot.qq_account)
|
||||
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||
else ""
|
||||
)
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
@@ -537,7 +530,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
if not chat_id: # Should not happen if above logic is correct
|
||||
continue
|
||||
|
||||
# Update name_mapping
|
||||
# Update name_mapping(仅用于展示聊天名称)
|
||||
try:
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
@@ -549,19 +542,30 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 重置为正确的格式
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
|
||||
# 检查是否是bot发送的消息(回复)
|
||||
is_bot_reply = False
|
||||
if bot_qq_account and message.user_id == bot_qq_account:
|
||||
is_bot_reply = True
|
||||
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if message_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
if is_bot_reply:
|
||||
stats[period_key][TOTAL_REPLY_CNT] += 1
|
||||
break
|
||||
|
||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||
try:
|
||||
action_query_start_timestamp = collect_period[-1][1].timestamp()
|
||||
for action in ActionRecords.select().where(ActionRecords.time >= action_query_start_timestamp): # type: ignore
|
||||
# 仅统计已完成的 reply 动作
|
||||
if action.action_name != "reply" or not action.action_done:
|
||||
continue
|
||||
|
||||
action_time_ts = action.time
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if action_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REPLY_CNT] += 1
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"统计 reply 动作次数失败,将回复数视为 0,错误信息:{e}")
|
||||
|
||||
return stats
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
@@ -742,7 +746,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
|
||||
output = [
|
||||
"按模型分类统计:",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
@@ -755,11 +759,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
@@ -767,7 +771,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
@@ -796,7 +800,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
|
||||
output = [
|
||||
"按模块分类统计:",
|
||||
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
@@ -809,11 +813,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODULE][module_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
|
||||
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
@@ -821,7 +825,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
import jieba
|
||||
import json
|
||||
import ast
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Tuple, List, TYPE_CHECKING
|
||||
|
||||
@@ -196,21 +198,54 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
List[str]: 分割和合并后的句子列表
|
||||
"""
|
||||
# 预处理:处理多余的换行符
|
||||
# 1. 将连续的换行符替换为单个换行符
|
||||
# 1. 将连续的换行符替换为单个换行符(保留换行符用于分割)
|
||||
text = re.sub(r"\n\s*\n+", "\n", text)
|
||||
# 2. 处理换行符和其他分隔符的组合
|
||||
text = re.sub(r"\n\s*([,,。;\s])", r"\1", text)
|
||||
text = re.sub(r"([,,。;\s])\s*\n", r"\1", text)
|
||||
# 2. 处理换行符和其他分隔符的组合(保留换行符,删除其他分隔符)
|
||||
text = re.sub(r"\n\s*([,,。;\s])", r"\n\1", text)
|
||||
text = re.sub(r"([,,。;\s])\s*\n", r"\1\n", text)
|
||||
|
||||
# 处理两个汉字中间的换行符
|
||||
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
|
||||
# 处理两个汉字中间的换行符(保留换行符,不替换为句号,让换行符强制分割)
|
||||
# text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text) # 注释掉,保留换行符用于分割
|
||||
|
||||
len_text = len(text)
|
||||
if len_text < 3:
|
||||
return list(text) if random.random() < 0.01 else [text]
|
||||
|
||||
# 定义分隔符
|
||||
separators = {",", ",", " ", "。", ";"}
|
||||
# 先标记哪些位置位于成对引号内部,避免在引号内部进行句子分割
|
||||
# 支持的引号包括:中英文单/双引号和常见中文书名号/引号
|
||||
quote_chars = {
|
||||
'"',
|
||||
"'",
|
||||
"“",
|
||||
"”",
|
||||
"‘",
|
||||
"’",
|
||||
"「",
|
||||
"」",
|
||||
"『",
|
||||
"』",
|
||||
}
|
||||
inside_quote = [False] * len_text
|
||||
in_quote = False
|
||||
current_quote_char = ""
|
||||
for idx, ch in enumerate(text):
|
||||
if ch in quote_chars:
|
||||
# 遇到引号时切换状态(英文引号本身开闭相同,用同一个字符表示)
|
||||
if not in_quote:
|
||||
in_quote = True
|
||||
current_quote_char = ch
|
||||
inside_quote[idx] = False
|
||||
else:
|
||||
# 只有遇到同一类引号才视为关闭
|
||||
if ch == current_quote_char or ch in {'"', "'"} and current_quote_char in {'"', "'"}:
|
||||
in_quote = False
|
||||
current_quote_char = ""
|
||||
inside_quote[idx] = False
|
||||
else:
|
||||
inside_quote[idx] = in_quote
|
||||
|
||||
# 定义分隔符(包含换行符)
|
||||
separators = {",", ",", " ", "。", ";", "\n"}
|
||||
segments = []
|
||||
current_segment = ""
|
||||
|
||||
@@ -219,24 +254,42 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
if char in separators:
|
||||
# 检查分割条件:如果空格左右都是英文字母、数字,或数字和英文之间,则不分割(仅对空格应用此规则)
|
||||
can_split = True
|
||||
if 0 < i < len(text) - 1:
|
||||
prev_char = text[i - 1]
|
||||
next_char = text[i + 1]
|
||||
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
|
||||
if char == " ":
|
||||
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
|
||||
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
|
||||
if prev_is_alnum and next_is_alnum:
|
||||
can_split = False
|
||||
# 引号内部一律不作为分割点(包括换行)
|
||||
if inside_quote[i]:
|
||||
can_split = False
|
||||
else:
|
||||
# 换行符在不在引号内时都强制分割
|
||||
if char == "\n":
|
||||
can_split = True
|
||||
else:
|
||||
# 检查分割条件
|
||||
can_split = True
|
||||
# 检查分隔符左右是否有冒号(中英文),如果有则不分割
|
||||
if i > 0:
|
||||
prev_char = text[i - 1]
|
||||
if prev_char in {":", ":"}:
|
||||
can_split = False
|
||||
if i < len(text) - 1:
|
||||
next_char = text[i + 1]
|
||||
if next_char in {":", ":"}:
|
||||
can_split = False
|
||||
|
||||
# 如果左右没有冒号,再检查空格的特殊情况
|
||||
if can_split and char == " " and i > 0 and i < len(text) - 1:
|
||||
prev_char = text[i - 1]
|
||||
next_char = text[i + 1]
|
||||
# 不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格
|
||||
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
|
||||
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
|
||||
if prev_is_alnum and next_is_alnum:
|
||||
can_split = False
|
||||
|
||||
if can_split:
|
||||
# 只有当当前段不为空时才添加
|
||||
if current_segment:
|
||||
segments.append((current_segment, char))
|
||||
# 如果当前段为空,但分隔符是空格,则也添加一个空段(保留空格)
|
||||
elif char == " ":
|
||||
# 如果当前段为空,但分隔符是空格或换行符,则也添加一个空段(保留分隔符)
|
||||
elif char in {" ", "\n"}:
|
||||
segments.append(("", char))
|
||||
current_segment = ""
|
||||
else:
|
||||
@@ -641,6 +694,42 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
return is_group_chat, chat_target_info
|
||||
|
||||
|
||||
def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None:
|
||||
"""
|
||||
临时记录replyer动作被选择的信息(仅群聊)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reason: 选择理由
|
||||
think_level: 思考深度等级
|
||||
"""
|
||||
try:
|
||||
# 确保data/temp目录存在
|
||||
temp_dir = "data/temp"
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# 创建记录数据
|
||||
record_data = {
|
||||
"chat_id": chat_id,
|
||||
"reason": reason,
|
||||
"think_level": think_level,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# 生成文件名(使用时间戳避免冲突)
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"replyer_action_{timestamp_str}.json"
|
||||
filepath = os.path.join(temp_dir, filename)
|
||||
|
||||
# 写入文件
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(record_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}")
|
||||
except Exception as e:
|
||||
logger.warning(f"记录replyer动作选择失败: {e}")
|
||||
|
||||
|
||||
def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, DatabaseMessages]]:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID
|
||||
|
||||
@@ -130,12 +130,10 @@ class ImageManager:
|
||||
try:
|
||||
# 清理Images表中type为emoji的记录
|
||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||
|
||||
|
||||
# 清理ImageDescriptions表中type为emoji的记录
|
||||
deleted_descriptions = (
|
||||
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
)
|
||||
|
||||
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
|
||||
total_deleted = deleted_images + deleted_descriptions
|
||||
if total_deleted > 0:
|
||||
logger.info(
|
||||
@@ -166,7 +164,7 @@ class ImageManager:
|
||||
|
||||
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
|
||||
"""如果启用了steal_emoji且表情包未注册,保存文件到data/emoji目录
|
||||
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
image_hash: 图片的MD5哈希值
|
||||
@@ -174,7 +172,7 @@ class ImageManager:
|
||||
"""
|
||||
if not global_config.emoji.steal_emoji:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
@@ -236,12 +234,16 @@ class ImageManager:
|
||||
# 优先使用情感标签,如果没有则使用详细描述
|
||||
result_text = ""
|
||||
if cache_record.emotion_tags:
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||
)
|
||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||
elif cache_record.description:
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||
)
|
||||
result_text = f"[表情包:{cache_record.description}]"
|
||||
|
||||
|
||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||
if result_text:
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
|
||||
@@ -77,7 +77,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
is_emoji: bool = False,
|
||||
is_picid: bool = False,
|
||||
is_command: bool = False,
|
||||
is_no_read_command: bool = False,
|
||||
intercept_message_level: int = 0,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
user_id: str = "",
|
||||
@@ -120,7 +120,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.is_emoji = is_emoji
|
||||
self.is_picid = is_picid
|
||||
self.is_command = is_command
|
||||
self.is_no_read_command = is_no_read_command
|
||||
self.intercept_message_level = intercept_message_level
|
||||
self.is_notify = is_notify
|
||||
|
||||
self.selected_expressions = selected_expressions
|
||||
@@ -188,7 +188,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"is_emoji": self.is_emoji,
|
||||
"is_picid": self.is_picid,
|
||||
"is_command": self.is_command,
|
||||
"is_no_read_command": self.is_no_read_command,
|
||||
"intercept_message_level": self.intercept_message_level,
|
||||
"is_notify": self.is_notify,
|
||||
"selected_expressions": self.selected_expressions,
|
||||
"user_id": self.user_info.user_id,
|
||||
|
||||
@@ -22,7 +22,7 @@ class MessageAndActionModel(BaseDataModel):
|
||||
is_action_record: bool = field(default=False)
|
||||
action_name: Optional[str] = None
|
||||
is_command: bool = field(default=False)
|
||||
is_no_read_command: bool = field(default=False)
|
||||
intercept_message_level: int = field(default=0)
|
||||
|
||||
@classmethod
|
||||
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
||||
@@ -37,7 +37,7 @@ class MessageAndActionModel(BaseDataModel):
|
||||
display_message=message.display_message,
|
||||
chat_info_platform=message.chat_info.platform,
|
||||
is_command=message.is_command,
|
||||
is_no_read_command=getattr(message, "is_no_read_command", False),
|
||||
intercept_message_level=getattr(message, "intercept_message_level", 0),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ class Messages(BaseModel):
|
||||
is_emoji = BooleanField(default=False)
|
||||
is_picid = BooleanField(default=False)
|
||||
is_command = BooleanField(default=False)
|
||||
is_no_read_command = BooleanField(default=False)
|
||||
intercept_message_level = IntegerField(default=0)
|
||||
is_notify = BooleanField(default=False)
|
||||
|
||||
selected_expressions = TextField(null=True)
|
||||
@@ -324,9 +324,9 @@ class Expression(BaseModel):
|
||||
|
||||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
up_content = TextField(null=True)
|
||||
|
||||
content_list = TextField(null=True)
|
||||
style_list = TextField(null=True) # 存储相似的 style,格式与 content_list 相同(JSON 数组)
|
||||
count = IntegerField(default=1)
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
@@ -593,22 +593,41 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
|
||||
logger.info(f"已创建备份表 '{backup_table}'")
|
||||
|
||||
# 2. 删除原表
|
||||
# 2. 获取原始行数(在删除表之前)
|
||||
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
|
||||
logger.info(f"备份表 '{backup_table}' 包含 {original_count} 行数据")
|
||||
|
||||
# 3. 删除原表
|
||||
db.execute_sql(f"DROP TABLE {table_name}")
|
||||
logger.info(f"已删除原表 '{table_name}'")
|
||||
|
||||
# 3. 重新创建表(使用当前模型定义)
|
||||
# 4. 重新创建表(使用当前模型定义)
|
||||
db.create_tables([model])
|
||||
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
|
||||
|
||||
# 4. 从备份表恢复数据
|
||||
# 获取字段列表
|
||||
# 5. 从备份表恢复数据
|
||||
# 获取字段列表,排除主键字段(让数据库自动生成新的主键)
|
||||
fields = list(model._meta.fields.keys())
|
||||
fields_str = ", ".join(fields)
|
||||
# Peewee 默认使用 'id' 作为主键字段名
|
||||
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
|
||||
primary_key_name = "id" # 默认值
|
||||
try:
|
||||
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
|
||||
if hasattr(model._meta.primary_key, "name"):
|
||||
primary_key_name = model._meta.primary_key.name
|
||||
elif isinstance(model._meta.primary_key, str):
|
||||
primary_key_name = model._meta.primary_key
|
||||
except Exception:
|
||||
pass # 如果获取失败,使用默认值 'id'
|
||||
|
||||
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
|
||||
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
|
||||
# 如果字段列表包含主键,则排除它
|
||||
if primary_key_name in fields:
|
||||
fields_without_pk = [f for f in fields if f != primary_key_name]
|
||||
logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键")
|
||||
else:
|
||||
fields_without_pk = fields
|
||||
|
||||
fields_str = ", ".join(fields_without_pk)
|
||||
|
||||
# 检查是否有字段需要从 NULL 改为 NOT NULL
|
||||
null_to_notnull_fields = [
|
||||
@@ -621,7 +640,7 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
|
||||
# 构建更复杂的 SELECT 语句来处理 NULL 值
|
||||
select_fields = []
|
||||
for field_name in fields:
|
||||
for field_name in fields_without_pk:
|
||||
if field_name in null_to_notnull_fields:
|
||||
field_obj = model._meta.fields[field_name]
|
||||
# 根据字段类型设置默认值
|
||||
@@ -642,12 +661,13 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
|
||||
select_str = ", ".join(select_fields)
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
|
||||
else:
|
||||
# 没有需要处理 NULL 的字段,直接复制数据(排除主键)
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
|
||||
|
||||
db.execute_sql(insert_sql)
|
||||
logger.info(f"已从备份表恢复数据到 '{table_name}'")
|
||||
|
||||
# 5. 验证数据完整性
|
||||
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
|
||||
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
|
||||
|
||||
if original_count == new_count:
|
||||
|
||||
@@ -20,6 +20,9 @@ PROJECT_ROOT = logger_file.parent.parent.parent.resolve()
|
||||
_file_handler = None
|
||||
_console_handler = None
|
||||
_ws_handler = None
|
||||
# 全局标志,防止重复初始化
|
||||
_logging_initialized = False
|
||||
_cleanup_task_started = False
|
||||
|
||||
|
||||
def get_file_handler():
|
||||
@@ -869,29 +872,41 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
||||
return logger
|
||||
|
||||
|
||||
def initialize_logging():
|
||||
def initialize_logging(verbose: bool = True):
|
||||
"""手动初始化日志系统,确保所有logger都使用正确的配置
|
||||
|
||||
在应用程序的早期调用此函数,确保所有模块都使用统一的日志配置
|
||||
|
||||
Args:
|
||||
verbose: 是否输出详细的初始化信息。默认为 True。
|
||||
在 Runner 进程中可以设置为 False 以避免重复的初始化日志。
|
||||
"""
|
||||
global LOG_CONFIG
|
||||
global LOG_CONFIG, _logging_initialized
|
||||
|
||||
# 防止重复初始化(在同一进程内)
|
||||
if _logging_initialized:
|
||||
return
|
||||
|
||||
_logging_initialized = True
|
||||
|
||||
LOG_CONFIG = load_log_config()
|
||||
# print(LOG_CONFIG)
|
||||
configure_third_party_loggers()
|
||||
reconfigure_existing_loggers()
|
||||
|
||||
# 启动日志清理任务
|
||||
start_log_cleanup_task()
|
||||
start_log_cleanup_task(verbose=verbose)
|
||||
|
||||
# 输出初始化信息
|
||||
logger = get_logger("logger")
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
# 只在 verbose=True 时输出详细的初始化信息
|
||||
if verbose:
|
||||
logger = get_logger("logger")
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
|
||||
logger.info("日志系统已初始化:")
|
||||
logger.info(f" - 控制台级别: {console_level}")
|
||||
logger.info(f" - 文件级别: {file_level}")
|
||||
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
|
||||
logger.info("日志系统已初始化:")
|
||||
logger.info(f" - 控制台级别: {console_level}")
|
||||
logger.info(f" - 文件级别: {file_level}")
|
||||
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
|
||||
|
||||
|
||||
def cleanup_old_logs():
|
||||
@@ -924,8 +939,19 @@ def cleanup_old_logs():
|
||||
logger.error(f"清理旧日志文件时出错: {e}")
|
||||
|
||||
|
||||
def start_log_cleanup_task():
|
||||
"""启动日志清理任务"""
|
||||
def start_log_cleanup_task(verbose: bool = True):
|
||||
"""启动日志清理任务
|
||||
|
||||
Args:
|
||||
verbose: 是否输出启动信息。默认为 True。
|
||||
"""
|
||||
global _cleanup_task_started
|
||||
|
||||
# 防止重复启动清理任务
|
||||
if _cleanup_task_started:
|
||||
return
|
||||
|
||||
_cleanup_task_started = True
|
||||
|
||||
def cleanup_task():
|
||||
while True:
|
||||
@@ -935,8 +961,9 @@ def start_log_cleanup_task():
|
||||
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
logger = get_logger("logger")
|
||||
logger.info("已启动日志清理任务,将自动清理30天前的日志文件(轮转份数限制: 30个文件)")
|
||||
if verbose:
|
||||
logger = get_logger("logger")
|
||||
logger.info("已启动日志清理任务,将自动清理30天前的日志文件(轮转份数限制: 30个文件)")
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
|
||||
@@ -15,14 +15,18 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
# 检查maim_message版本
|
||||
try:
|
||||
maim_message_version = importlib.metadata.version("maim_message")
|
||||
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 3]
|
||||
version_int = [int(x) for x in maim_message_version.split(".")]
|
||||
version_compatible = version_int >= [0, 3, 3]
|
||||
# Check for API Server feature (>= 0.6.0)
|
||||
has_api_server_feature = version_int >= [0, 6, 0]
|
||||
except (importlib.metadata.PackageNotFoundError, ValueError):
|
||||
version_compatible = False
|
||||
has_api_server_feature = False
|
||||
|
||||
# 读取配置项
|
||||
maim_message_config = global_config.maim_message
|
||||
|
||||
# 设置基本参数
|
||||
# 设置基本参数 (Legacy Server Mode)
|
||||
kwargs = {
|
||||
"host": os.environ["HOST"],
|
||||
"port": int(os.environ["PORT"]),
|
||||
@@ -39,21 +43,129 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
|
||||
kwargs["enable_token"] = True
|
||||
|
||||
if maim_message_config.use_custom:
|
||||
# 添加WSS模式支持
|
||||
del kwargs["app"]
|
||||
kwargs["host"] = maim_message_config.host
|
||||
kwargs["port"] = maim_message_config.port
|
||||
kwargs["mode"] = maim_message_config.mode
|
||||
if maim_message_config.use_wss:
|
||||
if maim_message_config.cert_file:
|
||||
kwargs["ssl_certfile"] = maim_message_config.cert_file
|
||||
if maim_message_config.key_file:
|
||||
kwargs["ssl_keyfile"] = maim_message_config.key_file
|
||||
kwargs["enable_custom_uvicorn_logger"] = False
|
||||
# Removed legacy custom config block (use_custom) as requested.
|
||||
kwargs["enable_custom_uvicorn_logger"] = False
|
||||
|
||||
global_api = MessageServer(**kwargs)
|
||||
if version_compatible and maim_message_config.auth_token:
|
||||
for token in maim_message_config.auth_token:
|
||||
global_api.add_valid_token(token)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Additional API Server Configuration (maim_message >= 6.0)
|
||||
# ---------------------------------------------------------------------
|
||||
enable_api_server = maim_message_config.enable_api_server
|
||||
|
||||
# 如果版本支持且启用了API Server,则初始化额外服务器
|
||||
if has_api_server_feature and enable_api_server:
|
||||
try:
|
||||
from maim_message.server import WebSocketServer, ServerConfig
|
||||
from maim_message.message import APIMessageBase
|
||||
|
||||
api_logger = get_logger("maim_message_api_server")
|
||||
|
||||
# 1. Prepare Config
|
||||
api_server_host = maim_message_config.api_server_host
|
||||
api_server_port = maim_message_config.api_server_port
|
||||
use_wss = maim_message_config.api_server_use_wss
|
||||
|
||||
server_config = ServerConfig(
|
||||
host=api_server_host,
|
||||
port=api_server_port,
|
||||
ssl_enabled=use_wss,
|
||||
ssl_certfile=maim_message_config.api_server_cert_file if use_wss else None,
|
||||
ssl_keyfile=maim_message_config.api_server_key_file if use_wss else None,
|
||||
)
|
||||
|
||||
# 2. Setup Auth Handler
|
||||
async def auth_handler(metadata: dict) -> bool:
|
||||
allowed_keys = maim_message_config.api_server_allowed_api_keys
|
||||
# If list is empty/None, allow all (default behavior of returning True)
|
||||
if not allowed_keys:
|
||||
return True
|
||||
|
||||
api_key = metadata.get("api_key")
|
||||
if api_key in allowed_keys:
|
||||
return True
|
||||
|
||||
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
|
||||
return False
|
||||
|
||||
server_config.on_auth = auth_handler
|
||||
|
||||
# 3. Setup Message Bridge
|
||||
# Initialize refined route map if not exists
|
||||
if not hasattr(global_api, "platform_map"):
|
||||
global_api.platform_map = {}
|
||||
|
||||
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
|
||||
# Bridge message to the main bot logic
|
||||
# We convert APIMessageBase to dict to be compatible with legacy handlers
|
||||
# that MainBot (ChatManager) expects.
|
||||
msg_dict = message.to_dict()
|
||||
|
||||
# Compatibility Layer: Flatten sender_info to top-level user_info/group_info
|
||||
# Legacy MessageBase expects message_info to have user_info and group_info directly.
|
||||
if "message_info" in msg_dict:
|
||||
msg_info = msg_dict["message_info"]
|
||||
sender_info = msg_info.get("sender_info")
|
||||
if sender_info:
|
||||
# If direct user_info/group_info are missing, populate them from sender_info
|
||||
if "user_info" not in msg_info and (ui := sender_info.get("user_info")):
|
||||
msg_info["user_info"] = ui
|
||||
|
||||
if "group_info" not in msg_info and (gi := sender_info.get("group_info")):
|
||||
msg_info["group_info"] = gi
|
||||
|
||||
# Route Caching Logic: Simply map platform to API Key
|
||||
# This allows us to send messages back to the correct API client for this platform
|
||||
try:
|
||||
api_key = metadata.get("api_key")
|
||||
if api_key:
|
||||
platform = msg_info.get("platform")
|
||||
if platform:
|
||||
global_api.platform_map[platform] = api_key
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to update platform map: {e}")
|
||||
|
||||
# Compatibility Layer: Ensure raw_message exists (even if None) as it's part of MessageBase
|
||||
if "raw_message" not in msg_dict:
|
||||
msg_dict["raw_message"] = None
|
||||
|
||||
await global_api.process_message(msg_dict)
|
||||
|
||||
server_config.on_message = bridge_message_handler
|
||||
|
||||
# 4. Initialize Server
|
||||
extra_server = WebSocketServer(config=server_config)
|
||||
|
||||
# 5. Patch global_api lifecycle methods to manage both servers
|
||||
original_run = global_api.run
|
||||
original_stop = global_api.stop
|
||||
|
||||
async def patched_run():
|
||||
api_logger.info(f"Starting Additional API Server on {api_server_host}:{api_server_port} (WSS: {use_wss})")
|
||||
# Start the extra server (non-blocking start)
|
||||
await extra_server.start()
|
||||
# Run the original legacy server (this usually keeps running)
|
||||
await original_run()
|
||||
|
||||
async def patched_stop():
|
||||
api_logger.info("Stopping Additional API Server...")
|
||||
await extra_server.stop()
|
||||
await original_stop()
|
||||
|
||||
global_api.run = patched_run
|
||||
global_api.stop = patched_stop
|
||||
|
||||
# Attach for reference
|
||||
global_api.extra_server = extra_server
|
||||
|
||||
except ImportError:
|
||||
get_logger("maim_message").error("Cannot import maim_message.server components. Is maim_message >= 0.6.0 installed?")
|
||||
except Exception as e:
|
||||
get_logger("maim_message").error(f"Failed to initialize Additional API Server: {e}")
|
||||
import traceback
|
||||
get_logger("maim_message").debug(traceback.format_exc())
|
||||
|
||||
return global_api
|
||||
|
||||
@@ -25,7 +25,7 @@ def find_messages(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
filter_no_read_command=False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""
|
||||
根据提供的过滤器、排序和限制条件查找消息。
|
||||
@@ -85,8 +85,9 @@ def find_messages(
|
||||
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
|
||||
query = query.where(~Messages.is_command)
|
||||
|
||||
if filter_no_read_command:
|
||||
query = query.where(~Messages.is_no_read_command)
|
||||
if filter_intercept_message_level is not None:
|
||||
# 过滤掉所有 intercept_message_level > filter_intercept_message_level 的消息
|
||||
query = query.where(Messages.intercept_message_level <= filter_intercept_message_level)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
|
||||
@@ -4,6 +4,7 @@ TOML 工具函数
|
||||
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
import tomlkit
|
||||
from tomlkit.items import AoT, Table, Array
|
||||
@@ -33,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
return obj
|
||||
|
||||
# 决定是否多行:仅在顶层且长度超过阈值时
|
||||
should_multiline = (depth == 0 and len(obj) > threshold)
|
||||
should_multiline = depth == 0 and len(obj) > threshold
|
||||
|
||||
# 如果已经是 tomlkit Array,原地修改以保留注释
|
||||
if isinstance(obj, Array):
|
||||
@@ -45,7 +46,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
# 普通 list:转换为 tomlkit 数组
|
||||
arr = tomlkit.array()
|
||||
arr.multiline(should_multiline)
|
||||
|
||||
|
||||
for item in obj:
|
||||
arr.append(_format_toml_value(item, threshold, depth + 1))
|
||||
return arr
|
||||
@@ -54,14 +55,71 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def save_toml_with_format(data: Any, file_path: str, multiline_threshold: int = 1) -> None:
|
||||
"""格式化 TOML 数据并保存到文件"""
|
||||
def _update_toml_doc(target: Any, source: Any) -> None:
|
||||
"""
|
||||
递归合并字典,将 source 的值更新到 target 中,保留 target 的注释和格式。
|
||||
- 已存在的键:更新值(递归处理嵌套字典)
|
||||
- 新增的键:添加到 target
|
||||
- 跳过 version 字段
|
||||
"""
|
||||
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
|
||||
return
|
||||
|
||||
for key, value in source.items():
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
# 已存在的键:递归更新或直接赋值
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||
_update_toml_doc(target_value, value)
|
||||
else:
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
else:
|
||||
# 新增的键:添加到 target
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
|
||||
|
||||
def save_toml_with_format(
|
||||
data: Any, file_path: str, multiline_threshold: int = 1, preserve_comments: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
格式化 TOML 数据并保存到文件。
|
||||
|
||||
Args:
|
||||
data: 要保存的数据(dict 或 tomlkit 文档)
|
||||
file_path: 保存路径
|
||||
multiline_threshold: 数组多行格式化阈值,-1 表示不格式化
|
||||
preserve_comments: 是否保留原文件的注释和格式(默认 True)
|
||||
若为 True 且文件已存在且 data 不是 tomlkit 文档,会先读取原文件,再将 data 合并进去
|
||||
"""
|
||||
import os
|
||||
from tomlkit import TOMLDocument
|
||||
|
||||
# 如果需要保留注释、文件存在、且 data 不是已有的 tomlkit 文档,先读取原文件再合并
|
||||
if preserve_comments and os.path.exists(file_path) and not isinstance(data, TOMLDocument):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
doc = tomlkit.load(f)
|
||||
_update_toml_doc(doc, data)
|
||||
data = doc
|
||||
|
||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||
output = tomlkit.dumps(formatted)
|
||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||
output = re.sub(r"\n{3,}", "\n\n", output)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(formatted, f)
|
||||
f.write(output)
|
||||
|
||||
|
||||
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
||||
"""格式化 TOML 数据并返回字符串"""
|
||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||
return tomlkit.dumps(formatted)
|
||||
output = tomlkit.dumps(formatted)
|
||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||
return re.sub(r"\n{3,}", "\n\n", output)
|
||||
|
||||
@@ -60,6 +60,12 @@ class ModelInfo(ConfigBase):
|
||||
price_out: float = field(default=0.0)
|
||||
"""每M token输出价格"""
|
||||
|
||||
temperature: float | None = field(default=None)
|
||||
"""模型级别温度(可选),会覆盖任务配置中的温度"""
|
||||
|
||||
max_tokens: int | None = field(default=None)
|
||||
"""模型级别最大token数(可选),会覆盖任务配置中的max_tokens"""
|
||||
|
||||
force_stream_mode: bool = field(default=False)
|
||||
"""是否强制使用流式输出模式"""
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from src.config.official_configs import (
|
||||
RelationshipConfig,
|
||||
ToolConfig,
|
||||
VoiceConfig,
|
||||
MoodConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
JargonConfig,
|
||||
DreamConfig,
|
||||
WebUIConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
@@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.6"
|
||||
MMC_VERSION = "0.12.0"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
@@ -348,15 +348,15 @@ class Config(ConfigBase):
|
||||
response_post_process: ResponsePostProcessConfig
|
||||
response_splitter: ResponseSplitterConfig
|
||||
telemetry: TelemetryConfig
|
||||
webui: WebUIConfig
|
||||
experimental: ExperimentalConfig
|
||||
maim_message: MaimMessageConfig
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
tool: ToolConfig
|
||||
memory: MemoryConfig
|
||||
debug: DebugConfig
|
||||
mood: MoodConfig
|
||||
voice: VoiceConfig
|
||||
jargon: JargonConfig
|
||||
dream: DreamConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -43,10 +43,13 @@ class PersonalityConfig(ConfigBase):
|
||||
"""人格"""
|
||||
|
||||
reply_style: str = ""
|
||||
"""表达风格"""
|
||||
"""默认表达风格"""
|
||||
|
||||
interest: str = ""
|
||||
"""兴趣"""
|
||||
multiple_reply_style: list[str] = field(default_factory=lambda: [])
|
||||
"""可选的多种表达风格列表,当配置不为空时可按概率随机替换 reply_style"""
|
||||
|
||||
multiple_probability: float = 0.0
|
||||
"""每次构建回复时,从 multiple_reply_style 中随机替换 reply_style 的概率(0.0-1.0)"""
|
||||
|
||||
plan_style: str = ""
|
||||
"""说话规则,行为风格"""
|
||||
@@ -79,12 +82,6 @@ class ChatConfig(ConfigBase):
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
interest_rate_mode: Literal["fast", "accurate"] = "fast"
|
||||
"""兴趣值计算模式,fast为快速计算,accurate为精确计算"""
|
||||
|
||||
planner_size: float = 1.5
|
||||
"""副规划器大小,越小,麦麦的动作执行能力越精细,但是消耗更多token,调大可以缓解429类错误"""
|
||||
|
||||
mentioned_bot_reply: bool = True
|
||||
"""是否启用提及必回复"""
|
||||
|
||||
@@ -117,8 +114,13 @@ class ChatConfig(ConfigBase):
|
||||
时间区间支持跨夜,例如 "23:00-02:00"。
|
||||
"""
|
||||
|
||||
include_planner_reasoning: bool = False
|
||||
"""是否将planner推理加入replyer,默认关闭(不加入)"""
|
||||
think_mode: Literal["classic", "deep", "dynamic"] = "classic"
|
||||
"""
|
||||
思考模式配置
|
||||
- classic: 默认think_level为0(轻量回复,不需要思考和回忆)
|
||||
- deep: 默认think_level为1(深度回复,需要进行回忆和思考)
|
||||
- dynamic: think_level由planner动态给出(根据planner返回的think_level决定)
|
||||
"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
|
||||
@@ -133,14 +135,9 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
is_group = stream_type == "group"
|
||||
|
||||
import hashlib
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
@@ -173,7 +170,11 @@ class ChatConfig(ConfigBase):
|
||||
def get_talk_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 talk_value,未匹配则回退到基础值。"""
|
||||
if not self.enable_talk_value_rules or not self.talk_value_rules:
|
||||
return self.talk_value
|
||||
result = self.talk_value
|
||||
# 防止返回0值,自动转换为0.0001
|
||||
if result == 0:
|
||||
return 0.0000001
|
||||
return result
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
@@ -199,7 +200,11 @@ class ChatConfig(ConfigBase):
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
result = float(value)
|
||||
# 防止返回0值,自动转换为0.0001
|
||||
if result == 0:
|
||||
return 0.0000001
|
||||
return result
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -218,12 +223,20 @@ class ChatConfig(ConfigBase):
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
result = float(value)
|
||||
# 防止返回0值,自动转换为0.0001
|
||||
if result == 0:
|
||||
return 0.0000001
|
||||
return result
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 3) 未命中规则返回基础值
|
||||
return self.talk_value
|
||||
result = self.talk_value
|
||||
# 防止返回0值,自动转换为0.0001
|
||||
if result == 0:
|
||||
return 0.0000001
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -244,13 +257,21 @@ class MemoryConfig(ConfigBase):
|
||||
max_agent_iterations: int = 5
|
||||
"""Agent最多迭代轮数(最低为1)"""
|
||||
|
||||
agent_timeout_seconds: float = 120.0
|
||||
"""Agent超时时间(秒)"""
|
||||
|
||||
enable_jargon_detection: bool = True
|
||||
"""记忆检索过程中是否启用黑话识别"""
|
||||
|
||||
global_memory: bool = False
|
||||
"""是否允许记忆检索在聊天记录中进行全局查询(忽略当前chat_id,仅对 search_chat_history 等工具生效)"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置值"""
|
||||
if self.max_agent_iterations < 1:
|
||||
raise ValueError(f"max_agent_iterations 必须至少为1,当前值: {self.max_agent_iterations}")
|
||||
if self.agent_timeout_seconds <= 0:
|
||||
raise ValueError(f"agent_timeout_seconds 必须大于0,当前值: {self.agent_timeout_seconds}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -260,20 +281,20 @@ class ExpressionConfig(ConfigBase):
|
||||
learning_list: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
|
||||
格式: [["chat_stream_id", "use_expression", "enable_learning", "enable_jargon_learning"], ...]
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5
|
||||
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||
["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习
|
||||
["qq:1919810:private", "enable", "enable", "enable"], # 特定私聊配置:使用表达,启用学习,启用jargon学习
|
||||
["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习
|
||||
]
|
||||
|
||||
说明:
|
||||
- 第一位: chat_stream_id,空字符串表示全局配置
|
||||
- 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||
- 第三位: 是否学习表达 ("enable"/"disable")
|
||||
- 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒)
|
||||
- 第四位: 是否启用jargon学习 ("enable"/"disable")
|
||||
"""
|
||||
|
||||
expression_groups: list[list[str]] = field(default_factory=list)
|
||||
@@ -296,6 +317,19 @@ class ExpressionConfig(ConfigBase):
|
||||
如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true)
|
||||
"""
|
||||
|
||||
all_global_jargon: bool = False
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id。注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除"""
|
||||
|
||||
enable_jargon_explanation: bool = True
|
||||
"""是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习)"""
|
||||
|
||||
jargon_mode: Literal["context", "planner"] = "context"
|
||||
"""
|
||||
黑话解释来源模式:
|
||||
- "context": 使用上下文自动匹配黑话并解释(原有模式)
|
||||
- "planner": 仅使用 Planner 在 reply 动作中给出的 unknown_words 列表进行黑话检索
|
||||
"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""
|
||||
解析流配置字符串并生成对应的 chat_id
|
||||
@@ -318,20 +352,15 @@ class ExpressionConfig(ConfigBase):
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
|
||||
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
@@ -339,11 +368,11 @@ class ExpressionConfig(ConfigBase):
|
||||
chat_stream_id: 聊天流ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔)
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
if not self.learning_list:
|
||||
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
|
||||
return True, True, 300
|
||||
# 如果没有配置,使用默认值:启用表达,启用学习,启用jargon学习
|
||||
return True, True, True
|
||||
|
||||
# 优先检查聊天流特定的配置
|
||||
if chat_stream_id:
|
||||
@@ -356,10 +385,10 @@ class ExpressionConfig(ConfigBase):
|
||||
if global_expression_config is not None:
|
||||
return global_expression_config
|
||||
|
||||
# 如果都没有匹配,返回默认值
|
||||
return True, True, 300
|
||||
# 如果都没有匹配,返回默认值:启用表达,启用学习,启用jargon学习
|
||||
return True, True, True
|
||||
|
||||
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
|
||||
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, bool]]:
|
||||
"""
|
||||
获取特定聊天流的表达配置
|
||||
|
||||
@@ -367,7 +396,7 @@ class ExpressionConfig(ConfigBase):
|
||||
chat_stream_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.learning_list:
|
||||
if not config_item or len(config_item) < 4:
|
||||
@@ -392,19 +421,19 @@ class ExpressionConfig(ConfigBase):
|
||||
try:
|
||||
use_expression: bool = config_item[1].lower() == "enable"
|
||||
enable_learning: bool = config_item[2].lower() == "enable"
|
||||
learning_intensity: float = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||
enable_jargon_learning: bool = config_item[3].lower() == "enable"
|
||||
return use_expression, enable_learning, enable_jargon_learning # type: ignore
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
|
||||
def _get_global_config(self) -> Optional[tuple[bool, bool, bool]]:
|
||||
"""
|
||||
获取全局表达配置
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.learning_list:
|
||||
if not config_item or len(config_item) < 4:
|
||||
@@ -415,8 +444,8 @@ class ExpressionConfig(ConfigBase):
|
||||
try:
|
||||
use_expression: bool = config_item[1].lower() == "enable"
|
||||
enable_learning: bool = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||
enable_jargon_learning: bool = config_item[3].lower() == "enable"
|
||||
return use_expression, enable_learning, enable_jargon_learning # type: ignore
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
@@ -431,20 +460,6 @@ class ToolConfig(ConfigBase):
|
||||
"""是否在聊天中启用工具"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoodConfig(ConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
enable_mood: bool = True
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
mood_update_threshold: float = 1
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
|
||||
"""情感特征,影响情绪的变化情况"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
@@ -582,6 +597,35 @@ class TelemetryConfig(ConfigBase):
|
||||
"""是否启用遥测"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebUIConfig(ConfigBase):
|
||||
"""WebUI配置类
|
||||
|
||||
注意: host 和 port 配置已移至环境变量 WEBUI_HOST 和 WEBUI_PORT
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
"""是否启用WebUI"""
|
||||
|
||||
mode: Literal["development", "production"] = "production"
|
||||
"""运行模式:development(开发) 或 production(生产)"""
|
||||
|
||||
anti_crawler_mode: Literal["false", "strict", "loose", "basic"] = "basic"
|
||||
"""防爬虫模式:false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)"""
|
||||
|
||||
allowed_ips: str = "127.0.0.1"
|
||||
"""IP白名单(逗号分隔,支持精确IP、CIDR格式和通配符)"""
|
||||
|
||||
trusted_proxies: str = ""
|
||||
"""信任的代理IP列表(逗号分隔),只有来自这些IP的X-Forwarded-For才被信任"""
|
||||
|
||||
trust_xff: bool = False
|
||||
"""是否启用X-Forwarded-For代理解析(默认false)"""
|
||||
|
||||
secure_cookie: bool = False
|
||||
"""是否启用安全Cookie(仅通过HTTPS传输,默认false)"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugConfig(ConfigBase):
|
||||
"""调试配置类"""
|
||||
@@ -639,29 +683,29 @@ class ExperimentalConfig(ConfigBase):
|
||||
class MaimMessageConfig(ConfigBase):
|
||||
"""maim_message配置类"""
|
||||
|
||||
use_custom: bool = False
|
||||
"""是否使用自定义的maim_message配置"""
|
||||
|
||||
host: str = "127.0.0.1"
|
||||
"""主机地址"""
|
||||
|
||||
port: int = 8090
|
||||
""""端口号"""
|
||||
|
||||
mode: Literal["ws", "tcp"] = "ws"
|
||||
"""连接模式,支持ws和tcp"""
|
||||
|
||||
use_wss: bool = False
|
||||
"""是否使用WSS安全连接"""
|
||||
|
||||
cert_file: str = ""
|
||||
"""SSL证书文件路径,仅在use_wss=True时有效"""
|
||||
|
||||
key_file: str = ""
|
||||
"""SSL密钥文件路径,仅在use_wss=True时有效"""
|
||||
|
||||
auth_token: list[str] = field(default_factory=lambda: [])
|
||||
"""认证令牌,用于API验证,为空则不启用验证"""
|
||||
"""认证令牌,用于旧版API验证,为空则不启用验证"""
|
||||
|
||||
enable_api_server: bool = False
|
||||
"""是否启用额外的新版API Server"""
|
||||
|
||||
api_server_host: str = "0.0.0.0"
|
||||
"""新版API Server主机地址"""
|
||||
|
||||
api_server_port: int = 8090
|
||||
"""新版API Server端口号"""
|
||||
|
||||
api_server_use_wss: bool = False
|
||||
"""新版API Server是否启用WSS"""
|
||||
|
||||
api_server_cert_file: str = ""
|
||||
"""新版API Server SSL证书文件路径"""
|
||||
|
||||
api_server_key_file: str = ""
|
||||
"""新版API Server SSL密钥文件路径"""
|
||||
|
||||
api_server_allowed_api_keys: list[str] = field(default_factory=lambda: [])
|
||||
"""新版API Server允许的API Key列表,为空则允许所有连接"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -707,10 +751,107 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
embedding_dimension: int = 1024
|
||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||
|
||||
max_embedding_workers: int = 3
|
||||
"""嵌入/抽取并发线程数"""
|
||||
|
||||
embedding_chunk_size: int = 4
|
||||
"""每批嵌入的条数"""
|
||||
|
||||
max_synonym_entities: int = 2000
|
||||
"""同义边参与的实体数上限,超限则跳过"""
|
||||
|
||||
enable_ppr: bool = True
|
||||
"""是否启用PPR,低配机器可关闭"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class JargonConfig(ConfigBase):
|
||||
"""Jargon配置类"""
|
||||
class DreamConfig(ConfigBase):
|
||||
"""Dream配置类"""
|
||||
|
||||
all_global: bool = False
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id"""
|
||||
interval_minutes: int = 30
|
||||
"""做梦时间间隔(分钟),默认30分钟"""
|
||||
|
||||
max_iterations: int = 20
|
||||
"""做梦最大轮次,默认20轮"""
|
||||
|
||||
first_delay_seconds: int = 60
|
||||
"""程序启动后首次做梦前的延迟时间(秒),默认60秒"""
|
||||
|
||||
dream_send: str = ""
|
||||
"""
|
||||
做梦结果推送目标,格式为 "platform:user_id"
|
||||
例如: "qq:123456" 表示在做梦结束后,将梦境文本额外发送给该QQ私聊用户。
|
||||
为空字符串时不推送。
|
||||
"""
|
||||
|
||||
dream_time_ranges: list[str] = field(default_factory=lambda: [])
|
||||
"""
|
||||
做梦时间段配置列表,格式:["HH:MM-HH:MM", ...]
|
||||
如果列表为空,则表示全天允许做梦。
|
||||
如果配置了时间段,则只有在这些时间段内才会实际执行做梦流程。
|
||||
时间段外,调度器仍会按间隔检查,但不会进入做梦流程。
|
||||
|
||||
示例:
|
||||
[
|
||||
"09:00-22:00", # 白天允许做梦
|
||||
"23:00-02:00", # 跨夜时间段(23:00到次日02:00)
|
||||
]
|
||||
|
||||
支持跨夜区间,例如 "23:00-02:00" 表示从23:00到次日02:00。
|
||||
"""
|
||||
|
||||
def _now_minutes(self) -> int:
|
||||
"""返回本地时间的分钟数(0-1439)。"""
|
||||
lt = time.localtime()
|
||||
return lt.tm_hour * 60 + lt.tm_min
|
||||
|
||||
def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
|
||||
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
|
||||
try:
|
||||
start_str, end_str = [s.strip() for s in range_str.split("-")]
|
||||
sh, sm = [int(x) for x in start_str.split(":")]
|
||||
eh, em = [int(x) for x in end_str.split(":")]
|
||||
return sh * 60 + sm, eh * 60 + em
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
|
||||
"""
|
||||
判断 now_min 是否在 [start_min, end_min] 区间内。
|
||||
支持跨夜:如果 start > end,则表示跨越午夜。
|
||||
"""
|
||||
if start_min <= end_min:
|
||||
return start_min <= now_min <= end_min
|
||||
# 跨夜:例如 23:00-02:00
|
||||
return now_min >= start_min or now_min <= end_min
|
||||
|
||||
def is_in_dream_time(self) -> bool:
|
||||
"""
|
||||
检查当前时间是否在允许做梦的时间段内。
|
||||
如果 dream_time_ranges 为空,则返回 True(全天允许)。
|
||||
"""
|
||||
if not self.dream_time_ranges:
|
||||
return True
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
for time_range in self.dream_time_ranges:
|
||||
if not isinstance(time_range, str):
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置值"""
|
||||
if self.interval_minutes < 1:
|
||||
raise ValueError(f"interval_minutes 必须至少为1,当前值: {self.interval_minutes}")
|
||||
if self.max_iterations < 1:
|
||||
raise ValueError(f"max_iterations 必须至少为1,当前值: {self.max_iterations}")
|
||||
if self.first_delay_seconds < 0:
|
||||
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")
|
||||
|
||||
580
src/dream/dream_agent.py
Normal file
580
src/dream/dream_agent.py
Normal file
@@ -0,0 +1,580 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.dream.dream_generator import generate_dream_summary
|
||||
|
||||
# dream 工具工厂函数
|
||||
from src.dream.tools.search_chat_history_tool import make_search_chat_history
|
||||
from src.dream.tools.get_chat_history_detail_tool import make_get_chat_history_detail
|
||||
from src.dream.tools.delete_chat_history_tool import make_delete_chat_history
|
||||
from src.dream.tools.create_chat_history_tool import make_create_chat_history
|
||||
from src.dream.tools.update_chat_history_tool import make_update_chat_history
|
||||
from src.dream.tools.finish_maintenance_tool import make_finish_maintenance
|
||||
from src.dream.tools.search_jargon_tool import make_search_jargon
|
||||
from src.dream.tools.delete_jargon_tool import make_delete_jargon
|
||||
from src.dream.tools.update_jargon_tool import make_update_jargon
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def init_dream_prompts() -> None:
|
||||
"""初始化 dream agent 的提示词"""
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是{bot_name},你现在处于"梦境维护模式(dream agent)"。
|
||||
你可以自由地在 ChatHistory 库中探索、整理、创建和删改记录,以帮助自己在未来更好地回忆和理解对话历史。
|
||||
|
||||
本轮要维护的聊天ID:{chat_id}
|
||||
本轮随机选中的起始记忆 ID:{start_memory_id}
|
||||
请优先以这条起始记忆为切入点,先理解它的内容与上下文,再决定如何在其附近进行创建新概括、重写或删除等整理操作;如果起始记忆为空,则由你自行选择合适的切入点。
|
||||
|
||||
你可以使用的工具包括:
|
||||
**ChatHistory 维护工具:**
|
||||
- search_chat_history:根据关键词或参与人搜索该 chat_id 下的历史记忆概括列表
|
||||
- get_chat_history_detail:查看某条概括的详细内容
|
||||
- create_chat_history:根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词、关键信息等)
|
||||
- update_chat_history:在不改变事实的前提下重写或精炼主题、概括、关键词、关键信息
|
||||
- delete_chat_history:删除明显冗余、噪声、错误或无意义的记录,或者非常有时效性的信息,或者无太多有用信息的日常互动。
|
||||
你也可以先用 create_chat_history 创建一条新的综合概括,再对旧的冗余记录执行多次 delete_chat_history 来完成“合并”效果。
|
||||
|
||||
**Jargon(黑话)维护工具(只读,禁止修改):**
|
||||
- search_jargon:根据一个或多个关键词搜索Jargon 记录,通常是含义不明确的词条或者特殊的缩写
|
||||
|
||||
**通用工具:**
|
||||
- finish_maintenance:当你认为当前维护工作已经完成,没有更多需要整理的内容时,调用此工具来结束本次运行
|
||||
|
||||
**工作目标**:
|
||||
- 发现冗余、重复或高度相似的记录,并进行合并或删除;
|
||||
- 发现主题/概括过于含糊、啰嗦或缺少关键信息的记录,进行重写和精简;
|
||||
- summary要尽可能保持有用的信息;
|
||||
- 尽量保持信息的真实与可用性,不要凭空捏造事实。
|
||||
|
||||
**合并准则**
|
||||
- 你可以新建一个记录,然后删除旧记录来实现合并。
|
||||
- 如果两个或多个记录的主题相似,内容是对主题不同方面的信息或讨论,且信息量较少,则可以合并为一条记录。
|
||||
- 如果两个记录冲突,可以根据逻辑保留一个或者进行整合,也可以采取更新的记录,删除旧的记录
|
||||
|
||||
**轮次信息**:
|
||||
- 本次维护最多执行 {max_iterations} 轮
|
||||
- 每轮开始时,系统会告知你当前是第几轮,还剩多少轮
|
||||
- 如果提前完成维护工作,可以调用 finish_maintenance 工具主动结束
|
||||
|
||||
**每一轮的执行方式(必须遵守):**
|
||||
- 第一步:先用一小段中文自然语言,写出你的「思考」和本轮计划(例如要查什么、准备怎么合并/修改);
|
||||
- 第二步:在这段思考之后,再通过工具调用来执行你的计划(可以调用 0~N 个工具);
|
||||
- 第三步:收到工具结果后,在下一轮继续先写出新的思考,再视情况继续调用工具。
|
||||
|
||||
请不要在没有先写出思考的情况下直接调用工具。
|
||||
只输出你的思考内容或工具调用结果,由系统负责真正执行工具调用。
|
||||
""",
|
||||
name="dream_react_head_prompt",
|
||||
)
|
||||
|
||||
|
||||
class DreamTool:
|
||||
"""dream 模块内部使用的简易工具封装"""
|
||||
|
||||
def __init__(self, name: str, description: str, parameters: List[Tuple], execute_func):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
self.execute_func = execute_func
|
||||
|
||||
def get_tool_definition(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> str:
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
|
||||
class DreamToolRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.tools: Dict[str, DreamTool] = {}
|
||||
|
||||
def register_tool(self, tool: DreamTool) -> None:
|
||||
"""
|
||||
注册或更新 dream 工具。
|
||||
注意:dream agent 每个 chat_id 会重新初始化工具,这里允许覆盖已有同名工具。
|
||||
"""
|
||||
self.tools[tool.name] = tool
|
||||
logger.info(f"注册/更新 dream 工具: {tool.name}")
|
||||
|
||||
def get_tool(self, name: str) -> Optional[DreamTool]:
|
||||
return self.tools.get(name)
|
||||
|
||||
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
return [tool.get_tool_definition() for tool in self.tools.values()]
|
||||
|
||||
|
||||
_dream_tool_registry = DreamToolRegistry()
|
||||
|
||||
|
||||
def get_dream_tool_registry() -> DreamToolRegistry:
|
||||
return _dream_tool_registry
|
||||
|
||||
|
||||
def init_dream_tools(chat_id: str) -> None:
|
||||
"""注册 dream agent 可用的 ChatHistory / Jargon 相关工具(限定在当前 chat_id 作用域内)"""
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
|
||||
# 通过工厂函数生成绑定当前 chat_id 的工具实现
|
||||
search_chat_history = make_search_chat_history(chat_id)
|
||||
get_chat_history_detail = make_get_chat_history_detail(chat_id)
|
||||
delete_chat_history = make_delete_chat_history(chat_id)
|
||||
create_chat_history = make_create_chat_history(chat_id)
|
||||
update_chat_history = make_update_chat_history(chat_id)
|
||||
finish_maintenance = make_finish_maintenance(chat_id)
|
||||
|
||||
search_jargon = make_search_jargon(chat_id)
|
||||
delete_jargon = make_delete_jargon(chat_id)
|
||||
update_jargon = make_update_jargon(chat_id)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"search_chat_history",
|
||||
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
|
||||
[
|
||||
(
|
||||
"keyword",
|
||||
ToolParamType.STRING,
|
||||
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
|
||||
],
|
||||
search_chat_history,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"get_chat_history_detail",
|
||||
"根据 memory_id 获取单条 ChatHistory 的详细内容,包含主题、概括、关键词、关键信息等字段(不包含原文)。",
|
||||
[
|
||||
("memory_id", ToolParamType.INTEGER, "ChatHistory 主键 ID。", True, None),
|
||||
],
|
||||
get_chat_history_detail,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"delete_chat_history",
|
||||
"根据 memory_id 删除一条 ChatHistory 记录(请谨慎使用)。",
|
||||
[
|
||||
("memory_id", ToolParamType.INTEGER, "需要删除的 ChatHistory 主键 ID。", True, None),
|
||||
],
|
||||
delete_chat_history,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"update_chat_history",
|
||||
"按字段更新 ChatHistory 记录,可用于清理、重写或补充信息。",
|
||||
[
|
||||
("memory_id", ToolParamType.INTEGER, "需要更新的 ChatHistory 主键 ID。", True, None),
|
||||
("theme", ToolParamType.STRING, "新的主题标题,如果不需要修改可不填。", False, None),
|
||||
("summary", ToolParamType.STRING, "新的概括内容,如果不需要修改可不填。", False, None),
|
||||
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2']。", False, None),
|
||||
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2']。", False, None),
|
||||
],
|
||||
update_chat_history,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"create_chat_history",
|
||||
"根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词、关键信息等)。",
|
||||
[
|
||||
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
|
||||
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
|
||||
(
|
||||
"keywords",
|
||||
ToolParamType.STRING,
|
||||
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
|
||||
True,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"key_point",
|
||||
ToolParamType.STRING,
|
||||
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
|
||||
True,
|
||||
None,
|
||||
),
|
||||
("start_time", ToolParamType.STRING, "起始时间戳(秒,Unix 时间,必填)。", True, None),
|
||||
("end_time", ToolParamType.STRING, "结束时间戳(秒,Unix 时间,必填)。", True, None),
|
||||
],
|
||||
create_chat_history,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"finish_maintenance",
|
||||
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
|
||||
[
|
||||
(
|
||||
"reason",
|
||||
ToolParamType.STRING,
|
||||
"结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
],
|
||||
finish_maintenance,
|
||||
)
|
||||
)
|
||||
|
||||
# ==================== Jargon 维护工具 ====================
|
||||
# 注册 Jargon 工具
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"search_jargon",
|
||||
"根据一个或多个关键词搜索当前 chat_id 相关的 Jargon 记录概览(只包含 is_jargon=True,含全局 Jargon),便于快速理解黑话库。",
|
||||
[
|
||||
("keyword", ToolParamType.STRING, "按一个或多个关键词搜索内容/含义/推断结果(必填)。", True, None),
|
||||
],
|
||||
search_jargon,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def run_dream_agent_once(
|
||||
chat_id: str,
|
||||
max_iterations: Optional[int] = None,
|
||||
start_memory_id: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
运行一次 dream agent,对指定 chat_id 的 ChatHistory 进行最多 max_iterations 轮的整理。
|
||||
如果 max_iterations 为 None,则使用配置文件中的默认值。
|
||||
"""
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.dream.max_iterations
|
||||
|
||||
start_ts = time.time()
|
||||
logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations} 轮")
|
||||
|
||||
# 初始化工具(作用域限定在当前 chat_id)
|
||||
init_dream_tools(chat_id)
|
||||
|
||||
tool_registry = get_dream_tool_registry()
|
||||
tool_defs = tool_registry.get_tool_definitions()
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
head_prompt = await global_prompt_manager.format_prompt(
|
||||
"dream_react_head_prompt",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
chat_id=chat_id,
|
||||
start_memory_id=start_memory_id if start_memory_id is not None else "无(本轮由你自由选择切入点)",
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
conversation_messages: List[Message] = []
|
||||
|
||||
# 如果提供了起始记忆 ID,则在对话正式开始前,先把这条记忆的详细信息放入上下文,
|
||||
# 避免 LLM 还需要额外调用一次 get_chat_history_detail 才能看到起始记忆内容。
|
||||
if start_memory_id is not None:
|
||||
try:
|
||||
record = ChatHistory.get_or_none(ChatHistory.id == start_memory_id)
|
||||
if record:
|
||||
start_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
|
||||
if record.start_time
|
||||
else "未知"
|
||||
)
|
||||
end_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||
)
|
||||
detail_text = (
|
||||
f"ID={record.id}\n"
|
||||
f"chat_id={record.chat_id}\n"
|
||||
f"时间范围={start_time_str} 至 {end_time_str}\n"
|
||||
f"主题={record.theme or '无'}\n"
|
||||
f"关键词={record.keywords or '无'}\n"
|
||||
f"参与者={record.participants or '无'}\n"
|
||||
f"概括={record.summary or '无'}\n"
|
||||
f"关键信息={record.key_point or '无'}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[dream] 预加载起始记忆详情 memory_id={start_memory_id},"
|
||||
f"预览: {detail_text[:200].replace(chr(10), ' ')}"
|
||||
)
|
||||
|
||||
start_detail_builder = MessageBuilder()
|
||||
start_detail_builder.set_role(RoleType.User)
|
||||
start_detail_builder.add_text_content(
|
||||
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
|
||||
)
|
||||
conversation_messages.append(start_detail_builder.build())
|
||||
else:
|
||||
logger.warning(
|
||||
f"[dream] 提供的 start_memory_id={start_memory_id} 未找到对应 ChatHistory 记录,"
|
||||
"将不预加载起始记忆详情。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 预加载起始记忆详情失败 start_memory_id={start_memory_id}: {e}")
|
||||
|
||||
# 注意:message_factory 必须是同步函数,返回消息列表(不能是 async/coroutine)
|
||||
def message_factory(
|
||||
_client,
|
||||
*,
|
||||
_head_prompt: str = head_prompt,
|
||||
_conversation_messages: List[Message] = conversation_messages,
|
||||
) -> List[Message]:
|
||||
messages: List[Message] = []
|
||||
system_builder = MessageBuilder()
|
||||
system_builder.set_role(RoleType.System)
|
||||
system_builder.add_text_content(_head_prompt)
|
||||
messages.append(system_builder.build())
|
||||
messages.extend(_conversation_messages)
|
||||
return messages
|
||||
|
||||
for iteration in range(1, max_iterations + 1):
|
||||
# 在每轮开始时,添加轮次信息到对话中
|
||||
remaining_rounds = max_iterations - iteration + 1
|
||||
round_info_builder = MessageBuilder()
|
||||
round_info_builder.set_role(RoleType.User)
|
||||
round_info_builder.add_text_content(
|
||||
f"【轮次信息】当前是第 {iteration}/{max_iterations} 轮,还剩 {remaining_rounds} 轮。"
|
||||
)
|
||||
conversation_messages.append(round_info_builder.build())
|
||||
|
||||
# 调用 LLM 让其决定是否要使用工具
|
||||
(
|
||||
success,
|
||||
response,
|
||||
reasoning_content,
|
||||
model_name,
|
||||
tool_calls,
|
||||
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||
message_factory,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=tool_defs,
|
||||
request_type="dream.react",
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[dream] 第 {iteration} 轮 LLM 调用失败: {response}")
|
||||
break
|
||||
|
||||
# 先输出「思考」内容,再输出工具调用信息(思考文本较长,仅在 debug 下输出)
|
||||
thought_log = reasoning_content or (response[:300] if response else "")
|
||||
if thought_log:
|
||||
logger.debug(f"[dream] 第 {iteration} 轮思考内容: {thought_log}")
|
||||
|
||||
logger.info(
|
||||
f"[dream] 第 {iteration} 轮响应,模型={model_name},工具调用数={len(tool_calls) if tool_calls else 0}"
|
||||
)
|
||||
|
||||
assistant_msg: Optional[Message] = None
|
||||
if tool_calls:
|
||||
builder = MessageBuilder()
|
||||
builder.set_role(RoleType.Assistant)
|
||||
if response and response.strip():
|
||||
builder.add_text_content(response)
|
||||
builder.set_tool_calls(tool_calls)
|
||||
assistant_msg = builder.build()
|
||||
elif response and response.strip():
|
||||
builder = MessageBuilder()
|
||||
builder.set_role(RoleType.Assistant)
|
||||
builder.add_text_content(response)
|
||||
assistant_msg = builder.build()
|
||||
|
||||
if assistant_msg:
|
||||
conversation_messages.append(assistant_msg)
|
||||
|
||||
# 如果本轮没有工具调用,仅作为思考记录,继续下一轮
|
||||
if not tool_calls:
|
||||
logger.debug(f"[dream] 第 {iteration} 轮未调用任何工具,仅记录思考。")
|
||||
continue
|
||||
|
||||
# 执行所有工具调用
|
||||
tasks = []
|
||||
finish_maintenance_called = False
|
||||
for tc in tool_calls:
|
||||
tool = tool_registry.get_tool(tc.func_name)
|
||||
if not tool:
|
||||
logger.warning(f"[dream] 未知工具:{tc.func_name}")
|
||||
continue
|
||||
|
||||
# 检测是否调用了 finish_maintenance 工具
|
||||
if tc.func_name == "finish_maintenance":
|
||||
finish_maintenance_called = True
|
||||
|
||||
params = tc.args or {}
|
||||
|
||||
async def _run_single(t: DreamTool, p: Dict[str, Any], call_id: str, it: int):
|
||||
try:
|
||||
result = await t.execute(**p)
|
||||
logger.debug(f"[dream] 第 {it} 轮 工具 {t.name} 执行完成")
|
||||
return call_id, result
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 工具 {t.name} 执行失败: {e}")
|
||||
return call_id, f"工具 {t.name} 执行失败: {e}"
|
||||
|
||||
tasks.append(_run_single(tool, params, tc.call_id, iteration))
|
||||
|
||||
if not tasks:
|
||||
continue
|
||||
|
||||
tool_results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 将工具结果作为 Tool 消息追加
|
||||
for call_id, obs in tool_results:
|
||||
tool_builder = MessageBuilder()
|
||||
tool_builder.set_role(RoleType.Tool)
|
||||
tool_builder.add_text_content(str(obs))
|
||||
tool_builder.add_tool_call(call_id)
|
||||
conversation_messages.append(tool_builder.build())
|
||||
|
||||
# 如果调用了 finish_maintenance 工具,提前结束本次运行
|
||||
if finish_maintenance_called:
|
||||
logger.info(f"[dream] 第 {iteration} 轮检测到 finish_maintenance 工具调用,提前结束本次维护。")
|
||||
break
|
||||
|
||||
cost = time.time() - start_ts
|
||||
logger.info(f"[dream] 对 chat_id={chat_id} 的 dream 维护结束,共迭代 {iteration} 轮,耗时 {cost:.1f} 秒")
|
||||
|
||||
# 生成梦境总结
|
||||
await generate_dream_summary(chat_id, conversation_messages, iteration, cost)
|
||||
|
||||
|
||||
def _pick_random_chat_id() -> Optional[str]:
|
||||
"""从 ChatHistory 中随机选择一个 chat_id,用于 dream agent 本次维护
|
||||
|
||||
规则:
|
||||
- 只在 chat_id 所属的 ChatHistory 记录数 >= 10 时才会参与随机选择;
|
||||
- 记录数不足 10 的 chat_id 将被跳过,不会触发做梦 react。
|
||||
"""
|
||||
try:
|
||||
# 统计每个 chat_id 的记录数,只保留记录数 >= 10 的 chat_id
|
||||
rows = (
|
||||
ChatHistory.select(ChatHistory.chat_id, fn.COUNT(ChatHistory.id).alias("cnt"))
|
||||
.group_by(ChatHistory.chat_id)
|
||||
.having(fn.COUNT(ChatHistory.id) >= 10)
|
||||
.order_by(ChatHistory.chat_id)
|
||||
.limit(200)
|
||||
)
|
||||
eligible_ids = [r.chat_id for r in rows]
|
||||
if not eligible_ids:
|
||||
logger.warning("[dream] ChatHistory 中暂无满足条件(记录数 >= 10)的 chat_id,本轮 dream 任务跳过。")
|
||||
return None
|
||||
chosen = random.choice(eligible_ids)
|
||||
logger.info(f"[dream] 从 {len(eligible_ids)} 个满足条件的 chat_id 中随机选择:{chosen}")
|
||||
return chosen
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 随机选择 chat_id 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _pick_random_memory_for_chat(chat_id: str) -> Optional[int]:
|
||||
"""
|
||||
在给定 chat_id 下随机选择一条 ChatHistory 记录,作为本轮整理的起始记忆。
|
||||
"""
|
||||
try:
|
||||
rows = (
|
||||
ChatHistory.select(ChatHistory.id)
|
||||
.where(ChatHistory.chat_id == chat_id)
|
||||
.order_by(ChatHistory.start_time.asc())
|
||||
.limit(200)
|
||||
)
|
||||
ids = [r.id for r in rows]
|
||||
if not ids:
|
||||
logger.warning(f"[dream] chat_id={chat_id} 下暂无 ChatHistory 记录,无法选择起始记忆。")
|
||||
return None
|
||||
return random.choice(ids)
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 在 chat_id={chat_id} 下随机选择起始记忆失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def run_dream_cycle_once() -> None:
|
||||
"""
|
||||
单次 dream 周期:
|
||||
- 随机选择一个 chat_id
|
||||
- 在该 chat_id 下随机选择一条 ChatHistory 作为起始记忆
|
||||
- 以这条起始记忆为切入点,对该 chat_id 运行一次 dream agent(最多 15 轮)
|
||||
"""
|
||||
chat_id = _pick_random_chat_id()
|
||||
if not chat_id:
|
||||
return
|
||||
|
||||
start_memory_id = _pick_random_memory_for_chat(chat_id)
|
||||
await run_dream_agent_once(
|
||||
chat_id=chat_id,
|
||||
max_iterations=None, # 使用配置文件中的默认值
|
||||
start_memory_id=start_memory_id,
|
||||
)
|
||||
|
||||
|
||||
async def start_dream_scheduler(
|
||||
first_delay_seconds: Optional[int] = None,
|
||||
interval_seconds: Optional[int] = None,
|
||||
stop_event: Optional[asyncio.Event] = None,
|
||||
) -> None:
|
||||
"""
|
||||
dream 调度器:
|
||||
- 程序启动后先等待 first_delay_seconds(如果为 None,则使用配置文件中的值,默认 60s)
|
||||
- 然后每隔 interval_seconds(如果为 None,则使用配置文件中的值,默认 30 分钟)运行一次 dream agent 周期
|
||||
- 如果提供 stop_event,则在 stop_event 被 set() 后优雅退出循环
|
||||
"""
|
||||
if first_delay_seconds is None:
|
||||
first_delay_seconds = global_config.dream.first_delay_seconds
|
||||
|
||||
if interval_seconds is None:
|
||||
interval_seconds = global_config.dream.interval_minutes * 60
|
||||
|
||||
logger.info(
|
||||
f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s,之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent"
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.sleep(first_delay_seconds)
|
||||
while True:
|
||||
if stop_event is not None and stop_event.is_set():
|
||||
logger.info("[dream] 收到停止事件,结束 dream 调度器循环。")
|
||||
break
|
||||
|
||||
start_ts = time.time()
|
||||
# 检查当前时间是否在允许做梦的时间段内
|
||||
if not global_config.dream.is_in_dream_time():
|
||||
logger.debug("[dream] 当前时间不在允许做梦的时间段内,跳过本次执行")
|
||||
else:
|
||||
try:
|
||||
await run_dream_cycle_once()
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 单次 dream 周期执行异常: {e}")
|
||||
|
||||
elapsed = time.time() - start_ts
|
||||
# 保证两次执行之间至少间隔 interval_seconds
|
||||
to_sleep = max(0.0, interval_seconds - elapsed)
|
||||
await asyncio.sleep(to_sleep)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[dream] dream 调度器任务被取消,准备退出。")
|
||||
raise
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_dream_prompts()
|
||||
251
src/dream/dream_generator.py
Normal file
251
src/dream/dream_generator.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
from src.llm_models.payload_content.message import RoleType, Message
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger("dream_generator")
|
||||
|
||||
# 初始化 utils 模型用于生成梦境总结
|
||||
_dream_summary_model: Optional[LLMRequest] = None
|
||||
|
||||
# 梦境风格列表(21种)
|
||||
DREAM_STYLES = [
|
||||
"保持诗意和想象力,自由编写",
|
||||
"诗意朦胧,如薄雾笼罩的清晨",
|
||||
"奇幻冒险,充满未知与探索",
|
||||
"温暖怀旧,带着时光的痕迹",
|
||||
"神秘悬疑,暗藏深意",
|
||||
"浪漫唯美,如诗如画",
|
||||
"科幻未来,科技与想象交织",
|
||||
"自然清新,如山林间的微风",
|
||||
"深沉哲思,引人深思",
|
||||
"轻松幽默,充满趣味",
|
||||
"悲伤忧郁,带着淡淡哀愁",
|
||||
"激昂热烈,充满活力",
|
||||
"宁静平和,如湖面般平静",
|
||||
"荒诞离奇,打破常规",
|
||||
"细腻温柔,如春风拂面",
|
||||
"壮阔宏大,气势磅礴",
|
||||
"简约纯粹,返璞归真",
|
||||
"复杂多变,层次丰富",
|
||||
"梦幻迷离,虚实难辨",
|
||||
"现实写意,贴近生活",
|
||||
"抽象概念,超越具象",
|
||||
]
|
||||
|
||||
|
||||
def get_random_dream_styles(count: int = 2) -> List[str]:
|
||||
"""从梦境风格列表中随机选择指定数量的风格"""
|
||||
return random.sample(DREAM_STYLES, min(count, len(DREAM_STYLES)))
|
||||
|
||||
|
||||
def get_dream_summary_model() -> LLMRequest:
|
||||
"""获取用于生成梦境总结的 utils 模型实例"""
|
||||
global _dream_summary_model
|
||||
if _dream_summary_model is None:
|
||||
_dream_summary_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="dream.summary",
|
||||
)
|
||||
return _dream_summary_model
|
||||
|
||||
|
||||
def init_dream_summary_prompt() -> None:
|
||||
"""初始化梦境总结的提示词"""
|
||||
Prompt(
|
||||
"""
|
||||
你刚刚完成了一次对聊天记录的记忆整理工作。以下是整理过程的摘要:
|
||||
整理过程:
|
||||
{conversation_text}
|
||||
|
||||
请将这次整理涉及的相关信息改写为一个富有诗意和想象力的"梦境",请你仅使用具体的记忆的内容,而不是整理过程编写。
|
||||
要求:
|
||||
1. 使用第一人称视角
|
||||
2. 叙述直白,不要复杂修辞,口语化
|
||||
3. 长度控制在200-800字
|
||||
4. 用中文输出
|
||||
梦境风格:
|
||||
{dream_styles}
|
||||
请直接输出梦境内容,不要添加其他说明:
|
||||
""",
|
||||
name="dream_summary_prompt",
|
||||
)
|
||||
|
||||
|
||||
async def generate_dream_summary(
|
||||
chat_id: str,
|
||||
conversation_messages: List[Message],
|
||||
total_iterations: int,
|
||||
time_cost: float,
|
||||
) -> None:
|
||||
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
|
||||
try:
|
||||
import json
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 第一步:建立工具调用结果映射 (call_id -> result)
|
||||
tool_results_map: dict[str, str] = {}
|
||||
for msg in conversation_messages:
|
||||
if msg.role == RoleType.Tool and msg.tool_call_id:
|
||||
content = ""
|
||||
if msg.content:
|
||||
if isinstance(msg.content, list) and msg.content:
|
||||
content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
||||
else:
|
||||
content = str(msg.content)
|
||||
tool_results_map[msg.tool_call_id] = content
|
||||
|
||||
# 第二步:详细记录所有工具调用操作和结果到日志
|
||||
tool_call_count = 0
|
||||
logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:")
|
||||
|
||||
for msg in conversation_messages:
|
||||
if msg.role == RoleType.Assistant and msg.tool_calls:
|
||||
tool_call_count += 1
|
||||
# 提取思考内容
|
||||
thought_content = ""
|
||||
if msg.content:
|
||||
if isinstance(msg.content, list) and msg.content:
|
||||
thought_content = (
|
||||
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
||||
)
|
||||
else:
|
||||
thought_content = str(msg.content)
|
||||
|
||||
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
|
||||
if thought_content:
|
||||
logger.info(
|
||||
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
|
||||
)
|
||||
|
||||
# 记录每个工具调用的详细信息
|
||||
for idx, tool_call in enumerate(msg.tool_calls, 1):
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
tool_call_id = tool_call.call_id
|
||||
tool_result = tool_results_map.get(tool_call_id, "未找到执行结果")
|
||||
|
||||
# 格式化参数
|
||||
try:
|
||||
args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数"
|
||||
except Exception:
|
||||
args_str = str(tool_args)
|
||||
|
||||
logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---")
|
||||
logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}")
|
||||
logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}")
|
||||
logger.info(f"[dream][工具调用详情] {'-' * 60}")
|
||||
|
||||
logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作")
|
||||
|
||||
# 第三步:构建对话历史摘要(用于生成梦境)
|
||||
conversation_summary = []
|
||||
for msg in conversation_messages:
|
||||
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
content = ""
|
||||
if msg.content:
|
||||
content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content)
|
||||
|
||||
if role == "user" and "轮次信息" in content:
|
||||
# 跳过轮次信息消息
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# 只保留思考内容,简化工具调用信息
|
||||
if content:
|
||||
# 截取前500字符,避免过长
|
||||
content_preview = content[:500] + ("..." if len(content) > 500 else "")
|
||||
conversation_summary.append(f"[{role}] {content_preview}")
|
||||
elif role == "tool":
|
||||
# 工具结果,只保留关键信息
|
||||
if content:
|
||||
# 截取前300字符
|
||||
content_preview = content[:300] + ("..." if len(content) > 300 else "")
|
||||
conversation_summary.append(f"[工具执行] {content_preview}")
|
||||
|
||||
conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息
|
||||
|
||||
# 随机选择2个梦境风格
|
||||
selected_styles = get_random_dream_styles(2)
|
||||
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
|
||||
|
||||
# 使用 Prompt 管理器格式化梦境生成 prompt
|
||||
dream_prompt = await global_prompt_manager.format_prompt(
|
||||
"dream_summary_prompt",
|
||||
chat_id=chat_id,
|
||||
total_iterations=total_iterations,
|
||||
time_cost=time_cost,
|
||||
conversation_text=conversation_text,
|
||||
dream_styles=dream_styles_text,
|
||||
)
|
||||
|
||||
# 调用 utils 模型生成梦境
|
||||
summary_model = get_dream_summary_model()
|
||||
dream_content, (reasoning, model_name, _) = await summary_model.generate_response_async(
|
||||
dream_prompt,
|
||||
max_tokens=512,
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
if dream_content:
|
||||
logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}")
|
||||
|
||||
# 第五步:根据配置决定是否将梦境发送给指定用户
|
||||
try:
|
||||
dream_send_raw = getattr(global_config.dream, "dream_send", "") or ""
|
||||
dream_send = dream_send_raw.strip()
|
||||
if dream_send:
|
||||
parts = dream_send.split(":")
|
||||
if len(parts) != 2:
|
||||
logger.warning(
|
||||
f"[dream][梦境总结] dream_send 配置格式不正确,应为 'platform:user_id',当前值: {dream_send_raw!r}"
|
||||
)
|
||||
else:
|
||||
platform, user_id = parts[0].strip(), parts[1].strip()
|
||||
if not platform or not user_id:
|
||||
logger.warning(
|
||||
f"[dream][梦境总结] dream_send 平台或用户ID为空,当前值: {dream_send_raw!r}"
|
||||
)
|
||||
else:
|
||||
# 默认为私聊会话
|
||||
stream_id = get_chat_manager().get_stream_id(
|
||||
platform=platform,
|
||||
id=str(user_id),
|
||||
is_group=False,
|
||||
)
|
||||
if not stream_id:
|
||||
logger.error(
|
||||
f"[dream][梦境总结] 无法根据 dream_send 找到有效的聊天流,"
|
||||
f"platform={platform!r}, user_id={user_id!r}"
|
||||
)
|
||||
else:
|
||||
ok = await send_api.text_to_stream(
|
||||
dream_content,
|
||||
stream_id=stream_id,
|
||||
typing=False,
|
||||
storage_message=True,
|
||||
)
|
||||
if ok:
|
||||
logger.info(
|
||||
f"[dream][梦境总结] 已将梦境结果发送给配置的目标用户: {platform}:{user_id}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"[dream][梦境总结] 向 {platform}:{user_id} 发送梦境结果失败"
|
||||
)
|
||||
except Exception as send_exc:
|
||||
logger.error(f"[dream][梦境总结] 发送梦境结果到配置用户时出错: {send_exc}", exc_info=True)
|
||||
else:
|
||||
logger.warning("[dream][梦境总结] 未能生成梦境总结")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
init_dream_summary_prompt()
|
||||
7
src/dream/tools/__init__.py
Normal file
7
src/dream/tools/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
dream agent 工具实现模块。
|
||||
|
||||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||
"""
|
||||
|
||||
63
src/dream/tools/create_chat_history_tool.py
Normal file
63
src/dream/tools/create_chat_history_tool.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_create_chat_history(chat_id: str):
|
||||
async def create_chat_history(
|
||||
theme: str,
|
||||
summary: str,
|
||||
keywords: str,
|
||||
key_point: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> str:
|
||||
"""创建一条新的 ChatHistory 概括记录(用于整理/合并后的新记忆)"""
|
||||
try:
|
||||
logger.info(
|
||||
f"[dream][tool] 调用 create_chat_history("
|
||||
f"theme={bool(theme)}, summary={bool(summary)}, "
|
||||
f"keywords={bool(keywords)}, key_point={bool(key_point)}, "
|
||||
f"start_time={start_time}, end_time={end_time}) (chat_id={chat_id})"
|
||||
)
|
||||
|
||||
now_ts = time.time()
|
||||
|
||||
# 将传入的 start_time/end_time(如果有)解析为时间戳;否则回退为当前时间
|
||||
def _parse_ts(value, default):
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
start_ts = _parse_ts(start_time, now_ts)
|
||||
end_ts = _parse_ts(end_time, now_ts)
|
||||
|
||||
record = ChatHistory.create(
|
||||
chat_id=chat_id,
|
||||
theme=theme,
|
||||
summary=summary,
|
||||
keywords=keywords,
|
||||
key_point=key_point,
|
||||
# 对于由 dream 整理产生的新概括,时间范围优先使用工具提供的时间,否则使用当前时间占位
|
||||
start_time=start_ts,
|
||||
end_time=end_ts,
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"已创建新的 ChatHistory 记录,ID={record.id},"
|
||||
f"theme={record.theme or '无'},summary={'有' if record.summary else '无'}。"
|
||||
)
|
||||
logger.info(f"[dream][tool] create_chat_history 完成: {msg}")
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.error(f"create_chat_history 失败: {e}")
|
||||
return f"create_chat_history 执行失败: {e}"
|
||||
|
||||
return create_chat_history
|
||||
|
||||
26
src/dream/tools/delete_chat_history_tool.py
Normal file
26
src/dream/tools/delete_chat_history_tool.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
|
||||
async def delete_chat_history(memory_id: int) -> str:
|
||||
"""删除一条 chat_history 记录"""
|
||||
try:
|
||||
logger.info(f"[dream][tool] 调用 delete_chat_history(memory_id={memory_id})")
|
||||
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
|
||||
if not record:
|
||||
msg = f"未找到 ID={memory_id} 的 ChatHistory 记录,无法删除。"
|
||||
logger.info(f"[dream][tool] delete_chat_history 未找到记录: {msg}")
|
||||
return msg
|
||||
rows = ChatHistory.delete().where(ChatHistory.id == memory_id).execute()
|
||||
msg = f"已删除 ID={memory_id} 的 ChatHistory 记录,受影响行数={rows}。"
|
||||
logger.info(f"[dream][tool] delete_chat_history 完成: {msg}")
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.error(f"delete_chat_history 失败: {e}")
|
||||
return f"delete_chat_history 执行失败: {e}"
|
||||
|
||||
return delete_chat_history
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user