diff --git a/pytests/test_tool_availability.py b/pytests/test_tool_availability.py index 9946900c..1f6368f5 100644 --- a/pytests/test_tool_availability.py +++ b/pytests/test_tool_availability.py @@ -1,4 +1,6 @@ from types import SimpleNamespace +import importlib.util +import sys import pytest @@ -94,3 +96,127 @@ def test_plugin_tool_session_disable_still_filters_specific_chat(monkeypatch: py assert "mute" not in disabled_specs assert "mute" in enabled_specs + + +def test_plugin_tool_allowed_session_filters_tool_exposure(monkeypatch: pytest.MonkeyPatch) -> None: + service = ComponentQueryService() + registry = ComponentRegistry() + supervisor = SimpleNamespace(component_registry=registry) + monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor]) + + registry.register_plugin_components( + "mute_plugin", + [ + { + "name": "mute", + "component_type": "TOOL", + "chat_scope": "group", + "allowed_session": ["qq:10001", "raw-group-id", "exact-session-id"], + "metadata": {"description": "mute group member"}, + } + ], + ) + + platform_group_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext( + session_id="hashed-session-1", + is_group_chat=True, + group_id="10001", + platform="qq", + ) + ) + raw_group_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext( + session_id="hashed-session-2", + is_group_chat=True, + group_id="raw-group-id", + platform="qq", + ) + ) + exact_session_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext(session_id="exact-session-id", is_group_chat=True) + ) + blocked_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext( + session_id="blocked-session", + is_group_chat=True, + group_id="20002", + platform="qq", + ) + ) + + entry = registry.get_component("mute_plugin.mute") + assert entry is not None + assert entry.allowed_session == {"qq:10001", "raw-group-id", "exact-session-id"} + assert "allowed_session" not in entry.metadata + assert "mute" in platform_group_specs + assert "mute" in raw_group_specs + assert "mute" in exact_session_specs + assert "mute" not in blocked_specs + + +def test_plugin_tool_disabled_session_take_precedence_over_allowed_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = ComponentQueryService() + registry = ComponentRegistry() + supervisor = SimpleNamespace(component_registry=registry) + monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor]) + + registry.register_plugin_components( + "mute_plugin", + [ + { + "name": "mute", + "component_type": "TOOL", + "chat_scope": "group", + "allowed_session": ["qq:10001"], + "metadata": {"description": "mute group member"}, + } + ], + ) + registry.set_component_enabled("mute_plugin.mute", False, session_id="allowed-session") + + visible_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext( + session_id="visible-session", + is_group_chat=True, + group_id="10001", + platform="qq", + ) + ) + disabled_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext( + session_id="allowed-session", + is_group_chat=True, + group_id="10001", + platform="qq", + ) + ) + + entry = registry.get_component("mute_plugin.mute") + assert entry is not None + assert entry.disabled_session == {"allowed-session"} + assert "mute" in visible_specs + assert "mute" not in disabled_specs + + +def test_mute_plugin_exports_allowed_groups_as_component_allowed_session() -> None: + module_path = "plugins/MutePlugin/plugin.py" + spec = importlib.util.spec_from_file_location("mute_plugin_under_test", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + module.MutePluginConfig.model_rebuild() + + plugin = module.MutePlugin() + plugin.set_plugin_config({"permissions": {"allowed_groups": ["qq:10001", "raw-group-id"]}}) + + mute_components = [component for component in plugin.get_components() if component.get("name") == "mute"] + + assert len(mute_components) == 1 + assert mute_components[0]["chat_scope"] == "group" + assert mute_components[0]["allowed_session"] == ["qq:10001", "raw-group-id"] + assert "allowed_session" not in mute_components[0]["metadata"] diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py index 2fb01797..f5c3e1f7 100644 --- a/src/plugin_runtime/component_query.py +++ b/src/plugin_runtime/component_query.py @@ -91,6 +91,8 @@ class ComponentQueryService: session_id = context.session_id if context is not None else None is_group_chat = context.is_group_chat if context is not None else None + group_id = context.group_id if context is not None else None + platform = context.platform if context is not None else None collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = [] for supervisor in self._iter_supervisors(): for component in supervisor.component_registry.get_components_by_type( @@ -98,6 +100,8 @@ class ComponentQueryService: enabled_only=enabled_only, session_id=session_id, is_group_chat=is_group_chat, + group_id=group_id, + platform=platform, ): collected_entries.append((supervisor, component)) return collected_entries diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index fceb7828..2b6e4c76 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -96,6 +96,7 @@ class ComponentEntry: "compiled_pattern", "disabled_session", "chat_scope", + "allowed_session", ) def __init__( @@ -105,6 +106,7 @@ class ComponentEntry: plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> None: self.name: str = name self.full_name: str = f"{plugin_id}.{name}" @@ -114,7 +116,11 @@ class ComponentEntry: self.enabled: bool = metadata.get("enabled", True) self.disabled_session: Set[str] = set() self.chat_scope: ComponentChatScope = _normalize_chat_scope(chat_scope) - + self.allowed_session: Set[str] = { + str(session_id).strip() + for session_id in (allowed_session or []) + if str(session_id).strip() + } class ActionEntry(ComponentEntry): """Action 组件条目""" @@ -126,8 +132,9 @@ class ActionEntry(ComponentEntry): plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> None: - super().__init__(name, component_type, plugin_id, metadata, chat_scope) + super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session) class CommandEntry(ComponentEntry): @@ -140,8 +147,9 @@ class CommandEntry(ComponentEntry): plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> None: - super().__init__(name, component_type, plugin_id, metadata, chat_scope) + super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session) self.aliases: List[str] = metadata.get("aliases", []) self.compiled_pattern: Optional[re.Pattern] = None if pattern := metadata.get("command_pattern", ""): @@ -161,6 +169,7 @@ class ToolEntry(ComponentEntry): plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> None: self.description: str = str(metadata.get("description", "") or "").strip() self.brief_description: str = str( @@ -172,7 +181,7 @@ class ToolEntry(ComponentEntry): self.detailed_description: str = detailed_description self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip() self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip() - super().__init__(name, component_type, plugin_id, metadata, chat_scope) + super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session) if not self.detailed_description: parameters_schema = self._get_parameters_schema() @@ -248,11 +257,12 @@ class EventHandlerEntry(ComponentEntry): plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> None: self.event_type: str = metadata.get("event_type", "") self.weight: int = metadata.get("weight", 0) self.intercept_message: bool = metadata.get("intercept_message", False) - super().__init__(name, component_type, plugin_id, metadata, chat_scope) + super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session) class HookHandlerEntry(ComponentEntry): @@ -265,13 +275,14 @@ class HookHandlerEntry(ComponentEntry): plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> None: self.hook: str = self._normalize_hook_name(metadata.get("hook", "")) self.mode: str = self._normalize_mode(metadata.get("mode", "blocking")) self.order: str = self._normalize_order(metadata.get("order", "normal")) self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0)) self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip")) - super().__init__(name, component_type, plugin_id, metadata, chat_scope) + super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session) @staticmethod def _normalize_error_policy(raw_value: Any) -> str: @@ -397,13 +408,14 @@ class MessageGatewayEntry(ComponentEntry): plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> None: self.route_type: str = self._normalize_route_type(metadata.get("route_type", "")) self.platform: str = str(metadata.get("platform", "") or "").strip() self.protocol: str = str(metadata.get("protocol", "") or "").strip() self.account_id: str = str(metadata.get("account_id", "") or "").strip() self.scope: str = str(metadata.get("scope", "") or "").strip() - super().__init__(name, component_type, plugin_id, metadata, chat_scope) + super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session) @staticmethod def _normalize_route_type(raw_value: Any) -> str: @@ -644,6 +656,7 @@ class ComponentRegistry: plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> ComponentEntry: """根据声明构造组件条目。 @@ -665,18 +678,60 @@ class ComponentRegistry: normalized_metadata = dict(metadata) if normalized_type == ComponentTypes.ACTION: normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata) - component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata, chat_scope) + component = ToolEntry( + name, + ComponentTypes.TOOL.value, + plugin_id, + normalized_metadata, + chat_scope, + allowed_session, + ) elif normalized_type == ComponentTypes.COMMAND: - component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) + component = CommandEntry( + name, + normalized_type.value, + plugin_id, + normalized_metadata, + chat_scope, + allowed_session, + ) elif normalized_type == ComponentTypes.TOOL: - component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) + component = ToolEntry( + name, + normalized_type.value, + plugin_id, + normalized_metadata, + chat_scope, + allowed_session, + ) elif normalized_type == ComponentTypes.EVENT_HANDLER: - component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) + component = EventHandlerEntry( + name, + normalized_type.value, + plugin_id, + normalized_metadata, + chat_scope, + allowed_session, + ) elif normalized_type == ComponentTypes.HOOK_HANDLER: - component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) + component = HookHandlerEntry( + name, + normalized_type.value, + plugin_id, + normalized_metadata, + chat_scope, + allowed_session, + ) self._validate_hook_handler_entry(component) elif normalized_type == ComponentTypes.MESSAGE_GATEWAY: - component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) + component = MessageGatewayEntry( + name, + normalized_type.value, + plugin_id, + normalized_metadata, + chat_scope, + allowed_session, + ) else: raise ComponentRegistrationError( f"组件类型 {component_type} 不存在", @@ -735,6 +790,7 @@ class ComponentRegistry: plugin_id: str, metadata: Dict[str, Any], chat_scope: str = "all", + allowed_session: Optional[List[str]] = None, ) -> bool: """注册单个组件。 @@ -751,7 +807,14 @@ class ComponentRegistry: ComponentRegistrationError: 组件声明不合法时抛出。 """ - component = self._build_component_entry(name, component_type, plugin_id, metadata, chat_scope) + component = self._build_component_entry( + name, + component_type, + plugin_id, + metadata, + chat_scope, + allowed_session, + ) self._add_component_entry(component) return True @@ -780,6 +843,12 @@ class ComponentRegistry: else {} ) chat_scope = str(component_data.get("chat_scope", raw_metadata.pop("chat_scope", "all")) or "all") + raw_allowed_session = component_data.get("allowed_session", raw_metadata.pop("allowed_session", [])) + allowed_session = ( + [str(item).strip() for item in raw_allowed_session if str(item).strip()] + if isinstance(raw_allowed_session, list) + else [] + ) prepared_components.append( self._build_component_entry( name=str(component_data.get("name", "") or ""), @@ -787,6 +856,7 @@ class ComponentRegistry: plugin_id=plugin_id, metadata=raw_metadata, chat_scope=chat_scope, + allowed_session=allowed_session, ) ) @@ -816,6 +886,8 @@ class ComponentRegistry: component: ComponentEntry, session_id: Optional[str] = None, is_group_chat: Optional[bool] = None, + group_id: Optional[str] = None, + platform: Optional[str] = None, ): if session_id and session_id in component.disabled_session: return False @@ -824,6 +896,13 @@ class ComponentRegistry: return False if component.chat_scope == "private" and is_group_chat is not False: return False + if component.allowed_session: + allowed_candidates = {str(session_id or "").strip(), str(group_id or "").strip()} + if platform and group_id: + allowed_candidates.add(f"{platform}:{group_id}") + allowed_candidates.discard("") + if component.allowed_session.isdisjoint(allowed_candidates): + return False return component.enabled def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: @@ -900,6 +979,8 @@ class ComponentRegistry: enabled_only: bool = True, session_id: Optional[str] = None, is_group_chat: Optional[bool] = None, + group_id: Optional[str] = None, + platform: Optional[str] = None, ) -> List[ComponentEntry]: """按类型查询组件 @@ -926,13 +1007,17 @@ class ComponentRegistry: return [ component for component in action_components - if self.check_component_enabled(component, session_id, is_group_chat) + if self.check_component_enabled(component, session_id, is_group_chat, group_id, platform) ] return action_components type_dict = self._by_type.get(comp_type, {}) if enabled_only: - return [c for c in type_dict.values() if self.check_component_enabled(c, session_id, is_group_chat)] + return [ + c + for c in type_dict.values() + if self.check_component_enabled(c, session_id, is_group_chat, group_id, platform) + ] return list(type_dict.values()) def get_components_by_plugin( diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index e4bebba9..81d8ec33 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -191,6 +191,10 @@ class ComponentDeclaration(BaseModel): """组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`""" plugin_id: str = Field(description="所属插件 ID") """所属插件 ID""" + chat_scope: str = Field(default="all", description="组件适用聊天类型:all/group/private") + """组件适用聊天类型。""" + allowed_session: List[str] = Field(default_factory=list, description="允许暴露该组件的会话 ID 或平台作用域 ID") + """允许暴露该组件的具体会话。空列表表示不限制。""" metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据") """组件元数据""" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index e485567f..bb8c3f4c 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -1003,6 +1003,14 @@ class PluginRunner: name=component_name, component_type=str(comp_info.get("type", "") or "").strip(), plugin_id=meta.plugin_id, + chat_scope=str(comp_info.get("chat_scope", "all") or "all").strip(), + allowed_session=[ + str(item).strip() + for item in comp_info.get("allowed_session", []) + if str(item).strip() + ] + if isinstance(comp_info.get("allowed_session"), list) + else [], metadata=component_metadata, ) )