Merge pull request #1377 from Mai-with-u/dev

Dev
This commit is contained in:
SengokuCola
2025-11-20 21:05:59 +08:00
committed by GitHub
152 changed files with 3123 additions and 2778 deletions

View File

@@ -1,11 +1,12 @@
name: Docker Build and Push
on:
schedule:
- cron: '0 0 * * *'
push:
branches:
- main
- classical
- dev
tags:
- "v*.*.*"
- "v*"
@@ -24,6 +25,7 @@ jobs:
- name: Check out git repository
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
fetch-depth: 0
# Clone required dependencies
@@ -77,6 +79,7 @@ jobs:
- name: Check out git repository
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
fetch-depth: 0
# Clone required dependencies

3
.gitignore vendored
View File

@@ -11,7 +11,6 @@ run_maibot_core.bat
run_voice.bat
run_napcat_adapter.bat
run_ad.bat
s4u.s4u
llm_tool_benchmark_results.json
MaiBot-Napcat-Adapter-main
MaiBot-Napcat-Adapter
@@ -27,6 +26,7 @@ run.bat
log_debug/
run_amds.bat
run_none.bat
docs-mai/
run.py
message_queue_content.txt
message_queue_content.bat
@@ -51,6 +51,7 @@ template/compare/model_config_template.toml
src/plugins/utils/statistic.py
CLAUDE.md
MaiBot-Dashboard/
cloudflare-workers/
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@@ -25,6 +25,8 @@ WORKDIR /MaiMBot
# 复制依赖列表
COPY requirements.txt .
RUN apt-get update && apt-get install -y git
# 从编译阶段复制 LPMM 编译结果
COPY --from=lpmm-builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/

View File

@@ -25,11 +25,12 @@
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制
- 🤔 **实时思维系统**:模拟人类思考过程。
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
- 💝 **情感表达系统**:情绪系统和表情包系统。
- 🔌 **强大插件系统**提供API和事件系统可编写强大插件。
- 💭 **拟人构建的prompt**使用自然语言风格构建回复器的prompt实现近似人类言语习惯的回复
- 💭 **行为规划**:在合适的时间说话,使用合适的动作
- 🧠 **表达学习**:学习群友的说话风格和表达方式,学会真实人类的说话风格
- 🤔 **黑话学习**:自主的学习没有见过的词语,尝试理解并认知含义
- 🔌 **插件系统**提供API和事件系统可编写丰富插件。
- 💝 **情感表达**:情绪系统和表情包系统。
<div style="text-align: center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
@@ -44,7 +45,9 @@
## 🔥 更新和安装
**最新版本: v0.11.4** ([更新日志](changelogs/changelog.md))
**最新版本: v0.11.5** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
@@ -60,15 +63,10 @@
> [!WARNING]
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
> - 有问题可以提交 Issue 或者 Discussion
> - 有问题可以提交 Issue 。
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
> - 由于程序处于开发中,可能消耗较多 token。
## 麦麦MC项目MaiCraft早期开发
[让麦麦玩MC](https://github.com/MaiM-with-u/Maicraft)
交流群1058573197
## 💬 讨论
**技术交流群:**
@@ -80,7 +78,7 @@
**聊天吹水群:**
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
**插件开发测试版群:**
**插件开发/测试版讨论群:**
- [插件开发群](https://qm.qq.com/q/1036092828)
## 📚 文档
@@ -89,7 +87,22 @@
- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。
### 设计理念(原始时代的火花)
## 📚 衍生项目
### MaiCraft早期开发
[MaiCraft](https://github.com/MaiM-with-u/Maicraft)
> 让麦麦具有玩MC能力的项目
> 交流群1058573197
### MoFox_Bot
[MoFox - 仓库地址](https://github.com/MoFox-Studio/MoFox-Core)
> MoFox_Bot 是一个基于 MaiCore 0.10.0 snapshot.5 的增强型 fork 项目
> 我们保留了原项目几乎所有核心功能,并在此基础上进行了深度优化与功能扩展,致力于打造一个更稳定、更智能、更具趣味性的 AI 智能体。
## 设计理念(原始时代的火花)
> **千石可乐说:**
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
@@ -101,7 +114,7 @@
## 🙋 贡献和致谢
你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦!
MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr都对项目非常宝贵。我们非常感谢你的支持🎉
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)。(待补完)
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs-src/CONTRIBUTE.md)。(待补完)
### 贡献者

17
bot.py
View File

@@ -1,7 +1,6 @@
import asyncio
import hashlib
import os
import sys
import time
import platform
import traceback
@@ -30,7 +29,7 @@ else:
raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
initialize_logging()
@@ -76,6 +75,15 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
try:
logger.info("正在优雅关闭麦麦...")
# 关闭 WebUI 服务器
try:
from src.webui.webui_server import get_webui_server
webui_server = get_webui_server()
if webui_server and webui_server._server:
await webui_server.shutdown()
except Exception as e:
logger.warning(f"关闭 WebUI 服务器时出错: {e}")
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType
@@ -212,9 +220,10 @@ if __name__ == "__main__":
# 创建事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 初始化 WebSocket 日志推送
from src.common.logger import initialize_ws_handler
initialize_ws_handler(loop)
try:
@@ -251,7 +260,7 @@ if __name__ == "__main__":
print(f"关闭日志系统时出错: {e}")
print("[主程序] 准备退出...")
# 使用 os._exit() 强制退出,避免被阻塞
# 由于已经在 graceful_shutdown() 中完成了所有清理工作,这是安全的
os._exit(exit_code)

View File

@@ -1,6 +1,88 @@
# Changelog
## [0.11.5] - 2025-11-21
### 🌟 重大更新
- WebUI 现支持手动重启麦麦,曲线救国版“热重载”
- 新增麦麦 QQ 适配器可视化编辑 UI独立进程需手动上传/下载并覆盖适配器文件)
- 麦麦主程序配置支持可视化模式与源代码模式双模式编辑,后端执行 TOML 校验
- 优化 planner 与 replyer 协同机制,调试日志更细
## [0.11.3] - 2025-11-17
### 新增
- 表情包管理、人物信息管理、表达方式管理界面手机端适配
- 配置页“重启麦麦”提示
- 详细的debug prompt显示配置
- 麦麦界面操作主题色按钮
- 前端集成 CodeMirrorPython/JSON/TOML 语法高亮)并对 JSON 配置提供自动纠错提示
### 修复
- 表情包缩略图过小
- 添加模型后无法立即显示
- 插件市场分类错误
- 浅色模式下日志查看器背景异常
- 表情包详情窗口无法关闭
- 情绪标签无法正常读取issues/1373
- 模型任务配置界面模型温度不可更改issues/1369
- 模型任务配置界面部分输入框无法删除默认值
- 插件商店默认勾选“仅显示兼容插件”
- 插件商店标签页数量显示错误
- 日志查看器日志换行异常
- 主题色未正确应用
- 侧边栏滚动条问题
- 移除前端 TOML 解析库导致的兼容问题
### 优化
- 首页加载用户体验
- 首页加载速度9s → <1s
- 整个界面的初屏加载速度
### 更新
- 适配器配置支持上传模式与指定路径模式;指定路径模式可免去反复上传/下载配置文件
- 适配器配置界面标签页响应式优化,小屏仅显示简短标签
- 麦麦资源管理所有界面的批量删除能力
- 资源管理所有界面的分页、每页数量选择与页跳转
- 插件市场新增点赞、点踩、评分与下载量统计(基于 Cloudflare保证国内可访问
- 麦麦表情包查看界面的描述部分支持 Markdown 渲染
## [0.11.4] - 2025-11-19
### 🌟 主要更新内容
- **首个官方 Web 管理界面上线**在此版本之前MaiBot 没有 WebUI所有配置需手动编辑 TOML 文件
- **认证系统**Token 安全登录(支持系统生成 64 位随机令牌 / 自定义 Token首次配置向导
- **配置管理(可视化编辑,无需手动改 TOML**
- 麦麦主程序配置:基础设置、人格、表情、黑话、情绪等
- 模型提供商配置OpenAI、Anthropic、DeepSeek、Qwen、Ollama 等
- 模型配置:对话/视觉/嵌入模型分配
- **资源管理**
- 表情包管理:查看、搜索、注册、封禁
- 表达方式管理:查看麦麦的表达记录
- 人物信息管理:查看联系人列表
- **插件系统**
- 插件市场浏览
- 一键安装/卸载/更新
- 版本兼容性检查
- 实时安装进度推送
- **日志查看器**
- WebSocket 实时日志流
- 日志级别过滤DEBUG/INFO/WARNING/ERROR/CRITICAL
- 搜索功能
- **主题定制**
- 浅色/深色/跟随系统
- 12 种主题色6 单色 + 6 渐变色)
- 自定义颜色选择器
- **全局搜索**Cmd/Ctrl + K 快捷键,快速跳转任意页面
### 细节
- **技术栈**
- 前端: React 19 + TypeScript + Vite + TanStack Router + shadcn/ui
- 后端: FastAPI + Uvicorn + WebSocket
- 特点: SPA 单页应用,前后端同端口,静态文件托管
- **使用方式**:参照 template.env 文件更新 .env 文件,添加两个字段:
- `WEBUI_ENABLED=true`
- `WEBUI_MODE=production`
- **WebUI 开源协议**GPLv3
- **WebUI 地址**https://github.com/Mai-with-u/MaiBot-Dashboard
告别手动编辑配置文件,享受现代化图形界面!
## [0.11.3] - 2025-11-18
### 功能更改和修复
- 优化记忆提取策略
- 优化黑话提取

View File

@@ -29,7 +29,8 @@ services:
- TZ=Asia/Shanghai
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
# ports:
ports:
- "18001:8001" # webui端口
# - "8000:8000"
volumes:
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件

View File

Before

Width:  |  Height:  |  Size: 4.1 KiB

After

Width:  |  Height:  |  Size: 4.1 KiB

View File

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

View File

@@ -1,336 +0,0 @@
# 模型配置指南
本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
## 配置文件结构
配置文件主要包含以下几个部分:
- 版本信息
- API服务提供商配置
- 模型配置
- 模型任务配置
## 1. 版本信息
```toml
[inner]
version = "1.1.1"
```
用于标识配置文件的版本,遵循语义化版本规则。
## 2. API服务提供商配置
### 2.1 基本配置
使用 `[[api_providers]]` 数组配置多个API服务提供商
```toml
[[api_providers]]
name = "DeepSeek" # 服务商名称(自定义)
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
api_key = "your-api-key-here" # API密钥
client_type = "openai" # 客户端类型
max_retry = 2 # 最大重试次数
timeout = 30 # 超时时间(秒)
retry_interval = 10 # 重试间隔(秒)
```
### 2.2 配置参数说明
| 参数 | 必填 | 说明 | 默认值 |
|------|------|------|--------|
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
| `base_url` | ✅ | API服务的基础URL | - |
| `api_key` | ✅ | API密钥请替换为实际密钥 | - |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式 | `openai` |
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
| `timeout` | ❌ | API请求超时时间 | 30 |
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
### 2.3 支持的服务商示例
#### DeepSeek
```toml
[[api_providers]]
name = "DeepSeek"
base_url = "https://api.deepseek.com/v1"
api_key = "your-deepseek-api-key"
client_type = "openai"
```
#### SiliconFlow
```toml
[[api_providers]]
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
api_key = "your-siliconflow-api-key"
client_type = "openai"
```
#### Google Gemini
```toml
[[api_providers]]
name = "Google"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "your-google-api-key"
client_type = "gemini" # 注意Gemini需要使用特殊客户端
```
## 3. 模型配置
### 3.1 基本模型配置
使用 `[[models]]` 数组配置多个模型:
```toml
[[models]]
model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
name = "deepseek-v3" # 自定义模型名称
api_provider = "DeepSeek" # 引用的API服务商名称
price_in = 2.0 # 输入价格(元/M token
price_out = 8.0 # 输出价格(元/M token
```
### 3.2 高级模型配置
#### 强制流式输出
对于不支持非流式输出的模型:
```toml
[[models]]
model_identifier = "some-model"
name = "custom-name"
api_provider = "Provider"
force_stream_mode = true # 启用强制流式输出
```
#### 额外参数配置`extra_params`
```toml
[[models]]
model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b"
api_provider = "SiliconFlow"
[models.extra_params]
enable_thinking = false # 禁用思考
```
这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置**配置时应参考相应的API文档**。
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
![SiliconFlow文档截图](image-1.png)
以豆包文档为另一个例子
![豆包文档截图](image.png)
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
```toml
[[models]]
# 你的模型
[models.extra_params]
thinking = {type = "disabled"} # 禁用思考
```
而对于`gemini`需要单独进行配置
```toml
[[models]]
model_identifier = "gemini-2.5-flash"
name = "gemini-2.5-flash"
api_provider = "Google"
[models.extra_params]
thinking_budget = 0 # 禁用思考
# thinking_budget = -1 由模型自己决定
```
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构具体内容取决于API服务商的要求。
### 3.3 配置参数说明
| 参数 | 必填 | 说明 |
|------|------|------|
| `model_identifier` | ✅ | API服务商提供的模型标识符 |
| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
| `api_provider` | ✅ | 对应的API服务商名称 |
| `price_in` | ❌ | 输入价格(元/M token用于成本统计 |
| `price_out` | ❌ | 输出价格(元/M token用于成本统计 |
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
| `extra_params` | ❌ | 额外的模型参数配置 |
## 4. 模型任务配置
### utils - 工具模型
用于表情包模块、取名模块、关系模块等核心功能:
```toml
[model_task_config.utils]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### utils_small - 小型工具模型
用于高频率调用的场景,建议使用速度快的小模型:
```toml
[model_task_config.utils_small]
model_list = ["qwen3-8b"]
temperature = 0.7
max_tokens = 800
```
### replyer - 主要回复模型
首要回复模型,也用于表达器和表达方式学习:
```toml
[model_task_config.replyer]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### planner - 决策模型
负责决定MaiBot该做什么
```toml
[model_task_config.planner]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.3
max_tokens = 800
```
### emotion - 情绪模型
负责MaiBot的情绪变化
```toml
[model_task_config.emotion]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.3
max_tokens = 800
```
### memory - 记忆模型
```toml
[model_task_config.memory]
model_list = ["qwen3-30b"]
temperature = 0.7
max_tokens = 800
```
### vlm - 视觉语言模型
用于图像识别:
```toml
[model_task_config.vlm]
model_list = ["qwen2.5-vl-72b"]
max_tokens = 800
```
### voice - 语音识别模型
```toml
[model_task_config.voice]
model_list = ["sensevoice-small"]
```
### embedding - 嵌入模型
```toml
[model_task_config.embedding]
model_list = ["bge-m3"]
```
### tool_use - 工具调用模型
需要使用支持工具调用的模型:
```toml
[model_task_config.tool_use]
model_list = ["qwen3-14b"]
temperature = 0.7
max_tokens = 800
```
### lpmm_entity_extract - 实体提取模型
```toml
[model_task_config.lpmm_entity_extract]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### lpmm_rdf_build - RDF构建模型
```toml
[model_task_config.lpmm_rdf_build]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### lpmm_qa - 问答模型
```toml
[model_task_config.lpmm_qa]
model_list = ["deepseek-r1-distill-qwen-32b"]
temperature = 0.7
max_tokens = 800
```
## 5. 配置建议
### 5.1 Temperature 参数选择
| 任务类型 | 推荐温度 | 说明 |
|----------|----------|------|
| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
### 5.2 模型选择建议
| 任务类型 | 推荐模型类型 | 示例 |
|----------|--------------|------|
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
| 高频率任务 | 小模型 | Qwen3-8B |
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
### 5.3 成本优化
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
3. **选择免费模型**对于测试环境优先使用price为0的模型
## 6. 配置验证
### 6.1 必要检查项
1. ✅ API密钥是否正确配置
2. ✅ 模型标识符是否与API服务商提供的一致
3. ✅ 任务配置中引用的模型名称是否在models中定义
4. ✅ 多模态任务是否配置了对应的专用模型
### 6.2 测试配置
建议在正式使用前:
1. 使用少量测试数据验证配置
2. 检查API调用是否正常
3. 确认成本统计功能正常工作
## 7. 故障排除
### 7.1 常见问题
**问题1**: API调用失败
- 检查API密钥是否正确
- 确认base_url是否可访问
- 检查模型标识符是否正确
**问题2**: 模型未找到
- 确认模型名称在任务配置和模型定义中一致
- 检查api_provider名称是否匹配
**问题3**: 响应异常
- 检查温度参数是否合理0-1之间
- 确认max_tokens设置是否合适
- 验证模型是否支持所需功能
### 7.2 日志查看
查看 `logs/` 目录下的日志文件,寻找相关错误信息。
## 8. 更新和维护
1. **定期更新**: 关注API服务商的模型更新及时调整配置
2. **性能监控**: 监控模型调用的成本和性能
3. **备份配置**: 在修改前备份当前配置文件

View File

@@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
SECONDS_5_MINUTES = 5 * 60

View File

@@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

File diff suppressed because it is too large Load Diff

View File

@@ -230,7 +230,7 @@ class HeartFChatting:
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
mentioned_message = message
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
# *控制频率用
if mentioned_message:
@@ -333,7 +333,6 @@ class HeartFChatting:
# 重置连续 no_reply 计数
self.consecutive_no_reply_count = 0
reason = ""
await database_api.store_action_info(
chat_stream=self.chat_stream,
@@ -411,7 +410,7 @@ class HeartFChatting:
# asyncio.create_task(self.chat_history_summarizer.process())
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})")
# 第一步:动作检查
available_actions: Dict[str, ActionInfo] = {}

View File

@@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None
inspire_manager = None
def get_qa_manager():
return qa_manager
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable:

View File

@@ -92,9 +92,10 @@ class QAManager:
# 过滤阈值
result = dyn_select_top_k(result, 0.5, 1.0)
for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
if global_config.debug.show_lpmm_paragraph:
for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights
@@ -128,11 +129,10 @@ class QAManager:
selected_knowledge = knowledge[:limit]
formatted_knowledge = [
f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}"
for i, k in enumerate(selected_knowledge)
f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
]
# if max_score is not None:
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
found_knowledge = "\n".join(formatted_knowledge)
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:

View File

@@ -7,7 +7,6 @@ from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
@@ -164,6 +163,45 @@ class ActionPlanner:
return item[1]
return None
def _replace_message_ids_with_text(
self, text: Optional[str], message_id_list: List[Tuple[str, "DatabaseMessages"]]
) -> Optional[str]:
"""将文本中的 m+数字 消息ID替换为原消息内容并添加双引号"""
if not text:
return text
id_to_message = {msg_id: msg for msg_id, msg in message_id_list}
# 匹配m后带2-4位数字前后不是字母数字下划线
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
matches = re.findall(pattern, text)
if matches:
available_ids = set(id_to_message.keys())
found_ids = set(matches)
missing_ids = found_ids - available_ids
if missing_ids:
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...")
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用其中{len(found_ids & available_ids)}个在上下文中")
def _replace(match: re.Match[str]) -> str:
msg_id = match.group(0)
message = id_to_message.get(msg_id)
if not message:
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
return msg_id
msg_text = (message.processed_plain_text or message.display_message or "").strip()
if not msg_text:
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
return msg_id
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview}")
return f"消息({msg_text}"
return re.sub(pattern, _replace, text)
def _parse_single_action(
self,
action_json: dict,
@@ -176,7 +214,10 @@ class ActionPlanner:
try:
action = action_json.get("action", "no_reply")
reasoning = action_json.get("reason", "未提供原因")
original_reasoning = action_json.get("reason", "未提供原因")
reasoning = self._replace_message_ids_with_text(original_reasoning, message_id_list)
if reasoning is None:
reasoning = original_reasoning
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
# 非no_reply动作需要target_message_id
target_message = None
@@ -573,9 +614,6 @@ class ActionPlanner:
# 调用LLM
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
if global_config.debug.show_planner_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
@@ -604,6 +642,7 @@ class ActionPlanner:
if llm_content:
try:
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
extracted_reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) or ""
if json_objects:
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
filtered_actions_list = list(filtered_actions.items())

View File

@@ -226,7 +226,9 @@ class DefaultReplyer:
traceback.print_exc()
return False, llm_response
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块
@@ -1094,10 +1096,10 @@ class DefaultReplyer:
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
if global_config.lpmm_knowledge.lpmm_mode == "agent":
return ""
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
@@ -1115,10 +1117,10 @@ class DefaultReplyer:
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
# logger.info(f"工具调用提示词: {prompt}")
# logger.info(f"工具调用: {tool_calls}")
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
end_time = time.time()

View File

@@ -241,7 +241,9 @@ class PrivateReplyer:
return f"{sender_relation}"
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块
@@ -1032,10 +1034,10 @@ class PrivateReplyer:
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
if global_config.lpmm_knowledge.lpmm_mode == "agent":
return ""
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname

View File

@@ -106,8 +106,8 @@ class ChatHistorySummarizer:
await self._check_and_package(current_time)
self.last_check_time = current_time
return
logger.info(
logger.debug(
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
)
@@ -119,7 +119,7 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
logger.info(f"{self.log_prefix} 更新聊天话题: {before_count} -> {len(self.current_batch.messages)} 条消息")
else:
# 创建新批次
self.current_batch = MessageBatch(
@@ -127,7 +127,7 @@ class ChatHistorySummarizer:
start_time=new_messages[0].time if new_messages else current_time,
end_time=current_time,
)
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
logger.info(f"{self.log_prefix} 新建聊天话题: {len(new_messages)} 条消息")
# 检查是否需要打包
await self._check_and_package(current_time)

View File

@@ -72,16 +72,16 @@ def get_ws_handler():
def initialize_ws_handler(loop):
"""初始化 WebSocket handler 的事件循环
Args:
loop: asyncio 事件循环
"""
handler = get_ws_handler()
handler.set_loop(loop)
# 为 WebSocket handler 设置 JSON 格式化器(与文件格式相同)
handler.setFormatter(file_formatter)
# 添加到根日志记录器
root_logger = logging.getLogger()
if handler not in root_logger.handlers:
@@ -177,43 +177,44 @@ class TimestampedFileHandler(logging.Handler):
class WebSocketLogHandler(logging.Handler):
"""WebSocket 日志处理器 - 将日志实时推送到前端"""
_log_counter = 0 # 类级别计数器,确保 ID 唯一性
def __init__(self, loop=None):
super().__init__()
self.loop = loop
self._initialized = False
def set_loop(self, loop):
"""设置事件循环"""
self.loop = loop
self._initialized = True
def emit(self, record):
"""发送日志到 WebSocket 客户端"""
if not self._initialized or self.loop is None:
return
try:
# 获取格式化后的消息
# 对于 structlog,formatted message 包含完整的日志信息
formatted_msg = self.format(record) if self.formatter else record.getMessage()
# 如果是 JSON 格式(文件格式化器),解析它
message = formatted_msg
try:
import json
log_dict = json.loads(formatted_msg)
message = log_dict.get('event', formatted_msg)
message = log_dict.get("event", formatted_msg)
except (json.JSONDecodeError, ValueError):
# 不是 JSON,直接使用消息
message = formatted_msg
# 生成唯一 ID: 时间戳毫秒 + 自增计数器
WebSocketLogHandler._log_counter += 1
log_id = f"{int(record.created * 1000)}_{WebSocketLogHandler._log_counter}"
# 格式化日志数据
log_data = {
"id": log_id,
@@ -222,20 +223,17 @@ class WebSocketLogHandler(logging.Handler):
"module": record.name,
"message": message,
}
# 异步广播日志(不阻塞日志记录)
try:
import asyncio
from src.webui.logs_ws import broadcast_log
asyncio.run_coroutine_threadsafe(
broadcast_log(log_data),
self.loop
)
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
except Exception:
# WebSocket 推送失败不影响日志记录
pass
except Exception:
# 不要让 WebSocket 错误影响日志系统
self.handleError(record)
@@ -255,7 +253,7 @@ def close_handlers():
if _console_handler:
_console_handler.close()
_console_handler = None
if _ws_handler:
_ws_handler.close()
_ws_handler = None

View File

@@ -1,5 +1,4 @@
from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware # 新增导入
from typing import Optional
from uvicorn import Config, Server as UvicornServer
import asyncio
@@ -17,21 +16,6 @@ class Server:
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
# 配置 CORS
origins = [
"http://localhost:7999", # 允许的前端源
"http://127.0.0.1:7999",
# 在生产环境中,您应该添加实际的前端域名
]
self.app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True, # 是否支持 cookie
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有 HTTP 请求头
)
def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由

View File

@@ -56,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.11.4"
MMC_VERSION = "0.11.5"
def get_key_comment(toml_table, key):

View File

@@ -581,9 +581,15 @@ class DebugConfig(ConfigBase):
show_jargon_prompt: bool = False
"""是否显示jargon相关提示词"""
show_memory_prompt: bool = False
"""是否显示记忆检索相关prompt"""
show_planner_prompt: bool = False
"""是否显示planner相关提示词"""
show_lpmm_paragraph: bool = False
"""是否显示lpmm找到的相关文段日志"""
@dataclass
class ExperimentalConfig(ConfigBase):
@@ -647,7 +653,7 @@ class LPMMKnowledgeConfig(ConfigBase):
enable: bool = True
"""是否启用LPMM知识库"""
lpmm_mode: Literal["classic", "agent"] = "classic"
"""LPMM知识库模式可选classic经典模式agent 模式,结合最新的记忆一同使用"""
@@ -690,4 +696,4 @@ class JargonConfig(ConfigBase):
"""Jargon配置类"""
all_global: bool = False
"""是否将所有新增的jargon项目默认为全局is_global=Truechat_id记录第一次存储时的id"""
"""是否将所有新增的jargon项目默认为全局is_global=Truechat_id记录第一次存储时的id"""

View File

@@ -467,11 +467,7 @@ class ExpressionLearner:
up_content: str,
current_time: float,
) -> None:
expr_obj = (
Expression.select()
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
.first()
)
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
if expr_obj:
await self._update_existing_expression(

View File

@@ -42,8 +42,6 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
@@ -238,9 +236,9 @@ class ExpressionSelector:
else:
target_message_str = ""
target_message_extra_block = ""
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
# 构建reply_reason块
if reply_reason:
reply_reason_block = f"你的回复理由是:{reply_reason}"
@@ -262,9 +260,8 @@ class ExpressionSelector:
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# print(prompt)
if not content:
logger.warning("LLM返回空结果")
return [], []

View File

@@ -36,10 +36,7 @@ def _contains_bot_self_name(content: str) -> bool:
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [
str(alias or "").strip().lower()
for alias in getattr(bot_config, "alias_names", []) or []
]
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
@@ -149,7 +146,7 @@ async def _enrich_raw_content_if_needed(
) -> List[str]:
"""
检查raw_content是否只包含黑话本身如果是则获取该消息的前三条消息作为原始内容
Args:
content: 黑话内容
raw_content_list: 原始raw_content列表
@@ -157,22 +154,22 @@ async def _enrich_raw_content_if_needed(
messages: 当前时间窗口内的消息列表
extraction_start_time: 提取开始时间
extraction_end_time: 提取结束时间
Returns:
处理后的raw_content列表
"""
enriched_list = []
for raw_content in raw_content_list:
# 检查raw_content是否只包含黑话本身去除空白字符后比较
raw_content_clean = raw_content.strip()
content_clean = content.strip()
# 如果raw_content只包含黑话本身可能有一些标点或空白则尝试获取上下文
# 去除所有空白字符后比较,确保只包含黑话本身
raw_content_normalized = raw_content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
content_normalized = content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
if raw_content_normalized == content_normalized:
# 在消息列表中查找只包含该黑话的消息(去除空白后比较)
target_message = None
@@ -183,22 +180,20 @@ async def _enrich_raw_content_if_needed(
if msg_content_normalized == content_normalized:
target_message = msg
break
if target_message and target_message.time:
# 获取该消息的前三条消息
try:
previous_messages = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=target_message.time,
limit=3
chat_id=chat_id, timestamp=target_message.time, limit=3
)
if previous_messages:
# 将前三条消息和当前消息一起格式化
context_messages = previous_messages + [target_message]
# 按时间排序
context_messages.sort(key=lambda x: x.time or 0)
# 格式化为可读消息
formatted_context, _ = await build_readable_messages_with_list(
context_messages,
@@ -206,7 +201,7 @@ async def _enrich_raw_content_if_needed(
timestamp_mode="relative",
truncate=False,
)
if formatted_context.strip():
enriched_list.append(formatted_context.strip())
logger.warning(f"为黑话 {content} 补充了上下文消息")
@@ -226,7 +221,7 @@ async def _enrich_raw_content_if_needed(
else:
# raw_content包含更多内容直接使用
enriched_list.append(raw_content)
return enriched_list
@@ -240,31 +235,31 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
# 如果已完成所有推断,不再推断
if jargon_obj.is_complete:
return False
count = jargon_obj.count or 0
last_inference = jargon_obj.last_inference_count or 0
# 阈值列表3,6, 10, 20, 40, 60, 100
thresholds = [3,6, 10, 20, 40, 60, 100]
thresholds = [3, 6, 10, 20, 40, 60, 100]
if count < thresholds[0]:
return False
# 如果count没有超过上次判定值不需要判定
if count <= last_inference:
return False
# 找到第一个大于last_inference的阈值
next_threshold = None
for threshold in thresholds:
if threshold > last_inference:
next_threshold = threshold
break
# 如果没有找到下一个阈值说明已经超过100不应该再推断
if next_threshold is None:
return False
# 检查count是否达到或超过这个阈值
return count >= next_threshold
@@ -275,13 +270,13 @@ class JargonMiner:
self.last_learning_time: float = time.time()
# 频率控制,可按需调整
self.min_messages_for_learning: int = 10
self.min_learning_interval: float = 20
self.min_learning_interval: float = 20
self.llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="jargon.extract",
)
# 初始化stream_name作为类属性避免重复提取
chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(self.chat_id)
@@ -306,17 +301,19 @@ class JargonMiner:
try:
content = jargon_obj.content
raw_content_str = jargon_obj.raw_content or ""
# 解析raw_content列表
raw_content_list = []
if raw_content_str:
try:
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
raw_content_list = (
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
)
if not isinstance(raw_content_list, list):
raw_content_list = [raw_content_list] if raw_content_list else []
except (json.JSONDecodeError, TypeError):
raw_content_list = [raw_content_str] if raw_content_str else []
if not raw_content_list:
logger.warning(f"jargon {content} 没有raw_content跳过推断")
return
@@ -328,12 +325,12 @@ class JargonMiner:
content=content,
raw_content_list=raw_content_text,
)
response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3)
if not response1:
logger.warning(f"jargon {content} 推断1失败无响应")
return
# 解析推断1结果
inference1 = None
try:
@@ -349,7 +346,7 @@ class JargonMiner:
except Exception as e:
logger.error(f"jargon {content} 推断1解析失败: {e}")
return
# 检查推断1是否表示信息不足无法推断
no_info = inference1.get("no_info", False)
meaning1 = inference1.get("meaning", "").strip()
@@ -360,18 +357,17 @@ class JargonMiner:
jargon_obj.save()
return
# 步骤2: 仅基于content推断
prompt2 = await global_prompt_manager.format_prompt(
"jargon_inference_content_only_prompt",
content=content,
)
response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3)
if not response2:
logger.warning(f"jargon {content} 推断2失败无响应")
return
# 解析推断2结果
inference2 = None
try:
@@ -387,13 +383,12 @@ class JargonMiner:
except Exception as e:
logger.error(f"jargon {content} 推断2解析失败: {e}")
return
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
logger.info(f"jargon {content} 推断2结果: {response2}")
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
logger.info(f"jargon {content} 推断1结果: {response1}")
# logger.info(f"jargon {content} 推断2提示词: {prompt2}")
# logger.info(f"jargon {content} 推断2结果: {response2}")
# logger.info(f"jargon {content} 推断1提示词: {prompt1}")
# logger.info(f"jargon {content} 推断1结果: {response1}")
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
logger.info(f"jargon {content} 推断2结果: {response2}")
@@ -404,22 +399,22 @@ class JargonMiner:
logger.debug(f"jargon {content} 推断2结果: {response2}")
logger.debug(f"jargon {content} 推断1提示词: {prompt1}")
logger.debug(f"jargon {content} 推断1结果: {response1}")
# 步骤3: 比较两个推断结果
prompt3 = await global_prompt_manager.format_prompt(
"jargon_compare_inference_prompt",
inference1=json.dumps(inference1, ensure_ascii=False),
inference2=json.dumps(inference2, ensure_ascii=False),
)
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 比较提示词: {prompt3}")
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
if not response3:
logger.warning(f"jargon {content} 比较失败:无响应")
return
# 解析比较结果
comparison = None
try:
@@ -439,7 +434,7 @@ class JargonMiner:
# 判断是否为黑话
is_similar = comparison.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
if is_jargon:
@@ -448,17 +443,19 @@ class JargonMiner:
else:
# 不是黑话也记录含义使用推断2的结果因为含义明确
jargon_obj.meaning = inference2.get("meaning", "")
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
jargon_obj.save()
logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
logger.debug(
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
)
# 固定输出推断结果,格式化为可读形式
if is_jargon:
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
@@ -471,10 +468,11 @@ class JargonMiner:
else:
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
logger.info(f"[{self.stream_name}]{content} 不是黑话")
except Exception as e:
logger.error(f"jargon推断失败: {e}")
import traceback
traceback.print_exc()
def should_trigger(self) -> bool:
@@ -502,7 +500,7 @@ class JargonMiner:
# 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_learning_time
extraction_end_time = time.time()
# 拉取学习窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
@@ -525,7 +523,7 @@ class JargonMiner:
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
if not response:
return
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon提取提示词: {prompt}")
logger.info(f"jargon提取结果: {response}")
@@ -555,7 +553,7 @@ class JargonMiner:
continue
content = str(item.get("content", "")).strip()
raw_content_value = item.get("raw_content", "")
# 处理raw_content可能是字符串或列表
raw_content_list = []
if isinstance(raw_content_value, list):
@@ -566,15 +564,12 @@ class JargonMiner:
raw_content_str = raw_content_value.strip()
if raw_content_str:
raw_content_list = [raw_content_str]
if content and raw_content_list:
if _contains_bot_self_name(content):
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
continue
entries.append({
"content": content,
"raw_content": raw_content_list
})
entries.append({"content": content, "raw_content": raw_content_list})
except Exception as e:
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
return
@@ -591,13 +586,13 @@ class JargonMiner:
if content_key not in seen:
seen.add(content_key)
uniq_entries.append(entry)
saved = 0
updated = 0
for entry in uniq_entries:
content = entry["content"]
raw_content_list = entry["raw_content"] # 已经是列表
# 检查并补充raw_content如果只包含黑话本身则获取前三条消息作为上下文
raw_content_list = await _enrich_raw_content_if_needed(
content=content,
@@ -607,60 +602,53 @@ class JargonMiner:
extraction_start_time=extraction_start_time,
extraction_end_time=extraction_end_time,
)
try:
# 根据all_global配置决定查询逻辑
if global_config.jargon.all_global:
# 开启all_global无视chat_id查询所有content匹配的记录所有记录都是全局的
query = (
Jargon.select()
.where(Jargon.content == content)
)
query = Jargon.select().where(Jargon.content == content)
else:
# 关闭all_global只查询chat_id匹配的记录不考虑is_global
query = (
Jargon.select()
.where(
(Jargon.chat_id == self.chat_id) &
(Jargon.content == content)
)
)
query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
if query.exists():
obj = query.get()
try:
obj.count = (obj.count or 0) + 1
except Exception:
obj.count = 1
# 合并raw_content列表读取现有列表追加新值去重
existing_raw_content = []
if obj.raw_content:
try:
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError):
existing_raw_content = [obj.raw_content] if obj.raw_content else []
# 合并并去重
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
# 开启all_global时确保记录标记为is_global=True
if global_config.jargon.all_global:
obj.is_global = True
# 关闭all_global时保持原有is_global不变不修改
obj.save()
# 检查是否需要推断(达到阈值且超过上次判定值)
if _should_infer_meaning(obj):
# 异步触发推断,不阻塞主流程
# 重新加载对象以确保数据最新
jargon_id = obj.id
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
updated += 1
else:
# 没找到匹配记录,创建新记录
@@ -670,13 +658,13 @@ class JargonMiner:
else:
# 关闭all_global新记录is_global=False
is_global_new = False
Jargon.create(
content=content,
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
chat_id=self.chat_id,
is_global=is_global_new,
count=1
count=1,
)
saved += 1
except Exception as e:
@@ -688,13 +676,13 @@ class JargonMiner:
# 收集所有提取的jargon内容
jargon_list = [entry["content"] for entry in uniq_entries]
jargon_str = ",".join(jargon_list)
# 输出格式化的结果使用logger.info会自动应用jargon模块的颜色
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
self.last_learning_time = extraction_end_time
if saved or updated:
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated}chat_id={self.chat_id}")
except Exception as e:
@@ -720,15 +708,11 @@ async def extract_and_store_jargon(chat_id: str) -> None:
def search_jargon(
keyword: str,
chat_id: Optional[str] = None,
limit: int = 10,
case_sensitive: bool = False,
fuzzy: bool = True
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]:
"""
搜索jargon支持大小写不敏感和模糊搜索
Args:
keyword: 搜索关键词
chat_id: 可选的聊天ID
@@ -737,21 +721,18 @@ def search_jargon(
limit: 返回结果数量限制默认10
case_sensitive: 是否大小写敏感默认False不敏感
fuzzy: 是否模糊搜索默认True使用LIKE匹配
Returns:
List[Dict[str, str]]: 包含content, meaning的字典列表
"""
if not keyword or not keyword.strip():
return []
keyword = keyword.strip()
# 构建查询
query = Jargon.select(
Jargon.content,
Jargon.meaning
)
query = Jargon.select(Jargon.content, Jargon.meaning)
# 构建搜索条件
if case_sensitive:
# 大小写敏感
@@ -760,7 +741,7 @@ def search_jargon(
search_condition = Jargon.content.contains(keyword)
else:
# 精确匹配
search_condition = (Jargon.content == keyword)
search_condition = Jargon.content == keyword
else:
# 大小写不敏感
if fuzzy:
@@ -768,10 +749,10 @@ def search_jargon(
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
else:
# 精确匹配使用LOWER函数
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
query = query.where(search_condition)
# 根据all_global配置决定查询逻辑
if global_config.jargon.all_global:
# 开启all_global所有记录都是全局的查询所有is_global=True的记录无视chat_id
@@ -779,35 +760,28 @@ def search_jargon(
else:
# 关闭all_global如果提供了chat_id优先搜索该聊天或global的jargon
if chat_id:
query = query.where(
(Jargon.chat_id == chat_id) | Jargon.is_global
)
query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
# 只返回有meaning的记录
query = query.where(
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
)
query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
# 按count降序排序优先返回出现频率高的
query = query.order_by(Jargon.count.desc())
# 限制结果数量
query = query.limit(limit)
# 执行查询并返回结果
results = []
for jargon in query:
results.append({
"content": jargon.content or "",
"meaning": jargon.meaning or ""
})
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
return results
async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: str) -> None:
"""将黑话存入jargon系统
Args:
jargon_keyword: 黑话关键词
answer: 答案内容将概括为raw_content
@@ -820,53 +794,52 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
答案:{answer}
只输出概括后的内容,不要输出其他内容:"""
success, summary, _, _ = await llm_api.generate_with_model(
summary_prompt,
model_config=model_config.model_task_config.utils_small,
request_type="memory.summarize_jargon",
)
logger.info(f"概括答案提示: {summary_prompt}")
logger.info(f"概括答案: {summary}")
if not success:
logger.warning(f"概括答案失败,使用原始答案: {summary}")
summary = answer[:100] # 截取前100字符作为备用
raw_content = summary.strip()[:200] # 限制长度
# 检查是否已存在
if global_config.jargon.all_global:
query = Jargon.select().where(Jargon.content == jargon_keyword)
else:
query = Jargon.select().where(
(Jargon.chat_id == chat_id) &
(Jargon.content == jargon_keyword)
)
query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
if query.exists():
# 更新现有记录
obj = query.get()
obj.count = (obj.count or 0) + 1
# 合并raw_content列表
existing_raw_content = []
if obj.raw_content:
try:
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError):
existing_raw_content = [obj.raw_content] if obj.raw_content else []
# 合并并去重
merged_list = list(dict.fromkeys(existing_raw_content + [raw_content]))
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
if global_config.jargon.all_global:
obj.is_global = True
obj.save()
logger.info(f"更新jargon记录: {jargon_keyword}")
else:
@@ -877,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
raw_content=json.dumps([raw_content], ensure_ascii=False),
chat_id=chat_id,
is_global=is_global_new,
count=1
count=1,
)
logger.info(f"创建新jargon记录: {jargon_keyword}")
except Exception as e:
logger.error(f"存储jargon失败: {e}")

View File

@@ -147,7 +147,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
param_type_value = tool_option_param.param_type.value
if param_type_value == "bool":
param_type_value = "boolean"
return_dict: dict[str, Any] = {
"type": param_type_value,
"description": tool_option_param.description,

View File

@@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
param_type_value = tool_option_param.param_type.value
if param_type_value == "bool":
param_type_value = "boolean"
return_dict: dict[str, Any] = {
"type": param_type_value,
"description": tool_option_param.description,

View File

@@ -116,9 +116,7 @@ class MessageBuilder:
构建消息对象
:return: Message对象
"""
if len(self.__content) == 0 and not (
self.__role == RoleType.Assistant and self.__tool_calls
):
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空")

View File

@@ -166,7 +166,7 @@ class LLMRequest:
time_cost=time.time() - start_time,
)
return content or "", (reasoning_content, model_info.name, tool_calls)
async def generate_response_with_message_async(
self,
message_factory: Callable[[BaseClient], List[Message]],

View File

@@ -36,37 +36,39 @@ class MainSystem:
# 使用消息API替代直接的FastAPI实例
self.app: MessageServer = get_global_api()
self.server: Server = get_global_server()
# 注册 WebUI API 路由
self._register_webui_routes()
# 设置 WebUI开发/生产模式)
self._setup_webui()
self.webui_server = None # 独立的 WebUI 服务器
def _register_webui_routes(self):
"""注册 WebUI API 路由"""
try:
from src.webui.routes import router as webui_router
self.server.register_router(webui_router)
logger.info("WebUI API 路由已注册")
except Exception as e:
logger.warning(f"注册 WebUI API 路由失败: {e}")
# 设置独立的 WebUI 服务器
self._setup_webui_server()
def _setup_webui(self):
"""设置 WebUI(根据环境变量决定模式)"""
def _setup_webui_server(self):
"""设置独立的 WebUI 服务器"""
import os
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
if not webui_enabled:
logger.info("WebUI 已禁用")
return
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
try:
from src.webui.manager import setup_webui
setup_webui(mode=webui_mode)
from src.webui.webui_server import get_webui_server
self.webui_server = get_webui_server()
if webui_mode == "development":
logger.info("📝 WebUI 开发模式已启用")
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
logger.info("💡 请手动启动前端开发服务器: cd MaiBot-Dashboard && bun dev")
logger.info("💡 前端将运行在 http://localhost:7999")
else:
logger.info("✅ WebUI 生产模式已启用")
logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001")
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
except Exception as e:
logger.error(f"设置 WebUI 失败: {e}")
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
async def initialize(self):
"""初始化系统组件"""
@@ -161,6 +163,10 @@ class MainSystem:
self.server.run(),
]
# 如果 WebUI 服务器已初始化,添加到任务列表
if self.webui_server:
tasks.append(self.webui_server.start())
await asyncio.gather(*tasks)
# async def forget_memory_task(self):

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@
记忆系统工具函数
包含模糊查找、相似度计算等工具函数
"""
import json
import re
from datetime import datetime
@@ -14,6 +15,7 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_md_json(json_text: str) -> list[str]:
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
@@ -52,14 +54,15 @@ def parse_md_json(json_text: str) -> list[str]:
return json_objects, reasoning_content
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度分数 (0-1)
"""
@@ -67,16 +70,16 @@ def calculate_similarity(text1: str, text2: str) -> float:
# 预处理文本
text1 = preprocess_text(text1)
text2 = preprocess_text(text2)
# 使用SequenceMatcher计算相似度
similarity = SequenceMatcher(None, text1, text2).ratio()
# 如果其中一个文本包含另一个,提高相似度
if text1 in text2 or text2 in text1:
similarity = max(similarity, 0.8)
return similarity
except Exception as e:
logger.error(f"计算相似度时出错: {e}")
return 0.0
@@ -85,31 +88,30 @@ def calculate_similarity(text1: str, text2: str) -> float:
def preprocess_text(text: str) -> str:
"""
预处理文本,提高匹配准确性
Args:
text: 原始文本
Returns:
str: 预处理后的文本
"""
try:
# 转换为小写
text = text.lower()
# 移除标点符号和特殊字符
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r"[^\w\s]", "", text)
# 移除多余空格
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r"\s+", " ", text).strip()
return text
except Exception as e:
logger.error(f"预处理文本时出错: {e}")
return text
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
@@ -143,25 +145,24 @@ def parse_datetime_to_timestamp(value: str) -> float:
def parse_time_range(time_range: str) -> Tuple[float, float]:
"""
解析时间范围字符串,返回开始和结束时间戳
Args:
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
Returns:
Tuple[float, float]: (开始时间戳, 结束时间戳)
"""
if " - " not in time_range:
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
parts = time_range.split(" - ", 1)
if len(parts) != 2:
raise ValueError(f"时间范围格式错误: {time_range}")
start_str = parts[0].strip()
end_str = parts[1].strip()
start_timestamp = parse_datetime_to_timestamp(start_str)
end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp
return start_timestamp, end_timestamp

View File

@@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_person_info import register_tool as register_query_person_info
from src.config.config import global_config
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_jargon()

View File

@@ -15,13 +15,10 @@ logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_range: Optional[str] = None,
fuzzy: bool = True
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
Args:
chat_id: 聊天ID
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
@@ -31,7 +28,7 @@ async def query_chat_history(
fuzzy: 是否使用模糊匹配模式默认True
- True: 模糊匹配只要包含任意一个关键词即匹配OR关系
- False: 全匹配必须包含所有关键词才匹配AND关系
Returns:
str: 查询结果
"""
@@ -39,10 +36,10 @@ async def query_chat_history(
# 检查参数
if not keyword and not time_range:
return "未指定查询参数需要提供keyword或time_range之一"
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件
if time_range:
# 判断是时间点还是时间范围
@@ -50,79 +47,79 @@ async def query_chat_history(
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = (
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = (
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
query = query.where(time_filter)
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
# 如果有关键词,进一步过滤
if keyword:
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if not keywords_lower:
return "关键词为空"
filtered_records = []
for record in records:
# 在theme、keywords、summary、original_text中搜索
theme = (record.theme or "").lower()
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析record中的keywords JSON
record_keywords_list = []
if record.keywords:
try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 根据匹配模式检查关键词
matched = False
if fuzzy:
# 模糊匹配只要包含任意一个关键词即匹配OR关系
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
if (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
):
matched = True
break
else:
# 全匹配必须包含所有关键词才匹配AND关系
matched = True
for kw in keywords_lower:
kw_matched = (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list))
kw_matched = (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
)
if not kw_matched:
matched = False
break
if matched:
filtered_records.append(record)
if not filtered_records:
keywords_str = "".join(keywords_list)
match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词"
@@ -130,9 +127,9 @@ async def query_chat_history(
return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述"
else:
return f"未找到{match_mode}'{keywords_str}'的聊天记录概述"
records = filtered_records
# 如果没有记录(可能是时间范围查询但没有匹配的记录)
if not records:
if time_range:
@@ -148,22 +145,23 @@ async def query_chat_history(
record.count = (record.count or 0) + 1
except Exception as update_error:
logger.error(f"更新聊天记录概述计数失败: {update_error}")
# 构建结果文本
results = []
for record in records_to_use: # 最多返回3条记录
result_parts = []
# 添加主题
if record.theme:
result_parts.append(f"主题:{record.theme}")
# 添加时间范围
from datetime import datetime
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}")
# 添加概括优先使用summary如果没有则使用original_text的前200字符
if record.summary:
result_parts.append(f"概括:{record.summary}")
@@ -172,18 +170,18 @@ async def query_chat_history(
if len(record.original_text) > 200:
text_preview += "..."
result_parts.append(f"内容:{text_preview}")
results.append("\n".join(result_parts))
if not results:
return "未找到相关聊天记录概述"
response_text = "\n\n---\n\n".join(results)
if len(records) > len(records_to_use):
omitted_count = len(records) - len(records_to_use)
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
return response_text
except Exception as e:
logger.error(f"查询聊天历史概述失败: {e}")
return f"查询失败: {str(e)}"
@@ -199,20 +197,20 @@ def register_tool():
"name": "keyword",
"type": "string",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
"required": False
"required": False,
},
{
"name": "time_range",
"type": "string",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False
"required": False,
},
{
"name": "fuzzy",
"type": "boolean",
"description": "是否使用模糊匹配模式默认True。True表示模糊匹配只要包含任意一个关键词即匹配OR关系False表示全匹配必须包含所有关键词才匹配AND关系",
"required": False
}
"required": False,
},
],
execute_func=query_chat_history
execute_func=query_chat_history,
)

View File

@@ -73,5 +73,3 @@ def register_tool():
],
execute_func=query_lpmm_knowledge,
)

View File

@@ -14,23 +14,25 @@ logger = get_logger("memory_retrieval_tools")
def _format_group_nick_names(group_nick_name_field) -> str:
"""格式化群昵称信息
Args:
group_nick_name_field: 群昵称字段可能是字符串JSON或None
Returns:
str: 格式化后的群昵称信息字符串
"""
if not group_nick_name_field:
return ""
try:
# 解析JSON格式的群昵称列表
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
group_nick_names_data = (
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
)
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
return ""
# 格式化群昵称列表
group_nick_list = []
for item in group_nick_names_data:
@@ -41,7 +43,7 @@ def _format_group_nick_names(group_nick_name_field) -> str:
elif isinstance(item, str):
# 兼容旧格式(如果存在)
group_nick_list.append(f" - {item}")
if group_nick_list:
return "群昵称:\n" + "\n".join(group_nick_list)
return ""
@@ -58,10 +60,10 @@ def _format_group_nick_names(group_nick_name_field) -> str:
async def query_person_info(person_name: str) -> str:
"""根据person_name查询用户信息使用模糊查询
Args:
person_name: 用户名称person_name字段
Returns:
str: 查询结果,包含用户的所有信息
"""
@@ -69,37 +71,35 @@ async def query_person_info(person_name: str) -> str:
person_name = str(person_name).strip()
if not person_name:
return "用户名称为空"
# 构建查询条件(使用模糊查询)
query = PersonInfo.select().where(
PersonInfo.person_name.contains(person_name)
)
query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
# 执行查询
records = list(query.limit(20)) # 最多返回20条记录
if not records:
return f"未找到模糊匹配'{person_name}'的用户信息"
# 区分精确匹配和模糊匹配的结果
exact_matches = []
fuzzy_matches = []
for record in records:
# 检查是否是精确匹配
if record.person_name and record.person_name.strip() == person_name:
exact_matches.append(record)
else:
fuzzy_matches.append(record)
# 构建结果文本
results = []
# 先处理精确匹配的结果
for record in exact_matches:
result_parts = []
result_parts.append("【精确匹配】") # 标注为精确匹配
# 基本信息
if record.person_name:
result_parts.append(f"用户名称:{record.person_name}")
@@ -111,19 +111,19 @@ async def query_person_info(person_name: str) -> str:
result_parts.append(f"平台:{record.platform}")
if record.user_id:
result_parts.append(f"平台用户ID{record.user_id}")
# 群昵称信息
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
if group_nick_name_str:
result_parts.append(group_nick_name_str)
# 名称设定原因
if record.name_reason:
result_parts.append(f"名称设定原因:{record.name_reason}")
# 认识状态
result_parts.append(f"是否已认识:{'' if record.is_known else ''}")
# 时间信息
if record.know_since:
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
@@ -133,11 +133,15 @@ async def query_person_info(person_name: str) -> str:
result_parts.append(f"最后认识时间:{last_know_str}")
if record.know_times:
result_parts.append(f"认识次数:{int(record.know_times)}")
# 记忆点memory_points
if record.memory_points:
try:
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
memory_points_data = (
json.loads(record.memory_points)
if isinstance(record.memory_points, str)
else record.memory_points
)
if isinstance(memory_points_data, list) and memory_points_data:
# 解析记忆点格式category:content:weight
memory_list = []
@@ -151,7 +155,7 @@ async def query_person_info(person_name: str) -> str:
memory_list.append(f" - [{category}] {content} (权重: {weight})")
else:
memory_list.append(f" - {memory_point}")
if memory_list:
result_parts.append("记忆点:\n" + "\n".join(memory_list))
except (json.JSONDecodeError, TypeError, ValueError) as e:
@@ -161,14 +165,14 @@ async def query_person_info(person_name: str) -> str:
if len(str(record.memory_points)) > 200:
memory_preview += "..."
result_parts.append(f"记忆点(原始数据):{memory_preview}")
results.append("\n".join(result_parts))
# 再处理模糊匹配的结果
for record in fuzzy_matches:
result_parts = []
result_parts.append("【模糊匹配】") # 标注为模糊匹配
# 基本信息
if record.person_name:
result_parts.append(f"用户名称:{record.person_name}")
@@ -180,19 +184,19 @@ async def query_person_info(person_name: str) -> str:
result_parts.append(f"平台:{record.platform}")
if record.user_id:
result_parts.append(f"平台用户ID{record.user_id}")
# 群昵称信息
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
if group_nick_name_str:
result_parts.append(group_nick_name_str)
# 名称设定原因
if record.name_reason:
result_parts.append(f"名称设定原因:{record.name_reason}")
# 认识状态
result_parts.append(f"是否已认识:{'' if record.is_known else ''}")
# 时间信息
if record.know_since:
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
@@ -202,11 +206,15 @@ async def query_person_info(person_name: str) -> str:
result_parts.append(f"最后认识时间:{last_know_str}")
if record.know_times:
result_parts.append(f"认识次数:{int(record.know_times)}")
# 记忆点memory_points
if record.memory_points:
try:
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
memory_points_data = (
json.loads(record.memory_points)
if isinstance(record.memory_points, str)
else record.memory_points
)
if isinstance(memory_points_data, list) and memory_points_data:
# 解析记忆点格式category:content:weight
memory_list = []
@@ -220,7 +228,7 @@ async def query_person_info(person_name: str) -> str:
memory_list.append(f" - [{category}] {content} (权重: {weight})")
else:
memory_list.append(f" - {memory_point}")
if memory_list:
result_parts.append("记忆点:\n" + "\n".join(memory_list))
except (json.JSONDecodeError, TypeError, ValueError) as e:
@@ -230,20 +238,20 @@ async def query_person_info(person_name: str) -> str:
if len(str(record.memory_points)) > 200:
memory_preview += "..."
result_parts.append(f"记忆点(原始数据):{memory_preview}")
results.append("\n".join(result_parts))
# 组合所有结果
if not results:
return f"未找到匹配'{person_name}'的用户信息"
response_text = "\n\n---\n\n".join(results)
# 添加统计信息
total_count = len(records)
exact_count = len(exact_matches)
fuzzy_count = len(fuzzy_matches)
# 显示精确匹配和模糊匹配的统计
if exact_count > 0 or fuzzy_count > 0:
stats_parts = []
@@ -257,13 +265,13 @@ async def query_person_info(person_name: str) -> str:
response_text = f"找到 {total_count} 条匹配的用户信息:\n\n{response_text}"
else:
response_text = f"找到用户信息:\n\n{response_text}"
# 如果结果数量达到限制,添加提示
if total_count >= 20:
response_text += "\n\n(已显示前20条结果可能还有更多匹配记录)"
return response_text
except Exception as e:
logger.error(f"查询用户信息失败: {e}")
return f"查询失败: {str(e)}"
@@ -275,13 +283,7 @@ def register_tool():
name="query_person_info",
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
parameters=[
{
"name": "person_name",
"type": "string",
"description": "用户名称,用于查询用户信息",
"required": True
}
{"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
],
execute_func=query_person_info
execute_func=query_person_info,
)

View File

@@ -47,10 +47,10 @@ class MemoryRetrievalTool:
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
def get_tool_definition(self) -> Dict[str, Any]:
"""获取工具定义用于LLM function calling
Returns:
Dict[str, Any]: 工具定义字典格式与BaseTool一致
格式: {"name": str, "description": str, "parameters": List[Tuple]}
@@ -58,14 +58,14 @@ class MemoryRetrievalTool:
# 转换参数格式为元组列表格式与BaseTool一致
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
param_tuples = []
for param in self.parameters:
param_name = param.get("name", "")
param_type_str = param.get("type", "string").lower()
param_desc = param.get("description", "")
is_required = param.get("required", False)
enum_values = param.get("enum", None)
# 转换类型字符串到ToolParamType
type_mapping = {
"string": ToolParamType.STRING,
@@ -76,18 +76,14 @@ class MemoryRetrievalTool:
"bool": ToolParamType.BOOLEAN,
}
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
# 构建参数元组
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
param_tuples.append(param_tuple)
# 构建工具定义格式与BaseTool.get_tool_definition()一致
tool_def = {
"name": self.name,
"description": self.description,
"parameters": param_tuples
}
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
return tool_def
@@ -126,10 +122,10 @@ class MemoryRetrievalToolRegistry:
action_types.append("final_answer")
action_types.append("no_answer")
return "".join([f'"{at}"' for at in action_types])
def get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取所有工具的定义列表用于LLM function calling
Returns:
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
"""

View File

@@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
class Person:
@classmethod
def register_person(
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
cls,
platform: str,
user_id: str,
nickname: str,
group_id: Optional[str] = None,
group_nick_name: Optional[str] = None,
):
"""
注册新用户的类方法
@@ -727,7 +732,7 @@ person_info_manager = PersonInfoManager()
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
"""将人物信息存入person_info的memory_points
Args:
person_name: 人物名称
memory_content: 记忆内容
@@ -739,13 +744,13 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
if not chat_stream:
logger.warning(f"无法获取chat_stream for chat_id: {chat_id}")
return
platform = chat_stream.platform
# 尝试从person_name查找person_id
# 首先尝试通过person_name查找
person_id = get_person_id_by_person_name(person_name)
if not person_id:
# 如果通过person_name找不到尝试从chat_stream获取user_info
if chat_stream.user_info:
@@ -754,25 +759,25 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
else:
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
return
# 创建或获取Person对象
person = Person(person_id=person_id)
if not person.is_known:
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
return
# 确定记忆分类可以根据memory_content判断这里使用通用分类
category = "其他" # 默认分类,可以根据需要调整
# 记忆点格式category:content:weight
weight = "1.0" # 默认权重
memory_point = f"{category}:{memory_content}:{weight}"
# 添加到memory_points
if not person.memory_points:
person.memory_points = []
# 检查是否已存在相似的记忆点(避免重复)
is_duplicate = False
for existing_point in person.memory_points:
@@ -781,16 +786,20 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
if len(parts) >= 2:
existing_content = parts[1].strip()
# 简单相似度检查(如果内容相同或非常相似,则跳过)
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
if (
existing_content == memory_content
or memory_content in existing_content
or existing_content in memory_content
):
is_duplicate = True
break
if not is_duplicate:
person.memory_points.append(memory_point)
person.sync_to_database()
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
else:
logger.debug(f"记忆点已存在,跳过: {memory_point}")
except Exception as e:
logger.error(f"存储人物记忆失败: {e}")

View File

@@ -124,7 +124,6 @@ class ToolExecutor:
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools, raise_when_empty=False
)
# 执行工具调用
tool_results, used_tools = await self.execute_tool_calls(tool_calls)

View File

@@ -102,13 +102,13 @@ class EmojiAction(BaseAction):
# 5. 调用LLM
models = llm_api.get_available_models()
chat_model_config = models.get("replyer") # 使用字典访问方式
chat_model_config = models.get("utils") # 使用字典访问方式
if not chat_model_config:
logger.error(f"{self.log_prefix} 未找到'replyer'模型配置无法调用LLM")
return False, "未找到'replyer'模型配置"
logger.error(f"{self.log_prefix} 未找到'utils'模型配置无法调用LLM")
return False, "未找到'utils'模型配置"
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="emoji"
prompt, model_config=chat_model_config, request_type="emoji.select"
)
if not success:

View File

@@ -51,7 +51,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None:
"""
递归合并字典,保留 target 中的注释和格式
将 source 的值更新到 target 中(仅更新已存在的键)
Args:
target: 目标字典tomlkit 对象,包含注释)
source: 源字典(普通 dict 或 list
@@ -59,7 +59,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None:
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
if isinstance(source, list):
return # 调用者需要直接赋值
# 如果都是字典,递归合并
if isinstance(source, dict) and isinstance(target, dict):
for key, value in source.items():
@@ -319,6 +319,58 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
# ===== 原始 TOML 文件操作接口 =====
@router.get("/bot/raw")
async def get_bot_config_raw():
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
raw_content = f.read()
return {"success": True, "content": raw_content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
try:
config_data = tomlkit.loads(raw_content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
# 验证配置数据结构
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
f.write(raw_content)
logger.info("麦麦主程序配置已更新(原始模式)")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
"""更新模型配置的指定节(保留注释和格式)"""
@@ -364,3 +416,144 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
# ===== 适配器配置管理接口 =====
@router.get("/adapter-config/path")
async def get_adapter_config_path():
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
webui_data_path = os.path.join("data", "webui.json")
if not os.path.exists(webui_data_path):
return {"success": True, "path": None}
import json
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
adapter_config_path = webui_data.get("adapter_config_path")
if not adapter_config_path:
return {"success": True, "path": None}
# 检查文件是否存在并返回最后修改时间
if os.path.exists(adapter_config_path):
import datetime
mtime = os.path.getmtime(adapter_config_path)
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
return {"success": True, "path": adapter_config_path, "lastModified": last_modified}
else:
return {"success": True, "path": adapter_config_path, "lastModified": None}
except Exception as e:
logger.error(f"获取适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
# 保存到 data/webui.json
webui_data_path = os.path.join("data", "webui.json")
import json
# 读取现有数据
if os.path.exists(webui_data_path):
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
else:
webui_data = {}
# 更新路径
webui_data["adapter_config_path"] = path
# 保存
os.makedirs("data", exist_ok=True)
with open(webui_data_path, "w", encoding="utf-8") as f:
json.dump(webui_data, f, ensure_ascii=False, indent=2)
logger.info(f"适配器配置路径已保存: {path}")
return {"success": True, "message": "路径已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
@router.get("/adapter-config")
async def get_adapter_config(path: str):
"""从指定路径读取适配器配置文件"""
try:
if not path:
raise HTTPException(status_code=400, detail="路径参数不能为空")
# 检查文件是否存在
if not os.path.exists(path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
# 检查文件扩展名
if not path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 读取文件内容
with open(path, "r", encoding="utf-8") as f:
content = f.read()
logger.info(f"已读取适配器配置: {path}")
return {"success": True, "content": content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
@router.post("/adapter-config")
async def save_adapter_config(data: dict[str, str] = Body(...)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")
content = data.get("content")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
if content is None:
raise HTTPException(status_code=400, detail="配置内容不能为空")
# 检查文件扩展名
if not path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 验证 TOML 格式
try:
import toml
toml.loads(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
# 确保目录存在
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
# 保存文件
with open(path, "w", encoding="utf-8") as f:
f.write(content)
logger.info(f"适配器配置已保存: {path}")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")

View File

@@ -1,4 +1,5 @@
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel
@@ -18,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
class EmojiResponse(BaseModel):
"""表情包响应"""
id: int
full_path: str
format: str
@@ -26,7 +28,7 @@ class EmojiResponse(BaseModel):
query_count: int
is_registered: bool
is_banned: bool
emotion: Optional[List[str]] # 解析后的 JSON
emotion: Optional[str] # 直接返回字符串
record_time: float
register_time: Optional[float]
usage_count: int
@@ -35,6 +37,7 @@ class EmojiResponse(BaseModel):
class EmojiListResponse(BaseModel):
"""表情包列表响应"""
success: bool
total: int
page: int
@@ -44,20 +47,23 @@ class EmojiListResponse(BaseModel):
class EmojiDetailResponse(BaseModel):
"""表情包详情响应"""
success: bool
data: EmojiResponse
class EmojiUpdateRequest(BaseModel):
"""表情包更新请求"""
description: Optional[str] = None
is_registered: Optional[bool] = None
is_banned: Optional[bool] = None
emotion: Optional[List[str]] = None
emotion: Optional[str] = None
class EmojiUpdateResponse(BaseModel):
"""表情包更新响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
@@ -65,34 +71,41 @@ class EmojiUpdateResponse(BaseModel):
class EmojiDeleteResponse(BaseModel):
"""表情包删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
emoji_ids: List[int]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[int] = []
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
def parse_emotion(emotion_str: Optional[str]) -> Optional[List[str]]:
"""解析情感标签 JSON 字符串"""
if not emotion_str:
return None
try:
return json.loads(emotion_str)
except (json.JSONDecodeError, TypeError):
return None
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
"""将 Emoji 模型转换为响应对象"""
return EmojiResponse(
@@ -104,7 +117,7 @@ def emoji_to_response(emoji: Emoji) -> EmojiResponse:
query_count=emoji.query_count,
is_registered=emoji.is_registered,
is_banned=emoji.is_banned,
emotion=parse_emotion(emoji.emotion),
emotion=str(emoji.emotion) if emoji.emotion is not None else None,
record_time=emoji.record_time,
register_time=emoji.register_time,
usage_count=emoji.usage_count,
@@ -120,11 +133,11 @@ async def get_emoji_list(
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
format: Optional[str] = Query(None, description="格式筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表情包列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
@@ -133,61 +146,51 @@ async def get_emoji_list(
is_banned: 是否被禁用筛选
format: 格式筛选
authorization: Authorization header
Returns:
表情包列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = Emoji.select()
# 搜索过滤
if search:
query = query.where(
(Emoji.description.contains(search)) |
(Emoji.emoji_hash.contains(search))
)
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
# 注册状态过滤
if is_registered is not None:
query = query.where(Emoji.is_registered == is_registered)
# 禁用状态过滤
if is_banned is not None:
query = query.where(Emoji.is_banned == is_banned)
# 格式过滤
if format:
query = query.where(Emoji.format == format)
# 排序:使用次数倒序,然后按记录时间倒序
from peewee import Case
query = query.order_by(
Emoji.usage_count.desc(),
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
Emoji.record_time.desc()
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
emojis = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [emoji_to_response(emoji) for emoji in emojis]
return EmojiListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
@@ -196,33 +199,27 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
获取表情包详细信息
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
表情包详细信息
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
return EmojiDetailResponse(
success=True,
data=emoji_to_response(emoji)
)
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
@@ -231,61 +228,50 @@ async def get_emoji_detail(
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
"""
增量更新表情包(只更新提供的字段)
Args:
emoji_id: 表情包ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 处理情感标签(转换为 JSON
if 'emotion' in update_data:
if update_data['emotion'] is None:
update_data['emotion'] = None
else:
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
# emotion 字段直接使用字符串,无需转换
# 如果注册状态从 False 变为 True记录注册时间
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
update_data['register_time'] = time.time()
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
update_data["register_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(emoji, field, value)
emoji.save()
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
return EmojiUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=emoji_to_response(emoji)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
)
except HTTPException:
raise
except Exception as e:
@@ -294,41 +280,35 @@ async def update_emoji(
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
删除表情包
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 记录删除信息
emoji_hash = emoji.emoji_hash
# 执行删除
emoji.delete_instance()
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
return EmojiDeleteResponse(
success=True,
message=f"成功删除表情包: {emoji_hash}"
)
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
except HTTPException:
raise
except Exception as e:
@@ -337,31 +317,29 @@ async def delete_emoji(
@router.get("/stats/summary")
async def get_emoji_stats(
authorization: Optional[str] = Header(None)
):
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
"""
获取表情包统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = Emoji.select().count()
registered = Emoji.select().where(Emoji.is_registered).count()
banned = Emoji.select().where(Emoji.is_banned).count()
# 按格式统计
formats = {}
for emoji in Emoji.select(Emoji.format):
fmt = emoji.format
formats[fmt] = formats.get(fmt, 0) + 1
# 获取最常用的表情包前10
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
top_used_list = [
@@ -369,11 +347,11 @@ async def get_emoji_stats(
"id": emoji.id,
"emoji_hash": emoji.emoji_hash,
"description": emoji.description,
"usage_count": emoji.usage_count
"usage_count": emoji.usage_count,
}
for emoji in top_used
]
return {
"success": True,
"data": {
@@ -382,10 +360,10 @@ async def get_emoji_stats(
"banned": banned,
"unregistered": total - registered,
"formats": formats,
"top_used": top_used_list
}
"top_used": top_used_list,
},
}
except HTTPException:
raise
except Exception as e:
@@ -394,47 +372,40 @@ async def get_emoji_stats(
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
注册表情包(快捷操作)
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if emoji.is_registered:
raise HTTPException(status_code=400, detail="该表情包已经注册")
if emoji.is_banned:
raise HTTPException(status_code=400, detail="该表情包已被禁用,无法注册")
# 注册表情包
emoji.is_registered = True
emoji.register_time = time.time()
emoji.save()
logger.info(f"表情包已注册: ID={emoji_id}")
return EmojiUpdateResponse(
success=True,
message="表情包注册成功",
data=emoji_to_response(emoji)
)
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
@@ -443,41 +414,34 @@ async def register_emoji(
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
禁用表情包(快捷操作)
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 禁用表情包(同时取消注册)
emoji.is_banned = True
emoji.is_registered = False
emoji.save()
logger.info(f"表情包已禁用: ID={emoji_id}")
return EmojiUpdateResponse(
success=True,
message="表情包禁用成功",
data=emoji_to_response(emoji)
)
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
@@ -489,16 +453,16 @@ async def ban_emoji(
async def get_emoji_thumbnail(
emoji_id: int,
token: Optional[str] = Query(None, description="访问令牌"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表情包缩略图
Args:
emoji_id: 表情包ID
token: 访问令牌(通过 query parameter
authorization: Authorization header
Returns:
表情包图片文件
"""
@@ -511,37 +475,88 @@ async def get_emoji_thumbnail(
else:
# 如果没有 query token则验证 Authorization header
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 检查文件是否存在
if not os.path.exists(emoji.full_path):
raise HTTPException(status_code=404, detail="表情包文件不存在")
# 根据格式设置 MIME 类型
mime_types = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'gif': 'image/gif',
'webp': 'image/webp',
'bmp': 'image/bmp'
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
}
media_type = mime_types.get(emoji.format.lower(), 'application/octet-stream')
return FileResponse(
path=emoji.full_path,
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
)
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表情包缩略图失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
"""
批量删除表情包
Args:
request: 包含emoji_ids列表的请求
authorization: Authorization header
Returns:
批量删除结果
"""
try:
verify_auth_token(authorization)
if not request.emoji_ids:
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
deleted_count = 0
failed_count = 0
failed_ids = []
for emoji_id in request.emoji_ids:
try:
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if emoji:
emoji.delete_instance()
deleted_count += 1
logger.info(f"批量删除表情包: {emoji_id}")
else:
failed_count += 1
failed_ids.append(emoji_id)
except Exception as e:
logger.error(f"删除表情包 {emoji_id} 失败: {e}")
failed_count += 1
failed_ids.append(emoji_id)
message = f"成功删除 {deleted_count} 个表情包"
if failed_count > 0:
message += f"{failed_count} 个失败"
return BatchDeleteResponse(
success=True,
message=message,
deleted_count=deleted_count,
failed_count=failed_count,
failed_ids=failed_ids,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e

View File

@@ -1,4 +1,5 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
@@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel):
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
@@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel):
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
context: Optional[str] = None
@@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel):
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
context: Optional[str] = None
@@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel):
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
@@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel):
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
@@ -82,13 +91,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
@@ -112,64 +121,58 @@ async def get_expression_list(
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style, context)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = Expression.select()
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search)) |
(Expression.style.contains(search)) |
(Expression.context.contains(search))
(Expression.situation.contains(search))
| (Expression.style.contains(search))
| (Expression.context.contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
Expression.last_active_time.desc()
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
@@ -178,33 +181,27 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int,
authorization: Optional[str] = Header(None)
):
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(
success=True,
data=expression_to_response(expression)
)
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
except HTTPException:
raise
except Exception as e:
@@ -213,25 +210,22 @@ async def get_expression_detail(
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
authorization: Optional[str] = Header(None)
):
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(authorization)
current_time = time.time()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
@@ -242,15 +236,13 @@ async def create_expression(
last_active_time=current_time,
create_date=current_time,
)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True,
message="表达方式创建成功",
data=expression_to_response(expression)
success=True, message="表达方式创建成功", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
@@ -260,52 +252,48 @@ async def create_expression(
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
authorization: Optional[str] = Header(None)
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
):
"""
增量更新表达方式(只更新提供的字段)
Args:
expression_id: 表达方式ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后活跃时间
update_data['last_active_time'] = time.time()
update_data["last_active_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=expression_to_response(expression)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
@@ -314,41 +302,35 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int,
authorization: Optional[str] = Header(None)
):
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
# 执行删除
expression.delete_instance()
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(
success=True,
message=f"成功删除表达方式: {situation}"
)
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
except Exception as e:
@@ -356,47 +338,93 @@ async def delete_expression(
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int]
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
"""
批量删除表达方式
Args:
request: 包含要删除的ID列表的请求
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
expressions = Expression.select().where(Expression.id.in_(request.ids))
found_ids = [expr.id for expr in expressions]
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
if not_found_ids:
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
logger.info(f"批量删除了 {deleted_count} 个表达方式")
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_expression_stats(
authorization: Optional[str] = Header(None)
):
async def get_expression_stats(authorization: Optional[str] = Header(None)):
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = Expression.select().count()
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = Expression.select().where(
(Expression.create_date.is_null(False)) &
(Expression.create_date >= seven_days_ago)
).count()
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
}
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:
raise
except Exception as e:

View File

@@ -1,4 +1,5 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
@@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror")
# 导入进度更新函数(避免循环导入)
_update_progress = None
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
@@ -23,6 +25,7 @@ def set_update_progress_callback(callback):
class MirrorType(str, Enum):
"""镜像源类型"""
GH_PROXY = "gh-proxy" # gh-proxy 主节点
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
@@ -34,10 +37,10 @@ class MirrorType(str, Enum):
class GitMirrorConfig:
"""Git 镜像源配置管理"""
# 配置文件路径
CONFIG_FILE = Path("data/webui.json")
# 默认镜像源配置
DEFAULT_MIRRORS = [
{
@@ -47,7 +50,7 @@ class GitMirrorConfig:
"clone_prefix": "https://gh-proxy.org/https://github.com",
"enabled": True,
"priority": 1,
"created_at": None
"created_at": None,
},
{
"id": "hk-gh-proxy",
@@ -56,7 +59,7 @@ class GitMirrorConfig:
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 2,
"created_at": None
"created_at": None,
},
{
"id": "cdn-gh-proxy",
@@ -65,7 +68,7 @@ class GitMirrorConfig:
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 3,
"created_at": None
"created_at": None,
},
{
"id": "edgeone-gh-proxy",
@@ -74,7 +77,7 @@ class GitMirrorConfig:
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 4,
"created_at": None
"created_at": None,
},
{
"id": "meyzh-github",
@@ -83,7 +86,7 @@ class GitMirrorConfig:
"clone_prefix": "https://meyzh.github.io/https://github.com",
"enabled": True,
"priority": 5,
"created_at": None
"created_at": None,
},
{
"id": "github",
@@ -92,23 +95,23 @@ class GitMirrorConfig:
"clone_prefix": "https://github.com",
"enabled": True,
"priority": 999,
"created_at": None
}
"created_at": None,
},
]
def __init__(self):
"""初始化配置管理器"""
self.config_file = self.CONFIG_FILE
self.mirrors: List[Dict[str, Any]] = []
self._load_config()
def _load_config(self) -> None:
"""加载配置文件"""
try:
if self.config_file.exists():
with open(self.config_file, 'r', encoding='utf-8') as f:
with open(self.config_file, "r", encoding="utf-8") as f:
data = json.load(f)
# 检查是否有镜像源配置
if "git_mirrors" not in data or not data["git_mirrors"]:
logger.info("配置文件中未找到镜像源配置,使用默认配置")
@@ -122,59 +125,59 @@ class GitMirrorConfig:
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
self._init_default_mirrors()
def _init_default_mirrors(self) -> None:
"""初始化默认镜像源"""
current_time = datetime.now().isoformat()
self.mirrors = []
for mirror in self.DEFAULT_MIRRORS:
mirror_copy = mirror.copy()
mirror_copy["created_at"] = current_time
self.mirrors.append(mirror_copy)
self._save_config()
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
def _save_config(self) -> None:
"""保存配置到文件"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
# 读取现有配置
existing_data = {}
if self.config_file.exists():
with open(self.config_file, 'r', encoding='utf-8') as f:
with open(self.config_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
# 更新镜像源配置
existing_data["git_mirrors"] = self.mirrors
# 写入文件
with open(self.config_file, 'w', encoding='utf-8') as f:
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
logger.debug(f"配置已保存到 {self.config_file}")
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
def get_all_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有镜像源"""
return self.mirrors.copy()
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有启用的镜像源,按优先级排序"""
enabled = [m for m in self.mirrors if m.get("enabled", False)]
return sorted(enabled, key=lambda x: x.get("priority", 999))
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
"""根据 ID 获取镜像源"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
return mirror.copy()
return None
def add_mirror(
self,
mirror_id: str,
@@ -182,26 +185,26 @@ class GitMirrorConfig:
raw_prefix: str,
clone_prefix: str,
enabled: bool = True,
priority: Optional[int] = None
priority: Optional[int] = None,
) -> Dict[str, Any]:
"""
添加新的镜像源
Returns:
添加的镜像源配置
Raises:
ValueError: 如果镜像源 ID 已存在
"""
# 检查 ID 是否已存在
if self.get_mirror_by_id(mirror_id):
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
# 如果未指定优先级,使用最大优先级 + 1
if priority is None:
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
priority = max_priority + 1
new_mirror = {
"id": mirror_id,
"name": name,
@@ -209,15 +212,15 @@ class GitMirrorConfig:
"clone_prefix": clone_prefix,
"enabled": enabled,
"priority": priority,
"created_at": datetime.now().isoformat()
"created_at": datetime.now().isoformat(),
}
self.mirrors.append(new_mirror)
self._save_config()
logger.info(f"已添加镜像源: {mirror_id} - {name}")
return new_mirror.copy()
def update_mirror(
self,
mirror_id: str,
@@ -225,11 +228,11 @@ class GitMirrorConfig:
raw_prefix: Optional[str] = None,
clone_prefix: Optional[str] = None,
enabled: Optional[bool] = None,
priority: Optional[int] = None
priority: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新镜像源配置
Returns:
更新后的镜像源配置,如果不存在则返回 None
"""
@@ -245,19 +248,19 @@ class GitMirrorConfig:
mirror["enabled"] = enabled
if priority is not None:
mirror["priority"] = priority
mirror["updated_at"] = datetime.now().isoformat()
self._save_config()
logger.info(f"已更新镜像源: {mirror_id}")
return mirror.copy()
return None
def delete_mirror(self, mirror_id: str) -> bool:
"""
删除镜像源
Returns:
True 如果删除成功False 如果镜像源不存在
"""
@@ -267,9 +270,9 @@ class GitMirrorConfig:
self._save_config()
logger.info(f"已删除镜像源: {mirror_id}")
return True
return False
def get_default_priority_list(self) -> List[str]:
"""获取默认优先级列表(仅启用的镜像源 ID"""
enabled = self.get_enabled_mirrors()
@@ -278,16 +281,11 @@ class GitMirrorConfig:
class GitMirrorService:
"""Git 镜像源服务"""
def __init__(
self,
max_retries: int = 3,
timeout: int = 30,
config: Optional[GitMirrorConfig] = None
):
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
"""
初始化 Git 镜像源服务
Args:
max_retries: 最大重试次数
timeout: 请求超时时间(秒)
@@ -297,16 +295,16 @@ class GitMirrorService:
self.timeout = timeout
self.config = config or GitMirrorConfig()
logger.info(f"Git镜像源服务初始化完成已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
def get_mirror_config(self) -> GitMirrorConfig:
"""获取镜像源配置管理器"""
return self.config
@staticmethod
def check_git_installed() -> Dict[str, Any]:
"""
检查本机是否安装了 Git
Returns:
Dict 包含:
- installed: bool - 是否已安装 Git
@@ -316,54 +314,33 @@ class GitMirrorService:
"""
import subprocess
import shutil
try:
# 查找 git 可执行文件路径
git_path = shutil.which("git")
if not git_path:
logger.warning("未找到 Git 可执行文件")
return {
"installed": False,
"error": "系统中未找到 Git请先安装 Git"
}
return {"installed": False, "error": "系统中未找到 Git请先安装 Git"}
# 获取 Git 版本
result = subprocess.run(
["git", "--version"],
capture_output=True,
text=True,
timeout=5
)
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
version = result.stdout.strip()
logger.info(f"检测到 Git: {version} at {git_path}")
return {
"installed": True,
"version": version,
"path": git_path
}
return {"installed": True, "version": version, "path": git_path}
else:
logger.warning(f"Git 命令执行失败: {result.stderr}")
return {
"installed": False,
"error": f"Git 命令执行失败: {result.stderr}"
}
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
except subprocess.TimeoutExpired:
logger.error("Git 版本检测超时")
return {
"installed": False,
"error": "Git 版本检测超时"
}
return {"installed": False, "error": "Git 版本检测超时"}
except Exception as e:
logger.error(f"检测 Git 时发生错误: {e}")
return {
"installed": False,
"error": f"检测 Git 时发生错误: {str(e)}"
}
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
async def fetch_raw_file(
self,
owner: str,
@@ -371,11 +348,11 @@ class GitMirrorService:
branch: str,
file_path: str,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None
custom_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
获取 GitHub 仓库的 Raw 文件内容
Args:
owner: 仓库所有者
repo: 仓库名称
@@ -383,7 +360,7 @@ class GitMirrorService:
file_path: 文件路径
mirror_id: 指定的镜像源 ID
custom_url: 自定义完整 URL如果提供将忽略其他参数
Returns:
Dict 包含:
- success: bool - 是否成功
@@ -393,29 +370,24 @@ class GitMirrorService:
- attempts: int - 尝试次数
"""
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
if custom_url:
# 使用自定义 URL
return await self._fetch_with_url(custom_url, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {
"success": False,
"error": f"未找到镜像源: {mirror_id}",
"mirror_used": None,
"attempts": 0
}
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
total_mirrors = len(mirrors_to_try)
# 依次尝试每个镜像源
for index, mirror in enumerate(mirrors_to_try, 1):
# 推送进度:正在尝试第 N 个镜像源
@@ -427,15 +399,13 @@ class GitMirrorService:
progress=progress,
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
result = await self._fetch_raw_from_mirror(
owner, repo, branch, file_path, mirror
)
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
if result["success"]:
# 成功,推送进度
if _update_progress:
@@ -445,15 +415,15 @@ class GitMirrorService:
progress=70,
message=f"成功从 {mirror['name']} 获取数据",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
return result
# 失败,记录日志并推送失败信息
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
if _update_progress and index < total_mirrors:
try:
await _update_progress(
@@ -461,39 +431,29 @@ class GitMirrorService:
progress=30 + int(index / total_mirrors * 40),
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 所有镜像源都失败
return {
"success": False,
"error": "所有镜像源均失败",
"mirror_used": None,
"attempts": len(mirrors_to_try)
}
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _fetch_raw_from_mirror(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror: Dict[str, Any]
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
raw_prefix = mirror["raw_prefix"]
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
return await self._fetch_with_url(url, mirror["id"])
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
"""使用指定 URL 获取文件,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
@@ -501,14 +461,14 @@ class GitMirrorService:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(url)
response.raise_for_status()
logger.info(f"成功获取文件: {url}")
return {
"success": True,
"data": response.text,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
"url": url,
}
except httpx.HTTPStatusError as e:
last_error = f"HTTP {e.response.status_code}: {e}"
@@ -519,15 +479,9 @@ class GitMirrorService:
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
return {
"success": False,
"error": last_error,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
async def clone_repository(
self,
owner: str,
@@ -536,11 +490,11 @@ class GitMirrorService:
branch: Optional[str] = None,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
depth: Optional[int] = None
depth: Optional[int] = None,
) -> Dict[str, Any]:
"""
克隆 GitHub 仓库
Args:
owner: 仓库所有者
repo: 仓库名称
@@ -549,7 +503,7 @@ class GitMirrorService:
mirror_id: 指定的镜像源 ID
custom_url: 自定义克隆 URL
depth: 克隆深度(浅克隆)
Returns:
Dict 包含:
- success: bool - 是否成功
@@ -559,44 +513,32 @@ class GitMirrorService:
- attempts: int - 尝试次数
"""
logger.info(f"开始克隆仓库: {owner}/{repo}{target_path}")
if custom_url:
# 使用自定义 URL
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {
"success": False,
"error": f"未找到镜像源: {mirror_id}",
"mirror_used": None,
"attempts": 0
}
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
# 依次尝试每个镜像源
for mirror in mirrors_to_try:
result = await self._clone_from_mirror(
owner, repo, target_path, branch, depth, mirror
)
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
if result["success"]:
return result
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
# 所有镜像源都失败
return {
"success": False,
"error": "所有镜像源克隆均失败",
"mirror_used": None,
"attempts": len(mirrors_to_try)
}
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _clone_from_mirror(
self,
owner: str,
@@ -604,52 +546,47 @@ class GitMirrorService:
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror: Dict[str, Any]
mirror: Dict[str, Any],
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
clone_prefix = mirror["clone_prefix"]
url = f"{clone_prefix}/{owner}/{repo}.git"
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
async def _clone_with_url(
self,
url: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror_type: str
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
) -> Dict[str, Any]:
"""使用指定 URL 克隆仓库,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
# 确保目标路径不存在
if target_path.exists():
logger.warning(f"目标路径已存在,删除: {target_path}")
shutil.rmtree(target_path, ignore_errors=True)
# 构建 git clone 命令
cmd = ["git", "clone"]
# 添加分支参数
if branch:
cmd.extend(["-b", branch])
# 添加深度参数(浅克隆)
if depth:
cmd.extend(["--depth", str(depth)])
# 添加 URL 和目标路径
cmd.extend([url, str(target_path)])
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
# 推送进度
if _update_progress:
try:
@@ -657,24 +594,24 @@ class GitMirrorService:
stage="loading",
progress=20 + attempt * 10,
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
operation="install"
operation="install",
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop()
def run_git_clone():
return subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
timeout=300, # 5分钟超时
)
process = await loop.run_in_executor(None, run_git_clone)
if process.returncode == 0:
logger.info(f"成功克隆仓库: {url} -> {target_path}")
return {
@@ -683,40 +620,34 @@ class GitMirrorService:
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
"branch": branch or "default"
"branch": branch or "default",
}
else:
last_error = f"Git 克隆失败: {process.stderr}"
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except subprocess.TimeoutExpired:
last_error = "克隆超时(超过 5 分钟)"
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
except FileNotFoundError:
last_error = "Git 未安装或不在 PATH 中"
logger.error(f"Git 未找到: {last_error}")
break # Git 不存在,不需要重试
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
return {
"success": False,
"error": last_error,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
# 全局服务实例

View File

@@ -1,4 +1,5 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
import json
@@ -14,30 +15,30 @@ active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
"""从日志文件中加载最近的日志
Args:
limit: 返回的最大日志条数
Returns:
日志列表
"""
logs = []
log_dir = Path("logs")
if not log_dir.exists():
return logs
# 获取所有日志文件,按修改时间排序
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
# 用于生成唯一 ID 的计数器
log_counter = 0
# 从最新的文件开始读取
for log_file in log_files:
if len(logs) >= limit:
break
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
@@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
timestamp_id = (
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
)
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),
@@ -64,7 +67,7 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
except Exception as e:
logger.error(f"读取日志文件失败 {log_file}: {e}")
continue
# 反转列表,使其按时间顺序排列(旧到新)
return list(reversed(logs))
@@ -72,35 +75,35 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
"""
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:
recent_logs = load_recent_logs(limit=100)
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
for log_entry in recent_logs:
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
except Exception as e:
logger.error(f"发送历史日志失败: {e}")
try:
# 保持连接,等待客户端消息或断开
while True:
# 接收客户端消息(用于心跳或控制指令)
data = await websocket.receive_text()
# 可以处理客户端的控制消息,例如:
# - "ping" -> 心跳检测
# - {"filter": "ERROR"} -> 设置日志级别过滤
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
@@ -111,19 +114,19 @@ async def websocket_logs(websocket: WebSocket):
async def broadcast_log(log_data: dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
@@ -131,7 +134,7 @@ async def broadcast_log(log_data: dict):
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)

View File

@@ -1,108 +0,0 @@
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
import os
from pathlib import Path
from src.common.logger import get_logger
from .token_manager import get_token_manager
logger = get_logger("webui")
def setup_webui(mode: str = "production") -> bool:
"""
设置 WebUI
Args:
mode: 运行模式,"development""production"
Returns:
bool: 是否成功设置
"""
# 初始化 Token 管理器(确保 token 文件存在)
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
logger.info("💡 请使用此 Token 登录 WebUI")
if mode == "development":
return setup_dev_mode()
else:
return setup_production_mode()
def setup_dev_mode() -> bool:
"""设置开发模式 - 仅启用 CORS前端自行启动"""
from src.common.server import get_global_server
from .logs_ws import router as logs_router
# 注册 WebSocket 日志路由(开发模式也需要)
server = get_global_server()
server.register_router(logs_router)
logger.info("✅ WebSocket 日志推送路由已注册")
logger.info("📝 WebUI 开发模式已启用")
logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev")
logger.info("💡 前端将运行在 http://localhost:7999")
return True
def setup_production_mode() -> bool:
"""设置生产模式 - 挂载静态文件"""
try:
from src.common.server import get_global_server
from starlette.responses import FileResponse
from .logs_ws import router as logs_router
import mimetypes
# 确保正确的 MIME 类型映射
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('application/javascript', '.mjs')
mimetypes.add_type('text/css', '.css')
mimetypes.add_type('application/json', '.json')
server = get_global_server()
# 注册 WebSocket 日志路由
server.register_router(logs_router)
logger.info("✅ WebSocket 日志推送路由已注册")
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
if not static_path.exists():
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
logger.warning("💡 请先构建前端: cd webui && npm run build")
return False
if not (static_path / "index.html").exists():
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
logger.warning("💡 请确认前端已正确构建")
return False
# 处理 SPA 路由
@server.app.get("/{full_path:path}")
async def serve_spa(full_path: str):
"""服务单页应用"""
# API 路由不处理
if full_path.startswith("api/"):
return None
# 检查文件是否存在
file_path = static_path / full_path
if file_path.is_file():
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
return FileResponse(file_path, media_type=media_type)
# 返回 index.htmlSPA 路由)
return FileResponse(static_path / "index.html", media_type="text/html")
host = os.getenv("HOST", "127.0.0.1")
port = os.getenv("PORT", "8000")
logger.info("✅ WebUI 生产模式已挂载")
logger.info(f"🌐 访问 http://{host}:{port} 查看 WebUI")
return True
except Exception as e:
logger.error(f"挂载 WebUI 静态文件失败: {e}")
return False

View File

@@ -1,4 +1,5 @@
"""人物信息管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List, Dict
@@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"])
class PersonInfoResponse(BaseModel):
"""人物信息响应"""
id: int
is_known: bool
person_id: str
@@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel):
class PersonListResponse(BaseModel):
"""人物列表响应"""
success: bool
total: int
page: int
@@ -42,12 +45,14 @@ class PersonListResponse(BaseModel):
class PersonDetailResponse(BaseModel):
"""人物详情响应"""
success: bool
data: PersonInfoResponse
class PersonUpdateRequest(BaseModel):
"""人物信息更新请求"""
person_name: Optional[str] = None
name_reason: Optional[str] = None
nickname: Optional[str] = None
@@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel):
class PersonUpdateResponse(BaseModel):
"""人物信息更新响应"""
success: bool
message: str
data: Optional[PersonInfoResponse] = None
@@ -64,21 +70,38 @@ class PersonUpdateResponse(BaseModel):
class PersonDeleteResponse(BaseModel):
"""人物删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
person_ids: List[str]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[str] = []
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
@@ -118,11 +141,11 @@ async def get_person_list(
search: Optional[str] = Query(None, description="搜索关键词"),
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
platform: Optional[str] = Query(None, description="平台筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取人物信息列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
@@ -130,58 +153,50 @@ async def get_person_list(
is_known: 是否已认识筛选
platform: 平台筛选
authorization: Authorization header
Returns:
人物信息列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = PersonInfo.select()
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search)) |
(PersonInfo.nickname.contains(search)) |
(PersonInfo.user_id.contains(search))
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 已认识状态过滤
if is_known is not None:
query = query.where(PersonInfo.is_known == is_known)
# 平台过滤
if platform:
query = query.where(PersonInfo.platform == platform)
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case
query = query.order_by(
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
PersonInfo.last_know.desc()
)
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
persons = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [person_to_response(person) for person in persons]
return PersonListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
@@ -190,33 +205,27 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(
person_id: str,
authorization: Optional[str] = Header(None)
):
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
"""
获取人物详细信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
人物详细信息
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
return PersonDetailResponse(
success=True,
data=person_to_response(person)
)
return PersonDetailResponse(success=True, data=person_to_response(person))
except HTTPException:
raise
except Exception as e:
@@ -225,53 +234,47 @@ async def get_person_detail(
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(
person_id: str,
request: PersonUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
"""
增量更新人物信息(只更新提供的字段)
Args:
person_id: 人物唯一 ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后修改时间
update_data['last_know'] = time.time()
update_data["last_know"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(person, field, value)
person.save()
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
return PersonUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=person_to_response(person)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
)
except HTTPException:
raise
except Exception as e:
@@ -280,41 +283,35 @@ async def update_person(
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(
person_id: str,
authorization: Optional[str] = Header(None)
):
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
"""
删除人物信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 记录删除信息
person_name = person.person_name or person.nickname or person.user_id
# 执行删除
person.delete_instance()
logger.info(f"人物信息已删除: {person_id} ({person_name})")
return PersonDeleteResponse(
success=True,
message=f"成功删除人物信息: {person_name}"
)
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
except HTTPException:
raise
except Exception as e:
@@ -323,43 +320,89 @@ async def delete_person(
@router.get("/stats/summary")
async def get_person_stats(
authorization: Optional[str] = Header(None)
):
async def get_person_stats(authorization: Optional[str] = Header(None)):
"""
获取人物信息统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = PersonInfo.select().count()
known = PersonInfo.select().where(PersonInfo.is_known).count()
unknown = total - known
# 按平台统计
platforms = {}
for person in PersonInfo.select(PersonInfo.platform):
platform = person.platform
platforms[platform] = platforms.get(platform, 0) + 1
return {
"success": True,
"data": {
"total": total,
"known": known,
"unknown": unknown,
"platforms": platforms
}
}
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
"""
批量删除人物信息
Args:
request: 包含person_ids列表的请求
authorization: Authorization header
Returns:
批量删除结果
"""
try:
verify_auth_token(authorization)
if not request.person_ids:
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
deleted_count = 0
failed_count = 0
failed_ids = []
for person_id in request.person_ids:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if person:
person.delete_instance()
deleted_count += 1
logger.info(f"批量删除: {person_id}")
else:
failed_count += 1
failed_ids.append(person_id)
except Exception as e:
logger.error(f"删除 {person_id} 失败: {e}")
failed_count += 1
failed_ids.append(person_id)
message = f"成功删除 {deleted_count} 个人物"
if failed_count > 0:
message += f"{failed_count} 个失败"
return BatchDeleteResponse(
success=True,
message=message,
deleted_count=deleted_count,
failed_count=failed_count,
failed_ids=failed_ids,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e

View File

@@ -1,4 +1,5 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set, Dict, Any
import json
@@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = {
"error": None,
"plugin_id": None, # 当前操作的插件 ID
"total_plugins": 0,
"loaded_plugins": 0
"loaded_plugins": 0,
}
@@ -30,20 +31,20 @@ async def broadcast_progress(progress_data: Dict[str, Any]):
"""广播进度更新到所有连接的客户端"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
# 移除断开的连接
for websocket in disconnected:
active_connections.discard(websocket)
@@ -57,10 +58,10 @@ async def update_progress(
error: str = None,
plugin_id: str = None,
total_plugins: int = 0,
loaded_plugins: int = 0
loaded_plugins: int = 0,
):
"""更新并广播进度
Args:
stage: 阶段 (idle, loading, success, error)
progress: 进度百分比 (0-100)
@@ -80,9 +81,9 @@ async def update_progress(
"plugin_id": plugin_id,
"total_plugins": total_plugins,
"loaded_plugins": loaded_plugins,
"timestamp": asyncio.get_event_loop().time()
"timestamp": asyncio.get_event_loop().time(),
}
await broadcast_progress(progress_data)
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
@@ -90,30 +91,30 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
"""
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
# 保持连接并处理客户端消息
while True:
try:
data = await websocket.receive_text()
# 处理客户端心跳
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,97 @@
"""
系统控制路由
提供系统重启、状态查询等功能
"""
import os
import sys
import time
from datetime import datetime
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from src.config.config import MMC_VERSION
router = APIRouter(prefix="/system", tags=["system"])
# 记录启动时间
_start_time = time.time()
class RestartResponse(BaseModel):
"""重启响应"""
success: bool
message: str
class StatusResponse(BaseModel):
"""状态响应"""
running: bool
uptime: float
version: str
start_time: str
@router.post("/restart", response_model=RestartResponse)
async def restart_maibot():
"""
重启麦麦主程序
使用 os.execv 重启当前进程,配置更改将在重启后生效。
注意:此操作会使麦麦暂时离线。
"""
try:
# 记录重启操作
print(f"[{datetime.now()}] WebUI 触发重启操作")
# 使用 os.execv 重启当前进程
# 这会替换当前进程,保持相同的 PID
python = sys.executable
args = [python] + sys.argv
# 返回成功响应(实际上这个响应可能不会发送,因为进程会立即重启)
# 但我们仍然返回它以保持 API 一致性
os.execv(python, args)
return RestartResponse(success=True, message="麦麦正在重启中...")
except Exception as e:
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status():
"""
获取麦麦运行状态
返回麦麦的运行状态、运行时长和版本信息。
"""
try:
uptime = time.time() - _start_time
# 尝试获取版本信息(需要根据实际情况调整)
version = MMC_VERSION # 可以从配置或常量中读取
return StatusResponse(
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
# 可选:添加更多系统控制功能
@router.post("/reload-config")
async def reload_config():
"""
热重载配置(不重启进程)
仅重新加载配置文件,某些配置可能需要重启才能生效。
此功能需要在主程序中实现配置热重载逻辑。
"""
# 这里需要调用主程序的配置重载函数
# 示例await app_instance.reload_config()
return {"success": True, "message": "配置重载功能待实现"}

View File

@@ -1,4 +1,5 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header
from pydantic import BaseModel, Field
from typing import Optional
@@ -11,6 +12,7 @@ from .expression_routes import router as expression_router
from .emoji_routes import router as emoji_router
from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
from .routers.system import router as system_router
logger = get_logger("webui.api")
@@ -31,32 +33,39 @@ router.include_router(emoji_router)
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
class TokenVerifyRequest(BaseModel):
"""Token 验证请求"""
token: str = Field(..., description="访问令牌")
class TokenVerifyResponse(BaseModel):
"""Token 验证响应"""
valid: bool = Field(..., description="Token 是否有效")
message: str = Field(..., description="验证结果消息")
class TokenUpdateRequest(BaseModel):
"""Token 更新请求"""
new_token: str = Field(..., description="新的访问令牌", min_length=10)
class TokenUpdateResponse(BaseModel):
"""Token 更新响应"""
success: bool = Field(..., description="是否更新成功")
message: str = Field(..., description="更新结果消息")
class TokenRegenerateResponse(BaseModel):
"""Token 重新生成响应"""
success: bool = Field(..., description="是否生成成功")
token: str = Field(..., description="新生成的令牌")
message: str = Field(..., description="生成结果消息")
@@ -64,18 +73,21 @@ class TokenRegenerateResponse(BaseModel):
class FirstSetupStatusResponse(BaseModel):
"""首次配置状态响应"""
is_first_setup: bool = Field(..., description="是否为首次配置")
message: str = Field(..., description="状态消息")
class CompleteSetupResponse(BaseModel):
"""完成配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
class ResetSetupResponse(BaseModel):
"""重置配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
@@ -90,44 +102,35 @@ async def health_check():
async def verify_token(request: TokenVerifyRequest):
"""
验证访问令牌
Args:
request: 包含 token 的验证请求
Returns:
验证结果
"""
try:
token_manager = get_token_manager()
is_valid = token_manager.verify_token(request.token)
if is_valid:
return TokenVerifyResponse(
valid=True,
message="Token 验证成功"
)
return TokenVerifyResponse(valid=True, message="Token 验证成功")
else:
return TokenVerifyResponse(
valid=False,
message="Token 无效或已过期"
)
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e
@router.post("/auth/update", response_model=TokenUpdateResponse)
async def update_token(
request: TokenUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
"""
更新访问令牌(需要当前有效的 token
Args:
request: 包含新 token 的更新请求
authorization: Authorization header (Bearer token)
Returns:
更新结果
"""
@@ -135,20 +138,17 @@ async def update_token(
# 验证当前 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 更新 token
success, message = token_manager.update_token(request.new_token)
return TokenUpdateResponse(
success=success,
message=message
)
return TokenUpdateResponse(success=success, message=message)
except HTTPException:
raise
except Exception as e:
@@ -160,10 +160,10 @@ async def update_token(
async def regenerate_token(authorization: Optional[str] = Header(None)):
"""
重新生成访问令牌(需要当前有效的 token
Args:
authorization: Authorization header (Bearer token)
Returns:
新生成的 token
"""
@@ -171,21 +171,17 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
# 验证当前 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 重新生成 token
new_token = token_manager.regenerate_token()
return TokenRegenerateResponse(
success=True,
token=new_token,
message="Token 已重新生成"
)
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
except HTTPException:
raise
except Exception as e:
@@ -197,10 +193,10 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
async def get_setup_status(authorization: Optional[str] = Header(None)):
"""
获取首次配置状态
Args:
authorization: Authorization header (Bearer token)
Returns:
首次配置状态
"""
@@ -208,20 +204,17 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
# 验证 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 检查是否为首次配置
is_first = token_manager.is_first_setup()
return FirstSetupStatusResponse(
is_first_setup=is_first,
message="首次配置" if is_first else "已完成配置"
)
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
except HTTPException:
raise
except Exception as e:
@@ -233,10 +226,10 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
async def complete_setup(authorization: Optional[str] = Header(None)):
"""
标记首次配置完成
Args:
authorization: Authorization header (Bearer token)
Returns:
完成结果
"""
@@ -244,20 +237,17 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
# 验证 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 标记配置完成
success = token_manager.mark_setup_completed()
return CompleteSetupResponse(
success=success,
message="配置已完成" if success else "标记失败"
)
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
except HTTPException:
raise
except Exception as e:
@@ -269,10 +259,10 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
async def reset_setup(authorization: Optional[str] = Header(None)):
"""
重置首次配置状态,允许重新进入配置向导
Args:
authorization: Authorization header (Bearer token)
Returns:
重置结果
"""
@@ -280,20 +270,17 @@ async def reset_setup(authorization: Optional[str] = Header(None)):
# 验证 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 重置配置状态
success = token_manager.reset_setup_status()
return ResetSetupResponse(
success=success,
message="配置状态已重置" if success else "重置失败"
)
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
except HTTPException:
raise
except Exception as e:

View File

@@ -1,9 +1,10 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Any, List
from datetime import datetime, timedelta
from collections import defaultdict
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"])
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
@@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel):
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
@@ -37,6 +40,7 @@ class ModelStatistics(BaseModel):
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
@@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel):
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
@@ -56,39 +61,39 @@ class DashboardData(BaseModel):
async def get_dashboard_data(hours: int = 24):
"""
获取仪表盘统计数据
Args:
hours: 统计时间范围小时默认24小时
Returns:
仪表盘数据
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
# 获取摘要数据
summary = await _get_summary_statistics(start_time, now)
# 获取模型统计
model_stats = await _get_model_statistics(start_time)
# 获取小时级时间序列数据
hourly_data = await _get_hourly_statistics(start_time, now)
# 获取日级时间序列数据最近7天
daily_start = now - timedelta(days=7)
daily_data = await _get_daily_statistics(daily_start, now)
# 获取最近活动
recent_activity = await _get_recent_activity(limit=10)
return DashboardData(
summary=summary,
model_stats=model_stats,
hourly_data=hourly_data,
daily_data=daily_data,
recent_activity=recent_activity
recent_activity=recent_activity,
)
except Exception as e:
logger.error(f"获取仪表盘数据失败: {e}")
@@ -96,200 +101,176 @@ async def get_dashboard_data(hours: int = 24):
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
"""获取摘要统计数据"""
"""获取摘要统计数据(优化:使用数据库聚合)"""
summary = StatisticsSummary()
# 查询 LLM 使用记录
llm_records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
total_time_cost = 0.0
time_cost_count = 0
for record in llm_records:
summary.total_requests += 1
summary.total_cost += record.cost or 0.0
summary.total_tokens += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
if record.time_cost and record.time_cost > 0:
total_time_cost += record.time_cost
time_cost_count += 1
# 计算平均响应时间
if time_cost_count > 0:
summary.avg_response_time = total_time_cost / time_cost_count
# 查询在线时间
# 使用聚合查询替代全量加载
query = LLMUsage.select(
fn.COUNT(LLMUsage.id).alias("total_requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
result = query.dicts().get()
summary.total_requests = result["total_requests"]
summary.total_cost = result["total_cost"]
summary.total_tokens = result["total_tokens"]
summary.avg_response_time = result["avg_response_time"] or 0.0
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
online_records = list(
OnlineTime.select()
.where(
(OnlineTime.start_timestamp >= start_time) |
(OnlineTime.end_timestamp >= start_time)
)
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
)
for record in online_records:
start = max(record.start_timestamp, start_time)
end = min(record.end_timestamp, end_time)
if end > start:
summary.online_time += (end - start).total_seconds()
# 查询消息数量
messages = list(
Messages.select()
.where(Messages.time >= start_time.timestamp())
.where(Messages.time <= end_time.timestamp())
# 查询消息数量 - 使用聚合优化
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
)
summary.total_messages = len(messages)
# 简单统计:如果 reply_to 不为空,则认为是回复
summary.total_replies = len([m for m in messages if m.reply_to])
summary.total_messages = messages_query.scalar() or 0
# 统计回复数量
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp())
& (Messages.time <= end_time.timestamp())
& (Messages.reply_to.is_null(False))
)
summary.total_replies = replies_query.scalar() or 0
# 计算派生指标
if summary.online_time > 0:
online_hours = summary.online_time / 3600.0
summary.cost_per_hour = summary.total_cost / online_hours
summary.tokens_per_hour = summary.total_tokens / online_hours
return summary
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
"""获取模型统计数据"""
model_data = defaultdict(lambda: {
'request_count': 0,
'total_cost': 0.0,
'total_tokens': 0,
'time_costs': []
})
records = list(
LLMUsage.select()
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
# 使用GROUP BY聚合避免全量加载
query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
fn.COUNT(LLMUsage.id).alias("request_count"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
)
.where(LLMUsage.timestamp >= start_time)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10) # 只取前10个
)
for record in records:
model_name = record.model_assign_name or record.model_name or "unknown"
model_data[model_name]['request_count'] += 1
model_data[model_name]['total_cost'] += record.cost or 0.0
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
if record.time_cost and record.time_cost > 0:
model_data[model_name]['time_costs'].append(record.time_cost)
# 转换为列表并排序
result = []
for model_name, data in model_data.items():
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
result.append(ModelStatistics(
model_name=model_name,
request_count=data['request_count'],
total_cost=data['total_cost'],
total_tokens=data['total_tokens'],
avg_response_time=avg_time
))
# 按请求数排序
result.sort(key=lambda x: x.request_count, reverse=True)
return result[:10] # 返回前10个
for row in query.dicts():
result.append(
ModelStatistics(
model_name=row["model_name"],
request_count=row["request_count"],
total_cost=row["total_cost"],
total_tokens=row["total_tokens"],
avg_response_time=row["avg_response_time"] or 0.0,
)
)
return result
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取小时级统计数据"""
# 创建小时桶
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
"""获取小时级统计数据(优化:使用数据库聚合)"""
# SQLite的日期时间函数进行小时分组
# 使用strftime将timestamp格式化为小时级别
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
)
for record in records:
# 获取小时键(去掉分钟和秒)
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
hour_str = hour_key.isoformat()
hourly_buckets[hour_str]['requests'] += 1
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
# 转换为字典以快速查找
data_dict = {row["hour"]: row for row in query.dicts()}
# 填充所有小时(包括没有数据的)
result = []
current = start_time.replace(minute=0, second=0, microsecond=0)
while current <= end_time:
hour_str = current.isoformat()
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
result.append(TimeSeriesData(
timestamp=hour_str,
requests=data['requests'],
cost=data['cost'],
tokens=data['tokens']
))
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
if hour_str in data_dict:
row = data_dict[hour_str]
result.append(
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
)
else:
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
current += timedelta(hours=1)
return result
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取日级统计数据"""
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
"""获取日级统计数据(优化:使用数据库聚合)"""
# 使用strftime按日期分组
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
)
for record in records:
# 获取日期键
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
day_str = day_key.isoformat()
daily_buckets[day_str]['requests'] += 1
daily_buckets[day_str]['cost'] += record.cost or 0.0
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
# 转换为字典
data_dict = {row["day"]: row for row in query.dicts()}
# 填充所有天
result = []
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
while current <= end_time:
day_str = current.isoformat()
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
result.append(TimeSeriesData(
timestamp=day_str,
requests=data['requests'],
cost=data['cost'],
tokens=data['tokens']
))
day_str = current.strftime("%Y-%m-%dT00:00:00")
if day_str in data_dict:
row = data_dict[day_str]
result.append(
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
)
else:
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
current += timedelta(days=1)
return result
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近活动"""
records = list(
LLMUsage.select()
.order_by(LLMUsage.timestamp.desc())
.limit(limit)
)
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
activities = []
for record in records:
activities.append({
'timestamp': record.timestamp.isoformat(),
'model': record.model_assign_name or record.model_name,
'request_type': record.request_type,
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
'cost': record.cost or 0.0,
'time_cost': record.time_cost or 0.0,
'status': record.status
})
activities.append(
{
"timestamp": record.timestamp.isoformat(),
"model": record.model_assign_name or record.model_name,
"request_type": record.request_type,
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
"cost": record.cost or 0.0,
"time_cost": record.time_cost or 0.0,
"status": record.status,
}
)
return activities
@@ -297,7 +278,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
async def get_summary(hours: int = 24):
"""
获取统计摘要
Args:
hours: 统计时间范围(小时)
"""
@@ -315,7 +296,7 @@ async def get_summary(hours: int = 24):
async def get_model_stats(hours: int = 24):
"""
获取模型统计
Args:
hours: 统计时间范围(小时)
"""

View File

@@ -19,7 +19,7 @@ class TokenManager:
def __init__(self, config_path: Optional[Path] = None):
"""
初始化 Token 管理器
Args:
config_path: 配置文件路径,默认为项目根目录的 data/webui.json
"""
@@ -27,10 +27,10 @@ class TokenManager:
# 获取项目根目录 (src/webui -> src -> 根目录)
project_root = Path(__file__).parent.parent.parent
config_path = project_root / "data" / "webui.json"
self.config_path = config_path
self.config_path.parent.mkdir(parents=True, exist_ok=True)
# 确保配置文件存在并包含有效的 token
self._ensure_config()
@@ -75,22 +75,23 @@ class TokenManager:
"""生成新的 64 位随机 token"""
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
token = secrets.token_hex(32)
config = {
"access_token": token,
"created_at": self._get_current_timestamp(),
"updated_at": self._get_current_timestamp(),
"first_setup_completed": False # 标记首次配置未完成
"first_setup_completed": False, # 标记首次配置未完成
}
self._save_config(config)
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
return token
def _get_current_timestamp(self) -> str:
"""获取当前时间戳字符串"""
from datetime import datetime
return datetime.now().isoformat()
def get_token(self) -> str:
@@ -101,38 +102,38 @@ class TokenManager:
def verify_token(self, token: str) -> bool:
"""
验证 token 是否有效
Args:
token: 待验证的 token
Returns:
bool: token 是否有效
"""
if not token:
return False
current_token = self.get_token()
if not current_token:
logger.error("系统中没有有效的 token")
return False
# 使用 secrets.compare_digest 防止时序攻击
is_valid = secrets.compare_digest(token, current_token)
if is_valid:
logger.debug("Token 验证成功")
else:
logger.warning("Token 验证失败")
return is_valid
def update_token(self, new_token: str) -> tuple[bool, str]:
"""
更新 token
Args:
new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号)
Returns:
tuple[bool, str]: (是否更新成功, 错误消息)
"""
@@ -141,17 +142,17 @@ class TokenManager:
if not is_valid:
logger.error(f"Token 格式无效: {error_msg}")
return False, error_msg
try:
config = self._load_config()
old_token = config.get("access_token", "")[:8]
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
self._save_config(config)
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
return True, "Token 更新成功"
except Exception as e:
logger.error(f"更新 Token 失败: {e}")
@@ -160,7 +161,7 @@ class TokenManager:
def regenerate_token(self) -> str:
"""
重新生成 token
Returns:
str: 新生成的 token
"""
@@ -170,20 +171,20 @@ class TokenManager:
def _validate_token_format(self, token: str) -> bool:
"""
验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token
Args:
token: 待验证的 token
Returns:
bool: 格式是否正确
"""
if not token or not isinstance(token, str):
return False
# 必须是 64 位十六进制字符串
if len(token) != 64:
return False
# 验证是否为有效的十六进制字符串
try:
int(token, 16)
@@ -194,48 +195,48 @@ class TokenManager:
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
"""
验证自定义 token 格式
要求:
- 最少 10 位
- 包含大写字母
- 包含小写字母
- 包含特殊符号
Args:
token: 待验证的 token
Returns:
tuple[bool, str]: (是否有效, 错误消息)
"""
if not token or not isinstance(token, str):
return False, "Token 不能为空"
# 检查长度
if len(token) < 10:
return False, "Token 长度至少为 10 位"
# 检查是否包含大写字母
has_upper = any(c.isupper() for c in token)
if not has_upper:
return False, "Token 必须包含大写字母"
# 检查是否包含小写字母
has_lower = any(c.islower() for c in token)
if not has_lower:
return False, "Token 必须包含小写字母"
# 检查是否包含特殊符号
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
has_special = any(c in special_chars for c in token)
if not has_special:
return False, f"Token 必须包含特殊符号 ({special_chars})"
return True, "Token 格式正确"
def is_first_setup(self) -> bool:
"""
检查是否为首次配置
Returns:
bool: 是否为首次配置
"""
@@ -245,7 +246,7 @@ class TokenManager:
def mark_setup_completed(self) -> bool:
"""
标记首次配置已完成
Returns:
bool: 是否标记成功
"""
@@ -263,7 +264,7 @@ class TokenManager:
def reset_setup_status(self) -> bool:
"""
重置首次配置状态,允许重新进入配置向导
Returns:
bool: 是否重置成功
"""

148
src/webui/webui_server.py Normal file
View File

@@ -0,0 +1,148 @@
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
import os
import asyncio
import mimetypes
from pathlib import Path
from fastapi import FastAPI
from fastapi.responses import FileResponse
from uvicorn import Config, Server as UvicornServer
from src.common.logger import get_logger
logger = get_logger("webui_server")
class WebUIServer:
"""独立的 WebUI 服务器"""
def __init__(self, host: str = "0.0.0.0", port: int = 8001):
self.host = host
self.port = port
self.app = FastAPI(title="MaiBot WebUI")
self._server = None
# 显示 Access Token
self._show_access_token()
# 重要:先注册 API 路由,再设置静态文件
self._register_api_routes()
self._setup_static_files()
def _show_access_token(self):
"""显示 WebUI Access Token"""
try:
from src.webui.token_manager import get_token_manager
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
logger.info("💡 请使用此 Token 登录 WebUI")
except Exception as e:
logger.error(f"❌ 获取 Access Token 失败: {e}")
def _setup_static_files(self):
"""设置静态文件服务"""
# 确保正确的 MIME 类型映射
mimetypes.init()
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/json", ".json")
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
if not static_path.exists():
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
logger.warning("💡 请先构建前端: cd webui && npm run build")
return
if not (static_path / "index.html").exists():
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
logger.warning("💡 请确认前端已正确构建")
return
# 处理 SPA 路由 - 注意:这个路由优先级最低
@self.app.get("/{full_path:path}", include_in_schema=False)
async def serve_spa(full_path: str):
"""服务单页应用 - 只处理非 API 请求"""
# 如果是根路径,直接返回 index.html
if not full_path or full_path == "/":
return FileResponse(static_path / "index.html", media_type="text/html")
# 检查是否是静态文件
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
return FileResponse(file_path, media_type=media_type)
# 其他路径返回 index.htmlSPA 路由)
return FileResponse(static_path / "index.html", media_type="text/html")
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
def _register_api_routes(self):
"""注册所有 WebUI API 路由"""
try:
# 导入所有 WebUI 路由
from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router
# 注册路由
self.app.include_router(webui_router)
self.app.include_router(logs_router)
logger.info("✅ WebUI API 路由已注册")
except Exception as e:
logger.error(f"❌ 注册 WebUI API 路由失败: {e}")
async def start(self):
"""启动服务器"""
config = Config(
app=self.app,
host=self.host,
port=self.port,
log_config=None,
access_log=False,
)
self._server = UvicornServer(config=config)
logger.info("🌐 WebUI 服务器启动中...")
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
try:
await self._server.serve()
except Exception as e:
logger.error(f"❌ WebUI 服务器运行错误: {e}")
raise
async def shutdown(self):
"""关闭服务器"""
if self._server:
logger.info("正在关闭 WebUI 服务器...")
self._server.should_exit = True
try:
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
logger.info("✅ WebUI 服务器已关闭")
except asyncio.TimeoutError:
logger.warning("⚠️ WebUI 服务器关闭超时")
except Exception as e:
logger.error(f"❌ WebUI 服务器关闭失败: {e}")
finally:
self._server = None
# 全局 WebUI 服务器实例
_webui_server = None
def get_webui_server() -> WebUIServer:
"""获取全局 WebUI 服务器实例"""
global _webui_server
if _webui_server is None:
# 从环境变量读取配置
host = os.getenv("WEBUI_HOST", "0.0.0.0")
port = int(os.getenv("WEBUI_PORT", "8001"))
_webui_server = WebUIServer(host=host, port=port)
return _webui_server

View File

@@ -1,5 +1,5 @@
[inner]
version = "6.21.6"
version = "6.21.8"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -211,6 +211,9 @@ show_prompt = false # 是否显示prompt
show_replyer_prompt = false # 是否显示回复器prompt
show_replyer_reasoning = false # 是否显示回复器推理
show_jargon_prompt = false # 是否显示jargon相关提示词
show_memory_prompt = false # 是否显示记忆检索相关提示词
show_planner_prompt = false # 是否显示planner的prompt和原始返回结果
show_lpmm_paragraph = false # 是否显示lpmm找到的相关文段日志
[maim_message]
auth_token = [] # 认证令牌用于API验证为空则不启用验证

View File

@@ -1,6 +1,9 @@
# 麦麦主程序配置
HOST=127.0.0.1
PORT=8000
# WebUI 配置
# WebUI 独立服务器配置
WEBUI_ENABLED=true
WEBUI_MODE=production # 生产模式
WEBUI_MODE=production # 模式: development(开发) 或 production(生产)
WEBUI_HOST=0.0.0.0 # WebUI 服务器监听地址
WEBUI_PORT=8001 # WebUI 服务器端口

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More