feat:修复门控多重result问题,新增at动作,插件现在运行chat_id指定或chat_type指定
This commit is contained in:
@@ -222,6 +222,18 @@ class ToolExecutionContext:
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolAvailabilityContext:
|
||||
"""工具暴露可用性判断上下文。"""
|
||||
|
||||
session_id: str = ""
|
||||
stream_id: str = ""
|
||||
is_group_chat: bool | None = None
|
||||
group_id: str = ""
|
||||
user_id: str = ""
|
||||
platform: str = ""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolExecutionResult:
|
||||
"""统一工具执行结果。"""
|
||||
@@ -264,7 +276,10 @@ class ToolProvider(Protocol):
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
async def list_tools(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolSpec]:
|
||||
"""列出当前 Provider 暴露的全部工具。"""
|
||||
...
|
||||
|
||||
@@ -308,7 +323,10 @@ class ToolRegistry:
|
||||
|
||||
self._providers = [item for item in self._providers if item.provider_name != provider_name]
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
async def list_tools(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolSpec]:
|
||||
"""按 Provider 顺序列出全部去重后的工具。
|
||||
|
||||
Returns:
|
||||
@@ -319,7 +337,7 @@ class ToolRegistry:
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for provider in self._providers:
|
||||
provider_specs = await provider.list_tools()
|
||||
provider_specs = await provider.list_tools(context)
|
||||
for spec in provider_specs:
|
||||
if not spec.enabled:
|
||||
continue
|
||||
@@ -332,7 +350,11 @@ class ToolRegistry:
|
||||
collected_specs.append(spec)
|
||||
return collected_specs
|
||||
|
||||
async def get_tool_spec(self, tool_name: str) -> Optional[ToolSpec]:
|
||||
async def get_tool_spec(
|
||||
self,
|
||||
tool_name: str,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> Optional[ToolSpec]:
|
||||
"""查询指定工具声明。
|
||||
|
||||
Args:
|
||||
@@ -342,12 +364,16 @@ class ToolRegistry:
|
||||
Optional[ToolSpec]: 匹配到的工具声明。
|
||||
"""
|
||||
|
||||
for spec in await self.list_tools():
|
||||
for spec in await self.list_tools(context):
|
||||
if spec.name == tool_name:
|
||||
return spec
|
||||
return None
|
||||
|
||||
async def has_tool(self, tool_name: str) -> bool:
|
||||
async def has_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> bool:
|
||||
"""判断指定工具是否存在。
|
||||
|
||||
Args:
|
||||
@@ -357,16 +383,19 @@ class ToolRegistry:
|
||||
bool: 是否存在。
|
||||
"""
|
||||
|
||||
return await self.get_tool_spec(tool_name) is not None
|
||||
return await self.get_tool_spec(tool_name, context) is not None
|
||||
|
||||
async def get_llm_definitions(self) -> list[ToolDefinitionInput]:
|
||||
async def get_llm_definitions(
|
||||
self,
|
||||
context: Optional[ToolAvailabilityContext] = None,
|
||||
) -> list[ToolDefinitionInput]:
|
||||
"""获取供 LLM 使用的工具定义列表。
|
||||
|
||||
Returns:
|
||||
list[ToolDefinitionInput]: 统一工具定义列表。
|
||||
"""
|
||||
|
||||
return [spec.to_llm_definition() for spec in await self.list_tools()]
|
||||
return [spec.to_llm_definition() for spec in await self.list_tools(context)]
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user