feat:优化对多模态/非多模态replyer的配置
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import random
|
||||
|
||||
@@ -23,6 +23,7 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.config.model_configs import ModelInfo
|
||||
from src.core.types import ActionInfo
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import (
|
||||
@@ -59,13 +60,15 @@ class BaseMaisakaReplyGenerator:
|
||||
request_type: str = "maisaka_replyer",
|
||||
llm_client_cls: Any,
|
||||
load_prompt_func: Callable[..., str],
|
||||
enable_visual_message: bool,
|
||||
enable_visual_message: Optional[bool],
|
||||
replyer_mode: Literal["text", "multimodal", "auto"],
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self._llm_client_cls = llm_client_cls
|
||||
self._load_prompt = load_prompt_func
|
||||
self._enable_visual_message = enable_visual_message
|
||||
self._replyer_mode = replyer_mode
|
||||
self.express_model = llm_client_cls(
|
||||
task_name="replyer",
|
||||
request_type=request_type,
|
||||
@@ -265,8 +268,9 @@ class BaseMaisakaReplyGenerator:
|
||||
def _build_visual_user_message(
|
||||
self,
|
||||
message: SessionBackedMessage,
|
||||
enable_visual_message: bool,
|
||||
) -> Optional[Message]:
|
||||
if not self._enable_visual_message:
|
||||
if not enable_visual_message:
|
||||
return None
|
||||
|
||||
raw_message = clone_message_sequence(message.raw_message)
|
||||
@@ -283,7 +287,11 @@ class BaseMaisakaReplyGenerator:
|
||||
)
|
||||
return visual_message.to_llm_message()
|
||||
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
def _build_history_messages(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
enable_visual_message: bool,
|
||||
) -> List[Message]:
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
default_user_name = global_config.maisaka.cli_user_name.strip() or "User"
|
||||
messages: List[Message] = []
|
||||
@@ -300,7 +308,7 @@ class BaseMaisakaReplyGenerator:
|
||||
)
|
||||
continue
|
||||
|
||||
visual_message = self._build_visual_user_message(message)
|
||||
visual_message = self._build_visual_user_message(message, enable_visual_message)
|
||||
if visual_message is not None:
|
||||
messages.append(visual_message)
|
||||
continue
|
||||
@@ -337,6 +345,7 @@ class BaseMaisakaReplyGenerator:
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
stream_id: Optional[str] = None,
|
||||
enable_visual_message: bool = False,
|
||||
) -> List[Message]:
|
||||
messages: List[Message] = []
|
||||
system_prompt = self._build_system_prompt(
|
||||
@@ -348,10 +357,21 @@ class BaseMaisakaReplyGenerator:
|
||||
instruction = self._build_reply_instruction()
|
||||
|
||||
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
|
||||
messages.extend(self._build_history_messages(chat_history))
|
||||
messages.extend(self._build_history_messages(chat_history, enable_visual_message))
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
|
||||
return messages
|
||||
|
||||
def _resolve_enable_visual_message(self, model_info: Optional[ModelInfo] = None) -> bool:
|
||||
if self._enable_visual_message is not None:
|
||||
return self._enable_visual_message
|
||||
if self._replyer_mode == "multimodal":
|
||||
if model_info is not None and not model_info.visual:
|
||||
raise ValueError(f"replyer_mode=multimodal,但模型 '{model_info.name}' 未开启 visual,无法使用多模态 replyer")
|
||||
return True
|
||||
if self._replyer_mode == "text":
|
||||
return False
|
||||
return bool(model_info.visual) if model_info is not None else False
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
if stream_id:
|
||||
return stream_id
|
||||
@@ -494,7 +514,19 @@ class BaseMaisakaReplyGenerator:
|
||||
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
|
||||
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
|
||||
|
||||
def message_factory(_client: object) -> List[Message]:
|
||||
def message_factory(_client: object, model_info: Optional[ModelInfo] = None) -> List[Message]:
|
||||
nonlocal prompt_ms, prompt_preview, request_messages
|
||||
prompt_started_at = time.perf_counter()
|
||||
request_messages = self._build_request_messages(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
stream_id=stream_id,
|
||||
enable_visual_message=self._resolve_enable_visual_message(model_info),
|
||||
)
|
||||
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
|
||||
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
|
||||
return request_messages
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
@@ -531,6 +563,8 @@ class BaseMaisakaReplyGenerator:
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
result.request_messages = serialize_prompt_messages(request_messages)
|
||||
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
|
||||
Reference in New Issue
Block a user