feat:修复门控多重result问题,新增at动作,插件现在运行chat_id指定或chat_type指定
This commit is contained in:
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import (
|
||||
ToolAvailabilityContext,
|
||||
ToolExecutionContext,
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
@@ -72,6 +73,7 @@ class ComponentQueryService:
|
||||
component_type: ComponentType,
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
|
||||
"""遍历指定类型的全部组件条目。
|
||||
|
||||
@@ -87,11 +89,15 @@ class ComponentQueryService:
|
||||
if host_component_type is None:
|
||||
return []
|
||||
|
||||
session_id = context.session_id if context is not None else None
|
||||
is_group_chat = context.is_group_chat if context is not None else None
|
||||
collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
|
||||
for supervisor in self._iter_supervisors():
|
||||
for component in supervisor.component_registry.get_components_by_type(
|
||||
host_component_type,
|
||||
enabled_only=enabled_only,
|
||||
session_id=session_id,
|
||||
is_group_chat=is_group_chat,
|
||||
):
|
||||
collected_entries.append((supervisor, component))
|
||||
return collected_entries
|
||||
@@ -657,7 +663,10 @@ class ComponentQueryService:
|
||||
tool_entry = cast("ToolEntry", entry)
|
||||
return self._build_tool_executor(supervisor, tool_entry.plugin_id, tool_entry.name, tool_entry.invoke_method)
|
||||
|
||||
def get_llm_available_tool_specs(self) -> Dict[str, ToolSpec]:
|
||||
def get_llm_available_tool_specs(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> Dict[str, ToolSpec]:
|
||||
"""获取当前可供 LLM 使用的统一工具声明集合。
|
||||
|
||||
Returns:
|
||||
@@ -665,7 +674,7 @@ class ComponentQueryService:
|
||||
"""
|
||||
|
||||
collected_specs: Dict[str, ToolSpec] = {}
|
||||
for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL):
|
||||
for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL, context=context):
|
||||
if entry.name in collected_specs:
|
||||
self._log_duplicate_component(ComponentType.TOOL, entry.name)
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, TypedDict
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
@@ -58,6 +58,20 @@ class ComponentTypes(str, Enum):
|
||||
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
|
||||
|
||||
|
||||
ComponentChatScope = Literal["all", "group", "private"]
|
||||
|
||||
|
||||
def _normalize_chat_scope(raw_value: Any) -> ComponentChatScope:
|
||||
"""规范化组件聊天类型适用范围。"""
|
||||
|
||||
normalized_value = str(raw_value or "all").strip().lower()
|
||||
if normalized_value == "group":
|
||||
return "group"
|
||||
if normalized_value == "private":
|
||||
return "private"
|
||||
return "all"
|
||||
|
||||
|
||||
class StatusDict(TypedDict):
|
||||
total: int
|
||||
action: int
|
||||
@@ -81,9 +95,17 @@ class ComponentEntry:
|
||||
"enabled",
|
||||
"compiled_pattern",
|
||||
"disabled_session",
|
||||
"chat_scope",
|
||||
)
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.name: str = name
|
||||
self.full_name: str = f"{plugin_id}.{name}"
|
||||
self.component_type: ComponentTypes = ComponentTypes(component_type)
|
||||
@@ -91,20 +113,35 @@ class ComponentEntry:
|
||||
self.metadata: Dict[str, Any] = metadata
|
||||
self.enabled: bool = metadata.get("enabled", True)
|
||||
self.disabled_session: Set[str] = set()
|
||||
self.chat_scope: ComponentChatScope = _normalize_chat_scope(chat_scope)
|
||||
|
||||
|
||||
class ActionEntry(ComponentEntry):
|
||||
"""Action 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
|
||||
class CommandEntry(ComponentEntry):
|
||||
"""Command 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
self.aliases: List[str] = metadata.get("aliases", [])
|
||||
self.compiled_pattern: Optional[re.Pattern] = None
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
@@ -117,7 +154,14 @@ class CommandEntry(ComponentEntry):
|
||||
class ToolEntry(ComponentEntry):
|
||||
"""Tool 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.description: str = str(metadata.get("description", "") or "").strip()
|
||||
self.brief_description: str = str(
|
||||
metadata.get("brief_description", self.description) or self.description or f"工具 {name}"
|
||||
@@ -128,7 +172,7 @@ class ToolEntry(ComponentEntry):
|
||||
self.detailed_description: str = detailed_description
|
||||
self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip()
|
||||
self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip()
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
if not self.detailed_description:
|
||||
parameters_schema = self._get_parameters_schema()
|
||||
@@ -197,23 +241,37 @@ class ToolEntry(ComponentEntry):
|
||||
class EventHandlerEntry(ComponentEntry):
|
||||
"""EventHandler 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.event_type: str = metadata.get("event_type", "")
|
||||
self.weight: int = metadata.get("weight", 0)
|
||||
self.intercept_message: bool = metadata.get("intercept_message", False)
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
|
||||
class HookHandlerEntry(ComponentEntry):
|
||||
"""HookHandler 组件条目。"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.hook: str = self._normalize_hook_name(metadata.get("hook", ""))
|
||||
self.mode: str = self._normalize_mode(metadata.get("mode", "blocking"))
|
||||
self.order: str = self._normalize_order(metadata.get("order", "normal"))
|
||||
self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0))
|
||||
self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip"))
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_error_policy(raw_value: Any) -> str:
|
||||
@@ -332,13 +390,20 @@ class HookHandlerEntry(ComponentEntry):
|
||||
class MessageGatewayEntry(ComponentEntry):
|
||||
"""MessageGateway 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
|
||||
self.platform: str = str(metadata.get("platform", "") or "").strip()
|
||||
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
|
||||
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
|
||||
self.scope: str = str(metadata.get("scope", "") or "").strip()
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_route_type(raw_value: Any) -> str:
|
||||
@@ -578,6 +643,7 @@ class ComponentRegistry:
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> ComponentEntry:
|
||||
"""根据声明构造组件条目。
|
||||
|
||||
@@ -599,18 +665,18 @@ class ComponentRegistry:
|
||||
normalized_metadata = dict(metadata)
|
||||
if normalized_type == ComponentTypes.ACTION:
|
||||
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
|
||||
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
|
||||
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata, chat_scope)
|
||||
elif normalized_type == ComponentTypes.COMMAND:
|
||||
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
elif normalized_type == ComponentTypes.TOOL:
|
||||
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
elif normalized_type == ComponentTypes.EVENT_HANDLER:
|
||||
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
elif normalized_type == ComponentTypes.HOOK_HANDLER:
|
||||
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
self._validate_hook_handler_entry(component)
|
||||
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
else:
|
||||
raise ComponentRegistrationError(
|
||||
f"组件类型 {component_type} 不存在",
|
||||
@@ -662,7 +728,14 @@ class ComponentRegistry:
|
||||
self._by_plugin.setdefault(component.plugin_id, []).append(component)
|
||||
|
||||
# ====== 注册 / 注销 ======
|
||||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
def register_component(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> bool:
|
||||
"""注册单个组件。
|
||||
|
||||
Args:
|
||||
@@ -678,7 +751,7 @@ class ComponentRegistry:
|
||||
ComponentRegistrationError: 组件声明不合法时抛出。
|
||||
"""
|
||||
|
||||
component = self._build_component_entry(name, component_type, plugin_id, metadata)
|
||||
component = self._build_component_entry(name, component_type, plugin_id, metadata, chat_scope)
|
||||
self._add_component_entry(component)
|
||||
return True
|
||||
|
||||
@@ -701,14 +774,19 @@ class ComponentRegistry:
|
||||
|
||||
prepared_components: List[ComponentEntry] = []
|
||||
for component_data in components:
|
||||
raw_metadata = (
|
||||
dict(component_data.get("metadata", {}))
|
||||
if isinstance(component_data.get("metadata"), dict)
|
||||
else {}
|
||||
)
|
||||
chat_scope = str(component_data.get("chat_scope", raw_metadata.pop("chat_scope", "all")) or "all")
|
||||
prepared_components.append(
|
||||
self._build_component_entry(
|
||||
name=str(component_data.get("name", "") or ""),
|
||||
component_type=str(component_data.get("component_type", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=component_data.get("metadata", {})
|
||||
if isinstance(component_data.get("metadata"), dict)
|
||||
else {},
|
||||
metadata=raw_metadata,
|
||||
chat_scope=chat_scope,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -733,9 +811,19 @@ class ComponentRegistry:
|
||||
return len(comps)
|
||||
|
||||
# ====== 启用 / 禁用 ======
|
||||
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
|
||||
def check_component_enabled(
|
||||
self,
|
||||
component: ComponentEntry,
|
||||
session_id: Optional[str] = None,
|
||||
is_group_chat: Optional[bool] = None,
|
||||
):
|
||||
if session_id and session_id in component.disabled_session:
|
||||
return False
|
||||
if is_group_chat is not None:
|
||||
if component.chat_scope == "group" and is_group_chat is not True:
|
||||
return False
|
||||
if component.chat_scope == "private" and is_group_chat is not False:
|
||||
return False
|
||||
return component.enabled
|
||||
|
||||
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
|
||||
@@ -806,7 +894,12 @@ class ComponentRegistry:
|
||||
return self._components.get(full_name)
|
||||
|
||||
def get_components_by_type(
|
||||
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
self,
|
||||
component_type: str,
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
is_group_chat: Optional[bool] = None,
|
||||
) -> List[ComponentEntry]:
|
||||
"""按类型查询组件
|
||||
|
||||
@@ -830,12 +923,16 @@ class ComponentRegistry:
|
||||
if self._is_legacy_action_component(component)
|
||||
]
|
||||
if enabled_only:
|
||||
return [component for component in action_components if self.check_component_enabled(component, session_id)]
|
||||
return [
|
||||
component
|
||||
for component in action_components
|
||||
if self.check_component_enabled(component, session_id, is_group_chat)
|
||||
]
|
||||
return action_components
|
||||
|
||||
type_dict = self._by_type.get(comp_type, {})
|
||||
if enabled_only:
|
||||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
|
||||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id, is_group_chat)]
|
||||
return list(type_dict.values())
|
||||
|
||||
def get_components_by_plugin(
|
||||
|
||||
@@ -4,7 +4,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
from src.core.tooling import (
|
||||
ToolAvailabilityContext,
|
||||
ToolExecutionContext,
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
ToolProvider,
|
||||
ToolSpec,
|
||||
)
|
||||
|
||||
from .component_query import component_query_service
|
||||
|
||||
@@ -15,10 +22,13 @@ class PluginToolProvider(ToolProvider):
|
||||
provider_name = "plugin_runtime"
|
||||
provider_type = "plugin"
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
async def list_tools(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolSpec]:
|
||||
"""列出插件运行时当前可用的工具声明。"""
|
||||
|
||||
return list(component_query_service.get_llm_available_tool_specs().values())
|
||||
return list(component_query_service.get_llm_available_tool_specs(context=context).values())
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user