From 03ed59e388237b4fc9edaea5c1dcb453717063c6 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 24 Mar 2026 11:36:26 +0800 Subject: [PATCH 01/45] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 137 +++++++-------- docs/README_CN.md | 162 ++++++++++++++++++ docs/README_EN.md | 2 +- docs/minimal-cross-platform-plan.md | 6 +- src/chat/brain_chat/brain_chat.py | 4 +- src/chat/heart_flow/heartFC_chat - 副本.py | 4 +- src/chat/heart_flow/heartFC_chat.py | 4 +- src/chat/replyer/group_generator.py | 4 +- src/chat/replyer/private_generator.py | 4 +- .../expression_auto_check_task.py | 2 +- .../expression_learner.py | 0 .../expression_review_store.py | 0 .../expression_selector.py | 2 +- .../expression_utils.py | 0 .../jargon_explainer.py | 0 .../jargon_explainer_old.py | 4 +- src/{bw_learner => learners}/jargon_miner.py | 0 src/{bw_learner => learners}/learner_utils.py | 0 .../learner_utils_old.py | 0 src/main.py | 2 +- src/memory_system/memory_retrieval.py | 2 +- .../retrieval_tools/query_words.py | 2 +- 22 files changed, 248 insertions(+), 93 deletions(-) create mode 100644 docs/README_CN.md rename src/{bw_learner => learners}/expression_auto_check_task.py (98%) rename src/{bw_learner => learners}/expression_learner.py (100%) rename src/{bw_learner => learners}/expression_review_store.py (100%) rename src/{bw_learner => learners}/expression_selector.py (99%) rename src/{bw_learner => learners}/expression_utils.py (100%) rename src/{bw_learner => learners}/jargon_explainer.py (100%) rename src/{bw_learner => learners}/jargon_explainer_old.py (99%) rename src/{bw_learner => learners}/jargon_miner.py (100%) rename src/{bw_learner => learners}/learner_utils.py (100%) rename src/{bw_learner => learners}/learner_utils_old.py (100%) diff --git a/README.md b/README.md index 7c41c8cf..3f3851b7 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,21 @@
- 简体中文 | English + 简体中文 | English

-

麦麦 MaiBot MaiSaka

+

MaiBot MaiSaka

Python Version - License - Status - Contributors - Forks - Stars + License + Status + Contributors + Forks + Stars Ask DeepWiki

@@ -25,31 +25,24 @@ MaiBot Character -## 介绍 +## Introduction -麦麦MaiSaka 是一个基于大语言模型的可交互智能体 +MaiSaka is an interactive agent based on large language models. -MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务的“有帮助的助手”,她还是一个致力于了解你,并以真实人类的风格进行交互的数字生命,她不追求完美,她不追求高效,但追求亲切和真实。 +MaiSaka is more than just a bot, and more than a "helpful assistant" that completes tasks. She is a digital life form that tries to understand you and interact in a genuinely human style. She does not pursue perfection or efficiency above all else. She pursues warmth and authenticity. +- 💭 **No one likes GPT-sounding dialogue**: MaiSaka uses a more natural conversational style. Instead of long-winded markdown-heavy replies, she chats in a way that feels casual, varied, and human. +- 🎭 **No longer stuck in rigid Q&A**: She knows when to speak, how to read the room, when to join a conversation, and when to stay quiet. +- 🧠 **MaiSaka becoming human**: In group conversations, MaiSaka imitates how people around her speak, learns new slang and in-group language, and keeps evolving. +- ❤️ **Always learning more about you**: Inspired by personality theory in psychology, MaiSaka gradually builds an understanding of your preferences, traits, habits, and behavior style. +- 🔌 **Plugin system**: Provides powerful APIs and an event system with virtually unlimited room for extension. -- 💭 **没有人喜欢GPT的语言风格**:麦麦使用了更加自然,贴合人类对话习惯的交互方式,不是长篇大论或者markdown格式的分点,而是或长或短的闲谈。 - -- 🎭 **不再是傻乎乎的一问一答**:懂得在合适的时间说话,把握聊天中的气氛,在合适的时候开口,在合适的时候闭嘴。 - -- 🧠 **麦麦·成为人类**:在多人对话中,麦麦会模仿其他人的的说话风格,还会自主理解新词或者小圈子里的黑话,不断进化。 - -- ❤️ **永远都在更加了解你**:基于心理学中人格理论,麦麦会不断积累对于你的了解,不论是你的信息,喜恶或是行为风格,她都记在心里。 - -- 🔌 **插件系统**:提供强大的 API 和事件系统,无限扩展可能。 - - - -### 快速导航 +### Quick Navigation

- 🌟 演示视频  |  - 📦 快速入门  |  - 📃 核心文档  |  - 💬 加入社区 + 🌟 Demo Video  |  + 📦 Quick Start  |  + 📃 Core Documentation  |  + 💬 Join Community

@@ -60,103 +53,103 @@ MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务 - 麦麦演示视频 + MaiSaka Demo Video
- 前往观看麦麦演示视频 + Watch the MaiSaka demo video
--- -## 🔥 更新和安装 +## 🔥 Updates and Installation -> **最新版本: v1.0.0** ([📄 更新日志](changelogs/changelog.md)) +> **Latest Version: v1.0.0** ([📄 Changelog](changelogs/changelog.md)) -- **下载**: 前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 -- **启动器**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (仅支持 MacOS, 早期开发中) +- **Download**: Visit the [Release](https://github.com/MaiM-with-u/MaiBot/releases/) page to get the latest version. +- **Launcher**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (MacOS only, still in early development). -| 分支 | 说明 | +| Branch | Description | | :--- | :--- | -| `main` | ✅ **稳定发布版本 (推荐)** | -| `dev` | 🚧 开发测试版本,包含新功能,可能不稳定 | +| `main` | ✅ **Stable release (recommended)** | +| `dev` | 🚧 Development testing branch with new features, may be unstable | -### 📚 部署教程 -👉 **[🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)** +### 📚 Deployment Guide +👉 **[🚀 Latest Deployment Guide](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)** --- -## 💬 讨论与社区 +## 💬 Discussion and Community -我们欢迎所有对 MaiBot 感兴趣的朋友加入! +We welcome everyone interested in MaiBot to join us. -| 类别 | 群组 | 说明 | +| Category | Group | Description | | :--- | :--- | :--- | -| **技术交流** | [麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) | 技术交流/答疑 | -| **技术交流** | [麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) | 技术交流/答疑 | -| **技术交流** | [麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) | 技术交流/答疑 | -| **闲聊吹水** | [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) | 仅限闲聊,不答疑 | -| **插件开发** | [插件开发群](https://qm.qq.com/q/1036092828) | 进阶开发与测试 | +| **Technical** | [MaiBrain EEG](https://qm.qq.com/q/RzmCiRtHEW) | Technical discussion / Q&A | +| **Technical** | [MaiBrain MRI](https://qm.qq.com/q/VQ3XZrWgMs) | Technical discussion / Q&A | +| **Technical** | [Mai Wants to Be a VTuber](https://qm.qq.com/q/wGePTl1UyY) | Technical discussion / Q&A | +| **Casual Chat** | [Mai Casual Chat Group](https://qm.qq.com/q/JxvHZnxyec) | Casual chat only, no support | +| **Plugin Development** | [Plugin Dev Group](https://qm.qq.com/q/1036092828) | Advanced development and testing | --- -## 📚 文档 +## 📚 Documentation > [!NOTE] -> 部分内容可能更新不够及时,请注意版本对应。 +> Some content may not be updated promptly, so please pay attention to version compatibility. -- **[📚 核心 Wiki 文档](https://docs.mai-mai.org)**: 最全面的文档中心,了解麦麦的一切。 +- **[📚 Core Wiki Documentation](https://docs.mai-mai.org)**: The most comprehensive documentation hub for everything about MaiSaka. -### 🧩 衍生项目 +### 🧩 Related Projects -- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: 让麦麦在B站开播 -- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: 基于 MaiCore 0.10.0 的增强型 Fork,更稳定更有趣。 -- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: 让麦麦陪你玩 Minecraft (暂时停止维护中)。 +- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: Let MaiSaka stream on Bilibili. +- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: An enhanced fork based on MaiCore 0.10.0, with improved stability and more fun features. +- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: Let MaiSaka accompany you in Minecraft (currently paused). --- -## 💡 设计理念 +## 💡 Design Philosophy -> **千石可乐说:** -> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 -> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。 -> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。 +> **SengokuCola says:** +> - This project originally started as a few extra features for the NiuNiu bot, but it kept growing until a full rewrite became inevitable. The goal was to create a "life form" active in QQ group chats, not a feature-complete bot, but something as human-like and real-feeling as possible. +> - The core design principle is: "more lifelike, not merely better." +> - If people truly want AI companionship, not everyone needs a perfect "helpful assistant" that solves every problem. Some people may want a life form that can make mistakes and has its own perceptions and thoughts. -> **xxxxx说:** +> **xxxxx says:** > *Code is open, but the soul is yours.* --- -## 🙋 贡献和致谢 +## 🙋 Contributing and Acknowledgments -欢迎参与贡献!请先阅读 [贡献指南](docs-src/CONTRIBUTE.md)。 +Contributions are welcome. Please read the [Contribution Guide](docs-src/CONTRIBUTE.md) first. -### 🌟 贡献者 +### 🌟 Contributors contributors -### ❤️ 特别致谢 +### ❤️ Special Thanks -- **[萨卡班甲鱼](https://en.wikipedia.org/wiki/Sacabambaspis)**: 千石可乐很喜欢的生物。 -- **[略nd](https://space.bilibili.com/1344099355)**: 🎨 为麦麦绘制早期的精美人设。 -- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: 🚀 现代化的基于 NTQQ 的 Bot 协议实现。 +- **[Sacabambaspis](https://en.wikipedia.org/wiki/Sacabambaspis)**: SengokuCola's favorite creature. +- **[略nd](https://space.bilibili.com/1344099355)**: Drew MaiSaka's beautiful early character design. +- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: A modern NTQQ-based bot protocol implementation. --- -## 📊 仓库状态 +## 📊 Repository Status -![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "麦麦仓库状态") +![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "MaiBot Repository Status") -### Star 趋势 -[![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) +### Star History +[![Star History](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) --- -## 📌 注意事项 & License +## 📌 Notice & License > [!IMPORTANT] -> 使用前请阅读 [用户协议 (EULA)](EULA.md) 和 [隐私协议](PRIVACY.md)。AI 生成内容请仔细甄别。 +> Please read the [End User License Agreement (EULA)](EULA.md) and [Privacy Policy](PRIVACY.md) before use. Please evaluate AI-generated content carefully. **License**: GPL-3.0 diff --git a/docs/README_CN.md b/docs/README_CN.md new file mode 100644 index 00000000..9532cd00 --- /dev/null +++ b/docs/README_CN.md @@ -0,0 +1,162 @@ +
+ + + 简体中文 | English + +
+
+ +

麦麦 MaiBot MaiSaka

+ + +

+ Python Version + License + Status + Contributors + Forks + Stars + Ask DeepWiki +

+
+ +
+ + +MaiBot Character + +## 介绍 + +麦麦MaiSaka 是一个基于大语言模型的可交互智能体 + +MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务的“有帮助的助手”,她还是一个致力于了解你,并以真实人类的风格进行交互的数字生命,她不追求完美,她不追求高效,但追求亲切和真实。 + + +- 💭 **没有人喜欢GPT的语言风格**:麦麦使用了更加自然,贴合人类对话习惯的交互方式,不是长篇大论或者markdown格式的分点,而是或长或短的闲谈。 + +- 🎭 **不再是傻乎乎的一问一答**:懂得在合适的时间说话,把握聊天中的气氛,在合适的时候开口,在合适的时候闭嘴。 + +- 🧠 **麦麦·成为人类**:在多人对话中,麦麦会模仿其他人的的说话风格,还会自主理解新词或者小圈子里的黑话,不断进化。 + +- ❤️ **永远都在更加了解你**:基于心理学中人格理论,麦麦会不断积累对于你的了解,不论是你的信息,喜恶或是行为风格,她都记在心里。 + +- 🔌 **插件系统**:提供强大的 API 和事件系统,无限扩展可能。 + + + +### 快速导航 +

+ 🌟 演示视频  |  + 📦 快速入门  |  + 📃 核心文档  |  + 💬 加入社区 +

+ + +
+ +
+
+ + + + 麦麦演示视频 + +
+ 前往观看麦麦演示视频 +
+
+ +--- + +## 🔥 更新和安装 + +> **最新版本: v1.0.0** ([📄 更新日志](../changelogs/changelog.md)) + +- **下载**: 前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 +- **启动器**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (仅支持 MacOS, 早期开发中) + +| 分支 | 说明 | +| :--- | :--- | +| `main` | ✅ **稳定发布版本 (推荐)** | +| `dev` | 🚧 开发测试版本,包含新功能,可能不稳定 | + +### 📚 部署教程 +👉 **[🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)** + +--- + +## 💬 讨论与社区 + +我们欢迎所有对 MaiBot 感兴趣的朋友加入! + +| 类别 | 群组 | 说明 | +| :--- | :--- | :--- | +| **技术交流** | [麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) | 技术交流/答疑 | +| **技术交流** | [麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) | 技术交流/答疑 | +| **技术交流** | [麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) | 技术交流/答疑 | +| **闲聊吹水** | [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) | 仅限闲聊,不答疑 | +| **插件开发** | [插件开发群](https://qm.qq.com/q/1036092828) | 进阶开发与测试 | + +--- + +## 📚 文档 + +> [!NOTE] +> 部分内容可能更新不够及时,请注意版本对应。 + +- **[📚 核心 Wiki 文档](https://docs.mai-mai.org)**: 最全面的文档中心,了解麦麦的一切。 + +### 🧩 衍生项目 + +- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: 让麦麦在B站开播 +- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: 基于 MaiCore 0.10.0 的增强型 Fork,更稳定更有趣。 +- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: 让麦麦陪你玩 Minecraft (暂时停止维护中)。 + +--- + +## 💡 设计理念 + +> **千石可乐说:** +> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 +> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。 +> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。 + +> **xxxxx说:** +> *Code is open, but the soul is yours.* + +--- + +## 🙋 贡献和致谢 + +欢迎参与贡献!请先阅读 [贡献指南](../docs-src/CONTRIBUTE.md)。 + +### 🌟 贡献者 + + + contributors + + +### ❤️ 特别致谢 + +- **[萨卡班甲鱼](https://en.wikipedia.org/wiki/Sacabambaspis)**: 千石可乐很喜欢的生物。 +- **[略nd](https://space.bilibili.com/1344099355)**: 🎨 为麦麦绘制早期的精美人设。 +- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: 🚀 现代化的基于 NTQQ 的 Bot 协议实现。 + +--- + +## 📊 仓库状态 + +![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "麦麦仓库状态") + +### Star 趋势 +[![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) + +--- + +## 📌 注意事项 & License + +> [!IMPORTANT] +> 使用前请阅读 [用户协议 (EULA)](../EULA.md) 和 [隐私协议](../PRIVACY.md)。AI 生成内容请仔细甄别。 + +**License**: GPL-3.0 diff --git a/docs/README_EN.md b/docs/README_EN.md index f37002fa..9ecc53b7 100644 --- a/docs/README_EN.md +++ b/docs/README_EN.md @@ -1,7 +1,7 @@
- 简体中文 | English + 简体中文 | English

diff --git a/docs/minimal-cross-platform-plan.md b/docs/minimal-cross-platform-plan.md index d0b6707b..2f0a86bd 100644 --- a/docs/minimal-cross-platform-plan.md +++ b/docs/minimal-cross-platform-plan.md @@ -41,7 +41,7 @@ This plan is based on the checked-in code, not on assumptions from previous draf | `src/person_info/person_info.py:247` | `_is_bot_self(self, platform, user_id)` | Duplicate logic with same QQ fallback | Wrong-order call sites (8 total): -- `src/bw_learner/expression_learner.py` x3 (lines 158, 241, 301) +- `src/learners/expression_learner.py` x3 (lines 158, 241, 301) - `src/common/utils/utils_message.py` x4 (lines 370, 440, 476, 515) - `src/webui/routers/chat/support.py` x1 (line 65) @@ -122,7 +122,7 @@ Make `src/chat/utils/utils.py::is_bot_self(platform, user_id)` the only real imp - `src/common/utils/system_utils.py` - `src/chat/utils/utils.py` - `src/person_info/person_info.py` -- `src/bw_learner/expression_learner.py` +- `src/learners/expression_learner.py` - `src/common/utils/utils_message.py` - `src/webui/routers/chat/support.py` - tests @@ -468,7 +468,7 @@ When stopping, name: the exact file(s), the blocking mismatch, why it is outside | Phase | Allowed files | |-------|---------------| -| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/bw_learner/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) | +| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/learners/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) | | Phase 1 | `src/chat/utils/utils.py`, `src/chat/planner_actions/planner.py`, `src/chat/utils/statistic.py`, `src/common/message_repository.py`, `src/webui/routers/chat/support.py`, `src/services/send_service.py`, `src/chat/replyer/group_generator.py`, `src/chat/replyer/private_generator.py`, `src/chat/brain_chat/PFC/message_sender.py`, `src/person_info/person_info.py`, tests | ### INVALID OUTPUT EXAMPLES diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 2b4863ac..1e9e648a 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -8,8 +8,8 @@ from rich.traceback import install from src.config.config import global_config from src.common.logger import get_logger from src.common.utils.utils_config import ExpressionConfigUtils -from src.bw_learner.expression_learner import ExpressionLearner -from src.bw_learner.jargon_miner import JargonMiner +from src.learners.expression_learner import ExpressionLearner +from src.learners.jargon_miner import JargonMiner from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import SessionMessage diff --git a/src/chat/heart_flow/heartFC_chat - 副本.py b/src/chat/heart_flow/heartFC_chat - 副本.py index c805597d..02f70281 100644 --- a/src/chat/heart_flow/heartFC_chat - 副本.py +++ b/src/chat/heart_flow/heartFC_chat - 副本.py @@ -16,9 +16,9 @@ from src.chat.planner_actions.planner import ActionPlanner from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager from src.chat.heart_flow.hfc_utils import CycleDetail -from src.bw_learner.expression_learner import expression_learner_manager +from src.learners.expression_learner import expression_learner_manager from src.chat.heart_flow.frequency_control import frequency_control_manager -from src.bw_learner.message_recorder import extract_and_distribute_messages +from src.learners.message_recorder import extract_and_distribute_messages from src.person_info.person_info import Person from src.plugin_system.base.component_types import EventType, ActionInfo from src.plugin_system.core import events_manager diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index af0beb4e..2c1eb162 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -7,8 +7,8 @@ import traceback from rich.traceback import install -from src.bw_learner.expression_learner import ExpressionLearner -from src.bw_learner.jargon_miner import JargonMiner +from src.learners.expression_learner import ExpressionLearner +from src.learners.jargon_miner import JargonMiner from src.chat.event_helpers import build_event_message from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.chat.message_receive.chat_manager import BotChatSession diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 74b324be..e10aa147 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -27,7 +27,7 @@ from src.services.message_service import ( replace_user_references, translate_pid_to_description, ) -from src.bw_learner.expression_selector import expression_selector +from src.learners.expression_selector import expression_selector # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person @@ -36,7 +36,7 @@ from src.services import llm_service as llm_api from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt -from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon +from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon from src.chat.utils.common_utils import TempMethodsExpression init_memory_retrieval_sys() diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 3b70bb2c..ccd8e1e4 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -27,13 +27,13 @@ from src.services.message_service import ( replace_user_references, translate_pid_to_description, ) -from src.bw_learner.expression_selector import expression_selector +from src.learners.expression_selector import expression_selector # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person, is_person_known from src.core.types import ActionInfo, EventType from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt -from src.bw_learner.jargon_explainer_old import explain_jargon_in_context +from src.learners.jargon_explainer_old import explain_jargon_in_context init_memory_retrieval_sys() diff --git a/src/bw_learner/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py similarity index 98% rename from src/bw_learner/expression_auto_check_task.py rename to src/learners/expression_auto_check_task.py index d90eb4da..53b151b2 100644 --- a/src/bw_learner/expression_auto_check_task.py +++ b/src/learners/expression_auto_check_task.py @@ -15,7 +15,7 @@ import random from sqlmodel import select -from src.bw_learner.expression_review_store import get_review_state, set_review_state +from src.learners.expression_review_store import get_review_state, set_review_state from src.common.database.database import get_db_session from src.common.database.database_model import Expression from src.common.logger import get_logger diff --git a/src/bw_learner/expression_learner.py b/src/learners/expression_learner.py similarity index 100% rename from src/bw_learner/expression_learner.py rename to src/learners/expression_learner.py diff --git a/src/bw_learner/expression_review_store.py b/src/learners/expression_review_store.py similarity index 100% rename from src/bw_learner/expression_review_store.py rename to src/learners/expression_review_store.py diff --git a/src/bw_learner/expression_selector.py b/src/learners/expression_selector.py similarity index 99% rename from src/bw_learner/expression_selector.py rename to src/learners/expression_selector.py index c6cfe469..c96e84cf 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/learners/expression_selector.py @@ -9,7 +9,7 @@ from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.database_model import Expression from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.learner_utils_old import weighted_sample +from src.learners.learner_utils_old import weighted_sample from src.chat.utils.common_utils import TempMethodsExpression logger = get_logger("expression_selector") diff --git a/src/bw_learner/expression_utils.py b/src/learners/expression_utils.py similarity index 100% rename from src/bw_learner/expression_utils.py rename to src/learners/expression_utils.py diff --git a/src/bw_learner/jargon_explainer.py b/src/learners/jargon_explainer.py similarity index 100% rename from src/bw_learner/jargon_explainer.py rename to src/learners/jargon_explainer.py diff --git a/src/bw_learner/jargon_explainer_old.py b/src/learners/jargon_explainer_old.py similarity index 99% rename from src/bw_learner/jargon_explainer_old.py rename to src/learners/jargon_explainer_old.py index 4d144b2c..0cfafa82 100644 --- a/src/bw_learner/jargon_explainer_old.py +++ b/src/learners/jargon_explainer_old.py @@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.jargon_miner_old import search_jargon -from src.bw_learner.learner_utils_old import ( +from src.learners.jargon_miner_old import search_jargon +from src.learners.learner_utils_old import ( is_bot_message, contains_bot_self_name, parse_chat_id_list, diff --git a/src/bw_learner/jargon_miner.py b/src/learners/jargon_miner.py similarity index 100% rename from src/bw_learner/jargon_miner.py rename to src/learners/jargon_miner.py diff --git a/src/bw_learner/learner_utils.py b/src/learners/learner_utils.py similarity index 100% rename from src/bw_learner/learner_utils.py rename to src/learners/learner_utils.py diff --git a/src/bw_learner/learner_utils_old.py b/src/learners/learner_utils_old.py similarity index 100% rename from src/bw_learner/learner_utils_old.py rename to src/learners/learner_utils_old.py diff --git a/src/main.py b/src/main.py index 587c5634..30d1c86d 100644 --- a/src/main.py +++ b/src/main.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING import asyncio import time -from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask +from src.learners.expression_auto_check_task import ExpressionAutoCheckTask from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.knowledge import lpmm_start_up from src.chat.message_receive.bot import chat_bot diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 4193a16a..982db166 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -14,7 +14,7 @@ from src.common.database.database_model import ThinkingQuestion from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon +from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon logger = get_logger("memory_retrieval") diff --git a/src/memory_system/retrieval_tools/query_words.py b/src/memory_system/retrieval_tools/query_words.py index 66fb3c46..ee28b934 100644 --- a/src/memory_system/retrieval_tools/query_words.py +++ b/src/memory_system/retrieval_tools/query_words.py @@ -4,7 +4,7 @@ """ from src.common.logger import get_logger -from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon +from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon from .tool_registry import register_memory_retrieval_tool logger = get_logger("memory_retrieval_tools") From 668f41431ab5d88f147964de61b96691ed1a1705 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Mon, 16 Mar 2026 16:57:14 +0800 Subject: [PATCH 02/45] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=97=A0=E7=94=A8?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/official_configs.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 6ed5c452..e6907611 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1615,24 +1615,6 @@ class PluginRuntimeConfig(ConfigBase): ) """启用插件系统""" - builtin_plugin_dir: str = Field( - default="src/plugins/built_in", - json_schema_extra={ - "x-widget": "input", - "x-icon": "folder", - }, - ) - """内置插件目录(相对于项目根目录)""" - - thirdparty_plugin_dir: str = Field( - default="plugins", - json_schema_extra={ - "x-widget": "input", - "x-icon": "folder-open", - }, - ) - """第三方插件目录(相对于项目根目录)""" - health_check_interval_sec: float = Field( default=30.0, json_schema_extra={ @@ -1676,4 +1658,7 @@ class PluginRuntimeConfig(ConfigBase): "x-icon": "link", }, ) - """_wrap_\n 自定义 IPC Socket 路径(仅 Linux/macOS 生效)\n 留空则自动生成临时路径""" + """ + 自定义 IPC Socket 路径(仅 Linux/macOS 生效) + 留空则自动生成临时路径 + """ From 34190755992cd4f0685619a601f41a381f23aaef Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Mon, 16 Mar 2026 18:18:40 +0800 Subject: [PATCH 03/45] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=B3=A8?= =?UTF-8?q?=E9=87=8A;=E6=B7=BB=E5=8A=A0=E6=97=A5=E5=BF=97=E9=A2=9C?= =?UTF-8?q?=E8=89=B2=E8=87=AA=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/protocol/envelope.py | 160 +++++++++++++----------- 1 file changed, 90 insertions(+), 70 deletions(-) diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index bcfb2758..56bb4582 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -13,46 +13,38 @@ import logging as stdlib_logging import time -# ─── 协议常量 ────────────────────────────────────────────────────── - -PROTOCOL_VERSION = "1.0" - +# ====== 协议常量 ====== +PROTOCOL_VERSION = "1.0.0" # 支持的 SDK 版本范围(Host 在握手时校验) MIN_SDK_VERSION = "1.0.0" MAX_SDK_VERSION = "1.99.99" -# ─── 消息类型 ────────────────────────────────────────────────────── - - +# ====== 消息类型 ====== class MessageType(str, Enum): """RPC 消息类型""" REQUEST = "request" RESPONSE = "response" - EVENT = "event" - - -# ─── 请求 ID 生成器 ─────────────────────────────────────────────── + BROADCAST = "broadcast" +# ====== 请求 ID 生成器 ====== class RequestIdGenerator: - """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)""" + """单调递增 int64 请求 ID 生成器""" def __init__(self, start: int = 1) -> None: self._counter = start - def next(self) -> int: + async def next(self) -> int: current = self._counter self._counter += 1 return current -# ─── Envelope 模型 ───────────────────────────────────────────────── - - +# ====== Envelope 模型 ====== class Envelope(BaseModel): - """RPC 统一信封 + """RPC 统一消息封装 所有 Host <-> Runner 消息均封装为此格式。 序列化流程:Envelope -> .model_dump() -> MsgPack encode @@ -60,15 +52,25 @@ class Envelope(BaseModel): """ protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本") + """协议版本""" request_id: int = Field(description="单调递增请求 ID") + """单调递增请求 ID""" message_type: MessageType = Field(description="消息类型") + """消息类型""" method: str = Field(default="", description="RPC 方法名") + """RPC 方法名""" plugin_id: str = Field(default="", description="目标插件 ID") - timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)") - timeout_ms: int = Field(default=30000, description="相对超时(ms)") + """目标插件 ID""" + timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳 (ms)") + """发送时间戳 (ms)""" + timeout_ms: int = Field(default=30000, description="相对超时 (ms)") + """相对超时 (ms)""" generation: int = Field(default=0, description="Runner generation 编号") + """Runner generation 编号""" payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据") - error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)") + """业务数据""" + error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)") + """错误信息 (仅 response)""" def is_request(self) -> bool: return self.message_type == MessageType.REQUEST @@ -76,8 +78,8 @@ class Envelope(BaseModel): def is_response(self) -> bool: return self.message_type == MessageType.RESPONSE - def is_event(self) -> bool: - return self.message_type == MessageType.EVENT + def is_broadcast(self) -> bool: + return self.message_type == MessageType.BROADCAST def make_response( self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None @@ -105,153 +107,172 @@ class Envelope(BaseModel): ) -# ─── 握手消息 ────────────────────────────────────────────────────── - - +# ====== 握手请求与响应 ====== class HelloPayload(BaseModel): """runner.hello 握手请求 payload""" runner_id: str = Field(description="Runner 进程唯一标识") + """Runner 进程唯一标识""" sdk_version: str = Field(description="SDK 版本号") + """SDK 版本号""" session_token: str = Field(description="一次性会话令牌") + """一次性会话令牌""" class HelloResponsePayload(BaseModel): """runner.hello 握手响应 payload""" accepted: bool = Field(description="是否接受连接") + """是否接受连接""" host_version: str = Field(default="", description="Host 版本号") + """Host 版本号""" assigned_generation: int = Field(default=0, description="分配的 generation 编号") - reason: str = Field(default="", description="拒绝原因(若 accepted=False)") - - -# ─── 组件注册消息 ────────────────────────────────────────────────── + """分配的 generation 编号""" + reason: str = Field(default="", description="拒绝原因 (若 accepted=False)") + """拒绝原因 (若 `accepted`=`False`)""" +# ====== 组件注册消息 ====== class ComponentDeclaration(BaseModel): """单个组件声明""" name: str = Field(description="组件名称") - component_type: str = Field(description="组件类型: action/command/tool/event_handler") + """组件名称""" + component_type: str = Field( + description="组件类型:action/command/tool/event_handler/workflow_handler/message_gateway" + ) + """组件类型:`action`/`command`/`tool`/`event_handler`/`workflow_handler`/`message_gateway`""" plugin_id: str = Field(description="所属插件 ID") + """所属插件 ID""" metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据") + """组件元数据""" -class RegisterComponentsPayload(BaseModel): - """plugin.register_components 请求 payload""" +class RegisterPluginPayload(BaseModel): + """plugin.register_plugin 请求 payload""" plugin_id: str = Field(description="插件 ID") + """插件 ID""" plugin_version: str = Field(default="1.0.0", description="插件版本") + """插件版本""" components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表") + """组件列表""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") + """所需能力列表""" class BootstrapPluginPayload(BaseModel): """plugin.bootstrap 请求 payload""" plugin_id: str = Field(description="插件 ID") + """插件 ID""" plugin_version: str = Field(default="1.0.0", description="插件版本") + """插件版本""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") + """所需能力列表""" -# ─── 调用消息 ────────────────────────────────────────────────────── - - +# ====== 插件调用请求和响应 ====== class InvokePayload(BaseModel): - """plugin.invoke_* 请求 payload""" + """plugin.invoke.* 请求 payload""" component_name: str = Field(description="要调用的组件名称") + """要调用的组件名称""" args: Dict[str, Any] = Field(default_factory=dict, description="调用参数") + """调用参数""" class InvokeResultPayload(BaseModel): - """plugin.invoke_* 响应 payload""" + """plugin.invoke.* 响应 payload""" success: bool = Field(description="是否成功") + """是否成功""" result: Any = Field(default=None, description="返回值") + """返回值""" -# ─── 能力调用消息 ────────────────────────────────────────────────── - - +# ====== 能力调用消息 ====== class CapabilityRequestPayload(BaseModel): """cap.* 请求 payload(插件 -> Host 能力调用)""" capability: str = Field(description="能力名称,如 send.text, db.query") + """能力名称,如 send.text, db.query""" args: Dict[str, Any] = Field(default_factory=dict, description="调用参数") + """调用参数""" class CapabilityResponsePayload(BaseModel): """cap.* 响应 payload""" success: bool = Field(description="是否成功") + """是否成功""" result: Any = Field(default=None, description="返回值") + """返回值""" -# ─── 健康检查 ────────────────────────────────────────────────────── - - +# ====== 健康检查 ====== class HealthPayload(BaseModel): """plugin.health 响应 payload""" healthy: bool = Field(description="是否健康") + """是否健康""" loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表") - uptime_ms: int = Field(default=0, description="运行时长(ms)") + """已加载的插件列表""" + uptime_ms: int = Field(default=0, description="运行时长 (ms)") + """运行时长 (ms)""" class RunnerReadyPayload(BaseModel): """runner.ready 请求 payload""" loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表") + """已完成初始化的插件列表""" failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表") + """初始化失败的插件列表""" -# ─── 配置更新 ────────────────────────────────────────────────────── - - -# Host 侧现已支持配置更新推送: -# - 总配置热重载完成后,PluginRuntimeManager 会向已加载插件推送配置更新事件。 -# - 插件目录下的 config.toml 变化由现有 FileWatcher 监听并转发为 plugin.config_updated。 +# ====== 配置更新 ====== class ConfigUpdatedPayload(BaseModel): """plugin.config_updated 事件 payload""" plugin_id: str = Field(description="插件 ID") + """插件 ID""" config_version: str = Field(description="新配置版本") + """新配置版本""" config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容") + """配置内容""" -# ─── 关停 ────────────────────────────────────────────────────────── - - +# ====== 关停 ====== class ShutdownPayload(BaseModel): """plugin.shutdown / plugin.prepare_shutdown payload""" reason: str = Field(default="normal", description="关停原因") - drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)") + """关停原因""" + drain_timeout_ms: int = Field(default=5000, description="排空超时 (ms)") + """排空超时 (ms)""" -# ─── 日志传输 ────────────────────────────────────────────────────── +# ====== 日志传输 ====== class LogEntry(BaseModel): """单条日志记录(Runner → Host 传输格式)""" - timestamp_ms: int = Field( - description="日志时间戳,Unix epoch 毫秒", - ) - level: int = Field( - description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"), - ) - logger_name: str = Field( - description="Logger 名称,如 plugin.my_plugin.submodule", - ) - message: str = Field( - description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)", - ) + timestamp_ms: int = Field(description="日志时间戳,Unix epoch 毫秒") + """日志时间戳,Unix epoch 毫秒""" + level: int = Field(description="stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL") + """stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL""" + logger_name: str = Field(description="Logger 名称,如 plugin.my_plugin.submodule") + """Logger 名称,如 plugin.my_plugin.submodule""" + message: str = Field(description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)") + """经 Formatter 格式化后的完整日志消息(含 exc_info 文本)""" exception_text: str = Field( default="", description="原始异常摘要(exc_text),供结构化消费;已嵌入 message 中", ) + """原始异常摘要(exc_text),供结构化消费;已嵌入 message 中""" + log_color_in_hex: Optional[str] = Field(default=None, description="日志颜色的十六进制字符串(如 #RRGGBB)") @property def levelname(self) -> str: @@ -262,6 +283,5 @@ class LogEntry(BaseModel): class LogBatchPayload(BaseModel): """runner.log_batch 事件 payload:Runner 端向 Host 批量推送日志记录""" - entries: List[LogEntry] = Field( - description="本批次日志记录列表,按时间升序排列", - ) + entries: List[LogEntry] = Field(description="本批次日志记录列表,按时间升序排列") + """本批次日志记录列表,按时间升序排列""" From e1b2ecb5b13a01b41e5405e9c805b7add664a4b3 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Mon, 16 Mar 2026 19:37:58 +0800 Subject: [PATCH 04/45] =?UTF-8?q?fix:=20(AI)=20=E6=9B=B4robust=E7=9A=84?= =?UTF-8?q?=E4=BC=A0=E8=BE=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/transport/named_pipe.py | 91 ++++++++++++++++++---- src/plugin_runtime/transport/uds.py | 62 +++++++++++++-- 2 files changed, 129 insertions(+), 24 deletions(-) diff --git a/src/plugin_runtime/transport/named_pipe.py b/src/plugin_runtime/transport/named_pipe.py index a759507d..7fd39bc9 100644 --- a/src/plugin_runtime/transport/named_pipe.py +++ b/src/plugin_runtime/transport/named_pipe.py @@ -1,6 +1,9 @@ """Windows Named Pipe 传输实现。 适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。 + +注意:Named Pipe 是 Windows 特有的 IPC 机制, +在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。 """ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast @@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin" class _NamedPipeServerHandle(Protocol): + """Named Pipe 服务端句柄的协议定义。""" def close(self) -> None: ... class _NamedPipeEventLoop(Protocol): + """ProactorEventLoop 的协议定义,提供 named pipe 相关方法。""" async def start_serving_pipe( self, protocol_factory: Callable[[], asyncio.BaseProtocol], @@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol): def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str: + """规范化 Named Pipe 地址。 + + Args: + pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用, + 否则会自动添加前缀。如果为 None 则生成随机名称。 + + Returns: + 规范化的管道地址(格式:\\\\.\\pipe\\name) + """ if pipe_name and pipe_name.startswith(_PIPE_PREFIX): return pipe_name @@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str: class NamedPipeConnection(Connection): - """基于 Windows Named Pipe 的连接。""" + """基于 Windows Named Pipe 的连接。 + + 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。 + """ - pass + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + super().__init__(reader, writer) class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol): + """Named Pipe 服务端协议实现。 + + 处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。 + """ + def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None: self._reader: asyncio.StreamReader = asyncio.StreamReader() super().__init__(self._reader) @@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol): self._handler_task: Optional[asyncio.Task[None]] = None def connection_made(self, transport: asyncio.BaseTransport) -> None: + """连接建立时的回调。""" super().connection_made(transport) writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop) connection = NamedPipeConnection(self._reader, writer) - self._handler_task = self._loop.create_task(self._run_handler(connection)) + # 使用 asyncio.create_task 确保任务正确调度 + self._handler_task = asyncio.create_task(self._run_handler(connection)) self._handler_task.add_done_callback(self._on_handler_done) async def _run_handler(self, connection: NamedPipeConnection) -> None: + """运行连接处理器。""" try: await self._handler(connection) finally: await connection.close() def _on_handler_done(self, task: asyncio.Task[None]) -> None: + """连接处理器完成时的回调。""" if task.cancelled(): return if exc := task.exception(): - self._loop.call_exception_handler( - { - "message": "Named pipe 连接处理失败", - "exception": exc, - "protocol": self, - } - ) + try: + self._loop.call_exception_handler( + { + "message": "Named pipe 连接处理失败", + "exception": exc, + "protocol": self, + } + ) + except Exception: + # 如果 loop 已经关闭,忽略异常 + pass class NamedPipeTransportServer(TransportServer): - """Windows Named Pipe 传输服务端。""" + """Windows Named Pipe 传输服务端。 + + 使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。 + """ def __init__(self, pipe_name: Optional[str] = None) -> None: self._address: str = _normalize_pipe_address(pipe_name) self._servers: List[_NamedPipeServerHandle] = [] async def start(self, handler: ConnectionHandler) -> None: + """启动 Named Pipe 服务端。 + + Args: + handler: 新连接到来时的回调函数 + + Raises: + RuntimeError: 当在非 Windows 平台或事件循环不支持时 + """ if sys.platform != "win32": raise RuntimeError("Named pipe 仅支持 Windows") @@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer): ) async def stop(self) -> None: + """停止 Named Pipe 服务端并清理资源。""" for server in self._servers: server.close() + # 等待所有服务器句柄完全关闭 + await asyncio.gather( + *[asyncio.sleep(0.1) for _ in self._servers], + return_exceptions=True + ) self._servers.clear() - await asyncio.sleep(0) def get_address(self) -> str: return self._address class NamedPipeTransportClient(TransportClient): - """Windows Named Pipe 传输客户端。""" + """Windows Named Pipe 传输客户端。 + + 用于主动连接到 Named Pipe 服务端。 + """ def __init__(self, address: str) -> None: self._address: str = _normalize_pipe_address(address) async def connect(self) -> Connection: + """建立到 Named Pipe 服务端的连接。 + + Returns: + NamedPipeConnection: 连接对象 + + Raises: + NotImplementedError: 当在非 Windows 平台或事件循环不支持时 + """ if sys.platform != "win32": - raise RuntimeError("Named pipe 仅支持 Windows") + raise NotImplementedError("Named pipe 仅支持 Windows") loop = asyncio.get_running_loop() if not hasattr(loop, "create_pipe_connection"): - raise RuntimeError("当前事件循环不支持 Windows named pipe") + raise NotImplementedError("当前事件循环不支持 Windows named pipe") pipe_loop = cast(_NamedPipeEventLoop, loop) reader = asyncio.StreamReader() protocol = asyncio.StreamReaderProtocol(reader) transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address) - writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop) + # 使用返回的 protocol 创建 StreamWriter + writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop) return NamedPipeConnection(reader, writer) \ No newline at end of file diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py index 47bf033b..af71ea5d 100644 --- a/src/plugin_runtime/transport/uds.py +++ b/src/plugin_runtime/transport/uds.py @@ -1,6 +1,9 @@ """Unix Domain Socket 传输实现 适用于 Linux / macOS 平台。 + +注意:UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制, +在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。 """ from pathlib import Path @@ -8,20 +11,30 @@ from typing import Optional import asyncio import os +import sys import tempfile from .base import Connection, ConnectionHandler, TransportClient, TransportServer class UDSConnection(Connection): - """基于 UDS 的连接""" + """基于 UDS 的连接 + + 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。 + """ - pass # 直接复用 Connection 基类的分帧读写 + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + super().__init__(reader, writer) # Unix domain socket 路径的系统限制(sun_path 字段长度) -# Linux: 108 字节, macOS: 104 字节 -_UDS_PATH_MAX = 104 +# Linux: 108 字节,macOS: 104 字节,其他 Unix: 通常 104 字节 +if sys.platform == "linux": + _UDS_PATH_MAX = 108 +elif sys.platform == "darwin": # macOS + _UDS_PATH_MAX = 104 +else: + _UDS_PATH_MAX = 104 # 保守默认值 class UDSTransportServer(TransportServer): @@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer): self._server: Optional[asyncio.AbstractServer] = None async def start(self, handler: ConnectionHandler) -> None: + """启动 UDS 服务端 + + Args: + handler: 新连接到来时的回调函数 + + Raises: + RuntimeError: 当在非 Unix 平台(如 Windows)上调用时 + """ + # 平台检查:UDS 仅在 Unix-like 系统上可用 + if sys.platform == "win32": + raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe") + # 清理残留 socket 文件 if self._socket_path.exists(): self._socket_path.unlink() @@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer): finally: await conn.close() - self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path)) + try: + self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path)) - # 设置文件权限为仅当前用户可访问 - self._socket_path.chmod(0o600) + # 设置文件权限为仅当前用户可访问 + self._socket_path.chmod(0o600) + except Exception: + # 启动失败时清理可能创建的目录和 socket 文件 + if self._socket_path.exists(): + self._socket_path.unlink() + raise async def stop(self) -> None: if self._server: @@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer): class UDSTransportClient(TransportClient): - """UDS 传输客户端""" + """UDS 传输客户端 + + 用于主动连接到 UDS 服务端。 + """ def __init__(self, socket_path: Path) -> None: self._socket_path: Path = socket_path async def connect(self) -> Connection: + """建立到 UDS 服务端的连接 + + Returns: + UDSConnection: 连接对象 + + Raises: + RuntimeError: 当在非 Unix 平台(如 Windows)上调用时 + """ + # 平台检查:UDS 仅在 Unix-like 系统上可用 + if sys.platform == "win32": + raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe") + reader, writer = await asyncio.open_unix_connection(str(self._socket_path)) return UDSConnection(reader, writer) From 49b620219de2e1333e1107aead43deda9e33d0dc Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 17 Mar 2026 01:30:31 +0800 Subject: [PATCH 05/45] =?UTF-8?q?refcator:=20=E9=87=8D=E5=91=BD=E5=90=8Dpo?= =?UTF-8?q?licy=E4=B8=BAauthorization;=E7=A7=BB=E9=99=A4envelope=E7=9A=84g?= =?UTF-8?q?eneration(runner=E4=B8=8D=E5=86=8D=E9=87=8D=E8=BD=BD);?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/host/authorization.py | 61 ++++++++++++ src/plugin_runtime/host/capability_service.py | 14 +-- src/plugin_runtime/host/policy_engine.py | 97 ------------------- src/plugin_runtime/protocol/envelope.py | 5 - 4 files changed, 69 insertions(+), 108 deletions(-) create mode 100644 src/plugin_runtime/host/authorization.py delete mode 100644 src/plugin_runtime/host/policy_engine.py diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py new file mode 100644 index 00000000..d746c4d2 --- /dev/null +++ b/src/plugin_runtime/host/authorization.py @@ -0,0 +1,61 @@ +"""授权管理器 + +负责管理插件的能力授权以及校验 +每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。 +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Tuple + + +@dataclass +class CapabilityPermissionToken: + """能力令牌""" + + plugin_id: str + capabilities: Set[str] = field(default_factory=set) + + +class AuthorizationManager: + """授权管理器 + + 管理所有插件的能力令牌,提供授权校验。 + """ + + def __init__(self) -> None: + self._permission_tokens: Dict[str, CapabilityPermissionToken] = {} + + def register_plugin(self, plugin_id: str, capabilities: List[str]) -> CapabilityPermissionToken: + """为插件签发能力令牌""" + token = CapabilityPermissionToken(plugin_id=plugin_id, capabilities=set(capabilities)) + self._permission_tokens[plugin_id] = token + return token + + def revoke_permission_token(self, plugin_id: str): + """移除插件的能力令牌。""" + self._permission_tokens.pop(plugin_id, None) + + def clear(self) -> None: + """清空所有能力令牌。""" + self._permission_tokens.clear() + + def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]: + """检查插件是否有权调用某项能力 + + Returns: + return (bool, str): (是否有此能力, 原因) + """ + token = self._permission_tokens.get(plugin_id) + if not token: + return False, f"插件 {plugin_id} 未注册能力令牌" + if capability not in token.capabilities: + return False, f"插件 {plugin_id} 未获授权能力: {capability}" + return True, "" + + def get_token(self, plugin_id: str) -> Optional[CapabilityPermissionToken]: + """获取插件的能力令牌""" + return self._permission_tokens.get(plugin_id) + + def list_plugins(self) -> List[str]: + """列出所有已注册的插件""" + return list(self._permission_tokens.keys()) diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 6685ff60..e0c56c2b 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -4,10 +4,9 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。 每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。 """ -from typing import Any, Awaitable, Callable, Dict, List +from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING from src.common.logger import get_logger -from src.plugin_runtime.host.policy_engine import PolicyEngine from src.plugin_runtime.protocol.envelope import ( CapabilityRequestPayload, CapabilityResponsePayload, @@ -15,10 +14,13 @@ from src.plugin_runtime.protocol.envelope import ( ) from src.plugin_runtime.protocol.errors import ErrorCode, RPCError +if TYPE_CHECKING: + from src.plugin_runtime.host.authorization import AuthorizationManager + logger = get_logger("plugin_runtime.host.capability_service") # 能力实现函数类型: (plugin_id, capability, args) -> result -CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]] +CapabilityImpl = Callable[[str, str, Dict[str, Any]], Coroutine[Any, Any, Any]] class CapabilityService: @@ -31,8 +33,8 @@ class CapabilityService: 4. 执行实际操作并返回结果 """ - def __init__(self, policy_engine: PolicyEngine) -> None: - self._policy = policy_engine + def __init__(self, authorization: "AuthorizationManager") -> None: + self._authorization = authorization # capability_name -> implementation self._implementations: Dict[str, CapabilityImpl] = {} @@ -65,7 +67,7 @@ class CapabilityService: capability = req.capability # 1. 权限校验 - allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation) + allowed, reason = self._authorization.check_capability(plugin_id, capability) if not allowed: error_code = ( ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py deleted file mode 100644 index 61b32480..00000000 --- a/src/plugin_runtime/host/policy_engine.py +++ /dev/null @@ -1,97 +0,0 @@ -"""策略引擎 - -负责能力授权校验。 -每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。 -""" - -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Set, Tuple - - -@dataclass -class CapabilityToken: - """能力令牌""" - - plugin_id: str - generation: int - capabilities: Set[str] = field(default_factory=set) - - -class PolicyEngine: - """策略引擎 - - 管理所有插件的能力令牌,提供授权校验。 - """ - - def __init__(self) -> None: - self._tokens: Dict[str, Dict[int, CapabilityToken]] = {} - - def register_plugin( - self, - plugin_id: str, - generation: int, - capabilities: List[str], - ) -> CapabilityToken: - """为插件签发能力令牌""" - token = CapabilityToken( - plugin_id=plugin_id, - generation=generation, - capabilities=set(capabilities), - ) - self._tokens.setdefault(plugin_id, {})[generation] = token - return token - - def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None: - """撤销插件的能力令牌。""" - if generation is None: - self._tokens.pop(plugin_id, None) - return - - generations = self._tokens.get(plugin_id) - if generations is None: - return - - generations.pop(generation, None) - if not generations: - self._tokens.pop(plugin_id, None) - - def clear(self) -> None: - """清空所有能力令牌。""" - self._tokens.clear() - - def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]: - """检查插件是否有权调用某项能力 - - Returns: - (allowed, reason) - """ - generations = self._tokens.get(plugin_id) - if not generations: - return False, f"插件 {plugin_id} 未注册能力令牌" - - if generation is None: - token = generations[max(generations)] - else: - token = generations.get(generation) - if token is None: - active_generation = max(generations) - return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}" - - if capability not in token.capabilities: - return False, f"插件 {plugin_id} 未获授权能力: {capability}" - - if generation is not None and token.generation != generation: - return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}" - - return True, "" - - def get_token(self, plugin_id: str) -> Optional[CapabilityToken]: - """获取插件的能力令牌""" - generations = self._tokens.get(plugin_id) - if not generations: - return None - return generations[max(generations)] - - def list_plugins(self) -> List[str]: - """列出所有已注册的插件""" - return list(self._tokens.keys()) diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 56bb4582..ca9e8005 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -65,8 +65,6 @@ class Envelope(BaseModel): """发送时间戳 (ms)""" timeout_ms: int = Field(default=30000, description="相对超时 (ms)") """相对超时 (ms)""" - generation: int = Field(default=0, description="Runner generation 编号") - """Runner generation 编号""" payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据") """业务数据""" error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)") @@ -91,7 +89,6 @@ class Envelope(BaseModel): message_type=MessageType.RESPONSE, method=self.method, plugin_id=self.plugin_id, - generation=self.generation, payload=payload or {}, error=error, ) @@ -126,8 +123,6 @@ class HelloResponsePayload(BaseModel): """是否接受连接""" host_version: str = Field(default="", description="Host 版本号") """Host 版本号""" - assigned_generation: int = Field(default=0, description="分配的 generation 编号") - """分配的 generation 编号""" reason: str = Field(default="", description="拒绝原因 (若 accepted=False)") """拒绝原因 (若 `accepted`=`False`)""" From 84a6524bd9c8ed66eca35801b0d952b552194148 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 17 Mar 2026 20:00:19 +0800 Subject: [PATCH 06/45] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4generation;?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=96=B0=E7=9A=84ErrorCode;=E4=BF=AE?= =?UTF-8?q?=E6=94=B9ErrorCode=E7=9A=84=E4=B8=80=E4=B8=AA=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/host/authorization.py | 1 + src/plugin_runtime/host/capability_service.py | 29 +- src/plugin_runtime/host/rpc_server.py | 444 +++++------------- src/plugin_runtime/protocol/errors.py | 16 +- 4 files changed, 138 insertions(+), 352 deletions(-) diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py index d746c4d2..3fb48c6a 100644 --- a/src/plugin_runtime/host/authorization.py +++ b/src/plugin_runtime/host/authorization.py @@ -40,6 +40,7 @@ class AuthorizationManager: self._permission_tokens.clear() def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]: + # sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression """检查插件是否有权调用某项能力 Returns: diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index e0c56c2b..98366a07 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -7,11 +7,7 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。 from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING from src.common.logger import get_logger -from src.plugin_runtime.protocol.envelope import ( - CapabilityRequestPayload, - CapabilityResponsePayload, - Envelope, -) +from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope from src.plugin_runtime.protocol.errors import ErrorCode, RPCError if TYPE_CHECKING: @@ -59,31 +55,19 @@ class CapabilityService: try: req = CapabilityRequestPayload.model_validate(envelope.payload) except Exception as e: - return envelope.make_error_response( - ErrorCode.E_BAD_PAYLOAD.value, - f"能力调用 payload 格式错误: {e}", - ) + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}") capability = req.capability # 1. 权限校验 allowed, reason = self._authorization.check_capability(plugin_id, capability) if not allowed: - error_code = ( - ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED - ) - return envelope.make_error_response( - error_code.value, - reason, - ) + return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason) # 2. 查找实现 impl = self._implementations.get(capability) if impl is None: - return envelope.make_error_response( - ErrorCode.E_METHOD_NOT_ALLOWED.value, - f"未注册的能力: {capability}", - ) + return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}") # 3. 执行 try: @@ -94,10 +78,7 @@ class CapabilityService: return envelope.make_error_response(e.code.value, e.message, e.details) except Exception as e: logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True) - return envelope.make_error_response( - ErrorCode.E_CAPABILITY_FAILED.value, - str(e), - ) + return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e)) def list_capabilities(self) -> List[str]: """列出所有已注册的能力""" diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 79fe0d9a..75ef9b2a 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -7,7 +7,7 @@ 4. 请求-响应关联与超时管理 """ -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine import asyncio import contextlib @@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer logger = get_logger("plugin_runtime.host.rpc_server") # RPC 方法处理器类型 -MethodHandler = Callable[[Envelope], Awaitable[Envelope]] +MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]] class RPCServer: @@ -55,109 +55,29 @@ class RPCServer: self._id_gen = RequestIdGenerator() self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接 - self._runner_id: Optional[str] = None - self._runner_generation: int = 0 - self._staged_connection: Optional[Connection] = None - self._staged_runner_id: Optional[str] = None - self._staged_runner_generation: int = 0 - self._staging_takeover: bool = False # 方法处理器注册表 self._method_handlers: Dict[str, MethodHandler] = {} - # 等待响应的 pending 请求: request_id -> (Future, target_generation) - self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {} + # 等待响应的 pending 请求: request_id -> Future + self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {} # 发送队列(背压控制) self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None - self._send_worker_task: Optional[asyncio.Task] = None + self._send_worker_task: Optional[asyncio.Task[None]] = None # 运行状态 self._running: bool = False - self._tasks: List[asyncio.Task] = [] + self._tasks: List[asyncio.Task[None]] = [] @property def session_token(self) -> str: return self._session_token - def reset_session_token(self) -> str: - """重新生成会话令牌(热重载时调用,防止旧 Runner 重连)""" - self._session_token = secrets.token_hex(32) - return self._session_token - - def restore_session_token(self, token: str) -> None: - """恢复指定的会话令牌(热重载回滚时调用)""" - self._session_token = token - - @property - def runner_generation(self) -> int: - return self._runner_generation - - @property - def staged_generation(self) -> int: - return self._staged_runner_generation - @property def is_connected(self) -> bool: return self._connection is not None and not self._connection.is_closed - def has_generation(self, generation: int) -> bool: - return generation == self._runner_generation or ( - self._staged_connection is not None - and not self._staged_connection.is_closed - and generation == self._staged_runner_generation - ) - - def begin_staged_takeover(self) -> None: - """允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。""" - self._staging_takeover = True - - async def commit_staged_takeover(self) -> None: - """提交 staged Runner,原活跃连接在提交后被关闭。""" - if self._staged_connection is None or self._staged_connection.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接") - - old_connection = self._connection - old_generation = self._runner_generation - - self._connection = self._staged_connection - self._runner_id = self._staged_runner_id - self._runner_generation = self._staged_runner_generation - - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._staging_takeover = False - - if stale_count := self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已被新 generation 接管", - generation=old_generation, - ): - logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") - - if old_connection and old_connection is not self._connection and not old_connection.is_closed: - await old_connection.close() - - async def rollback_staged_takeover(self) -> None: - """放弃 staged Runner,保留当前活跃连接。""" - staged_connection = self._staged_connection - staged_generation = self._staged_runner_generation - - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._staging_takeover = False - - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "新 Runner 预热失败,已回滚", - generation=staged_generation, - ) - - if staged_connection and not staged_connection.is_closed: - await staged_connection.close() - def register_method(self, method: str, handler: MethodHandler) -> None: """注册 RPC 方法处理器""" self._method_handlers[method] = handler @@ -173,14 +93,8 @@ class RPCServer: async def stop(self) -> None: """停止 RPC 服务器""" self._running = False - - # 取消所有 pending 请求 - for future, _generation in self._pending_requests.values(): - if not future.done(): - future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) - self._pending_requests.clear() - - self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭") + self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") + self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") if self._send_worker_task: self._send_worker_task.cancel() @@ -198,10 +112,6 @@ class RPCServer: await self._connection.close() self._connection = None - if self._staged_connection: - await self._staged_connection.close() - self._staged_connection = None - await self._transport.stop() logger.info("RPC Server 已停止") @@ -211,7 +121,6 @@ class RPCServer: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, - target_generation: Optional[int] = None, ) -> Envelope: """向 Runner 发送 RPC 请求并等待响应 @@ -227,18 +136,14 @@ class RPCServer: Raises: RPCError: 调用失败 """ - generation = target_generation or self._runner_generation - conn = self._get_connection_for_generation(generation) - if conn is None or conn.is_closed: + if not self._connection or self._connection.is_closed: raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, method=method, plugin_id=plugin_id, - generation=generation, timeout_ms=timeout_ms, payload=payload or {}, ) @@ -246,12 +151,12 @@ class RPCServer: # 注册 pending future loop = asyncio.get_running_loop() future: asyncio.Future[Envelope] = loop.create_future() - self._pending_requests[request_id] = (future, generation) + self._pending_requests[request_id] = future try: # 发送请求 data = self._codec.encode_envelope(envelope) - await self._enqueue_send(conn, data) + await self._enqueue_send(self._connection, data) # 等待响应 timeout_sec = timeout_ms / 1000.0 @@ -265,93 +170,66 @@ class RPCServer: raise raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e - async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None: - """向 Runner 发送单向事件(不等待响应)""" - conn = self._connection - if conn is None or conn.is_closed: - return + # ============ 内部方法 ============ + # ========= 发送循环 ========= + async def _send_loop(self) -> None: + """后台发送循环:串行消费发送队列,统一执行连接写入。""" + if self._send_queue is None: + raise RuntimeError("没有消息队列") - request_id = self._id_gen.next() - envelope = Envelope( - request_id=request_id, - message_type=MessageType.EVENT, - method=method, - plugin_id=plugin_id, - generation=self._runner_generation, - payload=payload or {}, - ) - data = self._codec.encode_envelope(envelope) - await self._enqueue_send(conn, data) + while True: + try: + conn, data, send_future = await self._send_queue.get() + except asyncio.CancelledError: + break - # ─── 内部方法 ────────────────────────────────────────────── + try: + if conn.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + await conn.send_frame(data) + if not send_future.done(): + send_future.set_result(None) + except asyncio.CancelledError: + if not send_future.done(): + send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) + raise + except Exception as e: + send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED}) + if not send_future.done(): + send_future.set_exception(send_error) + finally: + self._send_queue.task_done() + # ====== 发送循环方法 ====== async def _handle_connection(self, conn: Connection) -> None: """处理新的 Runner 连接""" logger.info("收到 Runner 连接") - previous_connection = self._connection - previous_generation = self._runner_generation - # 第一条消息必须是 runner.hello 握手 try: - role = await self._handle_handshake(conn) - if role is None: + success = await self._handle_handshake(conn) + if not success: await conn.close() return except Exception as e: logger.error(f"握手失败: {e}") await conn.close() return - - if role == "staged": - expected_generation = self._staged_runner_generation - logger.info( - f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}" - ) - else: - self._connection = conn - expected_generation = self._runner_generation - logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}") - - if previous_connection and previous_connection is not conn and not previous_connection.is_closed: - logger.info("检测到新 Runner 已接管连接,关闭旧连接") - if stale_count := self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已被新 generation 接管", - generation=previous_generation, - ): - logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") - await previous_connection.close() - + logger.info("Runner staged 握手成功") + self._connection = conn # 启动消息接收循环 try: - await self._recv_loop(conn, expected_generation=expected_generation) + await self._recv_loop(conn) except Exception as e: logger.error(f"连接异常断开: {e}") finally: - if self._connection is conn: - self._connection = None - self._runner_id = None - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已断开", - generation=expected_generation, - ) - elif self._staged_connection is conn: - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Staged Runner 连接已断开", - generation=expected_generation, - ) + self._connection = None + self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开") - async def _handle_handshake(self, conn: Connection) -> Optional[str]: + async def _handle_handshake(self, conn: Connection) -> bool: """处理 runner.hello 握手""" # 接收握手请求 data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0) envelope = self._codec.decode_envelope(data) - if envelope.method != "runner.hello": logger.error(f"期望 runner.hello,收到 {envelope.method}") error_resp = envelope.make_error_response( @@ -359,21 +237,17 @@ class RPCServer: "首条消息必须为 runner.hello", ) await conn.send_frame(self._codec.encode_envelope(error_resp)) - return None + return False # 解析握手 payload hello = HelloPayload.model_validate(envelope.payload) - # 校验会话令牌 if hello.session_token != self._session_token: logger.error("会话令牌不匹配") - resp_payload = HelloResponsePayload( - accepted=False, - reason="会话令牌无效", - ) + resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效") resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return None + return False # 校验 SDK 版本 if not self._check_sdk_version(hello.sdk_version): @@ -384,31 +258,26 @@ class RPCServer: ) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return None + return False - # 握手成功 - role = "active" - assigned_generation = self._runner_generation + 1 - if self._staging_takeover and self.is_connected: - role = "staged" - self._staged_connection = conn - self._staged_runner_id = hello.runner_id - self._staged_runner_generation = assigned_generation - else: - self._runner_id = hello.runner_id - self._runner_generation = assigned_generation - - resp_payload = HelloResponsePayload( - accepted=True, - host_version=PROTOCOL_VERSION, - assigned_generation=assigned_generation, - ) + # 发送响应 + resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) + return True - return role + def _check_sdk_version(self, sdk_version: str) -> bool: + """检查 SDK 版本是否在支持范围内""" + try: + sdk_parts = _parse_version_tuple(sdk_version) + min_parts = _parse_version_tuple(MIN_SDK_VERSION) + max_parts = _parse_version_tuple(MAX_SDK_VERSION) + return min_parts <= sdk_parts <= max_parts + except (ValueError, AttributeError): + return False - async def _recv_loop(self, conn: Connection, expected_generation: int) -> None: + # ========= 接收循环 ========= + async def _recv_loop(self, conn: Connection) -> None: """消息接收主循环""" while self._running and not conn.is_closed: try: @@ -430,109 +299,40 @@ class RPCServer: if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): - if envelope.generation != expected_generation: - error_resp = envelope.make_error_response( - ErrorCode.E_GENERATION_MISMATCH.value, - f"过期 generation: {envelope.generation} != {expected_generation}", - ) - await conn.send_frame(self._codec.encode_envelope(error_resp)) - continue # 异步处理请求(Runner 发来的能力调用) task = asyncio.create_task(self._handle_request(envelope, conn)) self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) - elif envelope.is_event(): - if envelope.generation != expected_generation: - logger.warning( - f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}" - ) - continue - task = asyncio.create_task(self._handle_event(envelope)) + elif envelope.is_broadcast(): + task = asyncio.create_task(self._handle_broadcast(envelope)) self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) + else: + logger.warning(f"未知的消息类型: {envelope.message_type}") + continue + # ====== 接收循环内部方法 ====== def _handle_response(self, envelope: Envelope) -> None: """处理来自 Runner 的响应""" - pending = self._pending_requests.get(envelope.request_id) - if pending is None: + pending_future = self._pending_requests.pop(envelope.request_id, None) + if pending_future is None: return - - future, expected_generation = pending - if envelope.generation != expected_generation: - logger.warning( - f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}" - ) - return - - self._pending_requests.pop(envelope.request_id, None) - if not future.done(): + if not pending_future.done(): if envelope.error: - future.set_exception(RPCError.from_dict(envelope.error)) + pending_future.set_exception(RPCError.from_dict(envelope.error)) else: - future.set_result(envelope) - - async def _enqueue_send(self, conn: Connection, data: bytes) -> None: - """通过发送队列串行发送消息,提供真实背压。""" - if conn.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - - if self._send_queue is None: - await conn.send_frame(data) - return - - loop = asyncio.get_running_loop() - send_future: asyncio.Future[None] = loop.create_future() - - try: - self._send_queue.put_nowait((conn, data, send_future)) - except asyncio.QueueFull: - raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None - - await send_future - - async def _send_loop(self) -> None: - """后台发送循环:串行消费发送队列,统一执行连接写入。""" - if self._send_queue is None: - return - - while True: - try: - conn, data, send_future = await self._send_queue.get() - except asyncio.CancelledError: - break - - try: - if conn.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - await conn.send_frame(data) - if not send_future.done(): - send_future.set_result(None) - except asyncio.CancelledError: - if not send_future.done(): - send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) - raise - except Exception as e: - send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e) - if not send_future.done(): - send_future.set_exception(send_error) - finally: - self._send_queue.task_done() - - @staticmethod - def _normalize_send_exception(error: Exception) -> RPCError: - if isinstance(error, ConnectionError): - return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error)) - return RPCError(ErrorCode.E_UNKNOWN, str(error)) + pending_future.set_result(envelope) async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: """处理来自 Runner 的请求(通常是能力调用 cap.*)""" - handler = self._method_handlers.get(envelope.method) - if handler is None: - error_resp = envelope.make_error_response( + target_method = envelope.method + handler = self._method_handlers.get(target_method) + if not handler: + error_response = envelope.make_error_response( ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的方法: {envelope.method}", ) - await conn.send_frame(self._codec.encode_envelope(error_resp)) + await conn.send_frame(self._codec.encode_envelope(error_response)) return try: @@ -546,59 +346,25 @@ class RPCServer: error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) await conn.send_frame(self._codec.encode_envelope(error_resp)) - async def _handle_event(self, envelope: Envelope) -> None: - """处理来自 Runner 的事件""" + async def _handle_broadcast(self, envelope: Envelope) -> None: if handler := self._method_handlers.get(envelope.method): try: result = await handler(envelope) # 检查 handler 返回的信封是否包含错误信息 - if result is not None and isinstance(result, Envelope) and result.error: + if result.error: logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}") except Exception as e: logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) - @staticmethod - def _check_sdk_version(sdk_version: str) -> bool: - """检查 SDK 版本是否在支持范围内""" - try: - sdk_parts = RPCServer._parse_version_tuple(sdk_version) - min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION) - max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION) - return min_parts <= sdk_parts <= max_parts - except (ValueError, AttributeError): - return False - - @staticmethod - def _parse_version_tuple(version: str) -> Tuple[int, int, int]: - base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0] - base_version = base_version.split("+", 1)[0] - parts = [part for part in base_version.split(".") if part != ""] - while len(parts) < 3: - parts.append("0") - return (int(parts[0]), int(parts[1]), int(parts[2])) - - def _get_connection_for_generation(self, generation: int) -> Optional[Connection]: - if generation == self._runner_generation: - return self._connection - if generation == self._staged_runner_generation: - return self._staged_connection - return None - - def _fail_pending_requests( - self, - error_code: ErrorCode, - message: str, - generation: Optional[int] = None, - ) -> int: - stale_count = 0 - for request_id, (future, request_generation) in list(self._pending_requests.items()): - if generation is not None and request_generation != generation: - continue + def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int: + """失败所有等待中的请求(如连接断开时)""" + aborted_request_count = 0 + for future in self._pending_requests.values(): if not future.done(): future.set_exception(RPCError(error_code, message)) - stale_count += 1 - self._pending_requests.pop(request_id, None) - return stale_count + aborted_request_count += 1 + self._pending_requests.clear() + return aborted_request_count def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int: if self._send_queue is None: @@ -617,3 +383,31 @@ class RPCServer: self._send_queue.task_done() return failed_count + + async def _enqueue_send(self, conn: Connection, data: bytes) -> None: + """通过发送队列串行发送消息,提供真实背压。""" + if conn.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + + if self._send_queue is None: + await conn.send_frame(data) + return + + loop = asyncio.get_running_loop() + send_future: asyncio.Future[None] = loop.create_future() + + try: + self._send_queue.put_nowait((conn, data, send_future)) + except asyncio.QueueFull: + raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None + + await send_future + + +def _parse_version_tuple(version: str) -> Tuple[int, int, int]: + base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0] + base_version = base_version.split("+", 1)[0] + parts = [part for part in base_version.split(".") if part != ""] + while len(parts) < 3: + parts.append("0") + return (int(parts[0]), int(parts[1]), int(parts[2])) diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py index dcae6b8f..ed19760d 100644 --- a/src/plugin_runtime/protocol/errors.py +++ b/src/plugin_runtime/protocol/errors.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Dict, Optional -class ErrorCode(str, Enum): +class ErrorCode(Enum): """RPC 错误码枚举""" # 通用 @@ -18,17 +18,17 @@ class ErrorCode(str, Enum): E_TIMEOUT = "E_TIMEOUT" E_BAD_PAYLOAD = "E_BAD_PAYLOAD" E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH" + E_SHUTTING_DOWN = "E_SHUTTING_DOWN" # 权限与策略 E_UNAUTHORIZED = "E_UNAUTHORIZED" E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED" - E_BACKPRESSURE = "E_BACKPRESSURE" + E_BACK_PRESSURE = "E_BACK_PRESSURE" E_HOST_OVERLOADED = "E_HOST_OVERLOADED" # 插件生命周期 E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED" E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND" - E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH" E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS" # 能力调用 @@ -65,3 +65,13 @@ class RPCError(Exception): message=data.get("message", ""), details=data.get("details", {}), ) + + @classmethod + def from_exception(cls, exception: Exception, code_mapping: Optional[Dict[type[Exception], ErrorCode]] = None): + if isinstance(exception, cls): + return exception + if code_mapping: + for exception_type, code in code_mapping.items(): + if isinstance(exception, exception_type): + return cls(code=code, message=str(exception)) + return cls(ErrorCode.E_UNKNOWN, str(exception)) From ca6fd96d4c2e9eef1a469fd7653317ee8755be8e Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 17 Mar 2026 21:39:06 +0800 Subject: [PATCH 07/45] =?UTF-8?q?refactor:=20=E7=A1=AE=E8=AE=A4ErrorCode?= =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E7=BB=A7=E6=89=BFstr=EF=BC=8C=E6=81=A2?= =?UTF-8?q?=E5=A4=8D=E5=8E=9F=E6=9D=A5=E8=AE=BE=E8=AE=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/protocol/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py index ed19760d..d2b9228b 100644 --- a/src/plugin_runtime/protocol/errors.py +++ b/src/plugin_runtime/protocol/errors.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Dict, Optional -class ErrorCode(Enum): +class ErrorCode(str, Enum): """RPC 错误码枚举""" # 通用 From 14a0c21cbffdb80a3ff276bd736e0431ae772b10 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Wed, 18 Mar 2026 02:08:13 +0800 Subject: [PATCH 08/45] =?UTF-8?q?refactor:=20component=5Fregistry=E6=9B=B4?= =?UTF-8?q?=E6=98=93=E7=90=86=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/host/component_registry.py | 427 +++++++++++++----- src/plugin_runtime/host/logger_bridge.py | 45 ++ 2 files changed, 347 insertions(+), 125 deletions(-) create mode 100644 src/plugin_runtime/host/logger_bridge.py diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 220a19c0..89ec82d2 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -1,7 +1,7 @@ """Host-side ComponentRegistry 对齐旧系统 component_registry.py 的核心能力: -- 按类型注册组件(action / command / tool / event_handler / workflow_step) +- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway) - 命名空间 (plugin_id.component_name) - 命令正则匹配 - 组件启用/禁用 @@ -9,8 +9,10 @@ - 注册统计 """ -from typing import Any, Dict, List, Optional +from enum import Enum +from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple +import contextlib import re from src.common.logger import get_logger @@ -18,8 +20,28 @@ from src.common.logger import get_logger logger = get_logger("plugin_runtime.host.component_registry") -class RegisteredComponent: - """已注册的组件条目""" +class ComponentTypes(str, Enum): + ACTION = "ACTION" + COMMAND = "COMMAND" + TOOL = "TOOL" + EVENT_HANDLER = "EVENT_HANDLER" + WORKFLOW_HANDLER = "WORKFLOW_HANDLER" + MESSAGE_GATEWAY = "MESSAGE_GATEWAY" + + +class StatusDict(TypedDict): + total: int + ACTION: int + COMMAND: int + TOOL: int + EVENT_HANDLER: int + WORKFLOW_HANDLER: int + MESSAGE_GATEWAY: int + plugins: int + + +class ComponentEntry: + """组件条目""" __slots__ = ( "name", @@ -28,31 +50,74 @@ class RegisteredComponent: "plugin_id", "metadata", "enabled", - "_compiled_pattern", + "compiled_pattern", + "disabled_session", ) - def __init__( - self, - name: str, - component_type: str, - plugin_id: str, - metadata: Dict[str, Any], - ) -> None: - self.name = name - self.full_name = f"{plugin_id}.{name}" - self.component_type = component_type - self.plugin_id = plugin_id - self.metadata = metadata - self.enabled = metadata.get("enabled", True) + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.name: str = name + self.full_name: str = f"{plugin_id}.{name}" + self.component_type: ComponentTypes = ComponentTypes(component_type) + self.plugin_id: str = plugin_id + self.metadata: Dict[str, Any] = metadata + self.enabled: bool = metadata.get("enabled", True) + self.disabled_session: Set[str] = set() - # 预编译命令正则(仅 command 类型) - self._compiled_pattern: Optional[re.Pattern] = None - if component_type == "command": - if pattern := metadata.get("command_pattern", ""): - try: - self._compiled_pattern = re.compile(pattern) - except re.error as e: - logger.warning(f"命令 {self.full_name} 正则编译失败: {e}") + +class ActionEntry(ComponentEntry): + """Action 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + super().__init__(name, component_type, plugin_id, metadata) + + +class CommandEntry(ComponentEntry): + """Command 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.compiled_pattern: Optional[re.Pattern] = None + self.aliases: List[str] = metadata.get("aliases", []) + if pattern := metadata.get("command_pattern", ""): + try: + self.compiled_pattern = re.compile(pattern) + except re.error as e: + logger.warning(f"命令 {self.full_name} 正则编译失败: {e}") + super().__init__(name, component_type, plugin_id, metadata) + + +class ToolEntry(ComponentEntry): + """Tool 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.description: str = metadata.get("description", "") + self.parameters: List[Dict[str, Any]] = metadata.get("parameters", []) + self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", []) + super().__init__(name, component_type, plugin_id, metadata) + + +class EventHandlerEntry(ComponentEntry): + """EventHandler 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.event_type: str = metadata.get("event_type", "") + self.weight: int = metadata.get("weight", 0) + super().__init__(name, component_type, plugin_id, metadata) + + +class WorkflowHandlerEntry(ComponentEntry): + """WorkflowHandler 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + self.stage: str = metadata.get("stage", "") + self.priority: int = metadata.get("priority", 0) + super().__init__(name, component_type, plugin_id, metadata) + + +class MessageGatewayEntry(ComponentEntry): + """MessageGateway 组件条目""" + + def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + super().__init__(name, component_type, plugin_id, metadata) class ComponentRegistry: @@ -64,19 +129,15 @@ class ComponentRegistry: def __init__(self) -> None: # 全量索引 - self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp + self._components: Dict[str, ComponentEntry] = {} # full_name -> comp # 按类型索引 - self._by_type: Dict[str, Dict[str, RegisteredComponent]] = { - "action": {}, - "command": {}, - "tool": {}, - "event_handler": {}, - "workflow_step": {}, - } + self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = { + comp_type: {} for comp_type in ComponentTypes + } # component_type -> (full_name -> comp) # 按插件索引 - self._by_plugin: Dict[str, List[RegisteredComponent]] = {} + self._by_plugin: Dict[str, List[ComponentEntry]] = {} def clear(self) -> None: """清空全部组件注册状态。""" @@ -85,47 +146,63 @@ class ComponentRegistry: type_dict.clear() self._by_plugin.clear() - # ──── 注册 / 注销 ───────────────────────────────────────── + # ====== 注册 / 注销 ====== + def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: + """注册单个组件 + + Args: + name: 组件名称(不含插件id前缀) + component_type: 组件类型(如 `ACTION`、`COMMAND` 等) + plugin_id: 插件id + metadata: 组件元数据 + Returns: + success (bool): 是否成功注册(失败原因通常是组件类型无效) + """ + try: + if component_type == ComponentTypes.ACTION: + comp = ActionEntry(name, component_type, plugin_id, metadata) + elif component_type == ComponentTypes.COMMAND: + comp = CommandEntry(name, component_type, plugin_id, metadata) + elif component_type == ComponentTypes.TOOL: + comp = ToolEntry(name, component_type, plugin_id, metadata) + elif component_type == ComponentTypes.EVENT_HANDLER: + comp = EventHandlerEntry(name, component_type, plugin_id, metadata) + elif component_type == ComponentTypes.WORKFLOW_HANDLER: + comp = WorkflowHandlerEntry(name, component_type, plugin_id, metadata) + elif component_type == ComponentTypes.MESSAGE_GATEWAY: + comp = MessageGatewayEntry(name, component_type, plugin_id, metadata) + else: + raise ValueError(f"组件类型 {component_type} 不存在") + except ValueError: + logger.error(f"组件类型 {component_type} 不存在") + return False - def register_component( - self, - name: str, - component_type: str, - plugin_id: str, - metadata: Dict[str, Any], - ) -> bool: - """注册单个组件。""" - comp = RegisteredComponent(name, component_type, plugin_id, metadata) if comp.full_name in self._components: logger.warning(f"组件 {comp.full_name} 已存在,覆盖") old_comp = self._components[comp.full_name] # 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积 old_list = self._by_plugin.get(old_comp.plugin_id) if old_list is not None: - try: + with contextlib.suppress(ValueError): old_list.remove(old_comp) - except ValueError: - pass # 从旧类型索引中移除,防止类型变更时幽灵残留 if old_type_dict := self._by_type.get(old_comp.component_type): old_type_dict.pop(comp.full_name, None) self._components[comp.full_name] = comp - - if component_type not in self._by_type: - self._by_type[component_type] = {} - self._by_type[component_type][comp.full_name] = comp - + self._by_type[comp.component_type][comp.full_name] = comp self._by_plugin.setdefault(plugin_id, []).append(comp) return True - def register_plugin_components( - self, - plugin_id: str, - components: List[Dict[str, Any]], - ) -> int: - """批量注册一个插件的所有组件,返回成功注册数。""" + def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int: + """批量注册一个插件的所有组件,返回成功注册数。 + Args: + plugin_id (str): 插件id + components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段 + Returns: + count (int): 成功注册的组件数量 + """ count = 0 for comp_data in components: ok = self.register_component( @@ -139,7 +216,13 @@ class ComponentRegistry: return count def remove_components_by_plugin(self, plugin_id: str) -> int: - """移除某个插件的所有组件,返回移除数量。""" + """移除某个插件的所有组件,返回移除数量。 + + Args: + plugin_id (str): 插件id + Returns: + count (int): 移除的组件数量 + """ comps = self._by_plugin.pop(plugin_id, []) for comp in comps: self._components.pop(comp.full_name, None) @@ -147,106 +230,200 @@ class ComponentRegistry: type_dict.pop(comp.full_name, None) return len(comps) - # ──── 启用 / 禁用 ───────────────────────────────────────── + # ====== 启用 / 禁用 ====== + def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None): + if session_id and session_id in component.disabled_session: + return False + return component.enabled - def set_component_enabled(self, full_name: str, enabled: bool) -> bool: - """启用或禁用指定组件。""" + def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: + """启用或禁用指定组件。 + + Args: + full_name (str): 组件全名 + enabled (bool): 使能情况 + session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供) + Returns: + success (bool): 是否成功设置(失败原因通常是组件不存在) + """ comp = self._components.get(full_name) if comp is None: return False - comp.enabled = enabled + if session_id: + if enabled: + comp.disabled_session.discard(session_id) + else: + comp.disabled_session.add(session_id) + else: + comp.enabled = enabled return True - def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int: - """批量启用或禁用某插件的所有组件。""" + def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int: + """批量启用或禁用某插件的所有组件。 + + Args: + plugin_id (str): 插件id + enabled (bool): 使能情况 + session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供) + Returns: + count (int): 成功设置的组件数量(失败原因通常是插件不存在) + """ comps = self._by_plugin.get(plugin_id, []) for comp in comps: - comp.enabled = enabled + if session_id: + if enabled: + comp.disabled_session.discard(session_id) + else: + comp.disabled_session.add(session_id) + else: + comp.enabled = enabled return len(comps) - # ──── 查询方法 ───────────────────────────────────────────── + def get_component(self, full_name: str) -> Optional[ComponentEntry]: + """按全名查询。 - def get_component(self, full_name: str) -> Optional[RegisteredComponent]: - """按全名查询。""" + Args: + full_name (str): 组件全名 + Returns: + component (Optional[ComponentEntry]): 组件条目,未找到时为 None + """ return self._components.get(full_name) - def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: - """按类型查询。""" - type_dict = self._by_type.get(component_type, {}) + def get_components_by_type( + self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> List[ComponentEntry]: + """按类型查询组件 + + Args: + component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等) + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + components (List[ComponentEntry]): 组件条目列表 + """ + try: + comp_type = ComponentTypes(component_type) + except ValueError: + logger.error(f"组件类型 {component_type} 不存在") + raise + type_dict = self._by_type.get(comp_type, {}) if enabled_only: - return [c for c in type_dict.values() if c.enabled] + return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)] return list(type_dict.values()) - def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: - """按插件查询。""" - comps = self._by_plugin.get(plugin_id, []) - return [c for c in comps if c.enabled] if enabled_only else list(comps) + def get_components_by_plugin( + self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> List[ComponentEntry]: + """按插件查询组件。 - def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]: + Args: + plugin_id (str): 插件ID + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + components (List[ComponentEntry]): 组件条目列表 + """ + comps = self._by_plugin.get(plugin_id, []) + return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps) + + def find_command_by_text( + self, text: str, session_id: Optional[str] = None + ) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]: """通过文本匹配命令正则,返回 (组件, matched_groups) 元组。 matched_groups 为正则命名捕获组 dict,别名匹配时为空 dict。 + Args: + text (str): 待匹配文本 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None """ - for comp in self._by_type.get("command", {}).values(): - if not comp.enabled: + for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values(): + if not self.check_component_enabled(comp, session_id): continue - if comp._compiled_pattern: - m = comp._compiled_pattern.search(text) - if m: + if not isinstance(comp, CommandEntry): + continue + if comp.compiled_pattern: + if m := comp.compiled_pattern.search(text): return comp, m.groupdict() # 别名匹配 - aliases = comp.metadata.get("aliases", []) - for alias in aliases: + for alias in comp.aliases: if text.startswith(alias): return comp, {} return None - def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: - """获取特定事件类型的所有 event_handler,按 weight 降序排列。""" - handlers = [] - for comp in self._by_type.get("event_handler", {}).values(): - if enabled_only and not comp.enabled: + def get_event_handlers( + self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> List[EventHandlerEntry]: + """查询指定事件类型的事件处理器组件。 + + Args: + event_type (str): 事件类型 + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序 + """ + handlers: List[EventHandlerEntry] = [] + for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values(): + if enabled_only and not self.check_component_enabled(comp, session_id): continue - if comp.metadata.get("event_type") == event_type: + if not isinstance(comp, EventHandlerEntry): + continue + if comp.event_type == event_type: handlers.append(comp) - handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True) + handlers.sort(key=lambda c: c.weight, reverse=True) return handlers - def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: - """获取特定 workflow 阶段的所有步骤,按 priority 降序。""" - steps = [] - for comp in self._by_type.get("workflow_step", {}).values(): - if enabled_only and not comp.enabled: + def get_workflow_handlers( + self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> List[WorkflowHandlerEntry]: + """获取特定 workflow 阶段的所有步骤,按 priority 降序。 + + Args: + stage: workflow 阶段名称 + enabled_only: 是否仅返回启用的组件 + session_id: 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + handlers (List[WorkflowHandlerEntry]): 符合条件的 WorkflowHandler 组件列表,按 priority 降序排序 + """ + handlers: List[WorkflowHandlerEntry] = [] + for comp in self._by_type.get(ComponentTypes.WORKFLOW_HANDLER, {}).values(): + if enabled_only and not self.check_component_enabled(comp, session_id): continue - if comp.metadata.get("stage") == stage: - steps.append(comp) - steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True) - return steps + if not isinstance(comp, WorkflowHandlerEntry): + continue + if comp.stage == stage: + handlers.append(comp) + handlers.sort(key=lambda c: c.priority, reverse=True) + return handlers - def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]: - """获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。""" - result: List[Dict[str, Any]] = [] - for comp in self.get_components_by_type("tool", enabled_only=enabled_only): - tool_def: Dict[str, Any] = { - "name": comp.full_name, - "description": comp.metadata.get("description", ""), - } - # 从结构化参数或原始参数构建 parameters - params = comp.metadata.get("parameters", []) - params_raw = comp.metadata.get("parameters_raw", {}) - if params: - tool_def["parameters"] = params - elif params_raw: - tool_def["parameters"] = params_raw - result.append(tool_def) - return result + def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]: + """查询所有工具组件。 + + Args: + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + tools (List[ToolEntry]): 符合条件的 Tool 组件列表 + """ + tools: List[ToolEntry] = [] + for comp in self._by_type.get(ComponentTypes.TOOL, {}).values(): + if enabled_only and not self.check_component_enabled(comp, session_id): + continue + if isinstance(comp, ToolEntry): + tools.append(comp) + return tools - # ──── 统计 ───────────────────────────────────────────────── - - def get_stats(self) -> Dict[str, int]: - """获取注册统计。""" - stats: Dict[str, int] = {"total": len(self._components)} + # ====== 统计信息 ====== + def get_stats(self) -> StatusDict: + """获取注册统计。 + + Returns: + stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等 + """ + stats: StatusDict = {"total": len(self._components)} # type: ignore for comp_type, type_dict in self._by_type.items(): - stats[comp_type] = len(type_dict) + stats[comp_type.value] = len(type_dict) stats["plugins"] = len(self._by_plugin) return stats diff --git a/src/plugin_runtime/host/logger_bridge.py b/src/plugin_runtime/host/logger_bridge.py new file mode 100644 index 00000000..f2213dfe --- /dev/null +++ b/src/plugin_runtime/host/logger_bridge.py @@ -0,0 +1,45 @@ +import logging as stdlib_logging +from src.plugin_runtime.protocol.errors import ErrorCode +from src.plugin_runtime.protocol.envelope import Envelope, LogBatchPayload +class RunnerLogBridge: + """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。 + + Runner 通过 ``runner.log_batch`` IPC 事件批量到达。 + 每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接 + 调用 ``logging.getLogger(entry.logger_name).handle(record)``, + 从而接入主进程已配置好的 structlog Handler 链。 + """ + + async def handle_log_batch(self, envelope: Envelope) -> Envelope: + """IPC 事件处理器:解析批量日志并重放到主进程 Logger。 + + Args: + envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。 + + Returns: + 空响应信封(事件模式下将被忽略)。 + """ + try: + batch = LogBatchPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + for entry in batch.entries: + # 重建一个与原始日志尽量相符的 LogRecord + record = stdlib_logging.LogRecord( + name=entry.logger_name, + level=entry.level, + pathname="", + lineno=0, + msg=entry.message, + args=(), + exc_info=None, + ) + record.created = entry.timestamp_ms / 1000.0 + record.msecs = entry.timestamp_ms % 1000 + if entry.exception_text: + record.exc_text = entry.exception_text + + stdlib_logging.getLogger(entry.logger_name).handle(record) + + return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)}) \ No newline at end of file From 32519c688bcfb3a540b235d3f337629c490a9e6d Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Wed, 18 Mar 2026 15:29:08 +0800 Subject: [PATCH 09/45] refactor: event_dispatcher --- src/plugin_runtime/host/component_registry.py | 7 +- src/plugin_runtime/host/event_dispatcher.py | 114 ++++++++++-------- 2 files changed, 65 insertions(+), 56 deletions(-) diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 89ec82d2..22f9a7e0 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -101,6 +101,7 @@ class EventHandlerEntry(ComponentEntry): def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: self.event_type: str = metadata.get("event_type", "") self.weight: int = metadata.get("weight", 0) + self.intercept_message: bool = metadata.get("intercept_message", False) super().__init__(name, component_type, plugin_id, metadata) @@ -356,7 +357,7 @@ class ComponentRegistry: self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None ) -> List[EventHandlerEntry]: """查询指定事件类型的事件处理器组件。 - + Args: event_type (str): 事件类型 enabled_only (bool): 是否仅返回启用的组件 @@ -400,7 +401,7 @@ class ComponentRegistry: def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]: """查询所有工具组件。 - + Args: enabled_only (bool): 是否仅返回启用的组件 session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 @@ -418,7 +419,7 @@ class ComponentRegistry: # ====== 统计信息 ====== def get_stats(self) -> StatusDict: """获取注册统计。 - + Returns: stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等 """ diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py index 720e93d7..f08591a8 100644 --- a/src/plugin_runtime/host/event_dispatcher.py +++ b/src/plugin_runtime/host/event_dispatcher.py @@ -4,40 +4,38 @@ 1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry) 2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器 3. 支持阻塞(intercept_message)和非阻塞分发 -4. 事件结果历史记录 +4. 事件结果历史记录(有上限) """ -from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING import asyncio from src.common.logger import get_logger -from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent + + +if TYPE_CHECKING: + from .supervisor import PluginRunnerSupervisor + from .component_registry import ComponentRegistry, EventHandlerEntry logger = get_logger("plugin_runtime.host.event_dispatcher") # invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]] +# 每个事件类型的最大历史记录数量,防止内存无限增长 +_MAX_HISTORY_LENGTH = 100 +@dataclass class EventResult: """单个 EventHandler 的执行结果""" - __slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result") - - def __init__( - self, - handler_name: str, - success: bool = True, - continue_processing: bool = True, - modified_message: Optional[Dict[str, Any]] = None, - custom_result: Any = None, - ): - self.handler_name = handler_name - self.success = success - self.continue_processing = continue_processing - self.modified_message = modified_message - self.custom_result = custom_result + handler_name: str + success: bool = field(default=True) + continue_processing: bool = field(default=True) + modified_message: Optional[Dict[str, Any]] = field(default=None) + custom_result: Any = field(default=None) class EventDispatcher: @@ -48,8 +46,8 @@ class EventDispatcher: 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。 """ - def __init__(self, registry: ComponentRegistry) -> None: - self._registry: ComponentRegistry = registry + def __init__(self, component_registry: "ComponentRegistry") -> None: + self._component_registry: "ComponentRegistry" = component_registry self._result_history: Dict[str, List[EventResult]] = {} self._history_enabled: Set[str] = set() # 保持 fire-and-forget task 的强引用,防止被 GC 回收 @@ -59,6 +57,10 @@ class EventDispatcher: self._history_enabled.add(event_type) self._result_history.setdefault(event_type, []) + def disable_history(self, event_type: str) -> None: + self._history_enabled.discard(event_type) + self._result_history.pop(event_type, None) + def get_history(self, event_type: str) -> List[EventResult]: return self._result_history.get(event_type, []) @@ -69,44 +71,46 @@ class EventDispatcher: async def dispatch_event( self, event_type: str, - invoke_fn: InvokeFn, - message: Optional[Dict[str, Any]] = None, + supervisor: "PluginRunnerSupervisor", + message_dict: Optional[Dict[str, Any]] = None, extra_args: Optional[Dict[str, Any]] = None, ) -> Tuple[bool, Optional[Dict[str, Any]]]: - """分发事件到所有对应 handler。 + """分发事件到所有对应 handler 的便捷方法。 + + 内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑, + 无需调用方手动构造 invoke_fn 闭包。 Args: event_type: 事件类型字符串 - invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict + supervisor: PluginSupervisor 实例,用于调用 invoke_plugin message: MaiMessages 序列化后的 dict(可选) extra_args: 额外参数 Returns: - (should_continue, modified_message_dict) + (should_continue, modified_message_dict) (bool, Dict[str, Any] | None): (是否继续后续执行, 可选的修改后的消息字典) """ - handlers = self._registry.get_event_handlers(event_type) - if not handlers: + handler_entries = self._component_registry.get_event_handlers(event_type) + if not handler_entries: return True, None should_continue = True - modified_message: Optional[Dict[str, Any]] = None - intercept_handlers: List[RegisteredComponent] = [] - async_handlers: List[RegisteredComponent] = [] + modified_message: Optional[Dict[str, Any]] = message_dict + intercept_handlers: List["EventHandlerEntry"] = [] + non_blocking_handlers: List["EventHandlerEntry"] = [] - for handler in handlers: - if handler.metadata.get("intercept_message", False): - intercept_handlers.append(handler) + for entry in handler_entries: + if entry.intercept_message: + intercept_handlers.append(entry) else: - async_handlers.append(handler) + non_blocking_handlers.append(entry) - for handler in intercept_handlers: + for entry in intercept_handlers: args = { "event_type": event_type, - "message": modified_message or message, + "message": modified_message, **(extra_args or {}), } - - result = await self._invoke_handler(invoke_fn, handler, args, event_type) + result = await self._invoke_handler(supervisor, entry, args, event_type) if result and not result.continue_processing: should_continue = False break @@ -114,16 +118,16 @@ class EventDispatcher: modified_message = result.modified_message if should_continue: - final_message = modified_message or message - for handler in async_handlers: - async_message = final_message.copy() if isinstance(final_message, dict) else final_message + final_message = modified_message + for entry in non_blocking_handlers: + async_message = final_message.copy() if final_message else final_message args = { "event_type": event_type, "message": async_message, **(extra_args or {}), } # 非阻塞:保持实例级强引用,防止 task 被 GC 回收 - task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type)) + task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) @@ -131,30 +135,34 @@ class EventDispatcher: async def _invoke_handler( self, - invoke_fn: InvokeFn, - handler: RegisteredComponent, + supervisor: "PluginRunnerSupervisor", + handler_entry: "EventHandlerEntry", args: Dict[str, Any], event_type: str, ) -> Optional[EventResult]: """调用单个 handler 并收集结果。""" try: - resp = await invoke_fn(handler.plugin_id, handler.name, args) + resp_envelope = await supervisor.invoke_plugin( + "plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args + ) + resp = resp_envelope.payload result = EventResult( - handler_name=handler.full_name, + handler_name=handler_entry.full_name, success=resp.get("success", True), continue_processing=resp.get("continue_processing", True), modified_message=resp.get("modified_message"), custom_result=resp.get("custom_result"), ) except Exception as e: - logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True) - result = EventResult( - handler_name=handler.full_name, - success=False, - continue_processing=True, - ) + logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True) + result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True) if event_type in self._history_enabled: - self._result_history.setdefault(event_type, []).append(result) + history_list = self._result_history.setdefault(event_type, []) + history_list.append(result) + # 自动清理超出限制的旧记录,防止内存无限增长 + if len(history_list) > _MAX_HISTORY_LENGTH: + # 保留最新的 _MAX_HISTORY_LENGTH 条记录 + self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:] return result From 17248a4cbc2fcdda721a879c7d84f2520e54ea87 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Wed, 18 Mar 2026 20:18:11 +0800 Subject: [PATCH 10/45] =?UTF-8?q?=E6=B7=BB=E5=8A=A0message=20gateway?= =?UTF-8?q?=E7=BB=84=E4=BB=B6=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/host/component_registry.py | 25 ++ src/plugin_runtime/host/message_gateway.py | 309 ++++++++++++++++++ 2 files changed, 334 insertions(+) create mode 100644 src/plugin_runtime/host/message_gateway.py diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 22f9a7e0..7ac7a518 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -118,6 +118,10 @@ class MessageGatewayEntry(ComponentEntry): """MessageGateway 组件条目""" def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + platform = metadata.get("platform") + if not platform or not isinstance(platform, str): + raise ValueError(f"MessageGateway 组件 {plugin_id}.{name} 缺少有效的 platform 字段") + self.platform: str = platform super().__init__(name, component_type, plugin_id, metadata) @@ -399,6 +403,27 @@ class ComponentRegistry: handlers.sort(key=lambda c: c.priority, reverse=True) return handlers + def get_message_gateways( + self, platform: str, *, enabled_only: bool = True, session_id: Optional[str] = None + ) -> Optional[MessageGatewayEntry]: + """查询消息网关组件。 + + Args: + platform (str): 平台名称 + enabled_only (bool): 是否仅返回启用的组件 + session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + Returns: + gateway (Optional[MessageGatewayEntry]): 符合条件的 MessageGateway 组件,可能不存在 + """ + + for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values(): + if not isinstance(comp, MessageGatewayEntry): + continue + if enabled_only and not self.check_component_enabled(comp, session_id): + continue + if comp.platform == platform: + return comp # 返回第一个 + def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]: """查询所有工具组件。 diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py new file mode 100644 index 00000000..e995ed01 --- /dev/null +++ b/src/plugin_runtime/host/message_gateway.py @@ -0,0 +1,309 @@ +""" +Message Gateway 模块 +适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。 +""" + +from datetime import datetime +from typing import Dict, Any, TYPE_CHECKING, TypedDict, Optional, List + +from src.common.logger import get_logger +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo +from src.common.data_models.message_component_data_model import MessageSequence + +if TYPE_CHECKING: + from .component_registry import ComponentRegistry + from .supervisor import PluginRunnerSupervisor + +logger = get_logger("plugin_runtime.host.message_gateway") + + +class UserInfoDict(TypedDict, total=False): + user_id: str + user_nickname: str + user_cardname: Optional[str] + + +class GroupInfoDict(TypedDict, total=False): + group_id: str + group_name: str + + +class MessageInfoDict(TypedDict, total=False): + user_info: UserInfoDict + group_info: Optional[GroupInfoDict] + additional_config: Dict[str, Any] + + +class MessageDict(TypedDict, total=False): + message_id: str + timestamp: str + platform: str + message_info: MessageInfoDict + raw_message: List[Dict[str, Any]] + is_mentioned: bool + is_at: bool + is_emoji: bool + is_picture: bool + is_command: bool + is_notify: bool + session_id: str + reply_to: Optional[str] + processed_plain_text: Optional[str] + display_message: Optional[str] + + +class MessageGateway: + def __init__(self, component_registry: "ComponentRegistry") -> None: + self._component_registry = component_registry + + async def receive_external_message(self, external_message: Dict[str, Any]): + """ + 接收外部消息,转换为系统内部格式,并返回转换结果 + + Args: + external_message: 外部消息的字典格式数据 + + Returns: + 转换后的 SessionMessage 对象 + """ + # 使用递归函数将外部消息字典转换为 SessionMessage + try: + session_message = self._build_session_message_from_dict(external_message) + except Exception as e: + logger.error(f"转换外部消息失败: {e}") + return + from src.chat.message_receive.bot import chat_bot + + await chat_bot.receive_message(session_message) + + async def send_message_to_external( + self, + internal_message: SessionMessage, + supervisor: "PluginRunnerSupervisor", + *, + enabled_only: bool = True, + save_to_db: bool = True, + ) -> bool: + """ + 接收系统内部消息,转换为外部格式,并返回转换结果 + + Args: + internal_message: 系统内部的 SessionMessage 对象 + + Returns: + 转换是否成功 + """ + try: + # 将 SessionMessage 转换为字典格式 + message_dict = self._session_message_to_dict(internal_message) + except Exception as e: + logger.error(f"转换内部消息失败:{e}") + return False + gateway_entry = self._component_registry.get_message_gateways( + internal_message.platform, + enabled_only=enabled_only, + session_id=internal_message.session_id, + ) + if not gateway_entry: + logger.warning(f"未找到适配平台 {internal_message.platform} 的消息网关组件,无法发送消息到外部平台") + return False + args = {"platform": internal_message.platform, "message": message_dict} + try: + resp_envelope = await supervisor.invoke_plugin( + "plugin.emit_event", gateway_entry.plugin_id, gateway_entry.name, args + ) + logger.debug("信息发送成功") + except Exception as e: + logger.error(f"调用消息网关组件失败:{e}") + return False + + # 更新为实际id(如果组件返回了新的id) + actual_message_id = resp_envelope.payload.get("message_id") + try: + actual_message_id = str(actual_message_id) + except Exception: + actual_message_id = None + internal_message.message_id = actual_message_id or internal_message.message_id + if save_to_db: + try: + from src.common.utils.utils_message import MessageUtils + + MessageUtils.store_message_to_db(internal_message) + except Exception as e: + logger.error(f"保存消息到数据库失败: {e}") + return True + + def _message_info_to_dict(self, message_info: MessageInfo) -> MessageInfoDict: + """ + 将 MessageInfo 对象转换为字典格式 + + Args: + message_info: MessageInfo 对象 + + Returns: + 字典格式的消息信息 + """ + user_info_dict = UserInfoDict( + user_id=message_info.user_info.user_id, + user_nickname=message_info.user_info.user_nickname, + user_cardname=message_info.user_info.user_cardname, + ) + + group_info_dict: Optional[GroupInfoDict] = None + if message_info.group_info: + group_info_dict = GroupInfoDict( + group_id=message_info.group_info.group_id, + group_name=message_info.group_info.group_name, + ) + + return MessageInfoDict( + user_info=user_info_dict, + group_info=group_info_dict, + additional_config=message_info.additional_config, + ) + + def _session_message_to_dict(self, session_message: SessionMessage) -> MessageDict: + """ + 将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法) + + Args: + session_message: SessionMessage 对象 + + Returns: + 字典格式的消息 + """ + # 转换基本信息 + message_dict = MessageDict( + message_id=session_message.message_id, + timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串 + platform=session_message.platform, + message_info=self._message_info_to_dict(session_message.message_info), + raw_message=session_message.raw_message.to_dict(), # 复用 MessageSequence.to_dict() + is_mentioned=session_message.is_mentioned, + is_at=session_message.is_at, + is_emoji=session_message.is_emoji, + is_picture=session_message.is_picture, + is_command=session_message.is_command, + is_notify=session_message.is_notify, + session_id=session_message.session_id, + ) + + # 添加可选字段 + if session_message.reply_to is not None: + message_dict["reply_to"] = session_message.reply_to + if session_message.processed_plain_text is not None: + message_dict["processed_plain_text"] = session_message.processed_plain_text + if session_message.display_message is not None: + message_dict["display_message"] = session_message.display_message + + return message_dict + + def _build_message_info_from_dict(self, message_info_dict: Dict[str, Any]) -> MessageInfo: + """ + 从字典构建 MessageInfo 对象 + + Args: + message_info_dict: 包含消息信息的字典 + + Returns: + MessageInfo 对象 + """ + # 构建用户信息 + user_info_dict = message_info_dict.get("user_info") + if not user_info_dict or not isinstance(user_info_dict, dict): + raise ValueError("消息字典中 'user_info' 字段无效") + user_id = user_info_dict.get("user_id") + user_nickname = user_info_dict.get("user_nickname") + user_cardname = user_info_dict.get("user_cardname") + if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname: + raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'") + user_cardname = str(user_cardname) if user_cardname is not None else None + user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname) + + # 构建群信息 + if group_info_dict := message_info_dict.get("group_info"): + group_id = group_info_dict.get("group_id") + group_name = group_info_dict.get("group_name") + if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name: + raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'") + group_info = GroupInfo(group_id=group_id, group_name=group_name) + else: + group_info = None + + # 获取额外配置 + additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {}) + + return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config) + + def _build_session_message_from_dict(self, message_dict: Dict[str, Any]) -> SessionMessage: + """ + 从字典构建 SessionMessage 对象(递归处理消息组件) + + Args: + message_dict: 包含消息完整信息的字典 + + Returns: + SessionMessage 对象 + """ + # 提取基本信息 + message_id = message_dict["message_id"] + timestamp_str: str = message_dict.get("timestamp", "") + platform = message_dict["platform"] + if not isinstance(message_id, str) or not message_id: + raise ValueError("消息字典中缺少有效的 'message_id' 字段") + if not isinstance(platform, str) or not platform: + raise ValueError("消息字典中缺少有效的 'platform' 字段") + + # 解析时间戳 + try: + timestamp_float = float(timestamp_str) + timestamp = datetime.fromtimestamp(timestamp_float) + except (ValueError, TypeError): + timestamp = datetime.now() # 如果解析失败,使用当前时间 + + # 创建 SessionMessage 实例 + session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform) + + # 构建消息信息 + session_message.message_info = self._build_message_info_from_dict(message_dict["message_info"]) + + # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法) + raw_message_data = message_dict["raw_message"] + if isinstance(raw_message_data, list): + session_message.raw_message = MessageSequence.from_dict(raw_message_data) + else: + raise ValueError("消息字典中 'raw_message' 字段必须是一个列表") + + # 设置其他可选属性 + session_message.is_mentioned = message_dict.get("is_mentioned", False) + if not isinstance(session_message.is_mentioned, bool): + session_message.is_mentioned = False + session_message.is_at = message_dict.get("is_at", False) + if not isinstance(session_message.is_at, bool): + session_message.is_at = False + session_message.is_emoji = message_dict.get("is_emoji", False) + if not isinstance(session_message.is_emoji, bool): + session_message.is_emoji = False + session_message.is_picture = message_dict.get("is_picture", False) + if not isinstance(session_message.is_picture, bool): + session_message.is_picture = False + session_message.is_command = message_dict.get("is_command", False) + if not isinstance(session_message.is_command, bool): + session_message.is_command = False + session_message.is_notify = message_dict.get("is_notify", False) + if not isinstance(session_message.is_notify, bool): + session_message.is_notify = False + session_message.reply_to = message_dict.get("reply_to") + if session_message.reply_to is not None and not isinstance(session_message.reply_to, str): + session_message.reply_to = None + session_message.processed_plain_text = message_dict.get("processed_plain_text") + if session_message.processed_plain_text is not None and not isinstance( + session_message.processed_plain_text, str + ): + session_message.processed_plain_text = None + session_message.display_message = message_dict.get("display_message") + if session_message.display_message is not None and not isinstance(session_message.display_message, str): + session_message.display_message = None + + return session_message From 593400c0aacc543fe8385bd10fcabe76dc184ee4 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Wed, 18 Mar 2026 20:24:07 +0800 Subject: [PATCH 11/45] =?UTF-8?q?bot.py=E6=94=AF=E6=8C=81gateway=E7=9A=84?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/bot.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 60586406..23e7de6e 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -310,6 +310,14 @@ class ChatBot: # logger.debug(str(message_data)) maim_raw_message = MessageBase.from_dict(message_data) message = SessionMessage.from_maim_message(maim_raw_message) + await self.receive_message(message) + + except Exception as e: + logger.error(f"预处理消息失败: {e}") + traceback.print_exc() + + async def receive_message(self, message: SessionMessage): + try: group_info = message.message_info.group_info user_info = message.message_info.user_info @@ -366,11 +374,11 @@ class ChatBot: # 命令处理 - 使用新插件系统检查并处理命令 # 注意:命令返回的 response 当前只用于日志记录和流程判断, # 不会在这里自动作为回复消息发送回会话。 - is_command, cmd_result, continue_process = await self._process_commands(message) + # is_command, cmd_result, continue_process = await self._process_commands(message) - # 如果是命令且不需要继续处理,则直接返回 - if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process): - return + # # 如果是命令且不需要继续处理,则直接返回 + # if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process): + # return # continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message) # if not continue_flag: From 310d7798ba81cfd01981dee9dc87ef8100a2e2c9 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 19 Mar 2026 00:35:19 +0800 Subject: [PATCH 12/45] =?UTF-8?q?refactor:=20hook=5Fdispatcher=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E7=9A=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/logger_color_and_mapping.py | 2 +- src/config/official_configs.py | 6 +- src/plugin_runtime/host/component_registry.py | 27 +- src/plugin_runtime/host/event_dispatcher.py | 8 +- src/plugin_runtime/host/hook_dispatcher.py | 166 +++++++ src/plugin_runtime/host/workflow_executor.py | 422 ------------------ src/plugin_runtime/protocol/envelope.py | 4 +- 7 files changed, 193 insertions(+), 442 deletions(-) create mode 100644 src/plugin_runtime/host/hook_dispatcher.py delete mode 100644 src/plugin_runtime/host/workflow_executor.py diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index 1aabafbc..f84caa21 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -58,7 +58,7 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = { "plugin_runtime.host.component_registry": ("#ffaf00", None, False), "plugin_runtime.host.capability_service": ("#ffd700", None, False), "plugin_runtime.host.event_dispatcher": ("#87d700", None, False), - "plugin_runtime.host.workflow_executor": ("#5fd7af", None, False), + "plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False), "plugin_runtime.runner.main": ("#d787ff", None, False), "plugin_runtime.runner.rpc_client": ("#8787ff", None, False), "plugin_runtime.runner.manifest_validator": ("#5fafff", None, False), diff --git a/src/config/official_configs.py b/src/config/official_configs.py index e6907611..35360217 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1642,14 +1642,14 @@ class PluginRuntimeConfig(ConfigBase): ) """等待 Runner 子进程启动并注册的超时时间(秒)""" - workflow_blocking_timeout_sec: float = Field( - default=120.0, + hook_blocking_timeout_sec: float = Field( + default=30, json_schema_extra={ "x-widget": "number", "x-icon": "timer", }, ) - """Workflow 阻塞步骤的全局超时上限(秒)""" + """Hook 阻塞步骤的全局超时上限(秒)""" ipc_socket_path: str = Field( default="", diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 7ac7a518..95da0052 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -25,7 +25,7 @@ class ComponentTypes(str, Enum): COMMAND = "COMMAND" TOOL = "TOOL" EVENT_HANDLER = "EVENT_HANDLER" - WORKFLOW_HANDLER = "WORKFLOW_HANDLER" + HOOK_HANDLER = "HOOK_HANDLER" MESSAGE_GATEWAY = "MESSAGE_GATEWAY" @@ -35,7 +35,7 @@ class StatusDict(TypedDict): COMMAND: int TOOL: int EVENT_HANDLER: int - WORKFLOW_HANDLER: int + HOOK_HANDLER: int MESSAGE_GATEWAY: int plugins: int @@ -105,12 +105,13 @@ class EventHandlerEntry(ComponentEntry): super().__init__(name, component_type, plugin_id, metadata) -class WorkflowHandlerEntry(ComponentEntry): +class HookHandlerEntry(ComponentEntry): """WorkflowHandler 组件条目""" def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: self.stage: str = metadata.get("stage", "") self.priority: int = metadata.get("priority", 0) + self.blocking: bool = metadata.get("blocking", False) super().__init__(name, component_type, plugin_id, metadata) @@ -172,8 +173,8 @@ class ComponentRegistry: comp = ToolEntry(name, component_type, plugin_id, metadata) elif component_type == ComponentTypes.EVENT_HANDLER: comp = EventHandlerEntry(name, component_type, plugin_id, metadata) - elif component_type == ComponentTypes.WORKFLOW_HANDLER: - comp = WorkflowHandlerEntry(name, component_type, plugin_id, metadata) + elif component_type == ComponentTypes.HOOK_HANDLER: + comp = HookHandlerEntry(name, component_type, plugin_id, metadata) elif component_type == ComponentTypes.MESSAGE_GATEWAY: comp = MessageGatewayEntry(name, component_type, plugin_id, metadata) else: @@ -380,23 +381,23 @@ class ComponentRegistry: handlers.sort(key=lambda c: c.weight, reverse=True) return handlers - def get_workflow_handlers( + def get_hook_handlers( self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None - ) -> List[WorkflowHandlerEntry]: - """获取特定 workflow 阶段的所有步骤,按 priority 降序。 + ) -> List[HookHandlerEntry]: + """获取特定 hook 阶段的所有步骤,按 priority 降序。 Args: - stage: workflow 阶段名称 + stage: hook 名称 enabled_only: 是否仅返回启用的组件 session_id: 可选的会话ID,若提供则考虑会话禁用状态 Returns: - handlers (List[WorkflowHandlerEntry]): 符合条件的 WorkflowHandler 组件列表,按 priority 降序排序 + handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序 """ - handlers: List[WorkflowHandlerEntry] = [] - for comp in self._by_type.get(ComponentTypes.WORKFLOW_HANDLER, {}).values(): + handlers: List[HookHandlerEntry] = [] + for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values(): if enabled_only and not self.check_component_enabled(comp, session_id): continue - if not isinstance(comp, WorkflowHandlerEntry): + if not isinstance(comp, HookHandlerEntry): continue if comp.stage == stage: handlers.append(comp) diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py index f08591a8..29ae530b 100644 --- a/src/plugin_runtime/host/event_dispatcher.py +++ b/src/plugin_runtime/host/event_dispatcher.py @@ -50,7 +50,6 @@ class EventDispatcher: self._component_registry: "ComponentRegistry" = component_registry self._result_history: Dict[str, List[EventResult]] = {} self._history_enabled: Set[str] = set() - # 保持 fire-and-forget task 的强引用,防止被 GC 回收 self._background_tasks: Set[asyncio.Task] = set() def enable_history(self, event_type: str) -> None: @@ -68,6 +67,13 @@ class EventDispatcher: if event_type in self._result_history: self._result_history[event_type] = [] + async def stop(self): + """停止 EventDispatcher,取消所有未完成的后台任务""" + for task in self._background_tasks: + task.cancel() + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + async def dispatch_event( self, event_type: str, diff --git a/src/plugin_runtime/host/hook_dispatcher.py b/src/plugin_runtime/host/hook_dispatcher.py new file mode 100644 index 00000000..d5e88448 --- /dev/null +++ b/src/plugin_runtime/host/hook_dispatcher.py @@ -0,0 +1,166 @@ +""" +Hook Dispatch 系统 + +插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。 +每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。 +在参数/返回值匹配的情况下允许修改参数/返回值。 + +HookDispatcher 负责: +1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry) +2. 按 priority 排序,区分 blocking 和非 blocking 模式 +3. blocking 模式:依次同步调用,支持修改参数/提前终止 +4. 非 blocking 模式:异步调用,不阻塞主流程 +5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限 +""" + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING + +from src.common.logger import get_logger +from src.config.config import global_config + + +if TYPE_CHECKING: + from .supervisor import PluginRunnerSupervisor + from .component_registry import ComponentRegistry, HookHandlerEntry + +logger = get_logger("plugin_runtime.host.hook_dispatcher") + + +@dataclass +class HookResult: + """单个 HookHandler 的执行结果""" + + handler_name: str + success: bool = field(default=True) + continue_processing: bool = field(default=True) + modified_kwargs: Optional[Dict[str, Any]] = field(default=None) + custom_result: Any = field(default=None) + + +class HookDispatcher: + """Host-side Hook 分发器 + + 由业务层调用 hook_dispatch(), + 内部通过 ComponentRegistry 查询 handler, + 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。 + """ + + def __init__(self, component_registry: "ComponentRegistry") -> None: + """初始化 HookDispatcher + + Args: + component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler + """ + self._component_registry: "ComponentRegistry" = component_registry + self._background_tasks: Set[asyncio.Task] = set() + + async def stop(self) -> None: + """停止 HookDispatcher,取消所有未完成的后台任务""" + for task in self._background_tasks: + task.cancel() + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + + async def hook_dispatch( + self, + stage: str, + supervisor: "PluginRunnerSupervisor", + **kwargs: Any, + ) -> Dict[str, Any]: + """分发 hook 到所有对应 handler 的便捷方法。 + + 内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑, + 无需调用方手动构造 invoke_fn 闭包。 + + Args: + stage: hook 名称 + supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin + **kwargs: 关键字参数,会展开传递给 handler + + Returns: + modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数 + """ + handler_entries = self._component_registry.get_hook_handlers(stage) + if not handler_entries: + return kwargs + + current_kwargs = kwargs.copy() + blocking_handlers: List["HookHandlerEntry"] = [] + non_blocking_handlers: List["HookHandlerEntry"] = [] + + # 分离 blocking 和非 blocking handler + for entry in handler_entries: + if entry.blocking: + blocking_handlers.append(entry) + else: + non_blocking_handlers.append(entry) + + # 处理 blocking handlers(同步调用,支持修改参数/提前终止) + timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0 + for entry in blocking_handlers: + hook_args = {"stage": stage, **current_kwargs} + try: + # 应用超时控制 + result = await asyncio.wait_for( + self._invoke_handler(supervisor, entry, hook_args), + timeout=timeout, + ) + except asyncio.TimeoutError: + logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过") + result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True) + + if result: + if result.modified_kwargs is not None: + current_kwargs = result.modified_kwargs + if not result.continue_processing: + logger.info(f"HookHandler {entry.full_name} 终止了后续处理") + break + + # 处理 non-blocking handlers(异步调用,不阻塞主流程) + for entry in non_blocking_handlers: + async_kwargs = current_kwargs.copy() + hook_args = {"stage": stage, **async_kwargs} + task = asyncio.create_task( + asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout) + ) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + return current_kwargs + + async def _invoke_handler( + self, + supervisor: "PluginRunnerSupervisor", + handler_entry: "HookHandlerEntry", + args: Dict[str, Any], + ) -> Optional[HookResult]: + """调用单个 handler 并收集结果。 + + Args: + supervisor: PluginRunnerSupervisor 实例 + handler_entry: HookHandlerEntry 实例 + args: 传递给 handler 的参数字典 + stage: hook 名称 + + Returns: + Optional[HookResult]: 执行结果,如果执行失败则返回 None + """ + try: + resp_envelope = await supervisor.invoke_plugin( + "plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args + ) + resp = resp_envelope.payload + result = HookResult( + handler_name=handler_entry.full_name, + success=resp.get("success", True), + continue_processing=resp.get("continue_processing", True), + modified_kwargs=resp.get("modified_kwargs"), + custom_result=resp.get("custom_result"), + ) + except Exception as e: + logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True) + result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True) + + return result diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py deleted file mode 100644 index 3037e9dd..00000000 --- a/src/plugin_runtime/host/workflow_executor.py +++ /dev/null @@ -1,422 +0,0 @@ -"""Host-side WorkflowExecutor - -6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS) - -每个阶段执行顺序: -1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook -2. 按 priority 降序排列 -3. 串行执行 blocking hook(可修改 message,返回 HookResult) -4. 并发执行 non-blocking hook(只读) -5. 检查是否有 SKIP_STAGE 或 ABORT -6. PLAN 阶段内置 Command 匹配路由 - -支持: -- HookResult: CONTINUE / SKIP_STAGE / ABORT -- ErrorPolicy: ABORT / SKIP / LOG (per-hook) -- stage_outputs: 阶段间带命名空间的数据传递 -- modification_log: 消息修改审计 -""" - -from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple - -import asyncio -import time -import uuid - -from src.common.logger import get_logger -from src.config.config import global_config -from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent - -logger = get_logger("plugin_runtime.host.workflow_executor") - -# 阶段顺序 -STAGE_SEQUENCE: List[str] = [ - "ingress", - "pre_process", - "plan", - "tool_execute", - "post_process", - "egress", -] - -# HookResult 常量(与 SDK HookResult enum 值对应) -HOOK_CONTINUE = "continue" -HOOK_SKIP_STAGE = "skip_stage" -HOOK_ABORT = "abort" - - -# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待 -# 从配置文件读取,允许用户调整 -def _get_blocking_timeout() -> float: - return global_config.plugin_runtime.workflow_blocking_timeout_sec - - -class ModificationRecord: - """消息修改记录""" - - __slots__ = ("stage", "hook_name", "timestamp", "fields_changed") - - def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None: - self.stage = stage - self.hook_name = hook_name - self.timestamp = time.perf_counter() - self.fields_changed = fields_changed - - -class WorkflowContext: - """Workflow 执行上下文""" - - def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None: - self.trace_id = trace_id or uuid.uuid4().hex - self.stream_id = stream_id - self.timings: Dict[str, float] = {} - self.errors: List[str] = [] - # 阶段间数据传递(按 stage 命名空间隔离) - self.stage_outputs: Dict[str, Dict[str, Any]] = {} - # 消息修改审计日志 - self.modification_log: List[ModificationRecord] = [] - # PLAN 阶段命令匹配结果 - self.matched_command: Optional[str] = None - - def set_stage_output(self, stage: str, key: str, value: Any) -> None: - self.stage_outputs.setdefault(stage, {})[key] = value - - def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any: - return self.stage_outputs.get(stage, {}).get(key, default) - - -class WorkflowResult: - """Workflow 执行结果""" - - def __init__( - self, - status: str = "completed", # completed / aborted / failed - return_message: str = "", - stopped_at: str = "", - diagnostics: Optional[Dict[str, Any]] = None, - ) -> None: - self.status = status - self.return_message = return_message - self.stopped_at = stopped_at - self.diagnostics = diagnostics or {} - - -# invoke_fn 签名 -InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]] - - -class WorkflowExecutor: - """Host-side Workflow 执行器 - - 实现 stage-based pipeline + per-stage hook chain with priority + early return。 - """ - - def __init__(self, registry: ComponentRegistry) -> None: - self._registry = registry - self._background_tasks: Set[asyncio.Task] = set() - - async def execute( - self, - invoke_fn: InvokeFn, - message: Optional[Dict[str, Any]] = None, - stream_id: Optional[str] = None, - context: Optional[WorkflowContext] = None, - command_invoke_fn: Optional[InvokeFn] = None, - ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]: - """执行 workflow pipeline。 - - Args: - invoke_fn: 用于 workflow_step 的回调 - command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command), - 未传则复用 invoke_fn - - Returns: - (result, final_message, context) - """ - ctx = context or WorkflowContext(stream_id=stream_id) - current_message = dict(message) if message else None - - for stage in STAGE_SEQUENCE: - stage_start = time.perf_counter() - - try: - # PLAN 阶段: 先做 Command 路由 - if stage == "plan" and current_message: - cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx) - if cmd_result is not None: - # 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs - ctx.set_stage_output("plan", "command_result", cmd_result) - ctx.timings[stage] = time.perf_counter() - stage_start - continue - - # 获取该阶段所有 hook(已按 priority 降序排列) - all_steps = self._registry.get_workflow_steps(stage) - if not all_steps: - ctx.timings[stage] = time.perf_counter() - stage_start - continue - - # 1. Pre-filter - filtered_steps = self._pre_filter(all_steps, current_message) - - # 2. 分离 blocking 和 non-blocking - blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)] - nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)] - - # 3. 串行执行 blocking hook - skip_stage = False - for step in blocking_steps: - hook_result, modified, step_error = await self._invoke_step( - invoke_fn, step, stage, ctx, current_message - ) - - if step_error: - error_policy = step.metadata.get("error_policy", "abort") - ctx.errors.append(f"{step.full_name}: {step_error}") - - if error_policy == "abort": - ctx.timings[stage] = time.perf_counter() - stage_start - return ( - WorkflowResult( - status="failed", - return_message=step_error, - stopped_at=stage, - diagnostics={"step": step.full_name, "trace_id": ctx.trace_id}, - ), - current_message, - ctx, - ) - elif error_policy == "skip": - logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}") - continue - else: # log - logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}") - continue - - # 更新消息(仅 blocking hook 有权修改) - if modified: - changed_fields = ( - _diff_keys(current_message, modified) if current_message else list(modified.keys()) - ) - ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields)) - current_message = modified - - if hook_result == HOOK_ABORT: - ctx.timings[stage] = time.perf_counter() - stage_start - return ( - WorkflowResult( - status="aborted", - return_message=f"aborted by {step.full_name}", - stopped_at=stage, - diagnostics={"step": step.full_name, "trace_id": ctx.trace_id}, - ), - current_message, - ctx, - ) - - if hook_result == HOOK_SKIP_STAGE: - skip_stage = True - break - - # 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message) - if nonblocking_steps and not skip_stage: - for step in nonblocking_steps: - self._track_background_task( - asyncio.create_task( - self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message) - ) - ) - - ctx.timings[stage] = time.perf_counter() - stage_start - - except Exception as e: - ctx.timings[stage] = time.perf_counter() - stage_start - ctx.errors.append(f"{stage}: {e}") - logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True) - return ( - WorkflowResult( - status="failed", - return_message=str(e), - stopped_at=stage, - diagnostics={"trace_id": ctx.trace_id}, - ), - current_message, - ctx, - ) - - return ( - WorkflowResult( - status="completed", - return_message="workflow completed", - diagnostics={"trace_id": ctx.trace_id}, - ), - current_message, - ctx, - ) - - def _track_background_task(self, task: asyncio.Task) -> None: - """保持 non-blocking workflow task 的强引用,直到任务结束。""" - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - - # ─── 内部方法 ────────────────────────────────────────────── - - def _pre_filter( - self, - steps: List[RegisteredComponent], - message: Optional[Dict[str, Any]], - ) -> List[RegisteredComponent]: - """根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。""" - if not message: - return steps - - result = [] - for step in steps: - filter_cond = step.metadata.get("filter", {}) - if not filter_cond: - result.append(step) - continue - if self._match_filter(filter_cond, message): - result.append(step) - return result - - @staticmethod - def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool: - """简单 key-value 匹配过滤。 - - filter 中的每个 key 必须在 message 中存在且值相等, - 全部匹配才通过。 - """ - for key, expected in filter_cond.items(): - actual = message.get(key) - if (isinstance(expected, list) and actual not in expected) or ( - not isinstance(expected, list) and actual != expected - ): - return False - return True - - async def _invoke_step( - self, - invoke_fn: InvokeFn, - step: RegisteredComponent, - stage: str, - ctx: WorkflowContext, - message: Optional[Dict[str, Any]], - ) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]: - """调用单个 blocking hook。 - - Returns: - (hook_result, modified_message, error_string_or_None) - """ - timeout_ms = step.metadata.get("timeout_ms", 0) - # 使用 hook 声明的超时,但不超过全局安全阀 - timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout() - step_key = f"{stage}:{step.full_name}" - step_start = time.perf_counter() - - try: - coro = invoke_fn( - step.plugin_id, - step.name, - { - "stage": stage, - "trace_id": ctx.trace_id, - "message": message, - "stage_outputs": ctx.stage_outputs, - }, - ) - resp = await asyncio.wait_for(coro, timeout=timeout_sec) - ctx.timings[step_key] = time.perf_counter() - step_start - - hook_result = resp.get("hook_result", HOOK_CONTINUE) - modified_message = resp.get("modified_message") - # 存 stage output(如果 hook 提供了) - stage_out = resp.get("stage_output") - if isinstance(stage_out, dict): - for k, v in stage_out.items(): - ctx.set_stage_output(stage, k, v) - - return hook_result, modified_message, None - - except asyncio.TimeoutError: - ctx.timings[step_key] = time.perf_counter() - step_start - return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms" - - except Exception as e: - ctx.timings[step_key] = time.perf_counter() - step_start - return HOOK_CONTINUE, None, str(e) - - async def _invoke_step_fire_and_forget( - self, - invoke_fn: InvokeFn, - step: RegisteredComponent, - stage: str, - ctx: WorkflowContext, - message: Optional[Dict[str, Any]], - ) -> None: - """Non-blocking hook 调用,只读,忽略结果。""" - timeout_ms = step.metadata.get("timeout_ms", 0) - # 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏 - timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout() - - try: - coro = invoke_fn( - step.plugin_id, - step.name, - { - "stage": stage, - "trace_id": ctx.trace_id, - "message": message, - "stage_outputs": ctx.stage_outputs, - }, - ) - await asyncio.wait_for(coro, timeout=timeout_sec) - except asyncio.TimeoutError: - logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)") - except Exception as e: - logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}") - - async def _route_command( - self, - invoke_fn: InvokeFn, - message: Dict[str, Any], - ctx: WorkflowContext, - ) -> Optional[Dict[str, Any]]: - """PLAN 阶段内置 Command 路由。 - - 在 registry 中查找匹配的 command 组件, - 匹配到则直接路由到对应 command handler,返回执行结果。 - 不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。 - """ - plain_text = message.get("plain_text", "") - if not plain_text: - return None - - match_result = self._registry.find_command_by_text(plain_text) - if match_result is None: - return None - - matched, matched_groups = match_result - - ctx.matched_command = matched.full_name - logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}") - - try: - return await invoke_fn( - matched.plugin_id, - matched.name, - { - "text": plain_text, - "message": message, - "trace_id": ctx.trace_id, - "matched_groups": matched_groups, - }, - ) - except Exception as e: - logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True) - ctx.errors.append(f"command:{matched.full_name}: {e}") - return None - - -def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]: - """返回 new 中与 old 不同的 key 列表。""" - return [k for k, v in new.items() if k not in old or old[k] != v] diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index ca9e8005..e81df019 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -134,9 +134,9 @@ class ComponentDeclaration(BaseModel): name: str = Field(description="组件名称") """组件名称""" component_type: str = Field( - description="组件类型:action/command/tool/event_handler/workflow_handler/message_gateway" + description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway" ) - """组件类型:`action`/`command`/`tool`/`event_handler`/`workflow_handler`/`message_gateway`""" + """组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`""" plugin_id: str = Field(description="所属插件 ID") """所属插件 ID""" metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据") From 6cc7e37b1e91eae0ed439b7e7c343ac8d32c4e00 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 19 Mar 2026 20:30:09 +0800 Subject: [PATCH 13/45] =?UTF-8?q?refactor:=20=E6=8F=90=E5=8F=96=E9=83=A8?= =?UTF-8?q?=E5=88=86=E5=85=B1=E5=90=8C=E6=96=B9=E6=B3=95=EF=BC=8C=E9=A2=84?= =?UTF-8?q?=E5=A4=87supervisor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/host/event_dispatcher.py | 24 ++- src/plugin_runtime/host/message_gateway.py | 223 +------------------ src/plugin_runtime/host/message_utils.py | 224 ++++++++++++++++++++ 3 files changed, 247 insertions(+), 224 deletions(-) create mode 100644 src/plugin_runtime/host/message_utils.py diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py index 29ae530b..d252b6ee 100644 --- a/src/plugin_runtime/host/event_dispatcher.py +++ b/src/plugin_runtime/host/event_dispatcher.py @@ -14,10 +14,12 @@ import asyncio from src.common.logger import get_logger +from .message_utils import PluginMessageUtils, MessageDict if TYPE_CHECKING: from .supervisor import PluginRunnerSupervisor from .component_registry import ComponentRegistry, EventHandlerEntry + from src.chat.message_receive.message import SessionMessage logger = get_logger("plugin_runtime.host.event_dispatcher") @@ -34,7 +36,7 @@ class EventResult: handler_name: str success: bool = field(default=True) continue_processing: bool = field(default=True) - modified_message: Optional[Dict[str, Any]] = field(default=None) + modified_message: Optional[MessageDict] = field(default=None) custom_result: Any = field(default=None) @@ -78,9 +80,9 @@ class EventDispatcher: self, event_type: str, supervisor: "PluginRunnerSupervisor", - message_dict: Optional[Dict[str, Any]] = None, + message: Optional["SessionMessage"] = None, extra_args: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]]]: + ) -> Tuple[bool, Optional["SessionMessage"]]: """分发事件到所有对应 handler 的便捷方法。 内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑, @@ -93,14 +95,16 @@ class EventDispatcher: extra_args: 额外参数 Returns: - (should_continue, modified_message_dict) (bool, Dict[str, Any] | None): (是否继续后续执行, 可选的修改后的消息字典) + (should_continue, modified_message_dict) (bool, SessionMessage | None): (是否继续后续执行, 可选的修改后的消息) """ handler_entries = self._component_registry.get_event_handlers(event_type) if not handler_entries: return True, None should_continue = True - modified_message: Optional[Dict[str, Any]] = message_dict + modified_message: Optional[MessageDict] = ( + PluginMessageUtils._session_message_to_dict(message) if message else None + ) intercept_handlers: List["EventHandlerEntry"] = [] non_blocking_handlers: List["EventHandlerEntry"] = [] @@ -136,8 +140,14 @@ class EventDispatcher: task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) - - return should_continue, modified_message + try: + modified_message_obj = ( + PluginMessageUtils._build_session_message_from_dict(modified_message) if modified_message else None # type: ignore + ) + except Exception as e: + logger.error(f"构建修改后的 SessionMessage 失败: {e}") + modified_message_obj = None + return should_continue, modified_message_obj async def _invoke_handler( self, diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py index e995ed01..43777286 100644 --- a/src/plugin_runtime/host/message_gateway.py +++ b/src/plugin_runtime/host/message_gateway.py @@ -3,56 +3,19 @@ Message Gateway 模块 适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。 """ -from datetime import datetime -from typing import Dict, Any, TYPE_CHECKING, TypedDict, Optional, List +from typing import Dict, Any, TYPE_CHECKING from src.common.logger import get_logger -from src.chat.message_receive.message import SessionMessage -from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo -from src.common.data_models.message_component_data_model import MessageSequence +from .message_utils import PluginMessageUtils if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage from .component_registry import ComponentRegistry from .supervisor import PluginRunnerSupervisor logger = get_logger("plugin_runtime.host.message_gateway") -class UserInfoDict(TypedDict, total=False): - user_id: str - user_nickname: str - user_cardname: Optional[str] - - -class GroupInfoDict(TypedDict, total=False): - group_id: str - group_name: str - - -class MessageInfoDict(TypedDict, total=False): - user_info: UserInfoDict - group_info: Optional[GroupInfoDict] - additional_config: Dict[str, Any] - - -class MessageDict(TypedDict, total=False): - message_id: str - timestamp: str - platform: str - message_info: MessageInfoDict - raw_message: List[Dict[str, Any]] - is_mentioned: bool - is_at: bool - is_emoji: bool - is_picture: bool - is_command: bool - is_notify: bool - session_id: str - reply_to: Optional[str] - processed_plain_text: Optional[str] - display_message: Optional[str] - - class MessageGateway: def __init__(self, component_registry: "ComponentRegistry") -> None: self._component_registry = component_registry @@ -69,7 +32,7 @@ class MessageGateway: """ # 使用递归函数将外部消息字典转换为 SessionMessage try: - session_message = self._build_session_message_from_dict(external_message) + session_message = PluginMessageUtils._build_session_message_from_dict(external_message) except Exception as e: logger.error(f"转换外部消息失败: {e}") return @@ -79,7 +42,7 @@ class MessageGateway: async def send_message_to_external( self, - internal_message: SessionMessage, + internal_message: "SessionMessage", supervisor: "PluginRunnerSupervisor", *, enabled_only: bool = True, @@ -96,7 +59,7 @@ class MessageGateway: """ try: # 将 SessionMessage 转换为字典格式 - message_dict = self._session_message_to_dict(internal_message) + message_dict = PluginMessageUtils._session_message_to_dict(internal_message) except Exception as e: logger.error(f"转换内部消息失败:{e}") return False @@ -133,177 +96,3 @@ class MessageGateway: except Exception as e: logger.error(f"保存消息到数据库失败: {e}") return True - - def _message_info_to_dict(self, message_info: MessageInfo) -> MessageInfoDict: - """ - 将 MessageInfo 对象转换为字典格式 - - Args: - message_info: MessageInfo 对象 - - Returns: - 字典格式的消息信息 - """ - user_info_dict = UserInfoDict( - user_id=message_info.user_info.user_id, - user_nickname=message_info.user_info.user_nickname, - user_cardname=message_info.user_info.user_cardname, - ) - - group_info_dict: Optional[GroupInfoDict] = None - if message_info.group_info: - group_info_dict = GroupInfoDict( - group_id=message_info.group_info.group_id, - group_name=message_info.group_info.group_name, - ) - - return MessageInfoDict( - user_info=user_info_dict, - group_info=group_info_dict, - additional_config=message_info.additional_config, - ) - - def _session_message_to_dict(self, session_message: SessionMessage) -> MessageDict: - """ - 将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法) - - Args: - session_message: SessionMessage 对象 - - Returns: - 字典格式的消息 - """ - # 转换基本信息 - message_dict = MessageDict( - message_id=session_message.message_id, - timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串 - platform=session_message.platform, - message_info=self._message_info_to_dict(session_message.message_info), - raw_message=session_message.raw_message.to_dict(), # 复用 MessageSequence.to_dict() - is_mentioned=session_message.is_mentioned, - is_at=session_message.is_at, - is_emoji=session_message.is_emoji, - is_picture=session_message.is_picture, - is_command=session_message.is_command, - is_notify=session_message.is_notify, - session_id=session_message.session_id, - ) - - # 添加可选字段 - if session_message.reply_to is not None: - message_dict["reply_to"] = session_message.reply_to - if session_message.processed_plain_text is not None: - message_dict["processed_plain_text"] = session_message.processed_plain_text - if session_message.display_message is not None: - message_dict["display_message"] = session_message.display_message - - return message_dict - - def _build_message_info_from_dict(self, message_info_dict: Dict[str, Any]) -> MessageInfo: - """ - 从字典构建 MessageInfo 对象 - - Args: - message_info_dict: 包含消息信息的字典 - - Returns: - MessageInfo 对象 - """ - # 构建用户信息 - user_info_dict = message_info_dict.get("user_info") - if not user_info_dict or not isinstance(user_info_dict, dict): - raise ValueError("消息字典中 'user_info' 字段无效") - user_id = user_info_dict.get("user_id") - user_nickname = user_info_dict.get("user_nickname") - user_cardname = user_info_dict.get("user_cardname") - if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname: - raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'") - user_cardname = str(user_cardname) if user_cardname is not None else None - user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname) - - # 构建群信息 - if group_info_dict := message_info_dict.get("group_info"): - group_id = group_info_dict.get("group_id") - group_name = group_info_dict.get("group_name") - if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name: - raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'") - group_info = GroupInfo(group_id=group_id, group_name=group_name) - else: - group_info = None - - # 获取额外配置 - additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {}) - - return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config) - - def _build_session_message_from_dict(self, message_dict: Dict[str, Any]) -> SessionMessage: - """ - 从字典构建 SessionMessage 对象(递归处理消息组件) - - Args: - message_dict: 包含消息完整信息的字典 - - Returns: - SessionMessage 对象 - """ - # 提取基本信息 - message_id = message_dict["message_id"] - timestamp_str: str = message_dict.get("timestamp", "") - platform = message_dict["platform"] - if not isinstance(message_id, str) or not message_id: - raise ValueError("消息字典中缺少有效的 'message_id' 字段") - if not isinstance(platform, str) or not platform: - raise ValueError("消息字典中缺少有效的 'platform' 字段") - - # 解析时间戳 - try: - timestamp_float = float(timestamp_str) - timestamp = datetime.fromtimestamp(timestamp_float) - except (ValueError, TypeError): - timestamp = datetime.now() # 如果解析失败,使用当前时间 - - # 创建 SessionMessage 实例 - session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform) - - # 构建消息信息 - session_message.message_info = self._build_message_info_from_dict(message_dict["message_info"]) - - # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法) - raw_message_data = message_dict["raw_message"] - if isinstance(raw_message_data, list): - session_message.raw_message = MessageSequence.from_dict(raw_message_data) - else: - raise ValueError("消息字典中 'raw_message' 字段必须是一个列表") - - # 设置其他可选属性 - session_message.is_mentioned = message_dict.get("is_mentioned", False) - if not isinstance(session_message.is_mentioned, bool): - session_message.is_mentioned = False - session_message.is_at = message_dict.get("is_at", False) - if not isinstance(session_message.is_at, bool): - session_message.is_at = False - session_message.is_emoji = message_dict.get("is_emoji", False) - if not isinstance(session_message.is_emoji, bool): - session_message.is_emoji = False - session_message.is_picture = message_dict.get("is_picture", False) - if not isinstance(session_message.is_picture, bool): - session_message.is_picture = False - session_message.is_command = message_dict.get("is_command", False) - if not isinstance(session_message.is_command, bool): - session_message.is_command = False - session_message.is_notify = message_dict.get("is_notify", False) - if not isinstance(session_message.is_notify, bool): - session_message.is_notify = False - session_message.reply_to = message_dict.get("reply_to") - if session_message.reply_to is not None and not isinstance(session_message.reply_to, str): - session_message.reply_to = None - session_message.processed_plain_text = message_dict.get("processed_plain_text") - if session_message.processed_plain_text is not None and not isinstance( - session_message.processed_plain_text, str - ): - session_message.processed_plain_text = None - session_message.display_message = message_dict.get("display_message") - if session_message.display_message is not None and not isinstance(session_message.display_message, str): - session_message.display_message = None - - return session_message diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py new file mode 100644 index 00000000..428e3c48 --- /dev/null +++ b/src/plugin_runtime/host/message_utils.py @@ -0,0 +1,224 @@ +from datetime import datetime +from typing import Dict, Any, TypedDict, Optional, List + +from src.common.logger import get_logger +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo +from src.common.data_models.message_component_data_model import MessageSequence + +logger = get_logger("plugin_runtime.host.message_utils") + + +class UserInfoDict(TypedDict, total=False): + user_id: str + user_nickname: str + user_cardname: Optional[str] + + +class GroupInfoDict(TypedDict, total=False): + group_id: str + group_name: str + + +class MessageInfoDict(TypedDict, total=False): + user_info: UserInfoDict + group_info: Optional[GroupInfoDict] + additional_config: Dict[str, Any] + + +class MessageDict(TypedDict, total=False): + message_id: str + timestamp: str + platform: str + message_info: MessageInfoDict + raw_message: List[Dict[str, Any]] + is_mentioned: bool + is_at: bool + is_emoji: bool + is_picture: bool + is_command: bool + is_notify: bool + session_id: str + reply_to: Optional[str] + processed_plain_text: Optional[str] + display_message: Optional[str] + + +class PluginMessageUtils: + @staticmethod + def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict: + """ + 将 MessageInfo 对象转换为字典格式 + + Args: + message_info: MessageInfo 对象 + + Returns: + 字典格式的消息信息 + """ + user_info_dict = UserInfoDict( + user_id=message_info.user_info.user_id, + user_nickname=message_info.user_info.user_nickname, + user_cardname=message_info.user_info.user_cardname, + ) + + group_info_dict: Optional[GroupInfoDict] = None + if message_info.group_info: + group_info_dict = GroupInfoDict( + group_id=message_info.group_info.group_id, + group_name=message_info.group_info.group_name, + ) + + return MessageInfoDict( + user_info=user_info_dict, + group_info=group_info_dict, + additional_config=message_info.additional_config, + ) + + @staticmethod + def _session_message_to_dict(session_message: SessionMessage) -> MessageDict: + """ + 将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法) + + Args: + session_message: SessionMessage 对象 + + Returns: + 字典格式的消息 + """ + # 转换基本信息 + message_dict = MessageDict( + message_id=session_message.message_id, + timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串 + platform=session_message.platform, + message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info), + raw_message=session_message.raw_message.to_dict(), # 复用 MessageSequence.to_dict() + is_mentioned=session_message.is_mentioned, + is_at=session_message.is_at, + is_emoji=session_message.is_emoji, + is_picture=session_message.is_picture, + is_command=session_message.is_command, + is_notify=session_message.is_notify, + session_id=session_message.session_id, + ) + + # 添加可选字段 + if session_message.reply_to is not None: + message_dict["reply_to"] = session_message.reply_to + if session_message.processed_plain_text is not None: + message_dict["processed_plain_text"] = session_message.processed_plain_text + if session_message.display_message is not None: + message_dict["display_message"] = session_message.display_message + + return message_dict + + @staticmethod + def _build_message_info_from_dict(message_info_dict: Dict[str, Any]) -> MessageInfo: + """ + 从字典构建 MessageInfo 对象 + + Args: + message_info_dict: 包含消息信息的字典 + + Returns: + MessageInfo 对象 + """ + # 构建用户信息 + user_info_dict = message_info_dict.get("user_info") + if not user_info_dict or not isinstance(user_info_dict, dict): + raise ValueError("消息字典中 'user_info' 字段无效") + user_id = user_info_dict.get("user_id") + user_nickname = user_info_dict.get("user_nickname") + user_cardname = user_info_dict.get("user_cardname") + if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname: + raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'") + user_cardname = str(user_cardname) if user_cardname is not None else None + user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname) + + # 构建群信息 + if group_info_dict := message_info_dict.get("group_info"): + group_id = group_info_dict.get("group_id") + group_name = group_info_dict.get("group_name") + if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name: + raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'") + group_info = GroupInfo(group_id=group_id, group_name=group_name) + else: + group_info = None + + # 获取额外配置 + additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {}) + + return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config) + + @staticmethod + def _build_session_message_from_dict(message_dict: Dict[str, Any]) -> SessionMessage: + """ + 从字典构建 SessionMessage 对象(递归处理消息组件) + + Args: + message_dict: 包含消息完整信息的字典 + + Returns: + SessionMessage 对象 + """ + # 提取基本信息 + message_id = message_dict["message_id"] + timestamp_str: str = message_dict.get("timestamp", "") + platform = message_dict["platform"] + if not isinstance(message_id, str) or not message_id: + raise ValueError("消息字典中缺少有效的 'message_id' 字段") + if not isinstance(platform, str) or not platform: + raise ValueError("消息字典中缺少有效的 'platform' 字段") + + # 解析时间戳 + try: + timestamp_float = float(timestamp_str) + timestamp = datetime.fromtimestamp(timestamp_float) + except (ValueError, TypeError): + timestamp = datetime.now() # 如果解析失败,使用当前时间 + + # 创建 SessionMessage 实例 + session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform) + + # 构建消息信息 + session_message.message_info = PluginMessageUtils._build_message_info_from_dict(message_dict["message_info"]) + + # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法) + raw_message_data = message_dict["raw_message"] + if isinstance(raw_message_data, list): + session_message.raw_message = MessageSequence.from_dict(raw_message_data) + else: + raise ValueError("消息字典中 'raw_message' 字段必须是一个列表") + + # 设置其他可选属性 + session_message.is_mentioned = message_dict.get("is_mentioned", False) + if not isinstance(session_message.is_mentioned, bool): + session_message.is_mentioned = False + session_message.is_at = message_dict.get("is_at", False) + if not isinstance(session_message.is_at, bool): + session_message.is_at = False + session_message.is_emoji = message_dict.get("is_emoji", False) + if not isinstance(session_message.is_emoji, bool): + session_message.is_emoji = False + session_message.is_picture = message_dict.get("is_picture", False) + if not isinstance(session_message.is_picture, bool): + session_message.is_picture = False + session_message.is_command = message_dict.get("is_command", False) + if not isinstance(session_message.is_command, bool): + session_message.is_command = False + session_message.is_notify = message_dict.get("is_notify", False) + if not isinstance(session_message.is_notify, bool): + session_message.is_notify = False + session_message.reply_to = message_dict.get("reply_to") + if session_message.reply_to is not None and not isinstance(session_message.reply_to, str): + session_message.reply_to = None + session_message.processed_plain_text = message_dict.get("processed_plain_text") + if session_message.processed_plain_text is not None and not isinstance( + session_message.processed_plain_text, str + ): + session_message.processed_plain_text = None + session_message.display_message = message_dict.get("display_message") + if session_message.display_message is not None and not isinstance(session_message.display_message, str): + session_message.display_message = None + + return session_message From 3d22657707b757e73cb77bf52ee0e5ea924f7832 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Fri, 20 Mar 2026 21:39:19 +0800 Subject: [PATCH 14/45] =?UTF-8?q?refactor:=20supervisor=E9=83=A8=E5=88=86?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E9=87=8D=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/host/supervisor.py | 681 ++++---------------------- 1 file changed, 87 insertions(+), 594 deletions(-) diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index bfa00cbf..82a5970b 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -1,98 +1,41 @@ -"""Supervisor - 插件生命周期管理 - -负责: -1. 拉起 Runner 子进程 -2. 健康检查 + 崩溃自动重启 -3. 代码热重载(generation 切换) -4. 优雅关停 -""" - -from typing import Any, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING import asyncio -import contextlib -import logging as stdlib_logging -import os -import sys -from pathlib import Path + from src.common.logger import get_logger -from src.config.config import MMC_VERSION, global_config -from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN -from src.plugin_runtime.host.capability_service import CapabilityService -from src.plugin_runtime.host.component_registry import ComponentRegistry -from src.plugin_runtime.host.event_dispatcher import EventDispatcher -from src.plugin_runtime.host.policy_engine import PolicyEngine -from src.plugin_runtime.host.rpc_server import RPCServer -from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult +from src.config.config import global_config +from src.plugin_runtime.transport.factory import create_transport_server from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, ConfigUpdatedPayload, Envelope, HealthPayload, LogBatchPayload, - RegisterComponentsPayload, + RegisterPluginPayload, RunnerReadyPayload, ShutdownPayload, ) -from src.plugin_runtime.protocol.errors import ErrorCode, RPCError -from src.plugin_runtime.transport.factory import create_transport_server -logger = get_logger("plugin_runtime.host.supervisor") +from .authorization import AuthorizationManager +from .capability_service import CapabilityService +from .rpc_server import RPCServer +from .logger_bridge import RunnerLogBridge +from .component_registry import ComponentRegistry +from .event_dispatcher import EventDispatcher +from .hook_dispatcher import HookDispatcher +from .message_gateway import MessageGateway +from .message_utils import PluginMessageUtils + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + +logger = get_logger("plugin_runtime.host.runner_manager") -# ─── 日志桥 ────────────────────────────────────────────────────── - - -class RunnerLogBridge: - """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。 - - Runner 通过 ``runner.log_batch`` IPC 事件批量到达。 - 每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接 - 调用 ``logging.getLogger(entry.logger_name).handle(record)``, - 从而接入主进程已配置好的 structlog Handler 链。 - """ - - async def handle_log_batch(self, envelope: Envelope) -> Envelope: - """IPC 事件处理器:解析批量日志并重放到主进程 Logger。 - - Args: - envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。 - - Returns: - 空响应信封(事件模式下将被忽略)。 - """ - try: - batch = LogBatchPayload.model_validate(envelope.payload) - except Exception as exc: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - - for entry in batch.entries: - # 重建一个与原始日志尽量相符的 LogRecord - record = stdlib_logging.LogRecord( - name=entry.logger_name, - level=entry.level, - pathname="", - lineno=0, - msg=entry.message, - args=(), - exc_info=None, - ) - record.created = entry.timestamp_ms / 1000.0 - record.msecs = entry.timestamp_ms % 1000 - if entry.exception_text: - record.exc_text = entry.exception_text - - stdlib_logging.getLogger(entry.logger_name).handle(record) - - return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)}) - - -class PluginSupervisor: - """插件 Supervisor - - Host 端的核心管理器,负责整个插件 Runner 进程的生命周期。 - """ +class PluginRunnerSupervisor: + """插件的Runner管理器,负责管理Runner的生命周期""" def __init__( self, @@ -103,45 +46,34 @@ class PluginSupervisor: runner_spawn_timeout_sec: Optional[float] = None, ): _cfg = global_config.plugin_runtime - self._plugin_dirs = plugin_dirs or [] - self._health_interval = ( - health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec - ) - self._runner_spawn_timeout = ( - runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec - ) + self._plugin_dirs: List[Path] = plugin_dirs or [] + self._health_interval = health_check_interval_sec or _cfg.health_check_interval_sec or 30.0 + self._runner_spawn_timeout = runner_spawn_timeout_sec or _cfg.runner_spawn_timeout_sec or 30.0 # 基础设施 self._transport = create_transport_server(socket_path=socket_path) - self._policy = PolicyEngine() - self._capability_service = CapabilityService(self._policy) + self._authorization = AuthorizationManager() + self._capability_service = CapabilityService(self._authorization) self._component_registry = ComponentRegistry() self._event_dispatcher = EventDispatcher(self._component_registry) - self._workflow_executor = WorkflowExecutor(self._component_registry) + self._hook_dispatcher = HookDispatcher(self._component_registry) + self._message_gateway = MessageGateway(self._component_registry) - # 编解码 + # 编解码和服务器 from src.plugin_runtime.protocol.codec import MsgPackCodec codec = MsgPackCodec() - - self._rpc_server = RPCServer( - transport=self._transport, - codec=codec, - ) + self._rpc_server = RPCServer(transport=self._transport, codec=codec) # Runner 子进程 self._runner_process: Optional[asyncio.subprocess.Process] = None - self._runner_generation: int = 0 - self._max_restart_attempts: int = ( - max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts - ) + self._max_restart_attempts: int = max_restart_attempts or _cfg.max_restart_attempts or 3 self._restart_count: int = 0 # 已注册的插件组件信息 - self._registered_plugins: Dict[str, RegisterComponentsPayload] = {} - self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {} - self._runner_ready_events: Dict[int, asyncio.Event] = {} - self._runner_ready_payloads: Dict[int, RunnerReadyPayload] = {} + self._registered_plugins: Dict[str, RegisterPluginPayload] = {} + self._runner_ready_events: asyncio.Event = asyncio.Event() + self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() # 后台任务 self._health_task: Optional[asyncio.Task] = None @@ -153,11 +85,11 @@ class PluginSupervisor: self._log_bridge: RunnerLogBridge = RunnerLogBridge() # 注册内部 RPC 方法 - self._register_internal_methods() + self._register_internal_methods() # TODO: 完成内部方法注册 @property - def policy_engine(self) -> PolicyEngine: - return self._policy + def authorization_manager(self) -> AuthorizationManager: + return self._authorization @property def capability_service(self) -> CapabilityService: @@ -172,8 +104,12 @@ class PluginSupervisor: return self._event_dispatcher @property - def workflow_executor(self) -> WorkflowExecutor: - return self._workflow_executor + def hook_dispatcher(self) -> HookDispatcher: + return self._hook_dispatcher + + @property + def message_gateway(self) -> MessageGateway: + return self._message_gateway @property def rpc_server(self) -> RPCServer: @@ -182,64 +118,26 @@ class PluginSupervisor: async def dispatch_event( self, event_type: str, - message: Optional[Dict[str, Any]] = None, + message: Optional["SessionMessage"] = None, extra_args: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]]]: + ) -> Tuple[bool, Optional["SessionMessage"]]: """分发事件到所有对应 handler 的快捷方法。""" + return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args) - async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: - resp = await self.invoke_plugin( - method="plugin.emit_event", - plugin_id=plugin_id, - component_name=component_name, - args=args, - ) - return resp.payload + async def dispatch_hook(self, stage: str, **kwargs): + """分发Hook事件到所有对应 handler 的快捷方法。""" + return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs) - return await self._event_dispatcher.dispatch_event( - event_type=event_type, - invoke_fn=_invoke, - message=message, - extra_args=extra_args, - ) - - async def execute_workflow( + async def send_message_to_external( self, - message: Optional[Dict[str, Any]] = None, - stream_id: Optional[str] = None, - context: Optional[WorkflowContext] = None, - ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]: - """执行 Workflow Pipeline 的快捷方法。""" - - async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: - resp = await self.invoke_plugin( - method="plugin.invoke_workflow_step", - plugin_id=plugin_id, - component_name=component_name, - args=args, - ) - payload = resp.payload - if payload.get("success"): - result = payload.get("result") - return result if isinstance(result, dict) else {} - raise RuntimeError(payload.get("result", "workflow step invoke failed")) - - async def _command_invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: - """命令走 plugin.invoke_command,保留原始返回值结构。""" - resp = await self.invoke_plugin( - method="plugin.invoke_command", - plugin_id=plugin_id, - component_name=component_name, - args=args, - ) - return resp.payload - - return await self._workflow_executor.execute( - invoke_fn=_invoke, - message=message, - stream_id=stream_id, - context=context, - command_invoke_fn=_command_invoke, + internal_message: "SessionMessage", + *, + enabled_only: bool = True, + save_to_db: bool = True, + ) -> bool: + """发送系统内部消息到外部平台的快捷方法。""" + return await self._message_gateway.send_message_to_external( + internal_message, self, enabled_only=enabled_only, save_to_db=save_to_db ) async def start(self) -> None: @@ -253,17 +151,13 @@ class PluginSupervisor: # 启动 RPC Server await self._rpc_server.start() - - # 计算预期 generation(与 reload_plugins 保持一致) - expected_generation = self._rpc_server.runner_generation + 1 - # 拉起 Runner 进程 await self._spawn_runner() # 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪 try: - await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout) - await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout) + await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) + await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) except TimeoutError: if not self._rpc_server.is_connected: logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败") @@ -279,6 +173,10 @@ class PluginSupervisor: """停止 Supervisor""" self._running = False + # 停止组件 + await self._event_dispatcher.stop() + await self._hook_dispatcher.stop() + # 停止健康检查 if self._health_task: self._health_task.cancel() @@ -305,439 +203,34 @@ class PluginSupervisor: 由主进程业务逻辑调用,通过 RPC 转发给 Runner。 """ return await self._rpc_server.send_request( - method=method, - plugin_id=plugin_id, - payload={ - "component_name": component_name, - "args": args or {}, - }, - timeout_ms=timeout_ms, + method, + plugin_id, + {"component_name": component_name, "args": args or {}}, + timeout_ms, ) - async def reload_plugins(self, reason: str = "manual") -> bool: - """热重载所有插件(进程级 generation 切换) + async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: + raise NotImplementedError("等待SDK完成") # TODO: 完成对应的调用和请求逻辑 - 1. 拉起新 Runner - 2. 等待新 Runner 完成注册和健康检查 - 3. 关停旧 Runner - """ - logger.info(f"开始热重载插件,原因: {reason}") + async def _wait_for_runner_connection(self, timeout_sec: float) -> None: + """等待 Runner 连接上 RPC Server""" - # 保存旧进程引用和旧 session token(回滚时需要恢复) - old_process = self._runner_process - old_registered_plugins = dict(self._registered_plugins) - old_session_token = self._rpc_server.session_token - expected_generation = self._rpc_server.runner_generation + 1 + async def wait_for_connection(): + while self._running and not self._rpc_server.is_connected: + await asyncio.sleep(0.1) - # 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接 - self._rpc_server.begin_staged_takeover() - self._staged_registered_plugins.clear() - - # 重新生成 session token,防止被终止的旧 Runner 重连 - self._rpc_server.reset_session_token() - - # 注意:不在此处调用 _clear_runtime_state()。 - # 旧组件在新 Runner 完成注册前继续提供服务,避免热重载窗口期内 - # dispatch_event / execute_workflow 找不到任何组件导致消息静默丢失。 - # ComponentRegistry.register_component 对同名组件是覆盖式写入,安全。 - - # 拉起新 Runner try: - await self._spawn_runner() - await self._wait_for_runner_generation( - expected_generation, - timeout_sec=self._runner_spawn_timeout, - allow_staged=True, - ) - await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout) - resp = await self._rpc_server.send_request( - "plugin.health", - timeout_ms=5000, - target_generation=expected_generation, - ) - health = HealthPayload.model_validate(resp.payload) - if not health.healthy: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败") - await self._rpc_server.commit_staged_takeover() - except Exception as e: - logger.error(f"新 Runner 健康检查失败: {e},回滚") - await self._terminate_process(self._runner_process, old_process) - await self._rpc_server.rollback_staged_takeover() - self._runner_process = old_process - self._rpc_server.restore_session_token(old_session_token) - self._staged_registered_plugins.clear() - self._registered_plugins = dict(old_registered_plugins) - self._rebuild_runtime_state() - return False + await asyncio.wait_for(wait_for_connection(), timeout=timeout_sec) + logger.info("Runner 已连接到 RPC Server") + except asyncio.TimeoutError as e: + raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from e - self._runner_generation = self._rpc_server.runner_generation - self._registered_plugins = dict(self._staged_registered_plugins) - self._staged_registered_plugins.clear() - self._rebuild_runtime_state() + async def _wait_for_runner_ready(self, timeout_sec: float = 30.0) -> RunnerReadyPayload: + """等待 Runner 完成初始化并上报就绪""" - # 关停旧 Runner - if old_process and old_process.returncode is None: - try: - old_process.terminate() - await asyncio.wait_for(old_process.wait(), timeout=10.0) - except asyncio.TimeoutError: - old_process.kill() - - logger.info("热重载完成") - return True - - async def notify_plugin_config_updated( - self, - plugin_id: str, - config_data: Dict[str, Any], - config_version: str = "", - ) -> bool: - """通知指定插件其配置已更新。""" - if plugin_id not in self._registered_plugins: - return False - - payload = ConfigUpdatedPayload( - plugin_id=plugin_id, - config_version=config_version, - config_data=config_data, - ) - await self._rpc_server.send_request( - "plugin.config_updated", - plugin_id=plugin_id, - payload=payload.model_dump(), - timeout_ms=5000, - ) - return True - - # ─── 内部方法 ────────────────────────────────────────────── - - def _register_internal_methods(self) -> None: - """注册 Host 端的 RPC 方法处理器""" - # Runner -> Host 的能力调用统一走 capability_service - self._rpc_server.register_method("cap.request", self._capability_service.handle_capability_request) - self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) - # 插件注册 - self._rpc_server.register_method("plugin.register_components", self._handle_register_components) - self._rpc_server.register_method("runner.ready", self._handle_runner_ready) - # Runner 日志批量上报 - self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch) - - async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope: - """处理插件 bootstrap 请求,仅同步能力令牌。""" try: - bootstrap = BootstrapPluginPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) - - active_generation = self._rpc_server.runner_generation - staged_generation = self._rpc_server.staged_generation - if envelope.generation not in {active_generation, staged_generation}: - return envelope.make_error_response( - ErrorCode.E_GENERATION_MISMATCH.value, - f"插件 bootstrap generation 过期: {envelope.generation} 不在已知代际中", - ) - - if bootstrap.capabilities_required: - self._policy.register_plugin( - plugin_id=bootstrap.plugin_id, - generation=envelope.generation, - capabilities=bootstrap.capabilities_required, - ) - else: - self._policy.revoke_plugin(bootstrap.plugin_id, generation=envelope.generation) - - return envelope.make_response(payload={"accepted": True}) - - async def _handle_register_components(self, envelope: Envelope) -> Envelope: - """处理插件组件注册请求""" - try: - reg = RegisterComponentsPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) - - active_generation = self._rpc_server.runner_generation - staged_generation = self._rpc_server.staged_generation - if envelope.generation not in {active_generation, staged_generation}: - return envelope.make_error_response( - ErrorCode.E_GENERATION_MISMATCH.value, - f"组件注册 generation 过期: {envelope.generation} 不在已知代际中", - ) - - if envelope.generation == staged_generation and staged_generation != 0: - self._staged_registered_plugins[reg.plugin_id] = reg - logger.info( - f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功," - f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}" - ) - return envelope.make_response(payload={"accepted": True, "staged": True}) - - self._registered_plugins[reg.plugin_id] = reg - - # 在策略引擎中注册插件 - self._policy.register_plugin( - plugin_id=reg.plugin_id, - generation=envelope.generation, - capabilities=reg.capabilities_required or [], - ) - - # 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件 - self._component_registry.remove_components_by_plugin(reg.plugin_id) - self._component_registry.register_plugin_components( - plugin_id=reg.plugin_id, - components=[c.model_dump() for c in reg.components], - ) - - stats = self._component_registry.get_stats() - logger.info( - f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功," - f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}," - f"注册表总计: {stats}" - ) - - return envelope.make_response(payload={"accepted": True}) - - async def _handle_runner_ready(self, envelope: Envelope) -> Envelope: - """处理 Runner 初始化完成信号。""" - try: - ready = RunnerReadyPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) - - event = self._runner_ready_events.setdefault(envelope.generation, asyncio.Event()) - self._runner_ready_payloads[envelope.generation] = ready - event.set() - logger.info( - f"Runner generation={envelope.generation} 已就绪,成功插件数: {len(ready.loaded_plugins)}," - f"失败插件数: {len(ready.failed_plugins)}" - ) - return envelope.make_response(payload={"accepted": True}) - - async def _spawn_runner(self) -> None: - """拉起 Runner 子进程""" - runner_module = "src.plugin_runtime.runner.runner_main" - address = self._transport.get_address() - token = self._rpc_server.session_token - - env = os.environ.copy() - env[ENV_IPC_ADDRESS] = address - env[ENV_SESSION_TOKEN] = token - env[ENV_PLUGIN_DIRS] = os.pathsep.join(str(p) for p in self._plugin_dirs) - env[ENV_HOST_VERSION] = MMC_VERSION - - self._runner_process = await asyncio.create_subprocess_exec( - sys.executable, - "-m", - runner_module, - env=env, - # stdout 不捕获:Runner 的日志均通过 IPC 传㛹(RunnerIPCLogHandler) - stdout=None, - # stderr 捕获为 PIPE,仅用于 IPC 建立前的进程级致命错误输出 - stderr=asyncio.subprocess.PIPE, - ) - - self._attach_stderr_drain(self._runner_process) - self._runner_generation = self._rpc_server.runner_generation - logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}") - - async def _shutdown_runner(self) -> None: - """优雅关停 Runner""" - if not self._runner_process or self._runner_process.returncode is not None: - return - - # 发送 prepare_shutdown - try: - if self._rpc_server.is_connected: - shutdown_payload = ShutdownPayload(reason="host_shutdown", drain_timeout_ms=5000) - await self._rpc_server.send_request( - "plugin.prepare_shutdown", - payload=shutdown_payload.model_dump(), - timeout_ms=5000, - ) - await self._rpc_server.send_request( - "plugin.shutdown", - payload=shutdown_payload.model_dump(), - timeout_ms=5000, - ) - except Exception as e: - logger.warning(f"发送关停命令失败: {e}") - - # 等待进程退出 - try: - await asyncio.wait_for(self._runner_process.wait(), timeout=10.0) - except asyncio.TimeoutError: - logger.warning("Runner 未在超时内退出,强制终止") - self._runner_process.kill() - await self._runner_process.wait() - - await self._cleanup_stderr_drain() - - async def _health_check_loop(self) -> None: - """周期性健康检查 + 崩溃自动重启""" - while self._running: - await asyncio.sleep(self._health_interval) - - # 检查 Runner 进程是否意外退出 - if self._runner_process and self._runner_process.returncode is not None: - exit_code = self._runner_process.returncode - logger.warning(f"Runner 进程已退出 (exit_code={exit_code})") - - if self._restart_count < self._max_restart_attempts: - self._restart_count += 1 - logger.info(f"尝试重启 Runner ({self._restart_count}/{self._max_restart_attempts})") - # 清理旧的组件注册 - for plugin_id in list(self._registered_plugins.keys()): - self._component_registry.remove_components_by_plugin(plugin_id) - self._policy.revoke_plugin(plugin_id) - self._registered_plugins.clear() - - try: - self._clear_runtime_state() - # 重新生成 session token,防止旧 Runner 僵尸进程用旧 token 重连 - self._rpc_server.reset_session_token() - await self._spawn_runner() - except Exception as e: - logger.error(f"Runner 重启失败: {e}", exc_info=True) - else: - logger.error(f"Runner 连续崩溃 {self._max_restart_attempts} 次,停止重启") - continue - - if not self._rpc_server.is_connected: - logger.warning("Runner 未连接,跳过健康检查") - continue - - try: - resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000) - health = HealthPayload.model_validate(resp.payload) - if not health.healthy: - logger.warning(f"Runner 健康检查异常: {health}") - else: - # 健康检查成功,重置重启计数 - self._restart_count = 0 - except RPCError as e: - logger.error(f"健康检查失败: {e}") - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"健康检查异常: {e}") - - async def _wait_for_runner_generation( - self, - expected_generation: int, - timeout_sec: float, - allow_staged: bool = False, - ) -> None: - """等待指定代际的 Runner 完成连接。""" - deadline = asyncio.get_running_loop().time() + timeout_sec - while asyncio.get_running_loop().time() < deadline: - if allow_staged and self._rpc_server.has_generation(expected_generation): - return - if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation: - self._runner_generation = self._rpc_server.runner_generation - return - await asyncio.sleep(0.1) - raise TimeoutError(f"等待 Runner generation {expected_generation} 超时") - - async def _wait_for_runner_ready(self, expected_generation: int, timeout_sec: float) -> RunnerReadyPayload: - """等待指定代际的 Runner 完成初始化。""" - event = self._runner_ready_events.setdefault(expected_generation, asyncio.Event()) - await asyncio.wait_for(event.wait(), timeout=timeout_sec) - return self._runner_ready_payloads.get(expected_generation, RunnerReadyPayload()) - - def _clear_runtime_state(self) -> None: - """清空当前插件注册态。""" - self._component_registry.clear() - self._policy.clear() - self._registered_plugins.clear() - self._staged_registered_plugins.clear() - - def _rebuild_runtime_state(self) -> None: - """根据已记录的插件注册信息重建运行时状态。""" - self._component_registry.clear() - self._policy.clear() - for reg in self._registered_plugins.values(): - self._policy.register_plugin( - plugin_id=reg.plugin_id, - generation=self._rpc_server.runner_generation, - capabilities=reg.capabilities_required or [], - ) - self._component_registry.register_plugin_components( - plugin_id=reg.plugin_id, - components=[c.model_dump() for c in reg.components], - ) - - def _attach_stderr_drain(self, process: asyncio.subprocess.Process) -> None: - """为 Runner stderr 创建排空任务,捕获 IPC 建立前的进程级错误输出。 - - stderr 中的内容通常是: - - Runner 启动早期(握手完成之前)的日志 - - 进程级致命错误(ImportError、SyntaxError等) - - 异常进程退出前的最后输出 - - 握手成功后,插件的所有日志均经由 RunnerIPCLogHandler 通过 IPC 传输。 - """ - if process.stderr is None: - return - task = asyncio.create_task( - self._drain_runner_stderr(process.stderr, process.pid), - name=f"runner_stderr_drain:{process.pid}", - ) - self._stderr_drain_task = task - task.add_done_callback( - lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task() - ) - - def _clear_stderr_drain_task(self) -> None: - self._stderr_drain_task = None - - async def _drain_runner_stderr( - self, - stream: asyncio.StreamReader, - pid: int, - ) -> None: - """持续读取 Runner stderr 并转发到 Host Logger,防止 PIPE 锡死子进程。 - - Args: - stream: Runner 子进程的 stderr 流。 - pid: 子进程 PID,仅用于日志上下文。 - """ - try: - while True: - line = await stream.readline() - if not line: - break - if message := line.decode(errors="replace").rstrip(): - # 将 stderr 输出以 WARNING 级展示: - # 如果 Runner 正常运行,此流应当无输出; - # 有输出说明进程级错误发生,需要出现在主进程日志中 - logger.warning(f"[runner:{pid}:stderr] {message}") - except asyncio.CancelledError: - raise - except Exception as exc: - logger.debug(f"读取 Runner stderr 失败 (pid={pid}): {exc}") - - async def _cleanup_stderr_drain(self) -> None: - """等待并取消 stderr 排空任务。""" - if self._stderr_drain_task is None: - return - task = self._stderr_drain_task - self._stderr_drain_task = None - if not task.done(): - task.cancel() - with contextlib.suppress(Exception): - await asyncio.gather(task, return_exceptions=True) - - @staticmethod - async def _terminate_process( - process: Optional[asyncio.subprocess.Process], - keep_process: Optional[asyncio.subprocess.Process] = None, - ) -> None: - """终止指定进程,但跳过需要保留的旧进程引用。""" - if process is None or process is keep_process or process.returncode is not None: - return - - process.terminate() - try: - await asyncio.wait_for(process.wait(), timeout=10.0) - except asyncio.TimeoutError: - process.kill() - await process.wait() + await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec) + logger.info("Runner 已完成初始化并上报就绪") + return self._runner_ready_payloads + except asyncio.TimeoutError as e: + raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from e From 04f260e570070a792ff00906f5694877a001637e Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 20 Mar 2026 01:15:17 +0800 Subject: [PATCH 15/45] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=AE=8C=E6=95=B4?= =?UTF-8?q?=E7=9A=84=E6=B6=88=E6=81=AF=E4=B8=AD=E9=97=B4=E5=B1=82=E5=9C=B0?= =?UTF-8?q?=E5=9F=BA=EF=BC=8C=E6=9A=82=E6=9C=AA=E6=8E=A5=E5=85=A5=E5=AE=9E?= =?UTF-8?q?=E9=99=85=E7=9A=84=E6=B6=88=E6=81=AF=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/platform_io/__init__.py | 35 ++ src/platform_io/dedupe.py | 133 ++++++ src/platform_io/drivers/__init__.py | 11 + src/platform_io/drivers/base.py | 104 +++++ src/platform_io/drivers/legacy_driver.py | 61 +++ src/platform_io/drivers/plugin_driver.py | 64 +++ src/platform_io/manager.py | 529 +++++++++++++++++++++++ src/platform_io/outbound_tracker.py | 242 +++++++++++ src/platform_io/registry.py | 70 +++ src/platform_io/route_key_factory.py | 150 +++++++ src/platform_io/routing.py | 202 +++++++++ src/platform_io/types.py | 240 ++++++++++ 12 files changed, 1841 insertions(+) create mode 100644 src/platform_io/__init__.py create mode 100644 src/platform_io/dedupe.py create mode 100644 src/platform_io/drivers/__init__.py create mode 100644 src/platform_io/drivers/base.py create mode 100644 src/platform_io/drivers/legacy_driver.py create mode 100644 src/platform_io/drivers/plugin_driver.py create mode 100644 src/platform_io/manager.py create mode 100644 src/platform_io/outbound_tracker.py create mode 100644 src/platform_io/registry.py create mode 100644 src/platform_io/route_key_factory.py create mode 100644 src/platform_io/routing.py create mode 100644 src/platform_io/types.py diff --git a/src/platform_io/__init__.py b/src/platform_io/__init__.py new file mode 100644 index 00000000..380ecbb6 --- /dev/null +++ b/src/platform_io/__init__.py @@ -0,0 +1,35 @@ +"""导出 Platform IO 层的公开入口。 + +当前仍处于地基阶段,调用方应优先从这里导入共享类型和全局管理器, +而不是直接依赖更底层的私有子模块。 +""" + +from .manager import PlatformIOManager, get_platform_io_manager +from .route_key_factory import RouteKeyFactory +from .routing import RouteBindingConflictError, RouteTable +from .types import ( + DeliveryReceipt, + DeliveryStatus, + DriverDescriptor, + DriverKind, + InboundMessageEnvelope, + RouteBinding, + RouteKey, + RouteMode, +) + +__all__ = [ + "DeliveryReceipt", + "DeliveryStatus", + "DriverDescriptor", + "DriverKind", + "InboundMessageEnvelope", + "PlatformIOManager", + "RouteKeyFactory", + "RouteBinding", + "RouteBindingConflictError", + "RouteKey", + "RouteMode", + "RouteTable", + "get_platform_io_manager", +] diff --git a/src/platform_io/dedupe.py b/src/platform_io/dedupe.py new file mode 100644 index 00000000..4c5c55a2 --- /dev/null +++ b/src/platform_io/dedupe.py @@ -0,0 +1,133 @@ +"""提供 Platform IO 的轻量入站消息去重能力。 + +当前实现基于 ``dict + heapq``: +- ``dict`` 保存去重键到过期时间的映射 +- ``heapq`` 维护按过期时间排序的小顶堆 + +这样就不需要在每次检查时全表扫描,而是通过懒清理逐步弹出已经过期 +或已经失效的堆节点。 +""" + +from typing import Dict, List, Tuple + +import heapq +import time + + +class MessageDeduplicator: + """使用基于 TTL 的内存缓存进行入站消息去重。 + + 主要用于解决同一条外部消息被重复送入 Core 的问题,例如双路径并存、 + 适配器重试、重连或重复回调等场景。Broker 可以借助这个组件在进入 + Core 前先拦住重复投递,避免重复处理、重复回复和重复入库。 + + 当前实现使用 ``dict + heapq`` 维护过期时间: + - ``dict`` 负责 ``O(1)`` 级别的去重键查找 + - ``heapq`` 负责按过期时间顺序做懒清理 + + 这比“每次调用都全表扫描过期项”的实现更适合高吞吐消息场景。 + + Notes: + 复杂度说明如下,设 ``n`` 为当前缓存中的有效去重键数量: + + - 单次 ``mark_seen()`` 在常见路径下的时间复杂度接近 ``O(log n)`` + - 从长期摊还角度看,``mark_seen()`` 的时间复杂度也接近 ``O(log n)`` + - 如果某次调用恰好触发一批过期键的集中清理,则该次调用的最坏时间复杂度 + 可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出或清理的键数量 + - 空间复杂度为 ``O(n)`` + """ + + def __init__(self, ttl_seconds: float = 300.0, max_entries: int = 10000) -> None: + """初始化去重器。 + + Args: + ttl_seconds: 每个去重键在缓存中的保留时长,单位为秒。 + max_entries: 缓存允许保留的最大有效键数量,超出后会触发 + 机会性淘汰。 + + Raises: + ValueError: 当 ``ttl_seconds`` 或 ``max_entries`` 非正数时抛出。 + """ + if ttl_seconds <= 0: + raise ValueError("ttl_seconds 必须大于 0") + if max_entries <= 0: + raise ValueError("max_entries 必须大于 0") + + self._ttl_seconds = ttl_seconds + self._max_entries = max_entries + self._expire_heap: List[Tuple[float, str]] = [] + self._seen: Dict[str, float] = {} + + def mark_seen(self, dedupe_key: str) -> bool: + """标记一条去重键已经出现过。 + + Args: + dedupe_key: 能稳定标识一条外部入站消息的去重键。 + + Returns: + bool: 若该键在当前 TTL 窗口内首次出现则返回 ``True``, + 否则返回 ``False``。 + + Notes: + 方法会先基于小顶堆做一次懒清理,再判断当前键是否仍在有效期内。 + 如果缓存已达到上限,则会优先淘汰“最早过期的仍然有效的键”。 + + 复杂度方面,常见路径下该方法接近 ``O(log n)``;如果恰好需要 + 集中清理一批过期键,则单次调用最坏可达到 ``O(k log n)``。 + """ + now = time.monotonic() + self._purge_expired(now) + + expires_at = self._seen.get(dedupe_key) + if expires_at is not None and expires_at > now: + return False + + if len(self._seen) >= self._max_entries: + self._evict_earliest_live() + + expires_at = now + self._ttl_seconds + self._seen[dedupe_key] = expires_at + heapq.heappush(self._expire_heap, (expires_at, dedupe_key)) + return True + + def clear(self) -> None: + """清空全部去重缓存。""" + self._expire_heap.clear() + self._seen.clear() + + def _purge_expired(self, now: float) -> None: + """从缓存中清理已经过期的去重键。 + + Args: + now: 当前单调时钟时间戳。 + + Notes: + 堆中可能存在旧版本节点。例如同一个 ``dedupe_key`` 被重新写入后, + 旧的过期时间节点仍会留在堆里。这里会通过和 ``dict`` 中当前值比对, + 跳过这类失效节点。 + """ + while self._expire_heap and self._expire_heap[0][0] <= now: + expires_at, dedupe_key = heapq.heappop(self._expire_heap) + current_expires_at = self._seen.get(dedupe_key) + if current_expires_at is None: + continue + if current_expires_at != expires_at: + continue + self._seen.pop(dedupe_key, None) + + def _evict_earliest_live(self) -> None: + """当缓存达到容量上限时,淘汰一条最早过期的有效键。 + + Notes: + 堆顶可能是已经过期或已失效的旧节点,因此这里同样需要循环弹出, + 直到找到一条当前仍然在 ``dict`` 中生效的键。 + """ + while self._expire_heap: + expires_at, dedupe_key = heapq.heappop(self._expire_heap) + current_expires_at = self._seen.get(dedupe_key) + if current_expires_at is None: + continue + if current_expires_at != expires_at: + continue + self._seen.pop(dedupe_key, None) + return diff --git a/src/platform_io/drivers/__init__.py b/src/platform_io/drivers/__init__.py new file mode 100644 index 00000000..b12120cf --- /dev/null +++ b/src/platform_io/drivers/__init__.py @@ -0,0 +1,11 @@ +"""导出 Platform IO 层的公开驱动类型。""" + +from .base import PlatformIODriver +from .legacy_driver import LegacyPlatformDriver +from .plugin_driver import PluginPlatformDriver + +__all__ = [ + "LegacyPlatformDriver", + "PlatformIODriver", + "PluginPlatformDriver", +] diff --git a/src/platform_io/drivers/base.py b/src/platform_io/drivers/base.py new file mode 100644 index 00000000..c6173d8c --- /dev/null +++ b/src/platform_io/drivers/base.py @@ -0,0 +1,104 @@ +"""定义 Platform IO 传输驱动的基础抽象协议。""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional + +from src.platform_io.types import DeliveryReceipt, DriverDescriptor, InboundMessageEnvelope, RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + +InboundHandler = Callable[[InboundMessageEnvelope], Awaitable[bool]] + + +class PlatformIODriver(ABC): + """定义所有 Platform IO 驱动都必须实现的最小契约。 + + 当前实现故意保持接口很小,让中间层可以先落地,再逐步把 legacy + 与 plugin 路径的真实收发能力迁入这套协议之下。 + """ + + def __init__(self, descriptor: DriverDescriptor) -> None: + """使用驱动描述对象初始化驱动。 + + Args: + descriptor: 注册到 Broker 中的静态驱动元数据。 + """ + self._descriptor = descriptor + self._inbound_handler: Optional[InboundHandler] = None + + @property + def descriptor(self) -> DriverDescriptor: + """返回当前驱动的描述对象。 + + Returns: + DriverDescriptor: 当前驱动实例对应的描述对象。 + """ + return self._descriptor + + @property + def driver_id(self) -> str: + """返回驱动标识。 + + Returns: + str: 当前驱动的唯一 ID。 + """ + return self._descriptor.driver_id + + def set_inbound_handler(self, handler: InboundHandler) -> None: + """注册入站消息交回 Broker 的回调函数。 + + Args: + handler: 将规范化入站封装继续转发给 Broker 的异步回调。 + """ + self._inbound_handler = handler + + def clear_inbound_handler(self) -> None: + """清除当前注册的入站回调函数。""" + self._inbound_handler = None + + async def emit_inbound(self, envelope: InboundMessageEnvelope) -> bool: + """将一条入站封装转交给 Broker 回调。 + + Args: + envelope: 由驱动产出的规范化入站封装。 + + Returns: + bool: 若 Broker 接受该入站消息则返回 ``True``,否则返回 ``False``。 + """ + + if self._inbound_handler is None: + return False + return await self._inbound_handler(envelope) + + async def start(self) -> None: + """启动驱动生命周期。 + + 子类后续若需要初始化逻辑,可以覆盖这个钩子。 + """ + return None + + async def stop(self) -> None: + """停止驱动生命周期。 + + 子类后续若需要清理逻辑,可以覆盖这个钩子。 + """ + return None + + @abstractmethod + async def send_message( + self, + message: "SessionMessage", + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """通过具体驱动发送一条消息。 + + Args: + message: 要投递的内部会话消息。 + route_key: Broker 为本次投递选中的路由键。 + metadata: 本次出站投递可选的 Broker 侧元数据。 + + Returns: + DeliveryReceipt: 规范化后的投递结果。 + """ diff --git a/src/platform_io/drivers/legacy_driver.py b/src/platform_io/drivers/legacy_driver.py new file mode 100644 index 00000000..bd74d8c7 --- /dev/null +++ b/src/platform_io/drivers/legacy_driver.py @@ -0,0 +1,61 @@ +"""提供 Platform IO 的 legacy 传输驱动骨架。""" + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + + +class LegacyPlatformDriver(PlatformIODriver): + """面向 ``maim_message`` 旧链路的 Platform IO 驱动骨架。""" + + def __init__( + self, + driver_id: str, + platform: str, + account_id: Optional[str] = None, + scope: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """初始化一个 legacy 驱动描述对象。 + + Args: + driver_id: Broker 内的唯一驱动 ID。 + platform: 该 legacy 适配器链路负责的平台。 + account_id: 可选的账号 ID 或 self ID。 + scope: 可选的额外路由作用域。 + metadata: 可选的额外驱动元数据。 + """ + descriptor = DriverDescriptor( + driver_id=driver_id, + kind=DriverKind.LEGACY, + platform=platform, + account_id=account_id, + scope=scope, + metadata=metadata or {}, + ) + super().__init__(descriptor) + + async def send_message( + self, + message: "SessionMessage", + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """通过 legacy 传输路径发送消息。 + + Args: + message: 要投递的内部会话消息。 + route_key: Broker 为本次投递选择的路由键。 + metadata: 本次出站投递可选的 Broker 侧元数据。 + + Returns: + DeliveryReceipt: 由驱动返回的规范化回执。 + + Raises: + NotImplementedError: 当前仍处于骨架阶段,尚未真正接入旧发送链。 + """ + raise NotImplementedError("LegacyPlatformDriver 仅完成地基实现,尚未接入旧发送链") diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py new file mode 100644 index 00000000..9c139309 --- /dev/null +++ b/src/platform_io/drivers/plugin_driver.py @@ -0,0 +1,64 @@ +"""提供 Platform IO 的 plugin 传输驱动骨架。""" + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + + +class PluginPlatformDriver(PlatformIODriver): + """面向 ``MessageGateway`` 插件链路的 Platform IO 驱动骨架。""" + + def __init__( + self, + driver_id: str, + platform: str, + account_id: Optional[str] = None, + scope: Optional[str] = None, + plugin_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """初始化一个 plugin 驱动描述对象。 + + Args: + driver_id: Broker 内的唯一驱动 ID。 + platform: 该 plugin 适配器链路负责的平台。 + account_id: 可选的账号 ID 或 self ID。 + scope: 可选的额外路由作用域。 + plugin_id: 拥有该适配器实现的插件 ID,可为空。 + metadata: 可选的额外驱动元数据。 + """ + descriptor = DriverDescriptor( + driver_id=driver_id, + kind=DriverKind.PLUGIN, + platform=platform, + account_id=account_id, + scope=scope, + plugin_id=plugin_id, + metadata=metadata or {}, + ) + super().__init__(descriptor) + + async def send_message( + self, + message: "SessionMessage", + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """通过 plugin 传输路径发送消息。 + + Args: + message: 要投递的内部会话消息。 + route_key: Broker 为本次投递选择的路由键。 + metadata: 本次出站投递可选的 Broker 侧元数据。 + + Returns: + DeliveryReceipt: 由驱动返回的规范化回执。 + + Raises: + NotImplementedError: 当前仍处于骨架阶段,尚未真正接入 MessageGateway。 + """ + raise NotImplementedError("PluginPlatformDriver 仅完成地基实现,尚未接入 MessageGateway") diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py new file mode 100644 index 00000000..6135a567 --- /dev/null +++ b/src/platform_io/manager.py @@ -0,0 +1,529 @@ +"""提供 Platform IO 层的中心 Broker 管理器。""" + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional + +import hashlib +import json + +from src.common.logger import get_logger +from src.platform_io.drivers.base import PlatformIODriver + +from .dedupe import MessageDeduplicator +from .outbound_tracker import OutboundTracker +from .route_key_factory import RouteKeyFactory +from .registry import DriverRegistry +from .routing import RouteTable +from .types import DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + +logger = get_logger("platform_io.manager") + +InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]] + + +class PlatformIOManager: + """统一协调双路径平台消息 IO 的路由、去重与状态跟踪。 + + 这个管理器预期会成为 legacy 适配器链路与 plugin 适配器链路之间的 + 唯一裁决点。当前地基阶段,它只提供共享状态和 Broker 侧契约,还没有 + 真正把生产流量切到新中间层。 + """ + + def __init__(self) -> None: + """初始化 Broker 管理器及其内存状态。""" + self._driver_registry = DriverRegistry() + self._route_table = RouteTable() + self._deduplicator = MessageDeduplicator() + self._outbound_tracker = OutboundTracker() + self._inbound_dispatcher: Optional[InboundDispatcher] = None + self._started = False + + @property + def is_started(self) -> bool: + """返回 Broker 当前是否已进入运行态。 + + Returns: + bool: 若 Broker 已启动则返回 ``True``。 + """ + return self._started + + async def start(self) -> None: + """启动 Broker,并依次启动当前已注册的全部驱动。 + + Raises: + Exception: 当某个驱动启动失败时,异常会继续上抛;已成功启动的驱动 + 会被自动回滚停止。 + """ + if self._started: + return + + started_drivers: list[PlatformIODriver] = [] + try: + for driver in self._driver_registry.list(): + await driver.start() + started_drivers.append(driver) + except Exception: + for driver in reversed(started_drivers): + try: + await driver.stop() + except Exception: + logger.exception("回滚驱动停止失败: driver_id=%s", driver.driver_id) + raise + + self._started = True + + async def stop(self) -> None: + """停止 Broker,并按逆序停止全部已注册驱动。 + + 停止完成后,会同步清空仅对当前运行周期有效的去重缓存和出站跟踪状态, + 避免下一次启动时继续沿用上一个运行周期的瞬时内存数据。 + + Raises: + RuntimeError: 当一个或多个驱动停止失败时抛出汇总异常。 + """ + if not self._started: + return + + stop_errors: list[str] = [] + for driver in reversed(self._driver_registry.list()): + try: + await driver.stop() + except Exception as exc: + stop_errors.append(f"{driver.driver_id}: {exc}") + logger.exception("驱动停止失败: driver_id=%s", driver.driver_id) + + self._started = False + self._deduplicator.clear() + self._outbound_tracker.clear() + if stop_errors: + raise RuntimeError(f"部分驱动停止失败: {'; '.join(stop_errors)}") + + async def add_driver(self, driver: PlatformIODriver) -> None: + """向运行中的 Broker 注册并启动一个驱动。 + + 如果 Broker 尚未启动,则该方法等价于 ``register_driver()``。 + + Args: + driver: 要添加的驱动实例。 + + Raises: + Exception: 当驱动启动失败时,注册会自动回滚,异常继续上抛。 + """ + self._register_driver_internal(driver) + if not self._started: + return + + try: + await driver.start() + except Exception: + self._unregister_driver_internal(driver.driver_id) + raise + + async def remove_driver(self, driver_id: str) -> Optional[PlatformIODriver]: + """从运行中的 Broker 停止并移除一个驱动。 + + 如果 Broker 尚未启动,则该方法等价于 ``unregister_driver()``。 + + Args: + driver_id: 要移除的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。 + + Raises: + Exception: 当 Broker 运行中且驱动停止失败时,异常会继续上抛。 + """ + if not self._started: + return self.unregister_driver(driver_id) + + driver = self._driver_registry.get(driver_id) + if driver is None: + return None + + await driver.stop() + return self._unregister_driver_internal(driver_id) + + @property + def driver_registry(self) -> DriverRegistry: + """返回管理器持有的驱动注册表。 + + Returns: + DriverRegistry: 用于保存全部已注册驱动的注册表。 + """ + return self._driver_registry + + @property + def route_table(self) -> RouteTable: + """返回管理器持有的路由绑定表。 + + Returns: + RouteTable: 用于归属解析的路由绑定表。 + """ + return self._route_table + + @property + def deduplicator(self) -> MessageDeduplicator: + """返回管理器持有的入站去重器。 + + Returns: + MessageDeduplicator: 用于抑制重复入站的去重器。 + """ + return self._deduplicator + + @property + def outbound_tracker(self) -> OutboundTracker: + """返回管理器持有的出站跟踪器。 + + Returns: + OutboundTracker: 用于记录出站 pending 状态与回执的跟踪器。 + """ + return self._outbound_tracker + + def set_inbound_dispatcher(self, dispatcher: InboundDispatcher) -> None: + """设置统一的入站分发回调。 + + Args: + dispatcher: 接收已通过 Broker 审核的入站封装,并继续送入 + Core 下一处理阶段的异步回调。 + """ + + self._inbound_dispatcher = dispatcher + + def clear_inbound_dispatcher(self) -> None: + """清除当前的入站分发回调。""" + self._inbound_dispatcher = None + + @property + def has_inbound_dispatcher(self) -> bool: + """返回当前是否已经配置入站分发回调。 + + Returns: + bool: 若已经配置入站分发回调则返回 ``True``。 + """ + return self._inbound_dispatcher is not None + + def register_driver(self, driver: PlatformIODriver) -> None: + """注册驱动,并把它的入站回调挂到 Broker。 + + Args: + driver: 要注册的驱动实例。 + + Raises: + RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用 + ``add_driver()`` 以保证驱动生命周期和注册状态一致。 + """ + if self._started: + raise RuntimeError("Broker 运行中不允许直接 register_driver,请改用 add_driver()") + + self._register_driver_internal(driver) + + def _register_driver_internal(self, driver: PlatformIODriver) -> None: + """执行不带运行态限制的内部驱动注册。 + + Args: + driver: 要注册的驱动实例。 + """ + driver.set_inbound_handler(self.accept_inbound) + self._driver_registry.register(driver) + + def unregister_driver(self, driver_id: str) -> Optional[PlatformIODriver]: + """从 Broker 注销一个驱动。 + + Args: + driver_id: 要移除的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。 + + Raises: + RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用 + ``remove_driver()``,避免驱动停止与路由解绑脱节。 + """ + if self._started: + raise RuntimeError("Broker 运行中不允许直接 unregister_driver,请改用 remove_driver()") + + return self._unregister_driver_internal(driver_id) + + def _unregister_driver_internal(self, driver_id: str) -> Optional[PlatformIODriver]: + """执行不带运行态限制的内部驱动注销。 + + Args: + driver_id: 要移除的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。 + """ + removed_driver = self._driver_registry.unregister(driver_id) + if removed_driver is None: + return None + + removed_driver.clear_inbound_handler() + self._route_table.remove_bindings_by_driver(driver_id) + return removed_driver + + def bind_route(self, binding: RouteBinding, *, replace: bool = False) -> None: + """为某个路由键绑定驱动。 + + Args: + binding: 要保存的路由绑定。 + replace: 是否允许替换已有的精确 active owner。 + + Raises: + ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。 + """ + driver = self._driver_registry.get(binding.driver_id) + if driver is None: + raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由") + + self._validate_binding_against_driver(binding, driver) + self._route_table.bind(binding, replace=replace) + + def unbind_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: + """移除一个或多个路由绑定。 + + Args: + route_key: 要移除绑定的路由键。 + driver_id: 可选的特定驱动 ID。 + """ + self._route_table.unbind(route_key, driver_id) + + def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]: + """解析某个路由键当前的 active 驱动。 + + Args: + route_key: 要解析的路由键。 + + Returns: + Optional[PlatformIODriver]: 若存在 active 驱动,则返回该驱动实例。 + """ + active_binding = self._route_table.get_active_binding(route_key) + if active_binding is None: + return None + return self._driver_registry.get(active_binding.driver_id) + + @staticmethod + def build_route_key_from_message(message: "SessionMessage") -> RouteKey: + """根据 ``SessionMessage`` 构造路由键。 + + Args: + message: 内部会话消息对象。 + + Returns: + RouteKey: 由消息内容提取出的规范化路由键。 + """ + return RouteKeyFactory.from_session_message(message) + + @staticmethod + def build_route_key_from_message_dict(message_dict: Dict[str, Any]) -> RouteKey: + """根据消息字典构造路由键。 + + Args: + message_dict: Host 与插件之间传输的消息字典。 + + Returns: + RouteKey: 由消息字典提取出的规范化路由键。 + """ + return RouteKeyFactory.from_message_dict(message_dict) + + async def accept_inbound(self, envelope: InboundMessageEnvelope) -> bool: + """处理一条由驱动上报的入站封装。 + + Args: + envelope: 由传输驱动产出的入站封装。 + + Returns: + bool: 若消息被接受并继续转发给入站分发器,则返回 ``True``, + 否则返回 ``False``。 + """ + + if not self._route_table.accepts_inbound(envelope.route_key, envelope.driver_id): + logger.info( + "忽略非 active owner 的入站消息: route=%s driver=%s", + envelope.route_key, + envelope.driver_id, + ) + return False + + if self._inbound_dispatcher is None: + logger.debug("PlatformIOManager 尚未配置 inbound dispatcher,暂不继续分发") + return False + + dedupe_key = self._build_inbound_dedupe_key(envelope) + if dedupe_key is not None: + if not self._deduplicator.mark_seen(dedupe_key): + logger.info("忽略重复入站消息: dedupe_key=%s", dedupe_key) + return False + + await self._inbound_dispatcher(envelope) + return True + + async def send_message( + self, + message: "SessionMessage", + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """通过 Broker 选中的驱动发送一条消息。 + + Args: + message: 要投递的内部会话消息。 + route_key: 本次出站投递选择的路由键。 + metadata: 可选的额外 Broker 侧元数据。 + + Returns: + DeliveryReceipt: 规范化后的出站回执。若路由不存在、驱动缺失, + 或同一消息已存在未完成的出站跟踪,也会返回失败回执而不是抛异常。 + """ + + active_binding = self._route_table.get_active_binding(route_key) + if active_binding is None: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + error="未找到 active 路由绑定", + ) + + driver = self._driver_registry.get(active_binding.driver_id) + if driver is None: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=active_binding.driver_id, + driver_kind=active_binding.driver_kind, + error="active 路由绑定对应的驱动不存在", + ) + + try: + self._outbound_tracker.begin_tracking( + internal_message_id=message.message_id, + route_key=route_key, + driver_id=driver.driver_id, + metadata=metadata, + ) + except ValueError as exc: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=driver.driver_id, + driver_kind=driver.descriptor.kind, + error=str(exc), + ) + + try: + receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata) + except Exception as exc: + receipt = DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=driver.driver_id, + driver_kind=driver.descriptor.kind, + error=str(exc), + ) + + self._outbound_tracker.finish_tracking(receipt) + return receipt + + @staticmethod + def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]: + """构造用于入站抑制的去重键。 + + Args: + envelope: 当前正在处理的入站封装。 + + Returns: + Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。 + """ + raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id + if raw_dedupe_key is None and envelope.session_message is not None: + raw_dedupe_key = envelope.session_message.message_id + if raw_dedupe_key is None and envelope.payload is not None: + raw_dedupe_key = PlatformIOManager._build_payload_fingerprint(envelope.payload) + if raw_dedupe_key is None: + return None + + normalized_dedupe_key = str(raw_dedupe_key).strip() + if not normalized_dedupe_key: + return None + + return f"{envelope.route_key.to_dedupe_scope()}:{normalized_dedupe_key}" + + @staticmethod + def _build_payload_fingerprint(payload: Dict[str, Any]) -> Optional[str]: + """根据消息载荷构造稳定指纹。 + + Args: + payload: 待构造指纹的原始载荷字典。 + + Returns: + Optional[str]: 若成功生成指纹则返回十六进制摘要,否则返回 ``None``。 + """ + try: + serialized_payload = json.dumps( + payload, + default=str, + ensure_ascii=True, + separators=(",", ":"), + sort_keys=True, + ) + except Exception: + return None + + return hashlib.sha256(serialized_payload.encode()).hexdigest() + + @staticmethod + def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None: + """校验路由绑定与驱动描述是否一致。 + + Args: + binding: 待校验的路由绑定。 + driver: 被绑定的驱动实例。 + + Raises: + ValueError: 当绑定类型、平台或更细粒度路由维度与驱动描述冲突时抛出。 + """ + descriptor = driver.descriptor + if binding.driver_kind != descriptor.kind: + raise ValueError( + f"路由绑定的 driver_kind={binding.driver_kind} 与驱动 {driver.driver_id} 的类型 " + f"{descriptor.kind} 不一致" + ) + + if binding.route_key.platform != descriptor.platform: + raise ValueError( + f"路由绑定的平台 {binding.route_key.platform} 与驱动 {driver.driver_id} 的平台 " + f"{descriptor.platform} 不一致" + ) + + if descriptor.account_id is not None and binding.route_key.account_id not in (None, descriptor.account_id): + raise ValueError( + f"路由绑定的 account_id={binding.route_key.account_id} 与驱动 {driver.driver_id} 的 " + f"account_id={descriptor.account_id} 冲突" + ) + + if descriptor.scope is not None and binding.route_key.scope not in (None, descriptor.scope): + raise ValueError( + f"路由绑定的 scope={binding.route_key.scope} 与驱动 {driver.driver_id} 的 " + f"scope={descriptor.scope} 冲突" + ) + + +_platform_io_manager: Optional[PlatformIOManager] = None + + +def get_platform_io_manager() -> PlatformIOManager: + """返回全局 ``PlatformIOManager`` 单例。 + + Returns: + PlatformIOManager: 进程级共享的 Broker 管理器实例。 + """ + + global _platform_io_manager + if _platform_io_manager is None: + _platform_io_manager = PlatformIOManager() + return _platform_io_manager diff --git a/src/platform_io/outbound_tracker.py b/src/platform_io/outbound_tracker.py new file mode 100644 index 00000000..438aa566 --- /dev/null +++ b/src/platform_io/outbound_tracker.py @@ -0,0 +1,242 @@ +"""跟踪 Platform IO 层的出站投递状态。 + +当前实现基于两组 ``dict + heapq``: +- ``_pending`` 和 ``_pending_expire_heap`` 负责管理待完成的出站记录 +- ``_receipts_by_external_id`` 和 ``_receipt_expire_heap`` 负责管理已完成回执索引 + +这样就不需要在每次读写时全表扫描过期项,而是通过懒清理逐步弹出已经过期 +或已经失效的堆节点。 +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import heapq +import time + +from .types import DeliveryReceipt, RouteKey + + +@dataclass(slots=True) +class PendingOutboundRecord: + """表示一条仍在等待完成的出站投递记录。 + + Attributes: + internal_message_id: 正在跟踪的内部 ``SessionMessage.message_id``。 + route_key: 该出站投递开始时使用的路由键。 + driver_id: 负责这次出站投递的驱动 ID。 + created_at: 开始跟踪时记录的单调时钟时间戳。 + expires_at: 该待完成记录预计过期的单调时钟时间戳。 + metadata: 与待完成记录一同保留的额外 Broker 侧元数据。 + """ + + internal_message_id: str + route_key: RouteKey + driver_id: str + created_at: float = field(default_factory=time.monotonic) + expires_at: float = 0.0 + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class StoredDeliveryReceipt: + """表示一条已完成并暂存的出站回执。 + + Attributes: + receipt: 规范化后的出站投递回执。 + stored_at: 回执被写入索引时记录的单调时钟时间戳。 + expires_at: 该回执索引预计过期的单调时钟时间戳。 + """ + + receipt: DeliveryReceipt + stored_at: float = field(default_factory=time.monotonic) + expires_at: float = 0.0 + + +class OutboundTracker: + """统一跟踪出站消息的 pending 状态与最终回执。 + + 主要用于解决出站消息在发送过程中“状态散落在不同路径里”的问题: + - 发送开始后,需要在最终回执返回前保留一份 pending 状态 + - 平台返回 ``external_message_id`` 后,需要保留一段时间的回执索引 + + 当前实现使用 ``dict + heapq`` 做 TTL 管理: + - ``dict`` 提供 ``O(1)`` 级别的主键查询 + - ``heapq`` 提供按过期时间排序的懒清理能力 + + 这比“每次 begin/finish/get 都全表扫描”的实现更适合高吞吐出站场景。 + + Notes: + 复杂度说明如下,设 ``p`` 为当前有效 pending 数量,``r`` 为当前有效回执数量: + + - ``begin_tracking()``、``finish_tracking()`` 的常见路径时间复杂度接近 + ``O(log p)`` 或 ``O(log r)`` + - ``get_pending()``、``get_receipt_by_external_id()`` 的查询本身是 ``O(1)`` + ,连同懒清理一起看,长期摊还复杂度接近 ``O(log n)`` + - 如果某次调用恰好触发一批过期节点的集中清理,则该次调用的最坏时间复杂度 + 可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出的节点数量 + - 空间复杂度为 ``O(p + r)`` + """ + + def __init__(self, ttl_seconds: float = 1800.0) -> None: + """初始化出站跟踪器。 + + Args: + ttl_seconds: 待完成记录与按外部消息 ID 建立的回执索引保留时长, + 单位为秒。 + + Raises: + ValueError: 当 ``ttl_seconds`` 非正数时抛出。 + """ + if ttl_seconds <= 0: + raise ValueError("ttl_seconds 必须大于 0") + + self._ttl_seconds = ttl_seconds + self._pending: Dict[str, PendingOutboundRecord] = {} + self._pending_expire_heap: List[Tuple[float, str]] = [] + self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {} + self._receipt_expire_heap: List[Tuple[float, str]] = [] + + def begin_tracking( + self, + internal_message_id: str, + route_key: RouteKey, + driver_id: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> PendingOutboundRecord: + """开始跟踪一次出站投递。 + + Args: + internal_message_id: 正在投递的内部消息 ID。 + route_key: 这次出站投递选择的路由键。 + driver_id: 负责本次投递的驱动 ID。 + metadata: 可选的额外元数据,会一并保存在待完成记录中。 + + Returns: + PendingOutboundRecord: 新创建的待完成记录。 + + Raises: + ValueError: 当同一个 ``internal_message_id`` 已经存在未完成记录时抛出。 + """ + now = time.monotonic() + self._cleanup_expired(now) + + if internal_message_id in self._pending: + raise ValueError(f"消息 {internal_message_id} 已存在未完成的出站跟踪记录") + + expires_at = now + self._ttl_seconds + record = PendingOutboundRecord( + internal_message_id=internal_message_id, + route_key=route_key, + driver_id=driver_id, + created_at=now, + expires_at=expires_at, + metadata=metadata or {}, + ) + self._pending[internal_message_id] = record + heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id)) + return record + + def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]: + """使用最终回执结束一条出站跟踪。 + + Args: + receipt: 规范化后的最终投递回执。 + + Returns: + Optional[PendingOutboundRecord]: 若此前存在待完成记录,则返回该记录。 + """ + now = time.monotonic() + self._cleanup_expired(now) + + pending_record = self._pending.pop(receipt.internal_message_id, None) + if receipt.external_message_id: + expires_at = now + self._ttl_seconds + self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt( + receipt=receipt, + stored_at=now, + expires_at=expires_at, + ) + heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id)) + return pending_record + + def get_pending(self, internal_message_id: str) -> Optional[PendingOutboundRecord]: + """根据内部消息 ID 查询待完成记录。 + + Args: + internal_message_id: 要查询的内部消息 ID。 + + Returns: + Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。 + """ + self._cleanup_expired(time.monotonic()) + return self._pending.get(internal_message_id) + + def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]: + """根据外部平台消息 ID 查询已完成回执。 + + Args: + external_message_id: 要查询的平台侧消息 ID。 + + Returns: + Optional[DeliveryReceipt]: 若存在对应回执,则返回该回执。 + """ + self._cleanup_expired(time.monotonic()) + stored_receipt = self._receipts_by_external_id.get(external_message_id) + return stored_receipt.receipt if stored_receipt else None + + def clear(self) -> None: + """清空全部待完成记录与已保存回执。""" + self._pending.clear() + self._pending_expire_heap.clear() + self._receipts_by_external_id.clear() + self._receipt_expire_heap.clear() + + def _cleanup_expired(self, now: float) -> None: + """清理内存中已经过期的待完成记录与已保存回执。 + + Args: + now: 当前单调时钟时间戳。 + """ + self._cleanup_expired_pending(now) + self._cleanup_expired_receipts(now) + + def _cleanup_expired_pending(self, now: float) -> None: + """清理已经过期的待完成记录。 + + Args: + now: 当前单调时钟时间戳。 + + Notes: + 堆中可能存在已经失效的旧节点。例如某条记录提前 ``finish`` 后, + 它原本的过期节点仍可能留在堆里。这里会通过和 ``dict`` 中当前记录的 + ``expires_at`` 对比,跳过这类旧节点。 + """ + while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now: + expires_at, internal_message_id = heapq.heappop(self._pending_expire_heap) + current_record = self._pending.get(internal_message_id) + if current_record is None: + continue + if current_record.expires_at != expires_at: + continue + self._pending.pop(internal_message_id, None) + + def _cleanup_expired_receipts(self, now: float) -> None: + """清理已经过期的回执索引。 + + Args: + now: 当前单调时钟时间戳。 + + Notes: + 同一个 ``external_message_id`` 在极端情况下可能被重复写入索引, + 因此这里同样需要通过 ``expires_at`` 和当前 ``dict`` 中的值比对, + 跳过已经失效的旧堆节点。 + """ + while self._receipt_expire_heap and self._receipt_expire_heap[0][0] <= now: + expires_at, external_message_id = heapq.heappop(self._receipt_expire_heap) + current_receipt = self._receipts_by_external_id.get(external_message_id) + if current_receipt is None: + continue + if current_receipt.expires_at != expires_at: + continue + self._receipts_by_external_id.pop(external_message_id, None) diff --git a/src/platform_io/registry.py b/src/platform_io/registry.py new file mode 100644 index 00000000..9ad8ea8a --- /dev/null +++ b/src/platform_io/registry.py @@ -0,0 +1,70 @@ +"""提供 Platform IO 的驱动注册与查询能力。""" + +from typing import Dict, List, Optional + +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.types import DriverKind + + +class DriverRegistry: + """集中保存已注册的 Platform IO 驱动,并提供基础查询接口。""" + + def __init__(self) -> None: + """初始化一个空的驱动注册表。""" + self._drivers: Dict[str, PlatformIODriver] = {} + + def register(self, driver: PlatformIODriver) -> None: + """注册一个驱动实例。 + + Args: + driver: 要注册的驱动实例。 + + Raises: + ValueError: 当驱动 ID 已经存在时抛出。 + """ + if driver.driver_id in self._drivers: + raise ValueError(f"驱动 {driver.driver_id} 已注册") + self._drivers[driver.driver_id] = driver + + def unregister(self, driver_id: str) -> Optional[PlatformIODriver]: + """按驱动 ID 注销一个驱动。 + + Args: + driver_id: 要移除的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。 + """ + return self._drivers.pop(driver_id, None) + + def get(self, driver_id: str) -> Optional[PlatformIODriver]: + """按驱动 ID 获取驱动实例。 + + Args: + driver_id: 要查询的驱动 ID。 + + Returns: + Optional[PlatformIODriver]: 若存在匹配驱动,则返回该驱动实例。 + """ + return self._drivers.get(driver_id) + + def list(self, *, kind: Optional[DriverKind] = None, platform: Optional[str] = None) -> List[PlatformIODriver]: + """列出已注册驱动,并支持可选过滤。 + + Args: + kind: 可选的驱动类型过滤条件。 + platform: 可选的平台名称过滤条件。 + + Returns: + List[PlatformIODriver]: 符合过滤条件的驱动列表。 + """ + drivers = list(self._drivers.values()) + if kind is not None: + drivers = [driver for driver in drivers if driver.descriptor.kind == kind] + if platform is not None: + drivers = [driver for driver in drivers if driver.descriptor.platform == platform] + return drivers + + def clear(self) -> None: + """清空全部已注册驱动。""" + self._drivers.clear() diff --git a/src/platform_io/route_key_factory.py b/src/platform_io/route_key_factory.py new file mode 100644 index 00000000..05bac6e8 --- /dev/null +++ b/src/platform_io/route_key_factory.py @@ -0,0 +1,150 @@ +"""提供 Platform IO 路由键的统一提取与构造能力。 + +这层的目标不是直接接入具体消息链,而是先把“未来接线时用什么字段构造 +RouteKey”约定下来,避免 legacy 和 plugin 两条链路各自发明一套隐式规则。 +""" + +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +from .types import RouteKey + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + + +class RouteKeyFactory: + """统一构造 ``RouteKey`` 的工厂。 + + 当前约定会优先从消息字典顶层、``message_info``、``additional_config`` 或传入 metadata 中提取 + 以下字段: + + - account_id: ``platform_io_account_id`` / ``account_id`` / ``self_id`` / ``bot_account`` + - scope: ``platform_io_scope`` / ``route_scope`` / ``adapter_scope`` / ``connection_id`` + + 这样即使上游主链暂时还没有正式的 ``self_id`` 字段,中间层也能先统一 + 约定提取口径,等具体消息链接入时直接复用。 + """ + + ACCOUNT_ID_KEYS = ( + "platform_io_account_id", + "account_id", + "self_id", + "bot_account", + ) + SCOPE_KEYS = ( + "platform_io_scope", + "route_scope", + "adapter_scope", + "connection_id", + ) + + @classmethod + def from_platform( + cls, + platform: str, + *, + account_id: Optional[str] = None, + scope: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> RouteKey: + """根据平台名和可选 metadata 构造 ``RouteKey``。 + + Args: + platform: 平台名称。 + account_id: 显式传入的账号 ID;若为空,则尝试从 metadata 提取。 + scope: 显式传入的路由作用域;若为空,则尝试从 metadata 提取。 + metadata: 可选的元数据字典。 + + Returns: + RouteKey: 构造出的规范化路由键。 + """ + extracted_account_id, extracted_scope = cls.extract_components(metadata) + return RouteKey( + platform=platform, + account_id=account_id or extracted_account_id, + scope=scope or extracted_scope, + ) + + @classmethod + def from_message_dict(cls, message_dict: Dict[str, Any]) -> RouteKey: + """从消息字典中提取 ``RouteKey``。 + + Args: + message_dict: Host 与插件之间传输的消息字典。 + + Returns: + RouteKey: 构造出的规范化路由键。 + + Raises: + ValueError: 当消息字典缺少有效 ``platform`` 字段时抛出。 + """ + platform = str(message_dict.get("platform") or "").strip() + if not platform: + raise ValueError("消息字典缺少有效的 platform 字段,无法构造 RouteKey") + + message_info = message_dict.get("message_info", {}) + additional_config = {} + if isinstance(message_info, dict): + raw_additional_config = message_info.get("additional_config", {}) + if isinstance(raw_additional_config, dict): + additional_config = raw_additional_config + + explicit_account_id, explicit_scope = cls.extract_components(message_dict) + message_info_account_id, message_info_scope = cls.extract_components(message_info) + metadata_account_id, metadata_scope = cls.extract_components(additional_config) + return RouteKey( + platform=platform, + account_id=explicit_account_id or message_info_account_id or metadata_account_id, + scope=explicit_scope or message_info_scope or metadata_scope, + ) + + @classmethod + def from_session_message(cls, message: "SessionMessage") -> RouteKey: + """从 ``SessionMessage`` 中提取 ``RouteKey``。 + + Args: + message: 内部会话消息对象。 + + Returns: + RouteKey: 构造出的规范化路由键。 + """ + additional_config = message.message_info.additional_config or {} + metadata = additional_config if isinstance(additional_config, dict) else {} + return cls.from_platform(message.platform, metadata=metadata) + + @classmethod + def extract_components(cls, mapping: Optional[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str]]: + """从任意字典中提取 ``account_id`` 与 ``scope``。 + + Args: + mapping: 待提取的字典;若为空或不是字典,则返回空结果。 + + Returns: + Tuple[Optional[str], Optional[str]]: ``(account_id, scope)``。 + """ + if not mapping or not isinstance(mapping, dict): + return None, None + + account_id = cls._pick_string(mapping, cls.ACCOUNT_ID_KEYS) + scope = cls._pick_string(mapping, cls.SCOPE_KEYS) + return account_id, scope + + @staticmethod + def _pick_string(mapping: Dict[str, Any], keys: Tuple[str, ...]) -> Optional[str]: + """按优先级从字典里挑选第一个有效字符串。 + + Args: + mapping: 待查询的字典。 + keys: 按优先级排列的候选键名。 + + Returns: + Optional[str]: 第一个规范化后非空的字符串值;若不存在则返回 ``None``。 + """ + for key in keys: + value = mapping.get(key) + if value is None: + continue + normalized = str(value).strip() + if normalized: + return normalized + return None diff --git a/src/platform_io/routing.py b/src/platform_io/routing.py new file mode 100644 index 00000000..7f85bbfa --- /dev/null +++ b/src/platform_io/routing.py @@ -0,0 +1,202 @@ +"""提供 Platform IO 的路由绑定存储与归属解析能力。""" + +from typing import Dict, List, Optional + +from .types import RouteBinding, RouteKey, RouteMode + + +class RouteBindingConflictError(ValueError): + """当同一路由键出现多个 active owner 竞争时抛出。""" + + +class RouteTable: + """维护路由绑定并解析路由归属。 + + 这个表刻意保持轻量,只负责归属规则本身,不掺杂具体发送或接收逻辑。 + 它决定某个路由键当前由哪个驱动 active 接管,哪些驱动仅以 shadow + 方式旁路观测。 + """ + + def __init__(self) -> None: + """初始化一个空的路由绑定表。""" + self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {} + + def bind(self, binding: RouteBinding, *, replace: bool = False) -> None: + """注册或更新一条路由绑定。 + + Args: + binding: 要注册的绑定对象。 + replace: 当精确路由键上已经存在 active owner 时,是否允许替换。 + + Raises: + RouteBindingConflictError: 当精确路由键上已存在其他 active owner, + 且 ``replace`` 为 ``False`` 时抛出。 + """ + + if binding.mode == RouteMode.DISABLED: + self.unbind(binding.route_key, binding.driver_id) + return + + if binding.mode == RouteMode.ACTIVE: + active_binding = self.get_active_binding(binding.route_key, exact_only=True) + if active_binding and active_binding.driver_id != binding.driver_id: + if not replace: + raise RouteBindingConflictError( + f"RouteKey {binding.route_key} 已由 {active_binding.driver_id} 接管," + f"拒绝绑定到 {binding.driver_id}" + ) + self.unbind(binding.route_key, active_binding.driver_id) + + self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding + + def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]: + """移除指定路由键上的绑定。 + + Args: + route_key: 要移除绑定的路由键。 + driver_id: 可选的特定驱动 ID;若为空,则移除该路由键上的全部绑定。 + + Returns: + List[RouteBinding]: 被移除的绑定列表。 + """ + + binding_map = self._bindings.get(route_key) + if not binding_map: + return [] + + if driver_id is None: + removed = list(binding_map.values()) + self._bindings.pop(route_key, None) + return removed + + removed_binding = binding_map.pop(driver_id, None) + if not binding_map: + self._bindings.pop(route_key, None) + return [removed_binding] if removed_binding else [] + + def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]: + """移除某个驱动在所有路由键上的绑定。 + + Args: + driver_id: 要移除绑定的驱动 ID。 + + Returns: + List[RouteBinding]: 被移除的绑定列表。 + """ + removed_bindings: List[RouteBinding] = [] + empty_route_keys: List[RouteKey] = [] + + for route_key, binding_map in self._bindings.items(): + removed_binding = binding_map.pop(driver_id, None) + if removed_binding is not None: + removed_bindings.append(removed_binding) + if not binding_map: + empty_route_keys.append(route_key) + + for route_key in empty_route_keys: + self._bindings.pop(route_key, None) + + return self._sort_bindings(removed_bindings) + + def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]: + """列出当前绑定。 + + Args: + route_key: 可选的路由键过滤条件;若为空,则返回全部路由键上的绑定。 + + Returns: + List[RouteBinding]: 按优先级降序排列的绑定列表。 + """ + + if route_key is None: + bindings: List[RouteBinding] = [] + for binding_map in self._bindings.values(): + bindings.extend(binding_map.values()) + return self._sort_bindings(bindings) + + binding_map = self._bindings.get(route_key, {}) + return self._sort_bindings(list(binding_map.values())) + + def get_active_binding(self, route_key: RouteKey, *, exact_only: bool = False) -> Optional[RouteBinding]: + """获取某个路由键当前生效的 active 绑定。 + + Args: + route_key: 要解析的路由键。 + exact_only: 是否只检查精确路由键而不做回退解析。 + + Returns: + Optional[RouteBinding]: 若存在 active owner,则返回对应绑定。 + """ + + candidate_keys = [route_key] if exact_only else route_key.resolution_order() + for candidate_key in candidate_keys: + binding_map = self._bindings.get(candidate_key, {}) + active_binding = self._pick_best_binding(binding_map, RouteMode.ACTIVE) + if active_binding is not None: + return active_binding + return None + + def get_shadow_bindings(self, route_key: RouteKey) -> List[RouteBinding]: + """获取某个精确路由键上的 shadow 绑定。 + + Args: + route_key: 要查看的路由键。 + + Returns: + List[RouteBinding]: 按优先级降序排列的 shadow 绑定列表。 + """ + binding_map = self._bindings.get(route_key, {}) + shadow_bindings = [binding for binding in binding_map.values() if binding.mode == RouteMode.SHADOW] + return self._sort_bindings(shadow_bindings) + + def accepts_inbound(self, route_key: RouteKey, driver_id: str) -> bool: + """判断某个驱动是否是当前允许入 Core 的 active owner。 + + Args: + route_key: 入站消息对应的路由键。 + driver_id: 希望将消息送入 Core 的驱动 ID。 + + Returns: + bool: 若该驱动是解析结果中的 active owner,则返回 ``True``。 + """ + + active_binding = self.get_active_binding(route_key) + return active_binding is not None and active_binding.driver_id == driver_id + + @staticmethod + def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]: + """按优先级降序排列绑定列表。 + + Args: + bindings: 待排序的绑定列表。 + + Returns: + List[RouteBinding]: 排序后的绑定列表。 + """ + return sorted(bindings, key=lambda item: item.priority, reverse=True) + + @staticmethod + def _pick_best_binding( + binding_map: Dict[str, RouteBinding], + mode: RouteMode, + ) -> Optional[RouteBinding]: + """从绑定映射中挑选指定模式下优先级最高的一条绑定。 + + Args: + binding_map: 某个精确 ``RouteKey`` 对应的绑定映射。 + mode: 需要挑选的绑定模式。 + + Returns: + Optional[RouteBinding]: 若存在匹配模式的绑定,则返回优先级最高的一条。 + + Notes: + 这里使用单次线性扫描代替“先过滤成列表再排序”的做法,以减少 + 高频路由解析路径上的临时对象分配和排序开销。 + """ + best_binding: Optional[RouteBinding] = None + for binding in binding_map.values(): + if binding.mode != mode: + continue + if best_binding is None or binding.priority > best_binding.priority: + best_binding = binding + return best_binding diff --git a/src/platform_io/types.py b/src/platform_io/types.py new file mode 100644 index 00000000..c74dc246 --- /dev/null +++ b/src/platform_io/types.py @@ -0,0 +1,240 @@ +"""定义 Platform IO 中间层共享的核心类型。 + +本模块放置路由、驱动、入站与出站等规范化数据结构,供 Broker +层在 legacy 适配器链路和 plugin 适配器链路之间复用。 +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + + +class DriverKind(str, Enum): + """底层收发驱动类型枚举。""" + + LEGACY = "legacy" + PLUGIN = "plugin" + + +class RouteMode(str, Enum): + """路由归属模式枚举。""" + + ACTIVE = "active" + SHADOW = "shadow" + DISABLED = "disabled" + + +class DeliveryStatus(str, Enum): + """统一出站回执状态枚举。""" + + PENDING = "pending" + SENT = "sent" + FAILED = "failed" + DROPPED = "dropped" + + +@dataclass(frozen=True, slots=True) +class RouteKey: + """用于 Platform IO 路由决策的唯一键。 + + 路由解析会按照“从最具体到最宽泛”的顺序进行回退,这样同一平台 + 后续就能自然支持按账号、自定义 scope 等更细粒度的归属控制。 + + Attributes: + platform: 平台名称,例如 ``qq``。 + account_id: 机器人账号 ID 或 self ID,用于区分同平台多身份。 + scope: 额外路由作用域,预留给未来的连接实例、租户或子通道等维度。 + """ + + platform: str + account_id: Optional[str] = None + scope: Optional[str] = None + + def __post_init__(self) -> None: + """规范化并校验路由键字段。 + + Raises: + ValueError: 当 ``platform`` 规范化后为空时抛出。 + """ + platform = str(self.platform).strip() + account_id = str(self.account_id).strip() if self.account_id is not None else None + scope = str(self.scope).strip() if self.scope is not None else None + + if not platform: + raise ValueError("RouteKey.platform 不能为空") + + object.__setattr__(self, "platform", platform) + object.__setattr__(self, "account_id", account_id or None) + object.__setattr__(self, "scope", scope or None) + + def resolution_order(self) -> List["RouteKey"]: + """返回从最具体到最宽泛的路由匹配顺序。 + + Returns: + List[RouteKey]: 按回退优先级排序的候选路由键列表。 + """ + + keys: List[RouteKey] = [self] + + if self.account_id is not None and self.scope is not None: + keys.append(RouteKey(platform=self.platform, account_id=self.account_id, scope=None)) + keys.append(RouteKey(platform=self.platform, account_id=None, scope=self.scope)) + elif self.account_id is not None: + keys.append(RouteKey(platform=self.platform, account_id=None, scope=None)) + elif self.scope is not None: + keys.append(RouteKey(platform=self.platform, account_id=None, scope=None)) + + default_key = RouteKey(platform=self.platform, account_id=None, scope=None) + if default_key not in keys: + keys.append(default_key) + + return keys + + def to_dedupe_scope(self) -> str: + """生成跨驱动共享的去重作用域字符串。 + + Returns: + str: 用于入站消息去重的稳定文本作用域键。 + """ + + account_id = self.account_id or "*" + scope = self.scope or "*" + return f"{self.platform}:{account_id}:{scope}" + + +@dataclass(frozen=True, slots=True) +class DriverDescriptor: + """描述一个已注册的 Platform IO 驱动。 + + Attributes: + driver_id: Broker 层内全局唯一的驱动标识。 + kind: 驱动实现类型,例如 legacy 或 plugin。 + platform: 驱动负责的平台名称。 + account_id: 可选的账号 ID 或 self ID。 + scope: 可选的额外路由作用域。 + plugin_id: 当驱动来自插件适配器时,对应的插件 ID。 + metadata: 预留给路由策略或观测能力的额外驱动元数据。 + """ + + driver_id: str + kind: DriverKind + platform: str + account_id: Optional[str] = None + scope: Optional[str] = None + plugin_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """规范化并校验驱动描述字段。 + + Raises: + ValueError: 当 ``driver_id`` 或 ``platform`` 规范化后为空时抛出。 + """ + driver_id = str(self.driver_id).strip() + platform = str(self.platform).strip() + plugin_id = str(self.plugin_id).strip() if self.plugin_id is not None else None + + if not driver_id: + raise ValueError("DriverDescriptor.driver_id 不能为空") + if not platform: + raise ValueError("DriverDescriptor.platform 不能为空") + + object.__setattr__(self, "driver_id", driver_id) + object.__setattr__(self, "platform", platform) + object.__setattr__(self, "plugin_id", plugin_id or None) + + @property + def route_key(self) -> RouteKey: + """构造该驱动默认代表的路由键。 + + Returns: + RouteKey: 当前驱动描述对应的规范化路由键。 + """ + return RouteKey(platform=self.platform, account_id=self.account_id, scope=self.scope) + + +@dataclass(frozen=True, slots=True) +class RouteBinding: + """表示一条从路由键到驱动的归属绑定关系。 + + Attributes: + route_key: 该绑定覆盖的路由键。 + driver_id: 拥有或旁路观察该路由的驱动 ID。 + driver_kind: 绑定驱动的类型。 + mode: 绑定模式,例如 active owner 或 shadow observer。 + priority: 当同模式下存在多条绑定时使用的相对优先级。 + metadata: 预留给未来路由策略的额外绑定元数据。 + """ + + route_key: RouteKey + driver_id: str + driver_kind: DriverKind + mode: RouteMode = RouteMode.ACTIVE + priority: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """规范化并校验绑定字段。 + + Raises: + ValueError: 当 ``driver_id`` 规范化后为空时抛出。 + """ + driver_id = str(self.driver_id).strip() + if not driver_id: + raise ValueError("RouteBinding.driver_id 不能为空") + object.__setattr__(self, "driver_id", driver_id) + + +@dataclass(slots=True) +class InboundMessageEnvelope: + """封装一次由驱动产出的规范化入站消息。 + + Attributes: + route_key: 该入站消息解析出的路由键。 + driver_id: 产出该消息的驱动 ID。 + driver_kind: 产出该消息的驱动类型。 + external_message_id: 可选的平台侧消息 ID,用于去重。 + dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时, + 可由上游驱动提供消息指纹。若这里为空,中间层仍可能继续回退到 + ``session_message.message_id`` 或 ``payload`` 指纹。 + session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。 + payload: 可选的原始字典载荷,供延迟转换或调试使用。 + metadata: 额外入站元数据,例如连接信息或追踪上下文。 + """ + + route_key: RouteKey + driver_id: str + driver_kind: DriverKind + external_message_id: Optional[str] = None + dedupe_key: Optional[str] = None + session_message: Optional["SessionMessage"] = None + payload: Optional[Dict[str, Any]] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class DeliveryReceipt: + """表示一次出站投递尝试的统一结果。 + + Attributes: + internal_message_id: Broker 跟踪的内部 ``SessionMessage.message_id``。 + route_key: 本次投递使用的路由键。 + status: 规范化后的投递状态。 + driver_id: 实际处理该投递的驱动 ID,可为空。 + driver_kind: 实际处理该投递的驱动类型,可为空。 + external_message_id: 驱动或适配器返回的平台侧消息 ID,可为空。 + error: 投递失败时的错误信息,可为空。 + metadata: 预留给回执、时间戳或平台特有信息的额外元数据。 + """ + + internal_message_id: str + route_key: RouteKey + status: DeliveryStatus + driver_id: Optional[str] = None + driver_kind: Optional[DriverKind] = None + external_message_id: Optional[str] = None + error: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) From 07256182fbf2f649bf55bf68d2ac2ac3bc0e1593 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 20 Mar 2026 01:20:22 +0800 Subject: [PATCH 16/45] =?UTF-8?q?refactor(manager):=20=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=20List=20=E7=B1=BB=E5=9E=8B=E6=9B=BF=E4=BB=A3=20list=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E7=B1=BB=E5=9E=8B=E4=B8=80=E8=87=B4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/platform_io/manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index 6135a567..97835667 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -1,6 +1,6 @@ """提供 Platform IO 层的中心 Broker 管理器。""" -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional import hashlib import json @@ -59,7 +59,7 @@ class PlatformIOManager: if self._started: return - started_drivers: list[PlatformIODriver] = [] + started_drivers: List[PlatformIODriver] = [] try: for driver in self._driver_registry.list(): await driver.start() @@ -86,7 +86,7 @@ class PlatformIOManager: if not self._started: return - stop_errors: list[str] = [] + stop_errors: List[str] = [] for driver in reversed(self._driver_registry.list()): try: await driver.stop() From e4850c469feb68c8bd0dfbb66593b5dd959fbdd6 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 20 Mar 2026 22:23:47 +0800 Subject: [PATCH 17/45] feat: Enhance plugin loading and management - Added module_name parameter to PluginMeta for better module tracking. - Improved documentation for PluginMeta and PluginLoader methods. - Introduced methods for managing loaded plugins: set_loaded_plugin, remove_loaded_plugin, and purge_plugin_modules. - Enhanced dependency resolution in PluginLoader with resolve_dependencies method. - Implemented candidate discovery and loading in PluginLoader. - Added support for plugin reloading with _reload_plugin_by_id in PluginRunner. - Improved error handling and logging throughout the RPCClient and PluginRunner. - Added support for handling hook invocations in PluginRunner. - Refactored plugin registration and unregistration processes for clarity and efficiency. --- src/plugin_runtime/capabilities/components.py | 15 +- src/plugin_runtime/capabilities/registry.py | 115 ++-- src/plugin_runtime/host/capability_service.py | 21 +- src/plugin_runtime/host/supervisor.py | 590 +++++++++++++++--- src/plugin_runtime/integration.py | 30 +- src/plugin_runtime/protocol/envelope.py | 39 +- src/plugin_runtime/runner/plugin_loader.py | 188 +++++- src/plugin_runtime/runner/rpc_client.py | 241 +++---- src/plugin_runtime/runner/runner_main.py | 445 +++++++++++-- 9 files changed, 1351 insertions(+), 333 deletions(-) diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index aa7ceb46..4223525f 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -174,7 +174,10 @@ class RuntimeComponentCapabilityMixin: if registered_supervisor is not None: try: - reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}") + reloaded = await registered_supervisor.reload_plugins( + plugin_ids=[plugin_name], + reason=f"load {plugin_name}", + ) if reloaded: return {"success": True, "count": 1} return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} @@ -186,7 +189,10 @@ class RuntimeComponentCapabilityMixin: for pdir in sv._plugin_dirs: if (pdir / plugin_name).is_dir(): try: - reloaded = await sv.reload_plugins(reason=f"load {plugin_name}") + reloaded = await sv.reload_plugins( + plugin_ids=[plugin_name], + reason=f"load {plugin_name}", + ) if reloaded: return {"success": True, "count": 1} return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} @@ -222,7 +228,10 @@ class RuntimeComponentCapabilityMixin: if sv is not None: try: - reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}") + reloaded = await sv.reload_plugins( + plugin_ids=[plugin_name], + reason=f"reload {plugin_name}", + ) if reloaded: return {"success": True} return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py index abce97dc..ead5876a 100644 --- a/src/plugin_runtime/capabilities/registry.py +++ b/src/plugin_runtime/capabilities/registry.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from src.common.logger import get_logger from src.plugin_runtime.host.supervisor import PluginSupervisor @@ -12,67 +12,78 @@ logger = get_logger("plugin_runtime.integration") def register_capability_impls(manager: "PluginRuntimeManager", supervisor: PluginSupervisor) -> None: """向指定 Supervisor 注册主程序提供的能力实现。""" cap_service = supervisor.capability_service + rpc_server = supervisor.rpc_server - cap_service.register_capability("send.text", manager._cap_send_text) - cap_service.register_capability("send.emoji", manager._cap_send_emoji) - cap_service.register_capability("send.image", manager._cap_send_image) - cap_service.register_capability("send.command", manager._cap_send_command) - cap_service.register_capability("send.custom", manager._cap_send_custom) + def _register(name: str, impl: Any) -> None: + """注册单个能力实现及其 RPC 入口。 - cap_service.register_capability("llm.generate", manager._cap_llm_generate) - cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools) - cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models) + Args: + name: 能力名称。 + impl: 能力实现函数。 + """ + cap_service.register_capability(name, impl) + rpc_server.register_method(name, cap_service.handle_capability_request) - cap_service.register_capability("config.get", manager._cap_config_get) - cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin) - cap_service.register_capability("config.get_all", manager._cap_config_get_all) + _register("send.text", manager._cap_send_text) + _register("send.emoji", manager._cap_send_emoji) + _register("send.image", manager._cap_send_image) + _register("send.command", manager._cap_send_command) + _register("send.custom", manager._cap_send_custom) - cap_service.register_capability("database.query", manager._cap_database_query) - cap_service.register_capability("database.save", manager._cap_database_save) - cap_service.register_capability("database.get", manager._cap_database_get) - cap_service.register_capability("database.delete", manager._cap_database_delete) - cap_service.register_capability("database.count", manager._cap_database_count) + _register("llm.generate", manager._cap_llm_generate) + _register("llm.generate_with_tools", manager._cap_llm_generate_with_tools) + _register("llm.get_available_models", manager._cap_llm_get_available_models) - cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams) - cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams) - cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams) - cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id) - cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id) + _register("config.get", manager._cap_config_get) + _register("config.get_plugin", manager._cap_config_get_plugin) + _register("config.get_all", manager._cap_config_get_all) - cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time) - cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat) - cap_service.register_capability("message.get_recent", manager._cap_message_get_recent) - cap_service.register_capability("message.count_new", manager._cap_message_count_new) - cap_service.register_capability("message.build_readable", manager._cap_message_build_readable) + _register("database.query", manager._cap_database_query) + _register("database.save", manager._cap_database_save) + _register("database.get", manager._cap_database_get) + _register("database.delete", manager._cap_database_delete) + _register("database.count", manager._cap_database_count) - cap_service.register_capability("person.get_id", manager._cap_person_get_id) - cap_service.register_capability("person.get_value", manager._cap_person_get_value) - cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name) + _register("chat.get_all_streams", manager._cap_chat_get_all_streams) + _register("chat.get_group_streams", manager._cap_chat_get_group_streams) + _register("chat.get_private_streams", manager._cap_chat_get_private_streams) + _register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id) + _register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id) - cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description) - cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random) - cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count) - cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions) - cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all) - cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info) - cap_service.register_capability("emoji.register", manager._cap_emoji_register) - cap_service.register_capability("emoji.delete", manager._cap_emoji_delete) + _register("message.get_by_time", manager._cap_message_get_by_time) + _register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat) + _register("message.get_recent", manager._cap_message_get_recent) + _register("message.count_new", manager._cap_message_count_new) + _register("message.build_readable", manager._cap_message_build_readable) - cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value) - cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust) - cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust) + _register("person.get_id", manager._cap_person_get_id) + _register("person.get_value", manager._cap_person_get_value) + _register("person.get_id_by_name", manager._cap_person_get_id_by_name) - cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions) + _register("emoji.get_by_description", manager._cap_emoji_get_by_description) + _register("emoji.get_random", manager._cap_emoji_get_random) + _register("emoji.get_count", manager._cap_emoji_get_count) + _register("emoji.get_emotions", manager._cap_emoji_get_emotions) + _register("emoji.get_all", manager._cap_emoji_get_all) + _register("emoji.get_info", manager._cap_emoji_get_info) + _register("emoji.register", manager._cap_emoji_register) + _register("emoji.delete", manager._cap_emoji_delete) - cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins) - cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info) - cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) - cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins) - cap_service.register_capability("component.enable", manager._cap_component_enable) - cap_service.register_capability("component.disable", manager._cap_component_disable) - cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin) - cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin) - cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin) + _register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value) + _register("frequency.set_adjust", manager._cap_frequency_set_adjust) + _register("frequency.get_adjust", manager._cap_frequency_get_adjust) - cap_service.register_capability("knowledge.search", manager._cap_knowledge_search) + _register("tool.get_definitions", manager._cap_tool_get_definitions) + + _register("component.get_all_plugins", manager._cap_component_get_all_plugins) + _register("component.get_plugin_info", manager._cap_component_get_plugin_info) + _register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) + _register("component.list_registered_plugins", manager._cap_component_list_registered_plugins) + _register("component.enable", manager._cap_component_enable) + _register("component.disable", manager._cap_component_disable) + _register("component.load_plugin", manager._cap_component_load_plugin) + _register("component.unload_plugin", manager._cap_component_unload_plugin) + _register("component.reload_plugin", manager._cap_component_reload_plugin) + + _register("knowledge.search", manager._cap_knowledge_search) logger.debug("已注册全部主程序能力实现") diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 98366a07..761b20ca 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -30,6 +30,11 @@ class CapabilityService: """ def __init__(self, authorization: "AuthorizationManager") -> None: + """初始化能力服务。 + + Args: + authorization: 能力授权管理器。 + """ self._authorization = authorization # capability_name -> implementation self._implementations: Dict[str, CapabilityImpl] = {} @@ -51,13 +56,19 @@ class CapabilityService: 校验权限后调用对应实现。 """ plugin_id = envelope.plugin_id + payload = envelope.payload if isinstance(envelope.payload, dict) else {} try: - req = CapabilityRequestPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}") + req = CapabilityRequestPayload.model_validate(payload) + capability = req.capability + args = req.args + except Exception: + capability = envelope.method + raw_args = payload.get("args", payload) + args = raw_args if isinstance(raw_args, dict) else {} - capability = req.capability + if not capability: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, "能力调用缺少 capability") # 1. 权限校验 allowed, reason = self._authorization.check_capability(plugin_id, capability) @@ -71,7 +82,7 @@ class CapabilityService: # 3. 执行 try: - result = await impl(plugin_id, capability, req.args) + result = await impl(plugin_id, capability, args) resp_payload = CapabilityResponsePayload(success=True, result=result) return envelope.make_response(payload=resp_payload.model_dump()) except RPCError as e: diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 82a5970b..5ae3bdee 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -1,32 +1,38 @@ from pathlib import Path -from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import asyncio - +import contextlib +import os +import sys from src.common.logger import get_logger from src.config.config import global_config -from src.plugin_runtime.transport.factory import create_transport_server +from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, ConfigUpdatedPayload, Envelope, HealthPayload, - LogBatchPayload, + PROTOCOL_VERSION, RegisterPluginPayload, + ReloadPluginResultPayload, RunnerReadyPayload, ShutdownPayload, + UnregisterPluginPayload, ) +from src.plugin_runtime.protocol.codec import MsgPackCodec +from src.plugin_runtime.protocol.errors import ErrorCode, RPCError +from src.plugin_runtime.transport.factory import create_transport_server from .authorization import AuthorizationManager from .capability_service import CapabilityService -from .rpc_server import RPCServer -from .logger_bridge import RunnerLogBridge from .component_registry import ComponentRegistry from .event_dispatcher import EventDispatcher from .hook_dispatcher import HookDispatcher +from .logger_bridge import RunnerLogBridge from .message_gateway import MessageGateway -from .message_utils import PluginMessageUtils +from .rpc_server import RPCServer if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage @@ -35,7 +41,11 @@ logger = get_logger("plugin_runtime.host.runner_manager") class PluginRunnerSupervisor: - """插件的Runner管理器,负责管理Runner的生命周期""" + """插件 Runner 监督器。 + + 负责 Host 侧与单个 Runner 子进程之间的生命周期、内部 RPC、 + 健康检查和插件级重载协调。 + """ def __init__( self, @@ -44,13 +54,24 @@ class PluginRunnerSupervisor: health_check_interval_sec: Optional[float] = None, max_restart_attempts: Optional[int] = None, runner_spawn_timeout_sec: Optional[float] = None, - ): - _cfg = global_config.plugin_runtime - self._plugin_dirs: List[Path] = plugin_dirs or [] - self._health_interval = health_check_interval_sec or _cfg.health_check_interval_sec or 30.0 - self._runner_spawn_timeout = runner_spawn_timeout_sec or _cfg.runner_spawn_timeout_sec or 30.0 + ) -> None: + """初始化 Supervisor。 + + Args: + plugin_dirs: 由当前 Runner 负责加载的插件目录列表。 + socket_path: 自定义 IPC 地址;留空时由传输层自动生成。 + health_check_interval_sec: 健康检查间隔,单位秒。 + max_restart_attempts: 自动重启 Runner 的最大次数。 + runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。 + """ + runtime_config = global_config.plugin_runtime + self._plugin_dirs: List[Path] = plugin_dirs or [] + self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0 + self._runner_spawn_timeout: float = ( + runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0 + ) + self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3 - # 基础设施 self._transport = create_transport_server(socket_path=socket_path) self._authorization = AuthorizationManager() self._capability_service = CapabilityService(self._authorization) @@ -58,61 +79,55 @@ class PluginRunnerSupervisor: self._event_dispatcher = EventDispatcher(self._component_registry) self._hook_dispatcher = HookDispatcher(self._component_registry) self._message_gateway = MessageGateway(self._component_registry) - - # 编解码和服务器 - from src.plugin_runtime.protocol.codec import MsgPackCodec + self._log_bridge = RunnerLogBridge() codec = MsgPackCodec() self._rpc_server = RPCServer(transport=self._transport, codec=codec) - # Runner 子进程 self._runner_process: Optional[asyncio.subprocess.Process] = None - self._max_restart_attempts: int = max_restart_attempts or _cfg.max_restart_attempts or 3 - self._restart_count: int = 0 - - # 已注册的插件组件信息 self._registered_plugins: Dict[str, RegisterPluginPayload] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() + self._health_task: Optional[asyncio.Task[None]] = None + self._stderr_drain_task: Optional[asyncio.Task[None]] = None + self._restart_count: int = 0 + self._running: bool = False - # 后台任务 - self._health_task: Optional[asyncio.Task] = None - # Runner stderr 流排空任务(仅保留 stderr,用于 IPC 建立前的启动日志倒空、致命错误输出等场景) - self._stderr_drain_task: Optional[asyncio.Task] = None - self._running = False - - # Runner 日志桥(将 Runner 上报的批量日志重放到主进程 Logger) - self._log_bridge: RunnerLogBridge = RunnerLogBridge() - - # 注册内部 RPC 方法 - self._register_internal_methods() # TODO: 完成内部方法注册 + self._register_internal_methods() @property def authorization_manager(self) -> AuthorizationManager: + """返回授权管理器。""" return self._authorization @property def capability_service(self) -> CapabilityService: + """返回能力服务。""" return self._capability_service @property def component_registry(self) -> ComponentRegistry: + """返回组件注册表。""" return self._component_registry @property def event_dispatcher(self) -> EventDispatcher: + """返回事件分发器。""" return self._event_dispatcher @property def hook_dispatcher(self) -> HookDispatcher: + """返回 Hook 分发器。""" return self._hook_dispatcher @property def message_gateway(self) -> MessageGateway: + """返回消息网关。""" return self._message_gateway @property def rpc_server(self) -> RPCServer: + """返回底层 RPC 服务端。""" return self._rpc_server async def dispatch_event( @@ -121,11 +136,28 @@ class PluginRunnerSupervisor: message: Optional["SessionMessage"] = None, extra_args: Optional[Dict[str, Any]] = None, ) -> Tuple[bool, Optional["SessionMessage"]]: - """分发事件到所有对应 handler 的快捷方法。""" + """分发事件到已注册的事件处理器。 + + Args: + event_type: 事件类型。 + message: 可选的消息对象。 + extra_args: 附加参数。 + + Returns: + Tuple[bool, Optional[SessionMessage]]: 是否继续处理,以及插件可能修改后的消息。 + """ return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args) - async def dispatch_hook(self, stage: str, **kwargs): - """分发Hook事件到所有对应 handler 的快捷方法。""" + async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]: + """分发 Hook 到已注册的 Hook 处理器。 + + Args: + stage: Hook 阶段名称。 + **kwargs: 传递给 Hook 的关键字参数。 + + Returns: + Dict[str, Any]: 经 Hook 修改后的参数字典。 + """ return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs) async def send_message_to_external( @@ -135,60 +167,68 @@ class PluginRunnerSupervisor: enabled_only: bool = True, save_to_db: bool = True, ) -> bool: - """发送系统内部消息到外部平台的快捷方法。""" + """通过插件消息网关发送外部消息。 + + Args: + internal_message: 系统内部消息对象。 + enabled_only: 是否仅使用启用的网关组件。 + save_to_db: 发送成功后是否写入数据库。 + + Returns: + bool: 是否发送成功。 + """ return await self._message_gateway.send_message_to_external( - internal_message, self, enabled_only=enabled_only, save_to_db=save_to_db + internal_message, + self, + enabled_only=enabled_only, + save_to_db=save_to_db, ) async def start(self) -> None: - """启动 Supervisor + """启动 Supervisor。""" + if self._running: + logger.warning("PluginRunnerSupervisor 已在运行,跳过重复启动") + return - 1. 启动 RPC Server - 2. 拉起 Runner 子进程 - 3. 启动健康检查 - """ self._running = True + self._restart_count = 0 + self._clear_runner_state() - # 启动 RPC Server await self._rpc_server.start() - # 拉起 Runner 进程 await self._spawn_runner() - # 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪 try: await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) except TimeoutError: if not self._rpc_server.is_connected: - logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败") + logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败") else: - logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成初始化,后续操作可能失败") + logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败") - # 启动健康检查 - self._health_task = asyncio.create_task(self._health_check_loop()) - - logger.info("PluginSupervisor 已启动") + self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health") + logger.info("PluginRunnerSupervisor 已启动") async def stop(self) -> None: - """停止 Supervisor""" + """停止 Supervisor。""" + if not self._running: + return + self._running = False - # 停止组件 - await self._event_dispatcher.stop() - await self._hook_dispatcher.stop() - - # 停止健康检查 - if self._health_task: + if self._health_task is not None: self._health_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._health_task self._health_task = None - # 优雅关停 Runner - await self._shutdown_runner() - - # 停止 RPC Server + await self._event_dispatcher.stop() + await self._hook_dispatcher.stop() + await self._shutdown_runner(reason="host_stop") await self._rpc_server.stop() + self._clear_runner_state() - logger.info("PluginSupervisor 已停止") + logger.info("PluginRunnerSupervisor 已停止") async def invoke_plugin( self, @@ -198,9 +238,17 @@ class PluginRunnerSupervisor: args: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> Envelope: - """调用插件组件 + """调用 Runner 内的插件组件。 - 由主进程业务逻辑调用,通过 RPC 转发给 Runner。 + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + component_name: 组件名。 + args: 调用参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: RPC 响应信封。 """ return await self._rpc_server.send_request( method, @@ -210,27 +258,421 @@ class PluginRunnerSupervisor: ) async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: - raise NotImplementedError("等待SDK完成") # TODO: 完成对应的调用和请求逻辑 + """按插件 ID 触发精确重载。 + + Args: + plugin_id: 目标插件 ID。 + reason: 重载原因。 + + Returns: + bool: 是否重载成功。 + """ + try: + response = await self._rpc_server.send_request( + "plugin.reload", + plugin_id=plugin_id, + payload={"plugin_id": plugin_id, "reason": reason}, + timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), + ) + except Exception as exc: + logger.error(f"插件 {plugin_id} 重载请求失败: {exc}") + return False + + result = ReloadPluginResultPayload.model_validate(response.payload) + if not result.success: + logger.warning(f"插件 {plugin_id} 重载失败: {result.failed_plugins}") + return result.success + + async def reload_plugins( + self, + plugin_ids: Optional[List[str]] = None, + reason: str = "manual", + ) -> bool: + """批量重载插件。 + + Args: + plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。 + reason: 重载原因。 + + Returns: + bool: 是否全部重载成功。 + """ + target_plugin_ids = plugin_ids or list(self._registered_plugins.keys()) + ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids)) + success = True + + for plugin_id in ordered_plugin_ids: + reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason) + success = success and reloaded + + return success + + async def notify_plugin_config_updated( + self, + plugin_id: str, + config_data: Optional[Dict[str, Any]] = None, + config_version: str = "", + ) -> bool: + """向 Runner 推送插件配置更新。 + + Args: + plugin_id: 目标插件 ID。 + config_data: 配置内容。 + config_version: 配置版本号。 + + Returns: + bool: 请求是否成功送达并被 Runner 接受。 + """ + payload = ConfigUpdatedPayload( + plugin_id=plugin_id, + config_version=config_version, + config_data=config_data or {}, + ) + try: + response = await self._rpc_server.send_request( + "plugin.config_updated", + plugin_id=plugin_id, + payload=payload.model_dump(), + timeout_ms=10000, + ) + except Exception as exc: + logger.warning(f"插件 {plugin_id} 配置更新通知失败: {exc}") + return False + + return bool(response.payload.get("acknowledged", False)) async def _wait_for_runner_connection(self, timeout_sec: float) -> None: - """等待 Runner 连接上 RPC Server""" + """等待 Runner 建立 RPC 连接。 - async def wait_for_connection(): + Args: + timeout_sec: 超时时间,单位秒。 + + Raises: + TimeoutError: 在超时时间内 Runner 未完成连接。 + """ + + async def wait_for_connection() -> None: + """轮询等待 RPC 连接建立。""" while self._running and not self._rpc_server.is_connected: await asyncio.sleep(0.1) try: await asyncio.wait_for(wait_for_connection(), timeout=timeout_sec) logger.info("Runner 已连接到 RPC Server") - except asyncio.TimeoutError as e: - raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from e + except asyncio.TimeoutError as exc: + raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from exc async def _wait_for_runner_ready(self, timeout_sec: float = 30.0) -> RunnerReadyPayload: - """等待 Runner 完成初始化并上报就绪""" + """等待 Runner 完成启动初始化。 + Args: + timeout_sec: 超时时间,单位秒。 + + Returns: + RunnerReadyPayload: Runner 上报的就绪信息。 + + Raises: + TimeoutError: 在超时时间内 Runner 未完成初始化。 + """ try: await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec) logger.info("Runner 已完成初始化并上报就绪") return self._runner_ready_payloads - except asyncio.TimeoutError as e: - raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from e + except asyncio.TimeoutError as exc: + raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc + + def _register_internal_methods(self) -> None: + """注册 Host 侧内部 RPC 方法。""" + self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request) + self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) + self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin) + self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin) + self._rpc_server.register_method("plugin.unregister", self._handle_unregister_plugin) + self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch) + self._rpc_server.register_method("runner.ready", self._handle_runner_ready) + + async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope: + """处理插件 bootstrap 请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = BootstrapPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + if payload.capabilities_required: + self._authorization.register_plugin(payload.plugin_id, payload.capabilities_required) + else: + self._authorization.revoke_permission_token(payload.plugin_id) + + return envelope.make_response(payload={"accepted": True, "plugin_id": payload.plugin_id}) + + async def _handle_register_plugin(self, envelope: Envelope) -> Envelope: + """处理插件组件注册请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = RegisterPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + self._component_registry.remove_components_by_plugin(payload.plugin_id) + registered_count = self._component_registry.register_plugin_components( + payload.plugin_id, + [component.model_dump() for component in payload.components], + ) + self._registered_plugins[payload.plugin_id] = payload + + return envelope.make_response( + payload={ + "accepted": True, + "plugin_id": payload.plugin_id, + "registered_components": registered_count, + } + ) + + async def _handle_unregister_plugin(self, envelope: Envelope) -> Envelope: + """处理插件注销请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = UnregisterPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) + self._authorization.revoke_permission_token(payload.plugin_id) + removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None + + return envelope.make_response( + payload={ + "accepted": True, + "plugin_id": payload.plugin_id, + "reason": payload.reason, + "removed_components": removed_components, + "removed_registration": removed_registration, + } + ) + + async def _handle_runner_ready(self, envelope: Envelope) -> Envelope: + """处理 Runner 就绪通知。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = RunnerReadyPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + self._runner_ready_payloads = payload + self._runner_ready_events.set() + return envelope.make_response(payload={"accepted": True}) + + def _build_runner_environment(self) -> Dict[str, str]: + """构建拉起 Runner 所需的环境变量。 + + Returns: + Dict[str, str]: 传递给 Runner 进程的环境变量映射。 + """ + return { + ENV_HOST_VERSION: PROTOCOL_VERSION, + ENV_IPC_ADDRESS: self._transport.get_address(), + ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs), + ENV_SESSION_TOKEN: self._rpc_server.session_token, + } + + async def _spawn_runner(self) -> None: + """拉起 Runner 子进程。""" + if self._runner_process is not None and self._runner_process.returncode is None: + logger.warning("Runner 已在运行,跳过重复拉起") + return + + self._clear_runner_state() + + env = os.environ.copy() + env.update(self._build_runner_environment()) + + self._runner_process = await asyncio.create_subprocess_exec( + sys.executable, + "-m", + "src.plugin_runtime.runner.runner_main", + env=env, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + ) + + if self._runner_process.stderr is not None: + self._stderr_drain_task = asyncio.create_task( + self._drain_runner_stderr(self._runner_process.stderr), + name="PluginRunnerSupervisor.stderr", + ) + + logger.info(f"Runner 已拉起,pid={self._runner_process.pid}") + + async def _drain_runner_stderr(self, stream: asyncio.StreamReader) -> None: + """持续排空 Runner 的 stderr。 + + Args: + stream: Runner 的 stderr 流。 + """ + try: + while True: + line = await stream.readline() + if not line: + return + message = line.decode("utf-8", errors="replace").rstrip() + if message: + logger.warning(f"[runner-stderr] {message}") + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"排空 Runner stderr 失败: {exc}") + + async def _shutdown_runner(self, reason: str = "normal") -> None: + """优雅关闭 Runner 子进程。 + + Args: + reason: 关停原因。 + """ + process = self._runner_process + if process is None: + return + + payload = ShutdownPayload(reason=reason) + + if process.returncode is None and self._rpc_server.is_connected: + with contextlib.suppress(Exception): + await self._rpc_server.send_request( + "plugin.prepare_shutdown", + payload=payload.model_dump(), + timeout_ms=payload.drain_timeout_ms, + ) + with contextlib.suppress(Exception): + await self._rpc_server.send_request( + "plugin.shutdown", + payload=payload.model_dump(), + timeout_ms=payload.drain_timeout_ms, + ) + + if process.returncode is None: + try: + await asyncio.wait_for(process.wait(), timeout=max(payload.drain_timeout_ms / 1000.0, 1.0)) + except asyncio.TimeoutError: + logger.warning("Runner 优雅退出超时,尝试 terminate") + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning("Runner terminate 超时,尝试 kill") + process.kill() + with contextlib.suppress(Exception): + await asyncio.wait_for(process.wait(), timeout=5.0) + + self._runner_process = None + + if self._stderr_drain_task is not None: + self._stderr_drain_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._stderr_drain_task + self._stderr_drain_task = None + + self._clear_runner_state() + + async def _health_check_loop(self) -> None: + """周期性检查 Runner 健康状态,并在必要时重启。""" + timeout_ms = max(int(self._health_interval * 1000), 1000) + + while self._running: + try: + await asyncio.sleep(self._health_interval) + except asyncio.CancelledError: + return + + if not self._running: + return + + process = self._runner_process + if process is None or process.returncode is not None: + reason = "runner_process_exited" if process is not None else "runner_process_missing" + restarted = await self._restart_runner(reason=reason) + if not restarted: + return + continue + + try: + response = await self._rpc_server.send_request("plugin.health", timeout_ms=timeout_ms) + health = HealthPayload.model_validate(response.payload) + if not health.healthy: + restarted = await self._restart_runner(reason="health_check_unhealthy") + if not restarted: + return + except asyncio.CancelledError: + return + except (RPCError, Exception) as exc: + logger.warning(f"Runner 健康检查失败: {exc}") + restarted = await self._restart_runner(reason="health_check_failed") + if not restarted: + return + + async def _restart_runner(self, reason: str) -> bool: + """在 Runner 异常时执行整进程级重启。 + + Args: + reason: 触发重启的原因。 + + Returns: + bool: 是否重启成功。 + """ + if not self._running: + return False + + if self._restart_count >= self._max_restart_attempts: + logger.error(f"Runner 自动重启次数已达上限,停止重启。reason={reason}") + return False + + self._restart_count += 1 + logger.warning(f"准备重启 Runner,第 {self._restart_count} 次,reason={reason}") + + await self._shutdown_runner(reason=reason) + + try: + await self._spawn_runner() + await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) + await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) + except Exception as exc: + logger.error(f"Runner 重启失败: {exc}", exc_info=True) + return False + + self._restart_count = 0 + logger.info("Runner 已成功重启") + return True + + def _clear_runner_state(self) -> None: + """清理当前 Runner 对应的 Host 侧注册状态。""" + self._authorization.clear() + self._component_registry.clear() + self._registered_plugins.clear() + self._runner_ready_events = asyncio.Event() + self._runner_ready_payloads = RunnerReadyPayload() + + +PluginSupervisor = PluginRunnerSupervisor diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 04c8e324..730da3e1 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -23,8 +23,10 @@ from src.plugin_runtime.capabilities import ( RuntimeDataCapabilityMixin, ) from src.plugin_runtime.capabilities.registry import register_capability_impls +from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage from src.plugin_runtime.host.supervisor import PluginSupervisor logger = get_logger("plugin_runtime.integration") @@ -223,9 +225,9 @@ class PluginRuntimeManager( async def bridge_event( self, event_type_value: str, - message_dict: Optional[Dict[str, Any]] = None, + message_dict: Optional[MessageDict] = None, extra_args: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]]]: + ) -> Tuple[bool, Optional[MessageDict]]: """将事件分发到所有 Supervisor Returns: @@ -235,17 +237,23 @@ class PluginRuntimeManager( return True, None new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value) - modified: Optional[Dict[str, Any]] = None + modified: Optional[MessageDict] = None + current_message: Optional["SessionMessage"] = ( + PluginMessageUtils._build_session_message_from_dict(dict(message_dict)) + if message_dict is not None + else None + ) for sv in self.supervisors: try: cont, mod = await sv.dispatch_event( event_type=new_event_type, - message=modified or message_dict, + message=current_message, extra_args=extra_args, ) if mod is not None: - modified = mod + current_message = mod + modified = PluginMessageUtils._session_message_to_dict(mod) if not cont: return False, modified except Exception as e: @@ -477,7 +485,7 @@ class PluginRuntimeManager( logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}") return - reload_supervisors: List[Any] = [] + reload_supervisors: Dict[Any, List[str]] = {} changed_paths = [change.path.resolve() for change in changes] for supervisor in self.supervisors: @@ -485,11 +493,13 @@ class PluginRuntimeManager( plugin_id = self._match_plugin_id_for_supervisor(supervisor, path) if plugin_id is None: continue - if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors: - reload_supervisors.append(supervisor) + if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py": + reload_supervisors.setdefault(supervisor, []) + if plugin_id not in reload_supervisors[supervisor]: + reload_supervisors[supervisor].append(plugin_id) - for supervisor in reload_supervisors: - await supervisor.reload_plugins(reason="file_watcher") + for supervisor, plugin_ids in reload_supervisors.items(): + await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher") if reload_supervisors: self._refresh_plugin_config_watch_subscriptions() diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index e81df019..6f95f97f 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -144,7 +144,11 @@ class ComponentDeclaration(BaseModel): class RegisterPluginPayload(BaseModel): - """plugin.register_plugin 请求 payload""" + """插件组件注册请求载荷。 + + 该模型同时用于 ``plugin.register_components`` 与兼容旧命名的 + ``plugin.register_plugin`` 请求。 + """ plugin_id: str = Field(description="插件 ID") """插件 ID""" @@ -248,6 +252,39 @@ class ShutdownPayload(BaseModel): """排空超时 (ms)""" +class UnregisterPluginPayload(BaseModel): + """插件注销请求载荷。""" + + plugin_id: str = Field(description="插件 ID") + """插件 ID""" + reason: str = Field(default="manual", description="注销原因") + """注销原因""" + + +class ReloadPluginPayload(BaseModel): + """插件重载请求载荷。""" + + plugin_id: str = Field(description="目标插件 ID") + """目标插件 ID""" + reason: str = Field(default="manual", description="重载原因") + """重载原因""" + + +class ReloadPluginResultPayload(BaseModel): + """插件重载结果载荷。""" + + success: bool = Field(description="是否重载成功") + """是否重载成功""" + requested_plugin_id: str = Field(description="请求重载的插件 ID") + """请求重载的插件 ID""" + reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表") + """成功完成重载的插件列表""" + unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表") + """本次已卸载的插件列表""" + failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因") + """重载失败的插件及原因""" + + # ====== 日志传输 ====== diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 11ba45e7..90c8bf47 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -32,11 +32,22 @@ class PluginMeta: self, plugin_id: str, plugin_dir: str, + module_name: str, plugin_instance: Any, manifest: Dict[str, Any], ) -> None: + """初始化插件元数据。 + + Args: + plugin_id: 插件 ID。 + plugin_dir: 插件目录绝对路径。 + module_name: 插件入口模块名。 + plugin_instance: 插件实例对象。 + manifest: 解析后的 manifest 内容。 + """ self.plugin_id = plugin_id self.plugin_dir = plugin_dir + self.module_name = module_name self.instance = plugin_instance self.manifest = manifest self.version = manifest.get("version", "1.0.0") @@ -45,6 +56,14 @@ class PluginMeta: @staticmethod def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]: + """从 manifest 中提取依赖列表。 + + Args: + manifest: 插件 manifest。 + + Returns: + List[str]: 规范化后的依赖插件 ID 列表。 + """ raw = manifest.get("dependencies", []) result: List[str] = [] for dep in raw: @@ -66,19 +85,24 @@ class PluginLoader: """ def __init__(self, host_version: str = "") -> None: + """初始化插件加载器。 + + Args: + host_version: Host 版本号,用于 manifest 兼容性校验。 + """ self._loaded_plugins: Dict[str, PluginMeta] = {} self._failed_plugins: Dict[str, str] = {} self._manifest_validator = ManifestValidator(host_version=host_version) self._compat_hook_installed = False def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]: - """扫描多个目录并加载所有插件(含依赖排序和 manifest 校验) + """扫描多个目录并加载所有插件。 Args: - plugin_dirs: 插件目录列表 + plugin_dirs: 插件目录列表。 Returns: - 成功加载的插件元数据列表(按依赖顺序) + List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。 """ candidates, duplicate_candidates = self._discover_candidates(plugin_dirs) self._record_duplicate_candidates(duplicate_candidates) @@ -90,6 +114,18 @@ class PluginLoader: # 第三阶段:按依赖顺序加载 return self._load_plugins_in_order(load_order, candidates) + def discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: + """扫描插件目录并返回候选插件。 + + Args: + plugin_dirs: 需要扫描的插件根目录列表。 + + Returns: + Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: + 候选插件映射和重复插件 ID 冲突映射。 + """ + return self._discover_candidates(plugin_dirs) + def _discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: """扫描插件目录并收集候选插件。""" candidates: Dict[str, PluginCandidate] = {} @@ -170,7 +206,6 @@ class PluginLoader: plugin_dir, manifest, plugin_path = candidates[plugin_id] try: if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path): - self._loaded_plugins[meta.plugin_id] = meta results.append(meta) except Exception as e: self._failed_plugins[plugin_id] = str(e) @@ -182,22 +217,109 @@ class PluginLoader: """获取已加载的插件""" return self._loaded_plugins.get(plugin_id) + def set_loaded_plugin(self, meta: PluginMeta) -> None: + """登记一个已经完成初始化的插件。 + + Args: + meta: 待登记的插件元数据。 + """ + self._loaded_plugins[meta.plugin_id] = meta + + def remove_loaded_plugin(self, plugin_id: str) -> Optional[PluginMeta]: + """移除一个已加载插件的元数据。 + + Args: + plugin_id: 待移除的插件 ID。 + + Returns: + Optional[PluginMeta]: 被移除的插件元数据;不存在时返回 ``None``。 + """ + return self._loaded_plugins.pop(plugin_id, None) + + def purge_plugin_modules(self, plugin_id: str, plugin_dir: str) -> List[str]: + """清理指定插件目录下的模块缓存。 + + Args: + plugin_id: 插件 ID。 + plugin_dir: 插件目录绝对路径。 + + Returns: + List[str]: 已从 ``sys.modules`` 中移除的模块名列表。 + """ + removed_modules: List[str] = [] + plugin_path = Path(plugin_dir).resolve() + synthetic_module_name = f"_maibot_plugin_{plugin_id}" + + for module_name, module in list(sys.modules.items()): + if module_name == synthetic_module_name: + removed_modules.append(module_name) + sys.modules.pop(module_name, None) + continue + + module_file = getattr(module, "__file__", None) + if module_file is None: + continue + + try: + module_path = Path(module_file).resolve() + except Exception: + continue + + if module_path.is_relative_to(plugin_path): + removed_modules.append(module_name) + sys.modules.pop(module_name, None) + + importlib.invalidate_caches() + return removed_modules + def list_plugins(self) -> List[str]: """列出所有已加载的插件 ID""" return list(self._loaded_plugins.keys()) @property def failed_plugins(self) -> Dict[str, str]: + """返回当前记录的失败插件原因映射。""" return dict(self._failed_plugins) # ──── 依赖解析 ──────────────────────────────────────────── + def resolve_dependencies( + self, + candidates: Dict[str, PluginCandidate], + extra_available: Optional[Set[str]] = None, + ) -> Tuple[List[str], Dict[str, str]]: + """解析候选插件的依赖顺序。 + + Args: + candidates: 待加载的候选插件集合。 + extra_available: 视为已满足的外部依赖插件 ID 集合。 + + Returns: + Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。 + """ + return self._resolve_dependencies(candidates, extra_available=extra_available) + + def load_candidate(self, plugin_id: str, candidate: PluginCandidate) -> Optional[PluginMeta]: + """加载单个候选插件模块。 + + Args: + plugin_id: 插件 ID。 + candidate: 候选插件三元组。 + + Returns: + Optional[PluginMeta]: 加载成功的插件元数据;失败时返回 ``None``。 + """ + plugin_dir, manifest, plugin_path = candidate + return self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path) + def _resolve_dependencies( self, candidates: Dict[str, PluginCandidate], + extra_available: Optional[Set[str]] = None, ) -> Tuple[List[str], Dict[str, str]]: """拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。""" available = set(candidates.keys()) + satisfied_dependencies = set(extra_available or set()) dep_graph: Dict[str, Set[str]] = {} failed: Dict[str, str] = {} @@ -212,6 +334,8 @@ class PluginLoader: continue if dep_name in available: resolved.add(dep_name) + elif dep_name in satisfied_dependencies: + continue else: missing.append(dep_name) if missing: @@ -271,33 +395,39 @@ class PluginLoader: sys.modules[module_name] = module plugin_parent_dir = plugin_dir.parent - with self._temporary_sys_path_entry(plugin_parent_dir): - spec.loader.exec_module(module) + try: + with self._temporary_sys_path_entry(plugin_parent_dir): + spec.loader.exec_module(module) - # 优先使用新版 create_plugin 工厂函数 - create_plugin = getattr(module, "create_plugin", None) - if create_plugin is not None: - instance = create_plugin() - logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") - return PluginMeta( - plugin_id=plugin_id, - plugin_dir=str(plugin_dir), - plugin_instance=instance, - manifest=manifest, - ) + # 优先使用新版 create_plugin 工厂函数 + create_plugin = getattr(module, "create_plugin", None) + if create_plugin is not None: + instance = create_plugin() + logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=str(plugin_dir), + module_name=module_name, + plugin_instance=instance, + manifest=manifest, + ) - # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类 - instance = self._try_load_legacy_plugin(module, plugin_id) - if instance is not None: - logger.info( - f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" - ) - return PluginMeta( - plugin_id=plugin_id, - plugin_dir=str(plugin_dir), - plugin_instance=instance, - manifest=manifest, - ) + # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类 + instance = self._try_load_legacy_plugin(module, plugin_id) + if instance is not None: + logger.info( + f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" + ) + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=str(plugin_dir), + module_name=module_name, + plugin_instance=instance, + manifest=manifest, + ) + except Exception: + sys.modules.pop(module_name, None) + raise logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin") return None diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py index 6a1d59d5..dc917cc8 100644 --- a/src/plugin_runtime/runner/rpc_client.py +++ b/src/plugin_runtime/runner/rpc_client.py @@ -1,14 +1,6 @@ -"""Runner 端 RPC Client +"""Runner 端 RPC 客户端。""" -负责: -1. 连接 Host RPC Server -2. 发送握手(runner.hello) -3. 发送组件注册请求 -4. 接收并分发 Host 的调用请求 -5. 发送能力调用请求到 Host -""" - -from typing import Any, Awaitable, Callable, Dict, Optional, cast +from typing import Any, Awaitable, Callable, Dict, Optional, Set, cast import asyncio import contextlib @@ -29,12 +21,15 @@ from src.plugin_runtime.transport.factory import create_transport_client logger = get_logger("plugin_runtime.runner.rpc_client") -# RPC 方法处理器类型 MethodHandler = Callable[[Envelope], Awaitable[Envelope]] def _get_sdk_version() -> str: - """从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。""" + """读取 SDK 版本号。 + + Returns: + str: 已安装的 SDK 版本;读取失败时回退到 ``1.0.0``。 + """ try: from importlib.metadata import version @@ -47,73 +42,78 @@ SDK_VERSION = _get_sdk_version() class RPCClient: - """Runner 端 RPC 客户端 - - 管理与 Host 的 IPC 连接,支持双向 RPC 调用。 - """ + """Runner 端 RPC 客户端。""" def __init__( self, host_address: str, session_token: str, codec: Optional[Codec] = None, - ): - self._host_address = host_address - self._session_token = session_token - self._codec = codec or MsgPackCodec() + ) -> None: + """初始化 RPC 客户端。 + + Args: + host_address: Host 的 IPC 地址。 + session_token: 握手用会话令牌。 + codec: 可选的编解码器实现。 + """ + self._host_address: str = host_address + self._session_token: str = session_token + self._codec: Codec = codec or MsgPackCodec() self._id_gen = RequestIdGenerator() self._connection: Optional[Connection] = None - self._runner_id = str(uuid.uuid4()) - self._generation: int = 0 - - # 方法处理器注册表(Host 发来的调用) + self._runner_id: str = str(uuid.uuid4()) self._method_handlers: Dict[str, MethodHandler] = {} - - # 等待响应的 pending 请求: request_id -> Future - self._pending_requests: Dict[int, asyncio.Future] = {} - - # 运行状态 - self._running = False - self._recv_task: Optional[asyncio.Task] = None - self._background_tasks: set[asyncio.Task] = set() - - @property - def generation(self) -> int: - return self._generation + self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {} + self._running: bool = False + self._recv_task: Optional[asyncio.Task[None]] = None + self._background_tasks: Set[asyncio.Task[Any]] = set() @property def is_connected(self) -> bool: + """返回当前连接是否可用。""" return self._connection is not None and not self._connection.is_closed def register_method(self, method: str, handler: MethodHandler) -> None: - """注册方法处理器(处理 Host 发来的请求)""" + """注册 Host -> Runner 的 RPC 处理器。 + + Args: + method: RPC 方法名。 + handler: 方法处理函数。 + """ self._method_handlers[method] = handler def _require_connection(self) -> Connection: - """返回当前可用连接;若连接不可用则抛出 RPCError。""" + """返回当前可用连接。 + + Returns: + Connection: 当前连接对象。 + + Raises: + RPCError: 当前未连接到 Host。 + """ connection = self._connection if connection is None or connection.is_closed: raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host") return cast(Connection, connection) async def connect_and_handshake(self) -> bool: - """连接 Host 并完成握手 + """连接 Host 并完成握手。 Returns: - 是否握手成功 + bool: 是否握手成功。 """ client = create_transport_client(self._host_address) self._connection = await client.connect() connection = self._require_connection() - # 发送 runner.hello hello = HelloPayload( runner_id=self._runner_id, sdk_version=SDK_VERSION, session_token=self._session_token, ) - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, @@ -121,33 +121,27 @@ class RPCClient: payload=hello.model_dump(), ) - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) + await connection.send_frame(self._codec.encode_envelope(envelope)) - # 接收握手响应 resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0) - resp = self._codec.decode_envelope(resp_data) + response = self._codec.decode_envelope(resp_data) + resp_payload = HelloResponsePayload.model_validate(response.payload) - resp_payload = HelloResponsePayload.model_validate(resp.payload) if not resp_payload.accepted: logger.error(f"握手被拒绝: {resp_payload.reason}") - await self._connection.close() - self._connection = None + await self.disconnect() return False - self._generation = resp_payload.assigned_generation - logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}") - - # 启动消息接收循环 + logger.info(f"握手成功: host_version={resp_payload.host_version}") self._running = True - self._recv_task = asyncio.create_task(self._recv_loop()) - + self._recv_task = asyncio.create_task(self._recv_loop(), name="RPCClient.recv") return True async def disconnect(self) -> None: - """断开连接""" + """断开与 Host 的连接并清理状态。""" self._running = False - if self._recv_task: + + if self._recv_task is not None: self._recv_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._recv_task @@ -160,13 +154,12 @@ class RPCClient: await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() - # 取消所有 pending 请求 for future in self._pending_requests.values(): if not future.done(): future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭")) self._pending_requests.clear() - if self._connection: + if self._connection is not None: await self._connection.close() self._connection = None @@ -177,16 +170,27 @@ class RPCClient: payload: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> Envelope: - """向 Host 发送 RPC 请求并等待响应""" - connection = self._require_connection() + """向 Host 发送 RPC 请求并等待响应。 - request_id = self._id_gen.next() + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + payload: 请求载荷。 + timeout_ms: 超时时间,单位毫秒。 + + Returns: + Envelope: Host 返回的响应信封。 + + Raises: + RPCError: 发送失败、超时或连接异常。 + """ + connection = self._require_connection() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, method=method, plugin_id=plugin_id, - generation=self._generation, timeout_ms=timeout_ms, payload=payload or {}, ) @@ -196,21 +200,16 @@ class RPCClient: self._pending_requests[request_id] = future try: - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) - - timeout_sec = timeout_ms / 1000.0 - return await asyncio.wait_for(future, timeout=timeout_sec) + await connection.send_frame(self._codec.encode_envelope(envelope)) + return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0) except asyncio.TimeoutError: self._pending_requests.pop(request_id, None) raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None - except Exception as e: + except Exception as exc: self._pending_requests.pop(request_id, None) - if isinstance(e, RPCError): + if isinstance(exc, RPCError): raise - raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e - - # ─── 内部方法 ────────────────────────────────────────────── + raise RPCError(ErrorCode.E_UNKNOWN, str(exc)) from exc async def send_event( self, @@ -218,33 +217,30 @@ class RPCClient: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, ) -> None: - """向 Host 发送单向事件(fire-and-forget,不等待响应)。 + """向 Host 发送单向广播消息。 Args: - method: RPC 方法名,如 "runner.log_batch"。 - plugin_id: 目标插件 ID(可为空,表示 Runner 级消息)。 - payload: 事件数据。 + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + payload: 广播载荷。 """ if not self.is_connected: return connection = self._require_connection() - - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, - message_type=MessageType.EVENT, + message_type=MessageType.BROADCAST, method=method, plugin_id=plugin_id, - generation=self._generation, payload=payload or {}, ) - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) + await connection.send_frame(self._codec.encode_envelope(envelope)) async def _recv_loop(self) -> None: - """消息接收主循环""" - while self._running and self._connection and not self._connection.is_closed: + """持续接收 Host 发来的消息并分发。""" + while self._running and self._connection is not None and not self._connection.is_closed: try: data = await self._connection.recv_frame() except (asyncio.IncompleteReadError, ConnectionError): @@ -252,39 +248,47 @@ class RPCClient: break except asyncio.CancelledError: break - except Exception as e: - logger.error(f"接收帧失败: {e}") + except Exception as exc: + logger.error(f"接收帧失败: {exc}") break try: envelope = self._codec.decode_envelope(data) - except Exception as e: - logger.error(f"解码消息失败: {e}") + except Exception as exc: + logger.error(f"解码消息失败: {exc}") continue if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): self._track_background_task(asyncio.create_task(self._handle_request(envelope))) - elif envelope.is_event(): - self._track_background_task(asyncio.create_task(self._handle_event(envelope))) + elif envelope.is_broadcast(): + self._track_background_task(asyncio.create_task(self._handle_broadcast(envelope))) def _handle_response(self, envelope: Envelope) -> None: - """处理来自 Host 的响应""" + """处理 Host 返回的响应。 + + Args: + envelope: 响应信封。 + """ future = self._pending_requests.pop(envelope.request_id, None) - if future and not future.done(): - if envelope.error: - future.set_exception(RPCError.from_dict(envelope.error)) - else: - future.set_result(envelope) + if future is None or future.done(): + return + if envelope.error: + future.set_exception(RPCError.from_dict(envelope.error)) + else: + future.set_result(envelope) async def _handle_request(self, envelope: Envelope) -> None: - """处理来自 Host 的请求(调用插件组件)""" + """处理 Host 发来的请求。 + + Args: + envelope: 请求信封。 + """ connection = self._connection if connection is None or connection.is_closed: logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应") return - connection = cast(Connection, connection) handler = self._method_handlers.get(envelope.method) if handler is None: @@ -298,23 +302,34 @@ class RPCClient: try: response = await handler(envelope) await connection.send_frame(self._codec.encode_envelope(response)) - except RPCError as e: - error_resp = envelope.make_error_response(e.code.value, e.message, e.details) + except RPCError as exc: + error_resp = envelope.make_error_response(exc.code.value, exc.message, exc.details) await connection.send_frame(self._codec.encode_envelope(error_resp)) - except Exception as e: - logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True) - error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) + except Exception as exc: + logger.error(f"处理请求 {envelope.method} 异常: {exc}", exc_info=True) + error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc)) await connection.send_frame(self._codec.encode_envelope(error_resp)) - async def _handle_event(self, envelope: Envelope) -> None: - """处理来自 Host 的事件""" - if handler := self._method_handlers.get(envelope.method): - try: - await handler(envelope) - except Exception as e: - logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) + async def _handle_broadcast(self, envelope: Envelope) -> None: + """处理 Host 发来的广播事件。 - def _track_background_task(self, task: asyncio.Task) -> None: - """保持后台任务强引用,直到其完成或被取消。""" + Args: + envelope: 广播信封。 + """ + handler = self._method_handlers.get(envelope.method) + if handler is None: + return + + try: + await handler(envelope) + except Exception as exc: + logger.error(f"处理广播 {envelope.method} 异常: {exc}", exc_info=True) + + def _track_background_task(self, task: asyncio.Task[Any]) -> None: + """持有后台任务强引用直到其结束。 + + Args: + task: 后台任务。 + """ self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index dae1cfa1..771e685f 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -9,7 +9,7 @@ 6. 转发插件的能力调用到 Host """ -from typing import Any, Callable, List, Optional, Protocol, cast +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast from pathlib import Path @@ -32,8 +32,11 @@ from src.plugin_runtime.protocol.envelope import ( HealthPayload, InvokePayload, InvokeResultPayload, - RegisterComponentsPayload, + RegisterPluginPayload, + ReloadPluginPayload, + ReloadPluginResultPayload, RunnerReadyPayload, + UnregisterPluginPayload, ) from src.plugin_runtime.protocol.errors import ErrorCode from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler @@ -44,7 +47,8 @@ logger = get_logger("plugin_runtime.runner.main") class _ContextAwarePlugin(Protocol): - def _set_context(self, context: Any) -> None: ... + def _set_context(self, context: Any) -> None: + """为插件注入上下文对象。""" def _install_shutdown_signal_handlers( @@ -90,21 +94,29 @@ class PluginRunner: session_token: str, plugin_dirs: List[str], ) -> None: + """初始化 Runner。 + + Args: + host_address: Host 的 IPC 地址。 + session_token: 握手用会话令牌。 + plugin_dirs: 当前 Runner 负责扫描的插件目录列表。 + """ self._host_address: str = host_address self._session_token: str = session_token - self._plugin_dirs: list[str] = plugin_dirs + self._plugin_dirs: List[str] = plugin_dirs self._rpc_client: RPCClient = RPCClient(host_address, session_token) self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, "")) self._start_time: float = time.monotonic() self._shutting_down: bool = False + self._reload_lock: asyncio.Lock = asyncio.Lock() # IPC 日志 Handler:握手成功后安装,将所有 stdlib logging 转发到 Host self._log_handler: Optional[RunnerIPCLogHandler] = None - self._suspended_console_handlers: list[stdlib_logging.Handler] = [] + self._suspended_console_handlers: List[stdlib_logging.Handler] = [] async def run(self) -> None: - """Runner 主入口""" + """运行 Runner 主循环。""" # 1. 连接 Host logger.info(f"Runner 启动,连接 Host: {self._host_address}") ok = await self._rpc_client.connect_and_handshake() @@ -123,32 +135,11 @@ class PluginRunner: logger.info(f"已加载 {len(plugins)} 个插件") # 4. 注入 PluginContext + 调用 on_load 生命周期钩子 - failed_plugins: set[str] = set() + failed_plugins: Set[str] = set(self._loader.failed_plugins.keys()) for meta in plugins: - instance = meta.instance - self._inject_context(meta.plugin_id, instance) - self._apply_plugin_config(meta) - if not await self._bootstrap_plugin(meta): - failed_plugins.add(meta.plugin_id) - continue - if hasattr(instance, "on_load"): - try: - ret = instance.on_load() - if asyncio.iscoroutine(ret): - await ret - except Exception as e: - logger.error(f"插件 {meta.plugin_id} on_load 失败,跳过注册: {e}", exc_info=True) - failed_plugins.add(meta.plugin_id) - await self._deactivate_plugin(meta) - - # 5. 向 Host 注册所有插件的组件(跳过 on_load 失败的插件) - for meta in plugins: - if meta.plugin_id in failed_plugins: - continue - ok = await self._register_plugin(meta) + ok = await self._activate_plugin(meta) if not ok: failed_plugins.add(meta.plugin_id) - await self._deactivate_plugin(meta) successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins] await self._notify_ready(successful_plugins, sorted(failed_plugins)) @@ -232,7 +223,9 @@ class PluginRunner: bound_plugin_id = plugin_id async def _rpc_call( - method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None + method: str, + plugin_id: str = "", + payload: Optional[Dict[str, Any]] = None, ) -> Any: """桥接 PluginContext.call_capability → RPCClient.send_request。 @@ -257,7 +250,7 @@ class PluginRunner: cast(_ContextAwarePlugin, instance)._set_context(ctx) logger.debug(f"已为插件 {plugin_id} 注入 PluginContext") - def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None: + def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None: """在 Runner 侧为插件实例注入当前插件配置。""" instance = meta.instance if not hasattr(instance, "set_plugin_config"): @@ -270,7 +263,7 @@ class PluginRunner: logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}") @staticmethod - def _load_plugin_config(plugin_dir: str) -> dict[str, Any]: + def _load_plugin_config(plugin_dir: str) -> Dict[str, Any]: """从插件目录读取 config.toml。""" config_path = Path(plugin_dir) / "config.toml" if not config_path.exists(): @@ -286,16 +279,18 @@ class PluginRunner: return loaded if isinstance(loaded, dict) else {} def _register_handlers(self) -> None: - """注册方法处理器""" + """注册 Host -> Runner 的方法处理器。""" self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) + self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) self._rpc_client.register_method("plugin.health", self._handle_health) self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated) + self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin) async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool: """向 Host 同步插件 bootstrap 能力令牌。""" @@ -324,7 +319,14 @@ class PluginRunner: await self._bootstrap_plugin(meta, capabilities_required=[]) async def _register_plugin(self, meta: PluginMeta) -> bool: - """向 Host 注册单个插件""" + """向 Host 注册单个插件。 + + Args: + meta: 待注册的插件元数据。 + + Returns: + bool: 是否注册成功。 + """ # 收集插件组件声明 components: List[ComponentDeclaration] = [] instance = meta.instance @@ -341,7 +343,7 @@ class PluginRunner: for comp_info in instance.get_components() ) - reg_payload = RegisterComponentsPayload( + reg_payload = RegisterPluginPayload( plugin_id=meta.plugin_id, plugin_version=meta.version, components=components, @@ -361,8 +363,281 @@ class PluginRunner: logger.error(f"插件 {meta.plugin_id} 注册失败: {e}") return False + async def _unregister_plugin(self, plugin_id: str, reason: str) -> None: + """通知 Host 注销指定插件。 + + Args: + plugin_id: 目标插件 ID。 + reason: 注销原因。 + """ + payload = UnregisterPluginPayload(plugin_id=plugin_id, reason=reason) + try: + await self._rpc_client.send_request( + "plugin.unregister", + plugin_id=plugin_id, + payload=payload.model_dump(), + timeout_ms=10000, + ) + except Exception as exc: + logger.warning(f"插件 {plugin_id} 注销通知失败: {exc}") + + async def _invoke_plugin_on_load(self, meta: PluginMeta) -> bool: + """执行插件的 ``on_load`` 生命周期。 + + Args: + meta: 待初始化的插件元数据。 + + Returns: + bool: 生命周期是否执行成功。 + """ + instance = meta.instance + if not hasattr(instance, "on_load"): + return True + + try: + result = instance.on_load() + if asyncio.iscoroutine(result): + await result + return True + except Exception as exc: + logger.error(f"插件 {meta.plugin_id} on_load 失败: {exc}", exc_info=True) + return False + + async def _invoke_plugin_on_unload(self, meta: PluginMeta) -> None: + """执行插件的 ``on_unload`` 生命周期。 + + Args: + meta: 待卸载的插件元数据。 + """ + instance = meta.instance + if not hasattr(instance, "on_unload"): + return + + try: + result = instance.on_unload() + if asyncio.iscoroutine(result): + await result + except Exception as exc: + logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True) + + async def _activate_plugin(self, meta: PluginMeta) -> bool: + """完成插件注入、授权、生命周期和组件注册。 + + Args: + meta: 待激活的插件元数据。 + + Returns: + bool: 是否激活成功。 + """ + self._inject_context(meta.plugin_id, meta.instance) + self._apply_plugin_config(meta) + + if not await self._bootstrap_plugin(meta): + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + if not await self._invoke_plugin_on_load(meta): + await self._deactivate_plugin(meta) + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + if not await self._register_plugin(meta): + await self._invoke_plugin_on_unload(meta) + await self._deactivate_plugin(meta) + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + self._loader.set_loaded_plugin(meta) + return True + + async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None: + """卸载单个插件并清理 Host/Runner 两侧状态。 + + Args: + meta: 待卸载的插件元数据。 + reason: 卸载原因。 + """ + await self._invoke_plugin_on_unload(meta) + await self._unregister_plugin(meta.plugin_id, reason) + await self._deactivate_plugin(meta) + self._loader.remove_loaded_plugin(meta.plugin_id) + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + + def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]: + """收集依赖指定插件的所有已加载插件。 + + Args: + plugin_id: 根插件 ID。 + + Returns: + Set[str]: 目标插件及其所有反向依赖插件集合。 + """ + impacted_plugins: Set[str] = {plugin_id} + changed = True + + while changed: + changed = False + for loaded_plugin_id in self._loader.list_plugins(): + if loaded_plugin_id in impacted_plugins: + continue + + meta = self._loader.get_plugin(loaded_plugin_id) + if meta is None: + continue + + if any(dependency in impacted_plugins for dependency in meta.dependencies): + impacted_plugins.add(loaded_plugin_id) + changed = True + + return impacted_plugins + + def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]: + """构建受影响插件的卸载顺序。 + + Args: + plugin_ids: 需要卸载的插件集合。 + + Returns: + List[str]: 依赖方优先的卸载顺序。 + """ + dependency_graph: Dict[str, Set[str]] = {} + for plugin_id in plugin_ids: + meta = self._loader.get_plugin(plugin_id) + if meta is None: + dependency_graph[plugin_id] = set() + continue + dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids} + + indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()} + reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph} + + for plugin_id, dependencies in dependency_graph.items(): + for dependency in dependencies: + reverse_graph.setdefault(dependency, set()).add(plugin_id) + + queue: List[str] = sorted(plugin_id for plugin_id, degree in indegree.items() if degree == 0) + load_order: List[str] = [] + + while queue: + current_plugin_id = queue.pop(0) + load_order.append(current_plugin_id) + for dependent_plugin_id in sorted(reverse_graph.get(current_plugin_id, set())): + indegree[dependent_plugin_id] -= 1 + if indegree[dependent_plugin_id] == 0: + queue.append(dependent_plugin_id) + queue.sort() + + return list(reversed(load_order)) + + async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload: + """按插件 ID 在 Runner 进程内执行精确重载。 + + Args: + plugin_id: 目标插件 ID。 + reason: 重载原因。 + + Returns: + ReloadPluginResultPayload: 结构化重载结果。 + """ + candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs) + failed_plugins: Dict[str, str] = {} + + if plugin_id in duplicate_candidates: + conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) + return ReloadPluginResultPayload( + success=False, + requested_plugin_id=plugin_id, + failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"}, + ) + + loaded_plugin_ids = set(self._loader.list_plugins()) + plugin_is_loaded = plugin_id in loaded_plugin_ids + plugin_has_candidate = plugin_id in candidates + + if not plugin_is_loaded and not plugin_has_candidate: + return ReloadPluginResultPayload( + success=False, + requested_plugin_id=plugin_id, + failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"}, + ) + + target_plugin_ids: Set[str] = {plugin_id} + if plugin_is_loaded: + target_plugin_ids = self._collect_reverse_dependents(plugin_id) + + unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids) + unloaded_plugins: List[str] = [] + retained_plugin_ids = loaded_plugin_ids - set(unload_order) + + for unload_plugin_id in unload_order: + meta = self._loader.get_plugin(unload_plugin_id) + if meta is None: + continue + await self._unload_plugin(meta, reason=reason) + unloaded_plugins.append(unload_plugin_id) + + reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {} + for target_plugin_id in target_plugin_ids: + candidate = candidates.get(target_plugin_id) + if candidate is None: + failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态" + continue + reload_candidates[target_plugin_id] = candidate + + load_order, dependency_failures = self._loader.resolve_dependencies( + reload_candidates, + extra_available=retained_plugin_ids, + ) + failed_plugins.update(dependency_failures) + + available_plugins = set(retained_plugin_ids) + reloaded_plugins: List[str] = [] + + for load_plugin_id in load_order: + if load_plugin_id in failed_plugins: + continue + + candidate = reload_candidates.get(load_plugin_id) + if candidate is None: + continue + + _, manifest, _ = candidate + dependencies = PluginMeta._extract_dependencies(manifest) + missing_dependencies = [dependency for dependency in dependencies if dependency not in available_plugins] + if missing_dependencies: + failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(missing_dependencies)}" + continue + + meta = self._loader.load_candidate(load_plugin_id, candidate) + if meta is None: + failed_plugins[load_plugin_id] = "插件模块加载失败" + continue + + activated = await self._activate_plugin(meta) + if not activated: + failed_plugins[load_plugin_id] = "插件初始化失败" + continue + + available_plugins.add(load_plugin_id) + reloaded_plugins.append(load_plugin_id) + + requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins + + return ReloadPluginResultPayload( + success=requested_plugin_success, + requested_plugin_id=plugin_id, + reloaded_plugins=reloaded_plugins, + unloaded_plugins=unloaded_plugins, + failed_plugins=failed_plugins, + ) + async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None: - """通知 Host 当前 generation 已完成插件初始化。""" + """通知 Host 当前 Runner 已完成插件初始化。 + + Args: + loaded_plugins: 成功初始化的插件列表。 + failed_plugins: 初始化失败的插件列表。 + """ payload = RunnerReadyPayload( loaded_plugins=loaded_plugins, failed_plugins=failed_plugins, @@ -487,6 +762,61 @@ class PluginRunner: logger.error(f"插件 {plugin_id} event_handler {component_name} 执行异常: {e}", exc_info=True) return envelope.make_response(payload={"success": False, "continue_processing": True}) + async def _handle_hook_invoke(self, envelope: Envelope) -> Envelope: + """处理 HookHandler 调用请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 标准化后的 Hook 调用结果。 + """ + try: + invoke = InvokePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + plugin_id = envelope.plugin_id + meta = self._loader.get_plugin(plugin_id) + if meta is None: + return envelope.make_error_response( + ErrorCode.E_PLUGIN_NOT_FOUND.value, + f"插件 {plugin_id} 未加载", + ) + + instance = meta.instance + component_name = invoke.component_name + handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None) + if handler_method is None or not callable(handler_method): + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {plugin_id} 无组件: {component_name}", + ) + + try: + raw = ( + await handler_method(**invoke.args) + if inspect.iscoroutinefunction(handler_method) + else handler_method(**invoke.args) + ) + except Exception as exc: + logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True) + return envelope.make_response(payload={"success": False, "continue_processing": True}) + + if raw is None: + result = {"success": True, "continue_processing": True} + elif isinstance(raw, dict): + result = { + "success": True, + "continue_processing": raw.get("continue_processing", True), + "modified_kwargs": raw.get("modified_kwargs"), + "custom_result": raw.get("custom_result"), + } + else: + result = {"success": True, "continue_processing": True, "custom_result": raw} + + return envelope.make_response(payload=result) + async def _handle_workflow_step(self, envelope: Envelope) -> Envelope: """处理 WorkflowStep 调用请求 @@ -557,15 +887,10 @@ class PluginRunner: async def _handle_shutdown(self, envelope: Envelope) -> Envelope: """处理关停 — 调用所有插件的 on_unload 后退出""" logger.info("收到 shutdown 信号,开始调用 on_unload") - for plugin_id in self._loader.list_plugins(): + for plugin_id in list(self._loader.list_plugins()): meta = self._loader.get_plugin(plugin_id) - if meta and hasattr(meta.instance, "on_unload"): - try: - ret = meta.instance.on_unload() - if asyncio.iscoroutine(ret): - await ret - except Exception as e: - logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True) + if meta is not None: + await self._unload_plugin(meta, reason="runner_shutdown") self._shutting_down = True return envelope.make_response(payload={"acknowledged": True}) @@ -587,6 +912,30 @@ class PluginRunner: return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) return envelope.make_response(payload={"acknowledged": True}) + async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope: + """处理按插件 ID 的精确重载请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 结构化重载结果。 + """ + try: + payload = ReloadPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + if self._reload_lock.locked(): + return envelope.make_error_response( + ErrorCode.E_RELOAD_IN_PROGRESS.value, + f"插件 {payload.plugin_id} 重载请求被拒绝:已有重载任务正在执行", + ) + + async with self._reload_lock: + result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason) + return envelope.make_response(payload=result.model_dump()) + def request_capability(self) -> RPCClient: """获取 RPC 客户端(供 SDK 使用,发起能力调用)""" return self._rpc_client @@ -652,13 +1001,16 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common") - def find_module(self, fullname, path=None): + def find_module(self, fullname: str, path: Any = None) -> Any: + """决定是否拦截指定模块导入。""" return self if self._should_block(fullname) else None - def load_module(self, fullname): + def load_module(self, fullname: str) -> None: + """阻止被拦截模块继续导入。""" raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}") def _should_block(self, fullname: str) -> bool: + """判断给定模块名是否应被阻止导入。""" # 放行非 src.* 的导入、以及 "src" 本身 if not fullname.startswith("src.") or fullname == "src": return False @@ -692,6 +1044,7 @@ async def _async_main() -> None: # 注册信号处理 def _mark_runner_shutting_down() -> None: + """标记 Runner 即将进入关停流程。""" runner._shutting_down = True _install_shutdown_signal_handlers(_mark_runner_shutting_down) From 75cd50ee0f2daff6aef647a7b17471731da55992 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 20 Mar 2026 22:35:24 +0800 Subject: [PATCH 18/45] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=96=B0=E8=83=BD?= =?UTF-8?q?=E5=8A=9B=E5=AE=9E=E7=8E=B0=E6=B3=A8=E5=86=8C=E5=92=8C=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E5=A4=84=E7=90=86=EF=BC=8C=E5=A2=9E=E5=BC=BA=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E4=B8=80=E8=87=B4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/capabilities/registry.py | 9 ++++----- src/plugin_runtime/host/capability_service.py | 15 +++++---------- src/plugin_runtime/runner/runner_main.py | 7 +++++-- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py index ead5876a..96b190b4 100644 --- a/src/plugin_runtime/capabilities/registry.py +++ b/src/plugin_runtime/capabilities/registry.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from src.common.logger import get_logger +from src.plugin_runtime.host.capability_service import CapabilityImpl from src.plugin_runtime.host.supervisor import PluginSupervisor if TYPE_CHECKING: @@ -12,17 +13,15 @@ logger = get_logger("plugin_runtime.integration") def register_capability_impls(manager: "PluginRuntimeManager", supervisor: PluginSupervisor) -> None: """向指定 Supervisor 注册主程序提供的能力实现。""" cap_service = supervisor.capability_service - rpc_server = supervisor.rpc_server - def _register(name: str, impl: Any) -> None: - """注册单个能力实现及其 RPC 入口。 + def _register(name: str, impl: CapabilityImpl) -> None: + """注册单个能力实现。 Args: name: 能力名称。 impl: 能力实现函数。 """ cap_service.register_capability(name, impl) - rpc_server.register_method(name, cap_service.handle_capability_request) _register("send.text", manager._cap_send_text) _register("send.emoji", manager._cap_send_emoji) diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 761b20ca..0ff31fe1 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -56,19 +56,14 @@ class CapabilityService: 校验权限后调用对应实现。 """ plugin_id = envelope.plugin_id - payload = envelope.payload if isinstance(envelope.payload, dict) else {} try: - req = CapabilityRequestPayload.model_validate(payload) - capability = req.capability - args = req.args - except Exception: - capability = envelope.method - raw_args = payload.get("args", payload) - args = raw_args if isinstance(raw_args, dict) else {} + req = CapabilityRequestPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 非法: {exc}") - if not capability: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, "能力调用缺少 capability") + capability = req.capability + args = req.args # 1. 权限校验 allowed, reason = self._authorization.check_capability(plugin_id, capability) diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 771e685f..bf36a05c 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -237,9 +237,12 @@ class PluginRunner: f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份" ) resp = await rpc_client.send_request( - method=method, + method="cap.call", plugin_id=bound_plugin_id, - payload=payload or {}, + payload={ + "capability": method, + "args": payload or {}, + }, ) # 从响应信封中提取业务结果 if resp.error: From 85f060621d7b48f40de790a0d366c8d8be67c2c2 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 21 Mar 2026 00:18:28 +0800 Subject: [PATCH 19/45] feat: Add NapCat adapter plugin and enhance message handling - Introduced a built-in NapCat adapter plugin for MVP message forwarding. - Implemented core functionalities for connecting to NapCat/OneBot v11 WebSocket service. - Added message serialization capabilities for WebUI chat routes. - Enhanced the RegisterPluginPayload to include optional adapter declarations. - Implemented methods for handling external messages and adapter declarations in the PluginRunner. - Improved the send_service to inherit platform IO route metadata for outgoing messages. --- .../message_receive/uni_message_sender.py | 173 ++--- src/chat/replyer/group_generator.py | 5 +- src/chat/replyer/private_generator.py | 4 +- src/platform_io/drivers/plugin_driver.py | 154 +++- src/plugin_runtime/host/message_gateway.py | 102 +-- src/plugin_runtime/host/message_utils.py | 3 + src/plugin_runtime/host/supervisor.py | 265 +++++++ src/plugin_runtime/integration.py | 92 ++- src/plugin_runtime/protocol/envelope.py | 48 +- src/plugin_runtime/runner/runner_main.py | 55 +- .../built_in/napcat_adapter/_manifest.json | 30 + src/plugins/built_in/napcat_adapter/plugin.py | 690 ++++++++++++++++++ src/services/send_service.py | 66 +- src/webui/routers/chat/serializers.py | 175 +++++ 14 files changed, 1683 insertions(+), 179 deletions(-) create mode 100644 src/plugins/built_in/napcat_adapter/_manifest.json create mode 100644 src/plugins/built_in/napcat_adapter/plugin.py create mode 100644 src/webui/routers/chat/serializers.py diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 894af238..17d5d6d5 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -1,31 +1,37 @@ -from rich.traceback import install -from typing import Optional +from typing import Any, Optional, Tuple import asyncio +import traceback +from rich.traceback import install -from src.common.message_server.api import get_global_api -from src.common.logger import get_logger -from src.common.database.database import get_db_session from src.chat.message_receive.message import SessionMessage +from src.chat.utils.utils import calculate_typing_time, truncate_message from src.common.data_models.message_component_data_model import ReplyComponent -from src.chat.utils.utils import truncate_message -from src.chat.utils.utils import calculate_typing_time +from src.common.database.database import get_db_session +from src.common.logger import get_logger +from src.common.message_server.api import get_global_api +from src.webui.routers.chat.serializers import serialize_message_sequence install(extra_lines=3) logger = get_logger("sender") # WebUI 聊天室的消息广播器(延迟导入避免循环依赖) -_webui_chat_broadcaster = None +_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None # 虚拟群 ID 前缀(与 chat_routes.py 保持一致) VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_" # TODO: 重构完成后完成webui相关 -def get_webui_chat_broadcaster(): - """获取 WebUI 聊天室广播器""" +def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]: + """获取 WebUI 聊天室广播器。 + + Returns: + Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组; + 若 WebUI 相关模块不可用,则元素会退化为 ``None``。 + """ global _webui_chat_broadcaster if _webui_chat_broadcaster is None: try: @@ -38,102 +44,36 @@ def get_webui_chat_broadcaster(): def is_webui_virtual_group(group_id: str) -> bool: - """检查是否是 WebUI 虚拟群""" - return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX) - - -def parse_message_segments(segment) -> list: - """解析消息段,转换为 WebUI 可用的格式 - - 参考 NapCat 适配器的消息解析逻辑 + """检查是否是 WebUI 虚拟群。 Args: - segment: Seg 消息段对象 + group_id: 待判断的群 ID。 Returns: - list: 消息段列表,每个元素为 {"type": "...", "data": ...} + bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。 """ - - result = [] - - if segment is None: - return result - - if segment.type == "seglist": - # 处理消息段列表 - if segment.data: - for seg in segment.data: - result.extend(parse_message_segments(seg)) - elif segment.type == "text": - # 文本消息 - if segment.data: - result.append({"type": "text", "data": segment.data}) - elif segment.type == "image": - # 图片消息(base64) - if segment.data: - result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"}) - elif segment.type == "emoji": - # 表情包消息(base64) - if segment.data: - result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"}) - elif segment.type == "imageurl": - # 图片链接消息 - if segment.data: - result.append({"type": "image", "data": segment.data}) - elif segment.type == "face": - # 原生表情 - result.append({"type": "face", "data": segment.data}) - elif segment.type == "voice": - # 语音消息(base64) - if segment.data: - result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"}) - elif segment.type == "voiceurl": - # 语音链接 - if segment.data: - result.append({"type": "voice", "data": segment.data}) - elif segment.type == "video": - # 视频消息(base64) - if segment.data: - result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"}) - elif segment.type == "videourl": - # 视频链接 - if segment.data: - result.append({"type": "video", "data": segment.data}) - elif segment.type == "music": - # 音乐消息 - result.append({"type": "music", "data": segment.data}) - elif segment.type == "file": - # 文件消息 - result.append({"type": "file", "data": segment.data}) - elif segment.type == "reply": - # 回复消息 - result.append({"type": "reply", "data": segment.data}) - elif segment.type == "forward": - # 转发消息 - forward_items = [] - if segment.data: - for item in segment.data: - forward_items.append( - { - "content": parse_message_segments(item.get("message_segment", {})) - if isinstance(item, dict) - else [] - } - ) - result.append({"type": "forward", "data": forward_items}) - else: - # 未知类型,尝试作为文本处理 - if segment.data: - result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)}) - - return result + return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX) -async def _send_message(message: MessageSending, show_log=True) -> bool: - """合并后的消息发送函数,包含WS发送和日志记录""" +async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: + """执行统一的消息发送流程。 + + 发送顺序为: + 1. WebUI 特殊链路 + 2. Platform IO 适配器链路 + 3. 旧版 ``maim_message`` / API Server 链路 + + Args: + message: 待发送的内部会话消息。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 是否最终发送成功。 + """ message_preview = truncate_message(message.processed_plain_text, max_length=200) platform = message.platform - group_id = message.session.group_id + group_info = message.message_info.group_info + group_id = group_info.group_id if group_info is not None else "" try: # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 @@ -146,7 +86,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: from src.config.config import global_config # 解析消息段,获取富文本内容 - message_segments = parse_message_segments(message.message_segment) + message_segments = serialize_message_sequence(message.raw_message) # 判断消息类型 # 如果只有一个文本段,使用简单的 text 类型 @@ -184,8 +124,38 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室") return True + try: + from src.platform_io import DeliveryStatus + from src.plugin_runtime.integration import get_plugin_runtime_manager + + receipt = await get_plugin_runtime_manager().try_send_message_via_platform_io(message) + if receipt is not None: + if receipt.status == DeliveryStatus.SENT: + if show_log: + logger.info( + f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' " + f"(driver: {receipt.driver_id or 'unknown'})" + ) + return True + + logger.warning( + f"Platform IO 发送失败: platform={platform} driver={receipt.driver_id} " + f"status={receipt.status} error={receipt.error}" + ) + return False + except Exception as exc: + logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}") + # Fallback 逻辑: 尝试通过 API Server 发送 - async def send_with_new_api(legacy_exception=None): + async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool: + """通过 API Server 回退链路发送消息。 + + Args: + legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。 + + Returns: + bool: 回退链路是否发送成功。 + """ try: from src.config.config import global_config @@ -289,7 +259,8 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: class UniversalMessageSender: """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。""" - def __init__(self): + def __init__(self) -> None: + """初始化统一消息发送器。""" pass async def send_message( @@ -300,7 +271,7 @@ class UniversalMessageSender: reply_message_id: Optional[str] = None, storage_message: bool = True, show_log: bool = True, - ): + ) -> bool: """ 处理、发送并存储一条消息。 diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index e10aa147..75563df7 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -1129,7 +1129,10 @@ class DefaultReplyer: user_id=bot_user_id, user_nickname=global_config.bot.nickname, ), - additional_config={}, + additional_config={ + "platform_io_target_group_id": self.chat_stream.group_id, + "platform_io_target_user_id": self.chat_stream.user_id, + }, ), message_segment=message_segment, ) diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index ccd8e1e4..f642dd69 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -970,7 +970,9 @@ class PrivateReplyer: user_nickname=global_config.bot.nickname, ), group_info=None, - additional_config={}, + additional_config={ + "platform_io_target_user_id": self.chat_stream.user_id, + }, ), message_segment=message_segment, ) diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py index 9c139309..dff980f8 100644 --- a/src/platform_io/drivers/plugin_driver.py +++ b/src/platform_io/drivers/plugin_driver.py @@ -1,34 +1,51 @@ -"""提供 Platform IO 的 plugin 传输驱动骨架。""" +"""提供 Platform IO 的插件适配器驱动实现。""" -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol from src.platform_io.drivers.base import PlatformIODriver -from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage +class _AdapterSupervisorProtocol(Protocol): + """适配器驱动依赖的 Supervisor 最小协议。""" + + async def invoke_adapter( + self, + plugin_id: str, + method_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Any: + """调用适配器插件专用方法。""" + + class PluginPlatformDriver(PlatformIODriver): - """面向 ``MessageGateway`` 插件链路的 Platform IO 驱动骨架。""" + """面向适配器插件链路的 Platform IO 驱动。""" def __init__( self, driver_id: str, platform: str, + supervisor: _AdapterSupervisorProtocol, + send_method: str = "send_to_platform", account_id: Optional[str] = None, scope: Optional[str] = None, plugin_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: - """初始化一个 plugin 驱动描述对象。 + """初始化一个插件适配器驱动。 Args: driver_id: Broker 内的唯一驱动 ID。 - platform: 该 plugin 适配器链路负责的平台。 + platform: 该适配器负责的平台名称。 + supervisor: 持有该适配器插件的 Supervisor。 + send_method: 出站发送时要调用的插件方法名。 account_id: 可选的账号 ID 或 self ID。 scope: 可选的额外路由作用域。 - plugin_id: 拥有该适配器实现的插件 ID,可为空。 + plugin_id: 拥有该适配器实现的插件 ID。 metadata: 可选的额外驱动元数据。 """ descriptor = DriverDescriptor( @@ -41,6 +58,8 @@ class PluginPlatformDriver(PlatformIODriver): metadata=metadata or {}, ) super().__init__(descriptor) + self._supervisor = supervisor + self._send_method = send_method async def send_message( self, @@ -48,7 +67,7 @@ class PluginPlatformDriver(PlatformIODriver): route_key: RouteKey, metadata: Optional[Dict[str, Any]] = None, ) -> DeliveryReceipt: - """通过 plugin 传输路径发送消息。 + """通过适配器插件发送消息。 Args: message: 要投递的内部会话消息。 @@ -57,8 +76,119 @@ class PluginPlatformDriver(PlatformIODriver): Returns: DeliveryReceipt: 由驱动返回的规范化回执。 - - Raises: - NotImplementedError: 当前仍处于骨架阶段,尚未真正接入 MessageGateway。 """ - raise NotImplementedError("PluginPlatformDriver 仅完成地基实现,尚未接入 MessageGateway") + from src.plugin_runtime.host.message_utils import PluginMessageUtils + + plugin_id = self.descriptor.plugin_id or "" + if not plugin_id: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error="插件适配器驱动缺少 plugin_id", + ) + + try: + message_dict = PluginMessageUtils._session_message_to_dict(message) + response = await self._supervisor.invoke_adapter( + plugin_id=plugin_id, + method_name=self._send_method, + args={ + "message": message_dict, + "route": { + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + }, + "metadata": metadata or {}, + }, + timeout_ms=30000, + ) + except Exception as exc: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(exc), + ) + + return self._build_receipt(message.message_id, route_key, response) + + def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt: + """将适配器调用响应归一化为出站回执。 + + Args: + internal_message_id: 内部消息 ID。 + route_key: 本次投递的路由键。 + response: Supervisor 返回的 RPC 响应对象。 + + Returns: + DeliveryReceipt: 标准化后的出站回执。 + """ + if getattr(response, "error", None): + error = response.error.get("message", "适配器发送失败") + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=error, + ) + + payload = getattr(response, "payload", {}) + invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False + if not invoke_success: + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(payload.get("result", "适配器发送失败")) if isinstance(payload, dict) else "适配器发送失败", + ) + + result = payload.get("result") if isinstance(payload, dict) else None + if isinstance(result, dict): + if result.get("success") is False: + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(result.get("error", "适配器发送失败")), + metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {}, + ) + external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + external_message_id=external_message_id, + metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {}, + ) + + if isinstance(result, str) and result.strip(): + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + external_message_id=result.strip(), + ) + + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py index 43777286..9e8e9be6 100644 --- a/src/plugin_runtime/host/message_gateway.py +++ b/src/plugin_runtime/host/message_gateway.py @@ -3,9 +3,11 @@ Message Gateway 模块 适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。 """ -from typing import Dict, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict from src.common.logger import get_logger +from src.platform_io import DeliveryStatus, get_platform_io_manager + from .message_utils import PluginMessageUtils if TYPE_CHECKING: @@ -17,25 +19,53 @@ logger = get_logger("plugin_runtime.host.message_gateway") class MessageGateway: - def __init__(self, component_registry: "ComponentRegistry") -> None: - self._component_registry = component_registry + """Host 侧消息网关包装器。""" - async def receive_external_message(self, external_message: Dict[str, Any]): - """ - 接收外部消息,转换为系统内部格式,并返回转换结果 + def __init__(self, component_registry: "ComponentRegistry") -> None: + """初始化消息网关。 Args: - external_message: 外部消息的字典格式数据 + component_registry: 组件注册表。 + """ + self._component_registry = component_registry + + def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage": + """将标准消息字典转换为 ``SessionMessage``。 + + Args: + external_message: 外部消息的字典格式数据。 Returns: - 转换后的 SessionMessage 对象 + SessionMessage: 转换后的内部消息对象。 + + Raises: + ValueError: 消息字典不合法时抛出。 + """ + return PluginMessageUtils._build_session_message_from_dict(external_message) + + def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]: + """将 ``SessionMessage`` 转换为标准消息字典。 + + Args: + internal_message: 内部消息对象。 + + Returns: + Dict[str, Any]: 供适配器插件消费的标准消息字典。 + """ + return dict(PluginMessageUtils._session_message_to_dict(internal_message)) + + async def receive_external_message(self, external_message: Dict[str, Any]) -> None: + """接收外部消息并送入主消息链。 + + Args: + external_message: 外部消息的字典格式数据。 """ - # 使用递归函数将外部消息字典转换为 SessionMessage try: - session_message = PluginMessageUtils._build_session_message_from_dict(external_message) + session_message = self.build_session_message(external_message) except Exception as e: logger.error(f"转换外部消息失败: {e}") return + from src.chat.message_receive.bot import chat_bot await chat_bot.receive_message(session_message) @@ -48,46 +78,32 @@ class MessageGateway: enabled_only: bool = True, save_to_db: bool = True, ) -> bool: - """ - 接收系统内部消息,转换为外部格式,并返回转换结果 + """将内部消息通过 Platform IO 发送到外部平台。 Args: - internal_message: 系统内部的 SessionMessage 对象 + internal_message: 系统内部的 ``SessionMessage`` 对象。 + supervisor: 当前持有该消息网关的 Supervisor。 + enabled_only: 兼容旧签名的保留参数,当前由 Platform IO 统一裁决。 + save_to_db: 发送成功后是否写入数据库。 Returns: - 转换是否成功 + bool: 是否发送成功。 """ - try: - # 将 SessionMessage 转换为字典格式 - message_dict = PluginMessageUtils._session_message_to_dict(internal_message) - except Exception as e: - logger.error(f"转换内部消息失败:{e}") - return False - gateway_entry = self._component_registry.get_message_gateways( - internal_message.platform, - enabled_only=enabled_only, - session_id=internal_message.session_id, - ) - if not gateway_entry: - logger.warning(f"未找到适配平台 {internal_message.platform} 的消息网关组件,无法发送消息到外部平台") - return False - args = {"platform": internal_message.platform, "message": message_dict} - try: - resp_envelope = await supervisor.invoke_plugin( - "plugin.emit_event", gateway_entry.plugin_id, gateway_entry.name, args - ) - logger.debug("信息发送成功") - except Exception as e: - logger.error(f"调用消息网关组件失败:{e}") + del enabled_only + del supervisor + + platform_io_manager = get_platform_io_manager() + if not platform_io_manager.is_started: + logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息") return False - # 更新为实际id(如果组件返回了新的id) - actual_message_id = resp_envelope.payload.get("message_id") - try: - actual_message_id = str(actual_message_id) - except Exception: - actual_message_id = None - internal_message.message_id = actual_message_id or internal_message.message_id + route_key = platform_io_manager.build_route_key_from_message(internal_message) + receipt = await platform_io_manager.send_message(internal_message, route_key) + if receipt.status != DeliveryStatus.SENT: + logger.warning(f"通过适配器链路发送消息失败: {receipt.error or receipt.status}") + return False + + internal_message.message_id = receipt.external_message_id or internal_message.message_id if save_to_db: try: from src.common.utils.utils_message import MessageUtils diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py index 428e3c48..aaebb529 100644 --- a/src/plugin_runtime/host/message_utils.py +++ b/src/plugin_runtime/host/message_utils.py @@ -209,6 +209,9 @@ class PluginMessageUtils: session_message.is_notify = message_dict.get("is_notify", False) if not isinstance(session_message.is_notify, bool): session_message.is_notify = False + session_message.session_id = message_dict.get("session_id", "") + if not isinstance(session_message.session_id, str): + session_message.session_id = "" session_message.reply_to = message_dict.get("reply_to") if session_message.reply_to is not None and not isinstance(session_message.reply_to, str): session_message.reply_to = None diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 5ae3bdee..cdf3d4ee 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -8,13 +8,20 @@ import sys from src.common.logger import get_logger from src.config.config import global_config +from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager +from src.platform_io.drivers import PluginPlatformDriver +from src.platform_io.route_key_factory import RouteKeyFactory +from src.platform_io.routing import RouteBindingConflictError from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( + AdapterDeclarationPayload, BootstrapPluginPayload, ConfigUpdatedPayload, Envelope, HealthPayload, PROTOCOL_VERSION, + ReceiveExternalMessagePayload, + ReceiveExternalMessageResultPayload, RegisterPluginPayload, ReloadPluginResultPayload, RunnerReadyPayload, @@ -86,6 +93,7 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} + self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -257,6 +265,32 @@ class PluginRunnerSupervisor: timeout_ms, ) + async def invoke_adapter( + self, + plugin_id: str, + method_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Envelope: + """调用适配器插件的专用方法。 + + Args: + plugin_id: 目标适配器插件 ID。 + method_name: 要调用的插件方法名,例如 ``send_to_platform``。 + args: 传递给插件方法的关键字参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: Runner 返回的响应信封。 + """ + return await self.invoke_plugin( + method="plugin.invoke_adapter", + plugin_id=plugin_id, + component_name=method_name, + args=args, + timeout_ms=timeout_ms, + ) + async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: """按插件 ID 触发精确重载。 @@ -384,6 +418,7 @@ class PluginRunnerSupervisor: def _register_internal_methods(self) -> None: """注册 Host 侧内部 RPC 方法。""" self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request) + self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message) self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin) self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin) @@ -427,6 +462,17 @@ class PluginRunnerSupervisor: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) self._component_registry.remove_components_by_plugin(payload.plugin_id) + if payload.plugin_id in self._registered_adapters: + await self._unregister_adapter_driver(payload.plugin_id) + + try: + if payload.adapter is not None: + await self._register_adapter_driver(payload.plugin_id, payload.adapter) + except RouteBindingConflictError as exc: + return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc)) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc)) + registered_count = self._component_registry.register_plugin_components( payload.plugin_id, [component.model_dump() for component in payload.components], @@ -438,6 +484,7 @@ class PluginRunnerSupervisor: "accepted": True, "plugin_id": payload.plugin_id, "registered_components": registered_count, + "adapter_registered": payload.adapter is not None, } ) @@ -458,6 +505,7 @@ class PluginRunnerSupervisor: removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) self._authorization.revoke_permission_token(payload.plugin_id) removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None + await self._unregister_adapter_driver(payload.plugin_id) return envelope.make_response( payload={ @@ -469,6 +517,221 @@ class PluginRunnerSupervisor: } ) + @staticmethod + def _build_adapter_driver_id(plugin_id: str) -> str: + """构造适配器驱动 ID。 + + Args: + plugin_id: 适配器插件 ID。 + + Returns: + str: 对应 Platform IO 中的驱动 ID。 + """ + return f"adapter:{plugin_id}" + + async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None: + """将适配器插件注册到 Platform IO。 + + Args: + plugin_id: 适配器插件 ID。 + adapter: 经过校验的适配器声明。 + + Raises: + ValueError: 适配器路由冲突或驱动注册失败时抛出。 + """ + await self._unregister_adapter_driver(plugin_id) + + platform_io_manager = get_platform_io_manager() + driver = PluginPlatformDriver( + driver_id=self._build_adapter_driver_id(plugin_id), + platform=adapter.platform, + account_id=adapter.account_id or None, + scope=adapter.scope or None, + plugin_id=plugin_id, + send_method=adapter.send_method, + supervisor=self, + metadata={ + "protocol": adapter.protocol, + **adapter.metadata, + }, + ) + binding = RouteBinding( + route_key=driver.descriptor.route_key, + driver_id=driver.driver_id, + driver_kind=DriverKind.PLUGIN, + metadata={ + "plugin_id": plugin_id, + "protocol": adapter.protocol, + }, + ) + + try: + if platform_io_manager.is_started: + await platform_io_manager.add_driver(driver) + else: + platform_io_manager.register_driver(driver) + platform_io_manager.bind_route(binding) + except Exception: + with contextlib.suppress(Exception): + if platform_io_manager.is_started: + await platform_io_manager.remove_driver(driver.driver_id) + else: + platform_io_manager.unregister_driver(driver.driver_id) + raise + + self._registered_adapters[plugin_id] = adapter + + async def _unregister_adapter_driver(self, plugin_id: str) -> None: + """从 Platform IO 注销一个适配器驱动。 + + Args: + plugin_id: 适配器插件 ID。 + """ + platform_io_manager = get_platform_io_manager() + driver_id = self._build_adapter_driver_id(plugin_id) + + with contextlib.suppress(Exception): + if platform_io_manager.is_started: + await platform_io_manager.remove_driver(driver_id) + else: + platform_io_manager.unregister_driver(driver_id) + + self._registered_adapters.pop(plugin_id, None) + + async def _unregister_all_adapter_drivers(self) -> None: + """注销当前 Supervisor 管理的全部适配器驱动。""" + plugin_ids = list(self._registered_adapters.keys()) + for plugin_id in plugin_ids: + await self._unregister_adapter_driver(plugin_id) + + @staticmethod + def _attach_inbound_route_metadata( + session_message: "SessionMessage", + route_key: RouteKey, + route_metadata: Dict[str, Any], + ) -> None: + """将入站路由信息写回消息的 ``additional_config``。 + + Args: + session_message: 已构造好的内部消息对象。 + route_key: Host 为该消息解析出的标准路由键。 + route_metadata: 适配器通过 RPC 补充的原始路由辅助元数据。 + """ + additional_config = session_message.message_info.additional_config + if not isinstance(additional_config, dict): + additional_config = {} + session_message.message_info.additional_config = additional_config + + for key, value in route_metadata.items(): + if value is None: + continue + normalized_value = str(value).strip() + if normalized_value: + additional_config[key] = value + + if route_key.account_id: + additional_config.setdefault("platform_io_account_id", route_key.account_id) + if route_key.scope: + additional_config.setdefault("platform_io_scope", route_key.scope) + + def _build_inbound_route_key( + self, + adapter: AdapterDeclarationPayload, + message: Dict[str, Any], + route_metadata: Dict[str, Any], + ) -> RouteKey: + """为适配器入站消息构造归一路由键。 + + Args: + adapter: 当前适配器声明。 + message: 标准消息字典。 + route_metadata: 插件补充的路由辅助元数据。 + + Returns: + RouteKey: 供 Platform IO 使用的规范化路由键。 + + Raises: + ValueError: 消息平台字段与适配器平台声明不一致时抛出。 + """ + message_platform = str(message.get("platform") or adapter.platform).strip() + if message_platform != adapter.platform: + raise ValueError( + f"外部消息平台 {message_platform} 与适配器 {adapter.platform} 不一致" + ) + + try: + route_key = RouteKeyFactory.from_message_dict(message) + except Exception: + route_key = RouteKey(platform=message_platform) + + route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata) + account_id = route_key.account_id or route_account_id or adapter.account_id or None + scope = route_key.scope or route_scope or adapter.scope or None + return RouteKey( + platform=message_platform, + account_id=account_id, + scope=scope, + ) + + async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope: + """处理适配器插件上报的外部入站消息。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 注入结果响应。 + """ + try: + payload = ReceiveExternalMessagePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + adapter = self._registered_adapters.get(envelope.plugin_id) + if adapter is None: + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {envelope.plugin_id} 未声明为适配器,不能注入外部消息", + ) + + try: + route_key = self._build_inbound_route_key( + adapter=adapter, + message=payload.message, + route_metadata=payload.route_metadata, + ) + session_message = self._message_gateway.build_session_message(payload.message) + self._attach_inbound_route_metadata(session_message, route_key, payload.route_metadata) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + platform_io_manager = get_platform_io_manager() + accepted = await platform_io_manager.accept_inbound( + InboundMessageEnvelope( + route_key=route_key, + driver_id=self._build_adapter_driver_id(envelope.plugin_id), + driver_kind=DriverKind.PLUGIN, + external_message_id=payload.external_message_id or str(payload.message.get("message_id") or "") or None, + dedupe_key=payload.dedupe_key or None, + session_message=session_message, + payload=payload.message, + metadata={ + "plugin_id": envelope.plugin_id, + "protocol": adapter.protocol, + **payload.route_metadata, + }, + ) + ) + response = ReceiveExternalMessageResultPayload( + accepted=accepted, + route_key={ + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + }, + ) + return envelope.make_response(payload=response.model_dump()) + async def _handle_runner_ready(self, envelope: Envelope) -> Envelope: """处理 Runner 就绪通知。 @@ -595,6 +858,7 @@ class PluginRunnerSupervisor: await self._stderr_drain_task self._stderr_drain_task = None + await self._unregister_all_adapter_drivers() self._clear_runner_state() async def _health_check_loop(self) -> None: @@ -671,6 +935,7 @@ class PluginRunnerSupervisor: self._authorization.clear() self._component_registry.clear() self._registered_plugins.clear() + self._registered_adapters.clear() self._runner_ready_events = asyncio.Event() self._runner_ready_payloads = RunnerReadyPayload() diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 730da3e1..30a3c150 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -12,11 +12,13 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Ite import asyncio import json + import tomlkit from src.common.logger import get_logger from src.config.config import global_config from src.config.file_watcher import FileChange, FileWatcher +from src.platform_io import DeliveryReceipt, InboundMessageEnvelope, get_platform_io_manager from src.plugin_runtime.capabilities import ( RuntimeComponentCapabilityMixin, RuntimeCoreCapabilityMixin, @@ -57,6 +59,7 @@ class PluginRuntimeManager( """ def __init__(self) -> None: + """初始化插件运行时管理器。""" from src.plugin_runtime.host.supervisor import PluginSupervisor self._builtin_supervisor: Optional[PluginSupervisor] = None @@ -66,6 +69,22 @@ class PluginRuntimeManager( self._plugin_source_watcher_subscription_id: Optional[str] = None self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {} + async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: + """接收 Platform IO 审核后的入站消息并送入主消息链。 + + Args: + envelope: Platform IO 产出的入站封装。 + """ + session_message = envelope.session_message + if session_message is None and envelope.payload is not None: + session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload)) + if session_message is None: + raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload") + + from src.chat.message_receive.bot import chat_bot + + await chat_bot.receive_message(session_message) + # ─── 插件目录 ───────────────────────────────────────────── @staticmethod @@ -110,6 +129,8 @@ class PluginRuntimeManager( logger.info("未找到任何插件目录,跳过插件运行时启动") return + platform_io_manager = get_platform_io_manager() + # 从配置读取自定义 IPC socket 路径(留空则自动生成) socket_path_base = _cfg.ipc_socket_path or None @@ -134,6 +155,9 @@ class PluginRuntimeManager( started_supervisors: List[PluginSupervisor] = [] try: + platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound) + await platform_io_manager.start() + if self._builtin_supervisor: await self._builtin_supervisor.start() started_supervisors.append(self._builtin_supervisor) @@ -147,6 +171,11 @@ class PluginRuntimeManager( logger.error(f"插件运行时启动失败: {e}", exc_info=True) await self._stop_plugin_file_watcher() await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True) + platform_io_manager.clear_inbound_dispatcher() + try: + await platform_io_manager.stop() + except Exception as platform_io_exc: + logger.warning(f"Platform IO 停止失败: {platform_io_exc}") self._started = False self._builtin_supervisor = None self._third_party_supervisor = None @@ -156,6 +185,7 @@ class PluginRuntimeManager( if not self._started: return + platform_io_manager = get_platform_io_manager() await self._stop_plugin_file_watcher() coroutines: List[Coroutine[Any, Any, None]] = [] @@ -164,11 +194,23 @@ class PluginRuntimeManager( if self._third_party_supervisor: coroutines.append(self._third_party_supervisor.stop()) + stop_errors: List[str] = [] try: - await asyncio.gather(*coroutines, return_exceptions=True) - logger.info("插件运行时已停止") - except Exception as e: - logger.error(f"插件运行时停止失败: {e}", exc_info=True) + results = await asyncio.gather(*coroutines, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + stop_errors.append(str(result)) + + platform_io_manager.clear_inbound_dispatcher() + try: + await platform_io_manager.stop() + except Exception as exc: + stop_errors.append(f"Platform IO: {exc}") + + if stop_errors: + logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}") + else: + logger.info("插件运行时已停止") finally: self._started = False self._builtin_supervisor = None @@ -176,6 +218,7 @@ class PluginRuntimeManager( @property def is_running(self) -> bool: + """返回插件运行时是否处于启动状态。""" return self._started @property @@ -303,6 +346,37 @@ class PluginRuntimeManager( timeout_ms=timeout_ms, ) + async def try_send_message_via_platform_io( + self, + message: "SessionMessage", + ) -> Optional[DeliveryReceipt]: + """尝试通过 Platform IO 中间层发送消息。 + + Args: + message: 待发送的内部会话消息。 + + Returns: + Optional[DeliveryReceipt]: 若当前消息存在 active 路由,则返回实际发送 + 结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。 + """ + if not self._started: + return None + + platform_io_manager = get_platform_io_manager() + if not platform_io_manager.is_started: + return None + + try: + route_key = platform_io_manager.build_route_key_from_message(message) + except Exception as exc: + logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}") + return None + + if platform_io_manager.resolve_driver(route_key) is None: + return None + + return await platform_io_manager.send_message(message, route_key) + def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]: """返回当前持有指定插件的所有 Supervisor。 @@ -426,6 +500,11 @@ class PluginRuntimeManager( """为指定插件生成配置文件变更回调。""" async def _callback(changes: Sequence[FileChange]) -> None: + """将 watcher 事件转发到指定插件的配置处理逻辑。 + + Args: + changes: 当前批次收集到的文件变更列表。 + """ await self._handle_plugin_config_changes(plugin_id, changes) return _callback @@ -542,6 +621,11 @@ class PluginRuntimeManager( # ─── 能力实现注册 ────────────────────────────────────────── def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None: + """向指定 Supervisor 注册主程序能力实现。 + + Args: + supervisor: 需要注册能力实现的目标 Supervisor。 + """ register_capability_impls(self, supervisor) diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 6f95f97f..0dfc6656 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -7,11 +7,11 @@ from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field - import logging as stdlib_logging import time +from pydantic import BaseModel, Field + # ====== 协议常量 ====== PROTOCOL_VERSION = "1.0.0" @@ -156,6 +156,8 @@ class RegisterPluginPayload(BaseModel): """插件版本""" components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表") """组件列表""" + adapter: Optional["AdapterDeclarationPayload"] = Field(default=None, description="可选的适配器声明") + """可选的适配器声明""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") """所需能力列表""" @@ -285,6 +287,48 @@ class ReloadPluginResultPayload(BaseModel): """重载失败的插件及原因""" +class AdapterDeclarationPayload(BaseModel): + """适配器插件声明载荷。""" + + platform: str = Field(description="适配器负责的平台名称,例如 qq") + """适配器负责的平台名称,例如 qq""" + protocol: str = Field(default="", description="接入协议或实现名称,例如 napcat") + """接入协议或实现名称,例如 napcat""" + account_id: str = Field(default="", description="可选的账号 ID 或 self_id") + """可选的账号 ID 或 self_id""" + scope: str = Field(default="", description="可选的路由作用域") + """可选的路由作用域""" + send_method: str = Field(default="send_to_platform", description="Host 出站调用的插件方法名") + """Host 出站调用的插件方法名""" + metadata: Dict[str, Any] = Field(default_factory=dict, description="适配器附加元数据") + """适配器附加元数据""" + + +class ReceiveExternalMessagePayload(BaseModel): + """适配器插件向 Host 注入外部消息的请求载荷。""" + + message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典") + """符合 MessageDict 结构的标准消息字典""" + route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据") + """可选的路由辅助元数据""" + external_message_id: str = Field(default="", description="可选的外部平台消息 ID") + """可选的外部平台消息 ID""" + dedupe_key: str = Field(default="", description="可选的显式去重键") + """可选的显式去重键""" + + +class ReceiveExternalMessageResultPayload(BaseModel): + """外部消息注入结果载荷。""" + + accepted: bool = Field(description="Host 是否接受了本次消息注入") + """Host 是否接受了本次消息注入""" + route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键") + """本次消息使用的归一路由键""" + + +RegisterPluginPayload.model_rebuild() + + # ====== 日志传输 ====== diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index bf36a05c..3ffb6b4b 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -9,9 +9,8 @@ 6. 转发插件的能力调用到 Host """ -from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast - from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast import asyncio import contextlib @@ -26,6 +25,7 @@ import tomllib from src.common.logger import get_console_handler, get_logger, initialize_logging from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( + AdapterDeclarationPayload, BootstrapPluginPayload, ComponentDeclaration, Envelope, @@ -227,7 +227,7 @@ class PluginRunner: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, ) -> Any: - """桥接 PluginContext.call_capability → RPCClient.send_request。 + """桥接 PluginContext 的原始 RPC 调用到 Host。 无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id 始终绑定为当前插件实例,避免伪造其他插件身份申请能力。 @@ -237,17 +237,13 @@ class PluginRunner: f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份" ) resp = await rpc_client.send_request( - method="cap.call", + method=method, plugin_id=bound_plugin_id, - payload={ - "capability": method, - "args": payload or {}, - }, + payload=payload or {}, ) - # 从响应信封中提取业务结果 if resp.error: raise RuntimeError(resp.error.get("message", "能力调用失败")) - return resp.payload.get("result") + return resp.payload ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call) cast(_ContextAwarePlugin, instance)._set_context(ctx) @@ -286,6 +282,7 @@ class PluginRunner: self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_adapter", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) @@ -306,12 +303,14 @@ class PluginRunner: ) try: - await self._rpc_client.send_request( + response = await self._rpc_client.send_request( "plugin.bootstrap", plugin_id=meta.plugin_id, payload=payload.model_dump(), timeout_ms=10000, ) + if response.error: + raise RuntimeError(response.error.get("message", "插件 bootstrap 失败")) return True except Exception as e: logger.error(f"插件 {meta.plugin_id} bootstrap 失败: {e}") @@ -321,6 +320,29 @@ class PluginRunner: """撤销 bootstrap 期间为插件签发的能力令牌。""" await self._bootstrap_plugin(meta, capabilities_required=[]) + def _collect_adapter_declaration(self, meta: PluginMeta) -> Optional[AdapterDeclarationPayload]: + """从插件实例中提取适配器声明。 + + Args: + meta: 待提取声明的插件元数据。 + + Returns: + Optional[AdapterDeclarationPayload]: 若插件声明了适配器角色,则返回 + 经过校验的适配器声明;否则返回 ``None``。 + + Raises: + ValueError: 插件导出的适配器声明结构非法时抛出。 + """ + instance = meta.instance + if not hasattr(instance, "get_adapter_info"): + return None + + adapter_info = instance.get_adapter_info() + if adapter_info is None: + return None + + return AdapterDeclarationPayload.model_validate(adapter_info) + async def _register_plugin(self, meta: PluginMeta) -> bool: """向 Host 注册单个插件。 @@ -346,20 +368,29 @@ class PluginRunner: for comp_info in instance.get_components() ) + try: + adapter = self._collect_adapter_declaration(meta) + except Exception as exc: + logger.error(f"插件 {meta.plugin_id} 适配器声明非法: {exc}", exc_info=True) + return False + reg_payload = RegisterPluginPayload( plugin_id=meta.plugin_id, plugin_version=meta.version, components=components, + adapter=adapter, capabilities_required=meta.capabilities_required, ) try: - _resp = await self._rpc_client.send_request( + response = await self._rpc_client.send_request( "plugin.register_components", plugin_id=meta.plugin_id, payload=reg_payload.model_dump(), timeout_ms=10000, ) + if response.error: + raise RuntimeError(response.error.get("message", "插件注册失败")) logger.info(f"插件 {meta.plugin_id} 注册完成") return True except Exception as e: diff --git a/src/plugins/built_in/napcat_adapter/_manifest.json b/src/plugins/built_in/napcat_adapter/_manifest.json new file mode 100644 index 00000000..6f7e68fd --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/_manifest.json @@ -0,0 +1,30 @@ +{ + "manifest_version": 1, + "name": "napcat_adapter_builtin", + "version": "0.1.0", + "description": "Built-in NapCat adapter plugin for MVP message forwarding.", + "author": { + "name": "OpenAI Codex" + }, + "license": "GPL-v3.0-or-later", + "host_application": { + "min_version": "1.0.0" + }, + "keywords": [ + "adapter", + "built-in", + "napcat", + "onebot", + "qq" + ], + "categories": [ + "Adapter", + "Built-in" + ], + "default_locale": "en-US", + "plugin_info": { + "is_built_in": true, + "plugin_type": "adapter" + }, + "capabilities": [] +} diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py new file mode 100644 index 00000000..3eff518d --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -0,0 +1,690 @@ +"""内置 NapCat 适配器插件。 + +当前实现是一个 MVP 版本,目标仅限于跑通基础消息收发链路: +1. 作为客户端连接 NapCat / OneBot v11 WebSocket 服务。 +2. 将入站消息事件转换为 Host 侧的 ``MessageDict``。 +3. 将 Host 出站消息转换为 OneBot 动作并发送。 + +当前范围刻意收敛为: +- 单连接 +- 文本、@、reply 基础转发 +- 暂不处理 ``notice`` / ``meta_event`` +- 暂不支持图片、语音、文件等复杂媒体 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from uuid import uuid4 + +import asyncio +import contextlib +import json +import time + +from maibot_sdk import Adapter, MaiBotPlugin + +if TYPE_CHECKING: + from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse + +try: + from aiohttp import ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType + + AIOHTTP_AVAILABLE = True +except ImportError: + ClientSession = cast(Any, None) + ClientTimeout = cast(Any, None) + ClientWebSocketResponse = cast(Any, None) + WSMsgType = cast(Any, None) + AIOHTTP_AVAILABLE = False + +if not TYPE_CHECKING: + AiohttpClientWebSocketResponse = Any + + +@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform") +class NapCatAdapterPlugin(MaiBotPlugin): + """NapCat 适配器 MVP 实现。""" + + def __init__(self) -> None: + """初始化 NapCat 适配器插件实例。""" + super().__init__() + self._plugin_config: Dict[str, Any] = {} + self._connection_task: Optional[asyncio.Task[None]] = None + self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} + self._background_tasks: set[asyncio.Task[Any]] = set() + self._send_lock = asyncio.Lock() + self._ws: Optional[AiohttpClientWebSocketResponse] = None + + def set_plugin_config(self, config: Dict[str, Any]) -> None: + """设置插件配置内容。 + + Args: + config: Runner 注入的 ``config.toml`` 解析结果。 + """ + self._plugin_config = config if isinstance(config, dict) else {} + + async def on_load(self) -> None: + """在插件加载时根据配置决定是否启动连接。""" + await self._restart_connection_if_needed() + + async def on_unload(self) -> None: + """在插件卸载时关闭连接并清理后台任务。""" + await self._stop_connection() + await self._cancel_background_tasks() + + async def on_config_update(self, new_config: Dict[str, Any], version: str) -> None: + """在配置更新后重载连接状态。 + + Args: + new_config: 最新的插件配置。 + version: 配置版本号。 + """ + del version + self.set_plugin_config(new_config) + await self._restart_connection_if_needed() + + async def send_to_platform( + self, + message: Dict[str, Any], + route: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """将 Host 出站消息发送到 NapCat。 + + Args: + message: Host 侧标准 ``MessageDict``。 + route: Platform IO 生成的路由信息。 + metadata: Platform IO 附带的投递元数据。 + **kwargs: 预留的扩展参数。 + + Returns: + Dict[str, Any]: 标准化后的发送结果。 + """ + del metadata + del kwargs + + ws = self._ws + if ws is None or ws.closed: + return {"success": False, "error": "NapCat is not connected"} + + try: + action_name, params = self._build_outbound_action(message, route or {}) + response = await self._call_action(action_name, params) + except Exception as exc: + return {"success": False, "error": str(exc)} + + if str(response.get("status", "")).lower() != "ok": + return { + "success": False, + "error": str(response.get("wording") or response.get("message") or "NapCat send failed"), + "metadata": {"retcode": response.get("retcode")}, + } + + response_data = response.get("data", {}) + external_message_id = "" + if isinstance(response_data, dict): + external_message_id = str(response_data.get("message_id") or "") + + return { + "success": True, + "external_message_id": external_message_id or None, + "metadata": {"action": action_name}, + } + + async def _restart_connection_if_needed(self) -> None: + """根据当前配置重启连接循环。""" + await self._stop_connection() + if not self._should_connect(): + self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用") + return + if not AIOHTTP_AVAILABLE: + self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") + return + self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection") + + async def _stop_connection(self) -> None: + """停止当前连接并让所有等待中的动作失败返回。""" + connection_task = self._connection_task + self._connection_task = None + + ws = self._ws + if ws is not None and not ws.closed: + with contextlib.suppress(Exception): + await ws.close() + self._ws = None + + if connection_task is not None: + connection_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await connection_task + + self._fail_pending_actions("NapCat connection closed") + + async def _cancel_background_tasks(self) -> None: + """取消所有仍在运行的入站后台任务。""" + background_tasks = list(self._background_tasks) + for task in background_tasks: + task.cancel() + if background_tasks: + with contextlib.suppress(Exception): + await asyncio.gather(*background_tasks, return_exceptions=True) + self._background_tasks.clear() + + async def _connection_loop(self) -> None: + """维护单个 WebSocket 连接,并在断开后按配置重连。""" + assert ClientSession is not None + assert ClientTimeout is not None + + while self._should_connect(): + ws_url = self._get_string(self._connection_config(), "ws_url") + if not ws_url: + self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空") + return + + headers = self._build_headers() + timeout = ClientTimeout(total=None, connect=10) + heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", 30.0) + + try: + async with ClientSession(headers=headers, timeout=timeout) as session: + async with session.ws_connect(ws_url, heartbeat=heartbeat or None) as ws: + self._ws = ws + self.ctx.logger.info(f"NapCat 适配器已连接: {ws_url}") + await self._receive_loop(ws) + except asyncio.CancelledError: + raise + except Exception as exc: + self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}") + finally: + self._ws = None + self._fail_pending_actions("NapCat connection interrupted") + + if not self._should_connect(): + break + + await asyncio.sleep(self._get_positive_float(self._connection_config(), "reconnect_delay_sec", 5.0)) + + async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None: + """持续消费 WebSocket 消息并分发处理。 + + Args: + ws: 当前活跃的 WebSocket 连接对象。 + """ + assert WSMsgType is not None + + async for ws_message in ws: + if ws_message.type != WSMsgType.TEXT: + if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: + break + continue + + payload = self._parse_json_message(ws_message.data) + if payload is None: + continue + + if echo_id := str(payload.get("echo") or "").strip(): + self._resolve_pending_action(echo_id, payload) + continue + + if str(payload.get("post_type") or "").strip() != "message": + continue + + task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound") + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None: + """处理单条 NapCat 入站消息并注入 Host。 + + Args: + payload: NapCat / OneBot 推送的原始事件数据。 + """ + self_id = str(payload.get("self_id") or "").strip() + sender = payload.get("sender", {}) + if not isinstance(sender, dict): + sender = {} + + sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip() + if not sender_user_id: + return + + if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True): + return + + message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender) + route_metadata: Dict[str, Any] = {} + if self_id: + route_metadata["self_id"] = self_id + if connection_id := self._get_string(self._connection_config(), "connection_id"): + route_metadata["connection_id"] = connection_id + + external_message_id = str(payload.get("message_id") or "").strip() + accepted = await self.ctx.adapter.receive_external_message( + message_dict, + route_metadata=route_metadata, + external_message_id=external_message_id, + dedupe_key=external_message_id, + ) + if not accepted: + self.ctx.logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}") + + def _build_inbound_message_dict( + self, + payload: Dict[str, Any], + self_id: str, + sender_user_id: str, + sender: Dict[str, Any], + ) -> Dict[str, Any]: + """构造 Host 侧可接受的 ``MessageDict``。 + + Args: + payload: NapCat 原始消息事件。 + self_id: 当前机器人账号 ID。 + sender_user_id: 发送者用户 ID。 + sender: 发送者信息字典。 + + Returns: + Dict[str, Any]: 规范化后的 ``MessageDict``。 + """ + message_type = str(payload.get("message_type") or "").strip() or "private" + group_id = str(payload.get("group_id") or "").strip() + group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "") + user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id + user_cardname = str(sender.get("card") or "").strip() or None + + raw_message, is_at = self._convert_inbound_segments(payload.get("message"), self_id) + raw_message_text = str(payload.get("raw_message") or "").strip() + if not raw_message: + raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}] + + plain_text = self._build_plain_text(raw_message, raw_message_text) + timestamp_seconds = payload.get("time") + if not isinstance(timestamp_seconds, (int, float)): + timestamp_seconds = time.time() + + additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type} + if group_id: + additional_config["platform_io_target_group_id"] = group_id + else: + additional_config["platform_io_target_user_id"] = sender_user_id + + message_info: Dict[str, Any] = { + "user_info": { + "user_id": sender_user_id, + "user_nickname": user_nickname, + "user_cardname": user_cardname, + }, + "additional_config": additional_config, + } + if group_id: + message_info["group_info"] = {"group_id": group_id, "group_name": group_name} + + message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip() + return { + "message_id": message_id, + "timestamp": str(float(timestamp_seconds)), + "platform": "qq", + "message_info": message_info, + "raw_message": raw_message, + "is_mentioned": is_at, + "is_at": is_at, + "is_emoji": False, + "is_picture": False, + "is_command": plain_text.startswith("/"), + "is_notify": False, + "session_id": "", + "processed_plain_text": plain_text, + "display_message": plain_text, + } + + def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> tuple[List[Dict[str, Any]], bool]: + """将 OneBot 消息段转换为 Host 消息段结构。 + + Args: + message_payload: OneBot 原始 ``message`` 字段。 + self_id: 当前机器人账号 ID。 + + Returns: + tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 + """ + if isinstance(message_payload, str): + normalized_text = message_payload.strip() + return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False + + if not isinstance(message_payload, list): + return [], False + + converted_segments: List[Dict[str, Any]] = [] + is_at = False + placeholder_texts = { + "face": "[face]", + "file": "[file]", + "image": "[image]", + "json": "[json]", + "record": "[voice]", + "video": "[video]", + "xml": "[xml]", + } + + for segment in message_payload: + if not isinstance(segment, dict): + continue + + segment_type = str(segment.get("type") or "").strip() + segment_data = segment.get("data", {}) + if not isinstance(segment_data, dict): + segment_data = {} + + if segment_type == "text": + if text_value := str(segment_data.get("text") or ""): + converted_segments.append({"type": "text", "data": text_value}) + continue + + if segment_type == "at": + if target_user_id := str(segment_data.get("qq") or "").strip(): + converted_segments.append( + { + "type": "at", + "data": { + "target_user_id": target_user_id, + "target_user_nickname": None, + "target_user_cardname": None, + }, + } + ) + if self_id and target_user_id == self_id: + is_at = True + continue + + if segment_type == "reply": + if target_message_id := str(segment_data.get("id") or "").strip(): + converted_segments.append({"type": "reply", "data": target_message_id}) + continue + + if placeholder := placeholder_texts.get(segment_type): + converted_segments.append({"type": "text", "data": placeholder}) + + return converted_segments, is_at + + def _build_outbound_action( + self, + message: Dict[str, Any], + route: Dict[str, Any], + ) -> tuple[str, Dict[str, Any]]: + """为 Host 出站消息构造 OneBot 动作。 + + Args: + message: Host 侧标准 ``MessageDict``。 + route: Platform IO 路由信息。 + + Returns: + tuple[str, Dict[str, Any]]: 动作名称与参数字典。 + """ + message_info = message.get("message_info", {}) + if not isinstance(message_info, dict): + message_info = {} + + group_info = message_info.get("group_info", {}) + if not isinstance(group_info, dict): + group_info = {} + + additional_config = message_info.get("additional_config", {}) + if not isinstance(additional_config, dict): + additional_config = {} + + raw_message = message.get("raw_message", []) + segments = self._convert_outbound_segments(raw_message) + + if target_group_id := str( + group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or "" + ).strip(): + return "send_group_msg", {"group_id": target_group_id, "message": segments} + + if not ( + target_user_id := str( + additional_config.get("platform_io_target_user_id") + or additional_config.get("target_user_id") + or route.get("target_user_id") + or "" + ).strip() + ): + raise ValueError("Outbound private message is missing target_user_id") + + return "send_private_msg", {"message": segments, "user_id": target_user_id} + + def _convert_outbound_segments(self, raw_message: Any) -> List[Dict[str, Any]]: + """将 Host 消息段转换为 OneBot 消息段。 + + Args: + raw_message: Host 侧 ``raw_message`` 字段。 + + Returns: + List[Dict[str, Any]]: OneBot 消息段列表。 + """ + if not isinstance(raw_message, list): + return [{"type": "text", "data": {"text": ""}}] + + outbound_segments: List[Dict[str, Any]] = [] + for item in raw_message: + if not isinstance(item, dict): + continue + + item_type = str(item.get("type") or "").strip() + item_data = item.get("data") + + if item_type == "text": + text_value = str(item_data or "") + outbound_segments.append({"type": "text", "data": {"text": text_value}}) + continue + + if item_type == "at" and isinstance(item_data, dict): + if target_user_id := str(item_data.get("target_user_id") or "").strip(): + outbound_segments.append({"type": "at", "data": {"qq": target_user_id}}) + continue + + if item_type == "reply": + if target_message_id := str(item_data or "").strip(): + outbound_segments.append({"type": "reply", "data": {"id": target_message_id}}) + continue + + fallback_text = f"[unsupported:{item_type or 'unknown'}]" + outbound_segments.append({"type": "text", "data": {"text": fallback_text}}) + + if not outbound_segments: + outbound_segments.append({"type": "text", "data": {"text": ""}}) + return outbound_segments + + async def _call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]: + """发送 OneBot 动作并等待对应的 echo 响应。 + + Args: + action_name: OneBot 动作名称。 + params: 动作参数。 + + Returns: + Dict[str, Any]: NapCat 返回的原始响应字典。 + """ + ws = self._ws + if ws is None or ws.closed: + raise RuntimeError("NapCat is not connected") + + echo_id = uuid4().hex + loop = asyncio.get_running_loop() + response_future: asyncio.Future[Dict[str, Any]] = loop.create_future() + self._pending_actions[echo_id] = response_future + + request_payload = {"action": action_name, "params": params, "echo": echo_id} + try: + async with self._send_lock: + await ws.send_str(json.dumps(request_payload, ensure_ascii=False)) + timeout_seconds = self._get_positive_float(self._connection_config(), "action_timeout_sec", 15.0) + return await asyncio.wait_for(response_future, timeout=timeout_seconds) + finally: + self._pending_actions.pop(echo_id, None) + + def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None: + """解析等待中的动作响应。 + + Args: + echo_id: 动作请求对应的 echo 标识。 + payload: NapCat 返回的响应载荷。 + """ + response_future = self._pending_actions.get(echo_id) + if response_future is None or response_future.done(): + return + response_future.set_result(payload) + + def _fail_pending_actions(self, error_message: str) -> None: + """让所有等待中的动作以异常方式结束。 + + Args: + error_message: 写入异常中的错误信息。 + """ + for response_future in self._pending_actions.values(): + if not response_future.done(): + response_future.set_exception(RuntimeError(error_message)) + self._pending_actions.clear() + + def _build_headers(self) -> Dict[str, str]: + """构造连接 NapCat 所需的请求头。 + + Returns: + Dict[str, str]: WebSocket 握手请求头。 + """ + access_token = self._get_string(self._connection_config(), "access_token") + return {"Authorization": f"Bearer {access_token}"} if access_token else {} + + def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]: + """解析 WebSocket 文本消息中的 JSON 数据。 + + Args: + data: WebSocket 收到的原始文本数据。 + + Returns: + Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。 + """ + try: + payload = json.loads(str(data)) + except Exception as exc: + self.ctx.logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}") + return None + + return payload if isinstance(payload, dict) else None + + def _build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str: + """从标准消息段中提取可展示的纯文本。 + + Args: + raw_message: 标准化后的消息段列表。 + fallback_text: 当无法拼出文本时使用的回退文本。 + + Returns: + str: 用于 Host 展示和命令判断的纯文本内容。 + """ + plain_text_parts: List[str] = [] + for item in raw_message: + if not isinstance(item, dict): + continue + item_type = str(item.get("type") or "").strip() + item_data = item.get("data") + if item_type == "text": + plain_text_parts.append(str(item_data or "")) + elif item_type == "at" and isinstance(item_data, dict): + plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}") + elif item_type == "reply": + plain_text_parts.append("[reply]") + + plain_text = "".join(part for part in plain_text_parts if part).strip() + return plain_text or fallback_text or "[unsupported]" + + def _plugin_section(self) -> Dict[str, Any]: + """读取插件配置中的 ``plugin`` 段。 + + Returns: + Dict[str, Any]: ``plugin`` 配置字典。 + """ + plugin_section = self._plugin_config.get("plugin", {}) + return plugin_section if isinstance(plugin_section, dict) else {} + + def _connection_config(self) -> Dict[str, Any]: + """读取插件配置中的 ``connection`` 段。 + + Returns: + Dict[str, Any]: ``connection`` 配置字典。 + """ + connection_config = self._plugin_config.get("connection", {}) + return connection_config if isinstance(connection_config, dict) else {} + + def _filters_config(self) -> Dict[str, Any]: + """读取插件配置中的 ``filters`` 段。 + + Returns: + Dict[str, Any]: ``filters`` 配置字典。 + """ + filters_config = self._plugin_config.get("filters", {}) + return filters_config if isinstance(filters_config, dict) else {} + + def _should_connect(self) -> bool: + """判断当前配置下是否应当启动连接。 + + Returns: + bool: 若启用了插件连接则返回 ``True``。 + """ + return self._get_bool(self._plugin_section(), "enabled", False) + + @staticmethod + def _get_bool(mapping: Dict[str, Any], key: str, default: bool) -> bool: + """安全读取布尔配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + + Returns: + bool: 解析后的布尔值。 + """ + value = mapping.get(key, default) + return value if isinstance(value, bool) else default + + @staticmethod + def _get_positive_float(mapping: Dict[str, Any], key: str, default: float) -> float: + """安全读取正浮点数配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + + Returns: + float: 合法的正浮点数;否则返回默认值。 + """ + value = mapping.get(key, default) + if isinstance(value, (int, float)) and float(value) > 0: + return float(value) + return default + + @staticmethod + def _get_string(mapping: Dict[str, Any], key: str) -> str: + """安全读取字符串配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + + Returns: + str: 去除首尾空白后的字符串值。 + """ + value = mapping.get(key) + return "" if value is None else str(value).strip() + + +def create_plugin() -> NapCatAdapterPlugin: + """创建插件实例。 + + Returns: + NapCatAdapterPlugin: NapCat 内置适配器插件实例。 + """ + return NapCatAdapterPlugin() diff --git a/src/services/send_service.py b/src/services/send_service.py index 7af55716..6ca7d005 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -4,7 +4,7 @@ 提供发送各种类型消息的核心功能。 """ -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional import time import traceback @@ -19,6 +19,7 @@ from src.common.data_models.mai_message_data_model import MaiMessage from src.common.data_models.message_component_data_model import DictComponent, MessageSequence from src.common.logger import get_logger from src.config.config import global_config +from src.platform_io.route_key_factory import RouteKeyFactory if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage @@ -31,6 +32,50 @@ logger = get_logger("send_service") # ============================================================================= +def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]: + """从目标会话上下文继承 Platform IO 路由元数据。 + + Args: + target_stream: 当前消息要发送到的会话对象。 + + Returns: + Dict[str, object]: 可安全透传到出站消息 ``additional_config`` 中的 + 路由辅助字段。 + """ + inherited_metadata: Dict[str, object] = {} + + context = getattr(target_stream, "context", None) + context_message = getattr(context, "message", None) + if context_message is None: + return inherited_metadata + + additional_config = getattr(context_message.message_info, "additional_config", {}) + if not isinstance(additional_config, dict): + return inherited_metadata + + for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS): + value = additional_config.get(key) + if value is None: + continue + normalized_value = str(value).strip() + if normalized_value: + inherited_metadata[key] = value + + target_group_id = getattr(target_stream, "group_id", None) + if target_group_id is not None: + normalized_group_id = str(target_group_id).strip() + if normalized_group_id: + inherited_metadata["platform_io_target_group_id"] = normalized_group_id + + target_user_id = getattr(target_stream, "user_id", None) + if target_user_id is not None: + normalized_user_id = str(target_user_id).strip() + if normalized_user_id: + inherited_metadata["platform_io_target_user_id"] = normalized_user_id + + return inherited_metadata + + async def _send_to_target( message_segment: Seg, stream_id: str, @@ -42,7 +87,22 @@ async def _send_to_target( show_log: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定目标发送消息的内部实现""" + """向指定目标发送消息。 + + Args: + message_segment: 待发送的消息段。 + stream_id: 目标会话 ID。 + display_message: 用于界面展示的文本内容。 + typing: 是否显示输入中状态。 + set_reply: 是否在发送时附带引用回复。 + reply_message: 被回复的消息对象。 + storage_message: 是否将发送结果写入消息存储。 + show_log: 是否输出发送日志。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + bool: 发送成功返回 ``True``,否则返回 ``False``。 + """ try: if set_reply and not reply_message: logger.warning("[SendService] 使用引用回复,但未提供回复消息") @@ -80,7 +140,7 @@ async def _send_to_target( platform=target_stream.platform, ) - additional_config: dict[str, object] = {} + additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream) if selected_expressions is not None: additional_config["selected_expressions"] = selected_expressions bot_user_id = get_bot_account(target_stream.platform) diff --git a/src/webui/routers/chat/serializers.py b/src/webui/routers/chat/serializers.py new file mode 100644 index 00000000..32104f88 --- /dev/null +++ b/src/webui/routers/chat/serializers.py @@ -0,0 +1,175 @@ +"""提供 WebUI 聊天路由使用的消息序列化能力。""" + +from typing import Any, Dict, List, Optional + +import base64 + +from src.common.data_models.message_component_data_model import ( + AtComponent, + DictComponent, + EmojiComponent, + ForwardComponent, + ForwardNodeComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + StandardMessageComponents, + TextComponent, + VoiceComponent, +) + + +def serialize_message_sequence(message_sequence: MessageSequence) -> List[Dict[str, Any]]: + """将内部统一消息组件序列转换为 WebUI 富文本消息段。 + + Args: + message_sequence: 内部统一消息组件序列。 + + Returns: + List[Dict[str, Any]]: 可直接广播给 WebUI 前端的消息段列表。 + """ + serialized_segments: List[Dict[str, Any]] = [] + for component in message_sequence.components: + serialized_segment = serialize_message_component(component) + if serialized_segment is not None: + serialized_segments.append(serialized_segment) + return serialized_segments + + +def serialize_message_component(component: StandardMessageComponents) -> Optional[Dict[str, Any]]: + """将单个内部消息组件转换为 WebUI 消息段。 + + Args: + component: 待序列化的内部消息组件。 + + Returns: + Optional[Dict[str, Any]]: 序列化后的 WebUI 消息段;若组件不应展示则返回 ``None``。 + """ + if isinstance(component, TextComponent): + return {"type": "text", "data": component.text} + + if isinstance(component, ImageComponent): + return _serialize_binary_component( + segment_type="image", + mime_type="image/png", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, EmojiComponent): + return _serialize_binary_component( + segment_type="emoji", + mime_type="image/gif", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, VoiceComponent): + return _serialize_binary_component( + segment_type="voice", + mime_type="audio/wav", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, AtComponent): + return { + "type": "at", + "data": { + "target_user_id": component.target_user_id, + "target_user_nickname": component.target_user_nickname, + "target_user_cardname": component.target_user_cardname, + }, + } + + if isinstance(component, ReplyComponent): + return { + "type": "reply", + "data": { + "target_message_id": component.target_message_id, + "target_message_content": component.target_message_content, + "target_message_sender_id": component.target_message_sender_id, + "target_message_sender_nickname": component.target_message_sender_nickname, + "target_message_sender_cardname": component.target_message_sender_cardname, + }, + } + + if isinstance(component, ForwardNodeComponent): + return { + "type": "forward", + "data": [_serialize_forward_component(item) for item in component.forward_components], + } + + if isinstance(component, DictComponent): + return _serialize_dict_component(component.data) + + return {"type": "unknown", "data": str(component)} + + +def _serialize_binary_component( + segment_type: str, + mime_type: str, + binary_data: bytes, + fallback_text: str, +) -> Dict[str, Any]: + """序列化带二进制负载的消息组件。 + + Args: + segment_type: WebUI 消息段类型。 + mime_type: 对应的数据 MIME 类型。 + binary_data: 组件二进制数据。 + fallback_text: 二进制缺失时可退化展示的文本。 + + Returns: + Dict[str, Any]: 序列化后的 WebUI 消息段。 + """ + if binary_data: + encoded_payload = base64.b64encode(binary_data).decode() + return {"type": segment_type, "data": f"data:{mime_type};base64,{encoded_payload}"} + + if fallback_text: + return {"type": "text", "data": fallback_text} + + return {"type": "unknown", "original_type": segment_type, "data": ""} + + +def _serialize_forward_component(component: ForwardComponent) -> Dict[str, Any]: + """序列化单个转发节点。 + + Args: + component: 待序列化的转发节点组件。 + + Returns: + Dict[str, Any]: WebUI 可消费的转发节点字典。 + """ + return { + "message_id": component.message_id, + "user_id": component.user_id, + "user_nickname": component.user_nickname, + "user_cardname": component.user_cardname, + "content": serialize_message_sequence(MessageSequence(component.content)), + } + + +def _serialize_dict_component(data: Dict[str, Any]) -> Dict[str, Any]: + """最佳努力地序列化非标准字典组件。 + + Args: + data: 原始字典组件内容。 + + Returns: + Dict[str, Any]: 序列化后的 WebUI 消息段。 + """ + raw_type = str(data.get("type") or "dict").strip() + raw_payload = data.get("data", data) + + if raw_type in {"text", "image", "emoji", "voice", "video", "file", "music", "face"}: + return {"type": raw_type, "data": raw_payload} + + if raw_type == "reply": + return {"type": "reply", "data": raw_payload} + + if raw_type == "forward" and isinstance(raw_payload, list): + return {"type": "forward", "data": raw_payload} + + return {"type": "unknown", "original_type": raw_type, "data": raw_payload} From 780cd4f767eb7f8f6bc2357ef062891bb2de2ca2 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 21 Mar 2026 00:46:34 +0800 Subject: [PATCH 20/45] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=96=B0=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E5=92=8C=20RPC=20=E6=9C=8D=E5=8A=A1=E5=99=A8=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E5=A2=9E=E5=BC=BA=E6=8F=A1=E6=89=8B=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E7=AE=A1=E7=90=86=E4=B8=8E=E9=85=8D=E7=BD=AE=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- src/plugin_runtime/host/rpc_server.py | 23 +- src/plugin_runtime/host/supervisor.py | 79 +++++- src/plugin_runtime/integration.py | 162 ++++++++--- src/plugin_runtime/protocol/envelope.py | 2 +- src/plugins/built_in/napcat_adapter/plugin.py | 251 +++++++++++++++++- 6 files changed, 457 insertions(+), 62 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f41e9448..9887ac24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "jieba>=0.42.1", "json-repair>=0.47.6", "maim-message>=0.6.2", - "maibot-plugin-sdk>=1.2.3,<2.0.0", + "maibot-plugin-sdk>=2.0.0", "msgpack>=1.1.2", "numpy>=2.2.6", "openai>=1.95.0", diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 75ef9b2a..2c422775 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -69,6 +69,7 @@ class RPCServer: # 运行状态 self._running: bool = False self._tasks: List[asyncio.Task[None]] = [] + self._last_handshake_rejection_reason: str = "" @property def session_token(self) -> str: @@ -78,6 +79,15 @@ class RPCServer: def is_connected(self) -> bool: return self._connection is not None and not self._connection.is_closed + @property + def last_handshake_rejection_reason(self) -> str: + """返回最近一次握手被拒绝的原因。""" + return self._last_handshake_rejection_reason + + def clear_handshake_state(self) -> None: + """清空最近一次握手拒绝状态。""" + self._last_handshake_rejection_reason = "" + def register_method(self, method: str, handler: MethodHandler) -> None: """注册 RPC 方法处理器""" self._method_handlers[method] = handler @@ -85,6 +95,7 @@ class RPCServer: async def start(self) -> None: """启动 RPC 服务器""" self._running = True + self.clear_handshake_state() self._send_queue = asyncio.Queue(maxsize=self._send_queue_size) self._send_worker_task = asyncio.create_task(self._send_loop()) await self._transport.start(self._handle_connection) @@ -93,6 +104,7 @@ class RPCServer: async def stop(self) -> None: """停止 RPC 服务器""" self._running = False + self.clear_handshake_state() self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") @@ -204,6 +216,7 @@ class RPCServer: async def _handle_connection(self, conn: Connection) -> None: """处理新的 Runner 连接""" logger.info("收到 Runner 连接") + self.clear_handshake_state() # 第一条消息必须是 runner.hello 握手 try: success = await self._handle_handshake(conn) @@ -232,6 +245,7 @@ class RPCServer: envelope = self._codec.decode_envelope(data) if envelope.method != "runner.hello": logger.error(f"期望 runner.hello,收到 {envelope.method}") + self._last_handshake_rejection_reason = "首条消息必须为 runner.hello" error_resp = envelope.make_error_response( ErrorCode.E_PROTOCOL_MISMATCH.value, "首条消息必须为 runner.hello", @@ -244,7 +258,8 @@ class RPCServer: # 校验会话令牌 if hello.session_token != self._session_token: logger.error("会话令牌不匹配") - resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效") + self._last_handshake_rejection_reason = "会话令牌无效" + resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) return False @@ -252,15 +267,19 @@ class RPCServer: # 校验 SDK 版本 if not self._check_sdk_version(hello.sdk_version): logger.error(f"SDK 版本不兼容: {hello.sdk_version}") + self._last_handshake_rejection_reason = ( + f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]" + ) resp_payload = HelloResponsePayload( accepted=False, - reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]", + reason=self._last_handshake_rejection_reason, ) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) return False # 发送响应 + self.clear_handshake_state() resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index cdf3d4ee..33091d5a 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -202,17 +202,24 @@ class PluginRunnerSupervisor: self._restart_count = 0 self._clear_runner_state() - await self._rpc_server.start() - await self._spawn_runner() - try: - await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) - await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) - except TimeoutError: - if not self._rpc_server.is_connected: - logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败") - else: - logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败") + await self._rpc_server.start() + await self._spawn_runner() + + try: + await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) + await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) + except TimeoutError: + if not self._rpc_server.is_connected: + logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败") + else: + logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败") + except Exception: + await self._shutdown_runner(reason="startup_failed") + await self._rpc_server.stop() + self._clear_runner_state() + self._running = False + raise self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health") logger.info("PluginRunnerSupervisor 已启动") @@ -387,7 +394,16 @@ class PluginRunnerSupervisor: async def wait_for_connection() -> None: """轮询等待 RPC 连接建立。""" - while self._running and not self._rpc_server.is_connected: + while True: + if self._rpc_server.is_connected: + return + + if not self._running: + raise RuntimeError("Supervisor 已停止,等待 Runner 连接已取消") + + if failure_reason := self._get_runner_startup_failure_reason(): + raise RuntimeError(f"等待 Runner 连接失败: {failure_reason}") + await asyncio.sleep(0.1) try: @@ -408,10 +424,27 @@ class PluginRunnerSupervisor: Raises: TimeoutError: 在超时时间内 Runner 未完成初始化。 """ + async def wait_for_ready() -> RunnerReadyPayload: + """轮询等待 Runner 上报就绪。""" + while True: + if self._runner_ready_events.is_set(): + return self._runner_ready_payloads + + if not self._running: + raise RuntimeError("Supervisor 已停止,等待 Runner 就绪已取消") + + if failure_reason := self._get_runner_startup_failure_reason(): + raise RuntimeError(f"等待 Runner 就绪失败: {failure_reason}") + + if not self._rpc_server.is_connected: + raise RuntimeError("等待 Runner 就绪失败: Runner RPC 连接已断开") + + await asyncio.sleep(0.1) + try: - await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec) + payload = await asyncio.wait_for(wait_for_ready(), timeout=timeout_sec) logger.info("Runner 已完成初始化并上报就绪") - return self._runner_ready_payloads + return payload except asyncio.TimeoutError as exc: raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc @@ -923,6 +956,7 @@ class PluginRunnerSupervisor: await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) except Exception as exc: + await self._shutdown_runner(reason="restart_failed") logger.error(f"Runner 重启失败: {exc}", exc_info=True) return False @@ -938,6 +972,25 @@ class PluginRunnerSupervisor: self._registered_adapters.clear() self._runner_ready_events = asyncio.Event() self._runner_ready_payloads = RunnerReadyPayload() + self._rpc_server.clear_handshake_state() + + def _get_runner_startup_failure_reason(self) -> Optional[str]: + """获取 Runner 在启动阶段已经暴露出的失败原因。 + + Returns: + Optional[str]: 若已检测到失败则返回失败原因,否则返回 ``None``。 + """ + if handshake_reason := self._rpc_server.last_handshake_rejection_reason: + return f"握手被拒绝: {handshake_reason}" + + process = self._runner_process + if process is None: + return "Runner 进程不存在" + + if process.returncode is not None: + return f"Runner 进程已退出,退出码 {process.returncode}" + + return None PluginSupervisor = PluginRunnerSupervisor diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 30a3c150..24cf09fc 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -68,6 +68,7 @@ class PluginRuntimeManager( self._plugin_file_watcher: Optional[FileWatcher] = None self._plugin_source_watcher_subscription_id: Optional[str] = None self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {} + self._plugin_path_cache: Dict[str, Path] = {} async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: """接收 Platform IO 审核后的入站消息并送入主消息链。 @@ -215,6 +216,7 @@ class PluginRuntimeManager( self._started = False self._builtin_supervisor = None self._third_party_supervisor = None + self._plugin_path_cache.clear() @property def is_running(self) -> bool: @@ -254,7 +256,7 @@ class PluginRuntimeManager( config_payload = ( config_data if config_data is not None - else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs) + else self._load_plugin_config_for_supervisor(sv, plugin_id) ) await sv.notify_plugin_config_updated( plugin_id=plugin_id, @@ -452,6 +454,7 @@ class PluginRuntimeManager( async def _stop_plugin_file_watcher(self) -> None: """停止插件文件监视器,并清理所有已注册订阅。""" if self._plugin_file_watcher is None: + self._plugin_path_cache.clear() return for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()): self._plugin_file_watcher.unsubscribe(subscription_id) @@ -461,12 +464,95 @@ class PluginRuntimeManager( self._plugin_source_watcher_subscription_id = None await self._plugin_file_watcher.stop() self._plugin_file_watcher = None + self._plugin_path_cache.clear() def _iter_plugin_dirs(self) -> Iterable[Path]: """迭代所有 Supervisor 当前管理的插件根目录。""" for supervisor in self.supervisors: yield from getattr(supervisor, "_plugin_dirs", []) + @staticmethod + def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]: + """迭代所有可能的插件目录路径。 + + Args: + plugin_dirs: 一个或多个插件根目录。 + + Yields: + Path: 单个插件目录路径。 + """ + for plugin_dir in plugin_dirs: + plugin_root = Path(plugin_dir).resolve() + if not plugin_root.is_dir(): + continue + for entry in plugin_root.iterdir(): + if entry.is_dir(): + yield entry.resolve() + + @staticmethod + def _read_plugin_id_from_plugin_path(plugin_path: Path) -> Optional[str]: + """从单个插件目录中读取 manifest 声明的插件 ID。 + + Args: + plugin_path: 单个插件目录路径。 + + Returns: + Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。 + """ + manifest_path = plugin_path / "_manifest.json" + entrypoint_path = plugin_path / "plugin.py" + if not manifest_path.is_file() or not entrypoint_path.is_file(): + return None + + try: + with open(manifest_path, "r", encoding="utf-8") as manifest_file: + manifest = json.load(manifest_file) + except Exception: + return None + + if not isinstance(manifest, dict): + return None + + plugin_id = str(manifest.get("name", plugin_path.name)).strip() or plugin_path.name + return plugin_id or None + + def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]: + """迭代目录中可解析到的插件 ID 与实际目录路径。 + + Args: + plugin_dirs: 一个或多个插件根目录。 + + Yields: + Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。 + """ + for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs): + if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path): + yield plugin_id, plugin_path + + def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]: + """为指定 Supervisor 定位某个插件的实际目录。 + + Args: + supervisor: 目标 Supervisor。 + plugin_id: 插件 ID。 + + Returns: + Optional[Path]: 插件目录路径;未找到时返回 ``None``。 + """ + cached_path = self._plugin_path_cache.get(plugin_id) + if cached_path is not None: + for plugin_dir in getattr(supervisor, "_plugin_dirs", []): + if self._plugin_dir_matches(cached_path, Path(plugin_dir)): + return cached_path + + for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])): + if candidate_plugin_id != plugin_id: + continue + self._plugin_path_cache[plugin_id] = plugin_path + return plugin_path + + return None + def _refresh_plugin_config_watch_subscriptions(self) -> None: """按当前已注册插件集合刷新 config.toml 的单插件订阅。 @@ -476,7 +562,11 @@ class PluginRuntimeManager( if self._plugin_file_watcher is None: return - desired_config_paths = dict(self._iter_registered_plugin_config_paths()) + desired_plugin_paths = dict(self._iter_registered_plugin_paths()) + self._plugin_path_cache = desired_plugin_paths.copy() + desired_config_paths = { + plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items() + } for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()): if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]: @@ -509,21 +599,17 @@ class PluginRuntimeManager( return _callback - def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]: - """迭代当前所有已注册插件的 config.toml 路径。""" + def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]: + """迭代当前所有已注册插件的实际目录路径。""" for supervisor in self.supervisors: for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys(): - if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id): - yield plugin_id, config_path + if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id): + yield plugin_id, plugin_path def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]: """从指定 Supervisor 的插件目录中定位某个插件的 config.toml。""" - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - plugin_path = plugin_dir.resolve() / plugin_id - if plugin_path.is_dir(): - return plugin_path / "config.toml" - return None + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + return None if plugin_path is None else plugin_path / "config.toml" async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None: """处理单个插件配置文件变化,并仅向目标插件推送配置更新。""" @@ -542,7 +628,7 @@ class PluginRuntimeManager( try: await supervisor.notify_plugin_config_updated( plugin_id=plugin_id, - config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])), + config_data=self._load_plugin_config_for_supervisor(supervisor, plugin_id), ) except Exception as exc: logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}") @@ -591,32 +677,38 @@ class PluginRuntimeManager( def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]: """根据变更路径为指定 Supervisor 推断受影响的插件 ID。""" - for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items(): - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - candidate_dir = plugin_dir.resolve() / plugin_id - if path == candidate_dir or path.is_relative_to(candidate_dir): - return plugin_id + resolved_path = path.resolve() + + for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys(): + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)): + return plugin_id + + for plugin_id, plugin_path in self._plugin_path_cache.items(): + if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])): + continue + if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path): + return plugin_id + + for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])): + if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path): + self._plugin_path_cache[plugin_id] = plugin_path + return plugin_id - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - plugin_root = plugin_dir.resolve() - if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts): - return relative_parts[0] return None - @staticmethod - def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]: + def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]: """从给定插件目录集合中读取目标插件的配置内容。""" - for plugin_dir in plugin_dirs: - plugin_path = plugin_dir.resolve() / plugin_id - if plugin_path.is_dir(): - config_path = plugin_path / "config.toml" - if not config_path.exists(): - return {} - with open(config_path, "r", encoding="utf-8") as handle: - return tomlkit.load(handle).unwrap() - return {} + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + if plugin_path is None: + return {} + + config_path = plugin_path / "config.toml" + if not config_path.exists(): + return {} + + with open(config_path, "r", encoding="utf-8") as handle: + return tomlkit.load(handle).unwrap() # ─── 能力实现注册 ────────────────────────────────────────── diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 0dfc6656..d71e02c5 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field PROTOCOL_VERSION = "1.0.0" # 支持的 SDK 版本范围(Host 在握手时校验) MIN_SDK_VERSION = "1.0.0" -MAX_SDK_VERSION = "1.99.99" +MAX_SDK_VERSION = "2.99.99" # ====== 消息类型 ====== diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index 3eff518d..a481101f 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, cast from uuid import uuid4 import asyncio @@ -42,6 +42,13 @@ if not TYPE_CHECKING: AiohttpClientWebSocketResponse = Any +SUPPORTED_CONFIG_VERSION = "0.1.0" +DEFAULT_RECONNECT_DELAY_SEC = 5.0 +DEFAULT_HEARTBEAT_SEC = 30.0 +DEFAULT_ACTION_TIMEOUT_SEC = 15.0 +DEFAULT_CHAT_LIST_TYPE = "whitelist" + + @Adapter(platform="qq", protocol="napcat", send_method="send_to_platform") class NapCatAdapterPlugin(MaiBotPlugin): """NapCat 适配器 MVP 实现。""" @@ -52,7 +59,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): self._plugin_config: Dict[str, Any] = {} self._connection_task: Optional[asyncio.Task[None]] = None self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} - self._background_tasks: set[asyncio.Task[Any]] = set() + self._background_tasks: Set[asyncio.Task[Any]] = set() self._send_lock = asyncio.Lock() self._ws: Optional[AiohttpClientWebSocketResponse] = None @@ -80,8 +87,9 @@ class NapCatAdapterPlugin(MaiBotPlugin): new_config: 最新的插件配置。 version: 配置版本号。 """ - del version self.set_plugin_config(new_config) + if version: + self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}") await self._restart_connection_if_needed() async def send_to_platform( @@ -139,6 +147,8 @@ class NapCatAdapterPlugin(MaiBotPlugin): if not self._should_connect(): self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用") return + if not self._validate_current_config(): + return if not AIOHTTP_AVAILABLE: self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") return @@ -185,7 +195,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): headers = self._build_headers() timeout = ClientTimeout(total=None, connect=10) - heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", 30.0) + heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", DEFAULT_HEARTBEAT_SEC) try: async with ClientSession(headers=headers, timeout=timeout) as session: @@ -204,7 +214,13 @@ class NapCatAdapterPlugin(MaiBotPlugin): if not self._should_connect(): break - await asyncio.sleep(self._get_positive_float(self._connection_config(), "reconnect_delay_sec", 5.0)) + await asyncio.sleep( + self._get_positive_float( + self._connection_config(), + "reconnect_delay_sec", + DEFAULT_RECONNECT_DELAY_SEC, + ) + ) async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None: """持续消费 WebSocket 消息并分发处理。 @@ -250,8 +266,11 @@ class NapCatAdapterPlugin(MaiBotPlugin): if not sender_user_id: return + group_id = str(payload.get("group_id") or "").strip() if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True): return + if not self._is_inbound_chat_allowed(sender_user_id, group_id): + return message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender) route_metadata: Dict[str, Any] = {} @@ -339,7 +358,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): "display_message": plain_text, } - def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> tuple[List[Dict[str, Any]], bool]: + def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> Tuple[List[Dict[str, Any]], bool]: """将 OneBot 消息段转换为 Host 消息段结构。 Args: @@ -347,7 +366,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): self_id: 当前机器人账号 ID。 Returns: - tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 + Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 """ if isinstance(message_payload, str): normalized_text = message_payload.strip() @@ -412,7 +431,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): self, message: Dict[str, Any], route: Dict[str, Any], - ) -> tuple[str, Dict[str, Any]]: + ) -> Tuple[str, Dict[str, Any]]: """为 Host 出站消息构造 OneBot 动作。 Args: @@ -420,7 +439,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): route: Platform IO 路由信息。 Returns: - tuple[str, Dict[str, Any]]: 动作名称与参数字典。 + Tuple[str, Dict[str, Any]]: 动作名称与参数字典。 """ message_info = message.get("message_info", {}) if not isinstance(message_info, dict): @@ -519,7 +538,11 @@ class NapCatAdapterPlugin(MaiBotPlugin): try: async with self._send_lock: await ws.send_str(json.dumps(request_payload, ensure_ascii=False)) - timeout_seconds = self._get_positive_float(self._connection_config(), "action_timeout_sec", 15.0) + timeout_seconds = self._get_positive_float( + self._connection_config(), + "action_timeout_sec", + DEFAULT_ACTION_TIMEOUT_SEC, + ) return await asyncio.wait_for(response_future, timeout=timeout_seconds) finally: self._pending_actions.pop(echo_id, None) @@ -626,6 +649,173 @@ class NapCatAdapterPlugin(MaiBotPlugin): filters_config = self._plugin_config.get("filters", {}) return filters_config if isinstance(filters_config, dict) else {} + def _chat_config(self) -> Dict[str, Any]: + """读取插件配置中的 ``chat`` 段。 + + Returns: + Dict[str, Any]: ``chat`` 配置字典。 + """ + chat_config = self._plugin_config.get("chat", {}) + return chat_config if isinstance(chat_config, dict) else {} + + def _is_inbound_chat_allowed(self, sender_user_id: str, group_id: str) -> bool: + """检查入站消息是否通过聊天名单过滤。 + + Args: + sender_user_id: 发送者用户 ID。 + group_id: 群聊 ID;私聊时为空字符串。 + + Returns: + bool: 若消息允许继续进入 Host,则返回 ``True``。 + """ + chat_config = self._chat_config() + banned_user_ids = self._get_string_list(chat_config, "ban_user_id") + if sender_user_id in banned_user_ids: + self.ctx.logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃") + return False + + if group_id: + group_list_type = self._get_list_mode(chat_config, "group_list_type", DEFAULT_CHAT_LIST_TYPE) + group_id_list = self._get_string_list(chat_config, "group_list") + if not self._is_id_allowed_by_list_policy(group_id, group_list_type, group_id_list): + self.ctx.logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃") + return False + return True + + private_list_type = self._get_list_mode(chat_config, "private_list_type", DEFAULT_CHAT_LIST_TYPE) + private_id_list = self._get_string_list(chat_config, "private_list") + if not self._is_id_allowed_by_list_policy(sender_user_id, private_list_type, private_id_list): + self.ctx.logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃") + return False + return True + + def _is_id_allowed_by_list_policy( + self, + target_id: str, + list_type: str, + configured_ids: Set[str], + ) -> bool: + """根据白名单或黑名单规则判断目标 ID 是否允许通过。 + + Args: + target_id: 待检查的目标 ID。 + list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。 + configured_ids: 配置中的 ID 集合。 + + Returns: + bool: 若目标 ID 允许通过,则返回 ``True``。 + """ + if list_type == "whitelist": + return target_id in configured_ids + return target_id not in configured_ids + + def _validate_current_config(self) -> bool: + """校验当前配置是否满足启动连接的前提条件。 + + Returns: + bool: 配置可用于启动连接时返回 ``True``。 + """ + if not self._validate_plugin_config_version(): + return False + + connection_config = self._connection_config() + ws_url = self._get_string(connection_config, "ws_url") + if not ws_url: + self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空") + return False + + self._validate_positive_float_setting( + connection_config, + "connection", + "reconnect_delay_sec", + DEFAULT_RECONNECT_DELAY_SEC, + ) + self._validate_positive_float_setting( + connection_config, + "connection", + "heartbeat_sec", + DEFAULT_HEARTBEAT_SEC, + ) + self._validate_positive_float_setting( + connection_config, + "connection", + "action_timeout_sec", + DEFAULT_ACTION_TIMEOUT_SEC, + ) + self._validate_list_mode_setting(self._chat_config(), "chat", "group_list_type", DEFAULT_CHAT_LIST_TYPE) + self._validate_list_mode_setting(self._chat_config(), "chat", "private_list_type", DEFAULT_CHAT_LIST_TYPE) + return True + + def _validate_plugin_config_version(self) -> bool: + """校验插件配置版本是否与当前实现兼容。 + + Returns: + bool: 版本兼容时返回 ``True``。 + """ + config_version = self._get_string(self._plugin_section(), "config_version") + if not config_version: + self.ctx.logger.error( + f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}" + ) + return False + + if config_version != SUPPORTED_CONFIG_VERSION: + self.ctx.logger.error( + "NapCat 适配器配置版本不兼容: " + f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}" + ) + return False + + return True + + def _validate_positive_float_setting( + self, + mapping: Dict[str, Any], + section_name: str, + key: str, + default: float, + ) -> None: + """校验正浮点数配置项,并在非法时输出告警日志。 + + Args: + mapping: 待读取的配置字典。 + section_name: 当前配置段名称。 + key: 目标配置键名。 + default: 配置非法时实际使用的默认值。 + """ + value = mapping.get(key, default) + if isinstance(value, (int, float)) and float(value) > 0: + return + + self.ctx.logger.warning( + "NapCat 适配器配置项取值无效,已回退到默认值: " + f"{section_name}.{key}={value!r},默认值为 {default}" + ) + + def _validate_list_mode_setting( + self, + mapping: Dict[str, Any], + section_name: str, + key: str, + default: str, + ) -> None: + """校验名单模式配置项,并在非法时输出告警日志。 + + Args: + mapping: 待读取的配置字典。 + section_name: 当前配置段名称。 + key: 目标配置键名。 + default: 配置非法时实际使用的默认值。 + """ + value = mapping.get(key, default) + if isinstance(value, str) and value.strip() in {"whitelist", "blacklist"}: + return + + self.ctx.logger.warning( + "NapCat 适配器配置项取值无效,已回退到默认值: " + f"{section_name}.{key}={value!r},默认值为 {default}" + ) + def _should_connect(self) -> bool: """判断当前配置下是否应当启动连接。 @@ -680,6 +870,47 @@ class NapCatAdapterPlugin(MaiBotPlugin): value = mapping.get(key) return "" if value is None else str(value).strip() + @staticmethod + def _get_list_mode(mapping: Dict[str, Any], key: str, default: str) -> str: + """安全读取名单模式配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + + Returns: + str: 合法的名单模式字符串。 + """ + value = mapping.get(key, default) + if isinstance(value, str): + normalized_value = value.strip() + if normalized_value in {"whitelist", "blacklist"}: + return normalized_value + return default + + @staticmethod + def _get_string_list(mapping: Dict[str, Any], key: str) -> Set[str]: + """安全读取 ID 列表配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + + Returns: + Set[str]: 去重后的字符串 ID 集合。 + """ + value = mapping.get(key, []) + if not isinstance(value, list): + return set() + + normalized_values: Set[str] = set() + for item in value: + item_text = "" if item is None else str(item).strip() + if item_text: + normalized_values.add(item_text) + return normalized_values + def create_plugin() -> NapCatAdapterPlugin: """创建插件实例。 From a1859027efbf3fa5fb9178c815b350a2f7e2ae02 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 21 Mar 2026 00:53:05 +0800 Subject: [PATCH 21/45] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E9=A2=9C?= =?UTF-8?q?=E8=89=B2=E6=98=A0=E5=B0=84=E5=92=8C=E5=88=AB=E5=90=8D=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=EF=BC=8C=E5=A2=9E=E5=BC=BA=E6=A8=A1=E5=9D=97=E4=B8=80?= =?UTF-8?q?=E8=87=B4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/logger_color_and_mapping.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index f84caa21..4044d2dc 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -1,9 +1,8 @@ # 定义模块颜色映射 -from typing import Optional, Tuple, Dict - import itertools import os import sys +from typing import Dict, Optional, Tuple MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = { @@ -54,15 +53,19 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = { "component_registry": ("#ffaf00", None, False), "plugin_runtime.integration": ("#d75f00", None, False), "plugin_runtime.host.supervisor": ("#ff5f00", None, False), + "plugin_runtime.host.runner_manager": ("#ff5f00", None, False), "plugin_runtime.host.rpc_server": ("#ff8700", None, False), "plugin_runtime.host.component_registry": ("#ffaf00", None, False), "plugin_runtime.host.capability_service": ("#ffd700", None, False), "plugin_runtime.host.event_dispatcher": ("#87d700", None, False), "plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False), + "plugin_runtime.host.message_gateway": ("#5fd7d7", None, False), + "plugin_runtime.host.message_utils": ("#5faf87", None, False), "plugin_runtime.runner.main": ("#d787ff", None, False), "plugin_runtime.runner.rpc_client": ("#8787ff", None, False), "plugin_runtime.runner.manifest_validator": ("#5fafff", None, False), "plugin_runtime.runner.plugin_loader": ("#00afaf", None, False), + "plugin.napcat_adapter_builtin": ("#00af87", None, False), "webui": ("#5f87ff", None, False), "webui.app": ("#5f87d7", None, False), "webui.api": ("#5fafff", None, False), @@ -157,15 +160,20 @@ MODULE_ALIASES = { "chat_history_summarizer": "聊天概括器", "plugin_runtime.integration": "IPC插件系统", "plugin_runtime.host.supervisor": "插件监督器", + "plugin_runtime.host.runner_manager": "插件监督器", "plugin_runtime.host.rpc_server": "插件RPC服务", "plugin_runtime.host.component_registry": "插件组件注册", "plugin_runtime.host.capability_service": "插件能力服务", "plugin_runtime.host.event_dispatcher": "插件事件分发", + "plugin_runtime.host.hook_dispatcher": "插件Hook分发", + "plugin_runtime.host.message_gateway": "插件消息网关", + "plugin_runtime.host.message_utils": "插件消息工具", "plugin_runtime.host.workflow_executor": "插件工作流", "plugin_runtime.runner.main": "插件运行器", "plugin_runtime.runner.rpc_client": "插件RPC客户端", "plugin_runtime.runner.manifest_validator": "插件清单校验", "plugin_runtime.runner.plugin_loader": "插件加载器", + "plugin.napcat_adapter_builtin": "NapCat内置适配器", "webui": "WebUI", "webui.app": "WebUI应用", "webui.api": "WebUI接口", From dd20cd4992b1c1bc250f0367317e9eef8587328f Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 21 Mar 2026 00:59:21 +0800 Subject: [PATCH 22/45] =?UTF-8?q?refactor:=20=E5=A2=9E=E5=BC=BA=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/runner/log_handler.py | 6 ++ .../runner/manifest_validator.py | 69 +++++++++++++++++++ src/plugin_runtime/runner/runner_main.py | 13 +++- 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/src/plugin_runtime/runner/log_handler.py b/src/plugin_runtime/runner/log_handler.py index b5a0a328..6f42940f 100644 --- a/src/plugin_runtime/runner/log_handler.py +++ b/src/plugin_runtime/runner/log_handler.py @@ -66,6 +66,12 @@ class RunnerIPCLogHandler(logging.Handler): ALLOWED_LOGGER_PREFIXES: tuple[str, ...] = ("plugin.", "plugin_runtime.", "_maibot_plugin_") def __init__(self) -> None: + """初始化 Runner 端日志转发处理器。 + + 创建有界日志缓冲区,并准备与 RPC 客户端绑定的后台刷新任务。 + 此时不会启动任何异步任务;真正开始转发要等到 :meth:`start` + 被调用后才会发生。 + """ super().__init__() # deque(maxlen=N): append/popleft 在 CPython GIL 保护下线程安全 self._buffer: collections.deque[LogEntry] = collections.deque(maxlen=self.QUEUE_MAX) diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py index b6990850..32429e01 100644 --- a/src/plugin_runtime/runner/manifest_validator.py +++ b/src/plugin_runtime/runner/manifest_validator.py @@ -18,6 +18,15 @@ class VersionComparator: @staticmethod def normalize_version(version: str) -> str: + """将版本号规范化为三段式语义版本字符串。 + + Args: + version: 原始版本号字符串。 + + Returns: + str: 规范化后的 ``major.minor.patch`` 形式版本号。 + 当输入为空或格式非法时返回 ``0.0.0``。 + """ if not version: return "0.0.0" normalized = re.sub(r"-snapshot\.\d+", "", version.strip()) @@ -30,6 +39,15 @@ class VersionComparator: @staticmethod def parse_version(version: str) -> Tuple[int, int, int]: + """将版本字符串解析为可比较的整数元组。 + + Args: + version: 原始版本号字符串。 + + Returns: + Tuple[int, int, int]: 三段式版本号对应的整数元组。 + 当解析失败时返回 ``(0, 0, 0)``。 + """ normalized = VersionComparator.normalize_version(version) try: parts = normalized.split(".") @@ -39,6 +57,16 @@ class VersionComparator: @staticmethod def compare(v1: str, v2: str) -> int: + """比较两个版本号的大小关系。 + + Args: + v1: 第一个版本号。 + v2: 第二个版本号。 + + Returns: + int: ``-1`` 表示 ``v1 < v2``,``1`` 表示 ``v1 > v2``, + ``0`` 表示两者相等。 + """ t1 = VersionComparator.parse_version(v1) t2 = VersionComparator.parse_version(v2) if t1 < t2: @@ -49,6 +77,17 @@ class VersionComparator: @staticmethod def is_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]: + """判断版本号是否落在给定闭区间内。 + + Args: + version: 待检查的版本号。 + min_version: 允许的最小版本号,留空表示不限制下界。 + max_version: 允许的最大版本号,留空表示不限制上界。 + + Returns: + Tuple[bool, str]: 第一项表示是否满足要求,第二项为失败原因; + 当校验通过时第二项为空字符串。 + """ if not min_version and not max_version: return True, "" vn = VersionComparator.normalize_version(version) @@ -71,6 +110,11 @@ class ManifestValidator: SUPPORTED_MANIFEST_VERSIONS = [1, 2] def __init__(self, host_version: str = "") -> None: + """初始化 Manifest 校验器。 + + Args: + host_version: 当前 Host 版本号,用于校验插件声明的兼容区间。 + """ self._host_version = host_version self.errors: List[str] = [] self.warnings: List[str] = [] @@ -96,6 +140,11 @@ class ManifestValidator: return len(self.errors) == 0 def _check_required_fields(self, manifest: Dict[str, Any]) -> None: + """检查 Manifest 中的必填字段是否存在且非空。 + + Args: + manifest: 待校验的 Manifest 数据。 + """ for field in self.REQUIRED_FIELDS: if field not in manifest: self.errors.append(f"缺少必需字段: {field}") @@ -103,11 +152,21 @@ class ManifestValidator: self.errors.append(f"必需字段不能为空: {field}") def _check_manifest_version(self, manifest: Dict[str, Any]) -> None: + """检查 Manifest 版本号是否在当前 Runner 支持范围内。 + + Args: + manifest: 待校验的 Manifest 数据。 + """ mv = manifest.get("manifest_version") if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS: self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}") def _check_author(self, manifest: Dict[str, Any]) -> None: + """校验 ``author`` 字段的结构与内容。 + + Args: + manifest: 待校验的 Manifest 数据。 + """ author = manifest.get("author") if author is None: return @@ -121,6 +180,11 @@ class ManifestValidator: self.errors.append("author 应为字符串或 {name, url} 对象") def _check_host_compatibility(self, manifest: Dict[str, Any]) -> None: + """检查插件声明的 Host 兼容范围是否包含当前 Host 版本。 + + Args: + manifest: 待校验的 Manifest 数据。 + """ host_app = manifest.get("host_application") if not isinstance(host_app, dict) or not self._host_version: return @@ -131,6 +195,11 @@ class ManifestValidator: self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})") def _check_recommended(self, manifest: Dict[str, Any]) -> None: + """检查推荐字段是否齐备,并记录为警告而非错误。 + + Args: + manifest: 待校验的 Manifest 数据。 + """ for field in self.RECOMMENDED_FIELDS: if field not in manifest or not manifest[field]: self.warnings.append(f"建议填写字段: {field}") diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 3ffb6b4b..88f92494 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -47,8 +47,19 @@ logger = get_logger("plugin_runtime.runner.main") class _ContextAwarePlugin(Protocol): + """支持注入运行时上下文的插件协议。 + + 该协议用于描述 Runner 在激活插件时依赖的最小接口。 + 只要插件实例实现了 ``_set_context`` 方法,就可以被 Runner + 注入 ``PluginContext`` 或兼容层上下文对象。 + """ + def _set_context(self, context: Any) -> None: - """为插件注入上下文对象。""" + """为插件实例注入运行时上下文。 + + Args: + context: 由 Runner 构造的上下文对象。 + """ def _install_shutdown_signal_handlers( From 4e2e7a279e42d7f33a41bdeaf7c9b1bedd4f9953 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 21 Mar 2026 21:47:22 +0800 Subject: [PATCH 23/45] feat: Implement adapter runtime state management and update handling - Added support for adapter runtime state updates in the PluginRunnerSupervisor. - Introduced new payload classes: AdapterStateUpdatePayload and AdapterStateUpdateResultPayload for handling state updates. - Implemented methods to bind and unbind routes based on adapter connection status. - Enhanced the NapCat adapter to report connection state and manage runtime state. - Added tests for adapter runtime state synchronization and database session behavior in the statistic module. - Updated existing methods to ensure proper handling of adapter state and route bindings. --- .../test_person_info_group_cardname.py | 355 ++++++++++++++++++ pytests/test_adapter_runtime_state.py | 162 ++++++++ pytests/test_plugin_runtime.py | 6 +- pytests/utils_test/statistic_test.py | 115 ++++++ src/chat/utils/statistic.py | 12 +- .../data_models/person_info_data_model.py | 96 ++++- src/person_info/person_info.py | 79 ++-- src/plugin_runtime/host/supervisor.py | 274 +++++++++++++- src/plugin_runtime/protocol/envelope.py | 24 ++ src/plugin_runtime/runner/runner_main.py | 7 +- src/plugins/built_in/napcat_adapter/plugin.py | 168 ++++++++- 11 files changed, 1219 insertions(+), 79 deletions(-) create mode 100644 pytests/common_test/test_person_info_group_cardname.py create mode 100644 pytests/test_adapter_runtime_state.py create mode 100644 pytests/utils_test/statistic_test.py diff --git a/pytests/common_test/test_person_info_group_cardname.py b/pytests/common_test/test_person_info_group_cardname.py new file mode 100644 index 00000000..62a63f43 --- /dev/null +++ b/pytests/common_test/test_person_info_group_cardname.py @@ -0,0 +1,355 @@ +"""人物信息群名片字段兼容测试。""" + +from __future__ import annotations + +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import json +import sys + +import pytest + +from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json + + +class _DummyLogger: + """模拟日志记录器。""" + + def debug(self, message: str) -> None: + """记录调试日志。 + + Args: + message: 日志内容。 + """ + del message + + def info(self, message: str) -> None: + """记录信息日志。 + + Args: + message: 日志内容。 + """ + del message + + def warning(self, message: str) -> None: + """记录警告日志。 + + Args: + message: 日志内容。 + """ + del message + + def error(self, message: str) -> None: + """记录错误日志。 + + Args: + message: 日志内容。 + """ + del message + + +class _DummyStatement: + """模拟 SQL 查询语句对象。""" + + def where(self, condition: Any) -> "_DummyStatement": + """附加过滤条件。 + + Args: + condition: 过滤条件。 + + Returns: + _DummyStatement: 当前语句对象。 + """ + del condition + return self + + def limit(self, value: int) -> "_DummyStatement": + """限制返回条数。 + + Args: + value: 条数限制。 + + Returns: + _DummyStatement: 当前语句对象。 + """ + del value + return self + + +class _DummyColumn: + """模拟 SQLModel 列对象。""" + + def is_not(self, value: Any) -> "_DummyColumn": + """模拟 `IS NOT` 条件构造。 + + Args: + value: 比较值。 + + Returns: + _DummyColumn: 当前列对象。 + """ + del value + return self + + def __eq__(self, other: Any) -> "_DummyColumn": + """模拟等值条件构造。 + + Args: + other: 比较值。 + + Returns: + _DummyColumn: 当前列对象。 + """ + del other + return self + + +class _DummyResult: + """模拟数据库查询结果。""" + + def __init__(self, record: Any) -> None: + """初始化查询结果。 + + Args: + record: 待返回的首条记录。 + """ + self._record = record + + def first(self) -> Any: + """返回第一条记录。 + + Returns: + Any: 首条记录。 + """ + return self._record + + def all(self) -> list[Any]: + """返回全部结果。 + + Returns: + list[Any]: 结果列表。 + """ + if self._record is None: + return [] + return self._record if isinstance(self._record, list) else [self._record] + + +class _DummySession: + """模拟数据库 Session。""" + + def __init__(self, record: Any) -> None: + """初始化 Session。 + + Args: + record: `first()` 应返回的记录。 + """ + self.record = record + self.added_records: list[Any] = [] + + def __enter__(self) -> "_DummySession": + """进入上下文管理器。 + + Returns: + _DummySession: 当前 Session。 + """ + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """退出上下文管理器。 + + Args: + exc_type: 异常类型。 + exc_val: 异常值。 + exc_tb: 异常回溯。 + """ + del exc_type + del exc_val + del exc_tb + + def exec(self, statement: Any) -> _DummyResult: + """执行查询。 + + Args: + statement: 查询语句。 + + Returns: + _DummyResult: 模拟结果对象。 + """ + del statement + return _DummyResult(self.record) + + def add(self, record: Any) -> None: + """记录被添加的对象。 + + Args: + record: 被写入 Session 的对象。 + """ + self.added_records.append(record) + + +class _DummyPersonInfoRecord: + """模拟 `PersonInfo` ORM 模型。""" + + person_id = "person_id" + person_name = "person_name" + + def __init__(self, **kwargs: Any) -> None: + """使用关键字参数初始化记录对象。 + + Args: + **kwargs: 字段值。 + """ + for key, value in kwargs.items(): + setattr(self, key, value) + + +def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType: + """加载带依赖桩的 `person_info` 模块。 + + Args: + monkeypatch: Pytest monkeypatch 工具。 + session: 提供给模块使用的假数据库 Session。 + + Returns: + ModuleType: 加载后的模块对象。 + """ + logger_module = ModuleType("src.common.logger") + logger_module.get_logger = lambda name: _DummyLogger() + monkeypatch.setitem(sys.modules, "src.common.logger", logger_module) + + database_module = ModuleType("src.common.database.database") + database_module.get_db_session = lambda: session + monkeypatch.setitem(sys.modules, "src.common.database.database", database_module) + + database_model_module = ModuleType("src.common.database.database_model") + database_model_module.PersonInfo = _DummyPersonInfoRecord + monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module) + + llm_module = ModuleType("src.llm_models.utils_model") + + class _DummyLLMRequest: + """模拟 LLMRequest。""" + + def __init__(self, model_set: Any, request_type: str) -> None: + """初始化假请求对象。 + + Args: + model_set: 模型配置。 + request_type: 请求类型。 + """ + del model_set + del request_type + + llm_module.LLMRequest = _DummyLLMRequest + monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module) + + config_module = ModuleType("src.config.config") + config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot")) + config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils")) + monkeypatch.setitem(sys.modules, "src.config.config", config_module) + + chat_manager_module = ModuleType("src.chat.message_receive.chat_manager") + chat_manager_module.chat_manager = SimpleNamespace() + monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module) + + module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py" + spec = spec_from_file_location("person_info_group_cardname_test_module", module_path) + assert spec is not None and spec.loader is not None + + module = module_from_spec(spec) + monkeypatch.setitem(sys.modules, spec.name, module) + spec.loader.exec_module(module) + + monkeypatch.setattr(module, "select", lambda *args: _DummyStatement()) + monkeypatch.setattr(module, "col", lambda field: _DummyColumn()) + return module + + +def test_parse_group_cardname_json_uses_canonical_key() -> None: + """群名片 JSON 解析应只使用 `group_cardname` 键名。""" + parsed = parse_group_cardname_json( + json.dumps( + [ + {"group_id": "1001", "group_cardname": "现行字段"}, + ], + ensure_ascii=False, + ) + ) + + assert parsed is not None + assert [(item.group_id, item.group_cardname) for item in parsed] == [ + ("1001", "现行字段"), + ] + + +def test_dump_group_cardname_records_uses_canonical_key() -> None: + """群名片序列化应输出 `group_cardname` 键名。""" + dumped = dump_group_cardname_records( + [ + {"group_id": "1001", "group_cardname": "群昵称"}, + ] + ) + + assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}] + + +def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None: + """同步人物信息时应写入数据库模型的 `group_cardname` 字段。""" + record = _DummyPersonInfoRecord() + session = _DummySession(record) + module = _load_person_module(monkeypatch, session) + + person = module.Person.__new__(module.Person) + person.is_known = True + person.person_id = "person-1" + person.platform = "qq" + person.user_id = "10001" + person.nickname = "看番的龙" + person.person_name = "看番的龙" + person.name_reason = "测试" + person.know_times = 1 + person.know_since = 1700000000.0 + person.last_know = 1700000100.0 + person.memory_points = ["喜好:番剧:0.8"] + person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}] + + person.sync_to_database() + + assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]' + assert not hasattr(record, "group_nickname") + + +def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None: + """从数据库加载人物信息时应读取标准 `group_cardname` 结构。""" + record = _DummyPersonInfoRecord( + user_id="10001", + platform="qq", + is_known=True, + user_nickname="看番的龙", + person_name="看番的龙", + name_reason=None, + know_counts=2, + memory_points='["喜好:番剧:0.8"]', + group_cardname=json.dumps( + [ + {"group_id": "20001", "group_cardname": "白泽大人"}, + ], + ensure_ascii=False, + ), + ) + session = _DummySession(record) + module = _load_person_module(monkeypatch, session) + + person = module.Person.__new__(module.Person) + person.person_id = "person-1" + person.memory_points = [] + person.group_cardname_list = [] + + person.load_from_database() + + assert person.group_cardname_list == [ + {"group_id": "20001", "group_cardname": "白泽大人"}, + ] diff --git a/pytests/test_adapter_runtime_state.py b/pytests/test_adapter_runtime_state.py new file mode 100644 index 00000000..e82f4c8c --- /dev/null +++ b/pytests/test_adapter_runtime_state.py @@ -0,0 +1,162 @@ +"""适配器运行时状态同步测试。""" + +from typing import Any, Dict + +import pytest + +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import RouteKey +from src.plugin_runtime.host.supervisor import PluginSupervisor +from src.plugin_runtime.protocol.envelope import ( + AdapterDeclarationPayload, + Envelope, + MessageType, +) + + +def _make_request(plugin_id: str, payload: Dict[str, Any]) -> Envelope: + """构造一个适配器状态更新 RPC 请求。 + + Args: + plugin_id: 目标适配器插件 ID。 + payload: 请求载荷。 + + Returns: + Envelope: 标准 RPC 请求信封。 + """ + return Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="host.update_adapter_state", + plugin_id=plugin_id, + payload=payload, + ) + + +@pytest.mark.asyncio +async def test_adapter_runtime_state_binds_and_unbinds_routes(monkeypatch: pytest.MonkeyPatch) -> None: + """连接建立后应绑定路由,断开后应撤销路由。""" + import src.plugin_runtime.host.supervisor as supervisor_module + + platform_io_manager = PlatformIOManager() + monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) + + supervisor = PluginSupervisor(plugin_dirs=[]) + adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat") + await supervisor._register_adapter_driver("napcat_adapter_builtin", adapter) + + response = await supervisor._handle_update_adapter_state( + _make_request( + "napcat_adapter_builtin", + { + "connected": True, + "account_id": "10001", + "scope": "", + "metadata": {}, + }, + ) + ) + + assert response.error is None + assert response.payload["accepted"] is True + assert ( + platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq", account_id="10001"), + exact_only=True, + ).driver_id + == "adapter:napcat_adapter_builtin" + ) + assert ( + platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq"), + exact_only=True, + ).driver_id + == "adapter:napcat_adapter_builtin" + ) + + response = await supervisor._handle_update_adapter_state( + _make_request( + "napcat_adapter_builtin", + { + "connected": False, + "account_id": "", + "scope": "", + "metadata": {}, + }, + ) + ) + + assert response.error is None + assert response.payload["accepted"] is True + assert platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq", account_id="10001"), + exact_only=True, + ) is None + assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None + + +@pytest.mark.asyncio +async def test_platform_default_route_is_removed_when_multiple_exact_routes_exist( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """同一平台存在多个精确路由时不应保留默认平台路由。""" + import src.plugin_runtime.host.supervisor as supervisor_module + + platform_io_manager = PlatformIOManager() + monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) + + supervisor = PluginSupervisor(plugin_dirs=[]) + adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat") + await supervisor._register_adapter_driver("adapter_a", adapter) + await supervisor._register_adapter_driver("adapter_b", adapter) + + await supervisor._handle_update_adapter_state( + _make_request( + "adapter_a", + { + "connected": True, + "account_id": "10001", + "scope": "", + "metadata": {}, + }, + ) + ) + assert ( + platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq"), + exact_only=True, + ).driver_id + == "adapter:adapter_a" + ) + + await supervisor._handle_update_adapter_state( + _make_request( + "adapter_b", + { + "connected": True, + "account_id": "10002", + "scope": "", + "metadata": {}, + }, + ) + ) + assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None + + await supervisor._handle_update_adapter_state( + _make_request( + "adapter_b", + { + "connected": False, + "account_id": "", + "scope": "", + "metadata": {}, + }, + ) + ) + assert ( + platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq"), + exact_only=True, + ).driver_id + == "adapter:adapter_a" + ) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 2c703161..5ab16c85 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -486,10 +486,10 @@ class TestSDK: "timeout_ms": timeout_ms, } ) - if method == "cap.request": + if method == "cap.call": bootstrap_methods = [call["method"] for call in self.calls[:-1]] assert "plugin.bootstrap" in bootstrap_methods - return SimpleNamespace(error=None, payload={"result": {"success": True}}) + return SimpleNamespace(error=None, payload={"success": True}) return SimpleNamespace(error=None, payload={"accepted": True}) async def disconnect(self): @@ -529,7 +529,7 @@ class TestSDK: await runner.run() methods = [call["method"] for call in runner._rpc_client.calls] - assert methods == ["plugin.bootstrap", "cap.request", "plugin.register_components", "runner.ready"] + assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"] class TestPluginSdkUsage: diff --git a/pytests/utils_test/statistic_test.py b/pytests/utils_test/statistic_test.py new file mode 100644 index 00000000..d3d8c18a --- /dev/null +++ b/pytests/utils_test/statistic_test.py @@ -0,0 +1,115 @@ +"""统计模块数据库会话行为测试。""" + +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime, timedelta +from types import ModuleType +from typing import Any, Callable, Iterator + +import sys + +import pytest + +from src.chat.utils import statistic + + +class _DummyResult: + """模拟 SQLModel 查询结果对象。""" + + def all(self) -> list[Any]: + """返回空结果集。 + + Returns: + list[Any]: 空列表。 + """ + return [] + + +class _DummySession: + """模拟数据库 Session。""" + + def exec(self, statement: Any) -> _DummyResult: + """执行查询语句并返回空结果。 + + Args: + statement: 待执行的查询语句。 + + Returns: + _DummyResult: 空结果对象。 + """ + del statement + return _DummyResult() + + +def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]: + """构造一个记录 auto_commit 参数的假会话工厂。 + + Args: + calls: 用于记录每次调用 auto_commit 参数的列表。 + + Returns: + Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。 + """ + + @contextmanager + def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]: + """记录会话参数并返回假 Session。 + + Args: + auto_commit: 是否启用自动提交。 + + Yields: + Iterator[_DummySession]: 假 Session 对象。 + """ + calls.append(auto_commit) + yield _DummySession() + + return _fake_get_db_session + + +def _build_statistic_task() -> statistic.StatisticOutputTask: + """构造一个最小可用的统计任务实例。 + + Returns: + statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。 + """ + task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask) + task.name_mapping = {} + return task + + +def _is_bot_self(platform: str, user_id: str) -> bool: + """返回固定的非机器人身份判断结果。 + + Args: + platform: 平台名称。 + user_id: 用户 ID。 + + Returns: + bool: 始终返回 ``False``。 + """ + del platform + del user_id + return False + + +def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None: + """统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。""" + calls: list[bool] = [] + now = datetime.now() + task = _build_statistic_task() + + monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls)) + + utils_module = ModuleType("src.chat.utils.utils") + utils_module.is_bot_self = _is_bot_self + monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module) + + statistic.StatisticOutputTask._fetch_online_time_since(now) + statistic.StatisticOutputTask._fetch_model_usage_since(now) + task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))]) + task._collect_interval_data(now, hours=1, interval_minutes=60) + task._collect_metrics_interval_data(now, hours=1, interval_hours=1) + + assert calls == [False] * 9 diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index ede10a41..51e5e643 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask): @staticmethod def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time) records = session.exec(statement).all() return [(record.start_timestamp, record.end_timestamp) for record in records] @staticmethod def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time) records = session.exec(statement).all() return [ @@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1] - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp) messages = session.exec(statement).all() for message in messages: @@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask): # 使用 ActionRecords 中的 reply 动作次数作为回复数基准 try: action_query_start_timestamp = collect_period[-1][1] - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp) actions = session.exec(statement).all() for action in actions: @@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= start_time) messages = session.exec(statement).all() for message in messages: @@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= start_time) messages = session.exec(statement).all() for message in messages: diff --git a/src/common/data_models/person_info_data_model.py b/src/common/data_models/person_info_data_model.py index 4cbb62d8..1b239356 100644 --- a/src/common/data_models/person_info_data_model.py +++ b/src/common/data_models/person_info_data_model.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import datetime -from typing import Optional, List +from typing import Any, List, Mapping, Optional, Sequence import json @@ -15,6 +15,76 @@ class GroupCardnameInfo: group_cardname: str +def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]: + """将单条群名片数据规范化为统一结构。 + + Args: + raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。 + + Returns: + Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。 + """ + group_id = str(raw_item.get("group_id") or "").strip() + group_cardname = str(raw_item.get("group_cardname") or "").strip() + if not group_id or not group_cardname: + return None + return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname) + + +def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]: + """解析数据库中的群名片 JSON 字段。 + + Args: + group_cardname_json: 数据库存储的群名片 JSON 字符串。 + + Returns: + Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。 + + Raises: + json.JSONDecodeError: 当 JSON 文本格式非法时抛出。 + TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。 + """ + if not group_cardname_json: + return None + + raw_items = json.loads(group_cardname_json) + if not isinstance(raw_items, list): + return None + + normalized_items: List[GroupCardnameInfo] = [] + for raw_item in raw_items: + if not isinstance(raw_item, Mapping): + continue + if normalized_item := _normalize_group_cardname_item(raw_item): + normalized_items.append(normalized_item) + + return normalized_items or None + + +def dump_group_cardname_records( + group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]], +) -> str: + """将群名片列表序列化为数据库使用的标准 JSON 字符串。 + + Args: + group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo` + 对象和包含 `group_id` / `group_cardname` 的字典。 + + Returns: + str: 统一使用 `group_cardname` 键名的 JSON 字符串。 + """ + normalized_items: List[GroupCardnameInfo] = [] + for raw_item in group_cardname_records or []: + if isinstance(raw_item, GroupCardnameInfo): + normalized_items.append(raw_item) + continue + if isinstance(raw_item, Mapping): + if normalized_item := _normalize_group_cardname_item(raw_item): + normalized_items.append(normalized_item) + + return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False) + + class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): def __init__( self, @@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): """最后一次被认识的时间""" @classmethod - def from_db_instance(cls, db_record: "PersonInfo"): - nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None - group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None + def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo": + """从数据库记录构造人物信息数据模型。 + + Args: + db_record: 数据库中的人物信息记录。 + + Returns: + MaiPersonInfo: 转换后的数据模型对象。 + """ + group_cardname_list = parse_group_cardname_json(db_record.group_cardname) memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None return cls( is_known=db_record.is_known, @@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): ) def to_db_instance(self) -> "PersonInfo": - group_cardname = ( - json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None - ) + """将当前数据模型转换为数据库记录对象。 + + Returns: + PersonInfo: 可直接写入数据库的模型实例。 + """ + group_cardname = dump_group_cardname_records(self.group_cardname_list) return PersonInfo( is_known=self.is_known, person_id=self.person_id, diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 799f56a0..15ef0049 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,22 +1,24 @@ -import hashlib +from datetime import datetime +from typing import Dict, Optional, Union + import asyncio +import hashlib import json -import time -import random import math +import random +import time from json_repair import repair_json -from typing import Union, Optional, Dict -from datetime import datetime from sqlmodel import col, select -from src.common.logger import get_logger +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo -from src.llm_models.utils_model import LLMRequest +from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.llm_models.utils_model import LLMRequest logger = get_logger("person_info") @@ -26,6 +28,32 @@ relation_selection_model = LLMRequest( ) +def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]: + """将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。 + + Args: + group_cardname_json: 数据库存储的群名片 JSON 字符串。 + + Returns: + list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。 + + Raises: + json.JSONDecodeError: 当 JSON 文本格式非法时抛出。 + TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。 + """ + group_cardname_list = parse_group_cardname_json(group_cardname_json) + if not group_cardname_list: + return [] + + return [ + { + "group_id": group_cardname.group_id, + "group_cardname": group_cardname.group_cardname, + } + for group_cardname in group_cardname_list + ] + + def get_person_id(platform: str, user_id: Union[int, str]) -> str: """获取唯一id""" if "-" in platform: @@ -231,7 +259,7 @@ class Person: person.know_since = time.time() person.last_know = time.time() person.memory_points = [] - person.group_nick_name = [] # 初始化群昵称列表 + person.group_cardname_list = [] # 初始化群名片列表 # 如果是群聊,添加群昵称 if group_id and group_nick_name: @@ -269,7 +297,7 @@ class Person: self.platform = platform self.nickname = global_config.bot.nickname self.person_name = global_config.bot.nickname - self.group_nick_name: list[dict[str, str]] = [] + self.group_cardname_list: list[dict[str, str]] = [] return self.user_id = "" @@ -308,7 +336,7 @@ class Person: self.know_since = None self.last_know: Optional[float] = None self.memory_points = [] - self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str} + self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str} # 从数据库加载数据 self.load_from_database() @@ -408,16 +436,16 @@ class Person: return # 检查是否已存在该群号的记录 - for item in self.group_nick_name: + for item in self.group_cardname_list: if item.get("group_id") == group_id: # 更新现有记录 - item["group_nick_name"] = group_nick_name + item["group_cardname"] = group_nick_name self.sync_to_database() logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}") return # 添加新记录 - self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name}) + self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name}) self.sync_to_database() logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}") @@ -452,20 +480,15 @@ class Person: else: self.memory_points = [] - # 处理group_nick_name字段(JSON格式的列表) + # 处理 group_cardname 字段(JSON 格式的列表) if record.group_cardname: try: - loaded_group_nick_names = json.loads(record.group_cardname) - # 确保是列表格式 - if isinstance(loaded_group_nick_names, list): - self.group_nick_name = loaded_group_nick_names - else: - self.group_nick_name = [] + self.group_cardname_list = _to_group_cardname_records(record.group_cardname) except (json.JSONDecodeError, TypeError): logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值") - self.group_nick_name = [] + self.group_cardname_list = [] else: - self.group_nick_name = [] + self.group_cardname_list = [] logger.debug(f"已从数据库加载用户 {self.person_id} 的信息") else: @@ -486,11 +509,7 @@ class Person: if self.memory_points else json.dumps([], ensure_ascii=False) ) - group_nickname_value = ( - json.dumps(self.group_nick_name, ensure_ascii=False) - if self.group_nick_name - else json.dumps([], ensure_ascii=False) - ) + group_cardname_value = dump_group_cardname_records(self.group_cardname_list) first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None @@ -510,7 +529,7 @@ class Person: record.first_known_time = first_known_time record.last_known_time = last_known_time record.memory_points = memory_points_value - record.group_nickname = group_nickname_value + record.group_cardname = group_cardname_value session.add(record) logger.debug(f"已同步用户 {self.person_id} 的信息到数据库") else: @@ -526,7 +545,7 @@ class Person: first_known_time=first_known_time, last_known_time=last_known_time, memory_points=memory_points_value, - group_nickname=group_nickname_value, + group_cardname=group_cardname_value, ) session.add(record) logger.debug(f"已创建用户 {self.person_id} 的信息到数据库") diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 33091d5a..8a26af11 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -8,13 +9,15 @@ import sys from src.common.logger import get_logger from src.config.config import global_config -from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager +from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager from src.platform_io.drivers import PluginPlatformDriver from src.platform_io.route_key_factory import RouteKeyFactory from src.platform_io.routing import RouteBindingConflictError from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( AdapterDeclarationPayload, + AdapterStateUpdatePayload, + AdapterStateUpdateResultPayload, BootstrapPluginPayload, ConfigUpdatedPayload, Envelope, @@ -46,6 +49,19 @@ if TYPE_CHECKING: logger = get_logger("plugin_runtime.host.runner_manager") +_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact" +_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default" + + +@dataclass(slots=True) +class _AdapterRuntimeState: + """保存适配器插件当前的运行时连接状态。""" + + connected: bool = False + account_id: Optional[str] = None + scope: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + class PluginRunnerSupervisor: """插件 Runner 监督器。 @@ -94,6 +110,7 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {} + self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -452,6 +469,7 @@ class PluginRunnerSupervisor: """注册 Host 侧内部 RPC 方法。""" self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request) self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message) + self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state) self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin) self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin) @@ -563,14 +581,14 @@ class PluginRunnerSupervisor: return f"adapter:{plugin_id}" async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None: - """将适配器插件注册到 Platform IO。 + """将适配器插件驱动注册到 Platform IO。 Args: plugin_id: 适配器插件 ID。 adapter: 经过校验的适配器声明。 Raises: - ValueError: 适配器路由冲突或驱动注册失败时抛出。 + ValueError: 当驱动注册失败时抛出。 """ await self._unregister_adapter_driver(plugin_id) @@ -588,22 +606,12 @@ class PluginRunnerSupervisor: **adapter.metadata, }, ) - binding = RouteBinding( - route_key=driver.descriptor.route_key, - driver_id=driver.driver_id, - driver_kind=DriverKind.PLUGIN, - metadata={ - "plugin_id": plugin_id, - "protocol": adapter.protocol, - }, - ) try: if platform_io_manager.is_started: await platform_io_manager.add_driver(driver) else: platform_io_manager.register_driver(driver) - platform_io_manager.bind_route(binding) except Exception: with contextlib.suppress(Exception): if platform_io_manager.is_started: @@ -613,6 +621,7 @@ class PluginRunnerSupervisor: raise self._registered_adapters[plugin_id] = adapter + self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState() async def _unregister_adapter_driver(self, plugin_id: str) -> None: """从 Platform IO 注销一个适配器驱动。 @@ -622,6 +631,9 @@ class PluginRunnerSupervisor: """ platform_io_manager = get_platform_io_manager() driver_id = self._build_adapter_driver_id(plugin_id) + adapter = self._registered_adapters.get(plugin_id) + + self._remove_adapter_route_bindings(plugin_id) with contextlib.suppress(Exception): if platform_io_manager.is_started: @@ -629,7 +641,11 @@ class PluginRunnerSupervisor: else: platform_io_manager.unregister_driver(driver_id) + if adapter is not None: + self._refresh_platform_default_route(adapter.platform) + self._registered_adapters.pop(plugin_id, None) + self._adapter_runtime_states.pop(plugin_id, None) async def _unregister_all_adapter_drivers(self) -> None: """注销当前 Supervisor 管理的全部适配器驱动。""" @@ -637,6 +653,198 @@ class PluginRunnerSupervisor: for plugin_id in plugin_ids: await self._unregister_adapter_driver(plugin_id) + def _remove_adapter_route_bindings(self, plugin_id: str) -> None: + """移除某个适配器驱动当前持有的全部路由绑定。 + + Args: + plugin_id: 适配器插件 ID。 + """ + platform_io_manager = get_platform_io_manager() + platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id)) + + @staticmethod + def _normalize_runtime_route_value(value: str) -> Optional[str]: + """规范化适配器运行时路由字段。 + + Args: + value: 待规范化的原始字符串。 + + Returns: + Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。 + """ + normalized_value = str(value).strip() + return normalized_value or None + + def _build_runtime_route_key( + self, + adapter: AdapterDeclarationPayload, + payload: AdapterStateUpdatePayload, + ) -> RouteKey: + """根据运行时状态更新构造适配器生效路由键。 + + Args: + adapter: 当前适配器声明。 + payload: 适配器上报的运行时状态。 + + Returns: + RouteKey: 当前连接应接管的精确路由键。 + + Raises: + ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。 + """ + runtime_account_id = self._normalize_runtime_route_value(payload.account_id) + runtime_scope = self._normalize_runtime_route_value(payload.scope) + + if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id: + raise ValueError( + f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致" + ) + if adapter.scope and runtime_scope and adapter.scope != runtime_scope: + raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致") + + return RouteKey( + platform=adapter.platform, + account_id=runtime_account_id or adapter.account_id or None, + scope=runtime_scope or adapter.scope or None, + ) + + def _bind_runtime_exact_route( + self, + plugin_id: str, + adapter: AdapterDeclarationPayload, + route_key: RouteKey, + ) -> None: + """为适配器连接绑定精确生效路由。 + + Args: + plugin_id: 适配器插件 ID。 + adapter: 当前适配器声明。 + route_key: 当前连接对应的精确路由键。 + + Raises: + RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。 + """ + platform_io_manager = get_platform_io_manager() + platform_io_manager.bind_route( + RouteBinding( + route_key=route_key, + driver_id=self._build_adapter_driver_id(plugin_id), + driver_kind=DriverKind.PLUGIN, + metadata={ + "plugin_id": plugin_id, + "protocol": adapter.protocol, + "binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT, + }, + ) + ) + + def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]: + """列出某个平台上由 Host 动态维护的精确适配器绑定。 + + Args: + platform: 目标平台名称。 + + Returns: + List[RouteBinding]: 当前平台上全部动态精确绑定。 + """ + platform_io_manager = get_platform_io_manager() + return [ + binding + for binding in platform_io_manager.route_table.list_bindings() + if binding.mode == RouteMode.ACTIVE + and binding.route_key.platform == platform + and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT + ] + + def _refresh_platform_default_route(self, platform: str) -> None: + """根据当前精确绑定数量刷新平台级默认路由。 + + 当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条 + ``RouteKey(platform=)`` 形式的默认路由,方便缺少账号维度的 + 出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1,则撤销 + 由 Host 自动维护的默认路由,避免出现隐式歧义。 + + Args: + platform: 目标平台名称。 + """ + platform_io_manager = get_platform_io_manager() + default_route_key = RouteKey(platform=platform) + existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True) + + if existing_default_binding is not None: + binding_role = existing_default_binding.metadata.get("binding_role") + if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT: + return + platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id) + + exact_bindings = self._list_runtime_exact_bindings(platform) + if len(exact_bindings) != 1: + return + + exact_binding = exact_bindings[0] + if exact_binding.route_key == default_route_key: + return + + platform_io_manager.bind_route( + RouteBinding( + route_key=default_route_key, + driver_id=exact_binding.driver_id, + driver_kind=exact_binding.driver_kind, + metadata={ + "plugin_id": exact_binding.metadata.get("plugin_id", ""), + "protocol": exact_binding.metadata.get("protocol", ""), + "binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT, + }, + ), + replace=True, + ) + + def _apply_adapter_runtime_state( + self, + plugin_id: str, + adapter: AdapterDeclarationPayload, + payload: AdapterStateUpdatePayload, + ) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]: + """应用适配器运行时状态,并同步 Platform IO 路由。 + + Args: + plugin_id: 适配器插件 ID。 + adapter: 当前适配器声明。 + payload: 适配器上报的运行时状态。 + + Returns: + Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及 + 供 RPC 响应返回的路由键字典。 + + Raises: + RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。 + ValueError: 当运行时路由信息不合法时抛出。 + """ + if not payload.connected: + self._remove_adapter_route_bindings(plugin_id) + self._refresh_platform_default_route(adapter.platform) + runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata)) + self._adapter_runtime_states[plugin_id] = runtime_state + return runtime_state, {} + + route_key = self._build_runtime_route_key(adapter, payload) + self._remove_adapter_route_bindings(plugin_id) + self._bind_runtime_exact_route(plugin_id, adapter, route_key) + self._refresh_platform_default_route(adapter.platform) + + runtime_state = _AdapterRuntimeState( + connected=True, + account_id=route_key.account_id, + scope=route_key.scope, + metadata=dict(payload.metadata), + ) + self._adapter_runtime_states[plugin_id] = runtime_state + return runtime_state, { + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + } + @staticmethod def _attach_inbound_route_metadata( session_message: "SessionMessage", @@ -706,6 +914,45 @@ class PluginRunnerSupervisor: scope=scope, ) + async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope: + """处理适配器插件上报的运行时状态更新。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 状态更新处理结果。 + """ + try: + payload = AdapterStateUpdatePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + adapter = self._registered_adapters.get(envelope.plugin_id) + if adapter is None: + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态", + ) + + try: + runtime_state, route_key_dict = self._apply_adapter_runtime_state( + plugin_id=envelope.plugin_id, + adapter=adapter, + payload=payload, + ) + except RouteBindingConflictError as exc: + return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc)) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + response = AdapterStateUpdateResultPayload( + accepted=True, + connected=runtime_state.connected, + route_key=route_key_dict, + ) + return envelope.make_response(payload=response.model_dump()) + async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope: """处理适配器插件上报的外部入站消息。 @@ -970,6 +1217,7 @@ class PluginRunnerSupervisor: self._component_registry.clear() self._registered_plugins.clear() self._registered_adapters.clear() + self._adapter_runtime_states.clear() self._runner_ready_events = asyncio.Event() self._runner_ready_payloads = RunnerReadyPayload() self._rpc_server.clear_handshake_state() diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index d71e02c5..f68657fa 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -304,6 +304,30 @@ class AdapterDeclarationPayload(BaseModel): """适配器附加元数据""" +class AdapterStateUpdatePayload(BaseModel): + """适配器运行时状态更新载荷。""" + + connected: bool = Field(description="适配器当前是否已连接并准备接管路由") + """适配器当前是否已连接并准备接管路由""" + account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id") + """当前连接对应的账号 ID 或 self_id""" + scope: str = Field(default="", description="当前连接对应的可选路由作用域") + """当前连接对应的可选路由作用域""" + metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据") + """可选的运行时状态元数据""" + + +class AdapterStateUpdateResultPayload(BaseModel): + """适配器运行时状态更新结果载荷。""" + + accepted: bool = Field(description="Host 是否接受了本次状态更新") + """Host 是否接受了本次状态更新""" + connected: bool = Field(description="Host 记录的当前连接状态") + """Host 记录的当前连接状态""" + route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键") + """当前生效的路由键""" + + class ReceiveExternalMessagePayload(BaseModel): """适配器插件向 Host 注入外部消息的请求载荷。""" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 88f92494..8078c88b 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -481,13 +481,14 @@ class PluginRunner: self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) return False - if not await self._invoke_plugin_on_load(meta): + if not await self._register_plugin(meta): + await self._invoke_plugin_on_unload(meta) await self._deactivate_plugin(meta) self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) return False - if not await self._register_plugin(meta): - await self._invoke_plugin_on_unload(meta) + if not await self._invoke_plugin_on_load(meta): + await self._unregister_plugin(meta.plugin_id, reason="on_load_failed") await self._deactivate_plugin(meta) self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) return False diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index a481101f..c8bb837b 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -60,6 +60,9 @@ class NapCatAdapterPlugin(MaiBotPlugin): self._connection_task: Optional[asyncio.Task[None]] = None self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} self._background_tasks: Set[asyncio.Task[Any]] = set() + self._reported_account_id: Optional[str] = None + self._reported_scope: Optional[str] = None + self._runtime_state_connected: bool = False self._send_lock = asyncio.Lock() self._ws: Optional[AiohttpClientWebSocketResponse] = None @@ -170,6 +173,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): with contextlib.suppress(asyncio.CancelledError): await connection_task + await self._report_adapter_disconnected() self._fail_pending_actions("NapCat connection closed") async def _cancel_background_tasks(self) -> None: @@ -209,6 +213,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}") finally: self._ws = None + await self._report_adapter_disconnected() self._fail_pending_actions("NapCat connection interrupted") if not self._should_connect(): @@ -230,26 +235,39 @@ class NapCatAdapterPlugin(MaiBotPlugin): """ assert WSMsgType is not None - async for ws_message in ws: - if ws_message.type != WSMsgType.TEXT: - if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: - break - continue + bootstrap_task = asyncio.create_task( + self._bootstrap_adapter_runtime_state(), + name="napcat_adapter.bootstrap", + ) + self._background_tasks.add(bootstrap_task) + bootstrap_task.add_done_callback(self._background_tasks.discard) - payload = self._parse_json_message(ws_message.data) - if payload is None: - continue + try: + async for ws_message in ws: + if ws_message.type != WSMsgType.TEXT: + if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: + break + continue - if echo_id := str(payload.get("echo") or "").strip(): - self._resolve_pending_action(echo_id, payload) - continue + payload = self._parse_json_message(ws_message.data) + if payload is None: + continue - if str(payload.get("post_type") or "").strip() != "message": - continue + if echo_id := str(payload.get("echo") or "").strip(): + self._resolve_pending_action(echo_id, payload) + continue - task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound") - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) + if str(payload.get("post_type") or "").strip() != "message": + continue + + task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound") + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + finally: + if not bootstrap_task.done(): + bootstrap_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await bootstrap_task async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None: """处理单条 NapCat 入站消息并注入 Host。 @@ -258,6 +276,9 @@ class NapCatAdapterPlugin(MaiBotPlugin): payload: NapCat / OneBot 推送的原始事件数据。 """ self_id = str(payload.get("self_id") or "").strip() + if self_id: + await self._report_adapter_connected(self_id) + sender = payload.get("sender", {}) if not isinstance(sender, dict): sender = {} @@ -570,6 +591,121 @@ class NapCatAdapterPlugin(MaiBotPlugin): response_future.set_exception(RuntimeError(error_message)) self._pending_actions.clear() + async def _bootstrap_adapter_runtime_state(self) -> None: + """在连接建立后主动获取账号信息并激活适配器路由。 + + 该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()` + 发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo + 响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。 + """ + max_attempts = 3 + last_error: Optional[Exception] = None + for attempt in range(1, max_attempts + 1): + ws = self._ws + if ws is None or ws.closed: + return + + try: + response = await self._call_action("get_login_info", {}) + self_id = self._extract_self_id_from_login_response(response) + await self._report_adapter_connected(self_id) + return + except asyncio.CancelledError: + raise + except Exception as exc: + last_error = exc + self.ctx.logger.warning( + f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}" + ) + if attempt < max_attempts: + await asyncio.sleep(1.0) + + if last_error is not None: + self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}") + + @staticmethod + def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str: + """从 `get_login_info` 响应中提取当前账号 ID。 + + Args: + response: NapCat 返回的原始动作响应。 + + Returns: + str: 规范化后的 `self_id` 字符串。 + + Raises: + ValueError: 当响应中缺少有效账号 ID 时抛出。 + """ + if str(response.get("status") or "").lower() != "ok": + raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed")) + + response_data = response.get("data", {}) + if not isinstance(response_data, dict): + raise ValueError("get_login_info 响应缺少 data 字段") + + self_id = str(response_data.get("user_id") or "").strip() + if not self_id: + raise ValueError("get_login_info 响应缺少有效的 user_id") + return self_id + + async def _report_adapter_connected(self, account_id: str) -> None: + """向 Host 上报当前连接已就绪。 + + Args: + account_id: 当前 NapCat 连接对应的机器人账号 ID。 + """ + normalized_account_id = str(account_id).strip() + if not normalized_account_id: + return + + scope = self._get_string(self._connection_config(), "connection_id").strip() + if ( + self._runtime_state_connected + and self._reported_account_id == normalized_account_id + and self._reported_scope == (scope or None) + ): + return + + accepted = False + try: + accepted = await self.ctx.adapter.update_runtime_state( + connected=True, + account_id=normalized_account_id, + scope=scope, + metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")}, + ) + except Exception as exc: + self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}") + return + + if not accepted: + self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新") + return + + self._runtime_state_connected = True + self._reported_account_id = normalized_account_id + self._reported_scope = scope or None + self.ctx.logger.info( + f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} " + f"scope={self._reported_scope or '*'}" + ) + + async def _report_adapter_disconnected(self) -> None: + """向 Host 上报当前连接已断开,并撤销适配器路由。""" + if not self._runtime_state_connected: + self._reported_account_id = None + self._reported_scope = None + return + + try: + await self.ctx.adapter.update_runtime_state(connected=False) + except Exception as exc: + self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}") + finally: + self._runtime_state_connected = False + self._reported_account_id = None + self._reported_scope = None + def _build_headers(self) -> Dict[str, str]: """构造连接 NapCat 所需的请求头。 From baabe4463ebea09d864b1dff8b78e04e78823751 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sun, 22 Mar 2026 00:19:26 +0800 Subject: [PATCH 24/45] feat: add NapCat built-in adapter with configuration, filters, and transport layer - Implemented configuration parsing for NapCat adapter including server, chat, and filter settings. - Added message filtering logic to handle inbound chat messages based on user and group lists. - Developed a transport layer for WebSocket communication with the NapCat server. - Created a query service for fetching user and group information from the QQ platform. - Implemented runtime state management to report connection status to the host. - Added notice handling for various QQ platform events. --- pytests/test_napcat_adapter_codec.py | 70 ++ pytests/test_napcat_adapter_config.py | 91 ++ pytests/test_platform_io_dedupe.py | 164 +++ pytests/test_plugin_message_utils_runtime.py | 87 ++ src/platform_io/manager.py | 35 +- src/platform_io/types.py | 5 +- src/plugin_runtime/host/message_utils.py | 268 ++++- .../built_in/napcat_adapter/__init__.py | 1 + .../built_in/napcat_adapter/codec_inbound.py | 414 +++++++ .../built_in/napcat_adapter/codec_outbound.py | 192 +++ src/plugins/built_in/napcat_adapter/config.py | 398 +++++++ .../built_in/napcat_adapter/constants.py | 9 + .../built_in/napcat_adapter/filters.py | 68 ++ src/plugins/built_in/napcat_adapter/plugin.py | 1049 +++-------------- .../built_in/napcat_adapter/qq_notice.py | 224 ++++ .../built_in/napcat_adapter/qq_queries.py | 170 +++ .../built_in/napcat_adapter/runtime_state.py | 85 ++ .../built_in/napcat_adapter/transport.py | 322 +++++ 18 files changed, 2755 insertions(+), 897 deletions(-) create mode 100644 pytests/test_napcat_adapter_codec.py create mode 100644 pytests/test_napcat_adapter_config.py create mode 100644 pytests/test_platform_io_dedupe.py create mode 100644 pytests/test_plugin_message_utils_runtime.py create mode 100644 src/plugins/built_in/napcat_adapter/__init__.py create mode 100644 src/plugins/built_in/napcat_adapter/codec_inbound.py create mode 100644 src/plugins/built_in/napcat_adapter/codec_outbound.py create mode 100644 src/plugins/built_in/napcat_adapter/config.py create mode 100644 src/plugins/built_in/napcat_adapter/constants.py create mode 100644 src/plugins/built_in/napcat_adapter/filters.py create mode 100644 src/plugins/built_in/napcat_adapter/qq_notice.py create mode 100644 src/plugins/built_in/napcat_adapter/qq_queries.py create mode 100644 src/plugins/built_in/napcat_adapter/runtime_state.py create mode 100644 src/plugins/built_in/napcat_adapter/transport.py diff --git a/pytests/test_napcat_adapter_codec.py b/pytests/test_napcat_adapter_codec.py new file mode 100644 index 00000000..6f557e08 --- /dev/null +++ b/pytests/test_napcat_adapter_codec.py @@ -0,0 +1,70 @@ +from pathlib import Path +from typing import Any, Dict + +import importlib +import sys + + +BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" +if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: + sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) + +NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec + + +def test_napcat_outbound_codec_supports_binary_and_forward_segments() -> None: + codec = NapCatOutboundCodec() + raw_message = [ + {"type": "text", "data": "hello"}, + {"type": "image", "data": "", "hash": "h1", "binary_data_base64": "aW1hZ2U="}, + {"type": "emoji", "data": "", "hash": "h2", "binary_data_base64": "ZW1vamk="}, + {"type": "voice", "data": "", "hash": "h3", "binary_data_base64": "dm9pY2U="}, + { + "type": "reply", + "data": { + "target_message_id": "origin-1", + "target_message_content": "origin text", + }, + }, + { + "type": "forward", + "data": [ + { + "user_id": "42", + "user_nickname": "alice", + "user_cardname": "Alice", + "message_id": "fwd-1", + "content": [{"type": "text", "data": "node-text"}], + } + ], + }, + ] + + converted = codec.convert_segments(raw_message) + + assert converted[0] == {"type": "text", "data": {"text": "hello"}} + assert converted[1]["type"] == "image" + assert converted[1]["data"]["file"] == "base64://aW1hZ2U=" + assert converted[2]["type"] == "image" + assert converted[2]["data"]["subtype"] == 1 + assert converted[3] == {"type": "record", "data": {"file": "base64://dm9pY2U="}} + assert converted[4] == {"type": "reply", "data": {"id": "origin-1"}} + assert converted[5]["type"] == "node" + assert converted[5]["data"]["name"] == "alice" + assert converted[5]["data"]["content"] == [{"type": "text", "data": {"text": "node-text"}}] + + +def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> None: + codec = NapCatOutboundCodec() + message: Dict[str, Any] = { + "message_info": { + "user_info": {"user_id": "10001", "user_nickname": "tester"}, + "additional_config": {}, + }, + "raw_message": [{"type": "text", "data": "hello"}], + } + + action_name, params = codec.build_outbound_action(message, {"target_user_id": "30001"}) + + assert action_name == "send_private_msg" + assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"} diff --git a/pytests/test_napcat_adapter_config.py b/pytests/test_napcat_adapter_config.py new file mode 100644 index 00000000..688b1a48 --- /dev/null +++ b/pytests/test_napcat_adapter_config.py @@ -0,0 +1,91 @@ +from pathlib import Path +from typing import List + +import importlib +import sys + + +BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" +if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: + sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) + +NapCatPluginSettings = importlib.import_module("napcat_adapter.config").NapCatPluginSettings + + +class DummyLogger: + """用于测试的轻量日志对象。""" + + def __init__(self) -> None: + """初始化测试日志对象。""" + self.warnings: List[str] = [] + self.errors: List[str] = [] + + def warning(self, message: str) -> None: + """记录警告日志。 + + Args: + message: 待记录的日志内容。 + """ + self.warnings.append(message) + + def error(self, message: str) -> None: + """记录错误日志。 + + Args: + message: 待记录的日志内容。 + """ + self.errors.append(message) + + +def test_parse_new_napcat_server_config() -> None: + logger = DummyLogger() + settings = NapCatPluginSettings.from_mapping( + { + "plugin": {"enabled": True, "config_version": "0.1.0"}, + "napcat_server": { + "host": "localhost", + "port": 8095, + "token": "secret", + "heartbeat_interval": 45, + "reconnect_delay_sec": 7, + "action_timeout_sec": 18, + "connection_id": "main", + }, + }, + logger, + ) + + assert settings.should_connect() is True + assert settings.napcat_server.host == "localhost" + assert settings.napcat_server.port == 8095 + assert settings.napcat_server.token == "secret" + assert settings.napcat_server.heartbeat_interval == 45.0 + assert settings.napcat_server.reconnect_delay_sec == 7.0 + assert settings.napcat_server.action_timeout_sec == 18.0 + assert settings.napcat_server.connection_id == "main" + assert settings.napcat_server.build_ws_url() == "ws://localhost:8095" + assert settings.validate(logger) is True + + +def test_parse_legacy_connection_ws_url_fallback() -> None: + logger = DummyLogger() + settings = NapCatPluginSettings.from_mapping( + { + "plugin": {"enabled": True, "config_version": "0.1.0"}, + "connection": { + "ws_url": "ws://127.0.0.1:3001", + "access_token": "legacy-token", + "heartbeat_sec": 35, + "action_timeout_sec": 12, + }, + }, + logger, + ) + + assert settings.napcat_server.host == "127.0.0.1" + assert settings.napcat_server.port == 3001 + assert settings.napcat_server.token == "legacy-token" + assert settings.napcat_server.heartbeat_interval == 35.0 + assert settings.napcat_server.action_timeout_sec == 12.0 + assert settings.validate(logger) is True + assert logger.warnings diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py new file mode 100644 index 00000000..4a3cbb44 --- /dev/null +++ b/pytests/test_platform_io_dedupe.py @@ -0,0 +1,164 @@ +"""Platform IO 入站去重策略测试。""" + +from types import SimpleNamespace +from typing import Any, Dict, List, Optional + +import pytest + +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey + + +def _build_envelope( + *, + dedupe_key: str | None = None, + external_message_id: str | None = None, + session_message_id: str | None = None, + payload: Optional[Dict[str, Any]] = None, +) -> InboundMessageEnvelope: + """构造测试用入站信封。 + + Args: + dedupe_key: 显式去重键。 + external_message_id: 平台侧消息 ID。 + session_message_id: 规范化消息对象上的消息 ID。 + payload: 原始载荷。 + + Returns: + InboundMessageEnvelope: 测试用入站消息信封。 + """ + session_message = None + if session_message_id is not None: + session_message = SimpleNamespace(message_id=session_message_id) + + return InboundMessageEnvelope( + route_key=RouteKey(platform="qq", account_id="10001", scope="main"), + driver_id="plugin.napcat", + driver_kind=DriverKind.PLUGIN, + dedupe_key=dedupe_key, + external_message_id=external_message_id, + session_message=session_message, + payload=payload, + ) + + +class _StubPlatformIODriver(PlatformIODriver): + """测试用 Platform IO 驱动。""" + + async def send_message( + self, + message: Any, + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """返回一个固定的成功回执。 + + Args: + message: 待发送的消息对象。 + route_key: 本次发送使用的路由键。 + metadata: 额外发送元数据。 + + Returns: + DeliveryReceipt: 固定的成功回执。 + """ + return DeliveryReceipt( + internal_message_id=str(getattr(message, "message_id", "stub-message-id")), + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) + + +def _build_manager() -> PlatformIOManager: + """构造带有最小 active owner 的 Broker 管理器。 + + Returns: + PlatformIOManager: 已注册测试驱动并绑定活动路由的 Broker。 + """ + manager = PlatformIOManager() + driver = _StubPlatformIODriver( + DriverDescriptor( + driver_id="plugin.napcat", + kind=DriverKind.PLUGIN, + platform="qq", + account_id="10001", + scope="main", + ) + ) + manager.register_driver(driver) + manager.bind_route( + RouteBinding( + route_key=RouteKey(platform="qq", account_id="10001", scope="main"), + driver_id=driver.driver_id, + driver_kind=driver.descriptor.kind, + ) + ) + return manager + + +class TestPlatformIODedupe: + """Platform IO 去重测试。""" + + @pytest.mark.asyncio + async def test_accept_inbound_dedupes_by_external_message_id(self) -> None: + """相同平台消息 ID 的重复入站应被抑制。""" + manager = _build_manager() + accepted_envelopes: List[InboundMessageEnvelope] = [] + + async def dispatcher(envelope: InboundMessageEnvelope) -> None: + """记录被成功接收的入站消息。 + + Args: + envelope: 被 Broker 接受的入站消息。 + """ + accepted_envelopes.append(envelope) + + manager.set_inbound_dispatcher(dispatcher) + + first_envelope = _build_envelope( + external_message_id="msg-1", + payload={"message": "hello"}, + ) + second_envelope = _build_envelope( + external_message_id="msg-1", + payload={"message": "hello"}, + ) + + assert await manager.accept_inbound(first_envelope) is True + assert await manager.accept_inbound(second_envelope) is False + assert len(accepted_envelopes) == 1 + + @pytest.mark.asyncio + async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None: + """缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。""" + manager = _build_manager() + accepted_envelopes: List[InboundMessageEnvelope] = [] + + async def dispatcher(envelope: InboundMessageEnvelope) -> None: + """记录被成功接收的入站消息。 + + Args: + envelope: 被 Broker 接受的入站消息。 + """ + accepted_envelopes.append(envelope) + + manager.set_inbound_dispatcher(dispatcher) + + first_envelope = _build_envelope(payload={"message": "same-payload"}) + second_envelope = _build_envelope(payload={"message": "same-payload"}) + + assert await manager.accept_inbound(first_envelope) is True + assert await manager.accept_inbound(second_envelope) is True + assert len(accepted_envelopes) == 2 + + def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None: + """去重键应只来自显式或稳定的技术身份。""" + explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1") + session_message_envelope = _build_envelope(session_message_id="session-1") + payload_only_envelope = _build_envelope(payload={"message": "hello"}) + + assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "qq:10001:main:dedupe-1" + assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "qq:10001:main:session-1" + assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None diff --git a/pytests/test_plugin_message_utils_runtime.py b/pytests/test_plugin_message_utils_runtime.py new file mode 100644 index 00000000..cb4b5341 --- /dev/null +++ b/pytests/test_plugin_message_utils_runtime.py @@ -0,0 +1,87 @@ +from datetime import datetime +from pathlib import Path + +import sys + +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo +from src.common.data_models.message_component_data_model import ( + ForwardComponent, + ForwardNodeComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + TextComponent, + VoiceComponent, +) +from src.plugin_runtime.host.message_utils import PluginMessageUtils + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None: + message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq") + message.message_info = MessageInfo( + user_info=UserInfo(user_id="10001", user_nickname="tester"), + group_info=GroupInfo(group_id="20001", group_name="group"), + additional_config={"self_id": "999"}, + ) + message.session_id = "qq:20001:10001" + message.processed_plain_text = "binary payload" + message.display_message = "binary payload" + message.raw_message = MessageSequence( + components=[ + TextComponent("hello"), + ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""), + VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""), + ReplyComponent( + target_message_id="origin-1", + target_message_content="origin text", + target_message_sender_id="42", + target_message_sender_nickname="alice", + target_message_sender_cardname="Alice", + ), + ForwardNodeComponent( + forward_components=[ + ForwardComponent( + user_nickname="bob", + user_id="43", + user_cardname="Bob", + message_id="forward-1", + content=[ + TextComponent("node-text"), + ImageComponent(binary_hash="", binary_data=b"node-image", content=""), + ], + ) + ] + ), + ] + ) + + message_dict = PluginMessageUtils._session_message_to_dict(message) + rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict)) + + image_component = rebuilt_message.raw_message.components[1] + voice_component = rebuilt_message.raw_message.components[2] + reply_component = rebuilt_message.raw_message.components[3] + forward_component = rebuilt_message.raw_message.components[4] + + assert isinstance(image_component, ImageComponent) + assert image_component.binary_data == b"image-bytes" + + assert isinstance(voice_component, VoiceComponent) + assert voice_component.binary_data == b"voice-bytes" + + assert isinstance(reply_component, ReplyComponent) + assert reply_component.target_message_id == "origin-1" + assert reply_component.target_message_content == "origin text" + assert reply_component.target_message_sender_id == "42" + assert reply_component.target_message_sender_nickname == "alice" + assert reply_component.target_message_sender_cardname == "Alice" + + assert isinstance(forward_component, ForwardNodeComponent) + assert isinstance(forward_component.forward_components[0].content[1], ImageComponent) + assert forward_component.forward_components[0].content[1].binary_data == b"node-image" diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index 97835667..b1fe3bdc 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -2,9 +2,6 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional -import hashlib -import json - from src.common.logger import get_logger from src.platform_io.drivers.base import PlatformIODriver @@ -438,12 +435,17 @@ class PlatformIOManager: Returns: Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。 + + Notes: + 这里仅接受上游显式提供的稳定消息身份,例如 ``dedupe_key``、 + 平台侧 ``external_message_id`` 或已经完成规范化的 + ``session_message.message_id``。Broker 不再根据 ``payload`` 内容 + 猜测语义去重键,避免把“短时间内两条内容刚好完全相同”的合法消息 + 误判为重复入站。 """ raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id if raw_dedupe_key is None and envelope.session_message is not None: raw_dedupe_key = envelope.session_message.message_id - if raw_dedupe_key is None and envelope.payload is not None: - raw_dedupe_key = PlatformIOManager._build_payload_fingerprint(envelope.payload) if raw_dedupe_key is None: return None @@ -453,29 +455,6 @@ class PlatformIOManager: return f"{envelope.route_key.to_dedupe_scope()}:{normalized_dedupe_key}" - @staticmethod - def _build_payload_fingerprint(payload: Dict[str, Any]) -> Optional[str]: - """根据消息载荷构造稳定指纹。 - - Args: - payload: 待构造指纹的原始载荷字典。 - - Returns: - Optional[str]: 若成功生成指纹则返回十六进制摘要,否则返回 ``None``。 - """ - try: - serialized_payload = json.dumps( - payload, - default=str, - ensure_ascii=True, - separators=(",", ":"), - sort_keys=True, - ) - except Exception: - return None - - return hashlib.sha256(serialized_payload.encode()).hexdigest() - @staticmethod def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None: """校验路由绑定与驱动描述是否一致。 diff --git a/src/platform_io/types.py b/src/platform_io/types.py index c74dc246..8729b637 100644 --- a/src/platform_io/types.py +++ b/src/platform_io/types.py @@ -198,8 +198,9 @@ class InboundMessageEnvelope: driver_kind: 产出该消息的驱动类型。 external_message_id: 可选的平台侧消息 ID,用于去重。 dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时, - 可由上游驱动提供消息指纹。若这里为空,中间层仍可能继续回退到 - ``session_message.message_id`` 或 ``payload`` 指纹。 + 可由上游驱动提供稳定的技术性幂等键。若这里为空,中间层仅会继续 + 回退到 ``external_message_id`` 或 ``session_message.message_id``, + 不会再根据 ``payload`` 内容猜测语义去重键。 session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。 payload: 可选的原始字典载荷,供延迟转换或调试使用。 metadata: 额外入站元数据,例如连接信息或追踪上下文。 diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py index aaebb529..2f6aa01b 100644 --- a/src/plugin_runtime/host/message_utils.py +++ b/src/plugin_runtime/host/message_utils.py @@ -1,10 +1,25 @@ from datetime import datetime -from typing import Dict, Any, TypedDict, Optional, List +from typing import Any, Dict, List, Optional, TypedDict + +import base64 +import hashlib from src.common.logger import get_logger from src.chat.message_receive.message import SessionMessage from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo -from src.common.data_models.message_component_data_model import MessageSequence +from src.common.data_models.message_component_data_model import ( + AtComponent, + DictComponent, + EmojiComponent, + ForwardComponent, + ForwardNodeComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + StandardMessageComponents, + TextComponent, + VoiceComponent, +) logger = get_logger("plugin_runtime.host.message_utils") @@ -45,6 +60,251 @@ class MessageDict(TypedDict, total=False): class PluginMessageUtils: + @staticmethod + def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]: + """将消息组件序列转换为插件运行时使用的字典结构。 + + Args: + message_sequence: 待转换的消息组件序列。 + + Returns: + List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。 + """ + return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components] + + @staticmethod + def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]: + """将单个消息组件转换为插件运行时字典结构。 + + Args: + component: 待转换的消息组件。 + + Returns: + Dict[str, Any]: 序列化后的消息组件字典。 + """ + if isinstance(component, TextComponent): + return {"type": "text", "data": component.text} + + if isinstance(component, ImageComponent): + serialized = { + "type": "image", + "data": component.content, + "hash": component.binary_hash, + } + if component.binary_data: + serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8") + return serialized + + if isinstance(component, EmojiComponent): + serialized = { + "type": "emoji", + "data": component.content, + "hash": component.binary_hash, + } + if component.binary_data: + serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8") + return serialized + + if isinstance(component, VoiceComponent): + serialized = { + "type": "voice", + "data": component.content, + "hash": component.binary_hash, + } + if component.binary_data: + serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8") + return serialized + + if isinstance(component, AtComponent): + return { + "type": "at", + "data": { + "target_user_id": component.target_user_id, + "target_user_nickname": component.target_user_nickname, + "target_user_cardname": component.target_user_cardname, + }, + } + + if isinstance(component, ReplyComponent): + return { + "type": "reply", + "data": { + "target_message_id": component.target_message_id, + "target_message_content": component.target_message_content, + "target_message_sender_id": component.target_message_sender_id, + "target_message_sender_nickname": component.target_message_sender_nickname, + "target_message_sender_cardname": component.target_message_sender_cardname, + }, + } + + if isinstance(component, ForwardNodeComponent): + return { + "type": "forward", + "data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components], + } + + return {"type": "dict", "data": component.data} + + @staticmethod + def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]: + """将单个转发节点组件转换为字典结构。 + + Args: + component: 待转换的转发节点组件。 + + Returns: + Dict[str, Any]: 序列化后的转发节点字典。 + """ + return { + "user_id": component.user_id, + "user_nickname": component.user_nickname, + "user_cardname": component.user_cardname, + "message_id": component.message_id, + "content": [PluginMessageUtils._component_to_dict(item) for item in component.content], + } + + @staticmethod + def _message_sequence_from_dict(raw_message_data: List[Dict[str, Any]]) -> MessageSequence: + """从插件运行时字典结构恢复消息组件序列。 + + Args: + raw_message_data: 插件运行时消息段字典列表。 + + Returns: + MessageSequence: 恢复后的消息组件序列。 + """ + components = [PluginMessageUtils._component_from_dict(item) for item in raw_message_data] + return MessageSequence(components=components) + + @staticmethod + def _component_from_dict(item: Dict[str, Any]) -> StandardMessageComponents: + """从插件运行时字典结构恢复单个消息组件。 + + Args: + item: 单个消息组件的字典表示。 + + Returns: + StandardMessageComponents: 恢复后的内部消息组件对象。 + """ + item_type = str(item.get("type") or "").strip() + if item_type == "text": + return TextComponent(text=str(item.get("data") or "")) + + if item_type == "image": + return PluginMessageUtils._build_binary_component(ImageComponent, item) + + if item_type == "emoji": + return PluginMessageUtils._build_binary_component(EmojiComponent, item) + + if item_type == "voice": + return PluginMessageUtils._build_binary_component(VoiceComponent, item) + + if item_type == "at": + item_data = item.get("data", {}) + if not isinstance(item_data, dict): + item_data = {} + return AtComponent( + target_user_id=str(item_data.get("target_user_id") or ""), + target_user_nickname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_nickname")), + target_user_cardname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_cardname")), + ) + + if item_type == "reply": + reply_data = item.get("data") + if isinstance(reply_data, dict): + return ReplyComponent( + target_message_id=str(reply_data.get("target_message_id") or ""), + target_message_content=PluginMessageUtils._normalize_optional_string( + reply_data.get("target_message_content") + ), + target_message_sender_id=PluginMessageUtils._normalize_optional_string( + reply_data.get("target_message_sender_id") + ), + target_message_sender_nickname=PluginMessageUtils._normalize_optional_string( + reply_data.get("target_message_sender_nickname") + ), + target_message_sender_cardname=PluginMessageUtils._normalize_optional_string( + reply_data.get("target_message_sender_cardname") + ), + ) + return ReplyComponent(target_message_id=str(reply_data or "")) + + if item_type == "forward": + forward_nodes: List[ForwardComponent] = [] + raw_forward_nodes = item.get("data", []) + if isinstance(raw_forward_nodes, list): + for node in raw_forward_nodes: + if not isinstance(node, dict): + continue + raw_content = node.get("content", []) + node_components: List[StandardMessageComponents] = [] + if isinstance(raw_content, list): + node_components = [ + PluginMessageUtils._component_from_dict(content) + for content in raw_content + if isinstance(content, dict) + ] + if not node_components: + node_components = [TextComponent(text="[empty forward node]")] + forward_nodes.append( + ForwardComponent( + user_nickname=str(node.get("user_nickname") or "未知用户"), + user_id=PluginMessageUtils._normalize_optional_string(node.get("user_id")), + user_cardname=PluginMessageUtils._normalize_optional_string(node.get("user_cardname")), + message_id=str(node.get("message_id") or ""), + content=node_components, + ) + ) + if not forward_nodes: + return DictComponent(data={"type": "forward", "data": item.get("data", [])}) + return ForwardNodeComponent(forward_components=forward_nodes) + + component_data = item.get("data") + if isinstance(component_data, dict): + return DictComponent(data=component_data) + return DictComponent(data=item) + + @staticmethod + def _build_binary_component(component_cls: Any, item: Dict[str, Any]) -> StandardMessageComponents: + """从字典构造带二进制负载的消息组件。 + + Args: + component_cls: 目标组件类型。 + item: 消息组件字典。 + + Returns: + StandardMessageComponents: 构造后的组件对象。 + """ + content = str(item.get("data") or "") + binary_hash = str(item.get("hash") or "") + raw_binary_base64 = item.get("binary_data_base64") + binary_data = b"" + if isinstance(raw_binary_base64, str) and raw_binary_base64: + try: + binary_data = base64.b64decode(raw_binary_base64) + except Exception: + binary_data = b"" + + if not binary_hash and binary_data: + binary_hash = hashlib.sha256(binary_data).hexdigest() + + return component_cls(binary_hash=binary_hash, content=content, binary_data=binary_data) + + @staticmethod + def _normalize_optional_string(value: Any) -> Optional[str]: + """将任意值规范化为可选字符串。 + + Args: + value: 待规范化的值。 + + Returns: + Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。 + """ + if value is None: + return None + normalized_value = str(value) + return normalized_value if normalized_value else None + @staticmethod def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict: """ @@ -92,7 +352,7 @@ class PluginMessageUtils: timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串 platform=session_message.platform, message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info), - raw_message=session_message.raw_message.to_dict(), # 复用 MessageSequence.to_dict() + raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message), is_mentioned=session_message.is_mentioned, is_at=session_message.is_at, is_emoji=session_message.is_emoji, @@ -186,7 +446,7 @@ class PluginMessageUtils: # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法) raw_message_data = message_dict["raw_message"] if isinstance(raw_message_data, list): - session_message.raw_message = MessageSequence.from_dict(raw_message_data) + session_message.raw_message = PluginMessageUtils._message_sequence_from_dict(raw_message_data) else: raise ValueError("消息字典中 'raw_message' 字段必须是一个列表") diff --git a/src/plugins/built_in/napcat_adapter/__init__.py b/src/plugins/built_in/napcat_adapter/__init__.py new file mode 100644 index 00000000..fa82860f --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/__init__.py @@ -0,0 +1 @@ +"""NapCat 内置适配器插件包。""" diff --git a/src/plugins/built_in/napcat_adapter/codec_inbound.py b/src/plugins/built_in/napcat_adapter/codec_inbound.py new file mode 100644 index 00000000..b8065585 --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/codec_inbound.py @@ -0,0 +1,414 @@ +"""NapCat 入站消息编解码。""" + +from typing import Any, Dict, List, Mapping, Optional, Tuple +from uuid import uuid4 + +import hashlib +import json +import time + +from napcat_adapter.qq_queries import NapCatQueryService + + +class NapCatInboundCodec: + """NapCat 入站消息编码器。""" + + def __init__(self, logger: Any, query_service: NapCatQueryService) -> None: + """初始化入站消息编码器。 + + Args: + logger: 插件日志对象。 + query_service: QQ 查询服务。 + """ + self._logger = logger + self._query_service = query_service + + async def build_message_dict( + self, + payload: Mapping[str, Any], + self_id: str, + sender_user_id: str, + sender: Mapping[str, Any], + ) -> Dict[str, Any]: + """构造 Host 侧可接受的 ``MessageDict``。 + + Args: + payload: NapCat 原始消息事件。 + self_id: 当前机器人账号 ID。 + sender_user_id: 发送者用户 ID。 + sender: 发送者信息字典。 + + Returns: + Dict[str, Any]: 规范化后的 ``MessageDict``。 + """ + message_type = str(payload.get("message_type") or "").strip() or "private" + group_id = str(payload.get("group_id") or "").strip() + group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "") + user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id + user_cardname = str(sender.get("card") or "").strip() or None + + raw_message, is_at = await self.convert_segments(payload, self_id) + raw_message_text = str(payload.get("raw_message") or "").strip() + if not raw_message: + raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}] + + plain_text = self.build_plain_text(raw_message, raw_message_text) + timestamp_seconds = payload.get("time") + if not isinstance(timestamp_seconds, (int, float)): + timestamp_seconds = time.time() + + additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type} + if group_id: + additional_config["platform_io_target_group_id"] = group_id + else: + additional_config["platform_io_target_user_id"] = sender_user_id + + message_info: Dict[str, Any] = { + "user_info": { + "user_id": sender_user_id, + "user_nickname": user_nickname, + "user_cardname": user_cardname, + }, + "additional_config": additional_config, + } + if group_id: + message_info["group_info"] = {"group_id": group_id, "group_name": group_name} + + message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip() + return { + "message_id": message_id, + "timestamp": str(float(timestamp_seconds)), + "platform": "qq", + "message_info": message_info, + "raw_message": raw_message, + "is_mentioned": is_at, + "is_at": is_at, + "is_emoji": False, + "is_picture": False, + "is_command": plain_text.startswith("/"), + "is_notify": False, + "session_id": "", + "processed_plain_text": plain_text, + "display_message": plain_text, + } + + async def convert_segments(self, payload: Mapping[str, Any], self_id: str) -> Tuple[List[Dict[str, Any]], bool]: + """将 OneBot 消息段转换为 Host 消息段结构。 + + Args: + payload: OneBot 原始消息事件。 + self_id: 当前机器人账号 ID。 + + Returns: + Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 + """ + message_payload = payload.get("message") + if isinstance(message_payload, str): + normalized_text = message_payload.strip() + return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False + + if not isinstance(message_payload, list): + return [], False + + converted_segments: List[Dict[str, Any]] = [] + is_at = False + for segment in message_payload: + if not isinstance(segment, Mapping): + continue + + segment_type = str(segment.get("type") or "").strip() + segment_data = segment.get("data", {}) + if not isinstance(segment_data, Mapping): + segment_data = {} + + if segment_type == "text": + if text_value := str(segment_data.get("text") or ""): + converted_segments.append({"type": "text", "data": text_value}) + continue + + if segment_type == "at": + if target_user_id := str(segment_data.get("qq") or "").strip(): + converted_segments.append( + { + "type": "at", + "data": { + "target_user_id": target_user_id, + "target_user_nickname": None, + "target_user_cardname": None, + }, + } + ) + if self_id and target_user_id == self_id: + is_at = True + continue + + if segment_type == "reply": + if reply_segment := await self._build_reply_segment(segment_data): + converted_segments.append(reply_segment) + continue + + if segment_type == "face": + converted_segments.append({"type": "text", "data": "[face]"}) + continue + + if segment_type == "image": + converted_segments.append(await self._build_image_like_segment(segment_data, is_emoji=False)) + continue + + if segment_type == "record": + converted_segments.append(await self._build_record_segment(segment_data)) + continue + + if segment_type == "video": + converted_segments.append({"type": "text", "data": "[video]"}) + continue + + if segment_type == "file": + converted_segments.append({"type": "text", "data": "[file]"}) + continue + + if segment_type == "json": + converted_segments.append(self._build_json_text_segment(segment_data)) + continue + + if segment_type == "forward": + if forward_segment := await self._build_forward_segment(segment_data): + converted_segments.append(forward_segment) + continue + + if segment_type in {"xml", "share"}: + converted_segments.append({"type": "text", "data": f"[{segment_type}]"}) + + return converted_segments, is_at + + async def _build_reply_segment(self, segment_data: Mapping[str, Any]) -> Optional[Dict[str, Any]]: + """构造回复消息段。 + + Args: + segment_data: OneBot ``reply`` 段的 ``data`` 字典。 + + Returns: + Optional[Dict[str, Any]]: 转换后的回复消息段;缺少消息 ID 时返回 ``None``。 + """ + target_message_id = str(segment_data.get("id") or "").strip() + if not target_message_id: + return None + + message_detail = await self._query_service.get_message_detail(target_message_id) + reply_payload: Dict[str, Any] = {"target_message_id": target_message_id} + if message_detail is not None: + sender = message_detail.get("sender", {}) + if not isinstance(sender, Mapping): + sender = {} + reply_payload["target_message_content"] = str(message_detail.get("raw_message") or "").strip() or None + reply_payload["target_message_sender_id"] = str( + message_detail.get("user_id") or sender.get("user_id") or "" + ).strip() or None + reply_payload["target_message_sender_nickname"] = str(sender.get("nickname") or "").strip() or None + reply_payload["target_message_sender_cardname"] = str(sender.get("card") or "").strip() or None + + return {"type": "reply", "data": reply_payload} + + async def _build_image_like_segment( + self, + segment_data: Mapping[str, Any], + is_emoji: bool, + ) -> Dict[str, Any]: + """构造图片或表情消息段。 + + Args: + segment_data: OneBot ``image`` 段的 ``data`` 字典。 + is_emoji: 是否按表情组件处理。 + + Returns: + Dict[str, Any]: 转换后的图片或表情消息段。 + """ + subtype = segment_data.get("sub_type") + actual_is_emoji = is_emoji or (isinstance(subtype, int) and subtype not in {0, 4, 9}) + + image_url = str(segment_data.get("url") or "").strip() + binary_data = await self._query_service.download_binary(image_url) + if not binary_data: + return {"type": "text", "data": "[emoji]" if actual_is_emoji else "[image]"} + + return { + "type": "emoji" if actual_is_emoji else "image", + "data": "", + "hash": hashlib.sha256(binary_data).hexdigest(), + "binary_data_base64": self._encode_binary(binary_data), + } + + async def _build_record_segment(self, segment_data: Mapping[str, Any]) -> Dict[str, Any]: + """构造语音消息段。 + + Args: + segment_data: OneBot ``record`` 段的 ``data`` 字典。 + + Returns: + Dict[str, Any]: 转换后的语音或占位文本消息段。 + """ + file_name = str(segment_data.get("file") or "").strip() + file_id = str(segment_data.get("file_id") or "").strip() or None + if not file_name: + return {"type": "text", "data": "[voice]"} + + record_detail = await self._query_service.get_record_detail(file_name=file_name, file_id=file_id) + if record_detail is None: + return {"type": "text", "data": "[voice]"} + + record_base64 = str(record_detail.get("base64") or "").strip() + if not record_base64: + return {"type": "text", "data": "[voice]"} + + try: + binary_data = self._decode_binary(record_base64) + except Exception: + return {"type": "text", "data": "[voice]"} + + return { + "type": "voice", + "data": "", + "hash": hashlib.sha256(binary_data).hexdigest(), + "binary_data_base64": self._encode_binary(binary_data), + } + + async def _build_forward_segment(self, segment_data: Mapping[str, Any]) -> Optional[Dict[str, Any]]: + """构造合并转发消息段。 + + Args: + segment_data: OneBot ``forward`` 段的 ``data`` 字典。 + + Returns: + Optional[Dict[str, Any]]: 转换后的合并转发消息段;失败时返回 ``None``。 + """ + message_id = str(segment_data.get("id") or "").strip() + if not message_id: + return None + + forward_detail = await self._query_service.get_forward_message(message_id) + if forward_detail is None: + return {"type": "text", "data": "[forward]"} + + messages = forward_detail.get("messages", []) + if not isinstance(messages, list): + return {"type": "text", "data": "[forward]"} + + forward_nodes: List[Dict[str, Any]] = [] + for forward_message in messages: + if not isinstance(forward_message, Mapping): + continue + raw_content = forward_message.get("content", []) + content_segments = await self._convert_forward_content(raw_content, "") + sender = forward_message.get("sender", {}) + if not isinstance(sender, Mapping): + sender = {} + forward_nodes.append( + { + "user_id": str(sender.get("user_id") or sender.get("uin") or "").strip() or None, + "user_nickname": str(sender.get("nickname") or sender.get("name") or "未知用户"), + "user_cardname": str(sender.get("card") or "").strip() or None, + "message_id": str(forward_message.get("message_id") or uuid4().hex), + "content": content_segments or [{"type": "text", "data": "[empty]"}], + } + ) + + if not forward_nodes: + return {"type": "text", "data": "[forward]"} + return {"type": "forward", "data": forward_nodes} + + async def _convert_forward_content(self, raw_content: Any, self_id: str) -> List[Dict[str, Any]]: + """转换转发节点内部的消息段列表。 + + Args: + raw_content: 转发节点原始内容。 + self_id: 当前机器人账号 ID。 + + Returns: + List[Dict[str, Any]]: 转换后的消息段列表。 + """ + pseudo_payload: Dict[str, Any] = {"message": raw_content} + segments, _ = await self.convert_segments(pseudo_payload, self_id) + return segments + + def _build_json_text_segment(self, segment_data: Mapping[str, Any]) -> Dict[str, Any]: + """将 JSON 卡片最佳努力转换为文本占位。 + + Args: + segment_data: OneBot ``json`` 段的 ``data`` 字典。 + + Returns: + Dict[str, Any]: 转换后的文本消息段。 + """ + json_data = str(segment_data.get("data") or "").strip() + if not json_data: + return {"type": "text", "data": "[json]"} + + try: + parsed_json = json.loads(json_data) + except Exception: + return {"type": "text", "data": "[json]"} + + app_name = str(parsed_json.get("app") or "").strip() + prompt = "" + if isinstance(parsed_json.get("meta"), Mapping): + prompt = str(parsed_json["meta"].get("prompt") or "").strip() + text = prompt or app_name or "json" + return {"type": "text", "data": f"[json:{text}]"} + + @staticmethod + def _encode_binary(binary_data: bytes) -> str: + """将二进制内容编码为 Base64 字符串。 + + Args: + binary_data: 待编码的二进制内容。 + + Returns: + str: Base64 编码字符串。 + """ + import base64 + + return base64.b64encode(binary_data).decode("utf-8") + + @staticmethod + def _decode_binary(binary_base64: str) -> bytes: + """将 Base64 字符串解码为二进制内容。 + + Args: + binary_base64: Base64 字符串。 + + Returns: + bytes: 解码后的二进制内容。 + """ + import base64 + + return base64.b64decode(binary_base64) + + def build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str: + """从标准消息段中提取可展示的纯文本。 + + Args: + raw_message: 标准化后的消息段列表。 + fallback_text: 当无法拼出文本时使用的回退文本。 + + Returns: + str: 用于 Host 展示和命令判断的纯文本内容。 + """ + plain_text_parts: List[str] = [] + for item in raw_message: + if not isinstance(item, Mapping): + continue + item_type = str(item.get("type") or "").strip() + item_data = item.get("data") + if item_type == "text": + plain_text_parts.append(str(item_data or "")) + elif item_type == "at" and isinstance(item_data, Mapping): + plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}") + elif item_type == "reply": + plain_text_parts.append("[reply]") + elif item_type == "forward": + plain_text_parts.append("[forward]") + elif item_type in {"image", "emoji", "voice"}: + plain_text_parts.append(f"[{item_type}]") + + plain_text = "".join(part for part in plain_text_parts if part).strip() + return plain_text or fallback_text or "[unsupported]" diff --git a/src/plugins/built_in/napcat_adapter/codec_outbound.py b/src/plugins/built_in/napcat_adapter/codec_outbound.py new file mode 100644 index 00000000..6adcb622 --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/codec_outbound.py @@ -0,0 +1,192 @@ +"""NapCat 出站消息编解码。""" + +from typing import Any, Dict, List, Mapping, Tuple + + +class NapCatOutboundCodec: + """NapCat 出站消息编码器。""" + + def build_outbound_action( + self, + message: Mapping[str, Any], + route: Mapping[str, Any], + ) -> Tuple[str, Dict[str, Any]]: + """为 Host 出站消息构造 OneBot 动作。 + + Args: + message: Host 侧标准 ``MessageDict``。 + route: Platform IO 路由信息。 + + Returns: + Tuple[str, Dict[str, Any]]: 动作名称与参数字典。 + + Raises: + ValueError: 当私聊出站缺少目标用户 ID 时抛出。 + """ + message_info = message.get("message_info", {}) + if not isinstance(message_info, Mapping): + message_info = {} + + group_info = message_info.get("group_info", {}) + if not isinstance(group_info, Mapping): + group_info = {} + + additional_config = message_info.get("additional_config", {}) + if not isinstance(additional_config, Mapping): + additional_config = {} + + raw_message = message.get("raw_message", []) + segments = self.convert_segments(raw_message) + + if target_group_id := str( + group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or "" + ).strip(): + return "send_group_msg", {"group_id": target_group_id, "message": segments} + + target_user_id = str( + additional_config.get("platform_io_target_user_id") + or additional_config.get("target_user_id") + or route.get("target_user_id") + or "" + ).strip() + if not target_user_id: + raise ValueError("Outbound private message is missing target_user_id") + + return "send_private_msg", {"message": segments, "user_id": target_user_id} + + def convert_segments(self, raw_message: Any) -> List[Dict[str, Any]]: + """将 Host 消息段转换为 OneBot 消息段。 + + Args: + raw_message: Host 侧 ``raw_message`` 字段。 + + Returns: + List[Dict[str, Any]]: OneBot 消息段列表。 + """ + if not isinstance(raw_message, list): + return [{"type": "text", "data": {"text": ""}}] + + outbound_segments: List[Dict[str, Any]] = [] + for item in raw_message: + if not isinstance(item, Mapping): + continue + + item_type = str(item.get("type") or "").strip() + item_data = item.get("data") + + if item_type == "text": + text_value = str(item_data or "") + outbound_segments.append({"type": "text", "data": {"text": text_value}}) + continue + + if item_type == "at" and isinstance(item_data, Mapping): + if target_user_id := str(item_data.get("target_user_id") or "").strip(): + outbound_segments.append({"type": "at", "data": {"qq": target_user_id}}) + continue + + if item_type == "reply": + if isinstance(item_data, Mapping): + target_message_id = str(item_data.get("target_message_id") or "").strip() + else: + target_message_id = str(item_data or "").strip() + if target_message_id: + outbound_segments.append({"type": "reply", "data": {"id": target_message_id}}) + continue + + if item_type == "image": + binary_base64 = str(item.get("binary_data_base64") or "").strip() + if binary_base64: + outbound_segments.append( + { + "type": "image", + "data": {"file": f"base64://{binary_base64}", "subtype": 0}, + } + ) + else: + outbound_segments.append({"type": "text", "data": {"text": "[image]"}}) + continue + + if item_type == "emoji": + binary_base64 = str(item.get("binary_data_base64") or "").strip() + if binary_base64: + outbound_segments.append( + { + "type": "image", + "data": { + "file": f"base64://{binary_base64}", + "subtype": 1, + "summary": "[动画表情]", + }, + } + ) + else: + outbound_segments.append({"type": "text", "data": {"text": "[emoji]"}}) + continue + + if item_type == "voice": + binary_base64 = str(item.get("binary_data_base64") or "").strip() + if binary_base64: + outbound_segments.append({"type": "record", "data": {"file": f"base64://{binary_base64}"}}) + else: + outbound_segments.append({"type": "text", "data": {"text": "[voice]"}}) + continue + + if item_type == "forward" and isinstance(item_data, list): + outbound_segments.extend(self._build_forward_nodes(item_data)) + continue + + if item_type == "dict" and isinstance(item_data, Mapping): + if dict_segment := self._build_dict_component_segment(item_data): + outbound_segments.append(dict_segment) + continue + + fallback_text = f"[unsupported:{item_type or 'unknown'}]" + outbound_segments.append({"type": "text", "data": {"text": fallback_text}}) + + if not outbound_segments: + outbound_segments.append({"type": "text", "data": {"text": ""}}) + return outbound_segments + + def _build_forward_nodes(self, forward_nodes: List[Any]) -> List[Dict[str, Any]]: + """构造 NapCat 转发节点列表。 + + Args: + forward_nodes: 内部转发节点列表。 + + Returns: + List[Dict[str, Any]]: NapCat 转发节点列表。 + """ + built_nodes: List[Dict[str, Any]] = [] + for node in forward_nodes: + if not isinstance(node, Mapping): + continue + raw_content = node.get("content", []) + node_segments = self.convert_segments(raw_content) + built_nodes.append( + { + "type": "node", + "data": { + "name": str(node.get("user_nickname") or node.get("user_cardname") or "QQ用户"), + "uin": str(node.get("user_id") or ""), + "content": node_segments, + }, + } + ) + return built_nodes + + def _build_dict_component_segment(self, item_data: Mapping[str, Any]) -> Dict[str, Any]: + """尽力将 ``DictComponent`` 转换为 NapCat 消息段。 + + Args: + item_data: ``DictComponent`` 原始数据。 + + Returns: + Dict[str, Any]: NapCat 消息段;不支持时返回占位文本段。 + """ + raw_type = str(item_data.get("type") or "").strip() + raw_payload = item_data.get("data", item_data) + if raw_type in {"file", "music", "video", "face"} and isinstance(raw_payload, Mapping): + return {"type": raw_type, "data": dict(raw_payload)} + if raw_type in {"image", "record", "reply", "at"} and isinstance(raw_payload, Mapping): + return {"type": raw_type, "data": dict(raw_payload)} + return {"type": "text", "data": {"text": f"[unsupported:{raw_type or 'dict'}]"}} diff --git a/src/plugins/built_in/napcat_adapter/config.py b/src/plugins/built_in/napcat_adapter/config.py new file mode 100644 index 00000000..eeb4acab --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/config.py @@ -0,0 +1,398 @@ +"""NapCat 内置适配器配置解析。""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Mapping, Optional, Set, Tuple +from urllib.parse import urlparse + +from napcat_adapter.constants import ( + DEFAULT_ACTION_TIMEOUT_SEC, + DEFAULT_CHAT_LIST_TYPE, + DEFAULT_HEARTBEAT_INTERVAL_SEC, + DEFAULT_NAPCAT_HOST, + DEFAULT_NAPCAT_PORT, + DEFAULT_RECONNECT_DELAY_SEC, + SUPPORTED_CONFIG_VERSION, +) + + +@dataclass(frozen=True) +class NapCatPluginOptions: + """插件级配置。""" + + enabled: bool = False + config_version: str = "" + + def should_connect(self) -> bool: + """判断当前配置下是否应当启动连接。 + + Returns: + bool: 若插件连接已启用,则返回 ``True``。 + """ + return self.enabled + + +@dataclass(frozen=True) +class NapCatServerConfig: + """NapCat 正向 WebSocket 连接配置。""" + + host: str = DEFAULT_NAPCAT_HOST + port: int = DEFAULT_NAPCAT_PORT + token: str = "" + heartbeat_interval: float = DEFAULT_HEARTBEAT_INTERVAL_SEC + reconnect_delay_sec: float = DEFAULT_RECONNECT_DELAY_SEC + action_timeout_sec: float = DEFAULT_ACTION_TIMEOUT_SEC + connection_id: str = "" + + def build_ws_url(self) -> str: + """构造正向 WebSocket 地址。 + + Returns: + str: 供适配器作为客户端连接的 NapCat WebSocket 地址。 + """ + return f"ws://{self.host}:{self.port}" + + +@dataclass(frozen=True) +class NapCatChatConfig: + """聊天名单配置。""" + + group_list_type: str = DEFAULT_CHAT_LIST_TYPE + group_list: Set[str] = field(default_factory=set) + private_list_type: str = DEFAULT_CHAT_LIST_TYPE + private_list: Set[str] = field(default_factory=set) + ban_user_id: Set[str] = field(default_factory=set) + + +@dataclass(frozen=True) +class NapCatFilterConfig: + """消息过滤配置。""" + + ignore_self_message: bool = True + + +@dataclass(frozen=True) +class NapCatPluginSettings: + """NapCat 插件完整配置。""" + + plugin: NapCatPluginOptions = field(default_factory=NapCatPluginOptions) + napcat_server: NapCatServerConfig = field(default_factory=NapCatServerConfig) + chat: NapCatChatConfig = field(default_factory=NapCatChatConfig) + filters: NapCatFilterConfig = field(default_factory=NapCatFilterConfig) + + @classmethod + def from_mapping(cls, raw_config: Mapping[str, Any], logger: Any) -> "NapCatPluginSettings": + """从 Runner 注入的原始配置字典解析插件配置。 + + Args: + raw_config: Runner 注入的原始配置内容。 + logger: 插件日志对象。 + + Returns: + NapCatPluginSettings: 规范化后的插件配置。 + """ + plugin_section = _as_mapping(raw_config.get("plugin")) + server_section = _as_mapping(raw_config.get("napcat_server")) + legacy_connection_section = _as_mapping(raw_config.get("connection")) + chat_section = _as_mapping(raw_config.get("chat")) + filters_section = _as_mapping(raw_config.get("filters")) + + if not server_section and legacy_connection_section: + logger.warning("NapCat 适配器检测到旧版 [connection] 配置段,请尽快迁移到 [napcat_server]") + server_section = legacy_connection_section + + legacy_host, legacy_port = _read_legacy_host_port(server_section, legacy_connection_section, logger) + parsed_host = _read_string(server_section, "host") or legacy_host or DEFAULT_NAPCAT_HOST + parsed_port = _read_positive_int( + mapping=server_section, + key="port", + default=legacy_port or DEFAULT_NAPCAT_PORT, + logger=logger, + setting_name="napcat_server.port", + ) + + return cls( + plugin=NapCatPluginOptions( + enabled=_read_bool(plugin_section, "enabled", False), + config_version=_read_string(plugin_section, "config_version"), + ), + napcat_server=NapCatServerConfig( + host=parsed_host, + port=parsed_port, + token=_read_string(server_section, "token") or _read_string(server_section, "access_token"), + heartbeat_interval=_read_positive_float( + mapping=server_section, + key="heartbeat_interval", + default=_read_positive_float( + mapping=server_section, + key="heartbeat_sec", + default=DEFAULT_HEARTBEAT_INTERVAL_SEC, + logger=logger, + setting_name="napcat_server.heartbeat_interval", + ), + logger=logger, + setting_name="napcat_server.heartbeat_interval", + ), + reconnect_delay_sec=_read_positive_float( + mapping=server_section, + key="reconnect_delay_sec", + default=DEFAULT_RECONNECT_DELAY_SEC, + logger=logger, + setting_name="napcat_server.reconnect_delay_sec", + ), + action_timeout_sec=_read_positive_float( + mapping=server_section, + key="action_timeout_sec", + default=DEFAULT_ACTION_TIMEOUT_SEC, + logger=logger, + setting_name="napcat_server.action_timeout_sec", + ), + connection_id=_read_string(server_section, "connection_id"), + ), + chat=NapCatChatConfig( + group_list_type=_read_list_mode( + mapping=chat_section, + key="group_list_type", + default=DEFAULT_CHAT_LIST_TYPE, + logger=logger, + setting_name="chat.group_list_type", + ), + group_list=_read_string_set(chat_section, "group_list"), + private_list_type=_read_list_mode( + mapping=chat_section, + key="private_list_type", + default=DEFAULT_CHAT_LIST_TYPE, + logger=logger, + setting_name="chat.private_list_type", + ), + private_list=_read_string_set(chat_section, "private_list"), + ban_user_id=_read_string_set(chat_section, "ban_user_id"), + ), + filters=NapCatFilterConfig( + ignore_self_message=_read_bool(filters_section, "ignore_self_message", True), + ), + ) + + def should_connect(self) -> bool: + """判断当前配置下是否应当启动连接。 + + Returns: + bool: 若插件连接已启用,则返回 ``True``。 + """ + return self.plugin.should_connect() + + def validate(self, logger: Any) -> bool: + """校验当前配置是否满足启动连接的前提条件。 + + Args: + logger: 插件日志对象。 + + Returns: + bool: 若配置满足启动连接的前提条件,则返回 ``True``。 + """ + config_version = self.plugin.config_version + if not config_version: + logger.error( + f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}" + ) + return False + + if config_version != SUPPORTED_CONFIG_VERSION: + logger.error( + "NapCat 适配器配置版本不兼容: " + f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}" + ) + return False + + if not self.napcat_server.host: + logger.warning("NapCat 适配器已启用,但 napcat_server.host 为空") + return False + + if self.napcat_server.port <= 0: + logger.warning("NapCat 适配器已启用,但 napcat_server.port 不是正整数") + return False + + return True + + +def _as_mapping(value: Any) -> Dict[str, Any]: + """将任意值安全转换为字典。 + + Args: + value: 待转换的值。 + + Returns: + Dict[str, Any]: 若原值是映射,则返回普通字典;否则返回空字典。 + """ + return dict(value) if isinstance(value, Mapping) else {} + + +def _read_bool(mapping: Mapping[str, Any], key: str, default: bool) -> bool: + """安全读取布尔配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + + Returns: + bool: 解析后的布尔值。 + """ + value = mapping.get(key, default) + return value if isinstance(value, bool) else default + + +def _read_string(mapping: Mapping[str, Any], key: str) -> str: + """安全读取字符串配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + + Returns: + str: 去除首尾空白后的字符串值。 + """ + value = mapping.get(key) + return "" if value is None else str(value).strip() + + +def _read_positive_float( + mapping: Mapping[str, Any], + key: str, + default: float, + logger: Any, + setting_name: str, +) -> float: + """安全读取正浮点数配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + logger: 插件日志对象。 + setting_name: 用于日志输出的完整配置名。 + + Returns: + float: 合法的正浮点数;否则返回默认值。 + """ + value = mapping.get(key, default) + if isinstance(value, (int, float)) and float(value) > 0: + return float(value) + + if key in mapping: + logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}") + return default + + +def _read_positive_int( + mapping: Mapping[str, Any], + key: str, + default: int, + logger: Any, + setting_name: str, +) -> int: + """安全读取正整数配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + logger: 插件日志对象。 + setting_name: 用于日志输出的完整配置名。 + + Returns: + int: 合法的正整数;否则返回默认值。 + """ + value = mapping.get(key, default) + if isinstance(value, int) and value > 0: + return value + + if isinstance(value, str) and value.isdigit() and int(value) > 0: + return int(value) + + if key in mapping: + logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}") + return default + + +def _read_list_mode( + mapping: Mapping[str, Any], + key: str, + default: str, + logger: Any, + setting_name: str, +) -> str: + """安全读取名单模式配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + logger: 插件日志对象。 + setting_name: 用于日志输出的完整配置名。 + + Returns: + str: 合法的名单模式字符串。 + """ + value = mapping.get(key, default) + if isinstance(value, str): + normalized_value = value.strip() + if normalized_value in {"whitelist", "blacklist"}: + return normalized_value + + if key in mapping: + logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}") + return default + + +def _read_string_set(mapping: Mapping[str, Any], key: str) -> Set[str]: + """安全读取字符串集合配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + + Returns: + Set[str]: 规范化后的字符串集合。 + """ + value = mapping.get(key, []) + if not isinstance(value, list): + return set() + + normalized_values: Set[str] = set() + for item in value: + item_text = "" if item is None else str(item).strip() + if item_text: + normalized_values.add(item_text) + return normalized_values + + +def _read_legacy_host_port( + server_section: Mapping[str, Any], + legacy_connection_section: Mapping[str, Any], + logger: Any, +) -> Tuple[str, Optional[int]]: + """从旧版 ``ws_url`` 配置中提取主机与端口。 + + Args: + server_section: 新版 ``napcat_server`` 配置段。 + legacy_connection_section: 旧版 ``connection`` 配置段。 + logger: 插件日志对象。 + + Returns: + Tuple[str, Optional[int]]: 解析到的主机与端口;若未找到,则返回空主机与 ``None``。 + """ + legacy_ws_url = _read_string(server_section, "ws_url") or _read_string(legacy_connection_section, "ws_url") + if not legacy_ws_url: + return "", None + + parsed_url = urlparse(legacy_ws_url) + parsed_host = parsed_url.hostname or "" + parsed_port = parsed_url.port + + logger.warning( + "NapCat 适配器检测到旧版 ws_url 配置,已临时兼容解析,请尽快迁移到 napcat_server.host/port" + ) + if parsed_url.path not in {"", "/"}: + logger.warning("NapCat 适配器旧版 ws_url 包含路径,新的 napcat_server 配置不会保留该路径") + + return parsed_host, parsed_port diff --git a/src/plugins/built_in/napcat_adapter/constants.py b/src/plugins/built_in/napcat_adapter/constants.py new file mode 100644 index 00000000..bdddde6f --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/constants.py @@ -0,0 +1,9 @@ +"""NapCat 内置适配器共享常量。""" + +SUPPORTED_CONFIG_VERSION = "0.1.0" +DEFAULT_NAPCAT_HOST = "127.0.0.1" +DEFAULT_NAPCAT_PORT = 3001 +DEFAULT_RECONNECT_DELAY_SEC = 5.0 +DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0 +DEFAULT_ACTION_TIMEOUT_SEC = 15.0 +DEFAULT_CHAT_LIST_TYPE = "whitelist" diff --git a/src/plugins/built_in/napcat_adapter/filters.py b/src/plugins/built_in/napcat_adapter/filters.py new file mode 100644 index 00000000..141cda85 --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/filters.py @@ -0,0 +1,68 @@ +"""NapCat 入站消息过滤。""" + +from typing import Any, Set + +from napcat_adapter.config import NapCatChatConfig + + +class NapCatChatFilter: + """NapCat 聊天名单过滤器。""" + + def __init__(self, logger: Any) -> None: + """初始化聊天名单过滤器。 + + Args: + logger: 插件日志对象。 + """ + self._logger = logger + + def is_inbound_chat_allowed( + self, + sender_user_id: str, + group_id: str, + chat_config: NapCatChatConfig, + ) -> bool: + """检查入站消息是否通过聊天名单过滤。 + + Args: + sender_user_id: 发送者用户 ID。 + group_id: 群聊 ID;私聊时为空字符串。 + chat_config: 当前生效的聊天配置。 + + Returns: + bool: 若消息允许继续进入 Host,则返回 ``True``。 + """ + if sender_user_id in chat_config.ban_user_id: + self._logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃") + return False + + if group_id: + if not self._is_id_allowed_by_list_policy(group_id, chat_config.group_list_type, chat_config.group_list): + self._logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃") + return False + return True + + if not self._is_id_allowed_by_list_policy( + sender_user_id, + chat_config.private_list_type, + chat_config.private_list, + ): + self._logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃") + return False + return True + + @staticmethod + def _is_id_allowed_by_list_policy(target_id: str, list_type: str, configured_ids: Set[str]) -> bool: + """根据白名单或黑名单规则判断目标 ID 是否允许通过。 + + Args: + target_id: 待检查的目标 ID。 + list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。 + configured_ids: 配置中的 ID 集合。 + + Returns: + bool: 若目标 ID 允许通过,则返回 ``True``。 + """ + if list_type == "whitelist": + return target_id in configured_ids + return target_id not in configured_ids diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index c8bb837b..b1e9bc8c 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -1,6 +1,6 @@ """内置 NapCat 适配器插件。 -当前实现是一个 MVP 版本,目标仅限于跑通基础消息收发链路: +当前实现维持 MVP 范围,目标是跑通基础消息收发链路: 1. 作为客户端连接 NapCat / OneBot v11 WebSocket 服务。 2. 将入站消息事件转换为 Host 侧的 ``MessageDict``。 3. 将 Host 出站消息转换为 OneBot 动作并发送。 @@ -8,45 +8,26 @@ 当前范围刻意收敛为: - 单连接 - 文本、@、reply 基础转发 -- 暂不处理 ``notice`` / ``meta_event`` +- 暂不处理 ``notice`` / ``meta_event`` 的完整语义归一化 - 暂不支持图片、语音、文件等复杂媒体 """ from __future__ import annotations -from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, cast -from uuid import uuid4 +from typing import Any, Dict, Mapping, Optional import asyncio -import contextlib -import json -import time from maibot_sdk import Adapter, MaiBotPlugin -if TYPE_CHECKING: - from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse - -try: - from aiohttp import ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType - - AIOHTTP_AVAILABLE = True -except ImportError: - ClientSession = cast(Any, None) - ClientTimeout = cast(Any, None) - ClientWebSocketResponse = cast(Any, None) - WSMsgType = cast(Any, None) - AIOHTTP_AVAILABLE = False - -if not TYPE_CHECKING: - AiohttpClientWebSocketResponse = Any - - -SUPPORTED_CONFIG_VERSION = "0.1.0" -DEFAULT_RECONNECT_DELAY_SEC = 5.0 -DEFAULT_HEARTBEAT_SEC = 30.0 -DEFAULT_ACTION_TIMEOUT_SEC = 15.0 -DEFAULT_CHAT_LIST_TYPE = "whitelist" +from napcat_adapter.codec_inbound import NapCatInboundCodec +from napcat_adapter.codec_outbound import NapCatOutboundCodec +from napcat_adapter.config import NapCatPluginSettings +from napcat_adapter.filters import NapCatChatFilter +from napcat_adapter.qq_notice import NapCatNoticeCodec +from napcat_adapter.qq_queries import NapCatQueryService +from napcat_adapter.runtime_state import NapCatRuntimeStateManager +from napcat_adapter.transport import NapCatTransportClient @Adapter(platform="qq", protocol="napcat", send_method="send_to_platform") @@ -57,14 +38,14 @@ class NapCatAdapterPlugin(MaiBotPlugin): """初始化 NapCat 适配器插件实例。""" super().__init__() self._plugin_config: Dict[str, Any] = {} - self._connection_task: Optional[asyncio.Task[None]] = None - self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} - self._background_tasks: Set[asyncio.Task[Any]] = set() - self._reported_account_id: Optional[str] = None - self._reported_scope: Optional[str] = None - self._runtime_state_connected: bool = False - self._send_lock = asyncio.Lock() - self._ws: Optional[AiohttpClientWebSocketResponse] = None + self._settings: Optional[NapCatPluginSettings] = None + self._inbound_codec: Optional[NapCatInboundCodec] = None + self._outbound_codec = NapCatOutboundCodec() + self._chat_filter: Optional[NapCatChatFilter] = None + self._query_service: Optional[NapCatQueryService] = None + self._notice_codec: Optional[NapCatNoticeCodec] = None + self._runtime_state: Optional[NapCatRuntimeStateManager] = None + self._transport: Optional[NapCatTransportClient] = None def set_plugin_config(self, config: Dict[str, Any]) -> None: """设置插件配置内容。 @@ -79,9 +60,8 @@ class NapCatAdapterPlugin(MaiBotPlugin): await self._restart_connection_if_needed() async def on_unload(self) -> None: - """在插件卸载时关闭连接并清理后台任务。""" + """在插件卸载时关闭连接并清理运行时状态。""" await self._stop_connection() - await self._cancel_background_tasks() async def on_config_update(self, new_config: Dict[str, Any], version: str) -> None: """在配置更新后重载连接状态。 @@ -116,13 +96,14 @@ class NapCatAdapterPlugin(MaiBotPlugin): del metadata del kwargs - ws = self._ws - if ws is None or ws.closed: - return {"success": False, "error": "NapCat is not connected"} + self._ensure_runtime_components() + transport = self._transport + if transport is None: + return {"success": False, "error": "NapCat transport is not initialized"} try: - action_name, params = self._build_outbound_action(message, route or {}) - response = await self._call_action(action_name, params) + action_name, params = self._outbound_codec.build_outbound_action(message, route or {}) + response = await transport.call_action(action_name, params) except Exception as exc: return {"success": False, "error": str(exc)} @@ -135,7 +116,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): response_data = response.get("data", {}) external_message_id = "" - if isinstance(response_data, dict): + if isinstance(response_data, Mapping): external_message_id = str(response_data.get("message_id") or "") return { @@ -144,143 +125,109 @@ class NapCatAdapterPlugin(MaiBotPlugin): "metadata": {"action": action_name}, } - async def _restart_connection_if_needed(self) -> None: - """根据当前配置重启连接循环。""" - await self._stop_connection() - if not self._should_connect(): - self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用") - return - if not self._validate_current_config(): - return - if not AIOHTTP_AVAILABLE: - self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") - return - self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection") + def _ensure_runtime_components(self) -> None: + """确保运行时依赖对象已经完成初始化。""" + if self._chat_filter is None: + self._chat_filter = NapCatChatFilter(self.ctx.logger) - async def _stop_connection(self) -> None: - """停止当前连接并让所有等待中的动作失败返回。""" - connection_task = self._connection_task - self._connection_task = None - - ws = self._ws - if ws is not None and not ws.closed: - with contextlib.suppress(Exception): - await ws.close() - self._ws = None - - if connection_task is not None: - connection_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await connection_task - - await self._report_adapter_disconnected() - self._fail_pending_actions("NapCat connection closed") - - async def _cancel_background_tasks(self) -> None: - """取消所有仍在运行的入站后台任务。""" - background_tasks = list(self._background_tasks) - for task in background_tasks: - task.cancel() - if background_tasks: - with contextlib.suppress(Exception): - await asyncio.gather(*background_tasks, return_exceptions=True) - self._background_tasks.clear() - - async def _connection_loop(self) -> None: - """维护单个 WebSocket 连接,并在断开后按配置重连。""" - assert ClientSession is not None - assert ClientTimeout is not None - - while self._should_connect(): - ws_url = self._get_string(self._connection_config(), "ws_url") - if not ws_url: - self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空") - return - - headers = self._build_headers() - timeout = ClientTimeout(total=None, connect=10) - heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", DEFAULT_HEARTBEAT_SEC) - - try: - async with ClientSession(headers=headers, timeout=timeout) as session: - async with session.ws_connect(ws_url, heartbeat=heartbeat or None) as ws: - self._ws = ws - self.ctx.logger.info(f"NapCat 适配器已连接: {ws_url}") - await self._receive_loop(ws) - except asyncio.CancelledError: - raise - except Exception as exc: - self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}") - finally: - self._ws = None - await self._report_adapter_disconnected() - self._fail_pending_actions("NapCat connection interrupted") - - if not self._should_connect(): - break - - await asyncio.sleep( - self._get_positive_float( - self._connection_config(), - "reconnect_delay_sec", - DEFAULT_RECONNECT_DELAY_SEC, - ) + if self._transport is None: + self._transport = NapCatTransportClient( + logger=self.ctx.logger, + on_connection_opened=self._bootstrap_adapter_runtime_state, + on_connection_closed=self._handle_transport_disconnected, + on_payload=self._handle_transport_payload, ) - async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None: - """持续消费 WebSocket 消息并分发处理。 + if self._query_service is None: + self._query_service = NapCatQueryService(self.ctx.logger, self._transport) + + if self._inbound_codec is None: + self._inbound_codec = NapCatInboundCodec(self.ctx.logger, self._query_service) + + if self._notice_codec is None: + self._notice_codec = NapCatNoticeCodec(self.ctx.logger, self._query_service) + + if self._runtime_state is None: + self._runtime_state = NapCatRuntimeStateManager(self.ctx.adapter, self.ctx.logger) + + def _reload_settings(self) -> NapCatPluginSettings: + """重新解析当前插件配置。 + + Returns: + NapCatPluginSettings: 最新的规范化配置。 + """ + self._settings = NapCatPluginSettings.from_mapping(self._plugin_config, self.ctx.logger) + return self._settings + + async def _restart_connection_if_needed(self) -> None: + """根据当前配置重启连接循环。""" + self._ensure_runtime_components() + settings = self._reload_settings() + + await self._stop_connection() + if not settings.should_connect(): + self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用") + return + if not settings.validate(self.ctx.logger): + return + + transport = self._transport + assert transport is not None + if not transport.is_available(): + self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") + return + + transport.configure(settings.napcat_server) + await transport.start() + + async def _stop_connection(self) -> None: + """停止当前连接。""" + transport = self._transport + if transport is not None: + await transport.stop() + return + + runtime_state = self._runtime_state + if runtime_state is not None: + await runtime_state.report_disconnected() + + async def _handle_transport_payload(self, payload: Dict[str, Any]) -> None: + """处理来自传输层的非 echo 载荷。 Args: - ws: 当前活跃的 WebSocket 连接对象。 + payload: NapCat 推送的原始事件数据。 """ - assert WSMsgType is not None - - bootstrap_task = asyncio.create_task( - self._bootstrap_adapter_runtime_state(), - name="napcat_adapter.bootstrap", - ) - self._background_tasks.add(bootstrap_task) - bootstrap_task.add_done_callback(self._background_tasks.discard) - - try: - async for ws_message in ws: - if ws_message.type != WSMsgType.TEXT: - if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: - break - continue - - payload = self._parse_json_message(ws_message.data) - if payload is None: - continue - - if echo_id := str(payload.get("echo") or "").strip(): - self._resolve_pending_action(echo_id, payload) - continue - - if str(payload.get("post_type") or "").strip() != "message": - continue - - task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound") - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - finally: - if not bootstrap_task.done(): - bootstrap_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await bootstrap_task + post_type = str(payload.get("post_type") or "").strip() + if post_type == "message": + await self._handle_inbound_message(payload) + return + if post_type == "notice": + await self._handle_notice_event(payload) + return + if post_type == "meta_event": + await self._handle_meta_event(payload) async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None: """处理单条 NapCat 入站消息并注入 Host。 Args: - payload: NapCat / OneBot 推送的原始事件数据。 + payload: NapCat / OneBot 推送的原始消息事件。 """ + self._ensure_runtime_components() + settings = self._settings or self._reload_settings() + chat_filter = self._chat_filter + inbound_codec = self._inbound_codec + runtime_state = self._runtime_state + assert chat_filter is not None + assert inbound_codec is not None + assert runtime_state is not None + self_id = str(payload.get("self_id") or "").strip() if self_id: - await self._report_adapter_connected(self_id) + await runtime_state.report_connected(self_id, settings.napcat_server) sender = payload.get("sender", {}) - if not isinstance(sender, dict): + if not isinstance(sender, Mapping): sender = {} sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip() @@ -288,17 +235,17 @@ class NapCatAdapterPlugin(MaiBotPlugin): return group_id = str(payload.get("group_id") or "").strip() - if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True): + if self_id and sender_user_id == self_id and settings.filters.ignore_self_message: return - if not self._is_inbound_chat_allowed(sender_user_id, group_id): + if not chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat): return - message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender) + message_dict = await inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender) route_metadata: Dict[str, Any] = {} if self_id: route_metadata["self_id"] = self_id - if connection_id := self._get_string(self._connection_config(), "connection_id"): - route_metadata["connection_id"] = connection_id + if settings.napcat_server.connection_id: + route_metadata["connection_id"] = settings.napcat_server.connection_id external_message_id = str(payload.get("message_id") or "").strip() accepted = await self.ctx.adapter.receive_external_message( @@ -310,305 +257,78 @@ class NapCatAdapterPlugin(MaiBotPlugin): if not accepted: self.ctx.logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}") - def _build_inbound_message_dict( - self, - payload: Dict[str, Any], - self_id: str, - sender_user_id: str, - sender: Dict[str, Any], - ) -> Dict[str, Any]: - """构造 Host 侧可接受的 ``MessageDict``。 + async def _handle_notice_event(self, payload: Dict[str, Any]) -> None: + """处理 NapCat ``notice`` 事件并注入 Host。 Args: - payload: NapCat 原始消息事件。 - self_id: 当前机器人账号 ID。 - sender_user_id: 发送者用户 ID。 - sender: 发送者信息字典。 - - Returns: - Dict[str, Any]: 规范化后的 ``MessageDict``。 + payload: NapCat 推送的通知事件。 """ - message_type = str(payload.get("message_type") or "").strip() or "private" - group_id = str(payload.get("group_id") or "").strip() - group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "") - user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id - user_cardname = str(sender.get("card") or "").strip() or None + self._ensure_runtime_components() + notice_codec = self._notice_codec + runtime_state = self._runtime_state + settings = self._settings or self._reload_settings() + assert notice_codec is not None + assert runtime_state is not None - raw_message, is_at = self._convert_inbound_segments(payload.get("message"), self_id) - raw_message_text = str(payload.get("raw_message") or "").strip() - if not raw_message: - raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}] + self_id = str(payload.get("self_id") or "").strip() + if self_id: + await runtime_state.report_connected(self_id, settings.napcat_server) - plain_text = self._build_plain_text(raw_message, raw_message_text) - timestamp_seconds = payload.get("time") - if not isinstance(timestamp_seconds, (int, float)): - timestamp_seconds = time.time() - - additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type} - if group_id: - additional_config["platform_io_target_group_id"] = group_id - else: - additional_config["platform_io_target_user_id"] = sender_user_id - - message_info: Dict[str, Any] = { - "user_info": { - "user_id": sender_user_id, - "user_nickname": user_nickname, - "user_cardname": user_cardname, - }, - "additional_config": additional_config, - } - if group_id: - message_info["group_info"] = {"group_id": group_id, "group_name": group_name} - - message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip() - return { - "message_id": message_id, - "timestamp": str(float(timestamp_seconds)), - "platform": "qq", - "message_info": message_info, - "raw_message": raw_message, - "is_mentioned": is_at, - "is_at": is_at, - "is_emoji": False, - "is_picture": False, - "is_command": plain_text.startswith("/"), - "is_notify": False, - "session_id": "", - "processed_plain_text": plain_text, - "display_message": plain_text, - } - - def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> Tuple[List[Dict[str, Any]], bool]: - """将 OneBot 消息段转换为 Host 消息段结构。 - - Args: - message_payload: OneBot 原始 ``message`` 字段。 - self_id: 当前机器人账号 ID。 - - Returns: - Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 - """ - if isinstance(message_payload, str): - normalized_text = message_payload.strip() - return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False - - if not isinstance(message_payload, list): - return [], False - - converted_segments: List[Dict[str, Any]] = [] - is_at = False - placeholder_texts = { - "face": "[face]", - "file": "[file]", - "image": "[image]", - "json": "[json]", - "record": "[voice]", - "video": "[video]", - "xml": "[xml]", - } - - for segment in message_payload: - if not isinstance(segment, dict): - continue - - segment_type = str(segment.get("type") or "").strip() - segment_data = segment.get("data", {}) - if not isinstance(segment_data, dict): - segment_data = {} - - if segment_type == "text": - if text_value := str(segment_data.get("text") or ""): - converted_segments.append({"type": "text", "data": text_value}) - continue - - if segment_type == "at": - if target_user_id := str(segment_data.get("qq") or "").strip(): - converted_segments.append( - { - "type": "at", - "data": { - "target_user_id": target_user_id, - "target_user_nickname": None, - "target_user_cardname": None, - }, - } - ) - if self_id and target_user_id == self_id: - is_at = True - continue - - if segment_type == "reply": - if target_message_id := str(segment_data.get("id") or "").strip(): - converted_segments.append({"type": "reply", "data": target_message_id}) - continue - - if placeholder := placeholder_texts.get(segment_type): - converted_segments.append({"type": "text", "data": placeholder}) - - return converted_segments, is_at - - def _build_outbound_action( - self, - message: Dict[str, Any], - route: Dict[str, Any], - ) -> Tuple[str, Dict[str, Any]]: - """为 Host 出站消息构造 OneBot 动作。 - - Args: - message: Host 侧标准 ``MessageDict``。 - route: Platform IO 路由信息。 - - Returns: - Tuple[str, Dict[str, Any]]: 动作名称与参数字典。 - """ - message_info = message.get("message_info", {}) - if not isinstance(message_info, dict): - message_info = {} - - group_info = message_info.get("group_info", {}) - if not isinstance(group_info, dict): - group_info = {} - - additional_config = message_info.get("additional_config", {}) - if not isinstance(additional_config, dict): - additional_config = {} - - raw_message = message.get("raw_message", []) - segments = self._convert_outbound_segments(raw_message) - - if target_group_id := str( - group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or "" - ).strip(): - return "send_group_msg", {"group_id": target_group_id, "message": segments} - - if not ( - target_user_id := str( - additional_config.get("platform_io_target_user_id") - or additional_config.get("target_user_id") - or route.get("target_user_id") - or "" - ).strip() - ): - raise ValueError("Outbound private message is missing target_user_id") - - return "send_private_msg", {"message": segments, "user_id": target_user_id} - - def _convert_outbound_segments(self, raw_message: Any) -> List[Dict[str, Any]]: - """将 Host 消息段转换为 OneBot 消息段。 - - Args: - raw_message: Host 侧 ``raw_message`` 字段。 - - Returns: - List[Dict[str, Any]]: OneBot 消息段列表。 - """ - if not isinstance(raw_message, list): - return [{"type": "text", "data": {"text": ""}}] - - outbound_segments: List[Dict[str, Any]] = [] - for item in raw_message: - if not isinstance(item, dict): - continue - - item_type = str(item.get("type") or "").strip() - item_data = item.get("data") - - if item_type == "text": - text_value = str(item_data or "") - outbound_segments.append({"type": "text", "data": {"text": text_value}}) - continue - - if item_type == "at" and isinstance(item_data, dict): - if target_user_id := str(item_data.get("target_user_id") or "").strip(): - outbound_segments.append({"type": "at", "data": {"qq": target_user_id}}) - continue - - if item_type == "reply": - if target_message_id := str(item_data or "").strip(): - outbound_segments.append({"type": "reply", "data": {"id": target_message_id}}) - continue - - fallback_text = f"[unsupported:{item_type or 'unknown'}]" - outbound_segments.append({"type": "text", "data": {"text": fallback_text}}) - - if not outbound_segments: - outbound_segments.append({"type": "text", "data": {"text": ""}}) - return outbound_segments - - async def _call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]: - """发送 OneBot 动作并等待对应的 echo 响应。 - - Args: - action_name: OneBot 动作名称。 - params: 动作参数。 - - Returns: - Dict[str, Any]: NapCat 返回的原始响应字典。 - """ - ws = self._ws - if ws is None or ws.closed: - raise RuntimeError("NapCat is not connected") - - echo_id = uuid4().hex - loop = asyncio.get_running_loop() - response_future: asyncio.Future[Dict[str, Any]] = loop.create_future() - self._pending_actions[echo_id] = response_future - - request_payload = {"action": action_name, "params": params, "echo": echo_id} - try: - async with self._send_lock: - await ws.send_str(json.dumps(request_payload, ensure_ascii=False)) - timeout_seconds = self._get_positive_float( - self._connection_config(), - "action_timeout_sec", - DEFAULT_ACTION_TIMEOUT_SEC, - ) - return await asyncio.wait_for(response_future, timeout=timeout_seconds) - finally: - self._pending_actions.pop(echo_id, None) - - def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None: - """解析等待中的动作响应。 - - Args: - echo_id: 动作请求对应的 echo 标识。 - payload: NapCat 返回的响应载荷。 - """ - response_future = self._pending_actions.get(echo_id) - if response_future is None or response_future.done(): + message_dict = await notice_codec.build_notice_message_dict(payload) + if message_dict is None: return - response_future.set_result(payload) - def _fail_pending_actions(self, error_message: str) -> None: - """让所有等待中的动作以异常方式结束。 + route_metadata: Dict[str, Any] = {} + if self_id: + route_metadata["self_id"] = self_id + if settings.napcat_server.connection_id: + route_metadata["connection_id"] = settings.napcat_server.connection_id + + external_message_id = str(payload.get("message_id") or payload.get("notice_type") or "").strip() + accepted = await self.ctx.adapter.receive_external_message( + message_dict, + route_metadata=route_metadata, + external_message_id=external_message_id or None, + dedupe_key=external_message_id or None, + ) + if not accepted: + self.ctx.logger.debug(f"Host 丢弃了 NapCat 通知事件: {external_message_id or '无消息 ID'}") + + async def _handle_meta_event(self, payload: Dict[str, Any]) -> None: + """处理 NapCat ``meta_event`` 事件。 Args: - error_message: 写入异常中的错误信息。 + payload: NapCat 推送的元事件。 """ - for response_future in self._pending_actions.values(): - if not response_future.done(): - response_future.set_exception(RuntimeError(error_message)) - self._pending_actions.clear() + self._ensure_runtime_components() + notice_codec = self._notice_codec + runtime_state = self._runtime_state + settings = self._settings or self._reload_settings() + assert notice_codec is not None + assert runtime_state is not None + + self_id = str(payload.get("self_id") or "").strip() + if self_id: + await runtime_state.report_connected(self_id, settings.napcat_server) + + await notice_codec.handle_meta_event(payload) async def _bootstrap_adapter_runtime_state(self) -> None: - """在连接建立后主动获取账号信息并激活适配器路由。 + """在连接建立后主动获取账号信息并激活适配器路由。""" + transport = self._transport + query_service = self._query_service + runtime_state = self._runtime_state + settings = self._settings or self._reload_settings() + if transport is None or query_service is None or runtime_state is None: + return - 该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()` - 发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo - 响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。 - """ max_attempts = 3 last_error: Optional[Exception] = None for attempt in range(1, max_attempts + 1): - ws = self._ws - if ws is None or ws.closed: - return - try: - response = await self._call_action("get_login_info", {}) - self_id = self._extract_self_id_from_login_response(response) - await self._report_adapter_connected(self_id) + login_info = await query_service.get_login_info() + self_id = self._extract_self_id_from_login_response(login_info) + await runtime_state.report_connected(self_id, settings.napcat_server) return except asyncio.CancelledError: raise @@ -623,430 +343,33 @@ class NapCatAdapterPlugin(MaiBotPlugin): if last_error is not None: self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}") + async def _handle_transport_disconnected(self) -> None: + """处理传输层断开事件。""" + runtime_state = self._runtime_state + if runtime_state is not None: + await runtime_state.report_disconnected() + @staticmethod - def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str: - """从 `get_login_info` 响应中提取当前账号 ID。 + def _extract_self_id_from_login_response(response: Optional[Dict[str, Any]]) -> str: + """从 ``get_login_info`` 查询结果中提取当前账号 ID。 Args: - response: NapCat 返回的原始动作响应。 + response: NapCat 返回的登录信息字典。 Returns: - str: 规范化后的 `self_id` 字符串。 + str: 规范化后的账号 ID 字符串。 Raises: ValueError: 当响应中缺少有效账号 ID 时抛出。 """ - if str(response.get("status") or "").lower() != "ok": - raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed")) - - response_data = response.get("data", {}) - if not isinstance(response_data, dict): + if not isinstance(response, Mapping): raise ValueError("get_login_info 响应缺少 data 字段") - self_id = str(response_data.get("user_id") or "").strip() + self_id = str(response.get("user_id") or "").strip() if not self_id: raise ValueError("get_login_info 响应缺少有效的 user_id") return self_id - async def _report_adapter_connected(self, account_id: str) -> None: - """向 Host 上报当前连接已就绪。 - - Args: - account_id: 当前 NapCat 连接对应的机器人账号 ID。 - """ - normalized_account_id = str(account_id).strip() - if not normalized_account_id: - return - - scope = self._get_string(self._connection_config(), "connection_id").strip() - if ( - self._runtime_state_connected - and self._reported_account_id == normalized_account_id - and self._reported_scope == (scope or None) - ): - return - - accepted = False - try: - accepted = await self.ctx.adapter.update_runtime_state( - connected=True, - account_id=normalized_account_id, - scope=scope, - metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")}, - ) - except Exception as exc: - self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}") - return - - if not accepted: - self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新") - return - - self._runtime_state_connected = True - self._reported_account_id = normalized_account_id - self._reported_scope = scope or None - self.ctx.logger.info( - f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} " - f"scope={self._reported_scope or '*'}" - ) - - async def _report_adapter_disconnected(self) -> None: - """向 Host 上报当前连接已断开,并撤销适配器路由。""" - if not self._runtime_state_connected: - self._reported_account_id = None - self._reported_scope = None - return - - try: - await self.ctx.adapter.update_runtime_state(connected=False) - except Exception as exc: - self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}") - finally: - self._runtime_state_connected = False - self._reported_account_id = None - self._reported_scope = None - - def _build_headers(self) -> Dict[str, str]: - """构造连接 NapCat 所需的请求头。 - - Returns: - Dict[str, str]: WebSocket 握手请求头。 - """ - access_token = self._get_string(self._connection_config(), "access_token") - return {"Authorization": f"Bearer {access_token}"} if access_token else {} - - def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]: - """解析 WebSocket 文本消息中的 JSON 数据。 - - Args: - data: WebSocket 收到的原始文本数据。 - - Returns: - Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。 - """ - try: - payload = json.loads(str(data)) - except Exception as exc: - self.ctx.logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}") - return None - - return payload if isinstance(payload, dict) else None - - def _build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str: - """从标准消息段中提取可展示的纯文本。 - - Args: - raw_message: 标准化后的消息段列表。 - fallback_text: 当无法拼出文本时使用的回退文本。 - - Returns: - str: 用于 Host 展示和命令判断的纯文本内容。 - """ - plain_text_parts: List[str] = [] - for item in raw_message: - if not isinstance(item, dict): - continue - item_type = str(item.get("type") or "").strip() - item_data = item.get("data") - if item_type == "text": - plain_text_parts.append(str(item_data or "")) - elif item_type == "at" and isinstance(item_data, dict): - plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}") - elif item_type == "reply": - plain_text_parts.append("[reply]") - - plain_text = "".join(part for part in plain_text_parts if part).strip() - return plain_text or fallback_text or "[unsupported]" - - def _plugin_section(self) -> Dict[str, Any]: - """读取插件配置中的 ``plugin`` 段。 - - Returns: - Dict[str, Any]: ``plugin`` 配置字典。 - """ - plugin_section = self._plugin_config.get("plugin", {}) - return plugin_section if isinstance(plugin_section, dict) else {} - - def _connection_config(self) -> Dict[str, Any]: - """读取插件配置中的 ``connection`` 段。 - - Returns: - Dict[str, Any]: ``connection`` 配置字典。 - """ - connection_config = self._plugin_config.get("connection", {}) - return connection_config if isinstance(connection_config, dict) else {} - - def _filters_config(self) -> Dict[str, Any]: - """读取插件配置中的 ``filters`` 段。 - - Returns: - Dict[str, Any]: ``filters`` 配置字典。 - """ - filters_config = self._plugin_config.get("filters", {}) - return filters_config if isinstance(filters_config, dict) else {} - - def _chat_config(self) -> Dict[str, Any]: - """读取插件配置中的 ``chat`` 段。 - - Returns: - Dict[str, Any]: ``chat`` 配置字典。 - """ - chat_config = self._plugin_config.get("chat", {}) - return chat_config if isinstance(chat_config, dict) else {} - - def _is_inbound_chat_allowed(self, sender_user_id: str, group_id: str) -> bool: - """检查入站消息是否通过聊天名单过滤。 - - Args: - sender_user_id: 发送者用户 ID。 - group_id: 群聊 ID;私聊时为空字符串。 - - Returns: - bool: 若消息允许继续进入 Host,则返回 ``True``。 - """ - chat_config = self._chat_config() - banned_user_ids = self._get_string_list(chat_config, "ban_user_id") - if sender_user_id in banned_user_ids: - self.ctx.logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃") - return False - - if group_id: - group_list_type = self._get_list_mode(chat_config, "group_list_type", DEFAULT_CHAT_LIST_TYPE) - group_id_list = self._get_string_list(chat_config, "group_list") - if not self._is_id_allowed_by_list_policy(group_id, group_list_type, group_id_list): - self.ctx.logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃") - return False - return True - - private_list_type = self._get_list_mode(chat_config, "private_list_type", DEFAULT_CHAT_LIST_TYPE) - private_id_list = self._get_string_list(chat_config, "private_list") - if not self._is_id_allowed_by_list_policy(sender_user_id, private_list_type, private_id_list): - self.ctx.logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃") - return False - return True - - def _is_id_allowed_by_list_policy( - self, - target_id: str, - list_type: str, - configured_ids: Set[str], - ) -> bool: - """根据白名单或黑名单规则判断目标 ID 是否允许通过。 - - Args: - target_id: 待检查的目标 ID。 - list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。 - configured_ids: 配置中的 ID 集合。 - - Returns: - bool: 若目标 ID 允许通过,则返回 ``True``。 - """ - if list_type == "whitelist": - return target_id in configured_ids - return target_id not in configured_ids - - def _validate_current_config(self) -> bool: - """校验当前配置是否满足启动连接的前提条件。 - - Returns: - bool: 配置可用于启动连接时返回 ``True``。 - """ - if not self._validate_plugin_config_version(): - return False - - connection_config = self._connection_config() - ws_url = self._get_string(connection_config, "ws_url") - if not ws_url: - self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空") - return False - - self._validate_positive_float_setting( - connection_config, - "connection", - "reconnect_delay_sec", - DEFAULT_RECONNECT_DELAY_SEC, - ) - self._validate_positive_float_setting( - connection_config, - "connection", - "heartbeat_sec", - DEFAULT_HEARTBEAT_SEC, - ) - self._validate_positive_float_setting( - connection_config, - "connection", - "action_timeout_sec", - DEFAULT_ACTION_TIMEOUT_SEC, - ) - self._validate_list_mode_setting(self._chat_config(), "chat", "group_list_type", DEFAULT_CHAT_LIST_TYPE) - self._validate_list_mode_setting(self._chat_config(), "chat", "private_list_type", DEFAULT_CHAT_LIST_TYPE) - return True - - def _validate_plugin_config_version(self) -> bool: - """校验插件配置版本是否与当前实现兼容。 - - Returns: - bool: 版本兼容时返回 ``True``。 - """ - config_version = self._get_string(self._plugin_section(), "config_version") - if not config_version: - self.ctx.logger.error( - f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}" - ) - return False - - if config_version != SUPPORTED_CONFIG_VERSION: - self.ctx.logger.error( - "NapCat 适配器配置版本不兼容: " - f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}" - ) - return False - - return True - - def _validate_positive_float_setting( - self, - mapping: Dict[str, Any], - section_name: str, - key: str, - default: float, - ) -> None: - """校验正浮点数配置项,并在非法时输出告警日志。 - - Args: - mapping: 待读取的配置字典。 - section_name: 当前配置段名称。 - key: 目标配置键名。 - default: 配置非法时实际使用的默认值。 - """ - value = mapping.get(key, default) - if isinstance(value, (int, float)) and float(value) > 0: - return - - self.ctx.logger.warning( - "NapCat 适配器配置项取值无效,已回退到默认值: " - f"{section_name}.{key}={value!r},默认值为 {default}" - ) - - def _validate_list_mode_setting( - self, - mapping: Dict[str, Any], - section_name: str, - key: str, - default: str, - ) -> None: - """校验名单模式配置项,并在非法时输出告警日志。 - - Args: - mapping: 待读取的配置字典。 - section_name: 当前配置段名称。 - key: 目标配置键名。 - default: 配置非法时实际使用的默认值。 - """ - value = mapping.get(key, default) - if isinstance(value, str) and value.strip() in {"whitelist", "blacklist"}: - return - - self.ctx.logger.warning( - "NapCat 适配器配置项取值无效,已回退到默认值: " - f"{section_name}.{key}={value!r},默认值为 {default}" - ) - - def _should_connect(self) -> bool: - """判断当前配置下是否应当启动连接。 - - Returns: - bool: 若启用了插件连接则返回 ``True``。 - """ - return self._get_bool(self._plugin_section(), "enabled", False) - - @staticmethod - def _get_bool(mapping: Dict[str, Any], key: str, default: bool) -> bool: - """安全读取布尔配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - default: 读取失败时的默认值。 - - Returns: - bool: 解析后的布尔值。 - """ - value = mapping.get(key, default) - return value if isinstance(value, bool) else default - - @staticmethod - def _get_positive_float(mapping: Dict[str, Any], key: str, default: float) -> float: - """安全读取正浮点数配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - default: 读取失败时的默认值。 - - Returns: - float: 合法的正浮点数;否则返回默认值。 - """ - value = mapping.get(key, default) - if isinstance(value, (int, float)) and float(value) > 0: - return float(value) - return default - - @staticmethod - def _get_string(mapping: Dict[str, Any], key: str) -> str: - """安全读取字符串配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - - Returns: - str: 去除首尾空白后的字符串值。 - """ - value = mapping.get(key) - return "" if value is None else str(value).strip() - - @staticmethod - def _get_list_mode(mapping: Dict[str, Any], key: str, default: str) -> str: - """安全读取名单模式配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - default: 读取失败时的默认值。 - - Returns: - str: 合法的名单模式字符串。 - """ - value = mapping.get(key, default) - if isinstance(value, str): - normalized_value = value.strip() - if normalized_value in {"whitelist", "blacklist"}: - return normalized_value - return default - - @staticmethod - def _get_string_list(mapping: Dict[str, Any], key: str) -> Set[str]: - """安全读取 ID 列表配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - - Returns: - Set[str]: 去重后的字符串 ID 集合。 - """ - value = mapping.get(key, []) - if not isinstance(value, list): - return set() - - normalized_values: Set[str] = set() - for item in value: - item_text = "" if item is None else str(item).strip() - if item_text: - normalized_values.add(item_text) - return normalized_values - def create_plugin() -> NapCatAdapterPlugin: """创建插件实例。 diff --git a/src/plugins/built_in/napcat_adapter/qq_notice.py b/src/plugins/built_in/napcat_adapter/qq_notice.py new file mode 100644 index 00000000..f577cf98 --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/qq_notice.py @@ -0,0 +1,224 @@ +"""NapCat QQ 平台通知与元事件处理。""" + +from typing import Any, Dict, Mapping, Optional +from uuid import uuid4 + +import time + +from napcat_adapter.qq_queries import NapCatQueryService + + +class NapCatNoticeCodec: + """NapCat QQ 通知事件编码器。""" + + def __init__(self, logger: Any, query_service: NapCatQueryService) -> None: + """初始化通知事件编码器。 + + Args: + logger: 插件日志对象。 + query_service: QQ 查询服务。 + """ + self._logger = logger + self._query_service = query_service + + async def build_notice_message_dict(self, payload: Mapping[str, Any]) -> Optional[Dict[str, Any]]: + """将 NapCat ``notice`` 事件转换为 Host 可接受的消息字典。 + + Args: + payload: NapCat 推送的原始通知事件。 + + Returns: + Optional[Dict[str, Any]]: 成功时返回标准 ``MessageDict``;无法识别时返回 ``None``。 + """ + notice_type = str(payload.get("notice_type") or "").strip() + if not notice_type: + return None + + group_id = str(payload.get("group_id") or "").strip() + user_id = str(payload.get("user_id") or payload.get("operator_id") or "").strip() + self_id = str(payload.get("self_id") or "").strip() + + user_info = await self._build_user_info(group_id=group_id, user_id=user_id) + group_info = await self._build_group_info(group_id) + notice_text = self._build_notice_text(payload, user_info.get("user_nickname", user_id or "系统")) + if not notice_text: + return None + + additional_config: Dict[str, Any] = { + "self_id": self_id, + "napcat_notice_type": notice_type, + "napcat_notice_sub_type": str(payload.get("sub_type") or "").strip(), + "napcat_notice_payload": dict(payload), + } + if group_id: + additional_config["platform_io_target_group_id"] = group_id + elif user_id: + additional_config["platform_io_target_user_id"] = user_id + + message_info: Dict[str, Any] = {"user_info": user_info, "additional_config": additional_config} + if group_info is not None: + message_info["group_info"] = group_info + + timestamp_seconds = payload.get("time") + if not isinstance(timestamp_seconds, (int, float)): + timestamp_seconds = time.time() + + return { + "message_id": f"napcat-notice-{uuid4().hex}", + "timestamp": str(float(timestamp_seconds)), + "platform": "qq", + "message_info": message_info, + "raw_message": [{"type": "text", "data": notice_text}], + "is_mentioned": False, + "is_at": False, + "is_emoji": False, + "is_picture": False, + "is_command": False, + "is_notify": True, + "session_id": "", + "processed_plain_text": notice_text, + "display_message": notice_text, + } + + async def handle_meta_event(self, payload: Mapping[str, Any]) -> None: + """处理 ``meta_event`` 事件的日志与状态观测。 + + Args: + payload: NapCat 推送的原始元事件。 + """ + meta_event_type = str(payload.get("meta_event_type") or "").strip() + self_id = str(payload.get("self_id") or "").strip() or "unknown" + + if meta_event_type == "lifecycle": + sub_type = str(payload.get("sub_type") or "").strip() + if sub_type == "connect": + self._logger.info(f"NapCat 元事件:Bot {self_id} 已建立连接") + else: + self._logger.debug(f"NapCat 生命周期事件: self_id={self_id} sub_type={sub_type}") + return + + if meta_event_type == "heartbeat": + status = payload.get("status", {}) + if not isinstance(status, Mapping): + status = {} + is_online = bool(status.get("online", False)) + is_good = bool(status.get("good", False)) + interval_ms = payload.get("interval") + self._logger.debug( + f"NapCat 心跳事件: self_id={self_id} online={is_online} good={is_good} interval={interval_ms}" + ) + if not is_online: + self._logger.warning(f"NapCat 心跳显示 Bot {self_id} 已离线") + elif not is_good: + self._logger.warning(f"NapCat 心跳显示 Bot {self_id} 状态异常") + + async def _build_user_info(self, group_id: str, user_id: str) -> Dict[str, Optional[str]]: + """构造通知消息的用户信息。 + + Args: + group_id: 群号;私聊或系统通知时为空字符串。 + user_id: 事件关联用户号。 + + Returns: + Dict[str, Optional[str]]: 规范化后的用户信息字典。 + """ + if not user_id: + return { + "user_id": "notice", + "user_nickname": "系统通知", + "user_cardname": None, + } + + member_info: Optional[Dict[str, Any]] + if group_id: + member_info = await self._query_service.get_group_member_info(group_id, user_id) + else: + member_info = await self._query_service.get_stranger_info(user_id) + + if member_info is None: + return { + "user_id": user_id, + "user_nickname": user_id, + "user_cardname": None, + } + + return { + "user_id": user_id, + "user_nickname": str(member_info.get("nickname") or user_id), + "user_cardname": self._normalize_optional_string(member_info.get("card")), + } + + async def _build_group_info(self, group_id: str) -> Optional[Dict[str, str]]: + """构造通知消息的群信息。 + + Args: + group_id: 群号。 + + Returns: + Optional[Dict[str, str]]: 群信息字典;若不是群通知则返回 ``None``。 + """ + if not group_id: + return None + + group_info = await self._query_service.get_group_info(group_id) + group_name = str(group_info.get("group_name") or f"group_{group_id}") if group_info else f"group_{group_id}" + return {"group_id": group_id, "group_name": group_name} + + def _build_notice_text(self, payload: Mapping[str, Any], actor_name: str) -> str: + """根据 NapCat 通知事件生成可读文本。 + + Args: + payload: 原始通知事件。 + actor_name: 事件操作者显示名。 + + Returns: + str: 生成的可读通知文本。 + """ + notice_type = str(payload.get("notice_type") or "").strip() + sub_type = str(payload.get("sub_type") or "").strip() + target_id = str(payload.get("target_id") or "").strip() + + if notice_type in {"group_recall", "friend_recall"}: + return f"{actor_name} 撤回了一条消息" + if notice_type == "notify" and sub_type == "poke": + target_text = f" -> {target_id}" if target_id else "" + return f"{actor_name} 发起了戳一戳{target_text}" + if notice_type == "notify" and sub_type == "group_name": + return f"{actor_name} 修改了群名称" + if notice_type == "group_ban" and sub_type == "ban": + duration = payload.get("duration") + return f"{actor_name} 触发了群禁言,时长 {duration} 秒" + if notice_type == "group_ban" and sub_type == "lift_ban": + return f"{actor_name} 触发了解除禁言" + if notice_type == "group_upload": + file_info = payload.get("file", {}) + file_name = "" + if isinstance(file_info, Mapping): + file_name = str(file_info.get("name") or "").strip() + return f"{actor_name} 上传了文件{f':{file_name}' if file_name else ''}" + if notice_type == "group_increase": + return f"{actor_name} 加入了群聊" + if notice_type == "group_decrease": + return f"{actor_name} 离开了群聊" + if notice_type == "group_admin": + return f"{actor_name} 的群管理员状态发生变化" + if notice_type == "essence": + return f"{actor_name} 触发了精华消息事件" + if notice_type == "group_msg_emoji_like": + return f"{actor_name} 给一条消息添加了表情回应" + return f"[notice] {notice_type}.{sub_type}".strip(".") + + @staticmethod + def _normalize_optional_string(value: Any) -> Optional[str]: + """将任意值规范化为可选字符串。 + + Args: + value: 待规范化的值。 + + Returns: + Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。 + """ + if value is None: + return None + normalized_value = str(value).strip() + return normalized_value if normalized_value else None diff --git a/src/plugins/built_in/napcat_adapter/qq_queries.py b/src/plugins/built_in/napcat_adapter/qq_queries.py new file mode 100644 index 00000000..7d29803a --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/qq_queries.py @@ -0,0 +1,170 @@ +"""NapCat QQ 平台查询能力。""" + +from typing import TYPE_CHECKING, Any, Dict, Optional + +import asyncio + +if TYPE_CHECKING: + from napcat_adapter.transport import NapCatTransportClient + +try: + from aiohttp import ClientSession, ClientTimeout + + AIOHTTP_AVAILABLE = True +except ImportError: + ClientSession = None # type: ignore[assignment] + ClientTimeout = None # type: ignore[assignment] + AIOHTTP_AVAILABLE = False + + +class NapCatQueryService: + """NapCat QQ 平台查询服务。""" + + def __init__(self, logger: Any, transport: "NapCatTransportClient") -> None: + """初始化查询服务。 + + Args: + logger: 插件日志对象。 + transport: NapCat 传输层客户端。 + """ + self._logger = logger + self._transport = transport + + async def get_login_info(self) -> Optional[Dict[str, Any]]: + """获取当前登录账号信息。 + + Returns: + Optional[Dict[str, Any]]: 登录信息字典;失败时返回 ``None``。 + """ + return await self._call_query("get_login_info", {}) + + async def get_group_info(self, group_id: str) -> Optional[Dict[str, Any]]: + """获取群信息。 + + Args: + group_id: 群号。 + + Returns: + Optional[Dict[str, Any]]: 群信息字典;失败时返回 ``None``。 + """ + return await self._call_query("get_group_info", {"group_id": group_id}) + + async def get_group_member_info(self, group_id: str, user_id: str) -> Optional[Dict[str, Any]]: + """获取群成员信息。 + + Args: + group_id: 群号。 + user_id: 用户号。 + + Returns: + Optional[Dict[str, Any]]: 群成员信息字典;失败时返回 ``None``。 + """ + return await self._call_query( + "get_group_member_info", + {"group_id": group_id, "user_id": user_id, "no_cache": True}, + ) + + async def get_stranger_info(self, user_id: str) -> Optional[Dict[str, Any]]: + """获取陌生人信息。 + + Args: + user_id: 用户号。 + + Returns: + Optional[Dict[str, Any]]: 陌生人信息字典;失败时返回 ``None``。 + """ + return await self._call_query("get_stranger_info", {"user_id": user_id}) + + async def get_message_detail(self, message_id: str) -> Optional[Dict[str, Any]]: + """获取消息详情。 + + Args: + message_id: 消息 ID。 + + Returns: + Optional[Dict[str, Any]]: 消息详情字典;失败时返回 ``None``。 + """ + return await self._call_query("get_msg", {"message_id": message_id}) + + async def get_forward_message(self, message_id: str) -> Optional[Dict[str, Any]]: + """获取合并转发消息详情。 + + Args: + message_id: 转发消息 ID。 + + Returns: + Optional[Dict[str, Any]]: 合并转发消息详情;失败时返回 ``None``。 + """ + return await self._call_query("get_forward_msg", {"message_id": message_id}) + + async def get_record_detail(self, file_name: str, file_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """获取语音文件详情。 + + Args: + file_name: 语音文件名。 + file_id: 可选文件 ID。 + + Returns: + Optional[Dict[str, Any]]: 语音详情字典;失败时返回 ``None``。 + """ + params: Dict[str, Any] = {"file": file_name, "out_format": "wav"} + if file_id: + params["file_id"] = file_id + return await self._call_query("get_record", params) + + async def download_binary(self, url: str) -> Optional[bytes]: + """下载远程二进制资源。 + + Args: + url: 资源 URL。 + + Returns: + Optional[bytes]: 下载到的二进制内容;失败时返回 ``None``。 + """ + if not url: + return None + if not AIOHTTP_AVAILABLE or ClientSession is None or ClientTimeout is None: + self._logger.warning("NapCat 查询层缺少 aiohttp,无法下载远程资源") + return None + + try: + timeout = ClientTimeout(total=15) + async with ClientSession(timeout=timeout) as session: + async with session.get(url) as response: + if response.status != 200: + self._logger.warning(f"NapCat 远程资源下载失败: status={response.status} url={url}") + return None + return await response.read() + except asyncio.CancelledError: + raise + except Exception as exc: + self._logger.warning(f"NapCat 远程资源下载失败: {exc}") + return None + + async def _call_query(self, action_name: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """调用 OneBot 查询动作并提取 ``data`` 字段。 + + Args: + action_name: OneBot 动作名。 + params: 动作参数。 + + Returns: + Optional[Dict[str, Any]]: 查询结果中的 ``data`` 字段;失败时返回 ``None``。 + """ + try: + response = await self._transport.call_action(action_name, params) + except asyncio.CancelledError: + raise + except Exception as exc: + self._logger.warning(f"NapCat 查询动作执行失败: action={action_name} error={exc}") + return None + + if str(response.get("status") or "").lower() != "ok": + self._logger.warning( + f"NapCat 查询动作返回失败: action={action_name} " + f"message={response.get('wording') or response.get('message') or 'unknown'}" + ) + return None + + response_data = response.get("data") + return response_data if isinstance(response_data, dict) else None diff --git a/src/plugins/built_in/napcat_adapter/runtime_state.py b/src/plugins/built_in/napcat_adapter/runtime_state.py new file mode 100644 index 00000000..b4dbfa09 --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/runtime_state.py @@ -0,0 +1,85 @@ +"""NapCat 运行时路由状态管理。""" + +from typing import Any, Optional + +from napcat_adapter.config import NapCatServerConfig + + +class NapCatRuntimeStateManager: + """NapCat 适配器路由状态上报器。""" + + def __init__(self, adapter_capability: Any, logger: Any) -> None: + """初始化运行时状态管理器。 + + Args: + adapter_capability: SDK 提供的适配器能力对象。 + logger: 插件日志对象。 + """ + self._adapter_capability = adapter_capability + self._logger = logger + self._runtime_state_connected: bool = False + self._reported_account_id: Optional[str] = None + self._reported_scope: Optional[str] = None + + async def report_connected(self, account_id: str, server_config: NapCatServerConfig) -> bool: + """向 Host 上报当前连接已就绪。 + + Args: + account_id: 当前 NapCat 连接对应的机器人账号 ID。 + server_config: 当前生效的 NapCat 服务端配置。 + + Returns: + bool: 若 Host 接受了运行时状态更新,则返回 ``True``。 + """ + normalized_account_id = str(account_id).strip() + if not normalized_account_id: + return False + + scope = server_config.connection_id or None + if ( + self._runtime_state_connected + and self._reported_account_id == normalized_account_id + and self._reported_scope == scope + ): + return True + + accepted = False + try: + accepted = await self._adapter_capability.update_runtime_state( + connected=True, + account_id=normalized_account_id, + scope=server_config.connection_id, + metadata={"ws_url": server_config.build_ws_url()}, + ) + except Exception as exc: + self._logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}") + return False + + if not accepted: + self._logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新") + return False + + self._runtime_state_connected = True + self._reported_account_id = normalized_account_id + self._reported_scope = scope + self._logger.info( + f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} " + f"scope={self._reported_scope or '*'}" + ) + return True + + async def report_disconnected(self) -> None: + """向 Host 上报当前连接已断开,并撤销适配器路由。""" + if not self._runtime_state_connected: + self._reported_account_id = None + self._reported_scope = None + return + + try: + await self._adapter_capability.update_runtime_state(connected=False) + except Exception as exc: + self._logger.warning(f"NapCat 适配器上报断开状态失败: {exc}") + finally: + self._runtime_state_connected = False + self._reported_account_id = None + self._reported_scope = None diff --git a/src/plugins/built_in/napcat_adapter/transport.py b/src/plugins/built_in/napcat_adapter/transport.py new file mode 100644 index 00000000..d20de097 --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/transport.py @@ -0,0 +1,322 @@ +"""NapCat 正向 WebSocket 传输层。""" + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Set, cast +from uuid import uuid4 + +import asyncio +import contextlib +import json + +from napcat_adapter.config import NapCatServerConfig + +if TYPE_CHECKING: + from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse + +try: + from aiohttp import ClientSession, ClientTimeout, WSMsgType + + AIOHTTP_AVAILABLE = True +except ImportError: + ClientSession = cast(Any, None) + ClientTimeout = cast(Any, None) + WSMsgType = cast(Any, None) + AIOHTTP_AVAILABLE = False + +if not TYPE_CHECKING: + AiohttpClientWebSocketResponse = Any + + +class NapCatTransportClient: + """NapCat 正向 WebSocket 客户端。""" + + def __init__( + self, + logger: Any, + on_connection_opened: Callable[[], Awaitable[None]], + on_connection_closed: Callable[[], Awaitable[None]], + on_payload: Callable[[Dict[str, Any]], Awaitable[None]], + ) -> None: + """初始化传输层客户端。 + + Args: + logger: 插件日志对象。 + on_connection_opened: 连接建立后的异步回调。 + on_connection_closed: 连接断开后的异步回调。 + on_payload: 收到非 echo 载荷后的异步回调。 + """ + self._logger = logger + self._on_connection_opened = on_connection_opened + self._on_connection_closed = on_connection_closed + self._on_payload = on_payload + self._server_config: Optional[NapCatServerConfig] = None + self._connection_task: Optional[asyncio.Task[None]] = None + self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} + self._background_tasks: Set[asyncio.Task[Any]] = set() + self._send_lock = asyncio.Lock() + self._ws: Optional[AiohttpClientWebSocketResponse] = None + self._stop_requested: bool = False + self._connection_active: bool = False + + @classmethod + def is_available(cls) -> bool: + """判断当前环境是否安装了传输层依赖。 + + Returns: + bool: 若已安装 ``aiohttp``,则返回 ``True``。 + """ + return AIOHTTP_AVAILABLE + + def configure(self, server_config: NapCatServerConfig) -> None: + """更新当前传输层使用的 NapCat 服务端配置。 + + Args: + server_config: 最新生效的 NapCat 服务端配置。 + """ + self._server_config = server_config + + async def start(self) -> None: + """启动 NapCat 正向 WebSocket 连接循环。 + + Raises: + RuntimeError: 当缺少配置或依赖时抛出。 + """ + if not self.is_available(): + raise RuntimeError("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") + if self._server_config is None: + raise RuntimeError("NapCat 适配器尚未配置 napcat_server") + if self._connection_task is not None and not self._connection_task.done(): + return + + self._stop_requested = False + self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection") + + async def stop(self) -> None: + """停止当前连接并清理所有后台任务。""" + self._stop_requested = True + connection_task = self._connection_task + self._connection_task = None + + ws = self._ws + if ws is not None and not ws.closed: + with contextlib.suppress(Exception): + await ws.close() + self._ws = None + + if connection_task is not None: + connection_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await connection_task + + await self._cancel_background_tasks() + await self._notify_connection_closed() + self._fail_pending_actions("NapCat connection closed") + + async def call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]: + """发送 OneBot 动作并等待对应的 echo 响应。 + + Args: + action_name: OneBot 动作名称。 + params: 动作参数。 + + Returns: + Dict[str, Any]: NapCat 返回的原始响应字典。 + + Raises: + RuntimeError: 当连接不可用时抛出。 + """ + ws = self._ws + server_config = self._server_config + if ws is None or ws.closed or server_config is None: + raise RuntimeError("NapCat is not connected") + + echo_id = uuid4().hex + loop = asyncio.get_running_loop() + response_future: asyncio.Future[Dict[str, Any]] = loop.create_future() + self._pending_actions[echo_id] = response_future + + request_payload = {"action": action_name, "params": params, "echo": echo_id} + try: + async with self._send_lock: + await ws.send_str(json.dumps(request_payload, ensure_ascii=False)) + return await asyncio.wait_for(response_future, timeout=server_config.action_timeout_sec) + finally: + self._pending_actions.pop(echo_id, None) + + async def _connection_loop(self) -> None: + """维护单个 WebSocket 连接,并在断开后按配置重连。""" + assert ClientSession is not None + assert ClientTimeout is not None + + while not self._stop_requested: + server_config = self._server_config + if server_config is None: + return + + ws_url = server_config.build_ws_url() + timeout = ClientTimeout(total=None, connect=10) + + try: + async with ClientSession(headers=self._build_headers(server_config), timeout=timeout) as session: + async with session.ws_connect(ws_url, heartbeat=server_config.heartbeat_interval or None) as ws: + self._ws = ws + self._logger.info(f"NapCat 适配器已连接: {ws_url}") + await self._receive_loop(ws) + except asyncio.CancelledError: + raise + except Exception as exc: + self._logger.warning(f"NapCat 适配器连接失败: {exc}") + finally: + self._ws = None + await self._notify_connection_closed() + self._fail_pending_actions("NapCat connection interrupted") + + if self._stop_requested: + break + + await asyncio.sleep(server_config.reconnect_delay_sec) + + async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None: + """持续消费 WebSocket 消息并分发处理。 + + Args: + ws: 当前活跃的 WebSocket 连接对象。 + """ + assert WSMsgType is not None + + bootstrap_task = self._create_background_task( + self._notify_connection_opened(), + "napcat_adapter.bootstrap", + ) + try: + async for ws_message in ws: + if ws_message.type != WSMsgType.TEXT: + if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: + break + continue + + payload = self._parse_json_message(ws_message.data) + if payload is None: + continue + + if echo_id := str(payload.get("echo") or "").strip(): + self._resolve_pending_action(echo_id, payload) + continue + + self._create_background_task(self._on_payload(payload), "napcat_adapter.payload") + finally: + if bootstrap_task is not None and not bootstrap_task.done(): + bootstrap_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await bootstrap_task + + def _create_background_task(self, coroutine: Awaitable[Any], name: str) -> asyncio.Task[Any]: + """创建并跟踪一个后台任务。 + + Args: + coroutine: 待执行的协程对象。 + name: 任务名。 + + Returns: + asyncio.Task[Any]: 已创建的后台任务。 + """ + task = asyncio.create_task(coroutine, name=name) + self._background_tasks.add(task) + task.add_done_callback(self._handle_background_task_completion) + return task + + def _handle_background_task_completion(self, task: asyncio.Task[Any]) -> None: + """处理后台任务结束后的清理与异常记录。 + + Args: + task: 已结束的后台任务。 + """ + self._background_tasks.discard(task) + if task.cancelled(): + return + + exception = task.exception() + if exception is not None: + self._logger.error(f"NapCat 适配器后台任务异常: {exception}", exc_info=True) + + async def _cancel_background_tasks(self) -> None: + """取消所有仍在运行的后台任务。""" + background_tasks = list(self._background_tasks) + for task in background_tasks: + task.cancel() + if background_tasks: + with contextlib.suppress(Exception): + await asyncio.gather(*background_tasks, return_exceptions=True) + self._background_tasks.clear() + + async def _notify_connection_opened(self) -> None: + """在连接建立后触发上层回调。""" + if self._connection_active: + return + + self._connection_active = True + try: + await self._on_connection_opened() + except Exception as exc: + self._logger.warning(f"NapCat 适配器连接建立回调失败: {exc}") + + async def _notify_connection_closed(self) -> None: + """在连接断开后触发上层回调。""" + if not self._connection_active: + return + + self._connection_active = False + try: + await self._on_connection_closed() + except Exception as exc: + self._logger.warning(f"NapCat 适配器断连回调失败: {exc}") + + def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None: + """解析等待中的动作响应。 + + Args: + echo_id: 动作请求对应的 echo 标识。 + payload: NapCat 返回的响应载荷。 + """ + response_future = self._pending_actions.get(echo_id) + if response_future is None or response_future.done(): + return + response_future.set_result(payload) + + def _fail_pending_actions(self, error_message: str) -> None: + """让所有等待中的动作以异常方式结束。 + + Args: + error_message: 写入异常中的错误信息。 + """ + for response_future in self._pending_actions.values(): + if not response_future.done(): + response_future.set_exception(RuntimeError(error_message)) + self._pending_actions.clear() + + def _build_headers(self, server_config: NapCatServerConfig) -> Dict[str, str]: + """构造连接 NapCat 所需的请求头。 + + Args: + server_config: 当前生效的 NapCat 服务端配置。 + + Returns: + Dict[str, str]: WebSocket 握手请求头。 + """ + return {"Authorization": f"Bearer {server_config.token}"} if server_config.token else {} + + def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]: + """解析 WebSocket 文本消息中的 JSON 数据。 + + Args: + data: WebSocket 收到的原始文本数据。 + + Returns: + Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。 + """ + try: + payload = json.loads(str(data)) + except Exception as exc: + self._logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}") + return None + + return payload if isinstance(payload, dict) else None From 56a6d2fd8ce13459395334c64d96ff6c991aef9c Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sun, 22 Mar 2026 00:22:24 +0800 Subject: [PATCH 25/45] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=93=8D=E4=BD=9C=E5=92=8C=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=EF=BC=8C=E5=A2=9E=E5=BC=BA=E8=A1=A8=E8=BE=BE?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E5=92=8C=E9=BB=91=E8=AF=9D=E8=A1=A8=E7=9A=84?= =?UTF-8?q?=E6=8F=92=E5=85=A5=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/common_test/test_expression_schema.py | 78 +++++++++++++++++ pytests/common_test/test_jargon_schema.py | 84 +++++++++++++++++++ src/common/database/database_model.py | 15 ++-- src/learners/expression_learner.py | 17 +++- src/learners/jargon_miner.py | 18 ++-- 5 files changed, 195 insertions(+), 17 deletions(-) create mode 100644 pytests/common_test/test_expression_schema.py create mode 100644 pytests/common_test/test_jargon_schema.py diff --git a/pytests/common_test/test_expression_schema.py b/pytests/common_test/test_expression_schema.py new file mode 100644 index 00000000..31fcd98f --- /dev/null +++ b/pytests/common_test/test_expression_schema.py @@ -0,0 +1,78 @@ +"""测试表达方式表结构和基础插入行为。""" + +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.common.database.database_model import Expression + + +@pytest.fixture(name="expression_engine") +def expression_engine_fixture() -> Generator: + """创建仅用于表达方式表测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None: + """表达方式表在新库中应能自动分配自增主键。""" + with Session(expression_engine) as session: + expression = Expression( + situation="表达情绪高涨或生理反应", + style="发送💦表情符号", + content_list='["表达情绪高涨或生理反应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + session.add(expression) + session.commit() + session.refresh(expression) + + assert expression.id is not None + assert expression.id > 0 + + +def test_expression_insert_allows_same_situation_style(expression_engine) -> None: + """相同情景和风格的表达方式记录不应再被错误绑定到复合主键。""" + with Session(expression_engine) as session: + first_expression = Expression( + situation="对重复行为的默契响应", + style="持续性跟发相同内容", + content_list='["对重复行为的默契响应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + second_expression = Expression( + situation="对重复行为的默契响应", + style="持续性跟发相同内容", + content_list='["对重复行为的默契响应-变体"]', + count=2, + session_id="session-b", + checked=False, + rejected=False, + ) + + session.add(first_expression) + session.add(second_expression) + session.commit() + session.refresh(first_expression) + session.refresh(second_expression) + + assert first_expression.id is not None + assert second_expression.id is not None + assert first_expression.id != second_expression.id diff --git a/pytests/common_test/test_jargon_schema.py b/pytests/common_test/test_jargon_schema.py new file mode 100644 index 00000000..909392ab --- /dev/null +++ b/pytests/common_test/test_jargon_schema.py @@ -0,0 +1,84 @@ +"""测试黑话表结构和基础插入行为。""" + +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.common.database.database_model import Jargon + + +@pytest.fixture(name="jargon_engine") +def jargon_engine_fixture() -> Generator: + """创建仅用于黑话表测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None: + """黑话表在新库中应能自动分配自增主键。""" + with Session(jargon_engine) as session: + jargon = Jargon( + content="VF8V4L", + raw_content='["[1] test"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=True, + last_inference_count=0, + ) + session.add(jargon) + session.commit() + session.refresh(jargon) + + assert jargon.id is not None + assert jargon.id > 0 + + +def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None: + """黑话内容不应再被错误地绑成复合主键的一部分。""" + with Session(jargon_engine) as session: + first_jargon = Jargon( + content="表情1", + raw_content='["[1] first"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + second_jargon = Jargon( + content="表情1", + raw_content='["[1] second"]', + meaning="", + session_id_dict='{"session-b": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + + session.add(first_jargon) + session.add(second_jargon) + session.commit() + session.refresh(first_jargon) + session.refresh(second_jargon) + + assert first_jargon.id is not None + assert second_jargon.id is not None + assert first_jargon.id != second_jargon.id diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index a0993a77..5b274c43 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,8 +1,9 @@ -from typing import Optional -from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime -from sqlmodel import SQLModel, Field, LargeBinary -from enum import Enum from datetime import datetime +from enum import Enum +from typing import Optional + +from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float +from sqlmodel import Field, LargeBinary, SQLModel class ModelUser(str, Enum): @@ -172,8 +173,8 @@ class Expression(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 - situation: str = Field(index=True, max_length=255, primary_key=True) # 情景 - style: str = Field(index=True, max_length=255, primary_key=True) # 风格 + situation: str = Field(index=True, max_length=255) # 情景 + style: str = Field(index=True, max_length=255) # 风格 # context: str # 上下文 # up_content: str @@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 - content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容 + content: str = Field(index=True, max_length=255) # 黑话内容 raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str] meaning: str # 黑话含义 diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py index 43e4ee7d..156fedc5 100644 --- a/src/learners/expression_learner.py +++ b/src/learners/expression_learner.py @@ -329,7 +329,13 @@ class ExpressionLearner: return filtered_expressions # ====== DB 操作相关 ====== - async def _upsert_expression_to_db(self, situation: str, style: str): + async def _upsert_expression_to_db(self, situation: str, style: str) -> None: + """将表达方式写入数据库,存在时更新,不存在时新增。 + + Args: + situation: 表达方式对应的使用情景。 + style: 表达方式风格。 + """ expr, similarity = self._find_similar_expression(situation) or (None, 0) if expr: # 根据相似度决定是否使用 LLM 总结 @@ -340,7 +346,13 @@ class ExpressionLearner: # 没有找到匹配的记录,创建新记录 self._create_expression(situation, style) - def _create_expression(self, situation: str, style: str): + def _create_expression(self, situation: str, style: str) -> None: + """创建新的表达方式记录。 + + Args: + situation: 表达方式对应的使用情景。 + style: 表达方式风格。 + """ content_list = [situation] try: with get_db_session() as db: @@ -353,6 +365,7 @@ class ExpressionLearner: last_active_time=datetime.now(), ) db.add(new_expr) + db.flush() except Exception as e: logger.error(f"创建表达方式失败: {e}") diff --git a/src/learners/jargon_miner.py b/src/learners/jargon_miner.py index 2fbf8a2e..674e5cc0 100644 --- a/src/learners/jargon_miner.py +++ b/src/learners/jargon_miner.py @@ -1,17 +1,18 @@ from collections import OrderedDict -from json_repair import repair_json -from sqlmodel import select -from typing import List, Optional, Dict, Callable, TypedDict, Set +from typing import Callable, Dict, List, Optional, Set, TypedDict import asyncio import json import random -from src.common.logger import get_logger +from json_repair import repair_json +from sqlmodel import select + +from src.common.data_models.jargon_data_model import MaiJargon from src.common.database.database import get_db_session from src.common.database.database_model import Jargon -from src.common.data_models.jargon_data_model import MaiJargon -from src.config.config import model_config, global_config +from src.common.logger import get_logger +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.prompt.prompt_manager import prompt_manager @@ -273,11 +274,12 @@ class JargonMiner: try: with get_db_session() as session: session.add(new_jargon) + session.flush() + saved += 1 + self._add_to_cache(content) except Exception as e: logger.error(f"保存新黑话 '{content}' 失败: {e}") continue - finally: - self._add_to_cache(content) # 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出) if uniq_entries: # 收集所有提取的jargon内容 From 89df7ccf6b213916aa787290e8c3f3ac97a4fbb0 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sun, 22 Mar 2026 00:43:34 +0800 Subject: [PATCH 26/45] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20NapCat=20?= =?UTF-8?q?=E9=80=82=E9=85=8D=E5=99=A8=E7=9A=84=E5=85=A5=E7=AB=99=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E7=BC=96=E8=A7=A3=E7=A0=81=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E6=8F=92=E4=BB=B6=E9=85=8D=E7=BD=AE=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E9=80=BB=E8=BE=91=E5=92=8C=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E4=BA=A4=E4=BA=92=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../common_test/test_expression_learner.py | 81 ++++++++++++++ pytests/common_test/test_jargon_miner.py | 90 +++++++++++++++ pytests/test_napcat_adapter_codec.py | 81 ++++++++++++++ pytests/test_napcat_adapter_plugin.py | 60 ++++++++++ pytests/test_plugin_runtime.py | 35 ++++-- src/learners/expression_learner.py | 44 +++++--- src/learners/jargon_miner.py | 24 +++- src/plugin_runtime/integration.py | 26 ++++- .../built_in/napcat_adapter/codec_inbound.py | 104 +++++++++++++++++- src/plugins/built_in/napcat_adapter/plugin.py | 1 + 10 files changed, 511 insertions(+), 35 deletions(-) create mode 100644 pytests/common_test/test_expression_learner.py create mode 100644 pytests/common_test/test_jargon_miner.py create mode 100644 pytests/test_napcat_adapter_plugin.py diff --git a/pytests/common_test/test_expression_learner.py b/pytests/common_test/test_expression_learner.py new file mode 100644 index 00000000..951aa424 --- /dev/null +++ b/pytests/common_test/test_expression_learner.py @@ -0,0 +1,81 @@ +"""测试表达方式学习器的数据库读取行为。""" + +from contextlib import contextmanager +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.bw_learner.expression_learner import ExpressionLearner +from src.common.database.database_model import Expression + + +@pytest.fixture(name="expression_learner_engine") +def expression_learner_engine_fixture() -> Generator: + """创建用于表达方式学习器测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_find_similar_expression_uses_read_only_session_and_history_content( + monkeypatch: pytest.MonkeyPatch, + expression_learner_engine, +) -> None: + """查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。""" + import src.bw_learner.expression_learner as expression_learner_module + + with Session(expression_learner_engine) as session: + session.add( + Expression( + situation="发送汗滴表情", + style="发送💦表情符号", + content_list='["表达情绪高涨或生理反应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + ) + session.commit() + + @contextmanager + def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: + """构造带自动提交语义的测试会话工厂。 + + Args: + auto_commit: 退出上下文时是否自动提交。 + + Yields: + Generator[Session, None, None]: SQLModel 会话对象。 + """ + session = Session(expression_learner_engine) + try: + yield session + if auto_commit: + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session) + + learner = ExpressionLearner(session_id="session-a") + result = learner._find_similar_expression("表达情绪高涨或生理反应") + + assert result is not None + expression, similarity = result + assert expression.item_id is not None + assert expression.style == "发送💦表情符号" + assert similarity == pytest.approx(1.0) diff --git a/pytests/common_test/test_jargon_miner.py b/pytests/common_test/test_jargon_miner.py new file mode 100644 index 00000000..bf81e4d2 --- /dev/null +++ b/pytests/common_test/test_jargon_miner.py @@ -0,0 +1,90 @@ +"""测试黑话学习器的数据库读取行为。""" + +from contextlib import contextmanager +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine, select + +from src.bw_learner.jargon_miner import JargonMiner +from src.common.database.database_model import Jargon + + +@pytest.fixture(name="jargon_miner_engine") +def jargon_miner_engine_fixture() -> Generator: + """创建用于黑话学习器测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +@pytest.mark.asyncio +async def test_process_extracted_entries_updates_existing_jargon_without_detached_session( + monkeypatch: pytest.MonkeyPatch, + jargon_miner_engine, +) -> None: + """更新已有黑话时,不应因会话关闭导致 ORM 实例失效。""" + import src.bw_learner.jargon_miner as jargon_miner_module + + with Session(jargon_miner_engine) as session: + session.add( + Jargon( + content="VF8V4L", + raw_content='["[1] first"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=0, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + ) + session.commit() + + @contextmanager + def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: + """构造带自动提交语义的测试会话工厂。 + + Args: + auto_commit: 退出上下文时是否自动提交。 + + Yields: + Generator[Session, None, None]: SQLModel 会话对象。 + """ + session = Session(jargon_miner_engine) + try: + yield session + if auto_commit: + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session) + + jargon_miner = JargonMiner(session_id="session-a", session_name="测试群") + await jargon_miner.process_extracted_entries( + [{"content": "VF8V4L", "raw_content": {"[2] second"}}], + ) + + with Session(jargon_miner_engine) as session: + db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one() + + assert db_jargon.count == 1 + assert db_jargon.session_id_dict == '{"session-a": 2}' + assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [ + "[1] first", + "[2] second", + ] diff --git a/pytests/test_napcat_adapter_codec.py b/pytests/test_napcat_adapter_codec.py index 6f557e08..97ed1d9e 100644 --- a/pytests/test_napcat_adapter_codec.py +++ b/pytests/test_napcat_adapter_codec.py @@ -3,12 +3,16 @@ from typing import Any, Dict import importlib import sys +from types import SimpleNamespace + +import pytest BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) +NapCatInboundCodec = importlib.import_module("napcat_adapter.codec_inbound").NapCatInboundCodec NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec @@ -68,3 +72,80 @@ def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> No assert action_name == "send_private_msg" assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"} + + +class DummyQueryService: + """用于测试的轻量查询服务。""" + + async def download_binary(self, url: str) -> bytes: + """返回固定图片二进制。 + + Args: + url: 图片地址。 + + Returns: + bytes: 固定测试图片二进制。 + """ + if url: + return b"image-bytes" + return b"" + + async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None: + """返回空消息详情。 + + Args: + message_id: 目标消息 ID。 + + Returns: + Dict[str, Any] | None: 固定空结果。 + """ + del message_id + return None + + async def get_record_detail(self, file_name: str, file_id: str | None = None) -> Dict[str, Any] | None: + """返回空语音详情。 + + Args: + file_name: 语音文件名。 + file_id: 可选文件 ID。 + + Returns: + Dict[str, Any] | None: 固定空结果。 + """ + del file_name + del file_id + return None + + async def get_forward_message(self, message_id: str) -> Dict[str, Any] | None: + """返回空转发详情。 + + Args: + message_id: 转发消息 ID。 + + Returns: + Dict[str, Any] | None: 固定空结果。 + """ + del message_id + return None + + +@pytest.mark.asyncio +async def test_napcat_inbound_codec_parses_cq_string_image_segments() -> None: + codec = NapCatInboundCodec(SimpleNamespace(debug=lambda message: None), DummyQueryService()) + payload = { + "message": "[CQ:image,file=test.png,sub_type=0,url=https://example.com/test.png][CQ:at,qq=10001] 看到是国人直接给你封了", + } + + raw_message, is_at = await codec.convert_segments(payload, "10001") + + assert raw_message[0]["type"] == "image" + assert raw_message[1] == { + "type": "at", + "data": { + "target_user_id": "10001", + "target_user_nickname": None, + "target_user_cardname": None, + }, + } + assert raw_message[2] == {"type": "text", "data": " 看到是国人直接给你封了"} + assert is_at is True diff --git a/pytests/test_napcat_adapter_plugin.py b/pytests/test_napcat_adapter_plugin.py new file mode 100644 index 00000000..ca550a39 --- /dev/null +++ b/pytests/test_napcat_adapter_plugin.py @@ -0,0 +1,60 @@ +"""NapCat 插件入口行为测试。""" + +from pathlib import Path +from typing import List +from types import SimpleNamespace + +import importlib +import sys + +import pytest + + +BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" +if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: + sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) + +NapCatAdapterPlugin = importlib.import_module("napcat_adapter.plugin").NapCatAdapterPlugin + + +class DummyLogger: + """用于测试的轻量日志对象。""" + + def __init__(self) -> None: + """初始化测试日志对象。""" + self.debug_messages: List[str] = [] + + def debug(self, message: str) -> None: + """记录调试日志。 + + Args: + message: 待记录的日志内容。 + """ + self.debug_messages.append(message) + + +@pytest.mark.asyncio +async def test_on_config_update_refreshes_settings_and_restarts(monkeypatch: pytest.MonkeyPatch) -> None: + """配置更新时应刷新插件配置、清空旧 settings,并触发连接重启。""" + plugin = NapCatAdapterPlugin() + plugin._ctx = SimpleNamespace(logger=DummyLogger()) + plugin._settings = object() + + restart_calls: List[dict] = [] + + async def fake_restart() -> None: + """记录一次重启调用。""" + restart_calls.append(dict(plugin._plugin_config)) + + monkeypatch.setattr(plugin, "_restart_connection_if_needed", fake_restart) + + new_config = { + "plugin": {"enabled": True, "config_version": "0.1.0"}, + "napcat_server": {"host": "127.0.0.1", "port": 3001}, + } + await plugin.on_config_update(new_config, "v2") + + assert plugin._plugin_config == new_config + assert plugin._settings is None + assert restart_calls == [new_config] + assert plugin.ctx.logger.debug_messages == ["NapCat 适配器收到配置更新通知: v2"] diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 5ab16c85..20cceb82 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -2238,6 +2238,7 @@ class TestIntegration: async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path): from src.config.file_watcher import FileChange from src.plugin_runtime import integration as integration_module + import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" @@ -2247,6 +2248,10 @@ class TestIntegration: beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") + (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") monkeypatch.chdir(tmp_path) @@ -2257,8 +2262,8 @@ class TestIntegration: self.reload_reasons = [] self.config_updates = [] - async def reload_plugins(self, reason="manual"): - self.reload_reasons.append(reason) + async def reload_plugins(self, plugin_ids=None, reason="manual"): + self.reload_reasons.append((plugin_ids, reason)) async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): self.config_updates.append((plugin_id, config_data, config_version)) @@ -2283,13 +2288,13 @@ class TestIntegration: await manager._handle_plugin_source_changes(changes) assert manager._builtin_supervisor.reload_reasons == [] - assert manager._third_party_supervisor.reload_reasons == ["file_watcher"] + assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")] assert manager._builtin_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == [] assert refresh_calls == [True] @pytest.mark.asyncio - async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): + async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module from src.config.file_watcher import FileChange @@ -2308,27 +2313,35 @@ class TestIntegration: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} - self.config_updates = [] + self.reload_calls = [] - async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): - self.config_updates.append((plugin_id, config_data, config_version)) + async def reload_plugin(self, plugin_id, reason="manual"): + self.reload_calls.append((plugin_id, reason)) return True manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) + refresh_calls = [] + + def fake_refresh() -> None: + refresh_calls.append(True) + + manager._refresh_plugin_config_watch_subscriptions = fake_refresh await manager._handle_plugin_config_changes( "alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) - assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "")] - assert manager._third_party_supervisor.config_updates == [] + assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")] + assert manager._third_party_supervisor.reload_calls == [] + assert refresh_calls == [True] def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): from src.plugin_runtime import integration as integration_module + import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" @@ -2336,6 +2349,10 @@ class TestIntegration: beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) + (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") class FakeWatcher: def __init__(self): diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py index 156fedc5..b82ae1fa 100644 --- a/src/learners/expression_learner.py +++ b/src/learners/expression_learner.py @@ -461,25 +461,43 @@ class ExpressionLearner: def _find_similar_expression( self, situation: str, similarity_threshold: float = 0.75 ) -> Optional[Tuple[MaiExpression, float]]: - """在数据库中查找相似的表达方式""" + """在数据库中查找相似的表达方式。 + + Args: + situation: 当前待匹配的情景描述。 + similarity_threshold: 认定为相似表达方式的最低相似度阈值。 + + Returns: + Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回 + ``(表达方式对象, 相似度)``;否则返回 ``None``。 + """ try: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Expression).filter_by(session_id=self.session_id) expressions = session.exec(statement).all() - best_match: Optional[Expression] = None - best_similarity = 0.0 + best_match: Optional[MaiExpression] = None + best_similarity = 0.0 + + for db_expression in expressions: + expression = MaiExpression.from_db_instance(db_expression) + candidate_situations = [expression.situation, *expression.content] + for candidate_situation in candidate_situations: + normalized_candidate_situation = candidate_situation.strip() + if not normalized_candidate_situation: + continue + similarity = difflib.SequenceMatcher( + None, + situation, + normalized_candidate_situation, + ).ratio() + if similarity > similarity_threshold and similarity > best_similarity: + best_similarity = similarity + best_match = expression - for expr in expressions: - content_list = json.loads(expr.content_list) - for situation in content_list: - similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio() - if similarity > similarity_threshold and similarity > best_similarity: - best_similarity = similarity - best_match = expr if best_match: - logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}") - return MaiExpression.from_db_instance(best_match), best_similarity + logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}") + return best_match, best_similarity except Exception as e: logger.error(f"查找相似表达方式失败: {e}") diff --git a/src/learners/jargon_miner.py b/src/learners/jargon_miner.py index 674e5cc0..32926894 100644 --- a/src/learners/jargon_miner.py +++ b/src/learners/jargon_miner.py @@ -199,7 +199,7 @@ class JargonMiner: async def process_extracted_entries( self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None - ): + ) -> None: """ 处理已提取的黑话条目(从 expression_learner 路由过来的) @@ -230,7 +230,7 @@ class JargonMiner: content = entry["content"] raw_content_set = entry["raw_content"] try: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: jargon_items = session.exec(select(Jargon).filter_by(content=content)).all() except Exception as e: logger.error(f"查询黑话 '{content}' 失败: {e}") @@ -306,7 +306,13 @@ class JargonMiner: removed_content, _ = self.cache.popitem(last=False) logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}") - def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]): + def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None: + """更新已有黑话记录并写回数据库。 + + Args: + db_jargon: 已命中的黑话 ORM 对象。 + raw_content_set: 本次新增的原始上下文集合。 + """ db_jargon.count += 1 existing_raw_content: List[str] = [] if db_jargon.raw_content: @@ -328,7 +334,17 @@ class JargonMiner: try: with get_db_session() as session: - session.add(db_jargon) + if db_jargon.id is None: + raise ValueError("黑话记录缺少 id,无法更新数据库") + statement = select(Jargon).filter_by(id=db_jargon.id).limit(1) + if persisted_jargon := session.exec(statement).first(): + persisted_jargon.count = db_jargon.count + persisted_jargon.raw_content = db_jargon.raw_content + persisted_jargon.session_id_dict = db_jargon.session_id_dict + persisted_jargon.is_global = db_jargon.is_global + session.add(persisted_jargon) + else: + logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新") except Exception as e: logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}") diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 24cf09fc..bf85669b 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -612,7 +612,17 @@ class PluginRuntimeManager( return None if plugin_path is None else plugin_path / "config.toml" async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None: - """处理单个插件配置文件变化,并仅向目标插件推送配置更新。""" + """处理单个插件配置文件变化,并精确重载目标插件。 + + Args: + plugin_id: 发生配置变更的插件 ID。 + changes: 当前批次收集到的配置文件变更列表。 + + Notes: + 这里选择“精确重载该插件”,而不是仅推送软性的配置更新通知。 + 这样可以保证没有实现 ``on_config_update()`` 的插件也能重新执行 + ``on_load()``,让磁盘上的 ``config.toml`` 修改对插件运行态真正生效。 + """ if not self._started or not changes: return @@ -626,18 +636,24 @@ class PluginRuntimeManager( return try: - await supervisor.notify_plugin_config_updated( + self._load_plugin_config_for_supervisor(supervisor, plugin_id) + reload_success = await supervisor.reload_plugin( plugin_id=plugin_id, - config_data=self._load_plugin_config_for_supervisor(supervisor, plugin_id), + reason="config_file_changed", ) + if reload_success: + self._refresh_plugin_config_watch_subscriptions() + else: + logger.warning(f"插件 {plugin_id} 配置文件变更后重载失败") except Exception as exc: - logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}") + logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}") async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None: """处理插件源码相关变化。 这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由 - 单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。 + 单独的 per-plugin watcher 处理,并精确重载对应插件,避免放大成 + 不必要的跨插件 reload。 """ if not self._started or not changes: return diff --git a/src/plugins/built_in/napcat_adapter/codec_inbound.py b/src/plugins/built_in/napcat_adapter/codec_inbound.py index b8065585..8fb020dc 100644 --- a/src/plugins/built_in/napcat_adapter/codec_inbound.py +++ b/src/plugins/built_in/napcat_adapter/codec_inbound.py @@ -5,11 +5,15 @@ from uuid import uuid4 import hashlib import json +import re import time from napcat_adapter.qq_queries import NapCatQueryService +_CQ_SEGMENT_PATTERN = re.compile(r"\[CQ:(?P[a-zA-Z0-9_]+)(?P(?:,[^\]]*)?)\]") + + class NapCatInboundCodec: """NapCat 入站消息编码器。""" @@ -104,8 +108,12 @@ class NapCatInboundCodec: """ message_payload = payload.get("message") if isinstance(message_payload, str): - normalized_text = message_payload.strip() - return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False + parsed_message_payload = self._parse_cq_message_text(message_payload) + if parsed_message_payload: + message_payload = parsed_message_payload + else: + normalized_text = self._decode_cq_entities(message_payload).strip() + return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False if not isinstance(message_payload, list): return [], False @@ -223,8 +231,8 @@ class NapCatInboundCodec: Returns: Dict[str, Any]: 转换后的图片或表情消息段。 """ - subtype = segment_data.get("sub_type") - actual_is_emoji = is_emoji or (isinstance(subtype, int) and subtype not in {0, 4, 9}) + subtype = self._normalize_numeric_segment_value(segment_data.get("sub_type")) + actual_is_emoji = is_emoji or (subtype is not None and subtype not in {0, 4, 9}) image_url = str(segment_data.get("url") or "").strip() binary_data = await self._query_service.download_binary(image_url) @@ -412,3 +420,91 @@ class NapCatInboundCodec: plain_text = "".join(part for part in plain_text_parts if part).strip() return plain_text or fallback_text or "[unsupported]" + + def _parse_cq_message_text(self, message_text: str) -> List[Dict[str, Any]]: + """将 CQ 码字符串解析为 OneBot 风格消息段列表。 + + Args: + message_text: NapCat 在字符串模式下返回的消息内容。 + + Returns: + List[Dict[str, Any]]: 解析后的 OneBot 风格消息段列表。 + """ + parsed_segments: List[Dict[str, Any]] = [] + current_index = 0 + + for match in _CQ_SEGMENT_PATTERN.finditer(message_text): + prefix_text = self._decode_cq_entities(message_text[current_index : match.start()]) + if prefix_text: + parsed_segments.append({"type": "text", "data": {"text": prefix_text}}) + + segment_type = str(match.group("type") or "").strip() + segment_data = self._parse_cq_segment_data(match.group("params") or "") + if segment_type: + parsed_segments.append({"type": segment_type, "data": segment_data}) + current_index = match.end() + + suffix_text = self._decode_cq_entities(message_text[current_index:]) + if suffix_text: + parsed_segments.append({"type": "text", "data": {"text": suffix_text}}) + + return parsed_segments + + def _parse_cq_segment_data(self, raw_params: str) -> Dict[str, Any]: + """解析单个 CQ 段中的参数串。 + + Args: + raw_params: 形如 ``,key=value,key2=value2`` 的原始参数字符串。 + + Returns: + Dict[str, Any]: 解析后的参数字典。 + """ + parsed_data: Dict[str, Any] = {} + if not raw_params: + return parsed_data + + for item in raw_params.lstrip(",").split(","): + if not item or "=" not in item: + continue + key, value = item.split("=", 1) + normalized_key = key.strip() + if not normalized_key: + continue + decoded_value = self._decode_cq_entities(value) + parsed_data[normalized_key] = self._normalize_numeric_segment_value(decoded_value) + + return parsed_data + + @staticmethod + def _decode_cq_entities(text: str) -> str: + """解码 CQ 码中的 HTML 风格转义实体。 + + Args: + text: 待解码的 CQ 文本。 + + Returns: + str: 解码后的普通文本。 + """ + return ( + text.replace("&", "&") + .replace("[", "[") + .replace("]", "]") + .replace(",", ",") + ) + + @staticmethod + def _normalize_numeric_segment_value(value: Any) -> Any: + """将可安全识别的数字字符串转为整数。 + + Args: + value: 原始字段值。 + + Returns: + Any: 规范化后的字段值。 + """ + if isinstance(value, str): + stripped_value = value.strip() + if stripped_value.isdigit(): + return int(stripped_value) + return stripped_value + return value diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index b1e9bc8c..50900c5d 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -71,6 +71,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): version: 配置版本号。 """ self.set_plugin_config(new_config) + self._settings = None if version: self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}") await self._restart_connection_if_needed() From a0c653de4532cdba8e9a9d99e85b9433c8a87b46 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sun, 22 Mar 2026 01:04:29 +0800 Subject: [PATCH 27/45] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E8=A7=84=E8=8C=83=E5=92=8C=E8=AF=AD=E8=A8=80=E8=A7=84?= =?UTF-8?q?=E8=8C=83=EF=BC=8C=E5=BC=BA=E8=B0=83=E4=BD=BF=E7=94=A8=20Google?= =?UTF-8?q?=20DocStr=20=E6=A0=BC=E5=BC=8F=E5=92=8C=E7=AE=80=E4=BD=93?= =?UTF-8?q?=E4=B8=AD=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AGENTS.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index b4caaaf1..b3456610 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,6 +17,7 @@ 1. 尽量保持良好的注释 2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。 3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。 +4. 对于类,方法以及模块的注释,首选使用的注释格式为 Google DocStr 格式,但保证语言为简体中文 ## 类型注解规范 1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。 2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解) @@ -35,3 +36,7 @@ # 运行/调试/构建/测试/依赖 优先使用uv 依赖项以 pyproject.toml 为准 + +# 语言规范 + +项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都应该首要以简体中文为首要实现目标 From 0066224251a644697f8c1ebeede83b565fdedfbf Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 22 Mar 2026 12:50:09 +0800 Subject: [PATCH 28/45] fix: remove nc ada --- .../built_in/napcat_adapter/__init__.py | 1 - .../built_in/napcat_adapter/_manifest.json | 30 -- .../built_in/napcat_adapter/codec_inbound.py | 510 ------------------ .../built_in/napcat_adapter/codec_outbound.py | 192 ------- src/plugins/built_in/napcat_adapter/config.py | 398 -------------- .../built_in/napcat_adapter/constants.py | 9 - .../built_in/napcat_adapter/filters.py | 68 --- src/plugins/built_in/napcat_adapter/plugin.py | 381 ------------- .../built_in/napcat_adapter/qq_notice.py | 224 -------- .../built_in/napcat_adapter/qq_queries.py | 170 ------ .../built_in/napcat_adapter/runtime_state.py | 85 --- .../built_in/napcat_adapter/transport.py | 322 ----------- 12 files changed, 2390 deletions(-) delete mode 100644 src/plugins/built_in/napcat_adapter/__init__.py delete mode 100644 src/plugins/built_in/napcat_adapter/_manifest.json delete mode 100644 src/plugins/built_in/napcat_adapter/codec_inbound.py delete mode 100644 src/plugins/built_in/napcat_adapter/codec_outbound.py delete mode 100644 src/plugins/built_in/napcat_adapter/config.py delete mode 100644 src/plugins/built_in/napcat_adapter/constants.py delete mode 100644 src/plugins/built_in/napcat_adapter/filters.py delete mode 100644 src/plugins/built_in/napcat_adapter/plugin.py delete mode 100644 src/plugins/built_in/napcat_adapter/qq_notice.py delete mode 100644 src/plugins/built_in/napcat_adapter/qq_queries.py delete mode 100644 src/plugins/built_in/napcat_adapter/runtime_state.py delete mode 100644 src/plugins/built_in/napcat_adapter/transport.py diff --git a/src/plugins/built_in/napcat_adapter/__init__.py b/src/plugins/built_in/napcat_adapter/__init__.py deleted file mode 100644 index fa82860f..00000000 --- a/src/plugins/built_in/napcat_adapter/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""NapCat 内置适配器插件包。""" diff --git a/src/plugins/built_in/napcat_adapter/_manifest.json b/src/plugins/built_in/napcat_adapter/_manifest.json deleted file mode 100644 index 6f7e68fd..00000000 --- a/src/plugins/built_in/napcat_adapter/_manifest.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "manifest_version": 1, - "name": "napcat_adapter_builtin", - "version": "0.1.0", - "description": "Built-in NapCat adapter plugin for MVP message forwarding.", - "author": { - "name": "OpenAI Codex" - }, - "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" - }, - "keywords": [ - "adapter", - "built-in", - "napcat", - "onebot", - "qq" - ], - "categories": [ - "Adapter", - "Built-in" - ], - "default_locale": "en-US", - "plugin_info": { - "is_built_in": true, - "plugin_type": "adapter" - }, - "capabilities": [] -} diff --git a/src/plugins/built_in/napcat_adapter/codec_inbound.py b/src/plugins/built_in/napcat_adapter/codec_inbound.py deleted file mode 100644 index 8fb020dc..00000000 --- a/src/plugins/built_in/napcat_adapter/codec_inbound.py +++ /dev/null @@ -1,510 +0,0 @@ -"""NapCat 入站消息编解码。""" - -from typing import Any, Dict, List, Mapping, Optional, Tuple -from uuid import uuid4 - -import hashlib -import json -import re -import time - -from napcat_adapter.qq_queries import NapCatQueryService - - -_CQ_SEGMENT_PATTERN = re.compile(r"\[CQ:(?P[a-zA-Z0-9_]+)(?P(?:,[^\]]*)?)\]") - - -class NapCatInboundCodec: - """NapCat 入站消息编码器。""" - - def __init__(self, logger: Any, query_service: NapCatQueryService) -> None: - """初始化入站消息编码器。 - - Args: - logger: 插件日志对象。 - query_service: QQ 查询服务。 - """ - self._logger = logger - self._query_service = query_service - - async def build_message_dict( - self, - payload: Mapping[str, Any], - self_id: str, - sender_user_id: str, - sender: Mapping[str, Any], - ) -> Dict[str, Any]: - """构造 Host 侧可接受的 ``MessageDict``。 - - Args: - payload: NapCat 原始消息事件。 - self_id: 当前机器人账号 ID。 - sender_user_id: 发送者用户 ID。 - sender: 发送者信息字典。 - - Returns: - Dict[str, Any]: 规范化后的 ``MessageDict``。 - """ - message_type = str(payload.get("message_type") or "").strip() or "private" - group_id = str(payload.get("group_id") or "").strip() - group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "") - user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id - user_cardname = str(sender.get("card") or "").strip() or None - - raw_message, is_at = await self.convert_segments(payload, self_id) - raw_message_text = str(payload.get("raw_message") or "").strip() - if not raw_message: - raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}] - - plain_text = self.build_plain_text(raw_message, raw_message_text) - timestamp_seconds = payload.get("time") - if not isinstance(timestamp_seconds, (int, float)): - timestamp_seconds = time.time() - - additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type} - if group_id: - additional_config["platform_io_target_group_id"] = group_id - else: - additional_config["platform_io_target_user_id"] = sender_user_id - - message_info: Dict[str, Any] = { - "user_info": { - "user_id": sender_user_id, - "user_nickname": user_nickname, - "user_cardname": user_cardname, - }, - "additional_config": additional_config, - } - if group_id: - message_info["group_info"] = {"group_id": group_id, "group_name": group_name} - - message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip() - return { - "message_id": message_id, - "timestamp": str(float(timestamp_seconds)), - "platform": "qq", - "message_info": message_info, - "raw_message": raw_message, - "is_mentioned": is_at, - "is_at": is_at, - "is_emoji": False, - "is_picture": False, - "is_command": plain_text.startswith("/"), - "is_notify": False, - "session_id": "", - "processed_plain_text": plain_text, - "display_message": plain_text, - } - - async def convert_segments(self, payload: Mapping[str, Any], self_id: str) -> Tuple[List[Dict[str, Any]], bool]: - """将 OneBot 消息段转换为 Host 消息段结构。 - - Args: - payload: OneBot 原始消息事件。 - self_id: 当前机器人账号 ID。 - - Returns: - Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 - """ - message_payload = payload.get("message") - if isinstance(message_payload, str): - parsed_message_payload = self._parse_cq_message_text(message_payload) - if parsed_message_payload: - message_payload = parsed_message_payload - else: - normalized_text = self._decode_cq_entities(message_payload).strip() - return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False - - if not isinstance(message_payload, list): - return [], False - - converted_segments: List[Dict[str, Any]] = [] - is_at = False - for segment in message_payload: - if not isinstance(segment, Mapping): - continue - - segment_type = str(segment.get("type") or "").strip() - segment_data = segment.get("data", {}) - if not isinstance(segment_data, Mapping): - segment_data = {} - - if segment_type == "text": - if text_value := str(segment_data.get("text") or ""): - converted_segments.append({"type": "text", "data": text_value}) - continue - - if segment_type == "at": - if target_user_id := str(segment_data.get("qq") or "").strip(): - converted_segments.append( - { - "type": "at", - "data": { - "target_user_id": target_user_id, - "target_user_nickname": None, - "target_user_cardname": None, - }, - } - ) - if self_id and target_user_id == self_id: - is_at = True - continue - - if segment_type == "reply": - if reply_segment := await self._build_reply_segment(segment_data): - converted_segments.append(reply_segment) - continue - - if segment_type == "face": - converted_segments.append({"type": "text", "data": "[face]"}) - continue - - if segment_type == "image": - converted_segments.append(await self._build_image_like_segment(segment_data, is_emoji=False)) - continue - - if segment_type == "record": - converted_segments.append(await self._build_record_segment(segment_data)) - continue - - if segment_type == "video": - converted_segments.append({"type": "text", "data": "[video]"}) - continue - - if segment_type == "file": - converted_segments.append({"type": "text", "data": "[file]"}) - continue - - if segment_type == "json": - converted_segments.append(self._build_json_text_segment(segment_data)) - continue - - if segment_type == "forward": - if forward_segment := await self._build_forward_segment(segment_data): - converted_segments.append(forward_segment) - continue - - if segment_type in {"xml", "share"}: - converted_segments.append({"type": "text", "data": f"[{segment_type}]"}) - - return converted_segments, is_at - - async def _build_reply_segment(self, segment_data: Mapping[str, Any]) -> Optional[Dict[str, Any]]: - """构造回复消息段。 - - Args: - segment_data: OneBot ``reply`` 段的 ``data`` 字典。 - - Returns: - Optional[Dict[str, Any]]: 转换后的回复消息段;缺少消息 ID 时返回 ``None``。 - """ - target_message_id = str(segment_data.get("id") or "").strip() - if not target_message_id: - return None - - message_detail = await self._query_service.get_message_detail(target_message_id) - reply_payload: Dict[str, Any] = {"target_message_id": target_message_id} - if message_detail is not None: - sender = message_detail.get("sender", {}) - if not isinstance(sender, Mapping): - sender = {} - reply_payload["target_message_content"] = str(message_detail.get("raw_message") or "").strip() or None - reply_payload["target_message_sender_id"] = str( - message_detail.get("user_id") or sender.get("user_id") or "" - ).strip() or None - reply_payload["target_message_sender_nickname"] = str(sender.get("nickname") or "").strip() or None - reply_payload["target_message_sender_cardname"] = str(sender.get("card") or "").strip() or None - - return {"type": "reply", "data": reply_payload} - - async def _build_image_like_segment( - self, - segment_data: Mapping[str, Any], - is_emoji: bool, - ) -> Dict[str, Any]: - """构造图片或表情消息段。 - - Args: - segment_data: OneBot ``image`` 段的 ``data`` 字典。 - is_emoji: 是否按表情组件处理。 - - Returns: - Dict[str, Any]: 转换后的图片或表情消息段。 - """ - subtype = self._normalize_numeric_segment_value(segment_data.get("sub_type")) - actual_is_emoji = is_emoji or (subtype is not None and subtype not in {0, 4, 9}) - - image_url = str(segment_data.get("url") or "").strip() - binary_data = await self._query_service.download_binary(image_url) - if not binary_data: - return {"type": "text", "data": "[emoji]" if actual_is_emoji else "[image]"} - - return { - "type": "emoji" if actual_is_emoji else "image", - "data": "", - "hash": hashlib.sha256(binary_data).hexdigest(), - "binary_data_base64": self._encode_binary(binary_data), - } - - async def _build_record_segment(self, segment_data: Mapping[str, Any]) -> Dict[str, Any]: - """构造语音消息段。 - - Args: - segment_data: OneBot ``record`` 段的 ``data`` 字典。 - - Returns: - Dict[str, Any]: 转换后的语音或占位文本消息段。 - """ - file_name = str(segment_data.get("file") or "").strip() - file_id = str(segment_data.get("file_id") or "").strip() or None - if not file_name: - return {"type": "text", "data": "[voice]"} - - record_detail = await self._query_service.get_record_detail(file_name=file_name, file_id=file_id) - if record_detail is None: - return {"type": "text", "data": "[voice]"} - - record_base64 = str(record_detail.get("base64") or "").strip() - if not record_base64: - return {"type": "text", "data": "[voice]"} - - try: - binary_data = self._decode_binary(record_base64) - except Exception: - return {"type": "text", "data": "[voice]"} - - return { - "type": "voice", - "data": "", - "hash": hashlib.sha256(binary_data).hexdigest(), - "binary_data_base64": self._encode_binary(binary_data), - } - - async def _build_forward_segment(self, segment_data: Mapping[str, Any]) -> Optional[Dict[str, Any]]: - """构造合并转发消息段。 - - Args: - segment_data: OneBot ``forward`` 段的 ``data`` 字典。 - - Returns: - Optional[Dict[str, Any]]: 转换后的合并转发消息段;失败时返回 ``None``。 - """ - message_id = str(segment_data.get("id") or "").strip() - if not message_id: - return None - - forward_detail = await self._query_service.get_forward_message(message_id) - if forward_detail is None: - return {"type": "text", "data": "[forward]"} - - messages = forward_detail.get("messages", []) - if not isinstance(messages, list): - return {"type": "text", "data": "[forward]"} - - forward_nodes: List[Dict[str, Any]] = [] - for forward_message in messages: - if not isinstance(forward_message, Mapping): - continue - raw_content = forward_message.get("content", []) - content_segments = await self._convert_forward_content(raw_content, "") - sender = forward_message.get("sender", {}) - if not isinstance(sender, Mapping): - sender = {} - forward_nodes.append( - { - "user_id": str(sender.get("user_id") or sender.get("uin") or "").strip() or None, - "user_nickname": str(sender.get("nickname") or sender.get("name") or "未知用户"), - "user_cardname": str(sender.get("card") or "").strip() or None, - "message_id": str(forward_message.get("message_id") or uuid4().hex), - "content": content_segments or [{"type": "text", "data": "[empty]"}], - } - ) - - if not forward_nodes: - return {"type": "text", "data": "[forward]"} - return {"type": "forward", "data": forward_nodes} - - async def _convert_forward_content(self, raw_content: Any, self_id: str) -> List[Dict[str, Any]]: - """转换转发节点内部的消息段列表。 - - Args: - raw_content: 转发节点原始内容。 - self_id: 当前机器人账号 ID。 - - Returns: - List[Dict[str, Any]]: 转换后的消息段列表。 - """ - pseudo_payload: Dict[str, Any] = {"message": raw_content} - segments, _ = await self.convert_segments(pseudo_payload, self_id) - return segments - - def _build_json_text_segment(self, segment_data: Mapping[str, Any]) -> Dict[str, Any]: - """将 JSON 卡片最佳努力转换为文本占位。 - - Args: - segment_data: OneBot ``json`` 段的 ``data`` 字典。 - - Returns: - Dict[str, Any]: 转换后的文本消息段。 - """ - json_data = str(segment_data.get("data") or "").strip() - if not json_data: - return {"type": "text", "data": "[json]"} - - try: - parsed_json = json.loads(json_data) - except Exception: - return {"type": "text", "data": "[json]"} - - app_name = str(parsed_json.get("app") or "").strip() - prompt = "" - if isinstance(parsed_json.get("meta"), Mapping): - prompt = str(parsed_json["meta"].get("prompt") or "").strip() - text = prompt or app_name or "json" - return {"type": "text", "data": f"[json:{text}]"} - - @staticmethod - def _encode_binary(binary_data: bytes) -> str: - """将二进制内容编码为 Base64 字符串。 - - Args: - binary_data: 待编码的二进制内容。 - - Returns: - str: Base64 编码字符串。 - """ - import base64 - - return base64.b64encode(binary_data).decode("utf-8") - - @staticmethod - def _decode_binary(binary_base64: str) -> bytes: - """将 Base64 字符串解码为二进制内容。 - - Args: - binary_base64: Base64 字符串。 - - Returns: - bytes: 解码后的二进制内容。 - """ - import base64 - - return base64.b64decode(binary_base64) - - def build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str: - """从标准消息段中提取可展示的纯文本。 - - Args: - raw_message: 标准化后的消息段列表。 - fallback_text: 当无法拼出文本时使用的回退文本。 - - Returns: - str: 用于 Host 展示和命令判断的纯文本内容。 - """ - plain_text_parts: List[str] = [] - for item in raw_message: - if not isinstance(item, Mapping): - continue - item_type = str(item.get("type") or "").strip() - item_data = item.get("data") - if item_type == "text": - plain_text_parts.append(str(item_data or "")) - elif item_type == "at" and isinstance(item_data, Mapping): - plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}") - elif item_type == "reply": - plain_text_parts.append("[reply]") - elif item_type == "forward": - plain_text_parts.append("[forward]") - elif item_type in {"image", "emoji", "voice"}: - plain_text_parts.append(f"[{item_type}]") - - plain_text = "".join(part for part in plain_text_parts if part).strip() - return plain_text or fallback_text or "[unsupported]" - - def _parse_cq_message_text(self, message_text: str) -> List[Dict[str, Any]]: - """将 CQ 码字符串解析为 OneBot 风格消息段列表。 - - Args: - message_text: NapCat 在字符串模式下返回的消息内容。 - - Returns: - List[Dict[str, Any]]: 解析后的 OneBot 风格消息段列表。 - """ - parsed_segments: List[Dict[str, Any]] = [] - current_index = 0 - - for match in _CQ_SEGMENT_PATTERN.finditer(message_text): - prefix_text = self._decode_cq_entities(message_text[current_index : match.start()]) - if prefix_text: - parsed_segments.append({"type": "text", "data": {"text": prefix_text}}) - - segment_type = str(match.group("type") or "").strip() - segment_data = self._parse_cq_segment_data(match.group("params") or "") - if segment_type: - parsed_segments.append({"type": segment_type, "data": segment_data}) - current_index = match.end() - - suffix_text = self._decode_cq_entities(message_text[current_index:]) - if suffix_text: - parsed_segments.append({"type": "text", "data": {"text": suffix_text}}) - - return parsed_segments - - def _parse_cq_segment_data(self, raw_params: str) -> Dict[str, Any]: - """解析单个 CQ 段中的参数串。 - - Args: - raw_params: 形如 ``,key=value,key2=value2`` 的原始参数字符串。 - - Returns: - Dict[str, Any]: 解析后的参数字典。 - """ - parsed_data: Dict[str, Any] = {} - if not raw_params: - return parsed_data - - for item in raw_params.lstrip(",").split(","): - if not item or "=" not in item: - continue - key, value = item.split("=", 1) - normalized_key = key.strip() - if not normalized_key: - continue - decoded_value = self._decode_cq_entities(value) - parsed_data[normalized_key] = self._normalize_numeric_segment_value(decoded_value) - - return parsed_data - - @staticmethod - def _decode_cq_entities(text: str) -> str: - """解码 CQ 码中的 HTML 风格转义实体。 - - Args: - text: 待解码的 CQ 文本。 - - Returns: - str: 解码后的普通文本。 - """ - return ( - text.replace("&", "&") - .replace("[", "[") - .replace("]", "]") - .replace(",", ",") - ) - - @staticmethod - def _normalize_numeric_segment_value(value: Any) -> Any: - """将可安全识别的数字字符串转为整数。 - - Args: - value: 原始字段值。 - - Returns: - Any: 规范化后的字段值。 - """ - if isinstance(value, str): - stripped_value = value.strip() - if stripped_value.isdigit(): - return int(stripped_value) - return stripped_value - return value diff --git a/src/plugins/built_in/napcat_adapter/codec_outbound.py b/src/plugins/built_in/napcat_adapter/codec_outbound.py deleted file mode 100644 index 6adcb622..00000000 --- a/src/plugins/built_in/napcat_adapter/codec_outbound.py +++ /dev/null @@ -1,192 +0,0 @@ -"""NapCat 出站消息编解码。""" - -from typing import Any, Dict, List, Mapping, Tuple - - -class NapCatOutboundCodec: - """NapCat 出站消息编码器。""" - - def build_outbound_action( - self, - message: Mapping[str, Any], - route: Mapping[str, Any], - ) -> Tuple[str, Dict[str, Any]]: - """为 Host 出站消息构造 OneBot 动作。 - - Args: - message: Host 侧标准 ``MessageDict``。 - route: Platform IO 路由信息。 - - Returns: - Tuple[str, Dict[str, Any]]: 动作名称与参数字典。 - - Raises: - ValueError: 当私聊出站缺少目标用户 ID 时抛出。 - """ - message_info = message.get("message_info", {}) - if not isinstance(message_info, Mapping): - message_info = {} - - group_info = message_info.get("group_info", {}) - if not isinstance(group_info, Mapping): - group_info = {} - - additional_config = message_info.get("additional_config", {}) - if not isinstance(additional_config, Mapping): - additional_config = {} - - raw_message = message.get("raw_message", []) - segments = self.convert_segments(raw_message) - - if target_group_id := str( - group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or "" - ).strip(): - return "send_group_msg", {"group_id": target_group_id, "message": segments} - - target_user_id = str( - additional_config.get("platform_io_target_user_id") - or additional_config.get("target_user_id") - or route.get("target_user_id") - or "" - ).strip() - if not target_user_id: - raise ValueError("Outbound private message is missing target_user_id") - - return "send_private_msg", {"message": segments, "user_id": target_user_id} - - def convert_segments(self, raw_message: Any) -> List[Dict[str, Any]]: - """将 Host 消息段转换为 OneBot 消息段。 - - Args: - raw_message: Host 侧 ``raw_message`` 字段。 - - Returns: - List[Dict[str, Any]]: OneBot 消息段列表。 - """ - if not isinstance(raw_message, list): - return [{"type": "text", "data": {"text": ""}}] - - outbound_segments: List[Dict[str, Any]] = [] - for item in raw_message: - if not isinstance(item, Mapping): - continue - - item_type = str(item.get("type") or "").strip() - item_data = item.get("data") - - if item_type == "text": - text_value = str(item_data or "") - outbound_segments.append({"type": "text", "data": {"text": text_value}}) - continue - - if item_type == "at" and isinstance(item_data, Mapping): - if target_user_id := str(item_data.get("target_user_id") or "").strip(): - outbound_segments.append({"type": "at", "data": {"qq": target_user_id}}) - continue - - if item_type == "reply": - if isinstance(item_data, Mapping): - target_message_id = str(item_data.get("target_message_id") or "").strip() - else: - target_message_id = str(item_data or "").strip() - if target_message_id: - outbound_segments.append({"type": "reply", "data": {"id": target_message_id}}) - continue - - if item_type == "image": - binary_base64 = str(item.get("binary_data_base64") or "").strip() - if binary_base64: - outbound_segments.append( - { - "type": "image", - "data": {"file": f"base64://{binary_base64}", "subtype": 0}, - } - ) - else: - outbound_segments.append({"type": "text", "data": {"text": "[image]"}}) - continue - - if item_type == "emoji": - binary_base64 = str(item.get("binary_data_base64") or "").strip() - if binary_base64: - outbound_segments.append( - { - "type": "image", - "data": { - "file": f"base64://{binary_base64}", - "subtype": 1, - "summary": "[动画表情]", - }, - } - ) - else: - outbound_segments.append({"type": "text", "data": {"text": "[emoji]"}}) - continue - - if item_type == "voice": - binary_base64 = str(item.get("binary_data_base64") or "").strip() - if binary_base64: - outbound_segments.append({"type": "record", "data": {"file": f"base64://{binary_base64}"}}) - else: - outbound_segments.append({"type": "text", "data": {"text": "[voice]"}}) - continue - - if item_type == "forward" and isinstance(item_data, list): - outbound_segments.extend(self._build_forward_nodes(item_data)) - continue - - if item_type == "dict" and isinstance(item_data, Mapping): - if dict_segment := self._build_dict_component_segment(item_data): - outbound_segments.append(dict_segment) - continue - - fallback_text = f"[unsupported:{item_type or 'unknown'}]" - outbound_segments.append({"type": "text", "data": {"text": fallback_text}}) - - if not outbound_segments: - outbound_segments.append({"type": "text", "data": {"text": ""}}) - return outbound_segments - - def _build_forward_nodes(self, forward_nodes: List[Any]) -> List[Dict[str, Any]]: - """构造 NapCat 转发节点列表。 - - Args: - forward_nodes: 内部转发节点列表。 - - Returns: - List[Dict[str, Any]]: NapCat 转发节点列表。 - """ - built_nodes: List[Dict[str, Any]] = [] - for node in forward_nodes: - if not isinstance(node, Mapping): - continue - raw_content = node.get("content", []) - node_segments = self.convert_segments(raw_content) - built_nodes.append( - { - "type": "node", - "data": { - "name": str(node.get("user_nickname") or node.get("user_cardname") or "QQ用户"), - "uin": str(node.get("user_id") or ""), - "content": node_segments, - }, - } - ) - return built_nodes - - def _build_dict_component_segment(self, item_data: Mapping[str, Any]) -> Dict[str, Any]: - """尽力将 ``DictComponent`` 转换为 NapCat 消息段。 - - Args: - item_data: ``DictComponent`` 原始数据。 - - Returns: - Dict[str, Any]: NapCat 消息段;不支持时返回占位文本段。 - """ - raw_type = str(item_data.get("type") or "").strip() - raw_payload = item_data.get("data", item_data) - if raw_type in {"file", "music", "video", "face"} and isinstance(raw_payload, Mapping): - return {"type": raw_type, "data": dict(raw_payload)} - if raw_type in {"image", "record", "reply", "at"} and isinstance(raw_payload, Mapping): - return {"type": raw_type, "data": dict(raw_payload)} - return {"type": "text", "data": {"text": f"[unsupported:{raw_type or 'dict'}]"}} diff --git a/src/plugins/built_in/napcat_adapter/config.py b/src/plugins/built_in/napcat_adapter/config.py deleted file mode 100644 index eeb4acab..00000000 --- a/src/plugins/built_in/napcat_adapter/config.py +++ /dev/null @@ -1,398 +0,0 @@ -"""NapCat 内置适配器配置解析。""" - -from dataclasses import dataclass, field -from typing import Any, Dict, Mapping, Optional, Set, Tuple -from urllib.parse import urlparse - -from napcat_adapter.constants import ( - DEFAULT_ACTION_TIMEOUT_SEC, - DEFAULT_CHAT_LIST_TYPE, - DEFAULT_HEARTBEAT_INTERVAL_SEC, - DEFAULT_NAPCAT_HOST, - DEFAULT_NAPCAT_PORT, - DEFAULT_RECONNECT_DELAY_SEC, - SUPPORTED_CONFIG_VERSION, -) - - -@dataclass(frozen=True) -class NapCatPluginOptions: - """插件级配置。""" - - enabled: bool = False - config_version: str = "" - - def should_connect(self) -> bool: - """判断当前配置下是否应当启动连接。 - - Returns: - bool: 若插件连接已启用,则返回 ``True``。 - """ - return self.enabled - - -@dataclass(frozen=True) -class NapCatServerConfig: - """NapCat 正向 WebSocket 连接配置。""" - - host: str = DEFAULT_NAPCAT_HOST - port: int = DEFAULT_NAPCAT_PORT - token: str = "" - heartbeat_interval: float = DEFAULT_HEARTBEAT_INTERVAL_SEC - reconnect_delay_sec: float = DEFAULT_RECONNECT_DELAY_SEC - action_timeout_sec: float = DEFAULT_ACTION_TIMEOUT_SEC - connection_id: str = "" - - def build_ws_url(self) -> str: - """构造正向 WebSocket 地址。 - - Returns: - str: 供适配器作为客户端连接的 NapCat WebSocket 地址。 - """ - return f"ws://{self.host}:{self.port}" - - -@dataclass(frozen=True) -class NapCatChatConfig: - """聊天名单配置。""" - - group_list_type: str = DEFAULT_CHAT_LIST_TYPE - group_list: Set[str] = field(default_factory=set) - private_list_type: str = DEFAULT_CHAT_LIST_TYPE - private_list: Set[str] = field(default_factory=set) - ban_user_id: Set[str] = field(default_factory=set) - - -@dataclass(frozen=True) -class NapCatFilterConfig: - """消息过滤配置。""" - - ignore_self_message: bool = True - - -@dataclass(frozen=True) -class NapCatPluginSettings: - """NapCat 插件完整配置。""" - - plugin: NapCatPluginOptions = field(default_factory=NapCatPluginOptions) - napcat_server: NapCatServerConfig = field(default_factory=NapCatServerConfig) - chat: NapCatChatConfig = field(default_factory=NapCatChatConfig) - filters: NapCatFilterConfig = field(default_factory=NapCatFilterConfig) - - @classmethod - def from_mapping(cls, raw_config: Mapping[str, Any], logger: Any) -> "NapCatPluginSettings": - """从 Runner 注入的原始配置字典解析插件配置。 - - Args: - raw_config: Runner 注入的原始配置内容。 - logger: 插件日志对象。 - - Returns: - NapCatPluginSettings: 规范化后的插件配置。 - """ - plugin_section = _as_mapping(raw_config.get("plugin")) - server_section = _as_mapping(raw_config.get("napcat_server")) - legacy_connection_section = _as_mapping(raw_config.get("connection")) - chat_section = _as_mapping(raw_config.get("chat")) - filters_section = _as_mapping(raw_config.get("filters")) - - if not server_section and legacy_connection_section: - logger.warning("NapCat 适配器检测到旧版 [connection] 配置段,请尽快迁移到 [napcat_server]") - server_section = legacy_connection_section - - legacy_host, legacy_port = _read_legacy_host_port(server_section, legacy_connection_section, logger) - parsed_host = _read_string(server_section, "host") or legacy_host or DEFAULT_NAPCAT_HOST - parsed_port = _read_positive_int( - mapping=server_section, - key="port", - default=legacy_port or DEFAULT_NAPCAT_PORT, - logger=logger, - setting_name="napcat_server.port", - ) - - return cls( - plugin=NapCatPluginOptions( - enabled=_read_bool(plugin_section, "enabled", False), - config_version=_read_string(plugin_section, "config_version"), - ), - napcat_server=NapCatServerConfig( - host=parsed_host, - port=parsed_port, - token=_read_string(server_section, "token") or _read_string(server_section, "access_token"), - heartbeat_interval=_read_positive_float( - mapping=server_section, - key="heartbeat_interval", - default=_read_positive_float( - mapping=server_section, - key="heartbeat_sec", - default=DEFAULT_HEARTBEAT_INTERVAL_SEC, - logger=logger, - setting_name="napcat_server.heartbeat_interval", - ), - logger=logger, - setting_name="napcat_server.heartbeat_interval", - ), - reconnect_delay_sec=_read_positive_float( - mapping=server_section, - key="reconnect_delay_sec", - default=DEFAULT_RECONNECT_DELAY_SEC, - logger=logger, - setting_name="napcat_server.reconnect_delay_sec", - ), - action_timeout_sec=_read_positive_float( - mapping=server_section, - key="action_timeout_sec", - default=DEFAULT_ACTION_TIMEOUT_SEC, - logger=logger, - setting_name="napcat_server.action_timeout_sec", - ), - connection_id=_read_string(server_section, "connection_id"), - ), - chat=NapCatChatConfig( - group_list_type=_read_list_mode( - mapping=chat_section, - key="group_list_type", - default=DEFAULT_CHAT_LIST_TYPE, - logger=logger, - setting_name="chat.group_list_type", - ), - group_list=_read_string_set(chat_section, "group_list"), - private_list_type=_read_list_mode( - mapping=chat_section, - key="private_list_type", - default=DEFAULT_CHAT_LIST_TYPE, - logger=logger, - setting_name="chat.private_list_type", - ), - private_list=_read_string_set(chat_section, "private_list"), - ban_user_id=_read_string_set(chat_section, "ban_user_id"), - ), - filters=NapCatFilterConfig( - ignore_self_message=_read_bool(filters_section, "ignore_self_message", True), - ), - ) - - def should_connect(self) -> bool: - """判断当前配置下是否应当启动连接。 - - Returns: - bool: 若插件连接已启用,则返回 ``True``。 - """ - return self.plugin.should_connect() - - def validate(self, logger: Any) -> bool: - """校验当前配置是否满足启动连接的前提条件。 - - Args: - logger: 插件日志对象。 - - Returns: - bool: 若配置满足启动连接的前提条件,则返回 ``True``。 - """ - config_version = self.plugin.config_version - if not config_version: - logger.error( - f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}" - ) - return False - - if config_version != SUPPORTED_CONFIG_VERSION: - logger.error( - "NapCat 适配器配置版本不兼容: " - f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}" - ) - return False - - if not self.napcat_server.host: - logger.warning("NapCat 适配器已启用,但 napcat_server.host 为空") - return False - - if self.napcat_server.port <= 0: - logger.warning("NapCat 适配器已启用,但 napcat_server.port 不是正整数") - return False - - return True - - -def _as_mapping(value: Any) -> Dict[str, Any]: - """将任意值安全转换为字典。 - - Args: - value: 待转换的值。 - - Returns: - Dict[str, Any]: 若原值是映射,则返回普通字典;否则返回空字典。 - """ - return dict(value) if isinstance(value, Mapping) else {} - - -def _read_bool(mapping: Mapping[str, Any], key: str, default: bool) -> bool: - """安全读取布尔配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - default: 读取失败时的默认值。 - - Returns: - bool: 解析后的布尔值。 - """ - value = mapping.get(key, default) - return value if isinstance(value, bool) else default - - -def _read_string(mapping: Mapping[str, Any], key: str) -> str: - """安全读取字符串配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - - Returns: - str: 去除首尾空白后的字符串值。 - """ - value = mapping.get(key) - return "" if value is None else str(value).strip() - - -def _read_positive_float( - mapping: Mapping[str, Any], - key: str, - default: float, - logger: Any, - setting_name: str, -) -> float: - """安全读取正浮点数配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - default: 读取失败时的默认值。 - logger: 插件日志对象。 - setting_name: 用于日志输出的完整配置名。 - - Returns: - float: 合法的正浮点数;否则返回默认值。 - """ - value = mapping.get(key, default) - if isinstance(value, (int, float)) and float(value) > 0: - return float(value) - - if key in mapping: - logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}") - return default - - -def _read_positive_int( - mapping: Mapping[str, Any], - key: str, - default: int, - logger: Any, - setting_name: str, -) -> int: - """安全读取正整数配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - default: 读取失败时的默认值。 - logger: 插件日志对象。 - setting_name: 用于日志输出的完整配置名。 - - Returns: - int: 合法的正整数;否则返回默认值。 - """ - value = mapping.get(key, default) - if isinstance(value, int) and value > 0: - return value - - if isinstance(value, str) and value.isdigit() and int(value) > 0: - return int(value) - - if key in mapping: - logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}") - return default - - -def _read_list_mode( - mapping: Mapping[str, Any], - key: str, - default: str, - logger: Any, - setting_name: str, -) -> str: - """安全读取名单模式配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - default: 读取失败时的默认值。 - logger: 插件日志对象。 - setting_name: 用于日志输出的完整配置名。 - - Returns: - str: 合法的名单模式字符串。 - """ - value = mapping.get(key, default) - if isinstance(value, str): - normalized_value = value.strip() - if normalized_value in {"whitelist", "blacklist"}: - return normalized_value - - if key in mapping: - logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}") - return default - - -def _read_string_set(mapping: Mapping[str, Any], key: str) -> Set[str]: - """安全读取字符串集合配置值。 - - Args: - mapping: 待读取的配置字典。 - key: 目标键名。 - - Returns: - Set[str]: 规范化后的字符串集合。 - """ - value = mapping.get(key, []) - if not isinstance(value, list): - return set() - - normalized_values: Set[str] = set() - for item in value: - item_text = "" if item is None else str(item).strip() - if item_text: - normalized_values.add(item_text) - return normalized_values - - -def _read_legacy_host_port( - server_section: Mapping[str, Any], - legacy_connection_section: Mapping[str, Any], - logger: Any, -) -> Tuple[str, Optional[int]]: - """从旧版 ``ws_url`` 配置中提取主机与端口。 - - Args: - server_section: 新版 ``napcat_server`` 配置段。 - legacy_connection_section: 旧版 ``connection`` 配置段。 - logger: 插件日志对象。 - - Returns: - Tuple[str, Optional[int]]: 解析到的主机与端口;若未找到,则返回空主机与 ``None``。 - """ - legacy_ws_url = _read_string(server_section, "ws_url") or _read_string(legacy_connection_section, "ws_url") - if not legacy_ws_url: - return "", None - - parsed_url = urlparse(legacy_ws_url) - parsed_host = parsed_url.hostname or "" - parsed_port = parsed_url.port - - logger.warning( - "NapCat 适配器检测到旧版 ws_url 配置,已临时兼容解析,请尽快迁移到 napcat_server.host/port" - ) - if parsed_url.path not in {"", "/"}: - logger.warning("NapCat 适配器旧版 ws_url 包含路径,新的 napcat_server 配置不会保留该路径") - - return parsed_host, parsed_port diff --git a/src/plugins/built_in/napcat_adapter/constants.py b/src/plugins/built_in/napcat_adapter/constants.py deleted file mode 100644 index bdddde6f..00000000 --- a/src/plugins/built_in/napcat_adapter/constants.py +++ /dev/null @@ -1,9 +0,0 @@ -"""NapCat 内置适配器共享常量。""" - -SUPPORTED_CONFIG_VERSION = "0.1.0" -DEFAULT_NAPCAT_HOST = "127.0.0.1" -DEFAULT_NAPCAT_PORT = 3001 -DEFAULT_RECONNECT_DELAY_SEC = 5.0 -DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0 -DEFAULT_ACTION_TIMEOUT_SEC = 15.0 -DEFAULT_CHAT_LIST_TYPE = "whitelist" diff --git a/src/plugins/built_in/napcat_adapter/filters.py b/src/plugins/built_in/napcat_adapter/filters.py deleted file mode 100644 index 141cda85..00000000 --- a/src/plugins/built_in/napcat_adapter/filters.py +++ /dev/null @@ -1,68 +0,0 @@ -"""NapCat 入站消息过滤。""" - -from typing import Any, Set - -from napcat_adapter.config import NapCatChatConfig - - -class NapCatChatFilter: - """NapCat 聊天名单过滤器。""" - - def __init__(self, logger: Any) -> None: - """初始化聊天名单过滤器。 - - Args: - logger: 插件日志对象。 - """ - self._logger = logger - - def is_inbound_chat_allowed( - self, - sender_user_id: str, - group_id: str, - chat_config: NapCatChatConfig, - ) -> bool: - """检查入站消息是否通过聊天名单过滤。 - - Args: - sender_user_id: 发送者用户 ID。 - group_id: 群聊 ID;私聊时为空字符串。 - chat_config: 当前生效的聊天配置。 - - Returns: - bool: 若消息允许继续进入 Host,则返回 ``True``。 - """ - if sender_user_id in chat_config.ban_user_id: - self._logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃") - return False - - if group_id: - if not self._is_id_allowed_by_list_policy(group_id, chat_config.group_list_type, chat_config.group_list): - self._logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃") - return False - return True - - if not self._is_id_allowed_by_list_policy( - sender_user_id, - chat_config.private_list_type, - chat_config.private_list, - ): - self._logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃") - return False - return True - - @staticmethod - def _is_id_allowed_by_list_policy(target_id: str, list_type: str, configured_ids: Set[str]) -> bool: - """根据白名单或黑名单规则判断目标 ID 是否允许通过。 - - Args: - target_id: 待检查的目标 ID。 - list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。 - configured_ids: 配置中的 ID 集合。 - - Returns: - bool: 若目标 ID 允许通过,则返回 ``True``。 - """ - if list_type == "whitelist": - return target_id in configured_ids - return target_id not in configured_ids diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py deleted file mode 100644 index 50900c5d..00000000 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ /dev/null @@ -1,381 +0,0 @@ -"""内置 NapCat 适配器插件。 - -当前实现维持 MVP 范围,目标是跑通基础消息收发链路: -1. 作为客户端连接 NapCat / OneBot v11 WebSocket 服务。 -2. 将入站消息事件转换为 Host 侧的 ``MessageDict``。 -3. 将 Host 出站消息转换为 OneBot 动作并发送。 - -当前范围刻意收敛为: -- 单连接 -- 文本、@、reply 基础转发 -- 暂不处理 ``notice`` / ``meta_event`` 的完整语义归一化 -- 暂不支持图片、语音、文件等复杂媒体 -""" - -from __future__ import annotations - -from typing import Any, Dict, Mapping, Optional - -import asyncio - -from maibot_sdk import Adapter, MaiBotPlugin - -from napcat_adapter.codec_inbound import NapCatInboundCodec -from napcat_adapter.codec_outbound import NapCatOutboundCodec -from napcat_adapter.config import NapCatPluginSettings -from napcat_adapter.filters import NapCatChatFilter -from napcat_adapter.qq_notice import NapCatNoticeCodec -from napcat_adapter.qq_queries import NapCatQueryService -from napcat_adapter.runtime_state import NapCatRuntimeStateManager -from napcat_adapter.transport import NapCatTransportClient - - -@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform") -class NapCatAdapterPlugin(MaiBotPlugin): - """NapCat 适配器 MVP 实现。""" - - def __init__(self) -> None: - """初始化 NapCat 适配器插件实例。""" - super().__init__() - self._plugin_config: Dict[str, Any] = {} - self._settings: Optional[NapCatPluginSettings] = None - self._inbound_codec: Optional[NapCatInboundCodec] = None - self._outbound_codec = NapCatOutboundCodec() - self._chat_filter: Optional[NapCatChatFilter] = None - self._query_service: Optional[NapCatQueryService] = None - self._notice_codec: Optional[NapCatNoticeCodec] = None - self._runtime_state: Optional[NapCatRuntimeStateManager] = None - self._transport: Optional[NapCatTransportClient] = None - - def set_plugin_config(self, config: Dict[str, Any]) -> None: - """设置插件配置内容。 - - Args: - config: Runner 注入的 ``config.toml`` 解析结果。 - """ - self._plugin_config = config if isinstance(config, dict) else {} - - async def on_load(self) -> None: - """在插件加载时根据配置决定是否启动连接。""" - await self._restart_connection_if_needed() - - async def on_unload(self) -> None: - """在插件卸载时关闭连接并清理运行时状态。""" - await self._stop_connection() - - async def on_config_update(self, new_config: Dict[str, Any], version: str) -> None: - """在配置更新后重载连接状态。 - - Args: - new_config: 最新的插件配置。 - version: 配置版本号。 - """ - self.set_plugin_config(new_config) - self._settings = None - if version: - self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}") - await self._restart_connection_if_needed() - - async def send_to_platform( - self, - message: Dict[str, Any], - route: Optional[Dict[str, Any]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Dict[str, Any]: - """将 Host 出站消息发送到 NapCat。 - - Args: - message: Host 侧标准 ``MessageDict``。 - route: Platform IO 生成的路由信息。 - metadata: Platform IO 附带的投递元数据。 - **kwargs: 预留的扩展参数。 - - Returns: - Dict[str, Any]: 标准化后的发送结果。 - """ - del metadata - del kwargs - - self._ensure_runtime_components() - transport = self._transport - if transport is None: - return {"success": False, "error": "NapCat transport is not initialized"} - - try: - action_name, params = self._outbound_codec.build_outbound_action(message, route or {}) - response = await transport.call_action(action_name, params) - except Exception as exc: - return {"success": False, "error": str(exc)} - - if str(response.get("status", "")).lower() != "ok": - return { - "success": False, - "error": str(response.get("wording") or response.get("message") or "NapCat send failed"), - "metadata": {"retcode": response.get("retcode")}, - } - - response_data = response.get("data", {}) - external_message_id = "" - if isinstance(response_data, Mapping): - external_message_id = str(response_data.get("message_id") or "") - - return { - "success": True, - "external_message_id": external_message_id or None, - "metadata": {"action": action_name}, - } - - def _ensure_runtime_components(self) -> None: - """确保运行时依赖对象已经完成初始化。""" - if self._chat_filter is None: - self._chat_filter = NapCatChatFilter(self.ctx.logger) - - if self._transport is None: - self._transport = NapCatTransportClient( - logger=self.ctx.logger, - on_connection_opened=self._bootstrap_adapter_runtime_state, - on_connection_closed=self._handle_transport_disconnected, - on_payload=self._handle_transport_payload, - ) - - if self._query_service is None: - self._query_service = NapCatQueryService(self.ctx.logger, self._transport) - - if self._inbound_codec is None: - self._inbound_codec = NapCatInboundCodec(self.ctx.logger, self._query_service) - - if self._notice_codec is None: - self._notice_codec = NapCatNoticeCodec(self.ctx.logger, self._query_service) - - if self._runtime_state is None: - self._runtime_state = NapCatRuntimeStateManager(self.ctx.adapter, self.ctx.logger) - - def _reload_settings(self) -> NapCatPluginSettings: - """重新解析当前插件配置。 - - Returns: - NapCatPluginSettings: 最新的规范化配置。 - """ - self._settings = NapCatPluginSettings.from_mapping(self._plugin_config, self.ctx.logger) - return self._settings - - async def _restart_connection_if_needed(self) -> None: - """根据当前配置重启连接循环。""" - self._ensure_runtime_components() - settings = self._reload_settings() - - await self._stop_connection() - if not settings.should_connect(): - self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用") - return - if not settings.validate(self.ctx.logger): - return - - transport = self._transport - assert transport is not None - if not transport.is_available(): - self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") - return - - transport.configure(settings.napcat_server) - await transport.start() - - async def _stop_connection(self) -> None: - """停止当前连接。""" - transport = self._transport - if transport is not None: - await transport.stop() - return - - runtime_state = self._runtime_state - if runtime_state is not None: - await runtime_state.report_disconnected() - - async def _handle_transport_payload(self, payload: Dict[str, Any]) -> None: - """处理来自传输层的非 echo 载荷。 - - Args: - payload: NapCat 推送的原始事件数据。 - """ - post_type = str(payload.get("post_type") or "").strip() - if post_type == "message": - await self._handle_inbound_message(payload) - return - if post_type == "notice": - await self._handle_notice_event(payload) - return - if post_type == "meta_event": - await self._handle_meta_event(payload) - - async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None: - """处理单条 NapCat 入站消息并注入 Host。 - - Args: - payload: NapCat / OneBot 推送的原始消息事件。 - """ - self._ensure_runtime_components() - settings = self._settings or self._reload_settings() - chat_filter = self._chat_filter - inbound_codec = self._inbound_codec - runtime_state = self._runtime_state - assert chat_filter is not None - assert inbound_codec is not None - assert runtime_state is not None - - self_id = str(payload.get("self_id") or "").strip() - if self_id: - await runtime_state.report_connected(self_id, settings.napcat_server) - - sender = payload.get("sender", {}) - if not isinstance(sender, Mapping): - sender = {} - - sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip() - if not sender_user_id: - return - - group_id = str(payload.get("group_id") or "").strip() - if self_id and sender_user_id == self_id and settings.filters.ignore_self_message: - return - if not chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat): - return - - message_dict = await inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender) - route_metadata: Dict[str, Any] = {} - if self_id: - route_metadata["self_id"] = self_id - if settings.napcat_server.connection_id: - route_metadata["connection_id"] = settings.napcat_server.connection_id - - external_message_id = str(payload.get("message_id") or "").strip() - accepted = await self.ctx.adapter.receive_external_message( - message_dict, - route_metadata=route_metadata, - external_message_id=external_message_id, - dedupe_key=external_message_id, - ) - if not accepted: - self.ctx.logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}") - - async def _handle_notice_event(self, payload: Dict[str, Any]) -> None: - """处理 NapCat ``notice`` 事件并注入 Host。 - - Args: - payload: NapCat 推送的通知事件。 - """ - self._ensure_runtime_components() - notice_codec = self._notice_codec - runtime_state = self._runtime_state - settings = self._settings or self._reload_settings() - assert notice_codec is not None - assert runtime_state is not None - - self_id = str(payload.get("self_id") or "").strip() - if self_id: - await runtime_state.report_connected(self_id, settings.napcat_server) - - message_dict = await notice_codec.build_notice_message_dict(payload) - if message_dict is None: - return - - route_metadata: Dict[str, Any] = {} - if self_id: - route_metadata["self_id"] = self_id - if settings.napcat_server.connection_id: - route_metadata["connection_id"] = settings.napcat_server.connection_id - - external_message_id = str(payload.get("message_id") or payload.get("notice_type") or "").strip() - accepted = await self.ctx.adapter.receive_external_message( - message_dict, - route_metadata=route_metadata, - external_message_id=external_message_id or None, - dedupe_key=external_message_id or None, - ) - if not accepted: - self.ctx.logger.debug(f"Host 丢弃了 NapCat 通知事件: {external_message_id or '无消息 ID'}") - - async def _handle_meta_event(self, payload: Dict[str, Any]) -> None: - """处理 NapCat ``meta_event`` 事件。 - - Args: - payload: NapCat 推送的元事件。 - """ - self._ensure_runtime_components() - notice_codec = self._notice_codec - runtime_state = self._runtime_state - settings = self._settings or self._reload_settings() - assert notice_codec is not None - assert runtime_state is not None - - self_id = str(payload.get("self_id") or "").strip() - if self_id: - await runtime_state.report_connected(self_id, settings.napcat_server) - - await notice_codec.handle_meta_event(payload) - - async def _bootstrap_adapter_runtime_state(self) -> None: - """在连接建立后主动获取账号信息并激活适配器路由。""" - transport = self._transport - query_service = self._query_service - runtime_state = self._runtime_state - settings = self._settings or self._reload_settings() - if transport is None or query_service is None or runtime_state is None: - return - - max_attempts = 3 - last_error: Optional[Exception] = None - for attempt in range(1, max_attempts + 1): - try: - login_info = await query_service.get_login_info() - self_id = self._extract_self_id_from_login_response(login_info) - await runtime_state.report_connected(self_id, settings.napcat_server) - return - except asyncio.CancelledError: - raise - except Exception as exc: - last_error = exc - self.ctx.logger.warning( - f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}" - ) - if attempt < max_attempts: - await asyncio.sleep(1.0) - - if last_error is not None: - self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}") - - async def _handle_transport_disconnected(self) -> None: - """处理传输层断开事件。""" - runtime_state = self._runtime_state - if runtime_state is not None: - await runtime_state.report_disconnected() - - @staticmethod - def _extract_self_id_from_login_response(response: Optional[Dict[str, Any]]) -> str: - """从 ``get_login_info`` 查询结果中提取当前账号 ID。 - - Args: - response: NapCat 返回的登录信息字典。 - - Returns: - str: 规范化后的账号 ID 字符串。 - - Raises: - ValueError: 当响应中缺少有效账号 ID 时抛出。 - """ - if not isinstance(response, Mapping): - raise ValueError("get_login_info 响应缺少 data 字段") - - self_id = str(response.get("user_id") or "").strip() - if not self_id: - raise ValueError("get_login_info 响应缺少有效的 user_id") - return self_id - - -def create_plugin() -> NapCatAdapterPlugin: - """创建插件实例。 - - Returns: - NapCatAdapterPlugin: NapCat 内置适配器插件实例。 - """ - return NapCatAdapterPlugin() diff --git a/src/plugins/built_in/napcat_adapter/qq_notice.py b/src/plugins/built_in/napcat_adapter/qq_notice.py deleted file mode 100644 index f577cf98..00000000 --- a/src/plugins/built_in/napcat_adapter/qq_notice.py +++ /dev/null @@ -1,224 +0,0 @@ -"""NapCat QQ 平台通知与元事件处理。""" - -from typing import Any, Dict, Mapping, Optional -from uuid import uuid4 - -import time - -from napcat_adapter.qq_queries import NapCatQueryService - - -class NapCatNoticeCodec: - """NapCat QQ 通知事件编码器。""" - - def __init__(self, logger: Any, query_service: NapCatQueryService) -> None: - """初始化通知事件编码器。 - - Args: - logger: 插件日志对象。 - query_service: QQ 查询服务。 - """ - self._logger = logger - self._query_service = query_service - - async def build_notice_message_dict(self, payload: Mapping[str, Any]) -> Optional[Dict[str, Any]]: - """将 NapCat ``notice`` 事件转换为 Host 可接受的消息字典。 - - Args: - payload: NapCat 推送的原始通知事件。 - - Returns: - Optional[Dict[str, Any]]: 成功时返回标准 ``MessageDict``;无法识别时返回 ``None``。 - """ - notice_type = str(payload.get("notice_type") or "").strip() - if not notice_type: - return None - - group_id = str(payload.get("group_id") or "").strip() - user_id = str(payload.get("user_id") or payload.get("operator_id") or "").strip() - self_id = str(payload.get("self_id") or "").strip() - - user_info = await self._build_user_info(group_id=group_id, user_id=user_id) - group_info = await self._build_group_info(group_id) - notice_text = self._build_notice_text(payload, user_info.get("user_nickname", user_id or "系统")) - if not notice_text: - return None - - additional_config: Dict[str, Any] = { - "self_id": self_id, - "napcat_notice_type": notice_type, - "napcat_notice_sub_type": str(payload.get("sub_type") or "").strip(), - "napcat_notice_payload": dict(payload), - } - if group_id: - additional_config["platform_io_target_group_id"] = group_id - elif user_id: - additional_config["platform_io_target_user_id"] = user_id - - message_info: Dict[str, Any] = {"user_info": user_info, "additional_config": additional_config} - if group_info is not None: - message_info["group_info"] = group_info - - timestamp_seconds = payload.get("time") - if not isinstance(timestamp_seconds, (int, float)): - timestamp_seconds = time.time() - - return { - "message_id": f"napcat-notice-{uuid4().hex}", - "timestamp": str(float(timestamp_seconds)), - "platform": "qq", - "message_info": message_info, - "raw_message": [{"type": "text", "data": notice_text}], - "is_mentioned": False, - "is_at": False, - "is_emoji": False, - "is_picture": False, - "is_command": False, - "is_notify": True, - "session_id": "", - "processed_plain_text": notice_text, - "display_message": notice_text, - } - - async def handle_meta_event(self, payload: Mapping[str, Any]) -> None: - """处理 ``meta_event`` 事件的日志与状态观测。 - - Args: - payload: NapCat 推送的原始元事件。 - """ - meta_event_type = str(payload.get("meta_event_type") or "").strip() - self_id = str(payload.get("self_id") or "").strip() or "unknown" - - if meta_event_type == "lifecycle": - sub_type = str(payload.get("sub_type") or "").strip() - if sub_type == "connect": - self._logger.info(f"NapCat 元事件:Bot {self_id} 已建立连接") - else: - self._logger.debug(f"NapCat 生命周期事件: self_id={self_id} sub_type={sub_type}") - return - - if meta_event_type == "heartbeat": - status = payload.get("status", {}) - if not isinstance(status, Mapping): - status = {} - is_online = bool(status.get("online", False)) - is_good = bool(status.get("good", False)) - interval_ms = payload.get("interval") - self._logger.debug( - f"NapCat 心跳事件: self_id={self_id} online={is_online} good={is_good} interval={interval_ms}" - ) - if not is_online: - self._logger.warning(f"NapCat 心跳显示 Bot {self_id} 已离线") - elif not is_good: - self._logger.warning(f"NapCat 心跳显示 Bot {self_id} 状态异常") - - async def _build_user_info(self, group_id: str, user_id: str) -> Dict[str, Optional[str]]: - """构造通知消息的用户信息。 - - Args: - group_id: 群号;私聊或系统通知时为空字符串。 - user_id: 事件关联用户号。 - - Returns: - Dict[str, Optional[str]]: 规范化后的用户信息字典。 - """ - if not user_id: - return { - "user_id": "notice", - "user_nickname": "系统通知", - "user_cardname": None, - } - - member_info: Optional[Dict[str, Any]] - if group_id: - member_info = await self._query_service.get_group_member_info(group_id, user_id) - else: - member_info = await self._query_service.get_stranger_info(user_id) - - if member_info is None: - return { - "user_id": user_id, - "user_nickname": user_id, - "user_cardname": None, - } - - return { - "user_id": user_id, - "user_nickname": str(member_info.get("nickname") or user_id), - "user_cardname": self._normalize_optional_string(member_info.get("card")), - } - - async def _build_group_info(self, group_id: str) -> Optional[Dict[str, str]]: - """构造通知消息的群信息。 - - Args: - group_id: 群号。 - - Returns: - Optional[Dict[str, str]]: 群信息字典;若不是群通知则返回 ``None``。 - """ - if not group_id: - return None - - group_info = await self._query_service.get_group_info(group_id) - group_name = str(group_info.get("group_name") or f"group_{group_id}") if group_info else f"group_{group_id}" - return {"group_id": group_id, "group_name": group_name} - - def _build_notice_text(self, payload: Mapping[str, Any], actor_name: str) -> str: - """根据 NapCat 通知事件生成可读文本。 - - Args: - payload: 原始通知事件。 - actor_name: 事件操作者显示名。 - - Returns: - str: 生成的可读通知文本。 - """ - notice_type = str(payload.get("notice_type") or "").strip() - sub_type = str(payload.get("sub_type") or "").strip() - target_id = str(payload.get("target_id") or "").strip() - - if notice_type in {"group_recall", "friend_recall"}: - return f"{actor_name} 撤回了一条消息" - if notice_type == "notify" and sub_type == "poke": - target_text = f" -> {target_id}" if target_id else "" - return f"{actor_name} 发起了戳一戳{target_text}" - if notice_type == "notify" and sub_type == "group_name": - return f"{actor_name} 修改了群名称" - if notice_type == "group_ban" and sub_type == "ban": - duration = payload.get("duration") - return f"{actor_name} 触发了群禁言,时长 {duration} 秒" - if notice_type == "group_ban" and sub_type == "lift_ban": - return f"{actor_name} 触发了解除禁言" - if notice_type == "group_upload": - file_info = payload.get("file", {}) - file_name = "" - if isinstance(file_info, Mapping): - file_name = str(file_info.get("name") or "").strip() - return f"{actor_name} 上传了文件{f':{file_name}' if file_name else ''}" - if notice_type == "group_increase": - return f"{actor_name} 加入了群聊" - if notice_type == "group_decrease": - return f"{actor_name} 离开了群聊" - if notice_type == "group_admin": - return f"{actor_name} 的群管理员状态发生变化" - if notice_type == "essence": - return f"{actor_name} 触发了精华消息事件" - if notice_type == "group_msg_emoji_like": - return f"{actor_name} 给一条消息添加了表情回应" - return f"[notice] {notice_type}.{sub_type}".strip(".") - - @staticmethod - def _normalize_optional_string(value: Any) -> Optional[str]: - """将任意值规范化为可选字符串。 - - Args: - value: 待规范化的值。 - - Returns: - Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。 - """ - if value is None: - return None - normalized_value = str(value).strip() - return normalized_value if normalized_value else None diff --git a/src/plugins/built_in/napcat_adapter/qq_queries.py b/src/plugins/built_in/napcat_adapter/qq_queries.py deleted file mode 100644 index 7d29803a..00000000 --- a/src/plugins/built_in/napcat_adapter/qq_queries.py +++ /dev/null @@ -1,170 +0,0 @@ -"""NapCat QQ 平台查询能力。""" - -from typing import TYPE_CHECKING, Any, Dict, Optional - -import asyncio - -if TYPE_CHECKING: - from napcat_adapter.transport import NapCatTransportClient - -try: - from aiohttp import ClientSession, ClientTimeout - - AIOHTTP_AVAILABLE = True -except ImportError: - ClientSession = None # type: ignore[assignment] - ClientTimeout = None # type: ignore[assignment] - AIOHTTP_AVAILABLE = False - - -class NapCatQueryService: - """NapCat QQ 平台查询服务。""" - - def __init__(self, logger: Any, transport: "NapCatTransportClient") -> None: - """初始化查询服务。 - - Args: - logger: 插件日志对象。 - transport: NapCat 传输层客户端。 - """ - self._logger = logger - self._transport = transport - - async def get_login_info(self) -> Optional[Dict[str, Any]]: - """获取当前登录账号信息。 - - Returns: - Optional[Dict[str, Any]]: 登录信息字典;失败时返回 ``None``。 - """ - return await self._call_query("get_login_info", {}) - - async def get_group_info(self, group_id: str) -> Optional[Dict[str, Any]]: - """获取群信息。 - - Args: - group_id: 群号。 - - Returns: - Optional[Dict[str, Any]]: 群信息字典;失败时返回 ``None``。 - """ - return await self._call_query("get_group_info", {"group_id": group_id}) - - async def get_group_member_info(self, group_id: str, user_id: str) -> Optional[Dict[str, Any]]: - """获取群成员信息。 - - Args: - group_id: 群号。 - user_id: 用户号。 - - Returns: - Optional[Dict[str, Any]]: 群成员信息字典;失败时返回 ``None``。 - """ - return await self._call_query( - "get_group_member_info", - {"group_id": group_id, "user_id": user_id, "no_cache": True}, - ) - - async def get_stranger_info(self, user_id: str) -> Optional[Dict[str, Any]]: - """获取陌生人信息。 - - Args: - user_id: 用户号。 - - Returns: - Optional[Dict[str, Any]]: 陌生人信息字典;失败时返回 ``None``。 - """ - return await self._call_query("get_stranger_info", {"user_id": user_id}) - - async def get_message_detail(self, message_id: str) -> Optional[Dict[str, Any]]: - """获取消息详情。 - - Args: - message_id: 消息 ID。 - - Returns: - Optional[Dict[str, Any]]: 消息详情字典;失败时返回 ``None``。 - """ - return await self._call_query("get_msg", {"message_id": message_id}) - - async def get_forward_message(self, message_id: str) -> Optional[Dict[str, Any]]: - """获取合并转发消息详情。 - - Args: - message_id: 转发消息 ID。 - - Returns: - Optional[Dict[str, Any]]: 合并转发消息详情;失败时返回 ``None``。 - """ - return await self._call_query("get_forward_msg", {"message_id": message_id}) - - async def get_record_detail(self, file_name: str, file_id: Optional[str] = None) -> Optional[Dict[str, Any]]: - """获取语音文件详情。 - - Args: - file_name: 语音文件名。 - file_id: 可选文件 ID。 - - Returns: - Optional[Dict[str, Any]]: 语音详情字典;失败时返回 ``None``。 - """ - params: Dict[str, Any] = {"file": file_name, "out_format": "wav"} - if file_id: - params["file_id"] = file_id - return await self._call_query("get_record", params) - - async def download_binary(self, url: str) -> Optional[bytes]: - """下载远程二进制资源。 - - Args: - url: 资源 URL。 - - Returns: - Optional[bytes]: 下载到的二进制内容;失败时返回 ``None``。 - """ - if not url: - return None - if not AIOHTTP_AVAILABLE or ClientSession is None or ClientTimeout is None: - self._logger.warning("NapCat 查询层缺少 aiohttp,无法下载远程资源") - return None - - try: - timeout = ClientTimeout(total=15) - async with ClientSession(timeout=timeout) as session: - async with session.get(url) as response: - if response.status != 200: - self._logger.warning(f"NapCat 远程资源下载失败: status={response.status} url={url}") - return None - return await response.read() - except asyncio.CancelledError: - raise - except Exception as exc: - self._logger.warning(f"NapCat 远程资源下载失败: {exc}") - return None - - async def _call_query(self, action_name: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """调用 OneBot 查询动作并提取 ``data`` 字段。 - - Args: - action_name: OneBot 动作名。 - params: 动作参数。 - - Returns: - Optional[Dict[str, Any]]: 查询结果中的 ``data`` 字段;失败时返回 ``None``。 - """ - try: - response = await self._transport.call_action(action_name, params) - except asyncio.CancelledError: - raise - except Exception as exc: - self._logger.warning(f"NapCat 查询动作执行失败: action={action_name} error={exc}") - return None - - if str(response.get("status") or "").lower() != "ok": - self._logger.warning( - f"NapCat 查询动作返回失败: action={action_name} " - f"message={response.get('wording') or response.get('message') or 'unknown'}" - ) - return None - - response_data = response.get("data") - return response_data if isinstance(response_data, dict) else None diff --git a/src/plugins/built_in/napcat_adapter/runtime_state.py b/src/plugins/built_in/napcat_adapter/runtime_state.py deleted file mode 100644 index b4dbfa09..00000000 --- a/src/plugins/built_in/napcat_adapter/runtime_state.py +++ /dev/null @@ -1,85 +0,0 @@ -"""NapCat 运行时路由状态管理。""" - -from typing import Any, Optional - -from napcat_adapter.config import NapCatServerConfig - - -class NapCatRuntimeStateManager: - """NapCat 适配器路由状态上报器。""" - - def __init__(self, adapter_capability: Any, logger: Any) -> None: - """初始化运行时状态管理器。 - - Args: - adapter_capability: SDK 提供的适配器能力对象。 - logger: 插件日志对象。 - """ - self._adapter_capability = adapter_capability - self._logger = logger - self._runtime_state_connected: bool = False - self._reported_account_id: Optional[str] = None - self._reported_scope: Optional[str] = None - - async def report_connected(self, account_id: str, server_config: NapCatServerConfig) -> bool: - """向 Host 上报当前连接已就绪。 - - Args: - account_id: 当前 NapCat 连接对应的机器人账号 ID。 - server_config: 当前生效的 NapCat 服务端配置。 - - Returns: - bool: 若 Host 接受了运行时状态更新,则返回 ``True``。 - """ - normalized_account_id = str(account_id).strip() - if not normalized_account_id: - return False - - scope = server_config.connection_id or None - if ( - self._runtime_state_connected - and self._reported_account_id == normalized_account_id - and self._reported_scope == scope - ): - return True - - accepted = False - try: - accepted = await self._adapter_capability.update_runtime_state( - connected=True, - account_id=normalized_account_id, - scope=server_config.connection_id, - metadata={"ws_url": server_config.build_ws_url()}, - ) - except Exception as exc: - self._logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}") - return False - - if not accepted: - self._logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新") - return False - - self._runtime_state_connected = True - self._reported_account_id = normalized_account_id - self._reported_scope = scope - self._logger.info( - f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} " - f"scope={self._reported_scope or '*'}" - ) - return True - - async def report_disconnected(self) -> None: - """向 Host 上报当前连接已断开,并撤销适配器路由。""" - if not self._runtime_state_connected: - self._reported_account_id = None - self._reported_scope = None - return - - try: - await self._adapter_capability.update_runtime_state(connected=False) - except Exception as exc: - self._logger.warning(f"NapCat 适配器上报断开状态失败: {exc}") - finally: - self._runtime_state_connected = False - self._reported_account_id = None - self._reported_scope = None diff --git a/src/plugins/built_in/napcat_adapter/transport.py b/src/plugins/built_in/napcat_adapter/transport.py deleted file mode 100644 index d20de097..00000000 --- a/src/plugins/built_in/napcat_adapter/transport.py +++ /dev/null @@ -1,322 +0,0 @@ -"""NapCat 正向 WebSocket 传输层。""" - -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Set, cast -from uuid import uuid4 - -import asyncio -import contextlib -import json - -from napcat_adapter.config import NapCatServerConfig - -if TYPE_CHECKING: - from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse - -try: - from aiohttp import ClientSession, ClientTimeout, WSMsgType - - AIOHTTP_AVAILABLE = True -except ImportError: - ClientSession = cast(Any, None) - ClientTimeout = cast(Any, None) - WSMsgType = cast(Any, None) - AIOHTTP_AVAILABLE = False - -if not TYPE_CHECKING: - AiohttpClientWebSocketResponse = Any - - -class NapCatTransportClient: - """NapCat 正向 WebSocket 客户端。""" - - def __init__( - self, - logger: Any, - on_connection_opened: Callable[[], Awaitable[None]], - on_connection_closed: Callable[[], Awaitable[None]], - on_payload: Callable[[Dict[str, Any]], Awaitable[None]], - ) -> None: - """初始化传输层客户端。 - - Args: - logger: 插件日志对象。 - on_connection_opened: 连接建立后的异步回调。 - on_connection_closed: 连接断开后的异步回调。 - on_payload: 收到非 echo 载荷后的异步回调。 - """ - self._logger = logger - self._on_connection_opened = on_connection_opened - self._on_connection_closed = on_connection_closed - self._on_payload = on_payload - self._server_config: Optional[NapCatServerConfig] = None - self._connection_task: Optional[asyncio.Task[None]] = None - self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} - self._background_tasks: Set[asyncio.Task[Any]] = set() - self._send_lock = asyncio.Lock() - self._ws: Optional[AiohttpClientWebSocketResponse] = None - self._stop_requested: bool = False - self._connection_active: bool = False - - @classmethod - def is_available(cls) -> bool: - """判断当前环境是否安装了传输层依赖。 - - Returns: - bool: 若已安装 ``aiohttp``,则返回 ``True``。 - """ - return AIOHTTP_AVAILABLE - - def configure(self, server_config: NapCatServerConfig) -> None: - """更新当前传输层使用的 NapCat 服务端配置。 - - Args: - server_config: 最新生效的 NapCat 服务端配置。 - """ - self._server_config = server_config - - async def start(self) -> None: - """启动 NapCat 正向 WebSocket 连接循环。 - - Raises: - RuntimeError: 当缺少配置或依赖时抛出。 - """ - if not self.is_available(): - raise RuntimeError("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") - if self._server_config is None: - raise RuntimeError("NapCat 适配器尚未配置 napcat_server") - if self._connection_task is not None and not self._connection_task.done(): - return - - self._stop_requested = False - self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection") - - async def stop(self) -> None: - """停止当前连接并清理所有后台任务。""" - self._stop_requested = True - connection_task = self._connection_task - self._connection_task = None - - ws = self._ws - if ws is not None and not ws.closed: - with contextlib.suppress(Exception): - await ws.close() - self._ws = None - - if connection_task is not None: - connection_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await connection_task - - await self._cancel_background_tasks() - await self._notify_connection_closed() - self._fail_pending_actions("NapCat connection closed") - - async def call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]: - """发送 OneBot 动作并等待对应的 echo 响应。 - - Args: - action_name: OneBot 动作名称。 - params: 动作参数。 - - Returns: - Dict[str, Any]: NapCat 返回的原始响应字典。 - - Raises: - RuntimeError: 当连接不可用时抛出。 - """ - ws = self._ws - server_config = self._server_config - if ws is None or ws.closed or server_config is None: - raise RuntimeError("NapCat is not connected") - - echo_id = uuid4().hex - loop = asyncio.get_running_loop() - response_future: asyncio.Future[Dict[str, Any]] = loop.create_future() - self._pending_actions[echo_id] = response_future - - request_payload = {"action": action_name, "params": params, "echo": echo_id} - try: - async with self._send_lock: - await ws.send_str(json.dumps(request_payload, ensure_ascii=False)) - return await asyncio.wait_for(response_future, timeout=server_config.action_timeout_sec) - finally: - self._pending_actions.pop(echo_id, None) - - async def _connection_loop(self) -> None: - """维护单个 WebSocket 连接,并在断开后按配置重连。""" - assert ClientSession is not None - assert ClientTimeout is not None - - while not self._stop_requested: - server_config = self._server_config - if server_config is None: - return - - ws_url = server_config.build_ws_url() - timeout = ClientTimeout(total=None, connect=10) - - try: - async with ClientSession(headers=self._build_headers(server_config), timeout=timeout) as session: - async with session.ws_connect(ws_url, heartbeat=server_config.heartbeat_interval or None) as ws: - self._ws = ws - self._logger.info(f"NapCat 适配器已连接: {ws_url}") - await self._receive_loop(ws) - except asyncio.CancelledError: - raise - except Exception as exc: - self._logger.warning(f"NapCat 适配器连接失败: {exc}") - finally: - self._ws = None - await self._notify_connection_closed() - self._fail_pending_actions("NapCat connection interrupted") - - if self._stop_requested: - break - - await asyncio.sleep(server_config.reconnect_delay_sec) - - async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None: - """持续消费 WebSocket 消息并分发处理。 - - Args: - ws: 当前活跃的 WebSocket 连接对象。 - """ - assert WSMsgType is not None - - bootstrap_task = self._create_background_task( - self._notify_connection_opened(), - "napcat_adapter.bootstrap", - ) - try: - async for ws_message in ws: - if ws_message.type != WSMsgType.TEXT: - if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: - break - continue - - payload = self._parse_json_message(ws_message.data) - if payload is None: - continue - - if echo_id := str(payload.get("echo") or "").strip(): - self._resolve_pending_action(echo_id, payload) - continue - - self._create_background_task(self._on_payload(payload), "napcat_adapter.payload") - finally: - if bootstrap_task is not None and not bootstrap_task.done(): - bootstrap_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await bootstrap_task - - def _create_background_task(self, coroutine: Awaitable[Any], name: str) -> asyncio.Task[Any]: - """创建并跟踪一个后台任务。 - - Args: - coroutine: 待执行的协程对象。 - name: 任务名。 - - Returns: - asyncio.Task[Any]: 已创建的后台任务。 - """ - task = asyncio.create_task(coroutine, name=name) - self._background_tasks.add(task) - task.add_done_callback(self._handle_background_task_completion) - return task - - def _handle_background_task_completion(self, task: asyncio.Task[Any]) -> None: - """处理后台任务结束后的清理与异常记录。 - - Args: - task: 已结束的后台任务。 - """ - self._background_tasks.discard(task) - if task.cancelled(): - return - - exception = task.exception() - if exception is not None: - self._logger.error(f"NapCat 适配器后台任务异常: {exception}", exc_info=True) - - async def _cancel_background_tasks(self) -> None: - """取消所有仍在运行的后台任务。""" - background_tasks = list(self._background_tasks) - for task in background_tasks: - task.cancel() - if background_tasks: - with contextlib.suppress(Exception): - await asyncio.gather(*background_tasks, return_exceptions=True) - self._background_tasks.clear() - - async def _notify_connection_opened(self) -> None: - """在连接建立后触发上层回调。""" - if self._connection_active: - return - - self._connection_active = True - try: - await self._on_connection_opened() - except Exception as exc: - self._logger.warning(f"NapCat 适配器连接建立回调失败: {exc}") - - async def _notify_connection_closed(self) -> None: - """在连接断开后触发上层回调。""" - if not self._connection_active: - return - - self._connection_active = False - try: - await self._on_connection_closed() - except Exception as exc: - self._logger.warning(f"NapCat 适配器断连回调失败: {exc}") - - def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None: - """解析等待中的动作响应。 - - Args: - echo_id: 动作请求对应的 echo 标识。 - payload: NapCat 返回的响应载荷。 - """ - response_future = self._pending_actions.get(echo_id) - if response_future is None or response_future.done(): - return - response_future.set_result(payload) - - def _fail_pending_actions(self, error_message: str) -> None: - """让所有等待中的动作以异常方式结束。 - - Args: - error_message: 写入异常中的错误信息。 - """ - for response_future in self._pending_actions.values(): - if not response_future.done(): - response_future.set_exception(RuntimeError(error_message)) - self._pending_actions.clear() - - def _build_headers(self, server_config: NapCatServerConfig) -> Dict[str, str]: - """构造连接 NapCat 所需的请求头。 - - Args: - server_config: 当前生效的 NapCat 服务端配置。 - - Returns: - Dict[str, str]: WebSocket 握手请求头。 - """ - return {"Authorization": f"Bearer {server_config.token}"} if server_config.token else {} - - def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]: - """解析 WebSocket 文本消息中的 JSON 数据。 - - Args: - data: WebSocket 收到的原始文本数据。 - - Returns: - Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。 - """ - try: - payload = json.loads(str(data)) - except Exception as exc: - self._logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}") - return None - - return payload if isinstance(payload, dict) else None From d07e8f90ef3458419ad37d2c132a638f71771a81 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 22 Mar 2026 18:32:14 +0800 Subject: [PATCH 29/45] fix: remove nc ada pytest --- pytests/test_napcat_adapter_codec.py | 151 -------------------------- pytests/test_napcat_adapter_config.py | 91 ---------------- pytests/test_napcat_adapter_plugin.py | 60 ---------- 3 files changed, 302 deletions(-) delete mode 100644 pytests/test_napcat_adapter_codec.py delete mode 100644 pytests/test_napcat_adapter_config.py delete mode 100644 pytests/test_napcat_adapter_plugin.py diff --git a/pytests/test_napcat_adapter_codec.py b/pytests/test_napcat_adapter_codec.py deleted file mode 100644 index 97ed1d9e..00000000 --- a/pytests/test_napcat_adapter_codec.py +++ /dev/null @@ -1,151 +0,0 @@ -from pathlib import Path -from typing import Any, Dict - -import importlib -import sys -from types import SimpleNamespace - -import pytest - - -BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" -if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: - sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) - -NapCatInboundCodec = importlib.import_module("napcat_adapter.codec_inbound").NapCatInboundCodec -NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec - - -def test_napcat_outbound_codec_supports_binary_and_forward_segments() -> None: - codec = NapCatOutboundCodec() - raw_message = [ - {"type": "text", "data": "hello"}, - {"type": "image", "data": "", "hash": "h1", "binary_data_base64": "aW1hZ2U="}, - {"type": "emoji", "data": "", "hash": "h2", "binary_data_base64": "ZW1vamk="}, - {"type": "voice", "data": "", "hash": "h3", "binary_data_base64": "dm9pY2U="}, - { - "type": "reply", - "data": { - "target_message_id": "origin-1", - "target_message_content": "origin text", - }, - }, - { - "type": "forward", - "data": [ - { - "user_id": "42", - "user_nickname": "alice", - "user_cardname": "Alice", - "message_id": "fwd-1", - "content": [{"type": "text", "data": "node-text"}], - } - ], - }, - ] - - converted = codec.convert_segments(raw_message) - - assert converted[0] == {"type": "text", "data": {"text": "hello"}} - assert converted[1]["type"] == "image" - assert converted[1]["data"]["file"] == "base64://aW1hZ2U=" - assert converted[2]["type"] == "image" - assert converted[2]["data"]["subtype"] == 1 - assert converted[3] == {"type": "record", "data": {"file": "base64://dm9pY2U="}} - assert converted[4] == {"type": "reply", "data": {"id": "origin-1"}} - assert converted[5]["type"] == "node" - assert converted[5]["data"]["name"] == "alice" - assert converted[5]["data"]["content"] == [{"type": "text", "data": {"text": "node-text"}}] - - -def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> None: - codec = NapCatOutboundCodec() - message: Dict[str, Any] = { - "message_info": { - "user_info": {"user_id": "10001", "user_nickname": "tester"}, - "additional_config": {}, - }, - "raw_message": [{"type": "text", "data": "hello"}], - } - - action_name, params = codec.build_outbound_action(message, {"target_user_id": "30001"}) - - assert action_name == "send_private_msg" - assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"} - - -class DummyQueryService: - """用于测试的轻量查询服务。""" - - async def download_binary(self, url: str) -> bytes: - """返回固定图片二进制。 - - Args: - url: 图片地址。 - - Returns: - bytes: 固定测试图片二进制。 - """ - if url: - return b"image-bytes" - return b"" - - async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None: - """返回空消息详情。 - - Args: - message_id: 目标消息 ID。 - - Returns: - Dict[str, Any] | None: 固定空结果。 - """ - del message_id - return None - - async def get_record_detail(self, file_name: str, file_id: str | None = None) -> Dict[str, Any] | None: - """返回空语音详情。 - - Args: - file_name: 语音文件名。 - file_id: 可选文件 ID。 - - Returns: - Dict[str, Any] | None: 固定空结果。 - """ - del file_name - del file_id - return None - - async def get_forward_message(self, message_id: str) -> Dict[str, Any] | None: - """返回空转发详情。 - - Args: - message_id: 转发消息 ID。 - - Returns: - Dict[str, Any] | None: 固定空结果。 - """ - del message_id - return None - - -@pytest.mark.asyncio -async def test_napcat_inbound_codec_parses_cq_string_image_segments() -> None: - codec = NapCatInboundCodec(SimpleNamespace(debug=lambda message: None), DummyQueryService()) - payload = { - "message": "[CQ:image,file=test.png,sub_type=0,url=https://example.com/test.png][CQ:at,qq=10001] 看到是国人直接给你封了", - } - - raw_message, is_at = await codec.convert_segments(payload, "10001") - - assert raw_message[0]["type"] == "image" - assert raw_message[1] == { - "type": "at", - "data": { - "target_user_id": "10001", - "target_user_nickname": None, - "target_user_cardname": None, - }, - } - assert raw_message[2] == {"type": "text", "data": " 看到是国人直接给你封了"} - assert is_at is True diff --git a/pytests/test_napcat_adapter_config.py b/pytests/test_napcat_adapter_config.py deleted file mode 100644 index 688b1a48..00000000 --- a/pytests/test_napcat_adapter_config.py +++ /dev/null @@ -1,91 +0,0 @@ -from pathlib import Path -from typing import List - -import importlib -import sys - - -BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" -if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: - sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) - -NapCatPluginSettings = importlib.import_module("napcat_adapter.config").NapCatPluginSettings - - -class DummyLogger: - """用于测试的轻量日志对象。""" - - def __init__(self) -> None: - """初始化测试日志对象。""" - self.warnings: List[str] = [] - self.errors: List[str] = [] - - def warning(self, message: str) -> None: - """记录警告日志。 - - Args: - message: 待记录的日志内容。 - """ - self.warnings.append(message) - - def error(self, message: str) -> None: - """记录错误日志。 - - Args: - message: 待记录的日志内容。 - """ - self.errors.append(message) - - -def test_parse_new_napcat_server_config() -> None: - logger = DummyLogger() - settings = NapCatPluginSettings.from_mapping( - { - "plugin": {"enabled": True, "config_version": "0.1.0"}, - "napcat_server": { - "host": "localhost", - "port": 8095, - "token": "secret", - "heartbeat_interval": 45, - "reconnect_delay_sec": 7, - "action_timeout_sec": 18, - "connection_id": "main", - }, - }, - logger, - ) - - assert settings.should_connect() is True - assert settings.napcat_server.host == "localhost" - assert settings.napcat_server.port == 8095 - assert settings.napcat_server.token == "secret" - assert settings.napcat_server.heartbeat_interval == 45.0 - assert settings.napcat_server.reconnect_delay_sec == 7.0 - assert settings.napcat_server.action_timeout_sec == 18.0 - assert settings.napcat_server.connection_id == "main" - assert settings.napcat_server.build_ws_url() == "ws://localhost:8095" - assert settings.validate(logger) is True - - -def test_parse_legacy_connection_ws_url_fallback() -> None: - logger = DummyLogger() - settings = NapCatPluginSettings.from_mapping( - { - "plugin": {"enabled": True, "config_version": "0.1.0"}, - "connection": { - "ws_url": "ws://127.0.0.1:3001", - "access_token": "legacy-token", - "heartbeat_sec": 35, - "action_timeout_sec": 12, - }, - }, - logger, - ) - - assert settings.napcat_server.host == "127.0.0.1" - assert settings.napcat_server.port == 3001 - assert settings.napcat_server.token == "legacy-token" - assert settings.napcat_server.heartbeat_interval == 35.0 - assert settings.napcat_server.action_timeout_sec == 12.0 - assert settings.validate(logger) is True - assert logger.warnings diff --git a/pytests/test_napcat_adapter_plugin.py b/pytests/test_napcat_adapter_plugin.py deleted file mode 100644 index ca550a39..00000000 --- a/pytests/test_napcat_adapter_plugin.py +++ /dev/null @@ -1,60 +0,0 @@ -"""NapCat 插件入口行为测试。""" - -from pathlib import Path -from typing import List -from types import SimpleNamespace - -import importlib -import sys - -import pytest - - -BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" -if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: - sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) - -NapCatAdapterPlugin = importlib.import_module("napcat_adapter.plugin").NapCatAdapterPlugin - - -class DummyLogger: - """用于测试的轻量日志对象。""" - - def __init__(self) -> None: - """初始化测试日志对象。""" - self.debug_messages: List[str] = [] - - def debug(self, message: str) -> None: - """记录调试日志。 - - Args: - message: 待记录的日志内容。 - """ - self.debug_messages.append(message) - - -@pytest.mark.asyncio -async def test_on_config_update_refreshes_settings_and_restarts(monkeypatch: pytest.MonkeyPatch) -> None: - """配置更新时应刷新插件配置、清空旧 settings,并触发连接重启。""" - plugin = NapCatAdapterPlugin() - plugin._ctx = SimpleNamespace(logger=DummyLogger()) - plugin._settings = object() - - restart_calls: List[dict] = [] - - async def fake_restart() -> None: - """记录一次重启调用。""" - restart_calls.append(dict(plugin._plugin_config)) - - monkeypatch.setattr(plugin, "_restart_connection_if_needed", fake_restart) - - new_config = { - "plugin": {"enabled": True, "config_version": "0.1.0"}, - "napcat_server": {"host": "127.0.0.1", "port": 3001}, - } - await plugin.on_config_update(new_config, "v2") - - assert plugin._plugin_config == new_config - assert plugin._settings is None - assert restart_calls == [new_config] - assert plugin.ctx.logger.debug_messages == ["NapCat 适配器收到配置更新通知: v2"] From e26b27c28707ee8007f40aee6f9a394e459092f4 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 10:54:29 +0800 Subject: [PATCH 30/45] refactor: update message gateway handling and remove adapter references - Changed the message sending method to return DeliveryBatch instead of DeliveryReceipt in integration.py. - Removed AdapterDeclarationPayload and related references from envelope.py, replacing them with MessageGatewayStateUpdatePayload and MessageGatewayStateUpdateResultPayload. - Updated runner_main.py to remove adapter-related logic and methods, focusing on message gateway functionality. - Added tests for message gateway runtime state synchronization and action bridge functionality in test files. --- pytests/test_adapter_runtime_state.py | 162 ---- pytests/test_message_gateway_runtime.py | 170 ++++ pytests/test_platform_io_dedupe.py | 49 +- pytests/test_plugin_runtime_action_bridge.py | 138 ++++ .../message_receive/uni_message_sender.py | 22 +- src/platform_io/__init__.py | 7 +- src/platform_io/drivers/plugin_driver.py | 67 +- src/platform_io/manager.py | 218 ++++-- src/platform_io/routing.py | 135 +--- src/platform_io/types.py | 49 +- src/plugin_runtime/host/component_registry.py | 111 ++- src/plugin_runtime/host/message_gateway.py | 20 +- src/plugin_runtime/host/supervisor.py | 735 +++++++++++------- src/plugin_runtime/integration.py | 10 +- src/plugin_runtime/protocol/envelope.py | 53 +- src/plugin_runtime/runner/runner_main.py | 35 +- 16 files changed, 1221 insertions(+), 760 deletions(-) delete mode 100644 pytests/test_adapter_runtime_state.py create mode 100644 pytests/test_message_gateway_runtime.py create mode 100644 pytests/test_plugin_runtime_action_bridge.py diff --git a/pytests/test_adapter_runtime_state.py b/pytests/test_adapter_runtime_state.py deleted file mode 100644 index e82f4c8c..00000000 --- a/pytests/test_adapter_runtime_state.py +++ /dev/null @@ -1,162 +0,0 @@ -"""适配器运行时状态同步测试。""" - -from typing import Any, Dict - -import pytest - -from src.platform_io.manager import PlatformIOManager -from src.platform_io.types import RouteKey -from src.plugin_runtime.host.supervisor import PluginSupervisor -from src.plugin_runtime.protocol.envelope import ( - AdapterDeclarationPayload, - Envelope, - MessageType, -) - - -def _make_request(plugin_id: str, payload: Dict[str, Any]) -> Envelope: - """构造一个适配器状态更新 RPC 请求。 - - Args: - plugin_id: 目标适配器插件 ID。 - payload: 请求载荷。 - - Returns: - Envelope: 标准 RPC 请求信封。 - """ - return Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="host.update_adapter_state", - plugin_id=plugin_id, - payload=payload, - ) - - -@pytest.mark.asyncio -async def test_adapter_runtime_state_binds_and_unbinds_routes(monkeypatch: pytest.MonkeyPatch) -> None: - """连接建立后应绑定路由,断开后应撤销路由。""" - import src.plugin_runtime.host.supervisor as supervisor_module - - platform_io_manager = PlatformIOManager() - monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) - - supervisor = PluginSupervisor(plugin_dirs=[]) - adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat") - await supervisor._register_adapter_driver("napcat_adapter_builtin", adapter) - - response = await supervisor._handle_update_adapter_state( - _make_request( - "napcat_adapter_builtin", - { - "connected": True, - "account_id": "10001", - "scope": "", - "metadata": {}, - }, - ) - ) - - assert response.error is None - assert response.payload["accepted"] is True - assert ( - platform_io_manager.route_table.get_active_binding( - RouteKey(platform="qq", account_id="10001"), - exact_only=True, - ).driver_id - == "adapter:napcat_adapter_builtin" - ) - assert ( - platform_io_manager.route_table.get_active_binding( - RouteKey(platform="qq"), - exact_only=True, - ).driver_id - == "adapter:napcat_adapter_builtin" - ) - - response = await supervisor._handle_update_adapter_state( - _make_request( - "napcat_adapter_builtin", - { - "connected": False, - "account_id": "", - "scope": "", - "metadata": {}, - }, - ) - ) - - assert response.error is None - assert response.payload["accepted"] is True - assert platform_io_manager.route_table.get_active_binding( - RouteKey(platform="qq", account_id="10001"), - exact_only=True, - ) is None - assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None - - -@pytest.mark.asyncio -async def test_platform_default_route_is_removed_when_multiple_exact_routes_exist( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """同一平台存在多个精确路由时不应保留默认平台路由。""" - import src.plugin_runtime.host.supervisor as supervisor_module - - platform_io_manager = PlatformIOManager() - monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) - - supervisor = PluginSupervisor(plugin_dirs=[]) - adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat") - await supervisor._register_adapter_driver("adapter_a", adapter) - await supervisor._register_adapter_driver("adapter_b", adapter) - - await supervisor._handle_update_adapter_state( - _make_request( - "adapter_a", - { - "connected": True, - "account_id": "10001", - "scope": "", - "metadata": {}, - }, - ) - ) - assert ( - platform_io_manager.route_table.get_active_binding( - RouteKey(platform="qq"), - exact_only=True, - ).driver_id - == "adapter:adapter_a" - ) - - await supervisor._handle_update_adapter_state( - _make_request( - "adapter_b", - { - "connected": True, - "account_id": "10002", - "scope": "", - "metadata": {}, - }, - ) - ) - assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None - - await supervisor._handle_update_adapter_state( - _make_request( - "adapter_b", - { - "connected": False, - "account_id": "", - "scope": "", - "metadata": {}, - }, - ) - ) - assert ( - platform_io_manager.route_table.get_active_binding( - RouteKey(platform="qq"), - exact_only=True, - ).driver_id - == "adapter:adapter_a" - ) diff --git a/pytests/test_message_gateway_runtime.py b/pytests/test_message_gateway_runtime.py new file mode 100644 index 00000000..9650bc10 --- /dev/null +++ b/pytests/test_message_gateway_runtime.py @@ -0,0 +1,170 @@ +"""消息网关运行时状态同步测试。""" + +from typing import Any, Dict + +import pytest + +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import RouteKey +from src.plugin_runtime.host.supervisor import PluginSupervisor +from src.plugin_runtime.protocol.envelope import Envelope, MessageType + + +def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope: + """构造一个 RPC 请求信封。 + + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + payload: 请求载荷。 + + Returns: + Envelope: 标准 RPC 请求信封。 + """ + + return Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method=method, + plugin_id=plugin_id, + payload=payload, + ) + + +@pytest.mark.asyncio +async def test_message_gateway_runtime_state_binds_send_and_receive_routes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """消息网关就绪后应同时绑定发送表和接收表。""" + + import src.plugin_runtime.host.supervisor as supervisor_module + + platform_io_manager = PlatformIOManager() + monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) + + supervisor = PluginSupervisor(plugin_dirs=[]) + register_response = await supervisor._handle_register_plugin( + _make_request( + "plugin.register_components", + "napcat_plugin", + { + "plugin_id": "napcat_plugin", + "plugin_version": "1.0.0", + "components": [ + { + "name": "napcat_gateway", + "component_type": "MESSAGE_GATEWAY", + "plugin_id": "napcat_plugin", + "metadata": { + "route_type": "duplex", + "platform": "qq", + "protocol": "napcat", + }, + } + ], + "capabilities_required": [], + }, + ) + ) + + assert register_response.error is None + response = await supervisor._handle_update_message_gateway_state( + _make_request( + "host.update_message_gateway_state", + "napcat_plugin", + { + "gateway_name": "napcat_gateway", + "ready": True, + "platform": "qq", + "account_id": "10001", + "scope": "primary", + "metadata": {}, + }, + ) + ) + + assert response.error is None + assert response.payload["accepted"] is True + + send_bindings = platform_io_manager.send_route_table.resolve_bindings( + RouteKey(platform="qq", account_id="10001", scope="primary") + ) + receive_bindings = platform_io_manager.receive_route_table.resolve_bindings( + RouteKey(platform="qq", account_id="10001", scope="primary") + ) + + assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"] + assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"] + + +@pytest.mark.asyncio +async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """消息网关断开后应撤销发送表和接收表中的绑定。""" + + import src.plugin_runtime.host.supervisor as supervisor_module + + platform_io_manager = PlatformIOManager() + monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) + + supervisor = PluginSupervisor(plugin_dirs=[]) + await supervisor._handle_register_plugin( + _make_request( + "plugin.register_components", + "napcat_plugin", + { + "plugin_id": "napcat_plugin", + "plugin_version": "1.0.0", + "components": [ + { + "name": "napcat_gateway", + "component_type": "MESSAGE_GATEWAY", + "plugin_id": "napcat_plugin", + "metadata": { + "route_type": "duplex", + "platform": "qq", + "protocol": "napcat", + }, + } + ], + "capabilities_required": [], + }, + ) + ) + + await supervisor._handle_update_message_gateway_state( + _make_request( + "host.update_message_gateway_state", + "napcat_plugin", + { + "gateway_name": "napcat_gateway", + "ready": True, + "platform": "qq", + "account_id": "10001", + "scope": "primary", + "metadata": {}, + }, + ) + ) + response = await supervisor._handle_update_message_gateway_state( + _make_request( + "host.update_message_gateway_state", + "napcat_plugin", + { + "gateway_name": "napcat_gateway", + "ready": False, + "platform": "qq", + "account_id": "", + "scope": "", + "metadata": {}, + }, + ) + ) + + assert response.error is None + assert response.payload["accepted"] is True + assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == [] + assert ( + platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == [] + ) diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py index 4a3cbb44..68ae95c6 100644 --- a/pytests/test_platform_io_dedupe.py +++ b/pytests/test_platform_io_dedupe.py @@ -159,6 +159,51 @@ class TestPlatformIODedupe: session_message_envelope = _build_envelope(session_message_id="session-1") payload_only_envelope = _build_envelope(payload={"message": "hello"}) - assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "qq:10001:main:dedupe-1" - assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "qq:10001:main:session-1" + assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1" + assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1" assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None + + @pytest.mark.asyncio + async def test_send_message_fans_out_to_all_matching_routes(self) -> None: + """同一路由命中多条发送链路时应全部发送。""" + + manager = PlatformIOManager() + first_driver = _StubPlatformIODriver( + DriverDescriptor( + driver_id="plugin.gateway_a", + kind=DriverKind.PLUGIN, + platform="qq", + ) + ) + second_driver = _StubPlatformIODriver( + DriverDescriptor( + driver_id="plugin.gateway_b", + kind=DriverKind.PLUGIN, + platform="qq", + ) + ) + manager.register_driver(first_driver) + manager.register_driver(second_driver) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=first_driver.driver_id, + driver_kind=first_driver.descriptor.kind, + ) + ) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=second_driver.driver_id, + driver_kind=second_driver.descriptor.kind, + ) + ) + + message = SimpleNamespace(message_id="internal-msg-1") + result = await manager.send_message(message, RouteKey(platform="qq")) + + assert result.has_success is True + assert [receipt.driver_id for receipt in result.sent_receipts] == [ + "plugin.gateway_a", + "plugin.gateway_b", + ] diff --git a/pytests/test_plugin_runtime_action_bridge.py b/pytests/test_plugin_runtime_action_bridge.py new file mode 100644 index 00000000..f2364094 --- /dev/null +++ b/pytests/test_plugin_runtime_action_bridge.py @@ -0,0 +1,138 @@ +from types import SimpleNamespace +from typing import Any + +import pytest + +from src.core.component_registry import component_registry as core_component_registry +from src.plugin_runtime.host.supervisor import PluginSupervisor +from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterPluginPayload + + +def _build_action_payload(plugin_id: str, action_name: str) -> RegisterPluginPayload: + """构造用于测试的 runtime Action 注册载荷。 + + Args: + plugin_id: 插件 ID。 + action_name: Action 名称。 + + Returns: + RegisterPluginPayload: 测试用注册载荷。 + """ + return RegisterPluginPayload( + plugin_id=plugin_id, + plugin_version="1.0.0", + components=[ + ComponentDeclaration( + name=action_name, + component_type="ACTION", + plugin_id=plugin_id, + metadata={ + "description": "发送一个测试回复", + "enabled": True, + "activation_type": "keyword", + "activation_probability": 0.25, + "activation_keywords": ["测试", "hello"], + "action_parameters": {"target": "目标对象"}, + "action_require": ["需要发送回复时使用"], + "associated_types": ["text"], + "parallel_action": True, + }, + ) + ], + ) + + +@pytest.mark.asyncio +async def test_runtime_actions_are_mirrored_into_core_registry_and_invoked(monkeypatch: pytest.MonkeyPatch) -> None: + """运行时 Action 应镜像到旧核心注册表,并可由旧 Planner 执行。""" + plugin_id = "runtime_action_bridge_plugin" + action_name = "runtime_action_bridge_test" + payload = _build_action_payload(plugin_id=plugin_id, action_name=action_name) + supervisor = PluginSupervisor(plugin_dirs=[]) + captured: dict[str, Any] = {} + + core_component_registry.remove_action(action_name) + + async def fake_invoke_plugin( + method: str, + plugin_id: str, + component_name: str, + args: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟 plugin runtime Action 调用。 + + Args: + method: RPC 方法名。 + plugin_id: 插件 ID。 + component_name: 组件名称。 + args: 调用参数。 + timeout_ms: RPC 超时时间。 + + Returns: + Any: 伪造的 RPC 响应对象。 + """ + captured["method"] = method + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")}) + + monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) + + try: + supervisor._mirror_runtime_actions_to_core_registry(payload) + + action_info = core_component_registry.get_action_info(action_name) + assert action_info is not None + assert action_info.plugin_name == plugin_id + assert action_info.description == "发送一个测试回复" + assert action_info.activation_keywords == ["测试", "hello"] + assert action_info.random_activation_probability == 0.25 + assert action_info.parallel_action is True + + executor = core_component_registry.get_action_executor(action_name) + assert executor is not None + + success, reason = await executor( + action_data={"target": "MaiBot"}, + action_reasoning="当前适合使用这个动作", + cycle_timers={"planner": 0.1}, + thinking_id="tid-1", + chat_stream=SimpleNamespace(session_id="stream-1"), + log_prefix="[test]", + shutting_down=False, + plugin_config={"enabled": True}, + ) + + assert success is True + assert reason == "runtime action executed" + assert captured["method"] == "plugin.invoke_action" + assert captured["plugin_id"] == plugin_id + assert captured["component_name"] == action_name + assert captured["args"]["stream_id"] == "stream-1" + assert captured["args"]["chat_id"] == "stream-1" + assert captured["args"]["reasoning"] == "当前适合使用这个动作" + assert captured["args"]["target"] == "MaiBot" + assert captured["args"]["action_data"] == {"target": "MaiBot"} + finally: + supervisor._remove_core_action_mirrors(plugin_id) + core_component_registry.remove_action(action_name) + + +def test_clear_runner_state_removes_mirrored_runtime_actions() -> None: + """清理 Runner 状态时应同步移除旧核心注册表中的镜像 Action。""" + plugin_id = "runtime_action_bridge_cleanup_plugin" + action_name = "runtime_action_bridge_cleanup_test" + payload = _build_action_payload(plugin_id=plugin_id, action_name=action_name) + supervisor = PluginSupervisor(plugin_dirs=[]) + + core_component_registry.remove_action(action_name) + + supervisor._mirror_runtime_actions_to_core_registry(payload) + assert core_component_registry.get_action_info(action_name) is not None + + supervisor._clear_runner_state() + + assert core_component_registry.get_action_info(action_name) is None diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 17d5d6d5..df74e459 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -125,23 +125,27 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: return True try: - from src.platform_io import DeliveryStatus from src.plugin_runtime.integration import get_plugin_runtime_manager - receipt = await get_plugin_runtime_manager().try_send_message_via_platform_io(message) - if receipt is not None: - if receipt.status == DeliveryStatus.SENT: + delivery_batch = await get_plugin_runtime_manager().try_send_message_via_platform_io(message) + if delivery_batch is not None: + if delivery_batch.has_success: + successful_driver_ids = [ + receipt.driver_id or "unknown" + for receipt in delivery_batch.sent_receipts + ] if show_log: logger.info( f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' " - f"(driver: {receipt.driver_id or 'unknown'})" + f"(drivers: {', '.join(successful_driver_ids)})" ) return True - logger.warning( - f"Platform IO 发送失败: platform={platform} driver={receipt.driver_id} " - f"status={receipt.status} error={receipt.error}" - ) + failed_details = "; ".join( + f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}" + for receipt in delivery_batch.failed_receipts + ) or "未命中任何发送路由" + logger.warning(f"Platform IO 发送失败: platform={platform} {failed_details}") return False except Exception as exc: logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}") diff --git a/src/platform_io/__init__.py b/src/platform_io/__init__.py index 380ecbb6..c91535d1 100644 --- a/src/platform_io/__init__.py +++ b/src/platform_io/__init__.py @@ -6,8 +6,9 @@ from .manager import PlatformIOManager, get_platform_io_manager from .route_key_factory import RouteKeyFactory -from .routing import RouteBindingConflictError, RouteTable +from .routing import RouteTable from .types import ( + DeliveryBatch, DeliveryReceipt, DeliveryStatus, DriverDescriptor, @@ -15,10 +16,10 @@ from .types import ( InboundMessageEnvelope, RouteBinding, RouteKey, - RouteMode, ) __all__ = [ + "DeliveryBatch", "DeliveryReceipt", "DeliveryStatus", "DriverDescriptor", @@ -27,9 +28,7 @@ __all__ = [ "PlatformIOManager", "RouteKeyFactory", "RouteBinding", - "RouteBindingConflictError", "RouteKey", - "RouteMode", "RouteTable", "get_platform_io_manager", ] diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py index dff980f8..c03204ad 100644 --- a/src/platform_io/drivers/plugin_driver.py +++ b/src/platform_io/drivers/plugin_driver.py @@ -1,4 +1,4 @@ -"""提供 Platform IO 的插件适配器驱动实现。""" +"""提供 Platform IO 的插件消息网关驱动实现。""" from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol @@ -9,45 +9,49 @@ if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage -class _AdapterSupervisorProtocol(Protocol): - """适配器驱动依赖的 Supervisor 最小协议。""" +class _GatewaySupervisorProtocol(Protocol): + """消息网关驱动依赖的 Supervisor 最小协议。""" - async def invoke_adapter( + async def invoke_message_gateway( self, plugin_id: str, - method_name: str, + component_name: str, args: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> Any: - """调用适配器插件专用方法。""" + """调用插件声明的消息网关方法。""" class PluginPlatformDriver(PlatformIODriver): - """面向适配器插件链路的 Platform IO 驱动。""" + """面向插件消息网关链路的 Platform IO 驱动。""" def __init__( self, driver_id: str, platform: str, - supervisor: _AdapterSupervisorProtocol, - send_method: str = "send_to_platform", + supervisor: _GatewaySupervisorProtocol, + component_name: str, + *, + supports_send: bool, account_id: Optional[str] = None, scope: Optional[str] = None, plugin_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: - """初始化一个插件适配器驱动。 + """初始化一个插件消息网关驱动。 Args: driver_id: Broker 内的唯一驱动 ID。 - platform: 该适配器负责的平台名称。 - supervisor: 持有该适配器插件的 Supervisor。 - send_method: 出站发送时要调用的插件方法名。 + platform: 该消息网关负责的平台名称。 + supervisor: 持有该插件的 Supervisor。 + component_name: 出站时要调用的网关组件名称。 + supports_send: 当前驱动是否具备出站能力。 account_id: 可选的账号 ID 或 self ID。 scope: 可选的额外路由作用域。 - plugin_id: 拥有该适配器实现的插件 ID。 + plugin_id: 拥有该实现的插件 ID。 metadata: 可选的额外驱动元数据。 """ + descriptor = DriverDescriptor( driver_id=driver_id, kind=DriverKind.PLUGIN, @@ -59,7 +63,8 @@ class PluginPlatformDriver(PlatformIODriver): ) super().__init__(descriptor) self._supervisor = supervisor - self._send_method = send_method + self._component_name = component_name + self._supports_send = supports_send async def send_message( self, @@ -67,16 +72,27 @@ class PluginPlatformDriver(PlatformIODriver): route_key: RouteKey, metadata: Optional[Dict[str, Any]] = None, ) -> DeliveryReceipt: - """通过适配器插件发送消息。 + """通过插件消息网关发送消息。 Args: message: 要投递的内部会话消息。 route_key: Broker 为本次投递选择的路由键。 - metadata: 本次出站投递可选的 Broker 侧元数据。 + metadata: 可选的发送元数据。 Returns: - DeliveryReceipt: 由驱动返回的规范化回执。 + DeliveryReceipt: 规范化后的发送回执。 """ + + if not self._supports_send: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error="当前消息网关仅支持接收,不支持发送", + ) + from src.plugin_runtime.host.message_utils import PluginMessageUtils plugin_id = self.descriptor.plugin_id or "" @@ -87,14 +103,14 @@ class PluginPlatformDriver(PlatformIODriver): status=DeliveryStatus.FAILED, driver_id=self.driver_id, driver_kind=self.descriptor.kind, - error="插件适配器驱动缺少 plugin_id", + error="插件消息网关驱动缺少 plugin_id", ) try: message_dict = PluginMessageUtils._session_message_to_dict(message) - response = await self._supervisor.invoke_adapter( + response = await self._supervisor.invoke_message_gateway( plugin_id=plugin_id, - method_name=self._send_method, + component_name=self._component_name, args={ "message": message_dict, "route": { @@ -119,7 +135,7 @@ class PluginPlatformDriver(PlatformIODriver): return self._build_receipt(message.message_id, route_key, response) def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt: - """将适配器调用响应归一化为出站回执。 + """将网关调用响应归一化为出站回执。 Args: internal_message_id: 内部消息 ID。 @@ -129,8 +145,9 @@ class PluginPlatformDriver(PlatformIODriver): Returns: DeliveryReceipt: 标准化后的出站回执。 """ + if getattr(response, "error", None): - error = response.error.get("message", "适配器发送失败") + error = response.error.get("message", "消息网关发送失败") return DeliveryReceipt( internal_message_id=internal_message_id, route_key=route_key, @@ -149,7 +166,7 @@ class PluginPlatformDriver(PlatformIODriver): status=DeliveryStatus.FAILED, driver_id=self.driver_id, driver_kind=self.descriptor.kind, - error=str(payload.get("result", "适配器发送失败")) if isinstance(payload, dict) else "适配器发送失败", + error=str(payload.get("result", "消息网关发送失败")) if isinstance(payload, dict) else "消息网关发送失败", ) result = payload.get("result") if isinstance(payload, dict) else None @@ -161,7 +178,7 @@ class PluginPlatformDriver(PlatformIODriver): status=DeliveryStatus.FAILED, driver_id=self.driver_id, driver_kind=self.descriptor.kind, - error=str(result.get("error", "适配器发送失败")), + error=str(result.get("error", "消息网关发送失败")), metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {}, ) external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index b1fe3bdc..c96a9ddd 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -10,7 +10,7 @@ from .outbound_tracker import OutboundTracker from .route_key_factory import RouteKeyFactory from .registry import DriverRegistry from .routing import RouteTable -from .types import DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey +from .types import DeliveryBatch, DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage @@ -21,17 +21,21 @@ InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]] class PlatformIOManager: - """统一协调双路径平台消息 IO 的路由、去重与状态跟踪。 + """统一协调平台消息 IO 的路由、去重与状态跟踪。 - 这个管理器预期会成为 legacy 适配器链路与 plugin 适配器链路之间的 - 唯一裁决点。当前地基阶段,它只提供共享状态和 Broker 侧契约,还没有 - 真正把生产流量切到新中间层。 + 与旧实现不同,这个管理器不再负责“多条链路谁该接管平台”的裁决, + 只维护发送表和接收表两张轻量路由表: + + - 发送时:解析所有命中的发送绑定并全部投递。 + - 接收时:只校验当前驱动是否已登记为可接收链路,然后全部放行给上层。 + - 去重时:仅对单条链路做技术性重放抑制,不做跨链路语义去重。 """ def __init__(self) -> None: """初始化 Broker 管理器及其内存状态。""" self._driver_registry = DriverRegistry() - self._route_table = RouteTable() + self._send_route_table = RouteTable() + self._receive_route_table = RouteTable() self._deduplicator = MessageDeduplicator() self._outbound_tracker = OutboundTracker() self._inbound_dispatcher: Optional[InboundDispatcher] = None @@ -152,13 +156,22 @@ class PlatformIOManager: return self._driver_registry @property - def route_table(self) -> RouteTable: - """返回管理器持有的路由绑定表。 + def send_route_table(self) -> RouteTable: + """返回发送路由表。""" - Returns: - RouteTable: 用于归属解析的路由绑定表。 - """ - return self._route_table + return self._send_route_table + + @property + def receive_route_table(self) -> RouteTable: + """返回接收路由表。""" + + return self._receive_route_table + + @property + def route_table(self) -> RouteTable: + """兼容旧接口,返回发送路由表。""" + + return self._send_route_table @property def deduplicator(self) -> MessageDeduplicator: @@ -257,15 +270,15 @@ class PlatformIOManager: return None removed_driver.clear_inbound_handler() - self._route_table.remove_bindings_by_driver(driver_id) + self._send_route_table.remove_bindings_by_driver(driver_id) + self._receive_route_table.remove_bindings_by_driver(driver_id) return removed_driver - def bind_route(self, binding: RouteBinding, *, replace: bool = False) -> None: - """为某个路由键绑定驱动。 + def bind_send_route(self, binding: RouteBinding) -> None: + """为某个路由键绑定发送驱动。 Args: binding: 要保存的路由绑定。 - replace: 是否允许替换已有的精确 active owner。 Raises: ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。 @@ -275,30 +288,78 @@ class PlatformIOManager: raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由") self._validate_binding_against_driver(binding, driver) - self._route_table.bind(binding, replace=replace) + self._send_route_table.bind(binding) - def unbind_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: - """移除一个或多个路由绑定。 + def bind_receive_route(self, binding: RouteBinding) -> None: + """为某个路由键绑定接收驱动。 + + Args: + binding: 要保存的路由绑定。 + + Raises: + ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。 + """ + driver = self._driver_registry.get(binding.driver_id) + if driver is None: + raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由") + + self._validate_binding_against_driver(binding, driver) + self._receive_route_table.bind(binding) + + def bind_route(self, binding: RouteBinding) -> None: + """兼容旧接口,默认同时绑定发送表和接收表。""" + + self.bind_send_route(binding) + self.bind_receive_route(binding) + + def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: + """移除发送路由绑定。 Args: route_key: 要移除绑定的路由键。 driver_id: 可选的特定驱动 ID。 """ - self._route_table.unbind(route_key, driver_id) - def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]: - """解析某个路由键当前的 active 驱动。 + self._send_route_table.unbind(route_key, driver_id) + + def unbind_receive_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: + """移除接收路由绑定。 + + Args: + route_key: 要移除绑定的路由键。 + driver_id: 可选的特定驱动 ID。 + """ + + self._receive_route_table.unbind(route_key, driver_id) + + def unbind_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: + """兼容旧接口,默认同时从发送表和接收表解绑。""" + + self.unbind_send_route(route_key, driver_id) + self.unbind_receive_route(route_key, driver_id) + + def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]: + """解析某个路由键当前命中的全部发送驱动。 Args: route_key: 要解析的路由键。 Returns: - Optional[PlatformIODriver]: 若存在 active 驱动,则返回该驱动实例。 + List[PlatformIODriver]: 当前命中的全部发送驱动。 """ - active_binding = self._route_table.get_active_binding(route_key) - if active_binding is None: - return None - return self._driver_registry.get(active_binding.driver_id) + + drivers: List[PlatformIODriver] = [] + for binding in self._send_route_table.resolve_bindings(route_key): + driver = self._driver_registry.get(binding.driver_id) + if driver is not None: + drivers.append(driver) + return drivers + + def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]: + """兼容旧接口,返回首个命中的发送驱动。""" + + drivers = self.resolve_drivers(route_key) + return drivers[0] if drivers else None @staticmethod def build_route_key_from_message(message: "SessionMessage") -> RouteKey: @@ -335,9 +396,9 @@ class PlatformIOManager: 否则返回 ``False``。 """ - if not self._route_table.accepts_inbound(envelope.route_key, envelope.driver_id): + if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id): logger.info( - "忽略非 active owner 的入站消息: route=%s driver=%s", + "忽略未登记到接收路由表的入站消息: route=%s driver=%s", envelope.route_key, envelope.driver_id, ) @@ -361,8 +422,8 @@ class PlatformIOManager: message: "SessionMessage", route_key: RouteKey, metadata: Optional[Dict[str, Any]] = None, - ) -> DeliveryReceipt: - """通过 Broker 选中的驱动发送一条消息。 + ) -> DeliveryBatch: + """通过 Broker 选中的全部发送驱动广播一条消息。 Args: message: 要投递的内部会话消息。 @@ -370,61 +431,54 @@ class PlatformIOManager: metadata: 可选的额外 Broker 侧元数据。 Returns: - DeliveryReceipt: 规范化后的出站回执。若路由不存在、驱动缺失, - 或同一消息已存在未完成的出站跟踪,也会返回失败回执而不是抛异常。 + DeliveryBatch: 规范化后的批量出站回执。 """ + drivers = self.resolve_drivers(route_key) + if not drivers: + return DeliveryBatch(internal_message_id=message.message_id, route_key=route_key) - active_binding = self._route_table.get_active_binding(route_key) - if active_binding is None: - return DeliveryReceipt( - internal_message_id=message.message_id, - route_key=route_key, - status=DeliveryStatus.FAILED, - error="未找到 active 路由绑定", - ) + receipts: List[DeliveryReceipt] = [] + for driver in drivers: + try: + self._outbound_tracker.begin_tracking( + internal_message_id=message.message_id, + route_key=route_key, + driver_id=driver.driver_id, + metadata=metadata, + ) + except ValueError as exc: + receipts.append( + DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=driver.driver_id, + driver_kind=driver.descriptor.kind, + error=str(exc), + ) + ) + continue - driver = self._driver_registry.get(active_binding.driver_id) - if driver is None: - return DeliveryReceipt( - internal_message_id=message.message_id, - route_key=route_key, - status=DeliveryStatus.FAILED, - driver_id=active_binding.driver_id, - driver_kind=active_binding.driver_kind, - error="active 路由绑定对应的驱动不存在", - ) + try: + receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata) + except Exception as exc: + receipt = DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=driver.driver_id, + driver_kind=driver.descriptor.kind, + error=str(exc), + ) - try: - self._outbound_tracker.begin_tracking( - internal_message_id=message.message_id, - route_key=route_key, - driver_id=driver.driver_id, - metadata=metadata, - ) - except ValueError as exc: - return DeliveryReceipt( - internal_message_id=message.message_id, - route_key=route_key, - status=DeliveryStatus.FAILED, - driver_id=driver.driver_id, - driver_kind=driver.descriptor.kind, - error=str(exc), - ) + self._outbound_tracker.finish_tracking(receipt) + receipts.append(receipt) - try: - receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata) - except Exception as exc: - receipt = DeliveryReceipt( - internal_message_id=message.message_id, - route_key=route_key, - status=DeliveryStatus.FAILED, - driver_id=driver.driver_id, - driver_kind=driver.descriptor.kind, - error=str(exc), - ) - - self._outbound_tracker.finish_tracking(receipt) - return receipt + return DeliveryBatch( + internal_message_id=message.message_id, + route_key=route_key, + receipts=receipts, + ) @staticmethod def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]: @@ -453,7 +507,7 @@ class PlatformIOManager: if not normalized_dedupe_key: return None - return f"{envelope.route_key.to_dedupe_scope()}:{normalized_dedupe_key}" + return f"{envelope.driver_id}:{normalized_dedupe_key}" @staticmethod def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None: diff --git a/src/platform_io/routing.py b/src/platform_io/routing.py index 7f85bbfa..2a9b41ef 100644 --- a/src/platform_io/routing.py +++ b/src/platform_io/routing.py @@ -1,52 +1,29 @@ -"""提供 Platform IO 的路由绑定存储与归属解析能力。""" +"""提供 Platform IO 的轻量路由绑定表。""" from typing import Dict, List, Optional -from .types import RouteBinding, RouteKey, RouteMode - - -class RouteBindingConflictError(ValueError): - """当同一路由键出现多个 active owner 竞争时抛出。""" +from .types import RouteBinding, RouteKey class RouteTable: - """维护路由绑定并解析路由归属。 + """维护单张路由绑定表。 - 这个表刻意保持轻量,只负责归属规则本身,不掺杂具体发送或接收逻辑。 - 它决定某个路由键当前由哪个驱动 active 接管,哪些驱动仅以 shadow - 方式旁路观测。 + 该实现不负责裁决“唯一 owner”,只负责保存绑定,并按 + ``RouteKey.resolution_order()`` 解析出候选绑定列表。 """ def __init__(self) -> None: - """初始化一个空的路由绑定表。""" + """初始化空路由绑定表。""" + self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {} - def bind(self, binding: RouteBinding, *, replace: bool = False) -> None: + def bind(self, binding: RouteBinding) -> None: """注册或更新一条路由绑定。 Args: - binding: 要注册的绑定对象。 - replace: 当精确路由键上已经存在 active owner 时,是否允许替换。 - - Raises: - RouteBindingConflictError: 当精确路由键上已存在其他 active owner, - 且 ``replace`` 为 ``False`` 时抛出。 + binding: 要保存的路由绑定。 """ - if binding.mode == RouteMode.DISABLED: - self.unbind(binding.route_key, binding.driver_id) - return - - if binding.mode == RouteMode.ACTIVE: - active_binding = self.get_active_binding(binding.route_key, exact_only=True) - if active_binding and active_binding.driver_id != binding.driver_id: - if not replace: - raise RouteBindingConflictError( - f"RouteKey {binding.route_key} 已由 {active_binding.driver_id} 接管," - f"拒绝绑定到 {binding.driver_id}" - ) - self.unbind(binding.route_key, active_binding.driver_id) - self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]: @@ -54,7 +31,7 @@ class RouteTable: Args: route_key: 要移除绑定的路由键。 - driver_id: 可选的特定驱动 ID;若为空,则移除该路由键上的全部绑定。 + driver_id: 可选的驱动 ID;为空时移除该路由键下全部绑定。 Returns: List[RouteBinding]: 被移除的绑定列表。 @@ -67,15 +44,15 @@ class RouteTable: if driver_id is None: removed = list(binding_map.values()) self._bindings.pop(route_key, None) - return removed + return self._sort_bindings(removed) removed_binding = binding_map.pop(driver_id, None) if not binding_map: self._bindings.pop(route_key, None) - return [removed_binding] if removed_binding else [] + return [removed_binding] if removed_binding is not None else [] def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]: - """移除某个驱动在所有路由键上的绑定。 + """移除某个驱动在整张表上的全部绑定。 Args: driver_id: 要移除绑定的驱动 ID。 @@ -83,9 +60,9 @@ class RouteTable: Returns: List[RouteBinding]: 被移除的绑定列表。 """ + removed_bindings: List[RouteBinding] = [] empty_route_keys: List[RouteKey] = [] - for route_key, binding_map in self._bindings.items(): removed_binding = binding_map.pop(driver_id, None) if removed_binding is not None: @@ -99,13 +76,13 @@ class RouteTable: return self._sort_bindings(removed_bindings) def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]: - """列出当前绑定。 + """列出当前路由表中的绑定。 Args: - route_key: 可选的路由键过滤条件;若为空,则返回全部路由键上的绑定。 + route_key: 可选的路由键过滤条件。 Returns: - List[RouteBinding]: 按优先级降序排列的绑定列表。 + List[RouteBinding]: 当前绑定列表。 """ if route_key is None: @@ -117,51 +94,38 @@ class RouteTable: binding_map = self._bindings.get(route_key, {}) return self._sort_bindings(list(binding_map.values())) - def get_active_binding(self, route_key: RouteKey, *, exact_only: bool = False) -> Optional[RouteBinding]: - """获取某个路由键当前生效的 active 绑定。 + def resolve_bindings(self, route_key: RouteKey) -> List[RouteBinding]: + """按从具体到宽泛的顺序解析路由候选绑定。 Args: - route_key: 要解析的路由键。 - exact_only: 是否只检查精确路由键而不做回退解析。 + route_key: 待解析的路由键。 Returns: - Optional[RouteBinding]: 若存在 active owner,则返回对应绑定。 + List[RouteBinding]: 去重后的候选绑定列表。 """ - candidate_keys = [route_key] if exact_only else route_key.resolution_order() - for candidate_key in candidate_keys: - binding_map = self._bindings.get(candidate_key, {}) - active_binding = self._pick_best_binding(binding_map, RouteMode.ACTIVE) - if active_binding is not None: - return active_binding - return None + resolved_bindings: List[RouteBinding] = [] + seen_driver_ids: set[str] = set() + for candidate_key in route_key.resolution_order(): + for binding in self.list_bindings(candidate_key): + if binding.driver_id in seen_driver_ids: + continue + seen_driver_ids.add(binding.driver_id) + resolved_bindings.append(binding) + return resolved_bindings - def get_shadow_bindings(self, route_key: RouteKey) -> List[RouteBinding]: - """获取某个精确路由键上的 shadow 绑定。 + def has_binding_for_driver(self, route_key: RouteKey, driver_id: str) -> bool: + """判断指定驱动是否在当前路由键解析结果中。 Args: - route_key: 要查看的路由键。 + route_key: 待解析的路由键。 + driver_id: 目标驱动 ID。 Returns: - List[RouteBinding]: 按优先级降序排列的 shadow 绑定列表。 - """ - binding_map = self._bindings.get(route_key, {}) - shadow_bindings = [binding for binding in binding_map.values() if binding.mode == RouteMode.SHADOW] - return self._sort_bindings(shadow_bindings) - - def accepts_inbound(self, route_key: RouteKey, driver_id: str) -> bool: - """判断某个驱动是否是当前允许入 Core 的 active owner。 - - Args: - route_key: 入站消息对应的路由键。 - driver_id: 希望将消息送入 Core 的驱动 ID。 - - Returns: - bool: 若该驱动是解析结果中的 active owner,则返回 ``True``。 + bool: 若驱动存在于解析结果中则返回 ``True``。 """ - active_binding = self.get_active_binding(route_key) - return active_binding is not None and active_binding.driver_id == driver_id + return any(binding.driver_id == driver_id for binding in self.resolve_bindings(route_key)) @staticmethod def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]: @@ -173,30 +137,5 @@ class RouteTable: Returns: List[RouteBinding]: 排序后的绑定列表。 """ + return sorted(bindings, key=lambda item: item.priority, reverse=True) - - @staticmethod - def _pick_best_binding( - binding_map: Dict[str, RouteBinding], - mode: RouteMode, - ) -> Optional[RouteBinding]: - """从绑定映射中挑选指定模式下优先级最高的一条绑定。 - - Args: - binding_map: 某个精确 ``RouteKey`` 对应的绑定映射。 - mode: 需要挑选的绑定模式。 - - Returns: - Optional[RouteBinding]: 若存在匹配模式的绑定,则返回优先级最高的一条。 - - Notes: - 这里使用单次线性扫描代替“先过滤成列表再排序”的做法,以减少 - 高频路由解析路径上的临时对象分配和排序开销。 - """ - best_binding: Optional[RouteBinding] = None - for binding in binding_map.values(): - if binding.mode != mode: - continue - if best_binding is None or binding.priority > best_binding.priority: - best_binding = binding - return best_binding diff --git a/src/platform_io/types.py b/src/platform_io/types.py index 8729b637..200eca51 100644 --- a/src/platform_io/types.py +++ b/src/platform_io/types.py @@ -19,14 +19,6 @@ class DriverKind(str, Enum): PLUGIN = "plugin" -class RouteMode(str, Enum): - """路由归属模式枚举。""" - - ACTIVE = "active" - SHADOW = "shadow" - DISABLED = "disabled" - - class DeliveryStatus(str, Enum): """统一出站回执状态枚举。""" @@ -158,21 +150,19 @@ class DriverDescriptor: @dataclass(frozen=True, slots=True) class RouteBinding: - """表示一条从路由键到驱动的归属绑定关系。 + """表示一条从路由键到驱动的绑定关系。 Attributes: route_key: 该绑定覆盖的路由键。 - driver_id: 拥有或旁路观察该路由的驱动 ID。 + driver_id: 拥有该路由的驱动 ID。 driver_kind: 绑定驱动的类型。 - mode: 绑定模式,例如 active owner 或 shadow observer。 - priority: 当同模式下存在多条绑定时使用的相对优先级。 + priority: 当同一路由键存在多条绑定时使用的相对优先级。 metadata: 预留给未来路由策略的额外绑定元数据。 """ route_key: RouteKey driver_id: str driver_kind: DriverKind - mode: RouteMode = RouteMode.ACTIVE priority: int = 0 metadata: Dict[str, Any] = field(default_factory=dict) @@ -239,3 +229,36 @@ class DeliveryReceipt: external_message_id: Optional[str] = None error: Optional[str] = None metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class DeliveryBatch: + """表示一次广播式出站投递的批量结果。 + + Attributes: + internal_message_id: 内部消息 ID。 + route_key: 本次投递使用的路由键。 + receipts: 各条路由的独立投递回执列表。 + """ + + internal_message_id: str + route_key: RouteKey + receipts: List[DeliveryReceipt] = field(default_factory=list) + + @property + def sent_receipts(self) -> List[DeliveryReceipt]: + """返回全部发送成功的回执。""" + + return [receipt for receipt in self.receipts if receipt.status == DeliveryStatus.SENT] + + @property + def failed_receipts(self) -> List[DeliveryReceipt]: + """返回全部发送失败的回执。""" + + return [receipt for receipt in self.receipts if receipt.status != DeliveryStatus.SENT] + + @property + def has_success(self) -> bool: + """返回当前批量投递是否至少命中一条成功回执。""" + + return bool(self.sent_receipts) diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 95da0052..08b0ea3b 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -119,12 +119,52 @@ class MessageGatewayEntry(ComponentEntry): """MessageGateway 组件条目""" def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - platform = metadata.get("platform") - if not platform or not isinstance(platform, str): - raise ValueError(f"MessageGateway 组件 {plugin_id}.{name} 缺少有效的 platform 字段") - self.platform: str = platform + self.route_type: str = self._normalize_route_type(metadata.get("route_type", "")) + self.platform: str = str(metadata.get("platform", "") or "").strip() + self.protocol: str = str(metadata.get("protocol", "") or "").strip() + self.account_id: str = str(metadata.get("account_id", "") or "").strip() + self.scope: str = str(metadata.get("scope", "") or "").strip() super().__init__(name, component_type, plugin_id, metadata) + @staticmethod + def _normalize_route_type(raw_value: Any) -> str: + """规范化消息网关路由类型。 + + Args: + raw_value: 原始路由类型值。 + + Returns: + str: 规范化后的路由类型。 + + Raises: + ValueError: 当路由类型不受支持时抛出。 + """ + + normalized_value = str(raw_value or "").strip().lower() + route_type_aliases = { + "send": "send", + "receive": "receive", + "recv": "receive", + "recive": "receive", + "duplex": "duplex", + } + route_type = route_type_aliases.get(normalized_value) + if route_type is None: + raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}") + return route_type + + @property + def supports_send(self) -> bool: + """返回当前网关是否支持出站。""" + + return self.route_type in {"send", "duplex"} + + @property + def supports_receive(self) -> bool: + """返回当前网关是否支持入站。""" + + return self.route_type in {"receive", "duplex"} + class ComponentRegistry: """Host-side 组件注册表 @@ -404,26 +444,71 @@ class ComponentRegistry: handlers.sort(key=lambda c: c.priority, reverse=True) return handlers - def get_message_gateways( - self, platform: str, *, enabled_only: bool = True, session_id: Optional[str] = None + def get_message_gateway( + self, + plugin_id: str, + name: str, + *, + enabled_only: bool = True, + session_id: Optional[str] = None, ) -> Optional[MessageGatewayEntry]: - """查询消息网关组件。 + """按插件和组件名获取单个消息网关。 Args: - platform (str): 平台名称 - enabled_only (bool): 是否仅返回启用的组件 - session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态 + plugin_id: 插件 ID。 + name: 网关组件名称。 + enabled_only: 是否仅返回启用的组件。 + session_id: 可选的会话 ID。 + Returns: - gateway (Optional[MessageGatewayEntry]): 符合条件的 MessageGateway 组件,可能不存在 + Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。 """ + component = self._components.get(f"{plugin_id}.{name}") + if not isinstance(component, MessageGatewayEntry): + return None + if enabled_only and not self.check_component_enabled(component, session_id): + return None + return component + + def get_message_gateways( + self, + *, + plugin_id: Optional[str] = None, + platform: str = "", + route_type: str = "", + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> List[MessageGatewayEntry]: + """查询消息网关组件列表。 + + Args: + plugin_id: 可选的插件 ID 过滤条件。 + platform: 可选的平台过滤条件。 + route_type: 可选的路由类型过滤条件。 + enabled_only: 是否仅返回启用的组件。 + session_id: 可选的会话 ID。 + + Returns: + List[MessageGatewayEntry]: 符合条件的消息网关组件列表。 + """ + + normalized_platform = str(platform or "").strip() + normalized_route_type = str(route_type or "").strip().lower() + gateways: List[MessageGatewayEntry] = [] for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values(): if not isinstance(comp, MessageGatewayEntry): continue + if plugin_id and comp.plugin_id != plugin_id: + continue if enabled_only and not self.check_component_enabled(comp, session_id): continue - if comp.platform == platform: - return comp # 返回第一个 + if normalized_platform and comp.platform != normalized_platform: + continue + if normalized_route_type and comp.route_type != normalized_route_type: + continue + gateways.append(comp) + return gateways def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]: """查询所有工具组件。 diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py index 9e8e9be6..90f94493 100644 --- a/src/plugin_runtime/host/message_gateway.py +++ b/src/plugin_runtime/host/message_gateway.py @@ -1,12 +1,9 @@ -""" -Message Gateway 模块 -适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。 -""" +"""Host 侧消息网关包装器。""" from typing import TYPE_CHECKING, Any, Dict from src.common.logger import get_logger -from src.platform_io import DeliveryStatus, get_platform_io_manager +from src.platform_io import get_platform_io_manager from .message_utils import PluginMessageUtils @@ -50,7 +47,7 @@ class MessageGateway: internal_message: 内部消息对象。 Returns: - Dict[str, Any]: 供适配器插件消费的标准消息字典。 + Dict[str, Any]: 供消息网关插件消费的标准消息字典。 """ return dict(PluginMessageUtils._session_message_to_dict(internal_message)) @@ -83,7 +80,7 @@ class MessageGateway: Args: internal_message: 系统内部的 ``SessionMessage`` 对象。 supervisor: 当前持有该消息网关的 Supervisor。 - enabled_only: 兼容旧签名的保留参数,当前由 Platform IO 统一裁决。 + enabled_only: 兼容旧签名的保留参数,当前未使用。 save_to_db: 发送成功后是否写入数据库。 Returns: @@ -98,12 +95,13 @@ class MessageGateway: return False route_key = platform_io_manager.build_route_key_from_message(internal_message) - receipt = await platform_io_manager.send_message(internal_message, route_key) - if receipt.status != DeliveryStatus.SENT: - logger.warning(f"通过适配器链路发送消息失败: {receipt.error or receipt.status}") + delivery_batch = await platform_io_manager.send_message(internal_message, route_key) + if not delivery_batch.has_success: + logger.warning("通过消息网关链路发送消息失败: 未命中任何成功回执") return False - internal_message.message_id = receipt.external_message_id or internal_message.message_id + first_successful_receipt = delivery_batch.sent_receipts[0] + internal_message.message_id = first_successful_receipt.external_message_id or internal_message.message_id if save_to_db: try: from src.common.utils.utils_message import MessageUtils diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 8a26af11..3588934e 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -9,24 +9,24 @@ import sys from src.common.logger import get_logger from src.config.config import global_config -from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager +from src.core.component_registry import component_registry as core_component_registry +from src.core.types import ActionActivationType, ActionInfo, ComponentType as CoreComponentType +from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager from src.platform_io.drivers import PluginPlatformDriver from src.platform_io.route_key_factory import RouteKeyFactory -from src.platform_io.routing import RouteBindingConflictError from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( - AdapterDeclarationPayload, - AdapterStateUpdatePayload, - AdapterStateUpdateResultPayload, BootstrapPluginPayload, ConfigUpdatedPayload, Envelope, HealthPayload, + MessageGatewayStateUpdatePayload, + MessageGatewayStateUpdateResultPayload, PROTOCOL_VERSION, - ReceiveExternalMessagePayload, ReceiveExternalMessageResultPayload, RegisterPluginPayload, ReloadPluginResultPayload, + RouteMessagePayload, RunnerReadyPayload, ShutdownPayload, UnregisterPluginPayload, @@ -49,15 +49,12 @@ if TYPE_CHECKING: logger = get_logger("plugin_runtime.host.runner_manager") -_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact" -_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default" - - @dataclass(slots=True) -class _AdapterRuntimeState: - """保存适配器插件当前的运行时连接状态。""" +class _MessageGatewayRuntimeState: + """保存消息网关当前的运行时连接状态。""" - connected: bool = False + ready: bool = False + platform: Optional[str] = None account_id: Optional[str] = None scope: Optional[str] = None metadata: Dict[str, Any] = field(default_factory=dict) @@ -109,8 +106,8 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} - self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {} - self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {} + self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {} + self._mirrored_core_actions: Dict[str, List[str]] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -289,28 +286,29 @@ class PluginRunnerSupervisor: timeout_ms, ) - async def invoke_adapter( + async def invoke_message_gateway( self, plugin_id: str, - method_name: str, + component_name: str, args: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> Envelope: - """调用适配器插件的专用方法。 + """调用插件声明的消息网关方法。 Args: - plugin_id: 目标适配器插件 ID。 - method_name: 要调用的插件方法名,例如 ``send_to_platform``。 - args: 传递给插件方法的关键字参数。 + plugin_id: 目标插件 ID。 + component_name: 消息网关组件名称。 + args: 传递给网关方法的关键字参数。 timeout_ms: RPC 超时时间,单位毫秒。 Returns: Envelope: Runner 返回的响应信封。 """ + return await self.invoke_plugin( - method="plugin.invoke_adapter", + method="plugin.invoke_message_gateway", plugin_id=plugin_id, - component_name=method_name, + component_name=component_name, args=args, timeout_ms=timeout_ms, ) @@ -468,8 +466,8 @@ class PluginRunnerSupervisor: def _register_internal_methods(self) -> None: """注册 Host 侧内部 RPC 方法。""" self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request) - self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message) - self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state) + self._rpc_server.register_method("host.route_message", self._handle_route_message) + self._rpc_server.register_method("host.update_message_gateway_state", self._handle_update_message_gateway_state) self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin) self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin) @@ -512,30 +510,26 @@ class PluginRunnerSupervisor: except Exception as exc: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + self._remove_core_action_mirrors(payload.plugin_id) self._component_registry.remove_components_by_plugin(payload.plugin_id) - if payload.plugin_id in self._registered_adapters: - await self._unregister_adapter_driver(payload.plugin_id) - - try: - if payload.adapter is not None: - await self._register_adapter_driver(payload.plugin_id, payload.adapter) - except RouteBindingConflictError as exc: - return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc)) - except Exception as exc: - return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc)) + await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) registered_count = self._component_registry.register_plugin_components( payload.plugin_id, [component.model_dump() for component in payload.components], ) self._registered_plugins[payload.plugin_id] = payload + self._message_gateway_states[payload.plugin_id] = {} + self._mirror_runtime_actions_to_core_registry(payload) return envelope.make_response( payload={ "accepted": True, "plugin_id": payload.plugin_id, "registered_components": registered_count, - "adapter_registered": payload.adapter is not None, + "message_gateways": len( + self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False) + ), } ) @@ -556,7 +550,9 @@ class PluginRunnerSupervisor: removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) self._authorization.revoke_permission_token(payload.plugin_id) removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None - await self._unregister_adapter_driver(payload.plugin_id) + self._remove_core_action_mirrors(payload.plugin_id) + await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) + self._message_gateway_states.pop(payload.plugin_id, None) return envelope.make_response( payload={ @@ -569,41 +565,321 @@ class PluginRunnerSupervisor: ) @staticmethod - def _build_adapter_driver_id(plugin_id: str) -> str: - """构造适配器驱动 ID。 + def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType: + """将运行时 Action 激活类型转换为旧核心枚举。 Args: - plugin_id: 适配器插件 ID。 + raw_value: 插件运行时声明中的激活类型值。 + + Returns: + ActionActivationType: 可供旧 Planner 使用的激活类型枚举。 + """ + normalized_value = str(raw_value or ActionActivationType.ALWAYS.value).strip().lower() + try: + return ActionActivationType(normalized_value) + except ValueError: + return ActionActivationType.ALWAYS + + @staticmethod + def _coerce_float(value: Any, default: float = 0.0) -> float: + """将任意输入尽量转换为浮点数。 + + Args: + value: 待转换的值。 + default: 转换失败时使用的默认值。 + + Returns: + float: 转换结果。 + """ + try: + return float(value) + except (TypeError, ValueError): + return default + + @staticmethod + def _build_core_action_info(plugin_id: str, component_name: str, metadata: Dict[str, Any]) -> ActionInfo: + """将运行时 Action 元数据映射为旧核心 ActionInfo。 + + Args: + plugin_id: 插件 ID。 + component_name: 组件名称。 + metadata: 运行时组件元数据。 + + Returns: + ActionInfo: 兼容旧 Planner 的动作定义。 + """ + activation_keywords = [ + str(item) + for item in (metadata.get("activation_keywords") or []) + if item is not None and str(item).strip() + ] + action_require = [ + str(item) + for item in (metadata.get("action_require") or []) + if item is not None and str(item).strip() + ] + associated_types = [ + str(item) + for item in (metadata.get("associated_types") or []) + if item is not None and str(item).strip() + ] + raw_action_parameters = metadata.get("action_parameters") or {} + action_parameters = { + str(param_name): str(param_description) + for param_name, param_description in raw_action_parameters.items() + } if isinstance(raw_action_parameters, dict) else {} + + return ActionInfo( + name=component_name, + component_type=CoreComponentType.ACTION, + description=str(metadata.get("description", "") or ""), + enabled=bool(metadata.get("enabled", True)), + plugin_name=plugin_id, + metadata=dict(metadata), + action_parameters=action_parameters, + action_require=action_require, + associated_types=associated_types, + activation_type=PluginRunnerSupervisor._coerce_action_activation_type(metadata.get("activation_type")), + random_activation_probability=PluginRunnerSupervisor._coerce_float( + metadata.get("activation_probability"), + 0.0, + ), + activation_keywords=activation_keywords, + parallel_action=bool(metadata.get("parallel_action", False)), + ) + + @staticmethod + def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str: + """从旧 ActionManager 传入参数中提取聊天流 ID。 + + Args: + kwargs: 旧动作执行器收到的关键字参数。 + + Returns: + str: 可用于新运行时 Action 的 ``stream_id``。 + """ + chat_stream = kwargs.get("chat_stream") + if chat_stream is not None: + try: + return str(chat_stream.session_id) + except AttributeError: + pass + + raw_stream_id = kwargs.get("stream_id", "") + return str(raw_stream_id or "") + + def _build_runtime_action_executor( + self, + plugin_id: str, + component_name: str, + ) -> Any: + """构造一个转发到 plugin runtime 的旧核心 Action 执行器。 + + Args: + plugin_id: 目标插件 ID。 + component_name: 目标 Action 组件名称。 + + Returns: + Callable[..., Coroutine[Any, Any, tuple[bool, str]]]: 兼容旧 ActionManager 的执行器。 + """ + + async def _executor(**kwargs: Any) -> tuple[bool, str]: + """将旧 Planner 的动作调用桥接到 plugin runtime。 + + Args: + **kwargs: 旧 ActionManager 传入的运行时上下文参数。 + + Returns: + tuple[bool, str]: ``(是否成功, 动作说明)``。 + """ + invoke_args: Dict[str, Any] = {} + action_data = kwargs.get("action_data") + if isinstance(action_data, dict): + invoke_args.update(action_data) + + stream_id = self._extract_stream_id_from_action_kwargs(kwargs) + invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {} + invoke_args["stream_id"] = stream_id + invoke_args["chat_id"] = stream_id + invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "") + + thinking_id = kwargs.get("thinking_id") + if thinking_id is not None: + invoke_args["thinking_id"] = str(thinking_id) + + cycle_timers = kwargs.get("cycle_timers") + if isinstance(cycle_timers, dict): + invoke_args["cycle_timers"] = cycle_timers + + plugin_config = kwargs.get("plugin_config") + if isinstance(plugin_config, dict): + invoke_args["plugin_config"] = plugin_config + + log_prefix = kwargs.get("log_prefix") + if isinstance(log_prefix, str): + invoke_args["log_prefix"] = log_prefix + + shutting_down = kwargs.get("shutting_down") + if isinstance(shutting_down, bool): + invoke_args["shutting_down"] = shutting_down + + try: + response = await self.invoke_plugin( + method="plugin.invoke_action", + plugin_id=plugin_id, + component_name=component_name, + args=invoke_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True) + return False, str(exc) + + payload = response.payload if isinstance(response.payload, dict) else {} + success = bool(payload.get("success", False)) + result = payload.get("result") + + if isinstance(result, (list, tuple)): + if len(result) >= 2: + return bool(result[0]), "" if result[1] is None else str(result[1]) + if len(result) == 1: + return bool(result[0]), "" + + if success: + return True, "" if result is None else str(result) + return False, "" if result is None else str(result) + + return _executor + + def _mirror_runtime_actions_to_core_registry(self, payload: RegisterPluginPayload) -> None: + """将 plugin runtime 中声明的 Action 镜像到旧核心注册表。 + + Args: + payload: 当前插件的注册载荷。 + """ + mirrored_action_names: List[str] = [] + + for component in payload.components: + if str(component.component_type).upper() != CoreComponentType.ACTION.name: + continue + + action_info = self._build_core_action_info( + plugin_id=payload.plugin_id, + component_name=component.name, + metadata=component.metadata, + ) + action_executor = self._build_runtime_action_executor( + plugin_id=payload.plugin_id, + component_name=component.name, + ) + registered = core_component_registry.register_action(action_info, action_executor) + if not registered: + logger.warning( + f"运行时 Action {payload.plugin_id}.{component.name} 无法镜像到旧核心注册表," + "可能与现有 Action 重名" + ) + continue + mirrored_action_names.append(component.name) + + if mirrored_action_names: + self._mirrored_core_actions[payload.plugin_id] = mirrored_action_names + + def _remove_core_action_mirrors(self, plugin_id: str) -> None: + """移除某个插件镜像到旧核心注册表的所有 Action。 + + Args: + plugin_id: 目标插件 ID。 + """ + mirrored_action_names = self._mirrored_core_actions.pop(plugin_id, []) + for action_name in mirrored_action_names: + core_component_registry.remove_action(action_name) + + @staticmethod + def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str: + """构造消息网关驱动 ID。 + + Args: + plugin_id: 插件 ID。 + gateway_name: 网关组件名称。 Returns: str: 对应 Platform IO 中的驱动 ID。 """ - return f"adapter:{plugin_id}" - async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None: - """将适配器插件驱动注册到 Platform IO。 + return f"gateway:{plugin_id}:{gateway_name}" + + @staticmethod + def _normalize_runtime_route_value(value: str) -> Optional[str]: + """规范化运行时路由字段。 Args: - plugin_id: 适配器插件 ID。 - adapter: 经过校验的适配器声明。 + value: 待规范化的原始字符串。 - Raises: - ValueError: 当驱动注册失败时抛出。 + Returns: + Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。 """ - await self._unregister_adapter_driver(plugin_id) + + normalized_value = str(value or "").strip() + return normalized_value or None + + def _resolve_message_gateway_entry( + self, + plugin_id: str, + gateway_name: str, + ) -> Optional[Any]: + """解析指定插件的消息网关组件。 + + Args: + plugin_id: 插件 ID。 + gateway_name: 网关组件名称;为空时按兼容规则推断。 + + Returns: + Optional[Any]: 匹配到的消息网关组件条目。 + """ + + if gateway_name: + return self._component_registry.get_message_gateway( + plugin_id=plugin_id, + name=gateway_name, + enabled_only=False, + ) + + gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False) + if len(gateways) == 1: + return gateways[0] + + return None + + async def _register_message_gateway_driver( + self, + plugin_id: str, + gateway_entry: Any, + route_key: RouteKey, + ) -> None: + """为消息网关注册驱动并绑定发送/接收路由。 + + Args: + plugin_id: 插件 ID。 + gateway_entry: 消息网关组件条目。 + route_key: 当前链路对应的路由键。 + """ + + await self._unregister_message_gateway_driver(plugin_id, gateway_entry.name) platform_io_manager = get_platform_io_manager() driver = PluginPlatformDriver( - driver_id=self._build_adapter_driver_id(plugin_id), - platform=adapter.platform, - account_id=adapter.account_id or None, - scope=adapter.scope or None, + driver_id=self._build_message_gateway_driver_id(plugin_id, gateway_entry.name), + platform=route_key.platform, + account_id=route_key.account_id, + scope=route_key.scope, plugin_id=plugin_id, - send_method=adapter.send_method, + component_name=gateway_entry.name, + supports_send=bool(gateway_entry.supports_send), supervisor=self, metadata={ - "protocol": adapter.protocol, - **adapter.metadata, + "protocol": gateway_entry.protocol, + "route_type": gateway_entry.route_type, + **gateway_entry.metadata, }, ) @@ -620,20 +896,36 @@ class PluginRunnerSupervisor: platform_io_manager.unregister_driver(driver.driver_id) raise - self._registered_adapters[plugin_id] = adapter - self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState() + binding_metadata = { + "plugin_id": plugin_id, + "gateway_name": gateway_entry.name, + "protocol": gateway_entry.protocol, + "route_type": gateway_entry.route_type, + **gateway_entry.metadata, + } + binding = RouteBinding( + route_key=route_key, + driver_id=driver.driver_id, + driver_kind=DriverKind.PLUGIN, + metadata=binding_metadata, + ) + if gateway_entry.supports_send: + platform_io_manager.bind_send_route(binding) + if gateway_entry.supports_receive: + platform_io_manager.bind_receive_route(binding) - async def _unregister_adapter_driver(self, plugin_id: str) -> None: - """从 Platform IO 注销一个适配器驱动。 + async def _unregister_message_gateway_driver(self, plugin_id: str, gateway_name: str) -> None: + """从 Platform IO 注销单个消息网关驱动。 Args: - plugin_id: 适配器插件 ID。 + plugin_id: 插件 ID。 + gateway_name: 网关组件名称。 """ - platform_io_manager = get_platform_io_manager() - driver_id = self._build_adapter_driver_id(plugin_id) - adapter = self._registered_adapters.get(plugin_id) - self._remove_adapter_route_bindings(plugin_id) + platform_io_manager = get_platform_io_manager() + driver_id = self._build_message_gateway_driver_id(plugin_id, gateway_name) + platform_io_manager.send_route_table.remove_bindings_by_driver(driver_id) + platform_io_manager.receive_route_table.remove_bindings_by_driver(driver_id) with contextlib.suppress(Exception): if platform_io_manager.is_started: @@ -641,204 +933,83 @@ class PluginRunnerSupervisor: else: platform_io_manager.unregister_driver(driver_id) - if adapter is not None: - self._refresh_platform_default_route(adapter.platform) - - self._registered_adapters.pop(plugin_id, None) - self._adapter_runtime_states.pop(plugin_id, None) - - async def _unregister_all_adapter_drivers(self) -> None: - """注销当前 Supervisor 管理的全部适配器驱动。""" - plugin_ids = list(self._registered_adapters.keys()) - for plugin_id in plugin_ids: - await self._unregister_adapter_driver(plugin_id) - - def _remove_adapter_route_bindings(self, plugin_id: str) -> None: - """移除某个适配器驱动当前持有的全部路由绑定。 + async def _unregister_all_message_gateway_drivers_for_plugin(self, plugin_id: str) -> None: + """注销指定插件的全部消息网关驱动。 Args: - plugin_id: 适配器插件 ID。 + plugin_id: 插件 ID。 """ - platform_io_manager = get_platform_io_manager() - platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id)) - @staticmethod - def _normalize_runtime_route_value(value: str) -> Optional[str]: - """规范化适配器运行时路由字段。 + gateway_names = list(self._message_gateway_states.get(plugin_id, {}).keys()) + for gateway_name in gateway_names: + await self._unregister_message_gateway_driver(plugin_id, gateway_name) - Args: - value: 待规范化的原始字符串。 - - Returns: - Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。 - """ - normalized_value = str(value).strip() - return normalized_value or None - - def _build_runtime_route_key( + def _build_message_gateway_route_key( self, - adapter: AdapterDeclarationPayload, - payload: AdapterStateUpdatePayload, + gateway_entry: Any, + payload: MessageGatewayStateUpdatePayload, ) -> RouteKey: - """根据运行时状态更新构造适配器生效路由键。 + """根据消息网关运行时状态构造路由键。 Args: - adapter: 当前适配器声明。 - payload: 适配器上报的运行时状态。 + gateway_entry: 消息网关组件条目。 + payload: 网关上报的运行时状态。 Returns: - RouteKey: 当前连接应接管的精确路由键。 + RouteKey: 当前链路对应的路由键。 Raises: - ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。 + ValueError: 当平台信息缺失时抛出。 """ - runtime_account_id = self._normalize_runtime_route_value(payload.account_id) - runtime_scope = self._normalize_runtime_route_value(payload.scope) - if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id: - raise ValueError( - f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致" - ) - if adapter.scope and runtime_scope and adapter.scope != runtime_scope: - raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致") + platform = str(payload.platform or gateway_entry.platform or "").strip() + if not platform: + raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称") return RouteKey( - platform=adapter.platform, - account_id=runtime_account_id or adapter.account_id or None, - scope=runtime_scope or adapter.scope or None, + platform=platform, + account_id=self._normalize_runtime_route_value(payload.account_id) or gateway_entry.account_id or None, + scope=self._normalize_runtime_route_value(payload.scope) or gateway_entry.scope or None, ) - def _bind_runtime_exact_route( + def _apply_message_gateway_state( self, plugin_id: str, - adapter: AdapterDeclarationPayload, - route_key: RouteKey, - ) -> None: - """为适配器连接绑定精确生效路由。 + gateway_entry: Any, + payload: MessageGatewayStateUpdatePayload, + ) -> Tuple[_MessageGatewayRuntimeState, Dict[str, Any]]: + """应用消息网关运行时状态,并同步 Platform IO 路由。 Args: - plugin_id: 适配器插件 ID。 - adapter: 当前适配器声明。 - route_key: 当前连接对应的精确路由键。 + plugin_id: 插件 ID。 + gateway_entry: 消息网关组件条目。 + payload: 网关上报的运行时状态。 - Raises: - RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。 + Returns: + Tuple[_MessageGatewayRuntimeState, Dict[str, Any]]: 更新后的状态与路由键字典。 """ - platform_io_manager = get_platform_io_manager() - platform_io_manager.bind_route( - RouteBinding( - route_key=route_key, - driver_id=self._build_adapter_driver_id(plugin_id), - driver_kind=DriverKind.PLUGIN, - metadata={ - "plugin_id": plugin_id, - "protocol": adapter.protocol, - "binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT, - }, + + plugin_states = self._message_gateway_states.setdefault(plugin_id, {}) + if not payload.ready: + runtime_state = _MessageGatewayRuntimeState( + ready=False, + platform=self._normalize_runtime_route_value(payload.platform) or gateway_entry.platform or None, + account_id=self._normalize_runtime_route_value(payload.account_id) or gateway_entry.account_id or None, + scope=self._normalize_runtime_route_value(payload.scope) or gateway_entry.scope or None, + metadata=dict(payload.metadata), ) - ) - - def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]: - """列出某个平台上由 Host 动态维护的精确适配器绑定。 - - Args: - platform: 目标平台名称。 - - Returns: - List[RouteBinding]: 当前平台上全部动态精确绑定。 - """ - platform_io_manager = get_platform_io_manager() - return [ - binding - for binding in platform_io_manager.route_table.list_bindings() - if binding.mode == RouteMode.ACTIVE - and binding.route_key.platform == platform - and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT - ] - - def _refresh_platform_default_route(self, platform: str) -> None: - """根据当前精确绑定数量刷新平台级默认路由。 - - 当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条 - ``RouteKey(platform=)`` 形式的默认路由,方便缺少账号维度的 - 出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1,则撤销 - 由 Host 自动维护的默认路由,避免出现隐式歧义。 - - Args: - platform: 目标平台名称。 - """ - platform_io_manager = get_platform_io_manager() - default_route_key = RouteKey(platform=platform) - existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True) - - if existing_default_binding is not None: - binding_role = existing_default_binding.metadata.get("binding_role") - if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT: - return - platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id) - - exact_bindings = self._list_runtime_exact_bindings(platform) - if len(exact_bindings) != 1: - return - - exact_binding = exact_bindings[0] - if exact_binding.route_key == default_route_key: - return - - platform_io_manager.bind_route( - RouteBinding( - route_key=default_route_key, - driver_id=exact_binding.driver_id, - driver_kind=exact_binding.driver_kind, - metadata={ - "plugin_id": exact_binding.metadata.get("plugin_id", ""), - "protocol": exact_binding.metadata.get("protocol", ""), - "binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT, - }, - ), - replace=True, - ) - - def _apply_adapter_runtime_state( - self, - plugin_id: str, - adapter: AdapterDeclarationPayload, - payload: AdapterStateUpdatePayload, - ) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]: - """应用适配器运行时状态,并同步 Platform IO 路由。 - - Args: - plugin_id: 适配器插件 ID。 - adapter: 当前适配器声明。 - payload: 适配器上报的运行时状态。 - - Returns: - Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及 - 供 RPC 响应返回的路由键字典。 - - Raises: - RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。 - ValueError: 当运行时路由信息不合法时抛出。 - """ - if not payload.connected: - self._remove_adapter_route_bindings(plugin_id) - self._refresh_platform_default_route(adapter.platform) - runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata)) - self._adapter_runtime_states[plugin_id] = runtime_state + plugin_states[gateway_entry.name] = runtime_state return runtime_state, {} - route_key = self._build_runtime_route_key(adapter, payload) - self._remove_adapter_route_bindings(plugin_id) - self._bind_runtime_exact_route(plugin_id, adapter, route_key) - self._refresh_platform_default_route(adapter.platform) - - runtime_state = _AdapterRuntimeState( - connected=True, + route_key = self._build_message_gateway_route_key(gateway_entry, payload) + runtime_state = _MessageGatewayRuntimeState( + ready=True, + platform=route_key.platform, account_id=route_key.account_id, scope=route_key.scope, metadata=dict(payload.metadata), ) - self._adapter_runtime_states[plugin_id] = runtime_state + plugin_states[gateway_entry.name] = runtime_state return runtime_state, { "platform": route_key.platform, "account_id": route_key.account_id, @@ -856,8 +1027,9 @@ class PluginRunnerSupervisor: Args: session_message: 已构造好的内部消息对象。 route_key: Host 为该消息解析出的标准路由键。 - route_metadata: 适配器通过 RPC 补充的原始路由辅助元数据。 + route_metadata: 插件通过 RPC 补充的原始路由辅助元数据。 """ + additional_config = session_message.message_info.additional_config if not isinstance(additional_config, dict): additional_config = {} @@ -877,45 +1049,49 @@ class PluginRunnerSupervisor: def _build_inbound_route_key( self, - adapter: AdapterDeclarationPayload, + gateway_entry: Any, + runtime_state: _MessageGatewayRuntimeState, message: Dict[str, Any], route_metadata: Dict[str, Any], ) -> RouteKey: - """为适配器入站消息构造归一路由键。 + """为入站消息构造归一路由键。 Args: - adapter: 当前适配器声明。 + gateway_entry: 接收消息的网关组件条目。 + runtime_state: 当前网关的运行时状态。 message: 标准消息字典。 route_metadata: 插件补充的路由辅助元数据。 Returns: RouteKey: 供 Platform IO 使用的规范化路由键。 - - Raises: - ValueError: 消息平台字段与适配器平台声明不一致时抛出。 """ - message_platform = str(message.get("platform") or adapter.platform).strip() - if message_platform != adapter.platform: - raise ValueError( - f"外部消息平台 {message_platform} 与适配器 {adapter.platform} 不一致" - ) + + platform = str( + message.get("platform") + or route_metadata.get("platform") + or runtime_state.platform + or gateway_entry.platform + or "" + ).strip() + if not platform: + raise ValueError(f"消息网关 {gateway_entry.full_name} 的入站消息缺少平台信息") try: route_key = RouteKeyFactory.from_message_dict(message) except Exception: - route_key = RouteKey(platform=message_platform) + route_key = RouteKey(platform=platform) route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata) - account_id = route_key.account_id or route_account_id or adapter.account_id or None - scope = route_key.scope or route_scope or adapter.scope or None + account_id = route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None + scope = route_key.scope or route_scope or runtime_state.scope or gateway_entry.scope or None return RouteKey( - platform=message_platform, + platform=platform, account_id=account_id, scope=scope, ) - async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope: - """处理适配器插件上报的运行时状态更新。 + async def _handle_update_message_gateway_state(self, envelope: Envelope) -> Envelope: + """处理消息网关上报的运行时状态更新。 Args: envelope: RPC 请求信封。 @@ -923,38 +1099,42 @@ class PluginRunnerSupervisor: Returns: Envelope: 状态更新处理结果。 """ + try: - payload = AdapterStateUpdatePayload.model_validate(envelope.payload) + payload = MessageGatewayStateUpdatePayload.model_validate(envelope.payload) except Exception as exc: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - adapter = self._registered_adapters.get(envelope.plugin_id) - if adapter is None: + gateway_entry = self._resolve_message_gateway_entry(envelope.plugin_id, payload.gateway_name) + if gateway_entry is None: return envelope.make_error_response( ErrorCode.E_METHOD_NOT_ALLOWED.value, - f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态", + f"插件 {envelope.plugin_id} 未声明消息网关 {payload.gateway_name or ''}", ) try: - runtime_state, route_key_dict = self._apply_adapter_runtime_state( + if payload.ready: + route_key = self._build_message_gateway_route_key(gateway_entry, payload) + await self._register_message_gateway_driver(envelope.plugin_id, gateway_entry, route_key) + else: + await self._unregister_message_gateway_driver(envelope.plugin_id, gateway_entry.name) + runtime_state, route_key_dict = self._apply_message_gateway_state( plugin_id=envelope.plugin_id, - adapter=adapter, + gateway_entry=gateway_entry, payload=payload, ) - except RouteBindingConflictError as exc: - return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc)) except Exception as exc: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - response = AdapterStateUpdateResultPayload( + response = MessageGatewayStateUpdateResultPayload( accepted=True, - connected=runtime_state.connected, + ready=runtime_state.ready, route_key=route_key_dict, ) return envelope.make_response(payload=response.model_dump()) - async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope: - """处理适配器插件上报的外部入站消息。 + async def _handle_route_message(self, envelope: Envelope) -> Envelope: + """处理消息网关上报的外部入站消息。 Args: envelope: RPC 请求信封。 @@ -962,21 +1142,33 @@ class PluginRunnerSupervisor: Returns: Envelope: 注入结果响应。 """ + try: - payload = ReceiveExternalMessagePayload.model_validate(envelope.payload) + payload = RouteMessagePayload.model_validate(envelope.payload) except Exception as exc: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - adapter = self._registered_adapters.get(envelope.plugin_id) - if adapter is None: + gateway_entry = self._resolve_message_gateway_entry(envelope.plugin_id, payload.gateway_name) + if gateway_entry is None or not bool(gateway_entry.supports_receive): return envelope.make_error_response( ErrorCode.E_METHOD_NOT_ALLOWED.value, - f"插件 {envelope.plugin_id} 未声明为适配器,不能注入外部消息", + f"插件 {envelope.plugin_id} 未声明可接收的消息网关 {payload.gateway_name}", + ) + + runtime_state = self._message_gateway_states.get(envelope.plugin_id, {}).get( + gateway_entry.name, + _MessageGatewayRuntimeState(), + ) + if not runtime_state.ready: + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"消息网关 {gateway_entry.full_name} 尚未就绪,不能注入外部消息", ) try: route_key = self._build_inbound_route_key( - adapter=adapter, + gateway_entry=gateway_entry, + runtime_state=runtime_state, message=payload.message, route_metadata=payload.route_metadata, ) @@ -989,7 +1181,7 @@ class PluginRunnerSupervisor: accepted = await platform_io_manager.accept_inbound( InboundMessageEnvelope( route_key=route_key, - driver_id=self._build_adapter_driver_id(envelope.plugin_id), + driver_id=self._build_message_gateway_driver_id(envelope.plugin_id, gateway_entry.name), driver_kind=DriverKind.PLUGIN, external_message_id=payload.external_message_id or str(payload.message.get("message_id") or "") or None, dedupe_key=payload.dedupe_key or None, @@ -997,7 +1189,8 @@ class PluginRunnerSupervisor: payload=payload.message, metadata={ "plugin_id": envelope.plugin_id, - "protocol": adapter.protocol, + "gateway_name": gateway_entry.name, + "protocol": gateway_entry.protocol, **payload.route_metadata, }, ) @@ -1138,7 +1331,8 @@ class PluginRunnerSupervisor: await self._stderr_drain_task self._stderr_drain_task = None - await self._unregister_all_adapter_drivers() + for plugin_id in list(self._message_gateway_states.keys()): + await self._unregister_all_message_gateway_drivers_for_plugin(plugin_id) self._clear_runner_state() async def _health_check_loop(self) -> None: @@ -1213,11 +1407,12 @@ class PluginRunnerSupervisor: def _clear_runner_state(self) -> None: """清理当前 Runner 对应的 Host 侧注册状态。""" + for plugin_id in list(self._mirrored_core_actions.keys()): + self._remove_core_action_mirrors(plugin_id) self._authorization.clear() self._component_registry.clear() self._registered_plugins.clear() - self._registered_adapters.clear() - self._adapter_runtime_states.clear() + self._message_gateway_states.clear() self._runner_ready_events = asyncio.Event() self._runner_ready_payloads = RunnerReadyPayload() self._rpc_server.clear_handshake_state() diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index bf85669b..b74b2d46 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -18,7 +18,7 @@ import tomlkit from src.common.logger import get_logger from src.config.config import global_config from src.config.file_watcher import FileChange, FileWatcher -from src.platform_io import DeliveryReceipt, InboundMessageEnvelope, get_platform_io_manager +from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager from src.plugin_runtime.capabilities import ( RuntimeComponentCapabilityMixin, RuntimeCoreCapabilityMixin, @@ -351,15 +351,15 @@ class PluginRuntimeManager( async def try_send_message_via_platform_io( self, message: "SessionMessage", - ) -> Optional[DeliveryReceipt]: + ) -> Optional[DeliveryBatch]: """尝试通过 Platform IO 中间层发送消息。 Args: message: 待发送的内部会话消息。 Returns: - Optional[DeliveryReceipt]: 若当前消息存在 active 路由,则返回实际发送 - 结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。 + Optional[DeliveryBatch]: 若当前消息命中了至少一条发送路由,则返回 + 实际发送结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。 """ if not self._started: return None @@ -374,7 +374,7 @@ class PluginRuntimeManager( logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}") return None - if platform_io_manager.resolve_driver(route_key) is None: + if not platform_io_manager.resolve_drivers(route_key): return None return await platform_io_manager.send_message(message, route_key) diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index f68657fa..cbbb71be 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -156,8 +156,6 @@ class RegisterPluginPayload(BaseModel): """插件版本""" components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表") """组件列表""" - adapter: Optional["AdapterDeclarationPayload"] = Field(default=None, description="可选的适配器声明") - """可选的适配器声明""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") """所需能力列表""" @@ -287,50 +285,39 @@ class ReloadPluginResultPayload(BaseModel): """重载失败的插件及原因""" -class AdapterDeclarationPayload(BaseModel): - """适配器插件声明载荷。""" +class MessageGatewayStateUpdatePayload(BaseModel): + """消息网关运行时状态更新载荷。""" - platform: str = Field(description="适配器负责的平台名称,例如 qq") - """适配器负责的平台名称,例如 qq""" - protocol: str = Field(default="", description="接入协议或实现名称,例如 napcat") - """接入协议或实现名称,例如 napcat""" - account_id: str = Field(default="", description="可选的账号 ID 或 self_id") - """可选的账号 ID 或 self_id""" - scope: str = Field(default="", description="可选的路由作用域") - """可选的路由作用域""" - send_method: str = Field(default="send_to_platform", description="Host 出站调用的插件方法名") - """Host 出站调用的插件方法名""" - metadata: Dict[str, Any] = Field(default_factory=dict, description="适配器附加元数据") - """适配器附加元数据""" - - -class AdapterStateUpdatePayload(BaseModel): - """适配器运行时状态更新载荷。""" - - connected: bool = Field(description="适配器当前是否已连接并准备接管路由") - """适配器当前是否已连接并准备接管路由""" - account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id") - """当前连接对应的账号 ID 或 self_id""" - scope: str = Field(default="", description="当前连接对应的可选路由作用域") - """当前连接对应的可选路由作用域""" + gateway_name: str = Field(description="消息网关组件名称") + """消息网关组件名称""" + ready: bool = Field(description="当前链路是否已经就绪") + """当前链路是否已经就绪""" + platform: str = Field(default="", description="当前链路负责的平台名称") + """当前链路负责的平台名称""" + account_id: str = Field(default="", description="当前链路对应的账号 ID 或 self_id") + """当前链路对应的账号 ID 或 self_id""" + scope: str = Field(default="", description="当前链路对应的可选路由作用域") + """当前链路对应的可选路由作用域""" metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据") """可选的运行时状态元数据""" -class AdapterStateUpdateResultPayload(BaseModel): - """适配器运行时状态更新结果载荷。""" +class MessageGatewayStateUpdateResultPayload(BaseModel): + """消息网关运行时状态更新结果载荷。""" accepted: bool = Field(description="Host 是否接受了本次状态更新") """Host 是否接受了本次状态更新""" - connected: bool = Field(description="Host 记录的当前连接状态") - """Host 记录的当前连接状态""" + ready: bool = Field(description="Host 记录的当前就绪状态") + """Host 记录的当前就绪状态""" route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键") """当前生效的路由键""" -class ReceiveExternalMessagePayload(BaseModel): - """适配器插件向 Host 注入外部消息的请求载荷。""" +class RouteMessagePayload(BaseModel): + """消息网关向 Host 路由外部消息的请求载荷。""" + gateway_name: str = Field(description="接收消息的网关组件名称") + """接收消息的网关组件名称""" message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典") """符合 MessageDict 结构的标准消息字典""" route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据") diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 8078c88b..3a50e2f7 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -25,7 +25,6 @@ import tomllib from src.common.logger import get_console_handler, get_logger, initialize_logging from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( - AdapterDeclarationPayload, BootstrapPluginPayload, ComponentDeclaration, Envelope, @@ -219,7 +218,7 @@ class PluginRunner: """为插件实例创建并注入 PluginContext。 对新版 MaiBotPlugin(具有 _set_context 方法):创建 PluginContext 并注入。 - 对旧版 LegacyPluginAdapter(具有 _set_context 方法,由适配器代理):同上。 + 对旧版 LegacyPluginAdapter(具有 _set_context 方法,由兼容代理封装):同上。 """ if not hasattr(instance, "_set_context"): return @@ -293,7 +292,7 @@ class PluginRunner: self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) - self._rpc_client.register_method("plugin.invoke_adapter", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) @@ -331,29 +330,6 @@ class PluginRunner: """撤销 bootstrap 期间为插件签发的能力令牌。""" await self._bootstrap_plugin(meta, capabilities_required=[]) - def _collect_adapter_declaration(self, meta: PluginMeta) -> Optional[AdapterDeclarationPayload]: - """从插件实例中提取适配器声明。 - - Args: - meta: 待提取声明的插件元数据。 - - Returns: - Optional[AdapterDeclarationPayload]: 若插件声明了适配器角色,则返回 - 经过校验的适配器声明;否则返回 ``None``。 - - Raises: - ValueError: 插件导出的适配器声明结构非法时抛出。 - """ - instance = meta.instance - if not hasattr(instance, "get_adapter_info"): - return None - - adapter_info = instance.get_adapter_info() - if adapter_info is None: - return None - - return AdapterDeclarationPayload.model_validate(adapter_info) - async def _register_plugin(self, meta: PluginMeta) -> bool: """向 Host 注册单个插件。 @@ -379,17 +355,10 @@ class PluginRunner: for comp_info in instance.get_components() ) - try: - adapter = self._collect_adapter_declaration(meta) - except Exception as exc: - logger.error(f"插件 {meta.plugin_id} 适配器声明非法: {exc}", exc_info=True) - return False - reg_payload = RegisterPluginPayload( plugin_id=meta.plugin_id, plugin_version=meta.version, components=components, - adapter=adapter, capabilities_required=meta.capabilities_required, ) From d07915eea04b6b60b6f15b462ac93528e3338869 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 11:38:46 +0800 Subject: [PATCH 31/45] Refactor message sending architecture and implement legacy driver support - Removed UniversalMessageSender from group_generator.py and private_generator.py. - Updated PlatformIOManager to manage legacy send drivers and ensure send pipeline readiness. - Enhanced LegacyPlatformDriver to utilize prepared messages for sending. - Refactored send_service to unify message sending logic and integrate with Platform IO. - Added regression tests for Platform IO legacy driver and send service functionality. --- pytests/test_platform_io_legacy_driver.py | 124 ++++ pytests/test_send_service.py | 141 ++++ src/chat/brain_chat/PFC/message_sender.py | 84 +-- .../message_receive/uni_message_sender.py | 63 +- src/chat/replyer/group_generator.py | 8 +- src/chat/replyer/private_generator.py | 8 +- src/common/message_server/server.py | 2 +- src/platform_io/drivers/legacy_driver.py | 51 +- src/platform_io/manager.py | 77 ++- src/plugin_runtime/integration.py | 2 +- src/services/send_service.py | 636 ++++++++++++++---- 11 files changed, 967 insertions(+), 229 deletions(-) create mode 100644 pytests/test_platform_io_legacy_driver.py create mode 100644 pytests/test_send_service.py diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py new file mode 100644 index 00000000..2e94c1fc --- /dev/null +++ b/pytests/test_platform_io_legacy_driver.py @@ -0,0 +1,124 @@ +"""Platform IO legacy driver 回归测试。""" + +from typing import Any, Dict, Optional + +import pytest + +from src.chat.utils import utils as chat_utils +from src.chat.message_receive import uni_message_sender +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey + + +class _PluginDriver(PlatformIODriver): + """测试用插件发送驱动。""" + + def __init__(self, driver_id: str, platform: str) -> None: + """初始化测试驱动。 + + Args: + driver_id: 驱动 ID。 + platform: 负责的平台名称。 + """ + super().__init__( + DriverDescriptor( + driver_id=driver_id, + kind=DriverKind.PLUGIN, + platform=platform, + plugin_id="test.plugin", + ) + ) + + async def send_message( + self, + message: Any, + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """返回一个固定成功回执。 + + Args: + message: 待发送消息。 + route_key: 当前路由键。 + metadata: 发送元数据。 + + Returns: + DeliveryReceipt: 固定成功回执。 + """ + del metadata + return DeliveryReceipt( + internal_message_id=str(message.message_id), + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) + + +@pytest.mark.asyncio +async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """没有显式发送路由时,应由 Platform IO 回退到 legacy driver。""" + manager = PlatformIOManager() + monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"}) + + try: + await manager.ensure_send_pipeline_ready() + + fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq")) + assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"] + + plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq") + await manager.add_driver(plugin_driver) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=plugin_driver.driver_id, + driver_kind=plugin_driver.descriptor.kind, + ) + ) + + explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq")) + assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender"] + finally: + await manager.stop() + + +@pytest.mark.asyncio +async def test_legacy_platform_driver_uses_prepared_universal_sender( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """legacy driver 应复用已预处理消息的旧链发送函数。""" + calls: list[dict[str, Any]] = [] + + async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool: + """记录 legacy driver 调用。""" + calls.append({"message": message, "show_log": show_log}) + return True + + monkeypatch.setattr( + uni_message_sender, + "send_prepared_message_to_platform", + _fake_send_prepared_message_to_platform, + ) + + driver = LegacyPlatformDriver( + driver_id="legacy.send.qq", + platform="qq", + account_id="bot-qq", + ) + message = type("FakeMessage", (), {"message_id": "message-1"})() + receipt = await driver.send_message( + message=message, + route_key=RouteKey(platform="qq"), + metadata={"show_log": False}, + ) + + assert len(calls) == 1 + assert calls[0]["message"] is message + assert calls[0]["show_log"] is False + assert receipt.status == DeliveryStatus.SENT + assert receipt.driver_id == "legacy.send.qq" diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py new file mode 100644 index 00000000..4ddd4fa1 --- /dev/null +++ b/pytests/test_send_service.py @@ -0,0 +1,141 @@ +"""发送服务回归测试。""" + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from src.chat.message_receive.chat_manager import BotChatSession +from src.services import send_service + + +class _FakePlatformIOManager: + """用于测试的 Platform IO 管理器假对象。""" + + def __init__(self, delivery_batch: Any) -> None: + """初始化假 Platform IO 管理器。 + + Args: + delivery_batch: 发送时返回的批量回执。 + """ + self._delivery_batch = delivery_batch + self.ensure_calls = 0 + self.sent_messages: List[Dict[str, Any]] = [] + + async def ensure_send_pipeline_ready(self) -> None: + """记录发送管线准备调用次数。""" + self.ensure_calls += 1 + + def build_route_key_from_message(self, message: Any) -> Any: + """根据消息构造假的路由键。 + + Args: + message: 待发送的内部消息对象。 + + Returns: + Any: 简化后的路由键对象。 + """ + del message + return SimpleNamespace(platform="qq") + + async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any: + """记录发送请求并返回预设回执。 + + Args: + message: 待发送的内部消息对象。 + route_key: 本次发送使用的路由键。 + metadata: 发送元数据。 + + Returns: + Any: 预设的批量发送回执。 + """ + self.sent_messages.append( + { + "message": message, + "route_key": route_key, + "metadata": metadata, + } + ) + return self._delivery_batch + + +def _build_target_stream() -> BotChatSession: + """构造一个最小可用的目标会话对象。 + + Returns: + BotChatSession: 测试用会话对象。 + """ + return BotChatSession( + session_id="test-session", + platform="qq", + user_id="target-user", + group_id=None, + ) + + +@pytest.mark.asyncio +async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: + """send service 应将发送职责统一交给 Platform IO。""" + fake_manager = _FakePlatformIOManager( + delivery_batch=SimpleNamespace( + has_success=True, + sent_receipts=[SimpleNamespace(driver_id="plugin.qq.sender")], + failed_receipts=[], + route_key=SimpleNamespace(platform="qq"), + ) + ) + stored_messages: List[Any] = [] + + monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + ) + monkeypatch.setattr( + send_service.MessageUtils, + "store_message_to_db", + lambda message: stored_messages.append(message), + ) + + result = await send_service.text_to_stream(text="你好", stream_id="test-session") + + assert result is True + assert fake_manager.ensure_calls == 1 + assert len(fake_manager.sent_messages) == 1 + assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False} + assert len(stored_messages) == 1 + + +@pytest.mark.asyncio +async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None: + """Platform IO 批量发送全部失败时,应直接向上返回失败。""" + fake_manager = _FakePlatformIOManager( + delivery_batch=SimpleNamespace( + has_success=False, + sent_receipts=[], + failed_receipts=[ + SimpleNamespace( + driver_id="plugin.qq.sender", + status="failed", + error="network error", + ) + ], + route_key=SimpleNamespace(platform="qq"), + ) + ) + + monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + ) + + result = await send_service.text_to_stream(text="发送失败", stream_id="test-session") + + assert result is False + assert fake_manager.ensure_calls == 1 + assert len(fake_manager.sent_messages) == 1 diff --git a/src/chat/brain_chat/PFC/message_sender.py b/src/chat/brain_chat/PFC/message_sender.py index ec5fb5ba..b9da905c 100644 --- a/src/chat/brain_chat/PFC/message_sender.py +++ b/src/chat/brain_chat/PFC/message_sender.py @@ -1,27 +1,28 @@ -import time +"""PFC 侧消息发送封装。""" + from typing import Optional -from maim_message import Seg from rich.traceback import install -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.message import MessageSending -from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from src.chat.utils.utils import get_bot_account +from src.common.data_models.mai_message_data_model import MaiMessage from src.common.logger import get_logger -from src.config.config import global_config +from src.services import send_service as send_api install(extra_lines=3) - logger = get_logger("message_sender") class DirectMessageSender: - """直接消息发送器""" + """直接消息发送器。""" - def __init__(self, private_name: str): + def __init__(self, private_name: str) -> None: + """初始化直接消息发送器。 + + Args: + private_name: 当前私聊实例的名称。 + """ self.private_name = private_name async def send_message( @@ -30,58 +31,31 @@ class DirectMessageSender: content: str, reply_to_message: Optional[MaiMessage] = None, ) -> None: - """发送消息到聊天流 + """发送文本消息到聊天流。 Args: - chat_stream: 聊天会话 - content: 消息内容 - reply_to_message: 要回复的消息(可选) + chat_stream: 目标聊天会话。 + content: 待发送的文本内容。 + reply_to_message: 可选的引用回复锚点消息。 + + Raises: + RuntimeError: 当消息发送失败时抛出。 """ try: - # 创建消息内容 - segments = Seg(type="seglist", data=[Seg(type="text", data=content)]) - - # 获取麦麦的信息 - bot_user_id = get_bot_account(chat_stream.platform) - if not bot_user_id: - logger.error(f"[私聊][{self.private_name}]平台 {chat_stream.platform} 未配置机器人账号,无法发送消息") - raise RuntimeError(f"平台 {chat_stream.platform} 未配置机器人账号") - bot_user_info = UserInfo( - user_id=bot_user_id, - user_nickname=global_config.bot.nickname, + sent = await send_api.text_to_stream( + text=content, + stream_id=chat_stream.session_id, + set_reply=reply_to_message is not None, + reply_message=reply_to_message, + storage_message=True, ) - # 用当前时间作为message_id,和之前那套sender一样 - message_id = f"dm{round(time.time(), 2)}" - - # 构建发送者信息(私聊时为接收者) - sender_info = None - if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info: - sender_info = reply_to_message.message_info.user_info - - # 构建消息对象 - message = MessageSending( - message_id=message_id, - session=chat_stream, - bot_user_info=bot_user_info, - sender_info=sender_info, - message_segment=segments, - reply=reply_to_message, - is_head=True, - is_emoji=False, - thinking_start_time=time.time(), - ) - - # 发送消息 - message_sender = UniversalMessageSender() - sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True) - if sent: logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}") - else: - logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") - raise RuntimeError("消息发送失败") + return - except Exception as e: - logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}") + logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") + raise RuntimeError("消息发送失败") + except Exception as exc: + logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}") raise diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index df74e459..cf42e092 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -60,8 +60,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: 发送顺序为: 1. WebUI 特殊链路 - 2. Platform IO 适配器链路 - 3. 旧版 ``maim_message`` / API Server 链路 + 2. 旧版 ``maim_message`` / API Server 链路 Args: message: 待发送的内部会话消息。 @@ -124,32 +123,6 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室") return True - try: - from src.plugin_runtime.integration import get_plugin_runtime_manager - - delivery_batch = await get_plugin_runtime_manager().try_send_message_via_platform_io(message) - if delivery_batch is not None: - if delivery_batch.has_success: - successful_driver_ids = [ - receipt.driver_id or "unknown" - for receipt in delivery_batch.sent_receipts - ] - if show_log: - logger.info( - f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' " - f"(drivers: {', '.join(successful_driver_ids)})" - ) - return True - - failed_details = "; ".join( - f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}" - for receipt in delivery_batch.failed_receipts - ) or "未命中任何发送路由" - logger.warning(f"Platform IO 发送失败: platform={platform} {failed_details}") - return False - except Exception as exc: - logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}") - # Fallback 逻辑: 尝试通过 API Server 发送 async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool: """通过 API Server 回退链路发送消息。 @@ -260,8 +233,21 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: raise e # 重新抛出其他异常 +async def send_prepared_message_to_platform(message: SessionMessage, show_log: bool = True) -> bool: + """发送一条已完成预处理的消息到底层平台。 + + Args: + message: 已经完成回复组件注入、文本处理等预处理的消息对象。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ + return await _send_message(message, show_log=show_log) + + class UniversalMessageSender: - """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。""" + """旧链与 WebUI 的底层发送器。""" def __init__(self) -> None: """初始化统一消息发送器。""" @@ -276,17 +262,18 @@ class UniversalMessageSender: storage_message: bool = True, show_log: bool = True, ) -> bool: - """ - 处理、发送并存储一条消息。 + """通过旧链或 WebUI 发送并存储一条消息。 - 参数: - message: MessageSession 对象,待发送的消息。 + Args: + message: 待发送的内部消息对象。 typing: 是否模拟打字等待。 - set_reply: 是否构建回复引用消息。 + set_reply: 是否构建引用回复消息。 + reply_message_id: 被引用消息的 ID。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 - - 用法: - - typing=True 时,发送前会有打字等待。 + Returns: + bool: 发送成功时返回 ``True``。 """ if not message.message_id: logger.error("消息缺少 message_id,无法发送") @@ -339,7 +326,7 @@ class UniversalMessageSender: ) await asyncio.sleep(typing_time) - sent_msg = await _send_message(message, show_log=show_log) + sent_msg = await send_prepared_message_to_platform(message, show_log=show_log) if not sent_msg: return False diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 75563df7..4ffa14a7 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -17,7 +17,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser from src.common.data_models.mai_message_data_model import MaiMessage from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager @@ -51,10 +50,15 @@ class DefaultReplyer: chat_stream: BotChatSession, request_type: str = "replyer", ): + """初始化群聊回复器。 + + Args: + chat_stream: 当前绑定的聊天会话。 + request_type: LLM 请求类型标识。 + """ self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) - self.heart_fc_sender = UniversalMessageSender() from src.chat.tool_executor import ToolExecutor diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index f642dd69..c125a42f 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -16,7 +16,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser from src.common.data_models.mai_message_data_model import MaiMessage from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager @@ -47,10 +46,15 @@ class PrivateReplyer: chat_stream: BotChatSession, request_type: str = "replyer", ): + """初始化私聊回复器。 + + Args: + chat_stream: 当前绑定的聊天会话。 + request_type: LLM 请求类型标识。 + """ self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) - self.heart_fc_sender = UniversalMessageSender() # self.memory_activator = MemoryActivator() from src.chat.tool_executor import ToolExecutor diff --git a/src/common/message_server/server.py b/src/common/message_server/server.py index 77a931e5..e75da4e7 100644 --- a/src/common/message_server/server.py +++ b/src/common/message_server/server.py @@ -21,7 +21,7 @@ class Server: self._server: Optional[UvicornServer] = None self.set_address(host, port) - def register_router(self, router: APIRouter, prefix: str = ""): + def register_router(self, router: APIRouter, prefix: str = ""): """注册路由 APIRouter 用于对相关的路由端点进行分组和模块化管理: diff --git a/src/platform_io/drivers/legacy_driver.py b/src/platform_io/drivers/legacy_driver.py index bd74d8c7..ef90c772 100644 --- a/src/platform_io/drivers/legacy_driver.py +++ b/src/platform_io/drivers/legacy_driver.py @@ -1,16 +1,16 @@ -"""提供 Platform IO 的 legacy 传输驱动骨架。""" +"""提供 Platform IO 的 legacy 传输驱动实现。""" from typing import TYPE_CHECKING, Any, Dict, Optional from src.platform_io.drivers.base import PlatformIODriver -from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage class LegacyPlatformDriver(PlatformIODriver): - """面向 ``maim_message`` 旧链路的 Platform IO 驱动骨架。""" + """面向 ``UniversalMessageSender`` 旧链的 Platform IO 驱动。""" def __init__( self, @@ -25,7 +25,7 @@ class LegacyPlatformDriver(PlatformIODriver): Args: driver_id: Broker 内的唯一驱动 ID。 platform: 该 legacy 适配器链路负责的平台。 - account_id: 可选的账号 ID 或 self ID。 + account_id: 可选的账号 ID。 scope: 可选的额外路由作用域。 metadata: 可选的额外驱动元数据。 """ @@ -45,7 +45,7 @@ class LegacyPlatformDriver(PlatformIODriver): route_key: RouteKey, metadata: Optional[Dict[str, Any]] = None, ) -> DeliveryReceipt: - """通过 legacy 传输路径发送消息。 + """通过旧链发送一条已经过预处理的消息。 Args: message: 要投递的内部会话消息。 @@ -53,9 +53,40 @@ class LegacyPlatformDriver(PlatformIODriver): metadata: 本次出站投递可选的 Broker 侧元数据。 Returns: - DeliveryReceipt: 由驱动返回的规范化回执。 - - Raises: - NotImplementedError: 当前仍处于骨架阶段,尚未真正接入旧发送链。 + DeliveryReceipt: 规范化后的发送回执。 """ - raise NotImplementedError("LegacyPlatformDriver 仅完成地基实现,尚未接入旧发送链") + from src.chat.message_receive.uni_message_sender import send_prepared_message_to_platform + + show_log = False + if isinstance(metadata, dict): + show_log = bool(metadata.get("show_log", False)) + + try: + sent = await send_prepared_message_to_platform(message, show_log=show_log) + except Exception as exc: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(exc), + ) + + if not sent: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error="旧链发送失败", + ) + + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index c96a9ddd..cb5996b4 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -36,6 +36,7 @@ class PlatformIOManager: self._driver_registry = DriverRegistry() self._send_route_table = RouteTable() self._receive_route_table = RouteTable() + self._legacy_send_drivers: Dict[str, PlatformIODriver] = {} self._deduplicator = MessageDeduplicator() self._outbound_tracker = OutboundTracker() self._inbound_dispatcher: Optional[InboundDispatcher] = None @@ -75,6 +76,16 @@ class PlatformIOManager: self._started = True + async def ensure_send_pipeline_ready(self) -> None: + """确保出站发送管线已准备就绪。 + + 该方法会先同步 legacy fallback driver,再在需要时启动 Broker。 + send service 应只调用这一层准备入口,而不是自行判断旧链或插件链。 + """ + await self._sync_legacy_send_drivers() + if not self._started: + await self.start() + async def stop(self) -> None: """停止 Broker,并按逆序停止全部已注册驱动。 @@ -272,8 +283,60 @@ class PlatformIOManager: removed_driver.clear_inbound_handler() self._send_route_table.remove_bindings_by_driver(driver_id) self._receive_route_table.remove_bindings_by_driver(driver_id) + self._legacy_send_drivers = { + platform: driver + for platform, driver in self._legacy_send_drivers.items() + if driver.driver_id != driver_id + } return removed_driver + async def _sync_legacy_send_drivers(self) -> None: + """根据当前配置同步 legacy fallback driver。""" + from src.chat.utils.utils import get_all_bot_accounts + from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver + + desired_accounts = get_all_bot_accounts() + desired_platforms = set(desired_accounts.keys()) + current_platforms = set(self._legacy_send_drivers.keys()) + + for platform in sorted(current_platforms - desired_platforms): + await self._remove_legacy_send_driver(platform) + + for platform, account_id in desired_accounts.items(): + existing_driver = self._legacy_send_drivers.get(platform) + if existing_driver is not None and existing_driver.descriptor.account_id == account_id: + continue + + if existing_driver is not None: + await self._remove_legacy_send_driver(platform) + + driver = LegacyPlatformDriver( + driver_id=f"legacy.send.{platform}", + platform=platform, + account_id=account_id, + ) + if self._started: + await self.add_driver(driver) + else: + self.register_driver(driver) + self._legacy_send_drivers[platform] = driver + + async def _remove_legacy_send_driver(self, platform: str) -> None: + """移除指定平台的 legacy fallback driver。 + + Args: + platform: 要移除的目标平台。 + """ + driver = self._legacy_send_drivers.get(platform) + if driver is None: + return + + if self._started: + await self.remove_driver(driver.driver_id) + else: + self.unregister_driver(driver.driver_id) + self._legacy_send_drivers.pop(platform, None) + def bind_send_route(self, binding: RouteBinding) -> None: """为某个路由键绑定发送驱动。 @@ -353,7 +416,19 @@ class PlatformIOManager: driver = self._driver_registry.get(binding.driver_id) if driver is not None: drivers.append(driver) - return drivers + if drivers: + return drivers + + fallback_driver = self._legacy_send_drivers.get(route_key.platform) + if fallback_driver is None: + return [] + + descriptor = fallback_driver.descriptor + if descriptor.account_id is not None and route_key.account_id not in (None, descriptor.account_id): + return [] + if descriptor.scope is not None and route_key.scope not in (None, descriptor.scope): + return [] + return [fallback_driver] def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]: """兼容旧接口,返回首个命中的发送驱动。""" diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index b74b2d46..ff51f419 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -157,7 +157,7 @@ class PluginRuntimeManager( started_supervisors: List[PluginSupervisor] = [] try: platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound) - await platform_io_manager.start() + await platform_io_manager.ensure_send_pipeline_ready() if self._builtin_supervisor: await self._builtin_supervisor.start() diff --git a/src/services/send_service.py b/src/services/send_service.py index 6ca7d005..7903cdeb 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -1,39 +1,51 @@ """ -发送服务模块 +发送服务模块。 -提供发送各种类型消息的核心功能。 +统一封装内部模块的出站消息发送逻辑: + +1. 内部模块统一调用本模块。 +2. send service 只负责构造和预处理消息。 +3. 具体走插件链还是 legacy 旧链,由 Platform IO 内部统一决策。 """ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import Any, Dict, List, Optional +from maim_message import Seg + +import asyncio +import base64 +import hashlib import time import traceback +from datetime import datetime -from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo - +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import SessionMessage -from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from src.chat.utils.utils import get_bot_account -from src.common.data_models.mai_message_data_model import MaiMessage -from src.common.data_models.message_component_data_model import DictComponent, MessageSequence +from src.chat.utils.utils import calculate_typing_time, get_bot_account +from src.common.data_models.mai_message_data_model import GroupInfo, MaiMessage, MessageInfo, UserInfo +from src.common.data_models.message_component_data_model import ( + AtComponent, + DictComponent, + EmojiComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + StandardMessageComponents, + TextComponent, + VoiceComponent, +) from src.common.logger import get_logger +from src.common.utils.utils_message import MessageUtils from src.config.config import global_config +from src.platform_io import DeliveryBatch, get_platform_io_manager from src.platform_io.route_key_factory import RouteKeyFactory -if TYPE_CHECKING: - from src.chat.message_receive.message import SessionMessage - logger = get_logger("send_service") -# ============================================================================= -# 内部实现函数 -# ============================================================================= - - -def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]: - """从目标会话上下文继承 Platform IO 路由元数据。 +def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]: + """从目标会话继承 Platform IO 路由元数据。 Args: target_stream: 当前消息要发送到的会话对象。 @@ -44,12 +56,11 @@ def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object] """ inherited_metadata: Dict[str, object] = {} - context = getattr(target_stream, "context", None) - context_message = getattr(context, "message", None) + context_message = target_stream.context.message if target_stream.context else None if context_message is None: return inherited_metadata - additional_config = getattr(context_message.message_info, "additional_config", {}) + additional_config = context_message.message_info.additional_config if not isinstance(additional_config, dict): return inherited_metadata @@ -61,33 +72,412 @@ def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object] if normalized_value: inherited_metadata[key] = value - target_group_id = getattr(target_stream, "group_id", None) - if target_group_id is not None: - normalized_group_id = str(target_group_id).strip() + if target_stream.group_id: + normalized_group_id = str(target_stream.group_id).strip() if normalized_group_id: inherited_metadata["platform_io_target_group_id"] = normalized_group_id - target_user_id = getattr(target_stream, "user_id", None) - if target_user_id is not None: - normalized_user_id = str(target_user_id).strip() + if target_stream.user_id: + normalized_user_id = str(target_stream.user_id).strip() if normalized_user_id: inherited_metadata["platform_io_target_user_id"] = normalized_user_id return inherited_metadata +def _build_component_from_seg(message_segment: Seg) -> StandardMessageComponents: + """将单个消息段转换为内部消息组件。 + + Args: + message_segment: 待转换的消息段。 + + Returns: + StandardMessageComponents: 转换后的内部消息组件。 + """ + segment_type = str(message_segment.type or "").strip().lower() + segment_data = message_segment.data + + if segment_type == "text": + return TextComponent(text=str(segment_data or "")) + + if segment_type == "image": + image_binary = base64.b64decode(str(segment_data or "")) + return ImageComponent( + binary_hash=hashlib.sha256(image_binary).hexdigest(), + binary_data=image_binary, + ) + + if segment_type == "emoji": + emoji_binary = base64.b64decode(str(segment_data or "")) + return EmojiComponent( + binary_hash=hashlib.sha256(emoji_binary).hexdigest(), + binary_data=emoji_binary, + ) + + if segment_type == "voice": + voice_binary = base64.b64decode(str(segment_data or "")) + return VoiceComponent( + binary_hash=hashlib.sha256(voice_binary).hexdigest(), + binary_data=voice_binary, + ) + + if segment_type == "at": + return AtComponent(target_user_id=str(segment_data or "")) + + if segment_type == "reply": + return ReplyComponent(target_message_id=str(segment_data or "")) + + if segment_type == "dict" and isinstance(segment_data, dict): + return DictComponent(data=segment_data) + + return DictComponent(data={"type": segment_type, "data": segment_data}) + + +def _build_message_sequence_from_seg(message_segment: Seg) -> MessageSequence: + """将消息段转换为内部消息组件序列。 + + Args: + message_segment: 待转换的消息段。 + + Returns: + MessageSequence: 转换后的消息组件序列。 + """ + if str(message_segment.type or "").strip().lower() == "seglist": + raw_segments = message_segment.data + if not isinstance(raw_segments, list): + raise ValueError("seglist 类型的消息段数据必须是列表") + components = [ + _build_component_from_seg(item) + for item in raw_segments + if isinstance(item, Seg) + ] + return MessageSequence(components=components) + + return MessageSequence(components=[_build_component_from_seg(message_segment)]) + + +def _build_processed_plain_text(message: SessionMessage) -> str: + """为出站消息构造轻量纯文本摘要。 + + Args: + message: 待发送的内部消息对象。 + + Returns: + str: 适用于日志与打字时长估算的纯文本摘要。 + """ + processed_parts: List[str] = [] + for component in message.raw_message.components: + if isinstance(component, TextComponent): + processed_parts.append(component.text) + continue + + if isinstance(component, ImageComponent): + processed_parts.append(component.content or "[图片]") + continue + + if isinstance(component, EmojiComponent): + processed_parts.append(component.content or "[表情]") + continue + + if isinstance(component, VoiceComponent): + processed_parts.append(component.content or "[语音]") + continue + + if isinstance(component, AtComponent): + at_target = component.target_user_cardname or component.target_user_nickname or component.target_user_id + processed_parts.append(f"@{at_target}") + continue + + if isinstance(component, ReplyComponent): + processed_parts.append(component.target_message_content or "[回复消息]") + continue + + if isinstance(component, DictComponent): + raw_type = component.data.get("type") if isinstance(component.data, dict) else None + if isinstance(raw_type, str) and raw_type.strip(): + processed_parts.append(f"[{raw_type.strip()}消息]") + else: + processed_parts.append("[自定义消息]") + continue + + return " ".join(part for part in processed_parts if part) + + +def _build_outbound_session_message( + message_segment: Seg, + stream_id: str, + display_message: str = "", + reply_message: Optional[MaiMessage] = None, + selected_expressions: Optional[List[int]] = None, +) -> Optional[SessionMessage]: + """根据目标会话构建待发送的内部消息对象。 + + Args: + message_segment: 待发送的消息段。 + stream_id: 目标会话 ID。 + display_message: 用于界面展示的文本内容。 + reply_message: 被回复的锚点消息。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + Optional[SessionMessage]: 构建成功时返回内部消息对象;若目标会话或 + 机器人账号不存在,则返回 ``None``。 + """ + target_stream = _chat_manager.get_session_by_session_id(stream_id) + if target_stream is None: + logger.error(f"[SendService] 未找到聊天流: {stream_id}") + return None + + bot_user_id = get_bot_account(target_stream.platform) + if not bot_user_id: + logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息") + return None + + current_time = time.time() + message_id = f"send_api_{int(current_time * 1000)}" + anchor_message = reply_message.deepcopy() if reply_message is not None else None + + group_info: Optional[GroupInfo] = None + if target_stream.group_id: + group_name = "" + if ( + target_stream.context + and target_stream.context.message + and target_stream.context.message.message_info.group_info + ): + group_name = target_stream.context.message.message_info.group_info.group_name + group_info = GroupInfo( + group_id=target_stream.group_id, + group_name=group_name, + ) + + additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream) + if selected_expressions is not None: + additional_config["selected_expressions"] = selected_expressions + + outbound_message = SessionMessage( + message_id=message_id, + timestamp=datetime.fromtimestamp(current_time), + platform=target_stream.platform, + ) + outbound_message.message_info = MessageInfo( + user_info=UserInfo( + user_id=bot_user_id, + user_nickname=global_config.bot.nickname, + ), + group_info=group_info, + additional_config=additional_config, + ) + outbound_message.raw_message = _build_message_sequence_from_seg(message_segment) + outbound_message.session_id = target_stream.session_id + outbound_message.display_message = display_message + outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None + outbound_message.is_emoji = message_segment.type == "emoji" + outbound_message.is_picture = message_segment.type == "image" + outbound_message.is_command = message_segment.type == "command" + outbound_message.initialized = True + return outbound_message + + +def _ensure_reply_component(message: SessionMessage, reply_message_id: str) -> None: + """为消息补充回复组件。 + + Args: + message: 待发送的内部消息对象。 + reply_message_id: 被引用消息的 ID。 + """ + if message.raw_message.components: + first_component = message.raw_message.components[0] + if isinstance(first_component, ReplyComponent) and first_component.target_message_id == reply_message_id: + return + + message.raw_message.components.insert(0, ReplyComponent(target_message_id=reply_message_id)) + + +async def _prepare_message_for_platform_io( + message: SessionMessage, + *, + typing: bool, + set_reply: bool, + reply_message_id: Optional[str], +) -> None: + """为 Platform IO 发送链预处理消息。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否构建引用回复组件。 + reply_message_id: 被引用消息的 ID。 + + Raises: + ValueError: 当要求设置引用回复但缺少 ``reply_message_id`` 时抛出。 + """ + if set_reply: + if not reply_message_id: + raise ValueError("set_reply=True 时必须提供 reply_message_id") + _ensure_reply_component(message, reply_message_id) + + message.processed_plain_text = _build_processed_plain_text(message) + if typing: + typing_time = calculate_typing_time( + input_string=message.processed_plain_text or "", + is_emoji=message.is_emoji, + ) + await asyncio.sleep(typing_time) + + +def _store_sent_message(message: SessionMessage) -> None: + """将已成功发送的消息写入数据库。 + + Args: + message: 已成功发送的内部消息对象。 + """ + MessageUtils.store_message_to_db(message) + + +def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None: + """输出 Platform IO 批量发送失败详情。 + + Args: + delivery_batch: Platform IO 返回的批量回执。 + """ + failed_details = "; ".join( + f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}" + for receipt in delivery_batch.failed_receipts + ) or "未命中任何发送路由" + logger.warning( + "[SendService] Platform IO 发送失败: platform=%s %s", + delivery_batch.route_key.platform, + failed_details, + ) + + +async def _send_via_platform_io( + message: SessionMessage, + *, + typing: bool, + set_reply: bool, + reply_message_id: Optional[str], + storage_message: bool, + show_log: bool, +) -> bool: + """通过 Platform IO 发送消息。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否设置引用回复。 + reply_message_id: 被引用消息的 ID。 + storage_message: 发送成功后是否写入数据库。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ + platform_io_manager = get_platform_io_manager() + try: + await platform_io_manager.ensure_send_pipeline_ready() + except Exception as exc: + logger.error(f"[SendService] 准备 Platform IO 发送管线失败: {exc}") + logger.debug(traceback.format_exc()) + return False + + try: + route_key = platform_io_manager.build_route_key_from_message(message) + except Exception as exc: + logger.warning(f"[SendService] 根据消息构造 Platform IO 路由键失败: {exc}") + return False + + try: + await _prepare_message_for_platform_io( + message, + typing=typing, + set_reply=set_reply, + reply_message_id=reply_message_id, + ) + delivery_batch = await platform_io_manager.send_message( + message, + route_key, + metadata={"show_log": False}, + ) + except Exception as exc: + logger.error(f"[SendService] Platform IO 发送异常: {exc}") + logger.debug(traceback.format_exc()) + return False + + if delivery_batch.has_success: + if storage_message: + _store_sent_message(message) + if show_log: + successful_driver_ids = [ + receipt.driver_id or "unknown" + for receipt in delivery_batch.sent_receipts + ] + logger.info( + "[SendService] 已通过 Platform IO 将消息发往平台 '%s' (drivers: %s)", + route_key.platform, + ", ".join(successful_driver_ids), + ) + return True + + _log_platform_io_failures(delivery_batch) + return False + + +async def send_session_message( + message: SessionMessage, + *, + typing: bool = False, + set_reply: bool = False, + reply_message_id: Optional[str] = None, + storage_message: bool = True, + show_log: bool = True, +) -> bool: + """统一发送一条内部消息。 + + 该方法是内部模块的统一发送入口: + + 1. 构造并维护内部消息对象。 + 2. 由 Platform IO 统一决定走插件链还是 legacy 旧链。 + 3. send service 不再自行判断底层发送路径。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否设置引用回复。 + reply_message_id: 被引用消息的 ID。 + storage_message: 发送成功后是否写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 发送成功时返回 ``True``,否则返回 ``False``。 + """ + if not message.message_id: + logger.error("[SendService] 消息缺少 message_id,无法发送") + raise ValueError("消息缺少 message_id,无法发送") + + return await _send_via_platform_io( + message, + typing=typing, + set_reply=set_reply, + reply_message_id=reply_message_id, + storage_message=storage_message, + show_log=show_log, + ) + + async def _send_to_target( message_segment: Seg, stream_id: str, display_message: str = "", typing: bool = False, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, storage_message: bool = True, show_log: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定目标发送消息。 + """向指定目标构建并发送消息。 Args: message_segment: 待发送的消息段。 @@ -104,110 +494,66 @@ async def _send_to_target( bool: 发送成功返回 ``True``,否则返回 ``False``。 """ try: - if set_reply and not reply_message: + if set_reply and reply_message is None: logger.warning("[SendService] 使用引用回复,但未提供回复消息") return False if show_log: logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}") - target_stream = _chat_manager.get_session_by_session_id(stream_id) - if not target_stream: - logger.error(f"[SendService] 未找到聊天流: {stream_id}") - return False - - message_sender = UniversalMessageSender() - - current_time = time.time() - message_id = f"send_api_{int(current_time * 1000)}" - - anchor_message: Optional[MaiMessage] = None - if reply_message: - anchor_message = reply_message.deepcopy() - if anchor_message: - logger.debug( - f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}" - ) - - group_info = None - if target_stream.group_id: - group_name = "" - if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info: - group_name = target_stream.context.message.message_info.group_info.group_name - group_info = MaimGroupInfo( - group_id=target_stream.group_id, - group_name=group_name, - platform=target_stream.platform, - ) - - additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream) - if selected_expressions is not None: - additional_config["selected_expressions"] = selected_expressions - bot_user_id = get_bot_account(target_stream.platform) - if not bot_user_id: - logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息") - return False - - maim_message = MessageBase( - message_info=BaseMessageInfo( - platform=target_stream.platform, - message_id=message_id, - time=current_time, - user_info=MaimUserInfo( - user_id=bot_user_id, - user_nickname=global_config.bot.nickname, - platform=target_stream.platform, - ), - group_info=group_info, - additional_config=additional_config, - ), + outbound_message = _build_outbound_session_message( message_segment=message_segment, + stream_id=stream_id, + display_message=display_message, + reply_message=reply_message, + selected_expressions=selected_expressions, ) - bot_message = SessionMessage.from_maim_message(maim_message) - bot_message.session_id = target_stream.session_id - bot_message.display_message = display_message - bot_message.reply_to = anchor_message.message_id if anchor_message else None - bot_message.is_emoji = message_segment.type == "emoji" - bot_message.is_picture = message_segment.type == "image" - bot_message.is_command = message_segment.type == "command" + if outbound_message is None: + return False - sent_msg = await message_sender.send_message( - bot_message, + sent = await send_session_message( + outbound_message, typing=typing, set_reply=set_reply, - reply_message_id=anchor_message.message_id if anchor_message else None, + reply_message_id=reply_message.message_id if reply_message is not None else None, storage_message=storage_message, show_log=show_log, ) - - if sent_msg: + if sent: logger.debug(f"[SendService] 成功发送消息到 {stream_id}") return True - else: - logger.error("[SendService] 发送消息失败") - return False - except Exception as e: - logger.error(f"[SendService] 发送消息时出错: {e}") + logger.error("[SendService] 发送消息失败") + return False + except Exception as exc: + logger.error(f"[SendService] 发送消息时出错: {exc}") traceback.print_exc() return False -# ============================================================================= -# 公共函数 - 预定义类型的发送函数 -# ============================================================================= - - async def text_to_stream( text: str, stream_id: str, typing: bool = False, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, storage_message: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定流发送文本消息""" + """向指定流发送文本消息。 + + Args: + text: 要发送的文本内容。 + stream_id: 目标会话 ID。 + typing: 是否显示输入中状态。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + storage_message: 是否在发送成功后写入数据库。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( message_segment=Seg(type="text", data=text), stream_id=stream_id, @@ -225,9 +571,20 @@ async def emoji_to_stream( stream_id: str, storage_message: bool = True, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, ) -> bool: - """向指定流发送表情包""" + """向指定流发送表情消息。 + + Args: + emoji_base64: 表情图片的 Base64 内容。 + stream_id: 目标会话 ID。 + storage_message: 是否在发送成功后写入数据库。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( message_segment=Seg(type="emoji", data=emoji_base64), stream_id=stream_id, @@ -244,9 +601,20 @@ async def image_to_stream( stream_id: str, storage_message: bool = True, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, ) -> bool: - """向指定流发送图片""" + """向指定流发送图片消息。 + + Args: + image_base64: 图片的 Base64 内容。 + stream_id: 目标会话 ID。 + storage_message: 是否在发送成功后写入数据库。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( message_segment=Seg(type="image", data=image_base64), stream_id=stream_id, @@ -260,18 +628,33 @@ async def image_to_stream( async def custom_to_stream( message_type: str, - content: str | Dict, + content: str | Dict[str, Any], stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, ) -> bool: - """向指定流发送自定义类型消息""" + """向指定流发送自定义类型消息。 + + Args: + message_type: 自定义消息类型。 + content: 自定义消息内容。 + stream_id: 目标会话 ID。 + display_message: 用于展示的文本内容。 + typing: 是否显示输入中状态。 + reply_message: 被回复的消息对象。 + set_reply: 是否附带引用回复。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( - message_segment=Seg(type=message_type, data=content), # type: ignore + message_segment=Seg(type=message_type, data=content), # type: ignore[arg-type] stream_id=stream_id, display_message=display_message, typing=typing, @@ -287,18 +670,33 @@ async def custom_reply_set_to_stream( stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, ) -> bool: - """向指定流发送消息组件序列。""" - flag: bool = True + """向指定流发送消息组件序列。 + + Args: + reply_set: 待发送的消息组件序列。 + stream_id: 目标会话 ID。 + display_message: 用于展示的文本内容。 + typing: 是否显示输入中状态。 + reply_message: 被回复的消息对象。 + set_reply: 是否附带引用回复。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 全部组件发送成功时返回 ``True``。 + """ + success = True for component in reply_set.components: if isinstance(component, DictComponent): - message_seg = Seg(type="dict", data=component.data) # type: ignore + message_seg = Seg(type="dict", data=component.data) # type: ignore[arg-type] else: message_seg = await component.to_seg() + status = await _send_to_target( message_segment=message_seg, stream_id=stream_id, @@ -310,8 +708,8 @@ async def custom_reply_set_to_stream( show_log=show_log, ) if not status: - flag = False + success = False logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}") set_reply = False - return flag + return success From 18a0e7664ad23a1b582610557501f3d81923fd8a Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 16:14:13 +0800 Subject: [PATCH 32/45] Refactor plugin runtime components and enhance message handling - Removed unused core action mirror functionality from PluginRunnerSupervisor. - Simplified action and command execution logic in send_service.py. - Introduced ComponentQueryService for unified component querying in plugin runtime. - Enhanced message component handling with new binary component support. - Improved message sequence construction and detection of outbound message flags. - Updated methods for sending messages to streamline the process and improve readability. --- pytests/test_platform_io_dedupe.py | 6 +- pytests/test_plugin_runtime_action_bridge.py | 328 +++++--- src/chat/brain_chat/brain_planner.py | 32 +- src/chat/message_receive/bot.py | 114 +-- src/chat/planner_actions/action_manager.py | 14 +- src/chat/planner_actions/planner.py | 41 +- src/chat/tool_executor.py | 14 +- src/core/component_registry.py | 239 ------ src/platform_io/manager.py | 24 - src/plugin_runtime/capabilities/core.py | 8 +- src/plugin_runtime/capabilities/data.py | 4 +- src/plugin_runtime/component_query.py | 709 ++++++++++++++++++ src/plugin_runtime/host/component_registry.py | 72 +- src/plugin_runtime/host/supervisor.py | 238 ------ src/services/send_service.py | 265 ++++--- 15 files changed, 1255 insertions(+), 853 deletions(-) delete mode 100644 src/core/component_registry.py create mode 100644 src/plugin_runtime/component_query.py diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py index 68ae95c6..d6bdd1dd 100644 --- a/pytests/test_platform_io_dedupe.py +++ b/pytests/test_platform_io_dedupe.py @@ -72,10 +72,10 @@ class _StubPlatformIODriver(PlatformIODriver): def _build_manager() -> PlatformIOManager: - """构造带有最小 active owner 的 Broker 管理器。 + """构造带有最小接收路由的 Broker 管理器。 Returns: - PlatformIOManager: 已注册测试驱动并绑定活动路由的 Broker。 + PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。 """ manager = PlatformIOManager() driver = _StubPlatformIODriver( @@ -88,7 +88,7 @@ def _build_manager() -> PlatformIOManager: ) ) manager.register_driver(driver) - manager.bind_route( + manager.bind_receive_route( RouteBinding( route_key=RouteKey(platform="qq", account_id="10001", scope="main"), driver_id=driver.driver_id, diff --git a/pytests/test_plugin_runtime_action_bridge.py b/pytests/test_plugin_runtime_action_bridge.py index f2364094..e13dfaf3 100644 --- a/pytests/test_plugin_runtime_action_bridge.py +++ b/pytests/test_plugin_runtime_action_bridge.py @@ -1,57 +1,109 @@ +"""核心组件查询层与插件运行时聚合测试。""" + from types import SimpleNamespace from typing import Any import pytest -from src.core.component_registry import component_registry as core_component_registry +import src.plugin_runtime.integration as integration_module + +from src.core.types import ActionInfo, ToolInfo +from src.plugin_runtime.component_query import component_query_service from src.plugin_runtime.host.supervisor import PluginSupervisor -from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterPluginPayload -def _build_action_payload(plugin_id: str, action_name: str) -> RegisterPluginPayload: - """构造用于测试的 runtime Action 注册载荷。 +class _FakeRuntimeManager: + """测试用插件运行时管理器。""" + + def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None: + """初始化测试用运行时管理器。 + + Args: + supervisor: 持有测试组件的监督器。 + plugin_id: 目标插件 ID。 + plugin_config: 需要返回的插件配置。 + """ + + self.supervisors = [supervisor] + self._plugin_id = plugin_id + self._plugin_config = plugin_config + + def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None: + """按插件 ID 返回对应监督器。 + + Args: + plugin_id: 目标插件 ID。 + + Returns: + PluginSupervisor | None: 命中时返回监督器。 + """ + + return self.supervisors[0] if plugin_id == self._plugin_id else None + + def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]: + """返回测试配置。 + + Args: + supervisor: 监督器实例。 + plugin_id: 目标插件 ID。 + + Returns: + dict[str, Any]: 测试配置内容。 + """ + + del supervisor + if plugin_id != self._plugin_id: + return {} + return dict(self._plugin_config) + + +def _install_runtime_manager( + monkeypatch: pytest.MonkeyPatch, + supervisor: PluginSupervisor, + plugin_id: str, + plugin_config: dict[str, Any] | None = None, +) -> None: + """为测试安装假的运行时管理器。 Args: - plugin_id: 插件 ID。 - action_name: Action 名称。 - - Returns: - RegisterPluginPayload: 测试用注册载荷。 + monkeypatch: pytest monkeypatch 对象。 + supervisor: 持有测试组件的监督器。 + plugin_id: 测试插件 ID。 + plugin_config: 可选的测试配置内容。 """ - return RegisterPluginPayload( - plugin_id=plugin_id, - plugin_version="1.0.0", - components=[ - ComponentDeclaration( - name=action_name, - component_type="ACTION", - plugin_id=plugin_id, - metadata={ - "description": "发送一个测试回复", - "enabled": True, - "activation_type": "keyword", - "activation_probability": 0.25, - "activation_keywords": ["测试", "hello"], - "action_parameters": {"target": "目标对象"}, - "action_require": ["需要发送回复时使用"], - "associated_types": ["text"], - "parallel_action": True, - }, - ) - ], - ) + + fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True}) + monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager) @pytest.mark.asyncio -async def test_runtime_actions_are_mirrored_into_core_registry_and_invoked(monkeypatch: pytest.MonkeyPatch) -> None: - """运行时 Action 应镜像到旧核心注册表,并可由旧 Planner 执行。""" +async def test_core_component_registry_reads_runtime_action_and_executor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """核心查询层应直接读取运行时 Action,并返回 RPC 执行闭包。""" + plugin_id = "runtime_action_bridge_plugin" action_name = "runtime_action_bridge_test" - payload = _build_action_payload(plugin_id=plugin_id, action_name=action_name) supervisor = PluginSupervisor(plugin_dirs=[]) captured: dict[str, Any] = {} - core_component_registry.remove_action(action_name) + supervisor.component_registry.register_component( + name=action_name, + component_type="ACTION", + plugin_id=plugin_id, + metadata={ + "description": "发送一个测试回复", + "enabled": True, + "activation_type": "keyword", + "activation_probability": 0.25, + "activation_keywords": ["测试", "hello"], + "action_parameters": {"target": "目标对象"}, + "action_require": ["需要发送回复时使用"], + "associated_types": ["text"], + "parallel_action": True, + }, + ) + _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"}) async def fake_invoke_plugin( method: str, @@ -60,18 +112,8 @@ async def test_runtime_actions_are_mirrored_into_core_registry_and_invoked(monke args: dict[str, Any] | None = None, timeout_ms: int = 30000, ) -> Any: - """模拟 plugin runtime Action 调用。 + """模拟动作 RPC 调用。""" - Args: - method: RPC 方法名。 - plugin_id: 插件 ID。 - component_name: 组件名称。 - args: 调用参数。 - timeout_ms: RPC 超时时间。 - - Returns: - Any: 伪造的 RPC 响应对象。 - """ captured["method"] = method captured["plugin_id"] = plugin_id captured["component_name"] = component_name @@ -81,58 +123,162 @@ async def test_runtime_actions_are_mirrored_into_core_registry_and_invoked(monke monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) - try: - supervisor._mirror_runtime_actions_to_core_registry(payload) + action_info = component_query_service.get_action_info(action_name) + assert isinstance(action_info, ActionInfo) + assert action_info.plugin_name == plugin_id + assert action_info.description == "发送一个测试回复" + assert action_info.activation_keywords == ["测试", "hello"] + assert action_info.random_activation_probability == 0.25 + assert action_info.parallel_action is True + assert action_name in component_query_service.get_default_actions() + assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"} - action_info = core_component_registry.get_action_info(action_name) - assert action_info is not None - assert action_info.plugin_name == plugin_id - assert action_info.description == "发送一个测试回复" - assert action_info.activation_keywords == ["测试", "hello"] - assert action_info.random_activation_probability == 0.25 - assert action_info.parallel_action is True + executor = component_query_service.get_action_executor(action_name) + assert executor is not None - executor = core_component_registry.get_action_executor(action_name) - assert executor is not None + success, reason = await executor( + action_data={"target": "MaiBot"}, + action_reasoning="当前适合使用这个动作", + cycle_timers={"planner": 0.1}, + thinking_id="tid-1", + chat_stream=SimpleNamespace(session_id="stream-1"), + log_prefix="[test]", + shutting_down=False, + plugin_config={"enabled": True}, + ) - success, reason = await executor( - action_data={"target": "MaiBot"}, - action_reasoning="当前适合使用这个动作", - cycle_timers={"planner": 0.1}, - thinking_id="tid-1", - chat_stream=SimpleNamespace(session_id="stream-1"), - log_prefix="[test]", - shutting_down=False, - plugin_config={"enabled": True}, - ) - - assert success is True - assert reason == "runtime action executed" - assert captured["method"] == "plugin.invoke_action" - assert captured["plugin_id"] == plugin_id - assert captured["component_name"] == action_name - assert captured["args"]["stream_id"] == "stream-1" - assert captured["args"]["chat_id"] == "stream-1" - assert captured["args"]["reasoning"] == "当前适合使用这个动作" - assert captured["args"]["target"] == "MaiBot" - assert captured["args"]["action_data"] == {"target": "MaiBot"} - finally: - supervisor._remove_core_action_mirrors(plugin_id) - core_component_registry.remove_action(action_name) + assert success is True + assert reason == "runtime action executed" + assert captured["method"] == "plugin.invoke_action" + assert captured["plugin_id"] == plugin_id + assert captured["component_name"] == action_name + assert captured["args"]["stream_id"] == "stream-1" + assert captured["args"]["chat_id"] == "stream-1" + assert captured["args"]["reasoning"] == "当前适合使用这个动作" + assert captured["args"]["target"] == "MaiBot" + assert captured["args"]["action_data"] == {"target": "MaiBot"} -def test_clear_runner_state_removes_mirrored_runtime_actions() -> None: - """清理 Runner 状态时应同步移除旧核心注册表中的镜像 Action。""" - plugin_id = "runtime_action_bridge_cleanup_plugin" - action_name = "runtime_action_bridge_cleanup_test" - payload = _build_action_payload(plugin_id=plugin_id, action_name=action_name) +@pytest.mark.asyncio +async def test_core_component_registry_reads_runtime_command_and_executor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """核心查询层应直接使用运行时命令匹配与执行闭包。""" + + plugin_id = "runtime_command_bridge_plugin" + command_name = "runtime_command_bridge_test" + supervisor = PluginSupervisor(plugin_dirs=[]) + captured: dict[str, Any] = {} + + supervisor.component_registry.register_component( + name=command_name, + component_type="COMMAND", + plugin_id=plugin_id, + metadata={ + "description": "测试命令", + "enabled": True, + "command_pattern": r"^/test(?:\s+.+)?$", + "aliases": ["/hello"], + "intercept_message_level": 1, + }, + ) + _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"}) + + async def fake_invoke_plugin( + method: str, + plugin_id: str, + component_name: str, + args: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟命令 RPC 调用。""" + + captured["method"] = method + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)}) + + monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) + + matched = component_query_service.find_command_by_text("/test hello") + assert matched is not None + command_executor, matched_groups, command_info = matched + + assert matched_groups == {} + assert command_info.plugin_name == plugin_id + assert command_info.command_pattern == r"^/test(?:\s+.+)?$" + + success, response_text, intercept = await command_executor( + message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"), + plugin_config={"mode": "command"}, + matched_groups=matched_groups, + ) + + assert success is True + assert response_text == "command ok" + assert intercept is True + assert captured["method"] == "plugin.invoke_command" + assert captured["plugin_id"] == plugin_id + assert captured["component_name"] == command_name + assert captured["args"]["text"] == "/test hello" + assert captured["args"]["stream_id"] == "stream-2" + assert captured["args"]["plugin_config"] == {"mode": "command"} + + +@pytest.mark.asyncio +async def test_core_component_registry_reads_runtime_tools_and_executor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """核心查询层应直接读取运行时 Tool,并返回 RPC 执行闭包。""" + + plugin_id = "runtime_tool_bridge_plugin" + tool_name = "runtime_tool_bridge_test" supervisor = PluginSupervisor(plugin_dirs=[]) - core_component_registry.remove_action(action_name) + supervisor.component_registry.register_component( + name=tool_name, + component_type="TOOL", + plugin_id=plugin_id, + metadata={ + "description": "测试工具", + "enabled": True, + "parameters": [ + { + "name": "query", + "param_type": "string", + "description": "查询词", + "required": True, + } + ], + }, + ) + _install_runtime_manager(monkeypatch, supervisor, plugin_id) - supervisor._mirror_runtime_actions_to_core_registry(payload) - assert core_component_registry.get_action_info(action_name) is not None + async def fake_invoke_plugin( + method: str, + plugin_id: str, + component_name: str, + args: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟工具 RPC 调用。""" - supervisor._clear_runner_state() + del timeout_ms + assert method == "plugin.invoke_tool" + assert plugin_id == "runtime_tool_bridge_plugin" + assert component_name == "runtime_tool_bridge_test" + assert args == {"query": "MaiBot"} + return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}}) - assert core_component_registry.get_action_info(action_name) is None + monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) + + tool_info = component_query_service.get_tool_info(tool_name) + assert isinstance(tool_info, ToolInfo) + assert tool_info.tool_description == "测试工具" + assert tool_name in component_query_service.get_llm_available_tools() + + executor = component_query_service.get_tool_executor(tool_name) + assert executor is not None + assert await executor({"query": "MaiBot"}) == {"content": "tool ok"} diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py index 12b103a0..709be8ee 100644 --- a/src/chat/brain_chat/brain_planner.py +++ b/src/chat/brain_chat/brain_planner.py @@ -1,30 +1,32 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + import json -import time -import traceback import random import re -from typing import Dict, Optional, Tuple, List, TYPE_CHECKING -from rich.traceback import install -from datetime import datetime -from json_repair import repair_json +import time +import traceback + +from json_repair import repair_json +from rich.traceback import install -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger from src.chat.logger.plan_reply_logger import PlanReplyLogger +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.chat.planner_actions.action_manager import ActionManager +from src.chat.utils.utils import get_chat_type_and_target_info from src.common.data_models.info_data_model import ActionPlannerInfo +from src.common.logger import get_logger from src.common.utils.utils_action import ActionUtils +from src.config.config import global_config, model_config +from src.core.types import ActionActivationType, ActionInfo, ComponentType +from src.llm_models.utils_model import LLMRequest +from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager from src.services.message_service import ( build_readable_messages_with_id, get_actions_by_timestamp_with_chat, get_messages_before_time_in_chat, ) -from src.chat.utils.utils import get_chat_type_and_target_info -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.core.types import ActionActivationType, ActionInfo, ComponentType -from src.core.component_registry import component_registry if TYPE_CHECKING: from src.common.data_models.info_data_model import TargetPersonInfo @@ -320,7 +322,7 @@ class BrainPlanner: current_available_actions_dict = self.action_manager.get_using_actions() # 获取完整的动作信息 - all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore + all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore ComponentType.ACTION ) current_available_actions = {} diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 23e7de6e..025150fc 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -1,19 +1,19 @@ from contextlib import suppress -import traceback -import os - -from maim_message import MessageBase from typing import Any, Dict, Optional +import os +import traceback +from maim_message import MessageBase + +from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.common.logger import get_logger from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_session import SessionUtils -from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver # from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.core.announcement_manager import global_announcement_manager -from src.core.component_registry import component_registry +from src.plugin_runtime.component_query import component_query_service from .message import SessionMessage from .chat_manager import chat_manager @@ -58,16 +58,22 @@ class ChatBot: logger.error(f"创建PFC聊天失败: {e}") logger.error(traceback.format_exc()) - async def _process_commands(self, message: SessionMessage): - # sourcery skip: use-named-expression - """使用新插件系统处理命令""" + async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]: + """使用统一组件注册表处理命令。 + + Args: + message: 当前待处理的会话消息。 + + Returns: + tuple[bool, Optional[str], bool]: ``(是否命中命令, 命令响应文本, 是否继续后续处理)``。 + """ if not message.processed_plain_text: return False, None, True # 没有文本内容,继续处理消息 try: text = message.processed_plain_text - # 使用核心组件注册表查找命令 - command_result = component_registry.find_command_by_text(text) + # 使用插件运行时统一查询服务查找命令 + command_result = component_query_service.find_command_by_text(text) if command_result: command_executor, matched_groups, command_info = command_result plugin_name = command_info.plugin_name @@ -81,7 +87,7 @@ class ChatBot: message.is_command = True # 获取插件配置 - plugin_config = component_registry.get_plugin_config(plugin_name) + plugin_config = component_query_service.get_plugin_config(plugin_name) try: # 调用命令执行器 @@ -112,88 +118,32 @@ class ChatBot: # 命令出错时,根据命令的拦截设置决定是否继续处理消息 return True, str(e), False # 出错时继续处理消息 - # 没有找到旧系统命令,尝试新版本插件运行时 - new_cmd_result = await self._process_new_runtime_command(message) - return new_cmd_result if new_cmd_result is not None else (False, None, True) + return False, None, True except Exception as e: logger.error(f"处理命令时出错: {e}") return False, None, True # 出错时继续处理消息 - async def _process_new_runtime_command(self, message: SessionMessage): - """尝试在新版本插件运行时中查找并执行命令 - - Returns: - (found, response, continue_processing) 三元组, - 或 None 表示新运行时中也未找到匹配命令。 - """ - from src.plugin_runtime.integration import get_plugin_runtime_manager - - prm = get_plugin_runtime_manager() - if not prm.is_running: - return None - - matched = prm.find_command_by_text(message.processed_plain_text) - if matched is None: - return None - - command_name = matched["name"] - if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands( - message.session_id - ): - logger.info(f"[新运行时] 用户禁用的命令,跳过处理: {matched['full_name']}") - return False, None, True - - message.is_command = True - logger.info(f"[新运行时] 匹配命令: {matched['full_name']}") - - try: - resp = await prm.invoke_plugin( - method="plugin.invoke_command", - plugin_id=matched["plugin_id"], - component_name=matched["name"], - args={ - "text": message.processed_plain_text, - "stream_id": message.session_id or "", - "matched_groups": matched.get("matched_groups") or {}, - }, - timeout_ms=30000, - ) - - payload = resp.payload - success = payload.get("success", False) - cmd_result = payload.get("result") - - # 拦截位优先从命令返回值中获取(支持运行时动态决定), - # 回退到组件 metadata 中的静态声明 - if isinstance(cmd_result, (list, tuple)) and len(cmd_result) >= 3: - # 命令返回 (found, response_text, intercept_bool) 三元组 - response_text = cmd_result[1] if cmd_result[1] is not None else "" - intercept = bool(cmd_result[2]) - else: - response_text = cmd_result if cmd_result is not None else "" - intercept = bool(matched["metadata"].get("intercept_message_level", 0)) - - self._mark_command_message(message, int(intercept)) - - if success: - logger.info(f"[新运行时] 命令执行成功: {matched['full_name']}") - else: - logger.warning(f"[新运行时] 命令执行失败: {matched['full_name']} - {response_text}") - - return True, response_text, not intercept - - except Exception as e: - logger.error(f"[新运行时] 执行命令 {matched['full_name']} 异常: {e}", exc_info=True) - return True, str(e), True - @staticmethod def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None: + """标记消息已经被命令链消费。 + + Args: + message: 待标记的会话消息。 + intercept_message_level: 命令设置的拦截级别。 + """ + message.is_command = True message.message_info.additional_config["intercept_message_level"] = intercept_message_level @staticmethod def _store_intercepted_command_message(message: SessionMessage) -> None: + """将被命令链拦截的消息写入数据库。 + + Args: + message: 已完成命令处理的会话消息。 + """ + MessageUtils.store_message_to_db(message) async def _handle_command_processing_result( diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 167cdcab..8133ac18 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -3,8 +3,8 @@ from typing import Dict, Optional, Tuple from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger -from src.core.component_registry import component_registry, ActionExecutor from src.core.types import ActionInfo +from src.plugin_runtime.component_query import ActionExecutor, component_query_service logger = get_logger("action_manager") @@ -28,7 +28,7 @@ class ActionManager: """ 动作管理器,用于管理各种类型的动作 - 使用核心组件注册表的 executor-based 模式。 + 使用插件运行时统一查询服务的 executor-based 模式。 """ def __init__(self): @@ -38,7 +38,7 @@ class ActionManager: self._using_actions: Dict[str, ActionInfo] = {} # 初始化时将默认动作加载到使用中的动作 - self._using_actions = component_registry.get_default_actions() + self._using_actions = component_query_service.get_default_actions() # === 执行Action方法 === @@ -72,17 +72,17 @@ class ActionManager: Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None """ try: - executor = component_registry.get_action_executor(action_name) + executor = component_query_service.get_action_executor(action_name) if not executor: logger.warning(f"{log_prefix} 未找到Action组件: {action_name}") return None - info = component_registry.get_action_info(action_name) + info = component_query_service.get_action_info(action_name) if not info: logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") return None - plugin_config = component_registry.get_plugin_config(info.plugin_name) or {} + plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {} handle = ActionHandle( executor, @@ -133,5 +133,5 @@ class ActionManager: def restore_actions(self) -> None: """恢复到默认动作集""" actions_to_restore = list(self._using_actions.keys()) - self._using_actions = component_registry.get_default_actions() + self._using_actions = component_query_service.get_default_actions() logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 5184abcb..b21efa6b 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -1,33 +1,36 @@ +from collections import OrderedDict +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import contextlib import json -import time -import traceback import random import re -import contextlib -from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union -from collections import OrderedDict -from rich.traceback import install -from datetime import datetime +import time +import traceback + from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger +from rich.traceback import install + from src.chat.logger.plan_reply_logger import PlanReplyLogger +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.chat.message_receive.message import SessionMessage +from src.chat.planner_actions.action_manager import ActionManager +from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.common.data_models.info_data_model import ActionPlannerInfo +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.core.types import ActionActivationType, ActionInfo, ComponentType +from src.llm_models.utils_model import LLMRequest +from src.person_info.person_info import Person +from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager from src.services.message_service import ( build_readable_messages_with_id, - replace_user_references, get_messages_before_time_in_chat, + replace_user_references, translate_pid_to_description, ) -from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.message_receive.message import SessionMessage -from src.core.types import ActionActivationType, ActionInfo, ComponentType -from src.core.component_registry import component_registry -from src.person_info.person_info import Person if TYPE_CHECKING: from src.common.data_models.info_data_model import TargetPersonInfo @@ -634,7 +637,7 @@ class ActionPlanner: current_available_actions_dict = self.action_manager.get_using_actions() # 获取完整的动作信息 - all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore + all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore ComponentType.ACTION ) current_available_actions = {} diff --git a/src/chat/tool_executor.py b/src/chat/tool_executor.py index d449f7a1..aa99fce8 100644 --- a/src/chat/tool_executor.py +++ b/src/chat/tool_executor.py @@ -1,22 +1,20 @@ -""" -工具执行器 +"""工具执行器。 独立的工具执行组件,可以直接输入聊天消息内容, 自动判断并执行相应的工具,返回结构化的工具执行结果。 - -从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。 """ +from typing import Any, Dict, List, Optional, Tuple + import hashlib import time -from typing import Any, Dict, List, Optional, Tuple from src.common.logger import get_logger from src.config.config import global_config, model_config from src.core.announcement_manager import global_announcement_manager -from src.core.component_registry import component_registry from src.llm_models.payload_content import ToolCall from src.llm_models.utils_model import LLMRequest +from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager logger = get_logger("tool_use") @@ -89,7 +87,7 @@ class ToolExecutor: def _get_tool_definitions(self) -> List[Dict[str, Any]]: """获取 LLM 可用的工具定义列表""" - all_tools = component_registry.get_llm_available_tools() + all_tools = component_query_service.get_llm_available_tools() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools] @@ -152,7 +150,7 @@ class ToolExecutor: function_args = tool_call.args or {} function_args["llm_called"] = True - executor = component_registry.get_tool_executor(function_name) + executor = component_query_service.get_tool_executor(function_name) if not executor: logger.warning(f"未知工具名称: {function_name}") return None diff --git a/src/core/component_registry.py b/src/core/component_registry.py deleted file mode 100644 index bb58682a..00000000 --- a/src/core/component_registry.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -核心组件注册表 - -面向最终架构的组件管理: -- Action:注册 ActionInfo + 执行器(本地 callable 或 IPC 路由) -- Command:注册正则模式 + 执行器 -- Tool:注册工具定义 + 执行器 - -不依赖任何插件基类,组件执行器是纯 async callable。 -""" - -import re -from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple - -from src.common.logger import get_logger -from src.core.types import ( - ActionInfo, - CommandInfo, - ComponentInfo, - ComponentType, - ToolInfo, -) - -logger = get_logger("component_registry") - -# 执行器类型 -ActionExecutor = Callable[..., Awaitable[Any]] -CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]] -ToolExecutor = Callable[..., Awaitable[Any]] - - -class ComponentRegistry: - """核心组件注册表 - - 管理 action、command、tool 三类组件。 - 每个组件由「元信息 + 执行器」构成,执行器是 async callable, - 不需要继承任何基类。 - """ - - def __init__(self): - # Action 注册 - self._actions: Dict[str, ActionInfo] = {} - self._action_executors: Dict[str, ActionExecutor] = {} - self._default_actions: Dict[str, ActionInfo] = {} - - # Command 注册 - self._commands: Dict[str, CommandInfo] = {} - self._command_executors: Dict[str, CommandExecutor] = {} - self._command_patterns: Dict[Pattern, str] = {} - - # Tool 注册 - self._tools: Dict[str, ToolInfo] = {} - self._tool_executors: Dict[str, ToolExecutor] = {} - self._llm_available_tools: Dict[str, ToolInfo] = {} - - # 插件配置(plugin_name -> config dict) - self._plugin_configs: Dict[str, dict] = {} - - logger.info("核心组件注册表初始化完成") - - # ========== Action ========== - - def register_action( - self, - info: ActionInfo, - executor: ActionExecutor, - ) -> bool: - """注册 action - - Args: - info: action 元信息 - executor: 执行器,async callable - """ - name = info.name - if name in self._actions: - logger.warning(f"Action {name} 已存在,跳过注册") - return False - - self._actions[name] = info - self._action_executors[name] = executor - - if info.enabled: - self._default_actions[name] = info - - logger.debug(f"注册 Action: {name}") - return True - - def get_action_info(self, name: str) -> Optional[ActionInfo]: - return self._actions.get(name) - - def get_action_executor(self, name: str) -> Optional[ActionExecutor]: - return self._action_executors.get(name) - - def get_default_actions(self) -> Dict[str, ActionInfo]: - return self._default_actions.copy() - - def get_all_actions(self) -> Dict[str, ActionInfo]: - return self._actions.copy() - - def remove_action(self, name: str) -> bool: - if name not in self._actions: - return False - del self._actions[name] - self._action_executors.pop(name, None) - self._default_actions.pop(name, None) - logger.debug(f"移除 Action: {name}") - return True - - # ========== Command ========== - - def register_command( - self, - info: CommandInfo, - executor: CommandExecutor, - ) -> bool: - """注册 command""" - name = info.name - if name in self._commands: - logger.warning(f"Command {name} 已存在,跳过注册") - return False - - self._commands[name] = info - self._command_executors[name] = executor - - if info.enabled and info.command_pattern: - pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL) - self._command_patterns[pattern] = name - - logger.debug(f"注册 Command: {name}") - return True - - def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]: - """根据文本查找匹配的命令 - - Returns: - (executor, matched_groups, command_info) 或 None - """ - candidates = [p for p in self._command_patterns if p.match(text)] - if not candidates: - return None - if len(candidates) > 1: - logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个") - pattern = candidates[0] - name = self._command_patterns[pattern] - return ( - self._command_executors[name], - pattern.match(text).groupdict(), # type: ignore - self._commands[name], - ) - - def remove_command(self, name: str) -> bool: - if name not in self._commands: - return False - del self._commands[name] - self._command_executors.pop(name, None) - self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name} - logger.debug(f"移除 Command: {name}") - return True - - # ========== Tool ========== - - def register_tool( - self, - info: ToolInfo, - executor: ToolExecutor, - ) -> bool: - """注册 tool""" - name = info.name - if name in self._tools: - logger.warning(f"Tool {name} 已存在,跳过注册") - return False - - self._tools[name] = info - self._tool_executors[name] = executor - - if info.enabled: - self._llm_available_tools[name] = info - - logger.debug(f"注册 Tool: {name}") - return True - - def get_tool_info(self, name: str) -> Optional[ToolInfo]: - return self._tools.get(name) - - def get_tool_executor(self, name: str) -> Optional[ToolExecutor]: - return self._tool_executors.get(name) - - def get_llm_available_tools(self) -> Dict[str, ToolInfo]: - return self._llm_available_tools.copy() - - def get_all_tools(self) -> Dict[str, ToolInfo]: - return self._tools.copy() - - def remove_tool(self, name: str) -> bool: - if name not in self._tools: - return False - del self._tools[name] - self._tool_executors.pop(name, None) - self._llm_available_tools.pop(name, None) - logger.debug(f"移除 Tool: {name}") - return True - - # ========== 通用查询 ========== - - def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]: - """获取组件元信息""" - match component_type: - case ComponentType.ACTION: - return self._actions.get(name) - case ComponentType.COMMAND: - return self._commands.get(name) - case ComponentType.TOOL: - return self._tools.get(name) - case _: - return None - - def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: - """获取某类型的所有组件""" - match component_type: - case ComponentType.ACTION: - return dict(self._actions) - case ComponentType.COMMAND: - return dict(self._commands) - case ComponentType.TOOL: - return dict(self._tools) - case _: - return {} - - # ========== 插件配置 ========== - - def set_plugin_config(self, plugin_name: str, config: dict) -> None: - self._plugin_configs[plugin_name] = config - - def get_plugin_config(self, plugin_name: str) -> Optional[dict]: - return self._plugin_configs.get(plugin_name) - - -# 全局单例 -component_registry = ComponentRegistry() diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index cb5996b4..be03e35d 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -178,12 +178,6 @@ class PlatformIOManager: return self._receive_route_table - @property - def route_table(self) -> RouteTable: - """兼容旧接口,返回发送路由表。""" - - return self._send_route_table - @property def deduplicator(self) -> MessageDeduplicator: """返回管理器持有的入站去重器。 @@ -369,12 +363,6 @@ class PlatformIOManager: self._validate_binding_against_driver(binding, driver) self._receive_route_table.bind(binding) - def bind_route(self, binding: RouteBinding) -> None: - """兼容旧接口,默认同时绑定发送表和接收表。""" - - self.bind_send_route(binding) - self.bind_receive_route(binding) - def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: """移除发送路由绑定。 @@ -395,12 +383,6 @@ class PlatformIOManager: self._receive_route_table.unbind(route_key, driver_id) - def unbind_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None: - """兼容旧接口,默认同时从发送表和接收表解绑。""" - - self.unbind_send_route(route_key, driver_id) - self.unbind_receive_route(route_key, driver_id) - def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]: """解析某个路由键当前命中的全部发送驱动。 @@ -430,12 +412,6 @@ class PlatformIOManager: return [] return [fallback_driver] - def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]: - """兼容旧接口,返回首个命中的发送驱动。""" - - drivers = self.resolve_drivers(route_key) - return drivers[0] if drivers else None - @staticmethod def build_route_key_from_message(message: "SessionMessage") -> RouteKey: """根据 ``SessionMessage`` 构造路由键。 diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py index def5f03d..9bb1755b 100644 --- a/src/plugin_runtime/capabilities/core.py +++ b/src/plugin_runtime/capabilities/core.py @@ -238,14 +238,14 @@ class RuntimeCoreCapabilityMixin: return {"success": False, "value": None, "error": str(e)} async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.core.component_registry import component_registry as core_registry + from src.plugin_runtime.component_query import component_query_service plugin_name: str = args.get("plugin_name", plugin_id) key: str = args.get("key", "") default = args.get("default") try: - config = core_registry.get_plugin_config(plugin_name) + config = component_query_service.get_plugin_config(plugin_name) if config is None: return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"} @@ -258,11 +258,11 @@ class RuntimeCoreCapabilityMixin: return {"success": False, "value": default, "error": str(e)} async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.core.component_registry import component_registry as core_registry + from src.plugin_runtime.component_query import component_query_service plugin_name: str = args.get("plugin_name", plugin_id) try: - config = core_registry.get_plugin_config(plugin_name) + config = component_query_service.get_plugin_config(plugin_name) if config is None: return {"success": True, "value": {}} return {"success": True, "value": config} diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index c4ae0a56..fdf8d898 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -648,10 +648,10 @@ class RuntimeDataCapabilityMixin: return {"success": False, "error": str(e)} async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.core.component_registry import component_registry as core_registry + from src.plugin_runtime.component_query import component_query_service try: - tools = core_registry.get_llm_available_tools() + tools = component_query_service.get_llm_available_tools() return { "success": True, "tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()], diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py new file mode 100644 index 00000000..7d23d202 --- /dev/null +++ b/src/plugin_runtime/component_query.py @@ -0,0 +1,709 @@ +"""插件运行时统一组件查询服务。 + +该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图, +供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple + +from src.common.logger import get_logger +from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo +from src.llm_models.payload_content.tool_option import ToolParamType + +if TYPE_CHECKING: + from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry + from src.plugin_runtime.host.supervisor import PluginSupervisor + from src.plugin_runtime.integration import PluginRuntimeManager + +logger = get_logger("plugin_runtime.component_query") + +ActionExecutor = Callable[..., Awaitable[Any]] +CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]] +ToolExecutor = Callable[..., Awaitable[Any]] + +_HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = { + ComponentType.ACTION: "ACTION", + ComponentType.COMMAND: "COMMAND", + ComponentType.TOOL: "TOOL", +} +_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = { + "string": ToolParamType.STRING, + "integer": ToolParamType.INTEGER, + "float": ToolParamType.FLOAT, + "boolean": ToolParamType.BOOLEAN, + "bool": ToolParamType.BOOLEAN, +} + + +class ComponentQueryService: + """插件运行时统一组件查询服务。 + + 该对象不维护独立状态,只读取插件系统中的注册结果。 + 所有注册、删除、配置写入等写操作都被显式禁用。 + """ + + @staticmethod + def _get_runtime_manager() -> "PluginRuntimeManager": + """获取插件运行时管理器单例。 + + Returns: + PluginRuntimeManager: 当前全局插件运行时管理器。 + """ + + from src.plugin_runtime.integration import get_plugin_runtime_manager + + return get_plugin_runtime_manager() + + def _iter_supervisors(self) -> list["PluginSupervisor"]: + """获取当前所有活跃的插件运行时监督器。 + + Returns: + list[PluginSupervisor]: 当前运行中的监督器列表。 + """ + + runtime_manager = self._get_runtime_manager() + return list(runtime_manager.supervisors) + + def _iter_component_entries( + self, + component_type: ComponentType, + *, + enabled_only: bool = True, + ) -> list[tuple["PluginSupervisor", "ComponentEntry"]]: + """遍历指定类型的全部组件条目。 + + Args: + component_type: 目标组件类型。 + enabled_only: 是否仅返回启用状态的组件。 + + Returns: + list[tuple[PluginSupervisor, ComponentEntry]]: ``(监督器, 组件条目)`` 列表。 + """ + + host_component_type = _HOST_COMPONENT_TYPE_MAP.get(component_type) + if host_component_type is None: + return [] + + collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = [] + for supervisor in self._iter_supervisors(): + for component in supervisor.component_registry.get_components_by_type( + host_component_type, + enabled_only=enabled_only, + ): + collected_entries.append((supervisor, component)) + return collected_entries + + @staticmethod + def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType: + """规范化动作激活类型。 + + Args: + raw_value: 原始激活类型值。 + + Returns: + ActionActivationType: 规范化后的激活类型枚举。 + """ + + normalized_value = str(raw_value or "").strip().lower() + if normalized_value == ActionActivationType.NEVER.value: + return ActionActivationType.NEVER + if normalized_value == ActionActivationType.RANDOM.value: + return ActionActivationType.RANDOM + if normalized_value == ActionActivationType.KEYWORD.value: + return ActionActivationType.KEYWORD + return ActionActivationType.ALWAYS + + @staticmethod + def _coerce_float(value: Any, default: float = 0.0) -> float: + """将任意值安全转换为浮点数。 + + Args: + value: 待转换的输入值。 + default: 转换失败时返回的默认值。 + + Returns: + float: 转换后的浮点结果。 + """ + + try: + return float(value) + except (TypeError, ValueError): + return default + + @staticmethod + def _build_action_info(entry: "ActionEntry") -> ActionInfo: + """将运行时 Action 条目转换为核心动作信息。 + + Args: + entry: 插件运行时中的 Action 条目。 + + Returns: + ActionInfo: 供核心 Planner 使用的动作信息。 + """ + + metadata = dict(entry.metadata) + raw_action_parameters = metadata.get("action_parameters") + action_parameters = ( + { + str(param_name): str(param_description) + for param_name, param_description in raw_action_parameters.items() + } + if isinstance(raw_action_parameters, dict) + else {} + ) + action_require = [ + str(item) + for item in (metadata.get("action_require") or []) + if item is not None and str(item).strip() + ] + associated_types = [ + str(item) + for item in (metadata.get("associated_types") or []) + if item is not None and str(item).strip() + ] + activation_keywords = [ + str(item) + for item in (metadata.get("activation_keywords") or []) + if item is not None and str(item).strip() + ] + + return ActionInfo( + name=entry.name, + component_type=ComponentType.ACTION, + description=str(metadata.get("description", "") or ""), + enabled=bool(entry.enabled), + plugin_name=entry.plugin_id, + metadata=metadata, + action_parameters=action_parameters, + action_require=action_require, + associated_types=associated_types, + activation_type=ComponentQueryService._coerce_action_activation_type(metadata.get("activation_type")), + random_activation_probability=ComponentQueryService._coerce_float( + metadata.get("activation_probability"), + 0.0, + ), + activation_keywords=activation_keywords, + parallel_action=bool(metadata.get("parallel_action", False)), + ) + + @staticmethod + def _build_command_info(entry: "CommandEntry") -> CommandInfo: + """将运行时 Command 条目转换为核心命令信息。 + + Args: + entry: 插件运行时中的 Command 条目。 + + Returns: + CommandInfo: 供核心命令链使用的命令信息。 + """ + + metadata = dict(entry.metadata) + return CommandInfo( + name=entry.name, + component_type=ComponentType.COMMAND, + description=str(metadata.get("description", "") or ""), + enabled=bool(entry.enabled), + plugin_name=entry.plugin_id, + metadata=metadata, + command_pattern=str(metadata.get("command_pattern", "") or ""), + ) + + @staticmethod + def _coerce_tool_param_type(raw_value: Any) -> ToolParamType: + """规范化工具参数类型。 + + Args: + raw_value: 原始工具参数类型值。 + + Returns: + ToolParamType: 规范化后的工具参数类型。 + """ + + normalized_value = str(raw_value or "").strip().lower() + return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING) + + @staticmethod + def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]: + """将运行时工具参数元数据转换为核心 ToolInfo 参数列表。 + + Args: + entry: 插件运行时中的 Tool 条目。 + + Returns: + list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。 + """ + + structured_parameters = entry.parameters if isinstance(entry.parameters, list) else [] + if not structured_parameters and isinstance(entry.parameters_raw, dict): + structured_parameters = [ + {"name": key, **value} + for key, value in entry.parameters_raw.items() + if isinstance(value, dict) + ] + + normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] + for parameter in structured_parameters: + if not isinstance(parameter, dict): + continue + + parameter_name = str(parameter.get("name", "") or "").strip() + if not parameter_name: + continue + + enum_values = parameter.get("enum") + normalized_enum_values = ( + [str(item) for item in enum_values if item is not None] + if isinstance(enum_values, list) + else None + ) + normalized_parameters.append( + ( + parameter_name, + ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")), + str(parameter.get("description", "") or ""), + bool(parameter.get("required", True)), + normalized_enum_values, + ) + ) + return normalized_parameters + + @staticmethod + def _build_tool_info(entry: "ToolEntry") -> ToolInfo: + """将运行时 Tool 条目转换为核心工具信息。 + + Args: + entry: 插件运行时中的 Tool 条目。 + + Returns: + ToolInfo: 供 ToolExecutor 与能力层使用的工具信息。 + """ + + return ToolInfo( + name=entry.name, + component_type=ComponentType.TOOL, + description=entry.description, + enabled=bool(entry.enabled), + plugin_name=entry.plugin_id, + metadata=dict(entry.metadata), + tool_parameters=ComponentQueryService._build_tool_parameters(entry), + tool_description=entry.description, + ) + + @staticmethod + def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None: + """记录重复组件名称冲突。 + + Args: + component_type: 组件类型。 + component_name: 发生冲突的组件名称。 + """ + + logger.warning(f"检测到重复{component_type.value}名称 {component_name},将只保留首个匹配项") + + def _get_unique_component_entry( + self, + component_type: ComponentType, + name: str, + ) -> Optional[tuple["PluginSupervisor", "ComponentEntry"]]: + """按组件短名解析唯一条目。 + + Args: + component_type: 目标组件类型。 + name: 组件短名。 + + Returns: + Optional[tuple[PluginSupervisor, ComponentEntry]]: 唯一命中的组件条目。 + """ + + matched_entries = [ + (supervisor, entry) + for supervisor, entry in self._iter_component_entries(component_type) + if entry.name == name + ] + if not matched_entries: + return None + if len(matched_entries) > 1: + self._log_duplicate_component(component_type, name) + return matched_entries[0] + + def _collect_unique_component_infos( + self, + component_type: ComponentType, + ) -> Dict[str, ComponentInfo]: + """收集某类组件的唯一信息视图。 + + Args: + component_type: 目标组件类型。 + + Returns: + Dict[str, ComponentInfo]: 组件名到核心组件信息的映射。 + """ + + collected_components: Dict[str, ComponentInfo] = {} + for _supervisor, entry in self._iter_component_entries(component_type): + if entry.name in collected_components: + self._log_duplicate_component(component_type, entry.name) + continue + + if component_type == ComponentType.ACTION: + collected_components[entry.name] = self._build_action_info(entry) # type: ignore[arg-type] + elif component_type == ComponentType.COMMAND: + collected_components[entry.name] = self._build_command_info(entry) # type: ignore[arg-type] + elif component_type == ComponentType.TOOL: + collected_components[entry.name] = self._build_tool_info(entry) # type: ignore[arg-type] + return collected_components + + @staticmethod + def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str: + """从旧 ActionManager 参数中提取聊天流 ID。 + + Args: + kwargs: 旧动作执行器收到的关键字参数。 + + Returns: + str: 提取出的 ``stream_id``。 + """ + + chat_stream = kwargs.get("chat_stream") + if chat_stream is not None: + try: + return str(chat_stream.session_id) + except AttributeError: + pass + + return str(kwargs.get("stream_id", "") or "") + + @staticmethod + def _build_action_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ActionExecutor: + """构造动作执行 RPC 闭包。 + + Args: + supervisor: 负责该组件的监督器。 + plugin_id: 插件 ID。 + component_name: 组件名称。 + + Returns: + ActionExecutor: 兼容旧 Planner 的异步执行器。 + """ + + async def _executor(**kwargs: Any) -> tuple[bool, str]: + """将核心动作调用桥接到插件运行时。 + + Args: + **kwargs: 旧 ActionManager 传入的上下文参数。 + + Returns: + tuple[bool, str]: ``(是否成功, 结果说明)``。 + """ + + invoke_args: Dict[str, Any] = {} + action_data = kwargs.get("action_data") + if isinstance(action_data, dict): + invoke_args.update(action_data) + + stream_id = ComponentQueryService._extract_stream_id_from_action_kwargs(kwargs) + invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {} + invoke_args["stream_id"] = stream_id + invoke_args["chat_id"] = stream_id + invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "") + + if (thinking_id := kwargs.get("thinking_id")) is not None: + invoke_args["thinking_id"] = str(thinking_id) + if isinstance(kwargs.get("cycle_timers"), dict): + invoke_args["cycle_timers"] = kwargs["cycle_timers"] + if isinstance(kwargs.get("plugin_config"), dict): + invoke_args["plugin_config"] = kwargs["plugin_config"] + if isinstance(kwargs.get("log_prefix"), str): + invoke_args["log_prefix"] = kwargs["log_prefix"] + if isinstance(kwargs.get("shutting_down"), bool): + invoke_args["shutting_down"] = kwargs["shutting_down"] + + try: + response = await supervisor.invoke_plugin( + method="plugin.invoke_action", + plugin_id=plugin_id, + component_name=component_name, + args=invoke_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True) + return False, str(exc) + + payload = response.payload if isinstance(response.payload, dict) else {} + success = bool(payload.get("success", False)) + result = payload.get("result") + if isinstance(result, (list, tuple)): + if len(result) >= 2: + return bool(result[0]), "" if result[1] is None else str(result[1]) + if len(result) == 1: + return bool(result[0]), "" + if success: + return True, "" if result is None else str(result) + return False, "" if result is None else str(result) + + return _executor + + @staticmethod + def _build_command_executor( + supervisor: "PluginSupervisor", + plugin_id: str, + component_name: str, + metadata: Dict[str, Any], + ) -> CommandExecutor: + """构造命令执行 RPC 闭包。 + + Args: + supervisor: 负责该组件的监督器。 + plugin_id: 插件 ID。 + component_name: 组件名称。 + metadata: 命令组件元数据。 + + Returns: + CommandExecutor: 兼容旧消息命令链的执行器。 + """ + + async def _executor(**kwargs: Any) -> tuple[bool, Optional[str], bool]: + """将核心命令调用桥接到插件运行时。 + + Args: + **kwargs: 命令执行上下文参数。 + + Returns: + tuple[bool, Optional[str], bool]: ``(是否成功, 返回文本, 是否拦截后续消息)``。 + """ + + message = kwargs.get("message") + matched_groups = kwargs.get("matched_groups") + plugin_config = kwargs.get("plugin_config") + invoke_args: Dict[str, Any] = { + "text": str(getattr(message, "processed_plain_text", "") or ""), + "stream_id": str(getattr(message, "session_id", "") or ""), + "matched_groups": matched_groups if isinstance(matched_groups, dict) else {}, + } + if isinstance(plugin_config, dict): + invoke_args["plugin_config"] = plugin_config + + try: + response = await supervisor.invoke_plugin( + method="plugin.invoke_command", + plugin_id=plugin_id, + component_name=component_name, + args=invoke_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"运行时 Command {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True) + return False, str(exc), True + + payload = response.payload if isinstance(response.payload, dict) else {} + success = bool(payload.get("success", False)) + result = payload.get("result") + intercept = bool(metadata.get("intercept_message_level", 0)) + response_text: Optional[str] + + if isinstance(result, (list, tuple)) and len(result) >= 3: + response_text = None if result[1] is None else str(result[1]) + intercept = bool(result[2]) + else: + response_text = None if result is None else str(result) + + return success, response_text, intercept + + return _executor + + @staticmethod + def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor: + """构造工具执行 RPC 闭包。 + + Args: + supervisor: 负责该组件的监督器。 + plugin_id: 插件 ID。 + component_name: 组件名称。 + + Returns: + ToolExecutor: 兼容旧 ToolExecutor 的异步执行器。 + """ + + async def _executor(function_args: Dict[str, Any]) -> Any: + """将核心工具调用桥接到插件运行时。 + + Args: + function_args: 工具调用参数。 + + Returns: + Any: 插件工具返回结果;若结果不是字典,则会包装为 ``{"content": ...}``。 + """ + + try: + response = await supervisor.invoke_plugin( + method="plugin.invoke_tool", + plugin_id=plugin_id, + component_name=component_name, + args=function_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"运行时 Tool {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True) + return {"content": f"工具 {component_name} 执行失败: {exc}"} + + payload = response.payload if isinstance(response.payload, dict) else {} + result = payload.get("result") + if isinstance(result, dict): + return result + return {"content": "" if result is None else str(result)} + + return _executor + + def get_action_info(self, name: str) -> Optional[ActionInfo]: + """获取指定动作的信息。 + + Args: + name: 动作名称。 + + Returns: + Optional[ActionInfo]: 匹配到的动作信息。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name) + if matched_entry is None: + return None + _supervisor, entry = matched_entry + return self._build_action_info(entry) # type: ignore[arg-type] + + def get_action_executor(self, name: str) -> Optional[ActionExecutor]: + """获取指定动作的执行器。 + + Args: + name: 动作名称。 + + Returns: + Optional[ActionExecutor]: 运行时 RPC 执行闭包。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name) + if matched_entry is None: + return None + supervisor, entry = matched_entry + return self._build_action_executor(supervisor, entry.plugin_id, entry.name) + + def get_default_actions(self) -> Dict[str, ActionInfo]: + """获取当前默认启用的动作集合。 + + Returns: + Dict[str, ActionInfo]: 动作名到动作信息的映射。 + """ + + action_infos = self._collect_unique_component_infos(ComponentType.ACTION) + return {name: info for name, info in action_infos.items() if isinstance(info, ActionInfo) and info.enabled} + + def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]: + """根据文本查找匹配的命令。 + + Args: + text: 待匹配的文本内容。 + + Returns: + Optional[Tuple[CommandExecutor, dict, CommandInfo]]: 匹配结果。 + """ + + for supervisor in self._iter_supervisors(): + match_result = supervisor.component_registry.find_command_by_text(text) + if match_result is None: + continue + + entry, matched_groups = match_result + command_info = self._build_command_info(entry) # type: ignore[arg-type] + command_executor = self._build_command_executor( + supervisor, + entry.plugin_id, + entry.name, + dict(entry.metadata), + ) + return command_executor, matched_groups, command_info + return None + + def get_tool_info(self, name: str) -> Optional[ToolInfo]: + """获取指定工具的信息。 + + Args: + name: 工具名称。 + + Returns: + Optional[ToolInfo]: 匹配到的工具信息。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name) + if matched_entry is None: + return None + _supervisor, entry = matched_entry + return self._build_tool_info(entry) # type: ignore[arg-type] + + def get_tool_executor(self, name: str) -> Optional[ToolExecutor]: + """获取指定工具的执行器。 + + Args: + name: 工具名称。 + + Returns: + Optional[ToolExecutor]: 运行时 RPC 执行闭包。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name) + if matched_entry is None: + return None + supervisor, entry = matched_entry + return self._build_tool_executor(supervisor, entry.plugin_id, entry.name) + + def get_llm_available_tools(self) -> Dict[str, ToolInfo]: + """获取当前可供 LLM 选择的工具集合。 + + Returns: + Dict[str, ToolInfo]: 工具名到工具信息的映射。 + """ + + tool_infos = self._collect_unique_component_infos(ComponentType.TOOL) + return {name: info for name, info in tool_infos.items() if isinstance(info, ToolInfo) and info.enabled} + + def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: + """获取某类组件的全部信息。 + + Args: + component_type: 组件类型。 + + Returns: + Dict[str, ComponentInfo]: 组件名到组件信息的映射。 + """ + + return self._collect_unique_component_infos(component_type) + + def get_plugin_config(self, plugin_name: str) -> Optional[dict]: + """读取指定插件的配置文件内容。 + + Args: + plugin_name: 插件名称。 + + Returns: + Optional[dict]: 读取成功时返回配置字典;未找到时返回 ``None``。 + """ + + runtime_manager = self._get_runtime_manager() + try: + supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name) + except RuntimeError as exc: + logger.error(f"读取插件配置失败: {exc}") + return None + + if supervisor is None: + return None + + try: + return runtime_manager._load_plugin_config_for_supervisor(supervisor, plugin_name) + except Exception as exc: + logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True) + return None + + +component_query_service = ComponentQueryService() diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 08b0ea3b..1c073490 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -31,12 +31,12 @@ class ComponentTypes(str, Enum): class StatusDict(TypedDict): total: int - ACTION: int - COMMAND: int - TOOL: int - EVENT_HANDLER: int - HOOK_HANDLER: int - MESSAGE_GATEWAY: int + action: int + command: int + tool: int + event_handler: int + hook_handler: int + message_gateway: int plugins: int @@ -185,6 +185,23 @@ class ComponentRegistry: # 按插件索引 self._by_plugin: Dict[str, List[ComponentEntry]] = {} + @staticmethod + def _normalize_component_type(component_type: str) -> ComponentTypes: + """规范化组件类型输入。 + + Args: + component_type: 原始组件类型字符串。 + + Returns: + ComponentTypes: 规范化后的组件类型枚举。 + + Raises: + ValueError: 当组件类型不受支持时抛出。 + """ + + normalized_value = str(component_type or "").strip().upper() + return ComponentTypes(normalized_value) + def clear(self) -> None: """清空全部组件注册状态。""" self._components.clear() @@ -205,18 +222,19 @@ class ComponentRegistry: success (bool): 是否成功注册(失败原因通常是组件类型无效) """ try: - if component_type == ComponentTypes.ACTION: - comp = ActionEntry(name, component_type, plugin_id, metadata) - elif component_type == ComponentTypes.COMMAND: - comp = CommandEntry(name, component_type, plugin_id, metadata) - elif component_type == ComponentTypes.TOOL: - comp = ToolEntry(name, component_type, plugin_id, metadata) - elif component_type == ComponentTypes.EVENT_HANDLER: - comp = EventHandlerEntry(name, component_type, plugin_id, metadata) - elif component_type == ComponentTypes.HOOK_HANDLER: - comp = HookHandlerEntry(name, component_type, plugin_id, metadata) - elif component_type == ComponentTypes.MESSAGE_GATEWAY: - comp = MessageGatewayEntry(name, component_type, plugin_id, metadata) + normalized_type = self._normalize_component_type(component_type) + if normalized_type == ComponentTypes.ACTION: + comp = ActionEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.COMMAND: + comp = CommandEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.TOOL: + comp = ToolEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.EVENT_HANDLER: + comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.HOOK_HANDLER: + comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata) + elif normalized_type == ComponentTypes.MESSAGE_GATEWAY: + comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata) else: raise ValueError(f"组件类型 {component_type} 不存在") except ValueError: @@ -304,6 +322,20 @@ class ComponentRegistry: comp.enabled = enabled return True + def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: + """设置指定组件的启用状态。 + + Args: + full_name: 组件全名。 + enabled: 目标启用状态。 + session_id: 可选的会话 ID,仅对该会话生效。 + + Returns: + bool: 是否设置成功。 + """ + + return self.toggle_component_status(full_name, enabled, session_id=session_id) + def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int: """批量启用或禁用某插件的所有组件。 @@ -348,7 +380,7 @@ class ComponentRegistry: components (List[ComponentEntry]): 组件条目列表 """ try: - comp_type = ComponentTypes(component_type) + comp_type = self._normalize_component_type(component_type) except ValueError: logger.error(f"组件类型 {component_type} 不存在") raise @@ -536,6 +568,6 @@ class ComponentRegistry: """ stats: StatusDict = {"total": len(self._components)} # type: ignore for comp_type, type_dict in self._by_type.items(): - stats[comp_type.value] = len(type_dict) + stats[comp_type.value.lower()] = len(type_dict) stats["plugins"] = len(self._by_plugin) return stats diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 3588934e..4a9885f8 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -9,8 +9,6 @@ import sys from src.common.logger import get_logger from src.config.config import global_config -from src.core.component_registry import component_registry as core_component_registry -from src.core.types import ActionActivationType, ActionInfo, ComponentType as CoreComponentType from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager from src.platform_io.drivers import PluginPlatformDriver from src.platform_io.route_key_factory import RouteKeyFactory @@ -107,7 +105,6 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {} - self._mirrored_core_actions: Dict[str, List[str]] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -510,7 +507,6 @@ class PluginRunnerSupervisor: except Exception as exc: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) - self._remove_core_action_mirrors(payload.plugin_id) self._component_registry.remove_components_by_plugin(payload.plugin_id) await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) @@ -520,7 +516,6 @@ class PluginRunnerSupervisor: ) self._registered_plugins[payload.plugin_id] = payload self._message_gateway_states[payload.plugin_id] = {} - self._mirror_runtime_actions_to_core_registry(payload) return envelope.make_response( payload={ @@ -550,7 +545,6 @@ class PluginRunnerSupervisor: removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) self._authorization.revoke_permission_token(payload.plugin_id) removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None - self._remove_core_action_mirrors(payload.plugin_id) await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) self._message_gateway_states.pop(payload.plugin_id, None) @@ -564,236 +558,6 @@ class PluginRunnerSupervisor: } ) - @staticmethod - def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType: - """将运行时 Action 激活类型转换为旧核心枚举。 - - Args: - raw_value: 插件运行时声明中的激活类型值。 - - Returns: - ActionActivationType: 可供旧 Planner 使用的激活类型枚举。 - """ - normalized_value = str(raw_value or ActionActivationType.ALWAYS.value).strip().lower() - try: - return ActionActivationType(normalized_value) - except ValueError: - return ActionActivationType.ALWAYS - - @staticmethod - def _coerce_float(value: Any, default: float = 0.0) -> float: - """将任意输入尽量转换为浮点数。 - - Args: - value: 待转换的值。 - default: 转换失败时使用的默认值。 - - Returns: - float: 转换结果。 - """ - try: - return float(value) - except (TypeError, ValueError): - return default - - @staticmethod - def _build_core_action_info(plugin_id: str, component_name: str, metadata: Dict[str, Any]) -> ActionInfo: - """将运行时 Action 元数据映射为旧核心 ActionInfo。 - - Args: - plugin_id: 插件 ID。 - component_name: 组件名称。 - metadata: 运行时组件元数据。 - - Returns: - ActionInfo: 兼容旧 Planner 的动作定义。 - """ - activation_keywords = [ - str(item) - for item in (metadata.get("activation_keywords") or []) - if item is not None and str(item).strip() - ] - action_require = [ - str(item) - for item in (metadata.get("action_require") or []) - if item is not None and str(item).strip() - ] - associated_types = [ - str(item) - for item in (metadata.get("associated_types") or []) - if item is not None and str(item).strip() - ] - raw_action_parameters = metadata.get("action_parameters") or {} - action_parameters = { - str(param_name): str(param_description) - for param_name, param_description in raw_action_parameters.items() - } if isinstance(raw_action_parameters, dict) else {} - - return ActionInfo( - name=component_name, - component_type=CoreComponentType.ACTION, - description=str(metadata.get("description", "") or ""), - enabled=bool(metadata.get("enabled", True)), - plugin_name=plugin_id, - metadata=dict(metadata), - action_parameters=action_parameters, - action_require=action_require, - associated_types=associated_types, - activation_type=PluginRunnerSupervisor._coerce_action_activation_type(metadata.get("activation_type")), - random_activation_probability=PluginRunnerSupervisor._coerce_float( - metadata.get("activation_probability"), - 0.0, - ), - activation_keywords=activation_keywords, - parallel_action=bool(metadata.get("parallel_action", False)), - ) - - @staticmethod - def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str: - """从旧 ActionManager 传入参数中提取聊天流 ID。 - - Args: - kwargs: 旧动作执行器收到的关键字参数。 - - Returns: - str: 可用于新运行时 Action 的 ``stream_id``。 - """ - chat_stream = kwargs.get("chat_stream") - if chat_stream is not None: - try: - return str(chat_stream.session_id) - except AttributeError: - pass - - raw_stream_id = kwargs.get("stream_id", "") - return str(raw_stream_id or "") - - def _build_runtime_action_executor( - self, - plugin_id: str, - component_name: str, - ) -> Any: - """构造一个转发到 plugin runtime 的旧核心 Action 执行器。 - - Args: - plugin_id: 目标插件 ID。 - component_name: 目标 Action 组件名称。 - - Returns: - Callable[..., Coroutine[Any, Any, tuple[bool, str]]]: 兼容旧 ActionManager 的执行器。 - """ - - async def _executor(**kwargs: Any) -> tuple[bool, str]: - """将旧 Planner 的动作调用桥接到 plugin runtime。 - - Args: - **kwargs: 旧 ActionManager 传入的运行时上下文参数。 - - Returns: - tuple[bool, str]: ``(是否成功, 动作说明)``。 - """ - invoke_args: Dict[str, Any] = {} - action_data = kwargs.get("action_data") - if isinstance(action_data, dict): - invoke_args.update(action_data) - - stream_id = self._extract_stream_id_from_action_kwargs(kwargs) - invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {} - invoke_args["stream_id"] = stream_id - invoke_args["chat_id"] = stream_id - invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "") - - thinking_id = kwargs.get("thinking_id") - if thinking_id is not None: - invoke_args["thinking_id"] = str(thinking_id) - - cycle_timers = kwargs.get("cycle_timers") - if isinstance(cycle_timers, dict): - invoke_args["cycle_timers"] = cycle_timers - - plugin_config = kwargs.get("plugin_config") - if isinstance(plugin_config, dict): - invoke_args["plugin_config"] = plugin_config - - log_prefix = kwargs.get("log_prefix") - if isinstance(log_prefix, str): - invoke_args["log_prefix"] = log_prefix - - shutting_down = kwargs.get("shutting_down") - if isinstance(shutting_down, bool): - invoke_args["shutting_down"] = shutting_down - - try: - response = await self.invoke_plugin( - method="plugin.invoke_action", - plugin_id=plugin_id, - component_name=component_name, - args=invoke_args, - timeout_ms=30000, - ) - except Exception as exc: - logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True) - return False, str(exc) - - payload = response.payload if isinstance(response.payload, dict) else {} - success = bool(payload.get("success", False)) - result = payload.get("result") - - if isinstance(result, (list, tuple)): - if len(result) >= 2: - return bool(result[0]), "" if result[1] is None else str(result[1]) - if len(result) == 1: - return bool(result[0]), "" - - if success: - return True, "" if result is None else str(result) - return False, "" if result is None else str(result) - - return _executor - - def _mirror_runtime_actions_to_core_registry(self, payload: RegisterPluginPayload) -> None: - """将 plugin runtime 中声明的 Action 镜像到旧核心注册表。 - - Args: - payload: 当前插件的注册载荷。 - """ - mirrored_action_names: List[str] = [] - - for component in payload.components: - if str(component.component_type).upper() != CoreComponentType.ACTION.name: - continue - - action_info = self._build_core_action_info( - plugin_id=payload.plugin_id, - component_name=component.name, - metadata=component.metadata, - ) - action_executor = self._build_runtime_action_executor( - plugin_id=payload.plugin_id, - component_name=component.name, - ) - registered = core_component_registry.register_action(action_info, action_executor) - if not registered: - logger.warning( - f"运行时 Action {payload.plugin_id}.{component.name} 无法镜像到旧核心注册表," - "可能与现有 Action 重名" - ) - continue - mirrored_action_names.append(component.name) - - if mirrored_action_names: - self._mirrored_core_actions[payload.plugin_id] = mirrored_action_names - - def _remove_core_action_mirrors(self, plugin_id: str) -> None: - """移除某个插件镜像到旧核心注册表的所有 Action。 - - Args: - plugin_id: 目标插件 ID。 - """ - mirrored_action_names = self._mirrored_core_actions.pop(plugin_id, []) - for action_name in mirrored_action_names: - core_component_registry.remove_action(action_name) - @staticmethod def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str: """构造消息网关驱动 ID。 @@ -1407,8 +1171,6 @@ class PluginRunnerSupervisor: def _clear_runner_state(self) -> None: """清理当前 Runner 对应的 Host 侧注册状态。""" - for plugin_id in list(self._mirrored_core_actions.keys()): - self._remove_core_action_mirrors(plugin_id) self._authorization.clear() self._component_registry.clear() self._registered_plugins.clear() diff --git a/src/services/send_service.py b/src/services/send_service.py index 7903cdeb..54f2a9de 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -8,10 +8,9 @@ 3. 具体走插件链还是 legacy 旧链,由 Platform IO 内部统一决策。 """ +from copy import deepcopy from typing import Any, Dict, List, Optional -from maim_message import Seg - import asyncio import base64 import hashlib @@ -28,6 +27,7 @@ from src.common.data_models.message_component_data_model import ( AtComponent, DictComponent, EmojiComponent, + ForwardNodeComponent, ImageComponent, MessageSequence, ReplyComponent, @@ -72,88 +72,163 @@ def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[s if normalized_value: inherited_metadata[key] = value - if target_stream.group_id: - normalized_group_id = str(target_stream.group_id).strip() - if normalized_group_id: - inherited_metadata["platform_io_target_group_id"] = normalized_group_id + if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()): + inherited_metadata["platform_io_target_group_id"] = normalized_group_id - if target_stream.user_id: - normalized_user_id = str(target_stream.user_id).strip() - if normalized_user_id: - inherited_metadata["platform_io_target_user_id"] = normalized_user_id + if target_stream.user_id and (normalized_user_id := str(target_stream.user_id).strip()): + inherited_metadata["platform_io_target_user_id"] = normalized_user_id return inherited_metadata -def _build_component_from_seg(message_segment: Seg) -> StandardMessageComponents: - """将单个消息段转换为内部消息组件。 +def _build_binary_component_from_base64(component_type: str, raw_data: str) -> StandardMessageComponents: + """根据 Base64 数据构造二进制消息组件。 Args: - message_segment: 待转换的消息段。 + component_type: 组件类型名称。 + raw_data: Base64 编码后的二进制数据。 Returns: StandardMessageComponents: 转换后的内部消息组件。 + + Raises: + ValueError: 当组件类型不受支持时抛出。 """ - segment_type = str(message_segment.type or "").strip().lower() - segment_data = message_segment.data + binary_data = base64.b64decode(raw_data) + binary_hash = hashlib.sha256(binary_data).hexdigest() - if segment_type == "text": - return TextComponent(text=str(segment_data or "")) - - if segment_type == "image": - image_binary = base64.b64decode(str(segment_data or "")) - return ImageComponent( - binary_hash=hashlib.sha256(image_binary).hexdigest(), - binary_data=image_binary, - ) - - if segment_type == "emoji": - emoji_binary = base64.b64decode(str(segment_data or "")) - return EmojiComponent( - binary_hash=hashlib.sha256(emoji_binary).hexdigest(), - binary_data=emoji_binary, - ) - - if segment_type == "voice": - voice_binary = base64.b64decode(str(segment_data or "")) - return VoiceComponent( - binary_hash=hashlib.sha256(voice_binary).hexdigest(), - binary_data=voice_binary, - ) - - if segment_type == "at": - return AtComponent(target_user_id=str(segment_data or "")) - - if segment_type == "reply": - return ReplyComponent(target_message_id=str(segment_data or "")) - - if segment_type == "dict" and isinstance(segment_data, dict): - return DictComponent(data=segment_data) - - return DictComponent(data={"type": segment_type, "data": segment_data}) + if component_type == "image": + return ImageComponent(binary_hash=binary_hash, binary_data=binary_data) + if component_type == "emoji": + return EmojiComponent(binary_hash=binary_hash, binary_data=binary_data) + if component_type == "voice": + return VoiceComponent(binary_hash=binary_hash, binary_data=binary_data) + raise ValueError(f"不支持的二进制组件类型: {component_type}") -def _build_message_sequence_from_seg(message_segment: Seg) -> MessageSequence: - """将消息段转换为内部消息组件序列。 +def _build_message_sequence_from_custom_message( + message_type: str, + content: str | Dict[str, Any], +) -> MessageSequence: + """根据自定义消息类型构造内部消息组件序列。 Args: - message_segment: 待转换的消息段。 + message_type: 自定义消息类型。 + content: 自定义消息内容。 Returns: MessageSequence: 转换后的消息组件序列。 """ - if str(message_segment.type or "").strip().lower() == "seglist": - raw_segments = message_segment.data - if not isinstance(raw_segments, list): - raise ValueError("seglist 类型的消息段数据必须是列表") - components = [ - _build_component_from_seg(item) - for item in raw_segments - if isinstance(item, Seg) - ] - return MessageSequence(components=components) + normalized_type = message_type.strip().lower() - return MessageSequence(components=[_build_component_from_seg(message_segment)]) + if normalized_type == "text": + return MessageSequence(components=[TextComponent(text=str(content))]) + + if normalized_type in {"image", "emoji", "voice"}: + return MessageSequence( + components=[_build_binary_component_from_base64(normalized_type, str(content))] + ) + + if normalized_type == "at": + return MessageSequence(components=[AtComponent(target_user_id=str(content))]) + + if normalized_type == "reply": + return MessageSequence(components=[ReplyComponent(target_message_id=str(content))]) + + if normalized_type == "dict" and isinstance(content, dict): + return MessageSequence(components=[DictComponent(data=deepcopy(content))]) + + return MessageSequence( + components=[ + DictComponent( + data={ + "type": normalized_type, + "data": deepcopy(content), + } + ) + ] + ) + + +def _clone_message_sequence(message_sequence: MessageSequence) -> MessageSequence: + """复制消息组件序列,避免原对象被发送流程修改。 + + Args: + message_sequence: 原始消息组件序列。 + + Returns: + MessageSequence: 深拷贝后的消息组件序列。 + """ + return deepcopy(message_sequence) + + +def _detect_outbound_message_flags(message_sequence: MessageSequence) -> Dict[str, bool]: + """根据消息组件序列推断出站消息标记。 + + Args: + message_sequence: 待发送的消息组件序列。 + + Returns: + Dict[str, bool]: 包含 ``is_emoji``、``is_picture``、``is_command`` 的标记字典。 + """ + if len(message_sequence.components) != 1: + return { + "is_emoji": False, + "is_picture": False, + "is_command": False, + } + + component = message_sequence.components[0] + is_command = False + if isinstance(component, DictComponent) and isinstance(component.data, dict): + is_command = str(component.data.get("type") or "").strip().lower() == "command" + + return { + "is_emoji": isinstance(component, EmojiComponent), + "is_picture": isinstance(component, ImageComponent), + "is_command": is_command, + } + + +def _describe_message_sequence(message_sequence: MessageSequence) -> str: + """生成消息组件序列的简短描述文本。 + + Args: + message_sequence: 待描述的消息组件序列。 + + Returns: + str: 适用于日志的简短类型描述。 + """ + if len(message_sequence.components) != 1: + return "message_sequence" + + component = message_sequence.components[0] + if isinstance(component, DictComponent) and isinstance(component.data, dict): + custom_type = str(component.data.get("type") or "").strip() + return custom_type or "dict" + + if isinstance(component, TextComponent): + return component.format_name + + if isinstance(component, ImageComponent): + return component.format_name + + if isinstance(component, EmojiComponent): + return component.format_name + + if isinstance(component, VoiceComponent): + return component.format_name + + if isinstance(component, AtComponent): + return component.format_name + + if isinstance(component, ReplyComponent): + return component.format_name + + if isinstance(component, ForwardNodeComponent): + return component.format_name + + return "unknown" def _build_processed_plain_text(message: SessionMessage) -> str: @@ -204,7 +279,7 @@ def _build_processed_plain_text(message: SessionMessage) -> str: def _build_outbound_session_message( - message_segment: Seg, + message_sequence: MessageSequence, stream_id: str, display_message: str = "", reply_message: Optional[MaiMessage] = None, @@ -213,7 +288,7 @@ def _build_outbound_session_message( """根据目标会话构建待发送的内部消息对象。 Args: - message_segment: 待发送的消息段。 + message_sequence: 待发送的消息组件序列。 stream_id: 目标会话 ID。 display_message: 用于界面展示的文本内容。 reply_message: 被回复的锚点消息。 @@ -268,13 +343,14 @@ def _build_outbound_session_message( group_info=group_info, additional_config=additional_config, ) - outbound_message.raw_message = _build_message_sequence_from_seg(message_segment) + outbound_message.raw_message = _clone_message_sequence(message_sequence) outbound_message.session_id = target_stream.session_id outbound_message.display_message = display_message outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None - outbound_message.is_emoji = message_segment.type == "emoji" - outbound_message.is_picture = message_segment.type == "image" - outbound_message.is_command = message_segment.type == "command" + message_flags = _detect_outbound_message_flags(outbound_message.raw_message) + outbound_message.is_emoji = message_flags["is_emoji"] + outbound_message.is_picture = message_flags["is_picture"] + outbound_message.is_command = message_flags["is_command"] outbound_message.initialized = True return outbound_message @@ -467,7 +543,7 @@ async def send_session_message( async def _send_to_target( - message_segment: Seg, + message_sequence: MessageSequence, stream_id: str, display_message: str = "", typing: bool = False, @@ -480,7 +556,7 @@ async def _send_to_target( """向指定目标构建并发送消息。 Args: - message_segment: 待发送的消息段。 + message_sequence: 待发送的消息组件序列。 stream_id: 目标会话 ID。 display_message: 用于界面展示的文本内容。 typing: 是否显示输入中状态。 @@ -499,10 +575,10 @@ async def _send_to_target( return False if show_log: - logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}") + logger.debug(f"[SendService] 发送{_describe_message_sequence(message_sequence)}消息到 {stream_id}") outbound_message = _build_outbound_session_message( - message_segment=message_segment, + message_sequence=message_sequence, stream_id=stream_id, display_message=display_message, reply_message=reply_message, @@ -555,7 +631,7 @@ async def text_to_stream( bool: 发送成功时返回 ``True``。 """ return await _send_to_target( - message_segment=Seg(type="text", data=text), + message_sequence=MessageSequence(components=[TextComponent(text=text)]), stream_id=stream_id, display_message="", typing=typing, @@ -586,7 +662,7 @@ async def emoji_to_stream( bool: 发送成功时返回 ``True``。 """ return await _send_to_target( - message_segment=Seg(type="emoji", data=emoji_base64), + message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64), stream_id=stream_id, display_message="", typing=False, @@ -616,7 +692,7 @@ async def image_to_stream( bool: 发送成功时返回 ``True``。 """ return await _send_to_target( - message_segment=Seg(type="image", data=image_base64), + message_sequence=_build_message_sequence_from_custom_message("image", image_base64), stream_id=stream_id, display_message="", typing=False, @@ -654,7 +730,7 @@ async def custom_to_stream( bool: 发送成功时返回 ``True``。 """ return await _send_to_target( - message_segment=Seg(type=message_type, data=content), # type: ignore[arg-type] + message_sequence=_build_message_sequence_from_custom_message(message_type, content), stream_id=stream_id, display_message=display_message, typing=typing, @@ -688,28 +764,15 @@ async def custom_reply_set_to_stream( show_log: 是否输出发送日志。 Returns: - bool: 全部组件发送成功时返回 ``True``。 + bool: 发送成功时返回 ``True``。 """ - success = True - for component in reply_set.components: - if isinstance(component, DictComponent): - message_seg = Seg(type="dict", data=component.data) # type: ignore[arg-type] - else: - message_seg = await component.to_seg() - - status = await _send_to_target( - message_segment=message_seg, - stream_id=stream_id, - display_message=display_message, - typing=typing, - reply_message=reply_message, - set_reply=set_reply, - storage_message=storage_message, - show_log=show_log, - ) - if not status: - success = False - logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}") - set_reply = False - - return success + return await _send_to_target( + message_sequence=reply_set, + stream_id=stream_id, + display_message=display_message, + typing=typing, + reply_message=reply_message, + set_reply=set_reply, + storage_message=storage_message, + show_log=show_log, + ) From 9dea6b0e6fdeae1be119eaeb0a1449ac64d10900 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 17:18:05 +0800 Subject: [PATCH 33/45] feat: implement dedicated API registry and enhance API handling capabilities - Added APIEntry and APIRegistry classes for managing plugin APIs. - Updated PluginRunnerSupervisor to include API registry and methods for invoking APIs. - Enhanced PluginRuntimeManager to support API registration and invocation. - Created tests for API registration, invocation, and visibility between plugins. - Refactored component handling to distinguish between runtime components and APIs. --- pytests/test_plugin_runtime.py | 9 +- pytests/test_plugin_runtime_api.py | 294 +++++++++++++++ src/plugin_runtime/capabilities/components.py | 344 +++++++++++++++++- src/plugin_runtime/capabilities/registry.py | 4 + src/plugin_runtime/host/api_registry.py | 290 +++++++++++++++ src/plugin_runtime/host/supervisor.py | 81 ++++- src/plugin_runtime/runner/runner_main.py | 1 + 7 files changed, 1012 insertions(+), 11 deletions(-) create mode 100644 pytests/test_plugin_runtime_api.py create mode 100644 src/plugin_runtime/host/api_registry.py diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 20cceb82..9dfc34d8 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -2152,8 +2152,11 @@ class TestIntegration: self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = FakeSupervisor("plugin_a") + manager._third_party_supervisor = FakeSupervisor("plugin_b") - result = await integration_module.PluginRuntimeManager._cap_component_enable( + result = await manager._cap_component_enable( "plugin_a", "component.enable", {"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""}, @@ -2182,8 +2185,10 @@ class TestIntegration: self.supervisors = [FakeSupervisor()] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = FakeSupervisor() - result = await integration_module.PluginRuntimeManager._cap_component_disable( + result = await manager._cap_component_disable( "plugin_a", "component.disable", {"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"}, diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py new file mode 100644 index 00000000..fca7736a --- /dev/null +++ b/pytests/test_plugin_runtime_api.py @@ -0,0 +1,294 @@ +"""插件 API 注册与调用测试。""" + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from src.plugin_runtime.integration import PluginRuntimeManager +from src.plugin_runtime.host.supervisor import PluginSupervisor +from src.plugin_runtime.protocol.envelope import ( + ComponentDeclaration, + Envelope, + MessageType, + RegisterPluginPayload, + UnregisterPluginPayload, +) + + +def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager: + """构造一个最小可用的插件运行时管理器。 + + Args: + *supervisors: 需要挂载的监督器列表。 + + Returns: + PluginRuntimeManager: 已注入监督器的运行时管理器。 + """ + + manager = PluginRuntimeManager() + if supervisors: + manager._builtin_supervisor = supervisors[0] + if len(supervisors) > 1: + manager._third_party_supervisor = supervisors[1] + return manager + + +async def _register_plugin( + supervisor: PluginSupervisor, + plugin_id: str, + components: List[Dict[str, Any]], +) -> Envelope: + """通过 Supervisor 注册测试插件。 + + Args: + supervisor: 目标监督器。 + plugin_id: 测试插件 ID。 + components: 组件声明列表。 + + Returns: + Envelope: 注册响应信封。 + """ + + payload = RegisterPluginPayload( + plugin_id=plugin_id, + plugin_version="1.0.0", + components=[ + ComponentDeclaration( + name=str(component.get("name", "") or ""), + component_type=str(component.get("component_type", "") or ""), + plugin_id=plugin_id, + metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}, + ) + for component in components + ], + ) + return await supervisor._handle_register_plugin( + Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="plugin.register_components", + plugin_id=plugin_id, + payload=payload.model_dump(), + ) + ) + + +async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope: + """通过 Supervisor 注销测试插件。 + + Args: + supervisor: 目标监督器。 + plugin_id: 测试插件 ID。 + + Returns: + Envelope: 注销响应信封。 + """ + + payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test") + return await supervisor._handle_unregister_plugin( + Envelope( + request_id=2, + message_type=MessageType.REQUEST, + method="plugin.unregister", + plugin_id=plugin_id, + payload=payload.model_dump(), + ) + ) + + +@pytest.mark.asyncio +async def test_register_plugin_syncs_dedicated_api_registry() -> None: + """插件注册时应将 API 同步到独立注册表,而不是通用组件表。""" + + supervisor = PluginSupervisor(plugin_dirs=[]) + response = await _register_plugin( + supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML", + "version": "1", + "public": True, + }, + } + ], + ) + + assert response.payload["accepted"] is True + assert response.payload["registered_components"] == 0 + assert response.payload["registered_apis"] == 1 + assert supervisor.api_registry.get_api("provider", "render_html") is not None + assert supervisor.component_registry.get_component("provider.render_html") is None + + unregister_response = await _unregister_plugin(supervisor, "provider") + assert unregister_response.payload["removed_apis"] == 1 + assert supervisor.api_registry.get_api("provider", "render_html") is None + + +@pytest.mark.asyncio +async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None: + """公开 API 应允许其他插件通过 Host 转发调用。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML", + "version": "1", + "public": True, + }, + } + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟 API RPC 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}}) + + monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "version": "1", + "args": {"html": "
Hello
"}, + }, + ) + + assert result == {"success": True, "result": {"image": "ok"}} + assert captured["plugin_id"] == "provider" + assert captured["component_name"] == "render_html" + assert captured["args"] == {"html": "
Hello
"} + + +@pytest.mark.asyncio +async def test_api_call_rejects_private_api_between_plugins() -> None: + """未公开的 API 默认不允许跨插件调用。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "secret_api", + "component_type": "API", + "metadata": { + "description": "私有 API", + "version": "1", + "public": False, + }, + } + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.secret_api", + "args": {}, + }, + ) + + assert result["success"] is False + assert "未公开" in str(result["error"]) + + +@pytest.mark.asyncio +async def test_api_list_and_component_toggle_use_dedicated_registry() -> None: + """API 列表与组件启停应直接作用于独立 API 注册表。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "public_api", + "component_type": "API", + "metadata": {"version": "1", "public": True}, + }, + { + "name": "private_api", + "component_type": "API", + "metadata": {"version": "1", "public": False}, + }, + ], + ) + await _register_plugin( + consumer_supervisor, + "consumer", + [ + { + "name": "self_private_api", + "component_type": "API", + "metadata": {"version": "1", "public": False}, + } + ], + ) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + list_result = await manager._cap_api_list("consumer", "api.list", {}) + + assert list_result["success"] is True + api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]} + assert ("provider", "public_api") in api_names + assert ("provider", "private_api") not in api_names + assert ("consumer", "self_private_api") in api_names + + disable_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.public_api", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert disable_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None + + enable_result = await manager._cap_component_enable( + "consumer", + "component.enable", + { + "name": "provider.public_api", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert enable_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index 4223525f..2eede108 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -6,7 +6,8 @@ from src.common.logger import get_logger logger = get_logger("plugin_runtime.integration") if TYPE_CHECKING: - from src.plugin_runtime.host.component_registry import RegisteredComponent + from src.plugin_runtime.host.api_registry import APIEntry + from src.plugin_runtime.host.component_registry import ComponentEntry from src.plugin_runtime.host.supervisor import PluginSupervisor @@ -18,7 +19,7 @@ class _RuntimeComponentManagerProtocol(Protocol): def _resolve_component_toggle_target( self, name: str, component_type: str - ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ... + ) -> tuple[Optional["ComponentEntry"], Optional[str]]: ... def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ... @@ -26,6 +27,203 @@ class _RuntimeComponentManagerProtocol(Protocol): class RuntimeComponentCapabilityMixin: + @staticmethod + def _normalize_component_type(component_type: str) -> str: + """规范化组件类型名称。 + + Args: + component_type: 原始组件类型。 + + Returns: + str: 统一转为大写后的组件类型名。 + """ + + return str(component_type or "").strip().upper() + + @classmethod + def _is_api_component_type(cls, component_type: str) -> bool: + """判断组件类型是否为 API。 + + Args: + component_type: 原始组件类型。 + + Returns: + bool: 是否为 API 组件类型。 + """ + + return cls._normalize_component_type(component_type) == "API" + + @staticmethod + def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]: + """将 API 组件条目序列化为能力返回值。 + + Args: + entry: API 组件条目。 + + Returns: + Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。 + """ + + return { + "name": entry.name, + "full_name": entry.full_name, + "plugin_id": entry.plugin_id, + "description": entry.description, + "version": entry.version, + "public": entry.public, + "enabled": entry.enabled, + "metadata": dict(entry.metadata), + } + + @classmethod + def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]: + """将 API 条目序列化为通用组件视图。 + + Args: + entry: API 组件条目。 + + Returns: + Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。 + """ + + serialized_entry = cls._serialize_api_entry(entry) + return { + "name": serialized_entry["name"], + "full_name": serialized_entry["full_name"], + "type": "API", + "enabled": serialized_entry["enabled"], + "metadata": serialized_entry["metadata"], + } + + @staticmethod + def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool: + """判断某个 API 是否对调用方可见。 + + Args: + entry: 目标 API 组件条目。 + caller_plugin_id: 调用方插件 ID。 + + Returns: + bool: 是否允许当前插件可见并调用。 + """ + + return entry.plugin_id == caller_plugin_id or entry.public + + def _resolve_api_target( + self: _RuntimeComponentManagerProtocol, + caller_plugin_id: str, + api_name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: + """解析 API 名称到唯一可调用的目标组件。 + + Args: + caller_plugin_id: 调用方插件 ID。 + api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 + version: 可选的 API 版本。 + + Returns: + tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: + 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 + """ + + normalized_api_name = str(api_name or "").strip() + normalized_version = str(version or "").strip() + if not normalized_api_name: + return None, None, "缺少必要参数 api_name" + + if "." in normalized_api_name: + target_plugin_id, target_api_name = normalized_api_name.split(".", 1) + try: + supervisor = self._get_supervisor_for_plugin(target_plugin_id) + except RuntimeError as exc: + return None, None, str(exc) + + if supervisor is None: + return None, None, f"未找到 API 提供方插件: {target_plugin_id}" + + entry = supervisor.api_registry.get_api( + plugin_id=target_plugin_id, + name=target_api_name, + enabled_only=True, + ) + if entry is None: + return None, None, f"未找到 API: {normalized_api_name}" + if normalized_version and entry.version != normalized_version: + return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" + if not self._is_api_visible_to_plugin(entry, caller_plugin_id): + return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" + return supervisor, entry, None + + visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + hidden_match_exists = False + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True): + if normalized_version and entry.version != normalized_version: + continue + if self._is_api_visible_to_plugin(entry, caller_plugin_id): + visible_matches.append((supervisor, entry)) + else: + hidden_match_exists = True + + if len(visible_matches) == 1: + return visible_matches[0][0], visible_matches[0][1], None + if len(visible_matches) > 1: + return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name" + if hidden_match_exists: + return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" + if normalized_version: + return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" + return None, None, f"未找到 API: {normalized_api_name}" + + def _resolve_api_toggle_target( + self: _RuntimeComponentManagerProtocol, + name: str, + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: + """解析需要启用或禁用的 API 组件。 + + Args: + name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 + + Returns: + tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: + 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 + """ + + normalized_name = str(name or "").strip() + if not normalized_name: + return None, None, "缺少必要参数 name" + + if "." in normalized_name: + plugin_id, api_name = normalized_name.split(".", 1) + try: + supervisor = self._get_supervisor_for_plugin(plugin_id) + except RuntimeError as exc: + return None, None, str(exc) + + if supervisor is None: + return None, None, f"未找到 API 提供方插件: {plugin_id}" + + entry = supervisor.api_registry.get_api( + plugin_id=plugin_id, + name=api_name, + enabled_only=False, + ) + if entry is None: + return None, None, f"未找到 API: {normalized_name}" + return supervisor, entry, None + + matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False): + matches.append((supervisor, entry)) + + if len(matches) == 1: + return matches[0][0], matches[0][1], None + if len(matches) > 1: + return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name" + return None, None, f"未找到 API: {normalized_name}" + async def _cap_component_get_all_plugins( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] ) -> Any: @@ -46,6 +244,10 @@ class RuntimeComponentCapabilityMixin: } for component in comps ] + components_list.extend( + self._serialize_api_component_entry(entry) + for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False) + ) result[pid] = { "name": pid, "version": reg.plugin_version, @@ -96,24 +298,28 @@ class RuntimeComponentCapabilityMixin: def _resolve_component_toggle_target( self: _RuntimeComponentManagerProtocol, name: str, component_type: str - ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: - short_name_matches: List["RegisteredComponent"] = [] + ) -> tuple[Optional["ComponentEntry"], Optional[str]]: + normalized_component_type = self._normalize_component_type(component_type) + short_name_matches: List["ComponentEntry"] = [] for sv in self.supervisors: comp = sv.component_registry.get_component(name) - if comp is not None and comp.component_type == component_type: + if comp is not None and comp.component_type == normalized_component_type: return comp, None short_name_matches.extend( candidate - for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False) + for candidate in sv.component_registry.get_components_by_type( + normalized_component_type, + enabled_only=False, + ) if candidate.name == name ) if len(short_name_matches) == 1: return short_name_matches[0], None if len(short_name_matches) > 1: - return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name" - return None, f"未找到组件: {name} ({component_type})" + return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name" + return None, f"未找到组件: {name} ({normalized_component_type})" async def _cap_component_enable( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] @@ -127,6 +333,13 @@ class RuntimeComponentCapabilityMixin: if scope != "global" or stream_id: return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"} + if self._is_api_component_type(component_type): + supervisor, api_entry, error = self._resolve_api_toggle_target(name) + if supervisor is None or api_entry is None: + return {"success": False, "error": error or f"未找到 API: {name}"} + supervisor.api_registry.toggle_api_status(api_entry.full_name, True) + return {"success": True} + comp, error = self._resolve_component_toggle_target(name, component_type) if comp is None: return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} @@ -146,6 +359,13 @@ class RuntimeComponentCapabilityMixin: if scope != "global" or stream_id: return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"} + if self._is_api_component_type(component_type): + supervisor, api_entry, error = self._resolve_api_toggle_target(name) + if supervisor is None or api_entry is None: + return {"success": False, "error": error or f"未找到 API: {name}"} + supervisor.api_registry.toggle_api_status(api_entry.full_name, False) + return {"success": True} + comp, error = self._resolve_component_toggle_target(name, component_type) if comp is None: return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} @@ -239,3 +459,111 @@ class RuntimeComponentCapabilityMixin: logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") return {"success": False, "error": str(e)} return {"success": False, "error": f"未找到插件: {plugin_name}"} + + async def _cap_api_call( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """调用其他插件公开的 API。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 调用结果。 + """ + + del capability + api_name = str(args.get("api_name", "") or "").strip() + version = str(args.get("version", "") or "").strip() + api_args = args.get("args", {}) + if not isinstance(api_args, dict): + return {"success": False, "error": "参数 args 必须为字典"} + + supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version) + if supervisor is None or entry is None: + return {"success": False, "error": error or "API 解析失败"} + + try: + response = await supervisor.invoke_api( + plugin_id=entry.plugin_id, + component_name=entry.name, + args=api_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} + + if response.error: + return {"success": False, "error": response.error.get("message", "API 调用失败")} + + payload = response.payload if isinstance(response.payload, dict) else {} + if not bool(payload.get("success", False)): + result = payload.get("result") + return {"success": False, "error": "" if result is None else str(result)} + return {"success": True, "result": payload.get("result")} + + async def _cap_api_get( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """获取当前插件可见的单个 API 元信息。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 元信息或 ``None``。 + """ + + del capability + api_name = str(args.get("api_name", "") or "").strip() + version = str(args.get("version", "") or "").strip() + if not api_name: + return {"success": False, "error": "缺少必要参数 api_name"} + + supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version) + if supervisor is None or entry is None: + return {"success": True, "api": None} + return {"success": True, "api": self._serialize_api_entry(entry)} + + async def _cap_api_list( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """列出当前插件可见的 API 列表。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 元信息列表。 + """ + + del capability + target_plugin_id = str(args.get("plugin_id", "") or "").strip() + apis: List[Dict[str, Any]] = [] + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis( + plugin_id=target_plugin_id or None, + enabled_only=True, + ): + if not self._is_api_visible_to_plugin(entry, plugin_id): + continue + apis.append(self._serialize_api_entry(entry)) + + apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"]))) + return {"success": True, "apis": apis} diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py index 96b190b4..31693833 100644 --- a/src/plugin_runtime/capabilities/registry.py +++ b/src/plugin_runtime/capabilities/registry.py @@ -74,6 +74,10 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi _register("tool.get_definitions", manager._cap_tool_get_definitions) + _register("api.call", manager._cap_api_call) + _register("api.get", manager._cap_api_get) + _register("api.list", manager._cap_api_list) + _register("component.get_all_plugins", manager._cap_component_get_all_plugins) _register("component.get_plugin_info", manager._cap_component_get_plugin_info) _register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) diff --git a/src/plugin_runtime/host/api_registry.py b/src/plugin_runtime/host/api_registry.py new file mode 100644 index 00000000..84578ca5 --- /dev/null +++ b/src/plugin_runtime/host/api_registry.py @@ -0,0 +1,290 @@ +"""Host 侧插件 API 动态注册表。""" + +from typing import Any, Dict, List, Optional, Set + +from src.common.logger import get_logger + +logger = get_logger("plugin_runtime.host.api_registry") + + +class APIEntry: + """API 组件条目。""" + + __slots__ = ( + "description", + "disabled_session", + "enabled", + "full_name", + "metadata", + "name", + "plugin_id", + "public", + "version", + ) + + def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + """初始化 API 组件条目。 + + Args: + name: API 名称。 + plugin_id: 所属插件 ID。 + metadata: API 元数据。 + """ + + self.name: str = name + self.full_name: str = f"{plugin_id}.{name}" + self.plugin_id: str = plugin_id + self.description: str = str(metadata.get("description", "") or "") + self.version: str = str(metadata.get("version", "1") or "1").strip() or "1" + self.public: bool = bool(metadata.get("public", False)) + self.metadata: Dict[str, Any] = dict(metadata) + self.enabled: bool = bool(metadata.get("enabled", True)) + self.disabled_session: Set[str] = set() + + +class APIRegistry: + """Host 侧插件 API 动态注册表。 + + 该注册表不直接面向 Runner,而是复用插件组件注册/卸载事件, + 维护面向 API 调用场景的专用索引。 + """ + + def __init__(self) -> None: + """初始化 API 注册表。""" + + self._apis: Dict[str, APIEntry] = {} + self._by_plugin: Dict[str, List[APIEntry]] = {} + self._by_name: Dict[str, List[APIEntry]] = {} + + def clear(self) -> None: + """清空全部 API 注册状态。""" + + self._apis.clear() + self._by_plugin.clear() + self._by_name.clear() + + @staticmethod + def _is_api_component(component_type: Any) -> bool: + """判断组件声明是否属于 API。 + + Args: + component_type: 原始组件类型值。 + + Returns: + bool: 是否为 API 组件。 + """ + + return str(component_type or "").strip().upper() == "API" + + @staticmethod + def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool: + """判断 API 条目当前是否处于启用状态。 + + Args: + entry: 待检查的 API 条目。 + session_id: 可选的会话 ID。 + + Returns: + bool: 当前是否可用。 + """ + + if session_id and session_id in entry.disabled_session: + return False + return entry.enabled + + def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: + """注册单个 API 条目。 + + Args: + name: API 名称。 + plugin_id: 所属插件 ID。 + metadata: API 元数据。 + + Returns: + bool: 是否成功注册。 + """ + + normalized_name = str(name or "").strip() + if not normalized_name: + logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略") + return False + + entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata) + if entry.full_name in self._apis: + logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目") + self._remove_entry(self._apis[entry.full_name]) + + self._apis[entry.full_name] = entry + self._by_plugin.setdefault(plugin_id, []).append(entry) + self._by_name.setdefault(entry.name, []).append(entry) + return True + + def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int: + """批量注册某个插件声明的全部 API。 + + Args: + plugin_id: 插件 ID。 + components: 插件组件声明列表。 + + Returns: + int: 成功注册的 API 数量。 + """ + + count = 0 + for component in components: + if not self._is_api_component(component.get("component_type")): + continue + if self.register_api( + name=str(component.get("name", "") or ""), + plugin_id=plugin_id, + metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}, + ): + count += 1 + return count + + def _remove_entry(self, entry: APIEntry) -> None: + """从全部索引中移除单个 API 条目。 + + Args: + entry: 待移除的 API 条目。 + """ + + self._apis.pop(entry.full_name, None) + plugin_entries = self._by_plugin.get(entry.plugin_id) + if plugin_entries is not None: + self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry] + if not self._by_plugin[entry.plugin_id]: + self._by_plugin.pop(entry.plugin_id, None) + + name_entries = self._by_name.get(entry.name) + if name_entries is not None: + self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry] + if not self._by_name[entry.name]: + self._by_name.pop(entry.name, None) + + def remove_apis_by_plugin(self, plugin_id: str) -> int: + """移除某个插件的全部 API。 + + Args: + plugin_id: 目标插件 ID。 + + Returns: + int: 被移除的 API 数量。 + """ + + entries = list(self._by_plugin.get(plugin_id, [])) + for entry in entries: + self._remove_entry(entry) + return len(entries) + + def get_api_by_full_name( + self, + full_name: str, + *, + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> Optional[APIEntry]: + """按完整名查询单个 API。 + + Args: + full_name: API 完整名,格式为 ``plugin_id.api_name``。 + enabled_only: 是否仅返回启用状态的 API。 + session_id: 可选的会话 ID。 + + Returns: + Optional[APIEntry]: 命中时返回 API 条目。 + """ + + entry = self._apis.get(full_name) + if entry is None: + return None + if enabled_only and not self.check_api_enabled(entry, session_id): + return None + return entry + + def get_api( + self, + plugin_id: str, + name: str, + *, + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> Optional[APIEntry]: + """按插件 ID 和短名查询单个 API。 + + Args: + plugin_id: 提供方插件 ID。 + name: API 短名。 + enabled_only: 是否仅返回启用状态的 API。 + session_id: 可选的会话 ID。 + + Returns: + Optional[APIEntry]: 命中时返回 API 条目。 + """ + + return self.get_api_by_full_name( + f"{plugin_id}.{name}", + enabled_only=enabled_only, + session_id=session_id, + ) + + def get_apis( + self, + *, + plugin_id: Optional[str] = None, + name: str = "", + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> List[APIEntry]: + """查询 API 列表。 + + Args: + plugin_id: 可选的插件 ID 过滤条件。 + name: 可选的 API 名称过滤条件。 + enabled_only: 是否仅返回启用状态的 API。 + session_id: 可选的会话 ID。 + + Returns: + List[APIEntry]: 符合条件的 API 条目列表。 + """ + + normalized_name = str(name or "").strip() + if plugin_id: + candidates = list(self._by_plugin.get(plugin_id, [])) + elif normalized_name: + candidates = list(self._by_name.get(normalized_name, [])) + else: + candidates = list(self._apis.values()) + + filtered_entries: List[APIEntry] = [] + for entry in candidates: + if normalized_name and entry.name != normalized_name: + continue + if enabled_only and not self.check_api_enabled(entry, session_id): + continue + filtered_entries.append(entry) + return filtered_entries + + def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: + """设置指定 API 的启用状态。 + + Args: + full_name: API 完整名。 + enabled: 目标启用状态。 + session_id: 可选的会话 ID,仅对该会话生效。 + + Returns: + bool: 是否设置成功。 + """ + + entry = self._apis.get(full_name) + if entry is None: + return False + if session_id: + if enabled: + entry.disabled_session.discard(session_id) + else: + entry.disabled_session.add(session_id) + else: + entry.enabled = enabled + return True diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 4a9885f8..1add64c6 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -34,6 +34,7 @@ from src.plugin_runtime.protocol.errors import ErrorCode, RPCError from src.plugin_runtime.transport.factory import create_transport_server from .authorization import AuthorizationManager +from .api_registry import APIRegistry from .capability_service import CapabilityService from .component_registry import ComponentRegistry from .event_dispatcher import EventDispatcher @@ -93,6 +94,7 @@ class PluginRunnerSupervisor: self._transport = create_transport_server(socket_path=socket_path) self._authorization = AuthorizationManager() self._capability_service = CapabilityService(self._authorization) + self._api_registry = APIRegistry() self._component_registry = ComponentRegistry() self._event_dispatcher = EventDispatcher(self._component_registry) self._hook_dispatcher = HookDispatcher(self._component_registry) @@ -124,6 +126,11 @@ class PluginRunnerSupervisor: """返回能力服务。""" return self._capability_service + @property + def api_registry(self) -> APIRegistry: + """返回 API 专用注册表。""" + return self._api_registry + @property def component_registry(self) -> ComponentRegistry: """返回组件注册表。""" @@ -310,6 +317,33 @@ class PluginRunnerSupervisor: timeout_ms=timeout_ms, ) + async def invoke_api( + self, + plugin_id: str, + component_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Envelope: + """调用插件声明的 API 方法。 + + Args: + plugin_id: 目标插件 ID。 + component_name: API 组件名称。 + args: 传递给 API 方法的关键字参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: Runner 返回的响应信封。 + """ + + return await self.invoke_plugin( + method="plugin.invoke_api", + plugin_id=plugin_id, + component_name=component_name, + args=args, + timeout_ms=timeout_ms, + ) + async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: """按插件 ID 触发精确重载。 @@ -507,13 +541,17 @@ class PluginRunnerSupervisor: except Exception as exc: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + component_declarations = [component.model_dump() for component in payload.components] + runtime_components, api_components = self._split_component_declarations(component_declarations) self._component_registry.remove_components_by_plugin(payload.plugin_id) + self._api_registry.remove_apis_by_plugin(payload.plugin_id) await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) registered_count = self._component_registry.register_plugin_components( payload.plugin_id, - [component.model_dump() for component in payload.components], + runtime_components, ) + registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components) self._registered_plugins[payload.plugin_id] = payload self._message_gateway_states[payload.plugin_id] = {} @@ -522,6 +560,7 @@ class PluginRunnerSupervisor: "accepted": True, "plugin_id": payload.plugin_id, "registered_components": registered_count, + "registered_apis": registered_api_count, "message_gateways": len( self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False) ), @@ -543,6 +582,7 @@ class PluginRunnerSupervisor: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) + removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id) self._authorization.revoke_permission_token(payload.plugin_id) removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) @@ -554,10 +594,48 @@ class PluginRunnerSupervisor: "plugin_id": payload.plugin_id, "reason": payload.reason, "removed_components": removed_components, + "removed_apis": removed_apis, "removed_registration": removed_registration, } ) + @staticmethod + def _is_api_component(component: Dict[str, Any]) -> bool: + """判断组件声明是否属于 API。 + + Args: + component: 原始组件声明字典。 + + Returns: + bool: 是否为 API 组件。 + """ + + return str(component.get("component_type", "") or "").strip().upper() == "API" + + def _split_component_declarations( + self, + components: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """拆分通用组件声明和 API 声明。 + + Args: + components: Runner 上报的原始组件声明列表。 + + Returns: + Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + 第一个列表为需要进入通用组件表的声明, + 第二个列表为需要进入 API 专用表的声明。 + """ + + runtime_components: List[Dict[str, Any]] = [] + api_components: List[Dict[str, Any]] = [] + for component in components: + if self._is_api_component(component): + api_components.append(component) + else: + runtime_components.append(component) + return runtime_components, api_components + @staticmethod def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str: """构造消息网关驱动 ID。 @@ -1172,6 +1250,7 @@ class PluginRunnerSupervisor: def _clear_runner_state(self) -> None: """清理当前 Runner 对应的 Host 侧注册状态。""" self._authorization.clear() + self._api_registry.clear() self._component_registry.clear() self._registered_plugins.clear() self._message_gateway_states.clear() diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 3a50e2f7..4bee714c 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -291,6 +291,7 @@ class PluginRunner: """注册 Host -> Runner 的方法处理器。""" self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) From d13767ee21ea762c3006a9d74b161b69a098dd00 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 20:06:12 +0800 Subject: [PATCH 34/45] feat: Enhance plugin configuration management and SDK integration - Add support for configuration reload scopes in the plugin runtime. - Implement validation for SDK plugins to ensure required lifecycle methods are overridden. - Update the configuration update handling to include scope information. - Introduce tests for expression auto-check task and NapCat adapter SDK integration. - Refactor configuration management to support callbacks with variable arguments. - Improve plugin loading and error handling for configuration updates. - Ensure that plugins can manage their own configuration updates effectively. --- plugins/ChatFrequency/plugin.py | 29 +- plugins/emoji_manage_plugin/plugin.py | 31 ++- plugins/hello_world_plugin/plugin.py | 33 ++- pyproject.toml | 2 + .../test_expression_auto_check_task.py | 89 ++++++ pytests/test_napcat_adapter_sdk.py | 132 +++++++++ pytests/test_plugin_runtime.py | 263 ++++++++++++++++-- src/config/config.py | 135 ++++++++- src/learners/expression_auto_check_task.py | 7 +- src/plugin_runtime/host/supervisor.py | 19 ++ src/plugin_runtime/integration.py | 99 +++++-- src/plugin_runtime/protocol/envelope.py | 12 + src/plugin_runtime/runner/plugin_loader.py | 30 ++ src/plugin_runtime/runner/runner_main.py | 33 ++- src/plugins/built_in/emoji_plugin/plugin.py | 35 ++- .../built_in/plugin_management/plugin.py | 29 +- 16 files changed, 907 insertions(+), 71 deletions(-) create mode 100644 pytests/common_test/test_expression_auto_check_task.py create mode 100644 pytests/test_napcat_adapter_sdk.py diff --git a/plugins/ChatFrequency/plugin.py b/plugins/ChatFrequency/plugin.py index b3f69384..0e9f5a0c 100644 --- a/plugins/ChatFrequency/plugin.py +++ b/plugins/ChatFrequency/plugin.py @@ -3,12 +3,18 @@ 通过 /chat 命令设置和查看聊天频率。 """ -from maibot_sdk import MaiBotPlugin, Command +from maibot_sdk import Command, MaiBotPlugin class BetterFrequencyPlugin(MaiBotPlugin): """聊天频率控制插件""" + async def on_load(self) -> None: + """处理插件加载。""" + + async def on_unload(self) -> None: + """处理插件卸载。""" + @Command( "set_talk_frequency", description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>", @@ -80,6 +86,25 @@ class BetterFrequencyPlugin(MaiBotPlugin): await self.ctx.send.text(status_msg, stream_id) return True, None, False + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del scope + del config_data + del version + + +def create_plugin() -> BetterFrequencyPlugin: + """创建聊天频率插件实例。 + + Returns: + BetterFrequencyPlugin: 新的聊天频率插件实例。 + """ -def create_plugin(): return BetterFrequencyPlugin() diff --git a/plugins/emoji_manage_plugin/plugin.py b/plugins/emoji_manage_plugin/plugin.py index f3c5f677..9362c828 100644 --- a/plugins/emoji_manage_plugin/plugin.py +++ b/plugins/emoji_manage_plugin/plugin.py @@ -3,17 +3,23 @@ 通过 /emoji 命令管理表情包的添加、列表和删除。 """ +from maibot_sdk import Command, MaiBotPlugin + import base64 import datetime import hashlib import re -from maibot_sdk import MaiBotPlugin, Command - class EmojiManagePlugin(MaiBotPlugin): """表情包管理插件""" + async def on_load(self) -> None: + """处理插件加载。""" + + async def on_unload(self) -> None: + """处理插件卸载。""" + # ===== 工具方法 ===== @staticmethod @@ -208,6 +214,25 @@ class EmojiManagePlugin(MaiBotPlugin): await self.ctx.send.forward(messages, stream_id) return True, "已发送随机表情包", True + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del scope + del config_data + del version + + +def create_plugin() -> EmojiManagePlugin: + """创建表情包管理插件实例。 + + Returns: + EmojiManagePlugin: 新的表情包管理插件实例。 + """ -def create_plugin(): return EmojiManagePlugin() diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index fbba9d10..4d1f37af 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -3,16 +3,22 @@ 你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。 """ +from maibot_sdk import Action, Command, EventHandler, MaiBotPlugin, Tool +from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType + import datetime import random -from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler -from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType - class HelloWorldPlugin(MaiBotPlugin): """Hello World 示例插件""" + async def on_load(self) -> None: + """处理插件加载。""" + + async def on_unload(self) -> None: + """处理插件卸载。""" + # ===== Tool 组件 ===== @Tool( @@ -146,6 +152,25 @@ class HelloWorldPlugin(MaiBotPlugin): return True, True, None, None, None + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del scope + del config_data + del version + + +def create_plugin() -> HelloWorldPlugin: + """创建 Hello World 示例插件实例。 + + Returns: + HelloWorldPlugin: 新的示例插件实例。 + """ -def create_plugin(): return HelloWorldPlugin() diff --git a/pyproject.toml b/pyproject.toml index 9887ac24..95c92acd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ dev = [ [tool.uv] index-url = "https://pypi.tuna.tsinghua.edu.cn/simple" +[tool.uv.sources] +maibot-plugin-sdk = { path = "packages/maibot-plugin-sdk", editable = true } [tool.ruff] diff --git a/pytests/common_test/test_expression_auto_check_task.py b/pytests/common_test/test_expression_auto_check_task.py new file mode 100644 index 00000000..da8c59e1 --- /dev/null +++ b/pytests/common_test/test_expression_auto_check_task.py @@ -0,0 +1,89 @@ +"""测试表达方式自动检查任务的数据库读取行为。""" + +from contextlib import contextmanager +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask +from src.common.database.database_model import Expression + + +@pytest.fixture(name="expression_auto_check_engine") +def expression_auto_check_engine_fixture() -> Generator: + """创建用于表达方式自动检查任务测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +@pytest.mark.asyncio +async def test_select_expressions_uses_read_only_session( + monkeypatch: pytest.MonkeyPatch, + expression_auto_check_engine, +) -> None: + """选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。""" + + import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module + + with Session(expression_auto_check_engine) as session: + session.add( + Expression( + situation="表达情绪高涨或生理反应", + style="发送💦表情符号", + content_list='["表达情绪高涨或生理反应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + ) + session.commit() + + auto_commit_calls: list[bool] = [] + + @contextmanager + def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: + """构造带自动提交语义的测试会话工厂。 + + Args: + auto_commit: 退出上下文时是否自动提交。 + + Yields: + Generator[Session, None, None]: SQLModel 会话对象。 + """ + + auto_commit_calls.append(auto_commit) + session = Session(expression_auto_check_engine) + try: + yield session + if auto_commit: + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session) + monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries)) + + task = ExpressionAutoCheckTask() + expressions = await task._select_expressions(1) + + assert auto_commit_calls == [False] + assert len(expressions) == 1 + assert expressions[0].id is not None + assert expressions[0].situation == "表达情绪高涨或生理反应" + assert expressions[0].style == "发送💦表情符号" diff --git a/pytests/test_napcat_adapter_sdk.py b/pytests/test_napcat_adapter_sdk.py new file mode 100644 index 00000000..c6b1fdbd --- /dev/null +++ b/pytests/test_napcat_adapter_sdk.py @@ -0,0 +1,132 @@ +"""NapCat 插件与新 SDK 对接测试。""" + +from pathlib import Path +from typing import Any, Dict, List + +import importlib +import logging +import sys + +import pytest + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +PLUGINS_ROOT = PROJECT_ROOT / "plugins" +SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" + +for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)): + if import_path not in sys.path: + sys.path.insert(0, import_path) + + +class _FakeGatewayCapability: + """用于捕获消息网关状态上报的测试替身。""" + + def __init__(self) -> None: + """初始化测试替身。""" + + self.calls: List[Dict[str, Any]] = [] + + async def update_state( + self, + gateway_name: str, + *, + ready: bool, + platform: str = "", + account_id: str = "", + scope: str = "", + metadata: Dict[str, Any] | None = None, + ) -> bool: + """记录一次状态上报请求。 + + Args: + gateway_name: 网关组件名称。 + ready: 当前是否就绪。 + platform: 平台名称。 + account_id: 账号 ID。 + scope: 路由作用域。 + metadata: 附加元数据。 + + Returns: + bool: 始终返回 ``True``,模拟 Host 接受状态更新。 + """ + + self.calls.append( + { + "gateway_name": gateway_name, + "ready": ready, + "platform": platform, + "account_id": account_id, + "scope": scope, + "metadata": metadata or {}, + } + ) + return True + + +def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]: + """动态加载 NapCat 插件测试所需的符号。 + + Returns: + tuple[Any, Any, Any, Any]: + 依次返回网关名常量、配置类、插件类和运行时状态管理器类。 + """ + + constants_module = importlib.import_module("napcat_adapter.constants") + config_module = importlib.import_module("napcat_adapter.config") + plugin_module = importlib.import_module("napcat_adapter.plugin") + runtime_state_module = importlib.import_module("napcat_adapter.runtime_state") + return ( + constants_module.NAPCAT_GATEWAY_NAME, + config_module.NapCatServerConfig, + plugin_module.NapCatAdapterPlugin, + runtime_state_module.NapCatRuntimeStateManager, + ) + + +def test_napcat_plugin_collects_duplex_message_gateway() -> None: + """NapCat 插件应声明新的双工消息网关组件。""" + + napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() + plugin = napcat_plugin_cls() + components = plugin.get_components() + gateway_components = [ + component + for component in components + if component.get("type") == "MESSAGE_GATEWAY" + ] + + assert len(gateway_components) == 1 + gateway_component = gateway_components[0] + assert gateway_component["name"] == napcat_gateway_name + assert gateway_component["metadata"]["route_type"] == "duplex" + assert gateway_component["metadata"]["platform"] == "qq" + assert gateway_component["metadata"]["protocol"] == "napcat" + + +@pytest.mark.asyncio +async def test_runtime_state_reports_via_gateway_capability() -> None: + """NapCat 运行时状态应通过新的消息网关能力上报。""" + + napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols() + gateway_capability = _FakeGatewayCapability() + runtime_state_manager = runtime_state_cls( + gateway_capability=gateway_capability, + logger=logging.getLogger("test.napcat_adapter"), + gateway_name=napcat_gateway_name, + ) + + connected = await runtime_state_manager.report_connected( + "10001", + napcat_server_config_cls(connection_id="primary"), + ) + await runtime_state_manager.report_disconnected() + + assert connected is True + assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name + assert gateway_capability.calls[0]["ready"] is True + assert gateway_capability.calls[0]["platform"] == "qq" + assert gateway_capability.calls[0]["account_id"] == "10001" + assert gateway_capability.calls[0]["scope"] == "primary" + assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name + assert gateway_capability.calls[1]["ready"] is False + assert gateway_capability.calls[1]["platform"] == "qq" diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 9dfc34d8..9b46f897 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -441,8 +441,8 @@ class TestSDK: def set_plugin_config(self, config): self.configs.append(config) - async def on_config_update(self, config, version): - self.updates.append((config, version, list(self.configs))) + async def on_config_update(self, scope, config, version): + self.updates.append((scope, config, version, list(self.configs))) runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) plugin = DummyPlugin() @@ -453,14 +453,60 @@ class TestSDK: message_type=MessageType.REQUEST, method="plugin.config_updated", plugin_id="demo_plugin", - payload={"config_data": {"enabled": True}, "config_version": "v2"}, + payload={ + "plugin_id": "demo_plugin", + "config_scope": "self", + "config_data": {"enabled": True}, + "config_version": "v2", + }, ) response = await runner._handle_config_updated(envelope) assert response.payload["acknowledged"] is True assert plugin.configs == [{"enabled": True}] - assert plugin.updates == [({"enabled": True}, "v2", [{"enabled": True}])] + assert plugin.updates == [("self", {"enabled": True}, "v2", [{"enabled": True}])] + + @pytest.mark.asyncio + async def test_runner_global_config_update_does_not_override_plugin_config(self): + """bot/model 广播不应覆盖插件自身配置缓存。""" + from src.plugin_runtime.protocol.envelope import Envelope, MessageType + from src.plugin_runtime.runner.runner_main import PluginRunner + + class DummyPlugin: + def __init__(self): + self.configs = [] + self.updates = [] + + def set_plugin_config(self, config): + self.configs.append(config) + + async def on_config_update(self, scope, config, version): + self.updates.append((scope, config, version, list(self.configs))) + + runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) + plugin = DummyPlugin() + runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin) + plugin.set_plugin_config({"plugin_enabled": True}) + + envelope = Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="plugin.config_updated", + plugin_id="demo_plugin", + payload={ + "plugin_id": "demo_plugin", + "config_scope": "model", + "config_data": {"models": []}, + "config_version": "", + }, + ) + + response = await runner._handle_config_updated(envelope) + + assert response.payload["acknowledged"] is True + assert plugin.configs == [{"plugin_enabled": True}] + assert plugin.updates == [("model", {"models": []}, "", [{"plugin_enabled": True}])] @pytest.mark.asyncio async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch): @@ -911,6 +957,120 @@ class TestDependencyResolution: assert loader.failed_plugins == {} assert loaded[0].instance.answer() == 42 + def test_loader_requires_sdk_plugin_to_override_on_config_update(self, tmp_path): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + plugin_root = tmp_path / "plugins" + plugin_root.mkdir() + plugin_dir = plugin_root / "demo_plugin" + plugin_dir.mkdir() + + (plugin_dir / "_manifest.json").write_text( + json.dumps( + { + "name": "demo_plugin", + "version": "1.0.0", + "description": "demo", + "author": "tester", + } + ), + encoding="utf-8", + ) + (plugin_dir / "plugin.py").write_text( + "from maibot_sdk import MaiBotPlugin\n\n" + "class DemoPlugin(MaiBotPlugin):\n" + " async def on_load(self):\n" + " pass\n\n" + " async def on_unload(self):\n" + " pass\n\n" + "def create_plugin():\n" + " return DemoPlugin()\n", + encoding="utf-8", + ) + + loader = PluginLoader() + loaded = loader.discover_and_load([str(plugin_root)]) + + assert loaded == [] + assert "demo_plugin" in loader.failed_plugins + assert "on_config_update" in loader.failed_plugins["demo_plugin"] + + def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + plugin_root = tmp_path / "plugins" + plugin_root.mkdir() + plugin_dir = plugin_root / "demo_plugin" + plugin_dir.mkdir() + + (plugin_dir / "_manifest.json").write_text( + json.dumps( + { + "name": "demo_plugin", + "version": "1.0.0", + "description": "demo", + "author": "tester", + } + ), + encoding="utf-8", + ) + (plugin_dir / "plugin.py").write_text( + "from maibot_sdk import MaiBotPlugin\n\n" + "class DemoPlugin(MaiBotPlugin):\n" + " async def on_unload(self):\n" + " pass\n\n" + " async def on_config_update(self, scope, config_data, version):\n" + " pass\n\n" + "def create_plugin():\n" + " return DemoPlugin()\n", + encoding="utf-8", + ) + + loader = PluginLoader() + loaded = loader.discover_and_load([str(plugin_root)]) + + assert loaded == [] + assert "demo_plugin" in loader.failed_plugins + assert "on_load" in loader.failed_plugins["demo_plugin"] + + def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + plugin_root = tmp_path / "plugins" + plugin_root.mkdir() + plugin_dir = plugin_root / "demo_plugin" + plugin_dir.mkdir() + + (plugin_dir / "_manifest.json").write_text( + json.dumps( + { + "name": "demo_plugin", + "version": "1.0.0", + "description": "demo", + "author": "tester", + } + ), + encoding="utf-8", + ) + (plugin_dir / "plugin.py").write_text( + "from maibot_sdk import MaiBotPlugin\n\n" + "class DemoPlugin(MaiBotPlugin):\n" + " async def on_load(self):\n" + " pass\n\n" + " async def on_config_update(self, scope, config_data, version):\n" + " pass\n\n" + "def create_plugin():\n" + " return DemoPlugin()\n", + encoding="utf-8", + ) + + loader = PluginLoader() + loaded = loader.discover_and_load([str(plugin_root)]) + + assert loaded == [] + assert "demo_plugin" in loader.failed_plugins + assert "on_unload" in loader.failed_plugins["demo_plugin"] + def test_isolate_sys_path_preserves_plugin_dirs(self): from src.plugin_runtime.runner import runner_main @@ -2299,9 +2459,10 @@ class TestIntegration: assert refresh_calls == [True] @pytest.mark.asyncio - async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path): + async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module from src.config.file_watcher import FileChange + import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" @@ -2311,6 +2472,10 @@ class TestIntegration: beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") + (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") monkeypatch.chdir(tmp_path) @@ -2318,31 +2483,95 @@ class TestIntegration: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} - self.reload_calls = [] + self.config_updates = [] - async def reload_plugin(self, plugin_id, reason="manual"): - self.reload_calls.append((plugin_id, reason)) + async def notify_plugin_config_updated( + self, + plugin_id, + config_data, + config_version="", + config_scope="self", + ): + self.config_updates.append((plugin_id, config_data, config_version, config_scope)) return True manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) - refresh_calls = [] - - def fake_refresh() -> None: - refresh_calls.append(True) - - manager._refresh_plugin_config_watch_subscriptions = fake_refresh await manager._handle_plugin_config_changes( "alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) - assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")] - assert manager._third_party_supervisor.reload_calls == [] - assert refresh_calls == [True] + assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "", "self")] + assert manager._third_party_supervisor.config_updates == [] + + @pytest.mark.asyncio + async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch): + from src.plugin_runtime import integration as integration_module + + class FakeRegistration: + def __init__(self, subscriptions): + self.config_reload_subscriptions = subscriptions + + class FakeSupervisor: + def __init__(self, registrations): + self._registered_plugins = registrations + self.config_updates = [] + + def get_config_reload_subscribers(self, scope): + matched_plugins = [] + for plugin_id, registration in self._registered_plugins.items(): + if scope in registration.config_reload_subscriptions: + matched_plugins.append(plugin_id) + return matched_plugins + + async def notify_plugin_config_updated( + self, + plugin_id, + config_data, + config_version="", + config_scope="self", + ): + self.config_updates.append((plugin_id, config_data, config_version, config_scope)) + return True + + fake_global = SimpleNamespace(plugin_runtime=SimpleNamespace(enabled=True)) + monkeypatch.setattr( + integration_module.config_manager, + "get_global_config", + lambda: SimpleNamespace(model_dump=lambda: {"bot": {"name": "MaiBot"}}, plugin_runtime=fake_global.plugin_runtime), + ) + monkeypatch.setattr( + integration_module.config_manager, + "get_model_config", + lambda: SimpleNamespace(model_dump=lambda: {"models": [{"name": "demo"}]}), + ) + + manager = integration_module.PluginRuntimeManager() + manager._started = True + manager._builtin_supervisor = FakeSupervisor( + { + "alpha": FakeRegistration(["bot"]), + "beta": FakeRegistration([]), + } + ) + manager._third_party_supervisor = FakeSupervisor( + { + "gamma": FakeRegistration(["model"]), + } + ) + + await manager._handle_main_config_reload(["bot", "model"]) + + assert manager._builtin_supervisor.config_updates == [ + ("alpha", {"bot": {"name": "MaiBot"}}, "", "bot") + ] + assert manager._third_party_supervisor.config_updates == [ + ("gamma", {"models": [{"name": "demo"}]}, "", "model") + ] def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): from src.plugin_runtime import integration as integration_module diff --git a/src/config/config.py b/src/config/config.py index ff5941bf..bee81efb 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar import asyncio import copy +import inspect import sys import tomlkit @@ -61,6 +62,7 @@ MODEL_CONFIG_VERSION: str = "1.12.0" logger = get_logger("config") T = TypeVar("T", bound="ConfigBase") +ConfigReloadCallback = Callable[[Sequence[str]], object] | Callable[[], object] class Config(ConfigBase): @@ -190,7 +192,7 @@ class ConfigManager: self.global_config: Config | None = None self.model_config: ModelConfig | None = None self._reload_lock: asyncio.Lock = asyncio.Lock() - self._reload_callbacks: list[Callable[[], object]] = [] + self._reload_callbacks: list[ConfigReloadCallback] = [] self._file_watcher: FileWatcher | None = None self._file_watcher_subscription_id: str | None = None self._hot_reload_min_interval_s: float = 1.0 @@ -226,16 +228,125 @@ class ConfigManager: raise RuntimeError(t("config.model_not_initialized")) return self.model_config - def register_reload_callback(self, callback: Callable[[], object]) -> None: + def register_reload_callback(self, callback: ConfigReloadCallback) -> None: + """注册配置热重载回调。 + + Args: + callback: 配置热重载回调。允许无参回调,也允许接收 + ``Sequence[str]`` 类型的变更范围列表。 + """ + self._reload_callbacks.append(callback) - def unregister_reload_callback(self, callback: Callable[[], object]) -> None: + def unregister_reload_callback(self, callback: ConfigReloadCallback) -> None: + """注销配置热重载回调。 + + Args: + callback: 先前注册过的回调对象。 + """ + try: self._reload_callbacks.remove(callback) except ValueError: return - async def reload_config(self) -> bool: + @staticmethod + def _normalize_changed_scopes(changed_scopes: Sequence[str] | None) -> tuple[str, ...]: + """规范化配置变更范围列表。 + + Args: + changed_scopes: 原始配置变更范围。 + + Returns: + tuple[str, ...]: 去重后的配置变更范围元组。 + """ + + if not changed_scopes: + return ("bot", "model") + + normalized_scopes: list[str] = [] + for scope in changed_scopes: + normalized_scope = str(scope or "").strip().lower() + if normalized_scope not in {"bot", "model"}: + continue + if normalized_scope not in normalized_scopes: + normalized_scopes.append(normalized_scope) + return tuple(normalized_scopes) + + @staticmethod + def _resolve_changed_scopes(changes: Sequence[FileChange]) -> tuple[str, ...]: + """根据文件变更列表推断配置变更范围。 + + Args: + changes: 文件监听器返回的变更列表。 + + Returns: + tuple[str, ...]: 命中的配置变更范围元组。 + """ + + changed_scopes: list[str] = [] + for change in changes: + file_name = change.path.name + if file_name == "bot_config.toml" and "bot" not in changed_scopes: + changed_scopes.append("bot") + if file_name == "model_config.toml" and "model" not in changed_scopes: + changed_scopes.append("model") + return tuple(changed_scopes) + + @staticmethod + def _callback_accepts_scopes(callback: ConfigReloadCallback) -> bool: + """判断回调是否接收配置变更范围参数。 + + Args: + callback: 待检测的回调对象。 + + Returns: + bool: 若回调可接收一个位置参数或可变位置参数,则返回 ``True``。 + """ + + try: + parameters = inspect.signature(callback).parameters.values() + except (TypeError, ValueError): + return False + + positional_params = { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + } + for parameter in parameters: + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: + return True + if parameter.kind in positional_params: + return True + return False + + async def _invoke_reload_callback( + self, + callback: ConfigReloadCallback, + changed_scopes: Sequence[str], + ) -> None: + """执行单个配置热重载回调。 + + Args: + callback: 要执行的回调对象。 + changed_scopes: 本次热重载命中的配置范围。 + """ + + result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback() + if asyncio.iscoroutine(result): + await result + + async def reload_config(self, changed_scopes: Sequence[str] | None = None) -> bool: + """重新加载主配置和模型配置。 + + Args: + changed_scopes: 本次触发热重载的配置范围。 + + Returns: + bool: 是否重载成功。 + """ + + normalized_scopes = self._normalize_changed_scopes(changed_scopes) async with self._reload_lock: try: global_config_new, global_updated = load_config_from_file( @@ -265,9 +376,7 @@ class ConfigManager: for callback in list(self._reload_callbacks): try: - result = callback() - if asyncio.iscoroutine(result): - await result + await self._invoke_reload_callback(callback, normalized_scopes) except Exception as exc: logger.warning(t("config.reload_callback_failed", error=exc)) return True @@ -312,6 +421,12 @@ class ConfigManager: self._file_watcher = None async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None: + """处理主配置与模型配置文件变更。 + + Args: + changes: 当前批次收集到的文件变更列表。 + """ + if not changes: return now_monotonic = asyncio.get_running_loop().time() @@ -321,7 +436,11 @@ class ConfigManager: self._last_hot_reload_monotonic = now_monotonic logger.info(t("config.file_change_detected")) try: - await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s) + changed_scopes = self._resolve_changed_scopes(changes) + await asyncio.wait_for( + self.reload_config(changed_scopes=changed_scopes), + timeout=self._hot_reload_timeout_s, + ) except asyncio.TimeoutError: logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s)) diff --git a/src/learners/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py index 53b151b2..e5af1057 100644 --- a/src/learners/expression_auto_check_task.py +++ b/src/learners/expression_auto_check_task.py @@ -3,15 +3,15 @@ 功能: 1. 定期随机选取指定数量的表达方式 -2. 使用LLM进行评估 +2. 使用 LLM 进行评估 3. 通过评估的:rejected=0, checked=1 4. 未通过评估的:rejected=1, checked=1 """ -from typing import List import asyncio import json import random +from typing import List from sqlmodel import select @@ -146,7 +146,8 @@ class ExpressionAutoCheckTask(AsyncTask): 选中的表达方式列表 """ try: - with get_db_session() as session: + # 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。 + with get_db_session(auto_commit=False) as session: statement = select(Expression) all_expressions = session.exec(statement).all() diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 1add64c6..afe944e5 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -399,6 +399,7 @@ class PluginRunnerSupervisor: plugin_id: str, config_data: Optional[Dict[str, Any]] = None, config_version: str = "", + config_scope: str = "self", ) -> bool: """向 Runner 推送插件配置更新。 @@ -406,12 +407,14 @@ class PluginRunnerSupervisor: plugin_id: 目标插件 ID。 config_data: 配置内容。 config_version: 配置版本号。 + config_scope: 配置变更范围。 Returns: bool: 请求是否成功送达并被 Runner 接受。 """ payload = ConfigUpdatedPayload( plugin_id=plugin_id, + config_scope=config_scope, config_version=config_version, config_data=config_data or {}, ) @@ -428,6 +431,22 @@ class PluginRunnerSupervisor: return bool(response.payload.get("acknowledged", False)) + def get_config_reload_subscribers(self, scope: str) -> List[str]: + """返回订阅指定全局配置广播的插件列表。 + + Args: + scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。 + + Returns: + List[str]: 已声明订阅该范围的插件 ID 列表。 + """ + + matched_plugins: List[str] = [] + for plugin_id, registration in self._registered_plugins.items(): + if scope in registration.config_reload_subscriptions: + matched_plugins.append(plugin_id) + return matched_plugins + async def _wait_for_runner_connection(self, timeout_sec: float) -> None: """等待 Runner 建立 RPC 连接。 diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index ff51f419..e45b40de 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -16,7 +16,7 @@ import json import tomlkit from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import config_manager from src.config.file_watcher import FileChange, FileWatcher from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager from src.plugin_runtime.capabilities import ( @@ -69,6 +69,8 @@ class PluginRuntimeManager( self._plugin_source_watcher_subscription_id: Optional[str] = None self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {} self._plugin_path_cache: Dict[str, Path] = {} + self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload + self._config_reload_callback_registered: bool = False async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: """接收 Platform IO 审核后的入站消息并送入主消息链。 @@ -108,7 +110,7 @@ class PluginRuntimeManager( logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动") return - _cfg = global_config.plugin_runtime + _cfg = config_manager.get_global_config().plugin_runtime if not _cfg.enabled: logger.info("插件运行时已在配置中禁用,跳过启动") return @@ -166,11 +168,16 @@ class PluginRuntimeManager( await self._third_party_supervisor.start() started_supervisors.append(self._third_party_supervisor) await self._start_plugin_file_watcher() + config_manager.register_reload_callback(self._config_reload_callback) + self._config_reload_callback_registered = True self._started = True logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {third_party_dirs or '无'}") except Exception as e: logger.error(f"插件运行时启动失败: {e}", exc_info=True) await self._stop_plugin_file_watcher() + if self._config_reload_callback_registered: + config_manager.unregister_reload_callback(self._config_reload_callback) + self._config_reload_callback_registered = False await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True) platform_io_manager.clear_inbound_dispatcher() try: @@ -188,6 +195,9 @@ class PluginRuntimeManager( platform_io_manager = get_platform_io_manager() await self._stop_plugin_file_watcher() + if self._config_reload_callback_registered: + config_manager.unregister_reload_callback(self._config_reload_callback) + self._config_reload_callback_registered = False coroutines: List[Coroutine[Any, Any, None]] = [] if self._builtin_supervisor: @@ -233,6 +243,7 @@ class PluginRuntimeManager( plugin_id: str, config_data: Optional[Dict[str, Any]] = None, config_version: str = "", + config_scope: str = "self", ) -> bool: """向拥有该插件的 Supervisor 推送配置更新事件。 @@ -240,6 +251,7 @@ class PluginRuntimeManager( plugin_id: 插件 ID config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载) config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制 + config_scope: 配置变更范围。 """ if not self._started: return False @@ -258,12 +270,67 @@ class PluginRuntimeManager( if config_data is not None else self._load_plugin_config_for_supervisor(sv, plugin_id) ) - await sv.notify_plugin_config_updated( + return await sv.notify_plugin_config_updated( plugin_id=plugin_id, config_data=config_payload, config_version=config_version, + config_scope=config_scope, ) - return True + + @staticmethod + def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]: + """规范化配置热重载范围列表。 + + Args: + changed_scopes: 原始配置热重载范围列表。 + + Returns: + tuple[str, ...]: 去重后的有效配置范围元组。 + """ + + normalized_scopes: list[str] = [] + for scope in changed_scopes: + normalized_scope = str(scope or "").strip().lower() + if normalized_scope not in {"bot", "model"}: + continue + if normalized_scope not in normalized_scopes: + normalized_scopes.append(normalized_scope) + return tuple(normalized_scopes) + + async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None: + """向订阅指定范围的插件广播配置热重载。 + + Args: + scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。 + config_data: 最新配置数据。 + """ + + for supervisor in self.supervisors: + for plugin_id in supervisor.get_config_reload_subscribers(scope): + delivered = await supervisor.notify_plugin_config_updated( + plugin_id=plugin_id, + config_data=config_data, + config_version="", + config_scope=scope, + ) + if not delivered: + logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败") + + async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None: + """处理 bot/model 主配置热重载广播。 + + Args: + changed_scopes: 本次热重载命中的配置范围列表。 + """ + + if not self._started: + return + + normalized_scopes = self._normalize_config_reload_scopes(changed_scopes) + if "bot" in normalized_scopes: + await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump()) + if "model" in normalized_scopes: + await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump()) # ─── 事件桥接 ────────────────────────────────────────────── @@ -612,16 +679,12 @@ class PluginRuntimeManager( return None if plugin_path is None else plugin_path / "config.toml" async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None: - """处理单个插件配置文件变化,并精确重载目标插件。 + """处理单个插件配置文件变化,并定向派发自配置热更新。 Args: plugin_id: 发生配置变更的插件 ID。 changes: 当前批次收集到的配置文件变更列表。 - Notes: - 这里选择“精确重载该插件”,而不是仅推送软性的配置更新通知。 - 这样可以保证没有实现 ``on_config_update()`` 的插件也能重新执行 - ``on_load()``,让磁盘上的 ``config.toml`` 修改对插件运行态真正生效。 """ if not self._started or not changes: return @@ -636,15 +699,15 @@ class PluginRuntimeManager( return try: - self._load_plugin_config_for_supervisor(supervisor, plugin_id) - reload_success = await supervisor.reload_plugin( + config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id) + delivered = await supervisor.notify_plugin_config_updated( plugin_id=plugin_id, - reason="config_file_changed", + config_data=config_payload, + config_version="", + config_scope="self", ) - if reload_success: - self._refresh_plugin_config_watch_subscriptions() - else: - logger.warning(f"插件 {plugin_id} 配置文件变更后重载失败") + if not delivered: + logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败") except Exception as exc: logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}") @@ -652,8 +715,8 @@ class PluginRuntimeManager( """处理插件源码相关变化。 这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由 - 单独的 per-plugin watcher 处理,并精确重载对应插件,避免放大成 - 不必要的跨插件 reload。 + 单独的 per-plugin watcher 处理,并定向派发给目标插件的 + ``on_config_update()``,避免放大成不必要的跨插件 reload。 """ if not self._started or not changes: return diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index cbbb71be..6078e4dc 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -29,6 +29,14 @@ class MessageType(str, Enum): BROADCAST = "broadcast" +class ConfigReloadScope(str, Enum): + """配置热重载范围。""" + + SELF = "self" + BOT = "bot" + MODEL = "model" + + # ====== 请求 ID 生成器 ====== class RequestIdGenerator: """单调递增 int64 请求 ID 生成器""" @@ -158,6 +166,8 @@ class RegisterPluginPayload(BaseModel): """组件列表""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") """所需能力列表""" + config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围") + """订阅的全局配置热重载范围""" class BootstrapPluginPayload(BaseModel): @@ -236,6 +246,8 @@ class ConfigUpdatedPayload(BaseModel): plugin_id: str = Field(description="插件 ID") """插件 ID""" + config_scope: ConfigReloadScope = Field(description="配置变更范围") + """配置变更范围""" config_version: str = Field(description="新配置版本") """新配置版本""" config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容") diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 90c8bf47..a766eb04 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -403,6 +403,7 @@ class PluginLoader: create_plugin = getattr(module, "create_plugin", None) if create_plugin is not None: instance = create_plugin() + self._validate_sdk_plugin_contract(plugin_id, instance) logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") return PluginMeta( plugin_id=plugin_id, @@ -432,6 +433,35 @@ class PluginLoader: logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin") return None + @staticmethod + def _validate_sdk_plugin_contract(plugin_id: str, instance: Any) -> None: + """校验 SDK 插件的基础契约。 + + Args: + plugin_id: 当前插件 ID。 + instance: ``create_plugin()`` 返回的插件实例。 + + Raises: + TypeError: 当插件未覆盖必需生命周期方法或订阅声明不合法时抛出。 + """ + + try: + from maibot_sdk.plugin import MaiBotPlugin + except ImportError: + return + + if not isinstance(instance, MaiBotPlugin): + return + + if type(instance).on_load is MaiBotPlugin.on_load: + raise TypeError(f"插件 {plugin_id} 必须实现 on_load()") + if type(instance).on_unload is MaiBotPlugin.on_unload: + raise TypeError(f"插件 {plugin_id} 必须实现 on_unload()") + if type(instance).on_config_update is MaiBotPlugin.on_config_update: + raise TypeError(f"插件 {plugin_id} 必须实现 on_config_update()") + + instance.get_config_reload_subscriptions() + @staticmethod @contextlib.contextmanager def _temporary_sys_path_entry(path: Path) -> Iterator[None]: diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 4bee714c..b94b01d1 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -27,6 +27,7 @@ from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIR from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, ComponentDeclaration, + ConfigUpdatedPayload, Envelope, HealthPayload, InvokePayload, @@ -342,6 +343,7 @@ class PluginRunner: """ # 收集插件组件声明 components: List[ComponentDeclaration] = [] + config_reload_subscriptions: List[str] = [] instance = meta.instance # 从插件实例获取组件声明(SDK 插件须实现 get_components 方法) @@ -355,12 +357,15 @@ class PluginRunner: ) for comp_info in instance.get_components() ) + if hasattr(instance, "get_config_reload_subscriptions"): + config_reload_subscriptions = list(instance.get_config_reload_subscriptions()) reg_payload = RegisterPluginPayload( plugin_id=meta.plugin_id, plugin_version=meta.version, components=components, capabilities_required=meta.capabilities_required, + config_reload_subscriptions=config_reload_subscriptions, ) try: @@ -911,18 +916,28 @@ class PluginRunner: return envelope.make_response(payload={"acknowledged": True}) async def _handle_config_updated(self, envelope: Envelope) -> Envelope: - """处理配置更新事件""" + """处理配置更新事件。""" + try: + payload = ConfigUpdatedPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + plugin_id = envelope.plugin_id if meta := self._loader.get_plugin(plugin_id): try: - config_data = envelope.payload.get("config_data", {}) - config_version = envelope.payload.get("config_version", "") - self._apply_plugin_config(meta, config_data=config_data) - if hasattr(meta.instance, "on_config_update"): - ret = meta.instance.on_config_update(config_data, config_version) - # 兼容同步和异步的 on_config_update 实现 - if asyncio.iscoroutine(ret): - await ret + config_scope = payload.config_scope.value + if config_scope == "self": + self._apply_plugin_config(meta, config_data=payload.config_data) + if not hasattr(meta.instance, "on_config_update"): + raise AttributeError("插件缺少 on_config_update() 实现") + + ret = meta.instance.on_config_update( + config_scope, + payload.config_data, + payload.config_version, + ) + if asyncio.iscoroutine(ret): + await ret except Exception as e: logger.error(f"插件 {plugin_id} 配置更新失败: {e}") return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) diff --git a/src/plugins/built_in/emoji_plugin/plugin.py b/src/plugins/built_in/emoji_plugin/plugin.py index b946931b..cc6b87c5 100644 --- a/src/plugins/built_in/emoji_plugin/plugin.py +++ b/src/plugins/built_in/emoji_plugin/plugin.py @@ -3,11 +3,11 @@ 根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。 """ -import random - -from maibot_sdk import MaiBotPlugin, Action +from maibot_sdk import Action, MaiBotPlugin from maibot_sdk.types import ActivationType +import random + class EmojiPlugin(MaiBotPlugin): """表情包插件""" @@ -95,10 +95,35 @@ class EmojiPlugin(MaiBotPlugin): return True, f"成功发送表情包:[表情包:{chosen_emotion}]" return False, "发送表情包失败" - async def on_load(self): + async def on_load(self) -> None: + """处理插件加载。""" + # 从插件配置读取 emoji_chance 来覆盖默认概率 await self.ctx.config.get("emoji.emoji_chance") + async def on_unload(self) -> None: + """处理插件卸载。""" + + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del config_data + del version + if scope == "self": + await self.ctx.config.get("emoji.emoji_chance") + + +def create_plugin() -> EmojiPlugin: + """创建 Emoji 插件实例。 + + Returns: + EmojiPlugin: 新的 Emoji 插件实例。 + """ -def create_plugin(): return EmojiPlugin() diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index fe0888c6..aa2da795 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -3,7 +3,7 @@ 通过 /pm 命令管理插件和组件的生命周期。 """ -from maibot_sdk import MaiBotPlugin, Command +from maibot_sdk import Command, MaiBotPlugin _VALID_COMPONENT_TYPES = ("action", "command", "event_handler") @@ -44,6 +44,12 @@ HELP_COMPONENT = ( class PluginManagementPlugin(MaiBotPlugin): """插件和组件管理插件""" + async def on_load(self) -> None: + """处理插件加载。""" + + async def on_unload(self) -> None: + """处理插件卸载。""" + @Command( "management", description="管理插件和组件的生命周期", @@ -268,6 +274,25 @@ class PluginManagementPlugin(MaiBotPlugin): return components return [] + async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None: + """处理配置热重载事件。 + + Args: + scope: 配置变更范围。 + config_data: 最新配置数据。 + version: 配置版本号。 + """ + + del scope + del config_data + del version + + +def create_plugin() -> PluginManagementPlugin: + """创建插件管理插件实例。 + + Returns: + PluginManagementPlugin: 新的插件管理插件实例。 + """ -def create_plugin(): return PluginManagementPlugin() From 7a304ba54964273e52f8d6fa3d6aee7612164be0 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 21:01:55 +0800 Subject: [PATCH 35/45] feat: Enhance API and Outbound Tracking Functionality - Add test for fallback to bot account in platform IO route metadata when context message is absent. - Improve PlatformIOManager to avoid duplicate driver entries and streamline fallback driver handling. - Refactor OutboundTracker to support tracking by both internal message ID and driver ID, enhancing the uniqueness of pending records. - Introduce dynamic API capabilities in RuntimeComponent, allowing plugins to replace their dynamic API lists. - Update APIRegistry to manage dynamic APIs more effectively, including registration and toggling of API statuses. - Implement authorization checks for dynamic API capabilities to ensure proper permissions. - Restrict direct calls to certain host RPC methods from plugins for enhanced security. - Refactor send_service to ensure fallback to current platform account when no context message is available. --- pytests/test_platform_io_legacy_driver.py | 56 ++- pytests/test_plugin_runtime_api.py | 230 +++++++++++++ pytests/test_send_service.py | 13 + src/platform_io/manager.py | 21 +- src/platform_io/outbound_tracker.py | 70 +++- src/plugin_runtime/capabilities/components.py | 211 ++++++++++-- src/plugin_runtime/capabilities/registry.py | 1 + src/plugin_runtime/host/api_registry.py | 319 +++++++++++------- src/plugin_runtime/host/authorization.py | 5 + src/plugin_runtime/runner/runner_main.py | 16 +- src/services/send_service.py | 29 +- 11 files changed, 771 insertions(+), 200 deletions(-) diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py index 2e94c1fc..76f14d8f 100644 --- a/pytests/test_platform_io_legacy_driver.py +++ b/pytests/test_platform_io_legacy_driver.py @@ -82,7 +82,61 @@ async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route( ) explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq")) - assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender"] + assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"] + finally: + await manager.stop() + + +@pytest.mark.asyncio +async def test_platform_io_broadcasts_to_plugin_and_legacy_driver( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """同一路由命中插件驱动与 legacy driver 时,应同时广播发送。""" + + manager = PlatformIOManager() + legacy_calls: list[dict[str, Any]] = [] + monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"}) + + async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool: + """记录 legacy driver 调用。""" + + legacy_calls.append({"message": message, "show_log": show_log}) + return True + + monkeypatch.setattr( + uni_message_sender, + "send_prepared_message_to_platform", + _fake_send_prepared_message_to_platform, + ) + + try: + await manager.ensure_send_pipeline_ready() + + plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq") + await manager.add_driver(plugin_driver) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=plugin_driver.driver_id, + driver_kind=plugin_driver.descriptor.kind, + ) + ) + + message = type("FakeMessage", (), {"message_id": "message-1"})() + batch = await manager.send_message( + message=message, + route_key=RouteKey(platform="qq"), + metadata={"show_log": False}, + ) + + assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [ + "legacy.send.qq", + "plugin.qq.sender", + ] + assert batch.failed_receipts == [] + assert len(legacy_calls) == 1 + assert legacy_calls[0]["message"] is message + assert legacy_calls[0]["show_log"] is False finally: await manager.stop() diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py index fca7736a..58a8e6ba 100644 --- a/pytests/test_plugin_runtime_api.py +++ b/pytests/test_plugin_runtime_api.py @@ -292,3 +292,233 @@ async def test_api_list_and_component_toggle_use_dedicated_registry() -> None: ) assert enable_result["success"] is True assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None + + +@pytest.mark.asyncio +async def test_api_registry_supports_multiple_versions_with_distinct_handlers( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """同名 API 不同版本应可并存,并按版本路由到不同处理器。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML v1", + "version": "1", + "public": True, + "handler_name": "handle_render_html_v1", + }, + }, + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML v2", + "version": "2", + "public": True, + "handler_name": "handle_render_html_v2", + }, + }, + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟多版本 API 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}}) + + monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api) + manager = _build_manager(provider_supervisor, consumer_supervisor) + + ambiguous_result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "args": {"html": "
Hello
"}, + }, + ) + assert ambiguous_result["success"] is False + assert "多个版本" in str(ambiguous_result["error"]) + + disable_ambiguous_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.render_html", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert disable_ambiguous_result["success"] is False + assert "多个版本" in str(disable_ambiguous_result["error"]) + + disable_v1_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.render_html", + "component_type": "API", + "scope": "global", + "stream_id": "", + "version": "1", + }, + ) + assert disable_v1_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None + assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None + + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "version": "2", + "args": {"html": "
Hello
"}, + }, + ) + + assert result == {"success": True, "result": {"image": "ok"}} + assert captured["plugin_id"] == "provider" + assert captured["component_name"] == "handle_render_html_v2" + assert captured["args"] == {"html": "
Hello
"} + + +@pytest.mark.asyncio +async def test_api_replace_dynamic_can_offline_removed_entries( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """动态 API 替换后,被移除的 API 应返回明确下线错误。""" + + supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin(supervisor, "provider", []) + manager = _build_manager(supervisor) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟动态 API 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}}) + + monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api) + + replace_result = await manager._cap_api_replace_dynamic( + "provider", + "api.replace_dynamic", + { + "apis": [ + { + "name": "mcp.search", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_search", + }, + }, + { + "name": "mcp.read", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_read", + }, + }, + ], + "offline_reason": "MCP 服务器已关闭", + }, + ) + + assert replace_result["success"] is True + assert replace_result["count"] == 2 + list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"}) + assert {(item["name"], item["version"]) for item in list_result["apis"]} == { + ("mcp.read", "1"), + ("mcp.search", "1"), + } + + call_result = await manager._cap_api_call( + "provider", + "api.call", + { + "api_name": "provider.mcp.search", + "version": "1", + "args": {"query": "hello"}, + }, + ) + assert call_result == {"success": True, "result": {"ok": True}} + assert captured["component_name"] == "dynamic_search" + assert captured["args"]["query"] == "hello" + assert captured["args"]["__maibot_api_name__"] == "mcp.search" + assert captured["args"]["__maibot_api_version__"] == "1" + + second_replace_result = await manager._cap_api_replace_dynamic( + "provider", + "api.replace_dynamic", + { + "apis": [ + { + "name": "mcp.read", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_read", + }, + } + ], + "offline_reason": "MCP 服务器已关闭", + }, + ) + + assert second_replace_result["success"] is True + assert second_replace_result["count"] == 1 + assert second_replace_result["offlined"] == 1 + + offlined_call_result = await manager._cap_api_call( + "provider", + "api.call", + { + "api_name": "provider.mcp.search", + "version": "1", + "args": {}, + }, + ) + assert offlined_call_result["success"] is False + assert "MCP 服务器已关闭" in str(offlined_call_result["error"]) + + list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"}) + assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == { + ("mcp.read", "1"), + } diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py index 4ddd4fa1..16aad080 100644 --- a/pytests/test_send_service.py +++ b/pytests/test_send_service.py @@ -73,6 +73,19 @@ def _build_target_stream() -> BotChatSession: ) +def test_inherit_platform_io_route_metadata_falls_back_to_bot_account( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """没有上下文消息时,也应回填当前平台账号用于账号级路由命中。""" + + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "") + + metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream()) + + assert metadata["platform_io_account_id"] == "bot-qq" + assert metadata["platform_io_target_user_id"] == "target-user" + + @pytest.mark.asyncio async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: """send service 应将发送职责统一交给 Platform IO。""" diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index be03e35d..ab1b11e5 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -394,23 +394,22 @@ class PlatformIOManager: """ drivers: List[PlatformIODriver] = [] + seen_driver_ids: set[str] = set() for binding in self._send_route_table.resolve_bindings(route_key): driver = self._driver_registry.get(binding.driver_id) - if driver is not None: + if driver is not None and driver.driver_id not in seen_driver_ids: drivers.append(driver) - if drivers: - return drivers + seen_driver_ids.add(driver.driver_id) fallback_driver = self._legacy_send_drivers.get(route_key.platform) - if fallback_driver is None: - return [] + if fallback_driver is not None: + descriptor = fallback_driver.descriptor + account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id) + scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope) + if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids: + drivers.append(fallback_driver) - descriptor = fallback_driver.descriptor - if descriptor.account_id is not None and route_key.account_id not in (None, descriptor.account_id): - return [] - if descriptor.scope is not None and route_key.scope not in (None, descriptor.scope): - return [] - return [fallback_driver] + return drivers @staticmethod def build_route_key_from_message(message: "SessionMessage") -> RouteKey: diff --git a/src/platform_io/outbound_tracker.py b/src/platform_io/outbound_tracker.py index 438aa566..3725691f 100644 --- a/src/platform_io/outbound_tracker.py +++ b/src/platform_io/outbound_tracker.py @@ -92,11 +92,24 @@ class OutboundTracker: raise ValueError("ttl_seconds 必须大于 0") self._ttl_seconds = ttl_seconds - self._pending: Dict[str, PendingOutboundRecord] = {} - self._pending_expire_heap: List[Tuple[float, str]] = [] + self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {} + self._pending_expire_heap: List[Tuple[float, str, str]] = [] self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {} self._receipt_expire_heap: List[Tuple[float, str]] = [] + @staticmethod + def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]: + """构造单条出站跟踪记录的唯一键。 + + Args: + internal_message_id: 内部消息 ID。 + driver_id: 负责当前投递的驱动 ID。 + + Returns: + Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。 + """ + return internal_message_id, driver_id + def begin_tracking( self, internal_message_id: str, @@ -116,13 +129,15 @@ class OutboundTracker: PendingOutboundRecord: 新创建的待完成记录。 Raises: - ValueError: 当同一个 ``internal_message_id`` 已经存在未完成记录时抛出。 + ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在 + 未完成记录时抛出。 """ now = time.monotonic() self._cleanup_expired(now) + pending_key = self._build_pending_key(internal_message_id, driver_id) - if internal_message_id in self._pending: - raise ValueError(f"消息 {internal_message_id} 已存在未完成的出站跟踪记录") + if pending_key in self._pending: + raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录") expires_at = now + self._ttl_seconds record = PendingOutboundRecord( @@ -133,8 +148,8 @@ class OutboundTracker: expires_at=expires_at, metadata=metadata or {}, ) - self._pending[internal_message_id] = record - heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id)) + self._pending[pending_key] = record + heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id)) return record def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]: @@ -149,7 +164,19 @@ class OutboundTracker: now = time.monotonic() self._cleanup_expired(now) - pending_record = self._pending.pop(receipt.internal_message_id, None) + pending_record: Optional[PendingOutboundRecord] = None + if receipt.driver_id: + pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id) + pending_record = self._pending.pop(pending_key, None) + else: + matched_records = [ + key + for key, record in self._pending.items() + if record.internal_message_id == receipt.internal_message_id + ] + if len(matched_records) == 1: + pending_record = self._pending.pop(matched_records[0], None) + if receipt.external_message_id: expires_at = now + self._ttl_seconds self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt( @@ -160,17 +187,33 @@ class OutboundTracker: heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id)) return pending_record - def get_pending(self, internal_message_id: str) -> Optional[PendingOutboundRecord]: + def get_pending( + self, + internal_message_id: str, + driver_id: Optional[str] = None, + ) -> Optional[PendingOutboundRecord]: """根据内部消息 ID 查询待完成记录。 Args: internal_message_id: 要查询的内部消息 ID。 + driver_id: 可选的驱动 ID;提供后仅返回该驱动上的待完成记录。 Returns: Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。 """ self._cleanup_expired(time.monotonic()) - return self._pending.get(internal_message_id) + + if driver_id: + return self._pending.get(self._build_pending_key(internal_message_id, driver_id)) + + matched_records = [ + record + for record in self._pending.values() + if record.internal_message_id == internal_message_id + ] + if len(matched_records) == 1: + return matched_records[0] + return None def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]: """根据外部平台消息 ID 查询已完成回执。 @@ -213,13 +256,14 @@ class OutboundTracker: ``expires_at`` 对比,跳过这类旧节点。 """ while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now: - expires_at, internal_message_id = heapq.heappop(self._pending_expire_heap) - current_record = self._pending.get(internal_message_id) + expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap) + pending_key = self._build_pending_key(internal_message_id, driver_id) + current_record = self._pending.get(pending_key) if current_record is None: continue if current_record.expires_at != expires_at: continue - self._pending.pop(internal_message_id, None) + self._pending.pop(pending_key, None) def _cleanup_expired_receipts(self, now: float) -> None: """清理已经过期的回执索引。 diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index 2eede108..67033fdd 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -72,6 +72,8 @@ class RuntimeComponentCapabilityMixin: "version": entry.version, "public": entry.public, "enabled": entry.enabled, + "dynamic": entry.dynamic, + "offline_reason": entry.offline_reason, "metadata": dict(entry.metadata), } @@ -109,6 +111,32 @@ class RuntimeComponentCapabilityMixin: return entry.plugin_id == caller_plugin_id or entry.public + @staticmethod + def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]: + """规范化 API 名称与版本参数。 + + 支持在 ``api_name`` 中直接携带 ``@version`` 后缀。 + """ + + normalized_api_name = str(api_name or "").strip() + normalized_version = str(version or "").strip() + if normalized_api_name and not normalized_version and "@" in normalized_api_name: + candidate_name, candidate_version = normalized_api_name.rsplit("@", 1) + candidate_name = candidate_name.strip() + candidate_version = candidate_version.strip() + if candidate_name and candidate_version: + normalized_api_name = candidate_name + normalized_version = candidate_version + return normalized_api_name, normalized_version + + @staticmethod + def _build_api_unavailable_error(entry: "APIEntry") -> str: + """构造 API 当前不可用时的错误信息。""" + + if entry.offline_reason: + return entry.offline_reason + return f"API {entry.registry_key} 当前不可用" + def _resolve_api_target( self: _RuntimeComponentManagerProtocol, caller_plugin_id: str, @@ -127,8 +155,7 @@ class RuntimeComponentCapabilityMixin: 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 """ - normalized_api_name = str(api_name or "").strip() - normalized_version = str(version or "").strip() + normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version) if not normalized_api_name: return None, None, "缺少必要参数 api_name" @@ -142,34 +169,61 @@ class RuntimeComponentCapabilityMixin: if supervisor is None: return None, None, f"未找到 API 提供方插件: {target_plugin_id}" - entry = supervisor.api_registry.get_api( + entries = supervisor.api_registry.get_apis( plugin_id=target_plugin_id, name=target_api_name, - enabled_only=True, + version=normalized_version, + enabled_only=False, ) - if entry is None: - return None, None, f"未找到 API: {normalized_api_name}" - if normalized_version and entry.version != normalized_version: - return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" - if not self._is_api_visible_to_plugin(entry, caller_plugin_id): + visible_enabled_entries = [ + entry + for entry in entries + if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled + ] + visible_disabled_entries = [ + entry + for entry in entries + if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled + ] + if len(visible_enabled_entries) == 1: + return supervisor, visible_enabled_entries[0], None + if len(visible_enabled_entries) > 1: + return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version" + if visible_disabled_entries: + if len(visible_disabled_entries) == 1: + return None, None, self._build_api_unavailable_error(visible_disabled_entries[0]) + return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version" + if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries): return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" - return supervisor, entry, None + if normalized_version: + return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" + return None, None, f"未找到 API: {normalized_api_name}" - visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] hidden_match_exists = False for supervisor in self.supervisors: - for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True): - if normalized_version and entry.version != normalized_version: - continue + for entry in supervisor.api_registry.get_apis( + name=normalized_api_name, + version=normalized_version, + enabled_only=False, + ): if self._is_api_visible_to_plugin(entry, caller_plugin_id): - visible_matches.append((supervisor, entry)) + if entry.enabled: + visible_enabled_matches.append((supervisor, entry)) + else: + visible_disabled_matches.append((supervisor, entry)) else: hidden_match_exists = True - if len(visible_matches) == 1: - return visible_matches[0][0], visible_matches[0][1], None - if len(visible_matches) > 1: - return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name" + if len(visible_enabled_matches) == 1: + return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None + if len(visible_enabled_matches) > 1: + return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version" + if visible_disabled_matches: + if len(visible_disabled_matches) == 1: + return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1]) + return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version" if hidden_match_exists: return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" if normalized_version: @@ -179,18 +233,20 @@ class RuntimeComponentCapabilityMixin: def _resolve_api_toggle_target( self: _RuntimeComponentManagerProtocol, name: str, + version: str = "", ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: """解析需要启用或禁用的 API 组件。 Args: name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 + version: 可选的 API 版本。 Returns: tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 """ - normalized_name = str(name or "").strip() + normalized_name, normalized_version = self._normalize_api_reference(name, version) if not normalized_name: return None, None, "缺少必要参数 name" @@ -204,24 +260,31 @@ class RuntimeComponentCapabilityMixin: if supervisor is None: return None, None, f"未找到 API 提供方插件: {plugin_id}" - entry = supervisor.api_registry.get_api( + entries = supervisor.api_registry.get_apis( plugin_id=plugin_id, name=api_name, + version=normalized_version, enabled_only=False, ) - if entry is None: + if not entries: return None, None, f"未找到 API: {normalized_name}" - return supervisor, entry, None + if len(entries) > 1: + return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version" + return supervisor, entries[0], None matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] for supervisor in self.supervisors: - for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False): + for entry in supervisor.api_registry.get_apis( + name=normalized_name, + version=normalized_version, + enabled_only=False, + ): matches.append((supervisor, entry)) if len(matches) == 1: return matches[0][0], matches[0][1], None if len(matches) > 1: - return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name" + return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version" return None, None, f"未找到 API: {normalized_name}" async def _cap_component_get_all_plugins( @@ -326,6 +389,7 @@ class RuntimeComponentCapabilityMixin: ) -> Any: name: str = args.get("name", "") component_type: str = args.get("component_type", "") + version: str = args.get("version", "") scope: str = args.get("scope", "global") stream_id: str = args.get("stream_id", "") if not name or not component_type: @@ -334,10 +398,10 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"} if self._is_api_component_type(component_type): - supervisor, api_entry, error = self._resolve_api_toggle_target(name) + supervisor, api_entry, error = self._resolve_api_toggle_target(name, version) if supervisor is None or api_entry is None: return {"success": False, "error": error or f"未找到 API: {name}"} - supervisor.api_registry.toggle_api_status(api_entry.full_name, True) + supervisor.api_registry.toggle_api_status(api_entry.registry_key, True) return {"success": True} comp, error = self._resolve_component_toggle_target(name, component_type) @@ -352,6 +416,7 @@ class RuntimeComponentCapabilityMixin: ) -> Any: name: str = args.get("name", "") component_type: str = args.get("component_type", "") + version: str = args.get("version", "") scope: str = args.get("scope", "global") stream_id: str = args.get("stream_id", "") if not name or not component_type: @@ -360,10 +425,10 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"} if self._is_api_component_type(component_type): - supervisor, api_entry, error = self._resolve_api_toggle_target(name) + supervisor, api_entry, error = self._resolve_api_toggle_target(name, version) if supervisor is None or api_entry is None: return {"success": False, "error": error or f"未找到 API: {name}"} - supervisor.api_registry.toggle_api_status(api_entry.full_name, False) + supervisor.api_registry.toggle_api_status(api_entry.registry_key, False) return {"success": True} comp, error = self._resolve_component_toggle_target(name, component_type) @@ -488,11 +553,17 @@ class RuntimeComponentCapabilityMixin: if supervisor is None or entry is None: return {"success": False, "error": error or "API 解析失败"} + invoke_args = dict(api_args) + if entry.dynamic: + invoke_args.setdefault("__maibot_api_name__", entry.name) + invoke_args.setdefault("__maibot_api_full_name__", entry.full_name) + invoke_args.setdefault("__maibot_api_version__", entry.version) + try: response = await supervisor.invoke_api( plugin_id=entry.plugin_id, - component_name=entry.name, - args=api_args, + component_name=entry.handler_name, + args=invoke_args, timeout_ms=30000, ) except Exception as exc: @@ -555,10 +626,16 @@ class RuntimeComponentCapabilityMixin: del capability target_plugin_id = str(args.get("plugin_id", "") or "").strip() + api_name, version = self._normalize_api_reference( + str(args.get("api_name", args.get("name", "")) or ""), + str(args.get("version", "") or ""), + ) apis: List[Dict[str, Any]] = [] for supervisor in self.supervisors: for entry in supervisor.api_registry.get_apis( plugin_id=target_plugin_id or None, + name=api_name, + version=version, enabled_only=True, ): if not self._is_api_visible_to_plugin(entry, plugin_id): @@ -567,3 +644,75 @@ class RuntimeComponentCapabilityMixin: apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"]))) return {"success": True, "apis": apis} + + async def _cap_api_replace_dynamic( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """替换插件自行维护的动态 API 列表。""" + + del capability + raw_apis = args.get("apis", []) + offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线" + if not isinstance(raw_apis, list): + return {"success": False, "error": "参数 apis 必须为列表"} + + try: + supervisor = self._get_supervisor_for_plugin(plugin_id) + except RuntimeError as exc: + return {"success": False, "error": str(exc)} + + if supervisor is None: + return {"success": False, "error": f"未找到插件: {plugin_id}"} + + normalized_components: List[Dict[str, Any]] = [] + seen_registry_keys: set[str] = set() + for index, raw_api in enumerate(raw_apis): + if not isinstance(raw_api, dict): + return {"success": False, "error": f"apis[{index}] 必须为字典"} + + api_name = str(raw_api.get("name", "") or "").strip() + component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip() + if not api_name: + return {"success": False, "error": f"apis[{index}] 缺少 name"} + if not self._is_api_component_type(component_type): + return {"success": False, "error": f"apis[{index}] 不是 API 组件"} + + metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {} + normalized_metadata = dict(metadata) + normalized_metadata["dynamic"] = True + version = str(normalized_metadata.get("version", "1") or "1").strip() or "1" + registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version) + if registry_key in seen_registry_keys: + return {"success": False, "error": f"动态 API 重复声明: {registry_key}"} + seen_registry_keys.add(registry_key) + + existing_entry = supervisor.api_registry.get_api( + plugin_id, + api_name, + version=version, + enabled_only=False, + ) + if existing_entry is not None and not existing_entry.dynamic: + return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"} + + normalized_components.append( + { + "name": api_name, + "component_type": "API", + "metadata": normalized_metadata, + } + ) + + registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis( + plugin_id, + normalized_components, + offline_reason=offline_reason, + ) + return { + "success": True, + "count": registered_count, + "offlined": offlined_count, + } diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py index 31693833..7f87604d 100644 --- a/src/plugin_runtime/capabilities/registry.py +++ b/src/plugin_runtime/capabilities/registry.py @@ -77,6 +77,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi _register("api.call", manager._cap_api_call) _register("api.get", manager._cap_api_get) _register("api.list", manager._cap_api_list) + _register("api.replace_dynamic", manager._cap_api_replace_dynamic) _register("component.get_all_plugins", manager._cap_component_get_all_plugins) _register("component.get_plugin_info", manager._cap_component_get_plugin_info) diff --git a/src/plugin_runtime/host/api_registry.py b/src/plugin_runtime/host/api_registry.py index 84578ca5..1cbc05f6 100644 --- a/src/plugin_runtime/host/api_registry.py +++ b/src/plugin_runtime/host/api_registry.py @@ -1,45 +1,60 @@ """Host 侧插件 API 动态注册表。""" -from typing import Any, Dict, List, Optional, Set +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple from src.common.logger import get_logger logger = get_logger("plugin_runtime.host.api_registry") +@dataclass(slots=True) class APIEntry: """API 组件条目。""" - __slots__ = ( - "description", - "disabled_session", - "enabled", - "full_name", - "metadata", - "name", - "plugin_id", - "public", - "version", - ) + name: str + plugin_id: str + description: str = "" + version: str = "1" + public: bool = False + metadata: Dict[str, Any] = field(default_factory=dict) + enabled: bool = True + handler_name: str = "" + dynamic: bool = False + offline_reason: str = "" + disabled_session: Set[str] = field(default_factory=set) + full_name: str = field(init=False) + registry_key: str = field(init=False) - def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - """初始化 API 组件条目。 + def __post_init__(self) -> None: + """规范化 API 条目字段。""" - Args: - name: API 名称。 - plugin_id: 所属插件 ID。 - metadata: API 元数据。 - """ + self.name = str(self.name or "").strip() + self.plugin_id = str(self.plugin_id or "").strip() + self.description = str(self.description or "").strip() + self.version = str(self.version or "1").strip() or "1" + self.handler_name = str(self.handler_name or self.name).strip() or self.name + self.offline_reason = str(self.offline_reason or "").strip() + self.full_name = f"{self.plugin_id}.{self.name}" + self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version) - self.name: str = name - self.full_name: str = f"{plugin_id}.{name}" - self.plugin_id: str = plugin_id - self.description: str = str(metadata.get("description", "") or "") - self.version: str = str(metadata.get("version", "1") or "1").strip() or "1" - self.public: bool = bool(metadata.get("public", False)) - self.metadata: Dict[str, Any] = dict(metadata) - self.enabled: bool = bool(metadata.get("enabled", True)) - self.disabled_session: Set[str] = set() + @classmethod + def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry": + """根据 Runner 上报的元数据构造 API 条目。""" + + safe_metadata = dict(metadata) + return cls( + name=name, + plugin_id=plugin_id, + description=str(safe_metadata.get("description", "") or ""), + version=str(safe_metadata.get("version", "1") or "1"), + public=bool(safe_metadata.get("public", False)), + metadata=safe_metadata, + enabled=bool(safe_metadata.get("enabled", True)), + handler_name=str(safe_metadata.get("handler_name", name) or name), + dynamic=bool(safe_metadata.get("dynamic", False)), + offline_reason=str(safe_metadata.get("offline_reason", "") or ""), + ) class APIRegistry: @@ -53,6 +68,7 @@ class APIRegistry: """初始化 API 注册表。""" self._apis: Dict[str, APIEntry] = {} + self._by_full_name: Dict[str, List[APIEntry]] = {} self._by_plugin: Dict[str, List[APIEntry]] = {} self._by_name: Dict[str, List[APIEntry]] = {} @@ -60,75 +76,75 @@ class APIRegistry: """清空全部 API 注册状态。""" self._apis.clear() + self._by_full_name.clear() self._by_plugin.clear() self._by_name.clear() @staticmethod def _is_api_component(component_type: Any) -> bool: - """判断组件声明是否属于 API。 - - Args: - component_type: 原始组件类型值。 - - Returns: - bool: 是否为 API 组件。 - """ + """判断组件声明是否属于 API。""" return str(component_type or "").strip().upper() == "API" + @staticmethod + def _normalize_query_version(version: Any) -> str: + """规范化查询使用的版本字符串。""" + + return str(version or "").strip() + + @classmethod + def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]: + """解析可能带 ``@version`` 后缀的 API 引用。""" + + normalized_reference = str(reference or "").strip() + normalized_version = cls._normalize_query_version(version) + if normalized_reference and not normalized_version and "@" in normalized_reference: + candidate_reference, candidate_version = normalized_reference.rsplit("@", 1) + candidate_reference = candidate_reference.strip() + candidate_version = candidate_version.strip() + if candidate_reference and candidate_version: + normalized_reference = candidate_reference + normalized_version = candidate_version + return normalized_reference, normalized_version + + @staticmethod + def build_registry_key(plugin_id: str, name: str, version: str) -> str: + """构造 API 注册表唯一键。""" + + normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}" + normalized_version = str(version or "1").strip() or "1" + return f"{normalized_full_name}@{normalized_version}" + @staticmethod def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool: - """判断 API 条目当前是否处于启用状态。 - - Args: - entry: 待检查的 API 条目。 - session_id: 可选的会话 ID。 - - Returns: - bool: 当前是否可用。 - """ + """判断 API 条目当前是否处于启用状态。""" if session_id and session_id in entry.disabled_session: return False return entry.enabled def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: - """注册单个 API 条目。 - - Args: - name: API 名称。 - plugin_id: 所属插件 ID。 - metadata: API 元数据。 - - Returns: - bool: 是否成功注册。 - """ + """注册单个 API 条目。""" normalized_name = str(name or "").strip() if not normalized_name: logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略") return False - entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata) - if entry.full_name in self._apis: - logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目") - self._remove_entry(self._apis[entry.full_name]) + entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata) + existing_entry = self._apis.get(entry.registry_key) + if existing_entry is not None: + logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目") + self._remove_entry(existing_entry) - self._apis[entry.full_name] = entry + self._apis[entry.registry_key] = entry + self._by_full_name.setdefault(entry.full_name, []).append(entry) self._by_plugin.setdefault(plugin_id, []).append(entry) self._by_name.setdefault(entry.name, []).append(entry) return True def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int: - """批量注册某个插件声明的全部 API。 - - Args: - plugin_id: 插件 ID。 - components: 插件组件声明列表。 - - Returns: - int: 成功注册的 API 数量。 - """ + """批量注册某个插件声明的全部 API。""" count = 0 for component in components: @@ -142,14 +158,60 @@ class APIRegistry: count += 1 return count + def replace_plugin_dynamic_apis( + self, + plugin_id: str, + components: List[Dict[str, Any]], + *, + offline_reason: str = "动态 API 已下线", + ) -> Tuple[int, int]: + """替换指定插件当前声明的动态 API 集合。""" + + normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线" + desired_registry_keys: Set[str] = set() + registered_count = 0 + + for component in components: + if not self._is_api_component(component.get("component_type")): + continue + metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {} + dynamic_metadata = dict(metadata) + dynamic_metadata["dynamic"] = True + dynamic_metadata.pop("offline_reason", None) + + entry = APIEntry.from_metadata( + name=str(component.get("name", "") or ""), + plugin_id=plugin_id, + metadata=dynamic_metadata, + ) + desired_registry_keys.add(entry.registry_key) + if self.register_api(entry.name, plugin_id, dynamic_metadata): + registered_count += 1 + + offlined_count = 0 + for entry in list(self._by_plugin.get(plugin_id, [])): + if not entry.dynamic or entry.registry_key in desired_registry_keys: + continue + entry.enabled = False + entry.offline_reason = normalized_offline_reason + entry.metadata["offline_reason"] = normalized_offline_reason + offlined_count += 1 + + return registered_count, offlined_count + def _remove_entry(self, entry: APIEntry) -> None: - """从全部索引中移除单个 API 条目。 + """从全部索引中移除单个 API 条目。""" - Args: - entry: 待移除的 API 条目。 - """ + self._apis.pop(entry.registry_key, None) + + full_name_entries = self._by_full_name.get(entry.full_name) + if full_name_entries is not None: + self._by_full_name[entry.full_name] = [ + candidate for candidate in full_name_entries if candidate is not entry + ] + if not self._by_full_name[entry.full_name]: + self._by_full_name.pop(entry.full_name, None) - self._apis.pop(entry.full_name, None) plugin_entries = self._by_plugin.get(entry.plugin_id) if plugin_entries is not None: self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry] @@ -163,14 +225,7 @@ class APIRegistry: self._by_name.pop(entry.name, None) def remove_apis_by_plugin(self, plugin_id: str) -> int: - """移除某个插件的全部 API。 - - Args: - plugin_id: 目标插件 ID。 - - Returns: - int: 被移除的 API 数量。 - """ + """移除某个插件的全部 API。""" entries = list(self._by_plugin.get(plugin_id, [])) for entry in entries: @@ -181,49 +236,48 @@ class APIRegistry: self, full_name: str, *, + version: str = "", enabled_only: bool = True, session_id: Optional[str] = None, ) -> Optional[APIEntry]: - """按完整名查询单个 API。 + """按完整名查询单个 API。""" - Args: - full_name: API 完整名,格式为 ``plugin_id.api_name``。 - enabled_only: 是否仅返回启用状态的 API。 - session_id: 可选的会话 ID。 - - Returns: - Optional[APIEntry]: 命中时返回 API 条目。 - """ - - entry = self._apis.get(full_name) - if entry is None: + normalized_full_name, normalized_version = self._split_reference(full_name, version) + if not normalized_full_name: return None - if enabled_only and not self.check_api_enabled(entry, session_id): + + if normalized_version: + entry = self._apis.get(f"{normalized_full_name}@{normalized_version}") + if entry is None: + return None + if enabled_only and not self.check_api_enabled(entry, session_id): + return None + return entry + + candidates = list(self._by_full_name.get(normalized_full_name, [])) + filtered_entries = [ + entry + for entry in candidates + if not enabled_only or self.check_api_enabled(entry, session_id) + ] + if len(filtered_entries) != 1: return None - return entry + return filtered_entries[0] def get_api( self, plugin_id: str, name: str, *, + version: str = "", enabled_only: bool = True, session_id: Optional[str] = None, ) -> Optional[APIEntry]: - """按插件 ID 和短名查询单个 API。 - - Args: - plugin_id: 提供方插件 ID。 - name: API 短名。 - enabled_only: 是否仅返回启用状态的 API。 - session_id: 可选的会话 ID。 - - Returns: - Optional[APIEntry]: 命中时返回 API 条目。 - """ + """按插件 ID、短名与版本查询单个 API。""" return self.get_api_by_full_name( f"{plugin_id}.{name}", + version=version, enabled_only=enabled_only, session_id=session_id, ) @@ -233,22 +287,15 @@ class APIRegistry: *, plugin_id: Optional[str] = None, name: str = "", + version: str = "", enabled_only: bool = True, session_id: Optional[str] = None, ) -> List[APIEntry]: - """查询 API 列表。 - - Args: - plugin_id: 可选的插件 ID 过滤条件。 - name: 可选的 API 名称过滤条件。 - enabled_only: 是否仅返回启用状态的 API。 - session_id: 可选的会话 ID。 - - Returns: - List[APIEntry]: 符合条件的 API 条目列表。 - """ + """查询 API 列表。""" normalized_name = str(name or "").strip() + normalized_version = self._normalize_query_version(version) + if plugin_id: candidates = list(self._by_plugin.get(plugin_id, [])) elif normalized_name: @@ -258,26 +305,35 @@ class APIRegistry: filtered_entries: List[APIEntry] = [] for entry in candidates: + if plugin_id and entry.plugin_id != plugin_id: + continue if normalized_name and entry.name != normalized_name: continue + if normalized_version and entry.version != normalized_version: + continue if enabled_only and not self.check_api_enabled(entry, session_id): continue filtered_entries.append(entry) + + filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version)) return filtered_entries - def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: - """设置指定 API 的启用状态。 + def toggle_api_status( + self, + full_name: str, + enabled: bool, + *, + version: str = "", + session_id: Optional[str] = None, + ) -> bool: + """设置指定 API 的启用状态。""" - Args: - full_name: API 完整名。 - enabled: 目标启用状态。 - session_id: 可选的会话 ID,仅对该会话生效。 - - Returns: - bool: 是否设置成功。 - """ - - entry = self._apis.get(full_name) + entry = self.get_api_by_full_name( + full_name, + version=version, + enabled_only=False, + session_id=session_id, + ) if entry is None: return False if session_id: @@ -287,4 +343,7 @@ class APIRegistry: entry.disabled_session.add(session_id) else: entry.enabled = enabled + if enabled: + entry.offline_reason = "" + entry.metadata.pop("offline_reason", None) return True diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py index 3fb48c6a..70593768 100644 --- a/src/plugin_runtime/host/authorization.py +++ b/src/plugin_runtime/host/authorization.py @@ -7,6 +7,8 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, Set, Tuple +_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"}) + @dataclass class CapabilityPermissionToken: @@ -46,6 +48,9 @@ class AuthorizationManager: Returns: return (bool, str): (是否有此能力, 原因) """ + if capability in _ALWAYS_ALLOWED_CAPABILITIES: + return True, "" + token = self._permission_tokens.get(plugin_id) if not token: return False, f"插件 {plugin_id} 未注册能力令牌" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index b94b01d1..b38946d6 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -45,6 +45,14 @@ from src.plugin_runtime.runner.rpc_client import RPCClient logger = get_logger("plugin_runtime.runner.main") +_PLUGIN_ALLOWED_RAW_HOST_METHODS = frozenset( + { + "cap.call", + "host.route_message", + "host.update_message_gateway_state", + } +) + class _ContextAwarePlugin(Protocol): """支持注入运行时上下文的插件协议。 @@ -247,8 +255,14 @@ class PluginRunner: logger.warning( f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份" ) + normalized_method = str(method or "").strip() + if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS: + raise PermissionError( + f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: " + f"{normalized_method or ''}" + ) resp = await rpc_client.send_request( - method=method, + method=normalized_method, plugin_id=bound_plugin_id, payload=payload or {}, ) diff --git a/src/services/send_service.py b/src/services/send_service.py index 54f2a9de..134fb15e 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -57,20 +57,23 @@ def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[s inherited_metadata: Dict[str, object] = {} context_message = target_stream.context.message if target_stream.context else None - if context_message is None: - return inherited_metadata + if context_message is not None: + additional_config = context_message.message_info.additional_config + if isinstance(additional_config, dict): + for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS): + value = additional_config.get(key) + if value is None: + continue + normalized_value = str(value).strip() + if normalized_value: + inherited_metadata[key] = value - additional_config = context_message.message_info.additional_config - if not isinstance(additional_config, dict): - return inherited_metadata - - for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS): - value = additional_config.get(key) - if value is None: - continue - normalized_value = str(value).strip() - if normalized_value: - inherited_metadata[key] = value + # 当目标会话没有可继承的上下文消息时,至少补齐当前平台账号, + # 让按 ``platform + account_id`` 绑定的路由仍有机会命中。 + if not RouteKeyFactory.extract_components(inherited_metadata)[0]: + bot_account = get_bot_account(target_stream.platform) + if bot_account: + inherited_metadata["platform_io_account_id"] = bot_account if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()): inherited_metadata["platform_io_target_group_id"] = normalized_group_id From 0c508995ddbfa42020485bdc23f4d5a62b1e54f1 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 21:48:19 +0800 Subject: [PATCH 36/45] feat: enhance session ID calculation and plugin management - Updated `calculate_session_id` method in `SessionUtils` to include optional `account_id` and `scope` parameters for more granular session ID generation. - Added new environment variables in `plugin_runtime` for external plugin dependencies and global configuration snapshots. - Introduced methods in `RuntimeComponentManagerProtocol` for loading and reloading plugins globally, accommodating external dependencies. - Enhanced `PluginRunnerSupervisor` to manage external available plugin IDs during plugin reloads. - Implemented dependency extraction and management in `PluginRuntimeManager` to handle cross-supervisor dependencies. - Added tests for session ID calculation and message registration in `ChatManager` to ensure correct behavior with new parameters. --- pytests/test_plugin_runtime.py | 114 ++++--- pytests/utils_test/test_session_utils.py | 42 +++ src/chat/message_receive/bot.py | 16 +- src/chat/message_receive/chat_manager.py | 58 +++- src/common/utils/utils_session.py | 22 +- src/plugin_runtime/__init__.py | 6 + src/plugin_runtime/capabilities/components.py | 136 ++++---- src/plugin_runtime/host/supervisor.py | 86 +++-- src/plugin_runtime/integration.py | 302 +++++++++++++++++- src/plugin_runtime/protocol/envelope.py | 4 + src/plugin_runtime/runner/plugin_loader.py | 9 +- src/plugin_runtime/runner/runner_main.py | 140 +++++++- 12 files changed, 765 insertions(+), 170 deletions(-) create mode 100644 pytests/utils_test/test_session_utils.py diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 9b46f897..e094d85b 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -3,6 +3,7 @@ 验证协议层、传输层、RPC 通信链路的正确性。 """ +from pathlib import Path from types import SimpleNamespace import asyncio @@ -2362,6 +2363,8 @@ class TestIntegration: from src.plugin_runtime import integration as integration_module instances = [] + builtin_dir = Path("builtin") + thirdparty_dir = Path("thirdparty") class FakeCapabilityService: def register_capability(self, name, impl): @@ -2369,11 +2372,18 @@ class TestIntegration: class FakeSupervisor: def __init__(self, plugin_dirs=None, socket_path=None): - self.plugin_dirs = plugin_dirs or [] + self._plugin_dirs = plugin_dirs or [] self.capability_service = FakeCapabilityService() + self.external_plugin_ids = [] self.stopped = False instances.append(self) + def set_external_available_plugin_ids(self, plugin_ids): + self.external_plugin_ids = list(plugin_ids) + + def get_loaded_plugin_ids(self): + return [] + async def start(self): if len(instances) == 2 and self is instances[1]: raise RuntimeError("boom") @@ -2382,10 +2392,10 @@ class TestIntegration: self.stopped = True monkeypatch.setattr( - integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]) + integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir]) ) monkeypatch.setattr( - integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]) + integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir]) ) import src.plugin_runtime.host.supervisor as supervisor_module @@ -2427,8 +2437,11 @@ class TestIntegration: self.reload_reasons = [] self.config_updates = [] - async def reload_plugins(self, plugin_ids=None, reason="manual"): - self.reload_reasons.append((plugin_ids, reason)) + def get_loaded_plugin_ids(self): + return sorted(self._registered_plugins.keys()) + + async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): + self.reload_reasons.append((plugin_ids, reason, external_available_plugins or [])) async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): self.config_updates.append((plugin_id, config_data, config_version)) @@ -2453,11 +2466,59 @@ class TestIntegration: await manager._handle_plugin_source_changes(changes) assert manager._builtin_supervisor.reload_reasons == [] - assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")] + assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher", ["alpha"])] assert manager._builtin_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == [] assert refresh_calls == [True] + @pytest.mark.asyncio + async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch): + from src.plugin_runtime import integration as integration_module + + class FakeRegistration: + def __init__(self, dependencies): + self.dependencies = dependencies + + class FakeSupervisor: + def __init__(self, registrations): + self._registered_plugins = registrations + self.reload_calls = [] + + def get_loaded_plugin_ids(self): + return sorted(self._registered_plugins.keys()) + + async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): + self.reload_calls.append((plugin_ids, reason, sorted(external_available_plugins or []))) + return True + + builtin_supervisor = FakeSupervisor({"alpha": FakeRegistration([])}) + third_party_supervisor = FakeSupervisor( + { + "beta": FakeRegistration(["alpha"]), + "gamma": FakeRegistration(["beta"]), + } + ) + + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = builtin_supervisor + manager._third_party_supervisor = third_party_supervisor + warning_messages = [] + + monkeypatch.setattr( + integration_module.logger, + "warning", + lambda message: warning_messages.append(message), + ) + + reloaded = await manager.reload_plugins_globally(["alpha"], reason="manual") + + assert reloaded is True + assert builtin_supervisor.reload_calls == [(["alpha"], "manual", ["beta", "gamma"])] + assert third_party_supervisor.reload_calls == [] + assert len(warning_messages) == 1 + assert "beta, gamma" in warning_messages[0] + assert "跨 Supervisor API 调用仍然可用" in warning_messages[0] + @pytest.mark.asyncio async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module @@ -2623,55 +2684,30 @@ class TestIntegration: async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch): from src.plugin_runtime import integration as integration_module - class FakeSupervisor: - def __init__(self): - self._registered_plugins = {"alpha": object()} + manager = integration_module.PluginRuntimeManager() + monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False)) - async def reload_plugins(self, reason="manual"): - return False - - class FakeManager: - def __init__(self): - self.supervisors = [FakeSupervisor()] - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - - result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin( + result = await manager._cap_component_reload_plugin( "plugin_a", "component.reload_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False - assert "已回滚" in result["error"] + assert result["error"] == "插件 alpha 热重载失败" @pytest.mark.asyncio async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module - plugin_root = tmp_path / "plugins" - plugin_root.mkdir() - (plugin_root / "alpha").mkdir() + manager = integration_module.PluginRuntimeManager() + monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False)) - class FakeSupervisor: - def __init__(self): - self._registered_plugins = {} - self._plugin_dirs = [str(plugin_root)] - - async def reload_plugins(self, reason="manual"): - return False - - class FakeManager: - def __init__(self): - self.supervisors = [FakeSupervisor()] - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - - result = await integration_module.PluginRuntimeManager._cap_component_load_plugin( + result = await manager._cap_component_load_plugin( "plugin_a", "component.load_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False - assert "已回滚" in result["error"] + assert result["error"] == "插件 alpha 热重载失败" diff --git a/pytests/utils_test/test_session_utils.py b/pytests/utils_test/test_session_utils.py new file mode 100644 index 00000000..c44e2eba --- /dev/null +++ b/pytests/utils_test/test_session_utils.py @@ -0,0 +1,42 @@ +from types import SimpleNamespace + +from src.chat.message_receive.chat_manager import ChatManager +from src.common.utils.utils_session import SessionUtils + + +def test_calculate_session_id_distinguishes_account_and_scope() -> None: + base_session_id = SessionUtils.calculate_session_id("qq", user_id="42") + same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42") + account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123") + route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main") + + assert base_session_id == same_base_session_id + assert account_scoped_session_id != base_session_id + assert route_scoped_session_id != account_scoped_session_id + + +def test_chat_manager_register_message_uses_route_metadata() -> None: + chat_manager = ChatManager() + message = SimpleNamespace( + platform="qq", + session_id="", + message_info=SimpleNamespace( + user_info=SimpleNamespace(user_id="42"), + group_info=SimpleNamespace(group_id="1000"), + additional_config={ + "platform_io_account_id": "123", + "platform_io_scope": "main", + }, + ), + ) + + chat_manager.register_message(message) + + assert message.session_id == SessionUtils.calculate_session_id( + "qq", + user_id="42", + group_id="1000", + account_id="123", + scope="main", + ) + assert chat_manager.last_messages[message.session_id] is message diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 025150fc..1fc4ef53 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -10,6 +10,7 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv from src.common.logger import get_logger from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_session import SessionUtils +from src.platform_io.route_key_factory import RouteKeyFactory # from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.core.announcement_manager import global_announcement_manager @@ -270,11 +271,18 @@ class ChatBot: try: group_info = message.message_info.group_info user_info = message.message_info.user_info + account_id = None + scope = None + additional_config = message.message_info.additional_config + if isinstance(additional_config, dict): + account_id, scope = RouteKeyFactory.extract_components(additional_config) session_id = SessionUtils.calculate_session_id( message.platform, user_id=message.message_info.user_info.user_id, group_id=group_info.group_id if group_info else None, + account_id=account_id, + scope=scope, ) message.session_id = session_id # 正确初始化session_id @@ -317,7 +325,13 @@ class ChatBot: platform = message.platform user_id = user_info.user_id group_id = group_info.group_id if group_info else None - _ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在 + _ = await chat_manager.get_or_create_session( + platform, + user_id, + group_id, + account_id=account_id, + scope=scope, + ) # 确保会话存在 # message.update_chat_stream(chat) diff --git a/src/chat/message_receive/chat_manager.py b/src/chat/message_receive/chat_manager.py index b11d233c..48d89956 100644 --- a/src/chat/message_receive/chat_manager.py +++ b/src/chat/message_receive/chat_manager.py @@ -1,15 +1,16 @@ +import asyncio from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional + from rich.traceback import install from sqlmodel import select -from typing import Optional, TYPE_CHECKING, List, Dict -import asyncio - -from src.common.logger import get_logger from src.common.data_models.chat_session_data_model import MaiChatSession -from src.common.database.database_model import ChatSession from src.common.database.database import get_db_session +from src.common.database.database_model import ChatSession +from src.common.logger import get_logger from src.common.utils.utils_session import SessionUtils +from src.platform_io.route_key_factory import RouteKeyFactory if TYPE_CHECKING: from .message import SessionMessage @@ -82,7 +83,12 @@ class ChatManager: logger.error(f"初始化聊天管理器出现错误: {e}") async def get_or_create_session( - self, platform: str, user_id: str, group_id: Optional[str] = None + self, + platform: str, + user_id: str, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, ) -> BotChatSession: """获取会话,如果不存在则创建一个新会话;一个封装方法。 @@ -90,12 +96,20 @@ class ChatManager: platform: 平台 user_id: 用户ID group_id: 群ID(如果是群聊) + account_id: 平台账号 ID + scope: 路由作用域 Returns: return (BotChatSession) 会话对象 Raises: Exception: 获取或创建会话时发生错误 """ - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) if session := self.get_session_by_session_id(session_id): session.update_active_time() return session @@ -131,7 +145,18 @@ class ChatManager: raise ValueError("消息缺少平台信息") user_id = message.message_info.user_info.user_id group_id = message.message_info.group_info.group_id if message.message_info.group_info else None - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + account_id = None + scope = None + additional_config = message.message_info.additional_config + if isinstance(additional_config, dict): + account_id, scope = RouteKeyFactory.extract_components(additional_config) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) message.session_id = session_id # 确保消息的session_id正确设置 self.last_messages[session_id] = message @@ -188,7 +213,12 @@ class ChatManager: return None def get_session_by_info( - self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None + self, + platform: str, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, ) -> Optional[BotChatSession]: """根据平台、用户ID和群ID获取对应的会话 @@ -196,10 +226,18 @@ class ChatManager: platform: 平台 user_id: 用户ID group_id: 群ID(如果是群聊) + account_id: 平台账号 ID + scope: 路由作用域 Returns: return (Optional[BotChatSession]): 会话对象,如果不存在则返回None """ - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) return self.get_session_by_session_id(session_id) def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]: diff --git a/src/common/utils/utils_session.py b/src/common/utils/utils_session.py index a383f5a2..1b6d8f72 100644 --- a/src/common/utils/utils_session.py +++ b/src/common/utils/utils_session.py @@ -5,13 +5,22 @@ import hashlib class SessionUtils: @staticmethod - def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str: + def calculate_session_id( + platform: str, + *, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, + ) -> str: """计算session_id Args: platform: 平台名称 user_id: 用户ID(如果是私聊) group_id: 群ID(如果是群聊) + account_id: 当前平台账号 ID,可选 + scope: 当前路由作用域,可选 Returns: str: 计算得到的会话ID Raises: @@ -19,8 +28,15 @@ class SessionUtils: """ if not user_id and not group_id: raise ValueError("UserID 或 GroupID 必须提供其一") + + route_components = [] + if account_id: + route_components.append(f"account:{account_id}") + if scope: + route_components.append(f"scope:{scope}") + if group_id: - components = [platform, group_id] + components = [platform, *route_components, group_id] else: - components = [platform, user_id, "private"] + components = [platform, *route_components, user_id, "private"] return hashlib.md5("_".join(components).encode()).hexdigest() diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py index a881d399..704ce514 100644 --- a/src/plugin_runtime/__init__.py +++ b/src/plugin_runtime/__init__.py @@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS" ENV_HOST_VERSION = "MAIBOT_HOST_VERSION" """Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验""" + +ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS" +"""Runner 启动时可视为已满足的外部插件依赖列表(JSON 数组)""" + +ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT" +"""Runner 启动时注入的全局配置快照(JSON 对象)""" diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index 67033fdd..2e4c111c 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence from src.common.logger import get_logger @@ -15,8 +15,35 @@ class _RuntimeComponentManagerProtocol(Protocol): @property def supervisors(self) -> List["PluginSupervisor"]: ... + def _normalize_component_type(self, component_type: str) -> str: ... + + def _is_api_component_type(self, component_type: str) -> bool: ... + + def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ... + + def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ... + + def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ... + + def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ... + + def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ... + def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ... + def _resolve_api_target( + self, + caller_plugin_id: str, + api_name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ... + + def _resolve_api_toggle_target( + self, + name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ... + def _resolve_component_toggle_target( self, name: str, component_type: str ) -> tuple[Optional["ComponentEntry"], Optional[str]]: ... @@ -25,6 +52,10 @@ class _RuntimeComponentManagerProtocol(Protocol): def _iter_plugin_dirs(self) -> Iterable[Path]: ... + async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ... + + async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ... + class RuntimeComponentCapabilityMixin: @staticmethod @@ -266,20 +297,22 @@ class RuntimeComponentCapabilityMixin: version=normalized_version, enabled_only=False, ) - if not entries: - return None, None, f"未找到 API: {normalized_name}" - if len(entries) > 1: + if len(entries) == 1: + return supervisor, entries[0], None + if entries: return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version" - return supervisor, entries[0], None + return None, None, f"未找到 API: {normalized_name}" matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] for supervisor in self.supervisors: - for entry in supervisor.api_registry.get_apis( - name=normalized_name, - version=normalized_version, - enabled_only=False, - ): - matches.append((supervisor, entry)) + matches.extend( + (supervisor, entry) + for entry in supervisor.api_registry.get_apis( + name=normalized_name, + version=normalized_version, + enabled_only=False, + ) + ) if len(matches) == 1: return matches[0][0], matches[0][1], None @@ -453,39 +486,14 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} try: - registered_supervisor = self._get_supervisor_for_plugin(plugin_name) - except RuntimeError as exc: - return {"success": False, "error": str(exc)} + loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}") + except Exception as e: + logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} - if registered_supervisor is not None: - try: - reloaded = await registered_supervisor.reload_plugins( - plugin_ids=[plugin_name], - reason=f"load {plugin_name}", - ) - if reloaded: - return {"success": True, "count": 1} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - - for sv in self.supervisors: - for pdir in sv._plugin_dirs: - if (pdir / plugin_name).is_dir(): - try: - reloaded = await sv.reload_plugins( - plugin_ids=[plugin_name], - reason=f"load {plugin_name}", - ) - if reloaded: - return {"success": True, "count": 1} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - - return {"success": False, "error": f"未找到插件: {plugin_name}"} + if loaded: + return {"success": True, "count": 1} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败"} async def _cap_component_unload_plugin( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] @@ -507,23 +515,14 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} try: - sv = self._get_supervisor_for_plugin(plugin_name) - except RuntimeError as exc: - return {"success": False, "error": str(exc)} + reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}") + except Exception as e: + logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} - if sv is not None: - try: - reloaded = await sv.reload_plugins( - plugin_ids=[plugin_name], - reason=f"reload {plugin_name}", - ) - if reloaded: - return {"success": True} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - return {"success": False, "error": f"未找到插件: {plugin_name}"} + if reloaded: + return {"success": True} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败"} async def _cap_api_call( self: _RuntimeComponentManagerProtocol, @@ -632,15 +631,16 @@ class RuntimeComponentCapabilityMixin: ) apis: List[Dict[str, Any]] = [] for supervisor in self.supervisors: - for entry in supervisor.api_registry.get_apis( - plugin_id=target_plugin_id or None, - name=api_name, - version=version, - enabled_only=True, - ): - if not self._is_api_visible_to_plugin(entry, plugin_id): - continue - apis.append(self._serialize_api_entry(entry)) + apis.extend( + self._serialize_api_entry(entry) + for entry in supervisor.api_registry.get_apis( + plugin_id=target_plugin_id or None, + name=api_name, + version=version, + enabled_only=True, + ) + if self._is_api_visible_to_plugin(entry, plugin_id) + ) apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"]))) return {"success": True, "apis": apis} diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index afe944e5..693eae51 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -4,17 +4,26 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import asyncio import contextlib +import json import os import sys from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import config_manager, global_config from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager from src.platform_io.drivers import PluginPlatformDriver from src.platform_io.route_key_factory import RouteKeyFactory -from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN +from src.plugin_runtime import ( + ENV_EXTERNAL_PLUGIN_IDS, + ENV_GLOBAL_CONFIG_SNAPSHOT, + ENV_HOST_VERSION, + ENV_IPC_ADDRESS, + ENV_PLUGIN_DIRS, + ENV_SESSION_TOKEN, +) from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, + ConfigReloadScope, ConfigUpdatedPayload, Envelope, HealthPayload, @@ -107,6 +116,7 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {} + self._external_available_plugin_ids: List[str] = [] self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -156,6 +166,21 @@ class PluginRunnerSupervisor: """返回底层 RPC 服务端。""" return self._rpc_server + def set_external_available_plugin_ids(self, plugin_ids: List[str]) -> None: + """设置当前 Runner 启动/重载时可视为已满足的外部依赖列表。""" + + normalized_plugin_ids = { + str(plugin_id or "").strip() + for plugin_id in plugin_ids + if str(plugin_id or "").strip() + } + self._external_available_plugin_ids = sorted(normalized_plugin_ids) + + def get_loaded_plugin_ids(self) -> List[str]: + """返回当前 Supervisor 已注册的插件 ID 列表。""" + + return sorted(self._registered_plugins.keys()) + async def dispatch_event( self, event_type: str, @@ -344,12 +369,18 @@ class PluginRunnerSupervisor: timeout_ms=timeout_ms, ) - async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: + async def reload_plugin( + self, + plugin_id: str, + reason: str = "manual", + external_available_plugins: Optional[List[str]] = None, + ) -> bool: """按插件 ID 触发精确重载。 Args: plugin_id: 目标插件 ID。 reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件 ID 列表。 Returns: bool: 是否重载成功。 @@ -358,7 +389,11 @@ class PluginRunnerSupervisor: response = await self._rpc_server.send_request( "plugin.reload", plugin_id=plugin_id, - payload={"plugin_id": plugin_id, "reason": reason}, + payload={ + "plugin_id": plugin_id, + "reason": reason, + "external_available_plugins": external_available_plugins or self._external_available_plugin_ids, + }, timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), ) except Exception as exc: @@ -374,12 +409,14 @@ class PluginRunnerSupervisor: self, plugin_ids: Optional[List[str]] = None, reason: str = "manual", + external_available_plugins: Optional[List[str]] = None, ) -> bool: """批量重载插件。 Args: plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。 reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件 ID 列表。 Returns: bool: 是否全部重载成功。 @@ -389,7 +426,11 @@ class PluginRunnerSupervisor: success = True for plugin_id in ordered_plugin_ids: - reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason) + reloaded = await self.reload_plugin( + plugin_id=plugin_id, + reason=reason, + external_available_plugins=external_available_plugins, + ) success = success and reloaded return success @@ -399,7 +440,7 @@ class PluginRunnerSupervisor: plugin_id: str, config_data: Optional[Dict[str, Any]] = None, config_version: str = "", - config_scope: str = "self", + config_scope: str | ConfigReloadScope = "self", ) -> bool: """向 Runner 推送插件配置更新。 @@ -412,9 +453,15 @@ class PluginRunnerSupervisor: Returns: bool: 请求是否成功送达并被 Runner 接受。 """ + try: + normalized_scope = ConfigReloadScope(config_scope) + except ValueError: + logger.warning(f"插件 {plugin_id} 配置更新通知失败: 非法的 config_scope={config_scope}") + return False + payload = ConfigUpdatedPayload( plugin_id=plugin_id, - config_scope=config_scope, + config_scope=normalized_scope, config_version=config_version, config_data=config_data or {}, ) @@ -441,11 +488,11 @@ class PluginRunnerSupervisor: List[str]: 已声明订阅该范围的插件 ID 列表。 """ - matched_plugins: List[str] = [] - for plugin_id, registration in self._registered_plugins.items(): - if scope in registration.config_reload_subscriptions: - matched_plugins.append(plugin_id) - return matched_plugins + return [ + plugin_id + for plugin_id, registration in self._registered_plugins.items() + if scope in registration.config_reload_subscriptions + ] async def _wait_for_runner_connection(self, timeout_sec: float) -> None: """等待 Runner 建立 RPC 连接。 @@ -706,10 +753,7 @@ class PluginRunnerSupervisor: ) gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False) - if len(gateways) == 1: - return gateways[0] - - return None + return gateways[0] if len(gateways) == 1 else None async def _register_message_gateway_driver( self, @@ -823,8 +867,7 @@ class PluginRunnerSupervisor: ValueError: 当平台信息缺失时抛出。 """ - platform = str(payload.platform or gateway_entry.platform or "").strip() - if not platform: + if not (platform := str(payload.platform or gateway_entry.platform or "").strip()): raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称") return RouteKey( @@ -1090,7 +1133,11 @@ class PluginRunnerSupervisor: Returns: Dict[str, str]: 传递给 Runner 进程的环境变量映射。 """ + global_config_snapshot = config_manager.get_global_config().model_dump() + global_config_snapshot["model"] = config_manager.get_model_config().model_dump() return { + ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugin_ids, ensure_ascii=False), + ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False), ENV_HOST_VERSION: PROTOCOL_VERSION, ENV_IPC_ADDRESS: self._transport.get_address(), ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs), @@ -1136,8 +1183,7 @@ class PluginRunnerSupervisor: line = await stream.readline() if not line: return - message = line.decode("utf-8", errors="replace").rstrip() - if message: + if message := line.decode("utf-8", errors="replace").rstrip(): logger.warning(f"[runner-stderr] {message}") except asyncio.CancelledError: raise diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index e45b40de..d48260e5 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -8,7 +8,7 @@ """ from pathlib import Path -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple import asyncio import json @@ -102,6 +102,77 @@ class PluginRuntimeManager( candidate = Path("plugins").resolve() return [candidate] if candidate.is_dir() else [] + @staticmethod + def _extract_manifest_dependencies(manifest: Dict[str, Any]) -> List[str]: + """从插件 manifest 中提取规范化后的依赖插件 ID 列表。""" + + dependencies: List[str] = [] + for dependency in manifest.get("dependencies", []): + if isinstance(dependency, str): + normalized_dependency = dependency.strip() + elif isinstance(dependency, dict): + normalized_dependency = str(dependency.get("name", "") or "").strip() + else: + normalized_dependency = "" + + if normalized_dependency: + dependencies.append(normalized_dependency) + return dependencies + + @classmethod + def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]: + """扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。""" + + dependency_map: Dict[str, List[str]] = {} + for plugin_dir in cls._iter_candidate_plugin_paths(plugin_dirs): + manifest_path = plugin_dir / "_manifest.json" + entrypoint_path = plugin_dir / "plugin.py" + if not manifest_path.is_file() or not entrypoint_path.is_file(): + continue + + try: + with manifest_path.open("r", encoding="utf-8") as manifest_file: + manifest = json.load(manifest_file) + except Exception: + continue + + if not isinstance(manifest, dict): + continue + + plugin_id = str(manifest.get("name", plugin_dir.name) or "").strip() or plugin_dir.name + dependency_map[plugin_id] = cls._extract_manifest_dependencies(manifest) + return dependency_map + + @classmethod + def _build_group_start_order( + cls, + builtin_dirs: Sequence[Path], + third_party_dirs: Sequence[Path], + ) -> List[str]: + """根据跨 Supervisor 依赖关系决定 Runner 启动顺序。""" + + builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs) + third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs) + builtin_plugin_ids = set(builtin_dependencies) + third_party_plugin_ids = set(third_party_dependencies) + + builtin_needs_third_party = any( + dependency in third_party_plugin_ids + for dependencies in builtin_dependencies.values() + for dependency in dependencies + ) + third_party_needs_builtin = any( + dependency in builtin_plugin_ids + for dependencies in third_party_dependencies.values() + for dependency in dependencies + ) + + if builtin_needs_third_party and third_party_needs_builtin: + raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner") + if builtin_needs_third_party: + return ["third_party", "builtin"] + return ["builtin", "third_party"] + # ─── 生命周期 ───────────────────────────────────────────── async def start(self) -> None: @@ -161,12 +232,26 @@ class PluginRuntimeManager( platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound) await platform_io_manager.ensure_send_pipeline_ready() - if self._builtin_supervisor: - await self._builtin_supervisor.start() - started_supervisors.append(self._builtin_supervisor) - if self._third_party_supervisor: - await self._third_party_supervisor.start() - started_supervisors.append(self._third_party_supervisor) + supervisor_groups: Dict[str, Optional[PluginSupervisor]] = { + "builtin": self._builtin_supervisor, + "third_party": self._third_party_supervisor, + } + start_order = self._build_group_start_order(builtin_dirs, third_party_dirs) + + for group_name in start_order: + supervisor = supervisor_groups.get(group_name) + if supervisor is None: + continue + + external_plugin_ids = [ + plugin_id + for started_supervisor in started_supervisors + for plugin_id in started_supervisor.get_loaded_plugin_ids() + ] + supervisor.set_external_available_plugin_ids(external_plugin_ids) + await supervisor.start() + started_supervisors.append(supervisor) + await self._start_plugin_file_watcher() config_manager.register_reload_callback(self._config_reload_callback) self._config_reload_callback_registered = True @@ -238,6 +323,171 @@ class PluginRuntimeManager( """获取所有活跃的 Supervisor""" return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None] + def _build_registered_dependency_map(self) -> Dict[str, Set[str]]: + """根据当前已注册插件构建全局依赖图。""" + + dependency_map: Dict[str, Set[str]] = {} + for supervisor in self.supervisors: + for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items(): + dependency_map[plugin_id] = { + str(dependency or "").strip() + for dependency in getattr(registration, "dependencies", []) + if str(dependency or "").strip() + } + return dependency_map + + @staticmethod + def _collect_reverse_dependents( + plugin_ids: Set[str], + dependency_map: Dict[str, Set[str]], + ) -> Set[str]: + """根据依赖图收集反向依赖闭包。""" + + impacted_plugins: Set[str] = set(plugin_ids) + changed = True + + while changed: + changed = False + for registered_plugin_id, dependencies in dependency_map.items(): + if registered_plugin_id in impacted_plugins: + continue + if dependencies & impacted_plugins: + impacted_plugins.add(registered_plugin_id) + changed = True + + return impacted_plugins + + def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]: + """构建当前已注册插件到所属 Supervisor 的映射。""" + + return { + plugin_id: supervisor + for supervisor in self.supervisors + for plugin_id in supervisor.get_loaded_plugin_ids() + } + + def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> List[str]: + """收集某个 Supervisor 可用的外部插件 ID 列表。""" + + external_plugin_ids: Set[str] = set() + for supervisor in self.supervisors: + if supervisor is target_supervisor: + continue + external_plugin_ids.update(supervisor.get_loaded_plugin_ids()) + return sorted(external_plugin_ids) + + def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]: + """根据插件目录推断应负责该插件重载的 Supervisor。""" + + for supervisor in self.supervisors: + for plugin_dir in supervisor._plugin_dirs: + if (Path(plugin_dir) / plugin_id).is_dir(): + return supervisor + return None + + def _warn_skipped_cross_supervisor_reload( + self, + requested_loaded_plugin_ids: Set[str], + dependency_map: Dict[str, Set[str]], + supervisor_by_plugin: Dict[str, "PluginSupervisor"], + ) -> None: + """记录因跨 Supervisor 边界而未参与联动重载的插件。""" + + if not requested_loaded_plugin_ids: + return + + handled_plugin_ids: Set[str] = set() + for supervisor in self.supervisors: + local_requested_plugin_ids = { + plugin_id + for plugin_id in requested_loaded_plugin_ids + if supervisor_by_plugin.get(plugin_id) is supervisor + } + if not local_requested_plugin_ids: + continue + + local_plugin_ids = set(supervisor.get_loaded_plugin_ids()) + local_dependency_map = { + plugin_id: { + dependency + for dependency in dependency_map.get(plugin_id, set()) + if dependency in local_plugin_ids + } + for plugin_id in local_plugin_ids + } + handled_plugin_ids.update( + self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map) + ) + + impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map) + skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids) + if not skipped_plugin_ids: + return + + logger.warning( + f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: " + f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;" + "跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。" + ) + + async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: + """按 Supervisor 分组执行精确重载。 + + 仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警, + 不再自动参与本次热重载。 + """ + + normalized_plugin_ids = [ + normalized_plugin_id + for plugin_id in plugin_ids + if (normalized_plugin_id := str(plugin_id or "").strip()) + ] + if not normalized_plugin_ids: + return True + + dependency_map = self._build_registered_dependency_map() + supervisor_by_plugin = self._build_registered_supervisor_map() + supervisor_roots: Dict["PluginSupervisor", List[str]] = {} + requested_loaded_plugin_ids: Set[str] = set() + missing_plugin_ids: List[str] = [] + + for plugin_id in normalized_plugin_ids: + supervisor = supervisor_by_plugin.get(plugin_id) + if supervisor is not None: + requested_loaded_plugin_ids.add(plugin_id) + else: + supervisor = self._find_supervisor_by_plugin_directory(plugin_id) + + if supervisor is None: + missing_plugin_ids.append(plugin_id) + continue + + if plugin_id not in supervisor_roots.setdefault(supervisor, []): + supervisor_roots[supervisor].append(plugin_id) + + if missing_plugin_ids: + logger.warning(f"以下插件未找到可重载的 Supervisor,已跳过: {', '.join(sorted(missing_plugin_ids))}") + + self._warn_skipped_cross_supervisor_reload( + requested_loaded_plugin_ids=requested_loaded_plugin_ids, + dependency_map=dependency_map, + supervisor_by_plugin=supervisor_by_plugin, + ) + + success = True + for supervisor, root_plugin_ids in supervisor_roots.items(): + if not root_plugin_ids: + continue + + reloaded = await supervisor.reload_plugins( + plugin_ids=root_plugin_ids, + reason=reason, + external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor), + ) + success = success and reloaded + + return success and not missing_plugin_ids + async def notify_plugin_config_updated( self, plugin_id: str, @@ -465,6 +715,31 @@ class PluginRuntimeManager( raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由") return matches[0] if matches else None + async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: + """加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。""" + + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id: + return False + + try: + registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id) + except RuntimeError: + return False + + if registered_supervisor is not None: + return await self.reload_plugins_globally([normalized_plugin_id], reason=reason) + + supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id) + if supervisor is None: + return False + + return await supervisor.reload_plugins( + plugin_ids=[normalized_plugin_id], + reason=reason, + external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor), + ) + @staticmethod def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]: """扫描插件目录,找出被多个目录重复声明的插件 ID。""" @@ -729,7 +1004,7 @@ class PluginRuntimeManager( logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}") return - reload_supervisors: Dict[Any, List[str]] = {} + changed_plugin_ids: List[str] = [] changed_paths = [change.path.resolve() for change in changes] for supervisor in self.supervisors: @@ -738,14 +1013,11 @@ class PluginRuntimeManager( if plugin_id is None: continue if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py": - reload_supervisors.setdefault(supervisor, []) - if plugin_id not in reload_supervisors[supervisor]: - reload_supervisors[supervisor].append(plugin_id) + if plugin_id not in changed_plugin_ids: + changed_plugin_ids.append(plugin_id) - for supervisor, plugin_ids in reload_supervisors.items(): - await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher") - - if reload_supervisors: + if changed_plugin_ids: + await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher") self._refresh_plugin_config_watch_subscriptions() @staticmethod diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 6078e4dc..ce40d855 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -166,6 +166,8 @@ class RegisterPluginPayload(BaseModel): """组件列表""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") """所需能力列表""" + dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表") + """插件级依赖插件 ID 列表""" config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围") """订阅的全局配置热重载范围""" @@ -280,6 +282,8 @@ class ReloadPluginPayload(BaseModel): """目标插件 ID""" reason: str = Field(default="manual", description="重载原因") """重载原因""" + external_available_plugins: List[str] = Field(default_factory=list, description="可视为已满足的外部依赖插件 ID") + """可视为已满足的外部依赖插件 ID""" class ReloadPluginResultPayload(BaseModel): diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index a766eb04..f07eb593 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -95,11 +95,16 @@ class PluginLoader: self._manifest_validator = ManifestValidator(host_version=host_version) self._compat_hook_installed = False - def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]: + def discover_and_load( + self, + plugin_dirs: List[str], + extra_available: Optional[Set[str]] = None, + ) -> List[PluginMeta]: """扫描多个目录并加载所有插件。 Args: plugin_dirs: 插件目录列表。 + extra_available: 额外视为已满足的外部依赖插件 ID 集合。 Returns: List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。 @@ -108,7 +113,7 @@ class PluginLoader: self._record_duplicate_candidates(duplicate_candidates) # 第二阶段:依赖解析(拓扑排序) - load_order, failed_deps = self._resolve_dependencies(candidates) + load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available) self._record_failed_dependencies(failed_deps) # 第三阶段:按依赖顺序加载 diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index b38946d6..c0f5e771 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, ca import asyncio import contextlib import inspect +import json import logging as stdlib_logging import os import signal @@ -23,7 +24,13 @@ import time import tomllib from src.common.logger import get_console_handler, get_logger, initialize_logging -from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN +from src.plugin_runtime import ( + ENV_EXTERNAL_PLUGIN_IDS, + ENV_HOST_VERSION, + ENV_IPC_ADDRESS, + ENV_PLUGIN_DIRS, + ENV_SESSION_TOKEN, +) from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, ComponentDeclaration, @@ -112,6 +119,7 @@ class PluginRunner: host_address: str, session_token: str, plugin_dirs: List[str], + external_available_plugin_ids: Optional[List[str]] = None, ) -> None: """初始化 Runner。 @@ -119,10 +127,16 @@ class PluginRunner: host_address: Host 的 IPC 地址。 session_token: 握手用会话令牌。 plugin_dirs: 当前 Runner 负责扫描的插件目录列表。 + external_available_plugin_ids: 视为已满足的外部依赖插件 ID 列表。 """ self._host_address: str = host_address self._session_token: str = session_token self._plugin_dirs: List[str] = plugin_dirs + self._external_available_plugin_ids: Set[str] = { + str(plugin_id or "").strip() + for plugin_id in (external_available_plugin_ids or []) + if str(plugin_id or "").strip() + } self._rpc_client: RPCClient = RPCClient(host_address, session_token) self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, "")) @@ -150,7 +164,10 @@ class PluginRunner: self._register_handlers() # 3. 加载插件 - plugins = self._loader.discover_and_load(self._plugin_dirs) + plugins = self._loader.discover_and_load( + self._plugin_dirs, + extra_available=self._external_available_plugin_ids, + ) logger.info(f"已加载 {len(plugins)} 个插件") # 4. 注入 PluginContext + 调用 on_load 生命周期钩子 @@ -379,6 +396,7 @@ class PluginRunner: plugin_version=meta.version, components=components, capabilities_required=meta.capabilities_required, + dependencies=meta.dependencies, config_reload_subscriptions=config_reload_subscriptions, ) @@ -485,18 +503,20 @@ class PluginRunner: self._loader.set_loaded_plugin(meta) return True - async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None: + async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None: """卸载单个插件并清理 Host/Runner 两侧状态。 Args: meta: 待卸载的插件元数据。 reason: 卸载原因。 + purge_modules: 是否在卸载完成后清理插件模块缓存。 """ await self._invoke_plugin_on_unload(meta) await self._unregister_plugin(meta.plugin_id, reason) await self._deactivate_plugin(meta) self._loader.remove_loaded_plugin(meta.plugin_id) - self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + if purge_modules: + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]: """收集依赖指定插件的所有已加载插件。 @@ -564,18 +584,52 @@ class PluginRunner: return list(reversed(load_order)) - async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload: + @staticmethod + def _finalize_failed_reload_messages( + failed_plugins: Dict[str, str], + rollback_failures: Dict[str, str], + ) -> Dict[str, str]: + """在重载失败后补充回滚结果说明。""" + + finalized_failures: Dict[str, str] = {} + for failed_plugin_id, failure_reason in failed_plugins.items(): + rollback_failure = rollback_failures.get(failed_plugin_id) + if rollback_failure: + finalized_failures[failed_plugin_id] = ( + f"{failure_reason};且旧版本恢复失败: {rollback_failure}" + ) + else: + finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)" + + for failed_plugin_id, rollback_failure in rollback_failures.items(): + if failed_plugin_id not in finalized_failures: + finalized_failures[failed_plugin_id] = f"旧版本恢复失败: {rollback_failure}" + + return finalized_failures + + async def _reload_plugin_by_id( + self, + plugin_id: str, + reason: str, + external_available_plugins: Optional[Set[str]] = None, + ) -> ReloadPluginResultPayload: """按插件 ID 在 Runner 进程内执行精确重载。 Args: plugin_id: 目标插件 ID。 reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件 ID 集合。 Returns: ReloadPluginResultPayload: 结构化重载结果。 """ candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs) failed_plugins: Dict[str, str] = {} + normalized_external_available = { + str(candidate_plugin_id or "").strip() + for candidate_plugin_id in (external_available_plugins or set()) + if str(candidate_plugin_id or "").strip() + } if plugin_id in duplicate_candidates: conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) @@ -603,29 +657,32 @@ class PluginRunner: unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids) unloaded_plugins: List[str] = [] retained_plugin_ids = loaded_plugin_ids - set(unload_order) + rollback_metas: Dict[str, PluginMeta] = {} for unload_plugin_id in unload_order: meta = self._loader.get_plugin(unload_plugin_id) if meta is None: continue - await self._unload_plugin(meta, reason=reason) + rollback_metas[unload_plugin_id] = meta + await self._unload_plugin(meta, reason=reason, purge_modules=False) + self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir) unloaded_plugins.append(unload_plugin_id) reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {} for target_plugin_id in target_plugin_ids: candidate = candidates.get(target_plugin_id) if candidate is None: - failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态" + failed_plugins[target_plugin_id] = "插件目录已不存在" continue reload_candidates[target_plugin_id] = candidate load_order, dependency_failures = self._loader.resolve_dependencies( reload_candidates, - extra_available=retained_plugin_ids, + extra_available=retained_plugin_ids | normalized_external_available, ) failed_plugins.update(dependency_failures) - available_plugins = set(retained_plugin_ids) + available_plugins = set(retained_plugin_ids) | normalized_external_available reloaded_plugins: List[str] = [] for load_plugin_id in load_order: @@ -656,7 +713,48 @@ class PluginRunner: available_plugins.add(load_plugin_id) reloaded_plugins.append(load_plugin_id) - requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins + if failed_plugins: + rollback_failures: Dict[str, str] = {} + + for reloaded_plugin_id in reversed(reloaded_plugins): + reloaded_meta = self._loader.get_plugin(reloaded_plugin_id) + if reloaded_meta is None: + continue + + try: + await self._unload_plugin( + reloaded_meta, + reason=f"{reason}_rollback_cleanup", + purge_modules=False, + ) + except Exception as exc: + rollback_failures[reloaded_plugin_id] = f"清理失败: {exc}" + finally: + self._loader.purge_plugin_modules(reloaded_plugin_id, reloaded_meta.plugin_dir) + + for rollback_plugin_id in reversed(unload_order): + rollback_meta = rollback_metas.get(rollback_plugin_id) + if rollback_meta is None: + continue + + try: + restored = await self._activate_plugin(rollback_meta) + except Exception as exc: + rollback_failures[rollback_plugin_id] = str(exc) + continue + + if not restored: + rollback_failures[rollback_plugin_id] = "无法重新激活旧版本" + + return ReloadPluginResultPayload( + success=False, + requested_plugin_id=plugin_id, + reloaded_plugins=[], + unloaded_plugins=unloaded_plugins, + failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures), + ) + + requested_plugin_success = plugin_id in reloaded_plugins return ReloadPluginResultPayload( success=requested_plugin_success, @@ -978,7 +1076,11 @@ class PluginRunner: ) async with self._reload_lock: - result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason) + result = await self._reload_plugin_by_id( + payload.plugin_id, + payload.reason, + external_available_plugins=set(payload.external_available_plugins), + ) return envelope.make_response(payload=result.model_dump()) def request_capability(self) -> RPCClient: @@ -1073,6 +1175,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: async def _async_main() -> None: """异步主入口""" host_address = os.environ.get(ENV_IPC_ADDRESS, "") + external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "") session_token = os.environ.get(ENV_SESSION_TOKEN, "") plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "") @@ -1081,11 +1184,24 @@ async def _async_main() -> None: sys.exit(1) plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d] + try: + external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else [] + except json.JSONDecodeError: + logger.warning("解析外部依赖插件列表失败,已回退为空列表") + external_plugin_ids = [] + if not isinstance(external_plugin_ids, list): + logger.warning("外部依赖插件列表格式非法,已回退为空列表") + external_plugin_ids = [] # sys.path 隔离: 只保留标准库、SDK 包、插件目录 _isolate_sys_path(plugin_dirs) - runner = PluginRunner(host_address, session_token, plugin_dirs) + runner = PluginRunner( + host_address, + session_token, + plugin_dirs, + external_available_plugin_ids=[str(plugin_id) for plugin_id in external_plugin_ids], + ) # 注册信号处理 def _mark_runner_shutting_down() -> None: From 1f02171a635e555b62bff2480412c6a4c44a3ce5 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 22:59:01 +0800 Subject: [PATCH 37/45] Refactor plugin loader and runner to support enhanced manifest structure - Updated the PluginMeta class to utilize a strongly typed PluginManifest, improving type safety and clarity. - Refactored dependency extraction logic to streamline the handling of plugin dependencies. - Modified the PluginLoader to accommodate new manifest versioning and validation processes. - Enhanced the PluginRunner to work with a dictionary for external available plugins, allowing for version mapping. - Updated built-in plugins' manifest files to version 2, adding URLs and SDK versioning for better integration and documentation. - Improved error handling and logging for plugin loading and dependency resolution processes. --- plugins/ChatFrequency/_manifest.json | 76 +- plugins/MaiBot_MCPBridgePlugin/_manifest.json | 83 +- plugins/emoji_manage_plugin/_manifest.json | 88 +- plugins/hello_world_plugin/_manifest.json | 107 +- pytests/test_plugin_runtime.py | 387 ++++-- src/plugin_runtime/__init__.py | 2 +- src/plugin_runtime/capabilities/components.py | 4 +- src/plugin_runtime/host/supervisor.py | 41 +- src/plugin_runtime/integration.py | 114 +- src/plugin_runtime/protocol/envelope.py | 7 +- .../runner/manifest_validator.py | 1113 +++++++++++++++-- src/plugin_runtime/runner/plugin_loader.py | 166 ++- src/plugin_runtime/runner/runner_main.py | 79 +- .../built_in/emoji_plugin/_manifest.json | 43 +- .../built_in/plugin_management/_manifest.json | 77 +- 15 files changed, 1676 insertions(+), 711 deletions(-) diff --git a/plugins/ChatFrequency/_manifest.json b/plugins/ChatFrequency/_manifest.json index 241242ed..56417665 100644 --- a/plugins/ChatFrequency/_manifest.json +++ b/plugins/ChatFrequency/_manifest.json @@ -1,58 +1,40 @@ { - "manifest_version": 1, - "name": "发言频率控制插件|BetterFrequency Plugin", + "manifest_version": 2, "version": "2.0.0", - "description": "控制聊天频率,支持设置focus_value和talk_frequency调整值,提供命令", + "name": "发言频率控制插件|BetterFrequency Plugin", + "description": "控制聊天频率,支持设置 focus_value 和 talk_frequency 调整值,并提供命令入口。", "author": { "name": "SengokuCola", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" + "urls": { + "repository": "https://github.com/SengokuCola/BetterFrequency", + "homepage": "https://github.com/SengokuCola/BetterFrequency", + "documentation": "https://github.com/SengokuCola/BetterFrequency", + "issues": "https://github.com/SengokuCola/BetterFrequency/issues" }, - "homepage_url": "https://github.com/SengokuCola/BetterFrequency", - "repository_url": "https://github.com/SengokuCola/BetterFrequency", - "keywords": [ - "frequency", - "control", - "talk_frequency", - "plugin", - "shortcut" + "host_application": { + "min_version": "1.0.0", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [], + "capabilities": [ + "send.text", + "frequency.set_adjust", + "frequency.get_current_talk_value", + "frequency.get_adjust" ], - "categories": [ - "Chat", - "Frequency", - "Control" - ], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": false, - "plugin_type": "frequency", - "components": [ - { - "type": "command", - "name": "set_talk_frequency", - "description": "设置当前聊天的talk_frequency调整值", - "pattern": "/chat talk_frequency <数字> 或 /chat t <数字>" - }, - { - "type": "command", - "name": "show_frequency", - "description": "显示当前聊天的频率控制状态", - "pattern": "/chat show 或 /chat s" - } - ], - "features": [ - "设置talk_frequency调整值", - "调整当前聊天的发言频率", - "显示当前频率控制状态", - "实时频率控制调整", - "命令执行反馈(不保存消息)", - "支持完整命令和简化命令", - "快速操作支持" + "i18n": { + "default_locale": "zh-CN", + "locales_path": "_locales", + "supported_locales": [ + "zh-CN" ] }, - "id": "SengokuCola.BetterFrequency" -} \ No newline at end of file + "id": "sengokucola.betterfrequency" +} diff --git a/plugins/MaiBot_MCPBridgePlugin/_manifest.json b/plugins/MaiBot_MCPBridgePlugin/_manifest.json index 85225a43..d2e08ab4 100644 --- a/plugins/MaiBot_MCPBridgePlugin/_manifest.json +++ b/plugins/MaiBot_MCPBridgePlugin/_manifest.json @@ -1,67 +1,42 @@ { - "manifest_version": 1, - "name": "MCP桥接插件", + "manifest_version": 2, "version": "2.0.0", - "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具", + "name": "MCP桥接插件", + "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。", "author": { "name": "CharTyr", "url": "https://github.com/CharTyr" }, "license": "AGPL-3.0", - "host_application": { - "min_version": "0.11.6" + "urls": { + "repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", + "homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", + "documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", + "issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues" }, - "homepage_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "repository_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "keywords": [ - "mcp", - "bridge", - "tool", - "integration", - "resources", - "prompts", - "post-process", - "cache", - "trace", - "permissions", - "import", - "export", - "claude-desktop", - "workflow", - "react", - "agent" + "host_application": { + "min_version": "0.11.6", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [ + { + "type": "python_package", + "name": "mcp", + "version_spec": ">=0.0.0" + } ], - "categories": [ - "工具扩展", - "外部集成" + "capabilities": [ + "send.text" ], - "default_locale": "zh-CN", - "plugin_info": { - "is_built_in": false, - "components": [], - "features": [ - "支持多个 MCP 服务器", - "自动发现并注册 MCP 工具", - "支持 stdio、SSE、HTTP、Streamable HTTP 四种传输方式", - "工具参数自动转换", - "心跳检测与自动重连", - "调用统计(次数、成功率、耗时)", - "WebUI 配置支持", - "Resources 支持(实验性)", - "Prompts 支持(实验性)", - "结果后处理(LLM 摘要提炼)", - "工具禁用管理", - "调用链路追踪", - "工具调用缓存(LRU)", - "工具权限控制(群/用户级别)", - "配置导入导出(Claude Desktop mcpServers)", - "断路器模式(故障快速失败)", - "状态实时刷新", - "Workflow 硬流程(顺序执行多个工具)", - "Workflow 快速添加(表单式配置)", - "ReAct 软流程(LLM 自主多轮调用)", - "双轨制架构(软流程 + 硬流程)" + "i18n": { + "default_locale": "zh-CN", + "supported_locales": [ + "zh-CN" ] }, - "id": "MaiBot Community.MCPBridgePlugin" + "id": "chartyr.mcpbridge-plugin" } diff --git a/plugins/emoji_manage_plugin/_manifest.json b/plugins/emoji_manage_plugin/_manifest.json index 3af69023..998cb7da 100644 --- a/plugins/emoji_manage_plugin/_manifest.json +++ b/plugins/emoji_manage_plugin/_manifest.json @@ -1,68 +1,44 @@ { - "manifest_version": 1, - "name": "BetterEmoji", + "manifest_version": 2, "version": "2.0.0", + "name": "BetterEmoji", "description": "更好的表情包管理插件", "author": { "name": "SengokuCola", "url": "https://github.com/SengokuCola" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" + "urls": { + "repository": "https://github.com/SengokuCola/BetterEmoji", + "homepage": "https://github.com/SengokuCola/BetterEmoji", + "documentation": "https://github.com/SengokuCola/BetterEmoji", + "issues": "https://github.com/SengokuCola/BetterEmoji/issues" }, - "homepage_url": "https://github.com/SengokuCola/BetterEmoji", - "repository_url": "https://github.com/SengokuCola/BetterEmoji", - "keywords": [ - "emoji", - "manage", - "plugin" + "host_application": { + "min_version": "1.0.0", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [], + "capabilities": [ + "emoji.get_random", + "emoji.get_count", + "emoji.get_info", + "emoji.get_all", + "emoji.register_emoji", + "emoji.delete_emoji", + "send.text", + "send.forward" ], - "categories": [ - "Emoji", - "Management" - ], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": false, - "plugin_type": "emoji_manage", - "capabilities": [ - "emoji.get_random", - "emoji.get_count", - "emoji.get_info", - "emoji.get_all", - "emoji.register_emoji", - "emoji.delete_emoji", - "send.text", - "send.forward" - ], - "components": [ - { - "type": "command", - "name": "add_emoji", - "description": "添加表情包", - "pattern": "/emoji add" - }, - { - "type": "command", - "name": "emoji_list", - "description": "列表表情包", - "pattern": "/emoji list" - }, - { - "type": "command", - "name": "delete_emoji", - "description": "删除表情包", - "pattern": "/emoji delete" - }, - { - "type": "command", - "name": "random_emojis", - "description": "发送多张随机表情包", - "pattern": "/random_emojis" - } + "i18n": { + "default_locale": "zh-CN", + "locales_path": "_locales", + "supported_locales": [ + "zh-CN" ] }, - "id": "SengokuCola.BetterEmoji" -} \ No newline at end of file + "id": "sengokucola.betteremoji" +} diff --git a/plugins/hello_world_plugin/_manifest.json b/plugins/hello_world_plugin/_manifest.json index dc9fc474..e2bc694d 100644 --- a/plugins/hello_world_plugin/_manifest.json +++ b/plugins/hello_world_plugin/_manifest.json @@ -1,88 +1,41 @@ { - "manifest_version": 1, - "name": "Hello World 示例插件 (Hello World Plugin)", + "manifest_version": 2, "version": "2.0.0", - "description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例", + "name": "Hello World 示例插件 (Hello World Plugin)", + "description": "我的第一个 MaiCore 插件,包含问候功能和时间查询等基础示例", "author": { "name": "MaiBot开发团队", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" + "urls": { + "repository": "https://github.com/MaiM-with-u/maibot", + "homepage": "https://github.com/MaiM-with-u/maibot", + "documentation": "https://github.com/MaiM-with-u/maibot", + "issues": "https://github.com/MaiM-with-u/maibot/issues" }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": [ - "demo", - "example", - "hello", - "greeting", - "tutorial" + "host_application": { + "min_version": "1.0.0", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [], + "capabilities": [ + "send.text", + "send.forward", + "send.hybrid", + "emoji.get_random", + "config.get" ], - "categories": [ - "Examples", - "Tutorial" - ], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": false, - "plugin_type": "example", - "capabilities": [ - "send.text", - "send.forward", - "send.hybrid", - "emoji.get_random", - "config.get" - ], - "components": [ - { - "type": "tool", - "name": "compare_numbers", - "description": "比较两个数的大小" - }, - { - "type": "action", - "name": "hello_greeting", - "description": "向用户发送问候消息" - }, - { - "type": "action", - "name": "bye_greeting", - "description": "向用户发送告别消息", - "activation_modes": ["keyword"], - "keywords": ["再见", "bye", "88", "拜拜"] - }, - { - "type": "command", - "name": "time", - "description": "查询当前时间", - "pattern": "/time" - }, - { - "type": "command", - "name": "random_emojis", - "description": "发送多张随机表情包", - "pattern": "/random_emojis" - }, - { - "type": "command", - "name": "test", - "description": "测试命令", - "pattern": "/test" - }, - { - "type": "event_handler", - "name": "print_message_handler", - "description": "打印接收到的消息" - }, - { - "type": "event_handler", - "name": "forward_messages_handler", - "description": "把接收到的消息转发到指定聊天ID" - } + "i18n": { + "default_locale": "zh-CN", + "locales_path": "_locales", + "supported_locales": [ + "zh-CN" ] }, - "id": "MaiBot开发团队.maibot" -} \ No newline at end of file + "id": "maibot-team.hello-world-plugin" +} diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index e094d85b..1d93ae24 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -19,6 +19,104 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk")) +def build_test_manifest( + plugin_id: str, + *, + version: str = "1.0.0", + name: str = "测试插件", + description: str = "测试插件描述", + dependencies: list[dict[str, str]] | None = None, + capabilities: list[str] | None = None, + host_min_version: str = "0.12.0", + host_max_version: str = "1.0.0", + sdk_min_version: str = "2.0.0", + sdk_max_version: str = "2.99.99", +) -> dict[str, object]: + """构造一个合法的 Manifest v2 测试样例。 + + Args: + plugin_id: 插件 ID。 + version: 插件版本。 + name: 展示名称。 + description: 插件描述。 + dependencies: 依赖声明列表。 + capabilities: 能力声明列表。 + host_min_version: Host 最低支持版本。 + host_max_version: Host 最高支持版本。 + sdk_min_version: SDK 最低支持版本。 + sdk_max_version: SDK 最高支持版本。 + + Returns: + dict[str, object]: 可直接序列化为 ``_manifest.json`` 的字典。 + """ + return { + "manifest_version": 2, + "version": version, + "name": name, + "description": description, + "author": { + "name": "tester", + "url": "https://example.com/tester", + }, + "license": "MIT", + "urls": { + "repository": f"https://example.com/{plugin_id}", + }, + "host_application": { + "min_version": host_min_version, + "max_version": host_max_version, + }, + "sdk": { + "min_version": sdk_min_version, + "max_version": sdk_max_version, + }, + "dependencies": dependencies or [], + "capabilities": capabilities or [], + "i18n": { + "default_locale": "zh-CN", + "supported_locales": ["zh-CN"], + }, + "id": plugin_id, + } + + +def build_test_manifest_model( + plugin_id: str, + *, + version: str = "1.0.0", + dependencies: list[dict[str, str]] | None = None, + capabilities: list[str] | None = None, + host_version: str = "1.0.0", + sdk_version: str = "2.0.1", +) -> object: + """构造一个已经通过校验的强类型 Manifest 测试对象。 + + Args: + plugin_id: 插件 ID。 + version: 插件版本。 + dependencies: 依赖声明列表。 + capabilities: 能力声明列表。 + host_version: 当前测试使用的 Host 版本。 + sdk_version: 当前测试使用的 SDK 版本。 + + Returns: + object: ``PluginManifest`` 实例。 + """ + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version=host_version, sdk_version=sdk_version) + manifest = validator.parse_manifest( + build_test_manifest( + plugin_id, + version=version, + dependencies=dependencies, + capabilities=capabilities, + ) + ) + assert manifest is not None + return manifest + + # ─── 协议层测试 ─────────────────────────────────────────── @@ -759,65 +857,77 @@ class TestManifestValidator: def test_valid_manifest(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator() - manifest = { - "manifest_version": 1, - "name": "test_plugin", - "version": "1.0.0", - "description": "测试插件", - "author": "test", - } + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest("test.valid-plugin", capabilities=["send.text"]) assert validator.validate(manifest) is True assert len(validator.errors) == 0 + assert validator.warnings == [] def test_missing_required_fields(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator() - manifest = {"manifest_version": 1} + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = {"manifest_version": 2} assert validator.validate(manifest) is False - assert len(validator.errors) >= 4 # name, version, description, author + assert len(validator.errors) >= 6 + assert any("缺少必需字段" in error for error in validator.errors) def test_unsupported_manifest_version(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator() - manifest = { - "manifest_version": 999, - "name": "test", - "version": "1.0", - "description": "d", - "author": "a", - } + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest("test.invalid-version") + manifest["manifest_version"] = 999 assert validator.validate(manifest) is False assert any("manifest_version" in e for e in validator.errors) def test_host_version_compatibility(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator(host_version="0.8.5") - manifest = { - "name": "test", - "version": "1.0", - "description": "d", - "author": "a", - "host_application": {"min_version": "0.9.0"}, - } + validator = ManifestValidator(host_version="0.8.5", sdk_version="2.0.1") + manifest = build_test_manifest( + "test.host-check", + host_min_version="0.9.0", + host_max_version="1.0.0", + ) assert validator.validate(manifest) is False assert any("Host 版本不兼容" in e for e in validator.errors) - def test_recommended_fields_warning(self): + def test_sdk_version_compatibility(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator - validator = ManifestValidator() - manifest = { - "name": "test", - "version": "1.0", - "description": "d", - "author": "a", - } - validator.validate(manifest) - assert len(validator.warnings) >= 3 # license, keywords, categories + validator = ManifestValidator(host_version="1.0.0", sdk_version="1.9.9") + manifest = build_test_manifest("test.sdk-check") + assert validator.validate(manifest) is False + assert any("SDK 版本不兼容" in e for e in validator.errors) + + def test_extra_fields_are_rejected(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest("test.extra-field") + manifest["unexpected"] = True + + assert validator.validate(manifest) is False + assert any("存在未声明字段" in error for error in validator.errors) + + def test_python_package_conflict_rejects_manifest(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest( + "test.numpy-conflict", + dependencies=[ + { + "type": "python_package", + "name": "numpy", + "version_spec": ">=999.0.0", + } + ], + ) + + assert validator.validate(manifest) is False + assert any("Python 包依赖冲突" in error for error in validator.errors) class TestVersionComparator: @@ -859,59 +969,83 @@ class TestDependencyResolution: loader = PluginLoader() candidates = { - "core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"), - "auth": ( - "dir_auth", - {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, + "test.core": ( + "dir_core", + build_test_manifest_model("test.core"), "plugin.py", ), - "api": ( + "test.auth": ( + "dir_auth", + build_test_manifest_model( + "test.auth", + dependencies=[ + {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), + "plugin.py", + ), + "test.api": ( "dir_api", - {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, + build_test_manifest_model( + "test.api", + dependencies=[ + {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"}, + {"type": "plugin", "id": "test.auth", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), "plugin.py", ), } order, failed = loader._resolve_dependencies(candidates) assert len(failed) == 0 - assert order.index("core") < order.index("auth") - assert order.index("auth") < order.index("api") + assert order.index("test.core") < order.index("test.auth") + assert order.index("test.auth") < order.index("test.api") def test_missing_dependency(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { - "plugin_a": ( + "test.plugin-a": ( "dir_a", - { - "name": "plugin_a", - "version": "1.0", - "description": "d", - "author": "a", - "dependencies": ["nonexistent"], - }, + build_test_manifest_model( + "test.plugin-a", + dependencies=[ + {"type": "plugin", "id": "test.nonexistent", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), "plugin.py", ), } order, failed = loader._resolve_dependencies(candidates) - assert "plugin_a" in failed - assert "缺少依赖" in failed["plugin_a"] + assert "test.plugin-a" in failed + assert "依赖未满足" in failed["test.plugin-a"] def test_circular_dependency(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { - "a": ( + "test.a": ( "dir_a", - {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, + build_test_manifest_model( + "test.a", + dependencies=[ + {"type": "plugin", "id": "test.b", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), "p.py", ), - "b": ( + "test.b": ( "dir_b", - {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, + build_test_manifest_model( + "test.b", + dependencies=[ + {"type": "plugin", "id": "test.a", "version_spec": ">=1.0.0,<2.0.0"}, + ], + ), "p.py", ), } @@ -929,12 +1063,11 @@ class TestDependencyResolution: (plugin_dir / "_manifest.json").write_text( json.dumps( - { - "name": "grok_search_plugin", - "version": "1.0.0", - "description": "demo", - "author": "tester", - } + build_test_manifest( + "test.grok-search-plugin", + name="grok_search_plugin", + description="demo", + ) ), encoding="utf-8", ) @@ -954,7 +1087,7 @@ class TestDependencyResolution: loader = PluginLoader() loaded = loader.discover_and_load([str(plugin_root)]) - assert [meta.plugin_id for meta in loaded] == ["grok_search_plugin"] + assert [meta.plugin_id for meta in loaded] == ["test.grok-search-plugin"] assert loader.failed_plugins == {} assert loaded[0].instance.answer() == 42 @@ -968,12 +1101,11 @@ class TestDependencyResolution: (plugin_dir / "_manifest.json").write_text( json.dumps( - { - "name": "demo_plugin", - "version": "1.0.0", - "description": "demo", - "author": "tester", - } + build_test_manifest( + "test.demo-plugin", + name="demo_plugin", + description="demo", + ) ), encoding="utf-8", ) @@ -993,8 +1125,8 @@ class TestDependencyResolution: loaded = loader.discover_and_load([str(plugin_root)]) assert loaded == [] - assert "demo_plugin" in loader.failed_plugins - assert "on_config_update" in loader.failed_plugins["demo_plugin"] + assert "test.demo-plugin" in loader.failed_plugins + assert "on_config_update" in loader.failed_plugins["test.demo-plugin"] def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path): from src.plugin_runtime.runner.plugin_loader import PluginLoader @@ -1006,12 +1138,11 @@ class TestDependencyResolution: (plugin_dir / "_manifest.json").write_text( json.dumps( - { - "name": "demo_plugin", - "version": "1.0.0", - "description": "demo", - "author": "tester", - } + build_test_manifest( + "test.demo-plugin", + name="demo_plugin", + description="demo", + ) ), encoding="utf-8", ) @@ -1031,8 +1162,8 @@ class TestDependencyResolution: loaded = loader.discover_and_load([str(plugin_root)]) assert loaded == [] - assert "demo_plugin" in loader.failed_plugins - assert "on_load" in loader.failed_plugins["demo_plugin"] + assert "test.demo-plugin" in loader.failed_plugins + assert "on_load" in loader.failed_plugins["test.demo-plugin"] def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path): from src.plugin_runtime.runner.plugin_loader import PluginLoader @@ -1044,12 +1175,11 @@ class TestDependencyResolution: (plugin_dir / "_manifest.json").write_text( json.dumps( - { - "name": "demo_plugin", - "version": "1.0.0", - "description": "demo", - "author": "tester", - } + build_test_manifest( + "test.demo-plugin", + name="demo_plugin", + description="demo", + ) ), encoding="utf-8", ) @@ -1069,8 +1199,8 @@ class TestDependencyResolution: loaded = loader.discover_and_load([str(plugin_root)]) assert loaded == [] - assert "demo_plugin" in loader.failed_plugins - assert "on_unload" in loader.failed_plugins["demo_plugin"] + assert "test.demo-plugin" in loader.failed_plugins + assert "on_unload" in loader.failed_plugins["test.demo-plugin"] def test_isolate_sys_path_preserves_plugin_dirs(self): from src.plugin_runtime.runner import runner_main @@ -2374,16 +2504,19 @@ class TestIntegration: def __init__(self, plugin_dirs=None, socket_path=None): self._plugin_dirs = plugin_dirs or [] self.capability_service = FakeCapabilityService() - self.external_plugin_ids = [] + self.external_plugin_versions = {} self.stopped = False instances.append(self) - def set_external_available_plugin_ids(self, plugin_ids): - self.external_plugin_ids = list(plugin_ids) + def set_external_available_plugins(self, plugin_versions): + self.external_plugin_versions = dict(plugin_versions) def get_loaded_plugin_ids(self): return [] + def get_loaded_plugin_versions(self): + return {} + async def start(self): if len(instances) == 2 and self is instances[1]: raise RuntimeError("boom") @@ -2425,8 +2558,8 @@ class TestIntegration: (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") - (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") monkeypatch.chdir(tmp_path) @@ -2440,8 +2573,11 @@ class TestIntegration: def get_loaded_plugin_ids(self): return sorted(self._registered_plugins.keys()) + def get_loaded_plugin_versions(self): + return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins} + async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): - self.reload_reasons.append((plugin_ids, reason, external_available_plugins or [])) + self.reload_reasons.append((plugin_ids, reason, external_available_plugins or {})) async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): self.config_updates.append((plugin_id, config_data, config_version)) @@ -2449,8 +2585,8 @@ class TestIntegration: manager = integration_module.PluginRuntimeManager() manager._started = True - manager._builtin_supervisor = FakeSupervisor([builtin_root], {"alpha": object()}) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"beta": object()}) + manager._builtin_supervisor = FakeSupervisor([builtin_root], {"test.alpha": object()}) + manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"test.beta": object()}) changes = [ FileChange(change_type=1, path=beta_dir / "plugin.py"), @@ -2466,7 +2602,9 @@ class TestIntegration: await manager._handle_plugin_source_changes(changes) assert manager._builtin_supervisor.reload_reasons == [] - assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher", ["alpha"])] + assert manager._third_party_supervisor.reload_reasons == [ + (["test.beta"], "file_watcher", {"test.alpha": "1.0.0"}) + ] assert manager._builtin_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == [] assert refresh_calls == [True] @@ -2487,15 +2625,18 @@ class TestIntegration: def get_loaded_plugin_ids(self): return sorted(self._registered_plugins.keys()) + def get_loaded_plugin_versions(self): + return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins} + async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): - self.reload_calls.append((plugin_ids, reason, sorted(external_available_plugins or []))) + self.reload_calls.append((plugin_ids, reason, dict(sorted((external_available_plugins or {}).items())))) return True - builtin_supervisor = FakeSupervisor({"alpha": FakeRegistration([])}) + builtin_supervisor = FakeSupervisor({"test.alpha": FakeRegistration([])}) third_party_supervisor = FakeSupervisor( { - "beta": FakeRegistration(["alpha"]), - "gamma": FakeRegistration(["beta"]), + "test.beta": FakeRegistration(["test.alpha"]), + "test.gamma": FakeRegistration(["test.beta"]), } ) @@ -2510,13 +2651,15 @@ class TestIntegration: lambda message: warning_messages.append(message), ) - reloaded = await manager.reload_plugins_globally(["alpha"], reason="manual") + reloaded = await manager.reload_plugins_globally(["test.alpha"], reason="manual") assert reloaded is True - assert builtin_supervisor.reload_calls == [(["alpha"], "manual", ["beta", "gamma"])] + assert builtin_supervisor.reload_calls == [ + (["test.alpha"], "manual", {"test.beta": "1.0.0", "test.gamma": "1.0.0"}) + ] assert third_party_supervisor.reload_calls == [] assert len(warning_messages) == 1 - assert "beta, gamma" in warning_messages[0] + assert "test.beta, test.gamma" in warning_messages[0] assert "跨 Supervisor API 调用仍然可用" in warning_messages[0] @pytest.mark.asyncio @@ -2535,8 +2678,8 @@ class TestIntegration: (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") - (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") monkeypatch.chdir(tmp_path) @@ -2558,15 +2701,15 @@ class TestIntegration: manager = integration_module.PluginRuntimeManager() manager._started = True - manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) + manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) + manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"]) await manager._handle_plugin_config_changes( - "alpha", + "test.alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) - assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "", "self")] + assert manager._builtin_supervisor.config_updates == [("test.alpha", {"enabled": True}, "", "self")] assert manager._third_party_supervisor.config_updates == [] @pytest.mark.asyncio @@ -2615,23 +2758,23 @@ class TestIntegration: manager._started = True manager._builtin_supervisor = FakeSupervisor( { - "alpha": FakeRegistration(["bot"]), - "beta": FakeRegistration([]), + "test.alpha": FakeRegistration(["bot"]), + "test.beta": FakeRegistration([]), } ) manager._third_party_supervisor = FakeSupervisor( { - "gamma": FakeRegistration(["model"]), + "test.gamma": FakeRegistration(["model"]), } ) await manager._handle_main_config_reload(["bot", "model"]) assert manager._builtin_supervisor.config_updates == [ - ("alpha", {"bot": {"name": "MaiBot"}}, "", "bot") + ("test.alpha", {"bot": {"name": "MaiBot"}}, "", "bot") ] assert manager._third_party_supervisor.config_updates == [ - ("gamma", {"models": [{"name": "demo"}]}, "", "model") + ("test.gamma", {"models": [{"name": "demo"}]}, "", "model") ] def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): @@ -2646,8 +2789,8 @@ class TestIntegration: beta_dir.mkdir(parents=True) (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") - (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") class FakeWatcher: def __init__(self): @@ -2670,12 +2813,12 @@ class TestIntegration: manager = integration_module.PluginRuntimeManager() manager._plugin_file_watcher = FakeWatcher() - manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) + manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) + manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"]) manager._refresh_plugin_config_watch_subscriptions() - assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"alpha", "beta"} + assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"} assert { subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions } == {alpha_dir / "config.toml", beta_dir / "config.toml"} diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py index 704ce514..7f2d789f 100644 --- a/src/plugin_runtime/__init__.py +++ b/src/plugin_runtime/__init__.py @@ -18,7 +18,7 @@ ENV_HOST_VERSION = "MAIBOT_HOST_VERSION" """Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验""" ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS" -"""Runner 启动时可视为已满足的外部插件依赖列表(JSON 数组)""" +"""Runner 启动时可视为已满足的外部插件依赖版本映射(JSON 对象)""" ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT" """Runner 启动时注入的全局配置快照(JSON 对象)""" diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index 2e4c111c..33b54c64 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -191,7 +191,7 @@ class RuntimeComponentCapabilityMixin: return None, None, "缺少必要参数 api_name" if "." in normalized_api_name: - target_plugin_id, target_api_name = normalized_api_name.split(".", 1) + target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1) try: supervisor = self._get_supervisor_for_plugin(target_plugin_id) except RuntimeError as exc: @@ -282,7 +282,7 @@ class RuntimeComponentCapabilityMixin: return None, None, "缺少必要参数 name" if "." in normalized_name: - plugin_id, api_name = normalized_name.split(".", 1) + plugin_id, api_name = normalized_name.rsplit(".", 1) try: supervisor = self._get_supervisor_for_plugin(plugin_id) except RuntimeError as exc: diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 693eae51..ac953bb3 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -116,7 +116,7 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {} - self._external_available_plugin_ids: List[str] = [] + self._external_available_plugins: Dict[str, str] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -166,21 +166,34 @@ class PluginRunnerSupervisor: """返回底层 RPC 服务端。""" return self._rpc_server - def set_external_available_plugin_ids(self, plugin_ids: List[str]) -> None: - """设置当前 Runner 启动/重载时可视为已满足的外部依赖列表。""" + def set_external_available_plugins(self, plugin_versions: Dict[str, str]) -> None: + """设置当前 Runner 启动/重载时可视为已满足的外部依赖版本映射。 - normalized_plugin_ids = { - str(plugin_id or "").strip() - for plugin_id in plugin_ids - if str(plugin_id or "").strip() + Args: + plugin_versions: 外部插件版本映射,键为插件 ID,值为插件版本。 + """ + self._external_available_plugins = { + str(plugin_id or "").strip(): str(plugin_version or "").strip() + for plugin_id, plugin_version in plugin_versions.items() + if str(plugin_id or "").strip() and str(plugin_version or "").strip() } - self._external_available_plugin_ids = sorted(normalized_plugin_ids) def get_loaded_plugin_ids(self) -> List[str]: """返回当前 Supervisor 已注册的插件 ID 列表。""" return sorted(self._registered_plugins.keys()) + def get_loaded_plugin_versions(self) -> Dict[str, str]: + """返回当前 Supervisor 已注册插件的版本映射。 + + Returns: + Dict[str, str]: 已注册插件版本映射,键为插件 ID,值为插件版本。 + """ + return { + plugin_id: registration.plugin_version + for plugin_id, registration in self._registered_plugins.items() + } + async def dispatch_event( self, event_type: str, @@ -373,14 +386,14 @@ class PluginRunnerSupervisor: self, plugin_id: str, reason: str = "manual", - external_available_plugins: Optional[List[str]] = None, + external_available_plugins: Optional[Dict[str, str]] = None, ) -> bool: """按插件 ID 触发精确重载。 Args: plugin_id: 目标插件 ID。 reason: 重载原因。 - external_available_plugins: 视为已满足的外部依赖插件 ID 列表。 + external_available_plugins: 视为已满足的外部依赖插件版本映射。 Returns: bool: 是否重载成功。 @@ -392,7 +405,7 @@ class PluginRunnerSupervisor: payload={ "plugin_id": plugin_id, "reason": reason, - "external_available_plugins": external_available_plugins or self._external_available_plugin_ids, + "external_available_plugins": external_available_plugins or self._external_available_plugins, }, timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), ) @@ -409,14 +422,14 @@ class PluginRunnerSupervisor: self, plugin_ids: Optional[List[str]] = None, reason: str = "manual", - external_available_plugins: Optional[List[str]] = None, + external_available_plugins: Optional[Dict[str, str]] = None, ) -> bool: """批量重载插件。 Args: plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。 reason: 重载原因。 - external_available_plugins: 视为已满足的外部依赖插件 ID 列表。 + external_available_plugins: 视为已满足的外部依赖插件版本映射。 Returns: bool: 是否全部重载成功。 @@ -1136,7 +1149,7 @@ class PluginRunnerSupervisor: global_config_snapshot = config_manager.get_global_config().model_dump() global_config_snapshot["model"] = config_manager.get_model_config().model_dump() return { - ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugin_ids, ensure_ascii=False), + ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False), ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False), ENV_HOST_VERSION: PROTOCOL_VERSION, ENV_IPC_ADDRESS: self._transport.get_address(), diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index d48260e5..092b9597 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -11,7 +11,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple import asyncio -import json import tomlkit @@ -26,6 +25,7 @@ from src.plugin_runtime.capabilities import ( ) from src.plugin_runtime.capabilities.registry import register_capability_impls from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils +from src.plugin_runtime.runner.manifest_validator import ManifestValidator if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage @@ -69,6 +69,7 @@ class PluginRuntimeManager( self._plugin_source_watcher_subscription_id: Optional[str] = None self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {} self._plugin_path_cache: Dict[str, Path] = {} + self._manifest_validator: ManifestValidator = ManifestValidator() self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload self._config_reload_callback_registered: bool = False @@ -102,46 +103,11 @@ class PluginRuntimeManager( candidate = Path("plugins").resolve() return [candidate] if candidate.is_dir() else [] - @staticmethod - def _extract_manifest_dependencies(manifest: Dict[str, Any]) -> List[str]: - """从插件 manifest 中提取规范化后的依赖插件 ID 列表。""" - - dependencies: List[str] = [] - for dependency in manifest.get("dependencies", []): - if isinstance(dependency, str): - normalized_dependency = dependency.strip() - elif isinstance(dependency, dict): - normalized_dependency = str(dependency.get("name", "") or "").strip() - else: - normalized_dependency = "" - - if normalized_dependency: - dependencies.append(normalized_dependency) - return dependencies - @classmethod def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]: """扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。""" - - dependency_map: Dict[str, List[str]] = {} - for plugin_dir in cls._iter_candidate_plugin_paths(plugin_dirs): - manifest_path = plugin_dir / "_manifest.json" - entrypoint_path = plugin_dir / "plugin.py" - if not manifest_path.is_file() or not entrypoint_path.is_file(): - continue - - try: - with manifest_path.open("r", encoding="utf-8") as manifest_file: - manifest = json.load(manifest_file) - except Exception: - continue - - if not isinstance(manifest, dict): - continue - - plugin_id = str(manifest.get("name", plugin_dir.name) or "").strip() or plugin_dir.name - dependency_map[plugin_id] = cls._extract_manifest_dependencies(manifest) - return dependency_map + validator = ManifestValidator() + return validator.build_plugin_dependency_map(plugin_dirs) @classmethod def _build_group_start_order( @@ -243,12 +209,12 @@ class PluginRuntimeManager( if supervisor is None: continue - external_plugin_ids = [ - plugin_id + external_plugin_versions = { + plugin_id: plugin_version for started_supervisor in started_supervisors - for plugin_id in started_supervisor.get_loaded_plugin_ids() - ] - supervisor.set_external_available_plugin_ids(external_plugin_ids) + for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items() + } + supervisor.set_external_available_plugins(external_plugin_versions) await supervisor.start() started_supervisors.append(supervisor) @@ -366,23 +332,22 @@ class PluginRuntimeManager( for plugin_id in supervisor.get_loaded_plugin_ids() } - def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> List[str]: - """收集某个 Supervisor 可用的外部插件 ID 列表。""" + def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]: + """收集某个 Supervisor 可用的外部插件版本映射。""" - external_plugin_ids: Set[str] = set() + external_plugin_versions: Dict[str, str] = {} for supervisor in self.supervisors: if supervisor is target_supervisor: continue - external_plugin_ids.update(supervisor.get_loaded_plugin_ids()) - return sorted(external_plugin_ids) + external_plugin_versions.update(supervisor.get_loaded_plugin_versions()) + return external_plugin_versions def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]: """根据插件目录推断应负责该插件重载的 Supervisor。""" for supervisor in self.supervisors: - for plugin_dir in supervisor._plugin_dirs: - if (Path(plugin_dir) / plugin_id).is_dir(): - return supervisor + if self._get_plugin_path_for_supervisor(supervisor, plugin_id) is not None: + return supervisor return None def _warn_skipped_cross_supervisor_reload( @@ -740,30 +705,13 @@ class PluginRuntimeManager( external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor), ) - @staticmethod - def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]: + @classmethod + def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: """扫描插件目录,找出被多个目录重复声明的插件 ID。""" plugin_locations: Dict[str, List[Path]] = {} - for base_dir in plugin_dirs: - if not base_dir.is_dir(): - continue - for entry in base_dir.iterdir(): - if not entry.is_dir(): - continue - manifest_path = entry / "_manifest.json" - plugin_path = entry / "plugin.py" - if not manifest_path.exists() or not plugin_path.exists(): - continue - - plugin_id = entry.name - try: - with open(manifest_path, "r", encoding="utf-8") as manifest_file: - manifest = json.load(manifest_file) - plugin_id = str(manifest.get("name", entry.name)).strip() or entry.name - except Exception: - continue - - plugin_locations.setdefault(plugin_id, []).append(entry) + validator = ManifestValidator() + for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs): + plugin_locations.setdefault(manifest.id, []).append(plugin_path) return { plugin_id: sorted(dict.fromkeys(paths), key=lambda p: str(p)) @@ -831,8 +779,7 @@ class PluginRuntimeManager( if entry.is_dir(): yield entry.resolve() - @staticmethod - def _read_plugin_id_from_plugin_path(plugin_path: Path) -> Optional[str]: + def _read_plugin_id_from_plugin_path(self, plugin_path: Path) -> Optional[str]: """从单个插件目录中读取 manifest 声明的插件 ID。 Args: @@ -841,22 +788,7 @@ class PluginRuntimeManager( Returns: Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。 """ - manifest_path = plugin_path / "_manifest.json" - entrypoint_path = plugin_path / "plugin.py" - if not manifest_path.is_file() or not entrypoint_path.is_file(): - return None - - try: - with open(manifest_path, "r", encoding="utf-8") as manifest_file: - manifest = json.load(manifest_file) - except Exception: - return None - - if not isinstance(manifest, dict): - return None - - plugin_id = str(manifest.get("name", plugin_path.name)).strip() or plugin_path.name - return plugin_id or None + return self._manifest_validator.read_plugin_id_from_plugin_path(plugin_path) def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]: """迭代目录中可解析到的插件 ID 与实际目录路径。 diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index ce40d855..c2c89a0f 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -282,8 +282,11 @@ class ReloadPluginPayload(BaseModel): """目标插件 ID""" reason: str = Field(default="manual", description="重载原因") """重载原因""" - external_available_plugins: List[str] = Field(default_factory=list, description="可视为已满足的外部依赖插件 ID") - """可视为已满足的外部依赖插件 ID""" + external_available_plugins: Dict[str, str] = Field( + default_factory=dict, + description="可视为已满足的外部依赖插件版本映射", + ) + """可视为已满足的外部依赖插件版本映射""" class ReloadPluginResultPayload(BaseModel): diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py index 32429e01..33c2b1e5 100644 --- a/src/plugin_runtime/runner/manifest_validator.py +++ b/src/plugin_runtime/runner/manifest_validator.py @@ -1,20 +1,36 @@ -"""Manifest 校验与版本兼容性 +"""Manifest 校验与解析。 -从旧系统的 ManifestValidator / VersionComparator 对齐移植, -适配新 plugin_runtime 的 _manifest.json 格式。 +集中负责插件 ``_manifest.json`` 的读取、结构校验、运行时兼容性判断, +以及插件依赖/Python 包依赖的解析逻辑。 """ -from typing import Any, Dict, List, Tuple +from functools import lru_cache +from importlib import metadata as importlib_metadata +from pathlib import Path +from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union +import json import re +import tomllib + +from packaging.requirements import InvalidRequirement, Requirement +from packaging.specifiers import InvalidSpecifier, SpecifierSet +from packaging.utils import canonicalize_name +from packaging.version import InvalidVersion, Version +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator from src.common.logger import get_logger logger = get_logger("plugin_runtime.runner.manifest_validator") +_SEMVER_PATTERN = re.compile(r"^\d+\.\d+\.\d+$") +_PLUGIN_ID_PATTERN = re.compile(r"^[a-z0-9]+(?:[.-][a-z0-9]+)+$") +_PACKAGE_NAME_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") +_HTTP_URL_PATTERN = re.compile(r"^https?://.+$") + class VersionComparator: - """语义化版本号比较器""" + """语义化版本号比较器。""" @staticmethod def normalize_version(version: str) -> str: @@ -25,13 +41,15 @@ class VersionComparator: Returns: str: 规范化后的 ``major.minor.patch`` 形式版本号。 - 当输入为空或格式非法时返回 ``0.0.0``。 + 当输入为空或格式非法时返回 ``0.0.0``。 """ if not version: return "0.0.0" - normalized = re.sub(r"-snapshot\.\d+", "", version.strip()) + + normalized = re.sub(r"-snapshot\.\d+", "", str(version).strip()) if not re.match(r"^\d+(\.\d+){0,2}$", normalized): return "0.0.0" + parts = normalized.split(".") while len(parts) < 3: parts.append("0") @@ -46,7 +64,7 @@ class VersionComparator: Returns: Tuple[int, int, int]: 三段式版本号对应的整数元组。 - 当解析失败时返回 ``(0, 0, 0)``。 + 当解析失败时返回 ``(0, 0, 0)``。 """ normalized = VersionComparator.normalize_version(version) try: @@ -65,13 +83,13 @@ class VersionComparator: Returns: int: ``-1`` 表示 ``v1 < v2``,``1`` 表示 ``v1 > v2``, - ``0`` 表示两者相等。 + ``0`` 表示两者相等。 """ t1 = VersionComparator.parse_version(v1) t2 = VersionComparator.parse_version(v2) if t1 < t2: return -1 - elif t1 > t2: + if t1 > t2: return 1 return 0 @@ -86,120 +104,1043 @@ class VersionComparator: Returns: Tuple[bool, str]: 第一项表示是否满足要求,第二项为失败原因; - 当校验通过时第二项为空字符串。 + 当校验通过时第二项为空字符串。 """ if not min_version and not max_version: return True, "" - vn = VersionComparator.normalize_version(version) + + normalized_version = VersionComparator.normalize_version(version) if min_version: - mn = VersionComparator.normalize_version(min_version) - if VersionComparator.compare(vn, mn) < 0: - return False, f"版本 {vn} 低于最小要求 {mn}" + normalized_min_version = VersionComparator.normalize_version(min_version) + if VersionComparator.compare(normalized_version, normalized_min_version) < 0: + return False, f"版本 {normalized_version} 低于最小要求 {normalized_min_version}" if max_version: - mx = VersionComparator.normalize_version(max_version) - if VersionComparator.compare(vn, mx) > 0: - return False, f"版本 {vn} 高于最大支持 {mx}" + normalized_max_version = VersionComparator.normalize_version(max_version) + if VersionComparator.compare(normalized_version, normalized_max_version) > 0: + return False, f"版本 {normalized_version} 高于最大支持 {normalized_max_version}" return True, "" + @staticmethod + def is_valid_semver(version: str) -> bool: + """判断字符串是否为严格三段式语义版本号。 + + Args: + version: 待检查的版本号字符串。 + + Returns: + bool: 是否满足 ``X.Y.Z`` 格式。 + """ + return bool(_SEMVER_PATTERN.fullmatch(str(version or "").strip())) + + +class _StrictManifestModel(BaseModel): + """Manifest 解析使用的严格基类模型。""" + + model_config = ConfigDict(extra="forbid", frozen=True, str_strip_whitespace=True) + + +class ManifestAuthor(_StrictManifestModel): + """插件作者信息。""" + + name: str = Field(description="作者名称") + url: str = Field(description="作者主页地址") + + @field_validator("name") + @classmethod + def _validate_name(cls, value: str) -> str: + """校验作者名称。 + + Args: + value: 原始作者名称。 + + Returns: + str: 规范化后的作者名称。 + + Raises: + ValueError: 当字段为空时抛出。 + """ + if not value: + raise ValueError("不能为空") + return value + + @field_validator("url") + @classmethod + def _validate_url(cls, value: str) -> str: + """校验作者主页地址。 + + Args: + value: 原始主页地址。 + + Returns: + str: 规范化后的主页地址。 + + Raises: + ValueError: 当字段为空或不是 HTTP/HTTPS URL 时抛出。 + """ + if not value: + raise ValueError("不能为空") + if not _HTTP_URL_PATTERN.fullmatch(value): + raise ValueError("必须为 http:// 或 https:// 开头的 URL") + return value + + +class ManifestUrls(_StrictManifestModel): + """插件相关链接集合。""" + + repository: str = Field(description="插件仓库地址") + homepage: Optional[str] = Field(default=None, description="插件主页地址") + documentation: Optional[str] = Field(default=None, description="插件文档地址") + issues: Optional[str] = Field(default=None, description="插件问题反馈地址") + + @field_validator("repository") + @classmethod + def _validate_repository(cls, value: str) -> str: + """校验仓库地址。 + + Args: + value: 原始仓库地址。 + + Returns: + str: 规范化后的仓库地址。 + + Raises: + ValueError: 当字段为空或不是 HTTP/HTTPS URL 时抛出。 + """ + if not value: + raise ValueError("不能为空") + if not _HTTP_URL_PATTERN.fullmatch(value): + raise ValueError("必须为 http:// 或 https:// 开头的 URL") + return value + + @field_validator("homepage", "documentation", "issues") + @classmethod + def _validate_optional_url(cls, value: Optional[str]) -> Optional[str]: + """校验可选链接字段。 + + Args: + value: 原始链接值。 + + Returns: + Optional[str]: 合法的链接值。 + + Raises: + ValueError: 当提供的值不是 HTTP/HTTPS URL 时抛出。 + """ + if value is None: + return None + if not value: + raise ValueError("不能为空字符串") + if not _HTTP_URL_PATTERN.fullmatch(value): + raise ValueError("必须为 http:// 或 https:// 开头的 URL") + return value + + +class ManifestVersionRange(_StrictManifestModel): + """版本闭区间声明。""" + + min_version: str = Field(description="最小版本,闭区间") + max_version: str = Field(description="最大版本,闭区间") + + @field_validator("min_version", "max_version") + @classmethod + def _validate_version(cls, value: str) -> str: + """校验版本号格式。 + + Args: + value: 原始版本号。 + + Returns: + str: 合法的版本号。 + + Raises: + ValueError: 当版本号不是严格三段式语义版本时抛出。 + """ + if not VersionComparator.is_valid_semver(value): + raise ValueError("必须为严格三段式版本号,例如 1.0.0") + return value + + @model_validator(mode="after") + def _validate_range(self) -> "ManifestVersionRange": + """校验版本区间上下界关系。 + + Returns: + ManifestVersionRange: 当前对象本身。 + + Raises: + ValueError: 当最小版本大于最大版本时抛出。 + """ + if VersionComparator.compare(self.min_version, self.max_version) > 0: + raise ValueError("min_version 不能大于 max_version") + return self + + +class ManifestI18n(_StrictManifestModel): + """国际化配置。""" + + default_locale: str = Field(description="默认语言") + locales_path: Optional[str] = Field(default=None, description="语言资源目录") + supported_locales: List[str] = Field(default_factory=list, description="支持的语言列表") + + @field_validator("default_locale") + @classmethod + def _validate_default_locale(cls, value: str) -> str: + """校验默认语言。 + + Args: + value: 原始默认语言。 + + Returns: + str: 规范化后的默认语言。 + + Raises: + ValueError: 当字段为空时抛出。 + """ + if not value: + raise ValueError("不能为空") + return value + + @field_validator("locales_path") + @classmethod + def _validate_locales_path(cls, value: Optional[str]) -> Optional[str]: + """校验语言资源目录。 + + Args: + value: 原始语言资源目录。 + + Returns: + Optional[str]: 合法的目录值。 + + Raises: + ValueError: 当值为空字符串时抛出。 + """ + if value is None: + return None + if not value: + raise ValueError("不能为空字符串") + return value + + @field_validator("supported_locales") + @classmethod + def _validate_supported_locales(cls, value: List[str]) -> List[str]: + """校验支持语言列表。 + + Args: + value: 原始语言列表。 + + Returns: + List[str]: 去重后的语言列表。 + + Raises: + ValueError: 当列表项为空时抛出。 + """ + normalized_locales: List[str] = [] + for locale in value: + normalized_locale = str(locale or "").strip() + if not normalized_locale: + raise ValueError("语言列表中存在空值") + if normalized_locale not in normalized_locales: + normalized_locales.append(normalized_locale) + return normalized_locales + + @model_validator(mode="after") + def _validate_default_locale_membership(self) -> "ManifestI18n": + """校验默认语言是否位于支持列表中。 + + Returns: + ManifestI18n: 当前对象本身。 + + Raises: + ValueError: 当 ``supported_locales`` 非空但未包含 ``default_locale`` 时抛出。 + """ + if self.supported_locales and self.default_locale not in self.supported_locales: + raise ValueError("default_locale 必须包含在 supported_locales 中") + return self + + +class PluginDependencyDefinition(_StrictManifestModel): + """插件级依赖声明。""" + + type: Literal["plugin"] = Field(description="依赖类型") + id: str = Field(description="依赖插件 ID") + version_spec: str = Field(description="版本约束表达式") + + @field_validator("id") + @classmethod + def _validate_id(cls, value: str) -> str: + """校验依赖插件 ID。 + + Args: + value: 原始依赖插件 ID。 + + Returns: + str: 合法的依赖插件 ID。 + + Raises: + ValueError: 当 ID 不符合规则时抛出。 + """ + if not _PLUGIN_ID_PATTERN.fullmatch(value): + raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin") + return value + + @field_validator("version_spec") + @classmethod + def _validate_version_spec(cls, value: str) -> str: + """校验插件依赖版本约束。 + + Args: + value: 原始版本约束表达式。 + + Returns: + str: 合法的版本约束表达式。 + + Raises: + ValueError: 当表达式无效时抛出。 + """ + if not value: + raise ValueError("不能为空") + try: + SpecifierSet(value) + except InvalidSpecifier as exc: + raise ValueError(f"无效的版本约束: {exc}") from exc + return value + + +class PythonPackageDependencyDefinition(_StrictManifestModel): + """Python 包依赖声明。""" + + type: Literal["python_package"] = Field(description="依赖类型") + name: str = Field(description="Python 包名") + version_spec: str = Field(description="版本约束表达式") + + @field_validator("name") + @classmethod + def _validate_name(cls, value: str) -> str: + """校验 Python 包名。 + + Args: + value: 原始包名。 + + Returns: + str: 合法的包名。 + + Raises: + ValueError: 当包名不合法时抛出。 + """ + if not _PACKAGE_NAME_PATTERN.fullmatch(value): + raise ValueError("包名只能包含字母、数字、点号、下划线和横线") + return value + + @field_validator("version_spec") + @classmethod + def _validate_version_spec(cls, value: str) -> str: + """校验 Python 包版本约束。 + + Args: + value: 原始版本约束表达式。 + + Returns: + str: 合法的版本约束表达式。 + + Raises: + ValueError: 当表达式无效时抛出。 + """ + if not value: + raise ValueError("不能为空") + try: + Requirement(f"placeholder{value}") + except InvalidRequirement as exc: + raise ValueError(f"无效的版本约束: {exc}") from exc + return value + + +ManifestDependencyDefinition = Annotated[ + Union[PluginDependencyDefinition, PythonPackageDependencyDefinition], + Field(discriminator="type"), +] + + +class PluginManifest(_StrictManifestModel): + """插件 Manifest v2 强类型模型。""" + + manifest_version: Literal[2] = Field(description="Manifest 协议版本") + version: str = Field(description="插件版本") + name: str = Field(description="插件展示名称") + description: str = Field(description="插件描述") + author: ManifestAuthor = Field(description="插件作者信息") + license: str = Field(description="插件协议") + urls: ManifestUrls = Field(description="插件相关链接") + host_application: ManifestVersionRange = Field(description="Host 兼容区间") + sdk: ManifestVersionRange = Field(description="SDK 兼容区间") + dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明") + capabilities: List[str] = Field(description="插件声明的能力请求") + i18n: ManifestI18n = Field(description="国际化配置") + id: str = Field(description="稳定插件 ID") + + @field_validator("version") + @classmethod + def _validate_version(cls, value: str) -> str: + """校验插件版本号格式。 + + Args: + value: 原始插件版本号。 + + Returns: + str: 合法的插件版本号。 + + Raises: + ValueError: 当版本号不是严格三段式语义版本时抛出。 + """ + if not VersionComparator.is_valid_semver(value): + raise ValueError("必须为严格三段式版本号,例如 1.0.0") + return value + + @field_validator("name", "description", "license", "id") + @classmethod + def _validate_required_string(cls, value: str, info: Any) -> str: + """校验必填字符串字段。 + + Args: + value: 原始字段值。 + info: Pydantic 字段上下文。 + + Returns: + str: 合法的字段值。 + + Raises: + ValueError: 当字段为空或格式不合法时抛出。 + """ + if not value: + raise ValueError("不能为空") + if info.field_name == "id" and not _PLUGIN_ID_PATTERN.fullmatch(value): + raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin") + return value + + @field_validator("capabilities") + @classmethod + def _validate_capabilities(cls, value: List[str]) -> List[str]: + """校验能力声明列表。 + + Args: + value: 原始能力声明列表。 + + Returns: + List[str]: 去重后的能力列表。 + + Raises: + ValueError: 当列表为空项或能力名为空时抛出。 + """ + normalized_capabilities: List[str] = [] + for capability in value: + normalized_capability = str(capability or "").strip() + if not normalized_capability: + raise ValueError("capabilities 中存在空能力名") + if normalized_capability not in normalized_capabilities: + normalized_capabilities.append(normalized_capability) + return normalized_capabilities + + @model_validator(mode="after") + def _validate_dependencies(self) -> "PluginManifest": + """校验依赖声明集合。 + + Returns: + PluginManifest: 当前对象本身。 + + Raises: + ValueError: 当依赖项重复或插件依赖自身时抛出。 + """ + plugin_dependency_ids: set[str] = set() + python_package_names: set[str] = set() + + for dependency in self.dependencies: + if isinstance(dependency, PluginDependencyDefinition): + if dependency.id == self.id: + raise ValueError("dependencies 中的插件依赖不能依赖自身") + if dependency.id in plugin_dependency_ids: + raise ValueError(f"存在重复的插件依赖声明: {dependency.id}") + plugin_dependency_ids.add(dependency.id) + continue + + normalized_package_name = canonicalize_name(dependency.name) + if normalized_package_name in python_package_names: + raise ValueError(f"存在重复的 Python 包依赖声明: {dependency.name}") + python_package_names.add(normalized_package_name) + + return self + + @property + def plugin_dependencies(self) -> List[PluginDependencyDefinition]: + """返回插件级依赖列表。 + + Returns: + List[PluginDependencyDefinition]: 所有 ``type=plugin`` 的依赖项。 + """ + return [dependency for dependency in self.dependencies if isinstance(dependency, PluginDependencyDefinition)] + + @property + def python_package_dependencies(self) -> List[PythonPackageDependencyDefinition]: + """返回 Python 包依赖列表。 + + Returns: + List[PythonPackageDependencyDefinition]: 所有 ``type=python_package`` 的依赖项。 + """ + return [ + dependency + for dependency in self.dependencies + if isinstance(dependency, PythonPackageDependencyDefinition) + ] + + @property + def plugin_dependency_ids(self) -> List[str]: + """返回插件级依赖的插件 ID 列表。 + + Returns: + List[str]: 所有插件级依赖的插件 ID。 + """ + return [dependency.id for dependency in self.plugin_dependencies] + class ManifestValidator: - """_manifest.json 校验器""" + """严格的插件 Manifest v2 校验器。""" - REQUIRED_FIELDS = ["name", "version", "description", "author"] - RECOMMENDED_FIELDS = ["license", "keywords", "categories"] - SUPPORTED_MANIFEST_VERSIONS = [1, 2] + SUPPORTED_MANIFEST_VERSIONS = [2] - def __init__(self, host_version: str = "") -> None: + def __init__( + self, + host_version: str = "", + sdk_version: str = "", + project_root: Optional[Path] = None, + ) -> None: """初始化 Manifest 校验器。 Args: - host_version: 当前 Host 版本号,用于校验插件声明的兼容区间。 + host_version: 当前 Host 版本号;留空时自动从主程序 ``pyproject.toml`` 读取。 + sdk_version: 当前 SDK 版本号;留空时自动从运行环境中探测。 + project_root: 项目根目录;留空时自动推断。 """ - self._host_version = host_version + self._project_root: Path = project_root or self._resolve_project_root() + self._host_version: str = host_version or self._detect_default_host_version(self._project_root) + self._sdk_version: str = sdk_version or self._detect_default_sdk_version(self._project_root) self.errors: List[str] = [] self.warnings: List[str] = [] def validate(self, manifest: Dict[str, Any]) -> bool: - """校验 manifest 数据,返回是否通过(errors 为空即通过)。""" + """校验 manifest 数据,返回是否通过。 + + Args: + manifest: 待校验的 Manifest 原始字典。 + + Returns: + bool: 校验是否通过。 + """ + return self.parse_manifest(manifest) is not None + + def parse_manifest(self, manifest: Dict[str, Any]) -> Optional[PluginManifest]: + """解析并校验 manifest 字典。 + + Args: + manifest: 待解析的 Manifest 原始字典。 + + Returns: + Optional[PluginManifest]: 解析成功时返回强类型 Manifest;失败时返回 ``None``。 + """ self.errors.clear() self.warnings.clear() - self._check_required_fields(manifest) - self._check_manifest_version(manifest) - self._check_author(manifest) - self._check_host_compatibility(manifest) - self._check_recommended(manifest) + try: + parsed_manifest = PluginManifest.model_validate(manifest) + except ValidationError as exc: + self.errors.extend(self._format_validation_errors(exc)) + self._log_errors() + return None + self._validate_runtime_compatibility(parsed_manifest) if self.errors: - for e in self.errors: - logger.error(f"Manifest 校验失败: {e}") - if self.warnings: - for w in self.warnings: - logger.warning(f"Manifest 警告: {w}") + self._log_errors() + return None - return len(self.errors) == 0 + return parsed_manifest - def _check_required_fields(self, manifest: Dict[str, Any]) -> None: - """检查 Manifest 中的必填字段是否存在且非空。 + def load_from_plugin_path(self, plugin_path: Path, require_entrypoint: bool = True) -> Optional[PluginManifest]: + """从插件目录读取并解析 manifest。 Args: - manifest: 待校验的 Manifest 数据。 - """ - for field in self.REQUIRED_FIELDS: - if field not in manifest: - self.errors.append(f"缺少必需字段: {field}") - elif not manifest[field]: - self.errors.append(f"必需字段不能为空: {field}") + plugin_path: 单个插件目录路径。 + require_entrypoint: 是否要求目录内存在 ``plugin.py`` 入口文件。 - def _check_manifest_version(self, manifest: Dict[str, Any]) -> None: - """检查 Manifest 版本号是否在当前 Runner 支持范围内。 + Returns: + Optional[PluginManifest]: 解析成功时返回强类型 Manifest;失败时返回 ``None``。 + """ + self.errors.clear() + self.warnings.clear() + + manifest_path = plugin_path / "_manifest.json" + entrypoint_path = plugin_path / "plugin.py" + + if not manifest_path.is_file(): + self.errors.append("缺少 _manifest.json") + return None + if require_entrypoint and not entrypoint_path.is_file(): + self.errors.append("缺少 plugin.py") + return None + + try: + with manifest_path.open("r", encoding="utf-8") as manifest_file: + manifest_data = json.load(manifest_file) + except Exception as exc: + self.errors.append(f"manifest 解析失败: {exc}") + self._log_errors() + return None + + if not isinstance(manifest_data, dict): + self.errors.append("manifest 顶层必须为 JSON 对象") + self._log_errors() + return None + + return self.parse_manifest(manifest_data) + + def iter_plugin_manifests( + self, + plugin_dirs: Iterable[Path], + require_entrypoint: bool = True, + ) -> Iterable[Tuple[Path, PluginManifest]]: + """扫描插件根目录并迭代所有可成功解析的 Manifest。 Args: - manifest: 待校验的 Manifest 数据。 - """ - mv = manifest.get("manifest_version") - if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS: - self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}") + plugin_dirs: 一个或多个插件根目录。 + require_entrypoint: 是否要求每个插件目录内存在 ``plugin.py``。 - def _check_author(self, manifest: Dict[str, Any]) -> None: - """校验 ``author`` 字段的结构与内容。 + Yields: + Tuple[Path, PluginManifest]: ``(插件目录路径, 解析结果)`` 二元组。 + """ + for plugin_root in plugin_dirs: + normalized_root = Path(plugin_root).resolve() + if not normalized_root.is_dir(): + continue + + for candidate_path in sorted(entry.resolve() for entry in normalized_root.iterdir() if entry.is_dir()): + parsed_manifest = self.load_from_plugin_path(candidate_path, require_entrypoint=require_entrypoint) + if parsed_manifest is None: + continue + yield candidate_path, parsed_manifest + + def build_plugin_dependency_map( + self, + plugin_dirs: Iterable[Path], + require_entrypoint: bool = True, + ) -> Dict[str, List[str]]: + """扫描目录并构建 ``plugin_id -> 依赖插件 ID 列表`` 映射。 Args: - manifest: 待校验的 Manifest 数据。 - """ - author = manifest.get("author") - if author is None: - return - if isinstance(author, dict): - if "name" not in author or not author["name"]: - self.errors.append("author 对象缺少 name 字段") - elif isinstance(author, str): - if not author.strip(): - self.errors.append("author 不能为空") - else: - self.errors.append("author 应为字符串或 {name, url} 对象") + plugin_dirs: 一个或多个插件根目录。 + require_entrypoint: 是否要求每个插件目录内存在 ``plugin.py``。 - def _check_host_compatibility(self, manifest: Dict[str, Any]) -> None: - """检查插件声明的 Host 兼容范围是否包含当前 Host 版本。 + Returns: + Dict[str, List[str]]: 所有成功解析到的插件依赖映射。 + """ + dependency_map: Dict[str, List[str]] = {} + for _plugin_path, manifest in self.iter_plugin_manifests(plugin_dirs, require_entrypoint=require_entrypoint): + dependency_map[manifest.id] = manifest.plugin_dependency_ids + return dependency_map + + def read_plugin_id_from_plugin_path(self, plugin_path: Path, require_entrypoint: bool = True) -> Optional[str]: + """从单个插件目录中读取规范化后的插件 ID。 Args: - manifest: 待校验的 Manifest 数据。 - """ - host_app = manifest.get("host_application") - if not isinstance(host_app, dict) or not self._host_version: - return - min_v = host_app.get("min_version", "") - max_v = host_app.get("max_version", "") - ok, msg = VersionComparator.is_in_range(self._host_version, min_v, max_v) - if not ok: - self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})") + plugin_path: 单个插件目录路径。 + require_entrypoint: 是否要求目录内存在 ``plugin.py``。 - def _check_recommended(self, manifest: Dict[str, Any]) -> None: - """检查推荐字段是否齐备,并记录为警告而非错误。 + Returns: + Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。 + """ + manifest = self.load_from_plugin_path(plugin_path, require_entrypoint=require_entrypoint) + if manifest is None: + return None + return manifest.id + + def get_unsatisfied_plugin_dependencies( + self, + manifest: PluginManifest, + available_plugin_versions: Dict[str, str], + ) -> List[str]: + """返回当前 Manifest 尚未满足的插件依赖项。 Args: - manifest: 待校验的 Manifest 数据。 + manifest: 目标插件的强类型 Manifest。 + available_plugin_versions: 当前可用插件版本映射,键为插件 ID,值为插件版本。 + + Returns: + List[str]: 未满足依赖的错误描述列表。 """ - for field in self.RECOMMENDED_FIELDS: - if field not in manifest or not manifest[field]: - self.warnings.append(f"建议填写字段: {field}") + unsatisfied_dependencies: List[str] = [] + for dependency in manifest.plugin_dependencies: + dependency_version = available_plugin_versions.get(dependency.id) + if not dependency_version: + unsatisfied_dependencies.append(f"{dependency.id} (未找到依赖插件)") + continue + + if not self._version_matches_specifier(dependency_version, dependency.version_spec): + unsatisfied_dependencies.append( + f"{dependency.id} (需要 {dependency.version_spec},当前 {dependency_version})" + ) + + return unsatisfied_dependencies + + def is_plugin_dependency_satisfied( + self, + dependency: PluginDependencyDefinition, + plugin_version: str, + ) -> bool: + """判断单个插件依赖是否被指定版本满足。 + + Args: + dependency: 插件级依赖声明。 + plugin_version: 当前可用的插件版本号。 + + Returns: + bool: 是否满足版本约束。 + """ + return self._version_matches_specifier(plugin_version, dependency.version_spec) + + def _validate_runtime_compatibility(self, manifest: PluginManifest) -> None: + """校验运行时版本兼容性与 Python 包依赖。 + + Args: + manifest: 已通过结构校验的强类型 Manifest。 + """ + host_ok, host_message = VersionComparator.is_in_range( + self._host_version, + manifest.host_application.min_version, + manifest.host_application.max_version, + ) + if not host_ok: + self.errors.append(f"Host 版本不兼容: {host_message} (当前 Host: {self._host_version})") + + sdk_ok, sdk_message = VersionComparator.is_in_range( + self._sdk_version, + manifest.sdk.min_version, + manifest.sdk.max_version, + ) + if not sdk_ok: + self.errors.append(f"SDK 版本不兼容: {sdk_message} (当前 SDK: {self._sdk_version})") + + self._validate_python_package_dependencies(manifest) + + def _validate_python_package_dependencies(self, manifest: PluginManifest) -> None: + """校验 Python 包依赖与主程序运行环境是否冲突。 + + Args: + manifest: 已通过结构校验的强类型 Manifest。 + """ + host_requirements = self._load_host_dependency_requirements(self._project_root) + + for dependency in manifest.python_package_dependencies: + normalized_package_name = canonicalize_name(dependency.name) + package_specifier = self._build_specifier_set(dependency.version_spec) + if package_specifier is None: + self.errors.append( + f"Python 包依赖 {dependency.name} 的版本约束无效: {dependency.version_spec}" + ) + continue + + installed_version = self._get_installed_package_version(dependency.name) + host_requirement = host_requirements.get(normalized_package_name) + + if installed_version is not None and not self._version_matches_specifier( + installed_version, + dependency.version_spec, + ): + self.errors.append( + f"Python 包依赖冲突: {dependency.name} 需要 {dependency.version_spec}," + f"当前运行环境为 {installed_version}" + ) + continue + + if host_requirement is None: + continue + + if not self._requirements_may_overlap(host_requirement.specifier, package_specifier): + host_specifier = str(host_requirement.specifier or "") + self.errors.append( + f"Python 包依赖冲突: {dependency.name} 需要 {dependency.version_spec}," + f"主程序依赖约束为 {host_specifier or '任意版本'}" + ) + + def _log_errors(self) -> None: + """输出当前累计的 Manifest 校验错误。""" + for error_message in self.errors: + logger.error(f"Manifest 校验失败: {error_message}") + + @classmethod + def _resolve_project_root(cls) -> Path: + """推断当前项目根目录。 + + Returns: + Path: 项目根目录路径。 + """ + return Path(__file__).resolve().parents[3] + + @classmethod + @lru_cache(maxsize=None) + def _detect_default_host_version(cls, project_root: Path) -> str: + """从主程序 ``pyproject.toml`` 探测 Host 版本号。 + + Args: + project_root: 项目根目录。 + + Returns: + str: 探测到的 Host 版本号;失败时返回空字符串。 + """ + pyproject_path = project_root / "pyproject.toml" + try: + with pyproject_path.open("rb") as pyproject_file: + pyproject_data = tomllib.load(pyproject_file) + except Exception: + return "" + + project_data = pyproject_data.get("project", {}) + if not isinstance(project_data, dict): + return "" + + raw_version = str(project_data.get("version", "") or "").strip() + if VersionComparator.is_valid_semver(raw_version): + return raw_version + return "" + + @classmethod + @lru_cache(maxsize=None) + def _detect_default_sdk_version(cls, project_root: Path) -> str: + """探测当前运行环境中的 SDK 版本号。 + + Args: + project_root: 项目根目录。 + + Returns: + str: 探测到的 SDK 版本号;失败时返回空字符串。 + """ + try: + raw_version = importlib_metadata.version("maibot-plugin-sdk") + if VersionComparator.is_valid_semver(raw_version): + return raw_version + except importlib_metadata.PackageNotFoundError: + pass + + sdk_pyproject_path = project_root / "packages" / "maibot-plugin-sdk" / "pyproject.toml" + try: + with sdk_pyproject_path.open("rb") as pyproject_file: + pyproject_data = tomllib.load(pyproject_file) + except Exception: + return "" + + project_data = pyproject_data.get("project", {}) + if not isinstance(project_data, dict): + return "" + + raw_version = str(project_data.get("version", "") or "").strip() + if VersionComparator.is_valid_semver(raw_version): + return raw_version + return "" + + @classmethod + @lru_cache(maxsize=None) + def _load_host_dependency_requirements(cls, project_root: Path) -> Dict[str, Requirement]: + """加载主程序 ``pyproject.toml`` 中声明的依赖约束。 + + Args: + project_root: 项目根目录。 + + Returns: + Dict[str, Requirement]: 以规范化包名为键的 Requirement 映射。 + """ + pyproject_path = project_root / "pyproject.toml" + try: + with pyproject_path.open("rb") as pyproject_file: + pyproject_data = tomllib.load(pyproject_file) + except Exception: + return {} + + project_data = pyproject_data.get("project", {}) + if not isinstance(project_data, dict): + return {} + + raw_dependencies = project_data.get("dependencies", []) + if not isinstance(raw_dependencies, list): + return {} + + requirements: Dict[str, Requirement] = {} + for raw_dependency in raw_dependencies: + dependency_text = str(raw_dependency or "").strip() + if not dependency_text: + continue + + try: + requirement = Requirement(dependency_text) + except InvalidRequirement: + continue + + requirements[canonicalize_name(requirement.name)] = requirement + + return requirements + + @staticmethod + def _get_installed_package_version(package_name: str) -> Optional[str]: + """获取当前运行环境中指定 Python 包的安装版本。 + + Args: + package_name: 待查询的包名。 + + Returns: + Optional[str]: 已安装版本号;未安装时返回 ``None``。 + """ + try: + return importlib_metadata.version(package_name) + except importlib_metadata.PackageNotFoundError: + return None + + @staticmethod + def _build_specifier_set(version_spec: str) -> Optional[SpecifierSet]: + """构造版本约束对象。 + + Args: + version_spec: 版本约束字符串。 + + Returns: + Optional[SpecifierSet]: 构造成功时返回约束对象,否则返回 ``None``。 + """ + try: + return SpecifierSet(version_spec) + except InvalidSpecifier: + return None + + @staticmethod + def _version_matches_specifier(version: str, version_spec: str) -> bool: + """判断版本是否满足给定约束。 + + Args: + version: 待判断的版本号。 + version_spec: 版本约束表达式。 + + Returns: + bool: 是否满足约束。 + """ + try: + normalized_version = Version(version) + specifier_set = SpecifierSet(version_spec) + except (InvalidVersion, InvalidSpecifier): + return False + return specifier_set.contains(normalized_version, prereleases=True) + + @classmethod + def _requirements_may_overlap(cls, left: SpecifierSet, right: SpecifierSet) -> bool: + """粗略判断两个版本约束是否存在交集。 + + Args: + left: 左侧版本约束。 + right: 右侧版本约束。 + + Returns: + bool: 若可能存在交集则返回 ``True``,否则返回 ``False``。 + """ + candidate_versions = cls._build_candidate_versions(left, right) + for candidate_version in candidate_versions: + if left.contains(candidate_version, prereleases=True) and right.contains(candidate_version, prereleases=True): + return True + return False + + @classmethod + def _build_candidate_versions(cls, left: SpecifierSet, right: SpecifierSet) -> List[Version]: + """为两个版本约束构造一组用于交集探测的候选版本。 + + Args: + left: 左侧版本约束。 + right: 右侧版本约束。 + + Returns: + List[Version]: 去重后的候选版本列表。 + """ + candidate_versions: List[Version] = [Version("0.0.0")] + for specifier in tuple(left) + tuple(right): + for candidate_version in cls._expand_candidate_versions(specifier.version): + if candidate_version not in candidate_versions: + candidate_versions.append(candidate_version) + return candidate_versions + + @staticmethod + def _expand_candidate_versions(raw_version: str) -> List[Version]: + """根据边界版本扩展出一组邻近候选版本。 + + Args: + raw_version: 约束中出现的边界版本字符串。 + + Returns: + List[Version]: 可用于交集探测的候选版本列表。 + """ + normalized_text = raw_version.replace("*", "0") + try: + boundary_version = Version(normalized_text) + except InvalidVersion: + return [] + + release_parts = list(boundary_version.release[:3]) + while len(release_parts) < 3: + release_parts.append(0) + major, minor, patch = release_parts[:3] + + candidates = { + Version(f"{major}.{minor}.{patch}"), + Version(f"{major}.{minor}.{patch + 1}"), + } + if patch > 0: + candidates.add(Version(f"{major}.{minor}.{patch - 1}")) + elif minor > 0: + candidates.add(Version(f"{major}.{minor - 1}.999")) + elif major > 0: + candidates.add(Version(f"{major - 1}.999.999")) + + return sorted(candidates) + + @classmethod + def _format_validation_errors(cls, exc: ValidationError) -> List[str]: + """将 Pydantic 校验错误转换为中文错误列表。 + + Args: + exc: Pydantic 抛出的校验异常。 + + Returns: + List[str]: 中文错误描述列表。 + """ + error_messages: List[str] = [] + for error in exc.errors(): + location = cls._format_error_location(error.get("loc", ())) + error_type = str(error.get("type", "")) + error_input = error.get("input") + error_context = error.get("ctx", {}) or {} + + if error_type == "missing": + error_messages.append(f"缺少必需字段: {location}") + elif error_type == "extra_forbidden": + error_messages.append(f"存在未声明字段: {location}") + elif error_type == "literal_error": + expected_values = error_context.get("expected") + error_messages.append(f"字段 {location} 的值不合法,必须为 {expected_values}") + elif error_type == "model_type": + error_messages.append(f"字段 {location} 必须为对象") + elif error_type.endswith("_type"): + error_messages.append(f"字段 {location} 的类型不正确") + elif error_type == "value_error": + error_messages.append(f"字段 {location} 校验失败: {error_context.get('error')}") + else: + error_messages.append(f"字段 {location} 校验失败: {error.get('msg', error_input)}") + + return error_messages + + @staticmethod + def _format_error_location(location: Tuple[Any, ...]) -> str: + """格式化校验错误字段路径。 + + Args: + location: Pydantic 提供的字段路径元组。 + + Returns: + str: 点号连接后的字段路径。 + """ + return ".".join(str(item) for item in location) if location else "" diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index f07eb593..3eaf9f23 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -13,16 +13,16 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple import contextlib import importlib import importlib.util -import json import os +import re import sys from src.common.logger import get_logger -from src.plugin_runtime.runner.manifest_validator import ManifestValidator +from src.plugin_runtime.runner.manifest_validator import ManifestValidator, PluginManifest logger = get_logger("plugin_runtime.runner.plugin_loader") -PluginCandidate = Tuple[Path, Dict[str, Any], Path] +PluginCandidate = Tuple[Path, PluginManifest, Path] class PluginMeta: @@ -34,7 +34,7 @@ class PluginMeta: plugin_dir: str, module_name: str, plugin_instance: Any, - manifest: Dict[str, Any], + manifest: PluginManifest, ) -> None: """初始化插件元数据。 @@ -43,36 +43,16 @@ class PluginMeta: plugin_dir: 插件目录绝对路径。 module_name: 插件入口模块名。 plugin_instance: 插件实例对象。 - manifest: 解析后的 manifest 内容。 + manifest: 解析后的强类型 Manifest。 """ self.plugin_id = plugin_id self.plugin_dir = plugin_dir self.module_name = module_name self.instance = plugin_instance self.manifest = manifest - self.version = manifest.get("version", "1.0.0") - self.capabilities_required = manifest.get("capabilities", []) - self.dependencies: List[str] = self._extract_dependencies(manifest) - - @staticmethod - def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]: - """从 manifest 中提取依赖列表。 - - Args: - manifest: 插件 manifest。 - - Returns: - List[str]: 规范化后的依赖插件 ID 列表。 - """ - raw = manifest.get("dependencies", []) - result: List[str] = [] - for dep in raw: - if isinstance(dep, str): - result.append(dep.strip()) - elif isinstance(dep, dict): - if name := str(dep.get("name", "")).strip(): - result.append(name) - return result + self.version = manifest.version + self.capabilities_required = list(manifest.capabilities) + self.dependencies: List[str] = list(manifest.plugin_dependency_ids) class PluginLoader: @@ -98,13 +78,13 @@ class PluginLoader: def discover_and_load( self, plugin_dirs: List[str], - extra_available: Optional[Set[str]] = None, + extra_available: Optional[Dict[str, str]] = None, ) -> List[PluginMeta]: """扫描多个目录并加载所有插件。 Args: plugin_dirs: 插件目录列表。 - extra_available: 额外视为已满足的外部依赖插件 ID 集合。 + extra_available: 额外视为已满足的外部依赖插件版本映射。 Returns: List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。 @@ -164,26 +144,17 @@ class PluginLoader: def _discover_single_candidate(self, plugin_dir: Path) -> Optional[Tuple[str, PluginCandidate]]: """发现并校验单个插件目录。""" - manifest_path = plugin_dir / "_manifest.json" plugin_path = plugin_dir / "plugin.py" - - if not manifest_path.exists() or not plugin_path.exists(): + if not plugin_path.exists(): return None - try: - with manifest_path.open("r", encoding="utf-8") as manifest_file: - manifest: Dict[str, Any] = json.load(manifest_file) - except Exception as e: - self._failed_plugins[plugin_dir.name] = f"manifest 解析失败: {e}" - logger.error(f"插件 {plugin_dir.name} manifest 解析失败: {e}") - return None - - if not self._manifest_validator.validate(manifest): + manifest = self._manifest_validator.load_from_plugin_path(plugin_dir) + if manifest is None: errors = "; ".join(self._manifest_validator.errors) self._failed_plugins[plugin_dir.name] = f"manifest 校验失败: {errors}" return None - plugin_id = str(manifest.get("name", plugin_dir.name)).strip() or plugin_dir.name + plugin_id = manifest.id return plugin_id, (plugin_dir, manifest, plugin_path) def _record_duplicate_candidates(self, duplicate_candidates: Dict[str, List[Path]]) -> None: @@ -253,7 +224,7 @@ class PluginLoader: """ removed_modules: List[str] = [] plugin_path = Path(plugin_dir).resolve() - synthetic_module_name = f"_maibot_plugin_{plugin_id}" + synthetic_module_name = self._build_safe_module_name(plugin_id) for module_name, module in list(sys.modules.items()): if module_name == synthetic_module_name: @@ -277,6 +248,21 @@ class PluginLoader: importlib.invalidate_caches() return removed_modules + @staticmethod + def _build_safe_module_name(plugin_id: str) -> str: + """将插件 ID 转换为可用于动态导入的安全模块名。 + + Args: + plugin_id: 原始插件 ID。 + + Returns: + str: 仅包含字母、数字和下划线的合成模块名。 + """ + normalized_plugin_id = re.sub(r"[^0-9A-Za-z_]", "_", str(plugin_id or "").strip()) + if normalized_plugin_id and normalized_plugin_id[0].isdigit(): + normalized_plugin_id = f"_{normalized_plugin_id}" + return f"_maibot_plugin_{normalized_plugin_id or 'plugin'}" + def list_plugins(self) -> List[str]: """列出所有已加载的插件 ID""" return list(self._loaded_plugins.keys()) @@ -286,18 +272,27 @@ class PluginLoader: """返回当前记录的失败插件原因映射。""" return dict(self._failed_plugins) + @property + def manifest_validator(self) -> ManifestValidator: + """返回当前加载器持有的 Manifest 校验器。 + + Returns: + ManifestValidator: 当前使用的 Manifest 校验器实例。 + """ + return self._manifest_validator + # ──── 依赖解析 ──────────────────────────────────────────── def resolve_dependencies( self, candidates: Dict[str, PluginCandidate], - extra_available: Optional[Set[str]] = None, + extra_available: Optional[Dict[str, str]] = None, ) -> Tuple[List[str], Dict[str, str]]: """解析候选插件的依赖顺序。 Args: candidates: 待加载的候选插件集合。 - extra_available: 视为已满足的外部依赖插件 ID 集合。 + extra_available: 视为已满足的外部依赖插件版本映射。 Returns: Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。 @@ -320,36 +315,71 @@ class PluginLoader: def _resolve_dependencies( self, candidates: Dict[str, PluginCandidate], - extra_available: Optional[Set[str]] = None, + extra_available: Optional[Dict[str, str]] = None, ) -> Tuple[List[str], Dict[str, str]]: """拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。""" available = set(candidates.keys()) - satisfied_dependencies = set(extra_available or set()) + satisfied_dependencies = { + str(plugin_id or "").strip(): str(plugin_version or "").strip() + for plugin_id, plugin_version in (extra_available or {}).items() + if str(plugin_id or "").strip() and str(plugin_version or "").strip() + } dep_graph: Dict[str, Set[str]] = {} failed: Dict[str, str] = {} for pid, (_, manifest, _) in candidates.items(): - raw_deps = manifest.get("dependencies", []) resolved: Set[str] = set() - missing: List[str] = [] - for dep in raw_deps: - dep_name = dep if isinstance(dep, str) else str(dep.get("name", "")) - dep_name = dep_name.strip() - if not dep_name or dep_name == pid: + missing_or_incompatible: List[str] = [] + + for dependency in manifest.plugin_dependencies: + dependency_id = dependency.id + if dependency_id in available: + dependency_manifest = candidates[dependency_id][1] + if not self._manifest_validator.is_plugin_dependency_satisfied( + dependency, + dependency_manifest.version, + ): + missing_or_incompatible.append( + f"{dependency_id} (需要 {dependency.version_spec},当前 {dependency_manifest.version})" + ) + continue + resolved.add(dependency_id) continue - if dep_name in available: - resolved.add(dep_name) - elif dep_name in satisfied_dependencies: + + external_dependency_version = satisfied_dependencies.get(dependency_id) + if external_dependency_version is None: + missing_or_incompatible.append(f"{dependency_id} (未找到依赖插件)") continue - else: - missing.append(dep_name) - if missing: - failed[pid] = f"缺少依赖: {', '.join(missing)}" + + if not self._manifest_validator.is_plugin_dependency_satisfied( + dependency, + external_dependency_version, + ): + missing_or_incompatible.append( + f"{dependency_id} (需要 {dependency.version_spec},当前 {external_dependency_version})" + ) + + if missing_or_incompatible: + failed[pid] = f"依赖未满足: {', '.join(missing_or_incompatible)}" dep_graph[pid] = resolved - # 移除失败项 - for pid in failed: - dep_graph.pop(pid, None) + # 迭代传播“依赖自身加载失败”到上游依赖方,避免误报为循环依赖 + changed = True + while changed: + changed = False + failed_plugin_ids = set(failed) + for pid, dependencies in list(dep_graph.items()): + if pid in failed: + dep_graph.pop(pid, None) + continue + + failed_dependencies = sorted(dependency for dependency in dependencies if dependency in failed_plugin_ids) + if not failed_dependencies: + continue + + failed[pid] = f"依赖未满足: {', '.join(f'{dependency} (依赖插件加载失败)' for dependency in failed_dependencies)}" + dep_graph.pop(pid, None) + changed = True # Kahn 拓扑排序 indegree = {pid: len(deps) for pid, deps in dep_graph.items()} @@ -382,7 +412,7 @@ class PluginLoader: self, plugin_id: str, plugin_dir: Path, - manifest: Dict[str, Any], + manifest: PluginManifest, plugin_path: Path, ) -> Optional[PluginMeta]: """加载单个插件""" @@ -390,7 +420,7 @@ class PluginLoader: self._ensure_compat_hook() # 动态导入插件模块 - module_name = f"_maibot_plugin_{plugin_id}" + module_name = self._build_safe_module_name(plugin_id) spec = importlib.util.spec_from_file_location(module_name, str(plugin_path)) if spec is None or spec.loader is None: logger.error(f"无法创建模块 spec: {plugin_path}") @@ -409,7 +439,7 @@ class PluginLoader: if create_plugin is not None: instance = create_plugin() self._validate_sdk_plugin_contract(plugin_id, instance) - logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") + logger.info(f"插件 {plugin_id} v{manifest.version} 加载成功") return PluginMeta( plugin_id=plugin_id, plugin_dir=str(plugin_dir), @@ -422,7 +452,7 @@ class PluginLoader: instance = self._try_load_legacy_plugin(module, plugin_id) if instance is not None: logger.info( - f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" + f"插件 {plugin_id} v{manifest.version} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" ) return PluginMeta( plugin_id=plugin_id, diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index c0f5e771..f5c32f7e 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -10,7 +10,7 @@ """ from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, cast import asyncio import contextlib @@ -47,7 +47,7 @@ from src.plugin_runtime.protocol.envelope import ( ) from src.plugin_runtime.protocol.errors import ErrorCode from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler -from src.plugin_runtime.runner.plugin_loader import PluginLoader, PluginMeta +from src.plugin_runtime.runner.plugin_loader import PluginCandidate, PluginLoader, PluginMeta from src.plugin_runtime.runner.rpc_client import RPCClient logger = get_logger("plugin_runtime.runner.main") @@ -119,7 +119,7 @@ class PluginRunner: host_address: str, session_token: str, plugin_dirs: List[str], - external_available_plugin_ids: Optional[List[str]] = None, + external_available_plugins: Optional[Dict[str, str]] = None, ) -> None: """初始化 Runner。 @@ -127,15 +127,15 @@ class PluginRunner: host_address: Host 的 IPC 地址。 session_token: 握手用会话令牌。 plugin_dirs: 当前 Runner 负责扫描的插件目录列表。 - external_available_plugin_ids: 视为已满足的外部依赖插件 ID 列表。 + external_available_plugins: 视为已满足的外部依赖插件版本映射。 """ self._host_address: str = host_address self._session_token: str = session_token self._plugin_dirs: List[str] = plugin_dirs - self._external_available_plugin_ids: Set[str] = { - str(plugin_id or "").strip() - for plugin_id in (external_available_plugin_ids or []) - if str(plugin_id or "").strip() + self._external_available_plugins: Dict[str, str] = { + str(plugin_id or "").strip(): str(plugin_version or "").strip() + for plugin_id, plugin_version in (external_available_plugins or {}).items() + if str(plugin_id or "").strip() and str(plugin_version or "").strip() } self._rpc_client: RPCClient = RPCClient(host_address, session_token) @@ -166,7 +166,7 @@ class PluginRunner: # 3. 加载插件 plugins = self._loader.discover_and_load( self._plugin_dirs, - extra_available=self._external_available_plugin_ids, + extra_available=self._external_available_plugins, ) logger.info(f"已加载 {len(plugins)} 个插件") @@ -611,14 +611,14 @@ class PluginRunner: self, plugin_id: str, reason: str, - external_available_plugins: Optional[Set[str]] = None, + external_available_plugins: Optional[Dict[str, str]] = None, ) -> ReloadPluginResultPayload: """按插件 ID 在 Runner 进程内执行精确重载。 Args: plugin_id: 目标插件 ID。 reason: 重载原因。 - external_available_plugins: 视为已满足的外部依赖插件 ID 集合。 + external_available_plugins: 视为已满足的外部依赖插件版本映射。 Returns: ReloadPluginResultPayload: 结构化重载结果。 @@ -626,9 +626,9 @@ class PluginRunner: candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs) failed_plugins: Dict[str, str] = {} normalized_external_available = { - str(candidate_plugin_id or "").strip() - for candidate_plugin_id in (external_available_plugins or set()) - if str(candidate_plugin_id or "").strip() + str(candidate_plugin_id or "").strip(): str(candidate_plugin_version or "").strip() + for candidate_plugin_id, candidate_plugin_version in (external_available_plugins or {}).items() + if str(candidate_plugin_id or "").strip() and str(candidate_plugin_version or "").strip() } if plugin_id in duplicate_candidates: @@ -668,7 +668,7 @@ class PluginRunner: self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir) unloaded_plugins.append(unload_plugin_id) - reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {} + reload_candidates: Dict[str, PluginCandidate] = {} for target_plugin_id in target_plugin_ids: candidate = candidates.get(target_plugin_id) if candidate is None: @@ -678,11 +678,25 @@ class PluginRunner: load_order, dependency_failures = self._loader.resolve_dependencies( reload_candidates, - extra_available=retained_plugin_ids | normalized_external_available, + extra_available={ + **normalized_external_available, + **{ + retained_plugin_id: retained_meta.version + for retained_plugin_id in retained_plugin_ids + if (retained_meta := self._loader.get_plugin(retained_plugin_id)) is not None + }, + }, ) failed_plugins.update(dependency_failures) - available_plugins = set(retained_plugin_ids) | normalized_external_available + available_plugins = { + **normalized_external_available, + **{ + retained_plugin_id: retained_meta.version + for retained_plugin_id in retained_plugin_ids + if (retained_meta := self._loader.get_plugin(retained_plugin_id)) is not None + }, + } reloaded_plugins: List[str] = [] for load_plugin_id in load_order: @@ -694,10 +708,12 @@ class PluginRunner: continue _, manifest, _ = candidate - dependencies = PluginMeta._extract_dependencies(manifest) - missing_dependencies = [dependency for dependency in dependencies if dependency not in available_plugins] - if missing_dependencies: - failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(missing_dependencies)}" + unsatisfied_dependencies = self._loader.manifest_validator.get_unsatisfied_plugin_dependencies( + manifest, + available_plugin_versions=available_plugins, + ) + if unsatisfied_dependencies: + failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}" continue meta = self._loader.load_candidate(load_plugin_id, candidate) @@ -710,7 +726,7 @@ class PluginRunner: failed_plugins[load_plugin_id] = "插件初始化失败" continue - available_plugins.add(load_plugin_id) + available_plugins[load_plugin_id] = meta.version reloaded_plugins.append(load_plugin_id) if failed_plugins: @@ -1079,7 +1095,7 @@ class PluginRunner: result = await self._reload_plugin_by_id( payload.plugin_id, payload.reason, - external_available_plugins=set(payload.external_available_plugins), + external_available_plugins=dict(payload.external_available_plugins), ) return envelope.make_response(payload=result.model_dump()) @@ -1185,13 +1201,13 @@ async def _async_main() -> None: plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d] try: - external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else [] + external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else {} except json.JSONDecodeError: - logger.warning("解析外部依赖插件列表失败,已回退为空列表") - external_plugin_ids = [] - if not isinstance(external_plugin_ids, list): - logger.warning("外部依赖插件列表格式非法,已回退为空列表") - external_plugin_ids = [] + logger.warning("解析外部依赖插件版本映射失败,已回退为空映射") + external_plugin_ids = {} + if not isinstance(external_plugin_ids, dict): + logger.warning("外部依赖插件版本映射格式非法,已回退为空映射") + external_plugin_ids = {} # sys.path 隔离: 只保留标准库、SDK 包、插件目录 _isolate_sys_path(plugin_dirs) @@ -1200,7 +1216,10 @@ async def _async_main() -> None: host_address, session_token, plugin_dirs, - external_available_plugin_ids=[str(plugin_id) for plugin_id in external_plugin_ids], + external_available_plugins={ + str(plugin_id): str(plugin_version) + for plugin_id, plugin_version in external_plugin_ids.items() + }, ) # 注册信号处理 diff --git a/src/plugins/built_in/emoji_plugin/_manifest.json b/src/plugins/built_in/emoji_plugin/_manifest.json index d4d262e7..5b53abad 100644 --- a/src/plugins/built_in/emoji_plugin/_manifest.json +++ b/src/plugins/built_in/emoji_plugin/_manifest.json @@ -1,32 +1,28 @@ { - "manifest_version": 1, - "name": "Emoji插件 (Emoji Actions)", + "manifest_version": 2, "version": "2.0.0", - "description": "可以发送和管理Emoji", + "name": "Emoji插件 (Emoji Actions)", + "description": "可以发送和管理 Emoji", "author": { "name": "SengokuCola", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", + "urls": { + "repository": "https://github.com/MaiM-with-u/maibot", + "homepage": "https://github.com/MaiM-with-u/maibot", + "documentation": "https://github.com/MaiM-with-u/maibot", + "issues": "https://github.com/MaiM-with-u/maibot/issues" + }, "host_application": { - "min_version": "1.0.0" + "min_version": "1.0.0", + "max_version": "1.0.0" }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": ["emoji", "action", "built-in"], - "categories": ["Emoji"], - "default_locale": "zh-CN", - "plugin_info": { - "is_built_in": true, - "plugin_type": "action_provider", - "components": [ - { - "type": "action", - "name": "emoji", - "description": "发送表情包辅助表达情绪" - } - ] + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" }, + "dependencies": [], "capabilities": [ "emoji.get_random", "message.get_recent", @@ -34,5 +30,12 @@ "llm.generate", "send.emoji", "config.get" - ] + ], + "i18n": { + "default_locale": "zh-CN", + "supported_locales": [ + "zh-CN" + ] + }, + "id": "builtin.emoji-plugin" } diff --git a/src/plugins/built_in/plugin_management/_manifest.json b/src/plugins/built_in/plugin_management/_manifest.json index a5b52835..a2bfa9ce 100644 --- a/src/plugins/built_in/plugin_management/_manifest.json +++ b/src/plugins/built_in/plugin_management/_manifest.json @@ -1,51 +1,46 @@ { - "manifest_version": 1, - "name": "插件和组件管理 (Plugin and Component Management)", + "manifest_version": 2, "version": "2.0.0", - "description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。", + "name": "插件和组件管理 (Plugin and Component Management)", + "description": "通过系统 API 管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。", "author": { "name": "MaiBot团队", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "1.0.0" + "urls": { + "repository": "https://github.com/MaiM-with-u/maibot", + "homepage": "https://github.com/MaiM-with-u/maibot", + "documentation": "https://github.com/MaiM-with-u/maibot", + "issues": "https://github.com/MaiM-with-u/maibot/issues" }, - "homepage_url": "https://github.com/MaiM-with-u/maibot", - "repository_url": "https://github.com/MaiM-with-u/maibot", - "keywords": [ - "plugins", - "components", - "management", - "built-in" + "host_application": { + "min_version": "1.0.0", + "max_version": "1.0.0" + }, + "sdk": { + "min_version": "2.0.0", + "max_version": "2.99.99" + }, + "dependencies": [], + "capabilities": [ + "component.get_all_plugins", + "component.list_loaded_plugins", + "component.list_registered_plugins", + "component.enable", + "component.disable", + "component.load_plugin", + "component.unload_plugin", + "component.reload_plugin", + "send.text", + "config.get" ], - "categories": [ - "Core System", - "Plugin Management" - ], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": true, - "plugin_type": "plugin_management", - "capabilities": [ - "component.get_all_plugins", - "component.list_loaded_plugins", - "component.list_registered_plugins", - "component.enable", - "component.disable", - "component.load_plugin", - "component.unload_plugin", - "component.reload_plugin", - "send.text", - "config.get" - ], - "components": [ - { - "type": "command", - "name": "management", - "description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。" - } + "i18n": { + "default_locale": "zh-CN", + "locales_path": "_locales", + "supported_locales": [ + "zh-CN" ] - } -} \ No newline at end of file + }, + "id": "builtin.plugin-management" +} From a61b124c93e456196216fc5724112e62867aedac Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 23:05:05 +0800 Subject: [PATCH 38/45] feat: enhance global config handling and component resolution in plugin runtime --- src/plugin_runtime/host/supervisor.py | 4 +- src/plugin_runtime/integration.py | 4 +- src/plugin_runtime/runner/plugin_loader.py | 7 +- src/plugin_runtime/runner/runner_main.py | 90 ++++++++++++++++------ 4 files changed, 75 insertions(+), 30 deletions(-) diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index ac953bb3..d12014f6 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -1146,8 +1146,8 @@ class PluginRunnerSupervisor: Returns: Dict[str, str]: 传递给 Runner 进程的环境变量映射。 """ - global_config_snapshot = config_manager.get_global_config().model_dump() - global_config_snapshot["model"] = config_manager.get_model_config().model_dump() + global_config_snapshot = config_manager.get_global_config().model_dump(mode="json") + global_config_snapshot["model"] = config_manager.get_model_config().model_dump(mode="json") return { ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False), ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False), diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 092b9597..c34f5ef5 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -543,9 +543,9 @@ class PluginRuntimeManager( normalized_scopes = self._normalize_config_reload_scopes(changed_scopes) if "bot" in normalized_scopes: - await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump()) + await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump(mode="json")) if "model" in normalized_scopes: - await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump()) + await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump(mode="json")) # ─── 事件桥接 ────────────────────────────────────────────── diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 3eaf9f23..6e85714b 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -53,6 +53,7 @@ class PluginMeta: self.version = manifest.version self.capabilities_required = list(manifest.capabilities) self.dependencies: List[str] = list(manifest.plugin_dependency_ids) + self.component_handlers: Dict[str, str] = {} class PluginLoader: @@ -421,7 +422,11 @@ class PluginLoader: # 动态导入插件模块 module_name = self._build_safe_module_name(plugin_id) - spec = importlib.util.spec_from_file_location(module_name, str(plugin_path)) + spec = importlib.util.spec_from_file_location( + module_name, + str(plugin_path), + submodule_search_locations=[str(plugin_dir)], + ) if spec is None or spec.loader is None: logger.error(f"无法创建模块 spec: {plugin_path}") return None diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index f5c32f7e..39c741bd 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -335,6 +335,45 @@ class PluginRunner: self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated) self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin) + @staticmethod + def _resolve_component_handler_name(meta: PluginMeta, component_name: str) -> str: + """解析组件名对应的真实处理函数名。 + + Args: + meta: 已加载插件的元数据。 + component_name: Host 侧请求中的组件声明名。 + + Returns: + str: 实际应在插件实例上查找的方法名。 + """ + return str(meta.component_handlers.get(component_name, component_name) or component_name) + + def _resolve_component_handler(self, meta: PluginMeta, component_name: str) -> Any: + """根据组件声明名解析插件实例上的可调用处理函数。 + + Args: + meta: 已加载插件的元数据。 + component_name: Host 侧请求中的组件声明名。 + + Returns: + Any: 解析到的可调用对象;未找到时返回 ``None``。 + """ + instance = meta.instance + handler_name = self._resolve_component_handler_name(meta, component_name) + handler_method = getattr(instance, handler_name, None) + if handler_method is not None: + return handler_method + + if handler_name != component_name: + legacy_style_handler = getattr(instance, f"handle_{component_name}", None) + if legacy_style_handler is not None: + return legacy_style_handler + + prefixed_handler = getattr(instance, f"handle_{component_name}", None) + if prefixed_handler is not None: + return prefixed_handler + return getattr(instance, component_name, None) + async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool: """向 Host 同步插件 bootstrap 能力令牌。""" payload = BootstrapPluginPayload( @@ -379,15 +418,27 @@ class PluginRunner: # 从插件实例获取组件声明(SDK 插件须实现 get_components 方法) if hasattr(instance, "get_components"): - components.extend( - ComponentDeclaration( - name=comp_info.get("name", ""), - component_type=comp_info.get("type", ""), - plugin_id=meta.plugin_id, - metadata=comp_info.get("metadata", {}), + meta.component_handlers.clear() + for comp_info in instance.get_components(): + if not isinstance(comp_info, dict): + continue + + component_name = str(comp_info.get("name", "") or "").strip() + raw_metadata = comp_info.get("metadata", {}) + component_metadata = raw_metadata if isinstance(raw_metadata, dict) else {} + handler_name = str(component_metadata.get("handler_name", component_name) or component_name).strip() + + if component_name: + meta.component_handlers[component_name] = handler_name or component_name + + components.append( + ComponentDeclaration( + name=component_name, + component_type=str(comp_info.get("type", "") or "").strip(), + plugin_id=meta.plugin_id, + metadata=component_metadata, + ) ) - for comp_info in instance.get_components() - ) if hasattr(instance, "get_config_reload_subscriptions"): config_reload_subscriptions = list(instance.get_config_reload_subscriptions()) @@ -812,19 +863,13 @@ class PluginRunner: f"插件 {plugin_id} 未加载", ) - # 调用插件实例的组件方法 - instance = meta.instance component_name = invoke.component_name - - # 优先查找 handle_ 或直接 方法(新版 SDK 插件) - handler_method = getattr(instance, f"handle_{component_name}", None) - if handler_method is None: - handler_method = getattr(instance, component_name, None) + handler_method = self._resolve_component_handler(meta, component_name) # 回退: 旧版 LegacyPluginAdapter 通过 invoke_component 统一桥接 - if (handler_method is None or not callable(handler_method)) and hasattr(instance, "invoke_component"): + if (handler_method is None or not callable(handler_method)) and hasattr(meta.instance, "invoke_component"): try: - result = await instance.invoke_component(component_name, **invoke.args) + result = await meta.instance.invoke_component(component_name, **invoke.args) resp_payload = InvokeResultPayload(success=True, result=result) return envelope.make_response(payload=resp_payload.model_dump()) except Exception as e: @@ -871,11 +916,8 @@ class PluginRunner: f"插件 {plugin_id} 未加载", ) - instance = meta.instance component_name = invoke.component_name - handler_method = getattr(instance, f"handle_{component_name}", None) - if handler_method is None: - handler_method = getattr(instance, component_name, None) + handler_method = self._resolve_component_handler(meta, component_name) if handler_method is None or not callable(handler_method): return envelope.make_error_response( @@ -933,9 +975,8 @@ class PluginRunner: f"插件 {plugin_id} 未加载", ) - instance = meta.instance component_name = invoke.component_name - handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None) + handler_method = self._resolve_component_handler(meta, component_name) if handler_method is None or not callable(handler_method): return envelope.make_error_response( ErrorCode.E_METHOD_NOT_ALLOWED.value, @@ -985,9 +1026,8 @@ class PluginRunner: f"插件 {plugin_id} 未加载", ) - instance = meta.instance component_name = invoke.component_name - handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None) + handler_method = self._resolve_component_handler(meta, component_name) if handler_method is None or not callable(handler_method): return envelope.make_error_response( From 17b7306188dc351240d69f0444e537d5b9324a59 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 23:11:57 +0800 Subject: [PATCH 39/45] =?UTF-8?q?fix:=20=E4=BD=BF=E7=94=A8=20f-string=20?= =?UTF-8?q?=E6=94=B9=E8=BF=9B=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/platform_io/manager.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index ab1b11e5..dee553a6 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -71,7 +71,7 @@ class PlatformIOManager: try: await driver.stop() except Exception: - logger.exception("回滚驱动停止失败: driver_id=%s", driver.driver_id) + logger.exception(f"回滚驱动停止失败: driver_id={driver.driver_id}") raise self._started = True @@ -104,7 +104,7 @@ class PlatformIOManager: await driver.stop() except Exception as exc: stop_errors.append(f"{driver.driver_id}: {exc}") - logger.exception("驱动停止失败: driver_id=%s", driver.driver_id) + logger.exception(f"驱动停止失败: driver_id={driver.driver_id}") self._started = False self._deduplicator.clear() @@ -448,9 +448,8 @@ class PlatformIOManager: if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id): logger.info( - "忽略未登记到接收路由表的入站消息: route=%s driver=%s", - envelope.route_key, - envelope.driver_id, + f"忽略未登记到接收路由表的入站消息: route={envelope.route_key} " + f"driver={envelope.driver_id}" ) return False @@ -461,7 +460,7 @@ class PlatformIOManager: dedupe_key = self._build_inbound_dedupe_key(envelope) if dedupe_key is not None: if not self._deduplicator.mark_seen(dedupe_key): - logger.info("忽略重复入站消息: dedupe_key=%s", dedupe_key) + logger.info(f"忽略重复入站消息: dedupe_key={dedupe_key}") return False await self._inbound_dispatcher(envelope) From 78858f70043c9c6b7e90b9fffffcee0dec1bfb5c Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 23:15:51 +0800 Subject: [PATCH 40/45] =?UTF-8?q?fix:=20=E6=9B=B4=E6=96=B0=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E9=80=82=E9=85=8D=E5=99=A8=E5=90=8D=E7=A7=B0=E4=BB=A5?= =?UTF-8?q?=E5=8F=8D=E6=98=A0=E6=96=B0=E7=9A=84=E5=91=BD=E5=90=8D=E7=BA=A6?= =?UTF-8?q?=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/logger_color_and_mapping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index 4044d2dc..863c9c1e 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -65,7 +65,7 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = { "plugin_runtime.runner.rpc_client": ("#8787ff", None, False), "plugin_runtime.runner.manifest_validator": ("#5fafff", None, False), "plugin_runtime.runner.plugin_loader": ("#00afaf", None, False), - "plugin.napcat_adapter_builtin": ("#00af87", None, False), + "plugin.maibot-team.napcat-adapter": ("#00af87", None, False), "webui": ("#5f87ff", None, False), "webui.app": ("#5f87d7", None, False), "webui.api": ("#5fafff", None, False), @@ -173,7 +173,7 @@ MODULE_ALIASES = { "plugin_runtime.runner.rpc_client": "插件RPC客户端", "plugin_runtime.runner.manifest_validator": "插件清单校验", "plugin_runtime.runner.plugin_loader": "插件加载器", - "plugin.napcat_adapter_builtin": "NapCat内置适配器", + "plugin.maibot-team.napcat-adapter": "NapCat内置适配器", "webui": "WebUI", "webui.app": "WebUI应用", "webui.api": "WebUI接口", From 1b61e515541a74ee6faa8b99423bcaea8ae6d344 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Tue, 24 Mar 2026 10:55:58 +0800 Subject: [PATCH 41/45] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=89=B9?= =?UTF-8?q?=E9=87=8F=E6=8F=92=E4=BB=B6=E9=87=8D=E8=BD=BD=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=8F=8A=E7=9B=B8=E5=85=B3=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_plugin_runtime.py | 171 ++++++++++++++++++ src/plugin_runtime/host/component_registry.py | 6 +- src/plugin_runtime/host/supervisor.py | 67 ++++++- src/plugin_runtime/protocol/envelope.py | 29 +++ src/plugin_runtime/runner/runner_main.py | 166 ++++++++++++++--- 5 files changed, 398 insertions(+), 41 deletions(-) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 1d93ae24..f3e1e7ce 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -676,6 +676,101 @@ class TestSDK: methods = [call["method"] for call in runner._rpc_client.calls] assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"] + @pytest.mark.asyncio + async def test_runner_batch_reload_merges_overlapping_reverse_dependents(self, monkeypatch): + """批量重载应只对重叠依赖闭包执行一次 unload/load。""" + from src.plugin_runtime.runner.runner_main import PluginRunner + + runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) + plugin_a_id = "test.plugin-a" + plugin_b_id = "test.plugin-b" + plugin_c_id = "test.plugin-c" + + def build_meta(plugin_id: str, dependencies: list[str]) -> SimpleNamespace: + return SimpleNamespace( + plugin_id=plugin_id, + dependencies=dependencies, + plugin_dir=f"/tmp/{plugin_id}", + version="1.0.0", + instance=SimpleNamespace(), + ) + + loaded_metas = { + plugin_a_id: build_meta(plugin_a_id, []), + plugin_b_id: build_meta(plugin_b_id, [plugin_a_id]), + plugin_c_id: build_meta(plugin_c_id, [plugin_b_id]), + } + reloaded_metas = { + plugin_id: build_meta(plugin_id, list(meta.dependencies)) + for plugin_id, meta in loaded_metas.items() + } + candidates = { + plugin_a_id: ( + "dir_plugin_a", + build_test_manifest_model(plugin_a_id), + "plugin_a/plugin.py", + ), + plugin_b_id: ( + "dir_plugin_b", + build_test_manifest_model( + plugin_b_id, + dependencies=[{"type": "plugin", "id": plugin_a_id, "version_spec": ">=1.0.0,<2.0.0"}], + ), + "plugin_b/plugin.py", + ), + plugin_c_id: ( + "dir_plugin_c", + build_test_manifest_model( + plugin_c_id, + dependencies=[{"type": "plugin", "id": plugin_b_id, "version_spec": ">=1.0.0,<2.0.0"}], + ), + "plugin_c/plugin.py", + ), + } + unloaded_plugins: list[str] = [] + activated_plugins: list[str] = [] + + monkeypatch.setattr(runner._loader, "discover_candidates", lambda plugin_dirs: (candidates, {})) + monkeypatch.setattr(runner._loader, "list_plugins", lambda: sorted(loaded_metas.keys())) + monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: loaded_metas.get(plugin_id)) + monkeypatch.setattr( + runner._loader, + "remove_loaded_plugin", + lambda plugin_id: loaded_metas.pop(plugin_id, None), + ) + monkeypatch.setattr(runner._loader, "purge_plugin_modules", lambda plugin_id, plugin_dir: []) + monkeypatch.setattr( + runner._loader, + "resolve_dependencies", + lambda reload_candidates, extra_available=None: (sorted(reload_candidates.keys()), {}), + ) + monkeypatch.setattr( + runner._loader, + "load_candidate", + lambda plugin_id, candidate: reloaded_metas[plugin_id], + ) + + async def fake_unload_plugin(meta, reason, purge_modules=False): + del reason, purge_modules + unloaded_plugins.append(meta.plugin_id) + loaded_metas.pop(meta.plugin_id, None) + + async def fake_activate_plugin(meta): + activated_plugins.append(meta.plugin_id) + loaded_metas[meta.plugin_id] = meta + return True + + monkeypatch.setattr(runner, "_unload_plugin", fake_unload_plugin) + monkeypatch.setattr(runner, "_activate_plugin", fake_activate_plugin) + + result = await runner._reload_plugins_by_ids([plugin_a_id, plugin_b_id], reason="manual") + + assert result.success is True + assert result.requested_plugin_ids == [plugin_a_id, plugin_b_id] + assert unloaded_plugins == [plugin_c_id, plugin_b_id, plugin_a_id] + assert activated_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] + assert result.reloaded_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] + class TestPluginSdkUsage: """验证仓库内插件按新 SDK 归一化返回值工作。""" @@ -1220,6 +1315,25 @@ class TestDependencyResolution: sys.path[:] = original_path sys.meta_path[:] = original_meta_path + def test_isolate_sys_path_blocks_disallowed_src_imports(self): + import importlib + + from src.plugin_runtime.runner import runner_main + + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + sys.modules.pop("src.forbidden_demo", None) + + try: + runner_main._isolate_sys_path([]) + + with pytest.raises(ImportError, match="不允许导入主程序模块"): + importlib.import_module("src.forbidden_demo") + finally: + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + sys.modules.pop("src.forbidden_demo", None) + # ─── Host-side ComponentRegistry 测试 ────────────────────── @@ -1264,6 +1378,30 @@ class TestComponentRegistry: assert stats["command"] == 1 assert stats["tool"] == 1 + def test_register_command_with_invalid_regex_only_warns(self, monkeypatch): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + warnings: list[str] = [] + monkeypatch.setattr( + "src.plugin_runtime.host.component_registry.logger.warning", + lambda message: warnings.append(str(message)), + ) + + success = reg.register_component( + "broken", + "command", + "plugin_a", + { + "command_pattern": "[", + }, + ) + + assert success is True + assert reg.get_component("plugin_a.broken") is not None + assert warnings + assert "plugin_a.broken" in warnings[0] + def test_query_by_type(self): from src.plugin_runtime.host.component_registry import ComponentRegistry @@ -2303,6 +2441,39 @@ class TestSupervisor: assert supervisor.component_registry.get_component("plugin_a.handler") is not None assert supervisor.component_registry.get_component("plugin_a.obsolete") is None + @pytest.mark.asyncio + async def test_reload_plugins_uses_batch_rpc_for_multiple_roots(self): + from src.plugin_runtime.host.supervisor import PluginSupervisor + from src.plugin_runtime.protocol.envelope import ReloadPluginsResultPayload + + supervisor = PluginSupervisor(plugin_dirs=[]) + sent_requests: list[tuple[str, dict[str, object], int]] = [] + + class FakeRPCServer: + async def send_request(self, method, payload, timeout_ms=5000, **kwargs): + del kwargs + sent_requests.append((method, payload, timeout_ms)) + return SimpleNamespace( + payload=ReloadPluginsResultPayload( + success=True, + requested_plugin_ids=["plugin_a", "plugin_b"], + reloaded_plugins=["plugin_a", "plugin_b", "plugin_c"], + unloaded_plugins=["plugin_c", "plugin_b", "plugin_a"], + ).model_dump() + ) + + supervisor._rpc_server = FakeRPCServer() + + reloaded = await supervisor.reload_plugins(["plugin_a", "plugin_b", "plugin_a"], reason="manual") + + assert reloaded is True + assert len(sent_requests) == 1 + method, payload, timeout_ms = sent_requests[0] + assert method == "plugin.reload_batch" + assert payload["plugin_ids"] == ["plugin_a", "plugin_b"] + assert payload["reason"] == "manual" + assert timeout_ms >= 10000 + @pytest.mark.asyncio async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 1c073490..97fdca30 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -75,14 +75,14 @@ class CommandEntry(ComponentEntry): """Command 组件条目""" def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - self.compiled_pattern: Optional[re.Pattern] = None + super().__init__(name, component_type, plugin_id, metadata) self.aliases: List[str] = metadata.get("aliases", []) + self.compiled_pattern: Optional[re.Pattern] = None if pattern := metadata.get("command_pattern", ""): try: self.compiled_pattern = re.compile(pattern) - except re.error as e: + except (re.error, TypeError) as e: logger.warning(f"命令 {self.full_name} 正则编译失败: {e}") - super().__init__(name, component_type, plugin_id, metadata) class ToolEntry(ComponentEntry): diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index d12014f6..08638d16 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -33,6 +33,8 @@ from src.plugin_runtime.protocol.envelope import ( ReceiveExternalMessageResultPayload, RegisterPluginPayload, ReloadPluginResultPayload, + ReloadPluginsPayload, + ReloadPluginsResultPayload, RouteMessagePayload, RunnerReadyPayload, ShutdownPayload, @@ -194,6 +196,35 @@ class PluginRunnerSupervisor: for plugin_id, registration in self._registered_plugins.items() } + @staticmethod + def _normalize_reload_plugin_ids(plugin_ids: Optional[List[str] | str]) -> List[str]: + """规范化批量重载入参。 + + Args: + plugin_ids: 原始插件 ID 列表或单个插件 ID。 + + Returns: + List[str]: 去重且去空白后的插件 ID 列表。 + """ + + raw_plugin_ids: List[str] + if plugin_ids is None: + raw_plugin_ids = [] + elif isinstance(plugin_ids, str): + raw_plugin_ids = [plugin_ids] + else: + raw_plugin_ids = list(plugin_ids) + + normalized_plugin_ids: List[str] = [] + seen_plugin_ids: set[str] = set() + for plugin_id in raw_plugin_ids: + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids: + continue + seen_plugin_ids.add(normalized_plugin_id) + normalized_plugin_ids.append(normalized_plugin_id) + return normalized_plugin_ids + async def dispatch_event( self, event_type: str, @@ -420,7 +451,7 @@ class PluginRunnerSupervisor: async def reload_plugins( self, - plugin_ids: Optional[List[str]] = None, + plugin_ids: Optional[List[str] | str] = None, reason: str = "manual", external_available_plugins: Optional[Dict[str, str]] = None, ) -> bool: @@ -434,19 +465,37 @@ class PluginRunnerSupervisor: Returns: bool: 是否全部重载成功。 """ - target_plugin_ids = plugin_ids or list(self._registered_plugins.keys()) - ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids)) - success = True + ordered_plugin_ids = self._normalize_reload_plugin_ids(plugin_ids) + if not ordered_plugin_ids: + ordered_plugin_ids = list(self._registered_plugins.keys()) + if not ordered_plugin_ids: + return True - for plugin_id in ordered_plugin_ids: - reloaded = await self.reload_plugin( - plugin_id=plugin_id, + if len(ordered_plugin_ids) == 1: + return await self.reload_plugin( + plugin_id=ordered_plugin_ids[0], reason=reason, external_available_plugins=external_available_plugins, ) - success = success and reloaded - return success + try: + response = await self._rpc_server.send_request( + "plugin.reload_batch", + payload=ReloadPluginsPayload( + plugin_ids=ordered_plugin_ids, + reason=reason, + external_available_plugins=external_available_plugins or self._external_available_plugins, + ).model_dump(), + timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), + ) + except Exception as exc: + logger.error(f"插件批量重载请求失败: {exc}") + return False + + result = ReloadPluginsResultPayload.model_validate(response.payload) + if not result.success: + logger.warning(f"插件批量重载失败: {result.failed_plugins}") + return result.success async def notify_plugin_config_updated( self, diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index c2c89a0f..e738d019 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -289,6 +289,20 @@ class ReloadPluginPayload(BaseModel): """可视为已满足的外部依赖插件版本映射""" +class ReloadPluginsPayload(BaseModel): + """批量插件重载请求载荷。""" + + plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表") + """目标插件 ID 列表""" + reason: str = Field(default="manual", description="重载原因") + """重载原因""" + external_available_plugins: Dict[str, str] = Field( + default_factory=dict, + description="可视为已满足的外部依赖插件版本映射", + ) + """可视为已满足的外部依赖插件版本映射""" + + class ReloadPluginResultPayload(BaseModel): """插件重载结果载荷。""" @@ -304,6 +318,21 @@ class ReloadPluginResultPayload(BaseModel): """重载失败的插件及原因""" +class ReloadPluginsResultPayload(BaseModel): + """批量插件重载结果载荷。""" + + success: bool = Field(description="是否重载成功") + """是否重载成功""" + requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表") + """请求重载的插件 ID 列表""" + reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表") + """成功完成重载的插件列表""" + unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表") + """本次已卸载的插件列表""" + failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因") + """重载失败的插件及原因""" + + class MessageGatewayStateUpdatePayload(BaseModel): """消息网关运行时状态更新载荷。""" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 39c741bd..e66d2fab 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -42,6 +42,8 @@ from src.plugin_runtime.protocol.envelope import ( RegisterPluginPayload, ReloadPluginPayload, ReloadPluginResultPayload, + ReloadPluginsPayload, + ReloadPluginsResultPayload, RunnerReadyPayload, UnregisterPluginPayload, ) @@ -334,6 +336,7 @@ class PluginRunner: self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated) self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin) + self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins) @staticmethod def _resolve_component_handler_name(meta: PluginMeta, component_name: str) -> str: @@ -597,6 +600,21 @@ class PluginRunner: return impacted_plugins + def _collect_reverse_dependents_for_roots(self, plugin_ids: Set[str]) -> Set[str]: + """收集多个根插件对应的反向依赖并集。 + + Args: + plugin_ids: 根插件 ID 集合。 + + Returns: + Set[str]: 所有根插件及其反向依赖并集。 + """ + + impacted_plugins: Set[str] = set() + for plugin_id in sorted(plugin_ids): + impacted_plugins.update(self._collect_reverse_dependents(plugin_id)) + return impacted_plugins + def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]: """构建受影响插件的卸载顺序。 @@ -635,6 +653,20 @@ class PluginRunner: return list(reversed(load_order)) + @staticmethod + def _normalize_requested_plugin_ids(plugin_ids: List[str]) -> List[str]: + """规范化批量重载请求中的插件 ID 列表。""" + + normalized_plugin_ids: List[str] = [] + seen_plugin_ids: Set[str] = set() + for plugin_id in plugin_ids: + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids: + continue + seen_plugin_ids.add(normalized_plugin_id) + normalized_plugin_ids.append(normalized_plugin_id) + return normalized_plugin_ids + @staticmethod def _finalize_failed_reload_messages( failed_plugins: Dict[str, str], @@ -674,6 +706,31 @@ class PluginRunner: Returns: ReloadPluginResultPayload: 结构化重载结果。 """ + batch_result = await self._reload_plugins_by_ids( + [plugin_id], + reason, + external_available_plugins=external_available_plugins, + ) + return ReloadPluginResultPayload( + success=batch_result.success, + requested_plugin_id=plugin_id, + reloaded_plugins=batch_result.reloaded_plugins, + unloaded_plugins=batch_result.unloaded_plugins, + failed_plugins=batch_result.failed_plugins, + ) + + async def _reload_plugins_by_ids( + self, + plugin_ids: List[str], + reason: str, + external_available_plugins: Optional[Dict[str, str]] = None, + ) -> ReloadPluginsResultPayload: + """按插件 ID 列表在 Runner 进程内执行一次批量重载。""" + + normalized_plugin_ids = self._normalize_requested_plugin_ids(plugin_ids) + if not normalized_plugin_ids: + return ReloadPluginsResultPayload(success=True, requested_plugin_ids=[]) + candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs) failed_plugins: Dict[str, str] = {} normalized_external_available = { @@ -682,28 +739,35 @@ class PluginRunner: if str(candidate_plugin_id or "").strip() and str(candidate_plugin_version or "").strip() } - if plugin_id in duplicate_candidates: - conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) - return ReloadPluginResultPayload( - success=False, - requested_plugin_id=plugin_id, - failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"}, - ) - loaded_plugin_ids = set(self._loader.list_plugins()) - plugin_is_loaded = plugin_id in loaded_plugin_ids - plugin_has_candidate = plugin_id in candidates + reload_root_ids: Set[str] = set() + for plugin_id in normalized_plugin_ids: + if plugin_id in duplicate_candidates: + conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) + failed_plugins[plugin_id] = f"检测到重复插件 ID: {conflict_paths}" + continue - if not plugin_is_loaded and not plugin_has_candidate: - return ReloadPluginResultPayload( + plugin_is_loaded = plugin_id in loaded_plugin_ids + plugin_has_candidate = plugin_id in candidates + if not plugin_is_loaded and not plugin_has_candidate: + failed_plugins[plugin_id] = "插件不存在或未找到合法的 manifest/plugin.py" + continue + + reload_root_ids.add(plugin_id) + + if not reload_root_ids: + return ReloadPluginsResultPayload( success=False, - requested_plugin_id=plugin_id, - failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"}, + requested_plugin_ids=normalized_plugin_ids, + failed_plugins=failed_plugins, ) - target_plugin_ids: Set[str] = {plugin_id} - if plugin_is_loaded: - target_plugin_ids = self._collect_reverse_dependents(plugin_id) + target_plugin_ids: Set[str] = { + plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids + } + loaded_root_plugin_ids = reload_root_ids & loaded_plugin_ids + if loaded_root_plugin_ids: + target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids)) unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids) unloaded_plugins: List[str] = [] @@ -813,19 +877,19 @@ class PluginRunner: if not restored: rollback_failures[rollback_plugin_id] = "无法重新激活旧版本" - return ReloadPluginResultPayload( + return ReloadPluginsResultPayload( success=False, - requested_plugin_id=plugin_id, + requested_plugin_ids=normalized_plugin_ids, reloaded_plugins=[], unloaded_plugins=unloaded_plugins, failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures), ) - requested_plugin_success = plugin_id in reloaded_plugins + requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids) - return ReloadPluginResultPayload( - success=requested_plugin_success, - requested_plugin_id=plugin_id, + return ReloadPluginsResultPayload( + success=requested_plugin_success and not failed_plugins, + requested_plugin_ids=normalized_plugin_ids, reloaded_plugins=reloaded_plugins, unloaded_plugins=unloaded_plugins, failed_plugins=failed_plugins, @@ -1139,6 +1203,29 @@ class PluginRunner: ) return envelope.make_response(payload=result.model_dump()) + async def _handle_reload_plugins(self, envelope: Envelope) -> Envelope: + """处理批量插件重载请求。""" + + try: + payload = ReloadPluginsPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + if self._reload_lock.locked(): + requested_plugin_ids = ", ".join(self._normalize_requested_plugin_ids(payload.plugin_ids)) or "" + return envelope.make_error_response( + ErrorCode.E_RELOAD_IN_PROGRESS.value, + f"插件 {requested_plugin_ids} 批量重载请求被拒绝:已有重载任务正在执行", + ) + + async with self._reload_lock: + result = await self._reload_plugins_by_ids( + list(payload.plugin_ids), + payload.reason, + external_available_plugins=dict(payload.external_available_plugins), + ) + return envelope.make_response(payload=result.model_dump()) + def request_capability(self) -> RPCClient: """获取 RPC 客户端(供 SDK 使用,发起能力调用)""" return self._rpc_client @@ -1153,6 +1240,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: 防止插件代码 import 主程序模块读取运行时数据。 """ import importlib.abc + from importlib.machinery import ModuleSpec import sysconfig # 保留: 标准库路径 + site-packages(含 SDK 和依赖) @@ -1195,6 +1283,20 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: # 安装 import 钩子,阻止插件导入主程序核心模块 # 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包 + class _BlockedSrcModuleLoader(importlib.abc.Loader): + """阻止被 Runner 允许列表之外的主程序模块完成导入。""" + + def __init__(self, fullname: str) -> None: + self._fullname = fullname + + def create_module(self, spec: ModuleSpec) -> None: + del spec + return None + + def exec_module(self, module: Any) -> None: + del module + raise ImportError(f"Runner 子进程不允许导入主程序模块: {self._fullname}") + class _PluginImportBlocker(importlib.abc.MetaPathFinder): """阻止 Runner 子进程导入主程序核心模块。 @@ -1203,14 +1305,15 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: """ _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common") + __maibot_runner_plugin_import_blocker__ = True - def find_module(self, fullname: str, path: Any = None) -> Any: + def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None: """决定是否拦截指定模块导入。""" - return self if self._should_block(fullname) else None - - def load_module(self, fullname: str) -> None: - """阻止被拦截模块继续导入。""" - raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}") + del path, target + if not self._should_block(fullname): + return None + # Python 3.13+/3.14 会优先走 find_spec,不再依赖 find_module。 + return ModuleSpec(fullname, _BlockedSrcModuleLoader(fullname), is_package=True) def _should_block(self, fullname: str) -> bool: """判断给定模块名是否应被阻止导入。""" @@ -1222,6 +1325,11 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES ) + sys.meta_path[:] = [ + finder + for finder in sys.meta_path + if not getattr(finder, "__maibot_runner_plugin_import_blocker__", False) + ] sys.meta_path.insert(0, _PluginImportBlocker()) From f4a9afc452edc08809f93da32d23f8945b1f67d2 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Tue, 24 Mar 2026 11:43:23 +0800 Subject: [PATCH 42/45] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=20RPC=20?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=99=A8=E8=BF=9E=E6=8E=A5=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0=E8=BF=9E=E6=8E=A5=E9=94=81=E4=BB=A5?= =?UTF-8?q?=E9=98=B2=E6=AD=A2=E5=B9=B6=E5=8F=91=E8=BF=9E=E6=8E=A5=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_plugin_runtime.py | 146 +++++++++++++++++++++++ src/plugin_runtime/host/rpc_server.py | 36 ++++-- src/plugin_runtime/runner/runner_main.py | 103 ++++++++++++---- 3 files changed, 252 insertions(+), 33 deletions(-) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index f3e1e7ce..29227658 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -1298,9 +1298,12 @@ class TestDependencyResolution: assert "on_unload" in loader.failed_plugins["test.demo-plugin"] def test_isolate_sys_path_preserves_plugin_dirs(self): + import builtins + from src.plugin_runtime.runner import runner_main plugin_root = os.path.normpath("/tmp/maibot-plugin-root") + original_import = builtins.__import__ original_path = list(sys.path) original_meta_path = list(sys.meta_path) @@ -1312,14 +1315,17 @@ class TestDependencyResolution: assert plugin_root in sys.path finally: + builtins.__import__ = original_import sys.path[:] = original_path sys.meta_path[:] = original_meta_path def test_isolate_sys_path_blocks_disallowed_src_imports(self): + import builtins import importlib from src.plugin_runtime.runner import runner_main + original_import = builtins.__import__ original_path = list(sys.path) original_meta_path = list(sys.meta_path) sys.modules.pop("src.forbidden_demo", None) @@ -1330,10 +1336,89 @@ class TestDependencyResolution: with pytest.raises(ImportError, match="不允许导入主程序模块"): importlib.import_module("src.forbidden_demo") finally: + builtins.__import__ = original_import sys.path[:] = original_path sys.meta_path[:] = original_meta_path sys.modules.pop("src.forbidden_demo", None) + def test_isolate_sys_path_blocks_preloaded_runtime_modules(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + runner_main._isolate_sys_path([]) + + with pytest.raises(ImportError, match="rpc_client"): + importlib.import_module("src.plugin_runtime.runner.rpc_client") + finally: + builtins.__import__ = original_import + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + + def test_isolate_sys_path_keeps_legacy_logger_import_available(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + runner_main._isolate_sys_path([]) + + logger_module = importlib.import_module("src.common.logger") + assert callable(logger_module.get_logger) + finally: + builtins.__import__ = original_import + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + + @pytest.mark.asyncio + async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch): + from src.plugin_runtime.runner import runner_main + + captured = {} + + class FakeRunner: + def __init__( + self, + host_address: str, + session_token: str, + plugin_dirs: list[str], + external_available_plugins: dict[str, str] | None = None, + ) -> None: + captured["host_address"] = host_address + captured["session_token"] = session_token + captured["plugin_dirs"] = plugin_dirs + captured["external_available_plugins"] = external_available_plugins or {} + + async def run(self) -> None: + assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None + assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None + + monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999") + monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token") + monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins") + monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}') + monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None) + monkeypatch.setattr(runner_main, "_isolate_sys_path", lambda plugin_dirs: None) + monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner) + + await runner_main._async_main() + + assert captured["host_address"] == "tcp://127.0.0.1:9999" + assert captured["session_token"] == "secret-token" + assert captured["plugin_dirs"] == ["/tmp/plugins"] + assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"} + # ─── Host-side ComponentRegistry 测试 ────────────────────── @@ -2093,6 +2178,67 @@ class TestWorkflowExecutor: class TestRPCServer: """RPC Server 代际保护测试""" + @pytest.mark.asyncio + async def test_reject_second_active_runner_connection(self): + from src.plugin_runtime.host.rpc_server import RPCServer + from src.plugin_runtime.protocol.codec import MsgPackCodec + from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType + + class DummyTransport: + async def start(self, handler): + return None + + async def stop(self): + return None + + def get_address(self): + return "dummy" + + class FakeConnection: + def __init__(self, incoming_frames: list[bytes]): + self._incoming_frames = list(incoming_frames) + self.sent_frames: list[bytes] = [] + self.is_closed = False + + async def recv_frame(self): + return self._incoming_frames.pop(0) + + async def send_frame(self, data): + self.sent_frames.append(data) + + async def close(self): + self.is_closed = True + + codec = MsgPackCodec() + server = RPCServer(transport=DummyTransport(), session_token="session-token") + active_conn = SimpleNamespace(is_closed=False) + server._connection = active_conn + + hello = HelloPayload( + runner_id="runner-b", + sdk_version="1.0.0", + session_token="session-token", + ) + envelope = Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="runner.hello", + payload=hello.model_dump(), + ) + incoming_conn = FakeConnection([codec.encode_envelope(envelope)]) + + await server._handle_connection(incoming_conn) + + assert incoming_conn.is_closed is True + assert server._connection is active_conn + assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手" + assert len(incoming_conn.sent_frames) == 1 + + response = codec.decode_envelope(incoming_conn.sent_frames[0]) + response_payload = HelloResponsePayload.model_validate(response.payload) + assert response_payload.accepted is False + assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手" + def test_ignore_stale_generation_response(self): from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.protocol.envelope import Envelope, MessageType diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 2c422775..eb6768c2 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -70,6 +70,7 @@ class RPCServer: self._running: bool = False self._tasks: List[asyncio.Task[None]] = [] self._last_handshake_rejection_reason: str = "" + self._connection_lock: asyncio.Lock = asyncio.Lock() @property def session_token(self) -> str: @@ -216,27 +217,33 @@ class RPCServer: async def _handle_connection(self, conn: Connection) -> None: """处理新的 Runner 连接""" logger.info("收到 Runner 连接") - self.clear_handshake_state() - # 第一条消息必须是 runner.hello 握手 try: - success = await self._handle_handshake(conn) - if not success: - await conn.close() - return + async with self._connection_lock: + self.clear_handshake_state() + success = await self._handle_handshake(conn) + if not success: + await conn.close() + return + logger.info("Runner staged 握手成功") + self._connection = conn except Exception as e: logger.error(f"握手失败: {e}") await conn.close() return - logger.info("Runner staged 握手成功") - self._connection = conn + # 启动消息接收循环 try: await self._recv_loop(conn) except Exception as e: logger.error(f"连接异常断开: {e}") finally: - self._connection = None - self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开") + should_fail_pending_requests = False + async with self._connection_lock: + if self._connection is conn: + self._connection = None + should_fail_pending_requests = True + if should_fail_pending_requests: + self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开") async def _handle_handshake(self, conn: Connection) -> bool: """处理 runner.hello 握手""" @@ -264,6 +271,15 @@ class RPCServer: await conn.send_frame(self._codec.encode_envelope(resp)) return False + # 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。 + if self.is_connected: + logger.warning("拒绝新的 Runner 连接:已有活跃连接") + self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手" + resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason) + resp = envelope.make_response(payload=resp_payload.model_dump()) + await conn.send_frame(self._codec.encode_envelope(resp)) + return False + # 校验 SDK 版本 if not self._check_sdk_version(hello.sdk_version): logger.error(f"SDK 版本不兼容: {hello.sdk_version}") diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index e66d2fab..e4e47c68 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -1237,11 +1237,14 @@ class PluginRunner: def _isolate_sys_path(plugin_dirs: List[str]) -> None: """清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。 - 防止插件代码 import 主程序模块读取运行时数据。 + 同时移除插件可直接访问的主程序内部模块缓存,避免通过 ``sys.modules`` + 或常规导入绕过 SDK / capability 边界。 """ + import builtins import importlib.abc from importlib.machinery import ModuleSpec import sysconfig + from types import ModuleType # 保留: 标准库路径 + site-packages(含 SDK 和依赖) stdlib_paths = set() @@ -1271,18 +1274,68 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: for d in plugin_dir_paths: allowed.add(d) - # 添加项目根目录(使得 src.plugin_runtime / src.common 可导入) - runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) - allowed.add(runtime_root) - preserved_paths = [p for p in sys.path if p in allowed] - for extra_path in [*plugin_dir_paths, runtime_root]: + for extra_path in plugin_dir_paths: if extra_path not in preserved_paths: preserved_paths.append(extra_path) sys.path[:] = preserved_paths - # 安装 import 钩子,阻止插件导入主程序核心模块 - # 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包 + # 仅为旧版插件兼容层保留极小的 src.* 可见面: + # - src.plugin_system.*: 通过 maibot_sdk.compat 导入钩子重定向 + # - src.common.logger: 仓库内仍有少量旧插件沿用该日志入口 + allowed_src_exact_modules = frozenset( + { + "src", + "src.common", + "src.common.logger", + "src.common.logger_color_and_mapping", + } + ) + allowed_src_prefixes = ("src.plugin_system",) + + def _is_allowed_src_module(fullname: str) -> bool: + """判断给定 src.* 模块是否在 Runner 允许列表中。""" + if fullname in allowed_src_exact_modules: + return True + return any(fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in allowed_src_prefixes) + + def _format_block_message(fullname: str) -> str: + """构造统一的拒绝导入错误信息。""" + return ( + f"Runner 子进程不允许导入主程序模块: {fullname}。" + "请改用 maibot_sdk 或 src.plugin_system 兼容层提供的接口。" + ) + + def _detach_module_from_parent(fullname: str, module: ModuleType) -> None: + """从父模块上移除已清理模块的属性引用。""" + parent_name, _, child_name = fullname.rpartition(".") + if not parent_name or not child_name: + return + + parent_module = sys.modules.get(parent_name) + if parent_module is None: + return + if getattr(parent_module, child_name, None) is module: + with contextlib.suppress(AttributeError): + delattr(parent_module, child_name) + + # 清理主程序内部模块缓存,避免插件经由 sys.modules 直接拿到高权限对象。 + existing_src_modules = sorted( + ( + (module_name, module) + for module_name, module in list(sys.modules.items()) + if module_name == "src" or module_name.startswith("src.") + ), + key=lambda item: item[0].count("."), + reverse=True, + ) + for module_name, module in existing_src_modules: + if _is_allowed_src_module(module_name): + continue + _detach_module_from_parent(module_name, module) + sys.modules.pop(module_name, None) + + # 安装 import 钩子,阻止再次导入被清理掉的主程序内部模块。 class _BlockedSrcModuleLoader(importlib.abc.Loader): """阻止被 Runner 允许列表之外的主程序模块完成导入。""" @@ -1295,16 +1348,11 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: def exec_module(self, module: Any) -> None: del module - raise ImportError(f"Runner 子进程不允许导入主程序模块: {self._fullname}") + raise ImportError(_format_block_message(self._fullname)) class _PluginImportBlocker(importlib.abc.MetaPathFinder): - """阻止 Runner 子进程导入主程序核心模块。 + """阻止 Runner 子进程重新导入主程序内部 src.* 模块。""" - 只放行 src.plugin_runtime 和 src.common, - 拒绝 src.chat_module / src.services 等主程序内部包。 - """ - - _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common") __maibot_runner_plugin_import_blocker__ = True def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None: @@ -1317,13 +1365,9 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: def _should_block(self, fullname: str) -> bool: """判断给定模块名是否应被阻止导入。""" - # 放行非 src.* 的导入、以及 "src" 本身 - if not fullname.startswith("src.") or fullname == "src": + if not fullname.startswith("src"): return False - # 放行白名单前缀 - return not any( - fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES - ) + return not _is_allowed_src_module(fullname) sys.meta_path[:] = [ finder @@ -1332,15 +1376,28 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: ] sys.meta_path.insert(0, _PluginImportBlocker()) + # ``import`` 语句在模块已存在于 sys.modules 时不会再经过 finder, + # 因此还需要在入口处补一层兜底。 + original_import = getattr(builtins, "__maibot_runner_original_import__", builtins.__import__) + builtins.__maibot_runner_original_import__ = original_import + + def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any: + if level == 0 and name.startswith("src") and not _is_allowed_src_module(name): + raise ImportError(_format_block_message(name)) + return original_import(name, globals, locals, fromlist, level) + + _guarded_import.__maibot_runner_plugin_import_guard__ = True + builtins.__import__ = _guarded_import + # ─── 进程入口 ────────────────────────────────────────────── async def _async_main() -> None: """异步主入口""" - host_address = os.environ.get(ENV_IPC_ADDRESS, "") + host_address = os.environ.pop(ENV_IPC_ADDRESS, "") external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "") - session_token = os.environ.get(ENV_SESSION_TOKEN, "") + session_token = os.environ.pop(ENV_SESSION_TOKEN, "") plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "") if not host_address or not session_token: From d5581a1a970c8c0ae23194ba2f1bbcc2e2b283b9 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Tue, 24 Mar 2026 11:49:40 +0800 Subject: [PATCH 43/45] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E5=AF=BC=E5=85=A5=E7=AE=A1=E7=90=86=EF=BC=8C=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E5=AF=BC=E5=85=A5=E8=AF=B7=E6=B1=82=E9=AA=8C=E8=AF=81?= =?UTF-8?q?=E5=92=8C=E6=A8=A1=E5=9D=97=E8=AE=BF=E9=97=AE=E6=8E=A7=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_plugin_runtime.py | 55 +++++++++- src/plugin_runtime/runner/runner_main.py | 124 +++++++++++++---------- 2 files changed, 124 insertions(+), 55 deletions(-) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 29227658..e3247f05 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -1299,11 +1299,13 @@ class TestDependencyResolution: def test_isolate_sys_path_preserves_plugin_dirs(self): import builtins + import importlib from src.plugin_runtime.runner import runner_main plugin_root = os.path.normpath("/tmp/maibot-plugin-root") original_import = builtins.__import__ + original_import_module = importlib.import_module original_path = list(sys.path) original_meta_path = list(sys.meta_path) @@ -1316,6 +1318,7 @@ class TestDependencyResolution: assert plugin_root in sys.path finally: builtins.__import__ = original_import + importlib.import_module = original_import_module sys.path[:] = original_path sys.meta_path[:] = original_meta_path @@ -1326,17 +1329,24 @@ class TestDependencyResolution: from src.plugin_runtime.runner import runner_main original_import = builtins.__import__ + original_import_module = importlib.import_module original_path = list(sys.path) original_meta_path = list(sys.meta_path) sys.modules.pop("src.forbidden_demo", None) try: runner_main._isolate_sys_path([]) + plugin_globals = { + "__name__": "_maibot_plugin_demo", + "__package__": "_maibot_plugin_demo", + "importlib": importlib, + } with pytest.raises(ImportError, match="不允许导入主程序模块"): - importlib.import_module("src.forbidden_demo") + exec('importlib.import_module("src.forbidden_demo")', plugin_globals) finally: builtins.__import__ = original_import + importlib.import_module = original_import_module sys.path[:] = original_path sys.meta_path[:] = original_meta_path sys.modules.pop("src.forbidden_demo", None) @@ -1348,16 +1358,23 @@ class TestDependencyResolution: from src.plugin_runtime.runner import runner_main original_import = builtins.__import__ + original_import_module = importlib.import_module original_path = list(sys.path) original_meta_path = list(sys.meta_path) try: runner_main._isolate_sys_path([]) + plugin_globals = { + "__name__": "_maibot_plugin_demo", + "__package__": "_maibot_plugin_demo", + "importlib": importlib, + } with pytest.raises(ImportError, match="rpc_client"): - importlib.import_module("src.plugin_runtime.runner.rpc_client") + exec('importlib.import_module("src.plugin_runtime.runner.rpc_client")', plugin_globals) finally: builtins.__import__ = original_import + importlib.import_module = original_import_module sys.path[:] = original_path sys.meta_path[:] = original_meta_path @@ -1368,16 +1385,46 @@ class TestDependencyResolution: from src.plugin_runtime.runner import runner_main original_import = builtins.__import__ + original_import_module = importlib.import_module + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + runner_main._isolate_sys_path([]) + plugin_globals = { + "__name__": "_maibot_plugin_demo", + "__package__": "_maibot_plugin_demo", + "importlib": importlib, + } + + exec('logger_module = importlib.import_module("src.common.logger")', plugin_globals) + logger_module = plugin_globals["logger_module"] + assert callable(logger_module.get_logger) + finally: + builtins.__import__ = original_import + importlib.import_module = original_import_module + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + + def test_isolate_sys_path_keeps_runtime_imports_working(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_import_module = importlib.import_module original_path = list(sys.path) original_meta_path = list(sys.meta_path) try: runner_main._isolate_sys_path([]) - logger_module = importlib.import_module("src.common.logger") - assert callable(logger_module.get_logger) + uds_module = importlib.import_module("src.plugin_runtime.transport.uds") + assert hasattr(uds_module, "UDSTransportClient") finally: builtins.__import__ = original_import + importlib.import_module = original_import_module sys.path[:] = original_path sys.meta_path[:] = original_meta_path diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index e4e47c68..4dac4e05 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -1237,12 +1237,11 @@ class PluginRunner: def _isolate_sys_path(plugin_dirs: List[str]) -> None: """清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。 - 同时移除插件可直接访问的主程序内部模块缓存,避免通过 ``sys.modules`` - 或常规导入绕过 SDK / capability 边界。 + 同时阻止插件代码直接导入主程序内部 ``src.*`` 模块,并清理可直接从 + ``sys.modules`` 摸到的高权限叶子模块,避免绕过 SDK / capability 边界。 """ import builtins - import importlib.abc - from importlib.machinery import ModuleSpec + import importlib import sysconfig from types import ModuleType @@ -1292,6 +1291,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: } ) allowed_src_prefixes = ("src.plugin_system",) + plugin_module_prefix = "_maibot_plugin_" def _is_allowed_src_module(fullname: str) -> bool: """判断给定 src.* 模块是否在 Runner 允许列表中。""" @@ -1299,6 +1299,35 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: return True return any(fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in allowed_src_prefixes) + def _resolve_requester_name(import_globals: Any = None) -> str: + """解析当前导入请求的发起模块名。""" + if isinstance(import_globals, dict): + for key in ("__name__", "__package__"): + value = import_globals.get(key) + if isinstance(value, str) and value: + return value + + frame = inspect.currentframe() + try: + current = frame.f_back if frame is not None else None + while current is not None: + module_name = current.f_globals.get("__name__", "") + if not isinstance(module_name, str) or not module_name: + current = current.f_back + continue + if module_name == __name__ or module_name.startswith("importlib"): + current = current.f_back + continue + return module_name + return "" + finally: + del frame + + def _is_plugin_import_request(import_globals: Any = None) -> bool: + """判断当前导入是否由插件模块直接发起。""" + requester_name = _resolve_requester_name(import_globals) + return requester_name.startswith(plugin_module_prefix) + def _format_block_message(fullname: str) -> str: """构造统一的拒绝导入错误信息。""" return ( @@ -1306,6 +1335,30 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: "请改用 maibot_sdk 或 src.plugin_system 兼容层提供的接口。" ) + def _iter_requested_src_modules(name: str, fromlist: Any) -> List[str]: + """展开本次导入请求涉及的 src.* 模块名。""" + requested_modules = [name] + if not name.startswith("src") or not fromlist: + return requested_modules + + for item in fromlist: + if not isinstance(item, str) or not item or item == "*": + continue + requested_modules.append(f"{name}.{item}") + return requested_modules + + def _assert_plugin_import_allowed(name: str, import_globals: Any = None, fromlist: Any = ()) -> None: + """在插件发起导入时校验目标 src.* 模块是否允许访问。""" + if not _is_plugin_import_request(import_globals): + return + + for requested_module in _iter_requested_src_modules(name, fromlist): + if not requested_module.startswith("src"): + continue + if _is_allowed_src_module(requested_module): + continue + raise ImportError(_format_block_message(requested_module)) + def _detach_module_from_parent(fullname: str, module: ModuleType) -> None: """从父模块上移除已清理模块的属性引用。""" parent_name, _, child_name = fullname.rpartition(".") @@ -1319,7 +1372,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: with contextlib.suppress(AttributeError): delattr(parent_module, child_name) - # 清理主程序内部模块缓存,避免插件经由 sys.modules 直接拿到高权限对象。 + # 仅清理已加载的叶子模块,保留包对象给 Runner 自己的延迟导入和相对导入使用。 existing_src_modules = sorted( ( (module_name, module) @@ -1330,65 +1383,34 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: reverse=True, ) for module_name, module in existing_src_modules: - if _is_allowed_src_module(module_name): + if _is_allowed_src_module(module_name) or hasattr(module, "__path__"): continue _detach_module_from_parent(module_name, module) sys.modules.pop(module_name, None) - # 安装 import 钩子,阻止再次导入被清理掉的主程序内部模块。 - class _BlockedSrcModuleLoader(importlib.abc.Loader): - """阻止被 Runner 允许列表之外的主程序模块完成导入。""" - - def __init__(self, fullname: str) -> None: - self._fullname = fullname - - def create_module(self, spec: ModuleSpec) -> None: - del spec - return None - - def exec_module(self, module: Any) -> None: - del module - raise ImportError(_format_block_message(self._fullname)) - - class _PluginImportBlocker(importlib.abc.MetaPathFinder): - """阻止 Runner 子进程重新导入主程序内部 src.* 模块。""" - - __maibot_runner_plugin_import_blocker__ = True - - def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None: - """决定是否拦截指定模块导入。""" - del path, target - if not self._should_block(fullname): - return None - # Python 3.13+/3.14 会优先走 find_spec,不再依赖 find_module。 - return ModuleSpec(fullname, _BlockedSrcModuleLoader(fullname), is_package=True) - - def _should_block(self, fullname: str) -> bool: - """判断给定模块名是否应被阻止导入。""" - if not fullname.startswith("src"): - return False - return not _is_allowed_src_module(fullname) - - sys.meta_path[:] = [ - finder - for finder in sys.meta_path - if not getattr(finder, "__maibot_runner_plugin_import_blocker__", False) - ] - sys.meta_path.insert(0, _PluginImportBlocker()) - - # ``import`` 语句在模块已存在于 sys.modules 时不会再经过 finder, - # 因此还需要在入口处补一层兜底。 + # ``import`` 语句与 ``importlib.import_module`` 走的是不同入口,因此两边都需要兜底。 original_import = getattr(builtins, "__maibot_runner_original_import__", builtins.__import__) builtins.__maibot_runner_original_import__ = original_import def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any: - if level == 0 and name.startswith("src") and not _is_allowed_src_module(name): - raise ImportError(_format_block_message(name)) + if level == 0: + _assert_plugin_import_allowed(name, import_globals=globals, fromlist=fromlist) return original_import(name, globals, locals, fromlist, level) _guarded_import.__maibot_runner_plugin_import_guard__ = True builtins.__import__ = _guarded_import + original_import_module = getattr(importlib, "__maibot_runner_original_import_module__", importlib.import_module) + importlib.__maibot_runner_original_import_module__ = original_import_module + + def _guarded_import_module(name: str, package: Optional[str] = None) -> Any: + resolved_name = importlib.util.resolve_name(name, package) if name.startswith(".") else name + _assert_plugin_import_allowed(resolved_name) + return original_import_module(name, package) + + _guarded_import_module.__maibot_runner_plugin_import_guard__ = True + importlib.import_module = _guarded_import_module + # ─── 进程入口 ────────────────────────────────────────────── From b8224bdb3c6d8b2846cc7b594345d4ed09c0ff12 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Tue, 24 Mar 2026 11:51:15 +0800 Subject: [PATCH 44/45] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E6=94=B9?= =?UTF-8?q?=E8=BF=9B=E5=AF=BC=E5=85=A5=E4=BF=9D=E6=8A=A4=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/runner/runner_main.py | 30 +++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 4dac4e05..d1ebc064 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -429,9 +429,9 @@ class PluginRunner: component_name = str(comp_info.get("name", "") or "").strip() raw_metadata = comp_info.get("metadata", {}) component_metadata = raw_metadata if isinstance(raw_metadata, dict) else {} - handler_name = str(component_metadata.get("handler_name", component_name) or component_name).strip() if component_name: + handler_name = str(component_metadata.get("handler_name", component_name) or component_name).strip() meta.component_handlers[component_name] = handler_name or component_name components.append( @@ -765,8 +765,7 @@ class PluginRunner: target_plugin_ids: Set[str] = { plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids } - loaded_root_plugin_ids = reload_root_ids & loaded_plugin_ids - if loaded_root_plugin_ids: + if loaded_root_plugin_ids := reload_root_ids & loaded_plugin_ids: target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids)) unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids) @@ -823,11 +822,10 @@ class PluginRunner: continue _, manifest, _ = candidate - unsatisfied_dependencies = self._loader.manifest_validator.get_unsatisfied_plugin_dependencies( + if unsatisfied_dependencies := self._loader.manifest_validator.get_unsatisfied_plugin_dependencies( manifest, available_plugin_versions=available_plugins, - ) - if unsatisfied_dependencies: + ): failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}" continue @@ -1240,10 +1238,12 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: 同时阻止插件代码直接导入主程序内部 ``src.*`` 模块,并清理可直接从 ``sys.modules`` 摸到的高权限叶子模块,避免绕过 SDK / capability 边界。 """ + from importlib import util as importlib_util + from types import ModuleType + import builtins import importlib import sysconfig - from types import ModuleType # 保留: 标准库路径 + site-packages(含 SDK 和依赖) stdlib_paths = set() @@ -1389,26 +1389,28 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: sys.modules.pop(module_name, None) # ``import`` 语句与 ``importlib.import_module`` 走的是不同入口,因此两边都需要兜底。 - original_import = getattr(builtins, "__maibot_runner_original_import__", builtins.__import__) - builtins.__maibot_runner_original_import__ = original_import + builtins_module = cast(Any, builtins) + original_import = getattr(builtins_module, "__maibot_runner_original_import__", builtins.__import__) + builtins_module.__maibot_runner_original_import__ = original_import def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any: if level == 0: _assert_plugin_import_allowed(name, import_globals=globals, fromlist=fromlist) return original_import(name, globals, locals, fromlist, level) - _guarded_import.__maibot_runner_plugin_import_guard__ = True + cast(Any, _guarded_import).__maibot_runner_plugin_import_guard__ = True builtins.__import__ = _guarded_import - original_import_module = getattr(importlib, "__maibot_runner_original_import_module__", importlib.import_module) - importlib.__maibot_runner_original_import_module__ = original_import_module + importlib_module = cast(Any, importlib) + original_import_module = getattr(importlib_module, "__maibot_runner_original_import_module__", importlib.import_module) + importlib_module.__maibot_runner_original_import_module__ = original_import_module def _guarded_import_module(name: str, package: Optional[str] = None) -> Any: - resolved_name = importlib.util.resolve_name(name, package) if name.startswith(".") else name + resolved_name = importlib_util.resolve_name(name, package) if name.startswith(".") else name _assert_plugin_import_allowed(resolved_name) return original_import_module(name, package) - _guarded_import_module.__maibot_runner_plugin_import_guard__ = True + cast(Any, _guarded_import_module).__maibot_runner_plugin_import_guard__ = True importlib.import_module = _guarded_import_module From 2c279f703ca386306a4f066cd8a6b33df5510adb Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Tue, 24 Mar 2026 15:32:09 +0800 Subject: [PATCH 45/45] =?UTF-8?q?Revert=20"feat=EF=BC=9A=E5=B0=9D=E8=AF=95?= =?UTF-8?q?=E5=BB=BA=E7=AB=8Bhfc=E9=80=BB=E8=BE=91"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit bfc9781c4f10212ecfffc3129115ecc6fbd3fa89. --- src/chat/heart_flow/heartFC_chat - 副本.py | 734 ------------------ src/chat/heart_flow/heartFC_chat.py | 823 ++++----------------- src/chat/heart_flow/heartflow.py | 42 -- 3 files changed, 160 insertions(+), 1439 deletions(-) delete mode 100644 src/chat/heart_flow/heartFC_chat - 副本.py delete mode 100644 src/chat/heart_flow/heartflow.py diff --git a/src/chat/heart_flow/heartFC_chat - 副本.py b/src/chat/heart_flow/heartFC_chat - 副本.py deleted file mode 100644 index 02f70281..00000000 --- a/src/chat/heart_flow/heartFC_chat - 副本.py +++ /dev/null @@ -1,734 +0,0 @@ -import asyncio -import time -import traceback -import random -from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING -from rich.traceback import install - -from src.config.config import global_config -from src.common.logger import get_logger -from src.common.data_models.info_data_model import ActionPlannerInfo -from src.common.data_models.message_data_model import ReplyContentType -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.utils.prompt_builder import global_prompt_manager -from src.chat.utils.timer_calculator import Timer -from src.chat.planner_actions.planner import ActionPlanner -from src.chat.planner_actions.action_modifier import ActionModifier -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.heart_flow.hfc_utils import CycleDetail -from src.learners.expression_learner import expression_learner_manager -from src.chat.heart_flow.frequency_control import frequency_control_manager -from src.learners.message_recorder import extract_and_distribute_messages -from src.person_info.person_info import Person -from src.plugin_system.base.component_types import EventType, ActionInfo -from src.plugin_system.core import events_manager -from src.plugin_system.apis import generator_api, send_api, message_api, database_api -from src.chat.utils.chat_message_builder import ( - build_readable_messages_with_id, - get_raw_msg_before_timestamp_with_chat, -) -from src.chat.utils.utils import record_replyer_action_temp -from src.memory_system.chat_history_summarizer import ChatHistorySummarizer - -if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages - from src.common.data_models.message_data_model import ReplySetModel - - -ERROR_LOOP_INFO = { - "loop_plan_info": { - "action_result": { - "action_type": "error", - "action_data": {}, - "reasoning": "循环处理失败", - }, - }, - "loop_action_info": { - "action_taken": False, - "reply_text": "", - "command": "", - "taken_time": time.time(), - }, -} - - -install(extra_lines=3) - -# 注释:原来的动作修改超时常量已移除,因为改为顺序执行 - -logger = get_logger("hfc") # Logger Name Changed - - -class HeartFChatting: - """ - 管理一个连续的Focus Chat循环 - 用于在特定聊天流中生成回复。 - 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。 - """ - - def __init__(self, chat_id: str): - """ - HeartFChatting 初始化函数 - - 参数: - chat_id: 聊天流唯一标识符(如stream_id) - on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数 - performance_version: 性能记录版本号,用于区分不同启动版本 - """ - # 基础属性 - self.stream_id: str = chat_id # 聊天流ID - self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore - if not self.chat_stream: - raise ValueError(f"无法找到聊天流: {self.stream_id}") - self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" - - self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id) - - self.action_manager = ActionManager() - self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager) - self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id) - - # 循环控制内部状态 - self.running: bool = False - self._loop_task: Optional[asyncio.Task] = None # 主循环任务 - - # 添加循环信息管理相关的属性 - self.history_loop: List[CycleDetail] = [] - self._cycle_counter = 0 - self._current_cycle_detail: CycleDetail = None # type: ignore - - self.last_read_time = time.time() - 2 - - self.is_mute = False - - self.last_active_time = time.time() # 记录上一次非noreply时间 - - self.question_probability_multiplier = 1 - self.questioned = False - - # 跟踪连续 no_reply 次数,用于动态调整阈值 - self.consecutive_no_reply_count = 0 - - # 聊天内容概括器 - self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id) - - async def start(self): - """检查是否需要启动主循环,如果未激活则启动。""" - - # 如果循环已经激活,直接返回 - if self.running: - logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动") - return - - try: - # 标记为活动状态,防止重复启动 - self.running = True - - self._loop_task = asyncio.create_task(self._main_chat_loop()) - self._loop_task.add_done_callback(self._handle_loop_completion) - - # 启动聊天内容概括器的后台定期检查循环 - await self.chat_history_summarizer.start() - - logger.info(f"{self.log_prefix} HeartFChatting 启动完成") - - except Exception as e: - # 启动失败时重置状态 - self.running = False - self._loop_task = None - logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}") - raise - - def _handle_loop_completion(self, task: asyncio.Task): - """当 _hfc_loop 任务完成时执行的回调。""" - try: - if exception := task.exception(): - logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}") - logger.error(traceback.format_exc()) # Log full traceback for exceptions - else: - logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)") - except asyncio.CancelledError: - logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") - - def start_cycle(self) -> Tuple[Dict[str, float], str]: - self._cycle_counter += 1 - self._current_cycle_detail = CycleDetail(self._cycle_counter) - self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" - cycle_timers = {} - return cycle_timers, self._current_cycle_detail.thinking_id - - def end_cycle(self, loop_info, cycle_timers): - self._current_cycle_detail.set_loop_info(loop_info) - self.history_loop.append(self._current_cycle_detail) - self._current_cycle_detail.timers = cycle_timers - self._current_cycle_detail.end_time = time.time() - - def print_cycle_info(self, cycle_timers): - # 记录循环信息和计时器结果 - timer_strings = [] - for name, elapsed in cycle_timers.items(): - if elapsed < 0.1: - # 不显示小于0.1秒的计时器 - continue - formatted_time = f"{elapsed:.2f}秒" - timer_strings.append(f"{name}: {formatted_time}") - - logger.info( - f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," - f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore - + (f"详情: {'; '.join(timer_strings)}" if timer_strings else "") - ) - - async def _loopbody(self): - recent_messages_list = message_api.get_messages_by_time_in_chat( - chat_id=self.stream_id, - start_time=self.last_read_time, - end_time=time.time(), - limit=20, - limit_mode="latest", - filter_mai=True, - filter_command=False, - filter_intercept_message_level=0, - ) - - # 根据连续 no_reply 次数动态调整阈值 - # 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2) - # 5次 no_reply 时,提高到 2(大于等于两条消息的阈值) - if self.consecutive_no_reply_count >= 5: - threshold = 2 - elif self.consecutive_no_reply_count >= 3: - # 1.5 的含义:50%概率为1,50%概率为2 - threshold = 2 if random.random() < 0.5 else 1 - else: - threshold = 1 - - if len(recent_messages_list) >= threshold: - # for message in recent_messages_list: - # print(message.processed_plain_text) - - self.last_read_time = time.time() - - # !此处使at或者提及必定回复 - mentioned_message = None - for message in recent_messages_list: - if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply: - mentioned_message = message - - # logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}") - - # *控制频率用 - if mentioned_message: - await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message) - elif ( - random.random() - < global_config.chat.get_talk_value(self.stream_id) - * frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust() - ): - await self._observe(recent_messages_list=recent_messages_list) - else: - # 没有提到,继续保持沉默,等待5秒防止频繁触发 - await asyncio.sleep(10) - return True - else: - await asyncio.sleep(0.2) - return True - return True - - async def _send_and_store_reply( - self, - response_set: "ReplySetModel", - action_message: "DatabaseMessages", - cycle_timers: Dict[str, float], - thinking_id, - actions, - selected_expressions: Optional[List[int]] = None, - quote_message: Optional[bool] = None, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: - with Timer("回复发送", cycle_timers): - reply_text = await self._send_response( - reply_set=response_set, - message_data=action_message, - selected_expressions=selected_expressions, - quote_message=quote_message, - ) - - # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - platform = action_message.chat_info.platform - if platform is None: - platform = getattr(self.chat_stream, "platform", "unknown") - - person = Person(platform=platform, user_id=action_message.user_info.user_id) - person_name = person.person_name - action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" - - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=action_prompt_display, - action_done=True, - thinking_id=thinking_id, - action_data={"reply_text": reply_text}, - action_name="reply", - ) - - # 构建循环信息 - loop_info: Dict[str, Any] = { - "loop_plan_info": { - "action_result": actions, - }, - "loop_action_info": { - "action_taken": True, - "reply_text": reply_text, - "command": "", - "taken_time": time.time(), - }, - } - - return loop_info, reply_text, cycle_timers - - async def _observe( - self, # interest_value: float = 0.0, - recent_messages_list: Optional[List["DatabaseMessages"]] = None, - force_reply_message: Optional["DatabaseMessages"] = None, - ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if - if recent_messages_list is None: - recent_messages_list = [] - _reply_text = "" # 初始化reply_text变量,避免UnboundLocalError - - start_time = time.time() - async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): - # 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner - # 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息 - asyncio.create_task(extract_and_distribute_messages(self.stream_id)) - - # 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容 - # asyncio.create_task(check_and_make_question(self.stream_id)) - # 添加聊天内容概括任务 - 累积、打包和压缩聊天记录 - # 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理 - # asyncio.create_task(self.chat_history_summarizer.process()) - - cycle_timers, thinking_id = self.start_cycle() - logger.info( - f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})" - ) - - # 第一步:动作检查 - available_actions: Dict[str, ActionInfo] = {} - try: - await self.action_modifier.modify_actions() - available_actions = self.action_manager.get_using_actions() - except Exception as e: - logger.error(f"{self.log_prefix} 动作修改失败: {e}") - - # 执行planner - is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() - - message_list_before_now = get_raw_msg_before_timestamp_with_chat( - chat_id=self.stream_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.6), - filter_intercept_message_level=1, - ) - chat_content_block, message_id_list = build_readable_messages_with_id( - messages=message_list_before_now, - timestamp_mode="normal_no_YMD", - read_mark=self.action_planner.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - - prompt_info = await self.action_planner.build_planner_prompt( - is_group_chat=is_group_chat, - chat_target_info=chat_target_info, - current_available_actions=available_actions, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - ) - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id - ) - if not continue_flag: - return False - if modified_message and modified_message._modify_flags.modify_llm_prompt: - prompt_info = (modified_message.llm_prompt, prompt_info[1]) - - with Timer("规划器", cycle_timers): - action_to_use_info = await self.action_planner.plan( - loop_start_time=self.last_read_time, - available_actions=available_actions, - force_reply_message=force_reply_message, - ) - - logger.info( - f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}" - ) - - # 3. 并行执行所有动作 - action_tasks = [ - asyncio.create_task( - self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers) - ) - for action in action_to_use_info - ] - - # 并行执行所有任务 - results = await asyncio.gather(*action_tasks, return_exceptions=True) - - # 处理执行结果 - reply_loop_info = None - reply_text_from_reply = "" - action_success = False - action_reply_text = "" - - excute_result_str = "" - for result in results: - excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n" - - if isinstance(result, BaseException): - logger.error(f"{self.log_prefix} 动作执行异常: {result}") - continue - - if result["action_type"] != "reply": - action_success = result["success"] - action_reply_text = result["result"] - elif result["action_type"] == "reply": - if result["success"]: - reply_loop_info = result["loop_info"] - reply_text_from_reply = result["result"] - else: - logger.warning(f"{self.log_prefix} 回复动作执行失败") - - self.action_planner.add_plan_excute_log(result=excute_result_str) - - # 构建最终的循环信息 - if reply_loop_info: - # 如果有回复信息,使用回复的loop_info作为基础 - loop_info = reply_loop_info - # 更新动作执行信息 - loop_info["loop_action_info"].update( - { - "action_taken": action_success, - "taken_time": time.time(), - } - ) - _reply_text = reply_text_from_reply - else: - # 没有回复信息,构建纯动作的loop_info - loop_info = { - "loop_plan_info": { - "action_result": action_to_use_info, - }, - "loop_action_info": { - "action_taken": action_success, - "reply_text": action_reply_text, - "taken_time": time.time(), - }, - } - _reply_text = action_reply_text - - self.end_cycle(loop_info, cycle_timers) - self.print_cycle_info(cycle_timers) - - end_time = time.time() - if end_time - start_time < global_config.chat.planner_smooth: - wait_time = global_config.chat.planner_smooth - (end_time - start_time) - await asyncio.sleep(wait_time) - else: - await asyncio.sleep(0.1) - return True - - async def _main_chat_loop(self): - """主循环,持续进行计划并可能回复消息,直到被外部取消。""" - try: - while self.running: - # 主循环 - success = await self._loopbody() - await asyncio.sleep(0.1) - if not success: - break - except asyncio.CancelledError: - # 设置了关闭标志位后被取消是正常流程 - logger.info(f"{self.log_prefix} 麦麦已关闭聊天") - except Exception: - logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动") - print(traceback.format_exc()) - await asyncio.sleep(3) - self._loop_task = asyncio.create_task(self._main_chat_loop()) - logger.error(f"{self.log_prefix} 结束了当前聊天循环") - - async def _handle_action( - self, - action: str, - action_reasoning: str, - action_data: dict, - cycle_timers: Dict[str, float], - thinking_id: str, - action_message: Optional["DatabaseMessages"] = None, - ) -> tuple[bool, str, str]: - """ - 处理规划动作,使用动作工厂创建相应的动作处理器 - - 参数: - action: 动作类型 - action_reasoning: 决策理由 - action_data: 动作数据,包含不同动作需要的参数 - cycle_timers: 计时器字典 - thinking_id: 思考ID - action_message: 消息数据 - 返回: - tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令) - """ - try: - # 使用工厂创建动作处理器实例 - try: - action_handler = self.action_manager.create_action( - action_name=action, - action_data=action_data, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - chat_stream=self.chat_stream, - log_prefix=self.log_prefix, - action_reasoning=action_reasoning, - action_message=action_message, - ) - except Exception as e: - logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}") - traceback.print_exc() - return False, "" - - # 处理动作并获取结果(固定记录一次动作信息) - result = await action_handler.execute() - success, action_text = result - - return success, action_text - - except Exception as e: - logger.error(f"{self.log_prefix} 处理{action}时出错: {e}") - traceback.print_exc() - return False, "" - - async def _send_response( - self, - reply_set: "ReplySetModel", - message_data: "DatabaseMessages", - selected_expressions: Optional[List[int]] = None, - quote_message: Optional[bool] = None, - ) -> str: - # 根据 llm_quote 配置决定是否使用 quote_message 参数 - if global_config.chat.llm_quote: - # 如果配置为 true,使用 llm_quote 参数决定是否引用回复 - if quote_message is None: - logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用") - need_reply = False - else: - need_reply = quote_message - if need_reply: - logger.info(f"{self.log_prefix} LLM 决定使用引用回复") - else: - # 如果配置为 false,使用原来的模式 - new_message_count = message_api.count_new_messages( - chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time() - ) - need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90 - if need_reply: - logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒") - - reply_text = "" - first_replied = False - for reply_content in reply_set.reply_data: - if reply_content.content_type != ReplyContentType.TEXT: - continue - data: str = reply_content.content # type: ignore - if not first_replied: - await send_api.text_to_stream( - text=data, - stream_id=self.chat_stream.stream_id, - reply_message=message_data, - set_reply=need_reply, - typing=False, - selected_expressions=selected_expressions, - ) - first_replied = True - else: - await send_api.text_to_stream( - text=data, - stream_id=self.chat_stream.stream_id, - reply_message=message_data, - set_reply=False, - typing=True, - selected_expressions=selected_expressions, - ) - reply_text += data - - return reply_text - - async def _execute_action( - self, - action_planner_info: ActionPlannerInfo, - chosen_action_plan_infos: List[ActionPlannerInfo], - thinking_id: str, - available_actions: Dict[str, ActionInfo], - cycle_timers: Dict[str, float], - ): - """执行单个动作的通用函数""" - try: - with Timer(f"动作{action_planner_info.action_type}", cycle_timers): - # 直接当场执行no_reply逻辑 - if action_planner_info.action_type == "no_reply": - # 直接处理no_reply逻辑,不再通过动作系统 - reason = action_planner_info.reasoning or "选择不回复" - # logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") - - # 增加连续 no_reply 计数 - self.consecutive_no_reply_count += 1 - - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, - thinking_id=thinking_id, - action_data={}, - action_name="no_reply", - action_reasoning=reason, - ) - - return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""} - - elif action_planner_info.action_type == "reply": - # 直接当场执行reply逻辑 - self.questioned = False - # 刷新主动发言状态 - # 重置连续 no_reply 计数 - self.consecutive_no_reply_count = 0 - - reason = action_planner_info.reasoning or "" - # 根据 think_mode 配置决定 think_level 的值 - think_mode = global_config.chat.think_mode - if think_mode == "default": - think_level = 0 - elif think_mode == "deep": - think_level = 1 - elif think_mode == "dynamic": - # dynamic 模式:从 planner 返回的 action_data 中获取 - think_level = action_planner_info.action_data.get("think_level", 1) - else: - # 默认使用 default 模式 - think_level = 0 - # 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason - planner_reasoning = action_planner_info.action_reasoning or reason - - record_replyer_action_temp( - chat_id=self.stream_id, - reason=reason, - think_level=think_level, - ) - - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, - thinking_id=thinking_id, - action_data={}, - action_name="reply", - action_reasoning=reason, - ) - - # 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用) - unknown_words = None - quote_message = None - if isinstance(action_planner_info.action_data, dict): - uw = action_planner_info.action_data.get("unknown_words") - if isinstance(uw, list): - cleaned_uw: List[str] = [] - for item in uw: - if isinstance(item, str): - s = item.strip() - if s: - cleaned_uw.append(s) - if cleaned_uw: - unknown_words = cleaned_uw - - # 从 Planner 的 action_data 中提取 quote_message 参数 - qm = action_planner_info.action_data.get("quote") - if qm is not None: - # 支持多种格式:true/false, "true"/"false", 1/0 - if isinstance(qm, bool): - quote_message = qm - elif isinstance(qm, str): - quote_message = qm.lower() in ("true", "1", "yes") - elif isinstance(qm, (int, float)): - quote_message = bool(qm) - - logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}") - - success, llm_response = await generator_api.generate_reply( - chat_stream=self.chat_stream, - reply_message=action_planner_info.action_message, - available_actions=available_actions, - chosen_actions=chosen_action_plan_infos, - reply_reason=planner_reasoning, - unknown_words=unknown_words, - enable_tool=global_config.tool.enable_tool, - request_type="replyer", - from_plugin=False, - reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()), - think_level=think_level, - ) - - if not success or not llm_response or not llm_response.reply_set: - if action_planner_info.action_message: - logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败") - else: - logger.info("回复生成失败") - return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None} - - response_set = llm_response.reply_set - selected_expressions = llm_response.selected_expressions - loop_info, reply_text, _ = await self._send_and_store_reply( - response_set=response_set, - action_message=action_planner_info.action_message, # type: ignore - cycle_timers=cycle_timers, - thinking_id=thinking_id, - actions=chosen_action_plan_infos, - selected_expressions=selected_expressions, - quote_message=quote_message, - ) - self.last_active_time = time.time() - return { - "action_type": "reply", - "success": True, - "result": f"你使用reply动作,对' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'", - "loop_info": loop_info, - } - - else: - # 执行普通动作 - with Timer("动作执行", cycle_timers): - success, result = await self._handle_action( - action=action_planner_info.action_type, - action_reasoning=action_planner_info.action_reasoning or "", - action_data=action_planner_info.action_data or {}, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - action_message=action_planner_info.action_message, - ) - - self.last_active_time = time.time() - return { - "action_type": action_planner_info.action_type, - "success": success, - "result": result, - } - - except Exception as e: - logger.error(f"{self.log_prefix} 执行动作时出错: {e}") - logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") - return { - "action_type": action_planner_info.action_type, - "success": False, - "result": "", - "loop_info": None, - "error": str(e), - } diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 2c1eb162..74d94773 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -1,377 +1,231 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from rich.traceback import install +from typing import List, Optional, TYPE_CHECKING import asyncio import random import time import traceback -from rich.traceback import install - -from src.learners.expression_learner import ExpressionLearner -from src.learners.jargon_miner import JargonMiner -from src.chat.event_helpers import build_event_message -from src.chat.logger.plan_reply_logger import PlanReplyLogger -from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.planner_actions.action_modifier import ActionModifier -from src.chat.planner_actions.planner import ActionPlanner -from src.chat.utils.prompt_builder import global_prompt_manager -from src.chat.utils.timer_calculator import Timer -from src.chat.utils.utils import record_replyer_action_temp -from src.common.data_models.info_data_model import ActionPlannerInfo -from src.common.data_models.message_component_data_model import MessageSequence, TextComponent +from src.chat.message_receive.chat_manager import chat_manager from src.common.logger import get_logger from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils from src.config.config import global_config from src.config.file_watcher import FileChange -from src.core.event_bus import event_bus -from src.core.types import ActionInfo, EventType -from src.person_info.person_info import Person -from src.services import ( - database_service as database_api, - generator_service as generator_api, - message_service as message_api, - send_service as send_api, -) -from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat +from src.learners.expression_learner import ExpressionLearner +from src.learners.jargon_miner import JargonMiner from .heartFC_utils import CycleDetail if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage - install(extra_lines=5) logger = get_logger("heartFC_chat") class HeartFChatting: - """管理一个持续运行的 Focus Chat 会话。""" + """ + 管理一个连续的Focus Chat聊天会话 + 用于在特定的聊天会话里面生成回复 + """ def __init__(self, session_id: str): - self.session_id = session_id - self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.session_id) # type: ignore[assignment] - if not self.chat_stream: - raise ValueError(f"无法找到聊天会话 {self.session_id}") + """ + 初始化 HeartFChatting 实例 - session_name = _chat_manager.get_session_name(session_id) or session_id + Args: + session_id: 聊天会话ID + """ + # 基础属性 + self.session_id = session_id + session_name = chat_manager.get_session_name(session_id) or session_id self.log_prefix = f"[{session_name}]" self.session_name = session_name - self.action_manager = ActionManager() - self.action_planner = ActionPlanner(chat_id=self.session_id, action_manager=self.action_manager) - self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.session_id) - + # 系统运行状态 self._running: bool = False self._loop_task: Optional[asyncio.Task] = None + self._cycle_counter: int = 0 + self._hfc_lock: asyncio.Lock = asyncio.Lock() # 用于保护 _hfc_func 的并发访问 + # 聊天频率相关 + self._consecutive_no_reply_count = 0 # 跟踪连续 no_reply 次数,用于动态调整阈值 + self._talk_frequency_adjust: float = 1.0 # 发言频率修正值,默认为1.0,可以根据需要调整 + + # HFC内消息缓存 + self.message_cache: List[SessionMessage] = [] + + # Asyncio Event 用于控制循环的开始和结束 self._cycle_event = asyncio.Event() - self._hfc_lock = asyncio.Lock() - - self._cycle_counter = 0 - self._current_cycle_detail: Optional[CycleDetail] = None - self.history_loop: List[CycleDetail] = [] - - self.last_read_time = time.time() - 2 - self.last_active_time = time.time() - self._talk_frequency_adjust = 1.0 - self._consecutive_no_reply_count = 0 - - self.message_cache: List["SessionMessage"] = [] - - self._min_messages_for_extraction = 30 - self._min_extraction_interval = 60 - self._last_extraction_time = 0.0 + # 表达方式相关内容 + self._min_messages_for_extraction = 30 # 最少提取消息数 + self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒 + self._last_extraction_time: float = 0.0 # 上次提取的时间戳 expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id) - self._enable_expression_use = expr_use - self._enable_expression_learning = expr_learn - self._enable_jargon_learning = jargon_learn - self._expression_learner = ExpressionLearner(session_id) - self._jargon_miner = JargonMiner(session_id, session_name=session_name) + self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习 + self._enable_expression_learning = expr_learn # 允许学习表达方式 + self._enable_jargon_learning = jargon_learn # 允许学习黑话 + # 表达学习器 + self._expression_learner: ExpressionLearner = ExpressionLearner(session_id) + # 黑话挖掘器 + self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name) + + # TODO: ChatSummarizer 聊天总结器重构 + + # ====== 公开方法 ====== async def start(self): + """启动 HeartFChatting 的主循环""" + # 先检查是否已经启动运行 if self._running: - logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中") + logger.debug(f"{self.log_prefix} 已经在运行中,无需重复启动") return try: self._running = True - self._cycle_event.clear() + self._cycle_event.clear() # 确保事件初始状态为未设置 + self._loop_task = asyncio.create_task(self.main_loop()) self._loop_task.add_done_callback(self._handle_loop_completion) + logger.info(f"{self.log_prefix} HeartFChatting 启动完成") - except Exception as exc: - logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {exc}", exc_info=True) - self._running = False - self._cycle_event.set() - self._loop_task = None + except Exception as e: + logger.error(f"{self.log_prefix} 启动 HeartFChatting 失败: {e}", exc_info=True) + self._running = False # 确保状态正确 + self._cycle_event.set() # 确保事件被设置,避免死锁 + self._loop_task = None # 确保任务引用被清理 raise async def stop(self): + """停止 HeartFChatting 的主循环""" if not self._running: - logger.debug(f"{self.log_prefix} HeartFChatting 已停止") + logger.debug(f"{self.log_prefix} HeartFChatting 已经停止,无需重复停止") return self._running = False - self._cycle_event.set() + self._cycle_event.set() # 触发事件,通知循环结束 if self._loop_task: - self._loop_task.cancel() + self._loop_task.cancel() # 取消主循环任务 try: - await self._loop_task + await self._loop_task # 等待任务完成 except asyncio.CancelledError: - logger.info(f"{self.log_prefix} HeartFChatting 主循环已取消") - except Exception as exc: - logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {exc}", exc_info=True) + logger.info(f"{self.log_prefix} HeartFChatting 主循环已成功取消") + except Exception as e: + logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {e}", exc_info=True) finally: - self._loop_task = None + self._loop_task = None # 确保任务引用被清理 logger.info(f"{self.log_prefix} HeartFChatting 已停止") def adjust_talk_frequency(self, new_value: float): + """调整发言频率的调整值 + + Args: + new_value: 新的修正值,必须为非负数。值越大,修正发言频率越高;值越小,修正发言频率越低。 + """ self._talk_frequency_adjust = max(0.0, new_value) async def register_message(self, message: "SessionMessage"): + """注册一条消息到 HeartFChatting 的缓存中,并检测其是否产生提及,决定是否唤醒聊天 + + Args: + message: 待注册的消息对象 + """ self.message_cache.append(message) - + # 先检查at必回复 if global_config.chat.inevitable_at_reply and message.is_at: - self.last_read_time = time.time() - async with self._hfc_lock: - await self._judge_and_response(mentioned_message=message, recent_messages_list=[message]) - return - + async with self._hfc_lock: # 确保与主循环逻辑的互斥访问 + await self._judge_and_response(message) + return # 直接返回,避免同一条消息被主循环再次处理 + # 再检查提及必回复 if global_config.chat.mentioned_bot_reply and message.is_mentioned: - self.last_read_time = time.time() - async with self._hfc_lock: - await self._judge_and_response(mentioned_message=message, recent_messages_list=[message]) + # 直接获取锁,确保一定一定触发回复逻辑,不受当前是否正在执行主循环的影响 + async with self._hfc_lock: # 确保与主循环逻辑的互斥访问 + await self._judge_and_response(message) return async def main_loop(self): try: while self._running and not self._cycle_event.is_set(): if not self._hfc_lock.locked(): - async with self._hfc_lock: + async with self._hfc_lock: # 确保主循环逻辑的互斥访问 await self._hfc_func() - await asyncio.sleep(0.1) + await asyncio.sleep(5) except asyncio.CancelledError: - logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消") - except Exception as exc: - logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常: {exc}", exc_info=True) - await self.stop() + logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消,正在关闭") + except Exception as e: + logger.error(f"{self.log_prefix} 麦麦聊天意外错误: {e},将于3s后尝试重新启动") + await self.stop() # 确保状态正确 await asyncio.sleep(3) - await self.start() + await self.start() # 尝试重新启动 async def _config_callback(self, file_change: Optional[FileChange] = None): - del file_change - expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.session_id) - self._enable_expression_use = expr_use - self._enable_expression_learning = expr_learn - self._enable_jargon_learning = jargon_learn + """配置文件变更回调函数""" + # TODO: 根据配置文件变动重新计算相关参数: + """ + 需要计算的参数: + self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习 + self._enable_expression_learning = expr_learn # 允许学习表达方式 + self._enable_jargon_learning = jargon_learn # 允许学习黑话 + """ - async def _hfc_func(self): - recent_messages_list = message_api.get_messages_by_time_in_chat( - chat_id=self.session_id, - start_time=self.last_read_time, - end_time=time.time(), - limit=20, - limit_mode="latest", - filter_mai=True, - filter_command=False, - filter_intercept_message_level=1, - ) + # ====== 心流聊天核心逻辑 ====== + async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None): + """心流聊天的主循环逻辑""" + if self._consecutive_no_reply_count >= 5: + threshold = 2 + elif self._consecutive_no_reply_count >= 3: + threshold = 2 if random.random() < 0.5 else 1 + else: + threshold = 1 - if len(recent_messages_list) < 1: + if len(self.message_cache) < threshold: await asyncio.sleep(0.2) return True - self.last_read_time = time.time() - - mentioned_message: Optional["SessionMessage"] = None - for message in recent_messages_list: - if global_config.chat.inevitable_at_reply and message.is_at: - mentioned_message = message - elif global_config.chat.mentioned_bot_reply and message.is_mentioned: - mentioned_message = message - - talk_value = ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust - if mentioned_message: - await self._judge_and_response(mentioned_message=mentioned_message, recent_messages_list=recent_messages_list) - elif random.random() < talk_value: - await self._judge_and_response(recent_messages_list=recent_messages_list) + talk_value_threshold = ( + random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust + ) + if mentioned_message and global_config.chat.mentioned_bot_reply: + await self._judge_and_response(mentioned_message) + elif random.random() < talk_value_threshold: + await self._judge_and_response() return True - async def _judge_and_response( - self, - mentioned_message: Optional["SessionMessage"] = None, - recent_messages_list: Optional[List["SessionMessage"]] = None, - ): - recent_messages = list(recent_messages_list or self.message_cache[-20:]) - if recent_messages: - asyncio.create_task(self._trigger_expression_learning(recent_messages)) - - cycle_timers, thinking_id = self._start_cycle() + async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None): + """判定和生成回复""" + asyncio.create_task(self._trigger_expression_learning(self.message_cache)) + # TODO: 完成反思器之后的逻辑 + start_time = time.time() + current_cycle_detail = self._start_cycle() logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") - try: - async with global_prompt_manager.async_message_scope(self._get_template_name()): - available_actions: Dict[str, ActionInfo] = {} - try: - await self.action_modifier.modify_actions() - available_actions = self.action_manager.get_using_actions() - except Exception as exc: - logger.error(f"{self.log_prefix} 动作修改失败: {exc}", exc_info=True) + # TODO: 动作检查逻辑 + # TODO: Planner逻辑 + # TODO: 动作执行逻辑 - is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() - message_list_before_now = get_messages_before_time_in_chat( - chat_id=self.session_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.6), - filter_intercept_message_level=1, - ) - chat_content_block, message_id_list = build_readable_messages_with_id( - messages=message_list_before_now, - timestamp_mode="normal_no_YMD", - read_mark=self.action_planner.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - - prompt, filtered_actions = await self._build_planner_prompt_with_event( - available_actions=available_actions, - is_group_chat=is_group_chat, - chat_target_info=chat_target_info, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - ) - if prompt is None: - return False - - with Timer("规划器", cycle_timers): - reasoning, action_to_use_info, llm_raw_output, llm_reasoning, llm_duration_ms = ( - await self.action_planner._execute_main_planner( - prompt=prompt, - message_id_list=message_id_list, - filtered_actions=filtered_actions, - available_actions=available_actions, - loop_start_time=self.last_read_time, - ) - ) - - action_to_use_info = self._ensure_force_reply_action( - actions=action_to_use_info, - force_reply_message=mentioned_message, - available_actions=available_actions, - ) - self.action_planner.add_plan_log(reasoning, action_to_use_info) - self.action_planner.last_obs_time_mark = time.time() - self._log_plan( - prompt=prompt, - reasoning=reasoning, - llm_raw_output=llm_raw_output, - llm_reasoning=llm_reasoning, - llm_duration_ms=llm_duration_ms, - actions=action_to_use_info, - ) - - logger.info( - f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}" - ) - - action_tasks = [ - asyncio.create_task( - self._execute_action( - action, - action_to_use_info, - thinking_id, - available_actions, - cycle_timers, - ) - ) - for action in action_to_use_info - ] - results = await asyncio.gather(*action_tasks, return_exceptions=True) - - reply_loop_info = None - reply_text_from_reply = "" - action_success = False - action_reply_text = "" - execute_result_str = "" - - for result in results: - if isinstance(result, BaseException): - logger.error(f"{self.log_prefix} 动作执行异常: {result}", exc_info=True) - continue - - execute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n" - if result["action_type"] == "reply": - if result["success"]: - reply_loop_info = result["loop_info"] - reply_text_from_reply = result["result"] - else: - logger.warning(f"{self.log_prefix} reply 动作执行失败") - else: - action_success = result["success"] - action_reply_text = result["result"] - - self.action_planner.add_plan_excute_log(result=execute_result_str) - - if reply_loop_info: - loop_info = reply_loop_info - loop_info["loop_action_info"].update( - { - "action_taken": action_success, - "taken_time": time.time(), - } - ) - else: - loop_info = { - "loop_plan_info": { - "action_result": action_to_use_info, - }, - "loop_action_info": { - "action_taken": action_success, - "reply_text": action_reply_text, - "taken_time": time.time(), - }, - } - reply_text_from_reply = action_reply_text - - current_cycle_detail = self._end_cycle(self._current_cycle_detail, loop_info) - logger.debug(f"{self.log_prefix} 本轮最终输出: {reply_text_from_reply}") - return current_cycle_detail is not None - except Exception as exc: - logger.error(f"{self.log_prefix} 判定与回复流程失败: {exc}", exc_info=True) - if self._current_cycle_detail: - self._end_cycle( - self._current_cycle_detail, - { - "loop_plan_info": {"action_result": []}, - "loop_action_info": { - "action_taken": False, - "reply_text": "", - "taken_time": time.time(), - "error": str(exc), - }, - }, - ) - return False + cycle_detail = self._end_cycle(current_cycle_detail) + if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0: + await asyncio.sleep(wait_time) + else: + await asyncio.sleep(0.1) # 最小等待时间,避免过快循环 + return True def _handle_loop_completion(self, task: asyncio.Task): + """当 _hfc_func 任务完成时执行的回调。""" try: if exception := task.exception(): - logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常退出: {exception}") - logger.error(traceback.format_exc()) + logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}") + logger.error(traceback.format_exc()) # Log full traceback for exceptions else: - logger.info(f"{self.log_prefix} HeartFChatting: 主循环已退出") + logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)") except asyncio.CancelledError: - logger.info(f"{self.log_prefix} HeartFChatting: 聊天已结束") + logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") + # ====== 学习器触发逻辑 ====== async def _trigger_expression_learning(self, messages: List["SessionMessage"]): - if not messages: - return - self._expression_learner.add_messages(messages) if time.time() - self._last_extraction_time < self._min_extraction_interval: return @@ -379,14 +233,12 @@ class HeartFChatting: return if not self._enable_expression_learning: return - extraction_end_time = time.time() logger.info( f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息," f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}" ) self._last_extraction_time = extraction_end_time - try: jargon_miner = self._jargon_miner if self._enable_jargon_learning else None learnt_style = await self._expression_learner.learn(jargon_miner) @@ -394,398 +246,43 @@ class HeartFChatting: logger.info(f"{self.log_prefix} 表达学习完成") else: logger.debug(f"{self.log_prefix} 表达学习未获得有效结果") - except Exception as exc: - logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True) + except Exception as e: + logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True) - def _start_cycle(self) -> Tuple[Dict[str, float], str]: + # ====== 记录循环执行信息相关逻辑 ====== + def _start_cycle(self) -> CycleDetail: self._cycle_counter += 1 - self._current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter) - self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" - return self._current_cycle_detail.time_records, self._current_cycle_detail.thinking_id + current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter) + current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" + return current_cycle_detail - def _end_cycle(self, cycle_detail: Optional[CycleDetail], loop_info: Optional[Dict[str, Any]] = None): - if cycle_detail is None: - return None - - cycle_detail.loop_plan_info = (loop_info or {}).get("loop_plan_info") - cycle_detail.loop_action_info = (loop_info or {}).get("loop_action_info") + def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True): cycle_detail.end_time = time.time() - self.history_loop.append(cycle_detail) - - timer_strings = [ + timer_strings: List[str] = [ f"{name}: {duration:.2f}s" for name, duration in cycle_detail.time_records.items() - if duration >= 0.1 + if not only_long_execution or duration >= 0.1 ] logger.info( - f"{self.log_prefix} 第{cycle_detail.cycle_id} 个心流循环完成," - f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}s;" + f"{self.log_prefix} 第 {cycle_detail.cycle_id} 个心流循环完成" + f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}秒\n" f"详细计时: {', '.join(timer_strings) if timer_strings else '无'}" ) + return cycle_detail - async def _execute_action( - self, - action_planner_info: ActionPlannerInfo, - chosen_action_plan_infos: List[ActionPlannerInfo], - thinking_id: str, - available_actions: Dict[str, ActionInfo], - cycle_timers: Dict[str, float], - ): - try: - with Timer(f"动作{action_planner_info.action_type}", cycle_timers): - if action_planner_info.action_type == "no_reply": - reason = action_planner_info.reasoning or "选择不回复" - self._consecutive_no_reply_count += 1 - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=reason, - thinking_id=thinking_id, - action_data={}, - action_name="no_reply", - action_reasoning=reason, - ) - return { - "action_type": "no_reply", - "success": True, - "result": "选择不回复", - "loop_info": None, - } + # ====== Action相关逻辑 ====== + async def _execute_action(self, *args, **kwargs): + """原ExecuteAction""" + raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符 - if action_planner_info.action_type == "reply": - self._consecutive_no_reply_count = 0 - reason = action_planner_info.reasoning or "" - think_level = self._get_think_level(action_planner_info) - planner_reasoning = action_planner_info.action_reasoning or reason + async def _execute_other_actions(self, *args, **kwargs): + """原HandleAction""" + raise NotImplementedError( + "执行其他动作的逻辑尚未实现" + ) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符 - record_replyer_action_temp( - chat_id=self.session_id, - reason=reason, - think_level=think_level, - ) - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=reason, - thinking_id=thinking_id, - action_data={}, - action_name="reply", - action_reasoning=reason, - ) - - unknown_words, quote_message = self._extract_reply_metadata(action_planner_info) - success, llm_response = await generator_api.generate_reply( - chat_stream=self.chat_stream, - reply_message=action_planner_info.action_message, - available_actions=available_actions, - chosen_actions=chosen_action_plan_infos, - reply_reason=planner_reasoning, - unknown_words=unknown_words, - enable_tool=global_config.tool.enable_tool, - request_type="replyer", - from_plugin=False, - reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()) - if action_planner_info.action_data - else time.time(), - think_level=think_level, - ) - if not success or not llm_response or not llm_response.reply_set: - if action_planner_info.action_message: - logger.info( - f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败" - ) - else: - logger.info(f"{self.log_prefix} 回复生成失败") - return { - "action_type": "reply", - "success": False, - "result": "回复生成失败", - "loop_info": None, - } - - loop_info, reply_text, _ = await self._send_and_store_reply( - response_set=llm_response.reply_set, - action_message=action_planner_info.action_message, # type: ignore[arg-type] - cycle_timers=cycle_timers, - thinking_id=thinking_id, - actions=chosen_action_plan_infos, - selected_expressions=llm_response.selected_expressions, - quote_message=quote_message, - ) - self.last_active_time = time.time() - return { - "action_type": "reply", - "success": True, - "result": reply_text, - "loop_info": loop_info, - } - - with Timer("动作执行", cycle_timers): - success, result = await self._handle_action( - action=action_planner_info.action_type, - action_reasoning=action_planner_info.action_reasoning or "", - action_data=action_planner_info.action_data or {}, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - action_message=action_planner_info.action_message, - ) - if success: - self.last_active_time = time.time() - return { - "action_type": action_planner_info.action_type, - "success": success, - "result": result, - "loop_info": None, - } - except Exception as exc: - logger.error(f"{self.log_prefix} 执行动作时出错: {exc}", exc_info=True) - return { - "action_type": action_planner_info.action_type, - "success": False, - "result": "", - "loop_info": None, - "error": str(exc), - } - - async def _handle_action( - self, - action: str, - action_reasoning: str, - action_data: dict, - cycle_timers: Dict[str, float], - thinking_id: str, - action_message: Optional["SessionMessage"] = None, - ) -> Tuple[bool, str]: - try: - action_handler = self.action_manager.create_action( - action_name=action, - action_data=action_data, - action_reasoning=action_reasoning, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - chat_stream=self.chat_stream, - log_prefix=self.log_prefix, - action_message=action_message, - ) - if not action_handler: - logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}") - return False, "" - - success, action_text = await action_handler.execute() - return success, action_text - except Exception as exc: - logger.error(f"{self.log_prefix} 处理动作 {action} 时出错: {exc}", exc_info=True) - return False, "" - - async def _send_and_store_reply( - self, - response_set: MessageSequence, - action_message: "SessionMessage", - cycle_timers: Dict[str, float], - thinking_id: str, - actions: List[ActionPlannerInfo], - selected_expressions: Optional[List[int]] = None, - quote_message: Optional[bool] = None, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: - with Timer("回复发送", cycle_timers): - reply_text = await self._send_response( - reply_set=response_set, - message_data=action_message, - selected_expressions=selected_expressions, - quote_message=quote_message, - ) - - platform = action_message.platform or getattr(self.chat_stream, "platform", "unknown") - person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id) - action_prompt_display = f"你对{person.person_name}进行了回复:{reply_text}" - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=action_prompt_display, - thinking_id=thinking_id, - action_data={"reply_text": reply_text}, - action_name="reply", - ) - - loop_info: Dict[str, Any] = { - "loop_plan_info": { - "action_result": actions, - }, - "loop_action_info": { - "action_taken": True, - "reply_text": reply_text, - "command": "", - "taken_time": time.time(), - }, - } - return loop_info, reply_text, cycle_timers - - async def _send_response( - self, - reply_set: MessageSequence, - message_data: "SessionMessage", - selected_expressions: Optional[List[int]] = None, - quote_message: Optional[bool] = None, - ) -> str: - if global_config.chat.llm_quote: - need_reply = bool(quote_message) - else: - new_message_count = message_api.count_new_messages( - chat_id=self.session_id, - start_time=self.last_read_time, - end_time=time.time(), - ) - need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90 - - reply_text = "" - first_replied = False - for component in reply_set.components: - if not isinstance(component, TextComponent): - continue - data = component.text - if not first_replied: - await send_api.text_to_stream( - text=data, - stream_id=self.session_id, - reply_message=message_data, - set_reply=need_reply, - typing=False, - selected_expressions=selected_expressions, - ) - first_replied = True - else: - await send_api.text_to_stream( - text=data, - stream_id=self.session_id, - reply_message=message_data, - set_reply=False, - typing=True, - selected_expressions=selected_expressions, - ) - reply_text += data - return reply_text - - async def _build_planner_prompt_with_event( - self, - available_actions: Dict[str, ActionInfo], - is_group_chat: bool, - chat_target_info: Any, - chat_content_block: str, - message_id_list: List[Tuple[str, "SessionMessage"]], - ) -> Tuple[Optional[str], Dict[str, ActionInfo]]: - filtered_actions = self.action_planner._filter_actions_by_activation_type(available_actions, chat_content_block) - prompt, _ = await self.action_planner.build_planner_prompt( - is_group_chat=is_group_chat, - chat_target_info=chat_target_info, - current_available_actions=filtered_actions, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - ) - event_message = build_event_message(EventType.ON_PLAN, llm_prompt=prompt, stream_id=self.session_id) - continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, event_message) - if not continue_flag: - logger.info(f"{self.log_prefix} ON_PLAN 事件中止了本轮 HFC") - return None, filtered_actions - if modified_message and modified_message._modify_flags.modify_llm_prompt and modified_message.llm_prompt: - prompt = modified_message.llm_prompt - return prompt, filtered_actions - - def _ensure_force_reply_action( - self, - actions: List[ActionPlannerInfo], - force_reply_message: Optional["SessionMessage"], - available_actions: Dict[str, ActionInfo], - ) -> List[ActionPlannerInfo]: - if not force_reply_message: - return actions - - has_reply_to_force_message = any( - action.action_type == "reply" - and action.action_message - and action.action_message.message_id == force_reply_message.message_id - for action in actions - ) - if has_reply_to_force_message: - return actions - - actions = [action for action in actions if action.action_type != "no_reply"] - actions.insert( - 0, - ActionPlannerInfo( - action_type="reply", - reasoning="用户提及了我,必须回复该消息", - action_data={"loop_start_time": self.last_read_time}, - action_message=force_reply_message, - available_actions=available_actions, - action_reasoning=None, - ), - ) - logger.info(f"{self.log_prefix} 检测到强制回复消息,已补充 reply 动作") - return actions - - def _log_plan( - self, - prompt: str, - reasoning: str, - llm_raw_output: Optional[str], - llm_reasoning: Optional[str], - llm_duration_ms: Optional[float], - actions: List[ActionPlannerInfo], - ) -> None: - try: - PlanReplyLogger.log_plan( - chat_id=self.session_id, - prompt=prompt, - reasoning=reasoning, - raw_output=llm_raw_output, - raw_reasoning=llm_reasoning, - actions=actions, - timing={ - "llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None, - "loop_start_time": self.last_read_time, - }, - extra=None, - ) - except Exception: - logger.exception(f"{self.log_prefix} 记录 plan 日志失败") - - def _extract_reply_metadata( - self, - action_planner_info: ActionPlannerInfo, - ) -> Tuple[Optional[List[str]], Optional[bool]]: - unknown_words: Optional[List[str]] = None - quote_message: Optional[bool] = None - action_data = action_planner_info.action_data or {} - - raw_unknown_words = action_data.get("unknown_words") - if isinstance(raw_unknown_words, list): - cleaned_unknown_words = [] - for item in raw_unknown_words: - if isinstance(item, str) and (cleaned_item := item.strip()): - cleaned_unknown_words.append(cleaned_item) - if cleaned_unknown_words: - unknown_words = cleaned_unknown_words - - raw_quote = action_data.get("quote") - if isinstance(raw_quote, bool): - quote_message = raw_quote - elif isinstance(raw_quote, str): - quote_message = raw_quote.lower() in {"true", "1", "yes"} - elif isinstance(raw_quote, (int, float)): - quote_message = bool(raw_quote) - - return unknown_words, quote_message - - def _get_think_level(self, action_planner_info: ActionPlannerInfo) -> int: - think_mode = global_config.chat.think_mode - if think_mode == "default": - return 0 - if think_mode == "deep": - return 1 - if think_mode == "dynamic": - action_data = action_planner_info.action_data or {} - return int(action_data.get("think_level", 1)) - return 0 - - def _get_template_name(self) -> Optional[str]: - if self.chat_stream.context: - return self.chat_stream.context.template_name - return None + # ====== 响应发送相关方法 ====== + async def _send_response(self, *args, **kwargs): + raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符 + # 传入的消息至少应该是个MessageSequence实例,最好是SessionMessage实例,随后可直接转化为MessageSending实例 diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py deleted file mode 100644 index febff2d5..00000000 --- a/src/chat/heart_flow/heartflow.py +++ /dev/null @@ -1,42 +0,0 @@ -import traceback -from typing import Any, Optional, Dict - -from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.logger import get_logger -from src.chat.heart_flow.heartFC_chat import HeartFChatting -from src.chat.brain_chat.brain_chat import BrainChatting -from src.chat.message_receive.chat_stream import ChatStream - -logger = get_logger("heartflow") - - -class Heartflow: - """主心流协调器,负责初始化并协调聊天""" - - def __init__(self): - self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {} - - async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]: - """获取或创建一个新的HeartFChatting实例""" - try: - if chat_id in self.heartflow_chat_list: - if chat := self.heartflow_chat_list.get(chat_id): - return chat - else: - chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id) - if not chat_stream: - raise ValueError(f"未找到 chat_id={chat_id} 的聊天流") - if chat_stream.group_info: - new_chat = HeartFChatting(chat_id=chat_id) - else: - new_chat = BrainChatting(chat_id=chat_id) - await new_chat.start() - self.heartflow_chat_list[chat_id] = new_chat - return new_chat - except Exception as e: - logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True) - traceback.print_exc() - return None - - -heartflow = Heartflow()