Merge branch 'main' into helm-chart
This commit is contained in:
5
.github/workflows/docker-image.yml
vendored
5
.github/workflows/docker-image.yml
vendored
@@ -1,11 +1,12 @@
|
|||||||
name: Docker Build and Push
|
name: Docker Build and Push
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * *'
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
- classical
|
- classical
|
||||||
- dev
|
|
||||||
tags:
|
tags:
|
||||||
- "v*.*.*"
|
- "v*.*.*"
|
||||||
- "v*"
|
- "v*"
|
||||||
@@ -24,6 +25,7 @@ jobs:
|
|||||||
- name: Check out git repository
|
- name: Check out git repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
# Clone required dependencies
|
# Clone required dependencies
|
||||||
@@ -77,6 +79,7 @@ jobs:
|
|||||||
- name: Check out git repository
|
- name: Check out git repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
# Clone required dependencies
|
# Clone required dependencies
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,7 +11,6 @@ run_maibot_core.bat
|
|||||||
run_voice.bat
|
run_voice.bat
|
||||||
run_napcat_adapter.bat
|
run_napcat_adapter.bat
|
||||||
run_ad.bat
|
run_ad.bat
|
||||||
s4u.s4u
|
|
||||||
llm_tool_benchmark_results.json
|
llm_tool_benchmark_results.json
|
||||||
MaiBot-Napcat-Adapter-main
|
MaiBot-Napcat-Adapter-main
|
||||||
MaiBot-Napcat-Adapter
|
MaiBot-Napcat-Adapter
|
||||||
@@ -27,6 +26,7 @@ run.bat
|
|||||||
log_debug/
|
log_debug/
|
||||||
run_amds.bat
|
run_amds.bat
|
||||||
run_none.bat
|
run_none.bat
|
||||||
|
docs-mai/
|
||||||
run.py
|
run.py
|
||||||
message_queue_content.txt
|
message_queue_content.txt
|
||||||
message_queue_content.bat
|
message_queue_content.bat
|
||||||
@@ -51,6 +51,7 @@ template/compare/model_config_template.toml
|
|||||||
src/plugins/utils/statistic.py
|
src/plugins/utils/statistic.py
|
||||||
CLAUDE.md
|
CLAUDE.md
|
||||||
MaiBot-Dashboard/
|
MaiBot-Dashboard/
|
||||||
|
cloudflare-workers/
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ WORKDIR /MaiMBot
|
|||||||
# 复制依赖列表
|
# 复制依赖列表
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y git
|
||||||
|
|
||||||
# 从编译阶段复制 LPMM 编译结果
|
# 从编译阶段复制 LPMM 编译结果
|
||||||
COPY --from=lpmm-builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/
|
COPY --from=lpmm-builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/
|
||||||
|
|
||||||
|
|||||||
43
README.md
43
README.md
@@ -25,11 +25,12 @@
|
|||||||
|
|
||||||
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
||||||
|
|
||||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。
|
- 💭 **拟人构建的prompt**:使用自然语言风格构建回复器的prompt,实现近似人类言语习惯的回复。
|
||||||
- 🤔 **实时思维系统**:模拟人类思考过程。
|
- 💭 **行为规划**:在合适的时间说话,使用合适的动作
|
||||||
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
- 🧠 **表达学习**:学习群友的说话风格和表达方式,学会真实人类的说话风格
|
||||||
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
- 🤔 **黑话学习**:自主的学习没有见过的词语,尝试理解并认知含义
|
||||||
- 🔌 **强大插件系统**:提供API和事件系统,可编写强大插件。
|
- 🔌 **插件系统**:提供API和事件系统,可编写丰富插件。
|
||||||
|
- 💝 **情感表达**:情绪系统和表情包系统。
|
||||||
|
|
||||||
<div style="text-align: center">
|
<div style="text-align: center">
|
||||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||||
@@ -44,7 +45,9 @@
|
|||||||
|
|
||||||
## 🔥 更新和安装
|
## 🔥 更新和安装
|
||||||
|
|
||||||
**最新版本: v0.11.3** ([更新日志](changelogs/changelog.md))
|
|
||||||
|
**最新版本: v0.11.5** ([更新日志](changelogs/changelog.md))
|
||||||
|
|
||||||
|
|
||||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||||
@@ -60,15 +63,10 @@
|
|||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
|
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
|
||||||
> - 有问题可以提交 Issue 或者 Discussion。
|
> - 有问题可以提交 Issue 。
|
||||||
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
||||||
> - 由于程序处于开发中,可能消耗较多 token。
|
> - 由于程序处于开发中,可能消耗较多 token。
|
||||||
|
|
||||||
## 麦麦MC项目MaiCraft(早期开发)
|
|
||||||
[让麦麦玩MC](https://github.com/MaiM-with-u/Maicraft)
|
|
||||||
|
|
||||||
交流群:1058573197
|
|
||||||
|
|
||||||
## 💬 讨论
|
## 💬 讨论
|
||||||
|
|
||||||
**技术交流群:**
|
**技术交流群:**
|
||||||
@@ -80,7 +78,7 @@
|
|||||||
**聊天吹水群:**
|
**聊天吹水群:**
|
||||||
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
|
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
|
||||||
|
|
||||||
**插件开发测试版群:**
|
**插件开发/测试版讨论群:**
|
||||||
- [插件开发群](https://qm.qq.com/q/1036092828)
|
- [插件开发群](https://qm.qq.com/q/1036092828)
|
||||||
|
|
||||||
## 📚 文档
|
## 📚 文档
|
||||||
@@ -89,7 +87,22 @@
|
|||||||
|
|
||||||
- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。
|
- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。
|
||||||
|
|
||||||
### 设计理念(原始时代的火花)
|
|
||||||
|
## 📚 衍生项目
|
||||||
|
|
||||||
|
### MaiCraft(早期开发)
|
||||||
|
[MaiCraft](https://github.com/MaiM-with-u/Maicraft)
|
||||||
|
> 让麦麦具有玩MC能力的项目
|
||||||
|
> 交流群:1058573197
|
||||||
|
|
||||||
|
### MoFox_Bot
|
||||||
|
[MoFox - 仓库地址](https://github.com/MoFox-Studio/MoFox-Core)
|
||||||
|
> MoFox_Bot 是一个基于 MaiCore 0.10.0 snapshot.5 的增强型 fork 项目
|
||||||
|
> 我们保留了原项目几乎所有核心功能,并在此基础上进行了深度优化与功能扩展,致力于打造一个更稳定、更智能、更具趣味性的 AI 智能体。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 设计理念(原始时代的火花)
|
||||||
|
|
||||||
> **千石可乐说:**
|
> **千石可乐说:**
|
||||||
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
|
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
|
||||||
@@ -101,7 +114,7 @@
|
|||||||
## 🙋 贡献和致谢
|
## 🙋 贡献和致谢
|
||||||
你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦!
|
你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦!
|
||||||
MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉
|
MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉
|
||||||
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)。(待补完)
|
但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs-src/CONTRIBUTE.md)。(待补完)
|
||||||
|
|
||||||
### 贡献者
|
### 贡献者
|
||||||
|
|
||||||
|
|||||||
13
bot.py
13
bot.py
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
import platform
|
import platform
|
||||||
import traceback
|
import traceback
|
||||||
@@ -30,7 +29,7 @@ else:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
|
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
|
||||||
|
|
||||||
initialize_logging()
|
initialize_logging()
|
||||||
|
|
||||||
@@ -76,6 +75,15 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
|||||||
try:
|
try:
|
||||||
logger.info("正在优雅关闭麦麦...")
|
logger.info("正在优雅关闭麦麦...")
|
||||||
|
|
||||||
|
# 关闭 WebUI 服务器
|
||||||
|
try:
|
||||||
|
from src.webui.webui_server import get_webui_server
|
||||||
|
webui_server = get_webui_server()
|
||||||
|
if webui_server and webui_server._server:
|
||||||
|
await webui_server.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"关闭 WebUI 服务器时出错: {e}")
|
||||||
|
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
from src.plugin_system.base.component_types import EventType
|
from src.plugin_system.base.component_types import EventType
|
||||||
|
|
||||||
@@ -215,6 +223,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 初始化 WebSocket 日志推送
|
# 初始化 WebSocket 日志推送
|
||||||
from src.common.logger import initialize_ws_handler
|
from src.common.logger import initialize_ws_handler
|
||||||
|
|
||||||
initialize_ws_handler(loop)
|
initialize_ws_handler(loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,6 +1,88 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
## [0.11.5] - 2025-11-21
|
||||||
|
### 🌟 重大更新
|
||||||
|
- WebUI 现支持手动重启麦麦,曲线救国版“热重载”
|
||||||
|
- 新增麦麦 QQ 适配器可视化编辑 UI(独立进程,需手动上传/下载并覆盖适配器文件)
|
||||||
|
- 麦麦主程序配置支持可视化模式与源代码模式双模式编辑,后端执行 TOML 校验
|
||||||
|
- 优化 planner 与 replyer 协同机制,调试日志更细
|
||||||
|
|
||||||
## [0.11.3] - 2025-11-17
|
### 新增
|
||||||
|
- 表情包管理、人物信息管理、表达方式管理界面手机端适配
|
||||||
|
- 配置页“重启麦麦”提示
|
||||||
|
- 详细的debug prompt显示配置
|
||||||
|
- 麦麦界面操作主题色按钮
|
||||||
|
- 前端集成 CodeMirror(Python/JSON/TOML 语法高亮)并对 JSON 配置提供自动纠错提示
|
||||||
|
|
||||||
|
### 修复
|
||||||
|
- 表情包缩略图过小
|
||||||
|
- 添加模型后无法立即显示
|
||||||
|
- 插件市场分类错误
|
||||||
|
- 浅色模式下日志查看器背景异常
|
||||||
|
- 表情包详情窗口无法关闭
|
||||||
|
- 情绪标签无法正常读取(issues/1373)
|
||||||
|
- 模型任务配置界面模型温度不可更改(issues/1369)
|
||||||
|
- 模型任务配置界面部分输入框无法删除默认值
|
||||||
|
- 插件商店默认勾选“仅显示兼容插件”
|
||||||
|
- 插件商店标签页数量显示错误
|
||||||
|
- 日志查看器日志换行异常
|
||||||
|
- 主题色未正确应用
|
||||||
|
- 侧边栏滚动条问题
|
||||||
|
- 移除前端 TOML 解析库导致的兼容问题
|
||||||
|
|
||||||
|
### 优化
|
||||||
|
- 首页加载用户体验
|
||||||
|
- 首页加载速度(9s → <1s)
|
||||||
|
- 整个界面的初屏加载速度
|
||||||
|
|
||||||
|
### 更新
|
||||||
|
- 适配器配置支持上传模式与指定路径模式;指定路径模式可免去反复上传/下载配置文件
|
||||||
|
- 适配器配置界面标签页响应式优化,小屏仅显示简短标签
|
||||||
|
- 麦麦资源管理所有界面的批量删除能力
|
||||||
|
- 资源管理所有界面的分页、每页数量选择与页跳转
|
||||||
|
- 插件市场新增点赞、点踩、评分与下载量统计(基于 Cloudflare,保证国内可访问)
|
||||||
|
- 麦麦表情包查看界面的描述部分支持 Markdown 渲染
|
||||||
|
|
||||||
|
## [0.11.4] - 2025-11-19
|
||||||
|
### 🌟 主要更新内容
|
||||||
|
- **首个官方 Web 管理界面上线**:在此版本之前,MaiBot 没有 WebUI,所有配置需手动编辑 TOML 文件
|
||||||
|
- **认证系统**:Token 安全登录(支持系统生成 64 位随机令牌 / 自定义 Token),首次配置向导
|
||||||
|
- **配置管理(可视化编辑,无需手动改 TOML)**:
|
||||||
|
- 麦麦主程序配置:基础设置、人格、表情、黑话、情绪等
|
||||||
|
- 模型提供商配置:OpenAI、Anthropic、DeepSeek、Qwen、Ollama 等
|
||||||
|
- 模型配置:对话/视觉/嵌入模型分配
|
||||||
|
- **资源管理**:
|
||||||
|
- 表情包管理:查看、搜索、注册、封禁
|
||||||
|
- 表达方式管理:查看麦麦的表达记录
|
||||||
|
- 人物信息管理:查看联系人列表
|
||||||
|
- **插件系统**:
|
||||||
|
- 插件市场浏览
|
||||||
|
- 一键安装/卸载/更新
|
||||||
|
- 版本兼容性检查
|
||||||
|
- 实时安装进度推送
|
||||||
|
- **日志查看器**:
|
||||||
|
- WebSocket 实时日志流
|
||||||
|
- 日志级别过滤(DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||||
|
- 搜索功能
|
||||||
|
- **主题定制**:
|
||||||
|
- 浅色/深色/跟随系统
|
||||||
|
- 12 种主题色(6 单色 + 6 渐变色)
|
||||||
|
- 自定义颜色选择器
|
||||||
|
- **全局搜索**:Cmd/Ctrl + K 快捷键,快速跳转任意页面
|
||||||
|
|
||||||
|
### 细节
|
||||||
|
- **技术栈**:
|
||||||
|
- 前端: React 19 + TypeScript + Vite + TanStack Router + shadcn/ui
|
||||||
|
- 后端: FastAPI + Uvicorn + WebSocket
|
||||||
|
- 特点: SPA 单页应用,前后端同端口,静态文件托管
|
||||||
|
- **使用方式**:参照 template.env 文件更新 .env 文件,添加两个字段:
|
||||||
|
- `WEBUI_ENABLED=true`
|
||||||
|
- `WEBUI_MODE=production`
|
||||||
|
- **WebUI 开源协议**:GPLv3
|
||||||
|
- **WebUI 地址**:https://github.com/Mai-with-u/MaiBot-Dashboard
|
||||||
|
|
||||||
|
告别手动编辑配置文件,享受现代化图形界面!
|
||||||
|
|
||||||
|
## [0.11.3] - 2025-11-18
|
||||||
### 功能更改和修复
|
### 功能更改和修复
|
||||||
- 优化记忆提取策略
|
- 优化记忆提取策略
|
||||||
- 优化黑话提取
|
- 优化黑话提取
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ services:
|
|||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
|
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
|
||||||
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
|
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
|
||||||
# ports:
|
ports:
|
||||||
|
- "18001:8001" # webui端口
|
||||||
# - "8000:8000"
|
# - "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件
|
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 4.1 KiB After Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 11 KiB |
BIN
docs/image-1.png
BIN
docs/image-1.png
Binary file not shown.
|
Before Width: | Height: | Size: 21 KiB |
BIN
docs/image.png
BIN
docs/image.png
Binary file not shown.
|
Before Width: | Height: | Size: 4.9 KiB |
@@ -1,336 +0,0 @@
|
|||||||
# 模型配置指南
|
|
||||||
|
|
||||||
本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
|
|
||||||
|
|
||||||
## 配置文件结构
|
|
||||||
|
|
||||||
配置文件主要包含以下几个部分:
|
|
||||||
- 版本信息
|
|
||||||
- API服务提供商配置
|
|
||||||
- 模型配置
|
|
||||||
- 模型任务配置
|
|
||||||
|
|
||||||
## 1. 版本信息
|
|
||||||
|
|
||||||
```toml
|
|
||||||
[inner]
|
|
||||||
version = "1.1.1"
|
|
||||||
```
|
|
||||||
|
|
||||||
用于标识配置文件的版本,遵循语义化版本规则。
|
|
||||||
|
|
||||||
## 2. API服务提供商配置
|
|
||||||
|
|
||||||
### 2.1 基本配置
|
|
||||||
|
|
||||||
使用 `[[api_providers]]` 数组配置多个API服务提供商:
|
|
||||||
|
|
||||||
```toml
|
|
||||||
[[api_providers]]
|
|
||||||
name = "DeepSeek" # 服务商名称(自定义)
|
|
||||||
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
|
|
||||||
api_key = "your-api-key-here" # API密钥
|
|
||||||
client_type = "openai" # 客户端类型
|
|
||||||
max_retry = 2 # 最大重试次数
|
|
||||||
timeout = 30 # 超时时间(秒)
|
|
||||||
retry_interval = 10 # 重试间隔(秒)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.2 配置参数说明
|
|
||||||
|
|
||||||
| 参数 | 必填 | 说明 | 默认值 |
|
|
||||||
|------|------|------|--------|
|
|
||||||
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
|
||||||
| `base_url` | ✅ | API服务的基础URL | - |
|
|
||||||
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
|
||||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式) | `openai` |
|
|
||||||
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
|
||||||
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
|
||||||
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
|
||||||
|
|
||||||
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
|
|
||||||
### 2.3 支持的服务商示例
|
|
||||||
|
|
||||||
#### DeepSeek
|
|
||||||
```toml
|
|
||||||
[[api_providers]]
|
|
||||||
name = "DeepSeek"
|
|
||||||
base_url = "https://api.deepseek.com/v1"
|
|
||||||
api_key = "your-deepseek-api-key"
|
|
||||||
client_type = "openai"
|
|
||||||
```
|
|
||||||
|
|
||||||
#### SiliconFlow
|
|
||||||
```toml
|
|
||||||
[[api_providers]]
|
|
||||||
name = "SiliconFlow"
|
|
||||||
base_url = "https://api.siliconflow.cn/v1"
|
|
||||||
api_key = "your-siliconflow-api-key"
|
|
||||||
client_type = "openai"
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Google Gemini
|
|
||||||
```toml
|
|
||||||
[[api_providers]]
|
|
||||||
name = "Google"
|
|
||||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
|
||||||
api_key = "your-google-api-key"
|
|
||||||
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
|
||||||
```
|
|
||||||
|
|
||||||
## 3. 模型配置
|
|
||||||
|
|
||||||
### 3.1 基本模型配置
|
|
||||||
|
|
||||||
使用 `[[models]]` 数组配置多个模型:
|
|
||||||
|
|
||||||
```toml
|
|
||||||
[[models]]
|
|
||||||
model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
|
|
||||||
name = "deepseek-v3" # 自定义模型名称
|
|
||||||
api_provider = "DeepSeek" # 引用的API服务商名称
|
|
||||||
price_in = 2.0 # 输入价格(元/M token)
|
|
||||||
price_out = 8.0 # 输出价格(元/M token)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.2 高级模型配置
|
|
||||||
|
|
||||||
#### 强制流式输出
|
|
||||||
对于不支持非流式输出的模型:
|
|
||||||
```toml
|
|
||||||
[[models]]
|
|
||||||
model_identifier = "some-model"
|
|
||||||
name = "custom-name"
|
|
||||||
api_provider = "Provider"
|
|
||||||
force_stream_mode = true # 启用强制流式输出
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 额外参数配置`extra_params`
|
|
||||||
```toml
|
|
||||||
[[models]]
|
|
||||||
model_identifier = "Qwen/Qwen3-8B"
|
|
||||||
name = "qwen3-8b"
|
|
||||||
api_provider = "SiliconFlow"
|
|
||||||
[models.extra_params]
|
|
||||||
enable_thinking = false # 禁用思考
|
|
||||||
```
|
|
||||||
这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置,**配置时应参考相应的API文档**。
|
|
||||||
|
|
||||||
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
以豆包文档为另一个例子
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
|
|
||||||
```toml
|
|
||||||
[[models]]
|
|
||||||
# 你的模型
|
|
||||||
[models.extra_params]
|
|
||||||
thinking = {type = "disabled"} # 禁用思考
|
|
||||||
```
|
|
||||||
|
|
||||||
而对于`gemini`需要单独进行配置
|
|
||||||
```toml
|
|
||||||
[[models]]
|
|
||||||
model_identifier = "gemini-2.5-flash"
|
|
||||||
name = "gemini-2.5-flash"
|
|
||||||
api_provider = "Google"
|
|
||||||
[models.extra_params]
|
|
||||||
thinking_budget = 0 # 禁用思考
|
|
||||||
# thinking_budget = -1 由模型自己决定
|
|
||||||
```
|
|
||||||
|
|
||||||
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
|
||||||
|
|
||||||
### 3.3 配置参数说明
|
|
||||||
|
|
||||||
| 参数 | 必填 | 说明 |
|
|
||||||
|------|------|------|
|
|
||||||
| `model_identifier` | ✅ | API服务商提供的模型标识符 |
|
|
||||||
| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
|
|
||||||
| `api_provider` | ✅ | 对应的API服务商名称 |
|
|
||||||
| `price_in` | ❌ | 输入价格(元/M token),用于成本统计 |
|
|
||||||
| `price_out` | ❌ | 输出价格(元/M token),用于成本统计 |
|
|
||||||
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
|
|
||||||
| `extra_params` | ❌ | 额外的模型参数配置 |
|
|
||||||
|
|
||||||
## 4. 模型任务配置
|
|
||||||
|
|
||||||
### utils - 工具模型
|
|
||||||
用于表情包模块、取名模块、关系模块等核心功能:
|
|
||||||
```toml
|
|
||||||
[model_task_config.utils]
|
|
||||||
model_list = ["siliconflow-deepseek-v3"]
|
|
||||||
temperature = 0.2
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### utils_small - 小型工具模型
|
|
||||||
用于高频率调用的场景,建议使用速度快的小模型:
|
|
||||||
```toml
|
|
||||||
[model_task_config.utils_small]
|
|
||||||
model_list = ["qwen3-8b"]
|
|
||||||
temperature = 0.7
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### replyer - 主要回复模型
|
|
||||||
首要回复模型,也用于表达器和表达方式学习:
|
|
||||||
```toml
|
|
||||||
[model_task_config.replyer]
|
|
||||||
model_list = ["siliconflow-deepseek-v3"]
|
|
||||||
temperature = 0.2
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### planner - 决策模型
|
|
||||||
负责决定MaiBot该做什么:
|
|
||||||
```toml
|
|
||||||
[model_task_config.planner]
|
|
||||||
model_list = ["siliconflow-deepseek-v3"]
|
|
||||||
temperature = 0.3
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### emotion - 情绪模型
|
|
||||||
负责MaiBot的情绪变化:
|
|
||||||
```toml
|
|
||||||
[model_task_config.emotion]
|
|
||||||
model_list = ["siliconflow-deepseek-v3"]
|
|
||||||
temperature = 0.3
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### memory - 记忆模型
|
|
||||||
```toml
|
|
||||||
[model_task_config.memory]
|
|
||||||
model_list = ["qwen3-30b"]
|
|
||||||
temperature = 0.7
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### vlm - 视觉语言模型
|
|
||||||
用于图像识别:
|
|
||||||
```toml
|
|
||||||
[model_task_config.vlm]
|
|
||||||
model_list = ["qwen2.5-vl-72b"]
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### voice - 语音识别模型
|
|
||||||
```toml
|
|
||||||
[model_task_config.voice]
|
|
||||||
model_list = ["sensevoice-small"]
|
|
||||||
```
|
|
||||||
|
|
||||||
### embedding - 嵌入模型
|
|
||||||
```toml
|
|
||||||
[model_task_config.embedding]
|
|
||||||
model_list = ["bge-m3"]
|
|
||||||
```
|
|
||||||
|
|
||||||
### tool_use - 工具调用模型
|
|
||||||
需要使用支持工具调用的模型:
|
|
||||||
```toml
|
|
||||||
[model_task_config.tool_use]
|
|
||||||
model_list = ["qwen3-14b"]
|
|
||||||
temperature = 0.7
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### lpmm_entity_extract - 实体提取模型
|
|
||||||
```toml
|
|
||||||
[model_task_config.lpmm_entity_extract]
|
|
||||||
model_list = ["siliconflow-deepseek-v3"]
|
|
||||||
temperature = 0.2
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### lpmm_rdf_build - RDF构建模型
|
|
||||||
```toml
|
|
||||||
[model_task_config.lpmm_rdf_build]
|
|
||||||
model_list = ["siliconflow-deepseek-v3"]
|
|
||||||
temperature = 0.2
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
### lpmm_qa - 问答模型
|
|
||||||
```toml
|
|
||||||
[model_task_config.lpmm_qa]
|
|
||||||
model_list = ["deepseek-r1-distill-qwen-32b"]
|
|
||||||
temperature = 0.7
|
|
||||||
max_tokens = 800
|
|
||||||
```
|
|
||||||
|
|
||||||
## 5. 配置建议
|
|
||||||
|
|
||||||
### 5.1 Temperature 参数选择
|
|
||||||
|
|
||||||
| 任务类型 | 推荐温度 | 说明 |
|
|
||||||
|----------|----------|------|
|
|
||||||
| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
|
|
||||||
| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
|
|
||||||
| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
|
|
||||||
|
|
||||||
### 5.2 模型选择建议
|
|
||||||
|
|
||||||
| 任务类型 | 推荐模型类型 | 示例 |
|
|
||||||
|----------|--------------|------|
|
|
||||||
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
|
|
||||||
| 高频率任务 | 小模型 | Qwen3-8B |
|
|
||||||
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
|
|
||||||
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
|
|
||||||
|
|
||||||
### 5.3 成本优化
|
|
||||||
|
|
||||||
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
|
|
||||||
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
|
|
||||||
3. **选择免费模型**:对于测试环境,优先使用price为0的模型
|
|
||||||
|
|
||||||
## 6. 配置验证
|
|
||||||
|
|
||||||
### 6.1 必要检查项
|
|
||||||
|
|
||||||
1. ✅ API密钥是否正确配置
|
|
||||||
2. ✅ 模型标识符是否与API服务商提供的一致
|
|
||||||
3. ✅ 任务配置中引用的模型名称是否在models中定义
|
|
||||||
4. ✅ 多模态任务是否配置了对应的专用模型
|
|
||||||
|
|
||||||
### 6.2 测试配置
|
|
||||||
|
|
||||||
建议在正式使用前:
|
|
||||||
1. 使用少量测试数据验证配置
|
|
||||||
2. 检查API调用是否正常
|
|
||||||
3. 确认成本统计功能正常工作
|
|
||||||
|
|
||||||
## 7. 故障排除
|
|
||||||
|
|
||||||
### 7.1 常见问题
|
|
||||||
|
|
||||||
**问题1**: API调用失败
|
|
||||||
- 检查API密钥是否正确
|
|
||||||
- 确认base_url是否可访问
|
|
||||||
- 检查模型标识符是否正确
|
|
||||||
|
|
||||||
**问题2**: 模型未找到
|
|
||||||
- 确认模型名称在任务配置和模型定义中一致
|
|
||||||
- 检查api_provider名称是否匹配
|
|
||||||
|
|
||||||
**问题3**: 响应异常
|
|
||||||
- 检查温度参数是否合理(0-1之间)
|
|
||||||
- 确认max_tokens设置是否合适
|
|
||||||
- 验证模型是否支持所需功能
|
|
||||||
|
|
||||||
### 7.2 日志查看
|
|
||||||
|
|
||||||
查看 `logs/` 目录下的日志文件,寻找相关错误信息。
|
|
||||||
|
|
||||||
## 8. 更新和维护
|
|
||||||
|
|
||||||
1. **定期更新**: 关注API服务商的模型更新,及时调整配置
|
|
||||||
2. **性能监控**: 监控模型调用的成本和性能
|
|
||||||
3. **备份配置**: 在修改前备份当前配置文件
|
|
||||||
|
|
||||||
@@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
|
|||||||
sys.path.insert(0, PROJECT_ROOT)
|
sys.path.insert(0, PROJECT_ROOT)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SECONDS_5_MINUTES = 5 * 60
|
SECONDS_5_MINUTES = 5 * 60
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 设置中文字体
|
# 设置中文字体
|
||||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||||
plt.rcParams["axes.unicode_minus"] = False
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ from src.common.database.database import db
|
|||||||
from src.common.database.database_model import Emoji
|
from src.common.database.database_model import Emoji
|
||||||
|
|
||||||
# 常量定义
|
# 常量定义
|
||||||
MAGIC = b'MMIP'
|
MAGIC = b"MMIP"
|
||||||
FOOTER_MAGIC = b'MMFF'
|
FOOTER_MAGIC = b"MMFF"
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
FOOTER_VERSION = 1
|
FOOTER_VERSION = 1
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ MAX_MANIFEST_SIZE = 200 * 1024 * 1024 # 200 MB
|
|||||||
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB
|
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB
|
||||||
|
|
||||||
# 支持的图片格式
|
# 支持的图片格式
|
||||||
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.avif', '.bmp'}
|
SUPPORTED_FORMATS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".avif", ".bmp"}
|
||||||
|
|
||||||
# 创建控制台对象
|
# 创建控制台对象
|
||||||
console = Console()
|
console = Console()
|
||||||
@@ -75,6 +75,7 @@ console = Console()
|
|||||||
|
|
||||||
class MMIPKGError(Exception):
|
class MMIPKGError(Exception):
|
||||||
"""MMIPKG 相关错误"""
|
"""MMIPKG 相关错误"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -97,55 +98,55 @@ def get_image_info(file_path: str) -> Tuple[int, int, str]:
|
|||||||
try:
|
try:
|
||||||
with Image.open(file_path) as img:
|
with Image.open(file_path) as img:
|
||||||
width, height = img.size
|
width, height = img.size
|
||||||
format_lower = img.format.lower() if img.format else 'unknown'
|
format_lower = img.format.lower() if img.format else "unknown"
|
||||||
mime_map = {
|
mime_map = {
|
||||||
'jpeg': 'image/jpeg',
|
"jpeg": "image/jpeg",
|
||||||
'jpg': 'image/jpeg',
|
"jpg": "image/jpeg",
|
||||||
'png': 'image/png',
|
"png": "image/png",
|
||||||
'gif': 'image/gif',
|
"gif": "image/gif",
|
||||||
'webp': 'image/webp',
|
"webp": "image/webp",
|
||||||
'avif': 'image/avif',
|
"avif": "image/avif",
|
||||||
'bmp': 'image/bmp'
|
"bmp": "image/bmp",
|
||||||
}
|
}
|
||||||
mime_type = mime_map.get(format_lower, f'image/{format_lower}')
|
mime_type = mime_map.get(format_lower, f"image/{format_lower}")
|
||||||
return width, height, mime_type
|
return width, height, mime_type
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"警告: 无法读取图片信息 {file_path}: {e}")
|
print(f"警告: 无法读取图片信息 {file_path}: {e}")
|
||||||
return 0, 0, 'image/unknown'
|
return 0, 0, "image/unknown"
|
||||||
|
|
||||||
|
|
||||||
def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 80) -> bytes:
|
def reencode_image(file_path: str, output_format: str = "webp", quality: int = 80) -> bytes:
|
||||||
"""重新编码图片"""
|
"""重新编码图片"""
|
||||||
try:
|
try:
|
||||||
with Image.open(file_path) as img:
|
with Image.open(file_path) as img:
|
||||||
# 转换为 RGB(如果需要)
|
# 转换为 RGB(如果需要)
|
||||||
if img.mode in ('RGBA', 'LA', 'P'):
|
if img.mode in ("RGBA", "LA", "P"):
|
||||||
if output_format.lower() == 'jpeg':
|
if output_format.lower() == "jpeg":
|
||||||
# JPEG 不支持透明度,转为白色背景
|
# JPEG 不支持透明度,转为白色背景
|
||||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
background = Image.new("RGB", img.size, (255, 255, 255))
|
||||||
if img.mode == 'P':
|
if img.mode == "P":
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
|
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
|
||||||
img = background
|
img = background
|
||||||
elif output_format.lower() == 'webp':
|
elif output_format.lower() == "webp":
|
||||||
# WebP 支持透明度
|
# WebP 支持透明度
|
||||||
if img.mode == 'P':
|
if img.mode == "P":
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
elif img.mode not in ('RGB', 'RGBA'):
|
elif img.mode not in ("RGB", "RGBA"):
|
||||||
img = img.convert('RGB')
|
img = img.convert("RGB")
|
||||||
|
|
||||||
# 编码图片
|
# 编码图片
|
||||||
output = io.BytesIO()
|
output = io.BytesIO()
|
||||||
save_kwargs = {'format': output_format.upper()}
|
save_kwargs = {"format": output_format.upper()}
|
||||||
|
|
||||||
if output_format.lower() in {'jpeg', 'jpg'}:
|
if output_format.lower() in {"jpeg", "jpg"}:
|
||||||
save_kwargs['quality'] = quality
|
save_kwargs["quality"] = quality
|
||||||
save_kwargs['optimize'] = True
|
save_kwargs["optimize"] = True
|
||||||
elif output_format.lower() == 'webp':
|
elif output_format.lower() == "webp":
|
||||||
save_kwargs['quality'] = quality
|
save_kwargs["quality"] = quality
|
||||||
save_kwargs['method'] = 6 # 更好的压缩
|
save_kwargs["method"] = 6 # 更好的压缩
|
||||||
elif output_format.lower() == 'png':
|
elif output_format.lower() == "png":
|
||||||
save_kwargs['optimize'] = True
|
save_kwargs["optimize"] = True
|
||||||
|
|
||||||
img.save(output, **save_kwargs)
|
img.save(output, **save_kwargs)
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
@@ -156,11 +157,13 @@ def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 8
|
|||||||
class MMIPKGPacker:
|
class MMIPKGPacker:
|
||||||
"""MMIPKG 打包器"""
|
"""MMIPKG 打包器"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
use_compression: bool = True,
|
self,
|
||||||
zstd_level: int = 3,
|
use_compression: bool = True,
|
||||||
reencode: Optional[str] = None,
|
zstd_level: int = 3,
|
||||||
reencode_quality: int = 80):
|
reencode: Optional[str] = None,
|
||||||
|
reencode_quality: int = 80,
|
||||||
|
):
|
||||||
self.use_compression = use_compression and zstd is not None
|
self.use_compression = use_compression and zstd is not None
|
||||||
self.zstd_level = zstd_level
|
self.zstd_level = zstd_level
|
||||||
self.reencode = reencode
|
self.reencode = reencode
|
||||||
@@ -170,8 +173,9 @@ class MMIPKGPacker:
|
|||||||
print("警告: zstandard 未安装,将不使用压缩")
|
print("警告: zstandard 未安装,将不使用压缩")
|
||||||
self.use_compression = False
|
self.use_compression = False
|
||||||
|
|
||||||
def pack_from_db(self, output_path: str, pack_name: Optional[str] = None,
|
def pack_from_db(
|
||||||
custom_manifest: Optional[Dict] = None) -> bool:
|
self, output_path: str, pack_name: Optional[str] = None, custom_manifest: Optional[Dict] = None
|
||||||
|
) -> bool:
|
||||||
"""从数据库导出已注册的表情包
|
"""从数据库导出已注册的表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -205,12 +209,14 @@ class MMIPKGPacker:
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeElapsedColumn(),
|
TimeElapsedColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]扫描表情包...", total=emoji_count)
|
task = progress.add_task("[cyan]扫描表情包...", total=emoji_count)
|
||||||
|
|
||||||
for idx, emoji in enumerate(emojis, 1):
|
for idx, emoji in enumerate(emojis, 1):
|
||||||
progress.update(task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}")
|
progress.update(
|
||||||
|
task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}"
|
||||||
|
)
|
||||||
|
|
||||||
# 检查文件是否存在
|
# 检查文件是否存在
|
||||||
if not os.path.exists(emoji.full_path):
|
if not os.path.exists(emoji.full_path):
|
||||||
@@ -224,10 +230,10 @@ class MMIPKGPacker:
|
|||||||
img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality)
|
img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]")
|
console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]")
|
||||||
with open(emoji.full_path, 'rb') as f:
|
with open(emoji.full_path, "rb") as f:
|
||||||
img_bytes = f.read()
|
img_bytes = f.read()
|
||||||
else:
|
else:
|
||||||
with open(emoji.full_path, 'rb') as f:
|
with open(emoji.full_path, "rb") as f:
|
||||||
img_bytes = f.read()
|
img_bytes = f.read()
|
||||||
|
|
||||||
# 计算 SHA256
|
# 计算 SHA256
|
||||||
@@ -259,7 +265,7 @@ class MMIPKGPacker:
|
|||||||
"emoji_hash": emoji.emoji_hash or "",
|
"emoji_hash": emoji.emoji_hash or "",
|
||||||
"is_registered": True,
|
"is_registered": True,
|
||||||
"is_banned": emoji.is_banned or False,
|
"is_banned": emoji.is_banned or False,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
items.append(item)
|
items.append(item)
|
||||||
@@ -281,7 +287,7 @@ class MMIPKGPacker:
|
|||||||
"p": pack_id, # pack_id
|
"p": pack_id, # pack_id
|
||||||
"n": pack_name, # pack_name
|
"n": pack_name, # pack_name
|
||||||
"t": datetime.now().isoformat(), # created_at
|
"t": datetime.now().isoformat(), # created_at
|
||||||
"a": items # items array
|
"a": items, # items array
|
||||||
}
|
}
|
||||||
|
|
||||||
# 添加自定义字段
|
# 添加自定义字段
|
||||||
@@ -308,26 +314,28 @@ class MMIPKGPacker:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"打包失败: {e}")
|
print(f"打包失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
if not db.is_closed():
|
if not db.is_closed():
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
def _write_package(self, output_path: str, manifest_bytes: bytes,
|
def _write_package(
|
||||||
image_data_list: List[bytes], payload_size: int) -> bool:
|
self, output_path: str, manifest_bytes: bytes, image_data_list: List[bytes], payload_size: int
|
||||||
|
) -> bool:
|
||||||
"""写入打包文件"""
|
"""写入打包文件"""
|
||||||
try:
|
try:
|
||||||
with open(output_path, 'wb') as f:
|
with open(output_path, "wb") as f:
|
||||||
# 写入 Header (32 bytes)
|
# 写入 Header (32 bytes)
|
||||||
flags = 0x01 if self.use_compression else 0x00
|
flags = 0x01 if self.use_compression else 0x00
|
||||||
header = MAGIC # 4 bytes
|
header = MAGIC # 4 bytes
|
||||||
header += struct.pack('B', VERSION) # 1 byte
|
header += struct.pack("B", VERSION) # 1 byte
|
||||||
header += struct.pack('B', flags) # 1 byte
|
header += struct.pack("B", flags) # 1 byte
|
||||||
header += b'\x00\x00' # 2 bytes reserved
|
header += b"\x00\x00" # 2 bytes reserved
|
||||||
header += struct.pack('>Q', payload_size) # 8 bytes
|
header += struct.pack(">Q", payload_size) # 8 bytes
|
||||||
header += struct.pack('>Q', len(manifest_bytes)) # 8 bytes
|
header += struct.pack(">Q", len(manifest_bytes)) # 8 bytes
|
||||||
header += b'\x00' * 8 # 8 bytes reserved
|
header += b"\x00" * 8 # 8 bytes reserved
|
||||||
|
|
||||||
assert len(header) == 32, f"Header size mismatch: {len(header)}"
|
assert len(header) == 32, f"Header size mismatch: {len(header)}"
|
||||||
f.write(header)
|
f.write(header)
|
||||||
@@ -342,7 +350,7 @@ class MMIPKGPacker:
|
|||||||
|
|
||||||
with compressor.stream_writer(f, closefd=False) as writer:
|
with compressor.stream_writer(f, closefd=False) as writer:
|
||||||
# 写入 manifest
|
# 写入 manifest
|
||||||
manifest_len_bytes = struct.pack('>I', len(manifest_bytes))
|
manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
|
||||||
writer.write(manifest_len_bytes)
|
writer.write(manifest_len_bytes)
|
||||||
writer.write(manifest_bytes)
|
writer.write(manifest_bytes)
|
||||||
payload_sha.update(manifest_len_bytes)
|
payload_sha.update(manifest_len_bytes)
|
||||||
@@ -355,13 +363,13 @@ class MMIPKGPacker:
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeRemainingColumn(),
|
TimeRemainingColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list))
|
task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list))
|
||||||
|
|
||||||
for idx, img_bytes in enumerate(image_data_list, 1):
|
for idx, img_bytes in enumerate(image_data_list, 1):
|
||||||
progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}")
|
progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}")
|
||||||
img_len_bytes = struct.pack('>I', len(img_bytes))
|
img_len_bytes = struct.pack(">I", len(img_bytes))
|
||||||
writer.write(img_len_bytes)
|
writer.write(img_len_bytes)
|
||||||
writer.write(img_bytes)
|
writer.write(img_bytes)
|
||||||
payload_sha.update(img_len_bytes)
|
payload_sha.update(img_len_bytes)
|
||||||
@@ -370,7 +378,7 @@ class MMIPKGPacker:
|
|||||||
else:
|
else:
|
||||||
# 不压缩,直接写入
|
# 不压缩,直接写入
|
||||||
# 写入 manifest
|
# 写入 manifest
|
||||||
manifest_len_bytes = struct.pack('>I', len(manifest_bytes))
|
manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
|
||||||
f.write(manifest_len_bytes)
|
f.write(manifest_len_bytes)
|
||||||
f.write(manifest_bytes)
|
f.write(manifest_bytes)
|
||||||
payload_sha.update(manifest_len_bytes)
|
payload_sha.update(manifest_len_bytes)
|
||||||
@@ -383,13 +391,13 @@ class MMIPKGPacker:
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeRemainingColumn(),
|
TimeRemainingColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[green]写入图片...", total=len(image_data_list))
|
task = progress.add_task("[green]写入图片...", total=len(image_data_list))
|
||||||
|
|
||||||
for idx, img_bytes in enumerate(image_data_list, 1):
|
for idx, img_bytes in enumerate(image_data_list, 1):
|
||||||
progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}")
|
progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}")
|
||||||
img_len_bytes = struct.pack('>I', len(img_bytes))
|
img_len_bytes = struct.pack(">I", len(img_bytes))
|
||||||
f.write(img_len_bytes)
|
f.write(img_len_bytes)
|
||||||
f.write(img_bytes)
|
f.write(img_bytes)
|
||||||
payload_sha.update(img_len_bytes)
|
payload_sha.update(img_len_bytes)
|
||||||
@@ -400,8 +408,8 @@ class MMIPKGPacker:
|
|||||||
file_sha256 = payload_sha.digest()
|
file_sha256 = payload_sha.digest()
|
||||||
footer = FOOTER_MAGIC # 4 bytes
|
footer = FOOTER_MAGIC # 4 bytes
|
||||||
footer += file_sha256 # 32 bytes
|
footer += file_sha256 # 32 bytes
|
||||||
footer += struct.pack('B', FOOTER_VERSION) # 1 byte
|
footer += struct.pack("B", FOOTER_VERSION) # 1 byte
|
||||||
footer += b'\x00' * 3 # 3 bytes reserved
|
footer += b"\x00" * 3 # 3 bytes reserved
|
||||||
|
|
||||||
assert len(footer) == 40, f"Footer size mismatch: {len(footer)}"
|
assert len(footer) == 40, f"Footer size mismatch: {len(footer)}"
|
||||||
f.write(footer)
|
f.write(footer)
|
||||||
@@ -419,6 +427,7 @@ class MMIPKGPacker:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"写入文件失败: {e}")
|
print(f"写入文件失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -429,10 +438,9 @@ class MMIPKGUnpacker:
|
|||||||
def __init__(self, verify_sha: bool = True):
|
def __init__(self, verify_sha: bool = True):
|
||||||
self.verify_sha = verify_sha
|
self.verify_sha = verify_sha
|
||||||
|
|
||||||
def import_to_db(self, package_path: str,
|
def import_to_db(
|
||||||
output_dir: Optional[str] = None,
|
self, package_path: str, output_dir: Optional[str] = None, replace_existing: bool = False, batch_size: int = 500
|
||||||
replace_existing: bool = False,
|
) -> bool:
|
||||||
batch_size: int = 500) -> bool:
|
|
||||||
"""导入到数据库"""
|
"""导入到数据库"""
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(package_path):
|
if not os.path.exists(package_path):
|
||||||
@@ -451,7 +459,7 @@ class MMIPKGUnpacker:
|
|||||||
|
|
||||||
print(f"正在读取包: {package_path}")
|
print(f"正在读取包: {package_path}")
|
||||||
|
|
||||||
with open(package_path, 'rb') as f:
|
with open(package_path, "rb") as f:
|
||||||
# 读取 Header
|
# 读取 Header
|
||||||
header = f.read(32)
|
header = f.read(32)
|
||||||
if len(header) != 32:
|
if len(header) != 32:
|
||||||
@@ -461,15 +469,15 @@ class MMIPKGUnpacker:
|
|||||||
if magic != MAGIC:
|
if magic != MAGIC:
|
||||||
raise MMIPKGError(f"无效的 MAGIC: {magic}")
|
raise MMIPKGError(f"无效的 MAGIC: {magic}")
|
||||||
|
|
||||||
version = struct.unpack('B', header[4:5])[0]
|
version = struct.unpack("B", header[4:5])[0]
|
||||||
if version != VERSION:
|
if version != VERSION:
|
||||||
print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配")
|
print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配")
|
||||||
|
|
||||||
flags = struct.unpack('B', header[5:6])[0]
|
flags = struct.unpack("B", header[5:6])[0]
|
||||||
is_compressed = bool(flags & 0x01)
|
is_compressed = bool(flags & 0x01)
|
||||||
|
|
||||||
payload_uncompressed_len = struct.unpack('>Q', header[8:16])[0]
|
payload_uncompressed_len = struct.unpack(">Q", header[8:16])[0]
|
||||||
manifest_uncompressed_len = struct.unpack('>Q', header[16:24])[0]
|
manifest_uncompressed_len = struct.unpack(">Q", header[16:24])[0]
|
||||||
|
|
||||||
# 安全检查
|
# 安全检查
|
||||||
if manifest_uncompressed_len > MAX_MANIFEST_SIZE:
|
if manifest_uncompressed_len > MAX_MANIFEST_SIZE:
|
||||||
@@ -519,7 +527,9 @@ class MMIPKGUnpacker:
|
|||||||
# 方法2:如果流式失败,尝试直接解压(兼容旧格式)
|
# 方法2:如果流式失败,尝试直接解压(兼容旧格式)
|
||||||
print(f" 流式解压失败,尝试直接解压: {e}")
|
print(f" 流式解压失败,尝试直接解压: {e}")
|
||||||
try:
|
try:
|
||||||
payload_data = decompressor.decompress(compressed_data, max_output_size=payload_uncompressed_len)
|
payload_data = decompressor.decompress(
|
||||||
|
compressed_data, max_output_size=payload_uncompressed_len
|
||||||
|
)
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
raise MMIPKGError(f"解压失败: {e2}") from e2
|
raise MMIPKGError(f"解压失败: {e2}") from e2
|
||||||
else:
|
else:
|
||||||
@@ -537,7 +547,7 @@ class MMIPKGUnpacker:
|
|||||||
|
|
||||||
# 读取 manifest
|
# 读取 manifest
|
||||||
manifest_len_bytes = payload_stream.read(4)
|
manifest_len_bytes = payload_stream.read(4)
|
||||||
manifest_len = struct.unpack('>I', manifest_len_bytes)[0]
|
manifest_len = struct.unpack(">I", manifest_len_bytes)[0]
|
||||||
manifest_bytes = payload_stream.read(manifest_len)
|
manifest_bytes = payload_stream.read(manifest_len)
|
||||||
manifest = msgpack.unpackb(manifest_bytes, raw=False)
|
manifest = msgpack.unpackb(manifest_bytes, raw=False)
|
||||||
|
|
||||||
@@ -553,20 +563,21 @@ class MMIPKGUnpacker:
|
|||||||
print(f" 表情包数量: {len(items)}")
|
print(f" 表情包数量: {len(items)}")
|
||||||
|
|
||||||
# 导入表情包
|
# 导入表情包
|
||||||
return self._import_items(payload_stream, items, output_dir,
|
return self._import_items(payload_stream, items, output_dir, replace_existing, batch_size)
|
||||||
replace_existing, batch_size)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"导入失败: {e}")
|
print(f"导入失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
if not db.is_closed():
|
if not db.is_closed():
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
def _import_items(self, payload_stream: BinaryIO, items: List[Dict],
|
def _import_items(
|
||||||
output_dir: str, replace_existing: bool, batch_size: int) -> bool:
|
self, payload_stream: BinaryIO, items: List[Dict], output_dir: str, replace_existing: bool, batch_size: int
|
||||||
|
) -> bool:
|
||||||
"""导入 items 到数据库"""
|
"""导入 items 到数据库"""
|
||||||
try:
|
try:
|
||||||
imported_count = 0
|
imported_count = 0
|
||||||
@@ -581,7 +592,7 @@ class MMIPKGUnpacker:
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeRemainingColumn(),
|
TimeRemainingColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]导入表情包...", total=len(items))
|
task = progress.add_task("[cyan]导入表情包...", total=len(items))
|
||||||
|
|
||||||
@@ -597,7 +608,7 @@ class MMIPKGUnpacker:
|
|||||||
progress.advance(task)
|
progress.advance(task)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
img_len = struct.unpack('>I', img_len_bytes)[0]
|
img_len = struct.unpack(">I", img_len_bytes)[0]
|
||||||
img_bytes = payload_stream.read(img_len)
|
img_bytes = payload_stream.read(img_len)
|
||||||
|
|
||||||
if len(img_bytes) != img_len:
|
if len(img_bytes) != img_len:
|
||||||
@@ -641,7 +652,7 @@ class MMIPKGUnpacker:
|
|||||||
file_path = os.path.join(output_dir, filename)
|
file_path = os.path.join(output_dir, filename)
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
with open(file_path, 'wb') as img_file:
|
with open(file_path, "wb") as img_file:
|
||||||
img_file.write(img_bytes)
|
img_file.write(img_bytes)
|
||||||
|
|
||||||
# 准备数据库记录
|
# 准备数据库记录
|
||||||
@@ -700,6 +711,7 @@ class MMIPKGUnpacker:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]导入 items 失败: {e}[/red]")
|
console.print(f"[red]导入 items 失败: {e}[/red]")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -719,8 +731,9 @@ def print_menu():
|
|||||||
console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)")
|
console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)")
|
||||||
console.print(" [0] [bold]退出[/bold]")
|
console.print(" [0] [bold]退出[/bold]")
|
||||||
console.print()
|
console.print()
|
||||||
def get_input(prompt: str, default: Optional[str] = None,
|
|
||||||
choices: Optional[List[str]] = None) -> str:
|
|
||||||
|
def get_input(prompt: str, default: Optional[str] = None, choices: Optional[List[str]] = None) -> str:
|
||||||
"""获取用户输入"""
|
"""获取用户输入"""
|
||||||
if default:
|
if default:
|
||||||
prompt = f"{prompt} (默认: {default})"
|
prompt = f"{prompt} (默认: {default})"
|
||||||
@@ -760,9 +773,9 @@ def get_yes_no(prompt: str, default: bool = False) -> bool:
|
|||||||
if not value:
|
if not value:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
if value in ('y', 'yes', '是'):
|
if value in ("y", "yes", "是"):
|
||||||
return True
|
return True
|
||||||
elif value in ('n', 'no', '否'):
|
elif value in ("n", "no", "否"):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]")
|
console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]")
|
||||||
@@ -843,8 +856,8 @@ def interactive_export():
|
|||||||
output_path = get_input(" 输出文件路径", default_filename)
|
output_path = get_input(" 输出文件路径", default_filename)
|
||||||
|
|
||||||
# 确保有 .mmipkg 扩展名
|
# 确保有 .mmipkg 扩展名
|
||||||
if not output_path.endswith('.mmipkg'):
|
if not output_path.endswith(".mmipkg"):
|
||||||
output_path += '.mmipkg'
|
output_path += ".mmipkg"
|
||||||
|
|
||||||
# 获取包名称
|
# 获取包名称
|
||||||
default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}"
|
default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}"
|
||||||
@@ -853,9 +866,7 @@ def interactive_export():
|
|||||||
# 自定义 manifest
|
# 自定义 manifest
|
||||||
console.print("\n[yellow]2. 包信息设置(可选)[/yellow]")
|
console.print("\n[yellow]2. 包信息设置(可选)[/yellow]")
|
||||||
if get_yes_no(" 是否添加包的作者和介绍信息", False):
|
if get_yes_no(" 是否添加包的作者和介绍信息", False):
|
||||||
custom_manifest = {
|
custom_manifest = {"author": author} if (author := input(" 作者名称(可选): ").strip()) else {}
|
||||||
"author": author
|
|
||||||
} if (author := input(" 作者名称(可选): ").strip()) else {}
|
|
||||||
|
|
||||||
# 介绍信息
|
# 介绍信息
|
||||||
console.print(" 包介绍(限制 100 字以内):")
|
console.print(" 包介绍(限制 100 字以内):")
|
||||||
@@ -888,9 +899,9 @@ def interactive_export():
|
|||||||
console.print(" webp: 推荐,体积小且支持透明度")
|
console.print(" webp: 推荐,体积小且支持透明度")
|
||||||
console.print(" jpeg: 最小体积,但不支持透明度")
|
console.print(" jpeg: 最小体积,但不支持透明度")
|
||||||
console.print(" png: 无损,文件较大")
|
console.print(" png: 无损,文件较大")
|
||||||
reencode = get_input(" 选择格式", "webp", ['webp', 'jpeg', 'png'])
|
reencode = get_input(" 选择格式", "webp", ["webp", "jpeg", "png"])
|
||||||
|
|
||||||
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ('webp', 'jpeg') else 80
|
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ("webp", "jpeg") else 80
|
||||||
else:
|
else:
|
||||||
reencode = None
|
reencode = None
|
||||||
quality = 80
|
quality = 80
|
||||||
@@ -920,10 +931,7 @@ def interactive_export():
|
|||||||
# 开始导出
|
# 开始导出
|
||||||
console.print("\n[cyan]开始导出...[/cyan]")
|
console.print("\n[cyan]开始导出...[/cyan]")
|
||||||
packer = MMIPKGPacker(
|
packer = MMIPKGPacker(
|
||||||
use_compression=use_compression,
|
use_compression=use_compression, zstd_level=zstd_level, reencode=reencode, reencode_quality=quality
|
||||||
zstd_level=zstd_level,
|
|
||||||
reencode=reencode,
|
|
||||||
reencode_quality=quality
|
|
||||||
)
|
)
|
||||||
|
|
||||||
success = packer.pack_from_db(output_path, pack_name, custom_manifest)
|
success = packer.pack_from_db(output_path, pack_name, custom_manifest)
|
||||||
@@ -944,11 +952,11 @@ def interactive_import():
|
|||||||
|
|
||||||
# 选择导入模式
|
# 选择导入模式
|
||||||
print_import_mode_selection()
|
print_import_mode_selection()
|
||||||
import_mode = get_input("请选择", "1", ['1', '2'])
|
import_mode = get_input("请选择", "1", ["1", "2"])
|
||||||
|
|
||||||
input_files = []
|
input_files = []
|
||||||
|
|
||||||
if import_mode == '1':
|
if import_mode == "1":
|
||||||
# 自动扫描模式
|
# 自动扫描模式
|
||||||
import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji")
|
import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji")
|
||||||
os.makedirs(import_dir, exist_ok=True)
|
os.makedirs(import_dir, exist_ok=True)
|
||||||
@@ -957,7 +965,7 @@ def interactive_import():
|
|||||||
|
|
||||||
# 查找所有 .mmipkg 文件
|
# 查找所有 .mmipkg 文件
|
||||||
for file in os.listdir(import_dir):
|
for file in os.listdir(import_dir):
|
||||||
if file.endswith('.mmipkg'):
|
if file.endswith(".mmipkg"):
|
||||||
file_path = os.path.join(import_dir, file)
|
file_path = os.path.join(import_dir, file)
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
input_files.append(file_path)
|
input_files.append(file_path)
|
||||||
@@ -1032,7 +1040,7 @@ def interactive_import():
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeElapsedColumn(),
|
TimeElapsedColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]导入文件...", total=len(input_files))
|
task = progress.add_task("[cyan]导入文件...", total=len(input_files))
|
||||||
|
|
||||||
@@ -1044,10 +1052,7 @@ def interactive_import():
|
|||||||
console.print(f"[bold]{'=' * 70}[/bold]")
|
console.print(f"[bold]{'=' * 70}[/bold]")
|
||||||
|
|
||||||
success = unpacker.import_to_db(
|
success = unpacker.import_to_db(
|
||||||
input_path,
|
input_path, output_dir=output_dir, replace_existing=replace_existing, batch_size=batch_size
|
||||||
output_dir=output_dir,
|
|
||||||
replace_existing=replace_existing,
|
|
||||||
batch_size=batch_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -1076,16 +1081,16 @@ def main():
|
|||||||
while True:
|
while True:
|
||||||
print_menu()
|
print_menu()
|
||||||
try:
|
try:
|
||||||
choice = get_input("请选择", "1", ['0', '1', '2'])
|
choice = get_input("请选择", "1", ["0", "1", "2"])
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\n[green]再见![/green]")
|
console.print("\n[green]再见![/green]")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if choice == '0':
|
if choice == "0":
|
||||||
console.print("\n[green]再见![/green]")
|
console.print("\n[green]再见![/green]")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
elif choice == '1':
|
elif choice == "1":
|
||||||
try:
|
try:
|
||||||
interactive_export()
|
interactive_export()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -1093,6 +1098,7 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1100,7 +1106,7 @@ def main():
|
|||||||
except (KeyboardInterrupt, EOFError):
|
except (KeyboardInterrupt, EOFError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif choice == '2':
|
elif choice == "2":
|
||||||
try:
|
try:
|
||||||
interactive_import()
|
interactive_import()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -1108,6 +1114,7 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1121,5 +1128,5 @@ def main():
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ class HeartFChatting:
|
|||||||
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
||||||
mentioned_message = message
|
mentioned_message = message
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||||
|
|
||||||
# *控制频率用
|
# *控制频率用
|
||||||
if mentioned_message:
|
if mentioned_message:
|
||||||
@@ -334,7 +334,6 @@ class HeartFChatting:
|
|||||||
self.consecutive_no_reply_count = 0
|
self.consecutive_no_reply_count = 0
|
||||||
reason = ""
|
reason = ""
|
||||||
|
|
||||||
|
|
||||||
await database_api.store_action_info(
|
await database_api.store_action_info(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
action_build_into_prompt=False,
|
action_build_into_prompt=False,
|
||||||
@@ -411,7 +410,7 @@ class HeartFChatting:
|
|||||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||||
|
|
||||||
cycle_timers, thinking_id = self.start_cycle()
|
cycle_timers, thinking_id = self.start_cycle()
|
||||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})")
|
||||||
|
|
||||||
# 第一步:动作检查
|
# 第一步:动作检查
|
||||||
available_actions: Dict[str, ActionInfo] = {}
|
available_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|||||||
@@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
|||||||
qa_manager = None
|
qa_manager = None
|
||||||
inspire_manager = None
|
inspire_manager = None
|
||||||
|
|
||||||
|
|
||||||
def get_qa_manager():
|
def get_qa_manager():
|
||||||
return qa_manager
|
return qa_manager
|
||||||
|
|
||||||
|
|
||||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||||
# 检查LPMM知识库是否启用
|
# 检查LPMM知识库是否启用
|
||||||
if global_config.lpmm_knowledge.enable:
|
if global_config.lpmm_knowledge.enable:
|
||||||
|
|||||||
@@ -92,9 +92,10 @@ class QAManager:
|
|||||||
# 过滤阈值
|
# 过滤阈值
|
||||||
result = dyn_select_top_k(result, 0.5, 1.0)
|
result = dyn_select_top_k(result, 0.5, 1.0)
|
||||||
|
|
||||||
for res in result:
|
if global_config.debug.show_lpmm_paragraph:
|
||||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
for res in result:
|
||||||
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||||
|
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||||
|
|
||||||
return result, ppr_node_weights
|
return result, ppr_node_weights
|
||||||
|
|
||||||
@@ -128,11 +129,10 @@ class QAManager:
|
|||||||
selected_knowledge = knowledge[:limit]
|
selected_knowledge = knowledge[:limit]
|
||||||
|
|
||||||
formatted_knowledge = [
|
formatted_knowledge = [
|
||||||
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}"
|
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
|
||||||
for i, k in enumerate(selected_knowledge)
|
|
||||||
]
|
]
|
||||||
# if max_score is not None:
|
# if max_score is not None:
|
||||||
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
||||||
|
|
||||||
found_knowledge = "\n".join(formatted_knowledge)
|
found_knowledge = "\n".join(formatted_knowledge)
|
||||||
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
|
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
|
|||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -164,6 +163,45 @@ class ActionPlanner:
|
|||||||
return item[1]
|
return item[1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _replace_message_ids_with_text(
|
||||||
|
self, text: Optional[str], message_id_list: List[Tuple[str, "DatabaseMessages"]]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""将文本中的 m+数字 消息ID替换为原消息内容,并添加双引号"""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
id_to_message = {msg_id: msg for msg_id, msg in message_id_list}
|
||||||
|
|
||||||
|
# 匹配m后带2-4位数字,前后不是字母数字下划线
|
||||||
|
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
|
||||||
|
|
||||||
|
matches = re.findall(pattern, text)
|
||||||
|
if matches:
|
||||||
|
available_ids = set(id_to_message.keys())
|
||||||
|
found_ids = set(matches)
|
||||||
|
missing_ids = found_ids - available_ids
|
||||||
|
if missing_ids:
|
||||||
|
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...")
|
||||||
|
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中")
|
||||||
|
|
||||||
|
def _replace(match: re.Match[str]) -> str:
|
||||||
|
msg_id = match.group(0)
|
||||||
|
message = id_to_message.get(msg_id)
|
||||||
|
if not message:
|
||||||
|
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
|
||||||
|
return msg_id
|
||||||
|
|
||||||
|
msg_text = (message.processed_plain_text or message.display_message or "").strip()
|
||||||
|
if not msg_text:
|
||||||
|
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
|
||||||
|
return msg_id
|
||||||
|
|
||||||
|
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||||
|
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview})")
|
||||||
|
return f"消息({msg_text})"
|
||||||
|
|
||||||
|
return re.sub(pattern, _replace, text)
|
||||||
|
|
||||||
def _parse_single_action(
|
def _parse_single_action(
|
||||||
self,
|
self,
|
||||||
action_json: dict,
|
action_json: dict,
|
||||||
@@ -176,7 +214,10 @@ class ActionPlanner:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
action = action_json.get("action", "no_reply")
|
action = action_json.get("action", "no_reply")
|
||||||
reasoning = action_json.get("reason", "未提供原因")
|
original_reasoning = action_json.get("reason", "未提供原因")
|
||||||
|
reasoning = self._replace_message_ids_with_text(original_reasoning, message_id_list)
|
||||||
|
if reasoning is None:
|
||||||
|
reasoning = original_reasoning
|
||||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||||
# 非no_reply动作需要target_message_id
|
# 非no_reply动作需要target_message_id
|
||||||
target_message = None
|
target_message = None
|
||||||
@@ -573,9 +614,6 @@ class ActionPlanner:
|
|||||||
# 调用LLM
|
# 调用LLM
|
||||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
|
||||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
|
||||||
|
|
||||||
if global_config.debug.show_planner_prompt:
|
if global_config.debug.show_planner_prompt:
|
||||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||||
@@ -604,6 +642,7 @@ class ActionPlanner:
|
|||||||
if llm_content:
|
if llm_content:
|
||||||
try:
|
try:
|
||||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||||
|
extracted_reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) or ""
|
||||||
if json_objects:
|
if json_objects:
|
||||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||||
filtered_actions_list = list(filtered_actions.items())
|
filtered_actions_list = list(filtered_actions.items())
|
||||||
|
|||||||
@@ -226,7 +226,9 @@ class DefaultReplyer:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, llm_response
|
return False, llm_response
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
async def build_expression_habits(
|
||||||
|
self, chat_history: str, target: str, reply_reason: str = ""
|
||||||
|
) -> Tuple[str, List[int]]:
|
||||||
# sourcery skip: for-append-to-extend
|
# sourcery skip: for-append-to-extend
|
||||||
"""构建表达习惯块
|
"""构建表达习惯块
|
||||||
|
|
||||||
|
|||||||
@@ -241,7 +241,9 @@ class PrivateReplyer:
|
|||||||
|
|
||||||
return f"{sender_relation}"
|
return f"{sender_relation}"
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
async def build_expression_habits(
|
||||||
|
self, chat_history: str, target: str, reply_reason: str = ""
|
||||||
|
) -> Tuple[str, List[int]]:
|
||||||
# sourcery skip: for-append-to-extend
|
# sourcery skip: for-append-to-extend
|
||||||
"""构建表达习惯块
|
"""构建表达习惯块
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class ChatHistorySummarizer:
|
|||||||
self.last_check_time = current_time
|
self.last_check_time = current_time
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
|
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,7 +119,7 @@ class ChatHistorySummarizer:
|
|||||||
before_count = len(self.current_batch.messages)
|
before_count = len(self.current_batch.messages)
|
||||||
self.current_batch.messages.extend(new_messages)
|
self.current_batch.messages.extend(new_messages)
|
||||||
self.current_batch.end_time = current_time
|
self.current_batch.end_time = current_time
|
||||||
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
logger.info(f"{self.log_prefix} 更新聊天话题: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||||
else:
|
else:
|
||||||
# 创建新批次
|
# 创建新批次
|
||||||
self.current_batch = MessageBatch(
|
self.current_batch = MessageBatch(
|
||||||
@@ -127,7 +127,7 @@ class ChatHistorySummarizer:
|
|||||||
start_time=new_messages[0].time if new_messages else current_time,
|
start_time=new_messages[0].time if new_messages else current_time,
|
||||||
end_time=current_time,
|
end_time=current_time,
|
||||||
)
|
)
|
||||||
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
|
logger.info(f"{self.log_prefix} 新建聊天话题: {len(new_messages)} 条消息")
|
||||||
|
|
||||||
# 检查是否需要打包
|
# 检查是否需要打包
|
||||||
await self._check_and_package(current_time)
|
await self._check_and_package(current_time)
|
||||||
|
|||||||
@@ -204,8 +204,9 @@ class WebSocketLogHandler(logging.Handler):
|
|||||||
message = formatted_msg
|
message = formatted_msg
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
log_dict = json.loads(formatted_msg)
|
log_dict = json.loads(formatted_msg)
|
||||||
message = log_dict.get('event', formatted_msg)
|
message = log_dict.get("event", formatted_msg)
|
||||||
except (json.JSONDecodeError, ValueError):
|
except (json.JSONDecodeError, ValueError):
|
||||||
# 不是 JSON,直接使用消息
|
# 不是 JSON,直接使用消息
|
||||||
message = formatted_msg
|
message = formatted_msg
|
||||||
@@ -228,10 +229,7 @@ class WebSocketLogHandler(logging.Handler):
|
|||||||
import asyncio
|
import asyncio
|
||||||
from src.webui.logs_ws import broadcast_log
|
from src.webui.logs_ws import broadcast_log
|
||||||
|
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
|
||||||
broadcast_log(log_data),
|
|
||||||
self.loop
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# WebSocket 推送失败不影响日志记录
|
# WebSocket 推送失败不影响日志记录
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from fastapi import FastAPI, APIRouter
|
from fastapi import FastAPI, APIRouter
|
||||||
from fastapi.middleware.cors import CORSMiddleware # 新增导入
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uvicorn import Config, Server as UvicornServer
|
from uvicorn import Config, Server as UvicornServer
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -17,21 +16,6 @@ class Server:
|
|||||||
self._server: Optional[UvicornServer] = None
|
self._server: Optional[UvicornServer] = None
|
||||||
self.set_address(host, port)
|
self.set_address(host, port)
|
||||||
|
|
||||||
# 配置 CORS
|
|
||||||
origins = [
|
|
||||||
"http://localhost:7999", # 允许的前端源
|
|
||||||
"http://127.0.0.1:7999",
|
|
||||||
# 在生产环境中,您应该添加实际的前端域名
|
|
||||||
]
|
|
||||||
|
|
||||||
self.app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=origins,
|
|
||||||
allow_credentials=True, # 是否支持 cookie
|
|
||||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
|
||||||
allow_headers=["*"], # 允许所有 HTTP 请求头
|
|
||||||
)
|
|
||||||
|
|
||||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||||
"""注册路由
|
"""注册路由
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
|||||||
|
|
||||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||||
MMC_VERSION = "0.11.3"
|
MMC_VERSION = "0.11.5"
|
||||||
|
|
||||||
|
|
||||||
def get_key_comment(toml_table, key):
|
def get_key_comment(toml_table, key):
|
||||||
|
|||||||
@@ -581,9 +581,15 @@ class DebugConfig(ConfigBase):
|
|||||||
show_jargon_prompt: bool = False
|
show_jargon_prompt: bool = False
|
||||||
"""是否显示jargon相关提示词"""
|
"""是否显示jargon相关提示词"""
|
||||||
|
|
||||||
|
show_memory_prompt: bool = False
|
||||||
|
"""是否显示记忆检索相关prompt"""
|
||||||
|
|
||||||
show_planner_prompt: bool = False
|
show_planner_prompt: bool = False
|
||||||
"""是否显示planner相关提示词"""
|
"""是否显示planner相关提示词"""
|
||||||
|
|
||||||
|
show_lpmm_paragraph: bool = False
|
||||||
|
"""是否显示lpmm找到的相关文段日志"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExperimentalConfig(ConfigBase):
|
class ExperimentalConfig(ConfigBase):
|
||||||
|
|||||||
@@ -467,11 +467,7 @@ class ExpressionLearner:
|
|||||||
up_content: str,
|
up_content: str,
|
||||||
current_time: float,
|
current_time: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
expr_obj = (
|
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
|
||||||
Expression.select()
|
|
||||||
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if expr_obj:
|
if expr_obj:
|
||||||
await self._update_existing_expression(
|
await self._update_existing_expression(
|
||||||
|
|||||||
@@ -42,8 +42,6 @@ def init_prompt():
|
|||||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionSelector:
|
class ExpressionSelector:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm_model = LLMRequest(
|
self.llm_model = LLMRequest(
|
||||||
@@ -262,7 +260,6 @@ class ExpressionSelector:
|
|||||||
# 4. 调用LLM
|
# 4. 调用LLM
|
||||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
|
||||||
# print(prompt)
|
# print(prompt)
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
|
|||||||
@@ -36,10 +36,7 @@ def _contains_bot_self_name(content: str) -> bool:
|
|||||||
|
|
||||||
target = content.strip().lower()
|
target = content.strip().lower()
|
||||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||||
alias_names = [
|
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||||
str(alias or "").strip().lower()
|
|
||||||
for alias in getattr(bot_config, "alias_names", []) or []
|
|
||||||
]
|
|
||||||
|
|
||||||
candidates = [name for name in [nickname, *alias_names] if name]
|
candidates = [name for name in [nickname, *alias_names] if name]
|
||||||
|
|
||||||
@@ -188,9 +185,7 @@ async def _enrich_raw_content_if_needed(
|
|||||||
# 获取该消息的前三条消息
|
# 获取该消息的前三条消息
|
||||||
try:
|
try:
|
||||||
previous_messages = get_raw_msg_before_timestamp_with_chat(
|
previous_messages = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id, timestamp=target_message.time, limit=3
|
||||||
timestamp=target_message.time,
|
|
||||||
limit=3
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if previous_messages:
|
if previous_messages:
|
||||||
@@ -245,7 +240,7 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
|||||||
last_inference = jargon_obj.last_inference_count or 0
|
last_inference = jargon_obj.last_inference_count or 0
|
||||||
|
|
||||||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||||
thresholds = [3,6, 10, 20, 40, 60, 100]
|
thresholds = [3, 6, 10, 20, 40, 60, 100]
|
||||||
|
|
||||||
if count < thresholds[0]:
|
if count < thresholds[0]:
|
||||||
return False
|
return False
|
||||||
@@ -311,7 +306,9 @@ class JargonMiner:
|
|||||||
raw_content_list = []
|
raw_content_list = []
|
||||||
if raw_content_str:
|
if raw_content_str:
|
||||||
try:
|
try:
|
||||||
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
raw_content_list = (
|
||||||
|
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||||
|
)
|
||||||
if not isinstance(raw_content_list, list):
|
if not isinstance(raw_content_list, list):
|
||||||
raw_content_list = [raw_content_list] if raw_content_list else []
|
raw_content_list = [raw_content_list] if raw_content_list else []
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
@@ -360,7 +357,6 @@ class JargonMiner:
|
|||||||
jargon_obj.save()
|
jargon_obj.save()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# 步骤2: 仅基于content推断
|
# 步骤2: 仅基于content推断
|
||||||
prompt2 = await global_prompt_manager.format_prompt(
|
prompt2 = await global_prompt_manager.format_prompt(
|
||||||
"jargon_inference_content_only_prompt",
|
"jargon_inference_content_only_prompt",
|
||||||
@@ -388,11 +384,10 @@ class JargonMiner:
|
|||||||
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
# logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
# logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
# logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||||
logger.info(f"jargon {content} 推断1结果: {response1}")
|
|
||||||
|
|
||||||
if global_config.debug.show_jargon_prompt:
|
if global_config.debug.show_jargon_prompt:
|
||||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||||
@@ -457,7 +452,9 @@ class JargonMiner:
|
|||||||
jargon_obj.is_complete = True
|
jargon_obj.is_complete = True
|
||||||
|
|
||||||
jargon_obj.save()
|
jargon_obj.save()
|
||||||
logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
|
logger.debug(
|
||||||
|
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
|
||||||
|
)
|
||||||
|
|
||||||
# 固定输出推断结果,格式化为可读形式
|
# 固定输出推断结果,格式化为可读形式
|
||||||
if is_jargon:
|
if is_jargon:
|
||||||
@@ -475,6 +472,7 @@ class JargonMiner:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"jargon推断失败: {e}")
|
logger.error(f"jargon推断失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
def should_trigger(self) -> bool:
|
def should_trigger(self) -> bool:
|
||||||
@@ -571,10 +569,7 @@ class JargonMiner:
|
|||||||
if _contains_bot_self_name(content):
|
if _contains_bot_self_name(content):
|
||||||
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||||
continue
|
continue
|
||||||
entries.append({
|
entries.append({"content": content, "raw_content": raw_content_list})
|
||||||
"content": content,
|
|
||||||
"raw_content": raw_content_list
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||||
return
|
return
|
||||||
@@ -612,19 +607,10 @@ class JargonMiner:
|
|||||||
# 根据all_global配置决定查询逻辑
|
# 根据all_global配置决定查询逻辑
|
||||||
if global_config.jargon.all_global:
|
if global_config.jargon.all_global:
|
||||||
# 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的)
|
# 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的)
|
||||||
query = (
|
query = Jargon.select().where(Jargon.content == content)
|
||||||
Jargon.select()
|
|
||||||
.where(Jargon.content == content)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 关闭all_global:只查询chat_id匹配的记录(不考虑is_global)
|
# 关闭all_global:只查询chat_id匹配的记录(不考虑is_global)
|
||||||
query = (
|
query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
|
||||||
Jargon.select()
|
|
||||||
.where(
|
|
||||||
(Jargon.chat_id == self.chat_id) &
|
|
||||||
(Jargon.content == content)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if query.exists():
|
if query.exists():
|
||||||
obj = query.get()
|
obj = query.get()
|
||||||
@@ -637,7 +623,9 @@ class JargonMiner:
|
|||||||
existing_raw_content = []
|
existing_raw_content = []
|
||||||
if obj.raw_content:
|
if obj.raw_content:
|
||||||
try:
|
try:
|
||||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
existing_raw_content = (
|
||||||
|
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||||
|
)
|
||||||
if not isinstance(existing_raw_content, list):
|
if not isinstance(existing_raw_content, list):
|
||||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
@@ -676,7 +664,7 @@ class JargonMiner:
|
|||||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
is_global=is_global_new,
|
is_global=is_global_new,
|
||||||
count=1
|
count=1,
|
||||||
)
|
)
|
||||||
saved += 1
|
saved += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -720,11 +708,7 @@ async def extract_and_store_jargon(chat_id: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def search_jargon(
|
def search_jargon(
|
||||||
keyword: str,
|
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||||
chat_id: Optional[str] = None,
|
|
||||||
limit: int = 10,
|
|
||||||
case_sensitive: bool = False,
|
|
||||||
fuzzy: bool = True
|
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
搜索jargon,支持大小写不敏感和模糊搜索
|
搜索jargon,支持大小写不敏感和模糊搜索
|
||||||
@@ -747,10 +731,7 @@ def search_jargon(
|
|||||||
keyword = keyword.strip()
|
keyword = keyword.strip()
|
||||||
|
|
||||||
# 构建查询
|
# 构建查询
|
||||||
query = Jargon.select(
|
query = Jargon.select(Jargon.content, Jargon.meaning)
|
||||||
Jargon.content,
|
|
||||||
Jargon.meaning
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建搜索条件
|
# 构建搜索条件
|
||||||
if case_sensitive:
|
if case_sensitive:
|
||||||
@@ -760,7 +741,7 @@ def search_jargon(
|
|||||||
search_condition = Jargon.content.contains(keyword)
|
search_condition = Jargon.content.contains(keyword)
|
||||||
else:
|
else:
|
||||||
# 精确匹配
|
# 精确匹配
|
||||||
search_condition = (Jargon.content == keyword)
|
search_condition = Jargon.content == keyword
|
||||||
else:
|
else:
|
||||||
# 大小写不敏感
|
# 大小写不敏感
|
||||||
if fuzzy:
|
if fuzzy:
|
||||||
@@ -768,7 +749,7 @@ def search_jargon(
|
|||||||
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
||||||
else:
|
else:
|
||||||
# 精确匹配(使用LOWER函数)
|
# 精确匹配(使用LOWER函数)
|
||||||
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
|
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
|
||||||
|
|
||||||
query = query.where(search_condition)
|
query = query.where(search_condition)
|
||||||
|
|
||||||
@@ -779,14 +760,10 @@ def search_jargon(
|
|||||||
else:
|
else:
|
||||||
# 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon
|
# 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon
|
||||||
if chat_id:
|
if chat_id:
|
||||||
query = query.where(
|
query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
|
||||||
(Jargon.chat_id == chat_id) | Jargon.is_global
|
|
||||||
)
|
|
||||||
|
|
||||||
# 只返回有meaning的记录
|
# 只返回有meaning的记录
|
||||||
query = query.where(
|
query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 按count降序排序,优先返回出现频率高的
|
# 按count降序排序,优先返回出现频率高的
|
||||||
query = query.order_by(Jargon.count.desc())
|
query = query.order_by(Jargon.count.desc())
|
||||||
@@ -797,10 +774,7 @@ def search_jargon(
|
|||||||
# 执行查询并返回结果
|
# 执行查询并返回结果
|
||||||
results = []
|
results = []
|
||||||
for jargon in query:
|
for jargon in query:
|
||||||
results.append({
|
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||||
"content": jargon.content or "",
|
|
||||||
"meaning": jargon.meaning or ""
|
|
||||||
})
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -840,10 +814,7 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||||||
if global_config.jargon.all_global:
|
if global_config.jargon.all_global:
|
||||||
query = Jargon.select().where(Jargon.content == jargon_keyword)
|
query = Jargon.select().where(Jargon.content == jargon_keyword)
|
||||||
else:
|
else:
|
||||||
query = Jargon.select().where(
|
query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
|
||||||
(Jargon.chat_id == chat_id) &
|
|
||||||
(Jargon.content == jargon_keyword)
|
|
||||||
)
|
|
||||||
|
|
||||||
if query.exists():
|
if query.exists():
|
||||||
# 更新现有记录
|
# 更新现有记录
|
||||||
@@ -854,7 +825,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||||||
existing_raw_content = []
|
existing_raw_content = []
|
||||||
if obj.raw_content:
|
if obj.raw_content:
|
||||||
try:
|
try:
|
||||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
existing_raw_content = (
|
||||||
|
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||||
|
)
|
||||||
if not isinstance(existing_raw_content, list):
|
if not isinstance(existing_raw_content, list):
|
||||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
@@ -877,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||||||
raw_content=json.dumps([raw_content], ensure_ascii=False),
|
raw_content=json.dumps([raw_content], ensure_ascii=False),
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
is_global=is_global_new,
|
is_global=is_global_new,
|
||||||
count=1
|
count=1,
|
||||||
)
|
)
|
||||||
logger.info(f"创建新jargon记录: {jargon_keyword}")
|
logger.info(f"创建新jargon记录: {jargon_keyword}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储jargon失败: {e}")
|
logger.error(f"存储jargon失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -116,9 +116,7 @@ class MessageBuilder:
|
|||||||
构建消息对象
|
构建消息对象
|
||||||
:return: Message对象
|
:return: Message对象
|
||||||
"""
|
"""
|
||||||
if len(self.__content) == 0 and not (
|
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
|
||||||
self.__role == RoleType.Assistant and self.__tool_calls
|
|
||||||
):
|
|
||||||
raise ValueError("内容不能为空")
|
raise ValueError("内容不能为空")
|
||||||
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||||
raise ValueError("Tool角色的工具调用ID不能为空")
|
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||||
|
|||||||
44
src/main.py
44
src/main.py
@@ -36,25 +36,15 @@ class MainSystem:
|
|||||||
# 使用消息API替代直接的FastAPI实例
|
# 使用消息API替代直接的FastAPI实例
|
||||||
self.app: MessageServer = get_global_api()
|
self.app: MessageServer = get_global_api()
|
||||||
self.server: Server = get_global_server()
|
self.server: Server = get_global_server()
|
||||||
|
self.webui_server = None # 独立的 WebUI 服务器
|
||||||
|
|
||||||
# 注册 WebUI API 路由
|
# 设置独立的 WebUI 服务器
|
||||||
self._register_webui_routes()
|
self._setup_webui_server()
|
||||||
|
|
||||||
# 设置 WebUI(开发/生产模式)
|
def _setup_webui_server(self):
|
||||||
self._setup_webui()
|
"""设置独立的 WebUI 服务器"""
|
||||||
|
|
||||||
def _register_webui_routes(self):
|
|
||||||
"""注册 WebUI API 路由"""
|
|
||||||
try:
|
|
||||||
from src.webui.routes import router as webui_router
|
|
||||||
self.server.register_router(webui_router)
|
|
||||||
logger.info("WebUI API 路由已注册")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"注册 WebUI API 路由失败: {e}")
|
|
||||||
|
|
||||||
def _setup_webui(self):
|
|
||||||
"""设置 WebUI(根据环境变量决定模式)"""
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
||||||
if not webui_enabled:
|
if not webui_enabled:
|
||||||
logger.info("WebUI 已禁用")
|
logger.info("WebUI 已禁用")
|
||||||
@@ -63,10 +53,22 @@ class MainSystem:
|
|||||||
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
|
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.webui.manager import setup_webui
|
from src.webui.webui_server import get_webui_server
|
||||||
setup_webui(mode=webui_mode)
|
|
||||||
|
self.webui_server = get_webui_server()
|
||||||
|
|
||||||
|
if webui_mode == "development":
|
||||||
|
logger.info("📝 WebUI 开发模式已启用")
|
||||||
|
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
|
||||||
|
logger.info("💡 请手动启动前端开发服务器: cd MaiBot-Dashboard && bun dev")
|
||||||
|
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||||
|
else:
|
||||||
|
logger.info("✅ WebUI 生产模式已启用")
|
||||||
|
logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001")
|
||||||
|
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"设置 WebUI 失败: {e}")
|
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
@@ -161,6 +163,10 @@ class MainSystem:
|
|||||||
self.server.run(),
|
self.server.run(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 如果 WebUI 服务器已初始化,添加到任务列表
|
||||||
|
if self.webui_server:
|
||||||
|
tasks.append(self.webui_server.start())
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# async def forget_memory_task(self):
|
# async def forget_memory_task(self):
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ from src.llm_models.payload_content.message import MessageBuilder, RoleType, Mes
|
|||||||
|
|
||||||
logger = get_logger("memory_retrieval")
|
logger = get_logger("memory_retrieval")
|
||||||
|
|
||||||
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 3600 # 未找到答案记录保留时长
|
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 36000 # 未找到答案记录保留时长
|
||||||
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 300 # 清理频率
|
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 3000 # 清理频率
|
||||||
_last_not_found_cleanup_ts: float = 0.0
|
_last_not_found_cleanup_ts: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
@@ -33,10 +33,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
|
|||||||
try:
|
try:
|
||||||
deleted_rows = (
|
deleted_rows = (
|
||||||
ThinkingBack.delete()
|
ThinkingBack.delete()
|
||||||
.where(
|
.where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time))
|
||||||
(ThinkingBack.found_answer == 0) &
|
|
||||||
(ThinkingBack.update_time < threshold_time)
|
|
||||||
)
|
|
||||||
.execute()
|
.execute()
|
||||||
)
|
)
|
||||||
if deleted_rows:
|
if deleted_rows:
|
||||||
@@ -45,6 +42,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
def init_memory_retrieval_prompt():
|
def init_memory_retrieval_prompt():
|
||||||
"""初始化记忆检索相关的 prompt 模板和工具"""
|
"""初始化记忆检索相关的 prompt 模板和工具"""
|
||||||
# 首先注册所有工具
|
# 首先注册所有工具
|
||||||
@@ -221,10 +219,7 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def _retrieve_concepts_with_jargon(
|
async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
|
||||||
concepts: List[str],
|
|
||||||
chat_id: str
|
|
||||||
) -> str:
|
|
||||||
"""对概念列表进行jargon检索
|
"""对概念列表进行jargon检索
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -246,25 +241,13 @@ async def _retrieve_concepts_with_jargon(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 先尝试精确匹配
|
# 先尝试精确匹配
|
||||||
jargon_results = search_jargon(
|
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||||
keyword=concept,
|
|
||||||
chat_id=chat_id,
|
|
||||||
limit=10,
|
|
||||||
case_sensitive=False,
|
|
||||||
fuzzy=False
|
|
||||||
)
|
|
||||||
|
|
||||||
is_fuzzy_match = False
|
is_fuzzy_match = False
|
||||||
|
|
||||||
# 如果精确匹配未找到,尝试模糊搜索
|
# 如果精确匹配未找到,尝试模糊搜索
|
||||||
if not jargon_results:
|
if not jargon_results:
|
||||||
jargon_results = search_jargon(
|
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||||
keyword=concept,
|
|
||||||
chat_id=chat_id,
|
|
||||||
limit=10,
|
|
||||||
case_sensitive=False,
|
|
||||||
fuzzy=True
|
|
||||||
)
|
|
||||||
is_fuzzy_match = True
|
is_fuzzy_match = True
|
||||||
|
|
||||||
if jargon_results:
|
if jargon_results:
|
||||||
@@ -298,11 +281,7 @@ async def _retrieve_concepts_with_jargon(
|
|||||||
|
|
||||||
|
|
||||||
async def _react_agent_solve_question(
|
async def _react_agent_solve_question(
|
||||||
question: str,
|
question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0, initial_info: str = ""
|
||||||
chat_id: str,
|
|
||||||
max_iterations: int = 5,
|
|
||||||
timeout: float = 30.0,
|
|
||||||
initial_info: str = ""
|
|
||||||
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
||||||
"""使用ReAct架构的Agent来解决问题
|
"""使用ReAct架构的Agent来解决问题
|
||||||
|
|
||||||
@@ -343,11 +322,12 @@ async def _react_agent_solve_question(
|
|||||||
remaining_iterations = max_iterations - current_iteration
|
remaining_iterations = max_iterations - current_iteration
|
||||||
is_final_iteration = current_iteration >= max_iterations
|
is_final_iteration = current_iteration >= max_iterations
|
||||||
|
|
||||||
|
|
||||||
if is_final_iteration:
|
if is_final_iteration:
|
||||||
# 最后一次迭代,使用最终prompt
|
# 最后一次迭代,使用最终prompt
|
||||||
tool_definitions = []
|
tool_definitions = []
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0(最后一次迭代,不提供工具调用)")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0(最后一次迭代,不提供工具调用)"
|
||||||
|
)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"memory_retrieval_react_final_prompt",
|
"memory_retrieval_react_final_prompt",
|
||||||
@@ -360,7 +340,8 @@ async def _react_agent_solve_question(
|
|||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次Prompt: {prompt}")
|
if global_config.debug.show_memory_prompt:
|
||||||
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次Prompt: {prompt}")
|
||||||
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools(
|
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||||
prompt,
|
prompt,
|
||||||
model_config=model_config.model_task_config.tool_use,
|
model_config=model_config.model_task_config.tool_use,
|
||||||
@@ -370,7 +351,9 @@ async def _react_agent_solve_question(
|
|||||||
else:
|
else:
|
||||||
# 非最终迭代,使用head_prompt
|
# 非最终迭代,使用head_prompt
|
||||||
tool_definitions = tool_registry.get_tool_definitions()
|
tool_definitions = tool_registry.get_tool_definitions()
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}"
|
||||||
|
)
|
||||||
|
|
||||||
head_prompt = await global_prompt_manager.format_prompt(
|
head_prompt = await global_prompt_manager.format_prompt(
|
||||||
"memory_retrieval_react_prompt_head",
|
"memory_retrieval_react_prompt_head",
|
||||||
@@ -398,53 +381,62 @@ async def _react_agent_solve_question(
|
|||||||
|
|
||||||
messages.extend(_conversation_messages)
|
messages.extend(_conversation_messages)
|
||||||
|
|
||||||
# 优化日志展示 - 合并所有消息到一条日志
|
if global_config.debug.show_memory_prompt:
|
||||||
log_lines = []
|
# 优化日志展示 - 合并所有消息到一条日志
|
||||||
for idx, msg in enumerate(messages, 1):
|
log_lines = []
|
||||||
role_name = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
|
for idx, msg in enumerate(messages, 1):
|
||||||
|
role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||||
|
|
||||||
# 处理内容 - 显示完整内容,不截断
|
# 处理内容 - 显示完整内容,不截断
|
||||||
if isinstance(msg.content, str):
|
if isinstance(msg.content, str):
|
||||||
full_content = msg.content
|
full_content = msg.content
|
||||||
content_type = "文本"
|
content_type = "文本"
|
||||||
elif isinstance(msg.content, list):
|
elif isinstance(msg.content, list):
|
||||||
text_parts = [item for item in msg.content if isinstance(item, str)]
|
text_parts = [item for item in msg.content if isinstance(item, str)]
|
||||||
image_count = len([item for item in msg.content if isinstance(item, tuple)])
|
image_count = len([item for item in msg.content if isinstance(item, tuple)])
|
||||||
full_content = "".join(text_parts) if text_parts else ""
|
full_content = "".join(text_parts) if text_parts else ""
|
||||||
content_type = f"混合({len(text_parts)}段文本, {image_count}张图片)"
|
content_type = f"混合({len(text_parts)}段文本, {image_count}张图片)"
|
||||||
else:
|
else:
|
||||||
full_content = str(msg.content)
|
full_content = str(msg.content)
|
||||||
content_type = "未知"
|
content_type = "未知"
|
||||||
|
|
||||||
# 构建单条消息的日志信息
|
# 构建单条消息的日志信息
|
||||||
msg_info = f"\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n========================================"
|
msg_info = f"\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n========================================"
|
||||||
|
|
||||||
if full_content:
|
if full_content:
|
||||||
msg_info += f"\n{full_content}"
|
msg_info += f"\n{full_content}"
|
||||||
|
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
msg_info += f"\n 工具调用: {len(msg.tool_calls)}个"
|
msg_info += f"\n 工具调用: {len(msg.tool_calls)}个"
|
||||||
for tool_call in msg.tool_calls:
|
for tool_call in msg.tool_calls:
|
||||||
msg_info += f"\n - {tool_call}"
|
msg_info += f"\n - {tool_call}"
|
||||||
|
|
||||||
if msg.tool_call_id:
|
if msg.tool_call_id:
|
||||||
msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
||||||
|
|
||||||
log_lines.append(msg_info)
|
log_lines.append(msg_info)
|
||||||
|
|
||||||
# 合并所有消息为一条日志输出
|
# 合并所有消息为一条日志输出
|
||||||
logger.info(f"消息列表 (共{len(messages)}条):{''.join(log_lines)}")
|
logger.info(f"消息列表 (共{len(messages)}条):{''.join(log_lines)}")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools_by_message_factory(
|
(
|
||||||
|
success,
|
||||||
|
response,
|
||||||
|
reasoning_content,
|
||||||
|
model_name,
|
||||||
|
tool_calls,
|
||||||
|
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||||
message_factory,
|
message_factory,
|
||||||
model_config=model_config.model_task_config.tool_use,
|
model_config=model_config.model_task_config.tool_use,
|
||||||
tool_options=tool_definitions,
|
tool_options=tool_definitions,
|
||||||
request_type="memory.react",
|
request_type="memory.react",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
||||||
@@ -465,12 +457,7 @@ async def _react_agent_solve_question(
|
|||||||
assistant_message = assistant_builder.build()
|
assistant_message = assistant_builder.build()
|
||||||
|
|
||||||
# 记录思考步骤
|
# 记录思考步骤
|
||||||
step = {
|
step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []}
|
||||||
"iteration": iteration + 1,
|
|
||||||
"thought": response,
|
|
||||||
"actions": [],
|
|
||||||
"observations": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# 优先从思考内容中提取found_answer或not_enough_info
|
# 优先从思考内容中提取found_answer或not_enough_info
|
||||||
def extract_quoted_content(text, func_name, param_name):
|
def extract_quoted_content(text, func_name, param_name):
|
||||||
@@ -495,14 +482,14 @@ async def _react_agent_solve_question(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 查找参数名和等号
|
# 查找参数名和等号
|
||||||
param_pattern = f'{param_name}='
|
param_pattern = f"{param_name}="
|
||||||
param_pos = text_lower.find(param_pattern, func_pos)
|
param_pos = text_lower.find(param_pattern, func_pos)
|
||||||
if param_pos == -1:
|
if param_pos == -1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 跳过参数名、等号和空白
|
# 跳过参数名、等号和空白
|
||||||
start_pos = param_pos + len(param_pattern)
|
start_pos = param_pos + len(param_pattern)
|
||||||
while start_pos < len(text) and text[start_pos] in ' \t\n':
|
while start_pos < len(text) and text[start_pos] in " \t\n":
|
||||||
start_pos += 1
|
start_pos += 1
|
||||||
|
|
||||||
if start_pos >= len(text):
|
if start_pos >= len(text):
|
||||||
@@ -518,13 +505,13 @@ async def _react_agent_solve_question(
|
|||||||
while end_pos < len(text):
|
while end_pos < len(text):
|
||||||
if text[end_pos] == quote_char:
|
if text[end_pos] == quote_char:
|
||||||
# 检查是否是转义的引号
|
# 检查是否是转义的引号
|
||||||
if end_pos > start_pos + 1 and text[end_pos - 1] == '\\':
|
if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
|
||||||
end_pos += 1
|
end_pos += 1
|
||||||
continue
|
continue
|
||||||
# 找到匹配的引号
|
# 找到匹配的引号
|
||||||
content = text[start_pos + 1:end_pos]
|
content = text[start_pos + 1 : end_pos]
|
||||||
# 处理转义字符
|
# 处理转义字符
|
||||||
content = content.replace('\\"', '"').replace("\\'", "'").replace('\\\\', '\\')
|
content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
|
||||||
return content
|
return content
|
||||||
end_pos += 1
|
end_pos += 1
|
||||||
|
|
||||||
@@ -536,27 +523,35 @@ async def _react_agent_solve_question(
|
|||||||
|
|
||||||
# 只检查response(LLM的直接输出内容),不检查reasoning_content
|
# 只检查response(LLM的直接输出内容),不检查reasoning_content
|
||||||
if response:
|
if response:
|
||||||
found_answer_content = extract_quoted_content(response, 'found_answer', 'answer')
|
found_answer_content = extract_quoted_content(response, "found_answer", "answer")
|
||||||
if not found_answer_content:
|
if not found_answer_content:
|
||||||
not_enough_info_reason = extract_quoted_content(response, 'not_enough_info', 'reason')
|
not_enough_info_reason = extract_quoted_content(response, "not_enough_info", "reason")
|
||||||
|
|
||||||
# 如果从输出内容中找到了答案,直接返回
|
# 如果从输出内容中找到了答案,直接返回
|
||||||
if found_answer_content:
|
if found_answer_content:
|
||||||
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
||||||
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}...")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}..."
|
||||||
|
)
|
||||||
return True, found_answer_content, thinking_steps, False
|
return True, found_answer_content, thinking_steps, False
|
||||||
|
|
||||||
if not_enough_info_reason:
|
if not_enough_info_reason:
|
||||||
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}})
|
step["actions"].append(
|
||||||
|
{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}
|
||||||
|
)
|
||||||
step["observations"] = ["从LLM输出内容中检测到not_enough_info"]
|
step["observations"] = ["从LLM输出内容中检测到not_enough_info"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}...")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}..."
|
||||||
|
)
|
||||||
return False, not_enough_info_reason, thinking_steps, False
|
return False, not_enough_info_reason, thinking_steps, False
|
||||||
|
|
||||||
if is_final_iteration:
|
if is_final_iteration:
|
||||||
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}})
|
step["actions"].append(
|
||||||
|
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}
|
||||||
|
)
|
||||||
step["observations"] = ["已到达最后一次迭代,无法找到答案"]
|
step["observations"] = ["已到达最后一次迭代,无法找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
|
||||||
@@ -596,7 +591,9 @@ async def _react_agent_solve_question(
|
|||||||
tool_name = tool_call.func_name
|
tool_name = tool_call.func_name
|
||||||
tool_args = tool_call.args or {}
|
tool_args = tool_call.args or {}
|
||||||
|
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i+1}/{len(tool_calls)}: {tool_name}({tool_args})")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||||
|
)
|
||||||
|
|
||||||
# 普通工具调用
|
# 普通工具调用
|
||||||
tool = tool_registry.get_tool(tool_name)
|
tool = tool_registry.get_tool(tool_name)
|
||||||
@@ -606,6 +603,7 @@ async def _react_agent_solve_question(
|
|||||||
|
|
||||||
# 如果工具函数签名需要chat_id,添加它
|
# 如果工具函数签名需要chat_id,添加它
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
sig = inspect.signature(tool.execute_func)
|
sig = inspect.signature(tool.execute_func)
|
||||||
if "chat_id" in sig.parameters:
|
if "chat_id" in sig.parameters:
|
||||||
tool_params["chat_id"] = chat_id
|
tool_params["chat_id"] = chat_id
|
||||||
@@ -625,7 +623,7 @@ async def _react_agent_solve_question(
|
|||||||
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
|
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
|
||||||
else:
|
else:
|
||||||
error_msg = f"未知的工具类型: {tool_name}"
|
error_msg = f"未知的工具类型: {tool_name}"
|
||||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1}/{len(tool_calls)} {error_msg}")
|
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}")
|
||||||
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
|
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
|
||||||
|
|
||||||
# 并行执行所有工具
|
# 并行执行所有工具
|
||||||
@@ -636,7 +634,7 @@ async def _react_agent_solve_question(
|
|||||||
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
|
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
|
||||||
if isinstance(observation, Exception):
|
if isinstance(observation, Exception):
|
||||||
observation = f"工具执行异常: {str(observation)}"
|
observation = f"工具执行异常: {str(observation)}"
|
||||||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行异常: {observation}")
|
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
|
||||||
|
|
||||||
observation_text = observation if isinstance(observation, str) else str(observation)
|
observation_text = observation if isinstance(observation, str) else str(observation)
|
||||||
step["observations"].append(observation_text)
|
step["observations"].append(observation_text)
|
||||||
@@ -655,7 +653,9 @@ async def _react_agent_solve_question(
|
|||||||
# 迭代超时应该直接视为not_enough_info,而不是使用已有信息
|
# 迭代超时应该直接视为not_enough_info,而不是使用已有信息
|
||||||
# 只有Agent明确返回found_answer时,才认为找到了答案
|
# 只有Agent明确返回found_answer时,才认为找到了答案
|
||||||
if collected_info:
|
if collected_info:
|
||||||
logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}...")
|
logger.warning(
|
||||||
|
f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}..."
|
||||||
|
)
|
||||||
if is_timeout:
|
if is_timeout:
|
||||||
logger.warning("ReAct Agent超时,直接视为not_enough_info")
|
logger.warning("ReAct Agent超时,直接视为not_enough_info")
|
||||||
else:
|
else:
|
||||||
@@ -680,10 +680,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
|
|||||||
# 查询最近时间窗口内的记录,按更新时间倒序
|
# 查询最近时间窗口内的记录,按更新时间倒序
|
||||||
records = (
|
records = (
|
||||||
ThinkingBack.select()
|
ThinkingBack.select()
|
||||||
.where(
|
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
|
||||||
(ThinkingBack.chat_id == chat_id) &
|
|
||||||
(ThinkingBack.update_time >= start_time)
|
|
||||||
)
|
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
.limit(5) # 最多返回5条最近的记录
|
.limit(5) # 最多返回5条最近的记录
|
||||||
)
|
)
|
||||||
@@ -735,9 +732,9 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li
|
|||||||
records = (
|
records = (
|
||||||
ThinkingBack.select()
|
ThinkingBack.select()
|
||||||
.where(
|
.where(
|
||||||
(ThinkingBack.chat_id == chat_id) &
|
(ThinkingBack.chat_id == chat_id)
|
||||||
(ThinkingBack.update_time >= start_time) &
|
& (ThinkingBack.update_time >= start_time)
|
||||||
(ThinkingBack.found_answer == 1)
|
& (ThinkingBack.found_answer == 1)
|
||||||
)
|
)
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
.limit(5) # 最多返回5条最近的记录
|
.limit(5) # 最多返回5条最近的记录
|
||||||
@@ -775,10 +772,7 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st
|
|||||||
# 按更新时间倒序,获取最新的记录
|
# 按更新时间倒序,获取最新的记录
|
||||||
records = (
|
records = (
|
||||||
ThinkingBack.select()
|
ThinkingBack.select()
|
||||||
.where(
|
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||||
(ThinkingBack.chat_id == chat_id) &
|
|
||||||
(ThinkingBack.question == question)
|
|
||||||
)
|
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
@@ -857,6 +851,7 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
|
|||||||
jargon_keyword = analysis_result.get("jargon_keyword", "").strip()
|
jargon_keyword = analysis_result.get("jargon_keyword", "").strip()
|
||||||
if jargon_keyword:
|
if jargon_keyword:
|
||||||
from src.jargon.jargon_miner import store_jargon_from_answer
|
from src.jargon.jargon_miner import store_jargon_from_answer
|
||||||
|
|
||||||
await store_jargon_from_answer(jargon_keyword, answer, chat_id)
|
await store_jargon_from_answer(jargon_keyword, answer, chat_id)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...")
|
logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...")
|
||||||
@@ -882,14 +877,8 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
|
|||||||
logger.error(f"分析问题和答案时发生异常: {e}")
|
logger.error(f"分析问题和答案时发生异常: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _store_thinking_back(
|
def _store_thinking_back(
|
||||||
chat_id: str,
|
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
|
||||||
question: str,
|
|
||||||
context: str,
|
|
||||||
found_answer: bool,
|
|
||||||
answer: str,
|
|
||||||
thinking_steps: List[Dict[str, Any]]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
||||||
|
|
||||||
@@ -907,10 +896,7 @@ def _store_thinking_back(
|
|||||||
# 先查询是否已存在相同chat_id和问题的记录
|
# 先查询是否已存在相同chat_id和问题的记录
|
||||||
existing = (
|
existing = (
|
||||||
ThinkingBack.select()
|
ThinkingBack.select()
|
||||||
.where(
|
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||||
(ThinkingBack.chat_id == chat_id) &
|
|
||||||
(ThinkingBack.question == question)
|
|
||||||
)
|
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
@@ -935,19 +921,14 @@ def _store_thinking_back(
|
|||||||
answer=answer,
|
answer=answer,
|
||||||
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
||||||
create_time=now,
|
create_time=now,
|
||||||
update_time=now
|
update_time=now,
|
||||||
)
|
)
|
||||||
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
|
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储思考过程失败: {e}")
|
logger.error(f"存储思考过程失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _process_single_question(
|
async def _process_single_question(question: str, chat_id: str, context: str, initial_info: str = "") -> Optional[str]:
|
||||||
question: str,
|
|
||||||
chat_id: str,
|
|
||||||
context: str,
|
|
||||||
initial_info: str = ""
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""处理单个问题的查询(包含缓存检查逻辑)
|
"""处理单个问题的查询(包含缓存检查逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1015,7 +996,7 @@ async def _process_single_question(
|
|||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
max_iterations=global_config.memory.max_agent_iterations,
|
max_iterations=global_config.memory.max_agent_iterations,
|
||||||
timeout=120.0,
|
timeout=120.0,
|
||||||
initial_info=question_initial_info
|
initial_info=question_initial_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 存储到数据库(超时时不存储)
|
# 存储到数据库(超时时不存储)
|
||||||
@@ -1026,7 +1007,7 @@ async def _process_single_question(
|
|||||||
context=context,
|
context=context,
|
||||||
found_answer=found_answer,
|
found_answer=found_answer,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
thinking_steps=thinking_steps
|
thinking_steps=thinking_steps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
||||||
@@ -1089,7 +1070,8 @@ async def build_memory_retrieval_prompt(
|
|||||||
request_type="memory.question",
|
request_type="memory.question",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
|
if global_config.debug.show_memory_prompt:
|
||||||
|
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
|
||||||
logger.info(f"记忆检索问题生成响应: {response}")
|
logger.info(f"记忆检索问题生成响应: {response}")
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
@@ -1112,7 +1094,6 @@ async def build_memory_retrieval_prompt(
|
|||||||
else:
|
else:
|
||||||
logger.info("概念检索未找到任何结果")
|
logger.info("概念检索未找到任何结果")
|
||||||
|
|
||||||
|
|
||||||
# 获取缓存的记忆(与question时使用相同的时间窗口和数量限制)
|
# 获取缓存的记忆(与question时使用相同的时间窗口和数量限制)
|
||||||
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
|
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
|
||||||
|
|
||||||
@@ -1141,12 +1122,7 @@ async def build_memory_retrieval_prompt(
|
|||||||
|
|
||||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||||
question_tasks = [
|
question_tasks = [
|
||||||
_process_single_question(
|
_process_single_question(question=question, chat_id=chat_id, context=message, initial_info=initial_info)
|
||||||
question=question,
|
|
||||||
chat_id=chat_id,
|
|
||||||
context=message,
|
|
||||||
initial_info=initial_info
|
|
||||||
)
|
|
||||||
for question in questions
|
for question in questions
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1179,7 +1155,9 @@ async def build_memory_retrieval_prompt(
|
|||||||
|
|
||||||
if all_results:
|
if all_results:
|
||||||
retrieved_memory = "\n\n".join(all_results)
|
retrieved_memory = "\n\n".join(all_results)
|
||||||
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)")
|
logger.info(
|
||||||
|
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)"
|
||||||
|
)
|
||||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||||
else:
|
else:
|
||||||
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
记忆系统工具函数
|
记忆系统工具函数
|
||||||
包含模糊查找、相似度计算等工具函数
|
包含模糊查找、相似度计算等工具函数
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -14,6 +15,7 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("memory_utils")
|
logger = get_logger("memory_utils")
|
||||||
|
|
||||||
|
|
||||||
def parse_md_json(json_text: str) -> list[str]:
|
def parse_md_json(json_text: str) -> list[str]:
|
||||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||||
json_objects = []
|
json_objects = []
|
||||||
@@ -52,6 +54,7 @@ def parse_md_json(json_text: str) -> list[str]:
|
|||||||
|
|
||||||
return json_objects, reasoning_content
|
return json_objects, reasoning_content
|
||||||
|
|
||||||
|
|
||||||
def calculate_similarity(text1: str, text2: str) -> float:
|
def calculate_similarity(text1: str, text2: str) -> float:
|
||||||
"""
|
"""
|
||||||
计算两个文本的相似度
|
计算两个文本的相似度
|
||||||
@@ -97,10 +100,10 @@ def preprocess_text(text: str) -> str:
|
|||||||
text = text.lower()
|
text = text.lower()
|
||||||
|
|
||||||
# 移除标点符号和特殊字符
|
# 移除标点符号和特殊字符
|
||||||
text = re.sub(r'[^\w\s]', '', text)
|
text = re.sub(r"[^\w\s]", "", text)
|
||||||
|
|
||||||
# 移除多余空格
|
# 移除多余空格
|
||||||
text = re.sub(r'\s+', ' ', text).strip()
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@@ -109,7 +112,6 @@ def preprocess_text(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_datetime_to_timestamp(value: str) -> float:
|
def parse_datetime_to_timestamp(value: str) -> float:
|
||||||
"""
|
"""
|
||||||
接受多种常见格式并转换为时间戳(秒)
|
接受多种常见格式并转换为时间戳(秒)
|
||||||
@@ -164,4 +166,3 @@ def parse_time_range(time_range: str) -> Tuple[float, float]:
|
|||||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||||
|
|
||||||
return start_timestamp, end_timestamp
|
return start_timestamp, end_timestamp
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
|||||||
from .query_person_info import register_tool as register_query_person_info
|
from .query_person_info import register_tool as register_query_person_info
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
def init_all_tools():
|
def init_all_tools():
|
||||||
"""初始化并注册所有记忆检索工具"""
|
"""初始化并注册所有记忆检索工具"""
|
||||||
register_query_jargon()
|
register_query_jargon()
|
||||||
|
|||||||
@@ -15,10 +15,7 @@ logger = get_logger("memory_retrieval_tools")
|
|||||||
|
|
||||||
|
|
||||||
async def query_chat_history(
|
async def query_chat_history(
|
||||||
chat_id: str,
|
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
|
||||||
keyword: Optional[str] = None,
|
|
||||||
time_range: Optional[str] = None,
|
|
||||||
fuzzy: bool = True
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||||
|
|
||||||
@@ -50,17 +47,11 @@ async def query_chat_history(
|
|||||||
# 时间范围:查询与时间范围有交集的记录
|
# 时间范围:查询与时间范围有交集的记录
|
||||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||||
time_filter = (
|
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
|
||||||
(ChatHistory.start_time < end_timestamp) &
|
|
||||||
(ChatHistory.end_time > start_timestamp)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||||
time_filter = (
|
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
|
||||||
(ChatHistory.start_time <= target_timestamp) &
|
|
||||||
(ChatHistory.end_time >= target_timestamp)
|
|
||||||
)
|
|
||||||
query = query.where(time_filter)
|
query = query.where(time_filter)
|
||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
@@ -91,7 +82,9 @@ async def query_chat_history(
|
|||||||
record_keywords_list = []
|
record_keywords_list = []
|
||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
keywords_data = (
|
||||||
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
|
)
|
||||||
if isinstance(keywords_data, list):
|
if isinstance(keywords_data, list):
|
||||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||||
except (json.JSONDecodeError, TypeError, ValueError):
|
except (json.JSONDecodeError, TypeError, ValueError):
|
||||||
@@ -102,20 +95,24 @@ async def query_chat_history(
|
|||||||
if fuzzy:
|
if fuzzy:
|
||||||
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
||||||
for kw in keywords_lower:
|
for kw in keywords_lower:
|
||||||
if (kw in theme or
|
if (
|
||||||
kw in summary or
|
kw in theme
|
||||||
kw in original_text or
|
or kw in summary
|
||||||
any(kw in k for k in record_keywords_list)):
|
or kw in original_text
|
||||||
|
or any(kw in k for k in record_keywords_list)
|
||||||
|
):
|
||||||
matched = True
|
matched = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
||||||
matched = True
|
matched = True
|
||||||
for kw in keywords_lower:
|
for kw in keywords_lower:
|
||||||
kw_matched = (kw in theme or
|
kw_matched = (
|
||||||
kw in summary or
|
kw in theme
|
||||||
kw in original_text or
|
or kw in summary
|
||||||
any(kw in k for k in record_keywords_list))
|
or kw in original_text
|
||||||
|
or any(kw in k for k in record_keywords_list)
|
||||||
|
)
|
||||||
if not kw_matched:
|
if not kw_matched:
|
||||||
matched = False
|
matched = False
|
||||||
break
|
break
|
||||||
@@ -160,6 +157,7 @@ async def query_chat_history(
|
|||||||
|
|
||||||
# 添加时间范围
|
# 添加时间范围
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||||
@@ -199,20 +197,20 @@ def register_tool():
|
|||||||
"name": "keyword",
|
"name": "keyword",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
||||||
"required": False
|
"required": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "time_range",
|
"name": "time_range",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||||
"required": False
|
"required": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "fuzzy",
|
"name": "fuzzy",
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
||||||
"required": False
|
"required": False,
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
execute_func=query_chat_history
|
execute_func=query_chat_history,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -73,5 +73,3 @@ def register_tool():
|
|||||||
],
|
],
|
||||||
execute_func=query_lpmm_knowledge,
|
execute_func=query_lpmm_knowledge,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ def _format_group_nick_names(group_nick_name_field) -> str:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析JSON格式的群昵称列表
|
# 解析JSON格式的群昵称列表
|
||||||
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
group_nick_names_data = (
|
||||||
|
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
|
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
|
||||||
return ""
|
return ""
|
||||||
@@ -71,9 +73,7 @@ async def query_person_info(person_name: str) -> str:
|
|||||||
return "用户名称为空"
|
return "用户名称为空"
|
||||||
|
|
||||||
# 构建查询条件(使用模糊查询)
|
# 构建查询条件(使用模糊查询)
|
||||||
query = PersonInfo.select().where(
|
query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
|
||||||
PersonInfo.person_name.contains(person_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
records = list(query.limit(20)) # 最多返回20条记录
|
records = list(query.limit(20)) # 最多返回20条记录
|
||||||
@@ -137,7 +137,11 @@ async def query_person_info(person_name: str) -> str:
|
|||||||
# 记忆点(memory_points)
|
# 记忆点(memory_points)
|
||||||
if record.memory_points:
|
if record.memory_points:
|
||||||
try:
|
try:
|
||||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
memory_points_data = (
|
||||||
|
json.loads(record.memory_points)
|
||||||
|
if isinstance(record.memory_points, str)
|
||||||
|
else record.memory_points
|
||||||
|
)
|
||||||
if isinstance(memory_points_data, list) and memory_points_data:
|
if isinstance(memory_points_data, list) and memory_points_data:
|
||||||
# 解析记忆点格式:category:content:weight
|
# 解析记忆点格式:category:content:weight
|
||||||
memory_list = []
|
memory_list = []
|
||||||
@@ -206,7 +210,11 @@ async def query_person_info(person_name: str) -> str:
|
|||||||
# 记忆点(memory_points)
|
# 记忆点(memory_points)
|
||||||
if record.memory_points:
|
if record.memory_points:
|
||||||
try:
|
try:
|
||||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
memory_points_data = (
|
||||||
|
json.loads(record.memory_points)
|
||||||
|
if isinstance(record.memory_points, str)
|
||||||
|
else record.memory_points
|
||||||
|
)
|
||||||
if isinstance(memory_points_data, list) and memory_points_data:
|
if isinstance(memory_points_data, list) and memory_points_data:
|
||||||
# 解析记忆点格式:category:content:weight
|
# 解析记忆点格式:category:content:weight
|
||||||
memory_list = []
|
memory_list = []
|
||||||
@@ -275,13 +283,7 @@ def register_tool():
|
|||||||
name="query_person_info",
|
name="query_person_info",
|
||||||
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
|
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
|
||||||
parameters=[
|
parameters=[
|
||||||
{
|
{"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
|
||||||
"name": "person_name",
|
|
||||||
"type": "string",
|
|
||||||
"description": "用户名称,用于查询用户信息",
|
|
||||||
"required": True
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
execute_func=query_person_info
|
execute_func=query_person_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -82,11 +82,7 @@ class MemoryRetrievalTool:
|
|||||||
param_tuples.append(param_tuple)
|
param_tuples.append(param_tuple)
|
||||||
|
|
||||||
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
||||||
tool_def = {
|
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
|
||||||
"name": self.name,
|
|
||||||
"description": self.description,
|
|
||||||
"parameters": param_tuples
|
|
||||||
}
|
|
||||||
|
|
||||||
return tool_def
|
return tool_def
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
|
|||||||
class Person:
|
class Person:
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_person(
|
def register_person(
|
||||||
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
|
cls,
|
||||||
|
platform: str,
|
||||||
|
user_id: str,
|
||||||
|
nickname: str,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
group_nick_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
注册新用户的类方法
|
注册新用户的类方法
|
||||||
@@ -781,7 +786,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
|||||||
if len(parts) >= 2:
|
if len(parts) >= 2:
|
||||||
existing_content = parts[1].strip()
|
existing_content = parts[1].strip()
|
||||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
||||||
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
|
if (
|
||||||
|
existing_content == memory_content
|
||||||
|
or memory_content in existing_content
|
||||||
|
or existing_content in memory_content
|
||||||
|
):
|
||||||
is_duplicate = True
|
is_duplicate = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ class ToolExecutor:
|
|||||||
prompt=prompt, tools=tools, raise_when_empty=False
|
prompt=prompt, tools=tools, raise_when_empty=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 执行工具调用
|
# 执行工具调用
|
||||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||||
|
|
||||||
|
|||||||
@@ -102,13 +102,13 @@ class EmojiAction(BaseAction):
|
|||||||
|
|
||||||
# 5. 调用LLM
|
# 5. 调用LLM
|
||||||
models = llm_api.get_available_models()
|
models = llm_api.get_available_models()
|
||||||
chat_model_config = models.get("replyer") # 使用字典访问方式
|
chat_model_config = models.get("utils") # 使用字典访问方式
|
||||||
if not chat_model_config:
|
if not chat_model_config:
|
||||||
logger.error(f"{self.log_prefix} 未找到'replyer'模型配置,无法调用LLM")
|
logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM")
|
||||||
return False, "未找到'replyer'模型配置"
|
return False, "未找到'utils'模型配置"
|
||||||
|
|
||||||
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
|
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
|
||||||
prompt, model_config=chat_model_config, request_type="emoji"
|
prompt, model_config=chat_model_config, request_type="emoji.select"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
|
|||||||
@@ -319,6 +319,58 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
|||||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 原始 TOML 文件操作接口 =====
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/bot/raw")
|
||||||
|
async def get_bot_config_raw():
|
||||||
|
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||||
|
try:
|
||||||
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||||
|
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
raw_content = f.read()
|
||||||
|
|
||||||
|
return {"success": True, "content": raw_content}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取配置文件失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/bot/raw")
|
||||||
|
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||||
|
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||||
|
try:
|
||||||
|
# 验证 TOML 格式
|
||||||
|
try:
|
||||||
|
config_data = tomlkit.loads(raw_content)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||||
|
|
||||||
|
# 验证配置数据结构
|
||||||
|
try:
|
||||||
|
Config.from_dict(config_data)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||||
|
|
||||||
|
# 保存配置文件
|
||||||
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(raw_content)
|
||||||
|
|
||||||
|
logger.info("麦麦主程序配置已更新(原始模式)")
|
||||||
|
return {"success": True, "message": "配置已保存"}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存配置文件失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model/section/{section_name}")
|
@router.post("/model/section/{section_name}")
|
||||||
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
||||||
"""更新模型配置的指定节(保留注释和格式)"""
|
"""更新模型配置的指定节(保留注释和格式)"""
|
||||||
@@ -364,3 +416,144 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新配置节失败: {e}")
|
logger.error(f"更新配置节失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 适配器配置管理接口 =====
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/adapter-config/path")
|
||||||
|
async def get_adapter_config_path():
|
||||||
|
"""获取保存的适配器配置文件路径"""
|
||||||
|
try:
|
||||||
|
# 从 data/webui.json 读取路径偏好
|
||||||
|
webui_data_path = os.path.join("data", "webui.json")
|
||||||
|
if not os.path.exists(webui_data_path):
|
||||||
|
return {"success": True, "path": None}
|
||||||
|
|
||||||
|
import json
|
||||||
|
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||||
|
webui_data = json.load(f)
|
||||||
|
|
||||||
|
adapter_config_path = webui_data.get("adapter_config_path")
|
||||||
|
if not adapter_config_path:
|
||||||
|
return {"success": True, "path": None}
|
||||||
|
|
||||||
|
# 检查文件是否存在并返回最后修改时间
|
||||||
|
if os.path.exists(adapter_config_path):
|
||||||
|
import datetime
|
||||||
|
mtime = os.path.getmtime(adapter_config_path)
|
||||||
|
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||||
|
return {"success": True, "path": adapter_config_path, "lastModified": last_modified}
|
||||||
|
else:
|
||||||
|
return {"success": True, "path": adapter_config_path, "lastModified": None}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取适配器配置路径失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/adapter-config/path")
|
||||||
|
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||||
|
"""保存适配器配置文件路径偏好"""
|
||||||
|
try:
|
||||||
|
path = data.get("path")
|
||||||
|
if not path:
|
||||||
|
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||||
|
|
||||||
|
# 保存到 data/webui.json
|
||||||
|
webui_data_path = os.path.join("data", "webui.json")
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 读取现有数据
|
||||||
|
if os.path.exists(webui_data_path):
|
||||||
|
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||||
|
webui_data = json.load(f)
|
||||||
|
else:
|
||||||
|
webui_data = {}
|
||||||
|
|
||||||
|
# 更新路径
|
||||||
|
webui_data["adapter_config_path"] = path
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
os.makedirs("data", exist_ok=True)
|
||||||
|
with open(webui_data_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(webui_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"适配器配置路径已保存: {path}")
|
||||||
|
return {"success": True, "message": "路径已保存"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存适配器配置路径失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/adapter-config")
|
||||||
|
async def get_adapter_config(path: str):
|
||||||
|
"""从指定路径读取适配器配置文件"""
|
||||||
|
try:
|
||||||
|
if not path:
|
||||||
|
raise HTTPException(status_code=400, detail="路径参数不能为空")
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||||
|
|
||||||
|
# 检查文件扩展名
|
||||||
|
if not path.endswith(".toml"):
|
||||||
|
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
logger.info(f"已读取适配器配置: {path}")
|
||||||
|
return {"success": True, "content": content}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取适配器配置失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/adapter-config")
|
||||||
|
async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||||
|
"""保存适配器配置到指定路径"""
|
||||||
|
try:
|
||||||
|
path = data.get("path")
|
||||||
|
content = data.get("content")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||||
|
if content is None:
|
||||||
|
raise HTTPException(status_code=400, detail="配置内容不能为空")
|
||||||
|
|
||||||
|
# 检查文件扩展名
|
||||||
|
if not path.endswith(".toml"):
|
||||||
|
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||||
|
|
||||||
|
# 验证 TOML 格式
|
||||||
|
try:
|
||||||
|
import toml
|
||||||
|
toml.loads(content)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
logger.info(f"适配器配置已保存: {path}")
|
||||||
|
return {"success": True, "message": "配置已保存"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存适配器配置失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""表情包管理 API 路由"""
|
"""表情包管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query
|
from fastapi import APIRouter, HTTPException, Header, Query
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -7,6 +9,7 @@ from src.common.database.database_model import Emoji
|
|||||||
from .token_manager import get_token_manager
|
from .token_manager import get_token_manager
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
logger = get_logger("webui.emoji")
|
logger = get_logger("webui.emoji")
|
||||||
|
|
||||||
@@ -16,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
|||||||
|
|
||||||
class EmojiResponse(BaseModel):
|
class EmojiResponse(BaseModel):
|
||||||
"""表情包响应"""
|
"""表情包响应"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
full_path: str
|
full_path: str
|
||||||
format: str
|
format: str
|
||||||
@@ -24,7 +28,7 @@ class EmojiResponse(BaseModel):
|
|||||||
query_count: int
|
query_count: int
|
||||||
is_registered: bool
|
is_registered: bool
|
||||||
is_banned: bool
|
is_banned: bool
|
||||||
emotion: Optional[List[str]] # 解析后的 JSON
|
emotion: Optional[str] # 直接返回字符串
|
||||||
record_time: float
|
record_time: float
|
||||||
register_time: Optional[float]
|
register_time: Optional[float]
|
||||||
usage_count: int
|
usage_count: int
|
||||||
@@ -33,6 +37,7 @@ class EmojiResponse(BaseModel):
|
|||||||
|
|
||||||
class EmojiListResponse(BaseModel):
|
class EmojiListResponse(BaseModel):
|
||||||
"""表情包列表响应"""
|
"""表情包列表响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
@@ -42,20 +47,23 @@ class EmojiListResponse(BaseModel):
|
|||||||
|
|
||||||
class EmojiDetailResponse(BaseModel):
|
class EmojiDetailResponse(BaseModel):
|
||||||
"""表情包详情响应"""
|
"""表情包详情响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
data: EmojiResponse
|
data: EmojiResponse
|
||||||
|
|
||||||
|
|
||||||
class EmojiUpdateRequest(BaseModel):
|
class EmojiUpdateRequest(BaseModel):
|
||||||
"""表情包更新请求"""
|
"""表情包更新请求"""
|
||||||
|
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
is_registered: Optional[bool] = None
|
is_registered: Optional[bool] = None
|
||||||
is_banned: Optional[bool] = None
|
is_banned: Optional[bool] = None
|
||||||
emotion: Optional[List[str]] = None
|
emotion: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class EmojiUpdateResponse(BaseModel):
|
class EmojiUpdateResponse(BaseModel):
|
||||||
"""表情包更新响应"""
|
"""表情包更新响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
data: Optional[EmojiResponse] = None
|
data: Optional[EmojiResponse] = None
|
||||||
@@ -63,10 +71,27 @@ class EmojiUpdateResponse(BaseModel):
|
|||||||
|
|
||||||
class EmojiDeleteResponse(BaseModel):
|
class EmojiDeleteResponse(BaseModel):
|
||||||
"""表情包删除响应"""
|
"""表情包删除响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDeleteRequest(BaseModel):
|
||||||
|
"""批量删除请求"""
|
||||||
|
|
||||||
|
emoji_ids: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDeleteResponse(BaseModel):
|
||||||
|
"""批量删除响应"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
deleted_count: int
|
||||||
|
failed_count: int
|
||||||
|
failed_ids: List[int] = []
|
||||||
|
|
||||||
|
|
||||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||||
"""验证认证 Token"""
|
"""验证认证 Token"""
|
||||||
if not authorization or not authorization.startswith("Bearer "):
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
@@ -81,16 +106,6 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def parse_emotion(emotion_str: Optional[str]) -> Optional[List[str]]:
|
|
||||||
"""解析情感标签 JSON 字符串"""
|
|
||||||
if not emotion_str:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return json.loads(emotion_str)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
||||||
"""将 Emoji 模型转换为响应对象"""
|
"""将 Emoji 模型转换为响应对象"""
|
||||||
return EmojiResponse(
|
return EmojiResponse(
|
||||||
@@ -102,7 +117,7 @@ def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
|||||||
query_count=emoji.query_count,
|
query_count=emoji.query_count,
|
||||||
is_registered=emoji.is_registered,
|
is_registered=emoji.is_registered,
|
||||||
is_banned=emoji.is_banned,
|
is_banned=emoji.is_banned,
|
||||||
emotion=parse_emotion(emoji.emotion),
|
emotion=str(emoji.emotion) if emoji.emotion is not None else None,
|
||||||
record_time=emoji.record_time,
|
record_time=emoji.record_time,
|
||||||
register_time=emoji.register_time,
|
register_time=emoji.register_time,
|
||||||
usage_count=emoji.usage_count,
|
usage_count=emoji.usage_count,
|
||||||
@@ -118,7 +133,7 @@ async def get_emoji_list(
|
|||||||
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
||||||
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
||||||
format: Optional[str] = Query(None, description="格式筛选"),
|
format: Optional[str] = Query(None, description="格式筛选"),
|
||||||
authorization: Optional[str] = Header(None)
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取表情包列表
|
获取表情包列表
|
||||||
@@ -143,10 +158,7 @@ async def get_emoji_list(
|
|||||||
|
|
||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where(
|
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
|
||||||
(Emoji.description.contains(search)) |
|
|
||||||
(Emoji.emoji_hash.contains(search))
|
|
||||||
)
|
|
||||||
|
|
||||||
# 注册状态过滤
|
# 注册状态过滤
|
||||||
if is_registered is not None:
|
if is_registered is not None:
|
||||||
@@ -162,10 +174,9 @@ async def get_emoji_list(
|
|||||||
|
|
||||||
# 排序:使用次数倒序,然后按记录时间倒序
|
# 排序:使用次数倒序,然后按记录时间倒序
|
||||||
from peewee import Case
|
from peewee import Case
|
||||||
|
|
||||||
query = query.order_by(
|
query = query.order_by(
|
||||||
Emoji.usage_count.desc(),
|
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
|
||||||
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
|
|
||||||
Emoji.record_time.desc()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
@@ -178,13 +189,7 @@ async def get_emoji_list(
|
|||||||
# 转换为响应对象
|
# 转换为响应对象
|
||||||
data = [emoji_to_response(emoji) for emoji in emojis]
|
data = [emoji_to_response(emoji) for emoji in emojis]
|
||||||
|
|
||||||
return EmojiListResponse(
|
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||||
success=True,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -194,10 +199,7 @@ async def get_emoji_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||||
async def get_emoji_detail(
|
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取表情包详细信息
|
获取表情包详细信息
|
||||||
|
|
||||||
@@ -216,10 +218,7 @@ async def get_emoji_detail(
|
|||||||
if not emoji:
|
if not emoji:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||||
|
|
||||||
return EmojiDetailResponse(
|
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
|
||||||
success=True,
|
|
||||||
data=emoji_to_response(emoji)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -229,11 +228,7 @@ async def get_emoji_detail(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||||
async def update_emoji(
|
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
request: EmojiUpdateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
增量更新表情包(只更新提供的字段)
|
增量更新表情包(只更新提供的字段)
|
||||||
|
|
||||||
@@ -259,16 +254,11 @@ async def update_emoji(
|
|||||||
if not update_data:
|
if not update_data:
|
||||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||||
|
|
||||||
# 处理情感标签(转换为 JSON)
|
# emotion 字段直接使用字符串,无需转换
|
||||||
if 'emotion' in update_data:
|
|
||||||
if update_data['emotion'] is None:
|
|
||||||
update_data['emotion'] = None
|
|
||||||
else:
|
|
||||||
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
|
|
||||||
|
|
||||||
# 如果注册状态从 False 变为 True,记录注册时间
|
# 如果注册状态从 False 变为 True,记录注册时间
|
||||||
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
|
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||||
update_data['register_time'] = time.time()
|
update_data["register_time"] = time.time()
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -279,9 +269,7 @@ async def update_emoji(
|
|||||||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||||
|
|
||||||
return EmojiUpdateResponse(
|
return EmojiUpdateResponse(
|
||||||
success=True,
|
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
|
||||||
message=f"成功更新 {len(update_data)} 个字段",
|
|
||||||
data=emoji_to_response(emoji)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -292,10 +280,7 @@ async def update_emoji(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||||
async def delete_emoji(
|
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
删除表情包
|
删除表情包
|
||||||
|
|
||||||
@@ -322,10 +307,7 @@ async def delete_emoji(
|
|||||||
|
|
||||||
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
||||||
|
|
||||||
return EmojiDeleteResponse(
|
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
|
||||||
success=True,
|
|
||||||
message=f"成功删除表情包: {emoji_hash}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -335,9 +317,7 @@ async def delete_emoji(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_emoji_stats(
|
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取表情包统计数据
|
获取表情包统计数据
|
||||||
|
|
||||||
@@ -367,7 +347,7 @@ async def get_emoji_stats(
|
|||||||
"id": emoji.id,
|
"id": emoji.id,
|
||||||
"emoji_hash": emoji.emoji_hash,
|
"emoji_hash": emoji.emoji_hash,
|
||||||
"description": emoji.description,
|
"description": emoji.description,
|
||||||
"usage_count": emoji.usage_count
|
"usage_count": emoji.usage_count,
|
||||||
}
|
}
|
||||||
for emoji in top_used
|
for emoji in top_used
|
||||||
]
|
]
|
||||||
@@ -380,8 +360,8 @@ async def get_emoji_stats(
|
|||||||
"banned": banned,
|
"banned": banned,
|
||||||
"unregistered": total - registered,
|
"unregistered": total - registered,
|
||||||
"formats": formats,
|
"formats": formats,
|
||||||
"top_used": top_used_list
|
"top_used": top_used_list,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -392,10 +372,7 @@ async def get_emoji_stats(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||||
async def register_emoji(
|
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
注册表情包(快捷操作)
|
注册表情包(快捷操作)
|
||||||
|
|
||||||
@@ -427,11 +404,7 @@ async def register_emoji(
|
|||||||
|
|
||||||
logger.info(f"表情包已注册: ID={emoji_id}")
|
logger.info(f"表情包已注册: ID={emoji_id}")
|
||||||
|
|
||||||
return EmojiUpdateResponse(
|
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
|
||||||
success=True,
|
|
||||||
message="表情包注册成功",
|
|
||||||
data=emoji_to_response(emoji)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -441,10 +414,7 @@ async def register_emoji(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||||
async def ban_emoji(
|
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
禁用表情包(快捷操作)
|
禁用表情包(快捷操作)
|
||||||
|
|
||||||
@@ -470,14 +440,123 @@ async def ban_emoji(
|
|||||||
|
|
||||||
logger.info(f"表情包已禁用: ID={emoji_id}")
|
logger.info(f"表情包已禁用: ID={emoji_id}")
|
||||||
|
|
||||||
return EmojiUpdateResponse(
|
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
|
||||||
success=True,
|
|
||||||
message="表情包禁用成功",
|
|
||||||
data=emoji_to_response(emoji)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"禁用表情包失败: {e}")
|
logger.exception(f"禁用表情包失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"禁用表情包失败: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"禁用表情包失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{emoji_id}/thumbnail")
|
||||||
|
async def get_emoji_thumbnail(
|
||||||
|
emoji_id: int,
|
||||||
|
token: Optional[str] = Query(None, description="访问令牌"),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取表情包缩略图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emoji_id: 表情包ID
|
||||||
|
token: 访问令牌(通过 query parameter)
|
||||||
|
authorization: Authorization header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表情包图片文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 优先使用 query parameter 中的 token(用于 img 标签)
|
||||||
|
if token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token_manager.verify_token(token):
|
||||||
|
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||||
|
else:
|
||||||
|
# 如果没有 query token,则验证 Authorization header
|
||||||
|
verify_auth_token(authorization)
|
||||||
|
|
||||||
|
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||||
|
|
||||||
|
if not emoji:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not os.path.exists(emoji.full_path):
|
||||||
|
raise HTTPException(status_code=404, detail="表情包文件不存在")
|
||||||
|
|
||||||
|
# 根据格式设置 MIME 类型
|
||||||
|
mime_types = {
|
||||||
|
"png": "image/png",
|
||||||
|
"jpg": "image/jpeg",
|
||||||
|
"jpeg": "image/jpeg",
|
||||||
|
"gif": "image/gif",
|
||||||
|
"webp": "image/webp",
|
||||||
|
"bmp": "image/bmp",
|
||||||
|
}
|
||||||
|
|
||||||
|
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||||
|
|
||||||
|
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"获取表情包缩略图失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||||
|
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
批量删除表情包
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 包含emoji_ids列表的请求
|
||||||
|
authorization: Authorization header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
批量删除结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
verify_auth_token(authorization)
|
||||||
|
|
||||||
|
if not request.emoji_ids:
|
||||||
|
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
failed_ids = []
|
||||||
|
|
||||||
|
for emoji_id in request.emoji_ids:
|
||||||
|
try:
|
||||||
|
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||||
|
if emoji:
|
||||||
|
emoji.delete_instance()
|
||||||
|
deleted_count += 1
|
||||||
|
logger.info(f"批量删除表情包: {emoji_id}")
|
||||||
|
else:
|
||||||
|
failed_count += 1
|
||||||
|
failed_ids.append(emoji_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除表情包 {emoji_id} 失败: {e}")
|
||||||
|
failed_count += 1
|
||||||
|
failed_ids.append(emoji_id)
|
||||||
|
|
||||||
|
message = f"成功删除 {deleted_count} 个表情包"
|
||||||
|
if failed_count > 0:
|
||||||
|
message += f",{failed_count} 个失败"
|
||||||
|
|
||||||
|
return BatchDeleteResponse(
|
||||||
|
success=True,
|
||||||
|
message=message,
|
||||||
|
deleted_count=deleted_count,
|
||||||
|
failed_count=failed_count,
|
||||||
|
failed_ids=failed_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"批量删除表情包失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""表达方式管理 API 路由"""
|
"""表达方式管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query
|
from fastapi import APIRouter, HTTPException, Header, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"])
|
|||||||
|
|
||||||
class ExpressionResponse(BaseModel):
|
class ExpressionResponse(BaseModel):
|
||||||
"""表达方式响应"""
|
"""表达方式响应"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
situation: str
|
situation: str
|
||||||
style: str
|
style: str
|
||||||
@@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel):
|
|||||||
|
|
||||||
class ExpressionListResponse(BaseModel):
|
class ExpressionListResponse(BaseModel):
|
||||||
"""表达方式列表响应"""
|
"""表达方式列表响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
@@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel):
|
|||||||
|
|
||||||
class ExpressionDetailResponse(BaseModel):
|
class ExpressionDetailResponse(BaseModel):
|
||||||
"""表达方式详情响应"""
|
"""表达方式详情响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
data: ExpressionResponse
|
data: ExpressionResponse
|
||||||
|
|
||||||
|
|
||||||
class ExpressionCreateRequest(BaseModel):
|
class ExpressionCreateRequest(BaseModel):
|
||||||
"""表达方式创建请求"""
|
"""表达方式创建请求"""
|
||||||
|
|
||||||
situation: str
|
situation: str
|
||||||
style: str
|
style: str
|
||||||
context: Optional[str] = None
|
context: Optional[str] = None
|
||||||
@@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel):
|
|||||||
|
|
||||||
class ExpressionUpdateRequest(BaseModel):
|
class ExpressionUpdateRequest(BaseModel):
|
||||||
"""表达方式更新请求"""
|
"""表达方式更新请求"""
|
||||||
|
|
||||||
situation: Optional[str] = None
|
situation: Optional[str] = None
|
||||||
style: Optional[str] = None
|
style: Optional[str] = None
|
||||||
context: Optional[str] = None
|
context: Optional[str] = None
|
||||||
@@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel):
|
|||||||
|
|
||||||
class ExpressionUpdateResponse(BaseModel):
|
class ExpressionUpdateResponse(BaseModel):
|
||||||
"""表达方式更新响应"""
|
"""表达方式更新响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
data: Optional[ExpressionResponse] = None
|
data: Optional[ExpressionResponse] = None
|
||||||
@@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel):
|
|||||||
|
|
||||||
class ExpressionDeleteResponse(BaseModel):
|
class ExpressionDeleteResponse(BaseModel):
|
||||||
"""表达方式删除响应"""
|
"""表达方式删除响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
class ExpressionCreateResponse(BaseModel):
|
class ExpressionCreateResponse(BaseModel):
|
||||||
"""表达方式创建响应"""
|
"""表达方式创建响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
data: ExpressionResponse
|
data: ExpressionResponse
|
||||||
@@ -112,7 +121,7 @@ async def get_expression_list(
|
|||||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||||
authorization: Optional[str] = Header(None)
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取表达方式列表
|
获取表达方式列表
|
||||||
@@ -136,9 +145,9 @@ async def get_expression_list(
|
|||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where(
|
query = query.where(
|
||||||
(Expression.situation.contains(search)) |
|
(Expression.situation.contains(search))
|
||||||
(Expression.style.contains(search)) |
|
| (Expression.style.contains(search))
|
||||||
(Expression.context.contains(search))
|
| (Expression.context.contains(search))
|
||||||
)
|
)
|
||||||
|
|
||||||
# 聊天ID过滤
|
# 聊天ID过滤
|
||||||
@@ -147,9 +156,9 @@ async def get_expression_list(
|
|||||||
|
|
||||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||||
from peewee import Case
|
from peewee import Case
|
||||||
|
|
||||||
query = query.order_by(
|
query = query.order_by(
|
||||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
|
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||||
Expression.last_active_time.desc()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
@@ -162,13 +171,7 @@ async def get_expression_list(
|
|||||||
# 转换为响应对象
|
# 转换为响应对象
|
||||||
data = [expression_to_response(expr) for expr in expressions]
|
data = [expression_to_response(expr) for expr in expressions]
|
||||||
|
|
||||||
return ExpressionListResponse(
|
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||||
success=True,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -178,10 +181,7 @@ async def get_expression_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||||
async def get_expression_detail(
|
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||||
expression_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取表达方式详细信息
|
获取表达方式详细信息
|
||||||
|
|
||||||
@@ -200,10 +200,7 @@ async def get_expression_detail(
|
|||||||
if not expression:
|
if not expression:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||||
|
|
||||||
return ExpressionDetailResponse(
|
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||||
success=True,
|
|
||||||
data=expression_to_response(expression)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -213,10 +210,7 @@ async def get_expression_detail(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=ExpressionCreateResponse)
|
@router.post("/", response_model=ExpressionCreateResponse)
|
||||||
async def create_expression(
|
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||||
request: ExpressionCreateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
创建新的表达方式
|
创建新的表达方式
|
||||||
|
|
||||||
@@ -246,9 +240,7 @@ async def create_expression(
|
|||||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||||
|
|
||||||
return ExpressionCreateResponse(
|
return ExpressionCreateResponse(
|
||||||
success=True,
|
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||||
message="表达方式创建成功",
|
|
||||||
data=expression_to_response(expression)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -260,9 +252,7 @@ async def create_expression(
|
|||||||
|
|
||||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||||
async def update_expression(
|
async def update_expression(
|
||||||
expression_id: int,
|
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||||
request: ExpressionUpdateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
增量更新表达方式(只更新提供的字段)
|
增量更新表达方式(只更新提供的字段)
|
||||||
@@ -290,7 +280,7 @@ async def update_expression(
|
|||||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||||
|
|
||||||
# 更新最后活跃时间
|
# 更新最后活跃时间
|
||||||
update_data['last_active_time'] = time.time()
|
update_data["last_active_time"] = time.time()
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -301,9 +291,7 @@ async def update_expression(
|
|||||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||||
|
|
||||||
return ExpressionUpdateResponse(
|
return ExpressionUpdateResponse(
|
||||||
success=True,
|
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||||
message=f"成功更新 {len(update_data)} 个字段",
|
|
||||||
data=expression_to_response(expression)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -314,10 +302,7 @@ async def update_expression(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||||
async def delete_expression(
|
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||||
expression_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
删除表达方式
|
删除表达方式
|
||||||
|
|
||||||
@@ -344,10 +329,7 @@ async def delete_expression(
|
|||||||
|
|
||||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||||
|
|
||||||
return ExpressionDeleteResponse(
|
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
|
||||||
success=True,
|
|
||||||
message=f"成功删除表达方式: {situation}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -356,10 +338,55 @@ async def delete_expression(
|
|||||||
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDeleteRequest(BaseModel):
|
||||||
|
"""批量删除请求"""
|
||||||
|
|
||||||
|
ids: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||||
|
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
批量删除表达方式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 包含要删除的ID列表的请求
|
||||||
|
authorization: Authorization header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
verify_auth_token(authorization)
|
||||||
|
|
||||||
|
if not request.ids:
|
||||||
|
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||||
|
|
||||||
|
# 查找所有要删除的表达方式
|
||||||
|
expressions = Expression.select().where(Expression.id.in_(request.ids))
|
||||||
|
found_ids = [expr.id for expr in expressions]
|
||||||
|
|
||||||
|
# 检查是否有未找到的ID
|
||||||
|
not_found_ids = set(request.ids) - set(found_ids)
|
||||||
|
if not_found_ids:
|
||||||
|
logger.warning(f"部分表达方式未找到: {not_found_ids}")
|
||||||
|
|
||||||
|
# 执行批量删除
|
||||||
|
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
|
||||||
|
|
||||||
|
logger.info(f"批量删除了 {deleted_count} 个表达方式")
|
||||||
|
|
||||||
|
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"批量删除表达方式失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_expression_stats(
|
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取表达方式统计数据
|
获取表达方式统计数据
|
||||||
|
|
||||||
@@ -382,10 +409,11 @@ async def get_expression_stats(
|
|||||||
|
|
||||||
# 获取最近创建的记录数(7天内)
|
# 获取最近创建的记录数(7天内)
|
||||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||||
recent = Expression.select().where(
|
recent = (
|
||||||
(Expression.create_date.is_null(False)) &
|
Expression.select()
|
||||||
(Expression.create_date >= seven_days_ago)
|
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||||
).count()
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -393,8 +421,8 @@ async def get_expression_stats(
|
|||||||
"total": total,
|
"total": total,
|
||||||
"recent_7days": recent,
|
"recent_7days": recent,
|
||||||
"chat_count": len(chat_stats),
|
"chat_count": len(chat_stats),
|
||||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
|
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||||
|
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import httpx
|
import httpx
|
||||||
@@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror")
|
|||||||
# 导入进度更新函数(避免循环导入)
|
# 导入进度更新函数(避免循环导入)
|
||||||
_update_progress = None
|
_update_progress = None
|
||||||
|
|
||||||
|
|
||||||
def set_update_progress_callback(callback):
|
def set_update_progress_callback(callback):
|
||||||
"""设置进度更新回调函数"""
|
"""设置进度更新回调函数"""
|
||||||
global _update_progress
|
global _update_progress
|
||||||
@@ -23,6 +25,7 @@ def set_update_progress_callback(callback):
|
|||||||
|
|
||||||
class MirrorType(str, Enum):
|
class MirrorType(str, Enum):
|
||||||
"""镜像源类型"""
|
"""镜像源类型"""
|
||||||
|
|
||||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||||
@@ -47,7 +50,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 1,
|
"priority": 1,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "hk-gh-proxy",
|
"id": "hk-gh-proxy",
|
||||||
@@ -56,7 +59,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 2,
|
"priority": 2,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "cdn-gh-proxy",
|
"id": "cdn-gh-proxy",
|
||||||
@@ -65,7 +68,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 3,
|
"priority": 3,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "edgeone-gh-proxy",
|
"id": "edgeone-gh-proxy",
|
||||||
@@ -74,7 +77,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 4,
|
"priority": 4,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "meyzh-github",
|
"id": "meyzh-github",
|
||||||
@@ -83,7 +86,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 5,
|
"priority": 5,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "github",
|
"id": "github",
|
||||||
@@ -92,8 +95,8 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://github.com",
|
"clone_prefix": "https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 999,
|
"priority": 999,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -106,7 +109,7 @@ class GitMirrorConfig:
|
|||||||
"""加载配置文件"""
|
"""加载配置文件"""
|
||||||
try:
|
try:
|
||||||
if self.config_file.exists():
|
if self.config_file.exists():
|
||||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
# 检查是否有镜像源配置
|
# 检查是否有镜像源配置
|
||||||
@@ -145,14 +148,14 @@ class GitMirrorConfig:
|
|||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
existing_data = {}
|
existing_data = {}
|
||||||
if self.config_file.exists():
|
if self.config_file.exists():
|
||||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||||
existing_data = json.load(f)
|
existing_data = json.load(f)
|
||||||
|
|
||||||
# 更新镜像源配置
|
# 更新镜像源配置
|
||||||
existing_data["git_mirrors"] = self.mirrors
|
existing_data["git_mirrors"] = self.mirrors
|
||||||
|
|
||||||
# 写入文件
|
# 写入文件
|
||||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
logger.debug(f"配置已保存到 {self.config_file}")
|
logger.debug(f"配置已保存到 {self.config_file}")
|
||||||
@@ -182,7 +185,7 @@ class GitMirrorConfig:
|
|||||||
raw_prefix: str,
|
raw_prefix: str,
|
||||||
clone_prefix: str,
|
clone_prefix: str,
|
||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
添加新的镜像源
|
添加新的镜像源
|
||||||
@@ -209,7 +212,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": clone_prefix,
|
"clone_prefix": clone_prefix,
|
||||||
"enabled": enabled,
|
"enabled": enabled,
|
||||||
"priority": priority,
|
"priority": priority,
|
||||||
"created_at": datetime.now().isoformat()
|
"created_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.mirrors.append(new_mirror)
|
self.mirrors.append(new_mirror)
|
||||||
@@ -225,7 +228,7 @@ class GitMirrorConfig:
|
|||||||
raw_prefix: Optional[str] = None,
|
raw_prefix: Optional[str] = None,
|
||||||
clone_prefix: Optional[str] = None,
|
clone_prefix: Optional[str] = None,
|
||||||
enabled: Optional[bool] = None,
|
enabled: Optional[bool] = None,
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
更新镜像源配置
|
更新镜像源配置
|
||||||
@@ -279,12 +282,7 @@ class GitMirrorConfig:
|
|||||||
class GitMirrorService:
|
class GitMirrorService:
|
||||||
"""Git 镜像源服务"""
|
"""Git 镜像源服务"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||||
self,
|
|
||||||
max_retries: int = 3,
|
|
||||||
timeout: int = 30,
|
|
||||||
config: Optional[GitMirrorConfig] = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
初始化 Git 镜像源服务
|
初始化 Git 镜像源服务
|
||||||
|
|
||||||
@@ -323,46 +321,25 @@ class GitMirrorService:
|
|||||||
|
|
||||||
if not git_path:
|
if not git_path:
|
||||||
logger.warning("未找到 Git 可执行文件")
|
logger.warning("未找到 Git 可执行文件")
|
||||||
return {
|
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||||
"installed": False,
|
|
||||||
"error": "系统中未找到 Git,请先安装 Git"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取 Git 版本
|
# 获取 Git 版本
|
||||||
result = subprocess.run(
|
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||||
["git", "--version"],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
version = result.stdout.strip()
|
version = result.stdout.strip()
|
||||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||||
return {
|
return {"installed": True, "version": version, "path": git_path}
|
||||||
"installed": True,
|
|
||||||
"version": version,
|
|
||||||
"path": git_path
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||||
return {
|
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||||
"installed": False,
|
|
||||||
"error": f"Git 命令执行失败: {result.stderr}"
|
|
||||||
}
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.error("Git 版本检测超时")
|
logger.error("Git 版本检测超时")
|
||||||
return {
|
return {"installed": False, "error": "Git 版本检测超时"}
|
||||||
"installed": False,
|
|
||||||
"error": "Git 版本检测超时"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检测 Git 时发生错误: {e}")
|
logger.error(f"检测 Git 时发生错误: {e}")
|
||||||
return {
|
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||||
"installed": False,
|
|
||||||
"error": f"检测 Git 时发生错误: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
async def fetch_raw_file(
|
async def fetch_raw_file(
|
||||||
self,
|
self,
|
||||||
@@ -371,7 +348,7 @@ class GitMirrorService:
|
|||||||
branch: str,
|
branch: str,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
mirror_id: Optional[str] = None,
|
mirror_id: Optional[str] = None,
|
||||||
custom_url: Optional[str] = None
|
custom_url: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取 GitHub 仓库的 Raw 文件内容
|
获取 GitHub 仓库的 Raw 文件内容
|
||||||
@@ -403,12 +380,7 @@ class GitMirrorService:
|
|||||||
# 使用指定的镜像源
|
# 使用指定的镜像源
|
||||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||||
if not mirror:
|
if not mirror:
|
||||||
return {
|
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||||
"success": False,
|
|
||||||
"error": f"未找到镜像源: {mirror_id}",
|
|
||||||
"mirror_used": None,
|
|
||||||
"attempts": 0
|
|
||||||
}
|
|
||||||
mirrors_to_try = [mirror]
|
mirrors_to_try = [mirror]
|
||||||
else:
|
else:
|
||||||
# 使用所有启用的镜像源
|
# 使用所有启用的镜像源
|
||||||
@@ -427,14 +399,12 @@ class GitMirrorService:
|
|||||||
progress=progress,
|
progress=progress,
|
||||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||||
total_plugins=0,
|
total_plugins=0,
|
||||||
loaded_plugins=0
|
loaded_plugins=0,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"推送进度失败: {e}")
|
logger.warning(f"推送进度失败: {e}")
|
||||||
|
|
||||||
result = await self._fetch_raw_from_mirror(
|
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||||
owner, repo, branch, file_path, mirror
|
|
||||||
)
|
|
||||||
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
# 成功,推送进度
|
# 成功,推送进度
|
||||||
@@ -445,7 +415,7 @@ class GitMirrorService:
|
|||||||
progress=70,
|
progress=70,
|
||||||
message=f"成功从 {mirror['name']} 获取数据",
|
message=f"成功从 {mirror['name']} 获取数据",
|
||||||
total_plugins=0,
|
total_plugins=0,
|
||||||
loaded_plugins=0
|
loaded_plugins=0,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"推送进度失败: {e}")
|
logger.warning(f"推送进度失败: {e}")
|
||||||
@@ -461,26 +431,16 @@ class GitMirrorService:
|
|||||||
progress=30 + int(index / total_mirrors * 40),
|
progress=30 + int(index / total_mirrors * 40),
|
||||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||||
total_plugins=0,
|
total_plugins=0,
|
||||||
loaded_plugins=0
|
loaded_plugins=0,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"推送进度失败: {e}")
|
logger.warning(f"推送进度失败: {e}")
|
||||||
|
|
||||||
# 所有镜像源都失败
|
# 所有镜像源都失败
|
||||||
return {
|
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||||
"success": False,
|
|
||||||
"error": "所有镜像源均失败",
|
|
||||||
"mirror_used": None,
|
|
||||||
"attempts": len(mirrors_to_try)
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _fetch_raw_from_mirror(
|
async def _fetch_raw_from_mirror(
|
||||||
self,
|
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||||
owner: str,
|
|
||||||
repo: str,
|
|
||||||
branch: str,
|
|
||||||
file_path: str,
|
|
||||||
mirror: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""从指定镜像源获取文件"""
|
"""从指定镜像源获取文件"""
|
||||||
# 构建 URL
|
# 构建 URL
|
||||||
@@ -508,7 +468,7 @@ class GitMirrorService:
|
|||||||
"data": response.text,
|
"data": response.text,
|
||||||
"mirror_used": mirror_type,
|
"mirror_used": mirror_type,
|
||||||
"attempts": attempts,
|
"attempts": attempts,
|
||||||
"url": url
|
"url": url,
|
||||||
}
|
}
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||||
@@ -520,13 +480,7 @@ class GitMirrorService:
|
|||||||
last_error = f"未知错误: {e}"
|
last_error = f"未知错误: {e}"
|
||||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||||
|
|
||||||
return {
|
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||||
"success": False,
|
|
||||||
"error": last_error,
|
|
||||||
"mirror_used": mirror_type,
|
|
||||||
"attempts": attempts,
|
|
||||||
"url": url
|
|
||||||
}
|
|
||||||
|
|
||||||
async def clone_repository(
|
async def clone_repository(
|
||||||
self,
|
self,
|
||||||
@@ -536,7 +490,7 @@ class GitMirrorService:
|
|||||||
branch: Optional[str] = None,
|
branch: Optional[str] = None,
|
||||||
mirror_id: Optional[str] = None,
|
mirror_id: Optional[str] = None,
|
||||||
custom_url: Optional[str] = None,
|
custom_url: Optional[str] = None,
|
||||||
depth: Optional[int] = None
|
depth: Optional[int] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
克隆 GitHub 仓库
|
克隆 GitHub 仓库
|
||||||
@@ -569,12 +523,7 @@ class GitMirrorService:
|
|||||||
# 使用指定的镜像源
|
# 使用指定的镜像源
|
||||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||||
if not mirror:
|
if not mirror:
|
||||||
return {
|
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||||
"success": False,
|
|
||||||
"error": f"未找到镜像源: {mirror_id}",
|
|
||||||
"mirror_used": None,
|
|
||||||
"attempts": 0
|
|
||||||
}
|
|
||||||
mirrors_to_try = [mirror]
|
mirrors_to_try = [mirror]
|
||||||
else:
|
else:
|
||||||
# 使用所有启用的镜像源
|
# 使用所有启用的镜像源
|
||||||
@@ -582,20 +531,13 @@ class GitMirrorService:
|
|||||||
|
|
||||||
# 依次尝试每个镜像源
|
# 依次尝试每个镜像源
|
||||||
for mirror in mirrors_to_try:
|
for mirror in mirrors_to_try:
|
||||||
result = await self._clone_from_mirror(
|
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||||
owner, repo, target_path, branch, depth, mirror
|
|
||||||
)
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
return result
|
return result
|
||||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||||
|
|
||||||
# 所有镜像源都失败
|
# 所有镜像源都失败
|
||||||
return {
|
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||||
"success": False,
|
|
||||||
"error": "所有镜像源克隆均失败",
|
|
||||||
"mirror_used": None,
|
|
||||||
"attempts": len(mirrors_to_try)
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _clone_from_mirror(
|
async def _clone_from_mirror(
|
||||||
self,
|
self,
|
||||||
@@ -604,7 +546,7 @@ class GitMirrorService:
|
|||||||
target_path: Path,
|
target_path: Path,
|
||||||
branch: Optional[str],
|
branch: Optional[str],
|
||||||
depth: Optional[int],
|
depth: Optional[int],
|
||||||
mirror: Dict[str, Any]
|
mirror: Dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""从指定镜像源克隆仓库"""
|
"""从指定镜像源克隆仓库"""
|
||||||
# 构建克隆 URL
|
# 构建克隆 URL
|
||||||
@@ -614,12 +556,7 @@ class GitMirrorService:
|
|||||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||||
|
|
||||||
async def _clone_with_url(
|
async def _clone_with_url(
|
||||||
self,
|
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||||
url: str,
|
|
||||||
target_path: Path,
|
|
||||||
branch: Optional[str],
|
|
||||||
depth: Optional[int],
|
|
||||||
mirror_type: str
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""使用指定 URL 克隆仓库,支持重试"""
|
"""使用指定 URL 克隆仓库,支持重试"""
|
||||||
attempts = 0
|
attempts = 0
|
||||||
@@ -657,7 +594,7 @@ class GitMirrorService:
|
|||||||
stage="loading",
|
stage="loading",
|
||||||
progress=20 + attempt * 10,
|
progress=20 + attempt * 10,
|
||||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||||
operation="install"
|
operation="install",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"推送进度失败: {e}")
|
logger.warning(f"推送进度失败: {e}")
|
||||||
@@ -670,7 +607,7 @@ class GitMirrorService:
|
|||||||
cmd,
|
cmd,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300 # 5分钟超时
|
timeout=300, # 5分钟超时
|
||||||
)
|
)
|
||||||
|
|
||||||
process = await loop.run_in_executor(None, run_git_clone)
|
process = await loop.run_in_executor(None, run_git_clone)
|
||||||
@@ -683,7 +620,7 @@ class GitMirrorService:
|
|||||||
"mirror_used": mirror_type,
|
"mirror_used": mirror_type,
|
||||||
"attempts": attempts,
|
"attempts": attempts,
|
||||||
"url": url,
|
"url": url,
|
||||||
"branch": branch or "default"
|
"branch": branch or "default",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
last_error = f"Git 克隆失败: {process.stderr}"
|
last_error = f"Git 克隆失败: {process.stderr}"
|
||||||
@@ -710,13 +647,7 @@ class GitMirrorService:
|
|||||||
if target_path.exists():
|
if target_path.exists():
|
||||||
shutil.rmtree(target_path, ignore_errors=True)
|
shutil.rmtree(target_path, ignore_errors=True)
|
||||||
|
|
||||||
return {
|
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||||
"success": False,
|
|
||||||
"error": last_error,
|
|
||||||
"mirror_used": mirror_type,
|
|
||||||
"attempts": attempts,
|
|
||||||
"url": url
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局服务实例
|
# 全局服务实例
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""WebSocket 日志推送模块"""
|
"""WebSocket 日志推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from typing import Set
|
from typing import Set
|
||||||
import json
|
import json
|
||||||
@@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
|||||||
log_entry = json.loads(line.strip())
|
log_entry = json.loads(line.strip())
|
||||||
# 转换为前端期望的格式
|
# 转换为前端期望的格式
|
||||||
# 使用时间戳 + 计数器生成唯一 ID
|
# 使用时间戳 + 计数器生成唯一 ID
|
||||||
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
timestamp_id = (
|
||||||
|
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||||
|
)
|
||||||
formatted_log = {
|
formatted_log = {
|
||||||
"id": f"{timestamp_id}_{log_counter}",
|
"id": f"{timestamp_id}_{log_counter}",
|
||||||
"timestamp": log_entry.get("timestamp", ""),
|
"timestamp": log_entry.get("timestamp", ""),
|
||||||
|
|||||||
@@ -1,108 +0,0 @@
|
|||||||
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from .token_manager import get_token_manager
|
|
||||||
|
|
||||||
logger = get_logger("webui")
|
|
||||||
|
|
||||||
|
|
||||||
def setup_webui(mode: str = "production") -> bool:
|
|
||||||
"""
|
|
||||||
设置 WebUI
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mode: 运行模式,"development" 或 "production"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功设置
|
|
||||||
"""
|
|
||||||
# 初始化 Token 管理器(确保 token 文件存在)
|
|
||||||
token_manager = get_token_manager()
|
|
||||||
current_token = token_manager.get_token()
|
|
||||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
|
||||||
logger.info("💡 请使用此 Token 登录 WebUI")
|
|
||||||
|
|
||||||
if mode == "development":
|
|
||||||
return setup_dev_mode()
|
|
||||||
else:
|
|
||||||
return setup_production_mode()
|
|
||||||
|
|
||||||
|
|
||||||
def setup_dev_mode() -> bool:
|
|
||||||
"""设置开发模式 - 仅启用 CORS,前端自行启动"""
|
|
||||||
from src.common.server import get_global_server
|
|
||||||
from .logs_ws import router as logs_router
|
|
||||||
|
|
||||||
# 注册 WebSocket 日志路由(开发模式也需要)
|
|
||||||
server = get_global_server()
|
|
||||||
server.register_router(logs_router)
|
|
||||||
logger.info("✅ WebSocket 日志推送路由已注册")
|
|
||||||
|
|
||||||
logger.info("📝 WebUI 开发模式已启用")
|
|
||||||
logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev")
|
|
||||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def setup_production_mode() -> bool:
|
|
||||||
"""设置生产模式 - 挂载静态文件"""
|
|
||||||
try:
|
|
||||||
from src.common.server import get_global_server
|
|
||||||
from starlette.responses import FileResponse
|
|
||||||
from .logs_ws import router as logs_router
|
|
||||||
import mimetypes
|
|
||||||
|
|
||||||
# 确保正确的 MIME 类型映射
|
|
||||||
mimetypes.init()
|
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
|
||||||
mimetypes.add_type('application/javascript', '.mjs')
|
|
||||||
mimetypes.add_type('text/css', '.css')
|
|
||||||
mimetypes.add_type('application/json', '.json')
|
|
||||||
|
|
||||||
server = get_global_server()
|
|
||||||
|
|
||||||
# 注册 WebSocket 日志路由
|
|
||||||
server.register_router(logs_router)
|
|
||||||
logger.info("✅ WebSocket 日志推送路由已注册")
|
|
||||||
|
|
||||||
base_dir = Path(__file__).parent.parent.parent
|
|
||||||
static_path = base_dir / "webui" / "dist"
|
|
||||||
|
|
||||||
if not static_path.exists():
|
|
||||||
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
|
|
||||||
logger.warning("💡 请先构建前端: cd webui && npm run build")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not (static_path / "index.html").exists():
|
|
||||||
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
|
|
||||||
logger.warning("💡 请确认前端已正确构建")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 处理 SPA 路由
|
|
||||||
@server.app.get("/{full_path:path}")
|
|
||||||
async def serve_spa(full_path: str):
|
|
||||||
"""服务单页应用"""
|
|
||||||
# API 路由不处理
|
|
||||||
if full_path.startswith("api/"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 检查文件是否存在
|
|
||||||
file_path = static_path / full_path
|
|
||||||
if file_path.is_file():
|
|
||||||
# 自动检测 MIME 类型
|
|
||||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
|
||||||
return FileResponse(file_path, media_type=media_type)
|
|
||||||
|
|
||||||
# 返回 index.html(SPA 路由)
|
|
||||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
|
||||||
|
|
||||||
host = os.getenv("HOST", "127.0.0.1")
|
|
||||||
port = os.getenv("PORT", "8000")
|
|
||||||
logger.info("✅ WebUI 生产模式已挂载")
|
|
||||||
logger.info(f"🌐 访问 http://{host}:{port} 查看 WebUI")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"挂载 WebUI 静态文件失败: {e}")
|
|
||||||
return False
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
"""人物信息管理 API 路由"""
|
"""人物信息管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query
|
from fastapi import APIRouter, HTTPException, Header, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
@@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"])
|
|||||||
|
|
||||||
class PersonInfoResponse(BaseModel):
|
class PersonInfoResponse(BaseModel):
|
||||||
"""人物信息响应"""
|
"""人物信息响应"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
is_known: bool
|
is_known: bool
|
||||||
person_id: str
|
person_id: str
|
||||||
@@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel):
|
|||||||
|
|
||||||
class PersonListResponse(BaseModel):
|
class PersonListResponse(BaseModel):
|
||||||
"""人物列表响应"""
|
"""人物列表响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
@@ -42,12 +45,14 @@ class PersonListResponse(BaseModel):
|
|||||||
|
|
||||||
class PersonDetailResponse(BaseModel):
|
class PersonDetailResponse(BaseModel):
|
||||||
"""人物详情响应"""
|
"""人物详情响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
data: PersonInfoResponse
|
data: PersonInfoResponse
|
||||||
|
|
||||||
|
|
||||||
class PersonUpdateRequest(BaseModel):
|
class PersonUpdateRequest(BaseModel):
|
||||||
"""人物信息更新请求"""
|
"""人物信息更新请求"""
|
||||||
|
|
||||||
person_name: Optional[str] = None
|
person_name: Optional[str] = None
|
||||||
name_reason: Optional[str] = None
|
name_reason: Optional[str] = None
|
||||||
nickname: Optional[str] = None
|
nickname: Optional[str] = None
|
||||||
@@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel):
|
|||||||
|
|
||||||
class PersonUpdateResponse(BaseModel):
|
class PersonUpdateResponse(BaseModel):
|
||||||
"""人物信息更新响应"""
|
"""人物信息更新响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
data: Optional[PersonInfoResponse] = None
|
data: Optional[PersonInfoResponse] = None
|
||||||
@@ -64,10 +70,27 @@ class PersonUpdateResponse(BaseModel):
|
|||||||
|
|
||||||
class PersonDeleteResponse(BaseModel):
|
class PersonDeleteResponse(BaseModel):
|
||||||
"""人物删除响应"""
|
"""人物删除响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDeleteRequest(BaseModel):
|
||||||
|
"""批量删除请求"""
|
||||||
|
|
||||||
|
person_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDeleteResponse(BaseModel):
|
||||||
|
"""批量删除响应"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
deleted_count: int
|
||||||
|
failed_count: int
|
||||||
|
failed_ids: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||||
"""验证认证 Token"""
|
"""验证认证 Token"""
|
||||||
if not authorization or not authorization.startswith("Bearer "):
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
@@ -118,7 +141,7 @@ async def get_person_list(
|
|||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||||
authorization: Optional[str] = Header(None)
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取人物信息列表
|
获取人物信息列表
|
||||||
@@ -143,9 +166,9 @@ async def get_person_list(
|
|||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where(
|
query = query.where(
|
||||||
(PersonInfo.person_name.contains(search)) |
|
(PersonInfo.person_name.contains(search))
|
||||||
(PersonInfo.nickname.contains(search)) |
|
| (PersonInfo.nickname.contains(search))
|
||||||
(PersonInfo.user_id.contains(search))
|
| (PersonInfo.user_id.contains(search))
|
||||||
)
|
)
|
||||||
|
|
||||||
# 已认识状态过滤
|
# 已认识状态过滤
|
||||||
@@ -159,10 +182,8 @@ async def get_person_list(
|
|||||||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||||
from peewee import Case
|
from peewee import Case
|
||||||
query = query.order_by(
|
|
||||||
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
|
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||||
PersonInfo.last_know.desc()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
total = query.count()
|
total = query.count()
|
||||||
@@ -174,13 +195,7 @@ async def get_person_list(
|
|||||||
# 转换为响应对象
|
# 转换为响应对象
|
||||||
data = [person_to_response(person) for person in persons]
|
data = [person_to_response(person) for person in persons]
|
||||||
|
|
||||||
return PersonListResponse(
|
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||||
success=True,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -190,10 +205,7 @@ async def get_person_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||||
async def get_person_detail(
|
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
|
||||||
person_id: str,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取人物详细信息
|
获取人物详细信息
|
||||||
|
|
||||||
@@ -212,10 +224,7 @@ async def get_person_detail(
|
|||||||
if not person:
|
if not person:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||||
|
|
||||||
return PersonDetailResponse(
|
return PersonDetailResponse(success=True, data=person_to_response(person))
|
||||||
success=True,
|
|
||||||
data=person_to_response(person)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -225,11 +234,7 @@ async def get_person_detail(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||||
async def update_person(
|
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||||
person_id: str,
|
|
||||||
request: PersonUpdateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
增量更新人物信息(只更新提供的字段)
|
增量更新人物信息(只更新提供的字段)
|
||||||
|
|
||||||
@@ -256,7 +261,7 @@ async def update_person(
|
|||||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||||
|
|
||||||
# 更新最后修改时间
|
# 更新最后修改时间
|
||||||
update_data['last_know'] = time.time()
|
update_data["last_know"] = time.time()
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -267,9 +272,7 @@ async def update_person(
|
|||||||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||||
|
|
||||||
return PersonUpdateResponse(
|
return PersonUpdateResponse(
|
||||||
success=True,
|
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
|
||||||
message=f"成功更新 {len(update_data)} 个字段",
|
|
||||||
data=person_to_response(person)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -280,10 +283,7 @@ async def update_person(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||||
async def delete_person(
|
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
|
||||||
person_id: str,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
删除人物信息
|
删除人物信息
|
||||||
|
|
||||||
@@ -310,10 +310,7 @@ async def delete_person(
|
|||||||
|
|
||||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||||
|
|
||||||
return PersonDeleteResponse(
|
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
|
||||||
success=True,
|
|
||||||
message=f"成功删除人物信息: {person_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -323,9 +320,7 @@ async def delete_person(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_person_stats(
|
async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取人物信息统计数据
|
获取人物信息统计数据
|
||||||
|
|
||||||
@@ -348,18 +343,66 @@ async def get_person_stats(
|
|||||||
platform = person.platform
|
platform = person.platform
|
||||||
platforms[platform] = platforms.get(platform, 0) + 1
|
platforms[platform] = platforms.get(platform, 0) + 1
|
||||||
|
|
||||||
return {
|
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"total": total,
|
|
||||||
"known": known,
|
|
||||||
"unknown": unknown,
|
|
||||||
"platforms": platforms
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"获取统计数据失败: {e}")
|
logger.exception(f"获取统计数据失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||||
|
async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
批量删除人物信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 包含person_ids列表的请求
|
||||||
|
authorization: Authorization header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
批量删除结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
verify_auth_token(authorization)
|
||||||
|
|
||||||
|
if not request.person_ids:
|
||||||
|
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
failed_ids = []
|
||||||
|
|
||||||
|
for person_id in request.person_ids:
|
||||||
|
try:
|
||||||
|
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||||
|
if person:
|
||||||
|
person.delete_instance()
|
||||||
|
deleted_count += 1
|
||||||
|
logger.info(f"批量删除: {person_id}")
|
||||||
|
else:
|
||||||
|
failed_count += 1
|
||||||
|
failed_ids.append(person_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除 {person_id} 失败: {e}")
|
||||||
|
failed_count += 1
|
||||||
|
failed_ids.append(person_id)
|
||||||
|
|
||||||
|
message = f"成功删除 {deleted_count} 个人物"
|
||||||
|
if failed_count > 0:
|
||||||
|
message += f",{failed_count} 个失败"
|
||||||
|
|
||||||
|
return BatchDeleteResponse(
|
||||||
|
success=True,
|
||||||
|
message=message,
|
||||||
|
deleted_count=deleted_count,
|
||||||
|
failed_count=failed_count,
|
||||||
|
failed_ids=failed_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"批量删除人物信息失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""WebSocket 插件加载进度推送模块"""
|
"""WebSocket 插件加载进度推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from typing import Set, Dict, Any
|
from typing import Set, Dict, Any
|
||||||
import json
|
import json
|
||||||
@@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = {
|
|||||||
"error": None,
|
"error": None,
|
||||||
"plugin_id": None, # 当前操作的插件 ID
|
"plugin_id": None, # 当前操作的插件 ID
|
||||||
"total_plugins": 0,
|
"total_plugins": 0,
|
||||||
"loaded_plugins": 0
|
"loaded_plugins": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +58,7 @@ async def update_progress(
|
|||||||
error: str = None,
|
error: str = None,
|
||||||
plugin_id: str = None,
|
plugin_id: str = None,
|
||||||
total_plugins: int = 0,
|
total_plugins: int = 0,
|
||||||
loaded_plugins: int = 0
|
loaded_plugins: int = 0,
|
||||||
):
|
):
|
||||||
"""更新并广播进度
|
"""更新并广播进度
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ async def update_progress(
|
|||||||
"plugin_id": plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"total_plugins": total_plugins,
|
"total_plugins": total_plugins,
|
||||||
"loaded_plugins": loaded_plugins,
|
"loaded_plugins": loaded_plugins,
|
||||||
"timestamp": asyncio.get_event_loop().time()
|
"timestamp": asyncio.get_event_loop().time(),
|
||||||
}
|
}
|
||||||
|
|
||||||
await broadcast_progress(progress_data)
|
await broadcast_progress(progress_data)
|
||||||
|
|||||||
@@ -30,12 +30,12 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||||||
(major, minor, patch) 三元组
|
(major, minor, patch) 三元组
|
||||||
"""
|
"""
|
||||||
# 移除 snapshot 等后缀
|
# 移除 snapshot 等后缀
|
||||||
base_version = version_str.split('.snapshot')[0].split('.dev')[0].split('.alpha')[0].split('.beta')[0]
|
base_version = version_str.split(".snapshot")[0].split(".dev")[0].split(".alpha")[0].split(".beta")[0]
|
||||||
|
|
||||||
parts = base_version.split('.')
|
parts = base_version.split(".")
|
||||||
if len(parts) < 3:
|
if len(parts) < 3:
|
||||||
# 补齐到 3 位
|
# 补齐到 3 位
|
||||||
parts.extend(['0'] * (3 - len(parts)))
|
parts.extend(["0"] * (3 - len(parts)))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
major = int(parts[0])
|
major = int(parts[0])
|
||||||
@@ -49,8 +49,10 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||||||
|
|
||||||
# ============ 请求/响应模型 ============
|
# ============ 请求/响应模型 ============
|
||||||
|
|
||||||
|
|
||||||
class FetchRawFileRequest(BaseModel):
|
class FetchRawFileRequest(BaseModel):
|
||||||
"""获取 Raw 文件请求"""
|
"""获取 Raw 文件请求"""
|
||||||
|
|
||||||
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
||||||
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
||||||
branch: str = Field(..., description="分支名称", example="main")
|
branch: str = Field(..., description="分支名称", example="main")
|
||||||
@@ -61,6 +63,7 @@ class FetchRawFileRequest(BaseModel):
|
|||||||
|
|
||||||
class FetchRawFileResponse(BaseModel):
|
class FetchRawFileResponse(BaseModel):
|
||||||
"""获取 Raw 文件响应"""
|
"""获取 Raw 文件响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否成功")
|
success: bool = Field(..., description="是否成功")
|
||||||
data: Optional[str] = Field(None, description="文件内容")
|
data: Optional[str] = Field(None, description="文件内容")
|
||||||
error: Optional[str] = Field(None, description="错误信息")
|
error: Optional[str] = Field(None, description="错误信息")
|
||||||
@@ -71,6 +74,7 @@ class FetchRawFileResponse(BaseModel):
|
|||||||
|
|
||||||
class CloneRepositoryRequest(BaseModel):
|
class CloneRepositoryRequest(BaseModel):
|
||||||
"""克隆仓库请求"""
|
"""克隆仓库请求"""
|
||||||
|
|
||||||
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
||||||
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
||||||
target_path: str = Field(..., description="目标路径(相对于插件目录)")
|
target_path: str = Field(..., description="目标路径(相对于插件目录)")
|
||||||
@@ -82,6 +86,7 @@ class CloneRepositoryRequest(BaseModel):
|
|||||||
|
|
||||||
class CloneRepositoryResponse(BaseModel):
|
class CloneRepositoryResponse(BaseModel):
|
||||||
"""克隆仓库响应"""
|
"""克隆仓库响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否成功")
|
success: bool = Field(..., description="是否成功")
|
||||||
path: Optional[str] = Field(None, description="克隆路径")
|
path: Optional[str] = Field(None, description="克隆路径")
|
||||||
error: Optional[str] = Field(None, description="错误信息")
|
error: Optional[str] = Field(None, description="错误信息")
|
||||||
@@ -93,6 +98,7 @@ class CloneRepositoryResponse(BaseModel):
|
|||||||
|
|
||||||
class MirrorConfigResponse(BaseModel):
|
class MirrorConfigResponse(BaseModel):
|
||||||
"""镜像源配置响应"""
|
"""镜像源配置响应"""
|
||||||
|
|
||||||
id: str = Field(..., description="镜像源 ID")
|
id: str = Field(..., description="镜像源 ID")
|
||||||
name: str = Field(..., description="镜像源名称")
|
name: str = Field(..., description="镜像源名称")
|
||||||
raw_prefix: str = Field(..., description="Raw 文件前缀")
|
raw_prefix: str = Field(..., description="Raw 文件前缀")
|
||||||
@@ -103,12 +109,14 @@ class MirrorConfigResponse(BaseModel):
|
|||||||
|
|
||||||
class AvailableMirrorsResponse(BaseModel):
|
class AvailableMirrorsResponse(BaseModel):
|
||||||
"""可用镜像源列表响应"""
|
"""可用镜像源列表响应"""
|
||||||
|
|
||||||
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
||||||
default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
||||||
|
|
||||||
|
|
||||||
class AddMirrorRequest(BaseModel):
|
class AddMirrorRequest(BaseModel):
|
||||||
"""添加镜像源请求"""
|
"""添加镜像源请求"""
|
||||||
|
|
||||||
id: str = Field(..., description="镜像源 ID", example="custom-mirror")
|
id: str = Field(..., description="镜像源 ID", example="custom-mirror")
|
||||||
name: str = Field(..., description="镜像源名称", example="自定义镜像源")
|
name: str = Field(..., description="镜像源名称", example="自定义镜像源")
|
||||||
raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw")
|
raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw")
|
||||||
@@ -119,6 +127,7 @@ class AddMirrorRequest(BaseModel):
|
|||||||
|
|
||||||
class UpdateMirrorRequest(BaseModel):
|
class UpdateMirrorRequest(BaseModel):
|
||||||
"""更新镜像源请求"""
|
"""更新镜像源请求"""
|
||||||
|
|
||||||
name: Optional[str] = Field(None, description="镜像源名称")
|
name: Optional[str] = Field(None, description="镜像源名称")
|
||||||
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
|
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
|
||||||
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
|
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
|
||||||
@@ -128,6 +137,7 @@ class UpdateMirrorRequest(BaseModel):
|
|||||||
|
|
||||||
class GitStatusResponse(BaseModel):
|
class GitStatusResponse(BaseModel):
|
||||||
"""Git 安装状态响应"""
|
"""Git 安装状态响应"""
|
||||||
|
|
||||||
installed: bool = Field(..., description="是否已安装 Git")
|
installed: bool = Field(..., description="是否已安装 Git")
|
||||||
version: Optional[str] = Field(None, description="Git 版本号")
|
version: Optional[str] = Field(None, description="Git 版本号")
|
||||||
path: Optional[str] = Field(None, description="Git 可执行文件路径")
|
path: Optional[str] = Field(None, description="Git 可执行文件路径")
|
||||||
@@ -136,6 +146,7 @@ class GitStatusResponse(BaseModel):
|
|||||||
|
|
||||||
class InstallPluginRequest(BaseModel):
|
class InstallPluginRequest(BaseModel):
|
||||||
"""安装插件请求"""
|
"""安装插件请求"""
|
||||||
|
|
||||||
plugin_id: str = Field(..., description="插件 ID")
|
plugin_id: str = Field(..., description="插件 ID")
|
||||||
repository_url: str = Field(..., description="插件仓库 URL")
|
repository_url: str = Field(..., description="插件仓库 URL")
|
||||||
branch: Optional[str] = Field("main", description="分支名称")
|
branch: Optional[str] = Field("main", description="分支名称")
|
||||||
@@ -144,6 +155,7 @@ class InstallPluginRequest(BaseModel):
|
|||||||
|
|
||||||
class VersionResponse(BaseModel):
|
class VersionResponse(BaseModel):
|
||||||
"""麦麦版本响应"""
|
"""麦麦版本响应"""
|
||||||
|
|
||||||
version: str = Field(..., description="麦麦版本号")
|
version: str = Field(..., description="麦麦版本号")
|
||||||
version_major: int = Field(..., description="主版本号")
|
version_major: int = Field(..., description="主版本号")
|
||||||
version_minor: int = Field(..., description="次版本号")
|
version_minor: int = Field(..., description="次版本号")
|
||||||
@@ -152,11 +164,13 @@ class VersionResponse(BaseModel):
|
|||||||
|
|
||||||
class UninstallPluginRequest(BaseModel):
|
class UninstallPluginRequest(BaseModel):
|
||||||
"""卸载插件请求"""
|
"""卸载插件请求"""
|
||||||
|
|
||||||
plugin_id: str = Field(..., description="插件 ID")
|
plugin_id: str = Field(..., description="插件 ID")
|
||||||
|
|
||||||
|
|
||||||
class UpdatePluginRequest(BaseModel):
|
class UpdatePluginRequest(BaseModel):
|
||||||
"""更新插件请求"""
|
"""更新插件请求"""
|
||||||
|
|
||||||
plugin_id: str = Field(..., description="插件 ID")
|
plugin_id: str = Field(..., description="插件 ID")
|
||||||
repository_url: str = Field(..., description="插件仓库 URL")
|
repository_url: str = Field(..., description="插件仓库 URL")
|
||||||
branch: Optional[str] = Field("main", description="分支名称")
|
branch: Optional[str] = Field("main", description="分支名称")
|
||||||
@@ -165,6 +179,7 @@ class UpdatePluginRequest(BaseModel):
|
|||||||
|
|
||||||
# ============ API 路由 ============
|
# ============ API 路由 ============
|
||||||
|
|
||||||
|
|
||||||
@router.get("/version", response_model=VersionResponse)
|
@router.get("/version", response_model=VersionResponse)
|
||||||
async def get_maimai_version() -> VersionResponse:
|
async def get_maimai_version() -> VersionResponse:
|
||||||
"""
|
"""
|
||||||
@@ -174,12 +189,7 @@ async def get_maimai_version() -> VersionResponse:
|
|||||||
"""
|
"""
|
||||||
major, minor, patch = parse_version(MMC_VERSION)
|
major, minor, patch = parse_version(MMC_VERSION)
|
||||||
|
|
||||||
return VersionResponse(
|
return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch)
|
||||||
version=MMC_VERSION,
|
|
||||||
version_major=major,
|
|
||||||
version_minor=minor,
|
|
||||||
version_patch=patch
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/git-status", response_model=GitStatusResponse)
|
@router.get("/git-status", response_model=GitStatusResponse)
|
||||||
@@ -196,9 +206,7 @@ async def check_git_status() -> GitStatusResponse:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||||
async def get_available_mirrors(
|
async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> AvailableMirrorsResponse:
|
|
||||||
"""
|
"""
|
||||||
获取所有可用的镜像源配置
|
获取所有可用的镜像源配置
|
||||||
"""
|
"""
|
||||||
@@ -219,22 +227,16 @@ async def get_available_mirrors(
|
|||||||
raw_prefix=m["raw_prefix"],
|
raw_prefix=m["raw_prefix"],
|
||||||
clone_prefix=m["clone_prefix"],
|
clone_prefix=m["clone_prefix"],
|
||||||
enabled=m["enabled"],
|
enabled=m["enabled"],
|
||||||
priority=m["priority"]
|
priority=m["priority"],
|
||||||
)
|
)
|
||||||
for m in all_mirrors
|
for m in all_mirrors
|
||||||
]
|
]
|
||||||
|
|
||||||
return AvailableMirrorsResponse(
|
return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list())
|
||||||
mirrors=mirrors,
|
|
||||||
default_priority=config.get_default_priority_list()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||||
async def add_mirror(
|
async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
||||||
request: AddMirrorRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> MirrorConfigResponse:
|
|
||||||
"""
|
"""
|
||||||
添加新的镜像源
|
添加新的镜像源
|
||||||
"""
|
"""
|
||||||
@@ -254,7 +256,7 @@ async def add_mirror(
|
|||||||
raw_prefix=request.raw_prefix,
|
raw_prefix=request.raw_prefix,
|
||||||
clone_prefix=request.clone_prefix,
|
clone_prefix=request.clone_prefix,
|
||||||
enabled=request.enabled,
|
enabled=request.enabled,
|
||||||
priority=request.priority
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MirrorConfigResponse(
|
return MirrorConfigResponse(
|
||||||
@@ -263,7 +265,7 @@ async def add_mirror(
|
|||||||
raw_prefix=mirror["raw_prefix"],
|
raw_prefix=mirror["raw_prefix"],
|
||||||
clone_prefix=mirror["clone_prefix"],
|
clone_prefix=mirror["clone_prefix"],
|
||||||
enabled=mirror["enabled"],
|
enabled=mirror["enabled"],
|
||||||
priority=mirror["priority"]
|
priority=mirror["priority"],
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||||
@@ -274,9 +276,7 @@ async def add_mirror(
|
|||||||
|
|
||||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||||
async def update_mirror(
|
async def update_mirror(
|
||||||
mirror_id: str,
|
mirror_id: str, request: UpdateMirrorRequest, authorization: Optional[str] = Header(None)
|
||||||
request: UpdateMirrorRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> MirrorConfigResponse:
|
) -> MirrorConfigResponse:
|
||||||
"""
|
"""
|
||||||
更新镜像源配置
|
更新镜像源配置
|
||||||
@@ -297,7 +297,7 @@ async def update_mirror(
|
|||||||
raw_prefix=request.raw_prefix,
|
raw_prefix=request.raw_prefix,
|
||||||
clone_prefix=request.clone_prefix,
|
clone_prefix=request.clone_prefix,
|
||||||
enabled=request.enabled,
|
enabled=request.enabled,
|
||||||
priority=request.priority
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not mirror:
|
if not mirror:
|
||||||
@@ -309,7 +309,7 @@ async def update_mirror(
|
|||||||
raw_prefix=mirror["raw_prefix"],
|
raw_prefix=mirror["raw_prefix"],
|
||||||
clone_prefix=mirror["clone_prefix"],
|
clone_prefix=mirror["clone_prefix"],
|
||||||
enabled=mirror["enabled"],
|
enabled=mirror["enabled"],
|
||||||
priority=mirror["priority"]
|
priority=mirror["priority"],
|
||||||
)
|
)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -319,10 +319,7 @@ async def update_mirror(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/mirrors/{mirror_id}")
|
@router.delete("/mirrors/{mirror_id}")
|
||||||
async def delete_mirror(
|
async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
mirror_id: str,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
删除镜像源
|
删除镜像源
|
||||||
"""
|
"""
|
||||||
@@ -340,16 +337,12 @@ async def delete_mirror(
|
|||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
|
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
|
||||||
|
|
||||||
return {
|
return {"success": True, "message": f"已删除镜像源: {mirror_id}"}
|
||||||
"success": True,
|
|
||||||
"message": f"已删除镜像源: {mirror_id}"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||||
async def fetch_raw_file(
|
async def fetch_raw_file(
|
||||||
request: FetchRawFileRequest,
|
request: FetchRawFileRequest, authorization: Optional[str] = Header(None)
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> FetchRawFileResponse:
|
) -> FetchRawFileResponse:
|
||||||
"""
|
"""
|
||||||
获取 GitHub 仓库的 Raw 文件内容
|
获取 GitHub 仓库的 Raw 文件内容
|
||||||
@@ -376,7 +369,7 @@ async def fetch_raw_file(
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"正在获取插件列表: {request.file_path}",
|
message=f"正在获取插件列表: {request.file_path}",
|
||||||
total_plugins=0,
|
total_plugins=0,
|
||||||
loaded_plugins=0
|
loaded_plugins=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -389,22 +382,19 @@ async def fetch_raw_file(
|
|||||||
branch=request.branch,
|
branch=request.branch,
|
||||||
file_path=request.file_path,
|
file_path=request.file_path,
|
||||||
mirror_id=request.mirror_id,
|
mirror_id=request.mirror_id,
|
||||||
custom_url=request.custom_url
|
custom_url=request.custom_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.get("success"):
|
if result.get("success"):
|
||||||
# 更新进度:成功获取
|
# 更新进度:成功获取
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=70, message="正在解析插件数据...", total_plugins=0, loaded_plugins=0
|
||||||
progress=70,
|
|
||||||
message="正在解析插件数据...",
|
|
||||||
total_plugins=0,
|
|
||||||
loaded_plugins=0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 尝试解析插件数量
|
# 尝试解析插件数量
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
data = json.loads(result.get("data", "[]"))
|
data = json.loads(result.get("data", "[]"))
|
||||||
total = len(data) if isinstance(data, list) else 0
|
total = len(data) if isinstance(data, list) else 0
|
||||||
|
|
||||||
@@ -414,16 +404,12 @@ async def fetch_raw_file(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功加载 {total} 个插件",
|
message=f"成功加载 {total} 个插件",
|
||||||
total_plugins=total,
|
total_plugins=total,
|
||||||
loaded_plugins=total
|
loaded_plugins=total,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# 如果解析失败,仍然发送成功状态
|
# 如果解析失败,仍然发送成功状态
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="success",
|
stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0
|
||||||
progress=100,
|
|
||||||
message="加载完成",
|
|
||||||
total_plugins=0,
|
|
||||||
loaded_plugins=0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return FetchRawFileResponse(**result)
|
return FetchRawFileResponse(**result)
|
||||||
@@ -433,12 +419,7 @@ async def fetch_raw_file(
|
|||||||
|
|
||||||
# 发送错误进度
|
# 发送错误进度
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error",
|
stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0
|
||||||
progress=0,
|
|
||||||
message="加载失败",
|
|
||||||
error=str(e),
|
|
||||||
total_plugins=0,
|
|
||||||
loaded_plugins=0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
@@ -446,8 +427,7 @@ async def fetch_raw_file(
|
|||||||
|
|
||||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||||
async def clone_repository(
|
async def clone_repository(
|
||||||
request: CloneRepositoryRequest,
|
request: CloneRepositoryRequest, authorization: Optional[str] = Header(None)
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> CloneRepositoryResponse:
|
) -> CloneRepositoryResponse:
|
||||||
"""
|
"""
|
||||||
克隆 GitHub 仓库到本地
|
克隆 GitHub 仓库到本地
|
||||||
@@ -460,9 +440,7 @@ async def clone_repository(
|
|||||||
if not token or not token_manager.verify_token(token):
|
if not token or not token_manager.verify_token(token):
|
||||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||||
f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
||||||
@@ -478,7 +456,7 @@ async def clone_repository(
|
|||||||
branch=request.branch,
|
branch=request.branch,
|
||||||
mirror_id=request.mirror_id,
|
mirror_id=request.mirror_id,
|
||||||
custom_url=request.custom_url,
|
custom_url=request.custom_url,
|
||||||
depth=request.depth
|
depth=request.depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CloneRepositoryResponse(**result)
|
return CloneRepositoryResponse(**result)
|
||||||
@@ -489,10 +467,7 @@ async def clone_repository(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/install")
|
@router.post("/install")
|
||||||
async def install_plugin(
|
async def install_plugin(request: InstallPluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
request: InstallPluginRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
安装插件
|
安装插件
|
||||||
|
|
||||||
@@ -513,16 +488,16 @@ async def install_plugin(
|
|||||||
progress=5,
|
progress=5,
|
||||||
message=f"开始安装插件: {request.plugin_id}",
|
message=f"开始安装插件: {request.plugin_id}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 解析仓库 URL
|
# 1. 解析仓库 URL
|
||||||
# repository_url 格式: https://github.com/owner/repo
|
# repository_url 格式: https://github.com/owner/repo
|
||||||
repo_url = request.repository_url.rstrip('/')
|
repo_url = request.repository_url.rstrip("/")
|
||||||
if repo_url.endswith('.git'):
|
if repo_url.endswith(".git"):
|
||||||
repo_url = repo_url[:-4]
|
repo_url = repo_url[:-4]
|
||||||
|
|
||||||
parts = repo_url.split('/')
|
parts = repo_url.split("/")
|
||||||
if len(parts) < 2:
|
if len(parts) < 2:
|
||||||
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
||||||
|
|
||||||
@@ -534,7 +509,7 @@ async def install_plugin(
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"解析仓库信息: {owner}/{repo}",
|
message=f"解析仓库信息: {owner}/{repo}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 确定插件安装路径
|
# 2. 确定插件安装路径
|
||||||
@@ -548,10 +523,10 @@ async def install_plugin(
|
|||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error",
|
stage="error",
|
||||||
progress=0,
|
progress=0,
|
||||||
message=f"插件已存在",
|
message="插件已存在",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="插件已安装,请先卸载"
|
error="插件已安装,请先卸载",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="插件已安装")
|
raise HTTPException(status_code=400, detail="插件已安装")
|
||||||
|
|
||||||
@@ -560,31 +535,26 @@ async def install_plugin(
|
|||||||
progress=15,
|
progress=15,
|
||||||
message=f"准备克隆到: {target_path}",
|
message=f"准备克隆到: {target_path}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
|
|
||||||
# 如果是 GitHub 仓库,使用镜像源
|
# 如果是 GitHub 仓库,使用镜像源
|
||||||
if 'github.com' in repo_url:
|
if "github.com" in repo_url:
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
owner=owner,
|
owner=owner,
|
||||||
repo=repo,
|
repo=repo,
|
||||||
target_path=target_path,
|
target_path=target_path,
|
||||||
branch=request.branch,
|
branch=request.branch,
|
||||||
mirror_id=request.mirror_id,
|
mirror_id=request.mirror_id,
|
||||||
depth=1 # 浅克隆,节省时间和空间
|
depth=1, # 浅克隆,节省时间和空间
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 自定义仓库,直接使用 URL
|
# 自定义仓库,直接使用 URL
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
owner=owner,
|
owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||||
repo=repo,
|
|
||||||
target_path=target_path,
|
|
||||||
branch=request.branch,
|
|
||||||
custom_url=repo_url,
|
|
||||||
depth=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
@@ -595,23 +565,20 @@ async def install_plugin(
|
|||||||
message="克隆仓库失败",
|
message="克隆仓库失败",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=error_msg
|
error=error_msg,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
# 4. 验证插件完整性
|
# 4. 验证插件完整性
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
|
||||||
progress=85,
|
|
||||||
message="验证插件文件...",
|
|
||||||
operation="install",
|
|
||||||
plugin_id=request.plugin_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
manifest_path = target_path / "_manifest.json"
|
manifest_path = target_path / "_manifest.json"
|
||||||
if not manifest_path.exists():
|
if not manifest_path.exists():
|
||||||
# 清理失败的安装
|
# 清理失败的安装
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(target_path, ignore_errors=True)
|
shutil.rmtree(target_path, ignore_errors=True)
|
||||||
|
|
||||||
await update_progress(
|
await update_progress(
|
||||||
@@ -620,26 +587,23 @@ async def install_plugin(
|
|||||||
message="插件缺少 _manifest.json",
|
message="插件缺少 _manifest.json",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="无效的插件格式"
|
error="无效的插件格式",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
||||||
# 5. 读取并验证 manifest
|
# 5. 读取并验证 manifest
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
|
||||||
progress=90,
|
|
||||||
message="读取插件配置...",
|
|
||||||
operation="install",
|
|
||||||
plugin_id=request.plugin_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import json as json_module
|
import json as json_module
|
||||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
|
|
||||||
# 基本验证
|
# 基本验证
|
||||||
required_fields = ['manifest_version', 'name', 'version', 'author']
|
required_fields = ["manifest_version", "name", "version", "author"]
|
||||||
for field in required_fields:
|
for field in required_fields:
|
||||||
if field not in manifest:
|
if field not in manifest:
|
||||||
raise ValueError(f"缺少必需字段: {field}")
|
raise ValueError(f"缺少必需字段: {field}")
|
||||||
@@ -647,6 +611,7 @@ async def install_plugin(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 清理失败的安装
|
# 清理失败的安装
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(target_path, ignore_errors=True)
|
shutil.rmtree(target_path, ignore_errors=True)
|
||||||
|
|
||||||
await update_progress(
|
await update_progress(
|
||||||
@@ -655,7 +620,7 @@ async def install_plugin(
|
|||||||
message="_manifest.json 无效",
|
message="_manifest.json 无效",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=str(e)
|
error=str(e),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
|
|
||||||
@@ -665,16 +630,16 @@ async def install_plugin(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "插件安装成功",
|
"message": "插件安装成功",
|
||||||
"plugin_id": request.plugin_id,
|
"plugin_id": request.plugin_id,
|
||||||
"plugin_name": manifest['name'],
|
"plugin_name": manifest["name"],
|
||||||
"version": manifest['version'],
|
"version": manifest["version"],
|
||||||
"path": str(target_path)
|
"path": str(target_path),
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -688,7 +653,7 @@ async def install_plugin(
|
|||||||
message="安装失败",
|
message="安装失败",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=str(e)
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
@@ -696,8 +661,7 @@ async def install_plugin(
|
|||||||
|
|
||||||
@router.post("/uninstall")
|
@router.post("/uninstall")
|
||||||
async def uninstall_plugin(
|
async def uninstall_plugin(
|
||||||
request: UninstallPluginRequest,
|
request: UninstallPluginRequest, authorization: Optional[str] = Header(None)
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
卸载插件
|
卸载插件
|
||||||
@@ -719,7 +683,7 @@ async def uninstall_plugin(
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"开始卸载插件: {request.plugin_id}",
|
message=f"开始卸载插件: {request.plugin_id}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 检查插件是否存在
|
# 1. 检查插件是否存在
|
||||||
@@ -733,7 +697,7 @@ async def uninstall_plugin(
|
|||||||
message="插件不存在",
|
message="插件不存在",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="插件未安装或已被删除"
|
error="插件未安装或已被删除",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
|
|
||||||
@@ -742,7 +706,7 @@ async def uninstall_plugin(
|
|||||||
progress=30,
|
progress=30,
|
||||||
message=f"正在删除插件文件: {plugin_path}",
|
message=f"正在删除插件文件: {plugin_path}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 读取插件信息(用于日志)
|
# 2. 读取插件信息(用于日志)
|
||||||
@@ -752,7 +716,8 @@ async def uninstall_plugin(
|
|||||||
if manifest_path.exists():
|
if manifest_path.exists():
|
||||||
try:
|
try:
|
||||||
import json as json_module
|
import json as json_module
|
||||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
plugin_name = manifest.get("name", request.plugin_id)
|
plugin_name = manifest.get("name", request.plugin_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -763,7 +728,7 @@ async def uninstall_plugin(
|
|||||||
progress=50,
|
progress=50,
|
||||||
message=f"正在删除 {plugin_name}...",
|
message=f"正在删除 {plugin_name}...",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 删除插件目录
|
# 3. 删除插件目录
|
||||||
@@ -773,6 +738,7 @@ async def uninstall_plugin(
|
|||||||
def remove_readonly(func, path, _):
|
def remove_readonly(func, path, _):
|
||||||
"""清除只读属性并删除文件"""
|
"""清除只读属性并删除文件"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.chmod(path, stat.S_IWRITE)
|
os.chmod(path, stat.S_IWRITE)
|
||||||
func(path)
|
func(path)
|
||||||
|
|
||||||
@@ -786,15 +752,10 @@ async def uninstall_plugin(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功卸载插件: {plugin_name}",
|
message=f"成功卸载插件: {plugin_name}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
|
||||||
"success": True,
|
|
||||||
"message": "插件卸载成功",
|
|
||||||
"plugin_id": request.plugin_id,
|
|
||||||
"plugin_name": plugin_name
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -807,7 +768,7 @@ async def uninstall_plugin(
|
|||||||
message="卸载失败",
|
message="卸载失败",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="权限不足,无法删除插件文件"
|
error="权限不足,无法删除插件文件",
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
||||||
@@ -820,17 +781,14 @@ async def uninstall_plugin(
|
|||||||
message="卸载失败",
|
message="卸载失败",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=str(e)
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update")
|
@router.post("/update")
|
||||||
async def update_plugin(
|
async def update_plugin(request: UpdatePluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
request: UpdatePluginRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
更新插件
|
更新插件
|
||||||
|
|
||||||
@@ -851,7 +809,7 @@ async def update_plugin(
|
|||||||
progress=5,
|
progress=5,
|
||||||
message=f"开始更新插件: {request.plugin_id}",
|
message=f"开始更新插件: {request.plugin_id}",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 检查插件是否已安装
|
# 1. 检查插件是否已安装
|
||||||
@@ -865,7 +823,7 @@ async def update_plugin(
|
|||||||
message="插件不存在",
|
message="插件不存在",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="插件未安装,请先安装"
|
error="插件未安装,请先安装",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
|
|
||||||
@@ -877,10 +835,11 @@ async def update_plugin(
|
|||||||
if manifest_path.exists():
|
if manifest_path.exists():
|
||||||
try:
|
try:
|
||||||
import json as json_module
|
import json as json_module
|
||||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
old_version = manifest.get("version", "unknown")
|
old_version = manifest.get("version", "unknown")
|
||||||
plugin_name = manifest.get("name", request.plugin_id)
|
_plugin_name = manifest.get("name", request.plugin_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -889,16 +848,12 @@ async def update_plugin(
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"当前版本: {old_version},准备更新...",
|
message=f"当前版本: {old_version},准备更新...",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 删除旧版本
|
# 3. 删除旧版本
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
|
||||||
progress=20,
|
|
||||||
message="正在删除旧版本...",
|
|
||||||
operation="update",
|
|
||||||
plugin_id=request.plugin_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
@@ -907,6 +862,7 @@ async def update_plugin(
|
|||||||
def remove_readonly(func, path, _):
|
def remove_readonly(func, path, _):
|
||||||
"""清除只读属性并删除文件"""
|
"""清除只读属性并删除文件"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.chmod(path, stat.S_IWRITE)
|
os.chmod(path, stat.S_IWRITE)
|
||||||
func(path)
|
func(path)
|
||||||
|
|
||||||
@@ -920,14 +876,14 @@ async def update_plugin(
|
|||||||
progress=30,
|
progress=30,
|
||||||
message="正在准备下载新版本...",
|
message="正在准备下载新版本...",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
repo_url = request.repository_url.rstrip('/')
|
repo_url = request.repository_url.rstrip("/")
|
||||||
if repo_url.endswith('.git'):
|
if repo_url.endswith(".git"):
|
||||||
repo_url = repo_url[:-4]
|
repo_url = repo_url[:-4]
|
||||||
|
|
||||||
parts = repo_url.split('/')
|
parts = repo_url.split("/")
|
||||||
if len(parts) < 2:
|
if len(parts) < 2:
|
||||||
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
||||||
|
|
||||||
@@ -937,23 +893,18 @@ async def update_plugin(
|
|||||||
# 5. 克隆新版本(这里会推送 35%-85% 的进度)
|
# 5. 克隆新版本(这里会推送 35%-85% 的进度)
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
|
|
||||||
if 'github.com' in repo_url:
|
if "github.com" in repo_url:
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
owner=owner,
|
owner=owner,
|
||||||
repo=repo,
|
repo=repo,
|
||||||
target_path=plugin_path,
|
target_path=plugin_path,
|
||||||
branch=request.branch,
|
branch=request.branch,
|
||||||
mirror_id=request.mirror_id,
|
mirror_id=request.mirror_id,
|
||||||
depth=1
|
depth=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
owner=owner,
|
owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||||
repo=repo,
|
|
||||||
target_path=plugin_path,
|
|
||||||
branch=request.branch,
|
|
||||||
custom_url=repo_url,
|
|
||||||
depth=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
@@ -964,17 +915,13 @@ async def update_plugin(
|
|||||||
message="下载新版本失败",
|
message="下载新版本失败",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=error_msg
|
error=error_msg,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
# 6. 验证新版本
|
# 6. 验证新版本
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
|
||||||
progress=90,
|
|
||||||
message="验证新版本...",
|
|
||||||
operation="update",
|
|
||||||
plugin_id=request.plugin_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_manifest_path = plugin_path / "_manifest.json"
|
new_manifest_path = plugin_path / "_manifest.json"
|
||||||
@@ -983,6 +930,7 @@ async def update_plugin(
|
|||||||
def remove_readonly(func, path, _):
|
def remove_readonly(func, path, _):
|
||||||
"""清除只读属性并删除文件"""
|
"""清除只读属性并删除文件"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.chmod(path, stat.S_IWRITE)
|
os.chmod(path, stat.S_IWRITE)
|
||||||
func(path)
|
func(path)
|
||||||
|
|
||||||
@@ -994,13 +942,13 @@ async def update_plugin(
|
|||||||
message="新版本缺少 _manifest.json",
|
message="新版本缺少 _manifest.json",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="无效的插件格式"
|
error="无效的插件格式",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
||||||
# 7. 读取新版本信息
|
# 7. 读取新版本信息
|
||||||
try:
|
try:
|
||||||
with open(new_manifest_path, 'r', encoding='utf-8') as f:
|
with open(new_manifest_path, "r", encoding="utf-8") as f:
|
||||||
new_manifest = json_module.load(f)
|
new_manifest = json_module.load(f)
|
||||||
|
|
||||||
new_version = new_manifest.get("version", "unknown")
|
new_version = new_manifest.get("version", "unknown")
|
||||||
@@ -1014,7 +962,7 @@ async def update_plugin(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -1023,7 +971,7 @@ async def update_plugin(
|
|||||||
"plugin_id": request.plugin_id,
|
"plugin_id": request.plugin_id,
|
||||||
"plugin_name": new_name,
|
"plugin_name": new_name,
|
||||||
"old_version": old_version,
|
"old_version": old_version,
|
||||||
"new_version": new_version
|
"new_version": new_version,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1036,7 +984,7 @@ async def update_plugin(
|
|||||||
message="_manifest.json 无效",
|
message="_manifest.json 无效",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=str(e)
|
error=str(e),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
|
|
||||||
@@ -1046,21 +994,14 @@ async def update_plugin(
|
|||||||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||||
|
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error",
|
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
|
||||||
progress=0,
|
|
||||||
message="更新失败",
|
|
||||||
operation="update",
|
|
||||||
plugin_id=request.plugin_id,
|
|
||||||
error=str(e)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/installed")
|
@router.get("/installed")
|
||||||
async def get_installed_plugins(
|
async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
获取已安装的插件列表
|
获取已安装的插件列表
|
||||||
|
|
||||||
@@ -1081,10 +1022,7 @@ async def get_installed_plugins(
|
|||||||
if not plugins_dir.exists():
|
if not plugins_dir.exists():
|
||||||
logger.info("插件目录不存在,创建目录")
|
logger.info("插件目录不存在,创建目录")
|
||||||
plugins_dir.mkdir(exist_ok=True)
|
plugins_dir.mkdir(exist_ok=True)
|
||||||
return {
|
return {"success": True, "plugins": []}
|
||||||
"success": True,
|
|
||||||
"plugins": []
|
|
||||||
}
|
|
||||||
|
|
||||||
installed_plugins = []
|
installed_plugins = []
|
||||||
|
|
||||||
@@ -1098,7 +1036,7 @@ async def get_installed_plugins(
|
|||||||
plugin_id = plugin_path.name
|
plugin_id = plugin_path.name
|
||||||
|
|
||||||
# 跳过隐藏目录和特殊目录
|
# 跳过隐藏目录和特殊目录
|
||||||
if plugin_id.startswith('.') or plugin_id.startswith('__'):
|
if plugin_id.startswith(".") or plugin_id.startswith("__"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 读取 _manifest.json
|
# 读取 _manifest.json
|
||||||
@@ -1110,20 +1048,23 @@ async def get_installed_plugins(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import json as json_module
|
import json as json_module
|
||||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
|
|
||||||
# 基本验证
|
# 基本验证
|
||||||
if 'name' not in manifest or 'version' not in manifest:
|
if "name" not in manifest or "version" not in manifest:
|
||||||
logger.warning(f"插件 {plugin_id} 的 _manifest.json 格式无效,跳过")
|
logger.warning(f"插件 {plugin_id} 的 _manifest.json 格式无效,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 添加到已安装列表(返回完整的 manifest 信息)
|
# 添加到已安装列表(返回完整的 manifest 信息)
|
||||||
installed_plugins.append({
|
installed_plugins.append(
|
||||||
"id": plugin_id,
|
{
|
||||||
"manifest": manifest, # 返回完整的 manifest 对象
|
"id": plugin_id,
|
||||||
"path": str(plugin_path.absolute())
|
"manifest": manifest, # 返回完整的 manifest 对象
|
||||||
})
|
"path": str(plugin_path.absolute()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning(f"插件 {plugin_id} 的 _manifest.json 解析失败: {e}")
|
logger.warning(f"插件 {plugin_id} 的 _manifest.json 解析失败: {e}")
|
||||||
@@ -1134,11 +1075,7 @@ async def get_installed_plugins(
|
|||||||
|
|
||||||
logger.info(f"找到 {len(installed_plugins)} 个已安装插件")
|
logger.info(f"找到 {len(installed_plugins)} 个已安装插件")
|
||||||
|
|
||||||
return {
|
return {"success": True, "plugins": installed_plugins, "total": len(installed_plugins)}
|
||||||
"success": True,
|
|
||||||
"plugins": installed_plugins,
|
|
||||||
"total": len(installed_plugins)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
||||||
|
|||||||
97
src/webui/routers/system.py
Normal file
97
src/webui/routers/system.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
系统控制路由
|
||||||
|
|
||||||
|
提供系统重启、状态查询等功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from src.config.config import MMC_VERSION
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/system", tags=["system"])
|
||||||
|
|
||||||
|
# 记录启动时间
|
||||||
|
_start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class RestartResponse(BaseModel):
|
||||||
|
"""重启响应"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class StatusResponse(BaseModel):
|
||||||
|
"""状态响应"""
|
||||||
|
|
||||||
|
running: bool
|
||||||
|
uptime: float
|
||||||
|
version: str
|
||||||
|
start_time: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/restart", response_model=RestartResponse)
|
||||||
|
async def restart_maibot():
|
||||||
|
"""
|
||||||
|
重启麦麦主程序
|
||||||
|
|
||||||
|
使用 os.execv 重启当前进程,配置更改将在重启后生效。
|
||||||
|
注意:此操作会使麦麦暂时离线。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 记录重启操作
|
||||||
|
print(f"[{datetime.now()}] WebUI 触发重启操作")
|
||||||
|
|
||||||
|
# 使用 os.execv 重启当前进程
|
||||||
|
# 这会替换当前进程,保持相同的 PID
|
||||||
|
python = sys.executable
|
||||||
|
args = [python] + sys.argv
|
||||||
|
|
||||||
|
# 返回成功响应(实际上这个响应可能不会发送,因为进程会立即重启)
|
||||||
|
# 但我们仍然返回它以保持 API 一致性
|
||||||
|
os.execv(python, args)
|
||||||
|
|
||||||
|
return RestartResponse(success=True, message="麦麦正在重启中...")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status", response_model=StatusResponse)
|
||||||
|
async def get_maibot_status():
|
||||||
|
"""
|
||||||
|
获取麦麦运行状态
|
||||||
|
|
||||||
|
返回麦麦的运行状态、运行时长和版本信息。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
uptime = time.time() - _start_time
|
||||||
|
|
||||||
|
# 尝试获取版本信息(需要根据实际情况调整)
|
||||||
|
version = MMC_VERSION # 可以从配置或常量中读取
|
||||||
|
|
||||||
|
return StatusResponse(
|
||||||
|
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# 可选:添加更多系统控制功能
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reload-config")
|
||||||
|
async def reload_config():
|
||||||
|
"""
|
||||||
|
热重载配置(不重启进程)
|
||||||
|
|
||||||
|
仅重新加载配置文件,某些配置可能需要重启才能生效。
|
||||||
|
此功能需要在主程序中实现配置热重载逻辑。
|
||||||
|
"""
|
||||||
|
# 这里需要调用主程序的配置重载函数
|
||||||
|
# 示例:await app_instance.reload_config()
|
||||||
|
|
||||||
|
return {"success": True, "message": "配置重载功能待实现"}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
"""WebUI API 路由"""
|
"""WebUI API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header
|
from fastapi import APIRouter, HTTPException, Header
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -11,6 +12,7 @@ from .expression_routes import router as expression_router
|
|||||||
from .emoji_routes import router as emoji_router
|
from .emoji_routes import router as emoji_router
|
||||||
from .plugin_routes import router as plugin_router
|
from .plugin_routes import router as plugin_router
|
||||||
from .plugin_progress_ws import get_progress_router
|
from .plugin_progress_ws import get_progress_router
|
||||||
|
from .routers.system import router as system_router
|
||||||
|
|
||||||
logger = get_logger("webui.api")
|
logger = get_logger("webui.api")
|
||||||
|
|
||||||
@@ -31,37 +33,65 @@ router.include_router(emoji_router)
|
|||||||
router.include_router(plugin_router)
|
router.include_router(plugin_router)
|
||||||
# 注册插件进度 WebSocket 路由
|
# 注册插件进度 WebSocket 路由
|
||||||
router.include_router(get_progress_router())
|
router.include_router(get_progress_router())
|
||||||
|
# 注册系统控制路由
|
||||||
|
router.include_router(system_router)
|
||||||
|
|
||||||
|
|
||||||
class TokenVerifyRequest(BaseModel):
|
class TokenVerifyRequest(BaseModel):
|
||||||
"""Token 验证请求"""
|
"""Token 验证请求"""
|
||||||
|
|
||||||
token: str = Field(..., description="访问令牌")
|
token: str = Field(..., description="访问令牌")
|
||||||
|
|
||||||
|
|
||||||
class TokenVerifyResponse(BaseModel):
|
class TokenVerifyResponse(BaseModel):
|
||||||
"""Token 验证响应"""
|
"""Token 验证响应"""
|
||||||
|
|
||||||
valid: bool = Field(..., description="Token 是否有效")
|
valid: bool = Field(..., description="Token 是否有效")
|
||||||
message: str = Field(..., description="验证结果消息")
|
message: str = Field(..., description="验证结果消息")
|
||||||
|
|
||||||
|
|
||||||
class TokenUpdateRequest(BaseModel):
|
class TokenUpdateRequest(BaseModel):
|
||||||
"""Token 更新请求"""
|
"""Token 更新请求"""
|
||||||
|
|
||||||
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
||||||
|
|
||||||
|
|
||||||
class TokenUpdateResponse(BaseModel):
|
class TokenUpdateResponse(BaseModel):
|
||||||
"""Token 更新响应"""
|
"""Token 更新响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否更新成功")
|
success: bool = Field(..., description="是否更新成功")
|
||||||
message: str = Field(..., description="更新结果消息")
|
message: str = Field(..., description="更新结果消息")
|
||||||
|
|
||||||
|
|
||||||
class TokenRegenerateResponse(BaseModel):
|
class TokenRegenerateResponse(BaseModel):
|
||||||
"""Token 重新生成响应"""
|
"""Token 重新生成响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否生成成功")
|
success: bool = Field(..., description="是否生成成功")
|
||||||
token: str = Field(..., description="新生成的令牌")
|
token: str = Field(..., description="新生成的令牌")
|
||||||
message: str = Field(..., description="生成结果消息")
|
message: str = Field(..., description="生成结果消息")
|
||||||
|
|
||||||
|
|
||||||
|
class FirstSetupStatusResponse(BaseModel):
|
||||||
|
"""首次配置状态响应"""
|
||||||
|
|
||||||
|
is_first_setup: bool = Field(..., description="是否为首次配置")
|
||||||
|
message: str = Field(..., description="状态消息")
|
||||||
|
|
||||||
|
|
||||||
|
class CompleteSetupResponse(BaseModel):
|
||||||
|
"""完成配置响应"""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="是否成功")
|
||||||
|
message: str = Field(..., description="结果消息")
|
||||||
|
|
||||||
|
|
||||||
|
class ResetSetupResponse(BaseModel):
|
||||||
|
"""重置配置响应"""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="是否成功")
|
||||||
|
message: str = Field(..., description="结果消息")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health")
|
@router.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""健康检查"""
|
"""健康检查"""
|
||||||
@@ -84,25 +114,16 @@ async def verify_token(request: TokenVerifyRequest):
|
|||||||
is_valid = token_manager.verify_token(request.token)
|
is_valid = token_manager.verify_token(request.token)
|
||||||
|
|
||||||
if is_valid:
|
if is_valid:
|
||||||
return TokenVerifyResponse(
|
return TokenVerifyResponse(valid=True, message="Token 验证成功")
|
||||||
valid=True,
|
|
||||||
message="Token 验证成功"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return TokenVerifyResponse(
|
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||||
valid=False,
|
|
||||||
message="Token 无效或已过期"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Token 验证失败: {e}")
|
logger.error(f"Token 验证失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||||
async def update_token(
|
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||||
request: TokenUpdateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
更新访问令牌(需要当前有效的 token)
|
更新访问令牌(需要当前有效的 token)
|
||||||
|
|
||||||
@@ -127,10 +148,7 @@ async def update_token(
|
|||||||
# 更新 token
|
# 更新 token
|
||||||
success, message = token_manager.update_token(request.new_token)
|
success, message = token_manager.update_token(request.new_token)
|
||||||
|
|
||||||
return TokenUpdateResponse(
|
return TokenUpdateResponse(success=success, message=message)
|
||||||
success=success,
|
|
||||||
message=message
|
|
||||||
)
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -163,14 +181,108 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
|||||||
# 重新生成 token
|
# 重新生成 token
|
||||||
new_token = token_manager.regenerate_token()
|
new_token = token_manager.regenerate_token()
|
||||||
|
|
||||||
return TokenRegenerateResponse(
|
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||||
success=True,
|
|
||||||
token=new_token,
|
|
||||||
message="Token 已重新生成"
|
|
||||||
)
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Token 重新生成失败: {e}")
|
logger.error(f"Token 重新生成失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
|
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
|
||||||
|
async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
获取首次配置状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization: Authorization header (Bearer token)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
首次配置状态
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证 token
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||||
|
|
||||||
|
current_token = authorization.replace("Bearer ", "")
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
|
||||||
|
if not token_manager.verify_token(current_token):
|
||||||
|
raise HTTPException(status_code=401, detail="Token 无效")
|
||||||
|
|
||||||
|
# 检查是否为首次配置
|
||||||
|
is_first = token_manager.is_first_setup()
|
||||||
|
|
||||||
|
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取配置状态失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="获取配置状态失败") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/setup/complete", response_model=CompleteSetupResponse)
|
||||||
|
async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
标记首次配置完成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization: Authorization header (Bearer token)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完成结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证 token
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||||
|
|
||||||
|
current_token = authorization.replace("Bearer ", "")
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
|
||||||
|
if not token_manager.verify_token(current_token):
|
||||||
|
raise HTTPException(status_code=401, detail="Token 无效")
|
||||||
|
|
||||||
|
# 标记配置完成
|
||||||
|
success = token_manager.mark_setup_completed()
|
||||||
|
|
||||||
|
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"标记配置完成失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="标记配置完成失败") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/setup/reset", response_model=ResetSetupResponse)
|
||||||
|
async def reset_setup(authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
重置首次配置状态,允许重新进入配置向导
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization: Authorization header (Bearer token)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
重置结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证 token
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||||
|
|
||||||
|
current_token = authorization.replace("Bearer ", "")
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
|
||||||
|
if not token_manager.verify_token(current_token):
|
||||||
|
raise HTTPException(status_code=401, detail="Token 无效")
|
||||||
|
|
||||||
|
# 重置配置状态
|
||||||
|
success = token_manager.reset_setup_status()
|
||||||
|
|
||||||
|
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重置配置状态失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="重置配置状态失败") from e
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""统计数据 API 路由"""
|
"""统计数据 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from collections import defaultdict
|
from peewee import fn
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"])
|
|||||||
|
|
||||||
class StatisticsSummary(BaseModel):
|
class StatisticsSummary(BaseModel):
|
||||||
"""统计数据摘要"""
|
"""统计数据摘要"""
|
||||||
|
|
||||||
total_requests: int = Field(0, description="总请求数")
|
total_requests: int = Field(0, description="总请求数")
|
||||||
total_cost: float = Field(0.0, description="总花费")
|
total_cost: float = Field(0.0, description="总花费")
|
||||||
total_tokens: int = Field(0, description="总token数")
|
total_tokens: int = Field(0, description="总token数")
|
||||||
@@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel):
|
|||||||
|
|
||||||
class ModelStatistics(BaseModel):
|
class ModelStatistics(BaseModel):
|
||||||
"""模型统计"""
|
"""模型统计"""
|
||||||
|
|
||||||
model_name: str
|
model_name: str
|
||||||
request_count: int
|
request_count: int
|
||||||
total_cost: float
|
total_cost: float
|
||||||
@@ -37,6 +40,7 @@ class ModelStatistics(BaseModel):
|
|||||||
|
|
||||||
class TimeSeriesData(BaseModel):
|
class TimeSeriesData(BaseModel):
|
||||||
"""时间序列数据"""
|
"""时间序列数据"""
|
||||||
|
|
||||||
timestamp: str
|
timestamp: str
|
||||||
requests: int = 0
|
requests: int = 0
|
||||||
cost: float = 0.0
|
cost: float = 0.0
|
||||||
@@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel):
|
|||||||
|
|
||||||
class DashboardData(BaseModel):
|
class DashboardData(BaseModel):
|
||||||
"""仪表盘数据"""
|
"""仪表盘数据"""
|
||||||
|
|
||||||
summary: StatisticsSummary
|
summary: StatisticsSummary
|
||||||
model_stats: List[ModelStatistics]
|
model_stats: List[ModelStatistics]
|
||||||
hourly_data: List[TimeSeriesData]
|
hourly_data: List[TimeSeriesData]
|
||||||
@@ -88,7 +93,7 @@ async def get_dashboard_data(hours: int = 24):
|
|||||||
model_stats=model_stats,
|
model_stats=model_stats,
|
||||||
hourly_data=hourly_data,
|
hourly_data=hourly_data,
|
||||||
daily_data=daily_data,
|
daily_data=daily_data,
|
||||||
recent_activity=recent_activity
|
recent_activity=recent_activity,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取仪表盘数据失败: {e}")
|
logger.error(f"获取仪表盘数据失败: {e}")
|
||||||
@@ -96,39 +101,26 @@ async def get_dashboard_data(hours: int = 24):
|
|||||||
|
|
||||||
|
|
||||||
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||||
"""获取摘要统计数据"""
|
"""获取摘要统计数据(优化:使用数据库聚合)"""
|
||||||
summary = StatisticsSummary()
|
summary = StatisticsSummary()
|
||||||
|
|
||||||
# 查询 LLM 使用记录
|
# 使用聚合查询替代全量加载
|
||||||
llm_records = list(
|
query = LLMUsage.select(
|
||||||
LLMUsage.select()
|
fn.COUNT(LLMUsage.id).alias("total_requests"),
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||||
.where(LLMUsage.timestamp <= end_time)
|
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
||||||
)
|
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
||||||
|
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||||
|
|
||||||
total_time_cost = 0.0
|
result = query.dicts().get()
|
||||||
time_cost_count = 0
|
summary.total_requests = result["total_requests"]
|
||||||
|
summary.total_cost = result["total_cost"]
|
||||||
|
summary.total_tokens = result["total_tokens"]
|
||||||
|
summary.avg_response_time = result["avg_response_time"] or 0.0
|
||||||
|
|
||||||
for record in llm_records:
|
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
|
||||||
summary.total_requests += 1
|
|
||||||
summary.total_cost += record.cost or 0.0
|
|
||||||
summary.total_tokens += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
|
||||||
|
|
||||||
if record.time_cost and record.time_cost > 0:
|
|
||||||
total_time_cost += record.time_cost
|
|
||||||
time_cost_count += 1
|
|
||||||
|
|
||||||
# 计算平均响应时间
|
|
||||||
if time_cost_count > 0:
|
|
||||||
summary.avg_response_time = total_time_cost / time_cost_count
|
|
||||||
|
|
||||||
# 查询在线时间
|
|
||||||
online_records = list(
|
online_records = list(
|
||||||
OnlineTime.select()
|
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
|
||||||
.where(
|
|
||||||
(OnlineTime.start_timestamp >= start_time) |
|
|
||||||
(OnlineTime.end_timestamp >= start_time)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for record in online_records:
|
for record in online_records:
|
||||||
@@ -137,16 +129,19 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||||||
if end > start:
|
if end > start:
|
||||||
summary.online_time += (end - start).total_seconds()
|
summary.online_time += (end - start).total_seconds()
|
||||||
|
|
||||||
# 查询消息数量
|
# 查询消息数量 - 使用聚合优化
|
||||||
messages = list(
|
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
||||||
Messages.select()
|
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
|
||||||
.where(Messages.time >= start_time.timestamp())
|
|
||||||
.where(Messages.time <= end_time.timestamp())
|
|
||||||
)
|
)
|
||||||
|
summary.total_messages = messages_query.scalar() or 0
|
||||||
|
|
||||||
summary.total_messages = len(messages)
|
# 统计回复数量
|
||||||
# 简单统计:如果 reply_to 不为空,则认为是回复
|
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
||||||
summary.total_replies = len([m for m in messages if m.reply_to])
|
(Messages.time >= start_time.timestamp())
|
||||||
|
& (Messages.time <= end_time.timestamp())
|
||||||
|
& (Messages.reply_to.is_null(False))
|
||||||
|
)
|
||||||
|
summary.total_replies = replies_query.scalar() or 0
|
||||||
|
|
||||||
# 计算派生指标
|
# 计算派生指标
|
||||||
if summary.online_time > 0:
|
if summary.online_time > 0:
|
||||||
@@ -158,113 +153,101 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||||||
|
|
||||||
|
|
||||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||||
"""获取模型统计数据"""
|
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||||
model_data = defaultdict(lambda: {
|
# 使用GROUP BY聚合,避免全量加载
|
||||||
'request_count': 0,
|
query = (
|
||||||
'total_cost': 0.0,
|
LLMUsage.select(
|
||||||
'total_tokens': 0,
|
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
|
||||||
'time_costs': []
|
fn.COUNT(LLMUsage.id).alias("request_count"),
|
||||||
})
|
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||||
|
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
||||||
records = list(
|
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
||||||
LLMUsage.select()
|
)
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
.where(LLMUsage.timestamp >= start_time)
|
||||||
|
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
|
||||||
|
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||||
|
.limit(10) # 只取前10个
|
||||||
)
|
)
|
||||||
|
|
||||||
for record in records:
|
|
||||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
|
||||||
model_data[model_name]['request_count'] += 1
|
|
||||||
model_data[model_name]['total_cost'] += record.cost or 0.0
|
|
||||||
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
|
||||||
|
|
||||||
if record.time_cost and record.time_cost > 0:
|
|
||||||
model_data[model_name]['time_costs'].append(record.time_cost)
|
|
||||||
|
|
||||||
# 转换为列表并排序
|
|
||||||
result = []
|
result = []
|
||||||
for model_name, data in model_data.items():
|
for row in query.dicts():
|
||||||
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
|
result.append(
|
||||||
result.append(ModelStatistics(
|
ModelStatistics(
|
||||||
model_name=model_name,
|
model_name=row["model_name"],
|
||||||
request_count=data['request_count'],
|
request_count=row["request_count"],
|
||||||
total_cost=data['total_cost'],
|
total_cost=row["total_cost"],
|
||||||
total_tokens=data['total_tokens'],
|
total_tokens=row["total_tokens"],
|
||||||
avg_response_time=avg_time
|
avg_response_time=row["avg_response_time"] or 0.0,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 按请求数排序
|
return result
|
||||||
result.sort(key=lambda x: x.request_count, reverse=True)
|
|
||||||
return result[:10] # 返回前10个
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||||
"""获取小时级统计数据"""
|
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||||
# 创建小时桶
|
# SQLite的日期时间函数进行小时分组
|
||||||
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
# 使用strftime将timestamp格式化为小时级别
|
||||||
|
query = (
|
||||||
records = list(
|
LLMUsage.select(
|
||||||
LLMUsage.select()
|
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
fn.COUNT(LLMUsage.id).alias("requests"),
|
||||||
.where(LLMUsage.timestamp <= end_time)
|
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||||
|
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
||||||
|
)
|
||||||
|
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||||
|
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
|
||||||
)
|
)
|
||||||
|
|
||||||
for record in records:
|
# 转换为字典以快速查找
|
||||||
# 获取小时键(去掉分钟和秒)
|
data_dict = {row["hour"]: row for row in query.dicts()}
|
||||||
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
|
|
||||||
hour_str = hour_key.isoformat()
|
|
||||||
|
|
||||||
hourly_buckets[hour_str]['requests'] += 1
|
|
||||||
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
|
|
||||||
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
|
||||||
|
|
||||||
# 填充所有小时(包括没有数据的)
|
# 填充所有小时(包括没有数据的)
|
||||||
result = []
|
result = []
|
||||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||||
while current <= end_time:
|
while current <= end_time:
|
||||||
hour_str = current.isoformat()
|
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
|
||||||
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
if hour_str in data_dict:
|
||||||
result.append(TimeSeriesData(
|
row = data_dict[hour_str]
|
||||||
timestamp=hour_str,
|
result.append(
|
||||||
requests=data['requests'],
|
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
||||||
cost=data['cost'],
|
)
|
||||||
tokens=data['tokens']
|
else:
|
||||||
))
|
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
||||||
current += timedelta(hours=1)
|
current += timedelta(hours=1)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||||
"""获取日级统计数据"""
|
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||||
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
# 使用strftime按日期分组
|
||||||
|
query = (
|
||||||
records = list(
|
LLMUsage.select(
|
||||||
LLMUsage.select()
|
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
fn.COUNT(LLMUsage.id).alias("requests"),
|
||||||
.where(LLMUsage.timestamp <= end_time)
|
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||||
|
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
||||||
|
)
|
||||||
|
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||||
|
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
|
||||||
)
|
)
|
||||||
|
|
||||||
for record in records:
|
# 转换为字典
|
||||||
# 获取日期键
|
data_dict = {row["day"]: row for row in query.dicts()}
|
||||||
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
day_str = day_key.isoformat()
|
|
||||||
|
|
||||||
daily_buckets[day_str]['requests'] += 1
|
|
||||||
daily_buckets[day_str]['cost'] += record.cost or 0.0
|
|
||||||
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
|
||||||
|
|
||||||
# 填充所有天
|
# 填充所有天
|
||||||
result = []
|
result = []
|
||||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
while current <= end_time:
|
while current <= end_time:
|
||||||
day_str = current.isoformat()
|
day_str = current.strftime("%Y-%m-%dT00:00:00")
|
||||||
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
if day_str in data_dict:
|
||||||
result.append(TimeSeriesData(
|
row = data_dict[day_str]
|
||||||
timestamp=day_str,
|
result.append(
|
||||||
requests=data['requests'],
|
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
||||||
cost=data['cost'],
|
)
|
||||||
tokens=data['tokens']
|
else:
|
||||||
))
|
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
||||||
current += timedelta(days=1)
|
current += timedelta(days=1)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -272,23 +255,21 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis
|
|||||||
|
|
||||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
"""获取最近活动"""
|
"""获取最近活动"""
|
||||||
records = list(
|
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
|
||||||
LLMUsage.select()
|
|
||||||
.order_by(LLMUsage.timestamp.desc())
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
|
|
||||||
activities = []
|
activities = []
|
||||||
for record in records:
|
for record in records:
|
||||||
activities.append({
|
activities.append(
|
||||||
'timestamp': record.timestamp.isoformat(),
|
{
|
||||||
'model': record.model_assign_name or record.model_name,
|
"timestamp": record.timestamp.isoformat(),
|
||||||
'request_type': record.request_type,
|
"model": record.model_assign_name or record.model_name,
|
||||||
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
"request_type": record.request_type,
|
||||||
'cost': record.cost or 0.0,
|
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||||
'time_cost': record.time_cost or 0.0,
|
"cost": record.cost or 0.0,
|
||||||
'status': record.status
|
"time_cost": record.time_cost or 0.0,
|
||||||
})
|
"status": record.status,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return activities
|
return activities
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ class TokenManager:
|
|||||||
config = {
|
config = {
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
"created_at": self._get_current_timestamp(),
|
"created_at": self._get_current_timestamp(),
|
||||||
"updated_at": self._get_current_timestamp()
|
"updated_at": self._get_current_timestamp(),
|
||||||
|
"first_setup_completed": False, # 标记首次配置未完成
|
||||||
}
|
}
|
||||||
|
|
||||||
self._save_config(config)
|
self._save_config(config)
|
||||||
@@ -90,6 +91,7 @@ class TokenManager:
|
|||||||
def _get_current_timestamp(self) -> str:
|
def _get_current_timestamp(self) -> str:
|
||||||
"""获取当前时间戳字符串"""
|
"""获取当前时间戳字符串"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
return datetime.now().isoformat()
|
return datetime.now().isoformat()
|
||||||
|
|
||||||
def get_token(self) -> str:
|
def get_token(self) -> str:
|
||||||
@@ -231,6 +233,53 @@ class TokenManager:
|
|||||||
|
|
||||||
return True, "Token 格式正确"
|
return True, "Token 格式正确"
|
||||||
|
|
||||||
|
def is_first_setup(self) -> bool:
|
||||||
|
"""
|
||||||
|
检查是否为首次配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否为首次配置
|
||||||
|
"""
|
||||||
|
config = self._load_config()
|
||||||
|
return not config.get("first_setup_completed", False)
|
||||||
|
|
||||||
|
def mark_setup_completed(self) -> bool:
|
||||||
|
"""
|
||||||
|
标记首次配置已完成
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否标记成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config = self._load_config()
|
||||||
|
config["first_setup_completed"] = True
|
||||||
|
config["setup_completed_at"] = self._get_current_timestamp()
|
||||||
|
self._save_config(config)
|
||||||
|
logger.info("首次配置已标记为完成")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"标记首次配置完成失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reset_setup_status(self) -> bool:
|
||||||
|
"""
|
||||||
|
重置首次配置状态,允许重新进入配置向导
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否重置成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config = self._load_config()
|
||||||
|
config["first_setup_completed"] = False
|
||||||
|
if "setup_completed_at" in config:
|
||||||
|
del config["setup_completed_at"]
|
||||||
|
self._save_config(config)
|
||||||
|
logger.info("首次配置状态已重置")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重置首次配置状态失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# 全局单例
|
# 全局单例
|
||||||
_token_manager_instance: Optional[TokenManager] = None
|
_token_manager_instance: Optional[TokenManager] = None
|
||||||
|
|||||||
148
src/webui/webui_server.py
Normal file
148
src/webui/webui_server.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from uvicorn import Config, Server as UvicornServer
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("webui_server")
|
||||||
|
|
||||||
|
|
||||||
|
class WebUIServer:
|
||||||
|
"""独立的 WebUI 服务器"""
|
||||||
|
|
||||||
|
def __init__(self, host: str = "0.0.0.0", port: int = 8001):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.app = FastAPI(title="MaiBot WebUI")
|
||||||
|
self._server = None
|
||||||
|
|
||||||
|
# 显示 Access Token
|
||||||
|
self._show_access_token()
|
||||||
|
|
||||||
|
# 重要:先注册 API 路由,再设置静态文件
|
||||||
|
self._register_api_routes()
|
||||||
|
self._setup_static_files()
|
||||||
|
|
||||||
|
def _show_access_token(self):
|
||||||
|
"""显示 WebUI Access Token"""
|
||||||
|
try:
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
current_token = token_manager.get_token()
|
||||||
|
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||||
|
logger.info("💡 请使用此 Token 登录 WebUI")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 获取 Access Token 失败: {e}")
|
||||||
|
|
||||||
|
def _setup_static_files(self):
|
||||||
|
"""设置静态文件服务"""
|
||||||
|
# 确保正确的 MIME 类型映射
|
||||||
|
mimetypes.init()
|
||||||
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
|
mimetypes.add_type("application/javascript", ".mjs")
|
||||||
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
mimetypes.add_type("application/json", ".json")
|
||||||
|
|
||||||
|
base_dir = Path(__file__).parent.parent.parent
|
||||||
|
static_path = base_dir / "webui" / "dist"
|
||||||
|
|
||||||
|
if not static_path.exists():
|
||||||
|
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
|
||||||
|
logger.warning("💡 请先构建前端: cd webui && npm run build")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not (static_path / "index.html").exists():
|
||||||
|
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
|
||||||
|
logger.warning("💡 请确认前端已正确构建")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理 SPA 路由 - 注意:这个路由优先级最低
|
||||||
|
@self.app.get("/{full_path:path}", include_in_schema=False)
|
||||||
|
async def serve_spa(full_path: str):
|
||||||
|
"""服务单页应用 - 只处理非 API 请求"""
|
||||||
|
# 如果是根路径,直接返回 index.html
|
||||||
|
if not full_path or full_path == "/":
|
||||||
|
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||||
|
|
||||||
|
# 检查是否是静态文件
|
||||||
|
file_path = static_path / full_path
|
||||||
|
if file_path.is_file() and file_path.exists():
|
||||||
|
# 自动检测 MIME 类型
|
||||||
|
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||||
|
return FileResponse(file_path, media_type=media_type)
|
||||||
|
|
||||||
|
# 其他路径返回 index.html(SPA 路由)
|
||||||
|
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||||
|
|
||||||
|
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
||||||
|
|
||||||
|
def _register_api_routes(self):
|
||||||
|
"""注册所有 WebUI API 路由"""
|
||||||
|
try:
|
||||||
|
# 导入所有 WebUI 路由
|
||||||
|
from src.webui.routes import router as webui_router
|
||||||
|
from src.webui.logs_ws import router as logs_router
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
self.app.include_router(webui_router)
|
||||||
|
self.app.include_router(logs_router)
|
||||||
|
|
||||||
|
logger.info("✅ WebUI API 路由已注册")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 注册 WebUI API 路由失败: {e}")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动服务器"""
|
||||||
|
config = Config(
|
||||||
|
app=self.app,
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
log_config=None,
|
||||||
|
access_log=False,
|
||||||
|
)
|
||||||
|
self._server = UvicornServer(config=config)
|
||||||
|
|
||||||
|
logger.info("🌐 WebUI 服务器启动中...")
|
||||||
|
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._server.serve()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ WebUI 服务器运行错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""关闭服务器"""
|
||||||
|
if self._server:
|
||||||
|
logger.info("正在关闭 WebUI 服务器...")
|
||||||
|
self._server.should_exit = True
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
|
||||||
|
logger.info("✅ WebUI 服务器已关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("⚠️ WebUI 服务器关闭超时")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ WebUI 服务器关闭失败: {e}")
|
||||||
|
finally:
|
||||||
|
self._server = None
|
||||||
|
|
||||||
|
|
||||||
|
# 全局 WebUI 服务器实例
|
||||||
|
_webui_server = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_webui_server() -> WebUIServer:
|
||||||
|
"""获取全局 WebUI 服务器实例"""
|
||||||
|
global _webui_server
|
||||||
|
if _webui_server is None:
|
||||||
|
# 从环境变量读取配置
|
||||||
|
host = os.getenv("WEBUI_HOST", "0.0.0.0")
|
||||||
|
port = int(os.getenv("WEBUI_PORT", "8001"))
|
||||||
|
_webui_server = WebUIServer(host=host, port=port)
|
||||||
|
return _webui_server
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "6.21.6"
|
version = "6.21.8"
|
||||||
|
|
||||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||||
#如果你想要修改配置文件,请递增version的值
|
#如果你想要修改配置文件,请递增version的值
|
||||||
@@ -211,6 +211,9 @@ show_prompt = false # 是否显示prompt
|
|||||||
show_replyer_prompt = false # 是否显示回复器prompt
|
show_replyer_prompt = false # 是否显示回复器prompt
|
||||||
show_replyer_reasoning = false # 是否显示回复器推理
|
show_replyer_reasoning = false # 是否显示回复器推理
|
||||||
show_jargon_prompt = false # 是否显示jargon相关提示词
|
show_jargon_prompt = false # 是否显示jargon相关提示词
|
||||||
|
show_memory_prompt = false # 是否显示记忆检索相关提示词
|
||||||
|
show_planner_prompt = false # 是否显示planner的prompt和原始返回结果
|
||||||
|
show_lpmm_paragraph = false # 是否显示lpmm找到的相关文段日志
|
||||||
|
|
||||||
[maim_message]
|
[maim_message]
|
||||||
auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
|
auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
# 麦麦主程序配置
|
||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
PORT=8000
|
PORT=8000
|
||||||
|
|
||||||
# WebUI 配置
|
# WebUI 独立服务器配置
|
||||||
# WEBUI_ENABLED=true
|
WEBUI_ENABLED=true
|
||||||
# WEBUI_MODE=development # 开发模式(需手动启动前端: cd webui && npm run dev,端口 7999)
|
WEBUI_MODE=production # 模式: development(开发) 或 production(生产)
|
||||||
# WEBUI_MODE=production # 生产模式(需先构建前端: cd webui && npm run build)
|
WEBUI_HOST=0.0.0.0 # WebUI 服务器监听地址
|
||||||
|
WEBUI_PORT=8001 # WebUI 服务器端口
|
||||||
BIN
webui/dist/assets/KaTeX_AMS-Regular-BQhdFMY1.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_AMS-Regular-BQhdFMY1.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_AMS-Regular-DMm9YOAa.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_AMS-Regular-DMm9YOAa.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_AMS-Regular-DRggAlZN.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_AMS-Regular-DRggAlZN.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-ATXxdsX0.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-ATXxdsX0.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-BEiXGLvX.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-BEiXGLvX.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-Dq_IR9rO.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Bold-Dq_IR9rO.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-CTRA-rTL.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-CTRA-rTL.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-Di6jR-x-.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-Di6jR-x-.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-wX97UBjC.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Caligraphic-Regular-wX97UBjC.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-BdnERNNW.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-BdnERNNW.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-BsDP51OF.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-BsDP51OF.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-CL6g_b3V.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Bold-CL6g_b3V.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-CB_wures.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-CB_wures.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-CTYiF6lA.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-CTYiF6lA.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-Dxdc4cR9.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Fraktur-Regular-Dxdc4cR9.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Main-Bold-Cx986IdX.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Main-Bold-Cx986IdX.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Main-Bold-Jm3AIy58.woff
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Main-Bold-Jm3AIy58.woff
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Main-Bold-waoOVXN0.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Main-Bold-waoOVXN0.ttf
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Main-BoldItalic-DxDJ3AOS.woff2
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Main-BoldItalic-DxDJ3AOS.woff2
vendored
Normal file
Binary file not shown.
BIN
webui/dist/assets/KaTeX_Main-BoldItalic-DzxPMmG6.ttf
vendored
Normal file
BIN
webui/dist/assets/KaTeX_Main-BoldItalic-DzxPMmG6.ttf
vendored
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user