feat:修复gemini tool问题,简化表情包识别,修复非多模态plan图片识别
This commit is contained in:
@@ -249,9 +249,13 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
|
||||
if message.role == RoleType.Tool:
|
||||
if not message.tool_call_id:
|
||||
raise ValueError("Gemini 工具结果消息缺少 tool_call_id")
|
||||
tool_name = tool_name_by_call_id.get(message.tool_call_id)
|
||||
tool_name = (message.tool_name or tool_name_by_call_id.get(message.tool_call_id, "")).strip()
|
||||
if not tool_name:
|
||||
raise ValueError(f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称")
|
||||
raise ValueError(
|
||||
f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称,"
|
||||
"且消息中未携带 tool_name"
|
||||
)
|
||||
tool_name_by_call_id[message.tool_call_id] = tool_name
|
||||
function_response_part = Part.from_function_response(
|
||||
name=tool_name,
|
||||
response=_normalize_function_response_payload(message),
|
||||
|
||||
@@ -75,6 +75,7 @@ class Message:
|
||||
role: RoleType
|
||||
parts: List[MessagePart] = field(default_factory=list)
|
||||
tool_call_id: str | None = None
|
||||
tool_name: str | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -87,6 +88,8 @@ class Message:
|
||||
raise ValueError("消息内容不能为空")
|
||||
if self.role == RoleType.Tool and not self.tool_call_id:
|
||||
raise ValueError("Tool 角色的工具调用 ID 不能为空")
|
||||
if self.tool_name and self.role != RoleType.Tool:
|
||||
raise ValueError("仅当角色为 Tool 时才能设置工具名称")
|
||||
|
||||
@property
|
||||
def content(self) -> str | List[Tuple[str, str] | str]:
|
||||
@@ -122,7 +125,7 @@ class Message:
|
||||
"""
|
||||
return (
|
||||
f"Role: {self.role}, Parts: {self.parts}, "
|
||||
f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}"
|
||||
f"Tool Call ID: {self.tool_call_id}, Tool Name: {self.tool_name}, Tool Calls: {self.tool_calls}"
|
||||
)
|
||||
|
||||
|
||||
@@ -134,6 +137,7 @@ class MessageBuilder:
|
||||
self.__role: RoleType = RoleType.User
|
||||
self.__parts: List[MessagePart] = []
|
||||
self.__tool_call_id: str | None = None
|
||||
self.__tool_name: str | None = None
|
||||
self.__tool_calls: List[ToolCall] | None = None
|
||||
|
||||
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
|
||||
@@ -247,6 +251,15 @@ class MessageBuilder:
|
||||
"""
|
||||
return self.set_tool_call_id(tool_call_id)
|
||||
|
||||
def set_tool_name(self, tool_name: str) -> "MessageBuilder":
|
||||
"""设置 Tool 消息对应的工具名称。"""
|
||||
if self.__role != RoleType.Tool:
|
||||
raise ValueError("仅当角色为 Tool 时才能设置工具名称")
|
||||
if not tool_name:
|
||||
raise ValueError("工具名称不能为空")
|
||||
self.__tool_name = tool_name
|
||||
return self
|
||||
|
||||
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
|
||||
"""设置助手消息中的工具调用列表。
|
||||
|
||||
@@ -276,5 +289,6 @@ class MessageBuilder:
|
||||
role=self.__role,
|
||||
parts=list(self.__parts),
|
||||
tool_call_id=self.__tool_call_id,
|
||||
tool_name=self.__tool_name,
|
||||
tool_calls=list(self.__tool_calls) if self.__tool_calls else None,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from datetime import datetime
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ModelUsage, ModelUser
|
||||
from src.common.logger import get_logger
|
||||
from src.config.model_configs import ModelInfo
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
|
||||
from .model_client.base_client import UsageRecord
|
||||
from .payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
|
||||
|
||||
logger = get_logger("消息压缩工具")
|
||||
|
||||
@@ -131,25 +132,32 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
|
||||
return base64_data
|
||||
|
||||
compressed_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
# 检查content,如有图片则压缩
|
||||
message_builder = MessageBuilder()
|
||||
for content_item in message.content:
|
||||
if isinstance(content_item, tuple):
|
||||
# 图片,进行压缩
|
||||
message_builder.add_image_content(
|
||||
content_item[0],
|
||||
compress_base64_image(content_item[1], target_size=img_target_size),
|
||||
)
|
||||
else:
|
||||
message_builder.add_text_content(content_item)
|
||||
compressed_messages.append(message_builder.build())
|
||||
else:
|
||||
compressed_messages.append(message)
|
||||
def rebuild_message_with_compressed_images(message: Message) -> Message:
|
||||
"""重建消息并压缩其中的图片,同时保留角色与工具元信息。"""
|
||||
if not any(isinstance(part, ImageMessagePart) for part in message.parts):
|
||||
return message
|
||||
|
||||
return compressed_messages
|
||||
message_builder = MessageBuilder().set_role(message.role)
|
||||
if message.role == RoleType.Assistant and message.tool_calls:
|
||||
message_builder.set_tool_calls(message.tool_calls)
|
||||
if message.role == RoleType.Tool and message.tool_call_id:
|
||||
message_builder.set_tool_call_id(message.tool_call_id)
|
||||
if message.role == RoleType.Tool and message.tool_name:
|
||||
message_builder.set_tool_name(message.tool_name)
|
||||
|
||||
for message_part in message.parts:
|
||||
if isinstance(message_part, ImageMessagePart):
|
||||
message_builder.add_image_content(
|
||||
message_part.image_format,
|
||||
compress_base64_image(message_part.image_base64, target_size=img_target_size),
|
||||
)
|
||||
continue
|
||||
if isinstance(message_part, TextMessagePart):
|
||||
message_builder.add_text_content(message_part.text)
|
||||
|
||||
return message_builder.build()
|
||||
|
||||
return [rebuild_message_with_compressed_images(message) for message in messages]
|
||||
|
||||
|
||||
class LLMUsageRecorder:
|
||||
|
||||
Reference in New Issue
Block a user