@@ -6,7 +6,9 @@ import { useTour } from './use-tour'
|
|||||||
// Joyride 主题配置
|
// Joyride 主题配置
|
||||||
const joyrideStyles = {
|
const joyrideStyles = {
|
||||||
options: {
|
options: {
|
||||||
zIndex: 10000,
|
// 提到 portal 容器(99999)之上,确保 overlay/spotlight/tooltip 都在最上层;
|
||||||
|
// overlay 的 z-index 由 react-joyride 内部基于 options.zIndex 推算,必须大于 floater 才能让 tooltip 按钮可点击。
|
||||||
|
zIndex: 100000,
|
||||||
primaryColor: 'hsl(var(--color-primary))',
|
primaryColor: 'hsl(var(--color-primary))',
|
||||||
textColor: 'hsl(var(--color-foreground))',
|
textColor: 'hsl(var(--color-foreground))',
|
||||||
backgroundColor: 'hsl(var(--color-background))',
|
backgroundColor: 'hsl(var(--color-background))',
|
||||||
@@ -197,13 +199,6 @@ export function TourRenderer() {
|
|||||||
locale={locale}
|
locale={locale}
|
||||||
scrollOffset={80}
|
scrollOffset={80}
|
||||||
scrollToFirstStep
|
scrollToFirstStep
|
||||||
floaterProps={{
|
|
||||||
styles: {
|
|
||||||
floater: {
|
|
||||||
zIndex: 99999,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from src.core.tooling import ToolExecutionResult, ToolInvocation
|
from src.core.tooling import ToolExecutionResult, ToolInvocation
|
||||||
@@ -8,6 +9,7 @@ from src.llm_models.payload_content.tool_option import ToolCall
|
|||||||
from src.maisaka.chat_loop_service import ChatResponse, MaisakaChatLoopService
|
from src.maisaka.chat_loop_service import ChatResponse, MaisakaChatLoopService
|
||||||
from src.maisaka.context_messages import AssistantMessage, TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
from src.maisaka.context_messages import AssistantMessage, TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
||||||
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
|
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
|
||||||
|
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
||||||
|
|
||||||
|
|
||||||
def _build_chat_response(tool_calls: list[ToolCall]) -> ChatResponse:
|
def _build_chat_response(tool_calls: list[ToolCall]) -> ChatResponse:
|
||||||
@@ -173,6 +175,29 @@ def test_timing_gate_invalid_tool_hint_only_visible_to_timing_gate() -> None:
|
|||||||
assert planner_history == []
|
assert planner_history == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_forced_timing_trigger_bypasses_message_frequency_threshold() -> None:
|
||||||
|
runtime = SimpleNamespace(
|
||||||
|
_STATE_WAIT="wait",
|
||||||
|
_agent_state="stop",
|
||||||
|
_message_turn_scheduled=False,
|
||||||
|
_internal_turn_queue=asyncio.Queue(),
|
||||||
|
_has_pending_messages=lambda: True,
|
||||||
|
_get_pending_message_count=lambda: 1,
|
||||||
|
_has_forced_timing_trigger=lambda: True,
|
||||||
|
_cancel_deferred_message_turn_task=lambda: None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fail_get_message_trigger_threshold() -> int:
|
||||||
|
raise AssertionError("@/提及必回不应被普通聊天频率阈值拦住")
|
||||||
|
|
||||||
|
runtime._get_message_trigger_threshold = _fail_get_message_trigger_threshold
|
||||||
|
|
||||||
|
MaisakaHeartFlowChatting._schedule_message_turn(runtime) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert runtime._message_turn_scheduled is True
|
||||||
|
assert runtime._internal_turn_queue.get_nowait() == "message"
|
||||||
|
|
||||||
|
|
||||||
def test_finish_tool_is_not_written_back_to_history() -> None:
|
def test_finish_tool_is_not_written_back_to_history() -> None:
|
||||||
finish_call = ToolCall(call_id="finish-call", func_name="finish", args={})
|
finish_call = ToolCall(call_id="finish-call", func_name="finish", args={})
|
||||||
reply_call = ToolCall(call_id="reply-call", func_name="reply", args={})
|
reply_call = ToolCall(call_id="reply-call", func_name="reply", args={})
|
||||||
@@ -213,3 +238,47 @@ def test_finish_tool_removes_empty_assistant_history_message() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert runtime._chat_history == []
|
assert runtime._chat_history == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_timing_gate_head_trim_keeps_short_history() -> None:
|
||||||
|
messages = [
|
||||||
|
AssistantMessage(content="第一条消息", timestamp=datetime.now()),
|
||||||
|
AssistantMessage(content="第二条消息", timestamp=datetime.now()),
|
||||||
|
]
|
||||||
|
|
||||||
|
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
|
||||||
|
messages,
|
||||||
|
drop_context_count=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert trimmed_messages == messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_timing_gate_head_trim_keeps_history_within_config_limit() -> None:
|
||||||
|
messages = [
|
||||||
|
AssistantMessage(content=f"消息 {index}", timestamp=datetime.now())
|
||||||
|
for index in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
|
||||||
|
messages,
|
||||||
|
drop_context_count=7,
|
||||||
|
trim_threshold_context_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert trimmed_messages == messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_timing_gate_head_trim_applies_after_config_limit_exceeded() -> None:
|
||||||
|
messages = [
|
||||||
|
AssistantMessage(content=f"消息 {index}", timestamp=datetime.now())
|
||||||
|
for index in range(11)
|
||||||
|
]
|
||||||
|
|
||||||
|
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
|
||||||
|
messages,
|
||||||
|
drop_context_count=7,
|
||||||
|
trim_threshold_context_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert trimmed_messages == messages[7:]
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from collections.abc import Awaitable, Callable, Sequence
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Optional, TYPE_CHECKING
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
from src.chat.message_receive.chat_manager import chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager
|
||||||
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
|
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
|
||||||
from src.common.data_models.image_data_model import MaiEmoji
|
from src.common.data_models.image_data_model import MaiEmoji
|
||||||
@@ -121,45 +119,13 @@ def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def select_emoji_for_maisaka(
|
|
||||||
*,
|
|
||||||
requested_emotion: str = "",
|
|
||||||
reasoning: str = "",
|
|
||||||
context_texts: Sequence[str] | None = None,
|
|
||||||
sample_size: int = 30,
|
|
||||||
) -> tuple[MaiEmoji | None, str]:
|
|
||||||
"""为 Maisaka 选择一个合适的表情。"""
|
|
||||||
|
|
||||||
del reasoning, context_texts
|
|
||||||
|
|
||||||
available_emojis = list(emoji_manager.emojis)
|
|
||||||
if not available_emojis:
|
|
||||||
return None, ""
|
|
||||||
|
|
||||||
normalized_requested_emotion = requested_emotion.strip()
|
|
||||||
if normalized_requested_emotion:
|
|
||||||
matched_emojis = [
|
|
||||||
emoji
|
|
||||||
for emoji in available_emojis
|
|
||||||
if normalized_requested_emotion.lower() in (emotion.lower() for emotion in _normalize_emotions(emoji))
|
|
||||||
]
|
|
||||||
if matched_emojis:
|
|
||||||
return random.choice(matched_emojis), normalized_requested_emotion
|
|
||||||
|
|
||||||
sampled_emojis = random.sample(
|
|
||||||
available_emojis,
|
|
||||||
min(max(sample_size, 1), len(available_emojis)),
|
|
||||||
)
|
|
||||||
return random.choice(sampled_emojis), ""
|
|
||||||
|
|
||||||
|
|
||||||
async def send_emoji_for_maisaka(
|
async def send_emoji_for_maisaka(
|
||||||
*,
|
*,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
|
emoji_selector: EmojiSelector,
|
||||||
requested_emotion: str = "",
|
requested_emotion: str = "",
|
||||||
reasoning: str = "",
|
reasoning: str = "",
|
||||||
context_texts: Sequence[str] | None = None,
|
context_texts: Sequence[str] | None = None,
|
||||||
emoji_selector: EmojiSelector | None = None,
|
|
||||||
) -> MaisakaEmojiSendResult:
|
) -> MaisakaEmojiSendResult:
|
||||||
"""为 Maisaka 选择并发送一个表情。"""
|
"""为 Maisaka 选择并发送一个表情。"""
|
||||||
|
|
||||||
@@ -194,20 +160,12 @@ async def send_emoji_for_maisaka(
|
|||||||
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
|
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
|
||||||
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
|
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
|
||||||
|
|
||||||
if emoji_selector is None:
|
selected_emoji, matched_emotion = await emoji_selector(
|
||||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
normalized_requested_emotion,
|
||||||
requested_emotion=normalized_requested_emotion,
|
normalized_reasoning,
|
||||||
reasoning=normalized_reasoning,
|
normalized_context_texts,
|
||||||
context_texts=normalized_context_texts,
|
sample_size,
|
||||||
sample_size=sample_size,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
selected_emoji, matched_emotion = await emoji_selector(
|
|
||||||
normalized_requested_emotion,
|
|
||||||
normalized_reasoning,
|
|
||||||
normalized_context_texts,
|
|
||||||
sample_size,
|
|
||||||
)
|
|
||||||
after_select_result = await _get_runtime_manager().invoke_hook(
|
after_select_result = await _get_runtime_manager().invoke_hook(
|
||||||
"emoji.maisaka.after_select",
|
"emoji.maisaka.after_select",
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from json import dumps
|
||||||
from random import sample
|
from random import sample
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@@ -17,9 +18,8 @@ from src.emoji_system.maisaka_tool import send_emoji_for_maisaka
|
|||||||
from src.common.data_models.image_data_model import MaiEmoji
|
from src.common.data_models.image_data_model import MaiEmoji
|
||||||
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import config_manager, global_config
|
||||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||||
from src.maisaka.context_messages import (
|
from src.maisaka.context_messages import (
|
||||||
LLMContextMessage,
|
LLMContextMessage,
|
||||||
@@ -221,6 +221,7 @@ def _build_send_emoji_monitor_detail(
|
|||||||
detail: Dict[str, Any] = {}
|
detail: Dict[str, Any] = {}
|
||||||
if isinstance(request_messages, list) and request_messages:
|
if isinstance(request_messages, list) and request_messages:
|
||||||
detail["request_messages"] = request_messages
|
detail["request_messages"] = request_messages
|
||||||
|
detail["prompt_text"] = dumps(request_messages, ensure_ascii=False, indent=2)
|
||||||
if reasoning_text.strip():
|
if reasoning_text.strip():
|
||||||
detail["reasoning_text"] = reasoning_text.strip()
|
detail["reasoning_text"] = reasoning_text.strip()
|
||||||
if output_text.strip():
|
if output_text.strip():
|
||||||
@@ -279,6 +280,24 @@ def _build_send_emoji_monitor_metadata(
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_emoji_selector_model_task_name() -> str:
|
||||||
|
"""根据 planner 模型视觉能力选择表情选择子代理的模型任务。"""
|
||||||
|
|
||||||
|
model_config = config_manager.get_model_config()
|
||||||
|
planner_models = [
|
||||||
|
model_name
|
||||||
|
for model_name in model_config.model_task_config.planner.model_list
|
||||||
|
if str(model_name).strip()
|
||||||
|
]
|
||||||
|
models_by_name = {model.name: model for model in model_config.models}
|
||||||
|
if planner_models and all(
|
||||||
|
model_name in models_by_name and models_by_name[model_name].visual
|
||||||
|
for model_name in planner_models
|
||||||
|
):
|
||||||
|
return "planner"
|
||||||
|
return "vlm"
|
||||||
|
|
||||||
|
|
||||||
async def _select_emoji_with_sub_agent(
|
async def _select_emoji_with_sub_agent(
|
||||||
tool_ctx: BuiltinToolRuntimeContext,
|
tool_ctx: BuiltinToolRuntimeContext,
|
||||||
reasoning: str,
|
reasoning: str,
|
||||||
@@ -326,7 +345,8 @@ async def _select_emoji_with_sub_agent(
|
|||||||
prompt_llm_message = prompt_message.to_llm_message()
|
prompt_llm_message = prompt_message.to_llm_message()
|
||||||
if prompt_llm_message is not None:
|
if prompt_llm_message is not None:
|
||||||
request_messages.append(prompt_llm_message)
|
request_messages.append(prompt_llm_message)
|
||||||
candidate_llm_message = candidate_message.to_llm_message()
|
candidate_to_llm_message = getattr(candidate_message, "to_llm_message", None)
|
||||||
|
candidate_llm_message = candidate_to_llm_message() if callable(candidate_to_llm_message) else None
|
||||||
if candidate_llm_message is not None:
|
if candidate_llm_message is not None:
|
||||||
request_messages.append(candidate_llm_message)
|
request_messages.append(candidate_llm_message)
|
||||||
serialized_request_messages = serialize_prompt_messages(request_messages)
|
serialized_request_messages = serialize_prompt_messages(request_messages)
|
||||||
@@ -337,10 +357,7 @@ async def _select_emoji_with_sub_agent(
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
extra_messages=[prompt_message, candidate_message],
|
extra_messages=[prompt_message, candidate_message],
|
||||||
max_tokens=_EMOJI_SUB_AGENT_MAX_TOKENS,
|
max_tokens=_EMOJI_SUB_AGENT_MAX_TOKENS,
|
||||||
response_format=RespFormat(
|
model_task_name=_resolve_emoji_selector_model_task_name(),
|
||||||
format_type=RespFormatType.JSON_SCHEMA,
|
|
||||||
schema=EmojiSelectionResult,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
selection_duration_ms = round((datetime.now() - selection_started_at).total_seconds() * 1000, 2)
|
selection_duration_ms = round((datetime.now() - selection_started_at).total_seconds() * 1000, 2)
|
||||||
|
|
||||||
@@ -409,12 +426,16 @@ async def handle_tool(
|
|||||||
"reason": "",
|
"reason": "",
|
||||||
}
|
}
|
||||||
selection_metadata: Dict[str, Any] = {"reason": "", "monitor_detail": {}}
|
selection_metadata: Dict[str, Any] = {"reason": "", "monitor_detail": {}}
|
||||||
|
requested_emotion = ""
|
||||||
|
if isinstance(invocation.arguments, dict):
|
||||||
|
requested_emotion = str(invocation.arguments.get("emotion") or "").strip()
|
||||||
|
|
||||||
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具")
|
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
send_result = await send_emoji_for_maisaka(
|
send_result = await send_emoji_for_maisaka(
|
||||||
stream_id=tool_ctx.runtime.session_id,
|
stream_id=tool_ctx.runtime.session_id,
|
||||||
|
requested_emotion=requested_emotion,
|
||||||
reasoning=tool_ctx.engine.last_reasoning_content,
|
reasoning=tool_ctx.engine.last_reasoning_content,
|
||||||
context_texts=context_texts,
|
context_texts=context_texts,
|
||||||
emoji_selector=lambda _requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent(
|
emoji_selector=lambda _requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent(
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ class MaisakaChatLoopService:
|
|||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
is_group_chat: Optional[bool] = None,
|
is_group_chat: Optional[bool] = None,
|
||||||
max_tokens: int = 2048,
|
max_tokens: int = 2048,
|
||||||
|
model_task_name: str = "planner",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化 Maisaka 对话循环服务。
|
"""初始化 Maisaka 对话循环服务。
|
||||||
|
|
||||||
@@ -205,6 +206,7 @@ class MaisakaChatLoopService:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self._max_tokens = max_tokens
|
self._max_tokens = max_tokens
|
||||||
|
self._model_task_name = model_task_name.strip() or "planner"
|
||||||
self._is_group_chat = is_group_chat
|
self._is_group_chat = is_group_chat
|
||||||
self._session_id = session_id or ""
|
self._session_id = session_id or ""
|
||||||
self._extra_tools: List[ToolOption] = []
|
self._extra_tools: List[ToolOption] = []
|
||||||
@@ -236,17 +238,18 @@ class MaisakaChatLoopService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_llm_chat_client(self, request_kind: str) -> LLMServiceClient:
|
def _get_llm_chat_client(self, request_kind: str) -> LLMServiceClient:
|
||||||
"""获取当前请求类型对应的 planner LLM 客户端。"""
|
"""获取当前请求类型对应的 LLM 客户端。"""
|
||||||
|
|
||||||
request_type = self._resolve_llm_request_type(request_kind)
|
request_type = self._resolve_llm_request_type(request_kind)
|
||||||
llm_client = self._llm_chat_clients.get(request_type)
|
client_key = f"{self._model_task_name}:{request_type}"
|
||||||
|
llm_client = self._llm_chat_clients.get(client_key)
|
||||||
if llm_client is None:
|
if llm_client is None:
|
||||||
llm_client = LLMServiceClient(
|
llm_client = LLMServiceClient(
|
||||||
task_name="planner",
|
task_name=self._model_task_name,
|
||||||
request_type=request_type,
|
request_type=request_type,
|
||||||
session_id=self._session_id,
|
session_id=self._session_id,
|
||||||
)
|
)
|
||||||
self._llm_chat_clients[request_type] = llm_client
|
self._llm_chat_clients[client_key] = llm_client
|
||||||
return llm_client
|
return llm_client
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -473,13 +473,18 @@ class MaisakaHeartFlowChatting:
|
|||||||
def _update_message_trigger_state(self, message: SessionMessage) -> None:
|
def _update_message_trigger_state(self, message: SessionMessage) -> None:
|
||||||
"""补齐消息中的 @/提及 标记,并在命中时启用强制 continue。"""
|
"""补齐消息中的 @/提及 标记,并在命中时启用强制 continue。"""
|
||||||
|
|
||||||
detected_mentioned, detected_at, _ = is_mentioned_bot_in_message(message)
|
detected_mentioned, detected_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||||
if detected_at:
|
if detected_at:
|
||||||
message.is_at = True
|
message.is_at = True
|
||||||
if detected_mentioned:
|
if detected_mentioned:
|
||||||
message.is_mentioned = True
|
message.is_mentioned = True
|
||||||
|
|
||||||
if not message.is_at and not message.is_mentioned:
|
should_force_reply = (
|
||||||
|
reply_probability_boost >= 1.0
|
||||||
|
or (message.is_at and global_config.chat.at_bot_inevitable_reply)
|
||||||
|
or (message.is_mentioned and global_config.chat.mentioned_bot_reply)
|
||||||
|
)
|
||||||
|
if not should_force_reply or (not message.is_at and not message.is_mentioned):
|
||||||
return
|
return
|
||||||
|
|
||||||
self._arm_force_next_timing_continue(
|
self._arm_force_next_timing_continue(
|
||||||
@@ -537,6 +542,11 @@ class MaisakaHeartFlowChatting:
|
|||||||
self._force_next_timing_reason = ""
|
self._force_next_timing_reason = ""
|
||||||
return reason
|
return reason
|
||||||
|
|
||||||
|
def _has_forced_timing_trigger(self) -> bool:
|
||||||
|
"""判断是否已有 @/提及必回触发,需绕过普通频率阈值。"""
|
||||||
|
|
||||||
|
return self._force_next_timing_continue
|
||||||
|
|
||||||
def _bind_planner_interrupt_flag(self, interrupt_flag: asyncio.Event) -> None:
|
def _bind_planner_interrupt_flag(self, interrupt_flag: asyncio.Event) -> None:
|
||||||
"""绑定当前可打断请求使用的中断标记。"""
|
"""绑定当前可打断请求使用的中断标记。"""
|
||||||
self._planner_interrupt_flag = interrupt_flag
|
self._planner_interrupt_flag = interrupt_flag
|
||||||
@@ -590,6 +600,7 @@ class MaisakaHeartFlowChatting:
|
|||||||
extra_messages: Optional[Sequence[LLMContextMessage]] = None,
|
extra_messages: Optional[Sequence[LLMContextMessage]] = None,
|
||||||
interrupt_flag: asyncio.Event | None = None,
|
interrupt_flag: asyncio.Event | None = None,
|
||||||
max_tokens: int = 512,
|
max_tokens: int = 512,
|
||||||
|
model_task_name: str = "planner",
|
||||||
response_format: RespFormat | None = None,
|
response_format: RespFormat | None = None,
|
||||||
tool_definitions: Optional[Sequence[ToolDefinitionInput]] = None,
|
tool_definitions: Optional[Sequence[ToolDefinitionInput]] = None,
|
||||||
) -> ChatResponse:
|
) -> ChatResponse:
|
||||||
@@ -603,6 +614,7 @@ class MaisakaHeartFlowChatting:
|
|||||||
sub_agent_history = self._drop_head_context_messages(
|
sub_agent_history = self._drop_head_context_messages(
|
||||||
selected_history,
|
selected_history,
|
||||||
drop_head_context_count,
|
drop_head_context_count,
|
||||||
|
trim_threshold_context_count=context_message_limit,
|
||||||
)
|
)
|
||||||
if extra_messages:
|
if extra_messages:
|
||||||
sub_agent_history.extend(list(extra_messages))
|
sub_agent_history.extend(list(extra_messages))
|
||||||
@@ -612,6 +624,7 @@ class MaisakaHeartFlowChatting:
|
|||||||
session_id=self.session_id,
|
session_id=self.session_id,
|
||||||
is_group_chat=self.chat_stream.is_group_session,
|
is_group_chat=self.chat_stream.is_group_session,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
model_task_name=model_task_name,
|
||||||
)
|
)
|
||||||
sub_agent.set_interrupt_flag(interrupt_flag)
|
sub_agent.set_interrupt_flag(interrupt_flag)
|
||||||
return await sub_agent.chat_loop_step(
|
return await sub_agent.chat_loop_step(
|
||||||
@@ -625,12 +638,21 @@ class MaisakaHeartFlowChatting:
|
|||||||
def _drop_head_context_messages(
|
def _drop_head_context_messages(
|
||||||
chat_history: Sequence[LLMContextMessage],
|
chat_history: Sequence[LLMContextMessage],
|
||||||
drop_context_count: int,
|
drop_context_count: int,
|
||||||
|
*,
|
||||||
|
trim_threshold_context_count: int | None = None,
|
||||||
) -> list[LLMContextMessage]:
|
) -> list[LLMContextMessage]:
|
||||||
"""从已选上下文头部丢弃指定数量的普通上下文消息。"""
|
"""从已选上下文头部丢弃指定数量的普通上下文消息。"""
|
||||||
|
|
||||||
if drop_context_count <= 0:
|
if drop_context_count <= 0:
|
||||||
return list(chat_history)
|
return list(chat_history)
|
||||||
|
|
||||||
|
context_message_count = sum(1 for message in chat_history if message.count_in_context)
|
||||||
|
if trim_threshold_context_count is not None and context_message_count <= trim_threshold_context_count:
|
||||||
|
return list(chat_history)
|
||||||
|
|
||||||
|
if context_message_count <= drop_context_count:
|
||||||
|
return list(chat_history)
|
||||||
|
|
||||||
first_kept_index = 0
|
first_kept_index = 0
|
||||||
dropped_context_count = 0
|
dropped_context_count = 0
|
||||||
while (
|
while (
|
||||||
@@ -867,6 +889,12 @@ class MaisakaHeartFlowChatting:
|
|||||||
if pending_count <= 0:
|
if pending_count <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self._has_forced_timing_trigger():
|
||||||
|
self._cancel_deferred_message_turn_task()
|
||||||
|
self._message_turn_scheduled = True
|
||||||
|
self._internal_turn_queue.put_nowait("message")
|
||||||
|
return
|
||||||
|
|
||||||
trigger_threshold = self._get_message_trigger_threshold()
|
trigger_threshold = self._get_message_trigger_threshold()
|
||||||
if pending_count >= trigger_threshold or self._should_trigger_message_turn_by_idle_compensation(
|
if pending_count >= trigger_threshold or self._should_trigger_message_turn_by_idle_compensation(
|
||||||
pending_count=pending_count,
|
pending_count=pending_count,
|
||||||
|
|||||||
Reference in New Issue
Block a user