解决openai_client的lint问题

This commit is contained in:
UnCLAS-Prommer
2025-07-31 00:49:59 +08:00
parent 5413c41a01
commit 82b5230df1
3 changed files with 97 additions and 195 deletions

View File

@@ -3,7 +3,8 @@ import io
import json
import re
from collections.abc import Iterable
from typing import Callable, Any
from typing import Callable, Any, Coroutine, Optional
from json_repair import repair_json
from openai import (
AsyncOpenAI,
@@ -20,11 +21,9 @@ from openai.types.chat import (
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from .base_client import BaseClient, client_registry
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (
RespParseException,
NetworkConnectionError,
@@ -82,7 +81,7 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret["tool_call_id"] = message.tool_call_id
return ret
return ret # type: ignore
return [_convert_message_item(message) for message in messages]
@@ -143,10 +142,10 @@ def _process_delta(
# 接收content
if has_rc_attr_flag:
# 有独立的推理内容块则无需考虑content内容的判读
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 如果有推理内容,则将其写入推理内容缓冲区
assert isinstance(delta.reasoning_content, str)
rc_delta_buffer.write(delta.reasoning_content)
assert isinstance(delta.reasoning_content, str) # type: ignore
rc_delta_buffer.write(delta.reasoning_content) # type: ignore
elif delta.content:
# 如果有正式内容,则将其写入正式内容缓冲区
fc_delta_buffer.write(delta.content)
@@ -173,15 +172,18 @@ def _process_delta(
if tool_call_delta.index >= len(tool_calls_buffer):
# 调用索引号大于等于缓冲区长度,说明是新的工具调用
tool_calls_buffer.append(
(
tool_call_delta.id,
tool_call_delta.function.name,
io.StringIO(),
if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name:
tool_calls_buffer.append(
(
tool_call_delta.id,
tool_call_delta.function.name,
io.StringIO(),
)
)
)
else:
logger.warning("工具调用索引号大于等于缓冲区长度但缺少ID或函数信息。")
if tool_call_delta.function.arguments:
if tool_call_delta.function and tool_call_delta.function.arguments:
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments)
@@ -212,7 +214,7 @@ def _build_stream_api_resp(
raw_arg_data = arguments_buffer.getvalue()
arguments_buffer.close()
try:
arguments = json.loads(raw_arg_data)
arguments = json.loads(repair_json(raw_arg_data))
if not isinstance(arguments, dict):
raise RespParseException(
None,
@@ -235,7 +237,7 @@ def _build_stream_api_resp(
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, tuple[int, int, int]]:
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
@@ -309,7 +311,7 @@ pattern = re.compile(
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, tuple[int, int, int]]:
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
@@ -343,7 +345,7 @@ def _default_normal_response_parser(
api_response.tool_calls = []
for call in message_part.tool_calls:
try:
arguments = json.loads(call.function.arguments)
arguments = json.loads(repair_json(call.function.arguments))
if not isinstance(arguments, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
@@ -384,26 +386,31 @@ class OpenaiClient(BaseClient):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
stream_response_handler: Optional[
Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数(可选,默认为1024
:param temperature: 温度(可选,默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
:param async_response_parser: 响应解析函数可选默认为default_response_parser
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
Args:
model_info: 模型信息
message_list: 对话体
tool_options: 工具选项(可选,默认为None
max_tokens: 最大token数(可选,默认为1024
temperature: 温度可选默认为0.7
response_format: 响应格式(可选,默认为 NotGiven
stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
async_response_parser: 响应解析函数可选默认为default_response_parser
interrupt_flag: 中断信号量可选默认为None
Returns:
(响应文本, 推理文本, 工具调用, 其他数据)
"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
@@ -414,7 +421,7 @@ class OpenaiClient(BaseClient):
# 将messages构造为OpenAI API所需的格式
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
# 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
try:
if model_info.force_stream_mode:
@@ -426,7 +433,7 @@ class OpenaiClient(BaseClient):
temperature=temperature,
max_tokens=max_tokens,
stream=True,
response_format=response_format.to_dict() if response_format else NOT_GIVEN,
response_format=NOT_GIVEN,
)
)
while not req_task.done():
@@ -447,7 +454,7 @@ class OpenaiClient(BaseClient):
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format=response_format.to_dict() if response_format else NOT_GIVEN,
response_format=NOT_GIVEN,
)
)
while not req_task.done():
@@ -514,9 +521,9 @@ class OpenaiClient(BaseClient):
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=raw_response.usage.prompt_tokens,
completion_tokens=raw_response.usage.completion_tokens,
total_tokens=raw_response.usage.total_tokens,
prompt_tokens=raw_response.usage.prompt_tokens or 0,
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
total_tokens=raw_response.usage.total_tokens or 0,
)
return response