Treat blank OpenAI-compatible tool call arguments as an empty dict so parameterless tools such as finish can execute with providers that return an empty string. Also trim model identifiers during config normalization to avoid leading whitespace leaking into requests and snapshots.
1555 lines
59 KiB
Python
1555 lines
59 KiB
Python
import asyncio
|
||
import base64
|
||
import binascii
|
||
import io
|
||
import json
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||
from urllib.parse import urlparse
|
||
from uuid import uuid4
|
||
|
||
from json_repair import repair_json
|
||
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
|
||
from openai._types import FileTypes, Omit, omit
|
||
from openai.types.chat import (
|
||
ChatCompletion,
|
||
ChatCompletionAssistantMessageParam,
|
||
ChatCompletionChunk,
|
||
ChatCompletionContentPartImageParam,
|
||
ChatCompletionContentPartParam,
|
||
ChatCompletionContentPartTextParam,
|
||
ChatCompletionMessageFunctionToolCallParam,
|
||
ChatCompletionMessageParam,
|
||
ChatCompletionSystemMessageParam,
|
||
ChatCompletionToolMessageParam,
|
||
ChatCompletionToolParam,
|
||
ChatCompletionUserMessageParam,
|
||
)
|
||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||
from PIL import Image as PILImage
|
||
|
||
from src.common.logger import get_logger
|
||
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
|
||
from src.llm_models.exceptions import (
|
||
EmptyResponseException,
|
||
NetworkConnectionError,
|
||
ReqAbortException,
|
||
RespNotOkException,
|
||
RespParseException,
|
||
)
|
||
from src.llm_models.openai_compat import (
|
||
build_openai_compatible_client_config,
|
||
split_openai_request_overrides,
|
||
)
|
||
from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart
|
||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||
|
||
from .adapter_base import (
|
||
AdapterClient,
|
||
ProviderResponseParser,
|
||
ProviderStreamResponseHandler,
|
||
await_task_with_interrupt,
|
||
)
|
||
from .base_client import (
|
||
APIResponse,
|
||
AudioTranscriptionRequest,
|
||
EmbeddingRequest,
|
||
ResponseRequest,
|
||
UsageTuple,
|
||
client_registry,
|
||
)
|
||
from ..request_snapshot import (
|
||
attach_request_snapshot,
|
||
has_request_snapshot,
|
||
save_failed_request_snapshot,
|
||
serialize_audio_request_snapshot,
|
||
serialize_embedding_request_snapshot,
|
||
serialize_response_request_snapshot,
|
||
)
|
||
|
||
logger = get_logger("llm_models")
|
||
|
||
DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache")
|
||
|
||
SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"}
|
||
"""OpenAI 兼容图片输入稳定支持的格式集合。"""
|
||
|
||
THINK_CONTENT_PATTERN = re.compile(
|
||
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
|
||
re.DOTALL,
|
||
)
|
||
"""用于解析 `<think>` 推理块的正则表达式。"""
|
||
|
||
XML_TOOL_CALL_PATTERN = re.compile(r"<tool_call>\s*(?P<body>.*?)\s*</tool_call>", re.DOTALL | re.IGNORECASE)
|
||
"""用于兜底解析模型以 XML 文本返回的工具调用。
|
||
|
||
这是一个暂时性兼容方案,专门处理“思维链内容里夹带工具调用”的情况;
|
||
后续如果上游稳定返回标准 tool_calls 字段,这里可能会调整或移除。
|
||
"""
|
||
|
||
XML_FUNCTION_CALL_PATTERN = re.compile(
|
||
r"<function=(?P<name>[A-Za-z0-9_.-]+)>\s*(?P<arguments>.*?)\s*</function>",
|
||
re.DOTALL | re.IGNORECASE,
|
||
)
|
||
"""用于从 XML 风格工具调用块中提取函数名与参数。"""
|
||
|
||
XML_PARAMETER_PATTERN = re.compile(
|
||
r"<parameter=(?P<name>[A-Za-z0-9_.-]+)>\s*(?P<value>.*?)\s*</parameter>",
|
||
re.DOTALL | re.IGNORECASE,
|
||
)
|
||
"""用于从 XML 风格工具调用块中提取参数列表。"""
|
||
|
||
CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = {
|
||
"max_tokens",
|
||
"messages",
|
||
"model",
|
||
"response_format",
|
||
"stream",
|
||
"temperature",
|
||
"tools",
|
||
}
|
||
"""由当前客户端显式承载、不应再落入 `extra_body` 的字段集合。"""
|
||
|
||
OpenAIStreamResponseHandler = Callable[
|
||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||
Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]],
|
||
]
|
||
"""OpenAI 流式响应处理函数类型。"""
|
||
|
||
OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]]
|
||
"""OpenAI 非流式响应解析函数类型。"""
|
||
|
||
PROVIDER_REASONING_KEYS_BY_DOMAIN: Dict[str, str] = {
|
||
"api.groq.com": "reasoning",
|
||
"openrouter.ai": "reasoning",
|
||
}
|
||
"""按 provider 域名指定的原生推理字段名。"""
|
||
|
||
|
||
def _build_debug_provider_request_filename(model_name: str) -> str:
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||
raw_name = f"provider_{timestamp}_{model_name or 'unknown'}.json"
|
||
return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name)
|
||
|
||
|
||
def _save_debug_provider_request_payload(model_name: str, request_payload: Dict[str, Any]) -> None:
|
||
if model_name != "deepseek-v4p":
|
||
return
|
||
|
||
from src.config.config import global_config
|
||
|
||
if not global_config.debug.record_reply_request:
|
||
return
|
||
|
||
try:
|
||
DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||
file_path = DEBUG_REPLY_CACHE_DIR / _build_debug_provider_request_filename(model_name)
|
||
with file_path.open("w", encoding="utf-8") as file:
|
||
json.dump(request_payload, file, ensure_ascii=False, indent=2)
|
||
logger.info(f"DeepSeek provider 请求体已保存: {file_path.resolve()}")
|
||
except Exception as exc:
|
||
logger.warning(f"保存 DeepSeek provider 请求体失败: {exc}")
|
||
|
||
|
||
def _build_fallback_tool_call_id(prefix: str) -> str:
|
||
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
|
||
|
||
normalized_prefix = str(prefix).strip() or "tool_call"
|
||
return f"{normalized_prefix}_{uuid4().hex}"
|
||
|
||
|
||
def _build_reasoning_key(api_provider: APIProvider) -> str:
|
||
"""根据 provider 构建可读取的原生推理字段名。"""
|
||
normalized_base_url = api_provider.base_url.strip()
|
||
if not normalized_base_url:
|
||
return "reasoning_content"
|
||
parsed_url = urlparse(normalized_base_url if "://" in normalized_base_url else f"//{normalized_base_url}")
|
||
provider_hostname = (parsed_url.hostname or "").lower()
|
||
return PROVIDER_REASONING_KEYS_BY_DOMAIN.get(provider_hostname, "reasoning_content")
|
||
|
||
|
||
def _extract_reasoning_content(message_part: Any, reasoning_key: str) -> str | None:
|
||
"""从 OpenAI 兼容响应对象中读取原生推理内容。
|
||
|
||
不同兼容服务商对推理字段命名并不完全一致。这里集中处理字段访问,
|
||
避免解析路径里散落 provider 特判;具体字段名由 provider 决定。
|
||
"""
|
||
native_reasoning = getattr(message_part, reasoning_key, None)
|
||
if isinstance(native_reasoning, str) and native_reasoning:
|
||
return native_reasoning
|
||
return None
|
||
|
||
|
||
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
|
||
"""将配置中的推理解析模式收敛为枚举值。
|
||
|
||
Args:
|
||
parse_mode: 原始解析模式配置。
|
||
|
||
Returns:
|
||
ReasoningParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。
|
||
"""
|
||
if isinstance(parse_mode, ReasoningParseMode):
|
||
return parse_mode
|
||
try:
|
||
return ReasoningParseMode(parse_mode)
|
||
except ValueError:
|
||
logger.warning(f"未识别的推理解析模式 {parse_mode},已回退为 auto")
|
||
return ReasoningParseMode.AUTO
|
||
|
||
|
||
def _normalize_tool_argument_parse_mode(parse_mode: str | ToolArgumentParseMode) -> ToolArgumentParseMode:
|
||
"""将配置中的工具参数解析模式收敛为枚举值。
|
||
|
||
Args:
|
||
parse_mode: 原始解析模式配置。
|
||
|
||
Returns:
|
||
ToolArgumentParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。
|
||
"""
|
||
if isinstance(parse_mode, ToolArgumentParseMode):
|
||
return parse_mode
|
||
try:
|
||
return ToolArgumentParseMode(parse_mode)
|
||
except ValueError:
|
||
logger.warning(f"未识别的工具参数解析模式 {parse_mode},已回退为 auto")
|
||
return ToolArgumentParseMode.AUTO
|
||
|
||
|
||
def _build_text_content_part(text: str) -> ChatCompletionContentPartTextParam:
|
||
"""构建文本内容片段。
|
||
|
||
Args:
|
||
text: 文本内容。
|
||
|
||
Returns:
|
||
ChatCompletionContentPartTextParam: OpenAI 兼容的文本片段。
|
||
"""
|
||
return {
|
||
"type": "text",
|
||
"text": text,
|
||
}
|
||
|
||
|
||
def _build_image_content_part(part: ImageMessagePart) -> ChatCompletionContentPartImageParam:
|
||
"""构建图片内容片段。
|
||
|
||
Args:
|
||
part: 内部图片片段。
|
||
|
||
Returns:
|
||
ChatCompletionContentPartImageParam: OpenAI 兼容的图片片段。
|
||
"""
|
||
normalized_image = _normalize_image_part_for_openai(part)
|
||
if normalized_image is None:
|
||
raise ValueError("图片数据无效,无法构建图片消息片段")
|
||
|
||
image_format, image_base64 = normalized_image
|
||
return {
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/{image_format};base64,{image_base64}",
|
||
},
|
||
}
|
||
|
||
|
||
def _normalize_image_part_for_openai(part: ImageMessagePart) -> Tuple[str, str] | None:
|
||
"""将图片片段规范化为 OpenAI 兼容格式。
|
||
|
||
Args:
|
||
part: 内部图片片段。
|
||
|
||
Returns:
|
||
Tuple[str, str] | None: `(image_format, image_base64)`;无法解析时返回 `None`。
|
||
"""
|
||
try:
|
||
image_bytes = base64.b64decode(part.image_base64, validate=True)
|
||
except (binascii.Error, ValueError) as exc:
|
||
logger.warning(f"图片 Base64 解码失败,已跳过该图片片段: {exc}")
|
||
return None
|
||
|
||
try:
|
||
with PILImage.open(io.BytesIO(image_bytes)) as image:
|
||
image_format = (image.format or part.normalized_image_format).lower()
|
||
if image_format in {"jpg", "jpeg"}:
|
||
image_format = "jpeg"
|
||
|
||
if image_format in SUPPORTED_OPENAI_IMAGE_FORMATS:
|
||
return image_format, part.image_base64
|
||
|
||
if image_format == "gif":
|
||
frame_count = getattr(image, "n_frames", 1)
|
||
frames: List[PILImage.Image] = []
|
||
durations: List[int] = []
|
||
|
||
for frame_index in range(frame_count):
|
||
image.seek(frame_index)
|
||
frame = image.copy()
|
||
if frame.mode not in {"RGB", "RGBA"}:
|
||
frame = frame.convert("RGBA")
|
||
frames.append(frame)
|
||
durations.append(int(image.info.get("duration", 100) or 100))
|
||
|
||
output_buffer = io.BytesIO()
|
||
save_kwargs: Dict[str, Any] = {
|
||
"format": "WEBP",
|
||
"save_all": True,
|
||
"append_images": frames[1:],
|
||
"duration": durations,
|
||
"loop": int(image.info.get("loop", 0) or 0),
|
||
}
|
||
if frame_count > 1:
|
||
save_kwargs["lossless"] = True
|
||
|
||
frames[0].save(output_buffer, **save_kwargs)
|
||
converted_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8")
|
||
return "webp", converted_base64
|
||
|
||
image.seek(0)
|
||
normalized_image = image.copy()
|
||
if normalized_image.mode not in {"RGB", "RGBA"}:
|
||
normalized_image = normalized_image.convert("RGBA")
|
||
|
||
output_buffer = io.BytesIO()
|
||
normalized_image.save(output_buffer, format="PNG")
|
||
converted_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8")
|
||
return "png", converted_base64
|
||
except Exception as exc:
|
||
logger.warning(f"图片内容无法被识别为有效图片,已跳过该图片片段: {exc}")
|
||
return None
|
||
|
||
|
||
def _convert_response_format(response_format: RespFormat | None) -> Any:
|
||
"""将内部响应格式转换为 OpenAI 兼容结构。
|
||
|
||
Args:
|
||
response_format: 内部响应格式定义。
|
||
|
||
Returns:
|
||
Any: OpenAI SDK 可接受的响应格式参数;未指定时返回 `omit`。
|
||
"""
|
||
if response_format is None or response_format.format_type == RespFormatType.TEXT:
|
||
return omit
|
||
if response_format.format_type == RespFormatType.JSON_OBJ:
|
||
return {"type": "json_object"}
|
||
if response_format.format_type == RespFormatType.JSON_SCHEMA:
|
||
return {
|
||
"type": "json_schema",
|
||
"json_schema": response_format.schema,
|
||
}
|
||
return omit
|
||
|
||
|
||
def _convert_text_only_message_content(
|
||
message: Message,
|
||
) -> str | List[ChatCompletionContentPartTextParam]:
|
||
"""将仅允许文本的消息转换为 OpenAI 兼容内容。
|
||
|
||
Args:
|
||
message: 内部统一消息对象。
|
||
|
||
Returns:
|
||
str | List[ChatCompletionContentPartTextParam]: 文本内容结构。
|
||
|
||
Raises:
|
||
ValueError: 当消息中包含非文本片段时抛出。
|
||
"""
|
||
if not message.parts:
|
||
return ""
|
||
if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart):
|
||
return message.parts[0].text
|
||
|
||
content: List[ChatCompletionContentPartTextParam] = []
|
||
for part in message.parts:
|
||
if not isinstance(part, TextMessagePart):
|
||
raise ValueError(f"{message.role.value} 消息仅支持文本片段")
|
||
content.append(_build_text_content_part(part.text))
|
||
return content
|
||
|
||
|
||
def _convert_user_message_content(message: Message) -> str | List[ChatCompletionContentPartParam]:
|
||
"""将用户消息转换为 OpenAI 兼容内容。
|
||
|
||
Args:
|
||
message: 内部统一消息对象。
|
||
|
||
Returns:
|
||
str | List[ChatCompletionContentPartParam]: 用户消息内容结构。
|
||
"""
|
||
if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart):
|
||
return message.parts[0].text
|
||
|
||
content: List[ChatCompletionContentPartParam] = []
|
||
for part in message.parts:
|
||
if isinstance(part, TextMessagePart):
|
||
content.append(_build_text_content_part(part.text))
|
||
continue
|
||
|
||
normalized_image = _normalize_image_part_for_openai(part)
|
||
if normalized_image is None:
|
||
content.append(_build_text_content_part("[图片内容不可用]"))
|
||
continue
|
||
|
||
image_format, image_base64 = normalized_image
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/{image_format};base64,{image_base64}",
|
||
},
|
||
}
|
||
)
|
||
return content
|
||
|
||
|
||
def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatCompletionMessageFunctionToolCallParam]:
|
||
"""将内部工具调用转换为 OpenAI assistant tool_calls 结构。
|
||
|
||
Args:
|
||
tool_calls: 内部工具调用列表。
|
||
|
||
Returns:
|
||
List[ChatCompletionMessageFunctionToolCallParam]: OpenAI 兼容工具调用结构。
|
||
"""
|
||
converted_tool_calls: List[ChatCompletionMessageFunctionToolCallParam] = []
|
||
for tool_call in tool_calls:
|
||
converted_tool_calls.append(
|
||
{
|
||
"id": tool_call.call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_call.func_name,
|
||
"arguments": json.dumps(tool_call.args or {}, ensure_ascii=False),
|
||
},
|
||
}
|
||
)
|
||
return converted_tool_calls
|
||
|
||
|
||
def _sanitize_messages_for_toolless_request(messages: List[Message]) -> List[Message]:
|
||
"""在无工具请求时清洗历史工具调用链,避免兼容接口拒收消息。"""
|
||
sanitized_messages: List[Message] = []
|
||
|
||
for message in messages:
|
||
if message.role == RoleType.Tool:
|
||
continue
|
||
|
||
if message.role == RoleType.Assistant and message.tool_calls:
|
||
if not message.parts:
|
||
continue
|
||
assistant_message = Message(
|
||
role=message.role,
|
||
parts=list(message.parts),
|
||
tool_call_id=message.tool_call_id,
|
||
tool_name=message.tool_name,
|
||
tool_calls=None,
|
||
)
|
||
sanitized_messages.append(assistant_message)
|
||
continue
|
||
|
||
sanitized_messages.append(message)
|
||
|
||
return sanitized_messages
|
||
|
||
|
||
def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]:
|
||
"""将内部消息列表转换为 OpenAI 兼容消息列表。
|
||
|
||
Args:
|
||
messages: 内部统一消息列表。
|
||
|
||
Returns:
|
||
List[ChatCompletionMessageParam]: OpenAI SDK 所需的消息结构列表。
|
||
"""
|
||
converted_messages: List[ChatCompletionMessageParam] = []
|
||
for message in messages:
|
||
if message.role == RoleType.System:
|
||
system_payload: ChatCompletionSystemMessageParam = {
|
||
"role": "system",
|
||
"content": _convert_text_only_message_content(message),
|
||
}
|
||
converted_messages.append(system_payload)
|
||
continue
|
||
|
||
if message.role == RoleType.User:
|
||
user_payload: ChatCompletionUserMessageParam = {
|
||
"role": "user",
|
||
"content": _convert_user_message_content(message),
|
||
}
|
||
converted_messages.append(user_payload)
|
||
continue
|
||
|
||
if message.role == RoleType.Assistant:
|
||
assistant_payload: ChatCompletionAssistantMessageParam = {
|
||
"role": "assistant",
|
||
"content": None if not message.parts and message.tool_calls else _convert_text_only_message_content(message),
|
||
}
|
||
if message.tool_calls:
|
||
assistant_payload["tool_calls"] = _convert_assistant_tool_calls(message.tool_calls)
|
||
converted_messages.append(assistant_payload)
|
||
continue
|
||
|
||
if message.role == RoleType.Tool:
|
||
if message.tool_call_id is None:
|
||
raise ValueError("Tool 消息缺少 tool_call_id")
|
||
tool_payload: ChatCompletionToolMessageParam = {
|
||
"role": "tool",
|
||
"content": _convert_text_only_message_content(message),
|
||
"tool_call_id": message.tool_call_id,
|
||
}
|
||
converted_messages.append(tool_payload)
|
||
continue
|
||
|
||
raise ValueError(f"不支持的消息角色:{message.role}")
|
||
|
||
return converted_messages
|
||
|
||
|
||
def _convert_tool_options(tool_options: List[ToolOption]) -> List[ChatCompletionToolParam]:
|
||
"""将工具定义转换为 OpenAI 兼容的工具列表。
|
||
|
||
Args:
|
||
tool_options: 内部统一工具定义列表。
|
||
|
||
Returns:
|
||
List[ChatCompletionToolParam]: OpenAI SDK 所需的工具定义列表。
|
||
"""
|
||
converted_tools: List[ChatCompletionToolParam] = []
|
||
for tool_option in tool_options:
|
||
parameters_schema = cast(
|
||
Dict[str, object],
|
||
tool_option.parameters_schema or {"type": "object", "properties": {}},
|
||
)
|
||
function_schema: FunctionDefinition = {
|
||
"name": tool_option.name,
|
||
"description": tool_option.description,
|
||
"parameters": parameters_schema,
|
||
}
|
||
converted_tools.append(
|
||
{
|
||
"type": "function",
|
||
"function": function_schema,
|
||
}
|
||
)
|
||
return converted_tools
|
||
|
||
|
||
def _extract_usage_record(usage: Any) -> UsageTuple | None:
|
||
"""从响应对象中提取 usage 三元组。
|
||
|
||
Args:
|
||
usage: OpenAI SDK 返回的 usage 对象。
|
||
|
||
Returns:
|
||
UsageTuple | None: `(prompt_tokens, completion_tokens, total_tokens)`。
|
||
"""
|
||
if usage is None:
|
||
return None
|
||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||
prompt_cache_hit_tokens = getattr(usage, "prompt_cache_hit_tokens", 0) or 0
|
||
prompt_cache_miss_tokens = getattr(usage, "prompt_cache_miss_tokens", 0) or 0
|
||
prompt_tokens_details = getattr(usage, "prompt_tokens_details", None)
|
||
if prompt_cache_hit_tokens == 0 and prompt_tokens_details is not None:
|
||
prompt_cache_hit_tokens = getattr(prompt_tokens_details, "cached_tokens", 0) or 0
|
||
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
|
||
prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0)
|
||
return (
|
||
prompt_tokens,
|
||
getattr(usage, "completion_tokens", 0) or 0,
|
||
getattr(usage, "total_tokens", 0) or 0,
|
||
prompt_cache_hit_tokens,
|
||
prompt_cache_miss_tokens,
|
||
)
|
||
|
||
|
||
def _parse_tool_arguments(
|
||
raw_arguments: str,
|
||
parse_mode: ToolArgumentParseMode,
|
||
response: Any,
|
||
) -> Dict[str, Any]:
|
||
"""解析工具调用参数字符串。
|
||
|
||
Args:
|
||
raw_arguments: 工具调用参数原始字符串。
|
||
parse_mode: 参数解析模式。
|
||
response: 原始响应对象,用于异常上下文。
|
||
|
||
Returns:
|
||
Dict[str, Any]: 解析后的参数字典。
|
||
|
||
Raises:
|
||
RespParseException: 当参数无法解析为字典时抛出。
|
||
"""
|
||
if not raw_arguments.strip():
|
||
return {}
|
||
|
||
try:
|
||
if parse_mode == ToolArgumentParseMode.STRICT:
|
||
arguments: Any = json.loads(raw_arguments)
|
||
elif parse_mode == ToolArgumentParseMode.REPAIR:
|
||
arguments = repair_json(raw_arguments, return_objects=True, logging=False)
|
||
else:
|
||
arguments = repair_json(raw_arguments, return_objects=True, logging=False)
|
||
if isinstance(arguments, str) and parse_mode in {
|
||
ToolArgumentParseMode.AUTO,
|
||
ToolArgumentParseMode.DOUBLE_DECODE,
|
||
}:
|
||
arguments = repair_json(arguments, return_objects=True, logging=False)
|
||
except json.JSONDecodeError as exc:
|
||
raise RespParseException(response, f"响应解析失败,无法解析工具调用参数。原始参数:{raw_arguments}") from exc
|
||
|
||
if not isinstance(arguments, dict):
|
||
raise RespParseException(
|
||
response,
|
||
f"响应解析失败,工具调用参数必须解析为字典,实际类型为 {type(arguments).__name__}。",
|
||
)
|
||
return arguments
|
||
|
||
|
||
def _extract_reasoning_and_content(
|
||
content: str,
|
||
parse_mode: ReasoningParseMode,
|
||
) -> Tuple[str | None, str | None]:
|
||
"""从文本内容中提取推理内容与正式输出。
|
||
|
||
Args:
|
||
content: 模型返回的文本内容。
|
||
parse_mode: 推理解析模式。
|
||
|
||
Returns:
|
||
Tuple[str | None, str | None]: `(reasoning_content, content)`。
|
||
"""
|
||
if parse_mode in {ReasoningParseMode.NATIVE, ReasoningParseMode.NONE}:
|
||
return None, content
|
||
|
||
match = THINK_CONTENT_PATTERN.match(content)
|
||
if not match:
|
||
return None, content
|
||
if match.group("think") is not None:
|
||
reasoning_content = match.group("think").strip() or None
|
||
final_content = match.group("content").strip() or None
|
||
return reasoning_content, final_content
|
||
if match.group("think_unclosed") is not None:
|
||
return match.group("think_unclosed").strip() or None, None
|
||
return None, match.group("content_only").strip() or None
|
||
|
||
|
||
def _extract_xml_tool_calls(
|
||
raw_text: str | None,
|
||
parse_mode: ToolArgumentParseMode,
|
||
response: Any,
|
||
) -> Tuple[str | None, List[ToolCall] | None]:
|
||
"""从 XML 风格文本中兜底提取工具调用。"""
|
||
if not isinstance(raw_text, str) or not raw_text.strip():
|
||
return raw_text, None
|
||
|
||
tool_calls: List[ToolCall] = []
|
||
|
||
def _coerce_xml_parameter_value(raw_value: str) -> Any:
|
||
normalized_value = raw_value.strip()
|
||
if not normalized_value:
|
||
return ""
|
||
lowered_value = normalized_value.lower()
|
||
if lowered_value == "true":
|
||
return True
|
||
if lowered_value == "false":
|
||
return False
|
||
if lowered_value in {"null", "none"}:
|
||
return None
|
||
if normalized_value.startswith(("{", "[")):
|
||
try:
|
||
return repair_json(normalized_value, return_objects=True, logging=False)
|
||
except Exception:
|
||
return normalized_value
|
||
return normalized_value
|
||
|
||
def _parse_xml_parameters(raw_arguments: str) -> Dict[str, Any] | None:
|
||
parameters = {
|
||
match.group("name").strip(): _coerce_xml_parameter_value(match.group("value"))
|
||
for match in XML_PARAMETER_PATTERN.finditer(raw_arguments)
|
||
}
|
||
return parameters or None
|
||
|
||
def _replace_tool_call(match: re.Match[str]) -> str:
|
||
body = match.group("body")
|
||
function_match = XML_FUNCTION_CALL_PATTERN.search(body)
|
||
if function_match is None:
|
||
return match.group(0)
|
||
|
||
function_name = function_match.group("name").strip()
|
||
raw_arguments = function_match.group("arguments").strip()
|
||
arguments = _parse_xml_parameters(raw_arguments)
|
||
if arguments is None:
|
||
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
|
||
tool_calls.append(
|
||
ToolCall(
|
||
call_id=_build_fallback_tool_call_id("xml_tool_call"),
|
||
func_name=function_name,
|
||
args=arguments,
|
||
)
|
||
)
|
||
return ""
|
||
|
||
cleaned_text = XML_TOOL_CALL_PATTERN.sub(_replace_tool_call, raw_text).strip() or None
|
||
return cleaned_text, tool_calls or None
|
||
|
||
|
||
def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None:
|
||
"""记录因长度截断导致的告警日志。
|
||
|
||
Args:
|
||
finish_reason: OpenAI 兼容接口返回的完成原因。
|
||
model_name: 上游返回的模型标识。
|
||
"""
|
||
if finish_reason == "length":
|
||
logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整")
|
||
|
||
|
||
def _apply_xml_tool_call_fallback(
|
||
response: APIResponse,
|
||
parse_mode: ToolArgumentParseMode,
|
||
raw_response: Any,
|
||
) -> None:
|
||
"""当上游未返回标准 tool_calls 时,尝试从 XML 文本兜底解析。
|
||
|
||
这是一个暂时性处理方法,用来兼容思维链中混入工具调用的返回格式,
|
||
后续可能随着模型或上游接口的规范化而变更。
|
||
"""
|
||
if response.tool_calls:
|
||
return
|
||
|
||
reasoning_content, tool_calls = _extract_xml_tool_calls(response.reasoning_content, parse_mode, raw_response)
|
||
if reasoning_content != response.reasoning_content:
|
||
response.reasoning_content = reasoning_content
|
||
if tool_calls:
|
||
response.tool_calls = tool_calls
|
||
if not response.content and reasoning_content:
|
||
response.content = reasoning_content
|
||
response.reasoning_content = None
|
||
logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用")
|
||
return
|
||
|
||
cleaned_content, tool_calls = _extract_xml_tool_calls(response.content, parse_mode, raw_response)
|
||
if cleaned_content != response.content:
|
||
response.content = cleaned_content
|
||
if tool_calls:
|
||
response.tool_calls = tool_calls
|
||
logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用")
|
||
|
||
|
||
def _coerce_openai_argument(value: Any) -> Any | Omit:
|
||
"""将可选参数转换为 OpenAI SDK 期望的值。
|
||
|
||
Args:
|
||
value: 原始参数值。
|
||
|
||
Returns:
|
||
Any | Omit: `None` 会被转换为 `omit`,其余值原样返回。
|
||
"""
|
||
if value is None:
|
||
return omit
|
||
return value
|
||
|
||
|
||
def _snapshot_openai_argument(value: Any | Omit) -> Any | None:
|
||
"""将 OpenAI SDK 参数转换为适合写入快照的普通值。"""
|
||
if value is omit:
|
||
return None
|
||
return value
|
||
|
||
|
||
def _build_api_status_message(error: APIStatusError) -> str:
|
||
"""构建更适合记录和展示的状态错误信息。
|
||
|
||
Args:
|
||
error: OpenAI SDK 抛出的状态错误。
|
||
|
||
Returns:
|
||
str: 拼装后的错误信息。
|
||
"""
|
||
message_parts: List[str] = []
|
||
if getattr(error, "message", None):
|
||
message_parts.append(str(error.message))
|
||
response_text = getattr(getattr(error, "response", None), "text", None)
|
||
if response_text:
|
||
message_parts.append(str(response_text))
|
||
if message_parts:
|
||
return " | ".join(message_parts)
|
||
return f"上游接口返回状态码 {error.status_code}"
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class _StreamedToolCallState:
|
||
"""流式工具调用累积状态。"""
|
||
|
||
index: int
|
||
call_id: str = ""
|
||
function_name: str = ""
|
||
arguments_buffer: io.StringIO = field(default_factory=io.StringIO)
|
||
|
||
def append_arguments(self, arguments_chunk: str) -> None:
|
||
"""追加一段工具调用参数字符串。
|
||
|
||
Args:
|
||
arguments_chunk: 参数增量片段。
|
||
"""
|
||
self.arguments_buffer.write(arguments_chunk)
|
||
|
||
def close(self) -> None:
|
||
"""关闭内部缓存。"""
|
||
if not self.arguments_buffer.closed:
|
||
self.arguments_buffer.close()
|
||
|
||
|
||
class _OpenAIStreamAccumulator:
|
||
"""OpenAI 兼容流式响应累积器。"""
|
||
|
||
def __init__(
|
||
self,
|
||
reasoning_parse_mode: ReasoningParseMode,
|
||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||
reasoning_key: str,
|
||
) -> None:
|
||
"""初始化累积器。
|
||
|
||
Args:
|
||
reasoning_parse_mode: 推理内容解析模式。
|
||
tool_argument_parse_mode: 工具参数解析模式。
|
||
reasoning_key: 允许读取的原生推理字段名。
|
||
"""
|
||
self.reasoning_parse_mode = reasoning_parse_mode
|
||
self.tool_argument_parse_mode = tool_argument_parse_mode
|
||
self.reasoning_key = reasoning_key
|
||
self.reasoning_buffer = io.StringIO()
|
||
self.content_buffer = io.StringIO()
|
||
self.tool_call_states: Dict[int, _StreamedToolCallState] = {}
|
||
self.finish_reason: str | None = None
|
||
self.model_name: str | None = None
|
||
self._using_native_reasoning = False
|
||
|
||
def capture_event_metadata(self, event: ChatCompletionChunk) -> None:
|
||
"""捕获事件中的完成原因和模型名。
|
||
|
||
Args:
|
||
event: 当前流式事件。
|
||
"""
|
||
if getattr(event, "model", None) and not self.model_name:
|
||
self.model_name = event.model
|
||
if getattr(event, "choices", None):
|
||
finish_reason = getattr(event.choices[0], "finish_reason", None)
|
||
if finish_reason:
|
||
self.finish_reason = finish_reason
|
||
|
||
def process_delta(self, delta: ChoiceDelta) -> None:
|
||
"""处理一个增量块。
|
||
|
||
Args:
|
||
delta: 当前增量对象。
|
||
"""
|
||
self._process_reasoning_delta(delta)
|
||
self._process_tool_call_delta(delta)
|
||
|
||
def _process_reasoning_delta(self, delta: ChoiceDelta) -> None:
|
||
"""处理推理内容与正式内容。
|
||
|
||
Args:
|
||
delta: 当前增量对象。
|
||
"""
|
||
native_reasoning = _extract_reasoning_content(delta, self.reasoning_key)
|
||
if native_reasoning is not None:
|
||
self._using_native_reasoning = True
|
||
if self.reasoning_parse_mode != ReasoningParseMode.NONE:
|
||
self.reasoning_buffer.write(native_reasoning)
|
||
return
|
||
|
||
content_chunk = getattr(delta, "content", None)
|
||
if not isinstance(content_chunk, str) or content_chunk == "":
|
||
return
|
||
|
||
if self.reasoning_parse_mode == ReasoningParseMode.NONE:
|
||
self.content_buffer.write(content_chunk)
|
||
return
|
||
|
||
if self.reasoning_parse_mode == ReasoningParseMode.NATIVE:
|
||
self.content_buffer.write(content_chunk)
|
||
return
|
||
|
||
self.content_buffer.write(content_chunk)
|
||
|
||
def _process_tool_call_delta(self, delta: ChoiceDelta) -> None:
|
||
"""处理工具调用增量。
|
||
|
||
Args:
|
||
delta: 当前增量对象。
|
||
"""
|
||
tool_call_deltas = getattr(delta, "tool_calls", None) or []
|
||
for tool_call_delta in tool_call_deltas:
|
||
state = self.tool_call_states.setdefault(tool_call_delta.index, _StreamedToolCallState(index=tool_call_delta.index))
|
||
if tool_call_delta.id:
|
||
state.call_id = tool_call_delta.id
|
||
function = tool_call_delta.function
|
||
if function is not None and function.name:
|
||
state.function_name = function.name
|
||
if function is not None and function.arguments:
|
||
state.append_arguments(function.arguments)
|
||
|
||
def build_response(self) -> APIResponse:
|
||
"""构建最终 APIResponse 对象。
|
||
|
||
Returns:
|
||
APIResponse: 累积完成的响应对象。
|
||
|
||
Raises:
|
||
EmptyResponseException: 当响应中既无可见内容也无工具调用时抛出。
|
||
RespParseException: 当工具调用结构不完整时抛出。
|
||
"""
|
||
response = APIResponse()
|
||
|
||
content = self.content_buffer.getvalue().strip()
|
||
reasoning_content = self.reasoning_buffer.getvalue().strip()
|
||
if not self._using_native_reasoning and self.reasoning_parse_mode != ReasoningParseMode.NONE and content:
|
||
parsed_reasoning_content, parsed_content = _extract_reasoning_and_content(
|
||
content=content,
|
||
parse_mode=self.reasoning_parse_mode,
|
||
)
|
||
if parsed_reasoning_content:
|
||
reasoning_content = parsed_reasoning_content
|
||
content = parsed_content or ""
|
||
if reasoning_content:
|
||
response.reasoning_content = reasoning_content
|
||
if content:
|
||
response.content = content
|
||
|
||
if self.tool_call_states:
|
||
response.tool_calls = []
|
||
for index in sorted(self.tool_call_states):
|
||
state = self.tool_call_states[index]
|
||
if not state.function_name:
|
||
raise RespParseException(None, f"响应解析失败,工具调用 {index} 缺少函数名。")
|
||
raw_arguments = state.arguments_buffer.getvalue().strip()
|
||
arguments = (
|
||
_parse_tool_arguments(raw_arguments, self.tool_argument_parse_mode, None)
|
||
if raw_arguments
|
||
else None
|
||
)
|
||
call_id = state.call_id or _build_fallback_tool_call_id(f"tool_call_{index}")
|
||
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
|
||
|
||
response.raw_data = {"model": self.model_name} if self.model_name else None
|
||
_apply_xml_tool_call_fallback(response, self.tool_argument_parse_mode, response.raw_data)
|
||
|
||
if not response.content and not response.tool_calls:
|
||
raise EmptyResponseException(response.raw_data)
|
||
|
||
return response
|
||
|
||
def close(self) -> None:
|
||
"""关闭内部缓冲区。"""
|
||
if not self.reasoning_buffer.closed:
|
||
self.reasoning_buffer.close()
|
||
if not self.content_buffer.closed:
|
||
self.content_buffer.close()
|
||
for state in self.tool_call_states.values():
|
||
state.close()
|
||
|
||
|
||
async def _default_stream_response_handler(
|
||
resp_stream: AsyncStream[ChatCompletionChunk],
|
||
interrupt_flag: asyncio.Event | None,
|
||
*,
|
||
reasoning_parse_mode: ReasoningParseMode,
|
||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||
reasoning_key: str,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""处理 OpenAI 兼容流式响应。
|
||
|
||
Args:
|
||
resp_stream: OpenAI SDK 返回的流式响应对象。
|
||
interrupt_flag: 外部中断标记。
|
||
reasoning_parse_mode: 推理内容解析模式。
|
||
tool_argument_parse_mode: 工具参数解析模式。
|
||
reasoning_key: 允许读取的原生推理字段名。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
||
"""
|
||
accumulator = _OpenAIStreamAccumulator(
|
||
reasoning_parse_mode=reasoning_parse_mode,
|
||
tool_argument_parse_mode=tool_argument_parse_mode,
|
||
reasoning_key=reasoning_key,
|
||
)
|
||
usage_record: UsageTuple | None = None
|
||
|
||
try:
|
||
async for event in resp_stream:
|
||
if interrupt_flag and interrupt_flag.is_set():
|
||
raise ReqAbortException("请求被外部信号中断")
|
||
|
||
accumulator.capture_event_metadata(event)
|
||
event_usage = _extract_usage_record(getattr(event, "usage", None))
|
||
if event_usage is not None:
|
||
usage_record = event_usage
|
||
|
||
if not getattr(event, "choices", None):
|
||
continue
|
||
|
||
accumulator.process_delta(event.choices[0].delta)
|
||
|
||
response = accumulator.build_response()
|
||
model_name = None
|
||
if isinstance(response.raw_data, dict):
|
||
model_name = response.raw_data.get("model")
|
||
_log_length_truncation(accumulator.finish_reason, model_name)
|
||
return response, usage_record
|
||
finally:
|
||
accumulator.close()
|
||
|
||
|
||
def _default_normal_response_parser(
|
||
resp: ChatCompletion,
|
||
*,
|
||
reasoning_parse_mode: ReasoningParseMode,
|
||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||
reasoning_key: str,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""解析 OpenAI 兼容的非流式响应。
|
||
|
||
Args:
|
||
resp: OpenAI SDK 返回的聊天补全响应。
|
||
reasoning_parse_mode: 推理内容解析模式。
|
||
tool_argument_parse_mode: 工具参数解析模式。
|
||
reasoning_key: 允许读取的原生推理字段名。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
||
|
||
Raises:
|
||
EmptyResponseException: 当 choices 为空或响应内容为空时抛出。
|
||
"""
|
||
choices = getattr(resp, "choices", None)
|
||
if not choices:
|
||
raise EmptyResponseException(resp, "响应解析失败,choices 为空或缺失")
|
||
|
||
api_response = APIResponse()
|
||
message_part = choices[0].message
|
||
native_reasoning = _extract_reasoning_content(message_part, reasoning_key)
|
||
message_content = message_part.content if isinstance(message_part.content, str) else None
|
||
|
||
if native_reasoning is not None and reasoning_parse_mode != ReasoningParseMode.NONE:
|
||
api_response.reasoning_content = native_reasoning
|
||
api_response.content = message_content
|
||
elif isinstance(message_content, str) and message_content:
|
||
reasoning_content, final_content = _extract_reasoning_and_content(
|
||
content=message_content,
|
||
parse_mode=reasoning_parse_mode,
|
||
)
|
||
api_response.reasoning_content = reasoning_content
|
||
api_response.content = final_content
|
||
|
||
tool_calls = getattr(message_part, "tool_calls", None) or []
|
||
if tool_calls:
|
||
api_response.tool_calls = []
|
||
for tool_call in tool_calls:
|
||
if tool_call.type != "function":
|
||
raise RespParseException(resp, f"响应解析失败,暂不支持工具调用类型 {tool_call.type}。")
|
||
raw_arguments = tool_call.function.arguments or ""
|
||
arguments = _parse_tool_arguments(raw_arguments, tool_argument_parse_mode, resp)
|
||
api_response.tool_calls.append(
|
||
ToolCall(
|
||
call_id=tool_call.id,
|
||
func_name=tool_call.function.name,
|
||
args=arguments,
|
||
)
|
||
)
|
||
|
||
usage_record = _extract_usage_record(getattr(resp, "usage", None))
|
||
api_response.raw_data = resp
|
||
|
||
finish_reason = getattr(resp.choices[0], "finish_reason", None)
|
||
_log_length_truncation(finish_reason, getattr(resp, "model", None))
|
||
_apply_xml_tool_call_fallback(api_response, tool_argument_parse_mode, resp)
|
||
|
||
if not api_response.content and not api_response.tool_calls:
|
||
raise EmptyResponseException(resp)
|
||
|
||
return api_response, usage_record
|
||
|
||
|
||
@client_registry.register_client_class("openai")
|
||
class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletion]):
|
||
"""OpenAI 兼容客户端。"""
|
||
|
||
client: AsyncOpenAI
|
||
reasoning_parse_mode: ReasoningParseMode
|
||
reasoning_key: str
|
||
tool_argument_parse_mode: ToolArgumentParseMode
|
||
|
||
def __init__(self, api_provider: APIProvider) -> None:
|
||
"""初始化 OpenAI 兼容客户端。
|
||
|
||
Args:
|
||
api_provider: API 提供商配置。
|
||
"""
|
||
super().__init__(api_provider)
|
||
client_config = build_openai_compatible_client_config(api_provider)
|
||
self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode)
|
||
self.reasoning_key = _build_reasoning_key(api_provider)
|
||
self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode)
|
||
self.client = AsyncOpenAI(
|
||
api_key=client_config.api_key,
|
||
organization=api_provider.organization,
|
||
project=api_provider.project,
|
||
base_url=client_config.base_url,
|
||
timeout=api_provider.timeout,
|
||
max_retries=api_provider.max_retry,
|
||
default_headers=client_config.default_headers or None,
|
||
default_query=client_config.default_query or None,
|
||
)
|
||
|
||
def _build_default_stream_response_handler(
|
||
self,
|
||
request: ResponseRequest,
|
||
) -> ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]:
|
||
"""构建 OpenAI 默认流式响应处理器。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
|
||
Returns:
|
||
ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: 默认流式处理器。
|
||
"""
|
||
del request
|
||
|
||
async def default_stream_handler(
|
||
resp_stream: AsyncStream[ChatCompletionChunk],
|
||
flag: asyncio.Event | None,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""包装默认流式解析器。"""
|
||
return await _default_stream_response_handler(
|
||
resp_stream,
|
||
flag,
|
||
reasoning_parse_mode=self.reasoning_parse_mode,
|
||
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
||
reasoning_key=self.reasoning_key,
|
||
)
|
||
|
||
return default_stream_handler
|
||
|
||
def _build_default_response_parser(
|
||
self,
|
||
request: ResponseRequest,
|
||
) -> ProviderResponseParser[ChatCompletion]:
|
||
"""构建 OpenAI 默认非流式响应解析器。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
|
||
Returns:
|
||
ProviderResponseParser[ChatCompletion]: 默认非流式解析器。
|
||
"""
|
||
del request
|
||
|
||
def default_response_parser(
|
||
response: ChatCompletion,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""包装默认非流式解析器。"""
|
||
return _default_normal_response_parser(
|
||
response,
|
||
reasoning_parse_mode=self.reasoning_parse_mode,
|
||
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
||
reasoning_key=self.reasoning_key,
|
||
)
|
||
|
||
return default_response_parser
|
||
|
||
async def _execute_response_request(
|
||
self,
|
||
request: ResponseRequest,
|
||
stream_response_handler: ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]],
|
||
response_parser: ProviderResponseParser[ChatCompletion],
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 OpenAI 兼容的文本/多模态响应请求。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
stream_response_handler: 流式响应处理器。
|
||
response_parser: 非流式响应解析器。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
snapshot_provider_request = {
|
||
"base_url": self.api_provider.base_url,
|
||
"endpoint": "/chat/completions",
|
||
"method": "POST",
|
||
"operation": "chat.completions.create",
|
||
"organization": self.api_provider.organization,
|
||
"project": self.api_provider.project,
|
||
"request_kwargs": {},
|
||
}
|
||
model_info = request.model_info
|
||
|
||
try:
|
||
request_messages = (
|
||
list(request.message_list)
|
||
if request.tool_options
|
||
else _sanitize_messages_for_toolless_request(request.message_list)
|
||
)
|
||
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request_messages)
|
||
tools_payload: List[ChatCompletionToolParam] | None = (
|
||
_convert_tool_options(request.tool_options) if request.tool_options else None
|
||
)
|
||
openai_response_format = _convert_response_format(request.response_format)
|
||
request_overrides = split_openai_request_overrides(
|
||
request.extra_params,
|
||
reserved_body_keys=CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS,
|
||
)
|
||
|
||
temperature_argument = (
|
||
omit if "temperature" in request_overrides.extra_body else _coerce_openai_argument(request.temperature)
|
||
)
|
||
max_tokens_argument = (
|
||
omit
|
||
if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body
|
||
else _coerce_openai_argument(request.max_tokens)
|
||
)
|
||
snapshot_provider_request["request_kwargs"] = {
|
||
"extra_body": request_overrides.extra_body or None,
|
||
"extra_headers": request_overrides.extra_headers or None,
|
||
"extra_query": request_overrides.extra_query or None,
|
||
"max_tokens": _snapshot_openai_argument(max_tokens_argument),
|
||
"messages": messages_payload,
|
||
"model": model_info.model_identifier,
|
||
"response_format": _snapshot_openai_argument(openai_response_format),
|
||
"stream": bool(model_info.force_stream_mode),
|
||
"temperature": _snapshot_openai_argument(temperature_argument),
|
||
"tools": tools_payload,
|
||
}
|
||
_save_debug_provider_request_payload(
|
||
model_info.name,
|
||
{
|
||
"base_url": self.api_provider.base_url,
|
||
"endpoint": "/chat/completions",
|
||
"model_name": model_info.name,
|
||
"model_identifier": model_info.model_identifier,
|
||
"created_at": datetime.now().isoformat(timespec="seconds"),
|
||
"request_kwargs": snapshot_provider_request["request_kwargs"],
|
||
},
|
||
)
|
||
|
||
if model_info.force_stream_mode:
|
||
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
|
||
self.client.chat.completions.create(
|
||
model=model_info.model_identifier,
|
||
messages=messages_payload,
|
||
tools=tools_payload or omit,
|
||
temperature=temperature_argument,
|
||
max_tokens=max_tokens_argument,
|
||
stream=True,
|
||
response_format=openai_response_format,
|
||
extra_headers=request_overrides.extra_headers or None,
|
||
extra_query=request_overrides.extra_query or None,
|
||
extra_body=request_overrides.extra_body or None,
|
||
)
|
||
)
|
||
raw_response = cast(
|
||
AsyncStream[ChatCompletionChunk],
|
||
await await_task_with_interrupt(stream_task, request.interrupt_flag),
|
||
)
|
||
return await stream_response_handler(raw_response, request.interrupt_flag)
|
||
|
||
completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task(
|
||
self.client.chat.completions.create(
|
||
model=model_info.model_identifier,
|
||
messages=messages_payload,
|
||
tools=tools_payload or omit,
|
||
temperature=temperature_argument,
|
||
max_tokens=max_tokens_argument,
|
||
stream=False,
|
||
response_format=openai_response_format,
|
||
extra_headers=request_overrides.extra_headers or None,
|
||
extra_query=request_overrides.extra_query or None,
|
||
extra_body=request_overrides.extra_body or None,
|
||
)
|
||
)
|
||
raw_response = cast(
|
||
ChatCompletion,
|
||
await await_task_with_interrupt(completion_task, request.interrupt_flag),
|
||
)
|
||
return response_parser(raw_response)
|
||
except (EmptyResponseException, RespParseException) as exc:
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_response_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="chat.completions.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
attach_request_snapshot(exc, snapshot_path)
|
||
raise
|
||
except APIConnectionError as exc:
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_response_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="chat.completions.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
wrapped_error = NetworkConnectionError(str(exc))
|
||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||
raise wrapped_error from exc
|
||
except APIStatusError as exc:
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_response_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="chat.completions.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||
raise wrapped_error from exc
|
||
except ReqAbortException:
|
||
raise
|
||
except Exception as exc:
|
||
if has_request_snapshot(exc):
|
||
raise
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_response_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="chat.completions.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
attach_request_snapshot(exc, snapshot_path)
|
||
raise
|
||
|
||
async def _execute_embedding_request(
|
||
self,
|
||
request: EmbeddingRequest,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 OpenAI 兼容的文本嵌入请求。
|
||
|
||
Args:
|
||
request: 统一嵌入请求对象。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
model_info = request.model_info
|
||
embedding_input = request.embedding_input
|
||
extra_params = request.extra_params
|
||
snapshot_provider_request = {
|
||
"base_url": self.api_provider.base_url,
|
||
"endpoint": "/embeddings",
|
||
"method": "POST",
|
||
"operation": "embeddings.create",
|
||
"organization": self.api_provider.organization,
|
||
"project": self.api_provider.project,
|
||
"request_kwargs": {},
|
||
}
|
||
|
||
try:
|
||
request_overrides = split_openai_request_overrides(extra_params)
|
||
snapshot_provider_request["request_kwargs"] = {
|
||
"extra_body": request_overrides.extra_body or None,
|
||
"extra_headers": request_overrides.extra_headers or None,
|
||
"extra_query": request_overrides.extra_query or None,
|
||
"input": embedding_input,
|
||
"model": model_info.model_identifier,
|
||
}
|
||
raw_response = await self.client.embeddings.create(
|
||
model=model_info.model_identifier,
|
||
input=embedding_input,
|
||
extra_headers=request_overrides.extra_headers or None,
|
||
extra_query=request_overrides.extra_query or None,
|
||
extra_body=request_overrides.extra_body or None,
|
||
)
|
||
except APIConnectionError as exc:
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_embedding_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="embeddings.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
wrapped_error = NetworkConnectionError(str(exc))
|
||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||
raise wrapped_error from exc
|
||
except APIStatusError as exc:
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_embedding_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="embeddings.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||
raise wrapped_error from exc
|
||
except Exception as exc:
|
||
if has_request_snapshot(exc):
|
||
raise
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_embedding_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="embeddings.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
attach_request_snapshot(exc, snapshot_path)
|
||
raise
|
||
|
||
response = APIResponse()
|
||
if not raw_response.data:
|
||
exc = RespParseException(raw_response, "嵌入响应解析失败,缺少 embeddings 数据。")
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_embedding_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="embeddings.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
attach_request_snapshot(exc, snapshot_path)
|
||
raise exc
|
||
if raw_response.data:
|
||
response.embedding = raw_response.data[0].embedding
|
||
else:
|
||
raise RespParseException(raw_response, "响应解析失败,缺失嵌入数据。")
|
||
|
||
usage_record = _extract_usage_record(getattr(raw_response, "usage", None))
|
||
return response, usage_record
|
||
|
||
async def _execute_audio_transcription_request(
|
||
self,
|
||
request: AudioTranscriptionRequest,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 OpenAI 兼容的音频转录请求。
|
||
|
||
Args:
|
||
request: 统一音频转录请求对象。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
model_info = request.model_info
|
||
audio_base64 = request.audio_base64
|
||
extra_params = request.extra_params
|
||
snapshot_provider_request = {
|
||
"base_url": self.api_provider.base_url,
|
||
"endpoint": "/audio/transcriptions",
|
||
"method": "POST",
|
||
"operation": "audio.transcriptions.create",
|
||
"organization": self.api_provider.organization,
|
||
"project": self.api_provider.project,
|
||
"request_kwargs": {},
|
||
}
|
||
|
||
try:
|
||
request_overrides = split_openai_request_overrides(extra_params)
|
||
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
|
||
snapshot_provider_request["request_kwargs"] = {
|
||
"audio_base64": audio_base64,
|
||
"extra_body": request_overrides.extra_body or None,
|
||
"extra_headers": request_overrides.extra_headers or None,
|
||
"extra_query": request_overrides.extra_query or None,
|
||
"file_name": "audio.wav",
|
||
"model": model_info.model_identifier,
|
||
}
|
||
raw_response = await self.client.audio.transcriptions.create(
|
||
model=model_info.model_identifier,
|
||
file=audio_file,
|
||
extra_headers=request_overrides.extra_headers or None,
|
||
extra_query=request_overrides.extra_query or None,
|
||
extra_body=request_overrides.extra_body or None,
|
||
)
|
||
except APIConnectionError as exc:
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_audio_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="audio.transcriptions.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
wrapped_error = NetworkConnectionError(str(exc))
|
||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||
raise wrapped_error from exc
|
||
except APIStatusError as exc:
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_audio_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="audio.transcriptions.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||
raise wrapped_error from exc
|
||
except Exception as exc:
|
||
if has_request_snapshot(exc):
|
||
raise
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_audio_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="audio.transcriptions.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
attach_request_snapshot(exc, snapshot_path)
|
||
raise
|
||
|
||
response = APIResponse()
|
||
transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None)
|
||
if not isinstance(transcription_text, str):
|
||
exc = RespParseException(raw_response, "音频转写响应解析失败,缺少文本内容。")
|
||
snapshot_path = save_failed_request_snapshot(
|
||
api_provider=self.api_provider,
|
||
client_type="openai",
|
||
error=exc,
|
||
internal_request=serialize_audio_request_snapshot(request),
|
||
model_info=model_info,
|
||
operation="audio.transcriptions.create",
|
||
provider_request=snapshot_provider_request,
|
||
)
|
||
attach_request_snapshot(exc, snapshot_path)
|
||
raise exc
|
||
if isinstance(transcription_text, str):
|
||
response.content = transcription_text
|
||
return response, None
|
||
raise RespParseException(raw_response, "响应解析失败,缺失转录文本。")
|
||
|
||
def get_support_image_formats(self) -> List[str]:
|
||
"""获取支持的图片格式列表。
|
||
|
||
Returns:
|
||
List[str]: 当前客户端支持的图片格式列表。
|
||
"""
|
||
return ["jpg", "jpeg", "png", "webp", "gif"]
|