- Added support for scheduling background tasks to build emoji and image descriptions when not found in cache. - Improved error handling and logging for emoji and image processing. - Updated `SessionMessage` processing to allow for optional heavy media analysis and voice transcription. - Refactored logging messages for better clarity and consistency across various modules. - Introduced a new function to build outbound log previews for messages, enhancing logging capabilities.
1001 lines
36 KiB
Python
1001 lines
36 KiB
Python
from collections.abc import Iterable
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||
|
||
import asyncio
|
||
import base64
|
||
import io
|
||
import json
|
||
import re
|
||
|
||
from json_repair import repair_json
|
||
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
|
||
from openai._types import FileTypes, Omit, omit
|
||
from openai.types.chat import (
|
||
ChatCompletion,
|
||
ChatCompletionAssistantMessageParam,
|
||
ChatCompletionChunk,
|
||
ChatCompletionContentPartImageParam,
|
||
ChatCompletionContentPartParam,
|
||
ChatCompletionContentPartTextParam,
|
||
ChatCompletionMessageFunctionToolCallParam,
|
||
ChatCompletionMessageParam,
|
||
ChatCompletionSystemMessageParam,
|
||
ChatCompletionToolMessageParam,
|
||
ChatCompletionToolParam,
|
||
ChatCompletionUserMessageParam,
|
||
)
|
||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||
|
||
from src.common.logger import get_logger
|
||
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
|
||
from src.llm_models.exceptions import (
|
||
EmptyResponseException,
|
||
NetworkConnectionError,
|
||
ReqAbortException,
|
||
RespNotOkException,
|
||
RespParseException,
|
||
)
|
||
from src.llm_models.openai_compat import (
|
||
build_openai_compatible_client_config,
|
||
split_openai_request_overrides,
|
||
)
|
||
from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart
|
||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||
|
||
from .adapter_base import (
|
||
AdapterClient,
|
||
ProviderResponseParser,
|
||
ProviderStreamResponseHandler,
|
||
await_task_with_interrupt,
|
||
)
|
||
from .base_client import (
|
||
APIResponse,
|
||
AudioTranscriptionRequest,
|
||
EmbeddingRequest,
|
||
ResponseRequest,
|
||
UsageTuple,
|
||
client_registry,
|
||
)
|
||
|
||
logger = get_logger("llm_models")
|
||
|
||
THINK_CONTENT_PATTERN = re.compile(
|
||
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
|
||
re.DOTALL,
|
||
)
|
||
"""用于解析 `<think>` 推理块的正则表达式。"""
|
||
|
||
CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = {
|
||
"max_tokens",
|
||
"messages",
|
||
"model",
|
||
"response_format",
|
||
"stream",
|
||
"temperature",
|
||
"tools",
|
||
}
|
||
"""由当前客户端显式承载、不应再落入 `extra_body` 的字段集合。"""
|
||
|
||
OpenAIStreamResponseHandler = Callable[
|
||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||
Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]],
|
||
]
|
||
"""OpenAI 流式响应处理函数类型。"""
|
||
|
||
OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]]
|
||
"""OpenAI 非流式响应解析函数类型。"""
|
||
|
||
|
||
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
|
||
"""将配置中的推理解析模式收敛为枚举值。
|
||
|
||
Args:
|
||
parse_mode: 原始解析模式配置。
|
||
|
||
Returns:
|
||
ReasoningParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。
|
||
"""
|
||
if isinstance(parse_mode, ReasoningParseMode):
|
||
return parse_mode
|
||
try:
|
||
return ReasoningParseMode(parse_mode)
|
||
except ValueError:
|
||
logger.warning(f"未识别的推理解析模式 {parse_mode},已回退为 auto")
|
||
return ReasoningParseMode.AUTO
|
||
|
||
|
||
def _normalize_tool_argument_parse_mode(parse_mode: str | ToolArgumentParseMode) -> ToolArgumentParseMode:
|
||
"""将配置中的工具参数解析模式收敛为枚举值。
|
||
|
||
Args:
|
||
parse_mode: 原始解析模式配置。
|
||
|
||
Returns:
|
||
ToolArgumentParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。
|
||
"""
|
||
if isinstance(parse_mode, ToolArgumentParseMode):
|
||
return parse_mode
|
||
try:
|
||
return ToolArgumentParseMode(parse_mode)
|
||
except ValueError:
|
||
logger.warning(f"未识别的工具参数解析模式 {parse_mode},已回退为 auto")
|
||
return ToolArgumentParseMode.AUTO
|
||
|
||
|
||
def _build_text_content_part(text: str) -> ChatCompletionContentPartTextParam:
|
||
"""构建文本内容片段。
|
||
|
||
Args:
|
||
text: 文本内容。
|
||
|
||
Returns:
|
||
ChatCompletionContentPartTextParam: OpenAI 兼容的文本片段。
|
||
"""
|
||
return {
|
||
"type": "text",
|
||
"text": text,
|
||
}
|
||
|
||
|
||
def _build_image_content_part(part: ImageMessagePart) -> ChatCompletionContentPartImageParam:
|
||
"""构建图片内容片段。
|
||
|
||
Args:
|
||
part: 内部图片片段。
|
||
|
||
Returns:
|
||
ChatCompletionContentPartImageParam: OpenAI 兼容的图片片段。
|
||
"""
|
||
return {
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/{part.normalized_image_format};base64,{part.image_base64}",
|
||
},
|
||
}
|
||
|
||
|
||
def _convert_response_format(response_format: RespFormat | None) -> Any:
|
||
"""将内部响应格式转换为 OpenAI 兼容结构。
|
||
|
||
Args:
|
||
response_format: 内部响应格式定义。
|
||
|
||
Returns:
|
||
Any: OpenAI SDK 可接受的响应格式参数;未指定时返回 `omit`。
|
||
"""
|
||
if response_format is None or response_format.format_type == RespFormatType.TEXT:
|
||
return omit
|
||
if response_format.format_type == RespFormatType.JSON_OBJ:
|
||
return {"type": "json_object"}
|
||
if response_format.format_type == RespFormatType.JSON_SCHEMA:
|
||
return {
|
||
"type": "json_schema",
|
||
"json_schema": response_format.schema,
|
||
}
|
||
return omit
|
||
|
||
|
||
def _convert_text_only_message_content(
|
||
message: Message,
|
||
) -> str | List[ChatCompletionContentPartTextParam]:
|
||
"""将仅允许文本的消息转换为 OpenAI 兼容内容。
|
||
|
||
Args:
|
||
message: 内部统一消息对象。
|
||
|
||
Returns:
|
||
str | List[ChatCompletionContentPartTextParam]: 文本内容结构。
|
||
|
||
Raises:
|
||
ValueError: 当消息中包含非文本片段时抛出。
|
||
"""
|
||
if not message.parts:
|
||
return ""
|
||
if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart):
|
||
return message.parts[0].text
|
||
|
||
content: List[ChatCompletionContentPartTextParam] = []
|
||
for part in message.parts:
|
||
if not isinstance(part, TextMessagePart):
|
||
raise ValueError(f"{message.role.value} 消息仅支持文本片段")
|
||
content.append(_build_text_content_part(part.text))
|
||
return content
|
||
|
||
|
||
def _convert_user_message_content(message: Message) -> str | List[ChatCompletionContentPartParam]:
|
||
"""将用户消息转换为 OpenAI 兼容内容。
|
||
|
||
Args:
|
||
message: 内部统一消息对象。
|
||
|
||
Returns:
|
||
str | List[ChatCompletionContentPartParam]: 用户消息内容结构。
|
||
"""
|
||
if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart):
|
||
return message.parts[0].text
|
||
|
||
content: List[ChatCompletionContentPartParam] = []
|
||
for part in message.parts:
|
||
if isinstance(part, TextMessagePart):
|
||
content.append(_build_text_content_part(part.text))
|
||
continue
|
||
content.append(_build_image_content_part(part))
|
||
return content
|
||
|
||
|
||
def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatCompletionMessageFunctionToolCallParam]:
|
||
"""将内部工具调用转换为 OpenAI assistant tool_calls 结构。
|
||
|
||
Args:
|
||
tool_calls: 内部工具调用列表。
|
||
|
||
Returns:
|
||
List[ChatCompletionMessageFunctionToolCallParam]: OpenAI 兼容工具调用结构。
|
||
"""
|
||
converted_tool_calls: List[ChatCompletionMessageFunctionToolCallParam] = []
|
||
for tool_call in tool_calls:
|
||
converted_tool_calls.append(
|
||
{
|
||
"id": tool_call.call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_call.func_name,
|
||
"arguments": json.dumps(tool_call.args or {}, ensure_ascii=False),
|
||
},
|
||
}
|
||
)
|
||
return converted_tool_calls
|
||
|
||
|
||
def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]:
|
||
"""将内部消息列表转换为 OpenAI 兼容消息列表。
|
||
|
||
Args:
|
||
messages: 内部统一消息列表。
|
||
|
||
Returns:
|
||
List[ChatCompletionMessageParam]: OpenAI SDK 所需的消息结构列表。
|
||
"""
|
||
converted_messages: List[ChatCompletionMessageParam] = []
|
||
for message in messages:
|
||
if message.role == RoleType.System:
|
||
system_payload: ChatCompletionSystemMessageParam = {
|
||
"role": "system",
|
||
"content": _convert_text_only_message_content(message),
|
||
}
|
||
converted_messages.append(system_payload)
|
||
continue
|
||
|
||
if message.role == RoleType.User:
|
||
user_payload: ChatCompletionUserMessageParam = {
|
||
"role": "user",
|
||
"content": _convert_user_message_content(message),
|
||
}
|
||
converted_messages.append(user_payload)
|
||
continue
|
||
|
||
if message.role == RoleType.Assistant:
|
||
assistant_payload: ChatCompletionAssistantMessageParam = {
|
||
"role": "assistant",
|
||
"content": None if not message.parts and message.tool_calls else _convert_text_only_message_content(message),
|
||
}
|
||
if message.tool_calls:
|
||
assistant_payload["tool_calls"] = _convert_assistant_tool_calls(message.tool_calls)
|
||
converted_messages.append(assistant_payload)
|
||
continue
|
||
|
||
if message.role == RoleType.Tool:
|
||
if message.tool_call_id is None:
|
||
raise ValueError("Tool 消息缺少 tool_call_id")
|
||
tool_payload: ChatCompletionToolMessageParam = {
|
||
"role": "tool",
|
||
"content": _convert_text_only_message_content(message),
|
||
"tool_call_id": message.tool_call_id,
|
||
}
|
||
converted_messages.append(tool_payload)
|
||
continue
|
||
|
||
raise ValueError(f"不支持的消息角色:{message.role}")
|
||
|
||
return converted_messages
|
||
|
||
|
||
def _convert_tool_options(tool_options: List[ToolOption]) -> List[ChatCompletionToolParam]:
|
||
"""将工具定义转换为 OpenAI 兼容的工具列表。
|
||
|
||
Args:
|
||
tool_options: 内部统一工具定义列表。
|
||
|
||
Returns:
|
||
List[ChatCompletionToolParam]: OpenAI SDK 所需的工具定义列表。
|
||
"""
|
||
converted_tools: List[ChatCompletionToolParam] = []
|
||
for tool_option in tool_options:
|
||
function_schema: FunctionDefinition = {
|
||
"name": tool_option.name,
|
||
"description": tool_option.description,
|
||
}
|
||
parameters_schema = tool_option.parameters_schema
|
||
if parameters_schema is not None:
|
||
function_schema["parameters"] = cast(Dict[str, object], parameters_schema)
|
||
converted_tools.append(
|
||
{
|
||
"type": "function",
|
||
"function": function_schema,
|
||
}
|
||
)
|
||
return converted_tools
|
||
|
||
|
||
def _extract_usage_record(usage: Any) -> UsageTuple | None:
|
||
"""从响应对象中提取 usage 三元组。
|
||
|
||
Args:
|
||
usage: OpenAI SDK 返回的 usage 对象。
|
||
|
||
Returns:
|
||
UsageTuple | None: `(prompt_tokens, completion_tokens, total_tokens)`。
|
||
"""
|
||
if usage is None:
|
||
return None
|
||
return (
|
||
getattr(usage, "prompt_tokens", 0) or 0,
|
||
getattr(usage, "completion_tokens", 0) or 0,
|
||
getattr(usage, "total_tokens", 0) or 0,
|
||
)
|
||
|
||
|
||
def _parse_tool_arguments(
|
||
raw_arguments: str,
|
||
parse_mode: ToolArgumentParseMode,
|
||
response: Any,
|
||
) -> Dict[str, Any]:
|
||
"""解析工具调用参数字符串。
|
||
|
||
Args:
|
||
raw_arguments: 工具调用参数原始字符串。
|
||
parse_mode: 参数解析模式。
|
||
response: 原始响应对象,用于异常上下文。
|
||
|
||
Returns:
|
||
Dict[str, Any]: 解析后的参数字典。
|
||
|
||
Raises:
|
||
RespParseException: 当参数无法解析为字典时抛出。
|
||
"""
|
||
try:
|
||
if parse_mode == ToolArgumentParseMode.STRICT:
|
||
arguments: Any = json.loads(raw_arguments)
|
||
elif parse_mode == ToolArgumentParseMode.REPAIR:
|
||
arguments = repair_json(raw_arguments, return_objects=True, logging=False)
|
||
else:
|
||
arguments = repair_json(raw_arguments, return_objects=True, logging=False)
|
||
if isinstance(arguments, str) and parse_mode in {
|
||
ToolArgumentParseMode.AUTO,
|
||
ToolArgumentParseMode.DOUBLE_DECODE,
|
||
}:
|
||
arguments = repair_json(arguments, return_objects=True, logging=False)
|
||
except json.JSONDecodeError as exc:
|
||
raise RespParseException(response, f"响应解析失败,无法解析工具调用参数。原始参数:{raw_arguments}") from exc
|
||
|
||
if not isinstance(arguments, dict):
|
||
raise RespParseException(
|
||
response,
|
||
f"响应解析失败,工具调用参数必须解析为字典,实际类型为 {type(arguments).__name__}。",
|
||
)
|
||
return arguments
|
||
|
||
|
||
def _extract_reasoning_and_content(
|
||
content: str,
|
||
parse_mode: ReasoningParseMode,
|
||
) -> Tuple[str | None, str | None]:
|
||
"""从文本内容中提取推理内容与正式输出。
|
||
|
||
Args:
|
||
content: 模型返回的文本内容。
|
||
parse_mode: 推理解析模式。
|
||
|
||
Returns:
|
||
Tuple[str | None, str | None]: `(reasoning_content, content)`。
|
||
"""
|
||
if parse_mode in {ReasoningParseMode.NATIVE, ReasoningParseMode.NONE}:
|
||
return None, content
|
||
|
||
match = THINK_CONTENT_PATTERN.match(content)
|
||
if not match:
|
||
return None, content
|
||
if match.group("think") is not None:
|
||
reasoning_content = match.group("think").strip() or None
|
||
final_content = match.group("content").strip() or None
|
||
return reasoning_content, final_content
|
||
if match.group("think_unclosed") is not None:
|
||
return match.group("think_unclosed").strip() or None, None
|
||
return None, match.group("content_only").strip() or None
|
||
|
||
|
||
def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None:
|
||
"""记录因长度截断导致的告警日志。
|
||
|
||
Args:
|
||
finish_reason: OpenAI 兼容接口返回的完成原因。
|
||
model_name: 上游返回的模型标识。
|
||
"""
|
||
if finish_reason == "length":
|
||
logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整")
|
||
|
||
|
||
def _coerce_openai_argument(value: Any) -> Any | Omit:
|
||
"""将可选参数转换为 OpenAI SDK 期望的值。
|
||
|
||
Args:
|
||
value: 原始参数值。
|
||
|
||
Returns:
|
||
Any | Omit: `None` 会被转换为 `omit`,其余值原样返回。
|
||
"""
|
||
if value is None:
|
||
return omit
|
||
return value
|
||
|
||
|
||
def _build_api_status_message(error: APIStatusError) -> str:
|
||
"""构建更适合记录和展示的状态错误信息。
|
||
|
||
Args:
|
||
error: OpenAI SDK 抛出的状态错误。
|
||
|
||
Returns:
|
||
str: 拼装后的错误信息。
|
||
"""
|
||
message_parts: List[str] = []
|
||
if getattr(error, "message", None):
|
||
message_parts.append(str(error.message))
|
||
response_text = getattr(getattr(error, "response", None), "text", None)
|
||
if response_text:
|
||
message_parts.append(str(response_text)[:300])
|
||
if message_parts:
|
||
return " | ".join(message_parts)
|
||
return f"上游接口返回状态码 {error.status_code}"
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class _StreamedToolCallState:
|
||
"""流式工具调用累积状态。"""
|
||
|
||
index: int
|
||
call_id: str = ""
|
||
function_name: str = ""
|
||
arguments_buffer: io.StringIO = field(default_factory=io.StringIO)
|
||
|
||
def append_arguments(self, arguments_chunk: str) -> None:
|
||
"""追加一段工具调用参数字符串。
|
||
|
||
Args:
|
||
arguments_chunk: 参数增量片段。
|
||
"""
|
||
self.arguments_buffer.write(arguments_chunk)
|
||
|
||
def close(self) -> None:
|
||
"""关闭内部缓存。"""
|
||
if not self.arguments_buffer.closed:
|
||
self.arguments_buffer.close()
|
||
|
||
|
||
class _OpenAIStreamAccumulator:
|
||
"""OpenAI 兼容流式响应累积器。"""
|
||
|
||
def __init__(
|
||
self,
|
||
reasoning_parse_mode: ReasoningParseMode,
|
||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||
) -> None:
|
||
"""初始化累积器。
|
||
|
||
Args:
|
||
reasoning_parse_mode: 推理内容解析模式。
|
||
tool_argument_parse_mode: 工具参数解析模式。
|
||
"""
|
||
self.reasoning_parse_mode = reasoning_parse_mode
|
||
self.tool_argument_parse_mode = tool_argument_parse_mode
|
||
self.reasoning_buffer = io.StringIO()
|
||
self.content_buffer = io.StringIO()
|
||
self.tool_call_states: Dict[int, _StreamedToolCallState] = {}
|
||
self.finish_reason: str | None = None
|
||
self.model_name: str | None = None
|
||
self._using_native_reasoning = False
|
||
|
||
def capture_event_metadata(self, event: ChatCompletionChunk) -> None:
|
||
"""捕获事件中的完成原因和模型名。
|
||
|
||
Args:
|
||
event: 当前流式事件。
|
||
"""
|
||
if getattr(event, "model", None) and not self.model_name:
|
||
self.model_name = event.model
|
||
if getattr(event, "choices", None):
|
||
finish_reason = getattr(event.choices[0], "finish_reason", None)
|
||
if finish_reason:
|
||
self.finish_reason = finish_reason
|
||
|
||
def process_delta(self, delta: ChoiceDelta) -> None:
|
||
"""处理一个增量块。
|
||
|
||
Args:
|
||
delta: 当前增量对象。
|
||
"""
|
||
self._process_reasoning_delta(delta)
|
||
self._process_tool_call_delta(delta)
|
||
|
||
def _process_reasoning_delta(self, delta: ChoiceDelta) -> None:
|
||
"""处理推理内容与正式内容。
|
||
|
||
Args:
|
||
delta: 当前增量对象。
|
||
"""
|
||
native_reasoning = getattr(delta, "reasoning_content", None)
|
||
if isinstance(native_reasoning, str) and native_reasoning:
|
||
self._using_native_reasoning = True
|
||
if self.reasoning_parse_mode != ReasoningParseMode.NONE:
|
||
self.reasoning_buffer.write(native_reasoning)
|
||
return
|
||
|
||
content_chunk = getattr(delta, "content", None)
|
||
if not isinstance(content_chunk, str) or content_chunk == "":
|
||
return
|
||
|
||
if self.reasoning_parse_mode == ReasoningParseMode.NONE:
|
||
self.content_buffer.write(content_chunk)
|
||
return
|
||
|
||
if self.reasoning_parse_mode == ReasoningParseMode.NATIVE:
|
||
self.content_buffer.write(content_chunk)
|
||
return
|
||
|
||
self.content_buffer.write(content_chunk)
|
||
|
||
def _process_tool_call_delta(self, delta: ChoiceDelta) -> None:
|
||
"""处理工具调用增量。
|
||
|
||
Args:
|
||
delta: 当前增量对象。
|
||
"""
|
||
tool_call_deltas = getattr(delta, "tool_calls", None) or []
|
||
for tool_call_delta in tool_call_deltas:
|
||
state = self.tool_call_states.setdefault(tool_call_delta.index, _StreamedToolCallState(index=tool_call_delta.index))
|
||
if tool_call_delta.id:
|
||
state.call_id = tool_call_delta.id
|
||
function = tool_call_delta.function
|
||
if function is not None and function.name:
|
||
state.function_name = function.name
|
||
if function is not None and function.arguments:
|
||
state.append_arguments(function.arguments)
|
||
|
||
def build_response(self) -> APIResponse:
|
||
"""构建最终 APIResponse 对象。
|
||
|
||
Returns:
|
||
APIResponse: 累积完成的响应对象。
|
||
|
||
Raises:
|
||
EmptyResponseException: 当响应中既无可见内容也无工具调用时抛出。
|
||
RespParseException: 当工具调用结构不完整时抛出。
|
||
"""
|
||
response = APIResponse()
|
||
|
||
content = self.content_buffer.getvalue().strip()
|
||
reasoning_content = self.reasoning_buffer.getvalue().strip()
|
||
if not self._using_native_reasoning and self.reasoning_parse_mode != ReasoningParseMode.NONE and content:
|
||
parsed_reasoning_content, parsed_content = _extract_reasoning_and_content(
|
||
content=content,
|
||
parse_mode=self.reasoning_parse_mode,
|
||
)
|
||
if parsed_reasoning_content:
|
||
reasoning_content = parsed_reasoning_content
|
||
content = parsed_content or ""
|
||
if reasoning_content:
|
||
response.reasoning_content = reasoning_content
|
||
if content:
|
||
response.content = content
|
||
|
||
if self.tool_call_states:
|
||
response.tool_calls = []
|
||
for index in sorted(self.tool_call_states):
|
||
state = self.tool_call_states[index]
|
||
if not state.function_name:
|
||
raise RespParseException(None, f"响应解析失败,工具调用 {index} 缺少函数名。")
|
||
raw_arguments = state.arguments_buffer.getvalue().strip()
|
||
arguments = (
|
||
_parse_tool_arguments(raw_arguments, self.tool_argument_parse_mode, None)
|
||
if raw_arguments
|
||
else None
|
||
)
|
||
call_id = state.call_id or f"tool_call_{index}"
|
||
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
|
||
|
||
response.raw_data = {"model": self.model_name} if self.model_name else None
|
||
|
||
if not response.content and not response.tool_calls:
|
||
raise EmptyResponseException()
|
||
|
||
return response
|
||
|
||
def close(self) -> None:
|
||
"""关闭内部缓冲区。"""
|
||
if not self.reasoning_buffer.closed:
|
||
self.reasoning_buffer.close()
|
||
if not self.content_buffer.closed:
|
||
self.content_buffer.close()
|
||
for state in self.tool_call_states.values():
|
||
state.close()
|
||
|
||
|
||
async def _default_stream_response_handler(
|
||
resp_stream: AsyncStream[ChatCompletionChunk],
|
||
interrupt_flag: asyncio.Event | None,
|
||
*,
|
||
reasoning_parse_mode: ReasoningParseMode,
|
||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""处理 OpenAI 兼容流式响应。
|
||
|
||
Args:
|
||
resp_stream: OpenAI SDK 返回的流式响应对象。
|
||
interrupt_flag: 外部中断标记。
|
||
reasoning_parse_mode: 推理内容解析模式。
|
||
tool_argument_parse_mode: 工具参数解析模式。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
||
"""
|
||
accumulator = _OpenAIStreamAccumulator(
|
||
reasoning_parse_mode=reasoning_parse_mode,
|
||
tool_argument_parse_mode=tool_argument_parse_mode,
|
||
)
|
||
usage_record: UsageTuple | None = None
|
||
|
||
try:
|
||
async for event in resp_stream:
|
||
if interrupt_flag and interrupt_flag.is_set():
|
||
raise ReqAbortException("请求被外部信号中断")
|
||
|
||
accumulator.capture_event_metadata(event)
|
||
event_usage = _extract_usage_record(getattr(event, "usage", None))
|
||
if event_usage is not None:
|
||
usage_record = event_usage
|
||
|
||
if not getattr(event, "choices", None):
|
||
continue
|
||
|
||
accumulator.process_delta(event.choices[0].delta)
|
||
|
||
response = accumulator.build_response()
|
||
model_name = None
|
||
if isinstance(response.raw_data, dict):
|
||
model_name = response.raw_data.get("model")
|
||
_log_length_truncation(accumulator.finish_reason, model_name)
|
||
return response, usage_record
|
||
finally:
|
||
accumulator.close()
|
||
|
||
|
||
def _default_normal_response_parser(
|
||
resp: ChatCompletion,
|
||
*,
|
||
reasoning_parse_mode: ReasoningParseMode,
|
||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""解析 OpenAI 兼容的非流式响应。
|
||
|
||
Args:
|
||
resp: OpenAI SDK 返回的聊天补全响应。
|
||
reasoning_parse_mode: 推理内容解析模式。
|
||
tool_argument_parse_mode: 工具参数解析模式。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
||
|
||
Raises:
|
||
EmptyResponseException: 当 choices 为空或响应内容为空时抛出。
|
||
"""
|
||
choices = getattr(resp, "choices", None)
|
||
if not choices:
|
||
raise EmptyResponseException("响应解析失败,choices 为空或缺失")
|
||
|
||
api_response = APIResponse()
|
||
message_part = choices[0].message
|
||
native_reasoning = getattr(message_part, "reasoning_content", None)
|
||
message_content = message_part.content if isinstance(message_part.content, str) else None
|
||
|
||
if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE:
|
||
api_response.reasoning_content = native_reasoning
|
||
api_response.content = message_content
|
||
elif isinstance(message_content, str) and message_content:
|
||
reasoning_content, final_content = _extract_reasoning_and_content(
|
||
content=message_content,
|
||
parse_mode=reasoning_parse_mode,
|
||
)
|
||
api_response.reasoning_content = reasoning_content
|
||
api_response.content = final_content
|
||
|
||
tool_calls = getattr(message_part, "tool_calls", None) or []
|
||
if tool_calls:
|
||
api_response.tool_calls = []
|
||
for tool_call in tool_calls:
|
||
if tool_call.type != "function":
|
||
raise RespParseException(resp, f"响应解析失败,暂不支持工具调用类型 {tool_call.type}。")
|
||
raw_arguments = tool_call.function.arguments or ""
|
||
arguments = _parse_tool_arguments(raw_arguments, tool_argument_parse_mode, resp)
|
||
api_response.tool_calls.append(
|
||
ToolCall(
|
||
call_id=tool_call.id,
|
||
func_name=tool_call.function.name,
|
||
args=arguments,
|
||
)
|
||
)
|
||
|
||
usage_record = _extract_usage_record(getattr(resp, "usage", None))
|
||
api_response.raw_data = resp
|
||
|
||
finish_reason = getattr(resp.choices[0], "finish_reason", None)
|
||
_log_length_truncation(finish_reason, getattr(resp, "model", None))
|
||
|
||
if not api_response.content and not api_response.tool_calls:
|
||
raise EmptyResponseException()
|
||
|
||
return api_response, usage_record
|
||
|
||
|
||
@client_registry.register_client_class("openai")
|
||
class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletion]):
|
||
"""OpenAI 兼容客户端。"""
|
||
|
||
client: AsyncOpenAI
|
||
reasoning_parse_mode: ReasoningParseMode
|
||
tool_argument_parse_mode: ToolArgumentParseMode
|
||
|
||
def __init__(self, api_provider: APIProvider) -> None:
|
||
"""初始化 OpenAI 兼容客户端。
|
||
|
||
Args:
|
||
api_provider: API 提供商配置。
|
||
"""
|
||
super().__init__(api_provider)
|
||
client_config = build_openai_compatible_client_config(api_provider)
|
||
self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode)
|
||
self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode)
|
||
self.client = AsyncOpenAI(
|
||
api_key=client_config.api_key,
|
||
organization=api_provider.organization,
|
||
project=api_provider.project,
|
||
base_url=client_config.base_url,
|
||
timeout=api_provider.timeout,
|
||
max_retries=api_provider.max_retry,
|
||
default_headers=client_config.default_headers or None,
|
||
default_query=client_config.default_query or None,
|
||
)
|
||
|
||
def _build_default_stream_response_handler(
|
||
self,
|
||
request: ResponseRequest,
|
||
) -> ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]:
|
||
"""构建 OpenAI 默认流式响应处理器。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
|
||
Returns:
|
||
ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: 默认流式处理器。
|
||
"""
|
||
del request
|
||
|
||
async def default_stream_handler(
|
||
resp_stream: AsyncStream[ChatCompletionChunk],
|
||
flag: asyncio.Event | None,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""包装默认流式解析器。"""
|
||
return await _default_stream_response_handler(
|
||
resp_stream,
|
||
flag,
|
||
reasoning_parse_mode=self.reasoning_parse_mode,
|
||
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
||
)
|
||
|
||
return default_stream_handler
|
||
|
||
def _build_default_response_parser(
|
||
self,
|
||
request: ResponseRequest,
|
||
) -> ProviderResponseParser[ChatCompletion]:
|
||
"""构建 OpenAI 默认非流式响应解析器。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
|
||
Returns:
|
||
ProviderResponseParser[ChatCompletion]: 默认非流式解析器。
|
||
"""
|
||
del request
|
||
|
||
def default_response_parser(
|
||
response: ChatCompletion,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""包装默认非流式解析器。"""
|
||
return _default_normal_response_parser(
|
||
response,
|
||
reasoning_parse_mode=self.reasoning_parse_mode,
|
||
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
||
)
|
||
|
||
return default_response_parser
|
||
|
||
async def _execute_response_request(
|
||
self,
|
||
request: ResponseRequest,
|
||
stream_response_handler: ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]],
|
||
response_parser: ProviderResponseParser[ChatCompletion],
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 OpenAI 兼容的文本/多模态响应请求。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
stream_response_handler: 流式响应处理器。
|
||
response_parser: 非流式响应解析器。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
model_info = request.model_info
|
||
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(request.message_list)
|
||
tools: Iterable[ChatCompletionToolParam] | Omit = (
|
||
_convert_tool_options(request.tool_options) if request.tool_options else omit
|
||
)
|
||
openai_response_format = _convert_response_format(request.response_format)
|
||
request_overrides = split_openai_request_overrides(
|
||
request.extra_params,
|
||
reserved_body_keys=CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS,
|
||
)
|
||
|
||
temperature_argument = (
|
||
omit if "temperature" in request_overrides.extra_body else _coerce_openai_argument(request.temperature)
|
||
)
|
||
max_tokens_argument = (
|
||
omit
|
||
if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body
|
||
else _coerce_openai_argument(request.max_tokens)
|
||
)
|
||
|
||
try:
|
||
if model_info.force_stream_mode:
|
||
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
|
||
self.client.chat.completions.create(
|
||
model=model_info.model_identifier,
|
||
messages=messages,
|
||
tools=tools,
|
||
temperature=temperature_argument,
|
||
max_tokens=max_tokens_argument,
|
||
stream=True,
|
||
response_format=openai_response_format,
|
||
extra_headers=request_overrides.extra_headers or None,
|
||
extra_query=request_overrides.extra_query or None,
|
||
extra_body=request_overrides.extra_body or None,
|
||
)
|
||
)
|
||
raw_response = cast(
|
||
AsyncStream[ChatCompletionChunk],
|
||
await await_task_with_interrupt(stream_task, request.interrupt_flag),
|
||
)
|
||
return await stream_response_handler(raw_response, request.interrupt_flag)
|
||
|
||
completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task(
|
||
self.client.chat.completions.create(
|
||
model=model_info.model_identifier,
|
||
messages=messages,
|
||
tools=tools,
|
||
temperature=temperature_argument,
|
||
max_tokens=max_tokens_argument,
|
||
stream=False,
|
||
response_format=openai_response_format,
|
||
extra_headers=request_overrides.extra_headers or None,
|
||
extra_query=request_overrides.extra_query or None,
|
||
extra_body=request_overrides.extra_body or None,
|
||
)
|
||
)
|
||
raw_response = cast(
|
||
ChatCompletion,
|
||
await await_task_with_interrupt(completion_task, request.interrupt_flag),
|
||
)
|
||
return response_parser(raw_response)
|
||
except APIConnectionError as exc:
|
||
raise NetworkConnectionError(str(exc)) from exc
|
||
except APIStatusError as exc:
|
||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
||
|
||
async def _execute_embedding_request(
|
||
self,
|
||
request: EmbeddingRequest,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 OpenAI 兼容的文本嵌入请求。
|
||
|
||
Args:
|
||
request: 统一嵌入请求对象。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
model_info = request.model_info
|
||
embedding_input = request.embedding_input
|
||
extra_params = request.extra_params
|
||
request_overrides = split_openai_request_overrides(extra_params)
|
||
|
||
try:
|
||
raw_response = await self.client.embeddings.create(
|
||
model=model_info.model_identifier,
|
||
input=embedding_input,
|
||
extra_headers=request_overrides.extra_headers or None,
|
||
extra_query=request_overrides.extra_query or None,
|
||
extra_body=request_overrides.extra_body or None,
|
||
)
|
||
except APIConnectionError as exc:
|
||
raise NetworkConnectionError(str(exc)) from exc
|
||
except APIStatusError as exc:
|
||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
||
|
||
response = APIResponse()
|
||
if raw_response.data:
|
||
response.embedding = raw_response.data[0].embedding
|
||
else:
|
||
raise RespParseException(raw_response, "响应解析失败,缺失嵌入数据。")
|
||
|
||
usage_record = _extract_usage_record(getattr(raw_response, "usage", None))
|
||
return response, usage_record
|
||
|
||
async def _execute_audio_transcription_request(
|
||
self,
|
||
request: AudioTranscriptionRequest,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 OpenAI 兼容的音频转录请求。
|
||
|
||
Args:
|
||
request: 统一音频转录请求对象。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
model_info = request.model_info
|
||
audio_base64 = request.audio_base64
|
||
extra_params = request.extra_params
|
||
request_overrides = split_openai_request_overrides(extra_params)
|
||
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
|
||
|
||
try:
|
||
raw_response = await self.client.audio.transcriptions.create(
|
||
model=model_info.model_identifier,
|
||
file=audio_file,
|
||
extra_headers=request_overrides.extra_headers or None,
|
||
extra_query=request_overrides.extra_query or None,
|
||
extra_body=request_overrides.extra_body or None,
|
||
)
|
||
except APIConnectionError as exc:
|
||
raise NetworkConnectionError(str(exc)) from exc
|
||
except APIStatusError as exc:
|
||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
||
|
||
response = APIResponse()
|
||
transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None)
|
||
if isinstance(transcription_text, str):
|
||
response.content = transcription_text
|
||
return response, None
|
||
raise RespParseException(raw_response, "响应解析失败,缺失转录文本。")
|
||
|
||
def get_support_image_formats(self) -> List[str]:
|
||
"""获取支持的图片格式列表。
|
||
|
||
Returns:
|
||
List[str]: 当前客户端支持的图片格式列表。
|
||
"""
|
||
return ["jpg", "jpeg", "png", "webp", "gif"]
|