feat: enhance session ID calculation and plugin management
- Updated `calculate_session_id` method in `SessionUtils` to include optional `account_id` and `scope` parameters for more granular session ID generation. - Added new environment variables in `plugin_runtime` for external plugin dependencies and global configuration snapshots. - Introduced methods in `RuntimeComponentManagerProtocol` for loading and reloading plugins globally, accommodating external dependencies. - Enhanced `PluginRunnerSupervisor` to manage external available plugin IDs during plugin reloads. - Implemented dependency extraction and management in `PluginRuntimeManager` to handle cross-supervisor dependencies. - Added tests for session ID calculation and message registration in `ChatManager` to ensure correct behavior with new parameters.
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
验证协议层、传输层、RPC 通信链路的正确性。
|
验证协议层、传输层、RPC 通信链路的正确性。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -2362,6 +2363,8 @@ class TestIntegration:
|
|||||||
from src.plugin_runtime import integration as integration_module
|
from src.plugin_runtime import integration as integration_module
|
||||||
|
|
||||||
instances = []
|
instances = []
|
||||||
|
builtin_dir = Path("builtin")
|
||||||
|
thirdparty_dir = Path("thirdparty")
|
||||||
|
|
||||||
class FakeCapabilityService:
|
class FakeCapabilityService:
|
||||||
def register_capability(self, name, impl):
|
def register_capability(self, name, impl):
|
||||||
@@ -2369,11 +2372,18 @@ class TestIntegration:
|
|||||||
|
|
||||||
class FakeSupervisor:
|
class FakeSupervisor:
|
||||||
def __init__(self, plugin_dirs=None, socket_path=None):
|
def __init__(self, plugin_dirs=None, socket_path=None):
|
||||||
self.plugin_dirs = plugin_dirs or []
|
self._plugin_dirs = plugin_dirs or []
|
||||||
self.capability_service = FakeCapabilityService()
|
self.capability_service = FakeCapabilityService()
|
||||||
|
self.external_plugin_ids = []
|
||||||
self.stopped = False
|
self.stopped = False
|
||||||
instances.append(self)
|
instances.append(self)
|
||||||
|
|
||||||
|
def set_external_available_plugin_ids(self, plugin_ids):
|
||||||
|
self.external_plugin_ids = list(plugin_ids)
|
||||||
|
|
||||||
|
def get_loaded_plugin_ids(self):
|
||||||
|
return []
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
if len(instances) == 2 and self is instances[1]:
|
if len(instances) == 2 and self is instances[1]:
|
||||||
raise RuntimeError("boom")
|
raise RuntimeError("boom")
|
||||||
@@ -2382,10 +2392,10 @@ class TestIntegration:
|
|||||||
self.stopped = True
|
self.stopped = True
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])
|
integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir])
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])
|
integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir])
|
||||||
)
|
)
|
||||||
|
|
||||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||||
@@ -2427,8 +2437,11 @@ class TestIntegration:
|
|||||||
self.reload_reasons = []
|
self.reload_reasons = []
|
||||||
self.config_updates = []
|
self.config_updates = []
|
||||||
|
|
||||||
async def reload_plugins(self, plugin_ids=None, reason="manual"):
|
def get_loaded_plugin_ids(self):
|
||||||
self.reload_reasons.append((plugin_ids, reason))
|
return sorted(self._registered_plugins.keys())
|
||||||
|
|
||||||
|
async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
|
||||||
|
self.reload_reasons.append((plugin_ids, reason, external_available_plugins or []))
|
||||||
|
|
||||||
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
|
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
|
||||||
self.config_updates.append((plugin_id, config_data, config_version))
|
self.config_updates.append((plugin_id, config_data, config_version))
|
||||||
@@ -2453,11 +2466,59 @@ class TestIntegration:
|
|||||||
await manager._handle_plugin_source_changes(changes)
|
await manager._handle_plugin_source_changes(changes)
|
||||||
|
|
||||||
assert manager._builtin_supervisor.reload_reasons == []
|
assert manager._builtin_supervisor.reload_reasons == []
|
||||||
assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")]
|
assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher", ["alpha"])]
|
||||||
assert manager._builtin_supervisor.config_updates == []
|
assert manager._builtin_supervisor.config_updates == []
|
||||||
assert manager._third_party_supervisor.config_updates == []
|
assert manager._third_party_supervisor.config_updates == []
|
||||||
assert refresh_calls == [True]
|
assert refresh_calls == [True]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch):
|
||||||
|
from src.plugin_runtime import integration as integration_module
|
||||||
|
|
||||||
|
class FakeRegistration:
|
||||||
|
def __init__(self, dependencies):
|
||||||
|
self.dependencies = dependencies
|
||||||
|
|
||||||
|
class FakeSupervisor:
|
||||||
|
def __init__(self, registrations):
|
||||||
|
self._registered_plugins = registrations
|
||||||
|
self.reload_calls = []
|
||||||
|
|
||||||
|
def get_loaded_plugin_ids(self):
|
||||||
|
return sorted(self._registered_plugins.keys())
|
||||||
|
|
||||||
|
async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
|
||||||
|
self.reload_calls.append((plugin_ids, reason, sorted(external_available_plugins or [])))
|
||||||
|
return True
|
||||||
|
|
||||||
|
builtin_supervisor = FakeSupervisor({"alpha": FakeRegistration([])})
|
||||||
|
third_party_supervisor = FakeSupervisor(
|
||||||
|
{
|
||||||
|
"beta": FakeRegistration(["alpha"]),
|
||||||
|
"gamma": FakeRegistration(["beta"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = integration_module.PluginRuntimeManager()
|
||||||
|
manager._builtin_supervisor = builtin_supervisor
|
||||||
|
manager._third_party_supervisor = third_party_supervisor
|
||||||
|
warning_messages = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
integration_module.logger,
|
||||||
|
"warning",
|
||||||
|
lambda message: warning_messages.append(message),
|
||||||
|
)
|
||||||
|
|
||||||
|
reloaded = await manager.reload_plugins_globally(["alpha"], reason="manual")
|
||||||
|
|
||||||
|
assert reloaded is True
|
||||||
|
assert builtin_supervisor.reload_calls == [(["alpha"], "manual", ["beta", "gamma"])]
|
||||||
|
assert third_party_supervisor.reload_calls == []
|
||||||
|
assert len(warning_messages) == 1
|
||||||
|
assert "beta, gamma" in warning_messages[0]
|
||||||
|
assert "跨 Supervisor API 调用仍然可用" in warning_messages[0]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
|
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
|
||||||
from src.plugin_runtime import integration as integration_module
|
from src.plugin_runtime import integration as integration_module
|
||||||
@@ -2623,55 +2684,30 @@ class TestIntegration:
|
|||||||
async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch):
|
async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch):
|
||||||
from src.plugin_runtime import integration as integration_module
|
from src.plugin_runtime import integration as integration_module
|
||||||
|
|
||||||
class FakeSupervisor:
|
manager = integration_module.PluginRuntimeManager()
|
||||||
def __init__(self):
|
monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False))
|
||||||
self._registered_plugins = {"alpha": object()}
|
|
||||||
|
|
||||||
async def reload_plugins(self, reason="manual"):
|
result = await manager._cap_component_reload_plugin(
|
||||||
return False
|
|
||||||
|
|
||||||
class FakeManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.supervisors = [FakeSupervisor()]
|
|
||||||
|
|
||||||
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
|
|
||||||
|
|
||||||
result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin(
|
|
||||||
"plugin_a",
|
"plugin_a",
|
||||||
"component.reload_plugin",
|
"component.reload_plugin",
|
||||||
{"plugin_name": "alpha"},
|
{"plugin_name": "alpha"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["success"] is False
|
assert result["success"] is False
|
||||||
assert "已回滚" in result["error"]
|
assert result["error"] == "插件 alpha 热重载失败"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path):
|
async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path):
|
||||||
from src.plugin_runtime import integration as integration_module
|
from src.plugin_runtime import integration as integration_module
|
||||||
|
|
||||||
plugin_root = tmp_path / "plugins"
|
manager = integration_module.PluginRuntimeManager()
|
||||||
plugin_root.mkdir()
|
monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False))
|
||||||
(plugin_root / "alpha").mkdir()
|
|
||||||
|
|
||||||
class FakeSupervisor:
|
result = await manager._cap_component_load_plugin(
|
||||||
def __init__(self):
|
|
||||||
self._registered_plugins = {}
|
|
||||||
self._plugin_dirs = [str(plugin_root)]
|
|
||||||
|
|
||||||
async def reload_plugins(self, reason="manual"):
|
|
||||||
return False
|
|
||||||
|
|
||||||
class FakeManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.supervisors = [FakeSupervisor()]
|
|
||||||
|
|
||||||
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
|
|
||||||
|
|
||||||
result = await integration_module.PluginRuntimeManager._cap_component_load_plugin(
|
|
||||||
"plugin_a",
|
"plugin_a",
|
||||||
"component.load_plugin",
|
"component.load_plugin",
|
||||||
{"plugin_name": "alpha"},
|
{"plugin_name": "alpha"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["success"] is False
|
assert result["success"] is False
|
||||||
assert "已回滚" in result["error"]
|
assert result["error"] == "插件 alpha 热重载失败"
|
||||||
|
|||||||
42
pytests/utils_test/test_session_utils.py
Normal file
42
pytests/utils_test/test_session_utils.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from src.chat.message_receive.chat_manager import ChatManager
|
||||||
|
from src.common.utils.utils_session import SessionUtils
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
|
||||||
|
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
||||||
|
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
||||||
|
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
|
||||||
|
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
|
||||||
|
|
||||||
|
assert base_session_id == same_base_session_id
|
||||||
|
assert account_scoped_session_id != base_session_id
|
||||||
|
assert route_scoped_session_id != account_scoped_session_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_manager_register_message_uses_route_metadata() -> None:
|
||||||
|
chat_manager = ChatManager()
|
||||||
|
message = SimpleNamespace(
|
||||||
|
platform="qq",
|
||||||
|
session_id="",
|
||||||
|
message_info=SimpleNamespace(
|
||||||
|
user_info=SimpleNamespace(user_id="42"),
|
||||||
|
group_info=SimpleNamespace(group_id="1000"),
|
||||||
|
additional_config={
|
||||||
|
"platform_io_account_id": "123",
|
||||||
|
"platform_io_scope": "main",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_manager.register_message(message)
|
||||||
|
|
||||||
|
assert message.session_id == SessionUtils.calculate_session_id(
|
||||||
|
"qq",
|
||||||
|
user_id="42",
|
||||||
|
group_id="1000",
|
||||||
|
account_id="123",
|
||||||
|
scope="main",
|
||||||
|
)
|
||||||
|
assert chat_manager.last_messages[message.session_id] is message
|
||||||
@@ -10,6 +10,7 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.utils.utils_message import MessageUtils
|
from src.common.utils.utils_message import MessageUtils
|
||||||
from src.common.utils.utils_session import SessionUtils
|
from src.common.utils.utils_session import SessionUtils
|
||||||
|
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||||
|
|
||||||
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||||
from src.core.announcement_manager import global_announcement_manager
|
from src.core.announcement_manager import global_announcement_manager
|
||||||
@@ -270,11 +271,18 @@ class ChatBot:
|
|||||||
try:
|
try:
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
|
account_id = None
|
||||||
|
scope = None
|
||||||
|
additional_config = message.message_info.additional_config
|
||||||
|
if isinstance(additional_config, dict):
|
||||||
|
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||||
|
|
||||||
session_id = SessionUtils.calculate_session_id(
|
session_id = SessionUtils.calculate_session_id(
|
||||||
message.platform,
|
message.platform,
|
||||||
user_id=message.message_info.user_info.user_id,
|
user_id=message.message_info.user_info.user_id,
|
||||||
group_id=group_info.group_id if group_info else None,
|
group_id=group_info.group_id if group_info else None,
|
||||||
|
account_id=account_id,
|
||||||
|
scope=scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
message.session_id = session_id # 正确初始化session_id
|
message.session_id = session_id # 正确初始化session_id
|
||||||
@@ -317,7 +325,13 @@ class ChatBot:
|
|||||||
platform = message.platform
|
platform = message.platform
|
||||||
user_id = user_info.user_id
|
user_id = user_info.user_id
|
||||||
group_id = group_info.group_id if group_info else None
|
group_id = group_info.group_id if group_info else None
|
||||||
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
|
_ = await chat_manager.get_or_create_session(
|
||||||
|
platform,
|
||||||
|
user_id,
|
||||||
|
group_id,
|
||||||
|
account_id=account_id,
|
||||||
|
scope=scope,
|
||||||
|
) # 确保会话存在
|
||||||
|
|
||||||
# message.update_chat_stream(chat)
|
# message.update_chat_stream(chat)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from typing import Optional, TYPE_CHECKING, List, Dict
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.data_models.chat_session_data_model import MaiChatSession
|
from src.common.data_models.chat_session_data_model import MaiChatSession
|
||||||
from src.common.database.database_model import ChatSession
|
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import ChatSession
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.common.utils.utils_session import SessionUtils
|
from src.common.utils.utils_session import SessionUtils
|
||||||
|
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .message import SessionMessage
|
from .message import SessionMessage
|
||||||
@@ -82,7 +83,12 @@ class ChatManager:
|
|||||||
logger.error(f"初始化聊天管理器出现错误: {e}")
|
logger.error(f"初始化聊天管理器出现错误: {e}")
|
||||||
|
|
||||||
async def get_or_create_session(
|
async def get_or_create_session(
|
||||||
self, platform: str, user_id: str, group_id: Optional[str] = None
|
self,
|
||||||
|
platform: str,
|
||||||
|
user_id: str,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
account_id: Optional[str] = None,
|
||||||
|
scope: Optional[str] = None,
|
||||||
) -> BotChatSession:
|
) -> BotChatSession:
|
||||||
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
|
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
|
||||||
|
|
||||||
@@ -90,12 +96,20 @@ class ChatManager:
|
|||||||
platform: 平台
|
platform: 平台
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
group_id: 群ID(如果是群聊)
|
group_id: 群ID(如果是群聊)
|
||||||
|
account_id: 平台账号 ID
|
||||||
|
scope: 路由作用域
|
||||||
Returns:
|
Returns:
|
||||||
return (BotChatSession) 会话对象
|
return (BotChatSession) 会话对象
|
||||||
Raises:
|
Raises:
|
||||||
Exception: 获取或创建会话时发生错误
|
Exception: 获取或创建会话时发生错误
|
||||||
"""
|
"""
|
||||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
session_id = SessionUtils.calculate_session_id(
|
||||||
|
platform,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
account_id=account_id,
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
if session := self.get_session_by_session_id(session_id):
|
if session := self.get_session_by_session_id(session_id):
|
||||||
session.update_active_time()
|
session.update_active_time()
|
||||||
return session
|
return session
|
||||||
@@ -131,7 +145,18 @@ class ChatManager:
|
|||||||
raise ValueError("消息缺少平台信息")
|
raise ValueError("消息缺少平台信息")
|
||||||
user_id = message.message_info.user_info.user_id
|
user_id = message.message_info.user_info.user_id
|
||||||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
account_id = None
|
||||||
|
scope = None
|
||||||
|
additional_config = message.message_info.additional_config
|
||||||
|
if isinstance(additional_config, dict):
|
||||||
|
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||||
|
session_id = SessionUtils.calculate_session_id(
|
||||||
|
platform,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
account_id=account_id,
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
message.session_id = session_id # 确保消息的session_id正确设置
|
message.session_id = session_id # 确保消息的session_id正确设置
|
||||||
self.last_messages[session_id] = message
|
self.last_messages[session_id] = message
|
||||||
|
|
||||||
@@ -188,7 +213,12 @@ class ChatManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_session_by_info(
|
def get_session_by_info(
|
||||||
self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None
|
self,
|
||||||
|
platform: str,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
account_id: Optional[str] = None,
|
||||||
|
scope: Optional[str] = None,
|
||||||
) -> Optional[BotChatSession]:
|
) -> Optional[BotChatSession]:
|
||||||
"""根据平台、用户ID和群ID获取对应的会话
|
"""根据平台、用户ID和群ID获取对应的会话
|
||||||
|
|
||||||
@@ -196,10 +226,18 @@ class ChatManager:
|
|||||||
platform: 平台
|
platform: 平台
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
group_id: 群ID(如果是群聊)
|
group_id: 群ID(如果是群聊)
|
||||||
|
account_id: 平台账号 ID
|
||||||
|
scope: 路由作用域
|
||||||
Returns:
|
Returns:
|
||||||
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
||||||
"""
|
"""
|
||||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
session_id = SessionUtils.calculate_session_id(
|
||||||
|
platform,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
account_id=account_id,
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
return self.get_session_by_session_id(session_id)
|
return self.get_session_by_session_id(session_id)
|
||||||
|
|
||||||
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
||||||
|
|||||||
@@ -5,13 +5,22 @@ import hashlib
|
|||||||
|
|
||||||
class SessionUtils:
|
class SessionUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
|
def calculate_session_id(
|
||||||
|
platform: str,
|
||||||
|
*,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
account_id: Optional[str] = None,
|
||||||
|
scope: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""计算session_id
|
"""计算session_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
platform: 平台名称
|
platform: 平台名称
|
||||||
user_id: 用户ID(如果是私聊)
|
user_id: 用户ID(如果是私聊)
|
||||||
group_id: 群ID(如果是群聊)
|
group_id: 群ID(如果是群聊)
|
||||||
|
account_id: 当前平台账号 ID,可选
|
||||||
|
scope: 当前路由作用域,可选
|
||||||
Returns:
|
Returns:
|
||||||
str: 计算得到的会话ID
|
str: 计算得到的会话ID
|
||||||
Raises:
|
Raises:
|
||||||
@@ -19,8 +28,15 @@ class SessionUtils:
|
|||||||
"""
|
"""
|
||||||
if not user_id and not group_id:
|
if not user_id and not group_id:
|
||||||
raise ValueError("UserID 或 GroupID 必须提供其一")
|
raise ValueError("UserID 或 GroupID 必须提供其一")
|
||||||
|
|
||||||
|
route_components = []
|
||||||
|
if account_id:
|
||||||
|
route_components.append(f"account:{account_id}")
|
||||||
|
if scope:
|
||||||
|
route_components.append(f"scope:{scope}")
|
||||||
|
|
||||||
if group_id:
|
if group_id:
|
||||||
components = [platform, group_id]
|
components = [platform, *route_components, group_id]
|
||||||
else:
|
else:
|
||||||
components = [platform, user_id, "private"]
|
components = [platform, *route_components, user_id, "private"]
|
||||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||||
|
|||||||
@@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
|
|||||||
|
|
||||||
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
|
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
|
||||||
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
|
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
|
||||||
|
|
||||||
|
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
|
||||||
|
"""Runner 启动时可视为已满足的外部插件依赖列表(JSON 数组)"""
|
||||||
|
|
||||||
|
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
|
||||||
|
"""Runner 启动时注入的全局配置快照(JSON 对象)"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -15,8 +15,35 @@ class _RuntimeComponentManagerProtocol(Protocol):
|
|||||||
@property
|
@property
|
||||||
def supervisors(self) -> List["PluginSupervisor"]: ...
|
def supervisors(self) -> List["PluginSupervisor"]: ...
|
||||||
|
|
||||||
|
def _normalize_component_type(self, component_type: str) -> str: ...
|
||||||
|
|
||||||
|
def _is_api_component_type(self, component_type: str) -> bool: ...
|
||||||
|
|
||||||
|
def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
|
||||||
|
|
||||||
|
def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
|
||||||
|
|
||||||
|
def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ...
|
||||||
|
|
||||||
|
def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ...
|
||||||
|
|
||||||
|
def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
|
||||||
|
|
||||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
|
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
|
||||||
|
|
||||||
|
def _resolve_api_target(
|
||||||
|
self,
|
||||||
|
caller_plugin_id: str,
|
||||||
|
api_name: str,
|
||||||
|
version: str = "",
|
||||||
|
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
|
||||||
|
|
||||||
|
def _resolve_api_toggle_target(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
version: str = "",
|
||||||
|
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
|
||||||
|
|
||||||
def _resolve_component_toggle_target(
|
def _resolve_component_toggle_target(
|
||||||
self, name: str, component_type: str
|
self, name: str, component_type: str
|
||||||
) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
|
) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
|
||||||
@@ -25,6 +52,10 @@ class _RuntimeComponentManagerProtocol(Protocol):
|
|||||||
|
|
||||||
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
|
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
|
||||||
|
|
||||||
|
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ...
|
||||||
|
|
||||||
|
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
class RuntimeComponentCapabilityMixin:
|
class RuntimeComponentCapabilityMixin:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -266,20 +297,22 @@ class RuntimeComponentCapabilityMixin:
|
|||||||
version=normalized_version,
|
version=normalized_version,
|
||||||
enabled_only=False,
|
enabled_only=False,
|
||||||
)
|
)
|
||||||
if not entries:
|
if len(entries) == 1:
|
||||||
return None, None, f"未找到 API: {normalized_name}"
|
return supervisor, entries[0], None
|
||||||
if len(entries) > 1:
|
if entries:
|
||||||
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
|
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
|
||||||
return supervisor, entries[0], None
|
return None, None, f"未找到 API: {normalized_name}"
|
||||||
|
|
||||||
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||||
for supervisor in self.supervisors:
|
for supervisor in self.supervisors:
|
||||||
for entry in supervisor.api_registry.get_apis(
|
matches.extend(
|
||||||
name=normalized_name,
|
(supervisor, entry)
|
||||||
version=normalized_version,
|
for entry in supervisor.api_registry.get_apis(
|
||||||
enabled_only=False,
|
name=normalized_name,
|
||||||
):
|
version=normalized_version,
|
||||||
matches.append((supervisor, entry))
|
enabled_only=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if len(matches) == 1:
|
if len(matches) == 1:
|
||||||
return matches[0][0], matches[0][1], None
|
return matches[0][0], matches[0][1], None
|
||||||
@@ -453,39 +486,14 @@ class RuntimeComponentCapabilityMixin:
|
|||||||
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
|
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
|
loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}")
|
||||||
except RuntimeError as exc:
|
except Exception as e:
|
||||||
return {"success": False, "error": str(exc)}
|
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
if registered_supervisor is not None:
|
if loaded:
|
||||||
try:
|
return {"success": True, "count": 1}
|
||||||
reloaded = await registered_supervisor.reload_plugins(
|
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
|
||||||
plugin_ids=[plugin_name],
|
|
||||||
reason=f"load {plugin_name}",
|
|
||||||
)
|
|
||||||
if reloaded:
|
|
||||||
return {"success": True, "count": 1}
|
|
||||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|
||||||
for sv in self.supervisors:
|
|
||||||
for pdir in sv._plugin_dirs:
|
|
||||||
if (pdir / plugin_name).is_dir():
|
|
||||||
try:
|
|
||||||
reloaded = await sv.reload_plugins(
|
|
||||||
plugin_ids=[plugin_name],
|
|
||||||
reason=f"load {plugin_name}",
|
|
||||||
)
|
|
||||||
if reloaded:
|
|
||||||
return {"success": True, "count": 1}
|
|
||||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|
||||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
|
||||||
|
|
||||||
async def _cap_component_unload_plugin(
|
async def _cap_component_unload_plugin(
|
||||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||||
@@ -507,23 +515,14 @@ class RuntimeComponentCapabilityMixin:
|
|||||||
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
|
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sv = self._get_supervisor_for_plugin(plugin_name)
|
reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}")
|
||||||
except RuntimeError as exc:
|
except Exception as e:
|
||||||
return {"success": False, "error": str(exc)}
|
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
if sv is not None:
|
if reloaded:
|
||||||
try:
|
return {"success": True}
|
||||||
reloaded = await sv.reload_plugins(
|
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
|
||||||
plugin_ids=[plugin_name],
|
|
||||||
reason=f"reload {plugin_name}",
|
|
||||||
)
|
|
||||||
if reloaded:
|
|
||||||
return {"success": True}
|
|
||||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
|
||||||
|
|
||||||
async def _cap_api_call(
|
async def _cap_api_call(
|
||||||
self: _RuntimeComponentManagerProtocol,
|
self: _RuntimeComponentManagerProtocol,
|
||||||
@@ -632,15 +631,16 @@ class RuntimeComponentCapabilityMixin:
|
|||||||
)
|
)
|
||||||
apis: List[Dict[str, Any]] = []
|
apis: List[Dict[str, Any]] = []
|
||||||
for supervisor in self.supervisors:
|
for supervisor in self.supervisors:
|
||||||
for entry in supervisor.api_registry.get_apis(
|
apis.extend(
|
||||||
plugin_id=target_plugin_id or None,
|
self._serialize_api_entry(entry)
|
||||||
name=api_name,
|
for entry in supervisor.api_registry.get_apis(
|
||||||
version=version,
|
plugin_id=target_plugin_id or None,
|
||||||
enabled_only=True,
|
name=api_name,
|
||||||
):
|
version=version,
|
||||||
if not self._is_api_visible_to_plugin(entry, plugin_id):
|
enabled_only=True,
|
||||||
continue
|
)
|
||||||
apis.append(self._serialize_api_entry(entry))
|
if self._is_api_visible_to_plugin(entry, plugin_id)
|
||||||
|
)
|
||||||
|
|
||||||
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
|
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
|
||||||
return {"success": True, "apis": apis}
|
return {"success": True, "apis": apis}
|
||||||
|
|||||||
@@ -4,17 +4,26 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import config_manager, global_config
|
||||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
||||||
from src.platform_io.drivers import PluginPlatformDriver
|
from src.platform_io.drivers import PluginPlatformDriver
|
||||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||||
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
from src.plugin_runtime import (
|
||||||
|
ENV_EXTERNAL_PLUGIN_IDS,
|
||||||
|
ENV_GLOBAL_CONFIG_SNAPSHOT,
|
||||||
|
ENV_HOST_VERSION,
|
||||||
|
ENV_IPC_ADDRESS,
|
||||||
|
ENV_PLUGIN_DIRS,
|
||||||
|
ENV_SESSION_TOKEN,
|
||||||
|
)
|
||||||
from src.plugin_runtime.protocol.envelope import (
|
from src.plugin_runtime.protocol.envelope import (
|
||||||
BootstrapPluginPayload,
|
BootstrapPluginPayload,
|
||||||
|
ConfigReloadScope,
|
||||||
ConfigUpdatedPayload,
|
ConfigUpdatedPayload,
|
||||||
Envelope,
|
Envelope,
|
||||||
HealthPayload,
|
HealthPayload,
|
||||||
@@ -107,6 +116,7 @@ class PluginRunnerSupervisor:
|
|||||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||||
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
||||||
self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
|
self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
|
||||||
|
self._external_available_plugin_ids: List[str] = []
|
||||||
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
||||||
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
||||||
self._health_task: Optional[asyncio.Task[None]] = None
|
self._health_task: Optional[asyncio.Task[None]] = None
|
||||||
@@ -156,6 +166,21 @@ class PluginRunnerSupervisor:
|
|||||||
"""返回底层 RPC 服务端。"""
|
"""返回底层 RPC 服务端。"""
|
||||||
return self._rpc_server
|
return self._rpc_server
|
||||||
|
|
||||||
|
def set_external_available_plugin_ids(self, plugin_ids: List[str]) -> None:
|
||||||
|
"""设置当前 Runner 启动/重载时可视为已满足的外部依赖列表。"""
|
||||||
|
|
||||||
|
normalized_plugin_ids = {
|
||||||
|
str(plugin_id or "").strip()
|
||||||
|
for plugin_id in plugin_ids
|
||||||
|
if str(plugin_id or "").strip()
|
||||||
|
}
|
||||||
|
self._external_available_plugin_ids = sorted(normalized_plugin_ids)
|
||||||
|
|
||||||
|
def get_loaded_plugin_ids(self) -> List[str]:
|
||||||
|
"""返回当前 Supervisor 已注册的插件 ID 列表。"""
|
||||||
|
|
||||||
|
return sorted(self._registered_plugins.keys())
|
||||||
|
|
||||||
async def dispatch_event(
|
async def dispatch_event(
|
||||||
self,
|
self,
|
||||||
event_type: str,
|
event_type: str,
|
||||||
@@ -344,12 +369,18 @@ class PluginRunnerSupervisor:
|
|||||||
timeout_ms=timeout_ms,
|
timeout_ms=timeout_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
|
async def reload_plugin(
|
||||||
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
reason: str = "manual",
|
||||||
|
external_available_plugins: Optional[List[str]] = None,
|
||||||
|
) -> bool:
|
||||||
"""按插件 ID 触发精确重载。
|
"""按插件 ID 触发精确重载。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plugin_id: 目标插件 ID。
|
plugin_id: 目标插件 ID。
|
||||||
reason: 重载原因。
|
reason: 重载原因。
|
||||||
|
external_available_plugins: 视为已满足的外部依赖插件 ID 列表。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否重载成功。
|
bool: 是否重载成功。
|
||||||
@@ -358,7 +389,11 @@ class PluginRunnerSupervisor:
|
|||||||
response = await self._rpc_server.send_request(
|
response = await self._rpc_server.send_request(
|
||||||
"plugin.reload",
|
"plugin.reload",
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
payload={"plugin_id": plugin_id, "reason": reason},
|
payload={
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"reason": reason,
|
||||||
|
"external_available_plugins": external_available_plugins or self._external_available_plugin_ids,
|
||||||
|
},
|
||||||
timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
|
timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -374,12 +409,14 @@ class PluginRunnerSupervisor:
|
|||||||
self,
|
self,
|
||||||
plugin_ids: Optional[List[str]] = None,
|
plugin_ids: Optional[List[str]] = None,
|
||||||
reason: str = "manual",
|
reason: str = "manual",
|
||||||
|
external_available_plugins: Optional[List[str]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""批量重载插件。
|
"""批量重载插件。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。
|
plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。
|
||||||
reason: 重载原因。
|
reason: 重载原因。
|
||||||
|
external_available_plugins: 视为已满足的外部依赖插件 ID 列表。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否全部重载成功。
|
bool: 是否全部重载成功。
|
||||||
@@ -389,7 +426,11 @@ class PluginRunnerSupervisor:
|
|||||||
success = True
|
success = True
|
||||||
|
|
||||||
for plugin_id in ordered_plugin_ids:
|
for plugin_id in ordered_plugin_ids:
|
||||||
reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason)
|
reloaded = await self.reload_plugin(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
reason=reason,
|
||||||
|
external_available_plugins=external_available_plugins,
|
||||||
|
)
|
||||||
success = success and reloaded
|
success = success and reloaded
|
||||||
|
|
||||||
return success
|
return success
|
||||||
@@ -399,7 +440,7 @@ class PluginRunnerSupervisor:
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
config_data: Optional[Dict[str, Any]] = None,
|
config_data: Optional[Dict[str, Any]] = None,
|
||||||
config_version: str = "",
|
config_version: str = "",
|
||||||
config_scope: str = "self",
|
config_scope: str | ConfigReloadScope = "self",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向 Runner 推送插件配置更新。
|
"""向 Runner 推送插件配置更新。
|
||||||
|
|
||||||
@@ -412,9 +453,15 @@ class PluginRunnerSupervisor:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 请求是否成功送达并被 Runner 接受。
|
bool: 请求是否成功送达并被 Runner 接受。
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
normalized_scope = ConfigReloadScope(config_scope)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"插件 {plugin_id} 配置更新通知失败: 非法的 config_scope={config_scope}")
|
||||||
|
return False
|
||||||
|
|
||||||
payload = ConfigUpdatedPayload(
|
payload = ConfigUpdatedPayload(
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
config_scope=config_scope,
|
config_scope=normalized_scope,
|
||||||
config_version=config_version,
|
config_version=config_version,
|
||||||
config_data=config_data or {},
|
config_data=config_data or {},
|
||||||
)
|
)
|
||||||
@@ -441,11 +488,11 @@ class PluginRunnerSupervisor:
|
|||||||
List[str]: 已声明订阅该范围的插件 ID 列表。
|
List[str]: 已声明订阅该范围的插件 ID 列表。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
matched_plugins: List[str] = []
|
return [
|
||||||
for plugin_id, registration in self._registered_plugins.items():
|
plugin_id
|
||||||
if scope in registration.config_reload_subscriptions:
|
for plugin_id, registration in self._registered_plugins.items()
|
||||||
matched_plugins.append(plugin_id)
|
if scope in registration.config_reload_subscriptions
|
||||||
return matched_plugins
|
]
|
||||||
|
|
||||||
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
|
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
|
||||||
"""等待 Runner 建立 RPC 连接。
|
"""等待 Runner 建立 RPC 连接。
|
||||||
@@ -706,10 +753,7 @@ class PluginRunnerSupervisor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False)
|
gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False)
|
||||||
if len(gateways) == 1:
|
return gateways[0] if len(gateways) == 1 else None
|
||||||
return gateways[0]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _register_message_gateway_driver(
|
async def _register_message_gateway_driver(
|
||||||
self,
|
self,
|
||||||
@@ -823,8 +867,7 @@ class PluginRunnerSupervisor:
|
|||||||
ValueError: 当平台信息缺失时抛出。
|
ValueError: 当平台信息缺失时抛出。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
platform = str(payload.platform or gateway_entry.platform or "").strip()
|
if not (platform := str(payload.platform or gateway_entry.platform or "").strip()):
|
||||||
if not platform:
|
|
||||||
raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称")
|
raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称")
|
||||||
|
|
||||||
return RouteKey(
|
return RouteKey(
|
||||||
@@ -1090,7 +1133,11 @@ class PluginRunnerSupervisor:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict[str, str]: 传递给 Runner 进程的环境变量映射。
|
Dict[str, str]: 传递给 Runner 进程的环境变量映射。
|
||||||
"""
|
"""
|
||||||
|
global_config_snapshot = config_manager.get_global_config().model_dump()
|
||||||
|
global_config_snapshot["model"] = config_manager.get_model_config().model_dump()
|
||||||
return {
|
return {
|
||||||
|
ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugin_ids, ensure_ascii=False),
|
||||||
|
ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
|
||||||
ENV_HOST_VERSION: PROTOCOL_VERSION,
|
ENV_HOST_VERSION: PROTOCOL_VERSION,
|
||||||
ENV_IPC_ADDRESS: self._transport.get_address(),
|
ENV_IPC_ADDRESS: self._transport.get_address(),
|
||||||
ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),
|
ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),
|
||||||
@@ -1136,8 +1183,7 @@ class PluginRunnerSupervisor:
|
|||||||
line = await stream.readline()
|
line = await stream.readline()
|
||||||
if not line:
|
if not line:
|
||||||
return
|
return
|
||||||
message = line.decode("utf-8", errors="replace").rstrip()
|
if message := line.decode("utf-8", errors="replace").rstrip():
|
||||||
if message:
|
|
||||||
logger.warning(f"[runner-stderr] {message}")
|
logger.warning(f"[runner-stderr] {message}")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
@@ -102,6 +102,77 @@ class PluginRuntimeManager(
|
|||||||
candidate = Path("plugins").resolve()
|
candidate = Path("plugins").resolve()
|
||||||
return [candidate] if candidate.is_dir() else []
|
return [candidate] if candidate.is_dir() else []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_manifest_dependencies(manifest: Dict[str, Any]) -> List[str]:
|
||||||
|
"""从插件 manifest 中提取规范化后的依赖插件 ID 列表。"""
|
||||||
|
|
||||||
|
dependencies: List[str] = []
|
||||||
|
for dependency in manifest.get("dependencies", []):
|
||||||
|
if isinstance(dependency, str):
|
||||||
|
normalized_dependency = dependency.strip()
|
||||||
|
elif isinstance(dependency, dict):
|
||||||
|
normalized_dependency = str(dependency.get("name", "") or "").strip()
|
||||||
|
else:
|
||||||
|
normalized_dependency = ""
|
||||||
|
|
||||||
|
if normalized_dependency:
|
||||||
|
dependencies.append(normalized_dependency)
|
||||||
|
return dependencies
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
|
||||||
|
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
|
||||||
|
|
||||||
|
dependency_map: Dict[str, List[str]] = {}
|
||||||
|
for plugin_dir in cls._iter_candidate_plugin_paths(plugin_dirs):
|
||||||
|
manifest_path = plugin_dir / "_manifest.json"
|
||||||
|
entrypoint_path = plugin_dir / "plugin.py"
|
||||||
|
if not manifest_path.is_file() or not entrypoint_path.is_file():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with manifest_path.open("r", encoding="utf-8") as manifest_file:
|
||||||
|
manifest = json.load(manifest_file)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(manifest, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
plugin_id = str(manifest.get("name", plugin_dir.name) or "").strip() or plugin_dir.name
|
||||||
|
dependency_map[plugin_id] = cls._extract_manifest_dependencies(manifest)
|
||||||
|
return dependency_map
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_group_start_order(
|
||||||
|
cls,
|
||||||
|
builtin_dirs: Sequence[Path],
|
||||||
|
third_party_dirs: Sequence[Path],
|
||||||
|
) -> List[str]:
|
||||||
|
"""根据跨 Supervisor 依赖关系决定 Runner 启动顺序。"""
|
||||||
|
|
||||||
|
builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs)
|
||||||
|
third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs)
|
||||||
|
builtin_plugin_ids = set(builtin_dependencies)
|
||||||
|
third_party_plugin_ids = set(third_party_dependencies)
|
||||||
|
|
||||||
|
builtin_needs_third_party = any(
|
||||||
|
dependency in third_party_plugin_ids
|
||||||
|
for dependencies in builtin_dependencies.values()
|
||||||
|
for dependency in dependencies
|
||||||
|
)
|
||||||
|
third_party_needs_builtin = any(
|
||||||
|
dependency in builtin_plugin_ids
|
||||||
|
for dependencies in third_party_dependencies.values()
|
||||||
|
for dependency in dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
if builtin_needs_third_party and third_party_needs_builtin:
|
||||||
|
raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner")
|
||||||
|
if builtin_needs_third_party:
|
||||||
|
return ["third_party", "builtin"]
|
||||||
|
return ["builtin", "third_party"]
|
||||||
|
|
||||||
# ─── 生命周期 ─────────────────────────────────────────────
|
# ─── 生命周期 ─────────────────────────────────────────────
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -161,12 +232,26 @@ class PluginRuntimeManager(
|
|||||||
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
|
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
|
||||||
await platform_io_manager.ensure_send_pipeline_ready()
|
await platform_io_manager.ensure_send_pipeline_ready()
|
||||||
|
|
||||||
if self._builtin_supervisor:
|
supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
|
||||||
await self._builtin_supervisor.start()
|
"builtin": self._builtin_supervisor,
|
||||||
started_supervisors.append(self._builtin_supervisor)
|
"third_party": self._third_party_supervisor,
|
||||||
if self._third_party_supervisor:
|
}
|
||||||
await self._third_party_supervisor.start()
|
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
|
||||||
started_supervisors.append(self._third_party_supervisor)
|
|
||||||
|
for group_name in start_order:
|
||||||
|
supervisor = supervisor_groups.get(group_name)
|
||||||
|
if supervisor is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
external_plugin_ids = [
|
||||||
|
plugin_id
|
||||||
|
for started_supervisor in started_supervisors
|
||||||
|
for plugin_id in started_supervisor.get_loaded_plugin_ids()
|
||||||
|
]
|
||||||
|
supervisor.set_external_available_plugin_ids(external_plugin_ids)
|
||||||
|
await supervisor.start()
|
||||||
|
started_supervisors.append(supervisor)
|
||||||
|
|
||||||
await self._start_plugin_file_watcher()
|
await self._start_plugin_file_watcher()
|
||||||
config_manager.register_reload_callback(self._config_reload_callback)
|
config_manager.register_reload_callback(self._config_reload_callback)
|
||||||
self._config_reload_callback_registered = True
|
self._config_reload_callback_registered = True
|
||||||
@@ -238,6 +323,171 @@ class PluginRuntimeManager(
|
|||||||
"""获取所有活跃的 Supervisor"""
|
"""获取所有活跃的 Supervisor"""
|
||||||
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
|
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
|
||||||
|
|
||||||
|
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
|
||||||
|
"""根据当前已注册插件构建全局依赖图。"""
|
||||||
|
|
||||||
|
dependency_map: Dict[str, Set[str]] = {}
|
||||||
|
for supervisor in self.supervisors:
|
||||||
|
for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items():
|
||||||
|
dependency_map[plugin_id] = {
|
||||||
|
str(dependency or "").strip()
|
||||||
|
for dependency in getattr(registration, "dependencies", [])
|
||||||
|
if str(dependency or "").strip()
|
||||||
|
}
|
||||||
|
return dependency_map
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _collect_reverse_dependents(
|
||||||
|
plugin_ids: Set[str],
|
||||||
|
dependency_map: Dict[str, Set[str]],
|
||||||
|
) -> Set[str]:
|
||||||
|
"""根据依赖图收集反向依赖闭包。"""
|
||||||
|
|
||||||
|
impacted_plugins: Set[str] = set(plugin_ids)
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
while changed:
|
||||||
|
changed = False
|
||||||
|
for registered_plugin_id, dependencies in dependency_map.items():
|
||||||
|
if registered_plugin_id in impacted_plugins:
|
||||||
|
continue
|
||||||
|
if dependencies & impacted_plugins:
|
||||||
|
impacted_plugins.add(registered_plugin_id)
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
return impacted_plugins
|
||||||
|
|
||||||
|
def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]:
|
||||||
|
"""构建当前已注册插件到所属 Supervisor 的映射。"""
|
||||||
|
|
||||||
|
return {
|
||||||
|
plugin_id: supervisor
|
||||||
|
for supervisor in self.supervisors
|
||||||
|
for plugin_id in supervisor.get_loaded_plugin_ids()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> List[str]:
|
||||||
|
"""收集某个 Supervisor 可用的外部插件 ID 列表。"""
|
||||||
|
|
||||||
|
external_plugin_ids: Set[str] = set()
|
||||||
|
for supervisor in self.supervisors:
|
||||||
|
if supervisor is target_supervisor:
|
||||||
|
continue
|
||||||
|
external_plugin_ids.update(supervisor.get_loaded_plugin_ids())
|
||||||
|
return sorted(external_plugin_ids)
|
||||||
|
|
||||||
|
def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
|
||||||
|
"""根据插件目录推断应负责该插件重载的 Supervisor。"""
|
||||||
|
|
||||||
|
for supervisor in self.supervisors:
|
||||||
|
for plugin_dir in supervisor._plugin_dirs:
|
||||||
|
if (Path(plugin_dir) / plugin_id).is_dir():
|
||||||
|
return supervisor
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _warn_skipped_cross_supervisor_reload(
|
||||||
|
self,
|
||||||
|
requested_loaded_plugin_ids: Set[str],
|
||||||
|
dependency_map: Dict[str, Set[str]],
|
||||||
|
supervisor_by_plugin: Dict[str, "PluginSupervisor"],
|
||||||
|
) -> None:
|
||||||
|
"""记录因跨 Supervisor 边界而未参与联动重载的插件。"""
|
||||||
|
|
||||||
|
if not requested_loaded_plugin_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
handled_plugin_ids: Set[str] = set()
|
||||||
|
for supervisor in self.supervisors:
|
||||||
|
local_requested_plugin_ids = {
|
||||||
|
plugin_id
|
||||||
|
for plugin_id in requested_loaded_plugin_ids
|
||||||
|
if supervisor_by_plugin.get(plugin_id) is supervisor
|
||||||
|
}
|
||||||
|
if not local_requested_plugin_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
|
||||||
|
local_dependency_map = {
|
||||||
|
plugin_id: {
|
||||||
|
dependency
|
||||||
|
for dependency in dependency_map.get(plugin_id, set())
|
||||||
|
if dependency in local_plugin_ids
|
||||||
|
}
|
||||||
|
for plugin_id in local_plugin_ids
|
||||||
|
}
|
||||||
|
handled_plugin_ids.update(
|
||||||
|
self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map)
|
||||||
|
)
|
||||||
|
|
||||||
|
impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map)
|
||||||
|
skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids)
|
||||||
|
if not skipped_plugin_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: "
|
||||||
|
f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;"
|
||||||
|
"跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool:
|
||||||
|
"""按 Supervisor 分组执行精确重载。
|
||||||
|
|
||||||
|
仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警,
|
||||||
|
不再自动参与本次热重载。
|
||||||
|
"""
|
||||||
|
|
||||||
|
normalized_plugin_ids = [
|
||||||
|
normalized_plugin_id
|
||||||
|
for plugin_id in plugin_ids
|
||||||
|
if (normalized_plugin_id := str(plugin_id or "").strip())
|
||||||
|
]
|
||||||
|
if not normalized_plugin_ids:
|
||||||
|
return True
|
||||||
|
|
||||||
|
dependency_map = self._build_registered_dependency_map()
|
||||||
|
supervisor_by_plugin = self._build_registered_supervisor_map()
|
||||||
|
supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
|
||||||
|
requested_loaded_plugin_ids: Set[str] = set()
|
||||||
|
missing_plugin_ids: List[str] = []
|
||||||
|
|
||||||
|
for plugin_id in normalized_plugin_ids:
|
||||||
|
supervisor = supervisor_by_plugin.get(plugin_id)
|
||||||
|
if supervisor is not None:
|
||||||
|
requested_loaded_plugin_ids.add(plugin_id)
|
||||||
|
else:
|
||||||
|
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||||
|
|
||||||
|
if supervisor is None:
|
||||||
|
missing_plugin_ids.append(plugin_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if plugin_id not in supervisor_roots.setdefault(supervisor, []):
|
||||||
|
supervisor_roots[supervisor].append(plugin_id)
|
||||||
|
|
||||||
|
if missing_plugin_ids:
|
||||||
|
logger.warning(f"以下插件未找到可重载的 Supervisor,已跳过: {', '.join(sorted(missing_plugin_ids))}")
|
||||||
|
|
||||||
|
self._warn_skipped_cross_supervisor_reload(
|
||||||
|
requested_loaded_plugin_ids=requested_loaded_plugin_ids,
|
||||||
|
dependency_map=dependency_map,
|
||||||
|
supervisor_by_plugin=supervisor_by_plugin,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = True
|
||||||
|
for supervisor, root_plugin_ids in supervisor_roots.items():
|
||||||
|
if not root_plugin_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
reloaded = await supervisor.reload_plugins(
|
||||||
|
plugin_ids=root_plugin_ids,
|
||||||
|
reason=reason,
|
||||||
|
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
|
||||||
|
)
|
||||||
|
success = success and reloaded
|
||||||
|
|
||||||
|
return success and not missing_plugin_ids
|
||||||
|
|
||||||
async def notify_plugin_config_updated(
|
async def notify_plugin_config_updated(
|
||||||
self,
|
self,
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
@@ -465,6 +715,31 @@ class PluginRuntimeManager(
|
|||||||
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
|
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
|
||||||
return matches[0] if matches else None
|
return matches[0] if matches else None
|
||||||
|
|
||||||
|
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
|
||||||
|
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
|
||||||
|
|
||||||
|
normalized_plugin_id = str(plugin_id or "").strip()
|
||||||
|
if not normalized_plugin_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if registered_supervisor is not None:
|
||||||
|
return await self.reload_plugins_globally([normalized_plugin_id], reason=reason)
|
||||||
|
|
||||||
|
supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id)
|
||||||
|
if supervisor is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return await supervisor.reload_plugins(
|
||||||
|
plugin_ids=[normalized_plugin_id],
|
||||||
|
reason=reason,
|
||||||
|
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
|
def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
|
||||||
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
|
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
|
||||||
@@ -729,7 +1004,7 @@ class PluginRuntimeManager(
|
|||||||
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
|
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
|
||||||
return
|
return
|
||||||
|
|
||||||
reload_supervisors: Dict[Any, List[str]] = {}
|
changed_plugin_ids: List[str] = []
|
||||||
changed_paths = [change.path.resolve() for change in changes]
|
changed_paths = [change.path.resolve() for change in changes]
|
||||||
|
|
||||||
for supervisor in self.supervisors:
|
for supervisor in self.supervisors:
|
||||||
@@ -738,14 +1013,11 @@ class PluginRuntimeManager(
|
|||||||
if plugin_id is None:
|
if plugin_id is None:
|
||||||
continue
|
continue
|
||||||
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
|
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
|
||||||
reload_supervisors.setdefault(supervisor, [])
|
if plugin_id not in changed_plugin_ids:
|
||||||
if plugin_id not in reload_supervisors[supervisor]:
|
changed_plugin_ids.append(plugin_id)
|
||||||
reload_supervisors[supervisor].append(plugin_id)
|
|
||||||
|
|
||||||
for supervisor, plugin_ids in reload_supervisors.items():
|
if changed_plugin_ids:
|
||||||
await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher")
|
await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
|
||||||
|
|
||||||
if reload_supervisors:
|
|
||||||
self._refresh_plugin_config_watch_subscriptions()
|
self._refresh_plugin_config_watch_subscriptions()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -166,6 +166,8 @@ class RegisterPluginPayload(BaseModel):
|
|||||||
"""组件列表"""
|
"""组件列表"""
|
||||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||||
"""所需能力列表"""
|
"""所需能力列表"""
|
||||||
|
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
|
||||||
|
"""插件级依赖插件 ID 列表"""
|
||||||
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
|
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
|
||||||
"""订阅的全局配置热重载范围"""
|
"""订阅的全局配置热重载范围"""
|
||||||
|
|
||||||
@@ -280,6 +282,8 @@ class ReloadPluginPayload(BaseModel):
|
|||||||
"""目标插件 ID"""
|
"""目标插件 ID"""
|
||||||
reason: str = Field(default="manual", description="重载原因")
|
reason: str = Field(default="manual", description="重载原因")
|
||||||
"""重载原因"""
|
"""重载原因"""
|
||||||
|
external_available_plugins: List[str] = Field(default_factory=list, description="可视为已满足的外部依赖插件 ID")
|
||||||
|
"""可视为已满足的外部依赖插件 ID"""
|
||||||
|
|
||||||
|
|
||||||
class ReloadPluginResultPayload(BaseModel):
|
class ReloadPluginResultPayload(BaseModel):
|
||||||
|
|||||||
@@ -95,11 +95,16 @@ class PluginLoader:
|
|||||||
self._manifest_validator = ManifestValidator(host_version=host_version)
|
self._manifest_validator = ManifestValidator(host_version=host_version)
|
||||||
self._compat_hook_installed = False
|
self._compat_hook_installed = False
|
||||||
|
|
||||||
def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
|
def discover_and_load(
|
||||||
|
self,
|
||||||
|
plugin_dirs: List[str],
|
||||||
|
extra_available: Optional[Set[str]] = None,
|
||||||
|
) -> List[PluginMeta]:
|
||||||
"""扫描多个目录并加载所有插件。
|
"""扫描多个目录并加载所有插件。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plugin_dirs: 插件目录列表。
|
plugin_dirs: 插件目录列表。
|
||||||
|
extra_available: 额外视为已满足的外部依赖插件 ID 集合。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
|
List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
|
||||||
@@ -108,7 +113,7 @@ class PluginLoader:
|
|||||||
self._record_duplicate_candidates(duplicate_candidates)
|
self._record_duplicate_candidates(duplicate_candidates)
|
||||||
|
|
||||||
# 第二阶段:依赖解析(拓扑排序)
|
# 第二阶段:依赖解析(拓扑排序)
|
||||||
load_order, failed_deps = self._resolve_dependencies(candidates)
|
load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available)
|
||||||
self._record_failed_dependencies(failed_deps)
|
self._record_failed_dependencies(failed_deps)
|
||||||
|
|
||||||
# 第三阶段:按依赖顺序加载
|
# 第三阶段:按依赖顺序加载
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, ca
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import logging as stdlib_logging
|
import logging as stdlib_logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@@ -23,7 +24,13 @@ import time
|
|||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
from src.common.logger import get_console_handler, get_logger, initialize_logging
|
from src.common.logger import get_console_handler, get_logger, initialize_logging
|
||||||
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
from src.plugin_runtime import (
|
||||||
|
ENV_EXTERNAL_PLUGIN_IDS,
|
||||||
|
ENV_HOST_VERSION,
|
||||||
|
ENV_IPC_ADDRESS,
|
||||||
|
ENV_PLUGIN_DIRS,
|
||||||
|
ENV_SESSION_TOKEN,
|
||||||
|
)
|
||||||
from src.plugin_runtime.protocol.envelope import (
|
from src.plugin_runtime.protocol.envelope import (
|
||||||
BootstrapPluginPayload,
|
BootstrapPluginPayload,
|
||||||
ComponentDeclaration,
|
ComponentDeclaration,
|
||||||
@@ -112,6 +119,7 @@ class PluginRunner:
|
|||||||
host_address: str,
|
host_address: str,
|
||||||
session_token: str,
|
session_token: str,
|
||||||
plugin_dirs: List[str],
|
plugin_dirs: List[str],
|
||||||
|
external_available_plugin_ids: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化 Runner。
|
"""初始化 Runner。
|
||||||
|
|
||||||
@@ -119,10 +127,16 @@ class PluginRunner:
|
|||||||
host_address: Host 的 IPC 地址。
|
host_address: Host 的 IPC 地址。
|
||||||
session_token: 握手用会话令牌。
|
session_token: 握手用会话令牌。
|
||||||
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
|
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
|
||||||
|
external_available_plugin_ids: 视为已满足的外部依赖插件 ID 列表。
|
||||||
"""
|
"""
|
||||||
self._host_address: str = host_address
|
self._host_address: str = host_address
|
||||||
self._session_token: str = session_token
|
self._session_token: str = session_token
|
||||||
self._plugin_dirs: List[str] = plugin_dirs
|
self._plugin_dirs: List[str] = plugin_dirs
|
||||||
|
self._external_available_plugin_ids: Set[str] = {
|
||||||
|
str(plugin_id or "").strip()
|
||||||
|
for plugin_id in (external_available_plugin_ids or [])
|
||||||
|
if str(plugin_id or "").strip()
|
||||||
|
}
|
||||||
|
|
||||||
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
|
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
|
||||||
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
|
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
|
||||||
@@ -150,7 +164,10 @@ class PluginRunner:
|
|||||||
self._register_handlers()
|
self._register_handlers()
|
||||||
|
|
||||||
# 3. 加载插件
|
# 3. 加载插件
|
||||||
plugins = self._loader.discover_and_load(self._plugin_dirs)
|
plugins = self._loader.discover_and_load(
|
||||||
|
self._plugin_dirs,
|
||||||
|
extra_available=self._external_available_plugin_ids,
|
||||||
|
)
|
||||||
logger.info(f"已加载 {len(plugins)} 个插件")
|
logger.info(f"已加载 {len(plugins)} 个插件")
|
||||||
|
|
||||||
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
|
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
|
||||||
@@ -379,6 +396,7 @@ class PluginRunner:
|
|||||||
plugin_version=meta.version,
|
plugin_version=meta.version,
|
||||||
components=components,
|
components=components,
|
||||||
capabilities_required=meta.capabilities_required,
|
capabilities_required=meta.capabilities_required,
|
||||||
|
dependencies=meta.dependencies,
|
||||||
config_reload_subscriptions=config_reload_subscriptions,
|
config_reload_subscriptions=config_reload_subscriptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -485,18 +503,20 @@ class PluginRunner:
|
|||||||
self._loader.set_loaded_plugin(meta)
|
self._loader.set_loaded_plugin(meta)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None:
|
async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None:
|
||||||
"""卸载单个插件并清理 Host/Runner 两侧状态。
|
"""卸载单个插件并清理 Host/Runner 两侧状态。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
meta: 待卸载的插件元数据。
|
meta: 待卸载的插件元数据。
|
||||||
reason: 卸载原因。
|
reason: 卸载原因。
|
||||||
|
purge_modules: 是否在卸载完成后清理插件模块缓存。
|
||||||
"""
|
"""
|
||||||
await self._invoke_plugin_on_unload(meta)
|
await self._invoke_plugin_on_unload(meta)
|
||||||
await self._unregister_plugin(meta.plugin_id, reason)
|
await self._unregister_plugin(meta.plugin_id, reason)
|
||||||
await self._deactivate_plugin(meta)
|
await self._deactivate_plugin(meta)
|
||||||
self._loader.remove_loaded_plugin(meta.plugin_id)
|
self._loader.remove_loaded_plugin(meta.plugin_id)
|
||||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
if purge_modules:
|
||||||
|
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||||
|
|
||||||
def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]:
|
def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]:
|
||||||
"""收集依赖指定插件的所有已加载插件。
|
"""收集依赖指定插件的所有已加载插件。
|
||||||
@@ -564,18 +584,52 @@ class PluginRunner:
|
|||||||
|
|
||||||
return list(reversed(load_order))
|
return list(reversed(load_order))
|
||||||
|
|
||||||
async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload:
|
@staticmethod
|
||||||
|
def _finalize_failed_reload_messages(
|
||||||
|
failed_plugins: Dict[str, str],
|
||||||
|
rollback_failures: Dict[str, str],
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""在重载失败后补充回滚结果说明。"""
|
||||||
|
|
||||||
|
finalized_failures: Dict[str, str] = {}
|
||||||
|
for failed_plugin_id, failure_reason in failed_plugins.items():
|
||||||
|
rollback_failure = rollback_failures.get(failed_plugin_id)
|
||||||
|
if rollback_failure:
|
||||||
|
finalized_failures[failed_plugin_id] = (
|
||||||
|
f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)"
|
||||||
|
|
||||||
|
for failed_plugin_id, rollback_failure in rollback_failures.items():
|
||||||
|
if failed_plugin_id not in finalized_failures:
|
||||||
|
finalized_failures[failed_plugin_id] = f"旧版本恢复失败: {rollback_failure}"
|
||||||
|
|
||||||
|
return finalized_failures
|
||||||
|
|
||||||
|
async def _reload_plugin_by_id(
|
||||||
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
reason: str,
|
||||||
|
external_available_plugins: Optional[Set[str]] = None,
|
||||||
|
) -> ReloadPluginResultPayload:
|
||||||
"""按插件 ID 在 Runner 进程内执行精确重载。
|
"""按插件 ID 在 Runner 进程内执行精确重载。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plugin_id: 目标插件 ID。
|
plugin_id: 目标插件 ID。
|
||||||
reason: 重载原因。
|
reason: 重载原因。
|
||||||
|
external_available_plugins: 视为已满足的外部依赖插件 ID 集合。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ReloadPluginResultPayload: 结构化重载结果。
|
ReloadPluginResultPayload: 结构化重载结果。
|
||||||
"""
|
"""
|
||||||
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
|
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
|
||||||
failed_plugins: Dict[str, str] = {}
|
failed_plugins: Dict[str, str] = {}
|
||||||
|
normalized_external_available = {
|
||||||
|
str(candidate_plugin_id or "").strip()
|
||||||
|
for candidate_plugin_id in (external_available_plugins or set())
|
||||||
|
if str(candidate_plugin_id or "").strip()
|
||||||
|
}
|
||||||
|
|
||||||
if plugin_id in duplicate_candidates:
|
if plugin_id in duplicate_candidates:
|
||||||
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
|
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
|
||||||
@@ -603,29 +657,32 @@ class PluginRunner:
|
|||||||
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
|
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
|
||||||
unloaded_plugins: List[str] = []
|
unloaded_plugins: List[str] = []
|
||||||
retained_plugin_ids = loaded_plugin_ids - set(unload_order)
|
retained_plugin_ids = loaded_plugin_ids - set(unload_order)
|
||||||
|
rollback_metas: Dict[str, PluginMeta] = {}
|
||||||
|
|
||||||
for unload_plugin_id in unload_order:
|
for unload_plugin_id in unload_order:
|
||||||
meta = self._loader.get_plugin(unload_plugin_id)
|
meta = self._loader.get_plugin(unload_plugin_id)
|
||||||
if meta is None:
|
if meta is None:
|
||||||
continue
|
continue
|
||||||
await self._unload_plugin(meta, reason=reason)
|
rollback_metas[unload_plugin_id] = meta
|
||||||
|
await self._unload_plugin(meta, reason=reason, purge_modules=False)
|
||||||
|
self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir)
|
||||||
unloaded_plugins.append(unload_plugin_id)
|
unloaded_plugins.append(unload_plugin_id)
|
||||||
|
|
||||||
reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {}
|
reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {}
|
||||||
for target_plugin_id in target_plugin_ids:
|
for target_plugin_id in target_plugin_ids:
|
||||||
candidate = candidates.get(target_plugin_id)
|
candidate = candidates.get(target_plugin_id)
|
||||||
if candidate is None:
|
if candidate is None:
|
||||||
failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态"
|
failed_plugins[target_plugin_id] = "插件目录已不存在"
|
||||||
continue
|
continue
|
||||||
reload_candidates[target_plugin_id] = candidate
|
reload_candidates[target_plugin_id] = candidate
|
||||||
|
|
||||||
load_order, dependency_failures = self._loader.resolve_dependencies(
|
load_order, dependency_failures = self._loader.resolve_dependencies(
|
||||||
reload_candidates,
|
reload_candidates,
|
||||||
extra_available=retained_plugin_ids,
|
extra_available=retained_plugin_ids | normalized_external_available,
|
||||||
)
|
)
|
||||||
failed_plugins.update(dependency_failures)
|
failed_plugins.update(dependency_failures)
|
||||||
|
|
||||||
available_plugins = set(retained_plugin_ids)
|
available_plugins = set(retained_plugin_ids) | normalized_external_available
|
||||||
reloaded_plugins: List[str] = []
|
reloaded_plugins: List[str] = []
|
||||||
|
|
||||||
for load_plugin_id in load_order:
|
for load_plugin_id in load_order:
|
||||||
@@ -656,7 +713,48 @@ class PluginRunner:
|
|||||||
available_plugins.add(load_plugin_id)
|
available_plugins.add(load_plugin_id)
|
||||||
reloaded_plugins.append(load_plugin_id)
|
reloaded_plugins.append(load_plugin_id)
|
||||||
|
|
||||||
requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins
|
if failed_plugins:
|
||||||
|
rollback_failures: Dict[str, str] = {}
|
||||||
|
|
||||||
|
for reloaded_plugin_id in reversed(reloaded_plugins):
|
||||||
|
reloaded_meta = self._loader.get_plugin(reloaded_plugin_id)
|
||||||
|
if reloaded_meta is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._unload_plugin(
|
||||||
|
reloaded_meta,
|
||||||
|
reason=f"{reason}_rollback_cleanup",
|
||||||
|
purge_modules=False,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
rollback_failures[reloaded_plugin_id] = f"清理失败: {exc}"
|
||||||
|
finally:
|
||||||
|
self._loader.purge_plugin_modules(reloaded_plugin_id, reloaded_meta.plugin_dir)
|
||||||
|
|
||||||
|
for rollback_plugin_id in reversed(unload_order):
|
||||||
|
rollback_meta = rollback_metas.get(rollback_plugin_id)
|
||||||
|
if rollback_meta is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
restored = await self._activate_plugin(rollback_meta)
|
||||||
|
except Exception as exc:
|
||||||
|
rollback_failures[rollback_plugin_id] = str(exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not restored:
|
||||||
|
rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
|
||||||
|
|
||||||
|
return ReloadPluginResultPayload(
|
||||||
|
success=False,
|
||||||
|
requested_plugin_id=plugin_id,
|
||||||
|
reloaded_plugins=[],
|
||||||
|
unloaded_plugins=unloaded_plugins,
|
||||||
|
failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
|
||||||
|
)
|
||||||
|
|
||||||
|
requested_plugin_success = plugin_id in reloaded_plugins
|
||||||
|
|
||||||
return ReloadPluginResultPayload(
|
return ReloadPluginResultPayload(
|
||||||
success=requested_plugin_success,
|
success=requested_plugin_success,
|
||||||
@@ -978,7 +1076,11 @@ class PluginRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async with self._reload_lock:
|
async with self._reload_lock:
|
||||||
result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason)
|
result = await self._reload_plugin_by_id(
|
||||||
|
payload.plugin_id,
|
||||||
|
payload.reason,
|
||||||
|
external_available_plugins=set(payload.external_available_plugins),
|
||||||
|
)
|
||||||
return envelope.make_response(payload=result.model_dump())
|
return envelope.make_response(payload=result.model_dump())
|
||||||
|
|
||||||
def request_capability(self) -> RPCClient:
|
def request_capability(self) -> RPCClient:
|
||||||
@@ -1073,6 +1175,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
|||||||
async def _async_main() -> None:
|
async def _async_main() -> None:
|
||||||
"""异步主入口"""
|
"""异步主入口"""
|
||||||
host_address = os.environ.get(ENV_IPC_ADDRESS, "")
|
host_address = os.environ.get(ENV_IPC_ADDRESS, "")
|
||||||
|
external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "")
|
||||||
session_token = os.environ.get(ENV_SESSION_TOKEN, "")
|
session_token = os.environ.get(ENV_SESSION_TOKEN, "")
|
||||||
plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "")
|
plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "")
|
||||||
|
|
||||||
@@ -1081,11 +1184,24 @@ async def _async_main() -> None:
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
|
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
|
||||||
|
try:
|
||||||
|
external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else []
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("解析外部依赖插件列表失败,已回退为空列表")
|
||||||
|
external_plugin_ids = []
|
||||||
|
if not isinstance(external_plugin_ids, list):
|
||||||
|
logger.warning("外部依赖插件列表格式非法,已回退为空列表")
|
||||||
|
external_plugin_ids = []
|
||||||
|
|
||||||
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
|
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
|
||||||
_isolate_sys_path(plugin_dirs)
|
_isolate_sys_path(plugin_dirs)
|
||||||
|
|
||||||
runner = PluginRunner(host_address, session_token, plugin_dirs)
|
runner = PluginRunner(
|
||||||
|
host_address,
|
||||||
|
session_token,
|
||||||
|
plugin_dirs,
|
||||||
|
external_available_plugin_ids=[str(plugin_id) for plugin_id in external_plugin_ids],
|
||||||
|
)
|
||||||
|
|
||||||
# 注册信号处理
|
# 注册信号处理
|
||||||
def _mark_runner_shutting_down() -> None:
|
def _mark_runner_shutting_down() -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user