From f1563ede654a798be29fa71584fa700bb419539f Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 22 Apr 2026 00:11:14 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E4=BF=AE=E5=A4=8D=E9=97=A8?= =?UTF-8?q?=E6=8E=A7=E5=A4=9A=E9=87=8Dresult=E9=97=AE=E9=A2=98=EF=BC=8C?= =?UTF-8?q?=E6=96=B0=E5=A2=9Eat=E5=8A=A8=E4=BD=9C=EF=BC=8C=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E7=8E=B0=E5=9C=A8=E8=BF=90=E8=A1=8Cchat=5Fid=E6=8C=87?= =?UTF-8?q?=E5=AE=9A=E6=88=96chat=5Ftype=E6=8C=87=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_maisaka_builtin_at.py | 176 ++++++++++++++++ pytests/test_tool_availability.py | 96 +++++++++ src/core/tooling.py | 47 ++++- src/maisaka/builtin_tool/__init__.py | 189 ++++++++++++------ src/maisaka/builtin_tool/at.py | 186 +++++++++++++++++ src/maisaka/chat_loop_service.py | 10 +- src/maisaka/reasoning_engine.py | 36 +++- src/maisaka/tool_provider.py | 16 +- src/mcp_module/provider.py | 15 +- src/plugin_runtime/component_query.py | 13 +- src/plugin_runtime/host/component_registry.py | 155 +++++++++++--- src/plugin_runtime/tool_provider.py | 16 +- 12 files changed, 833 insertions(+), 122 deletions(-) create mode 100644 pytests/test_maisaka_builtin_at.py create mode 100644 pytests/test_tool_availability.py create mode 100644 src/maisaka/builtin_tool/at.py diff --git a/pytests/test_maisaka_builtin_at.py b/pytests/test_maisaka_builtin_at.py new file mode 100644 index 00000000..0867340e --- /dev/null +++ b/pytests/test_maisaka_builtin_at.py @@ -0,0 +1,176 @@ +from importlib import util +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import sys + +import pytest + +from src.common.data_models.message_component_data_model import AtComponent, TextComponent +from src.core.tooling import ToolExecutionResult, ToolInvocation + +_MISSING_MODULE = object() +_module_overrides: dict[str, object] = {} + + +def _override_module(module_name: str, module: ModuleType) -> None: + _module_overrides[module_name] = sys.modules.get(module_name, _MISSING_MODULE) + sys.modules[module_name] = module + + +def _restore_overridden_modules() -> None: + for module_name, previous_module in reversed(_module_overrides.items()): + if previous_module is _MISSING_MODULE: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = previous_module + _module_overrides.clear() + + +fake_cli_sender_module = ModuleType("src.cli.maisaka_cli_sender") +fake_cli_sender_module.CLI_PLATFORM_NAME = "cli" +fake_cli_sender_module.render_cli_message = lambda text: text +fake_cli_module = ModuleType("src.cli") +fake_cli_module.maisaka_cli_sender = fake_cli_sender_module + +fake_send_service_module = ModuleType("src.services.send_service") +fake_send_service_module._send_to_target_with_message = None +fake_services_module = ModuleType("src.services") +fake_services_module.send_service = fake_send_service_module + +_override_module("src.cli", fake_cli_module) +_override_module("src.cli.maisaka_cli_sender", fake_cli_sender_module) +_override_module("src.services", fake_services_module) +_override_module("src.services.send_service", fake_send_service_module) + +AT_TOOL_PATH = Path(__file__).resolve().parents[1] / "src" / "maisaka" / "builtin_tool" / "at.py" +at_tool_spec = util.spec_from_file_location("_test_maisaka_builtin_at_tool", AT_TOOL_PATH) +assert at_tool_spec is not None and at_tool_spec.loader is not None +at_tool = util.module_from_spec(at_tool_spec) +sys.modules["_test_maisaka_builtin_at_tool"] = at_tool +try: + at_tool_spec.loader.exec_module(at_tool) +finally: + _restore_overridden_modules() + + +class _ToolCtx: + def __init__(self, runtime: SimpleNamespace) -> None: + self.runtime = runtime + + @staticmethod + def build_success_result( + tool_name: str, + content: str = "", + structured_content: Any = None, + metadata: dict[str, Any] | None = None, + ) -> ToolExecutionResult: + return ToolExecutionResult( + tool_name=tool_name, + success=True, + content=content, + structured_content=structured_content, + metadata=dict(metadata or {}), + ) + + @staticmethod + def build_failure_result( + tool_name: str, + error_message: str, + structured_content: Any = None, + metadata: dict[str, Any] | None = None, + ) -> ToolExecutionResult: + return ToolExecutionResult( + tool_name=tool_name, + success=False, + error_message=error_message, + structured_content=structured_content, + metadata=dict(metadata or {}), + ) + + def append_guided_reply_to_chat_history(self, reply_text: str) -> None: + self.runtime._chat_history.append(reply_text) + + +def _build_tool_ctx(*, group_id: str = "group-1") -> _ToolCtx: + target_message = SimpleNamespace( + message_info=SimpleNamespace( + user_info=SimpleNamespace( + user_id="target-user-1", + user_nickname="目标昵称", + user_cardname="群名片", + ) + ) + ) + runtime = SimpleNamespace( + _source_messages_by_id={"msg-1": target_message}, + chat_stream=SimpleNamespace(platform="qq", group_id=group_id), + session_id="session-1", + log_prefix="[test-at]", + _record_reply_sent=lambda: None, + _chat_history=[], + ) + return _ToolCtx(runtime=runtime) + + +def test_at_tool_spec_does_not_embed_visibility_metadata() -> None: + tool_spec = at_tool.get_tool_spec() + + assert tool_spec.name == "at" + assert "deferred" not in tool_spec.metadata + assert "visibility" not in tool_spec.metadata + + +@pytest.mark.asyncio +async def test_at_tool_sends_at_component_by_msg_id(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, Any] = {} + + async def fake_send_to_target_with_message(**kwargs: Any) -> object: + captured.update(kwargs) + return SimpleNamespace(message_id="sent-msg-1") + + monkeypatch.setattr(at_tool.send_service, "_send_to_target_with_message", fake_send_to_target_with_message) + + result = await at_tool.handle_tool( + _build_tool_ctx(), + ToolInvocation(tool_name="at", arguments={"msg_id": "msg-1", "text": "看这里"}), + ) + + assert result.success is True + assert result.structured_content["target_user_id"] == "target-user-1" + assert result.structured_content["target_user_name"] == "群名片" + assert captured["stream_id"] == "session-1" + assert captured["display_message"] == "@群名片 看这里" + assert captured["sync_to_maisaka_history"] is True + assert captured["maisaka_source_kind"] == "guided_reply" + + components = captured["message_sequence"].components + assert isinstance(components[0], AtComponent) + assert components[0].target_user_id == "target-user-1" + assert components[0].target_user_nickname == "目标昵称" + assert components[0].target_user_cardname == "群名片" + assert isinstance(components[1], TextComponent) + assert components[1].text == " 看这里" + + +@pytest.mark.asyncio +async def test_at_tool_rejects_private_chat() -> None: + result = await at_tool.handle_tool( + _build_tool_ctx(group_id=""), + ToolInvocation(tool_name="at", arguments={"msg_id": "msg-1"}), + ) + + assert result.success is False + assert "群聊" in result.error_message + + +@pytest.mark.asyncio +async def test_at_tool_rejects_unknown_msg_id() -> None: + result = await at_tool.handle_tool( + _build_tool_ctx(), + ToolInvocation(tool_name="at", arguments={"msg_id": "missing-msg"}), + ) + + assert result.success is False + assert result.structured_content == {"msg_id": "missing-msg"} diff --git a/pytests/test_tool_availability.py b/pytests/test_tool_availability.py new file mode 100644 index 00000000..9946900c --- /dev/null +++ b/pytests/test_tool_availability.py @@ -0,0 +1,96 @@ +from types import SimpleNamespace + +import pytest + +from src.core.tooling import ToolAvailabilityContext, ToolRegistry +from src.maisaka.tool_provider import MaisakaBuiltinToolProvider +from src.plugin_runtime.component_query import ComponentQueryService +from src.plugin_runtime.host.component_registry import ComponentRegistry + + +@pytest.mark.asyncio +async def test_builtin_at_is_exposed_only_in_group_chats() -> None: + registry = ToolRegistry() + registry.register_provider(MaisakaBuiltinToolProvider()) + + group_specs = await registry.list_tools(ToolAvailabilityContext(session_id="group-1", is_group_chat=True)) + private_specs = await registry.list_tools(ToolAvailabilityContext(session_id="private-1", is_group_chat=False)) + default_specs = await registry.list_tools() + + assert "at" in {tool_spec.name for tool_spec in group_specs} + assert "at" not in {tool_spec.name for tool_spec in private_specs} + assert "at" in {tool_spec.name for tool_spec in default_specs} + + +def test_plugin_tool_chat_scope_uses_component_field(monkeypatch: pytest.MonkeyPatch) -> None: + service = ComponentQueryService() + registry = ComponentRegistry() + supervisor = SimpleNamespace(component_registry=registry) + monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor]) + + registry.register_plugin_components( + "scope_plugin", + [ + { + "name": "group_tool", + "component_type": "TOOL", + "chat_scope": "group", + "metadata": {"description": "group only"}, + }, + { + "name": "private_tool", + "component_type": "TOOL", + "chat_scope": "private", + "metadata": {"description": "private only"}, + }, + { + "name": "all_tool", + "component_type": "TOOL", + "metadata": {"description": "all chats"}, + }, + ], + ) + + group_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext(session_id="group-1", is_group_chat=True) + ) + private_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext(session_id="private-1", is_group_chat=False) + ) + + group_entry = registry.get_component("scope_plugin.group_tool") + assert group_entry is not None + assert group_entry.chat_scope == "group" + assert "chat_scope" not in group_entry.metadata + assert set(group_specs) == {"group_tool", "all_tool"} + assert set(private_specs) == {"private_tool", "all_tool"} + + +def test_plugin_tool_session_disable_still_filters_specific_chat(monkeypatch: pytest.MonkeyPatch) -> None: + service = ComponentQueryService() + registry = ComponentRegistry() + supervisor = SimpleNamespace(component_registry=registry) + monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor]) + + registry.register_plugin_components( + "mute_plugin", + [ + { + "name": "mute", + "component_type": "TOOL", + "chat_scope": "group", + "metadata": {"description": "mute group member"}, + } + ], + ) + registry.set_component_enabled("mute_plugin.mute", False, session_id="group-disabled") + + disabled_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext(session_id="group-disabled", is_group_chat=True) + ) + enabled_specs = service.get_llm_available_tool_specs( + context=ToolAvailabilityContext(session_id="group-enabled", is_group_chat=True) + ) + + assert "mute" not in disabled_specs + assert "mute" in enabled_specs diff --git a/src/core/tooling.py b/src/core/tooling.py index ea78ec74..3cf438d1 100644 --- a/src/core/tooling.py +++ b/src/core/tooling.py @@ -222,6 +222,18 @@ class ToolExecutionContext: metadata: Dict[str, Any] = field(default_factory=dict) +@dataclass(slots=True) +class ToolAvailabilityContext: + """工具暴露可用性判断上下文。""" + + session_id: str = "" + stream_id: str = "" + is_group_chat: bool | None = None + group_id: str = "" + user_id: str = "" + platform: str = "" + + @dataclass(slots=True) class ToolExecutionResult: """统一工具执行结果。""" @@ -264,7 +276,10 @@ class ToolProvider(Protocol): provider_name: str provider_type: str - async def list_tools(self) -> list[ToolSpec]: + async def list_tools( + self, + context: Optional[ToolAvailabilityContext] = None, + ) -> list[ToolSpec]: """列出当前 Provider 暴露的全部工具。""" ... @@ -308,7 +323,10 @@ class ToolRegistry: self._providers = [item for item in self._providers if item.provider_name != provider_name] - async def list_tools(self) -> list[ToolSpec]: + async def list_tools( + self, + context: Optional[ToolAvailabilityContext] = None, + ) -> list[ToolSpec]: """按 Provider 顺序列出全部去重后的工具。 Returns: @@ -319,7 +337,7 @@ class ToolRegistry: seen_names: set[str] = set() for provider in self._providers: - provider_specs = await provider.list_tools() + provider_specs = await provider.list_tools(context) for spec in provider_specs: if not spec.enabled: continue @@ -332,7 +350,11 @@ class ToolRegistry: collected_specs.append(spec) return collected_specs - async def get_tool_spec(self, tool_name: str) -> Optional[ToolSpec]: + async def get_tool_spec( + self, + tool_name: str, + context: Optional[ToolAvailabilityContext] = None, + ) -> Optional[ToolSpec]: """查询指定工具声明。 Args: @@ -342,12 +364,16 @@ class ToolRegistry: Optional[ToolSpec]: 匹配到的工具声明。 """ - for spec in await self.list_tools(): + for spec in await self.list_tools(context): if spec.name == tool_name: return spec return None - async def has_tool(self, tool_name: str) -> bool: + async def has_tool( + self, + tool_name: str, + context: Optional[ToolAvailabilityContext] = None, + ) -> bool: """判断指定工具是否存在。 Args: @@ -357,16 +383,19 @@ class ToolRegistry: bool: 是否存在。 """ - return await self.get_tool_spec(tool_name) is not None + return await self.get_tool_spec(tool_name, context) is not None - async def get_llm_definitions(self) -> list[ToolDefinitionInput]: + async def get_llm_definitions( + self, + context: Optional[ToolAvailabilityContext] = None, + ) -> list[ToolDefinitionInput]: """获取供 LLM 使用的工具定义列表。 Returns: list[ToolDefinitionInput]: 统一工具定义列表。 """ - return [spec.to_llm_definition() for spec in await self.list_tools()] + return [spec.to_llm_definition() for spec in await self.list_tools(context)] async def invoke( self, diff --git a/src/maisaka/builtin_tool/__init__.py b/src/maisaka/builtin_tool/__init__.py index 256a9a44..ba8dae2d 100644 --- a/src/maisaka/builtin_tool/__init__.py +++ b/src/maisaka/builtin_tool/__init__.py @@ -1,12 +1,16 @@ """Maisaka 内置工具聚合入口。""" from collections.abc import Awaitable, Callable -from typing import Dict, List, Optional +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Literal, Optional from src.config.config import global_config -from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec +from src.core.tooling import ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec from src.llm_models.payload_content.tool_option import ToolDefinitionInput +from .at import get_tool_spec as get_at_tool_spec +from .at import handle_tool as handle_at_tool from .context import BuiltinToolRuntimeContext from .continue_tool import get_tool_spec as get_continue_tool_spec from .continue_tool import handle_tool as handle_continue_tool @@ -32,93 +36,152 @@ from .wait import get_tool_spec as get_wait_tool_spec from .wait import handle_tool as handle_wait_tool BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]] +BuiltinToolRawHandler = Callable[ + [BuiltinToolRuntimeContext, ToolInvocation, Optional[ToolExecutionContext]], + Awaitable[ToolExecutionResult], +] +BuiltinToolStage = Literal["timing", "action"] +BuiltinToolVisibility = Literal["visible", "deferred", "hidden"] +BuiltinToolChatScope = Literal["all", "group", "private"] -def get_timing_tool_specs() -> List[ToolSpec]: - """获取 Timing Gate 阶段可用的内置工具声明。""" +@dataclass(frozen=True) +class BuiltinToolEntry: + """内置工具目录项,集中声明工具所属阶段与默认可见性。""" - return [ - get_wait_tool_spec(), - get_no_reply_tool_spec(), - get_continue_tool_spec(), - ] + name: str + get_spec: Callable[[], ToolSpec] + handle_tool: BuiltinToolRawHandler + stage: BuiltinToolStage + visibility: BuiltinToolVisibility = "visible" + chat_scope: BuiltinToolChatScope = "all" + + def build_spec(self) -> ToolSpec: + """生成带统一可见性元数据的工具声明。""" + + tool_spec = deepcopy(self.get_spec()) + tool_spec.metadata["builtin_stage"] = self.stage + tool_spec.metadata["visibility"] = self.visibility + return tool_spec -def get_action_tool_specs() -> List[ToolSpec]: - """获取 Action Loop 阶段可用的内置工具声明。""" +def _get_query_memory_tool_spec() -> ToolSpec: + """根据配置生成 query_memory 工具声明。""" - return [ - get_finish_tool_spec(), - get_reply_tool_spec(), - get_view_complex_message_tool_spec(), - get_query_jargon_tool_spec(), - get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)), - get_send_emoji_tool_spec(), - get_tool_search_tool_spec(), - ] + return get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)) -def get_builtin_tool_specs() -> List[ToolSpec]: - """获取默认暴露的 Maisaka 内置工具声明。""" - - return get_action_tool_specs() +BUILTIN_TOOL_ENTRIES: List[BuiltinToolEntry] = [ + BuiltinToolEntry("wait", get_wait_tool_spec, handle_wait_tool, stage="timing"), + BuiltinToolEntry("no_reply", get_no_reply_tool_spec, handle_no_reply_tool, stage="timing"), + BuiltinToolEntry("continue", get_continue_tool_spec, handle_continue_tool, stage="timing"), + BuiltinToolEntry("finish", get_finish_tool_spec, handle_finish_tool, stage="action"), + BuiltinToolEntry("reply", get_reply_tool_spec, handle_reply_tool, stage="action"), + BuiltinToolEntry( + "view_complex_message", + get_view_complex_message_tool_spec, + handle_view_complex_message_tool, + stage="action", + ), + BuiltinToolEntry("query_jargon", get_query_jargon_tool_spec, handle_query_jargon_tool, stage="action"), + BuiltinToolEntry("query_memory", _get_query_memory_tool_spec, handle_query_memory_tool, stage="action"), + BuiltinToolEntry( + "query_person_info", + get_query_person_info_tool_spec, + handle_query_person_info_tool, + stage="action", + visibility="hidden", + ), + BuiltinToolEntry("send_emoji", get_send_emoji_tool_spec, handle_send_emoji_tool, stage="action"), + BuiltinToolEntry( + "at", + get_at_tool_spec, + handle_at_tool, + stage="action", + visibility="deferred", + chat_scope="group", + ), + BuiltinToolEntry("tool_search", get_tool_search_tool_spec, handle_tool_search_tool, stage="action"), +] -def get_all_builtin_tool_specs() -> List[ToolSpec]: +def _get_builtin_tool_entries( + *, + stage: Optional[BuiltinToolStage] = None, + visibility: Optional[BuiltinToolVisibility] = None, + context: Optional[ToolAvailabilityContext] = None, +) -> List[BuiltinToolEntry]: + """按阶段与可见性筛选内置工具目录项。""" + + entries = BUILTIN_TOOL_ENTRIES + if stage is not None: + entries = [entry for entry in entries if entry.stage == stage] + if visibility is not None: + entries = [entry for entry in entries if entry.visibility == visibility] + if context is not None: + entries = [entry for entry in entries if _is_builtin_tool_available(entry, context)] + return entries + + +def _is_builtin_tool_available(entry: BuiltinToolEntry, context: ToolAvailabilityContext) -> bool: + """判断内置工具是否适用于当前聊天。""" + + if entry.chat_scope == "all": + return True + if entry.chat_scope == "group": + return context.is_group_chat is True + if entry.chat_scope == "private": + return context.is_group_chat is False + return True + + +def get_builtin_tool_visibility(tool_spec: ToolSpec) -> BuiltinToolVisibility: + """读取工具声明里的可见性。""" + + raw_visibility = str(tool_spec.metadata.get("visibility") or "").strip() + if raw_visibility == "deferred": + return "deferred" + if raw_visibility == "hidden": + return "hidden" + return "visible" + + +def is_builtin_tool_in_action_stage(tool_spec: ToolSpec) -> bool: + """判断内置工具是否属于 Action Loop 阶段。""" + + return str(tool_spec.metadata.get("builtin_stage") or "").strip() == "action" + + +def get_all_builtin_tool_specs(context: Optional[ToolAvailabilityContext] = None) -> List[ToolSpec]: """获取全部内置工具声明。""" - return [ - *get_timing_tool_specs(), - get_finish_tool_spec(), - get_reply_tool_spec(), - get_view_complex_message_tool_spec(), - get_query_jargon_tool_spec(), - get_query_memory_tool_spec(enabled=True), - get_query_person_info_tool_spec(), - get_send_emoji_tool_spec(), - get_tool_search_tool_spec(), - ] + return [entry.build_spec() for entry in _get_builtin_tool_entries(context=context)] def get_timing_tools() -> List[ToolDefinitionInput]: """获取 Timing Gate 阶段的兼容工具定义。""" - return [tool_spec.to_llm_definition() for tool_spec in get_timing_tool_specs()] - - -def get_action_tools() -> List[ToolDefinitionInput]: - """获取 Action Loop 阶段的兼容工具定义。""" - - return [tool_spec.to_llm_definition() for tool_spec in get_action_tool_specs()] + tool_specs = [ + entry.build_spec() + for entry in _get_builtin_tool_entries(stage="timing", visibility="visible") + ] + return [tool_spec.to_llm_definition() for tool_spec in tool_specs if tool_spec.enabled] def get_builtin_tools() -> List[ToolDefinitionInput]: """获取默认暴露给模型层的内置工具定义。""" - return get_action_tools() + tool_specs = [ + entry.build_spec() + for entry in _get_builtin_tool_entries(stage="action", visibility="visible") + ] + return [tool_spec.to_llm_definition() for tool_spec in tool_specs if tool_spec.enabled] def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str, BuiltinToolHandler]: """构建内置工具处理器映射。""" return { - "continue": lambda invocation, context=None: handle_continue_tool(tool_ctx, invocation, context), - "finish": lambda invocation, context=None: handle_finish_tool(tool_ctx, invocation, context), - "reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context), - "no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context), - "query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context), - "query_memory": lambda invocation, context=None: handle_query_memory_tool(tool_ctx, invocation, context), - "query_person_info": lambda invocation, context=None: handle_query_person_info_tool( - tool_ctx, - invocation, - context, - ), - "wait": lambda invocation, context=None: handle_wait_tool(tool_ctx, invocation, context), - "send_emoji": lambda invocation, context=None: handle_send_emoji_tool(tool_ctx, invocation, context), - "tool_search": lambda invocation, context=None: handle_tool_search_tool(tool_ctx, invocation, context), - "view_complex_message": lambda invocation, context=None: handle_view_complex_message_tool( - tool_ctx, - invocation, - context, - ), + entry.name: lambda invocation, context=None, entry=entry: entry.handle_tool(tool_ctx, invocation, context) + for entry in BUILTIN_TOOL_ENTRIES } diff --git a/src/maisaka/builtin_tool/at.py b/src/maisaka/builtin_tool/at.py new file mode 100644 index 00000000..7f14ecba --- /dev/null +++ b/src/maisaka/builtin_tool/at.py @@ -0,0 +1,186 @@ +"""Maisaka 内置 at 工具。""" + +from typing import Any, Optional, TYPE_CHECKING + +from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message +from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, TextComponent +from src.common.logger import get_logger +from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec +from src.services import send_service + +if TYPE_CHECKING: + from .context import BuiltinToolRuntimeContext + +logger = get_logger("maisaka_builtin_at") + + +def get_tool_spec() -> ToolSpec: + """获取 at 工具声明。""" + + return ToolSpec( + name="at", + brief_description="根据一条已知 msg_id 找到发言用户,并发送一条 @ 该用户的消息。", + detailed_description=( + "参数说明:\n" + "- msg_id:string,必填。要 @ 的目标用户发过的消息编号。\n" + "- text:string,可选。@ 后追加发送的短文本;只想单独 @ 人时留空。\n" + "请优先从上下文里选择一条明确属于目标用户的 msg_id,不要凭昵称或印象猜测用户。" + ), + parameters_schema={ + "type": "object", + "properties": { + "msg_id": { + "type": "string", + "description": "要 @ 的目标用户发过的消息编号。", + }, + "text": { + "type": "string", + "description": "@ 后追加发送的短文本;只想单独 @ 人时留空。", + "default": "", + }, + }, + "required": ["msg_id"], + }, + provider_name="maisaka_builtin", + provider_type="builtin", + ) + + +def _get_target_user_info(target_message: Any) -> tuple[str, str, str]: + """从目标消息中提取可用于构造 at 组件的用户信息。""" + + message_info = getattr(target_message, "message_info", None) + user_info = getattr(message_info, "user_info", None) + target_user_id = str(getattr(user_info, "user_id", "") or "").strip() + target_user_nickname = str(getattr(user_info, "user_nickname", "") or "").strip() + target_user_cardname = str(getattr(user_info, "user_cardname", "") or "").strip() + return target_user_id, target_user_nickname, target_user_cardname + + +def _build_at_message_sequence( + *, + target_user_id: str, + target_user_nickname: str = "", + target_user_cardname: str = "", + text: str = "", +) -> MessageSequence: + """构造 @ 用户的消息组件序列。""" + + components = [ + AtComponent( + target_user_id=target_user_id, + target_user_nickname=target_user_nickname or None, + target_user_cardname=target_user_cardname or None, + ) + ] + normalized_text = text.strip() + if normalized_text: + components.append(TextComponent(f" {normalized_text}")) + return MessageSequence(components=components) + + +async def handle_tool( + tool_ctx: "BuiltinToolRuntimeContext", + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, +) -> ToolExecutionResult: + """执行 at 内置工具。""" + + del context + target_message_id = str(invocation.arguments.get("msg_id") or "").strip() + text = str(invocation.arguments.get("text") or "").strip() + + if not target_message_id: + return tool_ctx.build_failure_result( + invocation.tool_name, + "at 工具需要提供有效的 `msg_id` 参数。", + ) + + if not str(getattr(tool_ctx.runtime.chat_stream, "group_id", "") or "").strip(): + return tool_ctx.build_failure_result( + invocation.tool_name, + "at 工具只能在群聊中使用。", + structured_content={"msg_id": target_message_id}, + ) + + target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id) + if target_message is None: + return tool_ctx.build_failure_result( + invocation.tool_name, + f"未找到要 @ 的目标消息,msg_id={target_message_id}", + structured_content={"msg_id": target_message_id}, + ) + + target_user_id, target_user_nickname, target_user_cardname = _get_target_user_info(target_message) + if not target_user_id: + return tool_ctx.build_failure_result( + invocation.tool_name, + f"目标消息缺少有效用户 ID,msg_id={target_message_id}", + structured_content={"msg_id": target_message_id}, + ) + + target_user_name = target_user_cardname or target_user_nickname or target_user_id + message_sequence = _build_at_message_sequence( + target_user_id=target_user_id, + target_user_nickname=target_user_nickname, + target_user_cardname=target_user_cardname, + text=text, + ) + display_message = f"@{target_user_name}" + (f" {text}" if text else "") + + try: + if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME: + render_cli_message(display_message) + tool_ctx.append_guided_reply_to_chat_history(display_message) + sent_message = None + sent = True + else: + sent_message = await send_service._send_to_target_with_message( + message_sequence=message_sequence, + stream_id=tool_ctx.runtime.session_id, + display_message=display_message, + typing=False, + storage_message=True, + show_log=True, + sync_to_maisaka_history=True, + maisaka_source_kind="guided_reply", + ) + sent = sent_message is not None + except Exception as exc: + logger.exception( + f"{tool_ctx.runtime.log_prefix} 发送 at 消息时发生异常: msg_id={target_message_id} user_id={target_user_id}" + ) + return tool_ctx.build_failure_result( + invocation.tool_name, + f"发送 at 消息时发生异常:{exc}", + structured_content={ + "msg_id": target_message_id, + "target_user_id": target_user_id, + "target_user_name": target_user_name, + }, + ) + + if not sent: + return tool_ctx.build_failure_result( + invocation.tool_name, + "at 消息发送失败。", + structured_content={ + "msg_id": target_message_id, + "target_user_id": target_user_id, + "target_user_name": target_user_name, + }, + ) + + sent_message_id = str(getattr(sent_message, "message_id", "") or "").strip() if sent_message is not None else "" + tool_ctx.runtime._record_reply_sent() + return tool_ctx.build_success_result( + invocation.tool_name, + f"已 @ {target_user_name}。", + structured_content={ + "msg_id": target_message_id, + "target_user_id": target_user_id, + "target_user_name": target_user_name, + "text": text, + "sent_message_id": sent_message_id, + }, + ) diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index e14dd0c0..9363a13c 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -12,7 +12,7 @@ from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt from src.common.utils.utils_session import SessionUtils from src.config.config import global_config -from src.core.tooling import ToolRegistry +from src.core.tooling import ToolAvailabilityContext, ToolRegistry from src.llm_models.model_client.base_client import BaseClient from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType from src.llm_models.payload_content.resp_format import RespFormat @@ -502,7 +502,13 @@ class MaisakaChatLoopService: if tool_definitions is not None: all_tools = list(tool_definitions) elif self._tool_registry is not None: - tool_specs = await self._tool_registry.list_tools() + tool_specs = await self._tool_registry.list_tools( + ToolAvailabilityContext( + session_id=self._session_id, + stream_id=self._session_id, + is_group_chat=self._is_group_chat, + ) + ) all_tools = [tool_spec.to_llm_definition() for tool_spec in tool_specs] else: all_tools = [*get_builtin_tools(), *self._extra_tools] diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 603ee2a3..1f4c5500 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -14,14 +14,14 @@ from src.chat.message_receive.message import SessionMessage from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt -from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec +from src.core.tooling import ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec from src.llm_models.exceptions import ReqAbortException from src.llm_models.payload_content.tool_option import ToolCall from src.services import database_service as database_api from src.services.memory_service import memory_service -from .builtin_tool import get_action_tool_specs from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers +from .builtin_tool import get_builtin_tool_visibility, is_builtin_tool_in_action_stage from .builtin_tool import get_timing_tools from .chat_loop_service import ChatResponse from .chat_history_visual_refresher import refresh_chat_history_visual_placeholders @@ -54,8 +54,6 @@ logger = get_logger("maisaka_reasoning_engine") TIMING_GATE_CONTEXT_LIMIT = 24 TIMING_GATE_MAX_TOKENS = 384 TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"} -ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply"} -ACTION_BUILTIN_TOOL_NAMES = {tool_spec.name for tool_spec in get_action_tool_specs()} class MaisakaReasoningEngine: @@ -175,15 +173,19 @@ class MaisakaReasoningEngine: self._runtime.set_current_action_tool_names([]) return [], "" - tool_specs = await self._runtime._tool_registry.list_tools() + availability_context = self._build_tool_availability_context() + tool_specs = await self._runtime._tool_registry.list_tools(availability_context) visible_builtin_tool_specs: list[ToolSpec] = [] deferred_tool_specs: list[ToolSpec] = [] for tool_spec in tool_specs: - if tool_spec.name in ACTION_HIDDEN_TOOL_NAMES: - continue if tool_spec.provider_name == "maisaka_builtin": - if tool_spec.name in ACTION_BUILTIN_TOOL_NAMES: + if not is_builtin_tool_in_action_stage(tool_spec): + continue + visibility = get_builtin_tool_visibility(tool_spec) + if visibility == "visible": visible_builtin_tool_specs.append(tool_spec) + elif visibility == "deferred": + deferred_tool_specs.append(tool_spec) continue deferred_tool_specs.append(tool_spec) @@ -877,6 +879,19 @@ class MaisakaReasoningEngine: reasoning=latest_thought, ) + def _build_tool_availability_context(self) -> ToolAvailabilityContext: + """构造当前聊天的工具暴露上下文。""" + + chat_stream = self._runtime.chat_stream + return ToolAvailabilityContext( + session_id=self._runtime.session_id, + stream_id=self._runtime.session_id, + is_group_chat=chat_stream.is_group_session, + group_id=str(getattr(chat_stream, "group_id", "") or "").strip(), + user_id=str(getattr(chat_stream, "user_id", "") or "").strip(), + platform=str(getattr(chat_stream, "platform", "") or "").strip(), + ) + def _build_tool_execution_context( self, latest_thought: str, @@ -1197,6 +1212,8 @@ class MaisakaReasoningEngine: source_kind="timing_gate", ) ) + if tool_call.func_name == "wait": + return self._append_tool_execution_result(tool_call, result) def _build_tool_result_summary(self, tool_call: ToolCall, result: ToolExecutionResult) -> str: @@ -1290,9 +1307,10 @@ class MaisakaReasoningEngine: return False, tool_result_summaries, tool_monitor_results execution_context = self._build_tool_execution_context(latest_thought, anchor_message) + availability_context = self._build_tool_availability_context() tool_spec_map = { tool_spec.name: tool_spec - for tool_spec in await self._runtime._tool_registry.list_tools() + for tool_spec in await self._runtime._tool_registry.list_tools(availability_context) } total_tool_count = len(tool_calls) for tool_index, tool_call in enumerate(tool_calls, start=1): diff --git a/src/maisaka/tool_provider.py b/src/maisaka/tool_provider.py index b26ba138..908e5ad6 100644 --- a/src/maisaka/tool_provider.py +++ b/src/maisaka/tool_provider.py @@ -5,7 +5,14 @@ from __future__ import annotations from collections.abc import Awaitable, Callable from typing import Dict, Optional -from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec +from src.core.tooling import ( + ToolAvailabilityContext, + ToolExecutionContext, + ToolExecutionResult, + ToolInvocation, + ToolProvider, + ToolSpec, +) from .builtin_tool import get_all_builtin_tool_specs @@ -27,10 +34,13 @@ class MaisakaBuiltinToolProvider(ToolProvider): self._handlers = dict(handlers or {}) - async def list_tools(self) -> list[ToolSpec]: + async def list_tools( + self, + context: Optional[ToolAvailabilityContext] = None, + ) -> list[ToolSpec]: """列出全部内置工具。""" - return list(get_all_builtin_tool_specs()) + return list(get_all_builtin_tool_specs(context)) async def invoke( self, diff --git a/src/mcp_module/provider.py b/src/mcp_module/provider.py index 84065eb8..9f8e0cd3 100644 --- a/src/mcp_module/provider.py +++ b/src/mcp_module/provider.py @@ -4,7 +4,14 @@ from __future__ import annotations from typing import Optional -from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec +from src.core.tooling import ( + ToolAvailabilityContext, + ToolExecutionContext, + ToolExecutionResult, + ToolInvocation, + ToolProvider, + ToolSpec, +) from .manager import MCPManager @@ -24,9 +31,13 @@ class MCPToolProvider(ToolProvider): self._manager = manager - async def list_tools(self) -> list[ToolSpec]: + async def list_tools( + self, + context: Optional[ToolAvailabilityContext] = None, + ) -> list[ToolSpec]: """列出全部 MCP 工具。""" + del context return self._manager.get_tool_specs() async def invoke( diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py index dbd448f5..2fb01797 100644 --- a/src/plugin_runtime/component_query.py +++ b/src/plugin_runtime/component_query.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl from src.common.logger import get_logger from src.core.tooling import ( + ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, @@ -72,6 +73,7 @@ class ComponentQueryService: component_type: ComponentType, *, enabled_only: bool = True, + context: Optional[ToolAvailabilityContext] = None, ) -> list[tuple["PluginSupervisor", "ComponentEntry"]]: """遍历指定类型的全部组件条目。 @@ -87,11 +89,15 @@ class ComponentQueryService: if host_component_type is None: return [] + session_id = context.session_id if context is not None else None + is_group_chat = context.is_group_chat if context is not None else None collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = [] for supervisor in self._iter_supervisors(): for component in supervisor.component_registry.get_components_by_type( host_component_type, enabled_only=enabled_only, + session_id=session_id, + is_group_chat=is_group_chat, ): collected_entries.append((supervisor, component)) return collected_entries @@ -657,7 +663,10 @@ class ComponentQueryService: tool_entry = cast("ToolEntry", entry) return self._build_tool_executor(supervisor, tool_entry.plugin_id, tool_entry.name, tool_entry.invoke_method) - def get_llm_available_tool_specs(self) -> Dict[str, ToolSpec]: + def get_llm_available_tool_specs( + self, + context: Optional[ToolAvailabilityContext] = None, + ) -> Dict[str, ToolSpec]: """获取当前可供 LLM 使用的统一工具声明集合。 Returns: @@ -665,7 +674,7 @@ class ComponentQueryService: """ collected_specs: Dict[str, ToolSpec] = {} - for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL): + for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL, context=context): if entry.name in collected_specs: self._log_duplicate_component(ComponentType.TOOL, entry.name) continue diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 07e4d8ea..fceb7828 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -10,7 +10,7 @@ """ from enum import Enum -from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, TypedDict import contextlib import re @@ -58,6 +58,20 @@ class ComponentTypes(str, Enum): MESSAGE_GATEWAY = "MESSAGE_GATEWAY" +ComponentChatScope = Literal["all", "group", "private"] + + +def _normalize_chat_scope(raw_value: Any) -> ComponentChatScope: + """规范化组件聊天类型适用范围。""" + + normalized_value = str(raw_value or "all").strip().lower() + if normalized_value == "group": + return "group" + if normalized_value == "private": + return "private" + return "all" + + class StatusDict(TypedDict): total: int action: int @@ -81,9 +95,17 @@ class ComponentEntry: "enabled", "compiled_pattern", "disabled_session", + "chat_scope", ) - def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + def __init__( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: Dict[str, Any], + chat_scope: str = "all", + ) -> None: self.name: str = name self.full_name: str = f"{plugin_id}.{name}" self.component_type: ComponentTypes = ComponentTypes(component_type) @@ -91,20 +113,35 @@ class ComponentEntry: self.metadata: Dict[str, Any] = metadata self.enabled: bool = metadata.get("enabled", True) self.disabled_session: Set[str] = set() + self.chat_scope: ComponentChatScope = _normalize_chat_scope(chat_scope) class ActionEntry(ComponentEntry): """Action 组件条目""" - def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - super().__init__(name, component_type, plugin_id, metadata) + def __init__( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: Dict[str, Any], + chat_scope: str = "all", + ) -> None: + super().__init__(name, component_type, plugin_id, metadata, chat_scope) class CommandEntry(ComponentEntry): """Command 组件条目""" - def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - super().__init__(name, component_type, plugin_id, metadata) + def __init__( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: Dict[str, Any], + chat_scope: str = "all", + ) -> None: + super().__init__(name, component_type, plugin_id, metadata, chat_scope) self.aliases: List[str] = metadata.get("aliases", []) self.compiled_pattern: Optional[re.Pattern] = None if pattern := metadata.get("command_pattern", ""): @@ -117,7 +154,14 @@ class CommandEntry(ComponentEntry): class ToolEntry(ComponentEntry): """Tool 组件条目""" - def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + def __init__( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: Dict[str, Any], + chat_scope: str = "all", + ) -> None: self.description: str = str(metadata.get("description", "") or "").strip() self.brief_description: str = str( metadata.get("brief_description", self.description) or self.description or f"工具 {name}" @@ -128,7 +172,7 @@ class ToolEntry(ComponentEntry): self.detailed_description: str = detailed_description self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip() self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip() - super().__init__(name, component_type, plugin_id, metadata) + super().__init__(name, component_type, plugin_id, metadata, chat_scope) if not self.detailed_description: parameters_schema = self._get_parameters_schema() @@ -197,23 +241,37 @@ class ToolEntry(ComponentEntry): class EventHandlerEntry(ComponentEntry): """EventHandler 组件条目""" - def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + def __init__( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: Dict[str, Any], + chat_scope: str = "all", + ) -> None: self.event_type: str = metadata.get("event_type", "") self.weight: int = metadata.get("weight", 0) self.intercept_message: bool = metadata.get("intercept_message", False) - super().__init__(name, component_type, plugin_id, metadata) + super().__init__(name, component_type, plugin_id, metadata, chat_scope) class HookHandlerEntry(ComponentEntry): """HookHandler 组件条目。""" - def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + def __init__( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: Dict[str, Any], + chat_scope: str = "all", + ) -> None: self.hook: str = self._normalize_hook_name(metadata.get("hook", "")) self.mode: str = self._normalize_mode(metadata.get("mode", "blocking")) self.order: str = self._normalize_order(metadata.get("order", "normal")) self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0)) self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip")) - super().__init__(name, component_type, plugin_id, metadata) + super().__init__(name, component_type, plugin_id, metadata, chat_scope) @staticmethod def _normalize_error_policy(raw_value: Any) -> str: @@ -332,13 +390,20 @@ class HookHandlerEntry(ComponentEntry): class MessageGatewayEntry(ComponentEntry): """MessageGateway 组件条目""" - def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + def __init__( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: Dict[str, Any], + chat_scope: str = "all", + ) -> None: self.route_type: str = self._normalize_route_type(metadata.get("route_type", "")) self.platform: str = str(metadata.get("platform", "") or "").strip() self.protocol: str = str(metadata.get("protocol", "") or "").strip() self.account_id: str = str(metadata.get("account_id", "") or "").strip() self.scope: str = str(metadata.get("scope", "") or "").strip() - super().__init__(name, component_type, plugin_id, metadata) + super().__init__(name, component_type, plugin_id, metadata, chat_scope) @staticmethod def _normalize_route_type(raw_value: Any) -> str: @@ -578,6 +643,7 @@ class ComponentRegistry: component_type: str, plugin_id: str, metadata: Dict[str, Any], + chat_scope: str = "all", ) -> ComponentEntry: """根据声明构造组件条目。 @@ -599,18 +665,18 @@ class ComponentRegistry: normalized_metadata = dict(metadata) if normalized_type == ComponentTypes.ACTION: normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata) - component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata) + component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata, chat_scope) elif normalized_type == ComponentTypes.COMMAND: - component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata) + component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) elif normalized_type == ComponentTypes.TOOL: - component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata) + component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) elif normalized_type == ComponentTypes.EVENT_HANDLER: - component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata) + component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) elif normalized_type == ComponentTypes.HOOK_HANDLER: - component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata) + component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) self._validate_hook_handler_entry(component) elif normalized_type == ComponentTypes.MESSAGE_GATEWAY: - component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata) + component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope) else: raise ComponentRegistrationError( f"组件类型 {component_type} 不存在", @@ -662,7 +728,14 @@ class ComponentRegistry: self._by_plugin.setdefault(component.plugin_id, []).append(component) # ====== 注册 / 注销 ====== - def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: + def register_component( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: Dict[str, Any], + chat_scope: str = "all", + ) -> bool: """注册单个组件。 Args: @@ -678,7 +751,7 @@ class ComponentRegistry: ComponentRegistrationError: 组件声明不合法时抛出。 """ - component = self._build_component_entry(name, component_type, plugin_id, metadata) + component = self._build_component_entry(name, component_type, plugin_id, metadata, chat_scope) self._add_component_entry(component) return True @@ -701,14 +774,19 @@ class ComponentRegistry: prepared_components: List[ComponentEntry] = [] for component_data in components: + raw_metadata = ( + dict(component_data.get("metadata", {})) + if isinstance(component_data.get("metadata"), dict) + else {} + ) + chat_scope = str(component_data.get("chat_scope", raw_metadata.pop("chat_scope", "all")) or "all") prepared_components.append( self._build_component_entry( name=str(component_data.get("name", "") or ""), component_type=str(component_data.get("component_type", "") or ""), plugin_id=plugin_id, - metadata=component_data.get("metadata", {}) - if isinstance(component_data.get("metadata"), dict) - else {}, + metadata=raw_metadata, + chat_scope=chat_scope, ) ) @@ -733,9 +811,19 @@ class ComponentRegistry: return len(comps) # ====== 启用 / 禁用 ====== - def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None): + def check_component_enabled( + self, + component: ComponentEntry, + session_id: Optional[str] = None, + is_group_chat: Optional[bool] = None, + ): if session_id and session_id in component.disabled_session: return False + if is_group_chat is not None: + if component.chat_scope == "group" and is_group_chat is not True: + return False + if component.chat_scope == "private" and is_group_chat is not False: + return False return component.enabled def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: @@ -806,7 +894,12 @@ class ComponentRegistry: return self._components.get(full_name) def get_components_by_type( - self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None + self, + component_type: str, + *, + enabled_only: bool = True, + session_id: Optional[str] = None, + is_group_chat: Optional[bool] = None, ) -> List[ComponentEntry]: """按类型查询组件 @@ -830,12 +923,16 @@ class ComponentRegistry: if self._is_legacy_action_component(component) ] if enabled_only: - return [component for component in action_components if self.check_component_enabled(component, session_id)] + return [ + component + for component in action_components + if self.check_component_enabled(component, session_id, is_group_chat) + ] return action_components type_dict = self._by_type.get(comp_type, {}) if enabled_only: - return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)] + return [c for c in type_dict.values() if self.check_component_enabled(c, session_id, is_group_chat)] return list(type_dict.values()) def get_components_by_plugin( diff --git a/src/plugin_runtime/tool_provider.py b/src/plugin_runtime/tool_provider.py index 84bed06e..6ad48b79 100644 --- a/src/plugin_runtime/tool_provider.py +++ b/src/plugin_runtime/tool_provider.py @@ -4,7 +4,14 @@ from __future__ import annotations from typing import Optional -from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec +from src.core.tooling import ( + ToolAvailabilityContext, + ToolExecutionContext, + ToolExecutionResult, + ToolInvocation, + ToolProvider, + ToolSpec, +) from .component_query import component_query_service @@ -15,10 +22,13 @@ class PluginToolProvider(ToolProvider): provider_name = "plugin_runtime" provider_type = "plugin" - async def list_tools(self) -> list[ToolSpec]: + async def list_tools( + self, + context: Optional[ToolAvailabilityContext] = None, + ) -> list[ToolSpec]: """列出插件运行时当前可用的工具声明。""" - return list(component_query_service.get_llm_available_tool_specs().values()) + return list(component_query_service.get_llm_available_tool_specs(context=context).values()) async def invoke( self,