Merge pull request #1380 from zhangxinhui02/helm-chart

Helm chart `0.11.3-beta` and `0.11.5-beta` Release
This commit is contained in:
SengokuCola
2025-11-24 01:42:26 +08:00
committed by GitHub
183 changed files with 7572 additions and 1768 deletions

View File

@@ -1,11 +1,12 @@
name: Docker Build and Push name: Docker Build and Push
on: on:
schedule:
- cron: '0 0 * * *'
push: push:
branches: branches:
- main - main
- classical - classical
- dev
tags: tags:
- "v*.*.*" - "v*.*.*"
- "v*" - "v*"
@@ -24,6 +25,7 @@ jobs:
- name: Check out git repository - name: Check out git repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
fetch-depth: 0 fetch-depth: 0
# Clone required dependencies # Clone required dependencies
@@ -77,6 +79,7 @@ jobs:
- name: Check out git repository - name: Check out git repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
fetch-depth: 0 fetch-depth: 0
# Clone required dependencies # Clone required dependencies

4
.gitignore vendored
View File

@@ -11,7 +11,6 @@ run_maibot_core.bat
run_voice.bat run_voice.bat
run_napcat_adapter.bat run_napcat_adapter.bat
run_ad.bat run_ad.bat
s4u.s4u
llm_tool_benchmark_results.json llm_tool_benchmark_results.json
MaiBot-Napcat-Adapter-main MaiBot-Napcat-Adapter-main
MaiBot-Napcat-Adapter MaiBot-Napcat-Adapter
@@ -27,6 +26,7 @@ run.bat
log_debug/ log_debug/
run_amds.bat run_amds.bat
run_none.bat run_none.bat
docs-mai/
run.py run.py
message_queue_content.txt message_queue_content.txt
message_queue_content.bat message_queue_content.bat
@@ -51,6 +51,7 @@ template/compare/model_config_template.toml
src/plugins/utils/statistic.py src/plugins/utils/statistic.py
CLAUDE.md CLAUDE.md
MaiBot-Dashboard/ MaiBot-Dashboard/
cloudflare-workers/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
@@ -69,7 +70,6 @@ elua.confirmed
.Python .Python
build/ build/
develop-eggs/ develop-eggs/
dist/
downloads/ downloads/
eggs/ eggs/
.eggs/ .eggs/

View File

@@ -25,6 +25,8 @@ WORKDIR /MaiMBot
# 复制依赖列表 # 复制依赖列表
COPY requirements.txt . COPY requirements.txt .
RUN apt-get update && apt-get install -y git
# 从编译阶段复制 LPMM 编译结果 # 从编译阶段复制 LPMM 编译结果
COPY --from=lpmm-builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/ COPY --from=lpmm-builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/

View File

@@ -25,11 +25,12 @@
**🍔MaiCore 是一个基于大语言模型的可交互智能体** **🍔MaiCore 是一个基于大语言模型的可交互智能体**
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制 - 💭 **拟人构建的prompt**使用自然语言风格构建回复器的prompt实现近似人类言语习惯的回复
- 🤔 **实时思维系统**:模拟人类思考过程。 - 💭 **行为规划**:在合适的时间说话,使用合适的动作
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式 - 🧠 **表达学习**:学习群友的说话风格和表达方式,学会真实人类的说话风格
- 💝 **情感表达系统**:情绪系统和表情包系统。 - 🤔 **黑话学习**:自主的学习没有见过的词语,尝试理解并认知含义
- 🔌 **强大插件系统**提供API和事件系统可编写强大插件。 - 🔌 **插件系统**提供API和事件系统可编写丰富插件。
- 💝 **情感表达**:情绪系统和表情包系统。
<div style="text-align: center"> <div style="text-align: center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank"> <a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
@@ -44,29 +45,28 @@
## 🔥 更新和安装 ## 🔥 更新和安装
**最新版本: v0.11.0** ([更新日志](changelogs/changelog.md))
**最新版本: v0.11.5** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器 可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
**GitHub 分支说明:** **GitHub 分支说明:**
- `main`: 稳定发布版本(推荐) - `main`: 稳定发布版本(推荐)
- `dev`: 开发测试版本(不稳定) - `dev`: 开发测试版本(不稳定)
- `classical`: 版本(停止维护) - `classical`: 经典版本(停止维护)
### 最新版本部署教程 ### 最新版本部署教程
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容) - [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
> [!WARNING] > [!WARNING]
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。 > - 项目处于活跃开发阶段,功能和 API 可能随时调整。
> - 有问题可以提交 Issue 或者 Discussion > - 有问题可以提交 Issue 。
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。 > - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
> - 由于程序处于开发中,可能消耗较多 token。 > - 由于程序处于开发中,可能消耗较多 token。
## 麦麦MC项目MaiCraft早期开发
[让麦麦玩MC](https://github.com/MaiM-with-u/Maicraft)
交流群1058573197
## 💬 讨论 ## 💬 讨论
**技术交流群:** **技术交流群:**
@@ -78,7 +78,7 @@
**聊天吹水群:** **聊天吹水群:**
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) - [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
**插件开发测试版群:** **插件开发/测试版讨论群:**
- [插件开发群](https://qm.qq.com/q/1036092828) - [插件开发群](https://qm.qq.com/q/1036092828)
## 📚 文档 ## 📚 文档
@@ -87,7 +87,22 @@
- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。 - [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。
### 设计理念(原始时代的火花)
## 📚 衍生项目
### MaiCraft早期开发
[MaiCraft](https://github.com/MaiM-with-u/Maicraft)
> 让麦麦具有玩MC能力的项目
> 交流群1058573197
### MoFox_Bot
[MoFox - 仓库地址](https://github.com/MoFox-Studio/MoFox-Core)
> MoFox_Bot 是一个基于 MaiCore 0.10.0 snapshot.5 的增强型 fork 项目
> 我们保留了原项目几乎所有核心功能,并在此基础上进行了深度优化与功能扩展,致力于打造一个更稳定、更智能、更具趣味性的 AI 智能体。
## 设计理念(原始时代的火花)
> **千石可乐说:** > **千石可乐说:**
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 > - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
@@ -99,7 +114,7 @@
## 🙋 贡献和致谢 ## 🙋 贡献和致谢
你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦! 你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦!
MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr都对项目非常宝贵。我们非常感谢你的支持🎉 MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr都对项目非常宝贵。我们非常感谢你的支持🎉
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)。(待补完) 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs-src/CONTRIBUTE.md)。(待补完)
### 贡献者 ### 贡献者

30
bot.py
View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import hashlib import hashlib
import os import os
import sys
import time import time
import platform import platform
import traceback import traceback
@@ -30,7 +29,7 @@ else:
raise raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
initialize_logging() initialize_logging()
@@ -76,6 +75,15 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
try: try:
logger.info("正在优雅关闭麦麦...") logger.info("正在优雅关闭麦麦...")
# 关闭 WebUI 服务器
try:
from src.webui.webui_server import get_webui_server
webui_server = get_webui_server()
if webui_server and webui_server._server:
await webui_server.shutdown()
except Exception as e:
logger.warning(f"关闭 WebUI 服务器时出错: {e}")
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType from src.plugin_system.base.component_types import EventType
@@ -107,9 +115,6 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
logger.info("麦麦优雅关闭完成") logger.info("麦麦优雅关闭完成")
# 关闭日志系统,释放文件句柄
shutdown_logging()
except Exception as e: except Exception as e:
logger.error(f"麦麦关闭失败: {e}", exc_info=True) logger.error(f"麦麦关闭失败: {e}", exc_info=True)
@@ -216,6 +221,11 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
# 初始化 WebSocket 日志推送
from src.common.logger import initialize_ws_handler
initialize_ws_handler(loop)
try: try:
# 执行初始化和任务调度 # 执行初始化和任务调度
loop.run_until_complete(main_system.initialize()) loop.run_until_complete(main_system.initialize())
@@ -241,7 +251,7 @@ if __name__ == "__main__":
# 确保 loop 在任何情况下都尝试关闭(如果存在且未关闭) # 确保 loop 在任何情况下都尝试关闭(如果存在且未关闭)
if "loop" in locals() and loop and not loop.is_closed(): if "loop" in locals() and loop and not loop.is_closed():
loop.close() loop.close()
logger.info("事件循环已关闭") print("[主程序] 事件循环已关闭")
# 关闭日志系统,释放文件句柄 # 关闭日志系统,释放文件句柄
try: try:
@@ -249,6 +259,8 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"关闭日志系统时出错: {e}") print(f"关闭日志系统时出错: {e}")
# 在程序退出前暂停,让你有机会看到输出 print("[主程序] 准备退出...")
# input("按 Enter 键退出...") # <--- 添加这行
sys.exit(exit_code) # <--- 使用记录的退出码 # 使用 os._exit() 强制退出,避免被阻塞
# 由于已经在 graceful_shutdown() 中完成了所有清理工作,这是安全的
os._exit(exit_code)

View File

@@ -1,6 +1,99 @@
# Changelog # Changelog
## [0.11.5] - 2025-11-21
### 🌟 重大更新
- WebUI 现支持手动重启麦麦,曲线救国版“热重载”
- 新增麦麦 QQ 适配器可视化编辑 UI独立进程需手动上传/下载并覆盖适配器文件)
- 麦麦主程序配置支持可视化模式与源代码模式双模式编辑,后端执行 TOML 校验
- 优化 planner 与 replyer 协同机制,调试日志更细
## [0.11.2] - 2025-11-15 ### 新增
- 表情包管理、人物信息管理、表达方式管理界面手机端适配
- 配置页“重启麦麦”提示
- 详细的debug prompt显示配置
- 麦麦界面操作主题色按钮
- 前端集成 CodeMirrorPython/JSON/TOML 语法高亮)并对 JSON 配置提供自动纠错提示
### 修复
- 表情包缩略图过小
- 添加模型后无法立即显示
- 插件市场分类错误
- 浅色模式下日志查看器背景异常
- 表情包详情窗口无法关闭
- 情绪标签无法正常读取issues/1373
- 模型任务配置界面模型温度不可更改issues/1369
- 模型任务配置界面部分输入框无法删除默认值
- 插件商店默认勾选“仅显示兼容插件”
- 插件商店标签页数量显示错误
- 日志查看器日志换行异常
- 主题色未正确应用
- 侧边栏滚动条问题
- 移除前端 TOML 解析库导致的兼容问题
### 优化
- 首页加载用户体验
- 首页加载速度9s → <1s
- 整个界面的初屏加载速度
### 更新
- 适配器配置支持上传模式与指定路径模式;指定路径模式可免去反复上传/下载配置文件
- 适配器配置界面标签页响应式优化,小屏仅显示简短标签
- 麦麦资源管理所有界面的批量删除能力
- 资源管理所有界面的分页、每页数量选择与页跳转
- 插件市场新增点赞、点踩、评分与下载量统计(基于 Cloudflare保证国内可访问
- 麦麦表情包查看界面的描述部分支持 Markdown 渲染
## [0.11.4] - 2025-11-19
### 🌟 主要更新内容
- **首个官方 Web 管理界面上线**在此版本之前MaiBot 没有 WebUI所有配置需手动编辑 TOML 文件
- **认证系统**Token 安全登录(支持系统生成 64 位随机令牌 / 自定义 Token首次配置向导
- **配置管理(可视化编辑,无需手动改 TOML**
- 麦麦主程序配置:基础设置、人格、表情、黑话、情绪等
- 模型提供商配置OpenAI、Anthropic、DeepSeek、Qwen、Ollama 等
- 模型配置:对话/视觉/嵌入模型分配
- **资源管理**
- 表情包管理:查看、搜索、注册、封禁
- 表达方式管理:查看麦麦的表达记录
- 人物信息管理:查看联系人列表
- **插件系统**
- 插件市场浏览
- 一键安装/卸载/更新
- 版本兼容性检查
- 实时安装进度推送
- **日志查看器**
- WebSocket 实时日志流
- 日志级别过滤DEBUG/INFO/WARNING/ERROR/CRITICAL
- 搜索功能
- **主题定制**
- 浅色/深色/跟随系统
- 12 种主题色6 单色 + 6 渐变色)
- 自定义颜色选择器
- **全局搜索**Cmd/Ctrl + K 快捷键,快速跳转任意页面
### 细节
- **技术栈**
- 前端: React 19 + TypeScript + Vite + TanStack Router + shadcn/ui
- 后端: FastAPI + Uvicorn + WebSocket
- 特点: SPA 单页应用,前后端同端口,静态文件托管
- **使用方式**:参照 template.env 文件更新 .env 文件,添加两个字段:
- `WEBUI_ENABLED=true`
- `WEBUI_MODE=production`
- **WebUI 开源协议**GPLv3
- **WebUI 地址**https://github.com/Mai-with-u/MaiBot-Dashboard
告别手动编辑配置文件,享受现代化图形界面!
## [0.11.3] - 2025-11-18
### 功能更改和修复
- 优化记忆提取策略
- 优化黑话提取
- 优化表达方式学习
- 修改readme
- 加入测试版webui
提示:清理旧的记忆数据和表达方式,表现更好
方法:删除数据库中 expression jargon 和 thinking_back 的全部内容
## [0.11.2] - 2025-11-16
### 🌟 主要功能更改 ### 🌟 主要功能更改
- "海马体Agent"记忆系统上线最新最好的记忆系统默认已接入lpmm - "海马体Agent"记忆系统上线最新最好的记忆系统默认已接入lpmm
- 添加黑话jargon学习系统 - 添加黑话jargon学习系统

View File

@@ -29,7 +29,8 @@ services:
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA # - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA # - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
# ports: ports:
- "18001:8001" # webui端口
# - "8000:8000" # - "8000:8000"
volumes: volumes:
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件 - ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件

View File

Before

Width:  |  Height:  |  Size: 4.1 KiB

After

Width:  |  Height:  |  Size: 4.1 KiB

View File

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

View File

@@ -1,336 +0,0 @@
# 模型配置指南
本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
## 配置文件结构
配置文件主要包含以下几个部分:
- 版本信息
- API服务提供商配置
- 模型配置
- 模型任务配置
## 1. 版本信息
```toml
[inner]
version = "1.1.1"
```
用于标识配置文件的版本,遵循语义化版本规则。
## 2. API服务提供商配置
### 2.1 基本配置
使用 `[[api_providers]]` 数组配置多个API服务提供商
```toml
[[api_providers]]
name = "DeepSeek" # 服务商名称(自定义)
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
api_key = "your-api-key-here" # API密钥
client_type = "openai" # 客户端类型
max_retry = 2 # 最大重试次数
timeout = 30 # 超时时间(秒)
retry_interval = 10 # 重试间隔(秒)
```
### 2.2 配置参数说明
| 参数 | 必填 | 说明 | 默认值 |
|------|------|------|--------|
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
| `base_url` | ✅ | API服务的基础URL | - |
| `api_key` | ✅ | API密钥请替换为实际密钥 | - |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式 | `openai` |
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
| `timeout` | ❌ | API请求超时时间 | 30 |
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
### 2.3 支持的服务商示例
#### DeepSeek
```toml
[[api_providers]]
name = "DeepSeek"
base_url = "https://api.deepseek.com/v1"
api_key = "your-deepseek-api-key"
client_type = "openai"
```
#### SiliconFlow
```toml
[[api_providers]]
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
api_key = "your-siliconflow-api-key"
client_type = "openai"
```
#### Google Gemini
```toml
[[api_providers]]
name = "Google"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "your-google-api-key"
client_type = "gemini" # 注意Gemini需要使用特殊客户端
```
## 3. 模型配置
### 3.1 基本模型配置
使用 `[[models]]` 数组配置多个模型:
```toml
[[models]]
model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
name = "deepseek-v3" # 自定义模型名称
api_provider = "DeepSeek" # 引用的API服务商名称
price_in = 2.0 # 输入价格(元/M token
price_out = 8.0 # 输出价格(元/M token
```
### 3.2 高级模型配置
#### 强制流式输出
对于不支持非流式输出的模型:
```toml
[[models]]
model_identifier = "some-model"
name = "custom-name"
api_provider = "Provider"
force_stream_mode = true # 启用强制流式输出
```
#### 额外参数配置`extra_params`
```toml
[[models]]
model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b"
api_provider = "SiliconFlow"
[models.extra_params]
enable_thinking = false # 禁用思考
```
这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置**配置时应参考相应的API文档**。
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
![SiliconFlow文档截图](image-1.png)
以豆包文档为另一个例子
![豆包文档截图](image.png)
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
```toml
[[models]]
# 你的模型
[models.extra_params]
thinking = {type = "disabled"} # 禁用思考
```
而对于`gemini`需要单独进行配置
```toml
[[models]]
model_identifier = "gemini-2.5-flash"
name = "gemini-2.5-flash"
api_provider = "Google"
[models.extra_params]
thinking_budget = 0 # 禁用思考
# thinking_budget = -1 由模型自己决定
```
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构具体内容取决于API服务商的要求。
### 3.3 配置参数说明
| 参数 | 必填 | 说明 |
|------|------|------|
| `model_identifier` | ✅ | API服务商提供的模型标识符 |
| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
| `api_provider` | ✅ | 对应的API服务商名称 |
| `price_in` | ❌ | 输入价格(元/M token用于成本统计 |
| `price_out` | ❌ | 输出价格(元/M token用于成本统计 |
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
| `extra_params` | ❌ | 额外的模型参数配置 |
## 4. 模型任务配置
### utils - 工具模型
用于表情包模块、取名模块、关系模块等核心功能:
```toml
[model_task_config.utils]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### utils_small - 小型工具模型
用于高频率调用的场景,建议使用速度快的小模型:
```toml
[model_task_config.utils_small]
model_list = ["qwen3-8b"]
temperature = 0.7
max_tokens = 800
```
### replyer - 主要回复模型
首要回复模型,也用于表达器和表达方式学习:
```toml
[model_task_config.replyer]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### planner - 决策模型
负责决定MaiBot该做什么
```toml
[model_task_config.planner]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.3
max_tokens = 800
```
### emotion - 情绪模型
负责MaiBot的情绪变化
```toml
[model_task_config.emotion]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.3
max_tokens = 800
```
### memory - 记忆模型
```toml
[model_task_config.memory]
model_list = ["qwen3-30b"]
temperature = 0.7
max_tokens = 800
```
### vlm - 视觉语言模型
用于图像识别:
```toml
[model_task_config.vlm]
model_list = ["qwen2.5-vl-72b"]
max_tokens = 800
```
### voice - 语音识别模型
```toml
[model_task_config.voice]
model_list = ["sensevoice-small"]
```
### embedding - 嵌入模型
```toml
[model_task_config.embedding]
model_list = ["bge-m3"]
```
### tool_use - 工具调用模型
需要使用支持工具调用的模型:
```toml
[model_task_config.tool_use]
model_list = ["qwen3-14b"]
temperature = 0.7
max_tokens = 800
```
### lpmm_entity_extract - 实体提取模型
```toml
[model_task_config.lpmm_entity_extract]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### lpmm_rdf_build - RDF构建模型
```toml
[model_task_config.lpmm_rdf_build]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### lpmm_qa - 问答模型
```toml
[model_task_config.lpmm_qa]
model_list = ["deepseek-r1-distill-qwen-32b"]
temperature = 0.7
max_tokens = 800
```
## 5. 配置建议
### 5.1 Temperature 参数选择
| 任务类型 | 推荐温度 | 说明 |
|----------|----------|------|
| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
### 5.2 模型选择建议
| 任务类型 | 推荐模型类型 | 示例 |
|----------|--------------|------|
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
| 高频率任务 | 小模型 | Qwen3-8B |
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
### 5.3 成本优化
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
3. **选择免费模型**对于测试环境优先使用price为0的模型
## 6. 配置验证
### 6.1 必要检查项
1. ✅ API密钥是否正确配置
2. ✅ 模型标识符是否与API服务商提供的一致
3. ✅ 任务配置中引用的模型名称是否在models中定义
4. ✅ 多模态任务是否配置了对应的专用模型
### 6.2 测试配置
建议在正式使用前:
1. 使用少量测试数据验证配置
2. 检查API调用是否正常
3. 确认成本统计功能正常工作
## 7. 故障排除
### 7.1 常见问题
**问题1**: API调用失败
- 检查API密钥是否正确
- 确认base_url是否可访问
- 检查模型标识符是否正确
**问题2**: 模型未找到
- 确认模型名称在任务配置和模型定义中一致
- 检查api_provider名称是否匹配
**问题3**: 响应异常
- 检查温度参数是否合理0-1之间
- 确认max_tokens设置是否合适
- 验证模型是否支持所需功能
### 7.2 日志查看
查看 `logs/` 目录下的日志文件,寻找相关错误信息。
## 8. 更新和维护
1. **定期更新**: 关注API服务商的模型更新及时调整配置
2. **性能监控**: 监控模型调用的成本和性能
3. **备份配置**: 在修改前备份当前配置文件

View File

@@ -1,7 +1,6 @@
stages: stages:
# - initialize-maibot-git-repo - build-image
- build - package-helm-chart
- package
# 仅在helm-chart分支运行 # 仅在helm-chart分支运行
workflow: workflow:
@@ -9,49 +8,15 @@ workflow:
- if: '$CI_COMMIT_BRANCH == "helm-chart"' - if: '$CI_COMMIT_BRANCH == "helm-chart"'
- when: never - when: never
## 查询并将麦麦仓库的工作区置为最后一个tag的版本 # 构建并推送adapter-cm-generator镜像
#initialize-maibot-git-repo:
# stage: initialize-maibot-git-repo
# image: reg.mikumikumi.xyz/base/git:latest
# cache:
# key: git-repo
# policy: push
# paths:
# - target-repo/
# script:
# - git clone https://github.com/Mai-with-u/MaiBot.git target-repo/
# - cd target-repo/
# - export MAIBOT_VERSION=$(git describe --tags --abbrev=0)
# - echo "Current version is ${MAIBOT_VERSION}"
# - git reset --hard ${MAIBOT_VERSION}
# - echo ${MAIBOT_VERSION} > MAIBOT_VERSION
# - git clone https://github.com/MaiM-with-u/maim_message maim_message
# - git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
# - ls -al
#
## 构建最后一个tag的麦麦本体的镜像
#build-core:
# stage: build
# image: reg.mikumikumi.xyz/base/kaniko-builder:latest
# cache:
# key: git-repo
# policy: pull
# paths:
# - target-repo/
# script:
# - cd target-repo/
# - export BUILD_CONTEXT=$(pwd)
# - ls -al
# - export BUILD_DESTINATION="reg.mikumikumi.xyz/maibot/maibot:tag-$(cat MAIBOT_VERSION)"
# - build
# 将Helm Chart版本作为tag构建并推送镜像
build-adapter-cm-generator: build-adapter-cm-generator:
stage: build stage: build-image
image: reg.mikumikumi.xyz/base/kaniko-builder:latest image: reg.mikumikumi.xyz/base/kaniko-builder:latest
# rules: variables:
# - changes: BUILD_NO_CACHE: true
# - helm-chart/adapter-cm-generator/** rules:
- changes:
- helm-chart/adapter-cm-generator/**
script: script:
- export BUILD_CONTEXT=helm-chart/adapter-cm-generator - export BUILD_CONTEXT=helm-chart/adapter-cm-generator
- export TMP_DST=reg.mikumikumi.xyz/maibot/adapter-cm-generator - export TMP_DST=reg.mikumikumi.xyz/maibot/adapter-cm-generator
@@ -60,16 +25,36 @@ build-adapter-cm-generator:
- export BUILD_ARGS="--destination ${TMP_DST}:latest" - export BUILD_ARGS="--destination ${TMP_DST}:latest"
- build - build
# 构建并推送core-webui-cm-sync镜像
build-core-webui-cm-sync:
stage: build-image
image: reg.mikumikumi.xyz/base/kaniko-builder:latest
variables:
BUILD_NO_CACHE: true
rules:
- changes:
- helm-chart/core-webui-cm-sync/**
script:
- export BUILD_CONTEXT=helm-chart/core-webui-cm-sync
- export TMP_DST=reg.mikumikumi.xyz/maibot/core-webui-cm-sync
- export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2)
- export BUILD_DESTINATION="${TMP_DST}:${CHART_VERSION}"
- export BUILD_ARGS="--destination ${TMP_DST}:latest"
- build
# 打包并推送helm chart # 打包并推送helm chart
package-helm-chart: package-helm-chart:
stage: package stage: package-helm-chart
image: reg.mikumikumi.xyz/mirror/helm:latest image: reg.mikumikumi.xyz/mirror/helm:latest
# rules: rules:
# - changes: - changes:
# - helm-chart/files/** - helm-chart/files/**
# - helm-chart/templates/** - helm-chart/templates/**
# - helm-chart/Chart.yaml - helm-chart/.gitignore
# - helm-chart/values.yaml - helm-chart/.helmignore
- helm-chart/Chart.yaml
- helm-chart/README.md
- helm-chart/values.yaml
script: script:
- export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2) - export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2)
- helm registry login reg.mikumikumi.xyz --username ${CI_REGISTRY_USER} --password ${CI_REGISTRY_PASSWORD} - helm registry login reg.mikumikumi.xyz --username ${CI_REGISTRY_USER} --password ${CI_REGISTRY_PASSWORD}

View File

@@ -1,2 +1,3 @@
adapter-cm-generator adapter-cm-generator
core-webui-cm-sync
.gitlab-ci.yml .gitlab-ci.yml

View File

@@ -2,5 +2,5 @@ apiVersion: v2
name: maibot name: maibot
description: "Maimai Bot, a cyber friend dedicated to group chats" description: "Maimai Bot, a cyber friend dedicated to group chats"
type: application type: application
version: 0.11.2-beta version: 0.11.5-beta
appVersion: 0.11.2-beta appVersion: 0.11.5-beta

View File

@@ -10,6 +10,8 @@
| Helm Chart版本 | 对应的MaiBot版本 | Commit SHA | | Helm Chart版本 | 对应的MaiBot版本 | Commit SHA |
|----------------|--------------|------------------------------------------| |----------------|--------------|------------------------------------------|
| 0.11.5-beta | 0.11.5-beta | ad2df627001f18996802f23c405b263e78af0d0f |
| 0.11.3-beta | 0.11.3-beta | cd6dc18f546f81e08803d3b8dba48e504dad9295 |
| 0.11.2-beta | 0.11.2-beta | d3c8cea00dbb97f545350f2c3d5bcaf252443df2 | | 0.11.2-beta | 0.11.2-beta | d3c8cea00dbb97f545350f2c3d5bcaf252443df2 |
| 0.11.1-beta | 0.11.1-beta | 94e079a340a43dff8a2bc178706932937fc10b11 | | 0.11.1-beta | 0.11.1-beta | 94e079a340a43dff8a2bc178706932937fc10b11 |
| 0.11.0-beta | 0.11.0-beta | 16059532d8ef87ac28e2be0838ff8b3a34a91d0f | | 0.11.0-beta | 0.11.0-beta | 16059532d8ef87ac28e2be0838ff8b3a34a91d0f |

View File

@@ -5,6 +5,6 @@ WORKDIR /app
COPY . /app COPY . /app
RUN pip3 install --no-cache-dir -i https://mirrors.ustc.edu.cn/pypi/simple -r requirements.txt RUN pip3 install --no-cache-dir -r requirements.txt
ENTRYPOINT ["python3", "adapter-cm-generator.py"] ENTRYPOINT ["python3", "adapter-cm-generator.py"]

View File

@@ -29,7 +29,7 @@ data['maibot_server']['host'] = f'{release_name}-maibot-core' # 根据release
data['maibot_server']['port'] = 8000 data['maibot_server']['port'] = 8000
# 创建/修改configmap # 创建/修改configmap
cm_name = f'{release_name}-maibot-adapter' cm_name = f'{release_name}-maibot-adapter-config'
cm = client.V1ConfigMap( cm = client.V1ConfigMap(
metadata=client.V1ObjectMeta(name=cm_name), metadata=client.V1ObjectMeta(name=cm_name),
data={'config.toml': toml.dumps(data)} data={'config.toml': toml.dumps(data)}

View File

@@ -0,0 +1,10 @@
# 此镜像用于辅助麦麦的WebUI更新配置文件随core容器持续运行
FROM python:3.13-slim
WORKDIR /MaiMBot
COPY . /MaiMBot
RUN pip3 install --no-cache-dir -r requirements.txt
ENTRYPOINT ["python3", "core-webui-cm-sync.py"]

View File

@@ -0,0 +1,92 @@
#!/bin/python3
# 这个程序的作用是辅助麦麦的WebUI更新配置文件随core容器持续运行。
# 麦麦的配置文件存储于ConfigMap中挂载进core容器后属于只读文件无法直接修改。
# 此程序将core容器内的配置文件替换为可读写的中间层临时文件。启动时将实际配置文件写入并在后台持续检测文件变化实时同步到k8s apiServer反向修改ConfigMap。
# 工作目录:/MaiMBot/webui-cm-sync
import os
import time
from datetime import datetime
from kubernetes import client, config
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
work_dir = '/MaiMBot/webui-cm-sync'
os.chdir(work_dir)
config.load_incluster_config()
core_api = client.CoreV1Api()
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r") as f:
namespace = f.read().strip()
release_name = os.getenv("RELEASE_NAME")
model_configmap_name = f'{release_name}-maibot-core-model-config'
bot_configmap_name = f'{release_name}-maibot-core-bot-config'
# 过滤列表,只监控指定文件
target_files = {
os.path.abspath("model_config.toml"): (model_configmap_name, "model_config.toml"),
os.path.abspath("bot_config.toml"): (bot_configmap_name, "bot_config.toml")
}
def get_configmap(configmap_name: str):
"""获取core的ConfigMap内容"""
cm = core_api.read_namespaced_config_map(name=configmap_name, namespace=namespace)
return cm.data
def set_configmap(configmap_name: str, configmap_data: dict[str, str]):
"""设置core的ConfigMap内容"""
core_api.patch_namespaced_config_map(configmap_name, namespace, {'data': configmap_data})
class ConfigObserverHandler(FileSystemEventHandler):
"""配置文件变化的事件处理器"""
def on_modified(self, event):
if os.path.abspath(event.src_path) in target_files:
print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] File `{event.src_path}` was modified. '
f'Start to sync...')
with open(event.src_path, "r", encoding="utf-8") as _f:
current_data = _f.read()
_path = str(os.path.abspath(event.src_path))
new_cm = {
target_files[_path][1]: current_data
}
try:
set_configmap(target_files[_path][0], new_cm)
print(f'\tSync done.')
except client.exceptions.ApiException as _e:
print(f'\tError while setting configmap:\n'
f'\t\tStatus Code: {_e.status}\n'
f'\t\tReason: {_e.reason}')
if __name__ == '__main__':
# 初始化配置文件
print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Initializing config files...')
try:
__initial_model_config = get_configmap(model_configmap_name)['model_config.toml']
__initial_bot_config = get_configmap(bot_configmap_name)['bot_config.toml']
except client.exceptions.ApiException as e:
print(f'\tError while getting configmap:\n'
f'\t\tStatus Code: {e.status}\n'
f'\t\tReason: {e.reason}')
exit(1)
with open('model_config.toml', 'w') as f:
f.write(__initial_model_config)
with open('bot_config.toml', 'w') as f:
f.write(__initial_bot_config)
with open('ready', 'w') as f:
f.write('true')
print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Initializing done. Ready to sync.')
# 持续检测变化并同步
observer = Observer()
observer.schedule(ConfigObserverHandler(), work_dir, recursive=False)
observer.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()

View File

@@ -0,0 +1,3 @@
toml~=0.10.2
kubernetes~=34.1.0
watchdog~=6.0.0

View File

@@ -0,0 +1,56 @@
#!/bin/sh
# 此脚本用于覆盖core容器的默认启动命令进行一些初始化
# 1
# 由于k8s与docker-compose的卷挂载方式有所不同需要利用此脚本为一些文件和目录提前创建好软链接
# /MaiMBot/data是麦麦数据的实际挂载路径
# /MaiMBot/statistics是统计数据的实际挂载路径
# 2
# 此脚本等待辅助容器webui-cm-sync就绪后再启动麦麦
# 通过检测/MaiMBot/webui-cm-sync/ready文件来判断
set -e
echo "[K8s Init] Preparing volume..."
# 初次启动,在存储卷中检查并创建关键文件和目录
mkdir -p /MaiMBot/data/plugins
mkdir -p /MaiMBot/data/logs
if [ ! -d "/MaiMBot/statistics" ]
then
echo "[K8s Init] Statistics volume is disabled."
else
touch /MaiMBot/statistics/index.html
fi
# 删除默认插件目录,准备创建用户插件目录软链接
rm -rf /MaiMBot/plugins
# 创建软链接,从存储卷链接到实际位置
ln -s /MaiMBot/data/plugins /MaiMBot/plugins
ln -s /MaiMBot/data/logs /MaiMBot/logs
if [ -f "/MaiMBot/statistics/index.html" ]
then
ln -s /MaiMBot/statistics/index.html /MaiMBot/maibot_statistics.html
fi
echo "[K8s Init] Volume ready."
# 如果启用了WebUI则等待辅助容器webui-cm-sync就绪然后创建中间层配置文件软链接
if [ "$MAIBOT_WEBUI_ENABLED" = "true" ]
then
echo "[K8s Init] WebUI enabled. Waiting for container 'webui-cm-sync' ready..."
while [ ! -f /MaiMBot/webui-cm-sync/ready ]; do
sleep 1
done
echo "[K8s Init] Container 'webui-cm-sync' ready."
mkdir -p /MaiMBot/config
ln -s /MaiMBot/webui-cm-sync/model_config.toml /MaiMBot/config/model_config.toml
ln -s /MaiMBot/webui-cm-sync/bot_config.toml /MaiMBot/config/bot_config.toml
echo "[K8s Init] Config files middle layer for WebUI created."
else
echo "[K8s Init] WebUI is disabled."
fi
# 启动麦麦
echo "[K8s Init] Waking up MaiBot..."
echo
exec python bot.py

View File

@@ -1,33 +0,0 @@
#!/bin/sh
# 此脚本用于覆盖core容器的默认启动命令
# 由于k8s与docker-compose的卷挂载方式有所不同需要利用此脚本为一些文件和目录提前创建好软链接
# /MaiMBot/data是麦麦数据的实际挂载路径
# /MaiMBot/statistics是统计数据的实际挂载路径
set -e
echo "[VolumeLinker] Preparing volume..."
# 初次启动,在存储卷中检查并创建关键文件和目录
mkdir -p /MaiMBot/data/plugins
mkdir -p /MaiMBot/data/logs
if [ ! -d "/MaiMBot/statistics" ]
then
echo "[VolumeLinker] Statistics volume disabled."
else
touch /MaiMBot/statistics/index.html
fi
# 删除空的插件目录,准备创建软链接
rm -rf /MaiMBot/plugins
# 创建软链接,从存储卷链接到实际位置
ln -s /MaiMBot/data/plugins /MaiMBot/plugins
ln -s /MaiMBot/data/logs /MaiMBot/logs
if [ -f "/MaiMBot/statistics/index.html" ]
then
ln -s /MaiMBot/statistics/index.html /MaiMBot/maibot_statistics.html
fi
# 启动麦麦
echo "[VolumeLinker] Starting MaiBot..."
exec python bot.py

View File

@@ -58,5 +58,5 @@ spec:
items: items:
- key: config.toml - key: config.toml
path: config.toml path: config.toml
name: {{ .Release.Name }}-maibot-adapter name: {{ .Release.Name }}-maibot-adapter-config
name: config name: config

View File

@@ -0,0 +1,14 @@
{{- if or .Release.IsInstall (and .Values.core.webui.enabled .Values.config.enable_config_override_with_webui) (and (not .Values.core.webui.enabled) .Values.config.enable_config_override_without_webui) }}
# 渲染规则:
# 初次安装,或配置了覆盖规则
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ .Release.Name }}-maibot-core-bot-config
namespace: {{ .Release.Namespace }}
annotations:
"helm.sh/resource-policy": keep
data:
bot_config.toml: |
{{ .Values.config.core_bot_config | nindent 4 }}
{{- end }}

View File

@@ -0,0 +1,13 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ .Release.Name }}-maibot-core-env-config
namespace: {{ .Release.Namespace }}
data:
.env: |
HOST=0.0.0.0
PORT=8000
WEBUI_ENABLED={{ if .Values.core.webui.enabled }}true{{ else }}false{{ end }}
WEBUI_MODE=production
WEBUI_HOST=0.0.0.0
WEBUI_PORT=8001

View File

@@ -0,0 +1,14 @@
{{- if or .Release.IsInstall (and .Values.core.webui.enabled .Values.config.enable_config_override_with_webui) (and (not .Values.core.webui.enabled) .Values.config.enable_config_override_without_webui) }}
# 渲染规则:
# 初次安装,或配置了覆盖规则
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ .Release.Name }}-maibot-core-model-config
namespace: {{ .Release.Namespace }}
annotations:
"helm.sh/resource-policy": keep
data:
model_config.toml: |
{{ .Values.config.core_model_config | nindent 4 }}
{{- end }}

View File

@@ -1,13 +0,0 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ .Release.Name }}-maibot-core
namespace: {{ .Release.Namespace }}
data:
.env: |
HOST=0.0.0.0
PORT=8000
model_config.toml: |
{{ .Values.config.core_model_config | nindent 4 }}
bot_config.toml: |
{{ .Values.config.core_bot_config | nindent 4 }}

View File

@@ -0,0 +1,26 @@
{{- if and .Values.core.webui.enabled .Values.core.webui.ingress.enabled }}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ .Release.Name }}-maibot-webui
namespace: {{ .Release.Namespace }}
{{- if .Values.core.webui.ingress.annotations }}
annotations:
{{ toYaml .Values.core.webui.ingress.annotations | nindent 4 }}
{{- end }}
labels:
app: {{ .Release.Name }}-maibot-core
spec:
ingressClassName: {{ .Values.core.webui.ingress.className }}
rules:
- host: {{ .Values.core.webui.ingress.host }}
http:
paths:
- backend:
service:
name: {{ .Release.Name }}-maibot-core
port:
number: {{ .Values.core.webui.service.port }}
path: {{ .Values.core.webui.ingress.path }}
pathType: {{ .Values.core.webui.ingress.pathType }}
{{- end }}

View File

@@ -11,6 +11,15 @@ spec:
port: 8000 port: 8000
protocol: TCP protocol: TCP
targetPort: 8000 targetPort: 8000
{{- if .Values.core.webui.enabled }}
- name: webui
port: {{ .Values.core.webui.service.port }}
protocol: TCP
targetPort: 8001
{{- if eq .Values.core.webui.service.type "NodePort" }}
nodePort: {{ .Values.core.webui.service.nodePort | default nil }}
{{- end }}
{{- end }}
selector: selector:
app: {{ .Release.Name }}-maibot-core app: {{ .Release.Name }}-maibot-core
type: ClusterIP type: ClusterIP

View File

@@ -18,23 +18,32 @@ spec:
spec: spec:
containers: containers:
- name: core - name: core
command: # 为了在k8s中初始化存储卷,这里替换启动命令为指定脚本 command: # 为了在k8s中初始化这里替换启动命令为指定脚本
- sh - sh
args: args:
- /MaiMBot/volume-linker.sh - /MaiMBot/k8s-init.sh
env: env:
- name: TZ - name: TZ
value: Asia/Shanghai value: "Asia/Shanghai"
- name: EULA_AGREE - name: EULA_AGREE
value: 99f08e0cab0190de853cb6af7d64d4de value: "99f08e0cab0190de853cb6af7d64d4de"
- name: PRIVACY_AGREE - name: PRIVACY_AGREE
value: 9943b855e72199d0f5016ea39052f1b6 value: "9943b855e72199d0f5016ea39052f1b6"
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.11.2-beta" }} {{- if .Values.core.webui.enabled }}
- name: MAIBOT_WEBUI_ENABLED
value: "true"
{{- end}}
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.11.5-beta" }}
imagePullPolicy: {{ .Values.core.image.pullPolicy }} imagePullPolicy: {{ .Values.core.image.pullPolicy }}
ports: ports:
- containerPort: 8000 - containerPort: 8000
name: adapter-ws name: adapter-ws
protocol: TCP protocol: TCP
{{- if .Values.core.webui.enabled }}
- containerPort: 8001
name: webui
protocol: TCP
{{- end }}
{{- if .Values.core.resources }} {{- if .Values.core.resources }}
resources: resources:
{{ toYaml .Values.core.resources | nindent 12 }} {{ toYaml .Values.core.resources | nindent 12 }}
@@ -42,26 +51,45 @@ spec:
volumeMounts: volumeMounts:
- mountPath: /MaiMBot/data - mountPath: /MaiMBot/data
name: data name: data
- mountPath: /MaiMBot/volume-linker.sh - mountPath: /MaiMBot/k8s-init.sh
name: scripts name: scripts
readOnly: true readOnly: true
subPath: volume-linker.sh subPath: k8s-init.sh
- mountPath: /MaiMBot/.env - mountPath: /MaiMBot/.env
name: config name: env-config
readOnly: true readOnly: true
subPath: .env subPath: .env
{{- if not .Values.core.webui.enabled }}
- mountPath: /MaiMBot/config/model_config.toml - mountPath: /MaiMBot/config/model_config.toml
name: config name: model-config
readOnly: true readOnly: true
subPath: model_config.toml subPath: model_config.toml
- mountPath: /MaiMBot/config/bot_config.toml - mountPath: /MaiMBot/config/bot_config.toml
name: config name: bot-config
readOnly: true readOnly: true
subPath: bot_config.toml subPath: bot_config.toml
{{- end }}
{{- if .Values.statistics_dashboard.enabled }} {{- if .Values.statistics_dashboard.enabled }}
- mountPath: /MaiMBot/statistics - mountPath: /MaiMBot/statistics
name: statistics name: statistics
{{- end }} {{- end }}
{{- if .Values.core.webui.enabled }}
- mountPath: /MaiMBot/webui-cm-sync
name: webui-cm-sync
{{- end }}
{{- if .Values.core.webui.enabled }}
- name: webui-cm-sync
image: {{ .Values.core.webui.cm_sync.image.repository | default "reg.mikumikumi.xyz/maibot/core-webui-cm-sync" }}:{{ .Values.core.webui.cm_sync.image.tag | default "0.11.5-beta" }}
imagePullPolicy: {{ .Values.core.webui.cm_sync.image.pullPolicy }}
env:
- name: PYTHONUNBUFFERED
value: "1"
- name: RELEASE_NAME
value: {{ .Release.Name }}
volumeMounts:
- mountPath: /MaiMBot/webui-cm-sync
name: webui-cm-sync
{{- end }}
{{- if .Values.core.setup_default_plugins }} {{- if .Values.core.setup_default_plugins }}
initContainers: # 用户插件目录存储在存储卷中,会在启动时覆盖掉容器的默认插件目录。此初始化容器用于默认插件更新后或麦麦首次启动时为用户自动安装默认插件到存储卷中 initContainers: # 用户插件目录存储在存储卷中,会在启动时覆盖掉容器的默认插件目录。此初始化容器用于默认插件更新后或麦麦首次启动时为用户自动安装默认插件到存储卷中
- args: - args:
@@ -69,7 +97,7 @@ spec:
command: command:
- python3 - python3
workingDir: /MaiMBot workingDir: /MaiMBot
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.11.2-beta" }} image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.11.5-beta" }}
imagePullPolicy: {{ .Values.core.image.pullPolicy }} imagePullPolicy: {{ .Values.core.image.pullPolicy }}
name: setup-plugins name: setup-plugins
resources: { } resources: { }
@@ -81,6 +109,7 @@ spec:
readOnly: true readOnly: true
subPath: setup-plugins.py subPath: setup-plugins.py
{{- end }} {{- end }}
serviceAccountName: {{ .Release.Name }}-maibot-sa
{{- if .Values.core.image.pullSecrets }} {{- if .Values.core.image.pullSecrets }}
imagePullSecrets: imagePullSecrets:
{{ toYaml .Values.core.image.pullSecrets | nindent 8 }} {{ toYaml .Values.core.image.pullSecrets | nindent 8 }}
@@ -99,8 +128,8 @@ spec:
claimName: {{ .Release.Name }}-maibot-core claimName: {{ .Release.Name }}-maibot-core
- configMap: - configMap:
items: items:
- key: volume-linker.sh - key: k8s-init.sh
path: volume-linker.sh path: k8s-init.sh
{{- if .Values.core.setup_default_plugins }} {{- if .Values.core.setup_default_plugins }}
- key: setup-plugins.py - key: setup-plugins.py
path: setup-plugins.py path: setup-plugins.py
@@ -111,14 +140,28 @@ spec:
items: items:
- key: .env - key: .env
path: .env path: .env
name: {{ .Release.Name }}-maibot-core-env-config
name: env-config
{{- if not .Values.core.webui.enabled }}
- configMap:
items:
- key: model_config.toml - key: model_config.toml
path: model_config.toml path: model_config.toml
name: {{ .Release.Name }}-maibot-core-model-config
name: model-config
- configMap:
items:
- key: bot_config.toml - key: bot_config.toml
path: bot_config.toml path: bot_config.toml
name: {{ .Release.Name }}-maibot-core name: {{ .Release.Name }}-maibot-core-bot-config
name: config name: bot-config
{{- end }}
{{- if .Values.statistics_dashboard.enabled }} {{- if .Values.statistics_dashboard.enabled }}
- name: statistics - name: statistics
persistentVolumeClaim: persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-statistics-dashboard claimName: {{ .Release.Name }}-maibot-statistics-dashboard
{{- end }} {{- end }}
{{- if .Values.core.webui.enabled }}
- emptyDir: {}
name: webui-cm-sync
{{- end }}

View File

@@ -26,7 +26,7 @@ spec:
value: "{{ .Values.napcat.permission.uid }}" value: "{{ .Values.napcat.permission.uid }}"
- name: TZ - name: TZ
value: Asia/Shanghai value: Asia/Shanghai
image: {{ .Values.napcat.image.repository | default "mlikiowa/napcat-docker" }}:{{ .Values.napcat.image.tag | default "v4.9.70" }} image: {{ .Values.napcat.image.repository | default "mlikiowa/napcat-docker" }}:{{ .Values.napcat.image.tag | default "v4.9.73" }}
imagePullPolicy: {{ .Values.napcat.image.pullPolicy }} imagePullPolicy: {{ .Values.napcat.image.pullPolicy }}
livenessProbe: livenessProbe:
failureThreshold: 3 failureThreshold: 3

View File

@@ -5,8 +5,8 @@ metadata:
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
data: data:
# core # core
volume-linker.sh: | k8s-init.sh: |
{{ .Files.Get "files/volume-linker.sh" | nindent 4 }} {{ .Files.Get "files/k8s-init.sh" | nindent 4 }}
# core的初始化容器 # core的初始化容器
{{- if .Values.core.setup_default_plugins }} {{- if .Values.core.setup_default_plugins }}
setup-plugins.py: | setup-plugins.py: |

View File

@@ -11,11 +11,11 @@ spec:
backoffLimit: 2 backoffLimit: 2
template: template:
spec: spec:
serviceAccountName: {{ .Release.Name }}-maibot-adapter-cm-generator serviceAccountName: {{ .Release.Name }}-maibot-sa
restartPolicy: Never restartPolicy: Never
containers: containers:
- name: adapter-cm-generator - name: adapter-cm-generator
image: {{ .Values.adapter.cm_generator.image.repository | default "reg.mikumikumi.xyz/maibot/adapter-cm-generator" }}:{{ .Values.adapter.cm_generator.image.tag | default "0.11.2-beta" }} image: {{ .Values.adapter.cm_generator.image.repository | default "reg.mikumikumi.xyz/maibot/adapter-cm-generator" }}:{{ .Values.adapter.cm_generator.image.tag | default "0.11.5-beta" }}
workingDir: /app workingDir: /app
env: env:
- name: PYTHONUNBUFFERED - name: PYTHONUNBUFFERED

View File

@@ -1,14 +1,14 @@
# 动态生成adapter配置文件的configmap所需要的rbac授权 # 初始化及反向修改ConfigMap所需要的rbac授权
apiVersion: v1 apiVersion: v1
kind: ServiceAccount kind: ServiceAccount
metadata: metadata:
name: {{ .Release.Name }}-maibot-adapter-cm-generator name: {{ .Release.Name }}-maibot-sa
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
--- ---
apiVersion: rbac.authorization.k8s.io/v1 apiVersion: rbac.authorization.k8s.io/v1
kind: Role kind: Role
metadata: metadata:
name: {{ .Release.Name }}-maibot-adapter-cm-gen-role name: {{ .Release.Name }}-maibot-role
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
rules: rules:
- apiGroups: [""] - apiGroups: [""]
@@ -21,13 +21,13 @@ rules:
apiVersion: rbac.authorization.k8s.io/v1 apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding kind: RoleBinding
metadata: metadata:
name: {{ .Release.Name }}-maibot-adapter-cm-gen-role-binding name: {{ .Release.Name }}-maibot-rolebinding
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
subjects: subjects:
- kind: ServiceAccount - kind: ServiceAccount
name: {{ .Release.Name }}-maibot-adapter-cm-generator name: {{ .Release.Name }}-maibot-sa
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
roleRef: roleRef:
kind: Role kind: Role
name: {{ .Release.Name }}-maibot-adapter-cm-gen-role name: {{ .Release.Name }}-maibot-role
apiGroup: rbac.authorization.k8s.io apiGroup: rbac.authorization.k8s.io

View File

@@ -38,7 +38,7 @@ adapter:
cm_generator: cm_generator:
image: image:
repository: # 默认 reg.mikumikumi.xyz/maibot/adapter-cm-generator repository: # 默认 reg.mikumikumi.xyz/maibot/adapter-cm-generator
tag: # 默认 0.11.2-beta tag: # 默认 0.11.5-beta
pullPolicy: IfNotPresent pullPolicy: IfNotPresent
pullSecrets: [ ] pullSecrets: [ ]
@@ -48,7 +48,7 @@ core:
image: image:
repository: # 默认 sengokucola/maibot repository: # 默认 sengokucola/maibot
tag: # 默认 0.11.2-beta tag: # 默认 0.11.5-beta
pullPolicy: IfNotPresent pullPolicy: IfNotPresent
pullSecrets: [ ] pullSecrets: [ ]
@@ -65,6 +65,28 @@ core:
setup_default_plugins: true # 启用一个初始化容器,用于为用户自动安装默认插件到存储卷中 setup_default_plugins: true # 启用一个初始化容器,用于为用户自动安装默认插件到存储卷中
webui: # WebUI相关配置
# 对从helm chart 0.11.5-beta 之前的旧版升级的用户的重要提示:
# 旧版本helm chart没有配置configmap的保留策略直接升级同时开启WebUI后会导致configmap丢失从而无法启动。
# 这种情况可以先禁用WebUI更新一次再启用WebUI更新一次即可解决问题。
enabled: true # 默认启用
cm_sync: # WebUI的辅助容器配置
image:
repository: # 默认 reg.mikumikumi.xyz/maibot/core-webui-cm-sync
tag: # 默认 0.11.5-beta
pullPolicy: IfNotPresent
service:
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
port: 8001 # 服务端口
nodePort: # 仅在设置NodePort类型时有效不指定则会随机分配端口
ingress:
enabled: false
className: nginx
annotations: { }
host: maim.example.com # 访问麦麦WebUI的域名
path: /
pathType: Prefix
# 麦麦的运行统计看板配置 # 麦麦的运行统计看板配置
# 麦麦每隔一段时间会自动输出html格式的运行统计报告此统计报告可以作为静态网页访问 # 麦麦每隔一段时间会自动输出html格式的运行统计报告此统计报告可以作为静态网页访问
# 此功能默认禁用。如果你认为报告可以被公开访问(报告包含联系人/群组名称、模型token花费信息等则可以启用此功能 # 此功能默认禁用。如果你认为报告可以被公开访问(报告包含联系人/群组名称、模型token花费信息等则可以启用此功能
@@ -117,7 +139,7 @@ napcat:
image: image:
repository: # 默认 mlikiowa/napcat-docker repository: # 默认 mlikiowa/napcat-docker
tag: # 默认 v4.9.70 tag: # 默认 v4.9.73
pullPolicy: IfNotPresent pullPolicy: IfNotPresent
pullSecrets: [ ] pullSecrets: [ ]
@@ -189,9 +211,16 @@ sqlite_web:
path: / path: /
pathType: Prefix pathType: Prefix
# 麦麦各部分组件的运行配置文件 # 手动设置麦麦各部分组件的运行配置文件
config: config:
# 启用WebUI后配置文件的修改即可在WebUI进行。如果通过WebUI修改了配置则实际的配置文件将与values中的配置存在差异。
# 如果用户在k8s的ConfigMap中手动修改了配置文件则实际的配置文件也会与values中的配置存在差异。
# 出现上述两种情况时为了避免helm升级麦麦时下面values中的配置覆盖掉已有的配置文件而导致配置丢失可以在这里禁止本次部署时的配置覆盖。
# 注由于adapter的配置无法通过WebUI修改因此下面的adapter_config配置仍然会覆盖已有配置文件。
enable_config_override_without_webui: true # 未启用WebUI时默认覆盖
enable_config_override_with_webui: false # 启用WebUI时默认不覆盖
# adapter的config.toml # adapter的config.toml
adapter_config: | adapter_config: |
[inner] [inner]
@@ -227,7 +256,7 @@ config:
# core的model_config.toml # core的model_config.toml
core_model_config: | core_model_config: |
[inner] [inner]
version = "1.7.7" version = "1.7.8"
# 配置文件版本号迭代规则同bot_config.toml # 配置文件版本号迭代规则同bot_config.toml
@@ -388,7 +417,7 @@ config:
# core的bot_config.toml # core的bot_config.toml
core_bot_config: | core_bot_config: |
[inner] [inner]
version = "6.21.4" version = "6.21.8"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值 #如果你想要修改配置文件请递增version的值
@@ -493,7 +522,7 @@ config:
include_planner_reasoning = false # 是否将planner推理加入replyer默认关闭不加入 include_planner_reasoning = false # 是否将planner推理加入replyer默认关闭不加入
[memory] [memory]
max_agent_iterations = 5 # 记忆思考深度最低为1不深入思考 max_agent_iterations = 3 # 记忆思考深度最低为1不深入思考
[jargon] [jargon]
all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除 all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除
@@ -600,6 +629,9 @@ config:
show_replyer_prompt = false # 是否显示回复器prompt show_replyer_prompt = false # 是否显示回复器prompt
show_replyer_reasoning = false # 是否显示回复器推理 show_replyer_reasoning = false # 是否显示回复器推理
show_jargon_prompt = false # 是否显示jargon相关提示词 show_jargon_prompt = false # 是否显示jargon相关提示词
show_memory_prompt = false # 是否显示记忆检索相关提示词
show_planner_prompt = false # 是否显示planner的prompt和原始返回结果
show_lpmm_paragraph = false # 是否显示lpmm找到的相关文段日志
[maim_message] [maim_message]
auth_token = [] # 认证令牌用于API验证为空则不启用验证 auth_token = [] # 认证令牌用于API验证为空则不启用验证

View File

@@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT) sys.path.insert(0, PROJECT_ROOT)
SECONDS_5_MINUTES = 5 * 60 SECONDS_5_MINUTES = 5 * 60

View File

@@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
# 设置中文字体 # 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"] plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False plt.rcParams["axes.unicode_minus"] = False

View File

@@ -57,8 +57,8 @@ from src.common.database.database import db
from src.common.database.database_model import Emoji from src.common.database.database_model import Emoji
# 常量定义 # 常量定义
MAGIC = b'MMIP' MAGIC = b"MMIP"
FOOTER_MAGIC = b'MMFF' FOOTER_MAGIC = b"MMFF"
VERSION = 1 VERSION = 1
FOOTER_VERSION = 1 FOOTER_VERSION = 1
@@ -67,7 +67,7 @@ MAX_MANIFEST_SIZE = 200 * 1024 * 1024 # 200 MB
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB
# 支持的图片格式 # 支持的图片格式
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.avif', '.bmp'} SUPPORTED_FORMATS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".avif", ".bmp"}
# 创建控制台对象 # 创建控制台对象
console = Console() console = Console()
@@ -75,6 +75,7 @@ console = Console()
class MMIPKGError(Exception): class MMIPKGError(Exception):
"""MMIPKG 相关错误""" """MMIPKG 相关错误"""
pass pass
@@ -97,55 +98,55 @@ def get_image_info(file_path: str) -> Tuple[int, int, str]:
try: try:
with Image.open(file_path) as img: with Image.open(file_path) as img:
width, height = img.size width, height = img.size
format_lower = img.format.lower() if img.format else 'unknown' format_lower = img.format.lower() if img.format else "unknown"
mime_map = { mime_map = {
'jpeg': 'image/jpeg', "jpeg": "image/jpeg",
'jpg': 'image/jpeg', "jpg": "image/jpeg",
'png': 'image/png', "png": "image/png",
'gif': 'image/gif', "gif": "image/gif",
'webp': 'image/webp', "webp": "image/webp",
'avif': 'image/avif', "avif": "image/avif",
'bmp': 'image/bmp' "bmp": "image/bmp",
} }
mime_type = mime_map.get(format_lower, f'image/{format_lower}') mime_type = mime_map.get(format_lower, f"image/{format_lower}")
return width, height, mime_type return width, height, mime_type
except Exception as e: except Exception as e:
print(f"警告: 无法读取图片信息 {file_path}: {e}") print(f"警告: 无法读取图片信息 {file_path}: {e}")
return 0, 0, 'image/unknown' return 0, 0, "image/unknown"
def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 80) -> bytes: def reencode_image(file_path: str, output_format: str = "webp", quality: int = 80) -> bytes:
"""重新编码图片""" """重新编码图片"""
try: try:
with Image.open(file_path) as img: with Image.open(file_path) as img:
# 转换为 RGB如果需要 # 转换为 RGB如果需要
if img.mode in ('RGBA', 'LA', 'P'): if img.mode in ("RGBA", "LA", "P"):
if output_format.lower() == 'jpeg': if output_format.lower() == "jpeg":
# JPEG 不支持透明度,转为白色背景 # JPEG 不支持透明度,转为白色背景
background = Image.new('RGB', img.size, (255, 255, 255)) background = Image.new("RGB", img.size, (255, 255, 255))
if img.mode == 'P': if img.mode == "P":
img = img.convert('RGBA') img = img.convert("RGBA")
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
img = background img = background
elif output_format.lower() == 'webp': elif output_format.lower() == "webp":
# WebP 支持透明度 # WebP 支持透明度
if img.mode == 'P': if img.mode == "P":
img = img.convert('RGBA') img = img.convert("RGBA")
elif img.mode not in ('RGB', 'RGBA'): elif img.mode not in ("RGB", "RGBA"):
img = img.convert('RGB') img = img.convert("RGB")
# 编码图片 # 编码图片
output = io.BytesIO() output = io.BytesIO()
save_kwargs = {'format': output_format.upper()} save_kwargs = {"format": output_format.upper()}
if output_format.lower() in {'jpeg', 'jpg'}: if output_format.lower() in {"jpeg", "jpg"}:
save_kwargs['quality'] = quality save_kwargs["quality"] = quality
save_kwargs['optimize'] = True save_kwargs["optimize"] = True
elif output_format.lower() == 'webp': elif output_format.lower() == "webp":
save_kwargs['quality'] = quality save_kwargs["quality"] = quality
save_kwargs['method'] = 6 # 更好的压缩 save_kwargs["method"] = 6 # 更好的压缩
elif output_format.lower() == 'png': elif output_format.lower() == "png":
save_kwargs['optimize'] = True save_kwargs["optimize"] = True
img.save(output, **save_kwargs) img.save(output, **save_kwargs)
return output.getvalue() return output.getvalue()
@@ -156,11 +157,13 @@ def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 8
class MMIPKGPacker: class MMIPKGPacker:
"""MMIPKG 打包器""" """MMIPKG 打包器"""
def __init__(self, def __init__(
use_compression: bool = True, self,
zstd_level: int = 3, use_compression: bool = True,
reencode: Optional[str] = None, zstd_level: int = 3,
reencode_quality: int = 80): reencode: Optional[str] = None,
reencode_quality: int = 80,
):
self.use_compression = use_compression and zstd is not None self.use_compression = use_compression and zstd is not None
self.zstd_level = zstd_level self.zstd_level = zstd_level
self.reencode = reencode self.reencode = reencode
@@ -170,8 +173,9 @@ class MMIPKGPacker:
print("警告: zstandard 未安装,将不使用压缩") print("警告: zstandard 未安装,将不使用压缩")
self.use_compression = False self.use_compression = False
def pack_from_db(self, output_path: str, pack_name: Optional[str] = None, def pack_from_db(
custom_manifest: Optional[Dict] = None) -> bool: self, output_path: str, pack_name: Optional[str] = None, custom_manifest: Optional[Dict] = None
) -> bool:
"""从数据库导出已注册的表情包 """从数据库导出已注册的表情包
Args: Args:
@@ -205,12 +209,14 @@ class MMIPKGPacker:
BarColumn(), BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(), TimeElapsedColumn(),
console=console console=console,
) as progress: ) as progress:
task = progress.add_task("[cyan]扫描表情包...", total=emoji_count) task = progress.add_task("[cyan]扫描表情包...", total=emoji_count)
for idx, emoji in enumerate(emojis, 1): for idx, emoji in enumerate(emojis, 1):
progress.update(task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}") progress.update(
task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}"
)
# 检查文件是否存在 # 检查文件是否存在
if not os.path.exists(emoji.full_path): if not os.path.exists(emoji.full_path):
@@ -224,10 +230,10 @@ class MMIPKGPacker:
img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality) img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality)
except Exception as e: except Exception as e:
console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]") console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]")
with open(emoji.full_path, 'rb') as f: with open(emoji.full_path, "rb") as f:
img_bytes = f.read() img_bytes = f.read()
else: else:
with open(emoji.full_path, 'rb') as f: with open(emoji.full_path, "rb") as f:
img_bytes = f.read() img_bytes = f.read()
# 计算 SHA256 # 计算 SHA256
@@ -259,7 +265,7 @@ class MMIPKGPacker:
"emoji_hash": emoji.emoji_hash or "", "emoji_hash": emoji.emoji_hash or "",
"is_registered": True, "is_registered": True,
"is_banned": emoji.is_banned or False, "is_banned": emoji.is_banned or False,
} },
} }
items.append(item) items.append(item)
@@ -281,7 +287,7 @@ class MMIPKGPacker:
"p": pack_id, # pack_id "p": pack_id, # pack_id
"n": pack_name, # pack_name "n": pack_name, # pack_name
"t": datetime.now().isoformat(), # created_at "t": datetime.now().isoformat(), # created_at
"a": items # items array "a": items, # items array
} }
# 添加自定义字段 # 添加自定义字段
@@ -308,26 +314,28 @@ class MMIPKGPacker:
except Exception as e: except Exception as e:
print(f"打包失败: {e}") print(f"打包失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
finally: finally:
if not db.is_closed(): if not db.is_closed():
db.close() db.close()
def _write_package(self, output_path: str, manifest_bytes: bytes, def _write_package(
image_data_list: List[bytes], payload_size: int) -> bool: self, output_path: str, manifest_bytes: bytes, image_data_list: List[bytes], payload_size: int
) -> bool:
"""写入打包文件""" """写入打包文件"""
try: try:
with open(output_path, 'wb') as f: with open(output_path, "wb") as f:
# 写入 Header (32 bytes) # 写入 Header (32 bytes)
flags = 0x01 if self.use_compression else 0x00 flags = 0x01 if self.use_compression else 0x00
header = MAGIC # 4 bytes header = MAGIC # 4 bytes
header += struct.pack('B', VERSION) # 1 byte header += struct.pack("B", VERSION) # 1 byte
header += struct.pack('B', flags) # 1 byte header += struct.pack("B", flags) # 1 byte
header += b'\x00\x00' # 2 bytes reserved header += b"\x00\x00" # 2 bytes reserved
header += struct.pack('>Q', payload_size) # 8 bytes header += struct.pack(">Q", payload_size) # 8 bytes
header += struct.pack('>Q', len(manifest_bytes)) # 8 bytes header += struct.pack(">Q", len(manifest_bytes)) # 8 bytes
header += b'\x00' * 8 # 8 bytes reserved header += b"\x00" * 8 # 8 bytes reserved
assert len(header) == 32, f"Header size mismatch: {len(header)}" assert len(header) == 32, f"Header size mismatch: {len(header)}"
f.write(header) f.write(header)
@@ -342,7 +350,7 @@ class MMIPKGPacker:
with compressor.stream_writer(f, closefd=False) as writer: with compressor.stream_writer(f, closefd=False) as writer:
# 写入 manifest # 写入 manifest
manifest_len_bytes = struct.pack('>I', len(manifest_bytes)) manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
writer.write(manifest_len_bytes) writer.write(manifest_len_bytes)
writer.write(manifest_bytes) writer.write(manifest_bytes)
payload_sha.update(manifest_len_bytes) payload_sha.update(manifest_len_bytes)
@@ -355,13 +363,13 @@ class MMIPKGPacker:
BarColumn(), BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(), TimeRemainingColumn(),
console=console console=console,
) as progress: ) as progress:
task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list)) task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list))
for idx, img_bytes in enumerate(image_data_list, 1): for idx, img_bytes in enumerate(image_data_list, 1):
progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}") progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}")
img_len_bytes = struct.pack('>I', len(img_bytes)) img_len_bytes = struct.pack(">I", len(img_bytes))
writer.write(img_len_bytes) writer.write(img_len_bytes)
writer.write(img_bytes) writer.write(img_bytes)
payload_sha.update(img_len_bytes) payload_sha.update(img_len_bytes)
@@ -370,7 +378,7 @@ class MMIPKGPacker:
else: else:
# 不压缩,直接写入 # 不压缩,直接写入
# 写入 manifest # 写入 manifest
manifest_len_bytes = struct.pack('>I', len(manifest_bytes)) manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
f.write(manifest_len_bytes) f.write(manifest_len_bytes)
f.write(manifest_bytes) f.write(manifest_bytes)
payload_sha.update(manifest_len_bytes) payload_sha.update(manifest_len_bytes)
@@ -383,13 +391,13 @@ class MMIPKGPacker:
BarColumn(), BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(), TimeRemainingColumn(),
console=console console=console,
) as progress: ) as progress:
task = progress.add_task("[green]写入图片...", total=len(image_data_list)) task = progress.add_task("[green]写入图片...", total=len(image_data_list))
for idx, img_bytes in enumerate(image_data_list, 1): for idx, img_bytes in enumerate(image_data_list, 1):
progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}") progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}")
img_len_bytes = struct.pack('>I', len(img_bytes)) img_len_bytes = struct.pack(">I", len(img_bytes))
f.write(img_len_bytes) f.write(img_len_bytes)
f.write(img_bytes) f.write(img_bytes)
payload_sha.update(img_len_bytes) payload_sha.update(img_len_bytes)
@@ -400,8 +408,8 @@ class MMIPKGPacker:
file_sha256 = payload_sha.digest() file_sha256 = payload_sha.digest()
footer = FOOTER_MAGIC # 4 bytes footer = FOOTER_MAGIC # 4 bytes
footer += file_sha256 # 32 bytes footer += file_sha256 # 32 bytes
footer += struct.pack('B', FOOTER_VERSION) # 1 byte footer += struct.pack("B", FOOTER_VERSION) # 1 byte
footer += b'\x00' * 3 # 3 bytes reserved footer += b"\x00" * 3 # 3 bytes reserved
assert len(footer) == 40, f"Footer size mismatch: {len(footer)}" assert len(footer) == 40, f"Footer size mismatch: {len(footer)}"
f.write(footer) f.write(footer)
@@ -419,6 +427,7 @@ class MMIPKGPacker:
except Exception as e: except Exception as e:
print(f"写入文件失败: {e}") print(f"写入文件失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
@@ -429,10 +438,9 @@ class MMIPKGUnpacker:
def __init__(self, verify_sha: bool = True): def __init__(self, verify_sha: bool = True):
self.verify_sha = verify_sha self.verify_sha = verify_sha
def import_to_db(self, package_path: str, def import_to_db(
output_dir: Optional[str] = None, self, package_path: str, output_dir: Optional[str] = None, replace_existing: bool = False, batch_size: int = 500
replace_existing: bool = False, ) -> bool:
batch_size: int = 500) -> bool:
"""导入到数据库""" """导入到数据库"""
try: try:
if not os.path.exists(package_path): if not os.path.exists(package_path):
@@ -451,7 +459,7 @@ class MMIPKGUnpacker:
print(f"正在读取包: {package_path}") print(f"正在读取包: {package_path}")
with open(package_path, 'rb') as f: with open(package_path, "rb") as f:
# 读取 Header # 读取 Header
header = f.read(32) header = f.read(32)
if len(header) != 32: if len(header) != 32:
@@ -461,15 +469,15 @@ class MMIPKGUnpacker:
if magic != MAGIC: if magic != MAGIC:
raise MMIPKGError(f"无效的 MAGIC: {magic}") raise MMIPKGError(f"无效的 MAGIC: {magic}")
version = struct.unpack('B', header[4:5])[0] version = struct.unpack("B", header[4:5])[0]
if version != VERSION: if version != VERSION:
print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配") print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配")
flags = struct.unpack('B', header[5:6])[0] flags = struct.unpack("B", header[5:6])[0]
is_compressed = bool(flags & 0x01) is_compressed = bool(flags & 0x01)
payload_uncompressed_len = struct.unpack('>Q', header[8:16])[0] payload_uncompressed_len = struct.unpack(">Q", header[8:16])[0]
manifest_uncompressed_len = struct.unpack('>Q', header[16:24])[0] manifest_uncompressed_len = struct.unpack(">Q", header[16:24])[0]
# 安全检查 # 安全检查
if manifest_uncompressed_len > MAX_MANIFEST_SIZE: if manifest_uncompressed_len > MAX_MANIFEST_SIZE:
@@ -519,7 +527,9 @@ class MMIPKGUnpacker:
# 方法2如果流式失败尝试直接解压兼容旧格式 # 方法2如果流式失败尝试直接解压兼容旧格式
print(f" 流式解压失败,尝试直接解压: {e}") print(f" 流式解压失败,尝试直接解压: {e}")
try: try:
payload_data = decompressor.decompress(compressed_data, max_output_size=payload_uncompressed_len) payload_data = decompressor.decompress(
compressed_data, max_output_size=payload_uncompressed_len
)
except Exception as e2: except Exception as e2:
raise MMIPKGError(f"解压失败: {e2}") from e2 raise MMIPKGError(f"解压失败: {e2}") from e2
else: else:
@@ -537,7 +547,7 @@ class MMIPKGUnpacker:
# 读取 manifest # 读取 manifest
manifest_len_bytes = payload_stream.read(4) manifest_len_bytes = payload_stream.read(4)
manifest_len = struct.unpack('>I', manifest_len_bytes)[0] manifest_len = struct.unpack(">I", manifest_len_bytes)[0]
manifest_bytes = payload_stream.read(manifest_len) manifest_bytes = payload_stream.read(manifest_len)
manifest = msgpack.unpackb(manifest_bytes, raw=False) manifest = msgpack.unpackb(manifest_bytes, raw=False)
@@ -553,20 +563,21 @@ class MMIPKGUnpacker:
print(f" 表情包数量: {len(items)}") print(f" 表情包数量: {len(items)}")
# 导入表情包 # 导入表情包
return self._import_items(payload_stream, items, output_dir, return self._import_items(payload_stream, items, output_dir, replace_existing, batch_size)
replace_existing, batch_size)
except Exception as e: except Exception as e:
print(f"导入失败: {e}") print(f"导入失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
finally: finally:
if not db.is_closed(): if not db.is_closed():
db.close() db.close()
def _import_items(self, payload_stream: BinaryIO, items: List[Dict], def _import_items(
output_dir: str, replace_existing: bool, batch_size: int) -> bool: self, payload_stream: BinaryIO, items: List[Dict], output_dir: str, replace_existing: bool, batch_size: int
) -> bool:
"""导入 items 到数据库""" """导入 items 到数据库"""
try: try:
imported_count = 0 imported_count = 0
@@ -581,7 +592,7 @@ class MMIPKGUnpacker:
BarColumn(), BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(), TimeRemainingColumn(),
console=console console=console,
) as progress: ) as progress:
task = progress.add_task("[cyan]导入表情包...", total=len(items)) task = progress.add_task("[cyan]导入表情包...", total=len(items))
@@ -597,7 +608,7 @@ class MMIPKGUnpacker:
progress.advance(task) progress.advance(task)
continue continue
img_len = struct.unpack('>I', img_len_bytes)[0] img_len = struct.unpack(">I", img_len_bytes)[0]
img_bytes = payload_stream.read(img_len) img_bytes = payload_stream.read(img_len)
if len(img_bytes) != img_len: if len(img_bytes) != img_len:
@@ -641,7 +652,7 @@ class MMIPKGUnpacker:
file_path = os.path.join(output_dir, filename) file_path = os.path.join(output_dir, filename)
counter += 1 counter += 1
with open(file_path, 'wb') as img_file: with open(file_path, "wb") as img_file:
img_file.write(img_bytes) img_file.write(img_bytes)
# 准备数据库记录 # 准备数据库记录
@@ -700,6 +711,7 @@ class MMIPKGUnpacker:
except Exception as e: except Exception as e:
console.print(f"[red]导入 items 失败: {e}[/red]") console.print(f"[red]导入 items 失败: {e}[/red]")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
@@ -719,8 +731,9 @@ def print_menu():
console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)") console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)")
console.print(" [0] [bold]退出[/bold]") console.print(" [0] [bold]退出[/bold]")
console.print() console.print()
def get_input(prompt: str, default: Optional[str] = None,
choices: Optional[List[str]] = None) -> str:
def get_input(prompt: str, default: Optional[str] = None, choices: Optional[List[str]] = None) -> str:
"""获取用户输入""" """获取用户输入"""
if default: if default:
prompt = f"{prompt} (默认: {default})" prompt = f"{prompt} (默认: {default})"
@@ -760,9 +773,9 @@ def get_yes_no(prompt: str, default: bool = False) -> bool:
if not value: if not value:
return default return default
if value in ('y', 'yes', ''): if value in ("y", "yes", ""):
return True return True
elif value in ('n', 'no', ''): elif value in ("n", "no", ""):
return False return False
else: else:
console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]") console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]")
@@ -843,8 +856,8 @@ def interactive_export():
output_path = get_input(" 输出文件路径", default_filename) output_path = get_input(" 输出文件路径", default_filename)
# 确保有 .mmipkg 扩展名 # 确保有 .mmipkg 扩展名
if not output_path.endswith('.mmipkg'): if not output_path.endswith(".mmipkg"):
output_path += '.mmipkg' output_path += ".mmipkg"
# 获取包名称 # 获取包名称
default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}" default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}"
@@ -853,9 +866,7 @@ def interactive_export():
# 自定义 manifest # 自定义 manifest
console.print("\n[yellow]2. 包信息设置(可选)[/yellow]") console.print("\n[yellow]2. 包信息设置(可选)[/yellow]")
if get_yes_no(" 是否添加包的作者和介绍信息", False): if get_yes_no(" 是否添加包的作者和介绍信息", False):
custom_manifest = { custom_manifest = {"author": author} if (author := input(" 作者名称(可选): ").strip()) else {}
"author": author
} if (author := input(" 作者名称(可选): ").strip()) else {}
# 介绍信息 # 介绍信息
console.print(" 包介绍(限制 100 字以内):") console.print(" 包介绍(限制 100 字以内):")
@@ -888,9 +899,9 @@ def interactive_export():
console.print(" webp: 推荐,体积小且支持透明度") console.print(" webp: 推荐,体积小且支持透明度")
console.print(" jpeg: 最小体积,但不支持透明度") console.print(" jpeg: 最小体积,但不支持透明度")
console.print(" png: 无损,文件较大") console.print(" png: 无损,文件较大")
reencode = get_input(" 选择格式", "webp", ['webp', 'jpeg', 'png']) reencode = get_input(" 选择格式", "webp", ["webp", "jpeg", "png"])
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ('webp', 'jpeg') else 80 quality = get_int(" 编码质量", 80, 1, 100) if reencode in ("webp", "jpeg") else 80
else: else:
reencode = None reencode = None
quality = 80 quality = 80
@@ -920,10 +931,7 @@ def interactive_export():
# 开始导出 # 开始导出
console.print("\n[cyan]开始导出...[/cyan]") console.print("\n[cyan]开始导出...[/cyan]")
packer = MMIPKGPacker( packer = MMIPKGPacker(
use_compression=use_compression, use_compression=use_compression, zstd_level=zstd_level, reencode=reencode, reencode_quality=quality
zstd_level=zstd_level,
reencode=reencode,
reencode_quality=quality
) )
success = packer.pack_from_db(output_path, pack_name, custom_manifest) success = packer.pack_from_db(output_path, pack_name, custom_manifest)
@@ -944,11 +952,11 @@ def interactive_import():
# 选择导入模式 # 选择导入模式
print_import_mode_selection() print_import_mode_selection()
import_mode = get_input("请选择", "1", ['1', '2']) import_mode = get_input("请选择", "1", ["1", "2"])
input_files = [] input_files = []
if import_mode == '1': if import_mode == "1":
# 自动扫描模式 # 自动扫描模式
import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji") import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji")
os.makedirs(import_dir, exist_ok=True) os.makedirs(import_dir, exist_ok=True)
@@ -957,7 +965,7 @@ def interactive_import():
# 查找所有 .mmipkg 文件 # 查找所有 .mmipkg 文件
for file in os.listdir(import_dir): for file in os.listdir(import_dir):
if file.endswith('.mmipkg'): if file.endswith(".mmipkg"):
file_path = os.path.join(import_dir, file) file_path = os.path.join(import_dir, file)
if os.path.isfile(file_path): if os.path.isfile(file_path):
input_files.append(file_path) input_files.append(file_path)
@@ -1032,7 +1040,7 @@ def interactive_import():
BarColumn(), BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(), TimeElapsedColumn(),
console=console console=console,
) as progress: ) as progress:
task = progress.add_task("[cyan]导入文件...", total=len(input_files)) task = progress.add_task("[cyan]导入文件...", total=len(input_files))
@@ -1044,10 +1052,7 @@ def interactive_import():
console.print(f"[bold]{'=' * 70}[/bold]") console.print(f"[bold]{'=' * 70}[/bold]")
success = unpacker.import_to_db( success = unpacker.import_to_db(
input_path, input_path, output_dir=output_dir, replace_existing=replace_existing, batch_size=batch_size
output_dir=output_dir,
replace_existing=replace_existing,
batch_size=batch_size
) )
if success: if success:
@@ -1076,16 +1081,16 @@ def main():
while True: while True:
print_menu() print_menu()
try: try:
choice = get_input("请选择", "1", ['0', '1', '2']) choice = get_input("请选择", "1", ["0", "1", "2"])
except KeyboardInterrupt: except KeyboardInterrupt:
console.print("\n[green]再见![/green]") console.print("\n[green]再见![/green]")
return 0 return 0
if choice == '0': if choice == "0":
console.print("\n[green]再见![/green]") console.print("\n[green]再见![/green]")
return 0 return 0
elif choice == '1': elif choice == "1":
try: try:
interactive_export() interactive_export()
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -1093,6 +1098,7 @@ def main():
except Exception as e: except Exception as e:
console.print(f"\n[red]✗ 发生错误: {e}[/red]") console.print(f"\n[red]✗ 发生错误: {e}[/red]")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
try: try:
@@ -1100,7 +1106,7 @@ def main():
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
pass pass
elif choice == '2': elif choice == "2":
try: try:
interactive_import() interactive_import()
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -1108,6 +1114,7 @@ def main():
except Exception as e: except Exception as e:
console.print(f"\n[red]✗ 发生错误: {e}[/red]") console.print(f"\n[red]✗ 发生错误: {e}[/red]")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
try: try:
@@ -1121,5 +1128,5 @@ def main():
return 0 return 0
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@@ -230,7 +230,7 @@ class HeartFChatting:
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply: if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
mentioned_message = message mentioned_message = message
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}") # logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
# *控制频率用 # *控制频率用
if mentioned_message: if mentioned_message:
@@ -334,7 +334,6 @@ class HeartFChatting:
self.consecutive_no_reply_count = 0 self.consecutive_no_reply_count = 0
reason = "" reason = ""
await database_api.store_action_info( await database_api.store_action_info(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
action_build_into_prompt=False, action_build_into_prompt=False,
@@ -411,7 +410,7 @@ class HeartFChatting:
# asyncio.create_task(self.chat_history_summarizer.process()) # asyncio.create_task(self.chat_history_summarizer.process())
cycle_timers, thinking_id = self.start_cycle() cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})")
# 第一步:动作检查 # 第一步:动作检查
available_actions: Dict[str, ActionInfo] = {} available_actions: Dict[str, ActionInfo] = {}

View File

@@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None qa_manager = None
inspire_manager = None inspire_manager = None
def get_qa_manager(): def get_qa_manager():
return qa_manager return qa_manager
def lpmm_start_up(): # sourcery skip: extract-duplicate-method def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable: if global_config.lpmm_knowledge.enable:

View File

@@ -92,14 +92,20 @@ class QAManager:
# 过滤阈值 # 过滤阈值
result = dyn_select_top_k(result, 0.5, 1.0) result = dyn_select_top_k(result, 0.5, 1.0)
for res in result: if global_config.debug.show_lpmm_paragraph:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str for res in result:
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights return result, ppr_node_weights
async def get_knowledge(self, question: str) -> Optional[str]: async def get_knowledge(self, question: str, limit: int = 5) -> Optional[str]:
"""获取知识""" """获取知识
Args:
question: 查询问题
limit: 返回的相关知识条数
"""
# 处理查询 # 处理查询
processed_result = await self.process_query(question) processed_result = await self.process_query(question)
if processed_result is not None: if processed_result is not None:
@@ -109,6 +115,8 @@ class QAManager:
logger.debug("知识库查询结果为空,可能是知识库中没有相关内容") logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
return None return None
limit = max(1, limit) if isinstance(limit, int) else 5
knowledge = [ knowledge = [
( (
self.embed_manager.paragraphs_embedding_store.store[res[0]].str, self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
@@ -116,9 +124,17 @@ class QAManager:
) )
for res in query_res for res in query_res
] ]
found_knowledge = "\n".join(
[f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(knowledge)] # max_score = max([k[1] for k in knowledge]) if knowledge else None
) selected_knowledge = knowledge[:limit]
formatted_knowledge = [
f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
]
# if max_score is not None:
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
found_knowledge = "\n".join(formatted_knowledge)
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH: if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n" found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n"
return found_knowledge return found_knowledge

View File

@@ -7,7 +7,6 @@ from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
from rich.traceback import install from rich.traceback import install
from datetime import datetime from datetime import datetime
from json_repair import repair_json from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -164,6 +163,45 @@ class ActionPlanner:
return item[1] return item[1]
return None return None
def _replace_message_ids_with_text(
self, text: Optional[str], message_id_list: List[Tuple[str, "DatabaseMessages"]]
) -> Optional[str]:
"""将文本中的 m+数字 消息ID替换为原消息内容并添加双引号"""
if not text:
return text
id_to_message = {msg_id: msg for msg_id, msg in message_id_list}
# 匹配m后带2-4位数字前后不是字母数字下划线
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
matches = re.findall(pattern, text)
if matches:
available_ids = set(id_to_message.keys())
found_ids = set(matches)
missing_ids = found_ids - available_ids
if missing_ids:
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...")
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用其中{len(found_ids & available_ids)}个在上下文中")
def _replace(match: re.Match[str]) -> str:
msg_id = match.group(0)
message = id_to_message.get(msg_id)
if not message:
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
return msg_id
msg_text = (message.processed_plain_text or message.display_message or "").strip()
if not msg_text:
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
return msg_id
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview}")
return f"消息({msg_text}"
return re.sub(pattern, _replace, text)
def _parse_single_action( def _parse_single_action(
self, self,
action_json: dict, action_json: dict,
@@ -176,7 +214,10 @@ class ActionPlanner:
try: try:
action = action_json.get("action", "no_reply") action = action_json.get("action", "no_reply")
reasoning = action_json.get("reason", "未提供原因") original_reasoning = action_json.get("reason", "未提供原因")
reasoning = self._replace_message_ids_with_text(original_reasoning, message_id_list)
if reasoning is None:
reasoning = original_reasoning
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]} action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
# 非no_reply动作需要target_message_id # 非no_reply动作需要target_message_id
target_message = None target_message = None
@@ -573,9 +614,6 @@ class ActionPlanner:
# 调用LLM # 调用LLM
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt) llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
if global_config.debug.show_planner_prompt: if global_config.debug.show_planner_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}") logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
@@ -604,6 +642,7 @@ class ActionPlanner:
if llm_content: if llm_content:
try: try:
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content) json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
extracted_reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) or ""
if json_objects: if json_objects:
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
filtered_actions_list = list(filtered_actions.items()) filtered_actions_list = list(filtered_actions.items())

View File

@@ -226,7 +226,9 @@ class DefaultReplyer:
traceback.print_exc() traceback.print_exc()
return False, llm_response return False, llm_response
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]: async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend # sourcery skip: for-append-to-extend
"""构建表达习惯块 """构建表达习惯块

View File

@@ -241,7 +241,9 @@ class PrivateReplyer:
return f"{sender_relation}" return f"{sender_relation}"
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]: async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend # sourcery skip: for-append-to-extend
"""构建表达习惯块 """构建表达习惯块

View File

@@ -107,7 +107,7 @@ class ChatHistorySummarizer:
self.last_check_time = current_time self.last_check_time = current_time
return return
logger.info( logger.debug(
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}" f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
) )
@@ -119,7 +119,7 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages) before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages) self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息") logger.info(f"{self.log_prefix} 更新聊天话题: {before_count} -> {len(self.current_batch.messages)} 条消息")
else: else:
# 创建新批次 # 创建新批次
self.current_batch = MessageBatch( self.current_batch = MessageBatch(
@@ -127,7 +127,7 @@ class ChatHistorySummarizer:
start_time=new_messages[0].time if new_messages else current_time, start_time=new_messages[0].time if new_messages else current_time,
end_time=current_time, end_time=current_time,
) )
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息") logger.info(f"{self.log_prefix} 新建聊天话题: {len(new_messages)} 条消息")
# 检查是否需要打包 # 检查是否需要打包
await self._check_and_package(current_time) await self._check_and_package(current_time)

View File

@@ -311,6 +311,8 @@ class Expression(BaseModel):
context = TextField(null=True) context = TextField(null=True)
up_content = TextField(null=True) up_content = TextField(null=True)
content_list = TextField(null=True)
count = IntegerField(default=1)
last_active_time = FloatField() last_active_time = FloatField()
chat_id = TextField(index=True) chat_id = TextField(index=True)
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据 create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据

View File

@@ -19,6 +19,7 @@ PROJECT_ROOT = logger_file.parent.parent.parent.resolve()
# 全局handler实例避免重复创建 # 全局handler实例避免重复创建
_file_handler = None _file_handler = None
_console_handler = None _console_handler = None
_ws_handler = None
def get_file_handler(): def get_file_handler():
@@ -59,6 +60,35 @@ def get_console_handler():
return _console_handler return _console_handler
def get_ws_handler():
"""获取 WebSocket handler 单例"""
global _ws_handler
if _ws_handler is None:
_ws_handler = WebSocketLogHandler()
# WebSocket handler 推送所有级别的日志
_ws_handler.setLevel(logging.DEBUG)
return _ws_handler
def initialize_ws_handler(loop):
"""初始化 WebSocket handler 的事件循环
Args:
loop: asyncio 事件循环
"""
handler = get_ws_handler()
handler.set_loop(loop)
# 为 WebSocket handler 设置 JSON 格式化器(与文件格式相同)
handler.setFormatter(file_formatter)
# 添加到根日志记录器
root_logger = logging.getLogger()
if handler not in root_logger.handlers:
root_logger.addHandler(handler)
print("[日志系统] ✅ WebSocket 日志推送已启用")
class TimestampedFileHandler(logging.Handler): class TimestampedFileHandler(logging.Handler):
"""基于时间戳的文件处理器,简单的轮转份数限制""" """基于时间戳的文件处理器,简单的轮转份数限制"""
@@ -145,12 +175,76 @@ class TimestampedFileHandler(logging.Handler):
super().close() super().close()
class WebSocketLogHandler(logging.Handler):
"""WebSocket 日志处理器 - 将日志实时推送到前端"""
_log_counter = 0 # 类级别计数器,确保 ID 唯一性
def __init__(self, loop=None):
super().__init__()
self.loop = loop
self._initialized = False
def set_loop(self, loop):
"""设置事件循环"""
self.loop = loop
self._initialized = True
def emit(self, record):
"""发送日志到 WebSocket 客户端"""
if not self._initialized or self.loop is None:
return
try:
# 获取格式化后的消息
# 对于 structlog,formatted message 包含完整的日志信息
formatted_msg = self.format(record) if self.formatter else record.getMessage()
# 如果是 JSON 格式(文件格式化器),解析它
message = formatted_msg
try:
import json
log_dict = json.loads(formatted_msg)
message = log_dict.get("event", formatted_msg)
except (json.JSONDecodeError, ValueError):
# 不是 JSON,直接使用消息
message = formatted_msg
# 生成唯一 ID: 时间戳毫秒 + 自增计数器
WebSocketLogHandler._log_counter += 1
log_id = f"{int(record.created * 1000)}_{WebSocketLogHandler._log_counter}"
# 格式化日志数据
log_data = {
"id": log_id,
"timestamp": datetime.fromtimestamp(record.created).strftime("%Y-%m-%d %H:%M:%S"),
"level": record.levelname,
"module": record.name,
"message": message,
}
# 异步广播日志(不阻塞日志记录)
try:
import asyncio
from src.webui.logs_ws import broadcast_log
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
except Exception:
# WebSocket 推送失败不影响日志记录
pass
except Exception:
# 不要让 WebSocket 错误影响日志系统
self.handleError(record)
# 旧的轮转文件处理器已移除,现在使用基于时间戳的处理器 # 旧的轮转文件处理器已移除,现在使用基于时间戳的处理器
def close_handlers(): def close_handlers():
"""安全关闭所有handler""" """安全关闭所有handler"""
global _file_handler, _console_handler global _file_handler, _console_handler, _ws_handler
if _file_handler: if _file_handler:
_file_handler.close() _file_handler.close()
@@ -160,6 +254,10 @@ def close_handlers():
_console_handler.close() _console_handler.close()
_console_handler = None _console_handler = None
if _ws_handler:
_ws_handler.close()
_ws_handler = None
def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
"""移除重复的handler特别是文件handler""" """移除重复的handler特别是文件handler"""
@@ -843,8 +941,8 @@ def start_log_cleanup_task():
def shutdown_logging(): def shutdown_logging():
"""优雅关闭日志系统,释放所有文件句柄""" """优雅关闭日志系统,释放所有文件句柄"""
logger = get_logger("logger") # 先输出到控制台,避免日志系统关闭后无法输出
logger.info("正在关闭日志系统...") print("[logger] 正在关闭日志系统...")
# 关闭所有handler # 关闭所有handler
root_logger = logging.getLogger() root_logger = logging.getLogger()
@@ -865,4 +963,5 @@ def shutdown_logging():
handler.close() handler.close()
logger_obj.removeHandler(handler) logger_obj.removeHandler(handler)
logger.info("日志系统已关闭") # 使用 print 而不是 logger因为 logger 已经关闭
print("[logger] 日志系统已关闭")

View File

@@ -1,7 +1,7 @@
from fastapi import FastAPI, APIRouter from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware # 新增导入
from typing import Optional from typing import Optional
from uvicorn import Config, Server as UvicornServer from uvicorn import Config, Server as UvicornServer
import asyncio
import os import os
from rich.traceback import install from rich.traceback import install
@@ -16,21 +16,6 @@ class Server:
self._server: Optional[UvicornServer] = None self._server: Optional[UvicornServer] = None
self.set_address(host, port) self.set_address(host, port)
# 配置 CORS
origins = [
"http://localhost:7999", # 允许的前端源
"http://127.0.0.1:7999",
# 在生产环境中,您应该添加实际的前端域名
]
self.app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True, # 是否支持 cookie
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有 HTTP 请求头
)
def register_router(self, router: APIRouter, prefix: str = ""): def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由 """注册路由
@@ -82,8 +67,17 @@ class Server:
"""安全关闭服务器""" """安全关闭服务器"""
if self._server: if self._server:
self._server.should_exit = True self._server.should_exit = True
await self._server.shutdown() try:
self._server = None # 添加 3 秒超时,避免 shutdown 永久挂起
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
except asyncio.TimeoutError:
# 超时就强制标记为 None让垃圾回收处理
pass
except Exception:
# 忽略其他异常
pass
finally:
self._server = None
def get_app(self) -> FastAPI: def get_app(self) -> FastAPI:
"""获取 FastAPI 实例""" """获取 FastAPI 实例"""

View File

@@ -56,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/ # 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.11.2" MMC_VERSION = "0.11.5"
def get_key_comment(toml_table, key): def get_key_comment(toml_table, key):

View File

@@ -581,9 +581,15 @@ class DebugConfig(ConfigBase):
show_jargon_prompt: bool = False show_jargon_prompt: bool = False
"""是否显示jargon相关提示词""" """是否显示jargon相关提示词"""
show_memory_prompt: bool = False
"""是否显示记忆检索相关prompt"""
show_planner_prompt: bool = False show_planner_prompt: bool = False
"""是否显示planner相关提示词""" """是否显示planner相关提示词"""
show_lpmm_paragraph: bool = False
"""是否显示lpmm找到的相关文段日志"""
@dataclass @dataclass
class ExperimentalConfig(ConfigBase): class ExperimentalConfig(ConfigBase):

View File

@@ -61,6 +61,37 @@ def format_create_date(timestamp: float) -> str:
return "未知时间" return "未知时间"
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~3之间。
count越高权重越高但最多为基础权重的3倍。
"""
if not population:
return []
counts = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
return [1.0 for _ in counts]
weights = []
for count_value in counts:
# 线性映射到[1,3]区间
normalized = (count_value - min_count) / (max_count - min_count)
weights.append(1.0 + normalized * 2.0) # 1~3
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]: def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
""" """
随机抽样函数 随机抽样函数
@@ -78,15 +109,24 @@ def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
if len(population) <= k: if len(population) <= k:
return population.copy() return population.copy()
# 使用随机抽样 selected: List[Dict] = []
selected = []
population_copy = population.copy() population_copy = population.copy()
for _ in range(k): for _ in range(min(k, len(population_copy))):
if not population_copy: weights = _compute_weights(population_copy)
break total_weight = sum(weights)
# 随机选择一个元素 if total_weight <= 0:
idx = random.randint(0, len(population_copy) - 1) # 回退到均匀随机
selected.append(population_copy.pop(idx)) idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected return selected

View File

@@ -77,6 +77,9 @@ class ExpressionLearner:
self.express_learn_model: LLMRequest = LLMRequest( self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="expression.learner" model_set=model_config.model_task_config.utils, request_type="expression.learner"
) )
self.summary_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
)
self.embedding_model: LLMRequest = LLMRequest( self.embedding_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="expression.embedding" model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
) )
@@ -91,8 +94,8 @@ class ExpressionLearner:
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat( _, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id self.chat_id
) )
self.min_messages_for_learning = 30 / self.learning_intensity # 触发学习所需的最少消息数 self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 300 / self.learning_intensity self.min_learning_interval = 120 / self.learning_intensity
def should_trigger_learning(self) -> bool: def should_trigger_learning(self) -> bool:
""" """
@@ -186,25 +189,13 @@ class ExpressionLearner:
context, context,
up_content, up_content,
) in learnt_expressions: ) in learnt_expressions:
# 查找是否已存在相似表达方式 await self._upsert_expression_record(
query = Expression.select().where( situation=situation,
(Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style) style=style,
context=context,
up_content=up_content,
current_time=current_time,
) )
if query.exists():
# 表达方式完全相同,只更新时间戳
expr_obj = query.get()
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time, # 手动设置创建日期
context=context,
up_content=up_content,
)
return learnt_expressions return learnt_expressions
@@ -362,6 +353,10 @@ class ExpressionLearner:
logger.error(f"学习表达方式失败,模型生成出错: {e}") logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response) expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return None
# logger.debug(f"学习{type_str}的response: {response}") # logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源 # 对表达方式溯源
@@ -433,6 +428,149 @@ class ExpressionLearner:
expressions.append((situation, style)) expressions.append((situation, style))
return expressions return expressions
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
过滤掉style与机器人名称/昵称重复的表达
"""
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
filtered: List[Tuple[str, str]] = []
removed_count = 0
for situation, style in expressions:
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() not in banned_casefold:
filtered.append((situation, style))
else:
removed_count += 1
if removed_count:
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
return filtered
async def _upsert_expression_record(
self,
situation: str,
style: str,
context: str,
up_content: str,
current_time: float,
) -> None:
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
if expr_obj:
await self._update_existing_expression(
expr_obj=expr_obj,
situation=situation,
context=context,
up_content=up_content,
current_time=current_time,
)
return
await self._create_expression_record(
situation=situation,
style=style,
context=context,
up_content=up_content,
current_time=current_time,
)
async def _create_expression_record(
self,
situation: str,
style: str,
context: str,
up_content: str,
current_time: float,
) -> None:
content_list = [situation]
formatted_situation = await self._compose_situation_text(content_list, 1, situation)
Expression.create(
situation=formatted_situation,
style=style,
content_list=json.dumps(content_list, ensure_ascii=False),
count=1,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time,
context=context,
up_content=up_content,
)
async def _update_existing_expression(
self,
expr_obj: Expression,
situation: str,
context: str,
up_content: str,
current_time: float,
) -> None:
content_list = self._parse_content_list(expr_obj.content_list)
content_list.append(situation)
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
expr_obj.count = (expr_obj.count or 0) + 1
expr_obj.last_active_time = current_time
expr_obj.context = context
expr_obj.up_content = up_content
new_situation = await self._compose_situation_text(
content_list=content_list,
count=expr_obj.count,
fallback=expr_obj.situation,
)
expr_obj.situation = new_situation
expr_obj.save()
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
sanitized = [c.strip() for c in content_list if c.strip()]
summary = await self._summarize_situations(sanitized)
if summary:
return summary
return "/".join(sanitized) if sanitized else fallback
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
if not situations:
return None
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
logger.error(f"概括表达情境失败: {e}")
return None
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]: def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
""" """
为每条消息构建精简文本列表,保留到原消息索引的映射 为每条消息构建精简文本列表,保留到原消息索引的映射

View File

@@ -42,8 +42,6 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
class ExpressionSelector: class ExpressionSelector:
def __init__(self): def __init__(self):
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
@@ -139,6 +137,7 @@ class ExpressionSelector:
"last_active_time": expr.last_active_time, "last_active_time": expr.last_active_time,
"source_id": expr.chat_id, "source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"count": expr.count if getattr(expr, "count", None) is not None else 1,
} }
for expr in style_query for expr in style_query
] ]
@@ -261,7 +260,6 @@ class ExpressionSelector:
# 4. 调用LLM # 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# print(prompt) # print(prompt)
if not content: if not content:

View File

@@ -23,6 +23,26 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
logger = get_logger("jargon") logger = get_logger("jargon")
def _contains_bot_self_name(content: str) -> bool:
"""
判断词条是否包含机器人的昵称或别名
"""
if not content:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
return any(name in target for name in candidates if target)
def _init_prompt() -> None: def _init_prompt() -> None:
prompt_str = """ prompt_str = """
**聊天内容其中的SELF是你自己的发言** **聊天内容其中的SELF是你自己的发言**
@@ -165,9 +185,7 @@ async def _enrich_raw_content_if_needed(
# 获取该消息的前三条消息 # 获取该消息的前三条消息
try: try:
previous_messages = get_raw_msg_before_timestamp_with_chat( previous_messages = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id, timestamp=target_message.time, limit=3
timestamp=target_message.time,
limit=3
) )
if previous_messages: if previous_messages:
@@ -222,7 +240,7 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
last_inference = jargon_obj.last_inference_count or 0 last_inference = jargon_obj.last_inference_count or 0
# 阈值列表3,6, 10, 20, 40, 60, 100 # 阈值列表3,6, 10, 20, 40, 60, 100
thresholds = [3,6, 10, 20, 40, 60, 100] thresholds = [3, 6, 10, 20, 40, 60, 100]
if count < thresholds[0]: if count < thresholds[0]:
return False return False
@@ -251,7 +269,7 @@ class JargonMiner:
self.chat_id = chat_id self.chat_id = chat_id
self.last_learning_time: float = time.time() self.last_learning_time: float = time.time()
# 频率控制,可按需调整 # 频率控制,可按需调整
self.min_messages_for_learning: int = 15 self.min_messages_for_learning: int = 10
self.min_learning_interval: float = 20 self.min_learning_interval: float = 20
self.llm = LLMRequest( self.llm = LLMRequest(
@@ -288,7 +306,9 @@ class JargonMiner:
raw_content_list = [] raw_content_list = []
if raw_content_str: if raw_content_str:
try: try:
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str raw_content_list = (
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
)
if not isinstance(raw_content_list, list): if not isinstance(raw_content_list, list):
raw_content_list = [raw_content_list] if raw_content_list else [] raw_content_list = [raw_content_list] if raw_content_list else []
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
@@ -337,7 +357,6 @@ class JargonMiner:
jargon_obj.save() jargon_obj.save()
return return
# 步骤2: 仅基于content推断 # 步骤2: 仅基于content推断
prompt2 = await global_prompt_manager.format_prompt( prompt2 = await global_prompt_manager.format_prompt(
"jargon_inference_content_only_prompt", "jargon_inference_content_only_prompt",
@@ -365,11 +384,10 @@ class JargonMiner:
logger.error(f"jargon {content} 推断2解析失败: {e}") logger.error(f"jargon {content} 推断2解析失败: {e}")
return return
# logger.info(f"jargon {content} 推断2提示词: {prompt2}")
logger.info(f"jargon {content} 推断2提示词: {prompt2}") # logger.info(f"jargon {content} 推断2结果: {response2}")
logger.info(f"jargon {content} 推断2结果: {response2}") # logger.info(f"jargon {content} 推断1提示词: {prompt1}")
logger.info(f"jargon {content} 推断1提示词: {prompt1}") # logger.info(f"jargon {content} 推断1结果: {response1}")
logger.info(f"jargon {content} 推断1结果: {response1}")
if global_config.debug.show_jargon_prompt: if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 推断2提示词: {prompt2}") logger.info(f"jargon {content} 推断2提示词: {prompt2}")
@@ -434,7 +452,9 @@ class JargonMiner:
jargon_obj.is_complete = True jargon_obj.is_complete = True
jargon_obj.save() jargon_obj.save()
logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}") logger.debug(
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
)
# 固定输出推断结果,格式化为可读形式 # 固定输出推断结果,格式化为可读形式
if is_jargon: if is_jargon:
@@ -442,7 +462,7 @@ class JargonMiner:
meaning = jargon_obj.meaning or "无详细说明" meaning = jargon_obj.meaning or "无详细说明"
is_global = jargon_obj.is_global is_global = jargon_obj.is_global
if is_global: if is_global:
logger.info(f"[通用黑话]{content}的含义是 {meaning}") logger.info(f"[黑话]{content}的含义是 {meaning}")
else: else:
logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}") logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}")
else: else:
@@ -452,6 +472,7 @@ class JargonMiner:
except Exception as e: except Exception as e:
logger.error(f"jargon推断失败: {e}") logger.error(f"jargon推断失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def should_trigger(self) -> bool: def should_trigger(self) -> bool:
@@ -545,10 +566,10 @@ class JargonMiner:
raw_content_list = [raw_content_str] raw_content_list = [raw_content_str]
if content and raw_content_list: if content and raw_content_list:
entries.append({ if _contains_bot_self_name(content):
"content": content, logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
"raw_content": raw_content_list continue
}) entries.append({"content": content, "raw_content": raw_content_list})
except Exception as e: except Exception as e:
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}") logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
return return
@@ -586,19 +607,10 @@ class JargonMiner:
# 根据all_global配置决定查询逻辑 # 根据all_global配置决定查询逻辑
if global_config.jargon.all_global: if global_config.jargon.all_global:
# 开启all_global无视chat_id查询所有content匹配的记录所有记录都是全局的 # 开启all_global无视chat_id查询所有content匹配的记录所有记录都是全局的
query = ( query = Jargon.select().where(Jargon.content == content)
Jargon.select()
.where(Jargon.content == content)
)
else: else:
# 关闭all_global只查询chat_id匹配的记录不考虑is_global # 关闭all_global只查询chat_id匹配的记录不考虑is_global
query = ( query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
Jargon.select()
.where(
(Jargon.chat_id == self.chat_id) &
(Jargon.content == content)
)
)
if query.exists(): if query.exists():
obj = query.get() obj = query.get()
@@ -611,7 +623,9 @@ class JargonMiner:
existing_raw_content = [] existing_raw_content = []
if obj.raw_content: if obj.raw_content:
try: try:
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list): if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else [] existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
@@ -650,7 +664,7 @@ class JargonMiner:
raw_content=json.dumps(raw_content_list, ensure_ascii=False), raw_content=json.dumps(raw_content_list, ensure_ascii=False),
chat_id=self.chat_id, chat_id=self.chat_id,
is_global=is_global_new, is_global=is_global_new,
count=1 count=1,
) )
saved += 1 saved += 1
except Exception as e: except Exception as e:
@@ -694,11 +708,7 @@ async def extract_and_store_jargon(chat_id: str) -> None:
def search_jargon( def search_jargon(
keyword: str, keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
chat_id: Optional[str] = None,
limit: int = 10,
case_sensitive: bool = False,
fuzzy: bool = True
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
""" """
搜索jargon支持大小写不敏感和模糊搜索 搜索jargon支持大小写不敏感和模糊搜索
@@ -721,10 +731,7 @@ def search_jargon(
keyword = keyword.strip() keyword = keyword.strip()
# 构建查询 # 构建查询
query = Jargon.select( query = Jargon.select(Jargon.content, Jargon.meaning)
Jargon.content,
Jargon.meaning
)
# 构建搜索条件 # 构建搜索条件
if case_sensitive: if case_sensitive:
@@ -734,7 +741,7 @@ def search_jargon(
search_condition = Jargon.content.contains(keyword) search_condition = Jargon.content.contains(keyword)
else: else:
# 精确匹配 # 精确匹配
search_condition = (Jargon.content == keyword) search_condition = Jargon.content == keyword
else: else:
# 大小写不敏感 # 大小写不敏感
if fuzzy: if fuzzy:
@@ -742,7 +749,7 @@ def search_jargon(
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower()) search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
else: else:
# 精确匹配使用LOWER函数 # 精确匹配使用LOWER函数
search_condition = (fn.LOWER(Jargon.content) == keyword.lower()) search_condition = fn.LOWER(Jargon.content) == keyword.lower()
query = query.where(search_condition) query = query.where(search_condition)
@@ -753,14 +760,10 @@ def search_jargon(
else: else:
# 关闭all_global如果提供了chat_id优先搜索该聊天或global的jargon # 关闭all_global如果提供了chat_id优先搜索该聊天或global的jargon
if chat_id: if chat_id:
query = query.where( query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
(Jargon.chat_id == chat_id) | Jargon.is_global
)
# 只返回有meaning的记录 # 只返回有meaning的记录
query = query.where( query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
)
# 按count降序排序优先返回出现频率高的 # 按count降序排序优先返回出现频率高的
query = query.order_by(Jargon.count.desc()) query = query.order_by(Jargon.count.desc())
@@ -771,10 +774,7 @@ def search_jargon(
# 执行查询并返回结果 # 执行查询并返回结果
results = [] results = []
for jargon in query: for jargon in query:
results.append({ results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
"content": jargon.content or "",
"meaning": jargon.meaning or ""
})
return results return results
@@ -814,10 +814,7 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
if global_config.jargon.all_global: if global_config.jargon.all_global:
query = Jargon.select().where(Jargon.content == jargon_keyword) query = Jargon.select().where(Jargon.content == jargon_keyword)
else: else:
query = Jargon.select().where( query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
(Jargon.chat_id == chat_id) &
(Jargon.content == jargon_keyword)
)
if query.exists(): if query.exists():
# 更新现有记录 # 更新现有记录
@@ -828,7 +825,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
existing_raw_content = [] existing_raw_content = []
if obj.raw_content: if obj.raw_content:
try: try:
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list): if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else [] existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
@@ -851,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
raw_content=json.dumps([raw_content], ensure_ascii=False), raw_content=json.dumps([raw_content], ensure_ascii=False),
chat_id=chat_id, chat_id=chat_id,
is_global=is_global_new, is_global=is_global_new,
count=1 count=1,
) )
logger.info(f"创建新jargon记录: {jargon_keyword}") logger.info(f"创建新jargon记录: {jargon_keyword}")
except Exception as e: except Exception as e:
logger.error(f"存储jargon失败: {e}") logger.error(f"存储jargon失败: {e}")

View File

@@ -116,9 +116,7 @@ class MessageBuilder:
构建消息对象 构建消息对象
:return: Message对象 :return: Message对象
""" """
if len(self.__content) == 0 and not ( if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
self.__role == RoleType.Assistant and self.__tool_calls
):
raise ValueError("内容不能为空") raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None: if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空") raise ValueError("Tool角色的工具调用ID不能为空")

View File

@@ -36,25 +36,15 @@ class MainSystem:
# 使用消息API替代直接的FastAPI实例 # 使用消息API替代直接的FastAPI实例
self.app: MessageServer = get_global_api() self.app: MessageServer = get_global_api()
self.server: Server = get_global_server() self.server: Server = get_global_server()
self.webui_server = None # 独立的 WebUI 服务器
# 注册 WebUI API 路由 # 设置独立的 WebUI 服务器
self._register_webui_routes() self._setup_webui_server()
# 设置 WebUI开发/生产模式) def _setup_webui_server(self):
self._setup_webui() """设置独立的 WebUI 服务器"""
def _register_webui_routes(self):
"""注册 WebUI API 路由"""
try:
from src.webui.routes import router as webui_router
self.server.register_router(webui_router)
logger.info("WebUI API 路由已注册")
except Exception as e:
logger.warning(f"注册 WebUI API 路由失败: {e}")
def _setup_webui(self):
"""设置 WebUI根据环境变量决定模式"""
import os import os
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true" webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
if not webui_enabled: if not webui_enabled:
logger.info("WebUI 已禁用") logger.info("WebUI 已禁用")
@@ -63,10 +53,22 @@ class MainSystem:
webui_mode = os.getenv("WEBUI_MODE", "production").lower() webui_mode = os.getenv("WEBUI_MODE", "production").lower()
try: try:
from src.webui.manager import setup_webui from src.webui.webui_server import get_webui_server
setup_webui(mode=webui_mode)
self.webui_server = get_webui_server()
if webui_mode == "development":
logger.info("📝 WebUI 开发模式已启用")
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
logger.info("💡 请手动启动前端开发服务器: cd MaiBot-Dashboard && bun dev")
logger.info("💡 前端将运行在 http://localhost:7999")
else:
logger.info("✅ WebUI 生产模式已启用")
logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001")
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
except Exception as e: except Exception as e:
logger.error(f"设置 WebUI 失败: {e}") logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
async def initialize(self): async def initialize(self):
"""初始化系统组件""" """初始化系统组件"""
@@ -161,6 +163,10 @@ class MainSystem:
self.server.run(), self.server.run(),
] ]
# 如果 WebUI 服务器已初始化,添加到任务列表
if self.webui_server:
tasks.append(self.webui_server.start())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# async def forget_memory_task(self): # async def forget_memory_task(self):

View File

@@ -11,10 +11,38 @@ from src.plugin_system.apis import llm_api
from src.common.database.database_model import ThinkingBack from src.common.database.database_model import ThinkingBack
from json_repair import repair_json from json_repair import repair_json
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.memory_system.retrieval_tools.query_lpmm_knowledge import query_lpmm_knowledge
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
logger = get_logger("memory_retrieval") logger = get_logger("memory_retrieval")
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 36000 # 未找到答案记录保留时长
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 3000 # 清理频率
_last_not_found_cleanup_ts: float = 0.0
def _cleanup_stale_not_found_thinking_back() -> None:
"""定期清理过期的未找到答案记录"""
global _last_not_found_cleanup_ts
now = time.time()
if now - _last_not_found_cleanup_ts < THINKING_BACK_CLEANUP_INTERVAL_SECONDS:
return
threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS
try:
deleted_rows = (
ThinkingBack.delete()
.where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time))
.execute()
)
if deleted_rows:
logger.info(f"清理过期的未找到答案thinking_back记录 {deleted_rows}")
_last_not_found_cleanup_ts = now
except Exception as e:
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
def init_memory_retrieval_prompt(): def init_memory_retrieval_prompt():
"""初始化记忆检索相关的 prompt 模板和工具""" """初始化记忆检索相关的 prompt 模板和工具"""
# 首先注册所有工具 # 首先注册所有工具
@@ -34,20 +62,17 @@ def init_memory_retrieval_prompt():
1. 对话中是否提到了过去发生的事情、人物、事件或信息 1. 对话中是否提到了过去发生的事情、人物、事件或信息
2. 是否有需要回忆的内容(比如"之前说过""上次""以前"等) 2. 是否有需要回忆的内容(比如"之前说过""上次""以前"等)
3. 是否有需要查找历史信息的问题 3. 是否有需要查找历史信息的问题
4. 是否需要查找某人的信息person: 如果对话中提到人名、昵称、用户ID等需要查询该人物的详细信息 4. 是否有问题可以搜集信息帮助你聊天
5. 是否有问题可以搜集信息帮助你聊天 5. 对话中是否包含黑话、俚语、缩写等可能需要查询的概念
6. 对话中是否包含黑话、俚语、缩写等可能需要查询的概念
重要提示: 重要提示:
- **每次只能提出一个问题**,选择最需要查询的关键问题 - **每次只能提出一个问题**,选择最需要查询的关键问题
- 如果"最近已查询的问题和结果"中已经包含了类似的问题,请避免重复生成相同或相似的问题 - 如果"最近已查询的问题和结果"中已经包含了类似的问题并得到了答案,请避免重复生成相同或相似的问题,不需要重复查询
- 如果之前已经查询过某个问题但未找到答案,可以尝试用不同的方式提问或更具体的问题 - 如果之前已经查询过某个问题但未找到答案,可以尝试用不同的方式提问或更具体的问题
- 如果之前已经查询过某个问题并找到了答案,可以直接参考已有结果,不需要重复查询
如果你认为需要从记忆中检索信息来回答,请: 如果你认为需要从记忆中检索信息来回答,请:
1. 识别对话中可能需要查询的概念(黑话/俚语/缩写/专有名词等关键词),放入"concepts"字段 1. 识别对话中可能需要查询的概念(黑话/俚语/缩写/专有名词等关键词),放入"concepts"字段
2. 识别对话中提到的人物名称(人名、昵称等),放入"person"字段 2. 根据上下文提出**一个**最关键的问题来帮助你回复目标消息,放入"questions"字段
3. 然后根据上下文提出**一个**最关键的问题来帮助你回复目标消息,放入"questions"字段
问题格式示例: 问题格式示例:
- "xxx在前几天干了什么" - "xxx在前几天干了什么"
@@ -55,17 +80,11 @@ def init_memory_retrieval_prompt():
- "xxxx和xxx的关系是什么" - "xxxx和xxx的关系是什么"
- "xxx在某个时间点发生了什么" - "xxx在某个时间点发生了什么"
请输出JSON格式包含三个字段
- "concepts": 需要检索的概念列表(字符串数组),如果不需要检索概念则输出空数组[]
- "person": 需要查询的人物名称列表(字符串数组),如果不需要查询人物信息则输出空数组[]
- "questions": 问题数组(字符串数组),如果不需要检索记忆则输出空数组[],如果需要检索则只输出包含一个问题的数组
输出格式示例(需要检索时): 输出格式示例(需要检索时):
```json ```json
{{ {{
"concepts": ["AAA", "BBB", "CCC"], "concepts": ["AAA", "BBB", "CCC"], #需要检索的概念列表(字符串数组),如果不需要检索概念则输出空数组[]
"person": ["张三", "李四"], "questions": ["张三在前几天干了什么"] #问题数组(字符串数组),如果不需要检索记忆则输出空数组[],如果需要检索则只输出包含一个问题的数组
"questions": ["张三在前几天干了什么"]
}} }}
``` ```
@@ -73,7 +92,6 @@ def init_memory_retrieval_prompt():
```json ```json
{{ {{
"concepts": [], "concepts": [],
"person": [],
"questions": [] "questions": []
}} }}
``` ```
@@ -85,10 +103,8 @@ def init_memory_retrieval_prompt():
# 第二步ReAct Agent prompt使用function calling要求先思考再行动 # 第二步ReAct Agent prompt使用function calling要求先思考再行动
Prompt( Prompt(
""" """你的名字是{bot_name}。现在是{time_now}
你的名字是{bot_name}。现在是{time_now}
你正在参与聊天,你需要搜集信息来回答问题,帮助你参与聊天。 你正在参与聊天,你需要搜集信息来回答问题,帮助你参与聊天。
你需要通过思考(Think)、行动(Action)、观察(Observation)的循环来回答问题。
**重要限制:** **重要限制:**
- 最大查询轮数:{max_iterations}轮(当前第{current_iteration}轮,剩余{remaining_iterations}轮) - 最大查询轮数:{max_iterations}轮(当前第{current_iteration}轮,剩余{remaining_iterations}轮)
@@ -101,76 +117,32 @@ def init_memory_retrieval_prompt():
{collected_info} {collected_info}
**执行步骤:** **执行步骤:**
**第一步思考Think** **第一步思考Think**
在思考中分析: 在思考中分析:
- 当前信息是否足够回答问题? - 当前信息是否足够回答问题?
- **如果信息足够且能找到明确答案**在思考中直接给出答案格式为found_answer(answer="你的答案内容") - **如果信息足够且能找到明确答案**在思考中直接给出答案格式为found_answer(answer="你的答案内容")
- **如果信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因") - **如果需要尝试搜集更多信息,进一步调用工具,进入第二步行动环节
- 如果还需要继续查询,说明最需要查询什么,并输出为纯文本说明 - **如果已有信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因")
**第二步行动Action** **第二步行动Action**
根据思考结果立即行动: - 如果涉及过往事件,可以使用聊天记录查询工具查询过往事件
- 如果思考中已给出found_answer → 无需调用工具,直接结束 - 如果涉及概念可以用jargon查询或根据关键词检索聊天记录
- 如果思考中已给出not_enough_info → 无需调用工具,直接结束 - 如果涉及人物,可以使用人物信息查询工具查询人物信息
- 如果信息不足且需要继续查询 → 调用相应工具查询(可并行调用多个工具) - 如果不确定查询类别也可以使用lpmm知识库查询
- 如果信息不足且需要继续查询,说明最需要查询什么,并输出为纯文本说明,然后调用相应工具查询(可并行调用多个工具)
**重要规则:** **重要规则:**
- **只有在检索到明确、有关的信息并得出答案时才使用found_answer** - **只有在检索到明确、有关的信息并得出答案时才使用found_answer**
- **如果信息不足、无法确定、找不到相关信息必须使用not_enough_info不要使用found_answer** - **如果信息不足、无法确定、找不到相关信息必须使用not_enough_info不要使用found_answer**
- 答案必须在思考中给出,格式为 found_answer(answer="...") 或 not_enough_info(reason="..."),不要调用工具。 - 答案必须在思考中给出,格式为 found_answer(answer="...") 或 not_enough_info(reason="...")
""",
name="memory_retrieval_react_prompt",
)
# 第二步ReAct Agent prompt使用function calling要求先思考再行动
Prompt(
"""
你的名字是{bot_name}。现在是{time_now}
你正在参与聊天,你需要搜集信息来回答问题,帮助你参与聊天。
你需要通过思考(Think)、行动(Action)、观察(Observation)的循环来回答问题。
**重要限制:**
- 最大查询轮数:{max_iterations}轮(当前第{current_iteration}轮,剩余{remaining_iterations}轮)
- 必须尽快得出答案,避免不必要的查询
- 思考要简短,直接切入要点
- 必须严格使用检索到的信息回答问题,不要编造信息
当前问题:{question}
**执行步骤:**
**第一步思考Think**
在思考中分析:
- 当前信息是否足够回答问题?
- **如果信息足够且能找到明确答案**在思考中直接给出答案格式为found_answer(answer="你的答案内容")
- **如果信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因")
- 如果还需要继续查询,说明最需要查询什么,并输出为纯文本说明
**第二步行动Action**
根据思考结果立即行动:
- 如果思考中已给出found_answer → 无需调用工具,直接结束
- 如果思考中已给出not_enough_info → 无需调用工具,直接结束
- 如果信息不足且需要继续查询 → 调用相应工具查询(可并行调用多个工具)
**重要规则:**
- **只有在检索到明确、具体的答案时才使用found_answer**
- **如果信息不足、无法确定、找不到相关信息必须使用not_enough_info不要使用found_answer**
- 答案必须在思考中给出,格式为 found_answer(answer="...") 或 not_enough_info(reason="..."),不要调用工具。
""", """,
name="memory_retrieval_react_prompt_head", name="memory_retrieval_react_prompt_head",
) )
# 额外如果最后一轮迭代ReAct Agent prompt使用function calling要求先思考再行动 # 额外如果最后一轮迭代ReAct Agent prompt使用function calling要求先思考再行动
Prompt( Prompt(
""" """你的名字是{bot_name}。现在是{time_now}
的名字是{bot_name}。现在是{time_now} 正在参与聊天,你需要根据搜集到的信息判断问题是否可以回答问题
你正在参与聊天,你需要搜集信息来回答问题,帮助你参与聊天。
**重要限制:**
- 你已经经过几轮查询,尝试了信息搜集,现在你需要总结信息,选择回答问题或判断问题无法回答
- 思考要简短,直接切入要点
- 必须严格使用检索到的信息回答问题,不要编造信息
当前问题:{question} 当前问题:{question}
已收集的信息: 已收集的信息:
@@ -183,6 +155,9 @@ def init_memory_retrieval_prompt():
- **如果信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因") - **如果信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因")
**重要规则:** **重要规则:**
- 你已经经过几轮查询,尝试了信息搜集,现在你需要总结信息,选择回答问题或判断问题无法回答
- 必须严格使用检索到的信息回答问题,不要编造信息
- 答案必须精简,不要过多解释
- **只有在检索到明确、具体的答案时才使用found_answer** - **只有在检索到明确、具体的答案时才使用found_answer**
- **如果信息不足、无法确定、找不到相关信息必须使用not_enough_info不要使用found_answer** - **如果信息不足、无法确定、找不到相关信息必须使用not_enough_info不要使用found_answer**
- 答案必须给出,格式为 found_answer(answer="...") 或 not_enough_info(reason="...")。 - 答案必须给出,格式为 found_answer(answer="...") 或 not_enough_info(reason="...")。
@@ -244,10 +219,7 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
return None return None
async def _retrieve_concepts_with_jargon( async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
concepts: List[str],
chat_id: str
) -> str:
"""对概念列表进行jargon检索 """对概念列表进行jargon检索
Args: Args:
@@ -269,25 +241,13 @@ async def _retrieve_concepts_with_jargon(
continue continue
# 先尝试精确匹配 # 先尝试精确匹配
jargon_results = search_jargon( jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
keyword=concept,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=False
)
is_fuzzy_match = False is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索 # 如果精确匹配未找到,尝试模糊搜索
if not jargon_results: if not jargon_results:
jargon_results = search_jargon( jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
keyword=concept,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=True
)
is_fuzzy_match = True is_fuzzy_match = True
if jargon_results: if jargon_results:
@@ -312,8 +272,7 @@ async def _retrieve_concepts_with_jargon(
results.append("".join(output_parts) if len(output_parts) > 1 else output_parts[0]) results.append("".join(output_parts) if len(output_parts) > 1 else output_parts[0])
logger.info(f"在jargon库中找到匹配精确匹配: {concept},找到{len(jargon_results)}条结果") logger.info(f"在jargon库中找到匹配精确匹配: {concept},找到{len(jargon_results)}条结果")
else: else:
# 未找到 # 未找到,不返回占位信息,只记录日志
results.append(f"未在jargon库中找到'{concept}'的解释")
logger.info(f"在jargon库中未找到匹配: {concept}") logger.info(f"在jargon库中未找到匹配: {concept}")
if results: if results:
@@ -321,53 +280,8 @@ async def _retrieve_concepts_with_jargon(
return "" return ""
async def _retrieve_persons_info(
persons: List[str],
chat_id: str
) -> str:
"""对人物列表进行信息检索
Args:
persons: 人物名称列表
chat_id: 聊天ID
Returns:
str: 检索结果字符串
"""
if not persons:
return ""
from src.memory_system.retrieval_tools.query_person_info import query_person_info
results = []
for person in persons:
person = person.strip()
if not person:
continue
try:
person_info = await query_person_info(person)
if person_info and "未找到" not in person_info:
results.append(f"{person}\n{person_info}")
logger.info(f"查询到人物信息: {person}")
else:
results.append(f"未找到人物'{person}'的信息")
logger.info(f"未找到人物信息: {person}")
except Exception as e:
logger.error(f"查询人物信息失败: {person}, 错误: {e}")
results.append(f"查询人物'{person}'信息时发生错误: {str(e)}")
if results:
return "【人物信息检索结果】\n" + "\n\n".join(results) + "\n"
return ""
async def _react_agent_solve_question( async def _react_agent_solve_question(
question: str, question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0, initial_info: str = ""
chat_id: str,
max_iterations: int = 5,
timeout: float = 30.0,
initial_info: str = ""
) -> Tuple[bool, str, List[Dict[str, Any]], bool]: ) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
"""使用ReAct架构的Agent来解决问题 """使用ReAct架构的Agent来解决问题
@@ -408,36 +322,45 @@ async def _react_agent_solve_question(
remaining_iterations = max_iterations - current_iteration remaining_iterations = max_iterations - current_iteration
is_final_iteration = current_iteration >= max_iterations is_final_iteration = current_iteration >= max_iterations
# 构建prompt不再需要工具文本描述
prompt_type = "memory_retrieval_react_prompt"
if is_final_iteration: if is_final_iteration:
prompt_type = "memory_retrieval_react_final_prompt" # 最后一次迭代,使用最终prompt
tool_definitions = [] tool_definitions = []
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0最后一次迭代不提供工具调用") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0最后一次迭代不提供工具调用"
)
prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_final_prompt",
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info=collected_info if collected_info else "暂无信息",
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
if global_config.debug.show_memory_prompt:
logger.info(f"ReAct Agent 第 {iteration + 1} 次Prompt: {prompt}")
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_definitions,
request_type="memory.react",
)
else: else:
# 非最终迭代使用head_prompt
tool_definitions = tool_registry.get_tool_definitions() tool_definitions = tool_registry.get_tool_definitions()
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}"
)
prompt = await global_prompt_manager.format_prompt(
prompt_type,
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info=collected_info if collected_info else "暂无信息",
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
if not is_final_iteration:
head_prompt = await global_prompt_manager.format_prompt( head_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_prompt_head", "memory_retrieval_react_prompt_head",
bot_name=bot_name, bot_name=bot_name,
time_now=time_now, time_now=time_now,
question=question, question=question,
collected_info=collected_info if collected_info else "",
current_iteration=current_iteration, current_iteration=current_iteration,
remaining_iterations=remaining_iterations, remaining_iterations=remaining_iterations,
max_iterations=max_iterations, max_iterations=max_iterations,
@@ -447,7 +370,6 @@ async def _react_agent_solve_question(
_client, _client,
*, *,
_head_prompt: str = head_prompt, _head_prompt: str = head_prompt,
_prompt: str = prompt,
_conversation_messages: List[Message] = conversation_messages, _conversation_messages: List[Message] = conversation_messages,
) -> List[Message]: ) -> List[Message]:
messages: List[Message] = [] messages: List[Message] = []
@@ -455,33 +377,66 @@ async def _react_agent_solve_question(
system_builder = MessageBuilder() system_builder = MessageBuilder()
system_builder.set_role(RoleType.System) system_builder.set_role(RoleType.System)
system_builder.add_text_content(_head_prompt) system_builder.add_text_content(_head_prompt)
if _prompt.strip():
system_builder.add_text_content(f"\n{_prompt}")
messages.append(system_builder.build()) messages.append(system_builder.build())
messages.extend(_conversation_messages) messages.extend(_conversation_messages)
# for msg in messages: if global_config.debug.show_memory_prompt:
# print(msg) # 优化日志展示 - 合并所有消息到一条日志
log_lines = []
for idx, msg in enumerate(messages, 1):
role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
# 处理内容 - 显示完整内容,不截断
if isinstance(msg.content, str):
full_content = msg.content
content_type = "文本"
elif isinstance(msg.content, list):
text_parts = [item for item in msg.content if isinstance(item, str)]
image_count = len([item for item in msg.content if isinstance(item, tuple)])
full_content = "".join(text_parts) if text_parts else ""
content_type = f"混合({len(text_parts)}段文本, {image_count}张图片)"
else:
full_content = str(msg.content)
content_type = "未知"
# 构建单条消息的日志信息
msg_info = f"\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n========================================"
if full_content:
msg_info += f"\n{full_content}"
if msg.tool_calls:
msg_info += f"\n 工具调用: {len(msg.tool_calls)}"
for tool_call in msg.tool_calls:
msg_info += f"\n - {tool_call}"
if msg.tool_call_id:
msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
log_lines.append(msg_info)
# 合并所有消息为一条日志输出
logger.info(f"消息列表 (共{len(messages)}条):{''.join(log_lines)}")
return messages return messages
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools_by_message_factory( (
success,
response,
reasoning_content,
model_name,
tool_calls,
) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory, message_factory,
model_config=model_config.model_task_config.tool_use, model_config=model_config.model_task_config.tool_use,
tool_options=tool_definitions, tool_options=tool_definitions,
request_type="memory.react", request_type="memory.react",
) )
else:
logger.info(f"ReAct Agent 第 {iteration + 1} 次Prompt: {prompt}")
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_definitions,
request_type="memory.react",
)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
)
if not success: if not success:
logger.error(f"ReAct Agent LLM调用失败: {response}") logger.error(f"ReAct Agent LLM调用失败: {response}")
@@ -502,12 +457,7 @@ async def _react_agent_solve_question(
assistant_message = assistant_builder.build() assistant_message = assistant_builder.build()
# 记录思考步骤 # 记录思考步骤
step = { step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []}
"iteration": iteration + 1,
"thought": response,
"actions": [],
"observations": []
}
# 优先从思考内容中提取found_answer或not_enough_info # 优先从思考内容中提取found_answer或not_enough_info
def extract_quoted_content(text, func_name, param_name): def extract_quoted_content(text, func_name, param_name):
@@ -532,14 +482,14 @@ async def _react_agent_solve_question(
return None return None
# 查找参数名和等号 # 查找参数名和等号
param_pattern = f'{param_name}=' param_pattern = f"{param_name}="
param_pos = text_lower.find(param_pattern, func_pos) param_pos = text_lower.find(param_pattern, func_pos)
if param_pos == -1: if param_pos == -1:
return None return None
# 跳过参数名、等号和空白 # 跳过参数名、等号和空白
start_pos = param_pos + len(param_pattern) start_pos = param_pos + len(param_pattern)
while start_pos < len(text) and text[start_pos] in ' \t\n': while start_pos < len(text) and text[start_pos] in " \t\n":
start_pos += 1 start_pos += 1
if start_pos >= len(text): if start_pos >= len(text):
@@ -555,13 +505,13 @@ async def _react_agent_solve_question(
while end_pos < len(text): while end_pos < len(text):
if text[end_pos] == quote_char: if text[end_pos] == quote_char:
# 检查是否是转义的引号 # 检查是否是转义的引号
if end_pos > start_pos + 1 and text[end_pos - 1] == '\\': if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
end_pos += 1 end_pos += 1
continue continue
# 找到匹配的引号 # 找到匹配的引号
content = text[start_pos + 1:end_pos] content = text[start_pos + 1 : end_pos]
# 处理转义字符 # 处理转义字符
content = content.replace('\\"', '"').replace("\\'", "'").replace('\\\\', '\\') content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
return content return content
end_pos += 1 end_pos += 1
@@ -573,27 +523,35 @@ async def _react_agent_solve_question(
# 只检查responseLLM的直接输出内容不检查reasoning_content # 只检查responseLLM的直接输出内容不检查reasoning_content
if response: if response:
found_answer_content = extract_quoted_content(response, 'found_answer', 'answer') found_answer_content = extract_quoted_content(response, "found_answer", "answer")
if not found_answer_content: if not found_answer_content:
not_enough_info_reason = extract_quoted_content(response, 'not_enough_info', 'reason') not_enough_info_reason = extract_quoted_content(response, "not_enough_info", "reason")
# 如果从输出内容中找到了答案,直接返回 # 如果从输出内容中找到了答案,直接返回
if found_answer_content: if found_answer_content:
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}}) step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
step["observations"] = ["从LLM输出内容中检测到found_answer"] step["observations"] = ["从LLM输出内容中检测到found_answer"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}...") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}..."
)
return True, found_answer_content, thinking_steps, False return True, found_answer_content, thinking_steps, False
if not_enough_info_reason: if not_enough_info_reason:
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}) step["actions"].append(
{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}
)
step["observations"] = ["从LLM输出内容中检测到not_enough_info"] step["observations"] = ["从LLM输出内容中检测到not_enough_info"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}...") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}..."
)
return False, not_enough_info_reason, thinking_steps, False return False, not_enough_info_reason, thinking_steps, False
if is_final_iteration: if is_final_iteration:
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}) step["actions"].append(
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}
)
step["observations"] = ["已到达最后一次迭代,无法找到答案"] step["observations"] = ["已到达最后一次迭代,无法找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案") logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
@@ -633,7 +591,9 @@ async def _react_agent_solve_question(
tool_name = tool_call.func_name tool_name = tool_call.func_name
tool_args = tool_call.args or {} tool_args = tool_call.args or {}
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i+1}/{len(tool_calls)}: {tool_name}({tool_args})") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
)
# 普通工具调用 # 普通工具调用
tool = tool_registry.get_tool(tool_name) tool = tool_registry.get_tool(tool_name)
@@ -643,6 +603,7 @@ async def _react_agent_solve_question(
# 如果工具函数签名需要chat_id添加它 # 如果工具函数签名需要chat_id添加它
import inspect import inspect
sig = inspect.signature(tool.execute_func) sig = inspect.signature(tool.execute_func)
if "chat_id" in sig.parameters: if "chat_id" in sig.parameters:
tool_params["chat_id"] = chat_id tool_params["chat_id"] = chat_id
@@ -662,7 +623,7 @@ async def _react_agent_solve_question(
step["actions"].append({"action_type": tool_name, "action_params": tool_args}) step["actions"].append({"action_type": tool_name, "action_params": tool_args})
else: else:
error_msg = f"未知的工具类型: {tool_name}" error_msg = f"未知的工具类型: {tool_name}"
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1}/{len(tool_calls)} {error_msg}") logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}")
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}"))) tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
# 并行执行所有工具 # 并行执行所有工具
@@ -673,7 +634,7 @@ async def _react_agent_solve_question(
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)): for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
if isinstance(observation, Exception): if isinstance(observation, Exception):
observation = f"工具执行异常: {str(observation)}" observation = f"工具执行异常: {str(observation)}"
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行异常: {observation}") logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
observation_text = observation if isinstance(observation, str) else str(observation) observation_text = observation if isinstance(observation, str) else str(observation)
step["observations"].append(observation_text) step["observations"].append(observation_text)
@@ -692,7 +653,9 @@ async def _react_agent_solve_question(
# 迭代超时应该直接视为not_enough_info而不是使用已有信息 # 迭代超时应该直接视为not_enough_info而不是使用已有信息
# 只有Agent明确返回found_answer时才认为找到了答案 # 只有Agent明确返回found_answer时才认为找到了答案
if collected_info: if collected_info:
logger.warning(f"ReAct Agent达到最大迭代次数或超时但未明确返回found_answer。已收集信息: {collected_info[:100]}...") logger.warning(
f"ReAct Agent达到最大迭代次数或超时但未明确返回found_answer。已收集信息: {collected_info[:100]}..."
)
if is_timeout: if is_timeout:
logger.warning("ReAct Agent超时直接视为not_enough_info") logger.warning("ReAct Agent超时直接视为not_enough_info")
else: else:
@@ -717,10 +680,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
# 查询最近时间窗口内的记录,按更新时间倒序 # 查询最近时间窗口内的记录,按更新时间倒序
records = ( records = (
ThinkingBack.select() ThinkingBack.select()
.where( .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.update_time >= start_time)
)
.order_by(ThinkingBack.update_time.desc()) .order_by(ThinkingBack.update_time.desc())
.limit(5) # 最多返回5条最近的记录 .limit(5) # 最多返回5条最近的记录
) )
@@ -772,9 +732,9 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li
records = ( records = (
ThinkingBack.select() ThinkingBack.select()
.where( .where(
(ThinkingBack.chat_id == chat_id) & (ThinkingBack.chat_id == chat_id)
(ThinkingBack.update_time >= start_time) & & (ThinkingBack.update_time >= start_time)
(ThinkingBack.found_answer == 1) & (ThinkingBack.found_answer == 1)
) )
.order_by(ThinkingBack.update_time.desc()) .order_by(ThinkingBack.update_time.desc())
.limit(5) # 最多返回5条最近的记录 .limit(5) # 最多返回5条最近的记录
@@ -812,10 +772,7 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st
# 按更新时间倒序,获取最新的记录 # 按更新时间倒序,获取最新的记录
records = ( records = (
ThinkingBack.select() ThinkingBack.select()
.where( .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.question == question)
)
.order_by(ThinkingBack.update_time.desc()) .order_by(ThinkingBack.update_time.desc())
.limit(1) .limit(1)
) )
@@ -894,6 +851,7 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
jargon_keyword = analysis_result.get("jargon_keyword", "").strip() jargon_keyword = analysis_result.get("jargon_keyword", "").strip()
if jargon_keyword: if jargon_keyword:
from src.jargon.jargon_miner import store_jargon_from_answer from src.jargon.jargon_miner import store_jargon_from_answer
await store_jargon_from_answer(jargon_keyword, answer, chat_id) await store_jargon_from_answer(jargon_keyword, answer, chat_id)
else: else:
logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...") logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...")
@@ -919,14 +877,8 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
logger.error(f"分析问题和答案时发生异常: {e}") logger.error(f"分析问题和答案时发生异常: {e}")
def _store_thinking_back( def _store_thinking_back(
chat_id: str, chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
question: str,
context: str,
found_answer: bool,
answer: str,
thinking_steps: List[Dict[str, Any]]
) -> None: ) -> None:
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建) """存储或更新思考过程到数据库(如果已存在则更新,否则创建)
@@ -944,10 +896,7 @@ def _store_thinking_back(
# 先查询是否已存在相同chat_id和问题的记录 # 先查询是否已存在相同chat_id和问题的记录
existing = ( existing = (
ThinkingBack.select() ThinkingBack.select()
.where( .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.question == question)
)
.order_by(ThinkingBack.update_time.desc()) .order_by(ThinkingBack.update_time.desc())
.limit(1) .limit(1)
) )
@@ -972,19 +921,14 @@ def _store_thinking_back(
answer=answer, answer=answer,
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False), thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
create_time=now, create_time=now,
update_time=now update_time=now,
) )
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...") logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
except Exception as e: except Exception as e:
logger.error(f"存储思考过程失败: {e}") logger.error(f"存储思考过程失败: {e}")
async def _process_single_question( async def _process_single_question(question: str, chat_id: str, context: str, initial_info: str = "") -> Optional[str]:
question: str,
chat_id: str,
context: str,
initial_info: str = ""
) -> Optional[str]:
"""处理单个问题的查询(包含缓存检查逻辑) """处理单个问题的查询(包含缓存检查逻辑)
Args: Args:
@@ -998,6 +942,24 @@ async def _process_single_question(
""" """
logger.info(f"开始处理问题: {question}") logger.info(f"开始处理问题: {question}")
_cleanup_stale_not_found_thinking_back()
question_initial_info = initial_info or ""
# 预先进行一次LPMM知识库查询作为后续ReAct Agent的辅助信息
if global_config.lpmm_knowledge.enable:
try:
lpmm_result = await query_lpmm_knowledge(question, limit=2)
if lpmm_result and lpmm_result.startswith("你从LPMM知识库中找到"):
if question_initial_info:
question_initial_info += "\n"
question_initial_info += f"【LPMM知识库预查询】\n{lpmm_result}"
logger.info(f"LPMM预查询命中问题: {question[:50]}...")
else:
logger.info(f"LPMM预查询未命中或未找到信息问题: {question[:50]}...")
except Exception as e:
logger.error(f"LPMM预查询失败问题: {question[:50]}... 错误: {e}")
# 先检查thinking_back数据库中是否有现成答案 # 先检查thinking_back数据库中是否有现成答案
cached_result = _query_thinking_back(chat_id, question) cached_result = _query_thinking_back(chat_id, question)
should_requery = False should_requery = False
@@ -1005,26 +967,22 @@ async def _process_single_question(
if cached_result: if cached_result:
cached_found_answer, cached_answer = cached_result cached_found_answer, cached_answer = cached_result
# 根据found_answer的值决定是否重新查询
if cached_found_answer: # found_answer == 1 (True) if cached_found_answer: # found_answer == 1 (True)
# found_answer == 120%概率重新查询 # found_answer == 120%概率重新查询
if random.random() < 0.2: if random.random() < 0.5:
should_requery = True should_requery = True
logger.info(f"found_answer=1触发20%概率重新查询,问题: {question[:50]}...") logger.info(f"found_answer=1触发20%概率重新查询,问题: {question[:50]}...")
else: # found_answer == 0 (False)
# found_answer == 040%概率重新查询
if random.random() < 0.4:
should_requery = True
logger.info(f"found_answer=0触发40%概率重新查询,问题: {question[:50]}...")
# 如果不需要重新查询,使用缓存答案 if not should_requery and cached_answer:
if not should_requery:
if cached_answer:
logger.info(f"从thinking_back缓存中获取答案问题: {question[:50]}...") logger.info(f"从thinking_back缓存中获取答案问题: {question[:50]}...")
return f"问题:{question}\n答案:{cached_answer}" return f"问题:{question}\n答案:{cached_answer}"
else: elif not cached_answer:
# 缓存中没有答案,需要查询
should_requery = True should_requery = True
logger.info(f"found_answer=1 但缓存答案为空,重新查询,问题: {question[:50]}...")
else:
# found_answer == 0不使用缓存直接重新查询
should_requery = True
logger.info(f"thinking_back存在但未找到答案忽略缓存重新查询问题: {question[:50]}...")
# 如果没有缓存答案或需要重新查询使用ReAct Agent查询 # 如果没有缓存答案或需要重新查询使用ReAct Agent查询
if not cached_result or should_requery: if not cached_result or should_requery:
@@ -1038,7 +996,7 @@ async def _process_single_question(
chat_id=chat_id, chat_id=chat_id,
max_iterations=global_config.memory.max_agent_iterations, max_iterations=global_config.memory.max_agent_iterations,
timeout=120.0, timeout=120.0,
initial_info=initial_info initial_info=question_initial_info,
) )
# 存储到数据库(超时时不存储) # 存储到数据库(超时时不存储)
@@ -1049,7 +1007,7 @@ async def _process_single_question(
context=context, context=context,
found_answer=found_answer, found_answer=found_answer,
answer=answer, answer=answer,
thinking_steps=thinking_steps thinking_steps=thinking_steps,
) )
else: else:
logger.info(f"ReAct Agent超时不存储到数据库问题: {question[:50]}...") logger.info(f"ReAct Agent超时不存储到数据库问题: {question[:50]}...")
@@ -1112,17 +1070,17 @@ async def build_memory_retrieval_prompt(
request_type="memory.question", request_type="memory.question",
) )
logger.info(f"记忆检索问题生成提示词: {question_prompt}") if global_config.debug.show_memory_prompt:
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
logger.info(f"记忆检索问题生成响应: {response}") logger.info(f"记忆检索问题生成响应: {response}")
if not success: if not success:
logger.error(f"LLM生成问题失败: {response}") logger.error(f"LLM生成问题失败: {response}")
return "" return ""
# 解析概念列表、人物列表和问题列表 # 解析概念列表和问题列表
concepts, persons, questions = _parse_questions_json(response) concepts, questions = _parse_questions_json(response)
logger.info(f"解析到 {len(concepts)} 个概念: {concepts}") logger.info(f"解析到 {len(concepts)} 个概念: {concepts}")
logger.info(f"解析到 {len(persons)} 个人物: {persons}")
logger.info(f"解析到 {len(questions)} 个问题: {questions}") logger.info(f"解析到 {len(questions)} 个问题: {questions}")
# 对概念进行jargon检索作为初始信息 # 对概念进行jargon检索作为初始信息
@@ -1136,22 +1094,12 @@ async def build_memory_retrieval_prompt(
else: else:
logger.info("概念检索未找到任何结果") logger.info("概念检索未找到任何结果")
# 对人物进行信息检索,添加到初始信息
if persons:
logger.info(f"开始对 {len(persons)} 个人物进行信息检索")
person_info = await _retrieve_persons_info(persons, chat_id)
if person_info:
initial_info += person_info
logger.info(f"人物信息检索完成,结果: {person_info[:200]}...")
else:
logger.info("人物信息检索未找到任何结果")
# 获取缓存的记忆与question时使用相同的时间窗口和数量限制 # 获取缓存的记忆与question时使用相同的时间窗口和数量限制
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0) cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
if not questions: if not questions:
logger.debug("模型认为不需要检索记忆或解析失败") logger.debug("模型认为不需要检索记忆或解析失败")
# 即使没有当次查询,也返回缓存的记忆概念检索结果和人物信息检索结果 # 即使没有当次查询,也返回缓存的记忆概念检索结果
all_results = [] all_results = []
if initial_info: if initial_info:
all_results.append(initial_info.strip()) all_results.append(initial_info.strip())
@@ -1161,7 +1109,7 @@ async def build_memory_retrieval_prompt(
if all_results: if all_results:
retrieved_memory = "\n\n".join(all_results) retrieved_memory = "\n\n".join(all_results)
end_time = time.time() end_time = time.time()
logger.info(f"无当次查询,返回缓存记忆概念检索和人物信息检索结果,耗时: {(end_time - start_time):.3f}") logger.info(f"无当次查询,返回缓存记忆概念检索结果,耗时: {(end_time - start_time):.3f}")
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else: else:
return "" return ""
@@ -1174,12 +1122,7 @@ async def build_memory_retrieval_prompt(
# 并行处理所有问题,将概念检索结果作为初始信息传递 # 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [ question_tasks = [
_process_single_question( _process_single_question(question=question, chat_id=chat_id, context=message, initial_info=initial_info)
question=question,
chat_id=chat_id,
context=message,
initial_info=initial_info
)
for question in questions for question in questions
] ]
@@ -1212,7 +1155,9 @@ async def build_memory_retrieval_prompt(
if all_results: if all_results:
retrieved_memory = "\n\n".join(all_results) retrieved_memory = "\n\n".join(all_results)
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)") logger.info(
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else: else:
logger.debug("所有问题均未找到答案,且无缓存记忆") logger.debug("所有问题均未找到答案,且无缓存记忆")
@@ -1223,14 +1168,14 @@ async def build_memory_retrieval_prompt(
return "" return ""
def _parse_questions_json(response: str) -> Tuple[List[str], List[str], List[str]]: def _parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表、人物列表和问题列表 """解析问题JSON返回概念列表和问题列表
Args: Args:
response: LLM返回的响应 response: LLM返回的响应
Returns: Returns:
Tuple[List[str], List[str], List[str]]: (概念列表, 人物列表, 问题列表) Tuple[List[str], List[str]]: (概念列表, 问题列表)
""" """
try: try:
# 尝试提取JSON可能包含在```json代码块中 # 尝试提取JSON可能包含在```json代码块中
@@ -1249,30 +1194,26 @@ def _parse_questions_json(response: str) -> Tuple[List[str], List[str], List[str
# 解析JSON # 解析JSON
parsed = json.loads(repaired_json) parsed = json.loads(repaired_json)
# 只支持新格式包含concepts、person和questions的对象 # 只支持新格式包含concepts和questions的对象
if not isinstance(parsed, dict): if not isinstance(parsed, dict):
logger.warning(f"解析的JSON不是对象格式: {parsed}") logger.warning(f"解析的JSON不是对象格式: {parsed}")
return [], [], [] return [], []
concepts_raw = parsed.get("concepts", []) concepts_raw = parsed.get("concepts", [])
persons_raw = parsed.get("person", [])
questions_raw = parsed.get("questions", []) questions_raw = parsed.get("questions", [])
# 确保是列表 # 确保是列表
if not isinstance(concepts_raw, list): if not isinstance(concepts_raw, list):
concepts_raw = [] concepts_raw = []
if not isinstance(persons_raw, list):
persons_raw = []
if not isinstance(questions_raw, list): if not isinstance(questions_raw, list):
questions_raw = [] questions_raw = []
# 确保所有元素都是字符串 # 确保所有元素都是字符串
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()] concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
persons = [p for p in persons_raw if isinstance(p, str) and p.strip()]
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()] questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
return concepts, persons, questions return concepts, questions
except Exception as e: except Exception as e:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...") logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], [], [] return [], []

View File

@@ -3,6 +3,7 @@
记忆系统工具函数 记忆系统工具函数
包含模糊查找、相似度计算等工具函数 包含模糊查找、相似度计算等工具函数
""" """
import json import json
import re import re
from datetime import datetime from datetime import datetime
@@ -14,6 +15,7 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils") logger = get_logger("memory_utils")
def parse_md_json(json_text: str) -> list[str]: def parse_md_json(json_text: str) -> list[str]:
"""从Markdown格式的内容中提取JSON对象和推理内容""" """从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = [] json_objects = []
@@ -52,6 +54,7 @@ def parse_md_json(json_text: str) -> list[str]:
return json_objects, reasoning_content return json_objects, reasoning_content
def calculate_similarity(text1: str, text2: str) -> float: def calculate_similarity(text1: str, text2: str) -> float:
""" """
计算两个文本的相似度 计算两个文本的相似度
@@ -97,10 +100,10 @@ def preprocess_text(text: str) -> str:
text = text.lower() text = text.lower()
# 移除标点符号和特殊字符 # 移除标点符号和特殊字符
text = re.sub(r'[^\w\s]', '', text) text = re.sub(r"[^\w\s]", "", text)
# 移除多余空格 # 移除多余空格
text = re.sub(r'\s+', ' ', text).strip() text = re.sub(r"\s+", " ", text).strip()
return text return text
@@ -109,7 +112,6 @@ def preprocess_text(text: str) -> str:
return text return text
def parse_datetime_to_timestamp(value: str) -> float: def parse_datetime_to_timestamp(value: str) -> float:
""" """
接受多种常见格式并转换为时间戳(秒) 接受多种常见格式并转换为时间戳(秒)
@@ -164,4 +166,3 @@ def parse_time_range(time_range: str) -> Tuple[float, float]:
end_timestamp = parse_datetime_to_timestamp(end_str) end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp return start_timestamp, end_timestamp

View File

@@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_person_info import register_tool as register_query_person_info from .query_person_info import register_tool as register_query_person_info
from src.config.config import global_config from src.config.config import global_config
def init_all_tools(): def init_all_tools():
"""初始化并注册所有记忆检索工具""" """初始化并注册所有记忆检索工具"""
register_query_jargon() register_query_jargon()

View File

@@ -15,10 +15,7 @@ logger = get_logger("memory_retrieval_tools")
async def query_chat_history( async def query_chat_history(
chat_id: str, chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
keyword: Optional[str] = None,
time_range: Optional[str] = None,
fuzzy: bool = True
) -> str: ) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述 """根据时间或关键词在chat_history表中查询聊天记录概述
@@ -50,17 +47,11 @@ async def query_chat_history(
# 时间范围:查询与时间范围有交集的记录 # 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range) start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp # 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = ( time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
else: else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time # 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range) target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = ( time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
query = query.where(time_filter) query = query.where(time_filter)
# 执行查询 # 执行查询
@@ -91,7 +82,9 @@ async def query_chat_history(
record_keywords_list = [] record_keywords_list = []
if record.keywords: if record.keywords:
try: try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list): if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data] record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError): except (json.JSONDecodeError, TypeError, ValueError):
@@ -102,20 +95,24 @@ async def query_chat_history(
if fuzzy: if fuzzy:
# 模糊匹配只要包含任意一个关键词即匹配OR关系 # 模糊匹配只要包含任意一个关键词即匹配OR关系
for kw in keywords_lower: for kw in keywords_lower:
if (kw in theme or if (
kw in summary or kw in theme
kw in original_text or or kw in summary
any(kw in k for k in record_keywords_list)): or kw in original_text
or any(kw in k for k in record_keywords_list)
):
matched = True matched = True
break break
else: else:
# 全匹配必须包含所有关键词才匹配AND关系 # 全匹配必须包含所有关键词才匹配AND关系
matched = True matched = True
for kw in keywords_lower: for kw in keywords_lower:
kw_matched = (kw in theme or kw_matched = (
kw in summary or kw in theme
kw in original_text or or kw in summary
any(kw in k for k in record_keywords_list)) or kw in original_text
or any(kw in k for k in record_keywords_list)
)
if not kw_matched: if not kw_matched:
matched = False matched = False
break break
@@ -160,6 +157,7 @@ async def query_chat_history(
# 添加时间范围 # 添加时间范围
from datetime import datetime from datetime import datetime
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S") start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S") end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}") result_parts.append(f"时间:{start_str} - {end_str}")
@@ -193,26 +191,26 @@ def register_tool():
"""注册工具""" """注册工具"""
register_memory_retrieval_tool( register_memory_retrieval_tool(
name="query_chat_history", name="query_chat_history",
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)", description="根据时间或关键词在聊天记录中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
parameters=[ parameters=[
{ {
"name": "keyword", "name": "keyword",
"type": "string", "type": "string",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)", "description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
"required": False "required": False,
}, },
{ {
"name": "time_range", "name": "time_range",
"type": "string", "type": "string",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)", "description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False "required": False,
}, },
{ {
"name": "fuzzy", "name": "fuzzy",
"type": "boolean", "type": "boolean",
"description": "是否使用模糊匹配模式默认True。True表示模糊匹配只要包含任意一个关键词即匹配OR关系False表示全匹配必须包含所有关键词才匹配AND关系", "description": "是否使用模糊匹配模式默认True。True表示模糊匹配只要包含任意一个关键词即匹配OR关系False表示全匹配必须包含所有关键词才匹配AND关系",
"required": False "required": False,
} },
], ],
execute_func=query_chat_history execute_func=query_chat_history,
) )

View File

@@ -10,7 +10,7 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools") logger = get_logger("memory_retrieval_tools")
async def query_lpmm_knowledge(query: str) -> str: async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
"""在LPMM知识库中查询相关信息 """在LPMM知识库中查询相关信息
Args: Args:
@@ -24,6 +24,12 @@ async def query_lpmm_knowledge(query: str) -> str:
if not content: if not content:
return "查询关键词为空" return "查询关键词为空"
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
if not global_config.lpmm_knowledge.enable: if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用") logger.debug("LPMM知识库未启用")
return "LPMM知识库未启用" return "LPMM知识库未启用"
@@ -33,7 +39,7 @@ async def query_lpmm_knowledge(query: str) -> str:
logger.debug("LPMM知识库未初始化跳过查询") logger.debug("LPMM知识库未初始化跳过查询")
return "LPMM知识库未初始化" return "LPMM知识库未初始化"
knowledge_info = await qa_manager.get_knowledge(content) knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
logger.debug(f"LPMM知识库查询结果: {knowledge_info}") logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
if knowledge_info: if knowledge_info:
@@ -57,9 +63,13 @@ def register_tool():
"type": "string", "type": "string",
"description": "需要查询的关键词或问题", "description": "需要查询的关键词或问题",
"required": True, "required": True,
} },
{
"name": "limit",
"type": "integer",
"description": "希望返回的相关知识条数默认为5",
"required": False,
},
], ],
execute_func=query_lpmm_knowledge, execute_func=query_lpmm_knowledge,
) )

View File

@@ -26,7 +26,9 @@ def _format_group_nick_names(group_nick_name_field) -> str:
try: try:
# 解析JSON格式的群昵称列表 # 解析JSON格式的群昵称列表
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field group_nick_names_data = (
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
)
if not isinstance(group_nick_names_data, list) or not group_nick_names_data: if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
return "" return ""
@@ -71,9 +73,7 @@ async def query_person_info(person_name: str) -> str:
return "用户名称为空" return "用户名称为空"
# 构建查询条件(使用模糊查询) # 构建查询条件(使用模糊查询)
query = PersonInfo.select().where( query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
PersonInfo.person_name.contains(person_name)
)
# 执行查询 # 执行查询
records = list(query.limit(20)) # 最多返回20条记录 records = list(query.limit(20)) # 最多返回20条记录
@@ -137,7 +137,11 @@ async def query_person_info(person_name: str) -> str:
# 记忆点memory_points # 记忆点memory_points
if record.memory_points: if record.memory_points:
try: try:
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points memory_points_data = (
json.loads(record.memory_points)
if isinstance(record.memory_points, str)
else record.memory_points
)
if isinstance(memory_points_data, list) and memory_points_data: if isinstance(memory_points_data, list) and memory_points_data:
# 解析记忆点格式category:content:weight # 解析记忆点格式category:content:weight
memory_list = [] memory_list = []
@@ -206,7 +210,11 @@ async def query_person_info(person_name: str) -> str:
# 记忆点memory_points # 记忆点memory_points
if record.memory_points: if record.memory_points:
try: try:
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points memory_points_data = (
json.loads(record.memory_points)
if isinstance(record.memory_points, str)
else record.memory_points
)
if isinstance(memory_points_data, list) and memory_points_data: if isinstance(memory_points_data, list) and memory_points_data:
# 解析记忆点格式category:content:weight # 解析记忆点格式category:content:weight
memory_list = [] memory_list = []
@@ -275,13 +283,7 @@ def register_tool():
name="query_person_info", name="query_person_info",
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等", description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
parameters=[ parameters=[
{ {"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
"name": "person_name",
"type": "string",
"description": "用户名称,用于查询用户信息",
"required": True
}
], ],
execute_func=query_person_info execute_func=query_person_info,
) )

View File

@@ -82,11 +82,7 @@ class MemoryRetrievalTool:
param_tuples.append(param_tuple) param_tuples.append(param_tuple)
# 构建工具定义格式与BaseTool.get_tool_definition()一致 # 构建工具定义格式与BaseTool.get_tool_definition()一致
tool_def = { tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
"name": self.name,
"description": self.description,
"parameters": param_tuples
}
return tool_def return tool_def

View File

@@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
class Person: class Person:
@classmethod @classmethod
def register_person( def register_person(
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None cls,
platform: str,
user_id: str,
nickname: str,
group_id: Optional[str] = None,
group_nick_name: Optional[str] = None,
): ):
""" """
注册新用户的类方法 注册新用户的类方法
@@ -781,7 +786,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
if len(parts) >= 2: if len(parts) >= 2:
existing_content = parts[1].strip() existing_content = parts[1].strip()
# 简单相似度检查(如果内容相同或非常相似,则跳过) # 简单相似度检查(如果内容相同或非常相似,则跳过)
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content: if (
existing_content == memory_content
or memory_content in existing_content
or existing_content in memory_content
):
is_duplicate = True is_duplicate = True
break break

View File

@@ -125,7 +125,6 @@ class ToolExecutor:
prompt=prompt, tools=tools, raise_when_empty=False prompt=prompt, tools=tools, raise_when_empty=False
) )
# 执行工具调用 # 执行工具调用
tool_results, used_tools = await self.execute_tool_calls(tool_calls) tool_results, used_tools = await self.execute_tool_calls(tool_calls)

View File

@@ -102,13 +102,13 @@ class EmojiAction(BaseAction):
# 5. 调用LLM # 5. 调用LLM
models = llm_api.get_available_models() models = llm_api.get_available_models()
chat_model_config = models.get("replyer") # 使用字典访问方式 chat_model_config = models.get("utils") # 使用字典访问方式
if not chat_model_config: if not chat_model_config:
logger.error(f"{self.log_prefix} 未找到'replyer'模型配置无法调用LLM") logger.error(f"{self.log_prefix} 未找到'utils'模型配置无法调用LLM")
return False, "未找到'replyer'模型配置" return False, "未找到'utils'模型配置"
success, chosen_emotion, _, _ = await llm_api.generate_with_model( success, chosen_emotion, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="emoji" prompt, model_config=chat_model_config, request_type="emoji.select"
) )
if not success: if not success:

View File

@@ -15,6 +15,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具" description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
parameters = [ parameters = [
("query", ToolParamType.STRING, "搜索查询关键词", True, None), ("query", ToolParamType.STRING, "搜索查询关键词", True, None),
("limit", ToolParamType.INTEGER, "希望返回的相关知识条数默认5", False, 5),
] ]
available_for_llm = global_config.lpmm_knowledge.enable available_for_llm = global_config.lpmm_knowledge.enable
@@ -29,6 +30,12 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
""" """
try: try:
query: str = function_args.get("query") # type: ignore query: str = function_args.get("query") # type: ignore
limit = function_args.get("limit", 5)
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
# threshold = function_args.get("threshold", 0.4) # threshold = function_args.get("threshold", 0.4)
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
@@ -38,7 +45,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
# 调用知识库搜索 # 调用知识库搜索
knowledge_info = await qa_manager.get_knowledge(query) knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
logger.debug(f"知识库查询结果: {knowledge_info}") logger.debug(f"知识库查询结果: {knowledge_info}")

559
src/webui/config_routes.py Normal file
View File

@@ -0,0 +1,559 @@
"""
配置管理API路由
"""
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body
from typing import Any
from src.common.logger import get_logger
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR
from src.config.official_configs import (
BotConfig,
PersonalityConfig,
RelationshipConfig,
ChatConfig,
MessageReceiveConfig,
EmojiConfig,
ExpressionConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
ToolConfig,
MemoryConfig,
DebugConfig,
MoodConfig,
VoiceConfig,
JargonConfig,
)
from src.config.api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
from src.webui.config_schema import ConfigSchemaGenerator
logger = get_logger("webui")
router = APIRouter(prefix="/config", tags=["config"])
# ===== 辅助函数 =====
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
"""
递归合并字典,保留 target 中的注释和格式
将 source 的值更新到 target 中(仅更新已存在的键)
Args:
target: 目标字典tomlkit 对象,包含注释)
source: 源字典(普通 dict 或 list
"""
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
if isinstance(source, list):
return # 调用者需要直接赋值
# 如果都是字典,递归合并
if isinstance(source, dict) and isinstance(target, dict):
for key, value in source.items():
if key == "version":
continue # 跳过版本号
if key in target:
target_value = target[key]
# 递归处理嵌套字典
if isinstance(value, dict) and isinstance(target_value, dict):
_update_dict_preserve_comments(target_value, value)
else:
# 使用 tomlkit.item 保持类型
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema():
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
schema = ConfigSchemaGenerator.generate_config_schema(Config)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
@router.get("/schema/model")
async def get_model_config_schema():
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取模型配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
# ===== 子配置架构获取接口 =====
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str):
"""
获取指定配置节的架构
支持的section_name:
- bot: BotConfig
- personality: PersonalityConfig
- relationship: RelationshipConfig
- chat: ChatConfig
- message_receive: MessageReceiveConfig
- emoji: EmojiConfig
- expression: ExpressionConfig
- keyword_reaction: KeywordReactionConfig
- chinese_typo: ChineseTypoConfig
- response_post_process: ResponsePostProcessConfig
- response_splitter: ResponseSplitterConfig
- telemetry: TelemetryConfig
- experimental: ExperimentalConfig
- maim_message: MaimMessageConfig
- lpmm_knowledge: LPMMKnowledgeConfig
- tool: ToolConfig
- memory: MemoryConfig
- debug: DebugConfig
- mood: MoodConfig
- voice: VoiceConfig
- jargon: JargonConfig
- model_task_config: ModelTaskConfig
- api_provider: APIProvider
- model_info: ModelInfo
"""
section_map = {
"bot": BotConfig,
"personality": PersonalityConfig,
"relationship": RelationshipConfig,
"chat": ChatConfig,
"message_receive": MessageReceiveConfig,
"emoji": EmojiConfig,
"expression": ExpressionConfig,
"keyword_reaction": KeywordReactionConfig,
"chinese_typo": ChineseTypoConfig,
"response_post_process": ResponsePostProcessConfig,
"response_splitter": ResponseSplitterConfig,
"telemetry": TelemetryConfig,
"experimental": ExperimentalConfig,
"maim_message": MaimMessageConfig,
"lpmm_knowledge": LPMMKnowledgeConfig,
"tool": ToolConfig,
"memory": MemoryConfig,
"debug": DebugConfig,
"mood": MoodConfig,
"voice": VoiceConfig,
"jargon": JargonConfig,
"model_task_config": ModelTaskConfig,
"api_provider": APIProvider,
"model_info": ModelInfo,
}
if section_name not in section_map:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
try:
config_class = section_map[section_name]
schema = ConfigSchemaGenerator.generate_schema(config_class, include_nested=False)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置节架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
# ===== 配置读取接口 =====
@router.get("/bot")
async def get_bot_config():
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
@router.get("/model")
async def get_model_config():
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
# ===== 配置更新接口 =====
@router.post("/bot")
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config_data, f)
logger.info("麦麦主程序配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
@router.post("/model")
async def update_model_config(config_data: dict[str, Any] = Body(...)):
"""更新模型配置"""
try:
# 验证配置数据
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config_data, f)
logger.info("模型配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
# ===== 配置节更新接口 =====
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组类型(如 platforms, aliases直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config_data, f)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
# ===== 原始 TOML 文件操作接口 =====
@router.get("/bot/raw")
async def get_bot_config_raw():
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
raw_content = f.read()
return {"success": True, "content": raw_content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
try:
config_data = tomlkit.loads(raw_content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
# 验证配置数据结构
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
f.write(raw_content)
logger.info("麦麦主程序配置已更新(原始模式)")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组表(如 [[models]], [[api_providers]]),直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config_data, f)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
# ===== 适配器配置管理接口 =====
@router.get("/adapter-config/path")
async def get_adapter_config_path():
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
webui_data_path = os.path.join("data", "webui.json")
if not os.path.exists(webui_data_path):
return {"success": True, "path": None}
import json
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
adapter_config_path = webui_data.get("adapter_config_path")
if not adapter_config_path:
return {"success": True, "path": None}
# 检查文件是否存在并返回最后修改时间
if os.path.exists(adapter_config_path):
import datetime
mtime = os.path.getmtime(adapter_config_path)
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
return {"success": True, "path": adapter_config_path, "lastModified": last_modified}
else:
return {"success": True, "path": adapter_config_path, "lastModified": None}
except Exception as e:
logger.error(f"获取适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
# 保存到 data/webui.json
webui_data_path = os.path.join("data", "webui.json")
import json
# 读取现有数据
if os.path.exists(webui_data_path):
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
else:
webui_data = {}
# 更新路径
webui_data["adapter_config_path"] = path
# 保存
os.makedirs("data", exist_ok=True)
with open(webui_data_path, "w", encoding="utf-8") as f:
json.dump(webui_data, f, ensure_ascii=False, indent=2)
logger.info(f"适配器配置路径已保存: {path}")
return {"success": True, "message": "路径已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
@router.get("/adapter-config")
async def get_adapter_config(path: str):
"""从指定路径读取适配器配置文件"""
try:
if not path:
raise HTTPException(status_code=400, detail="路径参数不能为空")
# 检查文件是否存在
if not os.path.exists(path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
# 检查文件扩展名
if not path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 读取文件内容
with open(path, "r", encoding="utf-8") as f:
content = f.read()
logger.info(f"已读取适配器配置: {path}")
return {"success": True, "content": content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
@router.post("/adapter-config")
async def save_adapter_config(data: dict[str, str] = Body(...)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")
content = data.get("content")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
if content is None:
raise HTTPException(status_code=400, detail="配置内容不能为空")
# 检查文件扩展名
if not path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 验证 TOML 格式
try:
import toml
toml.loads(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
# 确保目录存在
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
# 保存文件
with open(path, "w", encoding="utf-8") as f:
f.write(content)
logger.info(f"适配器配置已保存: {path}")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")

336
src/webui/config_schema.py Normal file
View File

@@ -0,0 +1,336 @@
"""
配置架构生成器 - 自动从配置类生成前端表单架构
"""
import inspect
from dataclasses import fields, MISSING
from typing import Any, get_origin, get_args, Literal, Optional
from enum import Enum
from src.config.config_base import ConfigBase
class FieldType(str, Enum):
"""字段类型枚举"""
STRING = "string"
NUMBER = "number"
INTEGER = "integer"
BOOLEAN = "boolean"
SELECT = "select"
ARRAY = "array"
OBJECT = "object"
TEXTAREA = "textarea"
class FieldSchema:
"""字段架构"""
def __init__(
self,
name: str,
type: FieldType,
label: str,
description: str = "",
default: Any = None,
required: bool = True,
options: Optional[list[str]] = None,
min_value: Optional[float] = None,
max_value: Optional[float] = None,
items: Optional[dict] = None,
properties: Optional[dict] = None,
):
self.name = name
self.type = type
self.label = label
self.description = description
self.default = default
self.required = required
self.options = options
self.min_value = min_value
self.max_value = max_value
self.items = items
self.properties = properties
def to_dict(self) -> dict:
"""转换为字典"""
result = {
"name": self.name,
"type": self.type.value,
"label": self.label,
"description": self.description,
"required": self.required,
}
if self.default is not None:
result["default"] = self.default
if self.options is not None:
result["options"] = self.options
if self.min_value is not None:
result["minValue"] = self.min_value
if self.max_value is not None:
result["maxValue"] = self.max_value
if self.items is not None:
result["items"] = self.items
if self.properties is not None:
result["properties"] = self.properties
return result
class ConfigSchemaGenerator:
"""配置架构生成器"""
@staticmethod
def _extract_field_description(config_class: type, field_name: str) -> str:
"""
从类定义中提取字段的文档字符串描述
Args:
config_class: 配置类
field_name: 字段名
Returns:
str: 字段描述
"""
try:
# 获取源代码
source = inspect.getsource(config_class)
lines = source.split("\n")
# 查找字段定义
field_found = False
description_lines = []
for i, line in enumerate(lines):
# 匹配字段定义行,例如: platform: str
if f"{field_name}:" in line and "=" in line:
field_found = True
# 查找下一行的文档字符串
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"):
# 单行文档字符串
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip())
else:
# 多行文档字符串
quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip())
for j in range(i + 2, len(lines)):
if quote in lines[j]:
description_lines.append(lines[j].split(quote)[0].strip())
break
description_lines.append(lines[j].strip())
break
elif f"{field_name}:" in line and "=" not in line:
# 没有默认值的字段
field_found = True
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"):
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip())
else:
quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip())
for j in range(i + 2, len(lines)):
if quote in lines[j]:
description_lines.append(lines[j].split(quote)[0].strip())
break
description_lines.append(lines[j].strip())
break
if field_found and description_lines:
return " ".join(description_lines)
except Exception:
pass
return ""
@staticmethod
def _get_field_type_and_options(field_type: type) -> tuple[FieldType, Optional[list[str]], Optional[dict]]:
"""
获取字段类型和选项
Args:
field_type: 字段类型
Returns:
tuple: (FieldType, options, items)
"""
origin = get_origin(field_type)
args = get_args(field_type)
# 处理 Literal 类型(枚举选项)
if origin is Literal:
return FieldType.SELECT, [str(arg) for arg in args], None
# 处理 list 类型
if origin is list:
item_type = args[0] if args else str
if item_type is str:
items = {"type": "string"}
elif item_type is int:
items = {"type": "integer"}
elif item_type is float:
items = {"type": "number"}
elif item_type is bool:
items = {"type": "boolean"}
elif item_type is dict:
items = {"type": "object"}
else:
items = {"type": "string"}
return FieldType.ARRAY, None, items
# 处理 set 类型(与 list 类似)
if origin is set:
item_type = args[0] if args else str
if item_type is str:
items = {"type": "string"}
else:
items = {"type": "string"}
return FieldType.ARRAY, None, items
# 处理基本类型
if field_type is bool or field_type == bool:
return FieldType.BOOLEAN, None, None
elif field_type is int or field_type == int:
return FieldType.INTEGER, None, None
elif field_type is float or field_type == float:
return FieldType.NUMBER, None, None
elif field_type is str or field_type == str:
return FieldType.STRING, None, None
elif field_type is dict or origin is dict:
return FieldType.OBJECT, None, None
# 默认为字符串
return FieldType.STRING, None, None
@staticmethod
def _format_field_name(name: str) -> str:
"""
格式化字段名为可读的标签
Args:
name: 原始字段名
Returns:
str: 格式化后的标签
"""
# 将下划线替换为空格,并首字母大写
return " ".join(word.capitalize() for word in name.split("_"))
@staticmethod
def generate_schema(config_class: type[ConfigBase], include_nested: bool = True) -> dict:
"""
从配置类生成前端表单架构
Args:
config_class: 配置类(必须继承自 ConfigBase
include_nested: 是否包含嵌套的配置对象
Returns:
dict: 前端表单架构
"""
if not issubclass(config_class, ConfigBase):
raise ValueError(f"{config_class.__name__} 必须继承自 ConfigBase")
schema_fields = []
nested_schemas = {}
for field in fields(config_class):
# 跳过私有字段和内部字段
if field.name.startswith("_") or field.name in ["MMC_VERSION"]:
continue
# 提取字段描述
description = ConfigSchemaGenerator._extract_field_description(config_class, field.name)
# 判断是否必填
required = field.default is MISSING and field.default_factory is MISSING
# 获取默认值
default_value = None
if field.default is not MISSING:
default_value = field.default
elif field.default_factory is not MISSING:
try:
default_value = field.default_factory()
except Exception:
default_value = None
# 检查是否为嵌套的 ConfigBase
if isinstance(field.type, type) and issubclass(field.type, ConfigBase):
if include_nested:
# 递归生成嵌套配置的架构
nested_schema = ConfigSchemaGenerator.generate_schema(field.type, include_nested=True)
nested_schemas[field.name] = nested_schema
field_schema = FieldSchema(
name=field.name,
type=FieldType.OBJECT,
label=ConfigSchemaGenerator._format_field_name(field.name),
description=description or field.type.__doc__ or "",
default=default_value,
required=required,
properties=nested_schema,
)
else:
continue
else:
# 获取字段类型和选项
field_type, options, items = ConfigSchemaGenerator._get_field_type_and_options(field.type)
# 特殊处理:长文本使用 textarea
if field_type == FieldType.STRING and field.name in [
"personality",
"reply_style",
"interest",
"plan_style",
"visual_style",
"private_plan_style",
"emotion_style",
"reaction",
"filtration_prompt",
]:
field_type = FieldType.TEXTAREA
field_schema = FieldSchema(
name=field.name,
type=field_type,
label=ConfigSchemaGenerator._format_field_name(field.name),
description=description,
default=default_value,
required=required,
options=options,
items=items,
)
schema_fields.append(field_schema.to_dict())
return {
"className": config_class.__name__,
"classDoc": config_class.__doc__ or "",
"fields": schema_fields,
"nested": nested_schemas if nested_schemas else None,
}
@staticmethod
def generate_config_schema(config_class: type[ConfigBase]) -> dict:
"""
生成完整的配置架构(包含所有嵌套的子配置)
Args:
config_class: 配置类
Returns:
dict: 完整的配置架构
"""
return ConfigSchemaGenerator.generate_schema(config_class, include_nested=True)

562
src/webui/emoji_routes.py Normal file
View File

@@ -0,0 +1,562 @@
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, List
from src.common.logger import get_logger
from src.common.database.database_model import Emoji
from .token_manager import get_token_manager
import json
import time
import os
logger = get_logger("webui.emoji")
# 创建路由器
router = APIRouter(prefix="/emoji", tags=["Emoji"])
class EmojiResponse(BaseModel):
"""表情包响应"""
id: int
full_path: str
format: str
emoji_hash: str
description: str
query_count: int
is_registered: bool
is_banned: bool
emotion: Optional[str] # 直接返回字符串
record_time: float
register_time: Optional[float]
usage_count: int
last_used_time: Optional[float]
class EmojiListResponse(BaseModel):
"""表情包列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[EmojiResponse]
class EmojiDetailResponse(BaseModel):
"""表情包详情响应"""
success: bool
data: EmojiResponse
class EmojiUpdateRequest(BaseModel):
"""表情包更新请求"""
description: Optional[str] = None
is_registered: Optional[bool] = None
is_banned: Optional[bool] = None
emotion: Optional[str] = None
class EmojiUpdateResponse(BaseModel):
"""表情包更新响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
class EmojiDeleteResponse(BaseModel):
"""表情包删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
emoji_ids: List[int]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[int] = []
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
"""将 Emoji 模型转换为响应对象"""
return EmojiResponse(
id=emoji.id,
full_path=emoji.full_path,
format=emoji.format,
emoji_hash=emoji.emoji_hash,
description=emoji.description,
query_count=emoji.query_count,
is_registered=emoji.is_registered,
is_banned=emoji.is_banned,
emotion=str(emoji.emotion) if emoji.emotion is not None else None,
record_time=emoji.record_time,
register_time=emoji.register_time,
usage_count=emoji.usage_count,
last_used_time=emoji.last_used_time,
)
@router.get("/list", response_model=EmojiListResponse)
async def get_emoji_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
format: Optional[str] = Query(None, description="格式筛选"),
authorization: Optional[str] = Header(None),
):
"""
获取表情包列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 description, emoji_hash)
is_registered: 是否已注册筛选
is_banned: 是否被禁用筛选
format: 格式筛选
authorization: Authorization header
Returns:
表情包列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = Emoji.select()
# 搜索过滤
if search:
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
# 注册状态过滤
if is_registered is not None:
query = query.where(Emoji.is_registered == is_registered)
# 禁用状态过滤
if is_banned is not None:
query = query.where(Emoji.is_banned == is_banned)
# 格式过滤
if format:
query = query.where(Emoji.format == format)
# 排序:使用次数倒序,然后按记录时间倒序
from peewee import Case
query = query.order_by(
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
emojis = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [emoji_to_response(emoji) for emoji in emojis]
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表情包列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表情包列表失败: {str(e)}") from e
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
获取表情包详细信息
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
表情包详细信息
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表情包详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表情包详情失败: {str(e)}") from e
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
"""
增量更新表情包(只更新提供的字段)
Args:
emoji_id: 表情包ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# emotion 字段直接使用字符串,无需转换
# 如果注册状态从 False 变为 True记录注册时间
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
update_data["register_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(emoji, field, value)
emoji.save()
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
return EmojiUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表情包失败: {str(e)}") from e
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
删除表情包
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 记录删除信息
emoji_hash = emoji.emoji_hash
# 执行删除
emoji.delete_instance()
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表情包失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
"""
获取表情包统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = Emoji.select().count()
registered = Emoji.select().where(Emoji.is_registered).count()
banned = Emoji.select().where(Emoji.is_banned).count()
# 按格式统计
formats = {}
for emoji in Emoji.select(Emoji.format):
fmt = emoji.format
formats[fmt] = formats.get(fmt, 0) + 1
# 获取最常用的表情包前10
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
top_used_list = [
{
"id": emoji.id,
"emoji_hash": emoji.emoji_hash,
"description": emoji.description,
"usage_count": emoji.usage_count,
}
for emoji in top_used
]
return {
"success": True,
"data": {
"total": total,
"registered": registered,
"banned": banned,
"unregistered": total - registered,
"formats": formats,
"top_used": top_used_list,
},
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
注册表情包(快捷操作)
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if emoji.is_registered:
raise HTTPException(status_code=400, detail="该表情包已经注册")
if emoji.is_banned:
raise HTTPException(status_code=400, detail="该表情包已被禁用,无法注册")
# 注册表情包
emoji.is_registered = True
emoji.register_time = time.time()
emoji.save()
logger.info(f"表情包已注册: ID={emoji_id}")
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
logger.exception(f"注册表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"注册表情包失败: {str(e)}") from e
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
禁用表情包(快捷操作)
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 禁用表情包(同时取消注册)
emoji.is_banned = True
emoji.is_registered = False
emoji.save()
logger.info(f"表情包已禁用: ID={emoji_id}")
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
logger.exception(f"禁用表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"禁用表情包失败: {str(e)}") from e
@router.get("/{emoji_id}/thumbnail")
async def get_emoji_thumbnail(
emoji_id: int,
token: Optional[str] = Query(None, description="访问令牌"),
authorization: Optional[str] = Header(None),
):
"""
获取表情包缩略图
Args:
emoji_id: 表情包ID
token: 访问令牌(通过 query parameter
authorization: Authorization header
Returns:
表情包图片文件
"""
try:
# 优先使用 query parameter 中的 token用于 img 标签)
if token:
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
else:
# 如果没有 query token则验证 Authorization header
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 检查文件是否存在
if not os.path.exists(emoji.full_path):
raise HTTPException(status_code=404, detail="表情包文件不存在")
# 根据格式设置 MIME 类型
mime_types = {
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
}
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表情包缩略图失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
"""
批量删除表情包
Args:
request: 包含emoji_ids列表的请求
authorization: Authorization header
Returns:
批量删除结果
"""
try:
verify_auth_token(authorization)
if not request.emoji_ids:
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
deleted_count = 0
failed_count = 0
failed_ids = []
for emoji_id in request.emoji_ids:
try:
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if emoji:
emoji.delete_instance()
deleted_count += 1
logger.info(f"批量删除表情包: {emoji_id}")
else:
failed_count += 1
failed_ids.append(emoji_id)
except Exception as e:
logger.error(f"删除表情包 {emoji_id} 失败: {e}")
failed_count += 1
failed_ids.append(emoji_id)
message = f"成功删除 {deleted_count} 个表情包"
if failed_count > 0:
message += f"{failed_count} 个失败"
return BatchDeleteResponse(
success=True,
message=message,
deleted_count=deleted_count,
failed_count=failed_count,
failed_ids=failed_ids,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e

View File

@@ -0,0 +1,432 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from .token_manager import get_token_manager
import time
logger = get_logger("webui.expression")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
context: Optional[str]
up_content: Optional[str]
last_active_time: float
chat_id: str
create_date: Optional[float]
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
context: Optional[str] = None
up_content: Optional[str] = None
chat_id: str
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
context: Optional[str] = None
up_content: Optional[str] = None
chat_id: Optional[str] = None
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
return ExpressionResponse(
id=expression.id,
situation=expression.situation,
style=expression.style,
context=expression.context,
up_content=expression.up_content,
last_active_time=expression.last_active_time,
chat_id=expression.chat_id,
create_date=expression.create_date,
)
@router.get("/list", response_model=ExpressionListResponse)
async def get_expression_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style, context)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = Expression.select()
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search))
| (Expression.style.contains(search))
| (Expression.context.contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(authorization)
current_time = time.time()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
style=request.style,
context=request.context,
up_content=request.up_content,
chat_id=request.chat_id,
last_active_time=current_time,
create_date=current_time,
)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True, message="表达方式创建成功", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"创建表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
):
"""
增量更新表达方式(只更新提供的字段)
Args:
expression_id: 表达方式ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后活跃时间
update_data["last_active_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
# 执行删除
expression.delete_instance()
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int]
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
"""
批量删除表达方式
Args:
request: 包含要删除的ID列表的请求
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
expressions = Expression.select().where(Expression.id.in_(request.ids))
found_ids = [expr.id for expr in expressions]
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
if not_found_ids:
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
logger.info(f"批量删除了 {deleted_count} 个表达方式")
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_expression_stats(authorization: Optional[str] = Header(None)):
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = Expression.select().count()
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e

View File

@@ -0,0 +1,662 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
import json
import asyncio
import subprocess
import shutil
from pathlib import Path
from datetime import datetime
from src.common.logger import get_logger
logger = get_logger("webui.git_mirror")
# 导入进度更新函数(避免循环导入)
_update_progress = None
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
_update_progress = callback
class MirrorType(str, Enum):
"""镜像源类型"""
GH_PROXY = "gh-proxy" # gh-proxy 主节点
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
GITHUB = "github" # GitHub 官方源(兜底)
CUSTOM = "custom" # 自定义镜像源
class GitMirrorConfig:
"""Git 镜像源配置管理"""
# 配置文件路径
CONFIG_FILE = Path("data/webui.json")
# 默认镜像源配置
DEFAULT_MIRRORS = [
{
"id": "gh-proxy",
"name": "gh-proxy 镜像",
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://gh-proxy.org/https://github.com",
"enabled": True,
"priority": 1,
"created_at": None,
},
{
"id": "hk-gh-proxy",
"name": "gh-proxy 香港节点",
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 2,
"created_at": None,
},
{
"id": "cdn-gh-proxy",
"name": "gh-proxy CDN 节点",
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 3,
"created_at": None,
},
{
"id": "edgeone-gh-proxy",
"name": "gh-proxy EdgeOne 节点",
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 4,
"created_at": None,
},
{
"id": "meyzh-github",
"name": "Meyzh GitHub 镜像",
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
"clone_prefix": "https://meyzh.github.io/https://github.com",
"enabled": True,
"priority": 5,
"created_at": None,
},
{
"id": "github",
"name": "GitHub 官方源(兜底)",
"raw_prefix": "https://raw.githubusercontent.com",
"clone_prefix": "https://github.com",
"enabled": True,
"priority": 999,
"created_at": None,
},
]
def __init__(self):
"""初始化配置管理器"""
self.config_file = self.CONFIG_FILE
self.mirrors: List[Dict[str, Any]] = []
self._load_config()
def _load_config(self) -> None:
"""加载配置文件"""
try:
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
data = json.load(f)
# 检查是否有镜像源配置
if "git_mirrors" not in data or not data["git_mirrors"]:
logger.info("配置文件中未找到镜像源配置,使用默认配置")
self._init_default_mirrors()
else:
self.mirrors = data["git_mirrors"]
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
else:
logger.info("配置文件不存在,创建默认配置")
self._init_default_mirrors()
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
self._init_default_mirrors()
def _init_default_mirrors(self) -> None:
"""初始化默认镜像源"""
current_time = datetime.now().isoformat()
self.mirrors = []
for mirror in self.DEFAULT_MIRRORS:
mirror_copy = mirror.copy()
mirror_copy["created_at"] = current_time
self.mirrors.append(mirror_copy)
self._save_config()
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
def _save_config(self) -> None:
"""保存配置到文件"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
# 读取现有配置
existing_data = {}
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
# 更新镜像源配置
existing_data["git_mirrors"] = self.mirrors
# 写入文件
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
logger.debug(f"配置已保存到 {self.config_file}")
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
def get_all_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有镜像源"""
return self.mirrors.copy()
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有启用的镜像源,按优先级排序"""
enabled = [m for m in self.mirrors if m.get("enabled", False)]
return sorted(enabled, key=lambda x: x.get("priority", 999))
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
"""根据 ID 获取镜像源"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
return mirror.copy()
return None
def add_mirror(
self,
mirror_id: str,
name: str,
raw_prefix: str,
clone_prefix: str,
enabled: bool = True,
priority: Optional[int] = None,
) -> Dict[str, Any]:
"""
添加新的镜像源
Returns:
添加的镜像源配置
Raises:
ValueError: 如果镜像源 ID 已存在
"""
# 检查 ID 是否已存在
if self.get_mirror_by_id(mirror_id):
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
# 如果未指定优先级,使用最大优先级 + 1
if priority is None:
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
priority = max_priority + 1
new_mirror = {
"id": mirror_id,
"name": name,
"raw_prefix": raw_prefix,
"clone_prefix": clone_prefix,
"enabled": enabled,
"priority": priority,
"created_at": datetime.now().isoformat(),
}
self.mirrors.append(new_mirror)
self._save_config()
logger.info(f"已添加镜像源: {mirror_id} - {name}")
return new_mirror.copy()
def update_mirror(
self,
mirror_id: str,
name: Optional[str] = None,
raw_prefix: Optional[str] = None,
clone_prefix: Optional[str] = None,
enabled: Optional[bool] = None,
priority: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新镜像源配置
Returns:
更新后的镜像源配置,如果不存在则返回 None
"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
if name is not None:
mirror["name"] = name
if raw_prefix is not None:
mirror["raw_prefix"] = raw_prefix
if clone_prefix is not None:
mirror["clone_prefix"] = clone_prefix
if enabled is not None:
mirror["enabled"] = enabled
if priority is not None:
mirror["priority"] = priority
mirror["updated_at"] = datetime.now().isoformat()
self._save_config()
logger.info(f"已更新镜像源: {mirror_id}")
return mirror.copy()
return None
def delete_mirror(self, mirror_id: str) -> bool:
"""
删除镜像源
Returns:
True 如果删除成功False 如果镜像源不存在
"""
for i, mirror in enumerate(self.mirrors):
if mirror.get("id") == mirror_id:
self.mirrors.pop(i)
self._save_config()
logger.info(f"已删除镜像源: {mirror_id}")
return True
return False
def get_default_priority_list(self) -> List[str]:
"""获取默认优先级列表(仅启用的镜像源 ID"""
enabled = self.get_enabled_mirrors()
return [m["id"] for m in enabled]
class GitMirrorService:
"""Git 镜像源服务"""
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
"""
初始化 Git 镜像源服务
Args:
max_retries: 最大重试次数
timeout: 请求超时时间(秒)
config: 镜像源配置管理器(可选,默认创建新实例)
"""
self.max_retries = max_retries
self.timeout = timeout
self.config = config or GitMirrorConfig()
logger.info(f"Git镜像源服务初始化完成已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
def get_mirror_config(self) -> GitMirrorConfig:
"""获取镜像源配置管理器"""
return self.config
@staticmethod
def check_git_installed() -> Dict[str, Any]:
"""
检查本机是否安装了 Git
Returns:
Dict 包含:
- installed: bool - 是否已安装 Git
- version: str - Git 版本号(如果已安装)
- path: str - Git 可执行文件路径(如果已安装)
- error: str - 错误信息(如果未安装或检测失败)
"""
import subprocess
import shutil
try:
# 查找 git 可执行文件路径
git_path = shutil.which("git")
if not git_path:
logger.warning("未找到 Git 可执行文件")
return {"installed": False, "error": "系统中未找到 Git请先安装 Git"}
# 获取 Git 版本
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
version = result.stdout.strip()
logger.info(f"检测到 Git: {version} at {git_path}")
return {"installed": True, "version": version, "path": git_path}
else:
logger.warning(f"Git 命令执行失败: {result.stderr}")
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
except subprocess.TimeoutExpired:
logger.error("Git 版本检测超时")
return {"installed": False, "error": "Git 版本检测超时"}
except Exception as e:
logger.error(f"检测 Git 时发生错误: {e}")
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
async def fetch_raw_file(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
获取 GitHub 仓库的 Raw 文件内容
Args:
owner: 仓库所有者
repo: 仓库名称
branch: 分支名称
file_path: 文件路径
mirror_id: 指定的镜像源 ID
custom_url: 自定义完整 URL如果提供将忽略其他参数
Returns:
Dict 包含:
- success: bool - 是否成功
- data: str - 文件内容(成功时)
- error: str - 错误信息(失败时)
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
if custom_url:
# 使用自定义 URL
return await self._fetch_with_url(custom_url, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
total_mirrors = len(mirrors_to_try)
# 依次尝试每个镜像源
for index, mirror in enumerate(mirrors_to_try, 1):
# 推送进度:正在尝试第 N 个镜像源
if _update_progress:
try:
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
await _update_progress(
stage="loading",
progress=progress,
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
if result["success"]:
# 成功,推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=70,
message=f"成功从 {mirror['name']} 获取数据",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
return result
# 失败,记录日志并推送失败信息
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
if _update_progress and index < total_mirrors:
try:
await _update_progress(
stage="loading",
progress=30 + int(index / total_mirrors * 40),
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 所有镜像源都失败
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _fetch_raw_from_mirror(
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
raw_prefix = mirror["raw_prefix"]
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
return await self._fetch_with_url(url, mirror["id"])
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
"""使用指定 URL 获取文件,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
logger.debug(f"尝试 #{attempt + 1}: {url}")
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(url)
response.raise_for_status()
logger.info(f"成功获取文件: {url}")
return {
"success": True,
"data": response.text,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
}
except httpx.HTTPStatusError as e:
last_error = f"HTTP {e.response.status_code}: {e}"
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except httpx.TimeoutException as e:
last_error = f"请求超时: {e}"
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
async def clone_repository(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str] = None,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
depth: Optional[int] = None,
) -> Dict[str, Any]:
"""
克隆 GitHub 仓库
Args:
owner: 仓库所有者
repo: 仓库名称
target_path: 目标路径
branch: 分支名称(可选)
mirror_id: 指定的镜像源 ID
custom_url: 自定义克隆 URL
depth: 克隆深度(浅克隆)
Returns:
Dict 包含:
- success: bool - 是否成功
- path: str - 克隆路径(成功时)
- error: str - 错误信息(失败时)
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始克隆仓库: {owner}/{repo}{target_path}")
if custom_url:
# 使用自定义 URL
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
# 依次尝试每个镜像源
for mirror in mirrors_to_try:
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
if result["success"]:
return result
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
# 所有镜像源都失败
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _clone_from_mirror(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror: Dict[str, Any],
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
clone_prefix = mirror["clone_prefix"]
url = f"{clone_prefix}/{owner}/{repo}.git"
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
async def _clone_with_url(
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
) -> Dict[str, Any]:
"""使用指定 URL 克隆仓库,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
# 确保目标路径不存在
if target_path.exists():
logger.warning(f"目标路径已存在,删除: {target_path}")
shutil.rmtree(target_path, ignore_errors=True)
# 构建 git clone 命令
cmd = ["git", "clone"]
# 添加分支参数
if branch:
cmd.extend(["-b", branch])
# 添加深度参数(浅克隆)
if depth:
cmd.extend(["--depth", str(depth)])
# 添加 URL 和目标路径
cmd.extend([url, str(target_path)])
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
# 推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=20 + attempt * 10,
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
operation="install",
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop()
def run_git_clone():
return subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
process = await loop.run_in_executor(None, run_git_clone)
if process.returncode == 0:
logger.info(f"成功克隆仓库: {url} -> {target_path}")
return {
"success": True,
"path": str(target_path),
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
"branch": branch or "default",
}
else:
last_error = f"Git 克隆失败: {process.stderr}"
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except subprocess.TimeoutExpired:
last_error = "克隆超时(超过 5 分钟)"
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
except FileNotFoundError:
last_error = "Git 未安装或不在 PATH 中"
logger.error(f"Git 未找到: {last_error}")
break # Git 不存在,不需要重试
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
# 全局服务实例
_git_mirror_service: Optional[GitMirrorService] = None
def get_git_mirror_service() -> GitMirrorService:
"""获取 Git 镜像源服务实例(单例)"""
global _git_mirror_service
if _git_mirror_service is None:
_git_mirror_service = GitMirrorService()
return _git_mirror_service

View File

0
src/webui/logs_routes.py Normal file
View File

141
src/webui/logs_ws.py Normal file
View File

@@ -0,0 +1,141 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
import json
from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("webui.logs_ws")
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
"""从日志文件中加载最近的日志
Args:
limit: 返回的最大日志条数
Returns:
日志列表
"""
logs = []
log_dir = Path("logs")
if not log_dir.exists():
return logs
# 获取所有日志文件,按修改时间排序
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
# 用于生成唯一 ID 的计数器
log_counter = 0
# 从最新的文件开始读取
for log_file in log_files:
if len(logs) >= limit:
break
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
# 从文件末尾开始读取
for line in reversed(lines):
if len(logs) >= limit:
break
try:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = (
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
)
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),
"level": log_entry.get("level", "INFO").upper(),
"module": log_entry.get("logger_name", ""),
"message": log_entry.get("event", ""),
}
logs.append(formatted_log)
log_counter += 1
except (json.JSONDecodeError, KeyError):
continue
except Exception as e:
logger.error(f"读取日志文件失败 {log_file}: {e}")
continue
# 反转列表,使其按时间顺序排列(旧到新)
return list(reversed(logs))
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
"""
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:
recent_logs = load_recent_logs(limit=100)
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
for log_entry in recent_logs:
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
except Exception as e:
logger.error(f"发送历史日志失败: {e}")
try:
# 保持连接,等待客户端消息或断开
while True:
# 接收客户端消息(用于心跳或控制指令)
data = await websocket.receive_text()
# 可以处理客户端的控制消息,例如:
# - "ping" -> 心跳检测
# - {"filter": "ERROR"} -> 设置日志级别过滤
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
async def broadcast_log(log_data: dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")

View File

@@ -1,93 +0,0 @@
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
import os
from pathlib import Path
from src.common.logger import get_logger
from .token_manager import get_token_manager
logger = get_logger("webui")
def setup_webui(mode: str = "production") -> bool:
"""
设置 WebUI
Args:
mode: 运行模式,"development""production"
Returns:
bool: 是否成功设置
"""
# 初始化 Token 管理器(确保 token 文件存在)
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
logger.info("💡 请使用此 Token 登录 WebUI")
if mode == "development":
return setup_dev_mode()
else:
return setup_production_mode()
def setup_dev_mode() -> bool:
"""设置开发模式 - 仅启用 CORS前端自行启动"""
logger.info("📝 WebUI 开发模式已启用")
logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev")
logger.info("💡 前端将运行在 http://localhost:7999")
return True
def setup_production_mode() -> bool:
"""设置生产模式 - 挂载静态文件"""
try:
from src.common.server import get_global_server
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
server = get_global_server()
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
if not static_path.exists():
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
logger.warning("💡 请先构建前端: cd webui && npm run build")
return False
if not (static_path / "index.html").exists():
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
logger.warning("💡 请确认前端已正确构建")
return False
# 挂载静态资源
if (static_path / "assets").exists():
server.app.mount(
"/assets",
StaticFiles(directory=str(static_path / "assets")),
name="assets"
)
# 处理 SPA 路由
@server.app.get("/{full_path:path}")
async def serve_spa(full_path: str):
"""服务单页应用"""
# API 路由不处理
if full_path.startswith("api/"):
return None
# 检查文件是否存在
file_path = static_path / full_path
if file_path.is_file():
return FileResponse(file_path)
# 返回 index.htmlSPA 路由)
return FileResponse(static_path / "index.html")
host = os.getenv("HOST", "127.0.0.1")
port = os.getenv("PORT", "8000")
logger.info("✅ WebUI 生产模式已挂载")
logger.info(f"🌐 访问 http://{host}:{port} 查看 WebUI")
return True
except Exception as e:
logger.error(f"挂载 WebUI 静态文件失败: {e}")
return False

Some files were not shown because too many files have changed in this diff Show More