让Gemini的图像可用,修复部分typing

This commit is contained in:
UnCLAS-Prommer
2025-08-03 00:49:19 +08:00
parent 38930b0ceb
commit 9afa549aee
5 changed files with 88 additions and 57 deletions

View File

@@ -1,8 +1,8 @@
raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider")
import asyncio
import io
import base64
from collections.abc import Iterable
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any, List
from google import genai
from google.genai import types
@@ -17,7 +17,7 @@ from google.genai.errors import (
from src.config.api_ada_configs import ModelInfo, APIProvider
from .base_client import APIResponse, UsageRecord, BaseClient
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (
RespParseException,
NetworkConnectionError,
@@ -54,20 +54,21 @@ def _convert_messages(
role = "user"
# 添加Content
content: types.Part | list
if isinstance(message.content, str):
content = types.Part.from_text(message.content)
content = [types.Part.from_text(text=message.content)]
elif isinstance(message.content, list):
content = []
content: List[types.Part] = []
for item in message.content:
if isinstance(item, tuple):
content.append(types.Part.from_bytes(data=item[1], mime_type=f"image/{item[0].lower()}"))
content.append(
types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
)
elif isinstance(item, str):
content.append(types.Part.from_text(item))
content.append(types.Part.from_text(text=item))
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return types.Content(role=role, content=content)
return types.Content(role=role, parts=content)
temp_list: list[types.Content] = []
system_instructions: list[str] = []
@@ -76,7 +77,7 @@ def _convert_messages(
if isinstance(message.content, str):
system_instructions.append(message.content)
else:
raise RuntimeError("你tm怎么往system里面塞图片base64")
raise ValueError("你tm怎么往system里面塞图片base64")
elif message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
@@ -135,9 +136,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
def _process_delta(
delta: GenerateContentResponse,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, dict]],
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
):
if not hasattr(delta, "candidates") or len(delta.candidates) == 0:
if not hasattr(delta, "candidates") or not delta.candidates:
raise RespParseException(delta, "响应解析失败缺失candidates字段")
if delta.text:
@@ -148,11 +149,13 @@ def _process_delta(
try:
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
if not call.id or not call.name:
raise RespParseException(delta, "响应解析失败工具调用缺失id或name字段")
tool_calls_buffer.append(
(
call.id,
call.name,
call.args,
call.args or {}, # 如果args是None则转换为一个空字典
)
)
except Exception as e:
@@ -201,7 +204,7 @@ async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
async def _default_stream_response_handler(
resp_stream: Iterator[GenerateContentResponse],
resp_stream: AsyncIterator[GenerateContentResponse],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
@@ -232,9 +235,9 @@ async def _default_stream_response_handler(
if chunk.usage_metadata:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
chunk.usage_metadata.prompt_token_count,
chunk.usage_metadata.candidates_token_count + chunk.usage_metadata.thoughts_token_count,
chunk.usage_metadata.total_token_count,
chunk.usage_metadata.prompt_token_count or 0,
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
chunk.usage_metadata.total_token_count or 0,
)
try:
return _build_stream_api_resp(
@@ -257,7 +260,7 @@ def _default_normal_response_parser(
"""
api_response = APIResponse()
if not hasattr(resp, "candidates") or len(resp.candidates) == 0:
if not hasattr(resp, "candidates") or not resp.candidates:
raise RespParseException(resp, "响应解析失败缺失candidates字段")
if resp.text:
@@ -269,15 +272,17 @@ def _default_normal_response_parser(
try:
if not isinstance(call.args, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args))
if not call.id or not call.name:
raise RespParseException(resp, "响应解析失败工具调用缺失id或name字段")
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args or {}))
except Exception as e:
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
if resp.usage_metadata:
_usage_record = (
resp.usage_metadata.prompt_token_count,
resp.usage_metadata.candidates_token_count + resp.usage_metadata.thoughts_token_count,
resp.usage_metadata.total_token_count,
resp.usage_metadata.prompt_token_count or 0,
(resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0),
resp.usage_metadata.total_token_count or 0,
)
else:
_usage_record = None
@@ -287,6 +292,7 @@ def _default_normal_response_parser(
return api_response, _usage_record
@client_registry.register_client_class("gemini")
class GeminiClient(BaseClient):
client: genai.Client
@@ -307,7 +313,7 @@ class GeminiClient(BaseClient):
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None],
[AsyncIterator[GenerateContentResponse], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
@@ -398,7 +404,7 @@ class GeminiClient(BaseClient):
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from None
raise RespNotOkException(e.code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
@@ -438,14 +444,14 @@ class GeminiClient(BaseClient):
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code) from None
raise RespNotOkException(e.code) from None
except Exception as e:
raise NetworkConnectionError() from e
response = APIResponse()
# 解析嵌入响应和使用情况
if hasattr(raw_response, "embeddings"):
if hasattr(raw_response, "embeddings") and raw_response.embeddings:
response.embedding = raw_response.embeddings[0].values
else:
raise RespParseException(raw_response, "响应解析失败缺失embeddings字段")
@@ -459,3 +465,10 @@ class GeminiClient(BaseClient):
)
return response
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]