Files
mai-bot/src/llm_models/model_client/openai_client.py
DrSmoothl 0a08973c41 feat: Enhance emoji and image management with asynchronous background processing
- Added support for scheduling background tasks to build emoji and image descriptions when not found in cache.
- Improved error handling and logging for emoji and image processing.
- Updated `SessionMessage` processing to allow for optional heavy media analysis and voice transcription.
- Refactored logging messages for better clarity and consistency across various modules.
- Introduced a new function to build outbound log previews for messages, enhancing logging capabilities.
2026-03-26 23:03:47 +08:00

1001 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
import asyncio
import base64
import io
import json
import re
from json_repair import repair_json
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
from openai._types import FileTypes, Omit, omit
from openai.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageFunctionToolCallParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from openai.types.shared_params.function_definition import FunctionDefinition
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from src.common.logger import get_logger
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
from src.llm_models.exceptions import (
EmptyResponseException,
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from src.llm_models.openai_compat import (
build_openai_compatible_client_config,
split_openai_request_overrides,
)
from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
from .adapter_base import (
AdapterClient,
ProviderResponseParser,
ProviderStreamResponseHandler,
await_task_with_interrupt,
)
from .base_client import (
APIResponse,
AudioTranscriptionRequest,
EmbeddingRequest,
ResponseRequest,
UsageTuple,
client_registry,
)
logger = get_logger("llm_models")
THINK_CONTENT_PATTERN = re.compile(
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
re.DOTALL,
)
"""用于解析 `<think>` 推理块的正则表达式。"""
CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = {
"max_tokens",
"messages",
"model",
"response_format",
"stream",
"temperature",
"tools",
}
"""由当前客户端显式承载、不应再落入 `extra_body` 的字段集合。"""
OpenAIStreamResponseHandler = Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]],
]
"""OpenAI 流式响应处理函数类型。"""
OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]]
"""OpenAI 非流式响应解析函数类型。"""
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
"""将配置中的推理解析模式收敛为枚举值。
Args:
parse_mode: 原始解析模式配置。
Returns:
ReasoningParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。
"""
if isinstance(parse_mode, ReasoningParseMode):
return parse_mode
try:
return ReasoningParseMode(parse_mode)
except ValueError:
logger.warning(f"未识别的推理解析模式 {parse_mode},已回退为 auto")
return ReasoningParseMode.AUTO
def _normalize_tool_argument_parse_mode(parse_mode: str | ToolArgumentParseMode) -> ToolArgumentParseMode:
"""将配置中的工具参数解析模式收敛为枚举值。
Args:
parse_mode: 原始解析模式配置。
Returns:
ToolArgumentParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。
"""
if isinstance(parse_mode, ToolArgumentParseMode):
return parse_mode
try:
return ToolArgumentParseMode(parse_mode)
except ValueError:
logger.warning(f"未识别的工具参数解析模式 {parse_mode},已回退为 auto")
return ToolArgumentParseMode.AUTO
def _build_text_content_part(text: str) -> ChatCompletionContentPartTextParam:
"""构建文本内容片段。
Args:
text: 文本内容。
Returns:
ChatCompletionContentPartTextParam: OpenAI 兼容的文本片段。
"""
return {
"type": "text",
"text": text,
}
def _build_image_content_part(part: ImageMessagePart) -> ChatCompletionContentPartImageParam:
"""构建图片内容片段。
Args:
part: 内部图片片段。
Returns:
ChatCompletionContentPartImageParam: OpenAI 兼容的图片片段。
"""
return {
"type": "image_url",
"image_url": {
"url": f"data:image/{part.normalized_image_format};base64,{part.image_base64}",
},
}
def _convert_response_format(response_format: RespFormat | None) -> Any:
"""将内部响应格式转换为 OpenAI 兼容结构。
Args:
response_format: 内部响应格式定义。
Returns:
Any: OpenAI SDK 可接受的响应格式参数;未指定时返回 `omit`。
"""
if response_format is None or response_format.format_type == RespFormatType.TEXT:
return omit
if response_format.format_type == RespFormatType.JSON_OBJ:
return {"type": "json_object"}
if response_format.format_type == RespFormatType.JSON_SCHEMA:
return {
"type": "json_schema",
"json_schema": response_format.schema,
}
return omit
def _convert_text_only_message_content(
message: Message,
) -> str | List[ChatCompletionContentPartTextParam]:
"""将仅允许文本的消息转换为 OpenAI 兼容内容。
Args:
message: 内部统一消息对象。
Returns:
str | List[ChatCompletionContentPartTextParam]: 文本内容结构。
Raises:
ValueError: 当消息中包含非文本片段时抛出。
"""
if not message.parts:
return ""
if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart):
return message.parts[0].text
content: List[ChatCompletionContentPartTextParam] = []
for part in message.parts:
if not isinstance(part, TextMessagePart):
raise ValueError(f"{message.role.value} 消息仅支持文本片段")
content.append(_build_text_content_part(part.text))
return content
def _convert_user_message_content(message: Message) -> str | List[ChatCompletionContentPartParam]:
"""将用户消息转换为 OpenAI 兼容内容。
Args:
message: 内部统一消息对象。
Returns:
str | List[ChatCompletionContentPartParam]: 用户消息内容结构。
"""
if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart):
return message.parts[0].text
content: List[ChatCompletionContentPartParam] = []
for part in message.parts:
if isinstance(part, TextMessagePart):
content.append(_build_text_content_part(part.text))
continue
content.append(_build_image_content_part(part))
return content
def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatCompletionMessageFunctionToolCallParam]:
"""将内部工具调用转换为 OpenAI assistant tool_calls 结构。
Args:
tool_calls: 内部工具调用列表。
Returns:
List[ChatCompletionMessageFunctionToolCallParam]: OpenAI 兼容工具调用结构。
"""
converted_tool_calls: List[ChatCompletionMessageFunctionToolCallParam] = []
for tool_call in tool_calls:
converted_tool_calls.append(
{
"id": tool_call.call_id,
"type": "function",
"function": {
"name": tool_call.func_name,
"arguments": json.dumps(tool_call.args or {}, ensure_ascii=False),
},
}
)
return converted_tool_calls
def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]:
"""将内部消息列表转换为 OpenAI 兼容消息列表。
Args:
messages: 内部统一消息列表。
Returns:
List[ChatCompletionMessageParam]: OpenAI SDK 所需的消息结构列表。
"""
converted_messages: List[ChatCompletionMessageParam] = []
for message in messages:
if message.role == RoleType.System:
system_payload: ChatCompletionSystemMessageParam = {
"role": "system",
"content": _convert_text_only_message_content(message),
}
converted_messages.append(system_payload)
continue
if message.role == RoleType.User:
user_payload: ChatCompletionUserMessageParam = {
"role": "user",
"content": _convert_user_message_content(message),
}
converted_messages.append(user_payload)
continue
if message.role == RoleType.Assistant:
assistant_payload: ChatCompletionAssistantMessageParam = {
"role": "assistant",
"content": None if not message.parts and message.tool_calls else _convert_text_only_message_content(message),
}
if message.tool_calls:
assistant_payload["tool_calls"] = _convert_assistant_tool_calls(message.tool_calls)
converted_messages.append(assistant_payload)
continue
if message.role == RoleType.Tool:
if message.tool_call_id is None:
raise ValueError("Tool 消息缺少 tool_call_id")
tool_payload: ChatCompletionToolMessageParam = {
"role": "tool",
"content": _convert_text_only_message_content(message),
"tool_call_id": message.tool_call_id,
}
converted_messages.append(tool_payload)
continue
raise ValueError(f"不支持的消息角色:{message.role}")
return converted_messages
def _convert_tool_options(tool_options: List[ToolOption]) -> List[ChatCompletionToolParam]:
"""将工具定义转换为 OpenAI 兼容的工具列表。
Args:
tool_options: 内部统一工具定义列表。
Returns:
List[ChatCompletionToolParam]: OpenAI SDK 所需的工具定义列表。
"""
converted_tools: List[ChatCompletionToolParam] = []
for tool_option in tool_options:
function_schema: FunctionDefinition = {
"name": tool_option.name,
"description": tool_option.description,
}
parameters_schema = tool_option.parameters_schema
if parameters_schema is not None:
function_schema["parameters"] = cast(Dict[str, object], parameters_schema)
converted_tools.append(
{
"type": "function",
"function": function_schema,
}
)
return converted_tools
def _extract_usage_record(usage: Any) -> UsageTuple | None:
"""从响应对象中提取 usage 三元组。
Args:
usage: OpenAI SDK 返回的 usage 对象。
Returns:
UsageTuple | None: `(prompt_tokens, completion_tokens, total_tokens)`。
"""
if usage is None:
return None
return (
getattr(usage, "prompt_tokens", 0) or 0,
getattr(usage, "completion_tokens", 0) or 0,
getattr(usage, "total_tokens", 0) or 0,
)
def _parse_tool_arguments(
raw_arguments: str,
parse_mode: ToolArgumentParseMode,
response: Any,
) -> Dict[str, Any]:
"""解析工具调用参数字符串。
Args:
raw_arguments: 工具调用参数原始字符串。
parse_mode: 参数解析模式。
response: 原始响应对象,用于异常上下文。
Returns:
Dict[str, Any]: 解析后的参数字典。
Raises:
RespParseException: 当参数无法解析为字典时抛出。
"""
try:
if parse_mode == ToolArgumentParseMode.STRICT:
arguments: Any = json.loads(raw_arguments)
elif parse_mode == ToolArgumentParseMode.REPAIR:
arguments = repair_json(raw_arguments, return_objects=True, logging=False)
else:
arguments = repair_json(raw_arguments, return_objects=True, logging=False)
if isinstance(arguments, str) and parse_mode in {
ToolArgumentParseMode.AUTO,
ToolArgumentParseMode.DOUBLE_DECODE,
}:
arguments = repair_json(arguments, return_objects=True, logging=False)
except json.JSONDecodeError as exc:
raise RespParseException(response, f"响应解析失败,无法解析工具调用参数。原始参数:{raw_arguments}") from exc
if not isinstance(arguments, dict):
raise RespParseException(
response,
f"响应解析失败,工具调用参数必须解析为字典,实际类型为 {type(arguments).__name__}",
)
return arguments
def _extract_reasoning_and_content(
content: str,
parse_mode: ReasoningParseMode,
) -> Tuple[str | None, str | None]:
"""从文本内容中提取推理内容与正式输出。
Args:
content: 模型返回的文本内容。
parse_mode: 推理解析模式。
Returns:
Tuple[str | None, str | None]: `(reasoning_content, content)`。
"""
if parse_mode in {ReasoningParseMode.NATIVE, ReasoningParseMode.NONE}:
return None, content
match = THINK_CONTENT_PATTERN.match(content)
if not match:
return None, content
if match.group("think") is not None:
reasoning_content = match.group("think").strip() or None
final_content = match.group("content").strip() or None
return reasoning_content, final_content
if match.group("think_unclosed") is not None:
return match.group("think_unclosed").strip() or None, None
return None, match.group("content_only").strip() or None
def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None:
"""记录因长度截断导致的告警日志。
Args:
finish_reason: OpenAI 兼容接口返回的完成原因。
model_name: 上游返回的模型标识。
"""
if finish_reason == "length":
logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整")
def _coerce_openai_argument(value: Any) -> Any | Omit:
"""将可选参数转换为 OpenAI SDK 期望的值。
Args:
value: 原始参数值。
Returns:
Any | Omit: `None` 会被转换为 `omit`,其余值原样返回。
"""
if value is None:
return omit
return value
def _build_api_status_message(error: APIStatusError) -> str:
"""构建更适合记录和展示的状态错误信息。
Args:
error: OpenAI SDK 抛出的状态错误。
Returns:
str: 拼装后的错误信息。
"""
message_parts: List[str] = []
if getattr(error, "message", None):
message_parts.append(str(error.message))
response_text = getattr(getattr(error, "response", None), "text", None)
if response_text:
message_parts.append(str(response_text)[:300])
if message_parts:
return " | ".join(message_parts)
return f"上游接口返回状态码 {error.status_code}"
@dataclass(slots=True)
class _StreamedToolCallState:
"""流式工具调用累积状态。"""
index: int
call_id: str = ""
function_name: str = ""
arguments_buffer: io.StringIO = field(default_factory=io.StringIO)
def append_arguments(self, arguments_chunk: str) -> None:
"""追加一段工具调用参数字符串。
Args:
arguments_chunk: 参数增量片段。
"""
self.arguments_buffer.write(arguments_chunk)
def close(self) -> None:
"""关闭内部缓存。"""
if not self.arguments_buffer.closed:
self.arguments_buffer.close()
class _OpenAIStreamAccumulator:
"""OpenAI 兼容流式响应累积器。"""
def __init__(
self,
reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode,
) -> None:
"""初始化累积器。
Args:
reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。
"""
self.reasoning_parse_mode = reasoning_parse_mode
self.tool_argument_parse_mode = tool_argument_parse_mode
self.reasoning_buffer = io.StringIO()
self.content_buffer = io.StringIO()
self.tool_call_states: Dict[int, _StreamedToolCallState] = {}
self.finish_reason: str | None = None
self.model_name: str | None = None
self._using_native_reasoning = False
def capture_event_metadata(self, event: ChatCompletionChunk) -> None:
"""捕获事件中的完成原因和模型名。
Args:
event: 当前流式事件。
"""
if getattr(event, "model", None) and not self.model_name:
self.model_name = event.model
if getattr(event, "choices", None):
finish_reason = getattr(event.choices[0], "finish_reason", None)
if finish_reason:
self.finish_reason = finish_reason
def process_delta(self, delta: ChoiceDelta) -> None:
"""处理一个增量块。
Args:
delta: 当前增量对象。
"""
self._process_reasoning_delta(delta)
self._process_tool_call_delta(delta)
def _process_reasoning_delta(self, delta: ChoiceDelta) -> None:
"""处理推理内容与正式内容。
Args:
delta: 当前增量对象。
"""
native_reasoning = getattr(delta, "reasoning_content", None)
if isinstance(native_reasoning, str) and native_reasoning:
self._using_native_reasoning = True
if self.reasoning_parse_mode != ReasoningParseMode.NONE:
self.reasoning_buffer.write(native_reasoning)
return
content_chunk = getattr(delta, "content", None)
if not isinstance(content_chunk, str) or content_chunk == "":
return
if self.reasoning_parse_mode == ReasoningParseMode.NONE:
self.content_buffer.write(content_chunk)
return
if self.reasoning_parse_mode == ReasoningParseMode.NATIVE:
self.content_buffer.write(content_chunk)
return
self.content_buffer.write(content_chunk)
def _process_tool_call_delta(self, delta: ChoiceDelta) -> None:
"""处理工具调用增量。
Args:
delta: 当前增量对象。
"""
tool_call_deltas = getattr(delta, "tool_calls", None) or []
for tool_call_delta in tool_call_deltas:
state = self.tool_call_states.setdefault(tool_call_delta.index, _StreamedToolCallState(index=tool_call_delta.index))
if tool_call_delta.id:
state.call_id = tool_call_delta.id
function = tool_call_delta.function
if function is not None and function.name:
state.function_name = function.name
if function is not None and function.arguments:
state.append_arguments(function.arguments)
def build_response(self) -> APIResponse:
"""构建最终 APIResponse 对象。
Returns:
APIResponse: 累积完成的响应对象。
Raises:
EmptyResponseException: 当响应中既无可见内容也无工具调用时抛出。
RespParseException: 当工具调用结构不完整时抛出。
"""
response = APIResponse()
content = self.content_buffer.getvalue().strip()
reasoning_content = self.reasoning_buffer.getvalue().strip()
if not self._using_native_reasoning and self.reasoning_parse_mode != ReasoningParseMode.NONE and content:
parsed_reasoning_content, parsed_content = _extract_reasoning_and_content(
content=content,
parse_mode=self.reasoning_parse_mode,
)
if parsed_reasoning_content:
reasoning_content = parsed_reasoning_content
content = parsed_content or ""
if reasoning_content:
response.reasoning_content = reasoning_content
if content:
response.content = content
if self.tool_call_states:
response.tool_calls = []
for index in sorted(self.tool_call_states):
state = self.tool_call_states[index]
if not state.function_name:
raise RespParseException(None, f"响应解析失败,工具调用 {index} 缺少函数名。")
raw_arguments = state.arguments_buffer.getvalue().strip()
arguments = (
_parse_tool_arguments(raw_arguments, self.tool_argument_parse_mode, None)
if raw_arguments
else None
)
call_id = state.call_id or f"tool_call_{index}"
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
response.raw_data = {"model": self.model_name} if self.model_name else None
if not response.content and not response.tool_calls:
raise EmptyResponseException()
return response
def close(self) -> None:
"""关闭内部缓冲区。"""
if not self.reasoning_buffer.closed:
self.reasoning_buffer.close()
if not self.content_buffer.closed:
self.content_buffer.close()
for state in self.tool_call_states.values():
state.close()
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
*,
reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode,
) -> Tuple[APIResponse, UsageTuple | None]:
"""处理 OpenAI 兼容流式响应。
Args:
resp_stream: OpenAI SDK 返回的流式响应对象。
interrupt_flag: 外部中断标记。
reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。
Returns:
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
"""
accumulator = _OpenAIStreamAccumulator(
reasoning_parse_mode=reasoning_parse_mode,
tool_argument_parse_mode=tool_argument_parse_mode,
)
usage_record: UsageTuple | None = None
try:
async for event in resp_stream:
if interrupt_flag and interrupt_flag.is_set():
raise ReqAbortException("请求被外部信号中断")
accumulator.capture_event_metadata(event)
event_usage = _extract_usage_record(getattr(event, "usage", None))
if event_usage is not None:
usage_record = event_usage
if not getattr(event, "choices", None):
continue
accumulator.process_delta(event.choices[0].delta)
response = accumulator.build_response()
model_name = None
if isinstance(response.raw_data, dict):
model_name = response.raw_data.get("model")
_log_length_truncation(accumulator.finish_reason, model_name)
return response, usage_record
finally:
accumulator.close()
def _default_normal_response_parser(
resp: ChatCompletion,
*,
reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode,
) -> Tuple[APIResponse, UsageTuple | None]:
"""解析 OpenAI 兼容的非流式响应。
Args:
resp: OpenAI SDK 返回的聊天补全响应。
reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。
Returns:
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
Raises:
EmptyResponseException: 当 choices 为空或响应内容为空时抛出。
"""
choices = getattr(resp, "choices", None)
if not choices:
raise EmptyResponseException("响应解析失败choices 为空或缺失")
api_response = APIResponse()
message_part = choices[0].message
native_reasoning = getattr(message_part, "reasoning_content", None)
message_content = message_part.content if isinstance(message_part.content, str) else None
if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE:
api_response.reasoning_content = native_reasoning
api_response.content = message_content
elif isinstance(message_content, str) and message_content:
reasoning_content, final_content = _extract_reasoning_and_content(
content=message_content,
parse_mode=reasoning_parse_mode,
)
api_response.reasoning_content = reasoning_content
api_response.content = final_content
tool_calls = getattr(message_part, "tool_calls", None) or []
if tool_calls:
api_response.tool_calls = []
for tool_call in tool_calls:
if tool_call.type != "function":
raise RespParseException(resp, f"响应解析失败,暂不支持工具调用类型 {tool_call.type}")
raw_arguments = tool_call.function.arguments or ""
arguments = _parse_tool_arguments(raw_arguments, tool_argument_parse_mode, resp)
api_response.tool_calls.append(
ToolCall(
call_id=tool_call.id,
func_name=tool_call.function.name,
args=arguments,
)
)
usage_record = _extract_usage_record(getattr(resp, "usage", None))
api_response.raw_data = resp
finish_reason = getattr(resp.choices[0], "finish_reason", None)
_log_length_truncation(finish_reason, getattr(resp, "model", None))
if not api_response.content and not api_response.tool_calls:
raise EmptyResponseException()
return api_response, usage_record
@client_registry.register_client_class("openai")
class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletion]):
"""OpenAI 兼容客户端。"""
client: AsyncOpenAI
reasoning_parse_mode: ReasoningParseMode
tool_argument_parse_mode: ToolArgumentParseMode
def __init__(self, api_provider: APIProvider) -> None:
"""初始化 OpenAI 兼容客户端。
Args:
api_provider: API 提供商配置。
"""
super().__init__(api_provider)
client_config = build_openai_compatible_client_config(api_provider)
self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode)
self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode)
self.client = AsyncOpenAI(
api_key=client_config.api_key,
organization=api_provider.organization,
project=api_provider.project,
base_url=client_config.base_url,
timeout=api_provider.timeout,
max_retries=api_provider.max_retry,
default_headers=client_config.default_headers or None,
default_query=client_config.default_query or None,
)
def _build_default_stream_response_handler(
self,
request: ResponseRequest,
) -> ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]:
"""构建 OpenAI 默认流式响应处理器。
Args:
request: 统一响应请求对象。
Returns:
ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: 默认流式处理器。
"""
del request
async def default_stream_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
flag: asyncio.Event | None,
) -> Tuple[APIResponse, UsageTuple | None]:
"""包装默认流式解析器。"""
return await _default_stream_response_handler(
resp_stream,
flag,
reasoning_parse_mode=self.reasoning_parse_mode,
tool_argument_parse_mode=self.tool_argument_parse_mode,
)
return default_stream_handler
def _build_default_response_parser(
self,
request: ResponseRequest,
) -> ProviderResponseParser[ChatCompletion]:
"""构建 OpenAI 默认非流式响应解析器。
Args:
request: 统一响应请求对象。
Returns:
ProviderResponseParser[ChatCompletion]: 默认非流式解析器。
"""
del request
def default_response_parser(
response: ChatCompletion,
) -> Tuple[APIResponse, UsageTuple | None]:
"""包装默认非流式解析器。"""
return _default_normal_response_parser(
response,
reasoning_parse_mode=self.reasoning_parse_mode,
tool_argument_parse_mode=self.tool_argument_parse_mode,
)
return default_response_parser
async def _execute_response_request(
self,
request: ResponseRequest,
stream_response_handler: ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]],
response_parser: ProviderResponseParser[ChatCompletion],
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 OpenAI 兼容的文本/多模态响应请求。
Args:
request: 统一响应请求对象。
stream_response_handler: 流式响应处理器。
response_parser: 非流式响应解析器。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
model_info = request.model_info
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(request.message_list)
tools: Iterable[ChatCompletionToolParam] | Omit = (
_convert_tool_options(request.tool_options) if request.tool_options else omit
)
openai_response_format = _convert_response_format(request.response_format)
request_overrides = split_openai_request_overrides(
request.extra_params,
reserved_body_keys=CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS,
)
temperature_argument = (
omit if "temperature" in request_overrides.extra_body else _coerce_openai_argument(request.temperature)
)
max_tokens_argument = (
omit
if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body
else _coerce_openai_argument(request.max_tokens)
)
try:
if model_info.force_stream_mode:
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature_argument,
max_tokens=max_tokens_argument,
stream=True,
response_format=openai_response_format,
extra_headers=request_overrides.extra_headers or None,
extra_query=request_overrides.extra_query or None,
extra_body=request_overrides.extra_body or None,
)
)
raw_response = cast(
AsyncStream[ChatCompletionChunk],
await await_task_with_interrupt(stream_task, request.interrupt_flag),
)
return await stream_response_handler(raw_response, request.interrupt_flag)
completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature_argument,
max_tokens=max_tokens_argument,
stream=False,
response_format=openai_response_format,
extra_headers=request_overrides.extra_headers or None,
extra_query=request_overrides.extra_query or None,
extra_body=request_overrides.extra_body or None,
)
)
raw_response = cast(
ChatCompletion,
await await_task_with_interrupt(completion_task, request.interrupt_flag),
)
return response_parser(raw_response)
except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc
except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
async def _execute_embedding_request(
self,
request: EmbeddingRequest,
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 OpenAI 兼容的文本嵌入请求。
Args:
request: 统一嵌入请求对象。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
model_info = request.model_info
embedding_input = request.embedding_input
extra_params = request.extra_params
request_overrides = split_openai_request_overrides(extra_params)
try:
raw_response = await self.client.embeddings.create(
model=model_info.model_identifier,
input=embedding_input,
extra_headers=request_overrides.extra_headers or None,
extra_query=request_overrides.extra_query or None,
extra_body=request_overrides.extra_body or None,
)
except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc
except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
response = APIResponse()
if raw_response.data:
response.embedding = raw_response.data[0].embedding
else:
raise RespParseException(raw_response, "响应解析失败,缺失嵌入数据。")
usage_record = _extract_usage_record(getattr(raw_response, "usage", None))
return response, usage_record
async def _execute_audio_transcription_request(
self,
request: AudioTranscriptionRequest,
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 OpenAI 兼容的音频转录请求。
Args:
request: 统一音频转录请求对象。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
model_info = request.model_info
audio_base64 = request.audio_base64
extra_params = request.extra_params
request_overrides = split_openai_request_overrides(extra_params)
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
try:
raw_response = await self.client.audio.transcriptions.create(
model=model_info.model_identifier,
file=audio_file,
extra_headers=request_overrides.extra_headers or None,
extra_query=request_overrides.extra_query or None,
extra_body=request_overrides.extra_body or None,
)
except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc
except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
response = APIResponse()
transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None)
if isinstance(transcription_text, str):
response.content = transcription_text
return response, None
raise RespParseException(raw_response, "响应解析失败,缺失转录文本。")
def get_support_image_formats(self) -> List[str]:
"""获取支持的图片格式列表。
Returns:
List[str]: 当前客户端支持的图片格式列表。
"""
return ["jpg", "jpeg", "png", "webp", "gif"]