Merge pull request #1380 from zhangxinhui02/helm-chart
Helm chart `0.11.3-beta` and `0.11.5-beta` Release
This commit is contained in:
5
.github/workflows/docker-image.yml
vendored
5
.github/workflows/docker-image.yml
vendored
@@ -1,11 +1,12 @@
|
||||
name: Docker Build and Push
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- classical
|
||||
- dev
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
- "v*"
|
||||
@@ -24,6 +25,7 @@ jobs:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
@@ -77,6 +79,7 @@ jobs:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -11,7 +11,6 @@ run_maibot_core.bat
|
||||
run_voice.bat
|
||||
run_napcat_adapter.bat
|
||||
run_ad.bat
|
||||
s4u.s4u
|
||||
llm_tool_benchmark_results.json
|
||||
MaiBot-Napcat-Adapter-main
|
||||
MaiBot-Napcat-Adapter
|
||||
@@ -27,6 +26,7 @@ run.bat
|
||||
log_debug/
|
||||
run_amds.bat
|
||||
run_none.bat
|
||||
docs-mai/
|
||||
run.py
|
||||
message_queue_content.txt
|
||||
message_queue_content.bat
|
||||
@@ -51,6 +51,7 @@ template/compare/model_config_template.toml
|
||||
src/plugins/utils/statistic.py
|
||||
CLAUDE.md
|
||||
MaiBot-Dashboard/
|
||||
cloudflare-workers/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
@@ -69,7 +70,6 @@ elua.confirmed
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
|
||||
@@ -25,6 +25,8 @@ WORKDIR /MaiMBot
|
||||
# 复制依赖列表
|
||||
COPY requirements.txt .
|
||||
|
||||
RUN apt-get update && apt-get install -y git
|
||||
|
||||
# 从编译阶段复制 LPMM 编译结果
|
||||
COPY --from=lpmm-builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/
|
||||
|
||||
|
||||
47
README.md
47
README.md
@@ -25,11 +25,12 @@
|
||||
|
||||
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
||||
|
||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。
|
||||
- 🤔 **实时思维系统**:模拟人类思考过程。
|
||||
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
||||
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
||||
- 🔌 **强大插件系统**:提供API和事件系统,可编写强大插件。
|
||||
- 💭 **拟人构建的prompt**:使用自然语言风格构建回复器的prompt,实现近似人类言语习惯的回复。
|
||||
- 💭 **行为规划**:在合适的时间说话,使用合适的动作
|
||||
- 🧠 **表达学习**:学习群友的说话风格和表达方式,学会真实人类的说话风格
|
||||
- 🤔 **黑话学习**:自主的学习没有见过的词语,尝试理解并认知含义
|
||||
- 🔌 **插件系统**:提供API和事件系统,可编写丰富插件。
|
||||
- 💝 **情感表达**:情绪系统和表情包系统。
|
||||
|
||||
<div style="text-align: center">
|
||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||
@@ -44,29 +45,28 @@
|
||||
|
||||
## 🔥 更新和安装
|
||||
|
||||
**最新版本: v0.11.0** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
**最新版本: v0.11.5** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||
**GitHub 分支说明:**
|
||||
- `main`: 稳定发布版本(推荐)
|
||||
|
||||
|
||||
- `dev`: 开发测试版本(不稳定)
|
||||
- `classical`: 旧版本(停止维护)
|
||||
- `classical`: 经典版本(停止维护)
|
||||
|
||||
### 最新版本部署教程
|
||||
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
|
||||
|
||||
> [!WARNING]
|
||||
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
|
||||
> - 有问题可以提交 Issue 或者 Discussion。
|
||||
> - 有问题可以提交 Issue 。
|
||||
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
||||
> - 由于程序处于开发中,可能消耗较多 token。
|
||||
|
||||
## 麦麦MC项目MaiCraft(早期开发)
|
||||
[让麦麦玩MC](https://github.com/MaiM-with-u/Maicraft)
|
||||
|
||||
交流群:1058573197
|
||||
|
||||
## 💬 讨论
|
||||
|
||||
**技术交流群:**
|
||||
@@ -78,7 +78,7 @@
|
||||
**聊天吹水群:**
|
||||
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
|
||||
|
||||
**插件开发测试版群:**
|
||||
**插件开发/测试版讨论群:**
|
||||
- [插件开发群](https://qm.qq.com/q/1036092828)
|
||||
|
||||
## 📚 文档
|
||||
@@ -87,7 +87,22 @@
|
||||
|
||||
- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。
|
||||
|
||||
### 设计理念(原始时代的火花)
|
||||
|
||||
## 📚 衍生项目
|
||||
|
||||
### MaiCraft(早期开发)
|
||||
[MaiCraft](https://github.com/MaiM-with-u/Maicraft)
|
||||
> 让麦麦具有玩MC能力的项目
|
||||
> 交流群:1058573197
|
||||
|
||||
### MoFox_Bot
|
||||
[MoFox - 仓库地址](https://github.com/MoFox-Studio/MoFox-Core)
|
||||
> MoFox_Bot 是一个基于 MaiCore 0.10.0 snapshot.5 的增强型 fork 项目
|
||||
> 我们保留了原项目几乎所有核心功能,并在此基础上进行了深度优化与功能扩展,致力于打造一个更稳定、更智能、更具趣味性的 AI 智能体。
|
||||
|
||||
|
||||
|
||||
## 设计理念(原始时代的火花)
|
||||
|
||||
> **千石可乐说:**
|
||||
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
|
||||
@@ -99,7 +114,7 @@
|
||||
## 🙋 贡献和致谢
|
||||
你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦!
|
||||
MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉
|
||||
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)。(待补完)
|
||||
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs-src/CONTRIBUTE.md)。(待补完)
|
||||
|
||||
### 贡献者
|
||||
|
||||
|
||||
30
bot.py
30
bot.py
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import platform
|
||||
import traceback
|
||||
@@ -30,7 +29,7 @@ else:
|
||||
raise
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
|
||||
|
||||
initialize_logging()
|
||||
|
||||
@@ -76,6 +75,15 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
||||
try:
|
||||
logger.info("正在优雅关闭麦麦...")
|
||||
|
||||
# 关闭 WebUI 服务器
|
||||
try:
|
||||
from src.webui.webui_server import get_webui_server
|
||||
webui_server = get_webui_server()
|
||||
if webui_server and webui_server._server:
|
||||
await webui_server.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭 WebUI 服务器时出错: {e}")
|
||||
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
@@ -107,9 +115,6 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
||||
|
||||
logger.info("麦麦优雅关闭完成")
|
||||
|
||||
# 关闭日志系统,释放文件句柄
|
||||
shutdown_logging()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
||||
|
||||
@@ -216,6 +221,11 @@ if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 初始化 WebSocket 日志推送
|
||||
from src.common.logger import initialize_ws_handler
|
||||
|
||||
initialize_ws_handler(loop)
|
||||
|
||||
try:
|
||||
# 执行初始化和任务调度
|
||||
loop.run_until_complete(main_system.initialize())
|
||||
@@ -241,7 +251,7 @@ if __name__ == "__main__":
|
||||
# 确保 loop 在任何情况下都尝试关闭(如果存在且未关闭)
|
||||
if "loop" in locals() and loop and not loop.is_closed():
|
||||
loop.close()
|
||||
logger.info("事件循环已关闭")
|
||||
print("[主程序] 事件循环已关闭")
|
||||
|
||||
# 关闭日志系统,释放文件句柄
|
||||
try:
|
||||
@@ -249,6 +259,8 @@ if __name__ == "__main__":
|
||||
except Exception as e:
|
||||
print(f"关闭日志系统时出错: {e}")
|
||||
|
||||
# 在程序退出前暂停,让你有机会看到输出
|
||||
# input("按 Enter 键退出...") # <--- 添加这行
|
||||
sys.exit(exit_code) # <--- 使用记录的退出码
|
||||
print("[主程序] 准备退出...")
|
||||
|
||||
# 使用 os._exit() 强制退出,避免被阻塞
|
||||
# 由于已经在 graceful_shutdown() 中完成了所有清理工作,这是安全的
|
||||
os._exit(exit_code)
|
||||
|
||||
@@ -1,6 +1,99 @@
|
||||
# Changelog
|
||||
## [0.11.5] - 2025-11-21
|
||||
### 🌟 重大更新
|
||||
- WebUI 现支持手动重启麦麦,曲线救国版“热重载”
|
||||
- 新增麦麦 QQ 适配器可视化编辑 UI(独立进程,需手动上传/下载并覆盖适配器文件)
|
||||
- 麦麦主程序配置支持可视化模式与源代码模式双模式编辑,后端执行 TOML 校验
|
||||
- 优化 planner 与 replyer 协同机制,调试日志更细
|
||||
|
||||
## [0.11.2] - 2025-11-15
|
||||
### 新增
|
||||
- 表情包管理、人物信息管理、表达方式管理界面手机端适配
|
||||
- 配置页“重启麦麦”提示
|
||||
- 详细的debug prompt显示配置
|
||||
- 麦麦界面操作主题色按钮
|
||||
- 前端集成 CodeMirror(Python/JSON/TOML 语法高亮)并对 JSON 配置提供自动纠错提示
|
||||
|
||||
### 修复
|
||||
- 表情包缩略图过小
|
||||
- 添加模型后无法立即显示
|
||||
- 插件市场分类错误
|
||||
- 浅色模式下日志查看器背景异常
|
||||
- 表情包详情窗口无法关闭
|
||||
- 情绪标签无法正常读取(issues/1373)
|
||||
- 模型任务配置界面模型温度不可更改(issues/1369)
|
||||
- 模型任务配置界面部分输入框无法删除默认值
|
||||
- 插件商店默认勾选“仅显示兼容插件”
|
||||
- 插件商店标签页数量显示错误
|
||||
- 日志查看器日志换行异常
|
||||
- 主题色未正确应用
|
||||
- 侧边栏滚动条问题
|
||||
- 移除前端 TOML 解析库导致的兼容问题
|
||||
|
||||
### 优化
|
||||
- 首页加载用户体验
|
||||
- 首页加载速度(9s → <1s)
|
||||
- 整个界面的初屏加载速度
|
||||
|
||||
### 更新
|
||||
- 适配器配置支持上传模式与指定路径模式;指定路径模式可免去反复上传/下载配置文件
|
||||
- 适配器配置界面标签页响应式优化,小屏仅显示简短标签
|
||||
- 麦麦资源管理所有界面的批量删除能力
|
||||
- 资源管理所有界面的分页、每页数量选择与页跳转
|
||||
- 插件市场新增点赞、点踩、评分与下载量统计(基于 Cloudflare,保证国内可访问)
|
||||
- 麦麦表情包查看界面的描述部分支持 Markdown 渲染
|
||||
|
||||
## [0.11.4] - 2025-11-19
|
||||
### 🌟 主要更新内容
|
||||
- **首个官方 Web 管理界面上线**:在此版本之前,MaiBot 没有 WebUI,所有配置需手动编辑 TOML 文件
|
||||
- **认证系统**:Token 安全登录(支持系统生成 64 位随机令牌 / 自定义 Token),首次配置向导
|
||||
- **配置管理(可视化编辑,无需手动改 TOML)**:
|
||||
- 麦麦主程序配置:基础设置、人格、表情、黑话、情绪等
|
||||
- 模型提供商配置:OpenAI、Anthropic、DeepSeek、Qwen、Ollama 等
|
||||
- 模型配置:对话/视觉/嵌入模型分配
|
||||
- **资源管理**:
|
||||
- 表情包管理:查看、搜索、注册、封禁
|
||||
- 表达方式管理:查看麦麦的表达记录
|
||||
- 人物信息管理:查看联系人列表
|
||||
- **插件系统**:
|
||||
- 插件市场浏览
|
||||
- 一键安装/卸载/更新
|
||||
- 版本兼容性检查
|
||||
- 实时安装进度推送
|
||||
- **日志查看器**:
|
||||
- WebSocket 实时日志流
|
||||
- 日志级别过滤(DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
- 搜索功能
|
||||
- **主题定制**:
|
||||
- 浅色/深色/跟随系统
|
||||
- 12 种主题色(6 单色 + 6 渐变色)
|
||||
- 自定义颜色选择器
|
||||
- **全局搜索**:Cmd/Ctrl + K 快捷键,快速跳转任意页面
|
||||
|
||||
### 细节
|
||||
- **技术栈**:
|
||||
- 前端: React 19 + TypeScript + Vite + TanStack Router + shadcn/ui
|
||||
- 后端: FastAPI + Uvicorn + WebSocket
|
||||
- 特点: SPA 单页应用,前后端同端口,静态文件托管
|
||||
- **使用方式**:参照 template.env 文件更新 .env 文件,添加两个字段:
|
||||
- `WEBUI_ENABLED=true`
|
||||
- `WEBUI_MODE=production`
|
||||
- **WebUI 开源协议**:GPLv3
|
||||
- **WebUI 地址**:https://github.com/Mai-with-u/MaiBot-Dashboard
|
||||
|
||||
告别手动编辑配置文件,享受现代化图形界面!
|
||||
|
||||
## [0.11.3] - 2025-11-18
|
||||
### 功能更改和修复
|
||||
- 优化记忆提取策略
|
||||
- 优化黑话提取
|
||||
- 优化表达方式学习
|
||||
- 修改readme
|
||||
- 加入测试版webui
|
||||
|
||||
提示:清理旧的记忆数据和表达方式,表现更好
|
||||
方法:删除数据库中 expression jargon 和 thinking_back 的全部内容
|
||||
|
||||
## [0.11.2] - 2025-11-16
|
||||
### 🌟 主要功能更改
|
||||
- "海马体Agent"记忆系统上线,最新最好的记忆系统,默认已接入lpmm
|
||||
- 添加黑话jargon学习系统
|
||||
|
||||
@@ -29,7 +29,8 @@ services:
|
||||
- TZ=Asia/Shanghai
|
||||
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
|
||||
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
|
||||
# ports:
|
||||
ports:
|
||||
- "18001:8001" # webui端口
|
||||
# - "8000:8000"
|
||||
volumes:
|
||||
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件
|
||||
|
||||
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 11 KiB |
BIN
docs/image-1.png
BIN
docs/image-1.png
Binary file not shown.
|
Before Width: | Height: | Size: 21 KiB |
BIN
docs/image.png
BIN
docs/image.png
Binary file not shown.
|
Before Width: | Height: | Size: 4.9 KiB |
@@ -1,336 +0,0 @@
|
||||
# 模型配置指南
|
||||
|
||||
本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
|
||||
|
||||
## 配置文件结构
|
||||
|
||||
配置文件主要包含以下几个部分:
|
||||
- 版本信息
|
||||
- API服务提供商配置
|
||||
- 模型配置
|
||||
- 模型任务配置
|
||||
|
||||
## 1. 版本信息
|
||||
|
||||
```toml
|
||||
[inner]
|
||||
version = "1.1.1"
|
||||
```
|
||||
|
||||
用于标识配置文件的版本,遵循语义化版本规则。
|
||||
|
||||
## 2. API服务提供商配置
|
||||
|
||||
### 2.1 基本配置
|
||||
|
||||
使用 `[[api_providers]]` 数组配置多个API服务提供商:
|
||||
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "DeepSeek" # 服务商名称(自定义)
|
||||
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
|
||||
api_key = "your-api-key-here" # API密钥
|
||||
client_type = "openai" # 客户端类型
|
||||
max_retry = 2 # 最大重试次数
|
||||
timeout = 30 # 超时时间(秒)
|
||||
retry_interval = 10 # 重试间隔(秒)
|
||||
```
|
||||
|
||||
### 2.2 配置参数说明
|
||||
|
||||
| 参数 | 必填 | 说明 | 默认值 |
|
||||
|------|------|------|--------|
|
||||
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
||||
| `base_url` | ✅ | API服务的基础URL | - |
|
||||
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式) | `openai` |
|
||||
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
||||
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
||||
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
||||
|
||||
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
|
||||
### 2.3 支持的服务商示例
|
||||
|
||||
#### DeepSeek
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "DeepSeek"
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
api_key = "your-deepseek-api-key"
|
||||
client_type = "openai"
|
||||
```
|
||||
|
||||
#### SiliconFlow
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "SiliconFlow"
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
api_key = "your-siliconflow-api-key"
|
||||
client_type = "openai"
|
||||
```
|
||||
|
||||
#### Google Gemini
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "Google"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
api_key = "your-google-api-key"
|
||||
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
||||
```
|
||||
|
||||
## 3. 模型配置
|
||||
|
||||
### 3.1 基本模型配置
|
||||
|
||||
使用 `[[models]]` 数组配置多个模型:
|
||||
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
|
||||
name = "deepseek-v3" # 自定义模型名称
|
||||
api_provider = "DeepSeek" # 引用的API服务商名称
|
||||
price_in = 2.0 # 输入价格(元/M token)
|
||||
price_out = 8.0 # 输出价格(元/M token)
|
||||
```
|
||||
|
||||
### 3.2 高级模型配置
|
||||
|
||||
#### 强制流式输出
|
||||
对于不支持非流式输出的模型:
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "some-model"
|
||||
name = "custom-name"
|
||||
api_provider = "Provider"
|
||||
force_stream_mode = true # 启用强制流式输出
|
||||
```
|
||||
|
||||
#### 额外参数配置`extra_params`
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen3-8B"
|
||||
name = "qwen3-8b"
|
||||
api_provider = "SiliconFlow"
|
||||
[models.extra_params]
|
||||
enable_thinking = false # 禁用思考
|
||||
```
|
||||
这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置,**配置时应参考相应的API文档**。
|
||||
|
||||
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
|
||||
|
||||

|
||||
|
||||
以豆包文档为另一个例子
|
||||
|
||||

|
||||
|
||||
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
|
||||
```toml
|
||||
[[models]]
|
||||
# 你的模型
|
||||
[models.extra_params]
|
||||
thinking = {type = "disabled"} # 禁用思考
|
||||
```
|
||||
|
||||
而对于`gemini`需要单独进行配置
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "gemini-2.5-flash"
|
||||
name = "gemini-2.5-flash"
|
||||
api_provider = "Google"
|
||||
[models.extra_params]
|
||||
thinking_budget = 0 # 禁用思考
|
||||
# thinking_budget = -1 由模型自己决定
|
||||
```
|
||||
|
||||
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
||||
|
||||
### 3.3 配置参数说明
|
||||
|
||||
| 参数 | 必填 | 说明 |
|
||||
|------|------|------|
|
||||
| `model_identifier` | ✅ | API服务商提供的模型标识符 |
|
||||
| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
|
||||
| `api_provider` | ✅ | 对应的API服务商名称 |
|
||||
| `price_in` | ❌ | 输入价格(元/M token),用于成本统计 |
|
||||
| `price_out` | ❌ | 输出价格(元/M token),用于成本统计 |
|
||||
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
|
||||
| `extra_params` | ❌ | 额外的模型参数配置 |
|
||||
|
||||
## 4. 模型任务配置
|
||||
|
||||
### utils - 工具模型
|
||||
用于表情包模块、取名模块、关系模块等核心功能:
|
||||
```toml
|
||||
[model_task_config.utils]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### utils_small - 小型工具模型
|
||||
用于高频率调用的场景,建议使用速度快的小模型:
|
||||
```toml
|
||||
[model_task_config.utils_small]
|
||||
model_list = ["qwen3-8b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### replyer - 主要回复模型
|
||||
首要回复模型,也用于表达器和表达方式学习:
|
||||
```toml
|
||||
[model_task_config.replyer]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### planner - 决策模型
|
||||
负责决定MaiBot该做什么:
|
||||
```toml
|
||||
[model_task_config.planner]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.3
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### emotion - 情绪模型
|
||||
负责MaiBot的情绪变化:
|
||||
```toml
|
||||
[model_task_config.emotion]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.3
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### memory - 记忆模型
|
||||
```toml
|
||||
[model_task_config.memory]
|
||||
model_list = ["qwen3-30b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### vlm - 视觉语言模型
|
||||
用于图像识别:
|
||||
```toml
|
||||
[model_task_config.vlm]
|
||||
model_list = ["qwen2.5-vl-72b"]
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### voice - 语音识别模型
|
||||
```toml
|
||||
[model_task_config.voice]
|
||||
model_list = ["sensevoice-small"]
|
||||
```
|
||||
|
||||
### embedding - 嵌入模型
|
||||
```toml
|
||||
[model_task_config.embedding]
|
||||
model_list = ["bge-m3"]
|
||||
```
|
||||
|
||||
### tool_use - 工具调用模型
|
||||
需要使用支持工具调用的模型:
|
||||
```toml
|
||||
[model_task_config.tool_use]
|
||||
model_list = ["qwen3-14b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### lpmm_entity_extract - 实体提取模型
|
||||
```toml
|
||||
[model_task_config.lpmm_entity_extract]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### lpmm_rdf_build - RDF构建模型
|
||||
```toml
|
||||
[model_task_config.lpmm_rdf_build]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### lpmm_qa - 问答模型
|
||||
```toml
|
||||
[model_task_config.lpmm_qa]
|
||||
model_list = ["deepseek-r1-distill-qwen-32b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
## 5. 配置建议
|
||||
|
||||
### 5.1 Temperature 参数选择
|
||||
|
||||
| 任务类型 | 推荐温度 | 说明 |
|
||||
|----------|----------|------|
|
||||
| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
|
||||
| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
|
||||
| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
|
||||
|
||||
### 5.2 模型选择建议
|
||||
|
||||
| 任务类型 | 推荐模型类型 | 示例 |
|
||||
|----------|--------------|------|
|
||||
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
|
||||
| 高频率任务 | 小模型 | Qwen3-8B |
|
||||
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
|
||||
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
|
||||
|
||||
### 5.3 成本优化
|
||||
|
||||
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
|
||||
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
|
||||
3. **选择免费模型**:对于测试环境,优先使用price为0的模型
|
||||
|
||||
## 6. 配置验证
|
||||
|
||||
### 6.1 必要检查项
|
||||
|
||||
1. ✅ API密钥是否正确配置
|
||||
2. ✅ 模型标识符是否与API服务商提供的一致
|
||||
3. ✅ 任务配置中引用的模型名称是否在models中定义
|
||||
4. ✅ 多模态任务是否配置了对应的专用模型
|
||||
|
||||
### 6.2 测试配置
|
||||
|
||||
建议在正式使用前:
|
||||
1. 使用少量测试数据验证配置
|
||||
2. 检查API调用是否正常
|
||||
3. 确认成本统计功能正常工作
|
||||
|
||||
## 7. 故障排除
|
||||
|
||||
### 7.1 常见问题
|
||||
|
||||
**问题1**: API调用失败
|
||||
- 检查API密钥是否正确
|
||||
- 确认base_url是否可访问
|
||||
- 检查模型标识符是否正确
|
||||
|
||||
**问题2**: 模型未找到
|
||||
- 确认模型名称在任务配置和模型定义中一致
|
||||
- 检查api_provider名称是否匹配
|
||||
|
||||
**问题3**: 响应异常
|
||||
- 检查温度参数是否合理(0-1之间)
|
||||
- 确认max_tokens设置是否合适
|
||||
- 验证模型是否支持所需功能
|
||||
|
||||
### 7.2 日志查看
|
||||
|
||||
查看 `logs/` 目录下的日志文件,寻找相关错误信息。
|
||||
|
||||
## 8. 更新和维护
|
||||
|
||||
1. **定期更新**: 关注API服务商的模型更新,及时调整配置
|
||||
2. **性能监控**: 监控模型调用的成本和性能
|
||||
3. **备份配置**: 在修改前备份当前配置文件
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
stages:
|
||||
# - initialize-maibot-git-repo
|
||||
- build
|
||||
- package
|
||||
- build-image
|
||||
- package-helm-chart
|
||||
|
||||
# 仅在helm-chart分支运行
|
||||
workflow:
|
||||
@@ -9,49 +8,15 @@ workflow:
|
||||
- if: '$CI_COMMIT_BRANCH == "helm-chart"'
|
||||
- when: never
|
||||
|
||||
## 查询并将麦麦仓库的工作区置为最后一个tag的版本
|
||||
#initialize-maibot-git-repo:
|
||||
# stage: initialize-maibot-git-repo
|
||||
# image: reg.mikumikumi.xyz/base/git:latest
|
||||
# cache:
|
||||
# key: git-repo
|
||||
# policy: push
|
||||
# paths:
|
||||
# - target-repo/
|
||||
# script:
|
||||
# - git clone https://github.com/Mai-with-u/MaiBot.git target-repo/
|
||||
# - cd target-repo/
|
||||
# - export MAIBOT_VERSION=$(git describe --tags --abbrev=0)
|
||||
# - echo "Current version is ${MAIBOT_VERSION}"
|
||||
# - git reset --hard ${MAIBOT_VERSION}
|
||||
# - echo ${MAIBOT_VERSION} > MAIBOT_VERSION
|
||||
# - git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
# - git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
# - ls -al
|
||||
#
|
||||
## 构建最后一个tag的麦麦本体的镜像
|
||||
#build-core:
|
||||
# stage: build
|
||||
# image: reg.mikumikumi.xyz/base/kaniko-builder:latest
|
||||
# cache:
|
||||
# key: git-repo
|
||||
# policy: pull
|
||||
# paths:
|
||||
# - target-repo/
|
||||
# script:
|
||||
# - cd target-repo/
|
||||
# - export BUILD_CONTEXT=$(pwd)
|
||||
# - ls -al
|
||||
# - export BUILD_DESTINATION="reg.mikumikumi.xyz/maibot/maibot:tag-$(cat MAIBOT_VERSION)"
|
||||
# - build
|
||||
|
||||
# 将Helm Chart版本作为tag,构建并推送镜像
|
||||
# 构建并推送adapter-cm-generator镜像
|
||||
build-adapter-cm-generator:
|
||||
stage: build
|
||||
stage: build-image
|
||||
image: reg.mikumikumi.xyz/base/kaniko-builder:latest
|
||||
# rules:
|
||||
# - changes:
|
||||
# - helm-chart/adapter-cm-generator/**
|
||||
variables:
|
||||
BUILD_NO_CACHE: true
|
||||
rules:
|
||||
- changes:
|
||||
- helm-chart/adapter-cm-generator/**
|
||||
script:
|
||||
- export BUILD_CONTEXT=helm-chart/adapter-cm-generator
|
||||
- export TMP_DST=reg.mikumikumi.xyz/maibot/adapter-cm-generator
|
||||
@@ -60,16 +25,36 @@ build-adapter-cm-generator:
|
||||
- export BUILD_ARGS="--destination ${TMP_DST}:latest"
|
||||
- build
|
||||
|
||||
# 构建并推送core-webui-cm-sync镜像
|
||||
build-core-webui-cm-sync:
|
||||
stage: build-image
|
||||
image: reg.mikumikumi.xyz/base/kaniko-builder:latest
|
||||
variables:
|
||||
BUILD_NO_CACHE: true
|
||||
rules:
|
||||
- changes:
|
||||
- helm-chart/core-webui-cm-sync/**
|
||||
script:
|
||||
- export BUILD_CONTEXT=helm-chart/core-webui-cm-sync
|
||||
- export TMP_DST=reg.mikumikumi.xyz/maibot/core-webui-cm-sync
|
||||
- export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2)
|
||||
- export BUILD_DESTINATION="${TMP_DST}:${CHART_VERSION}"
|
||||
- export BUILD_ARGS="--destination ${TMP_DST}:latest"
|
||||
- build
|
||||
|
||||
# 打包并推送helm chart
|
||||
package-helm-chart:
|
||||
stage: package
|
||||
stage: package-helm-chart
|
||||
image: reg.mikumikumi.xyz/mirror/helm:latest
|
||||
# rules:
|
||||
# - changes:
|
||||
# - helm-chart/files/**
|
||||
# - helm-chart/templates/**
|
||||
# - helm-chart/Chart.yaml
|
||||
# - helm-chart/values.yaml
|
||||
rules:
|
||||
- changes:
|
||||
- helm-chart/files/**
|
||||
- helm-chart/templates/**
|
||||
- helm-chart/.gitignore
|
||||
- helm-chart/.helmignore
|
||||
- helm-chart/Chart.yaml
|
||||
- helm-chart/README.md
|
||||
- helm-chart/values.yaml
|
||||
script:
|
||||
- export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2)
|
||||
- helm registry login reg.mikumikumi.xyz --username ${CI_REGISTRY_USER} --password ${CI_REGISTRY_PASSWORD}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
adapter-cm-generator
|
||||
core-webui-cm-sync
|
||||
.gitlab-ci.yml
|
||||
@@ -2,5 +2,5 @@ apiVersion: v2
|
||||
name: maibot
|
||||
description: "Maimai Bot, a cyber friend dedicated to group chats"
|
||||
type: application
|
||||
version: 0.11.2-beta
|
||||
appVersion: 0.11.2-beta
|
||||
version: 0.11.5-beta
|
||||
appVersion: 0.11.5-beta
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
|
||||
| Helm Chart版本 | 对应的MaiBot版本 | Commit SHA |
|
||||
|----------------|--------------|------------------------------------------|
|
||||
| 0.11.5-beta | 0.11.5-beta | ad2df627001f18996802f23c405b263e78af0d0f |
|
||||
| 0.11.3-beta | 0.11.3-beta | cd6dc18f546f81e08803d3b8dba48e504dad9295 |
|
||||
| 0.11.2-beta | 0.11.2-beta | d3c8cea00dbb97f545350f2c3d5bcaf252443df2 |
|
||||
| 0.11.1-beta | 0.11.1-beta | 94e079a340a43dff8a2bc178706932937fc10b11 |
|
||||
| 0.11.0-beta | 0.11.0-beta | 16059532d8ef87ac28e2be0838ff8b3a34a91d0f |
|
||||
|
||||
@@ -5,6 +5,6 @@ WORKDIR /app
|
||||
|
||||
COPY . /app
|
||||
|
||||
RUN pip3 install --no-cache-dir -i https://mirrors.ustc.edu.cn/pypi/simple -r requirements.txt
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENTRYPOINT ["python3", "adapter-cm-generator.py"]
|
||||
|
||||
@@ -29,7 +29,7 @@ data['maibot_server']['host'] = f'{release_name}-maibot-core' # 根据release
|
||||
data['maibot_server']['port'] = 8000
|
||||
|
||||
# 创建/修改configmap
|
||||
cm_name = f'{release_name}-maibot-adapter'
|
||||
cm_name = f'{release_name}-maibot-adapter-config'
|
||||
cm = client.V1ConfigMap(
|
||||
metadata=client.V1ObjectMeta(name=cm_name),
|
||||
data={'config.toml': toml.dumps(data)}
|
||||
|
||||
10
helm-chart/core-webui-cm-sync/Dockerfile
Normal file
10
helm-chart/core-webui-cm-sync/Dockerfile
Normal file
@@ -0,0 +1,10 @@
|
||||
# 此镜像用于辅助麦麦的WebUI更新配置文件,随core容器持续运行
|
||||
FROM python:3.13-slim
|
||||
|
||||
WORKDIR /MaiMBot
|
||||
|
||||
COPY . /MaiMBot
|
||||
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENTRYPOINT ["python3", "core-webui-cm-sync.py"]
|
||||
92
helm-chart/core-webui-cm-sync/core-webui-cm-sync.py
Normal file
92
helm-chart/core-webui-cm-sync/core-webui-cm-sync.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#!/bin/python3
|
||||
# 这个程序的作用是辅助麦麦的WebUI更新配置文件,随core容器持续运行。
|
||||
# 麦麦的配置文件存储于ConfigMap中,挂载进core容器后属于只读文件,无法直接修改。
|
||||
# 此程序将core容器内的配置文件替换为可读写的中间层临时文件。启动时将实际配置文件写入,并在后台持续检测文件变化,实时同步到k8s apiServer,反向修改ConfigMap。
|
||||
# 工作目录:/MaiMBot/webui-cm-sync
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from kubernetes import client, config
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
work_dir = '/MaiMBot/webui-cm-sync'
|
||||
os.chdir(work_dir)
|
||||
|
||||
config.load_incluster_config()
|
||||
core_api = client.CoreV1Api()
|
||||
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r") as f:
|
||||
namespace = f.read().strip()
|
||||
release_name = os.getenv("RELEASE_NAME")
|
||||
model_configmap_name = f'{release_name}-maibot-core-model-config'
|
||||
bot_configmap_name = f'{release_name}-maibot-core-bot-config'
|
||||
|
||||
# 过滤列表,只监控指定文件
|
||||
target_files = {
|
||||
os.path.abspath("model_config.toml"): (model_configmap_name, "model_config.toml"),
|
||||
os.path.abspath("bot_config.toml"): (bot_configmap_name, "bot_config.toml")
|
||||
}
|
||||
|
||||
|
||||
def get_configmap(configmap_name: str):
|
||||
"""获取core的ConfigMap内容"""
|
||||
cm = core_api.read_namespaced_config_map(name=configmap_name, namespace=namespace)
|
||||
return cm.data
|
||||
|
||||
|
||||
def set_configmap(configmap_name: str, configmap_data: dict[str, str]):
|
||||
"""设置core的ConfigMap内容"""
|
||||
core_api.patch_namespaced_config_map(configmap_name, namespace, {'data': configmap_data})
|
||||
|
||||
|
||||
class ConfigObserverHandler(FileSystemEventHandler):
|
||||
"""配置文件变化的事件处理器"""
|
||||
def on_modified(self, event):
|
||||
if os.path.abspath(event.src_path) in target_files:
|
||||
print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] File `{event.src_path}` was modified. '
|
||||
f'Start to sync...')
|
||||
with open(event.src_path, "r", encoding="utf-8") as _f:
|
||||
current_data = _f.read()
|
||||
_path = str(os.path.abspath(event.src_path))
|
||||
new_cm = {
|
||||
target_files[_path][1]: current_data
|
||||
}
|
||||
try:
|
||||
set_configmap(target_files[_path][0], new_cm)
|
||||
print(f'\tSync done.')
|
||||
except client.exceptions.ApiException as _e:
|
||||
print(f'\tError while setting configmap:\n'
|
||||
f'\t\tStatus Code: {_e.status}\n'
|
||||
f'\t\tReason: {_e.reason}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 初始化配置文件
|
||||
print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Initializing config files...')
|
||||
try:
|
||||
__initial_model_config = get_configmap(model_configmap_name)['model_config.toml']
|
||||
__initial_bot_config = get_configmap(bot_configmap_name)['bot_config.toml']
|
||||
except client.exceptions.ApiException as e:
|
||||
print(f'\tError while getting configmap:\n'
|
||||
f'\t\tStatus Code: {e.status}\n'
|
||||
f'\t\tReason: {e.reason}')
|
||||
exit(1)
|
||||
with open('model_config.toml', 'w') as f:
|
||||
f.write(__initial_model_config)
|
||||
with open('bot_config.toml', 'w') as f:
|
||||
f.write(__initial_bot_config)
|
||||
with open('ready', 'w') as f:
|
||||
f.write('true')
|
||||
print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Initializing done. Ready to sync.')
|
||||
|
||||
# 持续检测变化并同步
|
||||
observer = Observer()
|
||||
observer.schedule(ConfigObserverHandler(), work_dir, recursive=False)
|
||||
observer.start()
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
observer.stop()
|
||||
observer.join()
|
||||
3
helm-chart/core-webui-cm-sync/requirements.txt
Normal file
3
helm-chart/core-webui-cm-sync/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
toml~=0.10.2
|
||||
kubernetes~=34.1.0
|
||||
watchdog~=6.0.0
|
||||
56
helm-chart/files/k8s-init.sh
Normal file
56
helm-chart/files/k8s-init.sh
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/bin/sh
|
||||
# 此脚本用于覆盖core容器的默认启动命令,进行一些初始化
|
||||
# 1
|
||||
# 由于k8s与docker-compose的卷挂载方式有所不同,需要利用此脚本为一些文件和目录提前创建好软链接
|
||||
# /MaiMBot/data是麦麦数据的实际挂载路径
|
||||
# /MaiMBot/statistics是统计数据的实际挂载路径
|
||||
# 2
|
||||
# 此脚本等待辅助容器webui-cm-sync就绪后再启动麦麦
|
||||
# 通过检测/MaiMBot/webui-cm-sync/ready文件来判断
|
||||
|
||||
set -e
|
||||
echo "[K8s Init] Preparing volume..."
|
||||
|
||||
# 初次启动,在存储卷中检查并创建关键文件和目录
|
||||
mkdir -p /MaiMBot/data/plugins
|
||||
mkdir -p /MaiMBot/data/logs
|
||||
if [ ! -d "/MaiMBot/statistics" ]
|
||||
then
|
||||
echo "[K8s Init] Statistics volume is disabled."
|
||||
else
|
||||
touch /MaiMBot/statistics/index.html
|
||||
fi
|
||||
|
||||
# 删除默认插件目录,准备创建用户插件目录软链接
|
||||
rm -rf /MaiMBot/plugins
|
||||
|
||||
# 创建软链接,从存储卷链接到实际位置
|
||||
ln -s /MaiMBot/data/plugins /MaiMBot/plugins
|
||||
ln -s /MaiMBot/data/logs /MaiMBot/logs
|
||||
if [ -f "/MaiMBot/statistics/index.html" ]
|
||||
then
|
||||
ln -s /MaiMBot/statistics/index.html /MaiMBot/maibot_statistics.html
|
||||
fi
|
||||
|
||||
echo "[K8s Init] Volume ready."
|
||||
|
||||
# 如果启用了WebUI,则等待辅助容器webui-cm-sync就绪,然后创建中间层配置文件软链接
|
||||
if [ "$MAIBOT_WEBUI_ENABLED" = "true" ]
|
||||
then
|
||||
echo "[K8s Init] WebUI enabled. Waiting for container 'webui-cm-sync' ready..."
|
||||
while [ ! -f /MaiMBot/webui-cm-sync/ready ]; do
|
||||
sleep 1
|
||||
done
|
||||
echo "[K8s Init] Container 'webui-cm-sync' ready."
|
||||
mkdir -p /MaiMBot/config
|
||||
ln -s /MaiMBot/webui-cm-sync/model_config.toml /MaiMBot/config/model_config.toml
|
||||
ln -s /MaiMBot/webui-cm-sync/bot_config.toml /MaiMBot/config/bot_config.toml
|
||||
echo "[K8s Init] Config files middle layer for WebUI created."
|
||||
else
|
||||
echo "[K8s Init] WebUI is disabled."
|
||||
fi
|
||||
|
||||
# 启动麦麦
|
||||
echo "[K8s Init] Waking up MaiBot..."
|
||||
echo
|
||||
exec python bot.py
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/bin/sh
|
||||
# 此脚本用于覆盖core容器的默认启动命令
|
||||
# 由于k8s与docker-compose的卷挂载方式有所不同,需要利用此脚本为一些文件和目录提前创建好软链接
|
||||
# /MaiMBot/data是麦麦数据的实际挂载路径
|
||||
# /MaiMBot/statistics是统计数据的实际挂载路径
|
||||
|
||||
set -e
|
||||
echo "[VolumeLinker] Preparing volume..."
|
||||
|
||||
# 初次启动,在存储卷中检查并创建关键文件和目录
|
||||
mkdir -p /MaiMBot/data/plugins
|
||||
mkdir -p /MaiMBot/data/logs
|
||||
if [ ! -d "/MaiMBot/statistics" ]
|
||||
then
|
||||
echo "[VolumeLinker] Statistics volume disabled."
|
||||
else
|
||||
touch /MaiMBot/statistics/index.html
|
||||
fi
|
||||
|
||||
# 删除空的插件目录,准备创建软链接
|
||||
rm -rf /MaiMBot/plugins
|
||||
|
||||
# 创建软链接,从存储卷链接到实际位置
|
||||
ln -s /MaiMBot/data/plugins /MaiMBot/plugins
|
||||
ln -s /MaiMBot/data/logs /MaiMBot/logs
|
||||
if [ -f "/MaiMBot/statistics/index.html" ]
|
||||
then
|
||||
ln -s /MaiMBot/statistics/index.html /MaiMBot/maibot_statistics.html
|
||||
fi
|
||||
|
||||
# 启动麦麦
|
||||
echo "[VolumeLinker] Starting MaiBot..."
|
||||
exec python bot.py
|
||||
@@ -58,5 +58,5 @@ spec:
|
||||
items:
|
||||
- key: config.toml
|
||||
path: config.toml
|
||||
name: {{ .Release.Name }}-maibot-adapter
|
||||
name: {{ .Release.Name }}-maibot-adapter-config
|
||||
name: config
|
||||
|
||||
14
helm-chart/templates/core/configmap-bot-config.yaml
Normal file
14
helm-chart/templates/core/configmap-bot-config.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
{{- if or .Release.IsInstall (and .Values.core.webui.enabled .Values.config.enable_config_override_with_webui) (and (not .Values.core.webui.enabled) .Values.config.enable_config_override_without_webui) }}
|
||||
# 渲染规则:
|
||||
# 初次安装,或配置了覆盖规则
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-core-bot-config
|
||||
namespace: {{ .Release.Namespace }}
|
||||
annotations:
|
||||
"helm.sh/resource-policy": keep
|
||||
data:
|
||||
bot_config.toml: |
|
||||
{{ .Values.config.core_bot_config | nindent 4 }}
|
||||
{{- end }}
|
||||
13
helm-chart/templates/core/configmap-env-config.yaml
Normal file
13
helm-chart/templates/core/configmap-env-config.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-core-env-config
|
||||
namespace: {{ .Release.Namespace }}
|
||||
data:
|
||||
.env: |
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
WEBUI_ENABLED={{ if .Values.core.webui.enabled }}true{{ else }}false{{ end }}
|
||||
WEBUI_MODE=production
|
||||
WEBUI_HOST=0.0.0.0
|
||||
WEBUI_PORT=8001
|
||||
14
helm-chart/templates/core/configmap-model-config.yaml
Normal file
14
helm-chart/templates/core/configmap-model-config.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
{{- if or .Release.IsInstall (and .Values.core.webui.enabled .Values.config.enable_config_override_with_webui) (and (not .Values.core.webui.enabled) .Values.config.enable_config_override_without_webui) }}
|
||||
# 渲染规则:
|
||||
# 初次安装,或配置了覆盖规则
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-core-model-config
|
||||
namespace: {{ .Release.Namespace }}
|
||||
annotations:
|
||||
"helm.sh/resource-policy": keep
|
||||
data:
|
||||
model_config.toml: |
|
||||
{{ .Values.config.core_model_config | nindent 4 }}
|
||||
{{- end }}
|
||||
@@ -1,13 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-core
|
||||
namespace: {{ .Release.Namespace }}
|
||||
data:
|
||||
.env: |
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
model_config.toml: |
|
||||
{{ .Values.config.core_model_config | nindent 4 }}
|
||||
bot_config.toml: |
|
||||
{{ .Values.config.core_bot_config | nindent 4 }}
|
||||
26
helm-chart/templates/core/ingress.yaml
Normal file
26
helm-chart/templates/core/ingress.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
{{- if and .Values.core.webui.enabled .Values.core.webui.ingress.enabled }}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-webui
|
||||
namespace: {{ .Release.Namespace }}
|
||||
{{- if .Values.core.webui.ingress.annotations }}
|
||||
annotations:
|
||||
{{ toYaml .Values.core.webui.ingress.annotations | nindent 4 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-core
|
||||
spec:
|
||||
ingressClassName: {{ .Values.core.webui.ingress.className }}
|
||||
rules:
|
||||
- host: {{ .Values.core.webui.ingress.host }}
|
||||
http:
|
||||
paths:
|
||||
- backend:
|
||||
service:
|
||||
name: {{ .Release.Name }}-maibot-core
|
||||
port:
|
||||
number: {{ .Values.core.webui.service.port }}
|
||||
path: {{ .Values.core.webui.ingress.path }}
|
||||
pathType: {{ .Values.core.webui.ingress.pathType }}
|
||||
{{- end }}
|
||||
@@ -11,6 +11,15 @@ spec:
|
||||
port: 8000
|
||||
protocol: TCP
|
||||
targetPort: 8000
|
||||
{{- if .Values.core.webui.enabled }}
|
||||
- name: webui
|
||||
port: {{ .Values.core.webui.service.port }}
|
||||
protocol: TCP
|
||||
targetPort: 8001
|
||||
{{- if eq .Values.core.webui.service.type "NodePort" }}
|
||||
nodePort: {{ .Values.core.webui.service.nodePort | default nil }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
selector:
|
||||
app: {{ .Release.Name }}-maibot-core
|
||||
type: ClusterIP
|
||||
|
||||
@@ -18,23 +18,32 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: core
|
||||
command: # 为了在k8s中初始化存储卷,这里替换启动命令为指定脚本
|
||||
command: # 为了在k8s中初始化,这里替换启动命令为指定脚本
|
||||
- sh
|
||||
args:
|
||||
- /MaiMBot/volume-linker.sh
|
||||
- /MaiMBot/k8s-init.sh
|
||||
env:
|
||||
- name: TZ
|
||||
value: Asia/Shanghai
|
||||
value: "Asia/Shanghai"
|
||||
- name: EULA_AGREE
|
||||
value: 99f08e0cab0190de853cb6af7d64d4de
|
||||
value: "99f08e0cab0190de853cb6af7d64d4de"
|
||||
- name: PRIVACY_AGREE
|
||||
value: 9943b855e72199d0f5016ea39052f1b6
|
||||
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.11.2-beta" }}
|
||||
value: "9943b855e72199d0f5016ea39052f1b6"
|
||||
{{- if .Values.core.webui.enabled }}
|
||||
- name: MAIBOT_WEBUI_ENABLED
|
||||
value: "true"
|
||||
{{- end}}
|
||||
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.11.5-beta" }}
|
||||
imagePullPolicy: {{ .Values.core.image.pullPolicy }}
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
name: adapter-ws
|
||||
protocol: TCP
|
||||
{{- if .Values.core.webui.enabled }}
|
||||
- containerPort: 8001
|
||||
name: webui
|
||||
protocol: TCP
|
||||
{{- end }}
|
||||
{{- if .Values.core.resources }}
|
||||
resources:
|
||||
{{ toYaml .Values.core.resources | nindent 12 }}
|
||||
@@ -42,26 +51,45 @@ spec:
|
||||
volumeMounts:
|
||||
- mountPath: /MaiMBot/data
|
||||
name: data
|
||||
- mountPath: /MaiMBot/volume-linker.sh
|
||||
- mountPath: /MaiMBot/k8s-init.sh
|
||||
name: scripts
|
||||
readOnly: true
|
||||
subPath: volume-linker.sh
|
||||
subPath: k8s-init.sh
|
||||
- mountPath: /MaiMBot/.env
|
||||
name: config
|
||||
name: env-config
|
||||
readOnly: true
|
||||
subPath: .env
|
||||
{{- if not .Values.core.webui.enabled }}
|
||||
- mountPath: /MaiMBot/config/model_config.toml
|
||||
name: config
|
||||
name: model-config
|
||||
readOnly: true
|
||||
subPath: model_config.toml
|
||||
- mountPath: /MaiMBot/config/bot_config.toml
|
||||
name: config
|
||||
name: bot-config
|
||||
readOnly: true
|
||||
subPath: bot_config.toml
|
||||
{{- end }}
|
||||
{{- if .Values.statistics_dashboard.enabled }}
|
||||
- mountPath: /MaiMBot/statistics
|
||||
name: statistics
|
||||
{{- end }}
|
||||
{{- if .Values.core.webui.enabled }}
|
||||
- mountPath: /MaiMBot/webui-cm-sync
|
||||
name: webui-cm-sync
|
||||
{{- end }}
|
||||
{{- if .Values.core.webui.enabled }}
|
||||
- name: webui-cm-sync
|
||||
image: {{ .Values.core.webui.cm_sync.image.repository | default "reg.mikumikumi.xyz/maibot/core-webui-cm-sync" }}:{{ .Values.core.webui.cm_sync.image.tag | default "0.11.5-beta" }}
|
||||
imagePullPolicy: {{ .Values.core.webui.cm_sync.image.pullPolicy }}
|
||||
env:
|
||||
- name: PYTHONUNBUFFERED
|
||||
value: "1"
|
||||
- name: RELEASE_NAME
|
||||
value: {{ .Release.Name }}
|
||||
volumeMounts:
|
||||
- mountPath: /MaiMBot/webui-cm-sync
|
||||
name: webui-cm-sync
|
||||
{{- end }}
|
||||
{{- if .Values.core.setup_default_plugins }}
|
||||
initContainers: # 用户插件目录存储在存储卷中,会在启动时覆盖掉容器的默认插件目录。此初始化容器用于默认插件更新后或麦麦首次启动时为用户自动安装默认插件到存储卷中
|
||||
- args:
|
||||
@@ -69,7 +97,7 @@ spec:
|
||||
command:
|
||||
- python3
|
||||
workingDir: /MaiMBot
|
||||
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.11.2-beta" }}
|
||||
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.11.5-beta" }}
|
||||
imagePullPolicy: {{ .Values.core.image.pullPolicy }}
|
||||
name: setup-plugins
|
||||
resources: { }
|
||||
@@ -81,6 +109,7 @@ spec:
|
||||
readOnly: true
|
||||
subPath: setup-plugins.py
|
||||
{{- end }}
|
||||
serviceAccountName: {{ .Release.Name }}-maibot-sa
|
||||
{{- if .Values.core.image.pullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{ toYaml .Values.core.image.pullSecrets | nindent 8 }}
|
||||
@@ -99,8 +128,8 @@ spec:
|
||||
claimName: {{ .Release.Name }}-maibot-core
|
||||
- configMap:
|
||||
items:
|
||||
- key: volume-linker.sh
|
||||
path: volume-linker.sh
|
||||
- key: k8s-init.sh
|
||||
path: k8s-init.sh
|
||||
{{- if .Values.core.setup_default_plugins }}
|
||||
- key: setup-plugins.py
|
||||
path: setup-plugins.py
|
||||
@@ -111,14 +140,28 @@ spec:
|
||||
items:
|
||||
- key: .env
|
||||
path: .env
|
||||
name: {{ .Release.Name }}-maibot-core-env-config
|
||||
name: env-config
|
||||
{{- if not .Values.core.webui.enabled }}
|
||||
- configMap:
|
||||
items:
|
||||
- key: model_config.toml
|
||||
path: model_config.toml
|
||||
name: {{ .Release.Name }}-maibot-core-model-config
|
||||
name: model-config
|
||||
- configMap:
|
||||
items:
|
||||
- key: bot_config.toml
|
||||
path: bot_config.toml
|
||||
name: {{ .Release.Name }}-maibot-core
|
||||
name: config
|
||||
name: {{ .Release.Name }}-maibot-core-bot-config
|
||||
name: bot-config
|
||||
{{- end }}
|
||||
{{- if .Values.statistics_dashboard.enabled }}
|
||||
- name: statistics
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
{{- end }}
|
||||
{{- if .Values.core.webui.enabled }}
|
||||
- emptyDir: {}
|
||||
name: webui-cm-sync
|
||||
{{- end }}
|
||||
|
||||
@@ -26,7 +26,7 @@ spec:
|
||||
value: "{{ .Values.napcat.permission.uid }}"
|
||||
- name: TZ
|
||||
value: Asia/Shanghai
|
||||
image: {{ .Values.napcat.image.repository | default "mlikiowa/napcat-docker" }}:{{ .Values.napcat.image.tag | default "v4.9.70" }}
|
||||
image: {{ .Values.napcat.image.repository | default "mlikiowa/napcat-docker" }}:{{ .Values.napcat.image.tag | default "v4.9.73" }}
|
||||
imagePullPolicy: {{ .Values.napcat.image.pullPolicy }}
|
||||
livenessProbe:
|
||||
failureThreshold: 3
|
||||
|
||||
@@ -5,8 +5,8 @@ metadata:
|
||||
namespace: {{ .Release.Namespace }}
|
||||
data:
|
||||
# core
|
||||
volume-linker.sh: |
|
||||
{{ .Files.Get "files/volume-linker.sh" | nindent 4 }}
|
||||
k8s-init.sh: |
|
||||
{{ .Files.Get "files/k8s-init.sh" | nindent 4 }}
|
||||
# core的初始化容器
|
||||
{{- if .Values.core.setup_default_plugins }}
|
||||
setup-plugins.py: |
|
||||
|
||||
@@ -11,11 +11,11 @@ spec:
|
||||
backoffLimit: 2
|
||||
template:
|
||||
spec:
|
||||
serviceAccountName: {{ .Release.Name }}-maibot-adapter-cm-generator
|
||||
serviceAccountName: {{ .Release.Name }}-maibot-sa
|
||||
restartPolicy: Never
|
||||
containers:
|
||||
- name: adapter-cm-generator
|
||||
image: {{ .Values.adapter.cm_generator.image.repository | default "reg.mikumikumi.xyz/maibot/adapter-cm-generator" }}:{{ .Values.adapter.cm_generator.image.tag | default "0.11.2-beta" }}
|
||||
image: {{ .Values.adapter.cm_generator.image.repository | default "reg.mikumikumi.xyz/maibot/adapter-cm-generator" }}:{{ .Values.adapter.cm_generator.image.tag | default "0.11.5-beta" }}
|
||||
workingDir: /app
|
||||
env:
|
||||
- name: PYTHONUNBUFFERED
|
||||
@@ -1,14 +1,14 @@
|
||||
# 动态生成adapter配置文件的configmap所需要的rbac授权
|
||||
# 初始化及反向修改ConfigMap所需要的rbac授权
|
||||
apiVersion: v1
|
||||
kind: ServiceAccount
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-adapter-cm-generator
|
||||
name: {{ .Release.Name }}-maibot-sa
|
||||
namespace: {{ .Release.Namespace }}
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: Role
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-adapter-cm-gen-role
|
||||
name: {{ .Release.Name }}-maibot-role
|
||||
namespace: {{ .Release.Namespace }}
|
||||
rules:
|
||||
- apiGroups: [""]
|
||||
@@ -21,13 +21,13 @@ rules:
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: RoleBinding
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-adapter-cm-gen-role-binding
|
||||
name: {{ .Release.Name }}-maibot-rolebinding
|
||||
namespace: {{ .Release.Namespace }}
|
||||
subjects:
|
||||
- kind: ServiceAccount
|
||||
name: {{ .Release.Name }}-maibot-adapter-cm-generator
|
||||
name: {{ .Release.Name }}-maibot-sa
|
||||
namespace: {{ .Release.Namespace }}
|
||||
roleRef:
|
||||
kind: Role
|
||||
name: {{ .Release.Name }}-maibot-adapter-cm-gen-role
|
||||
name: {{ .Release.Name }}-maibot-role
|
||||
apiGroup: rbac.authorization.k8s.io
|
||||
@@ -38,7 +38,7 @@ adapter:
|
||||
cm_generator:
|
||||
image:
|
||||
repository: # 默认 reg.mikumikumi.xyz/maibot/adapter-cm-generator
|
||||
tag: # 默认 0.11.2-beta
|
||||
tag: # 默认 0.11.5-beta
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
@@ -48,7 +48,7 @@ core:
|
||||
|
||||
image:
|
||||
repository: # 默认 sengokucola/maibot
|
||||
tag: # 默认 0.11.2-beta
|
||||
tag: # 默认 0.11.5-beta
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
@@ -65,6 +65,28 @@ core:
|
||||
|
||||
setup_default_plugins: true # 启用一个初始化容器,用于为用户自动安装默认插件到存储卷中
|
||||
|
||||
webui: # WebUI相关配置
|
||||
# 对从helm chart 0.11.5-beta 之前的旧版升级的用户的重要提示:
|
||||
# 旧版本helm chart没有配置configmap的保留策略,直接升级同时开启WebUI后会导致configmap丢失从而无法启动。
|
||||
# 这种情况可以先禁用WebUI更新一次,再启用WebUI更新一次即可解决问题。
|
||||
enabled: true # 默认启用
|
||||
cm_sync: # WebUI的辅助容器配置
|
||||
image:
|
||||
repository: # 默认 reg.mikumikumi.xyz/maibot/core-webui-cm-sync
|
||||
tag: # 默认 0.11.5-beta
|
||||
pullPolicy: IfNotPresent
|
||||
service:
|
||||
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
|
||||
port: 8001 # 服务端口
|
||||
nodePort: # 仅在设置NodePort类型时有效,不指定则会随机分配端口
|
||||
ingress:
|
||||
enabled: false
|
||||
className: nginx
|
||||
annotations: { }
|
||||
host: maim.example.com # 访问麦麦WebUI的域名
|
||||
path: /
|
||||
pathType: Prefix
|
||||
|
||||
# 麦麦的运行统计看板配置
|
||||
# 麦麦每隔一段时间会自动输出html格式的运行统计报告,此统计报告可以作为静态网页访问
|
||||
# 此功能默认禁用。如果你认为报告可以被公开访问(报告包含联系人/群组名称、模型token花费信息等),则可以启用此功能
|
||||
@@ -117,7 +139,7 @@ napcat:
|
||||
|
||||
image:
|
||||
repository: # 默认 mlikiowa/napcat-docker
|
||||
tag: # 默认 v4.9.70
|
||||
tag: # 默认 v4.9.73
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
@@ -189,9 +211,16 @@ sqlite_web:
|
||||
path: /
|
||||
pathType: Prefix
|
||||
|
||||
# 麦麦各部分组件的运行配置文件
|
||||
# 手动设置麦麦各部分组件的运行配置文件
|
||||
config:
|
||||
|
||||
# 启用WebUI后,配置文件的修改即可在WebUI进行。如果通过WebUI修改了配置,则实际的配置文件将与values中的配置存在差异。
|
||||
# 如果用户在k8s的ConfigMap中手动修改了配置文件,则实际的配置文件也会与values中的配置存在差异。
|
||||
# 出现上述两种情况时,为了避免helm升级麦麦时,下面values中的配置覆盖掉已有的配置文件而导致配置丢失,可以在这里禁止本次部署时的配置覆盖。
|
||||
# 注:由于adapter的配置无法通过WebUI修改,因此下面的adapter_config配置仍然会覆盖已有配置文件。
|
||||
enable_config_override_without_webui: true # 未启用WebUI时,默认覆盖
|
||||
enable_config_override_with_webui: false # 启用WebUI时,默认不覆盖
|
||||
|
||||
# adapter的config.toml
|
||||
adapter_config: |
|
||||
[inner]
|
||||
@@ -227,7 +256,7 @@ config:
|
||||
# core的model_config.toml
|
||||
core_model_config: |
|
||||
[inner]
|
||||
version = "1.7.7"
|
||||
version = "1.7.8"
|
||||
|
||||
# 配置文件版本号迭代规则同bot_config.toml
|
||||
|
||||
@@ -388,7 +417,7 @@ config:
|
||||
# core的bot_config.toml
|
||||
core_bot_config: |
|
||||
[inner]
|
||||
version = "6.21.4"
|
||||
version = "6.21.8"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -493,7 +522,7 @@ config:
|
||||
include_planner_reasoning = false # 是否将planner推理加入replyer,默认关闭(不加入)
|
||||
|
||||
[memory]
|
||||
max_agent_iterations = 5 # 记忆思考深度(最低为1(不深入思考))
|
||||
max_agent_iterations = 3 # 记忆思考深度(最低为1(不深入思考))
|
||||
|
||||
[jargon]
|
||||
all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除
|
||||
@@ -600,6 +629,9 @@ config:
|
||||
show_replyer_prompt = false # 是否显示回复器prompt
|
||||
show_replyer_reasoning = false # 是否显示回复器推理
|
||||
show_jargon_prompt = false # 是否显示jargon相关提示词
|
||||
show_memory_prompt = false # 是否显示记忆检索相关提示词
|
||||
show_planner_prompt = false # 是否显示planner的prompt和原始返回结果
|
||||
show_lpmm_paragraph = false # 是否显示lpmm找到的相关文段日志
|
||||
|
||||
[maim_message]
|
||||
auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
|
||||
|
||||
@@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -230,7 +230,7 @@ class HeartFChatting:
|
||||
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
||||
mentioned_message = message
|
||||
|
||||
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||
|
||||
# *控制频率用
|
||||
if mentioned_message:
|
||||
@@ -333,7 +333,6 @@ class HeartFChatting:
|
||||
# 重置连续 no_reply 计数
|
||||
self.consecutive_no_reply_count = 0
|
||||
reason = ""
|
||||
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
@@ -411,7 +410,7 @@ class HeartFChatting:
|
||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})")
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
@@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
|
||||
def get_qa_manager():
|
||||
return qa_manager
|
||||
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
|
||||
@@ -92,14 +92,20 @@ class QAManager:
|
||||
# 过滤阈值
|
||||
result = dyn_select_top_k(result, 0.5, 1.0)
|
||||
|
||||
for res in result:
|
||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||
if global_config.debug.show_lpmm_paragraph:
|
||||
for res in result:
|
||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||
|
||||
return result, ppr_node_weights
|
||||
|
||||
async def get_knowledge(self, question: str) -> Optional[str]:
|
||||
"""获取知识"""
|
||||
async def get_knowledge(self, question: str, limit: int = 5) -> Optional[str]:
|
||||
"""获取知识
|
||||
|
||||
Args:
|
||||
question: 查询问题
|
||||
limit: 返回的相关知识条数
|
||||
"""
|
||||
# 处理查询
|
||||
processed_result = await self.process_query(question)
|
||||
if processed_result is not None:
|
||||
@@ -109,6 +115,8 @@ class QAManager:
|
||||
logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
|
||||
return None
|
||||
|
||||
limit = max(1, limit) if isinstance(limit, int) else 5
|
||||
|
||||
knowledge = [
|
||||
(
|
||||
self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
|
||||
@@ -116,9 +124,17 @@ class QAManager:
|
||||
)
|
||||
for res in query_res
|
||||
]
|
||||
found_knowledge = "\n".join(
|
||||
[f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(knowledge)]
|
||||
)
|
||||
|
||||
# max_score = max([k[1] for k in knowledge]) if knowledge else None
|
||||
selected_knowledge = knowledge[:limit]
|
||||
|
||||
formatted_knowledge = [
|
||||
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
|
||||
]
|
||||
# if max_score is not None:
|
||||
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
||||
|
||||
found_knowledge = "\n".join(formatted_knowledge)
|
||||
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
|
||||
found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n"
|
||||
return found_knowledge
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
@@ -164,6 +163,45 @@ class ActionPlanner:
|
||||
return item[1]
|
||||
return None
|
||||
|
||||
def _replace_message_ids_with_text(
|
||||
self, text: Optional[str], message_id_list: List[Tuple[str, "DatabaseMessages"]]
|
||||
) -> Optional[str]:
|
||||
"""将文本中的 m+数字 消息ID替换为原消息内容,并添加双引号"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
id_to_message = {msg_id: msg for msg_id, msg in message_id_list}
|
||||
|
||||
# 匹配m后带2-4位数字,前后不是字母数字下划线
|
||||
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
|
||||
|
||||
matches = re.findall(pattern, text)
|
||||
if matches:
|
||||
available_ids = set(id_to_message.keys())
|
||||
found_ids = set(matches)
|
||||
missing_ids = found_ids - available_ids
|
||||
if missing_ids:
|
||||
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...")
|
||||
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中")
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
msg_id = match.group(0)
|
||||
message = id_to_message.get(msg_id)
|
||||
if not message:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
|
||||
return msg_id
|
||||
|
||||
msg_text = (message.processed_plain_text or message.display_message or "").strip()
|
||||
if not msg_text:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
|
||||
return msg_id
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview})")
|
||||
return f"消息({msg_text})"
|
||||
|
||||
return re.sub(pattern, _replace, text)
|
||||
|
||||
def _parse_single_action(
|
||||
self,
|
||||
action_json: dict,
|
||||
@@ -176,7 +214,10 @@ class ActionPlanner:
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
original_reasoning = action_json.get("reason", "未提供原因")
|
||||
reasoning = self._replace_message_ids_with_text(original_reasoning, message_id_list)
|
||||
if reasoning is None:
|
||||
reasoning = original_reasoning
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
@@ -573,9 +614,6 @@ class ActionPlanner:
|
||||
# 调用LLM
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
@@ -604,6 +642,7 @@ class ActionPlanner:
|
||||
if llm_content:
|
||||
try:
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
extracted_reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) or ""
|
||||
if json_objects:
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
|
||||
@@ -226,7 +226,9 @@ class DefaultReplyer:
|
||||
traceback.print_exc()
|
||||
return False, llm_response
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = ""
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
||||
@@ -1094,10 +1096,10 @@ class DefaultReplyer:
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
return ""
|
||||
|
||||
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
@@ -1115,10 +1117,10 @@ class DefaultReplyer:
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
|
||||
)
|
||||
|
||||
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
|
||||
if tool_calls:
|
||||
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
|
||||
end_time = time.time()
|
||||
|
||||
@@ -241,7 +241,9 @@ class PrivateReplyer:
|
||||
|
||||
return f"{sender_relation}"
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = ""
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
||||
@@ -1032,10 +1034,10 @@ class PrivateReplyer:
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
return ""
|
||||
|
||||
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
@@ -106,8 +106,8 @@ class ChatHistorySummarizer:
|
||||
await self._check_and_package(current_time)
|
||||
self.last_check_time = current_time
|
||||
return
|
||||
|
||||
logger.info(
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class ChatHistorySummarizer:
|
||||
before_count = len(self.current_batch.messages)
|
||||
self.current_batch.messages.extend(new_messages)
|
||||
self.current_batch.end_time = current_time
|
||||
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
logger.info(f"{self.log_prefix} 更新聊天话题: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
else:
|
||||
# 创建新批次
|
||||
self.current_batch = MessageBatch(
|
||||
@@ -127,7 +127,7 @@ class ChatHistorySummarizer:
|
||||
start_time=new_messages[0].time if new_messages else current_time,
|
||||
end_time=current_time,
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
|
||||
logger.info(f"{self.log_prefix} 新建聊天话题: {len(new_messages)} 条消息")
|
||||
|
||||
# 检查是否需要打包
|
||||
await self._check_and_package(current_time)
|
||||
|
||||
@@ -311,6 +311,8 @@ class Expression(BaseModel):
|
||||
context = TextField(null=True)
|
||||
up_content = TextField(null=True)
|
||||
|
||||
content_list = TextField(null=True)
|
||||
count = IntegerField(default=1)
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
|
||||
|
||||
@@ -19,6 +19,7 @@ PROJECT_ROOT = logger_file.parent.parent.parent.resolve()
|
||||
# 全局handler实例,避免重复创建
|
||||
_file_handler = None
|
||||
_console_handler = None
|
||||
_ws_handler = None
|
||||
|
||||
|
||||
def get_file_handler():
|
||||
@@ -59,6 +60,35 @@ def get_console_handler():
|
||||
return _console_handler
|
||||
|
||||
|
||||
def get_ws_handler():
|
||||
"""获取 WebSocket handler 单例"""
|
||||
global _ws_handler
|
||||
if _ws_handler is None:
|
||||
_ws_handler = WebSocketLogHandler()
|
||||
# WebSocket handler 推送所有级别的日志
|
||||
_ws_handler.setLevel(logging.DEBUG)
|
||||
return _ws_handler
|
||||
|
||||
|
||||
def initialize_ws_handler(loop):
|
||||
"""初始化 WebSocket handler 的事件循环
|
||||
|
||||
Args:
|
||||
loop: asyncio 事件循环
|
||||
"""
|
||||
handler = get_ws_handler()
|
||||
handler.set_loop(loop)
|
||||
|
||||
# 为 WebSocket handler 设置 JSON 格式化器(与文件格式相同)
|
||||
handler.setFormatter(file_formatter)
|
||||
|
||||
# 添加到根日志记录器
|
||||
root_logger = logging.getLogger()
|
||||
if handler not in root_logger.handlers:
|
||||
root_logger.addHandler(handler)
|
||||
print("[日志系统] ✅ WebSocket 日志推送已启用")
|
||||
|
||||
|
||||
class TimestampedFileHandler(logging.Handler):
|
||||
"""基于时间戳的文件处理器,简单的轮转份数限制"""
|
||||
|
||||
@@ -145,12 +175,76 @@ class TimestampedFileHandler(logging.Handler):
|
||||
super().close()
|
||||
|
||||
|
||||
class WebSocketLogHandler(logging.Handler):
|
||||
"""WebSocket 日志处理器 - 将日志实时推送到前端"""
|
||||
|
||||
_log_counter = 0 # 类级别计数器,确保 ID 唯一性
|
||||
|
||||
def __init__(self, loop=None):
|
||||
super().__init__()
|
||||
self.loop = loop
|
||||
self._initialized = False
|
||||
|
||||
def set_loop(self, loop):
|
||||
"""设置事件循环"""
|
||||
self.loop = loop
|
||||
self._initialized = True
|
||||
|
||||
def emit(self, record):
|
||||
"""发送日志到 WebSocket 客户端"""
|
||||
if not self._initialized or self.loop is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# 获取格式化后的消息
|
||||
# 对于 structlog,formatted message 包含完整的日志信息
|
||||
formatted_msg = self.format(record) if self.formatter else record.getMessage()
|
||||
|
||||
# 如果是 JSON 格式(文件格式化器),解析它
|
||||
message = formatted_msg
|
||||
try:
|
||||
import json
|
||||
|
||||
log_dict = json.loads(formatted_msg)
|
||||
message = log_dict.get("event", formatted_msg)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# 不是 JSON,直接使用消息
|
||||
message = formatted_msg
|
||||
|
||||
# 生成唯一 ID: 时间戳毫秒 + 自增计数器
|
||||
WebSocketLogHandler._log_counter += 1
|
||||
log_id = f"{int(record.created * 1000)}_{WebSocketLogHandler._log_counter}"
|
||||
|
||||
# 格式化日志数据
|
||||
log_data = {
|
||||
"id": log_id,
|
||||
"timestamp": datetime.fromtimestamp(record.created).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"level": record.levelname,
|
||||
"module": record.name,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
# 异步广播日志(不阻塞日志记录)
|
||||
try:
|
||||
import asyncio
|
||||
from src.webui.logs_ws import broadcast_log
|
||||
|
||||
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
|
||||
except Exception:
|
||||
# WebSocket 推送失败不影响日志记录
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
# 不要让 WebSocket 错误影响日志系统
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
# 旧的轮转文件处理器已移除,现在使用基于时间戳的处理器
|
||||
|
||||
|
||||
def close_handlers():
|
||||
"""安全关闭所有handler"""
|
||||
global _file_handler, _console_handler
|
||||
global _file_handler, _console_handler, _ws_handler
|
||||
|
||||
if _file_handler:
|
||||
_file_handler.close()
|
||||
@@ -160,6 +254,10 @@ def close_handlers():
|
||||
_console_handler.close()
|
||||
_console_handler = None
|
||||
|
||||
if _ws_handler:
|
||||
_ws_handler.close()
|
||||
_ws_handler = None
|
||||
|
||||
|
||||
def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
|
||||
"""移除重复的handler,特别是文件handler"""
|
||||
@@ -843,8 +941,8 @@ def start_log_cleanup_task():
|
||||
|
||||
def shutdown_logging():
|
||||
"""优雅关闭日志系统,释放所有文件句柄"""
|
||||
logger = get_logger("logger")
|
||||
logger.info("正在关闭日志系统...")
|
||||
# 先输出到控制台,避免日志系统关闭后无法输出
|
||||
print("[logger] 正在关闭日志系统...")
|
||||
|
||||
# 关闭所有handler
|
||||
root_logger = logging.getLogger()
|
||||
@@ -865,4 +963,5 @@ def shutdown_logging():
|
||||
handler.close()
|
||||
logger_obj.removeHandler(handler)
|
||||
|
||||
logger.info("日志系统已关闭")
|
||||
# 使用 print 而不是 logger,因为 logger 已经关闭
|
||||
print("[logger] 日志系统已关闭")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware # 新增导入
|
||||
from typing import Optional
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
import asyncio
|
||||
import os
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -16,21 +16,6 @@ class Server:
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self.set_address(host, port)
|
||||
|
||||
# 配置 CORS
|
||||
origins = [
|
||||
"http://localhost:7999", # 允许的前端源
|
||||
"http://127.0.0.1:7999",
|
||||
# 在生产环境中,您应该添加实际的前端域名
|
||||
]
|
||||
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True, # 是否支持 cookie
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
||||
allow_headers=["*"], # 允许所有 HTTP 请求头
|
||||
)
|
||||
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
"""注册路由
|
||||
|
||||
@@ -82,8 +67,17 @@ class Server:
|
||||
"""安全关闭服务器"""
|
||||
if self._server:
|
||||
self._server.should_exit = True
|
||||
await self._server.shutdown()
|
||||
self._server = None
|
||||
try:
|
||||
# 添加 3 秒超时,避免 shutdown 永久挂起
|
||||
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
|
||||
except asyncio.TimeoutError:
|
||||
# 超时就强制标记为 None,让垃圾回收处理
|
||||
pass
|
||||
except Exception:
|
||||
# 忽略其他异常
|
||||
pass
|
||||
finally:
|
||||
self._server = None
|
||||
|
||||
def get_app(self) -> FastAPI:
|
||||
"""获取 FastAPI 实例"""
|
||||
|
||||
@@ -56,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.2"
|
||||
MMC_VERSION = "0.11.5"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
|
||||
@@ -581,9 +581,15 @@ class DebugConfig(ConfigBase):
|
||||
show_jargon_prompt: bool = False
|
||||
"""是否显示jargon相关提示词"""
|
||||
|
||||
show_memory_prompt: bool = False
|
||||
"""是否显示记忆检索相关prompt"""
|
||||
|
||||
show_planner_prompt: bool = False
|
||||
"""是否显示planner相关提示词"""
|
||||
|
||||
show_lpmm_paragraph: bool = False
|
||||
"""是否显示lpmm找到的相关文段日志"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentalConfig(ConfigBase):
|
||||
@@ -647,7 +653,7 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
|
||||
enable: bool = True
|
||||
"""是否启用LPMM知识库"""
|
||||
|
||||
|
||||
lpmm_mode: Literal["classic", "agent"] = "classic"
|
||||
"""LPMM知识库模式,可选:classic经典模式,agent 模式,结合最新的记忆一同使用"""
|
||||
|
||||
@@ -690,4 +696,4 @@ class JargonConfig(ConfigBase):
|
||||
"""Jargon配置类"""
|
||||
|
||||
all_global: bool = False
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id"""
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id"""
|
||||
|
||||
@@ -61,6 +61,37 @@ def format_create_date(timestamp: float) -> str:
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def _compute_weights(population: List[Dict]) -> List[float]:
|
||||
"""
|
||||
根据表达的count计算权重,范围限定在1~3之间。
|
||||
count越高,权重越高,但最多为基础权重的3倍。
|
||||
"""
|
||||
if not population:
|
||||
return []
|
||||
|
||||
counts = []
|
||||
for item in population:
|
||||
count = item.get("count", 1)
|
||||
try:
|
||||
count_value = float(count)
|
||||
except (TypeError, ValueError):
|
||||
count_value = 1.0
|
||||
counts.append(max(count_value, 0.0))
|
||||
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
|
||||
if max_count == min_count:
|
||||
return [1.0 for _ in counts]
|
||||
|
||||
weights = []
|
||||
for count_value in counts:
|
||||
# 线性映射到[1,3]区间
|
||||
normalized = (count_value - min_count) / (max_count - min_count)
|
||||
weights.append(1.0 + normalized * 2.0) # 1~3
|
||||
return weights
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""
|
||||
随机抽样函数
|
||||
@@ -78,15 +109,24 @@ def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 使用随机抽样
|
||||
selected = []
|
||||
selected: List[Dict] = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
# 随机选择一个元素
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
for _ in range(min(k, len(population_copy))):
|
||||
weights = _compute_weights(population_copy)
|
||||
total_weight = sum(weights)
|
||||
if total_weight <= 0:
|
||||
# 回退到均匀随机
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
continue
|
||||
|
||||
threshold = random.uniform(0, total_weight)
|
||||
cumulative = 0.0
|
||||
for idx, weight in enumerate(weights):
|
||||
cumulative += weight
|
||||
if threshold <= cumulative:
|
||||
selected.append(population_copy.pop(idx))
|
||||
break
|
||||
|
||||
return selected
|
||||
|
||||
@@ -77,6 +77,9 @@ class ExpressionLearner:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="expression.learner"
|
||||
)
|
||||
self.summary_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
|
||||
)
|
||||
self.embedding_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
|
||||
)
|
||||
@@ -91,8 +94,8 @@ class ExpressionLearner:
|
||||
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
self.chat_id
|
||||
)
|
||||
self.min_messages_for_learning = 30 / self.learning_intensity # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 300 / self.learning_intensity
|
||||
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 120 / self.learning_intensity
|
||||
|
||||
def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
@@ -186,25 +189,13 @@ class ExpressionLearner:
|
||||
context,
|
||||
up_content,
|
||||
) in learnt_expressions:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
if query.exists():
|
||||
# 表达方式完全相同,只更新时间戳
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
@@ -362,6 +353,10 @@ class ExpressionLearner:
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
return None
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
# 对表达方式溯源
|
||||
@@ -433,6 +428,149 @@ class ExpressionLearner:
|
||||
expressions.append((situation, style))
|
||||
return expressions
|
||||
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
过滤掉style与机器人名称/昵称重复的表达
|
||||
"""
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
filtered: List[Tuple[str, str]] = []
|
||||
removed_count = 0
|
||||
for situation, style in expressions:
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() not in banned_casefold:
|
||||
filtered.append((situation, style))
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
|
||||
|
||||
return filtered
|
||||
|
||||
async def _upsert_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
|
||||
|
||||
if expr_obj:
|
||||
await self._update_existing_expression(
|
||||
expr_obj=expr_obj,
|
||||
situation=situation,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
return
|
||||
|
||||
await self._create_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
async def _create_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = [situation]
|
||||
formatted_situation = await self._compose_situation_text(content_list, 1, situation)
|
||||
|
||||
Expression.create(
|
||||
situation=formatted_situation,
|
||||
style=style,
|
||||
content_list=json.dumps(content_list, ensure_ascii=False),
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
|
||||
async def _update_existing_expression(
|
||||
self,
|
||||
expr_obj: Expression,
|
||||
situation: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = self._parse_content_list(expr_obj.content_list)
|
||||
content_list.append(situation)
|
||||
|
||||
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
|
||||
expr_obj.count = (expr_obj.count or 0) + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.context = context
|
||||
expr_obj.up_content = up_content
|
||||
|
||||
new_situation = await self._compose_situation_text(
|
||||
content_list=content_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.situation,
|
||||
)
|
||||
expr_obj.situation = new_situation
|
||||
|
||||
expr_obj.save()
|
||||
|
||||
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
summary = await self._summarize_situations(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
|
||||
if not situations:
|
||||
return None
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达情境失败: {e}")
|
||||
return None
|
||||
|
||||
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
|
||||
@@ -42,8 +42,6 @@ def init_prompt():
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
@@ -139,6 +137,7 @@ class ExpressionSelector:
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
@@ -237,9 +236,9 @@ class ExpressionSelector:
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
|
||||
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
|
||||
|
||||
|
||||
# 构建reply_reason块
|
||||
if reply_reason:
|
||||
reply_reason_block = f"你的回复理由是:{reply_reason}"
|
||||
@@ -261,9 +260,8 @@ class ExpressionSelector:
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
# print(prompt)
|
||||
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
@@ -23,6 +23,26 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
def _contains_bot_self_name(content: str) -> bool:
|
||||
"""
|
||||
判断词条是否包含机器人的昵称或别名
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
target = content.strip().lower()
|
||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||
|
||||
candidates = [name for name in [nickname, *alias_names] if name]
|
||||
|
||||
return any(name in target for name in candidates if target)
|
||||
|
||||
|
||||
def _init_prompt() -> None:
|
||||
prompt_str = """
|
||||
**聊天内容,其中的SELF是你自己的发言**
|
||||
@@ -126,7 +146,7 @@ async def _enrich_raw_content_if_needed(
|
||||
) -> List[str]:
|
||||
"""
|
||||
检查raw_content是否只包含黑话本身,如果是,则获取该消息的前三条消息作为原始内容
|
||||
|
||||
|
||||
Args:
|
||||
content: 黑话内容
|
||||
raw_content_list: 原始raw_content列表
|
||||
@@ -134,22 +154,22 @@ async def _enrich_raw_content_if_needed(
|
||||
messages: 当前时间窗口内的消息列表
|
||||
extraction_start_time: 提取开始时间
|
||||
extraction_end_time: 提取结束时间
|
||||
|
||||
|
||||
Returns:
|
||||
处理后的raw_content列表
|
||||
"""
|
||||
enriched_list = []
|
||||
|
||||
|
||||
for raw_content in raw_content_list:
|
||||
# 检查raw_content是否只包含黑话本身(去除空白字符后比较)
|
||||
raw_content_clean = raw_content.strip()
|
||||
content_clean = content.strip()
|
||||
|
||||
|
||||
# 如果raw_content只包含黑话本身(可能有一些标点或空白),则尝试获取上下文
|
||||
# 去除所有空白字符后比较,确保只包含黑话本身
|
||||
raw_content_normalized = raw_content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
|
||||
content_normalized = content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
|
||||
|
||||
|
||||
if raw_content_normalized == content_normalized:
|
||||
# 在消息列表中查找只包含该黑话的消息(去除空白后比较)
|
||||
target_message = None
|
||||
@@ -160,22 +180,20 @@ async def _enrich_raw_content_if_needed(
|
||||
if msg_content_normalized == content_normalized:
|
||||
target_message = msg
|
||||
break
|
||||
|
||||
|
||||
if target_message and target_message.time:
|
||||
# 获取该消息的前三条消息
|
||||
try:
|
||||
previous_messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=target_message.time,
|
||||
limit=3
|
||||
chat_id=chat_id, timestamp=target_message.time, limit=3
|
||||
)
|
||||
|
||||
|
||||
if previous_messages:
|
||||
# 将前三条消息和当前消息一起格式化
|
||||
context_messages = previous_messages + [target_message]
|
||||
# 按时间排序
|
||||
context_messages.sort(key=lambda x: x.time or 0)
|
||||
|
||||
|
||||
# 格式化为可读消息
|
||||
formatted_context, _ = await build_readable_messages_with_list(
|
||||
context_messages,
|
||||
@@ -183,7 +201,7 @@ async def _enrich_raw_content_if_needed(
|
||||
timestamp_mode="relative",
|
||||
truncate=False,
|
||||
)
|
||||
|
||||
|
||||
if formatted_context.strip():
|
||||
enriched_list.append(formatted_context.strip())
|
||||
logger.warning(f"为黑话 {content} 补充了上下文消息")
|
||||
@@ -203,7 +221,7 @@ async def _enrich_raw_content_if_needed(
|
||||
else:
|
||||
# raw_content包含更多内容,直接使用
|
||||
enriched_list.append(raw_content)
|
||||
|
||||
|
||||
return enriched_list
|
||||
|
||||
|
||||
@@ -217,31 +235,31 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||
# 如果已完成所有推断,不再推断
|
||||
if jargon_obj.is_complete:
|
||||
return False
|
||||
|
||||
|
||||
count = jargon_obj.count or 0
|
||||
last_inference = jargon_obj.last_inference_count or 0
|
||||
|
||||
|
||||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||
thresholds = [3,6, 10, 20, 40, 60, 100]
|
||||
|
||||
thresholds = [3, 6, 10, 20, 40, 60, 100]
|
||||
|
||||
if count < thresholds[0]:
|
||||
return False
|
||||
|
||||
|
||||
# 如果count没有超过上次判定值,不需要判定
|
||||
if count <= last_inference:
|
||||
return False
|
||||
|
||||
|
||||
# 找到第一个大于last_inference的阈值
|
||||
next_threshold = None
|
||||
for threshold in thresholds:
|
||||
if threshold > last_inference:
|
||||
next_threshold = threshold
|
||||
break
|
||||
|
||||
|
||||
# 如果没有找到下一个阈值,说明已经超过100,不应该再推断
|
||||
if next_threshold is None:
|
||||
return False
|
||||
|
||||
|
||||
# 检查count是否达到或超过这个阈值
|
||||
return count >= next_threshold
|
||||
|
||||
@@ -251,14 +269,14 @@ class JargonMiner:
|
||||
self.chat_id = chat_id
|
||||
self.last_learning_time: float = time.time()
|
||||
# 频率控制,可按需调整
|
||||
self.min_messages_for_learning: int = 15
|
||||
self.min_learning_interval: float = 20
|
||||
self.min_messages_for_learning: int = 10
|
||||
self.min_learning_interval: float = 20
|
||||
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.extract",
|
||||
)
|
||||
|
||||
|
||||
# 初始化stream_name作为类属性,避免重复提取
|
||||
chat_manager = get_chat_manager()
|
||||
stream_name = chat_manager.get_stream_name(self.chat_id)
|
||||
@@ -283,17 +301,19 @@ class JargonMiner:
|
||||
try:
|
||||
content = jargon_obj.content
|
||||
raw_content_str = jargon_obj.raw_content or ""
|
||||
|
||||
|
||||
# 解析raw_content列表
|
||||
raw_content_list = []
|
||||
if raw_content_str:
|
||||
try:
|
||||
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||
raw_content_list = (
|
||||
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||
)
|
||||
if not isinstance(raw_content_list, list):
|
||||
raw_content_list = [raw_content_list] if raw_content_list else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
raw_content_list = [raw_content_str] if raw_content_str else []
|
||||
|
||||
|
||||
if not raw_content_list:
|
||||
logger.warning(f"jargon {content} 没有raw_content,跳过推断")
|
||||
return
|
||||
@@ -305,12 +325,12 @@ class JargonMiner:
|
||||
content=content,
|
||||
raw_content_list=raw_content_text,
|
||||
)
|
||||
|
||||
|
||||
response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3)
|
||||
if not response1:
|
||||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||||
return
|
||||
|
||||
|
||||
# 解析推断1结果
|
||||
inference1 = None
|
||||
try:
|
||||
@@ -326,7 +346,7 @@ class JargonMiner:
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断1解析失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 检查推断1是否表示信息不足无法推断
|
||||
no_info = inference1.get("no_info", False)
|
||||
meaning1 = inference1.get("meaning", "").strip()
|
||||
@@ -337,18 +357,17 @@ class JargonMiner:
|
||||
jargon_obj.save()
|
||||
return
|
||||
|
||||
|
||||
# 步骤2: 仅基于content推断
|
||||
prompt2 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_content_only_prompt",
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3)
|
||||
if not response2:
|
||||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||||
return
|
||||
|
||||
|
||||
# 解析推断2结果
|
||||
inference2 = None
|
||||
try:
|
||||
@@ -364,13 +383,12 @@ class JargonMiner:
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
# logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
# logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
# logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
# logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
@@ -381,22 +399,22 @@ class JargonMiner:
|
||||
logger.debug(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.debug(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.debug(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
|
||||
# 步骤3: 比较两个推断结果
|
||||
prompt3 = await global_prompt_manager.format_prompt(
|
||||
"jargon_compare_inference_prompt",
|
||||
inference1=json.dumps(inference1, ensure_ascii=False),
|
||||
inference2=json.dumps(inference2, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||
|
||||
|
||||
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
|
||||
if not response3:
|
||||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||||
return
|
||||
|
||||
|
||||
# 解析比较结果
|
||||
comparison = None
|
||||
try:
|
||||
@@ -416,7 +434,7 @@ class JargonMiner:
|
||||
# 判断是否为黑话
|
||||
is_similar = comparison.get("is_similar", False)
|
||||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||||
|
||||
|
||||
# 更新数据库记录
|
||||
jargon_obj.is_jargon = is_jargon
|
||||
if is_jargon:
|
||||
@@ -425,33 +443,36 @@ class JargonMiner:
|
||||
else:
|
||||
# 不是黑话,也记录含义(使用推断2的结果,因为含义明确)
|
||||
jargon_obj.meaning = inference2.get("meaning", "")
|
||||
|
||||
|
||||
# 更新最后一次判定的count值,避免重启后重复判定
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
|
||||
|
||||
# 如果count>=100,标记为完成,不再进行推断
|
||||
if (jargon_obj.count or 0) >= 100:
|
||||
jargon_obj.is_complete = True
|
||||
|
||||
|
||||
jargon_obj.save()
|
||||
logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
|
||||
|
||||
logger.debug(
|
||||
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
|
||||
)
|
||||
|
||||
# 固定输出推断结果,格式化为可读形式
|
||||
if is_jargon:
|
||||
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
|
||||
meaning = jargon_obj.meaning or "无详细说明"
|
||||
is_global = jargon_obj.is_global
|
||||
if is_global:
|
||||
logger.info(f"[通用黑话]{content}的含义是 {meaning}")
|
||||
logger.info(f"[黑话]{content}的含义是 {meaning}")
|
||||
else:
|
||||
logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}")
|
||||
else:
|
||||
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
|
||||
logger.info(f"[{self.stream_name}]{content} 不是黑话")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"jargon推断失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def should_trigger(self) -> bool:
|
||||
@@ -479,7 +500,7 @@ class JargonMiner:
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_learning_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
|
||||
# 拉取学习窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
@@ -502,7 +523,7 @@ class JargonMiner:
|
||||
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
|
||||
if not response:
|
||||
return
|
||||
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon提取提示词: {prompt}")
|
||||
logger.info(f"jargon提取结果: {response}")
|
||||
@@ -532,7 +553,7 @@ class JargonMiner:
|
||||
continue
|
||||
content = str(item.get("content", "")).strip()
|
||||
raw_content_value = item.get("raw_content", "")
|
||||
|
||||
|
||||
# 处理raw_content:可能是字符串或列表
|
||||
raw_content_list = []
|
||||
if isinstance(raw_content_value, list):
|
||||
@@ -543,12 +564,12 @@ class JargonMiner:
|
||||
raw_content_str = raw_content_value.strip()
|
||||
if raw_content_str:
|
||||
raw_content_list = [raw_content_str]
|
||||
|
||||
|
||||
if content and raw_content_list:
|
||||
entries.append({
|
||||
"content": content,
|
||||
"raw_content": raw_content_list
|
||||
})
|
||||
if _contains_bot_self_name(content):
|
||||
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
entries.append({"content": content, "raw_content": raw_content_list})
|
||||
except Exception as e:
|
||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||
return
|
||||
@@ -565,13 +586,13 @@ class JargonMiner:
|
||||
if content_key not in seen:
|
||||
seen.add(content_key)
|
||||
uniq_entries.append(entry)
|
||||
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
|
||||
# 检查并补充raw_content:如果只包含黑话本身,则获取前三条消息作为上下文
|
||||
raw_content_list = await _enrich_raw_content_if_needed(
|
||||
content=content,
|
||||
@@ -581,60 +602,53 @@ class JargonMiner:
|
||||
extraction_start_time=extraction_start_time,
|
||||
extraction_end_time=extraction_end_time,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的)
|
||||
query = (
|
||||
Jargon.select()
|
||||
.where(Jargon.content == content)
|
||||
)
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
else:
|
||||
# 关闭all_global:只查询chat_id匹配的记录(不考虑is_global)
|
||||
query = (
|
||||
Jargon.select()
|
||||
.where(
|
||||
(Jargon.chat_id == self.chat_id) &
|
||||
(Jargon.content == content)
|
||||
)
|
||||
)
|
||||
|
||||
query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
|
||||
|
||||
if query.exists():
|
||||
obj = query.get()
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.jargon.all_global:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
|
||||
obj.save()
|
||||
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
@@ -644,13 +658,13 @@ class JargonMiner:
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=self.chat_id,
|
||||
is_global=is_global_new,
|
||||
count=1
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
@@ -662,13 +676,13 @@ class JargonMiner:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
|
||||
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
|
||||
self.last_learning_time = extraction_end_time
|
||||
|
||||
|
||||
if saved or updated:
|
||||
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
@@ -694,15 +708,11 @@ async def extract_and_store_jargon(chat_id: str) -> None:
|
||||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
case_sensitive: bool = False,
|
||||
fuzzy: bool = True
|
||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
搜索jargon,支持大小写不敏感和模糊搜索
|
||||
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
chat_id: 可选的聊天ID
|
||||
@@ -711,21 +721,18 @@ def search_jargon(
|
||||
limit: 返回结果数量限制,默认10
|
||||
case_sensitive: 是否大小写敏感,默认False(不敏感)
|
||||
fuzzy: 是否模糊搜索,默认True(使用LIKE匹配)
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: 包含content, meaning的字典列表
|
||||
"""
|
||||
if not keyword or not keyword.strip():
|
||||
return []
|
||||
|
||||
|
||||
keyword = keyword.strip()
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = Jargon.select(
|
||||
Jargon.content,
|
||||
Jargon.meaning
|
||||
)
|
||||
|
||||
query = Jargon.select(Jargon.content, Jargon.meaning)
|
||||
|
||||
# 构建搜索条件
|
||||
if case_sensitive:
|
||||
# 大小写敏感
|
||||
@@ -734,7 +741,7 @@ def search_jargon(
|
||||
search_condition = Jargon.content.contains(keyword)
|
||||
else:
|
||||
# 精确匹配
|
||||
search_condition = (Jargon.content == keyword)
|
||||
search_condition = Jargon.content == keyword
|
||||
else:
|
||||
# 大小写不敏感
|
||||
if fuzzy:
|
||||
@@ -742,10 +749,10 @@ def search_jargon(
|
||||
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
||||
else:
|
||||
# 精确匹配(使用LOWER函数)
|
||||
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
|
||||
|
||||
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
|
||||
|
||||
query = query.where(search_condition)
|
||||
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id)
|
||||
@@ -753,35 +760,28 @@ def search_jargon(
|
||||
else:
|
||||
# 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon
|
||||
if chat_id:
|
||||
query = query.where(
|
||||
(Jargon.chat_id == chat_id) | Jargon.is_global
|
||||
)
|
||||
|
||||
query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
|
||||
|
||||
# 只返回有meaning的记录
|
||||
query = query.where(
|
||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
||||
)
|
||||
|
||||
query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
|
||||
# 按count降序排序,优先返回出现频率高的
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
|
||||
# 限制结果数量
|
||||
query = query.limit(limit)
|
||||
|
||||
|
||||
# 执行查询并返回结果
|
||||
results = []
|
||||
for jargon in query:
|
||||
results.append({
|
||||
"content": jargon.content or "",
|
||||
"meaning": jargon.meaning or ""
|
||||
})
|
||||
|
||||
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: str) -> None:
|
||||
"""将黑话存入jargon系统
|
||||
|
||||
|
||||
Args:
|
||||
jargon_keyword: 黑话关键词
|
||||
answer: 答案内容(将概括为raw_content)
|
||||
@@ -794,53 +794,52 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
||||
答案:{answer}
|
||||
|
||||
只输出概括后的内容,不要输出其他内容:"""
|
||||
|
||||
|
||||
success, summary, _, _ = await llm_api.generate_with_model(
|
||||
summary_prompt,
|
||||
model_config=model_config.model_task_config.utils_small,
|
||||
request_type="memory.summarize_jargon",
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"概括答案提示: {summary_prompt}")
|
||||
logger.info(f"概括答案: {summary}")
|
||||
|
||||
|
||||
if not success:
|
||||
logger.warning(f"概括答案失败,使用原始答案: {summary}")
|
||||
summary = answer[:100] # 截取前100字符作为备用
|
||||
|
||||
|
||||
raw_content = summary.strip()[:200] # 限制长度
|
||||
|
||||
|
||||
# 检查是否已存在
|
||||
if global_config.jargon.all_global:
|
||||
query = Jargon.select().where(Jargon.content == jargon_keyword)
|
||||
else:
|
||||
query = Jargon.select().where(
|
||||
(Jargon.chat_id == chat_id) &
|
||||
(Jargon.content == jargon_keyword)
|
||||
)
|
||||
|
||||
query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
|
||||
|
||||
if query.exists():
|
||||
# 更新现有记录
|
||||
obj = query.get()
|
||||
obj.count = (obj.count or 0) + 1
|
||||
|
||||
|
||||
# 合并raw_content列表
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + [raw_content]))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
|
||||
if global_config.jargon.all_global:
|
||||
obj.is_global = True
|
||||
|
||||
|
||||
obj.save()
|
||||
logger.info(f"更新jargon记录: {jargon_keyword}")
|
||||
else:
|
||||
@@ -851,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
||||
raw_content=json.dumps([raw_content], ensure_ascii=False),
|
||||
chat_id=chat_id,
|
||||
is_global=is_global_new,
|
||||
count=1
|
||||
count=1,
|
||||
)
|
||||
logger.info(f"创建新jargon记录: {jargon_keyword}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储jargon失败: {e}")
|
||||
|
||||
|
||||
|
||||
@@ -147,7 +147,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
|
||||
param_type_value = tool_option_param.param_type.value
|
||||
if param_type_value == "bool":
|
||||
param_type_value = "boolean"
|
||||
|
||||
|
||||
return_dict: dict[str, Any] = {
|
||||
"type": param_type_value,
|
||||
"description": tool_option_param.description,
|
||||
|
||||
@@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
|
||||
param_type_value = tool_option_param.param_type.value
|
||||
if param_type_value == "bool":
|
||||
param_type_value = "boolean"
|
||||
|
||||
|
||||
return_dict: dict[str, Any] = {
|
||||
"type": param_type_value,
|
||||
"description": tool_option_param.description,
|
||||
|
||||
@@ -116,9 +116,7 @@ class MessageBuilder:
|
||||
构建消息对象
|
||||
:return: Message对象
|
||||
"""
|
||||
if len(self.__content) == 0 and not (
|
||||
self.__role == RoleType.Assistant and self.__tool_calls
|
||||
):
|
||||
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
|
||||
raise ValueError("内容不能为空")
|
||||
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||
|
||||
@@ -166,7 +166,7 @@ class LLMRequest:
|
||||
time_cost=time.time() - start_time,
|
||||
)
|
||||
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
|
||||
async def generate_response_with_message_async(
|
||||
self,
|
||||
message_factory: Callable[[BaseClient], List[Message]],
|
||||
|
||||
48
src/main.py
48
src/main.py
@@ -36,37 +36,39 @@ class MainSystem:
|
||||
# 使用消息API替代直接的FastAPI实例
|
||||
self.app: MessageServer = get_global_api()
|
||||
self.server: Server = get_global_server()
|
||||
|
||||
# 注册 WebUI API 路由
|
||||
self._register_webui_routes()
|
||||
|
||||
# 设置 WebUI(开发/生产模式)
|
||||
self._setup_webui()
|
||||
self.webui_server = None # 独立的 WebUI 服务器
|
||||
|
||||
def _register_webui_routes(self):
|
||||
"""注册 WebUI API 路由"""
|
||||
try:
|
||||
from src.webui.routes import router as webui_router
|
||||
self.server.register_router(webui_router)
|
||||
logger.info("WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
logger.warning(f"注册 WebUI API 路由失败: {e}")
|
||||
# 设置独立的 WebUI 服务器
|
||||
self._setup_webui_server()
|
||||
|
||||
def _setup_webui(self):
|
||||
"""设置 WebUI(根据环境变量决定模式)"""
|
||||
def _setup_webui_server(self):
|
||||
"""设置独立的 WebUI 服务器"""
|
||||
import os
|
||||
|
||||
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
||||
if not webui_enabled:
|
||||
logger.info("WebUI 已禁用")
|
||||
return
|
||||
|
||||
|
||||
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
|
||||
|
||||
|
||||
try:
|
||||
from src.webui.manager import setup_webui
|
||||
setup_webui(mode=webui_mode)
|
||||
from src.webui.webui_server import get_webui_server
|
||||
|
||||
self.webui_server = get_webui_server()
|
||||
|
||||
if webui_mode == "development":
|
||||
logger.info("📝 WebUI 开发模式已启用")
|
||||
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
|
||||
logger.info("💡 请手动启动前端开发服务器: cd MaiBot-Dashboard && bun dev")
|
||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||
else:
|
||||
logger.info("✅ WebUI 生产模式已启用")
|
||||
logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001")
|
||||
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置 WebUI 失败: {e}")
|
||||
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化系统组件"""
|
||||
@@ -161,6 +163,10 @@ class MainSystem:
|
||||
self.server.run(),
|
||||
]
|
||||
|
||||
# 如果 WebUI 服务器已初始化,添加到任务列表
|
||||
if self.webui_server:
|
||||
tasks.append(self.webui_server.start())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# async def forget_memory_task(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@
|
||||
记忆系统工具函数
|
||||
包含模糊查找、相似度计算等工具函数
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
@@ -14,6 +15,7 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
def parse_md_json(json_text: str) -> list[str]:
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
@@ -52,14 +54,15 @@ def parse_md_json(json_text: str) -> list[str]:
|
||||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
|
||||
Returns:
|
||||
float: 相似度分数 (0-1)
|
||||
"""
|
||||
@@ -67,16 +70,16 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
# 预处理文本
|
||||
text1 = preprocess_text(text1)
|
||||
text2 = preprocess_text(text2)
|
||||
|
||||
|
||||
# 使用SequenceMatcher计算相似度
|
||||
similarity = SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
# 如果其中一个文本包含另一个,提高相似度
|
||||
if text1 in text2 or text2 in text1:
|
||||
similarity = max(similarity, 0.8)
|
||||
|
||||
|
||||
return similarity
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算相似度时出错: {e}")
|
||||
return 0.0
|
||||
@@ -85,31 +88,30 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""
|
||||
预处理文本,提高匹配准确性
|
||||
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
|
||||
Returns:
|
||||
str: 预处理后的文本
|
||||
"""
|
||||
try:
|
||||
# 转换为小写
|
||||
text = text.lower()
|
||||
|
||||
|
||||
# 移除标点符号和特殊字符
|
||||
text = re.sub(r'[^\w\s]', '', text)
|
||||
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
|
||||
# 移除多余空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理文本时出错: {e}")
|
||||
return text
|
||||
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
@@ -143,25 +145,24 @@ def parse_datetime_to_timestamp(value: str) -> float:
|
||||
def parse_time_range(time_range: str) -> Tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
|
||||
|
||||
Args:
|
||||
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: (开始时间戳, 结束时间戳)
|
||||
"""
|
||||
if " - " not in time_range:
|
||||
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
|
||||
|
||||
|
||||
parts = time_range.split(" - ", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"时间范围格式错误: {time_range}")
|
||||
|
||||
|
||||
start_str = parts[0].strip()
|
||||
end_str = parts[1].strip()
|
||||
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str)
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
@@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
|
||||
@@ -15,13 +15,10 @@ logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
time_range: Optional[str] = None,
|
||||
fuzzy: bool = True
|
||||
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
|
||||
) -> str:
|
||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
|
||||
@@ -31,7 +28,7 @@ async def query_chat_history(
|
||||
fuzzy: 是否使用模糊匹配模式(默认True)
|
||||
- True: 模糊匹配,只要包含任意一个关键词即匹配(OR关系)
|
||||
- False: 全匹配,必须包含所有关键词才匹配(AND关系)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
@@ -39,10 +36,10 @@ async def query_chat_history(
|
||||
# 检查参数
|
||||
if not keyword and not time_range:
|
||||
return "未指定查询参数(需要提供keyword或time_range之一)"
|
||||
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
|
||||
# 时间过滤条件
|
||||
if time_range:
|
||||
# 判断是时间点还是时间范围
|
||||
@@ -50,79 +47,79 @@ async def query_chat_history(
|
||||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_filter = (
|
||||
(ChatHistory.start_time < end_timestamp) &
|
||||
(ChatHistory.end_time > start_timestamp)
|
||||
)
|
||||
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
|
||||
else:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||
time_filter = (
|
||||
(ChatHistory.start_time <= target_timestamp) &
|
||||
(ChatHistory.end_time >= target_timestamp)
|
||||
)
|
||||
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
|
||||
query = query.where(time_filter)
|
||||
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
||||
|
||||
# 如果有关键词,进一步过滤
|
||||
if keyword:
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
|
||||
if not keywords_lower:
|
||||
return "关键词为空"
|
||||
|
||||
|
||||
filtered_records = []
|
||||
|
||||
|
||||
for record in records:
|
||||
# 在theme、keywords、summary、original_text中搜索
|
||||
theme = (record.theme or "").lower()
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
|
||||
# 解析record中的keywords JSON
|
||||
record_keywords_list = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
# 根据匹配模式检查关键词
|
||||
matched = False
|
||||
if fuzzy:
|
||||
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
||||
for kw in keywords_lower:
|
||||
if (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list)):
|
||||
if (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
):
|
||||
matched = True
|
||||
break
|
||||
else:
|
||||
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
||||
matched = True
|
||||
for kw in keywords_lower:
|
||||
kw_matched = (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list))
|
||||
kw_matched = (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
)
|
||||
if not kw_matched:
|
||||
matched = False
|
||||
break
|
||||
|
||||
|
||||
if matched:
|
||||
filtered_records.append(record)
|
||||
|
||||
|
||||
if not filtered_records:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词"
|
||||
@@ -130,9 +127,9 @@ async def query_chat_history(
|
||||
return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述"
|
||||
else:
|
||||
return f"未找到{match_mode}'{keywords_str}'的聊天记录概述"
|
||||
|
||||
|
||||
records = filtered_records
|
||||
|
||||
|
||||
# 如果没有记录(可能是时间范围查询但没有匹配的记录)
|
||||
if not records:
|
||||
if time_range:
|
||||
@@ -148,22 +145,23 @@ async def query_chat_history(
|
||||
record.count = (record.count or 0) + 1
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新聊天记录概述计数失败: {update_error}")
|
||||
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
for record in records_to_use: # 最多返回3条记录
|
||||
result_parts = []
|
||||
|
||||
|
||||
# 添加主题
|
||||
if record.theme:
|
||||
result_parts.append(f"主题:{record.theme}")
|
||||
|
||||
|
||||
# 添加时间范围
|
||||
from datetime import datetime
|
||||
|
||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||
|
||||
|
||||
# 添加概括(优先使用summary,如果没有则使用original_text的前200字符)
|
||||
if record.summary:
|
||||
result_parts.append(f"概括:{record.summary}")
|
||||
@@ -172,18 +170,18 @@ async def query_chat_history(
|
||||
if len(record.original_text) > 200:
|
||||
text_preview += "..."
|
||||
result_parts.append(f"内容:{text_preview}")
|
||||
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
|
||||
if not results:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
if len(records) > len(records_to_use):
|
||||
omitted_count = len(records) - len(records_to_use)
|
||||
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
|
||||
return response_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询聊天历史概述失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
@@ -193,26 +191,26 @@ def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_chat_history",
|
||||
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
|
||||
description="根据时间或关键词在聊天记录中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
||||
"required": False
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "time_range",
|
||||
"type": "string",
|
||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||
"required": False
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "fuzzy",
|
||||
"type": "boolean",
|
||||
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
||||
"required": False
|
||||
}
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_chat_history
|
||||
execute_func=query_chat_history,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from .tool_registry import register_memory_retrieval_tool
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_lpmm_knowledge(query: str) -> str:
|
||||
async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
|
||||
"""在LPMM知识库中查询相关信息
|
||||
|
||||
Args:
|
||||
@@ -24,6 +24,12 @@ async def query_lpmm_knowledge(query: str) -> str:
|
||||
if not content:
|
||||
return "查询关键词为空"
|
||||
|
||||
try:
|
||||
limit_value = int(limit)
|
||||
except (TypeError, ValueError):
|
||||
limit_value = 5
|
||||
limit_value = max(1, limit_value)
|
||||
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用")
|
||||
return "LPMM知识库未启用"
|
||||
@@ -33,7 +39,7 @@ async def query_lpmm_knowledge(query: str) -> str:
|
||||
logger.debug("LPMM知识库未初始化,跳过查询")
|
||||
return "LPMM知识库未初始化"
|
||||
|
||||
knowledge_info = await qa_manager.get_knowledge(content)
|
||||
knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
|
||||
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
|
||||
|
||||
if knowledge_info:
|
||||
@@ -57,9 +63,13 @@ def register_tool():
|
||||
"type": "string",
|
||||
"description": "需要查询的关键词或问题",
|
||||
"required": True,
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"type": "integer",
|
||||
"description": "希望返回的相关知识条数,默认为5",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_lpmm_knowledge,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,23 +14,25 @@ logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
"""格式化群昵称信息
|
||||
|
||||
|
||||
Args:
|
||||
group_nick_name_field: 群昵称字段(可能是字符串JSON或None)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化后的群昵称信息字符串
|
||||
"""
|
||||
if not group_nick_name_field:
|
||||
return ""
|
||||
|
||||
|
||||
try:
|
||||
# 解析JSON格式的群昵称列表
|
||||
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||
|
||||
group_nick_names_data = (
|
||||
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||
)
|
||||
|
||||
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
|
||||
return ""
|
||||
|
||||
|
||||
# 格式化群昵称列表
|
||||
group_nick_list = []
|
||||
for item in group_nick_names_data:
|
||||
@@ -41,7 +43,7 @@ def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
elif isinstance(item, str):
|
||||
# 兼容旧格式(如果存在)
|
||||
group_nick_list.append(f" - {item}")
|
||||
|
||||
|
||||
if group_nick_list:
|
||||
return "群昵称:\n" + "\n".join(group_nick_list)
|
||||
return ""
|
||||
@@ -58,10 +60,10 @@ def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
|
||||
async def query_person_info(person_name: str) -> str:
|
||||
"""根据person_name查询用户信息,使用模糊查询
|
||||
|
||||
|
||||
Args:
|
||||
person_name: 用户名称(person_name字段)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含用户的所有信息
|
||||
"""
|
||||
@@ -69,37 +71,35 @@ async def query_person_info(person_name: str) -> str:
|
||||
person_name = str(person_name).strip()
|
||||
if not person_name:
|
||||
return "用户名称为空"
|
||||
|
||||
|
||||
# 构建查询条件(使用模糊查询)
|
||||
query = PersonInfo.select().where(
|
||||
PersonInfo.person_name.contains(person_name)
|
||||
)
|
||||
|
||||
query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
|
||||
|
||||
# 执行查询
|
||||
records = list(query.limit(20)) # 最多返回20条记录
|
||||
|
||||
|
||||
if not records:
|
||||
return f"未找到模糊匹配'{person_name}'的用户信息"
|
||||
|
||||
|
||||
# 区分精确匹配和模糊匹配的结果
|
||||
exact_matches = []
|
||||
fuzzy_matches = []
|
||||
|
||||
|
||||
for record in records:
|
||||
# 检查是否是精确匹配
|
||||
if record.person_name and record.person_name.strip() == person_name:
|
||||
exact_matches.append(record)
|
||||
else:
|
||||
fuzzy_matches.append(record)
|
||||
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
|
||||
|
||||
# 先处理精确匹配的结果
|
||||
for record in exact_matches:
|
||||
result_parts = []
|
||||
result_parts.append("【精确匹配】") # 标注为精确匹配
|
||||
|
||||
|
||||
# 基本信息
|
||||
if record.person_name:
|
||||
result_parts.append(f"用户名称:{record.person_name}")
|
||||
@@ -111,19 +111,19 @@ async def query_person_info(person_name: str) -> str:
|
||||
result_parts.append(f"平台:{record.platform}")
|
||||
if record.user_id:
|
||||
result_parts.append(f"平台用户ID:{record.user_id}")
|
||||
|
||||
|
||||
# 群昵称信息
|
||||
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
|
||||
if group_nick_name_str:
|
||||
result_parts.append(group_nick_name_str)
|
||||
|
||||
|
||||
# 名称设定原因
|
||||
if record.name_reason:
|
||||
result_parts.append(f"名称设定原因:{record.name_reason}")
|
||||
|
||||
|
||||
# 认识状态
|
||||
result_parts.append(f"是否已认识:{'是' if record.is_known else '否'}")
|
||||
|
||||
|
||||
# 时间信息
|
||||
if record.know_since:
|
||||
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -133,11 +133,15 @@ async def query_person_info(person_name: str) -> str:
|
||||
result_parts.append(f"最后认识时间:{last_know_str}")
|
||||
if record.know_times:
|
||||
result_parts.append(f"认识次数:{int(record.know_times)}")
|
||||
|
||||
|
||||
# 记忆点(memory_points)
|
||||
if record.memory_points:
|
||||
try:
|
||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
||||
memory_points_data = (
|
||||
json.loads(record.memory_points)
|
||||
if isinstance(record.memory_points, str)
|
||||
else record.memory_points
|
||||
)
|
||||
if isinstance(memory_points_data, list) and memory_points_data:
|
||||
# 解析记忆点格式:category:content:weight
|
||||
memory_list = []
|
||||
@@ -151,7 +155,7 @@ async def query_person_info(person_name: str) -> str:
|
||||
memory_list.append(f" - [{category}] {content} (权重: {weight})")
|
||||
else:
|
||||
memory_list.append(f" - {memory_point}")
|
||||
|
||||
|
||||
if memory_list:
|
||||
result_parts.append("记忆点:\n" + "\n".join(memory_list))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
@@ -161,14 +165,14 @@ async def query_person_info(person_name: str) -> str:
|
||||
if len(str(record.memory_points)) > 200:
|
||||
memory_preview += "..."
|
||||
result_parts.append(f"记忆点(原始数据):{memory_preview}")
|
||||
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
|
||||
# 再处理模糊匹配的结果
|
||||
for record in fuzzy_matches:
|
||||
result_parts = []
|
||||
result_parts.append("【模糊匹配】") # 标注为模糊匹配
|
||||
|
||||
|
||||
# 基本信息
|
||||
if record.person_name:
|
||||
result_parts.append(f"用户名称:{record.person_name}")
|
||||
@@ -180,19 +184,19 @@ async def query_person_info(person_name: str) -> str:
|
||||
result_parts.append(f"平台:{record.platform}")
|
||||
if record.user_id:
|
||||
result_parts.append(f"平台用户ID:{record.user_id}")
|
||||
|
||||
|
||||
# 群昵称信息
|
||||
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
|
||||
if group_nick_name_str:
|
||||
result_parts.append(group_nick_name_str)
|
||||
|
||||
|
||||
# 名称设定原因
|
||||
if record.name_reason:
|
||||
result_parts.append(f"名称设定原因:{record.name_reason}")
|
||||
|
||||
|
||||
# 认识状态
|
||||
result_parts.append(f"是否已认识:{'是' if record.is_known else '否'}")
|
||||
|
||||
|
||||
# 时间信息
|
||||
if record.know_since:
|
||||
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -202,11 +206,15 @@ async def query_person_info(person_name: str) -> str:
|
||||
result_parts.append(f"最后认识时间:{last_know_str}")
|
||||
if record.know_times:
|
||||
result_parts.append(f"认识次数:{int(record.know_times)}")
|
||||
|
||||
|
||||
# 记忆点(memory_points)
|
||||
if record.memory_points:
|
||||
try:
|
||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
||||
memory_points_data = (
|
||||
json.loads(record.memory_points)
|
||||
if isinstance(record.memory_points, str)
|
||||
else record.memory_points
|
||||
)
|
||||
if isinstance(memory_points_data, list) and memory_points_data:
|
||||
# 解析记忆点格式:category:content:weight
|
||||
memory_list = []
|
||||
@@ -220,7 +228,7 @@ async def query_person_info(person_name: str) -> str:
|
||||
memory_list.append(f" - [{category}] {content} (权重: {weight})")
|
||||
else:
|
||||
memory_list.append(f" - {memory_point}")
|
||||
|
||||
|
||||
if memory_list:
|
||||
result_parts.append("记忆点:\n" + "\n".join(memory_list))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
@@ -230,20 +238,20 @@ async def query_person_info(person_name: str) -> str:
|
||||
if len(str(record.memory_points)) > 200:
|
||||
memory_preview += "..."
|
||||
result_parts.append(f"记忆点(原始数据):{memory_preview}")
|
||||
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
|
||||
# 组合所有结果
|
||||
if not results:
|
||||
return f"未找到匹配'{person_name}'的用户信息"
|
||||
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
|
||||
|
||||
# 添加统计信息
|
||||
total_count = len(records)
|
||||
exact_count = len(exact_matches)
|
||||
fuzzy_count = len(fuzzy_matches)
|
||||
|
||||
|
||||
# 显示精确匹配和模糊匹配的统计
|
||||
if exact_count > 0 or fuzzy_count > 0:
|
||||
stats_parts = []
|
||||
@@ -257,13 +265,13 @@ async def query_person_info(person_name: str) -> str:
|
||||
response_text = f"找到 {total_count} 条匹配的用户信息:\n\n{response_text}"
|
||||
else:
|
||||
response_text = f"找到用户信息:\n\n{response_text}"
|
||||
|
||||
|
||||
# 如果结果数量达到限制,添加提示
|
||||
if total_count >= 20:
|
||||
response_text += "\n\n(已显示前20条结果,可能还有更多匹配记录)"
|
||||
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询用户信息失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
@@ -275,13 +283,7 @@ def register_tool():
|
||||
name="query_person_info",
|
||||
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
|
||||
parameters=[
|
||||
{
|
||||
"name": "person_name",
|
||||
"type": "string",
|
||||
"description": "用户名称,用于查询用户信息",
|
||||
"required": True
|
||||
}
|
||||
{"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
|
||||
],
|
||||
execute_func=query_person_info
|
||||
execute_func=query_person_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,10 +47,10 @@ class MemoryRetrievalTool:
|
||||
async def execute(self, **kwargs) -> str:
|
||||
"""执行工具"""
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
|
||||
def get_tool_definition(self) -> Dict[str, Any]:
|
||||
"""获取工具定义,用于LLM function calling
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 工具定义字典,格式与BaseTool一致
|
||||
格式: {"name": str, "description": str, "parameters": List[Tuple]}
|
||||
@@ -58,14 +58,14 @@ class MemoryRetrievalTool:
|
||||
# 转换参数格式为元组列表,格式与BaseTool一致
|
||||
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
|
||||
param_tuples = []
|
||||
|
||||
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type_str = param.get("type", "string").lower()
|
||||
param_desc = param.get("description", "")
|
||||
is_required = param.get("required", False)
|
||||
enum_values = param.get("enum", None)
|
||||
|
||||
|
||||
# 转换类型字符串到ToolParamType
|
||||
type_mapping = {
|
||||
"string": ToolParamType.STRING,
|
||||
@@ -76,18 +76,14 @@ class MemoryRetrievalTool:
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
|
||||
|
||||
|
||||
# 构建参数元组
|
||||
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
|
||||
param_tuples.append(param_tuple)
|
||||
|
||||
|
||||
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
||||
tool_def = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": param_tuples
|
||||
}
|
||||
|
||||
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
|
||||
|
||||
return tool_def
|
||||
|
||||
|
||||
@@ -126,10 +122,10 @@ class MemoryRetrievalToolRegistry:
|
||||
action_types.append("final_answer")
|
||||
action_types.append("no_answer")
|
||||
return " 或 ".join([f'"{at}"' for at in action_types])
|
||||
|
||||
|
||||
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具的定义列表,用于LLM function calling
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
|
||||
"""
|
||||
|
||||
@@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
|
||||
class Person:
|
||||
@classmethod
|
||||
def register_person(
|
||||
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
|
||||
cls,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
nickname: str,
|
||||
group_id: Optional[str] = None,
|
||||
group_nick_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
注册新用户的类方法
|
||||
@@ -727,7 +732,7 @@ person_info_manager = PersonInfoManager()
|
||||
|
||||
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
||||
"""将人物信息存入person_info的memory_points
|
||||
|
||||
|
||||
Args:
|
||||
person_name: 人物名称
|
||||
memory_content: 记忆内容
|
||||
@@ -739,13 +744,13 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
if not chat_stream:
|
||||
logger.warning(f"无法获取chat_stream for chat_id: {chat_id}")
|
||||
return
|
||||
|
||||
|
||||
platform = chat_stream.platform
|
||||
|
||||
|
||||
# 尝试从person_name查找person_id
|
||||
# 首先尝试通过person_name查找
|
||||
person_id = get_person_id_by_person_name(person_name)
|
||||
|
||||
|
||||
if not person_id:
|
||||
# 如果通过person_name找不到,尝试从chat_stream获取user_info
|
||||
if chat_stream.user_info:
|
||||
@@ -754,25 +759,25 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
else:
|
||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
||||
return
|
||||
|
||||
|
||||
# 创建或获取Person对象
|
||||
person = Person(person_id=person_id)
|
||||
|
||||
|
||||
if not person.is_known:
|
||||
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
|
||||
return
|
||||
|
||||
|
||||
# 确定记忆分类(可以根据memory_content判断,这里使用通用分类)
|
||||
category = "其他" # 默认分类,可以根据需要调整
|
||||
|
||||
|
||||
# 记忆点格式:category:content:weight
|
||||
weight = "1.0" # 默认权重
|
||||
memory_point = f"{category}:{memory_content}:{weight}"
|
||||
|
||||
|
||||
# 添加到memory_points
|
||||
if not person.memory_points:
|
||||
person.memory_points = []
|
||||
|
||||
|
||||
# 检查是否已存在相似的记忆点(避免重复)
|
||||
is_duplicate = False
|
||||
for existing_point in person.memory_points:
|
||||
@@ -781,16 +786,20 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
if len(parts) >= 2:
|
||||
existing_content = parts[1].strip()
|
||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
||||
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
|
||||
if (
|
||||
existing_content == memory_content
|
||||
or memory_content in existing_content
|
||||
or existing_content in memory_content
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
person.memory_points.append(memory_point)
|
||||
person.sync_to_database()
|
||||
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
|
||||
else:
|
||||
logger.debug(f"记忆点已存在,跳过: {memory_point}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储人物记忆失败: {e}")
|
||||
|
||||
@@ -124,7 +124,6 @@ class ToolExecutor:
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
)
|
||||
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
@@ -102,13 +102,13 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 5. 调用LLM
|
||||
models = llm_api.get_available_models()
|
||||
chat_model_config = models.get("replyer") # 使用字典访问方式
|
||||
chat_model_config = models.get("utils") # 使用字典访问方式
|
||||
if not chat_model_config:
|
||||
logger.error(f"{self.log_prefix} 未找到'replyer'模型配置,无法调用LLM")
|
||||
return False, "未找到'replyer'模型配置"
|
||||
logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM")
|
||||
return False, "未找到'utils'模型配置"
|
||||
|
||||
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="emoji"
|
||||
prompt, model_config=chat_model_config, request_type="emoji.select"
|
||||
)
|
||||
|
||||
if not success:
|
||||
|
||||
@@ -15,6 +15,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
||||
("limit", ToolParamType.INTEGER, "希望返回的相关知识条数,默认5", False, 5),
|
||||
]
|
||||
available_for_llm = global_config.lpmm_knowledge.enable
|
||||
|
||||
@@ -29,6 +30,12 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
"""
|
||||
try:
|
||||
query: str = function_args.get("query") # type: ignore
|
||||
limit = function_args.get("limit", 5)
|
||||
try:
|
||||
limit_value = int(limit)
|
||||
except (TypeError, ValueError):
|
||||
limit_value = 5
|
||||
limit_value = max(1, limit_value)
|
||||
# threshold = function_args.get("threshold", 0.4)
|
||||
|
||||
# 检查LPMM知识库是否启用
|
||||
@@ -38,7 +45,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
|
||||
# 调用知识库搜索
|
||||
|
||||
knowledge_info = await qa_manager.get_knowledge(query)
|
||||
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
|
||||
|
||||
logger.debug(f"知识库查询结果: {knowledge_info}")
|
||||
|
||||
|
||||
559
src/webui/config_routes.py
Normal file
559
src/webui/config_routes.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""
|
||||
配置管理API路由
|
||||
"""
|
||||
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
RelationshipConfig,
|
||||
ChatConfig,
|
||||
MessageReceiveConfig,
|
||||
EmojiConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
ChineseTypoConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
TelemetryConfig,
|
||||
ExperimentalConfig,
|
||||
MaimMessageConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
ToolConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
MoodConfig,
|
||||
VoiceConfig,
|
||||
JargonConfig,
|
||||
)
|
||||
from src.config.api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
ModelInfo,
|
||||
APIProvider,
|
||||
)
|
||||
from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
# ===== 辅助函数 =====
|
||||
|
||||
|
||||
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
||||
"""
|
||||
递归合并字典,保留 target 中的注释和格式
|
||||
将 source 的值更新到 target 中(仅更新已存在的键)
|
||||
|
||||
Args:
|
||||
target: 目标字典(tomlkit 对象,包含注释)
|
||||
source: 源字典(普通 dict 或 list)
|
||||
"""
|
||||
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
|
||||
if isinstance(source, list):
|
||||
return # 调用者需要直接赋值
|
||||
|
||||
# 如果都是字典,递归合并
|
||||
if isinstance(source, dict) and isinstance(target, dict):
|
||||
for key, value in source.items():
|
||||
if key == "version":
|
||||
continue # 跳过版本号
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
# 递归处理嵌套字典
|
||||
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||
_update_dict_preserve_comments(target_value, value)
|
||||
else:
|
||||
# 使用 tomlkit.item 保持类型
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/bot")
|
||||
async def get_bot_config_schema():
|
||||
"""获取麦麦主程序配置架构"""
|
||||
try:
|
||||
# Config 类包含所有子配置
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(Config)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
async def get_model_config_schema():
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
|
||||
|
||||
|
||||
# ===== 子配置架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/section/{section_name}")
|
||||
async def get_config_section_schema(section_name: str):
|
||||
"""
|
||||
获取指定配置节的架构
|
||||
|
||||
支持的section_name:
|
||||
- bot: BotConfig
|
||||
- personality: PersonalityConfig
|
||||
- relationship: RelationshipConfig
|
||||
- chat: ChatConfig
|
||||
- message_receive: MessageReceiveConfig
|
||||
- emoji: EmojiConfig
|
||||
- expression: ExpressionConfig
|
||||
- keyword_reaction: KeywordReactionConfig
|
||||
- chinese_typo: ChineseTypoConfig
|
||||
- response_post_process: ResponsePostProcessConfig
|
||||
- response_splitter: ResponseSplitterConfig
|
||||
- telemetry: TelemetryConfig
|
||||
- experimental: ExperimentalConfig
|
||||
- maim_message: MaimMessageConfig
|
||||
- lpmm_knowledge: LPMMKnowledgeConfig
|
||||
- tool: ToolConfig
|
||||
- memory: MemoryConfig
|
||||
- debug: DebugConfig
|
||||
- mood: MoodConfig
|
||||
- voice: VoiceConfig
|
||||
- jargon: JargonConfig
|
||||
- model_task_config: ModelTaskConfig
|
||||
- api_provider: APIProvider
|
||||
- model_info: ModelInfo
|
||||
"""
|
||||
section_map = {
|
||||
"bot": BotConfig,
|
||||
"personality": PersonalityConfig,
|
||||
"relationship": RelationshipConfig,
|
||||
"chat": ChatConfig,
|
||||
"message_receive": MessageReceiveConfig,
|
||||
"emoji": EmojiConfig,
|
||||
"expression": ExpressionConfig,
|
||||
"keyword_reaction": KeywordReactionConfig,
|
||||
"chinese_typo": ChineseTypoConfig,
|
||||
"response_post_process": ResponsePostProcessConfig,
|
||||
"response_splitter": ResponseSplitterConfig,
|
||||
"telemetry": TelemetryConfig,
|
||||
"experimental": ExperimentalConfig,
|
||||
"maim_message": MaimMessageConfig,
|
||||
"lpmm_knowledge": LPMMKnowledgeConfig,
|
||||
"tool": ToolConfig,
|
||||
"memory": MemoryConfig,
|
||||
"debug": DebugConfig,
|
||||
"mood": MoodConfig,
|
||||
"voice": VoiceConfig,
|
||||
"jargon": JargonConfig,
|
||||
"model_task_config": ModelTaskConfig,
|
||||
"api_provider": APIProvider,
|
||||
"model_info": ModelInfo,
|
||||
}
|
||||
|
||||
if section_name not in section_map:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
try:
|
||||
config_class = section_map[section_name]
|
||||
schema = ConfigSchemaGenerator.generate_schema(config_class, include_nested=False)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置节架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
|
||||
|
||||
|
||||
# ===== 配置读取接口 =====
|
||||
|
||||
|
||||
@router.get("/bot")
|
||||
async def get_bot_config():
|
||||
"""获取麦麦主程序配置"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
return {"success": True, "config": config_data}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
async def get_model_config():
|
||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
return {"success": True, "config": config_data}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
|
||||
|
||||
# ===== 配置更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
|
||||
logger.info("麦麦主程序配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
|
||||
logger.info("模型配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
|
||||
|
||||
# ===== 配置节更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 更新指定节
|
||||
if section_name not in config_data:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
# 使用递归合并保留注释(对于字典类型)
|
||||
# 对于数组类型(如 platforms, aliases),直接替换
|
||||
if isinstance(section_data, list):
|
||||
# 列表直接替换
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
|
||||
# 保存配置(tomlkit.dump 会保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
|
||||
|
||||
# ===== 原始 TOML 文件操作接口 =====
|
||||
|
||||
|
||||
@router.get("/bot/raw")
|
||||
async def get_bot_config_raw():
|
||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
raw_content = f.read()
|
||||
|
||||
return {"success": True, "content": raw_content}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
config_data = tomlkit.loads(raw_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
|
||||
# 验证配置数据结构
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(raw_content)
|
||||
|
||||
logger.info("麦麦主程序配置已更新(原始模式)")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 更新指定节
|
||||
if section_name not in config_data:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
# 使用递归合并保留注释(对于字典类型)
|
||||
# 对于数组表(如 [[models]], [[api_providers]]),直接替换
|
||||
if isinstance(section_data, list):
|
||||
# 列表直接替换
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
|
||||
# 保存配置(tomlkit.dump 会保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
|
||||
|
||||
# ===== 适配器配置管理接口 =====
|
||||
|
||||
|
||||
@router.get("/adapter-config/path")
|
||||
async def get_adapter_config_path():
|
||||
"""获取保存的适配器配置文件路径"""
|
||||
try:
|
||||
# 从 data/webui.json 读取路径偏好
|
||||
webui_data_path = os.path.join("data", "webui.json")
|
||||
if not os.path.exists(webui_data_path):
|
||||
return {"success": True, "path": None}
|
||||
|
||||
import json
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
|
||||
adapter_config_path = webui_data.get("adapter_config_path")
|
||||
if not adapter_config_path:
|
||||
return {"success": True, "path": None}
|
||||
|
||||
# 检查文件是否存在并返回最后修改时间
|
||||
if os.path.exists(adapter_config_path):
|
||||
import datetime
|
||||
mtime = os.path.getmtime(adapter_config_path)
|
||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||
return {"success": True, "path": adapter_config_path, "lastModified": last_modified}
|
||||
else:
|
||||
return {"success": True, "path": adapter_config_path, "lastModified": None}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
|
||||
# 保存到 data/webui.json
|
||||
webui_data_path = os.path.join("data", "webui.json")
|
||||
import json
|
||||
|
||||
# 读取现有数据
|
||||
if os.path.exists(webui_data_path):
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
else:
|
||||
webui_data = {}
|
||||
|
||||
# 更新路径
|
||||
webui_data["adapter_config_path"] = path
|
||||
|
||||
# 保存
|
||||
os.makedirs("data", exist_ok=True)
|
||||
with open(webui_data_path, "w", encoding="utf-8") as f:
|
||||
json.dump(webui_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"适配器配置路径已保存: {path}")
|
||||
return {"success": True, "message": "路径已保存"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
async def get_adapter_config(path: str):
|
||||
"""从指定路径读取适配器配置文件"""
|
||||
try:
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径参数不能为空")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(path):
|
||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||
|
||||
# 检查文件扩展名
|
||||
if not path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 读取文件内容
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
logger.info(f"已读取适配器配置: {path}")
|
||||
return {"success": True, "content": content}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
content = data.get("content")
|
||||
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
if content is None:
|
||||
raise HTTPException(status_code=400, detail="配置内容不能为空")
|
||||
|
||||
# 检查文件扩展名
|
||||
if not path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
import toml
|
||||
toml.loads(content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
||||
|
||||
# 保存文件
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"适配器配置已保存: {path}")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
|
||||
|
||||
336
src/webui/config_schema.py
Normal file
336
src/webui/config_schema.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
配置架构生成器 - 自动从配置类生成前端表单架构
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from dataclasses import fields, MISSING
|
||||
from typing import Any, get_origin, get_args, Literal, Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
|
||||
class FieldType(str, Enum):
|
||||
"""字段类型枚举"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
TEXTAREA = "textarea"
|
||||
|
||||
|
||||
class FieldSchema:
|
||||
"""字段架构"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
type: FieldType,
|
||||
label: str,
|
||||
description: str = "",
|
||||
default: Any = None,
|
||||
required: bool = True,
|
||||
options: Optional[list[str]] = None,
|
||||
min_value: Optional[float] = None,
|
||||
max_value: Optional[float] = None,
|
||||
items: Optional[dict] = None,
|
||||
properties: Optional[dict] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.label = label
|
||||
self.description = description
|
||||
self.default = default
|
||||
self.required = required
|
||||
self.options = options
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.items = items
|
||||
self.properties = properties
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"type": self.type.value,
|
||||
"label": self.label,
|
||||
"description": self.description,
|
||||
"required": self.required,
|
||||
}
|
||||
|
||||
if self.default is not None:
|
||||
result["default"] = self.default
|
||||
|
||||
if self.options is not None:
|
||||
result["options"] = self.options
|
||||
|
||||
if self.min_value is not None:
|
||||
result["minValue"] = self.min_value
|
||||
|
||||
if self.max_value is not None:
|
||||
result["maxValue"] = self.max_value
|
||||
|
||||
if self.items is not None:
|
||||
result["items"] = self.items
|
||||
|
||||
if self.properties is not None:
|
||||
result["properties"] = self.properties
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ConfigSchemaGenerator:
|
||||
"""配置架构生成器"""
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_description(config_class: type, field_name: str) -> str:
|
||||
"""
|
||||
从类定义中提取字段的文档字符串描述
|
||||
|
||||
Args:
|
||||
config_class: 配置类
|
||||
field_name: 字段名
|
||||
|
||||
Returns:
|
||||
str: 字段描述
|
||||
"""
|
||||
try:
|
||||
# 获取源代码
|
||||
source = inspect.getsource(config_class)
|
||||
lines = source.split("\n")
|
||||
|
||||
# 查找字段定义
|
||||
field_found = False
|
||||
description_lines = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 匹配字段定义行,例如: platform: str
|
||||
if f"{field_name}:" in line and "=" in line:
|
||||
field_found = True
|
||||
# 查找下一行的文档字符串
|
||||
if i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
# 单行文档字符串
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
||||
else:
|
||||
# 多行文档字符串
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
for j in range(i + 2, len(lines)):
|
||||
if quote in lines[j]:
|
||||
description_lines.append(lines[j].split(quote)[0].strip())
|
||||
break
|
||||
description_lines.append(lines[j].strip())
|
||||
break
|
||||
elif f"{field_name}:" in line and "=" not in line:
|
||||
# 没有默认值的字段
|
||||
field_found = True
|
||||
if i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
||||
else:
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
for j in range(i + 2, len(lines)):
|
||||
if quote in lines[j]:
|
||||
description_lines.append(lines[j].split(quote)[0].strip())
|
||||
break
|
||||
description_lines.append(lines[j].strip())
|
||||
break
|
||||
|
||||
if field_found and description_lines:
|
||||
return " ".join(description_lines)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _get_field_type_and_options(field_type: type) -> tuple[FieldType, Optional[list[str]], Optional[dict]]:
|
||||
"""
|
||||
获取字段类型和选项
|
||||
|
||||
Args:
|
||||
field_type: 字段类型
|
||||
|
||||
Returns:
|
||||
tuple: (FieldType, options, items)
|
||||
"""
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
# 处理 Literal 类型(枚举选项)
|
||||
if origin is Literal:
|
||||
return FieldType.SELECT, [str(arg) for arg in args], None
|
||||
|
||||
# 处理 list 类型
|
||||
if origin is list:
|
||||
item_type = args[0] if args else str
|
||||
if item_type is str:
|
||||
items = {"type": "string"}
|
||||
elif item_type is int:
|
||||
items = {"type": "integer"}
|
||||
elif item_type is float:
|
||||
items = {"type": "number"}
|
||||
elif item_type is bool:
|
||||
items = {"type": "boolean"}
|
||||
elif item_type is dict:
|
||||
items = {"type": "object"}
|
||||
else:
|
||||
items = {"type": "string"}
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理 set 类型(与 list 类似)
|
||||
if origin is set:
|
||||
item_type = args[0] if args else str
|
||||
if item_type is str:
|
||||
items = {"type": "string"}
|
||||
else:
|
||||
items = {"type": "string"}
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理基本类型
|
||||
if field_type is bool or field_type == bool:
|
||||
return FieldType.BOOLEAN, None, None
|
||||
elif field_type is int or field_type == int:
|
||||
return FieldType.INTEGER, None, None
|
||||
elif field_type is float or field_type == float:
|
||||
return FieldType.NUMBER, None, None
|
||||
elif field_type is str or field_type == str:
|
||||
return FieldType.STRING, None, None
|
||||
elif field_type is dict or origin is dict:
|
||||
return FieldType.OBJECT, None, None
|
||||
|
||||
# 默认为字符串
|
||||
return FieldType.STRING, None, None
|
||||
|
||||
@staticmethod
|
||||
def _format_field_name(name: str) -> str:
|
||||
"""
|
||||
格式化字段名为可读的标签
|
||||
|
||||
Args:
|
||||
name: 原始字段名
|
||||
|
||||
Returns:
|
||||
str: 格式化后的标签
|
||||
"""
|
||||
# 将下划线替换为空格,并首字母大写
|
||||
return " ".join(word.capitalize() for word in name.split("_"))
|
||||
|
||||
@staticmethod
|
||||
def generate_schema(config_class: type[ConfigBase], include_nested: bool = True) -> dict:
|
||||
"""
|
||||
从配置类生成前端表单架构
|
||||
|
||||
Args:
|
||||
config_class: 配置类(必须继承自 ConfigBase)
|
||||
include_nested: 是否包含嵌套的配置对象
|
||||
|
||||
Returns:
|
||||
dict: 前端表单架构
|
||||
"""
|
||||
if not issubclass(config_class, ConfigBase):
|
||||
raise ValueError(f"{config_class.__name__} 必须继承自 ConfigBase")
|
||||
|
||||
schema_fields = []
|
||||
nested_schemas = {}
|
||||
|
||||
for field in fields(config_class):
|
||||
# 跳过私有字段和内部字段
|
||||
if field.name.startswith("_") or field.name in ["MMC_VERSION"]:
|
||||
continue
|
||||
|
||||
# 提取字段描述
|
||||
description = ConfigSchemaGenerator._extract_field_description(config_class, field.name)
|
||||
|
||||
# 判断是否必填
|
||||
required = field.default is MISSING and field.default_factory is MISSING
|
||||
|
||||
# 获取默认值
|
||||
default_value = None
|
||||
if field.default is not MISSING:
|
||||
default_value = field.default
|
||||
elif field.default_factory is not MISSING:
|
||||
try:
|
||||
default_value = field.default_factory()
|
||||
except Exception:
|
||||
default_value = None
|
||||
|
||||
# 检查是否为嵌套的 ConfigBase
|
||||
if isinstance(field.type, type) and issubclass(field.type, ConfigBase):
|
||||
if include_nested:
|
||||
# 递归生成嵌套配置的架构
|
||||
nested_schema = ConfigSchemaGenerator.generate_schema(field.type, include_nested=True)
|
||||
nested_schemas[field.name] = nested_schema
|
||||
|
||||
field_schema = FieldSchema(
|
||||
name=field.name,
|
||||
type=FieldType.OBJECT,
|
||||
label=ConfigSchemaGenerator._format_field_name(field.name),
|
||||
description=description or field.type.__doc__ or "",
|
||||
default=default_value,
|
||||
required=required,
|
||||
properties=nested_schema,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
# 获取字段类型和选项
|
||||
field_type, options, items = ConfigSchemaGenerator._get_field_type_and_options(field.type)
|
||||
|
||||
# 特殊处理:长文本使用 textarea
|
||||
if field_type == FieldType.STRING and field.name in [
|
||||
"personality",
|
||||
"reply_style",
|
||||
"interest",
|
||||
"plan_style",
|
||||
"visual_style",
|
||||
"private_plan_style",
|
||||
"emotion_style",
|
||||
"reaction",
|
||||
"filtration_prompt",
|
||||
]:
|
||||
field_type = FieldType.TEXTAREA
|
||||
|
||||
field_schema = FieldSchema(
|
||||
name=field.name,
|
||||
type=field_type,
|
||||
label=ConfigSchemaGenerator._format_field_name(field.name),
|
||||
description=description,
|
||||
default=default_value,
|
||||
required=required,
|
||||
options=options,
|
||||
items=items,
|
||||
)
|
||||
|
||||
schema_fields.append(field_schema.to_dict())
|
||||
|
||||
return {
|
||||
"className": config_class.__name__,
|
||||
"classDoc": config_class.__doc__ or "",
|
||||
"fields": schema_fields,
|
||||
"nested": nested_schemas if nested_schemas else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def generate_config_schema(config_class: type[ConfigBase]) -> dict:
|
||||
"""
|
||||
生成完整的配置架构(包含所有嵌套的子配置)
|
||||
|
||||
Args:
|
||||
config_class: 配置类
|
||||
|
||||
Returns:
|
||||
dict: 完整的配置架构
|
||||
"""
|
||||
return ConfigSchemaGenerator.generate_schema(config_class, include_nested=True)
|
||||
562
src/webui/emoji_routes.py
Normal file
562
src/webui/emoji_routes.py
Normal file
@@ -0,0 +1,562 @@
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Emoji
|
||||
from .token_manager import get_token_manager
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
logger = get_logger("webui.emoji")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||
|
||||
|
||||
class EmojiResponse(BaseModel):
|
||||
"""表情包响应"""
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
format: str
|
||||
emoji_hash: str
|
||||
description: str
|
||||
query_count: int
|
||||
is_registered: bool
|
||||
is_banned: bool
|
||||
emotion: Optional[str] # 直接返回字符串
|
||||
record_time: float
|
||||
register_time: Optional[float]
|
||||
usage_count: int
|
||||
last_used_time: Optional[float]
|
||||
|
||||
|
||||
class EmojiListResponse(BaseModel):
|
||||
"""表情包列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[EmojiResponse]
|
||||
|
||||
|
||||
class EmojiDetailResponse(BaseModel):
|
||||
"""表情包详情响应"""
|
||||
|
||||
success: bool
|
||||
data: EmojiResponse
|
||||
|
||||
|
||||
class EmojiUpdateRequest(BaseModel):
|
||||
"""表情包更新请求"""
|
||||
|
||||
description: Optional[str] = None
|
||||
is_registered: Optional[bool] = None
|
||||
is_banned: Optional[bool] = None
|
||||
emotion: Optional[str] = None
|
||||
|
||||
|
||||
class EmojiUpdateResponse(BaseModel):
|
||||
"""表情包更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[EmojiResponse] = None
|
||||
|
||||
|
||||
class EmojiDeleteResponse(BaseModel):
|
||||
"""表情包删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
emoji_ids: List[int]
|
||||
|
||||
|
||||
class BatchDeleteResponse(BaseModel):
|
||||
"""批量删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
deleted_count: int
|
||||
failed_count: int
|
||||
failed_ids: List[int] = []
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
||||
"""将 Emoji 模型转换为响应对象"""
|
||||
return EmojiResponse(
|
||||
id=emoji.id,
|
||||
full_path=emoji.full_path,
|
||||
format=emoji.format,
|
||||
emoji_hash=emoji.emoji_hash,
|
||||
description=emoji.description,
|
||||
query_count=emoji.query_count,
|
||||
is_registered=emoji.is_registered,
|
||||
is_banned=emoji.is_banned,
|
||||
emotion=str(emoji.emotion) if emoji.emotion is not None else None,
|
||||
record_time=emoji.record_time,
|
||||
register_time=emoji.register_time,
|
||||
usage_count=emoji.usage_count,
|
||||
last_used_time=emoji.last_used_time,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", response_model=EmojiListResponse)
|
||||
async def get_emoji_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
||||
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
||||
format: Optional[str] = Query(None, description="格式筛选"),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表情包列表
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 description, emoji_hash)
|
||||
is_registered: 是否已注册筛选
|
||||
is_banned: 是否被禁用筛选
|
||||
format: 格式筛选
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表情包列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Emoji.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
|
||||
|
||||
# 注册状态过滤
|
||||
if is_registered is not None:
|
||||
query = query.where(Emoji.is_registered == is_registered)
|
||||
|
||||
# 禁用状态过滤
|
||||
if is_banned is not None:
|
||||
query = query.where(Emoji.is_banned == is_banned)
|
||||
|
||||
# 格式过滤
|
||||
if format:
|
||||
query = query.where(Emoji.format == format)
|
||||
|
||||
# 排序:使用次数倒序,然后按记录时间倒序
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
emojis = query.offset(offset).limit(page_size)
|
||||
|
||||
# 转换为响应对象
|
||||
data = [emoji_to_response(emoji) for emoji in emojis]
|
||||
|
||||
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表情包列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表情包列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表情包详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表情包详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表情包详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# emotion 字段直接使用字符串,无需转换
|
||||
|
||||
# 如果注册状态从 False 变为 True,记录注册时间
|
||||
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||
update_data["register_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(emoji, field, value)
|
||||
|
||||
emoji.save()
|
||||
|
||||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"更新表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新表情包失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 记录删除信息
|
||||
emoji_hash = emoji.emoji_hash
|
||||
|
||||
# 执行删除
|
||||
emoji.delete_instance()
|
||||
|
||||
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
||||
|
||||
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"删除表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除表情包失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包统计数据
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
total = Emoji.select().count()
|
||||
registered = Emoji.select().where(Emoji.is_registered).count()
|
||||
banned = Emoji.select().where(Emoji.is_banned).count()
|
||||
|
||||
# 按格式统计
|
||||
formats = {}
|
||||
for emoji in Emoji.select(Emoji.format):
|
||||
fmt = emoji.format
|
||||
formats[fmt] = formats.get(fmt, 0) + 1
|
||||
|
||||
# 获取最常用的表情包(前10)
|
||||
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
|
||||
top_used_list = [
|
||||
{
|
||||
"id": emoji.id,
|
||||
"emoji_hash": emoji.emoji_hash,
|
||||
"description": emoji.description,
|
||||
"usage_count": emoji.usage_count,
|
||||
}
|
||||
for emoji in top_used
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"registered": registered,
|
||||
"banned": banned,
|
||||
"unregistered": total - registered,
|
||||
"formats": formats,
|
||||
"top_used": top_used_list,
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取统计数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
if emoji.is_registered:
|
||||
raise HTTPException(status_code=400, detail="该表情包已经注册")
|
||||
|
||||
if emoji.is_banned:
|
||||
raise HTTPException(status_code=400, detail="该表情包已被禁用,无法注册")
|
||||
|
||||
# 注册表情包
|
||||
emoji.is_registered = True
|
||||
emoji.register_time = time.time()
|
||||
emoji.save()
|
||||
|
||||
logger.info(f"表情包已注册: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"注册表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"注册表情包失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 禁用表情包(同时取消注册)
|
||||
emoji.is_banned = True
|
||||
emoji.is_registered = False
|
||||
emoji.save()
|
||||
|
||||
logger.info(f"表情包已禁用: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"禁用表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"禁用表情包失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{emoji_id}/thumbnail")
|
||||
async def get_emoji_thumbnail(
|
||||
emoji_id: int,
|
||||
token: Optional[str] = Query(None, description="访问令牌"),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表情包缩略图
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
token: 访问令牌(通过 query parameter)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表情包图片文件
|
||||
"""
|
||||
try:
|
||||
# 优先使用 query parameter 中的 token(用于 img 标签)
|
||||
if token:
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
else:
|
||||
# 如果没有 query token,则验证 Authorization header
|
||||
verify_auth_token(authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(emoji.full_path):
|
||||
raise HTTPException(status_code=404, detail="表情包文件不存在")
|
||||
|
||||
# 根据格式设置 MIME 类型
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
|
||||
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表情包缩略图失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除表情包
|
||||
|
||||
Args:
|
||||
request: 包含emoji_ids列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
if not request.emoji_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
|
||||
|
||||
deleted_count = 0
|
||||
failed_count = 0
|
||||
failed_ids = []
|
||||
|
||||
for emoji_id in request.emoji_ids:
|
||||
try:
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
if emoji:
|
||||
emoji.delete_instance()
|
||||
deleted_count += 1
|
||||
logger.info(f"批量删除表情包: {emoji_id}")
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_ids.append(emoji_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除表情包 {emoji_id} 失败: {e}")
|
||||
failed_count += 1
|
||||
failed_ids.append(emoji_id)
|
||||
|
||||
message = f"成功删除 {deleted_count} 个表情包"
|
||||
if failed_count > 0:
|
||||
message += f",{failed_count} 个失败"
|
||||
|
||||
return BatchDeleteResponse(
|
||||
success=True,
|
||||
message=message,
|
||||
deleted_count=deleted_count,
|
||||
failed_count=failed_count,
|
||||
failed_ids=failed_ids,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||
432
src/webui/expression_routes.py
Normal file
432
src/webui/expression_routes.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from .token_manager import get_token_manager
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/expression", tags=["Expression"])
|
||||
|
||||
|
||||
class ExpressionResponse(BaseModel):
|
||||
"""表达方式响应"""
|
||||
|
||||
id: int
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str]
|
||||
up_content: Optional[str]
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
create_date: Optional[float]
|
||||
|
||||
|
||||
class ExpressionListResponse(BaseModel):
|
||||
"""表达方式列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[ExpressionResponse]
|
||||
|
||||
|
||||
class ExpressionDetailResponse(BaseModel):
|
||||
"""表达方式详情响应"""
|
||||
|
||||
success: bool
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
class ExpressionCreateRequest(BaseModel):
|
||||
"""表达方式创建请求"""
|
||||
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str] = None
|
||||
up_content: Optional[str] = None
|
||||
chat_id: str
|
||||
|
||||
|
||||
class ExpressionUpdateRequest(BaseModel):
|
||||
"""表达方式更新请求"""
|
||||
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
up_content: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
|
||||
|
||||
class ExpressionUpdateResponse(BaseModel):
|
||||
"""表达方式更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[ExpressionResponse] = None
|
||||
|
||||
|
||||
class ExpressionDeleteResponse(BaseModel):
|
||||
"""表达方式删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ExpressionCreateResponse(BaseModel):
|
||||
"""表达方式创建响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""将 Expression 模型转换为响应对象"""
|
||||
return ExpressionResponse(
|
||||
id=expression.id,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
context=expression.context,
|
||||
up_content=expression.up_content,
|
||||
last_active_time=expression.last_active_time,
|
||||
chat_id=expression.chat_id,
|
||||
create_date=expression.create_date,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", response_model=ExpressionListResponse)
|
||||
async def get_expression_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表达方式列表
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 situation, style, context)
|
||||
chat_id: 聊天ID筛选
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Expression.situation.contains(search))
|
||||
| (Expression.style.contains(search))
|
||||
| (Expression.context.contains(search))
|
||||
)
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
|
||||
# 转换为响应对象
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表达方式列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表达方式详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
Args:
|
||||
request: 创建请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 创建表达方式
|
||||
expression = Expression.create(
|
||||
situation=request.situation,
|
||||
style=request.style,
|
||||
context=request.context,
|
||||
up_content=request.up_content,
|
||||
chat_id=request.chat_id,
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||
|
||||
return ExpressionCreateResponse(
|
||||
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"创建表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 更新最后活跃时间
|
||||
update_data["last_active_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(expression, field, value)
|
||||
|
||||
expression.save()
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return ExpressionUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"更新表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 记录删除信息
|
||||
situation = expression.situation
|
||||
|
||||
# 执行删除
|
||||
expression.delete_instance()
|
||||
|
||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"删除表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
ids: List[int]
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
Args:
|
||||
request: 包含要删除的ID列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
|
||||
# 查找所有要删除的表达方式
|
||||
expressions = Expression.select().where(Expression.id.in_(request.ids))
|
||||
found_ids = [expr.id for expr in expressions]
|
||||
|
||||
# 检查是否有未找到的ID
|
||||
not_found_ids = set(request.ids) - set(found_ids)
|
||||
if not_found_ids:
|
||||
logger.warning(f"部分表达方式未找到: {not_found_ids}")
|
||||
|
||||
# 执行批量删除
|
||||
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
|
||||
|
||||
logger.info(f"批量删除了 {deleted_count} 个表达方式")
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
|
||||
# 按 chat_id 统计
|
||||
chat_stats = {}
|
||||
for expr in Expression.select(Expression.chat_id):
|
||||
chat_id = expr.chat_id
|
||||
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
||||
|
||||
# 获取最近创建的记录数(7天内)
|
||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||
recent = (
|
||||
Expression.select()
|
||||
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"recent_7days": recent,
|
||||
"chat_count": len(chat_stats),
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取统计数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
662
src/webui/git_mirror_service.py
Normal file
662
src/webui/git_mirror_service.py
Normal file
@@ -0,0 +1,662 @@
|
||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
import httpx
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.git_mirror")
|
||||
|
||||
# 导入进度更新函数(避免循环导入)
|
||||
_update_progress = None
|
||||
|
||||
|
||||
def set_update_progress_callback(callback):
|
||||
"""设置进度更新回调函数"""
|
||||
global _update_progress
|
||||
_update_progress = callback
|
||||
|
||||
|
||||
class MirrorType(str, Enum):
|
||||
"""镜像源类型"""
|
||||
|
||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
|
||||
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
|
||||
GITHUB = "github" # GitHub 官方源(兜底)
|
||||
CUSTOM = "custom" # 自定义镜像源
|
||||
|
||||
|
||||
class GitMirrorConfig:
|
||||
"""Git 镜像源配置管理"""
|
||||
|
||||
# 配置文件路径
|
||||
CONFIG_FILE = Path("data/webui.json")
|
||||
|
||||
# 默认镜像源配置
|
||||
DEFAULT_MIRRORS = [
|
||||
{
|
||||
"id": "gh-proxy",
|
||||
"name": "gh-proxy 镜像",
|
||||
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "hk-gh-proxy",
|
||||
"name": "gh-proxy 香港节点",
|
||||
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 2,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "cdn-gh-proxy",
|
||||
"name": "gh-proxy CDN 节点",
|
||||
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 3,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "edgeone-gh-proxy",
|
||||
"name": "gh-proxy EdgeOne 节点",
|
||||
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 4,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "meyzh-github",
|
||||
"name": "Meyzh GitHub 镜像",
|
||||
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 5,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "github",
|
||||
"name": "GitHub 官方源(兜底)",
|
||||
"raw_prefix": "https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 999,
|
||||
"created_at": None,
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""初始化配置管理器"""
|
||||
self.config_file = self.CONFIG_FILE
|
||||
self.mirrors: List[Dict[str, Any]] = []
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 检查是否有镜像源配置
|
||||
if "git_mirrors" not in data or not data["git_mirrors"]:
|
||||
logger.info("配置文件中未找到镜像源配置,使用默认配置")
|
||||
self._init_default_mirrors()
|
||||
else:
|
||||
self.mirrors = data["git_mirrors"]
|
||||
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
|
||||
else:
|
||||
logger.info("配置文件不存在,创建默认配置")
|
||||
self._init_default_mirrors()
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
self._init_default_mirrors()
|
||||
|
||||
def _init_default_mirrors(self) -> None:
|
||||
"""初始化默认镜像源"""
|
||||
current_time = datetime.now().isoformat()
|
||||
self.mirrors = []
|
||||
|
||||
for mirror in self.DEFAULT_MIRRORS:
|
||||
mirror_copy = mirror.copy()
|
||||
mirror_copy["created_at"] = current_time
|
||||
self.mirrors.append(mirror_copy)
|
||||
|
||||
self._save_config()
|
||||
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
|
||||
|
||||
def _save_config(self) -> None:
|
||||
"""保存配置到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 读取现有配置
|
||||
existing_data = {}
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
# 更新镜像源配置
|
||||
existing_data["git_mirrors"] = self.mirrors
|
||||
|
||||
# 写入文件
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.debug(f"配置已保存到 {self.config_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
|
||||
def get_all_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有镜像源"""
|
||||
return self.mirrors.copy()
|
||||
|
||||
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有启用的镜像源,按优先级排序"""
|
||||
enabled = [m for m in self.mirrors if m.get("enabled", False)]
|
||||
return sorted(enabled, key=lambda x: x.get("priority", 999))
|
||||
|
||||
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 ID 获取镜像源"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
return mirror.copy()
|
||||
return None
|
||||
|
||||
def add_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
name: str,
|
||||
raw_prefix: str,
|
||||
clone_prefix: str,
|
||||
enabled: bool = True,
|
||||
priority: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
添加新的镜像源
|
||||
|
||||
Returns:
|
||||
添加的镜像源配置
|
||||
|
||||
Raises:
|
||||
ValueError: 如果镜像源 ID 已存在
|
||||
"""
|
||||
# 检查 ID 是否已存在
|
||||
if self.get_mirror_by_id(mirror_id):
|
||||
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
|
||||
|
||||
# 如果未指定优先级,使用最大优先级 + 1
|
||||
if priority is None:
|
||||
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
|
||||
priority = max_priority + 1
|
||||
|
||||
new_mirror = {
|
||||
"id": mirror_id,
|
||||
"name": name,
|
||||
"raw_prefix": raw_prefix,
|
||||
"clone_prefix": clone_prefix,
|
||||
"enabled": enabled,
|
||||
"priority": priority,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
self.mirrors.append(new_mirror)
|
||||
self._save_config()
|
||||
|
||||
logger.info(f"已添加镜像源: {mirror_id} - {name}")
|
||||
return new_mirror.copy()
|
||||
|
||||
def update_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
name: Optional[str] = None,
|
||||
raw_prefix: Optional[str] = None,
|
||||
clone_prefix: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
priority: Optional[int] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新镜像源配置
|
||||
|
||||
Returns:
|
||||
更新后的镜像源配置,如果不存在则返回 None
|
||||
"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
if name is not None:
|
||||
mirror["name"] = name
|
||||
if raw_prefix is not None:
|
||||
mirror["raw_prefix"] = raw_prefix
|
||||
if clone_prefix is not None:
|
||||
mirror["clone_prefix"] = clone_prefix
|
||||
if enabled is not None:
|
||||
mirror["enabled"] = enabled
|
||||
if priority is not None:
|
||||
mirror["priority"] = priority
|
||||
|
||||
mirror["updated_at"] = datetime.now().isoformat()
|
||||
self._save_config()
|
||||
|
||||
logger.info(f"已更新镜像源: {mirror_id}")
|
||||
return mirror.copy()
|
||||
|
||||
return None
|
||||
|
||||
def delete_mirror(self, mirror_id: str) -> bool:
|
||||
"""
|
||||
删除镜像源
|
||||
|
||||
Returns:
|
||||
True 如果删除成功,False 如果镜像源不存在
|
||||
"""
|
||||
for i, mirror in enumerate(self.mirrors):
|
||||
if mirror.get("id") == mirror_id:
|
||||
self.mirrors.pop(i)
|
||||
self._save_config()
|
||||
logger.info(f"已删除镜像源: {mirror_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_default_priority_list(self) -> List[str]:
|
||||
"""获取默认优先级列表(仅启用的镜像源 ID)"""
|
||||
enabled = self.get_enabled_mirrors()
|
||||
return [m["id"] for m in enabled]
|
||||
|
||||
|
||||
class GitMirrorService:
|
||||
"""Git 镜像源服务"""
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||
"""
|
||||
初始化 Git 镜像源服务
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
timeout: 请求超时时间(秒)
|
||||
config: 镜像源配置管理器(可选,默认创建新实例)
|
||||
"""
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
self.config = config or GitMirrorConfig()
|
||||
logger.info(f"Git镜像源服务初始化完成,已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
|
||||
|
||||
def get_mirror_config(self) -> GitMirrorConfig:
|
||||
"""获取镜像源配置管理器"""
|
||||
return self.config
|
||||
|
||||
@staticmethod
|
||||
def check_git_installed() -> Dict[str, Any]:
|
||||
"""
|
||||
检查本机是否安装了 Git
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- installed: bool - 是否已安装 Git
|
||||
- version: str - Git 版本号(如果已安装)
|
||||
- path: str - Git 可执行文件路径(如果已安装)
|
||||
- error: str - 错误信息(如果未安装或检测失败)
|
||||
"""
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
try:
|
||||
# 查找 git 可执行文件路径
|
||||
git_path = shutil.which("git")
|
||||
|
||||
if not git_path:
|
||||
logger.warning("未找到 Git 可执行文件")
|
||||
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||
|
||||
# 获取 Git 版本
|
||||
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||
return {"installed": True, "version": version, "path": git_path}
|
||||
else:
|
||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Git 版本检测超时")
|
||||
return {"installed": False, "error": "Git 版本检测超时"}
|
||||
except Exception as e:
|
||||
logger.error(f"检测 Git 时发生错误: {e}")
|
||||
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||
|
||||
async def fetch_raw_file(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
branch: 分支名称
|
||||
file_path: 文件路径
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义完整 URL(如果提供,将忽略其他参数)
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
- data: str - 文件内容(成功时)
|
||||
- error: str - 错误信息(失败时)
|
||||
- mirror_used: str - 使用的镜像源
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._fetch_with_url(custom_url, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
total_mirrors = len(mirrors_to_try)
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for index, mirror in enumerate(mirrors_to_try, 1):
|
||||
# 推送进度:正在尝试第 N 个镜像源
|
||||
if _update_progress:
|
||||
try:
|
||||
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=progress,
|
||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||
|
||||
if result["success"]:
|
||||
# 成功,推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=70,
|
||||
message=f"成功从 {mirror['name']} 获取数据",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
return result
|
||||
|
||||
# 失败,记录日志并推送失败信息
|
||||
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
|
||||
|
||||
if _update_progress and index < total_mirrors:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=30 + int(index / total_mirrors * 40),
|
||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _fetch_raw_from_mirror(
|
||||
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源获取文件"""
|
||||
# 构建 URL
|
||||
raw_prefix = mirror["raw_prefix"]
|
||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||
|
||||
return await self._fetch_with_url(url, mirror["id"])
|
||||
|
||||
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
|
||||
"""使用指定 URL 获取文件,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
try:
|
||||
logger.debug(f"尝试 #{attempt + 1}: {url}")
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.info(f"成功获取文件: {url}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": response.text,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
except httpx.TimeoutException as e:
|
||||
last_error = f"请求超时: {e}"
|
||||
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
async def clone_repository(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str] = None,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
depth: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
克隆 GitHub 仓库
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
target_path: 目标路径
|
||||
branch: 分支名称(可选)
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义克隆 URL
|
||||
depth: 克隆深度(浅克隆)
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
- path: str - 克隆路径(成功时)
|
||||
- error: str - 错误信息(失败时)
|
||||
- mirror_used: str - 使用的镜像源
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for mirror in mirrors_to_try:
|
||||
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||
if result["success"]:
|
||||
return result
|
||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _clone_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源克隆仓库"""
|
||||
# 构建克隆 URL
|
||||
clone_prefix = mirror["clone_prefix"]
|
||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||
|
||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||
|
||||
async def _clone_with_url(
|
||||
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用指定 URL 克隆仓库,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
|
||||
try:
|
||||
# 确保目标路径不存在
|
||||
if target_path.exists():
|
||||
logger.warning(f"目标路径已存在,删除: {target_path}")
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
# 构建 git clone 命令
|
||||
cmd = ["git", "clone"]
|
||||
|
||||
# 添加分支参数
|
||||
if branch:
|
||||
cmd.extend(["-b", branch])
|
||||
|
||||
# 添加深度参数(浅克隆)
|
||||
if depth:
|
||||
cmd.extend(["--depth", str(depth)])
|
||||
|
||||
# 添加 URL 和目标路径
|
||||
cmd.extend([url, str(target_path)])
|
||||
|
||||
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
|
||||
|
||||
# 推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=20 + attempt * 10,
|
||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||
operation="install",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def run_git_clone():
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
|
||||
process = await loop.run_in_executor(None, run_git_clone)
|
||||
|
||||
if process.returncode == 0:
|
||||
logger.info(f"成功克隆仓库: {url} -> {target_path}")
|
||||
return {
|
||||
"success": True,
|
||||
"path": str(target_path),
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
"branch": branch or "default",
|
||||
}
|
||||
else:
|
||||
last_error = f"Git 克隆失败: {process.stderr}"
|
||||
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
last_error = "克隆超时(超过 5 分钟)"
|
||||
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
except FileNotFoundError:
|
||||
last_error = "Git 未安装或不在 PATH 中"
|
||||
logger.error(f"Git 未找到: {last_error}")
|
||||
break # Git 不存在,不需要重试
|
||||
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
_git_mirror_service: Optional[GitMirrorService] = None
|
||||
|
||||
|
||||
def get_git_mirror_service() -> GitMirrorService:
|
||||
"""获取 Git 镜像源服务实例(单例)"""
|
||||
global _git_mirror_service
|
||||
if _git_mirror_service is None:
|
||||
_git_mirror_service = GitMirrorService()
|
||||
return _git_mirror_service
|
||||
0
src/webui/log_broadcaster.py
Normal file
0
src/webui/log_broadcaster.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user