Merge remote-tracking branch 'upstream/r-dev' into sync/pr-1564-upstream-20260331
# Conflicts: # src/chat/brain_chat/PFC/conversation.py # src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py # src/chat/knowledge/lpmm_ops.py
This commit is contained in:
@@ -29,6 +29,19 @@ class _RuntimeComponentManagerProtocol(Protocol):
|
||||
|
||||
def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
|
||||
|
||||
def _collect_api_reference_matches(
|
||||
self,
|
||||
caller_plugin_id: str,
|
||||
normalized_api_name: str,
|
||||
normalized_version: str,
|
||||
) -> tuple[List[tuple["PluginSupervisor", "APIEntry"]], List[tuple["PluginSupervisor", "APIEntry"]], bool]: ...
|
||||
|
||||
def _collect_api_toggle_reference_matches(
|
||||
self,
|
||||
normalized_name: str,
|
||||
normalized_version: str,
|
||||
) -> List[tuple["PluginSupervisor", "APIEntry"]]: ...
|
||||
|
||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
|
||||
|
||||
def _resolve_api_target(
|
||||
@@ -58,6 +71,73 @@ class _RuntimeComponentManagerProtocol(Protocol):
|
||||
|
||||
|
||||
class RuntimeComponentCapabilityMixin:
|
||||
def _collect_api_reference_matches(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
caller_plugin_id: str,
|
||||
normalized_api_name: str,
|
||||
normalized_version: str,
|
||||
) -> tuple[List[tuple["PluginSupervisor", "APIEntry"]], List[tuple["PluginSupervisor", "APIEntry"]], bool]:
|
||||
"""按 API 完整名或短名精确收集匹配项。
|
||||
|
||||
该辅助方法用于兼容名字中本身包含 ``.`` 的 API。对于这类 API,
|
||||
不能简单按最后一个点号拆成 ``plugin_id.api_name``。
|
||||
|
||||
Args:
|
||||
caller_plugin_id: 调用方插件 ID。
|
||||
normalized_api_name: 已规范化的 API 名称。
|
||||
normalized_version: 已规范化的版本号。
|
||||
|
||||
Returns:
|
||||
tuple[List[tuple[PluginSupervisor, APIEntry]], List[tuple[PluginSupervisor, APIEntry]], bool]:
|
||||
依次为可见且启用的匹配项、可见但已禁用的匹配项、是否存在不可见匹配项。
|
||||
"""
|
||||
|
||||
visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
hidden_match_exists = False
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
):
|
||||
if entry.name != normalized_api_name and entry.full_name != normalized_api_name:
|
||||
continue
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id):
|
||||
if entry.enabled:
|
||||
visible_enabled_matches.append((supervisor, entry))
|
||||
else:
|
||||
visible_disabled_matches.append((supervisor, entry))
|
||||
else:
|
||||
hidden_match_exists = True
|
||||
|
||||
return visible_enabled_matches, visible_disabled_matches, hidden_match_exists
|
||||
|
||||
def _collect_api_toggle_reference_matches(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
normalized_name: str,
|
||||
normalized_version: str,
|
||||
) -> List[tuple["PluginSupervisor", "APIEntry"]]:
|
||||
"""按 API 完整名或短名精确收集启停操作匹配项。
|
||||
|
||||
Args:
|
||||
normalized_name: 已规范化的 API 名称。
|
||||
normalized_version: 已规范化的版本号。
|
||||
|
||||
Returns:
|
||||
List[tuple[PluginSupervisor, APIEntry]]: 匹配到的 API 条目列表。
|
||||
"""
|
||||
|
||||
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
for supervisor in self.supervisors:
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
):
|
||||
if entry.name == normalized_name or entry.full_name == normalized_name:
|
||||
matches.append((supervisor, entry))
|
||||
return matches
|
||||
|
||||
@staticmethod
|
||||
def _normalize_component_type(component_type: str) -> str:
|
||||
"""规范化组件类型名称。
|
||||
@@ -69,7 +149,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
str: 统一转为大写后的组件类型名。
|
||||
"""
|
||||
|
||||
return str(component_type or "").strip().upper()
|
||||
normalized_component_type = str(component_type or "").strip().upper()
|
||||
if normalized_component_type == "ACTION":
|
||||
return "TOOL"
|
||||
return normalized_component_type
|
||||
|
||||
@classmethod
|
||||
def _is_api_component_type(cls, component_type: str) -> bool:
|
||||
@@ -190,6 +273,20 @@ class RuntimeComponentCapabilityMixin:
|
||||
if not normalized_api_name:
|
||||
return None, None, "缺少必要参数 api_name"
|
||||
|
||||
exact_visible_enabled_matches, exact_visible_disabled_matches, exact_hidden_match_exists = (
|
||||
self._collect_api_reference_matches(caller_plugin_id, normalized_api_name, normalized_version)
|
||||
)
|
||||
if len(exact_visible_enabled_matches) == 1:
|
||||
return exact_visible_enabled_matches[0][0], exact_visible_enabled_matches[0][1], None
|
||||
if len(exact_visible_enabled_matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_api_name},请显式指定 version"
|
||||
if exact_visible_disabled_matches:
|
||||
if len(exact_visible_disabled_matches) == 1:
|
||||
return None, None, self._build_api_unavailable_error(exact_visible_disabled_matches[0][1])
|
||||
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
|
||||
if exact_hidden_match_exists:
|
||||
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
|
||||
|
||||
if "." in normalized_api_name:
|
||||
target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1)
|
||||
try:
|
||||
@@ -207,9 +304,7 @@ class RuntimeComponentCapabilityMixin:
|
||||
enabled_only=False,
|
||||
)
|
||||
visible_enabled_entries = [
|
||||
entry
|
||||
for entry in entries
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
|
||||
entry for entry in entries if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
|
||||
]
|
||||
visible_disabled_entries = [
|
||||
entry
|
||||
@@ -281,6 +376,12 @@ class RuntimeComponentCapabilityMixin:
|
||||
if not normalized_name:
|
||||
return None, None, "缺少必要参数 name"
|
||||
|
||||
exact_matches = self._collect_api_toggle_reference_matches(normalized_name, normalized_version)
|
||||
if len(exact_matches) == 1:
|
||||
return exact_matches[0][0], exact_matches[0][1], None
|
||||
if len(exact_matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_name},请显式指定 version"
|
||||
|
||||
if "." in normalized_name:
|
||||
plugin_id, api_name = normalized_name.rsplit(".", 1)
|
||||
try:
|
||||
|
||||
@@ -1,33 +1,80 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
|
||||
|
||||
def _get_nested_config_value(source: Any, key: str, default: Any = None) -> Any:
|
||||
"""从嵌套对象或字典中读取配置值。
|
||||
|
||||
Args:
|
||||
source: 配置对象或字典。
|
||||
key: 以点号分隔的路径。
|
||||
default: 未命中时返回的默认值。
|
||||
|
||||
Returns:
|
||||
Any: 命中的值;读取失败时返回默认值。
|
||||
"""
|
||||
current = source
|
||||
try:
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
elif hasattr(current, part):
|
||||
continue
|
||||
if hasattr(current, part):
|
||||
current = getattr(current, part)
|
||||
else:
|
||||
raise KeyError(part)
|
||||
continue
|
||||
raise KeyError(part)
|
||||
return current
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _normalize_prompt_arg(prompt: Any) -> str | List[Dict[str, Any]]:
|
||||
"""校验并规范化插件传入的提示参数。
|
||||
|
||||
Args:
|
||||
prompt: 原始提示参数。
|
||||
|
||||
Returns:
|
||||
str | List[Dict[str, Any]]: 规范化后的提示输入。
|
||||
|
||||
Raises:
|
||||
ValueError: 提示参数缺失或结构不受支持时抛出。
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
if not prompt.strip():
|
||||
raise ValueError("缺少必要参数 prompt")
|
||||
return prompt
|
||||
if isinstance(prompt, list) and prompt:
|
||||
for index, prompt_message in enumerate(prompt, start=1):
|
||||
if not isinstance(prompt_message, dict):
|
||||
raise ValueError(f"prompt 第 {index} 项必须为字典")
|
||||
return prompt
|
||||
raise ValueError("缺少必要参数 prompt")
|
||||
|
||||
|
||||
class RuntimeCoreCapabilityMixin:
|
||||
"""插件运行时的核心能力混入。"""
|
||||
|
||||
async def _cap_send_text(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送文本消息。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
text: str = args.get("text", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
text = str(args.get("text", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not text or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 text 或 stream_id"}
|
||||
|
||||
@@ -35,20 +82,31 @@ class RuntimeCoreCapabilityMixin:
|
||||
result = await send_api.text_to_stream(
|
||||
text=text,
|
||||
stream_id=stream_id,
|
||||
typing=args.get("typing", False),
|
||||
set_reply=args.get("set_reply", False),
|
||||
storage_message=args.get("storage_message", True),
|
||||
typing=bool(args.get("typing", False)),
|
||||
set_reply=bool(args.get("set_reply", False)),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.text] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_send_emoji(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送表情图片。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
emoji_base64: str = args.get("emoji_base64", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
emoji_base64 = str(args.get("emoji_base64", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not emoji_base64 or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"}
|
||||
|
||||
@@ -56,18 +114,29 @@ class RuntimeCoreCapabilityMixin:
|
||||
result = await send_api.emoji_to_stream(
|
||||
emoji_base64=emoji_base64,
|
||||
stream_id=stream_id,
|
||||
storage_message=args.get("storage_message", True),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.emoji] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_send_image(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送图片。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
image_base64: str = args.get("image_base64", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
image_base64 = str(args.get("image_base64", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not image_base64 or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"}
|
||||
|
||||
@@ -75,18 +144,29 @@ class RuntimeCoreCapabilityMixin:
|
||||
result = await send_api.image_to_stream(
|
||||
image_base64=image_base64,
|
||||
stream_id=stream_id,
|
||||
storage_message=args.get("storage_message", True),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.image] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_send_command(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送命令消息。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
command = args.get("command", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
command = str(args.get("command", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not command or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 command 或 stream_id"}
|
||||
|
||||
@@ -95,22 +175,33 @@ class RuntimeCoreCapabilityMixin:
|
||||
message_type="command",
|
||||
content=command,
|
||||
stream_id=stream_id,
|
||||
storage_message=args.get("storage_message", True),
|
||||
display_message=args.get("display_message", ""),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
display_message=str(args.get("display_message", "")),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.command] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_send_custom(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送自定义消息。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
|
||||
message_type = str(args.get("message_type", "") or args.get("custom_type", ""))
|
||||
content = args.get("content")
|
||||
if content is None:
|
||||
content = args.get("data", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not message_type or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
|
||||
|
||||
@@ -119,114 +210,116 @@ class RuntimeCoreCapabilityMixin:
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=stream_id,
|
||||
display_message=args.get("display_message", ""),
|
||||
typing=args.get("typing", False),
|
||||
storage_message=args.get("storage_message", True),
|
||||
display_message=str(args.get("display_message", "")),
|
||||
typing=bool(args.get("typing", False)),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.custom] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_llm_generate(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""执行无工具的 LLM 生成能力。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 标准化后的 LLM 响应结构。
|
||||
"""
|
||||
del capability
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
prompt: str = args.get("prompt", "")
|
||||
if not prompt:
|
||||
return {"success": False, "error": "缺少必要参数 prompt"}
|
||||
|
||||
model_name: str = args.get("model", "") or args.get("model_name", "")
|
||||
temperature = args.get("temperature")
|
||||
max_tokens = args.get("max_tokens")
|
||||
|
||||
try:
|
||||
models = llm_api.get_available_models()
|
||||
if model_name and model_name in models:
|
||||
model_config = models[model_name]
|
||||
else:
|
||||
if not models:
|
||||
return {"success": False, "error": "没有可用的模型配置"}
|
||||
model_config = next(iter(models.values()))
|
||||
|
||||
success, response, reasoning, used_model = await llm_api.generate_with_model(
|
||||
prompt=prompt,
|
||||
model_config=model_config,
|
||||
request_type=f"plugin.{plugin_id}",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
prompt = _normalize_prompt_arg(args.get("prompt"))
|
||||
task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", "")))
|
||||
result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name=task_name,
|
||||
request_type=f"plugin.{plugin_id}",
|
||||
prompt=prompt,
|
||||
temperature=args.get("temperature"),
|
||||
max_tokens=args.get("max_tokens"),
|
||||
)
|
||||
)
|
||||
return {
|
||||
"success": success,
|
||||
"response": response,
|
||||
"reasoning": reasoning,
|
||||
"model_name": used_model,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
return result.to_capability_payload()
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.llm.generate] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_llm_generate_with_tools(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""执行带工具的 LLM 生成能力。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 标准化后的 LLM 响应结构。
|
||||
"""
|
||||
del capability
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
prompt: str = args.get("prompt", "")
|
||||
if not prompt:
|
||||
return {"success": False, "error": "缺少必要参数 prompt"}
|
||||
|
||||
model_name: str = args.get("model", "") or args.get("model_name", "")
|
||||
tool_options = args.get("tools") or args.get("tool_options")
|
||||
temperature = args.get("temperature")
|
||||
max_tokens = args.get("max_tokens")
|
||||
if tool_options is not None and not isinstance(tool_options, list):
|
||||
return {"success": False, "error": "tools 必须为列表"}
|
||||
|
||||
try:
|
||||
models = llm_api.get_available_models()
|
||||
if model_name and model_name in models:
|
||||
model_config = models[model_name]
|
||||
else:
|
||||
if not models:
|
||||
return {"success": False, "error": "没有可用的模型配置"}
|
||||
model_config = next(iter(models.values()))
|
||||
|
||||
success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
prompt=prompt,
|
||||
model_config=model_config,
|
||||
tool_options=tool_options,
|
||||
request_type=f"plugin.{plugin_id}",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
prompt = _normalize_prompt_arg(args.get("prompt"))
|
||||
task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", "")))
|
||||
result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name=task_name,
|
||||
request_type=f"plugin.{plugin_id}",
|
||||
prompt=prompt,
|
||||
tool_options=tool_options,
|
||||
temperature=args.get("temperature"),
|
||||
max_tokens=args.get("max_tokens"),
|
||||
)
|
||||
)
|
||||
serialized_tool_calls = None
|
||||
if tool_calls:
|
||||
serialized_tool_calls = [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {"name": tool_call.func_name, "arguments": tool_call.args or {}},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
if isinstance(tool_call, ToolCall)
|
||||
]
|
||||
return {
|
||||
"success": success,
|
||||
"response": response,
|
||||
"reasoning": reasoning,
|
||||
"model_name": used_model,
|
||||
"tool_calls": serialized_tool_calls,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
return result.to_capability_payload()
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.llm.generate_with_tools] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_llm_get_available_models(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取当前宿主可用的模型任务列表。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 可用模型列表。
|
||||
"""
|
||||
del plugin_id, capability, args
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
try:
|
||||
models = llm_api.get_available_models()
|
||||
return {"success": True, "models": list(models.keys())}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.llm.get_available_models] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_config_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
key: str = args.get("key", "")
|
||||
"""读取宿主全局配置中的单个字段。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 配置读取结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
key = str(args.get("key", ""))
|
||||
default = args.get("default")
|
||||
if not key:
|
||||
return {"success": False, "value": None, "error": "缺少必要参数 key"}
|
||||
@@ -234,37 +327,57 @@ class RuntimeCoreCapabilityMixin:
|
||||
try:
|
||||
value = _get_nested_config_value(global_config, key, default)
|
||||
return {"success": True, "value": value}
|
||||
except Exception as e:
|
||||
return {"success": False, "value": None, "error": str(e)}
|
||||
except Exception as exc:
|
||||
return {"success": False, "value": None, "error": str(exc)}
|
||||
|
||||
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""读取指定插件的配置。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 配置读取结果。
|
||||
"""
|
||||
del capability
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
key: str = args.get("key", "")
|
||||
plugin_name = str(args.get("plugin_name", plugin_id))
|
||||
key = str(args.get("key", ""))
|
||||
default = args.get("default")
|
||||
|
||||
try:
|
||||
config = component_query_service.get_plugin_config(plugin_name)
|
||||
if config is None:
|
||||
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
|
||||
|
||||
if key:
|
||||
value = _get_nested_config_value(config, key, default)
|
||||
return {"success": True, "value": value}
|
||||
|
||||
return {"success": True, "value": config}
|
||||
except Exception as e:
|
||||
return {"success": False, "value": default, "error": str(e)}
|
||||
except Exception as exc:
|
||||
return {"success": False, "value": default, "error": str(exc)}
|
||||
|
||||
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""读取指定插件的全部配置。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 配置读取结果。
|
||||
"""
|
||||
del capability
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
plugin_name = str(args.get("plugin_name", plugin_id))
|
||||
try:
|
||||
config = component_query_service.get_plugin_config(plugin_name)
|
||||
if config is None:
|
||||
return {"success": True, "value": {}}
|
||||
return {"success": True, "value": config}
|
||||
except Exception as e:
|
||||
return {"success": False, "value": {}, "error": str(e)}
|
||||
except Exception as exc:
|
||||
return {"success": False, "value": {}, "error": str(exc)}
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"""插件运行时统一组件查询服务。
|
||||
|
||||
该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图,
|
||||
供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。
|
||||
供 HFC、ToolExecutor 和运行时能力层查询与调用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import (
|
||||
ToolExecutionContext,
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
ToolSpec,
|
||||
build_tool_detailed_description,
|
||||
)
|
||||
from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import normalize_tool_option
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
|
||||
@@ -28,13 +35,6 @@ _HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
|
||||
ComponentType.COMMAND: "COMMAND",
|
||||
ComponentType.TOOL: "TOOL",
|
||||
}
|
||||
_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
|
||||
"string": ToolParamType.STRING,
|
||||
"integer": ToolParamType.INTEGER,
|
||||
"float": ToolParamType.FLOAT,
|
||||
"boolean": ToolParamType.BOOLEAN,
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
|
||||
|
||||
class ComponentQueryService:
|
||||
@@ -146,36 +146,25 @@ class ComponentQueryService:
|
||||
metadata = dict(entry.metadata)
|
||||
raw_action_parameters = metadata.get("action_parameters")
|
||||
action_parameters = (
|
||||
{
|
||||
str(param_name): str(param_description)
|
||||
for param_name, param_description in raw_action_parameters.items()
|
||||
}
|
||||
{str(param_name): str(param_description) for param_name, param_description in raw_action_parameters.items()}
|
||||
if isinstance(raw_action_parameters, dict)
|
||||
else {}
|
||||
)
|
||||
action_require = [
|
||||
str(item)
|
||||
for item in (metadata.get("action_require") or [])
|
||||
if item is not None and str(item).strip()
|
||||
str(item) for item in (metadata.get("action_require") or []) if item is not None and str(item).strip()
|
||||
]
|
||||
associated_types = [
|
||||
str(item)
|
||||
for item in (metadata.get("associated_types") or [])
|
||||
if item is not None and str(item).strip()
|
||||
str(item) for item in (metadata.get("associated_types") or []) if item is not None and str(item).strip()
|
||||
]
|
||||
activation_keywords = [
|
||||
str(item)
|
||||
for item in (metadata.get("activation_keywords") or [])
|
||||
if item is not None and str(item).strip()
|
||||
str(item) for item in (metadata.get("activation_keywords") or []) if item is not None and str(item).strip()
|
||||
]
|
||||
|
||||
return ActionInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=str(metadata.get("description", "") or ""),
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=metadata,
|
||||
action_parameters=action_parameters,
|
||||
action_require=action_require,
|
||||
associated_types=associated_types,
|
||||
@@ -202,72 +191,48 @@ class ComponentQueryService:
|
||||
metadata = dict(entry.metadata)
|
||||
return CommandInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
description=str(metadata.get("description", "") or ""),
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=metadata,
|
||||
command_pattern=str(metadata.get("command_pattern", "") or ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
|
||||
"""规范化工具参数类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始工具参数类型值。
|
||||
|
||||
Returns:
|
||||
ToolParamType: 规范化后的工具参数类型。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
|
||||
"""将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
|
||||
def _build_tool_definition(entry: "ToolEntry") -> dict[str, Any]:
|
||||
"""将运行时 Tool 条目转换为原始工具定义字典。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
Returns:
|
||||
list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。
|
||||
dict[str, Any]: 可交给 `normalize_tool_option()` 的原始工具定义。
|
||||
"""
|
||||
raw_definition: dict[str, Any] = {
|
||||
"name": entry.name,
|
||||
"description": entry.description,
|
||||
}
|
||||
if isinstance(entry.parameters_raw, dict) and entry.parameters_raw:
|
||||
raw_definition["parameters_schema"] = entry.parameters_raw
|
||||
return raw_definition
|
||||
if isinstance(entry.parameters, list) and entry.parameters:
|
||||
raw_definition["parameters"] = entry.parameters
|
||||
return raw_definition
|
||||
if isinstance(entry.parameters_raw, list) and entry.parameters_raw:
|
||||
raw_definition["parameters"] = entry.parameters_raw
|
||||
return raw_definition
|
||||
return raw_definition
|
||||
|
||||
structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
|
||||
if not structured_parameters and isinstance(entry.parameters_raw, dict):
|
||||
structured_parameters = [
|
||||
{"name": key, **value}
|
||||
for key, value in entry.parameters_raw.items()
|
||||
if isinstance(value, dict)
|
||||
]
|
||||
@staticmethod
|
||||
def _build_tool_parameters_schema(entry: "ToolEntry") -> dict[str, Any] | None:
|
||||
"""将运行时 Tool 条目转换为对象级参数 Schema。
|
||||
|
||||
normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
for parameter in structured_parameters:
|
||||
if not isinstance(parameter, dict):
|
||||
continue
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
parameter_name = str(parameter.get("name", "") or "").strip()
|
||||
if not parameter_name:
|
||||
continue
|
||||
|
||||
enum_values = parameter.get("enum")
|
||||
normalized_enum_values = (
|
||||
[str(item) for item in enum_values if item is not None]
|
||||
if isinstance(enum_values, list)
|
||||
else None
|
||||
)
|
||||
normalized_parameters.append(
|
||||
(
|
||||
parameter_name,
|
||||
ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
|
||||
str(parameter.get("description", "") or ""),
|
||||
bool(parameter.get("required", True)),
|
||||
normalized_enum_values,
|
||||
)
|
||||
)
|
||||
return normalized_parameters
|
||||
Returns:
|
||||
dict[str, Any] | None: 规范化后的对象级参数 Schema。
|
||||
"""
|
||||
normalized_option = normalize_tool_option(ComponentQueryService._build_tool_definition(entry))
|
||||
return normalized_option.parameters_schema
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
|
||||
@@ -282,13 +247,36 @@ class ComponentQueryService:
|
||||
|
||||
return ToolInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.TOOL,
|
||||
description=entry.description,
|
||||
description=entry.brief_description or entry.description,
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=dict(entry.metadata),
|
||||
tool_parameters=ComponentQueryService._build_tool_parameters(entry),
|
||||
tool_description=entry.description,
|
||||
parameters_schema=ComponentQueryService._build_tool_parameters_schema(entry),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_spec(entry: "ToolEntry") -> ToolSpec:
|
||||
"""将运行时 Tool 条目转换为统一工具声明。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
Returns:
|
||||
ToolSpec: 统一工具声明。
|
||||
"""
|
||||
|
||||
parameters_schema = ComponentQueryService._build_tool_parameters_schema(entry)
|
||||
return ToolSpec(
|
||||
name=entry.name,
|
||||
brief_description=entry.brief_description or entry.description or f"工具 {entry.name}",
|
||||
detailed_description=entry.detailed_description or build_tool_detailed_description(parameters_schema),
|
||||
parameters_schema=parameters_schema,
|
||||
provider_name=entry.plugin_id,
|
||||
provider_type="plugin",
|
||||
metadata={
|
||||
"plugin_id": entry.plugin_id,
|
||||
"invoke_method": entry.invoke_method,
|
||||
"legacy_component_type": entry.legacy_component_type,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -478,9 +466,14 @@ class ComponentQueryService:
|
||||
message = kwargs.get("message")
|
||||
matched_groups = kwargs.get("matched_groups")
|
||||
plugin_config = kwargs.get("plugin_config")
|
||||
message_info = getattr(message, "message_info", None)
|
||||
group_info = getattr(message_info, "group_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
invoke_args: Dict[str, Any] = {
|
||||
"text": str(getattr(message, "processed_plain_text", "") or ""),
|
||||
"stream_id": str(getattr(message, "session_id", "") or ""),
|
||||
"group_id": str(getattr(group_info, "group_id", "") or ""),
|
||||
"user_id": str(getattr(user_info, "user_id", "") or ""),
|
||||
"matched_groups": matched_groups if isinstance(matched_groups, dict) else {},
|
||||
}
|
||||
if isinstance(plugin_config, dict):
|
||||
@@ -515,7 +508,12 @@ class ComponentQueryService:
|
||||
return _executor
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor:
|
||||
def _build_tool_executor(
|
||||
supervisor: "PluginSupervisor",
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
invoke_method: str = "plugin.invoke_tool",
|
||||
) -> ToolExecutor:
|
||||
"""构造工具执行 RPC 闭包。
|
||||
|
||||
Args:
|
||||
@@ -539,7 +537,7 @@ class ComponentQueryService:
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method="plugin.invoke_tool",
|
||||
method=invoke_method,
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=function_args,
|
||||
@@ -655,7 +653,162 @@ class ComponentQueryService:
|
||||
if matched_entry is None:
|
||||
return None
|
||||
supervisor, entry = matched_entry
|
||||
return self._build_tool_executor(supervisor, entry.plugin_id, entry.name)
|
||||
tool_entry = cast("ToolEntry", entry)
|
||||
return self._build_tool_executor(supervisor, tool_entry.plugin_id, tool_entry.name, tool_entry.invoke_method)
|
||||
|
||||
def get_llm_available_tool_specs(self) -> Dict[str, ToolSpec]:
|
||||
"""获取当前可供 LLM 使用的统一工具声明集合。
|
||||
|
||||
Returns:
|
||||
Dict[str, ToolSpec]: 工具名到工具声明的映射。
|
||||
"""
|
||||
|
||||
collected_specs: Dict[str, ToolSpec] = {}
|
||||
for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL):
|
||||
if entry.name in collected_specs:
|
||||
self._log_duplicate_component(ComponentType.TOOL, entry.name)
|
||||
continue
|
||||
collected_specs[entry.name] = self._build_tool_spec(entry) # type: ignore[arg-type]
|
||||
return collected_specs
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_invocation_payload(
|
||||
entry: "ToolEntry",
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext],
|
||||
) -> Dict[str, Any]:
|
||||
"""构造插件工具执行时发送给 Runner 的参数。
|
||||
|
||||
Args:
|
||||
entry: 目标工具条目。
|
||||
invocation: 统一工具调用请求。
|
||||
context: 统一工具执行上下文。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 发往 Runner 的参数字典。
|
||||
"""
|
||||
|
||||
payload = dict(invocation.arguments)
|
||||
if entry.invoke_method == "plugin.invoke_action":
|
||||
stream_id = context.stream_id if context is not None else invocation.stream_id
|
||||
reasoning = context.reasoning if context is not None else invocation.reasoning
|
||||
payload = {
|
||||
**payload,
|
||||
"stream_id": stream_id,
|
||||
"chat_id": stream_id,
|
||||
"reasoning": reasoning,
|
||||
"action_data": dict(invocation.arguments),
|
||||
}
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_invoke_result(
|
||||
entry: "ToolEntry",
|
||||
result: Any,
|
||||
) -> ToolExecutionResult:
|
||||
"""将插件组件返回值转换为统一工具执行结果。
|
||||
|
||||
Args:
|
||||
entry: 目标工具条目。
|
||||
result: 插件组件原始返回值。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一执行结果。
|
||||
"""
|
||||
|
||||
if isinstance(result, dict):
|
||||
success = bool(result.get("success", True))
|
||||
content = str(result.get("content", result.get("message", "")) or "").strip()
|
||||
error_message = ""
|
||||
if not success:
|
||||
error_message = str(result.get("error", result.get("message", "插件工具执行失败")) or "").strip()
|
||||
return ToolExecutionResult(
|
||||
tool_name=entry.name,
|
||||
success=success,
|
||||
content=content,
|
||||
error_message=error_message,
|
||||
structured_content=result,
|
||||
metadata={"plugin_id": entry.plugin_id},
|
||||
)
|
||||
|
||||
if isinstance(result, (list, tuple)) and result:
|
||||
if isinstance(result[0], bool):
|
||||
success = bool(result[0])
|
||||
message = "" if len(result) < 2 or result[1] is None else str(result[1]).strip()
|
||||
return ToolExecutionResult(
|
||||
tool_name=entry.name,
|
||||
success=success,
|
||||
content=message if success else "",
|
||||
error_message="" if success else message,
|
||||
structured_content=list(result),
|
||||
metadata={"plugin_id": entry.plugin_id},
|
||||
)
|
||||
|
||||
normalized_content = "" if result is None else str(result).strip()
|
||||
return ToolExecutionResult(
|
||||
tool_name=entry.name,
|
||||
success=True,
|
||||
content=normalized_content,
|
||||
structured_content=result,
|
||||
metadata={"plugin_id": entry.plugin_id},
|
||||
)
|
||||
|
||||
async def invoke_tool_as_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""按统一工具语义执行插件工具。
|
||||
|
||||
Args:
|
||||
invocation: 统一工具调用请求。
|
||||
context: 执行上下文。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一工具执行结果。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, invocation.tool_name)
|
||||
if matched_entry is None:
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=f"未找到插件工具:{invocation.tool_name}",
|
||||
)
|
||||
|
||||
supervisor, entry = matched_entry
|
||||
tool_entry = cast("ToolEntry", entry)
|
||||
invoke_payload = self._build_tool_invocation_payload(tool_entry, invocation, context)
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method=tool_entry.invoke_method,
|
||||
plugin_id=tool_entry.plugin_id,
|
||||
component_name=tool_entry.name,
|
||||
args=invoke_payload,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"运行时工具 {tool_entry.plugin_id}.{tool_entry.name} 执行失败: {exc}", exc_info=True)
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_entry.name,
|
||||
success=False,
|
||||
error_message=str(exc),
|
||||
metadata={"plugin_id": tool_entry.plugin_id},
|
||||
)
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
transport_success = bool(payload.get("success", False))
|
||||
result = payload.get("result")
|
||||
if not transport_success:
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_entry.name,
|
||||
success=False,
|
||||
error_message="" if result is None else str(result),
|
||||
structured_content=result,
|
||||
metadata={"plugin_id": tool_entry.plugin_id},
|
||||
)
|
||||
return self._parse_tool_invoke_result(tool_entry, result)
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
|
||||
"""获取当前可供 LLM 选择的工具集合。
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Host-side ComponentRegistry
|
||||
"""Host 侧组件注册表。
|
||||
|
||||
对齐旧系统 component_registry.py 的核心能力:
|
||||
- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway)
|
||||
- 按类型注册组件(action / command / tool / event_handler / hook_handler / message_gateway)
|
||||
- 命名空间 (plugin_id.component_name)
|
||||
- 命令正则匹配
|
||||
- 组件启用/禁用
|
||||
@@ -16,6 +16,7 @@ import contextlib
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import build_tool_detailed_description
|
||||
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
|
||||
@@ -89,11 +90,81 @@ class ToolEntry(ComponentEntry):
|
||||
"""Tool 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.description: str = metadata.get("description", "")
|
||||
self.description: str = str(metadata.get("description", "") or "").strip()
|
||||
self.brief_description: str = str(
|
||||
metadata.get("brief_description", self.description) or self.description or f"工具 {name}"
|
||||
).strip()
|
||||
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
|
||||
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
|
||||
self.parameters_raw: Dict[str, Any] | List[Dict[str, Any]] = metadata.get("parameters_raw", {})
|
||||
detailed_description = str(metadata.get("detailed_description", "") or "").strip()
|
||||
self.detailed_description: str = detailed_description
|
||||
self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip()
|
||||
self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip()
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
if not self.detailed_description:
|
||||
parameters_schema = self._get_parameters_schema()
|
||||
self.detailed_description = build_tool_detailed_description(parameters_schema)
|
||||
|
||||
def _get_parameters_schema(self) -> Dict[str, Any] | None:
|
||||
"""获取当前工具条目的对象级参数 Schema。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 归一化后的参数 Schema。
|
||||
"""
|
||||
|
||||
if isinstance(self.parameters_raw, dict) and self.parameters_raw:
|
||||
if self.parameters_raw.get("type") == "object" or "properties" in self.parameters_raw:
|
||||
return dict(self.parameters_raw)
|
||||
|
||||
required_names: List[str] = []
|
||||
normalized_properties: Dict[str, Any] = {}
|
||||
for property_name, property_schema in self.parameters_raw.items():
|
||||
if not isinstance(property_schema, dict):
|
||||
continue
|
||||
property_schema_copy = dict(property_schema)
|
||||
if bool(property_schema_copy.pop("required", False)):
|
||||
required_names.append(str(property_name))
|
||||
normalized_properties[str(property_name)] = property_schema_copy
|
||||
|
||||
schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": normalized_properties,
|
||||
}
|
||||
if required_names:
|
||||
schema["required"] = required_names
|
||||
return schema
|
||||
|
||||
if isinstance(self.parameters, list) and self.parameters:
|
||||
properties: Dict[str, Any] = {}
|
||||
required_names: List[str] = []
|
||||
for parameter in self.parameters:
|
||||
if not isinstance(parameter, dict):
|
||||
continue
|
||||
parameter_name = str(parameter.get("name", "") or "").strip()
|
||||
if not parameter_name:
|
||||
continue
|
||||
if bool(parameter.get("required", False)):
|
||||
required_names.append(parameter_name)
|
||||
properties[parameter_name] = {
|
||||
key: value
|
||||
for key, value in parameter.items()
|
||||
if key not in {"name", "required", "param_type"}
|
||||
}
|
||||
properties[parameter_name]["type"] = str(
|
||||
parameter.get("type", parameter.get("param_type", "string")) or "string"
|
||||
)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if required_names:
|
||||
schema["required"] = required_names
|
||||
return schema
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class EventHandlerEntry(ComponentEntry):
|
||||
"""EventHandler 组件条目"""
|
||||
@@ -106,14 +177,129 @@ class EventHandlerEntry(ComponentEntry):
|
||||
|
||||
|
||||
class HookHandlerEntry(ComponentEntry):
|
||||
"""WorkflowHandler 组件条目"""
|
||||
"""HookHandler 组件条目。"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.stage: str = metadata.get("stage", "")
|
||||
self.priority: int = metadata.get("priority", 0)
|
||||
self.blocking: bool = metadata.get("blocking", False)
|
||||
self.hook: str = self._normalize_hook_name(metadata.get("hook", ""))
|
||||
self.mode: str = self._normalize_mode(metadata.get("mode", "blocking"))
|
||||
self.order: str = self._normalize_order(metadata.get("order", "normal"))
|
||||
self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0))
|
||||
self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip"))
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_error_policy(raw_value: Any) -> str:
|
||||
"""规范化 Hook 异常处理策略。
|
||||
|
||||
Args:
|
||||
raw_value: 原始异常处理策略值。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的异常处理策略。
|
||||
|
||||
Raises:
|
||||
ValueError: 当异常处理策略不受支持时抛出。
|
||||
"""
|
||||
|
||||
normalized_source = getattr(raw_value, "value", raw_value)
|
||||
normalized_value = str(normalized_source or "").strip().lower() or "skip"
|
||||
if normalized_value not in {"abort", "skip", "log"}:
|
||||
raise ValueError(f"HookHandler 异常处理策略不合法: {raw_value}")
|
||||
return normalized_value
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hook_name(raw_value: Any) -> str:
|
||||
"""规范化命名 Hook 名称。
|
||||
|
||||
Args:
|
||||
raw_value: 原始 Hook 名称。
|
||||
|
||||
Returns:
|
||||
str: 去空白后的 Hook 名称。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 Hook 名称为空时抛出。
|
||||
"""
|
||||
|
||||
normalized_source = getattr(raw_value, "value", raw_value)
|
||||
if not (normalized_value := str(normalized_source or "").strip()):
|
||||
raise ValueError("HookHandler 的 hook 名称不能为空")
|
||||
return normalized_value
|
||||
|
||||
@staticmethod
|
||||
def _normalize_mode(raw_value: Any) -> str:
|
||||
"""规范化 Hook 处理模式。
|
||||
|
||||
Args:
|
||||
raw_value: 原始模式值。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的模式。
|
||||
|
||||
Raises:
|
||||
ValueError: 当模式不受支持时抛出。
|
||||
"""
|
||||
|
||||
normalized_source = getattr(raw_value, "value", raw_value)
|
||||
normalized_value = str(normalized_source or "").strip().lower() or "blocking"
|
||||
if normalized_value not in {"blocking", "observe"}:
|
||||
raise ValueError(f"HookHandler 模式不合法: {raw_value}")
|
||||
return normalized_value
|
||||
|
||||
@staticmethod
|
||||
def _normalize_order(raw_value: Any) -> str:
|
||||
"""规范化 Hook 顺序槽位。
|
||||
|
||||
Args:
|
||||
raw_value: 原始顺序值。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的顺序槽位。
|
||||
|
||||
Raises:
|
||||
ValueError: 当顺序值不受支持时抛出。
|
||||
"""
|
||||
|
||||
normalized_source = getattr(raw_value, "value", raw_value)
|
||||
normalized_value = str(normalized_source or "").strip().lower() or "normal"
|
||||
if normalized_value not in {"early", "normal", "late"}:
|
||||
raise ValueError(f"HookHandler 顺序槽位不合法: {raw_value}")
|
||||
return normalized_value
|
||||
|
||||
@staticmethod
|
||||
def _normalize_timeout_ms(raw_value: Any) -> int:
|
||||
"""规范化 Hook 超时配置。
|
||||
|
||||
Args:
|
||||
raw_value: 原始超时值。
|
||||
|
||||
Returns:
|
||||
int: 规范化后的超时毫秒数。
|
||||
|
||||
Raises:
|
||||
ValueError: 当超时值为负数或无法转换为整数时抛出。
|
||||
"""
|
||||
|
||||
try:
|
||||
timeout_ms = int(raw_value or 0)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"HookHandler 超时配置不合法: {raw_value}") from exc
|
||||
if timeout_ms < 0:
|
||||
raise ValueError(f"HookHandler 超时配置不能为负数: {raw_value}")
|
||||
return timeout_ms
|
||||
|
||||
@property
|
||||
def is_blocking(self) -> bool:
|
||||
"""返回当前 Hook 是否为阻塞模式。"""
|
||||
|
||||
return self.mode == "blocking"
|
||||
|
||||
@property
|
||||
def is_observe(self) -> bool:
|
||||
"""返回当前 Hook 是否为观察模式。"""
|
||||
|
||||
return self.mode == "observe"
|
||||
|
||||
|
||||
class MessageGatewayEntry(ComponentEntry):
|
||||
"""MessageGateway 组件条目"""
|
||||
@@ -167,7 +353,7 @@ class MessageGatewayEntry(ComponentEntry):
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""Host-side 组件注册表
|
||||
"""Host 侧组件注册表。
|
||||
|
||||
由 Supervisor 在收到 plugin.register_components 时调用。
|
||||
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
||||
@@ -185,6 +371,86 @@ class ComponentRegistry:
|
||||
# 按插件索引
|
||||
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _convert_action_metadata_to_tool_metadata(
|
||||
name: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""将旧 Action 元数据转换为统一 Tool 元数据。
|
||||
|
||||
Args:
|
||||
name: 组件名称。
|
||||
metadata: Action 原始元数据。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 转换后的 Tool 元数据。
|
||||
"""
|
||||
|
||||
action_parameters = metadata.get("action_parameters")
|
||||
parameters_schema: Dict[str, Any] | None = None
|
||||
if isinstance(action_parameters, dict) and action_parameters:
|
||||
properties: Dict[str, Any] = {}
|
||||
for parameter_name, parameter_description in action_parameters.items():
|
||||
normalized_name = str(parameter_name or "").strip()
|
||||
if not normalized_name:
|
||||
continue
|
||||
properties[normalized_name] = {
|
||||
"type": "string",
|
||||
"description": str(parameter_description or "").strip() or "兼容旧 Action 参数",
|
||||
}
|
||||
if properties:
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
|
||||
detailed_parts: List[str] = []
|
||||
if parameters_schema is not None:
|
||||
parameter_description = build_tool_detailed_description(parameters_schema)
|
||||
if parameter_description:
|
||||
detailed_parts.append(parameter_description)
|
||||
|
||||
action_require = [
|
||||
str(item).strip()
|
||||
for item in (metadata.get("action_require") or [])
|
||||
if str(item).strip()
|
||||
]
|
||||
if action_require:
|
||||
detailed_parts.append("使用建议:\n" + "\n".join(f"- {item}" for item in action_require))
|
||||
|
||||
associated_types = [
|
||||
str(item).strip()
|
||||
for item in (metadata.get("associated_types") or [])
|
||||
if str(item).strip()
|
||||
]
|
||||
if associated_types:
|
||||
detailed_parts.append(f"适用消息类型:{'、'.join(associated_types)}。")
|
||||
|
||||
activation_type = str(metadata.get("activation_type", "always") or "always").strip()
|
||||
activation_keywords = [
|
||||
str(item).strip()
|
||||
for item in (metadata.get("activation_keywords") or [])
|
||||
if str(item).strip()
|
||||
]
|
||||
activation_lines = [f"兼容旧 Action 激活方式:{activation_type}。"]
|
||||
if activation_keywords:
|
||||
activation_lines.append(f"激活关键词:{'、'.join(activation_keywords)}。")
|
||||
if str(metadata.get("action_prompt", "") or "").strip():
|
||||
activation_lines.append(f"原始 Action 提示语:{str(metadata['action_prompt']).strip()}。")
|
||||
detailed_parts.append("\n".join(activation_lines))
|
||||
|
||||
brief_description = str(metadata.get("brief_description", metadata.get("description", "") or f"工具 {name}")).strip()
|
||||
return {
|
||||
**metadata,
|
||||
"description": brief_description,
|
||||
"brief_description": brief_description,
|
||||
"detailed_description": "\n\n".join(part for part in detailed_parts if part).strip(),
|
||||
"parameters_raw": parameters_schema or {},
|
||||
"invoke_method": "plugin.invoke_action",
|
||||
"legacy_action": True,
|
||||
"legacy_component_type": "ACTION",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_component_type(component_type: str) -> ComponentTypes:
|
||||
"""规范化组件类型输入。
|
||||
@@ -223,18 +489,20 @@ class ComponentRegistry:
|
||||
"""
|
||||
try:
|
||||
normalized_type = self._normalize_component_type(component_type)
|
||||
normalized_metadata = dict(metadata)
|
||||
if normalized_type == ComponentTypes.ACTION:
|
||||
comp = ActionEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
|
||||
comp = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.COMMAND:
|
||||
comp = CommandEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.TOOL:
|
||||
comp = ToolEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.EVENT_HANDLER:
|
||||
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.HOOK_HANDLER:
|
||||
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
else:
|
||||
raise ValueError(f"组件类型 {component_type} 不存在")
|
||||
except ValueError:
|
||||
@@ -454,16 +722,17 @@ class ComponentRegistry:
|
||||
return handlers
|
||||
|
||||
def get_hook_handlers(
|
||||
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
self, hook_name: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[HookHandlerEntry]:
|
||||
"""获取特定 hook 阶段的所有步骤,按 priority 降序。
|
||||
"""获取订阅指定命名 Hook 的全部处理器。
|
||||
|
||||
Args:
|
||||
stage: hook 名称
|
||||
enabled_only: 是否仅返回启用的组件
|
||||
session_id: 可选的会话ID,若提供则考虑会话禁用状态
|
||||
hook_name: 目标 Hook 名称。
|
||||
enabled_only: 是否仅返回启用的组件。
|
||||
session_id: 可选的会话 ID,若提供则考虑会话禁用状态。
|
||||
|
||||
Returns:
|
||||
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
|
||||
List[HookHandlerEntry]: 符合条件的 HookHandler 组件列表。
|
||||
"""
|
||||
handlers: List[HookHandlerEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
|
||||
@@ -471,11 +740,37 @@ class ComponentRegistry:
|
||||
continue
|
||||
if not isinstance(comp, HookHandlerEntry):
|
||||
continue
|
||||
if comp.stage == stage:
|
||||
if comp.hook == hook_name:
|
||||
handlers.append(comp)
|
||||
handlers.sort(key=lambda c: c.priority, reverse=True)
|
||||
handlers.sort(key=lambda comp: (self._get_hook_mode_rank(comp.mode), self._get_hook_order_rank(comp.order), comp.plugin_id, comp.name))
|
||||
return handlers
|
||||
|
||||
@staticmethod
|
||||
def _get_hook_mode_rank(mode: str) -> int:
|
||||
"""返回 Hook 模式的排序权重。
|
||||
|
||||
Args:
|
||||
mode: Hook 模式字符串。
|
||||
|
||||
Returns:
|
||||
int: 越小表示越靠前。
|
||||
"""
|
||||
|
||||
return {"blocking": 0, "observe": 1}.get(mode, 99)
|
||||
|
||||
@staticmethod
|
||||
def _get_hook_order_rank(order: str) -> int:
|
||||
"""返回 Hook 顺序槽位的排序权重。
|
||||
|
||||
Args:
|
||||
order: Hook 顺序槽位字符串。
|
||||
|
||||
Returns:
|
||||
int: 越小表示越靠前。
|
||||
"""
|
||||
|
||||
return {"early": 0, "normal": 1, "late": 2}.get(order, 99)
|
||||
|
||||
def get_message_gateway(
|
||||
self,
|
||||
plugin_id: str,
|
||||
@@ -566,8 +861,13 @@ class ComponentRegistry:
|
||||
Returns:
|
||||
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
|
||||
"""
|
||||
stats: StatusDict = {"total": len(self._components)} # type: ignore
|
||||
for comp_type, type_dict in self._by_type.items():
|
||||
stats[comp_type.value.lower()] = len(type_dict)
|
||||
stats["plugins"] = len(self._by_plugin)
|
||||
return stats
|
||||
return StatusDict(
|
||||
total=len(self._components),
|
||||
action=len(self._by_type[ComponentTypes.ACTION]),
|
||||
command=len(self._by_type[ComponentTypes.COMMAND]),
|
||||
tool=len(self._by_type[ComponentTypes.TOOL]),
|
||||
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
|
||||
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
|
||||
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),
|
||||
plugins=len(self._by_plugin),
|
||||
)
|
||||
|
||||
@@ -1,166 +1,670 @@
|
||||
"""
|
||||
Hook Dispatch 系统
|
||||
"""命名 Hook 分发系统。
|
||||
|
||||
插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。
|
||||
每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。
|
||||
在参数/返回值匹配的情况下允许修改参数/返回值。
|
||||
主程序可以在任意执行点触发一个命名 Hook,Host 会收集所有订阅该 Hook 的
|
||||
插件处理器,并按照固定的全局顺序调度执行。
|
||||
|
||||
HookDispatcher 负责:
|
||||
1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry)
|
||||
2. 按 priority 排序,区分 blocking 和非 blocking 模式
|
||||
3. blocking 模式:依次同步调用,支持修改参数/提前终止
|
||||
4. 非 blocking 模式:异步调用,不阻塞主流程
|
||||
5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
|
||||
排序规则如下:
|
||||
|
||||
1. `blocking` 先于 `observe`
|
||||
2. `early` 先于 `normal` 先于 `late`
|
||||
3. 内置插件先于第三方插件
|
||||
4. `plugin_id`
|
||||
5. `handler_name`
|
||||
|
||||
其中:
|
||||
|
||||
- `blocking` 处理器串行执行,可修改 `kwargs`,也可中止本次 Hook 调用。
|
||||
- `observe` 处理器后台并发执行,只允许旁路观察,不参与主流程控制。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
|
||||
import contextlib
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .component_registry import HookHandlerEntry
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
from .component_registry import ComponentRegistry, HookHandlerEntry
|
||||
|
||||
logger = get_logger("plugin_runtime.host.hook_dispatcher")
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""单个 HookHandler 的执行结果"""
|
||||
@dataclass(slots=True)
|
||||
class HookSpec:
|
||||
"""命名 Hook 的静态规格定义。
|
||||
|
||||
Attributes:
|
||||
name: Hook 的唯一名称。
|
||||
description: Hook 描述。
|
||||
default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。
|
||||
allow_blocking: 是否允许注册阻塞处理器。
|
||||
allow_observe: 是否允许注册观察处理器。
|
||||
allow_abort: 是否允许处理器中止当前 Hook 调用。
|
||||
allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
default_timeout_ms: int = 0
|
||||
allow_blocking: bool = True
|
||||
allow_observe: bool = True
|
||||
allow_abort: bool = True
|
||||
allow_kwargs_mutation: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HookHandlerExecutionResult:
|
||||
"""单个 HookHandler 的执行结果。
|
||||
|
||||
Attributes:
|
||||
handler_name: 完整处理器名称,格式通常为 `plugin_id.component_name`。
|
||||
plugin_id: 处理器所属插件 ID。
|
||||
success: 本次调用是否成功。
|
||||
action: 当前处理器要求的控制动作,仅支持 `continue` 或 `abort`。
|
||||
modified_kwargs: 处理器返回的修改后参数字典。
|
||||
custom_result: 处理器返回的附加结果。
|
||||
error_message: 失败时的错误描述。
|
||||
"""
|
||||
|
||||
handler_name: str
|
||||
success: bool = field(default=True)
|
||||
continue_processing: bool = field(default=True)
|
||||
modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
|
||||
custom_result: Any = field(default=None)
|
||||
plugin_id: str
|
||||
success: bool = True
|
||||
action: str = "continue"
|
||||
modified_kwargs: Optional[Dict[str, Any]] = None
|
||||
custom_result: Any = None
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HookDispatchResult:
|
||||
"""一次命名 Hook 调用的聚合结果。
|
||||
|
||||
Attributes:
|
||||
hook_name: 本次调用的 Hook 名称。
|
||||
kwargs: 经阻塞处理器串行处理后的最终参数字典。
|
||||
aborted: 是否被某个处理器中止。
|
||||
stopped_by: 若被中止,记录触发中止的完整处理器名称。
|
||||
custom_results: 阻塞处理器返回的附加结果列表。
|
||||
errors: 本次调用中记录到的错误信息列表。
|
||||
"""
|
||||
|
||||
hook_name: str
|
||||
kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
aborted: bool = False
|
||||
stopped_by: Optional[str] = None
|
||||
custom_results: List[Any] = field(default_factory=list)
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _HookInvocationTarget:
|
||||
"""内部使用的 Hook 调度目标。
|
||||
|
||||
Attributes:
|
||||
supervisor: 负责该处理器的 Supervisor。
|
||||
entry: Hook 处理器条目。
|
||||
source_rank: 插件来源权重,内置插件为 `0`,第三方插件为 `1`。
|
||||
"""
|
||||
|
||||
supervisor: "PluginRunnerSupervisor"
|
||||
entry: "HookHandlerEntry"
|
||||
source_rank: int
|
||||
|
||||
|
||||
class HookDispatcher:
|
||||
"""Host-side Hook 分发器
|
||||
"""命名 Hook 分发器。"""
|
||||
|
||||
由业务层调用 hook_dispatch(),
|
||||
内部通过 ComponentRegistry 查询 handler,
|
||||
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
|
||||
"""
|
||||
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
"""初始化 HookDispatcher
|
||||
def __init__(
|
||||
self,
|
||||
supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None,
|
||||
) -> None:
|
||||
"""初始化 Hook 分发器。
|
||||
|
||||
Args:
|
||||
component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
|
||||
supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()`
|
||||
时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。
|
||||
"""
|
||||
self._component_registry: "ComponentRegistry" = component_registry
|
||||
self._background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||
self._hook_specs: Dict[str, HookSpec] = {}
|
||||
self._supervisors_provider = supervisors_provider
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 HookDispatcher,取消所有未完成的后台任务"""
|
||||
"""停止分发器并取消所有未完成的观察任务。"""
|
||||
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def hook_dispatch(
|
||||
self,
|
||||
stage: str,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""分发 hook 到所有对应 handler 的便捷方法。
|
||||
|
||||
内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
|
||||
无需调用方手动构造 invoke_fn 闭包。
|
||||
def register_hook_spec(self, spec: HookSpec) -> None:
|
||||
"""注册单个命名 Hook 规格。
|
||||
|
||||
Args:
|
||||
stage: hook 名称
|
||||
supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
|
||||
**kwargs: 关键字参数,会展开传递给 handler
|
||||
spec: 需要注册的 Hook 规格。
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(spec.name)
|
||||
self._hook_specs[normalized_name] = HookSpec(
|
||||
name=normalized_name,
|
||||
description=spec.description,
|
||||
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
|
||||
allow_blocking=bool(spec.allow_blocking),
|
||||
allow_observe=bool(spec.allow_observe),
|
||||
allow_abort=bool(spec.allow_abort),
|
||||
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
|
||||
)
|
||||
|
||||
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
|
||||
"""批量注册命名 Hook 规格。
|
||||
|
||||
Args:
|
||||
specs: 需要注册的 Hook 规格序列。
|
||||
"""
|
||||
|
||||
for spec in specs:
|
||||
self.register_hook_spec(spec)
|
||||
|
||||
def get_hook_spec(self, hook_name: str) -> HookSpec:
|
||||
"""获取指定 Hook 的规格定义。
|
||||
|
||||
Args:
|
||||
hook_name: Hook 名称。
|
||||
|
||||
Returns:
|
||||
modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
|
||||
HookSpec: 若未显式注册,则返回按系统默认值生成的运行时规格。
|
||||
"""
|
||||
handler_entries = self._component_registry.get_hook_handlers(stage)
|
||||
if not handler_entries:
|
||||
return kwargs
|
||||
|
||||
current_kwargs = kwargs.copy()
|
||||
blocking_handlers: List["HookHandlerEntry"] = []
|
||||
non_blocking_handlers: List["HookHandlerEntry"] = []
|
||||
normalized_name = self._normalize_hook_name(hook_name)
|
||||
if normalized_name in self._hook_specs:
|
||||
return self._hook_specs[normalized_name]
|
||||
|
||||
# 分离 blocking 和非 blocking handler
|
||||
for entry in handler_entries:
|
||||
if entry.blocking:
|
||||
blocking_handlers.append(entry)
|
||||
else:
|
||||
non_blocking_handlers.append(entry)
|
||||
return HookSpec(
|
||||
name=normalized_name,
|
||||
default_timeout_ms=self._get_default_timeout_ms(),
|
||||
)
|
||||
|
||||
# 处理 blocking handlers(同步调用,支持修改参数/提前终止)
|
||||
timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
|
||||
for entry in blocking_handlers:
|
||||
hook_args = {"stage": stage, **current_kwargs}
|
||||
try:
|
||||
# 应用超时控制
|
||||
result = await asyncio.wait_for(
|
||||
self._invoke_handler(supervisor, entry, hook_args),
|
||||
timeout=timeout,
|
||||
async def invoke_hook(
|
||||
self,
|
||||
hook_name: str,
|
||||
supervisors: Optional[Sequence["PluginRunnerSupervisor"]] = None,
|
||||
**kwargs: Any,
|
||||
) -> HookDispatchResult:
|
||||
"""触发一次命名 Hook 调用。
|
||||
|
||||
Args:
|
||||
hook_name: 本次触发的 Hook 名称。
|
||||
supervisors: 当前运行时中所有可参与分发的 Supervisor;留空时使用绑定的提供器。
|
||||
**kwargs: 传递给 Hook 处理器的关键字参数。
|
||||
|
||||
Returns:
|
||||
HookDispatchResult: 聚合后的 Hook 调用结果。
|
||||
"""
|
||||
|
||||
resolved_supervisors = list(supervisors) if supervisors is not None else list(self._resolve_supervisors())
|
||||
normalized_hook_name = self._normalize_hook_name(hook_name)
|
||||
hook_spec = self.get_hook_spec(normalized_hook_name)
|
||||
current_kwargs: Dict[str, Any] = dict(kwargs)
|
||||
dispatch_result = HookDispatchResult(hook_name=normalized_hook_name, kwargs=dict(current_kwargs))
|
||||
invocation_targets = self._collect_invocation_targets(normalized_hook_name, resolved_supervisors)
|
||||
|
||||
if not invocation_targets:
|
||||
return dispatch_result
|
||||
|
||||
for target in invocation_targets:
|
||||
if target.entry.is_observe:
|
||||
self._schedule_observe_handler(
|
||||
hook_name=normalized_hook_name,
|
||||
hook_spec=hook_spec,
|
||||
target=target,
|
||||
kwargs=current_kwargs,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
|
||||
result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
|
||||
continue
|
||||
|
||||
if result:
|
||||
if result.modified_kwargs is not None:
|
||||
current_kwargs = result.modified_kwargs
|
||||
if not result.continue_processing:
|
||||
logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
|
||||
break
|
||||
if not hook_spec.allow_blocking:
|
||||
error_message = (
|
||||
f"Hook {normalized_hook_name} 不允许 blocking 处理器,"
|
||||
f"已跳过 {target.entry.full_name}"
|
||||
)
|
||||
logger.warning(error_message)
|
||||
dispatch_result.errors.append(error_message)
|
||||
continue
|
||||
|
||||
# 处理 non-blocking handlers(异步调用,不阻塞主流程)
|
||||
for entry in non_blocking_handlers:
|
||||
async_kwargs = current_kwargs.copy()
|
||||
hook_args = {"stage": stage, **async_kwargs}
|
||||
task = asyncio.create_task(
|
||||
asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
|
||||
execution_result = await self._invoke_handler(
|
||||
hook_name=normalized_hook_name,
|
||||
hook_spec=hook_spec,
|
||||
target=target,
|
||||
kwargs=current_kwargs,
|
||||
)
|
||||
self._merge_blocking_result(
|
||||
hook_spec=hook_spec,
|
||||
target=target,
|
||||
execution_result=execution_result,
|
||||
dispatch_result=dispatch_result,
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return current_kwargs
|
||||
current_kwargs = dict(dispatch_result.kwargs)
|
||||
if dispatch_result.aborted:
|
||||
break
|
||||
|
||||
return dispatch_result
|
||||
|
||||
def _resolve_supervisors(self) -> Sequence["PluginRunnerSupervisor"]:
|
||||
"""解析当前调用应使用的 Supervisor 列表。
|
||||
|
||||
Returns:
|
||||
Sequence[PluginRunnerSupervisor]: 可参与本次 Hook 调度的 Supervisor 序列。
|
||||
|
||||
Raises:
|
||||
ValueError: 当未传入 `supervisors` 且分发器也未绑定提供器时抛出。
|
||||
"""
|
||||
|
||||
if self._supervisors_provider is None:
|
||||
raise ValueError("当前 HookDispatcher 未绑定 supervisors_provider,请显式传入 supervisors")
|
||||
return self._supervisors_provider()
|
||||
|
||||
def _collect_invocation_targets(
|
||||
self,
|
||||
hook_name: str,
|
||||
supervisors: Sequence["PluginRunnerSupervisor"],
|
||||
) -> List[_HookInvocationTarget]:
|
||||
"""收集并排序本次 Hook 调用的全部处理器目标。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
supervisors: 当前参与调度的 Supervisor 序列。
|
||||
|
||||
Returns:
|
||||
List[_HookInvocationTarget]: 已完成全局排序的处理器目标列表。
|
||||
"""
|
||||
|
||||
invocation_targets: List[_HookInvocationTarget] = []
|
||||
for supervisor in supervisors:
|
||||
source_rank = self._get_supervisor_source_rank(supervisor)
|
||||
for entry in supervisor.component_registry.get_hook_handlers(hook_name):
|
||||
invocation_targets.append(
|
||||
_HookInvocationTarget(
|
||||
supervisor=supervisor,
|
||||
entry=entry,
|
||||
source_rank=source_rank,
|
||||
)
|
||||
)
|
||||
|
||||
invocation_targets.sort(key=self._build_sort_key)
|
||||
return invocation_targets
|
||||
|
||||
@staticmethod
|
||||
def _build_sort_key(target: _HookInvocationTarget) -> tuple[int, int, int, str, str]:
|
||||
"""构造 Hook 处理器的全局排序键。
|
||||
|
||||
Args:
|
||||
target: 待排序的处理器目标。
|
||||
|
||||
Returns:
|
||||
tuple[int, int, int, str, str]: 全局排序键。
|
||||
"""
|
||||
|
||||
return (
|
||||
HookDispatcher._get_mode_rank(target.entry.mode),
|
||||
HookDispatcher._get_order_rank(target.entry.order),
|
||||
target.source_rank,
|
||||
target.entry.plugin_id,
|
||||
target.entry.name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_default_timeout_ms() -> int:
|
||||
"""读取系统级默认 Hook 超时。
|
||||
|
||||
Returns:
|
||||
int: 默认超时毫秒数。
|
||||
"""
|
||||
|
||||
timeout_seconds = float(global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0)
|
||||
return max(int(timeout_seconds * 1000), 1)
|
||||
|
||||
@staticmethod
|
||||
def _get_mode_rank(mode: str) -> int:
|
||||
"""返回 Hook 模式的排序权重。
|
||||
|
||||
Args:
|
||||
mode: Hook 模式。
|
||||
|
||||
Returns:
|
||||
int: 越小表示越靠前。
|
||||
"""
|
||||
|
||||
return {"blocking": 0, "observe": 1}.get(mode, 99)
|
||||
|
||||
@staticmethod
|
||||
def _get_order_rank(order: str) -> int:
|
||||
"""返回 Hook 顺序槽位的排序权重。
|
||||
|
||||
Args:
|
||||
order: Hook 顺序槽位。
|
||||
|
||||
Returns:
|
||||
int: 越小表示越靠前。
|
||||
"""
|
||||
|
||||
return {"early": 0, "normal": 1, "late": 2}.get(order, 99)
|
||||
|
||||
@staticmethod
|
||||
def _get_supervisor_source_rank(supervisor: "PluginRunnerSupervisor") -> int:
|
||||
"""返回 Supervisor 的来源排序权重。
|
||||
|
||||
Args:
|
||||
supervisor: 目标 Supervisor。
|
||||
|
||||
Returns:
|
||||
int: 内置插件返回 `0`,第三方插件返回 `1`。
|
||||
"""
|
||||
|
||||
return 0 if supervisor.group_name == "builtin" else 1
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hook_name(hook_name: str) -> str:
|
||||
"""规范化命名 Hook 名称。
|
||||
|
||||
Args:
|
||||
hook_name: 原始 Hook 名称。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的 Hook 名称。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 Hook 名称为空时抛出。
|
||||
"""
|
||||
|
||||
normalized_name = str(hook_name or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("Hook 名称不能为空")
|
||||
return normalized_name
|
||||
|
||||
def _resolve_timeout_ms(self, hook_spec: HookSpec, target: _HookInvocationTarget) -> int:
|
||||
"""计算单个处理器的实际超时。
|
||||
|
||||
Args:
|
||||
hook_spec: 当前 Hook 的规格定义。
|
||||
target: 当前执行目标。
|
||||
|
||||
Returns:
|
||||
int: 最终生效的超时毫秒数。
|
||||
"""
|
||||
|
||||
if target.entry.timeout_ms > 0:
|
||||
return target.entry.timeout_ms
|
||||
if hook_spec.default_timeout_ms > 0:
|
||||
return hook_spec.default_timeout_ms
|
||||
return self._get_default_timeout_ms()
|
||||
|
||||
async def _invoke_handler(
|
||||
self,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
handler_entry: "HookHandlerEntry",
|
||||
args: Dict[str, Any],
|
||||
) -> Optional[HookResult]:
|
||||
"""调用单个 handler 并收集结果。
|
||||
hook_name: str,
|
||||
hook_spec: HookSpec,
|
||||
target: _HookInvocationTarget,
|
||||
kwargs: Dict[str, Any],
|
||||
) -> HookHandlerExecutionResult:
|
||||
"""执行单个 Hook 处理器。
|
||||
|
||||
Args:
|
||||
supervisor: PluginRunnerSupervisor 实例
|
||||
handler_entry: HookHandlerEntry 实例
|
||||
args: 传递给 handler 的参数字典
|
||||
stage: hook 名称
|
||||
hook_name: 当前 Hook 名称。
|
||||
hook_spec: 当前 Hook 规格。
|
||||
target: 当前执行目标。
|
||||
kwargs: 当前参数字典。
|
||||
|
||||
Returns:
|
||||
Optional[HookResult]: 执行结果,如果执行失败则返回 None
|
||||
HookHandlerExecutionResult: 处理器执行结果。
|
||||
"""
|
||||
try:
|
||||
resp_envelope = await supervisor.invoke_plugin(
|
||||
"plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
|
||||
)
|
||||
resp = resp_envelope.payload
|
||||
result = HookResult(
|
||||
handler_name=handler_entry.full_name,
|
||||
success=resp.get("success", True),
|
||||
continue_processing=resp.get("continue_processing", True),
|
||||
modified_kwargs=resp.get("modified_kwargs"),
|
||||
custom_result=resp.get("custom_result"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
|
||||
result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
|
||||
|
||||
return result
|
||||
timeout_ms = self._resolve_timeout_ms(hook_spec, target)
|
||||
request_args: Dict[str, Any] = {"hook_name": hook_name, **dict(kwargs)}
|
||||
|
||||
try:
|
||||
response_envelope = await asyncio.wait_for(
|
||||
target.supervisor.invoke_plugin(
|
||||
"plugin.invoke_hook",
|
||||
target.entry.plugin_id,
|
||||
target.entry.name,
|
||||
request_args,
|
||||
timeout_ms=timeout_ms,
|
||||
),
|
||||
timeout=max(timeout_ms / 1000.0, 0.001),
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
error_message = (
|
||||
f"HookHandler {target.entry.full_name} 执行超时,已超过 {timeout_ms}ms"
|
||||
)
|
||||
logger.error(error_message)
|
||||
return HookHandlerExecutionResult(
|
||||
handler_name=target.entry.full_name,
|
||||
plugin_id=target.entry.plugin_id,
|
||||
success=False,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as exc:
|
||||
error_message = f"HookHandler {target.entry.full_name} 执行失败: {exc}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return HookHandlerExecutionResult(
|
||||
handler_name=target.entry.full_name,
|
||||
plugin_id=target.entry.plugin_id,
|
||||
success=False,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
response_payload = response_envelope.payload
|
||||
if not isinstance(response_payload, dict):
|
||||
return HookHandlerExecutionResult(
|
||||
handler_name=target.entry.full_name,
|
||||
plugin_id=target.entry.plugin_id,
|
||||
custom_result=response_payload,
|
||||
)
|
||||
|
||||
return HookHandlerExecutionResult(
|
||||
handler_name=target.entry.full_name,
|
||||
plugin_id=target.entry.plugin_id,
|
||||
success=bool(response_payload.get("success", True)),
|
||||
action=self._normalize_action(response_payload.get("action", "continue")),
|
||||
modified_kwargs=self._extract_modified_kwargs(response_payload.get("modified_kwargs")),
|
||||
custom_result=response_payload.get("custom_result"),
|
||||
error_message=str(response_payload.get("error_message", "") or ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_modified_kwargs(raw_value: Any) -> Optional[Dict[str, Any]]:
|
||||
"""提取并校验处理器返回的 `modified_kwargs`。
|
||||
|
||||
Args:
|
||||
raw_value: 原始返回值。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 合法时返回字典,否则返回 `None`。
|
||||
"""
|
||||
|
||||
if raw_value is None:
|
||||
return None
|
||||
if isinstance(raw_value, dict):
|
||||
return dict(raw_value)
|
||||
logger.warning("HookHandler 返回的 modified_kwargs 不是字典,已忽略")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_action(raw_value: Any) -> str:
|
||||
"""规范化处理器动作返回值。
|
||||
|
||||
Args:
|
||||
raw_value: 原始动作值。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的动作值,仅支持 `continue` 或 `abort`。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower() or "continue"
|
||||
if normalized_value not in {"continue", "abort"}:
|
||||
logger.warning(f"未知的 Hook action: {raw_value},已按 continue 处理")
|
||||
return "continue"
|
||||
return normalized_value
|
||||
|
||||
def _merge_blocking_result(
|
||||
self,
|
||||
hook_spec: HookSpec,
|
||||
target: _HookInvocationTarget,
|
||||
execution_result: HookHandlerExecutionResult,
|
||||
dispatch_result: HookDispatchResult,
|
||||
) -> None:
|
||||
"""合并阻塞处理器结果到聚合结果。
|
||||
|
||||
Args:
|
||||
hook_spec: 当前 Hook 规格。
|
||||
target: 当前执行目标。
|
||||
execution_result: 当前处理器执行结果。
|
||||
dispatch_result: 当前聚合结果对象。
|
||||
"""
|
||||
|
||||
if execution_result.custom_result is not None:
|
||||
dispatch_result.custom_results.append(execution_result.custom_result)
|
||||
|
||||
if not execution_result.success:
|
||||
error_message = execution_result.error_message or f"HookHandler {target.entry.full_name} 执行失败"
|
||||
dispatch_result.errors.append(error_message)
|
||||
self._apply_error_policy(target, hook_spec, dispatch_result, error_message)
|
||||
return
|
||||
|
||||
if execution_result.modified_kwargs is not None:
|
||||
if hook_spec.allow_kwargs_mutation:
|
||||
dispatch_result.kwargs = dict(execution_result.modified_kwargs)
|
||||
else:
|
||||
error_message = (
|
||||
f"Hook {dispatch_result.hook_name} 不允许修改 kwargs,"
|
||||
f"已忽略 {target.entry.full_name} 的 modified_kwargs"
|
||||
)
|
||||
logger.warning(error_message)
|
||||
dispatch_result.errors.append(error_message)
|
||||
|
||||
if execution_result.action == "abort":
|
||||
if hook_spec.allow_abort:
|
||||
dispatch_result.aborted = True
|
||||
dispatch_result.stopped_by = target.entry.full_name
|
||||
logger.info(f"HookHandler {target.entry.full_name} 中止了 Hook {dispatch_result.hook_name}")
|
||||
else:
|
||||
error_message = (
|
||||
f"Hook {dispatch_result.hook_name} 不允许 abort,"
|
||||
f"已忽略 {target.entry.full_name} 的 abort 请求"
|
||||
)
|
||||
logger.warning(error_message)
|
||||
dispatch_result.errors.append(error_message)
|
||||
|
||||
def _apply_error_policy(
|
||||
self,
|
||||
target: _HookInvocationTarget,
|
||||
hook_spec: HookSpec,
|
||||
dispatch_result: HookDispatchResult,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""根据错误策略处理阻塞处理器失败。
|
||||
|
||||
Args:
|
||||
target: 触发错误的处理器目标。
|
||||
hook_spec: 当前 Hook 规格。
|
||||
dispatch_result: 当前聚合结果对象。
|
||||
error_message: 需要记录的错误描述。
|
||||
"""
|
||||
|
||||
if target.entry.error_policy != "abort":
|
||||
return
|
||||
if not hook_spec.allow_abort:
|
||||
logger.warning(
|
||||
f"Hook {dispatch_result.hook_name} 禁止 abort,"
|
||||
f"已将 {target.entry.full_name} 的错误策略按 skip 处理"
|
||||
)
|
||||
return
|
||||
|
||||
dispatch_result.aborted = True
|
||||
dispatch_result.stopped_by = target.entry.full_name
|
||||
logger.warning(
|
||||
f"HookHandler {target.entry.full_name} 因错误策略 abort "
|
||||
f"中止了 Hook {dispatch_result.hook_name}: {error_message}"
|
||||
)
|
||||
|
||||
def _schedule_observe_handler(
|
||||
self,
|
||||
hook_name: str,
|
||||
hook_spec: HookSpec,
|
||||
target: _HookInvocationTarget,
|
||||
kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""后台调度观察型处理器。
|
||||
|
||||
Args:
|
||||
hook_name: 当前 Hook 名称。
|
||||
hook_spec: 当前 Hook 规格。
|
||||
target: 当前观察型处理器目标。
|
||||
kwargs: 调用参数快照。
|
||||
"""
|
||||
|
||||
if not hook_spec.allow_observe:
|
||||
logger.warning(f"Hook {hook_name} 不允许 observe 处理器,已跳过 {target.entry.full_name}")
|
||||
return
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._run_observe_handler(
|
||||
hook_name=hook_name,
|
||||
hook_spec=hook_spec,
|
||||
target=target,
|
||||
kwargs=dict(kwargs),
|
||||
)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._handle_background_task_done)
|
||||
|
||||
async def _run_observe_handler(
|
||||
self,
|
||||
hook_name: str,
|
||||
hook_spec: HookSpec,
|
||||
target: _HookInvocationTarget,
|
||||
kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""执行观察型处理器并吞掉控制流副作用。
|
||||
|
||||
Args:
|
||||
hook_name: 当前 Hook 名称。
|
||||
hook_spec: 当前 Hook 规格。
|
||||
target: 当前观察型处理器目标。
|
||||
kwargs: 调用参数快照。
|
||||
"""
|
||||
|
||||
execution_result = await self._invoke_handler(
|
||||
hook_name=hook_name,
|
||||
hook_spec=hook_spec,
|
||||
target=target,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
if not execution_result.success:
|
||||
logger.warning(
|
||||
f"观察型 HookHandler {target.entry.full_name} 执行失败: "
|
||||
f"{execution_result.error_message or '未知错误'}"
|
||||
)
|
||||
return
|
||||
|
||||
if execution_result.modified_kwargs is not None:
|
||||
logger.warning(f"观察型 HookHandler {target.entry.full_name} 返回了 modified_kwargs,已忽略")
|
||||
if execution_result.action == "abort":
|
||||
logger.warning(f"观察型 HookHandler {target.entry.full_name} 请求 abort,已忽略")
|
||||
|
||||
def _handle_background_task_done(self, task: asyncio.Task[Any]) -> None:
|
||||
"""处理观察任务完成回调。
|
||||
|
||||
Args:
|
||||
task: 已完成的后台任务。
|
||||
"""
|
||||
|
||||
self._background_tasks.discard(task)
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
exception = task.exception()
|
||||
if exception is not None:
|
||||
logger.error(f"观察型 Hook 后台任务执行失败: {exception}")
|
||||
|
||||
@@ -49,7 +49,7 @@ from .api_registry import APIRegistry
|
||||
from .capability_service import CapabilityService
|
||||
from .component_registry import ComponentRegistry
|
||||
from .event_dispatcher import EventDispatcher
|
||||
from .hook_dispatcher import HookDispatcher
|
||||
from .hook_dispatcher import HookDispatchResult, HookDispatcher
|
||||
from .logger_bridge import RunnerLogBridge
|
||||
from .message_gateway import MessageGateway
|
||||
from .rpc_server import RPCServer
|
||||
@@ -80,6 +80,7 @@ class PluginRunnerSupervisor:
|
||||
def __init__(
|
||||
self,
|
||||
plugin_dirs: Optional[List[Path]] = None,
|
||||
group_name: str = "third_party",
|
||||
socket_path: Optional[str] = None,
|
||||
health_check_interval_sec: Optional[float] = None,
|
||||
max_restart_attempts: Optional[int] = None,
|
||||
@@ -89,12 +90,14 @@ class PluginRunnerSupervisor:
|
||||
|
||||
Args:
|
||||
plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
|
||||
group_name: 当前 Supervisor 所属运行时分组名称。
|
||||
socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
|
||||
health_check_interval_sec: 健康检查间隔,单位秒。
|
||||
max_restart_attempts: 自动重启 Runner 的最大次数。
|
||||
runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。
|
||||
"""
|
||||
runtime_config = global_config.plugin_runtime
|
||||
self._group_name: str = str(group_name or "third_party").strip() or "third_party"
|
||||
self._plugin_dirs: List[Path] = plugin_dirs or []
|
||||
self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
|
||||
self._runner_spawn_timeout: float = (
|
||||
@@ -108,7 +111,7 @@ class PluginRunnerSupervisor:
|
||||
self._api_registry = APIRegistry()
|
||||
self._component_registry = ComponentRegistry()
|
||||
self._event_dispatcher = EventDispatcher(self._component_registry)
|
||||
self._hook_dispatcher = HookDispatcher(self._component_registry)
|
||||
self._hook_dispatcher = HookDispatcher(lambda: [self])
|
||||
self._message_gateway = MessageGateway(self._component_registry)
|
||||
self._log_bridge = RunnerLogBridge()
|
||||
|
||||
@@ -133,6 +136,12 @@ class PluginRunnerSupervisor:
|
||||
"""返回授权管理器。"""
|
||||
return self._authorization
|
||||
|
||||
@property
|
||||
def group_name(self) -> str:
|
||||
"""返回当前 Supervisor 的运行时分组名称。"""
|
||||
|
||||
return self._group_name
|
||||
|
||||
@property
|
||||
def capability_service(self) -> CapabilityService:
|
||||
"""返回能力服务。"""
|
||||
@@ -243,17 +252,18 @@ class PluginRunnerSupervisor:
|
||||
"""
|
||||
return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args)
|
||||
|
||||
async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""分发 Hook 到已注册的 Hook 处理器。
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult:
|
||||
"""在当前 Supervisor 内触发一次命名 Hook 调用。
|
||||
|
||||
Args:
|
||||
stage: Hook 阶段名称。
|
||||
**kwargs: 传递给 Hook 的关键字参数。
|
||||
hook_name: 本次触发的 Hook 名称。
|
||||
**kwargs: 传递给 Hook 处理器的关键字参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 经 Hook 修改后的参数字典。
|
||||
HookDispatchResult: 聚合后的 Hook 调用结果。
|
||||
"""
|
||||
return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs)
|
||||
|
||||
return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs)
|
||||
|
||||
async def send_message_to_external(
|
||||
self,
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
提供 PluginRuntimeManager 单例,负责:
|
||||
1. 管理双 PluginSupervisor 的生命周期(内置插件 / 第三方插件各一个子进程)
|
||||
2. 将 EventType 桥接到运行时的 event dispatch
|
||||
3. 在运行时的 ComponentRegistry 中查找命令
|
||||
4. 提供统一的能力实现注册接口,使插件可以调用主程序功能
|
||||
3. 触发跨 Supervisor 的命名 Hook 调用
|
||||
4. 在运行时的 ComponentRegistry 中查找命令
|
||||
5. 提供统一的能力实现注册接口,使插件可以调用主程序功能
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -24,6 +25,7 @@ from src.plugin_runtime.capabilities import (
|
||||
RuntimeDataCapabilityMixin,
|
||||
)
|
||||
from src.plugin_runtime.capabilities.registry import register_capability_impls
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec
|
||||
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
@@ -72,6 +74,7 @@ class PluginRuntimeManager(
|
||||
self._manifest_validator: ManifestValidator = ManifestValidator()
|
||||
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
|
||||
self._config_reload_callback_registered: bool = False
|
||||
self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors)
|
||||
|
||||
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
||||
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
||||
@@ -182,6 +185,7 @@ class PluginRuntimeManager(
|
||||
if builtin_dirs:
|
||||
self._builtin_supervisor = PluginSupervisor(
|
||||
plugin_dirs=builtin_dirs,
|
||||
group_name="builtin",
|
||||
socket_path=builtin_socket,
|
||||
)
|
||||
self._register_capability_impls(self._builtin_supervisor)
|
||||
@@ -189,6 +193,7 @@ class PluginRuntimeManager(
|
||||
if third_party_dirs:
|
||||
self._third_party_supervisor = PluginSupervisor(
|
||||
plugin_dirs=third_party_dirs,
|
||||
group_name="third_party",
|
||||
socket_path=third_party_socket,
|
||||
)
|
||||
self._register_capability_impls(self._third_party_supervisor)
|
||||
@@ -235,6 +240,7 @@ class PluginRuntimeManager(
|
||||
await platform_io_manager.stop()
|
||||
except Exception as platform_io_exc:
|
||||
logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
|
||||
await self._hook_dispatcher.stop()
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
@@ -274,6 +280,7 @@ class PluginRuntimeManager(
|
||||
else:
|
||||
logger.info("插件运行时已停止")
|
||||
finally:
|
||||
await self._hook_dispatcher.stop()
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
@@ -284,11 +291,41 @@ class PluginRuntimeManager(
|
||||
"""返回插件运行时是否处于启动状态。"""
|
||||
return self._started
|
||||
|
||||
@property
|
||||
def hook_dispatcher(self) -> HookDispatcher:
|
||||
"""返回跨 Supervisor 的命名 Hook 分发器。"""
|
||||
|
||||
return self._hook_dispatcher
|
||||
|
||||
@property
|
||||
def invoke_dispatcher(self) -> HookDispatcher:
|
||||
"""返回命名 Hook 分发器的兼容别名。"""
|
||||
|
||||
return self._hook_dispatcher
|
||||
|
||||
@property
|
||||
def supervisors(self) -> List["PluginSupervisor"]:
|
||||
"""获取所有活跃的 Supervisor"""
|
||||
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
|
||||
|
||||
def register_hook_spec(self, spec: HookSpec) -> None:
|
||||
"""注册单个命名 Hook 规格。
|
||||
|
||||
Args:
|
||||
spec: 需要注册的 Hook 规格。
|
||||
"""
|
||||
|
||||
self._hook_dispatcher.register_hook_spec(spec)
|
||||
|
||||
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
|
||||
"""批量注册命名 Hook 规格。
|
||||
|
||||
Args:
|
||||
specs: 需要注册的 Hook 规格序列。
|
||||
"""
|
||||
|
||||
self._hook_dispatcher.register_hook_specs(specs)
|
||||
|
||||
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
|
||||
"""根据当前已注册插件构建全局依赖图。"""
|
||||
|
||||
@@ -588,6 +625,19 @@ class PluginRuntimeManager(
|
||||
|
||||
return True, modified
|
||||
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult:
|
||||
"""触发一次跨 Supervisor 的命名 Hook 调用。
|
||||
|
||||
Args:
|
||||
hook_name: 本次触发的 Hook 名称。
|
||||
**kwargs: 传递给 Hook 处理器的关键字参数。
|
||||
|
||||
Returns:
|
||||
HookDispatchResult: 聚合后的 Hook 调用结果。
|
||||
"""
|
||||
|
||||
return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs)
|
||||
|
||||
# ─── 命令查找 ──────────────────────────────────────────────
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@@ -164,7 +164,7 @@ class RunnerIPCLogHandler(logging.Handler):
|
||||
return f"{event_text} {' '.join(extras)}".strip()
|
||||
return event_text
|
||||
|
||||
# format() 会处理 %s 参数替换和 exc_info 文本拼接。
|
||||
# format() 会处理占位参数替换和 exc_info 文本拼接。
|
||||
return self.format(record)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -330,7 +330,6 @@ class PluginRunner:
|
||||
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
|
||||
self._rpc_client.register_method("plugin.health", self._handle_health)
|
||||
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
|
||||
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
|
||||
@@ -1053,73 +1052,28 @@ class PluginRunner:
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True)
|
||||
return envelope.make_response(payload={"success": False, "continue_processing": True})
|
||||
return envelope.make_response(
|
||||
payload={
|
||||
"success": False,
|
||||
"action": "continue",
|
||||
"error_message": str(exc),
|
||||
}
|
||||
)
|
||||
|
||||
if raw is None:
|
||||
result = {"success": True, "continue_processing": True}
|
||||
result = {"success": True, "action": "continue"}
|
||||
elif isinstance(raw, dict):
|
||||
result = {
|
||||
"success": True,
|
||||
"continue_processing": raw.get("continue_processing", True),
|
||||
"action": str(raw.get("action", "continue") or "continue").strip().lower() or "continue",
|
||||
"modified_kwargs": raw.get("modified_kwargs"),
|
||||
"custom_result": raw.get("custom_result"),
|
||||
}
|
||||
else:
|
||||
result = {"success": True, "continue_processing": True, "custom_result": raw}
|
||||
result = {"success": True, "action": "continue", "custom_result": raw}
|
||||
|
||||
return envelope.make_response(payload=result)
|
||||
|
||||
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
|
||||
"""处理 WorkflowStep 调用请求
|
||||
|
||||
与通用 invoke 不同,会将返回值规范化为
|
||||
{hook_result, modified_message, stage_output} 格式。
|
||||
"""
|
||||
try:
|
||||
invoke = InvokePayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
f"插件 {plugin_id} 未加载",
|
||||
)
|
||||
|
||||
component_name = invoke.component_name
|
||||
handler_method = self._resolve_component_handler(meta, component_name)
|
||||
|
||||
if handler_method is None or not callable(handler_method):
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"插件 {plugin_id} 无组件: {component_name}",
|
||||
)
|
||||
|
||||
try:
|
||||
raw = (
|
||||
await handler_method(**invoke.args)
|
||||
if inspect.iscoroutinefunction(handler_method)
|
||||
else handler_method(**invoke.args)
|
||||
)
|
||||
|
||||
# 规范化返回值
|
||||
if isinstance(raw, str):
|
||||
result = {"hook_result": raw}
|
||||
elif isinstance(raw, dict):
|
||||
result = raw
|
||||
result.setdefault("hook_result", "continue")
|
||||
else:
|
||||
result = {"hook_result": "continue"}
|
||||
|
||||
resp_payload = InvokeResultPayload(success=True, result=result)
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True)
|
||||
resp_payload = InvokeResultPayload(success=False, result=str(e))
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
|
||||
async def _handle_health(self, envelope: Envelope) -> Envelope:
|
||||
"""处理健康检查"""
|
||||
uptime_ms = int((time.monotonic() - self._start_time) * 1000)
|
||||
|
||||
48
src/plugin_runtime/tool_provider.py
Normal file
48
src/plugin_runtime/tool_provider.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""插件运行时工具 Provider。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
|
||||
from .component_query import component_query_service
|
||||
|
||||
|
||||
class PluginToolProvider(ToolProvider):
|
||||
"""将插件 Tool 与兼容旧 Action 暴露为统一工具 Provider。"""
|
||||
|
||||
provider_name = "plugin_runtime"
|
||||
provider_type = "plugin"
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
"""列出插件运行时当前可用的工具声明。"""
|
||||
|
||||
return list(component_query_service.get_llm_available_tool_specs().values())
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行插件工具或兼容旧 Action 的工具调用。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用请求。
|
||||
context: 执行上下文。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 工具执行结果。
|
||||
"""
|
||||
|
||||
return await component_query_service.invoke_tool_as_tool(
|
||||
invocation=invocation,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Provider。
|
||||
|
||||
插件运行时工具 Provider 不持有独立资源。
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user