feat:新增按顺序选择 fix:修复timing gate意外tool问题
This commit is contained in:
148
pytests/test_maisaka_timing_gate.py
Normal file
148
pytests/test_maisaka_timing_gate.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.core.tooling import ToolExecutionResult
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
|
from src.maisaka.chat_loop_service import ChatResponse, MaisakaChatLoopService
|
||||||
|
from src.maisaka.context_messages import AssistantMessage, TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
||||||
|
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
|
||||||
|
|
||||||
|
|
||||||
|
def _build_chat_response(tool_calls: list[ToolCall]) -> ChatResponse:
|
||||||
|
return ChatResponse(
|
||||||
|
content="The model returned an invalid timing tool.",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
request_messages=[],
|
||||||
|
raw_message=AssistantMessage(
|
||||||
|
content="",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
source_kind="perception",
|
||||||
|
),
|
||||||
|
selected_history_count=1,
|
||||||
|
tool_count=len(tool_calls),
|
||||||
|
prompt_tokens=10,
|
||||||
|
built_message_count=1,
|
||||||
|
completion_tokens=3,
|
||||||
|
total_tokens=13,
|
||||||
|
prompt_section=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timing_gate_invalid_tool_defaults_to_no_reply(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
runtime = SimpleNamespace(
|
||||||
|
_force_next_timing_continue=False,
|
||||||
|
_chat_history=[],
|
||||||
|
log_prefix="[test]",
|
||||||
|
stopped=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _enter_stop_state() -> None:
|
||||||
|
runtime.stopped = True
|
||||||
|
|
||||||
|
runtime._enter_stop_state = _enter_stop_state
|
||||||
|
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
|
||||||
|
del kwargs
|
||||||
|
return _build_chat_response([
|
||||||
|
ToolCall(call_id="invalid-timing-tool", func_name="finish", args={}),
|
||||||
|
])
|
||||||
|
|
||||||
|
async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None:
|
||||||
|
del args, kwargs
|
||||||
|
raise AssertionError("invalid timing tools must not be executed")
|
||||||
|
|
||||||
|
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
|
||||||
|
monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call)
|
||||||
|
|
||||||
|
action, response, tool_results, tool_monitor_results = await engine._run_timing_gate(object()) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert action == "no_reply"
|
||||||
|
assert response.tool_calls[0].func_name == "finish"
|
||||||
|
assert runtime.stopped is True
|
||||||
|
assert tool_monitor_results == []
|
||||||
|
assert len(runtime._chat_history) == 1
|
||||||
|
assert runtime._chat_history[0].source == TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
||||||
|
assert "finish" in runtime._chat_history[0].processed_plain_text
|
||||||
|
assert tool_results == [
|
||||||
|
"- no_reply [非法 Timing 工具]: 返回了 finish,已停止本轮并等待新消息",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_timing_gate_invalid_tool_hint_keeps_only_latest() -> None:
|
||||||
|
old_hint = SimpleNamespace(source=TIMING_GATE_INVALID_TOOL_HINT_SOURCE)
|
||||||
|
runtime = SimpleNamespace(_chat_history=[old_hint])
|
||||||
|
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
engine._append_timing_gate_invalid_tool_hint("finish")
|
||||||
|
engine._append_timing_gate_invalid_tool_hint("reply")
|
||||||
|
|
||||||
|
assert len(runtime._chat_history) == 1
|
||||||
|
hint_message = runtime._chat_history[0]
|
||||||
|
assert hint_message.source == TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
||||||
|
assert "reply" in hint_message.processed_plain_text
|
||||||
|
assert "finish" not in hint_message.processed_plain_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_timing_gate_invalid_tool_hint_only_visible_to_timing_gate() -> None:
|
||||||
|
runtime = SimpleNamespace(_chat_history=[])
|
||||||
|
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
||||||
|
engine._append_timing_gate_invalid_tool_hint("finish")
|
||||||
|
hint_message = runtime._chat_history[0]
|
||||||
|
|
||||||
|
timing_history = MaisakaChatLoopService._filter_history_for_request_kind(
|
||||||
|
[hint_message],
|
||||||
|
request_kind="timing_gate",
|
||||||
|
)
|
||||||
|
planner_history = MaisakaChatLoopService._filter_history_for_request_kind(
|
||||||
|
[hint_message],
|
||||||
|
request_kind="planner",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert timing_history == [hint_message]
|
||||||
|
assert planner_history == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_tool_is_not_written_back_to_history() -> None:
|
||||||
|
finish_call = ToolCall(call_id="finish-call", func_name="finish", args={})
|
||||||
|
reply_call = ToolCall(call_id="reply-call", func_name="reply", args={})
|
||||||
|
assistant_message = AssistantMessage(
|
||||||
|
content="当前不需要继续回复。",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
tool_calls=[finish_call, reply_call],
|
||||||
|
)
|
||||||
|
runtime = SimpleNamespace(_chat_history=[assistant_message])
|
||||||
|
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
engine._append_tool_execution_result(
|
||||||
|
finish_call,
|
||||||
|
ToolExecutionResult(
|
||||||
|
tool_name="finish",
|
||||||
|
success=True,
|
||||||
|
content="当前对话循环已结束本轮思考,等待新的消息到来。",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert runtime._chat_history == [assistant_message]
|
||||||
|
assert [tool_call.func_name for tool_call in assistant_message.tool_calls] == ["reply"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_tool_removes_empty_assistant_history_message() -> None:
|
||||||
|
finish_call = ToolCall(call_id="finish-call", func_name="finish", args={})
|
||||||
|
assistant_message = AssistantMessage(
|
||||||
|
content="",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
tool_calls=[finish_call],
|
||||||
|
)
|
||||||
|
runtime = SimpleNamespace(_chat_history=[assistant_message])
|
||||||
|
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
engine._append_tool_execution_result(
|
||||||
|
finish_call,
|
||||||
|
ToolExecutionResult(tool_name="finish", success=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert runtime._chat_history == []
|
||||||
@@ -57,7 +57,7 @@ MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute(
|
|||||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||||
MMC_VERSION: str = "1.0.0"
|
MMC_VERSION: str = "1.0.0"
|
||||||
CONFIG_VERSION: str = "8.9.17"
|
CONFIG_VERSION: str = "8.9.17"
|
||||||
MODEL_CONFIG_VERSION: str = "1.14.2"
|
MODEL_CONFIG_VERSION: str = "1.14.3"
|
||||||
|
|
||||||
logger = get_logger("config")
|
logger = get_logger("config")
|
||||||
|
|
||||||
|
|||||||
@@ -406,9 +406,10 @@ class TaskConfig(ConfigBase):
|
|||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"x-widget": "select",
|
"x-widget": "select",
|
||||||
"x-icon": "shuffle",
|
"x-icon": "shuffle",
|
||||||
|
"options": ["balance", "random", "sequential"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
"""模型选择策略:balance(负载均衡)或 random(随机选择)"""
|
"""模型选择策略:balance(负载均衡)、random(随机选择)或 sequential(按配置顺序优先选择)"""
|
||||||
|
|
||||||
|
|
||||||
class ModelTaskConfig(ConfigBase):
|
class ModelTaskConfig(ConfigBase):
|
||||||
|
|||||||
@@ -683,11 +683,16 @@ class LLMOrchestrator:
|
|||||||
|
|
||||||
ensure_configured_clients_loaded()
|
ensure_configured_clients_loaded()
|
||||||
|
|
||||||
strategy = self.model_for_task.selection_strategy.lower()
|
strategy = self.model_for_task.selection_strategy.strip().lower()
|
||||||
|
|
||||||
if strategy == "random":
|
if strategy == "random":
|
||||||
# 随机选择策略
|
# 随机选择策略
|
||||||
selected_model_name = random.choice(list(available_models.keys()))
|
selected_model_name = random.choice(list(available_models.keys()))
|
||||||
|
elif strategy == "sequential":
|
||||||
|
# 顺序优先策略:按照配置顺序选择第一个尚未失败的模型。
|
||||||
|
selected_model_name = next(
|
||||||
|
model_name for model_name in self.model_for_task.model_list if model_name in available_models
|
||||||
|
)
|
||||||
elif strategy == "balance":
|
elif strategy == "balance":
|
||||||
# 负载均衡策略:根据总tokens和惩罚值选择
|
# 负载均衡策略:根据总tokens和惩罚值选择
|
||||||
selected_model_name = min(
|
selected_model_name = min(
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from .builtin_tool import get_builtin_tools
|
|||||||
from .context_messages import (
|
from .context_messages import (
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
LLMContextMessage,
|
LLMContextMessage,
|
||||||
|
TIMING_GATE_INVALID_TOOL_HINT_SOURCE,
|
||||||
ToolResultMessage,
|
ToolResultMessage,
|
||||||
build_llm_message_from_context,
|
build_llm_message_from_context,
|
||||||
)
|
)
|
||||||
@@ -704,6 +705,15 @@ class MaisakaChatLoopService:
|
|||||||
) -> List[LLMContextMessage]:
|
) -> List[LLMContextMessage]:
|
||||||
"""按请求类型过滤不应暴露的历史工具链。"""
|
"""按请求类型过滤不应暴露的历史工具链。"""
|
||||||
|
|
||||||
|
if request_kind == "timing_gate":
|
||||||
|
return selected_history
|
||||||
|
|
||||||
|
selected_history = [
|
||||||
|
message
|
||||||
|
for message in selected_history
|
||||||
|
if message.source != TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
||||||
|
]
|
||||||
|
|
||||||
if request_kind != "planner":
|
if request_kind != "planner":
|
||||||
return selected_history
|
return selected_history
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from src.llm_models.payload_content.tool_option import ToolCall
|
|||||||
from .message_adapter import parse_speaker_content
|
from .message_adapter import parse_speaker_content
|
||||||
|
|
||||||
FORWARD_PREVIEW_LIMIT = 4
|
FORWARD_PREVIEW_LIMIT = 4
|
||||||
|
TIMING_GATE_INVALID_TOOL_HINT_SOURCE = "timing_gate_invalid_tool_hint"
|
||||||
|
|
||||||
|
|
||||||
def _guess_image_format(image_bytes: bytes) -> Optional[str]:
|
def _guess_image_format(image_bytes: bytes) -> Optional[str]:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import traceback
|
|||||||
|
|
||||||
from src.chat.heart_flow.heartFC_utils import CycleDetail
|
from src.chat.heart_flow.heartFC_utils import CycleDetail
|
||||||
from src.chat.message_receive.message import SessionMessage
|
from src.chat.message_receive.message import SessionMessage
|
||||||
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
|
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.prompt_i18n import load_prompt
|
from src.common.prompt_i18n import load_prompt
|
||||||
from src.core.tooling import ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
from src.core.tooling import ToolAvailabilityContext, ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||||
@@ -31,6 +31,7 @@ from .context_messages import (
|
|||||||
ComplexSessionMessage,
|
ComplexSessionMessage,
|
||||||
LLMContextMessage,
|
LLMContextMessage,
|
||||||
SessionBackedMessage,
|
SessionBackedMessage,
|
||||||
|
TIMING_GATE_INVALID_TOOL_HINT_SOURCE,
|
||||||
ToolResultMessage,
|
ToolResultMessage,
|
||||||
contains_complex_message,
|
contains_complex_message,
|
||||||
)
|
)
|
||||||
@@ -54,6 +55,7 @@ logger = get_logger("maisaka_reasoning_engine")
|
|||||||
TIMING_GATE_CONTEXT_LIMIT = 24
|
TIMING_GATE_CONTEXT_LIMIT = 24
|
||||||
TIMING_GATE_MAX_TOKENS = 384
|
TIMING_GATE_MAX_TOKENS = 384
|
||||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||||
|
HISTORY_SILENT_TOOL_NAMES = {"finish"}
|
||||||
|
|
||||||
|
|
||||||
class MaisakaReasoningEngine:
|
class MaisakaReasoningEngine:
|
||||||
@@ -259,8 +261,21 @@ class MaisakaReasoningEngine:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if selected_tool_call is None:
|
if selected_tool_call is None:
|
||||||
logger.warning(f"{self._runtime.log_prefix} Timing Gate 未返回有效控制工具,默认继续执行 Action Loop")
|
invalid_tool_names = [
|
||||||
return "continue", response, tool_result_summaries, tool_monitor_results
|
str(tool_call.func_name).strip()
|
||||||
|
for tool_call in response.tool_calls
|
||||||
|
if str(tool_call.func_name).strip()
|
||||||
|
]
|
||||||
|
invalid_tool_text = "、".join(invalid_tool_names) if invalid_tool_names else "无工具"
|
||||||
|
logger.warning(
|
||||||
|
f"{self._runtime.log_prefix} Timing Gate 未返回有效控制工具:{invalid_tool_text},将按 no_reply 处理"
|
||||||
|
)
|
||||||
|
self._append_timing_gate_invalid_tool_hint(invalid_tool_text)
|
||||||
|
self._runtime._enter_stop_state()
|
||||||
|
tool_result_summaries.append(
|
||||||
|
f"- no_reply [非法 Timing 工具]: 返回了 {invalid_tool_text},已停止本轮并等待新消息"
|
||||||
|
)
|
||||||
|
return "no_reply", response, tool_result_summaries, tool_monitor_results
|
||||||
|
|
||||||
append_history = False
|
append_history = False
|
||||||
store_record = selected_tool_call.func_name != "continue"
|
store_record = selected_tool_call.func_name != "continue"
|
||||||
@@ -286,9 +301,13 @@ class MaisakaReasoningEngine:
|
|||||||
timing_action = str(result.metadata.get("timing_action") or selected_tool_call.func_name).strip()
|
timing_action = str(result.metadata.get("timing_action") or selected_tool_call.func_name).strip()
|
||||||
if timing_action not in TIMING_GATE_TOOL_NAMES:
|
if timing_action not in TIMING_GATE_TOOL_NAMES:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self._runtime.log_prefix} Timing Gate 返回未知动作 {timing_action!r},将按 continue 处理"
|
f"{self._runtime.log_prefix} Timing Gate 返回未知动作 {timing_action!r},将按 no_reply 处理"
|
||||||
)
|
)
|
||||||
return "continue", response, tool_result_summaries, tool_monitor_results
|
self._runtime._enter_stop_state()
|
||||||
|
tool_result_summaries.append(
|
||||||
|
f"- no_reply [未知 Timing 动作]: 返回了 {timing_action!r},已停止本轮并等待新消息"
|
||||||
|
)
|
||||||
|
return "no_reply", response, tool_result_summaries, tool_monitor_results
|
||||||
return timing_action, response, tool_result_summaries, tool_monitor_results
|
return timing_action, response, tool_result_summaries, tool_monitor_results
|
||||||
|
|
||||||
def _build_forced_continue_timing_result(
|
def _build_forced_continue_timing_result(
|
||||||
@@ -324,6 +343,29 @@ class MaisakaReasoningEngine:
|
|||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _append_timing_gate_invalid_tool_hint(self, invalid_tool_text: str) -> None:
|
||||||
|
"""写入一条仅 Timing Gate 可见的非法工具提示,并保证最多保留最新一条。"""
|
||||||
|
|
||||||
|
self._runtime._chat_history = [
|
||||||
|
message
|
||||||
|
for message in self._runtime._chat_history
|
||||||
|
if message.source != TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
||||||
|
]
|
||||||
|
normalized_tool_text = invalid_tool_text.strip() or "无工具"
|
||||||
|
hint_content = (
|
||||||
|
"Timing Gate 上一轮选择了非法工具:"
|
||||||
|
f"{normalized_tool_text}。\n"
|
||||||
|
"Timing Gate 只能调用 continue、wait 或 no_reply 中的一个工具。"
|
||||||
|
)
|
||||||
|
self._runtime._chat_history.append(
|
||||||
|
SessionBackedMessage(
|
||||||
|
raw_message=MessageSequence([TextComponent(hint_content)]),
|
||||||
|
visible_text=hint_content,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
source_kind=TIMING_GATE_INVALID_TOOL_HINT_SOURCE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _mark_timing_gate_completed(timing_action: str) -> bool:
|
def _mark_timing_gate_completed(timing_action: str) -> bool:
|
||||||
"""根据门控动作决定下一轮是否还需要重新执行 timing。"""
|
"""根据门控动作决定下一轮是否还需要重新执行 timing。"""
|
||||||
@@ -1210,6 +1252,10 @@ class MaisakaReasoningEngine:
|
|||||||
result: 统一工具执行结果。
|
result: 统一工具执行结果。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if tool_call.func_name in HISTORY_SILENT_TOOL_NAMES:
|
||||||
|
self._remove_tool_call_from_history(tool_call)
|
||||||
|
return
|
||||||
|
|
||||||
history_content = result.get_history_content()
|
history_content = result.get_history_content()
|
||||||
if not history_content:
|
if not history_content:
|
||||||
history_content = "工具执行成功。" if result.success else f"工具 {tool_call.func_name} 执行失败。"
|
history_content = "工具执行成功。" if result.success else f"工具 {tool_call.func_name} 执行失败。"
|
||||||
@@ -1224,6 +1270,34 @@ class MaisakaReasoningEngine:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _remove_tool_call_from_history(self, tool_call: ToolCall) -> None:
|
||||||
|
"""从历史里的 assistant 消息中移除控制类工具调用。"""
|
||||||
|
|
||||||
|
tool_call_id = str(tool_call.call_id or "").strip()
|
||||||
|
if not tool_call_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
for index in range(len(self._runtime._chat_history) - 1, -1, -1):
|
||||||
|
message = self._runtime._chat_history[index]
|
||||||
|
if not isinstance(message, AssistantMessage) or not message.tool_calls:
|
||||||
|
continue
|
||||||
|
|
||||||
|
remaining_tool_calls = [
|
||||||
|
existing_tool_call
|
||||||
|
for existing_tool_call in message.tool_calls
|
||||||
|
if str(existing_tool_call.call_id or "").strip() != tool_call_id
|
||||||
|
]
|
||||||
|
if len(remaining_tool_calls) == len(message.tool_calls):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if remaining_tool_calls:
|
||||||
|
message.tool_calls = remaining_tool_calls
|
||||||
|
elif message.content.strip():
|
||||||
|
message.tool_calls = []
|
||||||
|
else:
|
||||||
|
del self._runtime._chat_history[index]
|
||||||
|
return
|
||||||
|
|
||||||
def _append_timing_gate_execution_result(
|
def _append_timing_gate_execution_result(
|
||||||
self,
|
self,
|
||||||
response: ChatResponse,
|
response: ChatResponse,
|
||||||
|
|||||||
Reference in New Issue
Block a user