解决openai_client的lint问题

This commit is contained in:
UnCLAS-Prommer
2025-07-31 00:49:59 +08:00
parent 5413c41a01
commit 82b5230df1
3 changed files with 97 additions and 195 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import io
from collections.abc import Iterable
from typing import Callable, Iterator, TypeVar, AsyncIterator
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any
from google import genai
from google.genai import types
@@ -14,11 +14,9 @@ from google.genai.errors import (
FunctionInvocationError,
)
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient
from ..exceptions import (
RespParseException,
NetworkConnectionError,
@@ -29,7 +27,6 @@ from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端")
T = TypeVar("T")
@@ -63,11 +60,7 @@ def _convert_messages(
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
types.Part.from_bytes(
data=item[1], mime_type=f"image/{item[0].lower()}"
)
)
content.append(types.Part.from_bytes(data=item[1], mime_type=f"image/{item[0].lower()}"))
elif isinstance(item, str):
content.append(types.Part.from_text(item))
else:
@@ -122,20 +115,15 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
:param tool_option: 工具选项对象
:return: 转换后的Gemini工具选项对象
"""
ret = {
ret: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {
param.name: _convert_tool_param(param)
for param in tool_option.params
},
"required": [
param.name for param in tool_option.params if param.required
],
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
"required": [param.name for param in tool_option.params if param.required],
}
ret1 = types.FunctionDeclaration(**ret)
return ret1
@@ -157,12 +145,8 @@ def _process_delta(
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
for call in delta.function_calls:
try:
if not isinstance(
call.args, dict
): # gemini返回的function call参数就是dict格式的了
raise RespParseException(
delta, "响应解析失败,工具调用参数无法解析为字典类型"
)
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
tool_calls_buffer.append(
(
call.id,
@@ -178,6 +162,7 @@ def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
) -> APIResponse:
# sourcery skip: simplify-len-comparison, use-assigned-variable
resp = APIResponse()
if _fc_delta_buffer.tell() > 0:
@@ -193,8 +178,7 @@ def _build_stream_api_resp(
if not isinstance(arguments, dict):
raise RespParseException(
None,
"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n"
f"{arguments_buffer}",
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}",
)
else:
arguments = None
@@ -218,16 +202,14 @@ async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
async def _default_stream_response_handler(
resp_stream: Iterator[GenerateContentResponse],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, tuple[int, int, int]]:
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
流式响应处理函数 - 处理Gemini API的流式响应
:param resp_stream: 流式响应对象,是一个神秘的iterator我完全不知道这个玩意能不能跑不过遍历一遍之后它就空了如果跑不了一点的话可以考虑改成别的东西
:return: APIResponse对象
"""
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[
tuple[str, str, dict]
] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
@@ -250,8 +232,7 @@ async def _default_stream_response_handler(
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
chunk.usage_metadata.prompt_token_count,
chunk.usage_metadata.candidates_token_count
+ chunk.usage_metadata.thoughts_token_count,
chunk.usage_metadata.candidates_token_count + chunk.usage_metadata.thoughts_token_count,
chunk.usage_metadata.total_token_count,
)
try:
@@ -267,7 +248,7 @@ async def _default_stream_response_handler(
def _default_normal_response_parser(
resp: GenerateContentResponse,
) -> tuple[APIResponse, tuple[int, int, int]]:
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
解析对话补全响应 - 将Gemini API响应解析为APIResponse对象
:param resp: 响应对象
@@ -286,20 +267,15 @@ def _default_normal_response_parser(
for call in resp.function_calls:
try:
if not isinstance(call.args, dict):
raise RespParseException(
resp, "响应解析失败,工具调用参数无法解析为字典类型"
)
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args))
except Exception as e:
raise RespParseException(
resp, "响应解析失败,无法解析工具调用参数"
) from e
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
if resp.usage_metadata:
_usage_record = (
resp.usage_metadata.prompt_token_count,
resp.usage_metadata.candidates_token_count
+ resp.usage_metadata.thoughts_token_count,
resp.usage_metadata.candidates_token_count + resp.usage_metadata.thoughts_token_count,
resp.usage_metadata.total_token_count,
)
else:
@@ -311,55 +287,13 @@ def _default_normal_response_parser(
class GeminiClient(BaseClient):
client: genai.Client
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
# 不再在初始化时创建固定的client而是在请求时动态创建
self._clients_cache = {} # API Key -> genai.Client 的缓存
def _get_client(self, api_key: str = None) -> genai.Client:
"""获取或创建对应API Key的客户端"""
if api_key is None:
api_key = self.api_provider.get_current_api_key()
if not api_key:
raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key")
# 使用缓存避免重复创建客户端
if api_key not in self._clients_cache:
self._clients_cache[api_key] = genai.Client(api_key=api_key)
return self._clients_cache[api_key]
async def _execute_with_fallback(self, func, *args, **kwargs):
"""执行请求并在失败时切换API Key"""
current_api_key = self.api_provider.get_current_api_key()
max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1
for attempt in range(max_attempts):
try:
client = self._get_client(current_api_key)
result = await func(client, *args, **kwargs)
# 成功时重置失败计数
self.api_provider.reset_key_failures(current_api_key)
return result
except (ClientError, ServerError) as e:
# 记录失败并尝试下一个API Key
logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}")
if attempt < max_attempts - 1: # 还有重试机会
next_api_key = self.api_provider.mark_key_failed(current_api_key)
if next_api_key and next_api_key != current_api_key:
current_api_key = next_api_key
logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}")
continue
# 所有API Key都失败了重新抛出异常
raise RespNotOkException(e.status_code, e.message) from e
except Exception as e:
# 其他异常直接抛出
raise e
self.client = genai.Client(
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
async def get_response(
self,
@@ -370,12 +304,15 @@ class GeminiClient(BaseClient):
temperature: float = 0.7,
thinking_budget: int = 0,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[GenerateContentResponse], APIResponse]
| None = None,
stream_response_handler: Optional[
Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
@@ -392,39 +329,6 @@ class GeminiClient(BaseClient):
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
return await self._execute_with_fallback(
self._get_response_internal,
model_info,
message_list,
tool_options,
max_tokens,
temperature,
thinking_budget,
response_format,
stream_response_handler,
async_response_parser,
interrupt_flag,
)
async def _get_response_internal(
self,
client: genai.Client,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
thinking_budget: int = 0,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[GenerateContentResponse], APIResponse]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""内部方法执行实际的API调用"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
@@ -462,7 +366,7 @@ class GeminiClient(BaseClient):
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
client.aio.models.generate_content_stream(
self.client.aio.models.generate_content_stream(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
@@ -474,12 +378,10 @@ class GeminiClient(BaseClient):
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(
req_task.result(), interrupt_flag
)
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else:
req_task = asyncio.create_task(
client.aio.models.generate_content(
self.client.aio.models.generate_content(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
@@ -495,13 +397,13 @@ class GeminiClient(BaseClient):
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from e
raise RespNotOkException(e.status_code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from e
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
except Exception as e:
raise NetworkConnectionError() from e
@@ -527,30 +429,15 @@ class GeminiClient(BaseClient):
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
return await self._execute_with_fallback(
self._get_embedding_internal,
model_info,
embedding_input,
)
async def _get_embedding_internal(
self,
client: genai.Client,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""内部方法执行实际的嵌入API调用"""
try:
raw_response: types.EmbedContentResponse = (
await client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)
raw_response: types.EmbedContentResponse = await self.client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code) from e
raise RespNotOkException(e.status_code) from None
except Exception as e:
raise NetworkConnectionError() from e