feat:修复门控多重result问题,新增at动作,插件现在运行chat_id指定或chat_type指定
This commit is contained in:
176
pytests/test_maisaka_builtin_at.py
Normal file
176
pytests/test_maisaka_builtin_at.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.common.data_models.message_component_data_model import AtComponent, TextComponent
|
||||
from src.core.tooling import ToolExecutionResult, ToolInvocation
|
||||
|
||||
_MISSING_MODULE = object()
|
||||
_module_overrides: dict[str, object] = {}
|
||||
|
||||
|
||||
def _override_module(module_name: str, module: ModuleType) -> None:
|
||||
_module_overrides[module_name] = sys.modules.get(module_name, _MISSING_MODULE)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
|
||||
def _restore_overridden_modules() -> None:
|
||||
for module_name, previous_module in reversed(_module_overrides.items()):
|
||||
if previous_module is _MISSING_MODULE:
|
||||
sys.modules.pop(module_name, None)
|
||||
else:
|
||||
sys.modules[module_name] = previous_module
|
||||
_module_overrides.clear()
|
||||
|
||||
|
||||
fake_cli_sender_module = ModuleType("src.cli.maisaka_cli_sender")
|
||||
fake_cli_sender_module.CLI_PLATFORM_NAME = "cli"
|
||||
fake_cli_sender_module.render_cli_message = lambda text: text
|
||||
fake_cli_module = ModuleType("src.cli")
|
||||
fake_cli_module.maisaka_cli_sender = fake_cli_sender_module
|
||||
|
||||
fake_send_service_module = ModuleType("src.services.send_service")
|
||||
fake_send_service_module._send_to_target_with_message = None
|
||||
fake_services_module = ModuleType("src.services")
|
||||
fake_services_module.send_service = fake_send_service_module
|
||||
|
||||
_override_module("src.cli", fake_cli_module)
|
||||
_override_module("src.cli.maisaka_cli_sender", fake_cli_sender_module)
|
||||
_override_module("src.services", fake_services_module)
|
||||
_override_module("src.services.send_service", fake_send_service_module)
|
||||
|
||||
AT_TOOL_PATH = Path(__file__).resolve().parents[1] / "src" / "maisaka" / "builtin_tool" / "at.py"
|
||||
at_tool_spec = util.spec_from_file_location("_test_maisaka_builtin_at_tool", AT_TOOL_PATH)
|
||||
assert at_tool_spec is not None and at_tool_spec.loader is not None
|
||||
at_tool = util.module_from_spec(at_tool_spec)
|
||||
sys.modules["_test_maisaka_builtin_at_tool"] = at_tool
|
||||
try:
|
||||
at_tool_spec.loader.exec_module(at_tool)
|
||||
finally:
|
||||
_restore_overridden_modules()
|
||||
|
||||
|
||||
class _ToolCtx:
|
||||
def __init__(self, runtime: SimpleNamespace) -> None:
|
||||
self.runtime = runtime
|
||||
|
||||
@staticmethod
|
||||
def build_success_result(
|
||||
tool_name: str,
|
||||
content: str = "",
|
||||
structured_content: Any = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ToolExecutionResult:
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=True,
|
||||
content=content,
|
||||
structured_content=structured_content,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_failure_result(
|
||||
tool_name: str,
|
||||
error_message: str,
|
||||
structured_content: Any = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ToolExecutionResult:
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=False,
|
||||
error_message=error_message,
|
||||
structured_content=structured_content,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
def append_guided_reply_to_chat_history(self, reply_text: str) -> None:
|
||||
self.runtime._chat_history.append(reply_text)
|
||||
|
||||
|
||||
def _build_tool_ctx(*, group_id: str = "group-1") -> _ToolCtx:
|
||||
target_message = SimpleNamespace(
|
||||
message_info=SimpleNamespace(
|
||||
user_info=SimpleNamespace(
|
||||
user_id="target-user-1",
|
||||
user_nickname="目标昵称",
|
||||
user_cardname="群名片",
|
||||
)
|
||||
)
|
||||
)
|
||||
runtime = SimpleNamespace(
|
||||
_source_messages_by_id={"msg-1": target_message},
|
||||
chat_stream=SimpleNamespace(platform="qq", group_id=group_id),
|
||||
session_id="session-1",
|
||||
log_prefix="[test-at]",
|
||||
_record_reply_sent=lambda: None,
|
||||
_chat_history=[],
|
||||
)
|
||||
return _ToolCtx(runtime=runtime)
|
||||
|
||||
|
||||
def test_at_tool_spec_does_not_embed_visibility_metadata() -> None:
|
||||
tool_spec = at_tool.get_tool_spec()
|
||||
|
||||
assert tool_spec.name == "at"
|
||||
assert "deferred" not in tool_spec.metadata
|
||||
assert "visibility" not in tool_spec.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_at_tool_sends_at_component_by_msg_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def fake_send_to_target_with_message(**kwargs: Any) -> object:
|
||||
captured.update(kwargs)
|
||||
return SimpleNamespace(message_id="sent-msg-1")
|
||||
|
||||
monkeypatch.setattr(at_tool.send_service, "_send_to_target_with_message", fake_send_to_target_with_message)
|
||||
|
||||
result = await at_tool.handle_tool(
|
||||
_build_tool_ctx(),
|
||||
ToolInvocation(tool_name="at", arguments={"msg_id": "msg-1", "text": "看这里"}),
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.structured_content["target_user_id"] == "target-user-1"
|
||||
assert result.structured_content["target_user_name"] == "群名片"
|
||||
assert captured["stream_id"] == "session-1"
|
||||
assert captured["display_message"] == "@群名片 看这里"
|
||||
assert captured["sync_to_maisaka_history"] is True
|
||||
assert captured["maisaka_source_kind"] == "guided_reply"
|
||||
|
||||
components = captured["message_sequence"].components
|
||||
assert isinstance(components[0], AtComponent)
|
||||
assert components[0].target_user_id == "target-user-1"
|
||||
assert components[0].target_user_nickname == "目标昵称"
|
||||
assert components[0].target_user_cardname == "群名片"
|
||||
assert isinstance(components[1], TextComponent)
|
||||
assert components[1].text == " 看这里"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_at_tool_rejects_private_chat() -> None:
|
||||
result = await at_tool.handle_tool(
|
||||
_build_tool_ctx(group_id=""),
|
||||
ToolInvocation(tool_name="at", arguments={"msg_id": "msg-1"}),
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "群聊" in result.error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_at_tool_rejects_unknown_msg_id() -> None:
|
||||
result = await at_tool.handle_tool(
|
||||
_build_tool_ctx(),
|
||||
ToolInvocation(tool_name="at", arguments={"msg_id": "missing-msg"}),
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.structured_content == {"msg_id": "missing-msg"}
|
||||
96
pytests/test_tool_availability.py
Normal file
96
pytests/test_tool_availability.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
|
||||
from src.maisaka.tool_provider import MaisakaBuiltinToolProvider
|
||||
from src.plugin_runtime.component_query import ComponentQueryService
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_builtin_at_is_exposed_only_in_group_chats() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register_provider(MaisakaBuiltinToolProvider())
|
||||
|
||||
group_specs = await registry.list_tools(ToolAvailabilityContext(session_id="group-1", is_group_chat=True))
|
||||
private_specs = await registry.list_tools(ToolAvailabilityContext(session_id="private-1", is_group_chat=False))
|
||||
default_specs = await registry.list_tools()
|
||||
|
||||
assert "at" in {tool_spec.name for tool_spec in group_specs}
|
||||
assert "at" not in {tool_spec.name for tool_spec in private_specs}
|
||||
assert "at" in {tool_spec.name for tool_spec in default_specs}
|
||||
|
||||
|
||||
def test_plugin_tool_chat_scope_uses_component_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = ComponentQueryService()
|
||||
registry = ComponentRegistry()
|
||||
supervisor = SimpleNamespace(component_registry=registry)
|
||||
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
|
||||
|
||||
registry.register_plugin_components(
|
||||
"scope_plugin",
|
||||
[
|
||||
{
|
||||
"name": "group_tool",
|
||||
"component_type": "TOOL",
|
||||
"chat_scope": "group",
|
||||
"metadata": {"description": "group only"},
|
||||
},
|
||||
{
|
||||
"name": "private_tool",
|
||||
"component_type": "TOOL",
|
||||
"chat_scope": "private",
|
||||
"metadata": {"description": "private only"},
|
||||
},
|
||||
{
|
||||
"name": "all_tool",
|
||||
"component_type": "TOOL",
|
||||
"metadata": {"description": "all chats"},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
group_specs = service.get_llm_available_tool_specs(
|
||||
context=ToolAvailabilityContext(session_id="group-1", is_group_chat=True)
|
||||
)
|
||||
private_specs = service.get_llm_available_tool_specs(
|
||||
context=ToolAvailabilityContext(session_id="private-1", is_group_chat=False)
|
||||
)
|
||||
|
||||
group_entry = registry.get_component("scope_plugin.group_tool")
|
||||
assert group_entry is not None
|
||||
assert group_entry.chat_scope == "group"
|
||||
assert "chat_scope" not in group_entry.metadata
|
||||
assert set(group_specs) == {"group_tool", "all_tool"}
|
||||
assert set(private_specs) == {"private_tool", "all_tool"}
|
||||
|
||||
|
||||
def test_plugin_tool_session_disable_still_filters_specific_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = ComponentQueryService()
|
||||
registry = ComponentRegistry()
|
||||
supervisor = SimpleNamespace(component_registry=registry)
|
||||
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
|
||||
|
||||
registry.register_plugin_components(
|
||||
"mute_plugin",
|
||||
[
|
||||
{
|
||||
"name": "mute",
|
||||
"component_type": "TOOL",
|
||||
"chat_scope": "group",
|
||||
"metadata": {"description": "mute group member"},
|
||||
}
|
||||
],
|
||||
)
|
||||
registry.set_component_enabled("mute_plugin.mute", False, session_id="group-disabled")
|
||||
|
||||
disabled_specs = service.get_llm_available_tool_specs(
|
||||
context=ToolAvailabilityContext(session_id="group-disabled", is_group_chat=True)
|
||||
)
|
||||
enabled_specs = service.get_llm_available_tool_specs(
|
||||
context=ToolAvailabilityContext(session_id="group-enabled", is_group_chat=True)
|
||||
)
|
||||
|
||||
assert "mute" not in disabled_specs
|
||||
assert "mute" in enabled_specs
|
||||
@@ -222,6 +222,18 @@ class ToolExecutionContext:
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolAvailabilityContext:
|
||||
"""工具暴露可用性判断上下文。"""
|
||||
|
||||
session_id: str = ""
|
||||
stream_id: str = ""
|
||||
is_group_chat: bool | None = None
|
||||
group_id: str = ""
|
||||
user_id: str = ""
|
||||
platform: str = ""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolExecutionResult:
|
||||
"""统一工具执行结果。"""
|
||||
@@ -264,7 +276,10 @@ class ToolProvider(Protocol):
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
async def list_tools(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolSpec]:
|
||||
"""列出当前 Provider 暴露的全部工具。"""
|
||||
...
|
||||
|
||||
@@ -308,7 +323,10 @@ class ToolRegistry:
|
||||
|
||||
self._providers = [item for item in self._providers if item.provider_name != provider_name]
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
async def list_tools(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolSpec]:
|
||||
"""按 Provider 顺序列出全部去重后的工具。
|
||||
|
||||
Returns:
|
||||
@@ -319,7 +337,7 @@ class ToolRegistry:
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for provider in self._providers:
|
||||
provider_specs = await provider.list_tools()
|
||||
provider_specs = await provider.list_tools(context)
|
||||
for spec in provider_specs:
|
||||
if not spec.enabled:
|
||||
continue
|
||||
@@ -332,7 +350,11 @@ class ToolRegistry:
|
||||
collected_specs.append(spec)
|
||||
return collected_specs
|
||||
|
||||
async def get_tool_spec(self, tool_name: str) -> Optional[ToolSpec]:
|
||||
async def get_tool_spec(
|
||||
self,
|
||||
tool_name: str,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> Optional[ToolSpec]:
|
||||
"""查询指定工具声明。
|
||||
|
||||
Args:
|
||||
@@ -342,12 +364,16 @@ class ToolRegistry:
|
||||
Optional[ToolSpec]: 匹配到的工具声明。
|
||||
"""
|
||||
|
||||
for spec in await self.list_tools():
|
||||
for spec in await self.list_tools(context):
|
||||
if spec.name == tool_name:
|
||||
return spec
|
||||
return None
|
||||
|
||||
async def has_tool(self, tool_name: str) -> bool:
|
||||
async def has_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> bool:
|
||||
"""判断指定工具是否存在。
|
||||
|
||||
Args:
|
||||
@@ -357,16 +383,19 @@ class ToolRegistry:
|
||||
bool: 是否存在。
|
||||
"""
|
||||
|
||||
return await self.get_tool_spec(tool_name) is not None
|
||||
return await self.get_tool_spec(tool_name, context) is not None
|
||||
|
||||
async def get_llm_definitions(self) -> list[ToolDefinitionInput]:
|
||||
async def get_llm_definitions(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolDefinitionInput]:
|
||||
"""获取供 LLM 使用的工具定义列表。
|
||||
|
||||
Returns:
|
||||
list[ToolDefinitionInput]: 统一工具定义列表。
|
||||
"""
|
||||
|
||||
return [spec.to_llm_definition() for spec in await self.list_tools()]
|
||||
return [spec.to_llm_definition() for spec in await self.list_tools(context)]
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
"""Maisaka 内置工具聚合入口。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Dict, List, Optional
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.core.tooling import ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
|
||||
from .at import get_tool_spec as get_at_tool_spec
|
||||
from .at import handle_tool as handle_at_tool
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
from .continue_tool import get_tool_spec as get_continue_tool_spec
|
||||
from .continue_tool import handle_tool as handle_continue_tool
|
||||
@@ -32,93 +36,152 @@ from .wait import get_tool_spec as get_wait_tool_spec
|
||||
from .wait import handle_tool as handle_wait_tool
|
||||
|
||||
BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]]
|
||||
BuiltinToolRawHandler = Callable[
|
||||
[BuiltinToolRuntimeContext, ToolInvocation, Optional[ToolExecutionContext]],
|
||||
Awaitable[ToolExecutionResult],
|
||||
]
|
||||
BuiltinToolStage = Literal["timing", "action"]
|
||||
BuiltinToolVisibility = Literal["visible", "deferred", "hidden"]
|
||||
BuiltinToolChatScope = Literal["all", "group", "private"]
|
||||
|
||||
|
||||
def get_timing_tool_specs() -> List[ToolSpec]:
|
||||
"""获取 Timing Gate 阶段可用的内置工具声明。"""
|
||||
@dataclass(frozen=True)
|
||||
class BuiltinToolEntry:
|
||||
"""内置工具目录项,集中声明工具所属阶段与默认可见性。"""
|
||||
|
||||
return [
|
||||
get_wait_tool_spec(),
|
||||
get_no_reply_tool_spec(),
|
||||
get_continue_tool_spec(),
|
||||
]
|
||||
name: str
|
||||
get_spec: Callable[[], ToolSpec]
|
||||
handle_tool: BuiltinToolRawHandler
|
||||
stage: BuiltinToolStage
|
||||
visibility: BuiltinToolVisibility = "visible"
|
||||
chat_scope: BuiltinToolChatScope = "all"
|
||||
|
||||
def build_spec(self) -> ToolSpec:
|
||||
"""生成带统一可见性元数据的工具声明。"""
|
||||
|
||||
tool_spec = deepcopy(self.get_spec())
|
||||
tool_spec.metadata["builtin_stage"] = self.stage
|
||||
tool_spec.metadata["visibility"] = self.visibility
|
||||
return tool_spec
|
||||
|
||||
|
||||
def get_action_tool_specs() -> List[ToolSpec]:
|
||||
"""获取 Action Loop 阶段可用的内置工具声明。"""
|
||||
def _get_query_memory_tool_spec() -> ToolSpec:
|
||||
"""根据配置生成 query_memory 工具声明。"""
|
||||
|
||||
return [
|
||||
get_finish_tool_spec(),
|
||||
get_reply_tool_spec(),
|
||||
get_view_complex_message_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)),
|
||||
get_send_emoji_tool_spec(),
|
||||
get_tool_search_tool_spec(),
|
||||
]
|
||||
return get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool))
|
||||
|
||||
|
||||
def get_builtin_tool_specs() -> List[ToolSpec]:
|
||||
"""获取默认暴露的 Maisaka 内置工具声明。"""
|
||||
|
||||
return get_action_tool_specs()
|
||||
BUILTIN_TOOL_ENTRIES: List[BuiltinToolEntry] = [
|
||||
BuiltinToolEntry("wait", get_wait_tool_spec, handle_wait_tool, stage="timing"),
|
||||
BuiltinToolEntry("no_reply", get_no_reply_tool_spec, handle_no_reply_tool, stage="timing"),
|
||||
BuiltinToolEntry("continue", get_continue_tool_spec, handle_continue_tool, stage="timing"),
|
||||
BuiltinToolEntry("finish", get_finish_tool_spec, handle_finish_tool, stage="action"),
|
||||
BuiltinToolEntry("reply", get_reply_tool_spec, handle_reply_tool, stage="action"),
|
||||
BuiltinToolEntry(
|
||||
"view_complex_message",
|
||||
get_view_complex_message_tool_spec,
|
||||
handle_view_complex_message_tool,
|
||||
stage="action",
|
||||
),
|
||||
BuiltinToolEntry("query_jargon", get_query_jargon_tool_spec, handle_query_jargon_tool, stage="action"),
|
||||
BuiltinToolEntry("query_memory", _get_query_memory_tool_spec, handle_query_memory_tool, stage="action"),
|
||||
BuiltinToolEntry(
|
||||
"query_person_info",
|
||||
get_query_person_info_tool_spec,
|
||||
handle_query_person_info_tool,
|
||||
stage="action",
|
||||
visibility="hidden",
|
||||
),
|
||||
BuiltinToolEntry("send_emoji", get_send_emoji_tool_spec, handle_send_emoji_tool, stage="action"),
|
||||
BuiltinToolEntry(
|
||||
"at",
|
||||
get_at_tool_spec,
|
||||
handle_at_tool,
|
||||
stage="action",
|
||||
visibility="deferred",
|
||||
chat_scope="group",
|
||||
),
|
||||
BuiltinToolEntry("tool_search", get_tool_search_tool_spec, handle_tool_search_tool, stage="action"),
|
||||
]
|
||||
|
||||
|
||||
def get_all_builtin_tool_specs() -> List[ToolSpec]:
|
||||
def _get_builtin_tool_entries(
|
||||
*,
|
||||
stage: Optional[BuiltinToolStage] = None,
|
||||
visibility: Optional[BuiltinToolVisibility] = None,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> List[BuiltinToolEntry]:
|
||||
"""按阶段与可见性筛选内置工具目录项。"""
|
||||
|
||||
entries = BUILTIN_TOOL_ENTRIES
|
||||
if stage is not None:
|
||||
entries = [entry for entry in entries if entry.stage == stage]
|
||||
if visibility is not None:
|
||||
entries = [entry for entry in entries if entry.visibility == visibility]
|
||||
if context is not None:
|
||||
entries = [entry for entry in entries if _is_builtin_tool_available(entry, context)]
|
||||
return entries
|
||||
|
||||
|
||||
def _is_builtin_tool_available(entry: BuiltinToolEntry, context: ToolAvailabilityContext) -> bool:
|
||||
"""判断内置工具是否适用于当前聊天。"""
|
||||
|
||||
if entry.chat_scope == "all":
|
||||
return True
|
||||
if entry.chat_scope == "group":
|
||||
return context.is_group_chat is True
|
||||
if entry.chat_scope == "private":
|
||||
return context.is_group_chat is False
|
||||
return True
|
||||
|
||||
|
||||
def get_builtin_tool_visibility(tool_spec: ToolSpec) -> BuiltinToolVisibility:
|
||||
"""读取工具声明里的可见性。"""
|
||||
|
||||
raw_visibility = str(tool_spec.metadata.get("visibility") or "").strip()
|
||||
if raw_visibility == "deferred":
|
||||
return "deferred"
|
||||
if raw_visibility == "hidden":
|
||||
return "hidden"
|
||||
return "visible"
|
||||
|
||||
|
||||
def is_builtin_tool_in_action_stage(tool_spec: ToolSpec) -> bool:
|
||||
"""判断内置工具是否属于 Action Loop 阶段。"""
|
||||
|
||||
return str(tool_spec.metadata.get("builtin_stage") or "").strip() == "action"
|
||||
|
||||
|
||||
def get_all_builtin_tool_specs(context: Optional[ToolAvailabilityContext] = None) -> List[ToolSpec]:
|
||||
"""获取全部内置工具声明。"""
|
||||
|
||||
return [
|
||||
*get_timing_tool_specs(),
|
||||
get_finish_tool_spec(),
|
||||
get_reply_tool_spec(),
|
||||
get_view_complex_message_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_query_memory_tool_spec(enabled=True),
|
||||
get_query_person_info_tool_spec(),
|
||||
get_send_emoji_tool_spec(),
|
||||
get_tool_search_tool_spec(),
|
||||
]
|
||||
return [entry.build_spec() for entry in _get_builtin_tool_entries(context=context)]
|
||||
|
||||
|
||||
def get_timing_tools() -> List[ToolDefinitionInput]:
|
||||
"""获取 Timing Gate 阶段的兼容工具定义。"""
|
||||
|
||||
return [tool_spec.to_llm_definition() for tool_spec in get_timing_tool_specs()]
|
||||
|
||||
|
||||
def get_action_tools() -> List[ToolDefinitionInput]:
|
||||
"""获取 Action Loop 阶段的兼容工具定义。"""
|
||||
|
||||
return [tool_spec.to_llm_definition() for tool_spec in get_action_tool_specs()]
|
||||
tool_specs = [
|
||||
entry.build_spec()
|
||||
for entry in _get_builtin_tool_entries(stage="timing", visibility="visible")
|
||||
]
|
||||
return [tool_spec.to_llm_definition() for tool_spec in tool_specs if tool_spec.enabled]
|
||||
|
||||
|
||||
def get_builtin_tools() -> List[ToolDefinitionInput]:
|
||||
"""获取默认暴露给模型层的内置工具定义。"""
|
||||
|
||||
return get_action_tools()
|
||||
tool_specs = [
|
||||
entry.build_spec()
|
||||
for entry in _get_builtin_tool_entries(stage="action", visibility="visible")
|
||||
]
|
||||
return [tool_spec.to_llm_definition() for tool_spec in tool_specs if tool_spec.enabled]
|
||||
|
||||
|
||||
def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str, BuiltinToolHandler]:
|
||||
"""构建内置工具处理器映射。"""
|
||||
|
||||
return {
|
||||
"continue": lambda invocation, context=None: handle_continue_tool(tool_ctx, invocation, context),
|
||||
"finish": lambda invocation, context=None: handle_finish_tool(tool_ctx, invocation, context),
|
||||
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
|
||||
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
|
||||
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
|
||||
"query_memory": lambda invocation, context=None: handle_query_memory_tool(tool_ctx, invocation, context),
|
||||
"query_person_info": lambda invocation, context=None: handle_query_person_info_tool(
|
||||
tool_ctx,
|
||||
invocation,
|
||||
context,
|
||||
),
|
||||
"wait": lambda invocation, context=None: handle_wait_tool(tool_ctx, invocation, context),
|
||||
"send_emoji": lambda invocation, context=None: handle_send_emoji_tool(tool_ctx, invocation, context),
|
||||
"tool_search": lambda invocation, context=None: handle_tool_search_tool(tool_ctx, invocation, context),
|
||||
"view_complex_message": lambda invocation, context=None: handle_view_complex_message_tool(
|
||||
tool_ctx,
|
||||
invocation,
|
||||
context,
|
||||
),
|
||||
entry.name: lambda invocation, context=None, entry=entry: entry.handle_tool(tool_ctx, invocation, context)
|
||||
for entry in BUILTIN_TOOL_ENTRIES
|
||||
}
|
||||
|
||||
186
src/maisaka/builtin_tool/at.py
Normal file
186
src/maisaka/builtin_tool/at.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Maisaka 内置 at 工具。"""
|
||||
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
|
||||
from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, TextComponent
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.services import send_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
logger = get_logger("maisaka_builtin_at")
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 at 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="at",
|
||||
brief_description="根据一条已知 msg_id 找到发言用户,并发送一条 @ 该用户的消息。",
|
||||
detailed_description=(
|
||||
"参数说明:\n"
|
||||
"- msg_id:string,必填。要 @ 的目标用户发过的消息编号。\n"
|
||||
"- text:string,可选。@ 后追加发送的短文本;只想单独 @ 人时留空。\n"
|
||||
"请优先从上下文里选择一条明确属于目标用户的 msg_id,不要凭昵称或印象猜测用户。"
|
||||
),
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"msg_id": {
|
||||
"type": "string",
|
||||
"description": "要 @ 的目标用户发过的消息编号。",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "@ 后追加发送的短文本;只想单独 @ 人时留空。",
|
||||
"default": "",
|
||||
},
|
||||
},
|
||||
"required": ["msg_id"],
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
def _get_target_user_info(target_message: Any) -> tuple[str, str, str]:
|
||||
"""从目标消息中提取可用于构造 at 组件的用户信息。"""
|
||||
|
||||
message_info = getattr(target_message, "message_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
target_user_id = str(getattr(user_info, "user_id", "") or "").strip()
|
||||
target_user_nickname = str(getattr(user_info, "user_nickname", "") or "").strip()
|
||||
target_user_cardname = str(getattr(user_info, "user_cardname", "") or "").strip()
|
||||
return target_user_id, target_user_nickname, target_user_cardname
|
||||
|
||||
|
||||
def _build_at_message_sequence(
|
||||
*,
|
||||
target_user_id: str,
|
||||
target_user_nickname: str = "",
|
||||
target_user_cardname: str = "",
|
||||
text: str = "",
|
||||
) -> MessageSequence:
|
||||
"""构造 @ 用户的消息组件序列。"""
|
||||
|
||||
components = [
|
||||
AtComponent(
|
||||
target_user_id=target_user_id,
|
||||
target_user_nickname=target_user_nickname or None,
|
||||
target_user_cardname=target_user_cardname or None,
|
||||
)
|
||||
]
|
||||
normalized_text = text.strip()
|
||||
if normalized_text:
|
||||
components.append(TextComponent(f" {normalized_text}"))
|
||||
return MessageSequence(components=components)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: "BuiltinToolRuntimeContext",
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 at 内置工具。"""
|
||||
|
||||
del context
|
||||
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
|
||||
text = str(invocation.arguments.get("text") or "").strip()
|
||||
|
||||
if not target_message_id:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"at 工具需要提供有效的 `msg_id` 参数。",
|
||||
)
|
||||
|
||||
if not str(getattr(tool_ctx.runtime.chat_stream, "group_id", "") or "").strip():
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"at 工具只能在群聊中使用。",
|
||||
structured_content={"msg_id": target_message_id},
|
||||
)
|
||||
|
||||
target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id)
|
||||
if target_message is None:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
f"未找到要 @ 的目标消息,msg_id={target_message_id}",
|
||||
structured_content={"msg_id": target_message_id},
|
||||
)
|
||||
|
||||
target_user_id, target_user_nickname, target_user_cardname = _get_target_user_info(target_message)
|
||||
if not target_user_id:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
f"目标消息缺少有效用户 ID,msg_id={target_message_id}",
|
||||
structured_content={"msg_id": target_message_id},
|
||||
)
|
||||
|
||||
target_user_name = target_user_cardname or target_user_nickname or target_user_id
|
||||
message_sequence = _build_at_message_sequence(
|
||||
target_user_id=target_user_id,
|
||||
target_user_nickname=target_user_nickname,
|
||||
target_user_cardname=target_user_cardname,
|
||||
text=text,
|
||||
)
|
||||
display_message = f"@{target_user_name}" + (f" {text}" if text else "")
|
||||
|
||||
try:
|
||||
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
|
||||
render_cli_message(display_message)
|
||||
tool_ctx.append_guided_reply_to_chat_history(display_message)
|
||||
sent_message = None
|
||||
sent = True
|
||||
else:
|
||||
sent_message = await send_service._send_to_target_with_message(
|
||||
message_sequence=message_sequence,
|
||||
stream_id=tool_ctx.runtime.session_id,
|
||||
display_message=display_message,
|
||||
typing=False,
|
||||
storage_message=True,
|
||||
show_log=True,
|
||||
sync_to_maisaka_history=True,
|
||||
maisaka_source_kind="guided_reply",
|
||||
)
|
||||
sent = sent_message is not None
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
f"{tool_ctx.runtime.log_prefix} 发送 at 消息时发生异常: msg_id={target_message_id} user_id={target_user_id}"
|
||||
)
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
f"发送 at 消息时发生异常:{exc}",
|
||||
structured_content={
|
||||
"msg_id": target_message_id,
|
||||
"target_user_id": target_user_id,
|
||||
"target_user_name": target_user_name,
|
||||
},
|
||||
)
|
||||
|
||||
if not sent:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"at 消息发送失败。",
|
||||
structured_content={
|
||||
"msg_id": target_message_id,
|
||||
"target_user_id": target_user_id,
|
||||
"target_user_name": target_user_name,
|
||||
},
|
||||
)
|
||||
|
||||
sent_message_id = str(getattr(sent_message, "message_id", "") or "").strip() if sent_message is not None else ""
|
||||
tool_ctx.runtime._record_reply_sent()
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
f"已 @ {target_user_name}。",
|
||||
structured_content={
|
||||
"msg_id": target_message_id,
|
||||
"target_user_id": target_user_id,
|
||||
"target_user_name": target_user_name,
|
||||
"text": text,
|
||||
"sent_message_id": sent_message_id,
|
||||
},
|
||||
)
|
||||
@@ -12,7 +12,7 @@ from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
@@ -502,7 +502,13 @@ class MaisakaChatLoopService:
|
||||
if tool_definitions is not None:
|
||||
all_tools = list(tool_definitions)
|
||||
elif self._tool_registry is not None:
|
||||
tool_specs = await self._tool_registry.list_tools()
|
||||
tool_specs = await self._tool_registry.list_tools(
|
||||
ToolAvailabilityContext(
|
||||
session_id=self._session_id,
|
||||
stream_id=self._session_id,
|
||||
is_group_chat=self._is_group_chat,
|
||||
)
|
||||
)
|
||||
all_tools = [tool_spec.to_llm_definition() for tool_spec in tool_specs]
|
||||
else:
|
||||
all_tools = [*get_builtin_tools(), *self._extra_tools]
|
||||
|
||||
@@ -14,14 +14,14 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.core.tooling import ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.services import database_service as database_api
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
from .builtin_tool import get_action_tool_specs
|
||||
from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers
|
||||
from .builtin_tool import get_builtin_tool_visibility, is_builtin_tool_in_action_stage
|
||||
from .builtin_tool import get_timing_tools
|
||||
from .chat_loop_service import ChatResponse
|
||||
from .chat_history_visual_refresher import refresh_chat_history_visual_placeholders
|
||||
@@ -54,8 +54,6 @@ logger = get_logger("maisaka_reasoning_engine")
|
||||
TIMING_GATE_CONTEXT_LIMIT = 24
|
||||
TIMING_GATE_MAX_TOKENS = 384
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply"}
|
||||
ACTION_BUILTIN_TOOL_NAMES = {tool_spec.name for tool_spec in get_action_tool_specs()}
|
||||
|
||||
|
||||
class MaisakaReasoningEngine:
|
||||
@@ -175,15 +173,19 @@ class MaisakaReasoningEngine:
|
||||
self._runtime.set_current_action_tool_names([])
|
||||
return [], ""
|
||||
|
||||
tool_specs = await self._runtime._tool_registry.list_tools()
|
||||
availability_context = self._build_tool_availability_context()
|
||||
tool_specs = await self._runtime._tool_registry.list_tools(availability_context)
|
||||
visible_builtin_tool_specs: list[ToolSpec] = []
|
||||
deferred_tool_specs: list[ToolSpec] = []
|
||||
for tool_spec in tool_specs:
|
||||
if tool_spec.name in ACTION_HIDDEN_TOOL_NAMES:
|
||||
continue
|
||||
if tool_spec.provider_name == "maisaka_builtin":
|
||||
if tool_spec.name in ACTION_BUILTIN_TOOL_NAMES:
|
||||
if not is_builtin_tool_in_action_stage(tool_spec):
|
||||
continue
|
||||
visibility = get_builtin_tool_visibility(tool_spec)
|
||||
if visibility == "visible":
|
||||
visible_builtin_tool_specs.append(tool_spec)
|
||||
elif visibility == "deferred":
|
||||
deferred_tool_specs.append(tool_spec)
|
||||
continue
|
||||
deferred_tool_specs.append(tool_spec)
|
||||
|
||||
@@ -877,6 +879,19 @@ class MaisakaReasoningEngine:
|
||||
reasoning=latest_thought,
|
||||
)
|
||||
|
||||
def _build_tool_availability_context(self) -> ToolAvailabilityContext:
|
||||
"""构造当前聊天的工具暴露上下文。"""
|
||||
|
||||
chat_stream = self._runtime.chat_stream
|
||||
return ToolAvailabilityContext(
|
||||
session_id=self._runtime.session_id,
|
||||
stream_id=self._runtime.session_id,
|
||||
is_group_chat=chat_stream.is_group_session,
|
||||
group_id=str(getattr(chat_stream, "group_id", "") or "").strip(),
|
||||
user_id=str(getattr(chat_stream, "user_id", "") or "").strip(),
|
||||
platform=str(getattr(chat_stream, "platform", "") or "").strip(),
|
||||
)
|
||||
|
||||
def _build_tool_execution_context(
|
||||
self,
|
||||
latest_thought: str,
|
||||
@@ -1197,6 +1212,8 @@ class MaisakaReasoningEngine:
|
||||
source_kind="timing_gate",
|
||||
)
|
||||
)
|
||||
if tool_call.func_name == "wait":
|
||||
return
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
|
||||
def _build_tool_result_summary(self, tool_call: ToolCall, result: ToolExecutionResult) -> str:
|
||||
@@ -1290,9 +1307,10 @@ class MaisakaReasoningEngine:
|
||||
return False, tool_result_summaries, tool_monitor_results
|
||||
|
||||
execution_context = self._build_tool_execution_context(latest_thought, anchor_message)
|
||||
availability_context = self._build_tool_availability_context()
|
||||
tool_spec_map = {
|
||||
tool_spec.name: tool_spec
|
||||
for tool_spec in await self._runtime._tool_registry.list_tools()
|
||||
for tool_spec in await self._runtime._tool_registry.list_tools(availability_context)
|
||||
}
|
||||
total_tool_count = len(tool_calls)
|
||||
for tool_index, tool_call in enumerate(tool_calls, start=1):
|
||||
|
||||
@@ -5,7 +5,14 @@ from __future__ import annotations
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
from src.core.tooling import (
|
||||
ToolAvailabilityContext,
|
||||
ToolExecutionContext,
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
ToolProvider,
|
||||
ToolSpec,
|
||||
)
|
||||
|
||||
from .builtin_tool import get_all_builtin_tool_specs
|
||||
|
||||
@@ -27,10 +34,13 @@ class MaisakaBuiltinToolProvider(ToolProvider):
|
||||
|
||||
self._handlers = dict(handlers or {})
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
async def list_tools(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolSpec]:
|
||||
"""列出全部内置工具。"""
|
||||
|
||||
return list(get_all_builtin_tool_specs())
|
||||
return list(get_all_builtin_tool_specs(context))
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
|
||||
@@ -4,7 +4,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
from src.core.tooling import (
|
||||
ToolAvailabilityContext,
|
||||
ToolExecutionContext,
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
ToolProvider,
|
||||
ToolSpec,
|
||||
)
|
||||
|
||||
from .manager import MCPManager
|
||||
|
||||
@@ -24,9 +31,13 @@ class MCPToolProvider(ToolProvider):
|
||||
|
||||
self._manager = manager
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
async def list_tools(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolSpec]:
|
||||
"""列出全部 MCP 工具。"""
|
||||
|
||||
del context
|
||||
return self._manager.get_tool_specs()
|
||||
|
||||
async def invoke(
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import (
|
||||
ToolAvailabilityContext,
|
||||
ToolExecutionContext,
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
@@ -72,6 +73,7 @@ class ComponentQueryService:
|
||||
component_type: ComponentType,
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
|
||||
"""遍历指定类型的全部组件条目。
|
||||
|
||||
@@ -87,11 +89,15 @@ class ComponentQueryService:
|
||||
if host_component_type is None:
|
||||
return []
|
||||
|
||||
session_id = context.session_id if context is not None else None
|
||||
is_group_chat = context.is_group_chat if context is not None else None
|
||||
collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
|
||||
for supervisor in self._iter_supervisors():
|
||||
for component in supervisor.component_registry.get_components_by_type(
|
||||
host_component_type,
|
||||
enabled_only=enabled_only,
|
||||
session_id=session_id,
|
||||
is_group_chat=is_group_chat,
|
||||
):
|
||||
collected_entries.append((supervisor, component))
|
||||
return collected_entries
|
||||
@@ -657,7 +663,10 @@ class ComponentQueryService:
|
||||
tool_entry = cast("ToolEntry", entry)
|
||||
return self._build_tool_executor(supervisor, tool_entry.plugin_id, tool_entry.name, tool_entry.invoke_method)
|
||||
|
||||
def get_llm_available_tool_specs(self) -> Dict[str, ToolSpec]:
|
||||
def get_llm_available_tool_specs(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> Dict[str, ToolSpec]:
|
||||
"""获取当前可供 LLM 使用的统一工具声明集合。
|
||||
|
||||
Returns:
|
||||
@@ -665,7 +674,7 @@ class ComponentQueryService:
|
||||
"""
|
||||
|
||||
collected_specs: Dict[str, ToolSpec] = {}
|
||||
for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL):
|
||||
for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL, context=context):
|
||||
if entry.name in collected_specs:
|
||||
self._log_duplicate_component(ComponentType.TOOL, entry.name)
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, TypedDict
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
@@ -58,6 +58,20 @@ class ComponentTypes(str, Enum):
|
||||
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
|
||||
|
||||
|
||||
ComponentChatScope = Literal["all", "group", "private"]
|
||||
|
||||
|
||||
def _normalize_chat_scope(raw_value: Any) -> ComponentChatScope:
|
||||
"""规范化组件聊天类型适用范围。"""
|
||||
|
||||
normalized_value = str(raw_value or "all").strip().lower()
|
||||
if normalized_value == "group":
|
||||
return "group"
|
||||
if normalized_value == "private":
|
||||
return "private"
|
||||
return "all"
|
||||
|
||||
|
||||
class StatusDict(TypedDict):
|
||||
total: int
|
||||
action: int
|
||||
@@ -81,9 +95,17 @@ class ComponentEntry:
|
||||
"enabled",
|
||||
"compiled_pattern",
|
||||
"disabled_session",
|
||||
"chat_scope",
|
||||
)
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.name: str = name
|
||||
self.full_name: str = f"{plugin_id}.{name}"
|
||||
self.component_type: ComponentTypes = ComponentTypes(component_type)
|
||||
@@ -91,20 +113,35 @@ class ComponentEntry:
|
||||
self.metadata: Dict[str, Any] = metadata
|
||||
self.enabled: bool = metadata.get("enabled", True)
|
||||
self.disabled_session: Set[str] = set()
|
||||
self.chat_scope: ComponentChatScope = _normalize_chat_scope(chat_scope)
|
||||
|
||||
|
||||
class ActionEntry(ComponentEntry):
|
||||
"""Action 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
|
||||
class CommandEntry(ComponentEntry):
|
||||
"""Command 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
self.aliases: List[str] = metadata.get("aliases", [])
|
||||
self.compiled_pattern: Optional[re.Pattern] = None
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
@@ -117,7 +154,14 @@ class CommandEntry(ComponentEntry):
|
||||
class ToolEntry(ComponentEntry):
|
||||
"""Tool 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.description: str = str(metadata.get("description", "") or "").strip()
|
||||
self.brief_description: str = str(
|
||||
metadata.get("brief_description", self.description) or self.description or f"工具 {name}"
|
||||
@@ -128,7 +172,7 @@ class ToolEntry(ComponentEntry):
|
||||
self.detailed_description: str = detailed_description
|
||||
self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip()
|
||||
self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip()
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
if not self.detailed_description:
|
||||
parameters_schema = self._get_parameters_schema()
|
||||
@@ -197,23 +241,37 @@ class ToolEntry(ComponentEntry):
|
||||
class EventHandlerEntry(ComponentEntry):
|
||||
"""EventHandler 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.event_type: str = metadata.get("event_type", "")
|
||||
self.weight: int = metadata.get("weight", 0)
|
||||
self.intercept_message: bool = metadata.get("intercept_message", False)
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
|
||||
class HookHandlerEntry(ComponentEntry):
|
||||
"""HookHandler 组件条目。"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.hook: str = self._normalize_hook_name(metadata.get("hook", ""))
|
||||
self.mode: str = self._normalize_mode(metadata.get("mode", "blocking"))
|
||||
self.order: str = self._normalize_order(metadata.get("order", "normal"))
|
||||
self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0))
|
||||
self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip"))
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_error_policy(raw_value: Any) -> str:
|
||||
@@ -332,13 +390,20 @@ class HookHandlerEntry(ComponentEntry):
|
||||
class MessageGatewayEntry(ComponentEntry):
|
||||
"""MessageGateway 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> None:
|
||||
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
|
||||
self.platform: str = str(metadata.get("platform", "") or "").strip()
|
||||
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
|
||||
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
|
||||
self.scope: str = str(metadata.get("scope", "") or "").strip()
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_route_type(raw_value: Any) -> str:
|
||||
@@ -578,6 +643,7 @@ class ComponentRegistry:
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> ComponentEntry:
|
||||
"""根据声明构造组件条目。
|
||||
|
||||
@@ -599,18 +665,18 @@ class ComponentRegistry:
|
||||
normalized_metadata = dict(metadata)
|
||||
if normalized_type == ComponentTypes.ACTION:
|
||||
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
|
||||
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
|
||||
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata, chat_scope)
|
||||
elif normalized_type == ComponentTypes.COMMAND:
|
||||
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
elif normalized_type == ComponentTypes.TOOL:
|
||||
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
elif normalized_type == ComponentTypes.EVENT_HANDLER:
|
||||
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
elif normalized_type == ComponentTypes.HOOK_HANDLER:
|
||||
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
self._validate_hook_handler_entry(component)
|
||||
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
|
||||
else:
|
||||
raise ComponentRegistrationError(
|
||||
f"组件类型 {component_type} 不存在",
|
||||
@@ -662,7 +728,14 @@ class ComponentRegistry:
|
||||
self._by_plugin.setdefault(component.plugin_id, []).append(component)
|
||||
|
||||
# ====== 注册 / 注销 ======
|
||||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
def register_component(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
chat_scope: str = "all",
|
||||
) -> bool:
|
||||
"""注册单个组件。
|
||||
|
||||
Args:
|
||||
@@ -678,7 +751,7 @@ class ComponentRegistry:
|
||||
ComponentRegistrationError: 组件声明不合法时抛出。
|
||||
"""
|
||||
|
||||
component = self._build_component_entry(name, component_type, plugin_id, metadata)
|
||||
component = self._build_component_entry(name, component_type, plugin_id, metadata, chat_scope)
|
||||
self._add_component_entry(component)
|
||||
return True
|
||||
|
||||
@@ -701,14 +774,19 @@ class ComponentRegistry:
|
||||
|
||||
prepared_components: List[ComponentEntry] = []
|
||||
for component_data in components:
|
||||
raw_metadata = (
|
||||
dict(component_data.get("metadata", {}))
|
||||
if isinstance(component_data.get("metadata"), dict)
|
||||
else {}
|
||||
)
|
||||
chat_scope = str(component_data.get("chat_scope", raw_metadata.pop("chat_scope", "all")) or "all")
|
||||
prepared_components.append(
|
||||
self._build_component_entry(
|
||||
name=str(component_data.get("name", "") or ""),
|
||||
component_type=str(component_data.get("component_type", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=component_data.get("metadata", {})
|
||||
if isinstance(component_data.get("metadata"), dict)
|
||||
else {},
|
||||
metadata=raw_metadata,
|
||||
chat_scope=chat_scope,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -733,9 +811,19 @@ class ComponentRegistry:
|
||||
return len(comps)
|
||||
|
||||
# ====== 启用 / 禁用 ======
|
||||
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
|
||||
def check_component_enabled(
|
||||
self,
|
||||
component: ComponentEntry,
|
||||
session_id: Optional[str] = None,
|
||||
is_group_chat: Optional[bool] = None,
|
||||
):
|
||||
if session_id and session_id in component.disabled_session:
|
||||
return False
|
||||
if is_group_chat is not None:
|
||||
if component.chat_scope == "group" and is_group_chat is not True:
|
||||
return False
|
||||
if component.chat_scope == "private" and is_group_chat is not False:
|
||||
return False
|
||||
return component.enabled
|
||||
|
||||
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
|
||||
@@ -806,7 +894,12 @@ class ComponentRegistry:
|
||||
return self._components.get(full_name)
|
||||
|
||||
def get_components_by_type(
|
||||
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
self,
|
||||
component_type: str,
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
is_group_chat: Optional[bool] = None,
|
||||
) -> List[ComponentEntry]:
|
||||
"""按类型查询组件
|
||||
|
||||
@@ -830,12 +923,16 @@ class ComponentRegistry:
|
||||
if self._is_legacy_action_component(component)
|
||||
]
|
||||
if enabled_only:
|
||||
return [component for component in action_components if self.check_component_enabled(component, session_id)]
|
||||
return [
|
||||
component
|
||||
for component in action_components
|
||||
if self.check_component_enabled(component, session_id, is_group_chat)
|
||||
]
|
||||
return action_components
|
||||
|
||||
type_dict = self._by_type.get(comp_type, {})
|
||||
if enabled_only:
|
||||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
|
||||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id, is_group_chat)]
|
||||
return list(type_dict.values())
|
||||
|
||||
def get_components_by_plugin(
|
||||
|
||||
@@ -4,7 +4,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
from src.core.tooling import (
|
||||
ToolAvailabilityContext,
|
||||
ToolExecutionContext,
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
ToolProvider,
|
||||
ToolSpec,
|
||||
)
|
||||
|
||||
from .component_query import component_query_service
|
||||
|
||||
@@ -15,10 +22,13 @@ class PluginToolProvider(ToolProvider):
|
||||
provider_name = "plugin_runtime"
|
||||
provider_type = "plugin"
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
async def list_tools(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolSpec]:
|
||||
"""列出插件运行时当前可用的工具声明。"""
|
||||
|
||||
return list(component_query_service.get_llm_available_tool_specs().values())
|
||||
return list(component_query_service.get_llm_available_tool_specs(context=context).values())
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user