feat:采用tool索引展开方式压缩tool,移除tool过滤器

This commit is contained in:
SengokuCola
2026-04-09 20:29:01 +08:00
parent 243b8deb43
commit 0852c38e81
10 changed files with 359 additions and 319 deletions

View File

@@ -21,7 +21,7 @@ from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
from src.config.config import global_config
from src.core.tooling import ToolRegistry
from src.core.tooling import ToolRegistry, ToolSpec
from src.learners.expression_learner import ExpressionLearner
from src.learners.jargon_miner import JargonMiner
from src.llm_models.payload_content.resp_format import RespFormat
@@ -100,6 +100,9 @@ class MaisakaHeartFlowChatting:
self._planner_interrupt_flag: Optional[asyncio.Event] = None
self._planner_interrupt_requested = False
self._planner_interrupt_consecutive_count = 0
self._current_action_tool_names: set[str] = set()
self.discovered_tool_names: set[str] = set()
self.deferred_tool_specs_by_name: dict[str, ToolSpec] = {}
self._planner_interrupt_max_consecutive_count = max(
0,
int(global_config.chat.planner_interrupt_max_consecutive_count),
@@ -440,6 +443,117 @@ class MaisakaHeartFlowChatting:
tool_definitions=[] if tool_definitions is None else tool_definitions,
)
def set_current_action_tool_names(self, tool_names: Sequence[str]) -> None:
"""记录当前 Action Loop 已实际暴露给 planner 的工具名集合。"""
self._current_action_tool_names = {tool_name for tool_name in tool_names if str(tool_name).strip()}
def is_action_tool_currently_available(self, tool_name: str) -> bool:
"""判断指定工具在当前 Action Loop 轮次中是否真实可用。"""
normalized_name = str(tool_name).strip()
return bool(normalized_name) and normalized_name in self._current_action_tool_names
def update_deferred_tool_specs(self, deferred_tool_specs: Sequence[ToolSpec]) -> None:
"""刷新当前会话的 deferred tools 池,并清理失效的已发现工具。"""
next_specs_by_name: dict[str, ToolSpec] = {}
for tool_spec in deferred_tool_specs:
normalized_name = tool_spec.name.strip()
if not normalized_name:
continue
next_specs_by_name[normalized_name] = tool_spec
self.deferred_tool_specs_by_name = next_specs_by_name
self.discovered_tool_names.intersection_update(next_specs_by_name.keys())
def get_discovered_deferred_tool_specs(self) -> list[ToolSpec]:
"""返回当前会话中已发现、且仍然有效的 deferred tools。"""
return [
tool_spec
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items()
if tool_name in self.discovered_tool_names
]
def build_deferred_tools_reminder(self) -> str:
"""构造供 planner 使用的 deferred tools 提示消息。"""
undiscovered_tool_names = [
tool_name
for tool_name in self.deferred_tool_specs_by_name
if tool_name not in self.discovered_tool_names
]
if not undiscovered_tool_names:
return ""
tool_lines = [f"{index}. {tool_name}" for index, tool_name in enumerate(undiscovered_tool_names, start=1)]
reminder_lines = [
"<system-reminder>",
"以下工具当前未直接暴露给你,但可以通过 tool_search 工具发现并在后续轮次中使用:",
*tool_lines,
"",
"如需其中某个工具,请先调用 tool_search。tool_search 只负责发现工具,不直接执行业务。",
"</system-reminder>",
]
return "\n".join(reminder_lines)
def search_deferred_tool_specs(
self,
query: str,
*,
limit: int,
) -> list[ToolSpec]:
"""按名称或简要描述搜索 deferred tools。"""
normalized_query = " ".join(query.lower().split()).strip()
if not normalized_query:
return []
scored_matches: list[tuple[int, str, ToolSpec]] = []
query_terms = [term for term in normalized_query.replace("_", " ").replace("-", " ").split() if term]
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items():
lower_name = tool_name.lower()
lower_description = tool_spec.brief_description.lower()
score = 0
if normalized_query == lower_name:
score += 1000
if lower_name.startswith(normalized_query):
score += 300
if normalized_query in lower_name:
score += 200
if normalized_query in lower_description:
score += 100
for query_term in query_terms:
if query_term in lower_name:
score += 25
if query_term in lower_description:
score += 10
if score <= 0:
continue
scored_matches.append((score, tool_name, tool_spec))
scored_matches.sort(key=lambda item: (-item[0], item[1]))
return [tool_spec for _, _, tool_spec in scored_matches[: max(1, limit)]]
def discover_deferred_tools(self, tool_names: Sequence[str]) -> list[str]:
"""将指定 deferred tools 标记为已发现,并返回本次新发现的工具名。"""
newly_discovered_tool_names: list[str] = []
for raw_tool_name in tool_names:
normalized_name = str(raw_tool_name).strip()
if not normalized_name or normalized_name not in self.deferred_tool_specs_by_name:
continue
if normalized_name in self.discovered_tool_names:
continue
self.discovered_tool_names.add(normalized_name)
newly_discovered_tool_names.append(normalized_name)
return newly_discovered_tool_names
def _has_pending_messages(self) -> bool:
return self._last_processed_index < len(self.message_cache)