feat:优化对多模态/非多模态replyer的配置
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.maisaka.context_messages import SessionBackedMessage
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .maisaka_generator_base import BaseMaisakaReplyGenerator
|
||||
@@ -26,9 +24,6 @@ class MaisakaReplyGenerator(BaseMaisakaReplyGenerator):
|
||||
request_type=request_type,
|
||||
llm_client_cls=llm_client_cls or LLMServiceClient,
|
||||
load_prompt_func=load_prompt_func or load_prompt,
|
||||
enable_visual_message=(
|
||||
global_config.visual.multimodal_replyer
|
||||
if enable_visual_message is None
|
||||
else enable_visual_message
|
||||
),
|
||||
enable_visual_message=enable_visual_message,
|
||||
replyer_mode=global_config.visual.replyer_mode,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import random
|
||||
|
||||
@@ -23,6 +23,7 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.config.model_configs import ModelInfo
|
||||
from src.core.types import ActionInfo
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import (
|
||||
@@ -59,13 +60,15 @@ class BaseMaisakaReplyGenerator:
|
||||
request_type: str = "maisaka_replyer",
|
||||
llm_client_cls: Any,
|
||||
load_prompt_func: Callable[..., str],
|
||||
enable_visual_message: bool,
|
||||
enable_visual_message: Optional[bool],
|
||||
replyer_mode: Literal["text", "multimodal", "auto"],
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self._llm_client_cls = llm_client_cls
|
||||
self._load_prompt = load_prompt_func
|
||||
self._enable_visual_message = enable_visual_message
|
||||
self._replyer_mode = replyer_mode
|
||||
self.express_model = llm_client_cls(
|
||||
task_name="replyer",
|
||||
request_type=request_type,
|
||||
@@ -265,8 +268,9 @@ class BaseMaisakaReplyGenerator:
|
||||
def _build_visual_user_message(
|
||||
self,
|
||||
message: SessionBackedMessage,
|
||||
enable_visual_message: bool,
|
||||
) -> Optional[Message]:
|
||||
if not self._enable_visual_message:
|
||||
if not enable_visual_message:
|
||||
return None
|
||||
|
||||
raw_message = clone_message_sequence(message.raw_message)
|
||||
@@ -283,7 +287,11 @@ class BaseMaisakaReplyGenerator:
|
||||
)
|
||||
return visual_message.to_llm_message()
|
||||
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
def _build_history_messages(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
enable_visual_message: bool,
|
||||
) -> List[Message]:
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
default_user_name = global_config.maisaka.cli_user_name.strip() or "User"
|
||||
messages: List[Message] = []
|
||||
@@ -300,7 +308,7 @@ class BaseMaisakaReplyGenerator:
|
||||
)
|
||||
continue
|
||||
|
||||
visual_message = self._build_visual_user_message(message)
|
||||
visual_message = self._build_visual_user_message(message, enable_visual_message)
|
||||
if visual_message is not None:
|
||||
messages.append(visual_message)
|
||||
continue
|
||||
@@ -337,6 +345,7 @@ class BaseMaisakaReplyGenerator:
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
stream_id: Optional[str] = None,
|
||||
enable_visual_message: bool = False,
|
||||
) -> List[Message]:
|
||||
messages: List[Message] = []
|
||||
system_prompt = self._build_system_prompt(
|
||||
@@ -348,10 +357,21 @@ class BaseMaisakaReplyGenerator:
|
||||
instruction = self._build_reply_instruction()
|
||||
|
||||
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
|
||||
messages.extend(self._build_history_messages(chat_history))
|
||||
messages.extend(self._build_history_messages(chat_history, enable_visual_message))
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
|
||||
return messages
|
||||
|
||||
def _resolve_enable_visual_message(self, model_info: Optional[ModelInfo] = None) -> bool:
|
||||
if self._enable_visual_message is not None:
|
||||
return self._enable_visual_message
|
||||
if self._replyer_mode == "multimodal":
|
||||
if model_info is not None and not model_info.visual:
|
||||
raise ValueError(f"replyer_mode=multimodal,但模型 '{model_info.name}' 未开启 visual,无法使用多模态 replyer")
|
||||
return True
|
||||
if self._replyer_mode == "text":
|
||||
return False
|
||||
return bool(model_info.visual) if model_info is not None else False
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
if stream_id:
|
||||
return stream_id
|
||||
@@ -494,7 +514,19 @@ class BaseMaisakaReplyGenerator:
|
||||
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
|
||||
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
|
||||
|
||||
def message_factory(_client: object) -> List[Message]:
|
||||
def message_factory(_client: object, model_info: Optional[ModelInfo] = None) -> List[Message]:
|
||||
nonlocal prompt_ms, prompt_preview, request_messages
|
||||
prompt_started_at = time.perf_counter()
|
||||
request_messages = self._build_request_messages(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
stream_id=stream_id,
|
||||
enable_visual_message=self._resolve_enable_visual_message(model_info),
|
||||
)
|
||||
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
|
||||
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
|
||||
return request_messages
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
@@ -531,6 +563,8 @@ class BaseMaisakaReplyGenerator:
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
result.request_messages = serialize_prompt_messages(request_messages)
|
||||
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.config.config import global_config
|
||||
@@ -6,10 +6,6 @@ from src.common.logger import get_logger
|
||||
|
||||
from .maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
|
||||
|
||||
@@ -22,7 +18,7 @@ class ReplyerManager:
|
||||
@staticmethod
|
||||
def _get_maisaka_generator_type() -> str:
|
||||
"""返回当前配置下 Maisaka replyer 的消息模式。"""
|
||||
return "multimodal" if global_config.visual.multimodal_replyer else "legacy"
|
||||
return global_config.visual.replyer_mode
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
@@ -30,7 +26,7 @@ class ReplyerManager:
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
replyer_type: str = "default",
|
||||
) -> Optional["DefaultReplyer | PrivateReplyer | Any"]:
|
||||
) -> Optional[MaisakaReplyGenerator]:
|
||||
"""按会话和 replyer 类型获取实例。"""
|
||||
stream_id = chat_stream.session_id if chat_stream else chat_id
|
||||
if not stream_id:
|
||||
|
||||
Reference in New Issue
Block a user