feat:采用tool索引展开方式压缩tool,移除tool过滤器

This commit is contained in:
SengokuCola
2026-04-09 20:29:01 +08:00
parent 243b8deb43
commit 0852c38e81
10 changed files with 359 additions and 319 deletions

View File

@@ -5,20 +5,18 @@ from datetime import datetime
from typing import Any, List, Optional, Sequence
import asyncio
import json
import random
from pydantic import BaseModel, Field as PydanticField
from rich.console import RenderableType
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.core.tooling import ToolRegistry, ToolSpec
from src.core.tooling import ToolRegistry
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
from src.plugin_runtime.hook_payloads import (
deserialize_prompt_messages,
@@ -54,13 +52,6 @@ class ChatResponse:
prompt_section: Optional[RenderableType] = None
class ToolFilterSelection(BaseModel):
"""工具筛选响应。"""
selected_tool_names: list[str] = PydanticField(default_factory=list)
"""经过预筛后保留的候选工具名称列表。"""
logger = get_logger("maisaka_chat_loop")
@@ -217,10 +208,6 @@ class MaisakaChatLoopService:
else:
self._chat_system_prompt = chat_system_prompt
self._llm_chat = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
self._tool_filter_llm = LLMServiceClient(
task_name=global_config.maisaka.tool_filter_task_name,
request_type="maisaka_tool_filter",
)
@property
def personality_prompt(self) -> str:
@@ -399,6 +386,7 @@ class MaisakaChatLoopService:
self,
selected_history: List[LLMContextMessage],
*,
injected_user_messages: Sequence[str] | None = None,
system_prompt: Optional[str] = None,
) -> List[Message]:
"""构造发给大模型的消息列表。
@@ -420,61 +408,19 @@ class MaisakaChatLoopService:
if llm_message is not None:
messages.append(llm_message)
for injected_message in injected_user_messages or []:
normalized_message = str(injected_message or "").strip()
if not normalized_message:
continue
messages.append(
MessageBuilder()
.set_role(RoleType.User)
.add_text_content(normalized_message)
.build()
)
return messages
@staticmethod
def _is_builtin_tool_spec(tool_spec: ToolSpec) -> bool:
"""判断一个工具是否属于默认内置工具。
Args:
tool_spec: 待判断的工具声明。
Returns:
bool: 是否为默认内置工具。
"""
return tool_spec.provider_type == "builtin" or tool_spec.provider_name == "maisaka_builtin"
@classmethod
def _split_builtin_and_candidate_tools(
cls,
tool_specs: List[ToolSpec],
) -> tuple[List[ToolSpec], List[ToolSpec]]:
"""拆分内置工具与可筛选工具列表。
Args:
tool_specs: 当前全部工具声明。
Returns:
tuple[List[ToolSpec], List[ToolSpec]]: `(内置工具, 可筛选工具)`。
"""
builtin_tool_specs: List[ToolSpec] = []
candidate_tool_specs: List[ToolSpec] = []
for tool_spec in tool_specs:
if cls._is_builtin_tool_spec(tool_spec):
builtin_tool_specs.append(tool_spec)
else:
candidate_tool_specs.append(tool_spec)
return builtin_tool_specs, candidate_tool_specs
@staticmethod
def _truncate_tool_filter_text(text: str, max_length: int = 180) -> str:
"""截断工具筛选阶段展示的文本。
Args:
text: 原始文本。
max_length: 最长保留字符数。
Returns:
str: 截断后的文本。
"""
normalized_text = text.strip()
if len(normalized_text) <= max_length:
return normalized_text
return f"{normalized_text[: max_length - 1]}"
@staticmethod
def _build_tool_names_log_text(tool_definitions: Sequence[ToolDefinitionInput]) -> str:
"""构造 planner 请求前的工具列表日志文本。
@@ -503,211 +449,11 @@ class MaisakaChatLoopService:
return "".join(tool_names)
@staticmethod
def _build_tool_spec_names_log_text(tool_specs: Sequence[ToolSpec]) -> str:
"""构造 ToolSpec 列表的工具名日志文本。"""
tool_names = [tool_spec.name for tool_spec in tool_specs if tool_spec.name]
if not tool_names:
return "[无工具]"
return "".join(tool_names)
def _build_tool_filter_prompt(
self,
selected_history: List[LLMContextMessage],
candidate_tool_specs: List[ToolSpec],
max_keep: int,
) -> str:
"""构造小模型工具预筛选提示词。
Args:
selected_history: 已选中的对话上下文。
candidate_tool_specs: 非内置候选工具列表。
max_keep: 最多保留的候选工具数量。
Returns:
str: 用于工具预筛的小模型提示词。
"""
history_lines: List[str] = []
for message in selected_history[-10:]:
plain_text = message.processed_plain_text.strip()
if not plain_text:
continue
history_lines.append(
f"- {message.role}: {self._truncate_tool_filter_text(plain_text, max_length=200)}"
)
if history_lines:
history_section = "\n".join(history_lines)
else:
history_section = "- 当前没有可用的对话上下文。"
tool_lines = [
f"- {tool_spec.name}: {tool_spec.brief_description.strip() or '无简要描述'}"
for tool_spec in candidate_tool_specs
]
tool_section = "\n".join(tool_lines) if tool_lines else "- 当前没有候选工具。"
return (
"你是 Maisaka 的工具预筛选器。\n"
"你的任务是在正式进入 planner 前,根据当前情景从候选工具中挑出最可能马上会用到的工具。\n"
"默认内置工具已经自动保留,不在候选列表中,你不需要再次选择它们。\n"
"你只能参考工具的简要描述,不要假设未描述的隐藏能力。\n"
f"最多保留 {max_keep} 个候选工具;如果都不合适,可以返回空数组。\n"
"请严格返回 JSON 对象,格式为:"
'{"selected_tool_names":["工具名1","工具名2"]}\n\n'
f"【最近对话】\n{history_section}\n\n"
f"【候选工具(仅简要描述)】\n{tool_section}"
)
@staticmethod
def _parse_tool_filter_response(
response_text: str,
candidate_tool_specs: List[ToolSpec],
max_keep: int,
) -> List[ToolSpec] | None:
"""解析工具预筛选响应。
Args:
response_text: 小模型返回的原始文本。
candidate_tool_specs: 非内置候选工具列表。
max_keep: 最多保留的候选工具数量。
Returns:
List[ToolSpec] | None: 成功解析时返回筛选后的工具列表;解析失败时返回 ``None``。
"""
normalized_response = response_text.strip()
if not normalized_response:
return None
selected_tool_names: List[str]
try:
selected_tool_names = ToolFilterSelection.model_validate_json(normalized_response).selected_tool_names
except Exception:
try:
parsed_payload = json.loads(normalized_response)
except json.JSONDecodeError:
return None
if isinstance(parsed_payload, dict):
raw_tool_names = parsed_payload.get("selected_tool_names", [])
elif isinstance(parsed_payload, list):
raw_tool_names = parsed_payload
else:
return None
if not isinstance(raw_tool_names, list):
return None
selected_tool_names = []
for item in raw_tool_names:
normalized_name = str(item).strip()
if normalized_name:
selected_tool_names.append(normalized_name)
candidate_map = {tool_spec.name: tool_spec for tool_spec in candidate_tool_specs}
filtered_tool_specs: List[ToolSpec] = []
seen_names: set[str] = set()
for tool_name in selected_tool_names:
normalized_name = tool_name.strip()
if not normalized_name or normalized_name in seen_names:
continue
tool_spec = candidate_map.get(normalized_name)
if tool_spec is None:
continue
seen_names.add(normalized_name)
filtered_tool_specs.append(tool_spec)
if len(filtered_tool_specs) >= max_keep:
break
return filtered_tool_specs
async def _filter_tool_specs_for_planner(
self,
selected_history: List[LLMContextMessage],
tool_specs: List[ToolSpec],
) -> List[ToolSpec]:
"""在将工具交给 planner 前进行快速预筛选。
Args:
selected_history: 已选中的对话上下文。
tool_specs: 当前全部可用工具声明。
Returns:
List[ToolSpec]: 最终交给 planner 的工具声明列表。
"""
threshold = max(1, int(global_config.maisaka.tool_filter_threshold))
max_keep = max(1, int(global_config.maisaka.tool_filter_max_keep))
if len(tool_specs) <= threshold:
return tool_specs
builtin_tool_specs, candidate_tool_specs = self._split_builtin_and_candidate_tools(tool_specs)
if not candidate_tool_specs:
return tool_specs
if len(candidate_tool_specs) <= max_keep:
return [*builtin_tool_specs, *candidate_tool_specs]
filter_prompt = self._build_tool_filter_prompt(selected_history, candidate_tool_specs, max_keep)
logger.info(
"工具预筛选开始: "
f"总工具数={len(tool_specs)} "
f"内置工具数={len(builtin_tool_specs)} "
f"候选工具数={len(candidate_tool_specs)} "
f"最多保留候选数={max_keep} "
f"过滤前全部工具名={self._build_tool_spec_names_log_text(tool_specs)}"
)
try:
generation_result = await self._tool_filter_llm.generate_response(
prompt=filter_prompt,
options=LLMGenerationOptions(
temperature=0.0,
max_tokens=256,
response_format=RespFormat(
format_type=RespFormatType.JSON_SCHEMA,
schema=ToolFilterSelection,
),
),
)
except Exception as exc:
logger.warning(f"工具预筛选失败,保留全部工具。错误={exc}")
return tool_specs
filtered_candidate_tool_specs = self._parse_tool_filter_response(
generation_result.response or "",
candidate_tool_specs,
max_keep,
)
if filtered_candidate_tool_specs is None:
logger.warning(
"工具预筛选返回结果无法解析,保留全部工具。"
f" 原始返回={generation_result.response or ''!r}"
)
return tool_specs
filtered_tool_specs = [*builtin_tool_specs, *filtered_candidate_tool_specs]
if not filtered_tool_specs:
logger.warning("工具预筛选得到空结果,保留全部工具以避免主流程失去工具能力。")
return tool_specs
logger.info(
"工具预筛选完成: "
f"筛选前总数={len(tool_specs)} "
f"筛选后总数={len(filtered_tool_specs)} "
f"过滤后全部工具名={self._build_tool_spec_names_log_text(filtered_tool_specs)} "
f"保留候选工具={[tool_spec.name for tool_spec in filtered_candidate_tool_specs]}"
)
return filtered_tool_specs
async def chat_loop_step(
self,
chat_history: List[LLMContextMessage],
*,
injected_user_messages: Sequence[str] | None = None,
request_kind: str = "planner",
response_format: RespFormat | None = None,
tool_definitions: Sequence[ToolDefinitionInput] | None = None,
@@ -724,7 +470,10 @@ class MaisakaChatLoopService:
if not self._prompts_loaded:
await self.ensure_chat_prompt_loaded()
selected_history, selection_reason = self.select_llm_context_messages(chat_history)
built_messages = self._build_request_messages(selected_history)
built_messages = self._build_request_messages(
selected_history,
injected_user_messages=injected_user_messages,
)
def message_factory(_client: BaseClient) -> List[Message]:
"""返回当前轮次已经构建好的请求消息。
@@ -744,8 +493,7 @@ class MaisakaChatLoopService:
all_tools = list(tool_definitions)
elif self._tool_registry is not None:
tool_specs = await self._tool_registry.list_tools()
filtered_tool_specs = await self._filter_tool_specs_for_planner(selected_history, tool_specs)
all_tools = [tool_spec.to_llm_definition() for tool_spec in filtered_tool_specs]
all_tools = [tool_spec.to_llm_definition() for tool_spec in tool_specs]
else:
all_tools = [*get_builtin_tools(), *self._extra_tools]