merge: sync upstream/r-dev and resolve real conflicts
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
1. 尽量保持良好的注释
|
||||
2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。
|
||||
3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。
|
||||
4. 对于类,方法以及模块的注释,首选使用的注释格式为 Google DocStr 格式,但保证语言为简体中文
|
||||
## 类型注解规范
|
||||
1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。
|
||||
2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解)
|
||||
@@ -35,3 +36,7 @@
|
||||
# 运行/调试/构建/测试/依赖
|
||||
优先使用uv
|
||||
依赖项以 pyproject.toml 为准
|
||||
|
||||
# 语言规范
|
||||
|
||||
项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都应该首要以简体中文为首要实现目标
|
||||
|
||||
137
README.md
137
README.md
@@ -1,21 +1,21 @@
|
||||
<div align="center">
|
||||
|
||||
<!-- Language Switcher -->
|
||||
<a href="README.md">简体中文</a> | <a href="docs/README_EN.md">English</a>
|
||||
<a href="docs/README_CN.md">简体中文</a> | <a href="README.md">English</a>
|
||||
|
||||
<br>
|
||||
<br>
|
||||
|
||||
<h1>麦麦 MaiBot <sub><small>MaiSaka</small></sub></h1>
|
||||
<h1>MaiBot <sub><small>MaiSaka</small></sub></h1>
|
||||
|
||||
<!-- Badges Row -->
|
||||
<p>
|
||||
<img src="https://img.shields.io/badge/Python-3.10+-blue" alt="Python Version">
|
||||
<img src="https://img.shields.io/github/license/Mai-with-u/MaiBot?label=%E5%8D%8F%E8%AE%AE" alt="License">
|
||||
<img src="https://img.shields.io/badge/状态-开发中-yellow" alt="Status">
|
||||
<img src="https://img.shields.io/github/contributors/Mai-with-u/MaiBot.svg?style=flat&label=%E8%B4%A1%E7%8C%AE%E8%80%85" alt="Contributors">
|
||||
<img src="https://img.shields.io/github/forks/Mai-with-u/MaiBot.svg?style=flat&label=%E5%88%86%E6%94%AF%E6%95%B0" alt="Forks">
|
||||
<img src="https://img.shields.io/github/stars/Mai-with-u/MaiBot?style=flat&label=%E6%98%9F%E6%A0%87%E6%95%B0" alt="Stars">
|
||||
<img src="https://img.shields.io/github/license/Mai-with-u/MaiBot?label=License" alt="License">
|
||||
<img src="https://img.shields.io/badge/Status-In%20Development-yellow" alt="Status">
|
||||
<img src="https://img.shields.io/github/contributors/Mai-with-u/MaiBot.svg?style=flat&label=Contributors" alt="Contributors">
|
||||
<img src="https://img.shields.io/github/forks/Mai-with-u/MaiBot.svg?style=flat&label=Forks" alt="Forks">
|
||||
<img src="https://img.shields.io/github/stars/Mai-with-u/MaiBot?style=flat&label=Stars" alt="Stars">
|
||||
<a href="https://deepwiki.com/DrSmoothl/MaiBot"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
|
||||
</p>
|
||||
</div>
|
||||
@@ -25,31 +25,24 @@
|
||||
<!-- Mascot on the Right (Float) -->
|
||||
<img src="depends-data/maimai-v2.png" align="right" width="40%" alt="MaiBot Character" style="margin-left: 20px; margin-bottom: 20px;">
|
||||
|
||||
## 介绍
|
||||
## Introduction
|
||||
|
||||
麦麦MaiSaka 是一个基于大语言模型的可交互智能体
|
||||
MaiSaka is an interactive agent based on large language models.
|
||||
|
||||
MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务的“有帮助的助手”,她还是一个致力于了解你,并以真实人类的风格进行交互的数字生命,她不追求完美,她不追求高效,但追求亲切和真实。
|
||||
MaiSaka is more than just a bot, and more than a "helpful assistant" that completes tasks. She is a digital life form that tries to understand you and interact in a genuinely human style. She does not pursue perfection or efficiency above all else. She pursues warmth and authenticity.
|
||||
|
||||
- 💭 **No one likes GPT-sounding dialogue**: MaiSaka uses a more natural conversational style. Instead of long-winded markdown-heavy replies, she chats in a way that feels casual, varied, and human.
|
||||
- 🎭 **No longer stuck in rigid Q&A**: She knows when to speak, how to read the room, when to join a conversation, and when to stay quiet.
|
||||
- 🧠 **MaiSaka becoming human**: In group conversations, MaiSaka imitates how people around her speak, learns new slang and in-group language, and keeps evolving.
|
||||
- ❤️ **Always learning more about you**: Inspired by personality theory in psychology, MaiSaka gradually builds an understanding of your preferences, traits, habits, and behavior style.
|
||||
- 🔌 **Plugin system**: Provides powerful APIs and an event system with virtually unlimited room for extension.
|
||||
|
||||
- 💭 **没有人喜欢GPT的语言风格**:麦麦使用了更加自然,贴合人类对话习惯的交互方式,不是长篇大论或者markdown格式的分点,而是或长或短的闲谈。
|
||||
|
||||
- 🎭 **不再是傻乎乎的一问一答**:懂得在合适的时间说话,把握聊天中的气氛,在合适的时候开口,在合适的时候闭嘴。
|
||||
|
||||
- 🧠 **麦麦·成为人类**:在多人对话中,麦麦会模仿其他人的的说话风格,还会自主理解新词或者小圈子里的黑话,不断进化。
|
||||
|
||||
- ❤️ **永远都在更加了解你**:基于心理学中人格理论,麦麦会不断积累对于你的了解,不论是你的信息,喜恶或是行为风格,她都记在心里。
|
||||
|
||||
- 🔌 **插件系统**:提供强大的 API 和事件系统,无限扩展可能。
|
||||
|
||||
|
||||
|
||||
### 快速导航
|
||||
### Quick Navigation
|
||||
<p>
|
||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P">🌟 演示视频</a> |
|
||||
<a href="#-更新和安装">📦 快速入门</a> |
|
||||
<a href="#-部署教程">📃 核心文档</a> |
|
||||
<a href="#-讨论与社区">💬 加入社区</a>
|
||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P">🌟 Demo Video</a> |
|
||||
<a href="#-updates-and-installation">📦 Quick Start</a> |
|
||||
<a href="#-deployment-guide">📃 Core Documentation</a> |
|
||||
<a href="#-discussion-and-community">💬 Join Community</a>
|
||||
</p>
|
||||
|
||||
<!-- Clear float to ensure subsequent content starts below the image area if text is short -->
|
||||
@@ -60,103 +53,103 @@ MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务
|
||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||
<picture>
|
||||
<source media="(max-width: 600px)" srcset="depends-data/video.png" width="100%">
|
||||
<img src="depends-data/video.png" width="60%" alt="麦麦演示视频" style="border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
|
||||
<img src="depends-data/video.png" width="60%" alt="MaiSaka Demo Video" style="border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
|
||||
</picture>
|
||||
<br>
|
||||
<small>前往观看麦麦演示视频</small>
|
||||
<small>Watch the MaiSaka demo video</small>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## 🔥 更新和安装
|
||||
## 🔥 Updates and Installation
|
||||
|
||||
> **最新版本: v1.0.0** ([📄 更新日志](changelogs/changelog.md))
|
||||
> **Latest Version: v1.0.0** ([📄 Changelog](changelogs/changelog.md))
|
||||
|
||||
- **下载**: 前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
- **启动器**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (仅支持 MacOS, 早期开发中)
|
||||
- **Download**: Visit the [Release](https://github.com/MaiM-with-u/MaiBot/releases/) page to get the latest version.
|
||||
- **Launcher**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (MacOS only, still in early development).
|
||||
|
||||
| 分支 | 说明 |
|
||||
| Branch | Description |
|
||||
| :--- | :--- |
|
||||
| `main` | ✅ **稳定发布版本 (推荐)** |
|
||||
| `dev` | 🚧 开发测试版本,包含新功能,可能不稳定 |
|
||||
| `main` | ✅ **Stable release (recommended)** |
|
||||
| `dev` | 🚧 Development testing branch with new features, may be unstable |
|
||||
|
||||
### 📚 部署教程
|
||||
👉 **[🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)**
|
||||
### 📚 Deployment Guide
|
||||
👉 **[🚀 Latest Deployment Guide](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)**
|
||||
|
||||
---
|
||||
|
||||
## 💬 讨论与社区
|
||||
## 💬 Discussion and Community
|
||||
|
||||
我们欢迎所有对 MaiBot 感兴趣的朋友加入!
|
||||
We welcome everyone interested in MaiBot to join us.
|
||||
|
||||
| 类别 | 群组 | 说明 |
|
||||
| Category | Group | Description |
|
||||
| :--- | :--- | :--- |
|
||||
| **技术交流** | [麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) | 技术交流/答疑 |
|
||||
| **技术交流** | [麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) | 技术交流/答疑 |
|
||||
| **技术交流** | [麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) | 技术交流/答疑 |
|
||||
| **闲聊吹水** | [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) | 仅限闲聊,不答疑 |
|
||||
| **插件开发** | [插件开发群](https://qm.qq.com/q/1036092828) | 进阶开发与测试 |
|
||||
| **Technical** | [MaiBrain EEG](https://qm.qq.com/q/RzmCiRtHEW) | Technical discussion / Q&A |
|
||||
| **Technical** | [MaiBrain MRI](https://qm.qq.com/q/VQ3XZrWgMs) | Technical discussion / Q&A |
|
||||
| **Technical** | [Mai Wants to Be a VTuber](https://qm.qq.com/q/wGePTl1UyY) | Technical discussion / Q&A |
|
||||
| **Casual Chat** | [Mai Casual Chat Group](https://qm.qq.com/q/JxvHZnxyec) | Casual chat only, no support |
|
||||
| **Plugin Development** | [Plugin Dev Group](https://qm.qq.com/q/1036092828) | Advanced development and testing |
|
||||
|
||||
---
|
||||
|
||||
## 📚 文档
|
||||
## 📚 Documentation
|
||||
|
||||
> [!NOTE]
|
||||
> 部分内容可能更新不够及时,请注意版本对应。
|
||||
> Some content may not be updated promptly, so please pay attention to version compatibility.
|
||||
|
||||
- **[📚 核心 Wiki 文档](https://docs.mai-mai.org)**: 最全面的文档中心,了解麦麦的一切。
|
||||
- **[📚 Core Wiki Documentation](https://docs.mai-mai.org)**: The most comprehensive documentation hub for everything about MaiSaka.
|
||||
|
||||
### 🧩 衍生项目
|
||||
### 🧩 Related Projects
|
||||
|
||||
- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: 让麦麦在B站开播
|
||||
- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: 基于 MaiCore 0.10.0 的增强型 Fork,更稳定更有趣。
|
||||
- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: 让麦麦陪你玩 Minecraft (暂时停止维护中)。
|
||||
- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: Let MaiSaka stream on Bilibili.
|
||||
- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: An enhanced fork based on MaiCore 0.10.0, with improved stability and more fun features.
|
||||
- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: Let MaiSaka accompany you in Minecraft (currently paused).
|
||||
|
||||
---
|
||||
|
||||
## 💡 设计理念
|
||||
## 💡 Design Philosophy
|
||||
|
||||
> **千石可乐说:**
|
||||
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
|
||||
> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。
|
||||
> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。
|
||||
> **SengokuCola says:**
|
||||
> - This project originally started as a few extra features for the NiuNiu bot, but it kept growing until a full rewrite became inevitable. The goal was to create a "life form" active in QQ group chats, not a feature-complete bot, but something as human-like and real-feeling as possible.
|
||||
> - The core design principle is: "more lifelike, not merely better."
|
||||
> - If people truly want AI companionship, not everyone needs a perfect "helpful assistant" that solves every problem. Some people may want a life form that can make mistakes and has its own perceptions and thoughts.
|
||||
|
||||
> **xxxxx说:**
|
||||
> **xxxxx says:**
|
||||
> *Code is open, but the soul is yours.*
|
||||
|
||||
---
|
||||
|
||||
## 🙋 贡献和致谢
|
||||
## 🙋 Contributing and Acknowledgments
|
||||
|
||||
欢迎参与贡献!请先阅读 [贡献指南](docs-src/CONTRIBUTE.md)。
|
||||
Contributions are welcome. Please read the [Contribution Guide](docs-src/CONTRIBUTE.md) first.
|
||||
|
||||
### 🌟 贡献者
|
||||
### 🌟 Contributors
|
||||
|
||||
<a href="https://github.com/MaiM-with-u/MaiBot/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=MaiM-with-u/MaiBot" />
|
||||
</a>
|
||||
|
||||
### ❤️ 特别致谢
|
||||
### ❤️ Special Thanks
|
||||
|
||||
- **[萨卡班甲鱼](https://en.wikipedia.org/wiki/Sacabambaspis)**: 千石可乐很喜欢的生物。
|
||||
- **[略nd](https://space.bilibili.com/1344099355)**: 🎨 为麦麦绘制早期的精美人设。
|
||||
- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: 🚀 现代化的基于 NTQQ 的 Bot 协议实现。
|
||||
- **[Sacabambaspis](https://en.wikipedia.org/wiki/Sacabambaspis)**: SengokuCola's favorite creature.
|
||||
- **[略nd](https://space.bilibili.com/1344099355)**: Drew MaiSaka's beautiful early character design.
|
||||
- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: A modern NTQQ-based bot protocol implementation.
|
||||
|
||||
---
|
||||
|
||||
## 📊 仓库状态
|
||||
## 📊 Repository Status
|
||||
|
||||

|
||||

|
||||
|
||||
### Star 趋势
|
||||
[](https://starchart.cc/MaiM-with-u/MaiBot)
|
||||
### Star History
|
||||
[](https://starchart.cc/MaiM-with-u/MaiBot)
|
||||
|
||||
---
|
||||
|
||||
## 📌 注意事项 & License
|
||||
## 📌 Notice & License
|
||||
|
||||
> [!IMPORTANT]
|
||||
> 使用前请阅读 [用户协议 (EULA)](EULA.md) 和 [隐私协议](PRIVACY.md)。AI 生成内容请仔细甄别。
|
||||
> Please read the [End User License Agreement (EULA)](EULA.md) and [Privacy Policy](PRIVACY.md) before use. Please evaluate AI-generated content carefully.
|
||||
|
||||
**License**: GPL-3.0
|
||||
|
||||
162
docs/README_CN.md
Normal file
162
docs/README_CN.md
Normal file
@@ -0,0 +1,162 @@
|
||||
<div align="center">
|
||||
|
||||
<!-- Language Switcher -->
|
||||
<a href="README_CN.md">简体中文</a> | <a href="../README.md">English</a>
|
||||
|
||||
<br>
|
||||
<br>
|
||||
|
||||
<h1>麦麦 MaiBot <sub><small>MaiSaka</small></sub></h1>
|
||||
|
||||
<!-- Badges Row -->
|
||||
<p>
|
||||
<img src="https://img.shields.io/badge/Python-3.10+-blue" alt="Python Version">
|
||||
<img src="https://img.shields.io/github/license/Mai-with-u/MaiBot?label=%E5%8D%8F%E8%AE%AE" alt="License">
|
||||
<img src="https://img.shields.io/badge/状态-开发中-yellow" alt="Status">
|
||||
<img src="https://img.shields.io/github/contributors/Mai-with-u/MaiBot.svg?style=flat&label=%E8%B4%A1%E7%8C%AE%E8%80%85" alt="Contributors">
|
||||
<img src="https://img.shields.io/github/forks/Mai-with-u/MaiBot.svg?style=flat&label=%E5%88%86%E6%94%AF%E6%95%B0" alt="Forks">
|
||||
<img src="https://img.shields.io/github/stars/Mai-with-u/MaiBot?style=flat&label=%E6%98%9F%E6%A0%87%E6%95%B0" alt="Stars">
|
||||
<a href="https://deepwiki.com/DrSmoothl/MaiBot"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<!-- Mascot on the Right (Float) -->
|
||||
<img src="../depends-data/maimai-v2.png" align="right" width="40%" alt="MaiBot Character" style="margin-left: 20px; margin-bottom: 20px;">
|
||||
|
||||
## 介绍
|
||||
|
||||
麦麦MaiSaka 是一个基于大语言模型的可交互智能体
|
||||
|
||||
MaiSaka 不仅仅是一个机器人,不仅仅是一个可以帮你完成任务的“有帮助的助手”,她还是一个致力于了解你,并以真实人类的风格进行交互的数字生命,她不追求完美,她不追求高效,但追求亲切和真实。
|
||||
|
||||
|
||||
- 💭 **没有人喜欢GPT的语言风格**:麦麦使用了更加自然,贴合人类对话习惯的交互方式,不是长篇大论或者markdown格式的分点,而是或长或短的闲谈。
|
||||
|
||||
- 🎭 **不再是傻乎乎的一问一答**:懂得在合适的时间说话,把握聊天中的气氛,在合适的时候开口,在合适的时候闭嘴。
|
||||
|
||||
- 🧠 **麦麦·成为人类**:在多人对话中,麦麦会模仿其他人的的说话风格,还会自主理解新词或者小圈子里的黑话,不断进化。
|
||||
|
||||
- ❤️ **永远都在更加了解你**:基于心理学中人格理论,麦麦会不断积累对于你的了解,不论是你的信息,喜恶或是行为风格,她都记在心里。
|
||||
|
||||
- 🔌 **插件系统**:提供强大的 API 和事件系统,无限扩展可能。
|
||||
|
||||
|
||||
|
||||
### 快速导航
|
||||
<p>
|
||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P">🌟 演示视频</a> |
|
||||
<a href="#-更新和安装">📦 快速入门</a> |
|
||||
<a href="#-部署教程">📃 核心文档</a> |
|
||||
<a href="#-讨论与社区">💬 加入社区</a>
|
||||
</p>
|
||||
|
||||
<!-- Clear float to ensure subsequent content starts below the image area if text is short -->
|
||||
<br clear="both">
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||
<picture>
|
||||
<source media="(max-width: 600px)" srcset="../depends-data/video.png" width="100%">
|
||||
<img src="../depends-data/video.png" width="60%" alt="麦麦演示视频" style="border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
|
||||
</picture>
|
||||
<br>
|
||||
<small>前往观看麦麦演示视频</small>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## 🔥 更新和安装
|
||||
|
||||
> **最新版本: v1.0.0** ([📄 更新日志](../changelogs/changelog.md))
|
||||
|
||||
- **下载**: 前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
- **启动器**: [Mailauncher](https://github.com/MaiM-with-u/mailauncher/releases/) (仅支持 MacOS, 早期开发中)
|
||||
|
||||
| 分支 | 说明 |
|
||||
| :--- | :--- |
|
||||
| `main` | ✅ **稳定发布版本 (推荐)** |
|
||||
| `dev` | 🚧 开发测试版本,包含新功能,可能不稳定 |
|
||||
|
||||
### 📚 部署教程
|
||||
👉 **[🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html)**
|
||||
|
||||
---
|
||||
|
||||
## 💬 讨论与社区
|
||||
|
||||
我们欢迎所有对 MaiBot 感兴趣的朋友加入!
|
||||
|
||||
| 类别 | 群组 | 说明 |
|
||||
| :--- | :--- | :--- |
|
||||
| **技术交流** | [麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) | 技术交流/答疑 |
|
||||
| **技术交流** | [麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) | 技术交流/答疑 |
|
||||
| **技术交流** | [麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY) | 技术交流/答疑 |
|
||||
| **闲聊吹水** | [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec) | 仅限闲聊,不答疑 |
|
||||
| **插件开发** | [插件开发群](https://qm.qq.com/q/1036092828) | 进阶开发与测试 |
|
||||
|
||||
---
|
||||
|
||||
## 📚 文档
|
||||
|
||||
> [!NOTE]
|
||||
> 部分内容可能更新不够及时,请注意版本对应。
|
||||
|
||||
- **[📚 核心 Wiki 文档](https://docs.mai-mai.org)**: 最全面的文档中心,了解麦麦的一切。
|
||||
|
||||
### 🧩 衍生项目
|
||||
|
||||
- **[Amaidesu](https://github.com/MaiM-with-u/Amaidesu)**: 让麦麦在B站开播
|
||||
- **[MoFox_Bot](https://github.com/MoFox-Studio/MoFox-Core)**: 基于 MaiCore 0.10.0 的增强型 Fork,更稳定更有趣。
|
||||
- **[MaiCraft](https://github.com/MaiM-with-u/Maicraft)**: 让麦麦陪你玩 Minecraft (暂时停止维护中)。
|
||||
|
||||
---
|
||||
|
||||
## 💡 设计理念
|
||||
|
||||
> **千石可乐说:**
|
||||
> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。
|
||||
> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。
|
||||
> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。
|
||||
|
||||
> **xxxxx说:**
|
||||
> *Code is open, but the soul is yours.*
|
||||
|
||||
---
|
||||
|
||||
## 🙋 贡献和致谢
|
||||
|
||||
欢迎参与贡献!请先阅读 [贡献指南](../docs-src/CONTRIBUTE.md)。
|
||||
|
||||
### 🌟 贡献者
|
||||
|
||||
<a href="https://github.com/MaiM-with-u/MaiBot/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=MaiM-with-u/MaiBot" />
|
||||
</a>
|
||||
|
||||
### ❤️ 特别致谢
|
||||
|
||||
- **[萨卡班甲鱼](https://en.wikipedia.org/wiki/Sacabambaspis)**: 千石可乐很喜欢的生物。
|
||||
- **[略nd](https://space.bilibili.com/1344099355)**: 🎨 为麦麦绘制早期的精美人设。
|
||||
- **[NapCat](https://github.com/NapNeko/NapCatQQ)**: 🚀 现代化的基于 NTQQ 的 Bot 协议实现。
|
||||
|
||||
---
|
||||
|
||||
## 📊 仓库状态
|
||||
|
||||

|
||||
|
||||
### Star 趋势
|
||||
[](https://starchart.cc/MaiM-with-u/MaiBot)
|
||||
|
||||
---
|
||||
|
||||
## 📌 注意事项 & License
|
||||
|
||||
> [!IMPORTANT]
|
||||
> 使用前请阅读 [用户协议 (EULA)](../EULA.md) 和 [隐私协议](../PRIVACY.md)。AI 生成内容请仔细甄别。
|
||||
|
||||
**License**: GPL-3.0
|
||||
@@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
<!-- Language Switcher -->
|
||||
<a href="../README.md">简体中文</a> | <a href="README_EN.md">English</a>
|
||||
<a href="README_CN.md">简体中文</a> | <a href="../README.md">English</a>
|
||||
|
||||
<br>
|
||||
<br>
|
||||
|
||||
@@ -41,7 +41,7 @@ This plan is based on the checked-in code, not on assumptions from previous draf
|
||||
| `src/person_info/person_info.py:247` | `_is_bot_self(self, platform, user_id)` | Duplicate logic with same QQ fallback |
|
||||
|
||||
Wrong-order call sites (8 total):
|
||||
- `src/bw_learner/expression_learner.py` x3 (lines 158, 241, 301)
|
||||
- `src/learners/expression_learner.py` x3 (lines 158, 241, 301)
|
||||
- `src/common/utils/utils_message.py` x4 (lines 370, 440, 476, 515)
|
||||
- `src/webui/routers/chat/support.py` x1 (line 65)
|
||||
|
||||
@@ -122,7 +122,7 @@ Make `src/chat/utils/utils.py::is_bot_self(platform, user_id)` the only real imp
|
||||
- `src/common/utils/system_utils.py`
|
||||
- `src/chat/utils/utils.py`
|
||||
- `src/person_info/person_info.py`
|
||||
- `src/bw_learner/expression_learner.py`
|
||||
- `src/learners/expression_learner.py`
|
||||
- `src/common/utils/utils_message.py`
|
||||
- `src/webui/routers/chat/support.py`
|
||||
- tests
|
||||
@@ -468,7 +468,7 @@ When stopping, name: the exact file(s), the blocking mismatch, why it is outside
|
||||
|
||||
| Phase | Allowed files |
|
||||
|-------|---------------|
|
||||
| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/bw_learner/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) |
|
||||
| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/learners/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) |
|
||||
| Phase 1 | `src/chat/utils/utils.py`, `src/chat/planner_actions/planner.py`, `src/chat/utils/statistic.py`, `src/common/message_repository.py`, `src/webui/routers/chat/support.py`, `src/services/send_service.py`, `src/chat/replyer/group_generator.py`, `src/chat/replyer/private_generator.py`, `src/chat/brain_chat/PFC/message_sender.py`, `src/person_info/person_info.py`, tests |
|
||||
|
||||
### INVALID OUTPUT EXAMPLES
|
||||
|
||||
@@ -1,58 +1,40 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "发言频率控制插件|BetterFrequency Plugin",
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"description": "控制聊天频率,支持设置focus_value和talk_frequency调整值,提供命令",
|
||||
"name": "发言频率控制插件|BetterFrequency Plugin",
|
||||
"description": "控制聊天频率,支持设置 focus_value 和 talk_frequency 调整值,并提供命令入口。",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"host_application": {
|
||||
"min_version": "1.0.0"
|
||||
"urls": {
|
||||
"repository": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"homepage": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"documentation": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"issues": "https://github.com/SengokuCola/BetterFrequency/issues"
|
||||
},
|
||||
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"repository_url": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"keywords": [
|
||||
"frequency",
|
||||
"control",
|
||||
"talk_frequency",
|
||||
"plugin",
|
||||
"shortcut"
|
||||
"host_application": {
|
||||
"min_version": "1.0.0",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [],
|
||||
"capabilities": [
|
||||
"send.text",
|
||||
"frequency.set_adjust",
|
||||
"frequency.get_current_talk_value",
|
||||
"frequency.get_adjust"
|
||||
],
|
||||
"categories": [
|
||||
"Chat",
|
||||
"Frequency",
|
||||
"Control"
|
||||
],
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "frequency",
|
||||
"components": [
|
||||
{
|
||||
"type": "command",
|
||||
"name": "set_talk_frequency",
|
||||
"description": "设置当前聊天的talk_frequency调整值",
|
||||
"pattern": "/chat talk_frequency <数字> 或 /chat t <数字>"
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "show_frequency",
|
||||
"description": "显示当前聊天的频率控制状态",
|
||||
"pattern": "/chat show 或 /chat s"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
"设置talk_frequency调整值",
|
||||
"调整当前聊天的发言频率",
|
||||
"显示当前频率控制状态",
|
||||
"实时频率控制调整",
|
||||
"命令执行反馈(不保存消息)",
|
||||
"支持完整命令和简化命令",
|
||||
"快速操作支持"
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
},
|
||||
"id": "SengokuCola.BetterFrequency"
|
||||
}
|
||||
"id": "sengokucola.betterfrequency"
|
||||
}
|
||||
|
||||
@@ -3,12 +3,18 @@
|
||||
通过 /chat 命令设置和查看聊天频率。
|
||||
"""
|
||||
|
||||
from maibot_sdk import MaiBotPlugin, Command
|
||||
from maibot_sdk import Command, MaiBotPlugin
|
||||
|
||||
|
||||
class BetterFrequencyPlugin(MaiBotPlugin):
|
||||
"""聊天频率控制插件"""
|
||||
|
||||
async def on_load(self) -> None:
|
||||
"""处理插件加载。"""
|
||||
|
||||
async def on_unload(self) -> None:
|
||||
"""处理插件卸载。"""
|
||||
|
||||
@Command(
|
||||
"set_talk_frequency",
|
||||
description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>",
|
||||
@@ -80,6 +86,25 @@ class BetterFrequencyPlugin(MaiBotPlugin):
|
||||
await self.ctx.send.text(status_msg, stream_id)
|
||||
return True, None, False
|
||||
|
||||
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||
"""处理配置热重载事件。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围。
|
||||
config_data: 最新配置数据。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
|
||||
del scope
|
||||
del config_data
|
||||
del version
|
||||
|
||||
|
||||
def create_plugin() -> BetterFrequencyPlugin:
|
||||
"""创建聊天频率插件实例。
|
||||
|
||||
Returns:
|
||||
BetterFrequencyPlugin: 新的聊天频率插件实例。
|
||||
"""
|
||||
|
||||
def create_plugin():
|
||||
return BetterFrequencyPlugin()
|
||||
|
||||
@@ -1,67 +1,42 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "MCP桥接插件",
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具",
|
||||
"name": "MCP桥接插件",
|
||||
"description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。",
|
||||
"author": {
|
||||
"name": "CharTyr",
|
||||
"url": "https://github.com/CharTyr"
|
||||
},
|
||||
"license": "AGPL-3.0",
|
||||
"host_application": {
|
||||
"min_version": "0.11.6"
|
||||
"urls": {
|
||||
"repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues"
|
||||
},
|
||||
"homepage_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"repository_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"keywords": [
|
||||
"mcp",
|
||||
"bridge",
|
||||
"tool",
|
||||
"integration",
|
||||
"resources",
|
||||
"prompts",
|
||||
"post-process",
|
||||
"cache",
|
||||
"trace",
|
||||
"permissions",
|
||||
"import",
|
||||
"export",
|
||||
"claude-desktop",
|
||||
"workflow",
|
||||
"react",
|
||||
"agent"
|
||||
"host_application": {
|
||||
"min_version": "0.11.6",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [
|
||||
{
|
||||
"type": "python_package",
|
||||
"name": "mcp",
|
||||
"version_spec": ">=0.0.0"
|
||||
}
|
||||
],
|
||||
"categories": [
|
||||
"工具扩展",
|
||||
"外部集成"
|
||||
"capabilities": [
|
||||
"send.text"
|
||||
],
|
||||
"default_locale": "zh-CN",
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"components": [],
|
||||
"features": [
|
||||
"支持多个 MCP 服务器",
|
||||
"自动发现并注册 MCP 工具",
|
||||
"支持 stdio、SSE、HTTP、Streamable HTTP 四种传输方式",
|
||||
"工具参数自动转换",
|
||||
"心跳检测与自动重连",
|
||||
"调用统计(次数、成功率、耗时)",
|
||||
"WebUI 配置支持",
|
||||
"Resources 支持(实验性)",
|
||||
"Prompts 支持(实验性)",
|
||||
"结果后处理(LLM 摘要提炼)",
|
||||
"工具禁用管理",
|
||||
"调用链路追踪",
|
||||
"工具调用缓存(LRU)",
|
||||
"工具权限控制(群/用户级别)",
|
||||
"配置导入导出(Claude Desktop mcpServers)",
|
||||
"断路器模式(故障快速失败)",
|
||||
"状态实时刷新",
|
||||
"Workflow 硬流程(顺序执行多个工具)",
|
||||
"Workflow 快速添加(表单式配置)",
|
||||
"ReAct 软流程(LLM 自主多轮调用)",
|
||||
"双轨制架构(软流程 + 硬流程)"
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
},
|
||||
"id": "MaiBot Community.MCPBridgePlugin"
|
||||
"id": "chartyr.mcpbridge-plugin"
|
||||
}
|
||||
|
||||
@@ -1,68 +1,44 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "BetterEmoji",
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"name": "BetterEmoji",
|
||||
"description": "更好的表情包管理插件",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/SengokuCola"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"host_application": {
|
||||
"min_version": "1.0.0"
|
||||
"urls": {
|
||||
"repository": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"homepage": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"documentation": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"issues": "https://github.com/SengokuCola/BetterEmoji/issues"
|
||||
},
|
||||
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"keywords": [
|
||||
"emoji",
|
||||
"manage",
|
||||
"plugin"
|
||||
"host_application": {
|
||||
"min_version": "1.0.0",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [],
|
||||
"capabilities": [
|
||||
"emoji.get_random",
|
||||
"emoji.get_count",
|
||||
"emoji.get_info",
|
||||
"emoji.get_all",
|
||||
"emoji.register_emoji",
|
||||
"emoji.delete_emoji",
|
||||
"send.text",
|
||||
"send.forward"
|
||||
],
|
||||
"categories": [
|
||||
"Emoji",
|
||||
"Management"
|
||||
],
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "emoji_manage",
|
||||
"capabilities": [
|
||||
"emoji.get_random",
|
||||
"emoji.get_count",
|
||||
"emoji.get_info",
|
||||
"emoji.get_all",
|
||||
"emoji.register_emoji",
|
||||
"emoji.delete_emoji",
|
||||
"send.text",
|
||||
"send.forward"
|
||||
],
|
||||
"components": [
|
||||
{
|
||||
"type": "command",
|
||||
"name": "add_emoji",
|
||||
"description": "添加表情包",
|
||||
"pattern": "/emoji add"
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "emoji_list",
|
||||
"description": "列表表情包",
|
||||
"pattern": "/emoji list"
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "delete_emoji",
|
||||
"description": "删除表情包",
|
||||
"pattern": "/emoji delete"
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "random_emojis",
|
||||
"description": "发送多张随机表情包",
|
||||
"pattern": "/random_emojis"
|
||||
}
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
},
|
||||
"id": "SengokuCola.BetterEmoji"
|
||||
}
|
||||
"id": "sengokucola.betteremoji"
|
||||
}
|
||||
|
||||
@@ -3,17 +3,23 @@
|
||||
通过 /emoji 命令管理表情包的添加、列表和删除。
|
||||
"""
|
||||
|
||||
from maibot_sdk import Command, MaiBotPlugin
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from maibot_sdk import MaiBotPlugin, Command
|
||||
|
||||
|
||||
class EmojiManagePlugin(MaiBotPlugin):
|
||||
"""表情包管理插件"""
|
||||
|
||||
async def on_load(self) -> None:
|
||||
"""处理插件加载。"""
|
||||
|
||||
async def on_unload(self) -> None:
|
||||
"""处理插件卸载。"""
|
||||
|
||||
# ===== 工具方法 =====
|
||||
|
||||
@staticmethod
|
||||
@@ -208,6 +214,25 @@ class EmojiManagePlugin(MaiBotPlugin):
|
||||
await self.ctx.send.forward(messages, stream_id)
|
||||
return True, "已发送随机表情包", True
|
||||
|
||||
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||
"""处理配置热重载事件。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围。
|
||||
config_data: 最新配置数据。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
|
||||
del scope
|
||||
del config_data
|
||||
del version
|
||||
|
||||
|
||||
def create_plugin() -> EmojiManagePlugin:
|
||||
"""创建表情包管理插件实例。
|
||||
|
||||
Returns:
|
||||
EmojiManagePlugin: 新的表情包管理插件实例。
|
||||
"""
|
||||
|
||||
def create_plugin():
|
||||
return EmojiManagePlugin()
|
||||
|
||||
@@ -1,88 +1,41 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "Hello World 示例插件 (Hello World Plugin)",
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例",
|
||||
"name": "Hello World 示例插件 (Hello World Plugin)",
|
||||
"description": "我的第一个 MaiCore 插件,包含问候功能和时间查询等基础示例",
|
||||
"author": {
|
||||
"name": "MaiBot开发团队",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"host_application": {
|
||||
"min_version": "1.0.0"
|
||||
"urls": {
|
||||
"repository": "https://github.com/MaiM-with-u/maibot",
|
||||
"homepage": "https://github.com/MaiM-with-u/maibot",
|
||||
"documentation": "https://github.com/MaiM-with-u/maibot",
|
||||
"issues": "https://github.com/MaiM-with-u/maibot/issues"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": [
|
||||
"demo",
|
||||
"example",
|
||||
"hello",
|
||||
"greeting",
|
||||
"tutorial"
|
||||
"host_application": {
|
||||
"min_version": "1.0.0",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [],
|
||||
"capabilities": [
|
||||
"send.text",
|
||||
"send.forward",
|
||||
"send.hybrid",
|
||||
"emoji.get_random",
|
||||
"config.get"
|
||||
],
|
||||
"categories": [
|
||||
"Examples",
|
||||
"Tutorial"
|
||||
],
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "example",
|
||||
"capabilities": [
|
||||
"send.text",
|
||||
"send.forward",
|
||||
"send.hybrid",
|
||||
"emoji.get_random",
|
||||
"config.get"
|
||||
],
|
||||
"components": [
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "compare_numbers",
|
||||
"description": "比较两个数的大小"
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "hello_greeting",
|
||||
"description": "向用户发送问候消息"
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "bye_greeting",
|
||||
"description": "向用户发送告别消息",
|
||||
"activation_modes": ["keyword"],
|
||||
"keywords": ["再见", "bye", "88", "拜拜"]
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "time",
|
||||
"description": "查询当前时间",
|
||||
"pattern": "/time"
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "random_emojis",
|
||||
"description": "发送多张随机表情包",
|
||||
"pattern": "/random_emojis"
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "test",
|
||||
"description": "测试命令",
|
||||
"pattern": "/test"
|
||||
},
|
||||
{
|
||||
"type": "event_handler",
|
||||
"name": "print_message_handler",
|
||||
"description": "打印接收到的消息"
|
||||
},
|
||||
{
|
||||
"type": "event_handler",
|
||||
"name": "forward_messages_handler",
|
||||
"description": "把接收到的消息转发到指定聊天ID"
|
||||
}
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
},
|
||||
"id": "MaiBot开发团队.maibot"
|
||||
}
|
||||
"id": "maibot-team.hello-world-plugin"
|
||||
}
|
||||
|
||||
@@ -3,16 +3,22 @@
|
||||
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
|
||||
"""
|
||||
|
||||
from maibot_sdk import Action, Command, EventHandler, MaiBotPlugin, Tool
|
||||
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
|
||||
|
||||
import datetime
|
||||
import random
|
||||
|
||||
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
|
||||
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
|
||||
|
||||
|
||||
class HelloWorldPlugin(MaiBotPlugin):
|
||||
"""Hello World 示例插件"""
|
||||
|
||||
async def on_load(self) -> None:
|
||||
"""处理插件加载。"""
|
||||
|
||||
async def on_unload(self) -> None:
|
||||
"""处理插件卸载。"""
|
||||
|
||||
# ===== Tool 组件 =====
|
||||
|
||||
@Tool(
|
||||
@@ -146,6 +152,25 @@ class HelloWorldPlugin(MaiBotPlugin):
|
||||
|
||||
return True, True, None, None, None
|
||||
|
||||
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||
"""处理配置热重载事件。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围。
|
||||
config_data: 最新配置数据。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
|
||||
del scope
|
||||
del config_data
|
||||
del version
|
||||
|
||||
|
||||
def create_plugin() -> HelloWorldPlugin:
|
||||
"""创建 Hello World 示例插件实例。
|
||||
|
||||
Returns:
|
||||
HelloWorldPlugin: 新的示例插件实例。
|
||||
"""
|
||||
|
||||
def create_plugin():
|
||||
return HelloWorldPlugin()
|
||||
|
||||
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"jieba>=0.42.1",
|
||||
"json-repair>=0.47.6",
|
||||
"maim-message>=0.6.2",
|
||||
"maibot-plugin-sdk>=1.2.3,<2.0.0",
|
||||
"maibot-plugin-sdk>=2.0.0",
|
||||
"msgpack>=1.1.2",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
@@ -55,6 +55,8 @@ dev = [
|
||||
[tool.uv]
|
||||
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
|
||||
[tool.uv.sources]
|
||||
maibot-plugin-sdk = { path = "packages/maibot-plugin-sdk", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
|
||||
89
pytests/common_test/test_expression_auto_check_task.py
Normal file
89
pytests/common_test/test_expression_auto_check_task.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""测试表达方式自动检查任务的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_auto_check_engine")
|
||||
def expression_auto_check_engine_fixture() -> Generator:
|
||||
"""创建用于表达方式自动检查任务测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_expressions_uses_read_only_session(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
expression_auto_check_engine,
|
||||
) -> None:
|
||||
"""选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
|
||||
|
||||
import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
|
||||
|
||||
with Session(expression_auto_check_engine) as session:
|
||||
session.add(
|
||||
Expression(
|
||||
situation="表达情绪高涨或生理反应",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
auto_commit_calls: list[bool] = []
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
|
||||
auto_commit_calls.append(auto_commit)
|
||||
session = Session(expression_auto_check_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
|
||||
monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
|
||||
|
||||
task = ExpressionAutoCheckTask()
|
||||
expressions = await task._select_expressions(1)
|
||||
|
||||
assert auto_commit_calls == [False]
|
||||
assert len(expressions) == 1
|
||||
assert expressions[0].id is not None
|
||||
assert expressions[0].situation == "表达情绪高涨或生理反应"
|
||||
assert expressions[0].style == "发送💦表情符号"
|
||||
81
pytests/common_test/test_expression_learner.py
Normal file
81
pytests/common_test/test_expression_learner.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""测试表达方式学习器的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.bw_learner.expression_learner import ExpressionLearner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_learner_engine")
|
||||
def expression_learner_engine_fixture() -> Generator:
|
||||
"""创建用于表达方式学习器测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_find_similar_expression_uses_read_only_session_and_history_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
expression_learner_engine,
|
||||
) -> None:
|
||||
"""查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
|
||||
import src.bw_learner.expression_learner as expression_learner_module
|
||||
|
||||
with Session(expression_learner_engine) as session:
|
||||
session.add(
|
||||
Expression(
|
||||
situation="发送汗滴表情",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
session = Session(expression_learner_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
|
||||
|
||||
learner = ExpressionLearner(session_id="session-a")
|
||||
result = learner._find_similar_expression("表达情绪高涨或生理反应")
|
||||
|
||||
assert result is not None
|
||||
expression, similarity = result
|
||||
assert expression.item_id is not None
|
||||
assert expression.style == "发送💦表情符号"
|
||||
assert similarity == pytest.approx(1.0)
|
||||
78
pytests/common_test/test_expression_schema.py
Normal file
78
pytests/common_test/test_expression_schema.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""测试表达方式表结构和基础插入行为。"""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_engine")
|
||||
def expression_engine_fixture() -> Generator:
|
||||
"""创建仅用于表达方式表测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
|
||||
"""表达方式表在新库中应能自动分配自增主键。"""
|
||||
with Session(expression_engine) as session:
|
||||
expression = Expression(
|
||||
situation="表达情绪高涨或生理反应",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
session.add(expression)
|
||||
session.commit()
|
||||
session.refresh(expression)
|
||||
|
||||
assert expression.id is not None
|
||||
assert expression.id > 0
|
||||
|
||||
|
||||
def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
|
||||
"""相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
|
||||
with Session(expression_engine) as session:
|
||||
first_expression = Expression(
|
||||
situation="对重复行为的默契响应",
|
||||
style="持续性跟发相同内容",
|
||||
content_list='["对重复行为的默契响应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
second_expression = Expression(
|
||||
situation="对重复行为的默契响应",
|
||||
style="持续性跟发相同内容",
|
||||
content_list='["对重复行为的默契响应-变体"]',
|
||||
count=2,
|
||||
session_id="session-b",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
|
||||
session.add(first_expression)
|
||||
session.add(second_expression)
|
||||
session.commit()
|
||||
session.refresh(first_expression)
|
||||
session.refresh(second_expression)
|
||||
|
||||
assert first_expression.id is not None
|
||||
assert second_expression.id is not None
|
||||
assert first_expression.id != second_expression.id
|
||||
90
pytests/common_test/test_jargon_miner.py
Normal file
90
pytests/common_test/test_jargon_miner.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""测试黑话学习器的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
|
||||
from src.bw_learner.jargon_miner import JargonMiner
|
||||
from src.common.database.database_model import Jargon
|
||||
|
||||
|
||||
@pytest.fixture(name="jargon_miner_engine")
|
||||
def jargon_miner_engine_fixture() -> Generator:
|
||||
"""创建用于黑话学习器测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
jargon_miner_engine,
|
||||
) -> None:
|
||||
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
|
||||
import src.bw_learner.jargon_miner as jargon_miner_module
|
||||
|
||||
with Session(jargon_miner_engine) as session:
|
||||
session.add(
|
||||
Jargon(
|
||||
content="VF8V4L",
|
||||
raw_content='["[1] first"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=0,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
session = Session(jargon_miner_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
|
||||
|
||||
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
|
||||
await jargon_miner.process_extracted_entries(
|
||||
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
|
||||
)
|
||||
|
||||
with Session(jargon_miner_engine) as session:
|
||||
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
|
||||
|
||||
assert db_jargon.count == 1
|
||||
assert db_jargon.session_id_dict == '{"session-a": 2}'
|
||||
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
|
||||
"[1] first",
|
||||
"[2] second",
|
||||
]
|
||||
84
pytests/common_test/test_jargon_schema.py
Normal file
84
pytests/common_test/test_jargon_schema.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""测试黑话表结构和基础插入行为。"""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.common.database.database_model import Jargon
|
||||
|
||||
|
||||
@pytest.fixture(name="jargon_engine")
|
||||
def jargon_engine_fixture() -> Generator:
|
||||
"""创建仅用于黑话表测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
|
||||
"""黑话表在新库中应能自动分配自增主键。"""
|
||||
with Session(jargon_engine) as session:
|
||||
jargon = Jargon(
|
||||
content="VF8V4L",
|
||||
raw_content='["[1] test"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=True,
|
||||
last_inference_count=0,
|
||||
)
|
||||
session.add(jargon)
|
||||
session.commit()
|
||||
session.refresh(jargon)
|
||||
|
||||
assert jargon.id is not None
|
||||
assert jargon.id > 0
|
||||
|
||||
|
||||
def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
|
||||
"""黑话内容不应再被错误地绑成复合主键的一部分。"""
|
||||
with Session(jargon_engine) as session:
|
||||
first_jargon = Jargon(
|
||||
content="表情1",
|
||||
raw_content='["[1] first"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
second_jargon = Jargon(
|
||||
content="表情1",
|
||||
raw_content='["[1] second"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-b": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
|
||||
session.add(first_jargon)
|
||||
session.add(second_jargon)
|
||||
session.commit()
|
||||
session.refresh(first_jargon)
|
||||
session.refresh(second_jargon)
|
||||
|
||||
assert first_jargon.id is not None
|
||||
assert second_jargon.id is not None
|
||||
assert first_jargon.id != second_jargon.id
|
||||
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""人物信息群名片字段兼容测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
"""模拟日志记录器。"""
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""记录调试日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""记录信息日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""记录警告日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""记录错误日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
|
||||
class _DummyStatement:
|
||||
"""模拟 SQL 查询语句对象。"""
|
||||
|
||||
def where(self, condition: Any) -> "_DummyStatement":
|
||||
"""附加过滤条件。
|
||||
|
||||
Args:
|
||||
condition: 过滤条件。
|
||||
|
||||
Returns:
|
||||
_DummyStatement: 当前语句对象。
|
||||
"""
|
||||
del condition
|
||||
return self
|
||||
|
||||
def limit(self, value: int) -> "_DummyStatement":
|
||||
"""限制返回条数。
|
||||
|
||||
Args:
|
||||
value: 条数限制。
|
||||
|
||||
Returns:
|
||||
_DummyStatement: 当前语句对象。
|
||||
"""
|
||||
del value
|
||||
return self
|
||||
|
||||
|
||||
class _DummyColumn:
|
||||
"""模拟 SQLModel 列对象。"""
|
||||
|
||||
def is_not(self, value: Any) -> "_DummyColumn":
|
||||
"""模拟 `IS NOT` 条件构造。
|
||||
|
||||
Args:
|
||||
value: 比较值。
|
||||
|
||||
Returns:
|
||||
_DummyColumn: 当前列对象。
|
||||
"""
|
||||
del value
|
||||
return self
|
||||
|
||||
def __eq__(self, other: Any) -> "_DummyColumn":
|
||||
"""模拟等值条件构造。
|
||||
|
||||
Args:
|
||||
other: 比较值。
|
||||
|
||||
Returns:
|
||||
_DummyColumn: 当前列对象。
|
||||
"""
|
||||
del other
|
||||
return self
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
"""模拟数据库查询结果。"""
|
||||
|
||||
def __init__(self, record: Any) -> None:
|
||||
"""初始化查询结果。
|
||||
|
||||
Args:
|
||||
record: 待返回的首条记录。
|
||||
"""
|
||||
self._record = record
|
||||
|
||||
def first(self) -> Any:
|
||||
"""返回第一条记录。
|
||||
|
||||
Returns:
|
||||
Any: 首条记录。
|
||||
"""
|
||||
return self._record
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
"""返回全部结果。
|
||||
|
||||
Returns:
|
||||
list[Any]: 结果列表。
|
||||
"""
|
||||
if self._record is None:
|
||||
return []
|
||||
return self._record if isinstance(self._record, list) else [self._record]
|
||||
|
||||
|
||||
class _DummySession:
|
||||
"""模拟数据库 Session。"""
|
||||
|
||||
def __init__(self, record: Any) -> None:
|
||||
"""初始化 Session。
|
||||
|
||||
Args:
|
||||
record: `first()` 应返回的记录。
|
||||
"""
|
||||
self.record = record
|
||||
self.added_records: list[Any] = []
|
||||
|
||||
def __enter__(self) -> "_DummySession":
|
||||
"""进入上下文管理器。
|
||||
|
||||
Returns:
|
||||
_DummySession: 当前 Session。
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""退出上下文管理器。
|
||||
|
||||
Args:
|
||||
exc_type: 异常类型。
|
||||
exc_val: 异常值。
|
||||
exc_tb: 异常回溯。
|
||||
"""
|
||||
del exc_type
|
||||
del exc_val
|
||||
del exc_tb
|
||||
|
||||
def exec(self, statement: Any) -> _DummyResult:
|
||||
"""执行查询。
|
||||
|
||||
Args:
|
||||
statement: 查询语句。
|
||||
|
||||
Returns:
|
||||
_DummyResult: 模拟结果对象。
|
||||
"""
|
||||
del statement
|
||||
return _DummyResult(self.record)
|
||||
|
||||
def add(self, record: Any) -> None:
|
||||
"""记录被添加的对象。
|
||||
|
||||
Args:
|
||||
record: 被写入 Session 的对象。
|
||||
"""
|
||||
self.added_records.append(record)
|
||||
|
||||
|
||||
class _DummyPersonInfoRecord:
|
||||
"""模拟 `PersonInfo` ORM 模型。"""
|
||||
|
||||
person_id = "person_id"
|
||||
person_name = "person_name"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""使用关键字参数初始化记录对象。
|
||||
|
||||
Args:
|
||||
**kwargs: 字段值。
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
|
||||
"""加载带依赖桩的 `person_info` 模块。
|
||||
|
||||
Args:
|
||||
monkeypatch: Pytest monkeypatch 工具。
|
||||
session: 提供给模块使用的假数据库 Session。
|
||||
|
||||
Returns:
|
||||
ModuleType: 加载后的模块对象。
|
||||
"""
|
||||
logger_module = ModuleType("src.common.logger")
|
||||
logger_module.get_logger = lambda name: _DummyLogger()
|
||||
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
|
||||
|
||||
database_module = ModuleType("src.common.database.database")
|
||||
database_module.get_db_session = lambda: session
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
|
||||
|
||||
database_model_module = ModuleType("src.common.database.database_model")
|
||||
database_model_module.PersonInfo = _DummyPersonInfoRecord
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
|
||||
|
||||
llm_module = ModuleType("src.llm_models.utils_model")
|
||||
|
||||
class _DummyLLMRequest:
|
||||
"""模拟 LLMRequest。"""
|
||||
|
||||
def __init__(self, model_set: Any, request_type: str) -> None:
|
||||
"""初始化假请求对象。
|
||||
|
||||
Args:
|
||||
model_set: 模型配置。
|
||||
request_type: 请求类型。
|
||||
"""
|
||||
del model_set
|
||||
del request_type
|
||||
|
||||
llm_module.LLMRequest = _DummyLLMRequest
|
||||
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
|
||||
|
||||
config_module = ModuleType("src.config.config")
|
||||
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
|
||||
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
|
||||
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
||||
|
||||
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
|
||||
chat_manager_module.chat_manager = SimpleNamespace()
|
||||
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
|
||||
|
||||
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
|
||||
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
|
||||
assert spec is not None and spec.loader is not None
|
||||
|
||||
module = module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, spec.name, module)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
|
||||
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
|
||||
return module
|
||||
|
||||
|
||||
def test_parse_group_cardname_json_uses_canonical_key() -> None:
|
||||
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
|
||||
parsed = parse_group_cardname_json(
|
||||
json.dumps(
|
||||
[
|
||||
{"group_id": "1001", "group_cardname": "现行字段"},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert parsed is not None
|
||||
assert [(item.group_id, item.group_cardname) for item in parsed] == [
|
||||
("1001", "现行字段"),
|
||||
]
|
||||
|
||||
|
||||
def test_dump_group_cardname_records_uses_canonical_key() -> None:
|
||||
"""群名片序列化应输出 `group_cardname` 键名。"""
|
||||
dumped = dump_group_cardname_records(
|
||||
[
|
||||
{"group_id": "1001", "group_cardname": "群昵称"},
|
||||
]
|
||||
)
|
||||
|
||||
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
|
||||
|
||||
|
||||
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
|
||||
record = _DummyPersonInfoRecord()
|
||||
session = _DummySession(record)
|
||||
module = _load_person_module(monkeypatch, session)
|
||||
|
||||
person = module.Person.__new__(module.Person)
|
||||
person.is_known = True
|
||||
person.person_id = "person-1"
|
||||
person.platform = "qq"
|
||||
person.user_id = "10001"
|
||||
person.nickname = "看番的龙"
|
||||
person.person_name = "看番的龙"
|
||||
person.name_reason = "测试"
|
||||
person.know_times = 1
|
||||
person.know_since = 1700000000.0
|
||||
person.last_know = 1700000100.0
|
||||
person.memory_points = ["喜好:番剧:0.8"]
|
||||
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
|
||||
|
||||
person.sync_to_database()
|
||||
|
||||
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
|
||||
assert not hasattr(record, "group_nickname")
|
||||
|
||||
|
||||
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
|
||||
record = _DummyPersonInfoRecord(
|
||||
user_id="10001",
|
||||
platform="qq",
|
||||
is_known=True,
|
||||
user_nickname="看番的龙",
|
||||
person_name="看番的龙",
|
||||
name_reason=None,
|
||||
know_counts=2,
|
||||
memory_points='["喜好:番剧:0.8"]',
|
||||
group_cardname=json.dumps(
|
||||
[
|
||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
session = _DummySession(record)
|
||||
module = _load_person_module(monkeypatch, session)
|
||||
|
||||
person = module.Person.__new__(module.Person)
|
||||
person.person_id = "person-1"
|
||||
person.memory_points = []
|
||||
person.group_cardname_list = []
|
||||
|
||||
person.load_from_database()
|
||||
|
||||
assert person.group_cardname_list == [
|
||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||
]
|
||||
170
pytests/test_message_gateway_runtime.py
Normal file
170
pytests/test_message_gateway_runtime.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""消息网关运行时状态同步测试。"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from src.platform_io.manager import PlatformIOManager
|
||||
from src.platform_io.types import RouteKey
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
|
||||
def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope:
|
||||
"""构造一个 RPC 请求信封。
|
||||
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
payload: 请求载荷。
|
||||
|
||||
Returns:
|
||||
Envelope: 标准 RPC 请求信封。
|
||||
"""
|
||||
|
||||
return Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_gateway_runtime_state_binds_send_and_receive_routes(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""消息网关就绪后应同时绑定发送表和接收表。"""
|
||||
|
||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||
|
||||
platform_io_manager = PlatformIOManager()
|
||||
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
register_response = await supervisor._handle_register_plugin(
|
||||
_make_request(
|
||||
"plugin.register_components",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"plugin_id": "napcat_plugin",
|
||||
"plugin_version": "1.0.0",
|
||||
"components": [
|
||||
{
|
||||
"name": "napcat_gateway",
|
||||
"component_type": "MESSAGE_GATEWAY",
|
||||
"plugin_id": "napcat_plugin",
|
||||
"metadata": {
|
||||
"route_type": "duplex",
|
||||
"platform": "qq",
|
||||
"protocol": "napcat",
|
||||
},
|
||||
}
|
||||
],
|
||||
"capabilities_required": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert register_response.error is None
|
||||
response = await supervisor._handle_update_message_gateway_state(
|
||||
_make_request(
|
||||
"host.update_message_gateway_state",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"gateway_name": "napcat_gateway",
|
||||
"ready": True,
|
||||
"platform": "qq",
|
||||
"account_id": "10001",
|
||||
"scope": "primary",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert response.payload["accepted"] is True
|
||||
|
||||
send_bindings = platform_io_manager.send_route_table.resolve_bindings(
|
||||
RouteKey(platform="qq", account_id="10001", scope="primary")
|
||||
)
|
||||
receive_bindings = platform_io_manager.receive_route_table.resolve_bindings(
|
||||
RouteKey(platform="qq", account_id="10001", scope="primary")
|
||||
)
|
||||
|
||||
assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
|
||||
assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""消息网关断开后应撤销发送表和接收表中的绑定。"""
|
||||
|
||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||
|
||||
platform_io_manager = PlatformIOManager()
|
||||
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await supervisor._handle_register_plugin(
|
||||
_make_request(
|
||||
"plugin.register_components",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"plugin_id": "napcat_plugin",
|
||||
"plugin_version": "1.0.0",
|
||||
"components": [
|
||||
{
|
||||
"name": "napcat_gateway",
|
||||
"component_type": "MESSAGE_GATEWAY",
|
||||
"plugin_id": "napcat_plugin",
|
||||
"metadata": {
|
||||
"route_type": "duplex",
|
||||
"platform": "qq",
|
||||
"protocol": "napcat",
|
||||
},
|
||||
}
|
||||
],
|
||||
"capabilities_required": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
await supervisor._handle_update_message_gateway_state(
|
||||
_make_request(
|
||||
"host.update_message_gateway_state",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"gateway_name": "napcat_gateway",
|
||||
"ready": True,
|
||||
"platform": "qq",
|
||||
"account_id": "10001",
|
||||
"scope": "primary",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
response = await supervisor._handle_update_message_gateway_state(
|
||||
_make_request(
|
||||
"host.update_message_gateway_state",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"gateway_name": "napcat_gateway",
|
||||
"ready": False,
|
||||
"platform": "qq",
|
||||
"account_id": "",
|
||||
"scope": "",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert response.payload["accepted"] is True
|
||||
assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
|
||||
assert (
|
||||
platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
|
||||
)
|
||||
132
pytests/test_napcat_adapter_sdk.py
Normal file
132
pytests/test_napcat_adapter_sdk.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""NapCat 插件与新 SDK 对接测试。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
|
||||
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
|
||||
|
||||
for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)):
|
||||
if import_path not in sys.path:
|
||||
sys.path.insert(0, import_path)
|
||||
|
||||
|
||||
class _FakeGatewayCapability:
|
||||
"""用于捕获消息网关状态上报的测试替身。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化测试替身。"""
|
||||
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_state(
|
||||
self,
|
||||
gateway_name: str,
|
||||
*,
|
||||
ready: bool,
|
||||
platform: str = "",
|
||||
account_id: str = "",
|
||||
scope: str = "",
|
||||
metadata: Dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""记录一次状态上报请求。
|
||||
|
||||
Args:
|
||||
gateway_name: 网关组件名称。
|
||||
ready: 当前是否就绪。
|
||||
platform: 平台名称。
|
||||
account_id: 账号 ID。
|
||||
scope: 路由作用域。
|
||||
metadata: 附加元数据。
|
||||
|
||||
Returns:
|
||||
bool: 始终返回 ``True``,模拟 Host 接受状态更新。
|
||||
"""
|
||||
|
||||
self.calls.append(
|
||||
{
|
||||
"gateway_name": gateway_name,
|
||||
"ready": ready,
|
||||
"platform": platform,
|
||||
"account_id": account_id,
|
||||
"scope": scope,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]:
|
||||
"""动态加载 NapCat 插件测试所需的符号。
|
||||
|
||||
Returns:
|
||||
tuple[Any, Any, Any, Any]:
|
||||
依次返回网关名常量、配置类、插件类和运行时状态管理器类。
|
||||
"""
|
||||
|
||||
constants_module = importlib.import_module("napcat_adapter.constants")
|
||||
config_module = importlib.import_module("napcat_adapter.config")
|
||||
plugin_module = importlib.import_module("napcat_adapter.plugin")
|
||||
runtime_state_module = importlib.import_module("napcat_adapter.runtime_state")
|
||||
return (
|
||||
constants_module.NAPCAT_GATEWAY_NAME,
|
||||
config_module.NapCatServerConfig,
|
||||
plugin_module.NapCatAdapterPlugin,
|
||||
runtime_state_module.NapCatRuntimeStateManager,
|
||||
)
|
||||
|
||||
|
||||
def test_napcat_plugin_collects_duplex_message_gateway() -> None:
|
||||
"""NapCat 插件应声明新的双工消息网关组件。"""
|
||||
|
||||
napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
||||
plugin = napcat_plugin_cls()
|
||||
components = plugin.get_components()
|
||||
gateway_components = [
|
||||
component
|
||||
for component in components
|
||||
if component.get("type") == "MESSAGE_GATEWAY"
|
||||
]
|
||||
|
||||
assert len(gateway_components) == 1
|
||||
gateway_component = gateway_components[0]
|
||||
assert gateway_component["name"] == napcat_gateway_name
|
||||
assert gateway_component["metadata"]["route_type"] == "duplex"
|
||||
assert gateway_component["metadata"]["platform"] == "qq"
|
||||
assert gateway_component["metadata"]["protocol"] == "napcat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_state_reports_via_gateway_capability() -> None:
|
||||
"""NapCat 运行时状态应通过新的消息网关能力上报。"""
|
||||
|
||||
napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
|
||||
gateway_capability = _FakeGatewayCapability()
|
||||
runtime_state_manager = runtime_state_cls(
|
||||
gateway_capability=gateway_capability,
|
||||
logger=logging.getLogger("test.napcat_adapter"),
|
||||
gateway_name=napcat_gateway_name,
|
||||
)
|
||||
|
||||
connected = await runtime_state_manager.report_connected(
|
||||
"10001",
|
||||
napcat_server_config_cls(connection_id="primary"),
|
||||
)
|
||||
await runtime_state_manager.report_disconnected()
|
||||
|
||||
assert connected is True
|
||||
assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
|
||||
assert gateway_capability.calls[0]["ready"] is True
|
||||
assert gateway_capability.calls[0]["platform"] == "qq"
|
||||
assert gateway_capability.calls[0]["account_id"] == "10001"
|
||||
assert gateway_capability.calls[0]["scope"] == "primary"
|
||||
assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
|
||||
assert gateway_capability.calls[1]["ready"] is False
|
||||
assert gateway_capability.calls[1]["platform"] == "qq"
|
||||
209
pytests/test_platform_io_dedupe.py
Normal file
209
pytests/test_platform_io_dedupe.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Platform IO 入站去重策略测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.manager import PlatformIOManager
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey
|
||||
|
||||
|
||||
def _build_envelope(
|
||||
*,
|
||||
dedupe_key: str | None = None,
|
||||
external_message_id: str | None = None,
|
||||
session_message_id: str | None = None,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
) -> InboundMessageEnvelope:
|
||||
"""构造测试用入站信封。
|
||||
|
||||
Args:
|
||||
dedupe_key: 显式去重键。
|
||||
external_message_id: 平台侧消息 ID。
|
||||
session_message_id: 规范化消息对象上的消息 ID。
|
||||
payload: 原始载荷。
|
||||
|
||||
Returns:
|
||||
InboundMessageEnvelope: 测试用入站消息信封。
|
||||
"""
|
||||
session_message = None
|
||||
if session_message_id is not None:
|
||||
session_message = SimpleNamespace(message_id=session_message_id)
|
||||
|
||||
return InboundMessageEnvelope(
|
||||
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
|
||||
driver_id="plugin.napcat",
|
||||
driver_kind=DriverKind.PLUGIN,
|
||||
dedupe_key=dedupe_key,
|
||||
external_message_id=external_message_id,
|
||||
session_message=session_message,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
class _StubPlatformIODriver(PlatformIODriver):
|
||||
"""测试用 Platform IO 驱动。"""
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""返回一个固定的成功回执。
|
||||
|
||||
Args:
|
||||
message: 待发送的消息对象。
|
||||
route_key: 本次发送使用的路由键。
|
||||
metadata: 额外发送元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 固定的成功回执。
|
||||
"""
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=str(getattr(message, "message_id", "stub-message-id")),
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
|
||||
|
||||
def _build_manager() -> PlatformIOManager:
|
||||
"""构造带有最小接收路由的 Broker 管理器。
|
||||
|
||||
Returns:
|
||||
PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。
|
||||
"""
|
||||
manager = PlatformIOManager()
|
||||
driver = _StubPlatformIODriver(
|
||||
DriverDescriptor(
|
||||
driver_id="plugin.napcat",
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform="qq",
|
||||
account_id="10001",
|
||||
scope="main",
|
||||
)
|
||||
)
|
||||
manager.register_driver(driver)
|
||||
manager.bind_receive_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
return manager
|
||||
|
||||
|
||||
class TestPlatformIODedupe:
|
||||
"""Platform IO 去重测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_inbound_dedupes_by_external_message_id(self) -> None:
|
||||
"""相同平台消息 ID 的重复入站应被抑制。"""
|
||||
manager = _build_manager()
|
||||
accepted_envelopes: List[InboundMessageEnvelope] = []
|
||||
|
||||
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
|
||||
"""记录被成功接收的入站消息。
|
||||
|
||||
Args:
|
||||
envelope: 被 Broker 接受的入站消息。
|
||||
"""
|
||||
accepted_envelopes.append(envelope)
|
||||
|
||||
manager.set_inbound_dispatcher(dispatcher)
|
||||
|
||||
first_envelope = _build_envelope(
|
||||
external_message_id="msg-1",
|
||||
payload={"message": "hello"},
|
||||
)
|
||||
second_envelope = _build_envelope(
|
||||
external_message_id="msg-1",
|
||||
payload={"message": "hello"},
|
||||
)
|
||||
|
||||
assert await manager.accept_inbound(first_envelope) is True
|
||||
assert await manager.accept_inbound(second_envelope) is False
|
||||
assert len(accepted_envelopes) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None:
|
||||
"""缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。"""
|
||||
manager = _build_manager()
|
||||
accepted_envelopes: List[InboundMessageEnvelope] = []
|
||||
|
||||
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
|
||||
"""记录被成功接收的入站消息。
|
||||
|
||||
Args:
|
||||
envelope: 被 Broker 接受的入站消息。
|
||||
"""
|
||||
accepted_envelopes.append(envelope)
|
||||
|
||||
manager.set_inbound_dispatcher(dispatcher)
|
||||
|
||||
first_envelope = _build_envelope(payload={"message": "same-payload"})
|
||||
second_envelope = _build_envelope(payload={"message": "same-payload"})
|
||||
|
||||
assert await manager.accept_inbound(first_envelope) is True
|
||||
assert await manager.accept_inbound(second_envelope) is True
|
||||
assert len(accepted_envelopes) == 2
|
||||
|
||||
def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None:
|
||||
"""去重键应只来自显式或稳定的技术身份。"""
|
||||
explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1")
|
||||
session_message_envelope = _build_envelope(session_message_id="session-1")
|
||||
payload_only_envelope = _build_envelope(payload={"message": "hello"})
|
||||
|
||||
assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1"
|
||||
assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1"
|
||||
assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_fans_out_to_all_matching_routes(self) -> None:
|
||||
"""同一路由命中多条发送链路时应全部发送。"""
|
||||
|
||||
manager = PlatformIOManager()
|
||||
first_driver = _StubPlatformIODriver(
|
||||
DriverDescriptor(
|
||||
driver_id="plugin.gateway_a",
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform="qq",
|
||||
)
|
||||
)
|
||||
second_driver = _StubPlatformIODriver(
|
||||
DriverDescriptor(
|
||||
driver_id="plugin.gateway_b",
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform="qq",
|
||||
)
|
||||
)
|
||||
manager.register_driver(first_driver)
|
||||
manager.register_driver(second_driver)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=first_driver.driver_id,
|
||||
driver_kind=first_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=second_driver.driver_id,
|
||||
driver_kind=second_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
|
||||
message = SimpleNamespace(message_id="internal-msg-1")
|
||||
result = await manager.send_message(message, RouteKey(platform="qq"))
|
||||
|
||||
assert result.has_success is True
|
||||
assert [receipt.driver_id for receipt in result.sent_receipts] == [
|
||||
"plugin.gateway_a",
|
||||
"plugin.gateway_b",
|
||||
]
|
||||
178
pytests/test_platform_io_legacy_driver.py
Normal file
178
pytests/test_platform_io_legacy_driver.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Platform IO legacy driver 回归测试。"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.utils import utils as chat_utils
|
||||
from src.chat.message_receive import uni_message_sender
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
|
||||
from src.platform_io.manager import PlatformIOManager
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey
|
||||
|
||||
|
||||
class _PluginDriver(PlatformIODriver):
|
||||
"""测试用插件发送驱动。"""
|
||||
|
||||
def __init__(self, driver_id: str, platform: str) -> None:
|
||||
"""初始化测试驱动。
|
||||
|
||||
Args:
|
||||
driver_id: 驱动 ID。
|
||||
platform: 负责的平台名称。
|
||||
"""
|
||||
super().__init__(
|
||||
DriverDescriptor(
|
||||
driver_id=driver_id,
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform=platform,
|
||||
plugin_id="test.plugin",
|
||||
)
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""返回一个固定成功回执。
|
||||
|
||||
Args:
|
||||
message: 待发送消息。
|
||||
route_key: 当前路由键。
|
||||
metadata: 发送元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 固定成功回执。
|
||||
"""
|
||||
del metadata
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=str(message.message_id),
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""没有显式发送路由时,应由 Platform IO 回退到 legacy driver。"""
|
||||
manager = PlatformIOManager()
|
||||
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
|
||||
|
||||
try:
|
||||
await manager.ensure_send_pipeline_ready()
|
||||
|
||||
fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
|
||||
assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"]
|
||||
|
||||
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
|
||||
await manager.add_driver(plugin_driver)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=plugin_driver.driver_id,
|
||||
driver_kind=plugin_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
|
||||
explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
|
||||
assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
|
||||
|
||||
manager = PlatformIOManager()
|
||||
legacy_calls: list[dict[str, Any]] = []
|
||||
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
|
||||
|
||||
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
|
||||
"""记录 legacy driver 调用。"""
|
||||
|
||||
legacy_calls.append({"message": message, "show_log": show_log})
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(
|
||||
uni_message_sender,
|
||||
"send_prepared_message_to_platform",
|
||||
_fake_send_prepared_message_to_platform,
|
||||
)
|
||||
|
||||
try:
|
||||
await manager.ensure_send_pipeline_ready()
|
||||
|
||||
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
|
||||
await manager.add_driver(plugin_driver)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=plugin_driver.driver_id,
|
||||
driver_kind=plugin_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
|
||||
message = type("FakeMessage", (), {"message_id": "message-1"})()
|
||||
batch = await manager.send_message(
|
||||
message=message,
|
||||
route_key=RouteKey(platform="qq"),
|
||||
metadata={"show_log": False},
|
||||
)
|
||||
|
||||
assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
|
||||
"legacy.send.qq",
|
||||
"plugin.qq.sender",
|
||||
]
|
||||
assert batch.failed_receipts == []
|
||||
assert len(legacy_calls) == 1
|
||||
assert legacy_calls[0]["message"] is message
|
||||
assert legacy_calls[0]["show_log"] is False
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_platform_driver_uses_prepared_universal_sender(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""legacy driver 应复用已预处理消息的旧链发送函数。"""
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
|
||||
"""记录 legacy driver 调用。"""
|
||||
calls.append({"message": message, "show_log": show_log})
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(
|
||||
uni_message_sender,
|
||||
"send_prepared_message_to_platform",
|
||||
_fake_send_prepared_message_to_platform,
|
||||
)
|
||||
|
||||
driver = LegacyPlatformDriver(
|
||||
driver_id="legacy.send.qq",
|
||||
platform="qq",
|
||||
account_id="bot-qq",
|
||||
)
|
||||
message = type("FakeMessage", (), {"message_id": "message-1"})()
|
||||
receipt = await driver.send_message(
|
||||
message=message,
|
||||
route_key=RouteKey(platform="qq"),
|
||||
metadata={"show_log": False},
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["message"] is message
|
||||
assert calls[0]["show_log"] is False
|
||||
assert receipt.status == DeliveryStatus.SENT
|
||||
assert receipt.driver_id == "legacy.send.qq"
|
||||
87
pytests/test_plugin_message_utils_runtime.py
Normal file
87
pytests/test_plugin_message_utils_runtime.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
ForwardComponent,
|
||||
ForwardNodeComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
TextComponent,
|
||||
VoiceComponent,
|
||||
)
|
||||
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None:
|
||||
message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq")
|
||||
message.message_info = MessageInfo(
|
||||
user_info=UserInfo(user_id="10001", user_nickname="tester"),
|
||||
group_info=GroupInfo(group_id="20001", group_name="group"),
|
||||
additional_config={"self_id": "999"},
|
||||
)
|
||||
message.session_id = "qq:20001:10001"
|
||||
message.processed_plain_text = "binary payload"
|
||||
message.display_message = "binary payload"
|
||||
message.raw_message = MessageSequence(
|
||||
components=[
|
||||
TextComponent("hello"),
|
||||
ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""),
|
||||
VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""),
|
||||
ReplyComponent(
|
||||
target_message_id="origin-1",
|
||||
target_message_content="origin text",
|
||||
target_message_sender_id="42",
|
||||
target_message_sender_nickname="alice",
|
||||
target_message_sender_cardname="Alice",
|
||||
),
|
||||
ForwardNodeComponent(
|
||||
forward_components=[
|
||||
ForwardComponent(
|
||||
user_nickname="bob",
|
||||
user_id="43",
|
||||
user_cardname="Bob",
|
||||
message_id="forward-1",
|
||||
content=[
|
||||
TextComponent("node-text"),
|
||||
ImageComponent(binary_hash="", binary_data=b"node-image", content=""),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
message_dict = PluginMessageUtils._session_message_to_dict(message)
|
||||
rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
|
||||
|
||||
image_component = rebuilt_message.raw_message.components[1]
|
||||
voice_component = rebuilt_message.raw_message.components[2]
|
||||
reply_component = rebuilt_message.raw_message.components[3]
|
||||
forward_component = rebuilt_message.raw_message.components[4]
|
||||
|
||||
assert isinstance(image_component, ImageComponent)
|
||||
assert image_component.binary_data == b"image-bytes"
|
||||
|
||||
assert isinstance(voice_component, VoiceComponent)
|
||||
assert voice_component.binary_data == b"voice-bytes"
|
||||
|
||||
assert isinstance(reply_component, ReplyComponent)
|
||||
assert reply_component.target_message_id == "origin-1"
|
||||
assert reply_component.target_message_content == "origin text"
|
||||
assert reply_component.target_message_sender_id == "42"
|
||||
assert reply_component.target_message_sender_nickname == "alice"
|
||||
assert reply_component.target_message_sender_cardname == "Alice"
|
||||
|
||||
assert isinstance(forward_component, ForwardNodeComponent)
|
||||
assert isinstance(forward_component.forward_components[0].content[1], ImageComponent)
|
||||
assert forward_component.forward_components[0].content[1].binary_data == b"node-image"
|
||||
File diff suppressed because it is too large
Load Diff
284
pytests/test_plugin_runtime_action_bridge.py
Normal file
284
pytests/test_plugin_runtime_action_bridge.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""核心组件查询层与插件运行时聚合测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import src.plugin_runtime.integration as integration_module
|
||||
|
||||
from src.core.types import ActionInfo, ToolInfo
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
|
||||
class _FakeRuntimeManager:
|
||||
"""测试用插件运行时管理器。"""
|
||||
|
||||
def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None:
|
||||
"""初始化测试用运行时管理器。
|
||||
|
||||
Args:
|
||||
supervisor: 持有测试组件的监督器。
|
||||
plugin_id: 目标插件 ID。
|
||||
plugin_config: 需要返回的插件配置。
|
||||
"""
|
||||
|
||||
self.supervisors = [supervisor]
|
||||
self._plugin_id = plugin_id
|
||||
self._plugin_config = plugin_config
|
||||
|
||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None:
|
||||
"""按插件 ID 返回对应监督器。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
PluginSupervisor | None: 命中时返回监督器。
|
||||
"""
|
||||
|
||||
return self.supervisors[0] if plugin_id == self._plugin_id else None
|
||||
|
||||
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]:
|
||||
"""返回测试配置。
|
||||
|
||||
Args:
|
||||
supervisor: 监督器实例。
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 测试配置内容。
|
||||
"""
|
||||
|
||||
del supervisor
|
||||
if plugin_id != self._plugin_id:
|
||||
return {}
|
||||
return dict(self._plugin_config)
|
||||
|
||||
|
||||
def _install_runtime_manager(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
supervisor: PluginSupervisor,
|
||||
plugin_id: str,
|
||||
plugin_config: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""为测试安装假的运行时管理器。
|
||||
|
||||
Args:
|
||||
monkeypatch: pytest monkeypatch 对象。
|
||||
supervisor: 持有测试组件的监督器。
|
||||
plugin_id: 测试插件 ID。
|
||||
plugin_config: 可选的测试配置内容。
|
||||
"""
|
||||
|
||||
fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True})
|
||||
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_core_component_registry_reads_runtime_action_and_executor(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""核心查询层应直接读取运行时 Action,并返回 RPC 执行闭包。"""
|
||||
|
||||
plugin_id = "runtime_action_bridge_plugin"
|
||||
action_name = "runtime_action_bridge_test"
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
supervisor.component_registry.register_component(
|
||||
name=action_name,
|
||||
component_type="ACTION",
|
||||
plugin_id=plugin_id,
|
||||
metadata={
|
||||
"description": "发送一个测试回复",
|
||||
"enabled": True,
|
||||
"activation_type": "keyword",
|
||||
"activation_probability": 0.25,
|
||||
"activation_keywords": ["测试", "hello"],
|
||||
"action_parameters": {"target": "目标对象"},
|
||||
"action_require": ["需要发送回复时使用"],
|
||||
"associated_types": ["text"],
|
||||
"parallel_action": True,
|
||||
},
|
||||
)
|
||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"})
|
||||
|
||||
async def fake_invoke_plugin(
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟动作 RPC 调用。"""
|
||||
|
||||
captured["method"] = method
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
||||
|
||||
action_info = component_query_service.get_action_info(action_name)
|
||||
assert isinstance(action_info, ActionInfo)
|
||||
assert action_info.plugin_name == plugin_id
|
||||
assert action_info.description == "发送一个测试回复"
|
||||
assert action_info.activation_keywords == ["测试", "hello"]
|
||||
assert action_info.random_activation_probability == 0.25
|
||||
assert action_info.parallel_action is True
|
||||
assert action_name in component_query_service.get_default_actions()
|
||||
assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"}
|
||||
|
||||
executor = component_query_service.get_action_executor(action_name)
|
||||
assert executor is not None
|
||||
|
||||
success, reason = await executor(
|
||||
action_data={"target": "MaiBot"},
|
||||
action_reasoning="当前适合使用这个动作",
|
||||
cycle_timers={"planner": 0.1},
|
||||
thinking_id="tid-1",
|
||||
chat_stream=SimpleNamespace(session_id="stream-1"),
|
||||
log_prefix="[test]",
|
||||
shutting_down=False,
|
||||
plugin_config={"enabled": True},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert reason == "runtime action executed"
|
||||
assert captured["method"] == "plugin.invoke_action"
|
||||
assert captured["plugin_id"] == plugin_id
|
||||
assert captured["component_name"] == action_name
|
||||
assert captured["args"]["stream_id"] == "stream-1"
|
||||
assert captured["args"]["chat_id"] == "stream-1"
|
||||
assert captured["args"]["reasoning"] == "当前适合使用这个动作"
|
||||
assert captured["args"]["target"] == "MaiBot"
|
||||
assert captured["args"]["action_data"] == {"target": "MaiBot"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_core_component_registry_reads_runtime_command_and_executor(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""核心查询层应直接使用运行时命令匹配与执行闭包。"""
|
||||
|
||||
plugin_id = "runtime_command_bridge_plugin"
|
||||
command_name = "runtime_command_bridge_test"
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
supervisor.component_registry.register_component(
|
||||
name=command_name,
|
||||
component_type="COMMAND",
|
||||
plugin_id=plugin_id,
|
||||
metadata={
|
||||
"description": "测试命令",
|
||||
"enabled": True,
|
||||
"command_pattern": r"^/test(?:\s+.+)?$",
|
||||
"aliases": ["/hello"],
|
||||
"intercept_message_level": 1,
|
||||
},
|
||||
)
|
||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"})
|
||||
|
||||
async def fake_invoke_plugin(
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟命令 RPC 调用。"""
|
||||
|
||||
captured["method"] = method
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
||||
|
||||
matched = component_query_service.find_command_by_text("/test hello")
|
||||
assert matched is not None
|
||||
command_executor, matched_groups, command_info = matched
|
||||
|
||||
assert matched_groups == {}
|
||||
assert command_info.plugin_name == plugin_id
|
||||
assert command_info.command_pattern == r"^/test(?:\s+.+)?$"
|
||||
|
||||
success, response_text, intercept = await command_executor(
|
||||
message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"),
|
||||
plugin_config={"mode": "command"},
|
||||
matched_groups=matched_groups,
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert response_text == "command ok"
|
||||
assert intercept is True
|
||||
assert captured["method"] == "plugin.invoke_command"
|
||||
assert captured["plugin_id"] == plugin_id
|
||||
assert captured["component_name"] == command_name
|
||||
assert captured["args"]["text"] == "/test hello"
|
||||
assert captured["args"]["stream_id"] == "stream-2"
|
||||
assert captured["args"]["plugin_config"] == {"mode": "command"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_core_component_registry_reads_runtime_tools_and_executor(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""核心查询层应直接读取运行时 Tool,并返回 RPC 执行闭包。"""
|
||||
|
||||
plugin_id = "runtime_tool_bridge_plugin"
|
||||
tool_name = "runtime_tool_bridge_test"
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
|
||||
supervisor.component_registry.register_component(
|
||||
name=tool_name,
|
||||
component_type="TOOL",
|
||||
plugin_id=plugin_id,
|
||||
metadata={
|
||||
"description": "测试工具",
|
||||
"enabled": True,
|
||||
"parameters": [
|
||||
{
|
||||
"name": "query",
|
||||
"param_type": "string",
|
||||
"description": "查询词",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id)
|
||||
|
||||
async def fake_invoke_plugin(
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟工具 RPC 调用。"""
|
||||
|
||||
del timeout_ms
|
||||
assert method == "plugin.invoke_tool"
|
||||
assert plugin_id == "runtime_tool_bridge_plugin"
|
||||
assert component_name == "runtime_tool_bridge_test"
|
||||
assert args == {"query": "MaiBot"}
|
||||
return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
||||
|
||||
tool_info = component_query_service.get_tool_info(tool_name)
|
||||
assert isinstance(tool_info, ToolInfo)
|
||||
assert tool_info.tool_description == "测试工具"
|
||||
assert tool_name in component_query_service.get_llm_available_tools()
|
||||
|
||||
executor = component_query_service.get_tool_executor(tool_name)
|
||||
assert executor is not None
|
||||
assert await executor({"query": "MaiBot"}) == {"content": "tool ok"}
|
||||
524
pytests/test_plugin_runtime_api.py
Normal file
524
pytests/test_plugin_runtime_api.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""插件 API 注册与调用测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
ComponentDeclaration,
|
||||
Envelope,
|
||||
MessageType,
|
||||
RegisterPluginPayload,
|
||||
UnregisterPluginPayload,
|
||||
)
|
||||
|
||||
|
||||
def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager:
|
||||
"""构造一个最小可用的插件运行时管理器。
|
||||
|
||||
Args:
|
||||
*supervisors: 需要挂载的监督器列表。
|
||||
|
||||
Returns:
|
||||
PluginRuntimeManager: 已注入监督器的运行时管理器。
|
||||
"""
|
||||
|
||||
manager = PluginRuntimeManager()
|
||||
if supervisors:
|
||||
manager._builtin_supervisor = supervisors[0]
|
||||
if len(supervisors) > 1:
|
||||
manager._third_party_supervisor = supervisors[1]
|
||||
return manager
|
||||
|
||||
|
||||
async def _register_plugin(
|
||||
supervisor: PluginSupervisor,
|
||||
plugin_id: str,
|
||||
components: List[Dict[str, Any]],
|
||||
) -> Envelope:
|
||||
"""通过 Supervisor 注册测试插件。
|
||||
|
||||
Args:
|
||||
supervisor: 目标监督器。
|
||||
plugin_id: 测试插件 ID。
|
||||
components: 组件声明列表。
|
||||
|
||||
Returns:
|
||||
Envelope: 注册响应信封。
|
||||
"""
|
||||
|
||||
payload = RegisterPluginPayload(
|
||||
plugin_id=plugin_id,
|
||||
plugin_version="1.0.0",
|
||||
components=[
|
||||
ComponentDeclaration(
|
||||
name=str(component.get("name", "") or ""),
|
||||
component_type=str(component.get("component_type", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
|
||||
)
|
||||
for component in components
|
||||
],
|
||||
)
|
||||
return await supervisor._handle_register_plugin(
|
||||
Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="plugin.register_components",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope:
|
||||
"""通过 Supervisor 注销测试插件。
|
||||
|
||||
Args:
|
||||
supervisor: 目标监督器。
|
||||
plugin_id: 测试插件 ID。
|
||||
|
||||
Returns:
|
||||
Envelope: 注销响应信封。
|
||||
"""
|
||||
|
||||
payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test")
|
||||
return await supervisor._handle_unregister_plugin(
|
||||
Envelope(
|
||||
request_id=2,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="plugin.unregister",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_plugin_syncs_dedicated_api_registry() -> None:
|
||||
"""插件注册时应将 API 同步到独立注册表,而不是通用组件表。"""
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
response = await _register_plugin(
|
||||
supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML",
|
||||
"version": "1",
|
||||
"public": True,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert response.payload["accepted"] is True
|
||||
assert response.payload["registered_components"] == 0
|
||||
assert response.payload["registered_apis"] == 1
|
||||
assert supervisor.api_registry.get_api("provider", "render_html") is not None
|
||||
assert supervisor.component_registry.get_component("provider.render_html") is None
|
||||
|
||||
unregister_response = await _unregister_plugin(supervisor, "provider")
|
||||
assert unregister_response.payload["removed_apis"] == 1
|
||||
assert supervisor.api_registry.get_api("provider", "render_html") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""公开 API 应允许其他插件通过 Host 转发调用。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML",
|
||||
"version": "1",
|
||||
"public": True,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_invoke_api(
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟 API RPC 调用。"""
|
||||
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
|
||||
|
||||
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
|
||||
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.render_html",
|
||||
"version": "1",
|
||||
"args": {"html": "<div>Hello</div>"},
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"success": True, "result": {"image": "ok"}}
|
||||
assert captured["plugin_id"] == "provider"
|
||||
assert captured["component_name"] == "render_html"
|
||||
assert captured["args"] == {"html": "<div>Hello</div>"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_call_rejects_private_api_between_plugins() -> None:
|
||||
"""未公开的 API 默认不允许跨插件调用。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "secret_api",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "私有 API",
|
||||
"version": "1",
|
||||
"public": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
||||
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.secret_api",
|
||||
"args": {},
|
||||
},
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "未公开" in str(result["error"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
|
||||
"""API 列表与组件启停应直接作用于独立 API 注册表。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "public_api",
|
||||
"component_type": "API",
|
||||
"metadata": {"version": "1", "public": True},
|
||||
},
|
||||
{
|
||||
"name": "private_api",
|
||||
"component_type": "API",
|
||||
"metadata": {"version": "1", "public": False},
|
||||
},
|
||||
],
|
||||
)
|
||||
await _register_plugin(
|
||||
consumer_supervisor,
|
||||
"consumer",
|
||||
[
|
||||
{
|
||||
"name": "self_private_api",
|
||||
"component_type": "API",
|
||||
"metadata": {"version": "1", "public": False},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
list_result = await manager._cap_api_list("consumer", "api.list", {})
|
||||
|
||||
assert list_result["success"] is True
|
||||
api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]}
|
||||
assert ("provider", "public_api") in api_names
|
||||
assert ("provider", "private_api") not in api_names
|
||||
assert ("consumer", "self_private_api") in api_names
|
||||
|
||||
disable_result = await manager._cap_component_disable(
|
||||
"consumer",
|
||||
"component.disable",
|
||||
{
|
||||
"name": "provider.public_api",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
},
|
||||
)
|
||||
assert disable_result["success"] is True
|
||||
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None
|
||||
|
||||
enable_result = await manager._cap_component_enable(
|
||||
"consumer",
|
||||
"component.enable",
|
||||
{
|
||||
"name": "provider.public_api",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
},
|
||||
)
|
||||
assert enable_result["success"] is True
|
||||
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML v1",
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "handle_render_html_v1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML v2",
|
||||
"version": "2",
|
||||
"public": True,
|
||||
"handler_name": "handle_render_html_v2",
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_invoke_api(
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟多版本 API 调用。"""
|
||||
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
|
||||
|
||||
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
|
||||
ambiguous_result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.render_html",
|
||||
"args": {"html": "<div>Hello</div>"},
|
||||
},
|
||||
)
|
||||
assert ambiguous_result["success"] is False
|
||||
assert "多个版本" in str(ambiguous_result["error"])
|
||||
|
||||
disable_ambiguous_result = await manager._cap_component_disable(
|
||||
"consumer",
|
||||
"component.disable",
|
||||
{
|
||||
"name": "provider.render_html",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
},
|
||||
)
|
||||
assert disable_ambiguous_result["success"] is False
|
||||
assert "多个版本" in str(disable_ambiguous_result["error"])
|
||||
|
||||
disable_v1_result = await manager._cap_component_disable(
|
||||
"consumer",
|
||||
"component.disable",
|
||||
{
|
||||
"name": "provider.render_html",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
"version": "1",
|
||||
},
|
||||
)
|
||||
assert disable_v1_result["success"] is True
|
||||
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
|
||||
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
|
||||
|
||||
result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.render_html",
|
||||
"version": "2",
|
||||
"args": {"html": "<div>Hello</div>"},
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"success": True, "result": {"image": "ok"}}
|
||||
assert captured["plugin_id"] == "provider"
|
||||
assert captured["component_name"] == "handle_render_html_v2"
|
||||
assert captured["args"] == {"html": "<div>Hello</div>"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_replace_dynamic_can_offline_removed_entries(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""动态 API 替换后,被移除的 API 应返回明确下线错误。"""
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(supervisor, "provider", [])
|
||||
manager = _build_manager(supervisor)
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_invoke_api(
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟动态 API 调用。"""
|
||||
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
|
||||
|
||||
replace_result = await manager._cap_api_replace_dynamic(
|
||||
"provider",
|
||||
"api.replace_dynamic",
|
||||
{
|
||||
"apis": [
|
||||
{
|
||||
"name": "mcp.search",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_search",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "mcp.read",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_read",
|
||||
},
|
||||
},
|
||||
],
|
||||
"offline_reason": "MCP 服务器已关闭",
|
||||
},
|
||||
)
|
||||
|
||||
assert replace_result["success"] is True
|
||||
assert replace_result["count"] == 2
|
||||
list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
|
||||
assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
|
||||
("mcp.read", "1"),
|
||||
("mcp.search", "1"),
|
||||
}
|
||||
|
||||
call_result = await manager._cap_api_call(
|
||||
"provider",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.mcp.search",
|
||||
"version": "1",
|
||||
"args": {"query": "hello"},
|
||||
},
|
||||
)
|
||||
assert call_result == {"success": True, "result": {"ok": True}}
|
||||
assert captured["component_name"] == "dynamic_search"
|
||||
assert captured["args"]["query"] == "hello"
|
||||
assert captured["args"]["__maibot_api_name__"] == "mcp.search"
|
||||
assert captured["args"]["__maibot_api_version__"] == "1"
|
||||
|
||||
second_replace_result = await manager._cap_api_replace_dynamic(
|
||||
"provider",
|
||||
"api.replace_dynamic",
|
||||
{
|
||||
"apis": [
|
||||
{
|
||||
"name": "mcp.read",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_read",
|
||||
},
|
||||
}
|
||||
],
|
||||
"offline_reason": "MCP 服务器已关闭",
|
||||
},
|
||||
)
|
||||
|
||||
assert second_replace_result["success"] is True
|
||||
assert second_replace_result["count"] == 1
|
||||
assert second_replace_result["offlined"] == 1
|
||||
|
||||
offlined_call_result = await manager._cap_api_call(
|
||||
"provider",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.mcp.search",
|
||||
"version": "1",
|
||||
"args": {},
|
||||
},
|
||||
)
|
||||
assert offlined_call_result["success"] is False
|
||||
assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
|
||||
|
||||
list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
|
||||
assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
|
||||
("mcp.read", "1"),
|
||||
}
|
||||
154
pytests/test_send_service.py
Normal file
154
pytests/test_send_service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""发送服务回归测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.services import send_service
|
||||
|
||||
|
||||
class _FakePlatformIOManager:
|
||||
"""用于测试的 Platform IO 管理器假对象。"""
|
||||
|
||||
def __init__(self, delivery_batch: Any) -> None:
|
||||
"""初始化假 Platform IO 管理器。
|
||||
|
||||
Args:
|
||||
delivery_batch: 发送时返回的批量回执。
|
||||
"""
|
||||
self._delivery_batch = delivery_batch
|
||||
self.ensure_calls = 0
|
||||
self.sent_messages: List[Dict[str, Any]] = []
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
"""记录发送管线准备调用次数。"""
|
||||
self.ensure_calls += 1
|
||||
|
||||
def build_route_key_from_message(self, message: Any) -> Any:
|
||||
"""根据消息构造假的路由键。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
|
||||
Returns:
|
||||
Any: 简化后的路由键对象。
|
||||
"""
|
||||
del message
|
||||
return SimpleNamespace(platform="qq")
|
||||
|
||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
||||
"""记录发送请求并返回预设回执。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
route_key: 本次发送使用的路由键。
|
||||
metadata: 发送元数据。
|
||||
|
||||
Returns:
|
||||
Any: 预设的批量发送回执。
|
||||
"""
|
||||
self.sent_messages.append(
|
||||
{
|
||||
"message": message,
|
||||
"route_key": route_key,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return self._delivery_batch
|
||||
|
||||
|
||||
def _build_target_stream() -> BotChatSession:
|
||||
"""构造一个最小可用的目标会话对象。
|
||||
|
||||
Returns:
|
||||
BotChatSession: 测试用会话对象。
|
||||
"""
|
||||
return BotChatSession(
|
||||
session_id="test-session",
|
||||
platform="qq",
|
||||
user_id="target-user",
|
||||
group_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
|
||||
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
|
||||
|
||||
metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
|
||||
|
||||
assert metadata["platform_io_account_id"] == "bot-qq"
|
||||
assert metadata["platform_io_target_user_id"] == "target-user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""send service 应将发送职责统一交给 Platform IO。"""
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=True,
|
||||
sent_receipts=[SimpleNamespace(driver_id="plugin.qq.sender")],
|
||||
failed_receipts=[],
|
||||
route_key=SimpleNamespace(platform="qq"),
|
||||
)
|
||||
)
|
||||
stored_messages: List[Any] = []
|
||||
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
send_service.MessageUtils,
|
||||
"store_message_to_db",
|
||||
lambda message: stored_messages.append(message),
|
||||
)
|
||||
|
||||
result = await send_service.text_to_stream(text="你好", stream_id="test-session")
|
||||
|
||||
assert result is True
|
||||
assert fake_manager.ensure_calls == 1
|
||||
assert len(fake_manager.sent_messages) == 1
|
||||
assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False}
|
||||
assert len(stored_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Platform IO 批量发送全部失败时,应直接向上返回失败。"""
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=False,
|
||||
sent_receipts=[],
|
||||
failed_receipts=[
|
||||
SimpleNamespace(
|
||||
driver_id="plugin.qq.sender",
|
||||
status="failed",
|
||||
error="network error",
|
||||
)
|
||||
],
|
||||
route_key=SimpleNamespace(platform="qq"),
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
|
||||
result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
|
||||
|
||||
assert result is False
|
||||
assert fake_manager.ensure_calls == 1
|
||||
assert len(fake_manager.sent_messages) == 1
|
||||
115
pytests/utils_test/statistic_test.py
Normal file
115
pytests/utils_test/statistic_test.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""统计模块数据库会话行为测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.utils import statistic
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
"""模拟 SQLModel 查询结果对象。"""
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
"""返回空结果集。
|
||||
|
||||
Returns:
|
||||
list[Any]: 空列表。
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class _DummySession:
|
||||
"""模拟数据库 Session。"""
|
||||
|
||||
def exec(self, statement: Any) -> _DummyResult:
|
||||
"""执行查询语句并返回空结果。
|
||||
|
||||
Args:
|
||||
statement: 待执行的查询语句。
|
||||
|
||||
Returns:
|
||||
_DummyResult: 空结果对象。
|
||||
"""
|
||||
del statement
|
||||
return _DummyResult()
|
||||
|
||||
|
||||
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
|
||||
"""构造一个记录 auto_commit 参数的假会话工厂。
|
||||
|
||||
Args:
|
||||
calls: 用于记录每次调用 auto_commit 参数的列表。
|
||||
|
||||
Returns:
|
||||
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
|
||||
"""记录会话参数并返回假 Session。
|
||||
|
||||
Args:
|
||||
auto_commit: 是否启用自动提交。
|
||||
|
||||
Yields:
|
||||
Iterator[_DummySession]: 假 Session 对象。
|
||||
"""
|
||||
calls.append(auto_commit)
|
||||
yield _DummySession()
|
||||
|
||||
return _fake_get_db_session
|
||||
|
||||
|
||||
def _build_statistic_task() -> statistic.StatisticOutputTask:
|
||||
"""构造一个最小可用的统计任务实例。
|
||||
|
||||
Returns:
|
||||
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
|
||||
"""
|
||||
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
|
||||
task.name_mapping = {}
|
||||
return task
|
||||
|
||||
|
||||
def _is_bot_self(platform: str, user_id: str) -> bool:
|
||||
"""返回固定的非机器人身份判断结果。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
user_id: 用户 ID。
|
||||
|
||||
Returns:
|
||||
bool: 始终返回 ``False``。
|
||||
"""
|
||||
del platform
|
||||
del user_id
|
||||
return False
|
||||
|
||||
|
||||
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
|
||||
calls: list[bool] = []
|
||||
now = datetime.now()
|
||||
task = _build_statistic_task()
|
||||
|
||||
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
|
||||
|
||||
utils_module = ModuleType("src.chat.utils.utils")
|
||||
utils_module.is_bot_self = _is_bot_self
|
||||
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
||||
|
||||
statistic.StatisticOutputTask._fetch_online_time_since(now)
|
||||
statistic.StatisticOutputTask._fetch_model_usage_since(now)
|
||||
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
|
||||
task._collect_interval_data(now, hours=1, interval_minutes=60)
|
||||
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
|
||||
|
||||
assert calls == [False] * 9
|
||||
42
pytests/utils_test/test_session_utils.py
Normal file
42
pytests/utils_test/test_session_utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.chat.message_receive.chat_manager import ChatManager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
|
||||
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
|
||||
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
||||
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
||||
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
|
||||
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
|
||||
|
||||
assert base_session_id == same_base_session_id
|
||||
assert account_scoped_session_id != base_session_id
|
||||
assert route_scoped_session_id != account_scoped_session_id
|
||||
|
||||
|
||||
def test_chat_manager_register_message_uses_route_metadata() -> None:
|
||||
chat_manager = ChatManager()
|
||||
message = SimpleNamespace(
|
||||
platform="qq",
|
||||
session_id="",
|
||||
message_info=SimpleNamespace(
|
||||
user_info=SimpleNamespace(user_id="42"),
|
||||
group_info=SimpleNamespace(group_id="1000"),
|
||||
additional_config={
|
||||
"platform_io_account_id": "123",
|
||||
"platform_io_scope": "main",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
chat_manager.register_message(message)
|
||||
|
||||
assert message.session_id == SessionUtils.calculate_session_id(
|
||||
"qq",
|
||||
user_id="42",
|
||||
group_id="1000",
|
||||
account_id="123",
|
||||
scope="main",
|
||||
)
|
||||
assert chat_manager.last_messages[message.session_id] is message
|
||||
@@ -1,27 +1,28 @@
|
||||
import time
|
||||
"""PFC 侧消息发送封装。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from maim_message import Seg
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.utils import get_bot_account
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.services import send_service as send_api
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接消息发送器"""
|
||||
"""直接消息发送器。"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
def __init__(self, private_name: str) -> None:
|
||||
"""初始化直接消息发送器。
|
||||
|
||||
Args:
|
||||
private_name: 当前私聊实例的名称。
|
||||
"""
|
||||
self.private_name = private_name
|
||||
|
||||
async def send_message(
|
||||
@@ -30,58 +31,31 @@ class DirectMessageSender:
|
||||
content: str,
|
||||
reply_to_message: Optional[MaiMessage] = None,
|
||||
) -> None:
|
||||
"""发送消息到聊天流
|
||||
"""发送文本消息到聊天流。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天会话
|
||||
content: 消息内容
|
||||
reply_to_message: 要回复的消息(可选)
|
||||
chat_stream: 目标聊天会话。
|
||||
content: 待发送的文本内容。
|
||||
reply_to_message: 可选的引用回复锚点消息。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当消息发送失败时抛出。
|
||||
"""
|
||||
try:
|
||||
# 创建消息内容
|
||||
segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
|
||||
|
||||
# 获取麦麦的信息
|
||||
bot_user_id = get_bot_account(chat_stream.platform)
|
||||
if not bot_user_id:
|
||||
logger.error(f"[私聊][{self.private_name}]平台 {chat_stream.platform} 未配置机器人账号,无法发送消息")
|
||||
raise RuntimeError(f"平台 {chat_stream.platform} 未配置机器人账号")
|
||||
bot_user_info = UserInfo(
|
||||
user_id=bot_user_id,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
sent = await send_api.text_to_stream(
|
||||
text=content,
|
||||
stream_id=chat_stream.session_id,
|
||||
set_reply=reply_to_message is not None,
|
||||
reply_message=reply_to_message,
|
||||
storage_message=True,
|
||||
)
|
||||
|
||||
# 用当前时间作为message_id,和之前那套sender一样
|
||||
message_id = f"dm{round(time.time(), 2)}"
|
||||
|
||||
# 构建发送者信息(私聊时为接收者)
|
||||
sender_info = None
|
||||
if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info:
|
||||
sender_info = reply_to_message.message_info.user_info
|
||||
|
||||
# 构建消息对象
|
||||
message = MessageSending(
|
||||
message_id=message_id,
|
||||
session=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=sender_info,
|
||||
message_segment=segments,
|
||||
reply=reply_to_message,
|
||||
is_head=True,
|
||||
is_emoji=False,
|
||||
thinking_start_time=time.time(),
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
message_sender = UniversalMessageSender()
|
||||
sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True)
|
||||
|
||||
if sent:
|
||||
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
else:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
|
||||
raise RuntimeError("消息发送失败")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
|
||||
raise RuntimeError("消息发送失败")
|
||||
except Exception as exc:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}")
|
||||
raise
|
||||
|
||||
@@ -8,8 +8,8 @@ from rich.traceback import install
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ExpressionConfigUtils
|
||||
from src.bw_learner.expression_learner import ExpressionLearner
|
||||
from src.bw_learner.jargon_miner import JargonMiner
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_action import ActionUtils
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages_with_id,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
get_messages_before_time_in_chat,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.core.component_registry import component_registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
@@ -320,7 +322,7 @@ class BrainPlanner:
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
|
||||
@@ -1,734 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_data_model import ReplyContentType
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
||||
from src.bw_learner.message_recorder import extract_and_distribute_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import record_replyer_action_temp
|
||||
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_data_model import ReplySetModel
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
"action_result": {
|
||||
"action_type": "error",
|
||||
"action_data": {},
|
||||
"reasoning": "循环处理失败",
|
||||
},
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
||||
|
||||
logger = get_logger("hfc") # Logger Name Changed
|
||||
|
||||
|
||||
class HeartFChatting:
|
||||
"""
|
||||
管理一个连续的Focus Chat循环
|
||||
用于在特定聊天流中生成回复。
|
||||
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
HeartFChatting 初始化函数
|
||||
|
||||
参数:
|
||||
chat_id: 聊天流唯一标识符(如stream_id)
|
||||
on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
|
||||
performance_version: 性能记录版本号,用于区分不同启动版本
|
||||
"""
|
||||
# 基础属性
|
||||
self.stream_id: str = chat_id # 聊天流ID
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||
|
||||
# 循环控制内部状态
|
||||
self.running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
|
||||
|
||||
# 添加循环信息管理相关的属性
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
|
||||
self.is_mute = False
|
||||
|
||||
self.last_active_time = time.time() # 记录上一次非noreply时间
|
||||
|
||||
self.question_probability_multiplier = 1
|
||||
self.questioned = False
|
||||
|
||||
# 跟踪连续 no_reply 次数,用于动态调整阈值
|
||||
self.consecutive_no_reply_count = 0
|
||||
|
||||
# 聊天内容概括器
|
||||
self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id)
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
# 如果循环已经激活,直接返回
|
||||
if self.running:
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
# 标记为活动状态,防止重复启动
|
||||
self.running = True
|
||||
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
|
||||
# 启动聊天内容概括器的后台定期检查循环
|
||||
await self.chat_history_summarizer.start()
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
|
||||
|
||||
except Exception as e:
|
||||
# 启动失败时重置状态
|
||||
self.running = False
|
||||
self._loop_task = None
|
||||
logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}")
|
||||
raise
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_loop 任务完成时执行的回调。"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
||||
|
||||
def start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
cycle_timers = {}
|
||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||
|
||||
def end_cycle(self, loop_info, cycle_timers):
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
self.history_loop.append(self._current_cycle_detail)
|
||||
self._current_cycle_detail.timers = cycle_timers
|
||||
self._current_cycle_detail.end_time = time.time()
|
||||
|
||||
def print_cycle_info(self, cycle_timers):
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
if elapsed < 0.1:
|
||||
# 不显示小于0.1秒的计时器
|
||||
continue
|
||||
formatted_time = f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore
|
||||
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
async def _loopbody(self):
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=0,
|
||||
)
|
||||
|
||||
# 根据连续 no_reply 次数动态调整阈值
|
||||
# 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2)
|
||||
# 5次 no_reply 时,提高到 2(大于等于两条消息的阈值)
|
||||
if self.consecutive_no_reply_count >= 5:
|
||||
threshold = 2
|
||||
elif self.consecutive_no_reply_count >= 3:
|
||||
# 1.5 的含义:50%概率为1,50%概率为2
|
||||
threshold = 2 if random.random() < 0.5 else 1
|
||||
else:
|
||||
threshold = 1
|
||||
|
||||
if len(recent_messages_list) >= threshold:
|
||||
# for message in recent_messages_list:
|
||||
# print(message.processed_plain_text)
|
||||
|
||||
self.last_read_time = time.time()
|
||||
|
||||
# !此处使at或者提及必定回复
|
||||
mentioned_message = None
|
||||
for message in recent_messages_list:
|
||||
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
||||
mentioned_message = message
|
||||
|
||||
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||
|
||||
# *控制频率用
|
||||
if mentioned_message:
|
||||
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
|
||||
elif (
|
||||
random.random()
|
||||
< global_config.chat.get_talk_value(self.stream_id)
|
||||
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
|
||||
):
|
||||
await self._observe(recent_messages_list=recent_messages_list)
|
||||
else:
|
||||
# 没有提到,继续保持沉默,等待5秒防止频繁触发
|
||||
await asyncio.sleep(10)
|
||||
return True
|
||||
else:
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set: "ReplySetModel",
|
||||
action_message: "DatabaseMessages",
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.chat_info.platform
|
||||
if platform is None:
|
||||
platform = getattr(self.chat_stream, "platform", "unknown")
|
||||
|
||||
person = Person(platform=platform, user_id=action_message.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
|
||||
force_reply_message: Optional["DatabaseMessages"] = None,
|
||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
start_time = time.time()
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
|
||||
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
|
||||
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
|
||||
|
||||
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
# asyncio.create_task(check_and_make_question(self.stream_id))
|
||||
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
|
||||
# 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
|
||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
|
||||
)
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
force_reply_message=force_reply_message,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
|
||||
excute_result_str = ""
|
||||
for result in results:
|
||||
excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
|
||||
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["result"]
|
||||
elif result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["result"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
self.action_planner.add_plan_excute_log(result=excute_result_str)
|
||||
|
||||
# 构建最终的循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
# 更新动作执行信息
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
_reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
_reply_text = action_reply_text
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
end_time = time.time()
|
||||
if end_time - start_time < global_config.chat.planner_smooth:
|
||||
wait_time = global_config.chat.planner_smooth - (end_time - start_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
return True
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
try:
|
||||
while self.running:
|
||||
# 主循环
|
||||
success = await self._loopbody()
|
||||
await asyncio.sleep(0.1)
|
||||
if not success:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
|
||||
except Exception:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
|
||||
print(traceback.format_exc())
|
||||
await asyncio.sleep(3)
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
action_reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
action_message: Optional["DatabaseMessages"] = None,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理规划动作,使用动作工厂创建相应的动作处理器
|
||||
|
||||
参数:
|
||||
action: 动作类型
|
||||
action_reasoning: 决策理由
|
||||
action_data: 动作数据,包含不同动作需要的参数
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
action_message: 消息数据
|
||||
返回:
|
||||
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
|
||||
"""
|
||||
try:
|
||||
# 使用工厂创建动作处理器实例
|
||||
try:
|
||||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_reasoning=action_reasoning,
|
||||
action_message=action_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, ""
|
||||
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
result = await action_handler.execute()
|
||||
success, action_text = result
|
||||
|
||||
return success, action_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, ""
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set: "ReplySetModel",
|
||||
message_data: "DatabaseMessages",
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> str:
|
||||
# 根据 llm_quote 配置决定是否使用 quote_message 参数
|
||||
if global_config.chat.llm_quote:
|
||||
# 如果配置为 true,使用 llm_quote 参数决定是否引用回复
|
||||
if quote_message is None:
|
||||
logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
|
||||
need_reply = False
|
||||
else:
|
||||
need_reply = quote_message
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
|
||||
else:
|
||||
# 如果配置为 false,使用原来的模式
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒")
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for reply_content in reply_set.reply_data:
|
||||
if reply_content.content_type != ReplyContentType.TEXT:
|
||||
continue
|
||||
data: str = reply_content.content # type: ignore
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
reply_text += data
|
||||
|
||||
return reply_text
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
chosen_action_plan_infos: List[ActionPlannerInfo],
|
||||
thinking_id: str,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
# 直接当场执行no_reply逻辑
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 增加连续 no_reply 计数
|
||||
self.consecutive_no_reply_count += 1
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
# 直接当场执行reply逻辑
|
||||
self.questioned = False
|
||||
# 刷新主动发言状态
|
||||
# 重置连续 no_reply 计数
|
||||
self.consecutive_no_reply_count = 0
|
||||
|
||||
reason = action_planner_info.reasoning or ""
|
||||
# 根据 think_mode 配置决定 think_level 的值
|
||||
think_mode = global_config.chat.think_mode
|
||||
if think_mode == "default":
|
||||
think_level = 0
|
||||
elif think_mode == "deep":
|
||||
think_level = 1
|
||||
elif think_mode == "dynamic":
|
||||
# dynamic 模式:从 planner 返回的 action_data 中获取
|
||||
think_level = action_planner_info.action_data.get("think_level", 1)
|
||||
else:
|
||||
# 默认使用 default 模式
|
||||
think_level = 0
|
||||
# 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason
|
||||
planner_reasoning = action_planner_info.action_reasoning or reason
|
||||
|
||||
record_replyer_action_temp(
|
||||
chat_id=self.stream_id,
|
||||
reason=reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
quote_message = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned_uw: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned_uw.append(s)
|
||||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
# 从 Planner 的 action_data 中提取 quote_message 参数
|
||||
qm = action_planner_info.action_data.get("quote")
|
||||
if qm is not None:
|
||||
# 支持多种格式:true/false, "true"/"false", 1/0
|
||||
if isinstance(qm, bool):
|
||||
quote_message = qm
|
||||
elif isinstance(qm, str):
|
||||
quote_message = qm.lower() in ("true", "1", "yes")
|
||||
elif isinstance(qm, (int, float)):
|
||||
quote_message = bool(qm)
|
||||
|
||||
logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=planner_reasoning,
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你使用reply动作,对' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
else:
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, result = await self._handle_action(
|
||||
action=action_planner_info.action_type,
|
||||
action_reasoning=action_planner_info.action_reasoning or "",
|
||||
action_data=action_planner_info.action_data or {},
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
action_message=action_planner_info.action_message,
|
||||
)
|
||||
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"result": result,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"result": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -1,377 +1,231 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from rich.traceback import install
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.bw_learner.expression_learner import ExpressionLearner
|
||||
from src.bw_learner.jargon_miner import JargonMiner
|
||||
from src.chat.event_helpers import build_event_message
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import record_replyer_action_temp
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.config.file_watcher import FileChange
|
||||
from src.core.event_bus import event_bus
|
||||
from src.core.types import ActionInfo, EventType
|
||||
from src.person_info.person_info import Person
|
||||
from src.services import (
|
||||
database_service as database_api,
|
||||
generator_service as generator_api,
|
||||
message_service as message_api,
|
||||
send_service as send_api,
|
||||
)
|
||||
from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
|
||||
from .heartFC_utils import CycleDetail
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
install(extra_lines=5)
|
||||
|
||||
logger = get_logger("heartFC_chat")
|
||||
|
||||
|
||||
class HeartFChatting:
|
||||
"""管理一个持续运行的 Focus Chat 会话。"""
|
||||
"""
|
||||
管理一个连续的Focus Chat聊天会话
|
||||
用于在特定的聊天会话里面生成回复
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.session_id) # type: ignore[assignment]
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天会话 {self.session_id}")
|
||||
"""
|
||||
初始化 HeartFChatting 实例
|
||||
|
||||
session_name = _chat_manager.get_session_name(session_id) or session_id
|
||||
Args:
|
||||
session_id: 聊天会话ID
|
||||
"""
|
||||
# 基础属性
|
||||
self.session_id = session_id
|
||||
session_name = chat_manager.get_session_name(session_id) or session_id
|
||||
self.log_prefix = f"[{session_name}]"
|
||||
self.session_name = session_name
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = ActionPlanner(chat_id=self.session_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.session_id)
|
||||
|
||||
# 系统运行状态
|
||||
self._running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None
|
||||
self._cycle_counter: int = 0
|
||||
self._hfc_lock: asyncio.Lock = asyncio.Lock() # 用于保护 _hfc_func 的并发访问
|
||||
# 聊天频率相关
|
||||
self._consecutive_no_reply_count = 0 # 跟踪连续 no_reply 次数,用于动态调整阈值
|
||||
self._talk_frequency_adjust: float = 1.0 # 发言频率修正值,默认为1.0,可以根据需要调整
|
||||
|
||||
# HFC内消息缓存
|
||||
self.message_cache: List[SessionMessage] = []
|
||||
|
||||
# Asyncio Event 用于控制循环的开始和结束
|
||||
self._cycle_event = asyncio.Event()
|
||||
self._hfc_lock = asyncio.Lock()
|
||||
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: Optional[CycleDetail] = None
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
self.last_active_time = time.time()
|
||||
self._talk_frequency_adjust = 1.0
|
||||
self._consecutive_no_reply_count = 0
|
||||
|
||||
self.message_cache: List["SessionMessage"] = []
|
||||
|
||||
self._min_messages_for_extraction = 30
|
||||
self._min_extraction_interval = 60
|
||||
self._last_extraction_time = 0.0
|
||||
|
||||
# 表达方式相关内容
|
||||
self._min_messages_for_extraction = 30 # 最少提取消息数
|
||||
self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒
|
||||
self._last_extraction_time: float = 0.0 # 上次提取的时间戳
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
|
||||
self._enable_expression_use = expr_use
|
||||
self._enable_expression_learning = expr_learn
|
||||
self._enable_jargon_learning = jargon_learn
|
||||
self._expression_learner = ExpressionLearner(session_id)
|
||||
self._jargon_miner = JargonMiner(session_id, session_name=session_name)
|
||||
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
|
||||
self._enable_expression_learning = expr_learn # 允许学习表达方式
|
||||
self._enable_jargon_learning = jargon_learn # 允许学习黑话
|
||||
# 表达学习器
|
||||
self._expression_learner: ExpressionLearner = ExpressionLearner(session_id)
|
||||
# 黑话挖掘器
|
||||
self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name)
|
||||
|
||||
# TODO: ChatSummarizer 聊天总结器重构
|
||||
|
||||
# ====== 公开方法 ======
|
||||
|
||||
async def start(self):
|
||||
"""启动 HeartFChatting 的主循环"""
|
||||
# 先检查是否已经启动运行
|
||||
if self._running:
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中")
|
||||
logger.debug(f"{self.log_prefix} 已经在运行中,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
self._running = True
|
||||
self._cycle_event.clear()
|
||||
self._cycle_event.clear() # 确保事件初始状态为未设置
|
||||
|
||||
self._loop_task = asyncio.create_task(self.main_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {exc}", exc_info=True)
|
||||
self._running = False
|
||||
self._cycle_event.set()
|
||||
self._loop_task = None
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 启动 HeartFChatting 失败: {e}", exc_info=True)
|
||||
self._running = False # 确保状态正确
|
||||
self._cycle_event.set() # 确保事件被设置,避免死锁
|
||||
self._loop_task = None # 确保任务引用被清理
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""停止 HeartFChatting 的主循环"""
|
||||
if not self._running:
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已停止")
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已经停止,无需重复停止")
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._cycle_event.set()
|
||||
self._cycle_event.set() # 触发事件,通知循环结束
|
||||
|
||||
if self._loop_task:
|
||||
self._loop_task.cancel()
|
||||
self._loop_task.cancel() # 取消主循环任务
|
||||
try:
|
||||
await self._loop_task
|
||||
await self._loop_task # 等待任务完成
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 主循环已取消")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {exc}", exc_info=True)
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 主循环已成功取消")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {e}", exc_info=True)
|
||||
finally:
|
||||
self._loop_task = None
|
||||
self._loop_task = None # 确保任务引用被清理
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 已停止")
|
||||
|
||||
def adjust_talk_frequency(self, new_value: float):
|
||||
"""调整发言频率的调整值
|
||||
|
||||
Args:
|
||||
new_value: 新的修正值,必须为非负数。值越大,修正发言频率越高;值越小,修正发言频率越低。
|
||||
"""
|
||||
self._talk_frequency_adjust = max(0.0, new_value)
|
||||
|
||||
async def register_message(self, message: "SessionMessage"):
|
||||
"""注册一条消息到 HeartFChatting 的缓存中,并检测其是否产生提及,决定是否唤醒聊天
|
||||
|
||||
Args:
|
||||
message: 待注册的消息对象
|
||||
"""
|
||||
self.message_cache.append(message)
|
||||
|
||||
# 先检查at必回复
|
||||
if global_config.chat.inevitable_at_reply and message.is_at:
|
||||
self.last_read_time = time.time()
|
||||
async with self._hfc_lock:
|
||||
await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
|
||||
return
|
||||
|
||||
async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
|
||||
await self._judge_and_response(message)
|
||||
return # 直接返回,避免同一条消息被主循环再次处理
|
||||
# 再检查提及必回复
|
||||
if global_config.chat.mentioned_bot_reply and message.is_mentioned:
|
||||
self.last_read_time = time.time()
|
||||
async with self._hfc_lock:
|
||||
await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
|
||||
# 直接获取锁,确保一定一定触发回复逻辑,不受当前是否正在执行主循环的影响
|
||||
async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
|
||||
await self._judge_and_response(message)
|
||||
return
|
||||
|
||||
async def main_loop(self):
|
||||
try:
|
||||
while self._running and not self._cycle_event.is_set():
|
||||
if not self._hfc_lock.locked():
|
||||
async with self._hfc_lock:
|
||||
async with self._hfc_lock: # 确保主循环逻辑的互斥访问
|
||||
await self._hfc_func()
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常: {exc}", exc_info=True)
|
||||
await self.stop()
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消,正在关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误: {e},将于3s后尝试重新启动")
|
||||
await self.stop() # 确保状态正确
|
||||
await asyncio.sleep(3)
|
||||
await self.start()
|
||||
await self.start() # 尝试重新启动
|
||||
|
||||
async def _config_callback(self, file_change: Optional[FileChange] = None):
|
||||
del file_change
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.session_id)
|
||||
self._enable_expression_use = expr_use
|
||||
self._enable_expression_learning = expr_learn
|
||||
self._enable_jargon_learning = jargon_learn
|
||||
"""配置文件变更回调函数"""
|
||||
# TODO: 根据配置文件变动重新计算相关参数:
|
||||
"""
|
||||
需要计算的参数:
|
||||
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
|
||||
self._enable_expression_learning = expr_learn # 允许学习表达方式
|
||||
self._enable_jargon_learning = jargon_learn # 允许学习黑话
|
||||
"""
|
||||
|
||||
async def _hfc_func(self):
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.session_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
# ====== 心流聊天核心逻辑 ======
|
||||
async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None):
|
||||
"""心流聊天的主循环逻辑"""
|
||||
if self._consecutive_no_reply_count >= 5:
|
||||
threshold = 2
|
||||
elif self._consecutive_no_reply_count >= 3:
|
||||
threshold = 2 if random.random() < 0.5 else 1
|
||||
else:
|
||||
threshold = 1
|
||||
|
||||
if len(recent_messages_list) < 1:
|
||||
if len(self.message_cache) < threshold:
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
|
||||
self.last_read_time = time.time()
|
||||
|
||||
mentioned_message: Optional["SessionMessage"] = None
|
||||
for message in recent_messages_list:
|
||||
if global_config.chat.inevitable_at_reply and message.is_at:
|
||||
mentioned_message = message
|
||||
elif global_config.chat.mentioned_bot_reply and message.is_mentioned:
|
||||
mentioned_message = message
|
||||
|
||||
talk_value = ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
|
||||
if mentioned_message:
|
||||
await self._judge_and_response(mentioned_message=mentioned_message, recent_messages_list=recent_messages_list)
|
||||
elif random.random() < talk_value:
|
||||
await self._judge_and_response(recent_messages_list=recent_messages_list)
|
||||
talk_value_threshold = (
|
||||
random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
|
||||
)
|
||||
if mentioned_message and global_config.chat.mentioned_bot_reply:
|
||||
await self._judge_and_response(mentioned_message)
|
||||
elif random.random() < talk_value_threshold:
|
||||
await self._judge_and_response()
|
||||
return True
|
||||
|
||||
async def _judge_and_response(
|
||||
self,
|
||||
mentioned_message: Optional["SessionMessage"] = None,
|
||||
recent_messages_list: Optional[List["SessionMessage"]] = None,
|
||||
):
|
||||
recent_messages = list(recent_messages_list or self.message_cache[-20:])
|
||||
if recent_messages:
|
||||
asyncio.create_task(self._trigger_expression_learning(recent_messages))
|
||||
|
||||
cycle_timers, thinking_id = self._start_cycle()
|
||||
async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None):
|
||||
"""判定和生成回复"""
|
||||
asyncio.create_task(self._trigger_expression_learning(self.message_cache))
|
||||
# TODO: 完成反思器之后的逻辑
|
||||
start_time = time.time()
|
||||
current_cycle_detail = self._start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
try:
|
||||
async with global_prompt_manager.async_message_scope(self._get_template_name()):
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {exc}", exc_info=True)
|
||||
# TODO: 动作检查逻辑
|
||||
# TODO: Planner逻辑
|
||||
# TODO: 动作执行逻辑
|
||||
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
message_list_before_now = get_messages_before_time_in_chat(
|
||||
chat_id=self.session_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt, filtered_actions = await self._build_planner_prompt_with_event(
|
||||
available_actions=available_actions,
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
if prompt is None:
|
||||
return False
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
reasoning, action_to_use_info, llm_raw_output, llm_reasoning, llm_duration_ms = (
|
||||
await self.action_planner._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
available_actions=available_actions,
|
||||
loop_start_time=self.last_read_time,
|
||||
)
|
||||
)
|
||||
|
||||
action_to_use_info = self._ensure_force_reply_action(
|
||||
actions=action_to_use_info,
|
||||
force_reply_message=mentioned_message,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
self.action_planner.add_plan_log(reasoning, action_to_use_info)
|
||||
self.action_planner.last_obs_time_mark = time.time()
|
||||
self._log_plan(
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
llm_raw_output=llm_raw_output,
|
||||
llm_reasoning=llm_reasoning,
|
||||
llm_duration_ms=llm_duration_ms,
|
||||
actions=action_to_use_info,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(
|
||||
action,
|
||||
action_to_use_info,
|
||||
thinking_id,
|
||||
available_actions,
|
||||
cycle_timers,
|
||||
)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
execute_result_str = ""
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}", exc_info=True)
|
||||
continue
|
||||
|
||||
execute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
|
||||
if result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["result"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} reply 动作执行失败")
|
||||
else:
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["result"]
|
||||
|
||||
self.action_planner.add_plan_excute_log(result=execute_result_str)
|
||||
|
||||
if reply_loop_info:
|
||||
loop_info = reply_loop_info
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
reply_text_from_reply = action_reply_text
|
||||
|
||||
current_cycle_detail = self._end_cycle(self._current_cycle_detail, loop_info)
|
||||
logger.debug(f"{self.log_prefix} 本轮最终输出: {reply_text_from_reply}")
|
||||
return current_cycle_detail is not None
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 判定与回复流程失败: {exc}", exc_info=True)
|
||||
if self._current_cycle_detail:
|
||||
self._end_cycle(
|
||||
self._current_cycle_detail,
|
||||
{
|
||||
"loop_plan_info": {"action_result": []},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"taken_time": time.time(),
|
||||
"error": str(exc),
|
||||
},
|
||||
},
|
||||
)
|
||||
return False
|
||||
cycle_detail = self._end_cycle(current_cycle_detail)
|
||||
if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
|
||||
return True
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_func 任务完成时执行的回调。"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常退出: {exception}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 主循环已退出")
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 聊天已结束")
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
||||
|
||||
# ====== 学习器触发逻辑 ======
|
||||
async def _trigger_expression_learning(self, messages: List["SessionMessage"]):
|
||||
if not messages:
|
||||
return
|
||||
|
||||
self._expression_learner.add_messages(messages)
|
||||
if time.time() - self._last_extraction_time < self._min_extraction_interval:
|
||||
return
|
||||
@@ -379,14 +233,12 @@ class HeartFChatting:
|
||||
return
|
||||
if not self._enable_expression_learning:
|
||||
return
|
||||
|
||||
extraction_end_time = time.time()
|
||||
logger.info(
|
||||
f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息,"
|
||||
f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}"
|
||||
)
|
||||
self._last_extraction_time = extraction_end_time
|
||||
|
||||
try:
|
||||
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
|
||||
learnt_style = await self._expression_learner.learn(jargon_miner)
|
||||
@@ -394,398 +246,43 @@ class HeartFChatting:
|
||||
logger.info(f"{self.log_prefix} 表达学习完成")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 表达学习未获得有效结果")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True)
|
||||
|
||||
def _start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||
# ====== 记录循环执行信息相关逻辑 ======
|
||||
def _start_cycle(self) -> CycleDetail:
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
return self._current_cycle_detail.time_records, self._current_cycle_detail.thinking_id
|
||||
current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
|
||||
current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
return current_cycle_detail
|
||||
|
||||
def _end_cycle(self, cycle_detail: Optional[CycleDetail], loop_info: Optional[Dict[str, Any]] = None):
|
||||
if cycle_detail is None:
|
||||
return None
|
||||
|
||||
cycle_detail.loop_plan_info = (loop_info or {}).get("loop_plan_info")
|
||||
cycle_detail.loop_action_info = (loop_info or {}).get("loop_action_info")
|
||||
def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True):
|
||||
cycle_detail.end_time = time.time()
|
||||
self.history_loop.append(cycle_detail)
|
||||
|
||||
timer_strings = [
|
||||
timer_strings: List[str] = [
|
||||
f"{name}: {duration:.2f}s"
|
||||
for name, duration in cycle_detail.time_records.items()
|
||||
if duration >= 0.1
|
||||
if not only_long_execution or duration >= 0.1
|
||||
]
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{cycle_detail.cycle_id} 个心流循环完成,"
|
||||
f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}s;"
|
||||
f"{self.log_prefix} 第 {cycle_detail.cycle_id} 个心流循环完成"
|
||||
f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}秒\n"
|
||||
f"详细计时: {', '.join(timer_strings) if timer_strings else '无'}"
|
||||
)
|
||||
|
||||
return cycle_detail
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
chosen_action_plan_infos: List[ActionPlannerInfo],
|
||||
thinking_id: str,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
):
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
self._consecutive_no_reply_count += 1
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"success": True,
|
||||
"result": "选择不回复",
|
||||
"loop_info": None,
|
||||
}
|
||||
# ====== Action相关逻辑 ======
|
||||
async def _execute_action(self, *args, **kwargs):
|
||||
"""原ExecuteAction"""
|
||||
raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符
|
||||
|
||||
if action_planner_info.action_type == "reply":
|
||||
self._consecutive_no_reply_count = 0
|
||||
reason = action_planner_info.reasoning or ""
|
||||
think_level = self._get_think_level(action_planner_info)
|
||||
planner_reasoning = action_planner_info.action_reasoning or reason
|
||||
async def _execute_other_actions(self, *args, **kwargs):
|
||||
"""原HandleAction"""
|
||||
raise NotImplementedError(
|
||||
"执行其他动作的逻辑尚未实现"
|
||||
) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符
|
||||
|
||||
record_replyer_action_temp(
|
||||
chat_id=self.session_id,
|
||||
reason=reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
unknown_words, quote_message = self._extract_reply_metadata(action_planner_info)
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=planner_reasoning,
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time())
|
||||
if action_planner_info.action_data
|
||||
else time.time(),
|
||||
think_level=think_level,
|
||||
)
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 回复生成失败")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"result": "回复生成失败",
|
||||
"loop_info": None,
|
||||
}
|
||||
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=llm_response.reply_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore[arg-type]
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=llm_response.selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": reply_text,
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, result = await self._handle_action(
|
||||
action=action_planner_info.action_type,
|
||||
action_reasoning=action_planner_info.action_reasoning or "",
|
||||
action_data=action_planner_info.action_data or {},
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
action_message=action_planner_info.action_message,
|
||||
)
|
||||
if success:
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"result": result,
|
||||
"loop_info": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {exc}", exc_info=True)
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"result": "",
|
||||
"loop_info": None,
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
action_reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
action_message: Optional["SessionMessage"] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
try:
|
||||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
action_reasoning=action_reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
if not action_handler:
|
||||
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
|
||||
return False, ""
|
||||
|
||||
success, action_text = await action_handler.execute()
|
||||
return success, action_text
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 处理动作 {action} 时出错: {exc}", exc_info=True)
|
||||
return False, ""
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set: MessageSequence,
|
||||
action_message: "SessionMessage",
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
actions: List[ActionPlannerInfo],
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
|
||||
platform = action_message.platform or getattr(self.chat_stream, "platform", "unknown")
|
||||
person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id)
|
||||
action_prompt_display = f"你对{person.person_name}进行了回复:{reply_text}"
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=action_prompt_display,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set: MessageSequence,
|
||||
message_data: "SessionMessage",
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> str:
|
||||
if global_config.chat.llm_quote:
|
||||
need_reply = bool(quote_message)
|
||||
else:
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.session_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
)
|
||||
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for component in reply_set.components:
|
||||
if not isinstance(component, TextComponent):
|
||||
continue
|
||||
data = component.text
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.session_id,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.session_id,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
reply_text += data
|
||||
return reply_text
|
||||
|
||||
async def _build_planner_prompt_with_event(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
is_group_chat: bool,
|
||||
chat_target_info: Any,
|
||||
chat_content_block: str,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
) -> Tuple[Optional[str], Dict[str, ActionInfo]]:
|
||||
filtered_actions = self.action_planner._filter_actions_by_activation_type(available_actions, chat_content_block)
|
||||
prompt, _ = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
event_message = build_event_message(EventType.ON_PLAN, llm_prompt=prompt, stream_id=self.session_id)
|
||||
continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, event_message)
|
||||
if not continue_flag:
|
||||
logger.info(f"{self.log_prefix} ON_PLAN 事件中止了本轮 HFC")
|
||||
return None, filtered_actions
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt and modified_message.llm_prompt:
|
||||
prompt = modified_message.llm_prompt
|
||||
return prompt, filtered_actions
|
||||
|
||||
def _ensure_force_reply_action(
|
||||
self,
|
||||
actions: List[ActionPlannerInfo],
|
||||
force_reply_message: Optional["SessionMessage"],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
) -> List[ActionPlannerInfo]:
|
||||
if not force_reply_message:
|
||||
return actions
|
||||
|
||||
has_reply_to_force_message = any(
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
for action in actions
|
||||
)
|
||||
if has_reply_to_force_message:
|
||||
return actions
|
||||
|
||||
actions = [action for action in actions if action.action_type != "no_reply"]
|
||||
actions.insert(
|
||||
0,
|
||||
ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="用户提及了我,必须回复该消息",
|
||||
action_data={"loop_start_time": self.last_read_time},
|
||||
action_message=force_reply_message,
|
||||
available_actions=available_actions,
|
||||
action_reasoning=None,
|
||||
),
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 检测到强制回复消息,已补充 reply 动作")
|
||||
return actions
|
||||
|
||||
def _log_plan(
|
||||
self,
|
||||
prompt: str,
|
||||
reasoning: str,
|
||||
llm_raw_output: Optional[str],
|
||||
llm_reasoning: Optional[str],
|
||||
llm_duration_ms: Optional[float],
|
||||
actions: List[ActionPlannerInfo],
|
||||
) -> None:
|
||||
try:
|
||||
PlanReplyLogger.log_plan(
|
||||
chat_id=self.session_id,
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
raw_output=llm_raw_output,
|
||||
raw_reasoning=llm_reasoning,
|
||||
actions=actions,
|
||||
timing={
|
||||
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
|
||||
"loop_start_time": self.last_read_time,
|
||||
},
|
||||
extra=None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix} 记录 plan 日志失败")
|
||||
|
||||
def _extract_reply_metadata(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
) -> Tuple[Optional[List[str]], Optional[bool]]:
|
||||
unknown_words: Optional[List[str]] = None
|
||||
quote_message: Optional[bool] = None
|
||||
action_data = action_planner_info.action_data or {}
|
||||
|
||||
raw_unknown_words = action_data.get("unknown_words")
|
||||
if isinstance(raw_unknown_words, list):
|
||||
cleaned_unknown_words = []
|
||||
for item in raw_unknown_words:
|
||||
if isinstance(item, str) and (cleaned_item := item.strip()):
|
||||
cleaned_unknown_words.append(cleaned_item)
|
||||
if cleaned_unknown_words:
|
||||
unknown_words = cleaned_unknown_words
|
||||
|
||||
raw_quote = action_data.get("quote")
|
||||
if isinstance(raw_quote, bool):
|
||||
quote_message = raw_quote
|
||||
elif isinstance(raw_quote, str):
|
||||
quote_message = raw_quote.lower() in {"true", "1", "yes"}
|
||||
elif isinstance(raw_quote, (int, float)):
|
||||
quote_message = bool(raw_quote)
|
||||
|
||||
return unknown_words, quote_message
|
||||
|
||||
def _get_think_level(self, action_planner_info: ActionPlannerInfo) -> int:
|
||||
think_mode = global_config.chat.think_mode
|
||||
if think_mode == "default":
|
||||
return 0
|
||||
if think_mode == "deep":
|
||||
return 1
|
||||
if think_mode == "dynamic":
|
||||
action_data = action_planner_info.action_data or {}
|
||||
return int(action_data.get("think_level", 1))
|
||||
return 0
|
||||
|
||||
def _get_template_name(self) -> Optional[str]:
|
||||
if self.chat_stream.context:
|
||||
return self.chat_stream.context.template_name
|
||||
return None
|
||||
# ====== 响应发送相关方法 ======
|
||||
async def _send_response(self, *args, **kwargs):
|
||||
raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符
|
||||
# 传入的消息至少应该是个MessageSequence实例,最好是SessionMessage实例,随后可直接转化为MessageSending实例
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import traceback
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||
from src.chat.brain_chat.brain_chat import BrainChatting
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
||||
class Heartflow:
|
||||
"""主心流协调器,负责初始化并协调聊天"""
|
||||
|
||||
def __init__(self):
|
||||
self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
|
||||
|
||||
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
|
||||
"""获取或创建一个新的HeartFChatting实例"""
|
||||
try:
|
||||
if chat_id in self.heartflow_chat_list:
|
||||
if chat := self.heartflow_chat_list.get(chat_id):
|
||||
return chat
|
||||
else:
|
||||
chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
|
||||
if chat_stream.group_info:
|
||||
new_chat = HeartFChatting(chat_id=chat_id)
|
||||
else:
|
||||
new_chat = BrainChatting(chat_id=chat_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[chat_id] = new_chat
|
||||
return new_chat
|
||||
except Exception as e:
|
||||
logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
heartflow = Heartflow()
|
||||
@@ -1,19 +1,20 @@
|
||||
from contextlib import suppress
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from maim_message import MessageBase
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from maim_message import MessageBase
|
||||
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.core.component_registry import component_registry
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
from .message import SessionMessage
|
||||
from .chat_manager import chat_manager
|
||||
@@ -58,16 +59,22 @@ class ChatBot:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _process_commands(self, message: SessionMessage):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
|
||||
"""使用统一组件注册表处理命令。
|
||||
|
||||
Args:
|
||||
message: 当前待处理的会话消息。
|
||||
|
||||
Returns:
|
||||
tuple[bool, Optional[str], bool]: ``(是否命中命令, 命令响应文本, 是否继续后续处理)``。
|
||||
"""
|
||||
if not message.processed_plain_text:
|
||||
return False, None, True # 没有文本内容,继续处理消息
|
||||
try:
|
||||
text = message.processed_plain_text
|
||||
|
||||
# 使用核心组件注册表查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
# 使用插件运行时统一查询服务查找命令
|
||||
command_result = component_query_service.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_executor, matched_groups, command_info = command_result
|
||||
plugin_name = command_info.plugin_name
|
||||
@@ -81,7 +88,7 @@ class ChatBot:
|
||||
message.is_command = True
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
plugin_config = component_query_service.get_plugin_config(plugin_name)
|
||||
|
||||
try:
|
||||
# 调用命令执行器
|
||||
@@ -112,88 +119,32 @@ class ChatBot:
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
|
||||
# 没有找到旧系统命令,尝试新版本插件运行时
|
||||
new_cmd_result = await self._process_new_runtime_command(message)
|
||||
return new_cmd_result if new_cmd_result is not None else (False, None, True)
|
||||
return False, None, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def _process_new_runtime_command(self, message: SessionMessage):
|
||||
"""尝试在新版本插件运行时中查找并执行命令
|
||||
|
||||
Returns:
|
||||
(found, response, continue_processing) 三元组,
|
||||
或 None 表示新运行时中也未找到匹配命令。
|
||||
"""
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
prm = get_plugin_runtime_manager()
|
||||
if not prm.is_running:
|
||||
return None
|
||||
|
||||
matched = prm.find_command_by_text(message.processed_plain_text)
|
||||
if matched is None:
|
||||
return None
|
||||
|
||||
command_name = matched["name"]
|
||||
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
|
||||
message.session_id
|
||||
):
|
||||
logger.info(f"[新运行时] 用户禁用的命令,跳过处理: {matched['full_name']}")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
logger.info(f"[新运行时] 匹配命令: {matched['full_name']}")
|
||||
|
||||
try:
|
||||
resp = await prm.invoke_plugin(
|
||||
method="plugin.invoke_command",
|
||||
plugin_id=matched["plugin_id"],
|
||||
component_name=matched["name"],
|
||||
args={
|
||||
"text": message.processed_plain_text,
|
||||
"stream_id": message.session_id or "",
|
||||
"matched_groups": matched.get("matched_groups") or {},
|
||||
},
|
||||
timeout_ms=30000,
|
||||
)
|
||||
|
||||
payload = resp.payload
|
||||
success = payload.get("success", False)
|
||||
cmd_result = payload.get("result")
|
||||
|
||||
# 拦截位优先从命令返回值中获取(支持运行时动态决定),
|
||||
# 回退到组件 metadata 中的静态声明
|
||||
if isinstance(cmd_result, (list, tuple)) and len(cmd_result) >= 3:
|
||||
# 命令返回 (found, response_text, intercept_bool) 三元组
|
||||
response_text = cmd_result[1] if cmd_result[1] is not None else ""
|
||||
intercept = bool(cmd_result[2])
|
||||
else:
|
||||
response_text = cmd_result if cmd_result is not None else ""
|
||||
intercept = bool(matched["metadata"].get("intercept_message_level", 0))
|
||||
|
||||
self._mark_command_message(message, int(intercept))
|
||||
|
||||
if success:
|
||||
logger.info(f"[新运行时] 命令执行成功: {matched['full_name']}")
|
||||
else:
|
||||
logger.warning(f"[新运行时] 命令执行失败: {matched['full_name']} - {response_text}")
|
||||
|
||||
return True, response_text, not intercept
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[新运行时] 执行命令 {matched['full_name']} 异常: {e}", exc_info=True)
|
||||
return True, str(e), True
|
||||
|
||||
@staticmethod
|
||||
def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None:
|
||||
"""标记消息已经被命令链消费。
|
||||
|
||||
Args:
|
||||
message: 待标记的会话消息。
|
||||
intercept_message_level: 命令设置的拦截级别。
|
||||
"""
|
||||
|
||||
message.is_command = True
|
||||
message.message_info.additional_config["intercept_message_level"] = intercept_message_level
|
||||
|
||||
@staticmethod
|
||||
def _store_intercepted_command_message(message: SessionMessage) -> None:
|
||||
"""将被命令链拦截的消息写入数据库。
|
||||
|
||||
Args:
|
||||
message: 已完成命令处理的会话消息。
|
||||
"""
|
||||
|
||||
MessageUtils.store_message_to_db(message)
|
||||
|
||||
async def _handle_command_processing_result(
|
||||
@@ -310,13 +261,28 @@ class ChatBot:
|
||||
# logger.debug(str(message_data))
|
||||
maim_raw_message = MessageBase.from_dict(message_data)
|
||||
message = SessionMessage.from_maim_message(maim_raw_message)
|
||||
await self.receive_message(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def receive_message(self, message: SessionMessage):
|
||||
try:
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
account_id = None
|
||||
scope = None
|
||||
additional_config = message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
message.platform,
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
group_id=group_info.group_id if group_info else None,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
message.session_id = session_id # 正确初始化session_id
|
||||
@@ -359,24 +325,24 @@ class ChatBot:
|
||||
platform = message.platform
|
||||
user_id = user_info.user_id
|
||||
group_id = group_info.group_id if group_info else None
|
||||
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
|
||||
try:
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
|
||||
await memory_automation_service.on_incoming_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}")
|
||||
_ = await chat_manager.get_or_create_session(
|
||||
platform,
|
||||
user_id,
|
||||
group_id,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
) # 确保会话存在
|
||||
|
||||
# message.update_chat_stream(chat)
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
# 注意:命令返回的 response 当前只用于日志记录和流程判断,
|
||||
# 不会在这里自动作为回复消息发送回会话。
|
||||
is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
# is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
|
||||
return
|
||||
# # 如果是命令且不需要继续处理,则直接返回
|
||||
# if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
|
||||
# return
|
||||
|
||||
# continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
|
||||
# if not continue_flag:
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Optional, TYPE_CHECKING, List, Dict
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.chat_session_data_model import MaiChatSession
|
||||
from src.common.database.database_model import ChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ChatSession
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .message import SessionMessage
|
||||
@@ -82,7 +83,12 @@ class ChatManager:
|
||||
logger.error(f"初始化聊天管理器出现错误: {e}")
|
||||
|
||||
async def get_or_create_session(
|
||||
self, platform: str, user_id: str, group_id: Optional[str] = None
|
||||
self,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
group_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
) -> BotChatSession:
|
||||
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
|
||||
|
||||
@@ -90,12 +96,20 @@ class ChatManager:
|
||||
platform: 平台
|
||||
user_id: 用户ID
|
||||
group_id: 群ID(如果是群聊)
|
||||
account_id: 平台账号 ID
|
||||
scope: 路由作用域
|
||||
Returns:
|
||||
return (BotChatSession) 会话对象
|
||||
Raises:
|
||||
Exception: 获取或创建会话时发生错误
|
||||
"""
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
if session := self.get_session_by_session_id(session_id):
|
||||
session.update_active_time()
|
||||
return session
|
||||
@@ -131,7 +145,18 @@ class ChatManager:
|
||||
raise ValueError("消息缺少平台信息")
|
||||
user_id = message.message_info.user_info.user_id
|
||||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
account_id = None
|
||||
scope = None
|
||||
additional_config = message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
message.session_id = session_id # 确保消息的session_id正确设置
|
||||
self.last_messages[session_id] = message
|
||||
|
||||
@@ -188,7 +213,12 @@ class ChatManager:
|
||||
return None
|
||||
|
||||
def get_session_by_info(
|
||||
self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None
|
||||
self,
|
||||
platform: str,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
) -> Optional[BotChatSession]:
|
||||
"""根据平台、用户ID和群ID获取对应的会话
|
||||
|
||||
@@ -196,10 +226,18 @@ class ChatManager:
|
||||
platform: 平台
|
||||
user_id: 用户ID
|
||||
group_id: 群ID(如果是群聊)
|
||||
account_id: 平台账号 ID
|
||||
scope: 路由作用域
|
||||
Returns:
|
||||
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
||||
"""
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
return self.get_session_by_session_id(session_id)
|
||||
|
||||
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
||||
|
||||
@@ -1,31 +1,37 @@
|
||||
from rich.traceback import install
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.message_server.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
||||
from src.common.data_models.message_component_data_model import ReplyComponent
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_server.api import get_global_api
|
||||
from src.webui.routers.chat.serializers import serialize_message_sequence
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("sender")
|
||||
|
||||
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||
_webui_chat_broadcaster = None
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
|
||||
|
||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
|
||||
# TODO: 重构完成后完成webui相关
|
||||
def get_webui_chat_broadcaster():
|
||||
"""获取 WebUI 聊天室广播器"""
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
|
||||
"""获取 WebUI 聊天室广播器。
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
|
||||
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
|
||||
"""
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
@@ -38,102 +44,35 @@ def get_webui_chat_broadcaster():
|
||||
|
||||
|
||||
def is_webui_virtual_group(group_id: str) -> bool:
|
||||
"""检查是否是 WebUI 虚拟群"""
|
||||
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_message_segments(segment) -> list:
|
||||
"""解析消息段,转换为 WebUI 可用的格式
|
||||
|
||||
参考 NapCat 适配器的消息解析逻辑
|
||||
"""检查是否是 WebUI 虚拟群。
|
||||
|
||||
Args:
|
||||
segment: Seg 消息段对象
|
||||
group_id: 待判断的群 ID。
|
||||
|
||||
Returns:
|
||||
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||
bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。
|
||||
"""
|
||||
|
||||
result = []
|
||||
|
||||
if segment is None:
|
||||
return result
|
||||
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
if segment.data:
|
||||
for seg in segment.data:
|
||||
result.extend(parse_message_segments(seg))
|
||||
elif segment.type == "text":
|
||||
# 文本消息
|
||||
if segment.data:
|
||||
result.append({"type": "text", "data": segment.data})
|
||||
elif segment.type == "image":
|
||||
# 图片消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
|
||||
elif segment.type == "emoji":
|
||||
# 表情包消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
|
||||
elif segment.type == "imageurl":
|
||||
# 图片链接消息
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": segment.data})
|
||||
elif segment.type == "face":
|
||||
# 原生表情
|
||||
result.append({"type": "face", "data": segment.data})
|
||||
elif segment.type == "voice":
|
||||
# 语音消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
|
||||
elif segment.type == "voiceurl":
|
||||
# 语音链接
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": segment.data})
|
||||
elif segment.type == "video":
|
||||
# 视频消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
|
||||
elif segment.type == "videourl":
|
||||
# 视频链接
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": segment.data})
|
||||
elif segment.type == "music":
|
||||
# 音乐消息
|
||||
result.append({"type": "music", "data": segment.data})
|
||||
elif segment.type == "file":
|
||||
# 文件消息
|
||||
result.append({"type": "file", "data": segment.data})
|
||||
elif segment.type == "reply":
|
||||
# 回复消息
|
||||
result.append({"type": "reply", "data": segment.data})
|
||||
elif segment.type == "forward":
|
||||
# 转发消息
|
||||
forward_items = []
|
||||
if segment.data:
|
||||
for item in segment.data:
|
||||
forward_items.append(
|
||||
{
|
||||
"content": parse_message_segments(item.get("message_segment", {}))
|
||||
if isinstance(item, dict)
|
||||
else []
|
||||
}
|
||||
)
|
||||
result.append({"type": "forward", "data": forward_items})
|
||||
else:
|
||||
# 未知类型,尝试作为文本处理
|
||||
if segment.data:
|
||||
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
|
||||
|
||||
return result
|
||||
return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"""执行统一的消息发送流程。
|
||||
|
||||
发送顺序为:
|
||||
1. WebUI 特殊链路
|
||||
2. 旧版 ``maim_message`` / API Server 链路
|
||||
|
||||
Args:
|
||||
message: 待发送的内部会话消息。
|
||||
show_log: 是否输出发送成功日志。
|
||||
|
||||
Returns:
|
||||
bool: 是否最终发送成功。
|
||||
"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||
platform = message.platform
|
||||
group_id = message.session.group_id
|
||||
group_info = message.message_info.group_info
|
||||
group_id = group_info.group_id if group_info is not None else ""
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
@@ -146,7 +85,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
from src.config.config import global_config
|
||||
|
||||
# 解析消息段,获取富文本内容
|
||||
message_segments = parse_message_segments(message.message_segment)
|
||||
message_segments = serialize_message_sequence(message.raw_message)
|
||||
|
||||
# 判断消息类型
|
||||
# 如果只有一个文本段,使用简单的 text 类型
|
||||
@@ -185,7 +124,15 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
return True
|
||||
|
||||
# Fallback 逻辑: 尝试通过 API Server 发送
|
||||
async def send_with_new_api(legacy_exception=None):
|
||||
async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool:
|
||||
"""通过 API Server 回退链路发送消息。
|
||||
|
||||
Args:
|
||||
legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。
|
||||
|
||||
Returns:
|
||||
bool: 回退链路是否发送成功。
|
||||
"""
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -286,10 +233,24 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
|
||||
class UniversalMessageSender:
|
||||
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
|
||||
async def send_prepared_message_to_platform(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"""发送一条已完成预处理的消息到底层平台。
|
||||
|
||||
def __init__(self):
|
||||
Args:
|
||||
message: 已经完成回复组件注入、文本处理等预处理的消息对象。
|
||||
show_log: 是否输出发送成功日志。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
return await _send_message(message, show_log=show_log)
|
||||
|
||||
|
||||
class UniversalMessageSender:
|
||||
"""旧链与 WebUI 的底层发送器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化统一消息发送器。"""
|
||||
pass
|
||||
|
||||
async def send_message(
|
||||
@@ -300,18 +261,19 @@ class UniversalMessageSender:
|
||||
reply_message_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
):
|
||||
"""
|
||||
处理、发送并存储一条消息。
|
||||
) -> bool:
|
||||
"""通过旧链或 WebUI 发送并存储一条消息。
|
||||
|
||||
参数:
|
||||
message: MessageSession 对象,待发送的消息。
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
typing: 是否模拟打字等待。
|
||||
set_reply: 是否构建回复引用消息。
|
||||
set_reply: 是否构建引用回复消息。
|
||||
reply_message_id: 被引用消息的 ID。
|
||||
storage_message: 是否在发送成功后写入数据库。
|
||||
show_log: 是否输出发送日志。
|
||||
|
||||
|
||||
用法:
|
||||
- typing=True 时,发送前会有打字等待。
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
if not message.message_id:
|
||||
logger.error("消息缺少 message_id,无法发送")
|
||||
@@ -364,7 +326,7 @@ class UniversalMessageSender:
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
sent_msg = await _send_message(message, show_log=show_log)
|
||||
sent_msg = await send_prepared_message_to_platform(message, show_log=show_log)
|
||||
if not sent_msg:
|
||||
return False
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Dict, Optional, Tuple
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.core.component_registry import component_registry, ActionExecutor
|
||||
from src.core.types import ActionInfo
|
||||
from src.plugin_runtime.component_query import ActionExecutor, component_query_service
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
@@ -28,7 +28,7 @@ class ActionManager:
|
||||
"""
|
||||
动作管理器,用于管理各种类型的动作
|
||||
|
||||
使用核心组件注册表的 executor-based 模式。
|
||||
使用插件运行时统一查询服务的 executor-based 模式。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -38,7 +38,7 @@ class ActionManager:
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
self._using_actions = component_query_service.get_default_actions()
|
||||
|
||||
# === 执行Action方法 ===
|
||||
|
||||
@@ -72,17 +72,17 @@ class ActionManager:
|
||||
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
|
||||
"""
|
||||
try:
|
||||
executor = component_registry.get_action_executor(action_name)
|
||||
executor = component_query_service.get_action_executor(action_name)
|
||||
if not executor:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||
return None
|
||||
|
||||
info = component_registry.get_action_info(action_name)
|
||||
info = component_query_service.get_action_info(action_name)
|
||||
if not info:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||
return None
|
||||
|
||||
plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
|
||||
plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {}
|
||||
|
||||
handle = ActionHandle(
|
||||
executor,
|
||||
@@ -133,5 +133,5 @@ class ActionManager:
|
||||
def restore_actions(self) -> None:
|
||||
"""恢复到默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
self._using_actions = component_query_service.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
@@ -1,33 +1,36 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
import re
|
||||
import contextlib
|
||||
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
|
||||
from collections import OrderedDict
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages_with_id,
|
||||
replace_user_references,
|
||||
get_messages_before_time_in_chat,
|
||||
replace_user_references,
|
||||
translate_pid_to_description,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.core.component_registry import component_registry
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
@@ -634,7 +637,7 @@ class ActionPlanner:
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import traceback
|
||||
import time
|
||||
import asyncio
|
||||
import importlib
|
||||
import random
|
||||
import re
|
||||
|
||||
@@ -16,7 +17,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
@@ -26,7 +26,7 @@ from src.services.message_service import (
|
||||
replace_user_references,
|
||||
translate_pid_to_description,
|
||||
)
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.learners.expression_selector import expression_selector
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person
|
||||
@@ -35,8 +35,7 @@ from src.services import llm_service as llm_api
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||
from src.memory_system.retrieval_tools import get_tool_registry
|
||||
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||
from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
init_memory_retrieval_sys()
|
||||
@@ -51,10 +50,15 @@ class DefaultReplyer:
|
||||
chat_stream: BotChatSession,
|
||||
request_type: str = "replyer",
|
||||
):
|
||||
"""初始化群聊回复器。
|
||||
|
||||
Args:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
|
||||
from src.chat.tool_executor import ToolExecutor
|
||||
|
||||
@@ -1129,7 +1133,10 @@ class DefaultReplyer:
|
||||
user_id=bot_user_id,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
additional_config={},
|
||||
additional_config={
|
||||
"platform_io_target_group_id": self.chat_stream.group_id,
|
||||
"platform_io_target_user_id": self.chat_stream.user_id,
|
||||
},
|
||||
),
|
||||
message_segment=message_segment,
|
||||
)
|
||||
@@ -1164,14 +1171,29 @@ class DefaultReplyer:
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
related_info = ""
|
||||
start_time = time.time()
|
||||
search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory")
|
||||
if search_knowledge_tool is None:
|
||||
logger.debug("长期记忆检索工具未注册,跳过获取知识内容")
|
||||
try:
|
||||
knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge")
|
||||
except ImportError:
|
||||
logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None)
|
||||
if search_knowledge_tool is None:
|
||||
logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
# 从LPMM知识库获取知识
|
||||
try:
|
||||
template_prompt = prompt_manager.get_prompt("memory_get_knowledge")
|
||||
# 检查LPMM知识库是否启用
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
return ""
|
||||
|
||||
template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge")
|
||||
template_prompt.add_context("bot_name", global_config.bot.nickname)
|
||||
template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
template_prompt.add_context("chat_history", message)
|
||||
@@ -1187,31 +1209,24 @@ class DefaultReplyer:
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug("模型认为不需要使用长期记忆")
|
||||
if tool_calls:
|
||||
result = await self.tool_executor.execute_tool_call(tool_calls[0])
|
||||
end_time = time.time()
|
||||
if not result or not result.get("content"):
|
||||
logger.debug("从LPMM知识库获取知识失败,返回空知识...")
|
||||
return ""
|
||||
found_knowledge_from_lpmm = result.get("content", "")
|
||||
logger.info(
|
||||
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
||||
)
|
||||
related_info += found_knowledge_from_lpmm
|
||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
|
||||
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
||||
else:
|
||||
logger.debug("模型认为不需要使用LPMM知识库")
|
||||
return ""
|
||||
|
||||
related_chunks: List[str] = []
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.func_name != "search_long_term_memory":
|
||||
continue
|
||||
tool_args = dict(tool_call.args or {})
|
||||
tool_args.setdefault("chat_id", self.chat_stream.session_id)
|
||||
result_text = await search_knowledge_tool.execute(**tool_args)
|
||||
if result_text and "未找到" not in result_text:
|
||||
related_chunks.append(result_text)
|
||||
|
||||
if not related_chunks:
|
||||
logger.debug("长期记忆未返回有效信息")
|
||||
return ""
|
||||
|
||||
related_info = "\n".join(related_chunks)
|
||||
end_time = time.time()
|
||||
logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
|
||||
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||
return ""
|
||||
|
||||
@@ -16,7 +16,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
@@ -27,13 +26,13 @@ from src.services.message_service import (
|
||||
replace_user_references,
|
||||
translate_pid_to_description,
|
||||
)
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.learners.expression_selector import expression_selector
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person, is_person_known
|
||||
from src.core.types import ActionInfo, EventType
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context
|
||||
from src.learners.jargon_explainer_old import explain_jargon_in_context
|
||||
|
||||
init_memory_retrieval_sys()
|
||||
|
||||
@@ -47,10 +46,15 @@ class PrivateReplyer:
|
||||
chat_stream: BotChatSession,
|
||||
request_type: str = "replyer",
|
||||
):
|
||||
"""初始化私聊回复器。
|
||||
|
||||
Args:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
# self.memory_activator = MemoryActivator()
|
||||
|
||||
from src.chat.tool_executor import ToolExecutor
|
||||
@@ -970,7 +974,9 @@ class PrivateReplyer:
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
group_info=None,
|
||||
additional_config={},
|
||||
additional_config={
|
||||
"platform_io_target_user_id": self.chat_stream.user_id,
|
||||
},
|
||||
),
|
||||
message_segment=message_segment,
|
||||
)
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
"""
|
||||
工具执行器
|
||||
"""工具执行器。
|
||||
|
||||
独立的工具执行组件,可以直接输入聊天消息内容,
|
||||
自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
|
||||
从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.core.component_registry import component_registry
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
@@ -89,7 +87,7 @@ class ToolExecutor:
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取 LLM 可用的工具定义列表"""
|
||||
all_tools = component_registry.get_llm_available_tools()
|
||||
all_tools = component_query_service.get_llm_available_tools()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
|
||||
|
||||
@@ -152,7 +150,7 @@ class ToolExecutor:
|
||||
function_args = tool_call.args or {}
|
||||
function_args["llm_called"] = True
|
||||
|
||||
executor = component_registry.get_tool_executor(function_name)
|
||||
executor = component_query_service.get_tool_executor(function_name)
|
||||
if not executor:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [(record.start_timestamp, record.end_timestamp) for record in records]
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [
|
||||
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||
try:
|
||||
action_query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
|
||||
actions = session.exec(statement).all()
|
||||
for action in actions:
|
||||
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from typing import Any, List, Mapping, Optional, Sequence
|
||||
|
||||
import json
|
||||
|
||||
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
|
||||
group_cardname: str
|
||||
|
||||
|
||||
def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
|
||||
"""将单条群名片数据规范化为统一结构。
|
||||
|
||||
Args:
|
||||
raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
|
||||
|
||||
Returns:
|
||||
Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
|
||||
"""
|
||||
group_id = str(raw_item.get("group_id") or "").strip()
|
||||
group_cardname = str(raw_item.get("group_cardname") or "").strip()
|
||||
if not group_id or not group_cardname:
|
||||
return None
|
||||
return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
|
||||
|
||||
|
||||
def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
|
||||
"""解析数据库中的群名片 JSON 字段。
|
||||
|
||||
Args:
|
||||
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||
"""
|
||||
if not group_cardname_json:
|
||||
return None
|
||||
|
||||
raw_items = json.loads(group_cardname_json)
|
||||
if not isinstance(raw_items, list):
|
||||
return None
|
||||
|
||||
normalized_items: List[GroupCardnameInfo] = []
|
||||
for raw_item in raw_items:
|
||||
if not isinstance(raw_item, Mapping):
|
||||
continue
|
||||
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||
normalized_items.append(normalized_item)
|
||||
|
||||
return normalized_items or None
|
||||
|
||||
|
||||
def dump_group_cardname_records(
|
||||
group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
|
||||
) -> str:
|
||||
"""将群名片列表序列化为数据库使用的标准 JSON 字符串。
|
||||
|
||||
Args:
|
||||
group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
|
||||
对象和包含 `group_id` / `group_cardname` 的字典。
|
||||
|
||||
Returns:
|
||||
str: 统一使用 `group_cardname` 键名的 JSON 字符串。
|
||||
"""
|
||||
normalized_items: List[GroupCardnameInfo] = []
|
||||
for raw_item in group_cardname_records or []:
|
||||
if isinstance(raw_item, GroupCardnameInfo):
|
||||
normalized_items.append(raw_item)
|
||||
continue
|
||||
if isinstance(raw_item, Mapping):
|
||||
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||
normalized_items.append(normalized_item)
|
||||
|
||||
return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
|
||||
|
||||
|
||||
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
"""最后一次被认识的时间"""
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: "PersonInfo"):
|
||||
nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
|
||||
group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
|
||||
def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
|
||||
"""从数据库记录构造人物信息数据模型。
|
||||
|
||||
Args:
|
||||
db_record: 数据库中的人物信息记录。
|
||||
|
||||
Returns:
|
||||
MaiPersonInfo: 转换后的数据模型对象。
|
||||
"""
|
||||
group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
|
||||
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
|
||||
return cls(
|
||||
is_known=db_record.is_known,
|
||||
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
)
|
||||
|
||||
def to_db_instance(self) -> "PersonInfo":
|
||||
group_cardname = (
|
||||
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
|
||||
)
|
||||
"""将当前数据模型转换为数据库记录对象。
|
||||
|
||||
Returns:
|
||||
PersonInfo: 可直接写入数据库的模型实例。
|
||||
"""
|
||||
group_cardname = dump_group_cardname_records(self.group_cardname_list)
|
||||
return PersonInfo(
|
||||
is_known=self.is_known,
|
||||
person_id=self.person_id,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime
|
||||
from sqlmodel import SQLModel, Field, LargeBinary
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
|
||||
from sqlmodel import Field, LargeBinary, SQLModel
|
||||
|
||||
|
||||
class ModelUser(str, Enum):
|
||||
@@ -172,8 +173,8 @@ class Expression(SQLModel, table=True):
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
situation: str = Field(index=True, max_length=255, primary_key=True) # 情景
|
||||
style: str = Field(index=True, max_length=255, primary_key=True) # 风格
|
||||
situation: str = Field(index=True, max_length=255) # 情景
|
||||
style: str = Field(index=True, max_length=255) # 风格
|
||||
|
||||
# context: str # 上下文
|
||||
# up_content: str
|
||||
@@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True):
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容
|
||||
content: str = Field(index=True, max_length=255) # 黑话内容
|
||||
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str]
|
||||
|
||||
meaning: str # 黑话含义
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# 定义模块颜色映射
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
|
||||
MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
|
||||
@@ -54,15 +53,19 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
|
||||
"component_registry": ("#ffaf00", None, False),
|
||||
"plugin_runtime.integration": ("#d75f00", None, False),
|
||||
"plugin_runtime.host.supervisor": ("#ff5f00", None, False),
|
||||
"plugin_runtime.host.runner_manager": ("#ff5f00", None, False),
|
||||
"plugin_runtime.host.rpc_server": ("#ff8700", None, False),
|
||||
"plugin_runtime.host.component_registry": ("#ffaf00", None, False),
|
||||
"plugin_runtime.host.capability_service": ("#ffd700", None, False),
|
||||
"plugin_runtime.host.event_dispatcher": ("#87d700", None, False),
|
||||
"plugin_runtime.host.workflow_executor": ("#5fd7af", None, False),
|
||||
"plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False),
|
||||
"plugin_runtime.host.message_gateway": ("#5fd7d7", None, False),
|
||||
"plugin_runtime.host.message_utils": ("#5faf87", None, False),
|
||||
"plugin_runtime.runner.main": ("#d787ff", None, False),
|
||||
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
|
||||
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
|
||||
"plugin_runtime.runner.plugin_loader": ("#00afaf", None, False),
|
||||
"plugin.maibot-team.napcat-adapter": ("#00af87", None, False),
|
||||
"webui": ("#5f87ff", None, False),
|
||||
"webui.app": ("#5f87d7", None, False),
|
||||
"webui.api": ("#5fafff", None, False),
|
||||
@@ -157,15 +160,20 @@ MODULE_ALIASES = {
|
||||
"chat_history_summarizer": "聊天概括器",
|
||||
"plugin_runtime.integration": "IPC插件系统",
|
||||
"plugin_runtime.host.supervisor": "插件监督器",
|
||||
"plugin_runtime.host.runner_manager": "插件监督器",
|
||||
"plugin_runtime.host.rpc_server": "插件RPC服务",
|
||||
"plugin_runtime.host.component_registry": "插件组件注册",
|
||||
"plugin_runtime.host.capability_service": "插件能力服务",
|
||||
"plugin_runtime.host.event_dispatcher": "插件事件分发",
|
||||
"plugin_runtime.host.hook_dispatcher": "插件Hook分发",
|
||||
"plugin_runtime.host.message_gateway": "插件消息网关",
|
||||
"plugin_runtime.host.message_utils": "插件消息工具",
|
||||
"plugin_runtime.host.workflow_executor": "插件工作流",
|
||||
"plugin_runtime.runner.main": "插件运行器",
|
||||
"plugin_runtime.runner.rpc_client": "插件RPC客户端",
|
||||
"plugin_runtime.runner.manifest_validator": "插件清单校验",
|
||||
"plugin_runtime.runner.plugin_loader": "插件加载器",
|
||||
"plugin.maibot-team.napcat-adapter": "NapCat内置适配器",
|
||||
"webui": "WebUI",
|
||||
"webui.app": "WebUI应用",
|
||||
"webui.api": "WebUI接口",
|
||||
|
||||
@@ -21,7 +21,7 @@ class Server:
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self.set_address(host, port)
|
||||
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
"""注册路由
|
||||
|
||||
APIRouter 用于对相关的路由端点进行分组和模块化管理:
|
||||
|
||||
@@ -5,13 +5,22 @@ import hashlib
|
||||
|
||||
class SessionUtils:
|
||||
@staticmethod
|
||||
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
|
||||
def calculate_session_id(
|
||||
platform: str,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
) -> str:
|
||||
"""计算session_id
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
user_id: 用户ID(如果是私聊)
|
||||
group_id: 群ID(如果是群聊)
|
||||
account_id: 当前平台账号 ID,可选
|
||||
scope: 当前路由作用域,可选
|
||||
Returns:
|
||||
str: 计算得到的会话ID
|
||||
Raises:
|
||||
@@ -19,8 +28,15 @@ class SessionUtils:
|
||||
"""
|
||||
if not user_id and not group_id:
|
||||
raise ValueError("UserID 或 GroupID 必须提供其一")
|
||||
|
||||
route_components = []
|
||||
if account_id:
|
||||
route_components.append(f"account:{account_id}")
|
||||
if scope:
|
||||
route_components.append(f"scope:{scope}")
|
||||
|
||||
if group_id:
|
||||
components = [platform, group_id]
|
||||
components = [platform, *route_components, group_id]
|
||||
else:
|
||||
components = [platform, user_id, "private"]
|
||||
components = [platform, *route_components, user_id, "private"]
|
||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
import tomlkit
|
||||
@@ -61,6 +62,7 @@ MODEL_CONFIG_VERSION: str = "1.12.0"
|
||||
logger = get_logger("config")
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
ConfigReloadCallback = Callable[[Sequence[str]], object] | Callable[[], object]
|
||||
|
||||
|
||||
class Config(ConfigBase):
|
||||
@@ -190,7 +192,7 @@ class ConfigManager:
|
||||
self.global_config: Config | None = None
|
||||
self.model_config: ModelConfig | None = None
|
||||
self._reload_lock: asyncio.Lock = asyncio.Lock()
|
||||
self._reload_callbacks: list[Callable[[], object]] = []
|
||||
self._reload_callbacks: list[ConfigReloadCallback] = []
|
||||
self._file_watcher: FileWatcher | None = None
|
||||
self._file_watcher_subscription_id: str | None = None
|
||||
self._hot_reload_min_interval_s: float = 1.0
|
||||
@@ -226,16 +228,125 @@ class ConfigManager:
|
||||
raise RuntimeError(t("config.model_not_initialized"))
|
||||
return self.model_config
|
||||
|
||||
def register_reload_callback(self, callback: Callable[[], object]) -> None:
|
||||
def register_reload_callback(self, callback: ConfigReloadCallback) -> None:
|
||||
"""注册配置热重载回调。
|
||||
|
||||
Args:
|
||||
callback: 配置热重载回调。允许无参回调,也允许接收
|
||||
``Sequence[str]`` 类型的变更范围列表。
|
||||
"""
|
||||
|
||||
self._reload_callbacks.append(callback)
|
||||
|
||||
def unregister_reload_callback(self, callback: Callable[[], object]) -> None:
|
||||
def unregister_reload_callback(self, callback: ConfigReloadCallback) -> None:
|
||||
"""注销配置热重载回调。
|
||||
|
||||
Args:
|
||||
callback: 先前注册过的回调对象。
|
||||
"""
|
||||
|
||||
try:
|
||||
self._reload_callbacks.remove(callback)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
async def reload_config(self) -> bool:
|
||||
@staticmethod
|
||||
def _normalize_changed_scopes(changed_scopes: Sequence[str] | None) -> tuple[str, ...]:
|
||||
"""规范化配置变更范围列表。
|
||||
|
||||
Args:
|
||||
changed_scopes: 原始配置变更范围。
|
||||
|
||||
Returns:
|
||||
tuple[str, ...]: 去重后的配置变更范围元组。
|
||||
"""
|
||||
|
||||
if not changed_scopes:
|
||||
return ("bot", "model")
|
||||
|
||||
normalized_scopes: list[str] = []
|
||||
for scope in changed_scopes:
|
||||
normalized_scope = str(scope or "").strip().lower()
|
||||
if normalized_scope not in {"bot", "model"}:
|
||||
continue
|
||||
if normalized_scope not in normalized_scopes:
|
||||
normalized_scopes.append(normalized_scope)
|
||||
return tuple(normalized_scopes)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_changed_scopes(changes: Sequence[FileChange]) -> tuple[str, ...]:
|
||||
"""根据文件变更列表推断配置变更范围。
|
||||
|
||||
Args:
|
||||
changes: 文件监听器返回的变更列表。
|
||||
|
||||
Returns:
|
||||
tuple[str, ...]: 命中的配置变更范围元组。
|
||||
"""
|
||||
|
||||
changed_scopes: list[str] = []
|
||||
for change in changes:
|
||||
file_name = change.path.name
|
||||
if file_name == "bot_config.toml" and "bot" not in changed_scopes:
|
||||
changed_scopes.append("bot")
|
||||
if file_name == "model_config.toml" and "model" not in changed_scopes:
|
||||
changed_scopes.append("model")
|
||||
return tuple(changed_scopes)
|
||||
|
||||
@staticmethod
|
||||
def _callback_accepts_scopes(callback: ConfigReloadCallback) -> bool:
|
||||
"""判断回调是否接收配置变更范围参数。
|
||||
|
||||
Args:
|
||||
callback: 待检测的回调对象。
|
||||
|
||||
Returns:
|
||||
bool: 若回调可接收一个位置参数或可变位置参数,则返回 ``True``。
|
||||
"""
|
||||
|
||||
try:
|
||||
parameters = inspect.signature(callback).parameters.values()
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
positional_params = {
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
}
|
||||
for parameter in parameters:
|
||||
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
return True
|
||||
if parameter.kind in positional_params:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _invoke_reload_callback(
|
||||
self,
|
||||
callback: ConfigReloadCallback,
|
||||
changed_scopes: Sequence[str],
|
||||
) -> None:
|
||||
"""执行单个配置热重载回调。
|
||||
|
||||
Args:
|
||||
callback: 要执行的回调对象。
|
||||
changed_scopes: 本次热重载命中的配置范围。
|
||||
"""
|
||||
|
||||
result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
async def reload_config(self, changed_scopes: Sequence[str] | None = None) -> bool:
|
||||
"""重新加载主配置和模型配置。
|
||||
|
||||
Args:
|
||||
changed_scopes: 本次触发热重载的配置范围。
|
||||
|
||||
Returns:
|
||||
bool: 是否重载成功。
|
||||
"""
|
||||
|
||||
normalized_scopes = self._normalize_changed_scopes(changed_scopes)
|
||||
async with self._reload_lock:
|
||||
try:
|
||||
global_config_new, global_updated = load_config_from_file(
|
||||
@@ -265,9 +376,7 @@ class ConfigManager:
|
||||
|
||||
for callback in list(self._reload_callbacks):
|
||||
try:
|
||||
result = callback()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
await self._invoke_reload_callback(callback, normalized_scopes)
|
||||
except Exception as exc:
|
||||
logger.warning(t("config.reload_callback_failed", error=exc))
|
||||
return True
|
||||
@@ -312,6 +421,12 @@ class ConfigManager:
|
||||
self._file_watcher = None
|
||||
|
||||
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
|
||||
"""处理主配置与模型配置文件变更。
|
||||
|
||||
Args:
|
||||
changes: 当前批次收集到的文件变更列表。
|
||||
"""
|
||||
|
||||
if not changes:
|
||||
return
|
||||
now_monotonic = asyncio.get_running_loop().time()
|
||||
@@ -321,7 +436,11 @@ class ConfigManager:
|
||||
self._last_hot_reload_monotonic = now_monotonic
|
||||
logger.info(t("config.file_change_detected"))
|
||||
try:
|
||||
await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s)
|
||||
changed_scopes = self._resolve_changed_scopes(changes)
|
||||
await asyncio.wait_for(
|
||||
self.reload_config(changed_scopes=changed_scopes),
|
||||
timeout=self._hot_reload_timeout_s,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s))
|
||||
|
||||
|
||||
@@ -1633,24 +1633,6 @@ class PluginRuntimeConfig(ConfigBase):
|
||||
)
|
||||
"""启用插件系统"""
|
||||
|
||||
builtin_plugin_dir: str = Field(
|
||||
default="src/plugins/built_in",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "folder",
|
||||
},
|
||||
)
|
||||
"""内置插件目录(相对于项目根目录)"""
|
||||
|
||||
thirdparty_plugin_dir: str = Field(
|
||||
default="plugins",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "folder-open",
|
||||
},
|
||||
)
|
||||
"""第三方插件目录(相对于项目根目录)"""
|
||||
|
||||
health_check_interval_sec: float = Field(
|
||||
default=30.0,
|
||||
json_schema_extra={
|
||||
@@ -1678,14 +1660,14 @@ class PluginRuntimeConfig(ConfigBase):
|
||||
)
|
||||
"""等待 Runner 子进程启动并注册的超时时间(秒)"""
|
||||
|
||||
workflow_blocking_timeout_sec: float = Field(
|
||||
default=120.0,
|
||||
hook_blocking_timeout_sec: float = Field(
|
||||
default=30,
|
||||
json_schema_extra={
|
||||
"x-widget": "number",
|
||||
"x-icon": "timer",
|
||||
},
|
||||
)
|
||||
"""Workflow 阻塞步骤的全局超时上限(秒)"""
|
||||
"""Hook 阻塞步骤的全局超时上限(秒)"""
|
||||
|
||||
ipc_socket_path: str = Field(
|
||||
default="",
|
||||
@@ -1694,4 +1676,7 @@ class PluginRuntimeConfig(ConfigBase):
|
||||
"x-icon": "link",
|
||||
},
|
||||
)
|
||||
"""_wrap_\n 自定义 IPC Socket 路径(仅 Linux/macOS 生效)\n 留空则自动生成临时路径"""
|
||||
"""
|
||||
自定义 IPC Socket 路径(仅 Linux/macOS 生效)
|
||||
留空则自动生成临时路径
|
||||
"""
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
"""
|
||||
核心组件注册表
|
||||
|
||||
面向最终架构的组件管理:
|
||||
- Action:注册 ActionInfo + 执行器(本地 callable 或 IPC 路由)
|
||||
- Command:注册正则模式 + 执行器
|
||||
- Tool:注册工具定义 + 执行器
|
||||
|
||||
不依赖任何插件基类,组件执行器是纯 async callable。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import (
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
# 执行器类型
|
||||
ActionExecutor = Callable[..., Awaitable[Any]]
|
||||
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
|
||||
ToolExecutor = Callable[..., Awaitable[Any]]
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""核心组件注册表
|
||||
|
||||
管理 action、command、tool 三类组件。
|
||||
每个组件由「元信息 + 执行器」构成,执行器是 async callable,
|
||||
不需要继承任何基类。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Action 注册
|
||||
self._actions: Dict[str, ActionInfo] = {}
|
||||
self._action_executors: Dict[str, ActionExecutor] = {}
|
||||
self._default_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# Command 注册
|
||||
self._commands: Dict[str, CommandInfo] = {}
|
||||
self._command_executors: Dict[str, CommandExecutor] = {}
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
|
||||
# Tool 注册
|
||||
self._tools: Dict[str, ToolInfo] = {}
|
||||
self._tool_executors: Dict[str, ToolExecutor] = {}
|
||||
self._llm_available_tools: Dict[str, ToolInfo] = {}
|
||||
|
||||
# 插件配置(plugin_name -> config dict)
|
||||
self._plugin_configs: Dict[str, dict] = {}
|
||||
|
||||
logger.info("核心组件注册表初始化完成")
|
||||
|
||||
# ========== Action ==========
|
||||
|
||||
def register_action(
|
||||
self,
|
||||
info: ActionInfo,
|
||||
executor: ActionExecutor,
|
||||
) -> bool:
|
||||
"""注册 action
|
||||
|
||||
Args:
|
||||
info: action 元信息
|
||||
executor: 执行器,async callable
|
||||
"""
|
||||
name = info.name
|
||||
if name in self._actions:
|
||||
logger.warning(f"Action {name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._actions[name] = info
|
||||
self._action_executors[name] = executor
|
||||
|
||||
if info.enabled:
|
||||
self._default_actions[name] = info
|
||||
|
||||
logger.debug(f"注册 Action: {name}")
|
||||
return True
|
||||
|
||||
def get_action_info(self, name: str) -> Optional[ActionInfo]:
|
||||
return self._actions.get(name)
|
||||
|
||||
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
|
||||
return self._action_executors.get(name)
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
return self._default_actions.copy()
|
||||
|
||||
def get_all_actions(self) -> Dict[str, ActionInfo]:
|
||||
return self._actions.copy()
|
||||
|
||||
def remove_action(self, name: str) -> bool:
|
||||
if name not in self._actions:
|
||||
return False
|
||||
del self._actions[name]
|
||||
self._action_executors.pop(name, None)
|
||||
self._default_actions.pop(name, None)
|
||||
logger.debug(f"移除 Action: {name}")
|
||||
return True
|
||||
|
||||
# ========== Command ==========
|
||||
|
||||
def register_command(
|
||||
self,
|
||||
info: CommandInfo,
|
||||
executor: CommandExecutor,
|
||||
) -> bool:
|
||||
"""注册 command"""
|
||||
name = info.name
|
||||
if name in self._commands:
|
||||
logger.warning(f"Command {name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._commands[name] = info
|
||||
self._command_executors[name] = executor
|
||||
|
||||
if info.enabled and info.command_pattern:
|
||||
pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
self._command_patterns[pattern] = name
|
||||
|
||||
logger.debug(f"注册 Command: {name}")
|
||||
return True
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
Returns:
|
||||
(executor, matched_groups, command_info) 或 None
|
||||
"""
|
||||
candidates = [p for p in self._command_patterns if p.match(text)]
|
||||
if not candidates:
|
||||
return None
|
||||
if len(candidates) > 1:
|
||||
logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个")
|
||||
pattern = candidates[0]
|
||||
name = self._command_patterns[pattern]
|
||||
return (
|
||||
self._command_executors[name],
|
||||
pattern.match(text).groupdict(), # type: ignore
|
||||
self._commands[name],
|
||||
)
|
||||
|
||||
def remove_command(self, name: str) -> bool:
|
||||
if name not in self._commands:
|
||||
return False
|
||||
del self._commands[name]
|
||||
self._command_executors.pop(name, None)
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name}
|
||||
logger.debug(f"移除 Command: {name}")
|
||||
return True
|
||||
|
||||
# ========== Tool ==========
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
info: ToolInfo,
|
||||
executor: ToolExecutor,
|
||||
) -> bool:
|
||||
"""注册 tool"""
|
||||
name = info.name
|
||||
if name in self._tools:
|
||||
logger.warning(f"Tool {name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._tools[name] = info
|
||||
self._tool_executors[name] = executor
|
||||
|
||||
if info.enabled:
|
||||
self._llm_available_tools[name] = info
|
||||
|
||||
logger.debug(f"注册 Tool: {name}")
|
||||
return True
|
||||
|
||||
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
|
||||
return self._tool_executors.get(name)
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
def get_all_tools(self) -> Dict[str, ToolInfo]:
|
||||
return self._tools.copy()
|
||||
|
||||
def remove_tool(self, name: str) -> bool:
|
||||
if name not in self._tools:
|
||||
return False
|
||||
del self._tools[name]
|
||||
self._tool_executors.pop(name, None)
|
||||
self._llm_available_tools.pop(name, None)
|
||||
logger.debug(f"移除 Tool: {name}")
|
||||
return True
|
||||
|
||||
# ========== 通用查询 ==========
|
||||
|
||||
def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]:
|
||||
"""获取组件元信息"""
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return self._actions.get(name)
|
||||
case ComponentType.COMMAND:
|
||||
return self._commands.get(name)
|
||||
case ComponentType.TOOL:
|
||||
return self._tools.get(name)
|
||||
case _:
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取某类型的所有组件"""
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return dict(self._actions)
|
||||
case ComponentType.COMMAND:
|
||||
return dict(self._commands)
|
||||
case ComponentType.TOOL:
|
||||
return dict(self._tools)
|
||||
case _:
|
||||
return {}
|
||||
|
||||
# ========== 插件配置 ==========
|
||||
|
||||
def set_plugin_config(self, plugin_name: str, config: dict) -> None:
|
||||
self._plugin_configs[plugin_name] = config
|
||||
|
||||
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
||||
return self._plugin_configs.get(plugin_name)
|
||||
|
||||
|
||||
# 全局单例
|
||||
component_registry = ComponentRegistry()
|
||||
@@ -3,19 +3,19 @@
|
||||
|
||||
功能:
|
||||
1. 定期随机选取指定数量的表达方式
|
||||
2. 使用LLM进行评估
|
||||
2. 使用 LLM 进行评估
|
||||
3. 通过评估的:rejected=0, checked=1
|
||||
4. 未通过评估的:rejected=1, checked=1
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.bw_learner.expression_review_store import get_review_state, set_review_state
|
||||
from src.learners.expression_review_store import get_review_state, set_review_state
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
@@ -146,7 +146,8 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
选中的表达方式列表
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Expression)
|
||||
all_expressions = session.exec(statement).all()
|
||||
|
||||
@@ -329,7 +329,13 @@ class ExpressionLearner:
|
||||
return filtered_expressions
|
||||
|
||||
# ====== DB 操作相关 ======
|
||||
async def _upsert_expression_to_db(self, situation: str, style: str):
|
||||
async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
|
||||
"""将表达方式写入数据库,存在时更新,不存在时新增。
|
||||
|
||||
Args:
|
||||
situation: 表达方式对应的使用情景。
|
||||
style: 表达方式风格。
|
||||
"""
|
||||
expr, similarity = self._find_similar_expression(situation) or (None, 0)
|
||||
if expr:
|
||||
# 根据相似度决定是否使用 LLM 总结
|
||||
@@ -340,7 +346,13 @@ class ExpressionLearner:
|
||||
# 没有找到匹配的记录,创建新记录
|
||||
self._create_expression(situation, style)
|
||||
|
||||
def _create_expression(self, situation: str, style: str):
|
||||
def _create_expression(self, situation: str, style: str) -> None:
|
||||
"""创建新的表达方式记录。
|
||||
|
||||
Args:
|
||||
situation: 表达方式对应的使用情景。
|
||||
style: 表达方式风格。
|
||||
"""
|
||||
content_list = [situation]
|
||||
try:
|
||||
with get_db_session() as db:
|
||||
@@ -353,6 +365,7 @@ class ExpressionLearner:
|
||||
last_active_time=datetime.now(),
|
||||
)
|
||||
db.add(new_expr)
|
||||
db.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"创建表达方式失败: {e}")
|
||||
|
||||
@@ -448,25 +461,43 @@ class ExpressionLearner:
|
||||
def _find_similar_expression(
|
||||
self, situation: str, similarity_threshold: float = 0.75
|
||||
) -> Optional[Tuple[MaiExpression, float]]:
|
||||
"""在数据库中查找相似的表达方式"""
|
||||
"""在数据库中查找相似的表达方式。
|
||||
|
||||
Args:
|
||||
situation: 当前待匹配的情景描述。
|
||||
similarity_threshold: 认定为相似表达方式的最低相似度阈值。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回
|
||||
``(表达方式对象, 相似度)``;否则返回 ``None``。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Expression).filter_by(session_id=self.session_id)
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
best_match: Optional[Expression] = None
|
||||
best_similarity = 0.0
|
||||
best_match: Optional[MaiExpression] = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for db_expression in expressions:
|
||||
expression = MaiExpression.from_db_instance(db_expression)
|
||||
candidate_situations = [expression.situation, *expression.content]
|
||||
for candidate_situation in candidate_situations:
|
||||
normalized_candidate_situation = candidate_situation.strip()
|
||||
if not normalized_candidate_situation:
|
||||
continue
|
||||
similarity = difflib.SequenceMatcher(
|
||||
None,
|
||||
situation,
|
||||
normalized_candidate_situation,
|
||||
).ratio()
|
||||
if similarity > similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expression
|
||||
|
||||
for expr in expressions:
|
||||
content_list = json.loads(expr.content_list)
|
||||
for situation in content_list:
|
||||
similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio()
|
||||
if similarity > similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
if best_match:
|
||||
logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}")
|
||||
return MaiExpression.from_db_instance(best_match), best_similarity
|
||||
logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
|
||||
return best_match, best_similarity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找相似表达方式失败: {e}")
|
||||
@@ -9,7 +9,7 @@ from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.bw_learner.learner_utils_old import weighted_sample
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
@@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.bw_learner.jargon_explainer import search_jargon
|
||||
from src.bw_learner.learner_utils_old import (
|
||||
from src.learners.jargon_miner_old import search_jargon
|
||||
from src.learners.learner_utils_old import (
|
||||
is_bot_message,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
@@ -1,17 +1,18 @@
|
||||
from collections import OrderedDict
|
||||
from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
from typing import List, Optional, Dict, Callable, TypedDict, Set
|
||||
from typing import Callable, Dict, List, Optional, Set, TypedDict
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.config.config import model_config, global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
@@ -198,7 +199,7 @@ class JargonMiner:
|
||||
|
||||
async def process_extracted_entries(
|
||||
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
处理已提取的黑话条目(从 expression_learner 路由过来的)
|
||||
|
||||
@@ -229,7 +230,7 @@ class JargonMiner:
|
||||
content = entry["content"]
|
||||
raw_content_set = entry["raw_content"]
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询黑话 '{content}' 失败: {e}")
|
||||
@@ -273,11 +274,12 @@ class JargonMiner:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
session.add(new_jargon)
|
||||
session.flush()
|
||||
saved += 1
|
||||
self._add_to_cache(content)
|
||||
except Exception as e:
|
||||
logger.error(f"保存新黑话 '{content}' 失败: {e}")
|
||||
continue
|
||||
finally:
|
||||
self._add_to_cache(content)
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
@@ -304,7 +306,13 @@ class JargonMiner:
|
||||
removed_content, _ = self.cache.popitem(last=False)
|
||||
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
|
||||
|
||||
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]):
|
||||
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None:
|
||||
"""更新已有黑话记录并写回数据库。
|
||||
|
||||
Args:
|
||||
db_jargon: 已命中的黑话 ORM 对象。
|
||||
raw_content_set: 本次新增的原始上下文集合。
|
||||
"""
|
||||
db_jargon.count += 1
|
||||
existing_raw_content: List[str] = []
|
||||
if db_jargon.raw_content:
|
||||
@@ -326,7 +334,17 @@ class JargonMiner:
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
session.add(db_jargon)
|
||||
if db_jargon.id is None:
|
||||
raise ValueError("黑话记录缺少 id,无法更新数据库")
|
||||
statement = select(Jargon).filter_by(id=db_jargon.id).limit(1)
|
||||
if persisted_jargon := session.exec(statement).first():
|
||||
persisted_jargon.count = db_jargon.count
|
||||
persisted_jargon.raw_content = db_jargon.raw_content
|
||||
persisted_jargon.session_id_dict = db_jargon.session_id_dict
|
||||
persisted_jargon.is_global = db_jargon.is_global
|
||||
session.add(persisted_jargon)
|
||||
else:
|
||||
logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新")
|
||||
except Exception as e:
|
||||
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
from src.learners.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
@@ -14,7 +14,7 @@ from src.common.database.database_model import ThinkingQuestion
|
||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
|
||||
logger = get_logger("memory_retrieval")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union, Optional, Dict, List
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.services.memory_service import memory_service
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
logger = get_logger("person_info")
|
||||
@@ -28,6 +28,32 @@ relation_selection_model = LLMRequest(
|
||||
)
|
||||
|
||||
|
||||
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
|
||||
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
|
||||
|
||||
Args:
|
||||
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||
"""
|
||||
group_cardname_list = parse_group_cardname_json(group_cardname_json)
|
||||
if not group_cardname_list:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"group_id": group_cardname.group_id,
|
||||
"group_cardname": group_cardname.group_cardname,
|
||||
}
|
||||
for group_cardname in group_cardname_list
|
||||
]
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
if "-" in platform:
|
||||
@@ -39,60 +65,16 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
|
||||
def get_person_id_by_person_name(person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
clean_name = str(person_name or "").strip()
|
||||
if not clean_name:
|
||||
return ""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(PersonInfo)
|
||||
.where(
|
||||
or_(
|
||||
col(PersonInfo.person_name) == clean_name,
|
||||
col(PersonInfo.user_nickname) == clean_name,
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
if record and record.person_id:
|
||||
return record.person_id
|
||||
|
||||
statement = (
|
||||
select(PersonInfo)
|
||||
.where(PersonInfo.group_cardname.contains(clean_name))
|
||||
.limit(1)
|
||||
)
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
|
||||
record = session.exec(statement).first()
|
||||
return record.person_id if record else ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}")
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def resolve_person_id_for_memory(
|
||||
*,
|
||||
person_name: str = "",
|
||||
platform: str = "",
|
||||
user_id: Optional[Union[int, str]] = None,
|
||||
) -> str:
|
||||
"""统一人物记忆链路中的 person_id 解析。
|
||||
|
||||
优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。
|
||||
"""
|
||||
name_token = str(person_name or "").strip()
|
||||
if name_token:
|
||||
resolved = get_person_id_by_person_name(name_token)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
platform_token = str(platform or "").strip()
|
||||
user_token = str(user_id or "").strip()
|
||||
if platform_token and user_token:
|
||||
return get_person_id(platform_token, user_token)
|
||||
return ""
|
||||
|
||||
|
||||
def is_person_known(
|
||||
person_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
@@ -277,7 +259,7 @@ class Person:
|
||||
person.know_since = time.time()
|
||||
person.last_know = time.time()
|
||||
person.memory_points = []
|
||||
person.group_nick_name = [] # 初始化群昵称列表
|
||||
person.group_cardname_list = [] # 初始化群名片列表
|
||||
|
||||
# 如果是群聊,添加群昵称
|
||||
if group_id and group_nick_name:
|
||||
@@ -315,7 +297,7 @@ class Person:
|
||||
self.platform = platform
|
||||
self.nickname = global_config.bot.nickname
|
||||
self.person_name = global_config.bot.nickname
|
||||
self.group_nick_name: list[dict[str, str]] = []
|
||||
self.group_cardname_list: list[dict[str, str]] = []
|
||||
return
|
||||
|
||||
self.user_id = ""
|
||||
@@ -354,7 +336,7 @@ class Person:
|
||||
self.know_since = None
|
||||
self.last_know: Optional[float] = None
|
||||
self.memory_points = []
|
||||
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
|
||||
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
|
||||
|
||||
# 从数据库加载数据
|
||||
self.load_from_database()
|
||||
@@ -454,16 +436,16 @@ class Person:
|
||||
return
|
||||
|
||||
# 检查是否已存在该群号的记录
|
||||
for item in self.group_nick_name:
|
||||
for item in self.group_cardname_list:
|
||||
if item.get("group_id") == group_id:
|
||||
# 更新现有记录
|
||||
item["group_nick_name"] = group_nick_name
|
||||
item["group_cardname"] = group_nick_name
|
||||
self.sync_to_database()
|
||||
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
|
||||
return
|
||||
|
||||
# 添加新记录
|
||||
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
|
||||
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
|
||||
self.sync_to_database()
|
||||
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
|
||||
|
||||
@@ -498,20 +480,15 @@ class Person:
|
||||
else:
|
||||
self.memory_points = []
|
||||
|
||||
# 处理group_nick_name字段(JSON格式的列表)
|
||||
# 处理 group_cardname 字段(JSON 格式的列表)
|
||||
if record.group_cardname:
|
||||
try:
|
||||
loaded_group_nick_names = json.loads(record.group_cardname)
|
||||
# 确保是列表格式
|
||||
if isinstance(loaded_group_nick_names, list):
|
||||
self.group_nick_name = loaded_group_nick_names
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = []
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = []
|
||||
|
||||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||||
else:
|
||||
@@ -532,11 +509,7 @@ class Person:
|
||||
if self.memory_points
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_nickname_value = (
|
||||
json.dumps(self.group_nick_name, ensure_ascii=False)
|
||||
if self.group_nick_name
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
|
||||
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
|
||||
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
|
||||
|
||||
@@ -556,7 +529,7 @@ class Person:
|
||||
record.first_known_time = first_known_time
|
||||
record.last_known_time = last_known_time
|
||||
record.memory_points = memory_points_value
|
||||
record.group_nickname = group_nickname_value
|
||||
record.group_cardname = group_cardname_value
|
||||
session.add(record)
|
||||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||||
else:
|
||||
@@ -572,7 +545,7 @@ class Person:
|
||||
first_known_time=first_known_time,
|
||||
last_known_time=last_known_time,
|
||||
memory_points=memory_points_value,
|
||||
group_nickname=group_nickname_value,
|
||||
group_cardname=group_cardname_value,
|
||||
)
|
||||
session.add(record)
|
||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||
@@ -583,79 +556,79 @@ class Person:
|
||||
async def build_relationship(self, chat_content: str = "", info_type=""):
|
||||
if not self.is_known:
|
||||
return ""
|
||||
# 构建points文本
|
||||
|
||||
nickname_str = ""
|
||||
if self.person_name != self.nickname:
|
||||
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
|
||||
|
||||
async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]:
|
||||
clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()]
|
||||
if not clean_traits:
|
||||
return []
|
||||
if not query_text:
|
||||
return clean_traits[:limit]
|
||||
relation_info = ""
|
||||
|
||||
numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1))
|
||||
prompt = f"""当前关注内容:
|
||||
{query_text}
|
||||
points_text = ""
|
||||
category_list = self.get_all_category()
|
||||
|
||||
候选人物信息:
|
||||
{numbered_traits}
|
||||
if chat_content:
|
||||
prompt = f"""当前聊天内容:
|
||||
{chat_content}
|
||||
|
||||
请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。
|
||||
例如:
|
||||
<1><3>
|
||||
如果都不相关,请输出<none>"""
|
||||
分类列表:
|
||||
{category_list}
|
||||
**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
|
||||
例如:
|
||||
<分类1><分类2><分类3>......
|
||||
如果没有相关的分类,请输出<none>"""
|
||||
|
||||
try:
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
selected_traits: List[str] = []
|
||||
for raw_index in extract_categories_from_response(response):
|
||||
if raw_index == "none":
|
||||
return []
|
||||
try:
|
||||
trait_index = int(raw_index) - 1
|
||||
except ValueError:
|
||||
continue
|
||||
if 0 <= trait_index < len(clean_traits):
|
||||
trait = clean_traits[trait_index]
|
||||
if trait not in selected_traits:
|
||||
selected_traits.append(trait)
|
||||
if selected_traits:
|
||||
return selected_traits[:limit]
|
||||
except Exception as e:
|
||||
logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}")
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
category_list = extract_categories_from_response(response)
|
||||
if "none" not in category_list:
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 2)
|
||||
if random_memory:
|
||||
random_memory_str = "\n".join(
|
||||
[get_memory_content_from_memory(memory) for memory in random_memory]
|
||||
)
|
||||
points_text = f"有关 {category} 的内容:{random_memory_str}"
|
||||
break
|
||||
elif info_type:
|
||||
prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。
|
||||
|
||||
return clean_traits[:limit]
|
||||
|
||||
profile = await memory_service.get_person_profile(self.person_id, limit=8)
|
||||
relation_parts: List[str] = []
|
||||
if profile.summary.strip():
|
||||
relation_parts.append(profile.summary.strip())
|
||||
|
||||
query_text = str(chat_content or info_type or "").strip()
|
||||
selected_traits = await _select_traits(query_text, profile.traits, limit=3)
|
||||
if not selected_traits and not query_text:
|
||||
selected_traits = [trait for trait in profile.traits if trait][:2]
|
||||
|
||||
for trait in selected_traits:
|
||||
clean_trait = str(trait).strip()
|
||||
if clean_trait and clean_trait not in relation_parts:
|
||||
relation_parts.append(clean_trait)
|
||||
|
||||
for evidence in profile.evidence:
|
||||
content = str(evidence.get("content", "") or "").strip()
|
||||
if content and content not in relation_parts:
|
||||
relation_parts.append(content)
|
||||
if len(relation_parts) >= 4:
|
||||
break
|
||||
现有信息类别列表:
|
||||
{category_list}
|
||||
**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
|
||||
例如:
|
||||
<分类1><分类2><分类3>......
|
||||
如果没有相关的分类,请输出<none>"""
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
category_list = extract_categories_from_response(response)
|
||||
if "none" not in category_list:
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 3)
|
||||
if random_memory:
|
||||
random_memory_str = "\n".join(
|
||||
[get_memory_content_from_memory(memory) for memory in random_memory]
|
||||
)
|
||||
points_text = f"有关 {category} 的内容:{random_memory_str}"
|
||||
break
|
||||
else:
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 1)[0]
|
||||
if random_memory:
|
||||
points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}"
|
||||
break
|
||||
|
||||
points_info = ""
|
||||
if relation_parts:
|
||||
points_info = f"你还记得有关{self.person_name}的内容:{';'.join(relation_parts[:3])}"
|
||||
if points_text:
|
||||
points_info = f"你还记得有关{self.person_name}的内容:{points_text}"
|
||||
|
||||
if not (nickname_str or points_info):
|
||||
return ""
|
||||
return f"{self.person_name}:{nickname_str}{points_info}"
|
||||
relation_info = f"{self.person_name}:{nickname_str}{points_info}"
|
||||
|
||||
return relation_info
|
||||
|
||||
|
||||
class PersonInfoManager:
|
||||
@@ -822,7 +795,7 @@ person_info_manager = PersonInfoManager()
|
||||
|
||||
|
||||
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
||||
"""将人物事实写入统一长期记忆
|
||||
"""将人物信息存入person_info的memory_points
|
||||
|
||||
Args:
|
||||
person_name: 人物名称
|
||||
@@ -830,11 +803,6 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
chat_id: 聊天ID
|
||||
"""
|
||||
try:
|
||||
content = str(memory_content or "").strip()
|
||||
if not content:
|
||||
logger.debug("人物记忆内容为空,跳过写入")
|
||||
return
|
||||
|
||||
# 从 chat_id 获取 session
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if not session:
|
||||
@@ -845,14 +813,16 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
|
||||
# 尝试从person_name查找person_id
|
||||
# 首先尝试通过person_name查找
|
||||
person_id = resolve_person_id_for_memory(
|
||||
person_name=person_name,
|
||||
platform=platform,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
person_id = get_person_id_by_person_name(person_name)
|
||||
|
||||
if not person_id:
|
||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
||||
return
|
||||
# 如果通过person_name找不到,尝试从 session 获取 user_id
|
||||
if platform and session.user_id:
|
||||
user_id = session.user_id
|
||||
person_id = get_person_id(platform, user_id)
|
||||
else:
|
||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
||||
return
|
||||
|
||||
# 创建或获取Person对象
|
||||
person = Person(person_id=person_id)
|
||||
@@ -861,34 +831,39 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
|
||||
return
|
||||
|
||||
memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16]
|
||||
result = await memory_service.ingest_text(
|
||||
external_id=f"person_fact:{person_id}:{memory_hash}",
|
||||
source_type="person_fact",
|
||||
text=content,
|
||||
chat_id=chat_id,
|
||||
person_ids=[person_id],
|
||||
participants=[person.person_name or person_name],
|
||||
timestamp=time.time(),
|
||||
tags=["person_fact"],
|
||||
metadata={
|
||||
"person_id": person_id,
|
||||
"person_name": person.person_name or person_name,
|
||||
"platform": platform,
|
||||
"source": "person_info.store_person_memory_from_answer",
|
||||
},
|
||||
respect_filter=True,
|
||||
user_id=str(session.user_id or "").strip(),
|
||||
group_id=str(session.group_id or "").strip(),
|
||||
)
|
||||
# 确定记忆分类(可以根据memory_content判断,这里使用通用分类)
|
||||
category = "其他" # 默认分类,可以根据需要调整
|
||||
|
||||
if result.success:
|
||||
if result.detail == "chat_filtered":
|
||||
logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})")
|
||||
else:
|
||||
logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})")
|
||||
# 记忆点格式:category:content:weight
|
||||
weight = "1.0" # 默认权重
|
||||
memory_point = f"{category}:{memory_content}:{weight}"
|
||||
|
||||
# 添加到memory_points
|
||||
if not person.memory_points:
|
||||
person.memory_points = []
|
||||
|
||||
# 检查是否已存在相似的记忆点(避免重复)
|
||||
is_duplicate = False
|
||||
for existing_point in person.memory_points:
|
||||
if existing_point and isinstance(existing_point, str):
|
||||
parts = existing_point.split(":", 2)
|
||||
if len(parts) >= 2:
|
||||
existing_content = parts[1].strip()
|
||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
||||
if (
|
||||
existing_content == memory_content
|
||||
or memory_content in existing_content
|
||||
or existing_content in memory_content
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
person.memory_points.append(memory_point)
|
||||
person.sync_to_database()
|
||||
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
|
||||
else:
|
||||
logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}")
|
||||
logger.debug(f"记忆点已存在,跳过: {memory_point}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储人物记忆失败: {e}")
|
||||
|
||||
34
src/platform_io/__init__.py
Normal file
34
src/platform_io/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""导出 Platform IO 层的公开入口。
|
||||
|
||||
当前仍处于地基阶段,调用方应优先从这里导入共享类型和全局管理器,
|
||||
而不是直接依赖更底层的私有子模块。
|
||||
"""
|
||||
|
||||
from .manager import PlatformIOManager, get_platform_io_manager
|
||||
from .route_key_factory import RouteKeyFactory
|
||||
from .routing import RouteTable
|
||||
from .types import (
|
||||
DeliveryBatch,
|
||||
DeliveryReceipt,
|
||||
DeliveryStatus,
|
||||
DriverDescriptor,
|
||||
DriverKind,
|
||||
InboundMessageEnvelope,
|
||||
RouteBinding,
|
||||
RouteKey,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeliveryBatch",
|
||||
"DeliveryReceipt",
|
||||
"DeliveryStatus",
|
||||
"DriverDescriptor",
|
||||
"DriverKind",
|
||||
"InboundMessageEnvelope",
|
||||
"PlatformIOManager",
|
||||
"RouteKeyFactory",
|
||||
"RouteBinding",
|
||||
"RouteKey",
|
||||
"RouteTable",
|
||||
"get_platform_io_manager",
|
||||
]
|
||||
133
src/platform_io/dedupe.py
Normal file
133
src/platform_io/dedupe.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""提供 Platform IO 的轻量入站消息去重能力。
|
||||
|
||||
当前实现基于 ``dict + heapq``:
|
||||
- ``dict`` 保存去重键到过期时间的映射
|
||||
- ``heapq`` 维护按过期时间排序的小顶堆
|
||||
|
||||
这样就不需要在每次检查时全表扫描,而是通过懒清理逐步弹出已经过期
|
||||
或已经失效的堆节点。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import heapq
|
||||
import time
|
||||
|
||||
|
||||
class MessageDeduplicator:
|
||||
"""使用基于 TTL 的内存缓存进行入站消息去重。
|
||||
|
||||
主要用于解决同一条外部消息被重复送入 Core 的问题,例如双路径并存、
|
||||
适配器重试、重连或重复回调等场景。Broker 可以借助这个组件在进入
|
||||
Core 前先拦住重复投递,避免重复处理、重复回复和重复入库。
|
||||
|
||||
当前实现使用 ``dict + heapq`` 维护过期时间:
|
||||
- ``dict`` 负责 ``O(1)`` 级别的去重键查找
|
||||
- ``heapq`` 负责按过期时间顺序做懒清理
|
||||
|
||||
这比“每次调用都全表扫描过期项”的实现更适合高吞吐消息场景。
|
||||
|
||||
Notes:
|
||||
复杂度说明如下,设 ``n`` 为当前缓存中的有效去重键数量:
|
||||
|
||||
- 单次 ``mark_seen()`` 在常见路径下的时间复杂度接近 ``O(log n)``
|
||||
- 从长期摊还角度看,``mark_seen()`` 的时间复杂度也接近 ``O(log n)``
|
||||
- 如果某次调用恰好触发一批过期键的集中清理,则该次调用的最坏时间复杂度
|
||||
可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出或清理的键数量
|
||||
- 空间复杂度为 ``O(n)``
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 300.0, max_entries: int = 10000) -> None:
|
||||
"""初始化去重器。
|
||||
|
||||
Args:
|
||||
ttl_seconds: 每个去重键在缓存中的保留时长,单位为秒。
|
||||
max_entries: 缓存允许保留的最大有效键数量,超出后会触发
|
||||
机会性淘汰。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``ttl_seconds`` 或 ``max_entries`` 非正数时抛出。
|
||||
"""
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds 必须大于 0")
|
||||
if max_entries <= 0:
|
||||
raise ValueError("max_entries 必须大于 0")
|
||||
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._max_entries = max_entries
|
||||
self._expire_heap: List[Tuple[float, str]] = []
|
||||
self._seen: Dict[str, float] = {}
|
||||
|
||||
def mark_seen(self, dedupe_key: str) -> bool:
|
||||
"""标记一条去重键已经出现过。
|
||||
|
||||
Args:
|
||||
dedupe_key: 能稳定标识一条外部入站消息的去重键。
|
||||
|
||||
Returns:
|
||||
bool: 若该键在当前 TTL 窗口内首次出现则返回 ``True``,
|
||||
否则返回 ``False``。
|
||||
|
||||
Notes:
|
||||
方法会先基于小顶堆做一次懒清理,再判断当前键是否仍在有效期内。
|
||||
如果缓存已达到上限,则会优先淘汰“最早过期的仍然有效的键”。
|
||||
|
||||
复杂度方面,常见路径下该方法接近 ``O(log n)``;如果恰好需要
|
||||
集中清理一批过期键,则单次调用最坏可达到 ``O(k log n)``。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
self._purge_expired(now)
|
||||
|
||||
expires_at = self._seen.get(dedupe_key)
|
||||
if expires_at is not None and expires_at > now:
|
||||
return False
|
||||
|
||||
if len(self._seen) >= self._max_entries:
|
||||
self._evict_earliest_live()
|
||||
|
||||
expires_at = now + self._ttl_seconds
|
||||
self._seen[dedupe_key] = expires_at
|
||||
heapq.heappush(self._expire_heap, (expires_at, dedupe_key))
|
||||
return True
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部去重缓存。"""
|
||||
self._expire_heap.clear()
|
||||
self._seen.clear()
|
||||
|
||||
def _purge_expired(self, now: float) -> None:
|
||||
"""从缓存中清理已经过期的去重键。
|
||||
|
||||
Args:
|
||||
now: 当前单调时钟时间戳。
|
||||
|
||||
Notes:
|
||||
堆中可能存在旧版本节点。例如同一个 ``dedupe_key`` 被重新写入后,
|
||||
旧的过期时间节点仍会留在堆里。这里会通过和 ``dict`` 中当前值比对,
|
||||
跳过这类失效节点。
|
||||
"""
|
||||
while self._expire_heap and self._expire_heap[0][0] <= now:
|
||||
expires_at, dedupe_key = heapq.heappop(self._expire_heap)
|
||||
current_expires_at = self._seen.get(dedupe_key)
|
||||
if current_expires_at is None:
|
||||
continue
|
||||
if current_expires_at != expires_at:
|
||||
continue
|
||||
self._seen.pop(dedupe_key, None)
|
||||
|
||||
def _evict_earliest_live(self) -> None:
|
||||
"""当缓存达到容量上限时,淘汰一条最早过期的有效键。
|
||||
|
||||
Notes:
|
||||
堆顶可能是已经过期或已失效的旧节点,因此这里同样需要循环弹出,
|
||||
直到找到一条当前仍然在 ``dict`` 中生效的键。
|
||||
"""
|
||||
while self._expire_heap:
|
||||
expires_at, dedupe_key = heapq.heappop(self._expire_heap)
|
||||
current_expires_at = self._seen.get(dedupe_key)
|
||||
if current_expires_at is None:
|
||||
continue
|
||||
if current_expires_at != expires_at:
|
||||
continue
|
||||
self._seen.pop(dedupe_key, None)
|
||||
return
|
||||
11
src/platform_io/drivers/__init__.py
Normal file
11
src/platform_io/drivers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""导出 Platform IO 层的公开驱动类型。"""
|
||||
|
||||
from .base import PlatformIODriver
|
||||
from .legacy_driver import LegacyPlatformDriver
|
||||
from .plugin_driver import PluginPlatformDriver
|
||||
|
||||
__all__ = [
|
||||
"LegacyPlatformDriver",
|
||||
"PlatformIODriver",
|
||||
"PluginPlatformDriver",
|
||||
]
|
||||
104
src/platform_io/drivers/base.py
Normal file
104
src/platform_io/drivers/base.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""定义 Platform IO 传输驱动的基础抽象协议。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from src.platform_io.types import DeliveryReceipt, DriverDescriptor, InboundMessageEnvelope, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
InboundHandler = Callable[[InboundMessageEnvelope], Awaitable[bool]]
|
||||
|
||||
|
||||
class PlatformIODriver(ABC):
|
||||
"""定义所有 Platform IO 驱动都必须实现的最小契约。
|
||||
|
||||
当前实现故意保持接口很小,让中间层可以先落地,再逐步把 legacy
|
||||
与 plugin 路径的真实收发能力迁入这套协议之下。
|
||||
"""
|
||||
|
||||
def __init__(self, descriptor: DriverDescriptor) -> None:
|
||||
"""使用驱动描述对象初始化驱动。
|
||||
|
||||
Args:
|
||||
descriptor: 注册到 Broker 中的静态驱动元数据。
|
||||
"""
|
||||
self._descriptor = descriptor
|
||||
self._inbound_handler: Optional[InboundHandler] = None
|
||||
|
||||
@property
|
||||
def descriptor(self) -> DriverDescriptor:
|
||||
"""返回当前驱动的描述对象。
|
||||
|
||||
Returns:
|
||||
DriverDescriptor: 当前驱动实例对应的描述对象。
|
||||
"""
|
||||
return self._descriptor
|
||||
|
||||
@property
|
||||
def driver_id(self) -> str:
|
||||
"""返回驱动标识。
|
||||
|
||||
Returns:
|
||||
str: 当前驱动的唯一 ID。
|
||||
"""
|
||||
return self._descriptor.driver_id
|
||||
|
||||
def set_inbound_handler(self, handler: InboundHandler) -> None:
|
||||
"""注册入站消息交回 Broker 的回调函数。
|
||||
|
||||
Args:
|
||||
handler: 将规范化入站封装继续转发给 Broker 的异步回调。
|
||||
"""
|
||||
self._inbound_handler = handler
|
||||
|
||||
def clear_inbound_handler(self) -> None:
|
||||
"""清除当前注册的入站回调函数。"""
|
||||
self._inbound_handler = None
|
||||
|
||||
async def emit_inbound(self, envelope: InboundMessageEnvelope) -> bool:
|
||||
"""将一条入站封装转交给 Broker 回调。
|
||||
|
||||
Args:
|
||||
envelope: 由驱动产出的规范化入站封装。
|
||||
|
||||
Returns:
|
||||
bool: 若 Broker 接受该入站消息则返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
|
||||
if self._inbound_handler is None:
|
||||
return False
|
||||
return await self._inbound_handler(envelope)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动驱动生命周期。
|
||||
|
||||
子类后续若需要初始化逻辑,可以覆盖这个钩子。
|
||||
"""
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止驱动生命周期。
|
||||
|
||||
子类后续若需要清理逻辑,可以覆盖这个钩子。
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""通过具体驱动发送一条消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
route_key: Broker 为本次投递选中的路由键。
|
||||
metadata: 本次出站投递可选的 Broker 侧元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 规范化后的投递结果。
|
||||
"""
|
||||
92
src/platform_io/drivers/legacy_driver.py
Normal file
92
src/platform_io/drivers/legacy_driver.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""提供 Platform IO 的 legacy 传输驱动实现。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class LegacyPlatformDriver(PlatformIODriver):
|
||||
"""面向 ``UniversalMessageSender`` 旧链的 Platform IO 驱动。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
driver_id: str,
|
||||
platform: str,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""初始化一个 legacy 驱动描述对象。
|
||||
|
||||
Args:
|
||||
driver_id: Broker 内的唯一驱动 ID。
|
||||
platform: 该 legacy 适配器链路负责的平台。
|
||||
account_id: 可选的账号 ID。
|
||||
scope: 可选的额外路由作用域。
|
||||
metadata: 可选的额外驱动元数据。
|
||||
"""
|
||||
descriptor = DriverDescriptor(
|
||||
driver_id=driver_id,
|
||||
kind=DriverKind.LEGACY,
|
||||
platform=platform,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
super().__init__(descriptor)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""通过旧链发送一条已经过预处理的消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
route_key: Broker 为本次投递选择的路由键。
|
||||
metadata: 本次出站投递可选的 Broker 侧元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 规范化后的发送回执。
|
||||
"""
|
||||
from src.chat.message_receive.uni_message_sender import send_prepared_message_to_platform
|
||||
|
||||
show_log = False
|
||||
if isinstance(metadata, dict):
|
||||
show_log = bool(metadata.get("show_log", False))
|
||||
|
||||
try:
|
||||
sent = await send_prepared_message_to_platform(message, show_log=show_log)
|
||||
except Exception as exc:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
if not sent:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error="旧链发送失败",
|
||||
)
|
||||
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
211
src/platform_io/drivers/plugin_driver.py
Normal file
211
src/platform_io/drivers/plugin_driver.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""提供 Platform IO 的插件消息网关驱动实现。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class _GatewaySupervisorProtocol(Protocol):
|
||||
"""消息网关驱动依赖的 Supervisor 最小协议。"""
|
||||
|
||||
async def invoke_message_gateway(
|
||||
self,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""调用插件声明的消息网关方法。"""
|
||||
|
||||
|
||||
class PluginPlatformDriver(PlatformIODriver):
|
||||
"""面向插件消息网关链路的 Platform IO 驱动。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
driver_id: str,
|
||||
platform: str,
|
||||
supervisor: _GatewaySupervisorProtocol,
|
||||
component_name: str,
|
||||
*,
|
||||
supports_send: bool,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
plugin_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""初始化一个插件消息网关驱动。
|
||||
|
||||
Args:
|
||||
driver_id: Broker 内的唯一驱动 ID。
|
||||
platform: 该消息网关负责的平台名称。
|
||||
supervisor: 持有该插件的 Supervisor。
|
||||
component_name: 出站时要调用的网关组件名称。
|
||||
supports_send: 当前驱动是否具备出站能力。
|
||||
account_id: 可选的账号 ID 或 self ID。
|
||||
scope: 可选的额外路由作用域。
|
||||
plugin_id: 拥有该实现的插件 ID。
|
||||
metadata: 可选的额外驱动元数据。
|
||||
"""
|
||||
|
||||
descriptor = DriverDescriptor(
|
||||
driver_id=driver_id,
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform=platform,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
plugin_id=plugin_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
super().__init__(descriptor)
|
||||
self._supervisor = supervisor
|
||||
self._component_name = component_name
|
||||
self._supports_send = supports_send
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""通过插件消息网关发送消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
route_key: Broker 为本次投递选择的路由键。
|
||||
metadata: 可选的发送元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 规范化后的发送回执。
|
||||
"""
|
||||
|
||||
if not self._supports_send:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error="当前消息网关仅支持接收,不支持发送",
|
||||
)
|
||||
|
||||
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
||||
|
||||
plugin_id = self.descriptor.plugin_id or ""
|
||||
if not plugin_id:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error="插件消息网关驱动缺少 plugin_id",
|
||||
)
|
||||
|
||||
try:
|
||||
message_dict = PluginMessageUtils._session_message_to_dict(message)
|
||||
response = await self._supervisor.invoke_message_gateway(
|
||||
plugin_id=plugin_id,
|
||||
component_name=self._component_name,
|
||||
args={
|
||||
"message": message_dict,
|
||||
"route": {
|
||||
"platform": route_key.platform,
|
||||
"account_id": route_key.account_id,
|
||||
"scope": route_key.scope,
|
||||
},
|
||||
"metadata": metadata or {},
|
||||
},
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
return self._build_receipt(message.message_id, route_key, response)
|
||||
|
||||
def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt:
|
||||
"""将网关调用响应归一化为出站回执。
|
||||
|
||||
Args:
|
||||
internal_message_id: 内部消息 ID。
|
||||
route_key: 本次投递的路由键。
|
||||
response: Supervisor 返回的 RPC 响应对象。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 标准化后的出站回执。
|
||||
"""
|
||||
|
||||
if getattr(response, "error", None):
|
||||
error = response.error.get("message", "消息网关发送失败")
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=error,
|
||||
)
|
||||
|
||||
payload = getattr(response, "payload", {})
|
||||
invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False
|
||||
if not invoke_success:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(payload.get("result", "消息网关发送失败")) if isinstance(payload, dict) else "消息网关发送失败",
|
||||
)
|
||||
|
||||
result = payload.get("result") if isinstance(payload, dict) else None
|
||||
if isinstance(result, dict):
|
||||
if result.get("success") is False:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(result.get("error", "消息网关发送失败")),
|
||||
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
|
||||
)
|
||||
external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
external_message_id=external_message_id,
|
||||
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
|
||||
)
|
||||
|
||||
if isinstance(result, str) and result.strip():
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
external_message_id=result.strip(),
|
||||
)
|
||||
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
611
src/platform_io/manager.py
Normal file
611
src/platform_io/manager.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""提供 Platform IO 层的中心 Broker 管理器。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
|
||||
from .dedupe import MessageDeduplicator
|
||||
from .outbound_tracker import OutboundTracker
|
||||
from .route_key_factory import RouteKeyFactory
|
||||
from .registry import DriverRegistry
|
||||
from .routing import RouteTable
|
||||
from .types import DeliveryBatch, DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
logger = get_logger("platform_io.manager")
|
||||
|
||||
InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class PlatformIOManager:
|
||||
"""统一协调平台消息 IO 的路由、去重与状态跟踪。
|
||||
|
||||
与旧实现不同,这个管理器不再负责“多条链路谁该接管平台”的裁决,
|
||||
只维护发送表和接收表两张轻量路由表:
|
||||
|
||||
- 发送时:解析所有命中的发送绑定并全部投递。
|
||||
- 接收时:只校验当前驱动是否已登记为可接收链路,然后全部放行给上层。
|
||||
- 去重时:仅对单条链路做技术性重放抑制,不做跨链路语义去重。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化 Broker 管理器及其内存状态。"""
|
||||
self._driver_registry = DriverRegistry()
|
||||
self._send_route_table = RouteTable()
|
||||
self._receive_route_table = RouteTable()
|
||||
self._legacy_send_drivers: Dict[str, PlatformIODriver] = {}
|
||||
self._deduplicator = MessageDeduplicator()
|
||||
self._outbound_tracker = OutboundTracker()
|
||||
self._inbound_dispatcher: Optional[InboundDispatcher] = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def is_started(self) -> bool:
|
||||
"""返回 Broker 当前是否已进入运行态。
|
||||
|
||||
Returns:
|
||||
bool: 若 Broker 已启动则返回 ``True``。
|
||||
"""
|
||||
return self._started
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 Broker,并依次启动当前已注册的全部驱动。
|
||||
|
||||
Raises:
|
||||
Exception: 当某个驱动启动失败时,异常会继续上抛;已成功启动的驱动
|
||||
会被自动回滚停止。
|
||||
"""
|
||||
if self._started:
|
||||
return
|
||||
|
||||
started_drivers: List[PlatformIODriver] = []
|
||||
try:
|
||||
for driver in self._driver_registry.list():
|
||||
await driver.start()
|
||||
started_drivers.append(driver)
|
||||
except Exception:
|
||||
for driver in reversed(started_drivers):
|
||||
try:
|
||||
await driver.stop()
|
||||
except Exception:
|
||||
logger.exception(f"回滚驱动停止失败: driver_id={driver.driver_id}")
|
||||
raise
|
||||
|
||||
self._started = True
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
"""确保出站发送管线已准备就绪。
|
||||
|
||||
该方法会先同步 legacy fallback driver,再在需要时启动 Broker。
|
||||
send service 应只调用这一层准备入口,而不是自行判断旧链或插件链。
|
||||
"""
|
||||
await self._sync_legacy_send_drivers()
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 Broker,并按逆序停止全部已注册驱动。
|
||||
|
||||
停止完成后,会同步清空仅对当前运行周期有效的去重缓存和出站跟踪状态,
|
||||
避免下一次启动时继续沿用上一个运行周期的瞬时内存数据。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当一个或多个驱动停止失败时抛出汇总异常。
|
||||
"""
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
stop_errors: List[str] = []
|
||||
for driver in reversed(self._driver_registry.list()):
|
||||
try:
|
||||
await driver.stop()
|
||||
except Exception as exc:
|
||||
stop_errors.append(f"{driver.driver_id}: {exc}")
|
||||
logger.exception(f"驱动停止失败: driver_id={driver.driver_id}")
|
||||
|
||||
self._started = False
|
||||
self._deduplicator.clear()
|
||||
self._outbound_tracker.clear()
|
||||
if stop_errors:
|
||||
raise RuntimeError(f"部分驱动停止失败: {'; '.join(stop_errors)}")
|
||||
|
||||
async def add_driver(self, driver: PlatformIODriver) -> None:
|
||||
"""向运行中的 Broker 注册并启动一个驱动。
|
||||
|
||||
如果 Broker 尚未启动,则该方法等价于 ``register_driver()``。
|
||||
|
||||
Args:
|
||||
driver: 要添加的驱动实例。
|
||||
|
||||
Raises:
|
||||
Exception: 当驱动启动失败时,注册会自动回滚,异常继续上抛。
|
||||
"""
|
||||
self._register_driver_internal(driver)
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
try:
|
||||
await driver.start()
|
||||
except Exception:
|
||||
self._unregister_driver_internal(driver.driver_id)
|
||||
raise
|
||||
|
||||
async def remove_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""从运行中的 Broker 停止并移除一个驱动。
|
||||
|
||||
如果 Broker 尚未启动,则该方法等价于 ``unregister_driver()``。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
|
||||
|
||||
Raises:
|
||||
Exception: 当 Broker 运行中且驱动停止失败时,异常会继续上抛。
|
||||
"""
|
||||
if not self._started:
|
||||
return self.unregister_driver(driver_id)
|
||||
|
||||
driver = self._driver_registry.get(driver_id)
|
||||
if driver is None:
|
||||
return None
|
||||
|
||||
await driver.stop()
|
||||
return self._unregister_driver_internal(driver_id)
|
||||
|
||||
@property
|
||||
def driver_registry(self) -> DriverRegistry:
|
||||
"""返回管理器持有的驱动注册表。
|
||||
|
||||
Returns:
|
||||
DriverRegistry: 用于保存全部已注册驱动的注册表。
|
||||
"""
|
||||
return self._driver_registry
|
||||
|
||||
@property
|
||||
def send_route_table(self) -> RouteTable:
|
||||
"""返回发送路由表。"""
|
||||
|
||||
return self._send_route_table
|
||||
|
||||
@property
|
||||
def receive_route_table(self) -> RouteTable:
|
||||
"""返回接收路由表。"""
|
||||
|
||||
return self._receive_route_table
|
||||
|
||||
@property
|
||||
def deduplicator(self) -> MessageDeduplicator:
|
||||
"""返回管理器持有的入站去重器。
|
||||
|
||||
Returns:
|
||||
MessageDeduplicator: 用于抑制重复入站的去重器。
|
||||
"""
|
||||
return self._deduplicator
|
||||
|
||||
@property
|
||||
def outbound_tracker(self) -> OutboundTracker:
|
||||
"""返回管理器持有的出站跟踪器。
|
||||
|
||||
Returns:
|
||||
OutboundTracker: 用于记录出站 pending 状态与回执的跟踪器。
|
||||
"""
|
||||
return self._outbound_tracker
|
||||
|
||||
def set_inbound_dispatcher(self, dispatcher: InboundDispatcher) -> None:
|
||||
"""设置统一的入站分发回调。
|
||||
|
||||
Args:
|
||||
dispatcher: 接收已通过 Broker 审核的入站封装,并继续送入
|
||||
Core 下一处理阶段的异步回调。
|
||||
"""
|
||||
|
||||
self._inbound_dispatcher = dispatcher
|
||||
|
||||
def clear_inbound_dispatcher(self) -> None:
|
||||
"""清除当前的入站分发回调。"""
|
||||
self._inbound_dispatcher = None
|
||||
|
||||
@property
|
||||
def has_inbound_dispatcher(self) -> bool:
|
||||
"""返回当前是否已经配置入站分发回调。
|
||||
|
||||
Returns:
|
||||
bool: 若已经配置入站分发回调则返回 ``True``。
|
||||
"""
|
||||
return self._inbound_dispatcher is not None
|
||||
|
||||
def register_driver(self, driver: PlatformIODriver) -> None:
|
||||
"""注册驱动,并把它的入站回调挂到 Broker。
|
||||
|
||||
Args:
|
||||
driver: 要注册的驱动实例。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
|
||||
``add_driver()`` 以保证驱动生命周期和注册状态一致。
|
||||
"""
|
||||
if self._started:
|
||||
raise RuntimeError("Broker 运行中不允许直接 register_driver,请改用 add_driver()")
|
||||
|
||||
self._register_driver_internal(driver)
|
||||
|
||||
def _register_driver_internal(self, driver: PlatformIODriver) -> None:
|
||||
"""执行不带运行态限制的内部驱动注册。
|
||||
|
||||
Args:
|
||||
driver: 要注册的驱动实例。
|
||||
"""
|
||||
driver.set_inbound_handler(self.accept_inbound)
|
||||
self._driver_registry.register(driver)
|
||||
|
||||
def unregister_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""从 Broker 注销一个驱动。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
|
||||
``remove_driver()``,避免驱动停止与路由解绑脱节。
|
||||
"""
|
||||
if self._started:
|
||||
raise RuntimeError("Broker 运行中不允许直接 unregister_driver,请改用 remove_driver()")
|
||||
|
||||
return self._unregister_driver_internal(driver_id)
|
||||
|
||||
def _unregister_driver_internal(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""执行不带运行态限制的内部驱动注销。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
|
||||
"""
|
||||
removed_driver = self._driver_registry.unregister(driver_id)
|
||||
if removed_driver is None:
|
||||
return None
|
||||
|
||||
removed_driver.clear_inbound_handler()
|
||||
self._send_route_table.remove_bindings_by_driver(driver_id)
|
||||
self._receive_route_table.remove_bindings_by_driver(driver_id)
|
||||
self._legacy_send_drivers = {
|
||||
platform: driver
|
||||
for platform, driver in self._legacy_send_drivers.items()
|
||||
if driver.driver_id != driver_id
|
||||
}
|
||||
return removed_driver
|
||||
|
||||
async def _sync_legacy_send_drivers(self) -> None:
|
||||
"""根据当前配置同步 legacy fallback driver。"""
|
||||
from src.chat.utils.utils import get_all_bot_accounts
|
||||
from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
|
||||
|
||||
desired_accounts = get_all_bot_accounts()
|
||||
desired_platforms = set(desired_accounts.keys())
|
||||
current_platforms = set(self._legacy_send_drivers.keys())
|
||||
|
||||
for platform in sorted(current_platforms - desired_platforms):
|
||||
await self._remove_legacy_send_driver(platform)
|
||||
|
||||
for platform, account_id in desired_accounts.items():
|
||||
existing_driver = self._legacy_send_drivers.get(platform)
|
||||
if existing_driver is not None and existing_driver.descriptor.account_id == account_id:
|
||||
continue
|
||||
|
||||
if existing_driver is not None:
|
||||
await self._remove_legacy_send_driver(platform)
|
||||
|
||||
driver = LegacyPlatformDriver(
|
||||
driver_id=f"legacy.send.{platform}",
|
||||
platform=platform,
|
||||
account_id=account_id,
|
||||
)
|
||||
if self._started:
|
||||
await self.add_driver(driver)
|
||||
else:
|
||||
self.register_driver(driver)
|
||||
self._legacy_send_drivers[platform] = driver
|
||||
|
||||
async def _remove_legacy_send_driver(self, platform: str) -> None:
|
||||
"""移除指定平台的 legacy fallback driver。
|
||||
|
||||
Args:
|
||||
platform: 要移除的目标平台。
|
||||
"""
|
||||
driver = self._legacy_send_drivers.get(platform)
|
||||
if driver is None:
|
||||
return
|
||||
|
||||
if self._started:
|
||||
await self.remove_driver(driver.driver_id)
|
||||
else:
|
||||
self.unregister_driver(driver.driver_id)
|
||||
self._legacy_send_drivers.pop(platform, None)
|
||||
|
||||
def bind_send_route(self, binding: RouteBinding) -> None:
|
||||
"""为某个路由键绑定发送驱动。
|
||||
|
||||
Args:
|
||||
binding: 要保存的路由绑定。
|
||||
|
||||
Raises:
|
||||
ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
|
||||
"""
|
||||
driver = self._driver_registry.get(binding.driver_id)
|
||||
if driver is None:
|
||||
raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
|
||||
|
||||
self._validate_binding_against_driver(binding, driver)
|
||||
self._send_route_table.bind(binding)
|
||||
|
||||
def bind_receive_route(self, binding: RouteBinding) -> None:
|
||||
"""为某个路由键绑定接收驱动。
|
||||
|
||||
Args:
|
||||
binding: 要保存的路由绑定。
|
||||
|
||||
Raises:
|
||||
ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
|
||||
"""
|
||||
driver = self._driver_registry.get(binding.driver_id)
|
||||
if driver is None:
|
||||
raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
|
||||
|
||||
self._validate_binding_against_driver(binding, driver)
|
||||
self._receive_route_table.bind(binding)
|
||||
|
||||
def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
|
||||
"""移除发送路由绑定。
|
||||
|
||||
Args:
|
||||
route_key: 要移除绑定的路由键。
|
||||
driver_id: 可选的特定驱动 ID。
|
||||
"""
|
||||
|
||||
self._send_route_table.unbind(route_key, driver_id)
|
||||
|
||||
def unbind_receive_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
|
||||
"""移除接收路由绑定。
|
||||
|
||||
Args:
|
||||
route_key: 要移除绑定的路由键。
|
||||
driver_id: 可选的特定驱动 ID。
|
||||
"""
|
||||
|
||||
self._receive_route_table.unbind(route_key, driver_id)
|
||||
|
||||
def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]:
|
||||
"""解析某个路由键当前命中的全部发送驱动。
|
||||
|
||||
Args:
|
||||
route_key: 要解析的路由键。
|
||||
|
||||
Returns:
|
||||
List[PlatformIODriver]: 当前命中的全部发送驱动。
|
||||
"""
|
||||
|
||||
drivers: List[PlatformIODriver] = []
|
||||
seen_driver_ids: set[str] = set()
|
||||
for binding in self._send_route_table.resolve_bindings(route_key):
|
||||
driver = self._driver_registry.get(binding.driver_id)
|
||||
if driver is not None and driver.driver_id not in seen_driver_ids:
|
||||
drivers.append(driver)
|
||||
seen_driver_ids.add(driver.driver_id)
|
||||
|
||||
fallback_driver = self._legacy_send_drivers.get(route_key.platform)
|
||||
if fallback_driver is not None:
|
||||
descriptor = fallback_driver.descriptor
|
||||
account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id)
|
||||
scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope)
|
||||
if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids:
|
||||
drivers.append(fallback_driver)
|
||||
|
||||
return drivers
|
||||
|
||||
@staticmethod
|
||||
def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
|
||||
"""根据 ``SessionMessage`` 构造路由键。
|
||||
|
||||
Args:
|
||||
message: 内部会话消息对象。
|
||||
|
||||
Returns:
|
||||
RouteKey: 由消息内容提取出的规范化路由键。
|
||||
"""
|
||||
return RouteKeyFactory.from_session_message(message)
|
||||
|
||||
@staticmethod
|
||||
def build_route_key_from_message_dict(message_dict: Dict[str, Any]) -> RouteKey:
|
||||
"""根据消息字典构造路由键。
|
||||
|
||||
Args:
|
||||
message_dict: Host 与插件之间传输的消息字典。
|
||||
|
||||
Returns:
|
||||
RouteKey: 由消息字典提取出的规范化路由键。
|
||||
"""
|
||||
return RouteKeyFactory.from_message_dict(message_dict)
|
||||
|
||||
async def accept_inbound(self, envelope: InboundMessageEnvelope) -> bool:
|
||||
"""处理一条由驱动上报的入站封装。
|
||||
|
||||
Args:
|
||||
envelope: 由传输驱动产出的入站封装。
|
||||
|
||||
Returns:
|
||||
bool: 若消息被接受并继续转发给入站分发器,则返回 ``True``,
|
||||
否则返回 ``False``。
|
||||
"""
|
||||
|
||||
if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id):
|
||||
logger.info(
|
||||
f"忽略未登记到接收路由表的入站消息: route={envelope.route_key} "
|
||||
f"driver={envelope.driver_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
if self._inbound_dispatcher is None:
|
||||
logger.debug("PlatformIOManager 尚未配置 inbound dispatcher,暂不继续分发")
|
||||
return False
|
||||
|
||||
dedupe_key = self._build_inbound_dedupe_key(envelope)
|
||||
if dedupe_key is not None:
|
||||
if not self._deduplicator.mark_seen(dedupe_key):
|
||||
logger.info(f"忽略重复入站消息: dedupe_key={dedupe_key}")
|
||||
return False
|
||||
|
||||
await self._inbound_dispatcher(envelope)
|
||||
return True
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryBatch:
|
||||
"""通过 Broker 选中的全部发送驱动广播一条消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
route_key: 本次出站投递选择的路由键。
|
||||
metadata: 可选的额外 Broker 侧元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryBatch: 规范化后的批量出站回执。
|
||||
"""
|
||||
drivers = self.resolve_drivers(route_key)
|
||||
if not drivers:
|
||||
return DeliveryBatch(internal_message_id=message.message_id, route_key=route_key)
|
||||
|
||||
receipts: List[DeliveryReceipt] = []
|
||||
for driver in drivers:
|
||||
try:
|
||||
self._outbound_tracker.begin_tracking(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
driver_id=driver.driver_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
except ValueError as exc:
|
||||
receipts.append(
|
||||
DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=driver.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata)
|
||||
except Exception as exc:
|
||||
receipt = DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=driver.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
self._outbound_tracker.finish_tracking(receipt)
|
||||
receipts.append(receipt)
|
||||
|
||||
return DeliveryBatch(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
receipts=receipts,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]:
|
||||
"""构造用于入站抑制的去重键。
|
||||
|
||||
Args:
|
||||
envelope: 当前正在处理的入站封装。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。
|
||||
|
||||
Notes:
|
||||
这里仅接受上游显式提供的稳定消息身份,例如 ``dedupe_key``、
|
||||
平台侧 ``external_message_id`` 或已经完成规范化的
|
||||
``session_message.message_id``。Broker 不再根据 ``payload`` 内容
|
||||
猜测语义去重键,避免把“短时间内两条内容刚好完全相同”的合法消息
|
||||
误判为重复入站。
|
||||
"""
|
||||
raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id
|
||||
if raw_dedupe_key is None and envelope.session_message is not None:
|
||||
raw_dedupe_key = envelope.session_message.message_id
|
||||
if raw_dedupe_key is None:
|
||||
return None
|
||||
|
||||
normalized_dedupe_key = str(raw_dedupe_key).strip()
|
||||
if not normalized_dedupe_key:
|
||||
return None
|
||||
|
||||
return f"{envelope.driver_id}:{normalized_dedupe_key}"
|
||||
|
||||
@staticmethod
|
||||
def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None:
|
||||
"""校验路由绑定与驱动描述是否一致。
|
||||
|
||||
Args:
|
||||
binding: 待校验的路由绑定。
|
||||
driver: 被绑定的驱动实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当绑定类型、平台或更细粒度路由维度与驱动描述冲突时抛出。
|
||||
"""
|
||||
descriptor = driver.descriptor
|
||||
if binding.driver_kind != descriptor.kind:
|
||||
raise ValueError(
|
||||
f"路由绑定的 driver_kind={binding.driver_kind} 与驱动 {driver.driver_id} 的类型 "
|
||||
f"{descriptor.kind} 不一致"
|
||||
)
|
||||
|
||||
if binding.route_key.platform != descriptor.platform:
|
||||
raise ValueError(
|
||||
f"路由绑定的平台 {binding.route_key.platform} 与驱动 {driver.driver_id} 的平台 "
|
||||
f"{descriptor.platform} 不一致"
|
||||
)
|
||||
|
||||
if descriptor.account_id is not None and binding.route_key.account_id not in (None, descriptor.account_id):
|
||||
raise ValueError(
|
||||
f"路由绑定的 account_id={binding.route_key.account_id} 与驱动 {driver.driver_id} 的 "
|
||||
f"account_id={descriptor.account_id} 冲突"
|
||||
)
|
||||
|
||||
if descriptor.scope is not None and binding.route_key.scope not in (None, descriptor.scope):
|
||||
raise ValueError(
|
||||
f"路由绑定的 scope={binding.route_key.scope} 与驱动 {driver.driver_id} 的 "
|
||||
f"scope={descriptor.scope} 冲突"
|
||||
)
|
||||
|
||||
|
||||
_platform_io_manager: Optional[PlatformIOManager] = None
|
||||
|
||||
|
||||
def get_platform_io_manager() -> PlatformIOManager:
|
||||
"""返回全局 ``PlatformIOManager`` 单例。
|
||||
|
||||
Returns:
|
||||
PlatformIOManager: 进程级共享的 Broker 管理器实例。
|
||||
"""
|
||||
|
||||
global _platform_io_manager
|
||||
if _platform_io_manager is None:
|
||||
_platform_io_manager = PlatformIOManager()
|
||||
return _platform_io_manager
|
||||
286
src/platform_io/outbound_tracker.py
Normal file
286
src/platform_io/outbound_tracker.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""跟踪 Platform IO 层的出站投递状态。
|
||||
|
||||
当前实现基于两组 ``dict + heapq``:
|
||||
- ``_pending`` 和 ``_pending_expire_heap`` 负责管理待完成的出站记录
|
||||
- ``_receipts_by_external_id`` 和 ``_receipt_expire_heap`` 负责管理已完成回执索引
|
||||
|
||||
这样就不需要在每次读写时全表扫描过期项,而是通过懒清理逐步弹出已经过期
|
||||
或已经失效的堆节点。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import heapq
|
||||
import time
|
||||
|
||||
from .types import DeliveryReceipt, RouteKey
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PendingOutboundRecord:
|
||||
"""表示一条仍在等待完成的出站投递记录。
|
||||
|
||||
Attributes:
|
||||
internal_message_id: 正在跟踪的内部 ``SessionMessage.message_id``。
|
||||
route_key: 该出站投递开始时使用的路由键。
|
||||
driver_id: 负责这次出站投递的驱动 ID。
|
||||
created_at: 开始跟踪时记录的单调时钟时间戳。
|
||||
expires_at: 该待完成记录预计过期的单调时钟时间戳。
|
||||
metadata: 与待完成记录一同保留的额外 Broker 侧元数据。
|
||||
"""
|
||||
|
||||
internal_message_id: str
|
||||
route_key: RouteKey
|
||||
driver_id: str
|
||||
created_at: float = field(default_factory=time.monotonic)
|
||||
expires_at: float = 0.0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StoredDeliveryReceipt:
|
||||
"""表示一条已完成并暂存的出站回执。
|
||||
|
||||
Attributes:
|
||||
receipt: 规范化后的出站投递回执。
|
||||
stored_at: 回执被写入索引时记录的单调时钟时间戳。
|
||||
expires_at: 该回执索引预计过期的单调时钟时间戳。
|
||||
"""
|
||||
|
||||
receipt: DeliveryReceipt
|
||||
stored_at: float = field(default_factory=time.monotonic)
|
||||
expires_at: float = 0.0
|
||||
|
||||
|
||||
class OutboundTracker:
|
||||
"""统一跟踪出站消息的 pending 状态与最终回执。
|
||||
|
||||
主要用于解决出站消息在发送过程中“状态散落在不同路径里”的问题:
|
||||
- 发送开始后,需要在最终回执返回前保留一份 pending 状态
|
||||
- 平台返回 ``external_message_id`` 后,需要保留一段时间的回执索引
|
||||
|
||||
当前实现使用 ``dict + heapq`` 做 TTL 管理:
|
||||
- ``dict`` 提供 ``O(1)`` 级别的主键查询
|
||||
- ``heapq`` 提供按过期时间排序的懒清理能力
|
||||
|
||||
这比“每次 begin/finish/get 都全表扫描”的实现更适合高吞吐出站场景。
|
||||
|
||||
Notes:
|
||||
复杂度说明如下,设 ``p`` 为当前有效 pending 数量,``r`` 为当前有效回执数量:
|
||||
|
||||
- ``begin_tracking()``、``finish_tracking()`` 的常见路径时间复杂度接近
|
||||
``O(log p)`` 或 ``O(log r)``
|
||||
- ``get_pending()``、``get_receipt_by_external_id()`` 的查询本身是 ``O(1)``
|
||||
,连同懒清理一起看,长期摊还复杂度接近 ``O(log n)``
|
||||
- 如果某次调用恰好触发一批过期节点的集中清理,则该次调用的最坏时间复杂度
|
||||
可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出的节点数量
|
||||
- 空间复杂度为 ``O(p + r)``
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 1800.0) -> None:
|
||||
"""初始化出站跟踪器。
|
||||
|
||||
Args:
|
||||
ttl_seconds: 待完成记录与按外部消息 ID 建立的回执索引保留时长,
|
||||
单位为秒。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``ttl_seconds`` 非正数时抛出。
|
||||
"""
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds 必须大于 0")
|
||||
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {}
|
||||
self._pending_expire_heap: List[Tuple[float, str, str]] = []
|
||||
self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {}
|
||||
self._receipt_expire_heap: List[Tuple[float, str]] = []
|
||||
|
||||
@staticmethod
|
||||
def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]:
|
||||
"""构造单条出站跟踪记录的唯一键。
|
||||
|
||||
Args:
|
||||
internal_message_id: 内部消息 ID。
|
||||
driver_id: 负责当前投递的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。
|
||||
"""
|
||||
return internal_message_id, driver_id
|
||||
|
||||
def begin_tracking(
|
||||
self,
|
||||
internal_message_id: str,
|
||||
route_key: RouteKey,
|
||||
driver_id: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> PendingOutboundRecord:
|
||||
"""开始跟踪一次出站投递。
|
||||
|
||||
Args:
|
||||
internal_message_id: 正在投递的内部消息 ID。
|
||||
route_key: 这次出站投递选择的路由键。
|
||||
driver_id: 负责本次投递的驱动 ID。
|
||||
metadata: 可选的额外元数据,会一并保存在待完成记录中。
|
||||
|
||||
Returns:
|
||||
PendingOutboundRecord: 新创建的待完成记录。
|
||||
|
||||
Raises:
|
||||
ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在
|
||||
未完成记录时抛出。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
self._cleanup_expired(now)
|
||||
pending_key = self._build_pending_key(internal_message_id, driver_id)
|
||||
|
||||
if pending_key in self._pending:
|
||||
raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录")
|
||||
|
||||
expires_at = now + self._ttl_seconds
|
||||
record = PendingOutboundRecord(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
driver_id=driver_id,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._pending[pending_key] = record
|
||||
heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id))
|
||||
return record
|
||||
|
||||
def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]:
|
||||
"""使用最终回执结束一条出站跟踪。
|
||||
|
||||
Args:
|
||||
receipt: 规范化后的最终投递回执。
|
||||
|
||||
Returns:
|
||||
Optional[PendingOutboundRecord]: 若此前存在待完成记录,则返回该记录。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
self._cleanup_expired(now)
|
||||
|
||||
pending_record: Optional[PendingOutboundRecord] = None
|
||||
if receipt.driver_id:
|
||||
pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id)
|
||||
pending_record = self._pending.pop(pending_key, None)
|
||||
else:
|
||||
matched_records = [
|
||||
key
|
||||
for key, record in self._pending.items()
|
||||
if record.internal_message_id == receipt.internal_message_id
|
||||
]
|
||||
if len(matched_records) == 1:
|
||||
pending_record = self._pending.pop(matched_records[0], None)
|
||||
|
||||
if receipt.external_message_id:
|
||||
expires_at = now + self._ttl_seconds
|
||||
self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt(
|
||||
receipt=receipt,
|
||||
stored_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id))
|
||||
return pending_record
|
||||
|
||||
def get_pending(
|
||||
self,
|
||||
internal_message_id: str,
|
||||
driver_id: Optional[str] = None,
|
||||
) -> Optional[PendingOutboundRecord]:
|
||||
"""根据内部消息 ID 查询待完成记录。
|
||||
|
||||
Args:
|
||||
internal_message_id: 要查询的内部消息 ID。
|
||||
driver_id: 可选的驱动 ID;提供后仅返回该驱动上的待完成记录。
|
||||
|
||||
Returns:
|
||||
Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。
|
||||
"""
|
||||
self._cleanup_expired(time.monotonic())
|
||||
|
||||
if driver_id:
|
||||
return self._pending.get(self._build_pending_key(internal_message_id, driver_id))
|
||||
|
||||
matched_records = [
|
||||
record
|
||||
for record in self._pending.values()
|
||||
if record.internal_message_id == internal_message_id
|
||||
]
|
||||
if len(matched_records) == 1:
|
||||
return matched_records[0]
|
||||
return None
|
||||
|
||||
def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]:
|
||||
"""根据外部平台消息 ID 查询已完成回执。
|
||||
|
||||
Args:
|
||||
external_message_id: 要查询的平台侧消息 ID。
|
||||
|
||||
Returns:
|
||||
Optional[DeliveryReceipt]: 若存在对应回执,则返回该回执。
|
||||
"""
|
||||
self._cleanup_expired(time.monotonic())
|
||||
stored_receipt = self._receipts_by_external_id.get(external_message_id)
|
||||
return stored_receipt.receipt if stored_receipt else None
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部待完成记录与已保存回执。"""
|
||||
self._pending.clear()
|
||||
self._pending_expire_heap.clear()
|
||||
self._receipts_by_external_id.clear()
|
||||
self._receipt_expire_heap.clear()
|
||||
|
||||
def _cleanup_expired(self, now: float) -> None:
|
||||
"""清理内存中已经过期的待完成记录与已保存回执。
|
||||
|
||||
Args:
|
||||
now: 当前单调时钟时间戳。
|
||||
"""
|
||||
self._cleanup_expired_pending(now)
|
||||
self._cleanup_expired_receipts(now)
|
||||
|
||||
def _cleanup_expired_pending(self, now: float) -> None:
|
||||
"""清理已经过期的待完成记录。
|
||||
|
||||
Args:
|
||||
now: 当前单调时钟时间戳。
|
||||
|
||||
Notes:
|
||||
堆中可能存在已经失效的旧节点。例如某条记录提前 ``finish`` 后,
|
||||
它原本的过期节点仍可能留在堆里。这里会通过和 ``dict`` 中当前记录的
|
||||
``expires_at`` 对比,跳过这类旧节点。
|
||||
"""
|
||||
while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now:
|
||||
expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap)
|
||||
pending_key = self._build_pending_key(internal_message_id, driver_id)
|
||||
current_record = self._pending.get(pending_key)
|
||||
if current_record is None:
|
||||
continue
|
||||
if current_record.expires_at != expires_at:
|
||||
continue
|
||||
self._pending.pop(pending_key, None)
|
||||
|
||||
def _cleanup_expired_receipts(self, now: float) -> None:
|
||||
"""清理已经过期的回执索引。
|
||||
|
||||
Args:
|
||||
now: 当前单调时钟时间戳。
|
||||
|
||||
Notes:
|
||||
同一个 ``external_message_id`` 在极端情况下可能被重复写入索引,
|
||||
因此这里同样需要通过 ``expires_at`` 和当前 ``dict`` 中的值比对,
|
||||
跳过已经失效的旧堆节点。
|
||||
"""
|
||||
while self._receipt_expire_heap and self._receipt_expire_heap[0][0] <= now:
|
||||
expires_at, external_message_id = heapq.heappop(self._receipt_expire_heap)
|
||||
current_receipt = self._receipts_by_external_id.get(external_message_id)
|
||||
if current_receipt is None:
|
||||
continue
|
||||
if current_receipt.expires_at != expires_at:
|
||||
continue
|
||||
self._receipts_by_external_id.pop(external_message_id, None)
|
||||
70
src/platform_io/registry.py
Normal file
70
src/platform_io/registry.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""提供 Platform IO 的驱动注册与查询能力。"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.types import DriverKind
|
||||
|
||||
|
||||
class DriverRegistry:
|
||||
"""集中保存已注册的 Platform IO 驱动,并提供基础查询接口。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化一个空的驱动注册表。"""
|
||||
self._drivers: Dict[str, PlatformIODriver] = {}
|
||||
|
||||
def register(self, driver: PlatformIODriver) -> None:
|
||||
"""注册一个驱动实例。
|
||||
|
||||
Args:
|
||||
driver: 要注册的驱动实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当驱动 ID 已经存在时抛出。
|
||||
"""
|
||||
if driver.driver_id in self._drivers:
|
||||
raise ValueError(f"驱动 {driver.driver_id} 已注册")
|
||||
self._drivers[driver.driver_id] = driver
|
||||
|
||||
def unregister(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""按驱动 ID 注销一个驱动。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
|
||||
"""
|
||||
return self._drivers.pop(driver_id, None)
|
||||
|
||||
def get(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""按驱动 ID 获取驱动实例。
|
||||
|
||||
Args:
|
||||
driver_id: 要查询的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若存在匹配驱动,则返回该驱动实例。
|
||||
"""
|
||||
return self._drivers.get(driver_id)
|
||||
|
||||
def list(self, *, kind: Optional[DriverKind] = None, platform: Optional[str] = None) -> List[PlatformIODriver]:
|
||||
"""列出已注册驱动,并支持可选过滤。
|
||||
|
||||
Args:
|
||||
kind: 可选的驱动类型过滤条件。
|
||||
platform: 可选的平台名称过滤条件。
|
||||
|
||||
Returns:
|
||||
List[PlatformIODriver]: 符合过滤条件的驱动列表。
|
||||
"""
|
||||
drivers = list(self._drivers.values())
|
||||
if kind is not None:
|
||||
drivers = [driver for driver in drivers if driver.descriptor.kind == kind]
|
||||
if platform is not None:
|
||||
drivers = [driver for driver in drivers if driver.descriptor.platform == platform]
|
||||
return drivers
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部已注册驱动。"""
|
||||
self._drivers.clear()
|
||||
150
src/platform_io/route_key_factory.py
Normal file
150
src/platform_io/route_key_factory.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""提供 Platform IO 路由键的统一提取与构造能力。
|
||||
|
||||
这层的目标不是直接接入具体消息链,而是先把“未来接线时用什么字段构造
|
||||
RouteKey”约定下来,避免 legacy 和 plugin 两条链路各自发明一套隐式规则。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from .types import RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class RouteKeyFactory:
|
||||
"""统一构造 ``RouteKey`` 的工厂。
|
||||
|
||||
当前约定会优先从消息字典顶层、``message_info``、``additional_config`` 或传入 metadata 中提取
|
||||
以下字段:
|
||||
|
||||
- account_id: ``platform_io_account_id`` / ``account_id`` / ``self_id`` / ``bot_account``
|
||||
- scope: ``platform_io_scope`` / ``route_scope`` / ``adapter_scope`` / ``connection_id``
|
||||
|
||||
这样即使上游主链暂时还没有正式的 ``self_id`` 字段,中间层也能先统一
|
||||
约定提取口径,等具体消息链接入时直接复用。
|
||||
"""
|
||||
|
||||
ACCOUNT_ID_KEYS = (
|
||||
"platform_io_account_id",
|
||||
"account_id",
|
||||
"self_id",
|
||||
"bot_account",
|
||||
)
|
||||
SCOPE_KEYS = (
|
||||
"platform_io_scope",
|
||||
"route_scope",
|
||||
"adapter_scope",
|
||||
"connection_id",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_platform(
|
||||
cls,
|
||||
platform: str,
|
||||
*,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> RouteKey:
|
||||
"""根据平台名和可选 metadata 构造 ``RouteKey``。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
account_id: 显式传入的账号 ID;若为空,则尝试从 metadata 提取。
|
||||
scope: 显式传入的路由作用域;若为空,则尝试从 metadata 提取。
|
||||
metadata: 可选的元数据字典。
|
||||
|
||||
Returns:
|
||||
RouteKey: 构造出的规范化路由键。
|
||||
"""
|
||||
extracted_account_id, extracted_scope = cls.extract_components(metadata)
|
||||
return RouteKey(
|
||||
platform=platform,
|
||||
account_id=account_id or extracted_account_id,
|
||||
scope=scope or extracted_scope,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_message_dict(cls, message_dict: Dict[str, Any]) -> RouteKey:
|
||||
"""从消息字典中提取 ``RouteKey``。
|
||||
|
||||
Args:
|
||||
message_dict: Host 与插件之间传输的消息字典。
|
||||
|
||||
Returns:
|
||||
RouteKey: 构造出的规范化路由键。
|
||||
|
||||
Raises:
|
||||
ValueError: 当消息字典缺少有效 ``platform`` 字段时抛出。
|
||||
"""
|
||||
platform = str(message_dict.get("platform") or "").strip()
|
||||
if not platform:
|
||||
raise ValueError("消息字典缺少有效的 platform 字段,无法构造 RouteKey")
|
||||
|
||||
message_info = message_dict.get("message_info", {})
|
||||
additional_config = {}
|
||||
if isinstance(message_info, dict):
|
||||
raw_additional_config = message_info.get("additional_config", {})
|
||||
if isinstance(raw_additional_config, dict):
|
||||
additional_config = raw_additional_config
|
||||
|
||||
explicit_account_id, explicit_scope = cls.extract_components(message_dict)
|
||||
message_info_account_id, message_info_scope = cls.extract_components(message_info)
|
||||
metadata_account_id, metadata_scope = cls.extract_components(additional_config)
|
||||
return RouteKey(
|
||||
platform=platform,
|
||||
account_id=explicit_account_id or message_info_account_id or metadata_account_id,
|
||||
scope=explicit_scope or message_info_scope or metadata_scope,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_session_message(cls, message: "SessionMessage") -> RouteKey:
|
||||
"""从 ``SessionMessage`` 中提取 ``RouteKey``。
|
||||
|
||||
Args:
|
||||
message: 内部会话消息对象。
|
||||
|
||||
Returns:
|
||||
RouteKey: 构造出的规范化路由键。
|
||||
"""
|
||||
additional_config = message.message_info.additional_config or {}
|
||||
metadata = additional_config if isinstance(additional_config, dict) else {}
|
||||
return cls.from_platform(message.platform, metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def extract_components(cls, mapping: Optional[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""从任意字典中提取 ``account_id`` 与 ``scope``。
|
||||
|
||||
Args:
|
||||
mapping: 待提取的字典;若为空或不是字典,则返回空结果。
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]: ``(account_id, scope)``。
|
||||
"""
|
||||
if not mapping or not isinstance(mapping, dict):
|
||||
return None, None
|
||||
|
||||
account_id = cls._pick_string(mapping, cls.ACCOUNT_ID_KEYS)
|
||||
scope = cls._pick_string(mapping, cls.SCOPE_KEYS)
|
||||
return account_id, scope
|
||||
|
||||
@staticmethod
|
||||
def _pick_string(mapping: Dict[str, Any], keys: Tuple[str, ...]) -> Optional[str]:
|
||||
"""按优先级从字典里挑选第一个有效字符串。
|
||||
|
||||
Args:
|
||||
mapping: 待查询的字典。
|
||||
keys: 按优先级排列的候选键名。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 第一个规范化后非空的字符串值;若不存在则返回 ``None``。
|
||||
"""
|
||||
for key in keys:
|
||||
value = mapping.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
normalized = str(value).strip()
|
||||
if normalized:
|
||||
return normalized
|
||||
return None
|
||||
141
src/platform_io/routing.py
Normal file
141
src/platform_io/routing.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""提供 Platform IO 的轻量路由绑定表。"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .types import RouteBinding, RouteKey
|
||||
|
||||
|
||||
class RouteTable:
|
||||
"""维护单张路由绑定表。
|
||||
|
||||
该实现不负责裁决“唯一 owner”,只负责保存绑定,并按
|
||||
``RouteKey.resolution_order()`` 解析出候选绑定列表。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化空路由绑定表。"""
|
||||
|
||||
self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {}
|
||||
|
||||
def bind(self, binding: RouteBinding) -> None:
|
||||
"""注册或更新一条路由绑定。
|
||||
|
||||
Args:
|
||||
binding: 要保存的路由绑定。
|
||||
"""
|
||||
|
||||
self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding
|
||||
|
||||
def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]:
|
||||
"""移除指定路由键上的绑定。
|
||||
|
||||
Args:
|
||||
route_key: 要移除绑定的路由键。
|
||||
driver_id: 可选的驱动 ID;为空时移除该路由键下全部绑定。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 被移除的绑定列表。
|
||||
"""
|
||||
|
||||
binding_map = self._bindings.get(route_key)
|
||||
if not binding_map:
|
||||
return []
|
||||
|
||||
if driver_id is None:
|
||||
removed = list(binding_map.values())
|
||||
self._bindings.pop(route_key, None)
|
||||
return self._sort_bindings(removed)
|
||||
|
||||
removed_binding = binding_map.pop(driver_id, None)
|
||||
if not binding_map:
|
||||
self._bindings.pop(route_key, None)
|
||||
return [removed_binding] if removed_binding is not None else []
|
||||
|
||||
def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]:
|
||||
"""移除某个驱动在整张表上的全部绑定。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除绑定的驱动 ID。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 被移除的绑定列表。
|
||||
"""
|
||||
|
||||
removed_bindings: List[RouteBinding] = []
|
||||
empty_route_keys: List[RouteKey] = []
|
||||
for route_key, binding_map in self._bindings.items():
|
||||
removed_binding = binding_map.pop(driver_id, None)
|
||||
if removed_binding is not None:
|
||||
removed_bindings.append(removed_binding)
|
||||
if not binding_map:
|
||||
empty_route_keys.append(route_key)
|
||||
|
||||
for route_key in empty_route_keys:
|
||||
self._bindings.pop(route_key, None)
|
||||
|
||||
return self._sort_bindings(removed_bindings)
|
||||
|
||||
def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]:
|
||||
"""列出当前路由表中的绑定。
|
||||
|
||||
Args:
|
||||
route_key: 可选的路由键过滤条件。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 当前绑定列表。
|
||||
"""
|
||||
|
||||
if route_key is None:
|
||||
bindings: List[RouteBinding] = []
|
||||
for binding_map in self._bindings.values():
|
||||
bindings.extend(binding_map.values())
|
||||
return self._sort_bindings(bindings)
|
||||
|
||||
binding_map = self._bindings.get(route_key, {})
|
||||
return self._sort_bindings(list(binding_map.values()))
|
||||
|
||||
def resolve_bindings(self, route_key: RouteKey) -> List[RouteBinding]:
|
||||
"""按从具体到宽泛的顺序解析路由候选绑定。
|
||||
|
||||
Args:
|
||||
route_key: 待解析的路由键。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 去重后的候选绑定列表。
|
||||
"""
|
||||
|
||||
resolved_bindings: List[RouteBinding] = []
|
||||
seen_driver_ids: set[str] = set()
|
||||
for candidate_key in route_key.resolution_order():
|
||||
for binding in self.list_bindings(candidate_key):
|
||||
if binding.driver_id in seen_driver_ids:
|
||||
continue
|
||||
seen_driver_ids.add(binding.driver_id)
|
||||
resolved_bindings.append(binding)
|
||||
return resolved_bindings
|
||||
|
||||
def has_binding_for_driver(self, route_key: RouteKey, driver_id: str) -> bool:
|
||||
"""判断指定驱动是否在当前路由键解析结果中。
|
||||
|
||||
Args:
|
||||
route_key: 待解析的路由键。
|
||||
driver_id: 目标驱动 ID。
|
||||
|
||||
Returns:
|
||||
bool: 若驱动存在于解析结果中则返回 ``True``。
|
||||
"""
|
||||
|
||||
return any(binding.driver_id == driver_id for binding in self.resolve_bindings(route_key))
|
||||
|
||||
@staticmethod
|
||||
def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]:
|
||||
"""按优先级降序排列绑定列表。
|
||||
|
||||
Args:
|
||||
bindings: 待排序的绑定列表。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 排序后的绑定列表。
|
||||
"""
|
||||
|
||||
return sorted(bindings, key=lambda item: item.priority, reverse=True)
|
||||
264
src/platform_io/types.py
Normal file
264
src/platform_io/types.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""定义 Platform IO 中间层共享的核心类型。
|
||||
|
||||
本模块放置路由、驱动、入站与出站等规范化数据结构,供 Broker
|
||||
层在 legacy 适配器链路和 plugin 适配器链路之间复用。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class DriverKind(str, Enum):
|
||||
"""底层收发驱动类型枚举。"""
|
||||
|
||||
LEGACY = "legacy"
|
||||
PLUGIN = "plugin"
|
||||
|
||||
|
||||
class DeliveryStatus(str, Enum):
|
||||
"""统一出站回执状态枚举。"""
|
||||
|
||||
PENDING = "pending"
|
||||
SENT = "sent"
|
||||
FAILED = "failed"
|
||||
DROPPED = "dropped"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RouteKey:
|
||||
"""用于 Platform IO 路由决策的唯一键。
|
||||
|
||||
路由解析会按照“从最具体到最宽泛”的顺序进行回退,这样同一平台
|
||||
后续就能自然支持按账号、自定义 scope 等更细粒度的归属控制。
|
||||
|
||||
Attributes:
|
||||
platform: 平台名称,例如 ``qq``。
|
||||
account_id: 机器人账号 ID 或 self ID,用于区分同平台多身份。
|
||||
scope: 额外路由作用域,预留给未来的连接实例、租户或子通道等维度。
|
||||
"""
|
||||
|
||||
platform: str
|
||||
account_id: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化并校验路由键字段。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``platform`` 规范化后为空时抛出。
|
||||
"""
|
||||
platform = str(self.platform).strip()
|
||||
account_id = str(self.account_id).strip() if self.account_id is not None else None
|
||||
scope = str(self.scope).strip() if self.scope is not None else None
|
||||
|
||||
if not platform:
|
||||
raise ValueError("RouteKey.platform 不能为空")
|
||||
|
||||
object.__setattr__(self, "platform", platform)
|
||||
object.__setattr__(self, "account_id", account_id or None)
|
||||
object.__setattr__(self, "scope", scope or None)
|
||||
|
||||
def resolution_order(self) -> List["RouteKey"]:
|
||||
"""返回从最具体到最宽泛的路由匹配顺序。
|
||||
|
||||
Returns:
|
||||
List[RouteKey]: 按回退优先级排序的候选路由键列表。
|
||||
"""
|
||||
|
||||
keys: List[RouteKey] = [self]
|
||||
|
||||
if self.account_id is not None and self.scope is not None:
|
||||
keys.append(RouteKey(platform=self.platform, account_id=self.account_id, scope=None))
|
||||
keys.append(RouteKey(platform=self.platform, account_id=None, scope=self.scope))
|
||||
elif self.account_id is not None:
|
||||
keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
|
||||
elif self.scope is not None:
|
||||
keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
|
||||
|
||||
default_key = RouteKey(platform=self.platform, account_id=None, scope=None)
|
||||
if default_key not in keys:
|
||||
keys.append(default_key)
|
||||
|
||||
return keys
|
||||
|
||||
def to_dedupe_scope(self) -> str:
|
||||
"""生成跨驱动共享的去重作用域字符串。
|
||||
|
||||
Returns:
|
||||
str: 用于入站消息去重的稳定文本作用域键。
|
||||
"""
|
||||
|
||||
account_id = self.account_id or "*"
|
||||
scope = self.scope or "*"
|
||||
return f"{self.platform}:{account_id}:{scope}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DriverDescriptor:
|
||||
"""描述一个已注册的 Platform IO 驱动。
|
||||
|
||||
Attributes:
|
||||
driver_id: Broker 层内全局唯一的驱动标识。
|
||||
kind: 驱动实现类型,例如 legacy 或 plugin。
|
||||
platform: 驱动负责的平台名称。
|
||||
account_id: 可选的账号 ID 或 self ID。
|
||||
scope: 可选的额外路由作用域。
|
||||
plugin_id: 当驱动来自插件适配器时,对应的插件 ID。
|
||||
metadata: 预留给路由策略或观测能力的额外驱动元数据。
|
||||
"""
|
||||
|
||||
driver_id: str
|
||||
kind: DriverKind
|
||||
platform: str
|
||||
account_id: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
plugin_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化并校验驱动描述字段。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``driver_id`` 或 ``platform`` 规范化后为空时抛出。
|
||||
"""
|
||||
driver_id = str(self.driver_id).strip()
|
||||
platform = str(self.platform).strip()
|
||||
plugin_id = str(self.plugin_id).strip() if self.plugin_id is not None else None
|
||||
|
||||
if not driver_id:
|
||||
raise ValueError("DriverDescriptor.driver_id 不能为空")
|
||||
if not platform:
|
||||
raise ValueError("DriverDescriptor.platform 不能为空")
|
||||
|
||||
object.__setattr__(self, "driver_id", driver_id)
|
||||
object.__setattr__(self, "platform", platform)
|
||||
object.__setattr__(self, "plugin_id", plugin_id or None)
|
||||
|
||||
@property
|
||||
def route_key(self) -> RouteKey:
|
||||
"""构造该驱动默认代表的路由键。
|
||||
|
||||
Returns:
|
||||
RouteKey: 当前驱动描述对应的规范化路由键。
|
||||
"""
|
||||
return RouteKey(platform=self.platform, account_id=self.account_id, scope=self.scope)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RouteBinding:
|
||||
"""表示一条从路由键到驱动的绑定关系。
|
||||
|
||||
Attributes:
|
||||
route_key: 该绑定覆盖的路由键。
|
||||
driver_id: 拥有该路由的驱动 ID。
|
||||
driver_kind: 绑定驱动的类型。
|
||||
priority: 当同一路由键存在多条绑定时使用的相对优先级。
|
||||
metadata: 预留给未来路由策略的额外绑定元数据。
|
||||
"""
|
||||
|
||||
route_key: RouteKey
|
||||
driver_id: str
|
||||
driver_kind: DriverKind
|
||||
priority: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化并校验绑定字段。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``driver_id`` 规范化后为空时抛出。
|
||||
"""
|
||||
driver_id = str(self.driver_id).strip()
|
||||
if not driver_id:
|
||||
raise ValueError("RouteBinding.driver_id 不能为空")
|
||||
object.__setattr__(self, "driver_id", driver_id)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class InboundMessageEnvelope:
|
||||
"""封装一次由驱动产出的规范化入站消息。
|
||||
|
||||
Attributes:
|
||||
route_key: 该入站消息解析出的路由键。
|
||||
driver_id: 产出该消息的驱动 ID。
|
||||
driver_kind: 产出该消息的驱动类型。
|
||||
external_message_id: 可选的平台侧消息 ID,用于去重。
|
||||
dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时,
|
||||
可由上游驱动提供稳定的技术性幂等键。若这里为空,中间层仅会继续
|
||||
回退到 ``external_message_id`` 或 ``session_message.message_id``,
|
||||
不会再根据 ``payload`` 内容猜测语义去重键。
|
||||
session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。
|
||||
payload: 可选的原始字典载荷,供延迟转换或调试使用。
|
||||
metadata: 额外入站元数据,例如连接信息或追踪上下文。
|
||||
"""
|
||||
|
||||
route_key: RouteKey
|
||||
driver_id: str
|
||||
driver_kind: DriverKind
|
||||
external_message_id: Optional[str] = None
|
||||
dedupe_key: Optional[str] = None
|
||||
session_message: Optional["SessionMessage"] = None
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DeliveryReceipt:
|
||||
"""表示一次出站投递尝试的统一结果。
|
||||
|
||||
Attributes:
|
||||
internal_message_id: Broker 跟踪的内部 ``SessionMessage.message_id``。
|
||||
route_key: 本次投递使用的路由键。
|
||||
status: 规范化后的投递状态。
|
||||
driver_id: 实际处理该投递的驱动 ID,可为空。
|
||||
driver_kind: 实际处理该投递的驱动类型,可为空。
|
||||
external_message_id: 驱动或适配器返回的平台侧消息 ID,可为空。
|
||||
error: 投递失败时的错误信息,可为空。
|
||||
metadata: 预留给回执、时间戳或平台特有信息的额外元数据。
|
||||
"""
|
||||
|
||||
internal_message_id: str
|
||||
route_key: RouteKey
|
||||
status: DeliveryStatus
|
||||
driver_id: Optional[str] = None
|
||||
driver_kind: Optional[DriverKind] = None
|
||||
external_message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DeliveryBatch:
|
||||
"""表示一次广播式出站投递的批量结果。
|
||||
|
||||
Attributes:
|
||||
internal_message_id: 内部消息 ID。
|
||||
route_key: 本次投递使用的路由键。
|
||||
receipts: 各条路由的独立投递回执列表。
|
||||
"""
|
||||
|
||||
internal_message_id: str
|
||||
route_key: RouteKey
|
||||
receipts: List[DeliveryReceipt] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def sent_receipts(self) -> List[DeliveryReceipt]:
|
||||
"""返回全部发送成功的回执。"""
|
||||
|
||||
return [receipt for receipt in self.receipts if receipt.status == DeliveryStatus.SENT]
|
||||
|
||||
@property
|
||||
def failed_receipts(self) -> List[DeliveryReceipt]:
|
||||
"""返回全部发送失败的回执。"""
|
||||
|
||||
return [receipt for receipt in self.receipts if receipt.status != DeliveryStatus.SENT]
|
||||
|
||||
@property
|
||||
def has_success(self) -> bool:
|
||||
"""返回当前批量投递是否至少命中一条成功回执。"""
|
||||
|
||||
return bool(self.sent_receipts)
|
||||
@@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
|
||||
|
||||
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
|
||||
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
|
||||
|
||||
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
|
||||
"""Runner 启动时可视为已满足的外部插件依赖版本映射(JSON 对象)"""
|
||||
|
||||
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
|
||||
"""Runner 启动时注入的全局配置快照(JSON 对象)"""
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.component_registry import RegisteredComponent
|
||||
from src.plugin_runtime.host.api_registry import APIEntry
|
||||
from src.plugin_runtime.host.component_registry import ComponentEntry
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
|
||||
@@ -14,18 +15,311 @@ class _RuntimeComponentManagerProtocol(Protocol):
|
||||
@property
|
||||
def supervisors(self) -> List["PluginSupervisor"]: ...
|
||||
|
||||
def _normalize_component_type(self, component_type: str) -> str: ...
|
||||
|
||||
def _is_api_component_type(self, component_type: str) -> bool: ...
|
||||
|
||||
def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
|
||||
|
||||
def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
|
||||
|
||||
def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ...
|
||||
|
||||
def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ...
|
||||
|
||||
def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
|
||||
|
||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
|
||||
|
||||
def _resolve_api_target(
|
||||
self,
|
||||
caller_plugin_id: str,
|
||||
api_name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
|
||||
|
||||
def _resolve_api_toggle_target(
|
||||
self,
|
||||
name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
|
||||
|
||||
def _resolve_component_toggle_target(
|
||||
self, name: str, component_type: str
|
||||
) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
|
||||
) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
|
||||
|
||||
def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ...
|
||||
|
||||
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
|
||||
|
||||
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ...
|
||||
|
||||
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ...
|
||||
|
||||
|
||||
class RuntimeComponentCapabilityMixin:
|
||||
@staticmethod
|
||||
def _normalize_component_type(component_type: str) -> str:
|
||||
"""规范化组件类型名称。
|
||||
|
||||
Args:
|
||||
component_type: 原始组件类型。
|
||||
|
||||
Returns:
|
||||
str: 统一转为大写后的组件类型名。
|
||||
"""
|
||||
|
||||
return str(component_type or "").strip().upper()
|
||||
|
||||
@classmethod
|
||||
def _is_api_component_type(cls, component_type: str) -> bool:
|
||||
"""判断组件类型是否为 API。
|
||||
|
||||
Args:
|
||||
component_type: 原始组件类型。
|
||||
|
||||
Returns:
|
||||
bool: 是否为 API 组件类型。
|
||||
"""
|
||||
|
||||
return cls._normalize_component_type(component_type) == "API"
|
||||
|
||||
@staticmethod
|
||||
def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]:
|
||||
"""将 API 组件条目序列化为能力返回值。
|
||||
|
||||
Args:
|
||||
entry: API 组件条目。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。
|
||||
"""
|
||||
|
||||
return {
|
||||
"name": entry.name,
|
||||
"full_name": entry.full_name,
|
||||
"plugin_id": entry.plugin_id,
|
||||
"description": entry.description,
|
||||
"version": entry.version,
|
||||
"public": entry.public,
|
||||
"enabled": entry.enabled,
|
||||
"dynamic": entry.dynamic,
|
||||
"offline_reason": entry.offline_reason,
|
||||
"metadata": dict(entry.metadata),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]:
|
||||
"""将 API 条目序列化为通用组件视图。
|
||||
|
||||
Args:
|
||||
entry: API 组件条目。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。
|
||||
"""
|
||||
|
||||
serialized_entry = cls._serialize_api_entry(entry)
|
||||
return {
|
||||
"name": serialized_entry["name"],
|
||||
"full_name": serialized_entry["full_name"],
|
||||
"type": "API",
|
||||
"enabled": serialized_entry["enabled"],
|
||||
"metadata": serialized_entry["metadata"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool:
|
||||
"""判断某个 API 是否对调用方可见。
|
||||
|
||||
Args:
|
||||
entry: 目标 API 组件条目。
|
||||
caller_plugin_id: 调用方插件 ID。
|
||||
|
||||
Returns:
|
||||
bool: 是否允许当前插件可见并调用。
|
||||
"""
|
||||
|
||||
return entry.plugin_id == caller_plugin_id or entry.public
|
||||
|
||||
@staticmethod
|
||||
def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
|
||||
"""规范化 API 名称与版本参数。
|
||||
|
||||
支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
|
||||
"""
|
||||
|
||||
normalized_api_name = str(api_name or "").strip()
|
||||
normalized_version = str(version or "").strip()
|
||||
if normalized_api_name and not normalized_version and "@" in normalized_api_name:
|
||||
candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
|
||||
candidate_name = candidate_name.strip()
|
||||
candidate_version = candidate_version.strip()
|
||||
if candidate_name and candidate_version:
|
||||
normalized_api_name = candidate_name
|
||||
normalized_version = candidate_version
|
||||
return normalized_api_name, normalized_version
|
||||
|
||||
@staticmethod
|
||||
def _build_api_unavailable_error(entry: "APIEntry") -> str:
|
||||
"""构造 API 当前不可用时的错误信息。"""
|
||||
|
||||
if entry.offline_reason:
|
||||
return entry.offline_reason
|
||||
return f"API {entry.registry_key} 当前不可用"
|
||||
|
||||
def _resolve_api_target(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
caller_plugin_id: str,
|
||||
api_name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
|
||||
"""解析 API 名称到唯一可调用的目标组件。
|
||||
|
||||
Args:
|
||||
caller_plugin_id: 调用方插件 ID。
|
||||
api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
|
||||
version: 可选的 API 版本。
|
||||
|
||||
Returns:
|
||||
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
|
||||
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
|
||||
"""
|
||||
|
||||
normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
|
||||
if not normalized_api_name:
|
||||
return None, None, "缺少必要参数 api_name"
|
||||
|
||||
if "." in normalized_api_name:
|
||||
target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1)
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(target_plugin_id)
|
||||
except RuntimeError as exc:
|
||||
return None, None, str(exc)
|
||||
|
||||
if supervisor is None:
|
||||
return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
|
||||
|
||||
entries = supervisor.api_registry.get_apis(
|
||||
plugin_id=target_plugin_id,
|
||||
name=target_api_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
)
|
||||
visible_enabled_entries = [
|
||||
entry
|
||||
for entry in entries
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
|
||||
]
|
||||
visible_disabled_entries = [
|
||||
entry
|
||||
for entry in entries
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
|
||||
]
|
||||
if len(visible_enabled_entries) == 1:
|
||||
return supervisor, visible_enabled_entries[0], None
|
||||
if len(visible_enabled_entries) > 1:
|
||||
return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
|
||||
if visible_disabled_entries:
|
||||
if len(visible_disabled_entries) == 1:
|
||||
return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
|
||||
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
|
||||
if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
|
||||
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
|
||||
if normalized_version:
|
||||
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
|
||||
return None, None, f"未找到 API: {normalized_api_name}"
|
||||
|
||||
visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
hidden_match_exists = False
|
||||
for supervisor in self.supervisors:
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
name=normalized_api_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
):
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id):
|
||||
if entry.enabled:
|
||||
visible_enabled_matches.append((supervisor, entry))
|
||||
else:
|
||||
visible_disabled_matches.append((supervisor, entry))
|
||||
else:
|
||||
hidden_match_exists = True
|
||||
|
||||
if len(visible_enabled_matches) == 1:
|
||||
return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
|
||||
if len(visible_enabled_matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
|
||||
if visible_disabled_matches:
|
||||
if len(visible_disabled_matches) == 1:
|
||||
return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
|
||||
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
|
||||
if hidden_match_exists:
|
||||
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
|
||||
if normalized_version:
|
||||
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
|
||||
return None, None, f"未找到 API: {normalized_api_name}"
|
||||
|
||||
def _resolve_api_toggle_target(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
|
||||
"""解析需要启用或禁用的 API 组件。
|
||||
|
||||
Args:
|
||||
name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
|
||||
version: 可选的 API 版本。
|
||||
|
||||
Returns:
|
||||
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
|
||||
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
|
||||
"""
|
||||
|
||||
normalized_name, normalized_version = self._normalize_api_reference(name, version)
|
||||
if not normalized_name:
|
||||
return None, None, "缺少必要参数 name"
|
||||
|
||||
if "." in normalized_name:
|
||||
plugin_id, api_name = normalized_name.rsplit(".", 1)
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(plugin_id)
|
||||
except RuntimeError as exc:
|
||||
return None, None, str(exc)
|
||||
|
||||
if supervisor is None:
|
||||
return None, None, f"未找到 API 提供方插件: {plugin_id}"
|
||||
|
||||
entries = supervisor.api_registry.get_apis(
|
||||
plugin_id=plugin_id,
|
||||
name=api_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
)
|
||||
if len(entries) == 1:
|
||||
return supervisor, entries[0], None
|
||||
if entries:
|
||||
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
|
||||
return None, None, f"未找到 API: {normalized_name}"
|
||||
|
||||
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
for supervisor in self.supervisors:
|
||||
matches.extend(
|
||||
(supervisor, entry)
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
name=normalized_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
)
|
||||
)
|
||||
|
||||
if len(matches) == 1:
|
||||
return matches[0][0], matches[0][1], None
|
||||
if len(matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
|
||||
return None, None, f"未找到 API: {normalized_name}"
|
||||
|
||||
async def _cap_component_get_all_plugins(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
@@ -46,6 +340,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
}
|
||||
for component in comps
|
||||
]
|
||||
components_list.extend(
|
||||
self._serialize_api_component_entry(entry)
|
||||
for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False)
|
||||
)
|
||||
result[pid] = {
|
||||
"name": pid,
|
||||
"version": reg.plugin_version,
|
||||
@@ -96,30 +394,35 @@ class RuntimeComponentCapabilityMixin:
|
||||
|
||||
def _resolve_component_toggle_target(
|
||||
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
|
||||
) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
|
||||
short_name_matches: List["RegisteredComponent"] = []
|
||||
) -> tuple[Optional["ComponentEntry"], Optional[str]]:
|
||||
normalized_component_type = self._normalize_component_type(component_type)
|
||||
short_name_matches: List["ComponentEntry"] = []
|
||||
for sv in self.supervisors:
|
||||
comp = sv.component_registry.get_component(name)
|
||||
if comp is not None and comp.component_type == component_type:
|
||||
if comp is not None and comp.component_type == normalized_component_type:
|
||||
return comp, None
|
||||
|
||||
short_name_matches.extend(
|
||||
candidate
|
||||
for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False)
|
||||
for candidate in sv.component_registry.get_components_by_type(
|
||||
normalized_component_type,
|
||||
enabled_only=False,
|
||||
)
|
||||
if candidate.name == name
|
||||
)
|
||||
|
||||
if len(short_name_matches) == 1:
|
||||
return short_name_matches[0], None
|
||||
if len(short_name_matches) > 1:
|
||||
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
|
||||
return None, f"未找到组件: {name} ({component_type})"
|
||||
return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name"
|
||||
return None, f"未找到组件: {name} ({normalized_component_type})"
|
||||
|
||||
async def _cap_component_enable(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
name: str = args.get("name", "")
|
||||
component_type: str = args.get("component_type", "")
|
||||
version: str = args.get("version", "")
|
||||
scope: str = args.get("scope", "global")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
if not name or not component_type:
|
||||
@@ -127,6 +430,13 @@ class RuntimeComponentCapabilityMixin:
|
||||
if scope != "global" or stream_id:
|
||||
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
|
||||
|
||||
if self._is_api_component_type(component_type):
|
||||
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
|
||||
if supervisor is None or api_entry is None:
|
||||
return {"success": False, "error": error or f"未找到 API: {name}"}
|
||||
supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
|
||||
return {"success": True}
|
||||
|
||||
comp, error = self._resolve_component_toggle_target(name, component_type)
|
||||
if comp is None:
|
||||
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
|
||||
@@ -139,6 +449,7 @@ class RuntimeComponentCapabilityMixin:
|
||||
) -> Any:
|
||||
name: str = args.get("name", "")
|
||||
component_type: str = args.get("component_type", "")
|
||||
version: str = args.get("version", "")
|
||||
scope: str = args.get("scope", "global")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
if not name or not component_type:
|
||||
@@ -146,6 +457,13 @@ class RuntimeComponentCapabilityMixin:
|
||||
if scope != "global" or stream_id:
|
||||
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
|
||||
|
||||
if self._is_api_component_type(component_type):
|
||||
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
|
||||
if supervisor is None or api_entry is None:
|
||||
return {"success": False, "error": error or f"未找到 API: {name}"}
|
||||
supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
|
||||
return {"success": True}
|
||||
|
||||
comp, error = self._resolve_component_toggle_target(name, component_type)
|
||||
if comp is None:
|
||||
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
|
||||
@@ -168,33 +486,14 @@ class RuntimeComponentCapabilityMixin:
|
||||
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
|
||||
|
||||
try:
|
||||
registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
|
||||
except RuntimeError as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
if registered_supervisor is not None:
|
||||
try:
|
||||
reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
|
||||
if reloaded:
|
||||
return {"success": True, "count": 1}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
for sv in self.supervisors:
|
||||
for pdir in sv._plugin_dirs:
|
||||
if (pdir / plugin_name).is_dir():
|
||||
try:
|
||||
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
|
||||
if reloaded:
|
||||
return {"success": True, "count": 1}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
if loaded:
|
||||
return {"success": True, "count": 1}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
|
||||
|
||||
async def _cap_component_unload_plugin(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
@@ -216,17 +515,204 @@ class RuntimeComponentCapabilityMixin:
|
||||
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
|
||||
|
||||
try:
|
||||
sv = self._get_supervisor_for_plugin(plugin_name)
|
||||
reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
if reloaded:
|
||||
return {"success": True}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
|
||||
|
||||
async def _cap_api_call(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""调用其他插件公开的 API。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前调用方插件 ID。
|
||||
capability: 能力名称。
|
||||
args: 能力参数。
|
||||
|
||||
Returns:
|
||||
Any: API 调用结果。
|
||||
"""
|
||||
|
||||
del capability
|
||||
api_name = str(args.get("api_name", "") or "").strip()
|
||||
version = str(args.get("version", "") or "").strip()
|
||||
api_args = args.get("args", {})
|
||||
if not isinstance(api_args, dict):
|
||||
return {"success": False, "error": "参数 args 必须为字典"}
|
||||
|
||||
supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version)
|
||||
if supervisor is None or entry is None:
|
||||
return {"success": False, "error": error or "API 解析失败"}
|
||||
|
||||
invoke_args = dict(api_args)
|
||||
if entry.dynamic:
|
||||
invoke_args.setdefault("__maibot_api_name__", entry.name)
|
||||
invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
|
||||
invoke_args.setdefault("__maibot_api_version__", entry.version)
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_api(
|
||||
plugin_id=entry.plugin_id,
|
||||
component_name=entry.handler_name,
|
||||
args=invoke_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
if response.error:
|
||||
return {"success": False, "error": response.error.get("message", "API 调用失败")}
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
if not bool(payload.get("success", False)):
|
||||
result = payload.get("result")
|
||||
return {"success": False, "error": "" if result is None else str(result)}
|
||||
return {"success": True, "result": payload.get("result")}
|
||||
|
||||
async def _cap_api_get(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""获取当前插件可见的单个 API 元信息。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前调用方插件 ID。
|
||||
capability: 能力名称。
|
||||
args: 能力参数。
|
||||
|
||||
Returns:
|
||||
Any: API 元信息或 ``None``。
|
||||
"""
|
||||
|
||||
del capability
|
||||
api_name = str(args.get("api_name", "") or "").strip()
|
||||
version = str(args.get("version", "") or "").strip()
|
||||
if not api_name:
|
||||
return {"success": False, "error": "缺少必要参数 api_name"}
|
||||
|
||||
supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version)
|
||||
if supervisor is None or entry is None:
|
||||
return {"success": True, "api": None}
|
||||
return {"success": True, "api": self._serialize_api_entry(entry)}
|
||||
|
||||
async def _cap_api_list(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""列出当前插件可见的 API 列表。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前调用方插件 ID。
|
||||
capability: 能力名称。
|
||||
args: 能力参数。
|
||||
|
||||
Returns:
|
||||
Any: API 元信息列表。
|
||||
"""
|
||||
|
||||
del capability
|
||||
target_plugin_id = str(args.get("plugin_id", "") or "").strip()
|
||||
api_name, version = self._normalize_api_reference(
|
||||
str(args.get("api_name", args.get("name", "")) or ""),
|
||||
str(args.get("version", "") or ""),
|
||||
)
|
||||
apis: List[Dict[str, Any]] = []
|
||||
for supervisor in self.supervisors:
|
||||
apis.extend(
|
||||
self._serialize_api_entry(entry)
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
plugin_id=target_plugin_id or None,
|
||||
name=api_name,
|
||||
version=version,
|
||||
enabled_only=True,
|
||||
)
|
||||
if self._is_api_visible_to_plugin(entry, plugin_id)
|
||||
)
|
||||
|
||||
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
|
||||
return {"success": True, "apis": apis}
|
||||
|
||||
async def _cap_api_replace_dynamic(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""替换插件自行维护的动态 API 列表。"""
|
||||
|
||||
del capability
|
||||
raw_apis = args.get("apis", [])
|
||||
offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
|
||||
if not isinstance(raw_apis, list):
|
||||
return {"success": False, "error": "参数 apis 必须为列表"}
|
||||
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(plugin_id)
|
||||
except RuntimeError as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
if sv is not None:
|
||||
try:
|
||||
reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
|
||||
if reloaded:
|
||||
return {"success": True}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
if supervisor is None:
|
||||
return {"success": False, "error": f"未找到插件: {plugin_id}"}
|
||||
|
||||
normalized_components: List[Dict[str, Any]] = []
|
||||
seen_registry_keys: set[str] = set()
|
||||
for index, raw_api in enumerate(raw_apis):
|
||||
if not isinstance(raw_api, dict):
|
||||
return {"success": False, "error": f"apis[{index}] 必须为字典"}
|
||||
|
||||
api_name = str(raw_api.get("name", "") or "").strip()
|
||||
component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
|
||||
if not api_name:
|
||||
return {"success": False, "error": f"apis[{index}] 缺少 name"}
|
||||
if not self._is_api_component_type(component_type):
|
||||
return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
|
||||
|
||||
metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
|
||||
normalized_metadata = dict(metadata)
|
||||
normalized_metadata["dynamic"] = True
|
||||
version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
|
||||
registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
|
||||
if registry_key in seen_registry_keys:
|
||||
return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
|
||||
seen_registry_keys.add(registry_key)
|
||||
|
||||
existing_entry = supervisor.api_registry.get_api(
|
||||
plugin_id,
|
||||
api_name,
|
||||
version=version,
|
||||
enabled_only=False,
|
||||
)
|
||||
if existing_entry is not None and not existing_entry.dynamic:
|
||||
return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
|
||||
|
||||
normalized_components.append(
|
||||
{
|
||||
"name": api_name,
|
||||
"component_type": "API",
|
||||
"metadata": normalized_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
|
||||
plugin_id,
|
||||
normalized_components,
|
||||
offline_reason=offline_reason,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"count": registered_count,
|
||||
"offlined": offlined_count,
|
||||
}
|
||||
|
||||
@@ -238,14 +238,14 @@ class RuntimeCoreCapabilityMixin:
|
||||
return {"success": False, "value": None, "error": str(e)}
|
||||
|
||||
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.core.component_registry import component_registry as core_registry
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
key: str = args.get("key", "")
|
||||
default = args.get("default")
|
||||
|
||||
try:
|
||||
config = core_registry.get_plugin_config(plugin_name)
|
||||
config = component_query_service.get_plugin_config(plugin_name)
|
||||
if config is None:
|
||||
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
|
||||
|
||||
@@ -258,11 +258,11 @@ class RuntimeCoreCapabilityMixin:
|
||||
return {"success": False, "value": default, "error": str(e)}
|
||||
|
||||
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.core.component_registry import component_registry as core_registry
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
try:
|
||||
config = core_registry.get_plugin_config(plugin_name)
|
||||
config = component_query_service.get_plugin_config(plugin_name)
|
||||
if config is None:
|
||||
return {"success": True, "value": {}}
|
||||
return {"success": True, "value": config}
|
||||
|
||||
@@ -648,10 +648,10 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.core.component_registry import component_registry as core_registry
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
try:
|
||||
tools = core_registry.get_llm_available_tools()
|
||||
tools = component_query_service.get_llm_available_tools()
|
||||
return {
|
||||
"success": True,
|
||||
"tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()],
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.host.capability_service import CapabilityImpl
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -13,66 +14,80 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
|
||||
"""向指定 Supervisor 注册主程序提供的能力实现。"""
|
||||
cap_service = supervisor.capability_service
|
||||
|
||||
cap_service.register_capability("send.text", manager._cap_send_text)
|
||||
cap_service.register_capability("send.emoji", manager._cap_send_emoji)
|
||||
cap_service.register_capability("send.image", manager._cap_send_image)
|
||||
cap_service.register_capability("send.command", manager._cap_send_command)
|
||||
cap_service.register_capability("send.custom", manager._cap_send_custom)
|
||||
def _register(name: str, impl: CapabilityImpl) -> None:
|
||||
"""注册单个能力实现。
|
||||
|
||||
cap_service.register_capability("llm.generate", manager._cap_llm_generate)
|
||||
cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
|
||||
cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models)
|
||||
Args:
|
||||
name: 能力名称。
|
||||
impl: 能力实现函数。
|
||||
"""
|
||||
cap_service.register_capability(name, impl)
|
||||
|
||||
cap_service.register_capability("config.get", manager._cap_config_get)
|
||||
cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin)
|
||||
cap_service.register_capability("config.get_all", manager._cap_config_get_all)
|
||||
_register("send.text", manager._cap_send_text)
|
||||
_register("send.emoji", manager._cap_send_emoji)
|
||||
_register("send.image", manager._cap_send_image)
|
||||
_register("send.command", manager._cap_send_command)
|
||||
_register("send.custom", manager._cap_send_custom)
|
||||
|
||||
cap_service.register_capability("database.query", manager._cap_database_query)
|
||||
cap_service.register_capability("database.save", manager._cap_database_save)
|
||||
cap_service.register_capability("database.get", manager._cap_database_get)
|
||||
cap_service.register_capability("database.delete", manager._cap_database_delete)
|
||||
cap_service.register_capability("database.count", manager._cap_database_count)
|
||||
_register("llm.generate", manager._cap_llm_generate)
|
||||
_register("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
|
||||
_register("llm.get_available_models", manager._cap_llm_get_available_models)
|
||||
|
||||
cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams)
|
||||
cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams)
|
||||
cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams)
|
||||
cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
|
||||
cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
|
||||
_register("config.get", manager._cap_config_get)
|
||||
_register("config.get_plugin", manager._cap_config_get_plugin)
|
||||
_register("config.get_all", manager._cap_config_get_all)
|
||||
|
||||
cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time)
|
||||
cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
|
||||
cap_service.register_capability("message.get_recent", manager._cap_message_get_recent)
|
||||
cap_service.register_capability("message.count_new", manager._cap_message_count_new)
|
||||
cap_service.register_capability("message.build_readable", manager._cap_message_build_readable)
|
||||
_register("database.query", manager._cap_database_query)
|
||||
_register("database.save", manager._cap_database_save)
|
||||
_register("database.get", manager._cap_database_get)
|
||||
_register("database.delete", manager._cap_database_delete)
|
||||
_register("database.count", manager._cap_database_count)
|
||||
|
||||
cap_service.register_capability("person.get_id", manager._cap_person_get_id)
|
||||
cap_service.register_capability("person.get_value", manager._cap_person_get_value)
|
||||
cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name)
|
||||
_register("chat.get_all_streams", manager._cap_chat_get_all_streams)
|
||||
_register("chat.get_group_streams", manager._cap_chat_get_group_streams)
|
||||
_register("chat.get_private_streams", manager._cap_chat_get_private_streams)
|
||||
_register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
|
||||
_register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
|
||||
|
||||
cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description)
|
||||
cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random)
|
||||
cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count)
|
||||
cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions)
|
||||
cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all)
|
||||
cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info)
|
||||
cap_service.register_capability("emoji.register", manager._cap_emoji_register)
|
||||
cap_service.register_capability("emoji.delete", manager._cap_emoji_delete)
|
||||
_register("message.get_by_time", manager._cap_message_get_by_time)
|
||||
_register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
|
||||
_register("message.get_recent", manager._cap_message_get_recent)
|
||||
_register("message.count_new", manager._cap_message_count_new)
|
||||
_register("message.build_readable", manager._cap_message_build_readable)
|
||||
|
||||
cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
|
||||
cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust)
|
||||
cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust)
|
||||
_register("person.get_id", manager._cap_person_get_id)
|
||||
_register("person.get_value", manager._cap_person_get_value)
|
||||
_register("person.get_id_by_name", manager._cap_person_get_id_by_name)
|
||||
|
||||
cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions)
|
||||
_register("emoji.get_by_description", manager._cap_emoji_get_by_description)
|
||||
_register("emoji.get_random", manager._cap_emoji_get_random)
|
||||
_register("emoji.get_count", manager._cap_emoji_get_count)
|
||||
_register("emoji.get_emotions", manager._cap_emoji_get_emotions)
|
||||
_register("emoji.get_all", manager._cap_emoji_get_all)
|
||||
_register("emoji.get_info", manager._cap_emoji_get_info)
|
||||
_register("emoji.register", manager._cap_emoji_register)
|
||||
_register("emoji.delete", manager._cap_emoji_delete)
|
||||
|
||||
cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins)
|
||||
cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info)
|
||||
cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
|
||||
cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
|
||||
cap_service.register_capability("component.enable", manager._cap_component_enable)
|
||||
cap_service.register_capability("component.disable", manager._cap_component_disable)
|
||||
cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin)
|
||||
cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin)
|
||||
cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin)
|
||||
_register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
|
||||
_register("frequency.set_adjust", manager._cap_frequency_set_adjust)
|
||||
_register("frequency.get_adjust", manager._cap_frequency_get_adjust)
|
||||
|
||||
cap_service.register_capability("knowledge.search", manager._cap_knowledge_search)
|
||||
_register("tool.get_definitions", manager._cap_tool_get_definitions)
|
||||
|
||||
_register("api.call", manager._cap_api_call)
|
||||
_register("api.get", manager._cap_api_get)
|
||||
_register("api.list", manager._cap_api_list)
|
||||
_register("api.replace_dynamic", manager._cap_api_replace_dynamic)
|
||||
|
||||
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
|
||||
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
|
||||
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
|
||||
_register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
|
||||
_register("component.enable", manager._cap_component_enable)
|
||||
_register("component.disable", manager._cap_component_disable)
|
||||
_register("component.load_plugin", manager._cap_component_load_plugin)
|
||||
_register("component.unload_plugin", manager._cap_component_unload_plugin)
|
||||
_register("component.reload_plugin", manager._cap_component_reload_plugin)
|
||||
|
||||
_register("knowledge.search", manager._cap_knowledge_search)
|
||||
logger.debug("已注册全部主程序能力实现")
|
||||
|
||||
709
src/plugin_runtime/component_query.py
Normal file
709
src/plugin_runtime/component_query.py
Normal file
@@ -0,0 +1,709 @@
|
||||
"""插件运行时统一组件查询服务。
|
||||
|
||||
该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图,
|
||||
供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
||||
|
||||
logger = get_logger("plugin_runtime.component_query")
|
||||
|
||||
ActionExecutor = Callable[..., Awaitable[Any]]
|
||||
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
|
||||
ToolExecutor = Callable[..., Awaitable[Any]]
|
||||
|
||||
_HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
|
||||
ComponentType.ACTION: "ACTION",
|
||||
ComponentType.COMMAND: "COMMAND",
|
||||
ComponentType.TOOL: "TOOL",
|
||||
}
|
||||
_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
|
||||
"string": ToolParamType.STRING,
|
||||
"integer": ToolParamType.INTEGER,
|
||||
"float": ToolParamType.FLOAT,
|
||||
"boolean": ToolParamType.BOOLEAN,
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
|
||||
|
||||
class ComponentQueryService:
|
||||
"""插件运行时统一组件查询服务。
|
||||
|
||||
该对象不维护独立状态,只读取插件系统中的注册结果。
|
||||
所有注册、删除、配置写入等写操作都被显式禁用。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> "PluginRuntimeManager":
|
||||
"""获取插件运行时管理器单例。
|
||||
|
||||
Returns:
|
||||
PluginRuntimeManager: 当前全局插件运行时管理器。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
def _iter_supervisors(self) -> list["PluginSupervisor"]:
|
||||
"""获取当前所有活跃的插件运行时监督器。
|
||||
|
||||
Returns:
|
||||
list[PluginSupervisor]: 当前运行中的监督器列表。
|
||||
"""
|
||||
|
||||
runtime_manager = self._get_runtime_manager()
|
||||
return list(runtime_manager.supervisors)
|
||||
|
||||
def _iter_component_entries(
|
||||
self,
|
||||
component_type: ComponentType,
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
|
||||
"""遍历指定类型的全部组件条目。
|
||||
|
||||
Args:
|
||||
component_type: 目标组件类型。
|
||||
enabled_only: 是否仅返回启用状态的组件。
|
||||
|
||||
Returns:
|
||||
list[tuple[PluginSupervisor, ComponentEntry]]: ``(监督器, 组件条目)`` 列表。
|
||||
"""
|
||||
|
||||
host_component_type = _HOST_COMPONENT_TYPE_MAP.get(component_type)
|
||||
if host_component_type is None:
|
||||
return []
|
||||
|
||||
collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
|
||||
for supervisor in self._iter_supervisors():
|
||||
for component in supervisor.component_registry.get_components_by_type(
|
||||
host_component_type,
|
||||
enabled_only=enabled_only,
|
||||
):
|
||||
collected_entries.append((supervisor, component))
|
||||
return collected_entries
|
||||
|
||||
@staticmethod
|
||||
def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType:
|
||||
"""规范化动作激活类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始激活类型值。
|
||||
|
||||
Returns:
|
||||
ActionActivationType: 规范化后的激活类型枚举。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
if normalized_value == ActionActivationType.NEVER.value:
|
||||
return ActionActivationType.NEVER
|
||||
if normalized_value == ActionActivationType.RANDOM.value:
|
||||
return ActionActivationType.RANDOM
|
||||
if normalized_value == ActionActivationType.KEYWORD.value:
|
||||
return ActionActivationType.KEYWORD
|
||||
return ActionActivationType.ALWAYS
|
||||
|
||||
@staticmethod
|
||||
def _coerce_float(value: Any, default: float = 0.0) -> float:
|
||||
"""将任意值安全转换为浮点数。
|
||||
|
||||
Args:
|
||||
value: 待转换的输入值。
|
||||
default: 转换失败时返回的默认值。
|
||||
|
||||
Returns:
|
||||
float: 转换后的浮点结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _build_action_info(entry: "ActionEntry") -> ActionInfo:
|
||||
"""将运行时 Action 条目转换为核心动作信息。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Action 条目。
|
||||
|
||||
Returns:
|
||||
ActionInfo: 供核心 Planner 使用的动作信息。
|
||||
"""
|
||||
|
||||
metadata = dict(entry.metadata)
|
||||
raw_action_parameters = metadata.get("action_parameters")
|
||||
action_parameters = (
|
||||
{
|
||||
str(param_name): str(param_description)
|
||||
for param_name, param_description in raw_action_parameters.items()
|
||||
}
|
||||
if isinstance(raw_action_parameters, dict)
|
||||
else {}
|
||||
)
|
||||
action_require = [
|
||||
str(item)
|
||||
for item in (metadata.get("action_require") or [])
|
||||
if item is not None and str(item).strip()
|
||||
]
|
||||
associated_types = [
|
||||
str(item)
|
||||
for item in (metadata.get("associated_types") or [])
|
||||
if item is not None and str(item).strip()
|
||||
]
|
||||
activation_keywords = [
|
||||
str(item)
|
||||
for item in (metadata.get("activation_keywords") or [])
|
||||
if item is not None and str(item).strip()
|
||||
]
|
||||
|
||||
return ActionInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=str(metadata.get("description", "") or ""),
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=metadata,
|
||||
action_parameters=action_parameters,
|
||||
action_require=action_require,
|
||||
associated_types=associated_types,
|
||||
activation_type=ComponentQueryService._coerce_action_activation_type(metadata.get("activation_type")),
|
||||
random_activation_probability=ComponentQueryService._coerce_float(
|
||||
metadata.get("activation_probability"),
|
||||
0.0,
|
||||
),
|
||||
activation_keywords=activation_keywords,
|
||||
parallel_action=bool(metadata.get("parallel_action", False)),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_command_info(entry: "CommandEntry") -> CommandInfo:
|
||||
"""将运行时 Command 条目转换为核心命令信息。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Command 条目。
|
||||
|
||||
Returns:
|
||||
CommandInfo: 供核心命令链使用的命令信息。
|
||||
"""
|
||||
|
||||
metadata = dict(entry.metadata)
|
||||
return CommandInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
description=str(metadata.get("description", "") or ""),
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=metadata,
|
||||
command_pattern=str(metadata.get("command_pattern", "") or ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
|
||||
"""规范化工具参数类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始工具参数类型值。
|
||||
|
||||
Returns:
|
||||
ToolParamType: 规范化后的工具参数类型。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
|
||||
"""将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
Returns:
|
||||
list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。
|
||||
"""
|
||||
|
||||
structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
|
||||
if not structured_parameters and isinstance(entry.parameters_raw, dict):
|
||||
structured_parameters = [
|
||||
{"name": key, **value}
|
||||
for key, value in entry.parameters_raw.items()
|
||||
if isinstance(value, dict)
|
||||
]
|
||||
|
||||
normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
for parameter in structured_parameters:
|
||||
if not isinstance(parameter, dict):
|
||||
continue
|
||||
|
||||
parameter_name = str(parameter.get("name", "") or "").strip()
|
||||
if not parameter_name:
|
||||
continue
|
||||
|
||||
enum_values = parameter.get("enum")
|
||||
normalized_enum_values = (
|
||||
[str(item) for item in enum_values if item is not None]
|
||||
if isinstance(enum_values, list)
|
||||
else None
|
||||
)
|
||||
normalized_parameters.append(
|
||||
(
|
||||
parameter_name,
|
||||
ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
|
||||
str(parameter.get("description", "") or ""),
|
||||
bool(parameter.get("required", True)),
|
||||
normalized_enum_values,
|
||||
)
|
||||
)
|
||||
return normalized_parameters
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
|
||||
"""将运行时 Tool 条目转换为核心工具信息。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
Returns:
|
||||
ToolInfo: 供 ToolExecutor 与能力层使用的工具信息。
|
||||
"""
|
||||
|
||||
return ToolInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.TOOL,
|
||||
description=entry.description,
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=dict(entry.metadata),
|
||||
tool_parameters=ComponentQueryService._build_tool_parameters(entry),
|
||||
tool_description=entry.description,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None:
|
||||
"""记录重复组件名称冲突。
|
||||
|
||||
Args:
|
||||
component_type: 组件类型。
|
||||
component_name: 发生冲突的组件名称。
|
||||
"""
|
||||
|
||||
logger.warning(f"检测到重复{component_type.value}名称 {component_name},将只保留首个匹配项")
|
||||
|
||||
def _get_unique_component_entry(
|
||||
self,
|
||||
component_type: ComponentType,
|
||||
name: str,
|
||||
) -> Optional[tuple["PluginSupervisor", "ComponentEntry"]]:
|
||||
"""按组件短名解析唯一条目。
|
||||
|
||||
Args:
|
||||
component_type: 目标组件类型。
|
||||
name: 组件短名。
|
||||
|
||||
Returns:
|
||||
Optional[tuple[PluginSupervisor, ComponentEntry]]: 唯一命中的组件条目。
|
||||
"""
|
||||
|
||||
matched_entries = [
|
||||
(supervisor, entry)
|
||||
for supervisor, entry in self._iter_component_entries(component_type)
|
||||
if entry.name == name
|
||||
]
|
||||
if not matched_entries:
|
||||
return None
|
||||
if len(matched_entries) > 1:
|
||||
self._log_duplicate_component(component_type, name)
|
||||
return matched_entries[0]
|
||||
|
||||
def _collect_unique_component_infos(
|
||||
self,
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, ComponentInfo]:
|
||||
"""收集某类组件的唯一信息视图。
|
||||
|
||||
Args:
|
||||
component_type: 目标组件类型。
|
||||
|
||||
Returns:
|
||||
Dict[str, ComponentInfo]: 组件名到核心组件信息的映射。
|
||||
"""
|
||||
|
||||
collected_components: Dict[str, ComponentInfo] = {}
|
||||
for _supervisor, entry in self._iter_component_entries(component_type):
|
||||
if entry.name in collected_components:
|
||||
self._log_duplicate_component(component_type, entry.name)
|
||||
continue
|
||||
|
||||
if component_type == ComponentType.ACTION:
|
||||
collected_components[entry.name] = self._build_action_info(entry) # type: ignore[arg-type]
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
collected_components[entry.name] = self._build_command_info(entry) # type: ignore[arg-type]
|
||||
elif component_type == ComponentType.TOOL:
|
||||
collected_components[entry.name] = self._build_tool_info(entry) # type: ignore[arg-type]
|
||||
return collected_components
|
||||
|
||||
@staticmethod
|
||||
def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str:
|
||||
"""从旧 ActionManager 参数中提取聊天流 ID。
|
||||
|
||||
Args:
|
||||
kwargs: 旧动作执行器收到的关键字参数。
|
||||
|
||||
Returns:
|
||||
str: 提取出的 ``stream_id``。
|
||||
"""
|
||||
|
||||
chat_stream = kwargs.get("chat_stream")
|
||||
if chat_stream is not None:
|
||||
try:
|
||||
return str(chat_stream.session_id)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return str(kwargs.get("stream_id", "") or "")
|
||||
|
||||
@staticmethod
|
||||
def _build_action_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ActionExecutor:
|
||||
"""构造动作执行 RPC 闭包。
|
||||
|
||||
Args:
|
||||
supervisor: 负责该组件的监督器。
|
||||
plugin_id: 插件 ID。
|
||||
component_name: 组件名称。
|
||||
|
||||
Returns:
|
||||
ActionExecutor: 兼容旧 Planner 的异步执行器。
|
||||
"""
|
||||
|
||||
async def _executor(**kwargs: Any) -> tuple[bool, str]:
|
||||
"""将核心动作调用桥接到插件运行时。
|
||||
|
||||
Args:
|
||||
**kwargs: 旧 ActionManager 传入的上下文参数。
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: ``(是否成功, 结果说明)``。
|
||||
"""
|
||||
|
||||
invoke_args: Dict[str, Any] = {}
|
||||
action_data = kwargs.get("action_data")
|
||||
if isinstance(action_data, dict):
|
||||
invoke_args.update(action_data)
|
||||
|
||||
stream_id = ComponentQueryService._extract_stream_id_from_action_kwargs(kwargs)
|
||||
invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {}
|
||||
invoke_args["stream_id"] = stream_id
|
||||
invoke_args["chat_id"] = stream_id
|
||||
invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "")
|
||||
|
||||
if (thinking_id := kwargs.get("thinking_id")) is not None:
|
||||
invoke_args["thinking_id"] = str(thinking_id)
|
||||
if isinstance(kwargs.get("cycle_timers"), dict):
|
||||
invoke_args["cycle_timers"] = kwargs["cycle_timers"]
|
||||
if isinstance(kwargs.get("plugin_config"), dict):
|
||||
invoke_args["plugin_config"] = kwargs["plugin_config"]
|
||||
if isinstance(kwargs.get("log_prefix"), str):
|
||||
invoke_args["log_prefix"] = kwargs["log_prefix"]
|
||||
if isinstance(kwargs.get("shutting_down"), bool):
|
||||
invoke_args["shutting_down"] = kwargs["shutting_down"]
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method="plugin.invoke_action",
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=invoke_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
|
||||
return False, str(exc)
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
success = bool(payload.get("success", False))
|
||||
result = payload.get("result")
|
||||
if isinstance(result, (list, tuple)):
|
||||
if len(result) >= 2:
|
||||
return bool(result[0]), "" if result[1] is None else str(result[1])
|
||||
if len(result) == 1:
|
||||
return bool(result[0]), ""
|
||||
if success:
|
||||
return True, "" if result is None else str(result)
|
||||
return False, "" if result is None else str(result)
|
||||
|
||||
return _executor
|
||||
|
||||
@staticmethod
|
||||
def _build_command_executor(
|
||||
supervisor: "PluginSupervisor",
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> CommandExecutor:
|
||||
"""构造命令执行 RPC 闭包。
|
||||
|
||||
Args:
|
||||
supervisor: 负责该组件的监督器。
|
||||
plugin_id: 插件 ID。
|
||||
component_name: 组件名称。
|
||||
metadata: 命令组件元数据。
|
||||
|
||||
Returns:
|
||||
CommandExecutor: 兼容旧消息命令链的执行器。
|
||||
"""
|
||||
|
||||
async def _executor(**kwargs: Any) -> tuple[bool, Optional[str], bool]:
|
||||
"""将核心命令调用桥接到插件运行时。
|
||||
|
||||
Args:
|
||||
**kwargs: 命令执行上下文参数。
|
||||
|
||||
Returns:
|
||||
tuple[bool, Optional[str], bool]: ``(是否成功, 返回文本, 是否拦截后续消息)``。
|
||||
"""
|
||||
|
||||
message = kwargs.get("message")
|
||||
matched_groups = kwargs.get("matched_groups")
|
||||
plugin_config = kwargs.get("plugin_config")
|
||||
invoke_args: Dict[str, Any] = {
|
||||
"text": str(getattr(message, "processed_plain_text", "") or ""),
|
||||
"stream_id": str(getattr(message, "session_id", "") or ""),
|
||||
"matched_groups": matched_groups if isinstance(matched_groups, dict) else {},
|
||||
}
|
||||
if isinstance(plugin_config, dict):
|
||||
invoke_args["plugin_config"] = plugin_config
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method="plugin.invoke_command",
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=invoke_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"运行时 Command {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
|
||||
return False, str(exc), True
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
success = bool(payload.get("success", False))
|
||||
result = payload.get("result")
|
||||
intercept = bool(metadata.get("intercept_message_level", 0))
|
||||
response_text: Optional[str]
|
||||
|
||||
if isinstance(result, (list, tuple)) and len(result) >= 3:
|
||||
response_text = None if result[1] is None else str(result[1])
|
||||
intercept = bool(result[2])
|
||||
else:
|
||||
response_text = None if result is None else str(result)
|
||||
|
||||
return success, response_text, intercept
|
||||
|
||||
return _executor
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor:
|
||||
"""构造工具执行 RPC 闭包。
|
||||
|
||||
Args:
|
||||
supervisor: 负责该组件的监督器。
|
||||
plugin_id: 插件 ID。
|
||||
component_name: 组件名称。
|
||||
|
||||
Returns:
|
||||
ToolExecutor: 兼容旧 ToolExecutor 的异步执行器。
|
||||
"""
|
||||
|
||||
async def _executor(function_args: Dict[str, Any]) -> Any:
|
||||
"""将核心工具调用桥接到插件运行时。
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 插件工具返回结果;若结果不是字典,则会包装为 ``{"content": ...}``。
|
||||
"""
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method="plugin.invoke_tool",
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=function_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"运行时 Tool {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
|
||||
return {"content": f"工具 {component_name} 执行失败: {exc}"}
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
result = payload.get("result")
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return {"content": "" if result is None else str(result)}
|
||||
|
||||
return _executor
|
||||
|
||||
def get_action_info(self, name: str) -> Optional[ActionInfo]:
|
||||
"""获取指定动作的信息。
|
||||
|
||||
Args:
|
||||
name: 动作名称。
|
||||
|
||||
Returns:
|
||||
Optional[ActionInfo]: 匹配到的动作信息。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
|
||||
if matched_entry is None:
|
||||
return None
|
||||
_supervisor, entry = matched_entry
|
||||
return self._build_action_info(entry) # type: ignore[arg-type]
|
||||
|
||||
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
|
||||
"""获取指定动作的执行器。
|
||||
|
||||
Args:
|
||||
name: 动作名称。
|
||||
|
||||
Returns:
|
||||
Optional[ActionExecutor]: 运行时 RPC 执行闭包。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
|
||||
if matched_entry is None:
|
||||
return None
|
||||
supervisor, entry = matched_entry
|
||||
return self._build_action_executor(supervisor, entry.plugin_id, entry.name)
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前默认启用的动作集合。
|
||||
|
||||
Returns:
|
||||
Dict[str, ActionInfo]: 动作名到动作信息的映射。
|
||||
"""
|
||||
|
||||
action_infos = self._collect_unique_component_infos(ComponentType.ACTION)
|
||||
return {name: info for name, info in action_infos.items() if isinstance(info, ActionInfo) and info.enabled}
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
|
||||
"""根据文本查找匹配的命令。
|
||||
|
||||
Args:
|
||||
text: 待匹配的文本内容。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[CommandExecutor, dict, CommandInfo]]: 匹配结果。
|
||||
"""
|
||||
|
||||
for supervisor in self._iter_supervisors():
|
||||
match_result = supervisor.component_registry.find_command_by_text(text)
|
||||
if match_result is None:
|
||||
continue
|
||||
|
||||
entry, matched_groups = match_result
|
||||
command_info = self._build_command_info(entry) # type: ignore[arg-type]
|
||||
command_executor = self._build_command_executor(
|
||||
supervisor,
|
||||
entry.plugin_id,
|
||||
entry.name,
|
||||
dict(entry.metadata),
|
||||
)
|
||||
return command_executor, matched_groups, command_info
|
||||
return None
|
||||
|
||||
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
|
||||
"""获取指定工具的信息。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
|
||||
Returns:
|
||||
Optional[ToolInfo]: 匹配到的工具信息。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
|
||||
if matched_entry is None:
|
||||
return None
|
||||
_supervisor, entry = matched_entry
|
||||
return self._build_tool_info(entry) # type: ignore[arg-type]
|
||||
|
||||
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
|
||||
"""获取指定工具的执行器。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
|
||||
Returns:
|
||||
Optional[ToolExecutor]: 运行时 RPC 执行闭包。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
|
||||
if matched_entry is None:
|
||||
return None
|
||||
supervisor, entry = matched_entry
|
||||
return self._build_tool_executor(supervisor, entry.plugin_id, entry.name)
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
|
||||
"""获取当前可供 LLM 选择的工具集合。
|
||||
|
||||
Returns:
|
||||
Dict[str, ToolInfo]: 工具名到工具信息的映射。
|
||||
"""
|
||||
|
||||
tool_infos = self._collect_unique_component_infos(ComponentType.TOOL)
|
||||
return {name: info for name, info in tool_infos.items() if isinstance(info, ToolInfo) and info.enabled}
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取某类组件的全部信息。
|
||||
|
||||
Args:
|
||||
component_type: 组件类型。
|
||||
|
||||
Returns:
|
||||
Dict[str, ComponentInfo]: 组件名到组件信息的映射。
|
||||
"""
|
||||
|
||||
return self._collect_unique_component_infos(component_type)
|
||||
|
||||
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
||||
"""读取指定插件的配置文件内容。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称。
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 读取成功时返回配置字典;未找到时返回 ``None``。
|
||||
"""
|
||||
|
||||
runtime_manager = self._get_runtime_manager()
|
||||
try:
|
||||
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
|
||||
except RuntimeError as exc:
|
||||
logger.error(f"读取插件配置失败: {exc}")
|
||||
return None
|
||||
|
||||
if supervisor is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return runtime_manager._load_plugin_config_for_supervisor(supervisor, plugin_name)
|
||||
except Exception as exc:
|
||||
logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
component_query_service = ComponentQueryService()
|
||||
349
src/plugin_runtime/host/api_registry.py
Normal file
349
src/plugin_runtime/host/api_registry.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Host 侧插件 API 动态注册表。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_runtime.host.api_registry")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class APIEntry:
|
||||
"""API 组件条目。"""
|
||||
|
||||
name: str
|
||||
plugin_id: str
|
||||
description: str = ""
|
||||
version: str = "1"
|
||||
public: bool = False
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
handler_name: str = ""
|
||||
dynamic: bool = False
|
||||
offline_reason: str = ""
|
||||
disabled_session: Set[str] = field(default_factory=set)
|
||||
full_name: str = field(init=False)
|
||||
registry_key: str = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化 API 条目字段。"""
|
||||
|
||||
self.name = str(self.name or "").strip()
|
||||
self.plugin_id = str(self.plugin_id or "").strip()
|
||||
self.description = str(self.description or "").strip()
|
||||
self.version = str(self.version or "1").strip() or "1"
|
||||
self.handler_name = str(self.handler_name or self.name).strip() or self.name
|
||||
self.offline_reason = str(self.offline_reason or "").strip()
|
||||
self.full_name = f"{self.plugin_id}.{self.name}"
|
||||
self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
|
||||
"""根据 Runner 上报的元数据构造 API 条目。"""
|
||||
|
||||
safe_metadata = dict(metadata)
|
||||
return cls(
|
||||
name=name,
|
||||
plugin_id=plugin_id,
|
||||
description=str(safe_metadata.get("description", "") or ""),
|
||||
version=str(safe_metadata.get("version", "1") or "1"),
|
||||
public=bool(safe_metadata.get("public", False)),
|
||||
metadata=safe_metadata,
|
||||
enabled=bool(safe_metadata.get("enabled", True)),
|
||||
handler_name=str(safe_metadata.get("handler_name", name) or name),
|
||||
dynamic=bool(safe_metadata.get("dynamic", False)),
|
||||
offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
|
||||
)
|
||||
|
||||
|
||||
class APIRegistry:
|
||||
"""Host 侧插件 API 动态注册表。
|
||||
|
||||
该注册表不直接面向 Runner,而是复用插件组件注册/卸载事件,
|
||||
维护面向 API 调用场景的专用索引。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化 API 注册表。"""
|
||||
|
||||
self._apis: Dict[str, APIEntry] = {}
|
||||
self._by_full_name: Dict[str, List[APIEntry]] = {}
|
||||
self._by_plugin: Dict[str, List[APIEntry]] = {}
|
||||
self._by_name: Dict[str, List[APIEntry]] = {}
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部 API 注册状态。"""
|
||||
|
||||
self._apis.clear()
|
||||
self._by_full_name.clear()
|
||||
self._by_plugin.clear()
|
||||
self._by_name.clear()
|
||||
|
||||
@staticmethod
|
||||
def _is_api_component(component_type: Any) -> bool:
|
||||
"""判断组件声明是否属于 API。"""
|
||||
|
||||
return str(component_type or "").strip().upper() == "API"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_query_version(version: Any) -> str:
|
||||
"""规范化查询使用的版本字符串。"""
|
||||
|
||||
return str(version or "").strip()
|
||||
|
||||
@classmethod
|
||||
def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
|
||||
"""解析可能带 ``@version`` 后缀的 API 引用。"""
|
||||
|
||||
normalized_reference = str(reference or "").strip()
|
||||
normalized_version = cls._normalize_query_version(version)
|
||||
if normalized_reference and not normalized_version and "@" in normalized_reference:
|
||||
candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
|
||||
candidate_reference = candidate_reference.strip()
|
||||
candidate_version = candidate_version.strip()
|
||||
if candidate_reference and candidate_version:
|
||||
normalized_reference = candidate_reference
|
||||
normalized_version = candidate_version
|
||||
return normalized_reference, normalized_version
|
||||
|
||||
@staticmethod
|
||||
def build_registry_key(plugin_id: str, name: str, version: str) -> str:
|
||||
"""构造 API 注册表唯一键。"""
|
||||
|
||||
normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
|
||||
normalized_version = str(version or "1").strip() or "1"
|
||||
return f"{normalized_full_name}@{normalized_version}"
|
||||
|
||||
@staticmethod
|
||||
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
|
||||
"""判断 API 条目当前是否处于启用状态。"""
|
||||
|
||||
if session_id and session_id in entry.disabled_session:
|
||||
return False
|
||||
return entry.enabled
|
||||
|
||||
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个 API 条目。"""
|
||||
|
||||
normalized_name = str(name or "").strip()
|
||||
if not normalized_name:
|
||||
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
|
||||
return False
|
||||
|
||||
entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
|
||||
existing_entry = self._apis.get(entry.registry_key)
|
||||
if existing_entry is not None:
|
||||
logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
|
||||
self._remove_entry(existing_entry)
|
||||
|
||||
self._apis[entry.registry_key] = entry
|
||||
self._by_full_name.setdefault(entry.full_name, []).append(entry)
|
||||
self._by_plugin.setdefault(plugin_id, []).append(entry)
|
||||
self._by_name.setdefault(entry.name, []).append(entry)
|
||||
return True
|
||||
|
||||
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
|
||||
"""批量注册某个插件声明的全部 API。"""
|
||||
|
||||
count = 0
|
||||
for component in components:
|
||||
if not self._is_api_component(component.get("component_type")):
|
||||
continue
|
||||
if self.register_api(
|
||||
name=str(component.get("name", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def replace_plugin_dynamic_apis(
|
||||
self,
|
||||
plugin_id: str,
|
||||
components: List[Dict[str, Any]],
|
||||
*,
|
||||
offline_reason: str = "动态 API 已下线",
|
||||
) -> Tuple[int, int]:
|
||||
"""替换指定插件当前声明的动态 API 集合。"""
|
||||
|
||||
normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
|
||||
desired_registry_keys: Set[str] = set()
|
||||
registered_count = 0
|
||||
|
||||
for component in components:
|
||||
if not self._is_api_component(component.get("component_type")):
|
||||
continue
|
||||
metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
|
||||
dynamic_metadata = dict(metadata)
|
||||
dynamic_metadata["dynamic"] = True
|
||||
dynamic_metadata.pop("offline_reason", None)
|
||||
|
||||
entry = APIEntry.from_metadata(
|
||||
name=str(component.get("name", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=dynamic_metadata,
|
||||
)
|
||||
desired_registry_keys.add(entry.registry_key)
|
||||
if self.register_api(entry.name, plugin_id, dynamic_metadata):
|
||||
registered_count += 1
|
||||
|
||||
offlined_count = 0
|
||||
for entry in list(self._by_plugin.get(plugin_id, [])):
|
||||
if not entry.dynamic or entry.registry_key in desired_registry_keys:
|
||||
continue
|
||||
entry.enabled = False
|
||||
entry.offline_reason = normalized_offline_reason
|
||||
entry.metadata["offline_reason"] = normalized_offline_reason
|
||||
offlined_count += 1
|
||||
|
||||
return registered_count, offlined_count
|
||||
|
||||
def _remove_entry(self, entry: APIEntry) -> None:
|
||||
"""从全部索引中移除单个 API 条目。"""
|
||||
|
||||
self._apis.pop(entry.registry_key, None)
|
||||
|
||||
full_name_entries = self._by_full_name.get(entry.full_name)
|
||||
if full_name_entries is not None:
|
||||
self._by_full_name[entry.full_name] = [
|
||||
candidate for candidate in full_name_entries if candidate is not entry
|
||||
]
|
||||
if not self._by_full_name[entry.full_name]:
|
||||
self._by_full_name.pop(entry.full_name, None)
|
||||
|
||||
plugin_entries = self._by_plugin.get(entry.plugin_id)
|
||||
if plugin_entries is not None:
|
||||
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
|
||||
if not self._by_plugin[entry.plugin_id]:
|
||||
self._by_plugin.pop(entry.plugin_id, None)
|
||||
|
||||
name_entries = self._by_name.get(entry.name)
|
||||
if name_entries is not None:
|
||||
self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry]
|
||||
if not self._by_name[entry.name]:
|
||||
self._by_name.pop(entry.name, None)
|
||||
|
||||
def remove_apis_by_plugin(self, plugin_id: str) -> int:
|
||||
"""移除某个插件的全部 API。"""
|
||||
|
||||
entries = list(self._by_plugin.get(plugin_id, []))
|
||||
for entry in entries:
|
||||
self._remove_entry(entry)
|
||||
return len(entries)
|
||||
|
||||
def get_api_by_full_name(
|
||||
self,
|
||||
full_name: str,
|
||||
*,
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Optional[APIEntry]:
|
||||
"""按完整名查询单个 API。"""
|
||||
|
||||
normalized_full_name, normalized_version = self._split_reference(full_name, version)
|
||||
if not normalized_full_name:
|
||||
return None
|
||||
|
||||
if normalized_version:
|
||||
entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
|
||||
if entry is None:
|
||||
return None
|
||||
if enabled_only and not self.check_api_enabled(entry, session_id):
|
||||
return None
|
||||
return entry
|
||||
|
||||
candidates = list(self._by_full_name.get(normalized_full_name, []))
|
||||
filtered_entries = [
|
||||
entry
|
||||
for entry in candidates
|
||||
if not enabled_only or self.check_api_enabled(entry, session_id)
|
||||
]
|
||||
if len(filtered_entries) != 1:
|
||||
return None
|
||||
return filtered_entries[0]
|
||||
|
||||
def get_api(
|
||||
self,
|
||||
plugin_id: str,
|
||||
name: str,
|
||||
*,
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Optional[APIEntry]:
|
||||
"""按插件 ID、短名与版本查询单个 API。"""
|
||||
|
||||
return self.get_api_by_full_name(
|
||||
f"{plugin_id}.{name}",
|
||||
version=version,
|
||||
enabled_only=enabled_only,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def get_apis(
|
||||
self,
|
||||
*,
|
||||
plugin_id: Optional[str] = None,
|
||||
name: str = "",
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[APIEntry]:
|
||||
"""查询 API 列表。"""
|
||||
|
||||
normalized_name = str(name or "").strip()
|
||||
normalized_version = self._normalize_query_version(version)
|
||||
|
||||
if plugin_id:
|
||||
candidates = list(self._by_plugin.get(plugin_id, []))
|
||||
elif normalized_name:
|
||||
candidates = list(self._by_name.get(normalized_name, []))
|
||||
else:
|
||||
candidates = list(self._apis.values())
|
||||
|
||||
filtered_entries: List[APIEntry] = []
|
||||
for entry in candidates:
|
||||
if plugin_id and entry.plugin_id != plugin_id:
|
||||
continue
|
||||
if normalized_name and entry.name != normalized_name:
|
||||
continue
|
||||
if normalized_version and entry.version != normalized_version:
|
||||
continue
|
||||
if enabled_only and not self.check_api_enabled(entry, session_id):
|
||||
continue
|
||||
filtered_entries.append(entry)
|
||||
|
||||
filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
|
||||
return filtered_entries
|
||||
|
||||
def toggle_api_status(
|
||||
self,
|
||||
full_name: str,
|
||||
enabled: bool,
|
||||
*,
|
||||
version: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""设置指定 API 的启用状态。"""
|
||||
|
||||
entry = self.get_api_by_full_name(
|
||||
full_name,
|
||||
version=version,
|
||||
enabled_only=False,
|
||||
session_id=session_id,
|
||||
)
|
||||
if entry is None:
|
||||
return False
|
||||
if session_id:
|
||||
if enabled:
|
||||
entry.disabled_session.discard(session_id)
|
||||
else:
|
||||
entry.disabled_session.add(session_id)
|
||||
else:
|
||||
entry.enabled = enabled
|
||||
if enabled:
|
||||
entry.offline_reason = ""
|
||||
entry.metadata.pop("offline_reason", None)
|
||||
return True
|
||||
67
src/plugin_runtime/host/authorization.py
Normal file
67
src/plugin_runtime/host/authorization.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""授权管理器
|
||||
|
||||
负责管理插件的能力授权以及校验
|
||||
每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityPermissionToken:
|
||||
"""能力令牌"""
|
||||
|
||||
plugin_id: str
|
||||
capabilities: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class AuthorizationManager:
|
||||
"""授权管理器
|
||||
|
||||
管理所有插件的能力令牌,提供授权校验。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._permission_tokens: Dict[str, CapabilityPermissionToken] = {}
|
||||
|
||||
def register_plugin(self, plugin_id: str, capabilities: List[str]) -> CapabilityPermissionToken:
|
||||
"""为插件签发能力令牌"""
|
||||
token = CapabilityPermissionToken(plugin_id=plugin_id, capabilities=set(capabilities))
|
||||
self._permission_tokens[plugin_id] = token
|
||||
return token
|
||||
|
||||
def revoke_permission_token(self, plugin_id: str):
|
||||
"""移除插件的能力令牌。"""
|
||||
self._permission_tokens.pop(plugin_id, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有能力令牌。"""
|
||||
self._permission_tokens.clear()
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
return (bool, str): (是否有此能力, 原因)
|
||||
"""
|
||||
if capability in _ALWAYS_ALLOWED_CAPABILITIES:
|
||||
return True, ""
|
||||
|
||||
token = self._permission_tokens.get(plugin_id)
|
||||
if not token:
|
||||
return False, f"插件 {plugin_id} 未注册能力令牌"
|
||||
if capability not in token.capabilities:
|
||||
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> Optional[CapabilityPermissionToken]:
|
||||
"""获取插件的能力令牌"""
|
||||
return self._permission_tokens.get(plugin_id)
|
||||
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已注册的插件"""
|
||||
return list(self._permission_tokens.keys())
|
||||
@@ -4,21 +4,19 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
|
||||
每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
CapabilityRequestPayload,
|
||||
CapabilityResponsePayload,
|
||||
Envelope,
|
||||
)
|
||||
from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.authorization import AuthorizationManager
|
||||
|
||||
logger = get_logger("plugin_runtime.host.capability_service")
|
||||
|
||||
# 能力实现函数类型: (plugin_id, capability, args) -> result
|
||||
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
|
||||
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Coroutine[Any, Any, Any]]
|
||||
|
||||
|
||||
class CapabilityService:
|
||||
@@ -31,8 +29,13 @@ class CapabilityService:
|
||||
4. 执行实际操作并返回结果
|
||||
"""
|
||||
|
||||
def __init__(self, policy_engine: PolicyEngine) -> None:
|
||||
self._policy = policy_engine
|
||||
def __init__(self, authorization: "AuthorizationManager") -> None:
|
||||
"""初始化能力服务。
|
||||
|
||||
Args:
|
||||
authorization: 能力授权管理器。
|
||||
"""
|
||||
self._authorization = authorization
|
||||
# capability_name -> implementation
|
||||
self._implementations: Dict[str, CapabilityImpl] = {}
|
||||
|
||||
@@ -56,46 +59,32 @@ class CapabilityService:
|
||||
|
||||
try:
|
||||
req = CapabilityRequestPayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BAD_PAYLOAD.value,
|
||||
f"能力调用 payload 格式错误: {e}",
|
||||
)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 非法: {exc}")
|
||||
|
||||
capability = req.capability
|
||||
args = req.args
|
||||
|
||||
# 1. 权限校验
|
||||
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
|
||||
allowed, reason = self._authorization.check_capability(plugin_id, capability)
|
||||
if not allowed:
|
||||
error_code = (
|
||||
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
|
||||
)
|
||||
return envelope.make_error_response(
|
||||
error_code.value,
|
||||
reason,
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason)
|
||||
|
||||
# 2. 查找实现
|
||||
impl = self._implementations.get(capability)
|
||||
if impl is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的能力: {capability}",
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}")
|
||||
|
||||
# 3. 执行
|
||||
try:
|
||||
result = await impl(plugin_id, capability, req.args)
|
||||
result = await impl(plugin_id, capability, args)
|
||||
resp_payload = CapabilityResponsePayload(success=True, result=result)
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
except RPCError as e:
|
||||
return envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
except Exception as e:
|
||||
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_CAPABILITY_FAILED.value,
|
||||
str(e),
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e))
|
||||
|
||||
def list_capabilities(self) -> List[str]:
|
||||
"""列出所有已注册的能力"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Host-side ComponentRegistry
|
||||
|
||||
对齐旧系统 component_registry.py 的核心能力:
|
||||
- 按类型注册组件(action / command / tool / event_handler / workflow_step)
|
||||
- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway)
|
||||
- 命名空间 (plugin_id.component_name)
|
||||
- 命令正则匹配
|
||||
- 组件启用/禁用
|
||||
@@ -9,8 +9,10 @@
|
||||
- 注册统计
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -18,8 +20,28 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
|
||||
|
||||
class RegisteredComponent:
|
||||
"""已注册的组件条目"""
|
||||
class ComponentTypes(str, Enum):
|
||||
ACTION = "ACTION"
|
||||
COMMAND = "COMMAND"
|
||||
TOOL = "TOOL"
|
||||
EVENT_HANDLER = "EVENT_HANDLER"
|
||||
HOOK_HANDLER = "HOOK_HANDLER"
|
||||
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
|
||||
|
||||
|
||||
class StatusDict(TypedDict):
|
||||
total: int
|
||||
action: int
|
||||
command: int
|
||||
tool: int
|
||||
event_handler: int
|
||||
hook_handler: int
|
||||
message_gateway: int
|
||||
plugins: int
|
||||
|
||||
|
||||
class ComponentEntry:
|
||||
"""组件条目"""
|
||||
|
||||
__slots__ = (
|
||||
"name",
|
||||
@@ -28,31 +50,120 @@ class RegisteredComponent:
|
||||
"plugin_id",
|
||||
"metadata",
|
||||
"enabled",
|
||||
"_compiled_pattern",
|
||||
"compiled_pattern",
|
||||
"disabled_session",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.full_name = f"{plugin_id}.{name}"
|
||||
self.component_type = component_type
|
||||
self.plugin_id = plugin_id
|
||||
self.metadata = metadata
|
||||
self.enabled = metadata.get("enabled", True)
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.name: str = name
|
||||
self.full_name: str = f"{plugin_id}.{name}"
|
||||
self.component_type: ComponentTypes = ComponentTypes(component_type)
|
||||
self.plugin_id: str = plugin_id
|
||||
self.metadata: Dict[str, Any] = metadata
|
||||
self.enabled: bool = metadata.get("enabled", True)
|
||||
self.disabled_session: Set[str] = set()
|
||||
|
||||
# 预编译命令正则(仅 command 类型)
|
||||
self._compiled_pattern: Optional[re.Pattern] = None
|
||||
if component_type == "command":
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
try:
|
||||
self._compiled_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
|
||||
|
||||
class ActionEntry(ComponentEntry):
|
||||
"""Action 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class CommandEntry(ComponentEntry):
|
||||
"""Command 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
self.aliases: List[str] = metadata.get("aliases", [])
|
||||
self.compiled_pattern: Optional[re.Pattern] = None
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
try:
|
||||
self.compiled_pattern = re.compile(pattern)
|
||||
except (re.error, TypeError) as e:
|
||||
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
|
||||
|
||||
|
||||
class ToolEntry(ComponentEntry):
|
||||
"""Tool 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.description: str = metadata.get("description", "")
|
||||
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
|
||||
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class EventHandlerEntry(ComponentEntry):
|
||||
"""EventHandler 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.event_type: str = metadata.get("event_type", "")
|
||||
self.weight: int = metadata.get("weight", 0)
|
||||
self.intercept_message: bool = metadata.get("intercept_message", False)
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class HookHandlerEntry(ComponentEntry):
|
||||
"""WorkflowHandler 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.stage: str = metadata.get("stage", "")
|
||||
self.priority: int = metadata.get("priority", 0)
|
||||
self.blocking: bool = metadata.get("blocking", False)
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class MessageGatewayEntry(ComponentEntry):
|
||||
"""MessageGateway 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
|
||||
self.platform: str = str(metadata.get("platform", "") or "").strip()
|
||||
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
|
||||
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
|
||||
self.scope: str = str(metadata.get("scope", "") or "").strip()
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_route_type(raw_value: Any) -> str:
|
||||
"""规范化消息网关路由类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始路由类型值。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的路由类型。
|
||||
|
||||
Raises:
|
||||
ValueError: 当路由类型不受支持时抛出。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
route_type_aliases = {
|
||||
"send": "send",
|
||||
"receive": "receive",
|
||||
"recv": "receive",
|
||||
"recive": "receive",
|
||||
"duplex": "duplex",
|
||||
}
|
||||
route_type = route_type_aliases.get(normalized_value)
|
||||
if route_type is None:
|
||||
raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}")
|
||||
return route_type
|
||||
|
||||
@property
|
||||
def supports_send(self) -> bool:
|
||||
"""返回当前网关是否支持出站。"""
|
||||
|
||||
return self.route_type in {"send", "duplex"}
|
||||
|
||||
@property
|
||||
def supports_receive(self) -> bool:
|
||||
"""返回当前网关是否支持入站。"""
|
||||
|
||||
return self.route_type in {"receive", "duplex"}
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
@@ -64,19 +175,32 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self) -> None:
|
||||
# 全量索引
|
||||
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
|
||||
|
||||
# 按类型索引
|
||||
self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
|
||||
"action": {},
|
||||
"command": {},
|
||||
"tool": {},
|
||||
"event_handler": {},
|
||||
"workflow_step": {},
|
||||
}
|
||||
self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
|
||||
comp_type: {} for comp_type in ComponentTypes
|
||||
} # component_type -> (full_name -> comp)
|
||||
|
||||
# 按插件索引
|
||||
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
|
||||
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_component_type(component_type: str) -> ComponentTypes:
|
||||
"""规范化组件类型输入。
|
||||
|
||||
Args:
|
||||
component_type: 原始组件类型字符串。
|
||||
|
||||
Returns:
|
||||
ComponentTypes: 规范化后的组件类型枚举。
|
||||
|
||||
Raises:
|
||||
ValueError: 当组件类型不受支持时抛出。
|
||||
"""
|
||||
|
||||
normalized_value = str(component_type or "").strip().upper()
|
||||
return ComponentTypes(normalized_value)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部组件注册状态。"""
|
||||
@@ -85,47 +209,64 @@ class ComponentRegistry:
|
||||
type_dict.clear()
|
||||
self._by_plugin.clear()
|
||||
|
||||
# ──── 注册 / 注销 ─────────────────────────────────────────
|
||||
# ====== 注册 / 注销 ======
|
||||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个组件
|
||||
|
||||
Args:
|
||||
name: 组件名称(不含插件id前缀)
|
||||
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
|
||||
plugin_id: 插件id
|
||||
metadata: 组件元数据
|
||||
Returns:
|
||||
success (bool): 是否成功注册(失败原因通常是组件类型无效)
|
||||
"""
|
||||
try:
|
||||
normalized_type = self._normalize_component_type(component_type)
|
||||
if normalized_type == ComponentTypes.ACTION:
|
||||
comp = ActionEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.COMMAND:
|
||||
comp = CommandEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.TOOL:
|
||||
comp = ToolEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.EVENT_HANDLER:
|
||||
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.HOOK_HANDLER:
|
||||
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
else:
|
||||
raise ValueError(f"组件类型 {component_type} 不存在")
|
||||
except ValueError:
|
||||
logger.error(f"组件类型 {component_type} 不存在")
|
||||
return False
|
||||
|
||||
def register_component(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> bool:
|
||||
"""注册单个组件。"""
|
||||
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
|
||||
if comp.full_name in self._components:
|
||||
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
||||
old_comp = self._components[comp.full_name]
|
||||
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
|
||||
old_list = self._by_plugin.get(old_comp.plugin_id)
|
||||
if old_list is not None:
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
old_list.remove(old_comp)
|
||||
except ValueError:
|
||||
pass
|
||||
# 从旧类型索引中移除,防止类型变更时幽灵残留
|
||||
if old_type_dict := self._by_type.get(old_comp.component_type):
|
||||
old_type_dict.pop(comp.full_name, None)
|
||||
|
||||
self._components[comp.full_name] = comp
|
||||
|
||||
if component_type not in self._by_type:
|
||||
self._by_type[component_type] = {}
|
||||
self._by_type[component_type][comp.full_name] = comp
|
||||
|
||||
self._by_type[comp.component_type][comp.full_name] = comp
|
||||
self._by_plugin.setdefault(plugin_id, []).append(comp)
|
||||
|
||||
return True
|
||||
|
||||
def register_plugin_components(
|
||||
self,
|
||||
plugin_id: str,
|
||||
components: List[Dict[str, Any]],
|
||||
) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。"""
|
||||
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。
|
||||
Args:
|
||||
plugin_id (str): 插件id
|
||||
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
|
||||
Returns:
|
||||
count (int): 成功注册的组件数量
|
||||
"""
|
||||
count = 0
|
||||
for comp_data in components:
|
||||
ok = self.register_component(
|
||||
@@ -139,7 +280,13 @@ class ComponentRegistry:
|
||||
return count
|
||||
|
||||
def remove_components_by_plugin(self, plugin_id: str) -> int:
|
||||
"""移除某个插件的所有组件,返回移除数量。"""
|
||||
"""移除某个插件的所有组件,返回移除数量。
|
||||
|
||||
Args:
|
||||
plugin_id (str): 插件id
|
||||
Returns:
|
||||
count (int): 移除的组件数量
|
||||
"""
|
||||
comps = self._by_plugin.pop(plugin_id, [])
|
||||
for comp in comps:
|
||||
self._components.pop(comp.full_name, None)
|
||||
@@ -147,106 +294,280 @@ class ComponentRegistry:
|
||||
type_dict.pop(comp.full_name, None)
|
||||
return len(comps)
|
||||
|
||||
# ──── 启用 / 禁用 ─────────────────────────────────────────
|
||||
# ====== 启用 / 禁用 ======
|
||||
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
|
||||
if session_id and session_id in component.disabled_session:
|
||||
return False
|
||||
return component.enabled
|
||||
|
||||
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
|
||||
"""启用或禁用指定组件。"""
|
||||
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
|
||||
"""启用或禁用指定组件。
|
||||
|
||||
Args:
|
||||
full_name (str): 组件全名
|
||||
enabled (bool): 使能情况
|
||||
session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
|
||||
Returns:
|
||||
success (bool): 是否成功设置(失败原因通常是组件不存在)
|
||||
"""
|
||||
comp = self._components.get(full_name)
|
||||
if comp is None:
|
||||
return False
|
||||
comp.enabled = enabled
|
||||
if session_id:
|
||||
if enabled:
|
||||
comp.disabled_session.discard(session_id)
|
||||
else:
|
||||
comp.disabled_session.add(session_id)
|
||||
else:
|
||||
comp.enabled = enabled
|
||||
return True
|
||||
|
||||
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
|
||||
"""批量启用或禁用某插件的所有组件。"""
|
||||
def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
|
||||
"""设置指定组件的启用状态。
|
||||
|
||||
Args:
|
||||
full_name: 组件全名。
|
||||
enabled: 目标启用状态。
|
||||
session_id: 可选的会话 ID,仅对该会话生效。
|
||||
|
||||
Returns:
|
||||
bool: 是否设置成功。
|
||||
"""
|
||||
|
||||
return self.toggle_component_status(full_name, enabled, session_id=session_id)
|
||||
|
||||
def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
|
||||
"""批量启用或禁用某插件的所有组件。
|
||||
|
||||
Args:
|
||||
plugin_id (str): 插件id
|
||||
enabled (bool): 使能情况
|
||||
session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
|
||||
Returns:
|
||||
count (int): 成功设置的组件数量(失败原因通常是插件不存在)
|
||||
"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
for comp in comps:
|
||||
comp.enabled = enabled
|
||||
if session_id:
|
||||
if enabled:
|
||||
comp.disabled_session.discard(session_id)
|
||||
else:
|
||||
comp.disabled_session.add(session_id)
|
||||
else:
|
||||
comp.enabled = enabled
|
||||
return len(comps)
|
||||
|
||||
# ──── 查询方法 ─────────────────────────────────────────────
|
||||
def get_component(self, full_name: str) -> Optional[ComponentEntry]:
|
||||
"""按全名查询。
|
||||
|
||||
def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
|
||||
"""按全名查询。"""
|
||||
Args:
|
||||
full_name (str): 组件全名
|
||||
Returns:
|
||||
component (Optional[ComponentEntry]): 组件条目,未找到时为 None
|
||||
"""
|
||||
return self._components.get(full_name)
|
||||
|
||||
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""按类型查询。"""
|
||||
type_dict = self._by_type.get(component_type, {})
|
||||
def get_components_by_type(
|
||||
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[ComponentEntry]:
|
||||
"""按类型查询组件
|
||||
|
||||
Args:
|
||||
component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
|
||||
enabled_only (bool): 是否仅返回启用的组件
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
components (List[ComponentEntry]): 组件条目列表
|
||||
"""
|
||||
try:
|
||||
comp_type = self._normalize_component_type(component_type)
|
||||
except ValueError:
|
||||
logger.error(f"组件类型 {component_type} 不存在")
|
||||
raise
|
||||
type_dict = self._by_type.get(comp_type, {})
|
||||
if enabled_only:
|
||||
return [c for c in type_dict.values() if c.enabled]
|
||||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
|
||||
return list(type_dict.values())
|
||||
|
||||
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""按插件查询。"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
return [c for c in comps if c.enabled] if enabled_only else list(comps)
|
||||
def get_components_by_plugin(
|
||||
self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[ComponentEntry]:
|
||||
"""按插件查询组件。
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]:
|
||||
Args:
|
||||
plugin_id (str): 插件ID
|
||||
enabled_only (bool): 是否仅返回启用的组件
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
components (List[ComponentEntry]): 组件条目列表
|
||||
"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
|
||||
|
||||
def find_command_by_text(
|
||||
self, text: str, session_id: Optional[str] = None
|
||||
) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
|
||||
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
|
||||
|
||||
matched_groups 为正则命名捕获组 dict,别名匹配时为空 dict。
|
||||
Args:
|
||||
text (str): 待匹配文本
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
|
||||
"""
|
||||
for comp in self._by_type.get("command", {}).values():
|
||||
if not comp.enabled:
|
||||
for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
|
||||
if not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if comp._compiled_pattern:
|
||||
m = comp._compiled_pattern.search(text)
|
||||
if m:
|
||||
if not isinstance(comp, CommandEntry):
|
||||
continue
|
||||
if comp.compiled_pattern:
|
||||
if m := comp.compiled_pattern.search(text):
|
||||
return comp, m.groupdict()
|
||||
# 别名匹配
|
||||
aliases = comp.metadata.get("aliases", [])
|
||||
for alias in aliases:
|
||||
for alias in comp.aliases:
|
||||
if text.startswith(alias):
|
||||
return comp, {}
|
||||
return None
|
||||
|
||||
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
|
||||
handlers = []
|
||||
for comp in self._by_type.get("event_handler", {}).values():
|
||||
if enabled_only and not comp.enabled:
|
||||
def get_event_handlers(
|
||||
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[EventHandlerEntry]:
|
||||
"""查询指定事件类型的事件处理器组件。
|
||||
|
||||
Args:
|
||||
event_type (str): 事件类型
|
||||
enabled_only (bool): 是否仅返回启用的组件
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
|
||||
"""
|
||||
handlers: List[EventHandlerEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
|
||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if comp.metadata.get("event_type") == event_type:
|
||||
if not isinstance(comp, EventHandlerEntry):
|
||||
continue
|
||||
if comp.event_type == event_type:
|
||||
handlers.append(comp)
|
||||
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
|
||||
handlers.sort(key=lambda c: c.weight, reverse=True)
|
||||
return handlers
|
||||
|
||||
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
|
||||
steps = []
|
||||
for comp in self._by_type.get("workflow_step", {}).values():
|
||||
if enabled_only and not comp.enabled:
|
||||
def get_hook_handlers(
|
||||
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[HookHandlerEntry]:
|
||||
"""获取特定 hook 阶段的所有步骤,按 priority 降序。
|
||||
|
||||
Args:
|
||||
stage: hook 名称
|
||||
enabled_only: 是否仅返回启用的组件
|
||||
session_id: 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
|
||||
"""
|
||||
handlers: List[HookHandlerEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
|
||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if comp.metadata.get("stage") == stage:
|
||||
steps.append(comp)
|
||||
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
|
||||
return steps
|
||||
if not isinstance(comp, HookHandlerEntry):
|
||||
continue
|
||||
if comp.stage == stage:
|
||||
handlers.append(comp)
|
||||
handlers.sort(key=lambda c: c.priority, reverse=True)
|
||||
return handlers
|
||||
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
|
||||
"""获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。"""
|
||||
result: List[Dict[str, Any]] = []
|
||||
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
|
||||
tool_def: Dict[str, Any] = {
|
||||
"name": comp.full_name,
|
||||
"description": comp.metadata.get("description", ""),
|
||||
}
|
||||
# 从结构化参数或原始参数构建 parameters
|
||||
params = comp.metadata.get("parameters", [])
|
||||
params_raw = comp.metadata.get("parameters_raw", {})
|
||||
if params:
|
||||
tool_def["parameters"] = params
|
||||
elif params_raw:
|
||||
tool_def["parameters"] = params_raw
|
||||
result.append(tool_def)
|
||||
return result
|
||||
def get_message_gateway(
|
||||
self,
|
||||
plugin_id: str,
|
||||
name: str,
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Optional[MessageGatewayEntry]:
|
||||
"""按插件和组件名获取单个消息网关。
|
||||
|
||||
# ──── 统计 ─────────────────────────────────────────────────
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
name: 网关组件名称。
|
||||
enabled_only: 是否仅返回启用的组件。
|
||||
session_id: 可选的会话 ID。
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""获取注册统计。"""
|
||||
stats: Dict[str, int] = {"total": len(self._components)}
|
||||
Returns:
|
||||
Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。
|
||||
"""
|
||||
|
||||
component = self._components.get(f"{plugin_id}.{name}")
|
||||
if not isinstance(component, MessageGatewayEntry):
|
||||
return None
|
||||
if enabled_only and not self.check_component_enabled(component, session_id):
|
||||
return None
|
||||
return component
|
||||
|
||||
def get_message_gateways(
|
||||
self,
|
||||
*,
|
||||
plugin_id: Optional[str] = None,
|
||||
platform: str = "",
|
||||
route_type: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[MessageGatewayEntry]:
|
||||
"""查询消息网关组件列表。
|
||||
|
||||
Args:
|
||||
plugin_id: 可选的插件 ID 过滤条件。
|
||||
platform: 可选的平台过滤条件。
|
||||
route_type: 可选的路由类型过滤条件。
|
||||
enabled_only: 是否仅返回启用的组件。
|
||||
session_id: 可选的会话 ID。
|
||||
|
||||
Returns:
|
||||
List[MessageGatewayEntry]: 符合条件的消息网关组件列表。
|
||||
"""
|
||||
|
||||
normalized_platform = str(platform or "").strip()
|
||||
normalized_route_type = str(route_type or "").strip().lower()
|
||||
gateways: List[MessageGatewayEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
|
||||
if not isinstance(comp, MessageGatewayEntry):
|
||||
continue
|
||||
if plugin_id and comp.plugin_id != plugin_id:
|
||||
continue
|
||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if normalized_platform and comp.platform != normalized_platform:
|
||||
continue
|
||||
if normalized_route_type and comp.route_type != normalized_route_type:
|
||||
continue
|
||||
gateways.append(comp)
|
||||
return gateways
|
||||
|
||||
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
|
||||
"""查询所有工具组件。
|
||||
|
||||
Args:
|
||||
enabled_only (bool): 是否仅返回启用的组件
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
tools (List[ToolEntry]): 符合条件的 Tool 组件列表
|
||||
"""
|
||||
tools: List[ToolEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
|
||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if isinstance(comp, ToolEntry):
|
||||
tools.append(comp)
|
||||
return tools
|
||||
|
||||
# ====== 统计信息 ======
|
||||
def get_stats(self) -> StatusDict:
|
||||
"""获取注册统计。
|
||||
|
||||
Returns:
|
||||
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
|
||||
"""
|
||||
stats: StatusDict = {"total": len(self._components)} # type: ignore
|
||||
for comp_type, type_dict in self._by_type.items():
|
||||
stats[comp_type] = len(type_dict)
|
||||
stats[comp_type.value.lower()] = len(type_dict)
|
||||
stats["plugins"] = len(self._by_plugin)
|
||||
return stats
|
||||
|
||||
@@ -4,40 +4,40 @@
|
||||
1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry)
|
||||
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
|
||||
3. 支持阻塞(intercept_message)和非阻塞分发
|
||||
4. 事件结果历史记录
|
||||
4. 事件结果历史记录(有上限)
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
||||
|
||||
from .message_utils import PluginMessageUtils, MessageDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
from .component_registry import ComponentRegistry, EventHandlerEntry
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
logger = get_logger("plugin_runtime.host.event_dispatcher")
|
||||
|
||||
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
# 每个事件类型的最大历史记录数量,防止内存无限增长
|
||||
_MAX_HISTORY_LENGTH = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventResult:
|
||||
"""单个 EventHandler 的执行结果"""
|
||||
|
||||
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler_name: str,
|
||||
success: bool = True,
|
||||
continue_processing: bool = True,
|
||||
modified_message: Optional[Dict[str, Any]] = None,
|
||||
custom_result: Any = None,
|
||||
):
|
||||
self.handler_name = handler_name
|
||||
self.success = success
|
||||
self.continue_processing = continue_processing
|
||||
self.modified_message = modified_message
|
||||
self.custom_result = custom_result
|
||||
handler_name: str
|
||||
success: bool = field(default=True)
|
||||
continue_processing: bool = field(default=True)
|
||||
modified_message: Optional[MessageDict] = field(default=None)
|
||||
custom_result: Any = field(default=None)
|
||||
|
||||
|
||||
class EventDispatcher:
|
||||
@@ -48,17 +48,20 @@ class EventDispatcher:
|
||||
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ComponentRegistry) -> None:
|
||||
self._registry: ComponentRegistry = registry
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
self._component_registry: "ComponentRegistry" = component_registry
|
||||
self._result_history: Dict[str, List[EventResult]] = {}
|
||||
self._history_enabled: Set[str] = set()
|
||||
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
|
||||
self._background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
def enable_history(self, event_type: str) -> None:
|
||||
self._history_enabled.add(event_type)
|
||||
self._result_history.setdefault(event_type, [])
|
||||
|
||||
def disable_history(self, event_type: str) -> None:
|
||||
self._history_enabled.discard(event_type)
|
||||
self._result_history.pop(event_type, None)
|
||||
|
||||
def get_history(self, event_type: str) -> List[EventResult]:
|
||||
return self._result_history.get(event_type, [])
|
||||
|
||||
@@ -66,47 +69,58 @@ class EventDispatcher:
|
||||
if event_type in self._result_history:
|
||||
self._result_history[event_type] = []
|
||||
|
||||
async def stop(self):
|
||||
"""停止 EventDispatcher,取消所有未完成的后台任务"""
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def dispatch_event(
|
||||
self,
|
||||
event_type: str,
|
||||
invoke_fn: InvokeFn,
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
message: Optional["SessionMessage"] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""分发事件到所有对应 handler。
|
||||
) -> Tuple[bool, Optional["SessionMessage"]]:
|
||||
"""分发事件到所有对应 handler 的便捷方法。
|
||||
|
||||
内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑,
|
||||
无需调用方手动构造 invoke_fn 闭包。
|
||||
|
||||
Args:
|
||||
event_type: 事件类型字符串
|
||||
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
|
||||
supervisor: PluginSupervisor 实例,用于调用 invoke_plugin
|
||||
message: MaiMessages 序列化后的 dict(可选)
|
||||
extra_args: 额外参数
|
||||
|
||||
Returns:
|
||||
(should_continue, modified_message_dict)
|
||||
(should_continue, modified_message_dict) (bool, SessionMessage | None): (是否继续后续执行, 可选的修改后的消息)
|
||||
"""
|
||||
handlers = self._registry.get_event_handlers(event_type)
|
||||
if not handlers:
|
||||
handler_entries = self._component_registry.get_event_handlers(event_type)
|
||||
if not handler_entries:
|
||||
return True, None
|
||||
|
||||
should_continue = True
|
||||
modified_message: Optional[Dict[str, Any]] = None
|
||||
intercept_handlers: List[RegisteredComponent] = []
|
||||
async_handlers: List[RegisteredComponent] = []
|
||||
modified_message: Optional[MessageDict] = (
|
||||
PluginMessageUtils._session_message_to_dict(message) if message else None
|
||||
)
|
||||
intercept_handlers: List["EventHandlerEntry"] = []
|
||||
non_blocking_handlers: List["EventHandlerEntry"] = []
|
||||
|
||||
for handler in handlers:
|
||||
if handler.metadata.get("intercept_message", False):
|
||||
intercept_handlers.append(handler)
|
||||
for entry in handler_entries:
|
||||
if entry.intercept_message:
|
||||
intercept_handlers.append(entry)
|
||||
else:
|
||||
async_handlers.append(handler)
|
||||
non_blocking_handlers.append(entry)
|
||||
|
||||
for handler in intercept_handlers:
|
||||
for entry in intercept_handlers:
|
||||
args = {
|
||||
"event_type": event_type,
|
||||
"message": modified_message or message,
|
||||
"message": modified_message,
|
||||
**(extra_args or {}),
|
||||
}
|
||||
|
||||
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
|
||||
result = await self._invoke_handler(supervisor, entry, args, event_type)
|
||||
if result and not result.continue_processing:
|
||||
should_continue = False
|
||||
break
|
||||
@@ -114,47 +128,57 @@ class EventDispatcher:
|
||||
modified_message = result.modified_message
|
||||
|
||||
if should_continue:
|
||||
final_message = modified_message or message
|
||||
for handler in async_handlers:
|
||||
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
|
||||
final_message = modified_message
|
||||
for entry in non_blocking_handlers:
|
||||
async_message = final_message.copy() if final_message else final_message
|
||||
args = {
|
||||
"event_type": event_type,
|
||||
"message": async_message,
|
||||
**(extra_args or {}),
|
||||
}
|
||||
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
|
||||
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
|
||||
task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return should_continue, modified_message
|
||||
try:
|
||||
modified_message_obj = (
|
||||
PluginMessageUtils._build_session_message_from_dict(modified_message) if modified_message else None # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"构建修改后的 SessionMessage 失败: {e}")
|
||||
modified_message_obj = None
|
||||
return should_continue, modified_message_obj
|
||||
|
||||
async def _invoke_handler(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
handler: RegisteredComponent,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
handler_entry: "EventHandlerEntry",
|
||||
args: Dict[str, Any],
|
||||
event_type: str,
|
||||
) -> Optional[EventResult]:
|
||||
"""调用单个 handler 并收集结果。"""
|
||||
try:
|
||||
resp = await invoke_fn(handler.plugin_id, handler.name, args)
|
||||
resp_envelope = await supervisor.invoke_plugin(
|
||||
"plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args
|
||||
)
|
||||
resp = resp_envelope.payload
|
||||
result = EventResult(
|
||||
handler_name=handler.full_name,
|
||||
handler_name=handler_entry.full_name,
|
||||
success=resp.get("success", True),
|
||||
continue_processing=resp.get("continue_processing", True),
|
||||
modified_message=resp.get("modified_message"),
|
||||
custom_result=resp.get("custom_result"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
|
||||
result = EventResult(
|
||||
handler_name=handler.full_name,
|
||||
success=False,
|
||||
continue_processing=True,
|
||||
)
|
||||
logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True)
|
||||
result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
|
||||
|
||||
if event_type in self._history_enabled:
|
||||
self._result_history.setdefault(event_type, []).append(result)
|
||||
history_list = self._result_history.setdefault(event_type, [])
|
||||
history_list.append(result)
|
||||
# 自动清理超出限制的旧记录,防止内存无限增长
|
||||
if len(history_list) > _MAX_HISTORY_LENGTH:
|
||||
# 保留最新的 _MAX_HISTORY_LENGTH 条记录
|
||||
self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:]
|
||||
|
||||
return result
|
||||
|
||||
166
src/plugin_runtime/host/hook_dispatcher.py
Normal file
166
src/plugin_runtime/host/hook_dispatcher.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Hook Dispatch 系统
|
||||
|
||||
插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。
|
||||
每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。
|
||||
在参数/返回值匹配的情况下允许修改参数/返回值。
|
||||
|
||||
HookDispatcher 负责:
|
||||
1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry)
|
||||
2. 按 priority 排序,区分 blocking 和非 blocking 模式
|
||||
3. blocking 模式:依次同步调用,支持修改参数/提前终止
|
||||
4. 非 blocking 模式:异步调用,不阻塞主流程
|
||||
5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
from .component_registry import ComponentRegistry, HookHandlerEntry
|
||||
|
||||
logger = get_logger("plugin_runtime.host.hook_dispatcher")
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""单个 HookHandler 的执行结果"""
|
||||
|
||||
handler_name: str
|
||||
success: bool = field(default=True)
|
||||
continue_processing: bool = field(default=True)
|
||||
modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
|
||||
custom_result: Any = field(default=None)
|
||||
|
||||
|
||||
class HookDispatcher:
|
||||
"""Host-side Hook 分发器
|
||||
|
||||
由业务层调用 hook_dispatch(),
|
||||
内部通过 ComponentRegistry 查询 handler,
|
||||
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
|
||||
"""
|
||||
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
"""初始化 HookDispatcher
|
||||
|
||||
Args:
|
||||
component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
|
||||
"""
|
||||
self._component_registry: "ComponentRegistry" = component_registry
|
||||
self._background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 HookDispatcher,取消所有未完成的后台任务"""
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def hook_dispatch(
|
||||
self,
|
||||
stage: str,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""分发 hook 到所有对应 handler 的便捷方法。
|
||||
|
||||
内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
|
||||
无需调用方手动构造 invoke_fn 闭包。
|
||||
|
||||
Args:
|
||||
stage: hook 名称
|
||||
supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
|
||||
**kwargs: 关键字参数,会展开传递给 handler
|
||||
|
||||
Returns:
|
||||
modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
|
||||
"""
|
||||
handler_entries = self._component_registry.get_hook_handlers(stage)
|
||||
if not handler_entries:
|
||||
return kwargs
|
||||
|
||||
current_kwargs = kwargs.copy()
|
||||
blocking_handlers: List["HookHandlerEntry"] = []
|
||||
non_blocking_handlers: List["HookHandlerEntry"] = []
|
||||
|
||||
# 分离 blocking 和非 blocking handler
|
||||
for entry in handler_entries:
|
||||
if entry.blocking:
|
||||
blocking_handlers.append(entry)
|
||||
else:
|
||||
non_blocking_handlers.append(entry)
|
||||
|
||||
# 处理 blocking handlers(同步调用,支持修改参数/提前终止)
|
||||
timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
|
||||
for entry in blocking_handlers:
|
||||
hook_args = {"stage": stage, **current_kwargs}
|
||||
try:
|
||||
# 应用超时控制
|
||||
result = await asyncio.wait_for(
|
||||
self._invoke_handler(supervisor, entry, hook_args),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
|
||||
result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
|
||||
|
||||
if result:
|
||||
if result.modified_kwargs is not None:
|
||||
current_kwargs = result.modified_kwargs
|
||||
if not result.continue_processing:
|
||||
logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
|
||||
break
|
||||
|
||||
# 处理 non-blocking handlers(异步调用,不阻塞主流程)
|
||||
for entry in non_blocking_handlers:
|
||||
async_kwargs = current_kwargs.copy()
|
||||
hook_args = {"stage": stage, **async_kwargs}
|
||||
task = asyncio.create_task(
|
||||
asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return current_kwargs
|
||||
|
||||
async def _invoke_handler(
|
||||
self,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
handler_entry: "HookHandlerEntry",
|
||||
args: Dict[str, Any],
|
||||
) -> Optional[HookResult]:
|
||||
"""调用单个 handler 并收集结果。
|
||||
|
||||
Args:
|
||||
supervisor: PluginRunnerSupervisor 实例
|
||||
handler_entry: HookHandlerEntry 实例
|
||||
args: 传递给 handler 的参数字典
|
||||
stage: hook 名称
|
||||
|
||||
Returns:
|
||||
Optional[HookResult]: 执行结果,如果执行失败则返回 None
|
||||
"""
|
||||
try:
|
||||
resp_envelope = await supervisor.invoke_plugin(
|
||||
"plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
|
||||
)
|
||||
resp = resp_envelope.payload
|
||||
result = HookResult(
|
||||
handler_name=handler_entry.full_name,
|
||||
success=resp.get("success", True),
|
||||
continue_processing=resp.get("continue_processing", True),
|
||||
modified_kwargs=resp.get("modified_kwargs"),
|
||||
custom_result=resp.get("custom_result"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
|
||||
result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
|
||||
|
||||
return result
|
||||
45
src/plugin_runtime/host/logger_bridge.py
Normal file
45
src/plugin_runtime/host/logger_bridge.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging as stdlib_logging
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, LogBatchPayload
|
||||
class RunnerLogBridge:
|
||||
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
|
||||
|
||||
Runner 通过 ``runner.log_batch`` IPC 事件批量到达。
|
||||
每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接
|
||||
调用 ``logging.getLogger(entry.logger_name).handle(record)``,
|
||||
从而接入主进程已配置好的 structlog Handler 链。
|
||||
"""
|
||||
|
||||
async def handle_log_batch(self, envelope: Envelope) -> Envelope:
|
||||
"""IPC 事件处理器:解析批量日志并重放到主进程 Logger。
|
||||
|
||||
Args:
|
||||
envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。
|
||||
|
||||
Returns:
|
||||
空响应信封(事件模式下将被忽略)。
|
||||
"""
|
||||
try:
|
||||
batch = LogBatchPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
for entry in batch.entries:
|
||||
# 重建一个与原始日志尽量相符的 LogRecord
|
||||
record = stdlib_logging.LogRecord(
|
||||
name=entry.logger_name,
|
||||
level=entry.level,
|
||||
pathname="<runner>",
|
||||
lineno=0,
|
||||
msg=entry.message,
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.created = entry.timestamp_ms / 1000.0
|
||||
record.msecs = entry.timestamp_ms % 1000
|
||||
if entry.exception_text:
|
||||
record.exc_text = entry.exception_text
|
||||
|
||||
stdlib_logging.getLogger(entry.logger_name).handle(record)
|
||||
|
||||
return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
|
||||
112
src/plugin_runtime/host/message_gateway.py
Normal file
112
src/plugin_runtime/host/message_gateway.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Host 侧消息网关包装器。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.platform_io import get_platform_io_manager
|
||||
|
||||
from .message_utils import PluginMessageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from .component_registry import ComponentRegistry
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
|
||||
logger = get_logger("plugin_runtime.host.message_gateway")
|
||||
|
||||
|
||||
class MessageGateway:
|
||||
"""Host 侧消息网关包装器。"""
|
||||
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
"""初始化消息网关。
|
||||
|
||||
Args:
|
||||
component_registry: 组件注册表。
|
||||
"""
|
||||
self._component_registry = component_registry
|
||||
|
||||
def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage":
|
||||
"""将标准消息字典转换为 ``SessionMessage``。
|
||||
|
||||
Args:
|
||||
external_message: 外部消息的字典格式数据。
|
||||
|
||||
Returns:
|
||||
SessionMessage: 转换后的内部消息对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 消息字典不合法时抛出。
|
||||
"""
|
||||
return PluginMessageUtils._build_session_message_from_dict(external_message)
|
||||
|
||||
def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]:
|
||||
"""将 ``SessionMessage`` 转换为标准消息字典。
|
||||
|
||||
Args:
|
||||
internal_message: 内部消息对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 供消息网关插件消费的标准消息字典。
|
||||
"""
|
||||
return dict(PluginMessageUtils._session_message_to_dict(internal_message))
|
||||
|
||||
async def receive_external_message(self, external_message: Dict[str, Any]) -> None:
|
||||
"""接收外部消息并送入主消息链。
|
||||
|
||||
Args:
|
||||
external_message: 外部消息的字典格式数据。
|
||||
"""
|
||||
try:
|
||||
session_message = self.build_session_message(external_message)
|
||||
except Exception as e:
|
||||
logger.error(f"转换外部消息失败: {e}")
|
||||
return
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
|
||||
await chat_bot.receive_message(session_message)
|
||||
|
||||
async def send_message_to_external(
|
||||
self,
|
||||
internal_message: "SessionMessage",
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
save_to_db: bool = True,
|
||||
) -> bool:
|
||||
"""将内部消息通过 Platform IO 发送到外部平台。
|
||||
|
||||
Args:
|
||||
internal_message: 系统内部的 ``SessionMessage`` 对象。
|
||||
supervisor: 当前持有该消息网关的 Supervisor。
|
||||
enabled_only: 兼容旧签名的保留参数,当前未使用。
|
||||
save_to_db: 发送成功后是否写入数据库。
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功。
|
||||
"""
|
||||
del enabled_only
|
||||
del supervisor
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
if not platform_io_manager.is_started:
|
||||
logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息")
|
||||
return False
|
||||
|
||||
route_key = platform_io_manager.build_route_key_from_message(internal_message)
|
||||
delivery_batch = await platform_io_manager.send_message(internal_message, route_key)
|
||||
if not delivery_batch.has_success:
|
||||
logger.warning("通过消息网关链路发送消息失败: 未命中任何成功回执")
|
||||
return False
|
||||
|
||||
first_successful_receipt = delivery_batch.sent_receipts[0]
|
||||
internal_message.message_id = first_successful_receipt.external_message_id or internal_message.message_id
|
||||
if save_to_db:
|
||||
try:
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
|
||||
MessageUtils.store_message_to_db(internal_message)
|
||||
except Exception as e:
|
||||
logger.error(f"保存消息到数据库失败: {e}")
|
||||
return True
|
||||
487
src/plugin_runtime/host/message_utils.py
Normal file
487
src/plugin_runtime/host/message_utils.py
Normal file
@@ -0,0 +1,487 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
AtComponent,
|
||||
DictComponent,
|
||||
EmojiComponent,
|
||||
ForwardComponent,
|
||||
ForwardNodeComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
StandardMessageComponents,
|
||||
TextComponent,
|
||||
VoiceComponent,
|
||||
)
|
||||
|
||||
logger = get_logger("plugin_runtime.host.message_utils")
|
||||
|
||||
|
||||
class UserInfoDict(TypedDict, total=False):
|
||||
user_id: str
|
||||
user_nickname: str
|
||||
user_cardname: Optional[str]
|
||||
|
||||
|
||||
class GroupInfoDict(TypedDict, total=False):
|
||||
group_id: str
|
||||
group_name: str
|
||||
|
||||
|
||||
class MessageInfoDict(TypedDict, total=False):
|
||||
user_info: UserInfoDict
|
||||
group_info: Optional[GroupInfoDict]
|
||||
additional_config: Dict[str, Any]
|
||||
|
||||
|
||||
class MessageDict(TypedDict, total=False):
|
||||
message_id: str
|
||||
timestamp: str
|
||||
platform: str
|
||||
message_info: MessageInfoDict
|
||||
raw_message: List[Dict[str, Any]]
|
||||
is_mentioned: bool
|
||||
is_at: bool
|
||||
is_emoji: bool
|
||||
is_picture: bool
|
||||
is_command: bool
|
||||
is_notify: bool
|
||||
session_id: str
|
||||
reply_to: Optional[str]
|
||||
processed_plain_text: Optional[str]
|
||||
display_message: Optional[str]
|
||||
|
||||
|
||||
class PluginMessageUtils:
|
||||
@staticmethod
|
||||
def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
|
||||
"""将消息组件序列转换为插件运行时使用的字典结构。
|
||||
|
||||
Args:
|
||||
message_sequence: 待转换的消息组件序列。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。
|
||||
"""
|
||||
return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components]
|
||||
|
||||
@staticmethod
|
||||
def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]:
|
||||
"""将单个消息组件转换为插件运行时字典结构。
|
||||
|
||||
Args:
|
||||
component: 待转换的消息组件。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的消息组件字典。
|
||||
"""
|
||||
if isinstance(component, TextComponent):
|
||||
return {"type": "text", "data": component.text}
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
serialized = {
|
||||
"type": "image",
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
return serialized
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
serialized = {
|
||||
"type": "emoji",
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
return serialized
|
||||
|
||||
if isinstance(component, VoiceComponent):
|
||||
serialized = {
|
||||
"type": "voice",
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
return serialized
|
||||
|
||||
if isinstance(component, AtComponent):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {
|
||||
"target_user_id": component.target_user_id,
|
||||
"target_user_nickname": component.target_user_nickname,
|
||||
"target_user_cardname": component.target_user_cardname,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
return {
|
||||
"type": "reply",
|
||||
"data": {
|
||||
"target_message_id": component.target_message_id,
|
||||
"target_message_content": component.target_message_content,
|
||||
"target_message_sender_id": component.target_message_sender_id,
|
||||
"target_message_sender_nickname": component.target_message_sender_nickname,
|
||||
"target_message_sender_cardname": component.target_message_sender_cardname,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(component, ForwardNodeComponent):
|
||||
return {
|
||||
"type": "forward",
|
||||
"data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components],
|
||||
}
|
||||
|
||||
return {"type": "dict", "data": component.data}
|
||||
|
||||
@staticmethod
|
||||
def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]:
|
||||
"""将单个转发节点组件转换为字典结构。
|
||||
|
||||
Args:
|
||||
component: 待转换的转发节点组件。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的转发节点字典。
|
||||
"""
|
||||
return {
|
||||
"user_id": component.user_id,
|
||||
"user_nickname": component.user_nickname,
|
||||
"user_cardname": component.user_cardname,
|
||||
"message_id": component.message_id,
|
||||
"content": [PluginMessageUtils._component_to_dict(item) for item in component.content],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _message_sequence_from_dict(raw_message_data: List[Dict[str, Any]]) -> MessageSequence:
|
||||
"""从插件运行时字典结构恢复消息组件序列。
|
||||
|
||||
Args:
|
||||
raw_message_data: 插件运行时消息段字典列表。
|
||||
|
||||
Returns:
|
||||
MessageSequence: 恢复后的消息组件序列。
|
||||
"""
|
||||
components = [PluginMessageUtils._component_from_dict(item) for item in raw_message_data]
|
||||
return MessageSequence(components=components)
|
||||
|
||||
@staticmethod
|
||||
def _component_from_dict(item: Dict[str, Any]) -> StandardMessageComponents:
|
||||
"""从插件运行时字典结构恢复单个消息组件。
|
||||
|
||||
Args:
|
||||
item: 单个消息组件的字典表示。
|
||||
|
||||
Returns:
|
||||
StandardMessageComponents: 恢复后的内部消息组件对象。
|
||||
"""
|
||||
item_type = str(item.get("type") or "").strip()
|
||||
if item_type == "text":
|
||||
return TextComponent(text=str(item.get("data") or ""))
|
||||
|
||||
if item_type == "image":
|
||||
return PluginMessageUtils._build_binary_component(ImageComponent, item)
|
||||
|
||||
if item_type == "emoji":
|
||||
return PluginMessageUtils._build_binary_component(EmojiComponent, item)
|
||||
|
||||
if item_type == "voice":
|
||||
return PluginMessageUtils._build_binary_component(VoiceComponent, item)
|
||||
|
||||
if item_type == "at":
|
||||
item_data = item.get("data", {})
|
||||
if not isinstance(item_data, dict):
|
||||
item_data = {}
|
||||
return AtComponent(
|
||||
target_user_id=str(item_data.get("target_user_id") or ""),
|
||||
target_user_nickname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_nickname")),
|
||||
target_user_cardname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_cardname")),
|
||||
)
|
||||
|
||||
if item_type == "reply":
|
||||
reply_data = item.get("data")
|
||||
if isinstance(reply_data, dict):
|
||||
return ReplyComponent(
|
||||
target_message_id=str(reply_data.get("target_message_id") or ""),
|
||||
target_message_content=PluginMessageUtils._normalize_optional_string(
|
||||
reply_data.get("target_message_content")
|
||||
),
|
||||
target_message_sender_id=PluginMessageUtils._normalize_optional_string(
|
||||
reply_data.get("target_message_sender_id")
|
||||
),
|
||||
target_message_sender_nickname=PluginMessageUtils._normalize_optional_string(
|
||||
reply_data.get("target_message_sender_nickname")
|
||||
),
|
||||
target_message_sender_cardname=PluginMessageUtils._normalize_optional_string(
|
||||
reply_data.get("target_message_sender_cardname")
|
||||
),
|
||||
)
|
||||
return ReplyComponent(target_message_id=str(reply_data or ""))
|
||||
|
||||
if item_type == "forward":
|
||||
forward_nodes: List[ForwardComponent] = []
|
||||
raw_forward_nodes = item.get("data", [])
|
||||
if isinstance(raw_forward_nodes, list):
|
||||
for node in raw_forward_nodes:
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
raw_content = node.get("content", [])
|
||||
node_components: List[StandardMessageComponents] = []
|
||||
if isinstance(raw_content, list):
|
||||
node_components = [
|
||||
PluginMessageUtils._component_from_dict(content)
|
||||
for content in raw_content
|
||||
if isinstance(content, dict)
|
||||
]
|
||||
if not node_components:
|
||||
node_components = [TextComponent(text="[empty forward node]")]
|
||||
forward_nodes.append(
|
||||
ForwardComponent(
|
||||
user_nickname=str(node.get("user_nickname") or "未知用户"),
|
||||
user_id=PluginMessageUtils._normalize_optional_string(node.get("user_id")),
|
||||
user_cardname=PluginMessageUtils._normalize_optional_string(node.get("user_cardname")),
|
||||
message_id=str(node.get("message_id") or ""),
|
||||
content=node_components,
|
||||
)
|
||||
)
|
||||
if not forward_nodes:
|
||||
return DictComponent(data={"type": "forward", "data": item.get("data", [])})
|
||||
return ForwardNodeComponent(forward_components=forward_nodes)
|
||||
|
||||
component_data = item.get("data")
|
||||
if isinstance(component_data, dict):
|
||||
return DictComponent(data=component_data)
|
||||
return DictComponent(data=item)
|
||||
|
||||
@staticmethod
|
||||
def _build_binary_component(component_cls: Any, item: Dict[str, Any]) -> StandardMessageComponents:
|
||||
"""从字典构造带二进制负载的消息组件。
|
||||
|
||||
Args:
|
||||
component_cls: 目标组件类型。
|
||||
item: 消息组件字典。
|
||||
|
||||
Returns:
|
||||
StandardMessageComponents: 构造后的组件对象。
|
||||
"""
|
||||
content = str(item.get("data") or "")
|
||||
binary_hash = str(item.get("hash") or "")
|
||||
raw_binary_base64 = item.get("binary_data_base64")
|
||||
binary_data = b""
|
||||
if isinstance(raw_binary_base64, str) and raw_binary_base64:
|
||||
try:
|
||||
binary_data = base64.b64decode(raw_binary_base64)
|
||||
except Exception:
|
||||
binary_data = b""
|
||||
|
||||
if not binary_hash and binary_data:
|
||||
binary_hash = hashlib.sha256(binary_data).hexdigest()
|
||||
|
||||
return component_cls(binary_hash=binary_hash, content=content, binary_data=binary_data)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_optional_string(value: Any) -> Optional[str]:
|
||||
"""将任意值规范化为可选字符串。
|
||||
|
||||
Args:
|
||||
value: 待规范化的值。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
normalized_value = str(value)
|
||||
return normalized_value if normalized_value else None
|
||||
|
||||
@staticmethod
|
||||
def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict:
|
||||
"""
|
||||
将 MessageInfo 对象转换为字典格式
|
||||
|
||||
Args:
|
||||
message_info: MessageInfo 对象
|
||||
|
||||
Returns:
|
||||
字典格式的消息信息
|
||||
"""
|
||||
user_info_dict = UserInfoDict(
|
||||
user_id=message_info.user_info.user_id,
|
||||
user_nickname=message_info.user_info.user_nickname,
|
||||
user_cardname=message_info.user_info.user_cardname,
|
||||
)
|
||||
|
||||
group_info_dict: Optional[GroupInfoDict] = None
|
||||
if message_info.group_info:
|
||||
group_info_dict = GroupInfoDict(
|
||||
group_id=message_info.group_info.group_id,
|
||||
group_name=message_info.group_info.group_name,
|
||||
)
|
||||
|
||||
return MessageInfoDict(
|
||||
user_info=user_info_dict,
|
||||
group_info=group_info_dict,
|
||||
additional_config=message_info.additional_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _session_message_to_dict(session_message: SessionMessage) -> MessageDict:
|
||||
"""
|
||||
将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
|
||||
|
||||
Args:
|
||||
session_message: SessionMessage 对象
|
||||
|
||||
Returns:
|
||||
字典格式的消息
|
||||
"""
|
||||
# 转换基本信息
|
||||
message_dict = MessageDict(
|
||||
message_id=session_message.message_id,
|
||||
timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
|
||||
platform=session_message.platform,
|
||||
message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
|
||||
raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message),
|
||||
is_mentioned=session_message.is_mentioned,
|
||||
is_at=session_message.is_at,
|
||||
is_emoji=session_message.is_emoji,
|
||||
is_picture=session_message.is_picture,
|
||||
is_command=session_message.is_command,
|
||||
is_notify=session_message.is_notify,
|
||||
session_id=session_message.session_id,
|
||||
)
|
||||
|
||||
# 添加可选字段
|
||||
if session_message.reply_to is not None:
|
||||
message_dict["reply_to"] = session_message.reply_to
|
||||
if session_message.processed_plain_text is not None:
|
||||
message_dict["processed_plain_text"] = session_message.processed_plain_text
|
||||
if session_message.display_message is not None:
|
||||
message_dict["display_message"] = session_message.display_message
|
||||
|
||||
return message_dict
|
||||
|
||||
@staticmethod
|
||||
def _build_message_info_from_dict(message_info_dict: Dict[str, Any]) -> MessageInfo:
|
||||
"""
|
||||
从字典构建 MessageInfo 对象
|
||||
|
||||
Args:
|
||||
message_info_dict: 包含消息信息的字典
|
||||
|
||||
Returns:
|
||||
MessageInfo 对象
|
||||
"""
|
||||
# 构建用户信息
|
||||
user_info_dict = message_info_dict.get("user_info")
|
||||
if not user_info_dict or not isinstance(user_info_dict, dict):
|
||||
raise ValueError("消息字典中 'user_info' 字段无效")
|
||||
user_id = user_info_dict.get("user_id")
|
||||
user_nickname = user_info_dict.get("user_nickname")
|
||||
user_cardname = user_info_dict.get("user_cardname")
|
||||
if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname:
|
||||
raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'")
|
||||
user_cardname = str(user_cardname) if user_cardname is not None else None
|
||||
user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname)
|
||||
|
||||
# 构建群信息
|
||||
if group_info_dict := message_info_dict.get("group_info"):
|
||||
group_id = group_info_dict.get("group_id")
|
||||
group_name = group_info_dict.get("group_name")
|
||||
if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name:
|
||||
raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'")
|
||||
group_info = GroupInfo(group_id=group_id, group_name=group_name)
|
||||
else:
|
||||
group_info = None
|
||||
|
||||
# 获取额外配置
|
||||
additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {})
|
||||
|
||||
return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config)
|
||||
|
||||
@staticmethod
|
||||
def _build_session_message_from_dict(message_dict: Dict[str, Any]) -> SessionMessage:
|
||||
"""
|
||||
从字典构建 SessionMessage 对象(递归处理消息组件)
|
||||
|
||||
Args:
|
||||
message_dict: 包含消息完整信息的字典
|
||||
|
||||
Returns:
|
||||
SessionMessage 对象
|
||||
"""
|
||||
# 提取基本信息
|
||||
message_id = message_dict["message_id"]
|
||||
timestamp_str: str = message_dict.get("timestamp", "")
|
||||
platform = message_dict["platform"]
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
raise ValueError("消息字典中缺少有效的 'message_id' 字段")
|
||||
if not isinstance(platform, str) or not platform:
|
||||
raise ValueError("消息字典中缺少有效的 'platform' 字段")
|
||||
|
||||
# 解析时间戳
|
||||
try:
|
||||
timestamp_float = float(timestamp_str)
|
||||
timestamp = datetime.fromtimestamp(timestamp_float)
|
||||
except (ValueError, TypeError):
|
||||
timestamp = datetime.now() # 如果解析失败,使用当前时间
|
||||
|
||||
# 创建 SessionMessage 实例
|
||||
session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform)
|
||||
|
||||
# 构建消息信息
|
||||
session_message.message_info = PluginMessageUtils._build_message_info_from_dict(message_dict["message_info"])
|
||||
|
||||
# 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
|
||||
raw_message_data = message_dict["raw_message"]
|
||||
if isinstance(raw_message_data, list):
|
||||
session_message.raw_message = PluginMessageUtils._message_sequence_from_dict(raw_message_data)
|
||||
else:
|
||||
raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
|
||||
|
||||
# 设置其他可选属性
|
||||
session_message.is_mentioned = message_dict.get("is_mentioned", False)
|
||||
if not isinstance(session_message.is_mentioned, bool):
|
||||
session_message.is_mentioned = False
|
||||
session_message.is_at = message_dict.get("is_at", False)
|
||||
if not isinstance(session_message.is_at, bool):
|
||||
session_message.is_at = False
|
||||
session_message.is_emoji = message_dict.get("is_emoji", False)
|
||||
if not isinstance(session_message.is_emoji, bool):
|
||||
session_message.is_emoji = False
|
||||
session_message.is_picture = message_dict.get("is_picture", False)
|
||||
if not isinstance(session_message.is_picture, bool):
|
||||
session_message.is_picture = False
|
||||
session_message.is_command = message_dict.get("is_command", False)
|
||||
if not isinstance(session_message.is_command, bool):
|
||||
session_message.is_command = False
|
||||
session_message.is_notify = message_dict.get("is_notify", False)
|
||||
if not isinstance(session_message.is_notify, bool):
|
||||
session_message.is_notify = False
|
||||
session_message.session_id = message_dict.get("session_id", "")
|
||||
if not isinstance(session_message.session_id, str):
|
||||
session_message.session_id = ""
|
||||
session_message.reply_to = message_dict.get("reply_to")
|
||||
if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
|
||||
session_message.reply_to = None
|
||||
session_message.processed_plain_text = message_dict.get("processed_plain_text")
|
||||
if session_message.processed_plain_text is not None and not isinstance(
|
||||
session_message.processed_plain_text, str
|
||||
):
|
||||
session_message.processed_plain_text = None
|
||||
session_message.display_message = message_dict.get("display_message")
|
||||
if session_message.display_message is not None and not isinstance(session_message.display_message, str):
|
||||
session_message.display_message = None
|
||||
|
||||
return session_message
|
||||
@@ -1,97 +0,0 @@
|
||||
"""策略引擎
|
||||
|
||||
负责能力授权校验。
|
||||
每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityToken:
|
||||
"""能力令牌"""
|
||||
|
||||
plugin_id: str
|
||||
generation: int
|
||||
capabilities: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class PolicyEngine:
|
||||
"""策略引擎
|
||||
|
||||
管理所有插件的能力令牌,提供授权校验。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tokens: Dict[str, Dict[int, CapabilityToken]] = {}
|
||||
|
||||
def register_plugin(
|
||||
self,
|
||||
plugin_id: str,
|
||||
generation: int,
|
||||
capabilities: List[str],
|
||||
) -> CapabilityToken:
|
||||
"""为插件签发能力令牌"""
|
||||
token = CapabilityToken(
|
||||
plugin_id=plugin_id,
|
||||
generation=generation,
|
||||
capabilities=set(capabilities),
|
||||
)
|
||||
self._tokens.setdefault(plugin_id, {})[generation] = token
|
||||
return token
|
||||
|
||||
def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None:
|
||||
"""撤销插件的能力令牌。"""
|
||||
if generation is None:
|
||||
self._tokens.pop(plugin_id, None)
|
||||
return
|
||||
|
||||
generations = self._tokens.get(plugin_id)
|
||||
if generations is None:
|
||||
return
|
||||
|
||||
generations.pop(generation, None)
|
||||
if not generations:
|
||||
self._tokens.pop(plugin_id, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有能力令牌。"""
|
||||
self._tokens.clear()
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
(allowed, reason)
|
||||
"""
|
||||
generations = self._tokens.get(plugin_id)
|
||||
if not generations:
|
||||
return False, f"插件 {plugin_id} 未注册能力令牌"
|
||||
|
||||
if generation is None:
|
||||
token = generations[max(generations)]
|
||||
else:
|
||||
token = generations.get(generation)
|
||||
if token is None:
|
||||
active_generation = max(generations)
|
||||
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}"
|
||||
|
||||
if capability not in token.capabilities:
|
||||
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
|
||||
|
||||
if generation is not None and token.generation != generation:
|
||||
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
|
||||
"""获取插件的能力令牌"""
|
||||
generations = self._tokens.get(plugin_id)
|
||||
if not generations:
|
||||
return None
|
||||
return generations[max(generations)]
|
||||
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已注册的插件"""
|
||||
return list(self._tokens.keys())
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 请求-响应关联与超时管理
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer
|
||||
logger = get_logger("plugin_runtime.host.rpc_server")
|
||||
|
||||
# RPC 方法处理器类型
|
||||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||||
MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]]
|
||||
|
||||
|
||||
class RPCServer:
|
||||
@@ -55,108 +55,39 @@ class RPCServer:
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: Optional[str] = None
|
||||
self._runner_generation: int = 0
|
||||
self._staged_connection: Optional[Connection] = None
|
||||
self._staged_runner_id: Optional[str] = None
|
||||
self._staged_runner_generation: int = 0
|
||||
self._staging_takeover: bool = False
|
||||
|
||||
# 方法处理器注册表
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
|
||||
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
|
||||
self._send_worker_task: Optional[asyncio.Task] = None
|
||||
self._send_worker_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._tasks: List[asyncio.Task] = []
|
||||
self._tasks: List[asyncio.Task[None]] = []
|
||||
self._last_handshake_rejection_reason: str = ""
|
||||
self._connection_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def session_token(self) -> str:
|
||||
return self._session_token
|
||||
|
||||
def reset_session_token(self) -> str:
|
||||
"""重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
|
||||
self._session_token = secrets.token_hex(32)
|
||||
return self._session_token
|
||||
|
||||
def restore_session_token(self, token: str) -> None:
|
||||
"""恢复指定的会话令牌(热重载回滚时调用)"""
|
||||
self._session_token = token
|
||||
|
||||
@property
|
||||
def runner_generation(self) -> int:
|
||||
return self._runner_generation
|
||||
|
||||
@property
|
||||
def staged_generation(self) -> int:
|
||||
return self._staged_runner_generation
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def has_generation(self, generation: int) -> bool:
|
||||
return generation == self._runner_generation or (
|
||||
self._staged_connection is not None
|
||||
and not self._staged_connection.is_closed
|
||||
and generation == self._staged_runner_generation
|
||||
)
|
||||
@property
|
||||
def last_handshake_rejection_reason(self) -> str:
|
||||
"""返回最近一次握手被拒绝的原因。"""
|
||||
return self._last_handshake_rejection_reason
|
||||
|
||||
def begin_staged_takeover(self) -> None:
|
||||
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
|
||||
self._staging_takeover = True
|
||||
|
||||
async def commit_staged_takeover(self) -> None:
|
||||
"""提交 staged Runner,原活跃连接在提交后被关闭。"""
|
||||
if self._staged_connection is None or self._staged_connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
|
||||
|
||||
old_connection = self._connection
|
||||
old_generation = self._runner_generation
|
||||
|
||||
self._connection = self._staged_connection
|
||||
self._runner_id = self._staged_runner_id
|
||||
self._runner_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
if stale_count := self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=old_generation,
|
||||
):
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
|
||||
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
|
||||
await old_connection.close()
|
||||
|
||||
async def rollback_staged_takeover(self) -> None:
|
||||
"""放弃 staged Runner,保留当前活跃连接。"""
|
||||
staged_connection = self._staged_connection
|
||||
staged_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"新 Runner 预热失败,已回滚",
|
||||
generation=staged_generation,
|
||||
)
|
||||
|
||||
if staged_connection and not staged_connection.is_closed:
|
||||
await staged_connection.close()
|
||||
def clear_handshake_state(self) -> None:
|
||||
"""清空最近一次握手拒绝状态。"""
|
||||
self._last_handshake_rejection_reason = ""
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册 RPC 方法处理器"""
|
||||
@@ -165,6 +96,7 @@ class RPCServer:
|
||||
async def start(self) -> None:
|
||||
"""启动 RPC 服务器"""
|
||||
self._running = True
|
||||
self.clear_handshake_state()
|
||||
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
|
||||
self._send_worker_task = asyncio.create_task(self._send_loop())
|
||||
await self._transport.start(self._handle_connection)
|
||||
@@ -173,14 +105,9 @@ class RPCServer:
|
||||
async def stop(self) -> None:
|
||||
"""停止 RPC 服务器"""
|
||||
self._running = False
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future, _generation in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
|
||||
self.clear_handshake_state()
|
||||
self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
|
||||
self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
|
||||
|
||||
if self._send_worker_task:
|
||||
self._send_worker_task.cancel()
|
||||
@@ -198,10 +125,6 @@ class RPCServer:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
if self._staged_connection:
|
||||
await self._staged_connection.close()
|
||||
self._staged_connection = None
|
||||
|
||||
await self._transport.stop()
|
||||
logger.info("RPC Server 已停止")
|
||||
|
||||
@@ -211,7 +134,6 @@ class RPCServer:
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
target_generation: Optional[int] = None,
|
||||
) -> Envelope:
|
||||
"""向 Runner 发送 RPC 请求并等待响应
|
||||
|
||||
@@ -227,18 +149,14 @@ class RPCServer:
|
||||
Raises:
|
||||
RPCError: 调用失败
|
||||
"""
|
||||
generation = target_generation or self._runner_generation
|
||||
conn = self._get_connection_for_generation(generation)
|
||||
if conn is None or conn.is_closed:
|
||||
if not self._connection or self._connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
@@ -246,12 +164,12 @@ class RPCServer:
|
||||
# 注册 pending future
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Envelope] = loop.create_future()
|
||||
self._pending_requests[request_id] = (future, generation)
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._enqueue_send(conn, data)
|
||||
await self._enqueue_send(self._connection, data)
|
||||
|
||||
# 等待响应
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
@@ -265,150 +183,136 @@ class RPCServer:
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||
|
||||
async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""向 Runner 发送单向事件(不等待响应)"""
|
||||
conn = self._connection
|
||||
if conn is None or conn.is_closed:
|
||||
return
|
||||
# ============ 内部方法 ============
|
||||
# ========= 发送循环 =========
|
||||
async def _send_loop(self) -> None:
|
||||
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
|
||||
if self._send_queue is None:
|
||||
raise RuntimeError("没有消息队列")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.EVENT,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._runner_generation,
|
||||
payload=payload or {},
|
||||
)
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._enqueue_send(conn, data)
|
||||
while True:
|
||||
try:
|
||||
conn, data, send_future = await self._send_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
try:
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
await conn.send_frame(data)
|
||||
if not send_future.done():
|
||||
send_future.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
if not send_future.done():
|
||||
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
raise
|
||||
except Exception as e:
|
||||
send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED})
|
||||
if not send_future.done():
|
||||
send_future.set_exception(send_error)
|
||||
finally:
|
||||
self._send_queue.task_done()
|
||||
|
||||
# ====== 发送循环方法 ======
|
||||
async def _handle_connection(self, conn: Connection) -> None:
|
||||
"""处理新的 Runner 连接"""
|
||||
logger.info("收到 Runner 连接")
|
||||
previous_connection = self._connection
|
||||
previous_generation = self._runner_generation
|
||||
|
||||
# 第一条消息必须是 runner.hello 握手
|
||||
try:
|
||||
role = await self._handle_handshake(conn)
|
||||
if role is None:
|
||||
await conn.close()
|
||||
return
|
||||
async with self._connection_lock:
|
||||
self.clear_handshake_state()
|
||||
success = await self._handle_handshake(conn)
|
||||
if not success:
|
||||
await conn.close()
|
||||
return
|
||||
logger.info("Runner staged 握手成功")
|
||||
self._connection = conn
|
||||
except Exception as e:
|
||||
logger.error(f"握手失败: {e}")
|
||||
await conn.close()
|
||||
return
|
||||
|
||||
if role == "staged":
|
||||
expected_generation = self._staged_runner_generation
|
||||
logger.info(
|
||||
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
|
||||
)
|
||||
else:
|
||||
self._connection = conn
|
||||
expected_generation = self._runner_generation
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
|
||||
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
|
||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||
if stale_count := self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=previous_generation,
|
||||
):
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
await previous_connection.close()
|
||||
|
||||
# 启动消息接收循环
|
||||
try:
|
||||
await self._recv_loop(conn, expected_generation=expected_generation)
|
||||
await self._recv_loop(conn)
|
||||
except Exception as e:
|
||||
logger.error(f"连接异常断开: {e}")
|
||||
finally:
|
||||
if self._connection is conn:
|
||||
self._connection = None
|
||||
self._runner_id = None
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
elif self._staged_connection is conn:
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Staged Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
should_fail_pending_requests = False
|
||||
async with self._connection_lock:
|
||||
if self._connection is conn:
|
||||
self._connection = None
|
||||
should_fail_pending_requests = True
|
||||
if should_fail_pending_requests:
|
||||
self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
|
||||
|
||||
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
|
||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
||||
"""处理 runner.hello 握手"""
|
||||
# 接收握手请求
|
||||
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
|
||||
if envelope.method != "runner.hello":
|
||||
logger.error(f"期望 runner.hello,收到 {envelope.method}")
|
||||
self._last_handshake_rejection_reason = "首条消息必须为 runner.hello"
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_PROTOCOL_MISMATCH.value,
|
||||
"首条消息必须为 runner.hello",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 解析握手 payload
|
||||
hello = HelloPayload.model_validate(envelope.payload)
|
||||
|
||||
# 校验会话令牌
|
||||
if hello.session_token != self._session_token:
|
||||
logger.error("会话令牌不匹配")
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=False,
|
||||
reason="会话令牌无效",
|
||||
)
|
||||
self._last_handshake_rejection_reason = "会话令牌无效"
|
||||
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。
|
||||
if self.is_connected:
|
||||
logger.warning("拒绝新的 Runner 连接:已有活跃连接")
|
||||
self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手"
|
||||
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
|
||||
# 校验 SDK 版本
|
||||
if not self._check_sdk_version(hello.sdk_version):
|
||||
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
|
||||
self._last_handshake_rejection_reason = (
|
||||
f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]"
|
||||
)
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=False,
|
||||
reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
|
||||
reason=self._last_handshake_rejection_reason,
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 握手成功
|
||||
role = "active"
|
||||
assigned_generation = self._runner_generation + 1
|
||||
if self._staging_takeover and self.is_connected:
|
||||
role = "staged"
|
||||
self._staged_connection = conn
|
||||
self._staged_runner_id = hello.runner_id
|
||||
self._staged_runner_generation = assigned_generation
|
||||
else:
|
||||
self._runner_id = hello.runner_id
|
||||
self._runner_generation = assigned_generation
|
||||
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=True,
|
||||
host_version=PROTOCOL_VERSION,
|
||||
assigned_generation=assigned_generation,
|
||||
)
|
||||
# 发送响应
|
||||
self.clear_handshake_state()
|
||||
resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return True
|
||||
|
||||
return role
|
||||
def _check_sdk_version(self, sdk_version: str) -> bool:
|
||||
"""检查 SDK 版本是否在支持范围内"""
|
||||
try:
|
||||
sdk_parts = _parse_version_tuple(sdk_version)
|
||||
min_parts = _parse_version_tuple(MIN_SDK_VERSION)
|
||||
max_parts = _parse_version_tuple(MAX_SDK_VERSION)
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
|
||||
# ========= 接收循环 =========
|
||||
async def _recv_loop(self, conn: Connection) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and not conn.is_closed:
|
||||
try:
|
||||
@@ -430,109 +334,40 @@ class RPCServer:
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
if envelope.generation != expected_generation:
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"过期 generation: {envelope.generation} != {expected_generation}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
continue
|
||||
# 异步处理请求(Runner 发来的能力调用)
|
||||
task = asyncio.create_task(self._handle_request(envelope, conn))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
elif envelope.is_event():
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
continue
|
||||
task = asyncio.create_task(self._handle_event(envelope))
|
||||
elif envelope.is_broadcast():
|
||||
task = asyncio.create_task(self._handle_broadcast(envelope))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
else:
|
||||
logger.warning(f"未知的消息类型: {envelope.message_type}")
|
||||
continue
|
||||
|
||||
# ====== 接收循环内部方法 ======
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的响应"""
|
||||
pending = self._pending_requests.get(envelope.request_id)
|
||||
if pending is None:
|
||||
pending_future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if pending_future is None:
|
||||
return
|
||||
|
||||
future, expected_generation = pending
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
return
|
||||
|
||||
self._pending_requests.pop(envelope.request_id, None)
|
||||
if not future.done():
|
||||
if not pending_future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
pending_future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
|
||||
"""通过发送队列串行发送消息,提供真实背压。"""
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
if self._send_queue is None:
|
||||
await conn.send_frame(data)
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
send_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
try:
|
||||
self._send_queue.put_nowait((conn, data, send_future))
|
||||
except asyncio.QueueFull:
|
||||
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
|
||||
|
||||
await send_future
|
||||
|
||||
async def _send_loop(self) -> None:
|
||||
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
|
||||
if self._send_queue is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
conn, data, send_future = await self._send_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
await conn.send_frame(data)
|
||||
if not send_future.done():
|
||||
send_future.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
if not send_future.done():
|
||||
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
raise
|
||||
except Exception as e:
|
||||
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
|
||||
if not send_future.done():
|
||||
send_future.set_exception(send_error)
|
||||
finally:
|
||||
self._send_queue.task_done()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_send_exception(error: Exception) -> RPCError:
|
||||
if isinstance(error, ConnectionError):
|
||||
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
|
||||
return RPCError(ErrorCode.E_UNKNOWN, str(error))
|
||||
pending_future.set_result(envelope)
|
||||
|
||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
error_resp = envelope.make_error_response(
|
||||
target_method = envelope.method
|
||||
handler = self._method_handlers.get(target_method)
|
||||
if not handler:
|
||||
error_response = envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的方法: {envelope.method}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
await conn.send_frame(self._codec.encode_envelope(error_response))
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -546,59 +381,25 @@ class RPCServer:
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
|
||||
async def _handle_event(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的事件"""
|
||||
async def _handle_broadcast(self, envelope: Envelope) -> None:
|
||||
if handler := self._method_handlers.get(envelope.method):
|
||||
try:
|
||||
result = await handler(envelope)
|
||||
# 检查 handler 返回的信封是否包含错误信息
|
||||
if result is not None and isinstance(result, Envelope) and result.error:
|
||||
if result.error:
|
||||
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _check_sdk_version(sdk_version: str) -> bool:
|
||||
"""检查 SDK 版本是否在支持范围内"""
|
||||
try:
|
||||
sdk_parts = RPCServer._parse_version_tuple(sdk_version)
|
||||
min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION)
|
||||
max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION)
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
|
||||
base_version = base_version.split("+", 1)[0]
|
||||
parts = [part for part in base_version.split(".") if part != ""]
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||
|
||||
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
|
||||
if generation == self._runner_generation:
|
||||
return self._connection
|
||||
if generation == self._staged_runner_generation:
|
||||
return self._staged_connection
|
||||
return None
|
||||
|
||||
def _fail_pending_requests(
|
||||
self,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
generation: Optional[int] = None,
|
||||
) -> int:
|
||||
stale_count = 0
|
||||
for request_id, (future, request_generation) in list(self._pending_requests.items()):
|
||||
if generation is not None and request_generation != generation:
|
||||
continue
|
||||
def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int:
|
||||
"""失败所有等待中的请求(如连接断开时)"""
|
||||
aborted_request_count = 0
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(error_code, message))
|
||||
stale_count += 1
|
||||
self._pending_requests.pop(request_id, None)
|
||||
return stale_count
|
||||
aborted_request_count += 1
|
||||
self._pending_requests.clear()
|
||||
return aborted_request_count
|
||||
|
||||
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
|
||||
if self._send_queue is None:
|
||||
@@ -617,3 +418,31 @@ class RPCServer:
|
||||
self._send_queue.task_done()
|
||||
|
||||
return failed_count
|
||||
|
||||
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
|
||||
"""通过发送队列串行发送消息,提供真实背压。"""
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
if self._send_queue is None:
|
||||
await conn.send_frame(data)
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
send_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
try:
|
||||
self._send_queue.put_nowait((conn, data, send_future))
|
||||
except asyncio.QueueFull:
|
||||
raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None
|
||||
|
||||
await send_future
|
||||
|
||||
|
||||
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
|
||||
base_version = base_version.split("+", 1)[0]
|
||||
parts = [part for part in base_version.split(".") if part != ""]
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,422 +0,0 @@
|
||||
"""Host-side WorkflowExecutor
|
||||
|
||||
6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS)
|
||||
|
||||
每个阶段执行顺序:
|
||||
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
|
||||
2. 按 priority 降序排列
|
||||
3. 串行执行 blocking hook(可修改 message,返回 HookResult)
|
||||
4. 并发执行 non-blocking hook(只读)
|
||||
5. 检查是否有 SKIP_STAGE 或 ABORT
|
||||
6. PLAN 阶段内置 Command 匹配路由
|
||||
|
||||
支持:
|
||||
- HookResult: CONTINUE / SKIP_STAGE / ABORT
|
||||
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
|
||||
- stage_outputs: 阶段间带命名空间的数据传递
|
||||
- modification_log: 消息修改审计
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
||||
|
||||
logger = get_logger("plugin_runtime.host.workflow_executor")
|
||||
|
||||
# 阶段顺序
|
||||
STAGE_SEQUENCE: List[str] = [
|
||||
"ingress",
|
||||
"pre_process",
|
||||
"plan",
|
||||
"tool_execute",
|
||||
"post_process",
|
||||
"egress",
|
||||
]
|
||||
|
||||
# HookResult 常量(与 SDK HookResult enum 值对应)
|
||||
HOOK_CONTINUE = "continue"
|
||||
HOOK_SKIP_STAGE = "skip_stage"
|
||||
HOOK_ABORT = "abort"
|
||||
|
||||
|
||||
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
|
||||
# 从配置文件读取,允许用户调整
|
||||
def _get_blocking_timeout() -> float:
|
||||
return global_config.plugin_runtime.workflow_blocking_timeout_sec
|
||||
|
||||
|
||||
class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
|
||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
||||
self.stage = stage
|
||||
self.hook_name = hook_name
|
||||
self.timestamp = time.perf_counter()
|
||||
self.fields_changed = fields_changed
|
||||
|
||||
|
||||
class WorkflowContext:
|
||||
"""Workflow 执行上下文"""
|
||||
|
||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
|
||||
self.trace_id = trace_id or uuid.uuid4().hex
|
||||
self.stream_id = stream_id
|
||||
self.timings: Dict[str, float] = {}
|
||||
self.errors: List[str] = []
|
||||
# 阶段间数据传递(按 stage 命名空间隔离)
|
||||
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
|
||||
# 消息修改审计日志
|
||||
self.modification_log: List[ModificationRecord] = []
|
||||
# PLAN 阶段命令匹配结果
|
||||
self.matched_command: Optional[str] = None
|
||||
|
||||
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
|
||||
self.stage_outputs.setdefault(stage, {})[key] = value
|
||||
|
||||
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
|
||||
return self.stage_outputs.get(stage, {}).get(key, default)
|
||||
|
||||
|
||||
class WorkflowResult:
|
||||
"""Workflow 执行结果"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: str = "completed", # completed / aborted / failed
|
||||
return_message: str = "",
|
||||
stopped_at: str = "",
|
||||
diagnostics: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.status = status
|
||||
self.return_message = return_message
|
||||
self.stopped_at = stopped_at
|
||||
self.diagnostics = diagnostics or {}
|
||||
|
||||
|
||||
# invoke_fn 签名
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
"""Host-side Workflow 执行器
|
||||
|
||||
实现 stage-based pipeline + per-stage hook chain with priority + early return。
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ComponentRegistry) -> None:
|
||||
self._registry = registry
|
||||
self._background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
command_invoke_fn: Optional[InvokeFn] = None,
|
||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||
"""执行 workflow pipeline。
|
||||
|
||||
Args:
|
||||
invoke_fn: 用于 workflow_step 的回调
|
||||
command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command),
|
||||
未传则复用 invoke_fn
|
||||
|
||||
Returns:
|
||||
(result, final_message, context)
|
||||
"""
|
||||
ctx = context or WorkflowContext(stream_id=stream_id)
|
||||
current_message = dict(message) if message else None
|
||||
|
||||
for stage in STAGE_SEQUENCE:
|
||||
stage_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
# PLAN 阶段: 先做 Command 路由
|
||||
if stage == "plan" and current_message:
|
||||
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
|
||||
if cmd_result is not None:
|
||||
# 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
|
||||
ctx.set_stage_output("plan", "command_result", cmd_result)
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
continue
|
||||
|
||||
# 获取该阶段所有 hook(已按 priority 降序排列)
|
||||
all_steps = self._registry.get_workflow_steps(stage)
|
||||
if not all_steps:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
continue
|
||||
|
||||
# 1. Pre-filter
|
||||
filtered_steps = self._pre_filter(all_steps, current_message)
|
||||
|
||||
# 2. 分离 blocking 和 non-blocking
|
||||
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
|
||||
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
|
||||
|
||||
# 3. 串行执行 blocking hook
|
||||
skip_stage = False
|
||||
for step in blocking_steps:
|
||||
hook_result, modified, step_error = await self._invoke_step(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
)
|
||||
|
||||
if step_error:
|
||||
error_policy = step.metadata.get("error_policy", "abort")
|
||||
ctx.errors.append(f"{step.full_name}: {step_error}")
|
||||
|
||||
if error_policy == "abort":
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="failed",
|
||||
return_message=step_error,
|
||||
stopped_at=stage,
|
||||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
elif error_policy == "skip":
|
||||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
|
||||
continue
|
||||
else: # log
|
||||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
|
||||
continue
|
||||
|
||||
# 更新消息(仅 blocking hook 有权修改)
|
||||
if modified:
|
||||
changed_fields = (
|
||||
_diff_keys(current_message, modified) if current_message else list(modified.keys())
|
||||
)
|
||||
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
|
||||
current_message = modified
|
||||
|
||||
if hook_result == HOOK_ABORT:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="aborted",
|
||||
return_message=f"aborted by {step.full_name}",
|
||||
stopped_at=stage,
|
||||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
if hook_result == HOOK_SKIP_STAGE:
|
||||
skip_stage = True
|
||||
break
|
||||
|
||||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
||||
if nonblocking_steps and not skip_stage:
|
||||
for step in nonblocking_steps:
|
||||
self._track_background_task(
|
||||
asyncio.create_task(
|
||||
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
|
||||
)
|
||||
)
|
||||
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
|
||||
except Exception as e:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
ctx.errors.append(f"{stage}: {e}")
|
||||
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="failed",
|
||||
return_message=str(e),
|
||||
stopped_at=stage,
|
||||
diagnostics={"trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="completed",
|
||||
return_message="workflow completed",
|
||||
diagnostics={"trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task) -> None:
|
||||
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
|
||||
def _pre_filter(
|
||||
self,
|
||||
steps: List[RegisteredComponent],
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> List[RegisteredComponent]:
|
||||
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
|
||||
if not message:
|
||||
return steps
|
||||
|
||||
result = []
|
||||
for step in steps:
|
||||
filter_cond = step.metadata.get("filter", {})
|
||||
if not filter_cond:
|
||||
result.append(step)
|
||||
continue
|
||||
if self._match_filter(filter_cond, message):
|
||||
result.append(step)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
|
||||
"""简单 key-value 匹配过滤。
|
||||
|
||||
filter 中的每个 key 必须在 message 中存在且值相等,
|
||||
全部匹配才通过。
|
||||
"""
|
||||
for key, expected in filter_cond.items():
|
||||
actual = message.get(key)
|
||||
if (isinstance(expected, list) and actual not in expected) or (
|
||||
not isinstance(expected, list) and actual != expected
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _invoke_step(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""调用单个 blocking hook。
|
||||
|
||||
Returns:
|
||||
(hook_result, modified_message, error_string_or_None)
|
||||
"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
# 使用 hook 声明的超时,但不超过全局安全阀
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
||||
step_key = f"{stage}:{step.full_name}"
|
||||
step_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
coro = invoke_fn(
|
||||
step.plugin_id,
|
||||
step.name,
|
||||
{
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
},
|
||||
)
|
||||
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
|
||||
hook_result = resp.get("hook_result", HOOK_CONTINUE)
|
||||
modified_message = resp.get("modified_message")
|
||||
# 存 stage output(如果 hook 提供了)
|
||||
stage_out = resp.get("stage_output")
|
||||
if isinstance(stage_out, dict):
|
||||
for k, v in stage_out.items():
|
||||
ctx.set_stage_output(stage, k, v)
|
||||
|
||||
return hook_result, modified_message, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
|
||||
|
||||
except Exception as e:
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
return HOOK_CONTINUE, None, str(e)
|
||||
|
||||
async def _invoke_step_fire_and_forget(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Non-blocking hook 调用,只读,忽略结果。"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
# 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
||||
|
||||
try:
|
||||
coro = invoke_fn(
|
||||
step.plugin_id,
|
||||
step.name,
|
||||
{
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
},
|
||||
)
|
||||
await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
|
||||
except Exception as e:
|
||||
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
|
||||
|
||||
async def _route_command(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: Dict[str, Any],
|
||||
ctx: WorkflowContext,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""PLAN 阶段内置 Command 路由。
|
||||
|
||||
在 registry 中查找匹配的 command 组件,
|
||||
匹配到则直接路由到对应 command handler,返回执行结果。
|
||||
不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。
|
||||
"""
|
||||
plain_text = message.get("plain_text", "")
|
||||
if not plain_text:
|
||||
return None
|
||||
|
||||
match_result = self._registry.find_command_by_text(plain_text)
|
||||
if match_result is None:
|
||||
return None
|
||||
|
||||
matched, matched_groups = match_result
|
||||
|
||||
ctx.matched_command = matched.full_name
|
||||
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
||||
|
||||
try:
|
||||
return await invoke_fn(
|
||||
matched.plugin_id,
|
||||
matched.name,
|
||||
{
|
||||
"text": plain_text,
|
||||
"message": message,
|
||||
"trace_id": ctx.trace_id,
|
||||
"matched_groups": matched_groups,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
||||
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
|
||||
"""返回 new 中与 old 不同的 key 列表。"""
|
||||
return [k for k, v in new.items() if k not in old or old[k] != v]
|
||||
@@ -8,23 +8,27 @@
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import config_manager
|
||||
from src.config.file_watcher import FileChange, FileWatcher
|
||||
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
|
||||
from src.plugin_runtime.capabilities import (
|
||||
RuntimeComponentCapabilityMixin,
|
||||
RuntimeCoreCapabilityMixin,
|
||||
RuntimeDataCapabilityMixin,
|
||||
)
|
||||
from src.plugin_runtime.capabilities.registry import register_capability_impls
|
||||
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
@@ -55,6 +59,7 @@ class PluginRuntimeManager(
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化插件运行时管理器。"""
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
self._builtin_supervisor: Optional[PluginSupervisor] = None
|
||||
@@ -63,6 +68,26 @@ class PluginRuntimeManager(
|
||||
self._plugin_file_watcher: Optional[FileWatcher] = None
|
||||
self._plugin_source_watcher_subscription_id: Optional[str] = None
|
||||
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
|
||||
self._plugin_path_cache: Dict[str, Path] = {}
|
||||
self._manifest_validator: ManifestValidator = ManifestValidator()
|
||||
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
|
||||
self._config_reload_callback_registered: bool = False
|
||||
|
||||
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
||||
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
||||
|
||||
Args:
|
||||
envelope: Platform IO 产出的入站封装。
|
||||
"""
|
||||
session_message = envelope.session_message
|
||||
if session_message is None and envelope.payload is not None:
|
||||
session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload))
|
||||
if session_message is None:
|
||||
raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload")
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
|
||||
await chat_bot.receive_message(session_message)
|
||||
|
||||
# ─── 插件目录 ─────────────────────────────────────────────
|
||||
|
||||
@@ -78,6 +103,42 @@ class PluginRuntimeManager(
|
||||
candidate = Path("plugins").resolve()
|
||||
return [candidate] if candidate.is_dir() else []
|
||||
|
||||
@classmethod
|
||||
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
|
||||
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
|
||||
validator = ManifestValidator()
|
||||
return validator.build_plugin_dependency_map(plugin_dirs)
|
||||
|
||||
@classmethod
|
||||
def _build_group_start_order(
|
||||
cls,
|
||||
builtin_dirs: Sequence[Path],
|
||||
third_party_dirs: Sequence[Path],
|
||||
) -> List[str]:
|
||||
"""根据跨 Supervisor 依赖关系决定 Runner 启动顺序。"""
|
||||
|
||||
builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs)
|
||||
third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs)
|
||||
builtin_plugin_ids = set(builtin_dependencies)
|
||||
third_party_plugin_ids = set(third_party_dependencies)
|
||||
|
||||
builtin_needs_third_party = any(
|
||||
dependency in third_party_plugin_ids
|
||||
for dependencies in builtin_dependencies.values()
|
||||
for dependency in dependencies
|
||||
)
|
||||
third_party_needs_builtin = any(
|
||||
dependency in builtin_plugin_ids
|
||||
for dependencies in third_party_dependencies.values()
|
||||
for dependency in dependencies
|
||||
)
|
||||
|
||||
if builtin_needs_third_party and third_party_needs_builtin:
|
||||
raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner")
|
||||
if builtin_needs_third_party:
|
||||
return ["third_party", "builtin"]
|
||||
return ["builtin", "third_party"]
|
||||
|
||||
# ─── 生命周期 ─────────────────────────────────────────────
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -86,7 +147,7 @@ class PluginRuntimeManager(
|
||||
logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动")
|
||||
return
|
||||
|
||||
_cfg = global_config.plugin_runtime
|
||||
_cfg = config_manager.get_global_config().plugin_runtime
|
||||
if not _cfg.enabled:
|
||||
logger.info("插件运行时已在配置中禁用,跳过启动")
|
||||
return
|
||||
@@ -108,6 +169,8 @@ class PluginRuntimeManager(
|
||||
logger.info("未找到任何插件目录,跳过插件运行时启动")
|
||||
return
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
|
||||
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
|
||||
socket_path_base = _cfg.ipc_socket_path or None
|
||||
|
||||
@@ -132,19 +195,46 @@ class PluginRuntimeManager(
|
||||
|
||||
started_supervisors: List[PluginSupervisor] = []
|
||||
try:
|
||||
if self._builtin_supervisor:
|
||||
await self._builtin_supervisor.start()
|
||||
started_supervisors.append(self._builtin_supervisor)
|
||||
if self._third_party_supervisor:
|
||||
await self._third_party_supervisor.start()
|
||||
started_supervisors.append(self._third_party_supervisor)
|
||||
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
|
||||
await platform_io_manager.ensure_send_pipeline_ready()
|
||||
|
||||
supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
|
||||
"builtin": self._builtin_supervisor,
|
||||
"third_party": self._third_party_supervisor,
|
||||
}
|
||||
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
|
||||
|
||||
for group_name in start_order:
|
||||
supervisor = supervisor_groups.get(group_name)
|
||||
if supervisor is None:
|
||||
continue
|
||||
|
||||
external_plugin_versions = {
|
||||
plugin_id: plugin_version
|
||||
for started_supervisor in started_supervisors
|
||||
for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
|
||||
}
|
||||
supervisor.set_external_available_plugins(external_plugin_versions)
|
||||
await supervisor.start()
|
||||
started_supervisors.append(supervisor)
|
||||
|
||||
await self._start_plugin_file_watcher()
|
||||
config_manager.register_reload_callback(self._config_reload_callback)
|
||||
self._config_reload_callback_registered = True
|
||||
self._started = True
|
||||
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {third_party_dirs or '无'}")
|
||||
except Exception as e:
|
||||
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
|
||||
await self._stop_plugin_file_watcher()
|
||||
if self._config_reload_callback_registered:
|
||||
config_manager.unregister_reload_callback(self._config_reload_callback)
|
||||
self._config_reload_callback_registered = False
|
||||
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
|
||||
platform_io_manager.clear_inbound_dispatcher()
|
||||
try:
|
||||
await platform_io_manager.stop()
|
||||
except Exception as platform_io_exc:
|
||||
logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
@@ -154,7 +244,11 @@ class PluginRuntimeManager(
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
await self._stop_plugin_file_watcher()
|
||||
if self._config_reload_callback_registered:
|
||||
config_manager.unregister_reload_callback(self._config_reload_callback)
|
||||
self._config_reload_callback_registered = False
|
||||
|
||||
coroutines: List[Coroutine[Any, Any, None]] = []
|
||||
if self._builtin_supervisor:
|
||||
@@ -162,18 +256,32 @@ class PluginRuntimeManager(
|
||||
if self._third_party_supervisor:
|
||||
coroutines.append(self._third_party_supervisor.stop())
|
||||
|
||||
stop_errors: List[str] = []
|
||||
try:
|
||||
await asyncio.gather(*coroutines, return_exceptions=True)
|
||||
logger.info("插件运行时已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"插件运行时停止失败: {e}", exc_info=True)
|
||||
results = await asyncio.gather(*coroutines, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
stop_errors.append(str(result))
|
||||
|
||||
platform_io_manager.clear_inbound_dispatcher()
|
||||
try:
|
||||
await platform_io_manager.stop()
|
||||
except Exception as exc:
|
||||
stop_errors.append(f"Platform IO: {exc}")
|
||||
|
||||
if stop_errors:
|
||||
logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}")
|
||||
else:
|
||||
logger.info("插件运行时已停止")
|
||||
finally:
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
self._plugin_path_cache.clear()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""返回插件运行时是否处于启动状态。"""
|
||||
return self._started
|
||||
|
||||
@property
|
||||
@@ -181,11 +289,176 @@ class PluginRuntimeManager(
|
||||
"""获取所有活跃的 Supervisor"""
|
||||
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
|
||||
|
||||
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
|
||||
"""根据当前已注册插件构建全局依赖图。"""
|
||||
|
||||
dependency_map: Dict[str, Set[str]] = {}
|
||||
for supervisor in self.supervisors:
|
||||
for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items():
|
||||
dependency_map[plugin_id] = {
|
||||
str(dependency or "").strip()
|
||||
for dependency in getattr(registration, "dependencies", [])
|
||||
if str(dependency or "").strip()
|
||||
}
|
||||
return dependency_map
|
||||
|
||||
@staticmethod
|
||||
def _collect_reverse_dependents(
|
||||
plugin_ids: Set[str],
|
||||
dependency_map: Dict[str, Set[str]],
|
||||
) -> Set[str]:
|
||||
"""根据依赖图收集反向依赖闭包。"""
|
||||
|
||||
impacted_plugins: Set[str] = set(plugin_ids)
|
||||
changed = True
|
||||
|
||||
while changed:
|
||||
changed = False
|
||||
for registered_plugin_id, dependencies in dependency_map.items():
|
||||
if registered_plugin_id in impacted_plugins:
|
||||
continue
|
||||
if dependencies & impacted_plugins:
|
||||
impacted_plugins.add(registered_plugin_id)
|
||||
changed = True
|
||||
|
||||
return impacted_plugins
|
||||
|
||||
def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]:
|
||||
"""构建当前已注册插件到所属 Supervisor 的映射。"""
|
||||
|
||||
return {
|
||||
plugin_id: supervisor
|
||||
for supervisor in self.supervisors
|
||||
for plugin_id in supervisor.get_loaded_plugin_ids()
|
||||
}
|
||||
|
||||
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
|
||||
"""收集某个 Supervisor 可用的外部插件版本映射。"""
|
||||
|
||||
external_plugin_versions: Dict[str, str] = {}
|
||||
for supervisor in self.supervisors:
|
||||
if supervisor is target_supervisor:
|
||||
continue
|
||||
external_plugin_versions.update(supervisor.get_loaded_plugin_versions())
|
||||
return external_plugin_versions
|
||||
|
||||
def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
|
||||
"""根据插件目录推断应负责该插件重载的 Supervisor。"""
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
if self._get_plugin_path_for_supervisor(supervisor, plugin_id) is not None:
|
||||
return supervisor
|
||||
return None
|
||||
|
||||
def _warn_skipped_cross_supervisor_reload(
|
||||
self,
|
||||
requested_loaded_plugin_ids: Set[str],
|
||||
dependency_map: Dict[str, Set[str]],
|
||||
supervisor_by_plugin: Dict[str, "PluginSupervisor"],
|
||||
) -> None:
|
||||
"""记录因跨 Supervisor 边界而未参与联动重载的插件。"""
|
||||
|
||||
if not requested_loaded_plugin_ids:
|
||||
return
|
||||
|
||||
handled_plugin_ids: Set[str] = set()
|
||||
for supervisor in self.supervisors:
|
||||
local_requested_plugin_ids = {
|
||||
plugin_id
|
||||
for plugin_id in requested_loaded_plugin_ids
|
||||
if supervisor_by_plugin.get(plugin_id) is supervisor
|
||||
}
|
||||
if not local_requested_plugin_ids:
|
||||
continue
|
||||
|
||||
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
|
||||
local_dependency_map = {
|
||||
plugin_id: {
|
||||
dependency
|
||||
for dependency in dependency_map.get(plugin_id, set())
|
||||
if dependency in local_plugin_ids
|
||||
}
|
||||
for plugin_id in local_plugin_ids
|
||||
}
|
||||
handled_plugin_ids.update(
|
||||
self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map)
|
||||
)
|
||||
|
||||
impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map)
|
||||
skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids)
|
||||
if not skipped_plugin_ids:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: "
|
||||
f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;"
|
||||
"跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。"
|
||||
)
|
||||
|
||||
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool:
|
||||
"""按 Supervisor 分组执行精确重载。
|
||||
|
||||
仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警,
|
||||
不再自动参与本次热重载。
|
||||
"""
|
||||
|
||||
normalized_plugin_ids = [
|
||||
normalized_plugin_id
|
||||
for plugin_id in plugin_ids
|
||||
if (normalized_plugin_id := str(plugin_id or "").strip())
|
||||
]
|
||||
if not normalized_plugin_ids:
|
||||
return True
|
||||
|
||||
dependency_map = self._build_registered_dependency_map()
|
||||
supervisor_by_plugin = self._build_registered_supervisor_map()
|
||||
supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
|
||||
requested_loaded_plugin_ids: Set[str] = set()
|
||||
missing_plugin_ids: List[str] = []
|
||||
|
||||
for plugin_id in normalized_plugin_ids:
|
||||
supervisor = supervisor_by_plugin.get(plugin_id)
|
||||
if supervisor is not None:
|
||||
requested_loaded_plugin_ids.add(plugin_id)
|
||||
else:
|
||||
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||
|
||||
if supervisor is None:
|
||||
missing_plugin_ids.append(plugin_id)
|
||||
continue
|
||||
|
||||
if plugin_id not in supervisor_roots.setdefault(supervisor, []):
|
||||
supervisor_roots[supervisor].append(plugin_id)
|
||||
|
||||
if missing_plugin_ids:
|
||||
logger.warning(f"以下插件未找到可重载的 Supervisor,已跳过: {', '.join(sorted(missing_plugin_ids))}")
|
||||
|
||||
self._warn_skipped_cross_supervisor_reload(
|
||||
requested_loaded_plugin_ids=requested_loaded_plugin_ids,
|
||||
dependency_map=dependency_map,
|
||||
supervisor_by_plugin=supervisor_by_plugin,
|
||||
)
|
||||
|
||||
success = True
|
||||
for supervisor, root_plugin_ids in supervisor_roots.items():
|
||||
if not root_plugin_ids:
|
||||
continue
|
||||
|
||||
reloaded = await supervisor.reload_plugins(
|
||||
plugin_ids=root_plugin_ids,
|
||||
reason=reason,
|
||||
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
|
||||
)
|
||||
success = success and reloaded
|
||||
|
||||
return success and not missing_plugin_ids
|
||||
|
||||
async def notify_plugin_config_updated(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
config_version: str = "",
|
||||
config_scope: str = "self",
|
||||
) -> bool:
|
||||
"""向拥有该插件的 Supervisor 推送配置更新事件。
|
||||
|
||||
@@ -193,6 +466,7 @@ class PluginRuntimeManager(
|
||||
plugin_id: 插件 ID
|
||||
config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载)
|
||||
config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制
|
||||
config_scope: 配置变更范围。
|
||||
"""
|
||||
if not self._started:
|
||||
return False
|
||||
@@ -209,23 +483,78 @@ class PluginRuntimeManager(
|
||||
config_payload = (
|
||||
config_data
|
||||
if config_data is not None
|
||||
else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs)
|
||||
else self._load_plugin_config_for_supervisor(sv, plugin_id)
|
||||
)
|
||||
await sv.notify_plugin_config_updated(
|
||||
return await sv.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=config_payload,
|
||||
config_version=config_version,
|
||||
config_scope=config_scope,
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
|
||||
"""规范化配置热重载范围列表。
|
||||
|
||||
Args:
|
||||
changed_scopes: 原始配置热重载范围列表。
|
||||
|
||||
Returns:
|
||||
tuple[str, ...]: 去重后的有效配置范围元组。
|
||||
"""
|
||||
|
||||
normalized_scopes: list[str] = []
|
||||
for scope in changed_scopes:
|
||||
normalized_scope = str(scope or "").strip().lower()
|
||||
if normalized_scope not in {"bot", "model"}:
|
||||
continue
|
||||
if normalized_scope not in normalized_scopes:
|
||||
normalized_scopes.append(normalized_scope)
|
||||
return tuple(normalized_scopes)
|
||||
|
||||
async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None:
|
||||
"""向订阅指定范围的插件广播配置热重载。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
|
||||
config_data: 最新配置数据。
|
||||
"""
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
for plugin_id in supervisor.get_config_reload_subscribers(scope):
|
||||
delivered = await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=config_data,
|
||||
config_version="",
|
||||
config_scope=scope,
|
||||
)
|
||||
if not delivered:
|
||||
logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败")
|
||||
|
||||
async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None:
|
||||
"""处理 bot/model 主配置热重载广播。
|
||||
|
||||
Args:
|
||||
changed_scopes: 本次热重载命中的配置范围列表。
|
||||
"""
|
||||
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
normalized_scopes = self._normalize_config_reload_scopes(changed_scopes)
|
||||
if "bot" in normalized_scopes:
|
||||
await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump(mode="json"))
|
||||
if "model" in normalized_scopes:
|
||||
await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump(mode="json"))
|
||||
|
||||
# ─── 事件桥接 ──────────────────────────────────────────────
|
||||
|
||||
async def bridge_event(
|
||||
self,
|
||||
event_type_value: str,
|
||||
message_dict: Optional[Dict[str, Any]] = None,
|
||||
message_dict: Optional[MessageDict] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
) -> Tuple[bool, Optional[MessageDict]]:
|
||||
"""将事件分发到所有 Supervisor
|
||||
|
||||
Returns:
|
||||
@@ -235,17 +564,23 @@ class PluginRuntimeManager(
|
||||
return True, None
|
||||
|
||||
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
|
||||
modified: Optional[Dict[str, Any]] = None
|
||||
modified: Optional[MessageDict] = None
|
||||
current_message: Optional["SessionMessage"] = (
|
||||
PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
|
||||
if message_dict is not None
|
||||
else None
|
||||
)
|
||||
|
||||
for sv in self.supervisors:
|
||||
try:
|
||||
cont, mod = await sv.dispatch_event(
|
||||
event_type=new_event_type,
|
||||
message=modified or message_dict,
|
||||
message=current_message,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
if mod is not None:
|
||||
modified = mod
|
||||
current_message = mod
|
||||
modified = PluginMessageUtils._session_message_to_dict(mod)
|
||||
if not cont:
|
||||
return False, modified
|
||||
except Exception as e:
|
||||
@@ -295,6 +630,37 @@ class PluginRuntimeManager(
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
async def try_send_message_via_platform_io(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
) -> Optional[DeliveryBatch]:
|
||||
"""尝试通过 Platform IO 中间层发送消息。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部会话消息。
|
||||
|
||||
Returns:
|
||||
Optional[DeliveryBatch]: 若当前消息命中了至少一条发送路由,则返回
|
||||
实际发送结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
|
||||
"""
|
||||
if not self._started:
|
||||
return None
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
if not platform_io_manager.is_started:
|
||||
return None
|
||||
|
||||
try:
|
||||
route_key = platform_io_manager.build_route_key_from_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}")
|
||||
return None
|
||||
|
||||
if not platform_io_manager.resolve_drivers(route_key):
|
||||
return None
|
||||
|
||||
return await platform_io_manager.send_message(message, route_key)
|
||||
|
||||
def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]:
|
||||
"""返回当前持有指定插件的所有 Supervisor。
|
||||
|
||||
@@ -314,30 +680,38 @@ class PluginRuntimeManager(
|
||||
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
|
||||
return matches[0] if matches else None
|
||||
|
||||
@staticmethod
|
||||
def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
|
||||
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
|
||||
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
|
||||
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
return False
|
||||
|
||||
try:
|
||||
registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
if registered_supervisor is not None:
|
||||
return await self.reload_plugins_globally([normalized_plugin_id], reason=reason)
|
||||
|
||||
supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id)
|
||||
if supervisor is None:
|
||||
return False
|
||||
|
||||
return await supervisor.reload_plugins(
|
||||
plugin_ids=[normalized_plugin_id],
|
||||
reason=reason,
|
||||
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
|
||||
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
|
||||
plugin_locations: Dict[str, List[Path]] = {}
|
||||
for base_dir in plugin_dirs:
|
||||
if not base_dir.is_dir():
|
||||
continue
|
||||
for entry in base_dir.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
manifest_path = entry / "_manifest.json"
|
||||
plugin_path = entry / "plugin.py"
|
||||
if not manifest_path.exists() or not plugin_path.exists():
|
||||
continue
|
||||
|
||||
plugin_id = entry.name
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as manifest_file:
|
||||
manifest = json.load(manifest_file)
|
||||
plugin_id = str(manifest.get("name", entry.name)).strip() or entry.name
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
plugin_locations.setdefault(plugin_id, []).append(entry)
|
||||
validator = ManifestValidator()
|
||||
for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs):
|
||||
plugin_locations.setdefault(manifest.id, []).append(plugin_path)
|
||||
|
||||
return {
|
||||
plugin_id: sorted(dict.fromkeys(paths), key=lambda p: str(p))
|
||||
@@ -370,6 +744,7 @@ class PluginRuntimeManager(
|
||||
async def _stop_plugin_file_watcher(self) -> None:
|
||||
"""停止插件文件监视器,并清理所有已注册订阅。"""
|
||||
if self._plugin_file_watcher is None:
|
||||
self._plugin_path_cache.clear()
|
||||
return
|
||||
for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
|
||||
self._plugin_file_watcher.unsubscribe(subscription_id)
|
||||
@@ -379,12 +754,79 @@ class PluginRuntimeManager(
|
||||
self._plugin_source_watcher_subscription_id = None
|
||||
await self._plugin_file_watcher.stop()
|
||||
self._plugin_file_watcher = None
|
||||
self._plugin_path_cache.clear()
|
||||
|
||||
def _iter_plugin_dirs(self) -> Iterable[Path]:
|
||||
"""迭代所有 Supervisor 当前管理的插件根目录。"""
|
||||
for supervisor in self.supervisors:
|
||||
yield from getattr(supervisor, "_plugin_dirs", [])
|
||||
|
||||
@staticmethod
|
||||
def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]:
|
||||
"""迭代所有可能的插件目录路径。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 一个或多个插件根目录。
|
||||
|
||||
Yields:
|
||||
Path: 单个插件目录路径。
|
||||
"""
|
||||
for plugin_dir in plugin_dirs:
|
||||
plugin_root = Path(plugin_dir).resolve()
|
||||
if not plugin_root.is_dir():
|
||||
continue
|
||||
for entry in plugin_root.iterdir():
|
||||
if entry.is_dir():
|
||||
yield entry.resolve()
|
||||
|
||||
def _read_plugin_id_from_plugin_path(self, plugin_path: Path) -> Optional[str]:
|
||||
"""从单个插件目录中读取 manifest 声明的插件 ID。
|
||||
|
||||
Args:
|
||||
plugin_path: 单个插件目录路径。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。
|
||||
"""
|
||||
return self._manifest_validator.read_plugin_id_from_plugin_path(plugin_path)
|
||||
|
||||
def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代目录中可解析到的插件 ID 与实际目录路径。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 一个或多个插件根目录。
|
||||
|
||||
Yields:
|
||||
Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。
|
||||
"""
|
||||
for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs):
|
||||
if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path):
|
||||
yield plugin_id, plugin_path
|
||||
|
||||
def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
|
||||
"""为指定 Supervisor 定位某个插件的实际目录。
|
||||
|
||||
Args:
|
||||
supervisor: 目标 Supervisor。
|
||||
plugin_id: 插件 ID。
|
||||
|
||||
Returns:
|
||||
Optional[Path]: 插件目录路径;未找到时返回 ``None``。
|
||||
"""
|
||||
cached_path = self._plugin_path_cache.get(plugin_id)
|
||||
if cached_path is not None:
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
|
||||
return cached_path
|
||||
|
||||
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
|
||||
if candidate_plugin_id != plugin_id:
|
||||
continue
|
||||
self._plugin_path_cache[plugin_id] = plugin_path
|
||||
return plugin_path
|
||||
|
||||
return None
|
||||
|
||||
def _refresh_plugin_config_watch_subscriptions(self) -> None:
|
||||
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
|
||||
|
||||
@@ -394,7 +836,11 @@ class PluginRuntimeManager(
|
||||
if self._plugin_file_watcher is None:
|
||||
return
|
||||
|
||||
desired_config_paths = dict(self._iter_registered_plugin_config_paths())
|
||||
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
|
||||
self._plugin_path_cache = desired_plugin_paths.copy()
|
||||
desired_config_paths = {
|
||||
plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
|
||||
}
|
||||
|
||||
for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
|
||||
if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]:
|
||||
@@ -418,28 +864,35 @@ class PluginRuntimeManager(
|
||||
"""为指定插件生成配置文件变更回调。"""
|
||||
|
||||
async def _callback(changes: Sequence[FileChange]) -> None:
|
||||
"""将 watcher 事件转发到指定插件的配置处理逻辑。
|
||||
|
||||
Args:
|
||||
changes: 当前批次收集到的文件变更列表。
|
||||
"""
|
||||
await self._handle_plugin_config_changes(plugin_id, changes)
|
||||
|
||||
return _callback
|
||||
|
||||
def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代当前所有已注册插件的 config.toml 路径。"""
|
||||
def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代当前所有已注册插件的实际目录路径。"""
|
||||
for supervisor in self.supervisors:
|
||||
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
|
||||
if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id):
|
||||
yield plugin_id, config_path
|
||||
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
|
||||
yield plugin_id, plugin_path
|
||||
|
||||
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
|
||||
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
plugin_path = plugin_dir.resolve() / plugin_id
|
||||
if plugin_path.is_dir():
|
||||
return plugin_path / "config.toml"
|
||||
return None
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
return None if plugin_path is None else plugin_path / "config.toml"
|
||||
|
||||
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
|
||||
"""处理单个插件配置文件变化,并仅向目标插件推送配置更新。"""
|
||||
"""处理单个插件配置文件变化,并定向派发自配置热更新。
|
||||
|
||||
Args:
|
||||
plugin_id: 发生配置变更的插件 ID。
|
||||
changes: 当前批次收集到的配置文件变更列表。
|
||||
|
||||
"""
|
||||
if not self._started or not changes:
|
||||
return
|
||||
|
||||
@@ -453,18 +906,24 @@ class PluginRuntimeManager(
|
||||
return
|
||||
|
||||
try:
|
||||
await supervisor.notify_plugin_config_updated(
|
||||
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
|
||||
delivered = await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])),
|
||||
config_data=config_payload,
|
||||
config_version="",
|
||||
config_scope="self",
|
||||
)
|
||||
if not delivered:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}")
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
|
||||
|
||||
async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None:
|
||||
"""处理插件源码相关变化。
|
||||
|
||||
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
|
||||
单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。
|
||||
单独的 per-plugin watcher 处理,并定向派发给目标插件的
|
||||
``on_config_update()``,避免放大成不必要的跨插件 reload。
|
||||
"""
|
||||
if not self._started or not changes:
|
||||
return
|
||||
@@ -477,7 +936,7 @@ class PluginRuntimeManager(
|
||||
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
|
||||
return
|
||||
|
||||
reload_supervisors: List[Any] = []
|
||||
changed_plugin_ids: List[str] = []
|
||||
changed_paths = [change.path.resolve() for change in changes]
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
@@ -485,13 +944,12 @@ class PluginRuntimeManager(
|
||||
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
|
||||
if plugin_id is None:
|
||||
continue
|
||||
if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors:
|
||||
reload_supervisors.append(supervisor)
|
||||
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
|
||||
if plugin_id not in changed_plugin_ids:
|
||||
changed_plugin_ids.append(plugin_id)
|
||||
|
||||
for supervisor in reload_supervisors:
|
||||
await supervisor.reload_plugins(reason="file_watcher")
|
||||
|
||||
if reload_supervisors:
|
||||
if changed_plugin_ids:
|
||||
await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
|
||||
self._refresh_plugin_config_watch_subscriptions()
|
||||
|
||||
@staticmethod
|
||||
@@ -502,36 +960,47 @@ class PluginRuntimeManager(
|
||||
|
||||
def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]:
|
||||
"""根据变更路径为指定 Supervisor 推断受影响的插件 ID。"""
|
||||
for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items():
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
candidate_dir = plugin_dir.resolve() / plugin_id
|
||||
if path == candidate_dir or path.is_relative_to(candidate_dir):
|
||||
return plugin_id
|
||||
resolved_path = path.resolve()
|
||||
|
||||
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)):
|
||||
return plugin_id
|
||||
|
||||
for plugin_id, plugin_path in self._plugin_path_cache.items():
|
||||
if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
|
||||
continue
|
||||
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
|
||||
return plugin_id
|
||||
|
||||
for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
|
||||
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
|
||||
self._plugin_path_cache[plugin_id] = plugin_path
|
||||
return plugin_id
|
||||
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
plugin_root = plugin_dir.resolve()
|
||||
if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts):
|
||||
return relative_parts[0]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]:
|
||||
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]:
|
||||
"""从给定插件目录集合中读取目标插件的配置内容。"""
|
||||
for plugin_dir in plugin_dirs:
|
||||
plugin_path = plugin_dir.resolve() / plugin_id
|
||||
if plugin_path.is_dir():
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
return tomlkit.load(handle).unwrap()
|
||||
return {}
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
if plugin_path is None:
|
||||
return {}
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
return tomlkit.load(handle).unwrap()
|
||||
|
||||
# ─── 能力实现注册 ──────────────────────────────────────────
|
||||
|
||||
def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None:
|
||||
"""向指定 Supervisor 注册主程序能力实现。
|
||||
|
||||
Args:
|
||||
supervisor: 需要注册能力实现的目标 Supervisor。
|
||||
"""
|
||||
register_capability_impls(self, supervisor)
|
||||
|
||||
|
||||
|
||||
@@ -7,52 +7,52 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import logging as stdlib_logging
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ─── 协议常量 ──────────────────────────────────────────────────────
|
||||
|
||||
PROTOCOL_VERSION = "1.0"
|
||||
|
||||
# ====== 协议常量 ======
|
||||
PROTOCOL_VERSION = "1.0.0"
|
||||
# 支持的 SDK 版本范围(Host 在握手时校验)
|
||||
MIN_SDK_VERSION = "1.0.0"
|
||||
MAX_SDK_VERSION = "1.99.99"
|
||||
|
||||
|
||||
# ─── 消息类型 ──────────────────────────────────────────────────────
|
||||
MAX_SDK_VERSION = "2.99.99"
|
||||
|
||||
|
||||
# ====== 消息类型 ======
|
||||
class MessageType(str, Enum):
|
||||
"""RPC 消息类型"""
|
||||
|
||||
REQUEST = "request"
|
||||
RESPONSE = "response"
|
||||
EVENT = "event"
|
||||
BROADCAST = "broadcast"
|
||||
|
||||
|
||||
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
|
||||
class ConfigReloadScope(str, Enum):
|
||||
"""配置热重载范围。"""
|
||||
|
||||
SELF = "self"
|
||||
BOT = "bot"
|
||||
MODEL = "model"
|
||||
|
||||
|
||||
# ====== 请求 ID 生成器 ======
|
||||
class RequestIdGenerator:
|
||||
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
|
||||
"""单调递增 int64 请求 ID 生成器"""
|
||||
|
||||
def __init__(self, start: int = 1) -> None:
|
||||
self._counter = start
|
||||
|
||||
def next(self) -> int:
|
||||
async def next(self) -> int:
|
||||
current = self._counter
|
||||
self._counter += 1
|
||||
return current
|
||||
|
||||
|
||||
# ─── Envelope 模型 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== Envelope 模型 ======
|
||||
class Envelope(BaseModel):
|
||||
"""RPC 统一信封
|
||||
"""RPC 统一消息封装
|
||||
|
||||
所有 Host <-> Runner 消息均封装为此格式。
|
||||
序列化流程:Envelope -> .model_dump() -> MsgPack encode
|
||||
@@ -60,15 +60,23 @@ class Envelope(BaseModel):
|
||||
"""
|
||||
|
||||
protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本")
|
||||
"""协议版本"""
|
||||
request_id: int = Field(description="单调递增请求 ID")
|
||||
"""单调递增请求 ID"""
|
||||
message_type: MessageType = Field(description="消息类型")
|
||||
"""消息类型"""
|
||||
method: str = Field(default="", description="RPC 方法名")
|
||||
"""RPC 方法名"""
|
||||
plugin_id: str = Field(default="", description="目标插件 ID")
|
||||
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
|
||||
timeout_ms: int = Field(default=30000, description="相对超时(ms)")
|
||||
generation: int = Field(default=0, description="Runner generation 编号")
|
||||
"""目标插件 ID"""
|
||||
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳 (ms)")
|
||||
"""发送时间戳 (ms)"""
|
||||
timeout_ms: int = Field(default=30000, description="相对超时 (ms)")
|
||||
"""相对超时 (ms)"""
|
||||
payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
|
||||
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)")
|
||||
"""业务数据"""
|
||||
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)")
|
||||
"""错误信息 (仅 response)"""
|
||||
|
||||
def is_request(self) -> bool:
|
||||
return self.message_type == MessageType.REQUEST
|
||||
@@ -76,8 +84,8 @@ class Envelope(BaseModel):
|
||||
def is_response(self) -> bool:
|
||||
return self.message_type == MessageType.RESPONSE
|
||||
|
||||
def is_event(self) -> bool:
|
||||
return self.message_type == MessageType.EVENT
|
||||
def is_broadcast(self) -> bool:
|
||||
return self.message_type == MessageType.BROADCAST
|
||||
|
||||
def make_response(
|
||||
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
|
||||
@@ -89,7 +97,6 @@ class Envelope(BaseModel):
|
||||
message_type=MessageType.RESPONSE,
|
||||
method=self.method,
|
||||
plugin_id=self.plugin_id,
|
||||
generation=self.generation,
|
||||
payload=payload or {},
|
||||
error=error,
|
||||
)
|
||||
@@ -105,153 +112,302 @@ class Envelope(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# ─── 握手消息 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 握手请求与响应 ======
|
||||
class HelloPayload(BaseModel):
|
||||
"""runner.hello 握手请求 payload"""
|
||||
|
||||
runner_id: str = Field(description="Runner 进程唯一标识")
|
||||
"""Runner 进程唯一标识"""
|
||||
sdk_version: str = Field(description="SDK 版本号")
|
||||
"""SDK 版本号"""
|
||||
session_token: str = Field(description="一次性会话令牌")
|
||||
"""一次性会话令牌"""
|
||||
|
||||
|
||||
class HelloResponsePayload(BaseModel):
|
||||
"""runner.hello 握手响应 payload"""
|
||||
|
||||
accepted: bool = Field(description="是否接受连接")
|
||||
"""是否接受连接"""
|
||||
host_version: str = Field(default="", description="Host 版本号")
|
||||
assigned_generation: int = Field(default=0, description="分配的 generation 编号")
|
||||
reason: str = Field(default="", description="拒绝原因(若 accepted=False)")
|
||||
|
||||
|
||||
# ─── 组件注册消息 ──────────────────────────────────────────────────
|
||||
"""Host 版本号"""
|
||||
reason: str = Field(default="", description="拒绝原因 (若 accepted=False)")
|
||||
"""拒绝原因 (若 `accepted`=`False`)"""
|
||||
|
||||
|
||||
# ====== 组件注册消息 ======
|
||||
class ComponentDeclaration(BaseModel):
|
||||
"""单个组件声明"""
|
||||
|
||||
name: str = Field(description="组件名称")
|
||||
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
|
||||
"""组件名称"""
|
||||
component_type: str = Field(
|
||||
description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway"
|
||||
)
|
||||
"""组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
|
||||
plugin_id: str = Field(description="所属插件 ID")
|
||||
"""所属插件 ID"""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
|
||||
"""组件元数据"""
|
||||
|
||||
|
||||
class RegisterComponentsPayload(BaseModel):
|
||||
"""plugin.register_components 请求 payload"""
|
||||
class RegisterPluginPayload(BaseModel):
|
||||
"""插件组件注册请求载荷。
|
||||
|
||||
该模型同时用于 ``plugin.register_components`` 与兼容旧命名的
|
||||
``plugin.register_plugin`` 请求。
|
||||
"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
plugin_version: str = Field(default="1.0.0", description="插件版本")
|
||||
"""插件版本"""
|
||||
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||
"""组件列表"""
|
||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||
"""所需能力列表"""
|
||||
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
|
||||
"""插件级依赖插件 ID 列表"""
|
||||
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
|
||||
"""订阅的全局配置热重载范围"""
|
||||
|
||||
|
||||
class BootstrapPluginPayload(BaseModel):
|
||||
"""plugin.bootstrap 请求 payload"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
plugin_version: str = Field(default="1.0.0", description="插件版本")
|
||||
"""插件版本"""
|
||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||
"""所需能力列表"""
|
||||
|
||||
|
||||
# ─── 调用消息 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 插件调用请求和响应 ======
|
||||
class InvokePayload(BaseModel):
|
||||
"""plugin.invoke_* 请求 payload"""
|
||||
"""plugin.invoke.* 请求 payload"""
|
||||
|
||||
component_name: str = Field(description="要调用的组件名称")
|
||||
"""要调用的组件名称"""
|
||||
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
"""调用参数"""
|
||||
|
||||
|
||||
class InvokeResultPayload(BaseModel):
|
||||
"""plugin.invoke_* 响应 payload"""
|
||||
"""plugin.invoke.* 响应 payload"""
|
||||
|
||||
success: bool = Field(description="是否成功")
|
||||
"""是否成功"""
|
||||
result: Any = Field(default=None, description="返回值")
|
||||
"""返回值"""
|
||||
|
||||
|
||||
# ─── 能力调用消息 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 能力调用消息 ======
|
||||
class CapabilityRequestPayload(BaseModel):
|
||||
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
||||
|
||||
capability: str = Field(description="能力名称,如 send.text, db.query")
|
||||
"""能力名称,如 send.text, db.query"""
|
||||
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
"""调用参数"""
|
||||
|
||||
|
||||
class CapabilityResponsePayload(BaseModel):
|
||||
"""cap.* 响应 payload"""
|
||||
|
||||
success: bool = Field(description="是否成功")
|
||||
"""是否成功"""
|
||||
result: Any = Field(default=None, description="返回值")
|
||||
"""返回值"""
|
||||
|
||||
|
||||
# ─── 健康检查 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 健康检查 ======
|
||||
class HealthPayload(BaseModel):
|
||||
"""plugin.health 响应 payload"""
|
||||
|
||||
healthy: bool = Field(description="是否健康")
|
||||
"""是否健康"""
|
||||
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
|
||||
uptime_ms: int = Field(default=0, description="运行时长(ms)")
|
||||
"""已加载的插件列表"""
|
||||
uptime_ms: int = Field(default=0, description="运行时长 (ms)")
|
||||
"""运行时长 (ms)"""
|
||||
|
||||
|
||||
class RunnerReadyPayload(BaseModel):
|
||||
"""runner.ready 请求 payload"""
|
||||
|
||||
loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表")
|
||||
"""已完成初始化的插件列表"""
|
||||
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
|
||||
"""初始化失败的插件列表"""
|
||||
|
||||
|
||||
# ─── 配置更新 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# Host 侧现已支持配置更新推送:
|
||||
# - 总配置热重载完成后,PluginRuntimeManager 会向已加载插件推送配置更新事件。
|
||||
# - 插件目录下的 config.toml 变化由现有 FileWatcher 监听并转发为 plugin.config_updated。
|
||||
# ====== 配置更新 ======
|
||||
class ConfigUpdatedPayload(BaseModel):
|
||||
"""plugin.config_updated 事件 payload"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
config_scope: ConfigReloadScope = Field(description="配置变更范围")
|
||||
"""配置变更范围"""
|
||||
config_version: str = Field(description="新配置版本")
|
||||
"""新配置版本"""
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
||||
"""配置内容"""
|
||||
|
||||
|
||||
# ─── 关停 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 关停 ======
|
||||
class ShutdownPayload(BaseModel):
|
||||
"""plugin.shutdown / plugin.prepare_shutdown payload"""
|
||||
|
||||
reason: str = Field(default="normal", description="关停原因")
|
||||
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
|
||||
"""关停原因"""
|
||||
drain_timeout_ms: int = Field(default=5000, description="排空超时 (ms)")
|
||||
"""排空超时 (ms)"""
|
||||
|
||||
|
||||
# ─── 日志传输 ──────────────────────────────────────────────────────
|
||||
class UnregisterPluginPayload(BaseModel):
|
||||
"""插件注销请求载荷。"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
reason: str = Field(default="manual", description="注销原因")
|
||||
"""注销原因"""
|
||||
|
||||
|
||||
class ReloadPluginPayload(BaseModel):
|
||||
"""插件重载请求载荷。"""
|
||||
|
||||
plugin_id: str = Field(description="目标插件 ID")
|
||||
"""目标插件 ID"""
|
||||
reason: str = Field(default="manual", description="重载原因")
|
||||
"""重载原因"""
|
||||
external_available_plugins: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="可视为已满足的外部依赖插件版本映射",
|
||||
)
|
||||
"""可视为已满足的外部依赖插件版本映射"""
|
||||
|
||||
|
||||
class ReloadPluginsPayload(BaseModel):
|
||||
"""批量插件重载请求载荷。"""
|
||||
|
||||
plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表")
|
||||
"""目标插件 ID 列表"""
|
||||
reason: str = Field(default="manual", description="重载原因")
|
||||
"""重载原因"""
|
||||
external_available_plugins: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="可视为已满足的外部依赖插件版本映射",
|
||||
)
|
||||
"""可视为已满足的外部依赖插件版本映射"""
|
||||
|
||||
|
||||
class ReloadPluginResultPayload(BaseModel):
|
||||
"""插件重载结果载荷。"""
|
||||
|
||||
success: bool = Field(description="是否重载成功")
|
||||
"""是否重载成功"""
|
||||
requested_plugin_id: str = Field(description="请求重载的插件 ID")
|
||||
"""请求重载的插件 ID"""
|
||||
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
class ReloadPluginsResultPayload(BaseModel):
|
||||
"""批量插件重载结果载荷。"""
|
||||
|
||||
success: bool = Field(description="是否重载成功")
|
||||
"""是否重载成功"""
|
||||
requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表")
|
||||
"""请求重载的插件 ID 列表"""
|
||||
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
class MessageGatewayStateUpdatePayload(BaseModel):
|
||||
"""消息网关运行时状态更新载荷。"""
|
||||
|
||||
gateway_name: str = Field(description="消息网关组件名称")
|
||||
"""消息网关组件名称"""
|
||||
ready: bool = Field(description="当前链路是否已经就绪")
|
||||
"""当前链路是否已经就绪"""
|
||||
platform: str = Field(default="", description="当前链路负责的平台名称")
|
||||
"""当前链路负责的平台名称"""
|
||||
account_id: str = Field(default="", description="当前链路对应的账号 ID 或 self_id")
|
||||
"""当前链路对应的账号 ID 或 self_id"""
|
||||
scope: str = Field(default="", description="当前链路对应的可选路由作用域")
|
||||
"""当前链路对应的可选路由作用域"""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
|
||||
"""可选的运行时状态元数据"""
|
||||
|
||||
|
||||
class MessageGatewayStateUpdateResultPayload(BaseModel):
|
||||
"""消息网关运行时状态更新结果载荷。"""
|
||||
|
||||
accepted: bool = Field(description="Host 是否接受了本次状态更新")
|
||||
"""Host 是否接受了本次状态更新"""
|
||||
ready: bool = Field(description="Host 记录的当前就绪状态")
|
||||
"""Host 记录的当前就绪状态"""
|
||||
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
|
||||
"""当前生效的路由键"""
|
||||
|
||||
|
||||
class RouteMessagePayload(BaseModel):
|
||||
"""消息网关向 Host 路由外部消息的请求载荷。"""
|
||||
|
||||
gateway_name: str = Field(description="接收消息的网关组件名称")
|
||||
"""接收消息的网关组件名称"""
|
||||
message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典")
|
||||
"""符合 MessageDict 结构的标准消息字典"""
|
||||
route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据")
|
||||
"""可选的路由辅助元数据"""
|
||||
external_message_id: str = Field(default="", description="可选的外部平台消息 ID")
|
||||
"""可选的外部平台消息 ID"""
|
||||
dedupe_key: str = Field(default="", description="可选的显式去重键")
|
||||
"""可选的显式去重键"""
|
||||
|
||||
|
||||
class ReceiveExternalMessageResultPayload(BaseModel):
|
||||
"""外部消息注入结果载荷。"""
|
||||
|
||||
accepted: bool = Field(description="Host 是否接受了本次消息注入")
|
||||
"""Host 是否接受了本次消息注入"""
|
||||
route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键")
|
||||
"""本次消息使用的归一路由键"""
|
||||
|
||||
|
||||
RegisterPluginPayload.model_rebuild()
|
||||
|
||||
|
||||
# ====== 日志传输 ======
|
||||
|
||||
|
||||
class LogEntry(BaseModel):
|
||||
"""单条日志记录(Runner → Host 传输格式)"""
|
||||
|
||||
timestamp_ms: int = Field(
|
||||
description="日志时间戳,Unix epoch 毫秒",
|
||||
)
|
||||
level: int = Field(
|
||||
description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
|
||||
)
|
||||
logger_name: str = Field(
|
||||
description="Logger 名称,如 plugin.my_plugin.submodule",
|
||||
)
|
||||
message: str = Field(
|
||||
description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)",
|
||||
)
|
||||
timestamp_ms: int = Field(description="日志时间戳,Unix epoch 毫秒")
|
||||
"""日志时间戳,Unix epoch 毫秒"""
|
||||
level: int = Field(description="stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL")
|
||||
"""stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"""
|
||||
logger_name: str = Field(description="Logger 名称,如 plugin.my_plugin.submodule")
|
||||
"""Logger 名称,如 plugin.my_plugin.submodule"""
|
||||
message: str = Field(description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)")
|
||||
"""经 Formatter 格式化后的完整日志消息(含 exc_info 文本)"""
|
||||
exception_text: str = Field(
|
||||
default="",
|
||||
description="原始异常摘要(exc_text),供结构化消费;已嵌入 message 中",
|
||||
)
|
||||
"""原始异常摘要(exc_text),供结构化消费;已嵌入 message 中"""
|
||||
log_color_in_hex: Optional[str] = Field(default=None, description="日志颜色的十六进制字符串(如 #RRGGBB)")
|
||||
|
||||
@property
|
||||
def levelname(self) -> str:
|
||||
@@ -262,6 +418,5 @@ class LogEntry(BaseModel):
|
||||
class LogBatchPayload(BaseModel):
|
||||
"""runner.log_batch 事件 payload:Runner 端向 Host 批量推送日志记录"""
|
||||
|
||||
entries: List[LogEntry] = Field(
|
||||
description="本批次日志记录列表,按时间升序排列",
|
||||
)
|
||||
entries: List[LogEntry] = Field(description="本批次日志记录列表,按时间升序排列")
|
||||
"""本批次日志记录列表,按时间升序排列"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user