feat:采用tool索引展开方式压缩tool,移除tool过滤器
This commit is contained in:
@@ -5,20 +5,18 @@ from datetime import datetime
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
from rich.console import RenderableType
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
|
||||
from src.plugin_runtime.hook_payloads import (
|
||||
deserialize_prompt_messages,
|
||||
@@ -54,13 +52,6 @@ class ChatResponse:
|
||||
prompt_section: Optional[RenderableType] = None
|
||||
|
||||
|
||||
class ToolFilterSelection(BaseModel):
|
||||
"""工具筛选响应。"""
|
||||
|
||||
selected_tool_names: list[str] = PydanticField(default_factory=list)
|
||||
"""经过预筛后保留的候选工具名称列表。"""
|
||||
|
||||
|
||||
logger = get_logger("maisaka_chat_loop")
|
||||
|
||||
|
||||
@@ -217,10 +208,6 @@ class MaisakaChatLoopService:
|
||||
else:
|
||||
self._chat_system_prompt = chat_system_prompt
|
||||
self._llm_chat = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
|
||||
self._tool_filter_llm = LLMServiceClient(
|
||||
task_name=global_config.maisaka.tool_filter_task_name,
|
||||
request_type="maisaka_tool_filter",
|
||||
)
|
||||
|
||||
@property
|
||||
def personality_prompt(self) -> str:
|
||||
@@ -399,6 +386,7 @@ class MaisakaChatLoopService:
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
*,
|
||||
injected_user_messages: Sequence[str] | None = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
"""构造发给大模型的消息列表。
|
||||
@@ -420,61 +408,19 @@ class MaisakaChatLoopService:
|
||||
if llm_message is not None:
|
||||
messages.append(llm_message)
|
||||
|
||||
for injected_message in injected_user_messages or []:
|
||||
normalized_message = str(injected_message or "").strip()
|
||||
if not normalized_message:
|
||||
continue
|
||||
messages.append(
|
||||
MessageBuilder()
|
||||
.set_role(RoleType.User)
|
||||
.add_text_content(normalized_message)
|
||||
.build()
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _is_builtin_tool_spec(tool_spec: ToolSpec) -> bool:
|
||||
"""判断一个工具是否属于默认内置工具。
|
||||
|
||||
Args:
|
||||
tool_spec: 待判断的工具声明。
|
||||
|
||||
Returns:
|
||||
bool: 是否为默认内置工具。
|
||||
"""
|
||||
|
||||
return tool_spec.provider_type == "builtin" or tool_spec.provider_name == "maisaka_builtin"
|
||||
|
||||
@classmethod
|
||||
def _split_builtin_and_candidate_tools(
|
||||
cls,
|
||||
tool_specs: List[ToolSpec],
|
||||
) -> tuple[List[ToolSpec], List[ToolSpec]]:
|
||||
"""拆分内置工具与可筛选工具列表。
|
||||
|
||||
Args:
|
||||
tool_specs: 当前全部工具声明。
|
||||
|
||||
Returns:
|
||||
tuple[List[ToolSpec], List[ToolSpec]]: `(内置工具, 可筛选工具)`。
|
||||
"""
|
||||
|
||||
builtin_tool_specs: List[ToolSpec] = []
|
||||
candidate_tool_specs: List[ToolSpec] = []
|
||||
for tool_spec in tool_specs:
|
||||
if cls._is_builtin_tool_spec(tool_spec):
|
||||
builtin_tool_specs.append(tool_spec)
|
||||
else:
|
||||
candidate_tool_specs.append(tool_spec)
|
||||
return builtin_tool_specs, candidate_tool_specs
|
||||
|
||||
@staticmethod
|
||||
def _truncate_tool_filter_text(text: str, max_length: int = 180) -> str:
|
||||
"""截断工具筛选阶段展示的文本。
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
max_length: 最长保留字符数。
|
||||
|
||||
Returns:
|
||||
str: 截断后的文本。
|
||||
"""
|
||||
|
||||
normalized_text = text.strip()
|
||||
if len(normalized_text) <= max_length:
|
||||
return normalized_text
|
||||
return f"{normalized_text[: max_length - 1]}…"
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_names_log_text(tool_definitions: Sequence[ToolDefinitionInput]) -> str:
|
||||
"""构造 planner 请求前的工具列表日志文本。
|
||||
@@ -503,211 +449,11 @@ class MaisakaChatLoopService:
|
||||
|
||||
return "、".join(tool_names)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_spec_names_log_text(tool_specs: Sequence[ToolSpec]) -> str:
|
||||
"""构造 ToolSpec 列表的工具名日志文本。"""
|
||||
|
||||
tool_names = [tool_spec.name for tool_spec in tool_specs if tool_spec.name]
|
||||
if not tool_names:
|
||||
return "[无工具]"
|
||||
|
||||
return "、".join(tool_names)
|
||||
|
||||
def _build_tool_filter_prompt(
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
candidate_tool_specs: List[ToolSpec],
|
||||
max_keep: int,
|
||||
) -> str:
|
||||
"""构造小模型工具预筛选提示词。
|
||||
|
||||
Args:
|
||||
selected_history: 已选中的对话上下文。
|
||||
candidate_tool_specs: 非内置候选工具列表。
|
||||
max_keep: 最多保留的候选工具数量。
|
||||
|
||||
Returns:
|
||||
str: 用于工具预筛的小模型提示词。
|
||||
"""
|
||||
|
||||
history_lines: List[str] = []
|
||||
for message in selected_history[-10:]:
|
||||
plain_text = message.processed_plain_text.strip()
|
||||
if not plain_text:
|
||||
continue
|
||||
history_lines.append(
|
||||
f"- {message.role}: {self._truncate_tool_filter_text(plain_text, max_length=200)}"
|
||||
)
|
||||
|
||||
if history_lines:
|
||||
history_section = "\n".join(history_lines)
|
||||
else:
|
||||
history_section = "- 当前没有可用的对话上下文。"
|
||||
|
||||
tool_lines = [
|
||||
f"- {tool_spec.name}: {tool_spec.brief_description.strip() or '无简要描述'}"
|
||||
for tool_spec in candidate_tool_specs
|
||||
]
|
||||
tool_section = "\n".join(tool_lines) if tool_lines else "- 当前没有候选工具。"
|
||||
|
||||
return (
|
||||
"你是 Maisaka 的工具预筛选器。\n"
|
||||
"你的任务是在正式进入 planner 前,根据当前情景从候选工具中挑出最可能马上会用到的工具。\n"
|
||||
"默认内置工具已经自动保留,不在候选列表中,你不需要再次选择它们。\n"
|
||||
"你只能参考工具的简要描述,不要假设未描述的隐藏能力。\n"
|
||||
f"最多保留 {max_keep} 个候选工具;如果都不合适,可以返回空数组。\n"
|
||||
"请严格返回 JSON 对象,格式为:"
|
||||
'{"selected_tool_names":["工具名1","工具名2"]}\n\n'
|
||||
f"【最近对话】\n{history_section}\n\n"
|
||||
f"【候选工具(仅简要描述)】\n{tool_section}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_filter_response(
|
||||
response_text: str,
|
||||
candidate_tool_specs: List[ToolSpec],
|
||||
max_keep: int,
|
||||
) -> List[ToolSpec] | None:
|
||||
"""解析工具预筛选响应。
|
||||
|
||||
Args:
|
||||
response_text: 小模型返回的原始文本。
|
||||
candidate_tool_specs: 非内置候选工具列表。
|
||||
max_keep: 最多保留的候选工具数量。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec] | None: 成功解析时返回筛选后的工具列表;解析失败时返回 ``None``。
|
||||
"""
|
||||
|
||||
normalized_response = response_text.strip()
|
||||
if not normalized_response:
|
||||
return None
|
||||
|
||||
selected_tool_names: List[str]
|
||||
try:
|
||||
selected_tool_names = ToolFilterSelection.model_validate_json(normalized_response).selected_tool_names
|
||||
except Exception:
|
||||
try:
|
||||
parsed_payload = json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if isinstance(parsed_payload, dict):
|
||||
raw_tool_names = parsed_payload.get("selected_tool_names", [])
|
||||
elif isinstance(parsed_payload, list):
|
||||
raw_tool_names = parsed_payload
|
||||
else:
|
||||
return None
|
||||
|
||||
if not isinstance(raw_tool_names, list):
|
||||
return None
|
||||
|
||||
selected_tool_names = []
|
||||
for item in raw_tool_names:
|
||||
normalized_name = str(item).strip()
|
||||
if normalized_name:
|
||||
selected_tool_names.append(normalized_name)
|
||||
|
||||
candidate_map = {tool_spec.name: tool_spec for tool_spec in candidate_tool_specs}
|
||||
filtered_tool_specs: List[ToolSpec] = []
|
||||
seen_names: set[str] = set()
|
||||
for tool_name in selected_tool_names:
|
||||
normalized_name = tool_name.strip()
|
||||
if not normalized_name or normalized_name in seen_names:
|
||||
continue
|
||||
tool_spec = candidate_map.get(normalized_name)
|
||||
if tool_spec is None:
|
||||
continue
|
||||
|
||||
seen_names.add(normalized_name)
|
||||
filtered_tool_specs.append(tool_spec)
|
||||
if len(filtered_tool_specs) >= max_keep:
|
||||
break
|
||||
|
||||
return filtered_tool_specs
|
||||
|
||||
async def _filter_tool_specs_for_planner(
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
tool_specs: List[ToolSpec],
|
||||
) -> List[ToolSpec]:
|
||||
"""在将工具交给 planner 前进行快速预筛选。
|
||||
|
||||
Args:
|
||||
selected_history: 已选中的对话上下文。
|
||||
tool_specs: 当前全部可用工具声明。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec]: 最终交给 planner 的工具声明列表。
|
||||
"""
|
||||
|
||||
threshold = max(1, int(global_config.maisaka.tool_filter_threshold))
|
||||
max_keep = max(1, int(global_config.maisaka.tool_filter_max_keep))
|
||||
if len(tool_specs) <= threshold:
|
||||
return tool_specs
|
||||
|
||||
builtin_tool_specs, candidate_tool_specs = self._split_builtin_and_candidate_tools(tool_specs)
|
||||
if not candidate_tool_specs:
|
||||
return tool_specs
|
||||
if len(candidate_tool_specs) <= max_keep:
|
||||
return [*builtin_tool_specs, *candidate_tool_specs]
|
||||
|
||||
filter_prompt = self._build_tool_filter_prompt(selected_history, candidate_tool_specs, max_keep)
|
||||
logger.info(
|
||||
"工具预筛选开始: "
|
||||
f"总工具数={len(tool_specs)} "
|
||||
f"内置工具数={len(builtin_tool_specs)} "
|
||||
f"候选工具数={len(candidate_tool_specs)} "
|
||||
f"最多保留候选数={max_keep} "
|
||||
f"过滤前全部工具名={self._build_tool_spec_names_log_text(tool_specs)}"
|
||||
)
|
||||
|
||||
try:
|
||||
generation_result = await self._tool_filter_llm.generate_response(
|
||||
prompt=filter_prompt,
|
||||
options=LLMGenerationOptions(
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
response_format=RespFormat(
|
||||
format_type=RespFormatType.JSON_SCHEMA,
|
||||
schema=ToolFilterSelection,
|
||||
),
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"工具预筛选失败,保留全部工具。错误={exc}")
|
||||
return tool_specs
|
||||
|
||||
filtered_candidate_tool_specs = self._parse_tool_filter_response(
|
||||
generation_result.response or "",
|
||||
candidate_tool_specs,
|
||||
max_keep,
|
||||
)
|
||||
if filtered_candidate_tool_specs is None:
|
||||
logger.warning(
|
||||
"工具预筛选返回结果无法解析,保留全部工具。"
|
||||
f" 原始返回={generation_result.response or ''!r}"
|
||||
)
|
||||
return tool_specs
|
||||
|
||||
filtered_tool_specs = [*builtin_tool_specs, *filtered_candidate_tool_specs]
|
||||
if not filtered_tool_specs:
|
||||
logger.warning("工具预筛选得到空结果,保留全部工具以避免主流程失去工具能力。")
|
||||
return tool_specs
|
||||
|
||||
logger.info(
|
||||
"工具预筛选完成: "
|
||||
f"筛选前总数={len(tool_specs)} "
|
||||
f"筛选后总数={len(filtered_tool_specs)} "
|
||||
f"过滤后全部工具名={self._build_tool_spec_names_log_text(filtered_tool_specs)} "
|
||||
f"保留候选工具={[tool_spec.name for tool_spec in filtered_candidate_tool_specs]}"
|
||||
)
|
||||
return filtered_tool_specs
|
||||
|
||||
async def chat_loop_step(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
*,
|
||||
injected_user_messages: Sequence[str] | None = None,
|
||||
request_kind: str = "planner",
|
||||
response_format: RespFormat | None = None,
|
||||
tool_definitions: Sequence[ToolDefinitionInput] | None = None,
|
||||
@@ -724,7 +470,10 @@ class MaisakaChatLoopService:
|
||||
if not self._prompts_loaded:
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
selected_history, selection_reason = self.select_llm_context_messages(chat_history)
|
||||
built_messages = self._build_request_messages(selected_history)
|
||||
built_messages = self._build_request_messages(
|
||||
selected_history,
|
||||
injected_user_messages=injected_user_messages,
|
||||
)
|
||||
|
||||
def message_factory(_client: BaseClient) -> List[Message]:
|
||||
"""返回当前轮次已经构建好的请求消息。
|
||||
@@ -744,8 +493,7 @@ class MaisakaChatLoopService:
|
||||
all_tools = list(tool_definitions)
|
||||
elif self._tool_registry is not None:
|
||||
tool_specs = await self._tool_registry.list_tools()
|
||||
filtered_tool_specs = await self._filter_tool_specs_for_planner(selected_history, tool_specs)
|
||||
all_tools = [tool_spec.to_llm_definition() for tool_spec in filtered_tool_specs]
|
||||
all_tools = [tool_spec.to_llm_definition() for tool_spec in tool_specs]
|
||||
else:
|
||||
all_tools = [*get_builtin_tools(), *self._extra_tools]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user