feat:优化对多模态/非多模态replyer的配置

This commit is contained in:
SengokuCola
2026-04-11 19:30:23 +08:00
parent c0230fc313
commit d9b3440169
12 changed files with 150 additions and 44 deletions

View File

@@ -1,7 +1,7 @@
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
import random
@@ -23,6 +23,7 @@ from src.common.data_models.reply_generation_data_models import (
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.config.model_configs import ModelInfo
from src.core.types import ActionInfo
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.maisaka.context_messages import (
@@ -59,13 +60,15 @@ class BaseMaisakaReplyGenerator:
request_type: str = "maisaka_replyer",
llm_client_cls: Any,
load_prompt_func: Callable[..., str],
enable_visual_message: bool,
enable_visual_message: Optional[bool],
replyer_mode: Literal["text", "multimodal", "auto"],
) -> None:
self.chat_stream = chat_stream
self.request_type = request_type
self._llm_client_cls = llm_client_cls
self._load_prompt = load_prompt_func
self._enable_visual_message = enable_visual_message
self._replyer_mode = replyer_mode
self.express_model = llm_client_cls(
task_name="replyer",
request_type=request_type,
@@ -265,8 +268,9 @@ class BaseMaisakaReplyGenerator:
def _build_visual_user_message(
self,
message: SessionBackedMessage,
enable_visual_message: bool,
) -> Optional[Message]:
if not self._enable_visual_message:
if not enable_visual_message:
return None
raw_message = clone_message_sequence(message.raw_message)
@@ -283,7 +287,11 @@ class BaseMaisakaReplyGenerator:
)
return visual_message.to_llm_message()
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
def _build_history_messages(
self,
chat_history: List[LLMContextMessage],
enable_visual_message: bool,
) -> List[Message]:
bot_nickname = global_config.bot.nickname.strip() or "Bot"
default_user_name = global_config.maisaka.cli_user_name.strip() or "User"
messages: List[Message] = []
@@ -300,7 +308,7 @@ class BaseMaisakaReplyGenerator:
)
continue
visual_message = self._build_visual_user_message(message)
visual_message = self._build_visual_user_message(message, enable_visual_message)
if visual_message is not None:
messages.append(visual_message)
continue
@@ -337,6 +345,7 @@ class BaseMaisakaReplyGenerator:
reply_reason: str,
expression_habits: str = "",
stream_id: Optional[str] = None,
enable_visual_message: bool = False,
) -> List[Message]:
messages: List[Message] = []
system_prompt = self._build_system_prompt(
@@ -348,10 +357,21 @@ class BaseMaisakaReplyGenerator:
instruction = self._build_reply_instruction()
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
messages.extend(self._build_history_messages(chat_history))
messages.extend(self._build_history_messages(chat_history, enable_visual_message))
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
return messages
def _resolve_enable_visual_message(self, model_info: Optional[ModelInfo] = None) -> bool:
if self._enable_visual_message is not None:
return self._enable_visual_message
if self._replyer_mode == "multimodal":
if model_info is not None and not model_info.visual:
raise ValueError(f"replyer_mode=multimodal但模型 '{model_info.name}' 未开启 visual无法使用多模态 replyer")
return True
if self._replyer_mode == "text":
return False
return bool(model_info.visual) if model_info is not None else False
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
if stream_id:
return stream_id
@@ -494,7 +514,19 @@ class BaseMaisakaReplyGenerator:
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
def message_factory(_client: object) -> List[Message]:
def message_factory(_client: object, model_info: Optional[ModelInfo] = None) -> List[Message]:
nonlocal prompt_ms, prompt_preview, request_messages
prompt_started_at = time.perf_counter()
request_messages = self._build_request_messages(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
stream_id=stream_id,
enable_visual_message=self._resolve_enable_visual_message(model_info),
)
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
return request_messages
result.completion.request_prompt = prompt_preview
@@ -531,6 +563,8 @@ class BaseMaisakaReplyGenerator:
)
return finalize(False)
result.completion.request_prompt = prompt_preview
result.request_messages = serialize_prompt_messages(request_messages)
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
response_text = (generation_result.response or "").strip()
result.success = bool(response_text)