Merge branch 'Mai-with-u:main' into feat-lpmm知识库加强
This commit is contained in:
161
.github/workflows/docker-image-dev.yml
vendored
Normal file
161
.github/workflows/docker-image-dev.yml
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
name: Docker Build and Push (Dev)
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
# push:
|
||||
# branches:
|
||||
# - dev
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to build'
|
||||
required: false
|
||||
default: 'dev'
|
||||
|
||||
# Workflow's jobs
|
||||
jobs:
|
||||
build-amd64:
|
||||
name: Build AMD64 Image
|
||||
runs-on: ubuntu-24.04
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: dev
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
- name: Clone maim_message
|
||||
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
|
||||
- name: Clone lpmm
|
||||
run: git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot
|
||||
|
||||
# Build and push AMD64 image by digest
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:dev-amd64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:dev-amd64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/maibot,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
build-arm64:
|
||||
name: Build ARM64 Image
|
||||
runs-on: ubuntu-24.04-arm
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: dev
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
- name: Clone maim_message
|
||||
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
|
||||
- name: Clone lpmm
|
||||
run: git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot
|
||||
|
||||
# Build and push ARM64 image by digest
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/arm64/v8
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:dev-arm64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:dev-arm64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/maibot,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
create-manifest:
|
||||
name: Create Multi-Arch Manifest
|
||||
runs-on: ubuntu-24.04
|
||||
needs:
|
||||
- build-amd64
|
||||
- build-arm64
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
|
||||
|
||||
- name: Create and Push Manifest
|
||||
run: |
|
||||
# 为每个标签创建多架构镜像
|
||||
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
|
||||
echo "Creating manifest for $tag"
|
||||
docker buildx imagetools create -t $tag \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maibot@${{ needs.build-amd64.outputs.digest }} \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maibot@${{ needs.build-arm64.outputs.digest }}
|
||||
done
|
||||
@@ -1,8 +1,6 @@
|
||||
name: Docker Build and Push
|
||||
name: Docker Build and Push (Main)
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
@@ -13,6 +11,11 @@ on:
|
||||
- "*.*.*"
|
||||
- "*.*.*-*"
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to build'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
# Workflow's jobs
|
||||
jobs:
|
||||
@@ -25,7 +28,6 @@ jobs:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
@@ -79,7 +81,6 @@ jobs:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'schedule' && 'dev' || github.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Clone required dependencies
|
||||
@@ -164,4 +165,4 @@ jobs:
|
||||
docker buildx imagetools create -t $tag \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maibot@${{ needs.build-amd64.outputs.digest }} \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maibot@${{ needs.build-arm64.outputs.digest }}
|
||||
done
|
||||
done
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -25,6 +25,7 @@ run_na.bat
|
||||
run_all_in_wt.bat
|
||||
run.bat
|
||||
log_debug/
|
||||
NapCat.Shell.Windows.OneKey
|
||||
run_amds.bat
|
||||
run_none.bat
|
||||
docs-mai/
|
||||
@@ -333,4 +334,5 @@ config.toml
|
||||
|
||||
interested_rates.txt
|
||||
MaiBot.code-workspace
|
||||
*.lock
|
||||
*.lock
|
||||
actionlint
|
||||
|
||||
151
EULA.md
151
EULA.md
@@ -1,8 +1,9 @@
|
||||
# **MaiBot最终用户许可协议**
|
||||
**版本:V1.1**
|
||||
**更新日期:2025年7月10日**
|
||||
**生效日期:2025年3月18日**
|
||||
**适用的MaiBot版本号:所有版本**
|
||||
# **MaiBot最终用户许可协议**
|
||||
|
||||
**版本:V1.2**
|
||||
**更新日期:2025年12月01日**
|
||||
**生效日期:2025年12月01日**
|
||||
**适用的MaiBot版本号:所有版本**
|
||||
|
||||
**2025© MaiBot项目团队**
|
||||
|
||||
@@ -14,130 +15,120 @@
|
||||
|
||||
**1.2** 在运行或使用本项目之前,您**必须阅读并同意本协议的所有条款**。未成年人或其它无/不完全民事行为能力责任人请**在监护人的陪同下**阅读并同意本协议。如果您不同意,则不得运行或使用本项目。在这种情况下,您应立即从您的设备上卸载或删除本项目及其所有副本。
|
||||
|
||||
|
||||
## 二、许可授权
|
||||
|
||||
### 源代码许可
|
||||
|
||||
**2.1** 您**了解**本项目的源代码是基于GPLv3(GNU通用公共许可证第三版)开源协议发布的。您**可以自由使用、修改、分发**本项目的源代码,但**必须遵守**GPLv3许可证的要求。详细内容请参阅项目仓库中的LICENSE文件。
|
||||
|
||||
**2.2** 您**了解**本项目的源代码中可能包含第三方开源代码,这些代码的许可证可能与GPLv3许可证不同。您**同意**在使用这些代码时**遵守**相应的许可证要求。
|
||||
|
||||
**2.2** 您**了解**本项目的源代码中可能包含第三方开源代码,这些代码的许可证可能与GPLv3许可证不同。您**同意**在使用这些代码时**遵守**相应的许可证要求.
|
||||
|
||||
### 输入输出内容授权
|
||||
|
||||
**2.3** 您**了解**本项目是使用您的配置信息、提交的指令(以下简称“输入内容”)和生成的内容(以下简称“输出内容”)构建请求发送到第三方API生成回复的机器人项目。
|
||||
|
||||
**2.4** 您**了解**本项目是使用您的配置信息、提交的指令(以下简称“输入内容”)和生成的内容(以下简称“输出内容”)构建请求发送到第三方生成回复的机器人项目。
|
||||
**2.4** 您**授权**本项目使用您的输入和输出内容按照项目的隐私政策用于以下行为:
|
||||
- 调用第三方API生成回复;
|
||||
- 调用第三方API用于构建本项目专用的存储于您部署或使用的数据库中的知识库和记忆库;
|
||||
- 收集并记录本项目专用的存储于您部署或使用的设备中的日志;
|
||||
|
||||
- 调用第三方API生成回复;
|
||||
- 调用第三方API用于构建本项目专用的存储于您使用的数据库中的知识库和记忆库;
|
||||
- 调用第三方开发的插件系统功能;
|
||||
- 收集并记录本项目专用的存储于您使用的设备中的日志;
|
||||
|
||||
**2.4** 您**了解**本项目的源代码中包含第三方API的调用代码,这些API的使用可能受到第三方的服务条款和隐私政策的约束。在使用这些API时,您**必须遵守**相应的服务条款。
|
||||
|
||||
**2.5** 项目团队**不对**第三方API的服务质量、稳定性、准确性、安全性负责,亦**不对**第三方API的服务变更、终止、限制等行为负责。
|
||||
|
||||
|
||||
### 插件系统授权和责任免责
|
||||
|
||||
**2.6** 您**了解**本项目包含插件系统功能,允许加载和使用由第三方开发者(非MaiBot核心开发组成员)开发的插件。这些第三方插件可能具有独立的许可证条款和使用协议。
|
||||
|
||||
**2.7** 您**了解并同意**:
|
||||
- 第三方插件的开发、维护、分发由其各自的开发者负责,**与MaiBot项目团队无关**;
|
||||
- 第三方插件的功能、质量、安全性、合规性**完全由插件开发者负责**;
|
||||
- MaiBot项目团队**仅提供**插件系统的技术框架,**不对**任何第三方插件的内容、行为或后果承担责任;
|
||||
- 您使用任何第三方插件的风险**完全由您自行承担**;
|
||||
|
||||
**2.8** 在使用第三方插件前,您**应当**:
|
||||
- 仔细阅读并遵守插件开发者提供的许可证条款和使用协议;
|
||||
- 自行评估插件的安全性、合规性和适用性;
|
||||
- 确保插件的使用符合您所在地区的法律法规要求;
|
||||
|
||||
|
||||
## 三、用户行为
|
||||
|
||||
**3.1** 您**了解**本项目会将您的配置信息、输入指令和生成内容发送到第三方API,您**不应**在输入指令和生成内容中包含以下内容:
|
||||
- 涉及任何国家或地区秘密、商业秘密或其他可能会对国家或地区安全或者公共利益造成不利影响的数据;
|
||||
- 涉及个人隐私、个人信息或其他敏感信息的数据;
|
||||
- 任何侵犯他人合法权益的内容;
|
||||
- 任何违反国家或地区法律法规、政策规定的内容;
|
||||
**3.1** 您**了解**本项目会将您的配置信息、输入指令和生成内容发送到第三方,您**不应**在输入指令和生成内容中包含以下内容:
|
||||
|
||||
- 涉及任何国家或地区秘密、商业秘密或其他可能会对国家或地区安全或者公共利益造成不利影响的数据;
|
||||
- 涉及个人隐私、个人信息或其他敏感信息的数据;
|
||||
- 任何侵犯他人合法权益的内容;
|
||||
- 任何违反国家或地区法律法规、政策规定的内容;
|
||||
|
||||
**3.2** 您**不应**将本项目用于以下用途:
|
||||
- 违反任何国家或地区法律法规、政策规定的行为;
|
||||
|
||||
- 违反任何国家或地区法律法规、政策规定的行为;
|
||||
|
||||
**3.3** 您**应当**自行确保您被存储在本项目的知识库、记忆库和日志中的输入和输出内容的合法性与合规性以及存储行为的合法性与合规性。您需**自行承担**由此产生的任何法律责任。
|
||||
|
||||
**3.4** 对于第三方插件的使用,您**不应**:
|
||||
- 使用可能存在安全漏洞、恶意代码或违法内容的插件;
|
||||
- 通过插件进行任何违反法律法规的行为;
|
||||
- 将插件用于侵犯他人权益或危害系统安全的用途;
|
||||
|
||||
**3.5** 您**承诺**对使用第三方插件的行为及其后果承担**完全责任**,包括但不限于因插件缺陷、恶意行为或不当使用造成的任何损失或法律纠纷。
|
||||
- 安装、使用任何来源不明或未经验证的第三方插件;
|
||||
- 使用任何违反法律法规、政策规定或第三方平台规则的第三方插件;
|
||||
|
||||
**3.5** 您**应当**自行确保您安装和使用的第三方插件的合法性与合规性以及安装和使用行为的合法性与合规性。您需**自行承担**由此产生的任何法律责任。
|
||||
|
||||
**3.6** 由于本项目会将您的输入指令和生成内容发送到第三方,当您将本项目用于第三方交流环境(如与除您以外的人私聊、群聊、论坛、直播等)时,您**应当**事先明确告知其他交流参与者本项目的使用情况,包括但不限于:
|
||||
|
||||
- 本项目的输出内容是由人工智能生成的;
|
||||
- 本项目会将交流内容发送到第三方;
|
||||
- 本项目的隐私政策和用户行为要求;
|
||||
|
||||
您需**自行承担**由此产生的任何后果和法律责任。
|
||||
|
||||
**3.7** 项目团队**不鼓励**也**不支持**将本项目用于商业用途,但若您确实需要将本项目用于商业用途,您**应当**标明项目地址(如“本项目由MaiBot(<https://github.com/Mai-with-u/MaiBot>)驱动”),并**自行承担**由此产生的任何法律责任。
|
||||
|
||||
## 四、免责条款
|
||||
|
||||
**4.1** 本项目的输出内容依赖第三方API,**不受**项目团队控制,亦**不代表**项目团队的观点。
|
||||
|
||||
**4.2** 除本协议条目2.4提到的隐私政策之外,项目团队**不会**对您提供任何形式的担保,亦**不对**使用本项目的造成的任何后果负责。
|
||||
**4.2** 除本协议条目2.4提到的隐私政策之外,项目团队**不会**对您提供任何形式的担保,亦**不对**使用本项目的造成的任何直接或间接后果负责。
|
||||
|
||||
**4.3** 关于第三方插件,项目团队**明确声明**:
|
||||
- 项目团队**不对**任何第三方插件的功能、安全性、稳定性、合规性或适用性提供任何形式的保证或担保;
|
||||
- 项目团队**不对**因使用第三方插件而产生的任何直接或间接损失、数据丢失、系统故障、安全漏洞、法律纠纷或其他后果承担责任;
|
||||
- 第三方插件的质量问题、技术支持、bug修复等事宜应**直接联系插件开发者**,与项目团队无关;
|
||||
- 项目团队**保留**在不另行通知的情况下,对插件系统功能进行修改、限制或移除的权利;
|
||||
**4.3** 关于第三方插件,项目团队**声明**:
|
||||
|
||||
- 项目团队**不对**任何第三方插件的功能、安全性、稳定性、合规性或适用性提供任何形式的保证或担保;
|
||||
- 项目团队**不对**因使用第三方插件而产生的任何直接或间接后果承担责任;
|
||||
- 项目团队**不对**第三方插件的质量问题、技术支持、bug修复等事宜负责。如有相关问题,应**直接联系插件开发者**;
|
||||
|
||||
## 五、其他条款
|
||||
|
||||
**5.1** 项目团队有权**随时修改本协议的条款**,但**没有**义务通知您。修改后的协议将在本项目的新版本中生效,您应定期检查本协议的最新版本。
|
||||
**5.1** 项目团队有权**随时修改本协议的条款**,但**无义务**通知您。修改后的协议将在本项目的新版本中推送,您应定期检查本协议的最新版本。
|
||||
|
||||
**5.2** 项目团队**保留**本协议的最终解释权。
|
||||
|
||||
|
||||
## 附录:其他重要须知
|
||||
|
||||
### 一、过往版本使用条件追溯
|
||||
### 一、风险提示
|
||||
|
||||
**1.1** 对于本项目此前未配备 EULA 协议的版本,自本协议发布之日起,若用户希望继续使用本项目,应在本协议生效后的合理时间内,通过升级到最新版本并同意本协议全部条款。若在本版协议生效日(2025年3月18日)之后,用户仍使用此前无 EULA 协议的项目版本且未同意本协议,则用户无权继续使用,项目方有权采取措施阻止其使用行为,并保留追究相关法律责任的权利。
|
||||
**1.1** 隐私安全风险
|
||||
|
||||
- 本项目会将您的配置信息、输入指令和生成内容发送到第三方API,而这些API的服务质量、稳定性、准确性、安全性不受项目团队控制。
|
||||
- 本项目会收集您的输入和输出内容,用于构建本项目专用的知识库和记忆库,以提高回复的准确性和连贯性。
|
||||
|
||||
### 二、风险提示
|
||||
**因此,为了保障您的隐私信息安全,请注意以下事项:**
|
||||
|
||||
**2.1 隐私安全风险**
|
||||
- 避免在涉及个人隐私、个人信息或其他敏感信息的环境中使用本项目;
|
||||
- 避免在不可信的环境中使用本项目;
|
||||
|
||||
- 本项目会将您的配置信息、输入指令和生成内容发送到第三方API,而这些API的服务质量、稳定性、准确性、安全性不受项目团队控制。
|
||||
- 本项目会收集您的输入和输出内容,用于构建本项目专用的知识库和记忆库,以提高回复的准确性和连贯性。
|
||||
|
||||
**因此,为了保障您的隐私信息安全,请注意以下事项:**
|
||||
|
||||
- 避免在涉及个人隐私、个人信息或其他敏感信息的环境中使用本项目;
|
||||
- 避免在不可信的环境中使用本项目;
|
||||
|
||||
**2.2 精神健康风险**
|
||||
**1.2** 精神健康风险
|
||||
|
||||
本项目仅为工具型机器人,不具备情感交互能力。建议用户:
|
||||
- 避免过度依赖AI回复处理现实问题或情绪困扰;
|
||||
- 如感到心理不适,请及时寻求专业心理咨询服务。
|
||||
- 如遇心理困扰,请寻求专业帮助(全国心理援助热线:12355)。
|
||||
|
||||
**2.3 第三方插件风险**
|
||||
- 避免过度依赖AI回复处理现实问题或情绪困扰;
|
||||
- 如感到心理不适,请及时寻求专业心理咨询服务;
|
||||
- 如遇心理困扰,请寻求专业帮助(全国心理援助热线:12355);
|
||||
|
||||
**1.3** 第三方插件风险
|
||||
|
||||
本项目的插件系统允许加载第三方开发的插件,这可能带来以下风险:
|
||||
- **安全风险**:第三方插件可能包含恶意代码、安全漏洞或未知的安全威胁;
|
||||
- **稳定性风险**:插件可能导致系统崩溃、性能下降或功能异常;
|
||||
- **隐私风险**:插件可能收集、传输或泄露您的个人信息和数据;
|
||||
- **合规风险**:插件的功能或行为可能违反相关法律法规或平台规则;
|
||||
- **兼容性风险**:插件可能与主程序或其他插件产生冲突;
|
||||
|
||||
**因此,在使用第三方插件时,请务必:**
|
||||
- **安全风险**:第三方插件可能包含恶意代码、安全漏洞或未知的安全威胁;
|
||||
- **稳定性风险**:插件可能导致系统崩溃、性能下降或功能异常;
|
||||
- **隐私风险**:插件可能收集、传输或泄露您的个人信息和数据;
|
||||
- **合规风险**:插件的功能或行为可能违反相关法律法规或平台规则;
|
||||
- **兼容性风险**:插件可能与主程序或其他插件产生冲突;
|
||||
|
||||
- 仅从可信来源获取和安装插件;
|
||||
- 在安装前仔细了解插件的功能、权限和开发者信息;
|
||||
- 定期检查和更新已安装的插件;
|
||||
- 如发现插件异常行为,请立即停止使用并卸载;
|
||||
- 对插件的使用后果承担完全责任;
|
||||
**因此,在使用第三方插件时,请务必:**
|
||||
|
||||
### 三、其他
|
||||
**3.1 争议解决**
|
||||
- 本协议适用中国法律,争议提交相关地区法院管辖;
|
||||
- 若因GPLv3许可产生纠纷,以许可证官方解释为准。
|
||||
- 仅从可信来源获取和安装插件;
|
||||
- 在安装前仔细了解插件的功能、权限和开发者信息;
|
||||
- 定期检查和更新已安装的插件;
|
||||
- 如发现插件异常行为,请立即停止使用并卸载;
|
||||
|
||||
### 二、其他
|
||||
|
||||
**2.1** 争议解决
|
||||
|
||||
- 本协议适用中国法律,争议提交相关地区法院管辖;
|
||||
- 若因GPLv3许可产生纠纷,以许可证官方解释为准。
|
||||
|
||||
19
README.md
19
README.md
@@ -46,16 +46,20 @@
|
||||
## 🔥 更新和安装
|
||||
|
||||
|
||||
**最新版本: v0.11.5** ([更新日志](changelogs/changelog.md))
|
||||
**最新版本: v0.11.6** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
|
||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||
|
||||
注意,启动器处于早期开发版本,仅支持MacOS
|
||||
|
||||
**GitHub 分支说明:**
|
||||
- `main`: 稳定发布版本(推荐)
|
||||
|
||||
|
||||
- `dev`: 开发测试版本(不稳定)
|
||||
|
||||
- `classical`: 经典版本(停止维护)
|
||||
|
||||
### 最新版本部署教程
|
||||
@@ -69,18 +73,23 @@
|
||||
|
||||
## 💬 讨论
|
||||
|
||||
**技术交流群:**
|
||||
**技术交流群/答疑群:**
|
||||
[麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) |
|
||||
[麦麦脑磁图](https://qm.qq.com/q/wlH5eT8OmQ) |
|
||||
[麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) |
|
||||
[麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY)
|
||||
[麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) |
|
||||
|
||||
为了维持技术交流和互帮互助的氛围,请不要在技术交流群讨论过多无关内容~
|
||||
|
||||
**聊天吹水群:**
|
||||
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
|
||||
|
||||
麦麦相关闲聊群,此群仅用于聊天,提问部署/技术问题可能不会快速得到答案
|
||||
|
||||
**插件开发/测试版讨论群:**
|
||||
- [插件开发群](https://qm.qq.com/q/1036092828)
|
||||
|
||||
进阶内容,包括插件开发,测试版使用等等
|
||||
|
||||
## 📚 文档
|
||||
|
||||
**部分内容可能更新不够及时,请注意版本对应**
|
||||
|
||||
131
bot.py
131
bot.py
@@ -5,16 +5,22 @@ import time
|
||||
import platform
|
||||
import traceback
|
||||
import shutil
|
||||
import sys
|
||||
import subprocess
|
||||
from dotenv import load_dotenv
|
||||
from pathlib import Path
|
||||
from rich.traceback import install
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
||||
|
||||
# 设置工作目录为脚本所在目录
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
os.chdir(script_dir)
|
||||
|
||||
env_path = Path(__file__).parent / ".env"
|
||||
template_env_path = Path(__file__).parent / "template" / "template.env"
|
||||
|
||||
if env_path.exists():
|
||||
load_dotenv(str(env_path), override=True)
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
try:
|
||||
if template_env_path.exists():
|
||||
@@ -28,23 +34,86 @@ else:
|
||||
print(f"自动创建 .env 失败: {e}")
|
||||
raise
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
|
||||
|
||||
initialize_logging()
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("main")
|
||||
|
||||
# 定义重启退出码
|
||||
RESTART_EXIT_CODE = 42
|
||||
|
||||
def run_runner_process():
|
||||
"""
|
||||
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
||||
处理重启请求 (退出码 42) 和 Ctrl+C 信号。
|
||||
"""
|
||||
script_file = sys.argv[0]
|
||||
python_executable = sys.executable
|
||||
|
||||
# 设置环境变量,标记子进程为 Worker 进程
|
||||
env = os.environ.copy()
|
||||
env["MAIBOT_WORKER_PROCESS"] = "1"
|
||||
|
||||
while True:
|
||||
logger.info(f"正在启动 {script_file}...")
|
||||
|
||||
# 启动子进程 (Worker)
|
||||
# 使用 sys.executable 确保使用相同的 Python 解释器
|
||||
cmd = [python_executable, script_file] + sys.argv[1:]
|
||||
|
||||
process = subprocess.Popen(cmd, env=env)
|
||||
|
||||
try:
|
||||
# 等待子进程结束
|
||||
return_code = process.wait()
|
||||
|
||||
if return_code == RESTART_EXIT_CODE:
|
||||
logger.info("检测到重启请求 (退出码 42),正在重启...")
|
||||
time.sleep(1) # 稍作等待
|
||||
continue
|
||||
else:
|
||||
logger.info(f"程序已退出 (退出码 {return_code})")
|
||||
sys.exit(return_code)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# 向子进程发送终止信号
|
||||
if process.poll() is None:
|
||||
# 在 Windows 上,Ctrl+C 通常已经发送给了子进程(如果它们共享控制台)
|
||||
# 但为了保险,我们可以尝试 terminate
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("子进程未响应,强制关闭...")
|
||||
process.kill()
|
||||
sys.exit(0)
|
||||
|
||||
# 检查是否是 Worker 进程
|
||||
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
|
||||
# 此时应该作为 Runner 运行。
|
||||
if os.environ.get("MAIBOT_WORKER_PROCESS") != "1":
|
||||
if __name__ == "__main__":
|
||||
run_runner_process()
|
||||
# 如果作为模块导入,不执行 Runner 逻辑,但也不应该执行下面的 Worker 逻辑
|
||||
sys.exit(0)
|
||||
|
||||
# 以下是 Worker 进程的逻辑
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
# from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
|
||||
# initialize_logging()
|
||||
|
||||
from src.main import MainSystem # noqa
|
||||
from src.manager.async_task_manager import async_task_manager # noqa
|
||||
|
||||
|
||||
logger = get_logger("main")
|
||||
# logger = get_logger("main")
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
# install(extra_lines=3)
|
||||
|
||||
# 设置工作目录为脚本所在目录
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
os.chdir(script_dir)
|
||||
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# os.chdir(script_dir)
|
||||
logger.info(f"已设置工作目录为: {script_dir}")
|
||||
|
||||
|
||||
@@ -58,6 +127,33 @@ app = None
|
||||
loop = None
|
||||
|
||||
|
||||
def print_opensource_notice():
|
||||
"""打印开源项目提示,防止倒卖"""
|
||||
from colorama import init, Fore, Style
|
||||
|
||||
init()
|
||||
|
||||
notice_lines = [
|
||||
"",
|
||||
f"{Fore.CYAN}{'═' * 70}{Style.RESET_ALL}",
|
||||
f"{Fore.GREEN} ★ MaiBot - 开源 AI 聊天机器人 ★{Style.RESET_ALL}",
|
||||
f"{Fore.CYAN}{'─' * 70}{Style.RESET_ALL}",
|
||||
f"{Fore.YELLOW} 本项目是完全免费的开源软件,基于 GPL-3.0 协议发布{Style.RESET_ALL}",
|
||||
f"{Fore.WHITE} 如果有人向你「出售本软件」,你被骗了!{Style.RESET_ALL}",
|
||||
"",
|
||||
f"{Fore.WHITE} 官方仓库: {Fore.BLUE}https://github.com/MaiM-with-u/MaiBot {Style.RESET_ALL}",
|
||||
f"{Fore.WHITE} 官方文档: {Fore.BLUE}https://docs.mai-mai.org {Style.RESET_ALL}",
|
||||
f"{Fore.WHITE} 官方群聊: {Fore.BLUE}766798517{Style.RESET_ALL}",
|
||||
f"{Fore.CYAN}{'─' * 70}{Style.RESET_ALL}",
|
||||
f"{Fore.RED} ⚠ 将本软件作为「商品」倒卖、隐瞒开源性质均违反协议!{Style.RESET_ALL}",
|
||||
f"{Fore.CYAN}{'═' * 70}{Style.RESET_ALL}",
|
||||
"",
|
||||
]
|
||||
|
||||
for line in notice_lines:
|
||||
print(line)
|
||||
|
||||
|
||||
def easter_egg():
|
||||
# 彩蛋
|
||||
from colorama import init, Fore
|
||||
@@ -78,6 +174,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
||||
# 关闭 WebUI 服务器
|
||||
try:
|
||||
from src.webui.webui_server import get_webui_server
|
||||
|
||||
webui_server = get_webui_server()
|
||||
if webui_server and webui_server._server:
|
||||
await webui_server.shutdown()
|
||||
@@ -202,6 +299,9 @@ def raw_main():
|
||||
if platform.system().lower() != "windows":
|
||||
time.tzset() # type: ignore
|
||||
|
||||
# 打印开源提示(防止倒卖)
|
||||
print_opensource_notice()
|
||||
|
||||
check_eula()
|
||||
logger.info("检查EULA和隐私条款完成")
|
||||
|
||||
@@ -236,15 +336,15 @@ if __name__ == "__main__":
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("收到中断信号,正在优雅关闭...")
|
||||
|
||||
|
||||
# 取消主任务
|
||||
if 'main_tasks' in locals() and main_tasks and not main_tasks.done():
|
||||
if "main_tasks" in locals() and main_tasks and not main_tasks.done():
|
||||
main_tasks.cancel()
|
||||
try:
|
||||
loop.run_until_complete(main_tasks)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# 执行优雅关闭
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
@@ -253,6 +353,15 @@ if __name__ == "__main__":
|
||||
logger.error(f"优雅关闭时发生错误: {ge}")
|
||||
# 新增:检测外部请求关闭
|
||||
|
||||
except SystemExit as e:
|
||||
# 捕获 SystemExit (例如 sys.exit()) 并保留退出代码
|
||||
if isinstance(e.code, int):
|
||||
exit_code = e.code
|
||||
else:
|
||||
exit_code = 1 if e.code else 0
|
||||
if exit_code == RESTART_EXIT_CODE:
|
||||
logger.info("收到重启信号,准备退出并请求重启...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"主程序发生异常: {str(e)} {str(traceback.format_exc())}")
|
||||
exit_code = 1 # 标记发生错误
|
||||
|
||||
@@ -1,17 +1,52 @@
|
||||
# Changelog
|
||||
|
||||
## [0.11.5] - 2025-11-26
|
||||
### 主要功能更改
|
||||
|
||||
|
||||
## [0.11.6] - 2025-12-2
|
||||
### 🌟 重大更新
|
||||
- 大幅提高记忆检索能力,略微提高token消耗
|
||||
- 重构历史消息概括器,更好的主题记忆
|
||||
- 日志查看器性能革命性优化
|
||||
- 支持可视化查看麦麦LPMM知识图谱
|
||||
- 支持根据不同的模型提供商/模板/URL自动获取模型,可以不用手动输入模型了
|
||||
- 新增Baka引导系统,使用React-JoyTour实现很棒的用户引导系统(让Baka也能看懂!)
|
||||
- 本地聊天室功能!!你可以直接在WebUI网页和麦麦聊天!!
|
||||
- 使用cookie模式替换原有的LocalStorage Token存储,可能需要重新手动输入一遍Token
|
||||
- WebUI本地聊天室支持用户模拟和平台模拟的功能!
|
||||
- WebUI新增黑话管理 & 编辑界面
|
||||
|
||||
### 细节功能更改
|
||||
- 可选记忆识别中是否启用jargon
|
||||
- 解耦表情包识别和图片识别
|
||||
- 修复部分破损json的解析问题
|
||||
- 黑话更高的提取效率,增加提取准确性
|
||||
- 升级jargon,更快更精准
|
||||
- 新增Lpmm可视化
|
||||
|
||||
### webui细节更新
|
||||
- 修复侧边栏收起、UI及表格横向滚动等问题,优化Toast动画
|
||||
- 修复适配器配置、插件克隆、表情包注册等相关BUG
|
||||
- 新增适配器/模型预设模式及模板,自动填写URL和类型
|
||||
- 支持模型任务列表拖拽排序
|
||||
- 更新重启弹窗和首次引导内容
|
||||
- 多处界面命名及标题优化,如模型配置相关菜单重命名和描述更新
|
||||
- 修复聊天配置“提及回复”相关开关命名错误
|
||||
- 调试配置新增“显示记忆/Planner/LPMM Prompt”选项
|
||||
- 新增卡片尺寸、排序、字号、行间距等个性化功能
|
||||
- 聊天ID及群聊选择优化,显示可读名称
|
||||
- 聊天编辑界面精简字段,新增后端聊天列表API支持
|
||||
- 默认行间距减小,显示更紧凑
|
||||
- 修复页面滚动、表情包排序、发言频率为0等问题
|
||||
- 新增React异常Traceback界面及模型列表搜索
|
||||
- 更新WebUI Icon,修复适配器docker路径等问题
|
||||
- 插件配置可视化编辑,表单控件/元数据/布局类型扩展
|
||||
- 新增插件API与开发文档
|
||||
- 新增机器人状态卡片和快速操作按钮
|
||||
- 调整饼图显示、颜色算法,修复部分统计及解析错误
|
||||
- 新增缓存、WebSocket配置
|
||||
- 表情包支持上传和缩略图
|
||||
- 修复首页极端加载、重启后CtrlC失效、主程序配置移动端适配等问题
|
||||
- 新增表达反思设置和WebUI聊天室“思考中”占位组件
|
||||
- 细节如移除部分字段或UI控件、优化按钮/弹窗/编辑逻辑等
|
||||
|
||||
## [0.11.5] - 2025-11-21
|
||||
### 🌟 重大更新
|
||||
- WebUI 现支持手动重启麦麦,曲线救国版“热重载”
|
||||
|
||||
@@ -27,7 +27,7 @@ services:
|
||||
# image: infinitycat/maibot:dev
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
|
||||
# - EULA_AGREE=1b662741904d7155d1ce1c00b3530d0d # 同意EULA
|
||||
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
|
||||
ports:
|
||||
- "18001:8001" # webui端口
|
||||
@@ -35,11 +35,12 @@ services:
|
||||
volumes:
|
||||
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件
|
||||
- ./docker-config/mmc:/MaiMBot/config # 持久化bot配置文件
|
||||
- ./docker-config/adapters:/MaiMBot/adapters-config # adapter配置文件夹映射
|
||||
- ./data/MaiMBot/maibot_statistics.html:/MaiMBot/maibot_statistics.html #统计数据输出
|
||||
- ./data/MaiMBot:/MaiMBot/data # 共享目录
|
||||
- ./data/MaiMBot/plugins:/MaiMBot/plugins # 插件目录
|
||||
- ./data/MaiMBot/logs:/MaiMBot/logs # 日志目录
|
||||
- site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包
|
||||
# - site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包,需要时启用
|
||||
restart: always
|
||||
networks:
|
||||
- maim_bot
|
||||
@@ -86,8 +87,8 @@ services:
|
||||
# networks:
|
||||
# - maim_bot
|
||||
|
||||
volumes:
|
||||
site-packages:
|
||||
# volumes: # 若需要持久化Python包时启用
|
||||
# site-packages:
|
||||
networks:
|
||||
maim_bot:
|
||||
driver: bridge
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "MaiBot"
|
||||
version = "0.11.0"
|
||||
version = "0.11.6"
|
||||
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
@@ -14,6 +14,7 @@ dependencies = [
|
||||
"json-repair>=0.47.6",
|
||||
"maim-message",
|
||||
"matplotlib>=3.10.3",
|
||||
"msgpack>=1.1.2",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
"pandas>=2.3.1",
|
||||
@@ -23,6 +24,7 @@ dependencies = [
|
||||
"pydantic>=2.11.7",
|
||||
"pypinyin>=0.54.0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"python-multipart>=0.0.20",
|
||||
"quick-algo>=0.1.3",
|
||||
"rich>=14.0.0",
|
||||
"ruff>=0.12.2",
|
||||
@@ -32,6 +34,7 @@ dependencies = [
|
||||
"tomlkit>=0.13.3",
|
||||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"zstandard>=0.25.0",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ pyarrow>=20.0.0
|
||||
pydantic>=2.11.7
|
||||
pypinyin>=0.54.0
|
||||
python-dotenv>=1.1.1
|
||||
python-multipart>=0.0.20
|
||||
quick-algo>=0.1.3
|
||||
rich>=14.0.0
|
||||
ruff>=0.12.2
|
||||
|
||||
@@ -235,13 +235,13 @@ class BrainChatting:
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# ReflectTracker Check
|
||||
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
|
||||
# -------------------------------------------------------------------------
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
|
||||
|
||||
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
|
||||
if tracker:
|
||||
resolved = await tracker.trigger_tracker()
|
||||
@@ -254,6 +254,7 @@ class BrainChatting:
|
||||
# 检查是否需要提问表达反思
|
||||
# -------------------------------------------------------------------------
|
||||
from src.express.expression_reflector import expression_reflector_manager
|
||||
|
||||
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
||||
asyncio.create_task(reflector.check_and_ask())
|
||||
|
||||
|
||||
@@ -356,7 +356,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。")
|
||||
else:
|
||||
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
||||
logger.debug(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
@@ -46,6 +47,8 @@ class FrequencyControl:
|
||||
self.frequency_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
|
||||
)
|
||||
# 频率调整锁,防止并发执行
|
||||
self._adjust_lock = asyncio.Lock()
|
||||
|
||||
def get_talk_frequency_adjust(self) -> float:
|
||||
"""获取发言频率调整值"""
|
||||
@@ -56,68 +59,78 @@ class FrequencyControl:
|
||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||
|
||||
async def trigger_frequency_adjust(self) -> None:
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_frequency_adjust_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20:
|
||||
return
|
||||
else:
|
||||
new_msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._adjust_lock:
|
||||
# 在锁内检查,避免并发触发
|
||||
current_time = time.time()
|
||||
previous_adjust_time = self.last_frequency_adjust_time
|
||||
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_frequency_adjust_time,
|
||||
timestamp_end=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
timestamp_start=previous_adjust_time,
|
||||
timestamp_end=current_time,
|
||||
)
|
||||
|
||||
message_str = build_readable_messages(
|
||||
new_msg_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=False,
|
||||
)
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
if current_time - previous_adjust_time < 160 or len(msg_list) <= 20:
|
||||
return
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"frequency_adjust_prompt",
|
||||
name_block=name_block,
|
||||
time_block=time_block,
|
||||
message_str=message_str,
|
||||
)
|
||||
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
|
||||
prompt,
|
||||
)
|
||||
# 立即更新调整时间,防止并发触发
|
||||
self.last_frequency_adjust_time = current_time
|
||||
|
||||
# logger.info(f"频率调整 prompt: {prompt}")
|
||||
# logger.info(f"频率调整 response: {response}")
|
||||
try:
|
||||
new_msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=previous_adjust_time,
|
||||
timestamp_end=current_time,
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"频率调整 prompt: {prompt}")
|
||||
logger.info(f"频率调整 response: {response}")
|
||||
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
|
||||
message_str = build_readable_messages(
|
||||
new_msg_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=False,
|
||||
)
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"frequency_adjust_prompt",
|
||||
name_block=name_block,
|
||||
time_block=time_block,
|
||||
message_str=message_str,
|
||||
)
|
||||
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
|
||||
prompt,
|
||||
)
|
||||
|
||||
# LLM依然输出过多内容时取消本次调整。合法最多4个字,但有的模型可能会输出一些markdown换行符等,需要长度宽限
|
||||
if len(response) < 20:
|
||||
if "过于频繁" in response:
|
||||
logger.info(f"频率调整: 过于频繁,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 0.8))
|
||||
elif "过少" in response:
|
||||
logger.info(f"频率调整: 过少,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
|
||||
self.last_frequency_adjust_time = time.time()
|
||||
else:
|
||||
logger.info("频率调整:response不符合要求,取消本次调整")
|
||||
# logger.info(f"频率调整 prompt: {prompt}")
|
||||
# logger.info(f"频率调整 response: {response}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"频率调整 prompt: {prompt}")
|
||||
logger.info(f"频率调整 response: {response}")
|
||||
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
|
||||
|
||||
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
|
||||
|
||||
# LLM依然输出过多内容时取消本次调整。合法最多4个字,但有的模型可能会输出一些markdown换行符等,需要长度宽限
|
||||
if len(response) < 20:
|
||||
if "过于频繁" in response:
|
||||
logger.info(f"频率调整: 过于频繁,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 0.8))
|
||||
elif "过少" in response:
|
||||
logger.info(f"频率调整: 过少,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
|
||||
except Exception as e:
|
||||
logger.error(f"频率调整失败: {e}")
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
|
||||
class FrequencyControlManager:
|
||||
|
||||
@@ -29,7 +29,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.chat_history_summarizer import ChatHistorySummarizer
|
||||
from src.hippo_memorizer.chat_history_summarizer import ChatHistorySummarizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -400,7 +400,7 @@ class HeartFChatting:
|
||||
# ReflectTracker Check
|
||||
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
||||
await reflector.check_and_ask()
|
||||
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
|
||||
@@ -410,7 +410,6 @@ class HeartFChatting:
|
||||
reflect_tracker_manager.remove_tracker(self.stream_id)
|
||||
logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.")
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
@@ -427,7 +426,9 @@ class HeartFChatting:
|
||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
|
||||
)
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
@@ -39,6 +39,11 @@ class HeartFCMessageReceiver:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
try:
|
||||
# 通知消息不处理
|
||||
if message.is_notify:
|
||||
logger.debug("通知消息,跳过处理")
|
||||
return
|
||||
|
||||
# 1. 消息解析与初始化
|
||||
userinfo = message.message_info.user_info
|
||||
chat = message.chat_stream
|
||||
|
||||
@@ -33,6 +33,11 @@ class MessageStorage:
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 通知消息不存储
|
||||
if isinstance(message, MessageRecv) and message.is_notify:
|
||||
logger.debug("通知消息,跳过存储")
|
||||
return
|
||||
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
# print(message)
|
||||
|
||||
@@ -15,12 +15,72 @@ install(extra_lines=3)
|
||||
|
||||
logger = get_logger("sender")
|
||||
|
||||
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||
_webui_chat_broadcaster = None
|
||||
|
||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster():
|
||||
"""获取 WebUI 聊天室广播器"""
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
|
||||
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
except ImportError:
|
||||
_webui_chat_broadcaster = (None, None)
|
||||
return _webui_chat_broadcaster
|
||||
|
||||
|
||||
def is_webui_virtual_group(group_id: str) -> bool:
|
||||
"""检查是否是 WebUI 虚拟群"""
|
||||
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||
platform = message.message_info.platform
|
||||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||
|
||||
if is_webui_message and chat_manager is not None:
|
||||
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
|
||||
import time
|
||||
from src.config.config import global_config
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "bot_message",
|
||||
"content": message.processed_plain_text,
|
||||
"message_type": "text",
|
||||
"timestamp": time.time(),
|
||||
"group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签
|
||||
"sender": {
|
||||
"name": global_config.bot.nickname,
|
||||
"avatar": None,
|
||||
"is_bot": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||
# 无需手动保存
|
||||
|
||||
if show_log:
|
||||
if is_webui_virtual_group(group_id):
|
||||
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 虚拟群 (平台: {platform})")
|
||||
else:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
|
||||
return True
|
||||
|
||||
# 直接调用API发送消息
|
||||
await get_global_api().send_message(message)
|
||||
if show_log:
|
||||
|
||||
@@ -181,8 +181,12 @@ class ActionPlanner:
|
||||
found_ids = set(matches)
|
||||
missing_ids = found_ids - available_ids
|
||||
if missing_ids:
|
||||
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...")
|
||||
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中")
|
||||
logger.info(
|
||||
f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}..."
|
||||
)
|
||||
logger.info(
|
||||
f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中"
|
||||
)
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
msg_id = match.group(0)
|
||||
@@ -222,7 +226,8 @@ class ActionPlanner:
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
target_message_id = action_json.get("target_message_id")
|
||||
if target_message_id:
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
if target_message is None:
|
||||
@@ -233,6 +238,14 @@ class ActionPlanner:
|
||||
target_message = message_id_list[-1][1]
|
||||
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
||||
|
||||
if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message):
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
|
||||
)
|
||||
reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
|
||||
action = "no_reply"
|
||||
target_message = None
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time", "no_reply_until_call"]
|
||||
@@ -277,6 +290,16 @@ class ActionPlanner:
|
||||
|
||||
return action_planner_infos
|
||||
|
||||
def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
|
||||
"""判断消息是否由机器人自身发送"""
|
||||
try:
|
||||
return str(message.user_info.user_id) == str(global_config.bot.qq_account) and (
|
||||
message.user_info.platform or ""
|
||||
) == (global_config.bot.platform or "")
|
||||
except AttributeError:
|
||||
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
|
||||
return False
|
||||
|
||||
async def plan(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
@@ -754,20 +777,20 @@ class ActionPlanner:
|
||||
json_content_start = json_start_pos + 7 # ```json的长度
|
||||
# 提取从```json之后到内容结尾的所有内容
|
||||
incomplete_json_str = content[json_content_start:].strip()
|
||||
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if json_start_pos > 0:
|
||||
reasoning_content = content[:json_start_pos].strip()
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
|
||||
if incomplete_json_str:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
json_str = json_str.strip()
|
||||
|
||||
|
||||
if json_str:
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
@@ -782,7 +805,7 @@ class ActionPlanner:
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
|
||||
@@ -839,8 +839,6 @@ class DefaultReplyer:
|
||||
continue
|
||||
|
||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||
if duration > 12:
|
||||
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
|
||||
|
||||
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||
|
||||
@@ -760,8 +760,6 @@ class PrivateReplyer:
|
||||
continue
|
||||
|
||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||
if duration > 12:
|
||||
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
|
||||
|
||||
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||
|
||||
@@ -1,493 +0,0 @@
|
||||
"""
|
||||
聊天内容概括器
|
||||
用于累积、打包和压缩聊天记录
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Set
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import message_api
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("chat_history_summarizer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageBatch:
|
||||
"""消息批次"""
|
||||
|
||||
messages: List[DatabaseMessages]
|
||||
start_time: float
|
||||
end_time: float
|
||||
is_preparing: bool = False # 是否处于准备结束模式
|
||||
|
||||
|
||||
class ChatHistorySummarizer:
|
||||
"""聊天内容概括器"""
|
||||
|
||||
def __init__(self, chat_id: str, check_interval: int = 60):
|
||||
"""
|
||||
初始化聊天内容概括器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
check_interval: 定期检查间隔(秒),默认60秒
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self._chat_display_name = self._get_chat_display_name()
|
||||
self.log_prefix = f"[{self._chat_display_name}]"
|
||||
|
||||
# 记录时间点,用于计算新消息
|
||||
self.last_check_time = time.time()
|
||||
|
||||
# 当前累积的消息批次
|
||||
self.current_batch: Optional[MessageBatch] = None
|
||||
|
||||
# LLM请求器,用于压缩聊天内容
|
||||
self.summarizer_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
|
||||
)
|
||||
|
||||
# 后台循环相关
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self._periodic_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
def _get_chat_display_name(self) -> str:
|
||||
"""获取聊天显示名称"""
|
||||
try:
|
||||
chat_name = get_chat_manager().get_stream_name(self.chat_id)
|
||||
if chat_name:
|
||||
return chat_name
|
||||
# 如果获取失败,使用简化的chat_id显示
|
||||
if len(self.chat_id) > 20:
|
||||
return f"{self.chat_id[:8]}..."
|
||||
return self.chat_id
|
||||
except Exception:
|
||||
# 如果获取失败,使用简化的chat_id显示
|
||||
if len(self.chat_id) > 20:
|
||||
return f"{self.chat_id[:8]}..."
|
||||
return self.chat_id
|
||||
|
||||
async def process(self, current_time: Optional[float] = None):
|
||||
"""
|
||||
处理聊天内容概括
|
||||
|
||||
Args:
|
||||
current_time: 当前时间戳,如果为None则使用time.time()
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
|
||||
try:
|
||||
# 获取从上次检查时间到当前时间的新消息
|
||||
new_messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=self.last_check_time,
|
||||
end_time=current_time,
|
||||
limit=0,
|
||||
limit_mode="latest",
|
||||
filter_mai=False, # 不过滤bot消息,因为需要检查bot是否发言
|
||||
filter_command=False,
|
||||
)
|
||||
|
||||
if not new_messages:
|
||||
# 没有新消息,检查是否需要打包
|
||||
if self.current_batch and self.current_batch.messages:
|
||||
await self._check_and_package(current_time)
|
||||
self.last_check_time = current_time
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
|
||||
)
|
||||
|
||||
# 有新消息,更新最后检查时间
|
||||
self.last_check_time = current_time
|
||||
|
||||
# 如果有当前批次,添加新消息
|
||||
if self.current_batch:
|
||||
before_count = len(self.current_batch.messages)
|
||||
self.current_batch.messages.extend(new_messages)
|
||||
self.current_batch.end_time = current_time
|
||||
logger.info(f"{self.log_prefix} 更新聊天话题: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
else:
|
||||
# 创建新批次
|
||||
self.current_batch = MessageBatch(
|
||||
messages=new_messages,
|
||||
start_time=new_messages[0].time if new_messages else current_time,
|
||||
end_time=current_time,
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 新建聊天话题: {len(new_messages)} 条消息")
|
||||
|
||||
# 检查是否需要打包
|
||||
await self._check_and_package(current_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _check_and_package(self, current_time: float):
|
||||
"""检查是否需要打包"""
|
||||
if not self.current_batch or not self.current_batch.messages:
|
||||
return
|
||||
|
||||
messages = self.current_batch.messages
|
||||
message_count = len(messages)
|
||||
last_message_time = messages[-1].time if messages else current_time
|
||||
time_since_last_message = current_time - last_message_time
|
||||
|
||||
# 格式化时间差显示
|
||||
if time_since_last_message < 60:
|
||||
time_str = f"{time_since_last_message:.1f}秒"
|
||||
elif time_since_last_message < 3600:
|
||||
time_str = f"{time_since_last_message / 60:.1f}分钟"
|
||||
else:
|
||||
time_str = f"{time_since_last_message / 3600:.1f}小时"
|
||||
|
||||
preparing_status = "是" if self.current_batch.is_preparing else "否"
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距最后消息: {time_str} | 准备结束模式: {preparing_status}"
|
||||
)
|
||||
|
||||
# 检查打包条件
|
||||
should_package = False
|
||||
|
||||
# 条件1: 消息长度超过120,直接打包
|
||||
if message_count >= 120:
|
||||
should_package = True
|
||||
logger.info(f"{self.log_prefix} 触发打包条件: 消息数量达到 {message_count} 条(阈值: 120条)")
|
||||
|
||||
# 条件2: 最后一条消息的时间和当前时间差>600秒,直接打包
|
||||
elif time_since_last_message > 600:
|
||||
should_package = True
|
||||
logger.info(f"{self.log_prefix} 触发打包条件: 距最后消息 {time_str}(阈值: 10分钟)")
|
||||
|
||||
# 条件3: 消息长度超过100,进入准备结束模式
|
||||
elif message_count > 100:
|
||||
if not self.current_batch.is_preparing:
|
||||
self.current_batch.is_preparing = True
|
||||
logger.info(f"{self.log_prefix} 消息数量 {message_count} 条超过阈值(100条),进入准备结束模式")
|
||||
|
||||
# 在准备结束模式下,如果最后一条消息的时间和当前时间差>10秒,就打包
|
||||
if time_since_last_message > 10:
|
||||
should_package = True
|
||||
logger.info(f"{self.log_prefix} 触发打包条件: 准备结束模式下,距最后消息 {time_str}(阈值: 10秒)")
|
||||
|
||||
if should_package:
|
||||
await self._package_and_store()
|
||||
|
||||
async def _package_and_store(self):
|
||||
"""打包并存储聊天记录"""
|
||||
if not self.current_batch or not self.current_batch.messages:
|
||||
return
|
||||
|
||||
messages = self.current_batch.messages
|
||||
start_time = self.current_batch.start_time
|
||||
end_time = self.current_batch.end_time
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始打包批次 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
|
||||
)
|
||||
|
||||
# 检查是否有bot发言
|
||||
# 第一条消息前推600s到最后一条消息的时间内
|
||||
check_start_time = max(start_time - 600, 0)
|
||||
check_end_time = end_time
|
||||
|
||||
# 使用包含边界的时间范围查询
|
||||
bot_messages = message_api.get_messages_by_time_in_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
start_time=check_start_time,
|
||||
end_time=check_end_time,
|
||||
limit=0,
|
||||
limit_mode="latest",
|
||||
filter_mai=False,
|
||||
filter_command=False,
|
||||
)
|
||||
|
||||
# 检查是否有bot的发言
|
||||
has_bot_message = False
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
for msg in bot_messages:
|
||||
if msg.user_info.user_id == bot_user_id:
|
||||
has_bot_message = True
|
||||
break
|
||||
|
||||
if not has_bot_message:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 批次内无Bot发言,丢弃批次 | 检查时间范围: {check_start_time:.2f} - {check_end_time:.2f}"
|
||||
)
|
||||
self.current_batch = None
|
||||
return
|
||||
|
||||
# 有bot发言,进行压缩和存储
|
||||
try:
|
||||
# 构建对话原文
|
||||
original_text = build_readable_messages(
|
||||
messages=messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
# 获取参与的所有人的昵称
|
||||
participants_set: Set[str] = set()
|
||||
for msg in messages:
|
||||
# 使用 msg.user_platform(扁平化字段)或 msg.user_info.platform
|
||||
platform = (
|
||||
getattr(msg, "user_platform", None)
|
||||
or (msg.user_info.platform if msg.user_info else None)
|
||||
or msg.chat_info.platform
|
||||
)
|
||||
person = Person(platform=platform, user_id=msg.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
if person_name:
|
||||
participants_set.add(person_name)
|
||||
participants = list(participants_set)
|
||||
logger.info(f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}")
|
||||
|
||||
# 使用LLM压缩聊天内容
|
||||
success, theme, keywords, summary = await self._compress_with_llm(original_text)
|
||||
|
||||
if not success:
|
||||
logger.warning(f"{self.log_prefix} LLM压缩失败,不存储到数据库 | 消息数: {len(messages)}")
|
||||
# 清空当前批次,避免重复处理
|
||||
self.current_batch = None
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)} 字"
|
||||
)
|
||||
|
||||
# 存储到数据库
|
||||
await self._store_to_database(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
original_text=original_text,
|
||||
participants=participants,
|
||||
theme=theme,
|
||||
keywords=keywords,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} 成功打包并存储聊天记录 | 消息数: {len(messages)} | 主题: {theme}")
|
||||
|
||||
# 清空当前批次
|
||||
self.current_batch = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# 出错时也清空批次,避免重复处理
|
||||
self.current_batch = None
|
||||
|
||||
async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]:
|
||||
"""
|
||||
使用LLM压缩聊天内容
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括)
|
||||
"""
|
||||
prompt = f"""请对以下聊天记录进行概括,提取以下信息:
|
||||
|
||||
1. 主题:这段对话的主要内容,一个简短的标题(不超过20字)
|
||||
2. 关键词:这段对话的关键词,用列表形式返回(3-10个关键词)
|
||||
3. 概括:对这段话的平文本概括(50-200字)
|
||||
|
||||
请以JSON格式返回,格式如下:
|
||||
{{
|
||||
"theme": "主题",
|
||||
"keywords": ["关键词1", "关键词2", ...],
|
||||
"summary": "概括内容"
|
||||
}}
|
||||
|
||||
聊天记录:
|
||||
{original_text}
|
||||
|
||||
请直接返回JSON,不要包含其他内容。"""
|
||||
|
||||
try:
|
||||
response, _ = await self.summarizer_llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
# 解析JSON响应
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
json_str = response.strip()
|
||||
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = json_str.strip()
|
||||
|
||||
# 尝试找到JSON对象的开始和结束位置
|
||||
# 查找第一个 { 和最后一个匹配的 }
|
||||
start_idx = json_str.find("{")
|
||||
if start_idx == -1:
|
||||
raise ValueError("未找到JSON对象开始标记")
|
||||
|
||||
# 从后往前查找最后一个 }
|
||||
end_idx = json_str.rfind("}")
|
||||
if end_idx == -1 or end_idx <= start_idx:
|
||||
raise ValueError("未找到JSON对象结束标记")
|
||||
|
||||
# 提取JSON字符串
|
||||
json_str = json_str[start_idx : end_idx + 1]
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
result = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,尝试修复字符串值中的中文引号
|
||||
# 简单方法:将字符串值中的中文引号替换为转义的英文引号
|
||||
# 使用状态机方法:遍历字符串,在字符串值内部替换中文引号
|
||||
fixed_chars = []
|
||||
in_string = False
|
||||
escape_next = False
|
||||
i = 0
|
||||
while i < len(json_str):
|
||||
char = json_str[i]
|
||||
if escape_next:
|
||||
fixed_chars.append(char)
|
||||
escape_next = False
|
||||
elif char == "\\":
|
||||
fixed_chars.append(char)
|
||||
escape_next = True
|
||||
elif char == '"' and not escape_next:
|
||||
fixed_chars.append(char)
|
||||
in_string = not in_string
|
||||
elif in_string and (char == '"' or char == '"'):
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
fixed_chars.append('\\"')
|
||||
else:
|
||||
fixed_chars.append(char)
|
||||
i += 1
|
||||
|
||||
json_str = "".join(fixed_chars)
|
||||
# 再次尝试解析
|
||||
result = json.loads(json_str)
|
||||
|
||||
theme = result.get("theme", "未命名对话")
|
||||
keywords = result.get("keywords", [])
|
||||
summary = result.get("summary", "无概括")
|
||||
|
||||
# 确保keywords是列表
|
||||
if isinstance(keywords, str):
|
||||
keywords = [keywords]
|
||||
|
||||
return True, theme, keywords, summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
|
||||
# 返回失败标志和默认值
|
||||
return False, "未命名对话", [], "压缩失败,无法生成概括"
|
||||
|
||||
async def _store_to_database(
|
||||
self,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
original_text: str,
|
||||
participants: List[str],
|
||||
theme: str,
|
||||
keywords: List[str],
|
||||
summary: str,
|
||||
):
|
||||
"""存储到数据库"""
|
||||
try:
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.plugin_system.apis import database_api
|
||||
|
||||
# 准备数据
|
||||
data = {
|
||||
"chat_id": self.chat_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"original_text": original_text,
|
||||
"participants": json.dumps(participants, ensure_ascii=False),
|
||||
"theme": theme,
|
||||
"keywords": json.dumps(keywords, ensure_ascii=False),
|
||||
"summary": summary,
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
# 使用db_save存储(使用start_time和chat_id作为唯一标识)
|
||||
# 由于可能有多条记录,我们使用组合键,但peewee不支持,所以使用start_time作为唯一标识
|
||||
# 但为了避免冲突,我们使用组合键:chat_id + start_time
|
||||
# 由于peewee不支持组合键,我们直接创建新记录(不提供key_field和key_value)
|
||||
saved_record = await database_api.db_save(
|
||||
ChatHistory,
|
||||
data=data,
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""启动后台定期检查循环"""
|
||||
if self._running:
|
||||
logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._periodic_task = asyncio.create_task(self._periodic_check_loop())
|
||||
logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}秒")
|
||||
|
||||
async def stop(self):
|
||||
"""停止后台定期检查循环"""
|
||||
self._running = False
|
||||
if self._periodic_task:
|
||||
self._periodic_task.cancel()
|
||||
try:
|
||||
await self._periodic_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._periodic_task = None
|
||||
logger.info(f"{self.log_prefix} 已停止后台定期检查循环")
|
||||
|
||||
async def _periodic_check_loop(self):
|
||||
"""后台定期检查循环"""
|
||||
try:
|
||||
while self._running:
|
||||
# 执行一次检查
|
||||
await self.process()
|
||||
|
||||
# 等待指定间隔后再次检查
|
||||
await asyncio.sleep(self.check_interval)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 后台检查循环被取消")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self._running = False
|
||||
@@ -959,7 +959,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages], show_ids: b
|
||||
header = f"[{i + 1}] {anon_name}说 "
|
||||
else:
|
||||
header = f"{anon_name}说 "
|
||||
|
||||
|
||||
output_lines.append(header)
|
||||
stripped_line = content.strip()
|
||||
if stripped_line:
|
||||
|
||||
@@ -25,7 +25,7 @@ class MemoryForgetTask(AsyncTask):
|
||||
"""执行遗忘检查"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
logger.info("[记忆遗忘] 开始遗忘检查...")
|
||||
# logger.info("[记忆遗忘] 开始遗忘检查...")
|
||||
|
||||
# 执行4个阶段的遗忘检查
|
||||
await self._forget_stage_1(current_time)
|
||||
@@ -33,7 +33,7 @@ class MemoryForgetTask(AsyncTask):
|
||||
await self._forget_stage_3(current_time)
|
||||
await self._forget_stage_4(current_time)
|
||||
|
||||
logger.info("[记忆遗忘] 遗忘检查完成")
|
||||
# logger.info("[记忆遗忘] 遗忘检查完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -227,6 +227,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
"",
|
||||
self._format_model_classified_stat(stats["last_hour"]),
|
||||
"",
|
||||
self._format_module_classified_stat(stats["last_hour"]),
|
||||
"",
|
||||
self._format_chat_stat(stats["last_hour"]),
|
||||
self.SEP_LINE,
|
||||
"",
|
||||
@@ -737,11 +739,13 @@ class StatisticOutputTask(AsyncTask):
|
||||
"""
|
||||
if stats[TOTAL_REQ_CNT] <= 0:
|
||||
return ""
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f}"
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
output = [
|
||||
"按模型分类统计:",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒)",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
]
|
||||
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
|
||||
name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
|
||||
@@ -751,11 +755,19 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
formatted_out_tokens = _format_large_number(out_tokens)
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
@@ -766,6 +778,62 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost,
|
||||
avg_time_cost,
|
||||
std_time_cost,
|
||||
formatted_avg_count,
|
||||
formatted_avg_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
|
||||
@staticmethod
|
||||
def _format_module_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化按模块分类的统计数据
|
||||
"""
|
||||
if stats[TOTAL_REQ_CNT] <= 0:
|
||||
return ""
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
output = [
|
||||
"按模块分类统计:",
|
||||
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
]
|
||||
for module_name, count in sorted(stats[REQ_CNT_BY_MODULE].items()):
|
||||
name = f"{module_name[:29]}..." if len(module_name) > 32 else module_name
|
||||
in_tokens = stats[IN_TOK_BY_MODULE][module_name]
|
||||
out_tokens = stats[OUT_TOK_BY_MODULE][module_name]
|
||||
tokens = stats[TOTAL_TOK_BY_MODULE][module_name]
|
||||
cost = stats[COST_BY_MODULE][module_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
formatted_out_tokens = _format_large_number(out_tokens)
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
formatted_count,
|
||||
formatted_in_tokens,
|
||||
formatted_out_tokens,
|
||||
formatted_tokens,
|
||||
cost,
|
||||
avg_time_cost,
|
||||
std_time_cost,
|
||||
formatted_avg_count,
|
||||
formatted_avg_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -849,6 +917,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# format总在线时间
|
||||
|
||||
# 按模型分类统计
|
||||
total_replies = stat_data.get(TOTAL_REPLY_CNT, 0)
|
||||
model_rows = "\n".join(
|
||||
[
|
||||
f"<tr>"
|
||||
@@ -860,11 +929,13 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[COST_BY_MODEL][model_name]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODEL][model_name] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODEL]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按请求类型分类统计
|
||||
type_rows = "\n".join(
|
||||
@@ -878,11 +949,13 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[COST_BY_TYPE][req_type]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_TYPE][req_type] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_TYPE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按模块分类统计
|
||||
module_rows = "\n".join(
|
||||
@@ -896,11 +969,13 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[COST_BY_MODULE][module_name]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODULE][module_name] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODULE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
|
||||
# 聊天消息统计
|
||||
@@ -975,7 +1050,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
<h2>按模型分类统计</h2>
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr></thead>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr></thead>
|
||||
<tbody>
|
||||
{model_rows}
|
||||
</tbody>
|
||||
@@ -986,7 +1061,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{module_rows}
|
||||
@@ -998,7 +1073,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{type_rows}
|
||||
|
||||
@@ -164,6 +164,47 @@ class ImageManager:
|
||||
tag_str = ",".join(emotion_list)
|
||||
return f"[表情包:{tag_str}]"
|
||||
|
||||
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
|
||||
"""如果启用了steal_emoji且表情包未注册,保存文件到data/emoji目录
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
image_hash: 图片的MD5哈希值
|
||||
image_format: 图片格式
|
||||
"""
|
||||
if not global_config.emoji.steal_emoji:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
# 检查是否已存在该表情包(通过哈希值)
|
||||
emoji_manager = get_emoji_manager()
|
||||
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
||||
if existing_emoji:
|
||||
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
|
||||
return
|
||||
|
||||
# 生成文件名:使用哈希值前8位 + 格式
|
||||
filename = f"{image_hash[:8]}.{image_format}"
|
||||
file_path = os.path.join(EMOJI_DIR, filename)
|
||||
|
||||
# 检查文件是否已存在(可能之前保存过但未注册)
|
||||
if not os.path.exists(file_path):
|
||||
# 保存文件
|
||||
if base64_to_image(image_base64, file_path):
|
||||
logger.info(f"[自动保存] 表情包已保存到 {file_path} (Hash: {image_hash[:8]}...)")
|
||||
else:
|
||||
logger.warning(f"[自动保存] 保存表情包文件失败: {file_path}")
|
||||
else:
|
||||
logger.debug(f"[自动保存] 表情包文件已存在,跳过: {file_path}")
|
||||
except Exception as save_error:
|
||||
logger.warning(f"[自动保存] 保存表情包文件时出错: {save_error}")
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,优先使用EmojiDescriptionCache表中的缓存数据"""
|
||||
try:
|
||||
@@ -193,12 +234,18 @@ class ImageManager:
|
||||
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
|
||||
if cache_record:
|
||||
# 优先使用情感标签,如果没有则使用详细描述
|
||||
result_text = ""
|
||||
if cache_record.emotion_tags:
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
|
||||
return f"[表情包:{cache_record.emotion_tags}]"
|
||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||
elif cache_record.description:
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
|
||||
return f"[表情包:{cache_record.description}]"
|
||||
result_text = f"[表情包:{cache_record.description}]"
|
||||
|
||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||
if result_text:
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
return result_text
|
||||
except Exception as e:
|
||||
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
|
||||
|
||||
@@ -290,6 +337,9 @@ class ImageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
|
||||
|
||||
# 如果启用了steal_emoji,自动保存表情包文件到data/emoji目录
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
|
||||
return f"[表情包:{final_emotion}]"
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -372,6 +372,7 @@ class ChatHistory(BaseModel):
|
||||
theme = TextField() # 主题:这段对话的主要内容,一个简短的标题
|
||||
keywords = TextField() # 关键词:这段对话的关键词,JSON格式存储
|
||||
summary = TextField() # 概括:对这段话的平文本概括
|
||||
key_point = TextField(null=True) # 关键信息:话题中的关键信息点,JSON格式存储
|
||||
count = IntegerField(default=0) # 被检索次数
|
||||
forget_times = IntegerField(default=0) # 被遗忘检查的次数
|
||||
|
||||
|
||||
67
src/common/toml_utils.py
Normal file
67
src/common/toml_utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
TOML 工具函数
|
||||
|
||||
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
import tomlkit
|
||||
from tomlkit.items import AoT, Table, Array
|
||||
|
||||
|
||||
def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
"""递归格式化 TOML 值,将数组转换为多行格式"""
|
||||
# 处理 AoT (Array of Tables) - 保持原样,递归处理内部
|
||||
if isinstance(obj, AoT):
|
||||
for item in obj:
|
||||
_format_toml_value(item, threshold, depth)
|
||||
return obj
|
||||
|
||||
# 处理字典类型 (dict 或 Table)
|
||||
if isinstance(obj, (dict, Table)):
|
||||
for k, v in obj.items():
|
||||
obj[k] = _format_toml_value(v, threshold, depth)
|
||||
return obj
|
||||
|
||||
# 处理列表类型 (list 或 Array)
|
||||
if isinstance(obj, (list, Array)):
|
||||
# 如果是纯 list (非 tomlkit Array) 且包含字典/表,视为 AoT 的列表形式
|
||||
# 保持结构递归处理,避免转换为 Inline Table Array (因为 Inline Table 必须单行,复杂对象不友好)
|
||||
if isinstance(obj, list) and not isinstance(obj, Array) and obj and isinstance(obj[0], (dict, Table)):
|
||||
for i, item in enumerate(obj):
|
||||
obj[i] = _format_toml_value(item, threshold, depth)
|
||||
return obj
|
||||
|
||||
# 决定是否多行:仅在顶层且长度超过阈值时
|
||||
should_multiline = (depth == 0 and len(obj) > threshold)
|
||||
|
||||
# 如果已经是 tomlkit Array,原地修改以保留注释
|
||||
if isinstance(obj, Array):
|
||||
obj.multiline(should_multiline)
|
||||
for i, item in enumerate(obj):
|
||||
obj[i] = _format_toml_value(item, threshold, depth + 1)
|
||||
return obj
|
||||
|
||||
# 普通 list:转换为 tomlkit 数组
|
||||
arr = tomlkit.array()
|
||||
arr.multiline(should_multiline)
|
||||
|
||||
for item in obj:
|
||||
arr.append(_format_toml_value(item, threshold, depth + 1))
|
||||
return arr
|
||||
|
||||
# 其他基本类型直接返回
|
||||
return obj
|
||||
|
||||
|
||||
def save_toml_with_format(data: Any, file_path: str, multiline_threshold: int = 1) -> None:
|
||||
"""格式化 TOML 数据并保存到文件"""
|
||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(formatted, f)
|
||||
|
||||
|
||||
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
||||
"""格式化 TOML 数据并返回字符串"""
|
||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||
return tomlkit.dumps(formatted)
|
||||
@@ -88,6 +88,9 @@ class TaskConfig(ConfigBase):
|
||||
temperature: float = 0.3
|
||||
"""模型温度"""
|
||||
|
||||
slow_threshold: float = 15.0
|
||||
"""慢请求阈值(秒),超过此值会输出警告日志"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelTaskConfig(ConfigBase):
|
||||
|
||||
@@ -11,6 +11,7 @@ from rich.traceback import install
|
||||
from typing import List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import format_toml_string
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
@@ -56,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.6-snapshot.1"
|
||||
MMC_VERSION = "0.11.6"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
@@ -252,7 +253,7 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
# 如果配置有更新,立即保存到文件
|
||||
if config_updated:
|
||||
with open(old_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(old_config))
|
||||
f.write(format_toml_string(old_config))
|
||||
logger.info(f"已保存更新后的{config_name}配置文件")
|
||||
else:
|
||||
logger.info(f"未检测到{config_name}模板默认值变动")
|
||||
@@ -313,9 +314,9 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
logger.info(f"开始合并{config_name}新旧配置...")
|
||||
_update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
# 保存更新后的配置(保留注释和格式,数组多行格式化)
|
||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
f.write(format_toml_string(new_config))
|
||||
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ def _compute_weights(population: List[Dict]) -> List[float]:
|
||||
|
||||
# 如果checked,权重乘以3
|
||||
weights = []
|
||||
for base_weight, checked in zip(base_weights, checked_flags):
|
||||
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
|
||||
if checked:
|
||||
weights.append(base_weight * 3.0)
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
@@ -91,6 +92,9 @@ class ExpressionLearner:
|
||||
# 维护每个chat的上次学习时间
|
||||
self.last_learning_time: float = time.time()
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
# 学习参数
|
||||
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
self.chat_id
|
||||
@@ -139,32 +143,45 @@ class ExpressionLearner:
|
||||
Returns:
|
||||
bool: 是否成功触发学习
|
||||
"""
|
||||
if not self.should_trigger_learning():
|
||||
return
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._learning_lock:
|
||||
# 在锁内检查,避免并发触发
|
||||
# 如果锁被持有,其他协程会等待,但等待期间条件可能已变化,所以需要再次检查
|
||||
if not self.should_trigger_learning():
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
|
||||
# 学习语言风格
|
||||
learnt_style = await self.learn_and_store(num=25)
|
||||
# 保存学习开始前的时间戳,用于获取消息范围
|
||||
learning_start_timestamp = time.time()
|
||||
previous_learning_time = self.last_learning_time
|
||||
|
||||
# 立即更新学习时间,防止并发触发
|
||||
self.last_learning_time = learning_start_timestamp
|
||||
|
||||
# 更新学习时间
|
||||
self.last_learning_time = time.time()
|
||||
try:
|
||||
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
|
||||
# 学习语言风格,传递学习开始前的时间戳
|
||||
learnt_style = await self.learn_and_store(num=25, timestamp_start=previous_learning_time)
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
return
|
||||
|
||||
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
async def learn_and_store(self, num: int = 10, timestamp_start: Optional[float] = None) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time
|
||||
"""
|
||||
learnt_expressions = await self.learn_expression(num)
|
||||
learnt_expressions = await self.learn_expression(num, timestamp_start=timestamp_start)
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
@@ -226,19 +243,19 @@ class ExpressionLearner:
|
||||
match_responses = []
|
||||
try:
|
||||
response = response.strip()
|
||||
|
||||
|
||||
# 尝试提取JSON代码块(如果存在)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
if matches:
|
||||
response = matches[0].strip()
|
||||
|
||||
|
||||
# 移除可能的markdown代码块标记(如果没有找到```json,但可能有```)
|
||||
if not matches:
|
||||
response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE)
|
||||
response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE)
|
||||
response = response.strip()
|
||||
|
||||
|
||||
# 检查是否已经是标准JSON数组格式
|
||||
if response.startswith("[") and response.endswith("]"):
|
||||
match_responses = json.loads(response)
|
||||
@@ -374,18 +391,22 @@ class ExpressionLearner:
|
||||
|
||||
return matched_expressions
|
||||
|
||||
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 使用传入的时间戳,如果没有则使用self.last_learning_time
|
||||
start_timestamp = timestamp_start if timestamp_start is not None else self.last_learning_time
|
||||
|
||||
# 获取上次学习之后的消息
|
||||
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_start=start_timestamp,
|
||||
timestamp_end=current_time,
|
||||
limit=num,
|
||||
)
|
||||
|
||||
@@ -13,28 +13,28 @@ logger = get_logger("expression_reflector")
|
||||
|
||||
class ExpressionReflector:
|
||||
"""表达反思器,管理单个聊天流的表达反思提问"""
|
||||
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.last_ask_time: float = 0.0
|
||||
|
||||
|
||||
async def check_and_ask(self) -> bool:
|
||||
"""
|
||||
检查是否需要提问表达反思,如果需要则提问
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否执行了提问
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
|
||||
|
||||
|
||||
if not global_config.expression.reflect:
|
||||
logger.debug(f"[Expression Reflection] 表达反思功能未启用,跳过")
|
||||
logger.debug("[Expression Reflection] 表达反思功能未启用,跳过")
|
||||
return False
|
||||
|
||||
operator_config = global_config.expression.reflect_operator_id
|
||||
if not operator_config:
|
||||
logger.debug(f"[Expression Reflection] Operator ID 未配置,跳过")
|
||||
logger.debug("[Expression Reflection] Operator ID 未配置,跳过")
|
||||
return False
|
||||
|
||||
# 检查是否在允许列表中
|
||||
@@ -48,7 +48,7 @@ class ExpressionReflector:
|
||||
allow_reflect_chat_ids.append(parsed_chat_id)
|
||||
else:
|
||||
logger.warning(f"[Expression Reflection] 无法解析 allow_reflect 配置项: {stream_config}")
|
||||
|
||||
|
||||
if self.chat_id not in allow_reflect_chat_ids:
|
||||
logger.info(f"[Expression Reflection] 当前聊天流 {self.chat_id} 不在允许列表中,跳过")
|
||||
return False
|
||||
@@ -56,17 +56,21 @@ class ExpressionReflector:
|
||||
# 检查上一次提问时间
|
||||
current_time = time.time()
|
||||
time_since_last_ask = current_time - self.last_ask_time
|
||||
|
||||
|
||||
# 5-10分钟间隔,随机选择
|
||||
min_interval = 10 * 60 # 5分钟
|
||||
max_interval = 15 * 60 # 10分钟
|
||||
interval = random.uniform(min_interval, max_interval)
|
||||
|
||||
logger.info(f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask/60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval/60:.2f}分钟)")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask / 60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval / 60:.2f}分钟)"
|
||||
)
|
||||
|
||||
if time_since_last_ask < interval:
|
||||
remaining_time = interval - time_since_last_ask
|
||||
logger.info(f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time/60:.2f}分钟),跳过")
|
||||
logger.info(
|
||||
f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time / 60:.2f}分钟),跳过"
|
||||
)
|
||||
return False
|
||||
|
||||
# 检查是否已经有针对该 Operator 的 Tracker 在运行
|
||||
@@ -77,56 +81,59 @@ class ExpressionReflector:
|
||||
|
||||
# 获取未检查的表达
|
||||
try:
|
||||
logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||
expressions = (Expression
|
||||
.select()
|
||||
.where((Expression.checked == False) & (Expression.rejected == False))
|
||||
.limit(50))
|
||||
|
||||
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||
expressions = (
|
||||
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||
)
|
||||
|
||||
expr_list = list(expressions)
|
||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||
|
||||
|
||||
if not expr_list:
|
||||
logger.info(f"[Expression Reflection] 没有可用的表达,跳过")
|
||||
logger.info("[Expression Reflection] 没有可用的表达,跳过")
|
||||
return False
|
||||
|
||||
target_expr: Expression = random.choice(expr_list)
|
||||
logger.info(f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}")
|
||||
|
||||
logger.info(
|
||||
f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}"
|
||||
)
|
||||
|
||||
# 生成询问文本
|
||||
ask_text = _generate_ask_text(target_expr)
|
||||
if not ask_text:
|
||||
logger.warning(f"[Expression Reflection] 生成询问文本失败,跳过")
|
||||
logger.warning("[Expression Reflection] 生成询问文本失败,跳过")
|
||||
return False
|
||||
|
||||
logger.info(f"[Expression Reflection] 准备向 Operator {operator_config} 发送提问")
|
||||
# 发送给 Operator
|
||||
await _send_to_operator(operator_config, ask_text, target_expr)
|
||||
|
||||
|
||||
# 更新上一次提问时间
|
||||
self.last_ask_time = current_time
|
||||
logger.info(f"[Expression Reflection] 提问成功,已更新上次提问时间为 {current_time:.2f}")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
class ExpressionReflectorManager:
|
||||
"""表达反思管理器,管理多个聊天流的表达反思实例"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.reflectors: Dict[str, ExpressionReflector] = {}
|
||||
|
||||
|
||||
def get_or_create_reflector(self, chat_id: str) -> ExpressionReflector:
|
||||
"""获取或创建指定聊天流的表达反思实例"""
|
||||
if chat_id not in self.reflectors:
|
||||
@@ -141,6 +148,7 @@ expression_reflector_manager = ExpressionReflectorManager()
|
||||
async def _check_tracker_exists(operator_config: str) -> bool:
|
||||
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = None
|
||||
|
||||
@@ -150,12 +158,12 @@ async def _check_tracker_exists(operator_config: str) -> bool:
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
|
||||
user_info = None
|
||||
group_info = None
|
||||
|
||||
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
|
||||
if stream_type == "group":
|
||||
group_info = GroupInfo(group_id=id_str, platform=platform)
|
||||
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
||||
@@ -203,12 +211,12 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
|
||||
user_info = None
|
||||
group_info = None
|
||||
|
||||
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
|
||||
if stream_type == "group":
|
||||
group_info = GroupInfo(group_id=id_str, platform=platform)
|
||||
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
||||
@@ -232,20 +240,13 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
||||
return
|
||||
|
||||
stream_id = chat_stream.stream_id
|
||||
|
||||
|
||||
# 注册 Tracker
|
||||
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
||||
|
||||
|
||||
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
|
||||
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
||||
|
||||
|
||||
# 发送消息
|
||||
await send_api.text_to_stream(
|
||||
text=text,
|
||||
stream_id=stream_id,
|
||||
typing=True
|
||||
)
|
||||
await send_api.text_to_stream(text=text, stream_id=stream_id, typing=True)
|
||||
logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ class ExpressionSelector:
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.rejected == False)
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
@@ -151,7 +151,6 @@ class ExpressionSelector:
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
return selected_style
|
||||
|
||||
except Exception as e:
|
||||
@@ -294,7 +293,7 @@ class ExpressionSelector:
|
||||
if valid_expressions:
|
||||
self.update_expressions_last_active_time(valid_expressions)
|
||||
|
||||
logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
logger.debug(f"从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,34 +4,32 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.config.config import model_config, global_config
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
pass
|
||||
|
||||
logger = get_logger("reflect_tracker")
|
||||
|
||||
|
||||
class ReflectTracker:
|
||||
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
|
||||
self.chat_stream = chat_stream
|
||||
self.expression = expression
|
||||
self.created_time = created_time
|
||||
# self.message_count = 0 # Replaced by checking message list length
|
||||
self.last_check_msg_count = 0
|
||||
self.last_check_msg_count = 0
|
||||
self.max_message_count = 30
|
||||
self.max_duration = 15 * 60 # 15 minutes
|
||||
|
||||
|
||||
# LLM for judging response
|
||||
self.judge_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="reflect.tracker"
|
||||
)
|
||||
|
||||
self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker")
|
||||
|
||||
self._init_prompts()
|
||||
|
||||
def _init_prompts(self):
|
||||
@@ -72,16 +70,16 @@ class ReflectTracker:
|
||||
if time.time() - self.created_time > self.max_duration:
|
||||
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (duration).")
|
||||
return True
|
||||
|
||||
|
||||
# Fetch messages since creation
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
timestamp_start=self.created_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
current_msg_count = len(msg_list)
|
||||
|
||||
|
||||
# Check message limit
|
||||
if current_msg_count > self.max_message_count:
|
||||
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (message count).")
|
||||
@@ -90,9 +88,9 @@ class ReflectTracker:
|
||||
# If no new messages since last check, skip
|
||||
if current_msg_count <= self.last_check_msg_count:
|
||||
return False
|
||||
|
||||
|
||||
self.last_check_msg_count = current_msg_count
|
||||
|
||||
|
||||
# Build context block
|
||||
# Use simple readable format
|
||||
context_block = build_readable_messages(
|
||||
@@ -109,78 +107,83 @@ class ReflectTracker:
|
||||
"reflect_judge_prompt",
|
||||
situation=self.expression.situation,
|
||||
style=self.expression.style,
|
||||
context_block=context_block
|
||||
context_block=context_block,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"ReflectTracker LLM Prompt: {prompt}")
|
||||
|
||||
|
||||
response, _ = await self.judge_model.generate_response_async(prompt, temperature=0.1)
|
||||
|
||||
|
||||
logger.info(f"ReflectTracker LLM Response: {response}")
|
||||
|
||||
|
||||
# Parse JSON
|
||||
import json
|
||||
import re
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
if not matches:
|
||||
# Try to parse raw response if no code block
|
||||
matches = [response]
|
||||
|
||||
|
||||
json_obj = json.loads(repair_json(matches[0]))
|
||||
|
||||
|
||||
judgment = json_obj.get("judgment")
|
||||
|
||||
|
||||
if judgment == "Approve":
|
||||
self.expression.checked = True
|
||||
self.expression.rejected = False
|
||||
self.expression.save()
|
||||
logger.info(f"Expression {self.expression.id} approved by operator.")
|
||||
return True
|
||||
|
||||
|
||||
elif judgment == "Reject":
|
||||
self.expression.checked = True
|
||||
corrected_situation = json_obj.get("corrected_situation")
|
||||
corrected_style = json_obj.get("corrected_style")
|
||||
|
||||
|
||||
# 检查是否有更新
|
||||
has_update = bool(corrected_situation or corrected_style)
|
||||
|
||||
|
||||
if corrected_situation:
|
||||
self.expression.situation = corrected_situation
|
||||
if corrected_style:
|
||||
self.expression.style = corrected_style
|
||||
|
||||
|
||||
# 如果拒绝但未更新,标记为 rejected=1
|
||||
if not has_update:
|
||||
self.expression.rejected = True
|
||||
else:
|
||||
self.expression.rejected = False
|
||||
|
||||
|
||||
self.expression.save()
|
||||
|
||||
|
||||
if has_update:
|
||||
logger.info(f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}")
|
||||
logger.info(
|
||||
f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1.")
|
||||
logger.info(
|
||||
f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1."
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
elif judgment == "Ignore":
|
||||
logger.info(f"ReflectTracker for expr {self.expression.id} judged as Ignore.")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ReflectTracker check: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global manager for trackers
|
||||
class ReflectTrackerManager:
|
||||
def __init__(self):
|
||||
self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker
|
||||
self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker
|
||||
|
||||
def add_tracker(self, chat_id: str, tracker: ReflectTracker):
|
||||
self.trackers[chat_id] = tracker
|
||||
@@ -192,5 +195,5 @@ class ReflectTrackerManager:
|
||||
if chat_id in self.trackers:
|
||||
del self.trackers[chat_id]
|
||||
|
||||
reflect_tracker_manager = ReflectTrackerManager()
|
||||
|
||||
reflect_tracker_manager = ReflectTrackerManager()
|
||||
|
||||
924
src/hippo_memorizer/chat_history_summarizer.py
Normal file
924
src/hippo_memorizer/chat_history_summarizer.py
Normal file
@@ -0,0 +1,924 @@
|
||||
"""
|
||||
聊天内容概括器
|
||||
用于累积、打包和压缩聊天记录
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import message_api
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
||||
logger = get_logger("chat_history_summarizer")
|
||||
|
||||
HIPPO_CACHE_DIR = Path(__file__).resolve().parents[2] / "data" / "hippo_memorizer"
|
||||
|
||||
|
||||
def init_prompt():
|
||||
"""初始化提示词模板"""
|
||||
|
||||
topic_analysis_prompt = """
|
||||
【历史话题标题列表】(仅标题,不含具体内容):
|
||||
{history_topics_block}
|
||||
|
||||
【本次聊天记录】(每条消息前有编号,用于后续引用):
|
||||
{messages_block}
|
||||
|
||||
请完成以下任务:
|
||||
**识别话题**
|
||||
1. 识别【本次聊天记录】中正在进行的一个或多个话题;
|
||||
2. 判断【历史话题标题列表】中的话题是否在【本次聊天记录】中出现,如果出现,则直接使用该历史话题标题字符串;
|
||||
|
||||
**选取消息**
|
||||
1. 对于每个话题(新话题或历史话题),从上述带编号的消息中选出与该话题强相关的消息编号列表;
|
||||
2. 每个话题用一句话清晰地描述正在发生的事件,必须包含时间(大致即可)、人物、主要事件和主题,保证精准且有区分度;
|
||||
|
||||
请先输出一段简短思考,说明有什么话题,哪些是不包含在历史话题中的,哪些是包含在历史话题中的,并说明为什么;
|
||||
然后严格以 JSON 格式输出【本次聊天记录】中涉及的话题,格式如下:
|
||||
[
|
||||
{{
|
||||
"topic": "话题",
|
||||
"message_indices": [1, 2, 5]
|
||||
}},
|
||||
...
|
||||
]
|
||||
"""
|
||||
Prompt(topic_analysis_prompt, "hippo_topic_analysis_prompt")
|
||||
|
||||
topic_summary_prompt = """
|
||||
请基于以下话题,对聊天记录片段进行概括,提取以下信息:
|
||||
|
||||
**话题**:{topic}
|
||||
|
||||
**要求**:
|
||||
1. 关键词:提取与话题相关的关键词,用列表形式返回(3-10个关键词)
|
||||
2. 概括:对这段话的平文本概括(50-200字),要求:
|
||||
- 仔细地转述发生的事件和聊天内容;
|
||||
- 可以适当摘取聊天记录中的原文;
|
||||
- 重点突出事件的发展过程和结果;
|
||||
- 围绕话题这个中心进行概括。
|
||||
3. 关键信息:提取话题中的关键信息点,用列表形式返回(3-8个关键信息点),每个关键信息点应该简洁明了。
|
||||
|
||||
请以JSON格式返回,格式如下:
|
||||
{{
|
||||
"keywords": ["关键词1", "关键词2", ...],
|
||||
"summary": "概括内容",
|
||||
"key_point": ["关键信息1", "关键信息2", ...]
|
||||
}}
|
||||
|
||||
聊天记录:
|
||||
{original_text}
|
||||
|
||||
请直接返回JSON,不要包含其他内容。
|
||||
"""
|
||||
Prompt(topic_summary_prompt, "hippo_topic_summary_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageBatch:
|
||||
"""消息批次(用于触发话题检查的原始消息累积)"""
|
||||
|
||||
messages: List[DatabaseMessages]
|
||||
start_time: float
|
||||
end_time: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopicCacheItem:
|
||||
"""
|
||||
话题缓存项
|
||||
|
||||
Attributes:
|
||||
topic: 话题标题(一句话描述时间、人物、事件和主题)
|
||||
messages: 与该话题相关的消息字符串列表(已经通过 build 函数转成可读文本)
|
||||
participants: 涉及到的发言人昵称集合
|
||||
no_update_checks: 连续多少次“检查”没有新增内容
|
||||
"""
|
||||
|
||||
topic: str
|
||||
messages: List[str] = field(default_factory=list)
|
||||
participants: Set[str] = field(default_factory=set)
|
||||
no_update_checks: int = 0
|
||||
|
||||
|
||||
class ChatHistorySummarizer:
|
||||
"""聊天内容概括器"""
|
||||
|
||||
def __init__(self, chat_id: str, check_interval: int = 60):
|
||||
"""
|
||||
初始化聊天内容概括器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
check_interval: 定期检查间隔(秒),默认60秒
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self._chat_display_name = self._get_chat_display_name()
|
||||
self.log_prefix = f"[{self._chat_display_name}]"
|
||||
|
||||
# 记录时间点,用于计算新消息
|
||||
self.last_check_time = time.time()
|
||||
|
||||
# 记录上一次话题检查的时间,用于判断是否需要触发检查
|
||||
self.last_topic_check_time = time.time()
|
||||
|
||||
# 当前累积的消息批次
|
||||
self.current_batch: Optional[MessageBatch] = None
|
||||
|
||||
# 话题缓存:topic_str -> TopicCacheItem
|
||||
# 在内存中维护,并通过本地文件实时持久化
|
||||
self.topic_cache: Dict[str, TopicCacheItem] = {}
|
||||
self._safe_chat_id = self._sanitize_chat_id(self.chat_id)
|
||||
self._topic_cache_file = HIPPO_CACHE_DIR / f"{self._safe_chat_id}.json"
|
||||
# 注意:批次加载需要异步查询消息,所以在 start() 中调用
|
||||
|
||||
# LLM请求器,用于压缩聊天内容
|
||||
self.summarizer_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
|
||||
)
|
||||
|
||||
# 后台循环相关
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self._periodic_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
def _get_chat_display_name(self) -> str:
|
||||
"""获取聊天显示名称"""
|
||||
try:
|
||||
chat_name = get_chat_manager().get_stream_name(self.chat_id)
|
||||
if chat_name:
|
||||
return chat_name
|
||||
# 如果获取失败,使用简化的chat_id显示
|
||||
if len(self.chat_id) > 20:
|
||||
return f"{self.chat_id[:8]}..."
|
||||
return self.chat_id
|
||||
except Exception:
|
||||
# 如果获取失败,使用简化的chat_id显示
|
||||
if len(self.chat_id) > 20:
|
||||
return f"{self.chat_id[:8]}..."
|
||||
return self.chat_id
|
||||
|
||||
def _sanitize_chat_id(self, chat_id: str) -> str:
|
||||
"""用于生成可作为文件名的 chat_id"""
|
||||
return re.sub(r"[^a-zA-Z0-9_.-]", "_", chat_id)
|
||||
|
||||
def _load_topic_cache_from_disk(self):
|
||||
"""在启动时加载本地话题缓存(同步部分),支持重启后继续"""
|
||||
try:
|
||||
if not self._topic_cache_file.exists():
|
||||
return
|
||||
|
||||
with self._topic_cache_file.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.last_topic_check_time = data.get("last_topic_check_time", self.last_topic_check_time)
|
||||
topics_data = data.get("topics", {})
|
||||
loaded_count = 0
|
||||
for topic, payload in topics_data.items():
|
||||
self.topic_cache[topic] = TopicCacheItem(
|
||||
topic=topic,
|
||||
messages=payload.get("messages", []),
|
||||
participants=set(payload.get("participants", [])),
|
||||
no_update_checks=payload.get("no_update_checks", 0),
|
||||
)
|
||||
loaded_count += 1
|
||||
|
||||
if loaded_count:
|
||||
logger.info(f"{self.log_prefix} 已加载 {loaded_count} 个话题缓存,继续追踪")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 加载话题缓存失败: {e}")
|
||||
|
||||
async def _load_batch_from_disk(self):
|
||||
"""在启动时加载聊天批次,支持重启后继续"""
|
||||
try:
|
||||
if not self._topic_cache_file.exists():
|
||||
return
|
||||
|
||||
with self._topic_cache_file.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
batch_data = data.get("current_batch")
|
||||
if not batch_data:
|
||||
return
|
||||
|
||||
start_time = batch_data.get("start_time")
|
||||
end_time = batch_data.get("end_time")
|
||||
if not start_time or not end_time:
|
||||
return
|
||||
|
||||
# 根据时间范围重新查询消息
|
||||
messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=0,
|
||||
limit_mode="latest",
|
||||
filter_mai=False,
|
||||
filter_command=False,
|
||||
)
|
||||
|
||||
if messages:
|
||||
self.current_batch = MessageBatch(
|
||||
messages=messages,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 已恢复聊天批次,包含 {len(messages)} 条消息")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 加载聊天批次失败: {e}")
|
||||
|
||||
def _persist_topic_cache(self):
|
||||
"""实时持久化话题缓存和聊天批次,避免重启后丢失"""
|
||||
try:
|
||||
# 如果既没有话题缓存也没有批次,删除缓存文件
|
||||
if not self.topic_cache and not self.current_batch:
|
||||
if self._topic_cache_file.exists():
|
||||
self._topic_cache_file.unlink()
|
||||
return
|
||||
|
||||
HIPPO_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"chat_id": self.chat_id,
|
||||
"last_topic_check_time": self.last_topic_check_time,
|
||||
"topics": {
|
||||
topic: {
|
||||
"messages": item.messages,
|
||||
"participants": list(item.participants),
|
||||
"no_update_checks": item.no_update_checks,
|
||||
}
|
||||
for topic, item in self.topic_cache.items()
|
||||
},
|
||||
}
|
||||
|
||||
# 保存当前批次的时间范围(如果有)
|
||||
if self.current_batch:
|
||||
data["current_batch"] = {
|
||||
"start_time": self.current_batch.start_time,
|
||||
"end_time": self.current_batch.end_time,
|
||||
}
|
||||
|
||||
with self._topic_cache_file.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 持久化话题缓存失败: {e}")
|
||||
|
||||
async def process(self, current_time: Optional[float] = None):
|
||||
"""
|
||||
处理聊天内容概括
|
||||
|
||||
Args:
|
||||
current_time: 当前时间戳,如果为None则使用time.time()
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
|
||||
try:
|
||||
# 获取从上次检查时间到当前时间的新消息
|
||||
new_messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=self.last_check_time,
|
||||
end_time=current_time,
|
||||
limit=0,
|
||||
limit_mode="latest",
|
||||
filter_mai=False, # 不过滤bot消息,因为需要检查bot是否发言
|
||||
filter_command=False,
|
||||
)
|
||||
|
||||
if not new_messages:
|
||||
# 没有新消息,检查是否需要进行“话题检查”
|
||||
if self.current_batch and self.current_batch.messages:
|
||||
await self._check_and_run_topic_check(current_time)
|
||||
self.last_check_time = current_time
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
|
||||
)
|
||||
|
||||
# 有新消息,更新最后检查时间
|
||||
self.last_check_time = current_time
|
||||
|
||||
# 如果有当前批次,添加新消息
|
||||
if self.current_batch:
|
||||
before_count = len(self.current_batch.messages)
|
||||
self.current_batch.messages.extend(new_messages)
|
||||
self.current_batch.end_time = current_time
|
||||
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
# 更新批次后持久化
|
||||
self._persist_topic_cache()
|
||||
else:
|
||||
# 创建新批次
|
||||
self.current_batch = MessageBatch(
|
||||
messages=new_messages,
|
||||
start_time=new_messages[0].time if new_messages else current_time,
|
||||
end_time=current_time,
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 新建聊天检查批次: {len(new_messages)} 条消息")
|
||||
# 创建批次后持久化
|
||||
self._persist_topic_cache()
|
||||
|
||||
# 检查是否需要触发“话题检查”
|
||||
await self._check_and_run_topic_check(current_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _check_and_run_topic_check(self, current_time: float):
|
||||
"""
|
||||
检查是否需要进行一次“话题检查”
|
||||
|
||||
触发条件:
|
||||
- 当前批次消息数 >= 100,或者
|
||||
- 距离上一次检查的时间 > 3600 秒(1小时)
|
||||
"""
|
||||
if not self.current_batch or not self.current_batch.messages:
|
||||
return
|
||||
|
||||
messages = self.current_batch.messages
|
||||
message_count = len(messages)
|
||||
time_since_last_check = current_time - self.last_topic_check_time
|
||||
|
||||
# 格式化时间差显示
|
||||
if time_since_last_check < 60:
|
||||
time_str = f"{time_since_last_check:.1f}秒"
|
||||
elif time_since_last_check < 3600:
|
||||
time_str = f"{time_since_last_check / 60:.1f}分钟"
|
||||
else:
|
||||
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
|
||||
)
|
||||
|
||||
# 检查“话题检查”触发条件
|
||||
should_check = False
|
||||
|
||||
# 条件1: 消息数量 >= 100,触发一次检查
|
||||
if message_count >= 80:
|
||||
should_check = True
|
||||
logger.info(f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: 100条)")
|
||||
|
||||
# 条件2: 距离上一次检查 > 3600 秒(1小时),触发一次检查
|
||||
elif time_since_last_check > 2400:
|
||||
should_check = True
|
||||
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: 1小时)")
|
||||
|
||||
if should_check:
|
||||
await self._run_topic_check_and_update_cache(messages)
|
||||
# 本批次已经被处理为话题信息,可以清空
|
||||
self.current_batch = None
|
||||
# 更新上一次检查时间,并持久化
|
||||
self.last_topic_check_time = current_time
|
||||
self._persist_topic_cache()
|
||||
|
||||
async def _run_topic_check_and_update_cache(self, messages: List[DatabaseMessages]):
|
||||
"""
|
||||
执行一次“话题检查”:
|
||||
1. 首先确认这段消息里是否有 Bot 发言,没有则直接丢弃本次批次;
|
||||
2. 将消息编号并转成字符串,构造 LLM Prompt;
|
||||
3. 把历史话题标题列表放入 Prompt,要求 LLM:
|
||||
- 识别当前聊天中的话题(1 个或多个);
|
||||
- 为每个话题选出相关消息编号;
|
||||
- 若话题属于历史话题,则沿用原话题标题;
|
||||
4. LLM 返回 JSON:多个 {topic, message_indices};
|
||||
5. 更新本地话题缓存,并根据规则触发“话题打包存储”。
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
start_time = messages[0].time
|
||||
end_time = messages[-1].time
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始话题检查 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
|
||||
)
|
||||
|
||||
# 1. 检查当前批次内是否有 bot 发言(只检查当前批次,不往前推)
|
||||
# 原因:我们要记录的是 bot 参与过的对话片段,如果当前批次内 bot 没有发言,
|
||||
# 说明 bot 没有参与这段对话,不应该记录
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
has_bot_message = False
|
||||
|
||||
for msg in messages:
|
||||
if msg.user_info.user_id == bot_user_id:
|
||||
has_bot_message = True
|
||||
break
|
||||
|
||||
if not has_bot_message:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 当前批次内无 Bot 发言,丢弃本次检查 | 时间范围: {start_time:.2f} - {end_time:.2f}"
|
||||
)
|
||||
return
|
||||
|
||||
# 2. 构造编号后的消息字符串和参与者信息
|
||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
|
||||
|
||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices
|
||||
existing_topics = list(self.topic_cache.keys())
|
||||
success, topic_to_indices = await self._analyze_topics_with_llm(
|
||||
numbered_lines=numbered_lines,
|
||||
existing_topics=existing_topics,
|
||||
)
|
||||
|
||||
if not success or not topic_to_indices:
|
||||
logger.warning(f"{self.log_prefix} 话题识别失败或无有效话题,本次检查忽略")
|
||||
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状)
|
||||
return
|
||||
|
||||
# 4. 统计哪些话题在本次检查中有新增内容
|
||||
updated_topics: Set[str] = set()
|
||||
|
||||
for topic, indices in topic_to_indices.items():
|
||||
if not indices:
|
||||
continue
|
||||
|
||||
item = self.topic_cache.get(topic)
|
||||
if not item:
|
||||
# 新话题
|
||||
item = TopicCacheItem(topic=topic)
|
||||
self.topic_cache[topic] = item
|
||||
|
||||
# 收集属于该话题的消息文本(不带编号)
|
||||
topic_msg_texts: List[str] = []
|
||||
new_participants: Set[str] = set()
|
||||
for idx in indices:
|
||||
msg_text = index_to_msg_text.get(idx)
|
||||
if not msg_text:
|
||||
continue
|
||||
topic_msg_texts.append(msg_text)
|
||||
new_participants.update(index_to_participants.get(idx, set()))
|
||||
|
||||
if not topic_msg_texts:
|
||||
continue
|
||||
|
||||
# 将本次检查中属于该话题的所有消息合并为一个字符串(不带编号)
|
||||
merged_text = "\n".join(topic_msg_texts)
|
||||
item.messages.append(merged_text)
|
||||
item.participants.update(new_participants)
|
||||
# 本次检查中该话题有更新,重置计数
|
||||
item.no_update_checks = 0
|
||||
updated_topics.add(topic)
|
||||
|
||||
# 5. 对于本次没有更新的历史话题,no_update_checks + 1
|
||||
for topic, item in list(self.topic_cache.items()):
|
||||
if topic not in updated_topics:
|
||||
item.no_update_checks += 1
|
||||
|
||||
# 6. 检查是否有话题需要打包存储
|
||||
topics_to_finalize: List[str] = []
|
||||
for topic, item in self.topic_cache.items():
|
||||
if item.no_update_checks >= 3:
|
||||
logger.info(f"{self.log_prefix} 话题[{topic}] 连续 3 次检查无新增内容,触发打包存储")
|
||||
topics_to_finalize.append(topic)
|
||||
continue
|
||||
if len(item.messages) > 5:
|
||||
logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 4,触发打包存储")
|
||||
topics_to_finalize.append(topic)
|
||||
|
||||
for topic in topics_to_finalize:
|
||||
item = self.topic_cache.get(topic)
|
||||
if not item:
|
||||
continue
|
||||
try:
|
||||
await self._finalize_and_store_topic(
|
||||
topic=topic,
|
||||
item=item,
|
||||
# 这里的时间范围尽量覆盖最近一次检查的区间
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
finally:
|
||||
# 无论成功与否,都从缓存中删除,避免重复
|
||||
self.topic_cache.pop(topic, None)
|
||||
|
||||
def _build_numbered_messages_for_llm(
|
||||
self, messages: List[DatabaseMessages]
|
||||
) -> tuple[List[str], Dict[int, str], Dict[int, str], Dict[int, Set[str]]]:
|
||||
"""
|
||||
将消息转为带编号的字符串,供 LLM 选择使用。
|
||||
|
||||
返回:
|
||||
numbered_lines: ["1. xxx", "2. yyy", ...] # 带编号,用于 LLM 选择
|
||||
index_to_msg_str: idx -> "idx. xxx" # 带编号,用于 LLM 选择
|
||||
index_to_msg_text: idx -> "xxx" # 不带编号,用于最终存储
|
||||
index_to_participants: idx -> {nickname1, nickname2, ...}
|
||||
"""
|
||||
numbered_lines: List[str] = []
|
||||
index_to_msg_str: Dict[int, str] = {}
|
||||
index_to_msg_text: Dict[int, str] = {} # 不带编号的消息文本
|
||||
index_to_participants: Dict[int, Set[str]] = {}
|
||||
|
||||
for idx, msg in enumerate(messages, start=1):
|
||||
# 使用 build_readable_messages 生成可读文本
|
||||
try:
|
||||
text = build_readable_messages(
|
||||
messages=[msg],
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
).strip()
|
||||
except Exception:
|
||||
# 回退到简单文本
|
||||
text = getattr(msg, "processed_plain_text", "") or ""
|
||||
|
||||
# 获取发言人昵称
|
||||
participants: Set[str] = set()
|
||||
try:
|
||||
platform = (
|
||||
getattr(msg, "user_platform", None)
|
||||
or (msg.user_info.platform if msg.user_info else None)
|
||||
or msg.chat_info.platform
|
||||
)
|
||||
user_id = msg.user_info.user_id if msg.user_info else None
|
||||
if platform and user_id:
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
if person.person_name:
|
||||
participants.add(person.person_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 带编号的字符串(用于 LLM 选择)
|
||||
line = f"{idx}. {text}"
|
||||
numbered_lines.append(line)
|
||||
index_to_msg_str[idx] = line
|
||||
# 不带编号的文本(用于最终存储)
|
||||
index_to_msg_text[idx] = text
|
||||
index_to_participants[idx] = participants
|
||||
|
||||
return numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants
|
||||
|
||||
async def _analyze_topics_with_llm(
|
||||
self,
|
||||
numbered_lines: List[str],
|
||||
existing_topics: List[str],
|
||||
) -> tuple[bool, Dict[str, List[int]]]:
|
||||
"""
|
||||
使用 LLM 识别本次检查中的话题,并为每个话题选择相关消息编号。
|
||||
|
||||
要求:
|
||||
- 话题用一句话清晰描述正在发生的事件,包括时间、人物、主要事件和主题;
|
||||
- 可以有 1 个或多个话题;
|
||||
- 若某个话题与历史话题列表中的某个话题是同一件事,请直接使用历史话题的字符串;
|
||||
- 输出 JSON,格式:
|
||||
[
|
||||
{
|
||||
"topic": "话题标题字符串",
|
||||
"message_indices": [1, 2, 5]
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
if not numbered_lines:
|
||||
return False, {}
|
||||
|
||||
history_topics_block = (
|
||||
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||
)
|
||||
messages_block = "\n".join(numbered_lines)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"hippo_topic_analysis_prompt",
|
||||
history_topics_block=history_topics_block,
|
||||
messages_block=messages_block,
|
||||
)
|
||||
|
||||
try:
|
||||
response, _ = await self.summarizer_llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.2,
|
||||
max_tokens=800,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
|
||||
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
|
||||
|
||||
# 尝试从响应中提取JSON代码块
|
||||
json_str = None
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
# 找到JSON代码块,使用第一个匹配
|
||||
json_str = matches[0].strip()
|
||||
else:
|
||||
# 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
|
||||
# 查找第一个 [ 和最后一个 ]
|
||||
start_idx = response.find('[')
|
||||
end_idx = response.rfind(']')
|
||||
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||
json_str = response[start_idx:end_idx + 1].strip()
|
||||
else:
|
||||
# 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
|
||||
json_str = response.strip()
|
||||
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = json_str.strip()
|
||||
|
||||
# 使用json_repair修复可能的JSON错误
|
||||
if json_str:
|
||||
try:
|
||||
repaired_json = repair_json(json_str)
|
||||
result = json.loads(repaired_json) if isinstance(repaired_json, str) else repaired_json
|
||||
except Exception as repair_error:
|
||||
# 如果repair失败,尝试直接解析
|
||||
logger.warning(f"{self.log_prefix} JSON修复失败,尝试直接解析: {repair_error}")
|
||||
result = json.loads(json_str)
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON内容")
|
||||
|
||||
if not isinstance(result, list):
|
||||
logger.error(f"{self.log_prefix} 话题识别返回的 JSON 不是列表: {result}")
|
||||
return False, {}
|
||||
|
||||
topic_to_indices: Dict[str, List[int]] = {}
|
||||
for item in result:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
topic = item.get("topic")
|
||||
indices = item.get("message_indices") or item.get("messages") or []
|
||||
if not topic or not isinstance(topic, str):
|
||||
continue
|
||||
if isinstance(indices, list):
|
||||
valid_indices: List[int] = []
|
||||
for v in indices:
|
||||
try:
|
||||
iv = int(v)
|
||||
if iv > 0:
|
||||
valid_indices.append(iv)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if valid_indices:
|
||||
topic_to_indices[topic] = valid_indices
|
||||
|
||||
return True, topic_to_indices
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 话题识别 LLM 调用或解析失败: {e}")
|
||||
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
|
||||
return False, {}
|
||||
|
||||
async def _finalize_and_store_topic(
|
||||
self,
|
||||
topic: str,
|
||||
item: TopicCacheItem,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
):
|
||||
"""
|
||||
对某个话题进行最终打包存储:
|
||||
1. 将 messages(list[str]) 拼接为 original_text;
|
||||
2. 使用 LLM 对 original_text 进行总结,得到 summary 和 keywords,theme 直接使用话题字符串;
|
||||
3. 写入数据库 ChatHistory;
|
||||
4. 完成后,调用方会从缓存中删除该话题。
|
||||
"""
|
||||
if not item.messages:
|
||||
logger.info(f"{self.log_prefix} 话题[{topic}] 无消息内容,跳过打包")
|
||||
return
|
||||
|
||||
original_text = "\n".join(item.messages)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始打包话题[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
|
||||
)
|
||||
|
||||
# 使用 LLM 进行总结(基于话题名)
|
||||
success, keywords, summary, key_point = await self._compress_with_llm(original_text, topic)
|
||||
if not success:
|
||||
logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败,不写入数据库")
|
||||
return
|
||||
|
||||
participants = list(item.participants)
|
||||
|
||||
await self._store_to_database(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
original_text=original_text,
|
||||
participants=participants,
|
||||
theme=topic, # 主题直接使用话题名
|
||||
keywords=keywords,
|
||||
summary=summary,
|
||||
key_point=key_point,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 话题[{topic}] 成功打包并存储 | 消息数: {len(item.messages)} | 参与者数: {len(participants)}"
|
||||
)
|
||||
|
||||
async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str, List[str]]:
|
||||
"""
|
||||
使用LLM压缩聊天内容(用于单个话题的最终总结)
|
||||
|
||||
Args:
|
||||
original_text: 聊天记录原文
|
||||
topic: 话题名称
|
||||
|
||||
Returns:
|
||||
tuple[bool, List[str], str, List[str]]: (是否成功, 关键词列表, 概括, 关键信息列表)
|
||||
"""
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"hippo_topic_summary_prompt",
|
||||
topic=topic,
|
||||
original_text=original_text,
|
||||
)
|
||||
|
||||
try:
|
||||
response, _ = await self.summarizer_llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
# 解析JSON响应
|
||||
json_str = response.strip()
|
||||
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = json_str.strip()
|
||||
|
||||
# 查找JSON对象的开始与结束
|
||||
start_idx = json_str.find("{")
|
||||
if start_idx == -1:
|
||||
raise ValueError("未找到JSON对象开始标记")
|
||||
|
||||
end_idx = json_str.rfind("}")
|
||||
if end_idx == -1 or end_idx <= start_idx:
|
||||
logger.warning(f"{self.log_prefix} JSON缺少结束标记,尝试自动修复")
|
||||
extracted_json = json_str[start_idx:]
|
||||
else:
|
||||
extracted_json = json_str[start_idx : end_idx + 1]
|
||||
|
||||
def _parse_with_quote_fix(payload: str) -> Dict[str, Any]:
|
||||
fixed_chars: List[str] = []
|
||||
in_string = False
|
||||
escape_next = False
|
||||
i = 0
|
||||
while i < len(payload):
|
||||
char = payload[i]
|
||||
if escape_next:
|
||||
fixed_chars.append(char)
|
||||
escape_next = False
|
||||
elif char == "\\":
|
||||
fixed_chars.append(char)
|
||||
escape_next = True
|
||||
elif char == '"' and not escape_next:
|
||||
fixed_chars.append(char)
|
||||
in_string = not in_string
|
||||
elif in_string and char in {"“", "”"}:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
fixed_chars.append('\\"')
|
||||
else:
|
||||
fixed_chars.append(char)
|
||||
i += 1
|
||||
|
||||
repaired = "".join(fixed_chars)
|
||||
return json.loads(repaired)
|
||||
|
||||
try:
|
||||
result = json.loads(extracted_json)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
repaired_json = repair_json(extracted_json)
|
||||
if isinstance(repaired_json, str):
|
||||
result = json.loads(repaired_json)
|
||||
else:
|
||||
result = repaired_json
|
||||
except Exception as repair_error:
|
||||
logger.warning(f"{self.log_prefix} repair_json 失败,使用引号修复: {repair_error}")
|
||||
result = _parse_with_quote_fix(extracted_json)
|
||||
|
||||
keywords = result.get("keywords", [])
|
||||
summary = result.get("summary", "无概括")
|
||||
key_point = result.get("key_point", [])
|
||||
|
||||
# 确保keywords和key_point是列表
|
||||
if isinstance(keywords, str):
|
||||
keywords = [keywords]
|
||||
if isinstance(key_point, str):
|
||||
key_point = [key_point]
|
||||
|
||||
return True, keywords, summary, key_point
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
|
||||
# 返回失败标志和默认值
|
||||
return False, [], "压缩失败,无法生成概括", []
|
||||
|
||||
async def _store_to_database(
|
||||
self,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
original_text: str,
|
||||
participants: List[str],
|
||||
theme: str,
|
||||
keywords: List[str],
|
||||
summary: str,
|
||||
key_point: Optional[List[str]] = None,
|
||||
):
|
||||
"""存储到数据库"""
|
||||
try:
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.plugin_system.apis import database_api
|
||||
|
||||
# 准备数据
|
||||
data = {
|
||||
"chat_id": self.chat_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"original_text": original_text,
|
||||
"participants": json.dumps(participants, ensure_ascii=False),
|
||||
"theme": theme,
|
||||
"keywords": json.dumps(keywords, ensure_ascii=False),
|
||||
"summary": summary,
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
# 存储 key_point(如果存在)
|
||||
if key_point is not None:
|
||||
data["key_point"] = json.dumps(key_point, ensure_ascii=False)
|
||||
|
||||
# 使用db_save存储(使用start_time和chat_id作为唯一标识)
|
||||
# 由于可能有多条记录,我们使用组合键,但peewee不支持,所以使用start_time作为唯一标识
|
||||
# 但为了避免冲突,我们使用组合键:chat_id + start_time
|
||||
# 由于peewee不支持组合键,我们直接创建新记录(不提供key_field和key_value)
|
||||
saved_record = await database_api.db_save(
|
||||
ChatHistory,
|
||||
data=data,
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""启动后台定期检查循环"""
|
||||
if self._running:
|
||||
logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动")
|
||||
return
|
||||
|
||||
# 加载聊天批次(如果有)
|
||||
await self._load_batch_from_disk()
|
||||
|
||||
self._running = True
|
||||
self._periodic_task = asyncio.create_task(self._periodic_check_loop())
|
||||
logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}秒")
|
||||
|
||||
async def stop(self):
|
||||
"""停止后台定期检查循环"""
|
||||
self._running = False
|
||||
if self._periodic_task:
|
||||
self._periodic_task.cancel()
|
||||
try:
|
||||
await self._periodic_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._periodic_task = None
|
||||
logger.info(f"{self.log_prefix} 已停止后台定期检查循环")
|
||||
|
||||
async def _periodic_check_loop(self):
|
||||
"""后台定期检查循环"""
|
||||
try:
|
||||
while self._running:
|
||||
# 执行一次检查
|
||||
await self.process()
|
||||
|
||||
# 等待指定间隔后再次检查
|
||||
await asyncio.sleep(self.check_interval)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 后台检查循环被取消")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self._running = False
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -44,9 +44,7 @@ class JargonExplainer:
|
||||
request_type="jargon.explain",
|
||||
)
|
||||
|
||||
def match_jargon_from_messages(
|
||||
self, messages: List[Any]
|
||||
) -> List[Dict[str, str]]:
|
||||
def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
通过直接匹配数据库中的jargon字符串来提取黑话
|
||||
|
||||
@@ -57,7 +55,7 @@ class JargonExplainer:
|
||||
List[Dict[str, str]]: 提取到的黑话列表,每个元素包含content
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
@@ -67,8 +65,10 @@ class JargonExplainer:
|
||||
# 跳过机器人自己的消息
|
||||
if is_bot_message(msg):
|
||||
continue
|
||||
|
||||
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip()
|
||||
|
||||
msg_text = (
|
||||
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||
).strip()
|
||||
if msg_text:
|
||||
message_texts.append(msg_text)
|
||||
|
||||
@@ -79,9 +79,7 @@ class JargonExplainer:
|
||||
combined_text = " ".join(message_texts)
|
||||
|
||||
# 查询所有有meaning的jargon记录
|
||||
query = Jargon.select().where(
|
||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
||||
)
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
@@ -98,7 +96,7 @@ class JargonExplainer:
|
||||
# 执行查询并匹配
|
||||
matched_jargon: Dict[str, Dict[str, str]] = {}
|
||||
query_time = time.time()
|
||||
|
||||
|
||||
for jargon in query:
|
||||
content = jargon.content or ""
|
||||
if not content or not content.strip():
|
||||
@@ -123,13 +121,13 @@ class JargonExplainer:
|
||||
pattern = re.escape(content)
|
||||
# 使用单词边界或中文字符边界来匹配,避免部分匹配
|
||||
# 对于中文,使用Unicode字符类;对于英文,使用单词边界
|
||||
if re.search(r'[\u4e00-\u9fff]', content):
|
||||
if re.search(r"[\u4e00-\u9fff]", content):
|
||||
# 包含中文,使用更宽松的匹配
|
||||
search_pattern = pattern
|
||||
else:
|
||||
# 纯英文/数字,使用单词边界
|
||||
search_pattern = r'\b' + pattern + r'\b'
|
||||
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, combined_text, re.IGNORECASE):
|
||||
# 找到匹配,记录(去重)
|
||||
if content not in matched_jargon:
|
||||
@@ -139,17 +137,15 @@ class JargonExplainer:
|
||||
total_time = match_time - start_time
|
||||
query_duration = query_time - start_time
|
||||
match_duration = match_time - query_time
|
||||
|
||||
logger.info(
|
||||
|
||||
logger.debug(
|
||||
f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, "
|
||||
f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话"
|
||||
)
|
||||
|
||||
return list(matched_jargon.values())
|
||||
|
||||
async def explain_jargon(
|
||||
self, messages: List[Any], chat_context: str
|
||||
) -> Optional[str]:
|
||||
async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||
"""
|
||||
解释上下文中的黑话
|
||||
|
||||
@@ -183,7 +179,7 @@ class JargonExplainer:
|
||||
jargon_explanations: List[str] = []
|
||||
for entry in jargon_list:
|
||||
content = entry["content"]
|
||||
|
||||
|
||||
# 根据是否开启全局黑话,决定查询方式
|
||||
if global_config.jargon.all_global:
|
||||
# 开启全局黑话:查询所有is_global=True的记录
|
||||
@@ -239,9 +235,7 @@ class JargonExplainer:
|
||||
return summary
|
||||
|
||||
|
||||
async def explain_jargon_in_context(
|
||||
chat_id: str, messages: List[Any], chat_context: str
|
||||
) -> Optional[str]:
|
||||
async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||
"""
|
||||
解释上下文中的黑话(便捷函数)
|
||||
|
||||
@@ -256,3 +250,111 @@ async def explain_jargon_in_context(
|
||||
explainer = JargonExplainer(chat_id)
|
||||
return await explainer.explain_jargon(messages, chat_context)
|
||||
|
||||
|
||||
def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
"""直接在聊天文本中匹配已知的jargon,返回出现过的黑话列表
|
||||
|
||||
Args:
|
||||
chat_text: 要匹配的聊天文本
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
List[str]: 匹配到的黑话列表
|
||||
"""
|
||||
if not chat_text or not chat_text.strip():
|
||||
return []
|
||||
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
if global_config.jargon.all_global:
|
||||
query = query.where(Jargon.is_global)
|
||||
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
matched: Dict[str, None] = {}
|
||||
|
||||
for jargon in query:
|
||||
content = (jargon.content or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if not global_config.jargon.all_global and not jargon.is_global:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
|
||||
pattern = re.escape(content)
|
||||
if re.search(r"[\u4e00-\u9fff]", content):
|
||||
search_pattern = pattern
|
||||
else:
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, chat_text, re.IGNORECASE):
|
||||
matched[content] = None
|
||||
|
||||
logger.info(f"匹配到 {len(matched)} 个黑话")
|
||||
|
||||
return list(matched.keys())
|
||||
|
||||
|
||||
async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
|
||||
"""对概念列表进行jargon检索
|
||||
|
||||
Args:
|
||||
concepts: 概念列表
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 检索结果字符串
|
||||
"""
|
||||
if not concepts:
|
||||
return ""
|
||||
|
||||
results = []
|
||||
exact_matches = [] # 收集所有精确匹配的概念
|
||||
for concept in concepts:
|
||||
concept = concept.strip()
|
||||
if not concept:
|
||||
continue
|
||||
|
||||
# 先尝试精确匹配
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not jargon_results:
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if jargon_results:
|
||||
# 找到结果
|
||||
if is_fuzzy_match:
|
||||
# 模糊匹配
|
||||
output_parts = [f"未精确匹配到'{concept}'"]
|
||||
for result in jargon_results:
|
||||
found_content = result.get("content", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
results.append(",".join(output_parts))
|
||||
logger.info(f"在jargon库中找到匹配(模糊搜索): {concept},找到{len(jargon_results)}条结果")
|
||||
else:
|
||||
# 精确匹配
|
||||
output_parts = []
|
||||
for result in jargon_results:
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
results.append(";".join(output_parts) if len(output_parts) > 1 else output_parts[0])
|
||||
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
|
||||
else:
|
||||
# 未找到,不返回占位信息,只记录日志
|
||||
logger.info(f"在jargon库中未找到匹配: {concept}")
|
||||
|
||||
# 合并所有精确匹配的日志
|
||||
if exact_matches:
|
||||
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
|
||||
|
||||
if results:
|
||||
return "【概念检索结果】\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
@@ -17,20 +17,18 @@ from src.chat.utils.chat_message_builder import (
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.jargon.jargon_utils import (
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
chat_id_list_contains,
|
||||
update_chat_id_list
|
||||
update_chat_id_list,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
|
||||
|
||||
def _init_prompt() -> None:
|
||||
prompt_str = """
|
||||
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
||||
@@ -126,7 +124,6 @@ _init_prompt()
|
||||
_init_inference_prompts()
|
||||
|
||||
|
||||
|
||||
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||
"""
|
||||
判断是否需要进行含义推断
|
||||
@@ -185,6 +182,9 @@ class JargonMiner:
|
||||
self.stream_name = stream_name if stream_name else self.chat_id
|
||||
self.cache_limit = 100
|
||||
self.cache: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
def _add_to_cache(self, content: str) -> None:
|
||||
"""将提取到的黑话加入缓存,保持LRU语义"""
|
||||
@@ -211,7 +211,9 @@ class JargonMiner:
|
||||
processed_pairs = set()
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip()
|
||||
msg_text = (
|
||||
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||
).strip()
|
||||
if not msg_text or is_bot_message(msg):
|
||||
continue
|
||||
|
||||
@@ -270,7 +272,7 @@ class JargonMiner:
|
||||
prompt1 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_with_context_prompt",
|
||||
content=content,
|
||||
bot_name = global_config.bot.nickname,
|
||||
bot_name=global_config.bot.nickname,
|
||||
raw_content_list=raw_content_text,
|
||||
)
|
||||
|
||||
@@ -437,262 +439,265 @@ class JargonMiner:
|
||||
return bool(recent_messages and len(recent_messages) >= self.min_messages_for_learning)
|
||||
|
||||
async def run_once(self) -> None:
|
||||
try:
|
||||
if not self.should_trigger():
|
||||
return
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
if not chat_stream:
|
||||
return
|
||||
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_learning_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
# 拉取学习窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
limit=20,
|
||||
)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
# 按时间排序,确保编号与上下文一致
|
||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||
|
||||
chat_str, message_id_list = build_readable_messages_with_id(
|
||||
messages=messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
pic_single=True,
|
||||
)
|
||||
if not chat_str.strip():
|
||||
return
|
||||
|
||||
msg_id_to_index: Dict[str, int] = {}
|
||||
for idx, (msg_id, _msg) in enumerate(message_id_list or []):
|
||||
if not msg_id:
|
||||
continue
|
||||
msg_id_to_index[msg_id] = idx
|
||||
if not msg_id_to_index:
|
||||
logger.warning("未能生成消息ID映射,跳过本次提取")
|
||||
return
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"extract_jargon_prompt",
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_str=chat_str,
|
||||
)
|
||||
|
||||
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
|
||||
if not response:
|
||||
return
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon提取提示词: {prompt}")
|
||||
logger.info(f"jargon提取结果: {response}")
|
||||
|
||||
# 解析为JSON
|
||||
entries: List[dict] = []
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._extraction_lock:
|
||||
try:
|
||||
resp = response.strip()
|
||||
parsed = None
|
||||
if resp.startswith("[") and resp.endswith("]"):
|
||||
parsed = json.loads(resp)
|
||||
else:
|
||||
repaired = repair_json(resp)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
# 在锁内检查,避免并发触发
|
||||
if not self.should_trigger():
|
||||
return
|
||||
|
||||
for item in parsed:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
if not chat_stream:
|
||||
return
|
||||
|
||||
content = str(item.get("content", "")).strip()
|
||||
msg_id_value = item.get("msg_id")
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
|
||||
msg_id_str = str(msg_id_value or "").strip()
|
||||
if not msg_id_str:
|
||||
logger.warning(f"解析jargon失败:msg_id缺失,content={content}")
|
||||
continue
|
||||
|
||||
msg_index = msg_id_to_index.get(msg_id_str)
|
||||
if msg_index is None:
|
||||
logger.warning(f"解析jargon失败:msg_id未找到,content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
target_msg = messages[msg_index]
|
||||
if is_bot_message(target_msg):
|
||||
logger.info(f"解析阶段跳过引用机器人自身消息的词条: content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
context_paragraph = build_context_paragraph(messages, msg_index)
|
||||
if not context_paragraph:
|
||||
logger.warning(f"解析jargon失败:上下文为空,content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
entries.append({"content": content, "raw_content": [context_paragraph]})
|
||||
cached_entries = self._collect_cached_entries(messages)
|
||||
if cached_entries:
|
||||
entries.extend(cached_entries)
|
||||
except Exception as e:
|
||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||
return
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
# 去重并合并raw_content(按 content 聚合)
|
||||
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
|
||||
for entry in entries:
|
||||
content_key = entry["content"]
|
||||
raw_list = entry.get("raw_content", []) or []
|
||||
if content_key in merged_entries:
|
||||
merged_entries[content_key]["raw_content"].extend(raw_list)
|
||||
else:
|
||||
merged_entries[content_key] = {
|
||||
"content": content_key,
|
||||
"raw_content": list(raw_list),
|
||||
}
|
||||
|
||||
uniq_entries = []
|
||||
for merged_entry in merged_entries.values():
|
||||
raw_content_list = merged_entry["raw_content"]
|
||||
if raw_content_list:
|
||||
merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list))
|
||||
uniq_entries.append(merged_entry)
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
|
||||
try:
|
||||
# 查询所有content匹配的记录
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
|
||||
# 查找匹配的记录
|
||||
matched_obj = None
|
||||
for obj in query:
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_obj = obj
|
||||
break
|
||||
else:
|
||||
# 关闭all_global:需要检查chat_id列表是否包含目标chat_id
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
if chat_id_list_contains(chat_id_list, self.chat_id):
|
||||
matched_obj = obj
|
||||
break
|
||||
|
||||
if matched_obj:
|
||||
obj = matched_obj
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
# 更新chat_id列表:增加当前chat_id的计数
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1)
|
||||
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.jargon.all_global:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
obj.save()
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
# 使用新格式创建chat_id列表:[[chat_id, count]]
|
||||
chat_id_list = [[self.chat_id, 1]]
|
||||
chat_id_json = json.dumps(chat_id_list, ensure_ascii=False)
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=chat_id_json,
|
||||
is_global=is_global_new,
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
||||
continue
|
||||
finally:
|
||||
self._add_to_cache(content)
|
||||
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_learning_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
# 立即更新学习时间,防止并发触发
|
||||
self.last_learning_time = extraction_end_time
|
||||
|
||||
if saved or updated:
|
||||
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"JargonMiner 运行失败: {e}")
|
||||
# 拉取学习窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
limit=20,
|
||||
)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
# 按时间排序,确保编号与上下文一致
|
||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||
|
||||
chat_str, message_id_list = build_readable_messages_with_id(
|
||||
messages=messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
pic_single=True,
|
||||
)
|
||||
if not chat_str.strip():
|
||||
return
|
||||
|
||||
msg_id_to_index: Dict[str, int] = {}
|
||||
for idx, (msg_id, _msg) in enumerate(message_id_list or []):
|
||||
if not msg_id:
|
||||
continue
|
||||
msg_id_to_index[msg_id] = idx
|
||||
if not msg_id_to_index:
|
||||
logger.warning("未能生成消息ID映射,跳过本次提取")
|
||||
return
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"extract_jargon_prompt",
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_str=chat_str,
|
||||
)
|
||||
|
||||
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
|
||||
if not response:
|
||||
return
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon提取提示词: {prompt}")
|
||||
logger.info(f"jargon提取结果: {response}")
|
||||
|
||||
# 解析为JSON
|
||||
entries: List[dict] = []
|
||||
try:
|
||||
resp = response.strip()
|
||||
parsed = None
|
||||
if resp.startswith("[") and resp.endswith("]"):
|
||||
parsed = json.loads(resp)
|
||||
else:
|
||||
repaired = repair_json(resp)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
return
|
||||
|
||||
for item in parsed:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
content = str(item.get("content", "")).strip()
|
||||
msg_id_value = item.get("msg_id")
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
|
||||
msg_id_str = str(msg_id_value or "").strip()
|
||||
if not msg_id_str:
|
||||
logger.warning(f"解析jargon失败:msg_id缺失,content={content}")
|
||||
continue
|
||||
|
||||
msg_index = msg_id_to_index.get(msg_id_str)
|
||||
if msg_index is None:
|
||||
logger.warning(f"解析jargon失败:msg_id未找到,content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
target_msg = messages[msg_index]
|
||||
if is_bot_message(target_msg):
|
||||
logger.info(f"解析阶段跳过引用机器人自身消息的词条: content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
context_paragraph = build_context_paragraph(messages, msg_index)
|
||||
if not context_paragraph:
|
||||
logger.warning(f"解析jargon失败:上下文为空,content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
entries.append({"content": content, "raw_content": [context_paragraph]})
|
||||
cached_entries = self._collect_cached_entries(messages)
|
||||
if cached_entries:
|
||||
entries.extend(cached_entries)
|
||||
except Exception as e:
|
||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||
return
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
# 去重并合并raw_content(按 content 聚合)
|
||||
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
|
||||
for entry in entries:
|
||||
content_key = entry["content"]
|
||||
raw_list = entry.get("raw_content", []) or []
|
||||
if content_key in merged_entries:
|
||||
merged_entries[content_key]["raw_content"].extend(raw_list)
|
||||
else:
|
||||
merged_entries[content_key] = {
|
||||
"content": content_key,
|
||||
"raw_content": list(raw_list),
|
||||
}
|
||||
|
||||
uniq_entries = []
|
||||
for merged_entry in merged_entries.values():
|
||||
raw_content_list = merged_entry["raw_content"]
|
||||
if raw_content_list:
|
||||
merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list))
|
||||
uniq_entries.append(merged_entry)
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
try:
|
||||
# 查询所有content匹配的记录
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
|
||||
# 查找匹配的记录
|
||||
matched_obj = None
|
||||
for obj in query:
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_obj = obj
|
||||
break
|
||||
else:
|
||||
# 关闭all_global:需要检查chat_id列表是否包含目标chat_id
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
if chat_id_list_contains(chat_id_list, self.chat_id):
|
||||
matched_obj = obj
|
||||
break
|
||||
|
||||
if matched_obj:
|
||||
obj = matched_obj
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
# 更新chat_id列表:增加当前chat_id的计数
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1)
|
||||
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.jargon.all_global:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
obj.save()
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
# 使用新格式创建chat_id列表:[[chat_id, count]]
|
||||
chat_id_list = [[self.chat_id, 1]]
|
||||
chat_id_json = json.dumps(chat_id_list, ensure_ascii=False)
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=chat_id_json,
|
||||
is_global=is_global_new,
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
||||
continue
|
||||
finally:
|
||||
self._add_to_cache(content)
|
||||
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
if saved or updated:
|
||||
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"JargonMiner 运行失败: {e}")
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
|
||||
class JargonMinerManager:
|
||||
@@ -782,15 +787,15 @@ def search_jargon(
|
||||
# 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含
|
||||
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
|
||||
|
||||
# 只返回有meaning的记录
|
||||
if not jargon.meaning or jargon.meaning.strip() == "":
|
||||
continue
|
||||
|
||||
|
||||
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||
|
||||
|
||||
# 达到限制数量后停止
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
@@ -2,30 +2,29 @@ import json
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
build_readable_messages_with_id,
|
||||
)
|
||||
from src.chat.utils.utils import parse_platform_accounts
|
||||
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||
"""
|
||||
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
||||
|
||||
|
||||
Args:
|
||||
chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式)
|
||||
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
|
||||
"""
|
||||
if not chat_id_value:
|
||||
return []
|
||||
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(chat_id_value, str):
|
||||
# 尝试解析JSON
|
||||
@@ -54,12 +53,12 @@ def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
|
||||
"""
|
||||
更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目
|
||||
|
||||
|
||||
Args:
|
||||
chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要更新或添加的chat_id
|
||||
increment: 增加的计数,默认为1
|
||||
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 更新后的chat_id列表
|
||||
"""
|
||||
@@ -74,22 +73,22 @@ def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, incr
|
||||
item.append(increment)
|
||||
found = True
|
||||
break
|
||||
|
||||
|
||||
if not found:
|
||||
# 未找到,添加新条目
|
||||
chat_id_list.append([target_chat_id, increment])
|
||||
|
||||
|
||||
return chat_id_list
|
||||
|
||||
|
||||
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
|
||||
"""
|
||||
检查chat_id列表中是否包含指定的chat_id
|
||||
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果包含则返回True
|
||||
"""
|
||||
@@ -168,10 +167,7 @@ def is_bot_message(msg: Any) -> bool:
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
user_id = (
|
||||
str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "")
|
||||
.strip()
|
||||
)
|
||||
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
|
||||
if not platform or not user_id:
|
||||
return False
|
||||
@@ -196,4 +192,4 @@ def is_bot_message(msg: Any) -> bool:
|
||||
bot_accounts[plat] = account
|
||||
|
||||
bot_account = bot_accounts.get(platform)
|
||||
return bool(bot_account and user_id == bot_account)
|
||||
return bool(bot_account and user_id == bot_account)
|
||||
|
||||
@@ -18,11 +18,12 @@ error_code_mapping = {
|
||||
class NetworkConnectionError(Exception):
|
||||
"""连接异常,常见于网络问题或服务器不可用"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, message: str | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return "连接异常,请检查网络连接状态或URL是否正确"
|
||||
return self.message or "连接异常,请检查网络连接状态或URL是否正确"
|
||||
|
||||
|
||||
class ReqAbortException(Exception):
|
||||
|
||||
@@ -47,6 +47,21 @@ class LLMRequest:
|
||||
}
|
||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
|
||||
|
||||
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
|
||||
"""检查请求是否过慢并输出警告日志
|
||||
|
||||
Args:
|
||||
time_cost: 请求耗时(秒)
|
||||
model_name: 使用的模型名称
|
||||
"""
|
||||
threshold = self.model_for_task.slow_threshold
|
||||
if time_cost > threshold:
|
||||
request_type_display = self.request_type or "未知任务"
|
||||
logger.warning(
|
||||
f"LLM请求耗时过长: {request_type_display} 使用模型 {model_name} 耗时 {time_cost:.1f}s(阈值: {threshold}s),请考虑使用更快的模型\n"
|
||||
f" 如果你认为该警告出现得过于频繁,请调整model_config.toml中对应任务的slow_threshold至符合你实际情况的合理值"
|
||||
)
|
||||
|
||||
async def generate_response_for_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -86,6 +101,8 @@ class LLMRequest:
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
time_cost = time.time() - start_time
|
||||
self._check_slow_request(time_cost, model_info.name)
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
@@ -93,7 +110,7 @@ class LLMRequest:
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
time_cost=time.time() - start_time,
|
||||
time_cost=time_cost,
|
||||
)
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
@@ -198,7 +215,8 @@ class LLMRequest:
|
||||
tool_options=tool_built,
|
||||
)
|
||||
|
||||
logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
|
||||
time_cost = time.time() - start_time
|
||||
logger.debug(f"LLM请求总耗时: {time_cost}")
|
||||
logger.debug(f"LLM生成内容: {response}")
|
||||
|
||||
content = response.content
|
||||
@@ -207,6 +225,7 @@ class LLMRequest:
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
self._check_slow_request(time_cost, model_info.name)
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
@@ -214,7 +233,7 @@ class LLMRequest:
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
time_cost=time.time() - start_time,
|
||||
time_cost=time_cost,
|
||||
)
|
||||
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
@@ -301,7 +320,7 @@ class LLMRequest:
|
||||
message_list=(compressed_messages or message_list),
|
||||
tool_options=tool_options,
|
||||
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
||||
temperature=self.model_for_task.temperature if temperature is None else temperature,
|
||||
temperature=temperature if temperature is not None else (model_info.extra_params or {}).get("temperature", self.model_for_task.temperature),
|
||||
response_format=response_format,
|
||||
stream_response_handler=stream_response_handler,
|
||||
async_response_parser=async_response_parser,
|
||||
@@ -323,34 +342,45 @@ class LLMRequest:
|
||||
)
|
||||
except EmptyResponseException as e:
|
||||
# 空回复:通常为临时问题,单独记录并重试
|
||||
original_error_info = self._get_original_error_info(e)
|
||||
retry_remain -= 1
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。")
|
||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}")
|
||||
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except NetworkConnectionError as e:
|
||||
# 网络错误:单独记录并重试
|
||||
# 尝试从链式异常中获取原始错误信息以诊断具体原因
|
||||
original_error_info = self._get_original_error_info(e)
|
||||
|
||||
retry_remain -= 1
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
|
||||
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。{original_error_info}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}")
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}{original_error_info}\n"
|
||||
f" 常见原因: 如请求的API正常但APITimeoutError类型错误过多,请尝试调整模型配置中对应API Provider的timeout值\n"
|
||||
f" 其它可能原因: 网络波动、DNS 故障、连接超时、防火墙限制或代理问题\n"
|
||||
f" 剩余重试次数: {retry_remain}"
|
||||
)
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except RespNotOkException as e:
|
||||
original_error_info = self._get_original_error_info(e)
|
||||
|
||||
# 可重试的HTTP错误
|
||||
if e.status_code == 429 or e.status_code >= 500:
|
||||
retry_remain -= 1
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。")
|
||||
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}"
|
||||
f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}{original_error_info}。剩余重试次数: {retry_remain}"
|
||||
)
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
continue
|
||||
@@ -363,13 +393,15 @@ class LLMRequest:
|
||||
continue
|
||||
|
||||
# 不可重试的HTTP错误
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到不可重试的HTTP错误: {str(e)}")
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到不可重试的HTTP错误: {str(e)}{original_error_info}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}")
|
||||
original_error_info = self._get_original_error_info(e)
|
||||
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}{original_error_info}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
||||
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 未被尝试,因为重试次数已配置为0或更少。")
|
||||
@@ -483,3 +515,14 @@ class LLMRequest:
|
||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||
reasoning = match[1].strip() if match else ""
|
||||
return content, reasoning
|
||||
|
||||
@staticmethod
|
||||
def _get_original_error_info(e: Exception) -> str:
|
||||
"""获取原始错误信息"""
|
||||
if e.__cause__:
|
||||
original_error_type = type(e.__cause__).__name__
|
||||
original_error_msg = str(e.__cause__)
|
||||
return (
|
||||
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||
)
|
||||
return ""
|
||||
|
||||
@@ -56,7 +56,7 @@ class MainSystem:
|
||||
from src.webui.webui_server import get_webui_server
|
||||
|
||||
self.webui_server = get_webui_server()
|
||||
|
||||
|
||||
if webui_mode == "development":
|
||||
logger.info("📝 WebUI 开发模式已启用")
|
||||
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
|
||||
@@ -64,9 +64,9 @@ class MainSystem:
|
||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||
else:
|
||||
logger.info("✅ WebUI 生产模式已启用")
|
||||
logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001")
|
||||
logger.info("🌐 WebUI 将运行在 http://0.0.0.0:8001")
|
||||
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,8 @@ import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -16,101 +17,56 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
def parse_md_json(json_text: str) -> list[str]:
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, json_text, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = json_text.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = json_text[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
json_obj = json.loads(json_str)
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||
"""解析问题JSON,返回概念列表和问题列表
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
response: LLM返回的响应
|
||||
|
||||
Returns:
|
||||
float: 相似度分数 (0-1)
|
||||
Tuple[List[str], List[str]]: (概念列表, 问题列表)
|
||||
"""
|
||||
try:
|
||||
# 预处理文本
|
||||
text1 = preprocess_text(text1)
|
||||
text2 = preprocess_text(text2)
|
||||
# 尝试提取JSON(可能包含在```json代码块中)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
# 使用SequenceMatcher计算相似度
|
||||
similarity = SequenceMatcher(None, text1, text2).ratio()
|
||||
if matches:
|
||||
json_str = matches[0]
|
||||
else:
|
||||
# 尝试直接解析整个响应
|
||||
json_str = response.strip()
|
||||
|
||||
# 如果其中一个文本包含另一个,提高相似度
|
||||
if text1 in text2 or text2 in text1:
|
||||
similarity = max(similarity, 0.8)
|
||||
# 修复可能的JSON错误
|
||||
repaired_json = repair_json(json_str)
|
||||
|
||||
return similarity
|
||||
# 解析JSON
|
||||
parsed = json.loads(repaired_json)
|
||||
|
||||
# 只支持新格式:包含concepts和questions的对象
|
||||
if not isinstance(parsed, dict):
|
||||
logger.warning(f"解析的JSON不是对象格式: {parsed}")
|
||||
return [], []
|
||||
|
||||
concepts_raw = parsed.get("concepts", [])
|
||||
questions_raw = parsed.get("questions", [])
|
||||
|
||||
# 确保是列表
|
||||
if not isinstance(concepts_raw, list):
|
||||
concepts_raw = []
|
||||
if not isinstance(questions_raw, list):
|
||||
questions_raw = []
|
||||
|
||||
# 确保所有元素都是字符串
|
||||
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
|
||||
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
|
||||
|
||||
return concepts, questions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算相似度时出错: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""
|
||||
预处理文本,提高匹配准确性
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
str: 预处理后的文本
|
||||
"""
|
||||
try:
|
||||
# 转换为小写
|
||||
text = text.lower()
|
||||
|
||||
# 移除标点符号和特殊字符
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
|
||||
# 移除多余空格
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理文本时出错: {e}")
|
||||
return text
|
||||
|
||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||
return [], []
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
@@ -140,29 +96,3 @@ def parse_datetime_to_timestamp(value: str) -> float:
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def parse_time_range(time_range: str) -> Tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
|
||||
Args:
|
||||
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: (开始时间戳, 结束时间戳)
|
||||
"""
|
||||
if " - " not in time_range:
|
||||
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
|
||||
|
||||
parts = time_range.split(" - ", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"时间范围格式错误: {time_range}")
|
||||
|
||||
start_str = parts[0].strip()
|
||||
end_str = parts[1].strip()
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str)
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
@@ -14,6 +14,7 @@ from .tool_registry import (
|
||||
from .query_chat_history import register_tool as register_query_chat_history
|
||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
from .found_answer import register_tool as register_found_answer
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -21,6 +22,7 @@ def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
register_found_answer() # 注册found_answer工具
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
register_lpmm_knowledge()
|
||||
|
||||
40
src/memory_system/retrieval_tools/found_answer.py
Normal file
40
src/memory_system/retrieval_tools/found_answer.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
found_answer工具 - 用于在记忆检索过程中标记找到答案
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def found_answer(answer: str) -> str:
|
||||
"""标记已找到问题的答案
|
||||
|
||||
Args:
|
||||
answer: 找到的答案内容
|
||||
|
||||
Returns:
|
||||
str: 确认信息
|
||||
"""
|
||||
# 这个工具主要用于标记,实际答案会通过返回值传递
|
||||
logger.info(f"找到答案: {answer}")
|
||||
return f"已确认找到答案: {answer}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册found_answer工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="found_answer",
|
||||
description="当你在已收集的信息中找到了问题的明确答案时,调用此工具标记已找到答案。只有在检索到明确、具体的答案时才使用此工具,不要编造信息。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "answer",
|
||||
"type": "string",
|
||||
"description": "找到的答案内容,必须基于已收集的信息,不要编造",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
execute_func=found_answer,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
根据时间或关键词在chat_history中查询 - 工具实现
|
||||
根据关键词或参与人在chat_history中查询记忆 - 工具实现
|
||||
从ChatHistory表的聊天记录概述库中查询
|
||||
"""
|
||||
|
||||
@@ -9,103 +9,92 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from ..memory_utils import parse_datetime_to_timestamp, parse_time_range
|
||||
from datetime import datetime
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_chat_history(
|
||||
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
|
||||
async def search_chat_history(
|
||||
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
|
||||
) -> str:
|
||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
|
||||
time_range: 时间范围或时间点,格式:
|
||||
- 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
- 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录)
|
||||
fuzzy: 是否使用模糊匹配模式(默认True)
|
||||
- True: 模糊匹配,只要包含任意一个关键词即匹配(OR关系)
|
||||
- False: 全匹配,必须包含所有关键词才匹配(AND关系)
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配)
|
||||
participant: 参与人昵称(可选)
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
str: 查询结果,包含记忆id、theme和keywords
|
||||
"""
|
||||
try:
|
||||
# 检查参数
|
||||
if not keyword and not time_range:
|
||||
return "未指定查询参数(需要提供keyword或time_range之一)"
|
||||
if not keyword and not participant:
|
||||
return "未指定查询参数(需要提供keyword或participant之一)"
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
# 时间过滤条件
|
||||
if time_range:
|
||||
# 判断是时间点还是时间范围
|
||||
if " - " in time_range:
|
||||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
|
||||
else:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
|
||||
query = query.where(time_filter)
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
||||
# 如果有关键词,进一步过滤
|
||||
if keyword:
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
filtered_records = []
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
for record in records:
|
||||
participant_matched = True # 如果没有participant条件,默认为True
|
||||
keyword_matched = True # 如果没有keyword条件,默认为True
|
||||
|
||||
if not keywords_lower:
|
||||
return "关键词为空"
|
||||
|
||||
filtered_records = []
|
||||
|
||||
for record in records:
|
||||
# 在theme、keywords、summary、original_text中搜索
|
||||
theme = (record.theme or "").lower()
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
# 解析record中的keywords JSON
|
||||
record_keywords_list = []
|
||||
if record.keywords:
|
||||
# 检查参与人匹配
|
||||
if participant:
|
||||
participant_matched = False
|
||||
participants_list = []
|
||||
if record.participants:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
participants_data = (
|
||||
json.loads(record.participants)
|
||||
if isinstance(record.participants, str)
|
||||
else record.participants
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
if isinstance(participants_data, list):
|
||||
participants_list = [str(p).lower() for p in participants_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 根据匹配模式检查关键词
|
||||
matched = False
|
||||
if fuzzy:
|
||||
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
||||
for kw in keywords_lower:
|
||||
if (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
):
|
||||
matched = True
|
||||
break
|
||||
else:
|
||||
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
||||
matched = True
|
||||
participant_lower = participant.lower().strip()
|
||||
if participant_lower and any(participant_lower in p for p in participants_list):
|
||||
participant_matched = True
|
||||
|
||||
# 检查关键词匹配
|
||||
if keyword:
|
||||
keyword_matched = False
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
if keywords_lower:
|
||||
# 在theme、keywords、summary、original_text中搜索
|
||||
theme = (record.theme or "").lower()
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
# 解析record中的keywords JSON
|
||||
record_keywords_list = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 有容错的全匹配:如果关键词数量>2,允许n-1个关键词匹配;否则必须全部匹配
|
||||
matched_count = 0
|
||||
for kw in keywords_lower:
|
||||
kw_matched = (
|
||||
kw in theme
|
||||
@@ -113,73 +102,80 @@ async def query_chat_history(
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
)
|
||||
if not kw_matched:
|
||||
matched = False
|
||||
break
|
||||
if kw_matched:
|
||||
matched_count += 1
|
||||
|
||||
# 计算需要匹配的关键词数量
|
||||
total_keywords = len(keywords_lower)
|
||||
if total_keywords > 2:
|
||||
# 关键词数量>2,允许n-1个关键词匹配
|
||||
required_matches = total_keywords - 1
|
||||
else:
|
||||
# 关键词数量<=2,必须全部匹配
|
||||
required_matches = total_keywords
|
||||
|
||||
keyword_matched = matched_count >= required_matches
|
||||
|
||||
if matched:
|
||||
filtered_records.append(record)
|
||||
# 两者都匹配(如果同时有participant和keyword,需要两者都匹配;如果只有一个条件,只需要该条件匹配)
|
||||
matched = participant_matched and keyword_matched
|
||||
|
||||
if not filtered_records:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词"
|
||||
if time_range:
|
||||
return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述"
|
||||
if matched:
|
||||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
if keyword and participant:
|
||||
keywords_str = "、".join(parse_keywords_string(keyword) if keyword else [])
|
||||
return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录"
|
||||
elif keyword:
|
||||
keywords_str = "、".join(parse_keywords_string(keyword))
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if len(keywords_list) > 2:
|
||||
required_count = len(keywords_list) - 1
|
||||
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
else:
|
||||
return f"未找到{match_mode}'{keywords_str}'的聊天记录概述"
|
||||
|
||||
records = filtered_records
|
||||
|
||||
# 如果没有记录(可能是时间范围查询但没有匹配的记录)
|
||||
if not records:
|
||||
if time_range:
|
||||
return "未找到指定时间范围内的聊天记录概述"
|
||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||
elif participant:
|
||||
return f"未找到参与人包含'{participant}'的聊天记录"
|
||||
else:
|
||||
return "未找到相关聊天记录概述"
|
||||
return "未找到相关聊天记录"
|
||||
|
||||
# 对即将返回的记录增加使用计数
|
||||
records_to_use = records[:3]
|
||||
for record in records_to_use:
|
||||
try:
|
||||
ChatHistory.update(count=ChatHistory.count + 1).where(ChatHistory.id == record.id).execute()
|
||||
record.count = (record.count or 0) + 1
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新聊天记录概述计数失败: {update_error}")
|
||||
|
||||
# 构建结果文本
|
||||
# 构建结果文本,返回id、theme和keywords
|
||||
results = []
|
||||
for record in records_to_use: # 最多返回3条记录
|
||||
for record in filtered_records[:20]: # 最多返回20条记录
|
||||
result_parts = []
|
||||
|
||||
# 添加记忆ID
|
||||
result_parts.append(f"记忆ID:{record.id}")
|
||||
|
||||
# 添加主题
|
||||
if record.theme:
|
||||
result_parts.append(f"主题:{record.theme}")
|
||||
else:
|
||||
result_parts.append("主题:(无)")
|
||||
|
||||
# 添加时间范围
|
||||
from datetime import datetime
|
||||
|
||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||
|
||||
# 添加概括(优先使用summary,如果没有则使用original_text的前200字符)
|
||||
if record.summary:
|
||||
result_parts.append(f"概括:{record.summary}")
|
||||
elif record.original_text:
|
||||
text_preview = record.original_text[:200]
|
||||
if len(record.original_text) > 200:
|
||||
text_preview += "..."
|
||||
result_parts.append(f"内容:{text_preview}")
|
||||
# 添加关键词
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
if isinstance(keywords_data, list) and keywords_data:
|
||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||
result_parts.append(f"关键词:{keywords_str}")
|
||||
else:
|
||||
result_parts.append("关键词:(无)")
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
result_parts.append("关键词:(无)")
|
||||
else:
|
||||
result_parts.append("关键词:(无)")
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
if not results:
|
||||
return "未找到相关聊天记录概述"
|
||||
return "未找到相关聊天记录"
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
if len(records) > len(records_to_use):
|
||||
omitted_count = len(records) - len(records_to_use)
|
||||
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
|
||||
if len(filtered_records) > 20:
|
||||
omitted_count = len(filtered_records) - 20
|
||||
response_text += f"\n\n(还有{omitted_count}条记录已省略,可使用记忆ID查询详细信息)"
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
@@ -187,30 +183,145 @@ async def query_chat_history(
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
|
||||
"""根据记忆ID,展示某条或某几条记忆的具体内容
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_ids: 记忆ID,可以是单个ID(如"123")或多个ID(用逗号分隔,如"1,2,3")
|
||||
|
||||
Returns:
|
||||
str: 记忆的详细内容
|
||||
"""
|
||||
try:
|
||||
# 解析memory_ids
|
||||
id_list = []
|
||||
# 尝试解析为逗号分隔的ID列表
|
||||
try:
|
||||
id_list = [int(id_str.strip()) for id_str in memory_ids.split(",") if id_str.strip()]
|
||||
except ValueError:
|
||||
return f"无效的记忆ID格式: {memory_ids},请使用数字ID,多个ID用逗号分隔(如:'123' 或 '123,456')"
|
||||
|
||||
if not id_list:
|
||||
return "未提供有效的记忆ID"
|
||||
|
||||
# 查询记录
|
||||
query = ChatHistory.select().where((ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list)))
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()))
|
||||
|
||||
if not records:
|
||||
return f"未找到ID为{id_list}的记忆记录(可能ID不存在或不属于当前聊天)"
|
||||
|
||||
# 对即将返回的记录增加使用计数
|
||||
for record in records:
|
||||
try:
|
||||
ChatHistory.update(count=ChatHistory.count + 1).where(ChatHistory.id == record.id).execute()
|
||||
record.count = (record.count or 0) + 1
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新聊天记录概述计数失败: {update_error}")
|
||||
|
||||
# 构建详细结果
|
||||
results = []
|
||||
for record in records:
|
||||
result_parts = []
|
||||
|
||||
# 添加记忆ID
|
||||
result_parts.append(f"记忆ID:{record.id}")
|
||||
|
||||
# 添加主题
|
||||
if record.theme:
|
||||
result_parts.append(f"主题:{record.theme}")
|
||||
|
||||
# 添加时间范围
|
||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||
|
||||
# 添加参与人
|
||||
if record.participants:
|
||||
try:
|
||||
participants_data = (
|
||||
json.loads(record.participants) if isinstance(record.participants, str) else record.participants
|
||||
)
|
||||
if isinstance(participants_data, list) and participants_data:
|
||||
participants_str = "、".join([str(p) for p in participants_data])
|
||||
result_parts.append(f"参与人:{participants_str}")
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 添加关键词
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
if isinstance(keywords_data, list) and keywords_data:
|
||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||
result_parts.append(f"关键词:{keywords_str}")
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 添加概括
|
||||
if record.summary:
|
||||
result_parts.append(f"概括:{record.summary}")
|
||||
|
||||
# 添加关键信息点
|
||||
if record.key_point:
|
||||
try:
|
||||
key_point_data = (
|
||||
json.loads(record.key_point) if isinstance(record.key_point, str) else record.key_point
|
||||
)
|
||||
if isinstance(key_point_data, list) and key_point_data:
|
||||
key_point_str = "\n".join([f" - {str(kp)}" for kp in key_point_data])
|
||||
result_parts.append(f"关键信息点:\n{key_point_str}")
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
if not results:
|
||||
return "未找到相关记忆记录"
|
||||
|
||||
response_text = "\n\n" + "=" * 50 + "\n\n".join(results)
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆详情失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
# 注册工具1:搜索记忆
|
||||
register_memory_retrieval_tool(
|
||||
name="query_chat_history",
|
||||
description="根据时间或关键词在聊天记录中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
|
||||
name="search_chat_history",
|
||||
description="根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配(容错匹配)。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配)",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "time_range",
|
||||
"name": "participant",
|
||||
"type": "string",
|
||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "fuzzy",
|
||||
"type": "boolean",
|
||||
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
||||
"description": "参与人昵称(可选),用于查询包含该参与人的记忆",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_chat_history,
|
||||
execute_func=search_chat_history,
|
||||
)
|
||||
|
||||
# 注册工具2:获取记忆详情
|
||||
register_memory_retrieval_tool(
|
||||
name="get_chat_history_detail",
|
||||
description="根据记忆ID,展示某条或某几条记忆的具体内容。包括主题、时间、参与人、关键词、概括和关键信息点等详细信息。需要先使用search_chat_history工具获取记忆ID。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "memory_ids",
|
||||
"type": "string",
|
||||
"description": "记忆ID,可以是单个ID(如'123')或多个ID(用逗号分隔,如'123,456,789')",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
execute_func=get_chat_history_detail,
|
||||
)
|
||||
|
||||
@@ -104,7 +104,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
return []
|
||||
|
||||
if len(valid_emojis) < count:
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
f"[EmojiAPI] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
|
||||
)
|
||||
count = len(valid_emojis)
|
||||
|
||||
@@ -1,18 +1,263 @@
|
||||
"""
|
||||
插件系统配置类型定义
|
||||
|
||||
提供插件配置的类型定义,支持 WebUI 可视化配置编辑。
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, Dict, Union
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigField:
|
||||
"""配置字段定义"""
|
||||
"""
|
||||
配置字段定义
|
||||
|
||||
type: type # 字段类型
|
||||
用于定义插件配置项的元数据,支持类型验证、UI 渲染等功能。
|
||||
|
||||
基础示例:
|
||||
ConfigField(type=str, default="", description="API密钥")
|
||||
|
||||
完整示例:
|
||||
ConfigField(
|
||||
type=str,
|
||||
default="",
|
||||
description="API密钥",
|
||||
input_type="password",
|
||||
placeholder="请输入API密钥",
|
||||
required=True,
|
||||
hint="从服务商控制台获取",
|
||||
order=1
|
||||
)
|
||||
"""
|
||||
|
||||
# === 基础字段(必需) ===
|
||||
type: type # 字段类型: str, int, float, bool, list, dict
|
||||
default: Any # 默认值
|
||||
description: str # 字段描述
|
||||
example: Optional[str] = None # 示例值
|
||||
description: str # 字段描述(也用作默认标签)
|
||||
|
||||
# === 验证相关 ===
|
||||
example: Optional[str] = None # 示例值(用于生成配置文件注释)
|
||||
required: bool = False # 是否必需
|
||||
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表
|
||||
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表(用于下拉选择)
|
||||
min: Optional[float] = None # 最小值(数字类型)
|
||||
max: Optional[float] = None # 最大值(数字类型)
|
||||
step: Optional[float] = None # 步进值(数字类型)
|
||||
pattern: Optional[str] = None # 正则验证(字符串类型)
|
||||
max_length: Optional[int] = None # 最大长度(字符串类型)
|
||||
|
||||
# === UI 显示控制 ===
|
||||
label: Optional[str] = None # 显示标签(默认使用 description)
|
||||
placeholder: Optional[str] = None # 输入框占位符
|
||||
hint: Optional[str] = None # 字段下方的提示文字
|
||||
icon: Optional[str] = None # 字段图标名称
|
||||
hidden: bool = False # 是否在 UI 中隐藏
|
||||
disabled: bool = False # 是否禁用编辑
|
||||
order: int = 0 # 排序权重(数字越小越靠前)
|
||||
|
||||
# === 输入控件类型 ===
|
||||
# 可选值: text, password, textarea, number, color, code, file, json
|
||||
# 不指定时根据 type 和 choices 自动推断
|
||||
input_type: Optional[str] = None
|
||||
|
||||
# === textarea 专用 ===
|
||||
rows: int = 3 # 文本域行数
|
||||
|
||||
# === 分组与布局 ===
|
||||
group: Optional[str] = None # 字段分组(在 section 内再细分)
|
||||
|
||||
# === 条件显示 ===
|
||||
depends_on: Optional[str] = None # 依赖的字段路径,如 "section.field"
|
||||
depends_value: Any = None # 依赖字段需要的值(当依赖字段等于此值时显示)
|
||||
|
||||
def get_ui_type(self) -> str:
|
||||
"""
|
||||
获取 UI 控件类型
|
||||
|
||||
如果指定了 input_type 则直接返回,否则根据 type 和 choices 自动推断。
|
||||
|
||||
Returns:
|
||||
控件类型字符串
|
||||
"""
|
||||
if self.input_type:
|
||||
return self.input_type
|
||||
|
||||
# 根据 type 和 choices 自动推断
|
||||
if self.type is bool:
|
||||
return "switch"
|
||||
elif self.type in (int, float):
|
||||
if self.min is not None and self.max is not None:
|
||||
return "slider"
|
||||
return "number"
|
||||
elif self.type is str:
|
||||
if self.choices:
|
||||
return "select"
|
||||
return "text"
|
||||
elif self.type is list:
|
||||
return "list"
|
||||
elif self.type is dict:
|
||||
return "json"
|
||||
else:
|
||||
return "text"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
转换为可序列化的字典(用于 API 传输)
|
||||
|
||||
Returns:
|
||||
包含所有配置信息的字典
|
||||
"""
|
||||
return {
|
||||
"type": self.type.__name__ if isinstance(self.type, type) else str(self.type),
|
||||
"default": self.default,
|
||||
"description": self.description,
|
||||
"example": self.example,
|
||||
"required": self.required,
|
||||
"choices": self.choices if self.choices else None,
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"step": self.step,
|
||||
"pattern": self.pattern,
|
||||
"max_length": self.max_length,
|
||||
"label": self.label or self.description,
|
||||
"placeholder": self.placeholder,
|
||||
"hint": self.hint,
|
||||
"icon": self.icon,
|
||||
"hidden": self.hidden,
|
||||
"disabled": self.disabled,
|
||||
"order": self.order,
|
||||
"input_type": self.input_type,
|
||||
"ui_type": self.get_ui_type(),
|
||||
"rows": self.rows,
|
||||
"group": self.group,
|
||||
"depends_on": self.depends_on,
|
||||
"depends_value": self.depends_value,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigSection:
|
||||
"""
|
||||
配置节定义
|
||||
|
||||
用于描述配置文件中一个 section 的元数据。
|
||||
|
||||
示例:
|
||||
ConfigSection(
|
||||
title="API配置",
|
||||
description="外部API连接参数",
|
||||
icon="cloud",
|
||||
order=1
|
||||
)
|
||||
"""
|
||||
|
||||
title: str # 显示标题
|
||||
description: Optional[str] = None # 详细描述
|
||||
icon: Optional[str] = None # 图标名称
|
||||
collapsed: bool = False # 默认是否折叠
|
||||
order: int = 0 # 排序权重
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为可序列化的字典"""
|
||||
return {
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"icon": self.icon,
|
||||
"collapsed": self.collapsed,
|
||||
"order": self.order,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigTab:
|
||||
"""
|
||||
配置标签页定义
|
||||
|
||||
用于将多个 section 组织到一个标签页中。
|
||||
|
||||
示例:
|
||||
ConfigTab(
|
||||
id="general",
|
||||
title="通用设置",
|
||||
icon="settings",
|
||||
sections=["plugin", "api"]
|
||||
)
|
||||
"""
|
||||
|
||||
id: str # 标签页 ID
|
||||
title: str # 显示标题
|
||||
sections: List[str] = field(default_factory=list) # 包含的 section 名称列表
|
||||
icon: Optional[str] = None # 图标名称
|
||||
order: int = 0 # 排序权重
|
||||
badge: Optional[str] = None # 角标文字(如 "Beta", "New")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为可序列化的字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"title": self.title,
|
||||
"sections": self.sections,
|
||||
"icon": self.icon,
|
||||
"order": self.order,
|
||||
"badge": self.badge,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigLayout:
|
||||
"""
|
||||
配置页面布局定义
|
||||
|
||||
用于定义插件配置页面的整体布局结构。
|
||||
|
||||
布局类型:
|
||||
- "auto": 自动布局,sections 作为折叠面板显示
|
||||
- "tabs": 标签页布局
|
||||
- "pages": 分页布局(左侧导航 + 右侧内容)
|
||||
|
||||
简单示例(标签页布局):
|
||||
ConfigLayout(
|
||||
type="tabs",
|
||||
tabs=[
|
||||
ConfigTab(id="basic", title="基础", sections=["plugin", "api"]),
|
||||
ConfigTab(id="advanced", title="高级", sections=["debug"]),
|
||||
]
|
||||
)
|
||||
"""
|
||||
|
||||
type: str = "auto" # 布局类型: auto, tabs, pages
|
||||
tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为可序列化的字典"""
|
||||
return {
|
||||
"type": self.type,
|
||||
"tabs": [tab.to_dict() for tab in self.tabs],
|
||||
}
|
||||
|
||||
|
||||
def section_meta(
|
||||
title: str, description: Optional[str] = None, icon: Optional[str] = None, collapsed: bool = False, order: int = 0
|
||||
) -> Union[str, ConfigSection]:
|
||||
"""
|
||||
便捷函数:创建 section 元数据
|
||||
|
||||
可以在 config_section_descriptions 中使用,提供比纯字符串更丰富的信息。
|
||||
|
||||
Args:
|
||||
title: 显示标题
|
||||
description: 详细描述
|
||||
icon: 图标名称
|
||||
collapsed: 默认是否折叠
|
||||
order: 排序权重
|
||||
|
||||
Returns:
|
||||
ConfigSection 实例
|
||||
|
||||
示例:
|
||||
config_section_descriptions = {
|
||||
"api": section_meta("API配置", icon="cloud", order=1),
|
||||
"debug": section_meta("调试设置", collapsed=True, order=99),
|
||||
}
|
||||
"""
|
||||
return ConfigSection(title=title, description=description, icon=icon, collapsed=collapsed, order=order)
|
||||
|
||||
@@ -12,7 +12,11 @@ from src.plugin_system.base.component_types import (
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
)
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.base.config_types import (
|
||||
ConfigField,
|
||||
ConfigSection,
|
||||
ConfigLayout,
|
||||
)
|
||||
from src.plugin_system.utils.manifest_utils import ManifestValidator
|
||||
|
||||
logger = get_logger("plugin_base")
|
||||
@@ -60,7 +64,10 @@ class PluginBase(ABC):
|
||||
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
|
||||
return {}
|
||||
|
||||
config_section_descriptions: Dict[str, str] = {}
|
||||
config_section_descriptions: Dict[str, Union[str, ConfigSection]] = {}
|
||||
|
||||
# 布局配置(可选,不定义则使用自动布局)
|
||||
config_layout: ConfigLayout = None
|
||||
|
||||
def __init__(self, plugin_dir: str):
|
||||
"""初始化插件
|
||||
@@ -205,6 +212,22 @@ class PluginBase(ABC):
|
||||
|
||||
return value
|
||||
|
||||
def _format_toml_value(self, value: Any) -> str:
|
||||
"""将Python值格式化为合法的TOML字符串"""
|
||||
if isinstance(value, str):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, list):
|
||||
inner = ", ".join(self._format_toml_value(item) for item in value)
|
||||
return f"[{inner}]"
|
||||
if isinstance(value, dict):
|
||||
items = [f"{k} = {self._format_toml_value(v)}" for k, v in value.items()]
|
||||
return "{ " + ", ".join(items) + " }"
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _generate_and_save_default_config(self, config_file_path: str):
|
||||
"""根据插件的Schema生成并保存默认配置文件"""
|
||||
if not self.config_schema:
|
||||
@@ -244,12 +267,7 @@ class PluginBase(ABC):
|
||||
|
||||
# 添加字段值
|
||||
value = field.default
|
||||
if isinstance(value, str):
|
||||
toml_str += f'{field_name} = "{value}"\n'
|
||||
elif isinstance(value, bool):
|
||||
toml_str += f"{field_name} = {str(value).lower()}\n"
|
||||
else:
|
||||
toml_str += f"{field_name} = {value}\n"
|
||||
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
|
||||
|
||||
toml_str += "\n"
|
||||
toml_str += "\n"
|
||||
@@ -422,19 +440,7 @@ class PluginBase(ABC):
|
||||
|
||||
# 添加字段值(使用迁移后的值)
|
||||
value = section_data.get(field_name, field.default)
|
||||
if isinstance(value, str):
|
||||
toml_str += f'{field_name} = "{value}"\n'
|
||||
elif isinstance(value, bool):
|
||||
toml_str += f"{field_name} = {str(value).lower()}\n"
|
||||
elif isinstance(value, list):
|
||||
# 格式化列表
|
||||
if all(isinstance(item, str) for item in value):
|
||||
formatted_list = "[" + ", ".join(f'"{item}"' for item in value) + "]"
|
||||
else:
|
||||
formatted_list = str(value)
|
||||
toml_str += f"{field_name} = {formatted_list}\n"
|
||||
else:
|
||||
toml_str += f"{field_name} = {value}\n"
|
||||
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
|
||||
|
||||
toml_str += "\n"
|
||||
toml_str += "\n"
|
||||
@@ -564,6 +570,93 @@ class PluginBase(ABC):
|
||||
|
||||
return current
|
||||
|
||||
def get_webui_config_schema(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 WebUI 配置 Schema
|
||||
|
||||
返回完整的配置 schema,包含:
|
||||
- 插件基本信息
|
||||
- 所有 section 及其字段定义
|
||||
- 布局配置
|
||||
|
||||
用于 WebUI 动态生成配置表单。
|
||||
|
||||
Returns:
|
||||
Dict: 完整的配置 schema
|
||||
"""
|
||||
schema = {
|
||||
"plugin_id": self.plugin_name,
|
||||
"plugin_info": {
|
||||
"name": self.display_name,
|
||||
"version": self.plugin_version,
|
||||
"description": self.plugin_description,
|
||||
"author": self.plugin_author,
|
||||
},
|
||||
"sections": {},
|
||||
"layout": None,
|
||||
}
|
||||
|
||||
# 处理 sections
|
||||
for section_name, fields in self.config_schema.items():
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
|
||||
section_data = {
|
||||
"name": section_name,
|
||||
"title": section_name,
|
||||
"description": None,
|
||||
"icon": None,
|
||||
"collapsed": False,
|
||||
"order": 0,
|
||||
"fields": {},
|
||||
}
|
||||
|
||||
# 获取 section 元数据
|
||||
section_meta = self.config_section_descriptions.get(section_name)
|
||||
if section_meta:
|
||||
if isinstance(section_meta, str):
|
||||
section_data["title"] = section_meta
|
||||
elif isinstance(section_meta, ConfigSection):
|
||||
section_data["title"] = section_meta.title
|
||||
section_data["description"] = section_meta.description
|
||||
section_data["icon"] = section_meta.icon
|
||||
section_data["collapsed"] = section_meta.collapsed
|
||||
section_data["order"] = section_meta.order
|
||||
elif isinstance(section_meta, dict):
|
||||
section_data.update(section_meta)
|
||||
|
||||
# 处理字段
|
||||
for field_name, field_def in fields.items():
|
||||
if isinstance(field_def, ConfigField):
|
||||
field_data = field_def.to_dict()
|
||||
field_data["name"] = field_name
|
||||
section_data["fields"][field_name] = field_data
|
||||
|
||||
schema["sections"][section_name] = section_data
|
||||
|
||||
# 处理布局
|
||||
if self.config_layout:
|
||||
schema["layout"] = self.config_layout.to_dict()
|
||||
else:
|
||||
# 自动布局:按 section order 排序
|
||||
schema["layout"] = {
|
||||
"type": "auto",
|
||||
"tabs": [],
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
def get_current_config_values(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前配置值
|
||||
|
||||
返回插件当前的配置值(已从配置文件加载)。
|
||||
|
||||
Returns:
|
||||
Dict: 当前配置值
|
||||
"""
|
||||
return self.config.copy()
|
||||
|
||||
@abstractmethod
|
||||
def register_plugin(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -95,7 +95,7 @@ class ToolExecutor:
|
||||
|
||||
# 如果没有可用工具,直接返回空内容
|
||||
if not tools:
|
||||
logger.info(f"{self.log_prefix}没有可用工具,直接返回空内容")
|
||||
logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容")
|
||||
if return_details:
|
||||
return [], [], ""
|
||||
else:
|
||||
|
||||
@@ -15,7 +15,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
||||
("limit", ToolParamType.INTEGER, "希望返回的相关知识条数,默认5", False, 5),
|
||||
("limit", ToolParamType.INTEGER, "希望返回的相关知识条数,默认5", False, None),
|
||||
]
|
||||
available_for_llm = global_config.lpmm_knowledge.enable
|
||||
|
||||
|
||||
127
src/webui/auth.py
Normal file
127
src/webui/auth.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
WebUI 认证模块
|
||||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.auth")
|
||||
|
||||
# Cookie 配置
|
||||
COOKIE_NAME = "maibot_session"
|
||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str) -> None:
|
||||
"""
|
||||
设置认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
"""
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True, # 防止 JS 读取
|
||||
samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理)
|
||||
secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True
|
||||
path="/", # 确保 Cookie 在所有路径下可用
|
||||
)
|
||||
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""
|
||||
清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
path="/",
|
||||
)
|
||||
logger.debug("已清除认证 Cookie")
|
||||
|
||||
|
||||
def verify_auth_token_from_cookie_or_header(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
729
src/webui/chat_routes.py
Normal file
729
src/webui/chat_routes.py
Normal file
@@ -0,0 +1,729 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话
|
||||
|
||||
支持两种模式:
|
||||
1. WebUI 模式:使用 WebUI 平台独立身份聊天
|
||||
2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
|
||||
# 虚拟身份模式的群 ID 前缀
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
# 固定的 WebUI 用户 ID 前缀
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置"""
|
||||
|
||||
enabled: bool = False # 是否启用虚拟身份模式
|
||||
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
|
||||
person_id: Optional[str] = None # PersonInfo 的 person_id
|
||||
user_id: Optional[str] = None # 原始平台用户 ID
|
||||
user_nickname: Optional[str] = None # 用户昵称
|
||||
group_id: Optional[str] = None # 虚拟群 ID(自动生成或用户指定)
|
||||
group_name: Optional[str] = None # 虚拟群名(用户自定义)
|
||||
|
||||
|
||||
class ChatHistoryMessage(BaseModel):
|
||||
"""聊天历史消息"""
|
||||
|
||||
id: str
|
||||
type: str # 'user' | 'bot' | 'system'
|
||||
content: str
|
||||
timestamp: float
|
||||
sender_name: str
|
||||
sender_id: Optional[str] = None
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
|
||||
|
||||
def __init__(self, max_messages: int = 200):
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式
|
||||
|
||||
Args:
|
||||
msg: 数据库消息对象
|
||||
group_id: 群 ID,用于判断是否是虚拟群
|
||||
"""
|
||||
# 判断是否是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
|
||||
# 对于虚拟群,通过比较机器人 QQ 账号来判断
|
||||
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
|
||||
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
# 虚拟群:user_id 等于机器人 QQ 账号的是机器人消息
|
||||
bot_qq = str(global_config.bot.qq_account)
|
||||
is_bot = user_id == bot_qq
|
||||
else:
|
||||
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
|
||||
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
"content": msg.processed_plain_text or msg.display_message or "",
|
||||
"timestamp": msg.time,
|
||||
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
"is_bot": is_bot,
|
||||
}
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""从数据库获取最近的历史记录
|
||||
|
||||
Args:
|
||||
limit: 获取的消息数量
|
||||
group_id: 群 ID,默认为 WEBUI_CHAT_GROUP_ID
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
# 查询指定群的消息,按时间排序
|
||||
messages = (
|
||||
Messages.select()
|
||||
.where(Messages.chat_info_group_id == target_group_id)
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# 转换为列表并反转(使最旧的消息在前)
|
||||
# 传递 group_id 以便正确判断虚拟群中的机器人消息
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
result.reverse()
|
||||
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 全局聊天历史管理器
|
||||
chat_history = ChatHistoryManager()
|
||||
|
||||
|
||||
# 存储 WebSocket 连接
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str):
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
async def send_message(self, session_id: str, message: dict):
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""广播消息给所有连接"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def create_message_data(
|
||||
content: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建符合麦麦消息格式的消息数据
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
user_id: 用户 ID
|
||||
user_name: 用户昵称
|
||||
message_id: 消息 ID(可选,自动生成)
|
||||
is_at_bot: 是否 @ 机器人
|
||||
virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份)
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# 确定使用的平台、群信息和用户信息
|
||||
if virtual_config and virtual_config.enabled:
|
||||
# 虚拟身份模式:使用真实平台身份
|
||||
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
|
||||
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
|
||||
group_name = virtual_config.group_name or "WebUI虚拟群聊"
|
||||
actual_user_id = virtual_config.user_id or user_id
|
||||
actual_user_name = virtual_config.user_nickname or user_name
|
||||
else:
|
||||
# 标准 WebUI 模式
|
||||
platform = WEBUI_CHAT_PLATFORM
|
||||
group_id = WEBUI_CHAT_GROUP_ID
|
||||
group_name = "WebUI本地聊天室"
|
||||
actual_user_id = user_id
|
||||
actual_user_name = user_name
|
||||
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": actual_user_id,
|
||||
"user_nickname": actual_user_name,
|
||||
"user_cardname": actual_user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
{
|
||||
"type": "mention_bot",
|
||||
"data": "1.0",
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
|
||||
如果指定了 group_id,则获取该虚拟群的历史记录
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
history = chat_history.get_history(limit, target_group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"messages": history,
|
||||
"total": len(history),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms():
|
||||
"""获取可用平台列表
|
||||
|
||||
从 PersonInfo 表中获取所有已知的平台
|
||||
"""
|
||||
try:
|
||||
from peewee import fn
|
||||
|
||||
# 查询所有不同的平台
|
||||
platforms = (
|
||||
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
|
||||
.group_by(PersonInfo.platform)
|
||||
.order_by(fn.COUNT(PersonInfo.id).desc())
|
||||
)
|
||||
|
||||
result = []
|
||||
for p in platforms:
|
||||
if p.platform: # 排除空平台
|
||||
result.append({"platform": p.platform, "count": p.count})
|
||||
|
||||
return {"success": True, "platforms": result}
|
||||
except Exception as e:
|
||||
logger.error(f"获取平台列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "platforms": []}
|
||||
|
||||
|
||||
@router.get("/persons")
|
||||
async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
):
|
||||
"""获取指定平台的用户列表
|
||||
|
||||
Args:
|
||||
platform: 平台名称(如 qq, discord 等)
|
||||
search: 搜索关键词(匹配昵称、用户名、user_id)
|
||||
limit: 返回数量限制
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = PersonInfo.select().where(PersonInfo.platform == platform)
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
# 按最后交互时间排序,优先显示活跃用户
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
query = query.limit(limit)
|
||||
|
||||
result = []
|
||||
for person in query:
|
||||
result.append(
|
||||
{
|
||||
"person_id": person.person_id,
|
||||
"user_id": person.user_id,
|
||||
"person_name": person.person_name,
|
||||
"nickname": person.nickname,
|
||||
"is_known": person.is_known,
|
||||
"platform": person.platform,
|
||||
"display_name": person.person_name or person.nickname or person.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return {"success": True, "persons": result, "total": len(result)}
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "persons": []}
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 可选,指定要清空的群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已清空 {deleted} 条聊天记录",
|
||||
}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
Args:
|
||||
user_id: 用户唯一标识(由前端生成并持久化)
|
||||
user_name: 用户显示昵称(可修改)
|
||||
platform: 虚拟身份模式的平台(可选)
|
||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||
group_name: 虚拟身份模式的群名(可选)
|
||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
"""
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# 如果没有提供 user_id,生成一个新的
|
||||
if not user_id:
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
# 确保 user_id 有正确的前缀
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
|
||||
|
||||
# 当前会话的虚拟身份配置(可通过消息动态更新)
|
||||
current_virtual_config: Optional[VirtualIdentityConfig] = None
|
||||
|
||||
# 如果 URL 参数中提供了虚拟身份信息,自动配置
|
||||
if platform and person_id:
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if person:
|
||||
# 使用前端传递的 group_id,如果没有则生成一个稳定的
|
||||
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
|
||||
await chat_manager.connect(websocket, session_id, user_id)
|
||||
|
||||
try:
|
||||
# 构建会话信息
|
||||
session_info_data = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
}
|
||||
|
||||
# 如果有虚拟身份配置,添加到会话信息中
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
session_info_data["virtual_mode"] = True
|
||||
session_info_data["group_id"] = current_virtual_config.group_id
|
||||
session_info_data["virtual_identity"] = {
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
}
|
||||
|
||||
# 发送会话信息(包含用户 ID,前端需要保存)
|
||||
await chat_manager.send_message(session_id, session_info_data)
|
||||
|
||||
# 发送历史记录(根据模式选择不同的群)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
else:
|
||||
history = chat_history.get_history(50)
|
||||
if history:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送欢迎消息(不保存到历史)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
|
||||
else:
|
||||
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": welcome_msg,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "message":
|
||||
content = data.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 用户可以更新昵称
|
||||
current_user_name = data.get("user_name", user_name)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
timestamp = time.time()
|
||||
|
||||
# 确定发送者信息(根据是否使用虚拟身份)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
sender_name = current_virtual_config.user_nickname or current_user_name
|
||||
sender_user_id = current_virtual_config.user_id or user_id
|
||||
else:
|
||||
sender_name = current_user_name
|
||||
sender_user_id = user_id
|
||||
|
||||
# 广播用户消息给所有连接(包括发送者)
|
||||
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
"name": sender_name,
|
||||
"user_id": sender_user_id,
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
|
||||
}
|
||||
)
|
||||
|
||||
# 创建麦麦消息格式
|
||||
message_data = create_message_data(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
user_name=current_user_name,
|
||||
message_id=message_id,
|
||||
is_at_bot=True,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# 显示正在输入状态
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": True,
|
||||
}
|
||||
)
|
||||
|
||||
# 调用麦麦的消息处理
|
||||
await chat_bot.message_process(message_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"处理消息时出错: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif data.get("type") == "ping":
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "update_nickname":
|
||||
# 允许用户更新昵称
|
||||
if new_name := data.get("user_name", "").strip():
|
||||
current_user_name = new_name
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "nickname_updated",
|
||||
"user_name": current_user_name,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "set_virtual_identity":
|
||||
# 设置或更新虚拟身份配置
|
||||
virtual_data = data.get("config", {})
|
||||
if virtual_data.get("enabled"):
|
||||
# 验证必要字段
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 获取用户信息
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
|
||||
if not person:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"找不到用户: {virtual_data.get('person_id')}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成虚拟群 ID
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
if custom_group_id:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
else:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
|
||||
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
||||
group_id=group_id,
|
||||
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
|
||||
)
|
||||
|
||||
# 发送虚拟身份已激活的消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {
|
||||
"enabled": True,
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 加载虚拟群的历史记录
|
||||
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": virtual_history,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送系统消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"设置虚拟身份失败: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 禁用虚拟身份模式
|
||||
current_virtual_config = None
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {"enabled": False},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 重新加载默认聊天室历史
|
||||
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": default_history,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
},
|
||||
)
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": "已切换回 WebUI 独立用户模式",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info():
|
||||
"""获取聊天室信息"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"active_sessions": len(chat_manager.active_connections),
|
||||
}
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用
|
||||
|
||||
Returns:
|
||||
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
|
||||
"""
|
||||
return (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
@@ -5,9 +5,10 @@
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from typing import Any
|
||||
from typing import Any, Annotated
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
@@ -41,6 +42,12 @@ from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
ConfigBody = Annotated[dict[str, Any], Body()]
|
||||
SectionBody = Annotated[Any, Body()]
|
||||
RawContentBody = Annotated[str, Body(embed=True)]
|
||||
PathBody = Annotated[dict[str, str], Body()]
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
@@ -90,7 +97,7 @@ async def get_bot_config_schema():
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
@@ -101,7 +108,7 @@ async def get_model_config_schema():
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 子配置架构获取接口 =====
|
||||
@@ -174,7 +181,7 @@ async def get_config_section_schema(section_name: str):
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置节架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置读取接口 =====
|
||||
@@ -196,7 +203,7 @@ async def get_bot_config():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
@@ -215,26 +222,25 @@ async def get_model_config():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
# 保存配置文件(格式化数组为多行)
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info("麦麦主程序配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
@@ -242,23 +248,22 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
# 保存配置文件(格式化数组为多行)
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info("模型配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
@@ -266,14 +271,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置节更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -304,11 +309,10 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(tomlkit.dump 会保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
@@ -316,7 +320,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 原始 TOML 文件操作接口 =====
|
||||
@@ -338,24 +342,24 @@ async def get_bot_config_raw():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
config_data = tomlkit.loads(raw_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 验证配置数据结构
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -368,11 +372,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -403,11 +407,10 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(tomlkit.dump 会保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
@@ -415,7 +418,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 适配器配置管理接口 =====
|
||||
@@ -425,11 +428,11 @@ def _normalize_adapter_path(path: str) -> str:
|
||||
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
|
||||
# 如果已经是绝对路径,直接返回
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
|
||||
|
||||
# 相对路径,转换为相对于项目根目录的绝对路径
|
||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||
|
||||
@@ -438,17 +441,17 @@ def _to_relative_path(path: str) -> str:
|
||||
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
|
||||
if not path or not os.path.isabs(path):
|
||||
return path
|
||||
|
||||
|
||||
try:
|
||||
# 尝试获取相对路径
|
||||
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
||||
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
||||
if not rel_path.startswith('..'):
|
||||
if not rel_path.startswith(".."):
|
||||
return rel_path
|
||||
except (ValueError, TypeError):
|
||||
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
||||
pass
|
||||
|
||||
|
||||
# 无法转换为相对路径,返回绝对路径
|
||||
return path
|
||||
|
||||
@@ -463,6 +466,7 @@ async def get_adapter_config_path():
|
||||
return {"success": True, "path": None}
|
||||
|
||||
import json
|
||||
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
|
||||
@@ -472,10 +476,11 @@ async def get_adapter_config_path():
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(adapter_config_path)
|
||||
|
||||
|
||||
# 检查文件是否存在并返回最后修改时间
|
||||
if os.path.exists(abs_path):
|
||||
import datetime
|
||||
|
||||
mtime = os.path.getmtime(abs_path)
|
||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||
# 返回相对路径(如果可能)
|
||||
@@ -487,11 +492,11 @@ async def get_adapter_config_path():
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
async def save_adapter_config_path(data: PathBody):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -511,10 +516,10 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 尝试转换为相对路径保存(如果文件在项目目录内)
|
||||
save_path = _to_relative_path(abs_path)
|
||||
|
||||
|
||||
# 更新路径
|
||||
webui_data["adapter_config_path"] = save_path
|
||||
|
||||
@@ -530,7 +535,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
@@ -542,7 +547,7 @@ async def get_adapter_config(path: str):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(abs_path):
|
||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||
@@ -562,11 +567,11 @@ async def get_adapter_config(path: str):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
async def save_adapter_config(data: PathBody):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -579,17 +584,16 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
import toml
|
||||
toml.loads(content)
|
||||
tomlkit.loads(content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 确保目录存在
|
||||
dir_path = os.path.dirname(abs_path)
|
||||
@@ -607,5 +611,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e
|
||||
|
||||
@@ -117,7 +117,7 @@ class ConfigSchemaGenerator:
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
# 单行文档字符串
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
# 多行文档字符串
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
@@ -135,7 +135,7 @@ class ConfigSchemaGenerator:
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
@@ -199,13 +199,13 @@ class ConfigSchemaGenerator:
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理基本类型
|
||||
if field_type is bool or field_type == bool:
|
||||
if field_type is bool:
|
||||
return FieldType.BOOLEAN, None, None
|
||||
elif field_type is int or field_type == int:
|
||||
elif field_type is int:
|
||||
return FieldType.INTEGER, None, None
|
||||
elif field_type is float or field_type == float:
|
||||
elif field_type is float:
|
||||
return FieldType.NUMBER, None, None
|
||||
elif field_type is str or field_type == str:
|
||||
elif field_type is str:
|
||||
return FieldType.STRING, None, None
|
||||
elif field_type is dict or origin is dict:
|
||||
return FieldType.OBJECT, None, None
|
||||
|
||||
@@ -1,18 +1,175 @@
|
||||
"""表情包管理 API 路由"""
|
||||
""" 表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Annotated
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Emoji
|
||||
from .token_manager import get_token_manager
|
||||
import json
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
import os
|
||||
import hashlib
|
||||
from PIL import Image
|
||||
import io
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = get_logger("webui.emoji")
|
||||
|
||||
# ==================== 缩略图缓存配置 ====================
|
||||
# 缩略图缓存目录
|
||||
THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails")
|
||||
# 缩略图尺寸 (宽, 高)
|
||||
THUMBNAIL_SIZE = (200, 200)
|
||||
# 缩略图质量 (WebP 格式, 1-100)
|
||||
THUMBNAIL_QUALITY = 80
|
||||
# 缓存锁,防止并发生成同一缩略图
|
||||
_thumbnail_locks: dict[str, threading.Lock] = {}
|
||||
_locks_lock = threading.Lock()
|
||||
# 缩略图生成专用线程池(避免阻塞事件循环)
|
||||
_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail")
|
||||
# 正在生成中的缩略图哈希集合(防止重复提交任务)
|
||||
_generating_thumbnails: set[str] = set()
|
||||
_generating_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
|
||||
"""获取指定文件哈希的锁,用于防止并发生成同一缩略图"""
|
||||
with _locks_lock:
|
||||
if file_hash not in _thumbnail_locks:
|
||||
_thumbnail_locks[file_hash] = threading.Lock()
|
||||
return _thumbnail_locks[file_hash]
|
||||
|
||||
|
||||
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||
"""
|
||||
后台生成缩略图(在线程池中执行)
|
||||
|
||||
生成完成后自动从 generating 集合中移除
|
||||
"""
|
||||
try:
|
||||
_generate_thumbnail(source_path, file_hash)
|
||||
except Exception as e:
|
||||
logger.warning(f"后台生成缩略图失败 {file_hash}: {e}")
|
||||
finally:
|
||||
with _generating_lock:
|
||||
_generating_thumbnails.discard(file_hash)
|
||||
|
||||
|
||||
def _ensure_thumbnail_cache_dir() -> Path:
|
||||
"""确保缩略图缓存目录存在"""
|
||||
THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return THUMBNAIL_CACHE_DIR
|
||||
|
||||
|
||||
def _get_thumbnail_cache_path(file_hash: str) -> Path:
|
||||
"""获取缩略图缓存路径"""
|
||||
return THUMBNAIL_CACHE_DIR / f"{file_hash}.webp"
|
||||
|
||||
|
||||
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
"""
|
||||
生成缩略图并保存到缓存目录
|
||||
|
||||
Args:
|
||||
source_path: 原图路径
|
||||
file_hash: 文件哈希值,用作缓存文件名
|
||||
|
||||
Returns:
|
||||
缩略图路径
|
||||
|
||||
Features:
|
||||
- GIF: 提取第一帧作为缩略图
|
||||
- 所有格式统一转为 WebP
|
||||
- 保持宽高比缩放
|
||||
"""
|
||||
_ensure_thumbnail_cache_dir()
|
||||
cache_path = _get_thumbnail_cache_path(file_hash)
|
||||
|
||||
# 使用锁防止并发生成同一缩略图
|
||||
lock = _get_thumbnail_lock(file_hash)
|
||||
with lock:
|
||||
# 双重检查,可能在等待锁时已被其他线程生成
|
||||
if cache_path.exists():
|
||||
return cache_path
|
||||
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
# GIF 处理:提取第一帧
|
||||
if hasattr(img, 'n_frames') and img.n_frames > 1:
|
||||
img.seek(0) # 确保在第一帧
|
||||
|
||||
# 转换为 RGB/RGBA(WebP 支持透明度)
|
||||
if img.mode in ('P', 'PA'):
|
||||
# 调色板模式转换为 RGBA 以保留透明度
|
||||
img = img.convert('RGBA')
|
||||
elif img.mode == 'LA':
|
||||
img = img.convert('RGBA')
|
||||
elif img.mode not in ('RGB', 'RGBA'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 创建缩略图(保持宽高比)
|
||||
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存为 WebP 格式
|
||||
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
|
||||
|
||||
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
|
||||
# 生成失败时不创建缓存文件,下次会重试
|
||||
raise
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
"""
|
||||
清理孤立的缩略图缓存(原图已不存在的缩略图)
|
||||
|
||||
Returns:
|
||||
(清理数量, 保留数量)
|
||||
"""
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return 0, 0
|
||||
|
||||
# 获取所有表情包的哈希值
|
||||
valid_hashes = set()
|
||||
for emoji in Emoji.select(Emoji.emoji_hash):
|
||||
valid_hashes.add(emoji.emoji_hash)
|
||||
|
||||
cleaned = 0
|
||||
kept = 0
|
||||
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
file_hash = cache_file.stem
|
||||
if file_hash not in valid_hashes:
|
||||
try:
|
||||
cache_file.unlink()
|
||||
cleaned += 1
|
||||
logger.debug(f"清理孤立缩略图: {cache_file.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
|
||||
else:
|
||||
kept += 1
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个")
|
||||
|
||||
return cleaned, kept
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||
DescriptionForm = Annotated[str, Form(description="表情包描述")]
|
||||
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
|
||||
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||
|
||||
@@ -92,18 +249,12 @@ class BatchDeleteResponse(BaseModel):
|
||||
failed_ids: List[int] = []
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
||||
@@ -135,6 +286,7 @@ async def get_emoji_list(
|
||||
format: Optional[str] = Query(None, description="格式筛选"),
|
||||
sort_by: Optional[str] = Query("usage_count", description="排序字段"),
|
||||
sort_order: Optional[str] = Query("desc", description="排序方向"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -155,7 +307,7 @@ async def get_emoji_list(
|
||||
表情包列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Emoji.select()
|
||||
@@ -213,7 +365,7 @@ async def get_emoji_list(
|
||||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
@@ -225,7 +377,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(
|
||||
表情包详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -242,7 +394,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(
|
||||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
@@ -255,7 +407,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -294,7 +446,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization
|
||||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
@@ -306,7 +458,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -331,7 +483,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包统计数据
|
||||
|
||||
@@ -342,7 +494,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Emoji.select().count()
|
||||
registered = Emoji.select().where(Emoji.is_registered).count()
|
||||
@@ -386,7 +538,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
@@ -398,7 +550,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -426,7 +578,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
@@ -438,7 +590,7 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -465,28 +617,47 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_thumbnail(
|
||||
emoji_id: int,
|
||||
token: Optional[str] = Query(None, description="访问令牌"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
original: bool = Query(False, description="是否返回原图"),
|
||||
):
|
||||
"""
|
||||
获取表情包缩略图
|
||||
获取表情包缩略图(懒加载生成 + 缓存)
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
token: 访问令牌(通过 query parameter)
|
||||
token: 访问令牌(通过 query parameter,用于向后兼容)
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header
|
||||
original: 是否返回原图(用于详情页查看原图)
|
||||
|
||||
Returns:
|
||||
表情包图片文件
|
||||
表情包缩略图(WebP 格式)或原图
|
||||
|
||||
Features:
|
||||
- 懒加载:首次请求时生成缩略图
|
||||
- 缓存:后续请求直接返回缓存
|
||||
- GIF 支持:提取第一帧作为缩略图
|
||||
- 格式统一:所有缩略图统一为 WebP 格式
|
||||
"""
|
||||
try:
|
||||
# 优先使用 query parameter 中的 token(用于 img 标签)
|
||||
if token:
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
else:
|
||||
# 如果没有 query token,则验证 Authorization header
|
||||
verify_auth_token(authorization)
|
||||
token_manager = get_token_manager()
|
||||
is_valid = False
|
||||
|
||||
# 1. 优先使用 Cookie
|
||||
if maibot_session and token_manager.verify_token(maibot_session):
|
||||
is_valid = True
|
||||
# 2. 其次使用 query parameter(用于向后兼容 img 标签)
|
||||
elif token and token_manager.verify_token(token):
|
||||
is_valid = True
|
||||
# 3. 最后使用 Authorization header
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
auth_token = authorization.replace("Bearer ", "")
|
||||
if token_manager.verify_token(auth_token):
|
||||
is_valid = True
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -497,19 +668,59 @@ async def get_emoji_thumbnail(
|
||||
if not os.path.exists(emoji.full_path):
|
||||
raise HTTPException(status_code=404, detail="表情包文件不存在")
|
||||
|
||||
# 根据格式设置 MIME 类型
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
# 如果请求原图,直接返回原文件
|
||||
if original:
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
return FileResponse(
|
||||
path=emoji.full_path,
|
||||
media_type=media_type,
|
||||
filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
)
|
||||
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
|
||||
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
|
||||
# 尝试获取或生成缩略图
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
|
||||
# 检查缓存是否存在
|
||||
if cache_path.exists():
|
||||
# 缓存命中,直接返回
|
||||
return FileResponse(
|
||||
path=str(cache_path),
|
||||
media_type="image/webp",
|
||||
filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||
)
|
||||
|
||||
# 缓存未命中,触发后台生成并返回 202
|
||||
with _generating_lock:
|
||||
if emoji.emoji_hash not in _generating_thumbnails:
|
||||
# 标记为正在生成
|
||||
_generating_thumbnails.add(emoji.emoji_hash)
|
||||
# 提交到线程池后台生成
|
||||
_thumbnail_executor.submit(
|
||||
_background_generate_thumbnail,
|
||||
emoji.full_path,
|
||||
emoji.emoji_hash
|
||||
)
|
||||
|
||||
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"status": "generating",
|
||||
"message": "缩略图正在生成中,请稍后重试",
|
||||
"emoji_id": emoji_id,
|
||||
},
|
||||
headers={
|
||||
"Retry-After": "1", # 建议 1 秒后重试
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -519,7 +730,7 @@ async def get_emoji_thumbnail(
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除表情包
|
||||
|
||||
@@ -531,7 +742,7 @@ async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Option
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.emoji_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
|
||||
@@ -572,3 +783,524 @@ async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Option
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||
|
||||
|
||||
# 表情包存储目录
|
||||
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
||||
|
||||
|
||||
class EmojiUploadResponse(BaseModel):
|
||||
"""表情包上传响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[EmojiResponse] = None
|
||||
|
||||
|
||||
@router.post("/upload", response_model=EmojiUploadResponse)
|
||||
async def upload_emoji(
|
||||
file: EmojiFile,
|
||||
description: DescriptionForm = "",
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
上传并注册表情包
|
||||
|
||||
Args:
|
||||
file: 表情包图片文件 (支持 jpg, jpeg, png, gif, webp)
|
||||
description: 表情包描述
|
||||
emotion: 情感标签,多个用逗号分隔
|
||||
is_registered: 是否直接注册,默认为 True
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
上传结果和表情包信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 验证文件类型
|
||||
if not file.content_type:
|
||||
raise HTTPException(status_code=400, detail="无法识别文件类型")
|
||||
|
||||
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||
if file.content_type not in allowed_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file.content_type},支持: {', '.join(allowed_types)}",
|
||||
)
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
if not file_content:
|
||||
raise HTTPException(status_code=400, detail="文件内容为空")
|
||||
|
||||
# 验证图片并获取格式
|
||||
try:
|
||||
with Image.open(io.BytesIO(file_content)) as img:
|
||||
img_format = img.format.lower() if img.format else "png"
|
||||
# 验证图片可以正常打开
|
||||
img.verify()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"无效的图片文件: {str(e)}") from e
|
||||
|
||||
# 重新打开图片(verify后需要重新打开)
|
||||
with Image.open(io.BytesIO(file_content)) as img:
|
||||
img_format = img.format.lower() if img.format else "png"
|
||||
|
||||
# 计算文件哈希
|
||||
emoji_hash = hashlib.md5(file_content).hexdigest()
|
||||
|
||||
# 检查是否已存在相同哈希的表情包
|
||||
existing_emoji = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if existing_emoji:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
|
||||
)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = int(time.time())
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
|
||||
# 如果文件已存在,添加随机后缀
|
||||
counter = 1
|
||||
while os.path.exists(full_path):
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
# 保存文件
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
logger.info(f"表情包文件已保存: {full_path}")
|
||||
|
||||
# 处理情感标签
|
||||
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||
|
||||
# 创建数据库记录
|
||||
current_time = time.time()
|
||||
emoji = Emoji.create(
|
||||
full_path=full_path,
|
||||
format=img_format,
|
||||
emoji_hash=emoji_hash,
|
||||
description=description,
|
||||
emotion=emotion_str,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
record_time=current_time,
|
||||
register_time=current_time if is_registered else None,
|
||||
usage_count=0,
|
||||
last_used_time=None,
|
||||
)
|
||||
|
||||
logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}")
|
||||
|
||||
return EmojiUploadResponse(
|
||||
success=True,
|
||||
message="表情包上传成功" + ("并已注册" if is_registered else ""),
|
||||
data=emoji_to_response(emoji),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"上传表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/upload")
|
||||
async def batch_upload_emoji(
|
||||
files: EmojiFiles,
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量上传表情包
|
||||
|
||||
Args:
|
||||
files: 多个表情包图片文件
|
||||
emotion: 共用的情感标签
|
||||
is_registered: 是否直接注册
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
批量上传结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
results = {
|
||||
"success": True,
|
||||
"total": len(files),
|
||||
"uploaded": 0,
|
||||
"failed": 0,
|
||||
"details": [],
|
||||
}
|
||||
|
||||
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
# 验证文件类型
|
||||
if file.content_type not in allowed_types:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": f"不支持的文件类型: {file.content_type}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
if not file_content:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "文件内容为空",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 验证图片
|
||||
try:
|
||||
with Image.open(io.BytesIO(file_content)) as img:
|
||||
img_format = img.format.lower() if img.format else "png"
|
||||
except Exception as e:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": f"无效的图片: {str(e)}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 计算哈希
|
||||
emoji_hash = hashlib.md5(file_content).hexdigest()
|
||||
|
||||
# 检查重复
|
||||
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "已存在相同的表情包",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成文件名并保存
|
||||
timestamp = int(time.time())
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
|
||||
counter = 1
|
||||
while os.path.exists(full_path):
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
# 处理情感标签
|
||||
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||
|
||||
# 创建数据库记录
|
||||
current_time = time.time()
|
||||
emoji = Emoji.create(
|
||||
full_path=full_path,
|
||||
format=img_format,
|
||||
emoji_hash=emoji_hash,
|
||||
description="", # 批量上传暂不设置描述
|
||||
emotion=emotion_str,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
record_time=current_time,
|
||||
register_time=current_time if is_registered else None,
|
||||
usage_count=0,
|
||||
last_used_time=None,
|
||||
)
|
||||
|
||||
results["uploaded"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": True,
|
||||
"id": emoji.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
||||
return results
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量上传表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ==================== 缩略图缓存管理 API ====================
|
||||
|
||||
|
||||
class ThumbnailCacheStatsResponse(BaseModel):
|
||||
"""缩略图缓存统计响应"""
|
||||
|
||||
success: bool
|
||||
cache_dir: str
|
||||
total_count: int
|
||||
total_size_mb: float
|
||||
emoji_count: int
|
||||
coverage_percent: float
|
||||
|
||||
|
||||
class ThumbnailCleanupResponse(BaseModel):
|
||||
"""缩略图清理响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
cleaned_count: int
|
||||
kept_count: int
|
||||
|
||||
|
||||
class ThumbnailPreheatResponse(BaseModel):
|
||||
"""缩略图预热响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
generated_count: int
|
||||
skipped_count: int
|
||||
failed_count: int
|
||||
|
||||
|
||||
@router.get("/thumbnail-cache/stats", response_model=ThumbnailCacheStatsResponse)
|
||||
async def get_thumbnail_cache_stats(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取缩略图缓存统计信息
|
||||
|
||||
Returns:
|
||||
缓存目录、缓存数量、总大小、覆盖率等统计信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
# 统计缓存文件
|
||||
cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp"))
|
||||
total_count = len(cache_files)
|
||||
total_size = sum(f.stat().st_size for f in cache_files)
|
||||
total_size_mb = round(total_size / (1024 * 1024), 2)
|
||||
|
||||
# 统计表情包总数
|
||||
emoji_count = Emoji.select().count()
|
||||
|
||||
# 计算覆盖率
|
||||
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
|
||||
|
||||
return ThumbnailCacheStatsResponse(
|
||||
success=True,
|
||||
cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()),
|
||||
total_count=total_count,
|
||||
total_size_mb=total_size_mb,
|
||||
emoji_count=emoji_count,
|
||||
coverage_percent=coverage_percent,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取缩略图缓存统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/thumbnail-cache/cleanup", response_model=ThumbnailCleanupResponse)
|
||||
async def cleanup_thumbnail_cache(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图)
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
cleaned, kept = cleanup_orphaned_thumbnails()
|
||||
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存",
|
||||
cleaned_count=cleaned,
|
||||
kept_count=kept,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"清理缩略图缓存失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"清理失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/thumbnail-cache/preheat", response_model=ThumbnailPreheatResponse)
|
||||
async def preheat_thumbnail_cache(
|
||||
limit: int = Query(100, ge=1, le=1000, description="最多预热数量"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
预热缩略图缓存(提前生成未缓存的缩略图)
|
||||
|
||||
优先处理使用次数高的表情包
|
||||
|
||||
Args:
|
||||
limit: 最多预热数量 (1-1000)
|
||||
|
||||
Returns:
|
||||
预热结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
# 获取使用次数最高的表情包(未缓存的优先)
|
||||
emojis = (
|
||||
Emoji.select()
|
||||
.where(Emoji.is_banned == False) # noqa: E712 Peewee ORM requires == for boolean comparison
|
||||
.order_by(Emoji.usage_count.desc())
|
||||
.limit(limit * 2) # 多查一些,因为有些可能已缓存
|
||||
)
|
||||
|
||||
generated = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
for emoji in emojis:
|
||||
if generated >= limit:
|
||||
break
|
||||
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
|
||||
# 已缓存,跳过
|
||||
if cache_path.exists():
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 原文件不存在,跳过
|
||||
if not os.path.exists(emoji.full_path):
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
_thumbnail_executor,
|
||||
_generate_thumbnail,
|
||||
emoji.full_path,
|
||||
emoji.emoji_hash
|
||||
)
|
||||
generated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
||||
failed += 1
|
||||
|
||||
return ThumbnailPreheatResponse(
|
||||
success=True,
|
||||
message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed} 个",
|
||||
generated_count=generated,
|
||||
skipped_count=skipped,
|
||||
failed_count=failed,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"预热缩略图缓存失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"预热失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/thumbnail-cache/clear", response_model=ThumbnailCleanupResponse)
|
||||
async def clear_all_thumbnail_cache(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
清空所有缩略图缓存(下次访问时会重新生成)
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message="缓存目录不存在,无需清理",
|
||||
cleaned_count=0,
|
||||
kept_count=0,
|
||||
)
|
||||
|
||||
cleaned = 0
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
try:
|
||||
cache_file.unlink()
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}")
|
||||
|
||||
logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件")
|
||||
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件",
|
||||
cleaned_count=cleaned,
|
||||
kept_count=0,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"清空缩略图缓存失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
@@ -87,18 +87,12 @@ class ExpressionCreateResponse(BaseModel):
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
@@ -162,7 +156,7 @@ class ChatListResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list(authorization: Optional[str] = Header(None)):
|
||||
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取所有聊天列表(用于下拉选择)
|
||||
|
||||
@@ -173,7 +167,7 @@ async def get_chat_list(authorization: Optional[str] = Header(None)):
|
||||
聊天列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
chat_list = []
|
||||
for cs in ChatStreams.select():
|
||||
@@ -205,6 +199,7 @@ async def get_expression_list(
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -221,7 +216,7 @@ async def get_expression_list(
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
@@ -265,7 +260,7 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
@@ -277,7 +272,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str]
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -294,7 +289,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str]
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
@@ -306,7 +301,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
@@ -336,7 +331,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
@@ -350,7 +345,7 @@ async def update_expression(
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -386,7 +381,7 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
@@ -398,7 +393,7 @@ async def delete_expression(expression_id: int, authorization: Optional[str] = H
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -429,7 +424,7 @@ class BatchDeleteRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
@@ -441,7 +436,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
@@ -470,7 +465,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
@@ -481,7 +476,7 @@ async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
|
||||
|
||||
@@ -602,9 +602,9 @@ class GitMirrorService:
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def run_git_clone():
|
||||
def run_git_clone(clone_cmd=cmd):
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
clone_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
|
||||
547
src/webui/jargon_routes.py
Normal file
547
src/webui/jargon_routes.py
Normal file
@@ -0,0 +1,547 @@
|
||||
"""黑话(俚语)管理路由"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Annotated
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon, ChatStreams
|
||||
|
||||
logger = get_logger("webui.jargon")
|
||||
|
||||
router = APIRouter(prefix="/jargon", tags=["Jargon"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
"""
|
||||
解析 chat_id 字段,提取所有 stream_id
|
||||
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
|
||||
"""
|
||||
if not chat_id_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 尝试解析为 JSON
|
||||
parsed = json.loads(chat_id_str)
|
||||
if isinstance(parsed, list):
|
||||
# 格式: [["stream_id", user_id], ...]
|
||||
stream_ids = []
|
||||
for item in parsed:
|
||||
if isinstance(item, list) and len(item) >= 1:
|
||||
stream_ids.append(str(item[0]))
|
||||
return stream_ids
|
||||
else:
|
||||
# 其他格式,返回原始字符串
|
||||
return [chat_id_str]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 不是有效的 JSON,可能是直接的 stream_id
|
||||
return [chat_id_str]
|
||||
|
||||
|
||||
def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
"""
|
||||
获取 chat_id 的显示名称
|
||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
if not stream_ids:
|
||||
return chat_id_str
|
||||
|
||||
# 查询所有 stream_id 对应的名称
|
||||
names = []
|
||||
for stream_id in stream_ids:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream and chat_stream.group_name:
|
||||
names.append(chat_stream.group_name)
|
||||
else:
|
||||
# 如果没找到,显示截断的 stream_id
|
||||
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
|
||||
|
||||
return ", ".join(names) if names else chat_id_str
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
|
||||
class JargonResponse(BaseModel):
|
||||
"""黑话信息响应"""
|
||||
|
||||
id: int
|
||||
content: str
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: str
|
||||
stream_id: Optional[str] = None # 解析后的 stream_id,用于前端编辑时匹配
|
||||
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
|
||||
is_global: bool = False
|
||||
count: int = 0
|
||||
is_jargon: Optional[bool] = None
|
||||
is_complete: bool = False
|
||||
inference_with_context: Optional[str] = None
|
||||
inference_content_only: Optional[str] = None
|
||||
|
||||
|
||||
class JargonListResponse(BaseModel):
|
||||
"""黑话列表响应"""
|
||||
|
||||
success: bool = True
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[JargonResponse]
|
||||
|
||||
|
||||
class JargonDetailResponse(BaseModel):
|
||||
"""黑话详情响应"""
|
||||
|
||||
success: bool = True
|
||||
data: JargonResponse
|
||||
|
||||
|
||||
class JargonCreateRequest(BaseModel):
|
||||
"""黑话创建请求"""
|
||||
|
||||
content: str = Field(..., description="黑话内容")
|
||||
raw_content: Optional[str] = Field(None, description="原始内容")
|
||||
meaning: Optional[str] = Field(None, description="含义")
|
||||
chat_id: str = Field(..., description="聊天ID")
|
||||
is_global: bool = Field(False, description="是否全局")
|
||||
|
||||
|
||||
class JargonUpdateRequest(BaseModel):
|
||||
"""黑话更新请求"""
|
||||
|
||||
content: Optional[str] = None
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
is_global: Optional[bool] = None
|
||||
is_jargon: Optional[bool] = None
|
||||
|
||||
|
||||
class JargonCreateResponse(BaseModel):
|
||||
"""黑话创建响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
data: JargonResponse
|
||||
|
||||
|
||||
class JargonUpdateResponse(BaseModel):
|
||||
"""黑话更新响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
data: Optional[JargonResponse] = None
|
||||
|
||||
|
||||
class JargonDeleteResponse(BaseModel):
|
||||
"""黑话删除响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
deleted_count: int = 0
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
ids: List[int] = Field(..., description="要删除的黑话ID列表")
|
||||
|
||||
|
||||
class JargonStatsResponse(BaseModel):
|
||||
"""黑话统计响应"""
|
||||
|
||||
success: bool = True
|
||||
data: dict
|
||||
|
||||
|
||||
class ChatInfoResponse(BaseModel):
|
||||
"""聊天信息响应"""
|
||||
|
||||
chat_id: str
|
||||
chat_name: str
|
||||
platform: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
"""聊天列表响应"""
|
||||
|
||||
success: bool = True
|
||||
data: List[ChatInfoResponse]
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon) -> dict:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
# 解析 chat_id 获取显示名称和 stream_id
|
||||
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
|
||||
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
|
||||
stream_id = stream_ids[0] if stream_ids else None
|
||||
|
||||
return {
|
||||
"id": jargon.id,
|
||||
"content": jargon.content,
|
||||
"raw_content": jargon.raw_content,
|
||||
"meaning": jargon.meaning,
|
||||
"chat_id": jargon.chat_id,
|
||||
"stream_id": stream_id,
|
||||
"chat_name": chat_name,
|
||||
"is_global": jargon.is_global,
|
||||
"count": jargon.count,
|
||||
"is_jargon": jargon.is_jargon,
|
||||
"is_complete": jargon.is_complete,
|
||||
"inference_with_context": jargon.inference_with_context,
|
||||
"inference_content_only": jargon.inference_content_only,
|
||||
}
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
|
||||
@router.get("/list", response_model=JargonListResponse)
|
||||
async def get_jargon_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
|
||||
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
|
||||
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
|
||||
):
|
||||
"""获取黑话列表"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = Jargon.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Jargon.content.contains(search))
|
||||
| (Jargon.meaning.contains(search))
|
||||
| (Jargon.raw_content.contains(search))
|
||||
)
|
||||
|
||||
# 按聊天ID筛选(使用 contains 匹配,因为 chat_id 是 JSON 格式)
|
||||
if chat_id:
|
||||
# 从传入的 chat_id 中解析出 stream_id
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
# 使用第一个 stream_id 进行模糊匹配
|
||||
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
|
||||
else:
|
||||
# 如果无法解析,使用精确匹配
|
||||
query = query.where(Jargon.chat_id == chat_id)
|
||||
|
||||
# 按是否是黑话筛选
|
||||
if is_jargon is not None:
|
||||
query = query.where(Jargon.is_jargon == is_jargon)
|
||||
|
||||
# 按是否全局筛选
|
||||
if is_global is not None:
|
||||
query = query.where(Jargon.is_global == is_global)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页和排序(按使用次数降序)
|
||||
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
|
||||
query = query.paginate(page, page_size)
|
||||
|
||||
# 转换为响应格式
|
||||
data = [jargon_to_dict(j) for j in query]
|
||||
|
||||
return JargonListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
try:
|
||||
# 获取所有不同的 chat_id
|
||||
chat_ids = (
|
||||
Jargon.select(Jargon.chat_id)
|
||||
.distinct()
|
||||
.where(Jargon.chat_id.is_null(False))
|
||||
)
|
||||
|
||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||
|
||||
# 用于按 stream_id 去重
|
||||
seen_stream_ids: set[str] = set()
|
||||
|
||||
for chat_id in chat_id_list:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
seen_stream_ids.add(stream_ids[0])
|
||||
|
||||
result = []
|
||||
for stream_id in seen_stream_ids:
|
||||
# 尝试从 ChatStreams 表获取聊天名称
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id,方便筛选匹配
|
||||
chat_name=chat_stream.group_name or stream_id,
|
||||
platform=chat_stream.platform,
|
||||
is_group=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id
|
||||
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
)
|
||||
|
||||
return ChatListResponse(success=True, data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary", response_model=JargonStatsResponse)
|
||||
async def get_jargon_stats():
|
||||
"""获取黑话统计数据"""
|
||||
try:
|
||||
# 总数量
|
||||
total = Jargon.select().count()
|
||||
|
||||
# 已确认是黑话的数量
|
||||
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
|
||||
|
||||
# 已确认不是黑话的数量
|
||||
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
|
||||
|
||||
# 未判定的数量
|
||||
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
|
||||
|
||||
# 全局黑话数量
|
||||
global_count = Jargon.select().where(Jargon.is_global).count()
|
||||
|
||||
# 已完成推断的数量
|
||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||
|
||||
# 关联的聊天数量
|
||||
chat_count = (
|
||||
Jargon.select(Jargon.chat_id)
|
||||
.distinct()
|
||||
.where(Jargon.chat_id.is_null(False))
|
||||
.count()
|
||||
)
|
||||
|
||||
# 按聊天统计 TOP 5
|
||||
top_chats = (
|
||||
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
|
||||
.group_by(Jargon.chat_id)
|
||||
.order_by(fn.COUNT(Jargon.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
|
||||
|
||||
return JargonStatsResponse(
|
||||
success=True,
|
||||
data={
|
||||
"total": total,
|
||||
"confirmed_jargon": confirmed_jargon,
|
||||
"confirmed_not_jargon": confirmed_not_jargon,
|
||||
"pending": pending,
|
||||
"global_count": global_count,
|
||||
"complete_count": complete_count,
|
||||
"chat_count": chat_count,
|
||||
"top_chats": top_chats_dict,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话统计失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
|
||||
async def get_jargon_detail(jargon_id: int):
|
||||
"""获取黑话详情"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/", response_model=JargonCreateResponse)
|
||||
async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
try:
|
||||
# 检查是否已存在相同内容的黑话
|
||||
existing = Jargon.get_or_none(
|
||||
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
|
||||
# 创建黑话
|
||||
jargon = Jargon.create(
|
||||
content=request.content,
|
||||
raw_content=request.raw_content,
|
||||
meaning=request.meaning,
|
||||
chat_id=request.chat_id,
|
||||
is_global=request.is_global,
|
||||
count=0,
|
||||
is_jargon=None,
|
||||
is_complete=False,
|
||||
)
|
||||
|
||||
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
|
||||
|
||||
return JargonCreateResponse(
|
||||
success=True,
|
||||
message="创建成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
"""更新黑话(增量更新)"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
# 增量更新字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if update_data:
|
||||
for field, value in update_data.items():
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
jargon.save()
|
||||
|
||||
logger.info(f"更新黑话成功: id={jargon_id}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message="更新成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
|
||||
async def delete_jargon(jargon_id: int):
|
||||
"""删除黑话"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
content = jargon.content
|
||||
jargon.delete_instance()
|
||||
|
||||
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message="删除成功",
|
||||
deleted_count=1,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=JargonDeleteResponse)
|
||||
async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
"""批量删除黑话"""
|
||||
try:
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
|
||||
|
||||
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除 {deleted_count} 条黑话",
|
||||
deleted_count=deleted_count,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量删除黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/set-jargon", response_model=JargonUpdateResponse)
|
||||
async def batch_set_jargon_status(
|
||||
ids: Annotated[List[int], Query(description="黑话ID列表")],
|
||||
is_jargon: Annotated[bool, Query(description="是否是黑话")],
|
||||
):
|
||||
"""批量设置黑话状态"""
|
||||
try:
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
updated_count = (
|
||||
Jargon.update(is_jargon=is_jargon)
|
||||
.where(Jargon.id.in_(ids))
|
||||
.execute()
|
||||
)
|
||||
|
||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {updated_count} 条黑话状态",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新黑话状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量更新黑话状态失败: {str(e)}") from e
|
||||
@@ -1,4 +1,5 @@
|
||||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from pydantic import BaseModel
|
||||
@@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||
|
||||
class KnowledgeNode(BaseModel):
|
||||
"""知识节点"""
|
||||
|
||||
id: str
|
||||
type: str # 'entity' or 'paragraph'
|
||||
content: str
|
||||
@@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel):
|
||||
|
||||
class KnowledgeEdge(BaseModel):
|
||||
"""知识边"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
weight: float
|
||||
@@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel):
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""知识图谱"""
|
||||
|
||||
nodes: List[KnowledgeNode]
|
||||
edges: List[KnowledgeEdge]
|
||||
|
||||
|
||||
class KnowledgeStats(BaseModel):
|
||||
"""知识库统计信息"""
|
||||
|
||||
total_nodes: int
|
||||
total_edges: int
|
||||
entity_nodes: int
|
||||
@@ -45,7 +50,7 @@ def _load_kg_manager():
|
||||
"""延迟加载 KGManager"""
|
||||
try:
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
|
||||
|
||||
kg_manager = KGManager()
|
||||
kg_manager.load_from_file()
|
||||
return kg_manager
|
||||
@@ -58,31 +63,26 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
"""将 DiGraph 转换为 JSON 格式"""
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
|
||||
# 转换节点
|
||||
node_list = graph.get_node_list()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 转换边
|
||||
edge_list = graph.get_edge_list()
|
||||
for edge_tuple in edge_list:
|
||||
@@ -91,37 +91,35 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
source, target = edge_tuple[0], edge_tuple[1]
|
||||
# 通过 graph[source, target] 获取边的属性数据
|
||||
edge_data = graph[source, target]
|
||||
|
||||
|
||||
# edge_data 支持 [] 操作符但不支持 .get()
|
||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
||||
|
||||
edges.append(KnowledgeEdge(
|
||||
source=source,
|
||||
target=target,
|
||||
weight=weight,
|
||||
create_time=create_time,
|
||||
update_time=update_time
|
||||
))
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
@router.get("/graph", response_model=KnowledgeGraph)
|
||||
async def get_knowledge_graph(
|
||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph")
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||
):
|
||||
"""获取知识图谱(限制节点数量)
|
||||
|
||||
|
||||
Args:
|
||||
limit: 返回的最大节点数,默认 100,最大 10000
|
||||
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
|
||||
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
|
||||
"""
|
||||
@@ -130,46 +128,43 @@ async def get_knowledge_graph(
|
||||
if kg_manager is None:
|
||||
logger.warning("KGManager 未初始化,返回空图谱")
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
all_node_list = graph.get_node_list()
|
||||
|
||||
|
||||
# 按类型过滤节点
|
||||
if node_type == "entity":
|
||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent']
|
||||
all_node_list = [
|
||||
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
|
||||
]
|
||||
elif node_type == "paragraph":
|
||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg']
|
||||
|
||||
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
|
||||
|
||||
# 限制节点数量
|
||||
total_nodes = len(all_node_list)
|
||||
if len(all_node_list) > limit:
|
||||
node_list = all_node_list[:limit]
|
||||
else:
|
||||
node_list = all_node_list
|
||||
|
||||
|
||||
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
|
||||
|
||||
|
||||
# 转换节点
|
||||
nodes = []
|
||||
node_ids = set()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type_val,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||
node_ids.add(node_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 只获取涉及当前节点集的边(保证图的完整性)
|
||||
edges = []
|
||||
edge_list = graph.get_edge_list()
|
||||
@@ -179,27 +174,25 @@ async def get_knowledge_graph(
|
||||
# 只包含两端都在当前节点集中的边
|
||||
if source not in node_ids or target not in node_ids:
|
||||
continue
|
||||
|
||||
|
||||
edge_data = graph[source, target]
|
||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
||||
|
||||
edges.append(KnowledgeEdge(
|
||||
source=source,
|
||||
target=target,
|
||||
weight=weight,
|
||||
create_time=create_time,
|
||||
update_time=update_time
|
||||
))
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
|
||||
return graph_data
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
@@ -208,71 +201,59 @@ async def get_knowledge_graph(
|
||||
@router.get("/stats", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats():
|
||||
"""获取知识库统计信息
|
||||
|
||||
|
||||
Returns:
|
||||
KnowledgeStats: 统计信息
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeStats(
|
||||
total_nodes=0,
|
||||
total_edges=0,
|
||||
entity_nodes=0,
|
||||
paragraph_nodes=0,
|
||||
avg_connections=0.0
|
||||
)
|
||||
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
edge_list = graph.get_edge_list()
|
||||
|
||||
|
||||
total_nodes = len(node_list)
|
||||
total_edges = len(edge_list)
|
||||
|
||||
|
||||
# 统计节点类型
|
||||
entity_nodes = 0
|
||||
paragraph_nodes = 0
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type = node_data['type'] if 'type' in node_data else 'ent'
|
||||
if node_type == 'ent':
|
||||
node_type = node_data["type"] if "type" in node_data else "ent"
|
||||
if node_type == "ent":
|
||||
entity_nodes += 1
|
||||
elif node_type == 'pg':
|
||||
elif node_type == "pg":
|
||||
paragraph_nodes += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
# 计算平均连接数
|
||||
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
|
||||
|
||||
|
||||
return KnowledgeStats(
|
||||
total_nodes=total_nodes,
|
||||
total_edges=total_edges,
|
||||
entity_nodes=entity_nodes,
|
||||
paragraph_nodes=paragraph_nodes,
|
||||
avg_connections=round(avg_connections, 2)
|
||||
avg_connections=round(avg_connections, 2),
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||||
return KnowledgeStats(
|
||||
total_nodes=0,
|
||||
total_edges=0,
|
||||
entity_nodes=0,
|
||||
paragraph_nodes=0,
|
||||
avg_connections=0.0
|
||||
)
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeNode])
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
"""搜索知识节点
|
||||
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
|
||||
|
||||
Returns:
|
||||
List[KnowledgeNode]: 匹配的节点列表
|
||||
"""
|
||||
@@ -280,33 +261,28 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return []
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
|
||||
|
||||
# 在节点内容中搜索
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
if query_lower in content.lower() or query_lower in node_id.lower():
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
results.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
|
||||
return results[:50] # 限制返回数量
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索节点失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
@@ -43,25 +43,27 @@ def _normalize_url(url: str) -> str:
|
||||
def _parse_openai_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 OpenAI 格式的模型列表响应
|
||||
|
||||
|
||||
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
if isinstance(model, dict) and "id" in model:
|
||||
models.append({
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
})
|
||||
models.append(
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 Gemini 格式的模型列表响应
|
||||
|
||||
|
||||
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
|
||||
"""
|
||||
models = []
|
||||
@@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
model_id = model["name"]
|
||||
if model_id.startswith("models/"):
|
||||
model_id = model_id[7:] # 去掉 "models/" 前缀
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"name": model.get("displayName") or model_id,
|
||||
"owned_by": "google",
|
||||
})
|
||||
models.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"name": model.get("displayName") or model_id,
|
||||
"owned_by": "google",
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
@@ -89,55 +93,54 @@ async def _fetch_models_from_provider(
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从提供商 API 获取模型列表
|
||||
|
||||
|
||||
Args:
|
||||
base_url: 提供商的基础 URL
|
||||
api_key: API 密钥
|
||||
endpoint: 获取模型列表的端点
|
||||
parser: 响应解析器类型 ('openai' | 'gemini')
|
||||
client_type: 客户端类型 ('openai' | 'gemini')
|
||||
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
url = f"{_normalize_url(base_url)}{endpoint}"
|
||||
|
||||
|
||||
# 根据客户端类型设置请求头
|
||||
headers = {}
|
||||
params = {}
|
||||
|
||||
|
||||
if client_type == "gemini":
|
||||
# Gemini 使用 URL 参数传递 API Key
|
||||
params["key"] = api_key
|
||||
else:
|
||||
# OpenAI 兼容格式使用 Authorization 头
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试")
|
||||
except httpx.TimeoutException as e:
|
||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
||||
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
||||
if e.response.status_code == 401:
|
||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期")
|
||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
|
||||
elif e.response.status_code == 403:
|
||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限")
|
||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
|
||||
elif e.response.status_code == 404:
|
||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表")
|
||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||
)
|
||||
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
|
||||
|
||||
# 根据解析器类型解析响应
|
||||
if parser == "openai":
|
||||
return _parse_openai_response(data)
|
||||
@@ -150,26 +153,26 @@ async def _fetch_models_from_provider(
|
||||
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||
"""
|
||||
从 model_config.toml 获取指定提供商的配置
|
||||
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称
|
||||
|
||||
|
||||
Returns:
|
||||
提供商配置,如果未找到则返回 None
|
||||
"""
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
|
||||
providers = config_data.get("api_providers", [])
|
||||
for provider in providers:
|
||||
if provider.get("name") == provider_name:
|
||||
return dict(provider)
|
||||
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"读取提供商配置失败: {e}")
|
||||
@@ -184,23 +187,23 @@ async def get_provider_models(
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
|
||||
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点
|
||||
"""
|
||||
# 获取提供商配置
|
||||
provider_config = _get_provider_config(provider_name)
|
||||
if not provider_config:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
|
||||
base_url = provider_config.get("base_url")
|
||||
api_key = provider_config.get("api_key")
|
||||
client_type = provider_config.get("client_type", "openai")
|
||||
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
||||
|
||||
|
||||
# 获取模型列表
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
@@ -209,7 +212,7 @@ async def get_provider_models(
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
@@ -236,9 +239,132 @@ async def get_models_by_url(
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
"count": len(models),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/test-connection")
|
||||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
|
||||
分两步测试:
|
||||
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
|
||||
2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效
|
||||
|
||||
返回:
|
||||
- network_ok: 网络是否连通
|
||||
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
|
||||
- latency_ms: 响应延迟(毫秒)
|
||||
- error: 错误信息(如果有)
|
||||
"""
|
||||
import time
|
||||
|
||||
base_url = _normalize_url(base_url)
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="base_url 不能为空")
|
||||
|
||||
result = {
|
||||
"network_ok": False,
|
||||
"api_key_valid": None,
|
||||
"latency_ms": None,
|
||||
"error": None,
|
||||
"http_status": None,
|
||||
}
|
||||
|
||||
# 第一步:测试网络连通性
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
|
||||
# 尝试 GET 请求 base_url(不需要 API Key)
|
||||
response = await client.get(base_url)
|
||||
latency = (time.time() - start_time) * 1000
|
||||
|
||||
result["network_ok"] = True
|
||||
result["latency_ms"] = round(latency, 2)
|
||||
result["http_status"] = response.status_code
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
|
||||
return result
|
||||
except httpx.TimeoutException:
|
||||
result["error"] = "连接超时:服务器响应时间过长"
|
||||
return result
|
||||
except httpx.RequestError as e:
|
||||
result["error"] = f"请求错误:{str(e)}"
|
||||
return result
|
||||
except Exception as e:
|
||||
result["error"] = f"未知错误:{str(e)}"
|
||||
return result
|
||||
|
||||
# 第二步:如果提供了 API Key,验证其有效性
|
||||
if api_key:
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 尝试获取模型列表
|
||||
models_url = f"{base_url}/models"
|
||||
response = await client.get(models_url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result["api_key_valid"] = True
|
||||
elif response.status_code in (401, 403):
|
||||
result["api_key_valid"] = False
|
||||
result["error"] = "API Key 无效或已过期"
|
||||
else:
|
||||
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
|
||||
result["api_key_valid"] = None
|
||||
|
||||
except Exception as e:
|
||||
# API Key 验证失败不影响网络连通性结果
|
||||
logger.warning(f"API Key 验证失败: {e}")
|
||||
result["api_key_valid"] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/test-connection-by-name")
|
||||
async def test_provider_connection_by_name(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
):
|
||||
"""
|
||||
通过提供商名称测试连接(从配置文件读取信息)
|
||||
"""
|
||||
# 读取配置文件
|
||||
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(model_config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 查找提供商
|
||||
providers = config.get("api_providers", [])
|
||||
provider = None
|
||||
for p in providers:
|
||||
if p.get("name") == provider_name:
|
||||
provider = p
|
||||
break
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
base_url = provider.get("base_url", "")
|
||||
api_key = provider.get("api_key", "")
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
|
||||
# 调用测试接口
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""人物信息管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import json
|
||||
import time
|
||||
|
||||
@@ -91,18 +91,12 @@ class BatchDeleteResponse(BaseModel):
|
||||
failed_ids: List[str] = []
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
|
||||
@@ -141,6 +135,7 @@ async def get_person_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -158,7 +153,7 @@ async def get_person_list(
|
||||
人物信息列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = PersonInfo.select()
|
||||
@@ -205,7 +200,7 @@ async def get_person_list(
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
@@ -217,7 +212,7 @@ async def get_person_detail(person_id: str, authorization: Optional[str] = Heade
|
||||
人物详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
@@ -234,7 +229,7 @@ async def get_person_detail(person_id: str, authorization: Optional[str] = Heade
|
||||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
@@ -247,7 +242,7 @@ async def update_person(person_id: str, request: PersonUpdateRequest, authorizat
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
@@ -283,7 +278,7 @@ async def update_person(person_id: str, request: PersonUpdateRequest, authorizat
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
@@ -295,7 +290,7 @@ async def delete_person(person_id: str, authorization: Optional[str] = Header(No
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
@@ -320,7 +315,7 @@ async def delete_person(person_id: str, authorization: Optional[str] = Header(No
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
|
||||
@@ -331,7 +326,7 @@ async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = PersonInfo.select().count()
|
||||
known = PersonInfo.select().where(PersonInfo.is_known).count()
|
||||
@@ -353,7 +348,7 @@ async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除人物信息
|
||||
|
||||
@@ -365,7 +360,7 @@ async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optio
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.person_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from fastapi import APIRouter, HTTPException, Header, Cookie
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
import json
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||
from .token_manager import get_token_manager
|
||||
@@ -18,6 +19,20 @@ router = APIRouter(prefix="/plugins", tags=["插件管理"])
|
||||
set_update_progress_callback(update_progress)
|
||||
|
||||
|
||||
def get_token_from_cookie_or_header(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""从 Cookie 或 Header 获取 token"""
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
return maibot_session
|
||||
# 其次从 Header 获取
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
解析版本号字符串
|
||||
@@ -29,8 +44,11 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
Returns:
|
||||
(major, minor, patch) 三元组
|
||||
"""
|
||||
# 移除 snapshot 等后缀
|
||||
base_version = version_str.split(".snapshot")[0].split(".dev")[0].split(".alpha")[0].split(".beta")[0]
|
||||
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
|
||||
import re
|
||||
|
||||
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||
|
||||
parts = base_version.split(".")
|
||||
if len(parts) < 3:
|
||||
@@ -206,12 +224,12 @@ async def check_git_status() -> GitStatusResponse:
|
||||
|
||||
|
||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||
async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
||||
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
||||
"""
|
||||
获取所有可用的镜像源配置
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -236,12 +254,12 @@ async def get_available_mirrors(authorization: Optional[str] = Header(None)) ->
|
||||
|
||||
|
||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||
async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
||||
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
||||
"""
|
||||
添加新的镜像源
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -276,13 +294,13 @@ async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = H
|
||||
|
||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||
async def update_mirror(
|
||||
mirror_id: str, request: UpdateMirrorRequest, authorization: Optional[str] = Header(None)
|
||||
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> MirrorConfigResponse:
|
||||
"""
|
||||
更新镜像源配置
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -319,12 +337,12 @@ async def update_mirror(
|
||||
|
||||
|
||||
@router.delete("/mirrors/{mirror_id}")
|
||||
async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
删除镜像源
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -342,7 +360,7 @@ async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(No
|
||||
|
||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||
async def fetch_raw_file(
|
||||
request: FetchRawFileRequest, authorization: Optional[str] = Header(None)
|
||||
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> FetchRawFileResponse:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
@@ -352,7 +370,7 @@ async def fetch_raw_file(
|
||||
注意:此接口可公开访问,用于获取插件仓库等公开资源
|
||||
"""
|
||||
# Token 验证(可选,用于日志记录)
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
is_authenticated = token and token_manager.verify_token(token)
|
||||
|
||||
@@ -427,7 +445,7 @@ async def fetch_raw_file(
|
||||
|
||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||
async def clone_repository(
|
||||
request: CloneRepositoryRequest, authorization: Optional[str] = Header(None)
|
||||
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> CloneRepositoryResponse:
|
||||
"""
|
||||
克隆 GitHub 仓库到本地
|
||||
@@ -435,7 +453,7 @@ async def clone_repository(
|
||||
支持多镜像源自动切换和错误重试
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -467,14 +485,14 @@ async def clone_repository(
|
||||
|
||||
|
||||
@router.post("/install")
|
||||
async def install_plugin(request: InstallPluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
安装插件
|
||||
|
||||
从 Git 仓库克隆插件到本地插件目录
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -611,7 +629,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[
|
||||
for field in required_fields:
|
||||
if field not in manifest:
|
||||
raise ValueError(f"缺少必需字段: {field}")
|
||||
|
||||
|
||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||
manifest["id"] = request.plugin_id
|
||||
@@ -671,7 +689,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[
|
||||
|
||||
@router.post("/uninstall")
|
||||
async def uninstall_plugin(
|
||||
request: UninstallPluginRequest, authorization: Optional[str] = Header(None)
|
||||
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
卸载插件
|
||||
@@ -679,7 +697,7 @@ async def uninstall_plugin(
|
||||
删除插件目录及其所有文件
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -703,7 +721,7 @@ async def uninstall_plugin(
|
||||
plugin_path = plugins_dir / folder_name
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
if old_format_path.exists():
|
||||
@@ -806,14 +824,14 @@ async def uninstall_plugin(
|
||||
|
||||
|
||||
@router.post("/update")
|
||||
async def update_plugin(request: UpdatePluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
更新插件
|
||||
|
||||
删除旧版本,重新克隆新版本
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -837,7 +855,7 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st
|
||||
plugin_path = plugins_dir / folder_name
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
if old_format_path.exists():
|
||||
@@ -1025,14 +1043,14 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st
|
||||
|
||||
|
||||
@router.get("/installed")
|
||||
async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
获取已安装的插件列表
|
||||
|
||||
扫描 plugins 目录,返回所有已安装插件的 ID 和基本信息
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -1090,21 +1108,21 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
||||
# 尝试从 author.name 和 repository_url 构建标准 ID
|
||||
author_name = None
|
||||
repo_name = None
|
||||
|
||||
|
||||
# 获取作者名
|
||||
if "author" in manifest:
|
||||
if isinstance(manifest["author"], dict) and "name" in manifest["author"]:
|
||||
author_name = manifest["author"]["name"]
|
||||
elif isinstance(manifest["author"], str):
|
||||
author_name = manifest["author"]
|
||||
|
||||
|
||||
# 从 repository_url 获取仓库名
|
||||
if "repository_url" in manifest:
|
||||
repo_url = manifest["repository_url"].rstrip("/")
|
||||
if repo_url.endswith(".git"):
|
||||
repo_url = repo_url[:-4]
|
||||
repo_name = repo_url.split("/")[-1]
|
||||
|
||||
|
||||
# 构建 ID
|
||||
if author_name and repo_name:
|
||||
# 标准格式: Author.RepoName
|
||||
@@ -1120,7 +1138,7 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
||||
else:
|
||||
# 直接使用文件夹名
|
||||
plugin_id = folder_name
|
||||
|
||||
|
||||
# 将推断的 ID 写入 manifest(方便下次识别)
|
||||
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
|
||||
manifest["id"] = plugin_id
|
||||
@@ -1153,3 +1171,408 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
||||
except Exception as e:
|
||||
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
# ============ 插件配置管理 API ============
|
||||
|
||||
|
||||
class UpdatePluginConfigRequest(BaseModel):
|
||||
"""更新插件配置请求"""
|
||||
|
||||
config: Dict[str, Any] = Field(..., description="配置数据")
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/schema")
|
||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件配置 Schema
|
||||
|
||||
返回插件的完整配置 schema,包含所有 section、字段定义和布局信息。
|
||||
用于前端动态生成配置表单。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 尝试从已加载的插件中获取
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
# 查找插件实例
|
||||
plugin_instance = None
|
||||
|
||||
# 遍历所有已加载的插件
|
||||
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
||||
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
||||
if instance:
|
||||
# 匹配 plugin_name 或 manifest 中的 id
|
||||
if instance.plugin_name == plugin_id:
|
||||
plugin_instance = instance
|
||||
break
|
||||
# 也尝试匹配 manifest 中的 id
|
||||
manifest_id = instance.get_manifest_info("id", "")
|
||||
if manifest_id == plugin_id:
|
||||
plugin_instance = instance
|
||||
break
|
||||
|
||||
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
||||
# 从插件实例获取 schema
|
||||
schema = plugin_instance.get_webui_config_schema()
|
||||
return {"success": True, "schema": schema}
|
||||
|
||||
# 如果插件未加载,尝试从文件系统读取
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
# 读取配置文件获取当前配置
|
||||
config_path = plugin_path / "config.toml"
|
||||
current_config = {}
|
||||
if config_path.exists():
|
||||
import tomlkit
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
current_config = tomlkit.load(f)
|
||||
|
||||
# 构建基础 schema(无法获取完整的 ConfigField 信息)
|
||||
schema = {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_info": {
|
||||
"name": plugin_id,
|
||||
"version": "",
|
||||
"description": "",
|
||||
"author": "",
|
||||
},
|
||||
"sections": {},
|
||||
"layout": {"type": "auto", "tabs": []},
|
||||
"_note": "插件未加载,仅返回当前配置结构",
|
||||
}
|
||||
|
||||
# 从当前配置推断 schema
|
||||
for section_name, section_data in current_config.items():
|
||||
if isinstance(section_data, dict):
|
||||
schema["sections"][section_name] = {
|
||||
"name": section_name,
|
||||
"title": section_name,
|
||||
"description": None,
|
||||
"icon": None,
|
||||
"collapsed": False,
|
||||
"order": 0,
|
||||
"fields": {},
|
||||
}
|
||||
for field_name, field_value in section_data.items():
|
||||
# 推断字段类型
|
||||
field_type = type(field_value).__name__
|
||||
ui_type = "text"
|
||||
if isinstance(field_value, bool):
|
||||
ui_type = "switch"
|
||||
elif isinstance(field_value, (int, float)):
|
||||
ui_type = "number"
|
||||
elif isinstance(field_value, list):
|
||||
ui_type = "list"
|
||||
elif isinstance(field_value, dict):
|
||||
ui_type = "json"
|
||||
|
||||
schema["sections"][section_name]["fields"][field_name] = {
|
||||
"name": field_name,
|
||||
"type": field_type,
|
||||
"default": field_value,
|
||||
"description": field_name,
|
||||
"label": field_name,
|
||||
"ui_type": ui_type,
|
||||
"required": False,
|
||||
"hidden": False,
|
||||
"disabled": False,
|
||||
"order": 0,
|
||||
}
|
||||
|
||||
return {"success": True, "schema": schema}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件配置 Schema 失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}")
|
||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件当前配置值
|
||||
|
||||
返回插件的当前配置值。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"获取插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
# 读取配置文件
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||
|
||||
import tomlkit
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
return {"success": True, "config": dict(config)}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.put("/config/{plugin_id}")
|
||||
async def update_plugin_config(
|
||||
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新插件配置
|
||||
|
||||
保存新的配置值到插件的配置文件。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"更新插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
# 备份旧配置
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
if config_path.exists():
|
||||
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = plugin_path / backup_name
|
||||
shutil.copy(config_path, backup_path)
|
||||
logger.info(f"已备份配置文件: {backup_path}")
|
||||
|
||||
# 写入新配置(使用 tomlkit 保留注释)
|
||||
import tomlkit
|
||||
|
||||
# 先读取原配置以保留注释和格式
|
||||
existing_doc = tomlkit.document()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
existing_doc = tomlkit.load(f)
|
||||
# 更新值
|
||||
for key, value in request.config.items():
|
||||
existing_doc[key] = value
|
||||
save_toml_with_format(existing_doc, str(config_path))
|
||||
|
||||
logger.info(f"已更新插件配置: {plugin_id}")
|
||||
|
||||
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/reset")
|
||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
重置插件配置为默认值
|
||||
|
||||
删除当前配置文件,下次加载插件时将使用默认配置。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"重置插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
if not config_path.exists():
|
||||
return {"success": True, "message": "配置文件不存在,无需重置"}
|
||||
|
||||
# 备份并删除
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = plugin_path / backup_name
|
||||
shutil.move(config_path, backup_path)
|
||||
|
||||
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
|
||||
|
||||
return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重置插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/toggle")
|
||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
切换插件启用状态
|
||||
|
||||
切换插件配置中的 enabled 字段。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"切换插件状态: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
import tomlkit
|
||||
|
||||
# 读取当前配置(保留注释和格式)
|
||||
config = tomlkit.document()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 切换 enabled 状态
|
||||
if "plugin" not in config:
|
||||
config["plugin"] = tomlkit.table()
|
||||
|
||||
current_enabled = config["plugin"].get("enabled", True)
|
||||
new_enabled = not current_enabled
|
||||
config["plugin"]["enabled"] = new_enabled
|
||||
|
||||
# 写入配置(保留注释,格式化数组)
|
||||
save_toml_with_format(config, str(config_path))
|
||||
|
||||
status = "启用" if new_enabled else "禁用"
|
||||
logger.info(f"已{status}插件: {plugin_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"enabled": new_enabled,
|
||||
"message": f"插件已{status}",
|
||||
"note": "状态更改将在下次加载插件时生效",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"切换插件状态失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
@@ -5,14 +5,15 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.common.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
logger = get_logger("webui_system")
|
||||
|
||||
# 记录启动时间
|
||||
_start_time = time.time()
|
||||
@@ -39,22 +40,23 @@ async def restart_maibot():
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
使用 os.execv 重启当前进程,配置更改将在重启后生效。
|
||||
请求重启当前进程,配置更改将在重启后生效。
|
||||
注意:此操作会使麦麦暂时离线。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
try:
|
||||
# 记录重启操作
|
||||
print(f"[{datetime.now()}] WebUI 触发重启操作")
|
||||
logger.info("WebUI 触发重启操作")
|
||||
|
||||
# 定义延迟重启的异步任务
|
||||
async def delayed_restart():
|
||||
await asyncio.sleep(0.5) # 延迟0.5秒,确保响应已发送
|
||||
python = sys.executable
|
||||
args = [python] + sys.argv
|
||||
os.execv(python, args)
|
||||
|
||||
# 使用 os._exit(42) 退出当前进程,配合外部 runner 脚本进行重启
|
||||
# 42 是约定的重启状态码
|
||||
logger.info("WebUI 请求重启,退出代码 42")
|
||||
os._exit(42)
|
||||
|
||||
# 创建后台任务执行重启
|
||||
asyncio.create_task(delayed_restart())
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import set_auth_cookie, clear_auth_cookie
|
||||
from .config_routes import router as config_router
|
||||
from .statistics_routes import router as statistics_router
|
||||
from .person_routes import router as person_router
|
||||
from .expression_routes import router as expression_router
|
||||
from .jargon_routes import router as jargon_router
|
||||
from .emoji_routes import router as emoji_router
|
||||
from .plugin_routes import router as plugin_router
|
||||
from .plugin_progress_ws import get_progress_router
|
||||
@@ -28,6 +30,8 @@ router.include_router(statistics_router)
|
||||
router.include_router(person_router)
|
||||
# 注册表达方式管理路由
|
||||
router.include_router(expression_router)
|
||||
# 注册黑话管理路由
|
||||
router.include_router(jargon_router)
|
||||
# 注册表情包管理路由
|
||||
router.include_router(emoji_router)
|
||||
# 注册插件管理路由
|
||||
@@ -51,6 +55,7 @@ class TokenVerifyResponse(BaseModel):
|
||||
|
||||
valid: bool = Field(..., description="Token 是否有效")
|
||||
message: str = Field(..., description="验证结果消息")
|
||||
is_first_setup: bool = Field(False, description="是否为首次设置")
|
||||
|
||||
|
||||
class TokenUpdateRequest(BaseModel):
|
||||
@@ -102,22 +107,27 @@ async def health_check():
|
||||
|
||||
|
||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||
async def verify_token(request: TokenVerifyRequest):
|
||||
async def verify_token(request: TokenVerifyRequest, response: Response):
|
||||
"""
|
||||
验证访问令牌
|
||||
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
||||
|
||||
Args:
|
||||
request: 包含 token 的验证请求
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
验证结果(包含首次配置状态)
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(request.token)
|
||||
|
||||
if is_valid:
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功")
|
||||
# 设置 HttpOnly Cookie
|
||||
set_auth_cookie(response, request.token)
|
||||
# 同时返回首次配置状态,避免额外请求
|
||||
is_first_setup = token_manager.is_first_setup()
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
||||
else:
|
||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||
except Exception as e:
|
||||
@@ -125,24 +135,86 @@ async def verify_token(request: TokenVerifyRequest):
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/logout")
|
||||
async def logout(response: Response):
|
||||
"""
|
||||
登出并清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
登出结果
|
||||
"""
|
||||
clear_auth_cookie(response)
|
||||
return {"success": True, "message": "已成功登出"}
|
||||
|
||||
|
||||
@router.get("/auth/check")
|
||||
async def check_auth_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
检查当前认证状态(用于前端判断是否已登录)
|
||||
|
||||
Returns:
|
||||
认证状态
|
||||
"""
|
||||
try:
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
return {"authenticated": False}
|
||||
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
return {"authenticated": True}
|
||||
else:
|
||||
return {"authenticated": False}
|
||||
except Exception:
|
||||
return {"authenticated": False}
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
response: Response,
|
||||
req: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
request: 包含新 token 的更新请求
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -150,6 +222,10 @@ async def update_token(request: TokenUpdateRequest, authorization: Optional[str]
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
# 如果更新成功,更新 Cookie
|
||||
if success:
|
||||
set_auth_cookie(response, request.new_token)
|
||||
|
||||
return TokenUpdateResponse(success=success, message=message)
|
||||
except HTTPException:
|
||||
@@ -160,22 +236,34 @@ async def update_token(request: TokenUpdateRequest, authorization: Optional[str]
|
||||
|
||||
|
||||
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
|
||||
async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
async def regenerate_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重新生成访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
新生成的 token
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -183,6 +271,9 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
# 更新 Cookie
|
||||
set_auth_cookie(response, new_token)
|
||||
|
||||
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||
except HTTPException:
|
||||
@@ -193,22 +284,32 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
|
||||
async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||
async def get_setup_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取首次配置状态
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
首次配置状态
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -226,22 +327,32 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/setup/complete", response_model=CompleteSetupResponse)
|
||||
async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||
async def complete_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
标记首次配置完成
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
完成结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -259,22 +370,32 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/setup/reset", response_model=ResetSetupResponse)
|
||||
async def reset_setup(authorization: Optional[str] = Header(None)):
|
||||
async def reset_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
重置结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
from src.common.logger import get_logger
|
||||
@@ -20,19 +21,39 @@ class WebUIServer:
|
||||
self.port = port
|
||||
self.app = FastAPI(title="MaiBot WebUI")
|
||||
self._server = None
|
||||
|
||||
|
||||
# 配置 CORS(支持开发环境跨域请求)
|
||||
self._setup_cors()
|
||||
|
||||
# 显示 Access Token
|
||||
self._show_access_token()
|
||||
|
||||
|
||||
# 重要:先注册 API 路由,再设置静态文件
|
||||
self._register_api_routes()
|
||||
self._setup_static_files()
|
||||
|
||||
def _setup_cors(self):
|
||||
"""配置 CORS 中间件"""
|
||||
# 开发环境需要允许前端开发服务器的跨域请求
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:5173", # Vite 开发服务器
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:8001", # 生产环境
|
||||
"http://127.0.0.1:8001",
|
||||
],
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
logger.debug("✅ CORS 中间件已配置")
|
||||
|
||||
def _show_access_token(self):
|
||||
"""显示 WebUI Access Token"""
|
||||
try:
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
@@ -69,7 +90,7 @@ class WebUIServer:
|
||||
# 如果是根路径,直接返回 index.html
|
||||
if not full_path or full_path == "/":
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
|
||||
|
||||
# 检查是否是静态文件
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file() and file_path.exists():
|
||||
@@ -88,15 +109,22 @@ class WebUIServer:
|
||||
# 导入所有 WebUI 路由
|
||||
from src.webui.routes import router as webui_router
|
||||
from src.webui.logs_ws import router as logs_router
|
||||
|
||||
|
||||
logger.info("开始导入 knowledge_routes...")
|
||||
from src.webui.knowledge_routes import router as knowledge_router
|
||||
|
||||
logger.info("knowledge_routes 导入成功")
|
||||
|
||||
# 导入本地聊天室路由
|
||||
from src.webui.chat_routes import router as chat_router
|
||||
|
||||
logger.info("chat_routes 导入成功")
|
||||
|
||||
# 注册路由
|
||||
self.app.include_router(webui_router)
|
||||
self.app.include_router(logs_router)
|
||||
self.app.include_router(knowledge_router)
|
||||
self.app.include_router(chat_router)
|
||||
logger.info(f"knowledge_router 路由前缀: {knowledge_router.prefix}")
|
||||
|
||||
logger.info("✅ WebUI API 路由已注册")
|
||||
@@ -116,6 +144,8 @@ class WebUIServer:
|
||||
|
||||
logger.info("🌐 WebUI 服务器启动中...")
|
||||
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||
if self.host == "0.0.0.0":
|
||||
logger.info(f"本机访问请使用 http://localhost:{self.port}")
|
||||
|
||||
try:
|
||||
await self._server.serve()
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
version = "6.23.5"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
#如果新增项目,请阅读src/config/official_configs.py中的说明
|
||||
# 如果你想要修改配置文件,请递增version的值
|
||||
# 如果新增项目,请阅读src/config/official_configs.py中的说明
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
|
||||
# 主版本号:MMC版本更新
|
||||
@@ -23,7 +23,7 @@ alias_names = ["麦叠", "牢麦"] # 麦麦的别名
|
||||
[personality]
|
||||
# 建议120字以内,描述人格特质 和 身份特征
|
||||
personality = "是一个女大学生,现在在读大二,会刷贴吧。"
|
||||
#アイデンティティがない 生まれないらららら
|
||||
# アイデンティティがない 生まれないらららら
|
||||
# 描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容
|
||||
reply_style = "请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景。可以参考贴吧,知乎和微博的回复风格。"
|
||||
|
||||
@@ -85,11 +85,11 @@ reflect_operator_id = "" # 表达反思操作员ID,格式:platform:id:type (
|
||||
allow_reflect = [] # 允许进行表达反思的聊天流ID列表,格式:["qq:123456:private", "qq:654321:group", ...],只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true)
|
||||
|
||||
|
||||
[chat] #麦麦的聊天设置
|
||||
talk_value = 1 #聊天频率,越小越沉默,范围0-1
|
||||
[chat] # 麦麦的聊天设置
|
||||
talk_value = 1 # 聊天频率,越小越沉默,范围0-1
|
||||
mentioned_bot_reply = true # 是否启用提及必回复
|
||||
max_context_size = 30 # 上下文长度
|
||||
planner_smooth = 2 #规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐1-5,0为关闭,必须大于等于0
|
||||
planner_smooth = 2 # 规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐1-5,0为关闭,必须大于等于0
|
||||
|
||||
enable_talk_value_rules = true # 是否启用动态发言频率规则
|
||||
|
||||
@@ -143,8 +143,8 @@ ban_words = [
|
||||
|
||||
ban_msgs_regex = [
|
||||
# 需要过滤的消息(原始消息)匹配的正则表达式,匹配到的消息将被过滤,若不了解正则表达式请勿修改
|
||||
#"https?://[^\\s]+", # 匹配https链接
|
||||
#"\\d{4}-\\d{2}-\\d{2}", # 匹配日期
|
||||
# "https?://[^\\s]+", # 匹配https链接
|
||||
# "\\d{4}-\\d{2}-\\d{2}", # 匹配日期
|
||||
]
|
||||
|
||||
|
||||
@@ -177,7 +177,7 @@ webui_graph_default_limit = 200 # WebUI /graph 默认返回的最大节点数,
|
||||
keyword_rules = [
|
||||
{ keywords = ["人机", "bot", "机器", "入机", "robot", "机器人", "ai", "AI"], reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" },
|
||||
{ keywords = ["测试关键词回复", "test"], reaction = "回答测试成功" },
|
||||
#{ keywords = ["你好", "hello"], reaction = "你好,有什么可以帮你?" }
|
||||
# { keywords = ["你好", "hello"], reaction = "你好,有什么可以帮你?" }
|
||||
# 在此处添加更多规则,格式同上
|
||||
]
|
||||
|
||||
@@ -250,7 +250,7 @@ enable = true
|
||||
chat_prompts = []
|
||||
|
||||
|
||||
#此系统暂时移除,无效配置
|
||||
# 此系统暂时移除,无效配置
|
||||
[relationship]
|
||||
enable_relationship = true # 是否启用关系系统
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "1.8.1"
|
||||
version = "1.8.2"
|
||||
|
||||
# 配置文件版本号迭代规则同bot_config.toml
|
||||
|
||||
@@ -46,7 +46,7 @@ name = "deepseek-v3" # 模型名称(可随意命名,在后面
|
||||
api_provider = "DeepSeek" # API服务商名称(对应在api_providers中配置的服务商名称)
|
||||
price_in = 2.0 # 输入价格(用于API调用统计,单位:元/ M token)(可选,若无该字段,默认值为0)
|
||||
price_out = 8.0 # 输出价格(用于API调用统计,单位:元/ M token)(可选,若无该字段,默认值为0)
|
||||
#force_stream_mode = true # 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出,若无该字段,默认值为false)
|
||||
# force_stream_mode = true # 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出,若无该字段,默认值为false)
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
|
||||
@@ -56,6 +56,7 @@ price_in = 2.0
|
||||
price_out = 3.0
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = false # 不启用思考
|
||||
# temperature = 0.5 # 可选:为该模型单独指定温度,会覆盖任务配置中的温度
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
|
||||
@@ -64,7 +65,8 @@ api_provider = "SiliconFlow"
|
||||
price_in = 2.0
|
||||
price_out = 3.0
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = true # 不启用思考
|
||||
enable_thinking = true # 启用思考
|
||||
# temperature = 0.7 # 可选:为该模型单独指定温度,会覆盖任务配置中的温度
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen3-Next-80B-A3B-Instruct"
|
||||
@@ -89,8 +91,7 @@ api_provider = "SiliconFlow"
|
||||
price_in = 3.5
|
||||
price_out = 14.0
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = true # 不启用思考
|
||||
|
||||
enable_thinking = true # 启用思考
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-R1"
|
||||
@@ -134,51 +135,62 @@ price_out = 0
|
||||
model_list = ["siliconflow-deepseek-v3.2"] # 使用的模型列表,每个子项对应上面的模型名称(name)
|
||||
temperature = 0.2 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 2048 # 最大输出token数
|
||||
slow_threshold = 15.0 # 慢请求阈值(秒),模型等待回复时间超过此值会输出警告日志
|
||||
|
||||
[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
|
||||
model_list = ["qwen3-30b","qwen3-next-80b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 2048
|
||||
slow_threshold = 10.0
|
||||
|
||||
[model_task_config.tool_use] #工具调用模型,需要使用支持工具调用的模型
|
||||
model_list = ["qwen3-30b","qwen3-next-80b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
slow_threshold = 10.0
|
||||
|
||||
[model_task_config.replyer] # 首要回复模型,还用于表达器和表达方式学习
|
||||
model_list = ["siliconflow-deepseek-v3.2","siliconflow-deepseek-v3.2-think","siliconflow-glm-4.6","siliconflow-glm-4.6-think"]
|
||||
temperature = 0.3 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 2048
|
||||
slow_threshold = 25.0
|
||||
|
||||
[model_task_config.planner] #决策:负责决定麦麦该什么时候回复的模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.3
|
||||
max_tokens = 800
|
||||
slow_threshold = 12.0
|
||||
|
||||
[model_task_config.vlm] # 图像识别模型
|
||||
model_list = ["qwen3-vl-30"]
|
||||
max_tokens = 256
|
||||
slow_threshold = 15.0
|
||||
|
||||
[model_task_config.voice] # 语音识别模型
|
||||
model_list = ["sensevoice-small"]
|
||||
slow_threshold = 12.0
|
||||
|
||||
#嵌入模型
|
||||
# 嵌入模型
|
||||
[model_task_config.embedding]
|
||||
model_list = ["bge-m3"]
|
||||
slow_threshold = 5.0
|
||||
|
||||
#------------LPMM知识库模型------------
|
||||
# ------------LPMM知识库模型------------
|
||||
|
||||
[model_task_config.lpmm_entity_extract] # 实体提取模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
slow_threshold = 20.0
|
||||
|
||||
[model_task_config.lpmm_rdf_build] # RDF构建模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
slow_threshold = 20.0
|
||||
|
||||
[model_task_config.lpmm_qa] # 问答模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
slow_threshold = 20.0
|
||||
|
||||
@@ -8,23 +8,23 @@ if edges:
|
||||
e = edges[0]
|
||||
print(f"Edge tuple: {e}")
|
||||
print(f"Edge tuple type: {type(e)}")
|
||||
|
||||
|
||||
edge_data = kg.graph[e[0], e[1]]
|
||||
print(f"\nEdge data type: {type(edge_data)}")
|
||||
print(f"Edge data: {edge_data}")
|
||||
print(f"Has 'get' method: {hasattr(edge_data, 'get')}")
|
||||
print(f"Is dict: {isinstance(edge_data, dict)}")
|
||||
|
||||
|
||||
# 尝试不同的访问方式
|
||||
try:
|
||||
print(f"\nUsing []: {edge_data['weight']}")
|
||||
except Exception as e:
|
||||
print(f"Using [] failed: {e}")
|
||||
|
||||
|
||||
try:
|
||||
print(f"Using .get(): {edge_data.get('weight')}")
|
||||
except Exception as e:
|
||||
print(f"Using .get() failed: {e}")
|
||||
|
||||
|
||||
# 查看所有属性
|
||||
print(f"\nDir: {[x for x in dir(edge_data) if not x.startswith('_')]}")
|
||||
|
||||
File diff suppressed because one or more lines are too long
16
webui/dist/assets/codemirror-BHeANvwm.js
vendored
Normal file
16
webui/dist/assets/codemirror-BHeANvwm.js
vendored
Normal file
File diff suppressed because one or more lines are too long
5
webui/dist/assets/dnd-Dyi3CnuX.js
vendored
Normal file
5
webui/dist/assets/dnd-Dyi3CnuX.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
webui/dist/assets/icons-Bom2zaMH.js
vendored
1
webui/dist/assets/icons-Bom2zaMH.js
vendored
File diff suppressed because one or more lines are too long
1
webui/dist/assets/icons-DUfC2NKX.js
vendored
Normal file
1
webui/dist/assets/icons-DUfC2NKX.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
webui/dist/assets/index-BwXMDuHV.css
vendored
1
webui/dist/assets/index-BwXMDuHV.css
vendored
File diff suppressed because one or more lines are too long
381
webui/dist/assets/index-CrIP7TYI.js
vendored
381
webui/dist/assets/index-CrIP7TYI.js
vendored
File diff suppressed because one or more lines are too long
52
webui/dist/assets/index-DJb_iiTR.js
vendored
Normal file
52
webui/dist/assets/index-DJb_iiTR.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
webui/dist/assets/index-QJDQd8Xo.css
vendored
Normal file
1
webui/dist/assets/index-QJDQd8Xo.css
vendored
Normal file
File diff suppressed because one or more lines are too long
295
webui/dist/assets/markdown-A1ShuLvG.js
vendored
Normal file
295
webui/dist/assets/markdown-A1ShuLvG.js
vendored
Normal file
File diff suppressed because one or more lines are too long
27
webui/dist/assets/misc-DyBU7ISD.js
vendored
Normal file
27
webui/dist/assets/misc-DyBU7ISD.js
vendored
Normal file
File diff suppressed because one or more lines are too long
45
webui/dist/assets/radix-core-C3XKqQJw.js
vendored
Normal file
45
webui/dist/assets/radix-core-C3XKqQJw.js
vendored
Normal file
File diff suppressed because one or more lines are too long
12
webui/dist/assets/radix-extra-BM7iD6Dt.js
vendored
Normal file
12
webui/dist/assets/radix-extra-BM7iD6Dt.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
webui/dist/assets/reactflow-B3n3_Vkw.js
vendored
Normal file
2
webui/dist/assets/reactflow-B3n3_Vkw.js
vendored
Normal file
File diff suppressed because one or more lines are too long
5
webui/dist/assets/router-CWhjJi2n.js
vendored
Normal file
5
webui/dist/assets/router-CWhjJi2n.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
webui/dist/assets/router-SinpzM5S.js
vendored
2
webui/dist/assets/router-SinpzM5S.js
vendored
File diff suppressed because one or more lines are too long
45
webui/dist/assets/ui-vendor-BLBhIcJ8.js
vendored
45
webui/dist/assets/ui-vendor-BLBhIcJ8.js
vendored
File diff suppressed because one or more lines are too long
11
webui/dist/assets/uppy-BHC3OXBx.js
vendored
Normal file
11
webui/dist/assets/uppy-BHC3OXBx.js
vendored
Normal file
File diff suppressed because one or more lines are too long
6
webui/dist/assets/utils-CCeOswSm.js
vendored
Normal file
6
webui/dist/assets/utils-CCeOswSm.js
vendored
Normal file
File diff suppressed because one or more lines are too long
20
webui/dist/index.html
vendored
20
webui/dist/index.html
vendored
@@ -7,13 +7,21 @@
|
||||
<link rel="icon" type="image/x-icon" href="/maimai.ico" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>MaiBot Dashboard</title>
|
||||
<script type="module" crossorigin src="/assets/index-CrIP7TYI.js"></script>
|
||||
<script type="module" crossorigin src="/assets/index-DJb_iiTR.js"></script>
|
||||
<link rel="modulepreload" crossorigin href="/assets/react-vendor-Dtc2IqVY.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/router-SinpzM5S.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/charts-BH1Uno6i.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/ui-vendor-BLBhIcJ8.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/icons-Bom2zaMH.js">
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-BwXMDuHV.css">
|
||||
<link rel="modulepreload" crossorigin href="/assets/router-CWhjJi2n.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/utils-CCeOswSm.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/radix-core-C3XKqQJw.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/radix-extra-BM7iD6Dt.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/charts-Dhri-zxi.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/icons-DUfC2NKX.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/codemirror-BHeANvwm.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/misc-DyBU7ISD.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/dnd-Dyi3CnuX.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/uppy-BHC3OXBx.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/markdown-A1ShuLvG.js">
|
||||
<link rel="modulepreload" crossorigin href="/assets/reactflow-B3n3_Vkw.js">
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-QJDQd8Xo.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root" class="notranslate"></div>
|
||||
|
||||
Reference in New Issue
Block a user