import asyncio import base64 import binascii import io import json import re from dataclasses import dataclass, field from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast from json_repair import repair_json from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream from openai._types import FileTypes, Omit, omit from openai.types.chat import ( ChatCompletion, ChatCompletionAssistantMessageParam, ChatCompletionChunk, ChatCompletionContentPartImageParam, ChatCompletionContentPartParam, ChatCompletionContentPartTextParam, ChatCompletionMessageFunctionToolCallParam, ChatCompletionMessageParam, ChatCompletionSystemMessageParam, ChatCompletionToolMessageParam, ChatCompletionToolParam, ChatCompletionUserMessageParam, ) from openai.types.shared_params.function_definition import FunctionDefinition from openai.types.chat.chat_completion_chunk import ChoiceDelta from PIL import Image as PILImage from src.common.logger import get_logger from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode from src.llm_models.exceptions import ( EmptyResponseException, NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException, ) from src.llm_models.openai_compat import ( build_openai_compatible_client_config, split_openai_request_overrides, ) from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType from src.llm_models.payload_content.tool_option import ToolCall, ToolOption from .adapter_base import ( AdapterClient, ProviderResponseParser, ProviderStreamResponseHandler, await_task_with_interrupt, ) from .base_client import ( APIResponse, AudioTranscriptionRequest, EmbeddingRequest, ResponseRequest, UsageTuple, client_registry, ) from ..request_snapshot import ( attach_request_snapshot, has_request_snapshot, save_failed_request_snapshot, serialize_audio_request_snapshot, serialize_embedding_request_snapshot, serialize_response_request_snapshot, ) logger = get_logger("llm_models") SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"} """OpenAI 兼容图片输入稳定支持的格式集合。""" THINK_CONTENT_PATTERN = re.compile( r"(?P.*?)(?P.*)|(?P.*)|(?P.+)", re.DOTALL, ) """用于解析 `` 推理块的正则表达式。""" XML_TOOL_CALL_PATTERN = re.compile(r"\s*(?P.*?)\s*", re.DOTALL | re.IGNORECASE) """用于兜底解析模型以 XML 文本返回的工具调用。 这是一个暂时性兼容方案,专门处理“思维链内容里夹带工具调用”的情况; 后续如果上游稳定返回标准 tool_calls 字段,这里可能会调整或移除。 """ XML_FUNCTION_CALL_PATTERN = re.compile( r"[A-Za-z0-9_.-]+)>\s*(?P.*?)\s*", re.DOTALL | re.IGNORECASE, ) """用于从 XML 风格工具调用块中提取函数名与参数。""" XML_PARAMETER_PATTERN = re.compile( r"[A-Za-z0-9_.-]+)>\s*(?P.*?)\s*", re.DOTALL | re.IGNORECASE, ) """用于从 XML 风格工具调用块中提取参数列表。""" CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = { "max_tokens", "messages", "model", "response_format", "stream", "temperature", "tools", } """由当前客户端显式承载、不应再落入 `extra_body` 的字段集合。""" OpenAIStreamResponseHandler = Callable[ [AsyncStream[ChatCompletionChunk], asyncio.Event | None], Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]], ] """OpenAI 流式响应处理函数类型。""" OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]] """OpenAI 非流式响应解析函数类型。""" def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode: """将配置中的推理解析模式收敛为枚举值。 Args: parse_mode: 原始解析模式配置。 Returns: ReasoningParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。 """ if isinstance(parse_mode, ReasoningParseMode): return parse_mode try: return ReasoningParseMode(parse_mode) except ValueError: logger.warning(f"未识别的推理解析模式 {parse_mode},已回退为 auto") return ReasoningParseMode.AUTO def _normalize_tool_argument_parse_mode(parse_mode: str | ToolArgumentParseMode) -> ToolArgumentParseMode: """将配置中的工具参数解析模式收敛为枚举值。 Args: parse_mode: 原始解析模式配置。 Returns: ToolArgumentParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。 """ if isinstance(parse_mode, ToolArgumentParseMode): return parse_mode try: return ToolArgumentParseMode(parse_mode) except ValueError: logger.warning(f"未识别的工具参数解析模式 {parse_mode},已回退为 auto") return ToolArgumentParseMode.AUTO def _build_text_content_part(text: str) -> ChatCompletionContentPartTextParam: """构建文本内容片段。 Args: text: 文本内容。 Returns: ChatCompletionContentPartTextParam: OpenAI 兼容的文本片段。 """ return { "type": "text", "text": text, } def _build_image_content_part(part: ImageMessagePart) -> ChatCompletionContentPartImageParam: """构建图片内容片段。 Args: part: 内部图片片段。 Returns: ChatCompletionContentPartImageParam: OpenAI 兼容的图片片段。 """ normalized_image = _normalize_image_part_for_openai(part) if normalized_image is None: raise ValueError("图片数据无效,无法构建图片消息片段") image_format, image_base64 = normalized_image return { "type": "image_url", "image_url": { "url": f"data:image/{image_format};base64,{image_base64}", }, } def _normalize_image_part_for_openai(part: ImageMessagePart) -> Tuple[str, str] | None: """将图片片段规范化为 OpenAI 兼容格式。 Args: part: 内部图片片段。 Returns: Tuple[str, str] | None: `(image_format, image_base64)`;无法解析时返回 `None`。 """ try: image_bytes = base64.b64decode(part.image_base64, validate=True) except (binascii.Error, ValueError) as exc: logger.warning(f"图片 Base64 解码失败,已跳过该图片片段: {exc}") return None try: with PILImage.open(io.BytesIO(image_bytes)) as image: image_format = (image.format or part.normalized_image_format).lower() if image_format in {"jpg", "jpeg"}: image_format = "jpeg" if image_format in SUPPORTED_OPENAI_IMAGE_FORMATS: return image_format, part.image_base64 if image_format == "gif": frame_count = getattr(image, "n_frames", 1) frames: List[PILImage.Image] = [] durations: List[int] = [] for frame_index in range(frame_count): image.seek(frame_index) frame = image.copy() if frame.mode not in {"RGB", "RGBA"}: frame = frame.convert("RGBA") frames.append(frame) durations.append(int(image.info.get("duration", 100) or 100)) output_buffer = io.BytesIO() save_kwargs: Dict[str, Any] = { "format": "WEBP", "save_all": True, "append_images": frames[1:], "duration": durations, "loop": int(image.info.get("loop", 0) or 0), } if frame_count > 1: save_kwargs["lossless"] = True frames[0].save(output_buffer, **save_kwargs) converted_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8") return "webp", converted_base64 image.seek(0) normalized_image = image.copy() if normalized_image.mode not in {"RGB", "RGBA"}: normalized_image = normalized_image.convert("RGBA") output_buffer = io.BytesIO() normalized_image.save(output_buffer, format="PNG") converted_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8") return "png", converted_base64 except Exception as exc: logger.warning(f"图片内容无法被识别为有效图片,已跳过该图片片段: {exc}") return None def _convert_response_format(response_format: RespFormat | None) -> Any: """将内部响应格式转换为 OpenAI 兼容结构。 Args: response_format: 内部响应格式定义。 Returns: Any: OpenAI SDK 可接受的响应格式参数;未指定时返回 `omit`。 """ if response_format is None or response_format.format_type == RespFormatType.TEXT: return omit if response_format.format_type == RespFormatType.JSON_OBJ: return {"type": "json_object"} if response_format.format_type == RespFormatType.JSON_SCHEMA: return { "type": "json_schema", "json_schema": response_format.schema, } return omit def _convert_text_only_message_content( message: Message, ) -> str | List[ChatCompletionContentPartTextParam]: """将仅允许文本的消息转换为 OpenAI 兼容内容。 Args: message: 内部统一消息对象。 Returns: str | List[ChatCompletionContentPartTextParam]: 文本内容结构。 Raises: ValueError: 当消息中包含非文本片段时抛出。 """ if not message.parts: return "" if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart): return message.parts[0].text content: List[ChatCompletionContentPartTextParam] = [] for part in message.parts: if not isinstance(part, TextMessagePart): raise ValueError(f"{message.role.value} 消息仅支持文本片段") content.append(_build_text_content_part(part.text)) return content def _convert_user_message_content(message: Message) -> str | List[ChatCompletionContentPartParam]: """将用户消息转换为 OpenAI 兼容内容。 Args: message: 内部统一消息对象。 Returns: str | List[ChatCompletionContentPartParam]: 用户消息内容结构。 """ if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart): return message.parts[0].text content: List[ChatCompletionContentPartParam] = [] for part in message.parts: if isinstance(part, TextMessagePart): content.append(_build_text_content_part(part.text)) continue normalized_image = _normalize_image_part_for_openai(part) if normalized_image is None: content.append(_build_text_content_part("[图片内容不可用]")) continue image_format, image_base64 = normalized_image content.append( { "type": "image_url", "image_url": { "url": f"data:image/{image_format};base64,{image_base64}", }, } ) return content def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatCompletionMessageFunctionToolCallParam]: """将内部工具调用转换为 OpenAI assistant tool_calls 结构。 Args: tool_calls: 内部工具调用列表。 Returns: List[ChatCompletionMessageFunctionToolCallParam]: OpenAI 兼容工具调用结构。 """ converted_tool_calls: List[ChatCompletionMessageFunctionToolCallParam] = [] for tool_call in tool_calls: converted_tool_calls.append( { "id": tool_call.call_id, "type": "function", "function": { "name": tool_call.func_name, "arguments": json.dumps(tool_call.args or {}, ensure_ascii=False), }, } ) return converted_tool_calls def _sanitize_messages_for_toolless_request(messages: List[Message]) -> List[Message]: """在无工具请求时清洗历史工具调用链,避免兼容接口拒收消息。""" sanitized_messages: List[Message] = [] for message in messages: if message.role == RoleType.Tool: continue if message.role == RoleType.Assistant and message.tool_calls: if not message.parts: continue assistant_message = Message( role=message.role, parts=list(message.parts), tool_call_id=message.tool_call_id, tool_name=message.tool_name, tool_calls=None, ) sanitized_messages.append(assistant_message) continue sanitized_messages.append(message) return sanitized_messages def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]: """将内部消息列表转换为 OpenAI 兼容消息列表。 Args: messages: 内部统一消息列表。 Returns: List[ChatCompletionMessageParam]: OpenAI SDK 所需的消息结构列表。 """ converted_messages: List[ChatCompletionMessageParam] = [] for message in messages: if message.role == RoleType.System: system_payload: ChatCompletionSystemMessageParam = { "role": "system", "content": _convert_text_only_message_content(message), } converted_messages.append(system_payload) continue if message.role == RoleType.User: user_payload: ChatCompletionUserMessageParam = { "role": "user", "content": _convert_user_message_content(message), } converted_messages.append(user_payload) continue if message.role == RoleType.Assistant: assistant_payload: ChatCompletionAssistantMessageParam = { "role": "assistant", "content": None if not message.parts and message.tool_calls else _convert_text_only_message_content(message), } if message.tool_calls: assistant_payload["tool_calls"] = _convert_assistant_tool_calls(message.tool_calls) converted_messages.append(assistant_payload) continue if message.role == RoleType.Tool: if message.tool_call_id is None: raise ValueError("Tool 消息缺少 tool_call_id") tool_payload: ChatCompletionToolMessageParam = { "role": "tool", "content": _convert_text_only_message_content(message), "tool_call_id": message.tool_call_id, } converted_messages.append(tool_payload) continue raise ValueError(f"不支持的消息角色:{message.role}") return converted_messages def _convert_tool_options(tool_options: List[ToolOption]) -> List[ChatCompletionToolParam]: """将工具定义转换为 OpenAI 兼容的工具列表。 Args: tool_options: 内部统一工具定义列表。 Returns: List[ChatCompletionToolParam]: OpenAI SDK 所需的工具定义列表。 """ converted_tools: List[ChatCompletionToolParam] = [] for tool_option in tool_options: parameters_schema = cast( Dict[str, object], tool_option.parameters_schema or {"type": "object", "properties": {}}, ) function_schema: FunctionDefinition = { "name": tool_option.name, "description": tool_option.description, "parameters": parameters_schema, } converted_tools.append( { "type": "function", "function": function_schema, } ) return converted_tools def _extract_usage_record(usage: Any) -> UsageTuple | None: """从响应对象中提取 usage 三元组。 Args: usage: OpenAI SDK 返回的 usage 对象。 Returns: UsageTuple | None: `(prompt_tokens, completion_tokens, total_tokens)`。 """ if usage is None: return None return ( getattr(usage, "prompt_tokens", 0) or 0, getattr(usage, "completion_tokens", 0) or 0, getattr(usage, "total_tokens", 0) or 0, ) def _parse_tool_arguments( raw_arguments: str, parse_mode: ToolArgumentParseMode, response: Any, ) -> Dict[str, Any]: """解析工具调用参数字符串。 Args: raw_arguments: 工具调用参数原始字符串。 parse_mode: 参数解析模式。 response: 原始响应对象,用于异常上下文。 Returns: Dict[str, Any]: 解析后的参数字典。 Raises: RespParseException: 当参数无法解析为字典时抛出。 """ try: if parse_mode == ToolArgumentParseMode.STRICT: arguments: Any = json.loads(raw_arguments) elif parse_mode == ToolArgumentParseMode.REPAIR: arguments = repair_json(raw_arguments, return_objects=True, logging=False) else: arguments = repair_json(raw_arguments, return_objects=True, logging=False) if isinstance(arguments, str) and parse_mode in { ToolArgumentParseMode.AUTO, ToolArgumentParseMode.DOUBLE_DECODE, }: arguments = repair_json(arguments, return_objects=True, logging=False) except json.JSONDecodeError as exc: raise RespParseException(response, f"响应解析失败,无法解析工具调用参数。原始参数:{raw_arguments}") from exc if not isinstance(arguments, dict): raise RespParseException( response, f"响应解析失败,工具调用参数必须解析为字典,实际类型为 {type(arguments).__name__}。", ) return arguments def _extract_reasoning_and_content( content: str, parse_mode: ReasoningParseMode, ) -> Tuple[str | None, str | None]: """从文本内容中提取推理内容与正式输出。 Args: content: 模型返回的文本内容。 parse_mode: 推理解析模式。 Returns: Tuple[str | None, str | None]: `(reasoning_content, content)`。 """ if parse_mode in {ReasoningParseMode.NATIVE, ReasoningParseMode.NONE}: return None, content match = THINK_CONTENT_PATTERN.match(content) if not match: return None, content if match.group("think") is not None: reasoning_content = match.group("think").strip() or None final_content = match.group("content").strip() or None return reasoning_content, final_content if match.group("think_unclosed") is not None: return match.group("think_unclosed").strip() or None, None return None, match.group("content_only").strip() or None def _extract_xml_tool_calls( raw_text: str | None, parse_mode: ToolArgumentParseMode, response: Any, ) -> Tuple[str | None, List[ToolCall] | None]: """从 XML 风格文本中兜底提取工具调用。""" if not isinstance(raw_text, str) or not raw_text.strip(): return raw_text, None tool_calls: List[ToolCall] = [] def _coerce_xml_parameter_value(raw_value: str) -> Any: normalized_value = raw_value.strip() if not normalized_value: return "" lowered_value = normalized_value.lower() if lowered_value == "true": return True if lowered_value == "false": return False if lowered_value in {"null", "none"}: return None if normalized_value.startswith(("{", "[")): try: return repair_json(normalized_value, return_objects=True, logging=False) except Exception: return normalized_value return normalized_value def _parse_xml_parameters(raw_arguments: str) -> Dict[str, Any] | None: parameters = { match.group("name").strip(): _coerce_xml_parameter_value(match.group("value")) for match in XML_PARAMETER_PATTERN.finditer(raw_arguments) } return parameters or None def _replace_tool_call(match: re.Match[str]) -> str: body = match.group("body") function_match = XML_FUNCTION_CALL_PATTERN.search(body) if function_match is None: return match.group(0) function_name = function_match.group("name").strip() raw_arguments = function_match.group("arguments").strip() arguments = _parse_xml_parameters(raw_arguments) if arguments is None: arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {} tool_calls.append( ToolCall( call_id=f"xml_tool_call_{len(tool_calls) + 1}", func_name=function_name, args=arguments, ) ) return "" cleaned_text = XML_TOOL_CALL_PATTERN.sub(_replace_tool_call, raw_text).strip() or None return cleaned_text, tool_calls or None def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None: """记录因长度截断导致的告警日志。 Args: finish_reason: OpenAI 兼容接口返回的完成原因。 model_name: 上游返回的模型标识。 """ if finish_reason == "length": logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整") def _apply_xml_tool_call_fallback( response: APIResponse, parse_mode: ToolArgumentParseMode, raw_response: Any, ) -> None: """当上游未返回标准 tool_calls 时,尝试从 XML 文本兜底解析。 这是一个暂时性处理方法,用来兼容思维链中混入工具调用的返回格式, 后续可能随着模型或上游接口的规范化而变更。 """ if response.tool_calls: return reasoning_content, tool_calls = _extract_xml_tool_calls(response.reasoning_content, parse_mode, raw_response) if reasoning_content != response.reasoning_content: response.reasoning_content = reasoning_content if tool_calls: response.tool_calls = tool_calls if not response.content and reasoning_content: response.content = reasoning_content response.reasoning_content = None logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用") return cleaned_content, tool_calls = _extract_xml_tool_calls(response.content, parse_mode, raw_response) if cleaned_content != response.content: response.content = cleaned_content if tool_calls: response.tool_calls = tool_calls logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用") def _coerce_openai_argument(value: Any) -> Any | Omit: """将可选参数转换为 OpenAI SDK 期望的值。 Args: value: 原始参数值。 Returns: Any | Omit: `None` 会被转换为 `omit`,其余值原样返回。 """ if value is None: return omit return value def _snapshot_openai_argument(value: Any | Omit) -> Any | None: """将 OpenAI SDK 参数转换为适合写入快照的普通值。""" if value is omit: return None return value def _build_api_status_message(error: APIStatusError) -> str: """构建更适合记录和展示的状态错误信息。 Args: error: OpenAI SDK 抛出的状态错误。 Returns: str: 拼装后的错误信息。 """ message_parts: List[str] = [] if getattr(error, "message", None): message_parts.append(str(error.message)) response_text = getattr(getattr(error, "response", None), "text", None) if response_text: message_parts.append(str(response_text)) if message_parts: return " | ".join(message_parts) return f"上游接口返回状态码 {error.status_code}" @dataclass(slots=True) class _StreamedToolCallState: """流式工具调用累积状态。""" index: int call_id: str = "" function_name: str = "" arguments_buffer: io.StringIO = field(default_factory=io.StringIO) def append_arguments(self, arguments_chunk: str) -> None: """追加一段工具调用参数字符串。 Args: arguments_chunk: 参数增量片段。 """ self.arguments_buffer.write(arguments_chunk) def close(self) -> None: """关闭内部缓存。""" if not self.arguments_buffer.closed: self.arguments_buffer.close() class _OpenAIStreamAccumulator: """OpenAI 兼容流式响应累积器。""" def __init__( self, reasoning_parse_mode: ReasoningParseMode, tool_argument_parse_mode: ToolArgumentParseMode, ) -> None: """初始化累积器。 Args: reasoning_parse_mode: 推理内容解析模式。 tool_argument_parse_mode: 工具参数解析模式。 """ self.reasoning_parse_mode = reasoning_parse_mode self.tool_argument_parse_mode = tool_argument_parse_mode self.reasoning_buffer = io.StringIO() self.content_buffer = io.StringIO() self.tool_call_states: Dict[int, _StreamedToolCallState] = {} self.finish_reason: str | None = None self.model_name: str | None = None self._using_native_reasoning = False def capture_event_metadata(self, event: ChatCompletionChunk) -> None: """捕获事件中的完成原因和模型名。 Args: event: 当前流式事件。 """ if getattr(event, "model", None) and not self.model_name: self.model_name = event.model if getattr(event, "choices", None): finish_reason = getattr(event.choices[0], "finish_reason", None) if finish_reason: self.finish_reason = finish_reason def process_delta(self, delta: ChoiceDelta) -> None: """处理一个增量块。 Args: delta: 当前增量对象。 """ self._process_reasoning_delta(delta) self._process_tool_call_delta(delta) def _process_reasoning_delta(self, delta: ChoiceDelta) -> None: """处理推理内容与正式内容。 Args: delta: 当前增量对象。 """ native_reasoning = getattr(delta, "reasoning_content", None) if isinstance(native_reasoning, str) and native_reasoning: self._using_native_reasoning = True if self.reasoning_parse_mode != ReasoningParseMode.NONE: self.reasoning_buffer.write(native_reasoning) return content_chunk = getattr(delta, "content", None) if not isinstance(content_chunk, str) or content_chunk == "": return if self.reasoning_parse_mode == ReasoningParseMode.NONE: self.content_buffer.write(content_chunk) return if self.reasoning_parse_mode == ReasoningParseMode.NATIVE: self.content_buffer.write(content_chunk) return self.content_buffer.write(content_chunk) def _process_tool_call_delta(self, delta: ChoiceDelta) -> None: """处理工具调用增量。 Args: delta: 当前增量对象。 """ tool_call_deltas = getattr(delta, "tool_calls", None) or [] for tool_call_delta in tool_call_deltas: state = self.tool_call_states.setdefault(tool_call_delta.index, _StreamedToolCallState(index=tool_call_delta.index)) if tool_call_delta.id: state.call_id = tool_call_delta.id function = tool_call_delta.function if function is not None and function.name: state.function_name = function.name if function is not None and function.arguments: state.append_arguments(function.arguments) def build_response(self) -> APIResponse: """构建最终 APIResponse 对象。 Returns: APIResponse: 累积完成的响应对象。 Raises: EmptyResponseException: 当响应中既无可见内容也无工具调用时抛出。 RespParseException: 当工具调用结构不完整时抛出。 """ response = APIResponse() content = self.content_buffer.getvalue().strip() reasoning_content = self.reasoning_buffer.getvalue().strip() if not self._using_native_reasoning and self.reasoning_parse_mode != ReasoningParseMode.NONE and content: parsed_reasoning_content, parsed_content = _extract_reasoning_and_content( content=content, parse_mode=self.reasoning_parse_mode, ) if parsed_reasoning_content: reasoning_content = parsed_reasoning_content content = parsed_content or "" if reasoning_content: response.reasoning_content = reasoning_content if content: response.content = content if self.tool_call_states: response.tool_calls = [] for index in sorted(self.tool_call_states): state = self.tool_call_states[index] if not state.function_name: raise RespParseException(None, f"响应解析失败,工具调用 {index} 缺少函数名。") raw_arguments = state.arguments_buffer.getvalue().strip() arguments = ( _parse_tool_arguments(raw_arguments, self.tool_argument_parse_mode, None) if raw_arguments else None ) call_id = state.call_id or f"tool_call_{index}" response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments)) response.raw_data = {"model": self.model_name} if self.model_name else None _apply_xml_tool_call_fallback(response, self.tool_argument_parse_mode, response.raw_data) if not response.content and not response.tool_calls: raise EmptyResponseException(response.raw_data) return response def close(self) -> None: """关闭内部缓冲区。""" if not self.reasoning_buffer.closed: self.reasoning_buffer.close() if not self.content_buffer.closed: self.content_buffer.close() for state in self.tool_call_states.values(): state.close() async def _default_stream_response_handler( resp_stream: AsyncStream[ChatCompletionChunk], interrupt_flag: asyncio.Event | None, *, reasoning_parse_mode: ReasoningParseMode, tool_argument_parse_mode: ToolArgumentParseMode, ) -> Tuple[APIResponse, UsageTuple | None]: """处理 OpenAI 兼容流式响应。 Args: resp_stream: OpenAI SDK 返回的流式响应对象。 interrupt_flag: 外部中断标记。 reasoning_parse_mode: 推理内容解析模式。 tool_argument_parse_mode: 工具参数解析模式。 Returns: Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 """ accumulator = _OpenAIStreamAccumulator( reasoning_parse_mode=reasoning_parse_mode, tool_argument_parse_mode=tool_argument_parse_mode, ) usage_record: UsageTuple | None = None try: async for event in resp_stream: if interrupt_flag and interrupt_flag.is_set(): raise ReqAbortException("请求被外部信号中断") accumulator.capture_event_metadata(event) event_usage = _extract_usage_record(getattr(event, "usage", None)) if event_usage is not None: usage_record = event_usage if not getattr(event, "choices", None): continue accumulator.process_delta(event.choices[0].delta) response = accumulator.build_response() model_name = None if isinstance(response.raw_data, dict): model_name = response.raw_data.get("model") _log_length_truncation(accumulator.finish_reason, model_name) return response, usage_record finally: accumulator.close() def _default_normal_response_parser( resp: ChatCompletion, *, reasoning_parse_mode: ReasoningParseMode, tool_argument_parse_mode: ToolArgumentParseMode, ) -> Tuple[APIResponse, UsageTuple | None]: """解析 OpenAI 兼容的非流式响应。 Args: resp: OpenAI SDK 返回的聊天补全响应。 reasoning_parse_mode: 推理内容解析模式。 tool_argument_parse_mode: 工具参数解析模式。 Returns: Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 Raises: EmptyResponseException: 当 choices 为空或响应内容为空时抛出。 """ choices = getattr(resp, "choices", None) if not choices: raise EmptyResponseException(resp, "响应解析失败,choices 为空或缺失") api_response = APIResponse() message_part = choices[0].message native_reasoning = getattr(message_part, "reasoning_content", None) message_content = message_part.content if isinstance(message_part.content, str) else None if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE: api_response.reasoning_content = native_reasoning api_response.content = message_content elif isinstance(message_content, str) and message_content: reasoning_content, final_content = _extract_reasoning_and_content( content=message_content, parse_mode=reasoning_parse_mode, ) api_response.reasoning_content = reasoning_content api_response.content = final_content tool_calls = getattr(message_part, "tool_calls", None) or [] if tool_calls: api_response.tool_calls = [] for tool_call in tool_calls: if tool_call.type != "function": raise RespParseException(resp, f"响应解析失败,暂不支持工具调用类型 {tool_call.type}。") raw_arguments = tool_call.function.arguments or "" arguments = _parse_tool_arguments(raw_arguments, tool_argument_parse_mode, resp) api_response.tool_calls.append( ToolCall( call_id=tool_call.id, func_name=tool_call.function.name, args=arguments, ) ) usage_record = _extract_usage_record(getattr(resp, "usage", None)) api_response.raw_data = resp finish_reason = getattr(resp.choices[0], "finish_reason", None) _log_length_truncation(finish_reason, getattr(resp, "model", None)) _apply_xml_tool_call_fallback(api_response, tool_argument_parse_mode, resp) if not api_response.content and not api_response.tool_calls: raise EmptyResponseException(resp) return api_response, usage_record @client_registry.register_client_class("openai") class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletion]): """OpenAI 兼容客户端。""" client: AsyncOpenAI reasoning_parse_mode: ReasoningParseMode tool_argument_parse_mode: ToolArgumentParseMode def __init__(self, api_provider: APIProvider) -> None: """初始化 OpenAI 兼容客户端。 Args: api_provider: API 提供商配置。 """ super().__init__(api_provider) client_config = build_openai_compatible_client_config(api_provider) self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode) self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode) self.client = AsyncOpenAI( api_key=client_config.api_key, organization=api_provider.organization, project=api_provider.project, base_url=client_config.base_url, timeout=api_provider.timeout, max_retries=api_provider.max_retry, default_headers=client_config.default_headers or None, default_query=client_config.default_query or None, ) def _build_default_stream_response_handler( self, request: ResponseRequest, ) -> ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: """构建 OpenAI 默认流式响应处理器。 Args: request: 统一响应请求对象。 Returns: ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: 默认流式处理器。 """ del request async def default_stream_handler( resp_stream: AsyncStream[ChatCompletionChunk], flag: asyncio.Event | None, ) -> Tuple[APIResponse, UsageTuple | None]: """包装默认流式解析器。""" return await _default_stream_response_handler( resp_stream, flag, reasoning_parse_mode=self.reasoning_parse_mode, tool_argument_parse_mode=self.tool_argument_parse_mode, ) return default_stream_handler def _build_default_response_parser( self, request: ResponseRequest, ) -> ProviderResponseParser[ChatCompletion]: """构建 OpenAI 默认非流式响应解析器。 Args: request: 统一响应请求对象。 Returns: ProviderResponseParser[ChatCompletion]: 默认非流式解析器。 """ del request def default_response_parser( response: ChatCompletion, ) -> Tuple[APIResponse, UsageTuple | None]: """包装默认非流式解析器。""" return _default_normal_response_parser( response, reasoning_parse_mode=self.reasoning_parse_mode, tool_argument_parse_mode=self.tool_argument_parse_mode, ) return default_response_parser async def _execute_response_request( self, request: ResponseRequest, stream_response_handler: ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]], response_parser: ProviderResponseParser[ChatCompletion], ) -> Tuple[APIResponse, UsageTuple | None]: """执行 OpenAI 兼容的文本/多模态响应请求。 Args: request: 统一响应请求对象。 stream_response_handler: 流式响应处理器。 response_parser: 非流式响应解析器。 Returns: Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ snapshot_provider_request = { "base_url": self.api_provider.base_url, "endpoint": "/chat/completions", "method": "POST", "operation": "chat.completions.create", "organization": self.api_provider.organization, "project": self.api_provider.project, "request_kwargs": {}, } model_info = request.model_info try: request_messages = ( list(request.message_list) if request.tool_options else _sanitize_messages_for_toolless_request(request.message_list) ) messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request_messages) tools_payload: List[ChatCompletionToolParam] | None = ( _convert_tool_options(request.tool_options) if request.tool_options else None ) openai_response_format = _convert_response_format(request.response_format) request_overrides = split_openai_request_overrides( request.extra_params, reserved_body_keys=CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS, ) temperature_argument = ( omit if "temperature" in request_overrides.extra_body else _coerce_openai_argument(request.temperature) ) max_tokens_argument = ( omit if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body else _coerce_openai_argument(request.max_tokens) ) snapshot_provider_request["request_kwargs"] = { "extra_body": request_overrides.extra_body or None, "extra_headers": request_overrides.extra_headers or None, "extra_query": request_overrides.extra_query or None, "max_tokens": _snapshot_openai_argument(max_tokens_argument), "messages": messages_payload, "model": model_info.model_identifier, "response_format": _snapshot_openai_argument(openai_response_format), "stream": bool(model_info.force_stream_mode), "temperature": _snapshot_openai_argument(temperature_argument), "tools": tools_payload, } if model_info.force_stream_mode: stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task( self.client.chat.completions.create( model=model_info.model_identifier, messages=messages_payload, tools=tools_payload or omit, temperature=temperature_argument, max_tokens=max_tokens_argument, stream=True, response_format=openai_response_format, extra_headers=request_overrides.extra_headers or None, extra_query=request_overrides.extra_query or None, extra_body=request_overrides.extra_body or None, ) ) raw_response = cast( AsyncStream[ChatCompletionChunk], await await_task_with_interrupt(stream_task, request.interrupt_flag), ) return await stream_response_handler(raw_response, request.interrupt_flag) completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task( self.client.chat.completions.create( model=model_info.model_identifier, messages=messages_payload, tools=tools_payload or omit, temperature=temperature_argument, max_tokens=max_tokens_argument, stream=False, response_format=openai_response_format, extra_headers=request_overrides.extra_headers or None, extra_query=request_overrides.extra_query or None, extra_body=request_overrides.extra_body or None, ) ) raw_response = cast( ChatCompletion, await await_task_with_interrupt(completion_task, request.interrupt_flag), ) return response_parser(raw_response) except (EmptyResponseException, RespParseException) as exc: snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_response_request_snapshot(request), model_info=model_info, operation="chat.completions.create", provider_request=snapshot_provider_request, ) attach_request_snapshot(exc, snapshot_path) raise except APIConnectionError as exc: snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_response_request_snapshot(request), model_info=model_info, operation="chat.completions.create", provider_request=snapshot_provider_request, ) wrapped_error = NetworkConnectionError(str(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except APIStatusError as exc: snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_response_request_snapshot(request), model_info=model_info, operation="chat.completions.create", provider_request=snapshot_provider_request, ) wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except ReqAbortException: raise except Exception as exc: if has_request_snapshot(exc): raise snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_response_request_snapshot(request), model_info=model_info, operation="chat.completions.create", provider_request=snapshot_provider_request, ) attach_request_snapshot(exc, snapshot_path) raise async def _execute_embedding_request( self, request: EmbeddingRequest, ) -> Tuple[APIResponse, UsageTuple | None]: """执行 OpenAI 兼容的文本嵌入请求。 Args: request: 统一嵌入请求对象。 Returns: Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ model_info = request.model_info embedding_input = request.embedding_input extra_params = request.extra_params snapshot_provider_request = { "base_url": self.api_provider.base_url, "endpoint": "/embeddings", "method": "POST", "operation": "embeddings.create", "organization": self.api_provider.organization, "project": self.api_provider.project, "request_kwargs": {}, } try: request_overrides = split_openai_request_overrides(extra_params) snapshot_provider_request["request_kwargs"] = { "extra_body": request_overrides.extra_body or None, "extra_headers": request_overrides.extra_headers or None, "extra_query": request_overrides.extra_query or None, "input": embedding_input, "model": model_info.model_identifier, } raw_response = await self.client.embeddings.create( model=model_info.model_identifier, input=embedding_input, extra_headers=request_overrides.extra_headers or None, extra_query=request_overrides.extra_query or None, extra_body=request_overrides.extra_body or None, ) except APIConnectionError as exc: snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_embedding_request_snapshot(request), model_info=model_info, operation="embeddings.create", provider_request=snapshot_provider_request, ) wrapped_error = NetworkConnectionError(str(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except APIStatusError as exc: snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_embedding_request_snapshot(request), model_info=model_info, operation="embeddings.create", provider_request=snapshot_provider_request, ) wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except Exception as exc: if has_request_snapshot(exc): raise snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_embedding_request_snapshot(request), model_info=model_info, operation="embeddings.create", provider_request=snapshot_provider_request, ) attach_request_snapshot(exc, snapshot_path) raise response = APIResponse() if not raw_response.data: exc = RespParseException(raw_response, "嵌入响应解析失败,缺少 embeddings 数据。") snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_embedding_request_snapshot(request), model_info=model_info, operation="embeddings.create", provider_request=snapshot_provider_request, ) attach_request_snapshot(exc, snapshot_path) raise exc if raw_response.data: response.embedding = raw_response.data[0].embedding else: raise RespParseException(raw_response, "响应解析失败,缺失嵌入数据。") usage_record = _extract_usage_record(getattr(raw_response, "usage", None)) return response, usage_record async def _execute_audio_transcription_request( self, request: AudioTranscriptionRequest, ) -> Tuple[APIResponse, UsageTuple | None]: """执行 OpenAI 兼容的音频转录请求。 Args: request: 统一音频转录请求对象。 Returns: Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ model_info = request.model_info audio_base64 = request.audio_base64 extra_params = request.extra_params snapshot_provider_request = { "base_url": self.api_provider.base_url, "endpoint": "/audio/transcriptions", "method": "POST", "operation": "audio.transcriptions.create", "organization": self.api_provider.organization, "project": self.api_provider.project, "request_kwargs": {}, } try: request_overrides = split_openai_request_overrides(extra_params) audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64))) snapshot_provider_request["request_kwargs"] = { "audio_base64": audio_base64, "extra_body": request_overrides.extra_body or None, "extra_headers": request_overrides.extra_headers or None, "extra_query": request_overrides.extra_query or None, "file_name": "audio.wav", "model": model_info.model_identifier, } raw_response = await self.client.audio.transcriptions.create( model=model_info.model_identifier, file=audio_file, extra_headers=request_overrides.extra_headers or None, extra_query=request_overrides.extra_query or None, extra_body=request_overrides.extra_body or None, ) except APIConnectionError as exc: snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_audio_request_snapshot(request), model_info=model_info, operation="audio.transcriptions.create", provider_request=snapshot_provider_request, ) wrapped_error = NetworkConnectionError(str(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except APIStatusError as exc: snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_audio_request_snapshot(request), model_info=model_info, operation="audio.transcriptions.create", provider_request=snapshot_provider_request, ) wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except Exception as exc: if has_request_snapshot(exc): raise snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_audio_request_snapshot(request), model_info=model_info, operation="audio.transcriptions.create", provider_request=snapshot_provider_request, ) attach_request_snapshot(exc, snapshot_path) raise response = APIResponse() transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None) if not isinstance(transcription_text, str): exc = RespParseException(raw_response, "音频转写响应解析失败,缺少文本内容。") snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="openai", error=exc, internal_request=serialize_audio_request_snapshot(request), model_info=model_info, operation="audio.transcriptions.create", provider_request=snapshot_provider_request, ) attach_request_snapshot(exc, snapshot_path) raise exc if isinstance(transcription_text, str): response.content = transcription_text return response, None raise RespParseException(raw_response, "响应解析失败,缺失转录文本。") def get_support_image_formats(self) -> List[str]: """获取支持的图片格式列表。 Returns: List[str]: 当前客户端支持的图片格式列表。 """ return ["jpg", "jpeg", "png", "webp", "gif"]