Merge branch 'r-dev' of https://github.com/A-Dawn/MaiBot into r-dev
This commit is contained in:
@@ -11,8 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
@@ -20,8 +19,8 @@ from src.common.logger import get_logger
|
||||
from src.core.types import ActionInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.common.data_models.planned_action_data_models import PlannedAction
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -38,7 +37,7 @@ def _get_replyer(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||
) -> Optional[MaisakaReplyGenerator]:
|
||||
"""获取回复器对象"""
|
||||
if not chat_id and not chat_stream:
|
||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||
@@ -100,7 +99,7 @@ async def generate_reply(
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
|
||||
chosen_actions: Optional[List["PlannedAction"]] = None,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
|
||||
@@ -267,6 +267,46 @@ def _parse_data_url_image(image_url: str) -> Tuple[str, str]:
|
||||
return image_format, image_base64
|
||||
|
||||
|
||||
def _append_image_content(message_builder: MessageBuilder, content_item: Any) -> bool:
|
||||
"""向消息构建器追加图片片段。
|
||||
|
||||
兼容两种输入格式:
|
||||
1. 旧序列化格式中的 `(image_format, image_base64)` 元组。
|
||||
2. 标准字典片段中的 Data URL 或 `image_format`/`image_base64` 字段。
|
||||
"""
|
||||
|
||||
if isinstance(content_item, (tuple, list)) and len(content_item) == 2:
|
||||
image_format, image_base64 = content_item
|
||||
if not isinstance(image_format, str) or not isinstance(image_base64, str):
|
||||
raise ValueError("图片元组片段必须包含字符串类型的 image_format 和 image_base64")
|
||||
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
if not isinstance(content_item, dict):
|
||||
return False
|
||||
|
||||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||||
if part_type not in {"image", "image_url", "input_image"}:
|
||||
return False
|
||||
|
||||
image_url = content_item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if isinstance(image_url, str):
|
||||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
image_format = content_item.get("image_format")
|
||||
image_base64 = content_item.get("image_base64")
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
raise ValueError("图片片段缺少可识别的图片数据")
|
||||
|
||||
|
||||
def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None:
|
||||
"""将原始消息内容追加到内部消息构建器。
|
||||
|
||||
@@ -293,8 +333,10 @@ def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None
|
||||
if isinstance(content_item, str):
|
||||
message_builder.add_text_content(content_item)
|
||||
continue
|
||||
if _append_image_content(message_builder, content_item):
|
||||
continue
|
||||
if not isinstance(content_item, dict):
|
||||
raise ValueError("消息内容列表中仅支持字符串或字典片段")
|
||||
raise ValueError("消息内容列表中仅支持字符串、图片元组或字典片段")
|
||||
|
||||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||||
if part_type == "text":
|
||||
@@ -304,22 +346,6 @@ def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None
|
||||
message_builder.add_text_content(text_content)
|
||||
continue
|
||||
|
||||
if part_type in {"image", "image_url", "input_image"}:
|
||||
image_url = content_item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if isinstance(image_url, str):
|
||||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
|
||||
image_format = content_item.get("image_format")
|
||||
image_base64 = content_item.get("image_base64")
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
raise ValueError("图片片段缺少可识别的图片数据")
|
||||
|
||||
raise ValueError(f"不支持的消息片段类型: {part_type}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user