解决openai_client的lint问题
This commit is contained in:
@@ -3,7 +3,8 @@ import io
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Any
|
||||
from typing import Callable, Any, Coroutine, Optional
|
||||
from json_repair import repair_json
|
||||
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
@@ -20,11 +21,9 @@ from openai.types.chat import (
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
|
||||
from .base_client import APIResponse, UsageRecord
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from .base_client import BaseClient, client_registry
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||
from ..exceptions import (
|
||||
RespParseException,
|
||||
NetworkConnectionError,
|
||||
@@ -82,7 +81,7 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
|
||||
raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||
ret["tool_call_id"] = message.tool_call_id
|
||||
|
||||
return ret
|
||||
return ret # type: ignore
|
||||
|
||||
return [_convert_message_item(message) for message in messages]
|
||||
|
||||
@@ -143,10 +142,10 @@ def _process_delta(
|
||||
# 接收content
|
||||
if has_rc_attr_flag:
|
||||
# 有独立的推理内容块,则无需考虑content内容的判读
|
||||
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||||
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
|
||||
# 如果有推理内容,则将其写入推理内容缓冲区
|
||||
assert isinstance(delta.reasoning_content, str)
|
||||
rc_delta_buffer.write(delta.reasoning_content)
|
||||
assert isinstance(delta.reasoning_content, str) # type: ignore
|
||||
rc_delta_buffer.write(delta.reasoning_content) # type: ignore
|
||||
elif delta.content:
|
||||
# 如果有正式内容,则将其写入正式内容缓冲区
|
||||
fc_delta_buffer.write(delta.content)
|
||||
@@ -173,15 +172,18 @@ def _process_delta(
|
||||
|
||||
if tool_call_delta.index >= len(tool_calls_buffer):
|
||||
# 调用索引号大于等于缓冲区长度,说明是新的工具调用
|
||||
tool_calls_buffer.append(
|
||||
(
|
||||
tool_call_delta.id,
|
||||
tool_call_delta.function.name,
|
||||
io.StringIO(),
|
||||
if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name:
|
||||
tool_calls_buffer.append(
|
||||
(
|
||||
tool_call_delta.id,
|
||||
tool_call_delta.function.name,
|
||||
io.StringIO(),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning("工具调用索引号大于等于缓冲区长度,但缺少ID或函数信息。")
|
||||
|
||||
if tool_call_delta.function.arguments:
|
||||
if tool_call_delta.function and tool_call_delta.function.arguments:
|
||||
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
|
||||
tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments)
|
||||
|
||||
@@ -212,7 +214,7 @@ def _build_stream_api_resp(
|
||||
raw_arg_data = arguments_buffer.getvalue()
|
||||
arguments_buffer.close()
|
||||
try:
|
||||
arguments = json.loads(raw_arg_data)
|
||||
arguments = json.loads(repair_json(raw_arg_data))
|
||||
if not isinstance(arguments, dict):
|
||||
raise RespParseException(
|
||||
None,
|
||||
@@ -235,7 +237,7 @@ def _build_stream_api_resp(
|
||||
async def _default_stream_response_handler(
|
||||
resp_stream: AsyncStream[ChatCompletionChunk],
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
) -> tuple[APIResponse, tuple[int, int, int]]:
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
"""
|
||||
流式响应处理函数 - 处理OpenAI API的流式响应
|
||||
:param resp_stream: 流式响应对象
|
||||
@@ -309,7 +311,7 @@ pattern = re.compile(
|
||||
|
||||
def _default_normal_response_parser(
|
||||
resp: ChatCompletion,
|
||||
) -> tuple[APIResponse, tuple[int, int, int]]:
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
"""
|
||||
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
|
||||
:param resp: 响应对象
|
||||
@@ -343,7 +345,7 @@ def _default_normal_response_parser(
|
||||
api_response.tool_calls = []
|
||||
for call in message_part.tool_calls:
|
||||
try:
|
||||
arguments = json.loads(call.function.arguments)
|
||||
arguments = json.loads(repair_json(call.function.arguments))
|
||||
if not isinstance(arguments, dict):
|
||||
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
|
||||
@@ -384,26 +386,31 @@ class OpenaiClient(BaseClient):
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Callable[
|
||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||
tuple[APIResponse, tuple[int, int, int]],
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||
]
|
||||
] = None,
|
||||
async_response_parser: Optional[
|
||||
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
||||
] = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取对话响应
|
||||
:param model_info: 模型信息
|
||||
:param message_list: 对话体
|
||||
:param tool_options: 工具选项(可选,默认为None)
|
||||
:param max_tokens: 最大token数(可选,默认为1024)
|
||||
:param temperature: 温度(可选,默认为0.7)
|
||||
:param response_format: 响应格式(可选,默认为 NotGiven )
|
||||
:param stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler)
|
||||
:param async_response_parser: 响应解析函数(可选,默认为default_response_parser)
|
||||
:param interrupt_flag: 中断信号量(可选,默认为None)
|
||||
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
||||
Args:
|
||||
model_info: 模型信息
|
||||
message_list: 对话体
|
||||
tool_options: 工具选项(可选,默认为None)
|
||||
max_tokens: 最大token数(可选,默认为1024)
|
||||
temperature: 温度(可选,默认为0.7)
|
||||
response_format: 响应格式(可选,默认为 NotGiven )
|
||||
stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler)
|
||||
async_response_parser: 响应解析函数(可选,默认为default_response_parser)
|
||||
interrupt_flag: 中断信号量(可选,默认为None)
|
||||
Returns:
|
||||
(响应文本, 推理文本, 工具调用, 其他数据)
|
||||
"""
|
||||
if stream_response_handler is None:
|
||||
stream_response_handler = _default_stream_response_handler
|
||||
@@ -414,7 +421,7 @@ class OpenaiClient(BaseClient):
|
||||
# 将messages构造为OpenAI API所需的格式
|
||||
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
|
||||
# 将tool_options转换为OpenAI API所需的格式
|
||||
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN
|
||||
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
|
||||
|
||||
try:
|
||||
if model_info.force_stream_mode:
|
||||
@@ -426,7 +433,7 @@ class OpenaiClient(BaseClient):
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
response_format=response_format.to_dict() if response_format else NOT_GIVEN,
|
||||
response_format=NOT_GIVEN,
|
||||
)
|
||||
)
|
||||
while not req_task.done():
|
||||
@@ -447,7 +454,7 @@ class OpenaiClient(BaseClient):
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
response_format=response_format.to_dict() if response_format else NOT_GIVEN,
|
||||
response_format=NOT_GIVEN,
|
||||
)
|
||||
)
|
||||
while not req_task.done():
|
||||
@@ -514,9 +521,9 @@ class OpenaiClient(BaseClient):
|
||||
response.usage = UsageRecord(
|
||||
model_name=model_info.name,
|
||||
provider_name=model_info.api_provider,
|
||||
prompt_tokens=raw_response.usage.prompt_tokens,
|
||||
completion_tokens=raw_response.usage.completion_tokens,
|
||||
total_tokens=raw_response.usage.total_tokens,
|
||||
prompt_tokens=raw_response.usage.prompt_tokens or 0,
|
||||
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
|
||||
total_tokens=raw_response.usage.total_tokens or 0,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user