feat:采用tool索引展开方式压缩tool,移除tool过滤器
This commit is contained in:
@@ -21,7 +21,7 @@ from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
@@ -100,6 +100,9 @@ class MaisakaHeartFlowChatting:
|
||||
self._planner_interrupt_flag: Optional[asyncio.Event] = None
|
||||
self._planner_interrupt_requested = False
|
||||
self._planner_interrupt_consecutive_count = 0
|
||||
self._current_action_tool_names: set[str] = set()
|
||||
self.discovered_tool_names: set[str] = set()
|
||||
self.deferred_tool_specs_by_name: dict[str, ToolSpec] = {}
|
||||
self._planner_interrupt_max_consecutive_count = max(
|
||||
0,
|
||||
int(global_config.chat.planner_interrupt_max_consecutive_count),
|
||||
@@ -440,6 +443,117 @@ class MaisakaHeartFlowChatting:
|
||||
tool_definitions=[] if tool_definitions is None else tool_definitions,
|
||||
)
|
||||
|
||||
def set_current_action_tool_names(self, tool_names: Sequence[str]) -> None:
|
||||
"""记录当前 Action Loop 已实际暴露给 planner 的工具名集合。"""
|
||||
|
||||
self._current_action_tool_names = {tool_name for tool_name in tool_names if str(tool_name).strip()}
|
||||
|
||||
def is_action_tool_currently_available(self, tool_name: str) -> bool:
|
||||
"""判断指定工具在当前 Action Loop 轮次中是否真实可用。"""
|
||||
|
||||
normalized_name = str(tool_name).strip()
|
||||
return bool(normalized_name) and normalized_name in self._current_action_tool_names
|
||||
|
||||
def update_deferred_tool_specs(self, deferred_tool_specs: Sequence[ToolSpec]) -> None:
|
||||
"""刷新当前会话的 deferred tools 池,并清理失效的已发现工具。"""
|
||||
|
||||
next_specs_by_name: dict[str, ToolSpec] = {}
|
||||
for tool_spec in deferred_tool_specs:
|
||||
normalized_name = tool_spec.name.strip()
|
||||
if not normalized_name:
|
||||
continue
|
||||
next_specs_by_name[normalized_name] = tool_spec
|
||||
|
||||
self.deferred_tool_specs_by_name = next_specs_by_name
|
||||
self.discovered_tool_names.intersection_update(next_specs_by_name.keys())
|
||||
|
||||
def get_discovered_deferred_tool_specs(self) -> list[ToolSpec]:
|
||||
"""返回当前会话中已发现、且仍然有效的 deferred tools。"""
|
||||
|
||||
return [
|
||||
tool_spec
|
||||
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items()
|
||||
if tool_name in self.discovered_tool_names
|
||||
]
|
||||
|
||||
def build_deferred_tools_reminder(self) -> str:
|
||||
"""构造供 planner 使用的 deferred tools 提示消息。"""
|
||||
|
||||
undiscovered_tool_names = [
|
||||
tool_name
|
||||
for tool_name in self.deferred_tool_specs_by_name
|
||||
if tool_name not in self.discovered_tool_names
|
||||
]
|
||||
if not undiscovered_tool_names:
|
||||
return ""
|
||||
|
||||
tool_lines = [f"{index}. {tool_name}" for index, tool_name in enumerate(undiscovered_tool_names, start=1)]
|
||||
reminder_lines = [
|
||||
"<system-reminder>",
|
||||
"以下工具当前未直接暴露给你,但可以通过 tool_search 工具发现并在后续轮次中使用:",
|
||||
*tool_lines,
|
||||
"",
|
||||
"如需其中某个工具,请先调用 tool_search。tool_search 只负责发现工具,不直接执行业务。",
|
||||
"</system-reminder>",
|
||||
]
|
||||
return "\n".join(reminder_lines)
|
||||
|
||||
def search_deferred_tool_specs(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
) -> list[ToolSpec]:
|
||||
"""按名称或简要描述搜索 deferred tools。"""
|
||||
|
||||
normalized_query = " ".join(query.lower().split()).strip()
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
scored_matches: list[tuple[int, str, ToolSpec]] = []
|
||||
query_terms = [term for term in normalized_query.replace("_", " ").replace("-", " ").split() if term]
|
||||
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items():
|
||||
lower_name = tool_name.lower()
|
||||
lower_description = tool_spec.brief_description.lower()
|
||||
score = 0
|
||||
|
||||
if normalized_query == lower_name:
|
||||
score += 1000
|
||||
if lower_name.startswith(normalized_query):
|
||||
score += 300
|
||||
if normalized_query in lower_name:
|
||||
score += 200
|
||||
if normalized_query in lower_description:
|
||||
score += 100
|
||||
|
||||
for query_term in query_terms:
|
||||
if query_term in lower_name:
|
||||
score += 25
|
||||
if query_term in lower_description:
|
||||
score += 10
|
||||
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
scored_matches.append((score, tool_name, tool_spec))
|
||||
|
||||
scored_matches.sort(key=lambda item: (-item[0], item[1]))
|
||||
return [tool_spec for _, _, tool_spec in scored_matches[: max(1, limit)]]
|
||||
|
||||
def discover_deferred_tools(self, tool_names: Sequence[str]) -> list[str]:
|
||||
"""将指定 deferred tools 标记为已发现,并返回本次新发现的工具名。"""
|
||||
|
||||
newly_discovered_tool_names: list[str] = []
|
||||
for raw_tool_name in tool_names:
|
||||
normalized_name = str(raw_tool_name).strip()
|
||||
if not normalized_name or normalized_name not in self.deferred_tool_specs_by_name:
|
||||
continue
|
||||
if normalized_name in self.discovered_tool_names:
|
||||
continue
|
||||
self.discovered_tool_names.add(normalized_name)
|
||||
newly_discovered_tool_names.append(normalized_name)
|
||||
return newly_discovered_tool_names
|
||||
|
||||
def _has_pending_messages(self) -> bool:
|
||||
return self._last_processed_index < len(self.message_cache)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user