Merge branch 'main' into helm-chart
This commit is contained in:
5
.github/workflows/docker-image.yml
vendored
5
.github/workflows/docker-image.yml
vendored
@@ -1,11 +1,12 @@
|
||||
name: Docker Build and Push
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- classical
|
||||
- dev
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
- "v*"
|
||||
@@ -24,6 +25,7 @@ jobs:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
@@ -77,6 +79,7 @@ jobs:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,7 +11,6 @@ run_maibot_core.bat
|
||||
run_voice.bat
|
||||
run_napcat_adapter.bat
|
||||
run_ad.bat
|
||||
s4u.s4u
|
||||
llm_tool_benchmark_results.json
|
||||
MaiBot-Napcat-Adapter-main
|
||||
MaiBot-Napcat-Adapter
|
||||
@@ -27,6 +26,7 @@ run.bat
|
||||
log_debug/
|
||||
run_amds.bat
|
||||
run_none.bat
|
||||
docs-mai/
|
||||
run.py
|
||||
message_queue_content.txt
|
||||
message_queue_content.bat
|
||||
@@ -51,6 +51,7 @@ template/compare/model_config_template.toml
|
||||
src/plugins/utils/statistic.py
|
||||
CLAUDE.md
|
||||
MaiBot-Dashboard/
|
||||
cloudflare-workers/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
@@ -25,6 +25,8 @@ WORKDIR /MaiMBot
|
||||
# 复制依赖列表
|
||||
COPY requirements.txt .
|
||||
|
||||
RUN apt-get update && apt-get install -y git
|
||||
|
||||
# 从编译阶段复制 LPMM 编译结果
|
||||
COPY --from=lpmm-builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/
|
||||
|
||||
|
||||
43
README.md
43
README.md
@@ -25,11 +25,12 @@
|
||||
|
||||
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
||||
|
||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。
|
||||
- 🤔 **实时思维系统**:模拟人类思考过程。
|
||||
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
||||
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
||||
- 🔌 **强大插件系统**:提供API和事件系统,可编写强大插件。
|
||||
- 💭 **拟人构建的prompt**:使用自然语言风格构建回复器的prompt,实现近似人类言语习惯的回复。
|
||||
- 💭 **行为规划**:在合适的时间说话,使用合适的动作
|
||||
- 🧠 **表达学习**:学习群友的说话风格和表达方式,学会真实人类的说话风格
|
||||
- 🤔 **黑话学习**:自主的学习没有见过的词语,尝试理解并认知含义
|
||||
- 🔌 **插件系统**:提供API和事件系统,可编写丰富插件。
|
||||
- 💝 **情感表达**:情绪系统和表情包系统。
|
||||
|
||||
<div style="text-align: center">
|
||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||
@@ -44,7 +45,9 @@
|
||||
|
||||
## 🔥 更新和安装
|
||||
|
||||
**最新版本: v0.11.3** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
**最新版本: v0.11.5** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||
@@ -60,15 +63,10 @@
|
||||
|
||||
> [!WARNING]
|
||||
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
|
||||
> - 有问题可以提交 Issue 或者 Discussion。
|
||||
> - 有问题可以提交 Issue 。
|
||||
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
||||
> - 由于程序处于开发中,可能消耗较多 token。
|
||||
|
||||
## 麦麦MC项目MaiCraft(早期开发)
|
||||
[让麦麦玩MC](https://github.com/MaiM-with-u/Maicraft)
|
||||
|
||||
交流群:1058573197
|
||||
|
||||
## 💬 讨论
|
||||
|
||||
**技术交流群:**
|
||||
@@ -80,7 +78,7 @@
|
||||
**聊天吹水群:**
|
||||
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
|
||||
|
||||
**插件开发测试版群:**
|
||||
**插件开发/测试版讨论群:**
|
||||
- [插件开发群](https://qm.qq.com/q/1036092828)
|
||||
|
||||
## 📚 文档
|
||||
@@ -89,7 +87,22 @@
|
||||
|
||||
- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。
|
||||
|
||||
### 设计理念(原始时代的火花)
|
||||
|
||||
## 📚 衍生项目
|
||||
|
||||
### MaiCraft(早期开发)
|
||||
[MaiCraft](https://github.com/MaiM-with-u/Maicraft)
|
||||
> 让麦麦具有玩MC能力的项目
|
||||
> 交流群:1058573197
|
||||
|
||||
### MoFox_Bot
|
||||
[MoFox - 仓库地址](https://github.com/MoFox-Studio/MoFox-Core)
|
||||
> MoFox_Bot 是一个基于 MaiCore 0.10.0 snapshot.5 的增强型 fork 项目
|
||||
> 我们保留了原项目几乎所有核心功能,并在此基础上进行了深度优化与功能扩展,致力于打造一个更稳定、更智能、更具趣味性的 AI 智能体。
|
||||
|
||||
|
||||
|
||||
## 设计理念(原始时代的火花)
|
||||
|
||||
> **千石可乐说:**
|
||||
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
|
||||
@@ -101,7 +114,7 @@
|
||||
## 🙋 贡献和致谢
|
||||
你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦!
|
||||
MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉
|
||||
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)。(待补完)
|
||||
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs-src/CONTRIBUTE.md)。(待补完)
|
||||
|
||||
### 贡献者
|
||||
|
||||
|
||||
17
bot.py
17
bot.py
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import platform
|
||||
import traceback
|
||||
@@ -30,7 +29,7 @@ else:
|
||||
raise
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
|
||||
|
||||
initialize_logging()
|
||||
|
||||
@@ -76,6 +75,15 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
||||
try:
|
||||
logger.info("正在优雅关闭麦麦...")
|
||||
|
||||
# 关闭 WebUI 服务器
|
||||
try:
|
||||
from src.webui.webui_server import get_webui_server
|
||||
webui_server = get_webui_server()
|
||||
if webui_server and webui_server._server:
|
||||
await webui_server.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭 WebUI 服务器时出错: {e}")
|
||||
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
@@ -212,9 +220,10 @@ if __name__ == "__main__":
|
||||
# 创建事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
||||
# 初始化 WebSocket 日志推送
|
||||
from src.common.logger import initialize_ws_handler
|
||||
|
||||
initialize_ws_handler(loop)
|
||||
|
||||
try:
|
||||
@@ -251,7 +260,7 @@ if __name__ == "__main__":
|
||||
print(f"关闭日志系统时出错: {e}")
|
||||
|
||||
print("[主程序] 准备退出...")
|
||||
|
||||
|
||||
# 使用 os._exit() 强制退出,避免被阻塞
|
||||
# 由于已经在 graceful_shutdown() 中完成了所有清理工作,这是安全的
|
||||
os._exit(exit_code)
|
||||
|
||||
@@ -1,6 +1,88 @@
|
||||
# Changelog
|
||||
## [0.11.5] - 2025-11-21
|
||||
### 🌟 重大更新
|
||||
- WebUI 现支持手动重启麦麦,曲线救国版“热重载”
|
||||
- 新增麦麦 QQ 适配器可视化编辑 UI(独立进程,需手动上传/下载并覆盖适配器文件)
|
||||
- 麦麦主程序配置支持可视化模式与源代码模式双模式编辑,后端执行 TOML 校验
|
||||
- 优化 planner 与 replyer 协同机制,调试日志更细
|
||||
|
||||
## [0.11.3] - 2025-11-17
|
||||
### 新增
|
||||
- 表情包管理、人物信息管理、表达方式管理界面手机端适配
|
||||
- 配置页“重启麦麦”提示
|
||||
- 详细的debug prompt显示配置
|
||||
- 麦麦界面操作主题色按钮
|
||||
- 前端集成 CodeMirror(Python/JSON/TOML 语法高亮)并对 JSON 配置提供自动纠错提示
|
||||
|
||||
### 修复
|
||||
- 表情包缩略图过小
|
||||
- 添加模型后无法立即显示
|
||||
- 插件市场分类错误
|
||||
- 浅色模式下日志查看器背景异常
|
||||
- 表情包详情窗口无法关闭
|
||||
- 情绪标签无法正常读取(issues/1373)
|
||||
- 模型任务配置界面模型温度不可更改(issues/1369)
|
||||
- 模型任务配置界面部分输入框无法删除默认值
|
||||
- 插件商店默认勾选“仅显示兼容插件”
|
||||
- 插件商店标签页数量显示错误
|
||||
- 日志查看器日志换行异常
|
||||
- 主题色未正确应用
|
||||
- 侧边栏滚动条问题
|
||||
- 移除前端 TOML 解析库导致的兼容问题
|
||||
|
||||
### 优化
|
||||
- 首页加载用户体验
|
||||
- 首页加载速度(9s → <1s)
|
||||
- 整个界面的初屏加载速度
|
||||
|
||||
### 更新
|
||||
- 适配器配置支持上传模式与指定路径模式;指定路径模式可免去反复上传/下载配置文件
|
||||
- 适配器配置界面标签页响应式优化,小屏仅显示简短标签
|
||||
- 麦麦资源管理所有界面的批量删除能力
|
||||
- 资源管理所有界面的分页、每页数量选择与页跳转
|
||||
- 插件市场新增点赞、点踩、评分与下载量统计(基于 Cloudflare,保证国内可访问)
|
||||
- 麦麦表情包查看界面的描述部分支持 Markdown 渲染
|
||||
|
||||
## [0.11.4] - 2025-11-19
|
||||
### 🌟 主要更新内容
|
||||
- **首个官方 Web 管理界面上线**:在此版本之前,MaiBot 没有 WebUI,所有配置需手动编辑 TOML 文件
|
||||
- **认证系统**:Token 安全登录(支持系统生成 64 位随机令牌 / 自定义 Token),首次配置向导
|
||||
- **配置管理(可视化编辑,无需手动改 TOML)**:
|
||||
- 麦麦主程序配置:基础设置、人格、表情、黑话、情绪等
|
||||
- 模型提供商配置:OpenAI、Anthropic、DeepSeek、Qwen、Ollama 等
|
||||
- 模型配置:对话/视觉/嵌入模型分配
|
||||
- **资源管理**:
|
||||
- 表情包管理:查看、搜索、注册、封禁
|
||||
- 表达方式管理:查看麦麦的表达记录
|
||||
- 人物信息管理:查看联系人列表
|
||||
- **插件系统**:
|
||||
- 插件市场浏览
|
||||
- 一键安装/卸载/更新
|
||||
- 版本兼容性检查
|
||||
- 实时安装进度推送
|
||||
- **日志查看器**:
|
||||
- WebSocket 实时日志流
|
||||
- 日志级别过滤(DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
- 搜索功能
|
||||
- **主题定制**:
|
||||
- 浅色/深色/跟随系统
|
||||
- 12 种主题色(6 单色 + 6 渐变色)
|
||||
- 自定义颜色选择器
|
||||
- **全局搜索**:Cmd/Ctrl + K 快捷键,快速跳转任意页面
|
||||
|
||||
### 细节
|
||||
- **技术栈**:
|
||||
- 前端: React 19 + TypeScript + Vite + TanStack Router + shadcn/ui
|
||||
- 后端: FastAPI + Uvicorn + WebSocket
|
||||
- 特点: SPA 单页应用,前后端同端口,静态文件托管
|
||||
- **使用方式**:参照 template.env 文件更新 .env 文件,添加两个字段:
|
||||
- `WEBUI_ENABLED=true`
|
||||
- `WEBUI_MODE=production`
|
||||
- **WebUI 开源协议**:GPLv3
|
||||
- **WebUI 地址**:https://github.com/Mai-with-u/MaiBot-Dashboard
|
||||
|
||||
告别手动编辑配置文件,享受现代化图形界面!
|
||||
|
||||
## [0.11.3] - 2025-11-18
|
||||
### 功能更改和修复
|
||||
- 优化记忆提取策略
|
||||
- 优化黑话提取
|
||||
|
||||
@@ -29,7 +29,8 @@ services:
|
||||
- TZ=Asia/Shanghai
|
||||
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
|
||||
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
|
||||
# ports:
|
||||
ports:
|
||||
- "18001:8001" # webui端口
|
||||
# - "8000:8000"
|
||||
volumes:
|
||||
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件
|
||||
|
||||
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 11 KiB |
BIN
docs/image-1.png
BIN
docs/image-1.png
Binary file not shown.
|
Before Width: | Height: | Size: 21 KiB |
BIN
docs/image.png
BIN
docs/image.png
Binary file not shown.
|
Before Width: | Height: | Size: 4.9 KiB |
@@ -1,336 +0,0 @@
|
||||
# 模型配置指南
|
||||
|
||||
本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
|
||||
|
||||
## 配置文件结构
|
||||
|
||||
配置文件主要包含以下几个部分:
|
||||
- 版本信息
|
||||
- API服务提供商配置
|
||||
- 模型配置
|
||||
- 模型任务配置
|
||||
|
||||
## 1. 版本信息
|
||||
|
||||
```toml
|
||||
[inner]
|
||||
version = "1.1.1"
|
||||
```
|
||||
|
||||
用于标识配置文件的版本,遵循语义化版本规则。
|
||||
|
||||
## 2. API服务提供商配置
|
||||
|
||||
### 2.1 基本配置
|
||||
|
||||
使用 `[[api_providers]]` 数组配置多个API服务提供商:
|
||||
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "DeepSeek" # 服务商名称(自定义)
|
||||
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
|
||||
api_key = "your-api-key-here" # API密钥
|
||||
client_type = "openai" # 客户端类型
|
||||
max_retry = 2 # 最大重试次数
|
||||
timeout = 30 # 超时时间(秒)
|
||||
retry_interval = 10 # 重试间隔(秒)
|
||||
```
|
||||
|
||||
### 2.2 配置参数说明
|
||||
|
||||
| 参数 | 必填 | 说明 | 默认值 |
|
||||
|------|------|------|--------|
|
||||
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
||||
| `base_url` | ✅ | API服务的基础URL | - |
|
||||
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式) | `openai` |
|
||||
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
||||
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
||||
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
||||
|
||||
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
|
||||
### 2.3 支持的服务商示例
|
||||
|
||||
#### DeepSeek
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "DeepSeek"
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
api_key = "your-deepseek-api-key"
|
||||
client_type = "openai"
|
||||
```
|
||||
|
||||
#### SiliconFlow
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "SiliconFlow"
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
api_key = "your-siliconflow-api-key"
|
||||
client_type = "openai"
|
||||
```
|
||||
|
||||
#### Google Gemini
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "Google"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
api_key = "your-google-api-key"
|
||||
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
||||
```
|
||||
|
||||
## 3. 模型配置
|
||||
|
||||
### 3.1 基本模型配置
|
||||
|
||||
使用 `[[models]]` 数组配置多个模型:
|
||||
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
|
||||
name = "deepseek-v3" # 自定义模型名称
|
||||
api_provider = "DeepSeek" # 引用的API服务商名称
|
||||
price_in = 2.0 # 输入价格(元/M token)
|
||||
price_out = 8.0 # 输出价格(元/M token)
|
||||
```
|
||||
|
||||
### 3.2 高级模型配置
|
||||
|
||||
#### 强制流式输出
|
||||
对于不支持非流式输出的模型:
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "some-model"
|
||||
name = "custom-name"
|
||||
api_provider = "Provider"
|
||||
force_stream_mode = true # 启用强制流式输出
|
||||
```
|
||||
|
||||
#### 额外参数配置`extra_params`
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen3-8B"
|
||||
name = "qwen3-8b"
|
||||
api_provider = "SiliconFlow"
|
||||
[models.extra_params]
|
||||
enable_thinking = false # 禁用思考
|
||||
```
|
||||
这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置,**配置时应参考相应的API文档**。
|
||||
|
||||
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
|
||||
|
||||

|
||||
|
||||
以豆包文档为另一个例子
|
||||
|
||||

|
||||
|
||||
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
|
||||
```toml
|
||||
[[models]]
|
||||
# 你的模型
|
||||
[models.extra_params]
|
||||
thinking = {type = "disabled"} # 禁用思考
|
||||
```
|
||||
|
||||
而对于`gemini`需要单独进行配置
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "gemini-2.5-flash"
|
||||
name = "gemini-2.5-flash"
|
||||
api_provider = "Google"
|
||||
[models.extra_params]
|
||||
thinking_budget = 0 # 禁用思考
|
||||
# thinking_budget = -1 由模型自己决定
|
||||
```
|
||||
|
||||
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
||||
|
||||
### 3.3 配置参数说明
|
||||
|
||||
| 参数 | 必填 | 说明 |
|
||||
|------|------|------|
|
||||
| `model_identifier` | ✅ | API服务商提供的模型标识符 |
|
||||
| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
|
||||
| `api_provider` | ✅ | 对应的API服务商名称 |
|
||||
| `price_in` | ❌ | 输入价格(元/M token),用于成本统计 |
|
||||
| `price_out` | ❌ | 输出价格(元/M token),用于成本统计 |
|
||||
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
|
||||
| `extra_params` | ❌ | 额外的模型参数配置 |
|
||||
|
||||
## 4. 模型任务配置
|
||||
|
||||
### utils - 工具模型
|
||||
用于表情包模块、取名模块、关系模块等核心功能:
|
||||
```toml
|
||||
[model_task_config.utils]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### utils_small - 小型工具模型
|
||||
用于高频率调用的场景,建议使用速度快的小模型:
|
||||
```toml
|
||||
[model_task_config.utils_small]
|
||||
model_list = ["qwen3-8b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### replyer - 主要回复模型
|
||||
首要回复模型,也用于表达器和表达方式学习:
|
||||
```toml
|
||||
[model_task_config.replyer]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### planner - 决策模型
|
||||
负责决定MaiBot该做什么:
|
||||
```toml
|
||||
[model_task_config.planner]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.3
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### emotion - 情绪模型
|
||||
负责MaiBot的情绪变化:
|
||||
```toml
|
||||
[model_task_config.emotion]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.3
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### memory - 记忆模型
|
||||
```toml
|
||||
[model_task_config.memory]
|
||||
model_list = ["qwen3-30b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### vlm - 视觉语言模型
|
||||
用于图像识别:
|
||||
```toml
|
||||
[model_task_config.vlm]
|
||||
model_list = ["qwen2.5-vl-72b"]
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### voice - 语音识别模型
|
||||
```toml
|
||||
[model_task_config.voice]
|
||||
model_list = ["sensevoice-small"]
|
||||
```
|
||||
|
||||
### embedding - 嵌入模型
|
||||
```toml
|
||||
[model_task_config.embedding]
|
||||
model_list = ["bge-m3"]
|
||||
```
|
||||
|
||||
### tool_use - 工具调用模型
|
||||
需要使用支持工具调用的模型:
|
||||
```toml
|
||||
[model_task_config.tool_use]
|
||||
model_list = ["qwen3-14b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### lpmm_entity_extract - 实体提取模型
|
||||
```toml
|
||||
[model_task_config.lpmm_entity_extract]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### lpmm_rdf_build - RDF构建模型
|
||||
```toml
|
||||
[model_task_config.lpmm_rdf_build]
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
### lpmm_qa - 问答模型
|
||||
```toml
|
||||
[model_task_config.lpmm_qa]
|
||||
model_list = ["deepseek-r1-distill-qwen-32b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
```
|
||||
|
||||
## 5. 配置建议
|
||||
|
||||
### 5.1 Temperature 参数选择
|
||||
|
||||
| 任务类型 | 推荐温度 | 说明 |
|
||||
|----------|----------|------|
|
||||
| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
|
||||
| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
|
||||
| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
|
||||
|
||||
### 5.2 模型选择建议
|
||||
|
||||
| 任务类型 | 推荐模型类型 | 示例 |
|
||||
|----------|--------------|------|
|
||||
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
|
||||
| 高频率任务 | 小模型 | Qwen3-8B |
|
||||
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
|
||||
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
|
||||
|
||||
### 5.3 成本优化
|
||||
|
||||
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
|
||||
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
|
||||
3. **选择免费模型**:对于测试环境,优先使用price为0的模型
|
||||
|
||||
## 6. 配置验证
|
||||
|
||||
### 6.1 必要检查项
|
||||
|
||||
1. ✅ API密钥是否正确配置
|
||||
2. ✅ 模型标识符是否与API服务商提供的一致
|
||||
3. ✅ 任务配置中引用的模型名称是否在models中定义
|
||||
4. ✅ 多模态任务是否配置了对应的专用模型
|
||||
|
||||
### 6.2 测试配置
|
||||
|
||||
建议在正式使用前:
|
||||
1. 使用少量测试数据验证配置
|
||||
2. 检查API调用是否正常
|
||||
3. 确认成本统计功能正常工作
|
||||
|
||||
## 7. 故障排除
|
||||
|
||||
### 7.1 常见问题
|
||||
|
||||
**问题1**: API调用失败
|
||||
- 检查API密钥是否正确
|
||||
- 确认base_url是否可访问
|
||||
- 检查模型标识符是否正确
|
||||
|
||||
**问题2**: 模型未找到
|
||||
- 确认模型名称在任务配置和模型定义中一致
|
||||
- 检查api_provider名称是否匹配
|
||||
|
||||
**问题3**: 响应异常
|
||||
- 检查温度参数是否合理(0-1之间)
|
||||
- 确认max_tokens设置是否合适
|
||||
- 验证模型是否支持所需功能
|
||||
|
||||
### 7.2 日志查看
|
||||
|
||||
查看 `logs/` 目录下的日志文件,寻找相关错误信息。
|
||||
|
||||
## 8. 更新和维护
|
||||
|
||||
1. **定期更新**: 关注API服务商的模型更新,及时调整配置
|
||||
2. **性能监控**: 监控模型调用的成本和性能
|
||||
3. **备份配置**: 在修改前备份当前配置文件
|
||||
|
||||
@@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -230,7 +230,7 @@ class HeartFChatting:
|
||||
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
||||
mentioned_message = message
|
||||
|
||||
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||
|
||||
# *控制频率用
|
||||
if mentioned_message:
|
||||
@@ -333,7 +333,6 @@ class HeartFChatting:
|
||||
# 重置连续 no_reply 计数
|
||||
self.consecutive_no_reply_count = 0
|
||||
reason = ""
|
||||
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
@@ -411,7 +410,7 @@ class HeartFChatting:
|
||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})")
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
@@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
|
||||
def get_qa_manager():
|
||||
return qa_manager
|
||||
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
|
||||
@@ -92,9 +92,10 @@ class QAManager:
|
||||
# 过滤阈值
|
||||
result = dyn_select_top_k(result, 0.5, 1.0)
|
||||
|
||||
for res in result:
|
||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||
if global_config.debug.show_lpmm_paragraph:
|
||||
for res in result:
|
||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||
|
||||
return result, ppr_node_weights
|
||||
|
||||
@@ -128,11 +129,10 @@ class QAManager:
|
||||
selected_knowledge = knowledge[:limit]
|
||||
|
||||
formatted_knowledge = [
|
||||
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}"
|
||||
for i, k in enumerate(selected_knowledge)
|
||||
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
|
||||
]
|
||||
# if max_score is not None:
|
||||
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
||||
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
||||
|
||||
found_knowledge = "\n".join(formatted_knowledge)
|
||||
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
@@ -164,6 +163,45 @@ class ActionPlanner:
|
||||
return item[1]
|
||||
return None
|
||||
|
||||
def _replace_message_ids_with_text(
|
||||
self, text: Optional[str], message_id_list: List[Tuple[str, "DatabaseMessages"]]
|
||||
) -> Optional[str]:
|
||||
"""将文本中的 m+数字 消息ID替换为原消息内容,并添加双引号"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
id_to_message = {msg_id: msg for msg_id, msg in message_id_list}
|
||||
|
||||
# 匹配m后带2-4位数字,前后不是字母数字下划线
|
||||
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
|
||||
|
||||
matches = re.findall(pattern, text)
|
||||
if matches:
|
||||
available_ids = set(id_to_message.keys())
|
||||
found_ids = set(matches)
|
||||
missing_ids = found_ids - available_ids
|
||||
if missing_ids:
|
||||
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...")
|
||||
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中")
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
msg_id = match.group(0)
|
||||
message = id_to_message.get(msg_id)
|
||||
if not message:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
|
||||
return msg_id
|
||||
|
||||
msg_text = (message.processed_plain_text or message.display_message or "").strip()
|
||||
if not msg_text:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
|
||||
return msg_id
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview})")
|
||||
return f"消息({msg_text})"
|
||||
|
||||
return re.sub(pattern, _replace, text)
|
||||
|
||||
def _parse_single_action(
|
||||
self,
|
||||
action_json: dict,
|
||||
@@ -176,7 +214,10 @@ class ActionPlanner:
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
original_reasoning = action_json.get("reason", "未提供原因")
|
||||
reasoning = self._replace_message_ids_with_text(original_reasoning, message_id_list)
|
||||
if reasoning is None:
|
||||
reasoning = original_reasoning
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
@@ -573,9 +614,6 @@ class ActionPlanner:
|
||||
# 调用LLM
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
@@ -604,6 +642,7 @@ class ActionPlanner:
|
||||
if llm_content:
|
||||
try:
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
extracted_reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) or ""
|
||||
if json_objects:
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
|
||||
@@ -226,7 +226,9 @@ class DefaultReplyer:
|
||||
traceback.print_exc()
|
||||
return False, llm_response
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = ""
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
||||
@@ -1094,10 +1096,10 @@ class DefaultReplyer:
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
return ""
|
||||
|
||||
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
@@ -1115,10 +1117,10 @@ class DefaultReplyer:
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
|
||||
)
|
||||
|
||||
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
|
||||
if tool_calls:
|
||||
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
|
||||
end_time = time.time()
|
||||
|
||||
@@ -241,7 +241,9 @@ class PrivateReplyer:
|
||||
|
||||
return f"{sender_relation}"
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = ""
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
||||
@@ -1032,10 +1034,10 @@ class PrivateReplyer:
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
return ""
|
||||
|
||||
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
@@ -106,8 +106,8 @@ class ChatHistorySummarizer:
|
||||
await self._check_and_package(current_time)
|
||||
self.last_check_time = current_time
|
||||
return
|
||||
|
||||
logger.info(
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class ChatHistorySummarizer:
|
||||
before_count = len(self.current_batch.messages)
|
||||
self.current_batch.messages.extend(new_messages)
|
||||
self.current_batch.end_time = current_time
|
||||
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
logger.info(f"{self.log_prefix} 更新聊天话题: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
else:
|
||||
# 创建新批次
|
||||
self.current_batch = MessageBatch(
|
||||
@@ -127,7 +127,7 @@ class ChatHistorySummarizer:
|
||||
start_time=new_messages[0].time if new_messages else current_time,
|
||||
end_time=current_time,
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
|
||||
logger.info(f"{self.log_prefix} 新建聊天话题: {len(new_messages)} 条消息")
|
||||
|
||||
# 检查是否需要打包
|
||||
await self._check_and_package(current_time)
|
||||
|
||||
@@ -72,16 +72,16 @@ def get_ws_handler():
|
||||
|
||||
def initialize_ws_handler(loop):
|
||||
"""初始化 WebSocket handler 的事件循环
|
||||
|
||||
|
||||
Args:
|
||||
loop: asyncio 事件循环
|
||||
"""
|
||||
handler = get_ws_handler()
|
||||
handler.set_loop(loop)
|
||||
|
||||
|
||||
# 为 WebSocket handler 设置 JSON 格式化器(与文件格式相同)
|
||||
handler.setFormatter(file_formatter)
|
||||
|
||||
|
||||
# 添加到根日志记录器
|
||||
root_logger = logging.getLogger()
|
||||
if handler not in root_logger.handlers:
|
||||
@@ -177,43 +177,44 @@ class TimestampedFileHandler(logging.Handler):
|
||||
|
||||
class WebSocketLogHandler(logging.Handler):
|
||||
"""WebSocket 日志处理器 - 将日志实时推送到前端"""
|
||||
|
||||
|
||||
_log_counter = 0 # 类级别计数器,确保 ID 唯一性
|
||||
|
||||
|
||||
def __init__(self, loop=None):
|
||||
super().__init__()
|
||||
self.loop = loop
|
||||
self._initialized = False
|
||||
|
||||
|
||||
def set_loop(self, loop):
|
||||
"""设置事件循环"""
|
||||
self.loop = loop
|
||||
self._initialized = True
|
||||
|
||||
|
||||
def emit(self, record):
|
||||
"""发送日志到 WebSocket 客户端"""
|
||||
if not self._initialized or self.loop is None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
# 获取格式化后的消息
|
||||
# 对于 structlog,formatted message 包含完整的日志信息
|
||||
formatted_msg = self.format(record) if self.formatter else record.getMessage()
|
||||
|
||||
|
||||
# 如果是 JSON 格式(文件格式化器),解析它
|
||||
message = formatted_msg
|
||||
try:
|
||||
import json
|
||||
|
||||
log_dict = json.loads(formatted_msg)
|
||||
message = log_dict.get('event', formatted_msg)
|
||||
message = log_dict.get("event", formatted_msg)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# 不是 JSON,直接使用消息
|
||||
message = formatted_msg
|
||||
|
||||
|
||||
# 生成唯一 ID: 时间戳毫秒 + 自增计数器
|
||||
WebSocketLogHandler._log_counter += 1
|
||||
log_id = f"{int(record.created * 1000)}_{WebSocketLogHandler._log_counter}"
|
||||
|
||||
|
||||
# 格式化日志数据
|
||||
log_data = {
|
||||
"id": log_id,
|
||||
@@ -222,20 +223,17 @@ class WebSocketLogHandler(logging.Handler):
|
||||
"module": record.name,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
# 异步广播日志(不阻塞日志记录)
|
||||
try:
|
||||
import asyncio
|
||||
from src.webui.logs_ws import broadcast_log
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
broadcast_log(log_data),
|
||||
self.loop
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
|
||||
except Exception:
|
||||
# WebSocket 推送失败不影响日志记录
|
||||
pass
|
||||
|
||||
|
||||
except Exception:
|
||||
# 不要让 WebSocket 错误影响日志系统
|
||||
self.handleError(record)
|
||||
@@ -255,7 +253,7 @@ def close_handlers():
|
||||
if _console_handler:
|
||||
_console_handler.close()
|
||||
_console_handler = None
|
||||
|
||||
|
||||
if _ws_handler:
|
||||
_ws_handler.close()
|
||||
_ws_handler = None
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware # 新增导入
|
||||
from typing import Optional
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
import asyncio
|
||||
@@ -17,21 +16,6 @@ class Server:
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self.set_address(host, port)
|
||||
|
||||
# 配置 CORS
|
||||
origins = [
|
||||
"http://localhost:7999", # 允许的前端源
|
||||
"http://127.0.0.1:7999",
|
||||
# 在生产环境中,您应该添加实际的前端域名
|
||||
]
|
||||
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True, # 是否支持 cookie
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
||||
allow_headers=["*"], # 允许所有 HTTP 请求头
|
||||
)
|
||||
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
"""注册路由
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.3"
|
||||
MMC_VERSION = "0.11.5"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
|
||||
@@ -581,9 +581,15 @@ class DebugConfig(ConfigBase):
|
||||
show_jargon_prompt: bool = False
|
||||
"""是否显示jargon相关提示词"""
|
||||
|
||||
show_memory_prompt: bool = False
|
||||
"""是否显示记忆检索相关prompt"""
|
||||
|
||||
show_planner_prompt: bool = False
|
||||
"""是否显示planner相关提示词"""
|
||||
|
||||
show_lpmm_paragraph: bool = False
|
||||
"""是否显示lpmm找到的相关文段日志"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentalConfig(ConfigBase):
|
||||
@@ -647,7 +653,7 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
|
||||
enable: bool = True
|
||||
"""是否启用LPMM知识库"""
|
||||
|
||||
|
||||
lpmm_mode: Literal["classic", "agent"] = "classic"
|
||||
"""LPMM知识库模式,可选:classic经典模式,agent 模式,结合最新的记忆一同使用"""
|
||||
|
||||
@@ -690,4 +696,4 @@ class JargonConfig(ConfigBase):
|
||||
"""Jargon配置类"""
|
||||
|
||||
all_global: bool = False
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id"""
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id"""
|
||||
|
||||
@@ -467,11 +467,7 @@ class ExpressionLearner:
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
expr_obj = (
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
|
||||
.first()
|
||||
)
|
||||
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
|
||||
|
||||
if expr_obj:
|
||||
await self._update_existing_expression(
|
||||
|
||||
@@ -42,8 +42,6 @@ def init_prompt():
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
@@ -238,9 +236,9 @@ class ExpressionSelector:
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
|
||||
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
|
||||
|
||||
|
||||
# 构建reply_reason块
|
||||
if reply_reason:
|
||||
reply_reason_block = f"你的回复理由是:{reply_reason}"
|
||||
@@ -262,9 +260,8 @@ class ExpressionSelector:
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
# print(prompt)
|
||||
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
@@ -36,10 +36,7 @@ def _contains_bot_self_name(content: str) -> bool:
|
||||
|
||||
target = content.strip().lower()
|
||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||
alias_names = [
|
||||
str(alias or "").strip().lower()
|
||||
for alias in getattr(bot_config, "alias_names", []) or []
|
||||
]
|
||||
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||
|
||||
candidates = [name for name in [nickname, *alias_names] if name]
|
||||
|
||||
@@ -149,7 +146,7 @@ async def _enrich_raw_content_if_needed(
|
||||
) -> List[str]:
|
||||
"""
|
||||
检查raw_content是否只包含黑话本身,如果是,则获取该消息的前三条消息作为原始内容
|
||||
|
||||
|
||||
Args:
|
||||
content: 黑话内容
|
||||
raw_content_list: 原始raw_content列表
|
||||
@@ -157,22 +154,22 @@ async def _enrich_raw_content_if_needed(
|
||||
messages: 当前时间窗口内的消息列表
|
||||
extraction_start_time: 提取开始时间
|
||||
extraction_end_time: 提取结束时间
|
||||
|
||||
|
||||
Returns:
|
||||
处理后的raw_content列表
|
||||
"""
|
||||
enriched_list = []
|
||||
|
||||
|
||||
for raw_content in raw_content_list:
|
||||
# 检查raw_content是否只包含黑话本身(去除空白字符后比较)
|
||||
raw_content_clean = raw_content.strip()
|
||||
content_clean = content.strip()
|
||||
|
||||
|
||||
# 如果raw_content只包含黑话本身(可能有一些标点或空白),则尝试获取上下文
|
||||
# 去除所有空白字符后比较,确保只包含黑话本身
|
||||
raw_content_normalized = raw_content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
|
||||
content_normalized = content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
|
||||
|
||||
|
||||
if raw_content_normalized == content_normalized:
|
||||
# 在消息列表中查找只包含该黑话的消息(去除空白后比较)
|
||||
target_message = None
|
||||
@@ -183,22 +180,20 @@ async def _enrich_raw_content_if_needed(
|
||||
if msg_content_normalized == content_normalized:
|
||||
target_message = msg
|
||||
break
|
||||
|
||||
|
||||
if target_message and target_message.time:
|
||||
# 获取该消息的前三条消息
|
||||
try:
|
||||
previous_messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=target_message.time,
|
||||
limit=3
|
||||
chat_id=chat_id, timestamp=target_message.time, limit=3
|
||||
)
|
||||
|
||||
|
||||
if previous_messages:
|
||||
# 将前三条消息和当前消息一起格式化
|
||||
context_messages = previous_messages + [target_message]
|
||||
# 按时间排序
|
||||
context_messages.sort(key=lambda x: x.time or 0)
|
||||
|
||||
|
||||
# 格式化为可读消息
|
||||
formatted_context, _ = await build_readable_messages_with_list(
|
||||
context_messages,
|
||||
@@ -206,7 +201,7 @@ async def _enrich_raw_content_if_needed(
|
||||
timestamp_mode="relative",
|
||||
truncate=False,
|
||||
)
|
||||
|
||||
|
||||
if formatted_context.strip():
|
||||
enriched_list.append(formatted_context.strip())
|
||||
logger.warning(f"为黑话 {content} 补充了上下文消息")
|
||||
@@ -226,7 +221,7 @@ async def _enrich_raw_content_if_needed(
|
||||
else:
|
||||
# raw_content包含更多内容,直接使用
|
||||
enriched_list.append(raw_content)
|
||||
|
||||
|
||||
return enriched_list
|
||||
|
||||
|
||||
@@ -240,31 +235,31 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||
# 如果已完成所有推断,不再推断
|
||||
if jargon_obj.is_complete:
|
||||
return False
|
||||
|
||||
|
||||
count = jargon_obj.count or 0
|
||||
last_inference = jargon_obj.last_inference_count or 0
|
||||
|
||||
|
||||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||
thresholds = [3,6, 10, 20, 40, 60, 100]
|
||||
|
||||
thresholds = [3, 6, 10, 20, 40, 60, 100]
|
||||
|
||||
if count < thresholds[0]:
|
||||
return False
|
||||
|
||||
|
||||
# 如果count没有超过上次判定值,不需要判定
|
||||
if count <= last_inference:
|
||||
return False
|
||||
|
||||
|
||||
# 找到第一个大于last_inference的阈值
|
||||
next_threshold = None
|
||||
for threshold in thresholds:
|
||||
if threshold > last_inference:
|
||||
next_threshold = threshold
|
||||
break
|
||||
|
||||
|
||||
# 如果没有找到下一个阈值,说明已经超过100,不应该再推断
|
||||
if next_threshold is None:
|
||||
return False
|
||||
|
||||
|
||||
# 检查count是否达到或超过这个阈值
|
||||
return count >= next_threshold
|
||||
|
||||
@@ -275,13 +270,13 @@ class JargonMiner:
|
||||
self.last_learning_time: float = time.time()
|
||||
# 频率控制,可按需调整
|
||||
self.min_messages_for_learning: int = 10
|
||||
self.min_learning_interval: float = 20
|
||||
self.min_learning_interval: float = 20
|
||||
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.extract",
|
||||
)
|
||||
|
||||
|
||||
# 初始化stream_name作为类属性,避免重复提取
|
||||
chat_manager = get_chat_manager()
|
||||
stream_name = chat_manager.get_stream_name(self.chat_id)
|
||||
@@ -306,17 +301,19 @@ class JargonMiner:
|
||||
try:
|
||||
content = jargon_obj.content
|
||||
raw_content_str = jargon_obj.raw_content or ""
|
||||
|
||||
|
||||
# 解析raw_content列表
|
||||
raw_content_list = []
|
||||
if raw_content_str:
|
||||
try:
|
||||
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||
raw_content_list = (
|
||||
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||
)
|
||||
if not isinstance(raw_content_list, list):
|
||||
raw_content_list = [raw_content_list] if raw_content_list else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
raw_content_list = [raw_content_str] if raw_content_str else []
|
||||
|
||||
|
||||
if not raw_content_list:
|
||||
logger.warning(f"jargon {content} 没有raw_content,跳过推断")
|
||||
return
|
||||
@@ -328,12 +325,12 @@ class JargonMiner:
|
||||
content=content,
|
||||
raw_content_list=raw_content_text,
|
||||
)
|
||||
|
||||
|
||||
response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3)
|
||||
if not response1:
|
||||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||||
return
|
||||
|
||||
|
||||
# 解析推断1结果
|
||||
inference1 = None
|
||||
try:
|
||||
@@ -349,7 +346,7 @@ class JargonMiner:
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断1解析失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 检查推断1是否表示信息不足无法推断
|
||||
no_info = inference1.get("no_info", False)
|
||||
meaning1 = inference1.get("meaning", "").strip()
|
||||
@@ -360,18 +357,17 @@ class JargonMiner:
|
||||
jargon_obj.save()
|
||||
return
|
||||
|
||||
|
||||
# 步骤2: 仅基于content推断
|
||||
prompt2 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_content_only_prompt",
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3)
|
||||
if not response2:
|
||||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||||
return
|
||||
|
||||
|
||||
# 解析推断2结果
|
||||
inference2 = None
|
||||
try:
|
||||
@@ -387,13 +383,12 @@ class JargonMiner:
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
# logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
# logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
# logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
# logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
@@ -404,22 +399,22 @@ class JargonMiner:
|
||||
logger.debug(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.debug(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.debug(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
|
||||
# 步骤3: 比较两个推断结果
|
||||
prompt3 = await global_prompt_manager.format_prompt(
|
||||
"jargon_compare_inference_prompt",
|
||||
inference1=json.dumps(inference1, ensure_ascii=False),
|
||||
inference2=json.dumps(inference2, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||
|
||||
|
||||
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
|
||||
if not response3:
|
||||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||||
return
|
||||
|
||||
|
||||
# 解析比较结果
|
||||
comparison = None
|
||||
try:
|
||||
@@ -439,7 +434,7 @@ class JargonMiner:
|
||||
# 判断是否为黑话
|
||||
is_similar = comparison.get("is_similar", False)
|
||||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||||
|
||||
|
||||
# 更新数据库记录
|
||||
jargon_obj.is_jargon = is_jargon
|
||||
if is_jargon:
|
||||
@@ -448,17 +443,19 @@ class JargonMiner:
|
||||
else:
|
||||
# 不是黑话,也记录含义(使用推断2的结果,因为含义明确)
|
||||
jargon_obj.meaning = inference2.get("meaning", "")
|
||||
|
||||
|
||||
# 更新最后一次判定的count值,避免重启后重复判定
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
|
||||
|
||||
# 如果count>=100,标记为完成,不再进行推断
|
||||
if (jargon_obj.count or 0) >= 100:
|
||||
jargon_obj.is_complete = True
|
||||
|
||||
|
||||
jargon_obj.save()
|
||||
logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
|
||||
|
||||
logger.debug(
|
||||
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
|
||||
)
|
||||
|
||||
# 固定输出推断结果,格式化为可读形式
|
||||
if is_jargon:
|
||||
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
|
||||
@@ -471,10 +468,11 @@ class JargonMiner:
|
||||
else:
|
||||
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
|
||||
logger.info(f"[{self.stream_name}]{content} 不是黑话")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"jargon推断失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def should_trigger(self) -> bool:
|
||||
@@ -502,7 +500,7 @@ class JargonMiner:
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_learning_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
|
||||
# 拉取学习窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
@@ -525,7 +523,7 @@ class JargonMiner:
|
||||
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
|
||||
if not response:
|
||||
return
|
||||
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon提取提示词: {prompt}")
|
||||
logger.info(f"jargon提取结果: {response}")
|
||||
@@ -555,7 +553,7 @@ class JargonMiner:
|
||||
continue
|
||||
content = str(item.get("content", "")).strip()
|
||||
raw_content_value = item.get("raw_content", "")
|
||||
|
||||
|
||||
# 处理raw_content:可能是字符串或列表
|
||||
raw_content_list = []
|
||||
if isinstance(raw_content_value, list):
|
||||
@@ -566,15 +564,12 @@ class JargonMiner:
|
||||
raw_content_str = raw_content_value.strip()
|
||||
if raw_content_str:
|
||||
raw_content_list = [raw_content_str]
|
||||
|
||||
|
||||
if content and raw_content_list:
|
||||
if _contains_bot_self_name(content):
|
||||
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
entries.append({
|
||||
"content": content,
|
||||
"raw_content": raw_content_list
|
||||
})
|
||||
entries.append({"content": content, "raw_content": raw_content_list})
|
||||
except Exception as e:
|
||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||
return
|
||||
@@ -591,13 +586,13 @@ class JargonMiner:
|
||||
if content_key not in seen:
|
||||
seen.add(content_key)
|
||||
uniq_entries.append(entry)
|
||||
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
|
||||
# 检查并补充raw_content:如果只包含黑话本身,则获取前三条消息作为上下文
|
||||
raw_content_list = await _enrich_raw_content_if_needed(
|
||||
content=content,
|
||||
@@ -607,60 +602,53 @@ class JargonMiner:
|
||||
extraction_start_time=extraction_start_time,
|
||||
extraction_end_time=extraction_end_time,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的)
|
||||
query = (
|
||||
Jargon.select()
|
||||
.where(Jargon.content == content)
|
||||
)
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
else:
|
||||
# 关闭all_global:只查询chat_id匹配的记录(不考虑is_global)
|
||||
query = (
|
||||
Jargon.select()
|
||||
.where(
|
||||
(Jargon.chat_id == self.chat_id) &
|
||||
(Jargon.content == content)
|
||||
)
|
||||
)
|
||||
|
||||
query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
|
||||
|
||||
if query.exists():
|
||||
obj = query.get()
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.jargon.all_global:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
|
||||
obj.save()
|
||||
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
@@ -670,13 +658,13 @@ class JargonMiner:
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=self.chat_id,
|
||||
is_global=is_global_new,
|
||||
count=1
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
@@ -688,13 +676,13 @@ class JargonMiner:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
|
||||
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
|
||||
self.last_learning_time = extraction_end_time
|
||||
|
||||
|
||||
if saved or updated:
|
||||
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
@@ -720,15 +708,11 @@ async def extract_and_store_jargon(chat_id: str) -> None:
|
||||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
case_sensitive: bool = False,
|
||||
fuzzy: bool = True
|
||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
搜索jargon,支持大小写不敏感和模糊搜索
|
||||
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
chat_id: 可选的聊天ID
|
||||
@@ -737,21 +721,18 @@ def search_jargon(
|
||||
limit: 返回结果数量限制,默认10
|
||||
case_sensitive: 是否大小写敏感,默认False(不敏感)
|
||||
fuzzy: 是否模糊搜索,默认True(使用LIKE匹配)
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: 包含content, meaning的字典列表
|
||||
"""
|
||||
if not keyword or not keyword.strip():
|
||||
return []
|
||||
|
||||
|
||||
keyword = keyword.strip()
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = Jargon.select(
|
||||
Jargon.content,
|
||||
Jargon.meaning
|
||||
)
|
||||
|
||||
query = Jargon.select(Jargon.content, Jargon.meaning)
|
||||
|
||||
# 构建搜索条件
|
||||
if case_sensitive:
|
||||
# 大小写敏感
|
||||
@@ -760,7 +741,7 @@ def search_jargon(
|
||||
search_condition = Jargon.content.contains(keyword)
|
||||
else:
|
||||
# 精确匹配
|
||||
search_condition = (Jargon.content == keyword)
|
||||
search_condition = Jargon.content == keyword
|
||||
else:
|
||||
# 大小写不敏感
|
||||
if fuzzy:
|
||||
@@ -768,10 +749,10 @@ def search_jargon(
|
||||
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
||||
else:
|
||||
# 精确匹配(使用LOWER函数)
|
||||
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
|
||||
|
||||
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
|
||||
|
||||
query = query.where(search_condition)
|
||||
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id)
|
||||
@@ -779,35 +760,28 @@ def search_jargon(
|
||||
else:
|
||||
# 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon
|
||||
if chat_id:
|
||||
query = query.where(
|
||||
(Jargon.chat_id == chat_id) | Jargon.is_global
|
||||
)
|
||||
|
||||
query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
|
||||
|
||||
# 只返回有meaning的记录
|
||||
query = query.where(
|
||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
||||
)
|
||||
|
||||
query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
|
||||
# 按count降序排序,优先返回出现频率高的
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
|
||||
# 限制结果数量
|
||||
query = query.limit(limit)
|
||||
|
||||
|
||||
# 执行查询并返回结果
|
||||
results = []
|
||||
for jargon in query:
|
||||
results.append({
|
||||
"content": jargon.content or "",
|
||||
"meaning": jargon.meaning or ""
|
||||
})
|
||||
|
||||
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: str) -> None:
|
||||
"""将黑话存入jargon系统
|
||||
|
||||
|
||||
Args:
|
||||
jargon_keyword: 黑话关键词
|
||||
answer: 答案内容(将概括为raw_content)
|
||||
@@ -820,53 +794,52 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
||||
答案:{answer}
|
||||
|
||||
只输出概括后的内容,不要输出其他内容:"""
|
||||
|
||||
|
||||
success, summary, _, _ = await llm_api.generate_with_model(
|
||||
summary_prompt,
|
||||
model_config=model_config.model_task_config.utils_small,
|
||||
request_type="memory.summarize_jargon",
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"概括答案提示: {summary_prompt}")
|
||||
logger.info(f"概括答案: {summary}")
|
||||
|
||||
|
||||
if not success:
|
||||
logger.warning(f"概括答案失败,使用原始答案: {summary}")
|
||||
summary = answer[:100] # 截取前100字符作为备用
|
||||
|
||||
|
||||
raw_content = summary.strip()[:200] # 限制长度
|
||||
|
||||
|
||||
# 检查是否已存在
|
||||
if global_config.jargon.all_global:
|
||||
query = Jargon.select().where(Jargon.content == jargon_keyword)
|
||||
else:
|
||||
query = Jargon.select().where(
|
||||
(Jargon.chat_id == chat_id) &
|
||||
(Jargon.content == jargon_keyword)
|
||||
)
|
||||
|
||||
query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
|
||||
|
||||
if query.exists():
|
||||
# 更新现有记录
|
||||
obj = query.get()
|
||||
obj.count = (obj.count or 0) + 1
|
||||
|
||||
|
||||
# 合并raw_content列表
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + [raw_content]))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
|
||||
if global_config.jargon.all_global:
|
||||
obj.is_global = True
|
||||
|
||||
|
||||
obj.save()
|
||||
logger.info(f"更新jargon记录: {jargon_keyword}")
|
||||
else:
|
||||
@@ -877,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
||||
raw_content=json.dumps([raw_content], ensure_ascii=False),
|
||||
chat_id=chat_id,
|
||||
is_global=is_global_new,
|
||||
count=1
|
||||
count=1,
|
||||
)
|
||||
logger.info(f"创建新jargon记录: {jargon_keyword}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储jargon失败: {e}")
|
||||
|
||||
|
||||
|
||||
@@ -147,7 +147,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
|
||||
param_type_value = tool_option_param.param_type.value
|
||||
if param_type_value == "bool":
|
||||
param_type_value = "boolean"
|
||||
|
||||
|
||||
return_dict: dict[str, Any] = {
|
||||
"type": param_type_value,
|
||||
"description": tool_option_param.description,
|
||||
|
||||
@@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
|
||||
param_type_value = tool_option_param.param_type.value
|
||||
if param_type_value == "bool":
|
||||
param_type_value = "boolean"
|
||||
|
||||
|
||||
return_dict: dict[str, Any] = {
|
||||
"type": param_type_value,
|
||||
"description": tool_option_param.description,
|
||||
|
||||
@@ -116,9 +116,7 @@ class MessageBuilder:
|
||||
构建消息对象
|
||||
:return: Message对象
|
||||
"""
|
||||
if len(self.__content) == 0 and not (
|
||||
self.__role == RoleType.Assistant and self.__tool_calls
|
||||
):
|
||||
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
|
||||
raise ValueError("内容不能为空")
|
||||
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||
|
||||
@@ -166,7 +166,7 @@ class LLMRequest:
|
||||
time_cost=time.time() - start_time,
|
||||
)
|
||||
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
|
||||
async def generate_response_with_message_async(
|
||||
self,
|
||||
message_factory: Callable[[BaseClient], List[Message]],
|
||||
|
||||
48
src/main.py
48
src/main.py
@@ -36,37 +36,39 @@ class MainSystem:
|
||||
# 使用消息API替代直接的FastAPI实例
|
||||
self.app: MessageServer = get_global_api()
|
||||
self.server: Server = get_global_server()
|
||||
|
||||
# 注册 WebUI API 路由
|
||||
self._register_webui_routes()
|
||||
|
||||
# 设置 WebUI(开发/生产模式)
|
||||
self._setup_webui()
|
||||
self.webui_server = None # 独立的 WebUI 服务器
|
||||
|
||||
def _register_webui_routes(self):
|
||||
"""注册 WebUI API 路由"""
|
||||
try:
|
||||
from src.webui.routes import router as webui_router
|
||||
self.server.register_router(webui_router)
|
||||
logger.info("WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
logger.warning(f"注册 WebUI API 路由失败: {e}")
|
||||
# 设置独立的 WebUI 服务器
|
||||
self._setup_webui_server()
|
||||
|
||||
def _setup_webui(self):
|
||||
"""设置 WebUI(根据环境变量决定模式)"""
|
||||
def _setup_webui_server(self):
|
||||
"""设置独立的 WebUI 服务器"""
|
||||
import os
|
||||
|
||||
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
||||
if not webui_enabled:
|
||||
logger.info("WebUI 已禁用")
|
||||
return
|
||||
|
||||
|
||||
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
|
||||
|
||||
|
||||
try:
|
||||
from src.webui.manager import setup_webui
|
||||
setup_webui(mode=webui_mode)
|
||||
from src.webui.webui_server import get_webui_server
|
||||
|
||||
self.webui_server = get_webui_server()
|
||||
|
||||
if webui_mode == "development":
|
||||
logger.info("📝 WebUI 开发模式已启用")
|
||||
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
|
||||
logger.info("💡 请手动启动前端开发服务器: cd MaiBot-Dashboard && bun dev")
|
||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||
else:
|
||||
logger.info("✅ WebUI 生产模式已启用")
|
||||
logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001")
|
||||
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置 WebUI 失败: {e}")
|
||||
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化系统组件"""
|
||||
@@ -161,6 +163,10 @@ class MainSystem:
|
||||
self.server.run(),
|
||||
]
|
||||
|
||||
# 如果 WebUI 服务器已初始化,添加到任务列表
|
||||
if self.webui_server:
|
||||
tasks.append(self.webui_server.start())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# async def forget_memory_task(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@
|
||||
记忆系统工具函数
|
||||
包含模糊查找、相似度计算等工具函数
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
@@ -14,6 +15,7 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
def parse_md_json(json_text: str) -> list[str]:
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
@@ -52,14 +54,15 @@ def parse_md_json(json_text: str) -> list[str]:
|
||||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
|
||||
Returns:
|
||||
float: 相似度分数 (0-1)
|
||||
"""
|
||||
@@ -67,16 +70,16 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
# 预处理文本
|
||||
text1 = preprocess_text(text1)
|
||||
text2 = preprocess_text(text2)
|
||||
|
||||
|
||||
# 使用SequenceMatcher计算相似度
|
||||
similarity = SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
# 如果其中一个文本包含另一个,提高相似度
|
||||
if text1 in text2 or text2 in text1:
|
||||
similarity = max(similarity, 0.8)
|
||||
|
||||
|
||||
return similarity
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算相似度时出错: {e}")
|
||||
return 0.0
|
||||
@@ -85,31 +88,30 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""
|
||||
预处理文本,提高匹配准确性
|
||||
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
|
||||
Returns:
|
||||
str: 预处理后的文本
|
||||
"""
|
||||
try:
|
||||
# 转换为小写
|
||||
text = text.lower()
|
||||
|
||||
|
||||
# 移除标点符号和特殊字符
|
||||
text = re.sub(r'[^\w\s]', '', text)
|
||||
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
|
||||
# 移除多余空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理文本时出错: {e}")
|
||||
return text
|
||||
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
@@ -143,25 +145,24 @@ def parse_datetime_to_timestamp(value: str) -> float:
|
||||
def parse_time_range(time_range: str) -> Tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
|
||||
|
||||
Args:
|
||||
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: (开始时间戳, 结束时间戳)
|
||||
"""
|
||||
if " - " not in time_range:
|
||||
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
|
||||
|
||||
|
||||
parts = time_range.split(" - ", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"时间范围格式错误: {time_range}")
|
||||
|
||||
|
||||
start_str = parts[0].strip()
|
||||
end_str = parts[1].strip()
|
||||
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str)
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
@@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
|
||||
@@ -15,13 +15,10 @@ logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
time_range: Optional[str] = None,
|
||||
fuzzy: bool = True
|
||||
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
|
||||
) -> str:
|
||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
|
||||
@@ -31,7 +28,7 @@ async def query_chat_history(
|
||||
fuzzy: 是否使用模糊匹配模式(默认True)
|
||||
- True: 模糊匹配,只要包含任意一个关键词即匹配(OR关系)
|
||||
- False: 全匹配,必须包含所有关键词才匹配(AND关系)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
@@ -39,10 +36,10 @@ async def query_chat_history(
|
||||
# 检查参数
|
||||
if not keyword and not time_range:
|
||||
return "未指定查询参数(需要提供keyword或time_range之一)"
|
||||
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
|
||||
# 时间过滤条件
|
||||
if time_range:
|
||||
# 判断是时间点还是时间范围
|
||||
@@ -50,79 +47,79 @@ async def query_chat_history(
|
||||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_filter = (
|
||||
(ChatHistory.start_time < end_timestamp) &
|
||||
(ChatHistory.end_time > start_timestamp)
|
||||
)
|
||||
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
|
||||
else:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||
time_filter = (
|
||||
(ChatHistory.start_time <= target_timestamp) &
|
||||
(ChatHistory.end_time >= target_timestamp)
|
||||
)
|
||||
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
|
||||
query = query.where(time_filter)
|
||||
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
||||
|
||||
# 如果有关键词,进一步过滤
|
||||
if keyword:
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
|
||||
if not keywords_lower:
|
||||
return "关键词为空"
|
||||
|
||||
|
||||
filtered_records = []
|
||||
|
||||
|
||||
for record in records:
|
||||
# 在theme、keywords、summary、original_text中搜索
|
||||
theme = (record.theme or "").lower()
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
|
||||
# 解析record中的keywords JSON
|
||||
record_keywords_list = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
# 根据匹配模式检查关键词
|
||||
matched = False
|
||||
if fuzzy:
|
||||
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
||||
for kw in keywords_lower:
|
||||
if (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list)):
|
||||
if (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
):
|
||||
matched = True
|
||||
break
|
||||
else:
|
||||
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
||||
matched = True
|
||||
for kw in keywords_lower:
|
||||
kw_matched = (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list))
|
||||
kw_matched = (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
)
|
||||
if not kw_matched:
|
||||
matched = False
|
||||
break
|
||||
|
||||
|
||||
if matched:
|
||||
filtered_records.append(record)
|
||||
|
||||
|
||||
if not filtered_records:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词"
|
||||
@@ -130,9 +127,9 @@ async def query_chat_history(
|
||||
return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述"
|
||||
else:
|
||||
return f"未找到{match_mode}'{keywords_str}'的聊天记录概述"
|
||||
|
||||
|
||||
records = filtered_records
|
||||
|
||||
|
||||
# 如果没有记录(可能是时间范围查询但没有匹配的记录)
|
||||
if not records:
|
||||
if time_range:
|
||||
@@ -148,22 +145,23 @@ async def query_chat_history(
|
||||
record.count = (record.count or 0) + 1
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新聊天记录概述计数失败: {update_error}")
|
||||
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
for record in records_to_use: # 最多返回3条记录
|
||||
result_parts = []
|
||||
|
||||
|
||||
# 添加主题
|
||||
if record.theme:
|
||||
result_parts.append(f"主题:{record.theme}")
|
||||
|
||||
|
||||
# 添加时间范围
|
||||
from datetime import datetime
|
||||
|
||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||
|
||||
|
||||
# 添加概括(优先使用summary,如果没有则使用original_text的前200字符)
|
||||
if record.summary:
|
||||
result_parts.append(f"概括:{record.summary}")
|
||||
@@ -172,18 +170,18 @@ async def query_chat_history(
|
||||
if len(record.original_text) > 200:
|
||||
text_preview += "..."
|
||||
result_parts.append(f"内容:{text_preview}")
|
||||
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
|
||||
if not results:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
if len(records) > len(records_to_use):
|
||||
omitted_count = len(records) - len(records_to_use)
|
||||
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
|
||||
return response_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询聊天历史概述失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
@@ -199,20 +197,20 @@ def register_tool():
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
||||
"required": False
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "time_range",
|
||||
"type": "string",
|
||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||
"required": False
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "fuzzy",
|
||||
"type": "boolean",
|
||||
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
||||
"required": False
|
||||
}
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_chat_history
|
||||
execute_func=query_chat_history,
|
||||
)
|
||||
|
||||
@@ -73,5 +73,3 @@ def register_tool():
|
||||
],
|
||||
execute_func=query_lpmm_knowledge,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,23 +14,25 @@ logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
"""格式化群昵称信息
|
||||
|
||||
|
||||
Args:
|
||||
group_nick_name_field: 群昵称字段(可能是字符串JSON或None)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化后的群昵称信息字符串
|
||||
"""
|
||||
if not group_nick_name_field:
|
||||
return ""
|
||||
|
||||
|
||||
try:
|
||||
# 解析JSON格式的群昵称列表
|
||||
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||
|
||||
group_nick_names_data = (
|
||||
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||
)
|
||||
|
||||
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
|
||||
return ""
|
||||
|
||||
|
||||
# 格式化群昵称列表
|
||||
group_nick_list = []
|
||||
for item in group_nick_names_data:
|
||||
@@ -41,7 +43,7 @@ def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
elif isinstance(item, str):
|
||||
# 兼容旧格式(如果存在)
|
||||
group_nick_list.append(f" - {item}")
|
||||
|
||||
|
||||
if group_nick_list:
|
||||
return "群昵称:\n" + "\n".join(group_nick_list)
|
||||
return ""
|
||||
@@ -58,10 +60,10 @@ def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
|
||||
async def query_person_info(person_name: str) -> str:
|
||||
"""根据person_name查询用户信息,使用模糊查询
|
||||
|
||||
|
||||
Args:
|
||||
person_name: 用户名称(person_name字段)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含用户的所有信息
|
||||
"""
|
||||
@@ -69,37 +71,35 @@ async def query_person_info(person_name: str) -> str:
|
||||
person_name = str(person_name).strip()
|
||||
if not person_name:
|
||||
return "用户名称为空"
|
||||
|
||||
|
||||
# 构建查询条件(使用模糊查询)
|
||||
query = PersonInfo.select().where(
|
||||
PersonInfo.person_name.contains(person_name)
|
||||
)
|
||||
|
||||
query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
|
||||
|
||||
# 执行查询
|
||||
records = list(query.limit(20)) # 最多返回20条记录
|
||||
|
||||
|
||||
if not records:
|
||||
return f"未找到模糊匹配'{person_name}'的用户信息"
|
||||
|
||||
|
||||
# 区分精确匹配和模糊匹配的结果
|
||||
exact_matches = []
|
||||
fuzzy_matches = []
|
||||
|
||||
|
||||
for record in records:
|
||||
# 检查是否是精确匹配
|
||||
if record.person_name and record.person_name.strip() == person_name:
|
||||
exact_matches.append(record)
|
||||
else:
|
||||
fuzzy_matches.append(record)
|
||||
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
|
||||
|
||||
# 先处理精确匹配的结果
|
||||
for record in exact_matches:
|
||||
result_parts = []
|
||||
result_parts.append("【精确匹配】") # 标注为精确匹配
|
||||
|
||||
|
||||
# 基本信息
|
||||
if record.person_name:
|
||||
result_parts.append(f"用户名称:{record.person_name}")
|
||||
@@ -111,19 +111,19 @@ async def query_person_info(person_name: str) -> str:
|
||||
result_parts.append(f"平台:{record.platform}")
|
||||
if record.user_id:
|
||||
result_parts.append(f"平台用户ID:{record.user_id}")
|
||||
|
||||
|
||||
# 群昵称信息
|
||||
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
|
||||
if group_nick_name_str:
|
||||
result_parts.append(group_nick_name_str)
|
||||
|
||||
|
||||
# 名称设定原因
|
||||
if record.name_reason:
|
||||
result_parts.append(f"名称设定原因:{record.name_reason}")
|
||||
|
||||
|
||||
# 认识状态
|
||||
result_parts.append(f"是否已认识:{'是' if record.is_known else '否'}")
|
||||
|
||||
|
||||
# 时间信息
|
||||
if record.know_since:
|
||||
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -133,11 +133,15 @@ async def query_person_info(person_name: str) -> str:
|
||||
result_parts.append(f"最后认识时间:{last_know_str}")
|
||||
if record.know_times:
|
||||
result_parts.append(f"认识次数:{int(record.know_times)}")
|
||||
|
||||
|
||||
# 记忆点(memory_points)
|
||||
if record.memory_points:
|
||||
try:
|
||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
||||
memory_points_data = (
|
||||
json.loads(record.memory_points)
|
||||
if isinstance(record.memory_points, str)
|
||||
else record.memory_points
|
||||
)
|
||||
if isinstance(memory_points_data, list) and memory_points_data:
|
||||
# 解析记忆点格式:category:content:weight
|
||||
memory_list = []
|
||||
@@ -151,7 +155,7 @@ async def query_person_info(person_name: str) -> str:
|
||||
memory_list.append(f" - [{category}] {content} (权重: {weight})")
|
||||
else:
|
||||
memory_list.append(f" - {memory_point}")
|
||||
|
||||
|
||||
if memory_list:
|
||||
result_parts.append("记忆点:\n" + "\n".join(memory_list))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
@@ -161,14 +165,14 @@ async def query_person_info(person_name: str) -> str:
|
||||
if len(str(record.memory_points)) > 200:
|
||||
memory_preview += "..."
|
||||
result_parts.append(f"记忆点(原始数据):{memory_preview}")
|
||||
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
|
||||
# 再处理模糊匹配的结果
|
||||
for record in fuzzy_matches:
|
||||
result_parts = []
|
||||
result_parts.append("【模糊匹配】") # 标注为模糊匹配
|
||||
|
||||
|
||||
# 基本信息
|
||||
if record.person_name:
|
||||
result_parts.append(f"用户名称:{record.person_name}")
|
||||
@@ -180,19 +184,19 @@ async def query_person_info(person_name: str) -> str:
|
||||
result_parts.append(f"平台:{record.platform}")
|
||||
if record.user_id:
|
||||
result_parts.append(f"平台用户ID:{record.user_id}")
|
||||
|
||||
|
||||
# 群昵称信息
|
||||
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
|
||||
if group_nick_name_str:
|
||||
result_parts.append(group_nick_name_str)
|
||||
|
||||
|
||||
# 名称设定原因
|
||||
if record.name_reason:
|
||||
result_parts.append(f"名称设定原因:{record.name_reason}")
|
||||
|
||||
|
||||
# 认识状态
|
||||
result_parts.append(f"是否已认识:{'是' if record.is_known else '否'}")
|
||||
|
||||
|
||||
# 时间信息
|
||||
if record.know_since:
|
||||
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -202,11 +206,15 @@ async def query_person_info(person_name: str) -> str:
|
||||
result_parts.append(f"最后认识时间:{last_know_str}")
|
||||
if record.know_times:
|
||||
result_parts.append(f"认识次数:{int(record.know_times)}")
|
||||
|
||||
|
||||
# 记忆点(memory_points)
|
||||
if record.memory_points:
|
||||
try:
|
||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
||||
memory_points_data = (
|
||||
json.loads(record.memory_points)
|
||||
if isinstance(record.memory_points, str)
|
||||
else record.memory_points
|
||||
)
|
||||
if isinstance(memory_points_data, list) and memory_points_data:
|
||||
# 解析记忆点格式:category:content:weight
|
||||
memory_list = []
|
||||
@@ -220,7 +228,7 @@ async def query_person_info(person_name: str) -> str:
|
||||
memory_list.append(f" - [{category}] {content} (权重: {weight})")
|
||||
else:
|
||||
memory_list.append(f" - {memory_point}")
|
||||
|
||||
|
||||
if memory_list:
|
||||
result_parts.append("记忆点:\n" + "\n".join(memory_list))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
@@ -230,20 +238,20 @@ async def query_person_info(person_name: str) -> str:
|
||||
if len(str(record.memory_points)) > 200:
|
||||
memory_preview += "..."
|
||||
result_parts.append(f"记忆点(原始数据):{memory_preview}")
|
||||
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
|
||||
# 组合所有结果
|
||||
if not results:
|
||||
return f"未找到匹配'{person_name}'的用户信息"
|
||||
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
|
||||
|
||||
# 添加统计信息
|
||||
total_count = len(records)
|
||||
exact_count = len(exact_matches)
|
||||
fuzzy_count = len(fuzzy_matches)
|
||||
|
||||
|
||||
# 显示精确匹配和模糊匹配的统计
|
||||
if exact_count > 0 or fuzzy_count > 0:
|
||||
stats_parts = []
|
||||
@@ -257,13 +265,13 @@ async def query_person_info(person_name: str) -> str:
|
||||
response_text = f"找到 {total_count} 条匹配的用户信息:\n\n{response_text}"
|
||||
else:
|
||||
response_text = f"找到用户信息:\n\n{response_text}"
|
||||
|
||||
|
||||
# 如果结果数量达到限制,添加提示
|
||||
if total_count >= 20:
|
||||
response_text += "\n\n(已显示前20条结果,可能还有更多匹配记录)"
|
||||
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询用户信息失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
@@ -275,13 +283,7 @@ def register_tool():
|
||||
name="query_person_info",
|
||||
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
|
||||
parameters=[
|
||||
{
|
||||
"name": "person_name",
|
||||
"type": "string",
|
||||
"description": "用户名称,用于查询用户信息",
|
||||
"required": True
|
||||
}
|
||||
{"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
|
||||
],
|
||||
execute_func=query_person_info
|
||||
execute_func=query_person_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,10 +47,10 @@ class MemoryRetrievalTool:
|
||||
async def execute(self, **kwargs) -> str:
|
||||
"""执行工具"""
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
|
||||
def get_tool_definition(self) -> Dict[str, Any]:
|
||||
"""获取工具定义,用于LLM function calling
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 工具定义字典,格式与BaseTool一致
|
||||
格式: {"name": str, "description": str, "parameters": List[Tuple]}
|
||||
@@ -58,14 +58,14 @@ class MemoryRetrievalTool:
|
||||
# 转换参数格式为元组列表,格式与BaseTool一致
|
||||
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
|
||||
param_tuples = []
|
||||
|
||||
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type_str = param.get("type", "string").lower()
|
||||
param_desc = param.get("description", "")
|
||||
is_required = param.get("required", False)
|
||||
enum_values = param.get("enum", None)
|
||||
|
||||
|
||||
# 转换类型字符串到ToolParamType
|
||||
type_mapping = {
|
||||
"string": ToolParamType.STRING,
|
||||
@@ -76,18 +76,14 @@ class MemoryRetrievalTool:
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
|
||||
|
||||
|
||||
# 构建参数元组
|
||||
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
|
||||
param_tuples.append(param_tuple)
|
||||
|
||||
|
||||
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
||||
tool_def = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": param_tuples
|
||||
}
|
||||
|
||||
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
|
||||
|
||||
return tool_def
|
||||
|
||||
|
||||
@@ -126,10 +122,10 @@ class MemoryRetrievalToolRegistry:
|
||||
action_types.append("final_answer")
|
||||
action_types.append("no_answer")
|
||||
return " 或 ".join([f'"{at}"' for at in action_types])
|
||||
|
||||
|
||||
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具的定义列表,用于LLM function calling
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
|
||||
"""
|
||||
|
||||
@@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
|
||||
class Person:
|
||||
@classmethod
|
||||
def register_person(
|
||||
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
|
||||
cls,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
nickname: str,
|
||||
group_id: Optional[str] = None,
|
||||
group_nick_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
注册新用户的类方法
|
||||
@@ -727,7 +732,7 @@ person_info_manager = PersonInfoManager()
|
||||
|
||||
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
||||
"""将人物信息存入person_info的memory_points
|
||||
|
||||
|
||||
Args:
|
||||
person_name: 人物名称
|
||||
memory_content: 记忆内容
|
||||
@@ -739,13 +744,13 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
if not chat_stream:
|
||||
logger.warning(f"无法获取chat_stream for chat_id: {chat_id}")
|
||||
return
|
||||
|
||||
|
||||
platform = chat_stream.platform
|
||||
|
||||
|
||||
# 尝试从person_name查找person_id
|
||||
# 首先尝试通过person_name查找
|
||||
person_id = get_person_id_by_person_name(person_name)
|
||||
|
||||
|
||||
if not person_id:
|
||||
# 如果通过person_name找不到,尝试从chat_stream获取user_info
|
||||
if chat_stream.user_info:
|
||||
@@ -754,25 +759,25 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
else:
|
||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
||||
return
|
||||
|
||||
|
||||
# 创建或获取Person对象
|
||||
person = Person(person_id=person_id)
|
||||
|
||||
|
||||
if not person.is_known:
|
||||
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
|
||||
return
|
||||
|
||||
|
||||
# 确定记忆分类(可以根据memory_content判断,这里使用通用分类)
|
||||
category = "其他" # 默认分类,可以根据需要调整
|
||||
|
||||
|
||||
# 记忆点格式:category:content:weight
|
||||
weight = "1.0" # 默认权重
|
||||
memory_point = f"{category}:{memory_content}:{weight}"
|
||||
|
||||
|
||||
# 添加到memory_points
|
||||
if not person.memory_points:
|
||||
person.memory_points = []
|
||||
|
||||
|
||||
# 检查是否已存在相似的记忆点(避免重复)
|
||||
is_duplicate = False
|
||||
for existing_point in person.memory_points:
|
||||
@@ -781,16 +786,20 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
if len(parts) >= 2:
|
||||
existing_content = parts[1].strip()
|
||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
||||
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
|
||||
if (
|
||||
existing_content == memory_content
|
||||
or memory_content in existing_content
|
||||
or existing_content in memory_content
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
person.memory_points.append(memory_point)
|
||||
person.sync_to_database()
|
||||
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
|
||||
else:
|
||||
logger.debug(f"记忆点已存在,跳过: {memory_point}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储人物记忆失败: {e}")
|
||||
|
||||
@@ -124,7 +124,6 @@ class ToolExecutor:
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
)
|
||||
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
@@ -102,13 +102,13 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 5. 调用LLM
|
||||
models = llm_api.get_available_models()
|
||||
chat_model_config = models.get("replyer") # 使用字典访问方式
|
||||
chat_model_config = models.get("utils") # 使用字典访问方式
|
||||
if not chat_model_config:
|
||||
logger.error(f"{self.log_prefix} 未找到'replyer'模型配置,无法调用LLM")
|
||||
return False, "未找到'replyer'模型配置"
|
||||
logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM")
|
||||
return False, "未找到'utils'模型配置"
|
||||
|
||||
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="emoji"
|
||||
prompt, model_config=chat_model_config, request_type="emoji.select"
|
||||
)
|
||||
|
||||
if not success:
|
||||
|
||||
@@ -51,7 +51,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
||||
"""
|
||||
递归合并字典,保留 target 中的注释和格式
|
||||
将 source 的值更新到 target 中(仅更新已存在的键)
|
||||
|
||||
|
||||
Args:
|
||||
target: 目标字典(tomlkit 对象,包含注释)
|
||||
source: 源字典(普通 dict 或 list)
|
||||
@@ -59,7 +59,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
||||
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
|
||||
if isinstance(source, list):
|
||||
return # 调用者需要直接赋值
|
||||
|
||||
|
||||
# 如果都是字典,递归合并
|
||||
if isinstance(source, dict) and isinstance(target, dict):
|
||||
for key, value in source.items():
|
||||
@@ -319,6 +319,58 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
|
||||
|
||||
# ===== 原始 TOML 文件操作接口 =====
|
||||
|
||||
|
||||
@router.get("/bot/raw")
|
||||
async def get_bot_config_raw():
|
||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
raw_content = f.read()
|
||||
|
||||
return {"success": True, "content": raw_content}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
config_data = tomlkit.loads(raw_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
|
||||
# 验证配置数据结构
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(raw_content)
|
||||
|
||||
logger.info("麦麦主程序配置已更新(原始模式)")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
@@ -364,3 +416,144 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
|
||||
|
||||
# ===== 适配器配置管理接口 =====
|
||||
|
||||
|
||||
@router.get("/adapter-config/path")
|
||||
async def get_adapter_config_path():
|
||||
"""获取保存的适配器配置文件路径"""
|
||||
try:
|
||||
# 从 data/webui.json 读取路径偏好
|
||||
webui_data_path = os.path.join("data", "webui.json")
|
||||
if not os.path.exists(webui_data_path):
|
||||
return {"success": True, "path": None}
|
||||
|
||||
import json
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
|
||||
adapter_config_path = webui_data.get("adapter_config_path")
|
||||
if not adapter_config_path:
|
||||
return {"success": True, "path": None}
|
||||
|
||||
# 检查文件是否存在并返回最后修改时间
|
||||
if os.path.exists(adapter_config_path):
|
||||
import datetime
|
||||
mtime = os.path.getmtime(adapter_config_path)
|
||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||
return {"success": True, "path": adapter_config_path, "lastModified": last_modified}
|
||||
else:
|
||||
return {"success": True, "path": adapter_config_path, "lastModified": None}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
|
||||
# 保存到 data/webui.json
|
||||
webui_data_path = os.path.join("data", "webui.json")
|
||||
import json
|
||||
|
||||
# 读取现有数据
|
||||
if os.path.exists(webui_data_path):
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
else:
|
||||
webui_data = {}
|
||||
|
||||
# 更新路径
|
||||
webui_data["adapter_config_path"] = path
|
||||
|
||||
# 保存
|
||||
os.makedirs("data", exist_ok=True)
|
||||
with open(webui_data_path, "w", encoding="utf-8") as f:
|
||||
json.dump(webui_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"适配器配置路径已保存: {path}")
|
||||
return {"success": True, "message": "路径已保存"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
async def get_adapter_config(path: str):
|
||||
"""从指定路径读取适配器配置文件"""
|
||||
try:
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径参数不能为空")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(path):
|
||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||
|
||||
# 检查文件扩展名
|
||||
if not path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 读取文件内容
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
logger.info(f"已读取适配器配置: {path}")
|
||||
return {"success": True, "content": content}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
content = data.get("content")
|
||||
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
if content is None:
|
||||
raise HTTPException(status_code=400, detail="配置内容不能为空")
|
||||
|
||||
# 检查文件扩展名
|
||||
if not path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
import toml
|
||||
toml.loads(content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
||||
|
||||
# 保存文件
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"适配器配置已保存: {path}")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from src.common.logger import get_logger
|
||||
@@ -7,6 +9,7 @@ from src.common.database.database_model import Emoji
|
||||
from .token_manager import get_token_manager
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
logger = get_logger("webui.emoji")
|
||||
|
||||
@@ -16,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||
|
||||
class EmojiResponse(BaseModel):
|
||||
"""表情包响应"""
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
format: str
|
||||
@@ -24,7 +28,7 @@ class EmojiResponse(BaseModel):
|
||||
query_count: int
|
||||
is_registered: bool
|
||||
is_banned: bool
|
||||
emotion: Optional[List[str]] # 解析后的 JSON
|
||||
emotion: Optional[str] # 直接返回字符串
|
||||
record_time: float
|
||||
register_time: Optional[float]
|
||||
usage_count: int
|
||||
@@ -33,6 +37,7 @@ class EmojiResponse(BaseModel):
|
||||
|
||||
class EmojiListResponse(BaseModel):
|
||||
"""表情包列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
@@ -42,20 +47,23 @@ class EmojiListResponse(BaseModel):
|
||||
|
||||
class EmojiDetailResponse(BaseModel):
|
||||
"""表情包详情响应"""
|
||||
|
||||
success: bool
|
||||
data: EmojiResponse
|
||||
|
||||
|
||||
class EmojiUpdateRequest(BaseModel):
|
||||
"""表情包更新请求"""
|
||||
|
||||
description: Optional[str] = None
|
||||
is_registered: Optional[bool] = None
|
||||
is_banned: Optional[bool] = None
|
||||
emotion: Optional[List[str]] = None
|
||||
emotion: Optional[str] = None
|
||||
|
||||
|
||||
class EmojiUpdateResponse(BaseModel):
|
||||
"""表情包更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[EmojiResponse] = None
|
||||
@@ -63,34 +71,41 @@ class EmojiUpdateResponse(BaseModel):
|
||||
|
||||
class EmojiDeleteResponse(BaseModel):
|
||||
"""表情包删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
emoji_ids: List[int]
|
||||
|
||||
|
||||
class BatchDeleteResponse(BaseModel):
|
||||
"""批量删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
deleted_count: int
|
||||
failed_count: int
|
||||
failed_ids: List[int] = []
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def parse_emotion(emotion_str: Optional[str]) -> Optional[List[str]]:
|
||||
"""解析情感标签 JSON 字符串"""
|
||||
if not emotion_str:
|
||||
return None
|
||||
try:
|
||||
return json.loads(emotion_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
||||
"""将 Emoji 模型转换为响应对象"""
|
||||
return EmojiResponse(
|
||||
@@ -102,7 +117,7 @@ def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
||||
query_count=emoji.query_count,
|
||||
is_registered=emoji.is_registered,
|
||||
is_banned=emoji.is_banned,
|
||||
emotion=parse_emotion(emoji.emotion),
|
||||
emotion=str(emoji.emotion) if emoji.emotion is not None else None,
|
||||
record_time=emoji.record_time,
|
||||
register_time=emoji.register_time,
|
||||
usage_count=emoji.usage_count,
|
||||
@@ -118,11 +133,11 @@ async def get_emoji_list(
|
||||
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
||||
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
||||
format: Optional[str] = Query(None, description="格式筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表情包列表
|
||||
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
@@ -131,61 +146,51 @@ async def get_emoji_list(
|
||||
is_banned: 是否被禁用筛选
|
||||
format: 格式筛选
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表情包列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = Emoji.select()
|
||||
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Emoji.description.contains(search)) |
|
||||
(Emoji.emoji_hash.contains(search))
|
||||
)
|
||||
|
||||
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
|
||||
|
||||
# 注册状态过滤
|
||||
if is_registered is not None:
|
||||
query = query.where(Emoji.is_registered == is_registered)
|
||||
|
||||
|
||||
# 禁用状态过滤
|
||||
if is_banned is not None:
|
||||
query = query.where(Emoji.is_banned == is_banned)
|
||||
|
||||
|
||||
# 格式过滤
|
||||
if format:
|
||||
query = query.where(Emoji.format == format)
|
||||
|
||||
|
||||
# 排序:使用次数倒序,然后按记录时间倒序
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Emoji.usage_count.desc(),
|
||||
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
|
||||
Emoji.record_time.desc()
|
||||
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
|
||||
)
|
||||
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
emojis = query.offset(offset).limit(page_size)
|
||||
|
||||
|
||||
# 转换为响应对象
|
||||
data = [emoji_to_response(emoji) for emoji in emojis]
|
||||
|
||||
return EmojiListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
|
||||
|
||||
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -194,33 +199,27 @@ async def get_emoji_list(
|
||||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表情包详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
return EmojiDetailResponse(
|
||||
success=True,
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
|
||||
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -229,61 +228,50 @@ async def get_emoji_detail(
|
||||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(
|
||||
emoji_id: int,
|
||||
request: EmojiUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 处理情感标签(转换为 JSON)
|
||||
if 'emotion' in update_data:
|
||||
if update_data['emotion'] is None:
|
||||
update_data['emotion'] = None
|
||||
else:
|
||||
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
|
||||
|
||||
|
||||
# emotion 字段直接使用字符串,无需转换
|
||||
|
||||
# 如果注册状态从 False 变为 True,记录注册时间
|
||||
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
|
||||
update_data['register_time'] = time.time()
|
||||
|
||||
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||
update_data["register_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(emoji, field, value)
|
||||
|
||||
|
||||
emoji.save()
|
||||
|
||||
|
||||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=emoji_to_response(emoji)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -292,41 +280,35 @@ async def update_emoji(
|
||||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
# 记录删除信息
|
||||
emoji_hash = emoji.emoji_hash
|
||||
|
||||
|
||||
# 执行删除
|
||||
emoji.delete_instance()
|
||||
|
||||
|
||||
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
||||
|
||||
return EmojiDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除表情包: {emoji_hash}"
|
||||
)
|
||||
|
||||
|
||||
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -335,31 +317,29 @@ async def delete_emoji(
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包统计数据
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
total = Emoji.select().count()
|
||||
registered = Emoji.select().where(Emoji.is_registered).count()
|
||||
banned = Emoji.select().where(Emoji.is_banned).count()
|
||||
|
||||
|
||||
# 按格式统计
|
||||
formats = {}
|
||||
for emoji in Emoji.select(Emoji.format):
|
||||
fmt = emoji.format
|
||||
formats[fmt] = formats.get(fmt, 0) + 1
|
||||
|
||||
|
||||
# 获取最常用的表情包(前10)
|
||||
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
|
||||
top_used_list = [
|
||||
@@ -367,11 +347,11 @@ async def get_emoji_stats(
|
||||
"id": emoji.id,
|
||||
"emoji_hash": emoji.emoji_hash,
|
||||
"description": emoji.description,
|
||||
"usage_count": emoji.usage_count
|
||||
"usage_count": emoji.usage_count,
|
||||
}
|
||||
for emoji in top_used
|
||||
]
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
@@ -380,10 +360,10 @@ async def get_emoji_stats(
|
||||
"banned": banned,
|
||||
"unregistered": total - registered,
|
||||
"formats": formats,
|
||||
"top_used": top_used_list
|
||||
}
|
||||
"top_used": top_used_list,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -392,47 +372,40 @@ async def get_emoji_stats(
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
if emoji.is_registered:
|
||||
raise HTTPException(status_code=400, detail="该表情包已经注册")
|
||||
|
||||
|
||||
if emoji.is_banned:
|
||||
raise HTTPException(status_code=400, detail="该表情包已被禁用,无法注册")
|
||||
|
||||
|
||||
# 注册表情包
|
||||
emoji.is_registered = True
|
||||
emoji.register_time = time.time()
|
||||
emoji.save()
|
||||
|
||||
|
||||
logger.info(f"表情包已注册: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message="表情包注册成功",
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
|
||||
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -441,43 +414,149 @@ async def register_emoji(
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
# 禁用表情包(同时取消注册)
|
||||
emoji.is_banned = True
|
||||
emoji.is_registered = False
|
||||
emoji.save()
|
||||
|
||||
|
||||
logger.info(f"表情包已禁用: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message="表情包禁用成功",
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
|
||||
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"禁用表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"禁用表情包失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{emoji_id}/thumbnail")
|
||||
async def get_emoji_thumbnail(
|
||||
emoji_id: int,
|
||||
token: Optional[str] = Query(None, description="访问令牌"),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表情包缩略图
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
token: 访问令牌(通过 query parameter)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表情包图片文件
|
||||
"""
|
||||
try:
|
||||
# 优先使用 query parameter 中的 token(用于 img 标签)
|
||||
if token:
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
else:
|
||||
# 如果没有 query token,则验证 Authorization header
|
||||
verify_auth_token(authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(emoji.full_path):
|
||||
raise HTTPException(status_code=404, detail="表情包文件不存在")
|
||||
|
||||
# 根据格式设置 MIME 类型
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
|
||||
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表情包缩略图失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除表情包
|
||||
|
||||
Args:
|
||||
request: 包含emoji_ids列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
if not request.emoji_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
|
||||
|
||||
deleted_count = 0
|
||||
failed_count = 0
|
||||
failed_ids = []
|
||||
|
||||
for emoji_id in request.emoji_ids:
|
||||
try:
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
if emoji:
|
||||
emoji.delete_instance()
|
||||
deleted_count += 1
|
||||
logger.info(f"批量删除表情包: {emoji_id}")
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_ids.append(emoji_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除表情包 {emoji_id} 失败: {e}")
|
||||
failed_count += 1
|
||||
failed_ids.append(emoji_id)
|
||||
|
||||
message = f"成功删除 {deleted_count} 个表情包"
|
||||
if failed_count > 0:
|
||||
message += f",{failed_count} 个失败"
|
||||
|
||||
return BatchDeleteResponse(
|
||||
success=True,
|
||||
message=message,
|
||||
deleted_count=deleted_count,
|
||||
failed_count=failed_count,
|
||||
failed_ids=failed_ids,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"])
|
||||
|
||||
class ExpressionResponse(BaseModel):
|
||||
"""表达方式响应"""
|
||||
|
||||
id: int
|
||||
situation: str
|
||||
style: str
|
||||
@@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel):
|
||||
|
||||
class ExpressionListResponse(BaseModel):
|
||||
"""表达方式列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
@@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel):
|
||||
|
||||
class ExpressionDetailResponse(BaseModel):
|
||||
"""表达方式详情响应"""
|
||||
|
||||
success: bool
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
class ExpressionCreateRequest(BaseModel):
|
||||
"""表达方式创建请求"""
|
||||
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str] = None
|
||||
@@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel):
|
||||
|
||||
class ExpressionUpdateRequest(BaseModel):
|
||||
"""表达方式更新请求"""
|
||||
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
@@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel):
|
||||
|
||||
class ExpressionUpdateResponse(BaseModel):
|
||||
"""表达方式更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[ExpressionResponse] = None
|
||||
@@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel):
|
||||
|
||||
class ExpressionDeleteResponse(BaseModel):
|
||||
"""表达方式删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ExpressionCreateResponse(BaseModel):
|
||||
"""表达方式创建响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: ExpressionResponse
|
||||
@@ -82,13 +91,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -112,64 +121,58 @@ async def get_expression_list(
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表达方式列表
|
||||
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 situation, style, context)
|
||||
chat_id: 聊天ID筛选
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Expression.situation.contains(search)) |
|
||||
(Expression.style.contains(search)) |
|
||||
(Expression.context.contains(search))
|
||||
(Expression.situation.contains(search))
|
||||
| (Expression.style.contains(search))
|
||||
| (Expression.context.contains(search))
|
||||
)
|
||||
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
|
||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
|
||||
Expression.last_active_time.desc()
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||
)
|
||||
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
|
||||
|
||||
# 转换为响应对象
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -178,33 +181,27 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(
|
||||
expression_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
return ExpressionDetailResponse(
|
||||
success=True,
|
||||
data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -213,25 +210,22 @@ async def get_expression_detail(
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
|
||||
Args:
|
||||
request: 创建请求
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 创建表达方式
|
||||
expression = Expression.create(
|
||||
situation=request.situation,
|
||||
@@ -242,15 +236,13 @@ async def create_expression(
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||
|
||||
|
||||
return ExpressionCreateResponse(
|
||||
success=True,
|
||||
message="表达方式创建成功",
|
||||
data=expression_to_response(expression)
|
||||
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -260,52 +252,48 @@ async def create_expression(
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
|
||||
# 更新最后活跃时间
|
||||
update_data['last_active_time'] = time.time()
|
||||
|
||||
update_data["last_active_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(expression, field, value)
|
||||
|
||||
|
||||
expression.save()
|
||||
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
|
||||
return ExpressionUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=expression_to_response(expression)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -314,41 +302,35 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(
|
||||
expression_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
|
||||
# 记录删除信息
|
||||
situation = expression.situation
|
||||
|
||||
|
||||
# 执行删除
|
||||
expression.delete_instance()
|
||||
|
||||
|
||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||
|
||||
return ExpressionDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除表达方式: {situation}"
|
||||
)
|
||||
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -356,47 +338,93 @@ async def delete_expression(
|
||||
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
ids: List[int]
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
Args:
|
||||
request: 包含要删除的ID列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
|
||||
# 查找所有要删除的表达方式
|
||||
expressions = Expression.select().where(Expression.id.in_(request.ids))
|
||||
found_ids = [expr.id for expr in expressions]
|
||||
|
||||
# 检查是否有未找到的ID
|
||||
not_found_ids = set(request.ids) - set(found_ids)
|
||||
if not_found_ids:
|
||||
logger.warning(f"部分表达方式未找到: {not_found_ids}")
|
||||
|
||||
# 执行批量删除
|
||||
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
|
||||
|
||||
logger.info(f"批量删除了 {deleted_count} 个表达方式")
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
total = Expression.select().count()
|
||||
|
||||
|
||||
# 按 chat_id 统计
|
||||
chat_stats = {}
|
||||
for expr in Expression.select(Expression.chat_id):
|
||||
chat_id = expr.chat_id
|
||||
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
||||
|
||||
|
||||
# 获取最近创建的记录数(7天内)
|
||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||
recent = Expression.select().where(
|
||||
(Expression.create_date.is_null(False)) &
|
||||
(Expression.create_date >= seven_days_ago)
|
||||
).count()
|
||||
|
||||
recent = (
|
||||
Expression.select()
|
||||
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"recent_7days": recent,
|
||||
"chat_count": len(chat_stats),
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
|
||||
}
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
import httpx
|
||||
@@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror")
|
||||
# 导入进度更新函数(避免循环导入)
|
||||
_update_progress = None
|
||||
|
||||
|
||||
def set_update_progress_callback(callback):
|
||||
"""设置进度更新回调函数"""
|
||||
global _update_progress
|
||||
@@ -23,6 +25,7 @@ def set_update_progress_callback(callback):
|
||||
|
||||
class MirrorType(str, Enum):
|
||||
"""镜像源类型"""
|
||||
|
||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||
@@ -34,10 +37,10 @@ class MirrorType(str, Enum):
|
||||
|
||||
class GitMirrorConfig:
|
||||
"""Git 镜像源配置管理"""
|
||||
|
||||
|
||||
# 配置文件路径
|
||||
CONFIG_FILE = Path("data/webui.json")
|
||||
|
||||
|
||||
# 默认镜像源配置
|
||||
DEFAULT_MIRRORS = [
|
||||
{
|
||||
@@ -47,7 +50,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "hk-gh-proxy",
|
||||
@@ -56,7 +59,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 2,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "cdn-gh-proxy",
|
||||
@@ -65,7 +68,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 3,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "edgeone-gh-proxy",
|
||||
@@ -74,7 +77,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 4,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "meyzh-github",
|
||||
@@ -83,7 +86,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 5,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "github",
|
||||
@@ -92,23 +95,23 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 999,
|
||||
"created_at": None
|
||||
}
|
||||
"created_at": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""初始化配置管理器"""
|
||||
self.config_file = self.CONFIG_FILE
|
||||
self.mirrors: List[Dict[str, Any]] = []
|
||||
self._load_config()
|
||||
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# 检查是否有镜像源配置
|
||||
if "git_mirrors" not in data or not data["git_mirrors"]:
|
||||
logger.info("配置文件中未找到镜像源配置,使用默认配置")
|
||||
@@ -122,59 +125,59 @@ class GitMirrorConfig:
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
self._init_default_mirrors()
|
||||
|
||||
|
||||
def _init_default_mirrors(self) -> None:
|
||||
"""初始化默认镜像源"""
|
||||
current_time = datetime.now().isoformat()
|
||||
self.mirrors = []
|
||||
|
||||
|
||||
for mirror in self.DEFAULT_MIRRORS:
|
||||
mirror_copy = mirror.copy()
|
||||
mirror_copy["created_at"] = current_time
|
||||
self.mirrors.append(mirror_copy)
|
||||
|
||||
|
||||
self._save_config()
|
||||
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
|
||||
|
||||
|
||||
def _save_config(self) -> None:
|
||||
"""保存配置到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 读取现有配置
|
||||
existing_data = {}
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
|
||||
# 更新镜像源配置
|
||||
existing_data["git_mirrors"] = self.mirrors
|
||||
|
||||
|
||||
# 写入文件
|
||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
logger.debug(f"配置已保存到 {self.config_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
|
||||
|
||||
def get_all_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有镜像源"""
|
||||
return self.mirrors.copy()
|
||||
|
||||
|
||||
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有启用的镜像源,按优先级排序"""
|
||||
enabled = [m for m in self.mirrors if m.get("enabled", False)]
|
||||
return sorted(enabled, key=lambda x: x.get("priority", 999))
|
||||
|
||||
|
||||
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 ID 获取镜像源"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
return mirror.copy()
|
||||
return None
|
||||
|
||||
|
||||
def add_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
@@ -182,26 +185,26 @@ class GitMirrorConfig:
|
||||
raw_prefix: str,
|
||||
clone_prefix: str,
|
||||
enabled: bool = True,
|
||||
priority: Optional[int] = None
|
||||
priority: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
添加新的镜像源
|
||||
|
||||
|
||||
Returns:
|
||||
添加的镜像源配置
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 如果镜像源 ID 已存在
|
||||
"""
|
||||
# 检查 ID 是否已存在
|
||||
if self.get_mirror_by_id(mirror_id):
|
||||
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
|
||||
|
||||
|
||||
# 如果未指定优先级,使用最大优先级 + 1
|
||||
if priority is None:
|
||||
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
|
||||
priority = max_priority + 1
|
||||
|
||||
|
||||
new_mirror = {
|
||||
"id": mirror_id,
|
||||
"name": name,
|
||||
@@ -209,15 +212,15 @@ class GitMirrorConfig:
|
||||
"clone_prefix": clone_prefix,
|
||||
"enabled": enabled,
|
||||
"priority": priority,
|
||||
"created_at": datetime.now().isoformat()
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
self.mirrors.append(new_mirror)
|
||||
self._save_config()
|
||||
|
||||
|
||||
logger.info(f"已添加镜像源: {mirror_id} - {name}")
|
||||
return new_mirror.copy()
|
||||
|
||||
|
||||
def update_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
@@ -225,11 +228,11 @@ class GitMirrorConfig:
|
||||
raw_prefix: Optional[str] = None,
|
||||
clone_prefix: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
priority: Optional[int] = None
|
||||
priority: Optional[int] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新镜像源配置
|
||||
|
||||
|
||||
Returns:
|
||||
更新后的镜像源配置,如果不存在则返回 None
|
||||
"""
|
||||
@@ -245,19 +248,19 @@ class GitMirrorConfig:
|
||||
mirror["enabled"] = enabled
|
||||
if priority is not None:
|
||||
mirror["priority"] = priority
|
||||
|
||||
|
||||
mirror["updated_at"] = datetime.now().isoformat()
|
||||
self._save_config()
|
||||
|
||||
|
||||
logger.info(f"已更新镜像源: {mirror_id}")
|
||||
return mirror.copy()
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def delete_mirror(self, mirror_id: str) -> bool:
|
||||
"""
|
||||
删除镜像源
|
||||
|
||||
|
||||
Returns:
|
||||
True 如果删除成功,False 如果镜像源不存在
|
||||
"""
|
||||
@@ -267,9 +270,9 @@ class GitMirrorConfig:
|
||||
self._save_config()
|
||||
logger.info(f"已删除镜像源: {mirror_id}")
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_default_priority_list(self) -> List[str]:
|
||||
"""获取默认优先级列表(仅启用的镜像源 ID)"""
|
||||
enabled = self.get_enabled_mirrors()
|
||||
@@ -278,16 +281,11 @@ class GitMirrorConfig:
|
||||
|
||||
class GitMirrorService:
|
||||
"""Git 镜像源服务"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
timeout: int = 30,
|
||||
config: Optional[GitMirrorConfig] = None
|
||||
):
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||
"""
|
||||
初始化 Git 镜像源服务
|
||||
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
timeout: 请求超时时间(秒)
|
||||
@@ -297,16 +295,16 @@ class GitMirrorService:
|
||||
self.timeout = timeout
|
||||
self.config = config or GitMirrorConfig()
|
||||
logger.info(f"Git镜像源服务初始化完成,已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
|
||||
|
||||
|
||||
def get_mirror_config(self) -> GitMirrorConfig:
|
||||
"""获取镜像源配置管理器"""
|
||||
return self.config
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_git_installed() -> Dict[str, Any]:
|
||||
"""
|
||||
检查本机是否安装了 Git
|
||||
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- installed: bool - 是否已安装 Git
|
||||
@@ -316,54 +314,33 @@ class GitMirrorService:
|
||||
"""
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
|
||||
try:
|
||||
# 查找 git 可执行文件路径
|
||||
git_path = shutil.which("git")
|
||||
|
||||
|
||||
if not git_path:
|
||||
logger.warning("未找到 Git 可执行文件")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": "系统中未找到 Git,请先安装 Git"
|
||||
}
|
||||
|
||||
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||
|
||||
# 获取 Git 版本
|
||||
result = subprocess.run(
|
||||
["git", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||
return {
|
||||
"installed": True,
|
||||
"version": version,
|
||||
"path": git_path
|
||||
}
|
||||
return {"installed": True, "version": version, "path": git_path}
|
||||
else:
|
||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": f"Git 命令执行失败: {result.stderr}"
|
||||
}
|
||||
|
||||
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Git 版本检测超时")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": "Git 版本检测超时"
|
||||
}
|
||||
return {"installed": False, "error": "Git 版本检测超时"}
|
||||
except Exception as e:
|
||||
logger.error(f"检测 Git 时发生错误: {e}")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": f"检测 Git 时发生错误: {str(e)}"
|
||||
}
|
||||
|
||||
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||
|
||||
async def fetch_raw_file(
|
||||
self,
|
||||
owner: str,
|
||||
@@ -371,11 +348,11 @@ class GitMirrorService:
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None
|
||||
custom_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
@@ -383,7 +360,7 @@ class GitMirrorService:
|
||||
file_path: 文件路径
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义完整 URL(如果提供,将忽略其他参数)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
@@ -393,29 +370,24 @@ class GitMirrorService:
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
|
||||
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._fetch_with_url(custom_url, "custom")
|
||||
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到镜像源: {mirror_id}",
|
||||
"mirror_used": None,
|
||||
"attempts": 0
|
||||
}
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
|
||||
total_mirrors = len(mirrors_to_try)
|
||||
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for index, mirror in enumerate(mirrors_to_try, 1):
|
||||
# 推送进度:正在尝试第 N 个镜像源
|
||||
@@ -427,15 +399,13 @@ class GitMirrorService:
|
||||
progress=progress,
|
||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
result = await self._fetch_raw_from_mirror(
|
||||
owner, repo, branch, file_path, mirror
|
||||
)
|
||||
|
||||
|
||||
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||
|
||||
if result["success"]:
|
||||
# 成功,推送进度
|
||||
if _update_progress:
|
||||
@@ -445,15 +415,15 @@ class GitMirrorService:
|
||||
progress=70,
|
||||
message=f"成功从 {mirror['name']} 获取数据",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
return result
|
||||
|
||||
|
||||
# 失败,记录日志并推送失败信息
|
||||
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
|
||||
|
||||
|
||||
if _update_progress and index < total_mirrors:
|
||||
try:
|
||||
await _update_progress(
|
||||
@@ -461,39 +431,29 @@ class GitMirrorService:
|
||||
progress=30 + int(index / total_mirrors * 40),
|
||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {
|
||||
"success": False,
|
||||
"error": "所有镜像源均失败",
|
||||
"mirror_used": None,
|
||||
"attempts": len(mirrors_to_try)
|
||||
}
|
||||
|
||||
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _fetch_raw_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror: Dict[str, Any]
|
||||
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源获取文件"""
|
||||
# 构建 URL
|
||||
raw_prefix = mirror["raw_prefix"]
|
||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||
|
||||
|
||||
return await self._fetch_with_url(url, mirror["id"])
|
||||
|
||||
|
||||
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
|
||||
"""使用指定 URL 获取文件,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
try:
|
||||
@@ -501,14 +461,14 @@ class GitMirrorService:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
logger.info(f"成功获取文件: {url}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": response.text,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
"url": url,
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||
@@ -519,15 +479,9 @@ class GitMirrorService:
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": last_error,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
}
|
||||
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
async def clone_repository(
|
||||
self,
|
||||
owner: str,
|
||||
@@ -536,11 +490,11 @@ class GitMirrorService:
|
||||
branch: Optional[str] = None,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
depth: Optional[int] = None
|
||||
depth: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
克隆 GitHub 仓库
|
||||
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
@@ -549,7 +503,7 @@ class GitMirrorService:
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义克隆 URL
|
||||
depth: 克隆深度(浅克隆)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
@@ -559,44 +513,32 @@ class GitMirrorService:
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}")
|
||||
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
|
||||
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到镜像源: {mirror_id}",
|
||||
"mirror_used": None,
|
||||
"attempts": 0
|
||||
}
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for mirror in mirrors_to_try:
|
||||
result = await self._clone_from_mirror(
|
||||
owner, repo, target_path, branch, depth, mirror
|
||||
)
|
||||
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||
if result["success"]:
|
||||
return result
|
||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {
|
||||
"success": False,
|
||||
"error": "所有镜像源克隆均失败",
|
||||
"mirror_used": None,
|
||||
"attempts": len(mirrors_to_try)
|
||||
}
|
||||
|
||||
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _clone_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
@@ -604,52 +546,47 @@ class GitMirrorService:
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror: Dict[str, Any]
|
||||
mirror: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源克隆仓库"""
|
||||
# 构建克隆 URL
|
||||
clone_prefix = mirror["clone_prefix"]
|
||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||
|
||||
|
||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||
|
||||
|
||||
async def _clone_with_url(
|
||||
self,
|
||||
url: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror_type: str
|
||||
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用指定 URL 克隆仓库,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
|
||||
|
||||
try:
|
||||
# 确保目标路径不存在
|
||||
if target_path.exists():
|
||||
logger.warning(f"目标路径已存在,删除: {target_path}")
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
|
||||
# 构建 git clone 命令
|
||||
cmd = ["git", "clone"]
|
||||
|
||||
|
||||
# 添加分支参数
|
||||
if branch:
|
||||
cmd.extend(["-b", branch])
|
||||
|
||||
|
||||
# 添加深度参数(浅克隆)
|
||||
if depth:
|
||||
cmd.extend(["--depth", str(depth)])
|
||||
|
||||
|
||||
# 添加 URL 和目标路径
|
||||
cmd.extend([url, str(target_path)])
|
||||
|
||||
|
||||
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
|
||||
|
||||
|
||||
# 推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
@@ -657,24 +594,24 @@ class GitMirrorService:
|
||||
stage="loading",
|
||||
progress=20 + attempt * 10,
|
||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||
operation="install"
|
||||
operation="install",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
def run_git_clone():
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300 # 5分钟超时
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
|
||||
|
||||
process = await loop.run_in_executor(None, run_git_clone)
|
||||
|
||||
|
||||
if process.returncode == 0:
|
||||
logger.info(f"成功克隆仓库: {url} -> {target_path}")
|
||||
return {
|
||||
@@ -683,40 +620,34 @@ class GitMirrorService:
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
"branch": branch or "default"
|
||||
"branch": branch or "default",
|
||||
}
|
||||
else:
|
||||
last_error = f"Git 克隆失败: {process.stderr}"
|
||||
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
last_error = "克隆超时(超过 5 分钟)"
|
||||
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
|
||||
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
|
||||
except FileNotFoundError:
|
||||
last_error = "Git 未安装或不在 PATH 中"
|
||||
logger.error(f"Git 未找到: {last_error}")
|
||||
break # Git 不存在,不需要重试
|
||||
|
||||
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": last_error,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
}
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebSocket 日志推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set
|
||||
import json
|
||||
@@ -14,30 +15,30 @@ active_connections: Set[WebSocket] = set()
|
||||
|
||||
def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
"""从日志文件中加载最近的日志
|
||||
|
||||
|
||||
Args:
|
||||
limit: 返回的最大日志条数
|
||||
|
||||
|
||||
Returns:
|
||||
日志列表
|
||||
"""
|
||||
logs = []
|
||||
log_dir = Path("logs")
|
||||
|
||||
|
||||
if not log_dir.exists():
|
||||
return logs
|
||||
|
||||
|
||||
# 获取所有日志文件,按修改时间排序
|
||||
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
|
||||
# 用于生成唯一 ID 的计数器
|
||||
log_counter = 0
|
||||
|
||||
|
||||
# 从最新的文件开始读取
|
||||
for log_file in log_files:
|
||||
if len(logs) >= limit:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
@@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
log_entry = json.loads(line.strip())
|
||||
# 转换为前端期望的格式
|
||||
# 使用时间戳 + 计数器生成唯一 ID
|
||||
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
timestamp_id = (
|
||||
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
)
|
||||
formatted_log = {
|
||||
"id": f"{timestamp_id}_{log_counter}",
|
||||
"timestamp": log_entry.get("timestamp", ""),
|
||||
@@ -64,7 +67,7 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
except Exception as e:
|
||||
logger.error(f"读取日志文件失败 {log_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 反转列表,使其按时间顺序排列(旧到新)
|
||||
return list(reversed(logs))
|
||||
|
||||
@@ -72,35 +75,35 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
@router.websocket("/ws/logs")
|
||||
async def websocket_logs(websocket: WebSocket):
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
|
||||
客户端连接后会持续接收服务器端的日志消息
|
||||
"""
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
||||
|
||||
|
||||
# 连接建立后,立即发送历史日志
|
||||
try:
|
||||
recent_logs = load_recent_logs(limit=100)
|
||||
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
|
||||
|
||||
|
||||
for log_entry in recent_logs:
|
||||
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error(f"发送历史日志失败: {e}")
|
||||
|
||||
|
||||
try:
|
||||
# 保持连接,等待客户端消息或断开
|
||||
while True:
|
||||
# 接收客户端消息(用于心跳或控制指令)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
|
||||
# 可以处理客户端的控制消息,例如:
|
||||
# - "ping" -> 心跳检测
|
||||
# - {"filter": "ERROR"} -> 设置日志级别过滤
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
@@ -111,19 +114,19 @@ async def websocket_logs(websocket: WebSocket):
|
||||
|
||||
async def broadcast_log(log_data: dict):
|
||||
"""广播日志到所有连接的 WebSocket 客户端
|
||||
|
||||
|
||||
Args:
|
||||
log_data: 日志数据字典
|
||||
"""
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
|
||||
# 格式化为 JSON
|
||||
message = json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
|
||||
# 记录需要断开的连接
|
||||
disconnected = set()
|
||||
|
||||
|
||||
# 广播到所有客户端
|
||||
for connection in active_connections:
|
||||
try:
|
||||
@@ -131,7 +134,7 @@ async def broadcast_log(log_data: dict):
|
||||
except Exception:
|
||||
# 发送失败,标记为断开
|
||||
disconnected.add(connection)
|
||||
|
||||
|
||||
# 清理断开的连接
|
||||
if disconnected:
|
||||
active_connections.difference_update(disconnected)
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
|
||||
def setup_webui(mode: str = "production") -> bool:
|
||||
"""
|
||||
设置 WebUI
|
||||
|
||||
Args:
|
||||
mode: 运行模式,"development" 或 "production"
|
||||
|
||||
Returns:
|
||||
bool: 是否成功设置
|
||||
"""
|
||||
# 初始化 Token 管理器(确保 token 文件存在)
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
logger.info("💡 请使用此 Token 登录 WebUI")
|
||||
|
||||
if mode == "development":
|
||||
return setup_dev_mode()
|
||||
else:
|
||||
return setup_production_mode()
|
||||
|
||||
|
||||
def setup_dev_mode() -> bool:
|
||||
"""设置开发模式 - 仅启用 CORS,前端自行启动"""
|
||||
from src.common.server import get_global_server
|
||||
from .logs_ws import router as logs_router
|
||||
|
||||
# 注册 WebSocket 日志路由(开发模式也需要)
|
||||
server = get_global_server()
|
||||
server.register_router(logs_router)
|
||||
logger.info("✅ WebSocket 日志推送路由已注册")
|
||||
|
||||
logger.info("📝 WebUI 开发模式已启用")
|
||||
logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev")
|
||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||
return True
|
||||
|
||||
|
||||
def setup_production_mode() -> bool:
|
||||
"""设置生产模式 - 挂载静态文件"""
|
||||
try:
|
||||
from src.common.server import get_global_server
|
||||
from starlette.responses import FileResponse
|
||||
from .logs_ws import router as logs_router
|
||||
import mimetypes
|
||||
|
||||
# 确保正确的 MIME 类型映射
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
mimetypes.add_type('application/javascript', '.mjs')
|
||||
mimetypes.add_type('text/css', '.css')
|
||||
mimetypes.add_type('application/json', '.json')
|
||||
|
||||
server = get_global_server()
|
||||
|
||||
# 注册 WebSocket 日志路由
|
||||
server.register_router(logs_router)
|
||||
logger.info("✅ WebSocket 日志推送路由已注册")
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
static_path = base_dir / "webui" / "dist"
|
||||
|
||||
if not static_path.exists():
|
||||
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
|
||||
logger.warning("💡 请先构建前端: cd webui && npm run build")
|
||||
return False
|
||||
|
||||
if not (static_path / "index.html").exists():
|
||||
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
|
||||
logger.warning("💡 请确认前端已正确构建")
|
||||
return False
|
||||
|
||||
# 处理 SPA 路由
|
||||
@server.app.get("/{full_path:path}")
|
||||
async def serve_spa(full_path: str):
|
||||
"""服务单页应用"""
|
||||
# API 路由不处理
|
||||
if full_path.startswith("api/"):
|
||||
return None
|
||||
|
||||
# 检查文件是否存在
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file():
|
||||
# 自动检测 MIME 类型
|
||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
|
||||
# 返回 index.html(SPA 路由)
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port = os.getenv("PORT", "8000")
|
||||
logger.info("✅ WebUI 生产模式已挂载")
|
||||
logger.info(f"🌐 访问 http://{host}:{port} 查看 WebUI")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"挂载 WebUI 静态文件失败: {e}")
|
||||
return False
|
||||
@@ -1,4 +1,5 @@
|
||||
"""人物信息管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
@@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"])
|
||||
|
||||
class PersonInfoResponse(BaseModel):
|
||||
"""人物信息响应"""
|
||||
|
||||
id: int
|
||||
is_known: bool
|
||||
person_id: str
|
||||
@@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel):
|
||||
|
||||
class PersonListResponse(BaseModel):
|
||||
"""人物列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
@@ -42,12 +45,14 @@ class PersonListResponse(BaseModel):
|
||||
|
||||
class PersonDetailResponse(BaseModel):
|
||||
"""人物详情响应"""
|
||||
|
||||
success: bool
|
||||
data: PersonInfoResponse
|
||||
|
||||
|
||||
class PersonUpdateRequest(BaseModel):
|
||||
"""人物信息更新请求"""
|
||||
|
||||
person_name: Optional[str] = None
|
||||
name_reason: Optional[str] = None
|
||||
nickname: Optional[str] = None
|
||||
@@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel):
|
||||
|
||||
class PersonUpdateResponse(BaseModel):
|
||||
"""人物信息更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[PersonInfoResponse] = None
|
||||
@@ -64,21 +70,38 @@ class PersonUpdateResponse(BaseModel):
|
||||
|
||||
class PersonDeleteResponse(BaseModel):
|
||||
"""人物删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
person_ids: List[str]
|
||||
|
||||
|
||||
class BatchDeleteResponse(BaseModel):
|
||||
"""批量删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
deleted_count: int
|
||||
failed_count: int
|
||||
failed_ids: List[str] = []
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -118,11 +141,11 @@ async def get_person_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取人物信息列表
|
||||
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
@@ -130,58 +153,50 @@ async def get_person_list(
|
||||
is_known: 是否已认识筛选
|
||||
platform: 平台筛选
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
人物信息列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = PersonInfo.select()
|
||||
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search)) |
|
||||
(PersonInfo.nickname.contains(search)) |
|
||||
(PersonInfo.user_id.contains(search))
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
|
||||
# 已认识状态过滤
|
||||
if is_known is not None:
|
||||
query = query.where(PersonInfo.is_known == is_known)
|
||||
|
||||
|
||||
# 平台过滤
|
||||
if platform:
|
||||
query = query.where(PersonInfo.platform == platform)
|
||||
|
||||
|
||||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||
from peewee import Case
|
||||
query = query.order_by(
|
||||
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
|
||||
PersonInfo.last_know.desc()
|
||||
)
|
||||
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
persons = query.offset(offset).limit(page_size)
|
||||
|
||||
|
||||
# 转换为响应对象
|
||||
data = [person_to_response(person) for person in persons]
|
||||
|
||||
return PersonListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
|
||||
|
||||
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -190,33 +205,27 @@ async def get_person_list(
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(
|
||||
person_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
人物详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
return PersonDetailResponse(
|
||||
success=True,
|
||||
data=person_to_response(person)
|
||||
)
|
||||
|
||||
|
||||
return PersonDetailResponse(success=True, data=person_to_response(person))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -225,53 +234,47 @@ async def get_person_detail(
|
||||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
|
||||
# 更新最后修改时间
|
||||
update_data['last_know'] = time.time()
|
||||
|
||||
update_data["last_know"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(person, field, value)
|
||||
|
||||
|
||||
person.save()
|
||||
|
||||
|
||||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
|
||||
return PersonUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=person_to_response(person)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -280,41 +283,35 @@ async def update_person(
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(
|
||||
person_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
|
||||
# 记录删除信息
|
||||
person_name = person.person_name or person.nickname or person.user_id
|
||||
|
||||
|
||||
# 执行删除
|
||||
person.delete_instance()
|
||||
|
||||
|
||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||
|
||||
return PersonDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除人物信息: {person_name}"
|
||||
)
|
||||
|
||||
|
||||
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -323,43 +320,89 @@ async def delete_person(
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
total = PersonInfo.select().count()
|
||||
known = PersonInfo.select().where(PersonInfo.is_known).count()
|
||||
unknown = total - known
|
||||
|
||||
|
||||
# 按平台统计
|
||||
platforms = {}
|
||||
for person in PersonInfo.select(PersonInfo.platform):
|
||||
platform = person.platform
|
||||
platforms[platform] = platforms.get(platform, 0) + 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"known": known,
|
||||
"unknown": unknown,
|
||||
"platforms": platforms
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取统计数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除人物信息
|
||||
|
||||
Args:
|
||||
request: 包含person_ids列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
if not request.person_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
|
||||
|
||||
deleted_count = 0
|
||||
failed_count = 0
|
||||
failed_ids = []
|
||||
|
||||
for person_id in request.person_ids:
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if person:
|
||||
person.delete_instance()
|
||||
deleted_count += 1
|
||||
logger.info(f"批量删除: {person_id}")
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_ids.append(person_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除 {person_id} 失败: {e}")
|
||||
failed_count += 1
|
||||
failed_ids.append(person_id)
|
||||
|
||||
message = f"成功删除 {deleted_count} 个人物"
|
||||
if failed_count > 0:
|
||||
message += f",{failed_count} 个失败"
|
||||
|
||||
return BatchDeleteResponse(
|
||||
success=True,
|
||||
message=message,
|
||||
deleted_count=deleted_count,
|
||||
failed_count=failed_count,
|
||||
failed_ids=failed_ids,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除人物信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebSocket 插件加载进度推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set, Dict, Any
|
||||
import json
|
||||
@@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = {
|
||||
"error": None,
|
||||
"plugin_id": None, # 当前操作的插件 ID
|
||||
"total_plugins": 0,
|
||||
"loaded_plugins": 0
|
||||
"loaded_plugins": 0,
|
||||
}
|
||||
|
||||
|
||||
@@ -30,20 +31,20 @@ async def broadcast_progress(progress_data: Dict[str, Any]):
|
||||
"""广播进度更新到所有连接的客户端"""
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected = set()
|
||||
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送进度更新失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
|
||||
# 移除断开的连接
|
||||
for websocket in disconnected:
|
||||
active_connections.discard(websocket)
|
||||
@@ -57,10 +58,10 @@ async def update_progress(
|
||||
error: str = None,
|
||||
plugin_id: str = None,
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0
|
||||
loaded_plugins: int = 0,
|
||||
):
|
||||
"""更新并广播进度
|
||||
|
||||
|
||||
Args:
|
||||
stage: 阶段 (idle, loading, success, error)
|
||||
progress: 进度百分比 (0-100)
|
||||
@@ -80,9 +81,9 @@ async def update_progress(
|
||||
"plugin_id": plugin_id,
|
||||
"total_plugins": total_plugins,
|
||||
"loaded_plugins": loaded_plugins,
|
||||
"timestamp": asyncio.get_event_loop().time()
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
}
|
||||
|
||||
|
||||
await broadcast_progress(progress_data)
|
||||
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
|
||||
|
||||
@@ -90,30 +91,30 @@ async def update_progress(
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket):
|
||||
"""WebSocket 插件加载进度推送端点
|
||||
|
||||
|
||||
客户端连接后会立即收到当前进度状态
|
||||
"""
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
||||
|
||||
|
||||
try:
|
||||
# 发送当前进度状态
|
||||
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
|
||||
|
||||
|
||||
# 保持连接并处理客户端消息
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_text()
|
||||
|
||||
|
||||
# 处理客户端心跳
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端消息时出错: {e}")
|
||||
break
|
||||
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
97
src/webui/routers/system.py
Normal file
97
src/webui/routers/system.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
系统控制路由
|
||||
|
||||
提供系统重启、状态查询等功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
# 记录启动时间
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
"""重启响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
"""状态响应"""
|
||||
|
||||
running: bool
|
||||
uptime: float
|
||||
version: str
|
||||
start_time: str
|
||||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
async def restart_maibot():
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
使用 os.execv 重启当前进程,配置更改将在重启后生效。
|
||||
注意:此操作会使麦麦暂时离线。
|
||||
"""
|
||||
try:
|
||||
# 记录重启操作
|
||||
print(f"[{datetime.now()}] WebUI 触发重启操作")
|
||||
|
||||
# 使用 os.execv 重启当前进程
|
||||
# 这会替换当前进程,保持相同的 PID
|
||||
python = sys.executable
|
||||
args = [python] + sys.argv
|
||||
|
||||
# 返回成功响应(实际上这个响应可能不会发送,因为进程会立即重启)
|
||||
# 但我们仍然返回它以保持 API 一致性
|
||||
os.execv(python, args)
|
||||
|
||||
return RestartResponse(success=True, message="麦麦正在重启中...")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/status", response_model=StatusResponse)
|
||||
async def get_maibot_status():
|
||||
"""
|
||||
获取麦麦运行状态
|
||||
|
||||
返回麦麦的运行状态、运行时长和版本信息。
|
||||
"""
|
||||
try:
|
||||
uptime = time.time() - _start_time
|
||||
|
||||
# 尝试获取版本信息(需要根据实际情况调整)
|
||||
version = MMC_VERSION # 可以从配置或常量中读取
|
||||
|
||||
return StatusResponse(
|
||||
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||
|
||||
|
||||
# 可选:添加更多系统控制功能
|
||||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config():
|
||||
"""
|
||||
热重载配置(不重启进程)
|
||||
|
||||
仅重新加载配置文件,某些配置可能需要重启才能生效。
|
||||
此功能需要在主程序中实现配置热重载逻辑。
|
||||
"""
|
||||
# 这里需要调用主程序的配置重载函数
|
||||
# 示例:await app_instance.reload_config()
|
||||
|
||||
return {"success": True, "message": "配置重载功能待实现"}
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
@@ -11,6 +12,7 @@ from .expression_routes import router as expression_router
|
||||
from .emoji_routes import router as emoji_router
|
||||
from .plugin_routes import router as plugin_router
|
||||
from .plugin_progress_ws import get_progress_router
|
||||
from .routers.system import router as system_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
@@ -31,37 +33,65 @@ router.include_router(emoji_router)
|
||||
router.include_router(plugin_router)
|
||||
# 注册插件进度 WebSocket 路由
|
||||
router.include_router(get_progress_router())
|
||||
# 注册系统控制路由
|
||||
router.include_router(system_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
"""Token 验证请求"""
|
||||
|
||||
token: str = Field(..., description="访问令牌")
|
||||
|
||||
|
||||
class TokenVerifyResponse(BaseModel):
|
||||
"""Token 验证响应"""
|
||||
|
||||
valid: bool = Field(..., description="Token 是否有效")
|
||||
message: str = Field(..., description="验证结果消息")
|
||||
|
||||
|
||||
class TokenUpdateRequest(BaseModel):
|
||||
"""Token 更新请求"""
|
||||
|
||||
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
||||
|
||||
|
||||
class TokenUpdateResponse(BaseModel):
|
||||
"""Token 更新响应"""
|
||||
|
||||
success: bool = Field(..., description="是否更新成功")
|
||||
message: str = Field(..., description="更新结果消息")
|
||||
|
||||
|
||||
class TokenRegenerateResponse(BaseModel):
|
||||
"""Token 重新生成响应"""
|
||||
|
||||
success: bool = Field(..., description="是否生成成功")
|
||||
token: str = Field(..., description="新生成的令牌")
|
||||
message: str = Field(..., description="生成结果消息")
|
||||
|
||||
|
||||
class FirstSetupStatusResponse(BaseModel):
|
||||
"""首次配置状态响应"""
|
||||
|
||||
is_first_setup: bool = Field(..., description="是否为首次配置")
|
||||
message: str = Field(..., description="状态消息")
|
||||
|
||||
|
||||
class CompleteSetupResponse(BaseModel):
|
||||
"""完成配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
||||
class ResetSetupResponse(BaseModel):
|
||||
"""重置配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
@@ -72,44 +102,35 @@ async def health_check():
|
||||
async def verify_token(request: TokenVerifyRequest):
|
||||
"""
|
||||
验证访问令牌
|
||||
|
||||
|
||||
Args:
|
||||
request: 包含 token 的验证请求
|
||||
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(request.token)
|
||||
|
||||
|
||||
if is_valid:
|
||||
return TokenVerifyResponse(
|
||||
valid=True,
|
||||
message="Token 验证成功"
|
||||
)
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功")
|
||||
else:
|
||||
return TokenVerifyResponse(
|
||||
valid=False,
|
||||
message="Token 无效或已过期"
|
||||
)
|
||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||
except Exception as e:
|
||||
logger.error(f"Token 验证失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
|
||||
|
||||
Args:
|
||||
request: 包含新 token 的更新请求
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
@@ -117,20 +138,17 @@ async def update_token(
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
return TokenUpdateResponse(
|
||||
success=success,
|
||||
message=message
|
||||
)
|
||||
|
||||
return TokenUpdateResponse(success=success, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -142,10 +160,10 @@ async def update_token(
|
||||
async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
重新生成访问令牌(需要当前有效的 token)
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
新生成的 token
|
||||
"""
|
||||
@@ -153,24 +171,118 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
return TokenRegenerateResponse(
|
||||
success=True,
|
||||
token=new_token,
|
||||
message="Token 已重新生成"
|
||||
)
|
||||
|
||||
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 重新生成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
|
||||
|
||||
|
||||
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
|
||||
async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取首次配置状态
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
首次配置状态
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 检查是否为首次配置
|
||||
is_first = token_manager.is_first_setup()
|
||||
|
||||
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="获取配置状态失败") from e
|
||||
|
||||
|
||||
@router.post("/setup/complete", response_model=CompleteSetupResponse)
|
||||
async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
标记首次配置完成
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
完成结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 标记配置完成
|
||||
success = token_manager.mark_setup_completed()
|
||||
|
||||
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"标记配置完成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="标记配置完成失败") from e
|
||||
|
||||
|
||||
@router.post("/setup/reset", response_model=ResetSetupResponse)
|
||||
async def reset_setup(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
重置结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 重置配置状态
|
||||
success = token_manager.reset_setup_status()
|
||||
|
||||
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重置配置状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="重置配置状态失败") from e
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""统计数据 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
total_requests: int = Field(0, description="总请求数")
|
||||
total_cost: float = Field(0.0, description="总花费")
|
||||
total_tokens: int = Field(0, description="总token数")
|
||||
@@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel):
|
||||
|
||||
class ModelStatistics(BaseModel):
|
||||
"""模型统计"""
|
||||
|
||||
model_name: str
|
||||
request_count: int
|
||||
total_cost: float
|
||||
@@ -37,6 +40,7 @@ class ModelStatistics(BaseModel):
|
||||
|
||||
class TimeSeriesData(BaseModel):
|
||||
"""时间序列数据"""
|
||||
|
||||
timestamp: str
|
||||
requests: int = 0
|
||||
cost: float = 0.0
|
||||
@@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel):
|
||||
|
||||
class DashboardData(BaseModel):
|
||||
"""仪表盘数据"""
|
||||
|
||||
summary: StatisticsSummary
|
||||
model_stats: List[ModelStatistics]
|
||||
hourly_data: List[TimeSeriesData]
|
||||
@@ -56,39 +61,39 @@ class DashboardData(BaseModel):
|
||||
async def get_dashboard_data(hours: int = 24):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时),默认24小时
|
||||
|
||||
|
||||
Returns:
|
||||
仪表盘数据
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
|
||||
|
||||
# 获取摘要数据
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
|
||||
|
||||
# 获取模型统计
|
||||
model_stats = await _get_model_statistics(start_time)
|
||||
|
||||
|
||||
# 获取小时级时间序列数据
|
||||
hourly_data = await _get_hourly_statistics(start_time, now)
|
||||
|
||||
|
||||
# 获取日级时间序列数据(最近7天)
|
||||
daily_start = now - timedelta(days=7)
|
||||
daily_data = await _get_daily_statistics(daily_start, now)
|
||||
|
||||
|
||||
# 获取最近活动
|
||||
recent_activity = await _get_recent_activity(limit=10)
|
||||
|
||||
|
||||
return DashboardData(
|
||||
summary=summary,
|
||||
model_stats=model_stats,
|
||||
hourly_data=hourly_data,
|
||||
daily_data=daily_data,
|
||||
recent_activity=recent_activity
|
||||
recent_activity=recent_activity,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取仪表盘数据失败: {e}")
|
||||
@@ -96,200 +101,176 @@ async def get_dashboard_data(hours: int = 24):
|
||||
|
||||
|
||||
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||
"""获取摘要统计数据"""
|
||||
"""获取摘要统计数据(优化:使用数据库聚合)"""
|
||||
summary = StatisticsSummary()
|
||||
|
||||
# 查询 LLM 使用记录
|
||||
llm_records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
)
|
||||
|
||||
total_time_cost = 0.0
|
||||
time_cost_count = 0
|
||||
|
||||
for record in llm_records:
|
||||
summary.total_requests += 1
|
||||
summary.total_cost += record.cost or 0.0
|
||||
summary.total_tokens += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
if record.time_cost and record.time_cost > 0:
|
||||
total_time_cost += record.time_cost
|
||||
time_cost_count += 1
|
||||
|
||||
# 计算平均响应时间
|
||||
if time_cost_count > 0:
|
||||
summary.avg_response_time = total_time_cost / time_cost_count
|
||||
|
||||
# 查询在线时间
|
||||
|
||||
# 使用聚合查询替代全量加载
|
||||
query = LLMUsage.select(
|
||||
fn.COUNT(LLMUsage.id).alias("total_requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
||||
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
||||
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
|
||||
result = query.dicts().get()
|
||||
summary.total_requests = result["total_requests"]
|
||||
summary.total_cost = result["total_cost"]
|
||||
summary.total_tokens = result["total_tokens"]
|
||||
summary.avg_response_time = result["avg_response_time"] or 0.0
|
||||
|
||||
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
|
||||
online_records = list(
|
||||
OnlineTime.select()
|
||||
.where(
|
||||
(OnlineTime.start_timestamp >= start_time) |
|
||||
(OnlineTime.end_timestamp >= start_time)
|
||||
)
|
||||
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
|
||||
)
|
||||
|
||||
|
||||
for record in online_records:
|
||||
start = max(record.start_timestamp, start_time)
|
||||
end = min(record.end_timestamp, end_time)
|
||||
if end > start:
|
||||
summary.online_time += (end - start).total_seconds()
|
||||
|
||||
# 查询消息数量
|
||||
messages = list(
|
||||
Messages.select()
|
||||
.where(Messages.time >= start_time.timestamp())
|
||||
.where(Messages.time <= end_time.timestamp())
|
||||
|
||||
# 查询消息数量 - 使用聚合优化
|
||||
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
||||
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
|
||||
)
|
||||
|
||||
summary.total_messages = len(messages)
|
||||
# 简单统计:如果 reply_to 不为空,则认为是回复
|
||||
summary.total_replies = len([m for m in messages if m.reply_to])
|
||||
|
||||
summary.total_messages = messages_query.scalar() or 0
|
||||
|
||||
# 统计回复数量
|
||||
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
||||
(Messages.time >= start_time.timestamp())
|
||||
& (Messages.time <= end_time.timestamp())
|
||||
& (Messages.reply_to.is_null(False))
|
||||
)
|
||||
summary.total_replies = replies_query.scalar() or 0
|
||||
|
||||
# 计算派生指标
|
||||
if summary.online_time > 0:
|
||||
online_hours = summary.online_time / 3600.0
|
||||
summary.cost_per_hour = summary.total_cost / online_hours
|
||||
summary.tokens_per_hour = summary.total_tokens / online_hours
|
||||
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取模型统计数据"""
|
||||
model_data = defaultdict(lambda: {
|
||||
'request_count': 0,
|
||||
'total_cost': 0.0,
|
||||
'total_tokens': 0,
|
||||
'time_costs': []
|
||||
})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||
# 使用GROUP BY聚合,避免全量加载
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
|
||||
fn.COUNT(LLMUsage.id).alias("request_count"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
||||
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
||||
)
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
|
||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||
.limit(10) # 只取前10个
|
||||
)
|
||||
|
||||
for record in records:
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
model_data[model_name]['request_count'] += 1
|
||||
model_data[model_name]['total_cost'] += record.cost or 0.0
|
||||
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
if record.time_cost and record.time_cost > 0:
|
||||
model_data[model_name]['time_costs'].append(record.time_cost)
|
||||
|
||||
# 转换为列表并排序
|
||||
|
||||
result = []
|
||||
for model_name, data in model_data.items():
|
||||
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
|
||||
result.append(ModelStatistics(
|
||||
model_name=model_name,
|
||||
request_count=data['request_count'],
|
||||
total_cost=data['total_cost'],
|
||||
total_tokens=data['total_tokens'],
|
||||
avg_response_time=avg_time
|
||||
))
|
||||
|
||||
# 按请求数排序
|
||||
result.sort(key=lambda x: x.request_count, reverse=True)
|
||||
return result[:10] # 返回前10个
|
||||
for row in query.dicts():
|
||||
result.append(
|
||||
ModelStatistics(
|
||||
model_name=row["model_name"],
|
||||
request_count=row["request_count"],
|
||||
total_cost=row["total_cost"],
|
||||
total_tokens=row["total_tokens"],
|
||||
avg_response_time=row["avg_response_time"] or 0.0,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取小时级统计数据"""
|
||||
# 创建小时桶
|
||||
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||
# SQLite的日期时间函数进行小时分组
|
||||
# 使用strftime将timestamp格式化为小时级别
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
|
||||
fn.COUNT(LLMUsage.id).alias("requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
||||
)
|
||||
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
|
||||
)
|
||||
|
||||
for record in records:
|
||||
# 获取小时键(去掉分钟和秒)
|
||||
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
|
||||
hour_str = hour_key.isoformat()
|
||||
|
||||
hourly_buckets[hour_str]['requests'] += 1
|
||||
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
|
||||
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
|
||||
# 转换为字典以快速查找
|
||||
data_dict = {row["hour"]: row for row in query.dicts()}
|
||||
|
||||
# 填充所有小时(包括没有数据的)
|
||||
result = []
|
||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
hour_str = current.isoformat()
|
||||
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
result.append(TimeSeriesData(
|
||||
timestamp=hour_str,
|
||||
requests=data['requests'],
|
||||
cost=data['cost'],
|
||||
tokens=data['tokens']
|
||||
))
|
||||
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
|
||||
if hour_str in data_dict:
|
||||
row = data_dict[hour_str]
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(hours=1)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取日级统计数据"""
|
||||
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||
# 使用strftime按日期分组
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
|
||||
fn.COUNT(LLMUsage.id).alias("requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
||||
)
|
||||
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
|
||||
)
|
||||
|
||||
for record in records:
|
||||
# 获取日期键
|
||||
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
day_str = day_key.isoformat()
|
||||
|
||||
daily_buckets[day_str]['requests'] += 1
|
||||
daily_buckets[day_str]['cost'] += record.cost or 0.0
|
||||
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
|
||||
# 转换为字典
|
||||
data_dict = {row["day"]: row for row in query.dicts()}
|
||||
|
||||
# 填充所有天
|
||||
result = []
|
||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
day_str = current.isoformat()
|
||||
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
result.append(TimeSeriesData(
|
||||
timestamp=day_str,
|
||||
requests=data['requests'],
|
||||
cost=data['cost'],
|
||||
tokens=data['tokens']
|
||||
))
|
||||
day_str = current.strftime("%Y-%m-%dT00:00:00")
|
||||
if day_str in data_dict:
|
||||
row = data_dict[day_str]
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(days=1)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近活动"""
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.order_by(LLMUsage.timestamp.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
|
||||
|
||||
activities = []
|
||||
for record in records:
|
||||
activities.append({
|
||||
'timestamp': record.timestamp.isoformat(),
|
||||
'model': record.model_assign_name or record.model_name,
|
||||
'request_type': record.request_type,
|
||||
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||
'cost': record.cost or 0.0,
|
||||
'time_cost': record.time_cost or 0.0,
|
||||
'status': record.status
|
||||
})
|
||||
|
||||
activities.append(
|
||||
{
|
||||
"timestamp": record.timestamp.isoformat(),
|
||||
"model": record.model_assign_name or record.model_name,
|
||||
"request_type": record.request_type,
|
||||
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||
"cost": record.cost or 0.0,
|
||||
"time_cost": record.time_cost or 0.0,
|
||||
"status": record.status,
|
||||
}
|
||||
)
|
||||
|
||||
return activities
|
||||
|
||||
|
||||
@@ -297,7 +278,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
async def get_summary(hours: int = 24):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
@@ -315,7 +296,7 @@ async def get_summary(hours: int = 24):
|
||||
async def get_model_stats(hours: int = 24):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,7 @@ class TokenManager:
|
||||
def __init__(self, config_path: Optional[Path] = None):
|
||||
"""
|
||||
初始化 Token 管理器
|
||||
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认为项目根目录的 data/webui.json
|
||||
"""
|
||||
@@ -27,10 +27,10 @@ class TokenManager:
|
||||
# 获取项目根目录 (src/webui -> src -> 根目录)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
config_path = project_root / "data" / "webui.json"
|
||||
|
||||
|
||||
self.config_path = config_path
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 确保配置文件存在并包含有效的 token
|
||||
self._ensure_config()
|
||||
|
||||
@@ -75,21 +75,23 @@ class TokenManager:
|
||||
"""生成新的 64 位随机 token"""
|
||||
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
|
||||
token = secrets.token_hex(32)
|
||||
|
||||
|
||||
config = {
|
||||
"access_token": token,
|
||||
"created_at": self._get_current_timestamp(),
|
||||
"updated_at": self._get_current_timestamp()
|
||||
"updated_at": self._get_current_timestamp(),
|
||||
"first_setup_completed": False, # 标记首次配置未完成
|
||||
}
|
||||
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
|
||||
|
||||
|
||||
return token
|
||||
|
||||
def _get_current_timestamp(self) -> str:
|
||||
"""获取当前时间戳字符串"""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().isoformat()
|
||||
|
||||
def get_token(self) -> str:
|
||||
@@ -100,38 +102,38 @@ class TokenManager:
|
||||
def verify_token(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 是否有效
|
||||
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
|
||||
Returns:
|
||||
bool: token 是否有效
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
|
||||
current_token = self.get_token()
|
||||
if not current_token:
|
||||
logger.error("系统中没有有效的 token")
|
||||
return False
|
||||
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
is_valid = secrets.compare_digest(token, current_token)
|
||||
|
||||
|
||||
if is_valid:
|
||||
logger.debug("Token 验证成功")
|
||||
else:
|
||||
logger.warning("Token 验证失败")
|
||||
|
||||
|
||||
return is_valid
|
||||
|
||||
def update_token(self, new_token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
更新 token
|
||||
|
||||
|
||||
Args:
|
||||
new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号)
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否更新成功, 错误消息)
|
||||
"""
|
||||
@@ -140,17 +142,17 @@ class TokenManager:
|
||||
if not is_valid:
|
||||
logger.error(f"Token 格式无效: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
|
||||
try:
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8]
|
||||
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
|
||||
return True, "Token 更新成功"
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Token 失败: {e}")
|
||||
@@ -159,7 +161,7 @@ class TokenManager:
|
||||
def regenerate_token(self) -> str:
|
||||
"""
|
||||
重新生成 token
|
||||
|
||||
|
||||
Returns:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
@@ -169,20 +171,20 @@ class TokenManager:
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token)
|
||||
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 格式是否正确
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False
|
||||
|
||||
|
||||
# 必须是 64 位十六进制字符串
|
||||
if len(token) != 64:
|
||||
return False
|
||||
|
||||
|
||||
# 验证是否为有效的十六进制字符串
|
||||
try:
|
||||
int(token, 16)
|
||||
@@ -193,44 +195,91 @@ class TokenManager:
|
||||
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
验证自定义 token 格式
|
||||
|
||||
|
||||
要求:
|
||||
- 最少 10 位
|
||||
- 包含大写字母
|
||||
- 包含小写字母
|
||||
- 包含特殊符号
|
||||
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否有效, 错误消息)
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False, "Token 不能为空"
|
||||
|
||||
|
||||
# 检查长度
|
||||
if len(token) < 10:
|
||||
return False, "Token 长度至少为 10 位"
|
||||
|
||||
|
||||
# 检查是否包含大写字母
|
||||
has_upper = any(c.isupper() for c in token)
|
||||
if not has_upper:
|
||||
return False, "Token 必须包含大写字母"
|
||||
|
||||
|
||||
# 检查是否包含小写字母
|
||||
has_lower = any(c.islower() for c in token)
|
||||
if not has_lower:
|
||||
return False, "Token 必须包含小写字母"
|
||||
|
||||
|
||||
# 检查是否包含特殊符号
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
|
||||
has_special = any(c in special_chars for c in token)
|
||||
if not has_special:
|
||||
return False, f"Token 必须包含特殊符号 ({special_chars})"
|
||||
|
||||
|
||||
return True, "Token 格式正确"
|
||||
|
||||
def is_first_setup(self) -> bool:
|
||||
"""
|
||||
检查是否为首次配置
|
||||
|
||||
Returns:
|
||||
bool: 是否为首次配置
|
||||
"""
|
||||
config = self._load_config()
|
||||
return not config.get("first_setup_completed", False)
|
||||
|
||||
def mark_setup_completed(self) -> bool:
|
||||
"""
|
||||
标记首次配置已完成
|
||||
|
||||
Returns:
|
||||
bool: 是否标记成功
|
||||
"""
|
||||
try:
|
||||
config = self._load_config()
|
||||
config["first_setup_completed"] = True
|
||||
config["setup_completed_at"] = self._get_current_timestamp()
|
||||
self._save_config(config)
|
||||
logger.info("首次配置已标记为完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"标记首次配置完成失败: {e}")
|
||||
return False
|
||||
|
||||
def reset_setup_status(self) -> bool:
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Returns:
|
||||
bool: 是否重置成功
|
||||
"""
|
||||
try:
|
||||
config = self._load_config()
|
||||
config["first_setup_completed"] = False
|
||||
if "setup_completed_at" in config:
|
||||
del config["setup_completed_at"]
|
||||
self._save_config(config)
|
||||
logger.info("首次配置状态已重置")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重置首次配置状态失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局单例
|
||||
_token_manager_instance: Optional[TokenManager] = None
|
||||
|
||||
148
src/webui/webui_server.py
Normal file
148
src/webui/webui_server.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui_server")
|
||||
|
||||
|
||||
class WebUIServer:
|
||||
"""独立的 WebUI 服务器"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8001):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = FastAPI(title="MaiBot WebUI")
|
||||
self._server = None
|
||||
|
||||
# 显示 Access Token
|
||||
self._show_access_token()
|
||||
|
||||
# 重要:先注册 API 路由,再设置静态文件
|
||||
self._register_api_routes()
|
||||
self._setup_static_files()
|
||||
|
||||
def _show_access_token(self):
|
||||
"""显示 WebUI Access Token"""
|
||||
try:
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
logger.info("💡 请使用此 Token 登录 WebUI")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取 Access Token 失败: {e}")
|
||||
|
||||
def _setup_static_files(self):
|
||||
"""设置静态文件服务"""
|
||||
# 确保正确的 MIME 类型映射
|
||||
mimetypes.init()
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("application/javascript", ".mjs")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
mimetypes.add_type("application/json", ".json")
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
static_path = base_dir / "webui" / "dist"
|
||||
|
||||
if not static_path.exists():
|
||||
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
|
||||
logger.warning("💡 请先构建前端: cd webui && npm run build")
|
||||
return
|
||||
|
||||
if not (static_path / "index.html").exists():
|
||||
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
|
||||
logger.warning("💡 请确认前端已正确构建")
|
||||
return
|
||||
|
||||
# 处理 SPA 路由 - 注意:这个路由优先级最低
|
||||
@self.app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def serve_spa(full_path: str):
|
||||
"""服务单页应用 - 只处理非 API 请求"""
|
||||
# 如果是根路径,直接返回 index.html
|
||||
if not full_path or full_path == "/":
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
|
||||
# 检查是否是静态文件
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file() and file_path.exists():
|
||||
# 自动检测 MIME 类型
|
||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
|
||||
# 其他路径返回 index.html(SPA 路由)
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
|
||||
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
||||
|
||||
def _register_api_routes(self):
|
||||
"""注册所有 WebUI API 路由"""
|
||||
try:
|
||||
# 导入所有 WebUI 路由
|
||||
from src.webui.routes import router as webui_router
|
||||
from src.webui.logs_ws import router as logs_router
|
||||
|
||||
# 注册路由
|
||||
self.app.include_router(webui_router)
|
||||
self.app.include_router(logs_router)
|
||||
|
||||
logger.info("✅ WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册 WebUI API 路由失败: {e}")
|
||||
|
||||
async def start(self):
|
||||
"""启动服务器"""
|
||||
config = Config(
|
||||
app=self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_config=None,
|
||||
access_log=False,
|
||||
)
|
||||
self._server = UvicornServer(config=config)
|
||||
|
||||
logger.info("🌐 WebUI 服务器启动中...")
|
||||
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||
|
||||
try:
|
||||
await self._server.serve()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器运行错误: {e}")
|
||||
raise
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
if self._server:
|
||||
logger.info("正在关闭 WebUI 服务器...")
|
||||
self._server.should_exit = True
|
||||
try:
|
||||
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
|
||||
logger.info("✅ WebUI 服务器已关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("⚠️ WebUI 服务器关闭超时")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器关闭失败: {e}")
|
||||
finally:
|
||||
self._server = None
|
||||
|
||||
|
||||
# 全局 WebUI 服务器实例
|
||||
_webui_server = None
|
||||
|
||||
|
||||
def get_webui_server() -> WebUIServer:
|
||||
"""获取全局 WebUI 服务器实例"""
|
||||
global _webui_server
|
||||
if _webui_server is None:
|
||||
# 从环境变量读取配置
|
||||
host = os.getenv("WEBUI_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("WEBUI_PORT", "8001"))
|
||||
_webui_server = WebUIServer(host=host, port=port)
|
||||
return _webui_server
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "6.21.6"
|
||||
version = "6.21.8"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -211,6 +211,9 @@ show_prompt = false # 是否显示prompt
|
||||
show_replyer_prompt = false # 是否显示回复器prompt
|
||||
show_replyer_reasoning = false # 是否显示回复器推理
|
||||
show_jargon_prompt = false # 是否显示jargon相关提示词
|
||||
show_memory_prompt = false # 是否显示记忆检索相关提示词
|
||||
show_planner_prompt = false # 是否显示planner的prompt和原始返回结果
|
||||
show_lpmm_paragraph = false # 是否显示lpmm找到的相关文段日志
|
||||
|
||||
[maim_message]
|
||||
auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# 麦麦主程序配置
|
||||
HOST=127.0.0.1
|
||||
PORT=8000
|
||||
|
||||
# WebUI 配置
|
||||
# WEBUI_ENABLED=true
|
||||
# WEBUI_MODE=development # 开发模式(需手动启动前端: cd webui && npm run dev,端口 7999)
|
||||
# WEBUI_MODE=production # 生产模式(需先构建前端: cd webui && npm run build)
|
||||
# WebUI 独立服务器配置
|
||||
WEBUI_ENABLED=true
|
||||
WEBUI_MODE=production # 模式: development(开发) 或 production(生产)
|
||||
WEBUI_HOST=0.0.0.0 # WebUI 服务器监听地址
|
||||
WEBUI_PORT=8001 # WebUI 服务器端口
|
||||
BIN
webui/dist/assets/KaTeX_AMS-Regular-BQhdFMY1.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_AMS-Regular-BQhdFMY1.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_AMS-Regular-DMm9YOAa.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_AMS-Regular-DMm9YOAa.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_AMS-Regular-DRggAlZN.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_AMS-Regular-DRggAlZN.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-ATXxdsX0.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-ATXxdsX0.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-BEiXGLvX.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-BEiXGLvX.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-Dq_IR9rO.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-Dq_IR9rO.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-CTRA-rTL.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-CTRA-rTL.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-Di6jR-x-.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-Di6jR-x-.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-wX97UBjC.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-wX97UBjC.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-BdnERNNW.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-BdnERNNW.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-BsDP51OF.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-BsDP51OF.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-CL6g_b3V.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-CL6g_b3V.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-CB_wures.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-CB_wures.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-CTYiF6lA.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-CTYiF6lA.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-Dxdc4cR9.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-Dxdc4cR9.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Main-Bold-Cx986IdX.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Main-Bold-Cx986IdX.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Main-Bold-Jm3AIy58.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Main-Bold-Jm3AIy58.woff
vendored
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user