feat:修复门控多重result问题,新增at动作,插件现在运行chat_id指定或chat_type指定

This commit is contained in:
SengokuCola
2026-04-22 00:11:14 +08:00
parent 363c0a77b7
commit f1563ede65
12 changed files with 833 additions and 122 deletions

View File

@@ -0,0 +1,176 @@
from importlib import util
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
import sys
import pytest
from src.common.data_models.message_component_data_model import AtComponent, TextComponent
from src.core.tooling import ToolExecutionResult, ToolInvocation
_MISSING_MODULE = object()
_module_overrides: dict[str, object] = {}
def _override_module(module_name: str, module: ModuleType) -> None:
_module_overrides[module_name] = sys.modules.get(module_name, _MISSING_MODULE)
sys.modules[module_name] = module
def _restore_overridden_modules() -> None:
for module_name, previous_module in reversed(_module_overrides.items()):
if previous_module is _MISSING_MODULE:
sys.modules.pop(module_name, None)
else:
sys.modules[module_name] = previous_module
_module_overrides.clear()
fake_cli_sender_module = ModuleType("src.cli.maisaka_cli_sender")
fake_cli_sender_module.CLI_PLATFORM_NAME = "cli"
fake_cli_sender_module.render_cli_message = lambda text: text
fake_cli_module = ModuleType("src.cli")
fake_cli_module.maisaka_cli_sender = fake_cli_sender_module
fake_send_service_module = ModuleType("src.services.send_service")
fake_send_service_module._send_to_target_with_message = None
fake_services_module = ModuleType("src.services")
fake_services_module.send_service = fake_send_service_module
_override_module("src.cli", fake_cli_module)
_override_module("src.cli.maisaka_cli_sender", fake_cli_sender_module)
_override_module("src.services", fake_services_module)
_override_module("src.services.send_service", fake_send_service_module)
AT_TOOL_PATH = Path(__file__).resolve().parents[1] / "src" / "maisaka" / "builtin_tool" / "at.py"
at_tool_spec = util.spec_from_file_location("_test_maisaka_builtin_at_tool", AT_TOOL_PATH)
assert at_tool_spec is not None and at_tool_spec.loader is not None
at_tool = util.module_from_spec(at_tool_spec)
sys.modules["_test_maisaka_builtin_at_tool"] = at_tool
try:
at_tool_spec.loader.exec_module(at_tool)
finally:
_restore_overridden_modules()
class _ToolCtx:
def __init__(self, runtime: SimpleNamespace) -> None:
self.runtime = runtime
@staticmethod
def build_success_result(
tool_name: str,
content: str = "",
structured_content: Any = None,
metadata: dict[str, Any] | None = None,
) -> ToolExecutionResult:
return ToolExecutionResult(
tool_name=tool_name,
success=True,
content=content,
structured_content=structured_content,
metadata=dict(metadata or {}),
)
@staticmethod
def build_failure_result(
tool_name: str,
error_message: str,
structured_content: Any = None,
metadata: dict[str, Any] | None = None,
) -> ToolExecutionResult:
return ToolExecutionResult(
tool_name=tool_name,
success=False,
error_message=error_message,
structured_content=structured_content,
metadata=dict(metadata or {}),
)
def append_guided_reply_to_chat_history(self, reply_text: str) -> None:
self.runtime._chat_history.append(reply_text)
def _build_tool_ctx(*, group_id: str = "group-1") -> _ToolCtx:
target_message = SimpleNamespace(
message_info=SimpleNamespace(
user_info=SimpleNamespace(
user_id="target-user-1",
user_nickname="目标昵称",
user_cardname="群名片",
)
)
)
runtime = SimpleNamespace(
_source_messages_by_id={"msg-1": target_message},
chat_stream=SimpleNamespace(platform="qq", group_id=group_id),
session_id="session-1",
log_prefix="[test-at]",
_record_reply_sent=lambda: None,
_chat_history=[],
)
return _ToolCtx(runtime=runtime)
def test_at_tool_spec_does_not_embed_visibility_metadata() -> None:
tool_spec = at_tool.get_tool_spec()
assert tool_spec.name == "at"
assert "deferred" not in tool_spec.metadata
assert "visibility" not in tool_spec.metadata
@pytest.mark.asyncio
async def test_at_tool_sends_at_component_by_msg_id(monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, Any] = {}
async def fake_send_to_target_with_message(**kwargs: Any) -> object:
captured.update(kwargs)
return SimpleNamespace(message_id="sent-msg-1")
monkeypatch.setattr(at_tool.send_service, "_send_to_target_with_message", fake_send_to_target_with_message)
result = await at_tool.handle_tool(
_build_tool_ctx(),
ToolInvocation(tool_name="at", arguments={"msg_id": "msg-1", "text": "看这里"}),
)
assert result.success is True
assert result.structured_content["target_user_id"] == "target-user-1"
assert result.structured_content["target_user_name"] == "群名片"
assert captured["stream_id"] == "session-1"
assert captured["display_message"] == "@群名片 看这里"
assert captured["sync_to_maisaka_history"] is True
assert captured["maisaka_source_kind"] == "guided_reply"
components = captured["message_sequence"].components
assert isinstance(components[0], AtComponent)
assert components[0].target_user_id == "target-user-1"
assert components[0].target_user_nickname == "目标昵称"
assert components[0].target_user_cardname == "群名片"
assert isinstance(components[1], TextComponent)
assert components[1].text == " 看这里"
@pytest.mark.asyncio
async def test_at_tool_rejects_private_chat() -> None:
result = await at_tool.handle_tool(
_build_tool_ctx(group_id=""),
ToolInvocation(tool_name="at", arguments={"msg_id": "msg-1"}),
)
assert result.success is False
assert "群聊" in result.error_message
@pytest.mark.asyncio
async def test_at_tool_rejects_unknown_msg_id() -> None:
result = await at_tool.handle_tool(
_build_tool_ctx(),
ToolInvocation(tool_name="at", arguments={"msg_id": "missing-msg"}),
)
assert result.success is False
assert result.structured_content == {"msg_id": "missing-msg"}

View File

@@ -0,0 +1,96 @@
from types import SimpleNamespace
import pytest
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
from src.maisaka.tool_provider import MaisakaBuiltinToolProvider
from src.plugin_runtime.component_query import ComponentQueryService
from src.plugin_runtime.host.component_registry import ComponentRegistry
@pytest.mark.asyncio
async def test_builtin_at_is_exposed_only_in_group_chats() -> None:
registry = ToolRegistry()
registry.register_provider(MaisakaBuiltinToolProvider())
group_specs = await registry.list_tools(ToolAvailabilityContext(session_id="group-1", is_group_chat=True))
private_specs = await registry.list_tools(ToolAvailabilityContext(session_id="private-1", is_group_chat=False))
default_specs = await registry.list_tools()
assert "at" in {tool_spec.name for tool_spec in group_specs}
assert "at" not in {tool_spec.name for tool_spec in private_specs}
assert "at" in {tool_spec.name for tool_spec in default_specs}
def test_plugin_tool_chat_scope_uses_component_field(monkeypatch: pytest.MonkeyPatch) -> None:
service = ComponentQueryService()
registry = ComponentRegistry()
supervisor = SimpleNamespace(component_registry=registry)
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
registry.register_plugin_components(
"scope_plugin",
[
{
"name": "group_tool",
"component_type": "TOOL",
"chat_scope": "group",
"metadata": {"description": "group only"},
},
{
"name": "private_tool",
"component_type": "TOOL",
"chat_scope": "private",
"metadata": {"description": "private only"},
},
{
"name": "all_tool",
"component_type": "TOOL",
"metadata": {"description": "all chats"},
},
],
)
group_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="group-1", is_group_chat=True)
)
private_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="private-1", is_group_chat=False)
)
group_entry = registry.get_component("scope_plugin.group_tool")
assert group_entry is not None
assert group_entry.chat_scope == "group"
assert "chat_scope" not in group_entry.metadata
assert set(group_specs) == {"group_tool", "all_tool"}
assert set(private_specs) == {"private_tool", "all_tool"}
def test_plugin_tool_session_disable_still_filters_specific_chat(monkeypatch: pytest.MonkeyPatch) -> None:
service = ComponentQueryService()
registry = ComponentRegistry()
supervisor = SimpleNamespace(component_registry=registry)
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
registry.register_plugin_components(
"mute_plugin",
[
{
"name": "mute",
"component_type": "TOOL",
"chat_scope": "group",
"metadata": {"description": "mute group member"},
}
],
)
registry.set_component_enabled("mute_plugin.mute", False, session_id="group-disabled")
disabled_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="group-disabled", is_group_chat=True)
)
enabled_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="group-enabled", is_group_chat=True)
)
assert "mute" not in disabled_specs
assert "mute" in enabled_specs

View File

@@ -222,6 +222,18 @@ class ToolExecutionContext:
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class ToolAvailabilityContext:
"""工具暴露可用性判断上下文。"""
session_id: str = ""
stream_id: str = ""
is_group_chat: bool | None = None
group_id: str = ""
user_id: str = ""
platform: str = ""
@dataclass(slots=True)
class ToolExecutionResult:
"""统一工具执行结果。"""
@@ -264,7 +276,10 @@ class ToolProvider(Protocol):
provider_name: str
provider_type: str
async def list_tools(self) -> list[ToolSpec]:
async def list_tools(
self,
context: Optional[ToolAvailabilityContext] = None,
) -> list[ToolSpec]:
"""列出当前 Provider 暴露的全部工具。"""
...
@@ -308,7 +323,10 @@ class ToolRegistry:
self._providers = [item for item in self._providers if item.provider_name != provider_name]
async def list_tools(self) -> list[ToolSpec]:
async def list_tools(
self,
context: Optional[ToolAvailabilityContext] = None,
) -> list[ToolSpec]:
"""按 Provider 顺序列出全部去重后的工具。
Returns:
@@ -319,7 +337,7 @@ class ToolRegistry:
seen_names: set[str] = set()
for provider in self._providers:
provider_specs = await provider.list_tools()
provider_specs = await provider.list_tools(context)
for spec in provider_specs:
if not spec.enabled:
continue
@@ -332,7 +350,11 @@ class ToolRegistry:
collected_specs.append(spec)
return collected_specs
async def get_tool_spec(self, tool_name: str) -> Optional[ToolSpec]:
async def get_tool_spec(
self,
tool_name: str,
context: Optional[ToolAvailabilityContext] = None,
) -> Optional[ToolSpec]:
"""查询指定工具声明。
Args:
@@ -342,12 +364,16 @@ class ToolRegistry:
Optional[ToolSpec]: 匹配到的工具声明。
"""
for spec in await self.list_tools():
for spec in await self.list_tools(context):
if spec.name == tool_name:
return spec
return None
async def has_tool(self, tool_name: str) -> bool:
async def has_tool(
self,
tool_name: str,
context: Optional[ToolAvailabilityContext] = None,
) -> bool:
"""判断指定工具是否存在。
Args:
@@ -357,16 +383,19 @@ class ToolRegistry:
bool: 是否存在。
"""
return await self.get_tool_spec(tool_name) is not None
return await self.get_tool_spec(tool_name, context) is not None
async def get_llm_definitions(self) -> list[ToolDefinitionInput]:
async def get_llm_definitions(
self,
context: Optional[ToolAvailabilityContext] = None,
) -> list[ToolDefinitionInput]:
"""获取供 LLM 使用的工具定义列表。
Returns:
list[ToolDefinitionInput]: 统一工具定义列表。
"""
return [spec.to_llm_definition() for spec in await self.list_tools()]
return [spec.to_llm_definition() for spec in await self.list_tools(context)]
async def invoke(
self,

View File

@@ -1,12 +1,16 @@
"""Maisaka 内置工具聚合入口。"""
from collections.abc import Awaitable, Callable
from typing import Dict, List, Optional
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.core.tooling import ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
from .at import get_tool_spec as get_at_tool_spec
from .at import handle_tool as handle_at_tool
from .context import BuiltinToolRuntimeContext
from .continue_tool import get_tool_spec as get_continue_tool_spec
from .continue_tool import handle_tool as handle_continue_tool
@@ -32,93 +36,152 @@ from .wait import get_tool_spec as get_wait_tool_spec
from .wait import handle_tool as handle_wait_tool
BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]]
BuiltinToolRawHandler = Callable[
[BuiltinToolRuntimeContext, ToolInvocation, Optional[ToolExecutionContext]],
Awaitable[ToolExecutionResult],
]
BuiltinToolStage = Literal["timing", "action"]
BuiltinToolVisibility = Literal["visible", "deferred", "hidden"]
BuiltinToolChatScope = Literal["all", "group", "private"]
def get_timing_tool_specs() -> List[ToolSpec]:
"""获取 Timing Gate 阶段可用的内置工具声明。"""
@dataclass(frozen=True)
class BuiltinToolEntry:
"""内置工具目录项,集中声明工具所属阶段与默认可见性。"""
return [
get_wait_tool_spec(),
get_no_reply_tool_spec(),
get_continue_tool_spec(),
]
name: str
get_spec: Callable[[], ToolSpec]
handle_tool: BuiltinToolRawHandler
stage: BuiltinToolStage
visibility: BuiltinToolVisibility = "visible"
chat_scope: BuiltinToolChatScope = "all"
def build_spec(self) -> ToolSpec:
"""生成带统一可见性元数据的工具声明。"""
tool_spec = deepcopy(self.get_spec())
tool_spec.metadata["builtin_stage"] = self.stage
tool_spec.metadata["visibility"] = self.visibility
return tool_spec
def get_action_tool_specs() -> List[ToolSpec]:
"""获取 Action Loop 阶段可用的内置工具声明。"""
def _get_query_memory_tool_spec() -> ToolSpec:
"""根据配置生成 query_memory 工具声明。"""
return [
get_finish_tool_spec(),
get_reply_tool_spec(),
get_view_complex_message_tool_spec(),
get_query_jargon_tool_spec(),
get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)),
get_send_emoji_tool_spec(),
get_tool_search_tool_spec(),
]
return get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool))
def get_builtin_tool_specs() -> List[ToolSpec]:
"""获取默认暴露的 Maisaka 内置工具声明。"""
return get_action_tool_specs()
BUILTIN_TOOL_ENTRIES: List[BuiltinToolEntry] = [
BuiltinToolEntry("wait", get_wait_tool_spec, handle_wait_tool, stage="timing"),
BuiltinToolEntry("no_reply", get_no_reply_tool_spec, handle_no_reply_tool, stage="timing"),
BuiltinToolEntry("continue", get_continue_tool_spec, handle_continue_tool, stage="timing"),
BuiltinToolEntry("finish", get_finish_tool_spec, handle_finish_tool, stage="action"),
BuiltinToolEntry("reply", get_reply_tool_spec, handle_reply_tool, stage="action"),
BuiltinToolEntry(
"view_complex_message",
get_view_complex_message_tool_spec,
handle_view_complex_message_tool,
stage="action",
),
BuiltinToolEntry("query_jargon", get_query_jargon_tool_spec, handle_query_jargon_tool, stage="action"),
BuiltinToolEntry("query_memory", _get_query_memory_tool_spec, handle_query_memory_tool, stage="action"),
BuiltinToolEntry(
"query_person_info",
get_query_person_info_tool_spec,
handle_query_person_info_tool,
stage="action",
visibility="hidden",
),
BuiltinToolEntry("send_emoji", get_send_emoji_tool_spec, handle_send_emoji_tool, stage="action"),
BuiltinToolEntry(
"at",
get_at_tool_spec,
handle_at_tool,
stage="action",
visibility="deferred",
chat_scope="group",
),
BuiltinToolEntry("tool_search", get_tool_search_tool_spec, handle_tool_search_tool, stage="action"),
]
def get_all_builtin_tool_specs() -> List[ToolSpec]:
def _get_builtin_tool_entries(
*,
stage: Optional[BuiltinToolStage] = None,
visibility: Optional[BuiltinToolVisibility] = None,
context: Optional[ToolAvailabilityContext] = None,
) -> List[BuiltinToolEntry]:
"""按阶段与可见性筛选内置工具目录项。"""
entries = BUILTIN_TOOL_ENTRIES
if stage is not None:
entries = [entry for entry in entries if entry.stage == stage]
if visibility is not None:
entries = [entry for entry in entries if entry.visibility == visibility]
if context is not None:
entries = [entry for entry in entries if _is_builtin_tool_available(entry, context)]
return entries
def _is_builtin_tool_available(entry: BuiltinToolEntry, context: ToolAvailabilityContext) -> bool:
"""判断内置工具是否适用于当前聊天。"""
if entry.chat_scope == "all":
return True
if entry.chat_scope == "group":
return context.is_group_chat is True
if entry.chat_scope == "private":
return context.is_group_chat is False
return True
def get_builtin_tool_visibility(tool_spec: ToolSpec) -> BuiltinToolVisibility:
"""读取工具声明里的可见性。"""
raw_visibility = str(tool_spec.metadata.get("visibility") or "").strip()
if raw_visibility == "deferred":
return "deferred"
if raw_visibility == "hidden":
return "hidden"
return "visible"
def is_builtin_tool_in_action_stage(tool_spec: ToolSpec) -> bool:
"""判断内置工具是否属于 Action Loop 阶段。"""
return str(tool_spec.metadata.get("builtin_stage") or "").strip() == "action"
def get_all_builtin_tool_specs(context: Optional[ToolAvailabilityContext] = None) -> List[ToolSpec]:
"""获取全部内置工具声明。"""
return [
*get_timing_tool_specs(),
get_finish_tool_spec(),
get_reply_tool_spec(),
get_view_complex_message_tool_spec(),
get_query_jargon_tool_spec(),
get_query_memory_tool_spec(enabled=True),
get_query_person_info_tool_spec(),
get_send_emoji_tool_spec(),
get_tool_search_tool_spec(),
]
return [entry.build_spec() for entry in _get_builtin_tool_entries(context=context)]
def get_timing_tools() -> List[ToolDefinitionInput]:
"""获取 Timing Gate 阶段的兼容工具定义。"""
return [tool_spec.to_llm_definition() for tool_spec in get_timing_tool_specs()]
def get_action_tools() -> List[ToolDefinitionInput]:
"""获取 Action Loop 阶段的兼容工具定义。"""
return [tool_spec.to_llm_definition() for tool_spec in get_action_tool_specs()]
tool_specs = [
entry.build_spec()
for entry in _get_builtin_tool_entries(stage="timing", visibility="visible")
]
return [tool_spec.to_llm_definition() for tool_spec in tool_specs if tool_spec.enabled]
def get_builtin_tools() -> List[ToolDefinitionInput]:
"""获取默认暴露给模型层的内置工具定义。"""
return get_action_tools()
tool_specs = [
entry.build_spec()
for entry in _get_builtin_tool_entries(stage="action", visibility="visible")
]
return [tool_spec.to_llm_definition() for tool_spec in tool_specs if tool_spec.enabled]
def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str, BuiltinToolHandler]:
"""构建内置工具处理器映射。"""
return {
"continue": lambda invocation, context=None: handle_continue_tool(tool_ctx, invocation, context),
"finish": lambda invocation, context=None: handle_finish_tool(tool_ctx, invocation, context),
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
"query_memory": lambda invocation, context=None: handle_query_memory_tool(tool_ctx, invocation, context),
"query_person_info": lambda invocation, context=None: handle_query_person_info_tool(
tool_ctx,
invocation,
context,
),
"wait": lambda invocation, context=None: handle_wait_tool(tool_ctx, invocation, context),
"send_emoji": lambda invocation, context=None: handle_send_emoji_tool(tool_ctx, invocation, context),
"tool_search": lambda invocation, context=None: handle_tool_search_tool(tool_ctx, invocation, context),
"view_complex_message": lambda invocation, context=None: handle_view_complex_message_tool(
tool_ctx,
invocation,
context,
),
entry.name: lambda invocation, context=None, entry=entry: entry.handle_tool(tool_ctx, invocation, context)
for entry in BUILTIN_TOOL_ENTRIES
}

View File

@@ -0,0 +1,186 @@
"""Maisaka 内置 at 工具。"""
from typing import Any, Optional, TYPE_CHECKING
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, TextComponent
from src.common.logger import get_logger
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.services import send_service
if TYPE_CHECKING:
from .context import BuiltinToolRuntimeContext
logger = get_logger("maisaka_builtin_at")
def get_tool_spec() -> ToolSpec:
"""获取 at 工具声明。"""
return ToolSpec(
name="at",
brief_description="根据一条已知 msg_id 找到发言用户,并发送一条 @ 该用户的消息。",
detailed_description=(
"参数说明:\n"
"- msg_idstring必填。要 @ 的目标用户发过的消息编号。\n"
"- textstring可选。@ 后追加发送的短文本;只想单独 @ 人时留空。\n"
"请优先从上下文里选择一条明确属于目标用户的 msg_id不要凭昵称或印象猜测用户。"
),
parameters_schema={
"type": "object",
"properties": {
"msg_id": {
"type": "string",
"description": "要 @ 的目标用户发过的消息编号。",
},
"text": {
"type": "string",
"description": "@ 后追加发送的短文本;只想单独 @ 人时留空。",
"default": "",
},
},
"required": ["msg_id"],
},
provider_name="maisaka_builtin",
provider_type="builtin",
)
def _get_target_user_info(target_message: Any) -> tuple[str, str, str]:
"""从目标消息中提取可用于构造 at 组件的用户信息。"""
message_info = getattr(target_message, "message_info", None)
user_info = getattr(message_info, "user_info", None)
target_user_id = str(getattr(user_info, "user_id", "") or "").strip()
target_user_nickname = str(getattr(user_info, "user_nickname", "") or "").strip()
target_user_cardname = str(getattr(user_info, "user_cardname", "") or "").strip()
return target_user_id, target_user_nickname, target_user_cardname
def _build_at_message_sequence(
*,
target_user_id: str,
target_user_nickname: str = "",
target_user_cardname: str = "",
text: str = "",
) -> MessageSequence:
"""构造 @ 用户的消息组件序列。"""
components = [
AtComponent(
target_user_id=target_user_id,
target_user_nickname=target_user_nickname or None,
target_user_cardname=target_user_cardname or None,
)
]
normalized_text = text.strip()
if normalized_text:
components.append(TextComponent(f" {normalized_text}"))
return MessageSequence(components=components)
async def handle_tool(
tool_ctx: "BuiltinToolRuntimeContext",
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 at 内置工具。"""
del context
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
text = str(invocation.arguments.get("text") or "").strip()
if not target_message_id:
return tool_ctx.build_failure_result(
invocation.tool_name,
"at 工具需要提供有效的 `msg_id` 参数。",
)
if not str(getattr(tool_ctx.runtime.chat_stream, "group_id", "") or "").strip():
return tool_ctx.build_failure_result(
invocation.tool_name,
"at 工具只能在群聊中使用。",
structured_content={"msg_id": target_message_id},
)
target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id)
if target_message is None:
return tool_ctx.build_failure_result(
invocation.tool_name,
f"未找到要 @ 的目标消息msg_id={target_message_id}",
structured_content={"msg_id": target_message_id},
)
target_user_id, target_user_nickname, target_user_cardname = _get_target_user_info(target_message)
if not target_user_id:
return tool_ctx.build_failure_result(
invocation.tool_name,
f"目标消息缺少有效用户 IDmsg_id={target_message_id}",
structured_content={"msg_id": target_message_id},
)
target_user_name = target_user_cardname or target_user_nickname or target_user_id
message_sequence = _build_at_message_sequence(
target_user_id=target_user_id,
target_user_nickname=target_user_nickname,
target_user_cardname=target_user_cardname,
text=text,
)
display_message = f"@{target_user_name}" + (f" {text}" if text else "")
try:
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
render_cli_message(display_message)
tool_ctx.append_guided_reply_to_chat_history(display_message)
sent_message = None
sent = True
else:
sent_message = await send_service._send_to_target_with_message(
message_sequence=message_sequence,
stream_id=tool_ctx.runtime.session_id,
display_message=display_message,
typing=False,
storage_message=True,
show_log=True,
sync_to_maisaka_history=True,
maisaka_source_kind="guided_reply",
)
sent = sent_message is not None
except Exception as exc:
logger.exception(
f"{tool_ctx.runtime.log_prefix} 发送 at 消息时发生异常: msg_id={target_message_id} user_id={target_user_id}"
)
return tool_ctx.build_failure_result(
invocation.tool_name,
f"发送 at 消息时发生异常:{exc}",
structured_content={
"msg_id": target_message_id,
"target_user_id": target_user_id,
"target_user_name": target_user_name,
},
)
if not sent:
return tool_ctx.build_failure_result(
invocation.tool_name,
"at 消息发送失败。",
structured_content={
"msg_id": target_message_id,
"target_user_id": target_user_id,
"target_user_name": target_user_name,
},
)
sent_message_id = str(getattr(sent_message, "message_id", "") or "").strip() if sent_message is not None else ""
tool_ctx.runtime._record_reply_sent()
return tool_ctx.build_success_result(
invocation.tool_name,
f"已 @ {target_user_name}",
structured_content={
"msg_id": target_message_id,
"target_user_id": target_user_id,
"target_user_name": target_user_name,
"text": text,
"sent_message_id": sent_message_id,
},
)

View File

@@ -12,7 +12,7 @@ from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.core.tooling import ToolRegistry
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat
@@ -502,7 +502,13 @@ class MaisakaChatLoopService:
if tool_definitions is not None:
all_tools = list(tool_definitions)
elif self._tool_registry is not None:
tool_specs = await self._tool_registry.list_tools()
tool_specs = await self._tool_registry.list_tools(
ToolAvailabilityContext(
session_id=self._session_id,
stream_id=self._session_id,
is_group_chat=self._is_group_chat,
)
)
all_tools = [tool_spec.to_llm_definition() for tool_spec in tool_specs]
else:
all_tools = [*get_builtin_tools(), *self._extra_tools]

View File

@@ -14,14 +14,14 @@ from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.core.tooling import ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.exceptions import ReqAbortException
from src.llm_models.payload_content.tool_option import ToolCall
from src.services import database_service as database_api
from src.services.memory_service import memory_service
from .builtin_tool import get_action_tool_specs
from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers
from .builtin_tool import get_builtin_tool_visibility, is_builtin_tool_in_action_stage
from .builtin_tool import get_timing_tools
from .chat_loop_service import ChatResponse
from .chat_history_visual_refresher import refresh_chat_history_visual_placeholders
@@ -54,8 +54,6 @@ logger = get_logger("maisaka_reasoning_engine")
TIMING_GATE_CONTEXT_LIMIT = 24
TIMING_GATE_MAX_TOKENS = 384
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply"}
ACTION_BUILTIN_TOOL_NAMES = {tool_spec.name for tool_spec in get_action_tool_specs()}
class MaisakaReasoningEngine:
@@ -175,15 +173,19 @@ class MaisakaReasoningEngine:
self._runtime.set_current_action_tool_names([])
return [], ""
tool_specs = await self._runtime._tool_registry.list_tools()
availability_context = self._build_tool_availability_context()
tool_specs = await self._runtime._tool_registry.list_tools(availability_context)
visible_builtin_tool_specs: list[ToolSpec] = []
deferred_tool_specs: list[ToolSpec] = []
for tool_spec in tool_specs:
if tool_spec.name in ACTION_HIDDEN_TOOL_NAMES:
continue
if tool_spec.provider_name == "maisaka_builtin":
if tool_spec.name in ACTION_BUILTIN_TOOL_NAMES:
if not is_builtin_tool_in_action_stage(tool_spec):
continue
visibility = get_builtin_tool_visibility(tool_spec)
if visibility == "visible":
visible_builtin_tool_specs.append(tool_spec)
elif visibility == "deferred":
deferred_tool_specs.append(tool_spec)
continue
deferred_tool_specs.append(tool_spec)
@@ -877,6 +879,19 @@ class MaisakaReasoningEngine:
reasoning=latest_thought,
)
def _build_tool_availability_context(self) -> ToolAvailabilityContext:
"""构造当前聊天的工具暴露上下文。"""
chat_stream = self._runtime.chat_stream
return ToolAvailabilityContext(
session_id=self._runtime.session_id,
stream_id=self._runtime.session_id,
is_group_chat=chat_stream.is_group_session,
group_id=str(getattr(chat_stream, "group_id", "") or "").strip(),
user_id=str(getattr(chat_stream, "user_id", "") or "").strip(),
platform=str(getattr(chat_stream, "platform", "") or "").strip(),
)
def _build_tool_execution_context(
self,
latest_thought: str,
@@ -1197,6 +1212,8 @@ class MaisakaReasoningEngine:
source_kind="timing_gate",
)
)
if tool_call.func_name == "wait":
return
self._append_tool_execution_result(tool_call, result)
def _build_tool_result_summary(self, tool_call: ToolCall, result: ToolExecutionResult) -> str:
@@ -1290,9 +1307,10 @@ class MaisakaReasoningEngine:
return False, tool_result_summaries, tool_monitor_results
execution_context = self._build_tool_execution_context(latest_thought, anchor_message)
availability_context = self._build_tool_availability_context()
tool_spec_map = {
tool_spec.name: tool_spec
for tool_spec in await self._runtime._tool_registry.list_tools()
for tool_spec in await self._runtime._tool_registry.list_tools(availability_context)
}
total_tool_count = len(tool_calls)
for tool_index, tool_call in enumerate(tool_calls, start=1):

View File

@@ -5,7 +5,14 @@ from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Dict, Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
from src.core.tooling import (
ToolAvailabilityContext,
ToolExecutionContext,
ToolExecutionResult,
ToolInvocation,
ToolProvider,
ToolSpec,
)
from .builtin_tool import get_all_builtin_tool_specs
@@ -27,10 +34,13 @@ class MaisakaBuiltinToolProvider(ToolProvider):
self._handlers = dict(handlers or {})
async def list_tools(self) -> list[ToolSpec]:
async def list_tools(
self,
context: Optional[ToolAvailabilityContext] = None,
) -> list[ToolSpec]:
"""列出全部内置工具。"""
return list(get_all_builtin_tool_specs())
return list(get_all_builtin_tool_specs(context))
async def invoke(
self,

View File

@@ -4,7 +4,14 @@ from __future__ import annotations
from typing import Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
from src.core.tooling import (
ToolAvailabilityContext,
ToolExecutionContext,
ToolExecutionResult,
ToolInvocation,
ToolProvider,
ToolSpec,
)
from .manager import MCPManager
@@ -24,9 +31,13 @@ class MCPToolProvider(ToolProvider):
self._manager = manager
async def list_tools(self) -> list[ToolSpec]:
async def list_tools(
self,
context: Optional[ToolAvailabilityContext] = None,
) -> list[ToolSpec]:
"""列出全部 MCP 工具。"""
del context
return self._manager.get_tool_specs()
async def invoke(

View File

@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
from src.common.logger import get_logger
from src.core.tooling import (
ToolAvailabilityContext,
ToolExecutionContext,
ToolExecutionResult,
ToolInvocation,
@@ -72,6 +73,7 @@ class ComponentQueryService:
component_type: ComponentType,
*,
enabled_only: bool = True,
context: Optional[ToolAvailabilityContext] = None,
) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
"""遍历指定类型的全部组件条目。
@@ -87,11 +89,15 @@ class ComponentQueryService:
if host_component_type is None:
return []
session_id = context.session_id if context is not None else None
is_group_chat = context.is_group_chat if context is not None else None
collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
for supervisor in self._iter_supervisors():
for component in supervisor.component_registry.get_components_by_type(
host_component_type,
enabled_only=enabled_only,
session_id=session_id,
is_group_chat=is_group_chat,
):
collected_entries.append((supervisor, component))
return collected_entries
@@ -657,7 +663,10 @@ class ComponentQueryService:
tool_entry = cast("ToolEntry", entry)
return self._build_tool_executor(supervisor, tool_entry.plugin_id, tool_entry.name, tool_entry.invoke_method)
def get_llm_available_tool_specs(self) -> Dict[str, ToolSpec]:
def get_llm_available_tool_specs(
self,
context: Optional[ToolAvailabilityContext] = None,
) -> Dict[str, ToolSpec]:
"""获取当前可供 LLM 使用的统一工具声明集合。
Returns:
@@ -665,7 +674,7 @@ class ComponentQueryService:
"""
collected_specs: Dict[str, ToolSpec] = {}
for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL):
for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL, context=context):
if entry.name in collected_specs:
self._log_duplicate_component(ComponentType.TOOL, entry.name)
continue

View File

@@ -10,7 +10,7 @@
"""
from enum import Enum
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, TypedDict
import contextlib
import re
@@ -58,6 +58,20 @@ class ComponentTypes(str, Enum):
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
ComponentChatScope = Literal["all", "group", "private"]
def _normalize_chat_scope(raw_value: Any) -> ComponentChatScope:
"""规范化组件聊天类型适用范围。"""
normalized_value = str(raw_value or "all").strip().lower()
if normalized_value == "group":
return "group"
if normalized_value == "private":
return "private"
return "all"
class StatusDict(TypedDict):
total: int
action: int
@@ -81,9 +95,17 @@ class ComponentEntry:
"enabled",
"compiled_pattern",
"disabled_session",
"chat_scope",
)
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> None:
self.name: str = name
self.full_name: str = f"{plugin_id}.{name}"
self.component_type: ComponentTypes = ComponentTypes(component_type)
@@ -91,20 +113,35 @@ class ComponentEntry:
self.metadata: Dict[str, Any] = metadata
self.enabled: bool = metadata.get("enabled", True)
self.disabled_session: Set[str] = set()
self.chat_scope: ComponentChatScope = _normalize_chat_scope(chat_scope)
class ActionEntry(ComponentEntry):
"""Action 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
super().__init__(name, component_type, plugin_id, metadata)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> None:
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
class CommandEntry(ComponentEntry):
"""Command 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
super().__init__(name, component_type, plugin_id, metadata)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> None:
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
self.aliases: List[str] = metadata.get("aliases", [])
self.compiled_pattern: Optional[re.Pattern] = None
if pattern := metadata.get("command_pattern", ""):
@@ -117,7 +154,14 @@ class CommandEntry(ComponentEntry):
class ToolEntry(ComponentEntry):
"""Tool 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> None:
self.description: str = str(metadata.get("description", "") or "").strip()
self.brief_description: str = str(
metadata.get("brief_description", self.description) or self.description or f"工具 {name}"
@@ -128,7 +172,7 @@ class ToolEntry(ComponentEntry):
self.detailed_description: str = detailed_description
self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip()
self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip()
super().__init__(name, component_type, plugin_id, metadata)
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
if not self.detailed_description:
parameters_schema = self._get_parameters_schema()
@@ -197,23 +241,37 @@ class ToolEntry(ComponentEntry):
class EventHandlerEntry(ComponentEntry):
"""EventHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> None:
self.event_type: str = metadata.get("event_type", "")
self.weight: int = metadata.get("weight", 0)
self.intercept_message: bool = metadata.get("intercept_message", False)
super().__init__(name, component_type, plugin_id, metadata)
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
class HookHandlerEntry(ComponentEntry):
"""HookHandler 组件条目。"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> None:
self.hook: str = self._normalize_hook_name(metadata.get("hook", ""))
self.mode: str = self._normalize_mode(metadata.get("mode", "blocking"))
self.order: str = self._normalize_order(metadata.get("order", "normal"))
self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0))
self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip"))
super().__init__(name, component_type, plugin_id, metadata)
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
@staticmethod
def _normalize_error_policy(raw_value: Any) -> str:
@@ -332,13 +390,20 @@ class HookHandlerEntry(ComponentEntry):
class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> None:
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
self.platform: str = str(metadata.get("platform", "") or "").strip()
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
self.scope: str = str(metadata.get("scope", "") or "").strip()
super().__init__(name, component_type, plugin_id, metadata)
super().__init__(name, component_type, plugin_id, metadata, chat_scope)
@staticmethod
def _normalize_route_type(raw_value: Any) -> str:
@@ -578,6 +643,7 @@ class ComponentRegistry:
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> ComponentEntry:
"""根据声明构造组件条目。
@@ -599,18 +665,18 @@ class ComponentRegistry:
normalized_metadata = dict(metadata)
if normalized_type == ComponentTypes.ACTION:
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata, chat_scope)
elif normalized_type == ComponentTypes.COMMAND:
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
elif normalized_type == ComponentTypes.TOOL:
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
elif normalized_type == ComponentTypes.EVENT_HANDLER:
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
elif normalized_type == ComponentTypes.HOOK_HANDLER:
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
self._validate_hook_handler_entry(component)
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata, chat_scope)
else:
raise ComponentRegistrationError(
f"组件类型 {component_type} 不存在",
@@ -662,7 +728,14 @@ class ComponentRegistry:
self._by_plugin.setdefault(component.plugin_id, []).append(component)
# ====== 注册 / 注销 ======
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
def register_component(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
) -> bool:
"""注册单个组件。
Args:
@@ -678,7 +751,7 @@ class ComponentRegistry:
ComponentRegistrationError: 组件声明不合法时抛出。
"""
component = self._build_component_entry(name, component_type, plugin_id, metadata)
component = self._build_component_entry(name, component_type, plugin_id, metadata, chat_scope)
self._add_component_entry(component)
return True
@@ -701,14 +774,19 @@ class ComponentRegistry:
prepared_components: List[ComponentEntry] = []
for component_data in components:
raw_metadata = (
dict(component_data.get("metadata", {}))
if isinstance(component_data.get("metadata"), dict)
else {}
)
chat_scope = str(component_data.get("chat_scope", raw_metadata.pop("chat_scope", "all")) or "all")
prepared_components.append(
self._build_component_entry(
name=str(component_data.get("name", "") or ""),
component_type=str(component_data.get("component_type", "") or ""),
plugin_id=plugin_id,
metadata=component_data.get("metadata", {})
if isinstance(component_data.get("metadata"), dict)
else {},
metadata=raw_metadata,
chat_scope=chat_scope,
)
)
@@ -733,9 +811,19 @@ class ComponentRegistry:
return len(comps)
# ====== 启用 / 禁用 ======
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
def check_component_enabled(
self,
component: ComponentEntry,
session_id: Optional[str] = None,
is_group_chat: Optional[bool] = None,
):
if session_id and session_id in component.disabled_session:
return False
if is_group_chat is not None:
if component.chat_scope == "group" and is_group_chat is not True:
return False
if component.chat_scope == "private" and is_group_chat is not False:
return False
return component.enabled
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
@@ -806,7 +894,12 @@ class ComponentRegistry:
return self._components.get(full_name)
def get_components_by_type(
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
self,
component_type: str,
*,
enabled_only: bool = True,
session_id: Optional[str] = None,
is_group_chat: Optional[bool] = None,
) -> List[ComponentEntry]:
"""按类型查询组件
@@ -830,12 +923,16 @@ class ComponentRegistry:
if self._is_legacy_action_component(component)
]
if enabled_only:
return [component for component in action_components if self.check_component_enabled(component, session_id)]
return [
component
for component in action_components
if self.check_component_enabled(component, session_id, is_group_chat)
]
return action_components
type_dict = self._by_type.get(comp_type, {})
if enabled_only:
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id, is_group_chat)]
return list(type_dict.values())
def get_components_by_plugin(

View File

@@ -4,7 +4,14 @@ from __future__ import annotations
from typing import Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
from src.core.tooling import (
ToolAvailabilityContext,
ToolExecutionContext,
ToolExecutionResult,
ToolInvocation,
ToolProvider,
ToolSpec,
)
from .component_query import component_query_service
@@ -15,10 +22,13 @@ class PluginToolProvider(ToolProvider):
provider_name = "plugin_runtime"
provider_type = "plugin"
async def list_tools(self) -> list[ToolSpec]:
async def list_tools(
self,
context: Optional[ToolAvailabilityContext] = None,
) -> list[ToolSpec]:
"""列出插件运行时当前可用的工具声明。"""
return list(component_query_service.get_llm_available_tool_specs().values())
return list(component_query_service.get_llm_available_tool_specs(context=context).values())
async def invoke(
self,