diff --git a/AGENTS.md b/AGENTS.md index b4caaaf1..b3456610 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,6 +17,7 @@ 1. 尽量保持良好的注释 2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。 3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。 +4. 对于类,方法以及模块的注释,首选使用的注释格式为 Google DocStr 格式,但保证语言为简体中文 ## 类型注解规范 1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。 2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解) @@ -35,3 +36,7 @@ # 运行/调试/构建/测试/依赖 优先使用uv 依赖项以 pyproject.toml 为准 + +# 语言规范 + +项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都应该首要以简体中文为首要实现目标 diff --git a/README.md b/README.md index 7c41c8cf..3f3851b7 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,21 @@
- 简体中文 | English + 简体中文 | English

-

麦麦 MaiBot MaiSaka

+

MaiBot MaiSaka

Python Version - License - Status - Contributors - Forks - Stars + License + Status + Contributors + Forks + Stars Ask DeepWiki

@@ -25,31 +25,24 @@ MaiBot Character -## 介绍 +## Introduction -麦麦MaiSaka 是一个基于大语言模型的可交互智能体 +MaiSaka is an interactive agent based on large language models. -MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务的“有帮助的助手”,她还是一个致力于了解你,并以真实人类的风格进行交互的数字生命,她不追求完美,她不追求高效,但追求亲切和真实。 +MaiSaka is more than just a bot, and more than a "helpful assistant" that completes tasks. She is a digital life form that tries to understand you and interact in a genuinely human style. She does not pursue perfection or efficiency above all else. She pursues warmth and authenticity. +- 💭 **No one likes GPT-sounding dialogue**: MaiSaka uses a more natural conversational style. Instead of long-winded markdown-heavy replies, she chats in a way that feels casual, varied, and human. +- 🎭 **No longer stuck in rigid Q&A**: She knows when to speak, how to read the room, when to join a conversation, and when to stay quiet. +- 🧠 **MaiSaka becoming human**: In group conversations, MaiSaka imitates how people around her speak, learns new slang and in-group language, and keeps evolving. +- ❤️ **Always learning more about you**: Inspired by personality theory in psychology, MaiSaka gradually builds an understanding of your preferences, traits, habits, and behavior style. +- 🔌 **Plugin system**: Provides powerful APIs and an event system with virtually unlimited room for extension. -- 💭 **没有人喜欢GPT的语言风格**:麦麦使用了更加自然,贴合人类对话习惯的交互方式,不是长篇大论或者markdown格式的分点,而是或长或短的闲谈。 - -- 🎭 **不再是傻乎乎的一问一答**:懂得在合适的时间说话,把握聊天中的气氛,在合适的时候开口,在合适的时候闭嘴。 - -- 🧠 **麦麦·成为人类**:在多人对话中,麦麦会模仿其他人的的说话风格,还会自主理解新词或者小圈子里的黑话,不断进化。 - -- ❤️ **永远都在更加了解你**:基于心理学中人格理论,麦麦会不断积累对于你的了解,不论是你的信息,喜恶或是行为风格,她都记在心里。 - -- 🔌 **插件系统**:提供强大的 API 和事件系统,无限扩展可能。 - - - -### 快速导航 +### Quick Navigation

- 🌟 演示视频  |  - 📦 快速入门  |  - 📃 核心文档  |  - 💬 加入社区 + 🌟 Demo Video  |  + 📦 Quick Start  |  + 📃 Core Documentation  |  + 💬 Join Community

@@ -60,103 +53,103 @@ MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务 - 麦麦演示视频 + MaiSaka Demo Video
- 前往观看麦麦演示视频 + Watch the MaiSaka demo video
--- -## 🔥 更新和安装 +## 🔥 Updates and Installation -> **最新版本: v1.0.0** ([📄 更新日志](changelogs/changelog.md)) +> **Latest Version: v1.0.0** ([📄 Changelog](changelogs/changelog.md)) -- **下载**: 前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 -- **启动器**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (仅支持 MacOS, 早期开发中) +- **Download**: Visit the [Release](https://github.com/MaiM-with-u/MaiBot/releases/) page to get the latest version. +- **Launcher**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (MacOS only, still in early development). -| 分支 | 说明 | +| Branch | Description | | :--- | :--- | -| `main` | ✅ **稳定发布版本 (推荐)** | -| `dev` | 🚧 开发测试版本,包含新功能,可能不稳定 | +| `main` | ✅ **Stable release (recommended)** | +| `dev` | 🚧 Development testing branch with new features, may be unstable | -### 📚 部署教程 -👉 **[🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)** +### 📚 Deployment Guide +👉 **[🚀 Latest Deployment Guide](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)** --- -## 💬 讨论与社区 +## 💬 Discussion and Community -我们欢迎所有对 MaiBot 感兴趣的朋友加入! +We welcome everyone interested in MaiBot to join us. -| 类别 | 群组 | 说明 | +| Category | Group | Description | | :--- | :--- | :--- | -| **技术交流** | [麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) | 技术交流/答疑 | -| **技术交流** | [麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) | 技术交流/答疑 | -| **技术交流** | [麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) | 技术交流/答疑 | -| **闲聊吹水** | [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) | 仅限闲聊,不答疑 | -| **插件开发** | [插件开发群](https://qm.qq.com/q/1036092828) | 进阶开发与测试 | +| **Technical** | [MaiBrain EEG](https://qm.qq.com/q/RzmCiRtHEW) | Technical discussion / Q&A | +| **Technical** | [MaiBrain MRI](https://qm.qq.com/q/VQ3XZrWgMs) | Technical discussion / Q&A | +| **Technical** | [Mai Wants to Be a VTuber](https://qm.qq.com/q/wGePTl1UyY) | Technical discussion / Q&A | +| **Casual Chat** | [Mai Casual Chat Group](https://qm.qq.com/q/JxvHZnxyec) | Casual chat only, no support | +| **Plugin Development** | [Plugin Dev Group](https://qm.qq.com/q/1036092828) | Advanced development and testing | --- -## 📚 文档 +## 📚 Documentation > [!NOTE] -> 部分内容可能更新不够及时,请注意版本对应。 +> Some content may not be updated promptly, so please pay attention to version compatibility. -- **[📚 核心 Wiki 文档](https://docs.mai-mai.org)**: 最全面的文档中心,了解麦麦的一切。 +- **[📚 Core Wiki Documentation](https://docs.mai-mai.org)**: The most comprehensive documentation hub for everything about MaiSaka. -### 🧩 衍生项目 +### 🧩 Related Projects -- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: 让麦麦在B站开播 -- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: 基于 MaiCore 0.10.0 的增强型 Fork,更稳定更有趣。 -- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: 让麦麦陪你玩 Minecraft (暂时停止维护中)。 +- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: Let MaiSaka stream on Bilibili. +- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: An enhanced fork based on MaiCore 0.10.0, with improved stability and more fun features. +- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: Let MaiSaka accompany you in Minecraft (currently paused). --- -## 💡 设计理念 +## 💡 Design Philosophy -> **千石可乐说:** -> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 -> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。 -> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。 +> **SengokuCola says:** +> - This project originally started as a few extra features for the NiuNiu bot, but it kept growing until a full rewrite became inevitable. The goal was to create a "life form" active in QQ group chats, not a feature-complete bot, but something as human-like and real-feeling as possible. +> - The core design principle is: "more lifelike, not merely better." +> - If people truly want AI companionship, not everyone needs a perfect "helpful assistant" that solves every problem. Some people may want a life form that can make mistakes and has its own perceptions and thoughts. -> **xxxxx说:** +> **xxxxx says:** > *Code is open, but the soul is yours.* --- -## 🙋 贡献和致谢 +## 🙋 Contributing and Acknowledgments -欢迎参与贡献!请先阅读 [贡献指南](docs-src/CONTRIBUTE.md)。 +Contributions are welcome. Please read the [Contribution Guide](docs-src/CONTRIBUTE.md) first. -### 🌟 贡献者 +### 🌟 Contributors contributors -### ❤️ 特别致谢 +### ❤️ Special Thanks -- **[萨卡班甲鱼](https://en.wikipedia.org/wiki/Sacabambaspis)**: 千石可乐很喜欢的生物。 -- **[略nd](https://space.bilibili.com/1344099355)**: 🎨 为麦麦绘制早期的精美人设。 -- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: 🚀 现代化的基于 NTQQ 的 Bot 协议实现。 +- **[Sacabambaspis](https://en.wikipedia.org/wiki/Sacabambaspis)**: SengokuCola's favorite creature. +- **[略nd](https://space.bilibili.com/1344099355)**: Drew MaiSaka's beautiful early character design. +- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: A modern NTQQ-based bot protocol implementation. --- -## 📊 仓库状态 +## 📊 Repository Status -![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "麦麦仓库状态") +![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "MaiBot Repository Status") -### Star 趋势 -[![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) +### Star History +[![Star History](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) --- -## 📌 注意事项 & License +## 📌 Notice & License > [!IMPORTANT] -> 使用前请阅读 [用户协议 (EULA)](EULA.md) 和 [隐私协议](PRIVACY.md)。AI 生成内容请仔细甄别。 +> Please read the [End User License Agreement (EULA)](EULA.md) and [Privacy Policy](PRIVACY.md) before use. Please evaluate AI-generated content carefully. **License**: GPL-3.0 diff --git a/docs/README_CN.md b/docs/README_CN.md new file mode 100644 index 00000000..9532cd00 --- /dev/null +++ b/docs/README_CN.md @@ -0,0 +1,162 @@ +
+ + + 简体中文 | English + +
+
+ +

麦麦 MaiBot MaiSaka

+ + +

+ Python Version + License + Status + Contributors + Forks + Stars + Ask DeepWiki +

+
+ +
+ + +MaiBot Character + +## 介绍 + +麦麦MaiSaka 是一个基于大语言模型的可交互智能体 + +MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务的“有帮助的助手”,她还是一个致力于了解你,并以真实人类的风格进行交互的数字生命,她不追求完美,她不追求高效,但追求亲切和真实。 + + +- 💭 **没有人喜欢GPT的语言风格**:麦麦使用了更加自然,贴合人类对话习惯的交互方式,不是长篇大论或者markdown格式的分点,而是或长或短的闲谈。 + +- 🎭 **不再是傻乎乎的一问一答**:懂得在合适的时间说话,把握聊天中的气氛,在合适的时候开口,在合适的时候闭嘴。 + +- 🧠 **麦麦·成为人类**:在多人对话中,麦麦会模仿其他人的的说话风格,还会自主理解新词或者小圈子里的黑话,不断进化。 + +- ❤️ **永远都在更加了解你**:基于心理学中人格理论,麦麦会不断积累对于你的了解,不论是你的信息,喜恶或是行为风格,她都记在心里。 + +- 🔌 **插件系统**:提供强大的 API 和事件系统,无限扩展可能。 + + + +### 快速导航 +

+ 🌟 演示视频  |  + 📦 快速入门  |  + 📃 核心文档  |  + 💬 加入社区 +

+ + +
+ +
+
+ + + + 麦麦演示视频 + +
+ 前往观看麦麦演示视频 +
+
+ +--- + +## 🔥 更新和安装 + +> **最新版本: v1.0.0** ([📄 更新日志](../changelogs/changelog.md)) + +- **下载**: 前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 +- **启动器**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (仅支持 MacOS, 早期开发中) + +| 分支 | 说明 | +| :--- | :--- | +| `main` | ✅ **稳定发布版本 (推荐)** | +| `dev` | 🚧 开发测试版本,包含新功能,可能不稳定 | + +### 📚 部署教程 +👉 **[🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)** + +--- + +## 💬 讨论与社区 + +我们欢迎所有对 MaiBot 感兴趣的朋友加入! + +| 类别 | 群组 | 说明 | +| :--- | :--- | :--- | +| **技术交流** | [麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) | 技术交流/答疑 | +| **技术交流** | [麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) | 技术交流/答疑 | +| **技术交流** | [麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) | 技术交流/答疑 | +| **闲聊吹水** | [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) | 仅限闲聊,不答疑 | +| **插件开发** | [插件开发群](https://qm.qq.com/q/1036092828) | 进阶开发与测试 | + +--- + +## 📚 文档 + +> [!NOTE] +> 部分内容可能更新不够及时,请注意版本对应。 + +- **[📚 核心 Wiki 文档](https://docs.mai-mai.org)**: 最全面的文档中心,了解麦麦的一切。 + +### 🧩 衍生项目 + +- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: 让麦麦在B站开播 +- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: 基于 MaiCore 0.10.0 的增强型 Fork,更稳定更有趣。 +- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: 让麦麦陪你玩 Minecraft (暂时停止维护中)。 + +--- + +## 💡 设计理念 + +> **千石可乐说:** +> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 +> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。 +> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。 + +> **xxxxx说:** +> *Code is open, but the soul is yours.* + +--- + +## 🙋 贡献和致谢 + +欢迎参与贡献!请先阅读 [贡献指南](../docs-src/CONTRIBUTE.md)。 + +### 🌟 贡献者 + + + contributors + + +### ❤️ 特别致谢 + +- **[萨卡班甲鱼](https://en.wikipedia.org/wiki/Sacabambaspis)**: 千石可乐很喜欢的生物。 +- **[略nd](https://space.bilibili.com/1344099355)**: 🎨 为麦麦绘制早期的精美人设。 +- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: 🚀 现代化的基于 NTQQ 的 Bot 协议实现。 + +--- + +## 📊 仓库状态 + +![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "麦麦仓库状态") + +### Star 趋势 +[![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) + +--- + +## 📌 注意事项 & License + +> [!IMPORTANT] +> 使用前请阅读 [用户协议 (EULA)](../EULA.md) 和 [隐私协议](../PRIVACY.md)。AI 生成内容请仔细甄别。 + +**License**: GPL-3.0 diff --git a/docs/README_EN.md b/docs/README_EN.md index f37002fa..9ecc53b7 100644 --- a/docs/README_EN.md +++ b/docs/README_EN.md @@ -1,7 +1,7 @@
- 简体中文 | English + 简体中文 | English

diff --git a/docs/minimal-cross-platform-plan.md b/docs/minimal-cross-platform-plan.md index d0b6707b..2f0a86bd 100644 --- a/docs/minimal-cross-platform-plan.md +++ b/docs/minimal-cross-platform-plan.md @@ -41,7 +41,7 @@ This plan is based on the checked-in code, not on assumptions from previous draf | `src/person_info/person_info.py:247` | `_is_bot_self(self, platform, user_id)` | Duplicate logic with same QQ fallback | Wrong-order call sites (8 total): -- `src/bw_learner/expression_learner.py` x3 (lines 158, 241, 301) +- `src/learners/expression_learner.py` x3 (lines 158, 241, 301) - `src/common/utils/utils_message.py` x4 (lines 370, 440, 476, 515) - `src/webui/routers/chat/support.py` x1 (line 65) @@ -122,7 +122,7 @@ Make `src/chat/utils/utils.py::is_bot_self(platform, user_id)` the only real imp - `src/common/utils/system_utils.py` - `src/chat/utils/utils.py` - `src/person_info/person_info.py` -- `src/bw_learner/expression_learner.py` +- `src/learners/expression_learner.py` - `src/common/utils/utils_message.py` - `src/webui/routers/chat/support.py` - tests @@ -468,7 +468,7 @@ When stopping, name: the exact file(s), the blocking mismatch, why it is outside | Phase | Allowed files | |-------|---------------| -| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/bw_learner/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) | +| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/learners/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) | | Phase 1 | `src/chat/utils/utils.py`, `src/chat/planner_actions/planner.py`, `src/chat/utils/statistic.py`, `src/common/message_repository.py`, `src/webui/routers/chat/support.py`, `src/services/send_service.py`, `src/chat/replyer/group_generator.py`, `src/chat/replyer/private_generator.py`, `src/chat/brain_chat/PFC/message_sender.py`, `src/person_info/person_info.py`, tests | ### INVALID OUTPUT EXAMPLES diff --git a/plugins/ChatFrequency/_manifest.json b/plugins/ChatFrequency/_manifest.json index 241242ed..56417665 100644 --- a/plugins/ChatFrequency/_manifest.json +++ b/plugins/ChatFrequency/_manifest.json @@ -1,58 +1,40 @@ { - "manifest_version": 1, - "name": "发言频率控制插件|BetterFrequency Plugin", + "manifest_version": 2, "version": "2.0.0", - "description": "控制聊天频率,支持设置focus_value和talk_frequency调整值,提供命令", + "name": "发言频率控制插件|BetterFrequency Plugin", + "description": "控制聊天频率,支持设置 focus_value 和 talk_frequency 调整值,并提供命令入口。", "author": { "name": "SengokuCola", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" + "urls": { + "repository": "https://github.com/SengokuCola/BetterFrequency", + "homepage": "https://github.com/SengokuCola/BetterFrequency", + "documentation": "https://github.com/SengokuCola/BetterFrequency", + "issues": "https://github.com/SengokuCola/BetterFrequency/issues" }, - "homepage_url": "https://github.com/SengokuCola/BetterFrequency", - "repository_url": "https://github.com/SengokuCola/BetterFrequency", - "keywords": [ - "frequency", - "control", - "talk_frequency", - "plugin", - "shortcut" + "host_application": { + "min_version": "1.0.0", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [], + "capabilities": [ + "send.text", + "frequency.set_adjust", + "frequency.get_current_talk_value", + "frequency.get_adjust" ], - "categories": [ - "Chat", - "Frequency", - "Control" - ], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": false, - "plugin_type": "frequency", - "components": [ - { - "type": "command", - "name": "set_talk_frequency", - "description": "设置当前聊天的talk_frequency调整值", - "pattern": "/chat talk_frequency <数字> 或 /chat t <数字>" - }, - { - "type": "command", - "name": "show_frequency", - "description": "显示当前聊天的频率控制状态", - "pattern": "/chat show 或 /chat s" - } - ], - "features": [ - "设置talk_frequency调整值", - "调整当前聊天的发言频率", - "显示当前频率控制状态", - "实时频率控制调整", - "命令执行反馈(不保存消息)", - "支持完整命令和简化命令", - "快速操作支持" + "i18n": { + "default_locale": "zh-CN", + "locales_path": "_locales", + "supported_locales": [ + "zh-CN" ] }, - "id": "SengokuCola.BetterFrequency" -} \ No newline at end of file + "id": "sengokucola.betterfrequency" +} diff --git a/plugins/ChatFrequency/plugin.py b/plugins/ChatFrequency/plugin.py index b3f69384..0e9f5a0c 100644 --- a/plugins/ChatFrequency/plugin.py +++ b/plugins/ChatFrequency/plugin.py @@ -3,12 +3,18 @@ 通过 /chat 命令设置和查看聊天频率。 """ -from maibot_sdk import MaiBotPlugin, Command +from maibot_sdk import Command, MaiBotPlugin class BetterFrequencyPlugin(MaiBotPlugin): """聊天频率控制插件""" + async def on_load(self) -> None: + """处理插件加载。""" + + async def on_unload(self) -> None: + """处理插件卸载。""" + @Command( "set_talk_frequency", description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>", @@ -80,6 +86,25 @@ class BetterFrequencyPlugin(MaiBotPlugin): await self.ctx.send.text(status_msg, stream_id) return True, None, False + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del scope + del config_data + del version + + +def create_plugin() -> BetterFrequencyPlugin: + """创建聊天频率插件实例。 + + Returns: + BetterFrequencyPlugin: 新的聊天频率插件实例。 + """ -def create_plugin(): return BetterFrequencyPlugin() diff --git a/plugins/MaiBot_MCPBridgePlugin/_manifest.json b/plugins/MaiBot_MCPBridgePlugin/_manifest.json index 85225a43..d2e08ab4 100644 --- a/plugins/MaiBot_MCPBridgePlugin/_manifest.json +++ b/plugins/MaiBot_MCPBridgePlugin/_manifest.json @@ -1,67 +1,42 @@ { - "manifest_version": 1, - "name": "MCP桥接插件", + "manifest_version": 2, "version": "2.0.0", - "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具", + "name": "MCP桥接插件", + "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。", "author": { "name": "CharTyr", "url": "https://github.com/CharTyr" }, "license": "AGPL-3.0", - "host_application": { - "min_version": "0.11.6" + "urls": { + "repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", + "homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", + "documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", + "issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues" }, - "homepage_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "repository_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "keywords": [ - "mcp", - "bridge", - "tool", - "integration", - "resources", - "prompts", - "post-process", - "cache", - "trace", - "permissions", - "import", - "export", - "claude-desktop", - "workflow", - "react", - "agent" + "host_application": { + "min_version": "0.11.6", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [ + { + "type": "python_package", + "name": "mcp", + "version_spec": ">=0.0.0" + } ], - "categories": [ - "工具扩展", - "外部集成" + "capabilities": [ + "send.text" ], - "default_locale": "zh-CN", - "plugin_info": { - "is_built_in": false, - "components": [], - "features": [ - "支持多个 MCP 服务器", - "自动发现并注册 MCP 工具", - "支持 stdio、SSE、HTTP、Streamable HTTP 四种传输方式", - "工具参数自动转换", - "心跳检测与自动重连", - "调用统计(次数、成功率、耗时)", - "WebUI 配置支持", - "Resources 支持(实验性)", - "Prompts 支持(实验性)", - "结果后处理(LLM 摘要提炼)", - "工具禁用管理", - "调用链路追踪", - "工具调用缓存(LRU)", - "工具权限控制(群/用户级别)", - "配置导入导出(Claude Desktop mcpServers)", - "断路器模式(故障快速失败)", - "状态实时刷新", - "Workflow 硬流程(顺序执行多个工具)", - "Workflow 快速添加(表单式配置)", - "ReAct 软流程(LLM 自主多轮调用)", - "双轨制架构(软流程 + 硬流程)" + "i18n": { + "default_locale": "zh-CN", + "supported_locales": [ + "zh-CN" ] }, - "id": "MaiBot Community.MCPBridgePlugin" + "id": "chartyr.mcpbridge-plugin" } diff --git a/plugins/emoji_manage_plugin/_manifest.json b/plugins/emoji_manage_plugin/_manifest.json index 3af69023..998cb7da 100644 --- a/plugins/emoji_manage_plugin/_manifest.json +++ b/plugins/emoji_manage_plugin/_manifest.json @@ -1,68 +1,44 @@ { - "manifest_version": 1, - "name": "BetterEmoji", + "manifest_version": 2, "version": "2.0.0", + "name": "BetterEmoji", "description": "更好的表情包管理插件", "author": { "name": "SengokuCola", "url": "https://github.com/SengokuCola" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" + "urls": { + "repository": "https://github.com/SengokuCola/BetterEmoji", + "homepage": "https://github.com/SengokuCola/BetterEmoji", + "documentation": "https://github.com/SengokuCola/BetterEmoji", + "issues": "https://github.com/SengokuCola/BetterEmoji/issues" }, - "homepage_url": "https://github.com/SengokuCola/BetterEmoji", - "repository_url": "https://github.com/SengokuCola/BetterEmoji", - "keywords": [ - "emoji", - "manage", - "plugin" + "host_application": { + "min_version": "1.0.0", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [], + "capabilities": [ + "emoji.get_random", + "emoji.get_count", + "emoji.get_info", + "emoji.get_all", + "emoji.register_emoji", + "emoji.delete_emoji", + "send.text", + "send.forward" ], - "categories": [ - "Emoji", - "Management" - ], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": false, - "plugin_type": "emoji_manage", - "capabilities": [ - "emoji.get_random", - "emoji.get_count", - "emoji.get_info", - "emoji.get_all", - "emoji.register_emoji", - "emoji.delete_emoji", - "send.text", - "send.forward" - ], - "components": [ - { - "type": "command", - "name": "add_emoji", - "description": "添加表情包", - "pattern": "/emoji add" - }, - { - "type": "command", - "name": "emoji_list", - "description": "列表表情包", - "pattern": "/emoji list" - }, - { - "type": "command", - "name": "delete_emoji", - "description": "删除表情包", - "pattern": "/emoji delete" - }, - { - "type": "command", - "name": "random_emojis", - "description": "发送多张随机表情包", - "pattern": "/random_emojis" - } + "i18n": { + "default_locale": "zh-CN", + "locales_path": "_locales", + "supported_locales": [ + "zh-CN" ] }, - "id": "SengokuCola.BetterEmoji" -} \ No newline at end of file + "id": "sengokucola.betteremoji" +} diff --git a/plugins/emoji_manage_plugin/plugin.py b/plugins/emoji_manage_plugin/plugin.py index f3c5f677..9362c828 100644 --- a/plugins/emoji_manage_plugin/plugin.py +++ b/plugins/emoji_manage_plugin/plugin.py @@ -3,17 +3,23 @@ 通过 /emoji 命令管理表情包的添加、列表和删除。 """ +from maibot_sdk import Command, MaiBotPlugin + import base64 import datetime import hashlib import re -from maibot_sdk import MaiBotPlugin, Command - class EmojiManagePlugin(MaiBotPlugin): """表情包管理插件""" + async def on_load(self) -> None: + """处理插件加载。""" + + async def on_unload(self) -> None: + """处理插件卸载。""" + # ===== 工具方法 ===== @staticmethod @@ -208,6 +214,25 @@ class EmojiManagePlugin(MaiBotPlugin): await self.ctx.send.forward(messages, stream_id) return True, "已发送随机表情包", True + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del scope + del config_data + del version + + +def create_plugin() -> EmojiManagePlugin: + """创建表情包管理插件实例。 + + Returns: + EmojiManagePlugin: 新的表情包管理插件实例。 + """ -def create_plugin(): return EmojiManagePlugin() diff --git a/plugins/hello_world_plugin/_manifest.json b/plugins/hello_world_plugin/_manifest.json index dc9fc474..e2bc694d 100644 --- a/plugins/hello_world_plugin/_manifest.json +++ b/plugins/hello_world_plugin/_manifest.json @@ -1,88 +1,41 @@ { - "manifest_version": 1, - "name": "Hello World 示例插件 (Hello World Plugin)", + "manifest_version": 2, "version": "2.0.0", - "description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例", + "name": "Hello World 示例插件 (Hello World Plugin)", + "description": "我的第一个 MaiCore 插件,包含问候功能和时间查询等基础示例", "author": { "name": "MaiBot开发团队", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" + "urls": { + "repository": "https://github.com/MaiM-with-u/maibot", + "homepage": "https://github.com/MaiM-with-u/maibot", + "documentation": "https://github.com/MaiM-with-u/maibot", + "issues": "https://github.com/MaiM-with-u/maibot/issues" }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": [ - "demo", - "example", - "hello", - "greeting", - "tutorial" + "host_application": { + "min_version": "1.0.0", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [], + "capabilities": [ + "send.text", + "send.forward", + "send.hybrid", + "emoji.get_random", + "config.get" ], - "categories": [ - "Examples", - "Tutorial" - ], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": false, - "plugin_type": "example", - "capabilities": [ - "send.text", - "send.forward", - "send.hybrid", - "emoji.get_random", - "config.get" - ], - "components": [ - { - "type": "tool", - "name": "compare_numbers", - "description": "比较两个数的大小" - }, - { - "type": "action", - "name": "hello_greeting", - "description": "向用户发送问候消息" - }, - { - "type": "action", - "name": "bye_greeting", - "description": "向用户发送告别消息", - "activation_modes": ["keyword"], - "keywords": ["再见", "bye", "88", "拜拜"] - }, - { - "type": "command", - "name": "time", - "description": "查询当前时间", - "pattern": "/time" - }, - { - "type": "command", - "name": "random_emojis", - "description": "发送多张随机表情包", - "pattern": "/random_emojis" - }, - { - "type": "command", - "name": "test", - "description": "测试命令", - "pattern": "/test" - }, - { - "type": "event_handler", - "name": "print_message_handler", - "description": "打印接收到的消息" - }, - { - "type": "event_handler", - "name": "forward_messages_handler", - "description": "把接收到的消息转发到指定聊天ID" - } + "i18n": { + "default_locale": "zh-CN", + "locales_path": "_locales", + "supported_locales": [ + "zh-CN" ] }, - "id": "MaiBot开发团队.maibot" -} \ No newline at end of file + "id": "maibot-team.hello-world-plugin" +} diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index fbba9d10..4d1f37af 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -3,16 +3,22 @@ 你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。 """ +from maibot_sdk import Action, Command, EventHandler, MaiBotPlugin, Tool +from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType + import datetime import random -from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler -from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType - class HelloWorldPlugin(MaiBotPlugin): """Hello World 示例插件""" + async def on_load(self) -> None: + """处理插件加载。""" + + async def on_unload(self) -> None: + """处理插件卸载。""" + # ===== Tool 组件 ===== @Tool( @@ -146,6 +152,25 @@ class HelloWorldPlugin(MaiBotPlugin): return True, True, None, None, None + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del scope + del config_data + del version + + +def create_plugin() -> HelloWorldPlugin: + """创建 Hello World 示例插件实例。 + + Returns: + HelloWorldPlugin: 新的示例插件实例。 + """ -def create_plugin(): return HelloWorldPlugin() diff --git a/pyproject.toml b/pyproject.toml index 70aa42cf..90135e04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "jieba>=0.42.1", "json-repair>=0.47.6", "maim-message>=0.6.2", - "maibot-plugin-sdk>=1.2.3,<2.0.0", + "maibot-plugin-sdk>=2.0.0", "msgpack>=1.1.2", "numpy>=2.2.6", "openai>=1.95.0", @@ -55,6 +55,8 @@ dev = [ [tool.uv] index-url = "https://pypi.tuna.tsinghua.edu.cn/simple" +[tool.uv.sources] +maibot-plugin-sdk = { path = "packages/maibot-plugin-sdk", editable = true } [tool.ruff] diff --git a/pytests/common_test/test_expression_auto_check_task.py b/pytests/common_test/test_expression_auto_check_task.py new file mode 100644 index 00000000..da8c59e1 --- /dev/null +++ b/pytests/common_test/test_expression_auto_check_task.py @@ -0,0 +1,89 @@ +"""测试表达方式自动检查任务的数据库读取行为。""" + +from contextlib import contextmanager +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask +from src.common.database.database_model import Expression + + +@pytest.fixture(name="expression_auto_check_engine") +def expression_auto_check_engine_fixture() -> Generator: + """创建用于表达方式自动检查任务测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +@pytest.mark.asyncio +async def test_select_expressions_uses_read_only_session( + monkeypatch: pytest.MonkeyPatch, + expression_auto_check_engine, +) -> None: + """选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。""" + + import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module + + with Session(expression_auto_check_engine) as session: + session.add( + Expression( + situation="表达情绪高涨或生理反应", + style="发送💦表情符号", + content_list='["表达情绪高涨或生理反应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + ) + session.commit() + + auto_commit_calls: list[bool] = [] + + @contextmanager + def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: + """构造带自动提交语义的测试会话工厂。 + + Args: + auto_commit: 退出上下文时是否自动提交。 + + Yields: + Generator[Session, None, None]: SQLModel 会话对象。 + """ + + auto_commit_calls.append(auto_commit) + session = Session(expression_auto_check_engine) + try: + yield session + if auto_commit: + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session) + monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries)) + + task = ExpressionAutoCheckTask() + expressions = await task._select_expressions(1) + + assert auto_commit_calls == [False] + assert len(expressions) == 1 + assert expressions[0].id is not None + assert expressions[0].situation == "表达情绪高涨或生理反应" + assert expressions[0].style == "发送💦表情符号" diff --git a/pytests/common_test/test_expression_learner.py b/pytests/common_test/test_expression_learner.py new file mode 100644 index 00000000..951aa424 --- /dev/null +++ b/pytests/common_test/test_expression_learner.py @@ -0,0 +1,81 @@ +"""测试表达方式学习器的数据库读取行为。""" + +from contextlib import contextmanager +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.bw_learner.expression_learner import ExpressionLearner +from src.common.database.database_model import Expression + + +@pytest.fixture(name="expression_learner_engine") +def expression_learner_engine_fixture() -> Generator: + """创建用于表达方式学习器测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_find_similar_expression_uses_read_only_session_and_history_content( + monkeypatch: pytest.MonkeyPatch, + expression_learner_engine, +) -> None: + """查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。""" + import src.bw_learner.expression_learner as expression_learner_module + + with Session(expression_learner_engine) as session: + session.add( + Expression( + situation="发送汗滴表情", + style="发送💦表情符号", + content_list='["表达情绪高涨或生理反应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + ) + session.commit() + + @contextmanager + def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: + """构造带自动提交语义的测试会话工厂。 + + Args: + auto_commit: 退出上下文时是否自动提交。 + + Yields: + Generator[Session, None, None]: SQLModel 会话对象。 + """ + session = Session(expression_learner_engine) + try: + yield session + if auto_commit: + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session) + + learner = ExpressionLearner(session_id="session-a") + result = learner._find_similar_expression("表达情绪高涨或生理反应") + + assert result is not None + expression, similarity = result + assert expression.item_id is not None + assert expression.style == "发送💦表情符号" + assert similarity == pytest.approx(1.0) diff --git a/pytests/common_test/test_expression_schema.py b/pytests/common_test/test_expression_schema.py new file mode 100644 index 00000000..31fcd98f --- /dev/null +++ b/pytests/common_test/test_expression_schema.py @@ -0,0 +1,78 @@ +"""测试表达方式表结构和基础插入行为。""" + +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.common.database.database_model import Expression + + +@pytest.fixture(name="expression_engine") +def expression_engine_fixture() -> Generator: + """创建仅用于表达方式表测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None: + """表达方式表在新库中应能自动分配自增主键。""" + with Session(expression_engine) as session: + expression = Expression( + situation="表达情绪高涨或生理反应", + style="发送💦表情符号", + content_list='["表达情绪高涨或生理反应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + session.add(expression) + session.commit() + session.refresh(expression) + + assert expression.id is not None + assert expression.id > 0 + + +def test_expression_insert_allows_same_situation_style(expression_engine) -> None: + """相同情景和风格的表达方式记录不应再被错误绑定到复合主键。""" + with Session(expression_engine) as session: + first_expression = Expression( + situation="对重复行为的默契响应", + style="持续性跟发相同内容", + content_list='["对重复行为的默契响应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + second_expression = Expression( + situation="对重复行为的默契响应", + style="持续性跟发相同内容", + content_list='["对重复行为的默契响应-变体"]', + count=2, + session_id="session-b", + checked=False, + rejected=False, + ) + + session.add(first_expression) + session.add(second_expression) + session.commit() + session.refresh(first_expression) + session.refresh(second_expression) + + assert first_expression.id is not None + assert second_expression.id is not None + assert first_expression.id != second_expression.id diff --git a/pytests/common_test/test_jargon_miner.py b/pytests/common_test/test_jargon_miner.py new file mode 100644 index 00000000..bf81e4d2 --- /dev/null +++ b/pytests/common_test/test_jargon_miner.py @@ -0,0 +1,90 @@ +"""测试黑话学习器的数据库读取行为。""" + +from contextlib import contextmanager +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine, select + +from src.bw_learner.jargon_miner import JargonMiner +from src.common.database.database_model import Jargon + + +@pytest.fixture(name="jargon_miner_engine") +def jargon_miner_engine_fixture() -> Generator: + """创建用于黑话学习器测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +@pytest.mark.asyncio +async def test_process_extracted_entries_updates_existing_jargon_without_detached_session( + monkeypatch: pytest.MonkeyPatch, + jargon_miner_engine, +) -> None: + """更新已有黑话时,不应因会话关闭导致 ORM 实例失效。""" + import src.bw_learner.jargon_miner as jargon_miner_module + + with Session(jargon_miner_engine) as session: + session.add( + Jargon( + content="VF8V4L", + raw_content='["[1] first"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=0, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + ) + session.commit() + + @contextmanager + def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: + """构造带自动提交语义的测试会话工厂。 + + Args: + auto_commit: 退出上下文时是否自动提交。 + + Yields: + Generator[Session, None, None]: SQLModel 会话对象。 + """ + session = Session(jargon_miner_engine) + try: + yield session + if auto_commit: + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session) + + jargon_miner = JargonMiner(session_id="session-a", session_name="测试群") + await jargon_miner.process_extracted_entries( + [{"content": "VF8V4L", "raw_content": {"[2] second"}}], + ) + + with Session(jargon_miner_engine) as session: + db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one() + + assert db_jargon.count == 1 + assert db_jargon.session_id_dict == '{"session-a": 2}' + assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [ + "[1] first", + "[2] second", + ] diff --git a/pytests/common_test/test_jargon_schema.py b/pytests/common_test/test_jargon_schema.py new file mode 100644 index 00000000..909392ab --- /dev/null +++ b/pytests/common_test/test_jargon_schema.py @@ -0,0 +1,84 @@ +"""测试黑话表结构和基础插入行为。""" + +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.common.database.database_model import Jargon + + +@pytest.fixture(name="jargon_engine") +def jargon_engine_fixture() -> Generator: + """创建仅用于黑话表测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None: + """黑话表在新库中应能自动分配自增主键。""" + with Session(jargon_engine) as session: + jargon = Jargon( + content="VF8V4L", + raw_content='["[1] test"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=True, + last_inference_count=0, + ) + session.add(jargon) + session.commit() + session.refresh(jargon) + + assert jargon.id is not None + assert jargon.id > 0 + + +def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None: + """黑话内容不应再被错误地绑成复合主键的一部分。""" + with Session(jargon_engine) as session: + first_jargon = Jargon( + content="表情1", + raw_content='["[1] first"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + second_jargon = Jargon( + content="表情1", + raw_content='["[1] second"]', + meaning="", + session_id_dict='{"session-b": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + + session.add(first_jargon) + session.add(second_jargon) + session.commit() + session.refresh(first_jargon) + session.refresh(second_jargon) + + assert first_jargon.id is not None + assert second_jargon.id is not None + assert first_jargon.id != second_jargon.id diff --git a/pytests/common_test/test_person_info_group_cardname.py b/pytests/common_test/test_person_info_group_cardname.py new file mode 100644 index 00000000..62a63f43 --- /dev/null +++ b/pytests/common_test/test_person_info_group_cardname.py @@ -0,0 +1,355 @@ +"""人物信息群名片字段兼容测试。""" + +from __future__ import annotations + +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import json +import sys + +import pytest + +from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json + + +class _DummyLogger: + """模拟日志记录器。""" + + def debug(self, message: str) -> None: + """记录调试日志。 + + Args: + message: 日志内容。 + """ + del message + + def info(self, message: str) -> None: + """记录信息日志。 + + Args: + message: 日志内容。 + """ + del message + + def warning(self, message: str) -> None: + """记录警告日志。 + + Args: + message: 日志内容。 + """ + del message + + def error(self, message: str) -> None: + """记录错误日志。 + + Args: + message: 日志内容。 + """ + del message + + +class _DummyStatement: + """模拟 SQL 查询语句对象。""" + + def where(self, condition: Any) -> "_DummyStatement": + """附加过滤条件。 + + Args: + condition: 过滤条件。 + + Returns: + _DummyStatement: 当前语句对象。 + """ + del condition + return self + + def limit(self, value: int) -> "_DummyStatement": + """限制返回条数。 + + Args: + value: 条数限制。 + + Returns: + _DummyStatement: 当前语句对象。 + """ + del value + return self + + +class _DummyColumn: + """模拟 SQLModel 列对象。""" + + def is_not(self, value: Any) -> "_DummyColumn": + """模拟 `IS NOT` 条件构造。 + + Args: + value: 比较值。 + + Returns: + _DummyColumn: 当前列对象。 + """ + del value + return self + + def __eq__(self, other: Any) -> "_DummyColumn": + """模拟等值条件构造。 + + Args: + other: 比较值。 + + Returns: + _DummyColumn: 当前列对象。 + """ + del other + return self + + +class _DummyResult: + """模拟数据库查询结果。""" + + def __init__(self, record: Any) -> None: + """初始化查询结果。 + + Args: + record: 待返回的首条记录。 + """ + self._record = record + + def first(self) -> Any: + """返回第一条记录。 + + Returns: + Any: 首条记录。 + """ + return self._record + + def all(self) -> list[Any]: + """返回全部结果。 + + Returns: + list[Any]: 结果列表。 + """ + if self._record is None: + return [] + return self._record if isinstance(self._record, list) else [self._record] + + +class _DummySession: + """模拟数据库 Session。""" + + def __init__(self, record: Any) -> None: + """初始化 Session。 + + Args: + record: `first()` 应返回的记录。 + """ + self.record = record + self.added_records: list[Any] = [] + + def __enter__(self) -> "_DummySession": + """进入上下文管理器。 + + Returns: + _DummySession: 当前 Session。 + """ + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """退出上下文管理器。 + + Args: + exc_type: 异常类型。 + exc_val: 异常值。 + exc_tb: 异常回溯。 + """ + del exc_type + del exc_val + del exc_tb + + def exec(self, statement: Any) -> _DummyResult: + """执行查询。 + + Args: + statement: 查询语句。 + + Returns: + _DummyResult: 模拟结果对象。 + """ + del statement + return _DummyResult(self.record) + + def add(self, record: Any) -> None: + """记录被添加的对象。 + + Args: + record: 被写入 Session 的对象。 + """ + self.added_records.append(record) + + +class _DummyPersonInfoRecord: + """模拟 `PersonInfo` ORM 模型。""" + + person_id = "person_id" + person_name = "person_name" + + def __init__(self, **kwargs: Any) -> None: + """使用关键字参数初始化记录对象。 + + Args: + **kwargs: 字段值。 + """ + for key, value in kwargs.items(): + setattr(self, key, value) + + +def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType: + """加载带依赖桩的 `person_info` 模块。 + + Args: + monkeypatch: Pytest monkeypatch 工具。 + session: 提供给模块使用的假数据库 Session。 + + Returns: + ModuleType: 加载后的模块对象。 + """ + logger_module = ModuleType("src.common.logger") + logger_module.get_logger = lambda name: _DummyLogger() + monkeypatch.setitem(sys.modules, "src.common.logger", logger_module) + + database_module = ModuleType("src.common.database.database") + database_module.get_db_session = lambda: session + monkeypatch.setitem(sys.modules, "src.common.database.database", database_module) + + database_model_module = ModuleType("src.common.database.database_model") + database_model_module.PersonInfo = _DummyPersonInfoRecord + monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module) + + llm_module = ModuleType("src.llm_models.utils_model") + + class _DummyLLMRequest: + """模拟 LLMRequest。""" + + def __init__(self, model_set: Any, request_type: str) -> None: + """初始化假请求对象。 + + Args: + model_set: 模型配置。 + request_type: 请求类型。 + """ + del model_set + del request_type + + llm_module.LLMRequest = _DummyLLMRequest + monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module) + + config_module = ModuleType("src.config.config") + config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot")) + config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils")) + monkeypatch.setitem(sys.modules, "src.config.config", config_module) + + chat_manager_module = ModuleType("src.chat.message_receive.chat_manager") + chat_manager_module.chat_manager = SimpleNamespace() + monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module) + + module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py" + spec = spec_from_file_location("person_info_group_cardname_test_module", module_path) + assert spec is not None and spec.loader is not None + + module = module_from_spec(spec) + monkeypatch.setitem(sys.modules, spec.name, module) + spec.loader.exec_module(module) + + monkeypatch.setattr(module, "select", lambda *args: _DummyStatement()) + monkeypatch.setattr(module, "col", lambda field: _DummyColumn()) + return module + + +def test_parse_group_cardname_json_uses_canonical_key() -> None: + """群名片 JSON 解析应只使用 `group_cardname` 键名。""" + parsed = parse_group_cardname_json( + json.dumps( + [ + {"group_id": "1001", "group_cardname": "现行字段"}, + ], + ensure_ascii=False, + ) + ) + + assert parsed is not None + assert [(item.group_id, item.group_cardname) for item in parsed] == [ + ("1001", "现行字段"), + ] + + +def test_dump_group_cardname_records_uses_canonical_key() -> None: + """群名片序列化应输出 `group_cardname` 键名。""" + dumped = dump_group_cardname_records( + [ + {"group_id": "1001", "group_cardname": "群昵称"}, + ] + ) + + assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}] + + +def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None: + """同步人物信息时应写入数据库模型的 `group_cardname` 字段。""" + record = _DummyPersonInfoRecord() + session = _DummySession(record) + module = _load_person_module(monkeypatch, session) + + person = module.Person.__new__(module.Person) + person.is_known = True + person.person_id = "person-1" + person.platform = "qq" + person.user_id = "10001" + person.nickname = "看番的龙" + person.person_name = "看番的龙" + person.name_reason = "测试" + person.know_times = 1 + person.know_since = 1700000000.0 + person.last_know = 1700000100.0 + person.memory_points = ["喜好:番剧:0.8"] + person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}] + + person.sync_to_database() + + assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]' + assert not hasattr(record, "group_nickname") + + +def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None: + """从数据库加载人物信息时应读取标准 `group_cardname` 结构。""" + record = _DummyPersonInfoRecord( + user_id="10001", + platform="qq", + is_known=True, + user_nickname="看番的龙", + person_name="看番的龙", + name_reason=None, + know_counts=2, + memory_points='["喜好:番剧:0.8"]', + group_cardname=json.dumps( + [ + {"group_id": "20001", "group_cardname": "白泽大人"}, + ], + ensure_ascii=False, + ), + ) + session = _DummySession(record) + module = _load_person_module(monkeypatch, session) + + person = module.Person.__new__(module.Person) + person.person_id = "person-1" + person.memory_points = [] + person.group_cardname_list = [] + + person.load_from_database() + + assert person.group_cardname_list == [ + {"group_id": "20001", "group_cardname": "白泽大人"}, + ] diff --git a/pytests/test_message_gateway_runtime.py b/pytests/test_message_gateway_runtime.py new file mode 100644 index 00000000..9650bc10 --- /dev/null +++ b/pytests/test_message_gateway_runtime.py @@ -0,0 +1,170 @@ +"""消息网关运行时状态同步测试。""" + +from typing import Any, Dict + +import pytest + +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import RouteKey +from src.plugin_runtime.host.supervisor import PluginSupervisor +from src.plugin_runtime.protocol.envelope import Envelope, MessageType + + +def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope: + """构造一个 RPC 请求信封。 + + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + payload: 请求载荷。 + + Returns: + Envelope: 标准 RPC 请求信封。 + """ + + return Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method=method, + plugin_id=plugin_id, + payload=payload, + ) + + +@pytest.mark.asyncio +async def test_message_gateway_runtime_state_binds_send_and_receive_routes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """消息网关就绪后应同时绑定发送表和接收表。""" + + import src.plugin_runtime.host.supervisor as supervisor_module + + platform_io_manager = PlatformIOManager() + monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) + + supervisor = PluginSupervisor(plugin_dirs=[]) + register_response = await supervisor._handle_register_plugin( + _make_request( + "plugin.register_components", + "napcat_plugin", + { + "plugin_id": "napcat_plugin", + "plugin_version": "1.0.0", + "components": [ + { + "name": "napcat_gateway", + "component_type": "MESSAGE_GATEWAY", + "plugin_id": "napcat_plugin", + "metadata": { + "route_type": "duplex", + "platform": "qq", + "protocol": "napcat", + }, + } + ], + "capabilities_required": [], + }, + ) + ) + + assert register_response.error is None + response = await supervisor._handle_update_message_gateway_state( + _make_request( + "host.update_message_gateway_state", + "napcat_plugin", + { + "gateway_name": "napcat_gateway", + "ready": True, + "platform": "qq", + "account_id": "10001", + "scope": "primary", + "metadata": {}, + }, + ) + ) + + assert response.error is None + assert response.payload["accepted"] is True + + send_bindings = platform_io_manager.send_route_table.resolve_bindings( + RouteKey(platform="qq", account_id="10001", scope="primary") + ) + receive_bindings = platform_io_manager.receive_route_table.resolve_bindings( + RouteKey(platform="qq", account_id="10001", scope="primary") + ) + + assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"] + assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"] + + +@pytest.mark.asyncio +async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """消息网关断开后应撤销发送表和接收表中的绑定。""" + + import src.plugin_runtime.host.supervisor as supervisor_module + + platform_io_manager = PlatformIOManager() + monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) + + supervisor = PluginSupervisor(plugin_dirs=[]) + await supervisor._handle_register_plugin( + _make_request( + "plugin.register_components", + "napcat_plugin", + { + "plugin_id": "napcat_plugin", + "plugin_version": "1.0.0", + "components": [ + { + "name": "napcat_gateway", + "component_type": "MESSAGE_GATEWAY", + "plugin_id": "napcat_plugin", + "metadata": { + "route_type": "duplex", + "platform": "qq", + "protocol": "napcat", + }, + } + ], + "capabilities_required": [], + }, + ) + ) + + await supervisor._handle_update_message_gateway_state( + _make_request( + "host.update_message_gateway_state", + "napcat_plugin", + { + "gateway_name": "napcat_gateway", + "ready": True, + "platform": "qq", + "account_id": "10001", + "scope": "primary", + "metadata": {}, + }, + ) + ) + response = await supervisor._handle_update_message_gateway_state( + _make_request( + "host.update_message_gateway_state", + "napcat_plugin", + { + "gateway_name": "napcat_gateway", + "ready": False, + "platform": "qq", + "account_id": "", + "scope": "", + "metadata": {}, + }, + ) + ) + + assert response.error is None + assert response.payload["accepted"] is True + assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == [] + assert ( + platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == [] + ) diff --git a/pytests/test_napcat_adapter_sdk.py b/pytests/test_napcat_adapter_sdk.py new file mode 100644 index 00000000..c6b1fdbd --- /dev/null +++ b/pytests/test_napcat_adapter_sdk.py @@ -0,0 +1,132 @@ +"""NapCat 插件与新 SDK 对接测试。""" + +from pathlib import Path +from typing import Any, Dict, List + +import importlib +import logging +import sys + +import pytest + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +PLUGINS_ROOT = PROJECT_ROOT / "plugins" +SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" + +for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)): + if import_path not in sys.path: + sys.path.insert(0, import_path) + + +class _FakeGatewayCapability: + """用于捕获消息网关状态上报的测试替身。""" + + def __init__(self) -> None: + """初始化测试替身。""" + + self.calls: List[Dict[str, Any]] = [] + + async def update_state( + self, + gateway_name: str, + *, + ready: bool, + platform: str = "", + account_id: str = "", + scope: str = "", + metadata: Dict[str, Any] | None = None, + ) -> bool: + """记录一次状态上报请求。 + + Args: + gateway_name: 网关组件名称。 + ready: 当前是否就绪。 + platform: 平台名称。 + account_id: 账号 ID。 + scope: 路由作用域。 + metadata: 附加元数据。 + + Returns: + bool: 始终返回 ``True``,模拟 Host 接受状态更新。 + """ + + self.calls.append( + { + "gateway_name": gateway_name, + "ready": ready, + "platform": platform, + "account_id": account_id, + "scope": scope, + "metadata": metadata or {}, + } + ) + return True + + +def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]: + """动态加载 NapCat 插件测试所需的符号。 + + Returns: + tuple[Any, Any, Any, Any]: + 依次返回网关名常量、配置类、插件类和运行时状态管理器类。 + """ + + constants_module = importlib.import_module("napcat_adapter.constants") + config_module = importlib.import_module("napcat_adapter.config") + plugin_module = importlib.import_module("napcat_adapter.plugin") + runtime_state_module = importlib.import_module("napcat_adapter.runtime_state") + return ( + constants_module.NAPCAT_GATEWAY_NAME, + config_module.NapCatServerConfig, + plugin_module.NapCatAdapterPlugin, + runtime_state_module.NapCatRuntimeStateManager, + ) + + +def test_napcat_plugin_collects_duplex_message_gateway() -> None: + """NapCat 插件应声明新的双工消息网关组件。""" + + napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() + plugin = napcat_plugin_cls() + components = plugin.get_components() + gateway_components = [ + component + for component in components + if component.get("type") == "MESSAGE_GATEWAY" + ] + + assert len(gateway_components) == 1 + gateway_component = gateway_components[0] + assert gateway_component["name"] == napcat_gateway_name + assert gateway_component["metadata"]["route_type"] == "duplex" + assert gateway_component["metadata"]["platform"] == "qq" + assert gateway_component["metadata"]["protocol"] == "napcat" + + +@pytest.mark.asyncio +async def test_runtime_state_reports_via_gateway_capability() -> None: + """NapCat 运行时状态应通过新的消息网关能力上报。""" + + napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols() + gateway_capability = _FakeGatewayCapability() + runtime_state_manager = runtime_state_cls( + gateway_capability=gateway_capability, + logger=logging.getLogger("test.napcat_adapter"), + gateway_name=napcat_gateway_name, + ) + + connected = await runtime_state_manager.report_connected( + "10001", + napcat_server_config_cls(connection_id="primary"), + ) + await runtime_state_manager.report_disconnected() + + assert connected is True + assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name + assert gateway_capability.calls[0]["ready"] is True + assert gateway_capability.calls[0]["platform"] == "qq" + assert gateway_capability.calls[0]["account_id"] == "10001" + assert gateway_capability.calls[0]["scope"] == "primary" + assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name + assert gateway_capability.calls[1]["ready"] is False + assert gateway_capability.calls[1]["platform"] == "qq" diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py new file mode 100644 index 00000000..d6bdd1dd --- /dev/null +++ b/pytests/test_platform_io_dedupe.py @@ -0,0 +1,209 @@ +"""Platform IO 入站去重策略测试。""" + +from types import SimpleNamespace +from typing import Any, Dict, List, Optional + +import pytest + +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey + + +def _build_envelope( + *, + dedupe_key: str | None = None, + external_message_id: str | None = None, + session_message_id: str | None = None, + payload: Optional[Dict[str, Any]] = None, +) -> InboundMessageEnvelope: + """构造测试用入站信封。 + + Args: + dedupe_key: 显式去重键。 + external_message_id: 平台侧消息 ID。 + session_message_id: 规范化消息对象上的消息 ID。 + payload: 原始载荷。 + + Returns: + InboundMessageEnvelope: 测试用入站消息信封。 + """ + session_message = None + if session_message_id is not None: + session_message = SimpleNamespace(message_id=session_message_id) + + return InboundMessageEnvelope( + route_key=RouteKey(platform="qq", account_id="10001", scope="main"), + driver_id="plugin.napcat", + driver_kind=DriverKind.PLUGIN, + dedupe_key=dedupe_key, + external_message_id=external_message_id, + session_message=session_message, + payload=payload, + ) + + +class _StubPlatformIODriver(PlatformIODriver): + """测试用 Platform IO 驱动。""" + + async def send_message( + self, + message: Any, + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """返回一个固定的成功回执。 + + Args: + message: 待发送的消息对象。 + route_key: 本次发送使用的路由键。 + metadata: 额外发送元数据。 + + Returns: + DeliveryReceipt: 固定的成功回执。 + """ + return DeliveryReceipt( + internal_message_id=str(getattr(message, "message_id", "stub-message-id")), + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) + + +def _build_manager() -> PlatformIOManager: + """构造带有最小接收路由的 Broker 管理器。 + + Returns: + PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。 + """ + manager = PlatformIOManager() + driver = _StubPlatformIODriver( + DriverDescriptor( + driver_id="plugin.napcat", + kind=DriverKind.PLUGIN, + platform="qq", + account_id="10001", + scope="main", + ) + ) + manager.register_driver(driver) + manager.bind_receive_route( + RouteBinding( + route_key=RouteKey(platform="qq", account_id="10001", scope="main"), + driver_id=driver.driver_id, + driver_kind=driver.descriptor.kind, + ) + ) + return manager + + +class TestPlatformIODedupe: + """Platform IO 去重测试。""" + + @pytest.mark.asyncio + async def test_accept_inbound_dedupes_by_external_message_id(self) -> None: + """相同平台消息 ID 的重复入站应被抑制。""" + manager = _build_manager() + accepted_envelopes: List[InboundMessageEnvelope] = [] + + async def dispatcher(envelope: InboundMessageEnvelope) -> None: + """记录被成功接收的入站消息。 + + Args: + envelope: 被 Broker 接受的入站消息。 + """ + accepted_envelopes.append(envelope) + + manager.set_inbound_dispatcher(dispatcher) + + first_envelope = _build_envelope( + external_message_id="msg-1", + payload={"message": "hello"}, + ) + second_envelope = _build_envelope( + external_message_id="msg-1", + payload={"message": "hello"}, + ) + + assert await manager.accept_inbound(first_envelope) is True + assert await manager.accept_inbound(second_envelope) is False + assert len(accepted_envelopes) == 1 + + @pytest.mark.asyncio + async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None: + """缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。""" + manager = _build_manager() + accepted_envelopes: List[InboundMessageEnvelope] = [] + + async def dispatcher(envelope: InboundMessageEnvelope) -> None: + """记录被成功接收的入站消息。 + + Args: + envelope: 被 Broker 接受的入站消息。 + """ + accepted_envelopes.append(envelope) + + manager.set_inbound_dispatcher(dispatcher) + + first_envelope = _build_envelope(payload={"message": "same-payload"}) + second_envelope = _build_envelope(payload={"message": "same-payload"}) + + assert await manager.accept_inbound(first_envelope) is True + assert await manager.accept_inbound(second_envelope) is True + assert len(accepted_envelopes) == 2 + + def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None: + """去重键应只来自显式或稳定的技术身份。""" + explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1") + session_message_envelope = _build_envelope(session_message_id="session-1") + payload_only_envelope = _build_envelope(payload={"message": "hello"}) + + assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1" + assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1" + assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None + + @pytest.mark.asyncio + async def test_send_message_fans_out_to_all_matching_routes(self) -> None: + """同一路由命中多条发送链路时应全部发送。""" + + manager = PlatformIOManager() + first_driver = _StubPlatformIODriver( + DriverDescriptor( + driver_id="plugin.gateway_a", + kind=DriverKind.PLUGIN, + platform="qq", + ) + ) + second_driver = _StubPlatformIODriver( + DriverDescriptor( + driver_id="plugin.gateway_b", + kind=DriverKind.PLUGIN, + platform="qq", + ) + ) + manager.register_driver(first_driver) + manager.register_driver(second_driver) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=first_driver.driver_id, + driver_kind=first_driver.descriptor.kind, + ) + ) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=second_driver.driver_id, + driver_kind=second_driver.descriptor.kind, + ) + ) + + message = SimpleNamespace(message_id="internal-msg-1") + result = await manager.send_message(message, RouteKey(platform="qq")) + + assert result.has_success is True + assert [receipt.driver_id for receipt in result.sent_receipts] == [ + "plugin.gateway_a", + "plugin.gateway_b", + ] diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py new file mode 100644 index 00000000..76f14d8f --- /dev/null +++ b/pytests/test_platform_io_legacy_driver.py @@ -0,0 +1,178 @@ +"""Platform IO legacy driver 回归测试。""" + +from typing import Any, Dict, Optional + +import pytest + +from src.chat.utils import utils as chat_utils +from src.chat.message_receive import uni_message_sender +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey + + +class _PluginDriver(PlatformIODriver): + """测试用插件发送驱动。""" + + def __init__(self, driver_id: str, platform: str) -> None: + """初始化测试驱动。 + + Args: + driver_id: 驱动 ID。 + platform: 负责的平台名称。 + """ + super().__init__( + DriverDescriptor( + driver_id=driver_id, + kind=DriverKind.PLUGIN, + platform=platform, + plugin_id="test.plugin", + ) + ) + + async def send_message( + self, + message: Any, + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """返回一个固定成功回执。 + + Args: + message: 待发送消息。 + route_key: 当前路由键。 + metadata: 发送元数据。 + + Returns: + DeliveryReceipt: 固定成功回执。 + """ + del metadata + return DeliveryReceipt( + internal_message_id=str(message.message_id), + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) + + +@pytest.mark.asyncio +async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """没有显式发送路由时,应由 Platform IO 回退到 legacy driver。""" + manager = PlatformIOManager() + monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"}) + + try: + await manager.ensure_send_pipeline_ready() + + fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq")) + assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"] + + plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq") + await manager.add_driver(plugin_driver) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=plugin_driver.driver_id, + driver_kind=plugin_driver.descriptor.kind, + ) + ) + + explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq")) + assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"] + finally: + await manager.stop() + + +@pytest.mark.asyncio +async def test_platform_io_broadcasts_to_plugin_and_legacy_driver( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """同一路由命中插件驱动与 legacy driver 时,应同时广播发送。""" + + manager = PlatformIOManager() + legacy_calls: list[dict[str, Any]] = [] + monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"}) + + async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool: + """记录 legacy driver 调用。""" + + legacy_calls.append({"message": message, "show_log": show_log}) + return True + + monkeypatch.setattr( + uni_message_sender, + "send_prepared_message_to_platform", + _fake_send_prepared_message_to_platform, + ) + + try: + await manager.ensure_send_pipeline_ready() + + plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq") + await manager.add_driver(plugin_driver) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=plugin_driver.driver_id, + driver_kind=plugin_driver.descriptor.kind, + ) + ) + + message = type("FakeMessage", (), {"message_id": "message-1"})() + batch = await manager.send_message( + message=message, + route_key=RouteKey(platform="qq"), + metadata={"show_log": False}, + ) + + assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [ + "legacy.send.qq", + "plugin.qq.sender", + ] + assert batch.failed_receipts == [] + assert len(legacy_calls) == 1 + assert legacy_calls[0]["message"] is message + assert legacy_calls[0]["show_log"] is False + finally: + await manager.stop() + + +@pytest.mark.asyncio +async def test_legacy_platform_driver_uses_prepared_universal_sender( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """legacy driver 应复用已预处理消息的旧链发送函数。""" + calls: list[dict[str, Any]] = [] + + async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool: + """记录 legacy driver 调用。""" + calls.append({"message": message, "show_log": show_log}) + return True + + monkeypatch.setattr( + uni_message_sender, + "send_prepared_message_to_platform", + _fake_send_prepared_message_to_platform, + ) + + driver = LegacyPlatformDriver( + driver_id="legacy.send.qq", + platform="qq", + account_id="bot-qq", + ) + message = type("FakeMessage", (), {"message_id": "message-1"})() + receipt = await driver.send_message( + message=message, + route_key=RouteKey(platform="qq"), + metadata={"show_log": False}, + ) + + assert len(calls) == 1 + assert calls[0]["message"] is message + assert calls[0]["show_log"] is False + assert receipt.status == DeliveryStatus.SENT + assert receipt.driver_id == "legacy.send.qq" diff --git a/pytests/test_plugin_message_utils_runtime.py b/pytests/test_plugin_message_utils_runtime.py new file mode 100644 index 00000000..cb4b5341 --- /dev/null +++ b/pytests/test_plugin_message_utils_runtime.py @@ -0,0 +1,87 @@ +from datetime import datetime +from pathlib import Path + +import sys + +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo +from src.common.data_models.message_component_data_model import ( + ForwardComponent, + ForwardNodeComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + TextComponent, + VoiceComponent, +) +from src.plugin_runtime.host.message_utils import PluginMessageUtils + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None: + message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq") + message.message_info = MessageInfo( + user_info=UserInfo(user_id="10001", user_nickname="tester"), + group_info=GroupInfo(group_id="20001", group_name="group"), + additional_config={"self_id": "999"}, + ) + message.session_id = "qq:20001:10001" + message.processed_plain_text = "binary payload" + message.display_message = "binary payload" + message.raw_message = MessageSequence( + components=[ + TextComponent("hello"), + ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""), + VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""), + ReplyComponent( + target_message_id="origin-1", + target_message_content="origin text", + target_message_sender_id="42", + target_message_sender_nickname="alice", + target_message_sender_cardname="Alice", + ), + ForwardNodeComponent( + forward_components=[ + ForwardComponent( + user_nickname="bob", + user_id="43", + user_cardname="Bob", + message_id="forward-1", + content=[ + TextComponent("node-text"), + ImageComponent(binary_hash="", binary_data=b"node-image", content=""), + ], + ) + ] + ), + ] + ) + + message_dict = PluginMessageUtils._session_message_to_dict(message) + rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict)) + + image_component = rebuilt_message.raw_message.components[1] + voice_component = rebuilt_message.raw_message.components[2] + reply_component = rebuilt_message.raw_message.components[3] + forward_component = rebuilt_message.raw_message.components[4] + + assert isinstance(image_component, ImageComponent) + assert image_component.binary_data == b"image-bytes" + + assert isinstance(voice_component, VoiceComponent) + assert voice_component.binary_data == b"voice-bytes" + + assert isinstance(reply_component, ReplyComponent) + assert reply_component.target_message_id == "origin-1" + assert reply_component.target_message_content == "origin text" + assert reply_component.target_message_sender_id == "42" + assert reply_component.target_message_sender_nickname == "alice" + assert reply_component.target_message_sender_cardname == "Alice" + + assert isinstance(forward_component, ForwardNodeComponent) + assert isinstance(forward_component.forward_components[0].content[1], ImageComponent) + assert forward_component.forward_components[0].content[1].binary_data == b"node-image" diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 2c703161..e3247f05 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -3,6 +3,7 @@ 验证协议层、传输层、RPC 通信链路的正确性。 """ +from pathlib import Path from types import SimpleNamespace import asyncio @@ -18,6 +19,104 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk")) +def build_test_manifest( + plugin_id: str, + *, + version: str = "1.0.0", + name: str = "测试插件", + description: str = "测试插件描述", + dependencies: list[dict[str, str]] | None = None, + capabilities: list[str] | None = None, + host_min_version: str = "0.12.0", + host_max_version: str = "1.0.0", + sdk_min_version: str = "2.0.0", + sdk_max_version: str = "2.99.99", +) -> dict[str, object]: + """构造一个合法的 Manifest v2 测试样例。 + + Args: + plugin_id: 插件 ID。 + version: 插件版本。 + name: 展示名称。 + description: 插件描述。 + dependencies: 依赖声明列表。 + capabilities: 能力声明列表。 + host_min_version: Host 最低支持版本。 + host_max_version: Host 最高支持版本。 + sdk_min_version: SDK 最低支持版本。 + sdk_max_version: SDK 最高支持版本。 + + Returns: + dict[str, object]: 可直接序列化为 ``_manifest.json`` 的字典。 + """ + return { + "manifest_version": 2, + "version": version, + "name": name, + "description": description, + "author": { + "name": "tester", + "url": "https://example.com/tester", + }, + "license": "MIT", + "urls": { + "repository": f"https://example.com/{plugin_id}", + }, + "host_application": { + "min_version": host_min_version, + "max_version": host_max_version, + }, + "sdk": { + "min_version": sdk_min_version, + "max_version": sdk_max_version, + }, + "dependencies": dependencies or [], + "capabilities": capabilities or [], + "i18n": { + "default_locale": "zh-CN", + "supported_locales": ["zh-CN"], + }, + "id": plugin_id, + } + + +def build_test_manifest_model( + plugin_id: str, + *, + version: str = "1.0.0", + dependencies: list[dict[str, str]] | None = None, + capabilities: list[str] | None = None, + host_version: str = "1.0.0", + sdk_version: str = "2.0.1", +) -> object: + """构造一个已经通过校验的强类型 Manifest 测试对象。 + + Args: + plugin_id: 插件 ID。 + version: 插件版本。 + dependencies: 依赖声明列表。 + capabilities: 能力声明列表。 + host_version: 当前测试使用的 Host 版本。 + sdk_version: 当前测试使用的 SDK 版本。 + + Returns: + object: ``PluginManifest`` 实例。 + """ + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version=host_version, sdk_version=sdk_version) + manifest = validator.parse_manifest( + build_test_manifest( + plugin_id, + version=version, + dependencies=dependencies, + capabilities=capabilities, + ) + ) + assert manifest is not None + return manifest + + # ─── 协议层测试 ─────────────────────────────────────────── @@ -441,8 +540,8 @@ class TestSDK: def set_plugin_config(self, config): self.configs.append(config) - async def on_config_update(self, config, version): - self.updates.append((config, version, list(self.configs))) + async def on_config_update(self, scope, config, version): + self.updates.append((scope, config, version, list(self.configs))) runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) plugin = DummyPlugin() @@ -453,14 +552,60 @@ class TestSDK: message_type=MessageType.REQUEST, method="plugin.config_updated", plugin_id="demo_plugin", - payload={"config_data": {"enabled": True}, "config_version": "v2"}, + payload={ + "plugin_id": "demo_plugin", + "config_scope": "self", + "config_data": {"enabled": True}, + "config_version": "v2", + }, ) response = await runner._handle_config_updated(envelope) assert response.payload["acknowledged"] is True assert plugin.configs == [{"enabled": True}] - assert plugin.updates == [({"enabled": True}, "v2", [{"enabled": True}])] + assert plugin.updates == [("self", {"enabled": True}, "v2", [{"enabled": True}])] + + @pytest.mark.asyncio + async def test_runner_global_config_update_does_not_override_plugin_config(self): + """bot/model 广播不应覆盖插件自身配置缓存。""" + from src.plugin_runtime.protocol.envelope import Envelope, MessageType + from src.plugin_runtime.runner.runner_main import PluginRunner + + class DummyPlugin: + def __init__(self): + self.configs = [] + self.updates = [] + + def set_plugin_config(self, config): + self.configs.append(config) + + async def on_config_update(self, scope, config, version): + self.updates.append((scope, config, version, list(self.configs))) + + runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) + plugin = DummyPlugin() + runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin) + plugin.set_plugin_config({"plugin_enabled": True}) + + envelope = Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="plugin.config_updated", + plugin_id="demo_plugin", + payload={ + "plugin_id": "demo_plugin", + "config_scope": "model", + "config_data": {"models": []}, + "config_version": "", + }, + ) + + response = await runner._handle_config_updated(envelope) + + assert response.payload["acknowledged"] is True + assert plugin.configs == [{"plugin_enabled": True}] + assert plugin.updates == [("model", {"models": []}, "", [{"plugin_enabled": True}])] @pytest.mark.asyncio async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch): @@ -486,10 +631,10 @@ class TestSDK: "timeout_ms": timeout_ms, } ) - if method == "cap.request": + if method == "cap.call": bootstrap_methods = [call["method"] for call in self.calls[:-1]] assert "plugin.bootstrap" in bootstrap_methods - return SimpleNamespace(error=None, payload={"result": {"success": True}}) + return SimpleNamespace(error=None, payload={"success": True}) return SimpleNamespace(error=None, payload={"accepted": True}) async def disconnect(self): @@ -529,7 +674,102 @@ class TestSDK: await runner.run() methods = [call["method"] for call in runner._rpc_client.calls] - assert methods == ["plugin.bootstrap", "cap.request", "plugin.register_components", "runner.ready"] + assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"] + + @pytest.mark.asyncio + async def test_runner_batch_reload_merges_overlapping_reverse_dependents(self, monkeypatch): + """批量重载应只对重叠依赖闭包执行一次 unload/load。""" + from src.plugin_runtime.runner.runner_main import PluginRunner + + runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) + plugin_a_id = "test.plugin-a" + plugin_b_id = "test.plugin-b" + plugin_c_id = "test.plugin-c" + + def build_meta(plugin_id: str, dependencies: list[str]) -> SimpleNamespace: + return SimpleNamespace( + plugin_id=plugin_id, + dependencies=dependencies, + plugin_dir=f"/tmp/{plugin_id}", + version="1.0.0", + instance=SimpleNamespace(), + ) + + loaded_metas = { + plugin_a_id: build_meta(plugin_a_id, []), + plugin_b_id: build_meta(plugin_b_id, [plugin_a_id]), + plugin_c_id: build_meta(plugin_c_id, [plugin_b_id]), + } + reloaded_metas = { + plugin_id: build_meta(plugin_id, list(meta.dependencies)) + for plugin_id, meta in loaded_metas.items() + } + candidates = { + plugin_a_id: ( + "dir_plugin_a", + build_test_manifest_model(plugin_a_id), + "plugin_a/plugin.py", + ), + plugin_b_id: ( + "dir_plugin_b", + build_test_manifest_model( + plugin_b_id, + dependencies=[{"type": "plugin", "id": plugin_a_id, "version_spec": ">=1.0.0,<2.0.0"}], + ), + "plugin_b/plugin.py", + ), + plugin_c_id: ( + "dir_plugin_c", + build_test_manifest_model( + plugin_c_id, + dependencies=[{"type": "plugin", "id": plugin_b_id, "version_spec": ">=1.0.0,<2.0.0"}], + ), + "plugin_c/plugin.py", + ), + } + unloaded_plugins: list[str] = [] + activated_plugins: list[str] = [] + + monkeypatch.setattr(runner._loader, "discover_candidates", lambda plugin_dirs: (candidates, {})) + monkeypatch.setattr(runner._loader, "list_plugins", lambda: sorted(loaded_metas.keys())) + monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: loaded_metas.get(plugin_id)) + monkeypatch.setattr( + runner._loader, + "remove_loaded_plugin", + lambda plugin_id: loaded_metas.pop(plugin_id, None), + ) + monkeypatch.setattr(runner._loader, "purge_plugin_modules", lambda plugin_id, plugin_dir: []) + monkeypatch.setattr( + runner._loader, + "resolve_dependencies", + lambda reload_candidates, extra_available=None: (sorted(reload_candidates.keys()), {}), + ) + monkeypatch.setattr( + runner._loader, + "load_candidate", + lambda plugin_id, candidate: reloaded_metas[plugin_id], + ) + + async def fake_unload_plugin(meta, reason, purge_modules=False): + del reason, purge_modules + unloaded_plugins.append(meta.plugin_id) + loaded_metas.pop(meta.plugin_id, None) + + async def fake_activate_plugin(meta): + activated_plugins.append(meta.plugin_id) + loaded_metas[meta.plugin_id] = meta + return True + + monkeypatch.setattr(runner, "_unload_plugin", fake_unload_plugin) + monkeypatch.setattr(runner, "_activate_plugin", fake_activate_plugin) + + result = await runner._reload_plugins_by_ids([plugin_a_id, plugin_b_id], reason="manual") + + assert result.success is True + assert result.requested_plugin_ids == [plugin_a_id, plugin_b_id] + assert unloaded_plugins == [plugin_c_id, plugin_b_id, plugin_a_id] + assert activated_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] + assert result.reloaded_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] class TestPluginSdkUsage: @@ -712,65 +952,77 @@ class TestManifestValidator: def test_valid_manifest(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator() - manifest = { - "manifest_version": 1, - "name": "test_plugin", - "version": "1.0.0", - "description": "测试插件", - "author": "test", - } + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest("test.valid-plugin", capabilities=["send.text"]) assert validator.validate(manifest) is True assert len(validator.errors) == 0 + assert validator.warnings == [] def test_missing_required_fields(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator() - manifest = {"manifest_version": 1} + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = {"manifest_version": 2} assert validator.validate(manifest) is False - assert len(validator.errors) >= 4 # name, version, description, author + assert len(validator.errors) >= 6 + assert any("缺少必需字段" in error for error in validator.errors) def test_unsupported_manifest_version(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator() - manifest = { - "manifest_version": 999, - "name": "test", - "version": "1.0", - "description": "d", - "author": "a", - } + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest("test.invalid-version") + manifest["manifest_version"] = 999 assert validator.validate(manifest) is False assert any("manifest_version" in e for e in validator.errors) def test_host_version_compatibility(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator(host_version="0.8.5") - manifest = { - "name": "test", - "version": "1.0", - "description": "d", - "author": "a", - "host_application": {"min_version": "0.9.0"}, - } + validator = ManifestValidator(host_version="0.8.5", sdk_version="2.0.1") + manifest = build_test_manifest( + "test.host-check", + host_min_version="0.9.0", + host_max_version="1.0.0", + ) assert validator.validate(manifest) is False assert any("Host 版本不兼容" in e for e in validator.errors) - def test_recommended_fields_warning(self): + def test_sdk_version_compatibility(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator() - manifest = { - "name": "test", - "version": "1.0", - "description": "d", - "author": "a", - } - validator.validate(manifest) - assert len(validator.warnings) >= 3 # license, keywords, categories + validator = ManifestValidator(host_version="1.0.0", sdk_version="1.9.9") + manifest = build_test_manifest("test.sdk-check") + assert validator.validate(manifest) is False + assert any("SDK 版本不兼容" in e for e in validator.errors) + + def test_extra_fields_are_rejected(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest("test.extra-field") + manifest["unexpected"] = True + + assert validator.validate(manifest) is False + assert any("存在未声明字段" in error for error in validator.errors) + + def test_python_package_conflict_rejects_manifest(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest( + "test.numpy-conflict", + dependencies=[ + { + "type": "python_package", + "name": "numpy", + "version_spec": ">=999.0.0", + } + ], + ) + + assert validator.validate(manifest) is False + assert any("Python 包依赖冲突" in error for error in validator.errors) class TestVersionComparator: @@ -812,59 +1064,83 @@ class TestDependencyResolution: loader = PluginLoader() candidates = { - "core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"), - "auth": ( - "dir_auth", - {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, + "test.core": ( + "dir_core", + build_test_manifest_model("test.core"), "plugin.py", ), - "api": ( + "test.auth": ( + "dir_auth", + build_test_manifest_model( + "test.auth", + dependencies=[ + {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), + "plugin.py", + ), + "test.api": ( "dir_api", - {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, + build_test_manifest_model( + "test.api", + dependencies=[ + {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"}, + {"type": "plugin", "id": "test.auth", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), "plugin.py", ), } order, failed = loader._resolve_dependencies(candidates) assert len(failed) == 0 - assert order.index("core") < order.index("auth") - assert order.index("auth") < order.index("api") + assert order.index("test.core") < order.index("test.auth") + assert order.index("test.auth") < order.index("test.api") def test_missing_dependency(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { - "plugin_a": ( + "test.plugin-a": ( "dir_a", - { - "name": "plugin_a", - "version": "1.0", - "description": "d", - "author": "a", - "dependencies": ["nonexistent"], - }, + build_test_manifest_model( + "test.plugin-a", + dependencies=[ + {"type": "plugin", "id": "test.nonexistent", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), "plugin.py", ), } order, failed = loader._resolve_dependencies(candidates) - assert "plugin_a" in failed - assert "缺少依赖" in failed["plugin_a"] + assert "test.plugin-a" in failed + assert "依赖未满足" in failed["test.plugin-a"] def test_circular_dependency(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { - "a": ( + "test.a": ( "dir_a", - {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, + build_test_manifest_model( + "test.a", + dependencies=[ + {"type": "plugin", "id": "test.b", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), "p.py", ), - "b": ( + "test.b": ( "dir_b", - {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, + build_test_manifest_model( + "test.b", + dependencies=[ + {"type": "plugin", "id": "test.a", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), "p.py", ), } @@ -882,12 +1158,11 @@ class TestDependencyResolution: (plugin_dir / "_manifest.json").write_text( json.dumps( - { - "name": "grok_search_plugin", - "version": "1.0.0", - "description": "demo", - "author": "tester", - } + build_test_manifest( + "test.grok-search-plugin", + name="grok_search_plugin", + description="demo", + ) ), encoding="utf-8", ) @@ -907,14 +1182,130 @@ class TestDependencyResolution: loader = PluginLoader() loaded = loader.discover_and_load([str(plugin_root)]) - assert [meta.plugin_id for meta in loaded] == ["grok_search_plugin"] + assert [meta.plugin_id for meta in loaded] == ["test.grok-search-plugin"] assert loader.failed_plugins == {} assert loaded[0].instance.answer() == 42 + def test_loader_requires_sdk_plugin_to_override_on_config_update(self, tmp_path): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + plugin_root = tmp_path / "plugins" + plugin_root.mkdir() + plugin_dir = plugin_root / "demo_plugin" + plugin_dir.mkdir() + + (plugin_dir / "_manifest.json").write_text( + json.dumps( + build_test_manifest( + "test.demo-plugin", + name="demo_plugin", + description="demo", + ) + ), + encoding="utf-8", + ) + (plugin_dir / "plugin.py").write_text( + "from maibot_sdk import MaiBotPlugin\n\n" + "class DemoPlugin(MaiBotPlugin):\n" + " async def on_load(self):\n" + " pass\n\n" + " async def on_unload(self):\n" + " pass\n\n" + "def create_plugin():\n" + " return DemoPlugin()\n", + encoding="utf-8", + ) + + loader = PluginLoader() + loaded = loader.discover_and_load([str(plugin_root)]) + + assert loaded == [] + assert "test.demo-plugin" in loader.failed_plugins + assert "on_config_update" in loader.failed_plugins["test.demo-plugin"] + + def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + plugin_root = tmp_path / "plugins" + plugin_root.mkdir() + plugin_dir = plugin_root / "demo_plugin" + plugin_dir.mkdir() + + (plugin_dir / "_manifest.json").write_text( + json.dumps( + build_test_manifest( + "test.demo-plugin", + name="demo_plugin", + description="demo", + ) + ), + encoding="utf-8", + ) + (plugin_dir / "plugin.py").write_text( + "from maibot_sdk import MaiBotPlugin\n\n" + "class DemoPlugin(MaiBotPlugin):\n" + " async def on_unload(self):\n" + " pass\n\n" + " async def on_config_update(self, scope, config_data, version):\n" + " pass\n\n" + "def create_plugin():\n" + " return DemoPlugin()\n", + encoding="utf-8", + ) + + loader = PluginLoader() + loaded = loader.discover_and_load([str(plugin_root)]) + + assert loaded == [] + assert "test.demo-plugin" in loader.failed_plugins + assert "on_load" in loader.failed_plugins["test.demo-plugin"] + + def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + plugin_root = tmp_path / "plugins" + plugin_root.mkdir() + plugin_dir = plugin_root / "demo_plugin" + plugin_dir.mkdir() + + (plugin_dir / "_manifest.json").write_text( + json.dumps( + build_test_manifest( + "test.demo-plugin", + name="demo_plugin", + description="demo", + ) + ), + encoding="utf-8", + ) + (plugin_dir / "plugin.py").write_text( + "from maibot_sdk import MaiBotPlugin\n\n" + "class DemoPlugin(MaiBotPlugin):\n" + " async def on_load(self):\n" + " pass\n\n" + " async def on_config_update(self, scope, config_data, version):\n" + " pass\n\n" + "def create_plugin():\n" + " return DemoPlugin()\n", + encoding="utf-8", + ) + + loader = PluginLoader() + loaded = loader.discover_and_load([str(plugin_root)]) + + assert loaded == [] + assert "test.demo-plugin" in loader.failed_plugins + assert "on_unload" in loader.failed_plugins["test.demo-plugin"] + def test_isolate_sys_path_preserves_plugin_dirs(self): + import builtins + import importlib + from src.plugin_runtime.runner import runner_main plugin_root = os.path.normpath("/tmp/maibot-plugin-root") + original_import = builtins.__import__ + original_import_module = importlib.import_module original_path = list(sys.path) original_meta_path = list(sys.meta_path) @@ -926,9 +1317,155 @@ class TestDependencyResolution: assert plugin_root in sys.path finally: + builtins.__import__ = original_import + importlib.import_module = original_import_module sys.path[:] = original_path sys.meta_path[:] = original_meta_path + def test_isolate_sys_path_blocks_disallowed_src_imports(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_import_module = importlib.import_module + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + sys.modules.pop("src.forbidden_demo", None) + + try: + runner_main._isolate_sys_path([]) + plugin_globals = { + "__name__": "_maibot_plugin_demo", + "__package__": "_maibot_plugin_demo", + "importlib": importlib, + } + + with pytest.raises(ImportError, match="不允许导入主程序模块"): + exec('importlib.import_module("src.forbidden_demo")', plugin_globals) + finally: + builtins.__import__ = original_import + importlib.import_module = original_import_module + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + sys.modules.pop("src.forbidden_demo", None) + + def test_isolate_sys_path_blocks_preloaded_runtime_modules(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_import_module = importlib.import_module + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + runner_main._isolate_sys_path([]) + plugin_globals = { + "__name__": "_maibot_plugin_demo", + "__package__": "_maibot_plugin_demo", + "importlib": importlib, + } + + with pytest.raises(ImportError, match="rpc_client"): + exec('importlib.import_module("src.plugin_runtime.runner.rpc_client")', plugin_globals) + finally: + builtins.__import__ = original_import + importlib.import_module = original_import_module + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + + def test_isolate_sys_path_keeps_legacy_logger_import_available(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_import_module = importlib.import_module + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + runner_main._isolate_sys_path([]) + plugin_globals = { + "__name__": "_maibot_plugin_demo", + "__package__": "_maibot_plugin_demo", + "importlib": importlib, + } + + exec('logger_module = importlib.import_module("src.common.logger")', plugin_globals) + logger_module = plugin_globals["logger_module"] + assert callable(logger_module.get_logger) + finally: + builtins.__import__ = original_import + importlib.import_module = original_import_module + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + + def test_isolate_sys_path_keeps_runtime_imports_working(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_import_module = importlib.import_module + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + runner_main._isolate_sys_path([]) + + uds_module = importlib.import_module("src.plugin_runtime.transport.uds") + assert hasattr(uds_module, "UDSTransportClient") + finally: + builtins.__import__ = original_import + importlib.import_module = original_import_module + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + + @pytest.mark.asyncio + async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch): + from src.plugin_runtime.runner import runner_main + + captured = {} + + class FakeRunner: + def __init__( + self, + host_address: str, + session_token: str, + plugin_dirs: list[str], + external_available_plugins: dict[str, str] | None = None, + ) -> None: + captured["host_address"] = host_address + captured["session_token"] = session_token + captured["plugin_dirs"] = plugin_dirs + captured["external_available_plugins"] = external_available_plugins or {} + + async def run(self) -> None: + assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None + assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None + + monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999") + monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token") + monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins") + monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}') + monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None) + monkeypatch.setattr(runner_main, "_isolate_sys_path", lambda plugin_dirs: None) + monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner) + + await runner_main._async_main() + + assert captured["host_address"] == "tcp://127.0.0.1:9999" + assert captured["session_token"] == "secret-token" + assert captured["plugin_dirs"] == ["/tmp/plugins"] + assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"} + # ─── Host-side ComponentRegistry 测试 ────────────────────── @@ -973,6 +1510,30 @@ class TestComponentRegistry: assert stats["command"] == 1 assert stats["tool"] == 1 + def test_register_command_with_invalid_regex_only_warns(self, monkeypatch): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + warnings: list[str] = [] + monkeypatch.setattr( + "src.plugin_runtime.host.component_registry.logger.warning", + lambda message: warnings.append(str(message)), + ) + + success = reg.register_component( + "broken", + "command", + "plugin_a", + { + "command_pattern": "[", + }, + ) + + assert success is True + assert reg.get_component("plugin_a.broken") is not None + assert warnings + assert "plugin_a.broken" in warnings[0] + def test_query_by_type(self): from src.plugin_runtime.host.component_registry import ComponentRegistry @@ -1664,6 +2225,67 @@ class TestWorkflowExecutor: class TestRPCServer: """RPC Server 代际保护测试""" + @pytest.mark.asyncio + async def test_reject_second_active_runner_connection(self): + from src.plugin_runtime.host.rpc_server import RPCServer + from src.plugin_runtime.protocol.codec import MsgPackCodec + from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType + + class DummyTransport: + async def start(self, handler): + return None + + async def stop(self): + return None + + def get_address(self): + return "dummy" + + class FakeConnection: + def __init__(self, incoming_frames: list[bytes]): + self._incoming_frames = list(incoming_frames) + self.sent_frames: list[bytes] = [] + self.is_closed = False + + async def recv_frame(self): + return self._incoming_frames.pop(0) + + async def send_frame(self, data): + self.sent_frames.append(data) + + async def close(self): + self.is_closed = True + + codec = MsgPackCodec() + server = RPCServer(transport=DummyTransport(), session_token="session-token") + active_conn = SimpleNamespace(is_closed=False) + server._connection = active_conn + + hello = HelloPayload( + runner_id="runner-b", + sdk_version="1.0.0", + session_token="session-token", + ) + envelope = Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="runner.hello", + payload=hello.model_dump(), + ) + incoming_conn = FakeConnection([codec.encode_envelope(envelope)]) + + await server._handle_connection(incoming_conn) + + assert incoming_conn.is_closed is True + assert server._connection is active_conn + assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手" + assert len(incoming_conn.sent_frames) == 1 + + response = codec.decode_envelope(incoming_conn.sent_frames[0]) + response_payload = HelloResponsePayload.model_validate(response.payload) + assert response_payload.accepted is False + assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手" + def test_ignore_stale_generation_response(self): from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.protocol.envelope import Envelope, MessageType @@ -2012,6 +2634,39 @@ class TestSupervisor: assert supervisor.component_registry.get_component("plugin_a.handler") is not None assert supervisor.component_registry.get_component("plugin_a.obsolete") is None + @pytest.mark.asyncio + async def test_reload_plugins_uses_batch_rpc_for_multiple_roots(self): + from src.plugin_runtime.host.supervisor import PluginSupervisor + from src.plugin_runtime.protocol.envelope import ReloadPluginsResultPayload + + supervisor = PluginSupervisor(plugin_dirs=[]) + sent_requests: list[tuple[str, dict[str, object], int]] = [] + + class FakeRPCServer: + async def send_request(self, method, payload, timeout_ms=5000, **kwargs): + del kwargs + sent_requests.append((method, payload, timeout_ms)) + return SimpleNamespace( + payload=ReloadPluginsResultPayload( + success=True, + requested_plugin_ids=["plugin_a", "plugin_b"], + reloaded_plugins=["plugin_a", "plugin_b", "plugin_c"], + unloaded_plugins=["plugin_c", "plugin_b", "plugin_a"], + ).model_dump() + ) + + supervisor._rpc_server = FakeRPCServer() + + reloaded = await supervisor.reload_plugins(["plugin_a", "plugin_b", "plugin_a"], reason="manual") + + assert reloaded is True + assert len(sent_requests) == 1 + method, payload, timeout_ms = sent_requests[0] + assert method == "plugin.reload_batch" + assert payload["plugin_ids"] == ["plugin_a", "plugin_b"] + assert payload["reason"] == "manual" + assert timeout_ms >= 10000 + @pytest.mark.asyncio async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor @@ -2152,8 +2807,11 @@ class TestIntegration: self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = FakeSupervisor("plugin_a") + manager._third_party_supervisor = FakeSupervisor("plugin_b") - result = await integration_module.PluginRuntimeManager._cap_component_enable( + result = await manager._cap_component_enable( "plugin_a", "component.enable", {"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""}, @@ -2182,8 +2840,10 @@ class TestIntegration: self.supervisors = [FakeSupervisor()] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = FakeSupervisor() - result = await integration_module.PluginRuntimeManager._cap_component_disable( + result = await manager._cap_component_disable( "plugin_a", "component.disable", {"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"}, @@ -2197,6 +2857,8 @@ class TestIntegration: from src.plugin_runtime import integration as integration_module instances = [] + builtin_dir = Path("builtin") + thirdparty_dir = Path("thirdparty") class FakeCapabilityService: def register_capability(self, name, impl): @@ -2204,11 +2866,21 @@ class TestIntegration: class FakeSupervisor: def __init__(self, plugin_dirs=None, socket_path=None): - self.plugin_dirs = plugin_dirs or [] + self._plugin_dirs = plugin_dirs or [] self.capability_service = FakeCapabilityService() + self.external_plugin_versions = {} self.stopped = False instances.append(self) + def set_external_available_plugins(self, plugin_versions): + self.external_plugin_versions = dict(plugin_versions) + + def get_loaded_plugin_ids(self): + return [] + + def get_loaded_plugin_versions(self): + return {} + async def start(self): if len(instances) == 2 and self is instances[1]: raise RuntimeError("boom") @@ -2217,10 +2889,10 @@ class TestIntegration: self.stopped = True monkeypatch.setattr( - integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]) + integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir]) ) monkeypatch.setattr( - integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]) + integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir]) ) import src.plugin_runtime.host.supervisor as supervisor_module @@ -2238,6 +2910,7 @@ class TestIntegration: async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path): from src.config.file_watcher import FileChange from src.plugin_runtime import integration as integration_module + import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" @@ -2247,6 +2920,10 @@ class TestIntegration: beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") + (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") monkeypatch.chdir(tmp_path) @@ -2257,8 +2934,14 @@ class TestIntegration: self.reload_reasons = [] self.config_updates = [] - async def reload_plugins(self, reason="manual"): - self.reload_reasons.append(reason) + def get_loaded_plugin_ids(self): + return sorted(self._registered_plugins.keys()) + + def get_loaded_plugin_versions(self): + return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins} + + async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): + self.reload_reasons.append((plugin_ids, reason, external_available_plugins or {})) async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): self.config_updates.append((plugin_id, config_data, config_version)) @@ -2266,8 +2949,8 @@ class TestIntegration: manager = integration_module.PluginRuntimeManager() manager._started = True - manager._builtin_supervisor = FakeSupervisor([builtin_root], {"alpha": object()}) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"beta": object()}) + manager._builtin_supervisor = FakeSupervisor([builtin_root], {"test.alpha": object()}) + manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"test.beta": object()}) changes = [ FileChange(change_type=1, path=beta_dir / "plugin.py"), @@ -2283,15 +2966,71 @@ class TestIntegration: await manager._handle_plugin_source_changes(changes) assert manager._builtin_supervisor.reload_reasons == [] - assert manager._third_party_supervisor.reload_reasons == ["file_watcher"] + assert manager._third_party_supervisor.reload_reasons == [ + (["test.beta"], "file_watcher", {"test.alpha": "1.0.0"}) + ] assert manager._builtin_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == [] assert refresh_calls == [True] + @pytest.mark.asyncio + async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch): + from src.plugin_runtime import integration as integration_module + + class FakeRegistration: + def __init__(self, dependencies): + self.dependencies = dependencies + + class FakeSupervisor: + def __init__(self, registrations): + self._registered_plugins = registrations + self.reload_calls = [] + + def get_loaded_plugin_ids(self): + return sorted(self._registered_plugins.keys()) + + def get_loaded_plugin_versions(self): + return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins} + + async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): + self.reload_calls.append((plugin_ids, reason, dict(sorted((external_available_plugins or {}).items())))) + return True + + builtin_supervisor = FakeSupervisor({"test.alpha": FakeRegistration([])}) + third_party_supervisor = FakeSupervisor( + { + "test.beta": FakeRegistration(["test.alpha"]), + "test.gamma": FakeRegistration(["test.beta"]), + } + ) + + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = builtin_supervisor + manager._third_party_supervisor = third_party_supervisor + warning_messages = [] + + monkeypatch.setattr( + integration_module.logger, + "warning", + lambda message: warning_messages.append(message), + ) + + reloaded = await manager.reload_plugins_globally(["test.alpha"], reason="manual") + + assert reloaded is True + assert builtin_supervisor.reload_calls == [ + (["test.alpha"], "manual", {"test.beta": "1.0.0", "test.gamma": "1.0.0"}) + ] + assert third_party_supervisor.reload_calls == [] + assert len(warning_messages) == 1 + assert "test.beta, test.gamma" in warning_messages[0] + assert "跨 Supervisor API 调用仍然可用" in warning_messages[0] + @pytest.mark.asyncio async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module from src.config.file_watcher import FileChange + import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" @@ -2301,6 +3040,10 @@ class TestIntegration: beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") + (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") monkeypatch.chdir(tmp_path) @@ -2310,25 +3053,97 @@ class TestIntegration: self._registered_plugins = {plugin_id: object() for plugin_id in plugins} self.config_updates = [] - async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): - self.config_updates.append((plugin_id, config_data, config_version)) + async def notify_plugin_config_updated( + self, + plugin_id, + config_data, + config_version="", + config_scope="self", + ): + self.config_updates.append((plugin_id, config_data, config_version, config_scope)) return True manager = integration_module.PluginRuntimeManager() manager._started = True - manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) + manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) + manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"]) await manager._handle_plugin_config_changes( - "alpha", + "test.alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) - assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "")] + assert manager._builtin_supervisor.config_updates == [("test.alpha", {"enabled": True}, "", "self")] assert manager._third_party_supervisor.config_updates == [] + @pytest.mark.asyncio + async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch): + from src.plugin_runtime import integration as integration_module + + class FakeRegistration: + def __init__(self, subscriptions): + self.config_reload_subscriptions = subscriptions + + class FakeSupervisor: + def __init__(self, registrations): + self._registered_plugins = registrations + self.config_updates = [] + + def get_config_reload_subscribers(self, scope): + matched_plugins = [] + for plugin_id, registration in self._registered_plugins.items(): + if scope in registration.config_reload_subscriptions: + matched_plugins.append(plugin_id) + return matched_plugins + + async def notify_plugin_config_updated( + self, + plugin_id, + config_data, + config_version="", + config_scope="self", + ): + self.config_updates.append((plugin_id, config_data, config_version, config_scope)) + return True + + fake_global = SimpleNamespace(plugin_runtime=SimpleNamespace(enabled=True)) + monkeypatch.setattr( + integration_module.config_manager, + "get_global_config", + lambda: SimpleNamespace(model_dump=lambda: {"bot": {"name": "MaiBot"}}, plugin_runtime=fake_global.plugin_runtime), + ) + monkeypatch.setattr( + integration_module.config_manager, + "get_model_config", + lambda: SimpleNamespace(model_dump=lambda: {"models": [{"name": "demo"}]}), + ) + + manager = integration_module.PluginRuntimeManager() + manager._started = True + manager._builtin_supervisor = FakeSupervisor( + { + "test.alpha": FakeRegistration(["bot"]), + "test.beta": FakeRegistration([]), + } + ) + manager._third_party_supervisor = FakeSupervisor( + { + "test.gamma": FakeRegistration(["model"]), + } + ) + + await manager._handle_main_config_reload(["bot", "model"]) + + assert manager._builtin_supervisor.config_updates == [ + ("test.alpha", {"bot": {"name": "MaiBot"}}, "", "bot") + ] + assert manager._third_party_supervisor.config_updates == [ + ("test.gamma", {"models": [{"name": "demo"}]}, "", "model") + ] + def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): from src.plugin_runtime import integration as integration_module + import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" @@ -2336,6 +3151,10 @@ class TestIntegration: beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) + (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") class FakeWatcher: def __init__(self): @@ -2358,12 +3177,12 @@ class TestIntegration: manager = integration_module.PluginRuntimeManager() manager._plugin_file_watcher = FakeWatcher() - manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) + manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) + manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"]) manager._refresh_plugin_config_watch_subscriptions() - assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"alpha", "beta"} + assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"} assert { subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions } == {alpha_dir / "config.toml", beta_dir / "config.toml"} @@ -2372,55 +3191,30 @@ class TestIntegration: async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch): from src.plugin_runtime import integration as integration_module - class FakeSupervisor: - def __init__(self): - self._registered_plugins = {"alpha": object()} + manager = integration_module.PluginRuntimeManager() + monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False)) - async def reload_plugins(self, reason="manual"): - return False - - class FakeManager: - def __init__(self): - self.supervisors = [FakeSupervisor()] - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - - result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin( + result = await manager._cap_component_reload_plugin( "plugin_a", "component.reload_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False - assert "已回滚" in result["error"] + assert result["error"] == "插件 alpha 热重载失败" @pytest.mark.asyncio async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module - plugin_root = tmp_path / "plugins" - plugin_root.mkdir() - (plugin_root / "alpha").mkdir() + manager = integration_module.PluginRuntimeManager() + monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False)) - class FakeSupervisor: - def __init__(self): - self._registered_plugins = {} - self._plugin_dirs = [str(plugin_root)] - - async def reload_plugins(self, reason="manual"): - return False - - class FakeManager: - def __init__(self): - self.supervisors = [FakeSupervisor()] - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - - result = await integration_module.PluginRuntimeManager._cap_component_load_plugin( + result = await manager._cap_component_load_plugin( "plugin_a", "component.load_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False - assert "已回滚" in result["error"] + assert result["error"] == "插件 alpha 热重载失败" diff --git a/pytests/test_plugin_runtime_action_bridge.py b/pytests/test_plugin_runtime_action_bridge.py new file mode 100644 index 00000000..e13dfaf3 --- /dev/null +++ b/pytests/test_plugin_runtime_action_bridge.py @@ -0,0 +1,284 @@ +"""核心组件查询层与插件运行时聚合测试。""" + +from types import SimpleNamespace +from typing import Any + +import pytest + +import src.plugin_runtime.integration as integration_module + +from src.core.types import ActionInfo, ToolInfo +from src.plugin_runtime.component_query import component_query_service +from src.plugin_runtime.host.supervisor import PluginSupervisor + + +class _FakeRuntimeManager: + """测试用插件运行时管理器。""" + + def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None: + """初始化测试用运行时管理器。 + + Args: + supervisor: 持有测试组件的监督器。 + plugin_id: 目标插件 ID。 + plugin_config: 需要返回的插件配置。 + """ + + self.supervisors = [supervisor] + self._plugin_id = plugin_id + self._plugin_config = plugin_config + + def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None: + """按插件 ID 返回对应监督器。 + + Args: + plugin_id: 目标插件 ID。 + + Returns: + PluginSupervisor | None: 命中时返回监督器。 + """ + + return self.supervisors[0] if plugin_id == self._plugin_id else None + + def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]: + """返回测试配置。 + + Args: + supervisor: 监督器实例。 + plugin_id: 目标插件 ID。 + + Returns: + dict[str, Any]: 测试配置内容。 + """ + + del supervisor + if plugin_id != self._plugin_id: + return {} + return dict(self._plugin_config) + + +def _install_runtime_manager( + monkeypatch: pytest.MonkeyPatch, + supervisor: PluginSupervisor, + plugin_id: str, + plugin_config: dict[str, Any] | None = None, +) -> None: + """为测试安装假的运行时管理器。 + + Args: + monkeypatch: pytest monkeypatch 对象。 + supervisor: 持有测试组件的监督器。 + plugin_id: 测试插件 ID。 + plugin_config: 可选的测试配置内容。 + """ + + fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True}) + monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager) + + +@pytest.mark.asyncio +async def test_core_component_registry_reads_runtime_action_and_executor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """核心查询层应直接读取运行时 Action,并返回 RPC 执行闭包。""" + + plugin_id = "runtime_action_bridge_plugin" + action_name = "runtime_action_bridge_test" + supervisor = PluginSupervisor(plugin_dirs=[]) + captured: dict[str, Any] = {} + + supervisor.component_registry.register_component( + name=action_name, + component_type="ACTION", + plugin_id=plugin_id, + metadata={ + "description": "发送一个测试回复", + "enabled": True, + "activation_type": "keyword", + "activation_probability": 0.25, + "activation_keywords": ["测试", "hello"], + "action_parameters": {"target": "目标对象"}, + "action_require": ["需要发送回复时使用"], + "associated_types": ["text"], + "parallel_action": True, + }, + ) + _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"}) + + async def fake_invoke_plugin( + method: str, + plugin_id: str, + component_name: str, + args: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟动作 RPC 调用。""" + + captured["method"] = method + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")}) + + monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) + + action_info = component_query_service.get_action_info(action_name) + assert isinstance(action_info, ActionInfo) + assert action_info.plugin_name == plugin_id + assert action_info.description == "发送一个测试回复" + assert action_info.activation_keywords == ["测试", "hello"] + assert action_info.random_activation_probability == 0.25 + assert action_info.parallel_action is True + assert action_name in component_query_service.get_default_actions() + assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"} + + executor = component_query_service.get_action_executor(action_name) + assert executor is not None + + success, reason = await executor( + action_data={"target": "MaiBot"}, + action_reasoning="当前适合使用这个动作", + cycle_timers={"planner": 0.1}, + thinking_id="tid-1", + chat_stream=SimpleNamespace(session_id="stream-1"), + log_prefix="[test]", + shutting_down=False, + plugin_config={"enabled": True}, + ) + + assert success is True + assert reason == "runtime action executed" + assert captured["method"] == "plugin.invoke_action" + assert captured["plugin_id"] == plugin_id + assert captured["component_name"] == action_name + assert captured["args"]["stream_id"] == "stream-1" + assert captured["args"]["chat_id"] == "stream-1" + assert captured["args"]["reasoning"] == "当前适合使用这个动作" + assert captured["args"]["target"] == "MaiBot" + assert captured["args"]["action_data"] == {"target": "MaiBot"} + + +@pytest.mark.asyncio +async def test_core_component_registry_reads_runtime_command_and_executor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """核心查询层应直接使用运行时命令匹配与执行闭包。""" + + plugin_id = "runtime_command_bridge_plugin" + command_name = "runtime_command_bridge_test" + supervisor = PluginSupervisor(plugin_dirs=[]) + captured: dict[str, Any] = {} + + supervisor.component_registry.register_component( + name=command_name, + component_type="COMMAND", + plugin_id=plugin_id, + metadata={ + "description": "测试命令", + "enabled": True, + "command_pattern": r"^/test(?:\s+.+)?$", + "aliases": ["/hello"], + "intercept_message_level": 1, + }, + ) + _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"}) + + async def fake_invoke_plugin( + method: str, + plugin_id: str, + component_name: str, + args: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟命令 RPC 调用。""" + + captured["method"] = method + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)}) + + monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) + + matched = component_query_service.find_command_by_text("/test hello") + assert matched is not None + command_executor, matched_groups, command_info = matched + + assert matched_groups == {} + assert command_info.plugin_name == plugin_id + assert command_info.command_pattern == r"^/test(?:\s+.+)?$" + + success, response_text, intercept = await command_executor( + message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"), + plugin_config={"mode": "command"}, + matched_groups=matched_groups, + ) + + assert success is True + assert response_text == "command ok" + assert intercept is True + assert captured["method"] == "plugin.invoke_command" + assert captured["plugin_id"] == plugin_id + assert captured["component_name"] == command_name + assert captured["args"]["text"] == "/test hello" + assert captured["args"]["stream_id"] == "stream-2" + assert captured["args"]["plugin_config"] == {"mode": "command"} + + +@pytest.mark.asyncio +async def test_core_component_registry_reads_runtime_tools_and_executor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """核心查询层应直接读取运行时 Tool,并返回 RPC 执行闭包。""" + + plugin_id = "runtime_tool_bridge_plugin" + tool_name = "runtime_tool_bridge_test" + supervisor = PluginSupervisor(plugin_dirs=[]) + + supervisor.component_registry.register_component( + name=tool_name, + component_type="TOOL", + plugin_id=plugin_id, + metadata={ + "description": "测试工具", + "enabled": True, + "parameters": [ + { + "name": "query", + "param_type": "string", + "description": "查询词", + "required": True, + } + ], + }, + ) + _install_runtime_manager(monkeypatch, supervisor, plugin_id) + + async def fake_invoke_plugin( + method: str, + plugin_id: str, + component_name: str, + args: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟工具 RPC 调用。""" + + del timeout_ms + assert method == "plugin.invoke_tool" + assert plugin_id == "runtime_tool_bridge_plugin" + assert component_name == "runtime_tool_bridge_test" + assert args == {"query": "MaiBot"} + return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}}) + + monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) + + tool_info = component_query_service.get_tool_info(tool_name) + assert isinstance(tool_info, ToolInfo) + assert tool_info.tool_description == "测试工具" + assert tool_name in component_query_service.get_llm_available_tools() + + executor = component_query_service.get_tool_executor(tool_name) + assert executor is not None + assert await executor({"query": "MaiBot"}) == {"content": "tool ok"} diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py new file mode 100644 index 00000000..58a8e6ba --- /dev/null +++ b/pytests/test_plugin_runtime_api.py @@ -0,0 +1,524 @@ +"""插件 API 注册与调用测试。""" + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from src.plugin_runtime.integration import PluginRuntimeManager +from src.plugin_runtime.host.supervisor import PluginSupervisor +from src.plugin_runtime.protocol.envelope import ( + ComponentDeclaration, + Envelope, + MessageType, + RegisterPluginPayload, + UnregisterPluginPayload, +) + + +def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager: + """构造一个最小可用的插件运行时管理器。 + + Args: + *supervisors: 需要挂载的监督器列表。 + + Returns: + PluginRuntimeManager: 已注入监督器的运行时管理器。 + """ + + manager = PluginRuntimeManager() + if supervisors: + manager._builtin_supervisor = supervisors[0] + if len(supervisors) > 1: + manager._third_party_supervisor = supervisors[1] + return manager + + +async def _register_plugin( + supervisor: PluginSupervisor, + plugin_id: str, + components: List[Dict[str, Any]], +) -> Envelope: + """通过 Supervisor 注册测试插件。 + + Args: + supervisor: 目标监督器。 + plugin_id: 测试插件 ID。 + components: 组件声明列表。 + + Returns: + Envelope: 注册响应信封。 + """ + + payload = RegisterPluginPayload( + plugin_id=plugin_id, + plugin_version="1.0.0", + components=[ + ComponentDeclaration( + name=str(component.get("name", "") or ""), + component_type=str(component.get("component_type", "") or ""), + plugin_id=plugin_id, + metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}, + ) + for component in components + ], + ) + return await supervisor._handle_register_plugin( + Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="plugin.register_components", + plugin_id=plugin_id, + payload=payload.model_dump(), + ) + ) + + +async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope: + """通过 Supervisor 注销测试插件。 + + Args: + supervisor: 目标监督器。 + plugin_id: 测试插件 ID。 + + Returns: + Envelope: 注销响应信封。 + """ + + payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test") + return await supervisor._handle_unregister_plugin( + Envelope( + request_id=2, + message_type=MessageType.REQUEST, + method="plugin.unregister", + plugin_id=plugin_id, + payload=payload.model_dump(), + ) + ) + + +@pytest.mark.asyncio +async def test_register_plugin_syncs_dedicated_api_registry() -> None: + """插件注册时应将 API 同步到独立注册表,而不是通用组件表。""" + + supervisor = PluginSupervisor(plugin_dirs=[]) + response = await _register_plugin( + supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML", + "version": "1", + "public": True, + }, + } + ], + ) + + assert response.payload["accepted"] is True + assert response.payload["registered_components"] == 0 + assert response.payload["registered_apis"] == 1 + assert supervisor.api_registry.get_api("provider", "render_html") is not None + assert supervisor.component_registry.get_component("provider.render_html") is None + + unregister_response = await _unregister_plugin(supervisor, "provider") + assert unregister_response.payload["removed_apis"] == 1 + assert supervisor.api_registry.get_api("provider", "render_html") is None + + +@pytest.mark.asyncio +async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None: + """公开 API 应允许其他插件通过 Host 转发调用。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML", + "version": "1", + "public": True, + }, + } + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟 API RPC 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}}) + + monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "version": "1", + "args": {"html": "
Hello
"}, + }, + ) + + assert result == {"success": True, "result": {"image": "ok"}} + assert captured["plugin_id"] == "provider" + assert captured["component_name"] == "render_html" + assert captured["args"] == {"html": "
Hello
"} + + +@pytest.mark.asyncio +async def test_api_call_rejects_private_api_between_plugins() -> None: + """未公开的 API 默认不允许跨插件调用。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "secret_api", + "component_type": "API", + "metadata": { + "description": "私有 API", + "version": "1", + "public": False, + }, + } + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.secret_api", + "args": {}, + }, + ) + + assert result["success"] is False + assert "未公开" in str(result["error"]) + + +@pytest.mark.asyncio +async def test_api_list_and_component_toggle_use_dedicated_registry() -> None: + """API 列表与组件启停应直接作用于独立 API 注册表。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "public_api", + "component_type": "API", + "metadata": {"version": "1", "public": True}, + }, + { + "name": "private_api", + "component_type": "API", + "metadata": {"version": "1", "public": False}, + }, + ], + ) + await _register_plugin( + consumer_supervisor, + "consumer", + [ + { + "name": "self_private_api", + "component_type": "API", + "metadata": {"version": "1", "public": False}, + } + ], + ) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + list_result = await manager._cap_api_list("consumer", "api.list", {}) + + assert list_result["success"] is True + api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]} + assert ("provider", "public_api") in api_names + assert ("provider", "private_api") not in api_names + assert ("consumer", "self_private_api") in api_names + + disable_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.public_api", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert disable_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None + + enable_result = await manager._cap_component_enable( + "consumer", + "component.enable", + { + "name": "provider.public_api", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert enable_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None + + +@pytest.mark.asyncio +async def test_api_registry_supports_multiple_versions_with_distinct_handlers( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """同名 API 不同版本应可并存,并按版本路由到不同处理器。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML v1", + "version": "1", + "public": True, + "handler_name": "handle_render_html_v1", + }, + }, + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML v2", + "version": "2", + "public": True, + "handler_name": "handle_render_html_v2", + }, + }, + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟多版本 API 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}}) + + monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api) + manager = _build_manager(provider_supervisor, consumer_supervisor) + + ambiguous_result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "args": {"html": "
Hello
"}, + }, + ) + assert ambiguous_result["success"] is False + assert "多个版本" in str(ambiguous_result["error"]) + + disable_ambiguous_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.render_html", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert disable_ambiguous_result["success"] is False + assert "多个版本" in str(disable_ambiguous_result["error"]) + + disable_v1_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.render_html", + "component_type": "API", + "scope": "global", + "stream_id": "", + "version": "1", + }, + ) + assert disable_v1_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None + assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None + + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "version": "2", + "args": {"html": "
Hello
"}, + }, + ) + + assert result == {"success": True, "result": {"image": "ok"}} + assert captured["plugin_id"] == "provider" + assert captured["component_name"] == "handle_render_html_v2" + assert captured["args"] == {"html": "
Hello
"} + + +@pytest.mark.asyncio +async def test_api_replace_dynamic_can_offline_removed_entries( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """动态 API 替换后,被移除的 API 应返回明确下线错误。""" + + supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin(supervisor, "provider", []) + manager = _build_manager(supervisor) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟动态 API 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}}) + + monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api) + + replace_result = await manager._cap_api_replace_dynamic( + "provider", + "api.replace_dynamic", + { + "apis": [ + { + "name": "mcp.search", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_search", + }, + }, + { + "name": "mcp.read", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_read", + }, + }, + ], + "offline_reason": "MCP 服务器已关闭", + }, + ) + + assert replace_result["success"] is True + assert replace_result["count"] == 2 + list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"}) + assert {(item["name"], item["version"]) for item in list_result["apis"]} == { + ("mcp.read", "1"), + ("mcp.search", "1"), + } + + call_result = await manager._cap_api_call( + "provider", + "api.call", + { + "api_name": "provider.mcp.search", + "version": "1", + "args": {"query": "hello"}, + }, + ) + assert call_result == {"success": True, "result": {"ok": True}} + assert captured["component_name"] == "dynamic_search" + assert captured["args"]["query"] == "hello" + assert captured["args"]["__maibot_api_name__"] == "mcp.search" + assert captured["args"]["__maibot_api_version__"] == "1" + + second_replace_result = await manager._cap_api_replace_dynamic( + "provider", + "api.replace_dynamic", + { + "apis": [ + { + "name": "mcp.read", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_read", + }, + } + ], + "offline_reason": "MCP 服务器已关闭", + }, + ) + + assert second_replace_result["success"] is True + assert second_replace_result["count"] == 1 + assert second_replace_result["offlined"] == 1 + + offlined_call_result = await manager._cap_api_call( + "provider", + "api.call", + { + "api_name": "provider.mcp.search", + "version": "1", + "args": {}, + }, + ) + assert offlined_call_result["success"] is False + assert "MCP 服务器已关闭" in str(offlined_call_result["error"]) + + list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"}) + assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == { + ("mcp.read", "1"), + } diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py new file mode 100644 index 00000000..16aad080 --- /dev/null +++ b/pytests/test_send_service.py @@ -0,0 +1,154 @@ +"""发送服务回归测试。""" + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from src.chat.message_receive.chat_manager import BotChatSession +from src.services import send_service + + +class _FakePlatformIOManager: + """用于测试的 Platform IO 管理器假对象。""" + + def __init__(self, delivery_batch: Any) -> None: + """初始化假 Platform IO 管理器。 + + Args: + delivery_batch: 发送时返回的批量回执。 + """ + self._delivery_batch = delivery_batch + self.ensure_calls = 0 + self.sent_messages: List[Dict[str, Any]] = [] + + async def ensure_send_pipeline_ready(self) -> None: + """记录发送管线准备调用次数。""" + self.ensure_calls += 1 + + def build_route_key_from_message(self, message: Any) -> Any: + """根据消息构造假的路由键。 + + Args: + message: 待发送的内部消息对象。 + + Returns: + Any: 简化后的路由键对象。 + """ + del message + return SimpleNamespace(platform="qq") + + async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any: + """记录发送请求并返回预设回执。 + + Args: + message: 待发送的内部消息对象。 + route_key: 本次发送使用的路由键。 + metadata: 发送元数据。 + + Returns: + Any: 预设的批量发送回执。 + """ + self.sent_messages.append( + { + "message": message, + "route_key": route_key, + "metadata": metadata, + } + ) + return self._delivery_batch + + +def _build_target_stream() -> BotChatSession: + """构造一个最小可用的目标会话对象。 + + Returns: + BotChatSession: 测试用会话对象。 + """ + return BotChatSession( + session_id="test-session", + platform="qq", + user_id="target-user", + group_id=None, + ) + + +def test_inherit_platform_io_route_metadata_falls_back_to_bot_account( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """没有上下文消息时,也应回填当前平台账号用于账号级路由命中。""" + + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "") + + metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream()) + + assert metadata["platform_io_account_id"] == "bot-qq" + assert metadata["platform_io_target_user_id"] == "target-user" + + +@pytest.mark.asyncio +async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: + """send service 应将发送职责统一交给 Platform IO。""" + fake_manager = _FakePlatformIOManager( + delivery_batch=SimpleNamespace( + has_success=True, + sent_receipts=[SimpleNamespace(driver_id="plugin.qq.sender")], + failed_receipts=[], + route_key=SimpleNamespace(platform="qq"), + ) + ) + stored_messages: List[Any] = [] + + monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + ) + monkeypatch.setattr( + send_service.MessageUtils, + "store_message_to_db", + lambda message: stored_messages.append(message), + ) + + result = await send_service.text_to_stream(text="你好", stream_id="test-session") + + assert result is True + assert fake_manager.ensure_calls == 1 + assert len(fake_manager.sent_messages) == 1 + assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False} + assert len(stored_messages) == 1 + + +@pytest.mark.asyncio +async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None: + """Platform IO 批量发送全部失败时,应直接向上返回失败。""" + fake_manager = _FakePlatformIOManager( + delivery_batch=SimpleNamespace( + has_success=False, + sent_receipts=[], + failed_receipts=[ + SimpleNamespace( + driver_id="plugin.qq.sender", + status="failed", + error="network error", + ) + ], + route_key=SimpleNamespace(platform="qq"), + ) + ) + + monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + ) + + result = await send_service.text_to_stream(text="发送失败", stream_id="test-session") + + assert result is False + assert fake_manager.ensure_calls == 1 + assert len(fake_manager.sent_messages) == 1 diff --git a/pytests/utils_test/statistic_test.py b/pytests/utils_test/statistic_test.py new file mode 100644 index 00000000..d3d8c18a --- /dev/null +++ b/pytests/utils_test/statistic_test.py @@ -0,0 +1,115 @@ +"""统计模块数据库会话行为测试。""" + +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime, timedelta +from types import ModuleType +from typing import Any, Callable, Iterator + +import sys + +import pytest + +from src.chat.utils import statistic + + +class _DummyResult: + """模拟 SQLModel 查询结果对象。""" + + def all(self) -> list[Any]: + """返回空结果集。 + + Returns: + list[Any]: 空列表。 + """ + return [] + + +class _DummySession: + """模拟数据库 Session。""" + + def exec(self, statement: Any) -> _DummyResult: + """执行查询语句并返回空结果。 + + Args: + statement: 待执行的查询语句。 + + Returns: + _DummyResult: 空结果对象。 + """ + del statement + return _DummyResult() + + +def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]: + """构造一个记录 auto_commit 参数的假会话工厂。 + + Args: + calls: 用于记录每次调用 auto_commit 参数的列表。 + + Returns: + Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。 + """ + + @contextmanager + def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]: + """记录会话参数并返回假 Session。 + + Args: + auto_commit: 是否启用自动提交。 + + Yields: + Iterator[_DummySession]: 假 Session 对象。 + """ + calls.append(auto_commit) + yield _DummySession() + + return _fake_get_db_session + + +def _build_statistic_task() -> statistic.StatisticOutputTask: + """构造一个最小可用的统计任务实例。 + + Returns: + statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。 + """ + task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask) + task.name_mapping = {} + return task + + +def _is_bot_self(platform: str, user_id: str) -> bool: + """返回固定的非机器人身份判断结果。 + + Args: + platform: 平台名称。 + user_id: 用户 ID。 + + Returns: + bool: 始终返回 ``False``。 + """ + del platform + del user_id + return False + + +def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None: + """统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。""" + calls: list[bool] = [] + now = datetime.now() + task = _build_statistic_task() + + monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls)) + + utils_module = ModuleType("src.chat.utils.utils") + utils_module.is_bot_self = _is_bot_self + monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module) + + statistic.StatisticOutputTask._fetch_online_time_since(now) + statistic.StatisticOutputTask._fetch_model_usage_since(now) + task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))]) + task._collect_interval_data(now, hours=1, interval_minutes=60) + task._collect_metrics_interval_data(now, hours=1, interval_hours=1) + + assert calls == [False] * 9 diff --git a/pytests/utils_test/test_session_utils.py b/pytests/utils_test/test_session_utils.py new file mode 100644 index 00000000..c44e2eba --- /dev/null +++ b/pytests/utils_test/test_session_utils.py @@ -0,0 +1,42 @@ +from types import SimpleNamespace + +from src.chat.message_receive.chat_manager import ChatManager +from src.common.utils.utils_session import SessionUtils + + +def test_calculate_session_id_distinguishes_account_and_scope() -> None: + base_session_id = SessionUtils.calculate_session_id("qq", user_id="42") + same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42") + account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123") + route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main") + + assert base_session_id == same_base_session_id + assert account_scoped_session_id != base_session_id + assert route_scoped_session_id != account_scoped_session_id + + +def test_chat_manager_register_message_uses_route_metadata() -> None: + chat_manager = ChatManager() + message = SimpleNamespace( + platform="qq", + session_id="", + message_info=SimpleNamespace( + user_info=SimpleNamespace(user_id="42"), + group_info=SimpleNamespace(group_id="1000"), + additional_config={ + "platform_io_account_id": "123", + "platform_io_scope": "main", + }, + ), + ) + + chat_manager.register_message(message) + + assert message.session_id == SessionUtils.calculate_session_id( + "qq", + user_id="42", + group_id="1000", + account_id="123", + scope="main", + ) + assert chat_manager.last_messages[message.session_id] is message diff --git a/src/chat/brain_chat/PFC/message_sender.py b/src/chat/brain_chat/PFC/message_sender.py index ec5fb5ba..b9da905c 100644 --- a/src/chat/brain_chat/PFC/message_sender.py +++ b/src/chat/brain_chat/PFC/message_sender.py @@ -1,27 +1,28 @@ -import time +"""PFC 侧消息发送封装。""" + from typing import Optional -from maim_message import Seg from rich.traceback import install -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.message import MessageSending -from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from src.chat.utils.utils import get_bot_account +from src.common.data_models.mai_message_data_model import MaiMessage from src.common.logger import get_logger -from src.config.config import global_config +from src.services import send_service as send_api install(extra_lines=3) - logger = get_logger("message_sender") class DirectMessageSender: - """直接消息发送器""" + """直接消息发送器。""" - def __init__(self, private_name: str): + def __init__(self, private_name: str) -> None: + """初始化直接消息发送器。 + + Args: + private_name: 当前私聊实例的名称。 + """ self.private_name = private_name async def send_message( @@ -30,58 +31,31 @@ class DirectMessageSender: content: str, reply_to_message: Optional[MaiMessage] = None, ) -> None: - """发送消息到聊天流 + """发送文本消息到聊天流。 Args: - chat_stream: 聊天会话 - content: 消息内容 - reply_to_message: 要回复的消息(可选) + chat_stream: 目标聊天会话。 + content: 待发送的文本内容。 + reply_to_message: 可选的引用回复锚点消息。 + + Raises: + RuntimeError: 当消息发送失败时抛出。 """ try: - # 创建消息内容 - segments = Seg(type="seglist", data=[Seg(type="text", data=content)]) - - # 获取麦麦的信息 - bot_user_id = get_bot_account(chat_stream.platform) - if not bot_user_id: - logger.error(f"[私聊][{self.private_name}]平台 {chat_stream.platform} 未配置机器人账号,无法发送消息") - raise RuntimeError(f"平台 {chat_stream.platform} 未配置机器人账号") - bot_user_info = UserInfo( - user_id=bot_user_id, - user_nickname=global_config.bot.nickname, + sent = await send_api.text_to_stream( + text=content, + stream_id=chat_stream.session_id, + set_reply=reply_to_message is not None, + reply_message=reply_to_message, + storage_message=True, ) - # 用当前时间作为message_id,和之前那套sender一样 - message_id = f"dm{round(time.time(), 2)}" - - # 构建发送者信息(私聊时为接收者) - sender_info = None - if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info: - sender_info = reply_to_message.message_info.user_info - - # 构建消息对象 - message = MessageSending( - message_id=message_id, - session=chat_stream, - bot_user_info=bot_user_info, - sender_info=sender_info, - message_segment=segments, - reply=reply_to_message, - is_head=True, - is_emoji=False, - thinking_start_time=time.time(), - ) - - # 发送消息 - message_sender = UniversalMessageSender() - sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True) - if sent: logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}") - else: - logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") - raise RuntimeError("消息发送失败") + return - except Exception as e: - logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}") + logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") + raise RuntimeError("消息发送失败") + except Exception as exc: + logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}") raise diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 2b4863ac..1e9e648a 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -8,8 +8,8 @@ from rich.traceback import install from src.config.config import global_config from src.common.logger import get_logger from src.common.utils.utils_config import ExpressionConfigUtils -from src.bw_learner.expression_learner import ExpressionLearner -from src.bw_learner.jargon_miner import JargonMiner +from src.learners.expression_learner import ExpressionLearner +from src.learners.jargon_miner import JargonMiner from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import SessionMessage diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py index 12b103a0..709be8ee 100644 --- a/src/chat/brain_chat/brain_planner.py +++ b/src/chat/brain_chat/brain_planner.py @@ -1,30 +1,32 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + import json -import time -import traceback import random import re -from typing import Dict, Optional, Tuple, List, TYPE_CHECKING -from rich.traceback import install -from datetime import datetime -from json_repair import repair_json +import time +import traceback + +from json_repair import repair_json +from rich.traceback import install -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger from src.chat.logger.plan_reply_logger import PlanReplyLogger +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.chat.planner_actions.action_manager import ActionManager +from src.chat.utils.utils import get_chat_type_and_target_info from src.common.data_models.info_data_model import ActionPlannerInfo +from src.common.logger import get_logger from src.common.utils.utils_action import ActionUtils +from src.config.config import global_config, model_config +from src.core.types import ActionActivationType, ActionInfo, ComponentType +from src.llm_models.utils_model import LLMRequest +from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager from src.services.message_service import ( build_readable_messages_with_id, get_actions_by_timestamp_with_chat, get_messages_before_time_in_chat, ) -from src.chat.utils.utils import get_chat_type_and_target_info -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.core.types import ActionActivationType, ActionInfo, ComponentType -from src.core.component_registry import component_registry if TYPE_CHECKING: from src.common.data_models.info_data_model import TargetPersonInfo @@ -320,7 +322,7 @@ class BrainPlanner: current_available_actions_dict = self.action_manager.get_using_actions() # 获取完整的动作信息 - all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore + all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore ComponentType.ACTION ) current_available_actions = {} diff --git a/src/chat/heart_flow/heartFC_chat - 副本.py b/src/chat/heart_flow/heartFC_chat - 副本.py deleted file mode 100644 index c805597d..00000000 --- a/src/chat/heart_flow/heartFC_chat - 副本.py +++ /dev/null @@ -1,734 +0,0 @@ -import asyncio -import time -import traceback -import random -from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING -from rich.traceback import install - -from src.config.config import global_config -from src.common.logger import get_logger -from src.common.data_models.info_data_model import ActionPlannerInfo -from src.common.data_models.message_data_model import ReplyContentType -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.utils.prompt_builder import global_prompt_manager -from src.chat.utils.timer_calculator import Timer -from src.chat.planner_actions.planner import ActionPlanner -from src.chat.planner_actions.action_modifier import ActionModifier -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.heart_flow.hfc_utils import CycleDetail -from src.bw_learner.expression_learner import expression_learner_manager -from src.chat.heart_flow.frequency_control import frequency_control_manager -from src.bw_learner.message_recorder import extract_and_distribute_messages -from src.person_info.person_info import Person -from src.plugin_system.base.component_types import EventType, ActionInfo -from src.plugin_system.core import events_manager -from src.plugin_system.apis import generator_api, send_api, message_api, database_api -from src.chat.utils.chat_message_builder import ( - build_readable_messages_with_id, - get_raw_msg_before_timestamp_with_chat, -) -from src.chat.utils.utils import record_replyer_action_temp -from src.memory_system.chat_history_summarizer import ChatHistorySummarizer - -if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages - from src.common.data_models.message_data_model import ReplySetModel - - -ERROR_LOOP_INFO = { - "loop_plan_info": { - "action_result": { - "action_type": "error", - "action_data": {}, - "reasoning": "循环处理失败", - }, - }, - "loop_action_info": { - "action_taken": False, - "reply_text": "", - "command": "", - "taken_time": time.time(), - }, -} - - -install(extra_lines=3) - -# 注释:原来的动作修改超时常量已移除,因为改为顺序执行 - -logger = get_logger("hfc") # Logger Name Changed - - -class HeartFChatting: - """ - 管理一个连续的Focus Chat循环 - 用于在特定聊天流中生成回复。 - 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。 - """ - - def __init__(self, chat_id: str): - """ - HeartFChatting 初始化函数 - - 参数: - chat_id: 聊天流唯一标识符(如stream_id) - on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数 - performance_version: 性能记录版本号,用于区分不同启动版本 - """ - # 基础属性 - self.stream_id: str = chat_id # 聊天流ID - self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore - if not self.chat_stream: - raise ValueError(f"无法找到聊天流: {self.stream_id}") - self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" - - self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id) - - self.action_manager = ActionManager() - self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager) - self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id) - - # 循环控制内部状态 - self.running: bool = False - self._loop_task: Optional[asyncio.Task] = None # 主循环任务 - - # 添加循环信息管理相关的属性 - self.history_loop: List[CycleDetail] = [] - self._cycle_counter = 0 - self._current_cycle_detail: CycleDetail = None # type: ignore - - self.last_read_time = time.time() - 2 - - self.is_mute = False - - self.last_active_time = time.time() # 记录上一次非noreply时间 - - self.question_probability_multiplier = 1 - self.questioned = False - - # 跟踪连续 no_reply 次数,用于动态调整阈值 - self.consecutive_no_reply_count = 0 - - # 聊天内容概括器 - self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id) - - async def start(self): - """检查是否需要启动主循环,如果未激活则启动。""" - - # 如果循环已经激活,直接返回 - if self.running: - logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动") - return - - try: - # 标记为活动状态,防止重复启动 - self.running = True - - self._loop_task = asyncio.create_task(self._main_chat_loop()) - self._loop_task.add_done_callback(self._handle_loop_completion) - - # 启动聊天内容概括器的后台定期检查循环 - await self.chat_history_summarizer.start() - - logger.info(f"{self.log_prefix} HeartFChatting 启动完成") - - except Exception as e: - # 启动失败时重置状态 - self.running = False - self._loop_task = None - logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}") - raise - - def _handle_loop_completion(self, task: asyncio.Task): - """当 _hfc_loop 任务完成时执行的回调。""" - try: - if exception := task.exception(): - logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}") - logger.error(traceback.format_exc()) # Log full traceback for exceptions - else: - logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)") - except asyncio.CancelledError: - logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") - - def start_cycle(self) -> Tuple[Dict[str, float], str]: - self._cycle_counter += 1 - self._current_cycle_detail = CycleDetail(self._cycle_counter) - self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" - cycle_timers = {} - return cycle_timers, self._current_cycle_detail.thinking_id - - def end_cycle(self, loop_info, cycle_timers): - self._current_cycle_detail.set_loop_info(loop_info) - self.history_loop.append(self._current_cycle_detail) - self._current_cycle_detail.timers = cycle_timers - self._current_cycle_detail.end_time = time.time() - - def print_cycle_info(self, cycle_timers): - # 记录循环信息和计时器结果 - timer_strings = [] - for name, elapsed in cycle_timers.items(): - if elapsed < 0.1: - # 不显示小于0.1秒的计时器 - continue - formatted_time = f"{elapsed:.2f}秒" - timer_strings.append(f"{name}: {formatted_time}") - - logger.info( - f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," - f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore - + (f"详情: {'; '.join(timer_strings)}" if timer_strings else "") - ) - - async def _loopbody(self): - recent_messages_list = message_api.get_messages_by_time_in_chat( - chat_id=self.stream_id, - start_time=self.last_read_time, - end_time=time.time(), - limit=20, - limit_mode="latest", - filter_mai=True, - filter_command=False, - filter_intercept_message_level=0, - ) - - # 根据连续 no_reply 次数动态调整阈值 - # 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2) - # 5次 no_reply 时,提高到 2(大于等于两条消息的阈值) - if self.consecutive_no_reply_count >= 5: - threshold = 2 - elif self.consecutive_no_reply_count >= 3: - # 1.5 的含义:50%概率为1,50%概率为2 - threshold = 2 if random.random() < 0.5 else 1 - else: - threshold = 1 - - if len(recent_messages_list) >= threshold: - # for message in recent_messages_list: - # print(message.processed_plain_text) - - self.last_read_time = time.time() - - # !此处使at或者提及必定回复 - mentioned_message = None - for message in recent_messages_list: - if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply: - mentioned_message = message - - # logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}") - - # *控制频率用 - if mentioned_message: - await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message) - elif ( - random.random() - < global_config.chat.get_talk_value(self.stream_id) - * frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust() - ): - await self._observe(recent_messages_list=recent_messages_list) - else: - # 没有提到,继续保持沉默,等待5秒防止频繁触发 - await asyncio.sleep(10) - return True - else: - await asyncio.sleep(0.2) - return True - return True - - async def _send_and_store_reply( - self, - response_set: "ReplySetModel", - action_message: "DatabaseMessages", - cycle_timers: Dict[str, float], - thinking_id, - actions, - selected_expressions: Optional[List[int]] = None, - quote_message: Optional[bool] = None, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: - with Timer("回复发送", cycle_timers): - reply_text = await self._send_response( - reply_set=response_set, - message_data=action_message, - selected_expressions=selected_expressions, - quote_message=quote_message, - ) - - # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - platform = action_message.chat_info.platform - if platform is None: - platform = getattr(self.chat_stream, "platform", "unknown") - - person = Person(platform=platform, user_id=action_message.user_info.user_id) - person_name = person.person_name - action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" - - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=action_prompt_display, - action_done=True, - thinking_id=thinking_id, - action_data={"reply_text": reply_text}, - action_name="reply", - ) - - # 构建循环信息 - loop_info: Dict[str, Any] = { - "loop_plan_info": { - "action_result": actions, - }, - "loop_action_info": { - "action_taken": True, - "reply_text": reply_text, - "command": "", - "taken_time": time.time(), - }, - } - - return loop_info, reply_text, cycle_timers - - async def _observe( - self, # interest_value: float = 0.0, - recent_messages_list: Optional[List["DatabaseMessages"]] = None, - force_reply_message: Optional["DatabaseMessages"] = None, - ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if - if recent_messages_list is None: - recent_messages_list = [] - _reply_text = "" # 初始化reply_text变量,避免UnboundLocalError - - start_time = time.time() - async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): - # 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner - # 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息 - asyncio.create_task(extract_and_distribute_messages(self.stream_id)) - - # 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容 - # asyncio.create_task(check_and_make_question(self.stream_id)) - # 添加聊天内容概括任务 - 累积、打包和压缩聊天记录 - # 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理 - # asyncio.create_task(self.chat_history_summarizer.process()) - - cycle_timers, thinking_id = self.start_cycle() - logger.info( - f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})" - ) - - # 第一步:动作检查 - available_actions: Dict[str, ActionInfo] = {} - try: - await self.action_modifier.modify_actions() - available_actions = self.action_manager.get_using_actions() - except Exception as e: - logger.error(f"{self.log_prefix} 动作修改失败: {e}") - - # 执行planner - is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() - - message_list_before_now = get_raw_msg_before_timestamp_with_chat( - chat_id=self.stream_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.6), - filter_intercept_message_level=1, - ) - chat_content_block, message_id_list = build_readable_messages_with_id( - messages=message_list_before_now, - timestamp_mode="normal_no_YMD", - read_mark=self.action_planner.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - - prompt_info = await self.action_planner.build_planner_prompt( - is_group_chat=is_group_chat, - chat_target_info=chat_target_info, - current_available_actions=available_actions, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - ) - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id - ) - if not continue_flag: - return False - if modified_message and modified_message._modify_flags.modify_llm_prompt: - prompt_info = (modified_message.llm_prompt, prompt_info[1]) - - with Timer("规划器", cycle_timers): - action_to_use_info = await self.action_planner.plan( - loop_start_time=self.last_read_time, - available_actions=available_actions, - force_reply_message=force_reply_message, - ) - - logger.info( - f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}" - ) - - # 3. 并行执行所有动作 - action_tasks = [ - asyncio.create_task( - self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers) - ) - for action in action_to_use_info - ] - - # 并行执行所有任务 - results = await asyncio.gather(*action_tasks, return_exceptions=True) - - # 处理执行结果 - reply_loop_info = None - reply_text_from_reply = "" - action_success = False - action_reply_text = "" - - excute_result_str = "" - for result in results: - excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n" - - if isinstance(result, BaseException): - logger.error(f"{self.log_prefix} 动作执行异常: {result}") - continue - - if result["action_type"] != "reply": - action_success = result["success"] - action_reply_text = result["result"] - elif result["action_type"] == "reply": - if result["success"]: - reply_loop_info = result["loop_info"] - reply_text_from_reply = result["result"] - else: - logger.warning(f"{self.log_prefix} 回复动作执行失败") - - self.action_planner.add_plan_excute_log(result=excute_result_str) - - # 构建最终的循环信息 - if reply_loop_info: - # 如果有回复信息,使用回复的loop_info作为基础 - loop_info = reply_loop_info - # 更新动作执行信息 - loop_info["loop_action_info"].update( - { - "action_taken": action_success, - "taken_time": time.time(), - } - ) - _reply_text = reply_text_from_reply - else: - # 没有回复信息,构建纯动作的loop_info - loop_info = { - "loop_plan_info": { - "action_result": action_to_use_info, - }, - "loop_action_info": { - "action_taken": action_success, - "reply_text": action_reply_text, - "taken_time": time.time(), - }, - } - _reply_text = action_reply_text - - self.end_cycle(loop_info, cycle_timers) - self.print_cycle_info(cycle_timers) - - end_time = time.time() - if end_time - start_time < global_config.chat.planner_smooth: - wait_time = global_config.chat.planner_smooth - (end_time - start_time) - await asyncio.sleep(wait_time) - else: - await asyncio.sleep(0.1) - return True - - async def _main_chat_loop(self): - """主循环,持续进行计划并可能回复消息,直到被外部取消。""" - try: - while self.running: - # 主循环 - success = await self._loopbody() - await asyncio.sleep(0.1) - if not success: - break - except asyncio.CancelledError: - # 设置了关闭标志位后被取消是正常流程 - logger.info(f"{self.log_prefix} 麦麦已关闭聊天") - except Exception: - logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动") - print(traceback.format_exc()) - await asyncio.sleep(3) - self._loop_task = asyncio.create_task(self._main_chat_loop()) - logger.error(f"{self.log_prefix} 结束了当前聊天循环") - - async def _handle_action( - self, - action: str, - action_reasoning: str, - action_data: dict, - cycle_timers: Dict[str, float], - thinking_id: str, - action_message: Optional["DatabaseMessages"] = None, - ) -> tuple[bool, str, str]: - """ - 处理规划动作,使用动作工厂创建相应的动作处理器 - - 参数: - action: 动作类型 - action_reasoning: 决策理由 - action_data: 动作数据,包含不同动作需要的参数 - cycle_timers: 计时器字典 - thinking_id: 思考ID - action_message: 消息数据 - 返回: - tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令) - """ - try: - # 使用工厂创建动作处理器实例 - try: - action_handler = self.action_manager.create_action( - action_name=action, - action_data=action_data, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - chat_stream=self.chat_stream, - log_prefix=self.log_prefix, - action_reasoning=action_reasoning, - action_message=action_message, - ) - except Exception as e: - logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}") - traceback.print_exc() - return False, "" - - # 处理动作并获取结果(固定记录一次动作信息) - result = await action_handler.execute() - success, action_text = result - - return success, action_text - - except Exception as e: - logger.error(f"{self.log_prefix} 处理{action}时出错: {e}") - traceback.print_exc() - return False, "" - - async def _send_response( - self, - reply_set: "ReplySetModel", - message_data: "DatabaseMessages", - selected_expressions: Optional[List[int]] = None, - quote_message: Optional[bool] = None, - ) -> str: - # 根据 llm_quote 配置决定是否使用 quote_message 参数 - if global_config.chat.llm_quote: - # 如果配置为 true,使用 llm_quote 参数决定是否引用回复 - if quote_message is None: - logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用") - need_reply = False - else: - need_reply = quote_message - if need_reply: - logger.info(f"{self.log_prefix} LLM 决定使用引用回复") - else: - # 如果配置为 false,使用原来的模式 - new_message_count = message_api.count_new_messages( - chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time() - ) - need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90 - if need_reply: - logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒") - - reply_text = "" - first_replied = False - for reply_content in reply_set.reply_data: - if reply_content.content_type != ReplyContentType.TEXT: - continue - data: str = reply_content.content # type: ignore - if not first_replied: - await send_api.text_to_stream( - text=data, - stream_id=self.chat_stream.stream_id, - reply_message=message_data, - set_reply=need_reply, - typing=False, - selected_expressions=selected_expressions, - ) - first_replied = True - else: - await send_api.text_to_stream( - text=data, - stream_id=self.chat_stream.stream_id, - reply_message=message_data, - set_reply=False, - typing=True, - selected_expressions=selected_expressions, - ) - reply_text += data - - return reply_text - - async def _execute_action( - self, - action_planner_info: ActionPlannerInfo, - chosen_action_plan_infos: List[ActionPlannerInfo], - thinking_id: str, - available_actions: Dict[str, ActionInfo], - cycle_timers: Dict[str, float], - ): - """执行单个动作的通用函数""" - try: - with Timer(f"动作{action_planner_info.action_type}", cycle_timers): - # 直接当场执行no_reply逻辑 - if action_planner_info.action_type == "no_reply": - # 直接处理no_reply逻辑,不再通过动作系统 - reason = action_planner_info.reasoning or "选择不回复" - # logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") - - # 增加连续 no_reply 计数 - self.consecutive_no_reply_count += 1 - - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, - thinking_id=thinking_id, - action_data={}, - action_name="no_reply", - action_reasoning=reason, - ) - - return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""} - - elif action_planner_info.action_type == "reply": - # 直接当场执行reply逻辑 - self.questioned = False - # 刷新主动发言状态 - # 重置连续 no_reply 计数 - self.consecutive_no_reply_count = 0 - - reason = action_planner_info.reasoning or "" - # 根据 think_mode 配置决定 think_level 的值 - think_mode = global_config.chat.think_mode - if think_mode == "default": - think_level = 0 - elif think_mode == "deep": - think_level = 1 - elif think_mode == "dynamic": - # dynamic 模式:从 planner 返回的 action_data 中获取 - think_level = action_planner_info.action_data.get("think_level", 1) - else: - # 默认使用 default 模式 - think_level = 0 - # 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason - planner_reasoning = action_planner_info.action_reasoning or reason - - record_replyer_action_temp( - chat_id=self.stream_id, - reason=reason, - think_level=think_level, - ) - - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, - thinking_id=thinking_id, - action_data={}, - action_name="reply", - action_reasoning=reason, - ) - - # 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用) - unknown_words = None - quote_message = None - if isinstance(action_planner_info.action_data, dict): - uw = action_planner_info.action_data.get("unknown_words") - if isinstance(uw, list): - cleaned_uw: List[str] = [] - for item in uw: - if isinstance(item, str): - s = item.strip() - if s: - cleaned_uw.append(s) - if cleaned_uw: - unknown_words = cleaned_uw - - # 从 Planner 的 action_data 中提取 quote_message 参数 - qm = action_planner_info.action_data.get("quote") - if qm is not None: - # 支持多种格式:true/false, "true"/"false", 1/0 - if isinstance(qm, bool): - quote_message = qm - elif isinstance(qm, str): - quote_message = qm.lower() in ("true", "1", "yes") - elif isinstance(qm, (int, float)): - quote_message = bool(qm) - - logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}") - - success, llm_response = await generator_api.generate_reply( - chat_stream=self.chat_stream, - reply_message=action_planner_info.action_message, - available_actions=available_actions, - chosen_actions=chosen_action_plan_infos, - reply_reason=planner_reasoning, - unknown_words=unknown_words, - enable_tool=global_config.tool.enable_tool, - request_type="replyer", - from_plugin=False, - reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()), - think_level=think_level, - ) - - if not success or not llm_response or not llm_response.reply_set: - if action_planner_info.action_message: - logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败") - else: - logger.info("回复生成失败") - return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None} - - response_set = llm_response.reply_set - selected_expressions = llm_response.selected_expressions - loop_info, reply_text, _ = await self._send_and_store_reply( - response_set=response_set, - action_message=action_planner_info.action_message, # type: ignore - cycle_timers=cycle_timers, - thinking_id=thinking_id, - actions=chosen_action_plan_infos, - selected_expressions=selected_expressions, - quote_message=quote_message, - ) - self.last_active_time = time.time() - return { - "action_type": "reply", - "success": True, - "result": f"你使用reply动作,对' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'", - "loop_info": loop_info, - } - - else: - # 执行普通动作 - with Timer("动作执行", cycle_timers): - success, result = await self._handle_action( - action=action_planner_info.action_type, - action_reasoning=action_planner_info.action_reasoning or "", - action_data=action_planner_info.action_data or {}, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - action_message=action_planner_info.action_message, - ) - - self.last_active_time = time.time() - return { - "action_type": action_planner_info.action_type, - "success": success, - "result": result, - } - - except Exception as e: - logger.error(f"{self.log_prefix} 执行动作时出错: {e}") - logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") - return { - "action_type": action_planner_info.action_type, - "success": False, - "result": "", - "loop_info": None, - "error": str(e), - } diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index af0beb4e..74d94773 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -1,377 +1,231 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from rich.traceback import install +from typing import List, Optional, TYPE_CHECKING import asyncio import random import time import traceback -from rich.traceback import install - -from src.bw_learner.expression_learner import ExpressionLearner -from src.bw_learner.jargon_miner import JargonMiner -from src.chat.event_helpers import build_event_message -from src.chat.logger.plan_reply_logger import PlanReplyLogger -from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.planner_actions.action_modifier import ActionModifier -from src.chat.planner_actions.planner import ActionPlanner -from src.chat.utils.prompt_builder import global_prompt_manager -from src.chat.utils.timer_calculator import Timer -from src.chat.utils.utils import record_replyer_action_temp -from src.common.data_models.info_data_model import ActionPlannerInfo -from src.common.data_models.message_component_data_model import MessageSequence, TextComponent +from src.chat.message_receive.chat_manager import chat_manager from src.common.logger import get_logger from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils from src.config.config import global_config from src.config.file_watcher import FileChange -from src.core.event_bus import event_bus -from src.core.types import ActionInfo, EventType -from src.person_info.person_info import Person -from src.services import ( - database_service as database_api, - generator_service as generator_api, - message_service as message_api, - send_service as send_api, -) -from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat +from src.learners.expression_learner import ExpressionLearner +from src.learners.jargon_miner import JargonMiner from .heartFC_utils import CycleDetail if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage - install(extra_lines=5) logger = get_logger("heartFC_chat") class HeartFChatting: - """管理一个持续运行的 Focus Chat 会话。""" + """ + 管理一个连续的Focus Chat聊天会话 + 用于在特定的聊天会话里面生成回复 + """ def __init__(self, session_id: str): - self.session_id = session_id - self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.session_id) # type: ignore[assignment] - if not self.chat_stream: - raise ValueError(f"无法找到聊天会话 {self.session_id}") + """ + 初始化 HeartFChatting 实例 - session_name = _chat_manager.get_session_name(session_id) or session_id + Args: + session_id: 聊天会话ID + """ + # 基础属性 + self.session_id = session_id + session_name = chat_manager.get_session_name(session_id) or session_id self.log_prefix = f"[{session_name}]" self.session_name = session_name - self.action_manager = ActionManager() - self.action_planner = ActionPlanner(chat_id=self.session_id, action_manager=self.action_manager) - self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.session_id) - + # 系统运行状态 self._running: bool = False self._loop_task: Optional[asyncio.Task] = None + self._cycle_counter: int = 0 + self._hfc_lock: asyncio.Lock = asyncio.Lock() # 用于保护 _hfc_func 的并发访问 + # 聊天频率相关 + self._consecutive_no_reply_count = 0 # 跟踪连续 no_reply 次数,用于动态调整阈值 + self._talk_frequency_adjust: float = 1.0 # 发言频率修正值,默认为1.0,可以根据需要调整 + + # HFC内消息缓存 + self.message_cache: List[SessionMessage] = [] + + # Asyncio Event 用于控制循环的开始和结束 self._cycle_event = asyncio.Event() - self._hfc_lock = asyncio.Lock() - - self._cycle_counter = 0 - self._current_cycle_detail: Optional[CycleDetail] = None - self.history_loop: List[CycleDetail] = [] - - self.last_read_time = time.time() - 2 - self.last_active_time = time.time() - self._talk_frequency_adjust = 1.0 - self._consecutive_no_reply_count = 0 - - self.message_cache: List["SessionMessage"] = [] - - self._min_messages_for_extraction = 30 - self._min_extraction_interval = 60 - self._last_extraction_time = 0.0 + # 表达方式相关内容 + self._min_messages_for_extraction = 30 # 最少提取消息数 + self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒 + self._last_extraction_time: float = 0.0 # 上次提取的时间戳 expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id) - self._enable_expression_use = expr_use - self._enable_expression_learning = expr_learn - self._enable_jargon_learning = jargon_learn - self._expression_learner = ExpressionLearner(session_id) - self._jargon_miner = JargonMiner(session_id, session_name=session_name) + self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习 + self._enable_expression_learning = expr_learn # 允许学习表达方式 + self._enable_jargon_learning = jargon_learn # 允许学习黑话 + # 表达学习器 + self._expression_learner: ExpressionLearner = ExpressionLearner(session_id) + # 黑话挖掘器 + self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name) + + # TODO: ChatSummarizer 聊天总结器重构 + + # ====== 公开方法 ====== async def start(self): + """启动 HeartFChatting 的主循环""" + # 先检查是否已经启动运行 if self._running: - logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中") + logger.debug(f"{self.log_prefix} 已经在运行中,无需重复启动") return try: self._running = True - self._cycle_event.clear() + self._cycle_event.clear() # 确保事件初始状态为未设置 + self._loop_task = asyncio.create_task(self.main_loop()) self._loop_task.add_done_callback(self._handle_loop_completion) + logger.info(f"{self.log_prefix} HeartFChatting 启动完成") - except Exception as exc: - logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {exc}", exc_info=True) - self._running = False - self._cycle_event.set() - self._loop_task = None + except Exception as e: + logger.error(f"{self.log_prefix} 启动 HeartFChatting 失败: {e}", exc_info=True) + self._running = False # 确保状态正确 + self._cycle_event.set() # 确保事件被设置,避免死锁 + self._loop_task = None # 确保任务引用被清理 raise async def stop(self): + """停止 HeartFChatting 的主循环""" if not self._running: - logger.debug(f"{self.log_prefix} HeartFChatting 已停止") + logger.debug(f"{self.log_prefix} HeartFChatting 已经停止,无需重复停止") return self._running = False - self._cycle_event.set() + self._cycle_event.set() # 触发事件,通知循环结束 if self._loop_task: - self._loop_task.cancel() + self._loop_task.cancel() # 取消主循环任务 try: - await self._loop_task + await self._loop_task # 等待任务完成 except asyncio.CancelledError: - logger.info(f"{self.log_prefix} HeartFChatting 主循环已取消") - except Exception as exc: - logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {exc}", exc_info=True) + logger.info(f"{self.log_prefix} HeartFChatting 主循环已成功取消") + except Exception as e: + logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {e}", exc_info=True) finally: - self._loop_task = None + self._loop_task = None # 确保任务引用被清理 logger.info(f"{self.log_prefix} HeartFChatting 已停止") def adjust_talk_frequency(self, new_value: float): + """调整发言频率的调整值 + + Args: + new_value: 新的修正值,必须为非负数。值越大,修正发言频率越高;值越小,修正发言频率越低。 + """ self._talk_frequency_adjust = max(0.0, new_value) async def register_message(self, message: "SessionMessage"): + """注册一条消息到 HeartFChatting 的缓存中,并检测其是否产生提及,决定是否唤醒聊天 + + Args: + message: 待注册的消息对象 + """ self.message_cache.append(message) - + # 先检查at必回复 if global_config.chat.inevitable_at_reply and message.is_at: - self.last_read_time = time.time() - async with self._hfc_lock: - await self._judge_and_response(mentioned_message=message, recent_messages_list=[message]) - return - + async with self._hfc_lock: # 确保与主循环逻辑的互斥访问 + await self._judge_and_response(message) + return # 直接返回,避免同一条消息被主循环再次处理 + # 再检查提及必回复 if global_config.chat.mentioned_bot_reply and message.is_mentioned: - self.last_read_time = time.time() - async with self._hfc_lock: - await self._judge_and_response(mentioned_message=message, recent_messages_list=[message]) + # 直接获取锁,确保一定一定触发回复逻辑,不受当前是否正在执行主循环的影响 + async with self._hfc_lock: # 确保与主循环逻辑的互斥访问 + await self._judge_and_response(message) return async def main_loop(self): try: while self._running and not self._cycle_event.is_set(): if not self._hfc_lock.locked(): - async with self._hfc_lock: + async with self._hfc_lock: # 确保主循环逻辑的互斥访问 await self._hfc_func() - await asyncio.sleep(0.1) + await asyncio.sleep(5) except asyncio.CancelledError: - logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消") - except Exception as exc: - logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常: {exc}", exc_info=True) - await self.stop() + logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消,正在关闭") + except Exception as e: + logger.error(f"{self.log_prefix} 麦麦聊天意外错误: {e},将于3s后尝试重新启动") + await self.stop() # 确保状态正确 await asyncio.sleep(3) - await self.start() + await self.start() # 尝试重新启动 async def _config_callback(self, file_change: Optional[FileChange] = None): - del file_change - expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.session_id) - self._enable_expression_use = expr_use - self._enable_expression_learning = expr_learn - self._enable_jargon_learning = jargon_learn + """配置文件变更回调函数""" + # TODO: 根据配置文件变动重新计算相关参数: + """ + 需要计算的参数: + self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习 + self._enable_expression_learning = expr_learn # 允许学习表达方式 + self._enable_jargon_learning = jargon_learn # 允许学习黑话 + """ - async def _hfc_func(self): - recent_messages_list = message_api.get_messages_by_time_in_chat( - chat_id=self.session_id, - start_time=self.last_read_time, - end_time=time.time(), - limit=20, - limit_mode="latest", - filter_mai=True, - filter_command=False, - filter_intercept_message_level=1, - ) + # ====== 心流聊天核心逻辑 ====== + async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None): + """心流聊天的主循环逻辑""" + if self._consecutive_no_reply_count >= 5: + threshold = 2 + elif self._consecutive_no_reply_count >= 3: + threshold = 2 if random.random() < 0.5 else 1 + else: + threshold = 1 - if len(recent_messages_list) < 1: + if len(self.message_cache) < threshold: await asyncio.sleep(0.2) return True - self.last_read_time = time.time() - - mentioned_message: Optional["SessionMessage"] = None - for message in recent_messages_list: - if global_config.chat.inevitable_at_reply and message.is_at: - mentioned_message = message - elif global_config.chat.mentioned_bot_reply and message.is_mentioned: - mentioned_message = message - - talk_value = ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust - if mentioned_message: - await self._judge_and_response(mentioned_message=mentioned_message, recent_messages_list=recent_messages_list) - elif random.random() < talk_value: - await self._judge_and_response(recent_messages_list=recent_messages_list) + talk_value_threshold = ( + random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust + ) + if mentioned_message and global_config.chat.mentioned_bot_reply: + await self._judge_and_response(mentioned_message) + elif random.random() < talk_value_threshold: + await self._judge_and_response() return True - async def _judge_and_response( - self, - mentioned_message: Optional["SessionMessage"] = None, - recent_messages_list: Optional[List["SessionMessage"]] = None, - ): - recent_messages = list(recent_messages_list or self.message_cache[-20:]) - if recent_messages: - asyncio.create_task(self._trigger_expression_learning(recent_messages)) - - cycle_timers, thinking_id = self._start_cycle() + async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None): + """判定和生成回复""" + asyncio.create_task(self._trigger_expression_learning(self.message_cache)) + # TODO: 完成反思器之后的逻辑 + start_time = time.time() + current_cycle_detail = self._start_cycle() logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") - try: - async with global_prompt_manager.async_message_scope(self._get_template_name()): - available_actions: Dict[str, ActionInfo] = {} - try: - await self.action_modifier.modify_actions() - available_actions = self.action_manager.get_using_actions() - except Exception as exc: - logger.error(f"{self.log_prefix} 动作修改失败: {exc}", exc_info=True) + # TODO: 动作检查逻辑 + # TODO: Planner逻辑 + # TODO: 动作执行逻辑 - is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() - message_list_before_now = get_messages_before_time_in_chat( - chat_id=self.session_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.6), - filter_intercept_message_level=1, - ) - chat_content_block, message_id_list = build_readable_messages_with_id( - messages=message_list_before_now, - timestamp_mode="normal_no_YMD", - read_mark=self.action_planner.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - - prompt, filtered_actions = await self._build_planner_prompt_with_event( - available_actions=available_actions, - is_group_chat=is_group_chat, - chat_target_info=chat_target_info, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - ) - if prompt is None: - return False - - with Timer("规划器", cycle_timers): - reasoning, action_to_use_info, llm_raw_output, llm_reasoning, llm_duration_ms = ( - await self.action_planner._execute_main_planner( - prompt=prompt, - message_id_list=message_id_list, - filtered_actions=filtered_actions, - available_actions=available_actions, - loop_start_time=self.last_read_time, - ) - ) - - action_to_use_info = self._ensure_force_reply_action( - actions=action_to_use_info, - force_reply_message=mentioned_message, - available_actions=available_actions, - ) - self.action_planner.add_plan_log(reasoning, action_to_use_info) - self.action_planner.last_obs_time_mark = time.time() - self._log_plan( - prompt=prompt, - reasoning=reasoning, - llm_raw_output=llm_raw_output, - llm_reasoning=llm_reasoning, - llm_duration_ms=llm_duration_ms, - actions=action_to_use_info, - ) - - logger.info( - f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}" - ) - - action_tasks = [ - asyncio.create_task( - self._execute_action( - action, - action_to_use_info, - thinking_id, - available_actions, - cycle_timers, - ) - ) - for action in action_to_use_info - ] - results = await asyncio.gather(*action_tasks, return_exceptions=True) - - reply_loop_info = None - reply_text_from_reply = "" - action_success = False - action_reply_text = "" - execute_result_str = "" - - for result in results: - if isinstance(result, BaseException): - logger.error(f"{self.log_prefix} 动作执行异常: {result}", exc_info=True) - continue - - execute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n" - if result["action_type"] == "reply": - if result["success"]: - reply_loop_info = result["loop_info"] - reply_text_from_reply = result["result"] - else: - logger.warning(f"{self.log_prefix} reply 动作执行失败") - else: - action_success = result["success"] - action_reply_text = result["result"] - - self.action_planner.add_plan_excute_log(result=execute_result_str) - - if reply_loop_info: - loop_info = reply_loop_info - loop_info["loop_action_info"].update( - { - "action_taken": action_success, - "taken_time": time.time(), - } - ) - else: - loop_info = { - "loop_plan_info": { - "action_result": action_to_use_info, - }, - "loop_action_info": { - "action_taken": action_success, - "reply_text": action_reply_text, - "taken_time": time.time(), - }, - } - reply_text_from_reply = action_reply_text - - current_cycle_detail = self._end_cycle(self._current_cycle_detail, loop_info) - logger.debug(f"{self.log_prefix} 本轮最终输出: {reply_text_from_reply}") - return current_cycle_detail is not None - except Exception as exc: - logger.error(f"{self.log_prefix} 判定与回复流程失败: {exc}", exc_info=True) - if self._current_cycle_detail: - self._end_cycle( - self._current_cycle_detail, - { - "loop_plan_info": {"action_result": []}, - "loop_action_info": { - "action_taken": False, - "reply_text": "", - "taken_time": time.time(), - "error": str(exc), - }, - }, - ) - return False + cycle_detail = self._end_cycle(current_cycle_detail) + if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0: + await asyncio.sleep(wait_time) + else: + await asyncio.sleep(0.1) # 最小等待时间,避免过快循环 + return True def _handle_loop_completion(self, task: asyncio.Task): + """当 _hfc_func 任务完成时执行的回调。""" try: if exception := task.exception(): - logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常退出: {exception}") - logger.error(traceback.format_exc()) + logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}") + logger.error(traceback.format_exc()) # Log full traceback for exceptions else: - logger.info(f"{self.log_prefix} HeartFChatting: 主循环已退出") + logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)") except asyncio.CancelledError: - logger.info(f"{self.log_prefix} HeartFChatting: 聊天已结束") + logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") + # ====== 学习器触发逻辑 ====== async def _trigger_expression_learning(self, messages: List["SessionMessage"]): - if not messages: - return - self._expression_learner.add_messages(messages) if time.time() - self._last_extraction_time < self._min_extraction_interval: return @@ -379,14 +233,12 @@ class HeartFChatting: return if not self._enable_expression_learning: return - extraction_end_time = time.time() logger.info( f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息," f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}" ) self._last_extraction_time = extraction_end_time - try: jargon_miner = self._jargon_miner if self._enable_jargon_learning else None learnt_style = await self._expression_learner.learn(jargon_miner) @@ -394,398 +246,43 @@ class HeartFChatting: logger.info(f"{self.log_prefix} 表达学习完成") else: logger.debug(f"{self.log_prefix} 表达学习未获得有效结果") - except Exception as exc: - logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True) + except Exception as e: + logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True) - def _start_cycle(self) -> Tuple[Dict[str, float], str]: + # ====== 记录循环执行信息相关逻辑 ====== + def _start_cycle(self) -> CycleDetail: self._cycle_counter += 1 - self._current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter) - self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" - return self._current_cycle_detail.time_records, self._current_cycle_detail.thinking_id + current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter) + current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" + return current_cycle_detail - def _end_cycle(self, cycle_detail: Optional[CycleDetail], loop_info: Optional[Dict[str, Any]] = None): - if cycle_detail is None: - return None - - cycle_detail.loop_plan_info = (loop_info or {}).get("loop_plan_info") - cycle_detail.loop_action_info = (loop_info or {}).get("loop_action_info") + def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True): cycle_detail.end_time = time.time() - self.history_loop.append(cycle_detail) - - timer_strings = [ + timer_strings: List[str] = [ f"{name}: {duration:.2f}s" for name, duration in cycle_detail.time_records.items() - if duration >= 0.1 + if not only_long_execution or duration >= 0.1 ] logger.info( - f"{self.log_prefix} 第{cycle_detail.cycle_id} 个心流循环完成," - f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}s;" + f"{self.log_prefix} 第 {cycle_detail.cycle_id} 个心流循环完成" + f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}秒\n" f"详细计时: {', '.join(timer_strings) if timer_strings else '无'}" ) + return cycle_detail - async def _execute_action( - self, - action_planner_info: ActionPlannerInfo, - chosen_action_plan_infos: List[ActionPlannerInfo], - thinking_id: str, - available_actions: Dict[str, ActionInfo], - cycle_timers: Dict[str, float], - ): - try: - with Timer(f"动作{action_planner_info.action_type}", cycle_timers): - if action_planner_info.action_type == "no_reply": - reason = action_planner_info.reasoning or "选择不回复" - self._consecutive_no_reply_count += 1 - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=reason, - thinking_id=thinking_id, - action_data={}, - action_name="no_reply", - action_reasoning=reason, - ) - return { - "action_type": "no_reply", - "success": True, - "result": "选择不回复", - "loop_info": None, - } + # ====== Action相关逻辑 ====== + async def _execute_action(self, *args, **kwargs): + """原ExecuteAction""" + raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符 - if action_planner_info.action_type == "reply": - self._consecutive_no_reply_count = 0 - reason = action_planner_info.reasoning or "" - think_level = self._get_think_level(action_planner_info) - planner_reasoning = action_planner_info.action_reasoning or reason + async def _execute_other_actions(self, *args, **kwargs): + """原HandleAction""" + raise NotImplementedError( + "执行其他动作的逻辑尚未实现" + ) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符 - record_replyer_action_temp( - chat_id=self.session_id, - reason=reason, - think_level=think_level, - ) - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=reason, - thinking_id=thinking_id, - action_data={}, - action_name="reply", - action_reasoning=reason, - ) - - unknown_words, quote_message = self._extract_reply_metadata(action_planner_info) - success, llm_response = await generator_api.generate_reply( - chat_stream=self.chat_stream, - reply_message=action_planner_info.action_message, - available_actions=available_actions, - chosen_actions=chosen_action_plan_infos, - reply_reason=planner_reasoning, - unknown_words=unknown_words, - enable_tool=global_config.tool.enable_tool, - request_type="replyer", - from_plugin=False, - reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()) - if action_planner_info.action_data - else time.time(), - think_level=think_level, - ) - if not success or not llm_response or not llm_response.reply_set: - if action_planner_info.action_message: - logger.info( - f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败" - ) - else: - logger.info(f"{self.log_prefix} 回复生成失败") - return { - "action_type": "reply", - "success": False, - "result": "回复生成失败", - "loop_info": None, - } - - loop_info, reply_text, _ = await self._send_and_store_reply( - response_set=llm_response.reply_set, - action_message=action_planner_info.action_message, # type: ignore[arg-type] - cycle_timers=cycle_timers, - thinking_id=thinking_id, - actions=chosen_action_plan_infos, - selected_expressions=llm_response.selected_expressions, - quote_message=quote_message, - ) - self.last_active_time = time.time() - return { - "action_type": "reply", - "success": True, - "result": reply_text, - "loop_info": loop_info, - } - - with Timer("动作执行", cycle_timers): - success, result = await self._handle_action( - action=action_planner_info.action_type, - action_reasoning=action_planner_info.action_reasoning or "", - action_data=action_planner_info.action_data or {}, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - action_message=action_planner_info.action_message, - ) - if success: - self.last_active_time = time.time() - return { - "action_type": action_planner_info.action_type, - "success": success, - "result": result, - "loop_info": None, - } - except Exception as exc: - logger.error(f"{self.log_prefix} 执行动作时出错: {exc}", exc_info=True) - return { - "action_type": action_planner_info.action_type, - "success": False, - "result": "", - "loop_info": None, - "error": str(exc), - } - - async def _handle_action( - self, - action: str, - action_reasoning: str, - action_data: dict, - cycle_timers: Dict[str, float], - thinking_id: str, - action_message: Optional["SessionMessage"] = None, - ) -> Tuple[bool, str]: - try: - action_handler = self.action_manager.create_action( - action_name=action, - action_data=action_data, - action_reasoning=action_reasoning, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - chat_stream=self.chat_stream, - log_prefix=self.log_prefix, - action_message=action_message, - ) - if not action_handler: - logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}") - return False, "" - - success, action_text = await action_handler.execute() - return success, action_text - except Exception as exc: - logger.error(f"{self.log_prefix} 处理动作 {action} 时出错: {exc}", exc_info=True) - return False, "" - - async def _send_and_store_reply( - self, - response_set: MessageSequence, - action_message: "SessionMessage", - cycle_timers: Dict[str, float], - thinking_id: str, - actions: List[ActionPlannerInfo], - selected_expressions: Optional[List[int]] = None, - quote_message: Optional[bool] = None, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: - with Timer("回复发送", cycle_timers): - reply_text = await self._send_response( - reply_set=response_set, - message_data=action_message, - selected_expressions=selected_expressions, - quote_message=quote_message, - ) - - platform = action_message.platform or getattr(self.chat_stream, "platform", "unknown") - person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id) - action_prompt_display = f"你对{person.person_name}进行了回复:{reply_text}" - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=action_prompt_display, - thinking_id=thinking_id, - action_data={"reply_text": reply_text}, - action_name="reply", - ) - - loop_info: Dict[str, Any] = { - "loop_plan_info": { - "action_result": actions, - }, - "loop_action_info": { - "action_taken": True, - "reply_text": reply_text, - "command": "", - "taken_time": time.time(), - }, - } - return loop_info, reply_text, cycle_timers - - async def _send_response( - self, - reply_set: MessageSequence, - message_data: "SessionMessage", - selected_expressions: Optional[List[int]] = None, - quote_message: Optional[bool] = None, - ) -> str: - if global_config.chat.llm_quote: - need_reply = bool(quote_message) - else: - new_message_count = message_api.count_new_messages( - chat_id=self.session_id, - start_time=self.last_read_time, - end_time=time.time(), - ) - need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90 - - reply_text = "" - first_replied = False - for component in reply_set.components: - if not isinstance(component, TextComponent): - continue - data = component.text - if not first_replied: - await send_api.text_to_stream( - text=data, - stream_id=self.session_id, - reply_message=message_data, - set_reply=need_reply, - typing=False, - selected_expressions=selected_expressions, - ) - first_replied = True - else: - await send_api.text_to_stream( - text=data, - stream_id=self.session_id, - reply_message=message_data, - set_reply=False, - typing=True, - selected_expressions=selected_expressions, - ) - reply_text += data - return reply_text - - async def _build_planner_prompt_with_event( - self, - available_actions: Dict[str, ActionInfo], - is_group_chat: bool, - chat_target_info: Any, - chat_content_block: str, - message_id_list: List[Tuple[str, "SessionMessage"]], - ) -> Tuple[Optional[str], Dict[str, ActionInfo]]: - filtered_actions = self.action_planner._filter_actions_by_activation_type(available_actions, chat_content_block) - prompt, _ = await self.action_planner.build_planner_prompt( - is_group_chat=is_group_chat, - chat_target_info=chat_target_info, - current_available_actions=filtered_actions, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - ) - event_message = build_event_message(EventType.ON_PLAN, llm_prompt=prompt, stream_id=self.session_id) - continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, event_message) - if not continue_flag: - logger.info(f"{self.log_prefix} ON_PLAN 事件中止了本轮 HFC") - return None, filtered_actions - if modified_message and modified_message._modify_flags.modify_llm_prompt and modified_message.llm_prompt: - prompt = modified_message.llm_prompt - return prompt, filtered_actions - - def _ensure_force_reply_action( - self, - actions: List[ActionPlannerInfo], - force_reply_message: Optional["SessionMessage"], - available_actions: Dict[str, ActionInfo], - ) -> List[ActionPlannerInfo]: - if not force_reply_message: - return actions - - has_reply_to_force_message = any( - action.action_type == "reply" - and action.action_message - and action.action_message.message_id == force_reply_message.message_id - for action in actions - ) - if has_reply_to_force_message: - return actions - - actions = [action for action in actions if action.action_type != "no_reply"] - actions.insert( - 0, - ActionPlannerInfo( - action_type="reply", - reasoning="用户提及了我,必须回复该消息", - action_data={"loop_start_time": self.last_read_time}, - action_message=force_reply_message, - available_actions=available_actions, - action_reasoning=None, - ), - ) - logger.info(f"{self.log_prefix} 检测到强制回复消息,已补充 reply 动作") - return actions - - def _log_plan( - self, - prompt: str, - reasoning: str, - llm_raw_output: Optional[str], - llm_reasoning: Optional[str], - llm_duration_ms: Optional[float], - actions: List[ActionPlannerInfo], - ) -> None: - try: - PlanReplyLogger.log_plan( - chat_id=self.session_id, - prompt=prompt, - reasoning=reasoning, - raw_output=llm_raw_output, - raw_reasoning=llm_reasoning, - actions=actions, - timing={ - "llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None, - "loop_start_time": self.last_read_time, - }, - extra=None, - ) - except Exception: - logger.exception(f"{self.log_prefix} 记录 plan 日志失败") - - def _extract_reply_metadata( - self, - action_planner_info: ActionPlannerInfo, - ) -> Tuple[Optional[List[str]], Optional[bool]]: - unknown_words: Optional[List[str]] = None - quote_message: Optional[bool] = None - action_data = action_planner_info.action_data or {} - - raw_unknown_words = action_data.get("unknown_words") - if isinstance(raw_unknown_words, list): - cleaned_unknown_words = [] - for item in raw_unknown_words: - if isinstance(item, str) and (cleaned_item := item.strip()): - cleaned_unknown_words.append(cleaned_item) - if cleaned_unknown_words: - unknown_words = cleaned_unknown_words - - raw_quote = action_data.get("quote") - if isinstance(raw_quote, bool): - quote_message = raw_quote - elif isinstance(raw_quote, str): - quote_message = raw_quote.lower() in {"true", "1", "yes"} - elif isinstance(raw_quote, (int, float)): - quote_message = bool(raw_quote) - - return unknown_words, quote_message - - def _get_think_level(self, action_planner_info: ActionPlannerInfo) -> int: - think_mode = global_config.chat.think_mode - if think_mode == "default": - return 0 - if think_mode == "deep": - return 1 - if think_mode == "dynamic": - action_data = action_planner_info.action_data or {} - return int(action_data.get("think_level", 1)) - return 0 - - def _get_template_name(self) -> Optional[str]: - if self.chat_stream.context: - return self.chat_stream.context.template_name - return None + # ====== 响应发送相关方法 ====== + async def _send_response(self, *args, **kwargs): + raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符 + # 传入的消息至少应该是个MessageSequence实例,最好是SessionMessage实例,随后可直接转化为MessageSending实例 diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py deleted file mode 100644 index febff2d5..00000000 --- a/src/chat/heart_flow/heartflow.py +++ /dev/null @@ -1,42 +0,0 @@ -import traceback -from typing import Any, Optional, Dict - -from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.logger import get_logger -from src.chat.heart_flow.heartFC_chat import HeartFChatting -from src.chat.brain_chat.brain_chat import BrainChatting -from src.chat.message_receive.chat_stream import ChatStream - -logger = get_logger("heartflow") - - -class Heartflow: - """主心流协调器,负责初始化并协调聊天""" - - def __init__(self): - self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {} - - async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]: - """获取或创建一个新的HeartFChatting实例""" - try: - if chat_id in self.heartflow_chat_list: - if chat := self.heartflow_chat_list.get(chat_id): - return chat - else: - chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id) - if not chat_stream: - raise ValueError(f"未找到 chat_id={chat_id} 的聊天流") - if chat_stream.group_info: - new_chat = HeartFChatting(chat_id=chat_id) - else: - new_chat = BrainChatting(chat_id=chat_id) - await new_chat.start() - self.heartflow_chat_list[chat_id] = new_chat - return new_chat - except Exception as e: - logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True) - traceback.print_exc() - return None - - -heartflow = Heartflow() diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index df7d28fc..1fc4ef53 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -1,19 +1,20 @@ from contextlib import suppress -import traceback -import os - -from maim_message import MessageBase from typing import Any, Dict, Optional +import os +import traceback +from maim_message import MessageBase + +from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.common.logger import get_logger from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_session import SessionUtils -from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver +from src.platform_io.route_key_factory import RouteKeyFactory # from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.core.announcement_manager import global_announcement_manager -from src.core.component_registry import component_registry +from src.plugin_runtime.component_query import component_query_service from .message import SessionMessage from .chat_manager import chat_manager @@ -58,16 +59,22 @@ class ChatBot: logger.error(f"创建PFC聊天失败: {e}") logger.error(traceback.format_exc()) - async def _process_commands(self, message: SessionMessage): - # sourcery skip: use-named-expression - """使用新插件系统处理命令""" + async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]: + """使用统一组件注册表处理命令。 + + Args: + message: 当前待处理的会话消息。 + + Returns: + tuple[bool, Optional[str], bool]: ``(是否命中命令, 命令响应文本, 是否继续后续处理)``。 + """ if not message.processed_plain_text: return False, None, True # 没有文本内容,继续处理消息 try: text = message.processed_plain_text - # 使用核心组件注册表查找命令 - command_result = component_registry.find_command_by_text(text) + # 使用插件运行时统一查询服务查找命令 + command_result = component_query_service.find_command_by_text(text) if command_result: command_executor, matched_groups, command_info = command_result plugin_name = command_info.plugin_name @@ -81,7 +88,7 @@ class ChatBot: message.is_command = True # 获取插件配置 - plugin_config = component_registry.get_plugin_config(plugin_name) + plugin_config = component_query_service.get_plugin_config(plugin_name) try: # 调用命令执行器 @@ -112,88 +119,32 @@ class ChatBot: # 命令出错时,根据命令的拦截设置决定是否继续处理消息 return True, str(e), False # 出错时继续处理消息 - # 没有找到旧系统命令,尝试新版本插件运行时 - new_cmd_result = await self._process_new_runtime_command(message) - return new_cmd_result if new_cmd_result is not None else (False, None, True) + return False, None, True except Exception as e: logger.error(f"处理命令时出错: {e}") return False, None, True # 出错时继续处理消息 - async def _process_new_runtime_command(self, message: SessionMessage): - """尝试在新版本插件运行时中查找并执行命令 - - Returns: - (found, response, continue_processing) 三元组, - 或 None 表示新运行时中也未找到匹配命令。 - """ - from src.plugin_runtime.integration import get_plugin_runtime_manager - - prm = get_plugin_runtime_manager() - if not prm.is_running: - return None - - matched = prm.find_command_by_text(message.processed_plain_text) - if matched is None: - return None - - command_name = matched["name"] - if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands( - message.session_id - ): - logger.info(f"[新运行时] 用户禁用的命令,跳过处理: {matched['full_name']}") - return False, None, True - - message.is_command = True - logger.info(f"[新运行时] 匹配命令: {matched['full_name']}") - - try: - resp = await prm.invoke_plugin( - method="plugin.invoke_command", - plugin_id=matched["plugin_id"], - component_name=matched["name"], - args={ - "text": message.processed_plain_text, - "stream_id": message.session_id or "", - "matched_groups": matched.get("matched_groups") or {}, - }, - timeout_ms=30000, - ) - - payload = resp.payload - success = payload.get("success", False) - cmd_result = payload.get("result") - - # 拦截位优先从命令返回值中获取(支持运行时动态决定), - # 回退到组件 metadata 中的静态声明 - if isinstance(cmd_result, (list, tuple)) and len(cmd_result) >= 3: - # 命令返回 (found, response_text, intercept_bool) 三元组 - response_text = cmd_result[1] if cmd_result[1] is not None else "" - intercept = bool(cmd_result[2]) - else: - response_text = cmd_result if cmd_result is not None else "" - intercept = bool(matched["metadata"].get("intercept_message_level", 0)) - - self._mark_command_message(message, int(intercept)) - - if success: - logger.info(f"[新运行时] 命令执行成功: {matched['full_name']}") - else: - logger.warning(f"[新运行时] 命令执行失败: {matched['full_name']} - {response_text}") - - return True, response_text, not intercept - - except Exception as e: - logger.error(f"[新运行时] 执行命令 {matched['full_name']} 异常: {e}", exc_info=True) - return True, str(e), True - @staticmethod def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None: + """标记消息已经被命令链消费。 + + Args: + message: 待标记的会话消息。 + intercept_message_level: 命令设置的拦截级别。 + """ + message.is_command = True message.message_info.additional_config["intercept_message_level"] = intercept_message_level @staticmethod def _store_intercepted_command_message(message: SessionMessage) -> None: + """将被命令链拦截的消息写入数据库。 + + Args: + message: 已完成命令处理的会话消息。 + """ + MessageUtils.store_message_to_db(message) async def _handle_command_processing_result( @@ -310,13 +261,28 @@ class ChatBot: # logger.debug(str(message_data)) maim_raw_message = MessageBase.from_dict(message_data) message = SessionMessage.from_maim_message(maim_raw_message) + await self.receive_message(message) + + except Exception as e: + logger.error(f"预处理消息失败: {e}") + traceback.print_exc() + + async def receive_message(self, message: SessionMessage): + try: group_info = message.message_info.group_info user_info = message.message_info.user_info + account_id = None + scope = None + additional_config = message.message_info.additional_config + if isinstance(additional_config, dict): + account_id, scope = RouteKeyFactory.extract_components(additional_config) session_id = SessionUtils.calculate_session_id( message.platform, user_id=message.message_info.user_info.user_id, group_id=group_info.group_id if group_info else None, + account_id=account_id, + scope=scope, ) message.session_id = session_id # 正确初始化session_id @@ -359,24 +325,24 @@ class ChatBot: platform = message.platform user_id = user_info.user_id group_id = group_info.group_id if group_info else None - _ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在 - try: - from src.services.memory_flow_service import memory_automation_service - - await memory_automation_service.on_incoming_message(message) - except Exception as exc: - logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}") + _ = await chat_manager.get_or_create_session( + platform, + user_id, + group_id, + account_id=account_id, + scope=scope, + ) # 确保会话存在 # message.update_chat_stream(chat) # 命令处理 - 使用新插件系统检查并处理命令 # 注意:命令返回的 response 当前只用于日志记录和流程判断, # 不会在这里自动作为回复消息发送回会话。 - is_command, cmd_result, continue_process = await self._process_commands(message) + # is_command, cmd_result, continue_process = await self._process_commands(message) - # 如果是命令且不需要继续处理,则直接返回 - if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process): - return + # # 如果是命令且不需要继续处理,则直接返回 + # if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process): + # return # continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message) # if not continue_flag: diff --git a/src/chat/message_receive/chat_manager.py b/src/chat/message_receive/chat_manager.py index b11d233c..48d89956 100644 --- a/src/chat/message_receive/chat_manager.py +++ b/src/chat/message_receive/chat_manager.py @@ -1,15 +1,16 @@ +import asyncio from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional + from rich.traceback import install from sqlmodel import select -from typing import Optional, TYPE_CHECKING, List, Dict -import asyncio - -from src.common.logger import get_logger from src.common.data_models.chat_session_data_model import MaiChatSession -from src.common.database.database_model import ChatSession from src.common.database.database import get_db_session +from src.common.database.database_model import ChatSession +from src.common.logger import get_logger from src.common.utils.utils_session import SessionUtils +from src.platform_io.route_key_factory import RouteKeyFactory if TYPE_CHECKING: from .message import SessionMessage @@ -82,7 +83,12 @@ class ChatManager: logger.error(f"初始化聊天管理器出现错误: {e}") async def get_or_create_session( - self, platform: str, user_id: str, group_id: Optional[str] = None + self, + platform: str, + user_id: str, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, ) -> BotChatSession: """获取会话,如果不存在则创建一个新会话;一个封装方法。 @@ -90,12 +96,20 @@ class ChatManager: platform: 平台 user_id: 用户ID group_id: 群ID(如果是群聊) + account_id: 平台账号 ID + scope: 路由作用域 Returns: return (BotChatSession) 会话对象 Raises: Exception: 获取或创建会话时发生错误 """ - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) if session := self.get_session_by_session_id(session_id): session.update_active_time() return session @@ -131,7 +145,18 @@ class ChatManager: raise ValueError("消息缺少平台信息") user_id = message.message_info.user_info.user_id group_id = message.message_info.group_info.group_id if message.message_info.group_info else None - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + account_id = None + scope = None + additional_config = message.message_info.additional_config + if isinstance(additional_config, dict): + account_id, scope = RouteKeyFactory.extract_components(additional_config) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) message.session_id = session_id # 确保消息的session_id正确设置 self.last_messages[session_id] = message @@ -188,7 +213,12 @@ class ChatManager: return None def get_session_by_info( - self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None + self, + platform: str, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, ) -> Optional[BotChatSession]: """根据平台、用户ID和群ID获取对应的会话 @@ -196,10 +226,18 @@ class ChatManager: platform: 平台 user_id: 用户ID group_id: 群ID(如果是群聊) + account_id: 平台账号 ID + scope: 路由作用域 Returns: return (Optional[BotChatSession]): 会话对象,如果不存在则返回None """ - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) return self.get_session_by_session_id(session_id) def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]: diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 369c0c51..246a8350 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -1,31 +1,37 @@ -from rich.traceback import install -from typing import Optional +from typing import Any, Optional, Tuple import asyncio +import traceback +from rich.traceback import install -from src.common.message_server.api import get_global_api -from src.common.logger import get_logger -from src.common.database.database import get_db_session from src.chat.message_receive.message import SessionMessage +from src.chat.utils.utils import calculate_typing_time, truncate_message from src.common.data_models.message_component_data_model import ReplyComponent -from src.chat.utils.utils import truncate_message -from src.chat.utils.utils import calculate_typing_time +from src.common.database.database import get_db_session +from src.common.logger import get_logger +from src.common.message_server.api import get_global_api +from src.webui.routers.chat.serializers import serialize_message_sequence install(extra_lines=3) logger = get_logger("sender") # WebUI 聊天室的消息广播器(延迟导入避免循环依赖) -_webui_chat_broadcaster = None +_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None # 虚拟群 ID 前缀(与 chat_routes.py 保持一致) VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_" # TODO: 重构完成后完成webui相关 -def get_webui_chat_broadcaster(): - """获取 WebUI 聊天室广播器""" +def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]: + """获取 WebUI 聊天室广播器。 + + Returns: + Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组; + 若 WebUI 相关模块不可用,则元素会退化为 ``None``。 + """ global _webui_chat_broadcaster if _webui_chat_broadcaster is None: try: @@ -38,102 +44,35 @@ def get_webui_chat_broadcaster(): def is_webui_virtual_group(group_id: str) -> bool: - """检查是否是 WebUI 虚拟群""" - return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX) - - -def parse_message_segments(segment) -> list: - """解析消息段,转换为 WebUI 可用的格式 - - 参考 NapCat 适配器的消息解析逻辑 + """检查是否是 WebUI 虚拟群。 Args: - segment: Seg 消息段对象 + group_id: 待判断的群 ID。 Returns: - list: 消息段列表,每个元素为 {"type": "...", "data": ...} + bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。 """ - - result = [] - - if segment is None: - return result - - if segment.type == "seglist": - # 处理消息段列表 - if segment.data: - for seg in segment.data: - result.extend(parse_message_segments(seg)) - elif segment.type == "text": - # 文本消息 - if segment.data: - result.append({"type": "text", "data": segment.data}) - elif segment.type == "image": - # 图片消息(base64) - if segment.data: - result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"}) - elif segment.type == "emoji": - # 表情包消息(base64) - if segment.data: - result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"}) - elif segment.type == "imageurl": - # 图片链接消息 - if segment.data: - result.append({"type": "image", "data": segment.data}) - elif segment.type == "face": - # 原生表情 - result.append({"type": "face", "data": segment.data}) - elif segment.type == "voice": - # 语音消息(base64) - if segment.data: - result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"}) - elif segment.type == "voiceurl": - # 语音链接 - if segment.data: - result.append({"type": "voice", "data": segment.data}) - elif segment.type == "video": - # 视频消息(base64) - if segment.data: - result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"}) - elif segment.type == "videourl": - # 视频链接 - if segment.data: - result.append({"type": "video", "data": segment.data}) - elif segment.type == "music": - # 音乐消息 - result.append({"type": "music", "data": segment.data}) - elif segment.type == "file": - # 文件消息 - result.append({"type": "file", "data": segment.data}) - elif segment.type == "reply": - # 回复消息 - result.append({"type": "reply", "data": segment.data}) - elif segment.type == "forward": - # 转发消息 - forward_items = [] - if segment.data: - for item in segment.data: - forward_items.append( - { - "content": parse_message_segments(item.get("message_segment", {})) - if isinstance(item, dict) - else [] - } - ) - result.append({"type": "forward", "data": forward_items}) - else: - # 未知类型,尝试作为文本处理 - if segment.data: - result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)}) - - return result + return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX) -async def _send_message(message: MessageSending, show_log=True) -> bool: - """合并后的消息发送函数,包含WS发送和日志记录""" +async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: + """执行统一的消息发送流程。 + + 发送顺序为: + 1. WebUI 特殊链路 + 2. 旧版 ``maim_message`` / API Server 链路 + + Args: + message: 待发送的内部会话消息。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 是否最终发送成功。 + """ message_preview = truncate_message(message.processed_plain_text, max_length=200) platform = message.platform - group_id = message.session.group_id + group_info = message.message_info.group_info + group_id = group_info.group_id if group_info is not None else "" try: # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 @@ -146,7 +85,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: from src.config.config import global_config # 解析消息段,获取富文本内容 - message_segments = parse_message_segments(message.message_segment) + message_segments = serialize_message_sequence(message.raw_message) # 判断消息类型 # 如果只有一个文本段,使用简单的 text 类型 @@ -185,7 +124,15 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: return True # Fallback 逻辑: 尝试通过 API Server 发送 - async def send_with_new_api(legacy_exception=None): + async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool: + """通过 API Server 回退链路发送消息。 + + Args: + legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。 + + Returns: + bool: 回退链路是否发送成功。 + """ try: from src.config.config import global_config @@ -286,10 +233,24 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: raise e # 重新抛出其他异常 -class UniversalMessageSender: - """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。""" +async def send_prepared_message_to_platform(message: SessionMessage, show_log: bool = True) -> bool: + """发送一条已完成预处理的消息到底层平台。 - def __init__(self): + Args: + message: 已经完成回复组件注入、文本处理等预处理的消息对象。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ + return await _send_message(message, show_log=show_log) + + +class UniversalMessageSender: + """旧链与 WebUI 的底层发送器。""" + + def __init__(self) -> None: + """初始化统一消息发送器。""" pass async def send_message( @@ -300,18 +261,19 @@ class UniversalMessageSender: reply_message_id: Optional[str] = None, storage_message: bool = True, show_log: bool = True, - ): - """ - 处理、发送并存储一条消息。 + ) -> bool: + """通过旧链或 WebUI 发送并存储一条消息。 - 参数: - message: MessageSession 对象,待发送的消息。 + Args: + message: 待发送的内部消息对象。 typing: 是否模拟打字等待。 - set_reply: 是否构建回复引用消息。 + set_reply: 是否构建引用回复消息。 + reply_message_id: 被引用消息的 ID。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 - - 用法: - - typing=True 时,发送前会有打字等待。 + Returns: + bool: 发送成功时返回 ``True``。 """ if not message.message_id: logger.error("消息缺少 message_id,无法发送") @@ -364,7 +326,7 @@ class UniversalMessageSender: ) await asyncio.sleep(typing_time) - sent_msg = await _send_message(message, show_log=show_log) + sent_msg = await send_prepared_message_to_platform(message, show_log=show_log) if not sent_msg: return False diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 167cdcab..8133ac18 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -3,8 +3,8 @@ from typing import Dict, Optional, Tuple from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger -from src.core.component_registry import component_registry, ActionExecutor from src.core.types import ActionInfo +from src.plugin_runtime.component_query import ActionExecutor, component_query_service logger = get_logger("action_manager") @@ -28,7 +28,7 @@ class ActionManager: """ 动作管理器,用于管理各种类型的动作 - 使用核心组件注册表的 executor-based 模式。 + 使用插件运行时统一查询服务的 executor-based 模式。 """ def __init__(self): @@ -38,7 +38,7 @@ class ActionManager: self._using_actions: Dict[str, ActionInfo] = {} # 初始化时将默认动作加载到使用中的动作 - self._using_actions = component_registry.get_default_actions() + self._using_actions = component_query_service.get_default_actions() # === 执行Action方法 === @@ -72,17 +72,17 @@ class ActionManager: Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None """ try: - executor = component_registry.get_action_executor(action_name) + executor = component_query_service.get_action_executor(action_name) if not executor: logger.warning(f"{log_prefix} 未找到Action组件: {action_name}") return None - info = component_registry.get_action_info(action_name) + info = component_query_service.get_action_info(action_name) if not info: logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") return None - plugin_config = component_registry.get_plugin_config(info.plugin_name) or {} + plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {} handle = ActionHandle( executor, @@ -133,5 +133,5 @@ class ActionManager: def restore_actions(self) -> None: """恢复到默认动作集""" actions_to_restore = list(self._using_actions.keys()) - self._using_actions = component_registry.get_default_actions() + self._using_actions = component_query_service.get_default_actions() logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 5184abcb..b21efa6b 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -1,33 +1,36 @@ +from collections import OrderedDict +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import contextlib import json -import time -import traceback import random import re -import contextlib -from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union -from collections import OrderedDict -from rich.traceback import install -from datetime import datetime +import time +import traceback + from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger +from rich.traceback import install + from src.chat.logger.plan_reply_logger import PlanReplyLogger +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.chat.message_receive.message import SessionMessage +from src.chat.planner_actions.action_manager import ActionManager +from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.common.data_models.info_data_model import ActionPlannerInfo +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.core.types import ActionActivationType, ActionInfo, ComponentType +from src.llm_models.utils_model import LLMRequest +from src.person_info.person_info import Person +from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager from src.services.message_service import ( build_readable_messages_with_id, - replace_user_references, get_messages_before_time_in_chat, + replace_user_references, translate_pid_to_description, ) -from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.message_receive.message import SessionMessage -from src.core.types import ActionActivationType, ActionInfo, ComponentType -from src.core.component_registry import component_registry -from src.person_info.person_info import Person if TYPE_CHECKING: from src.common.data_models.info_data_model import TargetPersonInfo @@ -634,7 +637,7 @@ class ActionPlanner: current_available_actions_dict = self.action_manager.get_using_actions() # 获取完整的动作信息 - all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore + all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore ComponentType.ACTION ) current_available_actions = {} diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 003009b8..4ffa14a7 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -1,6 +1,7 @@ import traceback import time import asyncio +import importlib import random import re @@ -16,7 +17,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser from src.common.data_models.mai_message_data_model import MaiMessage from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager @@ -26,7 +26,7 @@ from src.services.message_service import ( replace_user_references, translate_pid_to_description, ) -from src.bw_learner.expression_selector import expression_selector +from src.learners.expression_selector import expression_selector # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person @@ -35,8 +35,7 @@ from src.services import llm_service as llm_api from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt -from src.memory_system.retrieval_tools import get_tool_registry -from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon +from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon from src.chat.utils.common_utils import TempMethodsExpression init_memory_retrieval_sys() @@ -51,10 +50,15 @@ class DefaultReplyer: chat_stream: BotChatSession, request_type: str = "replyer", ): + """初始化群聊回复器。 + + Args: + chat_stream: 当前绑定的聊天会话。 + request_type: LLM 请求类型标识。 + """ self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) - self.heart_fc_sender = UniversalMessageSender() from src.chat.tool_executor import ToolExecutor @@ -1129,7 +1133,10 @@ class DefaultReplyer: user_id=bot_user_id, user_nickname=global_config.bot.nickname, ), - additional_config={}, + additional_config={ + "platform_io_target_group_id": self.chat_stream.group_id, + "platform_io_target_user_id": self.chat_stream.user_id, + }, ), message_segment=message_segment, ) @@ -1164,14 +1171,29 @@ class DefaultReplyer: async def get_prompt_info(self, message: str, sender: str, target: str): related_info = "" start_time = time.time() - search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory") - if search_knowledge_tool is None: - logger.debug("长期记忆检索工具未注册,跳过获取知识内容") + try: + knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge") + except ImportError: + logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容") return "" - logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}") + search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None) + if search_knowledge_tool is None: + logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容") + return "" + + logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") + # 从LPMM知识库获取知识 try: - template_prompt = prompt_manager.get_prompt("memory_get_knowledge") + # 检查LPMM知识库是否启用 + if not global_config.lpmm_knowledge.enable: + logger.debug("LPMM知识库未启用,跳过获取知识库内容") + return "" + + if global_config.lpmm_knowledge.lpmm_mode == "agent": + return "" + + template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge") template_prompt.add_context("bot_name", global_config.bot.nickname) template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) template_prompt.add_context("chat_history", message) @@ -1187,31 +1209,24 @@ class DefaultReplyer: # logger.info(f"工具调用提示词: {prompt}") # logger.info(f"工具调用: {tool_calls}") - if not tool_calls: - logger.debug("模型认为不需要使用长期记忆") + if tool_calls: + result = await self.tool_executor.execute_tool_call(tool_calls[0]) + end_time = time.time() + if not result or not result.get("content"): + logger.debug("从LPMM知识库获取知识失败,返回空知识...") + return "" + found_knowledge_from_lpmm = result.get("content", "") + logger.info( + f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" + ) + related_info += found_knowledge_from_lpmm + logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") + logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") + + return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" + else: + logger.debug("模型认为不需要使用LPMM知识库") return "" - - related_chunks: List[str] = [] - for tool_call in tool_calls: - if tool_call.func_name != "search_long_term_memory": - continue - tool_args = dict(tool_call.args or {}) - tool_args.setdefault("chat_id", self.chat_stream.session_id) - result_text = await search_knowledge_tool.execute(**tool_args) - if result_text and "未找到" not in result_text: - related_chunks.append(result_text) - - if not related_chunks: - logger.debug("长期记忆未返回有效信息") - return "" - - related_info = "\n".join(related_chunks) - end_time = time.time() - logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") - logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - - return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") return "" diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 3b70bb2c..c125a42f 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -16,7 +16,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser from src.common.data_models.mai_message_data_model import MaiMessage from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager @@ -27,13 +26,13 @@ from src.services.message_service import ( replace_user_references, translate_pid_to_description, ) -from src.bw_learner.expression_selector import expression_selector +from src.learners.expression_selector import expression_selector # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person, is_person_known from src.core.types import ActionInfo, EventType from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt -from src.bw_learner.jargon_explainer_old import explain_jargon_in_context +from src.learners.jargon_explainer_old import explain_jargon_in_context init_memory_retrieval_sys() @@ -47,10 +46,15 @@ class PrivateReplyer: chat_stream: BotChatSession, request_type: str = "replyer", ): + """初始化私聊回复器。 + + Args: + chat_stream: 当前绑定的聊天会话。 + request_type: LLM 请求类型标识。 + """ self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) - self.heart_fc_sender = UniversalMessageSender() # self.memory_activator = MemoryActivator() from src.chat.tool_executor import ToolExecutor @@ -970,7 +974,9 @@ class PrivateReplyer: user_nickname=global_config.bot.nickname, ), group_info=None, - additional_config={}, + additional_config={ + "platform_io_target_user_id": self.chat_stream.user_id, + }, ), message_segment=message_segment, ) diff --git a/src/chat/tool_executor.py b/src/chat/tool_executor.py index d449f7a1..aa99fce8 100644 --- a/src/chat/tool_executor.py +++ b/src/chat/tool_executor.py @@ -1,22 +1,20 @@ -""" -工具执行器 +"""工具执行器。 独立的工具执行组件,可以直接输入聊天消息内容, 自动判断并执行相应的工具,返回结构化的工具执行结果。 - -从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。 """ +from typing import Any, Dict, List, Optional, Tuple + import hashlib import time -from typing import Any, Dict, List, Optional, Tuple from src.common.logger import get_logger from src.config.config import global_config, model_config from src.core.announcement_manager import global_announcement_manager -from src.core.component_registry import component_registry from src.llm_models.payload_content import ToolCall from src.llm_models.utils_model import LLMRequest +from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager logger = get_logger("tool_use") @@ -89,7 +87,7 @@ class ToolExecutor: def _get_tool_definitions(self) -> List[Dict[str, Any]]: """获取 LLM 可用的工具定义列表""" - all_tools = component_registry.get_llm_available_tools() + all_tools = component_query_service.get_llm_available_tools() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools] @@ -152,7 +150,7 @@ class ToolExecutor: function_args = tool_call.args or {} function_args["llm_called"] = True - executor = component_registry.get_tool_executor(function_name) + executor = component_query_service.get_tool_executor(function_name) if not executor: logger.warning(f"未知工具名称: {function_name}") return None diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index ede10a41..51e5e643 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask): @staticmethod def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time) records = session.exec(statement).all() return [(record.start_timestamp, record.end_timestamp) for record in records] @staticmethod def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time) records = session.exec(statement).all() return [ @@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1] - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp) messages = session.exec(statement).all() for message in messages: @@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask): # 使用 ActionRecords 中的 reply 动作次数作为回复数基准 try: action_query_start_timestamp = collect_period[-1][1] - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp) actions = session.exec(statement).all() for action in actions: @@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= start_time) messages = session.exec(statement).all() for message in messages: @@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= start_time) messages = session.exec(statement).all() for message in messages: diff --git a/src/common/data_models/person_info_data_model.py b/src/common/data_models/person_info_data_model.py index 4cbb62d8..1b239356 100644 --- a/src/common/data_models/person_info_data_model.py +++ b/src/common/data_models/person_info_data_model.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import datetime -from typing import Optional, List +from typing import Any, List, Mapping, Optional, Sequence import json @@ -15,6 +15,76 @@ class GroupCardnameInfo: group_cardname: str +def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]: + """将单条群名片数据规范化为统一结构。 + + Args: + raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。 + + Returns: + Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。 + """ + group_id = str(raw_item.get("group_id") or "").strip() + group_cardname = str(raw_item.get("group_cardname") or "").strip() + if not group_id or not group_cardname: + return None + return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname) + + +def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]: + """解析数据库中的群名片 JSON 字段。 + + Args: + group_cardname_json: 数据库存储的群名片 JSON 字符串。 + + Returns: + Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。 + + Raises: + json.JSONDecodeError: 当 JSON 文本格式非法时抛出。 + TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。 + """ + if not group_cardname_json: + return None + + raw_items = json.loads(group_cardname_json) + if not isinstance(raw_items, list): + return None + + normalized_items: List[GroupCardnameInfo] = [] + for raw_item in raw_items: + if not isinstance(raw_item, Mapping): + continue + if normalized_item := _normalize_group_cardname_item(raw_item): + normalized_items.append(normalized_item) + + return normalized_items or None + + +def dump_group_cardname_records( + group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]], +) -> str: + """将群名片列表序列化为数据库使用的标准 JSON 字符串。 + + Args: + group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo` + 对象和包含 `group_id` / `group_cardname` 的字典。 + + Returns: + str: 统一使用 `group_cardname` 键名的 JSON 字符串。 + """ + normalized_items: List[GroupCardnameInfo] = [] + for raw_item in group_cardname_records or []: + if isinstance(raw_item, GroupCardnameInfo): + normalized_items.append(raw_item) + continue + if isinstance(raw_item, Mapping): + if normalized_item := _normalize_group_cardname_item(raw_item): + normalized_items.append(normalized_item) + + return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False) + + class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): def __init__( self, @@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): """最后一次被认识的时间""" @classmethod - def from_db_instance(cls, db_record: "PersonInfo"): - nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None - group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None + def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo": + """从数据库记录构造人物信息数据模型。 + + Args: + db_record: 数据库中的人物信息记录。 + + Returns: + MaiPersonInfo: 转换后的数据模型对象。 + """ + group_cardname_list = parse_group_cardname_json(db_record.group_cardname) memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None return cls( is_known=db_record.is_known, @@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): ) def to_db_instance(self) -> "PersonInfo": - group_cardname = ( - json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None - ) + """将当前数据模型转换为数据库记录对象。 + + Returns: + PersonInfo: 可直接写入数据库的模型实例。 + """ + group_cardname = dump_group_cardname_records(self.group_cardname_list) return PersonInfo( is_known=self.is_known, person_id=self.person_id, diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index a0993a77..5b274c43 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,8 +1,9 @@ -from typing import Optional -from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime -from sqlmodel import SQLModel, Field, LargeBinary -from enum import Enum from datetime import datetime +from enum import Enum +from typing import Optional + +from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float +from sqlmodel import Field, LargeBinary, SQLModel class ModelUser(str, Enum): @@ -172,8 +173,8 @@ class Expression(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 - situation: str = Field(index=True, max_length=255, primary_key=True) # 情景 - style: str = Field(index=True, max_length=255, primary_key=True) # 风格 + situation: str = Field(index=True, max_length=255) # 情景 + style: str = Field(index=True, max_length=255) # 风格 # context: str # 上下文 # up_content: str @@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 - content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容 + content: str = Field(index=True, max_length=255) # 黑话内容 raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str] meaning: str # 黑话含义 diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index 1aabafbc..863c9c1e 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -1,9 +1,8 @@ # 定义模块颜色映射 -from typing import Optional, Tuple, Dict - import itertools import os import sys +from typing import Dict, Optional, Tuple MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = { @@ -54,15 +53,19 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = { "component_registry": ("#ffaf00", None, False), "plugin_runtime.integration": ("#d75f00", None, False), "plugin_runtime.host.supervisor": ("#ff5f00", None, False), + "plugin_runtime.host.runner_manager": ("#ff5f00", None, False), "plugin_runtime.host.rpc_server": ("#ff8700", None, False), "plugin_runtime.host.component_registry": ("#ffaf00", None, False), "plugin_runtime.host.capability_service": ("#ffd700", None, False), "plugin_runtime.host.event_dispatcher": ("#87d700", None, False), - "plugin_runtime.host.workflow_executor": ("#5fd7af", None, False), + "plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False), + "plugin_runtime.host.message_gateway": ("#5fd7d7", None, False), + "plugin_runtime.host.message_utils": ("#5faf87", None, False), "plugin_runtime.runner.main": ("#d787ff", None, False), "plugin_runtime.runner.rpc_client": ("#8787ff", None, False), "plugin_runtime.runner.manifest_validator": ("#5fafff", None, False), "plugin_runtime.runner.plugin_loader": ("#00afaf", None, False), + "plugin.maibot-team.napcat-adapter": ("#00af87", None, False), "webui": ("#5f87ff", None, False), "webui.app": ("#5f87d7", None, False), "webui.api": ("#5fafff", None, False), @@ -157,15 +160,20 @@ MODULE_ALIASES = { "chat_history_summarizer": "聊天概括器", "plugin_runtime.integration": "IPC插件系统", "plugin_runtime.host.supervisor": "插件监督器", + "plugin_runtime.host.runner_manager": "插件监督器", "plugin_runtime.host.rpc_server": "插件RPC服务", "plugin_runtime.host.component_registry": "插件组件注册", "plugin_runtime.host.capability_service": "插件能力服务", "plugin_runtime.host.event_dispatcher": "插件事件分发", + "plugin_runtime.host.hook_dispatcher": "插件Hook分发", + "plugin_runtime.host.message_gateway": "插件消息网关", + "plugin_runtime.host.message_utils": "插件消息工具", "plugin_runtime.host.workflow_executor": "插件工作流", "plugin_runtime.runner.main": "插件运行器", "plugin_runtime.runner.rpc_client": "插件RPC客户端", "plugin_runtime.runner.manifest_validator": "插件清单校验", "plugin_runtime.runner.plugin_loader": "插件加载器", + "plugin.maibot-team.napcat-adapter": "NapCat内置适配器", "webui": "WebUI", "webui.app": "WebUI应用", "webui.api": "WebUI接口", diff --git a/src/common/message_server/server.py b/src/common/message_server/server.py index 77a931e5..e75da4e7 100644 --- a/src/common/message_server/server.py +++ b/src/common/message_server/server.py @@ -21,7 +21,7 @@ class Server: self._server: Optional[UvicornServer] = None self.set_address(host, port) - def register_router(self, router: APIRouter, prefix: str = ""): + def register_router(self, router: APIRouter, prefix: str = ""): """注册路由 APIRouter 用于对相关的路由端点进行分组和模块化管理: diff --git a/src/common/utils/utils_session.py b/src/common/utils/utils_session.py index a383f5a2..1b6d8f72 100644 --- a/src/common/utils/utils_session.py +++ b/src/common/utils/utils_session.py @@ -5,13 +5,22 @@ import hashlib class SessionUtils: @staticmethod - def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str: + def calculate_session_id( + platform: str, + *, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, + ) -> str: """计算session_id Args: platform: 平台名称 user_id: 用户ID(如果是私聊) group_id: 群ID(如果是群聊) + account_id: 当前平台账号 ID,可选 + scope: 当前路由作用域,可选 Returns: str: 计算得到的会话ID Raises: @@ -19,8 +28,15 @@ class SessionUtils: """ if not user_id and not group_id: raise ValueError("UserID 或 GroupID 必须提供其一") + + route_components = [] + if account_id: + route_components.append(f"account:{account_id}") + if scope: + route_components.append(f"scope:{scope}") + if group_id: - components = [platform, group_id] + components = [platform, *route_components, group_id] else: - components = [platform, user_id, "private"] + components = [platform, *route_components, user_id, "private"] return hashlib.md5("_".join(components).encode()).hexdigest() diff --git a/src/config/config.py b/src/config/config.py index ff5941bf..bee81efb 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar import asyncio import copy +import inspect import sys import tomlkit @@ -61,6 +62,7 @@ MODEL_CONFIG_VERSION: str = "1.12.0" logger = get_logger("config") T = TypeVar("T", bound="ConfigBase") +ConfigReloadCallback = Callable[[Sequence[str]], object] | Callable[[], object] class Config(ConfigBase): @@ -190,7 +192,7 @@ class ConfigManager: self.global_config: Config | None = None self.model_config: ModelConfig | None = None self._reload_lock: asyncio.Lock = asyncio.Lock() - self._reload_callbacks: list[Callable[[], object]] = [] + self._reload_callbacks: list[ConfigReloadCallback] = [] self._file_watcher: FileWatcher | None = None self._file_watcher_subscription_id: str | None = None self._hot_reload_min_interval_s: float = 1.0 @@ -226,16 +228,125 @@ class ConfigManager: raise RuntimeError(t("config.model_not_initialized")) return self.model_config - def register_reload_callback(self, callback: Callable[[], object]) -> None: + def register_reload_callback(self, callback: ConfigReloadCallback) -> None: + """注册配置热重载回调。 + + Args: + callback: 配置热重载回调。允许无参回调,也允许接收 + ``Sequence[str]`` 类型的变更范围列表。 + """ + self._reload_callbacks.append(callback) - def unregister_reload_callback(self, callback: Callable[[], object]) -> None: + def unregister_reload_callback(self, callback: ConfigReloadCallback) -> None: + """注销配置热重载回调。 + + Args: + callback: 先前注册过的回调对象。 + """ + try: self._reload_callbacks.remove(callback) except ValueError: return - async def reload_config(self) -> bool: + @staticmethod + def _normalize_changed_scopes(changed_scopes: Sequence[str] | None) -> tuple[str, ...]: + """规范化配置变更范围列表。 + + Args: + changed_scopes: 原始配置变更范围。 + + Returns: + tuple[str, ...]: 去重后的配置变更范围元组。 + """ + + if not changed_scopes: + return ("bot", "model") + + normalized_scopes: list[str] = [] + for scope in changed_scopes: + normalized_scope = str(scope or "").strip().lower() + if normalized_scope not in {"bot", "model"}: + continue + if normalized_scope not in normalized_scopes: + normalized_scopes.append(normalized_scope) + return tuple(normalized_scopes) + + @staticmethod + def _resolve_changed_scopes(changes: Sequence[FileChange]) -> tuple[str, ...]: + """根据文件变更列表推断配置变更范围。 + + Args: + changes: 文件监听器返回的变更列表。 + + Returns: + tuple[str, ...]: 命中的配置变更范围元组。 + """ + + changed_scopes: list[str] = [] + for change in changes: + file_name = change.path.name + if file_name == "bot_config.toml" and "bot" not in changed_scopes: + changed_scopes.append("bot") + if file_name == "model_config.toml" and "model" not in changed_scopes: + changed_scopes.append("model") + return tuple(changed_scopes) + + @staticmethod + def _callback_accepts_scopes(callback: ConfigReloadCallback) -> bool: + """判断回调是否接收配置变更范围参数。 + + Args: + callback: 待检测的回调对象。 + + Returns: + bool: 若回调可接收一个位置参数或可变位置参数,则返回 ``True``。 + """ + + try: + parameters = inspect.signature(callback).parameters.values() + except (TypeError, ValueError): + return False + + positional_params = { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + } + for parameter in parameters: + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: + return True + if parameter.kind in positional_params: + return True + return False + + async def _invoke_reload_callback( + self, + callback: ConfigReloadCallback, + changed_scopes: Sequence[str], + ) -> None: + """执行单个配置热重载回调。 + + Args: + callback: 要执行的回调对象。 + changed_scopes: 本次热重载命中的配置范围。 + """ + + result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback() + if asyncio.iscoroutine(result): + await result + + async def reload_config(self, changed_scopes: Sequence[str] | None = None) -> bool: + """重新加载主配置和模型配置。 + + Args: + changed_scopes: 本次触发热重载的配置范围。 + + Returns: + bool: 是否重载成功。 + """ + + normalized_scopes = self._normalize_changed_scopes(changed_scopes) async with self._reload_lock: try: global_config_new, global_updated = load_config_from_file( @@ -265,9 +376,7 @@ class ConfigManager: for callback in list(self._reload_callbacks): try: - result = callback() - if asyncio.iscoroutine(result): - await result + await self._invoke_reload_callback(callback, normalized_scopes) except Exception as exc: logger.warning(t("config.reload_callback_failed", error=exc)) return True @@ -312,6 +421,12 @@ class ConfigManager: self._file_watcher = None async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None: + """处理主配置与模型配置文件变更。 + + Args: + changes: 当前批次收集到的文件变更列表。 + """ + if not changes: return now_monotonic = asyncio.get_running_loop().time() @@ -321,7 +436,11 @@ class ConfigManager: self._last_hot_reload_monotonic = now_monotonic logger.info(t("config.file_change_detected")) try: - await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s) + changed_scopes = self._resolve_changed_scopes(changes) + await asyncio.wait_for( + self.reload_config(changed_scopes=changed_scopes), + timeout=self._hot_reload_timeout_s, + ) except asyncio.TimeoutError: logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s)) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 20c2c2c8..fde3f800 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1633,24 +1633,6 @@ class PluginRuntimeConfig(ConfigBase): ) """启用插件系统""" - builtin_plugin_dir: str = Field( - default="src/plugins/built_in", - json_schema_extra={ - "x-widget": "input", - "x-icon": "folder", - }, - ) - """内置插件目录(相对于项目根目录)""" - - thirdparty_plugin_dir: str = Field( - default="plugins", - json_schema_extra={ - "x-widget": "input", - "x-icon": "folder-open", - }, - ) - """第三方插件目录(相对于项目根目录)""" - health_check_interval_sec: float = Field( default=30.0, json_schema_extra={ @@ -1678,14 +1660,14 @@ class PluginRuntimeConfig(ConfigBase): ) """等待 Runner 子进程启动并注册的超时时间(秒)""" - workflow_blocking_timeout_sec: float = Field( - default=120.0, + hook_blocking_timeout_sec: float = Field( + default=30, json_schema_extra={ "x-widget": "number", "x-icon": "timer", }, ) - """Workflow 阻塞步骤的全局超时上限(秒)""" + """Hook 阻塞步骤的全局超时上限(秒)""" ipc_socket_path: str = Field( default="", @@ -1694,4 +1676,7 @@ class PluginRuntimeConfig(ConfigBase): "x-icon": "link", }, ) - """_wrap_\n 自定义 IPC Socket 路径(仅 Linux/macOS 生效)\n 留空则自动生成临时路径""" + """ + 自定义 IPC Socket 路径(仅 Linux/macOS 生效) + 留空则自动生成临时路径 + """ diff --git a/src/core/component_registry.py b/src/core/component_registry.py deleted file mode 100644 index bb58682a..00000000 --- a/src/core/component_registry.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -核心组件注册表 - -面向最终架构的组件管理: -- Action:注册 ActionInfo + 执行器(本地 callable 或 IPC 路由) -- Command:注册正则模式 + 执行器 -- Tool:注册工具定义 + 执行器 - -不依赖任何插件基类,组件执行器是纯 async callable。 -""" - -import re -from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple - -from src.common.logger import get_logger -from src.core.types import ( - ActionInfo, - CommandInfo, - ComponentInfo, - ComponentType, - ToolInfo, -) - -logger = get_logger("component_registry") - -# 执行器类型 -ActionExecutor = Callable[..., Awaitable[Any]] -CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]] -ToolExecutor = Callable[..., Awaitable[Any]] - - -class ComponentRegistry: - """核心组件注册表 - - 管理 action、command、tool 三类组件。 - 每个组件由「元信息 + 执行器」构成,执行器是 async callable, - 不需要继承任何基类。 - """ - - def __init__(self): - # Action 注册 - self._actions: Dict[str, ActionInfo] = {} - self._action_executors: Dict[str, ActionExecutor] = {} - self._default_actions: Dict[str, ActionInfo] = {} - - # Command 注册 - self._commands: Dict[str, CommandInfo] = {} - self._command_executors: Dict[str, CommandExecutor] = {} - self._command_patterns: Dict[Pattern, str] = {} - - # Tool 注册 - self._tools: Dict[str, ToolInfo] = {} - self._tool_executors: Dict[str, ToolExecutor] = {} - self._llm_available_tools: Dict[str, ToolInfo] = {} - - # 插件配置(plugin_name -> config dict) - self._plugin_configs: Dict[str, dict] = {} - - logger.info("核心组件注册表初始化完成") - - # ========== Action ========== - - def register_action( - self, - info: ActionInfo, - executor: ActionExecutor, - ) -> bool: - """注册 action - - Args: - info: action 元信息 - executor: 执行器,async callable - """ - name = info.name - if name in self._actions: - logger.warning(f"Action {name} 已存在,跳过注册") - return False - - self._actions[name] = info - self._action_executors[name] = executor - - if info.enabled: - self._default_actions[name] = info - - logger.debug(f"注册 Action: {name}") - return True - - def get_action_info(self, name: str) -> Optional[ActionInfo]: - return self._actions.get(name) - - def get_action_executor(self, name: str) -> Optional[ActionExecutor]: - return self._action_executors.get(name) - - def get_default_actions(self) -> Dict[str, ActionInfo]: - return self._default_actions.copy() - - def get_all_actions(self) -> Dict[str, ActionInfo]: - return self._actions.copy() - - def remove_action(self, name: str) -> bool: - if name not in self._actions: - return False - del self._actions[name] - self._action_executors.pop(name, None) - self._default_actions.pop(name, None) - logger.debug(f"移除 Action: {name}") - return True - - # ========== Command ========== - - def register_command( - self, - info: CommandInfo, - executor: CommandExecutor, - ) -> bool: - """注册 command""" - name = info.name - if name in self._commands: - logger.warning(f"Command {name} 已存在,跳过注册") - return False - - self._commands[name] = info - self._command_executors[name] = executor - - if info.enabled and info.command_pattern: - pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL) - self._command_patterns[pattern] = name - - logger.debug(f"注册 Command: {name}") - return True - - def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]: - """根据文本查找匹配的命令 - - Returns: - (executor, matched_groups, command_info) 或 None - """ - candidates = [p for p in self._command_patterns if p.match(text)] - if not candidates: - return None - if len(candidates) > 1: - logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个") - pattern = candidates[0] - name = self._command_patterns[pattern] - return ( - self._command_executors[name], - pattern.match(text).groupdict(), # type: ignore - self._commands[name], - ) - - def remove_command(self, name: str) -> bool: - if name not in self._commands: - return False - del self._commands[name] - self._command_executors.pop(name, None) - self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name} - logger.debug(f"移除 Command: {name}") - return True - - # ========== Tool ========== - - def register_tool( - self, - info: ToolInfo, - executor: ToolExecutor, - ) -> bool: - """注册 tool""" - name = info.name - if name in self._tools: - logger.warning(f"Tool {name} 已存在,跳过注册") - return False - - self._tools[name] = info - self._tool_executors[name] = executor - - if info.enabled: - self._llm_available_tools[name] = info - - logger.debug(f"注册 Tool: {name}") - return True - - def get_tool_info(self, name: str) -> Optional[ToolInfo]: - return self._tools.get(name) - - def get_tool_executor(self, name: str) -> Optional[ToolExecutor]: - return self._tool_executors.get(name) - - def get_llm_available_tools(self) -> Dict[str, ToolInfo]: - return self._llm_available_tools.copy() - - def get_all_tools(self) -> Dict[str, ToolInfo]: - return self._tools.copy() - - def remove_tool(self, name: str) -> bool: - if name not in self._tools: - return False - del self._tools[name] - self._tool_executors.pop(name, None) - self._llm_available_tools.pop(name, None) - logger.debug(f"移除 Tool: {name}") - return True - - # ========== 通用查询 ========== - - def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]: - """获取组件元信息""" - match component_type: - case ComponentType.ACTION: - return self._actions.get(name) - case ComponentType.COMMAND: - return self._commands.get(name) - case ComponentType.TOOL: - return self._tools.get(name) - case _: - return None - - def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: - """获取某类型的所有组件""" - match component_type: - case ComponentType.ACTION: - return dict(self._actions) - case ComponentType.COMMAND: - return dict(self._commands) - case ComponentType.TOOL: - return dict(self._tools) - case _: - return {} - - # ========== 插件配置 ========== - - def set_plugin_config(self, plugin_name: str, config: dict) -> None: - self._plugin_configs[plugin_name] = config - - def get_plugin_config(self, plugin_name: str) -> Optional[dict]: - return self._plugin_configs.get(plugin_name) - - -# 全局单例 -component_registry = ComponentRegistry() diff --git a/src/bw_learner/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py similarity index 96% rename from src/bw_learner/expression_auto_check_task.py rename to src/learners/expression_auto_check_task.py index d90eb4da..e5af1057 100644 --- a/src/bw_learner/expression_auto_check_task.py +++ b/src/learners/expression_auto_check_task.py @@ -3,19 +3,19 @@ 功能: 1. 定期随机选取指定数量的表达方式 -2. 使用LLM进行评估 +2. 使用 LLM 进行评估 3. 通过评估的:rejected=0, checked=1 4. 未通过评估的:rejected=1, checked=1 """ -from typing import List import asyncio import json import random +from typing import List from sqlmodel import select -from src.bw_learner.expression_review_store import get_review_state, set_review_state +from src.learners.expression_review_store import get_review_state, set_review_state from src.common.database.database import get_db_session from src.common.database.database_model import Expression from src.common.logger import get_logger @@ -146,7 +146,8 @@ class ExpressionAutoCheckTask(AsyncTask): 选中的表达方式列表 """ try: - with get_db_session() as session: + # 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。 + with get_db_session(auto_commit=False) as session: statement = select(Expression) all_expressions = session.exec(statement).all() diff --git a/src/bw_learner/expression_learner.py b/src/learners/expression_learner.py similarity index 90% rename from src/bw_learner/expression_learner.py rename to src/learners/expression_learner.py index 43e4ee7d..b82ae1fa 100644 --- a/src/bw_learner/expression_learner.py +++ b/src/learners/expression_learner.py @@ -329,7 +329,13 @@ class ExpressionLearner: return filtered_expressions # ====== DB 操作相关 ====== - async def _upsert_expression_to_db(self, situation: str, style: str): + async def _upsert_expression_to_db(self, situation: str, style: str) -> None: + """将表达方式写入数据库,存在时更新,不存在时新增。 + + Args: + situation: 表达方式对应的使用情景。 + style: 表达方式风格。 + """ expr, similarity = self._find_similar_expression(situation) or (None, 0) if expr: # 根据相似度决定是否使用 LLM 总结 @@ -340,7 +346,13 @@ class ExpressionLearner: # 没有找到匹配的记录,创建新记录 self._create_expression(situation, style) - def _create_expression(self, situation: str, style: str): + def _create_expression(self, situation: str, style: str) -> None: + """创建新的表达方式记录。 + + Args: + situation: 表达方式对应的使用情景。 + style: 表达方式风格。 + """ content_list = [situation] try: with get_db_session() as db: @@ -353,6 +365,7 @@ class ExpressionLearner: last_active_time=datetime.now(), ) db.add(new_expr) + db.flush() except Exception as e: logger.error(f"创建表达方式失败: {e}") @@ -448,25 +461,43 @@ class ExpressionLearner: def _find_similar_expression( self, situation: str, similarity_threshold: float = 0.75 ) -> Optional[Tuple[MaiExpression, float]]: - """在数据库中查找相似的表达方式""" + """在数据库中查找相似的表达方式。 + + Args: + situation: 当前待匹配的情景描述。 + similarity_threshold: 认定为相似表达方式的最低相似度阈值。 + + Returns: + Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回 + ``(表达方式对象, 相似度)``;否则返回 ``None``。 + """ try: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Expression).filter_by(session_id=self.session_id) expressions = session.exec(statement).all() - best_match: Optional[Expression] = None - best_similarity = 0.0 + best_match: Optional[MaiExpression] = None + best_similarity = 0.0 + + for db_expression in expressions: + expression = MaiExpression.from_db_instance(db_expression) + candidate_situations = [expression.situation, *expression.content] + for candidate_situation in candidate_situations: + normalized_candidate_situation = candidate_situation.strip() + if not normalized_candidate_situation: + continue + similarity = difflib.SequenceMatcher( + None, + situation, + normalized_candidate_situation, + ).ratio() + if similarity > similarity_threshold and similarity > best_similarity: + best_similarity = similarity + best_match = expression - for expr in expressions: - content_list = json.loads(expr.content_list) - for situation in content_list: - similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio() - if similarity > similarity_threshold and similarity > best_similarity: - best_similarity = similarity - best_match = expr if best_match: - logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}") - return MaiExpression.from_db_instance(best_match), best_similarity + logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}") + return best_match, best_similarity except Exception as e: logger.error(f"查找相似表达方式失败: {e}") diff --git a/src/bw_learner/expression_review_store.py b/src/learners/expression_review_store.py similarity index 100% rename from src/bw_learner/expression_review_store.py rename to src/learners/expression_review_store.py diff --git a/src/bw_learner/expression_selector.py b/src/learners/expression_selector.py similarity index 99% rename from src/bw_learner/expression_selector.py rename to src/learners/expression_selector.py index c6cfe469..c96e84cf 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/learners/expression_selector.py @@ -9,7 +9,7 @@ from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.database_model import Expression from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.learner_utils_old import weighted_sample +from src.learners.learner_utils_old import weighted_sample from src.chat.utils.common_utils import TempMethodsExpression logger = get_logger("expression_selector") diff --git a/src/bw_learner/expression_utils.py b/src/learners/expression_utils.py similarity index 100% rename from src/bw_learner/expression_utils.py rename to src/learners/expression_utils.py diff --git a/src/bw_learner/jargon_explainer.py b/src/learners/jargon_explainer.py similarity index 100% rename from src/bw_learner/jargon_explainer.py rename to src/learners/jargon_explainer.py diff --git a/src/bw_learner/jargon_explainer_old.py b/src/learners/jargon_explainer_old.py similarity index 99% rename from src/bw_learner/jargon_explainer_old.py rename to src/learners/jargon_explainer_old.py index 94031b4a..0cfafa82 100644 --- a/src/bw_learner/jargon_explainer_old.py +++ b/src/learners/jargon_explainer_old.py @@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.jargon_explainer import search_jargon -from src.bw_learner.learner_utils_old import ( +from src.learners.jargon_miner_old import search_jargon +from src.learners.learner_utils_old import ( is_bot_message, contains_bot_self_name, parse_chat_id_list, diff --git a/src/bw_learner/jargon_miner.py b/src/learners/jargon_miner.py similarity index 93% rename from src/bw_learner/jargon_miner.py rename to src/learners/jargon_miner.py index 2fbf8a2e..32926894 100644 --- a/src/bw_learner/jargon_miner.py +++ b/src/learners/jargon_miner.py @@ -1,17 +1,18 @@ from collections import OrderedDict -from json_repair import repair_json -from sqlmodel import select -from typing import List, Optional, Dict, Callable, TypedDict, Set +from typing import Callable, Dict, List, Optional, Set, TypedDict import asyncio import json import random -from src.common.logger import get_logger +from json_repair import repair_json +from sqlmodel import select + +from src.common.data_models.jargon_data_model import MaiJargon from src.common.database.database import get_db_session from src.common.database.database_model import Jargon -from src.common.data_models.jargon_data_model import MaiJargon -from src.config.config import model_config, global_config +from src.common.logger import get_logger +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.prompt.prompt_manager import prompt_manager @@ -198,7 +199,7 @@ class JargonMiner: async def process_extracted_entries( self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None - ): + ) -> None: """ 处理已提取的黑话条目(从 expression_learner 路由过来的) @@ -229,7 +230,7 @@ class JargonMiner: content = entry["content"] raw_content_set = entry["raw_content"] try: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: jargon_items = session.exec(select(Jargon).filter_by(content=content)).all() except Exception as e: logger.error(f"查询黑话 '{content}' 失败: {e}") @@ -273,11 +274,12 @@ class JargonMiner: try: with get_db_session() as session: session.add(new_jargon) + session.flush() + saved += 1 + self._add_to_cache(content) except Exception as e: logger.error(f"保存新黑话 '{content}' 失败: {e}") continue - finally: - self._add_to_cache(content) # 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出) if uniq_entries: # 收集所有提取的jargon内容 @@ -304,7 +306,13 @@ class JargonMiner: removed_content, _ = self.cache.popitem(last=False) logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}") - def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]): + def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None: + """更新已有黑话记录并写回数据库。 + + Args: + db_jargon: 已命中的黑话 ORM 对象。 + raw_content_set: 本次新增的原始上下文集合。 + """ db_jargon.count += 1 existing_raw_content: List[str] = [] if db_jargon.raw_content: @@ -326,7 +334,17 @@ class JargonMiner: try: with get_db_session() as session: - session.add(db_jargon) + if db_jargon.id is None: + raise ValueError("黑话记录缺少 id,无法更新数据库") + statement = select(Jargon).filter_by(id=db_jargon.id).limit(1) + if persisted_jargon := session.exec(statement).first(): + persisted_jargon.count = db_jargon.count + persisted_jargon.raw_content = db_jargon.raw_content + persisted_jargon.session_id_dict = db_jargon.session_id_dict + persisted_jargon.is_global = db_jargon.is_global + session.add(persisted_jargon) + else: + logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新") except Exception as e: logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}") diff --git a/src/bw_learner/learner_utils.py b/src/learners/learner_utils.py similarity index 100% rename from src/bw_learner/learner_utils.py rename to src/learners/learner_utils.py diff --git a/src/bw_learner/learner_utils_old.py b/src/learners/learner_utils_old.py similarity index 100% rename from src/bw_learner/learner_utils_old.py rename to src/learners/learner_utils_old.py diff --git a/src/main.py b/src/main.py index 1bfa91b0..6e568df5 100644 --- a/src/main.py +++ b/src/main.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING import asyncio import time -from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask +from src.learners.expression_auto_check_task import ExpressionAutoCheckTask from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.message_receive.bot import chat_bot from src.chat.message_receive.chat_manager import chat_manager diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 49e5ca02..2eadd05a 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -14,7 +14,7 @@ from src.common.database.database_model import ThinkingQuestion from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon +from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon logger = get_logger("memory_retrieval") diff --git a/src/memory_system/retrieval_tools/query_words.py b/src/memory_system/retrieval_tools/query_words.py index 66fb3c46..ee28b934 100644 --- a/src/memory_system/retrieval_tools/query_words.py +++ b/src/memory_system/retrieval_tools/query_words.py @@ -4,7 +4,7 @@ """ from src.common.logger import get_logger -from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon +from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon from .tool_registry import register_memory_retrieval_tool logger = get_logger("memory_retrieval_tools") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 960de4aa..15ef0049 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,24 +1,24 @@ -import hashlib +from datetime import datetime +from typing import Dict, Optional, Union + import asyncio +import hashlib import json -import time -import random import math +import random +import time from json_repair import repair_json -from typing import Union, Optional, Dict, List -from datetime import datetime -from sqlalchemy import or_ from sqlmodel import col, select -from src.common.logger import get_logger +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo -from src.llm_models.utils_model import LLMRequest +from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.services.memory_service import memory_service +from src.llm_models.utils_model import LLMRequest logger = get_logger("person_info") @@ -28,6 +28,32 @@ relation_selection_model = LLMRequest( ) +def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]: + """将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。 + + Args: + group_cardname_json: 数据库存储的群名片 JSON 字符串。 + + Returns: + list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。 + + Raises: + json.JSONDecodeError: 当 JSON 文本格式非法时抛出。 + TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。 + """ + group_cardname_list = parse_group_cardname_json(group_cardname_json) + if not group_cardname_list: + return [] + + return [ + { + "group_id": group_cardname.group_id, + "group_cardname": group_cardname.group_cardname, + } + for group_cardname in group_cardname_list + ] + + def get_person_id(platform: str, user_id: Union[int, str]) -> str: """获取唯一id""" if "-" in platform: @@ -39,60 +65,16 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str: def get_person_id_by_person_name(person_name: str) -> str: """根据用户名获取用户ID""" - clean_name = str(person_name or "").strip() - if not clean_name: - return "" try: with get_db_session() as session: - statement = ( - select(PersonInfo) - .where( - or_( - col(PersonInfo.person_name) == clean_name, - col(PersonInfo.user_nickname) == clean_name, - ) - ) - .limit(1) - ) - record = session.exec(statement).first() - if record and record.person_id: - return record.person_id - - statement = ( - select(PersonInfo) - .where(PersonInfo.group_cardname.contains(clean_name)) - .limit(1) - ) + statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1) record = session.exec(statement).first() return record.person_id if record else "" except Exception as e: - logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}") + logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") return "" -def resolve_person_id_for_memory( - *, - person_name: str = "", - platform: str = "", - user_id: Optional[Union[int, str]] = None, -) -> str: - """统一人物记忆链路中的 person_id 解析。 - - 优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。 - """ - name_token = str(person_name or "").strip() - if name_token: - resolved = get_person_id_by_person_name(name_token) - if resolved: - return resolved - - platform_token = str(platform or "").strip() - user_token = str(user_id or "").strip() - if platform_token and user_token: - return get_person_id(platform_token, user_token) - return "" - - def is_person_known( person_id: Optional[str] = None, user_id: Optional[str] = None, @@ -277,7 +259,7 @@ class Person: person.know_since = time.time() person.last_know = time.time() person.memory_points = [] - person.group_nick_name = [] # 初始化群昵称列表 + person.group_cardname_list = [] # 初始化群名片列表 # 如果是群聊,添加群昵称 if group_id and group_nick_name: @@ -315,7 +297,7 @@ class Person: self.platform = platform self.nickname = global_config.bot.nickname self.person_name = global_config.bot.nickname - self.group_nick_name: list[dict[str, str]] = [] + self.group_cardname_list: list[dict[str, str]] = [] return self.user_id = "" @@ -354,7 +336,7 @@ class Person: self.know_since = None self.last_know: Optional[float] = None self.memory_points = [] - self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str} + self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str} # 从数据库加载数据 self.load_from_database() @@ -454,16 +436,16 @@ class Person: return # 检查是否已存在该群号的记录 - for item in self.group_nick_name: + for item in self.group_cardname_list: if item.get("group_id") == group_id: # 更新现有记录 - item["group_nick_name"] = group_nick_name + item["group_cardname"] = group_nick_name self.sync_to_database() logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}") return # 添加新记录 - self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name}) + self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name}) self.sync_to_database() logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}") @@ -498,20 +480,15 @@ class Person: else: self.memory_points = [] - # 处理group_nick_name字段(JSON格式的列表) + # 处理 group_cardname 字段(JSON 格式的列表) if record.group_cardname: try: - loaded_group_nick_names = json.loads(record.group_cardname) - # 确保是列表格式 - if isinstance(loaded_group_nick_names, list): - self.group_nick_name = loaded_group_nick_names - else: - self.group_nick_name = [] + self.group_cardname_list = _to_group_cardname_records(record.group_cardname) except (json.JSONDecodeError, TypeError): logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值") - self.group_nick_name = [] + self.group_cardname_list = [] else: - self.group_nick_name = [] + self.group_cardname_list = [] logger.debug(f"已从数据库加载用户 {self.person_id} 的信息") else: @@ -532,11 +509,7 @@ class Person: if self.memory_points else json.dumps([], ensure_ascii=False) ) - group_nickname_value = ( - json.dumps(self.group_nick_name, ensure_ascii=False) - if self.group_nick_name - else json.dumps([], ensure_ascii=False) - ) + group_cardname_value = dump_group_cardname_records(self.group_cardname_list) first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None @@ -556,7 +529,7 @@ class Person: record.first_known_time = first_known_time record.last_known_time = last_known_time record.memory_points = memory_points_value - record.group_nickname = group_nickname_value + record.group_cardname = group_cardname_value session.add(record) logger.debug(f"已同步用户 {self.person_id} 的信息到数据库") else: @@ -572,7 +545,7 @@ class Person: first_known_time=first_known_time, last_known_time=last_known_time, memory_points=memory_points_value, - group_nickname=group_nickname_value, + group_cardname=group_cardname_value, ) session.add(record) logger.debug(f"已创建用户 {self.person_id} 的信息到数据库") @@ -583,79 +556,79 @@ class Person: async def build_relationship(self, chat_content: str = "", info_type=""): if not self.is_known: return "" + # 构建points文本 + nickname_str = "" if self.person_name != self.nickname: nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})" - async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]: - clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()] - if not clean_traits: - return [] - if not query_text: - return clean_traits[:limit] + relation_info = "" - numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1)) - prompt = f"""当前关注内容: -{query_text} + points_text = "" + category_list = self.get_all_category() -候选人物信息: -{numbered_traits} + if chat_content: + prompt = f"""当前聊天内容: +{chat_content} -请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。 -例如: -<1><3> -如果都不相关,请输出""" +分类列表: +{category_list} +**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹: +例如: +<分类1><分类2><分类3>...... +如果没有相关的分类,请输出""" - try: - response, _ = await relation_selection_model.generate_response_async(prompt) - selected_traits: List[str] = [] - for raw_index in extract_categories_from_response(response): - if raw_index == "none": - return [] - try: - trait_index = int(raw_index) - 1 - except ValueError: - continue - if 0 <= trait_index < len(clean_traits): - trait = clean_traits[trait_index] - if trait not in selected_traits: - selected_traits.append(trait) - if selected_traits: - return selected_traits[:limit] - except Exception as e: - logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}") + response, _ = await relation_selection_model.generate_response_async(prompt) + # print(prompt) + # print(response) + category_list = extract_categories_from_response(response) + if "none" not in category_list: + for category in category_list: + random_memory = self.get_random_memory_by_category(category, 2) + if random_memory: + random_memory_str = "\n".join( + [get_memory_content_from_memory(memory) for memory in random_memory] + ) + points_text = f"有关 {category} 的内容:{random_memory_str}" + break + elif info_type: + prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。 - return clean_traits[:limit] - - profile = await memory_service.get_person_profile(self.person_id, limit=8) - relation_parts: List[str] = [] - if profile.summary.strip(): - relation_parts.append(profile.summary.strip()) - - query_text = str(chat_content or info_type or "").strip() - selected_traits = await _select_traits(query_text, profile.traits, limit=3) - if not selected_traits and not query_text: - selected_traits = [trait for trait in profile.traits if trait][:2] - - for trait in selected_traits: - clean_trait = str(trait).strip() - if clean_trait and clean_trait not in relation_parts: - relation_parts.append(clean_trait) - - for evidence in profile.evidence: - content = str(evidence.get("content", "") or "").strip() - if content and content not in relation_parts: - relation_parts.append(content) - if len(relation_parts) >= 4: - break +现有信息类别列表: +{category_list} +**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹: +例如: +<分类1><分类2><分类3>...... +如果没有相关的分类,请输出""" + response, _ = await relation_selection_model.generate_response_async(prompt) + # print(prompt) + # print(response) + category_list = extract_categories_from_response(response) + if "none" not in category_list: + for category in category_list: + random_memory = self.get_random_memory_by_category(category, 3) + if random_memory: + random_memory_str = "\n".join( + [get_memory_content_from_memory(memory) for memory in random_memory] + ) + points_text = f"有关 {category} 的内容:{random_memory_str}" + break + else: + for category in category_list: + random_memory = self.get_random_memory_by_category(category, 1)[0] + if random_memory: + points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}" + break points_info = "" - if relation_parts: - points_info = f"你还记得有关{self.person_name}的内容:{';'.join(relation_parts[:3])}" + if points_text: + points_info = f"你还记得有关{self.person_name}的内容:{points_text}" if not (nickname_str or points_info): return "" - return f"{self.person_name}:{nickname_str}{points_info}" + relation_info = f"{self.person_name}:{nickname_str}{points_info}" + + return relation_info class PersonInfoManager: @@ -822,7 +795,7 @@ person_info_manager = PersonInfoManager() async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None: - """将人物事实写入统一长期记忆 + """将人物信息存入person_info的memory_points Args: person_name: 人物名称 @@ -830,11 +803,6 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: 聊天ID """ try: - content = str(memory_content or "").strip() - if not content: - logger.debug("人物记忆内容为空,跳过写入") - return - # 从 chat_id 获取 session session = _chat_manager.get_session_by_session_id(chat_id) if not session: @@ -845,14 +813,16 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, # 尝试从person_name查找person_id # 首先尝试通过person_name查找 - person_id = resolve_person_id_for_memory( - person_name=person_name, - platform=platform, - user_id=session.user_id, - ) + person_id = get_person_id_by_person_name(person_name) + if not person_id: - logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") - return + # 如果通过person_name找不到,尝试从 session 获取 user_id + if platform and session.user_id: + user_id = session.user_id + person_id = get_person_id(platform, user_id) + else: + logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") + return # 创建或获取Person对象 person = Person(person_id=person_id) @@ -861,34 +831,39 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆") return - memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16] - result = await memory_service.ingest_text( - external_id=f"person_fact:{person_id}:{memory_hash}", - source_type="person_fact", - text=content, - chat_id=chat_id, - person_ids=[person_id], - participants=[person.person_name or person_name], - timestamp=time.time(), - tags=["person_fact"], - metadata={ - "person_id": person_id, - "person_name": person.person_name or person_name, - "platform": platform, - "source": "person_info.store_person_memory_from_answer", - }, - respect_filter=True, - user_id=str(session.user_id or "").strip(), - group_id=str(session.group_id or "").strip(), - ) + # 确定记忆分类(可以根据memory_content判断,这里使用通用分类) + category = "其他" # 默认分类,可以根据需要调整 - if result.success: - if result.detail == "chat_filtered": - logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})") - else: - logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})") + # 记忆点格式:category:content:weight + weight = "1.0" # 默认权重 + memory_point = f"{category}:{memory_content}:{weight}" + + # 添加到memory_points + if not person.memory_points: + person.memory_points = [] + + # 检查是否已存在相似的记忆点(避免重复) + is_duplicate = False + for existing_point in person.memory_points: + if existing_point and isinstance(existing_point, str): + parts = existing_point.split(":", 2) + if len(parts) >= 2: + existing_content = parts[1].strip() + # 简单相似度检查(如果内容相同或非常相似,则跳过) + if ( + existing_content == memory_content + or memory_content in existing_content + or existing_content in memory_content + ): + is_duplicate = True + break + + if not is_duplicate: + person.memory_points.append(memory_point) + person.sync_to_database() + logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}") else: - logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}") + logger.debug(f"记忆点已存在,跳过: {memory_point}") except Exception as e: logger.error(f"存储人物记忆失败: {e}") diff --git a/src/platform_io/__init__.py b/src/platform_io/__init__.py new file mode 100644 index 00000000..c91535d1 --- /dev/null +++ b/src/platform_io/__init__.py @@ -0,0 +1,34 @@ +"""导出 Platform IO 层的公开入口。 + +当前仍处于地基阶段,调用方应优先从这里导入共享类型和全局管理器, +而不是直接依赖更底层的私有子模块。 +""" + +from .manager import PlatformIOManager, get_platform_io_manager +from .route_key_factory import RouteKeyFactory +from .routing import RouteTable +from .types import ( + DeliveryBatch, + DeliveryReceipt, + DeliveryStatus, + DriverDescriptor, + DriverKind, + InboundMessageEnvelope, + RouteBinding, + RouteKey, +) + +__all__ = [ + "DeliveryBatch", + "DeliveryReceipt", + "DeliveryStatus", + "DriverDescriptor", + "DriverKind", + "InboundMessageEnvelope", + "PlatformIOManager", + "RouteKeyFactory", + "RouteBinding", + "RouteKey", + "RouteTable", + "get_platform_io_manager", +] diff --git a/src/platform_io/dedupe.py b/src/platform_io/dedupe.py new file mode 100644 index 00000000..4c5c55a2 --- /dev/null +++ b/src/platform_io/dedupe.py @@ -0,0 +1,133 @@ +"""提供 Platform IO 的轻量入站消息去重能力。 + +当前实现基于 ``dict + heapq``: +- ``dict`` 保存去重键到过期时间的映射 +- ``heapq`` 维护按过期时间排序的小顶堆 + +这样就不需要在每次检查时全表扫描,而是通过懒清理逐步弹出已经过期 +或已经失效的堆节点。 +""" + +from typing import Dict, List, Tuple + +import heapq +import time + + +class MessageDeduplicator: + """使用基于 TTL 的内存缓存进行入站消息去重。 + + 主要用于解决同一条外部消息被重复送入 Core 的问题,例如双路径并存、 + 适配器重试、重连或重复回调等场景。Broker 可以借助这个组件在进入 + Core 前先拦住重复投递,避免重复处理、重复回复和重复入库。 + + 当前实现使用 ``dict + heapq`` 维护过期时间: + - ``dict`` 负责 ``O(1)`` 级别的去重键查找 + - ``heapq`` 负责按过期时间顺序做懒清理 + + 这比“每次调用都全表扫描过期项”的实现更适合高吞吐消息场景。 + + Notes: + 复杂度说明如下,设 ``n`` 为当前缓存中的有效去重键数量: + + - 单次 ``mark_seen()`` 在常见路径下的时间复杂度接近 ``O(log n)`` + - 从长期摊还角度看,``mark_seen()`` 的时间复杂度也接近 ``O(log n)`` + - 如果某次调用恰好触发一批过期键的集中清理,则该次调用的最坏时间复杂度 + 可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出或清理的键数量 + - 空间复杂度为 ``O(n)`` + """ + + def __init__(self, ttl_seconds: float = 300.0, max_entries: int = 10000) -> None: + """初始化去重器。 + + Args: + ttl_seconds: 每个去重键在缓存中的保留时长,单位为秒。 + max_entries: 缓存允许保留的最大有效键数量,超出后会触发 + 机会性淘汰。 + + Raises: + ValueError: 当 ``ttl_seconds`` 或 ``max_entries`` 非正数时抛出。 + """ + if ttl_seconds <= 0: + raise ValueError("ttl_seconds 必须大于 0") + if max_entries <= 0: + raise ValueError("max_entries 必须大于 0") + + self._ttl_seconds = ttl_seconds + self._max_entries = max_entries + self._expire_heap: List[Tuple[float, str]] = [] + self._seen: Dict[str, float] = {} + + def mark_seen(self, dedupe_key: str) -> bool: + """标记一条去重键已经出现过。 + + Args: + dedupe_key: 能稳定标识一条外部入站消息的去重键。 + + Returns: + bool: 若该键在当前 TTL 窗口内首次出现则返回 ``True``, + 否则返回 ``False``。 + + Notes: + 方法会先基于小顶堆做一次懒清理,再判断当前键是否仍在有效期内。 + 如果缓存已达到上限,则会优先淘汰“最早过期的仍然有效的键”。 + + 复杂度方面,常见路径下该方法接近 ``O(log n)``;如果恰好需要 + 集中清理一批过期键,则单次调用最坏可达到 ``O(k log n)``。 + """ + now = time.monotonic() + self._purge_expired(now) + + expires_at = self._seen.get(dedupe_key) + if expires_at is not None and expires_at > now: + return False + + if len(self._seen) >= self._max_entries: + self._evict_earliest_live() + + expires_at = now + self._ttl_seconds + self._seen[dedupe_key] = expires_at + heapq.heappush(self._expire_heap, (expires_at, dedupe_key)) + return True + + def clear(self) -> None: + """清空全部去重缓存。""" + self._expire_heap.clear() + self._seen.clear() + + def _purge_expired(self, now: float) -> None: + """从缓存中清理已经过期的去重键。 + + Args: + now: 当前单调时钟时间戳。 + + Notes: + 堆中可能存在旧版本节点。例如同一个 ``dedupe_key`` 被重新写入后, + 旧的过期时间节点仍会留在堆里。这里会通过和 ``dict`` 中当前值比对, + 跳过这类失效节点。 + """ + while self._expire_heap and self._expire_heap[0][0] <= now: + expires_at, dedupe_key = heapq.heappop(self._expire_heap) + current_expires_at = self._seen.get(dedupe_key) + if current_expires_at is None: + continue + if current_expires_at != expires_at: + continue + self._seen.pop(dedupe_key, None) + + def _evict_earliest_live(self) -> None: + """当缓存达到容量上限时,淘汰一条最早过期的有效键。 + + Notes: + 堆顶可能是已经过期或已失效的旧节点,因此这里同样需要循环弹出, + 直到找到一条当前仍然在 ``dict`` 中生效的键。 + """ + while self._expire_heap: + expires_at, dedupe_key = heapq.heappop(self._expire_heap) + current_expires_at = self._seen.get(dedupe_key) + if current_expires_at is None: + continue + if current_expires_at != expires_at: + continue + self._seen.pop(dedupe_key, None) + return diff --git a/src/platform_io/drivers/__init__.py b/src/platform_io/drivers/__init__.py new file mode 100644 index 00000000..b12120cf --- /dev/null +++ b/src/platform_io/drivers/__init__.py @@ -0,0 +1,11 @@ +"""导出 Platform IO 层的公开驱动类型。""" + +from .base import PlatformIODriver +from .legacy_driver import LegacyPlatformDriver +from .plugin_driver import PluginPlatformDriver + +__all__ = [ + "LegacyPlatformDriver", + "PlatformIODriver", + "PluginPlatformDriver", +] diff --git a/src/platform_io/drivers/base.py b/src/platform_io/drivers/base.py new file mode 100644 index 00000000..c6173d8c --- /dev/null +++ b/src/platform_io/drivers/base.py @@ -0,0 +1,104 @@ +"""定义 Platform IO 传输驱动的基础抽象协议。""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional + +from src.platform_io.types import DeliveryReceipt, DriverDescriptor, InboundMessageEnvelope, RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + +InboundHandler = Callable[[InboundMessageEnvelope], Awaitable[bool]] + + +class PlatformIODriver(ABC): + """定义所有 Platform IO 驱动都必须实现的最小契约。 + + 当前实现故意保持接口很小,让中间层可以先落地,再逐步把 legacy + 与 plugin 路径的真实收发能力迁入这套协议之下。 + """ + + def __init__(self, descriptor: DriverDescriptor) -> None: + """使用驱动描述对象初始化驱动。 + + Args: + descriptor: 注册到 Broker 中的静态驱动元数据。 + """ + self._descriptor = descriptor + self._inbound_handler: Optional[InboundHandler] = None + + @property + def descriptor(self) -> DriverDescriptor: + """返回当前驱动的描述对象。 + + Returns: + DriverDescriptor: 当前驱动实例对应的描述对象。 + """ + return self._descriptor + + @property + def driver_id(self) -> str: + """返回驱动标识。 + + Returns: + str: 当前驱动的唯一 ID。 + """ + return self._descriptor.driver_id + + def set_inbound_handler(self, handler: InboundHandler) -> None: + """注册入站消息交回 Broker 的回调函数。 + + Args: + handler: 将规范化入站封装继续转发给 Broker 的异步回调。 + """ + self._inbound_handler = handler + + def clear_inbound_handler(self) -> None: + """清除当前注册的入站回调函数。""" + self._inbound_handler = None + + async def emit_inbound(self, envelope: InboundMessageEnvelope) -> bool: + """将一条入站封装转交给 Broker 回调。 + + Args: + envelope: 由驱动产出的规范化入站封装。 + + Returns: + bool: 若 Broker 接受该入站消息则返回 ``True``,否则返回 ``False``。 + """ + + if self._inbound_handler is None: + return False + return await self._inbound_handler(envelope) + + async def start(self) -> None: + """启动驱动生命周期。 + + 子类后续若需要初始化逻辑,可以覆盖这个钩子。 + """ + return None + + async def stop(self) -> None: + """停止驱动生命周期。 + + 子类后续若需要清理逻辑,可以覆盖这个钩子。 + """ + return None + + @abstractmethod + async def send_message( + self, + message: "SessionMessage", + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """通过具体驱动发送一条消息。 + + Args: + message: 要投递的内部会话消息。 + route_key: Broker 为本次投递选中的路由键。 + metadata: 本次出站投递可选的 Broker 侧元数据。 + + Returns: + DeliveryReceipt: 规范化后的投递结果。 + """ diff --git a/src/platform_io/drivers/legacy_driver.py b/src/platform_io/drivers/legacy_driver.py new file mode 100644 index 00000000..ef90c772 --- /dev/null +++ b/src/platform_io/drivers/legacy_driver.py @@ -0,0 +1,92 @@ +"""提供 Platform IO 的 legacy 传输驱动实现。""" + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + + +class LegacyPlatformDriver(PlatformIODriver): + """面向 ``UniversalMessageSender`` 旧链的 Platform IO 驱动。""" + + def __init__( + self, + driver_id: str, + platform: str, + account_id: Optional[str] = None, + scope: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """初始化一个 legacy 驱动描述对象。 + + Args: + driver_id: Broker 内的唯一驱动 ID。 + platform: 该 legacy 适配器链路负责的平台。 + account_id: 可选的账号 ID。 + scope: 可选的额外路由作用域。 + metadata: 可选的额外驱动元数据。 + """ + descriptor = DriverDescriptor( + driver_id=driver_id, + kind=DriverKind.LEGACY, + platform=platform, + account_id=account_id, + scope=scope, + metadata=metadata or {}, + ) + super().__init__(descriptor) + + async def send_message( + self, + message: "SessionMessage", + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """通过旧链发送一条已经过预处理的消息。 + + Args: + message: 要投递的内部会话消息。 + route_key: Broker 为本次投递选择的路由键。 + metadata: 本次出站投递可选的 Broker 侧元数据。 + + Returns: + DeliveryReceipt: 规范化后的发送回执。 + """ + from src.chat.message_receive.uni_message_sender import send_prepared_message_to_platform + + show_log = False + if isinstance(metadata, dict): + show_log = bool(metadata.get("show_log", False)) + + try: + sent = await send_prepared_message_to_platform(message, show_log=show_log) + except Exception as exc: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(exc), + ) + + if not sent: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error="旧链发送失败", + ) + + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py new file mode 100644 index 00000000..c03204ad --- /dev/null +++ b/src/platform_io/drivers/plugin_driver.py @@ -0,0 +1,211 @@ +"""提供 Platform IO 的插件消息网关驱动实现。""" + +from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol + +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + + +class _GatewaySupervisorProtocol(Protocol): + """消息网关驱动依赖的 Supervisor 最小协议。""" + + async def invoke_message_gateway( + self, + plugin_id: str, + component_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Any: + """调用插件声明的消息网关方法。""" + + +class PluginPlatformDriver(PlatformIODriver): + """面向插件消息网关链路的 Platform IO 驱动。""" + + def __init__( + self, + driver_id: str, + platform: str, + supervisor: _GatewaySupervisorProtocol, + component_name: str, + *, + supports_send: bool, + account_id: Optional[str] = None, + scope: Optional[str] = None, + plugin_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """初始化一个插件消息网关驱动。 + + Args: + driver_id: Broker 内的唯一驱动 ID。 + platform: 该消息网关负责的平台名称。 + supervisor: 持有该插件的 Supervisor。 + component_name: 出站时要调用的网关组件名称。 + supports_send: 当前驱动是否具备出站能力。 + account_id: 可选的账号 ID 或 self ID。 + scope: 可选的额外路由作用域。 + plugin_id: 拥有该实现的插件 ID。 + metadata: 可选的额外驱动元数据。 + """ + + descriptor = DriverDescriptor( + driver_id=driver_id, + kind=DriverKind.PLUGIN, + platform=platform, + account_id=account_id, + scope=scope, + plugin_id=plugin_id, + metadata=metadata or {}, + ) + super().__init__(descriptor) + self._supervisor = supervisor + self._component_name = component_name + self._supports_send = supports_send + + async def send_message( + self, + message: "SessionMessage", + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """通过插件消息网关发送消息。 + + Args: + message: 要投递的内部会话消息。 + route_key: Broker 为本次投递选择的路由键。 + metadata: 可选的发送元数据。 + + Returns: + DeliveryReceipt: 规范化后的发送回执。 + """ + + if not self._supports_send: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error="当前消息网关仅支持接收,不支持发送", + ) + + from src.plugin_runtime.host.message_utils import PluginMessageUtils + + plugin_id = self.descriptor.plugin_id or "" + if not plugin_id: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error="插件消息网关驱动缺少 plugin_id", + ) + + try: + message_dict = PluginMessageUtils._session_message_to_dict(message) + response = await self._supervisor.invoke_message_gateway( + plugin_id=plugin_id, + component_name=self._component_name, + args={ + "message": message_dict, + "route": { + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + }, + "metadata": metadata or {}, + }, + timeout_ms=30000, + ) + except Exception as exc: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(exc), + ) + + return self._build_receipt(message.message_id, route_key, response) + + def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt: + """将网关调用响应归一化为出站回执。 + + Args: + internal_message_id: 内部消息 ID。 + route_key: 本次投递的路由键。 + response: Supervisor 返回的 RPC 响应对象。 + + Returns: + DeliveryReceipt: 标准化后的出站回执。 + """ + + if getattr(response, "error", None): + error = response.error.get("message", "消息网关发送失败") + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=error, + ) + + payload = getattr(response, "payload", {}) + invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False + if not invoke_success: + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(payload.get("result", "消息网关发送失败")) if isinstance(payload, dict) else "消息网关发送失败", + ) + + result = payload.get("result") if isinstance(payload, dict) else None + if isinstance(result, dict): + if result.get("success") is False: + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(result.get("error", "消息网关发送失败")), + metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {}, + ) + external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + external_message_id=external_message_id, + metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {}, + ) + + if isinstance(result, str) and result.strip(): + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + external_message_id=result.strip(), + ) + + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py new file mode 100644 index 00000000..dee553a6 --- /dev/null +++ b/src/platform_io/manager.py @@ -0,0 +1,611 @@ +"""提供 Platform IO 层的中心 Broker 管理器。""" + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional + +from src.common.logger import get_logger +from src.platform_io.drivers.base import PlatformIODriver + +from .dedupe import MessageDeduplicator +from .outbound_tracker import OutboundTracker +from .route_key_factory import RouteKeyFactory +from .registry import DriverRegistry +from .routing import RouteTable +from .types import DeliveryBatch, DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + +logger = get_logger("platform_io.manager") + +InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]] + + +class PlatformIOManager: + """统一协调平台消息 IO 的路由、去重与状态跟踪。 + + 与旧实现不同,这个管理器不再负责“多条链路谁该接管平台”的裁决, + 只维护发送表和接收表两张轻量路由表: + + - 发送时:解析所有命中的发送绑定并全部投递。 + - 接收时:只校验当前驱动是否已登记为可接收链路,然后全部放行给上层。 + - 去重时:仅对单条链路做技术性重放抑制,不做跨链路语义去重。 + """ + + def __init__(self) -> None: + """初始化 Broker 管理器及其内存状态。""" + self._driver_registry = DriverRegistry() + self._send_route_table = RouteTable() + self._receive_route_table = RouteTable() + self._legacy_send_drivers: Dict[str, PlatformIODriver] = {} + self._deduplicator = MessageDeduplicator() + self._outbound_tracker = OutboundTracker() + self._inbound_dispatcher: Optional[InboundDispatcher] = None + self._started = False + + @property + def is_started(self) -> bool: + """返回 Broker 当前是否已进入运行态。 + + Returns: + bool: 若 Broker 已启动则返回 ``True``。 + """ + return self._started + + async def start(self) -> None: + """启动 Broker,并依次启动当前已注册的全部驱动。 + + Raises: + Exception: 当某个驱动启动失败时,异常会继续上抛;已成功启动的驱动 + 会被自动回滚停止。 + """ + if self._started: + return + + started_drivers: List[PlatformIODriver] = [] + try: + for driver in self._driver_registry.list(): + await driver.start() + started_drivers.append(driver) + except Exception: + for driver in reversed(started_drivers): + try: + await driver.stop() + except Exception: + logger.exception(f"回滚驱动停止失败: driver_id={driver.driver_id}") + raise + + self._started = True + + async def ensure_send_pipeline_ready(self) -> None: + """确保出站发送管线已准备就绪。 + + 该方法会先同步 legacy fallback driver,再在需要时启动 Broker。 + send service 应只调用这一层准备入口,而不是自行判断旧链或插件链。 + """ + await self._sync_legacy_send_drivers() + if not self._started: + await self.start() + + async def stop(self) -> None: + """停止 Broker,并按逆序停止全部已注册驱动。 + + 停止完成后,会同步清空仅对当前运行周期有效的去重缓存和出站跟踪状态, + 避免下一次启动时继续沿用上一个运行周期的瞬时内存数据。 + + Raises: + RuntimeError: 当一个或多个驱动停止失败时抛出汇总异常。 + """ + if not self._started: + return + + stop_errors: List[str] = [] + for driver in reversed(self._driver_registry.list()): + try: + await driver.stop() + except Exception as exc: + stop_errors.append(f"{driver.driver_id}: {exc}") + logger.exception(f"驱动停止失败: driver_id={driver.driver_id}") + + self._started = False + self._deduplicator.clear() + self._outbound_tracker.clear() + if stop_errors: + raise RuntimeError(f"部分驱动停止失败: {'; '.join(stop_errors)}") + + async def add_driver(self, driver: PlatformIODriver) -> None: + """向运行中的 Broker 注册并启动一个驱动。 + + 如果 Broker 尚未启动,则该方法等价于 ``register_driver()``。 + + Args: + driver: 要添加的驱动实例。 + + Raises: + Exception: 当驱动启动失败时,注册会自动回滚,异常继续上抛。 + """ + self._register_driver_internal(driver) + if not self._started: + return + + try: + await driver.start() + except Exception: + self._unregister_driver_internal(driver.driver_id) + raise + + async def remove_driver(self, driver_id: str) -> Optional[PlatformIODriver]: + """从运行中的 Broker 停止并移除一个驱动。 + + 如果 Broker 尚未启动,则该方法等价于 ``unregister_driver()``。 + + Args: + driver_id: 要移除的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。 + + Raises: + Exception: 当 Broker 运行中且驱动停止失败时,异常会继续上抛。 + """ + if not self._started: + return self.unregister_driver(driver_id) + + driver = self._driver_registry.get(driver_id) + if driver is None: + return None + + await driver.stop() + return self._unregister_driver_internal(driver_id) + + @property + def driver_registry(self) -> DriverRegistry: + """返回管理器持有的驱动注册表。 + + Returns: + DriverRegistry: 用于保存全部已注册驱动的注册表。 + """ + return self._driver_registry + + @property + def send_route_table(self) -> RouteTable: + """返回发送路由表。""" + + return self._send_route_table + + @property + def receive_route_table(self) -> RouteTable: + """返回接收路由表。""" + + return self._receive_route_table + + @property + def deduplicator(self) -> MessageDeduplicator: + """返回管理器持有的入站去重器。 + + Returns: + MessageDeduplicator: 用于抑制重复入站的去重器。 + """ + return self._deduplicator + + @property + def outbound_tracker(self) -> OutboundTracker: + """返回管理器持有的出站跟踪器。 + + Returns: + OutboundTracker: 用于记录出站 pending 状态与回执的跟踪器。 + """ + return self._outbound_tracker + + def set_inbound_dispatcher(self, dispatcher: InboundDispatcher) -> None: + """设置统一的入站分发回调。 + + Args: + dispatcher: 接收已通过 Broker 审核的入站封装,并继续送入 + Core 下一处理阶段的异步回调。 + """ + + self._inbound_dispatcher = dispatcher + + def clear_inbound_dispatcher(self) -> None: + """清除当前的入站分发回调。""" + self._inbound_dispatcher = None + + @property + def has_inbound_dispatcher(self) -> bool: + """返回当前是否已经配置入站分发回调。 + + Returns: + bool: 若已经配置入站分发回调则返回 ``True``。 + """ + return self._inbound_dispatcher is not None + + def register_driver(self, driver: PlatformIODriver) -> None: + """注册驱动,并把它的入站回调挂到 Broker。 + + Args: + driver: 要注册的驱动实例。 + + Raises: + RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用 + ``add_driver()`` 以保证驱动生命周期和注册状态一致。 + """ + if self._started: + raise RuntimeError("Broker 运行中不允许直接 register_driver,请改用 add_driver()") + + self._register_driver_internal(driver) + + def _register_driver_internal(self, driver: PlatformIODriver) -> None: + """执行不带运行态限制的内部驱动注册。 + + Args: + driver: 要注册的驱动实例。 + """ + driver.set_inbound_handler(self.accept_inbound) + self._driver_registry.register(driver) + + def unregister_driver(self, driver_id: str) -> Optional[PlatformIODriver]: + """从 Broker 注销一个驱动。 + + Args: + driver_id: 要移除的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。 + + Raises: + RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用 + ``remove_driver()``,避免驱动停止与路由解绑脱节。 + """ + if self._started: + raise RuntimeError("Broker 运行中不允许直接 unregister_driver,请改用 remove_driver()") + + return self._unregister_driver_internal(driver_id) + + def _unregister_driver_internal(self, driver_id: str) -> Optional[PlatformIODriver]: + """执行不带运行态限制的内部驱动注销。 + + Args: + driver_id: 要移除的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。 + """ + removed_driver = self._driver_registry.unregister(driver_id) + if removed_driver is None: + return None + + removed_driver.clear_inbound_handler() + self._send_route_table.remove_bindings_by_driver(driver_id) + self._receive_route_table.remove_bindings_by_driver(driver_id) + self._legacy_send_drivers = { + platform: driver + for platform, driver in self._legacy_send_drivers.items() + if driver.driver_id != driver_id + } + return removed_driver + + async def _sync_legacy_send_drivers(self) -> None: + """根据当前配置同步 legacy fallback driver。""" + from src.chat.utils.utils import get_all_bot_accounts + from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver + + desired_accounts = get_all_bot_accounts() + desired_platforms = set(desired_accounts.keys()) + current_platforms = set(self._legacy_send_drivers.keys()) + + for platform in sorted(current_platforms - desired_platforms): + await self._remove_legacy_send_driver(platform) + + for platform, account_id in desired_accounts.items(): + existing_driver = self._legacy_send_drivers.get(platform) + if existing_driver is not None and existing_driver.descriptor.account_id == account_id: + continue + + if existing_driver is not None: + await self._remove_legacy_send_driver(platform) + + driver = LegacyPlatformDriver( + driver_id=f"legacy.send.{platform}", + platform=platform, + account_id=account_id, + ) + if self._started: + await self.add_driver(driver) + else: + self.register_driver(driver) + self._legacy_send_drivers[platform] = driver + + async def _remove_legacy_send_driver(self, platform: str) -> None: + """移除指定平台的 legacy fallback driver。 + + Args: + platform: 要移除的目标平台。 + """ + driver = self._legacy_send_drivers.get(platform) + if driver is None: + return + + if self._started: + await self.remove_driver(driver.driver_id) + else: + self.unregister_driver(driver.driver_id) + self._legacy_send_drivers.pop(platform, None) + + def bind_send_route(self, binding: RouteBinding) -> None: + """为某个路由键绑定发送驱动。 + + Args: + binding: 要保存的路由绑定。 + + Raises: + ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。 + """ + driver = self._driver_registry.get(binding.driver_id) + if driver is None: + raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由") + + self._validate_binding_against_driver(binding, driver) + self._send_route_table.bind(binding) + + def bind_receive_route(self, binding: RouteBinding) -> None: + """为某个路由键绑定接收驱动。 + + Args: + binding: 要保存的路由绑定。 + + Raises: + ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。 + """ + driver = self._driver_registry.get(binding.driver_id) + if driver is None: + raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由") + + self._validate_binding_against_driver(binding, driver) + self._receive_route_table.bind(binding) + + def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: + """移除发送路由绑定。 + + Args: + route_key: 要移除绑定的路由键。 + driver_id: 可选的特定驱动 ID。 + """ + + self._send_route_table.unbind(route_key, driver_id) + + def unbind_receive_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: + """移除接收路由绑定。 + + Args: + route_key: 要移除绑定的路由键。 + driver_id: 可选的特定驱动 ID。 + """ + + self._receive_route_table.unbind(route_key, driver_id) + + def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]: + """解析某个路由键当前命中的全部发送驱动。 + + Args: + route_key: 要解析的路由键。 + + Returns: + List[PlatformIODriver]: 当前命中的全部发送驱动。 + """ + + drivers: List[PlatformIODriver] = [] + seen_driver_ids: set[str] = set() + for binding in self._send_route_table.resolve_bindings(route_key): + driver = self._driver_registry.get(binding.driver_id) + if driver is not None and driver.driver_id not in seen_driver_ids: + drivers.append(driver) + seen_driver_ids.add(driver.driver_id) + + fallback_driver = self._legacy_send_drivers.get(route_key.platform) + if fallback_driver is not None: + descriptor = fallback_driver.descriptor + account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id) + scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope) + if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids: + drivers.append(fallback_driver) + + return drivers + + @staticmethod + def build_route_key_from_message(message: "SessionMessage") -> RouteKey: + """根据 ``SessionMessage`` 构造路由键。 + + Args: + message: 内部会话消息对象。 + + Returns: + RouteKey: 由消息内容提取出的规范化路由键。 + """ + return RouteKeyFactory.from_session_message(message) + + @staticmethod + def build_route_key_from_message_dict(message_dict: Dict[str, Any]) -> RouteKey: + """根据消息字典构造路由键。 + + Args: + message_dict: Host 与插件之间传输的消息字典。 + + Returns: + RouteKey: 由消息字典提取出的规范化路由键。 + """ + return RouteKeyFactory.from_message_dict(message_dict) + + async def accept_inbound(self, envelope: InboundMessageEnvelope) -> bool: + """处理一条由驱动上报的入站封装。 + + Args: + envelope: 由传输驱动产出的入站封装。 + + Returns: + bool: 若消息被接受并继续转发给入站分发器,则返回 ``True``, + 否则返回 ``False``。 + """ + + if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id): + logger.info( + f"忽略未登记到接收路由表的入站消息: route={envelope.route_key} " + f"driver={envelope.driver_id}" + ) + return False + + if self._inbound_dispatcher is None: + logger.debug("PlatformIOManager 尚未配置 inbound dispatcher,暂不继续分发") + return False + + dedupe_key = self._build_inbound_dedupe_key(envelope) + if dedupe_key is not None: + if not self._deduplicator.mark_seen(dedupe_key): + logger.info(f"忽略重复入站消息: dedupe_key={dedupe_key}") + return False + + await self._inbound_dispatcher(envelope) + return True + + async def send_message( + self, + message: "SessionMessage", + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryBatch: + """通过 Broker 选中的全部发送驱动广播一条消息。 + + Args: + message: 要投递的内部会话消息。 + route_key: 本次出站投递选择的路由键。 + metadata: 可选的额外 Broker 侧元数据。 + + Returns: + DeliveryBatch: 规范化后的批量出站回执。 + """ + drivers = self.resolve_drivers(route_key) + if not drivers: + return DeliveryBatch(internal_message_id=message.message_id, route_key=route_key) + + receipts: List[DeliveryReceipt] = [] + for driver in drivers: + try: + self._outbound_tracker.begin_tracking( + internal_message_id=message.message_id, + route_key=route_key, + driver_id=driver.driver_id, + metadata=metadata, + ) + except ValueError as exc: + receipts.append( + DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=driver.driver_id, + driver_kind=driver.descriptor.kind, + error=str(exc), + ) + ) + continue + + try: + receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata) + except Exception as exc: + receipt = DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=driver.driver_id, + driver_kind=driver.descriptor.kind, + error=str(exc), + ) + + self._outbound_tracker.finish_tracking(receipt) + receipts.append(receipt) + + return DeliveryBatch( + internal_message_id=message.message_id, + route_key=route_key, + receipts=receipts, + ) + + @staticmethod + def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]: + """构造用于入站抑制的去重键。 + + Args: + envelope: 当前正在处理的入站封装。 + + Returns: + Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。 + + Notes: + 这里仅接受上游显式提供的稳定消息身份,例如 ``dedupe_key``、 + 平台侧 ``external_message_id`` 或已经完成规范化的 + ``session_message.message_id``。Broker 不再根据 ``payload`` 内容 + 猜测语义去重键,避免把“短时间内两条内容刚好完全相同”的合法消息 + 误判为重复入站。 + """ + raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id + if raw_dedupe_key is None and envelope.session_message is not None: + raw_dedupe_key = envelope.session_message.message_id + if raw_dedupe_key is None: + return None + + normalized_dedupe_key = str(raw_dedupe_key).strip() + if not normalized_dedupe_key: + return None + + return f"{envelope.driver_id}:{normalized_dedupe_key}" + + @staticmethod + def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None: + """校验路由绑定与驱动描述是否一致。 + + Args: + binding: 待校验的路由绑定。 + driver: 被绑定的驱动实例。 + + Raises: + ValueError: 当绑定类型、平台或更细粒度路由维度与驱动描述冲突时抛出。 + """ + descriptor = driver.descriptor + if binding.driver_kind != descriptor.kind: + raise ValueError( + f"路由绑定的 driver_kind={binding.driver_kind} 与驱动 {driver.driver_id} 的类型 " + f"{descriptor.kind} 不一致" + ) + + if binding.route_key.platform != descriptor.platform: + raise ValueError( + f"路由绑定的平台 {binding.route_key.platform} 与驱动 {driver.driver_id} 的平台 " + f"{descriptor.platform} 不一致" + ) + + if descriptor.account_id is not None and binding.route_key.account_id not in (None, descriptor.account_id): + raise ValueError( + f"路由绑定的 account_id={binding.route_key.account_id} 与驱动 {driver.driver_id} 的 " + f"account_id={descriptor.account_id} 冲突" + ) + + if descriptor.scope is not None and binding.route_key.scope not in (None, descriptor.scope): + raise ValueError( + f"路由绑定的 scope={binding.route_key.scope} 与驱动 {driver.driver_id} 的 " + f"scope={descriptor.scope} 冲突" + ) + + +_platform_io_manager: Optional[PlatformIOManager] = None + + +def get_platform_io_manager() -> PlatformIOManager: + """返回全局 ``PlatformIOManager`` 单例。 + + Returns: + PlatformIOManager: 进程级共享的 Broker 管理器实例。 + """ + + global _platform_io_manager + if _platform_io_manager is None: + _platform_io_manager = PlatformIOManager() + return _platform_io_manager diff --git a/src/platform_io/outbound_tracker.py b/src/platform_io/outbound_tracker.py new file mode 100644 index 00000000..3725691f --- /dev/null +++ b/src/platform_io/outbound_tracker.py @@ -0,0 +1,286 @@ +"""跟踪 Platform IO 层的出站投递状态。 + +当前实现基于两组 ``dict + heapq``: +- ``_pending`` 和 ``_pending_expire_heap`` 负责管理待完成的出站记录 +- ``_receipts_by_external_id`` 和 ``_receipt_expire_heap`` 负责管理已完成回执索引 + +这样就不需要在每次读写时全表扫描过期项,而是通过懒清理逐步弹出已经过期 +或已经失效的堆节点。 +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import heapq +import time + +from .types import DeliveryReceipt, RouteKey + + +@dataclass(slots=True) +class PendingOutboundRecord: + """表示一条仍在等待完成的出站投递记录。 + + Attributes: + internal_message_id: 正在跟踪的内部 ``SessionMessage.message_id``。 + route_key: 该出站投递开始时使用的路由键。 + driver_id: 负责这次出站投递的驱动 ID。 + created_at: 开始跟踪时记录的单调时钟时间戳。 + expires_at: 该待完成记录预计过期的单调时钟时间戳。 + metadata: 与待完成记录一同保留的额外 Broker 侧元数据。 + """ + + internal_message_id: str + route_key: RouteKey + driver_id: str + created_at: float = field(default_factory=time.monotonic) + expires_at: float = 0.0 + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class StoredDeliveryReceipt: + """表示一条已完成并暂存的出站回执。 + + Attributes: + receipt: 规范化后的出站投递回执。 + stored_at: 回执被写入索引时记录的单调时钟时间戳。 + expires_at: 该回执索引预计过期的单调时钟时间戳。 + """ + + receipt: DeliveryReceipt + stored_at: float = field(default_factory=time.monotonic) + expires_at: float = 0.0 + + +class OutboundTracker: + """统一跟踪出站消息的 pending 状态与最终回执。 + + 主要用于解决出站消息在发送过程中“状态散落在不同路径里”的问题: + - 发送开始后,需要在最终回执返回前保留一份 pending 状态 + - 平台返回 ``external_message_id`` 后,需要保留一段时间的回执索引 + + 当前实现使用 ``dict + heapq`` 做 TTL 管理: + - ``dict`` 提供 ``O(1)`` 级别的主键查询 + - ``heapq`` 提供按过期时间排序的懒清理能力 + + 这比“每次 begin/finish/get 都全表扫描”的实现更适合高吞吐出站场景。 + + Notes: + 复杂度说明如下,设 ``p`` 为当前有效 pending 数量,``r`` 为当前有效回执数量: + + - ``begin_tracking()``、``finish_tracking()`` 的常见路径时间复杂度接近 + ``O(log p)`` 或 ``O(log r)`` + - ``get_pending()``、``get_receipt_by_external_id()`` 的查询本身是 ``O(1)`` + ,连同懒清理一起看,长期摊还复杂度接近 ``O(log n)`` + - 如果某次调用恰好触发一批过期节点的集中清理,则该次调用的最坏时间复杂度 + 可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出的节点数量 + - 空间复杂度为 ``O(p + r)`` + """ + + def __init__(self, ttl_seconds: float = 1800.0) -> None: + """初始化出站跟踪器。 + + Args: + ttl_seconds: 待完成记录与按外部消息 ID 建立的回执索引保留时长, + 单位为秒。 + + Raises: + ValueError: 当 ``ttl_seconds`` 非正数时抛出。 + """ + if ttl_seconds <= 0: + raise ValueError("ttl_seconds 必须大于 0") + + self._ttl_seconds = ttl_seconds + self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {} + self._pending_expire_heap: List[Tuple[float, str, str]] = [] + self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {} + self._receipt_expire_heap: List[Tuple[float, str]] = [] + + @staticmethod + def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]: + """构造单条出站跟踪记录的唯一键。 + + Args: + internal_message_id: 内部消息 ID。 + driver_id: 负责当前投递的驱动 ID。 + + Returns: + Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。 + """ + return internal_message_id, driver_id + + def begin_tracking( + self, + internal_message_id: str, + route_key: RouteKey, + driver_id: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> PendingOutboundRecord: + """开始跟踪一次出站投递。 + + Args: + internal_message_id: 正在投递的内部消息 ID。 + route_key: 这次出站投递选择的路由键。 + driver_id: 负责本次投递的驱动 ID。 + metadata: 可选的额外元数据,会一并保存在待完成记录中。 + + Returns: + PendingOutboundRecord: 新创建的待完成记录。 + + Raises: + ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在 + 未完成记录时抛出。 + """ + now = time.monotonic() + self._cleanup_expired(now) + pending_key = self._build_pending_key(internal_message_id, driver_id) + + if pending_key in self._pending: + raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录") + + expires_at = now + self._ttl_seconds + record = PendingOutboundRecord( + internal_message_id=internal_message_id, + route_key=route_key, + driver_id=driver_id, + created_at=now, + expires_at=expires_at, + metadata=metadata or {}, + ) + self._pending[pending_key] = record + heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id)) + return record + + def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]: + """使用最终回执结束一条出站跟踪。 + + Args: + receipt: 规范化后的最终投递回执。 + + Returns: + Optional[PendingOutboundRecord]: 若此前存在待完成记录,则返回该记录。 + """ + now = time.monotonic() + self._cleanup_expired(now) + + pending_record: Optional[PendingOutboundRecord] = None + if receipt.driver_id: + pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id) + pending_record = self._pending.pop(pending_key, None) + else: + matched_records = [ + key + for key, record in self._pending.items() + if record.internal_message_id == receipt.internal_message_id + ] + if len(matched_records) == 1: + pending_record = self._pending.pop(matched_records[0], None) + + if receipt.external_message_id: + expires_at = now + self._ttl_seconds + self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt( + receipt=receipt, + stored_at=now, + expires_at=expires_at, + ) + heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id)) + return pending_record + + def get_pending( + self, + internal_message_id: str, + driver_id: Optional[str] = None, + ) -> Optional[PendingOutboundRecord]: + """根据内部消息 ID 查询待完成记录。 + + Args: + internal_message_id: 要查询的内部消息 ID。 + driver_id: 可选的驱动 ID;提供后仅返回该驱动上的待完成记录。 + + Returns: + Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。 + """ + self._cleanup_expired(time.monotonic()) + + if driver_id: + return self._pending.get(self._build_pending_key(internal_message_id, driver_id)) + + matched_records = [ + record + for record in self._pending.values() + if record.internal_message_id == internal_message_id + ] + if len(matched_records) == 1: + return matched_records[0] + return None + + def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]: + """根据外部平台消息 ID 查询已完成回执。 + + Args: + external_message_id: 要查询的平台侧消息 ID。 + + Returns: + Optional[DeliveryReceipt]: 若存在对应回执,则返回该回执。 + """ + self._cleanup_expired(time.monotonic()) + stored_receipt = self._receipts_by_external_id.get(external_message_id) + return stored_receipt.receipt if stored_receipt else None + + def clear(self) -> None: + """清空全部待完成记录与已保存回执。""" + self._pending.clear() + self._pending_expire_heap.clear() + self._receipts_by_external_id.clear() + self._receipt_expire_heap.clear() + + def _cleanup_expired(self, now: float) -> None: + """清理内存中已经过期的待完成记录与已保存回执。 + + Args: + now: 当前单调时钟时间戳。 + """ + self._cleanup_expired_pending(now) + self._cleanup_expired_receipts(now) + + def _cleanup_expired_pending(self, now: float) -> None: + """清理已经过期的待完成记录。 + + Args: + now: 当前单调时钟时间戳。 + + Notes: + 堆中可能存在已经失效的旧节点。例如某条记录提前 ``finish`` 后, + 它原本的过期节点仍可能留在堆里。这里会通过和 ``dict`` 中当前记录的 + ``expires_at`` 对比,跳过这类旧节点。 + """ + while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now: + expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap) + pending_key = self._build_pending_key(internal_message_id, driver_id) + current_record = self._pending.get(pending_key) + if current_record is None: + continue + if current_record.expires_at != expires_at: + continue + self._pending.pop(pending_key, None) + + def _cleanup_expired_receipts(self, now: float) -> None: + """清理已经过期的回执索引。 + + Args: + now: 当前单调时钟时间戳。 + + Notes: + 同一个 ``external_message_id`` 在极端情况下可能被重复写入索引, + 因此这里同样需要通过 ``expires_at`` 和当前 ``dict`` 中的值比对, + 跳过已经失效的旧堆节点。 + """ + while self._receipt_expire_heap and self._receipt_expire_heap[0][0] <= now: + expires_at, external_message_id = heapq.heappop(self._receipt_expire_heap) + current_receipt = self._receipts_by_external_id.get(external_message_id) + if current_receipt is None: + continue + if current_receipt.expires_at != expires_at: + continue + self._receipts_by_external_id.pop(external_message_id, None) diff --git a/src/platform_io/registry.py b/src/platform_io/registry.py new file mode 100644 index 00000000..9ad8ea8a --- /dev/null +++ b/src/platform_io/registry.py @@ -0,0 +1,70 @@ +"""提供 Platform IO 的驱动注册与查询能力。""" + +from typing import Dict, List, Optional + +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.types import DriverKind + + +class DriverRegistry: + """集中保存已注册的 Platform IO 驱动,并提供基础查询接口。""" + + def __init__(self) -> None: + """初始化一个空的驱动注册表。""" + self._drivers: Dict[str, PlatformIODriver] = {} + + def register(self, driver: PlatformIODriver) -> None: + """注册一个驱动实例。 + + Args: + driver: 要注册的驱动实例。 + + Raises: + ValueError: 当驱动 ID 已经存在时抛出。 + """ + if driver.driver_id in self._drivers: + raise ValueError(f"驱动 {driver.driver_id} 已注册") + self._drivers[driver.driver_id] = driver + + def unregister(self, driver_id: str) -> Optional[PlatformIODriver]: + """按驱动 ID 注销一个驱动。 + + Args: + driver_id: 要移除的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。 + """ + return self._drivers.pop(driver_id, None) + + def get(self, driver_id: str) -> Optional[PlatformIODriver]: + """按驱动 ID 获取驱动实例。 + + Args: + driver_id: 要查询的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若存在匹配驱动,则返回该驱动实例。 + """ + return self._drivers.get(driver_id) + + def list(self, *, kind: Optional[DriverKind] = None, platform: Optional[str] = None) -> List[PlatformIODriver]: + """列出已注册驱动,并支持可选过滤。 + + Args: + kind: 可选的驱动类型过滤条件。 + platform: 可选的平台名称过滤条件。 + + Returns: + List[PlatformIODriver]: 符合过滤条件的驱动列表。 + """ + drivers = list(self._drivers.values()) + if kind is not None: + drivers = [driver for driver in drivers if driver.descriptor.kind == kind] + if platform is not None: + drivers = [driver for driver in drivers if driver.descriptor.platform == platform] + return drivers + + def clear(self) -> None: + """清空全部已注册驱动。""" + self._drivers.clear() diff --git a/src/platform_io/route_key_factory.py b/src/platform_io/route_key_factory.py new file mode 100644 index 00000000..05bac6e8 --- /dev/null +++ b/src/platform_io/route_key_factory.py @@ -0,0 +1,150 @@ +"""提供 Platform IO 路由键的统一提取与构造能力。 + +这层的目标不是直接接入具体消息链,而是先把“未来接线时用什么字段构造 +RouteKey”约定下来,避免 legacy 和 plugin 两条链路各自发明一套隐式规则。 +""" + +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +from .types import RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + + +class RouteKeyFactory: + """统一构造 ``RouteKey`` 的工厂。 + + 当前约定会优先从消息字典顶层、``message_info``、``additional_config`` 或传入 metadata 中提取 + 以下字段: + + - account_id: ``platform_io_account_id`` / ``account_id`` / ``self_id`` / ``bot_account`` + - scope: ``platform_io_scope`` / ``route_scope`` / ``adapter_scope`` / ``connection_id`` + + 这样即使上游主链暂时还没有正式的 ``self_id`` 字段,中间层也能先统一 + 约定提取口径,等具体消息链接入时直接复用。 + """ + + ACCOUNT_ID_KEYS = ( + "platform_io_account_id", + "account_id", + "self_id", + "bot_account", + ) + SCOPE_KEYS = ( + "platform_io_scope", + "route_scope", + "adapter_scope", + "connection_id", + ) + + @classmethod + def from_platform( + cls, + platform: str, + *, + account_id: Optional[str] = None, + scope: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> RouteKey: + """根据平台名和可选 metadata 构造 ``RouteKey``。 + + Args: + platform: 平台名称。 + account_id: 显式传入的账号 ID;若为空,则尝试从 metadata 提取。 + scope: 显式传入的路由作用域;若为空,则尝试从 metadata 提取。 + metadata: 可选的元数据字典。 + + Returns: + RouteKey: 构造出的规范化路由键。 + """ + extracted_account_id, extracted_scope = cls.extract_components(metadata) + return RouteKey( + platform=platform, + account_id=account_id or extracted_account_id, + scope=scope or extracted_scope, + ) + + @classmethod + def from_message_dict(cls, message_dict: Dict[str, Any]) -> RouteKey: + """从消息字典中提取 ``RouteKey``。 + + Args: + message_dict: Host 与插件之间传输的消息字典。 + + Returns: + RouteKey: 构造出的规范化路由键。 + + Raises: + ValueError: 当消息字典缺少有效 ``platform`` 字段时抛出。 + """ + platform = str(message_dict.get("platform") or "").strip() + if not platform: + raise ValueError("消息字典缺少有效的 platform 字段,无法构造 RouteKey") + + message_info = message_dict.get("message_info", {}) + additional_config = {} + if isinstance(message_info, dict): + raw_additional_config = message_info.get("additional_config", {}) + if isinstance(raw_additional_config, dict): + additional_config = raw_additional_config + + explicit_account_id, explicit_scope = cls.extract_components(message_dict) + message_info_account_id, message_info_scope = cls.extract_components(message_info) + metadata_account_id, metadata_scope = cls.extract_components(additional_config) + return RouteKey( + platform=platform, + account_id=explicit_account_id or message_info_account_id or metadata_account_id, + scope=explicit_scope or message_info_scope or metadata_scope, + ) + + @classmethod + def from_session_message(cls, message: "SessionMessage") -> RouteKey: + """从 ``SessionMessage`` 中提取 ``RouteKey``。 + + Args: + message: 内部会话消息对象。 + + Returns: + RouteKey: 构造出的规范化路由键。 + """ + additional_config = message.message_info.additional_config or {} + metadata = additional_config if isinstance(additional_config, dict) else {} + return cls.from_platform(message.platform, metadata=metadata) + + @classmethod + def extract_components(cls, mapping: Optional[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str]]: + """从任意字典中提取 ``account_id`` 与 ``scope``。 + + Args: + mapping: 待提取的字典;若为空或不是字典,则返回空结果。 + + Returns: + Tuple[Optional[str], Optional[str]]: ``(account_id, scope)``。 + """ + if not mapping or not isinstance(mapping, dict): + return None, None + + account_id = cls._pick_string(mapping, cls.ACCOUNT_ID_KEYS) + scope = cls._pick_string(mapping, cls.SCOPE_KEYS) + return account_id, scope + + @staticmethod + def _pick_string(mapping: Dict[str, Any], keys: Tuple[str, ...]) -> Optional[str]: + """按优先级从字典里挑选第一个有效字符串。 + + Args: + mapping: 待查询的字典。 + keys: 按优先级排列的候选键名。 + + Returns: + Optional[str]: 第一个规范化后非空的字符串值;若不存在则返回 ``None``。 + """ + for key in keys: + value = mapping.get(key) + if value is None: + continue + normalized = str(value).strip() + if normalized: + return normalized + return None diff --git a/src/platform_io/routing.py b/src/platform_io/routing.py new file mode 100644 index 00000000..2a9b41ef --- /dev/null +++ b/src/platform_io/routing.py @@ -0,0 +1,141 @@ +"""提供 Platform IO 的轻量路由绑定表。""" + +from typing import Dict, List, Optional + +from .types import RouteBinding, RouteKey + + +class RouteTable: + """维护单张路由绑定表。 + + 该实现不负责裁决“唯一 owner”,只负责保存绑定,并按 + ``RouteKey.resolution_order()`` 解析出候选绑定列表。 + """ + + def __init__(self) -> None: + """初始化空路由绑定表。""" + + self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {} + + def bind(self, binding: RouteBinding) -> None: + """注册或更新一条路由绑定。 + + Args: + binding: 要保存的路由绑定。 + """ + + self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding + + def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]: + """移除指定路由键上的绑定。 + + Args: + route_key: 要移除绑定的路由键。 + driver_id: 可选的驱动 ID;为空时移除该路由键下全部绑定。 + + Returns: + List[RouteBinding]: 被移除的绑定列表。 + """ + + binding_map = self._bindings.get(route_key) + if not binding_map: + return [] + + if driver_id is None: + removed = list(binding_map.values()) + self._bindings.pop(route_key, None) + return self._sort_bindings(removed) + + removed_binding = binding_map.pop(driver_id, None) + if not binding_map: + self._bindings.pop(route_key, None) + return [removed_binding] if removed_binding is not None else [] + + def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]: + """移除某个驱动在整张表上的全部绑定。 + + Args: + driver_id: 要移除绑定的驱动 ID。 + + Returns: + List[RouteBinding]: 被移除的绑定列表。 + """ + + removed_bindings: List[RouteBinding] = [] + empty_route_keys: List[RouteKey] = [] + for route_key, binding_map in self._bindings.items(): + removed_binding = binding_map.pop(driver_id, None) + if removed_binding is not None: + removed_bindings.append(removed_binding) + if not binding_map: + empty_route_keys.append(route_key) + + for route_key in empty_route_keys: + self._bindings.pop(route_key, None) + + return self._sort_bindings(removed_bindings) + + def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]: + """列出当前路由表中的绑定。 + + Args: + route_key: 可选的路由键过滤条件。 + + Returns: + List[RouteBinding]: 当前绑定列表。 + """ + + if route_key is None: + bindings: List[RouteBinding] = [] + for binding_map in self._bindings.values(): + bindings.extend(binding_map.values()) + return self._sort_bindings(bindings) + + binding_map = self._bindings.get(route_key, {}) + return self._sort_bindings(list(binding_map.values())) + + def resolve_bindings(self, route_key: RouteKey) -> List[RouteBinding]: + """按从具体到宽泛的顺序解析路由候选绑定。 + + Args: + route_key: 待解析的路由键。 + + Returns: + List[RouteBinding]: 去重后的候选绑定列表。 + """ + + resolved_bindings: List[RouteBinding] = [] + seen_driver_ids: set[str] = set() + for candidate_key in route_key.resolution_order(): + for binding in self.list_bindings(candidate_key): + if binding.driver_id in seen_driver_ids: + continue + seen_driver_ids.add(binding.driver_id) + resolved_bindings.append(binding) + return resolved_bindings + + def has_binding_for_driver(self, route_key: RouteKey, driver_id: str) -> bool: + """判断指定驱动是否在当前路由键解析结果中。 + + Args: + route_key: 待解析的路由键。 + driver_id: 目标驱动 ID。 + + Returns: + bool: 若驱动存在于解析结果中则返回 ``True``。 + """ + + return any(binding.driver_id == driver_id for binding in self.resolve_bindings(route_key)) + + @staticmethod + def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]: + """按优先级降序排列绑定列表。 + + Args: + bindings: 待排序的绑定列表。 + + Returns: + List[RouteBinding]: 排序后的绑定列表。 + """ + + return sorted(bindings, key=lambda item: item.priority, reverse=True) diff --git a/src/platform_io/types.py b/src/platform_io/types.py new file mode 100644 index 00000000..200eca51 --- /dev/null +++ b/src/platform_io/types.py @@ -0,0 +1,264 @@ +"""定义 Platform IO 中间层共享的核心类型。 + +本模块放置路由、驱动、入站与出站等规范化数据结构,供 Broker +层在 legacy 适配器链路和 plugin 适配器链路之间复用。 +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + + +class DriverKind(str, Enum): + """底层收发驱动类型枚举。""" + + LEGACY = "legacy" + PLUGIN = "plugin" + + +class DeliveryStatus(str, Enum): + """统一出站回执状态枚举。""" + + PENDING = "pending" + SENT = "sent" + FAILED = "failed" + DROPPED = "dropped" + + +@dataclass(frozen=True, slots=True) +class RouteKey: + """用于 Platform IO 路由决策的唯一键。 + + 路由解析会按照“从最具体到最宽泛”的顺序进行回退,这样同一平台 + 后续就能自然支持按账号、自定义 scope 等更细粒度的归属控制。 + + Attributes: + platform: 平台名称,例如 ``qq``。 + account_id: 机器人账号 ID 或 self ID,用于区分同平台多身份。 + scope: 额外路由作用域,预留给未来的连接实例、租户或子通道等维度。 + """ + + platform: str + account_id: Optional[str] = None + scope: Optional[str] = None + + def __post_init__(self) -> None: + """规范化并校验路由键字段。 + + Raises: + ValueError: 当 ``platform`` 规范化后为空时抛出。 + """ + platform = str(self.platform).strip() + account_id = str(self.account_id).strip() if self.account_id is not None else None + scope = str(self.scope).strip() if self.scope is not None else None + + if not platform: + raise ValueError("RouteKey.platform 不能为空") + + object.__setattr__(self, "platform", platform) + object.__setattr__(self, "account_id", account_id or None) + object.__setattr__(self, "scope", scope or None) + + def resolution_order(self) -> List["RouteKey"]: + """返回从最具体到最宽泛的路由匹配顺序。 + + Returns: + List[RouteKey]: 按回退优先级排序的候选路由键列表。 + """ + + keys: List[RouteKey] = [self] + + if self.account_id is not None and self.scope is not None: + keys.append(RouteKey(platform=self.platform, account_id=self.account_id, scope=None)) + keys.append(RouteKey(platform=self.platform, account_id=None, scope=self.scope)) + elif self.account_id is not None: + keys.append(RouteKey(platform=self.platform, account_id=None, scope=None)) + elif self.scope is not None: + keys.append(RouteKey(platform=self.platform, account_id=None, scope=None)) + + default_key = RouteKey(platform=self.platform, account_id=None, scope=None) + if default_key not in keys: + keys.append(default_key) + + return keys + + def to_dedupe_scope(self) -> str: + """生成跨驱动共享的去重作用域字符串。 + + Returns: + str: 用于入站消息去重的稳定文本作用域键。 + """ + + account_id = self.account_id or "*" + scope = self.scope or "*" + return f"{self.platform}:{account_id}:{scope}" + + +@dataclass(frozen=True, slots=True) +class DriverDescriptor: + """描述一个已注册的 Platform IO 驱动。 + + Attributes: + driver_id: Broker 层内全局唯一的驱动标识。 + kind: 驱动实现类型,例如 legacy 或 plugin。 + platform: 驱动负责的平台名称。 + account_id: 可选的账号 ID 或 self ID。 + scope: 可选的额外路由作用域。 + plugin_id: 当驱动来自插件适配器时,对应的插件 ID。 + metadata: 预留给路由策略或观测能力的额外驱动元数据。 + """ + + driver_id: str + kind: DriverKind + platform: str + account_id: Optional[str] = None + scope: Optional[str] = None + plugin_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """规范化并校验驱动描述字段。 + + Raises: + ValueError: 当 ``driver_id`` 或 ``platform`` 规范化后为空时抛出。 + """ + driver_id = str(self.driver_id).strip() + platform = str(self.platform).strip() + plugin_id = str(self.plugin_id).strip() if self.plugin_id is not None else None + + if not driver_id: + raise ValueError("DriverDescriptor.driver_id 不能为空") + if not platform: + raise ValueError("DriverDescriptor.platform 不能为空") + + object.__setattr__(self, "driver_id", driver_id) + object.__setattr__(self, "platform", platform) + object.__setattr__(self, "plugin_id", plugin_id or None) + + @property + def route_key(self) -> RouteKey: + """构造该驱动默认代表的路由键。 + + Returns: + RouteKey: 当前驱动描述对应的规范化路由键。 + """ + return RouteKey(platform=self.platform, account_id=self.account_id, scope=self.scope) + + +@dataclass(frozen=True, slots=True) +class RouteBinding: + """表示一条从路由键到驱动的绑定关系。 + + Attributes: + route_key: 该绑定覆盖的路由键。 + driver_id: 拥有该路由的驱动 ID。 + driver_kind: 绑定驱动的类型。 + priority: 当同一路由键存在多条绑定时使用的相对优先级。 + metadata: 预留给未来路由策略的额外绑定元数据。 + """ + + route_key: RouteKey + driver_id: str + driver_kind: DriverKind + priority: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """规范化并校验绑定字段。 + + Raises: + ValueError: 当 ``driver_id`` 规范化后为空时抛出。 + """ + driver_id = str(self.driver_id).strip() + if not driver_id: + raise ValueError("RouteBinding.driver_id 不能为空") + object.__setattr__(self, "driver_id", driver_id) + + +@dataclass(slots=True) +class InboundMessageEnvelope: + """封装一次由驱动产出的规范化入站消息。 + + Attributes: + route_key: 该入站消息解析出的路由键。 + driver_id: 产出该消息的驱动 ID。 + driver_kind: 产出该消息的驱动类型。 + external_message_id: 可选的平台侧消息 ID,用于去重。 + dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时, + 可由上游驱动提供稳定的技术性幂等键。若这里为空,中间层仅会继续 + 回退到 ``external_message_id`` 或 ``session_message.message_id``, + 不会再根据 ``payload`` 内容猜测语义去重键。 + session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。 + payload: 可选的原始字典载荷,供延迟转换或调试使用。 + metadata: 额外入站元数据,例如连接信息或追踪上下文。 + """ + + route_key: RouteKey + driver_id: str + driver_kind: DriverKind + external_message_id: Optional[str] = None + dedupe_key: Optional[str] = None + session_message: Optional["SessionMessage"] = None + payload: Optional[Dict[str, Any]] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class DeliveryReceipt: + """表示一次出站投递尝试的统一结果。 + + Attributes: + internal_message_id: Broker 跟踪的内部 ``SessionMessage.message_id``。 + route_key: 本次投递使用的路由键。 + status: 规范化后的投递状态。 + driver_id: 实际处理该投递的驱动 ID,可为空。 + driver_kind: 实际处理该投递的驱动类型,可为空。 + external_message_id: 驱动或适配器返回的平台侧消息 ID,可为空。 + error: 投递失败时的错误信息,可为空。 + metadata: 预留给回执、时间戳或平台特有信息的额外元数据。 + """ + + internal_message_id: str + route_key: RouteKey + status: DeliveryStatus + driver_id: Optional[str] = None + driver_kind: Optional[DriverKind] = None + external_message_id: Optional[str] = None + error: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class DeliveryBatch: + """表示一次广播式出站投递的批量结果。 + + Attributes: + internal_message_id: 内部消息 ID。 + route_key: 本次投递使用的路由键。 + receipts: 各条路由的独立投递回执列表。 + """ + + internal_message_id: str + route_key: RouteKey + receipts: List[DeliveryReceipt] = field(default_factory=list) + + @property + def sent_receipts(self) -> List[DeliveryReceipt]: + """返回全部发送成功的回执。""" + + return [receipt for receipt in self.receipts if receipt.status == DeliveryStatus.SENT] + + @property + def failed_receipts(self) -> List[DeliveryReceipt]: + """返回全部发送失败的回执。""" + + return [receipt for receipt in self.receipts if receipt.status != DeliveryStatus.SENT] + + @property + def has_success(self) -> bool: + """返回当前批量投递是否至少命中一条成功回执。""" + + return bool(self.sent_receipts) diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py index a881d399..7f2d789f 100644 --- a/src/plugin_runtime/__init__.py +++ b/src/plugin_runtime/__init__.py @@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS" ENV_HOST_VERSION = "MAIBOT_HOST_VERSION" """Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验""" + +ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS" +"""Runner 启动时可视为已满足的外部插件依赖版本映射(JSON 对象)""" + +ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT" +"""Runner 启动时注入的全局配置快照(JSON 对象)""" diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index aa7ceb46..33b54c64 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -1,12 +1,13 @@ from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence from src.common.logger import get_logger logger = get_logger("plugin_runtime.integration") if TYPE_CHECKING: - from src.plugin_runtime.host.component_registry import RegisteredComponent + from src.plugin_runtime.host.api_registry import APIEntry + from src.plugin_runtime.host.component_registry import ComponentEntry from src.plugin_runtime.host.supervisor import PluginSupervisor @@ -14,18 +15,311 @@ class _RuntimeComponentManagerProtocol(Protocol): @property def supervisors(self) -> List["PluginSupervisor"]: ... + def _normalize_component_type(self, component_type: str) -> str: ... + + def _is_api_component_type(self, component_type: str) -> bool: ... + + def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ... + + def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ... + + def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ... + + def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ... + + def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ... + def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ... + def _resolve_api_target( + self, + caller_plugin_id: str, + api_name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ... + + def _resolve_api_toggle_target( + self, + name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ... + def _resolve_component_toggle_target( self, name: str, component_type: str - ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ... + ) -> tuple[Optional["ComponentEntry"], Optional[str]]: ... def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ... def _iter_plugin_dirs(self) -> Iterable[Path]: ... + async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ... + + async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ... + class RuntimeComponentCapabilityMixin: + @staticmethod + def _normalize_component_type(component_type: str) -> str: + """规范化组件类型名称。 + + Args: + component_type: 原始组件类型。 + + Returns: + str: 统一转为大写后的组件类型名。 + """ + + return str(component_type or "").strip().upper() + + @classmethod + def _is_api_component_type(cls, component_type: str) -> bool: + """判断组件类型是否为 API。 + + Args: + component_type: 原始组件类型。 + + Returns: + bool: 是否为 API 组件类型。 + """ + + return cls._normalize_component_type(component_type) == "API" + + @staticmethod + def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]: + """将 API 组件条目序列化为能力返回值。 + + Args: + entry: API 组件条目。 + + Returns: + Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。 + """ + + return { + "name": entry.name, + "full_name": entry.full_name, + "plugin_id": entry.plugin_id, + "description": entry.description, + "version": entry.version, + "public": entry.public, + "enabled": entry.enabled, + "dynamic": entry.dynamic, + "offline_reason": entry.offline_reason, + "metadata": dict(entry.metadata), + } + + @classmethod + def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]: + """将 API 条目序列化为通用组件视图。 + + Args: + entry: API 组件条目。 + + Returns: + Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。 + """ + + serialized_entry = cls._serialize_api_entry(entry) + return { + "name": serialized_entry["name"], + "full_name": serialized_entry["full_name"], + "type": "API", + "enabled": serialized_entry["enabled"], + "metadata": serialized_entry["metadata"], + } + + @staticmethod + def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool: + """判断某个 API 是否对调用方可见。 + + Args: + entry: 目标 API 组件条目。 + caller_plugin_id: 调用方插件 ID。 + + Returns: + bool: 是否允许当前插件可见并调用。 + """ + + return entry.plugin_id == caller_plugin_id or entry.public + + @staticmethod + def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]: + """规范化 API 名称与版本参数。 + + 支持在 ``api_name`` 中直接携带 ``@version`` 后缀。 + """ + + normalized_api_name = str(api_name or "").strip() + normalized_version = str(version or "").strip() + if normalized_api_name and not normalized_version and "@" in normalized_api_name: + candidate_name, candidate_version = normalized_api_name.rsplit("@", 1) + candidate_name = candidate_name.strip() + candidate_version = candidate_version.strip() + if candidate_name and candidate_version: + normalized_api_name = candidate_name + normalized_version = candidate_version + return normalized_api_name, normalized_version + + @staticmethod + def _build_api_unavailable_error(entry: "APIEntry") -> str: + """构造 API 当前不可用时的错误信息。""" + + if entry.offline_reason: + return entry.offline_reason + return f"API {entry.registry_key} 当前不可用" + + def _resolve_api_target( + self: _RuntimeComponentManagerProtocol, + caller_plugin_id: str, + api_name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: + """解析 API 名称到唯一可调用的目标组件。 + + Args: + caller_plugin_id: 调用方插件 ID。 + api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 + version: 可选的 API 版本。 + + Returns: + tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: + 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 + """ + + normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version) + if not normalized_api_name: + return None, None, "缺少必要参数 api_name" + + if "." in normalized_api_name: + target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1) + try: + supervisor = self._get_supervisor_for_plugin(target_plugin_id) + except RuntimeError as exc: + return None, None, str(exc) + + if supervisor is None: + return None, None, f"未找到 API 提供方插件: {target_plugin_id}" + + entries = supervisor.api_registry.get_apis( + plugin_id=target_plugin_id, + name=target_api_name, + version=normalized_version, + enabled_only=False, + ) + visible_enabled_entries = [ + entry + for entry in entries + if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled + ] + visible_disabled_entries = [ + entry + for entry in entries + if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled + ] + if len(visible_enabled_entries) == 1: + return supervisor, visible_enabled_entries[0], None + if len(visible_enabled_entries) > 1: + return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version" + if visible_disabled_entries: + if len(visible_disabled_entries) == 1: + return None, None, self._build_api_unavailable_error(visible_disabled_entries[0]) + return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version" + if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries): + return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" + if normalized_version: + return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" + return None, None, f"未找到 API: {normalized_api_name}" + + visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + hidden_match_exists = False + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis( + name=normalized_api_name, + version=normalized_version, + enabled_only=False, + ): + if self._is_api_visible_to_plugin(entry, caller_plugin_id): + if entry.enabled: + visible_enabled_matches.append((supervisor, entry)) + else: + visible_disabled_matches.append((supervisor, entry)) + else: + hidden_match_exists = True + + if len(visible_enabled_matches) == 1: + return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None + if len(visible_enabled_matches) > 1: + return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version" + if visible_disabled_matches: + if len(visible_disabled_matches) == 1: + return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1]) + return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version" + if hidden_match_exists: + return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" + if normalized_version: + return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" + return None, None, f"未找到 API: {normalized_api_name}" + + def _resolve_api_toggle_target( + self: _RuntimeComponentManagerProtocol, + name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: + """解析需要启用或禁用的 API 组件。 + + Args: + name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 + version: 可选的 API 版本。 + + Returns: + tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: + 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 + """ + + normalized_name, normalized_version = self._normalize_api_reference(name, version) + if not normalized_name: + return None, None, "缺少必要参数 name" + + if "." in normalized_name: + plugin_id, api_name = normalized_name.rsplit(".", 1) + try: + supervisor = self._get_supervisor_for_plugin(plugin_id) + except RuntimeError as exc: + return None, None, str(exc) + + if supervisor is None: + return None, None, f"未找到 API 提供方插件: {plugin_id}" + + entries = supervisor.api_registry.get_apis( + plugin_id=plugin_id, + name=api_name, + version=normalized_version, + enabled_only=False, + ) + if len(entries) == 1: + return supervisor, entries[0], None + if entries: + return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version" + return None, None, f"未找到 API: {normalized_name}" + + matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + for supervisor in self.supervisors: + matches.extend( + (supervisor, entry) + for entry in supervisor.api_registry.get_apis( + name=normalized_name, + version=normalized_version, + enabled_only=False, + ) + ) + + if len(matches) == 1: + return matches[0][0], matches[0][1], None + if len(matches) > 1: + return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version" + return None, None, f"未找到 API: {normalized_name}" + async def _cap_component_get_all_plugins( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] ) -> Any: @@ -46,6 +340,10 @@ class RuntimeComponentCapabilityMixin: } for component in comps ] + components_list.extend( + self._serialize_api_component_entry(entry) + for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False) + ) result[pid] = { "name": pid, "version": reg.plugin_version, @@ -96,30 +394,35 @@ class RuntimeComponentCapabilityMixin: def _resolve_component_toggle_target( self: _RuntimeComponentManagerProtocol, name: str, component_type: str - ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: - short_name_matches: List["RegisteredComponent"] = [] + ) -> tuple[Optional["ComponentEntry"], Optional[str]]: + normalized_component_type = self._normalize_component_type(component_type) + short_name_matches: List["ComponentEntry"] = [] for sv in self.supervisors: comp = sv.component_registry.get_component(name) - if comp is not None and comp.component_type == component_type: + if comp is not None and comp.component_type == normalized_component_type: return comp, None short_name_matches.extend( candidate - for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False) + for candidate in sv.component_registry.get_components_by_type( + normalized_component_type, + enabled_only=False, + ) if candidate.name == name ) if len(short_name_matches) == 1: return short_name_matches[0], None if len(short_name_matches) > 1: - return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name" - return None, f"未找到组件: {name} ({component_type})" + return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name" + return None, f"未找到组件: {name} ({normalized_component_type})" async def _cap_component_enable( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] ) -> Any: name: str = args.get("name", "") component_type: str = args.get("component_type", "") + version: str = args.get("version", "") scope: str = args.get("scope", "global") stream_id: str = args.get("stream_id", "") if not name or not component_type: @@ -127,6 +430,13 @@ class RuntimeComponentCapabilityMixin: if scope != "global" or stream_id: return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"} + if self._is_api_component_type(component_type): + supervisor, api_entry, error = self._resolve_api_toggle_target(name, version) + if supervisor is None or api_entry is None: + return {"success": False, "error": error or f"未找到 API: {name}"} + supervisor.api_registry.toggle_api_status(api_entry.registry_key, True) + return {"success": True} + comp, error = self._resolve_component_toggle_target(name, component_type) if comp is None: return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} @@ -139,6 +449,7 @@ class RuntimeComponentCapabilityMixin: ) -> Any: name: str = args.get("name", "") component_type: str = args.get("component_type", "") + version: str = args.get("version", "") scope: str = args.get("scope", "global") stream_id: str = args.get("stream_id", "") if not name or not component_type: @@ -146,6 +457,13 @@ class RuntimeComponentCapabilityMixin: if scope != "global" or stream_id: return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"} + if self._is_api_component_type(component_type): + supervisor, api_entry, error = self._resolve_api_toggle_target(name, version) + if supervisor is None or api_entry is None: + return {"success": False, "error": error or f"未找到 API: {name}"} + supervisor.api_registry.toggle_api_status(api_entry.registry_key, False) + return {"success": True} + comp, error = self._resolve_component_toggle_target(name, component_type) if comp is None: return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} @@ -168,33 +486,14 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} try: - registered_supervisor = self._get_supervisor_for_plugin(plugin_name) - except RuntimeError as exc: - return {"success": False, "error": str(exc)} + loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}") + except Exception as e: + logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} - if registered_supervisor is not None: - try: - reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}") - if reloaded: - return {"success": True, "count": 1} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - - for sv in self.supervisors: - for pdir in sv._plugin_dirs: - if (pdir / plugin_name).is_dir(): - try: - reloaded = await sv.reload_plugins(reason=f"load {plugin_name}") - if reloaded: - return {"success": True, "count": 1} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - - return {"success": False, "error": f"未找到插件: {plugin_name}"} + if loaded: + return {"success": True, "count": 1} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败"} async def _cap_component_unload_plugin( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] @@ -216,17 +515,204 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} try: - sv = self._get_supervisor_for_plugin(plugin_name) + reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}") + except Exception as e: + logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} + + if reloaded: + return {"success": True} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败"} + + async def _cap_api_call( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """调用其他插件公开的 API。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 调用结果。 + """ + + del capability + api_name = str(args.get("api_name", "") or "").strip() + version = str(args.get("version", "") or "").strip() + api_args = args.get("args", {}) + if not isinstance(api_args, dict): + return {"success": False, "error": "参数 args 必须为字典"} + + supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version) + if supervisor is None or entry is None: + return {"success": False, "error": error or "API 解析失败"} + + invoke_args = dict(api_args) + if entry.dynamic: + invoke_args.setdefault("__maibot_api_name__", entry.name) + invoke_args.setdefault("__maibot_api_full_name__", entry.full_name) + invoke_args.setdefault("__maibot_api_version__", entry.version) + + try: + response = await supervisor.invoke_api( + plugin_id=entry.plugin_id, + component_name=entry.handler_name, + args=invoke_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} + + if response.error: + return {"success": False, "error": response.error.get("message", "API 调用失败")} + + payload = response.payload if isinstance(response.payload, dict) else {} + if not bool(payload.get("success", False)): + result = payload.get("result") + return {"success": False, "error": "" if result is None else str(result)} + return {"success": True, "result": payload.get("result")} + + async def _cap_api_get( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """获取当前插件可见的单个 API 元信息。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 元信息或 ``None``。 + """ + + del capability + api_name = str(args.get("api_name", "") or "").strip() + version = str(args.get("version", "") or "").strip() + if not api_name: + return {"success": False, "error": "缺少必要参数 api_name"} + + supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version) + if supervisor is None or entry is None: + return {"success": True, "api": None} + return {"success": True, "api": self._serialize_api_entry(entry)} + + async def _cap_api_list( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """列出当前插件可见的 API 列表。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 元信息列表。 + """ + + del capability + target_plugin_id = str(args.get("plugin_id", "") or "").strip() + api_name, version = self._normalize_api_reference( + str(args.get("api_name", args.get("name", "")) or ""), + str(args.get("version", "") or ""), + ) + apis: List[Dict[str, Any]] = [] + for supervisor in self.supervisors: + apis.extend( + self._serialize_api_entry(entry) + for entry in supervisor.api_registry.get_apis( + plugin_id=target_plugin_id or None, + name=api_name, + version=version, + enabled_only=True, + ) + if self._is_api_visible_to_plugin(entry, plugin_id) + ) + + apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"]))) + return {"success": True, "apis": apis} + + async def _cap_api_replace_dynamic( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """替换插件自行维护的动态 API 列表。""" + + del capability + raw_apis = args.get("apis", []) + offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线" + if not isinstance(raw_apis, list): + return {"success": False, "error": "参数 apis 必须为列表"} + + try: + supervisor = self._get_supervisor_for_plugin(plugin_id) except RuntimeError as exc: return {"success": False, "error": str(exc)} - if sv is not None: - try: - reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}") - if reloaded: - return {"success": True} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - return {"success": False, "error": f"未找到插件: {plugin_name}"} + if supervisor is None: + return {"success": False, "error": f"未找到插件: {plugin_id}"} + + normalized_components: List[Dict[str, Any]] = [] + seen_registry_keys: set[str] = set() + for index, raw_api in enumerate(raw_apis): + if not isinstance(raw_api, dict): + return {"success": False, "error": f"apis[{index}] 必须为字典"} + + api_name = str(raw_api.get("name", "") or "").strip() + component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip() + if not api_name: + return {"success": False, "error": f"apis[{index}] 缺少 name"} + if not self._is_api_component_type(component_type): + return {"success": False, "error": f"apis[{index}] 不是 API 组件"} + + metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {} + normalized_metadata = dict(metadata) + normalized_metadata["dynamic"] = True + version = str(normalized_metadata.get("version", "1") or "1").strip() or "1" + registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version) + if registry_key in seen_registry_keys: + return {"success": False, "error": f"动态 API 重复声明: {registry_key}"} + seen_registry_keys.add(registry_key) + + existing_entry = supervisor.api_registry.get_api( + plugin_id, + api_name, + version=version, + enabled_only=False, + ) + if existing_entry is not None and not existing_entry.dynamic: + return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"} + + normalized_components.append( + { + "name": api_name, + "component_type": "API", + "metadata": normalized_metadata, + } + ) + + registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis( + plugin_id, + normalized_components, + offline_reason=offline_reason, + ) + return { + "success": True, + "count": registered_count, + "offlined": offlined_count, + } diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py index def5f03d..9bb1755b 100644 --- a/src/plugin_runtime/capabilities/core.py +++ b/src/plugin_runtime/capabilities/core.py @@ -238,14 +238,14 @@ class RuntimeCoreCapabilityMixin: return {"success": False, "value": None, "error": str(e)} async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.core.component_registry import component_registry as core_registry + from src.plugin_runtime.component_query import component_query_service plugin_name: str = args.get("plugin_name", plugin_id) key: str = args.get("key", "") default = args.get("default") try: - config = core_registry.get_plugin_config(plugin_name) + config = component_query_service.get_plugin_config(plugin_name) if config is None: return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"} @@ -258,11 +258,11 @@ class RuntimeCoreCapabilityMixin: return {"success": False, "value": default, "error": str(e)} async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.core.component_registry import component_registry as core_registry + from src.plugin_runtime.component_query import component_query_service plugin_name: str = args.get("plugin_name", plugin_id) try: - config = core_registry.get_plugin_config(plugin_name) + config = component_query_service.get_plugin_config(plugin_name) if config is None: return {"success": True, "value": {}} return {"success": True, "value": config} diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index c8139c16..32843d09 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -648,10 +648,10 @@ class RuntimeDataCapabilityMixin: return {"success": False, "error": str(e)} async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.core.component_registry import component_registry as core_registry + from src.plugin_runtime.component_query import component_query_service try: - tools = core_registry.get_llm_available_tools() + tools = component_query_service.get_llm_available_tools() return { "success": True, "tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()], diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py index abce97dc..7f87604d 100644 --- a/src/plugin_runtime/capabilities/registry.py +++ b/src/plugin_runtime/capabilities/registry.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING from src.common.logger import get_logger +from src.plugin_runtime.host.capability_service import CapabilityImpl from src.plugin_runtime.host.supervisor import PluginSupervisor if TYPE_CHECKING: @@ -13,66 +14,80 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi """向指定 Supervisor 注册主程序提供的能力实现。""" cap_service = supervisor.capability_service - cap_service.register_capability("send.text", manager._cap_send_text) - cap_service.register_capability("send.emoji", manager._cap_send_emoji) - cap_service.register_capability("send.image", manager._cap_send_image) - cap_service.register_capability("send.command", manager._cap_send_command) - cap_service.register_capability("send.custom", manager._cap_send_custom) + def _register(name: str, impl: CapabilityImpl) -> None: + """注册单个能力实现。 - cap_service.register_capability("llm.generate", manager._cap_llm_generate) - cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools) - cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models) + Args: + name: 能力名称。 + impl: 能力实现函数。 + """ + cap_service.register_capability(name, impl) - cap_service.register_capability("config.get", manager._cap_config_get) - cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin) - cap_service.register_capability("config.get_all", manager._cap_config_get_all) + _register("send.text", manager._cap_send_text) + _register("send.emoji", manager._cap_send_emoji) + _register("send.image", manager._cap_send_image) + _register("send.command", manager._cap_send_command) + _register("send.custom", manager._cap_send_custom) - cap_service.register_capability("database.query", manager._cap_database_query) - cap_service.register_capability("database.save", manager._cap_database_save) - cap_service.register_capability("database.get", manager._cap_database_get) - cap_service.register_capability("database.delete", manager._cap_database_delete) - cap_service.register_capability("database.count", manager._cap_database_count) + _register("llm.generate", manager._cap_llm_generate) + _register("llm.generate_with_tools", manager._cap_llm_generate_with_tools) + _register("llm.get_available_models", manager._cap_llm_get_available_models) - cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams) - cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams) - cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams) - cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id) - cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id) + _register("config.get", manager._cap_config_get) + _register("config.get_plugin", manager._cap_config_get_plugin) + _register("config.get_all", manager._cap_config_get_all) - cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time) - cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat) - cap_service.register_capability("message.get_recent", manager._cap_message_get_recent) - cap_service.register_capability("message.count_new", manager._cap_message_count_new) - cap_service.register_capability("message.build_readable", manager._cap_message_build_readable) + _register("database.query", manager._cap_database_query) + _register("database.save", manager._cap_database_save) + _register("database.get", manager._cap_database_get) + _register("database.delete", manager._cap_database_delete) + _register("database.count", manager._cap_database_count) - cap_service.register_capability("person.get_id", manager._cap_person_get_id) - cap_service.register_capability("person.get_value", manager._cap_person_get_value) - cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name) + _register("chat.get_all_streams", manager._cap_chat_get_all_streams) + _register("chat.get_group_streams", manager._cap_chat_get_group_streams) + _register("chat.get_private_streams", manager._cap_chat_get_private_streams) + _register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id) + _register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id) - cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description) - cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random) - cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count) - cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions) - cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all) - cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info) - cap_service.register_capability("emoji.register", manager._cap_emoji_register) - cap_service.register_capability("emoji.delete", manager._cap_emoji_delete) + _register("message.get_by_time", manager._cap_message_get_by_time) + _register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat) + _register("message.get_recent", manager._cap_message_get_recent) + _register("message.count_new", manager._cap_message_count_new) + _register("message.build_readable", manager._cap_message_build_readable) - cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value) - cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust) - cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust) + _register("person.get_id", manager._cap_person_get_id) + _register("person.get_value", manager._cap_person_get_value) + _register("person.get_id_by_name", manager._cap_person_get_id_by_name) - cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions) + _register("emoji.get_by_description", manager._cap_emoji_get_by_description) + _register("emoji.get_random", manager._cap_emoji_get_random) + _register("emoji.get_count", manager._cap_emoji_get_count) + _register("emoji.get_emotions", manager._cap_emoji_get_emotions) + _register("emoji.get_all", manager._cap_emoji_get_all) + _register("emoji.get_info", manager._cap_emoji_get_info) + _register("emoji.register", manager._cap_emoji_register) + _register("emoji.delete", manager._cap_emoji_delete) - cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins) - cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info) - cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) - cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins) - cap_service.register_capability("component.enable", manager._cap_component_enable) - cap_service.register_capability("component.disable", manager._cap_component_disable) - cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin) - cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin) - cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin) + _register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value) + _register("frequency.set_adjust", manager._cap_frequency_set_adjust) + _register("frequency.get_adjust", manager._cap_frequency_get_adjust) - cap_service.register_capability("knowledge.search", manager._cap_knowledge_search) + _register("tool.get_definitions", manager._cap_tool_get_definitions) + + _register("api.call", manager._cap_api_call) + _register("api.get", manager._cap_api_get) + _register("api.list", manager._cap_api_list) + _register("api.replace_dynamic", manager._cap_api_replace_dynamic) + + _register("component.get_all_plugins", manager._cap_component_get_all_plugins) + _register("component.get_plugin_info", manager._cap_component_get_plugin_info) + _register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) + _register("component.list_registered_plugins", manager._cap_component_list_registered_plugins) + _register("component.enable", manager._cap_component_enable) + _register("component.disable", manager._cap_component_disable) + _register("component.load_plugin", manager._cap_component_load_plugin) + _register("component.unload_plugin", manager._cap_component_unload_plugin) + _register("component.reload_plugin", manager._cap_component_reload_plugin) + + _register("knowledge.search", manager._cap_knowledge_search) logger.debug("已注册全部主程序能力实现") diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py new file mode 100644 index 00000000..7d23d202 --- /dev/null +++ b/src/plugin_runtime/component_query.py @@ -0,0 +1,709 @@ +"""插件运行时统一组件查询服务。 + +该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图, +供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple + +from src.common.logger import get_logger +from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo +from src.llm_models.payload_content.tool_option import ToolParamType + +if TYPE_CHECKING: + from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry + from src.plugin_runtime.host.supervisor import PluginSupervisor + from src.plugin_runtime.integration import PluginRuntimeManager + +logger = get_logger("plugin_runtime.component_query") + +ActionExecutor = Callable[..., Awaitable[Any]] +CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]] +ToolExecutor = Callable[..., Awaitable[Any]] + +_HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = { + ComponentType.ACTION: "ACTION", + ComponentType.COMMAND: "COMMAND", + ComponentType.TOOL: "TOOL", +} +_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = { + "string": ToolParamType.STRING, + "integer": ToolParamType.INTEGER, + "float": ToolParamType.FLOAT, + "boolean": ToolParamType.BOOLEAN, + "bool": ToolParamType.BOOLEAN, +} + + +class ComponentQueryService: + """插件运行时统一组件查询服务。 + + 该对象不维护独立状态,只读取插件系统中的注册结果。 + 所有注册、删除、配置写入等写操作都被显式禁用。 + """ + + @staticmethod + def _get_runtime_manager() -> "PluginRuntimeManager": + """获取插件运行时管理器单例。 + + Returns: + PluginRuntimeManager: 当前全局插件运行时管理器。 + """ + + from src.plugin_runtime.integration import get_plugin_runtime_manager + + return get_plugin_runtime_manager() + + def _iter_supervisors(self) -> list["PluginSupervisor"]: + """获取当前所有活跃的插件运行时监督器。 + + Returns: + list[PluginSupervisor]: 当前运行中的监督器列表。 + """ + + runtime_manager = self._get_runtime_manager() + return list(runtime_manager.supervisors) + + def _iter_component_entries( + self, + component_type: ComponentType, + *, + enabled_only: bool = True, + ) -> list[tuple["PluginSupervisor", "ComponentEntry"]]: + """遍历指定类型的全部组件条目。 + + Args: + component_type: 目标组件类型。 + enabled_only: 是否仅返回启用状态的组件。 + + Returns: + list[tuple[PluginSupervisor, ComponentEntry]]: ``(监督器, 组件条目)`` 列表。 + """ + + host_component_type = _HOST_COMPONENT_TYPE_MAP.get(component_type) + if host_component_type is None: + return [] + + collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = [] + for supervisor in self._iter_supervisors(): + for component in supervisor.component_registry.get_components_by_type( + host_component_type, + enabled_only=enabled_only, + ): + collected_entries.append((supervisor, component)) + return collected_entries + + @staticmethod + def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType: + """规范化动作激活类型。 + + Args: + raw_value: 原始激活类型值。 + + Returns: + ActionActivationType: 规范化后的激活类型枚举。 + """ + + normalized_value = str(raw_value or "").strip().lower() + if normalized_value == ActionActivationType.NEVER.value: + return ActionActivationType.NEVER + if normalized_value == ActionActivationType.RANDOM.value: + return ActionActivationType.RANDOM + if normalized_value == ActionActivationType.KEYWORD.value: + return ActionActivationType.KEYWORD + return ActionActivationType.ALWAYS + + @staticmethod + def _coerce_float(value: Any, default: float = 0.0) -> float: + """将任意值安全转换为浮点数。 + + Args: + value: 待转换的输入值。 + default: 转换失败时返回的默认值。 + + Returns: + float: 转换后的浮点结果。 + """ + + try: + return float(value) + except (TypeError, ValueError): + return default + + @staticmethod + def _build_action_info(entry: "ActionEntry") -> ActionInfo: + """将运行时 Action 条目转换为核心动作信息。 + + Args: + entry: 插件运行时中的 Action 条目。 + + Returns: + ActionInfo: 供核心 Planner 使用的动作信息。 + """ + + metadata = dict(entry.metadata) + raw_action_parameters = metadata.get("action_parameters") + action_parameters = ( + { + str(param_name): str(param_description) + for param_name, param_description in raw_action_parameters.items() + } + if isinstance(raw_action_parameters, dict) + else {} + ) + action_require = [ + str(item) + for item in (metadata.get("action_require") or []) + if item is not None and str(item).strip() + ] + associated_types = [ + str(item) + for item in (metadata.get("associated_types") or []) + if item is not None and str(item).strip() + ] + activation_keywords = [ + str(item) + for item in (metadata.get("activation_keywords") or []) + if item is not None and str(item).strip() + ] + + return ActionInfo( + name=entry.name, + component_type=ComponentType.ACTION, + description=str(metadata.get("description", "") or ""), + enabled=bool(entry.enabled), + plugin_name=entry.plugin_id, + metadata=metadata, + action_parameters=action_parameters, + action_require=action_require, + associated_types=associated_types, + activation_type=ComponentQueryService._coerce_action_activation_type(metadata.get("activation_type")), + random_activation_probability=ComponentQueryService._coerce_float( + metadata.get("activation_probability"), + 0.0, + ), + activation_keywords=activation_keywords, + parallel_action=bool(metadata.get("parallel_action", False)), + ) + + @staticmethod + def _build_command_info(entry: "CommandEntry") -> CommandInfo: + """将运行时 Command 条目转换为核心命令信息。 + + Args: + entry: 插件运行时中的 Command 条目。 + + Returns: + CommandInfo: 供核心命令链使用的命令信息。 + """ + + metadata = dict(entry.metadata) + return CommandInfo( + name=entry.name, + component_type=ComponentType.COMMAND, + description=str(metadata.get("description", "") or ""), + enabled=bool(entry.enabled), + plugin_name=entry.plugin_id, + metadata=metadata, + command_pattern=str(metadata.get("command_pattern", "") or ""), + ) + + @staticmethod + def _coerce_tool_param_type(raw_value: Any) -> ToolParamType: + """规范化工具参数类型。 + + Args: + raw_value: 原始工具参数类型值。 + + Returns: + ToolParamType: 规范化后的工具参数类型。 + """ + + normalized_value = str(raw_value or "").strip().lower() + return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING) + + @staticmethod + def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]: + """将运行时工具参数元数据转换为核心 ToolInfo 参数列表。 + + Args: + entry: 插件运行时中的 Tool 条目。 + + Returns: + list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。 + """ + + structured_parameters = entry.parameters if isinstance(entry.parameters, list) else [] + if not structured_parameters and isinstance(entry.parameters_raw, dict): + structured_parameters = [ + {"name": key, **value} + for key, value in entry.parameters_raw.items() + if isinstance(value, dict) + ] + + normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] + for parameter in structured_parameters: + if not isinstance(parameter, dict): + continue + + parameter_name = str(parameter.get("name", "") or "").strip() + if not parameter_name: + continue + + enum_values = parameter.get("enum") + normalized_enum_values = ( + [str(item) for item in enum_values if item is not None] + if isinstance(enum_values, list) + else None + ) + normalized_parameters.append( + ( + parameter_name, + ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")), + str(parameter.get("description", "") or ""), + bool(parameter.get("required", True)), + normalized_enum_values, + ) + ) + return normalized_parameters + + @staticmethod + def _build_tool_info(entry: "ToolEntry") -> ToolInfo: + """将运行时 Tool 条目转换为核心工具信息。 + + Args: + entry: 插件运行时中的 Tool 条目。 + + Returns: + ToolInfo: 供 ToolExecutor 与能力层使用的工具信息。 + """ + + return ToolInfo( + name=entry.name, + component_type=ComponentType.TOOL, + description=entry.description, + enabled=bool(entry.enabled), + plugin_name=entry.plugin_id, + metadata=dict(entry.metadata), + tool_parameters=ComponentQueryService._build_tool_parameters(entry), + tool_description=entry.description, + ) + + @staticmethod + def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None: + """记录重复组件名称冲突。 + + Args: + component_type: 组件类型。 + component_name: 发生冲突的组件名称。 + """ + + logger.warning(f"检测到重复{component_type.value}名称 {component_name},将只保留首个匹配项") + + def _get_unique_component_entry( + self, + component_type: ComponentType, + name: str, + ) -> Optional[tuple["PluginSupervisor", "ComponentEntry"]]: + """按组件短名解析唯一条目。 + + Args: + component_type: 目标组件类型。 + name: 组件短名。 + + Returns: + Optional[tuple[PluginSupervisor, ComponentEntry]]: 唯一命中的组件条目。 + """ + + matched_entries = [ + (supervisor, entry) + for supervisor, entry in self._iter_component_entries(component_type) + if entry.name == name + ] + if not matched_entries: + return None + if len(matched_entries) > 1: + self._log_duplicate_component(component_type, name) + return matched_entries[0] + + def _collect_unique_component_infos( + self, + component_type: ComponentType, + ) -> Dict[str, ComponentInfo]: + """收集某类组件的唯一信息视图。 + + Args: + component_type: 目标组件类型。 + + Returns: + Dict[str, ComponentInfo]: 组件名到核心组件信息的映射。 + """ + + collected_components: Dict[str, ComponentInfo] = {} + for _supervisor, entry in self._iter_component_entries(component_type): + if entry.name in collected_components: + self._log_duplicate_component(component_type, entry.name) + continue + + if component_type == ComponentType.ACTION: + collected_components[entry.name] = self._build_action_info(entry) # type: ignore[arg-type] + elif component_type == ComponentType.COMMAND: + collected_components[entry.name] = self._build_command_info(entry) # type: ignore[arg-type] + elif component_type == ComponentType.TOOL: + collected_components[entry.name] = self._build_tool_info(entry) # type: ignore[arg-type] + return collected_components + + @staticmethod + def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str: + """从旧 ActionManager 参数中提取聊天流 ID。 + + Args: + kwargs: 旧动作执行器收到的关键字参数。 + + Returns: + str: 提取出的 ``stream_id``。 + """ + + chat_stream = kwargs.get("chat_stream") + if chat_stream is not None: + try: + return str(chat_stream.session_id) + except AttributeError: + pass + + return str(kwargs.get("stream_id", "") or "") + + @staticmethod + def _build_action_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ActionExecutor: + """构造动作执行 RPC 闭包。 + + Args: + supervisor: 负责该组件的监督器。 + plugin_id: 插件 ID。 + component_name: 组件名称。 + + Returns: + ActionExecutor: 兼容旧 Planner 的异步执行器。 + """ + + async def _executor(**kwargs: Any) -> tuple[bool, str]: + """将核心动作调用桥接到插件运行时。 + + Args: + **kwargs: 旧 ActionManager 传入的上下文参数。 + + Returns: + tuple[bool, str]: ``(是否成功, 结果说明)``。 + """ + + invoke_args: Dict[str, Any] = {} + action_data = kwargs.get("action_data") + if isinstance(action_data, dict): + invoke_args.update(action_data) + + stream_id = ComponentQueryService._extract_stream_id_from_action_kwargs(kwargs) + invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {} + invoke_args["stream_id"] = stream_id + invoke_args["chat_id"] = stream_id + invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "") + + if (thinking_id := kwargs.get("thinking_id")) is not None: + invoke_args["thinking_id"] = str(thinking_id) + if isinstance(kwargs.get("cycle_timers"), dict): + invoke_args["cycle_timers"] = kwargs["cycle_timers"] + if isinstance(kwargs.get("plugin_config"), dict): + invoke_args["plugin_config"] = kwargs["plugin_config"] + if isinstance(kwargs.get("log_prefix"), str): + invoke_args["log_prefix"] = kwargs["log_prefix"] + if isinstance(kwargs.get("shutting_down"), bool): + invoke_args["shutting_down"] = kwargs["shutting_down"] + + try: + response = await supervisor.invoke_plugin( + method="plugin.invoke_action", + plugin_id=plugin_id, + component_name=component_name, + args=invoke_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True) + return False, str(exc) + + payload = response.payload if isinstance(response.payload, dict) else {} + success = bool(payload.get("success", False)) + result = payload.get("result") + if isinstance(result, (list, tuple)): + if len(result) >= 2: + return bool(result[0]), "" if result[1] is None else str(result[1]) + if len(result) == 1: + return bool(result[0]), "" + if success: + return True, "" if result is None else str(result) + return False, "" if result is None else str(result) + + return _executor + + @staticmethod + def _build_command_executor( + supervisor: "PluginSupervisor", + plugin_id: str, + component_name: str, + metadata: Dict[str, Any], + ) -> CommandExecutor: + """构造命令执行 RPC 闭包。 + + Args: + supervisor: 负责该组件的监督器。 + plugin_id: 插件 ID。 + component_name: 组件名称。 + metadata: 命令组件元数据。 + + Returns: + CommandExecutor: 兼容旧消息命令链的执行器。 + """ + + async def _executor(**kwargs: Any) -> tuple[bool, Optional[str], bool]: + """将核心命令调用桥接到插件运行时。 + + Args: + **kwargs: 命令执行上下文参数。 + + Returns: + tuple[bool, Optional[str], bool]: ``(是否成功, 返回文本, 是否拦截后续消息)``。 + """ + + message = kwargs.get("message") + matched_groups = kwargs.get("matched_groups") + plugin_config = kwargs.get("plugin_config") + invoke_args: Dict[str, Any] = { + "text": str(getattr(message, "processed_plain_text", "") or ""), + "stream_id": str(getattr(message, "session_id", "") or ""), + "matched_groups": matched_groups if isinstance(matched_groups, dict) else {}, + } + if isinstance(plugin_config, dict): + invoke_args["plugin_config"] = plugin_config + + try: + response = await supervisor.invoke_plugin( + method="plugin.invoke_command", + plugin_id=plugin_id, + component_name=component_name, + args=invoke_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"运行时 Command {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True) + return False, str(exc), True + + payload = response.payload if isinstance(response.payload, dict) else {} + success = bool(payload.get("success", False)) + result = payload.get("result") + intercept = bool(metadata.get("intercept_message_level", 0)) + response_text: Optional[str] + + if isinstance(result, (list, tuple)) and len(result) >= 3: + response_text = None if result[1] is None else str(result[1]) + intercept = bool(result[2]) + else: + response_text = None if result is None else str(result) + + return success, response_text, intercept + + return _executor + + @staticmethod + def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor: + """构造工具执行 RPC 闭包。 + + Args: + supervisor: 负责该组件的监督器。 + plugin_id: 插件 ID。 + component_name: 组件名称。 + + Returns: + ToolExecutor: 兼容旧 ToolExecutor 的异步执行器。 + """ + + async def _executor(function_args: Dict[str, Any]) -> Any: + """将核心工具调用桥接到插件运行时。 + + Args: + function_args: 工具调用参数。 + + Returns: + Any: 插件工具返回结果;若结果不是字典,则会包装为 ``{"content": ...}``。 + """ + + try: + response = await supervisor.invoke_plugin( + method="plugin.invoke_tool", + plugin_id=plugin_id, + component_name=component_name, + args=function_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"运行时 Tool {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True) + return {"content": f"工具 {component_name} 执行失败: {exc}"} + + payload = response.payload if isinstance(response.payload, dict) else {} + result = payload.get("result") + if isinstance(result, dict): + return result + return {"content": "" if result is None else str(result)} + + return _executor + + def get_action_info(self, name: str) -> Optional[ActionInfo]: + """获取指定动作的信息。 + + Args: + name: 动作名称。 + + Returns: + Optional[ActionInfo]: 匹配到的动作信息。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name) + if matched_entry is None: + return None + _supervisor, entry = matched_entry + return self._build_action_info(entry) # type: ignore[arg-type] + + def get_action_executor(self, name: str) -> Optional[ActionExecutor]: + """获取指定动作的执行器。 + + Args: + name: 动作名称。 + + Returns: + Optional[ActionExecutor]: 运行时 RPC 执行闭包。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name) + if matched_entry is None: + return None + supervisor, entry = matched_entry + return self._build_action_executor(supervisor, entry.plugin_id, entry.name) + + def get_default_actions(self) -> Dict[str, ActionInfo]: + """获取当前默认启用的动作集合。 + + Returns: + Dict[str, ActionInfo]: 动作名到动作信息的映射。 + """ + + action_infos = self._collect_unique_component_infos(ComponentType.ACTION) + return {name: info for name, info in action_infos.items() if isinstance(info, ActionInfo) and info.enabled} + + def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]: + """根据文本查找匹配的命令。 + + Args: + text: 待匹配的文本内容。 + + Returns: + Optional[Tuple[CommandExecutor, dict, CommandInfo]]: 匹配结果。 + """ + + for supervisor in self._iter_supervisors(): + match_result = supervisor.component_registry.find_command_by_text(text) + if match_result is None: + continue + + entry, matched_groups = match_result + command_info = self._build_command_info(entry) # type: ignore[arg-type] + command_executor = self._build_command_executor( + supervisor, + entry.plugin_id, + entry.name, + dict(entry.metadata), + ) + return command_executor, matched_groups, command_info + return None + + def get_tool_info(self, name: str) -> Optional[ToolInfo]: + """获取指定工具的信息。 + + Args: + name: 工具名称。 + + Returns: + Optional[ToolInfo]: 匹配到的工具信息。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name) + if matched_entry is None: + return None + _supervisor, entry = matched_entry + return self._build_tool_info(entry) # type: ignore[arg-type] + + def get_tool_executor(self, name: str) -> Optional[ToolExecutor]: + """获取指定工具的执行器。 + + Args: + name: 工具名称。 + + Returns: + Optional[ToolExecutor]: 运行时 RPC 执行闭包。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name) + if matched_entry is None: + return None + supervisor, entry = matched_entry + return self._build_tool_executor(supervisor, entry.plugin_id, entry.name) + + def get_llm_available_tools(self) -> Dict[str, ToolInfo]: + """获取当前可供 LLM 选择的工具集合。 + + Returns: + Dict[str, ToolInfo]: 工具名到工具信息的映射。 + """ + + tool_infos = self._collect_unique_component_infos(ComponentType.TOOL) + return {name: info for name, info in tool_infos.items() if isinstance(info, ToolInfo) and info.enabled} + + def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: + """获取某类组件的全部信息。 + + Args: + component_type: 组件类型。 + + Returns: + Dict[str, ComponentInfo]: 组件名到组件信息的映射。 + """ + + return self._collect_unique_component_infos(component_type) + + def get_plugin_config(self, plugin_name: str) -> Optional[dict]: + """读取指定插件的配置文件内容。 + + Args: + plugin_name: 插件名称。 + + Returns: + Optional[dict]: 读取成功时返回配置字典;未找到时返回 ``None``。 + """ + + runtime_manager = self._get_runtime_manager() + try: + supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name) + except RuntimeError as exc: + logger.error(f"读取插件配置失败: {exc}") + return None + + if supervisor is None: + return None + + try: + return runtime_manager._load_plugin_config_for_supervisor(supervisor, plugin_name) + except Exception as exc: + logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True) + return None + + +component_query_service = ComponentQueryService() diff --git a/src/plugin_runtime/host/api_registry.py b/src/plugin_runtime/host/api_registry.py new file mode 100644 index 00000000..1cbc05f6 --- /dev/null +++ b/src/plugin_runtime/host/api_registry.py @@ -0,0 +1,349 @@ +"""Host 侧插件 API 动态注册表。""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple + +from src.common.logger import get_logger + +logger = get_logger("plugin_runtime.host.api_registry") + + +@dataclass(slots=True) +class APIEntry: + """API 组件条目。""" + + name: str + plugin_id: str + description: str = "" + version: str = "1" + public: bool = False + metadata: Dict[str, Any] = field(default_factory=dict) + enabled: bool = True + handler_name: str = "" + dynamic: bool = False + offline_reason: str = "" + disabled_session: Set[str] = field(default_factory=set) + full_name: str = field(init=False) + registry_key: str = field(init=False) + + def __post_init__(self) -> None: + """规范化 API 条目字段。""" + + self.name = str(self.name or "").strip() + self.plugin_id = str(self.plugin_id or "").strip() + self.description = str(self.description or "").strip() + self.version = str(self.version or "1").strip() or "1" + self.handler_name = str(self.handler_name or self.name).strip() or self.name + self.offline_reason = str(self.offline_reason or "").strip() + self.full_name = f"{self.plugin_id}.{self.name}" + self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version) + + @classmethod + def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry": + """根据 Runner 上报的元数据构造 API 条目。""" + + safe_metadata = dict(metadata) + return cls( + name=name, + plugin_id=plugin_id, + description=str(safe_metadata.get("description", "") or ""), + version=str(safe_metadata.get("version", "1") or "1"), + public=bool(safe_metadata.get("public", False)), + metadata=safe_metadata, + enabled=bool(safe_metadata.get("enabled", True)), + handler_name=str(safe_metadata.get("handler_name", name) or name), + dynamic=bool(safe_metadata.get("dynamic", False)), + offline_reason=str(safe_metadata.get("offline_reason", "") or ""), + ) + + +class APIRegistry: + """Host 侧插件 API 动态注册表。 + + 该注册表不直接面向 Runner,而是复用插件组件注册/卸载事件, + 维护面向 API 调用场景的专用索引。 + """ + + def __init__(self) -> None: + """初始化 API 注册表。""" + + self._apis: Dict[str, APIEntry] = {} + self._by_full_name: Dict[str, List[APIEntry]] = {} + self._by_plugin: Dict[str, List[APIEntry]] = {} + self._by_name: Dict[str, List[APIEntry]] = {} + + def clear(self) -> None: + """清空全部 API 注册状态。""" + + self._apis.clear() + self._by_full_name.clear() + self._by_plugin.clear() + self._by_name.clear() + + @staticmethod + def _is_api_component(component_type: Any) -> bool: + """判断组件声明是否属于 API。""" + + return str(component_type or "").strip().upper() == "API" + + @staticmethod + def _normalize_query_version(version: Any) -> str: + """规范化查询使用的版本字符串。""" + + return str(version or "").strip() + + @classmethod + def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]: + """解析可能带 ``@version`` 后缀的 API 引用。""" + + normalized_reference = str(reference or "").strip() + normalized_version = cls._normalize_query_version(version) + if normalized_reference and not normalized_version and "@" in normalized_reference: + candidate_reference, candidate_version = normalized_reference.rsplit("@", 1) + candidate_reference = candidate_reference.strip() + candidate_version = candidate_version.strip() + if candidate_reference and candidate_version: + normalized_reference = candidate_reference + normalized_version = candidate_version + return normalized_reference, normalized_version + + @staticmethod + def build_registry_key(plugin_id: str, name: str, version: str) -> str: + """构造 API 注册表唯一键。""" + + normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}" + normalized_version = str(version or "1").strip() or "1" + return f"{normalized_full_name}@{normalized_version}" + + @staticmethod + def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool: + """判断 API 条目当前是否处于启用状态。""" + + if session_id and session_id in entry.disabled_session: + return False + return entry.enabled + + def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: + """注册单个 API 条目。""" + + normalized_name = str(name or "").strip() + if not normalized_name: + logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略") + return False + + entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata) + existing_entry = self._apis.get(entry.registry_key) + if existing_entry is not None: + logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目") + self._remove_entry(existing_entry) + + self._apis[entry.registry_key] = entry + self._by_full_name.setdefault(entry.full_name, []).append(entry) + self._by_plugin.setdefault(plugin_id, []).append(entry) + self._by_name.setdefault(entry.name, []).append(entry) + return True + + def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int: + """批量注册某个插件声明的全部 API。""" + + count = 0 + for component in components: + if not self._is_api_component(component.get("component_type")): + continue + if self.register_api( + name=str(component.get("name", "") or ""), + plugin_id=plugin_id, + metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}, + ): + count += 1 + return count + + def replace_plugin_dynamic_apis( + self, + plugin_id: str, + components: List[Dict[str, Any]], + *, + offline_reason: str = "动态 API 已下线", + ) -> Tuple[int, int]: + """替换指定插件当前声明的动态 API 集合。""" + + normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线" + desired_registry_keys: Set[str] = set() + registered_count = 0 + + for component in components: + if not self._is_api_component(component.get("component_type")): + continue + metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {} + dynamic_metadata = dict(metadata) + dynamic_metadata["dynamic"] = True + dynamic_metadata.pop("offline_reason", None) + + entry = APIEntry.from_metadata( + name=str(component.get("name", "") or ""), + plugin_id=plugin_id, + metadata=dynamic_metadata, + ) + desired_registry_keys.add(entry.registry_key) + if self.register_api(entry.name, plugin_id, dynamic_metadata): + registered_count += 1 + + offlined_count = 0 + for entry in list(self._by_plugin.get(plugin_id, [])): + if not entry.dynamic or entry.registry_key in desired_registry_keys: + continue + entry.enabled = False + entry.offline_reason = normalized_offline_reason + entry.metadata["offline_reason"] = normalized_offline_reason + offlined_count += 1 + + return registered_count, offlined_count + + def _remove_entry(self, entry: APIEntry) -> None: + """从全部索引中移除单个 API 条目。""" + + self._apis.pop(entry.registry_key, None) + + full_name_entries = self._by_full_name.get(entry.full_name) + if full_name_entries is not None: + self._by_full_name[entry.full_name] = [ + candidate for candidate in full_name_entries if candidate is not entry + ] + if not self._by_full_name[entry.full_name]: + self._by_full_name.pop(entry.full_name, None) + + plugin_entries = self._by_plugin.get(entry.plugin_id) + if plugin_entries is not None: + self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry] + if not self._by_plugin[entry.plugin_id]: + self._by_plugin.pop(entry.plugin_id, None) + + name_entries = self._by_name.get(entry.name) + if name_entries is not None: + self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry] + if not self._by_name[entry.name]: + self._by_name.pop(entry.name, None) + + def remove_apis_by_plugin(self, plugin_id: str) -> int: + """移除某个插件的全部 API。""" + + entries = list(self._by_plugin.get(plugin_id, [])) + for entry in entries: + self._remove_entry(entry) + return len(entries) + + def get_api_by_full_name( + self, + full_name: str, + *, + version: str = "", + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> Optional[APIEntry]: + """按完整名查询单个 API。""" + + normalized_full_name, normalized_version = self._split_reference(full_name, version) + if not normalized_full_name: + return None + + if normalized_version: + entry = self._apis.get(f"{normalized_full_name}@{normalized_version}") + if entry is None: + return None + if enabled_only and not self.check_api_enabled(entry, session_id): + return None + return entry + + candidates = list(self._by_full_name.get(normalized_full_name, [])) + filtered_entries = [ + entry + for entry in candidates + if not enabled_only or self.check_api_enabled(entry, session_id) + ] + if len(filtered_entries) != 1: + return None + return filtered_entries[0] + + def get_api( + self, + plugin_id: str, + name: str, + *, + version: str = "", + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> Optional[APIEntry]: + """按插件 ID、短名与版本查询单个 API。""" + + return self.get_api_by_full_name( + f"{plugin_id}.{name}", + version=version, + enabled_only=enabled_only, + session_id=session_id, + ) + + def get_apis( + self, + *, + plugin_id: Optional[str] = None, + name: str = "", + version: str = "", + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> List[APIEntry]: + """查询 API 列表。""" + + normalized_name = str(name or "").strip() + normalized_version = self._normalize_query_version(version) + + if plugin_id: + candidates = list(self._by_plugin.get(plugin_id, [])) + elif normalized_name: + candidates = list(self._by_name.get(normalized_name, [])) + else: + candidates = list(self._apis.values()) + + filtered_entries: List[APIEntry] = [] + for entry in candidates: + if plugin_id and entry.plugin_id != plugin_id: + continue + if normalized_name and entry.name != normalized_name: + continue + if normalized_version and entry.version != normalized_version: + continue + if enabled_only and not self.check_api_enabled(entry, session_id): + continue + filtered_entries.append(entry) + + filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version)) + return filtered_entries + + def toggle_api_status( + self, + full_name: str, + enabled: bool, + *, + version: str = "", + session_id: Optional[str] = None, + ) -> bool: + """设置指定 API 的启用状态。""" + + entry = self.get_api_by_full_name( + full_name, + version=version, + enabled_only=False, + session_id=session_id, + ) + if entry is None: + return False + if session_id: + if enabled: + entry.disabled_session.discard(session_id) + else: + entry.disabled_session.add(session_id) + else: + entry.enabled = enabled + if enabled: + entry.offline_reason = "" + entry.metadata.pop("offline_reason", None) + return True diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py new file mode 100644 index 00000000..70593768 --- /dev/null +++ b/src/plugin_runtime/host/authorization.py @@ -0,0 +1,67 @@ +"""授权管理器 + +负责管理插件的能力授权以及校验 +每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。 +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Tuple + +_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"}) + + +@dataclass +class CapabilityPermissionToken: + """能力令牌""" + + plugin_id: str + capabilities: Set[str] = field(default_factory=set) + + +class AuthorizationManager: + """授权管理器 + + 管理所有插件的能力令牌,提供授权校验。 + """ + + def __init__(self) -> None: + self._permission_tokens: Dict[str, CapabilityPermissionToken] = {} + + def register_plugin(self, plugin_id: str, capabilities: List[str]) -> CapabilityPermissionToken: + """为插件签发能力令牌""" + token = CapabilityPermissionToken(plugin_id=plugin_id, capabilities=set(capabilities)) + self._permission_tokens[plugin_id] = token + return token + + def revoke_permission_token(self, plugin_id: str): + """移除插件的能力令牌。""" + self._permission_tokens.pop(plugin_id, None) + + def clear(self) -> None: + """清空所有能力令牌。""" + self._permission_tokens.clear() + + def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]: + # sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression + """检查插件是否有权调用某项能力 + + Returns: + return (bool, str): (是否有此能力, 原因) + """ + if capability in _ALWAYS_ALLOWED_CAPABILITIES: + return True, "" + + token = self._permission_tokens.get(plugin_id) + if not token: + return False, f"插件 {plugin_id} 未注册能力令牌" + if capability not in token.capabilities: + return False, f"插件 {plugin_id} 未获授权能力: {capability}" + return True, "" + + def get_token(self, plugin_id: str) -> Optional[CapabilityPermissionToken]: + """获取插件的能力令牌""" + return self._permission_tokens.get(plugin_id) + + def list_plugins(self) -> List[str]: + """列出所有已注册的插件""" + return list(self._permission_tokens.keys()) diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 6685ff60..0ff31fe1 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -4,21 +4,19 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。 每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。 """ -from typing import Any, Awaitable, Callable, Dict, List +from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING from src.common.logger import get_logger -from src.plugin_runtime.host.policy_engine import PolicyEngine -from src.plugin_runtime.protocol.envelope import ( - CapabilityRequestPayload, - CapabilityResponsePayload, - Envelope, -) +from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope from src.plugin_runtime.protocol.errors import ErrorCode, RPCError +if TYPE_CHECKING: + from src.plugin_runtime.host.authorization import AuthorizationManager + logger = get_logger("plugin_runtime.host.capability_service") # 能力实现函数类型: (plugin_id, capability, args) -> result -CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]] +CapabilityImpl = Callable[[str, str, Dict[str, Any]], Coroutine[Any, Any, Any]] class CapabilityService: @@ -31,8 +29,13 @@ class CapabilityService: 4. 执行实际操作并返回结果 """ - def __init__(self, policy_engine: PolicyEngine) -> None: - self._policy = policy_engine + def __init__(self, authorization: "AuthorizationManager") -> None: + """初始化能力服务。 + + Args: + authorization: 能力授权管理器。 + """ + self._authorization = authorization # capability_name -> implementation self._implementations: Dict[str, CapabilityImpl] = {} @@ -56,46 +59,32 @@ class CapabilityService: try: req = CapabilityRequestPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response( - ErrorCode.E_BAD_PAYLOAD.value, - f"能力调用 payload 格式错误: {e}", - ) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 非法: {exc}") capability = req.capability + args = req.args # 1. 权限校验 - allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation) + allowed, reason = self._authorization.check_capability(plugin_id, capability) if not allowed: - error_code = ( - ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED - ) - return envelope.make_error_response( - error_code.value, - reason, - ) + return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason) # 2. 查找实现 impl = self._implementations.get(capability) if impl is None: - return envelope.make_error_response( - ErrorCode.E_METHOD_NOT_ALLOWED.value, - f"未注册的能力: {capability}", - ) + return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}") # 3. 执行 try: - result = await impl(plugin_id, capability, req.args) + result = await impl(plugin_id, capability, args) resp_payload = CapabilityResponsePayload(success=True, result=result) return envelope.make_response(payload=resp_payload.model_dump()) except RPCError as e: return envelope.make_error_response(e.code.value, e.message, e.details) except Exception as e: logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True) - return envelope.make_error_response( - ErrorCode.E_CAPABILITY_FAILED.value, - str(e), - ) + return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e)) def list_capabilities(self) -> List[str]: """列出所有已注册的能力""" diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 220a19c0..97fdca30 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -1,7 +1,7 @@ """Host-side ComponentRegistry 对齐旧系统 component_registry.py 的核心能力: -- 按类型注册组件(action / command / tool / event_handler / workflow_step) +- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway) - 命名空间 (plugin_id.component_name) - 命令正则匹配 - 组件启用/禁用 @@ -9,8 +9,10 @@ - 注册统计 """ -from typing import Any, Dict, List, Optional +from enum import Enum +from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple +import contextlib import re from src.common.logger import get_logger @@ -18,8 +20,28 @@ from src.common.logger import get_logger logger = get_logger("plugin_runtime.host.component_registry") -class RegisteredComponent: - """已注册的组件条目""" +class ComponentTypes(str, Enum): + ACTION = "ACTION" + COMMAND = "COMMAND" + TOOL = "TOOL" + EVENT_HANDLER = "EVENT_HANDLER" + HOOK_HANDLER = "HOOK_HANDLER" + MESSAGE_GATEWAY = "MESSAGE_GATEWAY" + + +class StatusDict(TypedDict): + total: int + action: int + command: int + tool: int + event_handler: int + hook_handler: int + message_gateway: int + plugins: int + + +class ComponentEntry: + """组件条目""" __slots__ = ( "name", @@ -28,31 +50,120 @@ class RegisteredComponent: "plugin_id", "metadata", "enabled", - "_compiled_pattern", + "compiled_pattern", + "disabled_session", ) - def __init__( - self, - name: str, - component_type: str, - plugin_id: str, - metadata: Dict[str, Any], - ) -> None: - self.name = name - self.full_name = f"{plugin_id}.{name}" - self.component_type = component_type - self.plugin_id = plugin_id - self.metadata = metadata - self.enabled = metadata.get("enabled", True) + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.name: str = name + self.full_name: str = f"{plugin_id}.{name}" + self.component_type: ComponentTypes = ComponentTypes(component_type) + self.plugin_id: str = plugin_id + self.metadata: Dict[str, Any] = metadata + self.enabled: bool = metadata.get("enabled", True) + self.disabled_session: Set[str] = set() - # 预编译命令正则(仅 command 类型) - self._compiled_pattern: Optional[re.Pattern] = None - if component_type == "command": - if pattern := metadata.get("command_pattern", ""): - try: - self._compiled_pattern = re.compile(pattern) - except re.error as e: - logger.warning(f"命令 {self.full_name} 正则编译失败: {e}") + +class ActionEntry(ComponentEntry): + """Action 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + super().__init__(name, component_type, plugin_id, metadata) + + +class CommandEntry(ComponentEntry): + """Command 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + super().__init__(name, component_type, plugin_id, metadata) + self.aliases: List[str] = metadata.get("aliases", []) + self.compiled_pattern: Optional[re.Pattern] = None + if pattern := metadata.get("command_pattern", ""): + try: + self.compiled_pattern = re.compile(pattern) + except (re.error, TypeError) as e: + logger.warning(f"命令 {self.full_name} 正则编译失败: {e}") + + +class ToolEntry(ComponentEntry): + """Tool 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.description: str = metadata.get("description", "") + self.parameters: List[Dict[str, Any]] = metadata.get("parameters", []) + self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", []) + super().__init__(name, component_type, plugin_id, metadata) + + +class EventHandlerEntry(ComponentEntry): + """EventHandler 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.event_type: str = metadata.get("event_type", "") + self.weight: int = metadata.get("weight", 0) + self.intercept_message: bool = metadata.get("intercept_message", False) + super().__init__(name, component_type, plugin_id, metadata) + + +class HookHandlerEntry(ComponentEntry): + """WorkflowHandler 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.stage: str = metadata.get("stage", "") + self.priority: int = metadata.get("priority", 0) + self.blocking: bool = metadata.get("blocking", False) + super().__init__(name, component_type, plugin_id, metadata) + + +class MessageGatewayEntry(ComponentEntry): + """MessageGateway 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.route_type: str = self._normalize_route_type(metadata.get("route_type", "")) + self.platform: str = str(metadata.get("platform", "") or "").strip() + self.protocol: str = str(metadata.get("protocol", "") or "").strip() + self.account_id: str = str(metadata.get("account_id", "") or "").strip() + self.scope: str = str(metadata.get("scope", "") or "").strip() + super().__init__(name, component_type, plugin_id, metadata) + + @staticmethod + def _normalize_route_type(raw_value: Any) -> str: + """规范化消息网关路由类型。 + + Args: + raw_value: 原始路由类型值。 + + Returns: + str: 规范化后的路由类型。 + + Raises: + ValueError: 当路由类型不受支持时抛出。 + """ + + normalized_value = str(raw_value or "").strip().lower() + route_type_aliases = { + "send": "send", + "receive": "receive", + "recv": "receive", + "recive": "receive", + "duplex": "duplex", + } + route_type = route_type_aliases.get(normalized_value) + if route_type is None: + raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}") + return route_type + + @property + def supports_send(self) -> bool: + """返回当前网关是否支持出站。""" + + return self.route_type in {"send", "duplex"} + + @property + def supports_receive(self) -> bool: + """返回当前网关是否支持入站。""" + + return self.route_type in {"receive", "duplex"} class ComponentRegistry: @@ -64,19 +175,32 @@ class ComponentRegistry: def __init__(self) -> None: # 全量索引 - self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp + self._components: Dict[str, ComponentEntry] = {} # full_name -> comp # 按类型索引 - self._by_type: Dict[str, Dict[str, RegisteredComponent]] = { - "action": {}, - "command": {}, - "tool": {}, - "event_handler": {}, - "workflow_step": {}, - } + self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = { + comp_type: {} for comp_type in ComponentTypes + } # component_type -> (full_name -> comp) # 按插件索引 - self._by_plugin: Dict[str, List[RegisteredComponent]] = {} + self._by_plugin: Dict[str, List[ComponentEntry]] = {} + + @staticmethod + def _normalize_component_type(component_type: str) -> ComponentTypes: + """规范化组件类型输入。 + + Args: + component_type: 原始组件类型字符串。 + + Returns: + ComponentTypes: 规范化后的组件类型枚举。 + + Raises: + ValueError: 当组件类型不受支持时抛出。 + """ + + normalized_value = str(component_type or "").strip().upper() + return ComponentTypes(normalized_value) def clear(self) -> None: """清空全部组件注册状态。""" @@ -85,47 +209,64 @@ class ComponentRegistry: type_dict.clear() self._by_plugin.clear() - # ──── 注册 / 注销 ───────────────────────────────────────── + # ====== 注册 / 注销 ====== + def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: + """注册单个组件 + + Args: + name: 组件名称(不含插件id前缀) + component_type: 组件类型(如 `ACTION`、`COMMAND` 等) + plugin_id: 插件id + metadata: 组件元数据 + Returns: + success (bool): 是否成功注册(失败原因通常是组件类型无效) + """ + try: + normalized_type = self._normalize_component_type(component_type) + if normalized_type == ComponentTypes.ACTION: + comp = ActionEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.COMMAND: + comp = CommandEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.TOOL: + comp = ToolEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.EVENT_HANDLER: + comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.HOOK_HANDLER: + comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.MESSAGE_GATEWAY: + comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata) + else: + raise ValueError(f"组件类型 {component_type} 不存在") + except ValueError: + logger.error(f"组件类型 {component_type} 不存在") + return False - def register_component( - self, - name: str, - component_type: str, - plugin_id: str, - metadata: Dict[str, Any], - ) -> bool: - """注册单个组件。""" - comp = RegisteredComponent(name, component_type, plugin_id, metadata) if comp.full_name in self._components: logger.warning(f"组件 {comp.full_name} 已存在,覆盖") old_comp = self._components[comp.full_name] # 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积 old_list = self._by_plugin.get(old_comp.plugin_id) if old_list is not None: - try: + with contextlib.suppress(ValueError): old_list.remove(old_comp) - except ValueError: - pass # 从旧类型索引中移除,防止类型变更时幽灵残留 if old_type_dict := self._by_type.get(old_comp.component_type): old_type_dict.pop(comp.full_name, None) self._components[comp.full_name] = comp - - if component_type not in self._by_type: - self._by_type[component_type] = {} - self._by_type[component_type][comp.full_name] = comp - + self._by_type[comp.component_type][comp.full_name] = comp self._by_plugin.setdefault(plugin_id, []).append(comp) return True - def register_plugin_components( - self, - plugin_id: str, - components: List[Dict[str, Any]], - ) -> int: - """批量注册一个插件的所有组件,返回成功注册数。""" + def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int: + """批量注册一个插件的所有组件,返回成功注册数。 + Args: + plugin_id (str): 插件id + components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段 + Returns: + count (int): 成功注册的组件数量 + """ count = 0 for comp_data in components: ok = self.register_component( @@ -139,7 +280,13 @@ class ComponentRegistry: return count def remove_components_by_plugin(self, plugin_id: str) -> int: - """移除某个插件的所有组件,返回移除数量。""" + """移除某个插件的所有组件,返回移除数量。 + + Args: + plugin_id (str): 插件id + Returns: + count (int): 移除的组件数量 + """ comps = self._by_plugin.pop(plugin_id, []) for comp in comps: self._components.pop(comp.full_name, None) @@ -147,106 +294,280 @@ class ComponentRegistry: type_dict.pop(comp.full_name, None) return len(comps) - # ──── 启用 / 禁用 ───────────────────────────────────────── + # ====== 启用 / 禁用 ====== + def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None): + if session_id and session_id in component.disabled_session: + return False + return component.enabled - def set_component_enabled(self, full_name: str, enabled: bool) -> bool: - """启用或禁用指定组件。""" + def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: + """启用或禁用指定组件。 + + Args: + full_name (str): 组件全名 + enabled (bool): 使能情况 + session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供) + Returns: + success (bool): 是否成功设置(失败原因通常是组件不存在) + """ comp = self._components.get(full_name) if comp is None: return False - comp.enabled = enabled + if session_id: + if enabled: + comp.disabled_session.discard(session_id) + else: + comp.disabled_session.add(session_id) + else: + comp.enabled = enabled return True - def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int: - """批量启用或禁用某插件的所有组件。""" + def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: + """设置指定组件的启用状态。 + + Args: + full_name: 组件全名。 + enabled: 目标启用状态。 + session_id: 可选的会话 ID,仅对该会话生效。 + + Returns: + bool: 是否设置成功。 + """ + + return self.toggle_component_status(full_name, enabled, session_id=session_id) + + def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int: + """批量启用或禁用某插件的所有组件。 + + Args: + plugin_id (str): 插件id + enabled (bool): 使能情况 + session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供) + Returns: + count (int): 成功设置的组件数量(失败原因通常是插件不存在) + """ comps = self._by_plugin.get(plugin_id, []) for comp in comps: - comp.enabled = enabled + if session_id: + if enabled: + comp.disabled_session.discard(session_id) + else: + comp.disabled_session.add(session_id) + else: + comp.enabled = enabled return len(comps) - # ──── 查询方法 ───────────────────────────────────────────── + def get_component(self, full_name: str) -> Optional[ComponentEntry]: + """按全名查询。 - def get_component(self, full_name: str) -> Optional[RegisteredComponent]: - """按全名查询。""" + Args: + full_name (str): 组件全名 + Returns: + component (Optional[ComponentEntry]): 组件条目,未找到时为 None + """ return self._components.get(full_name) - def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: - """按类型查询。""" - type_dict = self._by_type.get(component_type, {}) + def get_components_by_type( + self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> List[ComponentEntry]: + """按类型查询组件 + + Args: + component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等) + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + components (List[ComponentEntry]): 组件条目列表 + """ + try: + comp_type = self._normalize_component_type(component_type) + except ValueError: + logger.error(f"组件类型 {component_type} 不存在") + raise + type_dict = self._by_type.get(comp_type, {}) if enabled_only: - return [c for c in type_dict.values() if c.enabled] + return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)] return list(type_dict.values()) - def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: - """按插件查询。""" - comps = self._by_plugin.get(plugin_id, []) - return [c for c in comps if c.enabled] if enabled_only else list(comps) + def get_components_by_plugin( + self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> List[ComponentEntry]: + """按插件查询组件。 - def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]: + Args: + plugin_id (str): 插件ID + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + components (List[ComponentEntry]): 组件条目列表 + """ + comps = self._by_plugin.get(plugin_id, []) + return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps) + + def find_command_by_text( + self, text: str, session_id: Optional[str] = None + ) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]: """通过文本匹配命令正则,返回 (组件, matched_groups) 元组。 matched_groups 为正则命名捕获组 dict,别名匹配时为空 dict。 + Args: + text (str): 待匹配文本 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None """ - for comp in self._by_type.get("command", {}).values(): - if not comp.enabled: + for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values(): + if not self.check_component_enabled(comp, session_id): continue - if comp._compiled_pattern: - m = comp._compiled_pattern.search(text) - if m: + if not isinstance(comp, CommandEntry): + continue + if comp.compiled_pattern: + if m := comp.compiled_pattern.search(text): return comp, m.groupdict() # 别名匹配 - aliases = comp.metadata.get("aliases", []) - for alias in aliases: + for alias in comp.aliases: if text.startswith(alias): return comp, {} return None - def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: - """获取特定事件类型的所有 event_handler,按 weight 降序排列。""" - handlers = [] - for comp in self._by_type.get("event_handler", {}).values(): - if enabled_only and not comp.enabled: + def get_event_handlers( + self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> List[EventHandlerEntry]: + """查询指定事件类型的事件处理器组件。 + + Args: + event_type (str): 事件类型 + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序 + """ + handlers: List[EventHandlerEntry] = [] + for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values(): + if enabled_only and not self.check_component_enabled(comp, session_id): continue - if comp.metadata.get("event_type") == event_type: + if not isinstance(comp, EventHandlerEntry): + continue + if comp.event_type == event_type: handlers.append(comp) - handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True) + handlers.sort(key=lambda c: c.weight, reverse=True) return handlers - def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: - """获取特定 workflow 阶段的所有步骤,按 priority 降序。""" - steps = [] - for comp in self._by_type.get("workflow_step", {}).values(): - if enabled_only and not comp.enabled: + def get_hook_handlers( + self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> List[HookHandlerEntry]: + """获取特定 hook 阶段的所有步骤,按 priority 降序。 + + Args: + stage: hook 名称 + enabled_only: 是否仅返回启用的组件 + session_id: 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序 + """ + handlers: List[HookHandlerEntry] = [] + for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values(): + if enabled_only and not self.check_component_enabled(comp, session_id): continue - if comp.metadata.get("stage") == stage: - steps.append(comp) - steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True) - return steps + if not isinstance(comp, HookHandlerEntry): + continue + if comp.stage == stage: + handlers.append(comp) + handlers.sort(key=lambda c: c.priority, reverse=True) + return handlers - def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]: - """获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。""" - result: List[Dict[str, Any]] = [] - for comp in self.get_components_by_type("tool", enabled_only=enabled_only): - tool_def: Dict[str, Any] = { - "name": comp.full_name, - "description": comp.metadata.get("description", ""), - } - # 从结构化参数或原始参数构建 parameters - params = comp.metadata.get("parameters", []) - params_raw = comp.metadata.get("parameters_raw", {}) - if params: - tool_def["parameters"] = params - elif params_raw: - tool_def["parameters"] = params_raw - result.append(tool_def) - return result + def get_message_gateway( + self, + plugin_id: str, + name: str, + *, + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> Optional[MessageGatewayEntry]: + """按插件和组件名获取单个消息网关。 - # ──── 统计 ───────────────────────────────────────────────── + Args: + plugin_id: 插件 ID。 + name: 网关组件名称。 + enabled_only: 是否仅返回启用的组件。 + session_id: 可选的会话 ID。 - def get_stats(self) -> Dict[str, int]: - """获取注册统计。""" - stats: Dict[str, int] = {"total": len(self._components)} + Returns: + Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。 + """ + + component = self._components.get(f"{plugin_id}.{name}") + if not isinstance(component, MessageGatewayEntry): + return None + if enabled_only and not self.check_component_enabled(component, session_id): + return None + return component + + def get_message_gateways( + self, + *, + plugin_id: Optional[str] = None, + platform: str = "", + route_type: str = "", + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> List[MessageGatewayEntry]: + """查询消息网关组件列表。 + + Args: + plugin_id: 可选的插件 ID 过滤条件。 + platform: 可选的平台过滤条件。 + route_type: 可选的路由类型过滤条件。 + enabled_only: 是否仅返回启用的组件。 + session_id: 可选的会话 ID。 + + Returns: + List[MessageGatewayEntry]: 符合条件的消息网关组件列表。 + """ + + normalized_platform = str(platform or "").strip() + normalized_route_type = str(route_type or "").strip().lower() + gateways: List[MessageGatewayEntry] = [] + for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values(): + if not isinstance(comp, MessageGatewayEntry): + continue + if plugin_id and comp.plugin_id != plugin_id: + continue + if enabled_only and not self.check_component_enabled(comp, session_id): + continue + if normalized_platform and comp.platform != normalized_platform: + continue + if normalized_route_type and comp.route_type != normalized_route_type: + continue + gateways.append(comp) + return gateways + + def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]: + """查询所有工具组件。 + + Args: + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + tools (List[ToolEntry]): 符合条件的 Tool 组件列表 + """ + tools: List[ToolEntry] = [] + for comp in self._by_type.get(ComponentTypes.TOOL, {}).values(): + if enabled_only and not self.check_component_enabled(comp, session_id): + continue + if isinstance(comp, ToolEntry): + tools.append(comp) + return tools + + # ====== 统计信息 ====== + def get_stats(self) -> StatusDict: + """获取注册统计。 + + Returns: + stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等 + """ + stats: StatusDict = {"total": len(self._components)} # type: ignore for comp_type, type_dict in self._by_type.items(): - stats[comp_type] = len(type_dict) + stats[comp_type.value.lower()] = len(type_dict) stats["plugins"] = len(self._by_plugin) return stats diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py index 720e93d7..d252b6ee 100644 --- a/src/plugin_runtime/host/event_dispatcher.py +++ b/src/plugin_runtime/host/event_dispatcher.py @@ -4,40 +4,40 @@ 1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry) 2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器 3. 支持阻塞(intercept_message)和非阻塞分发 -4. 事件结果历史记录 +4. 事件结果历史记录(有上限) """ -from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING import asyncio from src.common.logger import get_logger -from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent + +from .message_utils import PluginMessageUtils, MessageDict + +if TYPE_CHECKING: + from .supervisor import PluginRunnerSupervisor + from .component_registry import ComponentRegistry, EventHandlerEntry + from src.chat.message_receive.message import SessionMessage logger = get_logger("plugin_runtime.host.event_dispatcher") # invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]] +# 每个事件类型的最大历史记录数量,防止内存无限增长 +_MAX_HISTORY_LENGTH = 100 +@dataclass class EventResult: """单个 EventHandler 的执行结果""" - __slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result") - - def __init__( - self, - handler_name: str, - success: bool = True, - continue_processing: bool = True, - modified_message: Optional[Dict[str, Any]] = None, - custom_result: Any = None, - ): - self.handler_name = handler_name - self.success = success - self.continue_processing = continue_processing - self.modified_message = modified_message - self.custom_result = custom_result + handler_name: str + success: bool = field(default=True) + continue_processing: bool = field(default=True) + modified_message: Optional[MessageDict] = field(default=None) + custom_result: Any = field(default=None) class EventDispatcher: @@ -48,17 +48,20 @@ class EventDispatcher: 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。 """ - def __init__(self, registry: ComponentRegistry) -> None: - self._registry: ComponentRegistry = registry + def __init__(self, component_registry: "ComponentRegistry") -> None: + self._component_registry: "ComponentRegistry" = component_registry self._result_history: Dict[str, List[EventResult]] = {} self._history_enabled: Set[str] = set() - # 保持 fire-and-forget task 的强引用,防止被 GC 回收 self._background_tasks: Set[asyncio.Task] = set() def enable_history(self, event_type: str) -> None: self._history_enabled.add(event_type) self._result_history.setdefault(event_type, []) + def disable_history(self, event_type: str) -> None: + self._history_enabled.discard(event_type) + self._result_history.pop(event_type, None) + def get_history(self, event_type: str) -> List[EventResult]: return self._result_history.get(event_type, []) @@ -66,47 +69,58 @@ class EventDispatcher: if event_type in self._result_history: self._result_history[event_type] = [] + async def stop(self): + """停止 EventDispatcher,取消所有未完成的后台任务""" + for task in self._background_tasks: + task.cancel() + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + async def dispatch_event( self, event_type: str, - invoke_fn: InvokeFn, - message: Optional[Dict[str, Any]] = None, + supervisor: "PluginRunnerSupervisor", + message: Optional["SessionMessage"] = None, extra_args: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]]]: - """分发事件到所有对应 handler。 + ) -> Tuple[bool, Optional["SessionMessage"]]: + """分发事件到所有对应 handler 的便捷方法。 + + 内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑, + 无需调用方手动构造 invoke_fn 闭包。 Args: event_type: 事件类型字符串 - invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict + supervisor: PluginSupervisor 实例,用于调用 invoke_plugin message: MaiMessages 序列化后的 dict(可选) extra_args: 额外参数 Returns: - (should_continue, modified_message_dict) + (should_continue, modified_message_dict) (bool, SessionMessage | None): (是否继续后续执行, 可选的修改后的消息) """ - handlers = self._registry.get_event_handlers(event_type) - if not handlers: + handler_entries = self._component_registry.get_event_handlers(event_type) + if not handler_entries: return True, None should_continue = True - modified_message: Optional[Dict[str, Any]] = None - intercept_handlers: List[RegisteredComponent] = [] - async_handlers: List[RegisteredComponent] = [] + modified_message: Optional[MessageDict] = ( + PluginMessageUtils._session_message_to_dict(message) if message else None + ) + intercept_handlers: List["EventHandlerEntry"] = [] + non_blocking_handlers: List["EventHandlerEntry"] = [] - for handler in handlers: - if handler.metadata.get("intercept_message", False): - intercept_handlers.append(handler) + for entry in handler_entries: + if entry.intercept_message: + intercept_handlers.append(entry) else: - async_handlers.append(handler) + non_blocking_handlers.append(entry) - for handler in intercept_handlers: + for entry in intercept_handlers: args = { "event_type": event_type, - "message": modified_message or message, + "message": modified_message, **(extra_args or {}), } - - result = await self._invoke_handler(invoke_fn, handler, args, event_type) + result = await self._invoke_handler(supervisor, entry, args, event_type) if result and not result.continue_processing: should_continue = False break @@ -114,47 +128,57 @@ class EventDispatcher: modified_message = result.modified_message if should_continue: - final_message = modified_message or message - for handler in async_handlers: - async_message = final_message.copy() if isinstance(final_message, dict) else final_message + final_message = modified_message + for entry in non_blocking_handlers: + async_message = final_message.copy() if final_message else final_message args = { "event_type": event_type, "message": async_message, **(extra_args or {}), } # 非阻塞:保持实例级强引用,防止 task 被 GC 回收 - task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type)) + task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) - - return should_continue, modified_message + try: + modified_message_obj = ( + PluginMessageUtils._build_session_message_from_dict(modified_message) if modified_message else None # type: ignore + ) + except Exception as e: + logger.error(f"构建修改后的 SessionMessage 失败: {e}") + modified_message_obj = None + return should_continue, modified_message_obj async def _invoke_handler( self, - invoke_fn: InvokeFn, - handler: RegisteredComponent, + supervisor: "PluginRunnerSupervisor", + handler_entry: "EventHandlerEntry", args: Dict[str, Any], event_type: str, ) -> Optional[EventResult]: """调用单个 handler 并收集结果。""" try: - resp = await invoke_fn(handler.plugin_id, handler.name, args) + resp_envelope = await supervisor.invoke_plugin( + "plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args + ) + resp = resp_envelope.payload result = EventResult( - handler_name=handler.full_name, + handler_name=handler_entry.full_name, success=resp.get("success", True), continue_processing=resp.get("continue_processing", True), modified_message=resp.get("modified_message"), custom_result=resp.get("custom_result"), ) except Exception as e: - logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True) - result = EventResult( - handler_name=handler.full_name, - success=False, - continue_processing=True, - ) + logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True) + result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True) if event_type in self._history_enabled: - self._result_history.setdefault(event_type, []).append(result) + history_list = self._result_history.setdefault(event_type, []) + history_list.append(result) + # 自动清理超出限制的旧记录,防止内存无限增长 + if len(history_list) > _MAX_HISTORY_LENGTH: + # 保留最新的 _MAX_HISTORY_LENGTH 条记录 + self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:] return result diff --git a/src/plugin_runtime/host/hook_dispatcher.py b/src/plugin_runtime/host/hook_dispatcher.py new file mode 100644 index 00000000..d5e88448 --- /dev/null +++ b/src/plugin_runtime/host/hook_dispatcher.py @@ -0,0 +1,166 @@ +""" +Hook Dispatch 系统 + +插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。 +每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。 +在参数/返回值匹配的情况下允许修改参数/返回值。 + +HookDispatcher 负责: +1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry) +2. 按 priority 排序,区分 blocking 和非 blocking 模式 +3. blocking 模式:依次同步调用,支持修改参数/提前终止 +4. 非 blocking 模式:异步调用,不阻塞主流程 +5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限 +""" + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING + +from src.common.logger import get_logger +from src.config.config import global_config + + +if TYPE_CHECKING: + from .supervisor import PluginRunnerSupervisor + from .component_registry import ComponentRegistry, HookHandlerEntry + +logger = get_logger("plugin_runtime.host.hook_dispatcher") + + +@dataclass +class HookResult: + """单个 HookHandler 的执行结果""" + + handler_name: str + success: bool = field(default=True) + continue_processing: bool = field(default=True) + modified_kwargs: Optional[Dict[str, Any]] = field(default=None) + custom_result: Any = field(default=None) + + +class HookDispatcher: + """Host-side Hook 分发器 + + 由业务层调用 hook_dispatch(), + 内部通过 ComponentRegistry 查询 handler, + 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。 + """ + + def __init__(self, component_registry: "ComponentRegistry") -> None: + """初始化 HookDispatcher + + Args: + component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler + """ + self._component_registry: "ComponentRegistry" = component_registry + self._background_tasks: Set[asyncio.Task] = set() + + async def stop(self) -> None: + """停止 HookDispatcher,取消所有未完成的后台任务""" + for task in self._background_tasks: + task.cancel() + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + + async def hook_dispatch( + self, + stage: str, + supervisor: "PluginRunnerSupervisor", + **kwargs: Any, + ) -> Dict[str, Any]: + """分发 hook 到所有对应 handler 的便捷方法。 + + 内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑, + 无需调用方手动构造 invoke_fn 闭包。 + + Args: + stage: hook 名称 + supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin + **kwargs: 关键字参数,会展开传递给 handler + + Returns: + modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数 + """ + handler_entries = self._component_registry.get_hook_handlers(stage) + if not handler_entries: + return kwargs + + current_kwargs = kwargs.copy() + blocking_handlers: List["HookHandlerEntry"] = [] + non_blocking_handlers: List["HookHandlerEntry"] = [] + + # 分离 blocking 和非 blocking handler + for entry in handler_entries: + if entry.blocking: + blocking_handlers.append(entry) + else: + non_blocking_handlers.append(entry) + + # 处理 blocking handlers(同步调用,支持修改参数/提前终止) + timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0 + for entry in blocking_handlers: + hook_args = {"stage": stage, **current_kwargs} + try: + # 应用超时控制 + result = await asyncio.wait_for( + self._invoke_handler(supervisor, entry, hook_args), + timeout=timeout, + ) + except asyncio.TimeoutError: + logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过") + result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True) + + if result: + if result.modified_kwargs is not None: + current_kwargs = result.modified_kwargs + if not result.continue_processing: + logger.info(f"HookHandler {entry.full_name} 终止了后续处理") + break + + # 处理 non-blocking handlers(异步调用,不阻塞主流程) + for entry in non_blocking_handlers: + async_kwargs = current_kwargs.copy() + hook_args = {"stage": stage, **async_kwargs} + task = asyncio.create_task( + asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout) + ) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + return current_kwargs + + async def _invoke_handler( + self, + supervisor: "PluginRunnerSupervisor", + handler_entry: "HookHandlerEntry", + args: Dict[str, Any], + ) -> Optional[HookResult]: + """调用单个 handler 并收集结果。 + + Args: + supervisor: PluginRunnerSupervisor 实例 + handler_entry: HookHandlerEntry 实例 + args: 传递给 handler 的参数字典 + stage: hook 名称 + + Returns: + Optional[HookResult]: 执行结果,如果执行失败则返回 None + """ + try: + resp_envelope = await supervisor.invoke_plugin( + "plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args + ) + resp = resp_envelope.payload + result = HookResult( + handler_name=handler_entry.full_name, + success=resp.get("success", True), + continue_processing=resp.get("continue_processing", True), + modified_kwargs=resp.get("modified_kwargs"), + custom_result=resp.get("custom_result"), + ) + except Exception as e: + logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True) + result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True) + + return result diff --git a/src/plugin_runtime/host/logger_bridge.py b/src/plugin_runtime/host/logger_bridge.py new file mode 100644 index 00000000..f2213dfe --- /dev/null +++ b/src/plugin_runtime/host/logger_bridge.py @@ -0,0 +1,45 @@ +import logging as stdlib_logging +from src.plugin_runtime.protocol.errors import ErrorCode +from src.plugin_runtime.protocol.envelope import Envelope, LogBatchPayload +class RunnerLogBridge: + """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。 + + Runner 通过 ``runner.log_batch`` IPC 事件批量到达。 + 每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接 + 调用 ``logging.getLogger(entry.logger_name).handle(record)``, + 从而接入主进程已配置好的 structlog Handler 链。 + """ + + async def handle_log_batch(self, envelope: Envelope) -> Envelope: + """IPC 事件处理器:解析批量日志并重放到主进程 Logger。 + + Args: + envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。 + + Returns: + 空响应信封(事件模式下将被忽略)。 + """ + try: + batch = LogBatchPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + for entry in batch.entries: + # 重建一个与原始日志尽量相符的 LogRecord + record = stdlib_logging.LogRecord( + name=entry.logger_name, + level=entry.level, + pathname="", + lineno=0, + msg=entry.message, + args=(), + exc_info=None, + ) + record.created = entry.timestamp_ms / 1000.0 + record.msecs = entry.timestamp_ms % 1000 + if entry.exception_text: + record.exc_text = entry.exception_text + + stdlib_logging.getLogger(entry.logger_name).handle(record) + + return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)}) \ No newline at end of file diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py new file mode 100644 index 00000000..90f94493 --- /dev/null +++ b/src/plugin_runtime/host/message_gateway.py @@ -0,0 +1,112 @@ +"""Host 侧消息网关包装器。""" + +from typing import TYPE_CHECKING, Any, Dict + +from src.common.logger import get_logger +from src.platform_io import get_platform_io_manager + +from .message_utils import PluginMessageUtils + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + from .component_registry import ComponentRegistry + from .supervisor import PluginRunnerSupervisor + +logger = get_logger("plugin_runtime.host.message_gateway") + + +class MessageGateway: + """Host 侧消息网关包装器。""" + + def __init__(self, component_registry: "ComponentRegistry") -> None: + """初始化消息网关。 + + Args: + component_registry: 组件注册表。 + """ + self._component_registry = component_registry + + def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage": + """将标准消息字典转换为 ``SessionMessage``。 + + Args: + external_message: 外部消息的字典格式数据。 + + Returns: + SessionMessage: 转换后的内部消息对象。 + + Raises: + ValueError: 消息字典不合法时抛出。 + """ + return PluginMessageUtils._build_session_message_from_dict(external_message) + + def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]: + """将 ``SessionMessage`` 转换为标准消息字典。 + + Args: + internal_message: 内部消息对象。 + + Returns: + Dict[str, Any]: 供消息网关插件消费的标准消息字典。 + """ + return dict(PluginMessageUtils._session_message_to_dict(internal_message)) + + async def receive_external_message(self, external_message: Dict[str, Any]) -> None: + """接收外部消息并送入主消息链。 + + Args: + external_message: 外部消息的字典格式数据。 + """ + try: + session_message = self.build_session_message(external_message) + except Exception as e: + logger.error(f"转换外部消息失败: {e}") + return + + from src.chat.message_receive.bot import chat_bot + + await chat_bot.receive_message(session_message) + + async def send_message_to_external( + self, + internal_message: "SessionMessage", + supervisor: "PluginRunnerSupervisor", + *, + enabled_only: bool = True, + save_to_db: bool = True, + ) -> bool: + """将内部消息通过 Platform IO 发送到外部平台。 + + Args: + internal_message: 系统内部的 ``SessionMessage`` 对象。 + supervisor: 当前持有该消息网关的 Supervisor。 + enabled_only: 兼容旧签名的保留参数,当前未使用。 + save_to_db: 发送成功后是否写入数据库。 + + Returns: + bool: 是否发送成功。 + """ + del enabled_only + del supervisor + + platform_io_manager = get_platform_io_manager() + if not platform_io_manager.is_started: + logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息") + return False + + route_key = platform_io_manager.build_route_key_from_message(internal_message) + delivery_batch = await platform_io_manager.send_message(internal_message, route_key) + if not delivery_batch.has_success: + logger.warning("通过消息网关链路发送消息失败: 未命中任何成功回执") + return False + + first_successful_receipt = delivery_batch.sent_receipts[0] + internal_message.message_id = first_successful_receipt.external_message_id or internal_message.message_id + if save_to_db: + try: + from src.common.utils.utils_message import MessageUtils + + MessageUtils.store_message_to_db(internal_message) + except Exception as e: + logger.error(f"保存消息到数据库失败: {e}") + return True diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py new file mode 100644 index 00000000..2f6aa01b --- /dev/null +++ b/src/plugin_runtime/host/message_utils.py @@ -0,0 +1,487 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional, TypedDict + +import base64 +import hashlib + +from src.common.logger import get_logger +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo +from src.common.data_models.message_component_data_model import ( + AtComponent, + DictComponent, + EmojiComponent, + ForwardComponent, + ForwardNodeComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + StandardMessageComponents, + TextComponent, + VoiceComponent, +) + +logger = get_logger("plugin_runtime.host.message_utils") + + +class UserInfoDict(TypedDict, total=False): + user_id: str + user_nickname: str + user_cardname: Optional[str] + + +class GroupInfoDict(TypedDict, total=False): + group_id: str + group_name: str + + +class MessageInfoDict(TypedDict, total=False): + user_info: UserInfoDict + group_info: Optional[GroupInfoDict] + additional_config: Dict[str, Any] + + +class MessageDict(TypedDict, total=False): + message_id: str + timestamp: str + platform: str + message_info: MessageInfoDict + raw_message: List[Dict[str, Any]] + is_mentioned: bool + is_at: bool + is_emoji: bool + is_picture: bool + is_command: bool + is_notify: bool + session_id: str + reply_to: Optional[str] + processed_plain_text: Optional[str] + display_message: Optional[str] + + +class PluginMessageUtils: + @staticmethod + def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]: + """将消息组件序列转换为插件运行时使用的字典结构。 + + Args: + message_sequence: 待转换的消息组件序列。 + + Returns: + List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。 + """ + return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components] + + @staticmethod + def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]: + """将单个消息组件转换为插件运行时字典结构。 + + Args: + component: 待转换的消息组件。 + + Returns: + Dict[str, Any]: 序列化后的消息组件字典。 + """ + if isinstance(component, TextComponent): + return {"type": "text", "data": component.text} + + if isinstance(component, ImageComponent): + serialized = { + "type": "image", + "data": component.content, + "hash": component.binary_hash, + } + if component.binary_data: + serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8") + return serialized + + if isinstance(component, EmojiComponent): + serialized = { + "type": "emoji", + "data": component.content, + "hash": component.binary_hash, + } + if component.binary_data: + serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8") + return serialized + + if isinstance(component, VoiceComponent): + serialized = { + "type": "voice", + "data": component.content, + "hash": component.binary_hash, + } + if component.binary_data: + serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8") + return serialized + + if isinstance(component, AtComponent): + return { + "type": "at", + "data": { + "target_user_id": component.target_user_id, + "target_user_nickname": component.target_user_nickname, + "target_user_cardname": component.target_user_cardname, + }, + } + + if isinstance(component, ReplyComponent): + return { + "type": "reply", + "data": { + "target_message_id": component.target_message_id, + "target_message_content": component.target_message_content, + "target_message_sender_id": component.target_message_sender_id, + "target_message_sender_nickname": component.target_message_sender_nickname, + "target_message_sender_cardname": component.target_message_sender_cardname, + }, + } + + if isinstance(component, ForwardNodeComponent): + return { + "type": "forward", + "data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components], + } + + return {"type": "dict", "data": component.data} + + @staticmethod + def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]: + """将单个转发节点组件转换为字典结构。 + + Args: + component: 待转换的转发节点组件。 + + Returns: + Dict[str, Any]: 序列化后的转发节点字典。 + """ + return { + "user_id": component.user_id, + "user_nickname": component.user_nickname, + "user_cardname": component.user_cardname, + "message_id": component.message_id, + "content": [PluginMessageUtils._component_to_dict(item) for item in component.content], + } + + @staticmethod + def _message_sequence_from_dict(raw_message_data: List[Dict[str, Any]]) -> MessageSequence: + """从插件运行时字典结构恢复消息组件序列。 + + Args: + raw_message_data: 插件运行时消息段字典列表。 + + Returns: + MessageSequence: 恢复后的消息组件序列。 + """ + components = [PluginMessageUtils._component_from_dict(item) for item in raw_message_data] + return MessageSequence(components=components) + + @staticmethod + def _component_from_dict(item: Dict[str, Any]) -> StandardMessageComponents: + """从插件运行时字典结构恢复单个消息组件。 + + Args: + item: 单个消息组件的字典表示。 + + Returns: + StandardMessageComponents: 恢复后的内部消息组件对象。 + """ + item_type = str(item.get("type") or "").strip() + if item_type == "text": + return TextComponent(text=str(item.get("data") or "")) + + if item_type == "image": + return PluginMessageUtils._build_binary_component(ImageComponent, item) + + if item_type == "emoji": + return PluginMessageUtils._build_binary_component(EmojiComponent, item) + + if item_type == "voice": + return PluginMessageUtils._build_binary_component(VoiceComponent, item) + + if item_type == "at": + item_data = item.get("data", {}) + if not isinstance(item_data, dict): + item_data = {} + return AtComponent( + target_user_id=str(item_data.get("target_user_id") or ""), + target_user_nickname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_nickname")), + target_user_cardname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_cardname")), + ) + + if item_type == "reply": + reply_data = item.get("data") + if isinstance(reply_data, dict): + return ReplyComponent( + target_message_id=str(reply_data.get("target_message_id") or ""), + target_message_content=PluginMessageUtils._normalize_optional_string( + reply_data.get("target_message_content") + ), + target_message_sender_id=PluginMessageUtils._normalize_optional_string( + reply_data.get("target_message_sender_id") + ), + target_message_sender_nickname=PluginMessageUtils._normalize_optional_string( + reply_data.get("target_message_sender_nickname") + ), + target_message_sender_cardname=PluginMessageUtils._normalize_optional_string( + reply_data.get("target_message_sender_cardname") + ), + ) + return ReplyComponent(target_message_id=str(reply_data or "")) + + if item_type == "forward": + forward_nodes: List[ForwardComponent] = [] + raw_forward_nodes = item.get("data", []) + if isinstance(raw_forward_nodes, list): + for node in raw_forward_nodes: + if not isinstance(node, dict): + continue + raw_content = node.get("content", []) + node_components: List[StandardMessageComponents] = [] + if isinstance(raw_content, list): + node_components = [ + PluginMessageUtils._component_from_dict(content) + for content in raw_content + if isinstance(content, dict) + ] + if not node_components: + node_components = [TextComponent(text="[empty forward node]")] + forward_nodes.append( + ForwardComponent( + user_nickname=str(node.get("user_nickname") or "未知用户"), + user_id=PluginMessageUtils._normalize_optional_string(node.get("user_id")), + user_cardname=PluginMessageUtils._normalize_optional_string(node.get("user_cardname")), + message_id=str(node.get("message_id") or ""), + content=node_components, + ) + ) + if not forward_nodes: + return DictComponent(data={"type": "forward", "data": item.get("data", [])}) + return ForwardNodeComponent(forward_components=forward_nodes) + + component_data = item.get("data") + if isinstance(component_data, dict): + return DictComponent(data=component_data) + return DictComponent(data=item) + + @staticmethod + def _build_binary_component(component_cls: Any, item: Dict[str, Any]) -> StandardMessageComponents: + """从字典构造带二进制负载的消息组件。 + + Args: + component_cls: 目标组件类型。 + item: 消息组件字典。 + + Returns: + StandardMessageComponents: 构造后的组件对象。 + """ + content = str(item.get("data") or "") + binary_hash = str(item.get("hash") or "") + raw_binary_base64 = item.get("binary_data_base64") + binary_data = b"" + if isinstance(raw_binary_base64, str) and raw_binary_base64: + try: + binary_data = base64.b64decode(raw_binary_base64) + except Exception: + binary_data = b"" + + if not binary_hash and binary_data: + binary_hash = hashlib.sha256(binary_data).hexdigest() + + return component_cls(binary_hash=binary_hash, content=content, binary_data=binary_data) + + @staticmethod + def _normalize_optional_string(value: Any) -> Optional[str]: + """将任意值规范化为可选字符串。 + + Args: + value: 待规范化的值。 + + Returns: + Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。 + """ + if value is None: + return None + normalized_value = str(value) + return normalized_value if normalized_value else None + + @staticmethod + def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict: + """ + 将 MessageInfo 对象转换为字典格式 + + Args: + message_info: MessageInfo 对象 + + Returns: + 字典格式的消息信息 + """ + user_info_dict = UserInfoDict( + user_id=message_info.user_info.user_id, + user_nickname=message_info.user_info.user_nickname, + user_cardname=message_info.user_info.user_cardname, + ) + + group_info_dict: Optional[GroupInfoDict] = None + if message_info.group_info: + group_info_dict = GroupInfoDict( + group_id=message_info.group_info.group_id, + group_name=message_info.group_info.group_name, + ) + + return MessageInfoDict( + user_info=user_info_dict, + group_info=group_info_dict, + additional_config=message_info.additional_config, + ) + + @staticmethod + def _session_message_to_dict(session_message: SessionMessage) -> MessageDict: + """ + 将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法) + + Args: + session_message: SessionMessage 对象 + + Returns: + 字典格式的消息 + """ + # 转换基本信息 + message_dict = MessageDict( + message_id=session_message.message_id, + timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串 + platform=session_message.platform, + message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info), + raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message), + is_mentioned=session_message.is_mentioned, + is_at=session_message.is_at, + is_emoji=session_message.is_emoji, + is_picture=session_message.is_picture, + is_command=session_message.is_command, + is_notify=session_message.is_notify, + session_id=session_message.session_id, + ) + + # 添加可选字段 + if session_message.reply_to is not None: + message_dict["reply_to"] = session_message.reply_to + if session_message.processed_plain_text is not None: + message_dict["processed_plain_text"] = session_message.processed_plain_text + if session_message.display_message is not None: + message_dict["display_message"] = session_message.display_message + + return message_dict + + @staticmethod + def _build_message_info_from_dict(message_info_dict: Dict[str, Any]) -> MessageInfo: + """ + 从字典构建 MessageInfo 对象 + + Args: + message_info_dict: 包含消息信息的字典 + + Returns: + MessageInfo 对象 + """ + # 构建用户信息 + user_info_dict = message_info_dict.get("user_info") + if not user_info_dict or not isinstance(user_info_dict, dict): + raise ValueError("消息字典中 'user_info' 字段无效") + user_id = user_info_dict.get("user_id") + user_nickname = user_info_dict.get("user_nickname") + user_cardname = user_info_dict.get("user_cardname") + if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname: + raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'") + user_cardname = str(user_cardname) if user_cardname is not None else None + user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname) + + # 构建群信息 + if group_info_dict := message_info_dict.get("group_info"): + group_id = group_info_dict.get("group_id") + group_name = group_info_dict.get("group_name") + if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name: + raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'") + group_info = GroupInfo(group_id=group_id, group_name=group_name) + else: + group_info = None + + # 获取额外配置 + additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {}) + + return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config) + + @staticmethod + def _build_session_message_from_dict(message_dict: Dict[str, Any]) -> SessionMessage: + """ + 从字典构建 SessionMessage 对象(递归处理消息组件) + + Args: + message_dict: 包含消息完整信息的字典 + + Returns: + SessionMessage 对象 + """ + # 提取基本信息 + message_id = message_dict["message_id"] + timestamp_str: str = message_dict.get("timestamp", "") + platform = message_dict["platform"] + if not isinstance(message_id, str) or not message_id: + raise ValueError("消息字典中缺少有效的 'message_id' 字段") + if not isinstance(platform, str) or not platform: + raise ValueError("消息字典中缺少有效的 'platform' 字段") + + # 解析时间戳 + try: + timestamp_float = float(timestamp_str) + timestamp = datetime.fromtimestamp(timestamp_float) + except (ValueError, TypeError): + timestamp = datetime.now() # 如果解析失败,使用当前时间 + + # 创建 SessionMessage 实例 + session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform) + + # 构建消息信息 + session_message.message_info = PluginMessageUtils._build_message_info_from_dict(message_dict["message_info"]) + + # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法) + raw_message_data = message_dict["raw_message"] + if isinstance(raw_message_data, list): + session_message.raw_message = PluginMessageUtils._message_sequence_from_dict(raw_message_data) + else: + raise ValueError("消息字典中 'raw_message' 字段必须是一个列表") + + # 设置其他可选属性 + session_message.is_mentioned = message_dict.get("is_mentioned", False) + if not isinstance(session_message.is_mentioned, bool): + session_message.is_mentioned = False + session_message.is_at = message_dict.get("is_at", False) + if not isinstance(session_message.is_at, bool): + session_message.is_at = False + session_message.is_emoji = message_dict.get("is_emoji", False) + if not isinstance(session_message.is_emoji, bool): + session_message.is_emoji = False + session_message.is_picture = message_dict.get("is_picture", False) + if not isinstance(session_message.is_picture, bool): + session_message.is_picture = False + session_message.is_command = message_dict.get("is_command", False) + if not isinstance(session_message.is_command, bool): + session_message.is_command = False + session_message.is_notify = message_dict.get("is_notify", False) + if not isinstance(session_message.is_notify, bool): + session_message.is_notify = False + session_message.session_id = message_dict.get("session_id", "") + if not isinstance(session_message.session_id, str): + session_message.session_id = "" + session_message.reply_to = message_dict.get("reply_to") + if session_message.reply_to is not None and not isinstance(session_message.reply_to, str): + session_message.reply_to = None + session_message.processed_plain_text = message_dict.get("processed_plain_text") + if session_message.processed_plain_text is not None and not isinstance( + session_message.processed_plain_text, str + ): + session_message.processed_plain_text = None + session_message.display_message = message_dict.get("display_message") + if session_message.display_message is not None and not isinstance(session_message.display_message, str): + session_message.display_message = None + + return session_message diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py deleted file mode 100644 index 61b32480..00000000 --- a/src/plugin_runtime/host/policy_engine.py +++ /dev/null @@ -1,97 +0,0 @@ -"""策略引擎 - -负责能力授权校验。 -每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。 -""" - -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Set, Tuple - - -@dataclass -class CapabilityToken: - """能力令牌""" - - plugin_id: str - generation: int - capabilities: Set[str] = field(default_factory=set) - - -class PolicyEngine: - """策略引擎 - - 管理所有插件的能力令牌,提供授权校验。 - """ - - def __init__(self) -> None: - self._tokens: Dict[str, Dict[int, CapabilityToken]] = {} - - def register_plugin( - self, - plugin_id: str, - generation: int, - capabilities: List[str], - ) -> CapabilityToken: - """为插件签发能力令牌""" - token = CapabilityToken( - plugin_id=plugin_id, - generation=generation, - capabilities=set(capabilities), - ) - self._tokens.setdefault(plugin_id, {})[generation] = token - return token - - def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None: - """撤销插件的能力令牌。""" - if generation is None: - self._tokens.pop(plugin_id, None) - return - - generations = self._tokens.get(plugin_id) - if generations is None: - return - - generations.pop(generation, None) - if not generations: - self._tokens.pop(plugin_id, None) - - def clear(self) -> None: - """清空所有能力令牌。""" - self._tokens.clear() - - def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]: - """检查插件是否有权调用某项能力 - - Returns: - (allowed, reason) - """ - generations = self._tokens.get(plugin_id) - if not generations: - return False, f"插件 {plugin_id} 未注册能力令牌" - - if generation is None: - token = generations[max(generations)] - else: - token = generations.get(generation) - if token is None: - active_generation = max(generations) - return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}" - - if capability not in token.capabilities: - return False, f"插件 {plugin_id} 未获授权能力: {capability}" - - if generation is not None and token.generation != generation: - return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}" - - return True, "" - - def get_token(self, plugin_id: str) -> Optional[CapabilityToken]: - """获取插件的能力令牌""" - generations = self._tokens.get(plugin_id) - if not generations: - return None - return generations[max(generations)] - - def list_plugins(self) -> List[str]: - """列出所有已注册的插件""" - return list(self._tokens.keys()) diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 79fe0d9a..eb6768c2 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -7,7 +7,7 @@ 4. 请求-响应关联与超时管理 """ -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine import asyncio import contextlib @@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer logger = get_logger("plugin_runtime.host.rpc_server") # RPC 方法处理器类型 -MethodHandler = Callable[[Envelope], Awaitable[Envelope]] +MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]] class RPCServer: @@ -55,108 +55,39 @@ class RPCServer: self._id_gen = RequestIdGenerator() self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接 - self._runner_id: Optional[str] = None - self._runner_generation: int = 0 - self._staged_connection: Optional[Connection] = None - self._staged_runner_id: Optional[str] = None - self._staged_runner_generation: int = 0 - self._staging_takeover: bool = False # 方法处理器注册表 self._method_handlers: Dict[str, MethodHandler] = {} - # 等待响应的 pending 请求: request_id -> (Future, target_generation) - self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {} + # 等待响应的 pending 请求: request_id -> Future + self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {} # 发送队列(背压控制) self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None - self._send_worker_task: Optional[asyncio.Task] = None + self._send_worker_task: Optional[asyncio.Task[None]] = None # 运行状态 self._running: bool = False - self._tasks: List[asyncio.Task] = [] + self._tasks: List[asyncio.Task[None]] = [] + self._last_handshake_rejection_reason: str = "" + self._connection_lock: asyncio.Lock = asyncio.Lock() @property def session_token(self) -> str: return self._session_token - def reset_session_token(self) -> str: - """重新生成会话令牌(热重载时调用,防止旧 Runner 重连)""" - self._session_token = secrets.token_hex(32) - return self._session_token - - def restore_session_token(self, token: str) -> None: - """恢复指定的会话令牌(热重载回滚时调用)""" - self._session_token = token - - @property - def runner_generation(self) -> int: - return self._runner_generation - - @property - def staged_generation(self) -> int: - return self._staged_runner_generation - @property def is_connected(self) -> bool: return self._connection is not None and not self._connection.is_closed - def has_generation(self, generation: int) -> bool: - return generation == self._runner_generation or ( - self._staged_connection is not None - and not self._staged_connection.is_closed - and generation == self._staged_runner_generation - ) + @property + def last_handshake_rejection_reason(self) -> str: + """返回最近一次握手被拒绝的原因。""" + return self._last_handshake_rejection_reason - def begin_staged_takeover(self) -> None: - """允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。""" - self._staging_takeover = True - - async def commit_staged_takeover(self) -> None: - """提交 staged Runner,原活跃连接在提交后被关闭。""" - if self._staged_connection is None or self._staged_connection.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接") - - old_connection = self._connection - old_generation = self._runner_generation - - self._connection = self._staged_connection - self._runner_id = self._staged_runner_id - self._runner_generation = self._staged_runner_generation - - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._staging_takeover = False - - if stale_count := self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已被新 generation 接管", - generation=old_generation, - ): - logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") - - if old_connection and old_connection is not self._connection and not old_connection.is_closed: - await old_connection.close() - - async def rollback_staged_takeover(self) -> None: - """放弃 staged Runner,保留当前活跃连接。""" - staged_connection = self._staged_connection - staged_generation = self._staged_runner_generation - - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._staging_takeover = False - - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "新 Runner 预热失败,已回滚", - generation=staged_generation, - ) - - if staged_connection and not staged_connection.is_closed: - await staged_connection.close() + def clear_handshake_state(self) -> None: + """清空最近一次握手拒绝状态。""" + self._last_handshake_rejection_reason = "" def register_method(self, method: str, handler: MethodHandler) -> None: """注册 RPC 方法处理器""" @@ -165,6 +96,7 @@ class RPCServer: async def start(self) -> None: """启动 RPC 服务器""" self._running = True + self.clear_handshake_state() self._send_queue = asyncio.Queue(maxsize=self._send_queue_size) self._send_worker_task = asyncio.create_task(self._send_loop()) await self._transport.start(self._handle_connection) @@ -173,14 +105,9 @@ class RPCServer: async def stop(self) -> None: """停止 RPC 服务器""" self._running = False - - # 取消所有 pending 请求 - for future, _generation in self._pending_requests.values(): - if not future.done(): - future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) - self._pending_requests.clear() - - self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭") + self.clear_handshake_state() + self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") + self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") if self._send_worker_task: self._send_worker_task.cancel() @@ -198,10 +125,6 @@ class RPCServer: await self._connection.close() self._connection = None - if self._staged_connection: - await self._staged_connection.close() - self._staged_connection = None - await self._transport.stop() logger.info("RPC Server 已停止") @@ -211,7 +134,6 @@ class RPCServer: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, - target_generation: Optional[int] = None, ) -> Envelope: """向 Runner 发送 RPC 请求并等待响应 @@ -227,18 +149,14 @@ class RPCServer: Raises: RPCError: 调用失败 """ - generation = target_generation or self._runner_generation - conn = self._get_connection_for_generation(generation) - if conn is None or conn.is_closed: + if not self._connection or self._connection.is_closed: raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, method=method, plugin_id=plugin_id, - generation=generation, timeout_ms=timeout_ms, payload=payload or {}, ) @@ -246,12 +164,12 @@ class RPCServer: # 注册 pending future loop = asyncio.get_running_loop() future: asyncio.Future[Envelope] = loop.create_future() - self._pending_requests[request_id] = (future, generation) + self._pending_requests[request_id] = future try: # 发送请求 data = self._codec.encode_envelope(envelope) - await self._enqueue_send(conn, data) + await self._enqueue_send(self._connection, data) # 等待响应 timeout_sec = timeout_ms / 1000.0 @@ -265,150 +183,136 @@ class RPCServer: raise raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e - async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None: - """向 Runner 发送单向事件(不等待响应)""" - conn = self._connection - if conn is None or conn.is_closed: - return + # ============ 内部方法 ============ + # ========= 发送循环 ========= + async def _send_loop(self) -> None: + """后台发送循环:串行消费发送队列,统一执行连接写入。""" + if self._send_queue is None: + raise RuntimeError("没有消息队列") - request_id = self._id_gen.next() - envelope = Envelope( - request_id=request_id, - message_type=MessageType.EVENT, - method=method, - plugin_id=plugin_id, - generation=self._runner_generation, - payload=payload or {}, - ) - data = self._codec.encode_envelope(envelope) - await self._enqueue_send(conn, data) + while True: + try: + conn, data, send_future = await self._send_queue.get() + except asyncio.CancelledError: + break - # ─── 内部方法 ────────────────────────────────────────────── + try: + if conn.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + await conn.send_frame(data) + if not send_future.done(): + send_future.set_result(None) + except asyncio.CancelledError: + if not send_future.done(): + send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) + raise + except Exception as e: + send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED}) + if not send_future.done(): + send_future.set_exception(send_error) + finally: + self._send_queue.task_done() + # ====== 发送循环方法 ====== async def _handle_connection(self, conn: Connection) -> None: """处理新的 Runner 连接""" logger.info("收到 Runner 连接") - previous_connection = self._connection - previous_generation = self._runner_generation - - # 第一条消息必须是 runner.hello 握手 try: - role = await self._handle_handshake(conn) - if role is None: - await conn.close() - return + async with self._connection_lock: + self.clear_handshake_state() + success = await self._handle_handshake(conn) + if not success: + await conn.close() + return + logger.info("Runner staged 握手成功") + self._connection = conn except Exception as e: logger.error(f"握手失败: {e}") await conn.close() return - if role == "staged": - expected_generation = self._staged_runner_generation - logger.info( - f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}" - ) - else: - self._connection = conn - expected_generation = self._runner_generation - logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}") - - if previous_connection and previous_connection is not conn and not previous_connection.is_closed: - logger.info("检测到新 Runner 已接管连接,关闭旧连接") - if stale_count := self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已被新 generation 接管", - generation=previous_generation, - ): - logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") - await previous_connection.close() - # 启动消息接收循环 try: - await self._recv_loop(conn, expected_generation=expected_generation) + await self._recv_loop(conn) except Exception as e: logger.error(f"连接异常断开: {e}") finally: - if self._connection is conn: - self._connection = None - self._runner_id = None - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已断开", - generation=expected_generation, - ) - elif self._staged_connection is conn: - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Staged Runner 连接已断开", - generation=expected_generation, - ) + should_fail_pending_requests = False + async with self._connection_lock: + if self._connection is conn: + self._connection = None + should_fail_pending_requests = True + if should_fail_pending_requests: + self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开") - async def _handle_handshake(self, conn: Connection) -> Optional[str]: + async def _handle_handshake(self, conn: Connection) -> bool: """处理 runner.hello 握手""" # 接收握手请求 data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0) envelope = self._codec.decode_envelope(data) - if envelope.method != "runner.hello": logger.error(f"期望 runner.hello,收到 {envelope.method}") + self._last_handshake_rejection_reason = "首条消息必须为 runner.hello" error_resp = envelope.make_error_response( ErrorCode.E_PROTOCOL_MISMATCH.value, "首条消息必须为 runner.hello", ) await conn.send_frame(self._codec.encode_envelope(error_resp)) - return None + return False # 解析握手 payload hello = HelloPayload.model_validate(envelope.payload) - # 校验会话令牌 if hello.session_token != self._session_token: logger.error("会话令牌不匹配") - resp_payload = HelloResponsePayload( - accepted=False, - reason="会话令牌无效", - ) + self._last_handshake_rejection_reason = "会话令牌无效" + resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return None + return False + + # 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。 + if self.is_connected: + logger.warning("拒绝新的 Runner 连接:已有活跃连接") + self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手" + resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason) + resp = envelope.make_response(payload=resp_payload.model_dump()) + await conn.send_frame(self._codec.encode_envelope(resp)) + return False # 校验 SDK 版本 if not self._check_sdk_version(hello.sdk_version): logger.error(f"SDK 版本不兼容: {hello.sdk_version}") + self._last_handshake_rejection_reason = ( + f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]" + ) resp_payload = HelloResponsePayload( accepted=False, - reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]", + reason=self._last_handshake_rejection_reason, ) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return None + return False - # 握手成功 - role = "active" - assigned_generation = self._runner_generation + 1 - if self._staging_takeover and self.is_connected: - role = "staged" - self._staged_connection = conn - self._staged_runner_id = hello.runner_id - self._staged_runner_generation = assigned_generation - else: - self._runner_id = hello.runner_id - self._runner_generation = assigned_generation - - resp_payload = HelloResponsePayload( - accepted=True, - host_version=PROTOCOL_VERSION, - assigned_generation=assigned_generation, - ) + # 发送响应 + self.clear_handshake_state() + resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) + return True - return role + def _check_sdk_version(self, sdk_version: str) -> bool: + """检查 SDK 版本是否在支持范围内""" + try: + sdk_parts = _parse_version_tuple(sdk_version) + min_parts = _parse_version_tuple(MIN_SDK_VERSION) + max_parts = _parse_version_tuple(MAX_SDK_VERSION) + return min_parts <= sdk_parts <= max_parts + except (ValueError, AttributeError): + return False - async def _recv_loop(self, conn: Connection, expected_generation: int) -> None: + # ========= 接收循环 ========= + async def _recv_loop(self, conn: Connection) -> None: """消息接收主循环""" while self._running and not conn.is_closed: try: @@ -430,109 +334,40 @@ class RPCServer: if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): - if envelope.generation != expected_generation: - error_resp = envelope.make_error_response( - ErrorCode.E_GENERATION_MISMATCH.value, - f"过期 generation: {envelope.generation} != {expected_generation}", - ) - await conn.send_frame(self._codec.encode_envelope(error_resp)) - continue # 异步处理请求(Runner 发来的能力调用) task = asyncio.create_task(self._handle_request(envelope, conn)) self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) - elif envelope.is_event(): - if envelope.generation != expected_generation: - logger.warning( - f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}" - ) - continue - task = asyncio.create_task(self._handle_event(envelope)) + elif envelope.is_broadcast(): + task = asyncio.create_task(self._handle_broadcast(envelope)) self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) + else: + logger.warning(f"未知的消息类型: {envelope.message_type}") + continue + # ====== 接收循环内部方法 ====== def _handle_response(self, envelope: Envelope) -> None: """处理来自 Runner 的响应""" - pending = self._pending_requests.get(envelope.request_id) - if pending is None: + pending_future = self._pending_requests.pop(envelope.request_id, None) + if pending_future is None: return - - future, expected_generation = pending - if envelope.generation != expected_generation: - logger.warning( - f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}" - ) - return - - self._pending_requests.pop(envelope.request_id, None) - if not future.done(): + if not pending_future.done(): if envelope.error: - future.set_exception(RPCError.from_dict(envelope.error)) + pending_future.set_exception(RPCError.from_dict(envelope.error)) else: - future.set_result(envelope) - - async def _enqueue_send(self, conn: Connection, data: bytes) -> None: - """通过发送队列串行发送消息,提供真实背压。""" - if conn.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - - if self._send_queue is None: - await conn.send_frame(data) - return - - loop = asyncio.get_running_loop() - send_future: asyncio.Future[None] = loop.create_future() - - try: - self._send_queue.put_nowait((conn, data, send_future)) - except asyncio.QueueFull: - raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None - - await send_future - - async def _send_loop(self) -> None: - """后台发送循环:串行消费发送队列,统一执行连接写入。""" - if self._send_queue is None: - return - - while True: - try: - conn, data, send_future = await self._send_queue.get() - except asyncio.CancelledError: - break - - try: - if conn.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - await conn.send_frame(data) - if not send_future.done(): - send_future.set_result(None) - except asyncio.CancelledError: - if not send_future.done(): - send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) - raise - except Exception as e: - send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e) - if not send_future.done(): - send_future.set_exception(send_error) - finally: - self._send_queue.task_done() - - @staticmethod - def _normalize_send_exception(error: Exception) -> RPCError: - if isinstance(error, ConnectionError): - return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error)) - return RPCError(ErrorCode.E_UNKNOWN, str(error)) + pending_future.set_result(envelope) async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: """处理来自 Runner 的请求(通常是能力调用 cap.*)""" - handler = self._method_handlers.get(envelope.method) - if handler is None: - error_resp = envelope.make_error_response( + target_method = envelope.method + handler = self._method_handlers.get(target_method) + if not handler: + error_response = envelope.make_error_response( ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的方法: {envelope.method}", ) - await conn.send_frame(self._codec.encode_envelope(error_resp)) + await conn.send_frame(self._codec.encode_envelope(error_response)) return try: @@ -546,59 +381,25 @@ class RPCServer: error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) await conn.send_frame(self._codec.encode_envelope(error_resp)) - async def _handle_event(self, envelope: Envelope) -> None: - """处理来自 Runner 的事件""" + async def _handle_broadcast(self, envelope: Envelope) -> None: if handler := self._method_handlers.get(envelope.method): try: result = await handler(envelope) # 检查 handler 返回的信封是否包含错误信息 - if result is not None and isinstance(result, Envelope) and result.error: + if result.error: logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}") except Exception as e: logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) - @staticmethod - def _check_sdk_version(sdk_version: str) -> bool: - """检查 SDK 版本是否在支持范围内""" - try: - sdk_parts = RPCServer._parse_version_tuple(sdk_version) - min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION) - max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION) - return min_parts <= sdk_parts <= max_parts - except (ValueError, AttributeError): - return False - - @staticmethod - def _parse_version_tuple(version: str) -> Tuple[int, int, int]: - base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0] - base_version = base_version.split("+", 1)[0] - parts = [part for part in base_version.split(".") if part != ""] - while len(parts) < 3: - parts.append("0") - return (int(parts[0]), int(parts[1]), int(parts[2])) - - def _get_connection_for_generation(self, generation: int) -> Optional[Connection]: - if generation == self._runner_generation: - return self._connection - if generation == self._staged_runner_generation: - return self._staged_connection - return None - - def _fail_pending_requests( - self, - error_code: ErrorCode, - message: str, - generation: Optional[int] = None, - ) -> int: - stale_count = 0 - for request_id, (future, request_generation) in list(self._pending_requests.items()): - if generation is not None and request_generation != generation: - continue + def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int: + """失败所有等待中的请求(如连接断开时)""" + aborted_request_count = 0 + for future in self._pending_requests.values(): if not future.done(): future.set_exception(RPCError(error_code, message)) - stale_count += 1 - self._pending_requests.pop(request_id, None) - return stale_count + aborted_request_count += 1 + self._pending_requests.clear() + return aborted_request_count def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int: if self._send_queue is None: @@ -617,3 +418,31 @@ class RPCServer: self._send_queue.task_done() return failed_count + + async def _enqueue_send(self, conn: Connection, data: bytes) -> None: + """通过发送队列串行发送消息,提供真实背压。""" + if conn.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + + if self._send_queue is None: + await conn.send_frame(data) + return + + loop = asyncio.get_running_loop() + send_future: asyncio.Future[None] = loop.create_future() + + try: + self._send_queue.put_nowait((conn, data, send_future)) + except asyncio.QueueFull: + raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None + + await send_future + + +def _parse_version_tuple(version: str) -> Tuple[int, int, int]: + base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0] + base_version = base_version.split("+", 1)[0] + parts = [part for part in base_version.split(".") if part != ""] + while len(parts) < 3: + parts.append("0") + return (int(parts[0]), int(parts[1]), int(parts[2])) diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index bfa00cbf..08638d16 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -1,97 +1,80 @@ -"""Supervisor - 插件生命周期管理 - -负责: -1. 拉起 Runner 子进程 -2. 健康检查 + 崩溃自动重启 -3. 代码热重载(generation 切换) -4. 优雅关停 -""" - -from typing import Any, Dict, List, Optional, Tuple +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import asyncio import contextlib -import logging as stdlib_logging +import json import os import sys -from pathlib import Path from src.common.logger import get_logger -from src.config.config import MMC_VERSION, global_config -from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN -from src.plugin_runtime.host.capability_service import CapabilityService -from src.plugin_runtime.host.component_registry import ComponentRegistry -from src.plugin_runtime.host.event_dispatcher import EventDispatcher -from src.plugin_runtime.host.policy_engine import PolicyEngine -from src.plugin_runtime.host.rpc_server import RPCServer -from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult +from src.config.config import config_manager, global_config +from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager +from src.platform_io.drivers import PluginPlatformDriver +from src.platform_io.route_key_factory import RouteKeyFactory +from src.plugin_runtime import ( + ENV_EXTERNAL_PLUGIN_IDS, + ENV_GLOBAL_CONFIG_SNAPSHOT, + ENV_HOST_VERSION, + ENV_IPC_ADDRESS, + ENV_PLUGIN_DIRS, + ENV_SESSION_TOKEN, +) from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, + ConfigReloadScope, ConfigUpdatedPayload, Envelope, HealthPayload, - LogBatchPayload, - RegisterComponentsPayload, + MessageGatewayStateUpdatePayload, + MessageGatewayStateUpdateResultPayload, + PROTOCOL_VERSION, + ReceiveExternalMessageResultPayload, + RegisterPluginPayload, + ReloadPluginResultPayload, + ReloadPluginsPayload, + ReloadPluginsResultPayload, + RouteMessagePayload, RunnerReadyPayload, ShutdownPayload, + UnregisterPluginPayload, ) +from src.plugin_runtime.protocol.codec import MsgPackCodec from src.plugin_runtime.protocol.errors import ErrorCode, RPCError from src.plugin_runtime.transport.factory import create_transport_server -logger = get_logger("plugin_runtime.host.supervisor") +from .authorization import AuthorizationManager +from .api_registry import APIRegistry +from .capability_service import CapabilityService +from .component_registry import ComponentRegistry +from .event_dispatcher import EventDispatcher +from .hook_dispatcher import HookDispatcher +from .logger_bridge import RunnerLogBridge +from .message_gateway import MessageGateway +from .rpc_server import RPCServer + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + +logger = get_logger("plugin_runtime.host.runner_manager") + +@dataclass(slots=True) +class _MessageGatewayRuntimeState: + """保存消息网关当前的运行时连接状态。""" + + ready: bool = False + platform: Optional[str] = None + account_id: Optional[str] = None + scope: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) -# ─── 日志桥 ────────────────────────────────────────────────────── +class PluginRunnerSupervisor: + """插件 Runner 监督器。 - -class RunnerLogBridge: - """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。 - - Runner 通过 ``runner.log_batch`` IPC 事件批量到达。 - 每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接 - 调用 ``logging.getLogger(entry.logger_name).handle(record)``, - 从而接入主进程已配置好的 structlog Handler 链。 - """ - - async def handle_log_batch(self, envelope: Envelope) -> Envelope: - """IPC 事件处理器:解析批量日志并重放到主进程 Logger。 - - Args: - envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。 - - Returns: - 空响应信封(事件模式下将被忽略)。 - """ - try: - batch = LogBatchPayload.model_validate(envelope.payload) - except Exception as exc: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - - for entry in batch.entries: - # 重建一个与原始日志尽量相符的 LogRecord - record = stdlib_logging.LogRecord( - name=entry.logger_name, - level=entry.level, - pathname="", - lineno=0, - msg=entry.message, - args=(), - exc_info=None, - ) - record.created = entry.timestamp_ms / 1000.0 - record.msecs = entry.timestamp_ms % 1000 - if entry.exception_text: - record.exc_text = entry.exception_text - - stdlib_logging.getLogger(entry.logger_name).handle(record) - - return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)}) - - -class PluginSupervisor: - """插件 Supervisor - - Host 端的核心管理器,负责整个插件 Runner 进程的生命周期。 + 负责 Host 侧与单个 Runner 子进程之间的生命周期、内部 RPC、 + 健康检查和插件级重载协调。 """ def __init__( @@ -101,196 +84,253 @@ class PluginSupervisor: health_check_interval_sec: Optional[float] = None, max_restart_attempts: Optional[int] = None, runner_spawn_timeout_sec: Optional[float] = None, - ): - _cfg = global_config.plugin_runtime - self._plugin_dirs = plugin_dirs or [] - self._health_interval = ( - health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec - ) - self._runner_spawn_timeout = ( - runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec - ) + ) -> None: + """初始化 Supervisor。 + + Args: + plugin_dirs: 由当前 Runner 负责加载的插件目录列表。 + socket_path: 自定义 IPC 地址;留空时由传输层自动生成。 + health_check_interval_sec: 健康检查间隔,单位秒。 + max_restart_attempts: 自动重启 Runner 的最大次数。 + runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。 + """ + runtime_config = global_config.plugin_runtime + self._plugin_dirs: List[Path] = plugin_dirs or [] + self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0 + self._runner_spawn_timeout: float = ( + runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0 + ) + self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3 - # 基础设施 self._transport = create_transport_server(socket_path=socket_path) - self._policy = PolicyEngine() - self._capability_service = CapabilityService(self._policy) + self._authorization = AuthorizationManager() + self._capability_service = CapabilityService(self._authorization) + self._api_registry = APIRegistry() self._component_registry = ComponentRegistry() self._event_dispatcher = EventDispatcher(self._component_registry) - self._workflow_executor = WorkflowExecutor(self._component_registry) - - # 编解码 - from src.plugin_runtime.protocol.codec import MsgPackCodec + self._hook_dispatcher = HookDispatcher(self._component_registry) + self._message_gateway = MessageGateway(self._component_registry) + self._log_bridge = RunnerLogBridge() codec = MsgPackCodec() + self._rpc_server = RPCServer(transport=self._transport, codec=codec) - self._rpc_server = RPCServer( - transport=self._transport, - codec=codec, - ) - - # Runner 子进程 self._runner_process: Optional[asyncio.subprocess.Process] = None - self._runner_generation: int = 0 - self._max_restart_attempts: int = ( - max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts - ) + self._registered_plugins: Dict[str, RegisterPluginPayload] = {} + self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {} + self._external_available_plugins: Dict[str, str] = {} + self._runner_ready_events: asyncio.Event = asyncio.Event() + self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() + self._health_task: Optional[asyncio.Task[None]] = None + self._stderr_drain_task: Optional[asyncio.Task[None]] = None self._restart_count: int = 0 + self._running: bool = False - # 已注册的插件组件信息 - self._registered_plugins: Dict[str, RegisterComponentsPayload] = {} - self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {} - self._runner_ready_events: Dict[int, asyncio.Event] = {} - self._runner_ready_payloads: Dict[int, RunnerReadyPayload] = {} - - # 后台任务 - self._health_task: Optional[asyncio.Task] = None - # Runner stderr 流排空任务(仅保留 stderr,用于 IPC 建立前的启动日志倒空、致命错误输出等场景) - self._stderr_drain_task: Optional[asyncio.Task] = None - self._running = False - - # Runner 日志桥(将 Runner 上报的批量日志重放到主进程 Logger) - self._log_bridge: RunnerLogBridge = RunnerLogBridge() - - # 注册内部 RPC 方法 self._register_internal_methods() @property - def policy_engine(self) -> PolicyEngine: - return self._policy + def authorization_manager(self) -> AuthorizationManager: + """返回授权管理器。""" + return self._authorization @property def capability_service(self) -> CapabilityService: + """返回能力服务。""" return self._capability_service + @property + def api_registry(self) -> APIRegistry: + """返回 API 专用注册表。""" + return self._api_registry + @property def component_registry(self) -> ComponentRegistry: + """返回组件注册表。""" return self._component_registry @property def event_dispatcher(self) -> EventDispatcher: + """返回事件分发器。""" return self._event_dispatcher @property - def workflow_executor(self) -> WorkflowExecutor: - return self._workflow_executor + def hook_dispatcher(self) -> HookDispatcher: + """返回 Hook 分发器。""" + return self._hook_dispatcher + + @property + def message_gateway(self) -> MessageGateway: + """返回消息网关。""" + return self._message_gateway @property def rpc_server(self) -> RPCServer: + """返回底层 RPC 服务端。""" return self._rpc_server + def set_external_available_plugins(self, plugin_versions: Dict[str, str]) -> None: + """设置当前 Runner 启动/重载时可视为已满足的外部依赖版本映射。 + + Args: + plugin_versions: 外部插件版本映射,键为插件 ID,值为插件版本。 + """ + self._external_available_plugins = { + str(plugin_id or "").strip(): str(plugin_version or "").strip() + for plugin_id, plugin_version in plugin_versions.items() + if str(plugin_id or "").strip() and str(plugin_version or "").strip() + } + + def get_loaded_plugin_ids(self) -> List[str]: + """返回当前 Supervisor 已注册的插件 ID 列表。""" + + return sorted(self._registered_plugins.keys()) + + def get_loaded_plugin_versions(self) -> Dict[str, str]: + """返回当前 Supervisor 已注册插件的版本映射。 + + Returns: + Dict[str, str]: 已注册插件版本映射,键为插件 ID,值为插件版本。 + """ + return { + plugin_id: registration.plugin_version + for plugin_id, registration in self._registered_plugins.items() + } + + @staticmethod + def _normalize_reload_plugin_ids(plugin_ids: Optional[List[str] | str]) -> List[str]: + """规范化批量重载入参。 + + Args: + plugin_ids: 原始插件 ID 列表或单个插件 ID。 + + Returns: + List[str]: 去重且去空白后的插件 ID 列表。 + """ + + raw_plugin_ids: List[str] + if plugin_ids is None: + raw_plugin_ids = [] + elif isinstance(plugin_ids, str): + raw_plugin_ids = [plugin_ids] + else: + raw_plugin_ids = list(plugin_ids) + + normalized_plugin_ids: List[str] = [] + seen_plugin_ids: set[str] = set() + for plugin_id in raw_plugin_ids: + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids: + continue + seen_plugin_ids.add(normalized_plugin_id) + normalized_plugin_ids.append(normalized_plugin_id) + return normalized_plugin_ids + async def dispatch_event( self, event_type: str, - message: Optional[Dict[str, Any]] = None, + message: Optional["SessionMessage"] = None, extra_args: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]]]: - """分发事件到所有对应 handler 的快捷方法。""" + ) -> Tuple[bool, Optional["SessionMessage"]]: + """分发事件到已注册的事件处理器。 - async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: - resp = await self.invoke_plugin( - method="plugin.emit_event", - plugin_id=plugin_id, - component_name=component_name, - args=args, - ) - return resp.payload + Args: + event_type: 事件类型。 + message: 可选的消息对象。 + extra_args: 附加参数。 - return await self._event_dispatcher.dispatch_event( - event_type=event_type, - invoke_fn=_invoke, - message=message, - extra_args=extra_args, - ) + Returns: + Tuple[bool, Optional[SessionMessage]]: 是否继续处理,以及插件可能修改后的消息。 + """ + return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args) - async def execute_workflow( + async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]: + """分发 Hook 到已注册的 Hook 处理器。 + + Args: + stage: Hook 阶段名称。 + **kwargs: 传递给 Hook 的关键字参数。 + + Returns: + Dict[str, Any]: 经 Hook 修改后的参数字典。 + """ + return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs) + + async def send_message_to_external( self, - message: Optional[Dict[str, Any]] = None, - stream_id: Optional[str] = None, - context: Optional[WorkflowContext] = None, - ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]: - """执行 Workflow Pipeline 的快捷方法。""" + internal_message: "SessionMessage", + *, + enabled_only: bool = True, + save_to_db: bool = True, + ) -> bool: + """通过插件消息网关发送外部消息。 - async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: - resp = await self.invoke_plugin( - method="plugin.invoke_workflow_step", - plugin_id=plugin_id, - component_name=component_name, - args=args, - ) - payload = resp.payload - if payload.get("success"): - result = payload.get("result") - return result if isinstance(result, dict) else {} - raise RuntimeError(payload.get("result", "workflow step invoke failed")) + Args: + internal_message: 系统内部消息对象。 + enabled_only: 是否仅使用启用的网关组件。 + save_to_db: 发送成功后是否写入数据库。 - async def _command_invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: - """命令走 plugin.invoke_command,保留原始返回值结构。""" - resp = await self.invoke_plugin( - method="plugin.invoke_command", - plugin_id=plugin_id, - component_name=component_name, - args=args, - ) - return resp.payload - - return await self._workflow_executor.execute( - invoke_fn=_invoke, - message=message, - stream_id=stream_id, - context=context, - command_invoke_fn=_command_invoke, + Returns: + bool: 是否发送成功。 + """ + return await self._message_gateway.send_message_to_external( + internal_message, + self, + enabled_only=enabled_only, + save_to_db=save_to_db, ) async def start(self) -> None: - """启动 Supervisor + """启动 Supervisor。""" + if self._running: + logger.warning("PluginRunnerSupervisor 已在运行,跳过重复启动") + return - 1. 启动 RPC Server - 2. 拉起 Runner 子进程 - 3. 启动健康检查 - """ self._running = True + self._restart_count = 0 + self._clear_runner_state() - # 启动 RPC Server - await self._rpc_server.start() - - # 计算预期 generation(与 reload_plugins 保持一致) - expected_generation = self._rpc_server.runner_generation + 1 - - # 拉起 Runner 进程 - await self._spawn_runner() - - # 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪 try: - await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout) - await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout) - except TimeoutError: - if not self._rpc_server.is_connected: - logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败") - else: - logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成初始化,后续操作可能失败") + await self._rpc_server.start() + await self._spawn_runner() - # 启动健康检查 - self._health_task = asyncio.create_task(self._health_check_loop()) + try: + await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) + await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) + except TimeoutError: + if not self._rpc_server.is_connected: + logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败") + else: + logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败") + except Exception: + await self._shutdown_runner(reason="startup_failed") + await self._rpc_server.stop() + self._clear_runner_state() + self._running = False + raise - logger.info("PluginSupervisor 已启动") + self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health") + logger.info("PluginRunnerSupervisor 已启动") async def stop(self) -> None: - """停止 Supervisor""" + """停止 Supervisor。""" + if not self._running: + return + self._running = False - # 停止健康检查 - if self._health_task: + if self._health_task is not None: self._health_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._health_task self._health_task = None - # 优雅关停 Runner - await self._shutdown_runner() - - # 停止 RPC Server + await self._event_dispatcher.stop() + await self._hook_dispatcher.stop() + await self._shutdown_runner(reason="host_stop") await self._rpc_server.stop() + self._clear_runner_state() - logger.info("PluginSupervisor 已停止") + logger.info("PluginRunnerSupervisor 已停止") async def invoke_plugin( self, @@ -300,444 +340,1068 @@ class PluginSupervisor: args: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> Envelope: - """调用插件组件 + """调用 Runner 内的插件组件。 - 由主进程业务逻辑调用,通过 RPC 转发给 Runner。 + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + component_name: 组件名。 + args: 调用参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: RPC 响应信封。 """ return await self._rpc_server.send_request( - method=method, + method, + plugin_id, + {"component_name": component_name, "args": args or {}}, + timeout_ms, + ) + + async def invoke_message_gateway( + self, + plugin_id: str, + component_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Envelope: + """调用插件声明的消息网关方法。 + + Args: + plugin_id: 目标插件 ID。 + component_name: 消息网关组件名称。 + args: 传递给网关方法的关键字参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: Runner 返回的响应信封。 + """ + + return await self.invoke_plugin( + method="plugin.invoke_message_gateway", plugin_id=plugin_id, - payload={ - "component_name": component_name, - "args": args or {}, - }, + component_name=component_name, + args=args, timeout_ms=timeout_ms, ) - async def reload_plugins(self, reason: str = "manual") -> bool: - """热重载所有插件(进程级 generation 切换) + async def invoke_api( + self, + plugin_id: str, + component_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Envelope: + """调用插件声明的 API 方法。 - 1. 拉起新 Runner - 2. 等待新 Runner 完成注册和健康检查 - 3. 关停旧 Runner + Args: + plugin_id: 目标插件 ID。 + component_name: API 组件名称。 + args: 传递给 API 方法的关键字参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: Runner 返回的响应信封。 """ - logger.info(f"开始热重载插件,原因: {reason}") - # 保存旧进程引用和旧 session token(回滚时需要恢复) - old_process = self._runner_process - old_registered_plugins = dict(self._registered_plugins) - old_session_token = self._rpc_server.session_token - expected_generation = self._rpc_server.runner_generation + 1 + return await self.invoke_plugin( + method="plugin.invoke_api", + plugin_id=plugin_id, + component_name=component_name, + args=args, + timeout_ms=timeout_ms, + ) - # 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接 - self._rpc_server.begin_staged_takeover() - self._staged_registered_plugins.clear() + async def reload_plugin( + self, + plugin_id: str, + reason: str = "manual", + external_available_plugins: Optional[Dict[str, str]] = None, + ) -> bool: + """按插件 ID 触发精确重载。 - # 重新生成 session token,防止被终止的旧 Runner 重连 - self._rpc_server.reset_session_token() + Args: + plugin_id: 目标插件 ID。 + reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件版本映射。 - # 注意:不在此处调用 _clear_runtime_state()。 - # 旧组件在新 Runner 完成注册前继续提供服务,避免热重载窗口期内 - # dispatch_event / execute_workflow 找不到任何组件导致消息静默丢失。 - # ComponentRegistry.register_component 对同名组件是覆盖式写入,安全。 - - # 拉起新 Runner + Returns: + bool: 是否重载成功。 + """ try: - await self._spawn_runner() - await self._wait_for_runner_generation( - expected_generation, - timeout_sec=self._runner_spawn_timeout, - allow_staged=True, + response = await self._rpc_server.send_request( + "plugin.reload", + plugin_id=plugin_id, + payload={ + "plugin_id": plugin_id, + "reason": reason, + "external_available_plugins": external_available_plugins or self._external_available_plugins, + }, + timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), ) - await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout) - resp = await self._rpc_server.send_request( - "plugin.health", - timeout_ms=5000, - target_generation=expected_generation, - ) - health = HealthPayload.model_validate(resp.payload) - if not health.healthy: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败") - await self._rpc_server.commit_staged_takeover() - except Exception as e: - logger.error(f"新 Runner 健康检查失败: {e},回滚") - await self._terminate_process(self._runner_process, old_process) - await self._rpc_server.rollback_staged_takeover() - self._runner_process = old_process - self._rpc_server.restore_session_token(old_session_token) - self._staged_registered_plugins.clear() - self._registered_plugins = dict(old_registered_plugins) - self._rebuild_runtime_state() + except Exception as exc: + logger.error(f"插件 {plugin_id} 重载请求失败: {exc}") return False - self._runner_generation = self._rpc_server.runner_generation - self._registered_plugins = dict(self._staged_registered_plugins) - self._staged_registered_plugins.clear() - self._rebuild_runtime_state() + result = ReloadPluginResultPayload.model_validate(response.payload) + if not result.success: + logger.warning(f"插件 {plugin_id} 重载失败: {result.failed_plugins}") + return result.success - # 关停旧 Runner - if old_process and old_process.returncode is None: - try: - old_process.terminate() - await asyncio.wait_for(old_process.wait(), timeout=10.0) - except asyncio.TimeoutError: - old_process.kill() + async def reload_plugins( + self, + plugin_ids: Optional[List[str] | str] = None, + reason: str = "manual", + external_available_plugins: Optional[Dict[str, str]] = None, + ) -> bool: + """批量重载插件。 - logger.info("热重载完成") - return True + Args: + plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。 + reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件版本映射。 + + Returns: + bool: 是否全部重载成功。 + """ + ordered_plugin_ids = self._normalize_reload_plugin_ids(plugin_ids) + if not ordered_plugin_ids: + ordered_plugin_ids = list(self._registered_plugins.keys()) + if not ordered_plugin_ids: + return True + + if len(ordered_plugin_ids) == 1: + return await self.reload_plugin( + plugin_id=ordered_plugin_ids[0], + reason=reason, + external_available_plugins=external_available_plugins, + ) + + try: + response = await self._rpc_server.send_request( + "plugin.reload_batch", + payload=ReloadPluginsPayload( + plugin_ids=ordered_plugin_ids, + reason=reason, + external_available_plugins=external_available_plugins or self._external_available_plugins, + ).model_dump(), + timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), + ) + except Exception as exc: + logger.error(f"插件批量重载请求失败: {exc}") + return False + + result = ReloadPluginsResultPayload.model_validate(response.payload) + if not result.success: + logger.warning(f"插件批量重载失败: {result.failed_plugins}") + return result.success async def notify_plugin_config_updated( self, plugin_id: str, - config_data: Dict[str, Any], + config_data: Optional[Dict[str, Any]] = None, config_version: str = "", + config_scope: str | ConfigReloadScope = "self", ) -> bool: - """通知指定插件其配置已更新。""" - if plugin_id not in self._registered_plugins: + """向 Runner 推送插件配置更新。 + + Args: + plugin_id: 目标插件 ID。 + config_data: 配置内容。 + config_version: 配置版本号。 + config_scope: 配置变更范围。 + + Returns: + bool: 请求是否成功送达并被 Runner 接受。 + """ + try: + normalized_scope = ConfigReloadScope(config_scope) + except ValueError: + logger.warning(f"插件 {plugin_id} 配置更新通知失败: 非法的 config_scope={config_scope}") return False payload = ConfigUpdatedPayload( plugin_id=plugin_id, + config_scope=normalized_scope, config_version=config_version, - config_data=config_data, + config_data=config_data or {}, ) - await self._rpc_server.send_request( - "plugin.config_updated", - plugin_id=plugin_id, - payload=payload.model_dump(), - timeout_ms=5000, - ) - return True + try: + response = await self._rpc_server.send_request( + "plugin.config_updated", + plugin_id=plugin_id, + payload=payload.model_dump(), + timeout_ms=10000, + ) + except Exception as exc: + logger.warning(f"插件 {plugin_id} 配置更新通知失败: {exc}") + return False - # ─── 内部方法 ────────────────────────────────────────────── + return bool(response.payload.get("acknowledged", False)) + + def get_config_reload_subscribers(self, scope: str) -> List[str]: + """返回订阅指定全局配置广播的插件列表。 + + Args: + scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。 + + Returns: + List[str]: 已声明订阅该范围的插件 ID 列表。 + """ + + return [ + plugin_id + for plugin_id, registration in self._registered_plugins.items() + if scope in registration.config_reload_subscriptions + ] + + async def _wait_for_runner_connection(self, timeout_sec: float) -> None: + """等待 Runner 建立 RPC 连接。 + + Args: + timeout_sec: 超时时间,单位秒。 + + Raises: + TimeoutError: 在超时时间内 Runner 未完成连接。 + """ + + async def wait_for_connection() -> None: + """轮询等待 RPC 连接建立。""" + while True: + if self._rpc_server.is_connected: + return + + if not self._running: + raise RuntimeError("Supervisor 已停止,等待 Runner 连接已取消") + + if failure_reason := self._get_runner_startup_failure_reason(): + raise RuntimeError(f"等待 Runner 连接失败: {failure_reason}") + + await asyncio.sleep(0.1) + + try: + await asyncio.wait_for(wait_for_connection(), timeout=timeout_sec) + logger.info("Runner 已连接到 RPC Server") + except asyncio.TimeoutError as exc: + raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from exc + + async def _wait_for_runner_ready(self, timeout_sec: float = 30.0) -> RunnerReadyPayload: + """等待 Runner 完成启动初始化。 + + Args: + timeout_sec: 超时时间,单位秒。 + + Returns: + RunnerReadyPayload: Runner 上报的就绪信息。 + + Raises: + TimeoutError: 在超时时间内 Runner 未完成初始化。 + """ + async def wait_for_ready() -> RunnerReadyPayload: + """轮询等待 Runner 上报就绪。""" + while True: + if self._runner_ready_events.is_set(): + return self._runner_ready_payloads + + if not self._running: + raise RuntimeError("Supervisor 已停止,等待 Runner 就绪已取消") + + if failure_reason := self._get_runner_startup_failure_reason(): + raise RuntimeError(f"等待 Runner 就绪失败: {failure_reason}") + + if not self._rpc_server.is_connected: + raise RuntimeError("等待 Runner 就绪失败: Runner RPC 连接已断开") + + await asyncio.sleep(0.1) + + try: + payload = await asyncio.wait_for(wait_for_ready(), timeout=timeout_sec) + logger.info("Runner 已完成初始化并上报就绪") + return payload + except asyncio.TimeoutError as exc: + raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc def _register_internal_methods(self) -> None: - """注册 Host 端的 RPC 方法处理器""" - # Runner -> Host 的能力调用统一走 capability_service - self._rpc_server.register_method("cap.request", self._capability_service.handle_capability_request) + """注册 Host 侧内部 RPC 方法。""" + self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request) + self._rpc_server.register_method("host.route_message", self._handle_route_message) + self._rpc_server.register_method("host.update_message_gateway_state", self._handle_update_message_gateway_state) self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) - # 插件注册 - self._rpc_server.register_method("plugin.register_components", self._handle_register_components) - self._rpc_server.register_method("runner.ready", self._handle_runner_ready) - # Runner 日志批量上报 + self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin) + self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin) + self._rpc_server.register_method("plugin.unregister", self._handle_unregister_plugin) self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch) + self._rpc_server.register_method("runner.ready", self._handle_runner_ready) async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope: - """处理插件 bootstrap 请求,仅同步能力令牌。""" + """处理插件 bootstrap 请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ try: - bootstrap = BootstrapPluginPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) + payload = BootstrapPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - active_generation = self._rpc_server.runner_generation - staged_generation = self._rpc_server.staged_generation - if envelope.generation not in {active_generation, staged_generation}: - return envelope.make_error_response( - ErrorCode.E_GENERATION_MISMATCH.value, - f"插件 bootstrap generation 过期: {envelope.generation} 不在已知代际中", - ) - - if bootstrap.capabilities_required: - self._policy.register_plugin( - plugin_id=bootstrap.plugin_id, - generation=envelope.generation, - capabilities=bootstrap.capabilities_required, - ) + if payload.capabilities_required: + self._authorization.register_plugin(payload.plugin_id, payload.capabilities_required) else: - self._policy.revoke_plugin(bootstrap.plugin_id, generation=envelope.generation) + self._authorization.revoke_permission_token(payload.plugin_id) - return envelope.make_response(payload={"accepted": True}) + return envelope.make_response(payload={"accepted": True, "plugin_id": payload.plugin_id}) - async def _handle_register_components(self, envelope: Envelope) -> Envelope: - """处理插件组件注册请求""" + async def _handle_register_plugin(self, envelope: Envelope) -> Envelope: + """处理插件组件注册请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ try: - reg = RegisterComponentsPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) + payload = RegisterPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - active_generation = self._rpc_server.runner_generation - staged_generation = self._rpc_server.staged_generation - if envelope.generation not in {active_generation, staged_generation}: + component_declarations = [component.model_dump() for component in payload.components] + runtime_components, api_components = self._split_component_declarations(component_declarations) + self._component_registry.remove_components_by_plugin(payload.plugin_id) + self._api_registry.remove_apis_by_plugin(payload.plugin_id) + await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) + + registered_count = self._component_registry.register_plugin_components( + payload.plugin_id, + runtime_components, + ) + registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components) + self._registered_plugins[payload.plugin_id] = payload + self._message_gateway_states[payload.plugin_id] = {} + + return envelope.make_response( + payload={ + "accepted": True, + "plugin_id": payload.plugin_id, + "registered_components": registered_count, + "registered_apis": registered_api_count, + "message_gateways": len( + self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False) + ), + } + ) + + async def _handle_unregister_plugin(self, envelope: Envelope) -> Envelope: + """处理插件注销请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = UnregisterPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) + removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id) + self._authorization.revoke_permission_token(payload.plugin_id) + removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None + await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) + self._message_gateway_states.pop(payload.plugin_id, None) + + return envelope.make_response( + payload={ + "accepted": True, + "plugin_id": payload.plugin_id, + "reason": payload.reason, + "removed_components": removed_components, + "removed_apis": removed_apis, + "removed_registration": removed_registration, + } + ) + + @staticmethod + def _is_api_component(component: Dict[str, Any]) -> bool: + """判断组件声明是否属于 API。 + + Args: + component: 原始组件声明字典。 + + Returns: + bool: 是否为 API 组件。 + """ + + return str(component.get("component_type", "") or "").strip().upper() == "API" + + def _split_component_declarations( + self, + components: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """拆分通用组件声明和 API 声明。 + + Args: + components: Runner 上报的原始组件声明列表。 + + Returns: + Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + 第一个列表为需要进入通用组件表的声明, + 第二个列表为需要进入 API 专用表的声明。 + """ + + runtime_components: List[Dict[str, Any]] = [] + api_components: List[Dict[str, Any]] = [] + for component in components: + if self._is_api_component(component): + api_components.append(component) + else: + runtime_components.append(component) + return runtime_components, api_components + + @staticmethod + def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str: + """构造消息网关驱动 ID。 + + Args: + plugin_id: 插件 ID。 + gateway_name: 网关组件名称。 + + Returns: + str: 对应 Platform IO 中的驱动 ID。 + """ + + return f"gateway:{plugin_id}:{gateway_name}" + + @staticmethod + def _normalize_runtime_route_value(value: str) -> Optional[str]: + """规范化运行时路由字段。 + + Args: + value: 待规范化的原始字符串。 + + Returns: + Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。 + """ + + normalized_value = str(value or "").strip() + return normalized_value or None + + def _resolve_message_gateway_entry( + self, + plugin_id: str, + gateway_name: str, + ) -> Optional[Any]: + """解析指定插件的消息网关组件。 + + Args: + plugin_id: 插件 ID。 + gateway_name: 网关组件名称;为空时按兼容规则推断。 + + Returns: + Optional[Any]: 匹配到的消息网关组件条目。 + """ + + if gateway_name: + return self._component_registry.get_message_gateway( + plugin_id=plugin_id, + name=gateway_name, + enabled_only=False, + ) + + gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False) + return gateways[0] if len(gateways) == 1 else None + + async def _register_message_gateway_driver( + self, + plugin_id: str, + gateway_entry: Any, + route_key: RouteKey, + ) -> None: + """为消息网关注册驱动并绑定发送/接收路由。 + + Args: + plugin_id: 插件 ID。 + gateway_entry: 消息网关组件条目。 + route_key: 当前链路对应的路由键。 + """ + + await self._unregister_message_gateway_driver(plugin_id, gateway_entry.name) + + platform_io_manager = get_platform_io_manager() + driver = PluginPlatformDriver( + driver_id=self._build_message_gateway_driver_id(plugin_id, gateway_entry.name), + platform=route_key.platform, + account_id=route_key.account_id, + scope=route_key.scope, + plugin_id=plugin_id, + component_name=gateway_entry.name, + supports_send=bool(gateway_entry.supports_send), + supervisor=self, + metadata={ + "protocol": gateway_entry.protocol, + "route_type": gateway_entry.route_type, + **gateway_entry.metadata, + }, + ) + + try: + if platform_io_manager.is_started: + await platform_io_manager.add_driver(driver) + else: + platform_io_manager.register_driver(driver) + except Exception: + with contextlib.suppress(Exception): + if platform_io_manager.is_started: + await platform_io_manager.remove_driver(driver.driver_id) + else: + platform_io_manager.unregister_driver(driver.driver_id) + raise + + binding_metadata = { + "plugin_id": plugin_id, + "gateway_name": gateway_entry.name, + "protocol": gateway_entry.protocol, + "route_type": gateway_entry.route_type, + **gateway_entry.metadata, + } + binding = RouteBinding( + route_key=route_key, + driver_id=driver.driver_id, + driver_kind=DriverKind.PLUGIN, + metadata=binding_metadata, + ) + if gateway_entry.supports_send: + platform_io_manager.bind_send_route(binding) + if gateway_entry.supports_receive: + platform_io_manager.bind_receive_route(binding) + + async def _unregister_message_gateway_driver(self, plugin_id: str, gateway_name: str) -> None: + """从 Platform IO 注销单个消息网关驱动。 + + Args: + plugin_id: 插件 ID。 + gateway_name: 网关组件名称。 + """ + + platform_io_manager = get_platform_io_manager() + driver_id = self._build_message_gateway_driver_id(plugin_id, gateway_name) + platform_io_manager.send_route_table.remove_bindings_by_driver(driver_id) + platform_io_manager.receive_route_table.remove_bindings_by_driver(driver_id) + + with contextlib.suppress(Exception): + if platform_io_manager.is_started: + await platform_io_manager.remove_driver(driver_id) + else: + platform_io_manager.unregister_driver(driver_id) + + async def _unregister_all_message_gateway_drivers_for_plugin(self, plugin_id: str) -> None: + """注销指定插件的全部消息网关驱动。 + + Args: + plugin_id: 插件 ID。 + """ + + gateway_names = list(self._message_gateway_states.get(plugin_id, {}).keys()) + for gateway_name in gateway_names: + await self._unregister_message_gateway_driver(plugin_id, gateway_name) + + def _build_message_gateway_route_key( + self, + gateway_entry: Any, + payload: MessageGatewayStateUpdatePayload, + ) -> RouteKey: + """根据消息网关运行时状态构造路由键。 + + Args: + gateway_entry: 消息网关组件条目。 + payload: 网关上报的运行时状态。 + + Returns: + RouteKey: 当前链路对应的路由键。 + + Raises: + ValueError: 当平台信息缺失时抛出。 + """ + + if not (platform := str(payload.platform or gateway_entry.platform or "").strip()): + raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称") + + return RouteKey( + platform=platform, + account_id=self._normalize_runtime_route_value(payload.account_id) or gateway_entry.account_id or None, + scope=self._normalize_runtime_route_value(payload.scope) or gateway_entry.scope or None, + ) + + def _apply_message_gateway_state( + self, + plugin_id: str, + gateway_entry: Any, + payload: MessageGatewayStateUpdatePayload, + ) -> Tuple[_MessageGatewayRuntimeState, Dict[str, Any]]: + """应用消息网关运行时状态,并同步 Platform IO 路由。 + + Args: + plugin_id: 插件 ID。 + gateway_entry: 消息网关组件条目。 + payload: 网关上报的运行时状态。 + + Returns: + Tuple[_MessageGatewayRuntimeState, Dict[str, Any]]: 更新后的状态与路由键字典。 + """ + + plugin_states = self._message_gateway_states.setdefault(plugin_id, {}) + if not payload.ready: + runtime_state = _MessageGatewayRuntimeState( + ready=False, + platform=self._normalize_runtime_route_value(payload.platform) or gateway_entry.platform or None, + account_id=self._normalize_runtime_route_value(payload.account_id) or gateway_entry.account_id or None, + scope=self._normalize_runtime_route_value(payload.scope) or gateway_entry.scope or None, + metadata=dict(payload.metadata), + ) + plugin_states[gateway_entry.name] = runtime_state + return runtime_state, {} + + route_key = self._build_message_gateway_route_key(gateway_entry, payload) + runtime_state = _MessageGatewayRuntimeState( + ready=True, + platform=route_key.platform, + account_id=route_key.account_id, + scope=route_key.scope, + metadata=dict(payload.metadata), + ) + plugin_states[gateway_entry.name] = runtime_state + return runtime_state, { + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + } + + @staticmethod + def _attach_inbound_route_metadata( + session_message: "SessionMessage", + route_key: RouteKey, + route_metadata: Dict[str, Any], + ) -> None: + """将入站路由信息写回消息的 ``additional_config``。 + + Args: + session_message: 已构造好的内部消息对象。 + route_key: Host 为该消息解析出的标准路由键。 + route_metadata: 插件通过 RPC 补充的原始路由辅助元数据。 + """ + + additional_config = session_message.message_info.additional_config + if not isinstance(additional_config, dict): + additional_config = {} + session_message.message_info.additional_config = additional_config + + for key, value in route_metadata.items(): + if value is None: + continue + normalized_value = str(value).strip() + if normalized_value: + additional_config[key] = value + + if route_key.account_id: + additional_config.setdefault("platform_io_account_id", route_key.account_id) + if route_key.scope: + additional_config.setdefault("platform_io_scope", route_key.scope) + + def _build_inbound_route_key( + self, + gateway_entry: Any, + runtime_state: _MessageGatewayRuntimeState, + message: Dict[str, Any], + route_metadata: Dict[str, Any], + ) -> RouteKey: + """为入站消息构造归一路由键。 + + Args: + gateway_entry: 接收消息的网关组件条目。 + runtime_state: 当前网关的运行时状态。 + message: 标准消息字典。 + route_metadata: 插件补充的路由辅助元数据。 + + Returns: + RouteKey: 供 Platform IO 使用的规范化路由键。 + """ + + platform = str( + message.get("platform") + or route_metadata.get("platform") + or runtime_state.platform + or gateway_entry.platform + or "" + ).strip() + if not platform: + raise ValueError(f"消息网关 {gateway_entry.full_name} 的入站消息缺少平台信息") + + try: + route_key = RouteKeyFactory.from_message_dict(message) + except Exception: + route_key = RouteKey(platform=platform) + + route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata) + account_id = route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None + scope = route_key.scope or route_scope or runtime_state.scope or gateway_entry.scope or None + return RouteKey( + platform=platform, + account_id=account_id, + scope=scope, + ) + + async def _handle_update_message_gateway_state(self, envelope: Envelope) -> Envelope: + """处理消息网关上报的运行时状态更新。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 状态更新处理结果。 + """ + + try: + payload = MessageGatewayStateUpdatePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + gateway_entry = self._resolve_message_gateway_entry(envelope.plugin_id, payload.gateway_name) + if gateway_entry is None: return envelope.make_error_response( - ErrorCode.E_GENERATION_MISMATCH.value, - f"组件注册 generation 过期: {envelope.generation} 不在已知代际中", + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {envelope.plugin_id} 未声明消息网关 {payload.gateway_name or ''}", ) - if envelope.generation == staged_generation and staged_generation != 0: - self._staged_registered_plugins[reg.plugin_id] = reg - logger.info( - f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功," - f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}" + try: + if payload.ready: + route_key = self._build_message_gateway_route_key(gateway_entry, payload) + await self._register_message_gateway_driver(envelope.plugin_id, gateway_entry, route_key) + else: + await self._unregister_message_gateway_driver(envelope.plugin_id, gateway_entry.name) + runtime_state, route_key_dict = self._apply_message_gateway_state( + plugin_id=envelope.plugin_id, + gateway_entry=gateway_entry, + payload=payload, ) - return envelope.make_response(payload={"accepted": True, "staged": True}) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - self._registered_plugins[reg.plugin_id] = reg - - # 在策略引擎中注册插件 - self._policy.register_plugin( - plugin_id=reg.plugin_id, - generation=envelope.generation, - capabilities=reg.capabilities_required or [], + response = MessageGatewayStateUpdateResultPayload( + accepted=True, + ready=runtime_state.ready, + route_key=route_key_dict, ) + return envelope.make_response(payload=response.model_dump()) - # 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件 - self._component_registry.remove_components_by_plugin(reg.plugin_id) - self._component_registry.register_plugin_components( - plugin_id=reg.plugin_id, - components=[c.model_dump() for c in reg.components], + async def _handle_route_message(self, envelope: Envelope) -> Envelope: + """处理消息网关上报的外部入站消息。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 注入结果响应。 + """ + + try: + payload = RouteMessagePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + gateway_entry = self._resolve_message_gateway_entry(envelope.plugin_id, payload.gateway_name) + if gateway_entry is None or not bool(gateway_entry.supports_receive): + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {envelope.plugin_id} 未声明可接收的消息网关 {payload.gateway_name}", + ) + + runtime_state = self._message_gateway_states.get(envelope.plugin_id, {}).get( + gateway_entry.name, + _MessageGatewayRuntimeState(), ) + if not runtime_state.ready: + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"消息网关 {gateway_entry.full_name} 尚未就绪,不能注入外部消息", + ) - stats = self._component_registry.get_stats() - logger.info( - f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功," - f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}," - f"注册表总计: {stats}" + try: + route_key = self._build_inbound_route_key( + gateway_entry=gateway_entry, + runtime_state=runtime_state, + message=payload.message, + route_metadata=payload.route_metadata, + ) + session_message = self._message_gateway.build_session_message(payload.message) + self._attach_inbound_route_metadata(session_message, route_key, payload.route_metadata) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + platform_io_manager = get_platform_io_manager() + accepted = await platform_io_manager.accept_inbound( + InboundMessageEnvelope( + route_key=route_key, + driver_id=self._build_message_gateway_driver_id(envelope.plugin_id, gateway_entry.name), + driver_kind=DriverKind.PLUGIN, + external_message_id=payload.external_message_id or str(payload.message.get("message_id") or "") or None, + dedupe_key=payload.dedupe_key or None, + session_message=session_message, + payload=payload.message, + metadata={ + "plugin_id": envelope.plugin_id, + "gateway_name": gateway_entry.name, + "protocol": gateway_entry.protocol, + **payload.route_metadata, + }, + ) ) - - return envelope.make_response(payload={"accepted": True}) + response = ReceiveExternalMessageResultPayload( + accepted=accepted, + route_key={ + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + }, + ) + return envelope.make_response(payload=response.model_dump()) async def _handle_runner_ready(self, envelope: Envelope) -> Envelope: - """处理 Runner 初始化完成信号。""" - try: - ready = RunnerReadyPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) + """处理 Runner 就绪通知。 - event = self._runner_ready_events.setdefault(envelope.generation, asyncio.Event()) - self._runner_ready_payloads[envelope.generation] = ready - event.set() - logger.info( - f"Runner generation={envelope.generation} 已就绪,成功插件数: {len(ready.loaded_plugins)}," - f"失败插件数: {len(ready.failed_plugins)}" - ) + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = RunnerReadyPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + self._runner_ready_payloads = payload + self._runner_ready_events.set() return envelope.make_response(payload={"accepted": True}) + def _build_runner_environment(self) -> Dict[str, str]: + """构建拉起 Runner 所需的环境变量。 + + Returns: + Dict[str, str]: 传递给 Runner 进程的环境变量映射。 + """ + global_config_snapshot = config_manager.get_global_config().model_dump(mode="json") + global_config_snapshot["model"] = config_manager.get_model_config().model_dump(mode="json") + return { + ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False), + ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False), + ENV_HOST_VERSION: PROTOCOL_VERSION, + ENV_IPC_ADDRESS: self._transport.get_address(), + ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs), + ENV_SESSION_TOKEN: self._rpc_server.session_token, + } + async def _spawn_runner(self) -> None: - """拉起 Runner 子进程""" - runner_module = "src.plugin_runtime.runner.runner_main" - address = self._transport.get_address() - token = self._rpc_server.session_token + """拉起 Runner 子进程。""" + if self._runner_process is not None and self._runner_process.returncode is None: + logger.warning("Runner 已在运行,跳过重复拉起") + return + + self._clear_runner_state() env = os.environ.copy() - env[ENV_IPC_ADDRESS] = address - env[ENV_SESSION_TOKEN] = token - env[ENV_PLUGIN_DIRS] = os.pathsep.join(str(p) for p in self._plugin_dirs) - env[ENV_HOST_VERSION] = MMC_VERSION + env.update(self._build_runner_environment()) self._runner_process = await asyncio.create_subprocess_exec( sys.executable, "-m", - runner_module, + "src.plugin_runtime.runner.runner_main", env=env, - # stdout 不捕获:Runner 的日志均通过 IPC 传㛹(RunnerIPCLogHandler) - stdout=None, - # stderr 捕获为 PIPE,仅用于 IPC 建立前的进程级致命错误输出 + stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE, ) - self._attach_stderr_drain(self._runner_process) - self._runner_generation = self._rpc_server.runner_generation - logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}") - - async def _shutdown_runner(self) -> None: - """优雅关停 Runner""" - if not self._runner_process or self._runner_process.returncode is not None: - return - - # 发送 prepare_shutdown - try: - if self._rpc_server.is_connected: - shutdown_payload = ShutdownPayload(reason="host_shutdown", drain_timeout_ms=5000) - await self._rpc_server.send_request( - "plugin.prepare_shutdown", - payload=shutdown_payload.model_dump(), - timeout_ms=5000, - ) - await self._rpc_server.send_request( - "plugin.shutdown", - payload=shutdown_payload.model_dump(), - timeout_ms=5000, - ) - except Exception as e: - logger.warning(f"发送关停命令失败: {e}") - - # 等待进程退出 - try: - await asyncio.wait_for(self._runner_process.wait(), timeout=10.0) - except asyncio.TimeoutError: - logger.warning("Runner 未在超时内退出,强制终止") - self._runner_process.kill() - await self._runner_process.wait() - - await self._cleanup_stderr_drain() - - async def _health_check_loop(self) -> None: - """周期性健康检查 + 崩溃自动重启""" - while self._running: - await asyncio.sleep(self._health_interval) - - # 检查 Runner 进程是否意外退出 - if self._runner_process and self._runner_process.returncode is not None: - exit_code = self._runner_process.returncode - logger.warning(f"Runner 进程已退出 (exit_code={exit_code})") - - if self._restart_count < self._max_restart_attempts: - self._restart_count += 1 - logger.info(f"尝试重启 Runner ({self._restart_count}/{self._max_restart_attempts})") - # 清理旧的组件注册 - for plugin_id in list(self._registered_plugins.keys()): - self._component_registry.remove_components_by_plugin(plugin_id) - self._policy.revoke_plugin(plugin_id) - self._registered_plugins.clear() - - try: - self._clear_runtime_state() - # 重新生成 session token,防止旧 Runner 僵尸进程用旧 token 重连 - self._rpc_server.reset_session_token() - await self._spawn_runner() - except Exception as e: - logger.error(f"Runner 重启失败: {e}", exc_info=True) - else: - logger.error(f"Runner 连续崩溃 {self._max_restart_attempts} 次,停止重启") - continue - - if not self._rpc_server.is_connected: - logger.warning("Runner 未连接,跳过健康检查") - continue - - try: - resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000) - health = HealthPayload.model_validate(resp.payload) - if not health.healthy: - logger.warning(f"Runner 健康检查异常: {health}") - else: - # 健康检查成功,重置重启计数 - self._restart_count = 0 - except RPCError as e: - logger.error(f"健康检查失败: {e}") - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"健康检查异常: {e}") - - async def _wait_for_runner_generation( - self, - expected_generation: int, - timeout_sec: float, - allow_staged: bool = False, - ) -> None: - """等待指定代际的 Runner 完成连接。""" - deadline = asyncio.get_running_loop().time() + timeout_sec - while asyncio.get_running_loop().time() < deadline: - if allow_staged and self._rpc_server.has_generation(expected_generation): - return - if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation: - self._runner_generation = self._rpc_server.runner_generation - return - await asyncio.sleep(0.1) - raise TimeoutError(f"等待 Runner generation {expected_generation} 超时") - - async def _wait_for_runner_ready(self, expected_generation: int, timeout_sec: float) -> RunnerReadyPayload: - """等待指定代际的 Runner 完成初始化。""" - event = self._runner_ready_events.setdefault(expected_generation, asyncio.Event()) - await asyncio.wait_for(event.wait(), timeout=timeout_sec) - return self._runner_ready_payloads.get(expected_generation, RunnerReadyPayload()) - - def _clear_runtime_state(self) -> None: - """清空当前插件注册态。""" - self._component_registry.clear() - self._policy.clear() - self._registered_plugins.clear() - self._staged_registered_plugins.clear() - - def _rebuild_runtime_state(self) -> None: - """根据已记录的插件注册信息重建运行时状态。""" - self._component_registry.clear() - self._policy.clear() - for reg in self._registered_plugins.values(): - self._policy.register_plugin( - plugin_id=reg.plugin_id, - generation=self._rpc_server.runner_generation, - capabilities=reg.capabilities_required or [], - ) - self._component_registry.register_plugin_components( - plugin_id=reg.plugin_id, - components=[c.model_dump() for c in reg.components], + if self._runner_process.stderr is not None: + self._stderr_drain_task = asyncio.create_task( + self._drain_runner_stderr(self._runner_process.stderr), + name="PluginRunnerSupervisor.stderr", ) - def _attach_stderr_drain(self, process: asyncio.subprocess.Process) -> None: - """为 Runner stderr 创建排空任务,捕获 IPC 建立前的进程级错误输出。 + logger.info(f"Runner 已拉起,pid={self._runner_process.pid}") - stderr 中的内容通常是: - - Runner 启动早期(握手完成之前)的日志 - - 进程级致命错误(ImportError、SyntaxError等) - - 异常进程退出前的最后输出 - - 握手成功后,插件的所有日志均经由 RunnerIPCLogHandler 通过 IPC 传输。 - """ - if process.stderr is None: - return - task = asyncio.create_task( - self._drain_runner_stderr(process.stderr, process.pid), - name=f"runner_stderr_drain:{process.pid}", - ) - self._stderr_drain_task = task - task.add_done_callback( - lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task() - ) - - def _clear_stderr_drain_task(self) -> None: - self._stderr_drain_task = None - - async def _drain_runner_stderr( - self, - stream: asyncio.StreamReader, - pid: int, - ) -> None: - """持续读取 Runner stderr 并转发到 Host Logger,防止 PIPE 锡死子进程。 + async def _drain_runner_stderr(self, stream: asyncio.StreamReader) -> None: + """持续排空 Runner 的 stderr。 Args: - stream: Runner 子进程的 stderr 流。 - pid: 子进程 PID,仅用于日志上下文。 + stream: Runner 的 stderr 流。 """ try: while True: line = await stream.readline() if not line: - break - if message := line.decode(errors="replace").rstrip(): - # 将 stderr 输出以 WARNING 级展示: - # 如果 Runner 正常运行,此流应当无输出; - # 有输出说明进程级错误发生,需要出现在主进程日志中 - logger.warning(f"[runner:{pid}:stderr] {message}") + return + if message := line.decode("utf-8", errors="replace").rstrip(): + logger.warning(f"[runner-stderr] {message}") except asyncio.CancelledError: raise except Exception as exc: - logger.debug(f"读取 Runner stderr 失败 (pid={pid}): {exc}") + logger.warning(f"排空 Runner stderr 失败: {exc}") - async def _cleanup_stderr_drain(self) -> None: - """等待并取消 stderr 排空任务。""" - if self._stderr_drain_task is None: - return - task = self._stderr_drain_task - self._stderr_drain_task = None - if not task.done(): - task.cancel() - with contextlib.suppress(Exception): - await asyncio.gather(task, return_exceptions=True) + async def _shutdown_runner(self, reason: str = "normal") -> None: + """优雅关闭 Runner 子进程。 - @staticmethod - async def _terminate_process( - process: Optional[asyncio.subprocess.Process], - keep_process: Optional[asyncio.subprocess.Process] = None, - ) -> None: - """终止指定进程,但跳过需要保留的旧进程引用。""" - if process is None or process is keep_process or process.returncode is not None: + Args: + reason: 关停原因。 + """ + process = self._runner_process + if process is None: return - process.terminate() + payload = ShutdownPayload(reason=reason) + + if process.returncode is None and self._rpc_server.is_connected: + with contextlib.suppress(Exception): + await self._rpc_server.send_request( + "plugin.prepare_shutdown", + payload=payload.model_dump(), + timeout_ms=payload.drain_timeout_ms, + ) + with contextlib.suppress(Exception): + await self._rpc_server.send_request( + "plugin.shutdown", + payload=payload.model_dump(), + timeout_ms=payload.drain_timeout_ms, + ) + + if process.returncode is None: + try: + await asyncio.wait_for(process.wait(), timeout=max(payload.drain_timeout_ms / 1000.0, 1.0)) + except asyncio.TimeoutError: + logger.warning("Runner 优雅退出超时,尝试 terminate") + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning("Runner terminate 超时,尝试 kill") + process.kill() + with contextlib.suppress(Exception): + await asyncio.wait_for(process.wait(), timeout=5.0) + + self._runner_process = None + + if self._stderr_drain_task is not None: + self._stderr_drain_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._stderr_drain_task + self._stderr_drain_task = None + + for plugin_id in list(self._message_gateway_states.keys()): + await self._unregister_all_message_gateway_drivers_for_plugin(plugin_id) + self._clear_runner_state() + + async def _health_check_loop(self) -> None: + """周期性检查 Runner 健康状态,并在必要时重启。""" + timeout_ms = max(int(self._health_interval * 1000), 1000) + + while self._running: + try: + await asyncio.sleep(self._health_interval) + except asyncio.CancelledError: + return + + if not self._running: + return + + process = self._runner_process + if process is None or process.returncode is not None: + reason = "runner_process_exited" if process is not None else "runner_process_missing" + restarted = await self._restart_runner(reason=reason) + if not restarted: + return + continue + + try: + response = await self._rpc_server.send_request("plugin.health", timeout_ms=timeout_ms) + health = HealthPayload.model_validate(response.payload) + if not health.healthy: + restarted = await self._restart_runner(reason="health_check_unhealthy") + if not restarted: + return + except asyncio.CancelledError: + return + except (RPCError, Exception) as exc: + logger.warning(f"Runner 健康检查失败: {exc}") + restarted = await self._restart_runner(reason="health_check_failed") + if not restarted: + return + + async def _restart_runner(self, reason: str) -> bool: + """在 Runner 异常时执行整进程级重启。 + + Args: + reason: 触发重启的原因。 + + Returns: + bool: 是否重启成功。 + """ + if not self._running: + return False + + if self._restart_count >= self._max_restart_attempts: + logger.error(f"Runner 自动重启次数已达上限,停止重启。reason={reason}") + return False + + self._restart_count += 1 + logger.warning(f"准备重启 Runner,第 {self._restart_count} 次,reason={reason}") + + await self._shutdown_runner(reason=reason) + try: - await asyncio.wait_for(process.wait(), timeout=10.0) - except asyncio.TimeoutError: - process.kill() - await process.wait() + await self._spawn_runner() + await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) + await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) + except Exception as exc: + await self._shutdown_runner(reason="restart_failed") + logger.error(f"Runner 重启失败: {exc}", exc_info=True) + return False + + self._restart_count = 0 + logger.info("Runner 已成功重启") + return True + + def _clear_runner_state(self) -> None: + """清理当前 Runner 对应的 Host 侧注册状态。""" + self._authorization.clear() + self._api_registry.clear() + self._component_registry.clear() + self._registered_plugins.clear() + self._message_gateway_states.clear() + self._runner_ready_events = asyncio.Event() + self._runner_ready_payloads = RunnerReadyPayload() + self._rpc_server.clear_handshake_state() + + def _get_runner_startup_failure_reason(self) -> Optional[str]: + """获取 Runner 在启动阶段已经暴露出的失败原因。 + + Returns: + Optional[str]: 若已检测到失败则返回失败原因,否则返回 ``None``。 + """ + if handshake_reason := self._rpc_server.last_handshake_rejection_reason: + return f"握手被拒绝: {handshake_reason}" + + process = self._runner_process + if process is None: + return "Runner 进程不存在" + + if process.returncode is not None: + return f"Runner 进程已退出,退出码 {process.returncode}" + + return None + + +PluginSupervisor = PluginRunnerSupervisor diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py deleted file mode 100644 index 3037e9dd..00000000 --- a/src/plugin_runtime/host/workflow_executor.py +++ /dev/null @@ -1,422 +0,0 @@ -"""Host-side WorkflowExecutor - -6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS) - -每个阶段执行顺序: -1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook -2. 按 priority 降序排列 -3. 串行执行 blocking hook(可修改 message,返回 HookResult) -4. 并发执行 non-blocking hook(只读) -5. 检查是否有 SKIP_STAGE 或 ABORT -6. PLAN 阶段内置 Command 匹配路由 - -支持: -- HookResult: CONTINUE / SKIP_STAGE / ABORT -- ErrorPolicy: ABORT / SKIP / LOG (per-hook) -- stage_outputs: 阶段间带命名空间的数据传递 -- modification_log: 消息修改审计 -""" - -from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple - -import asyncio -import time -import uuid - -from src.common.logger import get_logger -from src.config.config import global_config -from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent - -logger = get_logger("plugin_runtime.host.workflow_executor") - -# 阶段顺序 -STAGE_SEQUENCE: List[str] = [ - "ingress", - "pre_process", - "plan", - "tool_execute", - "post_process", - "egress", -] - -# HookResult 常量(与 SDK HookResult enum 值对应) -HOOK_CONTINUE = "continue" -HOOK_SKIP_STAGE = "skip_stage" -HOOK_ABORT = "abort" - - -# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待 -# 从配置文件读取,允许用户调整 -def _get_blocking_timeout() -> float: - return global_config.plugin_runtime.workflow_blocking_timeout_sec - - -class ModificationRecord: - """消息修改记录""" - - __slots__ = ("stage", "hook_name", "timestamp", "fields_changed") - - def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None: - self.stage = stage - self.hook_name = hook_name - self.timestamp = time.perf_counter() - self.fields_changed = fields_changed - - -class WorkflowContext: - """Workflow 执行上下文""" - - def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None: - self.trace_id = trace_id or uuid.uuid4().hex - self.stream_id = stream_id - self.timings: Dict[str, float] = {} - self.errors: List[str] = [] - # 阶段间数据传递(按 stage 命名空间隔离) - self.stage_outputs: Dict[str, Dict[str, Any]] = {} - # 消息修改审计日志 - self.modification_log: List[ModificationRecord] = [] - # PLAN 阶段命令匹配结果 - self.matched_command: Optional[str] = None - - def set_stage_output(self, stage: str, key: str, value: Any) -> None: - self.stage_outputs.setdefault(stage, {})[key] = value - - def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any: - return self.stage_outputs.get(stage, {}).get(key, default) - - -class WorkflowResult: - """Workflow 执行结果""" - - def __init__( - self, - status: str = "completed", # completed / aborted / failed - return_message: str = "", - stopped_at: str = "", - diagnostics: Optional[Dict[str, Any]] = None, - ) -> None: - self.status = status - self.return_message = return_message - self.stopped_at = stopped_at - self.diagnostics = diagnostics or {} - - -# invoke_fn 签名 -InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]] - - -class WorkflowExecutor: - """Host-side Workflow 执行器 - - 实现 stage-based pipeline + per-stage hook chain with priority + early return。 - """ - - def __init__(self, registry: ComponentRegistry) -> None: - self._registry = registry - self._background_tasks: Set[asyncio.Task] = set() - - async def execute( - self, - invoke_fn: InvokeFn, - message: Optional[Dict[str, Any]] = None, - stream_id: Optional[str] = None, - context: Optional[WorkflowContext] = None, - command_invoke_fn: Optional[InvokeFn] = None, - ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]: - """执行 workflow pipeline。 - - Args: - invoke_fn: 用于 workflow_step 的回调 - command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command), - 未传则复用 invoke_fn - - Returns: - (result, final_message, context) - """ - ctx = context or WorkflowContext(stream_id=stream_id) - current_message = dict(message) if message else None - - for stage in STAGE_SEQUENCE: - stage_start = time.perf_counter() - - try: - # PLAN 阶段: 先做 Command 路由 - if stage == "plan" and current_message: - cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx) - if cmd_result is not None: - # 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs - ctx.set_stage_output("plan", "command_result", cmd_result) - ctx.timings[stage] = time.perf_counter() - stage_start - continue - - # 获取该阶段所有 hook(已按 priority 降序排列) - all_steps = self._registry.get_workflow_steps(stage) - if not all_steps: - ctx.timings[stage] = time.perf_counter() - stage_start - continue - - # 1. Pre-filter - filtered_steps = self._pre_filter(all_steps, current_message) - - # 2. 分离 blocking 和 non-blocking - blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)] - nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)] - - # 3. 串行执行 blocking hook - skip_stage = False - for step in blocking_steps: - hook_result, modified, step_error = await self._invoke_step( - invoke_fn, step, stage, ctx, current_message - ) - - if step_error: - error_policy = step.metadata.get("error_policy", "abort") - ctx.errors.append(f"{step.full_name}: {step_error}") - - if error_policy == "abort": - ctx.timings[stage] = time.perf_counter() - stage_start - return ( - WorkflowResult( - status="failed", - return_message=step_error, - stopped_at=stage, - diagnostics={"step": step.full_name, "trace_id": ctx.trace_id}, - ), - current_message, - ctx, - ) - elif error_policy == "skip": - logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}") - continue - else: # log - logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}") - continue - - # 更新消息(仅 blocking hook 有权修改) - if modified: - changed_fields = ( - _diff_keys(current_message, modified) if current_message else list(modified.keys()) - ) - ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields)) - current_message = modified - - if hook_result == HOOK_ABORT: - ctx.timings[stage] = time.perf_counter() - stage_start - return ( - WorkflowResult( - status="aborted", - return_message=f"aborted by {step.full_name}", - stopped_at=stage, - diagnostics={"step": step.full_name, "trace_id": ctx.trace_id}, - ), - current_message, - ctx, - ) - - if hook_result == HOOK_SKIP_STAGE: - skip_stage = True - break - - # 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message) - if nonblocking_steps and not skip_stage: - for step in nonblocking_steps: - self._track_background_task( - asyncio.create_task( - self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message) - ) - ) - - ctx.timings[stage] = time.perf_counter() - stage_start - - except Exception as e: - ctx.timings[stage] = time.perf_counter() - stage_start - ctx.errors.append(f"{stage}: {e}") - logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True) - return ( - WorkflowResult( - status="failed", - return_message=str(e), - stopped_at=stage, - diagnostics={"trace_id": ctx.trace_id}, - ), - current_message, - ctx, - ) - - return ( - WorkflowResult( - status="completed", - return_message="workflow completed", - diagnostics={"trace_id": ctx.trace_id}, - ), - current_message, - ctx, - ) - - def _track_background_task(self, task: asyncio.Task) -> None: - """保持 non-blocking workflow task 的强引用,直到任务结束。""" - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - - # ─── 内部方法 ────────────────────────────────────────────── - - def _pre_filter( - self, - steps: List[RegisteredComponent], - message: Optional[Dict[str, Any]], - ) -> List[RegisteredComponent]: - """根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。""" - if not message: - return steps - - result = [] - for step in steps: - filter_cond = step.metadata.get("filter", {}) - if not filter_cond: - result.append(step) - continue - if self._match_filter(filter_cond, message): - result.append(step) - return result - - @staticmethod - def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool: - """简单 key-value 匹配过滤。 - - filter 中的每个 key 必须在 message 中存在且值相等, - 全部匹配才通过。 - """ - for key, expected in filter_cond.items(): - actual = message.get(key) - if (isinstance(expected, list) and actual not in expected) or ( - not isinstance(expected, list) and actual != expected - ): - return False - return True - - async def _invoke_step( - self, - invoke_fn: InvokeFn, - step: RegisteredComponent, - stage: str, - ctx: WorkflowContext, - message: Optional[Dict[str, Any]], - ) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]: - """调用单个 blocking hook。 - - Returns: - (hook_result, modified_message, error_string_or_None) - """ - timeout_ms = step.metadata.get("timeout_ms", 0) - # 使用 hook 声明的超时,但不超过全局安全阀 - timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout() - step_key = f"{stage}:{step.full_name}" - step_start = time.perf_counter() - - try: - coro = invoke_fn( - step.plugin_id, - step.name, - { - "stage": stage, - "trace_id": ctx.trace_id, - "message": message, - "stage_outputs": ctx.stage_outputs, - }, - ) - resp = await asyncio.wait_for(coro, timeout=timeout_sec) - ctx.timings[step_key] = time.perf_counter() - step_start - - hook_result = resp.get("hook_result", HOOK_CONTINUE) - modified_message = resp.get("modified_message") - # 存 stage output(如果 hook 提供了) - stage_out = resp.get("stage_output") - if isinstance(stage_out, dict): - for k, v in stage_out.items(): - ctx.set_stage_output(stage, k, v) - - return hook_result, modified_message, None - - except asyncio.TimeoutError: - ctx.timings[step_key] = time.perf_counter() - step_start - return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms" - - except Exception as e: - ctx.timings[step_key] = time.perf_counter() - step_start - return HOOK_CONTINUE, None, str(e) - - async def _invoke_step_fire_and_forget( - self, - invoke_fn: InvokeFn, - step: RegisteredComponent, - stage: str, - ctx: WorkflowContext, - message: Optional[Dict[str, Any]], - ) -> None: - """Non-blocking hook 调用,只读,忽略结果。""" - timeout_ms = step.metadata.get("timeout_ms", 0) - # 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏 - timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout() - - try: - coro = invoke_fn( - step.plugin_id, - step.name, - { - "stage": stage, - "trace_id": ctx.trace_id, - "message": message, - "stage_outputs": ctx.stage_outputs, - }, - ) - await asyncio.wait_for(coro, timeout=timeout_sec) - except asyncio.TimeoutError: - logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)") - except Exception as e: - logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}") - - async def _route_command( - self, - invoke_fn: InvokeFn, - message: Dict[str, Any], - ctx: WorkflowContext, - ) -> Optional[Dict[str, Any]]: - """PLAN 阶段内置 Command 路由。 - - 在 registry 中查找匹配的 command 组件, - 匹配到则直接路由到对应 command handler,返回执行结果。 - 不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。 - """ - plain_text = message.get("plain_text", "") - if not plain_text: - return None - - match_result = self._registry.find_command_by_text(plain_text) - if match_result is None: - return None - - matched, matched_groups = match_result - - ctx.matched_command = matched.full_name - logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}") - - try: - return await invoke_fn( - matched.plugin_id, - matched.name, - { - "text": plain_text, - "message": message, - "trace_id": ctx.trace_id, - "matched_groups": matched_groups, - }, - ) - except Exception as e: - logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True) - ctx.errors.append(f"command:{matched.full_name}: {e}") - return None - - -def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]: - """返回 new 中与 old 不同的 key 列表。""" - return [k for k, v in new.items() if k not in old or old[k] != v] diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 04c8e324..c34f5ef5 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -8,23 +8,27 @@ """ from pathlib import Path -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple import asyncio -import json + import tomlkit from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import config_manager from src.config.file_watcher import FileChange, FileWatcher +from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager from src.plugin_runtime.capabilities import ( RuntimeComponentCapabilityMixin, RuntimeCoreCapabilityMixin, RuntimeDataCapabilityMixin, ) from src.plugin_runtime.capabilities.registry import register_capability_impls +from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils +from src.plugin_runtime.runner.manifest_validator import ManifestValidator if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage from src.plugin_runtime.host.supervisor import PluginSupervisor logger = get_logger("plugin_runtime.integration") @@ -55,6 +59,7 @@ class PluginRuntimeManager( """ def __init__(self) -> None: + """初始化插件运行时管理器。""" from src.plugin_runtime.host.supervisor import PluginSupervisor self._builtin_supervisor: Optional[PluginSupervisor] = None @@ -63,6 +68,26 @@ class PluginRuntimeManager( self._plugin_file_watcher: Optional[FileWatcher] = None self._plugin_source_watcher_subscription_id: Optional[str] = None self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {} + self._plugin_path_cache: Dict[str, Path] = {} + self._manifest_validator: ManifestValidator = ManifestValidator() + self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload + self._config_reload_callback_registered: bool = False + + async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: + """接收 Platform IO 审核后的入站消息并送入主消息链。 + + Args: + envelope: Platform IO 产出的入站封装。 + """ + session_message = envelope.session_message + if session_message is None and envelope.payload is not None: + session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload)) + if session_message is None: + raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload") + + from src.chat.message_receive.bot import chat_bot + + await chat_bot.receive_message(session_message) # ─── 插件目录 ───────────────────────────────────────────── @@ -78,6 +103,42 @@ class PluginRuntimeManager( candidate = Path("plugins").resolve() return [candidate] if candidate.is_dir() else [] + @classmethod + def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]: + """扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。""" + validator = ManifestValidator() + return validator.build_plugin_dependency_map(plugin_dirs) + + @classmethod + def _build_group_start_order( + cls, + builtin_dirs: Sequence[Path], + third_party_dirs: Sequence[Path], + ) -> List[str]: + """根据跨 Supervisor 依赖关系决定 Runner 启动顺序。""" + + builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs) + third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs) + builtin_plugin_ids = set(builtin_dependencies) + third_party_plugin_ids = set(third_party_dependencies) + + builtin_needs_third_party = any( + dependency in third_party_plugin_ids + for dependencies in builtin_dependencies.values() + for dependency in dependencies + ) + third_party_needs_builtin = any( + dependency in builtin_plugin_ids + for dependencies in third_party_dependencies.values() + for dependency in dependencies + ) + + if builtin_needs_third_party and third_party_needs_builtin: + raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner") + if builtin_needs_third_party: + return ["third_party", "builtin"] + return ["builtin", "third_party"] + # ─── 生命周期 ───────────────────────────────────────────── async def start(self) -> None: @@ -86,7 +147,7 @@ class PluginRuntimeManager( logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动") return - _cfg = global_config.plugin_runtime + _cfg = config_manager.get_global_config().plugin_runtime if not _cfg.enabled: logger.info("插件运行时已在配置中禁用,跳过启动") return @@ -108,6 +169,8 @@ class PluginRuntimeManager( logger.info("未找到任何插件目录,跳过插件运行时启动") return + platform_io_manager = get_platform_io_manager() + # 从配置读取自定义 IPC socket 路径(留空则自动生成) socket_path_base = _cfg.ipc_socket_path or None @@ -132,19 +195,46 @@ class PluginRuntimeManager( started_supervisors: List[PluginSupervisor] = [] try: - if self._builtin_supervisor: - await self._builtin_supervisor.start() - started_supervisors.append(self._builtin_supervisor) - if self._third_party_supervisor: - await self._third_party_supervisor.start() - started_supervisors.append(self._third_party_supervisor) + platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound) + await platform_io_manager.ensure_send_pipeline_ready() + + supervisor_groups: Dict[str, Optional[PluginSupervisor]] = { + "builtin": self._builtin_supervisor, + "third_party": self._third_party_supervisor, + } + start_order = self._build_group_start_order(builtin_dirs, third_party_dirs) + + for group_name in start_order: + supervisor = supervisor_groups.get(group_name) + if supervisor is None: + continue + + external_plugin_versions = { + plugin_id: plugin_version + for started_supervisor in started_supervisors + for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items() + } + supervisor.set_external_available_plugins(external_plugin_versions) + await supervisor.start() + started_supervisors.append(supervisor) + await self._start_plugin_file_watcher() + config_manager.register_reload_callback(self._config_reload_callback) + self._config_reload_callback_registered = True self._started = True logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {third_party_dirs or '无'}") except Exception as e: logger.error(f"插件运行时启动失败: {e}", exc_info=True) await self._stop_plugin_file_watcher() + if self._config_reload_callback_registered: + config_manager.unregister_reload_callback(self._config_reload_callback) + self._config_reload_callback_registered = False await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True) + platform_io_manager.clear_inbound_dispatcher() + try: + await platform_io_manager.stop() + except Exception as platform_io_exc: + logger.warning(f"Platform IO 停止失败: {platform_io_exc}") self._started = False self._builtin_supervisor = None self._third_party_supervisor = None @@ -154,7 +244,11 @@ class PluginRuntimeManager( if not self._started: return + platform_io_manager = get_platform_io_manager() await self._stop_plugin_file_watcher() + if self._config_reload_callback_registered: + config_manager.unregister_reload_callback(self._config_reload_callback) + self._config_reload_callback_registered = False coroutines: List[Coroutine[Any, Any, None]] = [] if self._builtin_supervisor: @@ -162,18 +256,32 @@ class PluginRuntimeManager( if self._third_party_supervisor: coroutines.append(self._third_party_supervisor.stop()) + stop_errors: List[str] = [] try: - await asyncio.gather(*coroutines, return_exceptions=True) - logger.info("插件运行时已停止") - except Exception as e: - logger.error(f"插件运行时停止失败: {e}", exc_info=True) + results = await asyncio.gather(*coroutines, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + stop_errors.append(str(result)) + + platform_io_manager.clear_inbound_dispatcher() + try: + await platform_io_manager.stop() + except Exception as exc: + stop_errors.append(f"Platform IO: {exc}") + + if stop_errors: + logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}") + else: + logger.info("插件运行时已停止") finally: self._started = False self._builtin_supervisor = None self._third_party_supervisor = None + self._plugin_path_cache.clear() @property def is_running(self) -> bool: + """返回插件运行时是否处于启动状态。""" return self._started @property @@ -181,11 +289,176 @@ class PluginRuntimeManager( """获取所有活跃的 Supervisor""" return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None] + def _build_registered_dependency_map(self) -> Dict[str, Set[str]]: + """根据当前已注册插件构建全局依赖图。""" + + dependency_map: Dict[str, Set[str]] = {} + for supervisor in self.supervisors: + for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items(): + dependency_map[plugin_id] = { + str(dependency or "").strip() + for dependency in getattr(registration, "dependencies", []) + if str(dependency or "").strip() + } + return dependency_map + + @staticmethod + def _collect_reverse_dependents( + plugin_ids: Set[str], + dependency_map: Dict[str, Set[str]], + ) -> Set[str]: + """根据依赖图收集反向依赖闭包。""" + + impacted_plugins: Set[str] = set(plugin_ids) + changed = True + + while changed: + changed = False + for registered_plugin_id, dependencies in dependency_map.items(): + if registered_plugin_id in impacted_plugins: + continue + if dependencies & impacted_plugins: + impacted_plugins.add(registered_plugin_id) + changed = True + + return impacted_plugins + + def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]: + """构建当前已注册插件到所属 Supervisor 的映射。""" + + return { + plugin_id: supervisor + for supervisor in self.supervisors + for plugin_id in supervisor.get_loaded_plugin_ids() + } + + def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]: + """收集某个 Supervisor 可用的外部插件版本映射。""" + + external_plugin_versions: Dict[str, str] = {} + for supervisor in self.supervisors: + if supervisor is target_supervisor: + continue + external_plugin_versions.update(supervisor.get_loaded_plugin_versions()) + return external_plugin_versions + + def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]: + """根据插件目录推断应负责该插件重载的 Supervisor。""" + + for supervisor in self.supervisors: + if self._get_plugin_path_for_supervisor(supervisor, plugin_id) is not None: + return supervisor + return None + + def _warn_skipped_cross_supervisor_reload( + self, + requested_loaded_plugin_ids: Set[str], + dependency_map: Dict[str, Set[str]], + supervisor_by_plugin: Dict[str, "PluginSupervisor"], + ) -> None: + """记录因跨 Supervisor 边界而未参与联动重载的插件。""" + + if not requested_loaded_plugin_ids: + return + + handled_plugin_ids: Set[str] = set() + for supervisor in self.supervisors: + local_requested_plugin_ids = { + plugin_id + for plugin_id in requested_loaded_plugin_ids + if supervisor_by_plugin.get(plugin_id) is supervisor + } + if not local_requested_plugin_ids: + continue + + local_plugin_ids = set(supervisor.get_loaded_plugin_ids()) + local_dependency_map = { + plugin_id: { + dependency + for dependency in dependency_map.get(plugin_id, set()) + if dependency in local_plugin_ids + } + for plugin_id in local_plugin_ids + } + handled_plugin_ids.update( + self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map) + ) + + impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map) + skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids) + if not skipped_plugin_ids: + return + + logger.warning( + f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: " + f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;" + "跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。" + ) + + async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: + """按 Supervisor 分组执行精确重载。 + + 仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警, + 不再自动参与本次热重载。 + """ + + normalized_plugin_ids = [ + normalized_plugin_id + for plugin_id in plugin_ids + if (normalized_plugin_id := str(plugin_id or "").strip()) + ] + if not normalized_plugin_ids: + return True + + dependency_map = self._build_registered_dependency_map() + supervisor_by_plugin = self._build_registered_supervisor_map() + supervisor_roots: Dict["PluginSupervisor", List[str]] = {} + requested_loaded_plugin_ids: Set[str] = set() + missing_plugin_ids: List[str] = [] + + for plugin_id in normalized_plugin_ids: + supervisor = supervisor_by_plugin.get(plugin_id) + if supervisor is not None: + requested_loaded_plugin_ids.add(plugin_id) + else: + supervisor = self._find_supervisor_by_plugin_directory(plugin_id) + + if supervisor is None: + missing_plugin_ids.append(plugin_id) + continue + + if plugin_id not in supervisor_roots.setdefault(supervisor, []): + supervisor_roots[supervisor].append(plugin_id) + + if missing_plugin_ids: + logger.warning(f"以下插件未找到可重载的 Supervisor,已跳过: {', '.join(sorted(missing_plugin_ids))}") + + self._warn_skipped_cross_supervisor_reload( + requested_loaded_plugin_ids=requested_loaded_plugin_ids, + dependency_map=dependency_map, + supervisor_by_plugin=supervisor_by_plugin, + ) + + success = True + for supervisor, root_plugin_ids in supervisor_roots.items(): + if not root_plugin_ids: + continue + + reloaded = await supervisor.reload_plugins( + plugin_ids=root_plugin_ids, + reason=reason, + external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor), + ) + success = success and reloaded + + return success and not missing_plugin_ids + async def notify_plugin_config_updated( self, plugin_id: str, config_data: Optional[Dict[str, Any]] = None, config_version: str = "", + config_scope: str = "self", ) -> bool: """向拥有该插件的 Supervisor 推送配置更新事件。 @@ -193,6 +466,7 @@ class PluginRuntimeManager( plugin_id: 插件 ID config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载) config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制 + config_scope: 配置变更范围。 """ if not self._started: return False @@ -209,23 +483,78 @@ class PluginRuntimeManager( config_payload = ( config_data if config_data is not None - else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs) + else self._load_plugin_config_for_supervisor(sv, plugin_id) ) - await sv.notify_plugin_config_updated( + return await sv.notify_plugin_config_updated( plugin_id=plugin_id, config_data=config_payload, config_version=config_version, + config_scope=config_scope, ) - return True + + @staticmethod + def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]: + """规范化配置热重载范围列表。 + + Args: + changed_scopes: 原始配置热重载范围列表。 + + Returns: + tuple[str, ...]: 去重后的有效配置范围元组。 + """ + + normalized_scopes: list[str] = [] + for scope in changed_scopes: + normalized_scope = str(scope or "").strip().lower() + if normalized_scope not in {"bot", "model"}: + continue + if normalized_scope not in normalized_scopes: + normalized_scopes.append(normalized_scope) + return tuple(normalized_scopes) + + async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None: + """向订阅指定范围的插件广播配置热重载。 + + Args: + scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。 + config_data: 最新配置数据。 + """ + + for supervisor in self.supervisors: + for plugin_id in supervisor.get_config_reload_subscribers(scope): + delivered = await supervisor.notify_plugin_config_updated( + plugin_id=plugin_id, + config_data=config_data, + config_version="", + config_scope=scope, + ) + if not delivered: + logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败") + + async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None: + """处理 bot/model 主配置热重载广播。 + + Args: + changed_scopes: 本次热重载命中的配置范围列表。 + """ + + if not self._started: + return + + normalized_scopes = self._normalize_config_reload_scopes(changed_scopes) + if "bot" in normalized_scopes: + await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump(mode="json")) + if "model" in normalized_scopes: + await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump(mode="json")) # ─── 事件桥接 ────────────────────────────────────────────── async def bridge_event( self, event_type_value: str, - message_dict: Optional[Dict[str, Any]] = None, + message_dict: Optional[MessageDict] = None, extra_args: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]]]: + ) -> Tuple[bool, Optional[MessageDict]]: """将事件分发到所有 Supervisor Returns: @@ -235,17 +564,23 @@ class PluginRuntimeManager( return True, None new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value) - modified: Optional[Dict[str, Any]] = None + modified: Optional[MessageDict] = None + current_message: Optional["SessionMessage"] = ( + PluginMessageUtils._build_session_message_from_dict(dict(message_dict)) + if message_dict is not None + else None + ) for sv in self.supervisors: try: cont, mod = await sv.dispatch_event( event_type=new_event_type, - message=modified or message_dict, + message=current_message, extra_args=extra_args, ) if mod is not None: - modified = mod + current_message = mod + modified = PluginMessageUtils._session_message_to_dict(mod) if not cont: return False, modified except Exception as e: @@ -295,6 +630,37 @@ class PluginRuntimeManager( timeout_ms=timeout_ms, ) + async def try_send_message_via_platform_io( + self, + message: "SessionMessage", + ) -> Optional[DeliveryBatch]: + """尝试通过 Platform IO 中间层发送消息。 + + Args: + message: 待发送的内部会话消息。 + + Returns: + Optional[DeliveryBatch]: 若当前消息命中了至少一条发送路由,则返回 + 实际发送结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。 + """ + if not self._started: + return None + + platform_io_manager = get_platform_io_manager() + if not platform_io_manager.is_started: + return None + + try: + route_key = platform_io_manager.build_route_key_from_message(message) + except Exception as exc: + logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}") + return None + + if not platform_io_manager.resolve_drivers(route_key): + return None + + return await platform_io_manager.send_message(message, route_key) + def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]: """返回当前持有指定插件的所有 Supervisor。 @@ -314,30 +680,38 @@ class PluginRuntimeManager( raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由") return matches[0] if matches else None - @staticmethod - def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]: + async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: + """加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。""" + + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id: + return False + + try: + registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id) + except RuntimeError: + return False + + if registered_supervisor is not None: + return await self.reload_plugins_globally([normalized_plugin_id], reason=reason) + + supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id) + if supervisor is None: + return False + + return await supervisor.reload_plugins( + plugin_ids=[normalized_plugin_id], + reason=reason, + external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor), + ) + + @classmethod + def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: """扫描插件目录,找出被多个目录重复声明的插件 ID。""" plugin_locations: Dict[str, List[Path]] = {} - for base_dir in plugin_dirs: - if not base_dir.is_dir(): - continue - for entry in base_dir.iterdir(): - if not entry.is_dir(): - continue - manifest_path = entry / "_manifest.json" - plugin_path = entry / "plugin.py" - if not manifest_path.exists() or not plugin_path.exists(): - continue - - plugin_id = entry.name - try: - with open(manifest_path, "r", encoding="utf-8") as manifest_file: - manifest = json.load(manifest_file) - plugin_id = str(manifest.get("name", entry.name)).strip() or entry.name - except Exception: - continue - - plugin_locations.setdefault(plugin_id, []).append(entry) + validator = ManifestValidator() + for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs): + plugin_locations.setdefault(manifest.id, []).append(plugin_path) return { plugin_id: sorted(dict.fromkeys(paths), key=lambda p: str(p)) @@ -370,6 +744,7 @@ class PluginRuntimeManager( async def _stop_plugin_file_watcher(self) -> None: """停止插件文件监视器,并清理所有已注册订阅。""" if self._plugin_file_watcher is None: + self._plugin_path_cache.clear() return for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()): self._plugin_file_watcher.unsubscribe(subscription_id) @@ -379,12 +754,79 @@ class PluginRuntimeManager( self._plugin_source_watcher_subscription_id = None await self._plugin_file_watcher.stop() self._plugin_file_watcher = None + self._plugin_path_cache.clear() def _iter_plugin_dirs(self) -> Iterable[Path]: """迭代所有 Supervisor 当前管理的插件根目录。""" for supervisor in self.supervisors: yield from getattr(supervisor, "_plugin_dirs", []) + @staticmethod + def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]: + """迭代所有可能的插件目录路径。 + + Args: + plugin_dirs: 一个或多个插件根目录。 + + Yields: + Path: 单个插件目录路径。 + """ + for plugin_dir in plugin_dirs: + plugin_root = Path(plugin_dir).resolve() + if not plugin_root.is_dir(): + continue + for entry in plugin_root.iterdir(): + if entry.is_dir(): + yield entry.resolve() + + def _read_plugin_id_from_plugin_path(self, plugin_path: Path) -> Optional[str]: + """从单个插件目录中读取 manifest 声明的插件 ID。 + + Args: + plugin_path: 单个插件目录路径。 + + Returns: + Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。 + """ + return self._manifest_validator.read_plugin_id_from_plugin_path(plugin_path) + + def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]: + """迭代目录中可解析到的插件 ID 与实际目录路径。 + + Args: + plugin_dirs: 一个或多个插件根目录。 + + Yields: + Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。 + """ + for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs): + if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path): + yield plugin_id, plugin_path + + def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]: + """为指定 Supervisor 定位某个插件的实际目录。 + + Args: + supervisor: 目标 Supervisor。 + plugin_id: 插件 ID。 + + Returns: + Optional[Path]: 插件目录路径;未找到时返回 ``None``。 + """ + cached_path = self._plugin_path_cache.get(plugin_id) + if cached_path is not None: + for plugin_dir in getattr(supervisor, "_plugin_dirs", []): + if self._plugin_dir_matches(cached_path, Path(plugin_dir)): + return cached_path + + for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])): + if candidate_plugin_id != plugin_id: + continue + self._plugin_path_cache[plugin_id] = plugin_path + return plugin_path + + return None + def _refresh_plugin_config_watch_subscriptions(self) -> None: """按当前已注册插件集合刷新 config.toml 的单插件订阅。 @@ -394,7 +836,11 @@ class PluginRuntimeManager( if self._plugin_file_watcher is None: return - desired_config_paths = dict(self._iter_registered_plugin_config_paths()) + desired_plugin_paths = dict(self._iter_registered_plugin_paths()) + self._plugin_path_cache = desired_plugin_paths.copy() + desired_config_paths = { + plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items() + } for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()): if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]: @@ -418,28 +864,35 @@ class PluginRuntimeManager( """为指定插件生成配置文件变更回调。""" async def _callback(changes: Sequence[FileChange]) -> None: + """将 watcher 事件转发到指定插件的配置处理逻辑。 + + Args: + changes: 当前批次收集到的文件变更列表。 + """ await self._handle_plugin_config_changes(plugin_id, changes) return _callback - def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]: - """迭代当前所有已注册插件的 config.toml 路径。""" + def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]: + """迭代当前所有已注册插件的实际目录路径。""" for supervisor in self.supervisors: for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys(): - if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id): - yield plugin_id, config_path + if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id): + yield plugin_id, plugin_path def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]: """从指定 Supervisor 的插件目录中定位某个插件的 config.toml。""" - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - plugin_path = plugin_dir.resolve() / plugin_id - if plugin_path.is_dir(): - return plugin_path / "config.toml" - return None + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + return None if plugin_path is None else plugin_path / "config.toml" async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None: - """处理单个插件配置文件变化,并仅向目标插件推送配置更新。""" + """处理单个插件配置文件变化,并定向派发自配置热更新。 + + Args: + plugin_id: 发生配置变更的插件 ID。 + changes: 当前批次收集到的配置文件变更列表。 + + """ if not self._started or not changes: return @@ -453,18 +906,24 @@ class PluginRuntimeManager( return try: - await supervisor.notify_plugin_config_updated( + config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id) + delivered = await supervisor.notify_plugin_config_updated( plugin_id=plugin_id, - config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])), + config_data=config_payload, + config_version="", + config_scope="self", ) + if not delivered: + logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败") except Exception as exc: - logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}") + logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}") async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None: """处理插件源码相关变化。 这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由 - 单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。 + 单独的 per-plugin watcher 处理,并定向派发给目标插件的 + ``on_config_update()``,避免放大成不必要的跨插件 reload。 """ if not self._started or not changes: return @@ -477,7 +936,7 @@ class PluginRuntimeManager( logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}") return - reload_supervisors: List[Any] = [] + changed_plugin_ids: List[str] = [] changed_paths = [change.path.resolve() for change in changes] for supervisor in self.supervisors: @@ -485,13 +944,12 @@ class PluginRuntimeManager( plugin_id = self._match_plugin_id_for_supervisor(supervisor, path) if plugin_id is None: continue - if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors: - reload_supervisors.append(supervisor) + if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py": + if plugin_id not in changed_plugin_ids: + changed_plugin_ids.append(plugin_id) - for supervisor in reload_supervisors: - await supervisor.reload_plugins(reason="file_watcher") - - if reload_supervisors: + if changed_plugin_ids: + await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher") self._refresh_plugin_config_watch_subscriptions() @staticmethod @@ -502,36 +960,47 @@ class PluginRuntimeManager( def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]: """根据变更路径为指定 Supervisor 推断受影响的插件 ID。""" - for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items(): - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - candidate_dir = plugin_dir.resolve() / plugin_id - if path == candidate_dir or path.is_relative_to(candidate_dir): - return plugin_id + resolved_path = path.resolve() + + for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys(): + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)): + return plugin_id + + for plugin_id, plugin_path in self._plugin_path_cache.items(): + if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])): + continue + if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path): + return plugin_id + + for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])): + if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path): + self._plugin_path_cache[plugin_id] = plugin_path + return plugin_id - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - plugin_root = plugin_dir.resolve() - if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts): - return relative_parts[0] return None - @staticmethod - def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]: + def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]: """从给定插件目录集合中读取目标插件的配置内容。""" - for plugin_dir in plugin_dirs: - plugin_path = plugin_dir.resolve() / plugin_id - if plugin_path.is_dir(): - config_path = plugin_path / "config.toml" - if not config_path.exists(): - return {} - with open(config_path, "r", encoding="utf-8") as handle: - return tomlkit.load(handle).unwrap() - return {} + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + if plugin_path is None: + return {} + + config_path = plugin_path / "config.toml" + if not config_path.exists(): + return {} + + with open(config_path, "r", encoding="utf-8") as handle: + return tomlkit.load(handle).unwrap() # ─── 能力实现注册 ────────────────────────────────────────── def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None: + """向指定 Supervisor 注册主程序能力实现。 + + Args: + supervisor: 需要注册能力实现的目标 Supervisor。 + """ register_capability_impls(self, supervisor) diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index bcfb2758..e738d019 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -7,52 +7,52 @@ from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field - import logging as stdlib_logging import time +from pydantic import BaseModel, Field -# ─── 协议常量 ────────────────────────────────────────────────────── - -PROTOCOL_VERSION = "1.0" +# ====== 协议常量 ====== +PROTOCOL_VERSION = "1.0.0" # 支持的 SDK 版本范围(Host 在握手时校验) MIN_SDK_VERSION = "1.0.0" -MAX_SDK_VERSION = "1.99.99" - - -# ─── 消息类型 ────────────────────────────────────────────────────── +MAX_SDK_VERSION = "2.99.99" +# ====== 消息类型 ====== class MessageType(str, Enum): """RPC 消息类型""" REQUEST = "request" RESPONSE = "response" - EVENT = "event" + BROADCAST = "broadcast" -# ─── 请求 ID 生成器 ─────────────────────────────────────────────── +class ConfigReloadScope(str, Enum): + """配置热重载范围。""" + + SELF = "self" + BOT = "bot" + MODEL = "model" +# ====== 请求 ID 生成器 ====== class RequestIdGenerator: - """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)""" + """单调递增 int64 请求 ID 生成器""" def __init__(self, start: int = 1) -> None: self._counter = start - def next(self) -> int: + async def next(self) -> int: current = self._counter self._counter += 1 return current -# ─── Envelope 模型 ───────────────────────────────────────────────── - - +# ====== Envelope 模型 ====== class Envelope(BaseModel): - """RPC 统一信封 + """RPC 统一消息封装 所有 Host <-> Runner 消息均封装为此格式。 序列化流程:Envelope -> .model_dump() -> MsgPack encode @@ -60,15 +60,23 @@ class Envelope(BaseModel): """ protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本") + """协议版本""" request_id: int = Field(description="单调递增请求 ID") + """单调递增请求 ID""" message_type: MessageType = Field(description="消息类型") + """消息类型""" method: str = Field(default="", description="RPC 方法名") + """RPC 方法名""" plugin_id: str = Field(default="", description="目标插件 ID") - timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)") - timeout_ms: int = Field(default=30000, description="相对超时(ms)") - generation: int = Field(default=0, description="Runner generation 编号") + """目标插件 ID""" + timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳 (ms)") + """发送时间戳 (ms)""" + timeout_ms: int = Field(default=30000, description="相对超时 (ms)") + """相对超时 (ms)""" payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据") - error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)") + """业务数据""" + error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)") + """错误信息 (仅 response)""" def is_request(self) -> bool: return self.message_type == MessageType.REQUEST @@ -76,8 +84,8 @@ class Envelope(BaseModel): def is_response(self) -> bool: return self.message_type == MessageType.RESPONSE - def is_event(self) -> bool: - return self.message_type == MessageType.EVENT + def is_broadcast(self) -> bool: + return self.message_type == MessageType.BROADCAST def make_response( self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None @@ -89,7 +97,6 @@ class Envelope(BaseModel): message_type=MessageType.RESPONSE, method=self.method, plugin_id=self.plugin_id, - generation=self.generation, payload=payload or {}, error=error, ) @@ -105,153 +112,302 @@ class Envelope(BaseModel): ) -# ─── 握手消息 ────────────────────────────────────────────────────── - - +# ====== 握手请求与响应 ====== class HelloPayload(BaseModel): """runner.hello 握手请求 payload""" runner_id: str = Field(description="Runner 进程唯一标识") + """Runner 进程唯一标识""" sdk_version: str = Field(description="SDK 版本号") + """SDK 版本号""" session_token: str = Field(description="一次性会话令牌") + """一次性会话令牌""" class HelloResponsePayload(BaseModel): """runner.hello 握手响应 payload""" accepted: bool = Field(description="是否接受连接") + """是否接受连接""" host_version: str = Field(default="", description="Host 版本号") - assigned_generation: int = Field(default=0, description="分配的 generation 编号") - reason: str = Field(default="", description="拒绝原因(若 accepted=False)") - - -# ─── 组件注册消息 ────────────────────────────────────────────────── + """Host 版本号""" + reason: str = Field(default="", description="拒绝原因 (若 accepted=False)") + """拒绝原因 (若 `accepted`=`False`)""" +# ====== 组件注册消息 ====== class ComponentDeclaration(BaseModel): """单个组件声明""" name: str = Field(description="组件名称") - component_type: str = Field(description="组件类型: action/command/tool/event_handler") + """组件名称""" + component_type: str = Field( + description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway" + ) + """组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`""" plugin_id: str = Field(description="所属插件 ID") + """所属插件 ID""" metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据") + """组件元数据""" -class RegisterComponentsPayload(BaseModel): - """plugin.register_components 请求 payload""" +class RegisterPluginPayload(BaseModel): + """插件组件注册请求载荷。 + + 该模型同时用于 ``plugin.register_components`` 与兼容旧命名的 + ``plugin.register_plugin`` 请求。 + """ plugin_id: str = Field(description="插件 ID") + """插件 ID""" plugin_version: str = Field(default="1.0.0", description="插件版本") + """插件版本""" components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表") + """组件列表""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") + """所需能力列表""" + dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表") + """插件级依赖插件 ID 列表""" + config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围") + """订阅的全局配置热重载范围""" class BootstrapPluginPayload(BaseModel): """plugin.bootstrap 请求 payload""" plugin_id: str = Field(description="插件 ID") + """插件 ID""" plugin_version: str = Field(default="1.0.0", description="插件版本") + """插件版本""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") + """所需能力列表""" -# ─── 调用消息 ────────────────────────────────────────────────────── - - +# ====== 插件调用请求和响应 ====== class InvokePayload(BaseModel): - """plugin.invoke_* 请求 payload""" + """plugin.invoke.* 请求 payload""" component_name: str = Field(description="要调用的组件名称") + """要调用的组件名称""" args: Dict[str, Any] = Field(default_factory=dict, description="调用参数") + """调用参数""" class InvokeResultPayload(BaseModel): - """plugin.invoke_* 响应 payload""" + """plugin.invoke.* 响应 payload""" success: bool = Field(description="是否成功") + """是否成功""" result: Any = Field(default=None, description="返回值") + """返回值""" -# ─── 能力调用消息 ────────────────────────────────────────────────── - - +# ====== 能力调用消息 ====== class CapabilityRequestPayload(BaseModel): """cap.* 请求 payload(插件 -> Host 能力调用)""" capability: str = Field(description="能力名称,如 send.text, db.query") + """能力名称,如 send.text, db.query""" args: Dict[str, Any] = Field(default_factory=dict, description="调用参数") + """调用参数""" class CapabilityResponsePayload(BaseModel): """cap.* 响应 payload""" success: bool = Field(description="是否成功") + """是否成功""" result: Any = Field(default=None, description="返回值") + """返回值""" -# ─── 健康检查 ────────────────────────────────────────────────────── - - +# ====== 健康检查 ====== class HealthPayload(BaseModel): """plugin.health 响应 payload""" healthy: bool = Field(description="是否健康") + """是否健康""" loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表") - uptime_ms: int = Field(default=0, description="运行时长(ms)") + """已加载的插件列表""" + uptime_ms: int = Field(default=0, description="运行时长 (ms)") + """运行时长 (ms)""" class RunnerReadyPayload(BaseModel): """runner.ready 请求 payload""" loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表") + """已完成初始化的插件列表""" failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表") + """初始化失败的插件列表""" -# ─── 配置更新 ────────────────────────────────────────────────────── - - -# Host 侧现已支持配置更新推送: -# - 总配置热重载完成后,PluginRuntimeManager 会向已加载插件推送配置更新事件。 -# - 插件目录下的 config.toml 变化由现有 FileWatcher 监听并转发为 plugin.config_updated。 +# ====== 配置更新 ====== class ConfigUpdatedPayload(BaseModel): """plugin.config_updated 事件 payload""" plugin_id: str = Field(description="插件 ID") + """插件 ID""" + config_scope: ConfigReloadScope = Field(description="配置变更范围") + """配置变更范围""" config_version: str = Field(description="新配置版本") + """新配置版本""" config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容") + """配置内容""" -# ─── 关停 ────────────────────────────────────────────────────────── - - +# ====== 关停 ====== class ShutdownPayload(BaseModel): """plugin.shutdown / plugin.prepare_shutdown payload""" reason: str = Field(default="normal", description="关停原因") - drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)") + """关停原因""" + drain_timeout_ms: int = Field(default=5000, description="排空超时 (ms)") + """排空超时 (ms)""" -# ─── 日志传输 ────────────────────────────────────────────────────── +class UnregisterPluginPayload(BaseModel): + """插件注销请求载荷。""" + + plugin_id: str = Field(description="插件 ID") + """插件 ID""" + reason: str = Field(default="manual", description="注销原因") + """注销原因""" + + +class ReloadPluginPayload(BaseModel): + """插件重载请求载荷。""" + + plugin_id: str = Field(description="目标插件 ID") + """目标插件 ID""" + reason: str = Field(default="manual", description="重载原因") + """重载原因""" + external_available_plugins: Dict[str, str] = Field( + default_factory=dict, + description="可视为已满足的外部依赖插件版本映射", + ) + """可视为已满足的外部依赖插件版本映射""" + + +class ReloadPluginsPayload(BaseModel): + """批量插件重载请求载荷。""" + + plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表") + """目标插件 ID 列表""" + reason: str = Field(default="manual", description="重载原因") + """重载原因""" + external_available_plugins: Dict[str, str] = Field( + default_factory=dict, + description="可视为已满足的外部依赖插件版本映射", + ) + """可视为已满足的外部依赖插件版本映射""" + + +class ReloadPluginResultPayload(BaseModel): + """插件重载结果载荷。""" + + success: bool = Field(description="是否重载成功") + """是否重载成功""" + requested_plugin_id: str = Field(description="请求重载的插件 ID") + """请求重载的插件 ID""" + reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表") + """成功完成重载的插件列表""" + unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表") + """本次已卸载的插件列表""" + failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因") + """重载失败的插件及原因""" + + +class ReloadPluginsResultPayload(BaseModel): + """批量插件重载结果载荷。""" + + success: bool = Field(description="是否重载成功") + """是否重载成功""" + requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表") + """请求重载的插件 ID 列表""" + reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表") + """成功完成重载的插件列表""" + unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表") + """本次已卸载的插件列表""" + failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因") + """重载失败的插件及原因""" + + +class MessageGatewayStateUpdatePayload(BaseModel): + """消息网关运行时状态更新载荷。""" + + gateway_name: str = Field(description="消息网关组件名称") + """消息网关组件名称""" + ready: bool = Field(description="当前链路是否已经就绪") + """当前链路是否已经就绪""" + platform: str = Field(default="", description="当前链路负责的平台名称") + """当前链路负责的平台名称""" + account_id: str = Field(default="", description="当前链路对应的账号 ID 或 self_id") + """当前链路对应的账号 ID 或 self_id""" + scope: str = Field(default="", description="当前链路对应的可选路由作用域") + """当前链路对应的可选路由作用域""" + metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据") + """可选的运行时状态元数据""" + + +class MessageGatewayStateUpdateResultPayload(BaseModel): + """消息网关运行时状态更新结果载荷。""" + + accepted: bool = Field(description="Host 是否接受了本次状态更新") + """Host 是否接受了本次状态更新""" + ready: bool = Field(description="Host 记录的当前就绪状态") + """Host 记录的当前就绪状态""" + route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键") + """当前生效的路由键""" + + +class RouteMessagePayload(BaseModel): + """消息网关向 Host 路由外部消息的请求载荷。""" + + gateway_name: str = Field(description="接收消息的网关组件名称") + """接收消息的网关组件名称""" + message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典") + """符合 MessageDict 结构的标准消息字典""" + route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据") + """可选的路由辅助元数据""" + external_message_id: str = Field(default="", description="可选的外部平台消息 ID") + """可选的外部平台消息 ID""" + dedupe_key: str = Field(default="", description="可选的显式去重键") + """可选的显式去重键""" + + +class ReceiveExternalMessageResultPayload(BaseModel): + """外部消息注入结果载荷。""" + + accepted: bool = Field(description="Host 是否接受了本次消息注入") + """Host 是否接受了本次消息注入""" + route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键") + """本次消息使用的归一路由键""" + + +RegisterPluginPayload.model_rebuild() + + +# ====== 日志传输 ====== class LogEntry(BaseModel): """单条日志记录(Runner → Host 传输格式)""" - timestamp_ms: int = Field( - description="日志时间戳,Unix epoch 毫秒", - ) - level: int = Field( - description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"), - ) - logger_name: str = Field( - description="Logger 名称,如 plugin.my_plugin.submodule", - ) - message: str = Field( - description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)", - ) + timestamp_ms: int = Field(description="日志时间戳,Unix epoch 毫秒") + """日志时间戳,Unix epoch 毫秒""" + level: int = Field(description="stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL") + """stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL""" + logger_name: str = Field(description="Logger 名称,如 plugin.my_plugin.submodule") + """Logger 名称,如 plugin.my_plugin.submodule""" + message: str = Field(description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)") + """经 Formatter 格式化后的完整日志消息(含 exc_info 文本)""" exception_text: str = Field( default="", description="原始异常摘要(exc_text),供结构化消费;已嵌入 message 中", ) + """原始异常摘要(exc_text),供结构化消费;已嵌入 message 中""" + log_color_in_hex: Optional[str] = Field(default=None, description="日志颜色的十六进制字符串(如 #RRGGBB)") @property def levelname(self) -> str: @@ -262,6 +418,5 @@ class LogEntry(BaseModel): class LogBatchPayload(BaseModel): """runner.log_batch 事件 payload:Runner 端向 Host 批量推送日志记录""" - entries: List[LogEntry] = Field( - description="本批次日志记录列表,按时间升序排列", - ) + entries: List[LogEntry] = Field(description="本批次日志记录列表,按时间升序排列") + """本批次日志记录列表,按时间升序排列""" diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py index dcae6b8f..d2b9228b 100644 --- a/src/plugin_runtime/protocol/errors.py +++ b/src/plugin_runtime/protocol/errors.py @@ -18,17 +18,17 @@ class ErrorCode(str, Enum): E_TIMEOUT = "E_TIMEOUT" E_BAD_PAYLOAD = "E_BAD_PAYLOAD" E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH" + E_SHUTTING_DOWN = "E_SHUTTING_DOWN" # 权限与策略 E_UNAUTHORIZED = "E_UNAUTHORIZED" E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED" - E_BACKPRESSURE = "E_BACKPRESSURE" + E_BACK_PRESSURE = "E_BACK_PRESSURE" E_HOST_OVERLOADED = "E_HOST_OVERLOADED" # 插件生命周期 E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED" E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND" - E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH" E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS" # 能力调用 @@ -65,3 +65,13 @@ class RPCError(Exception): message=data.get("message", ""), details=data.get("details", {}), ) + + @classmethod + def from_exception(cls, exception: Exception, code_mapping: Optional[Dict[type[Exception], ErrorCode]] = None): + if isinstance(exception, cls): + return exception + if code_mapping: + for exception_type, code in code_mapping.items(): + if isinstance(exception, exception_type): + return cls(code=code, message=str(exception)) + return cls(ErrorCode.E_UNKNOWN, str(exception)) diff --git a/src/plugin_runtime/runner/log_handler.py b/src/plugin_runtime/runner/log_handler.py index b5a0a328..6f42940f 100644 --- a/src/plugin_runtime/runner/log_handler.py +++ b/src/plugin_runtime/runner/log_handler.py @@ -66,6 +66,12 @@ class RunnerIPCLogHandler(logging.Handler): ALLOWED_LOGGER_PREFIXES: tuple[str, ...] = ("plugin.", "plugin_runtime.", "_maibot_plugin_") def __init__(self) -> None: + """初始化 Runner 端日志转发处理器。 + + 创建有界日志缓冲区,并准备与 RPC 客户端绑定的后台刷新任务。 + 此时不会启动任何异步任务;真正开始转发要等到 :meth:`start` + 被调用后才会发生。 + """ super().__init__() # deque(maxlen=N): append/popleft 在 CPython GIL 保护下线程安全 self._buffer: collections.deque[LogEntry] = collections.deque(maxlen=self.QUEUE_MAX) diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py index b6990850..33c2b1e5 100644 --- a/src/plugin_runtime/runner/manifest_validator.py +++ b/src/plugin_runtime/runner/manifest_validator.py @@ -1,28 +1,55 @@ -"""Manifest 校验与版本兼容性 +"""Manifest 校验与解析。 -从旧系统的 ManifestValidator / VersionComparator 对齐移植, -适配新 plugin_runtime 的 _manifest.json 格式。 +集中负责插件 ``_manifest.json`` 的读取、结构校验、运行时兼容性判断, +以及插件依赖/Python 包依赖的解析逻辑。 """ -from typing import Any, Dict, List, Tuple +from functools import lru_cache +from importlib import metadata as importlib_metadata +from pathlib import Path +from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union +import json import re +import tomllib + +from packaging.requirements import InvalidRequirement, Requirement +from packaging.specifiers import InvalidSpecifier, SpecifierSet +from packaging.utils import canonicalize_name +from packaging.version import InvalidVersion, Version +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator from src.common.logger import get_logger logger = get_logger("plugin_runtime.runner.manifest_validator") +_SEMVER_PATTERN = re.compile(r"^\d+\.\d+\.\d+$") +_PLUGIN_ID_PATTERN = re.compile(r"^[a-z0-9]+(?:[.-][a-z0-9]+)+$") +_PACKAGE_NAME_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") +_HTTP_URL_PATTERN = re.compile(r"^https?://.+$") + class VersionComparator: - """语义化版本号比较器""" + """语义化版本号比较器。""" @staticmethod def normalize_version(version: str) -> str: + """将版本号规范化为三段式语义版本字符串。 + + Args: + version: 原始版本号字符串。 + + Returns: + str: 规范化后的 ``major.minor.patch`` 形式版本号。 + 当输入为空或格式非法时返回 ``0.0.0``。 + """ if not version: return "0.0.0" - normalized = re.sub(r"-snapshot\.\d+", "", version.strip()) + + normalized = re.sub(r"-snapshot\.\d+", "", str(version).strip()) if not re.match(r"^\d+(\.\d+){0,2}$", normalized): return "0.0.0" + parts = normalized.split(".") while len(parts) < 3: parts.append("0") @@ -30,6 +57,15 @@ class VersionComparator: @staticmethod def parse_version(version: str) -> Tuple[int, int, int]: + """将版本字符串解析为可比较的整数元组。 + + Args: + version: 原始版本号字符串。 + + Returns: + Tuple[int, int, int]: 三段式版本号对应的整数元组。 + 当解析失败时返回 ``(0, 0, 0)``。 + """ normalized = VersionComparator.normalize_version(version) try: parts = normalized.split(".") @@ -39,98 +75,1072 @@ class VersionComparator: @staticmethod def compare(v1: str, v2: str) -> int: + """比较两个版本号的大小关系。 + + Args: + v1: 第一个版本号。 + v2: 第二个版本号。 + + Returns: + int: ``-1`` 表示 ``v1 < v2``,``1`` 表示 ``v1 > v2``, + ``0`` 表示两者相等。 + """ t1 = VersionComparator.parse_version(v1) t2 = VersionComparator.parse_version(v2) if t1 < t2: return -1 - elif t1 > t2: + if t1 > t2: return 1 return 0 @staticmethod def is_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]: + """判断版本号是否落在给定闭区间内。 + + Args: + version: 待检查的版本号。 + min_version: 允许的最小版本号,留空表示不限制下界。 + max_version: 允许的最大版本号,留空表示不限制上界。 + + Returns: + Tuple[bool, str]: 第一项表示是否满足要求,第二项为失败原因; + 当校验通过时第二项为空字符串。 + """ if not min_version and not max_version: return True, "" - vn = VersionComparator.normalize_version(version) + + normalized_version = VersionComparator.normalize_version(version) if min_version: - mn = VersionComparator.normalize_version(min_version) - if VersionComparator.compare(vn, mn) < 0: - return False, f"版本 {vn} 低于最小要求 {mn}" + normalized_min_version = VersionComparator.normalize_version(min_version) + if VersionComparator.compare(normalized_version, normalized_min_version) < 0: + return False, f"版本 {normalized_version} 低于最小要求 {normalized_min_version}" if max_version: - mx = VersionComparator.normalize_version(max_version) - if VersionComparator.compare(vn, mx) > 0: - return False, f"版本 {vn} 高于最大支持 {mx}" + normalized_max_version = VersionComparator.normalize_version(max_version) + if VersionComparator.compare(normalized_version, normalized_max_version) > 0: + return False, f"版本 {normalized_version} 高于最大支持 {normalized_max_version}" return True, "" + @staticmethod + def is_valid_semver(version: str) -> bool: + """判断字符串是否为严格三段式语义版本号。 + + Args: + version: 待检查的版本号字符串。 + + Returns: + bool: 是否满足 ``X.Y.Z`` 格式。 + """ + return bool(_SEMVER_PATTERN.fullmatch(str(version or "").strip())) + + +class _StrictManifestModel(BaseModel): + """Manifest 解析使用的严格基类模型。""" + + model_config = ConfigDict(extra="forbid", frozen=True, str_strip_whitespace=True) + + +class ManifestAuthor(_StrictManifestModel): + """插件作者信息。""" + + name: str = Field(description="作者名称") + url: str = Field(description="作者主页地址") + + @field_validator("name") + @classmethod + def _validate_name(cls, value: str) -> str: + """校验作者名称。 + + Args: + value: 原始作者名称。 + + Returns: + str: 规范化后的作者名称。 + + Raises: + ValueError: 当字段为空时抛出。 + """ + if not value: + raise ValueError("不能为空") + return value + + @field_validator("url") + @classmethod + def _validate_url(cls, value: str) -> str: + """校验作者主页地址。 + + Args: + value: 原始主页地址。 + + Returns: + str: 规范化后的主页地址。 + + Raises: + ValueError: 当字段为空或不是 HTTP/HTTPS URL 时抛出。 + """ + if not value: + raise ValueError("不能为空") + if not _HTTP_URL_PATTERN.fullmatch(value): + raise ValueError("必须为 http:// 或 https:// 开头的 URL") + return value + + +class ManifestUrls(_StrictManifestModel): + """插件相关链接集合。""" + + repository: str = Field(description="插件仓库地址") + homepage: Optional[str] = Field(default=None, description="插件主页地址") + documentation: Optional[str] = Field(default=None, description="插件文档地址") + issues: Optional[str] = Field(default=None, description="插件问题反馈地址") + + @field_validator("repository") + @classmethod + def _validate_repository(cls, value: str) -> str: + """校验仓库地址。 + + Args: + value: 原始仓库地址。 + + Returns: + str: 规范化后的仓库地址。 + + Raises: + ValueError: 当字段为空或不是 HTTP/HTTPS URL 时抛出。 + """ + if not value: + raise ValueError("不能为空") + if not _HTTP_URL_PATTERN.fullmatch(value): + raise ValueError("必须为 http:// 或 https:// 开头的 URL") + return value + + @field_validator("homepage", "documentation", "issues") + @classmethod + def _validate_optional_url(cls, value: Optional[str]) -> Optional[str]: + """校验可选链接字段。 + + Args: + value: 原始链接值。 + + Returns: + Optional[str]: 合法的链接值。 + + Raises: + ValueError: 当提供的值不是 HTTP/HTTPS URL 时抛出。 + """ + if value is None: + return None + if not value: + raise ValueError("不能为空字符串") + if not _HTTP_URL_PATTERN.fullmatch(value): + raise ValueError("必须为 http:// 或 https:// 开头的 URL") + return value + + +class ManifestVersionRange(_StrictManifestModel): + """版本闭区间声明。""" + + min_version: str = Field(description="最小版本,闭区间") + max_version: str = Field(description="最大版本,闭区间") + + @field_validator("min_version", "max_version") + @classmethod + def _validate_version(cls, value: str) -> str: + """校验版本号格式。 + + Args: + value: 原始版本号。 + + Returns: + str: 合法的版本号。 + + Raises: + ValueError: 当版本号不是严格三段式语义版本时抛出。 + """ + if not VersionComparator.is_valid_semver(value): + raise ValueError("必须为严格三段式版本号,例如 1.0.0") + return value + + @model_validator(mode="after") + def _validate_range(self) -> "ManifestVersionRange": + """校验版本区间上下界关系。 + + Returns: + ManifestVersionRange: 当前对象本身。 + + Raises: + ValueError: 当最小版本大于最大版本时抛出。 + """ + if VersionComparator.compare(self.min_version, self.max_version) > 0: + raise ValueError("min_version 不能大于 max_version") + return self + + +class ManifestI18n(_StrictManifestModel): + """国际化配置。""" + + default_locale: str = Field(description="默认语言") + locales_path: Optional[str] = Field(default=None, description="语言资源目录") + supported_locales: List[str] = Field(default_factory=list, description="支持的语言列表") + + @field_validator("default_locale") + @classmethod + def _validate_default_locale(cls, value: str) -> str: + """校验默认语言。 + + Args: + value: 原始默认语言。 + + Returns: + str: 规范化后的默认语言。 + + Raises: + ValueError: 当字段为空时抛出。 + """ + if not value: + raise ValueError("不能为空") + return value + + @field_validator("locales_path") + @classmethod + def _validate_locales_path(cls, value: Optional[str]) -> Optional[str]: + """校验语言资源目录。 + + Args: + value: 原始语言资源目录。 + + Returns: + Optional[str]: 合法的目录值。 + + Raises: + ValueError: 当值为空字符串时抛出。 + """ + if value is None: + return None + if not value: + raise ValueError("不能为空字符串") + return value + + @field_validator("supported_locales") + @classmethod + def _validate_supported_locales(cls, value: List[str]) -> List[str]: + """校验支持语言列表。 + + Args: + value: 原始语言列表。 + + Returns: + List[str]: 去重后的语言列表。 + + Raises: + ValueError: 当列表项为空时抛出。 + """ + normalized_locales: List[str] = [] + for locale in value: + normalized_locale = str(locale or "").strip() + if not normalized_locale: + raise ValueError("语言列表中存在空值") + if normalized_locale not in normalized_locales: + normalized_locales.append(normalized_locale) + return normalized_locales + + @model_validator(mode="after") + def _validate_default_locale_membership(self) -> "ManifestI18n": + """校验默认语言是否位于支持列表中。 + + Returns: + ManifestI18n: 当前对象本身。 + + Raises: + ValueError: 当 ``supported_locales`` 非空但未包含 ``default_locale`` 时抛出。 + """ + if self.supported_locales and self.default_locale not in self.supported_locales: + raise ValueError("default_locale 必须包含在 supported_locales 中") + return self + + +class PluginDependencyDefinition(_StrictManifestModel): + """插件级依赖声明。""" + + type: Literal["plugin"] = Field(description="依赖类型") + id: str = Field(description="依赖插件 ID") + version_spec: str = Field(description="版本约束表达式") + + @field_validator("id") + @classmethod + def _validate_id(cls, value: str) -> str: + """校验依赖插件 ID。 + + Args: + value: 原始依赖插件 ID。 + + Returns: + str: 合法的依赖插件 ID。 + + Raises: + ValueError: 当 ID 不符合规则时抛出。 + """ + if not _PLUGIN_ID_PATTERN.fullmatch(value): + raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin") + return value + + @field_validator("version_spec") + @classmethod + def _validate_version_spec(cls, value: str) -> str: + """校验插件依赖版本约束。 + + Args: + value: 原始版本约束表达式。 + + Returns: + str: 合法的版本约束表达式。 + + Raises: + ValueError: 当表达式无效时抛出。 + """ + if not value: + raise ValueError("不能为空") + try: + SpecifierSet(value) + except InvalidSpecifier as exc: + raise ValueError(f"无效的版本约束: {exc}") from exc + return value + + +class PythonPackageDependencyDefinition(_StrictManifestModel): + """Python 包依赖声明。""" + + type: Literal["python_package"] = Field(description="依赖类型") + name: str = Field(description="Python 包名") + version_spec: str = Field(description="版本约束表达式") + + @field_validator("name") + @classmethod + def _validate_name(cls, value: str) -> str: + """校验 Python 包名。 + + Args: + value: 原始包名。 + + Returns: + str: 合法的包名。 + + Raises: + ValueError: 当包名不合法时抛出。 + """ + if not _PACKAGE_NAME_PATTERN.fullmatch(value): + raise ValueError("包名只能包含字母、数字、点号、下划线和横线") + return value + + @field_validator("version_spec") + @classmethod + def _validate_version_spec(cls, value: str) -> str: + """校验 Python 包版本约束。 + + Args: + value: 原始版本约束表达式。 + + Returns: + str: 合法的版本约束表达式。 + + Raises: + ValueError: 当表达式无效时抛出。 + """ + if not value: + raise ValueError("不能为空") + try: + Requirement(f"placeholder{value}") + except InvalidRequirement as exc: + raise ValueError(f"无效的版本约束: {exc}") from exc + return value + + +ManifestDependencyDefinition = Annotated[ + Union[PluginDependencyDefinition, PythonPackageDependencyDefinition], + Field(discriminator="type"), +] + + +class PluginManifest(_StrictManifestModel): + """插件 Manifest v2 强类型模型。""" + + manifest_version: Literal[2] = Field(description="Manifest 协议版本") + version: str = Field(description="插件版本") + name: str = Field(description="插件展示名称") + description: str = Field(description="插件描述") + author: ManifestAuthor = Field(description="插件作者信息") + license: str = Field(description="插件协议") + urls: ManifestUrls = Field(description="插件相关链接") + host_application: ManifestVersionRange = Field(description="Host 兼容区间") + sdk: ManifestVersionRange = Field(description="SDK 兼容区间") + dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明") + capabilities: List[str] = Field(description="插件声明的能力请求") + i18n: ManifestI18n = Field(description="国际化配置") + id: str = Field(description="稳定插件 ID") + + @field_validator("version") + @classmethod + def _validate_version(cls, value: str) -> str: + """校验插件版本号格式。 + + Args: + value: 原始插件版本号。 + + Returns: + str: 合法的插件版本号。 + + Raises: + ValueError: 当版本号不是严格三段式语义版本时抛出。 + """ + if not VersionComparator.is_valid_semver(value): + raise ValueError("必须为严格三段式版本号,例如 1.0.0") + return value + + @field_validator("name", "description", "license", "id") + @classmethod + def _validate_required_string(cls, value: str, info: Any) -> str: + """校验必填字符串字段。 + + Args: + value: 原始字段值。 + info: Pydantic 字段上下文。 + + Returns: + str: 合法的字段值。 + + Raises: + ValueError: 当字段为空或格式不合法时抛出。 + """ + if not value: + raise ValueError("不能为空") + if info.field_name == "id" and not _PLUGIN_ID_PATTERN.fullmatch(value): + raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin") + return value + + @field_validator("capabilities") + @classmethod + def _validate_capabilities(cls, value: List[str]) -> List[str]: + """校验能力声明列表。 + + Args: + value: 原始能力声明列表。 + + Returns: + List[str]: 去重后的能力列表。 + + Raises: + ValueError: 当列表为空项或能力名为空时抛出。 + """ + normalized_capabilities: List[str] = [] + for capability in value: + normalized_capability = str(capability or "").strip() + if not normalized_capability: + raise ValueError("capabilities 中存在空能力名") + if normalized_capability not in normalized_capabilities: + normalized_capabilities.append(normalized_capability) + return normalized_capabilities + + @model_validator(mode="after") + def _validate_dependencies(self) -> "PluginManifest": + """校验依赖声明集合。 + + Returns: + PluginManifest: 当前对象本身。 + + Raises: + ValueError: 当依赖项重复或插件依赖自身时抛出。 + """ + plugin_dependency_ids: set[str] = set() + python_package_names: set[str] = set() + + for dependency in self.dependencies: + if isinstance(dependency, PluginDependencyDefinition): + if dependency.id == self.id: + raise ValueError("dependencies 中的插件依赖不能依赖自身") + if dependency.id in plugin_dependency_ids: + raise ValueError(f"存在重复的插件依赖声明: {dependency.id}") + plugin_dependency_ids.add(dependency.id) + continue + + normalized_package_name = canonicalize_name(dependency.name) + if normalized_package_name in python_package_names: + raise ValueError(f"存在重复的 Python 包依赖声明: {dependency.name}") + python_package_names.add(normalized_package_name) + + return self + + @property + def plugin_dependencies(self) -> List[PluginDependencyDefinition]: + """返回插件级依赖列表。 + + Returns: + List[PluginDependencyDefinition]: 所有 ``type=plugin`` 的依赖项。 + """ + return [dependency for dependency in self.dependencies if isinstance(dependency, PluginDependencyDefinition)] + + @property + def python_package_dependencies(self) -> List[PythonPackageDependencyDefinition]: + """返回 Python 包依赖列表。 + + Returns: + List[PythonPackageDependencyDefinition]: 所有 ``type=python_package`` 的依赖项。 + """ + return [ + dependency + for dependency in self.dependencies + if isinstance(dependency, PythonPackageDependencyDefinition) + ] + + @property + def plugin_dependency_ids(self) -> List[str]: + """返回插件级依赖的插件 ID 列表。 + + Returns: + List[str]: 所有插件级依赖的插件 ID。 + """ + return [dependency.id for dependency in self.plugin_dependencies] + class ManifestValidator: - """_manifest.json 校验器""" + """严格的插件 Manifest v2 校验器。""" - REQUIRED_FIELDS = ["name", "version", "description", "author"] - RECOMMENDED_FIELDS = ["license", "keywords", "categories"] - SUPPORTED_MANIFEST_VERSIONS = [1, 2] + SUPPORTED_MANIFEST_VERSIONS = [2] - def __init__(self, host_version: str = "") -> None: - self._host_version = host_version + def __init__( + self, + host_version: str = "", + sdk_version: str = "", + project_root: Optional[Path] = None, + ) -> None: + """初始化 Manifest 校验器。 + + Args: + host_version: 当前 Host 版本号;留空时自动从主程序 ``pyproject.toml`` 读取。 + sdk_version: 当前 SDK 版本号;留空时自动从运行环境中探测。 + project_root: 项目根目录;留空时自动推断。 + """ + self._project_root: Path = project_root or self._resolve_project_root() + self._host_version: str = host_version or self._detect_default_host_version(self._project_root) + self._sdk_version: str = sdk_version or self._detect_default_sdk_version(self._project_root) self.errors: List[str] = [] self.warnings: List[str] = [] def validate(self, manifest: Dict[str, Any]) -> bool: - """校验 manifest 数据,返回是否通过(errors 为空即通过)。""" + """校验 manifest 数据,返回是否通过。 + + Args: + manifest: 待校验的 Manifest 原始字典。 + + Returns: + bool: 校验是否通过。 + """ + return self.parse_manifest(manifest) is not None + + def parse_manifest(self, manifest: Dict[str, Any]) -> Optional[PluginManifest]: + """解析并校验 manifest 字典。 + + Args: + manifest: 待解析的 Manifest 原始字典。 + + Returns: + Optional[PluginManifest]: 解析成功时返回强类型 Manifest;失败时返回 ``None``。 + """ self.errors.clear() self.warnings.clear() - self._check_required_fields(manifest) - self._check_manifest_version(manifest) - self._check_author(manifest) - self._check_host_compatibility(manifest) - self._check_recommended(manifest) + try: + parsed_manifest = PluginManifest.model_validate(manifest) + except ValidationError as exc: + self.errors.extend(self._format_validation_errors(exc)) + self._log_errors() + return None + self._validate_runtime_compatibility(parsed_manifest) if self.errors: - for e in self.errors: - logger.error(f"Manifest 校验失败: {e}") - if self.warnings: - for w in self.warnings: - logger.warning(f"Manifest 警告: {w}") + self._log_errors() + return None - return len(self.errors) == 0 + return parsed_manifest - def _check_required_fields(self, manifest: Dict[str, Any]) -> None: - for field in self.REQUIRED_FIELDS: - if field not in manifest: - self.errors.append(f"缺少必需字段: {field}") - elif not manifest[field]: - self.errors.append(f"必需字段不能为空: {field}") + def load_from_plugin_path(self, plugin_path: Path, require_entrypoint: bool = True) -> Optional[PluginManifest]: + """从插件目录读取并解析 manifest。 - def _check_manifest_version(self, manifest: Dict[str, Any]) -> None: - mv = manifest.get("manifest_version") - if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS: - self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}") + Args: + plugin_path: 单个插件目录路径。 + require_entrypoint: 是否要求目录内存在 ``plugin.py`` 入口文件。 - def _check_author(self, manifest: Dict[str, Any]) -> None: - author = manifest.get("author") - if author is None: - return - if isinstance(author, dict): - if "name" not in author or not author["name"]: - self.errors.append("author 对象缺少 name 字段") - elif isinstance(author, str): - if not author.strip(): - self.errors.append("author 不能为空") - else: - self.errors.append("author 应为字符串或 {name, url} 对象") + Returns: + Optional[PluginManifest]: 解析成功时返回强类型 Manifest;失败时返回 ``None``。 + """ + self.errors.clear() + self.warnings.clear() - def _check_host_compatibility(self, manifest: Dict[str, Any]) -> None: - host_app = manifest.get("host_application") - if not isinstance(host_app, dict) or not self._host_version: - return - min_v = host_app.get("min_version", "") - max_v = host_app.get("max_version", "") - ok, msg = VersionComparator.is_in_range(self._host_version, min_v, max_v) - if not ok: - self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})") + manifest_path = plugin_path / "_manifest.json" + entrypoint_path = plugin_path / "plugin.py" - def _check_recommended(self, manifest: Dict[str, Any]) -> None: - for field in self.RECOMMENDED_FIELDS: - if field not in manifest or not manifest[field]: - self.warnings.append(f"建议填写字段: {field}") + if not manifest_path.is_file(): + self.errors.append("缺少 _manifest.json") + return None + if require_entrypoint and not entrypoint_path.is_file(): + self.errors.append("缺少 plugin.py") + return None + + try: + with manifest_path.open("r", encoding="utf-8") as manifest_file: + manifest_data = json.load(manifest_file) + except Exception as exc: + self.errors.append(f"manifest 解析失败: {exc}") + self._log_errors() + return None + + if not isinstance(manifest_data, dict): + self.errors.append("manifest 顶层必须为 JSON 对象") + self._log_errors() + return None + + return self.parse_manifest(manifest_data) + + def iter_plugin_manifests( + self, + plugin_dirs: Iterable[Path], + require_entrypoint: bool = True, + ) -> Iterable[Tuple[Path, PluginManifest]]: + """扫描插件根目录并迭代所有可成功解析的 Manifest。 + + Args: + plugin_dirs: 一个或多个插件根目录。 + require_entrypoint: 是否要求每个插件目录内存在 ``plugin.py``。 + + Yields: + Tuple[Path, PluginManifest]: ``(插件目录路径, 解析结果)`` 二元组。 + """ + for plugin_root in plugin_dirs: + normalized_root = Path(plugin_root).resolve() + if not normalized_root.is_dir(): + continue + + for candidate_path in sorted(entry.resolve() for entry in normalized_root.iterdir() if entry.is_dir()): + parsed_manifest = self.load_from_plugin_path(candidate_path, require_entrypoint=require_entrypoint) + if parsed_manifest is None: + continue + yield candidate_path, parsed_manifest + + def build_plugin_dependency_map( + self, + plugin_dirs: Iterable[Path], + require_entrypoint: bool = True, + ) -> Dict[str, List[str]]: + """扫描目录并构建 ``plugin_id -> 依赖插件 ID 列表`` 映射。 + + Args: + plugin_dirs: 一个或多个插件根目录。 + require_entrypoint: 是否要求每个插件目录内存在 ``plugin.py``。 + + Returns: + Dict[str, List[str]]: 所有成功解析到的插件依赖映射。 + """ + dependency_map: Dict[str, List[str]] = {} + for _plugin_path, manifest in self.iter_plugin_manifests(plugin_dirs, require_entrypoint=require_entrypoint): + dependency_map[manifest.id] = manifest.plugin_dependency_ids + return dependency_map + + def read_plugin_id_from_plugin_path(self, plugin_path: Path, require_entrypoint: bool = True) -> Optional[str]: + """从单个插件目录中读取规范化后的插件 ID。 + + Args: + plugin_path: 单个插件目录路径。 + require_entrypoint: 是否要求目录内存在 ``plugin.py``。 + + Returns: + Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。 + """ + manifest = self.load_from_plugin_path(plugin_path, require_entrypoint=require_entrypoint) + if manifest is None: + return None + return manifest.id + + def get_unsatisfied_plugin_dependencies( + self, + manifest: PluginManifest, + available_plugin_versions: Dict[str, str], + ) -> List[str]: + """返回当前 Manifest 尚未满足的插件依赖项。 + + Args: + manifest: 目标插件的强类型 Manifest。 + available_plugin_versions: 当前可用插件版本映射,键为插件 ID,值为插件版本。 + + Returns: + List[str]: 未满足依赖的错误描述列表。 + """ + unsatisfied_dependencies: List[str] = [] + for dependency in manifest.plugin_dependencies: + dependency_version = available_plugin_versions.get(dependency.id) + if not dependency_version: + unsatisfied_dependencies.append(f"{dependency.id} (未找到依赖插件)") + continue + + if not self._version_matches_specifier(dependency_version, dependency.version_spec): + unsatisfied_dependencies.append( + f"{dependency.id} (需要 {dependency.version_spec},当前 {dependency_version})" + ) + + return unsatisfied_dependencies + + def is_plugin_dependency_satisfied( + self, + dependency: PluginDependencyDefinition, + plugin_version: str, + ) -> bool: + """判断单个插件依赖是否被指定版本满足。 + + Args: + dependency: 插件级依赖声明。 + plugin_version: 当前可用的插件版本号。 + + Returns: + bool: 是否满足版本约束。 + """ + return self._version_matches_specifier(plugin_version, dependency.version_spec) + + def _validate_runtime_compatibility(self, manifest: PluginManifest) -> None: + """校验运行时版本兼容性与 Python 包依赖。 + + Args: + manifest: 已通过结构校验的强类型 Manifest。 + """ + host_ok, host_message = VersionComparator.is_in_range( + self._host_version, + manifest.host_application.min_version, + manifest.host_application.max_version, + ) + if not host_ok: + self.errors.append(f"Host 版本不兼容: {host_message} (当前 Host: {self._host_version})") + + sdk_ok, sdk_message = VersionComparator.is_in_range( + self._sdk_version, + manifest.sdk.min_version, + manifest.sdk.max_version, + ) + if not sdk_ok: + self.errors.append(f"SDK 版本不兼容: {sdk_message} (当前 SDK: {self._sdk_version})") + + self._validate_python_package_dependencies(manifest) + + def _validate_python_package_dependencies(self, manifest: PluginManifest) -> None: + """校验 Python 包依赖与主程序运行环境是否冲突。 + + Args: + manifest: 已通过结构校验的强类型 Manifest。 + """ + host_requirements = self._load_host_dependency_requirements(self._project_root) + + for dependency in manifest.python_package_dependencies: + normalized_package_name = canonicalize_name(dependency.name) + package_specifier = self._build_specifier_set(dependency.version_spec) + if package_specifier is None: + self.errors.append( + f"Python 包依赖 {dependency.name} 的版本约束无效: {dependency.version_spec}" + ) + continue + + installed_version = self._get_installed_package_version(dependency.name) + host_requirement = host_requirements.get(normalized_package_name) + + if installed_version is not None and not self._version_matches_specifier( + installed_version, + dependency.version_spec, + ): + self.errors.append( + f"Python 包依赖冲突: {dependency.name} 需要 {dependency.version_spec}," + f"当前运行环境为 {installed_version}" + ) + continue + + if host_requirement is None: + continue + + if not self._requirements_may_overlap(host_requirement.specifier, package_specifier): + host_specifier = str(host_requirement.specifier or "") + self.errors.append( + f"Python 包依赖冲突: {dependency.name} 需要 {dependency.version_spec}," + f"主程序依赖约束为 {host_specifier or '任意版本'}" + ) + + def _log_errors(self) -> None: + """输出当前累计的 Manifest 校验错误。""" + for error_message in self.errors: + logger.error(f"Manifest 校验失败: {error_message}") + + @classmethod + def _resolve_project_root(cls) -> Path: + """推断当前项目根目录。 + + Returns: + Path: 项目根目录路径。 + """ + return Path(__file__).resolve().parents[3] + + @classmethod + @lru_cache(maxsize=None) + def _detect_default_host_version(cls, project_root: Path) -> str: + """从主程序 ``pyproject.toml`` 探测 Host 版本号。 + + Args: + project_root: 项目根目录。 + + Returns: + str: 探测到的 Host 版本号;失败时返回空字符串。 + """ + pyproject_path = project_root / "pyproject.toml" + try: + with pyproject_path.open("rb") as pyproject_file: + pyproject_data = tomllib.load(pyproject_file) + except Exception: + return "" + + project_data = pyproject_data.get("project", {}) + if not isinstance(project_data, dict): + return "" + + raw_version = str(project_data.get("version", "") or "").strip() + if VersionComparator.is_valid_semver(raw_version): + return raw_version + return "" + + @classmethod + @lru_cache(maxsize=None) + def _detect_default_sdk_version(cls, project_root: Path) -> str: + """探测当前运行环境中的 SDK 版本号。 + + Args: + project_root: 项目根目录。 + + Returns: + str: 探测到的 SDK 版本号;失败时返回空字符串。 + """ + try: + raw_version = importlib_metadata.version("maibot-plugin-sdk") + if VersionComparator.is_valid_semver(raw_version): + return raw_version + except importlib_metadata.PackageNotFoundError: + pass + + sdk_pyproject_path = project_root / "packages" / "maibot-plugin-sdk" / "pyproject.toml" + try: + with sdk_pyproject_path.open("rb") as pyproject_file: + pyproject_data = tomllib.load(pyproject_file) + except Exception: + return "" + + project_data = pyproject_data.get("project", {}) + if not isinstance(project_data, dict): + return "" + + raw_version = str(project_data.get("version", "") or "").strip() + if VersionComparator.is_valid_semver(raw_version): + return raw_version + return "" + + @classmethod + @lru_cache(maxsize=None) + def _load_host_dependency_requirements(cls, project_root: Path) -> Dict[str, Requirement]: + """加载主程序 ``pyproject.toml`` 中声明的依赖约束。 + + Args: + project_root: 项目根目录。 + + Returns: + Dict[str, Requirement]: 以规范化包名为键的 Requirement 映射。 + """ + pyproject_path = project_root / "pyproject.toml" + try: + with pyproject_path.open("rb") as pyproject_file: + pyproject_data = tomllib.load(pyproject_file) + except Exception: + return {} + + project_data = pyproject_data.get("project", {}) + if not isinstance(project_data, dict): + return {} + + raw_dependencies = project_data.get("dependencies", []) + if not isinstance(raw_dependencies, list): + return {} + + requirements: Dict[str, Requirement] = {} + for raw_dependency in raw_dependencies: + dependency_text = str(raw_dependency or "").strip() + if not dependency_text: + continue + + try: + requirement = Requirement(dependency_text) + except InvalidRequirement: + continue + + requirements[canonicalize_name(requirement.name)] = requirement + + return requirements + + @staticmethod + def _get_installed_package_version(package_name: str) -> Optional[str]: + """获取当前运行环境中指定 Python 包的安装版本。 + + Args: + package_name: 待查询的包名。 + + Returns: + Optional[str]: 已安装版本号;未安装时返回 ``None``。 + """ + try: + return importlib_metadata.version(package_name) + except importlib_metadata.PackageNotFoundError: + return None + + @staticmethod + def _build_specifier_set(version_spec: str) -> Optional[SpecifierSet]: + """构造版本约束对象。 + + Args: + version_spec: 版本约束字符串。 + + Returns: + Optional[SpecifierSet]: 构造成功时返回约束对象,否则返回 ``None``。 + """ + try: + return SpecifierSet(version_spec) + except InvalidSpecifier: + return None + + @staticmethod + def _version_matches_specifier(version: str, version_spec: str) -> bool: + """判断版本是否满足给定约束。 + + Args: + version: 待判断的版本号。 + version_spec: 版本约束表达式。 + + Returns: + bool: 是否满足约束。 + """ + try: + normalized_version = Version(version) + specifier_set = SpecifierSet(version_spec) + except (InvalidVersion, InvalidSpecifier): + return False + return specifier_set.contains(normalized_version, prereleases=True) + + @classmethod + def _requirements_may_overlap(cls, left: SpecifierSet, right: SpecifierSet) -> bool: + """粗略判断两个版本约束是否存在交集。 + + Args: + left: 左侧版本约束。 + right: 右侧版本约束。 + + Returns: + bool: 若可能存在交集则返回 ``True``,否则返回 ``False``。 + """ + candidate_versions = cls._build_candidate_versions(left, right) + for candidate_version in candidate_versions: + if left.contains(candidate_version, prereleases=True) and right.contains(candidate_version, prereleases=True): + return True + return False + + @classmethod + def _build_candidate_versions(cls, left: SpecifierSet, right: SpecifierSet) -> List[Version]: + """为两个版本约束构造一组用于交集探测的候选版本。 + + Args: + left: 左侧版本约束。 + right: 右侧版本约束。 + + Returns: + List[Version]: 去重后的候选版本列表。 + """ + candidate_versions: List[Version] = [Version("0.0.0")] + for specifier in tuple(left) + tuple(right): + for candidate_version in cls._expand_candidate_versions(specifier.version): + if candidate_version not in candidate_versions: + candidate_versions.append(candidate_version) + return candidate_versions + + @staticmethod + def _expand_candidate_versions(raw_version: str) -> List[Version]: + """根据边界版本扩展出一组邻近候选版本。 + + Args: + raw_version: 约束中出现的边界版本字符串。 + + Returns: + List[Version]: 可用于交集探测的候选版本列表。 + """ + normalized_text = raw_version.replace("*", "0") + try: + boundary_version = Version(normalized_text) + except InvalidVersion: + return [] + + release_parts = list(boundary_version.release[:3]) + while len(release_parts) < 3: + release_parts.append(0) + major, minor, patch = release_parts[:3] + + candidates = { + Version(f"{major}.{minor}.{patch}"), + Version(f"{major}.{minor}.{patch + 1}"), + } + if patch > 0: + candidates.add(Version(f"{major}.{minor}.{patch - 1}")) + elif minor > 0: + candidates.add(Version(f"{major}.{minor - 1}.999")) + elif major > 0: + candidates.add(Version(f"{major - 1}.999.999")) + + return sorted(candidates) + + @classmethod + def _format_validation_errors(cls, exc: ValidationError) -> List[str]: + """将 Pydantic 校验错误转换为中文错误列表。 + + Args: + exc: Pydantic 抛出的校验异常。 + + Returns: + List[str]: 中文错误描述列表。 + """ + error_messages: List[str] = [] + for error in exc.errors(): + location = cls._format_error_location(error.get("loc", ())) + error_type = str(error.get("type", "")) + error_input = error.get("input") + error_context = error.get("ctx", {}) or {} + + if error_type == "missing": + error_messages.append(f"缺少必需字段: {location}") + elif error_type == "extra_forbidden": + error_messages.append(f"存在未声明字段: {location}") + elif error_type == "literal_error": + expected_values = error_context.get("expected") + error_messages.append(f"字段 {location} 的值不合法,必须为 {expected_values}") + elif error_type == "model_type": + error_messages.append(f"字段 {location} 必须为对象") + elif error_type.endswith("_type"): + error_messages.append(f"字段 {location} 的类型不正确") + elif error_type == "value_error": + error_messages.append(f"字段 {location} 校验失败: {error_context.get('error')}") + else: + error_messages.append(f"字段 {location} 校验失败: {error.get('msg', error_input)}") + + return error_messages + + @staticmethod + def _format_error_location(location: Tuple[Any, ...]) -> str: + """格式化校验错误字段路径。 + + Args: + location: Pydantic 提供的字段路径元组。 + + Returns: + str: 点号连接后的字段路径。 + """ + return ".".join(str(item) for item in location) if location else "" diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 11ba45e7..6e85714b 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -13,16 +13,16 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple import contextlib import importlib import importlib.util -import json import os +import re import sys from src.common.logger import get_logger -from src.plugin_runtime.runner.manifest_validator import ManifestValidator +from src.plugin_runtime.runner.manifest_validator import ManifestValidator, PluginManifest logger = get_logger("plugin_runtime.runner.plugin_loader") -PluginCandidate = Tuple[Path, Dict[str, Any], Path] +PluginCandidate = Tuple[Path, PluginManifest, Path] class PluginMeta: @@ -32,28 +32,28 @@ class PluginMeta: self, plugin_id: str, plugin_dir: str, + module_name: str, plugin_instance: Any, - manifest: Dict[str, Any], + manifest: PluginManifest, ) -> None: + """初始化插件元数据。 + + Args: + plugin_id: 插件 ID。 + plugin_dir: 插件目录绝对路径。 + module_name: 插件入口模块名。 + plugin_instance: 插件实例对象。 + manifest: 解析后的强类型 Manifest。 + """ self.plugin_id = plugin_id self.plugin_dir = plugin_dir + self.module_name = module_name self.instance = plugin_instance self.manifest = manifest - self.version = manifest.get("version", "1.0.0") - self.capabilities_required = manifest.get("capabilities", []) - self.dependencies: List[str] = self._extract_dependencies(manifest) - - @staticmethod - def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]: - raw = manifest.get("dependencies", []) - result: List[str] = [] - for dep in raw: - if isinstance(dep, str): - result.append(dep.strip()) - elif isinstance(dep, dict): - if name := str(dep.get("name", "")).strip(): - result.append(name) - return result + self.version = manifest.version + self.capabilities_required = list(manifest.capabilities) + self.dependencies: List[str] = list(manifest.plugin_dependency_ids) + self.component_handlers: Dict[str, str] = {} class PluginLoader: @@ -66,30 +66,52 @@ class PluginLoader: """ def __init__(self, host_version: str = "") -> None: + """初始化插件加载器。 + + Args: + host_version: Host 版本号,用于 manifest 兼容性校验。 + """ self._loaded_plugins: Dict[str, PluginMeta] = {} self._failed_plugins: Dict[str, str] = {} self._manifest_validator = ManifestValidator(host_version=host_version) self._compat_hook_installed = False - def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]: - """扫描多个目录并加载所有插件(含依赖排序和 manifest 校验) + def discover_and_load( + self, + plugin_dirs: List[str], + extra_available: Optional[Dict[str, str]] = None, + ) -> List[PluginMeta]: + """扫描多个目录并加载所有插件。 Args: - plugin_dirs: 插件目录列表 + plugin_dirs: 插件目录列表。 + extra_available: 额外视为已满足的外部依赖插件版本映射。 Returns: - 成功加载的插件元数据列表(按依赖顺序) + List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。 """ candidates, duplicate_candidates = self._discover_candidates(plugin_dirs) self._record_duplicate_candidates(duplicate_candidates) # 第二阶段:依赖解析(拓扑排序) - load_order, failed_deps = self._resolve_dependencies(candidates) + load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available) self._record_failed_dependencies(failed_deps) # 第三阶段:按依赖顺序加载 return self._load_plugins_in_order(load_order, candidates) + def discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: + """扫描插件目录并返回候选插件。 + + Args: + plugin_dirs: 需要扫描的插件根目录列表。 + + Returns: + Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: + 候选插件映射和重复插件 ID 冲突映射。 + """ + return self._discover_candidates(plugin_dirs) + def _discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: """扫描插件目录并收集候选插件。""" candidates: Dict[str, PluginCandidate] = {} @@ -123,26 +145,17 @@ class PluginLoader: def _discover_single_candidate(self, plugin_dir: Path) -> Optional[Tuple[str, PluginCandidate]]: """发现并校验单个插件目录。""" - manifest_path = plugin_dir / "_manifest.json" plugin_path = plugin_dir / "plugin.py" - - if not manifest_path.exists() or not plugin_path.exists(): + if not plugin_path.exists(): return None - try: - with manifest_path.open("r", encoding="utf-8") as manifest_file: - manifest: Dict[str, Any] = json.load(manifest_file) - except Exception as e: - self._failed_plugins[plugin_dir.name] = f"manifest 解析失败: {e}" - logger.error(f"插件 {plugin_dir.name} manifest 解析失败: {e}") - return None - - if not self._manifest_validator.validate(manifest): + manifest = self._manifest_validator.load_from_plugin_path(plugin_dir) + if manifest is None: errors = "; ".join(self._manifest_validator.errors) self._failed_plugins[plugin_dir.name] = f"manifest 校验失败: {errors}" return None - plugin_id = str(manifest.get("name", plugin_dir.name)).strip() or plugin_dir.name + plugin_id = manifest.id return plugin_id, (plugin_dir, manifest, plugin_path) def _record_duplicate_candidates(self, duplicate_candidates: Dict[str, List[Path]]) -> None: @@ -170,7 +183,6 @@ class PluginLoader: plugin_dir, manifest, plugin_path = candidates[plugin_id] try: if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path): - self._loaded_plugins[meta.plugin_id] = meta results.append(meta) except Exception as e: self._failed_plugins[plugin_id] = str(e) @@ -182,45 +194,193 @@ class PluginLoader: """获取已加载的插件""" return self._loaded_plugins.get(plugin_id) + def set_loaded_plugin(self, meta: PluginMeta) -> None: + """登记一个已经完成初始化的插件。 + + Args: + meta: 待登记的插件元数据。 + """ + self._loaded_plugins[meta.plugin_id] = meta + + def remove_loaded_plugin(self, plugin_id: str) -> Optional[PluginMeta]: + """移除一个已加载插件的元数据。 + + Args: + plugin_id: 待移除的插件 ID。 + + Returns: + Optional[PluginMeta]: 被移除的插件元数据;不存在时返回 ``None``。 + """ + return self._loaded_plugins.pop(plugin_id, None) + + def purge_plugin_modules(self, plugin_id: str, plugin_dir: str) -> List[str]: + """清理指定插件目录下的模块缓存。 + + Args: + plugin_id: 插件 ID。 + plugin_dir: 插件目录绝对路径。 + + Returns: + List[str]: 已从 ``sys.modules`` 中移除的模块名列表。 + """ + removed_modules: List[str] = [] + plugin_path = Path(plugin_dir).resolve() + synthetic_module_name = self._build_safe_module_name(plugin_id) + + for module_name, module in list(sys.modules.items()): + if module_name == synthetic_module_name: + removed_modules.append(module_name) + sys.modules.pop(module_name, None) + continue + + module_file = getattr(module, "__file__", None) + if module_file is None: + continue + + try: + module_path = Path(module_file).resolve() + except Exception: + continue + + if module_path.is_relative_to(plugin_path): + removed_modules.append(module_name) + sys.modules.pop(module_name, None) + + importlib.invalidate_caches() + return removed_modules + + @staticmethod + def _build_safe_module_name(plugin_id: str) -> str: + """将插件 ID 转换为可用于动态导入的安全模块名。 + + Args: + plugin_id: 原始插件 ID。 + + Returns: + str: 仅包含字母、数字和下划线的合成模块名。 + """ + normalized_plugin_id = re.sub(r"[^0-9A-Za-z_]", "_", str(plugin_id or "").strip()) + if normalized_plugin_id and normalized_plugin_id[0].isdigit(): + normalized_plugin_id = f"_{normalized_plugin_id}" + return f"_maibot_plugin_{normalized_plugin_id or 'plugin'}" + def list_plugins(self) -> List[str]: """列出所有已加载的插件 ID""" return list(self._loaded_plugins.keys()) @property def failed_plugins(self) -> Dict[str, str]: + """返回当前记录的失败插件原因映射。""" return dict(self._failed_plugins) + @property + def manifest_validator(self) -> ManifestValidator: + """返回当前加载器持有的 Manifest 校验器。 + + Returns: + ManifestValidator: 当前使用的 Manifest 校验器实例。 + """ + return self._manifest_validator + # ──── 依赖解析 ──────────────────────────────────────────── + def resolve_dependencies( + self, + candidates: Dict[str, PluginCandidate], + extra_available: Optional[Dict[str, str]] = None, + ) -> Tuple[List[str], Dict[str, str]]: + """解析候选插件的依赖顺序。 + + Args: + candidates: 待加载的候选插件集合。 + extra_available: 视为已满足的外部依赖插件版本映射。 + + Returns: + Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。 + """ + return self._resolve_dependencies(candidates, extra_available=extra_available) + + def load_candidate(self, plugin_id: str, candidate: PluginCandidate) -> Optional[PluginMeta]: + """加载单个候选插件模块。 + + Args: + plugin_id: 插件 ID。 + candidate: 候选插件三元组。 + + Returns: + Optional[PluginMeta]: 加载成功的插件元数据;失败时返回 ``None``。 + """ + plugin_dir, manifest, plugin_path = candidate + return self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path) + def _resolve_dependencies( self, candidates: Dict[str, PluginCandidate], + extra_available: Optional[Dict[str, str]] = None, ) -> Tuple[List[str], Dict[str, str]]: """拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。""" available = set(candidates.keys()) + satisfied_dependencies = { + str(plugin_id or "").strip(): str(plugin_version or "").strip() + for plugin_id, plugin_version in (extra_available or {}).items() + if str(plugin_id or "").strip() and str(plugin_version or "").strip() + } dep_graph: Dict[str, Set[str]] = {} failed: Dict[str, str] = {} for pid, (_, manifest, _) in candidates.items(): - raw_deps = manifest.get("dependencies", []) resolved: Set[str] = set() - missing: List[str] = [] - for dep in raw_deps: - dep_name = dep if isinstance(dep, str) else str(dep.get("name", "")) - dep_name = dep_name.strip() - if not dep_name or dep_name == pid: + missing_or_incompatible: List[str] = [] + + for dependency in manifest.plugin_dependencies: + dependency_id = dependency.id + if dependency_id in available: + dependency_manifest = candidates[dependency_id][1] + if not self._manifest_validator.is_plugin_dependency_satisfied( + dependency, + dependency_manifest.version, + ): + missing_or_incompatible.append( + f"{dependency_id} (需要 {dependency.version_spec},当前 {dependency_manifest.version})" + ) + continue + resolved.add(dependency_id) continue - if dep_name in available: - resolved.add(dep_name) - else: - missing.append(dep_name) - if missing: - failed[pid] = f"缺少依赖: {', '.join(missing)}" + + external_dependency_version = satisfied_dependencies.get(dependency_id) + if external_dependency_version is None: + missing_or_incompatible.append(f"{dependency_id} (未找到依赖插件)") + continue + + if not self._manifest_validator.is_plugin_dependency_satisfied( + dependency, + external_dependency_version, + ): + missing_or_incompatible.append( + f"{dependency_id} (需要 {dependency.version_spec},当前 {external_dependency_version})" + ) + + if missing_or_incompatible: + failed[pid] = f"依赖未满足: {', '.join(missing_or_incompatible)}" dep_graph[pid] = resolved - # 移除失败项 - for pid in failed: - dep_graph.pop(pid, None) + # 迭代传播“依赖自身加载失败”到上游依赖方,避免误报为循环依赖 + changed = True + while changed: + changed = False + failed_plugin_ids = set(failed) + for pid, dependencies in list(dep_graph.items()): + if pid in failed: + dep_graph.pop(pid, None) + continue + + failed_dependencies = sorted(dependency for dependency in dependencies if dependency in failed_plugin_ids) + if not failed_dependencies: + continue + + failed[pid] = f"依赖未满足: {', '.join(f'{dependency} (依赖插件加载失败)' for dependency in failed_dependencies)}" + dep_graph.pop(pid, None) + changed = True # Kahn 拓扑排序 indegree = {pid: len(deps) for pid, deps in dep_graph.items()} @@ -253,7 +413,7 @@ class PluginLoader: self, plugin_id: str, plugin_dir: Path, - manifest: Dict[str, Any], + manifest: PluginManifest, plugin_path: Path, ) -> Optional[PluginMeta]: """加载单个插件""" @@ -261,8 +421,12 @@ class PluginLoader: self._ensure_compat_hook() # 动态导入插件模块 - module_name = f"_maibot_plugin_{plugin_id}" - spec = importlib.util.spec_from_file_location(module_name, str(plugin_path)) + module_name = self._build_safe_module_name(plugin_id) + spec = importlib.util.spec_from_file_location( + module_name, + str(plugin_path), + submodule_search_locations=[str(plugin_dir)], + ) if spec is None or spec.loader is None: logger.error(f"无法创建模块 spec: {plugin_path}") return None @@ -271,37 +435,73 @@ class PluginLoader: sys.modules[module_name] = module plugin_parent_dir = plugin_dir.parent - with self._temporary_sys_path_entry(plugin_parent_dir): - spec.loader.exec_module(module) + try: + with self._temporary_sys_path_entry(plugin_parent_dir): + spec.loader.exec_module(module) - # 优先使用新版 create_plugin 工厂函数 - create_plugin = getattr(module, "create_plugin", None) - if create_plugin is not None: - instance = create_plugin() - logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") - return PluginMeta( - plugin_id=plugin_id, - plugin_dir=str(plugin_dir), - plugin_instance=instance, - manifest=manifest, - ) + # 优先使用新版 create_plugin 工厂函数 + create_plugin = getattr(module, "create_plugin", None) + if create_plugin is not None: + instance = create_plugin() + self._validate_sdk_plugin_contract(plugin_id, instance) + logger.info(f"插件 {plugin_id} v{manifest.version} 加载成功") + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=str(plugin_dir), + module_name=module_name, + plugin_instance=instance, + manifest=manifest, + ) - # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类 - instance = self._try_load_legacy_plugin(module, plugin_id) - if instance is not None: - logger.info( - f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" - ) - return PluginMeta( - plugin_id=plugin_id, - plugin_dir=str(plugin_dir), - plugin_instance=instance, - manifest=manifest, - ) + # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类 + instance = self._try_load_legacy_plugin(module, plugin_id) + if instance is not None: + logger.info( + f"插件 {plugin_id} v{manifest.version} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" + ) + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=str(plugin_dir), + module_name=module_name, + plugin_instance=instance, + manifest=manifest, + ) + except Exception: + sys.modules.pop(module_name, None) + raise logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin") return None + @staticmethod + def _validate_sdk_plugin_contract(plugin_id: str, instance: Any) -> None: + """校验 SDK 插件的基础契约。 + + Args: + plugin_id: 当前插件 ID。 + instance: ``create_plugin()`` 返回的插件实例。 + + Raises: + TypeError: 当插件未覆盖必需生命周期方法或订阅声明不合法时抛出。 + """ + + try: + from maibot_sdk.plugin import MaiBotPlugin + except ImportError: + return + + if not isinstance(instance, MaiBotPlugin): + return + + if type(instance).on_load is MaiBotPlugin.on_load: + raise TypeError(f"插件 {plugin_id} 必须实现 on_load()") + if type(instance).on_unload is MaiBotPlugin.on_unload: + raise TypeError(f"插件 {plugin_id} 必须实现 on_unload()") + if type(instance).on_config_update is MaiBotPlugin.on_config_update: + raise TypeError(f"插件 {plugin_id} 必须实现 on_config_update()") + + instance.get_config_reload_subscriptions() + @staticmethod @contextlib.contextmanager def _temporary_sys_path_entry(path: Path) -> Iterator[None]: diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py index 6a1d59d5..dc917cc8 100644 --- a/src/plugin_runtime/runner/rpc_client.py +++ b/src/plugin_runtime/runner/rpc_client.py @@ -1,14 +1,6 @@ -"""Runner 端 RPC Client +"""Runner 端 RPC 客户端。""" -负责: -1. 连接 Host RPC Server -2. 发送握手(runner.hello) -3. 发送组件注册请求 -4. 接收并分发 Host 的调用请求 -5. 发送能力调用请求到 Host -""" - -from typing import Any, Awaitable, Callable, Dict, Optional, cast +from typing import Any, Awaitable, Callable, Dict, Optional, Set, cast import asyncio import contextlib @@ -29,12 +21,15 @@ from src.plugin_runtime.transport.factory import create_transport_client logger = get_logger("plugin_runtime.runner.rpc_client") -# RPC 方法处理器类型 MethodHandler = Callable[[Envelope], Awaitable[Envelope]] def _get_sdk_version() -> str: - """从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。""" + """读取 SDK 版本号。 + + Returns: + str: 已安装的 SDK 版本;读取失败时回退到 ``1.0.0``。 + """ try: from importlib.metadata import version @@ -47,73 +42,78 @@ SDK_VERSION = _get_sdk_version() class RPCClient: - """Runner 端 RPC 客户端 - - 管理与 Host 的 IPC 连接,支持双向 RPC 调用。 - """ + """Runner 端 RPC 客户端。""" def __init__( self, host_address: str, session_token: str, codec: Optional[Codec] = None, - ): - self._host_address = host_address - self._session_token = session_token - self._codec = codec or MsgPackCodec() + ) -> None: + """初始化 RPC 客户端。 + + Args: + host_address: Host 的 IPC 地址。 + session_token: 握手用会话令牌。 + codec: 可选的编解码器实现。 + """ + self._host_address: str = host_address + self._session_token: str = session_token + self._codec: Codec = codec or MsgPackCodec() self._id_gen = RequestIdGenerator() self._connection: Optional[Connection] = None - self._runner_id = str(uuid.uuid4()) - self._generation: int = 0 - - # 方法处理器注册表(Host 发来的调用) + self._runner_id: str = str(uuid.uuid4()) self._method_handlers: Dict[str, MethodHandler] = {} - - # 等待响应的 pending 请求: request_id -> Future - self._pending_requests: Dict[int, asyncio.Future] = {} - - # 运行状态 - self._running = False - self._recv_task: Optional[asyncio.Task] = None - self._background_tasks: set[asyncio.Task] = set() - - @property - def generation(self) -> int: - return self._generation + self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {} + self._running: bool = False + self._recv_task: Optional[asyncio.Task[None]] = None + self._background_tasks: Set[asyncio.Task[Any]] = set() @property def is_connected(self) -> bool: + """返回当前连接是否可用。""" return self._connection is not None and not self._connection.is_closed def register_method(self, method: str, handler: MethodHandler) -> None: - """注册方法处理器(处理 Host 发来的请求)""" + """注册 Host -> Runner 的 RPC 处理器。 + + Args: + method: RPC 方法名。 + handler: 方法处理函数。 + """ self._method_handlers[method] = handler def _require_connection(self) -> Connection: - """返回当前可用连接;若连接不可用则抛出 RPCError。""" + """返回当前可用连接。 + + Returns: + Connection: 当前连接对象。 + + Raises: + RPCError: 当前未连接到 Host。 + """ connection = self._connection if connection is None or connection.is_closed: raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host") return cast(Connection, connection) async def connect_and_handshake(self) -> bool: - """连接 Host 并完成握手 + """连接 Host 并完成握手。 Returns: - 是否握手成功 + bool: 是否握手成功。 """ client = create_transport_client(self._host_address) self._connection = await client.connect() connection = self._require_connection() - # 发送 runner.hello hello = HelloPayload( runner_id=self._runner_id, sdk_version=SDK_VERSION, session_token=self._session_token, ) - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, @@ -121,33 +121,27 @@ class RPCClient: payload=hello.model_dump(), ) - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) + await connection.send_frame(self._codec.encode_envelope(envelope)) - # 接收握手响应 resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0) - resp = self._codec.decode_envelope(resp_data) + response = self._codec.decode_envelope(resp_data) + resp_payload = HelloResponsePayload.model_validate(response.payload) - resp_payload = HelloResponsePayload.model_validate(resp.payload) if not resp_payload.accepted: logger.error(f"握手被拒绝: {resp_payload.reason}") - await self._connection.close() - self._connection = None + await self.disconnect() return False - self._generation = resp_payload.assigned_generation - logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}") - - # 启动消息接收循环 + logger.info(f"握手成功: host_version={resp_payload.host_version}") self._running = True - self._recv_task = asyncio.create_task(self._recv_loop()) - + self._recv_task = asyncio.create_task(self._recv_loop(), name="RPCClient.recv") return True async def disconnect(self) -> None: - """断开连接""" + """断开与 Host 的连接并清理状态。""" self._running = False - if self._recv_task: + + if self._recv_task is not None: self._recv_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._recv_task @@ -160,13 +154,12 @@ class RPCClient: await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() - # 取消所有 pending 请求 for future in self._pending_requests.values(): if not future.done(): future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭")) self._pending_requests.clear() - if self._connection: + if self._connection is not None: await self._connection.close() self._connection = None @@ -177,16 +170,27 @@ class RPCClient: payload: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> Envelope: - """向 Host 发送 RPC 请求并等待响应""" - connection = self._require_connection() + """向 Host 发送 RPC 请求并等待响应。 - request_id = self._id_gen.next() + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + payload: 请求载荷。 + timeout_ms: 超时时间,单位毫秒。 + + Returns: + Envelope: Host 返回的响应信封。 + + Raises: + RPCError: 发送失败、超时或连接异常。 + """ + connection = self._require_connection() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, method=method, plugin_id=plugin_id, - generation=self._generation, timeout_ms=timeout_ms, payload=payload or {}, ) @@ -196,21 +200,16 @@ class RPCClient: self._pending_requests[request_id] = future try: - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) - - timeout_sec = timeout_ms / 1000.0 - return await asyncio.wait_for(future, timeout=timeout_sec) + await connection.send_frame(self._codec.encode_envelope(envelope)) + return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0) except asyncio.TimeoutError: self._pending_requests.pop(request_id, None) raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None - except Exception as e: + except Exception as exc: self._pending_requests.pop(request_id, None) - if isinstance(e, RPCError): + if isinstance(exc, RPCError): raise - raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e - - # ─── 内部方法 ────────────────────────────────────────────── + raise RPCError(ErrorCode.E_UNKNOWN, str(exc)) from exc async def send_event( self, @@ -218,33 +217,30 @@ class RPCClient: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, ) -> None: - """向 Host 发送单向事件(fire-and-forget,不等待响应)。 + """向 Host 发送单向广播消息。 Args: - method: RPC 方法名,如 "runner.log_batch"。 - plugin_id: 目标插件 ID(可为空,表示 Runner 级消息)。 - payload: 事件数据。 + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + payload: 广播载荷。 """ if not self.is_connected: return connection = self._require_connection() - - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, - message_type=MessageType.EVENT, + message_type=MessageType.BROADCAST, method=method, plugin_id=plugin_id, - generation=self._generation, payload=payload or {}, ) - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) + await connection.send_frame(self._codec.encode_envelope(envelope)) async def _recv_loop(self) -> None: - """消息接收主循环""" - while self._running and self._connection and not self._connection.is_closed: + """持续接收 Host 发来的消息并分发。""" + while self._running and self._connection is not None and not self._connection.is_closed: try: data = await self._connection.recv_frame() except (asyncio.IncompleteReadError, ConnectionError): @@ -252,39 +248,47 @@ class RPCClient: break except asyncio.CancelledError: break - except Exception as e: - logger.error(f"接收帧失败: {e}") + except Exception as exc: + logger.error(f"接收帧失败: {exc}") break try: envelope = self._codec.decode_envelope(data) - except Exception as e: - logger.error(f"解码消息失败: {e}") + except Exception as exc: + logger.error(f"解码消息失败: {exc}") continue if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): self._track_background_task(asyncio.create_task(self._handle_request(envelope))) - elif envelope.is_event(): - self._track_background_task(asyncio.create_task(self._handle_event(envelope))) + elif envelope.is_broadcast(): + self._track_background_task(asyncio.create_task(self._handle_broadcast(envelope))) def _handle_response(self, envelope: Envelope) -> None: - """处理来自 Host 的响应""" + """处理 Host 返回的响应。 + + Args: + envelope: 响应信封。 + """ future = self._pending_requests.pop(envelope.request_id, None) - if future and not future.done(): - if envelope.error: - future.set_exception(RPCError.from_dict(envelope.error)) - else: - future.set_result(envelope) + if future is None or future.done(): + return + if envelope.error: + future.set_exception(RPCError.from_dict(envelope.error)) + else: + future.set_result(envelope) async def _handle_request(self, envelope: Envelope) -> None: - """处理来自 Host 的请求(调用插件组件)""" + """处理 Host 发来的请求。 + + Args: + envelope: 请求信封。 + """ connection = self._connection if connection is None or connection.is_closed: logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应") return - connection = cast(Connection, connection) handler = self._method_handlers.get(envelope.method) if handler is None: @@ -298,23 +302,34 @@ class RPCClient: try: response = await handler(envelope) await connection.send_frame(self._codec.encode_envelope(response)) - except RPCError as e: - error_resp = envelope.make_error_response(e.code.value, e.message, e.details) + except RPCError as exc: + error_resp = envelope.make_error_response(exc.code.value, exc.message, exc.details) await connection.send_frame(self._codec.encode_envelope(error_resp)) - except Exception as e: - logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True) - error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) + except Exception as exc: + logger.error(f"处理请求 {envelope.method} 异常: {exc}", exc_info=True) + error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc)) await connection.send_frame(self._codec.encode_envelope(error_resp)) - async def _handle_event(self, envelope: Envelope) -> None: - """处理来自 Host 的事件""" - if handler := self._method_handlers.get(envelope.method): - try: - await handler(envelope) - except Exception as e: - logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) + async def _handle_broadcast(self, envelope: Envelope) -> None: + """处理 Host 发来的广播事件。 - def _track_background_task(self, task: asyncio.Task) -> None: - """保持后台任务强引用,直到其完成或被取消。""" + Args: + envelope: 广播信封。 + """ + handler = self._method_handlers.get(envelope.method) + if handler is None: + return + + try: + await handler(envelope) + except Exception as exc: + logger.error(f"处理广播 {envelope.method} 异常: {exc}", exc_info=True) + + def _track_background_task(self, task: asyncio.Task[Any]) -> None: + """持有后台任务强引用直到其结束。 + + Args: + task: 后台任务。 + """ self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index dae1cfa1..d1ebc064 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -9,13 +9,13 @@ 6. 转发插件的能力调用到 Host """ -from typing import Any, Callable, List, Optional, Protocol, cast - from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, cast import asyncio import contextlib import inspect +import json import logging as stdlib_logging import os import signal @@ -24,27 +24,59 @@ import time import tomllib from src.common.logger import get_console_handler, get_logger, initialize_logging -from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN +from src.plugin_runtime import ( + ENV_EXTERNAL_PLUGIN_IDS, + ENV_HOST_VERSION, + ENV_IPC_ADDRESS, + ENV_PLUGIN_DIRS, + ENV_SESSION_TOKEN, +) from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, ComponentDeclaration, + ConfigUpdatedPayload, Envelope, HealthPayload, InvokePayload, InvokeResultPayload, - RegisterComponentsPayload, + RegisterPluginPayload, + ReloadPluginPayload, + ReloadPluginResultPayload, + ReloadPluginsPayload, + ReloadPluginsResultPayload, RunnerReadyPayload, + UnregisterPluginPayload, ) from src.plugin_runtime.protocol.errors import ErrorCode from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler -from src.plugin_runtime.runner.plugin_loader import PluginLoader, PluginMeta +from src.plugin_runtime.runner.plugin_loader import PluginCandidate, PluginLoader, PluginMeta from src.plugin_runtime.runner.rpc_client import RPCClient logger = get_logger("plugin_runtime.runner.main") +_PLUGIN_ALLOWED_RAW_HOST_METHODS = frozenset( + { + "cap.call", + "host.route_message", + "host.update_message_gateway_state", + } +) + class _ContextAwarePlugin(Protocol): - def _set_context(self, context: Any) -> None: ... + """支持注入运行时上下文的插件协议。 + + 该协议用于描述 Runner 在激活插件时依赖的最小接口。 + 只要插件实例实现了 ``_set_context`` 方法,就可以被 Runner + 注入 ``PluginContext`` 或兼容层上下文对象。 + """ + + def _set_context(self, context: Any) -> None: + """为插件实例注入运行时上下文。 + + Args: + context: 由 Runner 构造的上下文对象。 + """ def _install_shutdown_signal_handlers( @@ -89,22 +121,37 @@ class PluginRunner: host_address: str, session_token: str, plugin_dirs: List[str], + external_available_plugins: Optional[Dict[str, str]] = None, ) -> None: + """初始化 Runner。 + + Args: + host_address: Host 的 IPC 地址。 + session_token: 握手用会话令牌。 + plugin_dirs: 当前 Runner 负责扫描的插件目录列表。 + external_available_plugins: 视为已满足的外部依赖插件版本映射。 + """ self._host_address: str = host_address self._session_token: str = session_token - self._plugin_dirs: list[str] = plugin_dirs + self._plugin_dirs: List[str] = plugin_dirs + self._external_available_plugins: Dict[str, str] = { + str(plugin_id or "").strip(): str(plugin_version or "").strip() + for plugin_id, plugin_version in (external_available_plugins or {}).items() + if str(plugin_id or "").strip() and str(plugin_version or "").strip() + } self._rpc_client: RPCClient = RPCClient(host_address, session_token) self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, "")) self._start_time: float = time.monotonic() self._shutting_down: bool = False + self._reload_lock: asyncio.Lock = asyncio.Lock() # IPC 日志 Handler:握手成功后安装,将所有 stdlib logging 转发到 Host self._log_handler: Optional[RunnerIPCLogHandler] = None - self._suspended_console_handlers: list[stdlib_logging.Handler] = [] + self._suspended_console_handlers: List[stdlib_logging.Handler] = [] async def run(self) -> None: - """Runner 主入口""" + """运行 Runner 主循环。""" # 1. 连接 Host logger.info(f"Runner 启动,连接 Host: {self._host_address}") ok = await self._rpc_client.connect_and_handshake() @@ -119,36 +166,18 @@ class PluginRunner: self._register_handlers() # 3. 加载插件 - plugins = self._loader.discover_and_load(self._plugin_dirs) + plugins = self._loader.discover_and_load( + self._plugin_dirs, + extra_available=self._external_available_plugins, + ) logger.info(f"已加载 {len(plugins)} 个插件") # 4. 注入 PluginContext + 调用 on_load 生命周期钩子 - failed_plugins: set[str] = set() + failed_plugins: Set[str] = set(self._loader.failed_plugins.keys()) for meta in plugins: - instance = meta.instance - self._inject_context(meta.plugin_id, instance) - self._apply_plugin_config(meta) - if not await self._bootstrap_plugin(meta): - failed_plugins.add(meta.plugin_id) - continue - if hasattr(instance, "on_load"): - try: - ret = instance.on_load() - if asyncio.iscoroutine(ret): - await ret - except Exception as e: - logger.error(f"插件 {meta.plugin_id} on_load 失败,跳过注册: {e}", exc_info=True) - failed_plugins.add(meta.plugin_id) - await self._deactivate_plugin(meta) - - # 5. 向 Host 注册所有插件的组件(跳过 on_load 失败的插件) - for meta in plugins: - if meta.plugin_id in failed_plugins: - continue - ok = await self._register_plugin(meta) + ok = await self._activate_plugin(meta) if not ok: failed_plugins.add(meta.plugin_id) - await self._deactivate_plugin(meta) successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins] await self._notify_ready(successful_plugins, sorted(failed_plugins)) @@ -217,7 +246,7 @@ class PluginRunner: """为插件实例创建并注入 PluginContext。 对新版 MaiBotPlugin(具有 _set_context 方法):创建 PluginContext 并注入。 - 对旧版 LegacyPluginAdapter(具有 _set_context 方法,由适配器代理):同上。 + 对旧版 LegacyPluginAdapter(具有 _set_context 方法,由兼容代理封装):同上。 """ if not hasattr(instance, "_set_context"): return @@ -232,9 +261,11 @@ class PluginRunner: bound_plugin_id = plugin_id async def _rpc_call( - method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None + method: str, + plugin_id: str = "", + payload: Optional[Dict[str, Any]] = None, ) -> Any: - """桥接 PluginContext.call_capability → RPCClient.send_request。 + """桥接 PluginContext 的原始 RPC 调用到 Host。 无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id 始终绑定为当前插件实例,避免伪造其他插件身份申请能力。 @@ -243,21 +274,26 @@ class PluginRunner: logger.warning( f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份" ) + normalized_method = str(method or "").strip() + if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS: + raise PermissionError( + f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: " + f"{normalized_method or ''}" + ) resp = await rpc_client.send_request( - method=method, + method=normalized_method, plugin_id=bound_plugin_id, payload=payload or {}, ) - # 从响应信封中提取业务结果 if resp.error: raise RuntimeError(resp.error.get("message", "能力调用失败")) - return resp.payload.get("result") + return resp.payload ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call) cast(_ContextAwarePlugin, instance)._set_context(ctx) logger.debug(f"已为插件 {plugin_id} 注入 PluginContext") - def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None: + def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None: """在 Runner 侧为插件实例注入当前插件配置。""" instance = meta.instance if not hasattr(instance, "set_plugin_config"): @@ -270,7 +306,7 @@ class PluginRunner: logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}") @staticmethod - def _load_plugin_config(plugin_dir: str) -> dict[str, Any]: + def _load_plugin_config(plugin_dir: str) -> Dict[str, Any]: """从插件目录读取 config.toml。""" config_path = Path(plugin_dir) / "config.toml" if not config_path.exists(): @@ -286,16 +322,60 @@ class PluginRunner: return loaded if isinstance(loaded, dict) else {} def _register_handlers(self) -> None: - """注册方法处理器""" + """注册 Host -> Runner 的方法处理器。""" self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) + self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) self._rpc_client.register_method("plugin.health", self._handle_health) self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated) + self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin) + self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins) + + @staticmethod + def _resolve_component_handler_name(meta: PluginMeta, component_name: str) -> str: + """解析组件名对应的真实处理函数名。 + + Args: + meta: 已加载插件的元数据。 + component_name: Host 侧请求中的组件声明名。 + + Returns: + str: 实际应在插件实例上查找的方法名。 + """ + return str(meta.component_handlers.get(component_name, component_name) or component_name) + + def _resolve_component_handler(self, meta: PluginMeta, component_name: str) -> Any: + """根据组件声明名解析插件实例上的可调用处理函数。 + + Args: + meta: 已加载插件的元数据。 + component_name: Host 侧请求中的组件声明名。 + + Returns: + Any: 解析到的可调用对象;未找到时返回 ``None``。 + """ + instance = meta.instance + handler_name = self._resolve_component_handler_name(meta, component_name) + handler_method = getattr(instance, handler_name, None) + if handler_method is not None: + return handler_method + + if handler_name != component_name: + legacy_style_handler = getattr(instance, f"handle_{component_name}", None) + if legacy_style_handler is not None: + return legacy_style_handler + + prefixed_handler = getattr(instance, f"handle_{component_name}", None) + if prefixed_handler is not None: + return prefixed_handler + return getattr(instance, component_name, None) async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool: """向 Host 同步插件 bootstrap 能力令牌。""" @@ -308,12 +388,14 @@ class PluginRunner: ) try: - await self._rpc_client.send_request( + response = await self._rpc_client.send_request( "plugin.bootstrap", plugin_id=meta.plugin_id, payload=payload.model_dump(), timeout_ms=10000, ) + if response.error: + raise RuntimeError(response.error.get("message", "插件 bootstrap 失败")) return True except Exception as e: logger.error(f"插件 {meta.plugin_id} bootstrap 失败: {e}") @@ -324,45 +406,500 @@ class PluginRunner: await self._bootstrap_plugin(meta, capabilities_required=[]) async def _register_plugin(self, meta: PluginMeta) -> bool: - """向 Host 注册单个插件""" + """向 Host 注册单个插件。 + + Args: + meta: 待注册的插件元数据。 + + Returns: + bool: 是否注册成功。 + """ # 收集插件组件声明 components: List[ComponentDeclaration] = [] + config_reload_subscriptions: List[str] = [] instance = meta.instance # 从插件实例获取组件声明(SDK 插件须实现 get_components 方法) if hasattr(instance, "get_components"): - components.extend( - ComponentDeclaration( - name=comp_info.get("name", ""), - component_type=comp_info.get("type", ""), - plugin_id=meta.plugin_id, - metadata=comp_info.get("metadata", {}), - ) - for comp_info in instance.get_components() - ) + meta.component_handlers.clear() + for comp_info in instance.get_components(): + if not isinstance(comp_info, dict): + continue - reg_payload = RegisterComponentsPayload( + component_name = str(comp_info.get("name", "") or "").strip() + raw_metadata = comp_info.get("metadata", {}) + component_metadata = raw_metadata if isinstance(raw_metadata, dict) else {} + + if component_name: + handler_name = str(component_metadata.get("handler_name", component_name) or component_name).strip() + meta.component_handlers[component_name] = handler_name or component_name + + components.append( + ComponentDeclaration( + name=component_name, + component_type=str(comp_info.get("type", "") or "").strip(), + plugin_id=meta.plugin_id, + metadata=component_metadata, + ) + ) + if hasattr(instance, "get_config_reload_subscriptions"): + config_reload_subscriptions = list(instance.get_config_reload_subscriptions()) + + reg_payload = RegisterPluginPayload( plugin_id=meta.plugin_id, plugin_version=meta.version, components=components, capabilities_required=meta.capabilities_required, + dependencies=meta.dependencies, + config_reload_subscriptions=config_reload_subscriptions, ) try: - _resp = await self._rpc_client.send_request( + response = await self._rpc_client.send_request( "plugin.register_components", plugin_id=meta.plugin_id, payload=reg_payload.model_dump(), timeout_ms=10000, ) + if response.error: + raise RuntimeError(response.error.get("message", "插件注册失败")) logger.info(f"插件 {meta.plugin_id} 注册完成") return True except Exception as e: logger.error(f"插件 {meta.plugin_id} 注册失败: {e}") return False + async def _unregister_plugin(self, plugin_id: str, reason: str) -> None: + """通知 Host 注销指定插件。 + + Args: + plugin_id: 目标插件 ID。 + reason: 注销原因。 + """ + payload = UnregisterPluginPayload(plugin_id=plugin_id, reason=reason) + try: + await self._rpc_client.send_request( + "plugin.unregister", + plugin_id=plugin_id, + payload=payload.model_dump(), + timeout_ms=10000, + ) + except Exception as exc: + logger.warning(f"插件 {plugin_id} 注销通知失败: {exc}") + + async def _invoke_plugin_on_load(self, meta: PluginMeta) -> bool: + """执行插件的 ``on_load`` 生命周期。 + + Args: + meta: 待初始化的插件元数据。 + + Returns: + bool: 生命周期是否执行成功。 + """ + instance = meta.instance + if not hasattr(instance, "on_load"): + return True + + try: + result = instance.on_load() + if asyncio.iscoroutine(result): + await result + return True + except Exception as exc: + logger.error(f"插件 {meta.plugin_id} on_load 失败: {exc}", exc_info=True) + return False + + async def _invoke_plugin_on_unload(self, meta: PluginMeta) -> None: + """执行插件的 ``on_unload`` 生命周期。 + + Args: + meta: 待卸载的插件元数据。 + """ + instance = meta.instance + if not hasattr(instance, "on_unload"): + return + + try: + result = instance.on_unload() + if asyncio.iscoroutine(result): + await result + except Exception as exc: + logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True) + + async def _activate_plugin(self, meta: PluginMeta) -> bool: + """完成插件注入、授权、生命周期和组件注册。 + + Args: + meta: 待激活的插件元数据。 + + Returns: + bool: 是否激活成功。 + """ + self._inject_context(meta.plugin_id, meta.instance) + self._apply_plugin_config(meta) + + if not await self._bootstrap_plugin(meta): + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + if not await self._register_plugin(meta): + await self._invoke_plugin_on_unload(meta) + await self._deactivate_plugin(meta) + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + if not await self._invoke_plugin_on_load(meta): + await self._unregister_plugin(meta.plugin_id, reason="on_load_failed") + await self._deactivate_plugin(meta) + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + self._loader.set_loaded_plugin(meta) + return True + + async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None: + """卸载单个插件并清理 Host/Runner 两侧状态。 + + Args: + meta: 待卸载的插件元数据。 + reason: 卸载原因。 + purge_modules: 是否在卸载完成后清理插件模块缓存。 + """ + await self._invoke_plugin_on_unload(meta) + await self._unregister_plugin(meta.plugin_id, reason) + await self._deactivate_plugin(meta) + self._loader.remove_loaded_plugin(meta.plugin_id) + if purge_modules: + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + + def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]: + """收集依赖指定插件的所有已加载插件。 + + Args: + plugin_id: 根插件 ID。 + + Returns: + Set[str]: 目标插件及其所有反向依赖插件集合。 + """ + impacted_plugins: Set[str] = {plugin_id} + changed = True + + while changed: + changed = False + for loaded_plugin_id in self._loader.list_plugins(): + if loaded_plugin_id in impacted_plugins: + continue + + meta = self._loader.get_plugin(loaded_plugin_id) + if meta is None: + continue + + if any(dependency in impacted_plugins for dependency in meta.dependencies): + impacted_plugins.add(loaded_plugin_id) + changed = True + + return impacted_plugins + + def _collect_reverse_dependents_for_roots(self, plugin_ids: Set[str]) -> Set[str]: + """收集多个根插件对应的反向依赖并集。 + + Args: + plugin_ids: 根插件 ID 集合。 + + Returns: + Set[str]: 所有根插件及其反向依赖并集。 + """ + + impacted_plugins: Set[str] = set() + for plugin_id in sorted(plugin_ids): + impacted_plugins.update(self._collect_reverse_dependents(plugin_id)) + return impacted_plugins + + def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]: + """构建受影响插件的卸载顺序。 + + Args: + plugin_ids: 需要卸载的插件集合。 + + Returns: + List[str]: 依赖方优先的卸载顺序。 + """ + dependency_graph: Dict[str, Set[str]] = {} + for plugin_id in plugin_ids: + meta = self._loader.get_plugin(plugin_id) + if meta is None: + dependency_graph[plugin_id] = set() + continue + dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids} + + indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()} + reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph} + + for plugin_id, dependencies in dependency_graph.items(): + for dependency in dependencies: + reverse_graph.setdefault(dependency, set()).add(plugin_id) + + queue: List[str] = sorted(plugin_id for plugin_id, degree in indegree.items() if degree == 0) + load_order: List[str] = [] + + while queue: + current_plugin_id = queue.pop(0) + load_order.append(current_plugin_id) + for dependent_plugin_id in sorted(reverse_graph.get(current_plugin_id, set())): + indegree[dependent_plugin_id] -= 1 + if indegree[dependent_plugin_id] == 0: + queue.append(dependent_plugin_id) + queue.sort() + + return list(reversed(load_order)) + + @staticmethod + def _normalize_requested_plugin_ids(plugin_ids: List[str]) -> List[str]: + """规范化批量重载请求中的插件 ID 列表。""" + + normalized_plugin_ids: List[str] = [] + seen_plugin_ids: Set[str] = set() + for plugin_id in plugin_ids: + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids: + continue + seen_plugin_ids.add(normalized_plugin_id) + normalized_plugin_ids.append(normalized_plugin_id) + return normalized_plugin_ids + + @staticmethod + def _finalize_failed_reload_messages( + failed_plugins: Dict[str, str], + rollback_failures: Dict[str, str], + ) -> Dict[str, str]: + """在重载失败后补充回滚结果说明。""" + + finalized_failures: Dict[str, str] = {} + for failed_plugin_id, failure_reason in failed_plugins.items(): + rollback_failure = rollback_failures.get(failed_plugin_id) + if rollback_failure: + finalized_failures[failed_plugin_id] = ( + f"{failure_reason};且旧版本恢复失败: {rollback_failure}" + ) + else: + finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)" + + for failed_plugin_id, rollback_failure in rollback_failures.items(): + if failed_plugin_id not in finalized_failures: + finalized_failures[failed_plugin_id] = f"旧版本恢复失败: {rollback_failure}" + + return finalized_failures + + async def _reload_plugin_by_id( + self, + plugin_id: str, + reason: str, + external_available_plugins: Optional[Dict[str, str]] = None, + ) -> ReloadPluginResultPayload: + """按插件 ID 在 Runner 进程内执行精确重载。 + + Args: + plugin_id: 目标插件 ID。 + reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件版本映射。 + + Returns: + ReloadPluginResultPayload: 结构化重载结果。 + """ + batch_result = await self._reload_plugins_by_ids( + [plugin_id], + reason, + external_available_plugins=external_available_plugins, + ) + return ReloadPluginResultPayload( + success=batch_result.success, + requested_plugin_id=plugin_id, + reloaded_plugins=batch_result.reloaded_plugins, + unloaded_plugins=batch_result.unloaded_plugins, + failed_plugins=batch_result.failed_plugins, + ) + + async def _reload_plugins_by_ids( + self, + plugin_ids: List[str], + reason: str, + external_available_plugins: Optional[Dict[str, str]] = None, + ) -> ReloadPluginsResultPayload: + """按插件 ID 列表在 Runner 进程内执行一次批量重载。""" + + normalized_plugin_ids = self._normalize_requested_plugin_ids(plugin_ids) + if not normalized_plugin_ids: + return ReloadPluginsResultPayload(success=True, requested_plugin_ids=[]) + + candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs) + failed_plugins: Dict[str, str] = {} + normalized_external_available = { + str(candidate_plugin_id or "").strip(): str(candidate_plugin_version or "").strip() + for candidate_plugin_id, candidate_plugin_version in (external_available_plugins or {}).items() + if str(candidate_plugin_id or "").strip() and str(candidate_plugin_version or "").strip() + } + + loaded_plugin_ids = set(self._loader.list_plugins()) + reload_root_ids: Set[str] = set() + for plugin_id in normalized_plugin_ids: + if plugin_id in duplicate_candidates: + conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) + failed_plugins[plugin_id] = f"检测到重复插件 ID: {conflict_paths}" + continue + + plugin_is_loaded = plugin_id in loaded_plugin_ids + plugin_has_candidate = plugin_id in candidates + if not plugin_is_loaded and not plugin_has_candidate: + failed_plugins[plugin_id] = "插件不存在或未找到合法的 manifest/plugin.py" + continue + + reload_root_ids.add(plugin_id) + + if not reload_root_ids: + return ReloadPluginsResultPayload( + success=False, + requested_plugin_ids=normalized_plugin_ids, + failed_plugins=failed_plugins, + ) + + target_plugin_ids: Set[str] = { + plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids + } + if loaded_root_plugin_ids := reload_root_ids & loaded_plugin_ids: + target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids)) + + unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids) + unloaded_plugins: List[str] = [] + retained_plugin_ids = loaded_plugin_ids - set(unload_order) + rollback_metas: Dict[str, PluginMeta] = {} + + for unload_plugin_id in unload_order: + meta = self._loader.get_plugin(unload_plugin_id) + if meta is None: + continue + rollback_metas[unload_plugin_id] = meta + await self._unload_plugin(meta, reason=reason, purge_modules=False) + self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir) + unloaded_plugins.append(unload_plugin_id) + + reload_candidates: Dict[str, PluginCandidate] = {} + for target_plugin_id in target_plugin_ids: + candidate = candidates.get(target_plugin_id) + if candidate is None: + failed_plugins[target_plugin_id] = "插件目录已不存在" + continue + reload_candidates[target_plugin_id] = candidate + + load_order, dependency_failures = self._loader.resolve_dependencies( + reload_candidates, + extra_available={ + **normalized_external_available, + **{ + retained_plugin_id: retained_meta.version + for retained_plugin_id in retained_plugin_ids + if (retained_meta := self._loader.get_plugin(retained_plugin_id)) is not None + }, + }, + ) + failed_plugins.update(dependency_failures) + + available_plugins = { + **normalized_external_available, + **{ + retained_plugin_id: retained_meta.version + for retained_plugin_id in retained_plugin_ids + if (retained_meta := self._loader.get_plugin(retained_plugin_id)) is not None + }, + } + reloaded_plugins: List[str] = [] + + for load_plugin_id in load_order: + if load_plugin_id in failed_plugins: + continue + + candidate = reload_candidates.get(load_plugin_id) + if candidate is None: + continue + + _, manifest, _ = candidate + if unsatisfied_dependencies := self._loader.manifest_validator.get_unsatisfied_plugin_dependencies( + manifest, + available_plugin_versions=available_plugins, + ): + failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}" + continue + + meta = self._loader.load_candidate(load_plugin_id, candidate) + if meta is None: + failed_plugins[load_plugin_id] = "插件模块加载失败" + continue + + activated = await self._activate_plugin(meta) + if not activated: + failed_plugins[load_plugin_id] = "插件初始化失败" + continue + + available_plugins[load_plugin_id] = meta.version + reloaded_plugins.append(load_plugin_id) + + if failed_plugins: + rollback_failures: Dict[str, str] = {} + + for reloaded_plugin_id in reversed(reloaded_plugins): + reloaded_meta = self._loader.get_plugin(reloaded_plugin_id) + if reloaded_meta is None: + continue + + try: + await self._unload_plugin( + reloaded_meta, + reason=f"{reason}_rollback_cleanup", + purge_modules=False, + ) + except Exception as exc: + rollback_failures[reloaded_plugin_id] = f"清理失败: {exc}" + finally: + self._loader.purge_plugin_modules(reloaded_plugin_id, reloaded_meta.plugin_dir) + + for rollback_plugin_id in reversed(unload_order): + rollback_meta = rollback_metas.get(rollback_plugin_id) + if rollback_meta is None: + continue + + try: + restored = await self._activate_plugin(rollback_meta) + except Exception as exc: + rollback_failures[rollback_plugin_id] = str(exc) + continue + + if not restored: + rollback_failures[rollback_plugin_id] = "无法重新激活旧版本" + + return ReloadPluginsResultPayload( + success=False, + requested_plugin_ids=normalized_plugin_ids, + reloaded_plugins=[], + unloaded_plugins=unloaded_plugins, + failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures), + ) + + requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids) + + return ReloadPluginsResultPayload( + success=requested_plugin_success and not failed_plugins, + requested_plugin_ids=normalized_plugin_ids, + reloaded_plugins=reloaded_plugins, + unloaded_plugins=unloaded_plugins, + failed_plugins=failed_plugins, + ) + async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None: - """通知 Host 当前 generation 已完成插件初始化。""" + """通知 Host 当前 Runner 已完成插件初始化。 + + Args: + loaded_plugins: 成功初始化的插件列表。 + failed_plugins: 初始化失败的插件列表。 + """ payload = RunnerReadyPayload( loaded_plugins=loaded_plugins, failed_plugins=failed_plugins, @@ -388,19 +925,13 @@ class PluginRunner: f"插件 {plugin_id} 未加载", ) - # 调用插件实例的组件方法 - instance = meta.instance component_name = invoke.component_name - - # 优先查找 handle_ 或直接 方法(新版 SDK 插件) - handler_method = getattr(instance, f"handle_{component_name}", None) - if handler_method is None: - handler_method = getattr(instance, component_name, None) + handler_method = self._resolve_component_handler(meta, component_name) # 回退: 旧版 LegacyPluginAdapter 通过 invoke_component 统一桥接 - if (handler_method is None or not callable(handler_method)) and hasattr(instance, "invoke_component"): + if (handler_method is None or not callable(handler_method)) and hasattr(meta.instance, "invoke_component"): try: - result = await instance.invoke_component(component_name, **invoke.args) + result = await meta.instance.invoke_component(component_name, **invoke.args) resp_payload = InvokeResultPayload(success=True, result=result) return envelope.make_response(payload=resp_payload.model_dump()) except Exception as e: @@ -447,11 +978,8 @@ class PluginRunner: f"插件 {plugin_id} 未加载", ) - instance = meta.instance component_name = invoke.component_name - handler_method = getattr(instance, f"handle_{component_name}", None) - if handler_method is None: - handler_method = getattr(instance, component_name, None) + handler_method = self._resolve_component_handler(meta, component_name) if handler_method is None or not callable(handler_method): return envelope.make_error_response( @@ -487,6 +1015,60 @@ class PluginRunner: logger.error(f"插件 {plugin_id} event_handler {component_name} 执行异常: {e}", exc_info=True) return envelope.make_response(payload={"success": False, "continue_processing": True}) + async def _handle_hook_invoke(self, envelope: Envelope) -> Envelope: + """处理 HookHandler 调用请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 标准化后的 Hook 调用结果。 + """ + try: + invoke = InvokePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + plugin_id = envelope.plugin_id + meta = self._loader.get_plugin(plugin_id) + if meta is None: + return envelope.make_error_response( + ErrorCode.E_PLUGIN_NOT_FOUND.value, + f"插件 {plugin_id} 未加载", + ) + + component_name = invoke.component_name + handler_method = self._resolve_component_handler(meta, component_name) + if handler_method is None or not callable(handler_method): + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {plugin_id} 无组件: {component_name}", + ) + + try: + raw = ( + await handler_method(**invoke.args) + if inspect.iscoroutinefunction(handler_method) + else handler_method(**invoke.args) + ) + except Exception as exc: + logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True) + return envelope.make_response(payload={"success": False, "continue_processing": True}) + + if raw is None: + result = {"success": True, "continue_processing": True} + elif isinstance(raw, dict): + result = { + "success": True, + "continue_processing": raw.get("continue_processing", True), + "modified_kwargs": raw.get("modified_kwargs"), + "custom_result": raw.get("custom_result"), + } + else: + result = {"success": True, "continue_processing": True, "custom_result": raw} + + return envelope.make_response(payload=result) + async def _handle_workflow_step(self, envelope: Envelope) -> Envelope: """处理 WorkflowStep 调用请求 @@ -506,9 +1088,8 @@ class PluginRunner: f"插件 {plugin_id} 未加载", ) - instance = meta.instance component_name = invoke.component_name - handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None) + handler_method = self._resolve_component_handler(meta, component_name) if handler_method is None or not callable(handler_method): return envelope.make_error_response( @@ -557,36 +1138,92 @@ class PluginRunner: async def _handle_shutdown(self, envelope: Envelope) -> Envelope: """处理关停 — 调用所有插件的 on_unload 后退出""" logger.info("收到 shutdown 信号,开始调用 on_unload") - for plugin_id in self._loader.list_plugins(): + for plugin_id in list(self._loader.list_plugins()): meta = self._loader.get_plugin(plugin_id) - if meta and hasattr(meta.instance, "on_unload"): - try: - ret = meta.instance.on_unload() - if asyncio.iscoroutine(ret): - await ret - except Exception as e: - logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True) + if meta is not None: + await self._unload_plugin(meta, reason="runner_shutdown") self._shutting_down = True return envelope.make_response(payload={"acknowledged": True}) async def _handle_config_updated(self, envelope: Envelope) -> Envelope: - """处理配置更新事件""" + """处理配置更新事件。""" + try: + payload = ConfigUpdatedPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + plugin_id = envelope.plugin_id if meta := self._loader.get_plugin(plugin_id): try: - config_data = envelope.payload.get("config_data", {}) - config_version = envelope.payload.get("config_version", "") - self._apply_plugin_config(meta, config_data=config_data) - if hasattr(meta.instance, "on_config_update"): - ret = meta.instance.on_config_update(config_data, config_version) - # 兼容同步和异步的 on_config_update 实现 - if asyncio.iscoroutine(ret): - await ret + config_scope = payload.config_scope.value + if config_scope == "self": + self._apply_plugin_config(meta, config_data=payload.config_data) + if not hasattr(meta.instance, "on_config_update"): + raise AttributeError("插件缺少 on_config_update() 实现") + + ret = meta.instance.on_config_update( + config_scope, + payload.config_data, + payload.config_version, + ) + if asyncio.iscoroutine(ret): + await ret except Exception as e: logger.error(f"插件 {plugin_id} 配置更新失败: {e}") return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) return envelope.make_response(payload={"acknowledged": True}) + async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope: + """处理按插件 ID 的精确重载请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 结构化重载结果。 + """ + try: + payload = ReloadPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + if self._reload_lock.locked(): + return envelope.make_error_response( + ErrorCode.E_RELOAD_IN_PROGRESS.value, + f"插件 {payload.plugin_id} 重载请求被拒绝:已有重载任务正在执行", + ) + + async with self._reload_lock: + result = await self._reload_plugin_by_id( + payload.plugin_id, + payload.reason, + external_available_plugins=dict(payload.external_available_plugins), + ) + return envelope.make_response(payload=result.model_dump()) + + async def _handle_reload_plugins(self, envelope: Envelope) -> Envelope: + """处理批量插件重载请求。""" + + try: + payload = ReloadPluginsPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + if self._reload_lock.locked(): + requested_plugin_ids = ", ".join(self._normalize_requested_plugin_ids(payload.plugin_ids)) or "" + return envelope.make_error_response( + ErrorCode.E_RELOAD_IN_PROGRESS.value, + f"插件 {requested_plugin_ids} 批量重载请求被拒绝:已有重载任务正在执行", + ) + + async with self._reload_lock: + result = await self._reload_plugins_by_ids( + list(payload.plugin_ids), + payload.reason, + external_available_plugins=dict(payload.external_available_plugins), + ) + return envelope.make_response(payload=result.model_dump()) + def request_capability(self) -> RPCClient: """获取 RPC 客户端(供 SDK 使用,发起能力调用)""" return self._rpc_client @@ -598,9 +1235,14 @@ class PluginRunner: def _isolate_sys_path(plugin_dirs: List[str]) -> None: """清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。 - 防止插件代码 import 主程序模块读取运行时数据。 + 同时阻止插件代码直接导入主程序内部 ``src.*`` 模块,并清理可直接从 + ``sys.modules`` 摸到的高权限叶子模块,避免绕过 SDK / capability 边界。 """ - import importlib.abc + from importlib import util as importlib_util + from types import ModuleType + + import builtins + import importlib import sysconfig # 保留: 标准库路径 + site-packages(含 SDK 和依赖) @@ -631,43 +1273,145 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: for d in plugin_dir_paths: allowed.add(d) - # 添加项目根目录(使得 src.plugin_runtime / src.common 可导入) - runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) - allowed.add(runtime_root) - preserved_paths = [p for p in sys.path if p in allowed] - for extra_path in [*plugin_dir_paths, runtime_root]: + for extra_path in plugin_dir_paths: if extra_path not in preserved_paths: preserved_paths.append(extra_path) sys.path[:] = preserved_paths - # 安装 import 钩子,阻止插件导入主程序核心模块 - # 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包 - class _PluginImportBlocker(importlib.abc.MetaPathFinder): - """阻止 Runner 子进程导入主程序核心模块。 + # 仅为旧版插件兼容层保留极小的 src.* 可见面: + # - src.plugin_system.*: 通过 maibot_sdk.compat 导入钩子重定向 + # - src.common.logger: 仓库内仍有少量旧插件沿用该日志入口 + allowed_src_exact_modules = frozenset( + { + "src", + "src.common", + "src.common.logger", + "src.common.logger_color_and_mapping", + } + ) + allowed_src_prefixes = ("src.plugin_system",) + plugin_module_prefix = "_maibot_plugin_" - 只放行 src.plugin_runtime 和 src.common, - 拒绝 src.chat_module / src.services 等主程序内部包。 - """ + def _is_allowed_src_module(fullname: str) -> bool: + """判断给定 src.* 模块是否在 Runner 允许列表中。""" + if fullname in allowed_src_exact_modules: + return True + return any(fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in allowed_src_prefixes) - _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common") + def _resolve_requester_name(import_globals: Any = None) -> str: + """解析当前导入请求的发起模块名。""" + if isinstance(import_globals, dict): + for key in ("__name__", "__package__"): + value = import_globals.get(key) + if isinstance(value, str) and value: + return value - def find_module(self, fullname, path=None): - return self if self._should_block(fullname) else None + frame = inspect.currentframe() + try: + current = frame.f_back if frame is not None else None + while current is not None: + module_name = current.f_globals.get("__name__", "") + if not isinstance(module_name, str) or not module_name: + current = current.f_back + continue + if module_name == __name__ or module_name.startswith("importlib"): + current = current.f_back + continue + return module_name + return "" + finally: + del frame - def load_module(self, fullname): - raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}") + def _is_plugin_import_request(import_globals: Any = None) -> bool: + """判断当前导入是否由插件模块直接发起。""" + requester_name = _resolve_requester_name(import_globals) + return requester_name.startswith(plugin_module_prefix) - def _should_block(self, fullname: str) -> bool: - # 放行非 src.* 的导入、以及 "src" 本身 - if not fullname.startswith("src.") or fullname == "src": - return False - # 放行白名单前缀 - return not any( - fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES - ) + def _format_block_message(fullname: str) -> str: + """构造统一的拒绝导入错误信息。""" + return ( + f"Runner 子进程不允许导入主程序模块: {fullname}。" + "请改用 maibot_sdk 或 src.plugin_system 兼容层提供的接口。" + ) - sys.meta_path.insert(0, _PluginImportBlocker()) + def _iter_requested_src_modules(name: str, fromlist: Any) -> List[str]: + """展开本次导入请求涉及的 src.* 模块名。""" + requested_modules = [name] + if not name.startswith("src") or not fromlist: + return requested_modules + + for item in fromlist: + if not isinstance(item, str) or not item or item == "*": + continue + requested_modules.append(f"{name}.{item}") + return requested_modules + + def _assert_plugin_import_allowed(name: str, import_globals: Any = None, fromlist: Any = ()) -> None: + """在插件发起导入时校验目标 src.* 模块是否允许访问。""" + if not _is_plugin_import_request(import_globals): + return + + for requested_module in _iter_requested_src_modules(name, fromlist): + if not requested_module.startswith("src"): + continue + if _is_allowed_src_module(requested_module): + continue + raise ImportError(_format_block_message(requested_module)) + + def _detach_module_from_parent(fullname: str, module: ModuleType) -> None: + """从父模块上移除已清理模块的属性引用。""" + parent_name, _, child_name = fullname.rpartition(".") + if not parent_name or not child_name: + return + + parent_module = sys.modules.get(parent_name) + if parent_module is None: + return + if getattr(parent_module, child_name, None) is module: + with contextlib.suppress(AttributeError): + delattr(parent_module, child_name) + + # 仅清理已加载的叶子模块,保留包对象给 Runner 自己的延迟导入和相对导入使用。 + existing_src_modules = sorted( + ( + (module_name, module) + for module_name, module in list(sys.modules.items()) + if module_name == "src" or module_name.startswith("src.") + ), + key=lambda item: item[0].count("."), + reverse=True, + ) + for module_name, module in existing_src_modules: + if _is_allowed_src_module(module_name) or hasattr(module, "__path__"): + continue + _detach_module_from_parent(module_name, module) + sys.modules.pop(module_name, None) + + # ``import`` 语句与 ``importlib.import_module`` 走的是不同入口,因此两边都需要兜底。 + builtins_module = cast(Any, builtins) + original_import = getattr(builtins_module, "__maibot_runner_original_import__", builtins.__import__) + builtins_module.__maibot_runner_original_import__ = original_import + + def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any: + if level == 0: + _assert_plugin_import_allowed(name, import_globals=globals, fromlist=fromlist) + return original_import(name, globals, locals, fromlist, level) + + cast(Any, _guarded_import).__maibot_runner_plugin_import_guard__ = True + builtins.__import__ = _guarded_import + + importlib_module = cast(Any, importlib) + original_import_module = getattr(importlib_module, "__maibot_runner_original_import_module__", importlib.import_module) + importlib_module.__maibot_runner_original_import_module__ = original_import_module + + def _guarded_import_module(name: str, package: Optional[str] = None) -> Any: + resolved_name = importlib_util.resolve_name(name, package) if name.startswith(".") else name + _assert_plugin_import_allowed(resolved_name) + return original_import_module(name, package) + + cast(Any, _guarded_import_module).__maibot_runner_plugin_import_guard__ = True + importlib.import_module = _guarded_import_module # ─── 进程入口 ────────────────────────────────────────────── @@ -675,8 +1419,9 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: async def _async_main() -> None: """异步主入口""" - host_address = os.environ.get(ENV_IPC_ADDRESS, "") - session_token = os.environ.get(ENV_SESSION_TOKEN, "") + host_address = os.environ.pop(ENV_IPC_ADDRESS, "") + external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "") + session_token = os.environ.pop(ENV_SESSION_TOKEN, "") plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "") if not host_address or not session_token: @@ -684,14 +1429,31 @@ async def _async_main() -> None: sys.exit(1) plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d] + try: + external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else {} + except json.JSONDecodeError: + logger.warning("解析外部依赖插件版本映射失败,已回退为空映射") + external_plugin_ids = {} + if not isinstance(external_plugin_ids, dict): + logger.warning("外部依赖插件版本映射格式非法,已回退为空映射") + external_plugin_ids = {} # sys.path 隔离: 只保留标准库、SDK 包、插件目录 _isolate_sys_path(plugin_dirs) - runner = PluginRunner(host_address, session_token, plugin_dirs) + runner = PluginRunner( + host_address, + session_token, + plugin_dirs, + external_available_plugins={ + str(plugin_id): str(plugin_version) + for plugin_id, plugin_version in external_plugin_ids.items() + }, + ) # 注册信号处理 def _mark_runner_shutting_down() -> None: + """标记 Runner 即将进入关停流程。""" runner._shutting_down = True _install_shutdown_signal_handlers(_mark_runner_shutting_down) diff --git a/src/plugin_runtime/transport/named_pipe.py b/src/plugin_runtime/transport/named_pipe.py index a759507d..7fd39bc9 100644 --- a/src/plugin_runtime/transport/named_pipe.py +++ b/src/plugin_runtime/transport/named_pipe.py @@ -1,6 +1,9 @@ """Windows Named Pipe 传输实现。 适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。 + +注意:Named Pipe 是 Windows 特有的 IPC 机制, +在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。 """ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast @@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin" class _NamedPipeServerHandle(Protocol): + """Named Pipe 服务端句柄的协议定义。""" def close(self) -> None: ... class _NamedPipeEventLoop(Protocol): + """ProactorEventLoop 的协议定义,提供 named pipe 相关方法。""" async def start_serving_pipe( self, protocol_factory: Callable[[], asyncio.BaseProtocol], @@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol): def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str: + """规范化 Named Pipe 地址。 + + Args: + pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用, + 否则会自动添加前缀。如果为 None 则生成随机名称。 + + Returns: + 规范化的管道地址(格式:\\\\.\\pipe\\name) + """ if pipe_name and pipe_name.startswith(_PIPE_PREFIX): return pipe_name @@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str: class NamedPipeConnection(Connection): - """基于 Windows Named Pipe 的连接。""" + """基于 Windows Named Pipe 的连接。 + + 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。 + """ - pass + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + super().__init__(reader, writer) class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol): + """Named Pipe 服务端协议实现。 + + 处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。 + """ + def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None: self._reader: asyncio.StreamReader = asyncio.StreamReader() super().__init__(self._reader) @@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol): self._handler_task: Optional[asyncio.Task[None]] = None def connection_made(self, transport: asyncio.BaseTransport) -> None: + """连接建立时的回调。""" super().connection_made(transport) writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop) connection = NamedPipeConnection(self._reader, writer) - self._handler_task = self._loop.create_task(self._run_handler(connection)) + # 使用 asyncio.create_task 确保任务正确调度 + self._handler_task = asyncio.create_task(self._run_handler(connection)) self._handler_task.add_done_callback(self._on_handler_done) async def _run_handler(self, connection: NamedPipeConnection) -> None: + """运行连接处理器。""" try: await self._handler(connection) finally: await connection.close() def _on_handler_done(self, task: asyncio.Task[None]) -> None: + """连接处理器完成时的回调。""" if task.cancelled(): return if exc := task.exception(): - self._loop.call_exception_handler( - { - "message": "Named pipe 连接处理失败", - "exception": exc, - "protocol": self, - } - ) + try: + self._loop.call_exception_handler( + { + "message": "Named pipe 连接处理失败", + "exception": exc, + "protocol": self, + } + ) + except Exception: + # 如果 loop 已经关闭,忽略异常 + pass class NamedPipeTransportServer(TransportServer): - """Windows Named Pipe 传输服务端。""" + """Windows Named Pipe 传输服务端。 + + 使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。 + """ def __init__(self, pipe_name: Optional[str] = None) -> None: self._address: str = _normalize_pipe_address(pipe_name) self._servers: List[_NamedPipeServerHandle] = [] async def start(self, handler: ConnectionHandler) -> None: + """启动 Named Pipe 服务端。 + + Args: + handler: 新连接到来时的回调函数 + + Raises: + RuntimeError: 当在非 Windows 平台或事件循环不支持时 + """ if sys.platform != "win32": raise RuntimeError("Named pipe 仅支持 Windows") @@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer): ) async def stop(self) -> None: + """停止 Named Pipe 服务端并清理资源。""" for server in self._servers: server.close() + # 等待所有服务器句柄完全关闭 + await asyncio.gather( + *[asyncio.sleep(0.1) for _ in self._servers], + return_exceptions=True + ) self._servers.clear() - await asyncio.sleep(0) def get_address(self) -> str: return self._address class NamedPipeTransportClient(TransportClient): - """Windows Named Pipe 传输客户端。""" + """Windows Named Pipe 传输客户端。 + + 用于主动连接到 Named Pipe 服务端。 + """ def __init__(self, address: str) -> None: self._address: str = _normalize_pipe_address(address) async def connect(self) -> Connection: + """建立到 Named Pipe 服务端的连接。 + + Returns: + NamedPipeConnection: 连接对象 + + Raises: + NotImplementedError: 当在非 Windows 平台或事件循环不支持时 + """ if sys.platform != "win32": - raise RuntimeError("Named pipe 仅支持 Windows") + raise NotImplementedError("Named pipe 仅支持 Windows") loop = asyncio.get_running_loop() if not hasattr(loop, "create_pipe_connection"): - raise RuntimeError("当前事件循环不支持 Windows named pipe") + raise NotImplementedError("当前事件循环不支持 Windows named pipe") pipe_loop = cast(_NamedPipeEventLoop, loop) reader = asyncio.StreamReader() protocol = asyncio.StreamReaderProtocol(reader) transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address) - writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop) + # 使用返回的 protocol 创建 StreamWriter + writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop) return NamedPipeConnection(reader, writer) \ No newline at end of file diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py index 47bf033b..af71ea5d 100644 --- a/src/plugin_runtime/transport/uds.py +++ b/src/plugin_runtime/transport/uds.py @@ -1,6 +1,9 @@ """Unix Domain Socket 传输实现 适用于 Linux / macOS 平台。 + +注意:UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制, +在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。 """ from pathlib import Path @@ -8,20 +11,30 @@ from typing import Optional import asyncio import os +import sys import tempfile from .base import Connection, ConnectionHandler, TransportClient, TransportServer class UDSConnection(Connection): - """基于 UDS 的连接""" + """基于 UDS 的连接 + + 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。 + """ - pass # 直接复用 Connection 基类的分帧读写 + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + super().__init__(reader, writer) # Unix domain socket 路径的系统限制(sun_path 字段长度) -# Linux: 108 字节, macOS: 104 字节 -_UDS_PATH_MAX = 104 +# Linux: 108 字节,macOS: 104 字节,其他 Unix: 通常 104 字节 +if sys.platform == "linux": + _UDS_PATH_MAX = 108 +elif sys.platform == "darwin": # macOS + _UDS_PATH_MAX = 104 +else: + _UDS_PATH_MAX = 104 # 保守默认值 class UDSTransportServer(TransportServer): @@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer): self._server: Optional[asyncio.AbstractServer] = None async def start(self, handler: ConnectionHandler) -> None: + """启动 UDS 服务端 + + Args: + handler: 新连接到来时的回调函数 + + Raises: + RuntimeError: 当在非 Unix 平台(如 Windows)上调用时 + """ + # 平台检查:UDS 仅在 Unix-like 系统上可用 + if sys.platform == "win32": + raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe") + # 清理残留 socket 文件 if self._socket_path.exists(): self._socket_path.unlink() @@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer): finally: await conn.close() - self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path)) + try: + self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path)) - # 设置文件权限为仅当前用户可访问 - self._socket_path.chmod(0o600) + # 设置文件权限为仅当前用户可访问 + self._socket_path.chmod(0o600) + except Exception: + # 启动失败时清理可能创建的目录和 socket 文件 + if self._socket_path.exists(): + self._socket_path.unlink() + raise async def stop(self) -> None: if self._server: @@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer): class UDSTransportClient(TransportClient): - """UDS 传输客户端""" + """UDS 传输客户端 + + 用于主动连接到 UDS 服务端。 + """ def __init__(self, socket_path: Path) -> None: self._socket_path: Path = socket_path async def connect(self) -> Connection: + """建立到 UDS 服务端的连接 + + Returns: + UDSConnection: 连接对象 + + Raises: + RuntimeError: 当在非 Unix 平台(如 Windows)上调用时 + """ + # 平台检查:UDS 仅在 Unix-like 系统上可用 + if sys.platform == "win32": + raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe") + reader, writer = await asyncio.open_unix_connection(str(self._socket_path)) return UDSConnection(reader, writer) diff --git a/src/plugins/built_in/emoji_plugin/_manifest.json b/src/plugins/built_in/emoji_plugin/_manifest.json index d4d262e7..5b53abad 100644 --- a/src/plugins/built_in/emoji_plugin/_manifest.json +++ b/src/plugins/built_in/emoji_plugin/_manifest.json @@ -1,32 +1,28 @@ { - "manifest_version": 1, - "name": "Emoji插件 (Emoji Actions)", + "manifest_version": 2, "version": "2.0.0", - "description": "可以发送和管理Emoji", + "name": "Emoji插件 (Emoji Actions)", + "description": "可以发送和管理 Emoji", "author": { "name": "SengokuCola", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", + "urls": { + "repository": "https://github.com/MaiM-with-u/maibot", + "homepage": "https://github.com/MaiM-with-u/maibot", + "documentation": "https://github.com/MaiM-with-u/maibot", + "issues": "https://github.com/MaiM-with-u/maibot/issues" + }, "host_application": { - "min_version": "1.0.0" + "min_version": "1.0.0", + "max_version": "1.0.0" }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": ["emoji", "action", "built-in"], - "categories": ["Emoji"], - "default_locale": "zh-CN", - "plugin_info": { - "is_built_in": true, - "plugin_type": "action_provider", - "components": [ - { - "type": "action", - "name": "emoji", - "description": "发送表情包辅助表达情绪" - } - ] + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" }, + "dependencies": [], "capabilities": [ "emoji.get_random", "message.get_recent", @@ -34,5 +30,12 @@ "llm.generate", "send.emoji", "config.get" - ] + ], + "i18n": { + "default_locale": "zh-CN", + "supported_locales": [ + "zh-CN" + ] + }, + "id": "builtin.emoji-plugin" } diff --git a/src/plugins/built_in/emoji_plugin/plugin.py b/src/plugins/built_in/emoji_plugin/plugin.py index b946931b..cc6b87c5 100644 --- a/src/plugins/built_in/emoji_plugin/plugin.py +++ b/src/plugins/built_in/emoji_plugin/plugin.py @@ -3,11 +3,11 @@ 根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。 """ -import random - -from maibot_sdk import MaiBotPlugin, Action +from maibot_sdk import Action, MaiBotPlugin from maibot_sdk.types import ActivationType +import random + class EmojiPlugin(MaiBotPlugin): """表情包插件""" @@ -95,10 +95,35 @@ class EmojiPlugin(MaiBotPlugin): return True, f"成功发送表情包:[表情包:{chosen_emotion}]" return False, "发送表情包失败" - async def on_load(self): + async def on_load(self) -> None: + """处理插件加载。""" + # 从插件配置读取 emoji_chance 来覆盖默认概率 await self.ctx.config.get("emoji.emoji_chance") + async def on_unload(self) -> None: + """处理插件卸载。""" + + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del config_data + del version + if scope == "self": + await self.ctx.config.get("emoji.emoji_chance") + + +def create_plugin() -> EmojiPlugin: + """创建 Emoji 插件实例。 + + Returns: + EmojiPlugin: 新的 Emoji 插件实例。 + """ -def create_plugin(): return EmojiPlugin() diff --git a/src/plugins/built_in/plugin_management/_manifest.json b/src/plugins/built_in/plugin_management/_manifest.json index a5b52835..a2bfa9ce 100644 --- a/src/plugins/built_in/plugin_management/_manifest.json +++ b/src/plugins/built_in/plugin_management/_manifest.json @@ -1,51 +1,46 @@ { - "manifest_version": 1, - "name": "插件和组件管理 (Plugin and Component Management)", + "manifest_version": 2, "version": "2.0.0", - "description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。", + "name": "插件和组件管理 (Plugin and Component Management)", + "description": "通过系统 API 管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。", "author": { "name": "MaiBot团队", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" + "urls": { + "repository": "https://github.com/MaiM-with-u/maibot", + "homepage": "https://github.com/MaiM-with-u/maibot", + "documentation": "https://github.com/MaiM-with-u/maibot", + "issues": "https://github.com/MaiM-with-u/maibot/issues" }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": [ - "plugins", - "components", - "management", - "built-in" + "host_application": { + "min_version": "1.0.0", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [], + "capabilities": [ + "component.get_all_plugins", + "component.list_loaded_plugins", + "component.list_registered_plugins", + "component.enable", + "component.disable", + "component.load_plugin", + "component.unload_plugin", + "component.reload_plugin", + "send.text", + "config.get" ], - "categories": [ - "Core System", - "Plugin Management" - ], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": true, - "plugin_type": "plugin_management", - "capabilities": [ - "component.get_all_plugins", - "component.list_loaded_plugins", - "component.list_registered_plugins", - "component.enable", - "component.disable", - "component.load_plugin", - "component.unload_plugin", - "component.reload_plugin", - "send.text", - "config.get" - ], - "components": [ - { - "type": "command", - "name": "management", - "description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。" - } + "i18n": { + "default_locale": "zh-CN", + "locales_path": "_locales", + "supported_locales": [ + "zh-CN" ] - } -} \ No newline at end of file + }, + "id": "builtin.plugin-management" +} diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index fe0888c6..aa2da795 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -3,7 +3,7 @@ 通过 /pm 命令管理插件和组件的生命周期。 """ -from maibot_sdk import MaiBotPlugin, Command +from maibot_sdk import Command, MaiBotPlugin _VALID_COMPONENT_TYPES = ("action", "command", "event_handler") @@ -44,6 +44,12 @@ HELP_COMPONENT = ( class PluginManagementPlugin(MaiBotPlugin): """插件和组件管理插件""" + async def on_load(self) -> None: + """处理插件加载。""" + + async def on_unload(self) -> None: + """处理插件卸载。""" + @Command( "management", description="管理插件和组件的生命周期", @@ -268,6 +274,25 @@ class PluginManagementPlugin(MaiBotPlugin): return components return [] + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del scope + del config_data + del version + + +def create_plugin() -> PluginManagementPlugin: + """创建插件管理插件实例。 + + Returns: + PluginManagementPlugin: 新的插件管理插件实例。 + """ -def create_plugin(): return PluginManagementPlugin() diff --git a/src/services/send_service.py b/src/services/send_service.py index 7af55716..134fb15e 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -1,155 +1,640 @@ """ -发送服务模块 +发送服务模块。 -提供发送各种类型消息的核心功能。 +统一封装内部模块的出站消息发送逻辑: + +1. 内部模块统一调用本模块。 +2. send service 只负责构造和预处理消息。 +3. 具体走插件链还是 legacy 旧链,由 Platform IO 内部统一决策。 """ -from typing import Dict, List, Optional, TYPE_CHECKING +from copy import deepcopy +from typing import Any, Dict, List, Optional +import asyncio +import base64 +import hashlib import time import traceback +from datetime import datetime -from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo - +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import SessionMessage -from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from src.chat.utils.utils import get_bot_account -from src.common.data_models.mai_message_data_model import MaiMessage -from src.common.data_models.message_component_data_model import DictComponent, MessageSequence +from src.chat.utils.utils import calculate_typing_time, get_bot_account +from src.common.data_models.mai_message_data_model import GroupInfo, MaiMessage, MessageInfo, UserInfo +from src.common.data_models.message_component_data_model import ( + AtComponent, + DictComponent, + EmojiComponent, + ForwardNodeComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + StandardMessageComponents, + TextComponent, + VoiceComponent, +) from src.common.logger import get_logger +from src.common.utils.utils_message import MessageUtils from src.config.config import global_config - -if TYPE_CHECKING: - from src.chat.message_receive.message import SessionMessage +from src.platform_io import DeliveryBatch, get_platform_io_manager +from src.platform_io.route_key_factory import RouteKeyFactory logger = get_logger("send_service") -# ============================================================================= -# 内部实现函数 -# ============================================================================= +def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]: + """从目标会话继承 Platform IO 路由元数据。 + + Args: + target_stream: 当前消息要发送到的会话对象。 + + Returns: + Dict[str, object]: 可安全透传到出站消息 ``additional_config`` 中的 + 路由辅助字段。 + """ + inherited_metadata: Dict[str, object] = {} + + context_message = target_stream.context.message if target_stream.context else None + if context_message is not None: + additional_config = context_message.message_info.additional_config + if isinstance(additional_config, dict): + for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS): + value = additional_config.get(key) + if value is None: + continue + normalized_value = str(value).strip() + if normalized_value: + inherited_metadata[key] = value + + # 当目标会话没有可继承的上下文消息时,至少补齐当前平台账号, + # 让按 ``platform + account_id`` 绑定的路由仍有机会命中。 + if not RouteKeyFactory.extract_components(inherited_metadata)[0]: + bot_account = get_bot_account(target_stream.platform) + if bot_account: + inherited_metadata["platform_io_account_id"] = bot_account + + if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()): + inherited_metadata["platform_io_target_group_id"] = normalized_group_id + + if target_stream.user_id and (normalized_user_id := str(target_stream.user_id).strip()): + inherited_metadata["platform_io_target_user_id"] = normalized_user_id + + return inherited_metadata + + +def _build_binary_component_from_base64(component_type: str, raw_data: str) -> StandardMessageComponents: + """根据 Base64 数据构造二进制消息组件。 + + Args: + component_type: 组件类型名称。 + raw_data: Base64 编码后的二进制数据。 + + Returns: + StandardMessageComponents: 转换后的内部消息组件。 + + Raises: + ValueError: 当组件类型不受支持时抛出。 + """ + binary_data = base64.b64decode(raw_data) + binary_hash = hashlib.sha256(binary_data).hexdigest() + + if component_type == "image": + return ImageComponent(binary_hash=binary_hash, binary_data=binary_data) + if component_type == "emoji": + return EmojiComponent(binary_hash=binary_hash, binary_data=binary_data) + if component_type == "voice": + return VoiceComponent(binary_hash=binary_hash, binary_data=binary_data) + raise ValueError(f"不支持的二进制组件类型: {component_type}") + + +def _build_message_sequence_from_custom_message( + message_type: str, + content: str | Dict[str, Any], +) -> MessageSequence: + """根据自定义消息类型构造内部消息组件序列。 + + Args: + message_type: 自定义消息类型。 + content: 自定义消息内容。 + + Returns: + MessageSequence: 转换后的消息组件序列。 + """ + normalized_type = message_type.strip().lower() + + if normalized_type == "text": + return MessageSequence(components=[TextComponent(text=str(content))]) + + if normalized_type in {"image", "emoji", "voice"}: + return MessageSequence( + components=[_build_binary_component_from_base64(normalized_type, str(content))] + ) + + if normalized_type == "at": + return MessageSequence(components=[AtComponent(target_user_id=str(content))]) + + if normalized_type == "reply": + return MessageSequence(components=[ReplyComponent(target_message_id=str(content))]) + + if normalized_type == "dict" and isinstance(content, dict): + return MessageSequence(components=[DictComponent(data=deepcopy(content))]) + + return MessageSequence( + components=[ + DictComponent( + data={ + "type": normalized_type, + "data": deepcopy(content), + } + ) + ] + ) + + +def _clone_message_sequence(message_sequence: MessageSequence) -> MessageSequence: + """复制消息组件序列,避免原对象被发送流程修改。 + + Args: + message_sequence: 原始消息组件序列。 + + Returns: + MessageSequence: 深拷贝后的消息组件序列。 + """ + return deepcopy(message_sequence) + + +def _detect_outbound_message_flags(message_sequence: MessageSequence) -> Dict[str, bool]: + """根据消息组件序列推断出站消息标记。 + + Args: + message_sequence: 待发送的消息组件序列。 + + Returns: + Dict[str, bool]: 包含 ``is_emoji``、``is_picture``、``is_command`` 的标记字典。 + """ + if len(message_sequence.components) != 1: + return { + "is_emoji": False, + "is_picture": False, + "is_command": False, + } + + component = message_sequence.components[0] + is_command = False + if isinstance(component, DictComponent) and isinstance(component.data, dict): + is_command = str(component.data.get("type") or "").strip().lower() == "command" + + return { + "is_emoji": isinstance(component, EmojiComponent), + "is_picture": isinstance(component, ImageComponent), + "is_command": is_command, + } + + +def _describe_message_sequence(message_sequence: MessageSequence) -> str: + """生成消息组件序列的简短描述文本。 + + Args: + message_sequence: 待描述的消息组件序列。 + + Returns: + str: 适用于日志的简短类型描述。 + """ + if len(message_sequence.components) != 1: + return "message_sequence" + + component = message_sequence.components[0] + if isinstance(component, DictComponent) and isinstance(component.data, dict): + custom_type = str(component.data.get("type") or "").strip() + return custom_type or "dict" + + if isinstance(component, TextComponent): + return component.format_name + + if isinstance(component, ImageComponent): + return component.format_name + + if isinstance(component, EmojiComponent): + return component.format_name + + if isinstance(component, VoiceComponent): + return component.format_name + + if isinstance(component, AtComponent): + return component.format_name + + if isinstance(component, ReplyComponent): + return component.format_name + + if isinstance(component, ForwardNodeComponent): + return component.format_name + + return "unknown" + + +def _build_processed_plain_text(message: SessionMessage) -> str: + """为出站消息构造轻量纯文本摘要。 + + Args: + message: 待发送的内部消息对象。 + + Returns: + str: 适用于日志与打字时长估算的纯文本摘要。 + """ + processed_parts: List[str] = [] + for component in message.raw_message.components: + if isinstance(component, TextComponent): + processed_parts.append(component.text) + continue + + if isinstance(component, ImageComponent): + processed_parts.append(component.content or "[图片]") + continue + + if isinstance(component, EmojiComponent): + processed_parts.append(component.content or "[表情]") + continue + + if isinstance(component, VoiceComponent): + processed_parts.append(component.content or "[语音]") + continue + + if isinstance(component, AtComponent): + at_target = component.target_user_cardname or component.target_user_nickname or component.target_user_id + processed_parts.append(f"@{at_target}") + continue + + if isinstance(component, ReplyComponent): + processed_parts.append(component.target_message_content or "[回复消息]") + continue + + if isinstance(component, DictComponent): + raw_type = component.data.get("type") if isinstance(component.data, dict) else None + if isinstance(raw_type, str) and raw_type.strip(): + processed_parts.append(f"[{raw_type.strip()}消息]") + else: + processed_parts.append("[自定义消息]") + continue + + return " ".join(part for part in processed_parts if part) + + +def _build_outbound_session_message( + message_sequence: MessageSequence, + stream_id: str, + display_message: str = "", + reply_message: Optional[MaiMessage] = None, + selected_expressions: Optional[List[int]] = None, +) -> Optional[SessionMessage]: + """根据目标会话构建待发送的内部消息对象。 + + Args: + message_sequence: 待发送的消息组件序列。 + stream_id: 目标会话 ID。 + display_message: 用于界面展示的文本内容。 + reply_message: 被回复的锚点消息。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + Optional[SessionMessage]: 构建成功时返回内部消息对象;若目标会话或 + 机器人账号不存在,则返回 ``None``。 + """ + target_stream = _chat_manager.get_session_by_session_id(stream_id) + if target_stream is None: + logger.error(f"[SendService] 未找到聊天流: {stream_id}") + return None + + bot_user_id = get_bot_account(target_stream.platform) + if not bot_user_id: + logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息") + return None + + current_time = time.time() + message_id = f"send_api_{int(current_time * 1000)}" + anchor_message = reply_message.deepcopy() if reply_message is not None else None + + group_info: Optional[GroupInfo] = None + if target_stream.group_id: + group_name = "" + if ( + target_stream.context + and target_stream.context.message + and target_stream.context.message.message_info.group_info + ): + group_name = target_stream.context.message.message_info.group_info.group_name + group_info = GroupInfo( + group_id=target_stream.group_id, + group_name=group_name, + ) + + additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream) + if selected_expressions is not None: + additional_config["selected_expressions"] = selected_expressions + + outbound_message = SessionMessage( + message_id=message_id, + timestamp=datetime.fromtimestamp(current_time), + platform=target_stream.platform, + ) + outbound_message.message_info = MessageInfo( + user_info=UserInfo( + user_id=bot_user_id, + user_nickname=global_config.bot.nickname, + ), + group_info=group_info, + additional_config=additional_config, + ) + outbound_message.raw_message = _clone_message_sequence(message_sequence) + outbound_message.session_id = target_stream.session_id + outbound_message.display_message = display_message + outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None + message_flags = _detect_outbound_message_flags(outbound_message.raw_message) + outbound_message.is_emoji = message_flags["is_emoji"] + outbound_message.is_picture = message_flags["is_picture"] + outbound_message.is_command = message_flags["is_command"] + outbound_message.initialized = True + return outbound_message + + +def _ensure_reply_component(message: SessionMessage, reply_message_id: str) -> None: + """为消息补充回复组件。 + + Args: + message: 待发送的内部消息对象。 + reply_message_id: 被引用消息的 ID。 + """ + if message.raw_message.components: + first_component = message.raw_message.components[0] + if isinstance(first_component, ReplyComponent) and first_component.target_message_id == reply_message_id: + return + + message.raw_message.components.insert(0, ReplyComponent(target_message_id=reply_message_id)) + + +async def _prepare_message_for_platform_io( + message: SessionMessage, + *, + typing: bool, + set_reply: bool, + reply_message_id: Optional[str], +) -> None: + """为 Platform IO 发送链预处理消息。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否构建引用回复组件。 + reply_message_id: 被引用消息的 ID。 + + Raises: + ValueError: 当要求设置引用回复但缺少 ``reply_message_id`` 时抛出。 + """ + if set_reply: + if not reply_message_id: + raise ValueError("set_reply=True 时必须提供 reply_message_id") + _ensure_reply_component(message, reply_message_id) + + message.processed_plain_text = _build_processed_plain_text(message) + if typing: + typing_time = calculate_typing_time( + input_string=message.processed_plain_text or "", + is_emoji=message.is_emoji, + ) + await asyncio.sleep(typing_time) + + +def _store_sent_message(message: SessionMessage) -> None: + """将已成功发送的消息写入数据库。 + + Args: + message: 已成功发送的内部消息对象。 + """ + MessageUtils.store_message_to_db(message) + + +def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None: + """输出 Platform IO 批量发送失败详情。 + + Args: + delivery_batch: Platform IO 返回的批量回执。 + """ + failed_details = "; ".join( + f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}" + for receipt in delivery_batch.failed_receipts + ) or "未命中任何发送路由" + logger.warning( + "[SendService] Platform IO 发送失败: platform=%s %s", + delivery_batch.route_key.platform, + failed_details, + ) + + +async def _send_via_platform_io( + message: SessionMessage, + *, + typing: bool, + set_reply: bool, + reply_message_id: Optional[str], + storage_message: bool, + show_log: bool, +) -> bool: + """通过 Platform IO 发送消息。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否设置引用回复。 + reply_message_id: 被引用消息的 ID。 + storage_message: 发送成功后是否写入数据库。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ + platform_io_manager = get_platform_io_manager() + try: + await platform_io_manager.ensure_send_pipeline_ready() + except Exception as exc: + logger.error(f"[SendService] 准备 Platform IO 发送管线失败: {exc}") + logger.debug(traceback.format_exc()) + return False + + try: + route_key = platform_io_manager.build_route_key_from_message(message) + except Exception as exc: + logger.warning(f"[SendService] 根据消息构造 Platform IO 路由键失败: {exc}") + return False + + try: + await _prepare_message_for_platform_io( + message, + typing=typing, + set_reply=set_reply, + reply_message_id=reply_message_id, + ) + delivery_batch = await platform_io_manager.send_message( + message, + route_key, + metadata={"show_log": False}, + ) + except Exception as exc: + logger.error(f"[SendService] Platform IO 发送异常: {exc}") + logger.debug(traceback.format_exc()) + return False + + if delivery_batch.has_success: + if storage_message: + _store_sent_message(message) + if show_log: + successful_driver_ids = [ + receipt.driver_id or "unknown" + for receipt in delivery_batch.sent_receipts + ] + logger.info( + "[SendService] 已通过 Platform IO 将消息发往平台 '%s' (drivers: %s)", + route_key.platform, + ", ".join(successful_driver_ids), + ) + return True + + _log_platform_io_failures(delivery_batch) + return False + + +async def send_session_message( + message: SessionMessage, + *, + typing: bool = False, + set_reply: bool = False, + reply_message_id: Optional[str] = None, + storage_message: bool = True, + show_log: bool = True, +) -> bool: + """统一发送一条内部消息。 + + 该方法是内部模块的统一发送入口: + + 1. 构造并维护内部消息对象。 + 2. 由 Platform IO 统一决定走插件链还是 legacy 旧链。 + 3. send service 不再自行判断底层发送路径。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否设置引用回复。 + reply_message_id: 被引用消息的 ID。 + storage_message: 发送成功后是否写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 发送成功时返回 ``True``,否则返回 ``False``。 + """ + if not message.message_id: + logger.error("[SendService] 消息缺少 message_id,无法发送") + raise ValueError("消息缺少 message_id,无法发送") + + return await _send_via_platform_io( + message, + typing=typing, + set_reply=set_reply, + reply_message_id=reply_message_id, + storage_message=storage_message, + show_log=show_log, + ) async def _send_to_target( - message_segment: Seg, + message_sequence: MessageSequence, stream_id: str, display_message: str = "", typing: bool = False, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, storage_message: bool = True, show_log: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定目标发送消息的内部实现""" + """向指定目标构建并发送消息。 + + Args: + message_sequence: 待发送的消息组件序列。 + stream_id: 目标会话 ID。 + display_message: 用于界面展示的文本内容。 + typing: 是否显示输入中状态。 + set_reply: 是否在发送时附带引用回复。 + reply_message: 被回复的消息对象。 + storage_message: 是否将发送结果写入消息存储。 + show_log: 是否输出发送日志。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + bool: 发送成功返回 ``True``,否则返回 ``False``。 + """ try: - if set_reply and not reply_message: + if set_reply and reply_message is None: logger.warning("[SendService] 使用引用回复,但未提供回复消息") return False if show_log: - logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}") + logger.debug(f"[SendService] 发送{_describe_message_sequence(message_sequence)}消息到 {stream_id}") - target_stream = _chat_manager.get_session_by_session_id(stream_id) - if not target_stream: - logger.error(f"[SendService] 未找到聊天流: {stream_id}") - return False - - message_sender = UniversalMessageSender() - - current_time = time.time() - message_id = f"send_api_{int(current_time * 1000)}" - - anchor_message: Optional[MaiMessage] = None - if reply_message: - anchor_message = reply_message.deepcopy() - if anchor_message: - logger.debug( - f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}" - ) - - group_info = None - if target_stream.group_id: - group_name = "" - if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info: - group_name = target_stream.context.message.message_info.group_info.group_name - group_info = MaimGroupInfo( - group_id=target_stream.group_id, - group_name=group_name, - platform=target_stream.platform, - ) - - additional_config: dict[str, object] = {} - if selected_expressions is not None: - additional_config["selected_expressions"] = selected_expressions - bot_user_id = get_bot_account(target_stream.platform) - if not bot_user_id: - logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息") - return False - - maim_message = MessageBase( - message_info=BaseMessageInfo( - platform=target_stream.platform, - message_id=message_id, - time=current_time, - user_info=MaimUserInfo( - user_id=bot_user_id, - user_nickname=global_config.bot.nickname, - platform=target_stream.platform, - ), - group_info=group_info, - additional_config=additional_config, - ), - message_segment=message_segment, + outbound_message = _build_outbound_session_message( + message_sequence=message_sequence, + stream_id=stream_id, + display_message=display_message, + reply_message=reply_message, + selected_expressions=selected_expressions, ) - bot_message = SessionMessage.from_maim_message(maim_message) - bot_message.session_id = target_stream.session_id - bot_message.display_message = display_message - bot_message.reply_to = anchor_message.message_id if anchor_message else None - bot_message.is_emoji = message_segment.type == "emoji" - bot_message.is_picture = message_segment.type == "image" - bot_message.is_command = message_segment.type == "command" + if outbound_message is None: + return False - sent_msg = await message_sender.send_message( - bot_message, + sent = await send_session_message( + outbound_message, typing=typing, set_reply=set_reply, - reply_message_id=anchor_message.message_id if anchor_message else None, + reply_message_id=reply_message.message_id if reply_message is not None else None, storage_message=storage_message, show_log=show_log, ) - - if sent_msg: + if sent: logger.debug(f"[SendService] 成功发送消息到 {stream_id}") return True - else: - logger.error("[SendService] 发送消息失败") - return False - except Exception as e: - logger.error(f"[SendService] 发送消息时出错: {e}") + logger.error("[SendService] 发送消息失败") + return False + except Exception as exc: + logger.error(f"[SendService] 发送消息时出错: {exc}") traceback.print_exc() return False -# ============================================================================= -# 公共函数 - 预定义类型的发送函数 -# ============================================================================= - - async def text_to_stream( text: str, stream_id: str, typing: bool = False, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, storage_message: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定流发送文本消息""" + """向指定流发送文本消息。 + + Args: + text: 要发送的文本内容。 + stream_id: 目标会话 ID。 + typing: 是否显示输入中状态。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + storage_message: 是否在发送成功后写入数据库。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( - message_segment=Seg(type="text", data=text), + message_sequence=MessageSequence(components=[TextComponent(text=text)]), stream_id=stream_id, display_message="", typing=typing, @@ -165,11 +650,22 @@ async def emoji_to_stream( stream_id: str, storage_message: bool = True, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, ) -> bool: - """向指定流发送表情包""" + """向指定流发送表情消息。 + + Args: + emoji_base64: 表情图片的 Base64 内容。 + stream_id: 目标会话 ID。 + storage_message: 是否在发送成功后写入数据库。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( - message_segment=Seg(type="emoji", data=emoji_base64), + message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64), stream_id=stream_id, display_message="", typing=False, @@ -184,11 +680,22 @@ async def image_to_stream( stream_id: str, storage_message: bool = True, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, ) -> bool: - """向指定流发送图片""" + """向指定流发送图片消息。 + + Args: + image_base64: 图片的 Base64 内容。 + stream_id: 目标会话 ID。 + storage_message: 是否在发送成功后写入数据库。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( - message_segment=Seg(type="image", data=image_base64), + message_sequence=_build_message_sequence_from_custom_message("image", image_base64), stream_id=stream_id, display_message="", typing=False, @@ -200,18 +707,33 @@ async def image_to_stream( async def custom_to_stream( message_type: str, - content: str | Dict, + content: str | Dict[str, Any], stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, ) -> bool: - """向指定流发送自定义类型消息""" + """向指定流发送自定义类型消息。 + + Args: + message_type: 自定义消息类型。 + content: 自定义消息内容。 + stream_id: 目标会话 ID。 + display_message: 用于展示的文本内容。 + typing: 是否显示输入中状态。 + reply_message: 被回复的消息对象。 + set_reply: 是否附带引用回复。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( - message_segment=Seg(type=message_type, data=content), # type: ignore + message_sequence=_build_message_sequence_from_custom_message(message_type, content), stream_id=stream_id, display_message=display_message, typing=typing, @@ -227,31 +749,33 @@ async def custom_reply_set_to_stream( stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, ) -> bool: - """向指定流发送消息组件序列。""" - flag: bool = True - for component in reply_set.components: - if isinstance(component, DictComponent): - message_seg = Seg(type="dict", data=component.data) # type: ignore - else: - message_seg = await component.to_seg() - status = await _send_to_target( - message_segment=message_seg, - stream_id=stream_id, - display_message=display_message, - typing=typing, - reply_message=reply_message, - set_reply=set_reply, - storage_message=storage_message, - show_log=show_log, - ) - if not status: - flag = False - logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}") - set_reply = False + """向指定流发送消息组件序列。 - return flag + Args: + reply_set: 待发送的消息组件序列。 + stream_id: 目标会话 ID。 + display_message: 用于展示的文本内容。 + typing: 是否显示输入中状态。 + reply_message: 被回复的消息对象。 + set_reply: 是否附带引用回复。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ + return await _send_to_target( + message_sequence=reply_set, + stream_id=stream_id, + display_message=display_message, + typing=typing, + reply_message=reply_message, + set_reply=set_reply, + storage_message=storage_message, + show_log=show_log, + ) diff --git a/src/webui/routers/chat/serializers.py b/src/webui/routers/chat/serializers.py new file mode 100644 index 00000000..32104f88 --- /dev/null +++ b/src/webui/routers/chat/serializers.py @@ -0,0 +1,175 @@ +"""提供 WebUI 聊天路由使用的消息序列化能力。""" + +from typing import Any, Dict, List, Optional + +import base64 + +from src.common.data_models.message_component_data_model import ( + AtComponent, + DictComponent, + EmojiComponent, + ForwardComponent, + ForwardNodeComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + StandardMessageComponents, + TextComponent, + VoiceComponent, +) + + +def serialize_message_sequence(message_sequence: MessageSequence) -> List[Dict[str, Any]]: + """将内部统一消息组件序列转换为 WebUI 富文本消息段。 + + Args: + message_sequence: 内部统一消息组件序列。 + + Returns: + List[Dict[str, Any]]: 可直接广播给 WebUI 前端的消息段列表。 + """ + serialized_segments: List[Dict[str, Any]] = [] + for component in message_sequence.components: + serialized_segment = serialize_message_component(component) + if serialized_segment is not None: + serialized_segments.append(serialized_segment) + return serialized_segments + + +def serialize_message_component(component: StandardMessageComponents) -> Optional[Dict[str, Any]]: + """将单个内部消息组件转换为 WebUI 消息段。 + + Args: + component: 待序列化的内部消息组件。 + + Returns: + Optional[Dict[str, Any]]: 序列化后的 WebUI 消息段;若组件不应展示则返回 ``None``。 + """ + if isinstance(component, TextComponent): + return {"type": "text", "data": component.text} + + if isinstance(component, ImageComponent): + return _serialize_binary_component( + segment_type="image", + mime_type="image/png", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, EmojiComponent): + return _serialize_binary_component( + segment_type="emoji", + mime_type="image/gif", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, VoiceComponent): + return _serialize_binary_component( + segment_type="voice", + mime_type="audio/wav", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, AtComponent): + return { + "type": "at", + "data": { + "target_user_id": component.target_user_id, + "target_user_nickname": component.target_user_nickname, + "target_user_cardname": component.target_user_cardname, + }, + } + + if isinstance(component, ReplyComponent): + return { + "type": "reply", + "data": { + "target_message_id": component.target_message_id, + "target_message_content": component.target_message_content, + "target_message_sender_id": component.target_message_sender_id, + "target_message_sender_nickname": component.target_message_sender_nickname, + "target_message_sender_cardname": component.target_message_sender_cardname, + }, + } + + if isinstance(component, ForwardNodeComponent): + return { + "type": "forward", + "data": [_serialize_forward_component(item) for item in component.forward_components], + } + + if isinstance(component, DictComponent): + return _serialize_dict_component(component.data) + + return {"type": "unknown", "data": str(component)} + + +def _serialize_binary_component( + segment_type: str, + mime_type: str, + binary_data: bytes, + fallback_text: str, +) -> Dict[str, Any]: + """序列化带二进制负载的消息组件。 + + Args: + segment_type: WebUI 消息段类型。 + mime_type: 对应的数据 MIME 类型。 + binary_data: 组件二进制数据。 + fallback_text: 二进制缺失时可退化展示的文本。 + + Returns: + Dict[str, Any]: 序列化后的 WebUI 消息段。 + """ + if binary_data: + encoded_payload = base64.b64encode(binary_data).decode() + return {"type": segment_type, "data": f"data:{mime_type};base64,{encoded_payload}"} + + if fallback_text: + return {"type": "text", "data": fallback_text} + + return {"type": "unknown", "original_type": segment_type, "data": ""} + + +def _serialize_forward_component(component: ForwardComponent) -> Dict[str, Any]: + """序列化单个转发节点。 + + Args: + component: 待序列化的转发节点组件。 + + Returns: + Dict[str, Any]: WebUI 可消费的转发节点字典。 + """ + return { + "message_id": component.message_id, + "user_id": component.user_id, + "user_nickname": component.user_nickname, + "user_cardname": component.user_cardname, + "content": serialize_message_sequence(MessageSequence(component.content)), + } + + +def _serialize_dict_component(data: Dict[str, Any]) -> Dict[str, Any]: + """最佳努力地序列化非标准字典组件。 + + Args: + data: 原始字典组件内容。 + + Returns: + Dict[str, Any]: 序列化后的 WebUI 消息段。 + """ + raw_type = str(data.get("type") or "dict").strip() + raw_payload = data.get("data", data) + + if raw_type in {"text", "image", "emoji", "voice", "video", "file", "music", "face"}: + return {"type": raw_type, "data": raw_payload} + + if raw_type == "reply": + return {"type": "reply", "data": raw_payload} + + if raw_type == "forward" and isinstance(raw_payload, list): + return {"type": "forward", "data": raw_payload} + + return {"type": "unknown", "original_type": raw_type, "data": raw_payload}