feat:优化timing门控逻辑,减少消耗,提高速度

This commit is contained in:
SengokuCola
2026-04-09 13:56:34 +08:00
parent daef71b7e9
commit b28481d205
17 changed files with 371 additions and 49 deletions

View File

@@ -22,6 +22,7 @@ from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.core.types import ActionInfo
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.services.llm_service import LLMServiceClient
from src.maisaka.context_messages import (
@@ -33,6 +34,7 @@ from src.maisaka.context_messages import (
)
from src.maisaka.message_adapter import parse_speaker_content
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
from .maisaka_expression_selector import maisaka_expression_selector
@@ -255,15 +257,15 @@ class MaisakaReplyGenerator:
return "在该聊天中的注意事项:\n" + "\n\n".join(prompt_lines) + "\n"
def _build_prompt(
def _build_request_messages(
self,
chat_history: List[LLMContextMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
expression_habits: str = "",
stream_id: Optional[str] = None,
) -> str:
"""构建 Maisaka replyer 提示词"""
) -> List[Message]:
"""构建 Maisaka replyer 请求消息列表"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
formatted_history = self._format_chat_history(chat_history)
target_message_block = self._build_target_message_block(reply_message)
@@ -297,7 +299,10 @@ class MaisakaReplyGenerator:
user_sections.append("现在,你说:")
user_prompt = "\n\n".join(user_sections)
return f"System: {system_prompt}\n\nUser: {user_prompt}"
return [
MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build(),
MessageBuilder().set_role(RoleType.User).add_text_content(user_prompt).build(),
]
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
"""解析当前回复使用的会话 ID。"""
@@ -425,7 +430,7 @@ class MaisakaReplyGenerator:
prompt_started_at = time.perf_counter()
try:
prompt = self._build_prompt(
request_messages = self._build_request_messages(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
@@ -443,7 +448,9 @@ class MaisakaReplyGenerator:
return finalize(False)
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
result.completion.request_prompt = prompt
request_prompt = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
result.completion.request_prompt = request_prompt
result.request_messages = serialize_prompt_messages(request_messages)
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
preview_chat_id = self._resolve_session_id(stream_id) or "unknown"
@@ -451,22 +458,28 @@ class MaisakaReplyGenerator:
if show_replyer_prompt:
console.print(
Panel(
PromptCLIVisualizer.build_text_access_panel(
prompt,
PromptCLIVisualizer.build_prompt_access_panel(
request_messages,
category="replyer",
chat_id=preview_chat_id,
request_kind="replyer",
subtitle=f"ID: {preview_chat_id}",
selection_reason=f"ID: {preview_chat_id}",
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
),
title="Maisaka 回复器 Prompt",
title="Maisaka Replyer Prompt",
border_style="bright_yellow",
padding=(0, 1),
)
)
def message_factory(_client: object) -> List[Message]:
return request_messages
llm_started_at = time.perf_counter()
try:
generation_result = await self.express_model.generate_response(prompt)
generation_result = await self.express_model.generate_response_with_messages(
message_factory=message_factory
)
except Exception as exc:
logger.exception("Maisaka 回复器调用失败")
result.error_message = str(exc)
@@ -481,7 +494,7 @@ class MaisakaReplyGenerator:
response_text = (generation_result.response or "").strip()
result.success = bool(response_text)
result.completion = LLMCompletionResult(
request_prompt=prompt,
request_prompt=request_prompt,
response_text=response_text,
reasoning_text=generation_result.reasoning or "",
model_name=generation_result.model_name or "",

View File

@@ -42,6 +42,7 @@ from src.maisaka.context_messages import (
)
from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
from .maisaka_expression_selector import maisaka_expression_selector
@@ -267,21 +268,13 @@ class MaisakaReplyGenerator:
message: SessionBackedMessage,
default_user_name: str,
) -> Optional[Message]:
speaker_name, _ = parse_speaker_content(message.processed_plain_text.strip())
visible_speaker = speaker_name or default_user_name
raw_message = clone_message_sequence(message.raw_message)
if not raw_message.components:
raw_message = MessageSequence([TextComponent(f"[{visible_speaker}]")])
elif isinstance(raw_message.components[0], TextComponent):
first_text = raw_message.components[0].text or ""
raw_message.components[0] = TextComponent(f"[{visible_speaker}]{first_text}")
else:
raw_message.components.insert(0, TextComponent(f"[{visible_speaker}]"))
raw_message = MessageSequence([TextComponent(message.processed_plain_text)])
multimodal_message = SessionBackedMessage(
raw_message=raw_message,
visible_text=f"[{visible_speaker}]{message.processed_plain_text}",
visible_text=message.processed_plain_text,
timestamp=message.timestamp,
message_id=message.message_id,
original_message=message.original_message,
@@ -520,16 +513,18 @@ class MaisakaReplyGenerator:
return request_messages
result.completion.request_prompt = prompt_preview
result.request_messages = serialize_prompt_messages(request_messages)
preview_chat_id = self._resolve_session_id(stream_id)
replyer_prompt_section: RenderableType | None = None
if show_replyer_prompt:
replyer_prompt_section = Panel(
PromptCLIVisualizer.build_text_access_panel(
prompt_preview,
PromptCLIVisualizer.build_prompt_access_panel(
request_messages,
category="replyer",
chat_id=preview_chat_id,
request_kind="replyer",
subtitle=f"ID: {preview_chat_id}",
selection_reason=f"ID: {preview_chat_id}",
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
),
title="Reply Prompt",
border_style="bright_yellow",

View File

@@ -14,6 +14,7 @@ from . import BaseDataModel
if TYPE_CHECKING:
from src.common.data_models.message_component_data_model import MessageSequence
from src.common.data_models.llm_service_data_models import PromptMessage
from src.llm_models.payload_content.tool_option import ToolCall
@@ -121,6 +122,10 @@ class ReplyGenerationResult(BaseDataModel):
default=None,
metadata={"description": "供监控层直接消费的通用 tool 展示详情。"},
)
request_messages: List["PromptMessage"] = field(
default_factory=list,
metadata={"description": "本次 replyer 实际发送给模型的消息列表。"},
)
def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]:
@@ -133,6 +138,8 @@ def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]:
if prompt_text:
detail["prompt_text"] = prompt_text
if result.request_messages:
detail["request_messages"] = result.request_messages
if reasoning_text:
detail["reasoning_text"] = reasoning_text
if output_text:

View File

@@ -233,7 +233,7 @@ class ExpressionAutoCheckTask(AsyncTask):
failed_count = 0
for i, expression in enumerate(expressions, 1):
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
logger.debug(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
if await self._evaluate_expression(expression):
passed_count += 1

View File

@@ -10,6 +10,8 @@ from src.llm_models.payload_content.tool_option import ToolDefinitionInput
from .context import BuiltinToolRuntimeContext
from .continue_tool import get_tool_spec as get_continue_tool_spec
from .continue_tool import handle_tool as handle_continue_tool
from .finish import get_tool_spec as get_finish_tool_spec
from .finish import handle_tool as handle_finish_tool
from .no_reply import get_tool_spec as get_no_reply_tool_spec
from .no_reply import handle_tool as handle_no_reply_tool
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
@@ -44,6 +46,7 @@ def get_action_tool_specs() -> List[ToolSpec]:
"""获取 Action Loop 阶段可用的内置工具声明。"""
return [
get_finish_tool_spec(),
get_reply_tool_spec(),
get_view_complex_message_tool_spec(),
get_query_jargon_tool_spec(),
@@ -63,6 +66,7 @@ def get_all_builtin_tool_specs() -> List[ToolSpec]:
return [
*get_timing_tool_specs(),
get_finish_tool_spec(),
get_reply_tool_spec(),
get_view_complex_message_tool_spec(),
get_query_jargon_tool_spec(),
@@ -95,6 +99,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
return {
"continue": lambda invocation, context=None: handle_continue_tool(tool_ctx, invocation, context),
"finish": lambda invocation, context=None: handle_finish_tool(tool_ctx, invocation, context),
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),

View File

@@ -0,0 +1,34 @@
"""finish 内置工具。"""
from typing import Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from .context import BuiltinToolRuntimeContext
def get_tool_spec() -> ToolSpec:
"""获取 finish 工具声明。"""
return ToolSpec(
name="finish",
brief_description="结束本轮思考,等待后续新的外部消息再继续。",
provider_name="maisaka_builtin",
provider_type="builtin",
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 finish 内置工具。"""
del context
tool_ctx.runtime._enter_stop_state()
return tool_ctx.build_success_result(
invocation.tool_name,
"当前对话循环已结束本轮思考,等待新的外部消息到来。",
metadata={"pause_execution": True},
)

View File

@@ -52,7 +52,7 @@ logger = get_logger("maisaka_reasoning_engine")
TIMING_GATE_CONTEXT_LIMIT = 24
TIMING_GATE_MAX_TOKENS = 384
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply", "wait"}
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply"}
ACTION_BUILTIN_TOOL_NAMES = {tool_spec.name for tool_spec in get_action_tool_specs()}
@@ -297,6 +297,21 @@ class MaisakaReasoningEngine:
[f"- continue [强制跳过]: {reason}"],
)
@staticmethod
def _mark_timing_gate_completed(timing_action: str) -> bool:
"""根据门控动作决定下一轮是否还需要重新执行 timing。"""
return timing_action != "continue"
@staticmethod
def _should_retry_planner_after_interrupt(
*,
round_index: int,
max_internal_rounds: int,
has_pending_messages: bool,
) -> bool:
return has_pending_messages and round_index + 1 < max_internal_rounds
async def run_loop(self) -> None:
"""独立消费消息批次,并执行对应的内部思考轮次。"""
try:
@@ -313,7 +328,7 @@ class MaisakaReasoningEngine:
if self._runtime._has_pending_messages()
else []
)
if not timeout_triggered and not cached_messages and not message_triggered:
if not timeout_triggered and not cached_messages:
continue
self._runtime._agent_state = self._runtime._STATE_RUNNING
@@ -335,6 +350,7 @@ class MaisakaReasoningEngine:
self._trim_chat_history()
try:
timing_gate_required = True
for round_index in range(self._runtime._max_internal_rounds):
cycle_detail = self._start_cycle()
self._runtime._log_cycle_started(cycle_detail, round_index)
@@ -363,27 +379,36 @@ class MaisakaReasoningEngine:
f"{self._runtime.log_prefix} 本轮思考前已刷新 {refreshed_message_count} 条视觉占位历史消息"
)
timing_started_at = time.time()
timing_action, timing_response, timing_tool_results = await self._run_timing_gate(anchor_message)
timing_duration_ms = (time.time() - timing_started_at) * 1000
cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000
await emit_timing_gate_result(
session_id=self._runtime.session_id,
cycle_id=cycle_detail.cycle_id,
action=timing_action,
content=timing_response.content,
tool_calls=timing_response.tool_calls,
messages=[],
prompt_tokens=timing_response.prompt_tokens,
selected_history_count=timing_response.selected_history_count,
duration_ms=timing_duration_ms,
)
if timing_action != "continue":
logger.info(
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
f"回合={round_index + 1} 动作={timing_action}"
if timing_gate_required:
timing_started_at = time.time()
timing_action, timing_response, timing_tool_results = await self._run_timing_gate(
anchor_message
)
timing_duration_ms = (time.time() - timing_started_at) * 1000
cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000
await emit_timing_gate_result(
session_id=self._runtime.session_id,
cycle_id=cycle_detail.cycle_id,
action=timing_action,
content=timing_response.content,
tool_calls=timing_response.tool_calls,
messages=[],
prompt_tokens=timing_response.prompt_tokens,
selected_history_count=timing_response.selected_history_count,
duration_ms=timing_duration_ms,
)
timing_gate_required = self._mark_timing_gate_completed(timing_action)
if timing_action != "continue":
logger.info(
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
f"回合={round_index + 1} 动作={timing_action}"
)
break
else:
logger.info(
f"{self._runtime.log_prefix} 跳过 Timing Gate继续执行 Planner: "
f"回合={round_index + 1}"
)
break
planner_started_at = time.time()
action_tool_definitions = await self._build_action_tool_definitions()
@@ -436,7 +461,28 @@ class MaisakaReasoningEngine:
f"打断时间={interrupted_at:.3f} "
f"耗时={interrupted_at - planner_started_at:.3f}"
)
break
if not self._should_retry_planner_after_interrupt(
round_index=round_index,
max_internal_rounds=self._runtime._max_internal_rounds,
has_pending_messages=self._runtime._has_pending_messages(),
):
break
await self._runtime._wait_for_message_quiet_period()
self._runtime._message_turn_scheduled = False
interrupted_messages = self._runtime._collect_pending_messages()
if not interrupted_messages:
break
asyncio.create_task(self._runtime._trigger_batch_learning(interrupted_messages))
self._append_wait_interrupted_message_if_needed()
await self._ingest_messages(interrupted_messages)
anchor_message = interrupted_messages[-1]
logger.info(
f"{self._runtime.log_prefix} 淇濇寔娲昏穬鐘舵€侊紝璺宠繃 Timing Gate 鐩存帴閲嶈瘯 Planner: "
f"鍥炲悎={round_index + 2}"
)
continue
finally:
completed_cycle = self._end_cycle(cycle_detail)
self._runtime._render_context_usage_panel(
@@ -933,6 +979,9 @@ class MaisakaReasoningEngine:
if invocation.tool_name == "no_reply":
return "你暂停了当前对话循环,等待新的外部消息。"
if invocation.tool_name == "finish":
return "你结束了本轮思考,等待新的外部消息后再继续。"
if invocation.tool_name == "continue":
return "你允许当前对话继续进入下一轮完整思考与工具执行。"

View File

@@ -30,6 +30,7 @@ from src.mcp_module import MCPManager
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
from src.mcp_module.provider import MCPToolProvider
from src.plugin_runtime.tool_provider import PluginToolProvider
from src.plugin_runtime.hook_payloads import deserialize_prompt_messages
from .chat_loop_service import ChatResponse, MaisakaChatLoopService
from .context_messages import LLMContextMessage
@@ -941,6 +942,7 @@ class MaisakaHeartFlowChatting:
*,
tool_name: str,
prompt_text: str,
request_messages: Optional[list[Any]] = None,
tool_call_id: str,
) -> Panel:
"""将工具 prompt 渲染为可点击查看的预览入口。"""
@@ -950,6 +952,26 @@ class MaisakaHeartFlowChatting:
if tool_call_id:
subtitle += f"\n调用ID: {tool_call_id}"
if isinstance(request_messages, list) and request_messages:
try:
normalized_messages = deserialize_prompt_messages(request_messages)
except Exception as exc:
logger.warning(f"工具 {tool_name} 的 request_messages 无法反序列化,已回退为文本预览: {exc}")
else:
return Panel(
PromptCLIVisualizer.build_prompt_access_panel(
normalized_messages,
category=labels["prompt_category"],
chat_id=self.session_id,
request_kind=labels["request_kind"],
selection_reason=subtitle,
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
),
title=labels["prompt_title"],
border_style="bright_yellow",
padding=(0, 1),
)
return Panel(
PromptCLIVisualizer.build_text_access_panel(
prompt_text,
@@ -1019,6 +1041,7 @@ class MaisakaHeartFlowChatting:
self._build_tool_prompt_access_panel(
tool_name=tool_name,
prompt_text=prompt_text,
request_messages=detail.get("request_messages") if isinstance(detail.get("request_messages"), list) else None,
tool_call_id=tool_call_id,
)
)