merge: sync upstream/r-dev and resolve real conflicts

This commit is contained in:
A-Dawn
2026-03-24 15:36:26 +08:00
114 changed files with 15841 additions and 5236 deletions

View File

@@ -17,6 +17,7 @@
1. 尽量保持良好的注释
2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。
3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。
4. 对于类,方法以及模块的注释,首选使用的注释格式为 Google DocStr 格式,但保证语言为简体中文
## 类型注解规范
1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。
2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解)
@@ -35,3 +36,7 @@
# 运行/调试/构建/测试/依赖
优先使用uv
依赖项以 pyproject.toml 为准
# 语言规范
项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都应该首要以简体中文为首要实现目标

137
README.md
View File

@@ -1,21 +1,21 @@
<div align="center">
<!-- Language Switcher -->
<a href="README.md">简体中文</a> | <a href="docs/README_EN.md">English</a>
<a href="docs/README_CN.md">简体中文</a> | <a href="README.md">English</a>
<br>
<br>
<h1>麦麦 MaiBot <sub><small>MaiSaka</small></sub></h1>
<h1>MaiBot <sub><small>MaiSaka</small></sub></h1>
<!-- Badges Row -->
<p>
<img src="https://img.shields.io/badge/Python-3.10+-blue" alt="Python Version">
<img src="https://img.shields.io/github/license/Mai-with-u/MaiBot?label=%E5%8D%8F%E8%AE%AE" alt="License">
<img src="https://img.shields.io/badge/状态-开发中-yellow" alt="Status">
<img src="https://img.shields.io/github/contributors/Mai-with-u/MaiBot.svg?style=flat&label=%E8%B4%A1%E7%8C%AE%E8%80%85" alt="Contributors">
<img src="https://img.shields.io/github/forks/Mai-with-u/MaiBot.svg?style=flat&label=%E5%88%86%E6%94%AF%E6%95%B0" alt="Forks">
<img src="https://img.shields.io/github/stars/Mai-with-u/MaiBot?style=flat&label=%E6%98%9F%E6%A0%87%E6%95%B0" alt="Stars">
<img src="https://img.shields.io/github/license/Mai-with-u/MaiBot?label=License" alt="License">
<img src="https://img.shields.io/badge/Status-In%20Development-yellow" alt="Status">
<img src="https://img.shields.io/github/contributors/Mai-with-u/MaiBot.svg?style=flat&label=Contributors" alt="Contributors">
<img src="https://img.shields.io/github/forks/Mai-with-u/MaiBot.svg?style=flat&label=Forks" alt="Forks">
<img src="https://img.shields.io/github/stars/Mai-with-u/MaiBot?style=flat&label=Stars" alt="Stars">
<a href="https://deepwiki.com/DrSmoothl/MaiBot"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
</p>
</div>
@@ -25,31 +25,24 @@
<!-- Mascot on the Right (Float) -->
<img src="depends-data/maimai-v2.png" align="right" width="40%" alt="MaiBot Character" style="margin-left: 20px; margin-bottom: 20px;">
## 介绍
## Introduction
麦麦MaiSaka 是一个基于大语言模型的可交互智能体
MaiSaka is an interactive agent based on large language models.
MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务的“有帮助的助手”,她还是一个致力于了解你,并以真实人类的风格进行交互的数字生命,她不追求完美,她不追求高效,但追求亲切和真实。
MaiSaka is more than just a bot, and more than a "helpful assistant" that completes tasks. She is a digital life form that tries to understand you and interact in a genuinely human style. She does not pursue perfection or efficiency above all else. She pursues warmth and authenticity.
- 💭 **No one likes GPT-sounding dialogue**: MaiSaka uses a more natural conversational style. Instead of long-winded markdown-heavy replies, she chats in a way that feels casual, varied, and human.
- 🎭 **No longer stuck in rigid Q&A**: She knows when to speak, how to read the room, when to join a conversation, and when to stay quiet.
- 🧠 **MaiSaka becoming human**: In group conversations, MaiSaka imitates how people around her speak, learns new slang and in-group language, and keeps evolving.
- ❤️ **Always learning more about you**: Inspired by personality theory in psychology, MaiSaka gradually builds an understanding of your preferences, traits, habits, and behavior style.
- 🔌 **Plugin system**: Provides powerful APIs and an event system with virtually unlimited room for extension.
- 💭 **没有人喜欢GPT的语言风格**麦麦使用了更加自然贴合人类对话习惯的交互方式不是长篇大论或者markdown格式的分点而是或长或短的闲谈。
- 🎭 **不再是傻乎乎的一问一答**:懂得在合适的时间说话,把握聊天中的气氛,在合适的时候开口,在合适的时候闭嘴。
- 🧠 **麦麦·成为人类**:在多人对话中,麦麦会模仿其他人的的说话风格,还会自主理解新词或者小圈子里的黑话,不断进化。
- ❤️ **永远都在更加了解你**:基于心理学中人格理论,麦麦会不断积累对于你的了解,不论是你的信息,喜恶或是行为风格,她都记在心里。
- 🔌 **插件系统**:提供强大的 API 和事件系统,无限扩展可能。
### 快速导航
### Quick Navigation
<p>
<a href="https://www.bilibili.com/video/BV1amAneGE3P">🌟 演示视频</a> &nbsp;|&nbsp;
<a href="#-更新和安装">📦 快速入门</a> &nbsp;|&nbsp;
<a href="#-部署教程">📃 核心文档</a> &nbsp;|&nbsp;
<a href="#-讨论与社区">💬 加入社区</a>
<a href="https://www.bilibili.com/video/BV1amAneGE3P">🌟 Demo Video</a> &nbsp;|&nbsp;
<a href="#-updates-and-installation">📦 Quick Start</a> &nbsp;|&nbsp;
<a href="#-deployment-guide">📃 Core Documentation</a> &nbsp;|&nbsp;
<a href="#-discussion-and-community">💬 Join Community</a>
</p>
<!-- Clear float to ensure subsequent content starts below the image area if text is short -->
@@ -60,103 +53,103 @@ MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<picture>
<source media="(max-width: 600px)" srcset="depends-data/video.png" width="100%">
<img src="depends-data/video.png" width="60%" alt="麦麦演示视频" style="border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<img src="depends-data/video.png" width="60%" alt="MaiSaka Demo Video" style="border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
</picture>
<br>
<small>前往观看麦麦演示视频</small>
<small>Watch the MaiSaka demo video</small>
</a>
</div>
---
## 🔥 更新和安装
## 🔥 Updates and Installation
> **最新版本: v1.0.0** ([📄 更新日志](changelogs/changelog.md))
> **Latest Version: v1.0.0** ([📄 Changelog](changelogs/changelog.md))
- **下载**: 前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
- **启动器**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (仅支持 MacOS, 早期开发中)
- **Download**: Visit the [Release](https://github.com/MaiM-with-u/MaiBot/releases/) page to get the latest version.
- **Launcher**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (MacOS only, still in early development).
| 分支 | 说明 |
| Branch | Description |
| :--- | :--- |
| `main` | ✅ **稳定发布版本 (推荐)** |
| `dev` | 🚧 开发测试版本,包含新功能,可能不稳定 |
| `main` | ✅ **Stable release (recommended)** |
| `dev` | 🚧 Development testing branch with new features, may be unstable |
### 📚 部署教程
👉 **[🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)**
### 📚 Deployment Guide
👉 **[🚀 Latest Deployment Guide](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)**
---
## 💬 讨论与社区
## 💬 Discussion and Community
我们欢迎所有对 MaiBot 感兴趣的朋友加入!
We welcome everyone interested in MaiBot to join us.
| 类别 | 群组 | 说明 |
| Category | Group | Description |
| :--- | :--- | :--- |
| **技术交流** | [麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) | 技术交流/答疑 |
| **技术交流** | [麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) | 技术交流/答疑 |
| **技术交流** | [麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) | 技术交流/答疑 |
| **闲聊吹水** | [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) | 仅限闲聊,不答疑 |
| **插件开发** | [插件开发群](https://qm.qq.com/q/1036092828) | 进阶开发与测试 |
| **Technical** | [MaiBrain EEG](https://qm.qq.com/q/RzmCiRtHEW) | Technical discussion / Q&A |
| **Technical** | [MaiBrain MRI](https://qm.qq.com/q/VQ3XZrWgMs) | Technical discussion / Q&A |
| **Technical** | [Mai Wants to Be a VTuber](https://qm.qq.com/q/wGePTl1UyY) | Technical discussion / Q&A |
| **Casual Chat** | [Mai Casual Chat Group](https://qm.qq.com/q/JxvHZnxyec) | Casual chat only, no support |
| **Plugin Development** | [Plugin Dev Group](https://qm.qq.com/q/1036092828) | Advanced development and testing |
---
## 📚 文档
## 📚 Documentation
> [!NOTE]
> 部分内容可能更新不够及时,请注意版本对应。
> Some content may not be updated promptly, so please pay attention to version compatibility.
- **[📚 核心 Wiki 文档](https://docs.mai-mai.org)**: 最全面的文档中心,了解麦麦的一切。
- **[📚 Core Wiki Documentation](https://docs.mai-mai.org)**: The most comprehensive documentation hub for everything about MaiSaka.
### 🧩 衍生项目
### 🧩 Related Projects
- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: 让麦麦在B站开播
- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: 基于 MaiCore 0.10.0 的增强型 Fork更稳定更有趣。
- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: 让麦麦陪你玩 Minecraft (暂时停止维护中)。
- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: Let MaiSaka stream on Bilibili.
- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: An enhanced fork based on MaiCore 0.10.0, with improved stability and more fun features.
- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: Let MaiSaka accompany you in Minecraft (currently paused).
---
## 💡 设计理念
## 💡 Design Philosophy
> **千石可乐说:**
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。
> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。
> **SengokuCola says:**
> - This project originally started as a few extra features for the NiuNiu bot, but it kept growing until a full rewrite became inevitable. The goal was to create a "life form" active in QQ group chats, not a feature-complete bot, but something as human-like and real-feeling as possible.
> - The core design principle is: "more lifelike, not merely better."
> - If people truly want AI companionship, not everyone needs a perfect "helpful assistant" that solves every problem. Some people may want a life form that can make mistakes and has its own perceptions and thoughts.
> **xxxxx说:**
> **xxxxx says:**
> *Code is open, but the soul is yours.*
---
## 🙋 贡献和致谢
## 🙋 Contributing and Acknowledgments
欢迎参与贡献!请先阅读 [贡献指南](docs-src/CONTRIBUTE.md)
Contributions are welcome. Please read the [Contribution Guide](docs-src/CONTRIBUTE.md) first.
### 🌟 贡献者
### 🌟 Contributors
<a href="https://github.com/MaiM-with-u/MaiBot/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=MaiM-with-u/MaiBot" />
</a>
### ❤️ 特别致谢
### ❤️ Special Thanks
- **[萨卡班甲鱼](https://en.wikipedia.org/wiki/Sacabambaspis)**: 千石可乐很喜欢的生物。
- **[略nd](https://space.bilibili.com/1344099355)**: 🎨 为麦麦绘制早期的精美人设。
- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: 🚀 现代化的基于 NTQQ 的 Bot 协议实现。
- **[Sacabambaspis](https://en.wikipedia.org/wiki/Sacabambaspis)**: SengokuCola's favorite creature.
- **[略nd](https://space.bilibili.com/1344099355)**: Drew MaiSaka's beautiful early character design.
- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: A modern NTQQ-based bot protocol implementation.
---
## 📊 仓库状态
## 📊 Repository Status
![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "麦麦仓库状态")
![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "MaiBot Repository Status")
### Star 趋势
[![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot)
### Star History
[![Star History](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot)
---
## 📌 注意事项 & License
## 📌 Notice & License
> [!IMPORTANT]
> 使用前请阅读 [用户协议 (EULA)](EULA.md) 和 [隐私协议](PRIVACY.md)。AI 生成内容请仔细甄别。
> Please read the [End User License Agreement (EULA)](EULA.md) and [Privacy Policy](PRIVACY.md) before use. Please evaluate AI-generated content carefully.
**License**: GPL-3.0

162
docs/README_CN.md Normal file
View File

@@ -0,0 +1,162 @@
<div align="center">
<!-- Language Switcher -->
<a href="README_CN.md">简体中文</a> | <a href="../README.md">English</a>
<br>
<br>
<h1>麦麦 MaiBot <sub><small>MaiSaka</small></sub></h1>
<!-- Badges Row -->
<p>
<img src="https://img.shields.io/badge/Python-3.10+-blue" alt="Python Version">
<img src="https://img.shields.io/github/license/Mai-with-u/MaiBot?label=%E5%8D%8F%E8%AE%AE" alt="License">
<img src="https://img.shields.io/badge/状态-开发中-yellow" alt="Status">
<img src="https://img.shields.io/github/contributors/Mai-with-u/MaiBot.svg?style=flat&label=%E8%B4%A1%E7%8C%AE%E8%80%85" alt="Contributors">
<img src="https://img.shields.io/github/forks/Mai-with-u/MaiBot.svg?style=flat&label=%E5%88%86%E6%94%AF%E6%95%B0" alt="Forks">
<img src="https://img.shields.io/github/stars/Mai-with-u/MaiBot?style=flat&label=%E6%98%9F%E6%A0%87%E6%95%B0" alt="Stars">
<a href="https://deepwiki.com/DrSmoothl/MaiBot"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
</p>
</div>
<br>
<!-- Mascot on the Right (Float) -->
<img src="../depends-data/maimai-v2.png" align="right" width="40%" alt="MaiBot Character" style="margin-left: 20px; margin-bottom: 20px;">
## 介绍
麦麦MaiSaka 是一个基于大语言模型的可交互智能体
MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务的“有帮助的助手”,她还是一个致力于了解你,并以真实人类的风格进行交互的数字生命,她不追求完美,她不追求高效,但追求亲切和真实。
- 💭 **没有人喜欢GPT的语言风格**麦麦使用了更加自然贴合人类对话习惯的交互方式不是长篇大论或者markdown格式的分点而是或长或短的闲谈。
- 🎭 **不再是傻乎乎的一问一答**:懂得在合适的时间说话,把握聊天中的气氛,在合适的时候开口,在合适的时候闭嘴。
- 🧠 **麦麦·成为人类**:在多人对话中,麦麦会模仿其他人的的说话风格,还会自主理解新词或者小圈子里的黑话,不断进化。
- ❤️ **永远都在更加了解你**:基于心理学中人格理论,麦麦会不断积累对于你的了解,不论是你的信息,喜恶或是行为风格,她都记在心里。
- 🔌 **插件系统**:提供强大的 API 和事件系统,无限扩展可能。
### 快速导航
<p>
<a href="https://www.bilibili.com/video/BV1amAneGE3P">🌟 演示视频</a> &nbsp;|&nbsp;
<a href="#-更新和安装">📦 快速入门</a> &nbsp;|&nbsp;
<a href="#-部署教程">📃 核心文档</a> &nbsp;|&nbsp;
<a href="#-讨论与社区">💬 加入社区</a>
</p>
<!-- Clear float to ensure subsequent content starts below the image area if text is short -->
<br clear="both">
<div align="center">
<br>
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<picture>
<source media="(max-width: 600px)" srcset="../depends-data/video.png" width="100%">
<img src="../depends-data/video.png" width="60%" alt="麦麦演示视频" style="border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
</picture>
<br>
<small>前往观看麦麦演示视频</small>
</a>
</div>
---
## 🔥 更新和安装
> **最新版本: v1.0.0** ([📄 更新日志](../changelogs/changelog.md))
- **下载**: 前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
- **启动器**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (仅支持 MacOS, 早期开发中)
| 分支 | 说明 |
| :--- | :--- |
| `main` | ✅ **稳定发布版本 (推荐)** |
| `dev` | 🚧 开发测试版本,包含新功能,可能不稳定 |
### 📚 部署教程
👉 **[🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)**
---
## 💬 讨论与社区
我们欢迎所有对 MaiBot 感兴趣的朋友加入!
| 类别 | 群组 | 说明 |
| :--- | :--- | :--- |
| **技术交流** | [麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) | 技术交流/答疑 |
| **技术交流** | [麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) | 技术交流/答疑 |
| **技术交流** | [麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) | 技术交流/答疑 |
| **闲聊吹水** | [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) | 仅限闲聊,不答疑 |
| **插件开发** | [插件开发群](https://qm.qq.com/q/1036092828) | 进阶开发与测试 |
---
## 📚 文档
> [!NOTE]
> 部分内容可能更新不够及时,请注意版本对应。
- **[📚 核心 Wiki 文档](https://docs.mai-mai.org)**: 最全面的文档中心,了解麦麦的一切。
### 🧩 衍生项目
- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: 让麦麦在B站开播
- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: 基于 MaiCore 0.10.0 的增强型 Fork更稳定更有趣。
- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: 让麦麦陪你玩 Minecraft (暂时停止维护中)。
---
## 💡 设计理念
> **千石可乐说:**
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。
> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。
> **xxxxx说**
> *Code is open, but the soul is yours.*
---
## 🙋 贡献和致谢
欢迎参与贡献!请先阅读 [贡献指南](../docs-src/CONTRIBUTE.md)。
### 🌟 贡献者
<a href="https://github.com/MaiM-with-u/MaiBot/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=MaiM-with-u/MaiBot" />
</a>
### ❤️ 特别致谢
- **[萨卡班甲鱼](https://en.wikipedia.org/wiki/Sacabambaspis)**: 千石可乐很喜欢的生物。
- **[略nd](https://space.bilibili.com/1344099355)**: 🎨 为麦麦绘制早期的精美人设。
- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: 🚀 现代化的基于 NTQQ 的 Bot 协议实现。
---
## 📊 仓库状态
![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "麦麦仓库状态")
### Star 趋势
[![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot)
---
## 📌 注意事项 & License
> [!IMPORTANT]
> 使用前请阅读 [用户协议 (EULA)](../EULA.md) 和 [隐私协议](../PRIVACY.md)。AI 生成内容请仔细甄别。
**License**: GPL-3.0

View File

@@ -1,7 +1,7 @@
<div align="center">
<!-- Language Switcher -->
<a href="../README.md">简体中文</a> | <a href="README_EN.md">English</a>
<a href="README_CN.md">简体中文</a> | <a href="../README.md">English</a>
<br>
<br>

View File

@@ -41,7 +41,7 @@ This plan is based on the checked-in code, not on assumptions from previous draf
| `src/person_info/person_info.py:247` | `_is_bot_self(self, platform, user_id)` | Duplicate logic with same QQ fallback |
Wrong-order call sites (8 total):
- `src/bw_learner/expression_learner.py` x3 (lines 158, 241, 301)
- `src/learners/expression_learner.py` x3 (lines 158, 241, 301)
- `src/common/utils/utils_message.py` x4 (lines 370, 440, 476, 515)
- `src/webui/routers/chat/support.py` x1 (line 65)
@@ -122,7 +122,7 @@ Make `src/chat/utils/utils.py::is_bot_self(platform, user_id)` the only real imp
- `src/common/utils/system_utils.py`
- `src/chat/utils/utils.py`
- `src/person_info/person_info.py`
- `src/bw_learner/expression_learner.py`
- `src/learners/expression_learner.py`
- `src/common/utils/utils_message.py`
- `src/webui/routers/chat/support.py`
- tests
@@ -468,7 +468,7 @@ When stopping, name: the exact file(s), the blocking mismatch, why it is outside
| Phase | Allowed files |
|-------|---------------|
| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/bw_learner/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) |
| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/learners/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) |
| Phase 1 | `src/chat/utils/utils.py`, `src/chat/planner_actions/planner.py`, `src/chat/utils/statistic.py`, `src/common/message_repository.py`, `src/webui/routers/chat/support.py`, `src/services/send_service.py`, `src/chat/replyer/group_generator.py`, `src/chat/replyer/private_generator.py`, `src/chat/brain_chat/PFC/message_sender.py`, `src/person_info/person_info.py`, tests |
### INVALID OUTPUT EXAMPLES

View File

@@ -1,58 +1,40 @@
{
"manifest_version": 1,
"name": "发言频率控制插件|BetterFrequency Plugin",
"manifest_version": 2,
"version": "2.0.0",
"description": "控制聊天频率支持设置focus_value和talk_frequency调整值提供命令",
"name": "发言频率控制插件|BetterFrequency Plugin",
"description": "控制聊天频率,支持设置 focus_value 和 talk_frequency 调整值,并提供命令入口。",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "1.0.0"
"urls": {
"repository": "https://github.com/SengokuCola/BetterFrequency",
"homepage": "https://github.com/SengokuCola/BetterFrequency",
"documentation": "https://github.com/SengokuCola/BetterFrequency",
"issues": "https://github.com/SengokuCola/BetterFrequency/issues"
},
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
"repository_url": "https://github.com/SengokuCola/BetterFrequency",
"keywords": [
"frequency",
"control",
"talk_frequency",
"plugin",
"shortcut"
"host_application": {
"min_version": "1.0.0",
"max_version": "1.0.0"
},
"sdk": {
"min_version": "2.0.0",
"max_version": "2.99.99"
},
"dependencies": [],
"capabilities": [
"send.text",
"frequency.set_adjust",
"frequency.get_current_talk_value",
"frequency.get_adjust"
],
"categories": [
"Chat",
"Frequency",
"Control"
],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "frequency",
"components": [
{
"type": "command",
"name": "set_talk_frequency",
"description": "设置当前聊天的talk_frequency调整值",
"pattern": "/chat talk_frequency <数字> 或 /chat t <数字>"
},
{
"type": "command",
"name": "show_frequency",
"description": "显示当前聊天的频率控制状态",
"pattern": "/chat show 或 /chat s"
}
],
"features": [
"设置talk_frequency调整值",
"调整当前聊天的发言频率",
"显示当前频率控制状态",
"实时频率控制调整",
"命令执行反馈(不保存消息)",
"支持完整命令和简化命令",
"快速操作支持"
"i18n": {
"default_locale": "zh-CN",
"locales_path": "_locales",
"supported_locales": [
"zh-CN"
]
},
"id": "SengokuCola.BetterFrequency"
}
"id": "sengokucola.betterfrequency"
}

View File

@@ -3,12 +3,18 @@
通过 /chat 命令设置和查看聊天频率。
"""
from maibot_sdk import MaiBotPlugin, Command
from maibot_sdk import Command, MaiBotPlugin
class BetterFrequencyPlugin(MaiBotPlugin):
"""聊天频率控制插件"""
async def on_load(self) -> None:
"""处理插件加载。"""
async def on_unload(self) -> None:
"""处理插件卸载。"""
@Command(
"set_talk_frequency",
description="设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>",
@@ -80,6 +86,25 @@ class BetterFrequencyPlugin(MaiBotPlugin):
await self.ctx.send.text(status_msg, stream_id)
return True, None, False
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
"""处理配置热重载事件。
Args:
scope: 配置变更范围。
config_data: 最新配置数据。
version: 配置版本号。
"""
del scope
del config_data
del version
def create_plugin() -> BetterFrequencyPlugin:
"""创建聊天频率插件实例。
Returns:
BetterFrequencyPlugin: 新的聊天频率插件实例。
"""
def create_plugin():
return BetterFrequencyPlugin()

View File

@@ -1,67 +1,42 @@
{
"manifest_version": 1,
"name": "MCP桥接插件",
"manifest_version": 2,
"version": "2.0.0",
"description": "MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot使麦麦能够调用外部 MCP 工具",
"name": "MCP桥接插件",
"description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot使麦麦能够调用外部 MCP 工具。",
"author": {
"name": "CharTyr",
"url": "https://github.com/CharTyr"
},
"license": "AGPL-3.0",
"host_application": {
"min_version": "0.11.6"
"urls": {
"repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
"homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
"documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
"issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues"
},
"homepage_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
"repository_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
"keywords": [
"mcp",
"bridge",
"tool",
"integration",
"resources",
"prompts",
"post-process",
"cache",
"trace",
"permissions",
"import",
"export",
"claude-desktop",
"workflow",
"react",
"agent"
"host_application": {
"min_version": "0.11.6",
"max_version": "1.0.0"
},
"sdk": {
"min_version": "2.0.0",
"max_version": "2.99.99"
},
"dependencies": [
{
"type": "python_package",
"name": "mcp",
"version_spec": ">=0.0.0"
}
],
"categories": [
"工具扩展",
"外部集成"
"capabilities": [
"send.text"
],
"default_locale": "zh-CN",
"plugin_info": {
"is_built_in": false,
"components": [],
"features": [
"支持多个 MCP 服务器",
"自动发现并注册 MCP 工具",
"支持 stdio、SSE、HTTP、Streamable HTTP 四种传输方式",
"工具参数自动转换",
"心跳检测与自动重连",
"调用统计(次数、成功率、耗时)",
"WebUI 配置支持",
"Resources 支持(实验性)",
"Prompts 支持(实验性)",
"结果后处理LLM 摘要提炼)",
"工具禁用管理",
"调用链路追踪",
"工具调用缓存LRU",
"工具权限控制(群/用户级别)",
"配置导入导出Claude Desktop mcpServers",
"断路器模式(故障快速失败)",
"状态实时刷新",
"Workflow 硬流程(顺序执行多个工具)",
"Workflow 快速添加(表单式配置)",
"ReAct 软流程LLM 自主多轮调用)",
"双轨制架构(软流程 + 硬流程)"
"i18n": {
"default_locale": "zh-CN",
"supported_locales": [
"zh-CN"
]
},
"id": "MaiBot Community.MCPBridgePlugin"
"id": "chartyr.mcpbridge-plugin"
}

View File

@@ -1,68 +1,44 @@
{
"manifest_version": 1,
"name": "BetterEmoji",
"manifest_version": 2,
"version": "2.0.0",
"name": "BetterEmoji",
"description": "更好的表情包管理插件",
"author": {
"name": "SengokuCola",
"url": "https://github.com/SengokuCola"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "1.0.0"
"urls": {
"repository": "https://github.com/SengokuCola/BetterEmoji",
"homepage": "https://github.com/SengokuCola/BetterEmoji",
"documentation": "https://github.com/SengokuCola/BetterEmoji",
"issues": "https://github.com/SengokuCola/BetterEmoji/issues"
},
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
"keywords": [
"emoji",
"manage",
"plugin"
"host_application": {
"min_version": "1.0.0",
"max_version": "1.0.0"
},
"sdk": {
"min_version": "2.0.0",
"max_version": "2.99.99"
},
"dependencies": [],
"capabilities": [
"emoji.get_random",
"emoji.get_count",
"emoji.get_info",
"emoji.get_all",
"emoji.register_emoji",
"emoji.delete_emoji",
"send.text",
"send.forward"
],
"categories": [
"Emoji",
"Management"
],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "emoji_manage",
"capabilities": [
"emoji.get_random",
"emoji.get_count",
"emoji.get_info",
"emoji.get_all",
"emoji.register_emoji",
"emoji.delete_emoji",
"send.text",
"send.forward"
],
"components": [
{
"type": "command",
"name": "add_emoji",
"description": "添加表情包",
"pattern": "/emoji add"
},
{
"type": "command",
"name": "emoji_list",
"description": "列表表情包",
"pattern": "/emoji list"
},
{
"type": "command",
"name": "delete_emoji",
"description": "删除表情包",
"pattern": "/emoji delete"
},
{
"type": "command",
"name": "random_emojis",
"description": "发送多张随机表情包",
"pattern": "/random_emojis"
}
"i18n": {
"default_locale": "zh-CN",
"locales_path": "_locales",
"supported_locales": [
"zh-CN"
]
},
"id": "SengokuCola.BetterEmoji"
}
"id": "sengokucola.betteremoji"
}

View File

@@ -3,17 +3,23 @@
通过 /emoji 命令管理表情包的添加、列表和删除。
"""
from maibot_sdk import Command, MaiBotPlugin
import base64
import datetime
import hashlib
import re
from maibot_sdk import MaiBotPlugin, Command
class EmojiManagePlugin(MaiBotPlugin):
"""表情包管理插件"""
async def on_load(self) -> None:
"""处理插件加载。"""
async def on_unload(self) -> None:
"""处理插件卸载。"""
# ===== 工具方法 =====
@staticmethod
@@ -208,6 +214,25 @@ class EmojiManagePlugin(MaiBotPlugin):
await self.ctx.send.forward(messages, stream_id)
return True, "已发送随机表情包", True
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
"""处理配置热重载事件。
Args:
scope: 配置变更范围。
config_data: 最新配置数据。
version: 配置版本号。
"""
del scope
del config_data
del version
def create_plugin() -> EmojiManagePlugin:
"""创建表情包管理插件实例。
Returns:
EmojiManagePlugin: 新的表情包管理插件实例。
"""
def create_plugin():
return EmojiManagePlugin()

View File

@@ -1,88 +1,41 @@
{
"manifest_version": 1,
"name": "Hello World 示例插件 (Hello World Plugin)",
"manifest_version": 2,
"version": "2.0.0",
"description": "我的第一个MaiCore插件包含问候功能和时间查询等基础示例",
"name": "Hello World 示例插件 (Hello World Plugin)",
"description": "我的第一个 MaiCore 插件,包含问候功能和时间查询等基础示例",
"author": {
"name": "MaiBot开发团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "1.0.0"
"urls": {
"repository": "https://github.com/MaiM-with-u/maibot",
"homepage": "https://github.com/MaiM-with-u/maibot",
"documentation": "https://github.com/MaiM-with-u/maibot",
"issues": "https://github.com/MaiM-with-u/maibot/issues"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": [
"demo",
"example",
"hello",
"greeting",
"tutorial"
"host_application": {
"min_version": "1.0.0",
"max_version": "1.0.0"
},
"sdk": {
"min_version": "2.0.0",
"max_version": "2.99.99"
},
"dependencies": [],
"capabilities": [
"send.text",
"send.forward",
"send.hybrid",
"emoji.get_random",
"config.get"
],
"categories": [
"Examples",
"Tutorial"
],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "example",
"capabilities": [
"send.text",
"send.forward",
"send.hybrid",
"emoji.get_random",
"config.get"
],
"components": [
{
"type": "tool",
"name": "compare_numbers",
"description": "比较两个数的大小"
},
{
"type": "action",
"name": "hello_greeting",
"description": "向用户发送问候消息"
},
{
"type": "action",
"name": "bye_greeting",
"description": "向用户发送告别消息",
"activation_modes": ["keyword"],
"keywords": ["再见", "bye", "88", "拜拜"]
},
{
"type": "command",
"name": "time",
"description": "查询当前时间",
"pattern": "/time"
},
{
"type": "command",
"name": "random_emojis",
"description": "发送多张随机表情包",
"pattern": "/random_emojis"
},
{
"type": "command",
"name": "test",
"description": "测试命令",
"pattern": "/test"
},
{
"type": "event_handler",
"name": "print_message_handler",
"description": "打印接收到的消息"
},
{
"type": "event_handler",
"name": "forward_messages_handler",
"description": "把接收到的消息转发到指定聊天ID"
}
"i18n": {
"default_locale": "zh-CN",
"locales_path": "_locales",
"supported_locales": [
"zh-CN"
]
},
"id": "MaiBot开发团队.maibot"
}
"id": "maibot-team.hello-world-plugin"
}

View File

@@ -3,16 +3,22 @@
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
"""
from maibot_sdk import Action, Command, EventHandler, MaiBotPlugin, Tool
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
import datetime
import random
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
class HelloWorldPlugin(MaiBotPlugin):
"""Hello World 示例插件"""
async def on_load(self) -> None:
"""处理插件加载。"""
async def on_unload(self) -> None:
"""处理插件卸载。"""
# ===== Tool 组件 =====
@Tool(
@@ -146,6 +152,25 @@ class HelloWorldPlugin(MaiBotPlugin):
return True, True, None, None, None
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
"""处理配置热重载事件。
Args:
scope: 配置变更范围。
config_data: 最新配置数据。
version: 配置版本号。
"""
del scope
del config_data
del version
def create_plugin() -> HelloWorldPlugin:
"""创建 Hello World 示例插件实例。
Returns:
HelloWorldPlugin: 新的示例插件实例。
"""
def create_plugin():
return HelloWorldPlugin()

View File

@@ -19,7 +19,7 @@ dependencies = [
"jieba>=0.42.1",
"json-repair>=0.47.6",
"maim-message>=0.6.2",
"maibot-plugin-sdk>=1.2.3,<2.0.0",
"maibot-plugin-sdk>=2.0.0",
"msgpack>=1.1.2",
"numpy>=2.2.6",
"openai>=1.95.0",
@@ -55,6 +55,8 @@ dev = [
[tool.uv]
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
[tool.uv.sources]
maibot-plugin-sdk = { path = "packages/maibot-plugin-sdk", editable = true }
[tool.ruff]

View File

@@ -0,0 +1,89 @@
"""测试表达方式自动检查任务的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_auto_check_engine")
def expression_auto_check_engine_fixture() -> Generator:
"""创建用于表达方式自动检查任务测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.mark.asyncio
async def test_select_expressions_uses_read_only_session(
monkeypatch: pytest.MonkeyPatch,
expression_auto_check_engine,
) -> None:
"""选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
with Session(expression_auto_check_engine) as session:
session.add(
Expression(
situation="表达情绪高涨或生理反应",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
)
session.commit()
auto_commit_calls: list[bool] = []
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
auto_commit_calls.append(auto_commit)
session = Session(expression_auto_check_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
task = ExpressionAutoCheckTask()
expressions = await task._select_expressions(1)
assert auto_commit_calls == [False]
assert len(expressions) == 1
assert expressions[0].id is not None
assert expressions[0].situation == "表达情绪高涨或生理反应"
assert expressions[0].style == "发送💦表情符号"

View File

@@ -0,0 +1,81 @@
"""测试表达方式学习器的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.bw_learner.expression_learner import ExpressionLearner
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_learner_engine")
def expression_learner_engine_fixture() -> Generator:
"""创建用于表达方式学习器测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_find_similar_expression_uses_read_only_session_and_history_content(
monkeypatch: pytest.MonkeyPatch,
expression_learner_engine,
) -> None:
"""查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
import src.bw_learner.expression_learner as expression_learner_module
with Session(expression_learner_engine) as session:
session.add(
Expression(
situation="发送汗滴表情",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
)
session.commit()
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
session = Session(expression_learner_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
learner = ExpressionLearner(session_id="session-a")
result = learner._find_similar_expression("表达情绪高涨或生理反应")
assert result is not None
expression, similarity = result
assert expression.item_id is not None
assert expression.style == "发送💦表情符号"
assert similarity == pytest.approx(1.0)

View File

@@ -0,0 +1,78 @@
"""测试表达方式表结构和基础插入行为。"""
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_engine")
def expression_engine_fixture() -> Generator:
"""创建仅用于表达方式表测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
"""表达方式表在新库中应能自动分配自增主键。"""
with Session(expression_engine) as session:
expression = Expression(
situation="表达情绪高涨或生理反应",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
session.add(expression)
session.commit()
session.refresh(expression)
assert expression.id is not None
assert expression.id > 0
def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
"""相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
with Session(expression_engine) as session:
first_expression = Expression(
situation="对重复行为的默契响应",
style="持续性跟发相同内容",
content_list='["对重复行为的默契响应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
second_expression = Expression(
situation="对重复行为的默契响应",
style="持续性跟发相同内容",
content_list='["对重复行为的默契响应-变体"]',
count=2,
session_id="session-b",
checked=False,
rejected=False,
)
session.add(first_expression)
session.add(second_expression)
session.commit()
session.refresh(first_expression)
session.refresh(second_expression)
assert first_expression.id is not None
assert second_expression.id is not None
assert first_expression.id != second_expression.id

View File

@@ -0,0 +1,90 @@
"""测试黑话学习器的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine, select
from src.bw_learner.jargon_miner import JargonMiner
from src.common.database.database_model import Jargon
@pytest.fixture(name="jargon_miner_engine")
def jargon_miner_engine_fixture() -> Generator:
"""创建用于黑话学习器测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.mark.asyncio
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
monkeypatch: pytest.MonkeyPatch,
jargon_miner_engine,
) -> None:
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
import src.bw_learner.jargon_miner as jargon_miner_module
with Session(jargon_miner_engine) as session:
session.add(
Jargon(
content="VF8V4L",
raw_content='["[1] first"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=0,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
)
session.commit()
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
session = Session(jargon_miner_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
await jargon_miner.process_extracted_entries(
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
)
with Session(jargon_miner_engine) as session:
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
assert db_jargon.count == 1
assert db_jargon.session_id_dict == '{"session-a": 2}'
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
"[1] first",
"[2] second",
]

View File

@@ -0,0 +1,84 @@
"""测试黑话表结构和基础插入行为。"""
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Jargon
@pytest.fixture(name="jargon_engine")
def jargon_engine_fixture() -> Generator:
"""创建仅用于黑话表测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
"""黑话表在新库中应能自动分配自增主键。"""
with Session(jargon_engine) as session:
jargon = Jargon(
content="VF8V4L",
raw_content='["[1] test"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=True,
last_inference_count=0,
)
session.add(jargon)
session.commit()
session.refresh(jargon)
assert jargon.id is not None
assert jargon.id > 0
def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
"""黑话内容不应再被错误地绑成复合主键的一部分。"""
with Session(jargon_engine) as session:
first_jargon = Jargon(
content="表情1",
raw_content='["[1] first"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
second_jargon = Jargon(
content="表情1",
raw_content='["[1] second"]',
meaning="",
session_id_dict='{"session-b": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
session.add(first_jargon)
session.add(second_jargon)
session.commit()
session.refresh(first_jargon)
session.refresh(second_jargon)
assert first_jargon.id is not None
assert second_jargon.id is not None
assert first_jargon.id != second_jargon.id

View File

@@ -0,0 +1,355 @@
"""人物信息群名片字段兼容测试。"""
from __future__ import annotations
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
import json
import sys
import pytest
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
class _DummyLogger:
"""模拟日志记录器。"""
def debug(self, message: str) -> None:
"""记录调试日志。
Args:
message: 日志内容。
"""
del message
def info(self, message: str) -> None:
"""记录信息日志。
Args:
message: 日志内容。
"""
del message
def warning(self, message: str) -> None:
"""记录警告日志。
Args:
message: 日志内容。
"""
del message
def error(self, message: str) -> None:
"""记录错误日志。
Args:
message: 日志内容。
"""
del message
class _DummyStatement:
"""模拟 SQL 查询语句对象。"""
def where(self, condition: Any) -> "_DummyStatement":
"""附加过滤条件。
Args:
condition: 过滤条件。
Returns:
_DummyStatement: 当前语句对象。
"""
del condition
return self
def limit(self, value: int) -> "_DummyStatement":
"""限制返回条数。
Args:
value: 条数限制。
Returns:
_DummyStatement: 当前语句对象。
"""
del value
return self
class _DummyColumn:
"""模拟 SQLModel 列对象。"""
def is_not(self, value: Any) -> "_DummyColumn":
"""模拟 `IS NOT` 条件构造。
Args:
value: 比较值。
Returns:
_DummyColumn: 当前列对象。
"""
del value
return self
def __eq__(self, other: Any) -> "_DummyColumn":
"""模拟等值条件构造。
Args:
other: 比较值。
Returns:
_DummyColumn: 当前列对象。
"""
del other
return self
class _DummyResult:
"""模拟数据库查询结果。"""
def __init__(self, record: Any) -> None:
"""初始化查询结果。
Args:
record: 待返回的首条记录。
"""
self._record = record
def first(self) -> Any:
"""返回第一条记录。
Returns:
Any: 首条记录。
"""
return self._record
def all(self) -> list[Any]:
"""返回全部结果。
Returns:
list[Any]: 结果列表。
"""
if self._record is None:
return []
return self._record if isinstance(self._record, list) else [self._record]
class _DummySession:
"""模拟数据库 Session。"""
def __init__(self, record: Any) -> None:
"""初始化 Session。
Args:
record: `first()` 应返回的记录。
"""
self.record = record
self.added_records: list[Any] = []
def __enter__(self) -> "_DummySession":
"""进入上下文管理器。
Returns:
_DummySession: 当前 Session。
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""退出上下文管理器。
Args:
exc_type: 异常类型。
exc_val: 异常值。
exc_tb: 异常回溯。
"""
del exc_type
del exc_val
del exc_tb
def exec(self, statement: Any) -> _DummyResult:
"""执行查询。
Args:
statement: 查询语句。
Returns:
_DummyResult: 模拟结果对象。
"""
del statement
return _DummyResult(self.record)
def add(self, record: Any) -> None:
"""记录被添加的对象。
Args:
record: 被写入 Session 的对象。
"""
self.added_records.append(record)
class _DummyPersonInfoRecord:
"""模拟 `PersonInfo` ORM 模型。"""
person_id = "person_id"
person_name = "person_name"
def __init__(self, **kwargs: Any) -> None:
"""使用关键字参数初始化记录对象。
Args:
**kwargs: 字段值。
"""
for key, value in kwargs.items():
setattr(self, key, value)
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
"""加载带依赖桩的 `person_info` 模块。
Args:
monkeypatch: Pytest monkeypatch 工具。
session: 提供给模块使用的假数据库 Session。
Returns:
ModuleType: 加载后的模块对象。
"""
logger_module = ModuleType("src.common.logger")
logger_module.get_logger = lambda name: _DummyLogger()
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
database_module = ModuleType("src.common.database.database")
database_module.get_db_session = lambda: session
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
database_model_module = ModuleType("src.common.database.database_model")
database_model_module.PersonInfo = _DummyPersonInfoRecord
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
llm_module = ModuleType("src.llm_models.utils_model")
class _DummyLLMRequest:
"""模拟 LLMRequest。"""
def __init__(self, model_set: Any, request_type: str) -> None:
"""初始化假请求对象。
Args:
model_set: 模型配置。
request_type: 请求类型。
"""
del model_set
del request_type
llm_module.LLMRequest = _DummyLLMRequest
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
config_module = ModuleType("src.config.config")
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
chat_manager_module.chat_manager = SimpleNamespace()
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
assert spec is not None and spec.loader is not None
module = module_from_spec(spec)
monkeypatch.setitem(sys.modules, spec.name, module)
spec.loader.exec_module(module)
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
return module
def test_parse_group_cardname_json_uses_canonical_key() -> None:
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
parsed = parse_group_cardname_json(
json.dumps(
[
{"group_id": "1001", "group_cardname": "现行字段"},
],
ensure_ascii=False,
)
)
assert parsed is not None
assert [(item.group_id, item.group_cardname) for item in parsed] == [
("1001", "现行字段"),
]
def test_dump_group_cardname_records_uses_canonical_key() -> None:
"""群名片序列化应输出 `group_cardname` 键名。"""
dumped = dump_group_cardname_records(
[
{"group_id": "1001", "group_cardname": "群昵称"},
]
)
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
record = _DummyPersonInfoRecord()
session = _DummySession(record)
module = _load_person_module(monkeypatch, session)
person = module.Person.__new__(module.Person)
person.is_known = True
person.person_id = "person-1"
person.platform = "qq"
person.user_id = "10001"
person.nickname = "看番的龙"
person.person_name = "看番的龙"
person.name_reason = "测试"
person.know_times = 1
person.know_since = 1700000000.0
person.last_know = 1700000100.0
person.memory_points = ["喜好:番剧:0.8"]
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
person.sync_to_database()
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
assert not hasattr(record, "group_nickname")
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
record = _DummyPersonInfoRecord(
user_id="10001",
platform="qq",
is_known=True,
user_nickname="看番的龙",
person_name="看番的龙",
name_reason=None,
know_counts=2,
memory_points='["喜好:番剧:0.8"]',
group_cardname=json.dumps(
[
{"group_id": "20001", "group_cardname": "白泽大人"},
],
ensure_ascii=False,
),
)
session = _DummySession(record)
module = _load_person_module(monkeypatch, session)
person = module.Person.__new__(module.Person)
person.person_id = "person-1"
person.memory_points = []
person.group_cardname_list = []
person.load_from_database()
assert person.group_cardname_list == [
{"group_id": "20001", "group_cardname": "白泽大人"},
]

View File

@@ -0,0 +1,170 @@
"""消息网关运行时状态同步测试。"""
from typing import Any, Dict
import pytest
from src.platform_io.manager import PlatformIOManager
from src.platform_io.types import RouteKey
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope:
"""构造一个 RPC 请求信封。
Args:
method: RPC 方法名。
plugin_id: 目标插件 ID。
payload: 请求载荷。
Returns:
Envelope: 标准 RPC 请求信封。
"""
return Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
payload=payload,
)
@pytest.mark.asyncio
async def test_message_gateway_runtime_state_binds_send_and_receive_routes(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""消息网关就绪后应同时绑定发送表和接收表。"""
import src.plugin_runtime.host.supervisor as supervisor_module
platform_io_manager = PlatformIOManager()
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
supervisor = PluginSupervisor(plugin_dirs=[])
register_response = await supervisor._handle_register_plugin(
_make_request(
"plugin.register_components",
"napcat_plugin",
{
"plugin_id": "napcat_plugin",
"plugin_version": "1.0.0",
"components": [
{
"name": "napcat_gateway",
"component_type": "MESSAGE_GATEWAY",
"plugin_id": "napcat_plugin",
"metadata": {
"route_type": "duplex",
"platform": "qq",
"protocol": "napcat",
},
}
],
"capabilities_required": [],
},
)
)
assert register_response.error is None
response = await supervisor._handle_update_message_gateway_state(
_make_request(
"host.update_message_gateway_state",
"napcat_plugin",
{
"gateway_name": "napcat_gateway",
"ready": True,
"platform": "qq",
"account_id": "10001",
"scope": "primary",
"metadata": {},
},
)
)
assert response.error is None
assert response.payload["accepted"] is True
send_bindings = platform_io_manager.send_route_table.resolve_bindings(
RouteKey(platform="qq", account_id="10001", scope="primary")
)
receive_bindings = platform_io_manager.receive_route_table.resolve_bindings(
RouteKey(platform="qq", account_id="10001", scope="primary")
)
assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
@pytest.mark.asyncio
async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""消息网关断开后应撤销发送表和接收表中的绑定。"""
import src.plugin_runtime.host.supervisor as supervisor_module
platform_io_manager = PlatformIOManager()
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
supervisor = PluginSupervisor(plugin_dirs=[])
await supervisor._handle_register_plugin(
_make_request(
"plugin.register_components",
"napcat_plugin",
{
"plugin_id": "napcat_plugin",
"plugin_version": "1.0.0",
"components": [
{
"name": "napcat_gateway",
"component_type": "MESSAGE_GATEWAY",
"plugin_id": "napcat_plugin",
"metadata": {
"route_type": "duplex",
"platform": "qq",
"protocol": "napcat",
},
}
],
"capabilities_required": [],
},
)
)
await supervisor._handle_update_message_gateway_state(
_make_request(
"host.update_message_gateway_state",
"napcat_plugin",
{
"gateway_name": "napcat_gateway",
"ready": True,
"platform": "qq",
"account_id": "10001",
"scope": "primary",
"metadata": {},
},
)
)
response = await supervisor._handle_update_message_gateway_state(
_make_request(
"host.update_message_gateway_state",
"napcat_plugin",
{
"gateway_name": "napcat_gateway",
"ready": False,
"platform": "qq",
"account_id": "",
"scope": "",
"metadata": {},
},
)
)
assert response.error is None
assert response.payload["accepted"] is True
assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
assert (
platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
)

View File

@@ -0,0 +1,132 @@
"""NapCat 插件与新 SDK 对接测试。"""
from pathlib import Path
from typing import Any, Dict, List
import importlib
import logging
import sys
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)):
if import_path not in sys.path:
sys.path.insert(0, import_path)
class _FakeGatewayCapability:
"""用于捕获消息网关状态上报的测试替身。"""
def __init__(self) -> None:
"""初始化测试替身。"""
self.calls: List[Dict[str, Any]] = []
async def update_state(
self,
gateway_name: str,
*,
ready: bool,
platform: str = "",
account_id: str = "",
scope: str = "",
metadata: Dict[str, Any] | None = None,
) -> bool:
"""记录一次状态上报请求。
Args:
gateway_name: 网关组件名称。
ready: 当前是否就绪。
platform: 平台名称。
account_id: 账号 ID。
scope: 路由作用域。
metadata: 附加元数据。
Returns:
bool: 始终返回 ``True``,模拟 Host 接受状态更新。
"""
self.calls.append(
{
"gateway_name": gateway_name,
"ready": ready,
"platform": platform,
"account_id": account_id,
"scope": scope,
"metadata": metadata or {},
}
)
return True
def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]:
"""动态加载 NapCat 插件测试所需的符号。
Returns:
tuple[Any, Any, Any, Any]:
依次返回网关名常量、配置类、插件类和运行时状态管理器类。
"""
constants_module = importlib.import_module("napcat_adapter.constants")
config_module = importlib.import_module("napcat_adapter.config")
plugin_module = importlib.import_module("napcat_adapter.plugin")
runtime_state_module = importlib.import_module("napcat_adapter.runtime_state")
return (
constants_module.NAPCAT_GATEWAY_NAME,
config_module.NapCatServerConfig,
plugin_module.NapCatAdapterPlugin,
runtime_state_module.NapCatRuntimeStateManager,
)
def test_napcat_plugin_collects_duplex_message_gateway() -> None:
"""NapCat 插件应声明新的双工消息网关组件。"""
napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
components = plugin.get_components()
gateway_components = [
component
for component in components
if component.get("type") == "MESSAGE_GATEWAY"
]
assert len(gateway_components) == 1
gateway_component = gateway_components[0]
assert gateway_component["name"] == napcat_gateway_name
assert gateway_component["metadata"]["route_type"] == "duplex"
assert gateway_component["metadata"]["platform"] == "qq"
assert gateway_component["metadata"]["protocol"] == "napcat"
@pytest.mark.asyncio
async def test_runtime_state_reports_via_gateway_capability() -> None:
"""NapCat 运行时状态应通过新的消息网关能力上报。"""
napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
gateway_capability = _FakeGatewayCapability()
runtime_state_manager = runtime_state_cls(
gateway_capability=gateway_capability,
logger=logging.getLogger("test.napcat_adapter"),
gateway_name=napcat_gateway_name,
)
connected = await runtime_state_manager.report_connected(
"10001",
napcat_server_config_cls(connection_id="primary"),
)
await runtime_state_manager.report_disconnected()
assert connected is True
assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
assert gateway_capability.calls[0]["ready"] is True
assert gateway_capability.calls[0]["platform"] == "qq"
assert gateway_capability.calls[0]["account_id"] == "10001"
assert gateway_capability.calls[0]["scope"] == "primary"
assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
assert gateway_capability.calls[1]["ready"] is False
assert gateway_capability.calls[1]["platform"] == "qq"

View File

@@ -0,0 +1,209 @@
"""Platform IO 入站去重策略测试。"""
from types import SimpleNamespace
from typing import Any, Dict, List, Optional
import pytest
from src.platform_io.drivers.base import PlatformIODriver
from src.platform_io.manager import PlatformIOManager
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey
def _build_envelope(
*,
dedupe_key: str | None = None,
external_message_id: str | None = None,
session_message_id: str | None = None,
payload: Optional[Dict[str, Any]] = None,
) -> InboundMessageEnvelope:
"""构造测试用入站信封。
Args:
dedupe_key: 显式去重键。
external_message_id: 平台侧消息 ID。
session_message_id: 规范化消息对象上的消息 ID。
payload: 原始载荷。
Returns:
InboundMessageEnvelope: 测试用入站消息信封。
"""
session_message = None
if session_message_id is not None:
session_message = SimpleNamespace(message_id=session_message_id)
return InboundMessageEnvelope(
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
driver_id="plugin.napcat",
driver_kind=DriverKind.PLUGIN,
dedupe_key=dedupe_key,
external_message_id=external_message_id,
session_message=session_message,
payload=payload,
)
class _StubPlatformIODriver(PlatformIODriver):
"""测试用 Platform IO 驱动。"""
async def send_message(
self,
message: Any,
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
"""返回一个固定的成功回执。
Args:
message: 待发送的消息对象。
route_key: 本次发送使用的路由键。
metadata: 额外发送元数据。
Returns:
DeliveryReceipt: 固定的成功回执。
"""
return DeliveryReceipt(
internal_message_id=str(getattr(message, "message_id", "stub-message-id")),
route_key=route_key,
status=DeliveryStatus.SENT,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
)
def _build_manager() -> PlatformIOManager:
"""构造带有最小接收路由的 Broker 管理器。
Returns:
PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。
"""
manager = PlatformIOManager()
driver = _StubPlatformIODriver(
DriverDescriptor(
driver_id="plugin.napcat",
kind=DriverKind.PLUGIN,
platform="qq",
account_id="10001",
scope="main",
)
)
manager.register_driver(driver)
manager.bind_receive_route(
RouteBinding(
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
driver_id=driver.driver_id,
driver_kind=driver.descriptor.kind,
)
)
return manager
class TestPlatformIODedupe:
"""Platform IO 去重测试。"""
@pytest.mark.asyncio
async def test_accept_inbound_dedupes_by_external_message_id(self) -> None:
"""相同平台消息 ID 的重复入站应被抑制。"""
manager = _build_manager()
accepted_envelopes: List[InboundMessageEnvelope] = []
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
"""记录被成功接收的入站消息。
Args:
envelope: 被 Broker 接受的入站消息。
"""
accepted_envelopes.append(envelope)
manager.set_inbound_dispatcher(dispatcher)
first_envelope = _build_envelope(
external_message_id="msg-1",
payload={"message": "hello"},
)
second_envelope = _build_envelope(
external_message_id="msg-1",
payload={"message": "hello"},
)
assert await manager.accept_inbound(first_envelope) is True
assert await manager.accept_inbound(second_envelope) is False
assert len(accepted_envelopes) == 1
@pytest.mark.asyncio
async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None:
"""缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。"""
manager = _build_manager()
accepted_envelopes: List[InboundMessageEnvelope] = []
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
"""记录被成功接收的入站消息。
Args:
envelope: 被 Broker 接受的入站消息。
"""
accepted_envelopes.append(envelope)
manager.set_inbound_dispatcher(dispatcher)
first_envelope = _build_envelope(payload={"message": "same-payload"})
second_envelope = _build_envelope(payload={"message": "same-payload"})
assert await manager.accept_inbound(first_envelope) is True
assert await manager.accept_inbound(second_envelope) is True
assert len(accepted_envelopes) == 2
def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None:
"""去重键应只来自显式或稳定的技术身份。"""
explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1")
session_message_envelope = _build_envelope(session_message_id="session-1")
payload_only_envelope = _build_envelope(payload={"message": "hello"})
assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1"
assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1"
assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None
@pytest.mark.asyncio
async def test_send_message_fans_out_to_all_matching_routes(self) -> None:
"""同一路由命中多条发送链路时应全部发送。"""
manager = PlatformIOManager()
first_driver = _StubPlatformIODriver(
DriverDescriptor(
driver_id="plugin.gateway_a",
kind=DriverKind.PLUGIN,
platform="qq",
)
)
second_driver = _StubPlatformIODriver(
DriverDescriptor(
driver_id="plugin.gateway_b",
kind=DriverKind.PLUGIN,
platform="qq",
)
)
manager.register_driver(first_driver)
manager.register_driver(second_driver)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=first_driver.driver_id,
driver_kind=first_driver.descriptor.kind,
)
)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=second_driver.driver_id,
driver_kind=second_driver.descriptor.kind,
)
)
message = SimpleNamespace(message_id="internal-msg-1")
result = await manager.send_message(message, RouteKey(platform="qq"))
assert result.has_success is True
assert [receipt.driver_id for receipt in result.sent_receipts] == [
"plugin.gateway_a",
"plugin.gateway_b",
]

View File

@@ -0,0 +1,178 @@
"""Platform IO legacy driver 回归测试。"""
from typing import Any, Dict, Optional
import pytest
from src.chat.utils import utils as chat_utils
from src.chat.message_receive import uni_message_sender
from src.platform_io.drivers.base import PlatformIODriver
from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
from src.platform_io.manager import PlatformIOManager
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey
class _PluginDriver(PlatformIODriver):
"""测试用插件发送驱动。"""
def __init__(self, driver_id: str, platform: str) -> None:
"""初始化测试驱动。
Args:
driver_id: 驱动 ID。
platform: 负责的平台名称。
"""
super().__init__(
DriverDescriptor(
driver_id=driver_id,
kind=DriverKind.PLUGIN,
platform=platform,
plugin_id="test.plugin",
)
)
async def send_message(
self,
message: Any,
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
"""返回一个固定成功回执。
Args:
message: 待发送消息。
route_key: 当前路由键。
metadata: 发送元数据。
Returns:
DeliveryReceipt: 固定成功回执。
"""
del metadata
return DeliveryReceipt(
internal_message_id=str(message.message_id),
route_key=route_key,
status=DeliveryStatus.SENT,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
)
@pytest.mark.asyncio
async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""没有显式发送路由时,应由 Platform IO 回退到 legacy driver。"""
manager = PlatformIOManager()
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
try:
await manager.ensure_send_pipeline_ready()
fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"]
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
await manager.add_driver(plugin_driver)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=plugin_driver.driver_id,
driver_kind=plugin_driver.descriptor.kind,
)
)
explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
finally:
await manager.stop()
@pytest.mark.asyncio
async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
manager = PlatformIOManager()
legacy_calls: list[dict[str, Any]] = []
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
"""记录 legacy driver 调用。"""
legacy_calls.append({"message": message, "show_log": show_log})
return True
monkeypatch.setattr(
uni_message_sender,
"send_prepared_message_to_platform",
_fake_send_prepared_message_to_platform,
)
try:
await manager.ensure_send_pipeline_ready()
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
await manager.add_driver(plugin_driver)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=plugin_driver.driver_id,
driver_kind=plugin_driver.descriptor.kind,
)
)
message = type("FakeMessage", (), {"message_id": "message-1"})()
batch = await manager.send_message(
message=message,
route_key=RouteKey(platform="qq"),
metadata={"show_log": False},
)
assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
"legacy.send.qq",
"plugin.qq.sender",
]
assert batch.failed_receipts == []
assert len(legacy_calls) == 1
assert legacy_calls[0]["message"] is message
assert legacy_calls[0]["show_log"] is False
finally:
await manager.stop()
@pytest.mark.asyncio
async def test_legacy_platform_driver_uses_prepared_universal_sender(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""legacy driver 应复用已预处理消息的旧链发送函数。"""
calls: list[dict[str, Any]] = []
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
"""记录 legacy driver 调用。"""
calls.append({"message": message, "show_log": show_log})
return True
monkeypatch.setattr(
uni_message_sender,
"send_prepared_message_to_platform",
_fake_send_prepared_message_to_platform,
)
driver = LegacyPlatformDriver(
driver_id="legacy.send.qq",
platform="qq",
account_id="bot-qq",
)
message = type("FakeMessage", (), {"message_id": "message-1"})()
receipt = await driver.send_message(
message=message,
route_key=RouteKey(platform="qq"),
metadata={"show_log": False},
)
assert len(calls) == 1
assert calls[0]["message"] is message
assert calls[0]["show_log"] is False
assert receipt.status == DeliveryStatus.SENT
assert receipt.driver_id == "legacy.send.qq"

View File

@@ -0,0 +1,87 @@
from datetime import datetime
from pathlib import Path
import sys
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import (
ForwardComponent,
ForwardNodeComponent,
ImageComponent,
MessageSequence,
ReplyComponent,
TextComponent,
VoiceComponent,
)
from src.plugin_runtime.host.message_utils import PluginMessageUtils
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None:
message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq")
message.message_info = MessageInfo(
user_info=UserInfo(user_id="10001", user_nickname="tester"),
group_info=GroupInfo(group_id="20001", group_name="group"),
additional_config={"self_id": "999"},
)
message.session_id = "qq:20001:10001"
message.processed_plain_text = "binary payload"
message.display_message = "binary payload"
message.raw_message = MessageSequence(
components=[
TextComponent("hello"),
ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""),
VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""),
ReplyComponent(
target_message_id="origin-1",
target_message_content="origin text",
target_message_sender_id="42",
target_message_sender_nickname="alice",
target_message_sender_cardname="Alice",
),
ForwardNodeComponent(
forward_components=[
ForwardComponent(
user_nickname="bob",
user_id="43",
user_cardname="Bob",
message_id="forward-1",
content=[
TextComponent("node-text"),
ImageComponent(binary_hash="", binary_data=b"node-image", content=""),
],
)
]
),
]
)
message_dict = PluginMessageUtils._session_message_to_dict(message)
rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
image_component = rebuilt_message.raw_message.components[1]
voice_component = rebuilt_message.raw_message.components[2]
reply_component = rebuilt_message.raw_message.components[3]
forward_component = rebuilt_message.raw_message.components[4]
assert isinstance(image_component, ImageComponent)
assert image_component.binary_data == b"image-bytes"
assert isinstance(voice_component, VoiceComponent)
assert voice_component.binary_data == b"voice-bytes"
assert isinstance(reply_component, ReplyComponent)
assert reply_component.target_message_id == "origin-1"
assert reply_component.target_message_content == "origin text"
assert reply_component.target_message_sender_id == "42"
assert reply_component.target_message_sender_nickname == "alice"
assert reply_component.target_message_sender_cardname == "Alice"
assert isinstance(forward_component, ForwardNodeComponent)
assert isinstance(forward_component.forward_components[0].content[1], ImageComponent)
assert forward_component.forward_components[0].content[1].binary_data == b"node-image"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,284 @@
"""核心组件查询层与插件运行时聚合测试。"""
from types import SimpleNamespace
from typing import Any
import pytest
import src.plugin_runtime.integration as integration_module
from src.core.types import ActionInfo, ToolInfo
from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.host.supervisor import PluginSupervisor
class _FakeRuntimeManager:
"""测试用插件运行时管理器。"""
def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None:
"""初始化测试用运行时管理器。
Args:
supervisor: 持有测试组件的监督器。
plugin_id: 目标插件 ID。
plugin_config: 需要返回的插件配置。
"""
self.supervisors = [supervisor]
self._plugin_id = plugin_id
self._plugin_config = plugin_config
def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None:
"""按插件 ID 返回对应监督器。
Args:
plugin_id: 目标插件 ID。
Returns:
PluginSupervisor | None: 命中时返回监督器。
"""
return self.supervisors[0] if plugin_id == self._plugin_id else None
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]:
"""返回测试配置。
Args:
supervisor: 监督器实例。
plugin_id: 目标插件 ID。
Returns:
dict[str, Any]: 测试配置内容。
"""
del supervisor
if plugin_id != self._plugin_id:
return {}
return dict(self._plugin_config)
def _install_runtime_manager(
monkeypatch: pytest.MonkeyPatch,
supervisor: PluginSupervisor,
plugin_id: str,
plugin_config: dict[str, Any] | None = None,
) -> None:
"""为测试安装假的运行时管理器。
Args:
monkeypatch: pytest monkeypatch 对象。
supervisor: 持有测试组件的监督器。
plugin_id: 测试插件 ID。
plugin_config: 可选的测试配置内容。
"""
fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True})
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager)
@pytest.mark.asyncio
async def test_core_component_registry_reads_runtime_action_and_executor(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""核心查询层应直接读取运行时 Action并返回 RPC 执行闭包。"""
plugin_id = "runtime_action_bridge_plugin"
action_name = "runtime_action_bridge_test"
supervisor = PluginSupervisor(plugin_dirs=[])
captured: dict[str, Any] = {}
supervisor.component_registry.register_component(
name=action_name,
component_type="ACTION",
plugin_id=plugin_id,
metadata={
"description": "发送一个测试回复",
"enabled": True,
"activation_type": "keyword",
"activation_probability": 0.25,
"activation_keywords": ["测试", "hello"],
"action_parameters": {"target": "目标对象"},
"action_require": ["需要发送回复时使用"],
"associated_types": ["text"],
"parallel_action": True,
},
)
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"})
async def fake_invoke_plugin(
method: str,
plugin_id: str,
component_name: str,
args: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟动作 RPC 调用。"""
captured["method"] = method
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")})
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
action_info = component_query_service.get_action_info(action_name)
assert isinstance(action_info, ActionInfo)
assert action_info.plugin_name == plugin_id
assert action_info.description == "发送一个测试回复"
assert action_info.activation_keywords == ["测试", "hello"]
assert action_info.random_activation_probability == 0.25
assert action_info.parallel_action is True
assert action_name in component_query_service.get_default_actions()
assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"}
executor = component_query_service.get_action_executor(action_name)
assert executor is not None
success, reason = await executor(
action_data={"target": "MaiBot"},
action_reasoning="当前适合使用这个动作",
cycle_timers={"planner": 0.1},
thinking_id="tid-1",
chat_stream=SimpleNamespace(session_id="stream-1"),
log_prefix="[test]",
shutting_down=False,
plugin_config={"enabled": True},
)
assert success is True
assert reason == "runtime action executed"
assert captured["method"] == "plugin.invoke_action"
assert captured["plugin_id"] == plugin_id
assert captured["component_name"] == action_name
assert captured["args"]["stream_id"] == "stream-1"
assert captured["args"]["chat_id"] == "stream-1"
assert captured["args"]["reasoning"] == "当前适合使用这个动作"
assert captured["args"]["target"] == "MaiBot"
assert captured["args"]["action_data"] == {"target": "MaiBot"}
@pytest.mark.asyncio
async def test_core_component_registry_reads_runtime_command_and_executor(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""核心查询层应直接使用运行时命令匹配与执行闭包。"""
plugin_id = "runtime_command_bridge_plugin"
command_name = "runtime_command_bridge_test"
supervisor = PluginSupervisor(plugin_dirs=[])
captured: dict[str, Any] = {}
supervisor.component_registry.register_component(
name=command_name,
component_type="COMMAND",
plugin_id=plugin_id,
metadata={
"description": "测试命令",
"enabled": True,
"command_pattern": r"^/test(?:\s+.+)?$",
"aliases": ["/hello"],
"intercept_message_level": 1,
},
)
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"})
async def fake_invoke_plugin(
method: str,
plugin_id: str,
component_name: str,
args: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟命令 RPC 调用。"""
captured["method"] = method
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)})
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
matched = component_query_service.find_command_by_text("/test hello")
assert matched is not None
command_executor, matched_groups, command_info = matched
assert matched_groups == {}
assert command_info.plugin_name == plugin_id
assert command_info.command_pattern == r"^/test(?:\s+.+)?$"
success, response_text, intercept = await command_executor(
message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"),
plugin_config={"mode": "command"},
matched_groups=matched_groups,
)
assert success is True
assert response_text == "command ok"
assert intercept is True
assert captured["method"] == "plugin.invoke_command"
assert captured["plugin_id"] == plugin_id
assert captured["component_name"] == command_name
assert captured["args"]["text"] == "/test hello"
assert captured["args"]["stream_id"] == "stream-2"
assert captured["args"]["plugin_config"] == {"mode": "command"}
@pytest.mark.asyncio
async def test_core_component_registry_reads_runtime_tools_and_executor(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""核心查询层应直接读取运行时 Tool并返回 RPC 执行闭包。"""
plugin_id = "runtime_tool_bridge_plugin"
tool_name = "runtime_tool_bridge_test"
supervisor = PluginSupervisor(plugin_dirs=[])
supervisor.component_registry.register_component(
name=tool_name,
component_type="TOOL",
plugin_id=plugin_id,
metadata={
"description": "测试工具",
"enabled": True,
"parameters": [
{
"name": "query",
"param_type": "string",
"description": "查询词",
"required": True,
}
],
},
)
_install_runtime_manager(monkeypatch, supervisor, plugin_id)
async def fake_invoke_plugin(
method: str,
plugin_id: str,
component_name: str,
args: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟工具 RPC 调用。"""
del timeout_ms
assert method == "plugin.invoke_tool"
assert plugin_id == "runtime_tool_bridge_plugin"
assert component_name == "runtime_tool_bridge_test"
assert args == {"query": "MaiBot"}
return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}})
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
tool_info = component_query_service.get_tool_info(tool_name)
assert isinstance(tool_info, ToolInfo)
assert tool_info.tool_description == "测试工具"
assert tool_name in component_query_service.get_llm_available_tools()
executor = component_query_service.get_tool_executor(tool_name)
assert executor is not None
assert await executor({"query": "MaiBot"}) == {"content": "tool ok"}

View File

@@ -0,0 +1,524 @@
"""插件 API 注册与调用测试。"""
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
from src.plugin_runtime.integration import PluginRuntimeManager
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import (
ComponentDeclaration,
Envelope,
MessageType,
RegisterPluginPayload,
UnregisterPluginPayload,
)
def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager:
"""构造一个最小可用的插件运行时管理器。
Args:
*supervisors: 需要挂载的监督器列表。
Returns:
PluginRuntimeManager: 已注入监督器的运行时管理器。
"""
manager = PluginRuntimeManager()
if supervisors:
manager._builtin_supervisor = supervisors[0]
if len(supervisors) > 1:
manager._third_party_supervisor = supervisors[1]
return manager
async def _register_plugin(
supervisor: PluginSupervisor,
plugin_id: str,
components: List[Dict[str, Any]],
) -> Envelope:
"""通过 Supervisor 注册测试插件。
Args:
supervisor: 目标监督器。
plugin_id: 测试插件 ID。
components: 组件声明列表。
Returns:
Envelope: 注册响应信封。
"""
payload = RegisterPluginPayload(
plugin_id=plugin_id,
plugin_version="1.0.0",
components=[
ComponentDeclaration(
name=str(component.get("name", "") or ""),
component_type=str(component.get("component_type", "") or ""),
plugin_id=plugin_id,
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
)
for component in components
],
)
return await supervisor._handle_register_plugin(
Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.register_components",
plugin_id=plugin_id,
payload=payload.model_dump(),
)
)
async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope:
"""通过 Supervisor 注销测试插件。
Args:
supervisor: 目标监督器。
plugin_id: 测试插件 ID。
Returns:
Envelope: 注销响应信封。
"""
payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test")
return await supervisor._handle_unregister_plugin(
Envelope(
request_id=2,
message_type=MessageType.REQUEST,
method="plugin.unregister",
plugin_id=plugin_id,
payload=payload.model_dump(),
)
)
@pytest.mark.asyncio
async def test_register_plugin_syncs_dedicated_api_registry() -> None:
"""插件注册时应将 API 同步到独立注册表,而不是通用组件表。"""
supervisor = PluginSupervisor(plugin_dirs=[])
response = await _register_plugin(
supervisor,
"provider",
[
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML",
"version": "1",
"public": True,
},
}
],
)
assert response.payload["accepted"] is True
assert response.payload["registered_components"] == 0
assert response.payload["registered_apis"] == 1
assert supervisor.api_registry.get_api("provider", "render_html") is not None
assert supervisor.component_registry.get_component("provider.render_html") is None
unregister_response = await _unregister_plugin(supervisor, "provider")
assert unregister_response.payload["removed_apis"] == 1
assert supervisor.api_registry.get_api("provider", "render_html") is None
@pytest.mark.asyncio
async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None:
"""公开 API 应允许其他插件通过 Host 转发调用。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML",
"version": "1",
"public": True,
},
}
],
)
await _register_plugin(consumer_supervisor, "consumer", [])
captured: Dict[str, Any] = {}
async def fake_invoke_api(
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟 API RPC 调用。"""
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
manager = _build_manager(provider_supervisor, consumer_supervisor)
result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.render_html",
"version": "1",
"args": {"html": "<div>Hello</div>"},
},
)
assert result == {"success": True, "result": {"image": "ok"}}
assert captured["plugin_id"] == "provider"
assert captured["component_name"] == "render_html"
assert captured["args"] == {"html": "<div>Hello</div>"}
@pytest.mark.asyncio
async def test_api_call_rejects_private_api_between_plugins() -> None:
"""未公开的 API 默认不允许跨插件调用。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "secret_api",
"component_type": "API",
"metadata": {
"description": "私有 API",
"version": "1",
"public": False,
},
}
],
)
await _register_plugin(consumer_supervisor, "consumer", [])
manager = _build_manager(provider_supervisor, consumer_supervisor)
result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.secret_api",
"args": {},
},
)
assert result["success"] is False
assert "未公开" in str(result["error"])
@pytest.mark.asyncio
async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
"""API 列表与组件启停应直接作用于独立 API 注册表。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "public_api",
"component_type": "API",
"metadata": {"version": "1", "public": True},
},
{
"name": "private_api",
"component_type": "API",
"metadata": {"version": "1", "public": False},
},
],
)
await _register_plugin(
consumer_supervisor,
"consumer",
[
{
"name": "self_private_api",
"component_type": "API",
"metadata": {"version": "1", "public": False},
}
],
)
manager = _build_manager(provider_supervisor, consumer_supervisor)
list_result = await manager._cap_api_list("consumer", "api.list", {})
assert list_result["success"] is True
api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]}
assert ("provider", "public_api") in api_names
assert ("provider", "private_api") not in api_names
assert ("consumer", "self_private_api") in api_names
disable_result = await manager._cap_component_disable(
"consumer",
"component.disable",
{
"name": "provider.public_api",
"component_type": "API",
"scope": "global",
"stream_id": "",
},
)
assert disable_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None
enable_result = await manager._cap_component_enable(
"consumer",
"component.enable",
{
"name": "provider.public_api",
"component_type": "API",
"scope": "global",
"stream_id": "",
},
)
assert enable_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
@pytest.mark.asyncio
async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML v1",
"version": "1",
"public": True,
"handler_name": "handle_render_html_v1",
},
},
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML v2",
"version": "2",
"public": True,
"handler_name": "handle_render_html_v2",
},
},
],
)
await _register_plugin(consumer_supervisor, "consumer", [])
captured: Dict[str, Any] = {}
async def fake_invoke_api(
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟多版本 API 调用。"""
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
manager = _build_manager(provider_supervisor, consumer_supervisor)
ambiguous_result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.render_html",
"args": {"html": "<div>Hello</div>"},
},
)
assert ambiguous_result["success"] is False
assert "多个版本" in str(ambiguous_result["error"])
disable_ambiguous_result = await manager._cap_component_disable(
"consumer",
"component.disable",
{
"name": "provider.render_html",
"component_type": "API",
"scope": "global",
"stream_id": "",
},
)
assert disable_ambiguous_result["success"] is False
assert "多个版本" in str(disable_ambiguous_result["error"])
disable_v1_result = await manager._cap_component_disable(
"consumer",
"component.disable",
{
"name": "provider.render_html",
"component_type": "API",
"scope": "global",
"stream_id": "",
"version": "1",
},
)
assert disable_v1_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.render_html",
"version": "2",
"args": {"html": "<div>Hello</div>"},
},
)
assert result == {"success": True, "result": {"image": "ok"}}
assert captured["plugin_id"] == "provider"
assert captured["component_name"] == "handle_render_html_v2"
assert captured["args"] == {"html": "<div>Hello</div>"}
@pytest.mark.asyncio
async def test_api_replace_dynamic_can_offline_removed_entries(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""动态 API 替换后,被移除的 API 应返回明确下线错误。"""
supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(supervisor, "provider", [])
manager = _build_manager(supervisor)
captured: Dict[str, Any] = {}
async def fake_invoke_api(
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟动态 API 调用。"""
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
replace_result = await manager._cap_api_replace_dynamic(
"provider",
"api.replace_dynamic",
{
"apis": [
{
"name": "mcp.search",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_search",
},
},
{
"name": "mcp.read",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_read",
},
},
],
"offline_reason": "MCP 服务器已关闭",
},
)
assert replace_result["success"] is True
assert replace_result["count"] == 2
list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
("mcp.read", "1"),
("mcp.search", "1"),
}
call_result = await manager._cap_api_call(
"provider",
"api.call",
{
"api_name": "provider.mcp.search",
"version": "1",
"args": {"query": "hello"},
},
)
assert call_result == {"success": True, "result": {"ok": True}}
assert captured["component_name"] == "dynamic_search"
assert captured["args"]["query"] == "hello"
assert captured["args"]["__maibot_api_name__"] == "mcp.search"
assert captured["args"]["__maibot_api_version__"] == "1"
second_replace_result = await manager._cap_api_replace_dynamic(
"provider",
"api.replace_dynamic",
{
"apis": [
{
"name": "mcp.read",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_read",
},
}
],
"offline_reason": "MCP 服务器已关闭",
},
)
assert second_replace_result["success"] is True
assert second_replace_result["count"] == 1
assert second_replace_result["offlined"] == 1
offlined_call_result = await manager._cap_api_call(
"provider",
"api.call",
{
"api_name": "provider.mcp.search",
"version": "1",
"args": {},
},
)
assert offlined_call_result["success"] is False
assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
("mcp.read", "1"),
}

View File

@@ -0,0 +1,154 @@
"""发送服务回归测试。"""
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
from src.chat.message_receive.chat_manager import BotChatSession
from src.services import send_service
class _FakePlatformIOManager:
"""用于测试的 Platform IO 管理器假对象。"""
def __init__(self, delivery_batch: Any) -> None:
"""初始化假 Platform IO 管理器。
Args:
delivery_batch: 发送时返回的批量回执。
"""
self._delivery_batch = delivery_batch
self.ensure_calls = 0
self.sent_messages: List[Dict[str, Any]] = []
async def ensure_send_pipeline_ready(self) -> None:
"""记录发送管线准备调用次数。"""
self.ensure_calls += 1
def build_route_key_from_message(self, message: Any) -> Any:
"""根据消息构造假的路由键。
Args:
message: 待发送的内部消息对象。
Returns:
Any: 简化后的路由键对象。
"""
del message
return SimpleNamespace(platform="qq")
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
"""记录发送请求并返回预设回执。
Args:
message: 待发送的内部消息对象。
route_key: 本次发送使用的路由键。
metadata: 发送元数据。
Returns:
Any: 预设的批量发送回执。
"""
self.sent_messages.append(
{
"message": message,
"route_key": route_key,
"metadata": metadata,
}
)
return self._delivery_batch
def _build_target_stream() -> BotChatSession:
"""构造一个最小可用的目标会话对象。
Returns:
BotChatSession: 测试用会话对象。
"""
return BotChatSession(
session_id="test-session",
platform="qq",
user_id="target-user",
group_id=None,
)
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
assert metadata["platform_io_account_id"] == "bot-qq"
assert metadata["platform_io_target_user_id"] == "target-user"
@pytest.mark.asyncio
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
"""send service 应将发送职责统一交给 Platform IO。"""
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
sent_receipts=[SimpleNamespace(driver_id="plugin.qq.sender")],
failed_receipts=[],
route_key=SimpleNamespace(platform="qq"),
)
)
stored_messages: List[Any] = []
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
"store_message_to_db",
lambda message: stored_messages.append(message),
)
result = await send_service.text_to_stream(text="你好", stream_id="test-session")
assert result is True
assert fake_manager.ensure_calls == 1
assert len(fake_manager.sent_messages) == 1
assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False}
assert len(stored_messages) == 1
@pytest.mark.asyncio
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
"""Platform IO 批量发送全部失败时,应直接向上返回失败。"""
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=False,
sent_receipts=[],
failed_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
status="failed",
error="network error",
)
],
route_key=SimpleNamespace(platform="qq"),
)
)
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
)
result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
assert result is False
assert fake_manager.ensure_calls == 1
assert len(fake_manager.sent_messages) == 1

View File

@@ -0,0 +1,115 @@
"""统计模块数据库会话行为测试。"""
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime, timedelta
from types import ModuleType
from typing import Any, Callable, Iterator
import sys
import pytest
from src.chat.utils import statistic
class _DummyResult:
"""模拟 SQLModel 查询结果对象。"""
def all(self) -> list[Any]:
"""返回空结果集。
Returns:
list[Any]: 空列表。
"""
return []
class _DummySession:
"""模拟数据库 Session。"""
def exec(self, statement: Any) -> _DummyResult:
"""执行查询语句并返回空结果。
Args:
statement: 待执行的查询语句。
Returns:
_DummyResult: 空结果对象。
"""
del statement
return _DummyResult()
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
"""构造一个记录 auto_commit 参数的假会话工厂。
Args:
calls: 用于记录每次调用 auto_commit 参数的列表。
Returns:
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
"""
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
"""记录会话参数并返回假 Session。
Args:
auto_commit: 是否启用自动提交。
Yields:
Iterator[_DummySession]: 假 Session 对象。
"""
calls.append(auto_commit)
yield _DummySession()
return _fake_get_db_session
def _build_statistic_task() -> statistic.StatisticOutputTask:
"""构造一个最小可用的统计任务实例。
Returns:
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
"""
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
task.name_mapping = {}
return task
def _is_bot_self(platform: str, user_id: str) -> bool:
"""返回固定的非机器人身份判断结果。
Args:
platform: 平台名称。
user_id: 用户 ID。
Returns:
bool: 始终返回 ``False``。
"""
del platform
del user_id
return False
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
calls: list[bool] = []
now = datetime.now()
task = _build_statistic_task()
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
utils_module = ModuleType("src.chat.utils.utils")
utils_module.is_bot_self = _is_bot_self
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
statistic.StatisticOutputTask._fetch_online_time_since(now)
statistic.StatisticOutputTask._fetch_model_usage_since(now)
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
task._collect_interval_data(now, hours=1, interval_minutes=60)
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
assert calls == [False] * 9

View File

@@ -0,0 +1,42 @@
from types import SimpleNamespace
from src.chat.message_receive.chat_manager import ChatManager
from src.common.utils.utils_session import SessionUtils
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
assert base_session_id == same_base_session_id
assert account_scoped_session_id != base_session_id
assert route_scoped_session_id != account_scoped_session_id
def test_chat_manager_register_message_uses_route_metadata() -> None:
chat_manager = ChatManager()
message = SimpleNamespace(
platform="qq",
session_id="",
message_info=SimpleNamespace(
user_info=SimpleNamespace(user_id="42"),
group_info=SimpleNamespace(group_id="1000"),
additional_config={
"platform_io_account_id": "123",
"platform_io_scope": "main",
},
),
)
chat_manager.register_message(message)
assert message.session_id == SessionUtils.calculate_session_id(
"qq",
user_id="42",
group_id="1000",
account_id="123",
scope="main",
)
assert chat_manager.last_messages[message.session_id] is message

View File

@@ -1,27 +1,28 @@
import time
"""PFC 侧消息发送封装。"""
from typing import Optional
from maim_message import Seg
from rich.traceback import install
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.utils import get_bot_account
from src.common.data_models.mai_message_data_model import MaiMessage
from src.common.logger import get_logger
from src.config.config import global_config
from src.services import send_service as send_api
install(extra_lines=3)
logger = get_logger("message_sender")
class DirectMessageSender:
"""直接消息发送器"""
"""直接消息发送器"""
def __init__(self, private_name: str):
def __init__(self, private_name: str) -> None:
"""初始化直接消息发送器。
Args:
private_name: 当前私聊实例的名称。
"""
self.private_name = private_name
async def send_message(
@@ -30,58 +31,31 @@ class DirectMessageSender:
content: str,
reply_to_message: Optional[MaiMessage] = None,
) -> None:
"""发送消息到聊天流
"""发送文本消息到聊天流
Args:
chat_stream: 聊天会话
content: 消息内容
reply_to_message: 要回复的消息(可选)
chat_stream: 目标聊天会话
content: 待发送的文本内容
reply_to_message: 可选的引用回复锚点消息。
Raises:
RuntimeError: 当消息发送失败时抛出。
"""
try:
# 创建消息内容
segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
# 获取麦麦的信息
bot_user_id = get_bot_account(chat_stream.platform)
if not bot_user_id:
logger.error(f"[私聊][{self.private_name}]平台 {chat_stream.platform} 未配置机器人账号,无法发送消息")
raise RuntimeError(f"平台 {chat_stream.platform} 未配置机器人账号")
bot_user_info = UserInfo(
user_id=bot_user_id,
user_nickname=global_config.bot.nickname,
sent = await send_api.text_to_stream(
text=content,
stream_id=chat_stream.session_id,
set_reply=reply_to_message is not None,
reply_message=reply_to_message,
storage_message=True,
)
# 用当前时间作为message_id和之前那套sender一样
message_id = f"dm{round(time.time(), 2)}"
# 构建发送者信息(私聊时为接收者)
sender_info = None
if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info:
sender_info = reply_to_message.message_info.user_info
# 构建消息对象
message = MessageSending(
message_id=message_id,
session=chat_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
message_segment=segments,
reply=reply_to_message,
is_head=True,
is_emoji=False,
thinking_start_time=time.time(),
)
# 发送消息
message_sender = UniversalMessageSender()
sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True)
if sent:
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
else:
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
raise RuntimeError("消息发送失败")
return
except Exception as e:
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
raise RuntimeError("消息发送失败")
except Exception as exc:
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}")
raise

View File

@@ -8,8 +8,8 @@ from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.utils.utils_config import ExpressionConfigUtils
from src.bw_learner.expression_learner import ExpressionLearner
from src.bw_learner.jargon_miner import JargonMiner
from src.learners.expression_learner import ExpressionLearner
from src.learners.jargon_miner import JargonMiner
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage

View File

@@ -1,30 +1,32 @@
from datetime import datetime
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import json
import time
import traceback
import random
import re
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
import time
import traceback
from json_repair import repair_json
from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.logger import get_logger
from src.common.utils.utils_action import ActionUtils
from src.config.config import global_config, model_config
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.llm_models.utils_model import LLMRequest
from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
from src.services.message_service import (
build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
get_messages_before_time_in_chat,
)
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.core.component_registry import component_registry
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
@@ -320,7 +322,7 @@ class BrainPlanner:
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}

View File

@@ -1,734 +0,0 @@
import asyncio
import time
import traceback
import random
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.planner_actions.planner import ActionPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.bw_learner.expression_learner import expression_learner_manager
from src.chat.heart_flow.frequency_control import frequency_control_manager
from src.bw_learner.message_recorder import extract_and_distribute_messages
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
)
from src.chat.utils.utils import record_replyer_action_temp
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ReplySetModel
ERROR_LOOP_INFO = {
"loop_plan_info": {
"action_result": {
"action_type": "error",
"action_data": {},
"reasoning": "循环处理失败",
},
},
"loop_action_info": {
"action_taken": False,
"reply_text": "",
"command": "",
"taken_time": time.time(),
},
}
install(extra_lines=3)
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
logger = get_logger("hfc") # Logger Name Changed
class HeartFChatting:
"""
管理一个连续的Focus Chat循环
用于在特定聊天流中生成回复。
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
"""
def __init__(self, chat_id: str):
"""
HeartFChatting 初始化函数
参数:
chat_id: 聊天流唯一标识符(如stream_id)
on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
performance_version: 性能记录版本号,用于区分不同启动版本
"""
# 基础属性
self.stream_id: str = chat_id # 聊天流ID
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
if not self.chat_stream:
raise ValueError(f"无法找到聊天流: {self.stream_id}")
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
# 循环控制内部状态
self.running: bool = False
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
# 添加循环信息管理相关的属性
self.history_loop: List[CycleDetail] = []
self._cycle_counter = 0
self._current_cycle_detail: CycleDetail = None # type: ignore
self.last_read_time = time.time() - 2
self.is_mute = False
self.last_active_time = time.time() # 记录上一次非noreply时间
self.question_probability_multiplier = 1
self.questioned = False
# 跟踪连续 no_reply 次数,用于动态调整阈值
self.consecutive_no_reply_count = 0
# 聊天内容概括器
self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id)
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
# 如果循环已经激活,直接返回
if self.running:
logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动")
return
try:
# 标记为活动状态,防止重复启动
self.running = True
self._loop_task = asyncio.create_task(self._main_chat_loop())
self._loop_task.add_done_callback(self._handle_loop_completion)
# 启动聊天内容概括器的后台定期检查循环
await self.chat_history_summarizer.start()
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
except Exception as e:
# 启动失败时重置状态
self.running = False
self._loop_task = None
logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}")
raise
def _handle_loop_completion(self, task: asyncio.Task):
"""当 _hfc_loop 任务完成时执行的回调。"""
try:
if exception := task.exception():
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
logger.error(traceback.format_exc()) # Log full traceback for exceptions
else:
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
def start_cycle(self) -> Tuple[Dict[str, float], str]:
self._cycle_counter += 1
self._current_cycle_detail = CycleDetail(self._cycle_counter)
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
cycle_timers = {}
return cycle_timers, self._current_cycle_detail.thinking_id
def end_cycle(self, loop_info, cycle_timers):
self._current_cycle_detail.set_loop_info(loop_info)
self.history_loop.append(self._current_cycle_detail)
self._current_cycle_detail.timers = cycle_timers
self._current_cycle_detail.end_time = time.time()
def print_cycle_info(self, cycle_timers):
# 记录循环信息和计时器结果
timer_strings = []
for name, elapsed in cycle_timers.items():
if elapsed < 0.1:
# 不显示小于0.1秒的计时器
continue
formatted_time = f"{elapsed:.2f}"
timer_strings.append(f"{name}: {formatted_time}")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
async def _loopbody(self):
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
end_time=time.time(),
limit=20,
limit_mode="latest",
filter_mai=True,
filter_command=False,
filter_intercept_message_level=0,
)
# 根据连续 no_reply 次数动态调整阈值
# 3次 no_reply 时,阈值调高到 1.550%概率为150%概率为2
# 5次 no_reply 时,提高到 2大于等于两条消息的阈值
if self.consecutive_no_reply_count >= 5:
threshold = 2
elif self.consecutive_no_reply_count >= 3:
# 1.5 的含义50%概率为150%概率为2
threshold = 2 if random.random() < 0.5 else 1
else:
threshold = 1
if len(recent_messages_list) >= threshold:
# for message in recent_messages_list:
# print(message.processed_plain_text)
self.last_read_time = time.time()
# !此处使at或者提及必定回复
mentioned_message = None
for message in recent_messages_list:
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
mentioned_message = message
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
# *控制频率用
if mentioned_message:
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
elif (
random.random()
< global_config.chat.get_talk_value(self.stream_id)
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
):
await self._observe(recent_messages_list=recent_messages_list)
else:
# 没有提到继续保持沉默等待5秒防止频繁触发
await asyncio.sleep(10)
return True
else:
await asyncio.sleep(0.2)
return True
return True
async def _send_and_store_reply(
self,
response_set: "ReplySetModel",
action_message: "DatabaseMessages",
cycle_timers: Dict[str, float],
thinking_id,
actions,
selected_expressions: Optional[List[int]] = None,
quote_message: Optional[bool] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(
reply_set=response_set,
message_data=action_message,
selected_expressions=selected_expressions,
quote_message=quote_message,
)
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
platform = action_message.chat_info.platform
if platform is None:
platform = getattr(self.chat_stream, "platform", "unknown")
person = Person(platform=platform, user_id=action_message.user_info.user_id)
person_name = person.person_name
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=action_prompt_display,
action_done=True,
thinking_id=thinking_id,
action_data={"reply_text": reply_text},
action_name="reply",
)
# 构建循环信息
loop_info: Dict[str, Any] = {
"loop_plan_info": {
"action_result": actions,
},
"loop_action_info": {
"action_taken": True,
"reply_text": reply_text,
"command": "",
"taken_time": time.time(),
},
}
return loop_info, reply_text, cycle_timers
async def _observe(
self, # interest_value: float = 0.0,
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
force_reply_message: Optional["DatabaseMessages"] = None,
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
if recent_messages_list is None:
recent_messages_list = []
_reply_text = "" # 初始化reply_text变量避免UnboundLocalError
start_time = time.time()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
# asyncio.create_task(check_and_make_question(self.stream_id))
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
# 注意后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
# asyncio.create_task(self.chat_history_summarizer.process())
cycle_timers, thinking_id = self.start_cycle()
logger.info(
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
)
# 第一步:动作检查
available_actions: Dict[str, ActionInfo] = {}
try:
await self.action_modifier.modify_actions()
available_actions = self.action_manager.get_using_actions()
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
filter_intercept_message_level=1,
)
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.action_planner.last_obs_time_mark,
truncate=True,
show_actions=True,
)
prompt_info = await self.action_planner.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=available_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
)
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
)
if not continue_flag:
return False
if modified_message and modified_message._modify_flags.modify_llm_prompt:
prompt_info = (modified_message.llm_prompt, prompt_info[1])
with Timer("规划器", cycle_timers):
action_to_use_info = await self.action_planner.plan(
loop_start_time=self.last_read_time,
available_actions=available_actions,
force_reply_message=force_reply_message,
)
logger.info(
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
)
# 3. 并行执行所有动作
action_tasks = [
asyncio.create_task(
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
)
for action in action_to_use_info
]
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 处理执行结果
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
excute_result_str = ""
for result in results:
excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
if result["action_type"] != "reply":
action_success = result["success"]
action_reply_text = result["result"]
elif result["action_type"] == "reply":
if result["success"]:
reply_loop_info = result["loop_info"]
reply_text_from_reply = result["result"]
else:
logger.warning(f"{self.log_prefix} 回复动作执行失败")
self.action_planner.add_plan_excute_log(result=excute_result_str)
# 构建最终的循环信息
if reply_loop_info:
# 如果有回复信息使用回复的loop_info作为基础
loop_info = reply_loop_info
# 更新动作执行信息
loop_info["loop_action_info"].update(
{
"action_taken": action_success,
"taken_time": time.time(),
}
)
_reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
"loop_plan_info": {
"action_result": action_to_use_info,
},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"taken_time": time.time(),
},
}
_reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
end_time = time.time()
if end_time - start_time < global_config.chat.planner_smooth:
wait_time = global_config.chat.planner_smooth - (end_time - start_time)
await asyncio.sleep(wait_time)
else:
await asyncio.sleep(0.1)
return True
async def _main_chat_loop(self):
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
try:
while self.running:
# 主循环
success = await self._loopbody()
await asyncio.sleep(0.1)
if not success:
break
except asyncio.CancelledError:
# 设置了关闭标志位后被取消是正常流程
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
except Exception:
logger.error(f"{self.log_prefix} 麦麦聊天意外错误将于3s后尝试重新启动")
print(traceback.format_exc())
await asyncio.sleep(3)
self._loop_task = asyncio.create_task(self._main_chat_loop())
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
async def _handle_action(
self,
action: str,
action_reasoning: str,
action_data: dict,
cycle_timers: Dict[str, float],
thinking_id: str,
action_message: Optional["DatabaseMessages"] = None,
) -> tuple[bool, str, str]:
"""
处理规划动作,使用动作工厂创建相应的动作处理器
参数:
action: 动作类型
action_reasoning: 决策理由
action_data: 动作数据,包含不同动作需要的参数
cycle_timers: 计时器字典
thinking_id: 思考ID
action_message: 消息数据
返回:
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
"""
try:
# 使用工厂创建动作处理器实例
try:
action_handler = self.action_manager.create_action(
action_name=action,
action_data=action_data,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
chat_stream=self.chat_stream,
log_prefix=self.log_prefix,
action_reasoning=action_reasoning,
action_message=action_message,
)
except Exception as e:
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
traceback.print_exc()
return False, ""
# 处理动作并获取结果(固定记录一次动作信息)
result = await action_handler.execute()
success, action_text = result
return success, action_text
except Exception as e:
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
traceback.print_exc()
return False, ""
async def _send_response(
self,
reply_set: "ReplySetModel",
message_data: "DatabaseMessages",
selected_expressions: Optional[List[int]] = None,
quote_message: Optional[bool] = None,
) -> str:
# 根据 llm_quote 配置决定是否使用 quote_message 参数
if global_config.chat.llm_quote:
# 如果配置为 true使用 llm_quote 参数决定是否引用回复
if quote_message is None:
logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
need_reply = False
else:
need_reply = quote_message
if need_reply:
logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
else:
# 如果配置为 false使用原来的模式
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息使用引用回复或者上次回复时间超过90秒")
reply_text = ""
first_replied = False
for reply_content in reply_set.reply_data:
if reply_content.content_type != ReplyContentType.TEXT:
continue
data: str = reply_content.content # type: ignore
if not first_replied:
await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_message=message_data,
set_reply=need_reply,
typing=False,
selected_expressions=selected_expressions,
)
first_replied = True
else:
await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_message=message_data,
set_reply=False,
typing=True,
selected_expressions=selected_expressions,
)
reply_text += data
return reply_text
async def _execute_action(
self,
action_planner_info: ActionPlannerInfo,
chosen_action_plan_infos: List[ActionPlannerInfo],
thinking_id: str,
available_actions: Dict[str, ActionInfo],
cycle_timers: Dict[str, float],
):
"""执行单个动作的通用函数"""
try:
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
# 直接当场执行no_reply逻辑
if action_planner_info.action_type == "no_reply":
# 直接处理no_reply逻辑不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 增加连续 no_reply 计数
self.consecutive_no_reply_count += 1
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="no_reply",
action_reasoning=reason,
)
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
elif action_planner_info.action_type == "reply":
# 直接当场执行reply逻辑
self.questioned = False
# 刷新主动发言状态
# 重置连续 no_reply 计数
self.consecutive_no_reply_count = 0
reason = action_planner_info.reasoning or ""
# 根据 think_mode 配置决定 think_level 的值
think_mode = global_config.chat.think_mode
if think_mode == "default":
think_level = 0
elif think_mode == "deep":
think_level = 1
elif think_mode == "dynamic":
# dynamic 模式:从 planner 返回的 action_data 中获取
think_level = action_planner_info.action_data.get("think_level", 1)
else:
# 默认使用 default 模式
think_level = 0
# 使用 action_reasoningplanner 的整体思考理由)作为 reply_reason
planner_reasoning = action_planner_info.action_reasoning or reason
record_replyer_action_temp(
chat_id=self.stream_id,
reason=reason,
think_level=think_level,
)
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="reply",
action_reasoning=reason,
)
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
unknown_words = None
quote_message = None
if isinstance(action_planner_info.action_data, dict):
uw = action_planner_info.action_data.get("unknown_words")
if isinstance(uw, list):
cleaned_uw: List[str] = []
for item in uw:
if isinstance(item, str):
s = item.strip()
if s:
cleaned_uw.append(s)
if cleaned_uw:
unknown_words = cleaned_uw
# 从 Planner 的 action_data 中提取 quote_message 参数
qm = action_planner_info.action_data.get("quote")
if qm is not None:
# 支持多种格式true/false, "true"/"false", 1/0
if isinstance(qm, bool):
quote_message = qm
elif isinstance(qm, str):
quote_message = qm.lower() in ("true", "1", "yes")
elif isinstance(qm, (int, float)):
quote_message = bool(qm)
logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_planner_info.action_message,
available_actions=available_actions,
chosen_actions=chosen_action_plan_infos,
reply_reason=planner_reasoning,
unknown_words=unknown_words,
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
think_level=think_level,
)
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply(
response_set=response_set,
action_message=action_planner_info.action_message, # type: ignore
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
quote_message=quote_message,
)
self.last_active_time = time.time()
return {
"action_type": "reply",
"success": True,
"result": f"你使用reply动作' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'",
"loop_info": loop_info,
}
else:
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, result = await self._handle_action(
action=action_planner_info.action_type,
action_reasoning=action_planner_info.action_reasoning or "",
action_data=action_planner_info.action_data or {},
cycle_timers=cycle_timers,
thinking_id=thinking_id,
action_message=action_planner_info.action_message,
)
self.last_active_time = time.time()
return {
"action_type": action_planner_info.action_type,
"success": success,
"result": result,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
return {
"action_type": action_planner_info.action_type,
"success": False,
"result": "",
"loop_info": None,
"error": str(e),
}

View File

@@ -1,377 +1,231 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from rich.traceback import install
from typing import List, Optional, TYPE_CHECKING
import asyncio
import random
import time
import traceback
from rich.traceback import install
from src.bw_learner.expression_learner import ExpressionLearner
from src.bw_learner.jargon_miner import JargonMiner
from src.chat.event_helpers import build_event_message
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.planner import ActionPlanner
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import record_replyer_action_temp
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.chat.message_receive.chat_manager import chat_manager
from src.common.logger import get_logger
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
from src.config.config import global_config
from src.config.file_watcher import FileChange
from src.core.event_bus import event_bus
from src.core.types import ActionInfo, EventType
from src.person_info.person_info import Person
from src.services import (
database_service as database_api,
generator_service as generator_api,
message_service as message_api,
send_service as send_api,
)
from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
from src.learners.expression_learner import ExpressionLearner
from src.learners.jargon_miner import JargonMiner
from .heartFC_utils import CycleDetail
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
install(extra_lines=5)
logger = get_logger("heartFC_chat")
class HeartFChatting:
"""管理一个持续运行的 Focus Chat 会话。"""
"""
管理一个连续的Focus Chat聊天会话
用于在特定的聊天会话里面生成回复
"""
def __init__(self, session_id: str):
self.session_id = session_id
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.session_id) # type: ignore[assignment]
if not self.chat_stream:
raise ValueError(f"无法找到聊天会话 {self.session_id}")
"""
初始化 HeartFChatting 实例
session_name = _chat_manager.get_session_name(session_id) or session_id
Args:
session_id: 聊天会话ID
"""
# 基础属性
self.session_id = session_id
session_name = chat_manager.get_session_name(session_id) or session_id
self.log_prefix = f"[{session_name}]"
self.session_name = session_name
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(chat_id=self.session_id, action_manager=self.action_manager)
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.session_id)
# 系统运行状态
self._running: bool = False
self._loop_task: Optional[asyncio.Task] = None
self._cycle_counter: int = 0
self._hfc_lock: asyncio.Lock = asyncio.Lock() # 用于保护 _hfc_func 的并发访问
# 聊天频率相关
self._consecutive_no_reply_count = 0 # 跟踪连续 no_reply 次数,用于动态调整阈值
self._talk_frequency_adjust: float = 1.0 # 发言频率修正值默认为1.0,可以根据需要调整
# HFC内消息缓存
self.message_cache: List[SessionMessage] = []
# Asyncio Event 用于控制循环的开始和结束
self._cycle_event = asyncio.Event()
self._hfc_lock = asyncio.Lock()
self._cycle_counter = 0
self._current_cycle_detail: Optional[CycleDetail] = None
self.history_loop: List[CycleDetail] = []
self.last_read_time = time.time() - 2
self.last_active_time = time.time()
self._talk_frequency_adjust = 1.0
self._consecutive_no_reply_count = 0
self.message_cache: List["SessionMessage"] = []
self._min_messages_for_extraction = 30
self._min_extraction_interval = 60
self._last_extraction_time = 0.0
# 表达方式相关内容
self._min_messages_for_extraction = 30 # 最少提取消息数
self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒
self._last_extraction_time: float = 0.0 # 上次提取的时间戳
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
self._enable_expression_use = expr_use
self._enable_expression_learning = expr_learn
self._enable_jargon_learning = jargon_learn
self._expression_learner = ExpressionLearner(session_id)
self._jargon_miner = JargonMiner(session_id, session_name=session_name)
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
self._enable_expression_learning = expr_learn # 允许学习表达方式
self._enable_jargon_learning = jargon_learn # 允许学习黑话
# 表达学习器
self._expression_learner: ExpressionLearner = ExpressionLearner(session_id)
# 黑话挖掘器
self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name)
# TODO: ChatSummarizer 聊天总结器重构
# ====== 公开方法 ======
async def start(self):
"""启动 HeartFChatting 的主循环"""
# 先检查是否已经启动运行
if self._running:
logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中")
logger.debug(f"{self.log_prefix} 已经在运行中,无需重复启动")
return
try:
self._running = True
self._cycle_event.clear()
self._cycle_event.clear() # 确保事件初始状态为未设置
self._loop_task = asyncio.create_task(self.main_loop())
self._loop_task.add_done_callback(self._handle_loop_completion)
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
except Exception as exc:
logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {exc}", exc_info=True)
self._running = False
self._cycle_event.set()
self._loop_task = None
except Exception as e:
logger.error(f"{self.log_prefix} 启动 HeartFChatting 失败: {e}", exc_info=True)
self._running = False # 确保状态正确
self._cycle_event.set() # 确保事件被设置,避免死锁
self._loop_task = None # 确保任务引用被清理
raise
async def stop(self):
"""停止 HeartFChatting 的主循环"""
if not self._running:
logger.debug(f"{self.log_prefix} HeartFChatting 已停止")
logger.debug(f"{self.log_prefix} HeartFChatting 已经停止,无需重复停止")
return
self._running = False
self._cycle_event.set()
self._cycle_event.set() # 触发事件,通知循环结束
if self._loop_task:
self._loop_task.cancel()
self._loop_task.cancel() # 取消主循环任务
try:
await self._loop_task
await self._loop_task # 等待任务完成
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} HeartFChatting 主循环已取消")
except Exception as exc:
logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {exc}", exc_info=True)
logger.info(f"{self.log_prefix} HeartFChatting 主循环已成功取消")
except Exception as e:
logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {e}", exc_info=True)
finally:
self._loop_task = None
self._loop_task = None # 确保任务引用被清理
logger.info(f"{self.log_prefix} HeartFChatting 已停止")
def adjust_talk_frequency(self, new_value: float):
"""调整发言频率的调整值
Args:
new_value: 新的修正值,必须为非负数。值越大,修正发言频率越高;值越小,修正发言频率越低。
"""
self._talk_frequency_adjust = max(0.0, new_value)
async def register_message(self, message: "SessionMessage"):
"""注册一条消息到 HeartFChatting 的缓存中,并检测其是否产生提及,决定是否唤醒聊天
Args:
message: 待注册的消息对象
"""
self.message_cache.append(message)
# 先检查at必回复
if global_config.chat.inevitable_at_reply and message.is_at:
self.last_read_time = time.time()
async with self._hfc_lock:
await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
return
async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
await self._judge_and_response(message)
return # 直接返回,避免同一条消息被主循环再次处理
# 再检查提及必回复
if global_config.chat.mentioned_bot_reply and message.is_mentioned:
self.last_read_time = time.time()
async with self._hfc_lock:
await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
# 直接获取锁,确保一定一定触发回复逻辑,不受当前是否正在执行主循环的影响
async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
await self._judge_and_response(message)
return
async def main_loop(self):
try:
while self._running and not self._cycle_event.is_set():
if not self._hfc_lock.locked():
async with self._hfc_lock:
async with self._hfc_lock: # 确保主循环逻辑的互斥访问
await self._hfc_func()
await asyncio.sleep(0.1)
await asyncio.sleep(5)
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消")
except Exception as exc:
logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常: {exc}", exc_info=True)
await self.stop()
logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消,正在关闭")
except Exception as e:
logger.error(f"{self.log_prefix} 麦麦聊天意外错误: {e}将于3s后尝试重新启动")
await self.stop() # 确保状态正确
await asyncio.sleep(3)
await self.start()
await self.start() # 尝试重新启动
async def _config_callback(self, file_change: Optional[FileChange] = None):
del file_change
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.session_id)
self._enable_expression_use = expr_use
self._enable_expression_learning = expr_learn
self._enable_jargon_learning = jargon_learn
"""配置文件变更回调函数"""
# TODO: 根据配置文件变动重新计算相关参数:
"""
需要计算的参数:
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
self._enable_expression_learning = expr_learn # 允许学习表达方式
self._enable_jargon_learning = jargon_learn # 允许学习黑话
"""
async def _hfc_func(self):
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.session_id,
start_time=self.last_read_time,
end_time=time.time(),
limit=20,
limit_mode="latest",
filter_mai=True,
filter_command=False,
filter_intercept_message_level=1,
)
# ====== 心流聊天核心逻辑 ======
async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None):
"""心流聊天的主循环逻辑"""
if self._consecutive_no_reply_count >= 5:
threshold = 2
elif self._consecutive_no_reply_count >= 3:
threshold = 2 if random.random() < 0.5 else 1
else:
threshold = 1
if len(recent_messages_list) < 1:
if len(self.message_cache) < threshold:
await asyncio.sleep(0.2)
return True
self.last_read_time = time.time()
mentioned_message: Optional["SessionMessage"] = None
for message in recent_messages_list:
if global_config.chat.inevitable_at_reply and message.is_at:
mentioned_message = message
elif global_config.chat.mentioned_bot_reply and message.is_mentioned:
mentioned_message = message
talk_value = ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
if mentioned_message:
await self._judge_and_response(mentioned_message=mentioned_message, recent_messages_list=recent_messages_list)
elif random.random() < talk_value:
await self._judge_and_response(recent_messages_list=recent_messages_list)
talk_value_threshold = (
random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
)
if mentioned_message and global_config.chat.mentioned_bot_reply:
await self._judge_and_response(mentioned_message)
elif random.random() < talk_value_threshold:
await self._judge_and_response()
return True
async def _judge_and_response(
self,
mentioned_message: Optional["SessionMessage"] = None,
recent_messages_list: Optional[List["SessionMessage"]] = None,
):
recent_messages = list(recent_messages_list or self.message_cache[-20:])
if recent_messages:
asyncio.create_task(self._trigger_expression_learning(recent_messages))
cycle_timers, thinking_id = self._start_cycle()
async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None):
"""判定和生成回复"""
asyncio.create_task(self._trigger_expression_learning(self.message_cache))
# TODO: 完成反思器之后的逻辑
start_time = time.time()
current_cycle_detail = self._start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
try:
async with global_prompt_manager.async_message_scope(self._get_template_name()):
available_actions: Dict[str, ActionInfo] = {}
try:
await self.action_modifier.modify_actions()
available_actions = self.action_manager.get_using_actions()
except Exception as exc:
logger.error(f"{self.log_prefix} 动作修改失败: {exc}", exc_info=True)
# TODO: 动作检查逻辑
# TODO: Planner逻辑
# TODO: 动作执行逻辑
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
message_list_before_now = get_messages_before_time_in_chat(
chat_id=self.session_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
filter_intercept_message_level=1,
)
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.action_planner.last_obs_time_mark,
truncate=True,
show_actions=True,
)
prompt, filtered_actions = await self._build_planner_prompt_with_event(
available_actions=available_actions,
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
)
if prompt is None:
return False
with Timer("规划器", cycle_timers):
reasoning, action_to_use_info, llm_raw_output, llm_reasoning, llm_duration_ms = (
await self.action_planner._execute_main_planner(
prompt=prompt,
message_id_list=message_id_list,
filtered_actions=filtered_actions,
available_actions=available_actions,
loop_start_time=self.last_read_time,
)
)
action_to_use_info = self._ensure_force_reply_action(
actions=action_to_use_info,
force_reply_message=mentioned_message,
available_actions=available_actions,
)
self.action_planner.add_plan_log(reasoning, action_to_use_info)
self.action_planner.last_obs_time_mark = time.time()
self._log_plan(
prompt=prompt,
reasoning=reasoning,
llm_raw_output=llm_raw_output,
llm_reasoning=llm_reasoning,
llm_duration_ms=llm_duration_ms,
actions=action_to_use_info,
)
logger.info(
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
)
action_tasks = [
asyncio.create_task(
self._execute_action(
action,
action_to_use_info,
thinking_id,
available_actions,
cycle_timers,
)
)
for action in action_to_use_info
]
results = await asyncio.gather(*action_tasks, return_exceptions=True)
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
execute_result_str = ""
for result in results:
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}", exc_info=True)
continue
execute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
if result["action_type"] == "reply":
if result["success"]:
reply_loop_info = result["loop_info"]
reply_text_from_reply = result["result"]
else:
logger.warning(f"{self.log_prefix} reply 动作执行失败")
else:
action_success = result["success"]
action_reply_text = result["result"]
self.action_planner.add_plan_excute_log(result=execute_result_str)
if reply_loop_info:
loop_info = reply_loop_info
loop_info["loop_action_info"].update(
{
"action_taken": action_success,
"taken_time": time.time(),
}
)
else:
loop_info = {
"loop_plan_info": {
"action_result": action_to_use_info,
},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"taken_time": time.time(),
},
}
reply_text_from_reply = action_reply_text
current_cycle_detail = self._end_cycle(self._current_cycle_detail, loop_info)
logger.debug(f"{self.log_prefix} 本轮最终输出: {reply_text_from_reply}")
return current_cycle_detail is not None
except Exception as exc:
logger.error(f"{self.log_prefix} 判定与回复流程失败: {exc}", exc_info=True)
if self._current_cycle_detail:
self._end_cycle(
self._current_cycle_detail,
{
"loop_plan_info": {"action_result": []},
"loop_action_info": {
"action_taken": False,
"reply_text": "",
"taken_time": time.time(),
"error": str(exc),
},
},
)
return False
cycle_detail = self._end_cycle(current_cycle_detail)
if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0:
await asyncio.sleep(wait_time)
else:
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
return True
def _handle_loop_completion(self, task: asyncio.Task):
"""当 _hfc_func 任务完成时执行的回调。"""
try:
if exception := task.exception():
logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常退出: {exception}")
logger.error(traceback.format_exc())
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
logger.error(traceback.format_exc()) # Log full traceback for exceptions
else:
logger.info(f"{self.log_prefix} HeartFChatting: 主循环已退出")
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} HeartFChatting: 聊天已结束")
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
# ====== 学习器触发逻辑 ======
async def _trigger_expression_learning(self, messages: List["SessionMessage"]):
if not messages:
return
self._expression_learner.add_messages(messages)
if time.time() - self._last_extraction_time < self._min_extraction_interval:
return
@@ -379,14 +233,12 @@ class HeartFChatting:
return
if not self._enable_expression_learning:
return
extraction_end_time = time.time()
logger.info(
f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}"
)
self._last_extraction_time = extraction_end_time
try:
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
learnt_style = await self._expression_learner.learn(jargon_miner)
@@ -394,398 +246,43 @@ class HeartFChatting:
logger.info(f"{self.log_prefix} 表达学习完成")
else:
logger.debug(f"{self.log_prefix} 表达学习未获得有效结果")
except Exception as exc:
logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True)
except Exception as e:
logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True)
def _start_cycle(self) -> Tuple[Dict[str, float], str]:
# ====== 记录循环执行信息相关逻辑 ======
def _start_cycle(self) -> CycleDetail:
self._cycle_counter += 1
self._current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
return self._current_cycle_detail.time_records, self._current_cycle_detail.thinking_id
current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
return current_cycle_detail
def _end_cycle(self, cycle_detail: Optional[CycleDetail], loop_info: Optional[Dict[str, Any]] = None):
if cycle_detail is None:
return None
cycle_detail.loop_plan_info = (loop_info or {}).get("loop_plan_info")
cycle_detail.loop_action_info = (loop_info or {}).get("loop_action_info")
def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True):
cycle_detail.end_time = time.time()
self.history_loop.append(cycle_detail)
timer_strings = [
timer_strings: List[str] = [
f"{name}: {duration:.2f}s"
for name, duration in cycle_detail.time_records.items()
if duration >= 0.1
if not only_long_execution or duration >= 0.1
]
logger.info(
f"{self.log_prefix}{cycle_detail.cycle_id} 个心流循环完成"
f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}s"
f"{self.log_prefix} {cycle_detail.cycle_id} 个心流循环完成"
f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}\n"
f"详细计时: {', '.join(timer_strings) if timer_strings else ''}"
)
return cycle_detail
async def _execute_action(
self,
action_planner_info: ActionPlannerInfo,
chosen_action_plan_infos: List[ActionPlannerInfo],
thinking_id: str,
available_actions: Dict[str, ActionInfo],
cycle_timers: Dict[str, float],
):
try:
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
if action_planner_info.action_type == "no_reply":
reason = action_planner_info.reasoning or "选择不回复"
self._consecutive_no_reply_count += 1
await database_api.store_action_info(
chat_stream=self.chat_stream,
display_prompt=reason,
thinking_id=thinking_id,
action_data={},
action_name="no_reply",
action_reasoning=reason,
)
return {
"action_type": "no_reply",
"success": True,
"result": "选择不回复",
"loop_info": None,
}
# ====== Action相关逻辑 ======
async def _execute_action(self, *args, **kwargs):
"""原ExecuteAction"""
raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符
if action_planner_info.action_type == "reply":
self._consecutive_no_reply_count = 0
reason = action_planner_info.reasoning or ""
think_level = self._get_think_level(action_planner_info)
planner_reasoning = action_planner_info.action_reasoning or reason
async def _execute_other_actions(self, *args, **kwargs):
"""原HandleAction"""
raise NotImplementedError(
"执行其他动作的逻辑尚未实现"
) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符
record_replyer_action_temp(
chat_id=self.session_id,
reason=reason,
think_level=think_level,
)
await database_api.store_action_info(
chat_stream=self.chat_stream,
display_prompt=reason,
thinking_id=thinking_id,
action_data={},
action_name="reply",
action_reasoning=reason,
)
unknown_words, quote_message = self._extract_reply_metadata(action_planner_info)
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_planner_info.action_message,
available_actions=available_actions,
chosen_actions=chosen_action_plan_infos,
reply_reason=planner_reasoning,
unknown_words=unknown_words,
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time())
if action_planner_info.action_data
else time.time(),
think_level=think_level,
)
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else:
logger.info(f"{self.log_prefix} 回复生成失败")
return {
"action_type": "reply",
"success": False,
"result": "回复生成失败",
"loop_info": None,
}
loop_info, reply_text, _ = await self._send_and_store_reply(
response_set=llm_response.reply_set,
action_message=action_planner_info.action_message, # type: ignore[arg-type]
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=llm_response.selected_expressions,
quote_message=quote_message,
)
self.last_active_time = time.time()
return {
"action_type": "reply",
"success": True,
"result": reply_text,
"loop_info": loop_info,
}
with Timer("动作执行", cycle_timers):
success, result = await self._handle_action(
action=action_planner_info.action_type,
action_reasoning=action_planner_info.action_reasoning or "",
action_data=action_planner_info.action_data or {},
cycle_timers=cycle_timers,
thinking_id=thinking_id,
action_message=action_planner_info.action_message,
)
if success:
self.last_active_time = time.time()
return {
"action_type": action_planner_info.action_type,
"success": success,
"result": result,
"loop_info": None,
}
except Exception as exc:
logger.error(f"{self.log_prefix} 执行动作时出错: {exc}", exc_info=True)
return {
"action_type": action_planner_info.action_type,
"success": False,
"result": "",
"loop_info": None,
"error": str(exc),
}
async def _handle_action(
self,
action: str,
action_reasoning: str,
action_data: dict,
cycle_timers: Dict[str, float],
thinking_id: str,
action_message: Optional["SessionMessage"] = None,
) -> Tuple[bool, str]:
try:
action_handler = self.action_manager.create_action(
action_name=action,
action_data=action_data,
action_reasoning=action_reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
chat_stream=self.chat_stream,
log_prefix=self.log_prefix,
action_message=action_message,
)
if not action_handler:
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
return False, ""
success, action_text = await action_handler.execute()
return success, action_text
except Exception as exc:
logger.error(f"{self.log_prefix} 处理动作 {action} 时出错: {exc}", exc_info=True)
return False, ""
async def _send_and_store_reply(
self,
response_set: MessageSequence,
action_message: "SessionMessage",
cycle_timers: Dict[str, float],
thinking_id: str,
actions: List[ActionPlannerInfo],
selected_expressions: Optional[List[int]] = None,
quote_message: Optional[bool] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(
reply_set=response_set,
message_data=action_message,
selected_expressions=selected_expressions,
quote_message=quote_message,
)
platform = action_message.platform or getattr(self.chat_stream, "platform", "unknown")
person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id)
action_prompt_display = f"你对{person.person_name}进行了回复:{reply_text}"
await database_api.store_action_info(
chat_stream=self.chat_stream,
display_prompt=action_prompt_display,
thinking_id=thinking_id,
action_data={"reply_text": reply_text},
action_name="reply",
)
loop_info: Dict[str, Any] = {
"loop_plan_info": {
"action_result": actions,
},
"loop_action_info": {
"action_taken": True,
"reply_text": reply_text,
"command": "",
"taken_time": time.time(),
},
}
return loop_info, reply_text, cycle_timers
async def _send_response(
self,
reply_set: MessageSequence,
message_data: "SessionMessage",
selected_expressions: Optional[List[int]] = None,
quote_message: Optional[bool] = None,
) -> str:
if global_config.chat.llm_quote:
need_reply = bool(quote_message)
else:
new_message_count = message_api.count_new_messages(
chat_id=self.session_id,
start_time=self.last_read_time,
end_time=time.time(),
)
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
reply_text = ""
first_replied = False
for component in reply_set.components:
if not isinstance(component, TextComponent):
continue
data = component.text
if not first_replied:
await send_api.text_to_stream(
text=data,
stream_id=self.session_id,
reply_message=message_data,
set_reply=need_reply,
typing=False,
selected_expressions=selected_expressions,
)
first_replied = True
else:
await send_api.text_to_stream(
text=data,
stream_id=self.session_id,
reply_message=message_data,
set_reply=False,
typing=True,
selected_expressions=selected_expressions,
)
reply_text += data
return reply_text
async def _build_planner_prompt_with_event(
self,
available_actions: Dict[str, ActionInfo],
is_group_chat: bool,
chat_target_info: Any,
chat_content_block: str,
message_id_list: List[Tuple[str, "SessionMessage"]],
) -> Tuple[Optional[str], Dict[str, ActionInfo]]:
filtered_actions = self.action_planner._filter_actions_by_activation_type(available_actions, chat_content_block)
prompt, _ = await self.action_planner.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=filtered_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
)
event_message = build_event_message(EventType.ON_PLAN, llm_prompt=prompt, stream_id=self.session_id)
continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, event_message)
if not continue_flag:
logger.info(f"{self.log_prefix} ON_PLAN 事件中止了本轮 HFC")
return None, filtered_actions
if modified_message and modified_message._modify_flags.modify_llm_prompt and modified_message.llm_prompt:
prompt = modified_message.llm_prompt
return prompt, filtered_actions
def _ensure_force_reply_action(
self,
actions: List[ActionPlannerInfo],
force_reply_message: Optional["SessionMessage"],
available_actions: Dict[str, ActionInfo],
) -> List[ActionPlannerInfo]:
if not force_reply_message:
return actions
has_reply_to_force_message = any(
action.action_type == "reply"
and action.action_message
and action.action_message.message_id == force_reply_message.message_id
for action in actions
)
if has_reply_to_force_message:
return actions
actions = [action for action in actions if action.action_type != "no_reply"]
actions.insert(
0,
ActionPlannerInfo(
action_type="reply",
reasoning="用户提及了我,必须回复该消息",
action_data={"loop_start_time": self.last_read_time},
action_message=force_reply_message,
available_actions=available_actions,
action_reasoning=None,
),
)
logger.info(f"{self.log_prefix} 检测到强制回复消息,已补充 reply 动作")
return actions
def _log_plan(
self,
prompt: str,
reasoning: str,
llm_raw_output: Optional[str],
llm_reasoning: Optional[str],
llm_duration_ms: Optional[float],
actions: List[ActionPlannerInfo],
) -> None:
try:
PlanReplyLogger.log_plan(
chat_id=self.session_id,
prompt=prompt,
reasoning=reasoning,
raw_output=llm_raw_output,
raw_reasoning=llm_reasoning,
actions=actions,
timing={
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
"loop_start_time": self.last_read_time,
},
extra=None,
)
except Exception:
logger.exception(f"{self.log_prefix} 记录 plan 日志失败")
def _extract_reply_metadata(
self,
action_planner_info: ActionPlannerInfo,
) -> Tuple[Optional[List[str]], Optional[bool]]:
unknown_words: Optional[List[str]] = None
quote_message: Optional[bool] = None
action_data = action_planner_info.action_data or {}
raw_unknown_words = action_data.get("unknown_words")
if isinstance(raw_unknown_words, list):
cleaned_unknown_words = []
for item in raw_unknown_words:
if isinstance(item, str) and (cleaned_item := item.strip()):
cleaned_unknown_words.append(cleaned_item)
if cleaned_unknown_words:
unknown_words = cleaned_unknown_words
raw_quote = action_data.get("quote")
if isinstance(raw_quote, bool):
quote_message = raw_quote
elif isinstance(raw_quote, str):
quote_message = raw_quote.lower() in {"true", "1", "yes"}
elif isinstance(raw_quote, (int, float)):
quote_message = bool(raw_quote)
return unknown_words, quote_message
def _get_think_level(self, action_planner_info: ActionPlannerInfo) -> int:
think_mode = global_config.chat.think_mode
if think_mode == "default":
return 0
if think_mode == "deep":
return 1
if think_mode == "dynamic":
action_data = action_planner_info.action_data or {}
return int(action_data.get("think_level", 1))
return 0
def _get_template_name(self) -> Optional[str]:
if self.chat_stream.context:
return self.chat_stream.context.template_name
return None
# ====== 响应发送相关方法 ======
async def _send_response(self, *args, **kwargs):
raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符
# 传入的消息至少应该是个MessageSequence实例最好是SessionMessage实例随后可直接转化为MessageSending实例

View File

@@ -1,42 +0,0 @@
import traceback
from typing import Any, Optional, Dict
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.chat.heart_flow.heartFC_chat import HeartFChatting
from src.chat.brain_chat.brain_chat import BrainChatting
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("heartflow")
class Heartflow:
"""主心流协调器,负责初始化并协调聊天"""
def __init__(self):
self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
"""获取或创建一个新的HeartFChatting实例"""
try:
if chat_id in self.heartflow_chat_list:
if chat := self.heartflow_chat_list.get(chat_id):
return chat
else:
chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
if not chat_stream:
raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
if chat_stream.group_info:
new_chat = HeartFChatting(chat_id=chat_id)
else:
new_chat = BrainChatting(chat_id=chat_id)
await new_chat.start()
self.heartflow_chat_list[chat_id] = new_chat
return new_chat
except Exception as e:
logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True)
traceback.print_exc()
return None
heartflow = Heartflow()

View File

@@ -1,19 +1,20 @@
from contextlib import suppress
import traceback
import os
from maim_message import MessageBase
from typing import Any, Dict, Optional
import os
import traceback
from maim_message import MessageBase
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.platform_io.route_key_factory import RouteKeyFactory
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.core.announcement_manager import global_announcement_manager
from src.core.component_registry import component_registry
from src.plugin_runtime.component_query import component_query_service
from .message import SessionMessage
from .chat_manager import chat_manager
@@ -58,16 +59,22 @@ class ChatBot:
logger.error(f"创建PFC聊天失败: {e}")
logger.error(traceback.format_exc())
async def _process_commands(self, message: SessionMessage):
# sourcery skip: use-named-expression
"""使用新插件系统处理命令"""
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
"""使用统一组件注册表处理命令。
Args:
message: 当前待处理的会话消息。
Returns:
tuple[bool, Optional[str], bool]: ``(是否命中命令, 命令响应文本, 是否继续后续处理)``。
"""
if not message.processed_plain_text:
return False, None, True # 没有文本内容,继续处理消息
try:
text = message.processed_plain_text
# 使用核心组件注册表查找命令
command_result = component_registry.find_command_by_text(text)
# 使用插件运行时统一查询服务查找命令
command_result = component_query_service.find_command_by_text(text)
if command_result:
command_executor, matched_groups, command_info = command_result
plugin_name = command_info.plugin_name
@@ -81,7 +88,7 @@ class ChatBot:
message.is_command = True
# 获取插件配置
plugin_config = component_registry.get_plugin_config(plugin_name)
plugin_config = component_query_service.get_plugin_config(plugin_name)
try:
# 调用命令执行器
@@ -112,88 +119,32 @@ class ChatBot:
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
# 没有找到旧系统命令,尝试新版本插件运行时
new_cmd_result = await self._process_new_runtime_command(message)
return new_cmd_result if new_cmd_result is not None else (False, None, True)
return False, None, True
except Exception as e:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
async def _process_new_runtime_command(self, message: SessionMessage):
"""尝试在新版本插件运行时中查找并执行命令
Returns:
(found, response, continue_processing) 三元组,
或 None 表示新运行时中也未找到匹配命令。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
prm = get_plugin_runtime_manager()
if not prm.is_running:
return None
matched = prm.find_command_by_text(message.processed_plain_text)
if matched is None:
return None
command_name = matched["name"]
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
message.session_id
):
logger.info(f"[新运行时] 用户禁用的命令,跳过处理: {matched['full_name']}")
return False, None, True
message.is_command = True
logger.info(f"[新运行时] 匹配命令: {matched['full_name']}")
try:
resp = await prm.invoke_plugin(
method="plugin.invoke_command",
plugin_id=matched["plugin_id"],
component_name=matched["name"],
args={
"text": message.processed_plain_text,
"stream_id": message.session_id or "",
"matched_groups": matched.get("matched_groups") or {},
},
timeout_ms=30000,
)
payload = resp.payload
success = payload.get("success", False)
cmd_result = payload.get("result")
# 拦截位优先从命令返回值中获取(支持运行时动态决定),
# 回退到组件 metadata 中的静态声明
if isinstance(cmd_result, (list, tuple)) and len(cmd_result) >= 3:
# 命令返回 (found, response_text, intercept_bool) 三元组
response_text = cmd_result[1] if cmd_result[1] is not None else ""
intercept = bool(cmd_result[2])
else:
response_text = cmd_result if cmd_result is not None else ""
intercept = bool(matched["metadata"].get("intercept_message_level", 0))
self._mark_command_message(message, int(intercept))
if success:
logger.info(f"[新运行时] 命令执行成功: {matched['full_name']}")
else:
logger.warning(f"[新运行时] 命令执行失败: {matched['full_name']} - {response_text}")
return True, response_text, not intercept
except Exception as e:
logger.error(f"[新运行时] 执行命令 {matched['full_name']} 异常: {e}", exc_info=True)
return True, str(e), True
@staticmethod
def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None:
"""标记消息已经被命令链消费。
Args:
message: 待标记的会话消息。
intercept_message_level: 命令设置的拦截级别。
"""
message.is_command = True
message.message_info.additional_config["intercept_message_level"] = intercept_message_level
@staticmethod
def _store_intercepted_command_message(message: SessionMessage) -> None:
"""将被命令链拦截的消息写入数据库。
Args:
message: 已完成命令处理的会话消息。
"""
MessageUtils.store_message_to_db(message)
async def _handle_command_processing_result(
@@ -310,13 +261,28 @@ class ChatBot:
# logger.debug(str(message_data))
maim_raw_message = MessageBase.from_dict(message_data)
message = SessionMessage.from_maim_message(maim_raw_message)
await self.receive_message(message)
except Exception as e:
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
async def receive_message(self, message: SessionMessage):
try:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
account_id = None
scope = None
additional_config = message.message_info.additional_config
if isinstance(additional_config, dict):
account_id, scope = RouteKeyFactory.extract_components(additional_config)
session_id = SessionUtils.calculate_session_id(
message.platform,
user_id=message.message_info.user_info.user_id,
group_id=group_info.group_id if group_info else None,
account_id=account_id,
scope=scope,
)
message.session_id = session_id # 正确初始化session_id
@@ -359,24 +325,24 @@ class ChatBot:
platform = message.platform
user_id = user_info.user_id
group_id = group_info.group_id if group_info else None
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
try:
from src.services.memory_flow_service import memory_automation_service
await memory_automation_service.on_incoming_message(message)
except Exception as exc:
logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}")
_ = await chat_manager.get_or_create_session(
platform,
user_id,
group_id,
account_id=account_id,
scope=scope,
) # 确保会话存在
# message.update_chat_stream(chat)
# 命令处理 - 使用新插件系统检查并处理命令
# 注意:命令返回的 response 当前只用于日志记录和流程判断,
# 不会在这里自动作为回复消息发送回会话。
is_command, cmd_result, continue_process = await self._process_commands(message)
# is_command, cmd_result, continue_process = await self._process_commands(message)
# 如果是命令且不需要继续处理,则直接返回
if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
return
# # 如果是命令且不需要继续处理,则直接返回
# if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
# return
# continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
# if not continue_flag:

View File

@@ -1,15 +1,16 @@
import asyncio
from datetime import datetime
from typing import TYPE_CHECKING, Dict, List, Optional
from rich.traceback import install
from sqlmodel import select
from typing import Optional, TYPE_CHECKING, List, Dict
import asyncio
from src.common.logger import get_logger
from src.common.data_models.chat_session_data_model import MaiChatSession
from src.common.database.database_model import ChatSession
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.platform_io.route_key_factory import RouteKeyFactory
if TYPE_CHECKING:
from .message import SessionMessage
@@ -82,7 +83,12 @@ class ChatManager:
logger.error(f"初始化聊天管理器出现错误: {e}")
async def get_or_create_session(
self, platform: str, user_id: str, group_id: Optional[str] = None
self,
platform: str,
user_id: str,
group_id: Optional[str] = None,
account_id: Optional[str] = None,
scope: Optional[str] = None,
) -> BotChatSession:
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
@@ -90,12 +96,20 @@ class ChatManager:
platform: 平台
user_id: 用户ID
group_id: 群ID如果是群聊
account_id: 平台账号 ID
scope: 路由作用域
Returns:
return (BotChatSession) 会话对象
Raises:
Exception: 获取或创建会话时发生错误
"""
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
session_id = SessionUtils.calculate_session_id(
platform,
user_id=user_id,
group_id=group_id,
account_id=account_id,
scope=scope,
)
if session := self.get_session_by_session_id(session_id):
session.update_active_time()
return session
@@ -131,7 +145,18 @@ class ChatManager:
raise ValueError("消息缺少平台信息")
user_id = message.message_info.user_info.user_id
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
account_id = None
scope = None
additional_config = message.message_info.additional_config
if isinstance(additional_config, dict):
account_id, scope = RouteKeyFactory.extract_components(additional_config)
session_id = SessionUtils.calculate_session_id(
platform,
user_id=user_id,
group_id=group_id,
account_id=account_id,
scope=scope,
)
message.session_id = session_id # 确保消息的session_id正确设置
self.last_messages[session_id] = message
@@ -188,7 +213,12 @@ class ChatManager:
return None
def get_session_by_info(
self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None
self,
platform: str,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
account_id: Optional[str] = None,
scope: Optional[str] = None,
) -> Optional[BotChatSession]:
"""根据平台、用户ID和群ID获取对应的会话
@@ -196,10 +226,18 @@ class ChatManager:
platform: 平台
user_id: 用户ID
group_id: 群ID如果是群聊
account_id: 平台账号 ID
scope: 路由作用域
Returns:
return (Optional[BotChatSession]): 会话对象如果不存在则返回None
"""
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
session_id = SessionUtils.calculate_session_id(
platform,
user_id=user_id,
group_id=group_id,
account_id=account_id,
scope=scope,
)
return self.get_session_by_session_id(session_id)
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:

View File

@@ -1,31 +1,37 @@
from rich.traceback import install
from typing import Optional
from typing import Any, Optional, Tuple
import asyncio
import traceback
from rich.traceback import install
from src.common.message_server.api import get_global_api
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.chat.message_receive.message import SessionMessage
from src.chat.utils.utils import calculate_typing_time, truncate_message
from src.common.data_models.message_component_data_model import ReplyComponent
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
from src.common.database.database import get_db_session
from src.common.logger import get_logger
from src.common.message_server.api import get_global_api
from src.webui.routers.chat.serializers import serialize_message_sequence
install(extra_lines=3)
logger = get_logger("sender")
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
_webui_chat_broadcaster = None
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
def get_webui_chat_broadcaster():
"""获取 WebUI 聊天室广播器"""
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
"""获取 WebUI 聊天室广播器
Returns:
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
"""
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
@@ -38,102 +44,35 @@ def get_webui_chat_broadcaster():
def is_webui_virtual_group(group_id: str) -> bool:
"""检查是否是 WebUI 虚拟群"""
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
def parse_message_segments(segment) -> list:
"""解析消息段,转换为 WebUI 可用的格式
参考 NapCat 适配器的消息解析逻辑
"""检查是否是 WebUI 虚拟群
Args:
segment: Seg 消息段对象
group_id: 待判断的群 ID。
Returns:
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。
"""
result = []
if segment is None:
return result
if segment.type == "seglist":
# 处理消息段列表
if segment.data:
for seg in segment.data:
result.extend(parse_message_segments(seg))
elif segment.type == "text":
# 文本消息
if segment.data:
result.append({"type": "text", "data": segment.data})
elif segment.type == "image":
# 图片消息base64
if segment.data:
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
elif segment.type == "emoji":
# 表情包消息base64
if segment.data:
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
elif segment.type == "imageurl":
# 图片链接消息
if segment.data:
result.append({"type": "image", "data": segment.data})
elif segment.type == "face":
# 原生表情
result.append({"type": "face", "data": segment.data})
elif segment.type == "voice":
# 语音消息base64
if segment.data:
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
elif segment.type == "voiceurl":
# 语音链接
if segment.data:
result.append({"type": "voice", "data": segment.data})
elif segment.type == "video":
# 视频消息base64
if segment.data:
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
elif segment.type == "videourl":
# 视频链接
if segment.data:
result.append({"type": "video", "data": segment.data})
elif segment.type == "music":
# 音乐消息
result.append({"type": "music", "data": segment.data})
elif segment.type == "file":
# 文件消息
result.append({"type": "file", "data": segment.data})
elif segment.type == "reply":
# 回复消息
result.append({"type": "reply", "data": segment.data})
elif segment.type == "forward":
# 转发消息
forward_items = []
if segment.data:
for item in segment.data:
forward_items.append(
{
"content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items})
else:
# 未知类型,尝试作为文本处理
if segment.data:
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
return result
return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
async def _send_message(message: MessageSending, show_log=True) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录"""
async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
"""执行统一的消息发送流程。
发送顺序为:
1. WebUI 特殊链路
2. 旧版 ``maim_message`` / API Server 链路
Args:
message: 待发送的内部会话消息。
show_log: 是否输出发送成功日志。
Returns:
bool: 是否最终发送成功。
"""
message_preview = truncate_message(message.processed_plain_text, max_length=200)
platform = message.platform
group_id = message.session.group_id
group_info = message.message_info.group_info
group_id = group_info.group_id if group_info is not None else ""
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
@@ -146,7 +85,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
from src.config.config import global_config
# 解析消息段,获取富文本内容
message_segments = parse_message_segments(message.message_segment)
message_segments = serialize_message_sequence(message.raw_message)
# 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型
@@ -185,7 +124,15 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
return True
# Fallback 逻辑: 尝试通过 API Server 发送
async def send_with_new_api(legacy_exception=None):
async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool:
"""通过 API Server 回退链路发送消息。
Args:
legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。
Returns:
bool: 回退链路是否发送成功。
"""
try:
from src.config.config import global_config
@@ -286,10 +233,24 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
raise e # 重新抛出其他异常
class UniversalMessageSender:
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
async def send_prepared_message_to_platform(message: SessionMessage, show_log: bool = True) -> bool:
"""发送一条已完成预处理的消息到底层平台。
def __init__(self):
Args:
message: 已经完成回复组件注入、文本处理等预处理的消息对象。
show_log: 是否输出发送成功日志。
Returns:
bool: 发送成功时返回 ``True``。
"""
return await _send_message(message, show_log=show_log)
class UniversalMessageSender:
"""旧链与 WebUI 的底层发送器。"""
def __init__(self) -> None:
"""初始化统一消息发送器。"""
pass
async def send_message(
@@ -300,18 +261,19 @@ class UniversalMessageSender:
reply_message_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
):
"""
处理、发送并存储一条消息。
) -> bool:
"""通过旧链或 WebUI 发送并存储一条消息。
参数:
message: MessageSession 对象,待发送的消息
Args:
message: 待发送的内部消息对象
typing: 是否模拟打字等待。
set_reply: 是否构建回复引用消息。
set_reply: 是否构建引用回复消息。
reply_message_id: 被引用消息的 ID。
storage_message: 是否在发送成功后写入数据库。
show_log: 是否输出发送日志。
用法:
- typing=True 时,发送前会有打字等待。
Returns:
bool: 发送成功时返回 ``True``。
"""
if not message.message_id:
logger.error("消息缺少 message_id无法发送")
@@ -364,7 +326,7 @@ class UniversalMessageSender:
)
await asyncio.sleep(typing_time)
sent_msg = await _send_message(message, show_log=show_log)
sent_msg = await send_prepared_message_to_platform(message, show_log=show_log)
if not sent_msg:
return False

View File

@@ -3,8 +3,8 @@ from typing import Dict, Optional, Tuple
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger
from src.core.component_registry import component_registry, ActionExecutor
from src.core.types import ActionInfo
from src.plugin_runtime.component_query import ActionExecutor, component_query_service
logger = get_logger("action_manager")
@@ -28,7 +28,7 @@ class ActionManager:
"""
动作管理器,用于管理各种类型的动作
使用核心组件注册表的 executor-based 模式。
使用插件运行时统一查询服务的 executor-based 模式。
"""
def __init__(self):
@@ -38,7 +38,7 @@ class ActionManager:
self._using_actions: Dict[str, ActionInfo] = {}
# 初始化时将默认动作加载到使用中的动作
self._using_actions = component_registry.get_default_actions()
self._using_actions = component_query_service.get_default_actions()
# === 执行Action方法 ===
@@ -72,17 +72,17 @@ class ActionManager:
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
"""
try:
executor = component_registry.get_action_executor(action_name)
executor = component_query_service.get_action_executor(action_name)
if not executor:
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
return None
info = component_registry.get_action_info(action_name)
info = component_query_service.get_action_info(action_name)
if not info:
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
return None
plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {}
handle = ActionHandle(
executor,
@@ -133,5 +133,5 @@ class ActionManager:
def restore_actions(self) -> None:
"""恢复到默认动作集"""
actions_to_restore = list(self._using_actions.keys())
self._using_actions = component_registry.get_default_actions()
self._using_actions = component_query_service.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")

View File

@@ -1,33 +1,36 @@
from collections import OrderedDict
from datetime import datetime
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import contextlib
import json
import time
import traceback
import random
import re
import contextlib
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
from collections import OrderedDict
from rich.traceback import install
from datetime import datetime
import time
import traceback
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from rich.traceback import install
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import Person
from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
from src.services.message_service import (
build_readable_messages_with_id,
replace_user_references,
get_messages_before_time_in_chat,
replace_user_references,
translate_pid_to_description,
)
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.core.component_registry import component_registry
from src.person_info.person_info import Person
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
@@ -634,7 +637,7 @@ class ActionPlanner:
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}

View File

@@ -1,6 +1,7 @@
import traceback
import time
import asyncio
import importlib
import random
import re
@@ -16,7 +17,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
@@ -26,7 +26,7 @@ from src.services.message_service import (
replace_user_references,
translate_pid_to_description,
)
from src.bw_learner.expression_selector import expression_selector
from src.learners.expression_selector import expression_selector
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person
@@ -35,8 +35,7 @@ from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.memory_system.retrieval_tools import get_tool_registry
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
from src.chat.utils.common_utils import TempMethodsExpression
init_memory_retrieval_sys()
@@ -51,10 +50,15 @@ class DefaultReplyer:
chat_stream: BotChatSession,
request_type: str = "replyer",
):
"""初始化群聊回复器。
Args:
chat_stream: 当前绑定的聊天会话。
request_type: LLM 请求类型标识。
"""
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender()
from src.chat.tool_executor import ToolExecutor
@@ -1129,7 +1133,10 @@ class DefaultReplyer:
user_id=bot_user_id,
user_nickname=global_config.bot.nickname,
),
additional_config={},
additional_config={
"platform_io_target_group_id": self.chat_stream.group_id,
"platform_io_target_user_id": self.chat_stream.user_id,
},
),
message_segment=message_segment,
)
@@ -1164,14 +1171,29 @@ class DefaultReplyer:
async def get_prompt_info(self, message: str, sender: str, target: str):
related_info = ""
start_time = time.time()
search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory")
if search_knowledge_tool is None:
logger.debug("长期记忆检索工具未注册,跳过获取知识内容")
try:
knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge")
except ImportError:
logger.debug("LPMM知识库工具模块不存在跳过获取知识库内容")
return ""
logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}")
search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None)
if search_knowledge_tool is None:
logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool跳过获取知识库内容")
return ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
try:
template_prompt = prompt_manager.get_prompt("memory_get_knowledge")
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
if global_config.lpmm_knowledge.lpmm_mode == "agent":
return ""
template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge")
template_prompt.add_context("bot_name", global_config.bot.nickname)
template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
template_prompt.add_context("chat_history", message)
@@ -1187,31 +1209,24 @@ class DefaultReplyer:
# logger.info(f"工具调用提示词: {prompt}")
# logger.info(f"工具调用: {tool_calls}")
if not tool_calls:
logger.debug("模型认为不需要使用长期记忆")
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0])
end_time = time.time()
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return ""
found_knowledge_from_lpmm = result.get("content", "")
logger.info(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
else:
logger.debug("模型认为不需要使用LPMM知识库")
return ""
related_chunks: List[str] = []
for tool_call in tool_calls:
if tool_call.func_name != "search_long_term_memory":
continue
tool_args = dict(tool_call.args or {})
tool_args.setdefault("chat_id", self.chat_stream.session_id)
result_text = await search_knowledge_tool.execute(**tool_args)
if result_text and "未找到" not in result_text:
related_chunks.append(result_text)
if not related_chunks:
logger.debug("长期记忆未返回有效信息")
return ""
related_info = "\n".join(related_chunks)
end_time = time.time()
logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""

View File

@@ -16,7 +16,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
@@ -27,13 +26,13 @@ from src.services.message_service import (
replace_user_references,
translate_pid_to_description,
)
from src.bw_learner.expression_selector import expression_selector
from src.learners.expression_selector import expression_selector
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.core.types import ActionInfo, EventType
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context
from src.learners.jargon_explainer_old import explain_jargon_in_context
init_memory_retrieval_sys()
@@ -47,10 +46,15 @@ class PrivateReplyer:
chat_stream: BotChatSession,
request_type: str = "replyer",
):
"""初始化私聊回复器。
Args:
chat_stream: 当前绑定的聊天会话。
request_type: LLM 请求类型标识。
"""
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender()
# self.memory_activator = MemoryActivator()
from src.chat.tool_executor import ToolExecutor
@@ -970,7 +974,9 @@ class PrivateReplyer:
user_nickname=global_config.bot.nickname,
),
group_info=None,
additional_config={},
additional_config={
"platform_io_target_user_id": self.chat_stream.user_id,
},
),
message_segment=message_segment,
)

View File

@@ -1,22 +1,20 @@
"""
工具执行器
"""工具执行器。
独立的工具执行组件,可以直接输入聊天消息内容,
自动判断并执行相应的工具,返回结构化的工具执行结果。
从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。
"""
from typing import Any, Dict, List, Optional, Tuple
import hashlib
import time
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.core.announcement_manager import global_announcement_manager
from src.core.component_registry import component_registry
from src.llm_models.payload_content import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
logger = get_logger("tool_use")
@@ -89,7 +87,7 @@ class ToolExecutor:
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取 LLM 可用的工具定义列表"""
all_tools = component_registry.get_llm_available_tools()
all_tools = component_query_service.get_llm_available_tools()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
@@ -152,7 +150,7 @@ class ToolExecutor:
function_args = tool_call.args or {}
function_args["llm_called"] = True
executor = component_registry.get_tool_executor(function_name)
executor = component_query_service.get_tool_executor(function_name)
if not executor:
logger.warning(f"未知工具名称: {function_name}")
return None

View File

@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
@staticmethod
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
records = session.exec(statement).all()
return [(record.start_timestamp, record.end_timestamp) for record in records]
@staticmethod
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
records = session.exec(statement).all()
return [
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1]
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
messages = session.exec(statement).all()
for message in messages:
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
try:
action_query_start_timestamp = collect_period[-1][1]
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
actions = session.exec(statement).all()
for action in actions:
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Optional, List
from typing import Any, List, Mapping, Optional, Sequence
import json
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
group_cardname: str
def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
"""将单条群名片数据规范化为统一结构。
Args:
raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
Returns:
Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
"""
group_id = str(raw_item.get("group_id") or "").strip()
group_cardname = str(raw_item.get("group_cardname") or "").strip()
if not group_id or not group_cardname:
return None
return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
"""解析数据库中的群名片 JSON 字段。
Args:
group_cardname_json: 数据库存储的群名片 JSON 字符串。
Returns:
Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
Raises:
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
"""
if not group_cardname_json:
return None
raw_items = json.loads(group_cardname_json)
if not isinstance(raw_items, list):
return None
normalized_items: List[GroupCardnameInfo] = []
for raw_item in raw_items:
if not isinstance(raw_item, Mapping):
continue
if normalized_item := _normalize_group_cardname_item(raw_item):
normalized_items.append(normalized_item)
return normalized_items or None
def dump_group_cardname_records(
group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
) -> str:
"""将群名片列表序列化为数据库使用的标准 JSON 字符串。
Args:
group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
对象和包含 `group_id` / `group_cardname` 的字典。
Returns:
str: 统一使用 `group_cardname` 键名的 JSON 字符串。
"""
normalized_items: List[GroupCardnameInfo] = []
for raw_item in group_cardname_records or []:
if isinstance(raw_item, GroupCardnameInfo):
normalized_items.append(raw_item)
continue
if isinstance(raw_item, Mapping):
if normalized_item := _normalize_group_cardname_item(raw_item):
normalized_items.append(normalized_item)
return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
def __init__(
self,
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
"""最后一次被认识的时间"""
@classmethod
def from_db_instance(cls, db_record: "PersonInfo"):
nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
"""从数据库记录构造人物信息数据模型。
Args:
db_record: 数据库中的人物信息记录。
Returns:
MaiPersonInfo: 转换后的数据模型对象。
"""
group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
return cls(
is_known=db_record.is_known,
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
)
def to_db_instance(self) -> "PersonInfo":
group_cardname = (
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
)
"""将当前数据模型转换为数据库记录对象。
Returns:
PersonInfo: 可直接写入数据库的模型实例。
"""
group_cardname = dump_group_cardname_records(self.group_cardname_list)
return PersonInfo(
is_known=self.is_known,
person_id=self.person_id,

View File

@@ -1,8 +1,9 @@
from typing import Optional
from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime
from sqlmodel import SQLModel, Field, LargeBinary
from enum import Enum
from datetime import datetime
from enum import Enum
from typing import Optional
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
from sqlmodel import Field, LargeBinary, SQLModel
class ModelUser(str, Enum):
@@ -172,8 +173,8 @@ class Expression(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
situation: str = Field(index=True, max_length=255, primary_key=True) # 情景
style: str = Field(index=True, max_length=255, primary_key=True) # 风格
situation: str = Field(index=True, max_length=255) # 情景
style: str = Field(index=True, max_length=255) # 风格
# context: str # 上下文
# up_content: str
@@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容
content: str = Field(index=True, max_length=255) # 黑话内容
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容未处理的黑话内容为List[str]
meaning: str # 黑话含义

View File

@@ -1,9 +1,8 @@
# 定义模块颜色映射
from typing import Optional, Tuple, Dict
import itertools
import os
import sys
from typing import Dict, Optional, Tuple
MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
@@ -54,15 +53,19 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
"component_registry": ("#ffaf00", None, False),
"plugin_runtime.integration": ("#d75f00", None, False),
"plugin_runtime.host.supervisor": ("#ff5f00", None, False),
"plugin_runtime.host.runner_manager": ("#ff5f00", None, False),
"plugin_runtime.host.rpc_server": ("#ff8700", None, False),
"plugin_runtime.host.component_registry": ("#ffaf00", None, False),
"plugin_runtime.host.capability_service": ("#ffd700", None, False),
"plugin_runtime.host.event_dispatcher": ("#87d700", None, False),
"plugin_runtime.host.workflow_executor": ("#5fd7af", None, False),
"plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False),
"plugin_runtime.host.message_gateway": ("#5fd7d7", None, False),
"plugin_runtime.host.message_utils": ("#5faf87", None, False),
"plugin_runtime.runner.main": ("#d787ff", None, False),
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
"plugin_runtime.runner.plugin_loader": ("#00afaf", None, False),
"plugin.maibot-team.napcat-adapter": ("#00af87", None, False),
"webui": ("#5f87ff", None, False),
"webui.app": ("#5f87d7", None, False),
"webui.api": ("#5fafff", None, False),
@@ -157,15 +160,20 @@ MODULE_ALIASES = {
"chat_history_summarizer": "聊天概括器",
"plugin_runtime.integration": "IPC插件系统",
"plugin_runtime.host.supervisor": "插件监督器",
"plugin_runtime.host.runner_manager": "插件监督器",
"plugin_runtime.host.rpc_server": "插件RPC服务",
"plugin_runtime.host.component_registry": "插件组件注册",
"plugin_runtime.host.capability_service": "插件能力服务",
"plugin_runtime.host.event_dispatcher": "插件事件分发",
"plugin_runtime.host.hook_dispatcher": "插件Hook分发",
"plugin_runtime.host.message_gateway": "插件消息网关",
"plugin_runtime.host.message_utils": "插件消息工具",
"plugin_runtime.host.workflow_executor": "插件工作流",
"plugin_runtime.runner.main": "插件运行器",
"plugin_runtime.runner.rpc_client": "插件RPC客户端",
"plugin_runtime.runner.manifest_validator": "插件清单校验",
"plugin_runtime.runner.plugin_loader": "插件加载器",
"plugin.maibot-team.napcat-adapter": "NapCat内置适配器",
"webui": "WebUI",
"webui.app": "WebUI应用",
"webui.api": "WebUI接口",

View File

@@ -21,7 +21,7 @@ class Server:
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
def register_router(self, router: APIRouter, prefix: str = ""):
def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:

View File

@@ -5,13 +5,22 @@ import hashlib
class SessionUtils:
@staticmethod
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
def calculate_session_id(
platform: str,
*,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
account_id: Optional[str] = None,
scope: Optional[str] = None,
) -> str:
"""计算session_id
Args:
platform: 平台名称
user_id: 用户ID如果是私聊
group_id: 群ID如果是群聊
account_id: 当前平台账号 ID可选
scope: 当前路由作用域,可选
Returns:
str: 计算得到的会话ID
Raises:
@@ -19,8 +28,15 @@ class SessionUtils:
"""
if not user_id and not group_id:
raise ValueError("UserID 或 GroupID 必须提供其一")
route_components = []
if account_id:
route_components.append(f"account:{account_id}")
if scope:
route_components.append(f"scope:{scope}")
if group_id:
components = [platform, group_id]
components = [platform, *route_components, group_id]
else:
components = [platform, user_id, "private"]
components = [platform, *route_components, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()

View File

@@ -4,6 +4,7 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar
import asyncio
import copy
import inspect
import sys
import tomlkit
@@ -61,6 +62,7 @@ MODEL_CONFIG_VERSION: str = "1.12.0"
logger = get_logger("config")
T = TypeVar("T", bound="ConfigBase")
ConfigReloadCallback = Callable[[Sequence[str]], object] | Callable[[], object]
class Config(ConfigBase):
@@ -190,7 +192,7 @@ class ConfigManager:
self.global_config: Config | None = None
self.model_config: ModelConfig | None = None
self._reload_lock: asyncio.Lock = asyncio.Lock()
self._reload_callbacks: list[Callable[[], object]] = []
self._reload_callbacks: list[ConfigReloadCallback] = []
self._file_watcher: FileWatcher | None = None
self._file_watcher_subscription_id: str | None = None
self._hot_reload_min_interval_s: float = 1.0
@@ -226,16 +228,125 @@ class ConfigManager:
raise RuntimeError(t("config.model_not_initialized"))
return self.model_config
def register_reload_callback(self, callback: Callable[[], object]) -> None:
def register_reload_callback(self, callback: ConfigReloadCallback) -> None:
"""注册配置热重载回调。
Args:
callback: 配置热重载回调。允许无参回调,也允许接收
``Sequence[str]`` 类型的变更范围列表。
"""
self._reload_callbacks.append(callback)
def unregister_reload_callback(self, callback: Callable[[], object]) -> None:
def unregister_reload_callback(self, callback: ConfigReloadCallback) -> None:
"""注销配置热重载回调。
Args:
callback: 先前注册过的回调对象。
"""
try:
self._reload_callbacks.remove(callback)
except ValueError:
return
async def reload_config(self) -> bool:
@staticmethod
def _normalize_changed_scopes(changed_scopes: Sequence[str] | None) -> tuple[str, ...]:
"""规范化配置变更范围列表。
Args:
changed_scopes: 原始配置变更范围。
Returns:
tuple[str, ...]: 去重后的配置变更范围元组。
"""
if not changed_scopes:
return ("bot", "model")
normalized_scopes: list[str] = []
for scope in changed_scopes:
normalized_scope = str(scope or "").strip().lower()
if normalized_scope not in {"bot", "model"}:
continue
if normalized_scope not in normalized_scopes:
normalized_scopes.append(normalized_scope)
return tuple(normalized_scopes)
@staticmethod
def _resolve_changed_scopes(changes: Sequence[FileChange]) -> tuple[str, ...]:
"""根据文件变更列表推断配置变更范围。
Args:
changes: 文件监听器返回的变更列表。
Returns:
tuple[str, ...]: 命中的配置变更范围元组。
"""
changed_scopes: list[str] = []
for change in changes:
file_name = change.path.name
if file_name == "bot_config.toml" and "bot" not in changed_scopes:
changed_scopes.append("bot")
if file_name == "model_config.toml" and "model" not in changed_scopes:
changed_scopes.append("model")
return tuple(changed_scopes)
@staticmethod
def _callback_accepts_scopes(callback: ConfigReloadCallback) -> bool:
"""判断回调是否接收配置变更范围参数。
Args:
callback: 待检测的回调对象。
Returns:
bool: 若回调可接收一个位置参数或可变位置参数,则返回 ``True``。
"""
try:
parameters = inspect.signature(callback).parameters.values()
except (TypeError, ValueError):
return False
positional_params = {
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
}
for parameter in parameters:
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
return True
if parameter.kind in positional_params:
return True
return False
async def _invoke_reload_callback(
self,
callback: ConfigReloadCallback,
changed_scopes: Sequence[str],
) -> None:
"""执行单个配置热重载回调。
Args:
callback: 要执行的回调对象。
changed_scopes: 本次热重载命中的配置范围。
"""
result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback()
if asyncio.iscoroutine(result):
await result
async def reload_config(self, changed_scopes: Sequence[str] | None = None) -> bool:
"""重新加载主配置和模型配置。
Args:
changed_scopes: 本次触发热重载的配置范围。
Returns:
bool: 是否重载成功。
"""
normalized_scopes = self._normalize_changed_scopes(changed_scopes)
async with self._reload_lock:
try:
global_config_new, global_updated = load_config_from_file(
@@ -265,9 +376,7 @@ class ConfigManager:
for callback in list(self._reload_callbacks):
try:
result = callback()
if asyncio.iscoroutine(result):
await result
await self._invoke_reload_callback(callback, normalized_scopes)
except Exception as exc:
logger.warning(t("config.reload_callback_failed", error=exc))
return True
@@ -312,6 +421,12 @@ class ConfigManager:
self._file_watcher = None
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
"""处理主配置与模型配置文件变更。
Args:
changes: 当前批次收集到的文件变更列表。
"""
if not changes:
return
now_monotonic = asyncio.get_running_loop().time()
@@ -321,7 +436,11 @@ class ConfigManager:
self._last_hot_reload_monotonic = now_monotonic
logger.info(t("config.file_change_detected"))
try:
await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s)
changed_scopes = self._resolve_changed_scopes(changes)
await asyncio.wait_for(
self.reload_config(changed_scopes=changed_scopes),
timeout=self._hot_reload_timeout_s,
)
except asyncio.TimeoutError:
logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s))

View File

@@ -1633,24 +1633,6 @@ class PluginRuntimeConfig(ConfigBase):
)
"""启用插件系统"""
builtin_plugin_dir: str = Field(
default="src/plugins/built_in",
json_schema_extra={
"x-widget": "input",
"x-icon": "folder",
},
)
"""内置插件目录(相对于项目根目录)"""
thirdparty_plugin_dir: str = Field(
default="plugins",
json_schema_extra={
"x-widget": "input",
"x-icon": "folder-open",
},
)
"""第三方插件目录(相对于项目根目录)"""
health_check_interval_sec: float = Field(
default=30.0,
json_schema_extra={
@@ -1678,14 +1660,14 @@ class PluginRuntimeConfig(ConfigBase):
)
"""等待 Runner 子进程启动并注册的超时时间(秒)"""
workflow_blocking_timeout_sec: float = Field(
default=120.0,
hook_blocking_timeout_sec: float = Field(
default=30,
json_schema_extra={
"x-widget": "number",
"x-icon": "timer",
},
)
"""Workflow 阻塞步骤的全局超时上限(秒)"""
"""Hook 阻塞步骤的全局超时上限(秒)"""
ipc_socket_path: str = Field(
default="",
@@ -1694,4 +1676,7 @@ class PluginRuntimeConfig(ConfigBase):
"x-icon": "link",
},
)
"""_wrap_\n 自定义 IPC Socket 路径(仅 Linux/macOS 生效)\n 留空则自动生成临时路径"""
"""
自定义 IPC Socket 路径(仅 Linux/macOS 生效)
留空则自动生成临时路径
"""

View File

@@ -1,239 +0,0 @@
"""
核心组件注册表
面向最终架构的组件管理:
- Action注册 ActionInfo + 执行器(本地 callable 或 IPC 路由)
- Command注册正则模式 + 执行器
- Tool注册工具定义 + 执行器
不依赖任何插件基类,组件执行器是纯 async callable。
"""
import re
from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple
from src.common.logger import get_logger
from src.core.types import (
ActionInfo,
CommandInfo,
ComponentInfo,
ComponentType,
ToolInfo,
)
logger = get_logger("component_registry")
# 执行器类型
ActionExecutor = Callable[..., Awaitable[Any]]
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
ToolExecutor = Callable[..., Awaitable[Any]]
class ComponentRegistry:
"""核心组件注册表
管理 action、command、tool 三类组件。
每个组件由「元信息 + 执行器」构成,执行器是 async callable
不需要继承任何基类。
"""
def __init__(self):
# Action 注册
self._actions: Dict[str, ActionInfo] = {}
self._action_executors: Dict[str, ActionExecutor] = {}
self._default_actions: Dict[str, ActionInfo] = {}
# Command 注册
self._commands: Dict[str, CommandInfo] = {}
self._command_executors: Dict[str, CommandExecutor] = {}
self._command_patterns: Dict[Pattern, str] = {}
# Tool 注册
self._tools: Dict[str, ToolInfo] = {}
self._tool_executors: Dict[str, ToolExecutor] = {}
self._llm_available_tools: Dict[str, ToolInfo] = {}
# 插件配置plugin_name -> config dict
self._plugin_configs: Dict[str, dict] = {}
logger.info("核心组件注册表初始化完成")
# ========== Action ==========
def register_action(
self,
info: ActionInfo,
executor: ActionExecutor,
) -> bool:
"""注册 action
Args:
info: action 元信息
executor: 执行器async callable
"""
name = info.name
if name in self._actions:
logger.warning(f"Action {name} 已存在,跳过注册")
return False
self._actions[name] = info
self._action_executors[name] = executor
if info.enabled:
self._default_actions[name] = info
logger.debug(f"注册 Action: {name}")
return True
def get_action_info(self, name: str) -> Optional[ActionInfo]:
return self._actions.get(name)
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
return self._action_executors.get(name)
def get_default_actions(self) -> Dict[str, ActionInfo]:
return self._default_actions.copy()
def get_all_actions(self) -> Dict[str, ActionInfo]:
return self._actions.copy()
def remove_action(self, name: str) -> bool:
if name not in self._actions:
return False
del self._actions[name]
self._action_executors.pop(name, None)
self._default_actions.pop(name, None)
logger.debug(f"移除 Action: {name}")
return True
# ========== Command ==========
def register_command(
self,
info: CommandInfo,
executor: CommandExecutor,
) -> bool:
"""注册 command"""
name = info.name
if name in self._commands:
logger.warning(f"Command {name} 已存在,跳过注册")
return False
self._commands[name] = info
self._command_executors[name] = executor
if info.enabled and info.command_pattern:
pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL)
self._command_patterns[pattern] = name
logger.debug(f"注册 Command: {name}")
return True
def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
"""根据文本查找匹配的命令
Returns:
(executor, matched_groups, command_info) 或 None
"""
candidates = [p for p in self._command_patterns if p.match(text)]
if not candidates:
return None
if len(candidates) > 1:
logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个")
pattern = candidates[0]
name = self._command_patterns[pattern]
return (
self._command_executors[name],
pattern.match(text).groupdict(), # type: ignore
self._commands[name],
)
def remove_command(self, name: str) -> bool:
if name not in self._commands:
return False
del self._commands[name]
self._command_executors.pop(name, None)
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name}
logger.debug(f"移除 Command: {name}")
return True
# ========== Tool ==========
def register_tool(
self,
info: ToolInfo,
executor: ToolExecutor,
) -> bool:
"""注册 tool"""
name = info.name
if name in self._tools:
logger.warning(f"Tool {name} 已存在,跳过注册")
return False
self._tools[name] = info
self._tool_executors[name] = executor
if info.enabled:
self._llm_available_tools[name] = info
logger.debug(f"注册 Tool: {name}")
return True
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
return self._tools.get(name)
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
return self._tool_executors.get(name)
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
return self._llm_available_tools.copy()
def get_all_tools(self) -> Dict[str, ToolInfo]:
return self._tools.copy()
def remove_tool(self, name: str) -> bool:
if name not in self._tools:
return False
del self._tools[name]
self._tool_executors.pop(name, None)
self._llm_available_tools.pop(name, None)
logger.debug(f"移除 Tool: {name}")
return True
# ========== 通用查询 ==========
def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]:
"""获取组件元信息"""
match component_type:
case ComponentType.ACTION:
return self._actions.get(name)
case ComponentType.COMMAND:
return self._commands.get(name)
case ComponentType.TOOL:
return self._tools.get(name)
case _:
return None
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取某类型的所有组件"""
match component_type:
case ComponentType.ACTION:
return dict(self._actions)
case ComponentType.COMMAND:
return dict(self._commands)
case ComponentType.TOOL:
return dict(self._tools)
case _:
return {}
# ========== 插件配置 ==========
def set_plugin_config(self, plugin_name: str, config: dict) -> None:
self._plugin_configs[plugin_name] = config
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
return self._plugin_configs.get(plugin_name)
# 全局单例
component_registry = ComponentRegistry()

View File

@@ -3,19 +3,19 @@
功能
1. 定期随机选取指定数量的表达方式
2. 使用LLM进行评估
2. 使用 LLM 进行评估
3. 通过评估的rejected=0, checked=1
4. 未通过评估的rejected=1, checked=1
"""
from typing import List
import asyncio
import json
import random
from typing import List
from sqlmodel import select
from src.bw_learner.expression_review_store import get_review_state, set_review_state
from src.learners.expression_review_store import get_review_state, set_review_state
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
@@ -146,7 +146,8 @@ class ExpressionAutoCheckTask(AsyncTask):
选中的表达方式列表
"""
try:
with get_db_session() as session:
# 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。
with get_db_session(auto_commit=False) as session:
statement = select(Expression)
all_expressions = session.exec(statement).all()

View File

@@ -329,7 +329,13 @@ class ExpressionLearner:
return filtered_expressions
# ====== DB 操作相关 ======
async def _upsert_expression_to_db(self, situation: str, style: str):
async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
"""将表达方式写入数据库,存在时更新,不存在时新增。
Args:
situation: 表达方式对应的使用情景
style: 表达方式风格
"""
expr, similarity = self._find_similar_expression(situation) or (None, 0)
if expr:
# 根据相似度决定是否使用 LLM 总结
@@ -340,7 +346,13 @@ class ExpressionLearner:
# 没有找到匹配的记录,创建新记录
self._create_expression(situation, style)
def _create_expression(self, situation: str, style: str):
def _create_expression(self, situation: str, style: str) -> None:
"""创建新的表达方式记录。
Args:
situation: 表达方式对应的使用情景
style: 表达方式风格
"""
content_list = [situation]
try:
with get_db_session() as db:
@@ -353,6 +365,7 @@ class ExpressionLearner:
last_active_time=datetime.now(),
)
db.add(new_expr)
db.flush()
except Exception as e:
logger.error(f"创建表达方式失败: {e}")
@@ -448,25 +461,43 @@ class ExpressionLearner:
def _find_similar_expression(
self, situation: str, similarity_threshold: float = 0.75
) -> Optional[Tuple[MaiExpression, float]]:
"""在数据库中查找相似的表达方式"""
"""在数据库中查找相似的表达方式
Args:
situation: 当前待匹配的情景描述
similarity_threshold: 认定为相似表达方式的最低相似度阈值
Returns:
Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式则返回
``(表达方式对象, 相似度)``否则返回 ``None``
"""
try:
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Expression).filter_by(session_id=self.session_id)
expressions = session.exec(statement).all()
best_match: Optional[Expression] = None
best_similarity = 0.0
best_match: Optional[MaiExpression] = None
best_similarity = 0.0
for db_expression in expressions:
expression = MaiExpression.from_db_instance(db_expression)
candidate_situations = [expression.situation, *expression.content]
for candidate_situation in candidate_situations:
normalized_candidate_situation = candidate_situation.strip()
if not normalized_candidate_situation:
continue
similarity = difflib.SequenceMatcher(
None,
situation,
normalized_candidate_situation,
).ratio()
if similarity > similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expression
for expr in expressions:
content_list = json.loads(expr.content_list)
for situation in content_list:
similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio()
if similarity > similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}")
return MaiExpression.from_db_instance(best_match), best_similarity
logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
return best_match, best_similarity
except Exception as e:
logger.error(f"查找相似表达方式失败: {e}")

View File

@@ -9,7 +9,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.learner_utils_old import weighted_sample
from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector")

View File

@@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.jargon_explainer import search_jargon
from src.bw_learner.learner_utils_old import (
from src.learners.jargon_miner_old import search_jargon
from src.learners.learner_utils_old import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,

View File

@@ -1,17 +1,18 @@
from collections import OrderedDict
from json_repair import repair_json
from sqlmodel import select
from typing import List, Optional, Dict, Callable, TypedDict, Set
from typing import Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
import random
from src.common.logger import get_logger
from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.data_models.jargon_data_model import MaiJargon
from src.config.config import model_config, global_config
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
@@ -198,7 +199,7 @@ class JargonMiner:
async def process_extracted_entries(
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
):
) -> None:
"""
处理已提取的黑话条目 expression_learner 路由过来的
@@ -229,7 +230,7 @@ class JargonMiner:
content = entry["content"]
raw_content_set = entry["raw_content"]
try:
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
except Exception as e:
logger.error(f"查询黑话 '{content}' 失败: {e}")
@@ -273,11 +274,12 @@ class JargonMiner:
try:
with get_db_session() as session:
session.add(new_jargon)
session.flush()
saved += 1
self._add_to_cache(content)
except Exception as e:
logger.error(f"保存新黑话 '{content}' 失败: {e}")
continue
finally:
self._add_to_cache(content)
# 固定输出提取的jargon结果格式化为可读形式只要有提取结果就输出
if uniq_entries:
# 收集所有提取的jargon内容
@@ -304,7 +306,13 @@ class JargonMiner:
removed_content, _ = self.cache.popitem(last=False)
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]):
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None:
"""更新已有黑话记录并写回数据库。
Args:
db_jargon: 已命中的黑话 ORM 对象
raw_content_set: 本次新增的原始上下文集合
"""
db_jargon.count += 1
existing_raw_content: List[str] = []
if db_jargon.raw_content:
@@ -326,7 +334,17 @@ class JargonMiner:
try:
with get_db_session() as session:
session.add(db_jargon)
if db_jargon.id is None:
raise ValueError("黑话记录缺少 id无法更新数据库")
statement = select(Jargon).filter_by(id=db_jargon.id).limit(1)
if persisted_jargon := session.exec(statement).first():
persisted_jargon.count = db_jargon.count
persisted_jargon.raw_content = db_jargon.raw_content
persisted_jargon.session_id_dict = db_jargon.session_id_dict
persisted_jargon.is_global = db_jargon.is_global
session.add(persisted_jargon)
else:
logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新")
except Exception as e:
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import asyncio
import time
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
from src.learners.expression_auto_check_task import ExpressionAutoCheckTask
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.bot import chat_bot
from src.chat.message_receive.chat_manager import chat_manager

View File

@@ -14,7 +14,7 @@ from src.common.database.database_model import ThinkingQuestion
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval")

View File

@@ -4,7 +4,7 @@
"""
from src.common.logger import get_logger
from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")

View File

@@ -1,24 +1,24 @@
import hashlib
from datetime import datetime
from typing import Dict, Optional, Union
import asyncio
import hashlib
import json
import time
import random
import math
import random
import time
from json_repair import repair_json
from typing import Union, Optional, Dict, List
from datetime import datetime
from sqlalchemy import or_
from sqlmodel import col, select
from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.services.memory_service import memory_service
from src.llm_models.utils_model import LLMRequest
logger = get_logger("person_info")
@@ -28,6 +28,32 @@ relation_selection_model = LLMRequest(
)
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
Args:
group_cardname_json: 数据库存储的群名片 JSON 字符串。
Returns:
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
Raises:
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
"""
group_cardname_list = parse_group_cardname_json(group_cardname_json)
if not group_cardname_list:
return []
return [
{
"group_id": group_cardname.group_id,
"group_cardname": group_cardname.group_cardname,
}
for group_cardname in group_cardname_list
]
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
@@ -39,60 +65,16 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
clean_name = str(person_name or "").strip()
if not clean_name:
return ""
try:
with get_db_session() as session:
statement = (
select(PersonInfo)
.where(
or_(
col(PersonInfo.person_name) == clean_name,
col(PersonInfo.user_nickname) == clean_name,
)
)
.limit(1)
)
record = session.exec(statement).first()
if record and record.person_id:
return record.person_id
statement = (
select(PersonInfo)
.where(PersonInfo.group_cardname.contains(clean_name))
.limit(1)
)
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
record = session.exec(statement).first()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}")
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
return ""
def resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Optional[Union[int, str]] = None,
) -> str:
"""统一人物记忆链路中的 person_id 解析。
优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。
"""
name_token = str(person_name or "").strip()
if name_token:
resolved = get_person_id_by_person_name(name_token)
if resolved:
return resolved
platform_token = str(platform or "").strip()
user_token = str(user_id or "").strip()
if platform_token and user_token:
return get_person_id(platform_token, user_token)
return ""
def is_person_known(
person_id: Optional[str] = None,
user_id: Optional[str] = None,
@@ -277,7 +259,7 @@ class Person:
person.know_since = time.time()
person.last_know = time.time()
person.memory_points = []
person.group_nick_name = [] # 初始化群昵称列表
person.group_cardname_list = [] # 初始化群名片列表
# 如果是群聊,添加群昵称
if group_id and group_nick_name:
@@ -315,7 +297,7 @@ class Person:
self.platform = platform
self.nickname = global_config.bot.nickname
self.person_name = global_config.bot.nickname
self.group_nick_name: list[dict[str, str]] = []
self.group_cardname_list: list[dict[str, str]] = []
return
self.user_id = ""
@@ -354,7 +336,7 @@ class Person:
self.know_since = None
self.last_know: Optional[float] = None
self.memory_points = []
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
# 从数据库加载数据
self.load_from_database()
@@ -454,16 +436,16 @@ class Person:
return
# 检查是否已存在该群号的记录
for item in self.group_nick_name:
for item in self.group_cardname_list:
if item.get("group_id") == group_id:
# 更新现有记录
item["group_nick_name"] = group_nick_name
item["group_cardname"] = group_nick_name
self.sync_to_database()
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
return
# 添加新记录
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
self.sync_to_database()
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
@@ -498,20 +480,15 @@ class Person:
else:
self.memory_points = []
# 处理group_nick_name字段JSON格式的列表
# 处理 group_cardname 字段JSON 格式的列表)
if record.group_cardname:
try:
loaded_group_nick_names = json.loads(record.group_cardname)
# 确保是列表格式
if isinstance(loaded_group_nick_names, list):
self.group_nick_name = loaded_group_nick_names
else:
self.group_nick_name = []
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败使用默认值")
self.group_nick_name = []
self.group_cardname_list = []
else:
self.group_nick_name = []
self.group_cardname_list = []
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
else:
@@ -532,11 +509,7 @@ class Person:
if self.memory_points
else json.dumps([], ensure_ascii=False)
)
group_nickname_value = (
json.dumps(self.group_nick_name, ensure_ascii=False)
if self.group_nick_name
else json.dumps([], ensure_ascii=False)
)
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
@@ -556,7 +529,7 @@ class Person:
record.first_known_time = first_known_time
record.last_known_time = last_known_time
record.memory_points = memory_points_value
record.group_nickname = group_nickname_value
record.group_cardname = group_cardname_value
session.add(record)
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
@@ -572,7 +545,7 @@ class Person:
first_known_time=first_known_time,
last_known_time=last_known_time,
memory_points=memory_points_value,
group_nickname=group_nickname_value,
group_cardname=group_cardname_value,
)
session.add(record)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
@@ -583,79 +556,79 @@ class Person:
async def build_relationship(self, chat_content: str = "", info_type=""):
if not self.is_known:
return ""
# 构建points文本
nickname_str = ""
if self.person_name != self.nickname:
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]:
clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()]
if not clean_traits:
return []
if not query_text:
return clean_traits[:limit]
relation_info = ""
numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1))
prompt = f"""当前关注内容:
{query_text}
points_text = ""
category_list = self.get_all_category()
候选人物信息:
{numbered_traits}
if chat_content:
prompt = f"""当前聊天内容:
{chat_content}
请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。
例如:
<1><3>
如果都不相关,请输出<none>"""
分类列表:
{category_list}
**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
例如:
<分类1><分类2><分类3>......
如果没有相关的分类,请输出<none>"""
try:
response, _ = await relation_selection_model.generate_response_async(prompt)
selected_traits: List[str] = []
for raw_index in extract_categories_from_response(response):
if raw_index == "none":
return []
try:
trait_index = int(raw_index) - 1
except ValueError:
continue
if 0 <= trait_index < len(clean_traits):
trait = clean_traits[trait_index]
if trait not in selected_traits:
selected_traits.append(trait)
if selected_traits:
return selected_traits[:limit]
except Exception as e:
logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}")
response, _ = await relation_selection_model.generate_response_async(prompt)
# print(prompt)
# print(response)
category_list = extract_categories_from_response(response)
if "none" not in category_list:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 2)
if random_memory:
random_memory_str = "\n".join(
[get_memory_content_from_memory(memory) for memory in random_memory]
)
points_text = f"有关 {category} 的内容:{random_memory_str}"
break
elif info_type:
prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。
return clean_traits[:limit]
profile = await memory_service.get_person_profile(self.person_id, limit=8)
relation_parts: List[str] = []
if profile.summary.strip():
relation_parts.append(profile.summary.strip())
query_text = str(chat_content or info_type or "").strip()
selected_traits = await _select_traits(query_text, profile.traits, limit=3)
if not selected_traits and not query_text:
selected_traits = [trait for trait in profile.traits if trait][:2]
for trait in selected_traits:
clean_trait = str(trait).strip()
if clean_trait and clean_trait not in relation_parts:
relation_parts.append(clean_trait)
for evidence in profile.evidence:
content = str(evidence.get("content", "") or "").strip()
if content and content not in relation_parts:
relation_parts.append(content)
if len(relation_parts) >= 4:
break
现有信息类别列表:
{category_list}
**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
例如:
<分类1><分类2><分类3>......
如果没有相关的分类,请输出<none>"""
response, _ = await relation_selection_model.generate_response_async(prompt)
# print(prompt)
# print(response)
category_list = extract_categories_from_response(response)
if "none" not in category_list:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 3)
if random_memory:
random_memory_str = "\n".join(
[get_memory_content_from_memory(memory) for memory in random_memory]
)
points_text = f"有关 {category} 的内容:{random_memory_str}"
break
else:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 1)[0]
if random_memory:
points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}"
break
points_info = ""
if relation_parts:
points_info = f"你还记得有关{self.person_name}的内容:{''.join(relation_parts[:3])}"
if points_text:
points_info = f"你还记得有关{self.person_name}的内容:{points_text}"
if not (nickname_str or points_info):
return ""
return f"{self.person_name}:{nickname_str}{points_info}"
relation_info = f"{self.person_name}:{nickname_str}{points_info}"
return relation_info
class PersonInfoManager:
@@ -822,7 +795,7 @@ person_info_manager = PersonInfoManager()
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
"""将人物事实写入统一长期记忆
"""将人物信息存入person_info的memory_points
Args:
person_name: 人物名称
@@ -830,11 +803,6 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
chat_id: 聊天ID
"""
try:
content = str(memory_content or "").strip()
if not content:
logger.debug("人物记忆内容为空,跳过写入")
return
# 从 chat_id 获取 session
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
@@ -845,14 +813,16 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
# 尝试从person_name查找person_id
# 首先尝试通过person_name查找
person_id = resolve_person_id_for_memory(
person_name=person_name,
platform=platform,
user_id=session.user_id,
)
person_id = get_person_id_by_person_name(person_name)
if not person_id:
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
return
# 如果通过person_name找不到尝试从 session 获取 user_id
if platform and session.user_id:
user_id = session.user_id
person_id = get_person_id(platform, user_id)
else:
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
return
# 创建或获取Person对象
person = Person(person_id=person_id)
@@ -861,34 +831,39 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
return
memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16]
result = await memory_service.ingest_text(
external_id=f"person_fact:{person_id}:{memory_hash}",
source_type="person_fact",
text=content,
chat_id=chat_id,
person_ids=[person_id],
participants=[person.person_name or person_name],
timestamp=time.time(),
tags=["person_fact"],
metadata={
"person_id": person_id,
"person_name": person.person_name or person_name,
"platform": platform,
"source": "person_info.store_person_memory_from_answer",
},
respect_filter=True,
user_id=str(session.user_id or "").strip(),
group_id=str(session.group_id or "").strip(),
)
# 确定记忆分类可以根据memory_content判断这里使用通用分类
category = "其他" # 默认分类,可以根据需要调整
if result.success:
if result.detail == "chat_filtered":
logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})")
else:
logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})")
# 记忆点格式category:content:weight
weight = "1.0" # 默认权重
memory_point = f"{category}:{memory_content}:{weight}"
# 添加到memory_points
if not person.memory_points:
person.memory_points = []
# 检查是否已存在相似的记忆点(避免重复)
is_duplicate = False
for existing_point in person.memory_points:
if existing_point and isinstance(existing_point, str):
parts = existing_point.split(":", 2)
if len(parts) >= 2:
existing_content = parts[1].strip()
# 简单相似度检查(如果内容相同或非常相似,则跳过)
if (
existing_content == memory_content
or memory_content in existing_content
or existing_content in memory_content
):
is_duplicate = True
break
if not is_duplicate:
person.memory_points.append(memory_point)
person.sync_to_database()
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
else:
logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}")
logger.debug(f"记忆点已存在,跳过: {memory_point}")
except Exception as e:
logger.error(f"存储人物记忆失败: {e}")

View File

@@ -0,0 +1,34 @@
"""导出 Platform IO 层的公开入口。
当前仍处于地基阶段,调用方应优先从这里导入共享类型和全局管理器,
而不是直接依赖更底层的私有子模块。
"""
from .manager import PlatformIOManager, get_platform_io_manager
from .route_key_factory import RouteKeyFactory
from .routing import RouteTable
from .types import (
DeliveryBatch,
DeliveryReceipt,
DeliveryStatus,
DriverDescriptor,
DriverKind,
InboundMessageEnvelope,
RouteBinding,
RouteKey,
)
__all__ = [
"DeliveryBatch",
"DeliveryReceipt",
"DeliveryStatus",
"DriverDescriptor",
"DriverKind",
"InboundMessageEnvelope",
"PlatformIOManager",
"RouteKeyFactory",
"RouteBinding",
"RouteKey",
"RouteTable",
"get_platform_io_manager",
]

133
src/platform_io/dedupe.py Normal file
View File

@@ -0,0 +1,133 @@
"""提供 Platform IO 的轻量入站消息去重能力。
当前实现基于 ``dict + heapq``
- ``dict`` 保存去重键到过期时间的映射
- ``heapq`` 维护按过期时间排序的小顶堆
这样就不需要在每次检查时全表扫描,而是通过懒清理逐步弹出已经过期
或已经失效的堆节点。
"""
from typing import Dict, List, Tuple
import heapq
import time
class MessageDeduplicator:
"""使用基于 TTL 的内存缓存进行入站消息去重。
主要用于解决同一条外部消息被重复送入 Core 的问题,例如双路径并存、
适配器重试、重连或重复回调等场景。Broker 可以借助这个组件在进入
Core 前先拦住重复投递,避免重复处理、重复回复和重复入库。
当前实现使用 ``dict + heapq`` 维护过期时间:
- ``dict`` 负责 ``O(1)`` 级别的去重键查找
- ``heapq`` 负责按过期时间顺序做懒清理
这比“每次调用都全表扫描过期项”的实现更适合高吞吐消息场景。
Notes:
复杂度说明如下,设 ``n`` 为当前缓存中的有效去重键数量:
- 单次 ``mark_seen()`` 在常见路径下的时间复杂度接近 ``O(log n)``
- 从长期摊还角度看,``mark_seen()`` 的时间复杂度也接近 ``O(log n)``
- 如果某次调用恰好触发一批过期键的集中清理,则该次调用的最坏时间复杂度
可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出或清理的键数量
- 空间复杂度为 ``O(n)``
"""
def __init__(self, ttl_seconds: float = 300.0, max_entries: int = 10000) -> None:
"""初始化去重器。
Args:
ttl_seconds: 每个去重键在缓存中的保留时长,单位为秒。
max_entries: 缓存允许保留的最大有效键数量,超出后会触发
机会性淘汰。
Raises:
ValueError: 当 ``ttl_seconds`` 或 ``max_entries`` 非正数时抛出。
"""
if ttl_seconds <= 0:
raise ValueError("ttl_seconds 必须大于 0")
if max_entries <= 0:
raise ValueError("max_entries 必须大于 0")
self._ttl_seconds = ttl_seconds
self._max_entries = max_entries
self._expire_heap: List[Tuple[float, str]] = []
self._seen: Dict[str, float] = {}
def mark_seen(self, dedupe_key: str) -> bool:
"""标记一条去重键已经出现过。
Args:
dedupe_key: 能稳定标识一条外部入站消息的去重键。
Returns:
bool: 若该键在当前 TTL 窗口内首次出现则返回 ``True``
否则返回 ``False``。
Notes:
方法会先基于小顶堆做一次懒清理,再判断当前键是否仍在有效期内。
如果缓存已达到上限,则会优先淘汰“最早过期的仍然有效的键”。
复杂度方面,常见路径下该方法接近 ``O(log n)``;如果恰好需要
集中清理一批过期键,则单次调用最坏可达到 ``O(k log n)``。
"""
now = time.monotonic()
self._purge_expired(now)
expires_at = self._seen.get(dedupe_key)
if expires_at is not None and expires_at > now:
return False
if len(self._seen) >= self._max_entries:
self._evict_earliest_live()
expires_at = now + self._ttl_seconds
self._seen[dedupe_key] = expires_at
heapq.heappush(self._expire_heap, (expires_at, dedupe_key))
return True
def clear(self) -> None:
"""清空全部去重缓存。"""
self._expire_heap.clear()
self._seen.clear()
def _purge_expired(self, now: float) -> None:
"""从缓存中清理已经过期的去重键。
Args:
now: 当前单调时钟时间戳。
Notes:
堆中可能存在旧版本节点。例如同一个 ``dedupe_key`` 被重新写入后,
旧的过期时间节点仍会留在堆里。这里会通过和 ``dict`` 中当前值比对,
跳过这类失效节点。
"""
while self._expire_heap and self._expire_heap[0][0] <= now:
expires_at, dedupe_key = heapq.heappop(self._expire_heap)
current_expires_at = self._seen.get(dedupe_key)
if current_expires_at is None:
continue
if current_expires_at != expires_at:
continue
self._seen.pop(dedupe_key, None)
def _evict_earliest_live(self) -> None:
"""当缓存达到容量上限时,淘汰一条最早过期的有效键。
Notes:
堆顶可能是已经过期或已失效的旧节点,因此这里同样需要循环弹出,
直到找到一条当前仍然在 ``dict`` 中生效的键。
"""
while self._expire_heap:
expires_at, dedupe_key = heapq.heappop(self._expire_heap)
current_expires_at = self._seen.get(dedupe_key)
if current_expires_at is None:
continue
if current_expires_at != expires_at:
continue
self._seen.pop(dedupe_key, None)
return

View File

@@ -0,0 +1,11 @@
"""导出 Platform IO 层的公开驱动类型。"""
from .base import PlatformIODriver
from .legacy_driver import LegacyPlatformDriver
from .plugin_driver import PluginPlatformDriver
__all__ = [
"LegacyPlatformDriver",
"PlatformIODriver",
"PluginPlatformDriver",
]

View File

@@ -0,0 +1,104 @@
"""定义 Platform IO 传输驱动的基础抽象协议。"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
from src.platform_io.types import DeliveryReceipt, DriverDescriptor, InboundMessageEnvelope, RouteKey
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
InboundHandler = Callable[[InboundMessageEnvelope], Awaitable[bool]]
class PlatformIODriver(ABC):
"""定义所有 Platform IO 驱动都必须实现的最小契约。
当前实现故意保持接口很小,让中间层可以先落地,再逐步把 legacy
与 plugin 路径的真实收发能力迁入这套协议之下。
"""
def __init__(self, descriptor: DriverDescriptor) -> None:
"""使用驱动描述对象初始化驱动。
Args:
descriptor: 注册到 Broker 中的静态驱动元数据。
"""
self._descriptor = descriptor
self._inbound_handler: Optional[InboundHandler] = None
@property
def descriptor(self) -> DriverDescriptor:
"""返回当前驱动的描述对象。
Returns:
DriverDescriptor: 当前驱动实例对应的描述对象。
"""
return self._descriptor
@property
def driver_id(self) -> str:
"""返回驱动标识。
Returns:
str: 当前驱动的唯一 ID。
"""
return self._descriptor.driver_id
def set_inbound_handler(self, handler: InboundHandler) -> None:
"""注册入站消息交回 Broker 的回调函数。
Args:
handler: 将规范化入站封装继续转发给 Broker 的异步回调。
"""
self._inbound_handler = handler
def clear_inbound_handler(self) -> None:
"""清除当前注册的入站回调函数。"""
self._inbound_handler = None
async def emit_inbound(self, envelope: InboundMessageEnvelope) -> bool:
"""将一条入站封装转交给 Broker 回调。
Args:
envelope: 由驱动产出的规范化入站封装。
Returns:
bool: 若 Broker 接受该入站消息则返回 ``True``,否则返回 ``False``。
"""
if self._inbound_handler is None:
return False
return await self._inbound_handler(envelope)
async def start(self) -> None:
"""启动驱动生命周期。
子类后续若需要初始化逻辑,可以覆盖这个钩子。
"""
return None
async def stop(self) -> None:
"""停止驱动生命周期。
子类后续若需要清理逻辑,可以覆盖这个钩子。
"""
return None
@abstractmethod
async def send_message(
self,
message: "SessionMessage",
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
"""通过具体驱动发送一条消息。
Args:
message: 要投递的内部会话消息。
route_key: Broker 为本次投递选中的路由键。
metadata: 本次出站投递可选的 Broker 侧元数据。
Returns:
DeliveryReceipt: 规范化后的投递结果。
"""

View File

@@ -0,0 +1,92 @@
"""提供 Platform IO 的 legacy 传输驱动实现。"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from src.platform_io.drivers.base import PlatformIODriver
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
class LegacyPlatformDriver(PlatformIODriver):
"""面向 ``UniversalMessageSender`` 旧链的 Platform IO 驱动。"""
def __init__(
self,
driver_id: str,
platform: str,
account_id: Optional[str] = None,
scope: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""初始化一个 legacy 驱动描述对象。
Args:
driver_id: Broker 内的唯一驱动 ID。
platform: 该 legacy 适配器链路负责的平台。
account_id: 可选的账号 ID。
scope: 可选的额外路由作用域。
metadata: 可选的额外驱动元数据。
"""
descriptor = DriverDescriptor(
driver_id=driver_id,
kind=DriverKind.LEGACY,
platform=platform,
account_id=account_id,
scope=scope,
metadata=metadata or {},
)
super().__init__(descriptor)
async def send_message(
self,
message: "SessionMessage",
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
"""通过旧链发送一条已经过预处理的消息。
Args:
message: 要投递的内部会话消息。
route_key: Broker 为本次投递选择的路由键。
metadata: 本次出站投递可选的 Broker 侧元数据。
Returns:
DeliveryReceipt: 规范化后的发送回执。
"""
from src.chat.message_receive.uni_message_sender import send_prepared_message_to_platform
show_log = False
if isinstance(metadata, dict):
show_log = bool(metadata.get("show_log", False))
try:
sent = await send_prepared_message_to_platform(message, show_log=show_log)
except Exception as exc:
return DeliveryReceipt(
internal_message_id=message.message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
error=str(exc),
)
if not sent:
return DeliveryReceipt(
internal_message_id=message.message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
error="旧链发送失败",
)
return DeliveryReceipt(
internal_message_id=message.message_id,
route_key=route_key,
status=DeliveryStatus.SENT,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
)

View File

@@ -0,0 +1,211 @@
"""提供 Platform IO 的插件消息网关驱动实现。"""
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol
from src.platform_io.drivers.base import PlatformIODriver
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
class _GatewaySupervisorProtocol(Protocol):
"""消息网关驱动依赖的 Supervisor 最小协议。"""
async def invoke_message_gateway(
self,
plugin_id: str,
component_name: str,
args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Any:
"""调用插件声明的消息网关方法。"""
class PluginPlatformDriver(PlatformIODriver):
"""面向插件消息网关链路的 Platform IO 驱动。"""
def __init__(
self,
driver_id: str,
platform: str,
supervisor: _GatewaySupervisorProtocol,
component_name: str,
*,
supports_send: bool,
account_id: Optional[str] = None,
scope: Optional[str] = None,
plugin_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""初始化一个插件消息网关驱动。
Args:
driver_id: Broker 内的唯一驱动 ID。
platform: 该消息网关负责的平台名称。
supervisor: 持有该插件的 Supervisor。
component_name: 出站时要调用的网关组件名称。
supports_send: 当前驱动是否具备出站能力。
account_id: 可选的账号 ID 或 self ID。
scope: 可选的额外路由作用域。
plugin_id: 拥有该实现的插件 ID。
metadata: 可选的额外驱动元数据。
"""
descriptor = DriverDescriptor(
driver_id=driver_id,
kind=DriverKind.PLUGIN,
platform=platform,
account_id=account_id,
scope=scope,
plugin_id=plugin_id,
metadata=metadata or {},
)
super().__init__(descriptor)
self._supervisor = supervisor
self._component_name = component_name
self._supports_send = supports_send
async def send_message(
self,
message: "SessionMessage",
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
"""通过插件消息网关发送消息。
Args:
message: 要投递的内部会话消息。
route_key: Broker 为本次投递选择的路由键。
metadata: 可选的发送元数据。
Returns:
DeliveryReceipt: 规范化后的发送回执。
"""
if not self._supports_send:
return DeliveryReceipt(
internal_message_id=message.message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
error="当前消息网关仅支持接收,不支持发送",
)
from src.plugin_runtime.host.message_utils import PluginMessageUtils
plugin_id = self.descriptor.plugin_id or ""
if not plugin_id:
return DeliveryReceipt(
internal_message_id=message.message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
error="插件消息网关驱动缺少 plugin_id",
)
try:
message_dict = PluginMessageUtils._session_message_to_dict(message)
response = await self._supervisor.invoke_message_gateway(
plugin_id=plugin_id,
component_name=self._component_name,
args={
"message": message_dict,
"route": {
"platform": route_key.platform,
"account_id": route_key.account_id,
"scope": route_key.scope,
},
"metadata": metadata or {},
},
timeout_ms=30000,
)
except Exception as exc:
return DeliveryReceipt(
internal_message_id=message.message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
error=str(exc),
)
return self._build_receipt(message.message_id, route_key, response)
def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt:
"""将网关调用响应归一化为出站回执。
Args:
internal_message_id: 内部消息 ID。
route_key: 本次投递的路由键。
response: Supervisor 返回的 RPC 响应对象。
Returns:
DeliveryReceipt: 标准化后的出站回执。
"""
if getattr(response, "error", None):
error = response.error.get("message", "消息网关发送失败")
return DeliveryReceipt(
internal_message_id=internal_message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
error=error,
)
payload = getattr(response, "payload", {})
invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False
if not invoke_success:
return DeliveryReceipt(
internal_message_id=internal_message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
error=str(payload.get("result", "消息网关发送失败")) if isinstance(payload, dict) else "消息网关发送失败",
)
result = payload.get("result") if isinstance(payload, dict) else None
if isinstance(result, dict):
if result.get("success") is False:
return DeliveryReceipt(
internal_message_id=internal_message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
error=str(result.get("error", "消息网关发送失败")),
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
)
external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None
return DeliveryReceipt(
internal_message_id=internal_message_id,
route_key=route_key,
status=DeliveryStatus.SENT,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
external_message_id=external_message_id,
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
)
if isinstance(result, str) and result.strip():
return DeliveryReceipt(
internal_message_id=internal_message_id,
route_key=route_key,
status=DeliveryStatus.SENT,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
external_message_id=result.strip(),
)
return DeliveryReceipt(
internal_message_id=internal_message_id,
route_key=route_key,
status=DeliveryStatus.SENT,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
)

611
src/platform_io/manager.py Normal file
View File

@@ -0,0 +1,611 @@
"""提供 Platform IO 层的中心 Broker 管理器。"""
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
from src.common.logger import get_logger
from src.platform_io.drivers.base import PlatformIODriver
from .dedupe import MessageDeduplicator
from .outbound_tracker import OutboundTracker
from .route_key_factory import RouteKeyFactory
from .registry import DriverRegistry
from .routing import RouteTable
from .types import DeliveryBatch, DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
logger = get_logger("platform_io.manager")
InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]]
class PlatformIOManager:
"""统一协调平台消息 IO 的路由、去重与状态跟踪。
与旧实现不同,这个管理器不再负责“多条链路谁该接管平台”的裁决,
只维护发送表和接收表两张轻量路由表:
- 发送时:解析所有命中的发送绑定并全部投递。
- 接收时:只校验当前驱动是否已登记为可接收链路,然后全部放行给上层。
- 去重时:仅对单条链路做技术性重放抑制,不做跨链路语义去重。
"""
def __init__(self) -> None:
"""初始化 Broker 管理器及其内存状态。"""
self._driver_registry = DriverRegistry()
self._send_route_table = RouteTable()
self._receive_route_table = RouteTable()
self._legacy_send_drivers: Dict[str, PlatformIODriver] = {}
self._deduplicator = MessageDeduplicator()
self._outbound_tracker = OutboundTracker()
self._inbound_dispatcher: Optional[InboundDispatcher] = None
self._started = False
@property
def is_started(self) -> bool:
"""返回 Broker 当前是否已进入运行态。
Returns:
bool: 若 Broker 已启动则返回 ``True``。
"""
return self._started
async def start(self) -> None:
"""启动 Broker并依次启动当前已注册的全部驱动。
Raises:
Exception: 当某个驱动启动失败时,异常会继续上抛;已成功启动的驱动
会被自动回滚停止。
"""
if self._started:
return
started_drivers: List[PlatformIODriver] = []
try:
for driver in self._driver_registry.list():
await driver.start()
started_drivers.append(driver)
except Exception:
for driver in reversed(started_drivers):
try:
await driver.stop()
except Exception:
logger.exception(f"回滚驱动停止失败: driver_id={driver.driver_id}")
raise
self._started = True
async def ensure_send_pipeline_ready(self) -> None:
"""确保出站发送管线已准备就绪。
该方法会先同步 legacy fallback driver再在需要时启动 Broker。
send service 应只调用这一层准备入口,而不是自行判断旧链或插件链。
"""
await self._sync_legacy_send_drivers()
if not self._started:
await self.start()
async def stop(self) -> None:
"""停止 Broker并按逆序停止全部已注册驱动。
停止完成后,会同步清空仅对当前运行周期有效的去重缓存和出站跟踪状态,
避免下一次启动时继续沿用上一个运行周期的瞬时内存数据。
Raises:
RuntimeError: 当一个或多个驱动停止失败时抛出汇总异常。
"""
if not self._started:
return
stop_errors: List[str] = []
for driver in reversed(self._driver_registry.list()):
try:
await driver.stop()
except Exception as exc:
stop_errors.append(f"{driver.driver_id}: {exc}")
logger.exception(f"驱动停止失败: driver_id={driver.driver_id}")
self._started = False
self._deduplicator.clear()
self._outbound_tracker.clear()
if stop_errors:
raise RuntimeError(f"部分驱动停止失败: {'; '.join(stop_errors)}")
async def add_driver(self, driver: PlatformIODriver) -> None:
"""向运行中的 Broker 注册并启动一个驱动。
如果 Broker 尚未启动,则该方法等价于 ``register_driver()``。
Args:
driver: 要添加的驱动实例。
Raises:
Exception: 当驱动启动失败时,注册会自动回滚,异常继续上抛。
"""
self._register_driver_internal(driver)
if not self._started:
return
try:
await driver.start()
except Exception:
self._unregister_driver_internal(driver.driver_id)
raise
async def remove_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
"""从运行中的 Broker 停止并移除一个驱动。
如果 Broker 尚未启动,则该方法等价于 ``unregister_driver()``。
Args:
driver_id: 要移除的驱动 ID。
Returns:
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
Raises:
Exception: 当 Broker 运行中且驱动停止失败时,异常会继续上抛。
"""
if not self._started:
return self.unregister_driver(driver_id)
driver = self._driver_registry.get(driver_id)
if driver is None:
return None
await driver.stop()
return self._unregister_driver_internal(driver_id)
@property
def driver_registry(self) -> DriverRegistry:
"""返回管理器持有的驱动注册表。
Returns:
DriverRegistry: 用于保存全部已注册驱动的注册表。
"""
return self._driver_registry
@property
def send_route_table(self) -> RouteTable:
"""返回发送路由表。"""
return self._send_route_table
@property
def receive_route_table(self) -> RouteTable:
"""返回接收路由表。"""
return self._receive_route_table
@property
def deduplicator(self) -> MessageDeduplicator:
"""返回管理器持有的入站去重器。
Returns:
MessageDeduplicator: 用于抑制重复入站的去重器。
"""
return self._deduplicator
@property
def outbound_tracker(self) -> OutboundTracker:
"""返回管理器持有的出站跟踪器。
Returns:
OutboundTracker: 用于记录出站 pending 状态与回执的跟踪器。
"""
return self._outbound_tracker
def set_inbound_dispatcher(self, dispatcher: InboundDispatcher) -> None:
"""设置统一的入站分发回调。
Args:
dispatcher: 接收已通过 Broker 审核的入站封装,并继续送入
Core 下一处理阶段的异步回调。
"""
self._inbound_dispatcher = dispatcher
def clear_inbound_dispatcher(self) -> None:
"""清除当前的入站分发回调。"""
self._inbound_dispatcher = None
@property
def has_inbound_dispatcher(self) -> bool:
"""返回当前是否已经配置入站分发回调。
Returns:
bool: 若已经配置入站分发回调则返回 ``True``。
"""
return self._inbound_dispatcher is not None
def register_driver(self, driver: PlatformIODriver) -> None:
"""注册驱动,并把它的入站回调挂到 Broker。
Args:
driver: 要注册的驱动实例。
Raises:
RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
``add_driver()`` 以保证驱动生命周期和注册状态一致。
"""
if self._started:
raise RuntimeError("Broker 运行中不允许直接 register_driver请改用 add_driver()")
self._register_driver_internal(driver)
def _register_driver_internal(self, driver: PlatformIODriver) -> None:
"""执行不带运行态限制的内部驱动注册。
Args:
driver: 要注册的驱动实例。
"""
driver.set_inbound_handler(self.accept_inbound)
self._driver_registry.register(driver)
def unregister_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
"""从 Broker 注销一个驱动。
Args:
driver_id: 要移除的驱动 ID。
Returns:
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
Raises:
RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
``remove_driver()``,避免驱动停止与路由解绑脱节。
"""
if self._started:
raise RuntimeError("Broker 运行中不允许直接 unregister_driver请改用 remove_driver()")
return self._unregister_driver_internal(driver_id)
def _unregister_driver_internal(self, driver_id: str) -> Optional[PlatformIODriver]:
"""执行不带运行态限制的内部驱动注销。
Args:
driver_id: 要移除的驱动 ID。
Returns:
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
"""
removed_driver = self._driver_registry.unregister(driver_id)
if removed_driver is None:
return None
removed_driver.clear_inbound_handler()
self._send_route_table.remove_bindings_by_driver(driver_id)
self._receive_route_table.remove_bindings_by_driver(driver_id)
self._legacy_send_drivers = {
platform: driver
for platform, driver in self._legacy_send_drivers.items()
if driver.driver_id != driver_id
}
return removed_driver
async def _sync_legacy_send_drivers(self) -> None:
"""根据当前配置同步 legacy fallback driver。"""
from src.chat.utils.utils import get_all_bot_accounts
from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
desired_accounts = get_all_bot_accounts()
desired_platforms = set(desired_accounts.keys())
current_platforms = set(self._legacy_send_drivers.keys())
for platform in sorted(current_platforms - desired_platforms):
await self._remove_legacy_send_driver(platform)
for platform, account_id in desired_accounts.items():
existing_driver = self._legacy_send_drivers.get(platform)
if existing_driver is not None and existing_driver.descriptor.account_id == account_id:
continue
if existing_driver is not None:
await self._remove_legacy_send_driver(platform)
driver = LegacyPlatformDriver(
driver_id=f"legacy.send.{platform}",
platform=platform,
account_id=account_id,
)
if self._started:
await self.add_driver(driver)
else:
self.register_driver(driver)
self._legacy_send_drivers[platform] = driver
async def _remove_legacy_send_driver(self, platform: str) -> None:
"""移除指定平台的 legacy fallback driver。
Args:
platform: 要移除的目标平台。
"""
driver = self._legacy_send_drivers.get(platform)
if driver is None:
return
if self._started:
await self.remove_driver(driver.driver_id)
else:
self.unregister_driver(driver.driver_id)
self._legacy_send_drivers.pop(platform, None)
def bind_send_route(self, binding: RouteBinding) -> None:
"""为某个路由键绑定发送驱动。
Args:
binding: 要保存的路由绑定。
Raises:
ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
"""
driver = self._driver_registry.get(binding.driver_id)
if driver is None:
raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
self._validate_binding_against_driver(binding, driver)
self._send_route_table.bind(binding)
def bind_receive_route(self, binding: RouteBinding) -> None:
"""为某个路由键绑定接收驱动。
Args:
binding: 要保存的路由绑定。
Raises:
ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
"""
driver = self._driver_registry.get(binding.driver_id)
if driver is None:
raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
self._validate_binding_against_driver(binding, driver)
self._receive_route_table.bind(binding)
def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
"""移除发送路由绑定。
Args:
route_key: 要移除绑定的路由键。
driver_id: 可选的特定驱动 ID。
"""
self._send_route_table.unbind(route_key, driver_id)
def unbind_receive_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
"""移除接收路由绑定。
Args:
route_key: 要移除绑定的路由键。
driver_id: 可选的特定驱动 ID。
"""
self._receive_route_table.unbind(route_key, driver_id)
def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]:
"""解析某个路由键当前命中的全部发送驱动。
Args:
route_key: 要解析的路由键。
Returns:
List[PlatformIODriver]: 当前命中的全部发送驱动。
"""
drivers: List[PlatformIODriver] = []
seen_driver_ids: set[str] = set()
for binding in self._send_route_table.resolve_bindings(route_key):
driver = self._driver_registry.get(binding.driver_id)
if driver is not None and driver.driver_id not in seen_driver_ids:
drivers.append(driver)
seen_driver_ids.add(driver.driver_id)
fallback_driver = self._legacy_send_drivers.get(route_key.platform)
if fallback_driver is not None:
descriptor = fallback_driver.descriptor
account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id)
scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope)
if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids:
drivers.append(fallback_driver)
return drivers
@staticmethod
def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
"""根据 ``SessionMessage`` 构造路由键。
Args:
message: 内部会话消息对象。
Returns:
RouteKey: 由消息内容提取出的规范化路由键。
"""
return RouteKeyFactory.from_session_message(message)
@staticmethod
def build_route_key_from_message_dict(message_dict: Dict[str, Any]) -> RouteKey:
"""根据消息字典构造路由键。
Args:
message_dict: Host 与插件之间传输的消息字典。
Returns:
RouteKey: 由消息字典提取出的规范化路由键。
"""
return RouteKeyFactory.from_message_dict(message_dict)
async def accept_inbound(self, envelope: InboundMessageEnvelope) -> bool:
"""处理一条由驱动上报的入站封装。
Args:
envelope: 由传输驱动产出的入站封装。
Returns:
bool: 若消息被接受并继续转发给入站分发器,则返回 ``True``
否则返回 ``False``。
"""
if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id):
logger.info(
f"忽略未登记到接收路由表的入站消息: route={envelope.route_key} "
f"driver={envelope.driver_id}"
)
return False
if self._inbound_dispatcher is None:
logger.debug("PlatformIOManager 尚未配置 inbound dispatcher暂不继续分发")
return False
dedupe_key = self._build_inbound_dedupe_key(envelope)
if dedupe_key is not None:
if not self._deduplicator.mark_seen(dedupe_key):
logger.info(f"忽略重复入站消息: dedupe_key={dedupe_key}")
return False
await self._inbound_dispatcher(envelope)
return True
async def send_message(
self,
message: "SessionMessage",
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryBatch:
"""通过 Broker 选中的全部发送驱动广播一条消息。
Args:
message: 要投递的内部会话消息。
route_key: 本次出站投递选择的路由键。
metadata: 可选的额外 Broker 侧元数据。
Returns:
DeliveryBatch: 规范化后的批量出站回执。
"""
drivers = self.resolve_drivers(route_key)
if not drivers:
return DeliveryBatch(internal_message_id=message.message_id, route_key=route_key)
receipts: List[DeliveryReceipt] = []
for driver in drivers:
try:
self._outbound_tracker.begin_tracking(
internal_message_id=message.message_id,
route_key=route_key,
driver_id=driver.driver_id,
metadata=metadata,
)
except ValueError as exc:
receipts.append(
DeliveryReceipt(
internal_message_id=message.message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=driver.driver_id,
driver_kind=driver.descriptor.kind,
error=str(exc),
)
)
continue
try:
receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata)
except Exception as exc:
receipt = DeliveryReceipt(
internal_message_id=message.message_id,
route_key=route_key,
status=DeliveryStatus.FAILED,
driver_id=driver.driver_id,
driver_kind=driver.descriptor.kind,
error=str(exc),
)
self._outbound_tracker.finish_tracking(receipt)
receipts.append(receipt)
return DeliveryBatch(
internal_message_id=message.message_id,
route_key=route_key,
receipts=receipts,
)
@staticmethod
def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]:
"""构造用于入站抑制的去重键。
Args:
envelope: 当前正在处理的入站封装。
Returns:
Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。
Notes:
这里仅接受上游显式提供的稳定消息身份,例如 ``dedupe_key``、
平台侧 ``external_message_id`` 或已经完成规范化的
``session_message.message_id``。Broker 不再根据 ``payload`` 内容
猜测语义去重键,避免把“短时间内两条内容刚好完全相同”的合法消息
误判为重复入站。
"""
raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id
if raw_dedupe_key is None and envelope.session_message is not None:
raw_dedupe_key = envelope.session_message.message_id
if raw_dedupe_key is None:
return None
normalized_dedupe_key = str(raw_dedupe_key).strip()
if not normalized_dedupe_key:
return None
return f"{envelope.driver_id}:{normalized_dedupe_key}"
@staticmethod
def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None:
"""校验路由绑定与驱动描述是否一致。
Args:
binding: 待校验的路由绑定。
driver: 被绑定的驱动实例。
Raises:
ValueError: 当绑定类型、平台或更细粒度路由维度与驱动描述冲突时抛出。
"""
descriptor = driver.descriptor
if binding.driver_kind != descriptor.kind:
raise ValueError(
f"路由绑定的 driver_kind={binding.driver_kind} 与驱动 {driver.driver_id} 的类型 "
f"{descriptor.kind} 不一致"
)
if binding.route_key.platform != descriptor.platform:
raise ValueError(
f"路由绑定的平台 {binding.route_key.platform} 与驱动 {driver.driver_id} 的平台 "
f"{descriptor.platform} 不一致"
)
if descriptor.account_id is not None and binding.route_key.account_id not in (None, descriptor.account_id):
raise ValueError(
f"路由绑定的 account_id={binding.route_key.account_id} 与驱动 {driver.driver_id}"
f"account_id={descriptor.account_id} 冲突"
)
if descriptor.scope is not None and binding.route_key.scope not in (None, descriptor.scope):
raise ValueError(
f"路由绑定的 scope={binding.route_key.scope} 与驱动 {driver.driver_id}"
f"scope={descriptor.scope} 冲突"
)
_platform_io_manager: Optional[PlatformIOManager] = None
def get_platform_io_manager() -> PlatformIOManager:
"""返回全局 ``PlatformIOManager`` 单例。
Returns:
PlatformIOManager: 进程级共享的 Broker 管理器实例。
"""
global _platform_io_manager
if _platform_io_manager is None:
_platform_io_manager = PlatformIOManager()
return _platform_io_manager

View File

@@ -0,0 +1,286 @@
"""跟踪 Platform IO 层的出站投递状态。
当前实现基于两组 ``dict + heapq``
- ``_pending`` 和 ``_pending_expire_heap`` 负责管理待完成的出站记录
- ``_receipts_by_external_id`` 和 ``_receipt_expire_heap`` 负责管理已完成回执索引
这样就不需要在每次读写时全表扫描过期项,而是通过懒清理逐步弹出已经过期
或已经失效的堆节点。
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import heapq
import time
from .types import DeliveryReceipt, RouteKey
@dataclass(slots=True)
class PendingOutboundRecord:
"""表示一条仍在等待完成的出站投递记录。
Attributes:
internal_message_id: 正在跟踪的内部 ``SessionMessage.message_id``。
route_key: 该出站投递开始时使用的路由键。
driver_id: 负责这次出站投递的驱动 ID。
created_at: 开始跟踪时记录的单调时钟时间戳。
expires_at: 该待完成记录预计过期的单调时钟时间戳。
metadata: 与待完成记录一同保留的额外 Broker 侧元数据。
"""
internal_message_id: str
route_key: RouteKey
driver_id: str
created_at: float = field(default_factory=time.monotonic)
expires_at: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class StoredDeliveryReceipt:
"""表示一条已完成并暂存的出站回执。
Attributes:
receipt: 规范化后的出站投递回执。
stored_at: 回执被写入索引时记录的单调时钟时间戳。
expires_at: 该回执索引预计过期的单调时钟时间戳。
"""
receipt: DeliveryReceipt
stored_at: float = field(default_factory=time.monotonic)
expires_at: float = 0.0
class OutboundTracker:
"""统一跟踪出站消息的 pending 状态与最终回执。
主要用于解决出站消息在发送过程中“状态散落在不同路径里”的问题:
- 发送开始后,需要在最终回执返回前保留一份 pending 状态
- 平台返回 ``external_message_id`` 后,需要保留一段时间的回执索引
当前实现使用 ``dict + heapq`` 做 TTL 管理:
- ``dict`` 提供 ``O(1)`` 级别的主键查询
- ``heapq`` 提供按过期时间排序的懒清理能力
这比“每次 begin/finish/get 都全表扫描”的实现更适合高吞吐出站场景。
Notes:
复杂度说明如下,设 ``p`` 为当前有效 pending 数量,``r`` 为当前有效回执数量:
- ``begin_tracking()``、``finish_tracking()`` 的常见路径时间复杂度接近
``O(log p)`` 或 ``O(log r)``
- ``get_pending()``、``get_receipt_by_external_id()`` 的查询本身是 ``O(1)``
,连同懒清理一起看,长期摊还复杂度接近 ``O(log n)``
- 如果某次调用恰好触发一批过期节点的集中清理,则该次调用的最坏时间复杂度
可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出的节点数量
- 空间复杂度为 ``O(p + r)``
"""
def __init__(self, ttl_seconds: float = 1800.0) -> None:
"""初始化出站跟踪器。
Args:
ttl_seconds: 待完成记录与按外部消息 ID 建立的回执索引保留时长,
单位为秒。
Raises:
ValueError: 当 ``ttl_seconds`` 非正数时抛出。
"""
if ttl_seconds <= 0:
raise ValueError("ttl_seconds 必须大于 0")
self._ttl_seconds = ttl_seconds
self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {}
self._pending_expire_heap: List[Tuple[float, str, str]] = []
self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {}
self._receipt_expire_heap: List[Tuple[float, str]] = []
@staticmethod
def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]:
"""构造单条出站跟踪记录的唯一键。
Args:
internal_message_id: 内部消息 ID。
driver_id: 负责当前投递的驱动 ID。
Returns:
Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。
"""
return internal_message_id, driver_id
def begin_tracking(
self,
internal_message_id: str,
route_key: RouteKey,
driver_id: str,
metadata: Optional[Dict[str, Any]] = None,
) -> PendingOutboundRecord:
"""开始跟踪一次出站投递。
Args:
internal_message_id: 正在投递的内部消息 ID。
route_key: 这次出站投递选择的路由键。
driver_id: 负责本次投递的驱动 ID。
metadata: 可选的额外元数据,会一并保存在待完成记录中。
Returns:
PendingOutboundRecord: 新创建的待完成记录。
Raises:
ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在
未完成记录时抛出。
"""
now = time.monotonic()
self._cleanup_expired(now)
pending_key = self._build_pending_key(internal_message_id, driver_id)
if pending_key in self._pending:
raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录")
expires_at = now + self._ttl_seconds
record = PendingOutboundRecord(
internal_message_id=internal_message_id,
route_key=route_key,
driver_id=driver_id,
created_at=now,
expires_at=expires_at,
metadata=metadata or {},
)
self._pending[pending_key] = record
heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id))
return record
def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]:
"""使用最终回执结束一条出站跟踪。
Args:
receipt: 规范化后的最终投递回执。
Returns:
Optional[PendingOutboundRecord]: 若此前存在待完成记录,则返回该记录。
"""
now = time.monotonic()
self._cleanup_expired(now)
pending_record: Optional[PendingOutboundRecord] = None
if receipt.driver_id:
pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id)
pending_record = self._pending.pop(pending_key, None)
else:
matched_records = [
key
for key, record in self._pending.items()
if record.internal_message_id == receipt.internal_message_id
]
if len(matched_records) == 1:
pending_record = self._pending.pop(matched_records[0], None)
if receipt.external_message_id:
expires_at = now + self._ttl_seconds
self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt(
receipt=receipt,
stored_at=now,
expires_at=expires_at,
)
heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id))
return pending_record
def get_pending(
self,
internal_message_id: str,
driver_id: Optional[str] = None,
) -> Optional[PendingOutboundRecord]:
"""根据内部消息 ID 查询待完成记录。
Args:
internal_message_id: 要查询的内部消息 ID。
driver_id: 可选的驱动 ID提供后仅返回该驱动上的待完成记录。
Returns:
Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。
"""
self._cleanup_expired(time.monotonic())
if driver_id:
return self._pending.get(self._build_pending_key(internal_message_id, driver_id))
matched_records = [
record
for record in self._pending.values()
if record.internal_message_id == internal_message_id
]
if len(matched_records) == 1:
return matched_records[0]
return None
def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]:
"""根据外部平台消息 ID 查询已完成回执。
Args:
external_message_id: 要查询的平台侧消息 ID。
Returns:
Optional[DeliveryReceipt]: 若存在对应回执,则返回该回执。
"""
self._cleanup_expired(time.monotonic())
stored_receipt = self._receipts_by_external_id.get(external_message_id)
return stored_receipt.receipt if stored_receipt else None
def clear(self) -> None:
"""清空全部待完成记录与已保存回执。"""
self._pending.clear()
self._pending_expire_heap.clear()
self._receipts_by_external_id.clear()
self._receipt_expire_heap.clear()
def _cleanup_expired(self, now: float) -> None:
"""清理内存中已经过期的待完成记录与已保存回执。
Args:
now: 当前单调时钟时间戳。
"""
self._cleanup_expired_pending(now)
self._cleanup_expired_receipts(now)
def _cleanup_expired_pending(self, now: float) -> None:
"""清理已经过期的待完成记录。
Args:
now: 当前单调时钟时间戳。
Notes:
堆中可能存在已经失效的旧节点。例如某条记录提前 ``finish`` 后,
它原本的过期节点仍可能留在堆里。这里会通过和 ``dict`` 中当前记录的
``expires_at`` 对比,跳过这类旧节点。
"""
while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now:
expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap)
pending_key = self._build_pending_key(internal_message_id, driver_id)
current_record = self._pending.get(pending_key)
if current_record is None:
continue
if current_record.expires_at != expires_at:
continue
self._pending.pop(pending_key, None)
def _cleanup_expired_receipts(self, now: float) -> None:
"""清理已经过期的回执索引。
Args:
now: 当前单调时钟时间戳。
Notes:
同一个 ``external_message_id`` 在极端情况下可能被重复写入索引,
因此这里同样需要通过 ``expires_at`` 和当前 ``dict`` 中的值比对,
跳过已经失效的旧堆节点。
"""
while self._receipt_expire_heap and self._receipt_expire_heap[0][0] <= now:
expires_at, external_message_id = heapq.heappop(self._receipt_expire_heap)
current_receipt = self._receipts_by_external_id.get(external_message_id)
if current_receipt is None:
continue
if current_receipt.expires_at != expires_at:
continue
self._receipts_by_external_id.pop(external_message_id, None)

View File

@@ -0,0 +1,70 @@
"""提供 Platform IO 的驱动注册与查询能力。"""
from typing import Dict, List, Optional
from src.platform_io.drivers.base import PlatformIODriver
from src.platform_io.types import DriverKind
class DriverRegistry:
"""集中保存已注册的 Platform IO 驱动,并提供基础查询接口。"""
def __init__(self) -> None:
"""初始化一个空的驱动注册表。"""
self._drivers: Dict[str, PlatformIODriver] = {}
def register(self, driver: PlatformIODriver) -> None:
"""注册一个驱动实例。
Args:
driver: 要注册的驱动实例。
Raises:
ValueError: 当驱动 ID 已经存在时抛出。
"""
if driver.driver_id in self._drivers:
raise ValueError(f"驱动 {driver.driver_id} 已注册")
self._drivers[driver.driver_id] = driver
def unregister(self, driver_id: str) -> Optional[PlatformIODriver]:
"""按驱动 ID 注销一个驱动。
Args:
driver_id: 要移除的驱动 ID。
Returns:
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
"""
return self._drivers.pop(driver_id, None)
def get(self, driver_id: str) -> Optional[PlatformIODriver]:
"""按驱动 ID 获取驱动实例。
Args:
driver_id: 要查询的驱动 ID。
Returns:
Optional[PlatformIODriver]: 若存在匹配驱动,则返回该驱动实例。
"""
return self._drivers.get(driver_id)
def list(self, *, kind: Optional[DriverKind] = None, platform: Optional[str] = None) -> List[PlatformIODriver]:
"""列出已注册驱动,并支持可选过滤。
Args:
kind: 可选的驱动类型过滤条件。
platform: 可选的平台名称过滤条件。
Returns:
List[PlatformIODriver]: 符合过滤条件的驱动列表。
"""
drivers = list(self._drivers.values())
if kind is not None:
drivers = [driver for driver in drivers if driver.descriptor.kind == kind]
if platform is not None:
drivers = [driver for driver in drivers if driver.descriptor.platform == platform]
return drivers
def clear(self) -> None:
"""清空全部已注册驱动。"""
self._drivers.clear()

View File

@@ -0,0 +1,150 @@
"""提供 Platform IO 路由键的统一提取与构造能力。
这层的目标不是直接接入具体消息链,而是先把“未来接线时用什么字段构造
RouteKey”约定下来避免 legacy 和 plugin 两条链路各自发明一套隐式规则。
"""
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from .types import RouteKey
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
class RouteKeyFactory:
"""统一构造 ``RouteKey`` 的工厂。
当前约定会优先从消息字典顶层、``message_info``、``additional_config`` 或传入 metadata 中提取
以下字段:
- account_id: ``platform_io_account_id`` / ``account_id`` / ``self_id`` / ``bot_account``
- scope: ``platform_io_scope`` / ``route_scope`` / ``adapter_scope`` / ``connection_id``
这样即使上游主链暂时还没有正式的 ``self_id`` 字段,中间层也能先统一
约定提取口径,等具体消息链接入时直接复用。
"""
ACCOUNT_ID_KEYS = (
"platform_io_account_id",
"account_id",
"self_id",
"bot_account",
)
SCOPE_KEYS = (
"platform_io_scope",
"route_scope",
"adapter_scope",
"connection_id",
)
@classmethod
def from_platform(
cls,
platform: str,
*,
account_id: Optional[str] = None,
scope: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> RouteKey:
"""根据平台名和可选 metadata 构造 ``RouteKey``。
Args:
platform: 平台名称。
account_id: 显式传入的账号 ID若为空则尝试从 metadata 提取。
scope: 显式传入的路由作用域;若为空,则尝试从 metadata 提取。
metadata: 可选的元数据字典。
Returns:
RouteKey: 构造出的规范化路由键。
"""
extracted_account_id, extracted_scope = cls.extract_components(metadata)
return RouteKey(
platform=platform,
account_id=account_id or extracted_account_id,
scope=scope or extracted_scope,
)
@classmethod
def from_message_dict(cls, message_dict: Dict[str, Any]) -> RouteKey:
"""从消息字典中提取 ``RouteKey``。
Args:
message_dict: Host 与插件之间传输的消息字典。
Returns:
RouteKey: 构造出的规范化路由键。
Raises:
ValueError: 当消息字典缺少有效 ``platform`` 字段时抛出。
"""
platform = str(message_dict.get("platform") or "").strip()
if not platform:
raise ValueError("消息字典缺少有效的 platform 字段,无法构造 RouteKey")
message_info = message_dict.get("message_info", {})
additional_config = {}
if isinstance(message_info, dict):
raw_additional_config = message_info.get("additional_config", {})
if isinstance(raw_additional_config, dict):
additional_config = raw_additional_config
explicit_account_id, explicit_scope = cls.extract_components(message_dict)
message_info_account_id, message_info_scope = cls.extract_components(message_info)
metadata_account_id, metadata_scope = cls.extract_components(additional_config)
return RouteKey(
platform=platform,
account_id=explicit_account_id or message_info_account_id or metadata_account_id,
scope=explicit_scope or message_info_scope or metadata_scope,
)
@classmethod
def from_session_message(cls, message: "SessionMessage") -> RouteKey:
"""从 ``SessionMessage`` 中提取 ``RouteKey``。
Args:
message: 内部会话消息对象。
Returns:
RouteKey: 构造出的规范化路由键。
"""
additional_config = message.message_info.additional_config or {}
metadata = additional_config if isinstance(additional_config, dict) else {}
return cls.from_platform(message.platform, metadata=metadata)
@classmethod
def extract_components(cls, mapping: Optional[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str]]:
"""从任意字典中提取 ``account_id`` 与 ``scope``。
Args:
mapping: 待提取的字典;若为空或不是字典,则返回空结果。
Returns:
Tuple[Optional[str], Optional[str]]: ``(account_id, scope)``。
"""
if not mapping or not isinstance(mapping, dict):
return None, None
account_id = cls._pick_string(mapping, cls.ACCOUNT_ID_KEYS)
scope = cls._pick_string(mapping, cls.SCOPE_KEYS)
return account_id, scope
@staticmethod
def _pick_string(mapping: Dict[str, Any], keys: Tuple[str, ...]) -> Optional[str]:
"""按优先级从字典里挑选第一个有效字符串。
Args:
mapping: 待查询的字典。
keys: 按优先级排列的候选键名。
Returns:
Optional[str]: 第一个规范化后非空的字符串值;若不存在则返回 ``None``。
"""
for key in keys:
value = mapping.get(key)
if value is None:
continue
normalized = str(value).strip()
if normalized:
return normalized
return None

141
src/platform_io/routing.py Normal file
View File

@@ -0,0 +1,141 @@
"""提供 Platform IO 的轻量路由绑定表。"""
from typing import Dict, List, Optional
from .types import RouteBinding, RouteKey
class RouteTable:
"""维护单张路由绑定表。
该实现不负责裁决“唯一 owner”只负责保存绑定并按
``RouteKey.resolution_order()`` 解析出候选绑定列表。
"""
def __init__(self) -> None:
"""初始化空路由绑定表。"""
self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {}
def bind(self, binding: RouteBinding) -> None:
"""注册或更新一条路由绑定。
Args:
binding: 要保存的路由绑定。
"""
self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding
def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]:
"""移除指定路由键上的绑定。
Args:
route_key: 要移除绑定的路由键。
driver_id: 可选的驱动 ID为空时移除该路由键下全部绑定。
Returns:
List[RouteBinding]: 被移除的绑定列表。
"""
binding_map = self._bindings.get(route_key)
if not binding_map:
return []
if driver_id is None:
removed = list(binding_map.values())
self._bindings.pop(route_key, None)
return self._sort_bindings(removed)
removed_binding = binding_map.pop(driver_id, None)
if not binding_map:
self._bindings.pop(route_key, None)
return [removed_binding] if removed_binding is not None else []
def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]:
"""移除某个驱动在整张表上的全部绑定。
Args:
driver_id: 要移除绑定的驱动 ID。
Returns:
List[RouteBinding]: 被移除的绑定列表。
"""
removed_bindings: List[RouteBinding] = []
empty_route_keys: List[RouteKey] = []
for route_key, binding_map in self._bindings.items():
removed_binding = binding_map.pop(driver_id, None)
if removed_binding is not None:
removed_bindings.append(removed_binding)
if not binding_map:
empty_route_keys.append(route_key)
for route_key in empty_route_keys:
self._bindings.pop(route_key, None)
return self._sort_bindings(removed_bindings)
def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]:
"""列出当前路由表中的绑定。
Args:
route_key: 可选的路由键过滤条件。
Returns:
List[RouteBinding]: 当前绑定列表。
"""
if route_key is None:
bindings: List[RouteBinding] = []
for binding_map in self._bindings.values():
bindings.extend(binding_map.values())
return self._sort_bindings(bindings)
binding_map = self._bindings.get(route_key, {})
return self._sort_bindings(list(binding_map.values()))
def resolve_bindings(self, route_key: RouteKey) -> List[RouteBinding]:
"""按从具体到宽泛的顺序解析路由候选绑定。
Args:
route_key: 待解析的路由键。
Returns:
List[RouteBinding]: 去重后的候选绑定列表。
"""
resolved_bindings: List[RouteBinding] = []
seen_driver_ids: set[str] = set()
for candidate_key in route_key.resolution_order():
for binding in self.list_bindings(candidate_key):
if binding.driver_id in seen_driver_ids:
continue
seen_driver_ids.add(binding.driver_id)
resolved_bindings.append(binding)
return resolved_bindings
def has_binding_for_driver(self, route_key: RouteKey, driver_id: str) -> bool:
"""判断指定驱动是否在当前路由键解析结果中。
Args:
route_key: 待解析的路由键。
driver_id: 目标驱动 ID。
Returns:
bool: 若驱动存在于解析结果中则返回 ``True``。
"""
return any(binding.driver_id == driver_id for binding in self.resolve_bindings(route_key))
@staticmethod
def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]:
"""按优先级降序排列绑定列表。
Args:
bindings: 待排序的绑定列表。
Returns:
List[RouteBinding]: 排序后的绑定列表。
"""
return sorted(bindings, key=lambda item: item.priority, reverse=True)

264
src/platform_io/types.py Normal file
View File

@@ -0,0 +1,264 @@
"""定义 Platform IO 中间层共享的核心类型。
本模块放置路由、驱动、入站与出站等规范化数据结构,供 Broker
层在 legacy 适配器链路和 plugin 适配器链路之间复用。
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
class DriverKind(str, Enum):
"""底层收发驱动类型枚举。"""
LEGACY = "legacy"
PLUGIN = "plugin"
class DeliveryStatus(str, Enum):
"""统一出站回执状态枚举。"""
PENDING = "pending"
SENT = "sent"
FAILED = "failed"
DROPPED = "dropped"
@dataclass(frozen=True, slots=True)
class RouteKey:
"""用于 Platform IO 路由决策的唯一键。
路由解析会按照“从最具体到最宽泛”的顺序进行回退,这样同一平台
后续就能自然支持按账号、自定义 scope 等更细粒度的归属控制。
Attributes:
platform: 平台名称,例如 ``qq``。
account_id: 机器人账号 ID 或 self ID用于区分同平台多身份。
scope: 额外路由作用域,预留给未来的连接实例、租户或子通道等维度。
"""
platform: str
account_id: Optional[str] = None
scope: Optional[str] = None
def __post_init__(self) -> None:
"""规范化并校验路由键字段。
Raises:
ValueError: 当 ``platform`` 规范化后为空时抛出。
"""
platform = str(self.platform).strip()
account_id = str(self.account_id).strip() if self.account_id is not None else None
scope = str(self.scope).strip() if self.scope is not None else None
if not platform:
raise ValueError("RouteKey.platform 不能为空")
object.__setattr__(self, "platform", platform)
object.__setattr__(self, "account_id", account_id or None)
object.__setattr__(self, "scope", scope or None)
def resolution_order(self) -> List["RouteKey"]:
"""返回从最具体到最宽泛的路由匹配顺序。
Returns:
List[RouteKey]: 按回退优先级排序的候选路由键列表。
"""
keys: List[RouteKey] = [self]
if self.account_id is not None and self.scope is not None:
keys.append(RouteKey(platform=self.platform, account_id=self.account_id, scope=None))
keys.append(RouteKey(platform=self.platform, account_id=None, scope=self.scope))
elif self.account_id is not None:
keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
elif self.scope is not None:
keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
default_key = RouteKey(platform=self.platform, account_id=None, scope=None)
if default_key not in keys:
keys.append(default_key)
return keys
def to_dedupe_scope(self) -> str:
"""生成跨驱动共享的去重作用域字符串。
Returns:
str: 用于入站消息去重的稳定文本作用域键。
"""
account_id = self.account_id or "*"
scope = self.scope or "*"
return f"{self.platform}:{account_id}:{scope}"
@dataclass(frozen=True, slots=True)
class DriverDescriptor:
"""描述一个已注册的 Platform IO 驱动。
Attributes:
driver_id: Broker 层内全局唯一的驱动标识。
kind: 驱动实现类型,例如 legacy 或 plugin。
platform: 驱动负责的平台名称。
account_id: 可选的账号 ID 或 self ID。
scope: 可选的额外路由作用域。
plugin_id: 当驱动来自插件适配器时,对应的插件 ID。
metadata: 预留给路由策略或观测能力的额外驱动元数据。
"""
driver_id: str
kind: DriverKind
platform: str
account_id: Optional[str] = None
scope: Optional[str] = None
plugin_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""规范化并校验驱动描述字段。
Raises:
ValueError: 当 ``driver_id`` 或 ``platform`` 规范化后为空时抛出。
"""
driver_id = str(self.driver_id).strip()
platform = str(self.platform).strip()
plugin_id = str(self.plugin_id).strip() if self.plugin_id is not None else None
if not driver_id:
raise ValueError("DriverDescriptor.driver_id 不能为空")
if not platform:
raise ValueError("DriverDescriptor.platform 不能为空")
object.__setattr__(self, "driver_id", driver_id)
object.__setattr__(self, "platform", platform)
object.__setattr__(self, "plugin_id", plugin_id or None)
@property
def route_key(self) -> RouteKey:
"""构造该驱动默认代表的路由键。
Returns:
RouteKey: 当前驱动描述对应的规范化路由键。
"""
return RouteKey(platform=self.platform, account_id=self.account_id, scope=self.scope)
@dataclass(frozen=True, slots=True)
class RouteBinding:
"""表示一条从路由键到驱动的绑定关系。
Attributes:
route_key: 该绑定覆盖的路由键。
driver_id: 拥有该路由的驱动 ID。
driver_kind: 绑定驱动的类型。
priority: 当同一路由键存在多条绑定时使用的相对优先级。
metadata: 预留给未来路由策略的额外绑定元数据。
"""
route_key: RouteKey
driver_id: str
driver_kind: DriverKind
priority: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""规范化并校验绑定字段。
Raises:
ValueError: 当 ``driver_id`` 规范化后为空时抛出。
"""
driver_id = str(self.driver_id).strip()
if not driver_id:
raise ValueError("RouteBinding.driver_id 不能为空")
object.__setattr__(self, "driver_id", driver_id)
@dataclass(slots=True)
class InboundMessageEnvelope:
"""封装一次由驱动产出的规范化入站消息。
Attributes:
route_key: 该入站消息解析出的路由键。
driver_id: 产出该消息的驱动 ID。
driver_kind: 产出该消息的驱动类型。
external_message_id: 可选的平台侧消息 ID用于去重。
dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时,
可由上游驱动提供稳定的技术性幂等键。若这里为空,中间层仅会继续
回退到 ``external_message_id`` 或 ``session_message.message_id``
不会再根据 ``payload`` 内容猜测语义去重键。
session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。
payload: 可选的原始字典载荷,供延迟转换或调试使用。
metadata: 额外入站元数据,例如连接信息或追踪上下文。
"""
route_key: RouteKey
driver_id: str
driver_kind: DriverKind
external_message_id: Optional[str] = None
dedupe_key: Optional[str] = None
session_message: Optional["SessionMessage"] = None
payload: Optional[Dict[str, Any]] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class DeliveryReceipt:
"""表示一次出站投递尝试的统一结果。
Attributes:
internal_message_id: Broker 跟踪的内部 ``SessionMessage.message_id``。
route_key: 本次投递使用的路由键。
status: 规范化后的投递状态。
driver_id: 实际处理该投递的驱动 ID可为空。
driver_kind: 实际处理该投递的驱动类型,可为空。
external_message_id: 驱动或适配器返回的平台侧消息 ID可为空。
error: 投递失败时的错误信息,可为空。
metadata: 预留给回执、时间戳或平台特有信息的额外元数据。
"""
internal_message_id: str
route_key: RouteKey
status: DeliveryStatus
driver_id: Optional[str] = None
driver_kind: Optional[DriverKind] = None
external_message_id: Optional[str] = None
error: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class DeliveryBatch:
"""表示一次广播式出站投递的批量结果。
Attributes:
internal_message_id: 内部消息 ID。
route_key: 本次投递使用的路由键。
receipts: 各条路由的独立投递回执列表。
"""
internal_message_id: str
route_key: RouteKey
receipts: List[DeliveryReceipt] = field(default_factory=list)
@property
def sent_receipts(self) -> List[DeliveryReceipt]:
"""返回全部发送成功的回执。"""
return [receipt for receipt in self.receipts if receipt.status == DeliveryStatus.SENT]
@property
def failed_receipts(self) -> List[DeliveryReceipt]:
"""返回全部发送失败的回执。"""
return [receipt for receipt in self.receipts if receipt.status != DeliveryStatus.SENT]
@property
def has_success(self) -> bool:
"""返回当前批量投递是否至少命中一条成功回执。"""
return bool(self.sent_receipts)

View File

@@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
"""Runner 启动时可视为已满足的外部插件依赖版本映射JSON 对象)"""
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
"""Runner 启动时注入的全局配置快照JSON 对象)"""

View File

@@ -1,12 +1,13 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
if TYPE_CHECKING:
from src.plugin_runtime.host.component_registry import RegisteredComponent
from src.plugin_runtime.host.api_registry import APIEntry
from src.plugin_runtime.host.component_registry import ComponentEntry
from src.plugin_runtime.host.supervisor import PluginSupervisor
@@ -14,18 +15,311 @@ class _RuntimeComponentManagerProtocol(Protocol):
@property
def supervisors(self) -> List["PluginSupervisor"]: ...
def _normalize_component_type(self, component_type: str) -> str: ...
def _is_api_component_type(self, component_type: str) -> bool: ...
def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ...
def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ...
def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
def _resolve_api_target(
self,
caller_plugin_id: str,
api_name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
def _resolve_api_toggle_target(
self,
name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
def _resolve_component_toggle_target(
self, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ...
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ...
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ...
class RuntimeComponentCapabilityMixin:
@staticmethod
def _normalize_component_type(component_type: str) -> str:
"""规范化组件类型名称。
Args:
component_type: 原始组件类型。
Returns:
str: 统一转为大写后的组件类型名。
"""
return str(component_type or "").strip().upper()
@classmethod
def _is_api_component_type(cls, component_type: str) -> bool:
"""判断组件类型是否为 API。
Args:
component_type: 原始组件类型。
Returns:
bool: 是否为 API 组件类型。
"""
return cls._normalize_component_type(component_type) == "API"
@staticmethod
def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]:
"""将 API 组件条目序列化为能力返回值。
Args:
entry: API 组件条目。
Returns:
Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。
"""
return {
"name": entry.name,
"full_name": entry.full_name,
"plugin_id": entry.plugin_id,
"description": entry.description,
"version": entry.version,
"public": entry.public,
"enabled": entry.enabled,
"dynamic": entry.dynamic,
"offline_reason": entry.offline_reason,
"metadata": dict(entry.metadata),
}
@classmethod
def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]:
"""将 API 条目序列化为通用组件视图。
Args:
entry: API 组件条目。
Returns:
Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。
"""
serialized_entry = cls._serialize_api_entry(entry)
return {
"name": serialized_entry["name"],
"full_name": serialized_entry["full_name"],
"type": "API",
"enabled": serialized_entry["enabled"],
"metadata": serialized_entry["metadata"],
}
@staticmethod
def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool:
"""判断某个 API 是否对调用方可见。
Args:
entry: 目标 API 组件条目。
caller_plugin_id: 调用方插件 ID。
Returns:
bool: 是否允许当前插件可见并调用。
"""
return entry.plugin_id == caller_plugin_id or entry.public
@staticmethod
def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
"""规范化 API 名称与版本参数。
支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
"""
normalized_api_name = str(api_name or "").strip()
normalized_version = str(version or "").strip()
if normalized_api_name and not normalized_version and "@" in normalized_api_name:
candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
candidate_name = candidate_name.strip()
candidate_version = candidate_version.strip()
if candidate_name and candidate_version:
normalized_api_name = candidate_name
normalized_version = candidate_version
return normalized_api_name, normalized_version
@staticmethod
def _build_api_unavailable_error(entry: "APIEntry") -> str:
"""构造 API 当前不可用时的错误信息。"""
if entry.offline_reason:
return entry.offline_reason
return f"API {entry.registry_key} 当前不可用"
def _resolve_api_target(
self: _RuntimeComponentManagerProtocol,
caller_plugin_id: str,
api_name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
"""解析 API 名称到唯一可调用的目标组件。
Args:
caller_plugin_id: 调用方插件 ID。
api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
version: 可选的 API 版本。
Returns:
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
"""
normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
if not normalized_api_name:
return None, None, "缺少必要参数 api_name"
if "." in normalized_api_name:
target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1)
try:
supervisor = self._get_supervisor_for_plugin(target_plugin_id)
except RuntimeError as exc:
return None, None, str(exc)
if supervisor is None:
return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
entries = supervisor.api_registry.get_apis(
plugin_id=target_plugin_id,
name=target_api_name,
version=normalized_version,
enabled_only=False,
)
visible_enabled_entries = [
entry
for entry in entries
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
]
visible_disabled_entries = [
entry
for entry in entries
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
]
if len(visible_enabled_entries) == 1:
return supervisor, visible_enabled_entries[0], None
if len(visible_enabled_entries) > 1:
return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
if visible_disabled_entries:
if len(visible_disabled_entries) == 1:
return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
if normalized_version:
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
return None, None, f"未找到 API: {normalized_api_name}"
visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
hidden_match_exists = False
for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(
name=normalized_api_name,
version=normalized_version,
enabled_only=False,
):
if self._is_api_visible_to_plugin(entry, caller_plugin_id):
if entry.enabled:
visible_enabled_matches.append((supervisor, entry))
else:
visible_disabled_matches.append((supervisor, entry))
else:
hidden_match_exists = True
if len(visible_enabled_matches) == 1:
return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
if len(visible_enabled_matches) > 1:
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
if visible_disabled_matches:
if len(visible_disabled_matches) == 1:
return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
if hidden_match_exists:
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
if normalized_version:
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
return None, None, f"未找到 API: {normalized_api_name}"
def _resolve_api_toggle_target(
self: _RuntimeComponentManagerProtocol,
name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
"""解析需要启用或禁用的 API 组件。
Args:
name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
version: 可选的 API 版本。
Returns:
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
"""
normalized_name, normalized_version = self._normalize_api_reference(name, version)
if not normalized_name:
return None, None, "缺少必要参数 name"
if "." in normalized_name:
plugin_id, api_name = normalized_name.rsplit(".", 1)
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
return None, None, str(exc)
if supervisor is None:
return None, None, f"未找到 API 提供方插件: {plugin_id}"
entries = supervisor.api_registry.get_apis(
plugin_id=plugin_id,
name=api_name,
version=normalized_version,
enabled_only=False,
)
if len(entries) == 1:
return supervisor, entries[0], None
if entries:
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
return None, None, f"未找到 API: {normalized_name}"
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
for supervisor in self.supervisors:
matches.extend(
(supervisor, entry)
for entry in supervisor.api_registry.get_apis(
name=normalized_name,
version=normalized_version,
enabled_only=False,
)
)
if len(matches) == 1:
return matches[0][0], matches[0][1], None
if len(matches) > 1:
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
return None, None, f"未找到 API: {normalized_name}"
async def _cap_component_get_all_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
@@ -46,6 +340,10 @@ class RuntimeComponentCapabilityMixin:
}
for component in comps
]
components_list.extend(
self._serialize_api_component_entry(entry)
for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False)
)
result[pid] = {
"name": pid,
"version": reg.plugin_version,
@@ -96,30 +394,35 @@ class RuntimeComponentCapabilityMixin:
def _resolve_component_toggle_target(
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
short_name_matches: List["RegisteredComponent"] = []
) -> tuple[Optional["ComponentEntry"], Optional[str]]:
normalized_component_type = self._normalize_component_type(component_type)
short_name_matches: List["ComponentEntry"] = []
for sv in self.supervisors:
comp = sv.component_registry.get_component(name)
if comp is not None and comp.component_type == component_type:
if comp is not None and comp.component_type == normalized_component_type:
return comp, None
short_name_matches.extend(
candidate
for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False)
for candidate in sv.component_registry.get_components_by_type(
normalized_component_type,
enabled_only=False,
)
if candidate.name == name
)
if len(short_name_matches) == 1:
return short_name_matches[0], None
if len(short_name_matches) > 1:
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
return None, f"未找到组件: {name} ({component_type})"
return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name"
return None, f"未找到组件: {name} ({normalized_component_type})"
async def _cap_component_enable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -127,6 +430,13 @@ class RuntimeComponentCapabilityMixin:
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type):
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"}
supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
@@ -139,6 +449,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -146,6 +457,13 @@ class RuntimeComponentCapabilityMixin:
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type):
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"}
supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
@@ -168,33 +486,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}")
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
if registered_supervisor is not None:
try:
reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
for sv in self.supervisors:
for pdir in sv._plugin_dirs:
if (pdir / plugin_name).is_dir():
try:
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
if loaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_component_unload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
@@ -216,17 +515,204 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
sv = self._get_supervisor_for_plugin(plugin_name)
reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}")
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_api_call(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""调用其他插件公开的 API。
Args:
plugin_id: 当前调用方插件 ID。
capability: 能力名称。
args: 能力参数。
Returns:
Any: API 调用结果。
"""
del capability
api_name = str(args.get("api_name", "") or "").strip()
version = str(args.get("version", "") or "").strip()
api_args = args.get("args", {})
if not isinstance(api_args, dict):
return {"success": False, "error": "参数 args 必须为字典"}
supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version)
if supervisor is None or entry is None:
return {"success": False, "error": error or "API 解析失败"}
invoke_args = dict(api_args)
if entry.dynamic:
invoke_args.setdefault("__maibot_api_name__", entry.name)
invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
invoke_args.setdefault("__maibot_api_version__", entry.version)
try:
response = await supervisor.invoke_api(
plugin_id=entry.plugin_id,
component_name=entry.handler_name,
args=invoke_args,
timeout_ms=30000,
)
except Exception as exc:
logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
if response.error:
return {"success": False, "error": response.error.get("message", "API 调用失败")}
payload = response.payload if isinstance(response.payload, dict) else {}
if not bool(payload.get("success", False)):
result = payload.get("result")
return {"success": False, "error": "" if result is None else str(result)}
return {"success": True, "result": payload.get("result")}
async def _cap_api_get(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""获取当前插件可见的单个 API 元信息。
Args:
plugin_id: 当前调用方插件 ID。
capability: 能力名称。
args: 能力参数。
Returns:
Any: API 元信息或 ``None``。
"""
del capability
api_name = str(args.get("api_name", "") or "").strip()
version = str(args.get("version", "") or "").strip()
if not api_name:
return {"success": False, "error": "缺少必要参数 api_name"}
supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version)
if supervisor is None or entry is None:
return {"success": True, "api": None}
return {"success": True, "api": self._serialize_api_entry(entry)}
async def _cap_api_list(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""列出当前插件可见的 API 列表。
Args:
plugin_id: 当前调用方插件 ID。
capability: 能力名称。
args: 能力参数。
Returns:
Any: API 元信息列表。
"""
del capability
target_plugin_id = str(args.get("plugin_id", "") or "").strip()
api_name, version = self._normalize_api_reference(
str(args.get("api_name", args.get("name", "")) or ""),
str(args.get("version", "") or ""),
)
apis: List[Dict[str, Any]] = []
for supervisor in self.supervisors:
apis.extend(
self._serialize_api_entry(entry)
for entry in supervisor.api_registry.get_apis(
plugin_id=target_plugin_id or None,
name=api_name,
version=version,
enabled_only=True,
)
if self._is_api_visible_to_plugin(entry, plugin_id)
)
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
return {"success": True, "apis": apis}
async def _cap_api_replace_dynamic(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""替换插件自行维护的动态 API 列表。"""
del capability
raw_apis = args.get("apis", [])
offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
if not isinstance(raw_apis, list):
return {"success": False, "error": "参数 apis 必须为列表"}
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if sv is not None:
try:
reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
if supervisor is None:
return {"success": False, "error": f"未找到插件: {plugin_id}"}
normalized_components: List[Dict[str, Any]] = []
seen_registry_keys: set[str] = set()
for index, raw_api in enumerate(raw_apis):
if not isinstance(raw_api, dict):
return {"success": False, "error": f"apis[{index}] 必须为字典"}
api_name = str(raw_api.get("name", "") or "").strip()
component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
if not api_name:
return {"success": False, "error": f"apis[{index}] 缺少 name"}
if not self._is_api_component_type(component_type):
return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
normalized_metadata = dict(metadata)
normalized_metadata["dynamic"] = True
version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
if registry_key in seen_registry_keys:
return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
seen_registry_keys.add(registry_key)
existing_entry = supervisor.api_registry.get_api(
plugin_id,
api_name,
version=version,
enabled_only=False,
)
if existing_entry is not None and not existing_entry.dynamic:
return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
normalized_components.append(
{
"name": api_name,
"component_type": "API",
"metadata": normalized_metadata,
}
)
registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
plugin_id,
normalized_components,
offline_reason=offline_reason,
)
return {
"success": True,
"count": registered_count,
"offlined": offlined_count,
}

View File

@@ -238,14 +238,14 @@ class RuntimeCoreCapabilityMixin:
return {"success": False, "value": None, "error": str(e)}
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
key: str = args.get("key", "")
default = args.get("default")
try:
config = core_registry.get_plugin_config(plugin_name)
config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
@@ -258,11 +258,11 @@ class RuntimeCoreCapabilityMixin:
return {"success": False, "value": default, "error": str(e)}
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
try:
config = core_registry.get_plugin_config(plugin_name)
config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": True, "value": {}}
return {"success": True, "value": config}

View File

@@ -648,10 +648,10 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": str(e)}
async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
from src.plugin_runtime.component_query import component_query_service
try:
tools = core_registry.get_llm_available_tools()
tools = component_query_service.get_llm_available_tools()
return {
"success": True,
"tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()],

View File

@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_runtime.host.capability_service import CapabilityImpl
from src.plugin_runtime.host.supervisor import PluginSupervisor
if TYPE_CHECKING:
@@ -13,66 +14,80 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
"""向指定 Supervisor 注册主程序提供的能力实现。"""
cap_service = supervisor.capability_service
cap_service.register_capability("send.text", manager._cap_send_text)
cap_service.register_capability("send.emoji", manager._cap_send_emoji)
cap_service.register_capability("send.image", manager._cap_send_image)
cap_service.register_capability("send.command", manager._cap_send_command)
cap_service.register_capability("send.custom", manager._cap_send_custom)
def _register(name: str, impl: CapabilityImpl) -> None:
"""注册单个能力实现。
cap_service.register_capability("llm.generate", manager._cap_llm_generate)
cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models)
Args:
name: 能力名称。
impl: 能力实现函数。
"""
cap_service.register_capability(name, impl)
cap_service.register_capability("config.get", manager._cap_config_get)
cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin)
cap_service.register_capability("config.get_all", manager._cap_config_get_all)
_register("send.text", manager._cap_send_text)
_register("send.emoji", manager._cap_send_emoji)
_register("send.image", manager._cap_send_image)
_register("send.command", manager._cap_send_command)
_register("send.custom", manager._cap_send_custom)
cap_service.register_capability("database.query", manager._cap_database_query)
cap_service.register_capability("database.save", manager._cap_database_save)
cap_service.register_capability("database.get", manager._cap_database_get)
cap_service.register_capability("database.delete", manager._cap_database_delete)
cap_service.register_capability("database.count", manager._cap_database_count)
_register("llm.generate", manager._cap_llm_generate)
_register("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
_register("llm.get_available_models", manager._cap_llm_get_available_models)
cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams)
cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams)
cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams)
cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
_register("config.get", manager._cap_config_get)
_register("config.get_plugin", manager._cap_config_get_plugin)
_register("config.get_all", manager._cap_config_get_all)
cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time)
cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
cap_service.register_capability("message.get_recent", manager._cap_message_get_recent)
cap_service.register_capability("message.count_new", manager._cap_message_count_new)
cap_service.register_capability("message.build_readable", manager._cap_message_build_readable)
_register("database.query", manager._cap_database_query)
_register("database.save", manager._cap_database_save)
_register("database.get", manager._cap_database_get)
_register("database.delete", manager._cap_database_delete)
_register("database.count", manager._cap_database_count)
cap_service.register_capability("person.get_id", manager._cap_person_get_id)
cap_service.register_capability("person.get_value", manager._cap_person_get_value)
cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name)
_register("chat.get_all_streams", manager._cap_chat_get_all_streams)
_register("chat.get_group_streams", manager._cap_chat_get_group_streams)
_register("chat.get_private_streams", manager._cap_chat_get_private_streams)
_register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
_register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description)
cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random)
cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count)
cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions)
cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all)
cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info)
cap_service.register_capability("emoji.register", manager._cap_emoji_register)
cap_service.register_capability("emoji.delete", manager._cap_emoji_delete)
_register("message.get_by_time", manager._cap_message_get_by_time)
_register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
_register("message.get_recent", manager._cap_message_get_recent)
_register("message.count_new", manager._cap_message_count_new)
_register("message.build_readable", manager._cap_message_build_readable)
cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust)
cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust)
_register("person.get_id", manager._cap_person_get_id)
_register("person.get_value", manager._cap_person_get_value)
_register("person.get_id_by_name", manager._cap_person_get_id_by_name)
cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions)
_register("emoji.get_by_description", manager._cap_emoji_get_by_description)
_register("emoji.get_random", manager._cap_emoji_get_random)
_register("emoji.get_count", manager._cap_emoji_get_count)
_register("emoji.get_emotions", manager._cap_emoji_get_emotions)
_register("emoji.get_all", manager._cap_emoji_get_all)
_register("emoji.get_info", manager._cap_emoji_get_info)
_register("emoji.register", manager._cap_emoji_register)
_register("emoji.delete", manager._cap_emoji_delete)
cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins)
cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info)
cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
cap_service.register_capability("component.enable", manager._cap_component_enable)
cap_service.register_capability("component.disable", manager._cap_component_disable)
cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin)
cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin)
cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin)
_register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
_register("frequency.set_adjust", manager._cap_frequency_set_adjust)
_register("frequency.get_adjust", manager._cap_frequency_get_adjust)
cap_service.register_capability("knowledge.search", manager._cap_knowledge_search)
_register("tool.get_definitions", manager._cap_tool_get_definitions)
_register("api.call", manager._cap_api_call)
_register("api.get", manager._cap_api_get)
_register("api.list", manager._cap_api_list)
_register("api.replace_dynamic", manager._cap_api_replace_dynamic)
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
_register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
_register("component.enable", manager._cap_component_enable)
_register("component.disable", manager._cap_component_disable)
_register("component.load_plugin", manager._cap_component_load_plugin)
_register("component.unload_plugin", manager._cap_component_unload_plugin)
_register("component.reload_plugin", manager._cap_component_reload_plugin)
_register("knowledge.search", manager._cap_knowledge_search)
logger.debug("已注册全部主程序能力实现")

View File

@@ -0,0 +1,709 @@
"""插件运行时统一组件查询服务。
该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图,
供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple
from src.common.logger import get_logger
from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
from src.llm_models.payload_content.tool_option import ToolParamType
if TYPE_CHECKING:
from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.integration import PluginRuntimeManager
logger = get_logger("plugin_runtime.component_query")
ActionExecutor = Callable[..., Awaitable[Any]]
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
ToolExecutor = Callable[..., Awaitable[Any]]
_HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
ComponentType.ACTION: "ACTION",
ComponentType.COMMAND: "COMMAND",
ComponentType.TOOL: "TOOL",
}
_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"float": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
"bool": ToolParamType.BOOLEAN,
}
class ComponentQueryService:
"""插件运行时统一组件查询服务。
该对象不维护独立状态,只读取插件系统中的注册结果。
所有注册、删除、配置写入等写操作都被显式禁用。
"""
@staticmethod
def _get_runtime_manager() -> "PluginRuntimeManager":
"""获取插件运行时管理器单例。
Returns:
PluginRuntimeManager: 当前全局插件运行时管理器。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _iter_supervisors(self) -> list["PluginSupervisor"]:
"""获取当前所有活跃的插件运行时监督器。
Returns:
list[PluginSupervisor]: 当前运行中的监督器列表。
"""
runtime_manager = self._get_runtime_manager()
return list(runtime_manager.supervisors)
def _iter_component_entries(
self,
component_type: ComponentType,
*,
enabled_only: bool = True,
) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
"""遍历指定类型的全部组件条目。
Args:
component_type: 目标组件类型。
enabled_only: 是否仅返回启用状态的组件。
Returns:
list[tuple[PluginSupervisor, ComponentEntry]]: ``(监督器, 组件条目)`` 列表。
"""
host_component_type = _HOST_COMPONENT_TYPE_MAP.get(component_type)
if host_component_type is None:
return []
collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
for supervisor in self._iter_supervisors():
for component in supervisor.component_registry.get_components_by_type(
host_component_type,
enabled_only=enabled_only,
):
collected_entries.append((supervisor, component))
return collected_entries
@staticmethod
def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType:
"""规范化动作激活类型。
Args:
raw_value: 原始激活类型值。
Returns:
ActionActivationType: 规范化后的激活类型枚举。
"""
normalized_value = str(raw_value or "").strip().lower()
if normalized_value == ActionActivationType.NEVER.value:
return ActionActivationType.NEVER
if normalized_value == ActionActivationType.RANDOM.value:
return ActionActivationType.RANDOM
if normalized_value == ActionActivationType.KEYWORD.value:
return ActionActivationType.KEYWORD
return ActionActivationType.ALWAYS
@staticmethod
def _coerce_float(value: Any, default: float = 0.0) -> float:
"""将任意值安全转换为浮点数。
Args:
value: 待转换的输入值。
default: 转换失败时返回的默认值。
Returns:
float: 转换后的浮点结果。
"""
try:
return float(value)
except (TypeError, ValueError):
return default
@staticmethod
def _build_action_info(entry: "ActionEntry") -> ActionInfo:
"""将运行时 Action 条目转换为核心动作信息。
Args:
entry: 插件运行时中的 Action 条目。
Returns:
ActionInfo: 供核心 Planner 使用的动作信息。
"""
metadata = dict(entry.metadata)
raw_action_parameters = metadata.get("action_parameters")
action_parameters = (
{
str(param_name): str(param_description)
for param_name, param_description in raw_action_parameters.items()
}
if isinstance(raw_action_parameters, dict)
else {}
)
action_require = [
str(item)
for item in (metadata.get("action_require") or [])
if item is not None and str(item).strip()
]
associated_types = [
str(item)
for item in (metadata.get("associated_types") or [])
if item is not None and str(item).strip()
]
activation_keywords = [
str(item)
for item in (metadata.get("activation_keywords") or [])
if item is not None and str(item).strip()
]
return ActionInfo(
name=entry.name,
component_type=ComponentType.ACTION,
description=str(metadata.get("description", "") or ""),
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=metadata,
action_parameters=action_parameters,
action_require=action_require,
associated_types=associated_types,
activation_type=ComponentQueryService._coerce_action_activation_type(metadata.get("activation_type")),
random_activation_probability=ComponentQueryService._coerce_float(
metadata.get("activation_probability"),
0.0,
),
activation_keywords=activation_keywords,
parallel_action=bool(metadata.get("parallel_action", False)),
)
@staticmethod
def _build_command_info(entry: "CommandEntry") -> CommandInfo:
"""将运行时 Command 条目转换为核心命令信息。
Args:
entry: 插件运行时中的 Command 条目。
Returns:
CommandInfo: 供核心命令链使用的命令信息。
"""
metadata = dict(entry.metadata)
return CommandInfo(
name=entry.name,
component_type=ComponentType.COMMAND,
description=str(metadata.get("description", "") or ""),
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=metadata,
command_pattern=str(metadata.get("command_pattern", "") or ""),
)
@staticmethod
def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
"""规范化工具参数类型。
Args:
raw_value: 原始工具参数类型值。
Returns:
ToolParamType: 规范化后的工具参数类型。
"""
normalized_value = str(raw_value or "").strip().lower()
return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
@staticmethod
def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
"""将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
Args:
entry: 插件运行时中的 Tool 条目。
Returns:
list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。
"""
structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
if not structured_parameters and isinstance(entry.parameters_raw, dict):
structured_parameters = [
{"name": key, **value}
for key, value in entry.parameters_raw.items()
if isinstance(value, dict)
]
normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
for parameter in structured_parameters:
if not isinstance(parameter, dict):
continue
parameter_name = str(parameter.get("name", "") or "").strip()
if not parameter_name:
continue
enum_values = parameter.get("enum")
normalized_enum_values = (
[str(item) for item in enum_values if item is not None]
if isinstance(enum_values, list)
else None
)
normalized_parameters.append(
(
parameter_name,
ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
str(parameter.get("description", "") or ""),
bool(parameter.get("required", True)),
normalized_enum_values,
)
)
return normalized_parameters
@staticmethod
def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
"""将运行时 Tool 条目转换为核心工具信息。
Args:
entry: 插件运行时中的 Tool 条目。
Returns:
ToolInfo: 供 ToolExecutor 与能力层使用的工具信息。
"""
return ToolInfo(
name=entry.name,
component_type=ComponentType.TOOL,
description=entry.description,
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=dict(entry.metadata),
tool_parameters=ComponentQueryService._build_tool_parameters(entry),
tool_description=entry.description,
)
@staticmethod
def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None:
"""记录重复组件名称冲突。
Args:
component_type: 组件类型。
component_name: 发生冲突的组件名称。
"""
logger.warning(f"检测到重复{component_type.value}名称 {component_name},将只保留首个匹配项")
def _get_unique_component_entry(
self,
component_type: ComponentType,
name: str,
) -> Optional[tuple["PluginSupervisor", "ComponentEntry"]]:
"""按组件短名解析唯一条目。
Args:
component_type: 目标组件类型。
name: 组件短名。
Returns:
Optional[tuple[PluginSupervisor, ComponentEntry]]: 唯一命中的组件条目。
"""
matched_entries = [
(supervisor, entry)
for supervisor, entry in self._iter_component_entries(component_type)
if entry.name == name
]
if not matched_entries:
return None
if len(matched_entries) > 1:
self._log_duplicate_component(component_type, name)
return matched_entries[0]
def _collect_unique_component_infos(
self,
component_type: ComponentType,
) -> Dict[str, ComponentInfo]:
"""收集某类组件的唯一信息视图。
Args:
component_type: 目标组件类型。
Returns:
Dict[str, ComponentInfo]: 组件名到核心组件信息的映射。
"""
collected_components: Dict[str, ComponentInfo] = {}
for _supervisor, entry in self._iter_component_entries(component_type):
if entry.name in collected_components:
self._log_duplicate_component(component_type, entry.name)
continue
if component_type == ComponentType.ACTION:
collected_components[entry.name] = self._build_action_info(entry) # type: ignore[arg-type]
elif component_type == ComponentType.COMMAND:
collected_components[entry.name] = self._build_command_info(entry) # type: ignore[arg-type]
elif component_type == ComponentType.TOOL:
collected_components[entry.name] = self._build_tool_info(entry) # type: ignore[arg-type]
return collected_components
@staticmethod
def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str:
"""从旧 ActionManager 参数中提取聊天流 ID。
Args:
kwargs: 旧动作执行器收到的关键字参数。
Returns:
str: 提取出的 ``stream_id``。
"""
chat_stream = kwargs.get("chat_stream")
if chat_stream is not None:
try:
return str(chat_stream.session_id)
except AttributeError:
pass
return str(kwargs.get("stream_id", "") or "")
@staticmethod
def _build_action_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ActionExecutor:
"""构造动作执行 RPC 闭包。
Args:
supervisor: 负责该组件的监督器。
plugin_id: 插件 ID。
component_name: 组件名称。
Returns:
ActionExecutor: 兼容旧 Planner 的异步执行器。
"""
async def _executor(**kwargs: Any) -> tuple[bool, str]:
"""将核心动作调用桥接到插件运行时。
Args:
**kwargs: 旧 ActionManager 传入的上下文参数。
Returns:
tuple[bool, str]: ``(是否成功, 结果说明)``。
"""
invoke_args: Dict[str, Any] = {}
action_data = kwargs.get("action_data")
if isinstance(action_data, dict):
invoke_args.update(action_data)
stream_id = ComponentQueryService._extract_stream_id_from_action_kwargs(kwargs)
invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {}
invoke_args["stream_id"] = stream_id
invoke_args["chat_id"] = stream_id
invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "")
if (thinking_id := kwargs.get("thinking_id")) is not None:
invoke_args["thinking_id"] = str(thinking_id)
if isinstance(kwargs.get("cycle_timers"), dict):
invoke_args["cycle_timers"] = kwargs["cycle_timers"]
if isinstance(kwargs.get("plugin_config"), dict):
invoke_args["plugin_config"] = kwargs["plugin_config"]
if isinstance(kwargs.get("log_prefix"), str):
invoke_args["log_prefix"] = kwargs["log_prefix"]
if isinstance(kwargs.get("shutting_down"), bool):
invoke_args["shutting_down"] = kwargs["shutting_down"]
try:
response = await supervisor.invoke_plugin(
method="plugin.invoke_action",
plugin_id=plugin_id,
component_name=component_name,
args=invoke_args,
timeout_ms=30000,
)
except Exception as exc:
logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
return False, str(exc)
payload = response.payload if isinstance(response.payload, dict) else {}
success = bool(payload.get("success", False))
result = payload.get("result")
if isinstance(result, (list, tuple)):
if len(result) >= 2:
return bool(result[0]), "" if result[1] is None else str(result[1])
if len(result) == 1:
return bool(result[0]), ""
if success:
return True, "" if result is None else str(result)
return False, "" if result is None else str(result)
return _executor
@staticmethod
def _build_command_executor(
supervisor: "PluginSupervisor",
plugin_id: str,
component_name: str,
metadata: Dict[str, Any],
) -> CommandExecutor:
"""构造命令执行 RPC 闭包。
Args:
supervisor: 负责该组件的监督器。
plugin_id: 插件 ID。
component_name: 组件名称。
metadata: 命令组件元数据。
Returns:
CommandExecutor: 兼容旧消息命令链的执行器。
"""
async def _executor(**kwargs: Any) -> tuple[bool, Optional[str], bool]:
"""将核心命令调用桥接到插件运行时。
Args:
**kwargs: 命令执行上下文参数。
Returns:
tuple[bool, Optional[str], bool]: ``(是否成功, 返回文本, 是否拦截后续消息)``。
"""
message = kwargs.get("message")
matched_groups = kwargs.get("matched_groups")
plugin_config = kwargs.get("plugin_config")
invoke_args: Dict[str, Any] = {
"text": str(getattr(message, "processed_plain_text", "") or ""),
"stream_id": str(getattr(message, "session_id", "") or ""),
"matched_groups": matched_groups if isinstance(matched_groups, dict) else {},
}
if isinstance(plugin_config, dict):
invoke_args["plugin_config"] = plugin_config
try:
response = await supervisor.invoke_plugin(
method="plugin.invoke_command",
plugin_id=plugin_id,
component_name=component_name,
args=invoke_args,
timeout_ms=30000,
)
except Exception as exc:
logger.error(f"运行时 Command {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
return False, str(exc), True
payload = response.payload if isinstance(response.payload, dict) else {}
success = bool(payload.get("success", False))
result = payload.get("result")
intercept = bool(metadata.get("intercept_message_level", 0))
response_text: Optional[str]
if isinstance(result, (list, tuple)) and len(result) >= 3:
response_text = None if result[1] is None else str(result[1])
intercept = bool(result[2])
else:
response_text = None if result is None else str(result)
return success, response_text, intercept
return _executor
@staticmethod
def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor:
"""构造工具执行 RPC 闭包。
Args:
supervisor: 负责该组件的监督器。
plugin_id: 插件 ID。
component_name: 组件名称。
Returns:
ToolExecutor: 兼容旧 ToolExecutor 的异步执行器。
"""
async def _executor(function_args: Dict[str, Any]) -> Any:
"""将核心工具调用桥接到插件运行时。
Args:
function_args: 工具调用参数。
Returns:
Any: 插件工具返回结果;若结果不是字典,则会包装为 ``{"content": ...}``。
"""
try:
response = await supervisor.invoke_plugin(
method="plugin.invoke_tool",
plugin_id=plugin_id,
component_name=component_name,
args=function_args,
timeout_ms=30000,
)
except Exception as exc:
logger.error(f"运行时 Tool {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
return {"content": f"工具 {component_name} 执行失败: {exc}"}
payload = response.payload if isinstance(response.payload, dict) else {}
result = payload.get("result")
if isinstance(result, dict):
return result
return {"content": "" if result is None else str(result)}
return _executor
def get_action_info(self, name: str) -> Optional[ActionInfo]:
"""获取指定动作的信息。
Args:
name: 动作名称。
Returns:
Optional[ActionInfo]: 匹配到的动作信息。
"""
matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
if matched_entry is None:
return None
_supervisor, entry = matched_entry
return self._build_action_info(entry) # type: ignore[arg-type]
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
"""获取指定动作的执行器。
Args:
name: 动作名称。
Returns:
Optional[ActionExecutor]: 运行时 RPC 执行闭包。
"""
matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
if matched_entry is None:
return None
supervisor, entry = matched_entry
return self._build_action_executor(supervisor, entry.plugin_id, entry.name)
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取当前默认启用的动作集合。
Returns:
Dict[str, ActionInfo]: 动作名到动作信息的映射。
"""
action_infos = self._collect_unique_component_infos(ComponentType.ACTION)
return {name: info for name, info in action_infos.items() if isinstance(info, ActionInfo) and info.enabled}
def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
"""根据文本查找匹配的命令。
Args:
text: 待匹配的文本内容。
Returns:
Optional[Tuple[CommandExecutor, dict, CommandInfo]]: 匹配结果。
"""
for supervisor in self._iter_supervisors():
match_result = supervisor.component_registry.find_command_by_text(text)
if match_result is None:
continue
entry, matched_groups = match_result
command_info = self._build_command_info(entry) # type: ignore[arg-type]
command_executor = self._build_command_executor(
supervisor,
entry.plugin_id,
entry.name,
dict(entry.metadata),
)
return command_executor, matched_groups, command_info
return None
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
"""获取指定工具的信息。
Args:
name: 工具名称。
Returns:
Optional[ToolInfo]: 匹配到的工具信息。
"""
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
if matched_entry is None:
return None
_supervisor, entry = matched_entry
return self._build_tool_info(entry) # type: ignore[arg-type]
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
"""获取指定工具的执行器。
Args:
name: 工具名称。
Returns:
Optional[ToolExecutor]: 运行时 RPC 执行闭包。
"""
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
if matched_entry is None:
return None
supervisor, entry = matched_entry
return self._build_tool_executor(supervisor, entry.plugin_id, entry.name)
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
"""获取当前可供 LLM 选择的工具集合。
Returns:
Dict[str, ToolInfo]: 工具名到工具信息的映射。
"""
tool_infos = self._collect_unique_component_infos(ComponentType.TOOL)
return {name: info for name, info in tool_infos.items() if isinstance(info, ToolInfo) and info.enabled}
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取某类组件的全部信息。
Args:
component_type: 组件类型。
Returns:
Dict[str, ComponentInfo]: 组件名到组件信息的映射。
"""
return self._collect_unique_component_infos(component_type)
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
"""读取指定插件的配置文件内容。
Args:
plugin_name: 插件名称。
Returns:
Optional[dict]: 读取成功时返回配置字典;未找到时返回 ``None``。
"""
runtime_manager = self._get_runtime_manager()
try:
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
logger.error(f"读取插件配置失败: {exc}")
return None
if supervisor is None:
return None
try:
return runtime_manager._load_plugin_config_for_supervisor(supervisor, plugin_name)
except Exception as exc:
logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
return None
component_query_service = ComponentQueryService()

View File

@@ -0,0 +1,349 @@
"""Host 侧插件 API 动态注册表。"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.api_registry")
@dataclass(slots=True)
class APIEntry:
"""API 组件条目。"""
name: str
plugin_id: str
description: str = ""
version: str = "1"
public: bool = False
metadata: Dict[str, Any] = field(default_factory=dict)
enabled: bool = True
handler_name: str = ""
dynamic: bool = False
offline_reason: str = ""
disabled_session: Set[str] = field(default_factory=set)
full_name: str = field(init=False)
registry_key: str = field(init=False)
def __post_init__(self) -> None:
"""规范化 API 条目字段。"""
self.name = str(self.name or "").strip()
self.plugin_id = str(self.plugin_id or "").strip()
self.description = str(self.description or "").strip()
self.version = str(self.version or "1").strip() or "1"
self.handler_name = str(self.handler_name or self.name).strip() or self.name
self.offline_reason = str(self.offline_reason or "").strip()
self.full_name = f"{self.plugin_id}.{self.name}"
self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
@classmethod
def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
"""根据 Runner 上报的元数据构造 API 条目。"""
safe_metadata = dict(metadata)
return cls(
name=name,
plugin_id=plugin_id,
description=str(safe_metadata.get("description", "") or ""),
version=str(safe_metadata.get("version", "1") or "1"),
public=bool(safe_metadata.get("public", False)),
metadata=safe_metadata,
enabled=bool(safe_metadata.get("enabled", True)),
handler_name=str(safe_metadata.get("handler_name", name) or name),
dynamic=bool(safe_metadata.get("dynamic", False)),
offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
)
class APIRegistry:
"""Host 侧插件 API 动态注册表。
该注册表不直接面向 Runner而是复用插件组件注册/卸载事件,
维护面向 API 调用场景的专用索引。
"""
def __init__(self) -> None:
"""初始化 API 注册表。"""
self._apis: Dict[str, APIEntry] = {}
self._by_full_name: Dict[str, List[APIEntry]] = {}
self._by_plugin: Dict[str, List[APIEntry]] = {}
self._by_name: Dict[str, List[APIEntry]] = {}
def clear(self) -> None:
"""清空全部 API 注册状态。"""
self._apis.clear()
self._by_full_name.clear()
self._by_plugin.clear()
self._by_name.clear()
@staticmethod
def _is_api_component(component_type: Any) -> bool:
"""判断组件声明是否属于 API。"""
return str(component_type or "").strip().upper() == "API"
@staticmethod
def _normalize_query_version(version: Any) -> str:
"""规范化查询使用的版本字符串。"""
return str(version or "").strip()
@classmethod
def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
"""解析可能带 ``@version`` 后缀的 API 引用。"""
normalized_reference = str(reference or "").strip()
normalized_version = cls._normalize_query_version(version)
if normalized_reference and not normalized_version and "@" in normalized_reference:
candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
candidate_reference = candidate_reference.strip()
candidate_version = candidate_version.strip()
if candidate_reference and candidate_version:
normalized_reference = candidate_reference
normalized_version = candidate_version
return normalized_reference, normalized_version
@staticmethod
def build_registry_key(plugin_id: str, name: str, version: str) -> str:
"""构造 API 注册表唯一键。"""
normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
normalized_version = str(version or "1").strip() or "1"
return f"{normalized_full_name}@{normalized_version}"
@staticmethod
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
"""判断 API 条目当前是否处于启用状态。"""
if session_id and session_id in entry.disabled_session:
return False
return entry.enabled
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个 API 条目。"""
normalized_name = str(name or "").strip()
if not normalized_name:
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
return False
entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
existing_entry = self._apis.get(entry.registry_key)
if existing_entry is not None:
logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
self._remove_entry(existing_entry)
self._apis[entry.registry_key] = entry
self._by_full_name.setdefault(entry.full_name, []).append(entry)
self._by_plugin.setdefault(plugin_id, []).append(entry)
self._by_name.setdefault(entry.name, []).append(entry)
return True
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册某个插件声明的全部 API。"""
count = 0
for component in components:
if not self._is_api_component(component.get("component_type")):
continue
if self.register_api(
name=str(component.get("name", "") or ""),
plugin_id=plugin_id,
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
):
count += 1
return count
def replace_plugin_dynamic_apis(
self,
plugin_id: str,
components: List[Dict[str, Any]],
*,
offline_reason: str = "动态 API 已下线",
) -> Tuple[int, int]:
"""替换指定插件当前声明的动态 API 集合。"""
normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
desired_registry_keys: Set[str] = set()
registered_count = 0
for component in components:
if not self._is_api_component(component.get("component_type")):
continue
metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
dynamic_metadata = dict(metadata)
dynamic_metadata["dynamic"] = True
dynamic_metadata.pop("offline_reason", None)
entry = APIEntry.from_metadata(
name=str(component.get("name", "") or ""),
plugin_id=plugin_id,
metadata=dynamic_metadata,
)
desired_registry_keys.add(entry.registry_key)
if self.register_api(entry.name, plugin_id, dynamic_metadata):
registered_count += 1
offlined_count = 0
for entry in list(self._by_plugin.get(plugin_id, [])):
if not entry.dynamic or entry.registry_key in desired_registry_keys:
continue
entry.enabled = False
entry.offline_reason = normalized_offline_reason
entry.metadata["offline_reason"] = normalized_offline_reason
offlined_count += 1
return registered_count, offlined_count
def _remove_entry(self, entry: APIEntry) -> None:
"""从全部索引中移除单个 API 条目。"""
self._apis.pop(entry.registry_key, None)
full_name_entries = self._by_full_name.get(entry.full_name)
if full_name_entries is not None:
self._by_full_name[entry.full_name] = [
candidate for candidate in full_name_entries if candidate is not entry
]
if not self._by_full_name[entry.full_name]:
self._by_full_name.pop(entry.full_name, None)
plugin_entries = self._by_plugin.get(entry.plugin_id)
if plugin_entries is not None:
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
if not self._by_plugin[entry.plugin_id]:
self._by_plugin.pop(entry.plugin_id, None)
name_entries = self._by_name.get(entry.name)
if name_entries is not None:
self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry]
if not self._by_name[entry.name]:
self._by_name.pop(entry.name, None)
def remove_apis_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的全部 API。"""
entries = list(self._by_plugin.get(plugin_id, []))
for entry in entries:
self._remove_entry(entry)
return len(entries)
def get_api_by_full_name(
self,
full_name: str,
*,
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
"""按完整名查询单个 API。"""
normalized_full_name, normalized_version = self._split_reference(full_name, version)
if not normalized_full_name:
return None
if normalized_version:
entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
if entry is None:
return None
if enabled_only and not self.check_api_enabled(entry, session_id):
return None
return entry
candidates = list(self._by_full_name.get(normalized_full_name, []))
filtered_entries = [
entry
for entry in candidates
if not enabled_only or self.check_api_enabled(entry, session_id)
]
if len(filtered_entries) != 1:
return None
return filtered_entries[0]
def get_api(
self,
plugin_id: str,
name: str,
*,
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
"""按插件 ID、短名与版本查询单个 API。"""
return self.get_api_by_full_name(
f"{plugin_id}.{name}",
version=version,
enabled_only=enabled_only,
session_id=session_id,
)
def get_apis(
self,
*,
plugin_id: Optional[str] = None,
name: str = "",
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> List[APIEntry]:
"""查询 API 列表。"""
normalized_name = str(name or "").strip()
normalized_version = self._normalize_query_version(version)
if plugin_id:
candidates = list(self._by_plugin.get(plugin_id, []))
elif normalized_name:
candidates = list(self._by_name.get(normalized_name, []))
else:
candidates = list(self._apis.values())
filtered_entries: List[APIEntry] = []
for entry in candidates:
if plugin_id and entry.plugin_id != plugin_id:
continue
if normalized_name and entry.name != normalized_name:
continue
if normalized_version and entry.version != normalized_version:
continue
if enabled_only and not self.check_api_enabled(entry, session_id):
continue
filtered_entries.append(entry)
filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
return filtered_entries
def toggle_api_status(
self,
full_name: str,
enabled: bool,
*,
version: str = "",
session_id: Optional[str] = None,
) -> bool:
"""设置指定 API 的启用状态。"""
entry = self.get_api_by_full_name(
full_name,
version=version,
enabled_only=False,
session_id=session_id,
)
if entry is None:
return False
if session_id:
if enabled:
entry.disabled_session.discard(session_id)
else:
entry.disabled_session.add(session_id)
else:
entry.enabled = enabled
if enabled:
entry.offline_reason = ""
entry.metadata.pop("offline_reason", None)
return True

View File

@@ -0,0 +1,67 @@
"""授权管理器
负责管理插件的能力授权以及校验
每个插件在 manifest 中声明能力需求Host 启动时签发能力令牌。
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
@dataclass
class CapabilityPermissionToken:
"""能力令牌"""
plugin_id: str
capabilities: Set[str] = field(default_factory=set)
class AuthorizationManager:
"""授权管理器
管理所有插件的能力令牌,提供授权校验。
"""
def __init__(self) -> None:
self._permission_tokens: Dict[str, CapabilityPermissionToken] = {}
def register_plugin(self, plugin_id: str, capabilities: List[str]) -> CapabilityPermissionToken:
"""为插件签发能力令牌"""
token = CapabilityPermissionToken(plugin_id=plugin_id, capabilities=set(capabilities))
self._permission_tokens[plugin_id] = token
return token
def revoke_permission_token(self, plugin_id: str):
"""移除插件的能力令牌。"""
self._permission_tokens.pop(plugin_id, None)
def clear(self) -> None:
"""清空所有能力令牌。"""
self._permission_tokens.clear()
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
# sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression
"""检查插件是否有权调用某项能力
Returns:
return (bool, str): (是否有此能力, 原因)
"""
if capability in _ALWAYS_ALLOWED_CAPABILITIES:
return True, ""
token = self._permission_tokens.get(plugin_id)
if not token:
return False, f"插件 {plugin_id} 未注册能力令牌"
if capability not in token.capabilities:
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
return True, ""
def get_token(self, plugin_id: str) -> Optional[CapabilityPermissionToken]:
"""获取插件的能力令牌"""
return self._permission_tokens.get(plugin_id)
def list_plugins(self) -> List[str]:
"""列出所有已注册的插件"""
return list(self._permission_tokens.keys())

View File

@@ -4,21 +4,19 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
每个能力方法被注册到 RPC Server接收 Runner 转发的请求并执行实际操作。
"""
from typing import Any, Awaitable, Callable, Dict, List
from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_runtime.host.policy_engine import PolicyEngine
from src.plugin_runtime.protocol.envelope import (
CapabilityRequestPayload,
CapabilityResponsePayload,
Envelope,
)
from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
if TYPE_CHECKING:
from src.plugin_runtime.host.authorization import AuthorizationManager
logger = get_logger("plugin_runtime.host.capability_service")
# 能力实现函数类型: (plugin_id, capability, args) -> result
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Coroutine[Any, Any, Any]]
class CapabilityService:
@@ -31,8 +29,13 @@ class CapabilityService:
4. 执行实际操作并返回结果
"""
def __init__(self, policy_engine: PolicyEngine) -> None:
self._policy = policy_engine
def __init__(self, authorization: "AuthorizationManager") -> None:
"""初始化能力服务。
Args:
authorization: 能力授权管理器。
"""
self._authorization = authorization
# capability_name -> implementation
self._implementations: Dict[str, CapabilityImpl] = {}
@@ -56,46 +59,32 @@ class CapabilityService:
try:
req = CapabilityRequestPayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(
ErrorCode.E_BAD_PAYLOAD.value,
f"能力调用 payload 格式错误: {e}",
)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 非法: {exc}")
capability = req.capability
args = req.args
# 1. 权限校验
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
allowed, reason = self._authorization.check_capability(plugin_id, capability)
if not allowed:
error_code = (
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
)
return envelope.make_error_response(
error_code.value,
reason,
)
return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason)
# 2. 查找实现
impl = self._implementations.get(capability)
if impl is None:
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的能力: {capability}",
)
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}")
# 3. 执行
try:
result = await impl(plugin_id, capability, req.args)
result = await impl(plugin_id, capability, args)
resp_payload = CapabilityResponsePayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except RPCError as e:
return envelope.make_error_response(e.code.value, e.message, e.details)
except Exception as e:
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
return envelope.make_error_response(
ErrorCode.E_CAPABILITY_FAILED.value,
str(e),
)
return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e))
def list_capabilities(self) -> List[str]:
"""列出所有已注册的能力"""

View File

@@ -1,7 +1,7 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_step
- 按类型注册组件action / command / tool / event_handler / workflow_handler / message_gateway
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
@@ -9,8 +9,10 @@
- 注册统计
"""
from typing import Any, Dict, List, Optional
from enum import Enum
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
import contextlib
import re
from src.common.logger import get_logger
@@ -18,8 +20,28 @@ from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.component_registry")
class RegisteredComponent:
"""已注册的组件条目"""
class ComponentTypes(str, Enum):
ACTION = "ACTION"
COMMAND = "COMMAND"
TOOL = "TOOL"
EVENT_HANDLER = "EVENT_HANDLER"
HOOK_HANDLER = "HOOK_HANDLER"
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
class StatusDict(TypedDict):
total: int
action: int
command: int
tool: int
event_handler: int
hook_handler: int
message_gateway: int
plugins: int
class ComponentEntry:
"""组件条目"""
__slots__ = (
"name",
@@ -28,31 +50,120 @@ class RegisteredComponent:
"plugin_id",
"metadata",
"enabled",
"_compiled_pattern",
"compiled_pattern",
"disabled_session",
)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> None:
self.name = name
self.full_name = f"{plugin_id}.{name}"
self.component_type = component_type
self.plugin_id = plugin_id
self.metadata = metadata
self.enabled = metadata.get("enabled", True)
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.name: str = name
self.full_name: str = f"{plugin_id}.{name}"
self.component_type: ComponentTypes = ComponentTypes(component_type)
self.plugin_id: str = plugin_id
self.metadata: Dict[str, Any] = metadata
self.enabled: bool = metadata.get("enabled", True)
self.disabled_session: Set[str] = set()
# 预编译命令正则(仅 command 类型)
self._compiled_pattern: Optional[re.Pattern] = None
if component_type == "command":
if pattern := metadata.get("command_pattern", ""):
try:
self._compiled_pattern = re.compile(pattern)
except re.error as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ActionEntry(ComponentEntry):
"""Action 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
super().__init__(name, component_type, plugin_id, metadata)
class CommandEntry(ComponentEntry):
"""Command 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
super().__init__(name, component_type, plugin_id, metadata)
self.aliases: List[str] = metadata.get("aliases", [])
self.compiled_pattern: Optional[re.Pattern] = None
if pattern := metadata.get("command_pattern", ""):
try:
self.compiled_pattern = re.compile(pattern)
except (re.error, TypeError) as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ToolEntry(ComponentEntry):
"""Tool 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.description: str = metadata.get("description", "")
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
super().__init__(name, component_type, plugin_id, metadata)
class EventHandlerEntry(ComponentEntry):
"""EventHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.event_type: str = metadata.get("event_type", "")
self.weight: int = metadata.get("weight", 0)
self.intercept_message: bool = metadata.get("intercept_message", False)
super().__init__(name, component_type, plugin_id, metadata)
class HookHandlerEntry(ComponentEntry):
"""WorkflowHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.stage: str = metadata.get("stage", "")
self.priority: int = metadata.get("priority", 0)
self.blocking: bool = metadata.get("blocking", False)
super().__init__(name, component_type, plugin_id, metadata)
class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
self.platform: str = str(metadata.get("platform", "") or "").strip()
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
self.scope: str = str(metadata.get("scope", "") or "").strip()
super().__init__(name, component_type, plugin_id, metadata)
@staticmethod
def _normalize_route_type(raw_value: Any) -> str:
"""规范化消息网关路由类型。
Args:
raw_value: 原始路由类型值。
Returns:
str: 规范化后的路由类型。
Raises:
ValueError: 当路由类型不受支持时抛出。
"""
normalized_value = str(raw_value or "").strip().lower()
route_type_aliases = {
"send": "send",
"receive": "receive",
"recv": "receive",
"recive": "receive",
"duplex": "duplex",
}
route_type = route_type_aliases.get(normalized_value)
if route_type is None:
raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}")
return route_type
@property
def supports_send(self) -> bool:
"""返回当前网关是否支持出站。"""
return self.route_type in {"send", "duplex"}
@property
def supports_receive(self) -> bool:
"""返回当前网关是否支持入站。"""
return self.route_type in {"receive", "duplex"}
class ComponentRegistry:
@@ -64,19 +175,32 @@ class ComponentRegistry:
def __init__(self) -> None:
# 全量索引
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
# 按类型索引
self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
"action": {},
"command": {},
"tool": {},
"event_handler": {},
"workflow_step": {},
}
self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
comp_type: {} for comp_type in ComponentTypes
} # component_type -> (full_name -> comp)
# 按插件索引
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
@staticmethod
def _normalize_component_type(component_type: str) -> ComponentTypes:
"""规范化组件类型输入。
Args:
component_type: 原始组件类型字符串。
Returns:
ComponentTypes: 规范化后的组件类型枚举。
Raises:
ValueError: 当组件类型不受支持时抛出。
"""
normalized_value = str(component_type or "").strip().upper()
return ComponentTypes(normalized_value)
def clear(self) -> None:
"""清空全部组件注册状态。"""
@@ -85,47 +209,64 @@ class ComponentRegistry:
type_dict.clear()
self._by_plugin.clear()
# ──── 注册 / 注销 ─────────────────────────────────────────
# ====== 注册 / 注销 ======
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个组件
Args:
name: 组件名称不含插件id前缀
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
plugin_id: 插件id
metadata: 组件元数据
Returns:
success (bool): 是否成功注册(失败原因通常是组件类型无效)
"""
try:
normalized_type = self._normalize_component_type(component_type)
if normalized_type == ComponentTypes.ACTION:
comp = ActionEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.COMMAND:
comp = CommandEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.TOOL:
comp = ToolEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.EVENT_HANDLER:
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.HOOK_HANDLER:
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata)
else:
raise ValueError(f"组件类型 {component_type} 不存在")
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
return False
def register_component(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> bool:
"""注册单个组件。"""
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
try:
with contextlib.suppress(ValueError):
old_list.remove(old_comp)
except ValueError:
pass
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
self._components[comp.full_name] = comp
if component_type not in self._by_type:
self._by_type[component_type] = {}
self._by_type[component_type][comp.full_name] = comp
self._by_type[comp.component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
return True
def register_plugin_components(
self,
plugin_id: str,
components: List[Dict[str, Any]],
) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。"""
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。
Args:
plugin_id (str): 插件id
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
Returns:
count (int): 成功注册的组件数量
"""
count = 0
for comp_data in components:
ok = self.register_component(
@@ -139,7 +280,13 @@ class ComponentRegistry:
return count
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。"""
"""移除某个插件的所有组件,返回移除数量。
Args:
plugin_id (str): 插件id
Returns:
count (int): 移除的组件数量
"""
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
@@ -147,106 +294,280 @@ class ComponentRegistry:
type_dict.pop(comp.full_name, None)
return len(comps)
# ──── 启用 / 禁用 ─────────────────────────────────────────
# ====== 启用 / 禁用 ======
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
if session_id and session_id in component.disabled_session:
return False
return component.enabled
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
"""启用或禁用指定组件。"""
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""启用或禁用指定组件。
Args:
full_name (str): 组件全名
enabled (bool): 使能情况
session_id (Optional[str]): 可选的会话ID仅对该会话禁用如果提供
Returns:
success (bool): 是否成功设置(失败原因通常是组件不存在)
"""
comp = self._components.get(full_name)
if comp is None:
return False
comp.enabled = enabled
if session_id:
if enabled:
comp.disabled_session.discard(session_id)
else:
comp.disabled_session.add(session_id)
else:
comp.enabled = enabled
return True
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
"""批量启用或禁用某插件的所有组件。"""
def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""设置指定组件的启用状态。
Args:
full_name: 组件全名。
enabled: 目标启用状态。
session_id: 可选的会话 ID仅对该会话生效。
Returns:
bool: 是否设置成功。
"""
return self.toggle_component_status(full_name, enabled, session_id=session_id)
def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
"""批量启用或禁用某插件的所有组件。
Args:
plugin_id (str): 插件id
enabled (bool): 使能情况
session_id (Optional[str]): 可选的会话ID仅对该会话禁用如果提供
Returns:
count (int): 成功设置的组件数量(失败原因通常是插件不存在)
"""
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
comp.enabled = enabled
if session_id:
if enabled:
comp.disabled_session.discard(session_id)
else:
comp.disabled_session.add(session_id)
else:
comp.enabled = enabled
return len(comps)
# ──── 查询方法 ─────────────────────────────────────────────
def get_component(self, full_name: str) -> Optional[ComponentEntry]:
"""按全名查询。
def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
"""按全名查询。"""
Args:
full_name (str): 组件全名
Returns:
component (Optional[ComponentEntry]): 组件条目,未找到时为 None
"""
return self._components.get(full_name)
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按类型查询。"""
type_dict = self._by_type.get(component_type, {})
def get_components_by_type(
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[ComponentEntry]:
"""按类型查询组件
Args:
component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
components (List[ComponentEntry]): 组件条目列表
"""
try:
comp_type = self._normalize_component_type(component_type)
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
raise
type_dict = self._by_type.get(comp_type, {})
if enabled_only:
return [c for c in type_dict.values() if c.enabled]
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
return list(type_dict.values())
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按插件查询。"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if c.enabled] if enabled_only else list(comps)
def get_components_by_plugin(
self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[ComponentEntry]:
"""按插件查询组件。
def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]:
Args:
plugin_id (str): 插件ID
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
components (List[ComponentEntry]): 组件条目列表
"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
def find_command_by_text(
self, text: str, session_id: Optional[str] = None
) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
matched_groups 为正则命名捕获组 dict别名匹配时为空 dict。
Args:
text (str): 待匹配文本
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
"""
for comp in self._by_type.get("command", {}).values():
if not comp.enabled:
for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
if not self.check_component_enabled(comp, session_id):
continue
if comp._compiled_pattern:
m = comp._compiled_pattern.search(text)
if m:
if not isinstance(comp, CommandEntry):
continue
if comp.compiled_pattern:
if m := comp.compiled_pattern.search(text):
return comp, m.groupdict()
# 别名匹配
aliases = comp.metadata.get("aliases", [])
for alias in aliases:
for alias in comp.aliases:
if text.startswith(alias):
return comp, {}
return None
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = []
for comp in self._by_type.get("event_handler", {}).values():
if enabled_only and not comp.enabled:
def get_event_handlers(
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[EventHandlerEntry]:
"""查询指定事件类型的事件处理器组件。
Args:
event_type (str): 事件类型
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
"""
handlers: List[EventHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if comp.metadata.get("event_type") == event_type:
if not isinstance(comp, EventHandlerEntry):
continue
if comp.event_type == event_type:
handlers.append(comp)
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
handlers.sort(key=lambda c: c.weight, reverse=True)
return handlers
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = []
for comp in self._by_type.get("workflow_step", {}).values():
if enabled_only and not comp.enabled:
def get_hook_handlers(
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[HookHandlerEntry]:
"""获取特定 hook 阶段的所有步骤,按 priority 降序。
Args:
stage: hook 名称
enabled_only: 是否仅返回启用的组件
session_id: 可选的会话ID若提供则考虑会话禁用状态
Returns:
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
"""
handlers: List[HookHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if comp.metadata.get("stage") == stage:
steps.append(comp)
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
return steps
if not isinstance(comp, HookHandlerEntry):
continue
if comp.stage == stage:
handlers.append(comp)
handlers.sort(key=lambda c: c.priority, reverse=True)
return handlers
def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
"""获取可供 LLM 使用的工具列表openai function-calling 格式预览)。"""
result: List[Dict[str, Any]] = []
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
tool_def: Dict[str, Any] = {
"name": comp.full_name,
"description": comp.metadata.get("description", ""),
}
# 从结构化参数或原始参数构建 parameters
params = comp.metadata.get("parameters", [])
params_raw = comp.metadata.get("parameters_raw", {})
if params:
tool_def["parameters"] = params
elif params_raw:
tool_def["parameters"] = params_raw
result.append(tool_def)
return result
def get_message_gateway(
self,
plugin_id: str,
name: str,
*,
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[MessageGatewayEntry]:
"""按插件和组件名获取单个消息网关。
# ──── 统计 ─────────────────────────────────────────────────
Args:
plugin_id: 插件 ID。
name: 网关组件名称。
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID。
def get_stats(self) -> Dict[str, int]:
"""获取注册统计。"""
stats: Dict[str, int] = {"total": len(self._components)}
Returns:
Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。
"""
component = self._components.get(f"{plugin_id}.{name}")
if not isinstance(component, MessageGatewayEntry):
return None
if enabled_only and not self.check_component_enabled(component, session_id):
return None
return component
def get_message_gateways(
self,
*,
plugin_id: Optional[str] = None,
platform: str = "",
route_type: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> List[MessageGatewayEntry]:
"""查询消息网关组件列表。
Args:
plugin_id: 可选的插件 ID 过滤条件。
platform: 可选的平台过滤条件。
route_type: 可选的路由类型过滤条件。
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID。
Returns:
List[MessageGatewayEntry]: 符合条件的消息网关组件列表。
"""
normalized_platform = str(platform or "").strip()
normalized_route_type = str(route_type or "").strip().lower()
gateways: List[MessageGatewayEntry] = []
for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
if not isinstance(comp, MessageGatewayEntry):
continue
if plugin_id and comp.plugin_id != plugin_id:
continue
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if normalized_platform and comp.platform != normalized_platform:
continue
if normalized_route_type and comp.route_type != normalized_route_type:
continue
gateways.append(comp)
return gateways
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
"""查询所有工具组件。
Args:
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
tools (List[ToolEntry]): 符合条件的 Tool 组件列表
"""
tools: List[ToolEntry] = []
for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if isinstance(comp, ToolEntry):
tools.append(comp)
return tools
# ====== 统计信息 ======
def get_stats(self) -> StatusDict:
"""获取注册统计。
Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
"""
stats: StatusDict = {"total": len(self._components)} # type: ignore
for comp_type, type_dict in self._by_type.items():
stats[comp_type] = len(type_dict)
stats[comp_type.value.lower()] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats

View File

@@ -4,40 +4,40 @@
1. 按事件类型查询已注册的 event_handler通过 ComponentRegistry
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
3. 支持阻塞intercept_message和非阻塞分发
4. 事件结果历史记录
4. 事件结果历史记录(有上限)
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import asyncio
from src.common.logger import get_logger
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
from .message_utils import PluginMessageUtils, MessageDict
if TYPE_CHECKING:
from .supervisor import PluginRunnerSupervisor
from .component_registry import ComponentRegistry, EventHandlerEntry
from src.chat.message_receive.message import SessionMessage
logger = get_logger("plugin_runtime.host.event_dispatcher")
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
# 每个事件类型的最大历史记录数量,防止内存无限增长
_MAX_HISTORY_LENGTH = 100
@dataclass
class EventResult:
"""单个 EventHandler 的执行结果"""
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
def __init__(
self,
handler_name: str,
success: bool = True,
continue_processing: bool = True,
modified_message: Optional[Dict[str, Any]] = None,
custom_result: Any = None,
):
self.handler_name = handler_name
self.success = success
self.continue_processing = continue_processing
self.modified_message = modified_message
self.custom_result = custom_result
handler_name: str
success: bool = field(default=True)
continue_processing: bool = field(default=True)
modified_message: Optional[MessageDict] = field(default=None)
custom_result: Any = field(default=None)
class EventDispatcher:
@@ -48,17 +48,20 @@ class EventDispatcher:
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, registry: ComponentRegistry) -> None:
self._registry: ComponentRegistry = registry
def __init__(self, component_registry: "ComponentRegistry") -> None:
self._component_registry: "ComponentRegistry" = component_registry
self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: Set[str] = set()
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
self._background_tasks: Set[asyncio.Task] = set()
def enable_history(self, event_type: str) -> None:
self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, [])
def disable_history(self, event_type: str) -> None:
self._history_enabled.discard(event_type)
self._result_history.pop(event_type, None)
def get_history(self, event_type: str) -> List[EventResult]:
return self._result_history.get(event_type, [])
@@ -66,47 +69,58 @@ class EventDispatcher:
if event_type in self._result_history:
self._result_history[event_type] = []
async def stop(self):
"""停止 EventDispatcher取消所有未完成的后台任务"""
for task in self._background_tasks:
task.cancel()
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
async def dispatch_event(
self,
event_type: str,
invoke_fn: InvokeFn,
message: Optional[Dict[str, Any]] = None,
supervisor: "PluginRunnerSupervisor",
message: Optional["SessionMessage"] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""分发事件到所有对应 handler。
) -> Tuple[bool, Optional["SessionMessage"]]:
"""分发事件到所有对应 handler 的便捷方法
内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑,
无需调用方手动构造 invoke_fn 闭包。
Args:
event_type: 事件类型字符串
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
supervisor: PluginSupervisor 实例,用于调用 invoke_plugin
message: MaiMessages 序列化后的 dict可选
extra_args: 额外参数
Returns:
(should_continue, modified_message_dict)
(should_continue, modified_message_dict) (bool, SessionMessage | None): (是否继续后续执行, 可选的修改后的消息)
"""
handlers = self._registry.get_event_handlers(event_type)
if not handlers:
handler_entries = self._component_registry.get_event_handlers(event_type)
if not handler_entries:
return True, None
should_continue = True
modified_message: Optional[Dict[str, Any]] = None
intercept_handlers: List[RegisteredComponent] = []
async_handlers: List[RegisteredComponent] = []
modified_message: Optional[MessageDict] = (
PluginMessageUtils._session_message_to_dict(message) if message else None
)
intercept_handlers: List["EventHandlerEntry"] = []
non_blocking_handlers: List["EventHandlerEntry"] = []
for handler in handlers:
if handler.metadata.get("intercept_message", False):
intercept_handlers.append(handler)
for entry in handler_entries:
if entry.intercept_message:
intercept_handlers.append(entry)
else:
async_handlers.append(handler)
non_blocking_handlers.append(entry)
for handler in intercept_handlers:
for entry in intercept_handlers:
args = {
"event_type": event_type,
"message": modified_message or message,
"message": modified_message,
**(extra_args or {}),
}
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
result = await self._invoke_handler(supervisor, entry, args, event_type)
if result and not result.continue_processing:
should_continue = False
break
@@ -114,47 +128,57 @@ class EventDispatcher:
modified_message = result.modified_message
if should_continue:
final_message = modified_message or message
for handler in async_handlers:
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
final_message = modified_message
for entry in non_blocking_handlers:
async_message = final_message.copy() if final_message else final_message
args = {
"event_type": event_type,
"message": async_message,
**(extra_args or {}),
}
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return should_continue, modified_message
try:
modified_message_obj = (
PluginMessageUtils._build_session_message_from_dict(modified_message) if modified_message else None # type: ignore
)
except Exception as e:
logger.error(f"构建修改后的 SessionMessage 失败: {e}")
modified_message_obj = None
return should_continue, modified_message_obj
async def _invoke_handler(
self,
invoke_fn: InvokeFn,
handler: RegisteredComponent,
supervisor: "PluginRunnerSupervisor",
handler_entry: "EventHandlerEntry",
args: Dict[str, Any],
event_type: str,
) -> Optional[EventResult]:
"""调用单个 handler 并收集结果。"""
try:
resp = await invoke_fn(handler.plugin_id, handler.name, args)
resp_envelope = await supervisor.invoke_plugin(
"plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args
)
resp = resp_envelope.payload
result = EventResult(
handler_name=handler.full_name,
handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_message=resp.get("modified_message"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(
handler_name=handler.full_name,
success=False,
continue_processing=True,
)
logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
if event_type in self._history_enabled:
self._result_history.setdefault(event_type, []).append(result)
history_list = self._result_history.setdefault(event_type, [])
history_list.append(result)
# 自动清理超出限制的旧记录,防止内存无限增长
if len(history_list) > _MAX_HISTORY_LENGTH:
# 保留最新的 _MAX_HISTORY_LENGTH 条记录
self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:]
return result

View File

@@ -0,0 +1,166 @@
"""
Hook Dispatch 系统
插件可以注册自己的Hook当特定函数被调用时Hook Dispatch系统会将调用转发给插件的Hook处理函数。
每个Hook的参数随Hook点位确定因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。
在参数/返回值匹配的情况下允许修改参数/返回值。
HookDispatcher 负责:
1. 按 stage 查询已注册的 hook_handler通过 ComponentRegistry
2. 按 priority 排序,区分 blocking 和非 blocking 模式
3. blocking 模式:依次同步调用,支持修改参数/提前终止
4. 非 blocking 模式:异步调用,不阻塞主流程
5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
"""
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from .supervisor import PluginRunnerSupervisor
from .component_registry import ComponentRegistry, HookHandlerEntry
logger = get_logger("plugin_runtime.host.hook_dispatcher")
@dataclass
class HookResult:
"""单个 HookHandler 的执行结果"""
handler_name: str
success: bool = field(default=True)
continue_processing: bool = field(default=True)
modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
custom_result: Any = field(default=None)
class HookDispatcher:
"""Host-side Hook 分发器
由业务层调用 hook_dispatch()
内部通过 ComponentRegistry 查询 handler
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, component_registry: "ComponentRegistry") -> None:
"""初始化 HookDispatcher
Args:
component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
"""
self._component_registry: "ComponentRegistry" = component_registry
self._background_tasks: Set[asyncio.Task] = set()
async def stop(self) -> None:
"""停止 HookDispatcher取消所有未完成的后台任务"""
for task in self._background_tasks:
task.cancel()
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
async def hook_dispatch(
self,
stage: str,
supervisor: "PluginRunnerSupervisor",
**kwargs: Any,
) -> Dict[str, Any]:
"""分发 hook 到所有对应 handler 的便捷方法。
内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
无需调用方手动构造 invoke_fn 闭包。
Args:
stage: hook 名称
supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
**kwargs: 关键字参数,会展开传递给 handler
Returns:
modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
"""
handler_entries = self._component_registry.get_hook_handlers(stage)
if not handler_entries:
return kwargs
current_kwargs = kwargs.copy()
blocking_handlers: List["HookHandlerEntry"] = []
non_blocking_handlers: List["HookHandlerEntry"] = []
# 分离 blocking 和非 blocking handler
for entry in handler_entries:
if entry.blocking:
blocking_handlers.append(entry)
else:
non_blocking_handlers.append(entry)
# 处理 blocking handlers同步调用支持修改参数/提前终止)
timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
for entry in blocking_handlers:
hook_args = {"stage": stage, **current_kwargs}
try:
# 应用超时控制
result = await asyncio.wait_for(
self._invoke_handler(supervisor, entry, hook_args),
timeout=timeout,
)
except asyncio.TimeoutError:
logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
if result:
if result.modified_kwargs is not None:
current_kwargs = result.modified_kwargs
if not result.continue_processing:
logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
break
# 处理 non-blocking handlers异步调用不阻塞主流程
for entry in non_blocking_handlers:
async_kwargs = current_kwargs.copy()
hook_args = {"stage": stage, **async_kwargs}
task = asyncio.create_task(
asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return current_kwargs
async def _invoke_handler(
self,
supervisor: "PluginRunnerSupervisor",
handler_entry: "HookHandlerEntry",
args: Dict[str, Any],
) -> Optional[HookResult]:
"""调用单个 handler 并收集结果。
Args:
supervisor: PluginRunnerSupervisor 实例
handler_entry: HookHandlerEntry 实例
args: 传递给 handler 的参数字典
stage: hook 名称
Returns:
Optional[HookResult]: 执行结果,如果执行失败则返回 None
"""
try:
resp_envelope = await supervisor.invoke_plugin(
"plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
)
resp = resp_envelope.payload
result = HookResult(
handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_kwargs=resp.get("modified_kwargs"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
return result

View File

@@ -0,0 +1,45 @@
import logging as stdlib_logging
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.protocol.envelope import Envelope, LogBatchPayload
class RunnerLogBridge:
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
Runner 通过 ``runner.log_batch`` IPC 事件批量到达。
每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接
调用 ``logging.getLogger(entry.logger_name).handle(record)``
从而接入主进程已配置好的 structlog Handler 链。
"""
async def handle_log_batch(self, envelope: Envelope) -> Envelope:
"""IPC 事件处理器:解析批量日志并重放到主进程 Logger。
Args:
envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。
Returns:
空响应信封(事件模式下将被忽略)。
"""
try:
batch = LogBatchPayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
for entry in batch.entries:
# 重建一个与原始日志尽量相符的 LogRecord
record = stdlib_logging.LogRecord(
name=entry.logger_name,
level=entry.level,
pathname="<runner>",
lineno=0,
msg=entry.message,
args=(),
exc_info=None,
)
record.created = entry.timestamp_ms / 1000.0
record.msecs = entry.timestamp_ms % 1000
if entry.exception_text:
record.exc_text = entry.exception_text
stdlib_logging.getLogger(entry.logger_name).handle(record)
return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})

View File

@@ -0,0 +1,112 @@
"""Host 侧消息网关包装器。"""
from typing import TYPE_CHECKING, Any, Dict
from src.common.logger import get_logger
from src.platform_io import get_platform_io_manager
from .message_utils import PluginMessageUtils
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
from .component_registry import ComponentRegistry
from .supervisor import PluginRunnerSupervisor
logger = get_logger("plugin_runtime.host.message_gateway")
class MessageGateway:
"""Host 侧消息网关包装器。"""
def __init__(self, component_registry: "ComponentRegistry") -> None:
"""初始化消息网关。
Args:
component_registry: 组件注册表。
"""
self._component_registry = component_registry
def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage":
"""将标准消息字典转换为 ``SessionMessage``。
Args:
external_message: 外部消息的字典格式数据。
Returns:
SessionMessage: 转换后的内部消息对象。
Raises:
ValueError: 消息字典不合法时抛出。
"""
return PluginMessageUtils._build_session_message_from_dict(external_message)
def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]:
"""将 ``SessionMessage`` 转换为标准消息字典。
Args:
internal_message: 内部消息对象。
Returns:
Dict[str, Any]: 供消息网关插件消费的标准消息字典。
"""
return dict(PluginMessageUtils._session_message_to_dict(internal_message))
async def receive_external_message(self, external_message: Dict[str, Any]) -> None:
"""接收外部消息并送入主消息链。
Args:
external_message: 外部消息的字典格式数据。
"""
try:
session_message = self.build_session_message(external_message)
except Exception as e:
logger.error(f"转换外部消息失败: {e}")
return
from src.chat.message_receive.bot import chat_bot
await chat_bot.receive_message(session_message)
async def send_message_to_external(
self,
internal_message: "SessionMessage",
supervisor: "PluginRunnerSupervisor",
*,
enabled_only: bool = True,
save_to_db: bool = True,
) -> bool:
"""将内部消息通过 Platform IO 发送到外部平台。
Args:
internal_message: 系统内部的 ``SessionMessage`` 对象。
supervisor: 当前持有该消息网关的 Supervisor。
enabled_only: 兼容旧签名的保留参数,当前未使用。
save_to_db: 发送成功后是否写入数据库。
Returns:
bool: 是否发送成功。
"""
del enabled_only
del supervisor
platform_io_manager = get_platform_io_manager()
if not platform_io_manager.is_started:
logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息")
return False
route_key = platform_io_manager.build_route_key_from_message(internal_message)
delivery_batch = await platform_io_manager.send_message(internal_message, route_key)
if not delivery_batch.has_success:
logger.warning("通过消息网关链路发送消息失败: 未命中任何成功回执")
return False
first_successful_receipt = delivery_batch.sent_receipts[0]
internal_message.message_id = first_successful_receipt.external_message_id or internal_message.message_id
if save_to_db:
try:
from src.common.utils.utils_message import MessageUtils
MessageUtils.store_message_to_db(internal_message)
except Exception as e:
logger.error(f"保存消息到数据库失败: {e}")
return True

View File

@@ -0,0 +1,487 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict
import base64
import hashlib
from src.common.logger import get_logger
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import (
AtComponent,
DictComponent,
EmojiComponent,
ForwardComponent,
ForwardNodeComponent,
ImageComponent,
MessageSequence,
ReplyComponent,
StandardMessageComponents,
TextComponent,
VoiceComponent,
)
logger = get_logger("plugin_runtime.host.message_utils")
class UserInfoDict(TypedDict, total=False):
user_id: str
user_nickname: str
user_cardname: Optional[str]
class GroupInfoDict(TypedDict, total=False):
group_id: str
group_name: str
class MessageInfoDict(TypedDict, total=False):
user_info: UserInfoDict
group_info: Optional[GroupInfoDict]
additional_config: Dict[str, Any]
class MessageDict(TypedDict, total=False):
message_id: str
timestamp: str
platform: str
message_info: MessageInfoDict
raw_message: List[Dict[str, Any]]
is_mentioned: bool
is_at: bool
is_emoji: bool
is_picture: bool
is_command: bool
is_notify: bool
session_id: str
reply_to: Optional[str]
processed_plain_text: Optional[str]
display_message: Optional[str]
class PluginMessageUtils:
@staticmethod
def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
"""将消息组件序列转换为插件运行时使用的字典结构。
Args:
message_sequence: 待转换的消息组件序列。
Returns:
List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。
"""
return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components]
@staticmethod
def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]:
"""将单个消息组件转换为插件运行时字典结构。
Args:
component: 待转换的消息组件。
Returns:
Dict[str, Any]: 序列化后的消息组件字典。
"""
if isinstance(component, TextComponent):
return {"type": "text", "data": component.text}
if isinstance(component, ImageComponent):
serialized = {
"type": "image",
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
return serialized
if isinstance(component, EmojiComponent):
serialized = {
"type": "emoji",
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
return serialized
if isinstance(component, VoiceComponent):
serialized = {
"type": "voice",
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
return serialized
if isinstance(component, AtComponent):
return {
"type": "at",
"data": {
"target_user_id": component.target_user_id,
"target_user_nickname": component.target_user_nickname,
"target_user_cardname": component.target_user_cardname,
},
}
if isinstance(component, ReplyComponent):
return {
"type": "reply",
"data": {
"target_message_id": component.target_message_id,
"target_message_content": component.target_message_content,
"target_message_sender_id": component.target_message_sender_id,
"target_message_sender_nickname": component.target_message_sender_nickname,
"target_message_sender_cardname": component.target_message_sender_cardname,
},
}
if isinstance(component, ForwardNodeComponent):
return {
"type": "forward",
"data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components],
}
return {"type": "dict", "data": component.data}
@staticmethod
def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]:
"""将单个转发节点组件转换为字典结构。
Args:
component: 待转换的转发节点组件。
Returns:
Dict[str, Any]: 序列化后的转发节点字典。
"""
return {
"user_id": component.user_id,
"user_nickname": component.user_nickname,
"user_cardname": component.user_cardname,
"message_id": component.message_id,
"content": [PluginMessageUtils._component_to_dict(item) for item in component.content],
}
@staticmethod
def _message_sequence_from_dict(raw_message_data: List[Dict[str, Any]]) -> MessageSequence:
"""从插件运行时字典结构恢复消息组件序列。
Args:
raw_message_data: 插件运行时消息段字典列表。
Returns:
MessageSequence: 恢复后的消息组件序列。
"""
components = [PluginMessageUtils._component_from_dict(item) for item in raw_message_data]
return MessageSequence(components=components)
@staticmethod
def _component_from_dict(item: Dict[str, Any]) -> StandardMessageComponents:
"""从插件运行时字典结构恢复单个消息组件。
Args:
item: 单个消息组件的字典表示。
Returns:
StandardMessageComponents: 恢复后的内部消息组件对象。
"""
item_type = str(item.get("type") or "").strip()
if item_type == "text":
return TextComponent(text=str(item.get("data") or ""))
if item_type == "image":
return PluginMessageUtils._build_binary_component(ImageComponent, item)
if item_type == "emoji":
return PluginMessageUtils._build_binary_component(EmojiComponent, item)
if item_type == "voice":
return PluginMessageUtils._build_binary_component(VoiceComponent, item)
if item_type == "at":
item_data = item.get("data", {})
if not isinstance(item_data, dict):
item_data = {}
return AtComponent(
target_user_id=str(item_data.get("target_user_id") or ""),
target_user_nickname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_nickname")),
target_user_cardname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_cardname")),
)
if item_type == "reply":
reply_data = item.get("data")
if isinstance(reply_data, dict):
return ReplyComponent(
target_message_id=str(reply_data.get("target_message_id") or ""),
target_message_content=PluginMessageUtils._normalize_optional_string(
reply_data.get("target_message_content")
),
target_message_sender_id=PluginMessageUtils._normalize_optional_string(
reply_data.get("target_message_sender_id")
),
target_message_sender_nickname=PluginMessageUtils._normalize_optional_string(
reply_data.get("target_message_sender_nickname")
),
target_message_sender_cardname=PluginMessageUtils._normalize_optional_string(
reply_data.get("target_message_sender_cardname")
),
)
return ReplyComponent(target_message_id=str(reply_data or ""))
if item_type == "forward":
forward_nodes: List[ForwardComponent] = []
raw_forward_nodes = item.get("data", [])
if isinstance(raw_forward_nodes, list):
for node in raw_forward_nodes:
if not isinstance(node, dict):
continue
raw_content = node.get("content", [])
node_components: List[StandardMessageComponents] = []
if isinstance(raw_content, list):
node_components = [
PluginMessageUtils._component_from_dict(content)
for content in raw_content
if isinstance(content, dict)
]
if not node_components:
node_components = [TextComponent(text="[empty forward node]")]
forward_nodes.append(
ForwardComponent(
user_nickname=str(node.get("user_nickname") or "未知用户"),
user_id=PluginMessageUtils._normalize_optional_string(node.get("user_id")),
user_cardname=PluginMessageUtils._normalize_optional_string(node.get("user_cardname")),
message_id=str(node.get("message_id") or ""),
content=node_components,
)
)
if not forward_nodes:
return DictComponent(data={"type": "forward", "data": item.get("data", [])})
return ForwardNodeComponent(forward_components=forward_nodes)
component_data = item.get("data")
if isinstance(component_data, dict):
return DictComponent(data=component_data)
return DictComponent(data=item)
@staticmethod
def _build_binary_component(component_cls: Any, item: Dict[str, Any]) -> StandardMessageComponents:
"""从字典构造带二进制负载的消息组件。
Args:
component_cls: 目标组件类型。
item: 消息组件字典。
Returns:
StandardMessageComponents: 构造后的组件对象。
"""
content = str(item.get("data") or "")
binary_hash = str(item.get("hash") or "")
raw_binary_base64 = item.get("binary_data_base64")
binary_data = b""
if isinstance(raw_binary_base64, str) and raw_binary_base64:
try:
binary_data = base64.b64decode(raw_binary_base64)
except Exception:
binary_data = b""
if not binary_hash and binary_data:
binary_hash = hashlib.sha256(binary_data).hexdigest()
return component_cls(binary_hash=binary_hash, content=content, binary_data=binary_data)
@staticmethod
def _normalize_optional_string(value: Any) -> Optional[str]:
"""将任意值规范化为可选字符串。
Args:
value: 待规范化的值。
Returns:
Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。
"""
if value is None:
return None
normalized_value = str(value)
return normalized_value if normalized_value else None
@staticmethod
def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict:
"""
将 MessageInfo 对象转换为字典格式
Args:
message_info: MessageInfo 对象
Returns:
字典格式的消息信息
"""
user_info_dict = UserInfoDict(
user_id=message_info.user_info.user_id,
user_nickname=message_info.user_info.user_nickname,
user_cardname=message_info.user_info.user_cardname,
)
group_info_dict: Optional[GroupInfoDict] = None
if message_info.group_info:
group_info_dict = GroupInfoDict(
group_id=message_info.group_info.group_id,
group_name=message_info.group_info.group_name,
)
return MessageInfoDict(
user_info=user_info_dict,
group_info=group_info_dict,
additional_config=message_info.additional_config,
)
@staticmethod
def _session_message_to_dict(session_message: SessionMessage) -> MessageDict:
"""
将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
Args:
session_message: SessionMessage 对象
Returns:
字典格式的消息
"""
# 转换基本信息
message_dict = MessageDict(
message_id=session_message.message_id,
timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
platform=session_message.platform,
message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message),
is_mentioned=session_message.is_mentioned,
is_at=session_message.is_at,
is_emoji=session_message.is_emoji,
is_picture=session_message.is_picture,
is_command=session_message.is_command,
is_notify=session_message.is_notify,
session_id=session_message.session_id,
)
# 添加可选字段
if session_message.reply_to is not None:
message_dict["reply_to"] = session_message.reply_to
if session_message.processed_plain_text is not None:
message_dict["processed_plain_text"] = session_message.processed_plain_text
if session_message.display_message is not None:
message_dict["display_message"] = session_message.display_message
return message_dict
@staticmethod
def _build_message_info_from_dict(message_info_dict: Dict[str, Any]) -> MessageInfo:
"""
从字典构建 MessageInfo 对象
Args:
message_info_dict: 包含消息信息的字典
Returns:
MessageInfo 对象
"""
# 构建用户信息
user_info_dict = message_info_dict.get("user_info")
if not user_info_dict or not isinstance(user_info_dict, dict):
raise ValueError("消息字典中 'user_info' 字段无效")
user_id = user_info_dict.get("user_id")
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname:
raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id''user_nickname'")
user_cardname = str(user_cardname) if user_cardname is not None else None
user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname)
# 构建群信息
if group_info_dict := message_info_dict.get("group_info"):
group_id = group_info_dict.get("group_id")
group_name = group_info_dict.get("group_name")
if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name:
raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id''group_name'")
group_info = GroupInfo(group_id=group_id, group_name=group_name)
else:
group_info = None
# 获取额外配置
additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {})
return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config)
@staticmethod
def _build_session_message_from_dict(message_dict: Dict[str, Any]) -> SessionMessage:
"""
从字典构建 SessionMessage 对象(递归处理消息组件)
Args:
message_dict: 包含消息完整信息的字典
Returns:
SessionMessage 对象
"""
# 提取基本信息
message_id = message_dict["message_id"]
timestamp_str: str = message_dict.get("timestamp", "")
platform = message_dict["platform"]
if not isinstance(message_id, str) or not message_id:
raise ValueError("消息字典中缺少有效的 'message_id' 字段")
if not isinstance(platform, str) or not platform:
raise ValueError("消息字典中缺少有效的 'platform' 字段")
# 解析时间戳
try:
timestamp_float = float(timestamp_str)
timestamp = datetime.fromtimestamp(timestamp_float)
except (ValueError, TypeError):
timestamp = datetime.now() # 如果解析失败,使用当前时间
# 创建 SessionMessage 实例
session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform)
# 构建消息信息
session_message.message_info = PluginMessageUtils._build_message_info_from_dict(message_dict["message_info"])
# 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
raw_message_data = message_dict["raw_message"]
if isinstance(raw_message_data, list):
session_message.raw_message = PluginMessageUtils._message_sequence_from_dict(raw_message_data)
else:
raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
# 设置其他可选属性
session_message.is_mentioned = message_dict.get("is_mentioned", False)
if not isinstance(session_message.is_mentioned, bool):
session_message.is_mentioned = False
session_message.is_at = message_dict.get("is_at", False)
if not isinstance(session_message.is_at, bool):
session_message.is_at = False
session_message.is_emoji = message_dict.get("is_emoji", False)
if not isinstance(session_message.is_emoji, bool):
session_message.is_emoji = False
session_message.is_picture = message_dict.get("is_picture", False)
if not isinstance(session_message.is_picture, bool):
session_message.is_picture = False
session_message.is_command = message_dict.get("is_command", False)
if not isinstance(session_message.is_command, bool):
session_message.is_command = False
session_message.is_notify = message_dict.get("is_notify", False)
if not isinstance(session_message.is_notify, bool):
session_message.is_notify = False
session_message.session_id = message_dict.get("session_id", "")
if not isinstance(session_message.session_id, str):
session_message.session_id = ""
session_message.reply_to = message_dict.get("reply_to")
if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
session_message.reply_to = None
session_message.processed_plain_text = message_dict.get("processed_plain_text")
if session_message.processed_plain_text is not None and not isinstance(
session_message.processed_plain_text, str
):
session_message.processed_plain_text = None
session_message.display_message = message_dict.get("display_message")
if session_message.display_message is not None and not isinstance(session_message.display_message, str):
session_message.display_message = None
return session_message

View File

@@ -1,97 +0,0 @@
"""策略引擎
负责能力授权校验。
每个插件在 manifest 中声明能力需求Host 启动时签发能力令牌。
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
@dataclass
class CapabilityToken:
"""能力令牌"""
plugin_id: str
generation: int
capabilities: Set[str] = field(default_factory=set)
class PolicyEngine:
"""策略引擎
管理所有插件的能力令牌,提供授权校验。
"""
def __init__(self) -> None:
self._tokens: Dict[str, Dict[int, CapabilityToken]] = {}
def register_plugin(
self,
plugin_id: str,
generation: int,
capabilities: List[str],
) -> CapabilityToken:
"""为插件签发能力令牌"""
token = CapabilityToken(
plugin_id=plugin_id,
generation=generation,
capabilities=set(capabilities),
)
self._tokens.setdefault(plugin_id, {})[generation] = token
return token
def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None:
"""撤销插件的能力令牌。"""
if generation is None:
self._tokens.pop(plugin_id, None)
return
generations = self._tokens.get(plugin_id)
if generations is None:
return
generations.pop(generation, None)
if not generations:
self._tokens.pop(plugin_id, None)
def clear(self) -> None:
"""清空所有能力令牌。"""
self._tokens.clear()
def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
"""检查插件是否有权调用某项能力
Returns:
(allowed, reason)
"""
generations = self._tokens.get(plugin_id)
if not generations:
return False, f"插件 {plugin_id} 未注册能力令牌"
if generation is None:
token = generations[max(generations)]
else:
token = generations.get(generation)
if token is None:
active_generation = max(generations)
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}"
if capability not in token.capabilities:
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
if generation is not None and token.generation != generation:
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
return True, ""
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
"""获取插件的能力令牌"""
generations = self._tokens.get(plugin_id)
if not generations:
return None
return generations[max(generations)]
def list_plugins(self) -> List[str]:
"""列出所有已注册的插件"""
return list(self._tokens.keys())

View File

@@ -7,7 +7,7 @@
4. 请求-响应关联与超时管理
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine
import asyncio
import contextlib
@@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer
logger = get_logger("plugin_runtime.host.rpc_server")
# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]]
class RPCServer:
@@ -55,108 +55,39 @@ class RPCServer:
self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
self._runner_id: Optional[str] = None
self._runner_generation: int = 0
self._staged_connection: Optional[Connection] = None
self._staged_runner_id: Optional[str] = None
self._staged_runner_generation: int = 0
self._staging_takeover: bool = False
# 方法处理器注册表
self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
# 等待响应的 pending 请求: request_id -> Future
self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
# 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
self._send_worker_task: Optional[asyncio.Task] = None
self._send_worker_task: Optional[asyncio.Task[None]] = None
# 运行状态
self._running: bool = False
self._tasks: List[asyncio.Task] = []
self._tasks: List[asyncio.Task[None]] = []
self._last_handshake_rejection_reason: str = ""
self._connection_lock: asyncio.Lock = asyncio.Lock()
@property
def session_token(self) -> str:
return self._session_token
def reset_session_token(self) -> str:
"""重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
self._session_token = secrets.token_hex(32)
return self._session_token
def restore_session_token(self, token: str) -> None:
"""恢复指定的会话令牌(热重载回滚时调用)"""
self._session_token = token
@property
def runner_generation(self) -> int:
return self._runner_generation
@property
def staged_generation(self) -> int:
return self._staged_runner_generation
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
def has_generation(self, generation: int) -> bool:
return generation == self._runner_generation or (
self._staged_connection is not None
and not self._staged_connection.is_closed
and generation == self._staged_runner_generation
)
@property
def last_handshake_rejection_reason(self) -> str:
"""返回最近一次握手被拒绝的原因。"""
return self._last_handshake_rejection_reason
def begin_staged_takeover(self) -> None:
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接"""
self._staging_takeover = True
async def commit_staged_takeover(self) -> None:
"""提交 staged Runner原活跃连接在提交后被关闭。"""
if self._staged_connection is None or self._staged_connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
old_connection = self._connection
old_generation = self._runner_generation
self._connection = self._staged_connection
self._runner_id = self._staged_runner_id
self._runner_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
if stale_count := self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=old_generation,
):
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
await old_connection.close()
async def rollback_staged_takeover(self) -> None:
"""放弃 staged Runner保留当前活跃连接。"""
staged_connection = self._staged_connection
staged_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"新 Runner 预热失败,已回滚",
generation=staged_generation,
)
if staged_connection and not staged_connection.is_closed:
await staged_connection.close()
def clear_handshake_state(self) -> None:
"""清空最近一次握手拒绝状态"""
self._last_handshake_rejection_reason = ""
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器"""
@@ -165,6 +96,7 @@ class RPCServer:
async def start(self) -> None:
"""启动 RPC 服务器"""
self._running = True
self.clear_handshake_state()
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
self._send_worker_task = asyncio.create_task(self._send_loop())
await self._transport.start(self._handle_connection)
@@ -173,14 +105,9 @@ class RPCServer:
async def stop(self) -> None:
"""停止 RPC 服务器"""
self._running = False
# 取消所有 pending 请求
for future, _generation in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear()
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
self.clear_handshake_state()
self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
if self._send_worker_task:
self._send_worker_task.cancel()
@@ -198,10 +125,6 @@ class RPCServer:
await self._connection.close()
self._connection = None
if self._staged_connection:
await self._staged_connection.close()
self._staged_connection = None
await self._transport.stop()
logger.info("RPC Server 已停止")
@@ -211,7 +134,6 @@ class RPCServer:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
target_generation: Optional[int] = None,
) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应
@@ -227,18 +149,14 @@ class RPCServer:
Raises:
RPCError: 调用失败
"""
generation = target_generation or self._runner_generation
conn = self._get_connection_for_generation(generation)
if conn is None or conn.is_closed:
if not self._connection or self._connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
request_id = self._id_gen.next()
request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
generation=generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -246,12 +164,12 @@ class RPCServer:
# 注册 pending future
loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future()
self._pending_requests[request_id] = (future, generation)
self._pending_requests[request_id] = future
try:
# 发送请求
data = self._codec.encode_envelope(envelope)
await self._enqueue_send(conn, data)
await self._enqueue_send(self._connection, data)
# 等待响应
timeout_sec = timeout_ms / 1000.0
@@ -265,150 +183,136 @@ class RPCServer:
raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
"""向 Runner 发送单向事件(不等待响应)"""
conn = self._connection
if conn is None or conn.is_closed:
return
# ============ 内部方法 ============
# ========= 发送循环 =========
async def _send_loop(self) -> None:
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
if self._send_queue is None:
raise RuntimeError("没有消息队列")
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.EVENT,
method=method,
plugin_id=plugin_id,
generation=self._runner_generation,
payload=payload or {},
)
data = self._codec.encode_envelope(envelope)
await self._enqueue_send(conn, data)
while True:
try:
conn, data, send_future = await self._send_queue.get()
except asyncio.CancelledError:
break
# ─── 内部方法 ──────────────────────────────────────────────
try:
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
await conn.send_frame(data)
if not send_future.done():
send_future.set_result(None)
except asyncio.CancelledError:
if not send_future.done():
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
raise
except Exception as e:
send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED})
if not send_future.done():
send_future.set_exception(send_error)
finally:
self._send_queue.task_done()
# ====== 发送循环方法 ======
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
previous_connection = self._connection
previous_generation = self._runner_generation
# 第一条消息必须是 runner.hello 握手
try:
role = await self._handle_handshake(conn)
if role is None:
await conn.close()
return
async with self._connection_lock:
self.clear_handshake_state()
success = await self._handle_handshake(conn)
if not success:
await conn.close()
return
logger.info("Runner staged 握手成功")
self._connection = conn
except Exception as e:
logger.error(f"握手失败: {e}")
await conn.close()
return
if role == "staged":
expected_generation = self._staged_runner_generation
logger.info(
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
)
else:
self._connection = conn
expected_generation = self._runner_generation
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
if stale_count := self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=previous_generation,
):
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
await previous_connection.close()
# 启动消息接收循环
try:
await self._recv_loop(conn, expected_generation=expected_generation)
await self._recv_loop(conn)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
if self._connection is conn:
self._connection = None
self._runner_id = None
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已断开",
generation=expected_generation,
)
elif self._staged_connection is conn:
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Staged Runner 连接已断开",
generation=expected_generation,
)
should_fail_pending_requests = False
async with self._connection_lock:
if self._connection is conn:
self._connection = None
should_fail_pending_requests = True
if should_fail_pending_requests:
self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手"""
# 接收握手请求
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
envelope = self._codec.decode_envelope(data)
if envelope.method != "runner.hello":
logger.error(f"期望 runner.hello收到 {envelope.method}")
self._last_handshake_rejection_reason = "首条消息必须为 runner.hello"
error_resp = envelope.make_error_response(
ErrorCode.E_PROTOCOL_MISMATCH.value,
"首条消息必须为 runner.hello",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
return None
return False
# 解析握手 payload
hello = HelloPayload.model_validate(envelope.payload)
# 校验会话令牌
if hello.session_token != self._session_token:
logger.error("会话令牌不匹配")
resp_payload = HelloResponsePayload(
accepted=False,
reason="会话令牌无效",
)
self._last_handshake_rejection_reason = "会话令牌无效"
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return None
return False
# 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。
if self.is_connected:
logger.warning("拒绝新的 Runner 连接:已有活跃连接")
self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手"
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
self._last_handshake_rejection_reason = (
f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]"
)
resp_payload = HelloResponsePayload(
accepted=False,
reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
reason=self._last_handshake_rejection_reason,
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return None
return False
# 握手成功
role = "active"
assigned_generation = self._runner_generation + 1
if self._staging_takeover and self.is_connected:
role = "staged"
self._staged_connection = conn
self._staged_runner_id = hello.runner_id
self._staged_runner_generation = assigned_generation
else:
self._runner_id = hello.runner_id
self._runner_generation = assigned_generation
resp_payload = HelloResponsePayload(
accepted=True,
host_version=PROTOCOL_VERSION,
assigned_generation=assigned_generation,
)
# 发送响应
self.clear_handshake_state()
resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return True
return role
def _check_sdk_version(self, sdk_version: str) -> bool:
"""检查 SDK 版本是否在支持范围内"""
try:
sdk_parts = _parse_version_tuple(sdk_version)
min_parts = _parse_version_tuple(MIN_SDK_VERSION)
max_parts = _parse_version_tuple(MAX_SDK_VERSION)
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
# ========= 接收循环 =========
async def _recv_loop(self, conn: Connection) -> None:
"""消息接收主循环"""
while self._running and not conn.is_closed:
try:
@@ -430,109 +334,40 @@ class RPCServer:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
if envelope.generation != expected_generation:
error_resp = envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"过期 generation: {envelope.generation} != {expected_generation}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
continue
# 异步处理请求Runner 发来的能力调用)
task = asyncio.create_task(self._handle_request(envelope, conn))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
elif envelope.is_event():
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
)
continue
task = asyncio.create_task(self._handle_event(envelope))
elif envelope.is_broadcast():
task = asyncio.create_task(self._handle_broadcast(envelope))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
else:
logger.warning(f"未知的消息类型: {envelope.message_type}")
continue
# ====== 接收循环内部方法 ======
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
pending = self._pending_requests.get(envelope.request_id)
if pending is None:
pending_future = self._pending_requests.pop(envelope.request_id, None)
if pending_future is None:
return
future, expected_generation = pending
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
)
return
self._pending_requests.pop(envelope.request_id, None)
if not future.done():
if not pending_future.done():
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
pending_future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
"""通过发送队列串行发送消息,提供真实背压。"""
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
if self._send_queue is None:
await conn.send_frame(data)
return
loop = asyncio.get_running_loop()
send_future: asyncio.Future[None] = loop.create_future()
try:
self._send_queue.put_nowait((conn, data, send_future))
except asyncio.QueueFull:
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
await send_future
async def _send_loop(self) -> None:
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
if self._send_queue is None:
return
while True:
try:
conn, data, send_future = await self._send_queue.get()
except asyncio.CancelledError:
break
try:
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
await conn.send_frame(data)
if not send_future.done():
send_future.set_result(None)
except asyncio.CancelledError:
if not send_future.done():
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
raise
except Exception as e:
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
if not send_future.done():
send_future.set_exception(send_error)
finally:
self._send_queue.task_done()
@staticmethod
def _normalize_send_exception(error: Exception) -> RPCError:
if isinstance(error, ConnectionError):
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
return RPCError(ErrorCode.E_UNKNOWN, str(error))
pending_future.set_result(envelope)
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method)
if handler is None:
error_resp = envelope.make_error_response(
target_method = envelope.method
handler = self._method_handlers.get(target_method)
if not handler:
error_response = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
await conn.send_frame(self._codec.encode_envelope(error_response))
return
try:
@@ -546,59 +381,25 @@ class RPCServer:
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await conn.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Runner 的事件"""
async def _handle_broadcast(self, envelope: Envelope) -> None:
if handler := self._method_handlers.get(envelope.method):
try:
result = await handler(envelope)
# 检查 handler 返回的信封是否包含错误信息
if result is not None and isinstance(result, Envelope) and result.error:
if result.error:
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
@staticmethod
def _check_sdk_version(sdk_version: str) -> bool:
"""检查 SDK 版本是否在支持范围内"""
try:
sdk_parts = RPCServer._parse_version_tuple(sdk_version)
min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION)
max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION)
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False
@staticmethod
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
base_version = base_version.split("+", 1)[0]
parts = [part for part in base_version.split(".") if part != ""]
while len(parts) < 3:
parts.append("0")
return (int(parts[0]), int(parts[1]), int(parts[2]))
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
if generation == self._runner_generation:
return self._connection
if generation == self._staged_runner_generation:
return self._staged_connection
return None
def _fail_pending_requests(
self,
error_code: ErrorCode,
message: str,
generation: Optional[int] = None,
) -> int:
stale_count = 0
for request_id, (future, request_generation) in list(self._pending_requests.items()):
if generation is not None and request_generation != generation:
continue
def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int:
"""失败所有等待中的请求(如连接断开时)"""
aborted_request_count = 0
for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(error_code, message))
stale_count += 1
self._pending_requests.pop(request_id, None)
return stale_count
aborted_request_count += 1
self._pending_requests.clear()
return aborted_request_count
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
if self._send_queue is None:
@@ -617,3 +418,31 @@ class RPCServer:
self._send_queue.task_done()
return failed_count
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
"""通过发送队列串行发送消息,提供真实背压。"""
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
if self._send_queue is None:
await conn.send_frame(data)
return
loop = asyncio.get_running_loop()
send_future: asyncio.Future[None] = loop.create_future()
try:
self._send_queue.put_nowait((conn, data, send_future))
except asyncio.QueueFull:
raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None
await send_future
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
base_version = base_version.split("+", 1)[0]
parts = [part for part in base_version.split(".") if part != ""]
while len(parts) < 3:
parts.append("0")
return (int(parts[0]), int(parts[1]), int(parts[2]))

File diff suppressed because it is too large Load Diff

View File

@@ -1,422 +0,0 @@
"""Host-side WorkflowExecutor
6 阶段线性流转INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS
每个阶段执行顺序:
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
2. 按 priority 降序排列
3. 串行执行 blocking hook可修改 message返回 HookResult
4. 并发执行 non-blocking hook只读
5. 检查是否有 SKIP_STAGE 或 ABORT
6. PLAN 阶段内置 Command 匹配路由
支持:
- HookResult: CONTINUE / SKIP_STAGE / ABORT
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
- stage_outputs: 阶段间带命名空间的数据传递
- modification_log: 消息修改审计
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
import asyncio
import time
import uuid
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
logger = get_logger("plugin_runtime.host.workflow_executor")
# 阶段顺序
STAGE_SEQUENCE: List[str] = [
"ingress",
"pre_process",
"plan",
"tool_execute",
"post_process",
"egress",
]
# HookResult 常量(与 SDK HookResult enum 值对应)
HOOK_CONTINUE = "continue"
HOOK_SKIP_STAGE = "skip_stage"
HOOK_ABORT = "abort"
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
# 从配置文件读取,允许用户调整
def _get_blocking_timeout() -> float:
return global_config.plugin_runtime.workflow_blocking_timeout_sec
class ModificationRecord:
"""消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
self.stage = stage
self.hook_name = hook_name
self.timestamp = time.perf_counter()
self.fields_changed = fields_changed
class WorkflowContext:
"""Workflow 执行上下文"""
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
self.trace_id = trace_id or uuid.uuid4().hex
self.stream_id = stream_id
self.timings: Dict[str, float] = {}
self.errors: List[str] = []
# 阶段间数据传递(按 stage 命名空间隔离)
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
# 消息修改审计日志
self.modification_log: List[ModificationRecord] = []
# PLAN 阶段命令匹配结果
self.matched_command: Optional[str] = None
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
self.stage_outputs.setdefault(stage, {})[key] = value
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
return self.stage_outputs.get(stage, {}).get(key, default)
class WorkflowResult:
"""Workflow 执行结果"""
def __init__(
self,
status: str = "completed", # completed / aborted / failed
return_message: str = "",
stopped_at: str = "",
diagnostics: Optional[Dict[str, Any]] = None,
) -> None:
self.status = status
self.return_message = return_message
self.stopped_at = stopped_at
self.diagnostics = diagnostics or {}
# invoke_fn 签名
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
class WorkflowExecutor:
"""Host-side Workflow 执行器
实现 stage-based pipeline + per-stage hook chain with priority + early return。
"""
def __init__(self, registry: ComponentRegistry) -> None:
self._registry = registry
self._background_tasks: Set[asyncio.Task] = set()
async def execute(
self,
invoke_fn: InvokeFn,
message: Optional[Dict[str, Any]] = None,
stream_id: Optional[str] = None,
context: Optional[WorkflowContext] = None,
command_invoke_fn: Optional[InvokeFn] = None,
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
"""执行 workflow pipeline。
Args:
invoke_fn: 用于 workflow_step 的回调
command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command
未传则复用 invoke_fn
Returns:
(result, final_message, context)
"""
ctx = context or WorkflowContext(stream_id=stream_id)
current_message = dict(message) if message else None
for stage in STAGE_SEQUENCE:
stage_start = time.perf_counter()
try:
# PLAN 阶段: 先做 Command 路由
if stage == "plan" and current_message:
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
if cmd_result is not None:
# 命令匹配成功,跳过 PLAN 阶段的 hook直接存结果进 stage_outputs
ctx.set_stage_output("plan", "command_result", cmd_result)
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 获取该阶段所有 hook已按 priority 降序排列)
all_steps = self._registry.get_workflow_steps(stage)
if not all_steps:
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 1. Pre-filter
filtered_steps = self._pre_filter(all_steps, current_message)
# 2. 分离 blocking 和 non-blocking
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
# 3. 串行执行 blocking hook
skip_stage = False
for step in blocking_steps:
hook_result, modified, step_error = await self._invoke_step(
invoke_fn, step, stage, ctx, current_message
)
if step_error:
error_policy = step.metadata.get("error_policy", "abort")
ctx.errors.append(f"{step.full_name}: {step_error}")
if error_policy == "abort":
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="failed",
return_message=step_error,
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
elif error_policy == "skip":
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
continue
else: # log
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
continue
# 更新消息(仅 blocking hook 有权修改)
if modified:
changed_fields = (
_diff_keys(current_message, modified) if current_message else list(modified.keys())
)
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
current_message = modified
if hook_result == HOOK_ABORT:
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="aborted",
return_message=f"aborted by {step.full_name}",
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
if hook_result == HOOK_SKIP_STAGE:
skip_stage = True
break
# 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message
if nonblocking_steps and not skip_stage:
for step in nonblocking_steps:
self._track_background_task(
asyncio.create_task(
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
)
)
ctx.timings[stage] = time.perf_counter() - stage_start
except Exception as e:
ctx.timings[stage] = time.perf_counter() - stage_start
ctx.errors.append(f"{stage}: {e}")
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
return (
WorkflowResult(
status="failed",
return_message=str(e),
stopped_at=stage,
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
return (
WorkflowResult(
status="completed",
return_message="workflow completed",
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
def _track_background_task(self, task: asyncio.Task) -> None:
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# ─── 内部方法 ──────────────────────────────────────────────
def _pre_filter(
self,
steps: List[RegisteredComponent],
message: Optional[Dict[str, Any]],
) -> List[RegisteredComponent]:
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
if not message:
return steps
result = []
for step in steps:
filter_cond = step.metadata.get("filter", {})
if not filter_cond:
result.append(step)
continue
if self._match_filter(filter_cond, message):
result.append(step)
return result
@staticmethod
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
"""简单 key-value 匹配过滤。
filter 中的每个 key 必须在 message 中存在且值相等,
全部匹配才通过。
"""
for key, expected in filter_cond.items():
actual = message.get(key)
if (isinstance(expected, list) and actual not in expected) or (
not isinstance(expected, list) and actual != expected
):
return False
return True
async def _invoke_step(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: Optional[Dict[str, Any]],
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
"""调用单个 blocking hook。
Returns:
(hook_result, modified_message, error_string_or_None)
"""
timeout_ms = step.metadata.get("timeout_ms", 0)
# 使用 hook 声明的超时,但不超过全局安全阀
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
step_key = f"{stage}:{step.full_name}"
step_start = time.perf_counter()
try:
coro = invoke_fn(
step.plugin_id,
step.name,
{
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
ctx.timings[step_key] = time.perf_counter() - step_start
hook_result = resp.get("hook_result", HOOK_CONTINUE)
modified_message = resp.get("modified_message")
# 存 stage output如果 hook 提供了)
stage_out = resp.get("stage_output")
if isinstance(stage_out, dict):
for k, v in stage_out.items():
ctx.set_stage_output(stage, k, v)
return hook_result, modified_message, None
except asyncio.TimeoutError:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
except Exception as e:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, str(e)
async def _invoke_step_fire_and_forget(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: Optional[Dict[str, Any]],
) -> None:
"""Non-blocking hook 调用,只读,忽略结果。"""
timeout_ms = step.metadata.get("timeout_ms", 0)
# 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
try:
coro = invoke_fn(
step.plugin_id,
step.name,
{
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
await asyncio.wait_for(coro, timeout=timeout_sec)
except asyncio.TimeoutError:
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
except Exception as e:
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
async def _route_command(
self,
invoke_fn: InvokeFn,
message: Dict[str, Any],
ctx: WorkflowContext,
) -> Optional[Dict[str, Any]]:
"""PLAN 阶段内置 Command 路由。
在 registry 中查找匹配的 command 组件,
匹配到则直接路由到对应 command handler返回执行结果。
不匹配则返回 None让 PLAN 阶段的 hook 继续执行。
"""
plain_text = message.get("plain_text", "")
if not plain_text:
return None
match_result = self._registry.find_command_by_text(plain_text)
if match_result is None:
return None
matched, matched_groups = match_result
ctx.matched_command = matched.full_name
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
try:
return await invoke_fn(
matched.plugin_id,
matched.name,
{
"text": plain_text,
"message": message,
"trace_id": ctx.trace_id,
"matched_groups": matched_groups,
},
)
except Exception as e:
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
ctx.errors.append(f"command:{matched.full_name}: {e}")
return None
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
"""返回 new 中与 old 不同的 key 列表。"""
return [k for k, v in new.items() if k not in old or old[k] != v]

View File

@@ -8,23 +8,27 @@
"""
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import asyncio
import json
import tomlkit
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import config_manager
from src.config.file_watcher import FileChange, FileWatcher
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
from src.plugin_runtime.capabilities import (
RuntimeComponentCapabilityMixin,
RuntimeCoreCapabilityMixin,
RuntimeDataCapabilityMixin,
)
from src.plugin_runtime.capabilities.registry import register_capability_impls
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
from src.plugin_runtime.host.supervisor import PluginSupervisor
logger = get_logger("plugin_runtime.integration")
@@ -55,6 +59,7 @@ class PluginRuntimeManager(
"""
def __init__(self) -> None:
"""初始化插件运行时管理器。"""
from src.plugin_runtime.host.supervisor import PluginSupervisor
self._builtin_supervisor: Optional[PluginSupervisor] = None
@@ -63,6 +68,26 @@ class PluginRuntimeManager(
self._plugin_file_watcher: Optional[FileWatcher] = None
self._plugin_source_watcher_subscription_id: Optional[str] = None
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
self._plugin_path_cache: Dict[str, Path] = {}
self._manifest_validator: ManifestValidator = ManifestValidator()
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
self._config_reload_callback_registered: bool = False
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
"""接收 Platform IO 审核后的入站消息并送入主消息链。
Args:
envelope: Platform IO 产出的入站封装。
"""
session_message = envelope.session_message
if session_message is None and envelope.payload is not None:
session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload))
if session_message is None:
raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload")
from src.chat.message_receive.bot import chat_bot
await chat_bot.receive_message(session_message)
# ─── 插件目录 ─────────────────────────────────────────────
@@ -78,6 +103,42 @@ class PluginRuntimeManager(
candidate = Path("plugins").resolve()
return [candidate] if candidate.is_dir() else []
@classmethod
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
validator = ManifestValidator()
return validator.build_plugin_dependency_map(plugin_dirs)
@classmethod
def _build_group_start_order(
cls,
builtin_dirs: Sequence[Path],
third_party_dirs: Sequence[Path],
) -> List[str]:
"""根据跨 Supervisor 依赖关系决定 Runner 启动顺序。"""
builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs)
third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs)
builtin_plugin_ids = set(builtin_dependencies)
third_party_plugin_ids = set(third_party_dependencies)
builtin_needs_third_party = any(
dependency in third_party_plugin_ids
for dependencies in builtin_dependencies.values()
for dependency in dependencies
)
third_party_needs_builtin = any(
dependency in builtin_plugin_ids
for dependencies in third_party_dependencies.values()
for dependency in dependencies
)
if builtin_needs_third_party and third_party_needs_builtin:
raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner")
if builtin_needs_third_party:
return ["third_party", "builtin"]
return ["builtin", "third_party"]
# ─── 生命周期 ─────────────────────────────────────────────
async def start(self) -> None:
@@ -86,7 +147,7 @@ class PluginRuntimeManager(
logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动")
return
_cfg = global_config.plugin_runtime
_cfg = config_manager.get_global_config().plugin_runtime
if not _cfg.enabled:
logger.info("插件运行时已在配置中禁用,跳过启动")
return
@@ -108,6 +169,8 @@ class PluginRuntimeManager(
logger.info("未找到任何插件目录,跳过插件运行时启动")
return
platform_io_manager = get_platform_io_manager()
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
socket_path_base = _cfg.ipc_socket_path or None
@@ -132,19 +195,46 @@ class PluginRuntimeManager(
started_supervisors: List[PluginSupervisor] = []
try:
if self._builtin_supervisor:
await self._builtin_supervisor.start()
started_supervisors.append(self._builtin_supervisor)
if self._third_party_supervisor:
await self._third_party_supervisor.start()
started_supervisors.append(self._third_party_supervisor)
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
await platform_io_manager.ensure_send_pipeline_ready()
supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
"builtin": self._builtin_supervisor,
"third_party": self._third_party_supervisor,
}
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
for group_name in start_order:
supervisor = supervisor_groups.get(group_name)
if supervisor is None:
continue
external_plugin_versions = {
plugin_id: plugin_version
for started_supervisor in started_supervisors
for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
}
supervisor.set_external_available_plugins(external_plugin_versions)
await supervisor.start()
started_supervisors.append(supervisor)
await self._start_plugin_file_watcher()
config_manager.register_reload_callback(self._config_reload_callback)
self._config_reload_callback_registered = True
self._started = True
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or ''}, 第三方: {third_party_dirs or ''}")
except Exception as e:
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
await self._stop_plugin_file_watcher()
if self._config_reload_callback_registered:
config_manager.unregister_reload_callback(self._config_reload_callback)
self._config_reload_callback_registered = False
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
platform_io_manager.clear_inbound_dispatcher()
try:
await platform_io_manager.stop()
except Exception as platform_io_exc:
logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
@@ -154,7 +244,11 @@ class PluginRuntimeManager(
if not self._started:
return
platform_io_manager = get_platform_io_manager()
await self._stop_plugin_file_watcher()
if self._config_reload_callback_registered:
config_manager.unregister_reload_callback(self._config_reload_callback)
self._config_reload_callback_registered = False
coroutines: List[Coroutine[Any, Any, None]] = []
if self._builtin_supervisor:
@@ -162,18 +256,32 @@ class PluginRuntimeManager(
if self._third_party_supervisor:
coroutines.append(self._third_party_supervisor.stop())
stop_errors: List[str] = []
try:
await asyncio.gather(*coroutines, return_exceptions=True)
logger.info("插件运行时已停止")
except Exception as e:
logger.error(f"插件运行时停止失败: {e}", exc_info=True)
results = await asyncio.gather(*coroutines, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
stop_errors.append(str(result))
platform_io_manager.clear_inbound_dispatcher()
try:
await platform_io_manager.stop()
except Exception as exc:
stop_errors.append(f"Platform IO: {exc}")
if stop_errors:
logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}")
else:
logger.info("插件运行时已停止")
finally:
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
self._plugin_path_cache.clear()
@property
def is_running(self) -> bool:
"""返回插件运行时是否处于启动状态。"""
return self._started
@property
@@ -181,11 +289,176 @@ class PluginRuntimeManager(
"""获取所有活跃的 Supervisor"""
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
"""根据当前已注册插件构建全局依赖图。"""
dependency_map: Dict[str, Set[str]] = {}
for supervisor in self.supervisors:
for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items():
dependency_map[plugin_id] = {
str(dependency or "").strip()
for dependency in getattr(registration, "dependencies", [])
if str(dependency or "").strip()
}
return dependency_map
@staticmethod
def _collect_reverse_dependents(
plugin_ids: Set[str],
dependency_map: Dict[str, Set[str]],
) -> Set[str]:
"""根据依赖图收集反向依赖闭包。"""
impacted_plugins: Set[str] = set(plugin_ids)
changed = True
while changed:
changed = False
for registered_plugin_id, dependencies in dependency_map.items():
if registered_plugin_id in impacted_plugins:
continue
if dependencies & impacted_plugins:
impacted_plugins.add(registered_plugin_id)
changed = True
return impacted_plugins
def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]:
"""构建当前已注册插件到所属 Supervisor 的映射。"""
return {
plugin_id: supervisor
for supervisor in self.supervisors
for plugin_id in supervisor.get_loaded_plugin_ids()
}
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
"""收集某个 Supervisor 可用的外部插件版本映射。"""
external_plugin_versions: Dict[str, str] = {}
for supervisor in self.supervisors:
if supervisor is target_supervisor:
continue
external_plugin_versions.update(supervisor.get_loaded_plugin_versions())
return external_plugin_versions
def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
"""根据插件目录推断应负责该插件重载的 Supervisor。"""
for supervisor in self.supervisors:
if self._get_plugin_path_for_supervisor(supervisor, plugin_id) is not None:
return supervisor
return None
def _warn_skipped_cross_supervisor_reload(
self,
requested_loaded_plugin_ids: Set[str],
dependency_map: Dict[str, Set[str]],
supervisor_by_plugin: Dict[str, "PluginSupervisor"],
) -> None:
"""记录因跨 Supervisor 边界而未参与联动重载的插件。"""
if not requested_loaded_plugin_ids:
return
handled_plugin_ids: Set[str] = set()
for supervisor in self.supervisors:
local_requested_plugin_ids = {
plugin_id
for plugin_id in requested_loaded_plugin_ids
if supervisor_by_plugin.get(plugin_id) is supervisor
}
if not local_requested_plugin_ids:
continue
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
local_dependency_map = {
plugin_id: {
dependency
for dependency in dependency_map.get(plugin_id, set())
if dependency in local_plugin_ids
}
for plugin_id in local_plugin_ids
}
handled_plugin_ids.update(
self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map)
)
impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map)
skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids)
if not skipped_plugin_ids:
return
logger.warning(
f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: "
f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;"
"跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。"
)
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool:
"""按 Supervisor 分组执行精确重载。
仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警,
不再自动参与本次热重载。
"""
normalized_plugin_ids = [
normalized_plugin_id
for plugin_id in plugin_ids
if (normalized_plugin_id := str(plugin_id or "").strip())
]
if not normalized_plugin_ids:
return True
dependency_map = self._build_registered_dependency_map()
supervisor_by_plugin = self._build_registered_supervisor_map()
supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
requested_loaded_plugin_ids: Set[str] = set()
missing_plugin_ids: List[str] = []
for plugin_id in normalized_plugin_ids:
supervisor = supervisor_by_plugin.get(plugin_id)
if supervisor is not None:
requested_loaded_plugin_ids.add(plugin_id)
else:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
missing_plugin_ids.append(plugin_id)
continue
if plugin_id not in supervisor_roots.setdefault(supervisor, []):
supervisor_roots[supervisor].append(plugin_id)
if missing_plugin_ids:
logger.warning(f"以下插件未找到可重载的 Supervisor已跳过: {', '.join(sorted(missing_plugin_ids))}")
self._warn_skipped_cross_supervisor_reload(
requested_loaded_plugin_ids=requested_loaded_plugin_ids,
dependency_map=dependency_map,
supervisor_by_plugin=supervisor_by_plugin,
)
success = True
for supervisor, root_plugin_ids in supervisor_roots.items():
if not root_plugin_ids:
continue
reloaded = await supervisor.reload_plugins(
plugin_ids=root_plugin_ids,
reason=reason,
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
success = success and reloaded
return success and not missing_plugin_ids
async def notify_plugin_config_updated(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
config_version: str = "",
config_scope: str = "self",
) -> bool:
"""向拥有该插件的 Supervisor 推送配置更新事件。
@@ -193,6 +466,7 @@ class PluginRuntimeManager(
plugin_id: 插件 ID
config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载)
config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制
config_scope: 配置变更范围。
"""
if not self._started:
return False
@@ -209,23 +483,78 @@ class PluginRuntimeManager(
config_payload = (
config_data
if config_data is not None
else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs)
else self._load_plugin_config_for_supervisor(sv, plugin_id)
)
await sv.notify_plugin_config_updated(
return await sv.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=config_payload,
config_version=config_version,
config_scope=config_scope,
)
return True
@staticmethod
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
"""规范化配置热重载范围列表。
Args:
changed_scopes: 原始配置热重载范围列表。
Returns:
tuple[str, ...]: 去重后的有效配置范围元组。
"""
normalized_scopes: list[str] = []
for scope in changed_scopes:
normalized_scope = str(scope or "").strip().lower()
if normalized_scope not in {"bot", "model"}:
continue
if normalized_scope not in normalized_scopes:
normalized_scopes.append(normalized_scope)
return tuple(normalized_scopes)
async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None:
"""向订阅指定范围的插件广播配置热重载。
Args:
scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
config_data: 最新配置数据。
"""
for supervisor in self.supervisors:
for plugin_id in supervisor.get_config_reload_subscribers(scope):
delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=config_data,
config_version="",
config_scope=scope,
)
if not delivered:
logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败")
async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None:
"""处理 bot/model 主配置热重载广播。
Args:
changed_scopes: 本次热重载命中的配置范围列表。
"""
if not self._started:
return
normalized_scopes = self._normalize_config_reload_scopes(changed_scopes)
if "bot" in normalized_scopes:
await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump(mode="json"))
if "model" in normalized_scopes:
await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump(mode="json"))
# ─── 事件桥接 ──────────────────────────────────────────────
async def bridge_event(
self,
event_type_value: str,
message_dict: Optional[Dict[str, Any]] = None,
message_dict: Optional[MessageDict] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
) -> Tuple[bool, Optional[MessageDict]]:
"""将事件分发到所有 Supervisor
Returns:
@@ -235,17 +564,23 @@ class PluginRuntimeManager(
return True, None
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
modified: Optional[Dict[str, Any]] = None
modified: Optional[MessageDict] = None
current_message: Optional["SessionMessage"] = (
PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
if message_dict is not None
else None
)
for sv in self.supervisors:
try:
cont, mod = await sv.dispatch_event(
event_type=new_event_type,
message=modified or message_dict,
message=current_message,
extra_args=extra_args,
)
if mod is not None:
modified = mod
current_message = mod
modified = PluginMessageUtils._session_message_to_dict(mod)
if not cont:
return False, modified
except Exception as e:
@@ -295,6 +630,37 @@ class PluginRuntimeManager(
timeout_ms=timeout_ms,
)
async def try_send_message_via_platform_io(
self,
message: "SessionMessage",
) -> Optional[DeliveryBatch]:
"""尝试通过 Platform IO 中间层发送消息。
Args:
message: 待发送的内部会话消息。
Returns:
Optional[DeliveryBatch]: 若当前消息命中了至少一条发送路由,则返回
实际发送结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
"""
if not self._started:
return None
platform_io_manager = get_platform_io_manager()
if not platform_io_manager.is_started:
return None
try:
route_key = platform_io_manager.build_route_key_from_message(message)
except Exception as exc:
logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}")
return None
if not platform_io_manager.resolve_drivers(route_key):
return None
return await platform_io_manager.send_message(message, route_key)
def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]:
"""返回当前持有指定插件的所有 Supervisor。
@@ -314,30 +680,38 @@ class PluginRuntimeManager(
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
return matches[0] if matches else None
@staticmethod
def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
return False
try:
registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
except RuntimeError:
return False
if registered_supervisor is not None:
return await self.reload_plugins_globally([normalized_plugin_id], reason=reason)
supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id)
if supervisor is None:
return False
return await supervisor.reload_plugins(
plugin_ids=[normalized_plugin_id],
reason=reason,
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
@classmethod
def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
plugin_locations: Dict[str, List[Path]] = {}
for base_dir in plugin_dirs:
if not base_dir.is_dir():
continue
for entry in base_dir.iterdir():
if not entry.is_dir():
continue
manifest_path = entry / "_manifest.json"
plugin_path = entry / "plugin.py"
if not manifest_path.exists() or not plugin_path.exists():
continue
plugin_id = entry.name
try:
with open(manifest_path, "r", encoding="utf-8") as manifest_file:
manifest = json.load(manifest_file)
plugin_id = str(manifest.get("name", entry.name)).strip() or entry.name
except Exception:
continue
plugin_locations.setdefault(plugin_id, []).append(entry)
validator = ManifestValidator()
for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs):
plugin_locations.setdefault(manifest.id, []).append(plugin_path)
return {
plugin_id: sorted(dict.fromkeys(paths), key=lambda p: str(p))
@@ -370,6 +744,7 @@ class PluginRuntimeManager(
async def _stop_plugin_file_watcher(self) -> None:
"""停止插件文件监视器,并清理所有已注册订阅。"""
if self._plugin_file_watcher is None:
self._plugin_path_cache.clear()
return
for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
self._plugin_file_watcher.unsubscribe(subscription_id)
@@ -379,12 +754,79 @@ class PluginRuntimeManager(
self._plugin_source_watcher_subscription_id = None
await self._plugin_file_watcher.stop()
self._plugin_file_watcher = None
self._plugin_path_cache.clear()
def _iter_plugin_dirs(self) -> Iterable[Path]:
"""迭代所有 Supervisor 当前管理的插件根目录。"""
for supervisor in self.supervisors:
yield from getattr(supervisor, "_plugin_dirs", [])
@staticmethod
def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]:
"""迭代所有可能的插件目录路径。
Args:
plugin_dirs: 一个或多个插件根目录。
Yields:
Path: 单个插件目录路径。
"""
for plugin_dir in plugin_dirs:
plugin_root = Path(plugin_dir).resolve()
if not plugin_root.is_dir():
continue
for entry in plugin_root.iterdir():
if entry.is_dir():
yield entry.resolve()
def _read_plugin_id_from_plugin_path(self, plugin_path: Path) -> Optional[str]:
"""从单个插件目录中读取 manifest 声明的插件 ID。
Args:
plugin_path: 单个插件目录路径。
Returns:
Optional[str]: 解析成功时返回插件 ID否则返回 ``None``。
"""
return self._manifest_validator.read_plugin_id_from_plugin_path(plugin_path)
def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]:
"""迭代目录中可解析到的插件 ID 与实际目录路径。
Args:
plugin_dirs: 一个或多个插件根目录。
Yields:
Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。
"""
for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs):
if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path):
yield plugin_id, plugin_path
def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
"""为指定 Supervisor 定位某个插件的实际目录。
Args:
supervisor: 目标 Supervisor。
plugin_id: 插件 ID。
Returns:
Optional[Path]: 插件目录路径;未找到时返回 ``None``。
"""
cached_path = self._plugin_path_cache.get(plugin_id)
if cached_path is not None:
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
return cached_path
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
if candidate_plugin_id != plugin_id:
continue
self._plugin_path_cache[plugin_id] = plugin_path
return plugin_path
return None
def _refresh_plugin_config_watch_subscriptions(self) -> None:
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
@@ -394,7 +836,11 @@ class PluginRuntimeManager(
if self._plugin_file_watcher is None:
return
desired_config_paths = dict(self._iter_registered_plugin_config_paths())
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
self._plugin_path_cache = desired_plugin_paths.copy()
desired_config_paths = {
plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
}
for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]:
@@ -418,28 +864,35 @@ class PluginRuntimeManager(
"""为指定插件生成配置文件变更回调。"""
async def _callback(changes: Sequence[FileChange]) -> None:
"""将 watcher 事件转发到指定插件的配置处理逻辑。
Args:
changes: 当前批次收集到的文件变更列表。
"""
await self._handle_plugin_config_changes(plugin_id, changes)
return _callback
def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]:
"""迭代当前所有已注册插件的 config.toml 路径。"""
def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
"""迭代当前所有已注册插件的实际目录路径。"""
for supervisor in self.supervisors:
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id):
yield plugin_id, config_path
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
yield plugin_id, plugin_path
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
plugin_dir = Path(plugin_dir)
plugin_path = plugin_dir.resolve() / plugin_id
if plugin_path.is_dir():
return plugin_path / "config.toml"
return None
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
return None if plugin_path is None else plugin_path / "config.toml"
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
"""处理单个插件配置文件变化,并仅向目标插件推送配置更新。"""
"""处理单个插件配置文件变化,并定向派发自配置更新。
Args:
plugin_id: 发生配置变更的插件 ID。
changes: 当前批次收集到的配置文件变更列表。
"""
if not self._started or not changes:
return
@@ -453,18 +906,24 @@ class PluginRuntimeManager(
return
try:
await supervisor.notify_plugin_config_updated(
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])),
config_data=config_payload,
config_version="",
config_scope="self",
)
if not delivered:
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}")
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None:
"""处理插件源码相关变化。
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。
单独的 per-plugin watcher 处理,并定向派发给目标插件的
``on_config_update()``,避免放大成不必要的跨插件 reload。
"""
if not self._started or not changes:
return
@@ -477,7 +936,7 @@ class PluginRuntimeManager(
logger.error(f"检测到重复插件 ID跳过本次插件热重载: {details}")
return
reload_supervisors: List[Any] = []
changed_plugin_ids: List[str] = []
changed_paths = [change.path.resolve() for change in changes]
for supervisor in self.supervisors:
@@ -485,13 +944,12 @@ class PluginRuntimeManager(
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
if plugin_id is None:
continue
if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors:
reload_supervisors.append(supervisor)
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
if plugin_id not in changed_plugin_ids:
changed_plugin_ids.append(plugin_id)
for supervisor in reload_supervisors:
await supervisor.reload_plugins(reason="file_watcher")
if reload_supervisors:
if changed_plugin_ids:
await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
self._refresh_plugin_config_watch_subscriptions()
@staticmethod
@@ -502,36 +960,47 @@ class PluginRuntimeManager(
def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]:
"""根据变更路径为指定 Supervisor 推断受影响的插件 ID。"""
for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items():
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
plugin_dir = Path(plugin_dir)
candidate_dir = plugin_dir.resolve() / plugin_id
if path == candidate_dir or path.is_relative_to(candidate_dir):
return plugin_id
resolved_path = path.resolve()
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)):
return plugin_id
for plugin_id, plugin_path in self._plugin_path_cache.items():
if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
continue
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
return plugin_id
for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
self._plugin_path_cache[plugin_id] = plugin_path
return plugin_id
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
plugin_dir = Path(plugin_dir)
plugin_root = plugin_dir.resolve()
if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts):
return relative_parts[0]
return None
@staticmethod
def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]:
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]:
"""从给定插件目录集合中读取目标插件的配置内容。"""
for plugin_dir in plugin_dirs:
plugin_path = plugin_dir.resolve() / plugin_id
if plugin_path.is_dir():
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {}
with open(config_path, "r", encoding="utf-8") as handle:
return tomlkit.load(handle).unwrap()
return {}
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
if plugin_path is None:
return {}
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {}
with open(config_path, "r", encoding="utf-8") as handle:
return tomlkit.load(handle).unwrap()
# ─── 能力实现注册 ──────────────────────────────────────────
def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None:
"""向指定 Supervisor 注册主程序能力实现。
Args:
supervisor: 需要注册能力实现的目标 Supervisor。
"""
register_capability_impls(self, supervisor)

View File

@@ -7,52 +7,52 @@
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
import logging as stdlib_logging
import time
from pydantic import BaseModel, Field
# ─── 协议常量 ──────────────────────────────────────────────────────
PROTOCOL_VERSION = "1.0"
# ====== 协议常量 ======
PROTOCOL_VERSION = "1.0.0"
# 支持的 SDK 版本范围Host 在握手时校验)
MIN_SDK_VERSION = "1.0.0"
MAX_SDK_VERSION = "1.99.99"
# ─── 消息类型 ──────────────────────────────────────────────────────
MAX_SDK_VERSION = "2.99.99"
# ====== 消息类型 ======
class MessageType(str, Enum):
"""RPC 消息类型"""
REQUEST = "request"
RESPONSE = "response"
EVENT = "event"
BROADCAST = "broadcast"
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
class ConfigReloadScope(str, Enum):
"""配置热重载范围。"""
SELF = "self"
BOT = "bot"
MODEL = "model"
# ====== 请求 ID 生成器 ======
class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio"""
"""单调递增 int64 请求 ID 生成器"""
def __init__(self, start: int = 1) -> None:
self._counter = start
def next(self) -> int:
async def next(self) -> int:
current = self._counter
self._counter += 1
return current
# ─── Envelope 模型 ─────────────────────────────────────────────────
# ====== Envelope 模型 ======
class Envelope(BaseModel):
"""RPC 统一信封
"""RPC 统一消息封装
所有 Host <-> Runner 消息均封装为此格式。
序列化流程Envelope -> .model_dump() -> MsgPack encode
@@ -60,15 +60,23 @@ class Envelope(BaseModel):
"""
protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本")
"""协议版本"""
request_id: int = Field(description="单调递增请求 ID")
"""单调递增请求 ID"""
message_type: MessageType = Field(description="消息类型")
"""消息类型"""
method: str = Field(default="", description="RPC 方法名")
"""RPC 方法名"""
plugin_id: str = Field(default="", description="目标插件 ID")
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
timeout_ms: int = Field(default=30000, description="相对超时(ms)")
generation: int = Field(default=0, description="Runner generation 编号")
"""目标插件 ID"""
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳 (ms)")
"""发送时间戳 (ms)"""
timeout_ms: int = Field(default=30000, description="相对超时 (ms)")
"""相对超时 (ms)"""
payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)")
"""业务数据"""
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)")
"""错误信息 (仅 response)"""
def is_request(self) -> bool:
return self.message_type == MessageType.REQUEST
@@ -76,8 +84,8 @@ class Envelope(BaseModel):
def is_response(self) -> bool:
return self.message_type == MessageType.RESPONSE
def is_event(self) -> bool:
return self.message_type == MessageType.EVENT
def is_broadcast(self) -> bool:
return self.message_type == MessageType.BROADCAST
def make_response(
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
@@ -89,7 +97,6 @@ class Envelope(BaseModel):
message_type=MessageType.RESPONSE,
method=self.method,
plugin_id=self.plugin_id,
generation=self.generation,
payload=payload or {},
error=error,
)
@@ -105,153 +112,302 @@ class Envelope(BaseModel):
)
# ─── 握手消息 ──────────────────────────────────────────────────────
# ====== 握手请求与响应 ======
class HelloPayload(BaseModel):
"""runner.hello 握手请求 payload"""
runner_id: str = Field(description="Runner 进程唯一标识")
"""Runner 进程唯一标识"""
sdk_version: str = Field(description="SDK 版本号")
"""SDK 版本号"""
session_token: str = Field(description="一次性会话令牌")
"""一次性会话令牌"""
class HelloResponsePayload(BaseModel):
"""runner.hello 握手响应 payload"""
accepted: bool = Field(description="是否接受连接")
"""是否接受连接"""
host_version: str = Field(default="", description="Host 版本号")
assigned_generation: int = Field(default=0, description="分配的 generation 编")
reason: str = Field(default="", description="拒绝原因(若 accepted=False)")
# ─── 组件注册消息 ──────────────────────────────────────────────────
"""Host 版本"""
reason: str = Field(default="", description="拒绝原因 (若 accepted=False)")
"""拒绝原因 (若 `accepted`=`False`)"""
# ====== 组件注册消息 ======
class ComponentDeclaration(BaseModel):
"""单个组件声明"""
name: str = Field(description="组件名称")
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
"""组件名称"""
component_type: str = Field(
description="组件类型action/command/tool/event_handler/hook_handler/message_gateway"
)
"""组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
plugin_id: str = Field(description="所属插件 ID")
"""所属插件 ID"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
"""组件元数据"""
class RegisterComponentsPayload(BaseModel):
"""plugin.register_components 请求 payload"""
class RegisterPluginPayload(BaseModel):
"""插件组件注册请求载荷。
该模型同时用于 ``plugin.register_components`` 与兼容旧命名的
``plugin.register_plugin`` 请求。
"""
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
plugin_version: str = Field(default="1.0.0", description="插件版本")
"""插件版本"""
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
"""组件列表"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
"""插件级依赖插件 ID 列表"""
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
"""订阅的全局配置热重载范围"""
class BootstrapPluginPayload(BaseModel):
"""plugin.bootstrap 请求 payload"""
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
plugin_version: str = Field(default="1.0.0", description="插件版本")
"""插件版本"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
# ─── 调用消息 ──────────────────────────────────────────────────────
# ====== 插件调用请求和响应 ======
class InvokePayload(BaseModel):
"""plugin.invoke_* 请求 payload"""
"""plugin.invoke.* 请求 payload"""
component_name: str = Field(description="要调用的组件名称")
"""要调用的组件名称"""
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
"""调用参数"""
class InvokeResultPayload(BaseModel):
"""plugin.invoke_* 响应 payload"""
"""plugin.invoke.* 响应 payload"""
success: bool = Field(description="是否成功")
"""是否成功"""
result: Any = Field(default=None, description="返回值")
"""返回值"""
# ─── 能力调用消息 ──────────────────────────────────────────────────
# ====== 能力调用消息 ======
class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload插件 -> Host 能力调用)"""
capability: str = Field(description="能力名称,如 send.text, db.query")
"""能力名称,如 send.text, db.query"""
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
"""调用参数"""
class CapabilityResponsePayload(BaseModel):
"""cap.* 响应 payload"""
success: bool = Field(description="是否成功")
"""是否成功"""
result: Any = Field(default=None, description="返回值")
"""返回值"""
# ─── 健康检查 ──────────────────────────────────────────────────────
# ====== 健康检查 ======
class HealthPayload(BaseModel):
"""plugin.health 响应 payload"""
healthy: bool = Field(description="是否健康")
"""是否健康"""
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
uptime_ms: int = Field(default=0, description="运行时长(ms)")
"""已加载的插件列表"""
uptime_ms: int = Field(default=0, description="运行时长 (ms)")
"""运行时长 (ms)"""
class RunnerReadyPayload(BaseModel):
"""runner.ready 请求 payload"""
loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表")
"""已完成初始化的插件列表"""
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
"""初始化失败的插件列表"""
# ─── 配置更新 ──────────────────────────────────────────────────────
# Host 侧现已支持配置更新推送:
# - 总配置热重载完成后PluginRuntimeManager 会向已加载插件推送配置更新事件。
# - 插件目录下的 config.toml 变化由现有 FileWatcher 监听并转发为 plugin.config_updated。
# ====== 配置更新 ======
class ConfigUpdatedPayload(BaseModel):
"""plugin.config_updated 事件 payload"""
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
config_scope: ConfigReloadScope = Field(description="配置变更范围")
"""配置变更范围"""
config_version: str = Field(description="新配置版本")
"""新配置版本"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
"""配置内容"""
# ─── 关停 ──────────────────────────────────────────────────────────
# ====== 关停 ======
class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload"""
reason: str = Field(default="normal", description="关停原因")
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
"""关停原因"""
drain_timeout_ms: int = Field(default=5000, description="排空超时 (ms)")
"""排空超时 (ms)"""
# ─── 日志传输 ──────────────────────────────────────────────────────
class UnregisterPluginPayload(BaseModel):
"""插件注销请求载荷。"""
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
reason: str = Field(default="manual", description="注销原因")
"""注销原因"""
class ReloadPluginPayload(BaseModel):
"""插件重载请求载荷。"""
plugin_id: str = Field(description="目标插件 ID")
"""目标插件 ID"""
reason: str = Field(default="manual", description="重载原因")
"""重载原因"""
external_available_plugins: Dict[str, str] = Field(
default_factory=dict,
description="可视为已满足的外部依赖插件版本映射",
)
"""可视为已满足的外部依赖插件版本映射"""
class ReloadPluginsPayload(BaseModel):
"""批量插件重载请求载荷。"""
plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表")
"""目标插件 ID 列表"""
reason: str = Field(default="manual", description="重载原因")
"""重载原因"""
external_available_plugins: Dict[str, str] = Field(
default_factory=dict,
description="可视为已满足的外部依赖插件版本映射",
)
"""可视为已满足的外部依赖插件版本映射"""
class ReloadPluginResultPayload(BaseModel):
"""插件重载结果载荷。"""
success: bool = Field(description="是否重载成功")
"""是否重载成功"""
requested_plugin_id: str = Field(description="请求重载的插件 ID")
"""请求重载的插件 ID"""
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
"""成功完成重载的插件列表"""
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
"""本次已卸载的插件列表"""
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
"""重载失败的插件及原因"""
class ReloadPluginsResultPayload(BaseModel):
"""批量插件重载结果载荷。"""
success: bool = Field(description="是否重载成功")
"""是否重载成功"""
requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表")
"""请求重载的插件 ID 列表"""
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
"""成功完成重载的插件列表"""
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
"""本次已卸载的插件列表"""
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
"""重载失败的插件及原因"""
class MessageGatewayStateUpdatePayload(BaseModel):
"""消息网关运行时状态更新载荷。"""
gateway_name: str = Field(description="消息网关组件名称")
"""消息网关组件名称"""
ready: bool = Field(description="当前链路是否已经就绪")
"""当前链路是否已经就绪"""
platform: str = Field(default="", description="当前链路负责的平台名称")
"""当前链路负责的平台名称"""
account_id: str = Field(default="", description="当前链路对应的账号 ID 或 self_id")
"""当前链路对应的账号 ID 或 self_id"""
scope: str = Field(default="", description="当前链路对应的可选路由作用域")
"""当前链路对应的可选路由作用域"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
"""可选的运行时状态元数据"""
class MessageGatewayStateUpdateResultPayload(BaseModel):
"""消息网关运行时状态更新结果载荷。"""
accepted: bool = Field(description="Host 是否接受了本次状态更新")
"""Host 是否接受了本次状态更新"""
ready: bool = Field(description="Host 记录的当前就绪状态")
"""Host 记录的当前就绪状态"""
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
"""当前生效的路由键"""
class RouteMessagePayload(BaseModel):
"""消息网关向 Host 路由外部消息的请求载荷。"""
gateway_name: str = Field(description="接收消息的网关组件名称")
"""接收消息的网关组件名称"""
message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典")
"""符合 MessageDict 结构的标准消息字典"""
route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据")
"""可选的路由辅助元数据"""
external_message_id: str = Field(default="", description="可选的外部平台消息 ID")
"""可选的外部平台消息 ID"""
dedupe_key: str = Field(default="", description="可选的显式去重键")
"""可选的显式去重键"""
class ReceiveExternalMessageResultPayload(BaseModel):
"""外部消息注入结果载荷。"""
accepted: bool = Field(description="Host 是否接受了本次消息注入")
"""Host 是否接受了本次消息注入"""
route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键")
"""本次消息使用的归一路由键"""
RegisterPluginPayload.model_rebuild()
# ====== 日志传输 ======
class LogEntry(BaseModel):
"""单条日志记录Runner → Host 传输格式)"""
timestamp_ms: int = Field(
description="日志时间戳Unix epoch 毫秒",
)
level: int = Field(
description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
)
logger_name: str = Field(
description="Logger 名称,如 plugin.my_plugin.submodule",
)
message: str = Field(
description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)",
)
timestamp_ms: int = Field(description="日志时间戳Unix epoch 毫秒")
"""日志时间戳Unix epoch 毫秒"""
level: int = Field(description="stdlib logging 整数级别10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL")
"""stdlib logging 整数级别10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"""
logger_name: str = Field(description="Logger 名称,如 plugin.my_plugin.submodule")
"""Logger 名称,如 plugin.my_plugin.submodule"""
message: str = Field(description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)")
"""经 Formatter 格式化后的完整日志消息(含 exc_info 文本)"""
exception_text: str = Field(
default="",
description="原始异常摘要exc_text供结构化消费已嵌入 message 中",
)
"""原始异常摘要exc_text供结构化消费已嵌入 message 中"""
log_color_in_hex: Optional[str] = Field(default=None, description="日志颜色的十六进制字符串(如 #RRGGBB")
@property
def levelname(self) -> str:
@@ -262,6 +418,5 @@ class LogEntry(BaseModel):
class LogBatchPayload(BaseModel):
"""runner.log_batch 事件 payloadRunner 端向 Host 批量推送日志记录"""
entries: List[LogEntry] = Field(
description="本批次日志记录列表,按时间升序排列",
)
entries: List[LogEntry] = Field(description="本批次日志记录列表,按时间升序排列")
"""本批次日志记录列表,按时间升序排列"""

Some files were not shown because too many files have changed in this diff Show More