- Added support for scheduling background tasks to build emoji and image descriptions when not found in cache. - Improved error handling and logging for emoji and image processing. - Updated `SessionMessage` processing to allow for optional heavy media analysis and voice transcription. - Refactored logging messages for better clarity and consistency across various modules. - Introduced a new function to build outbound log previews for messages, enhancing logging capabilities.
974 lines
36 KiB
Python
974 lines
36 KiB
Python
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast
|
||
|
||
import asyncio
|
||
import base64
|
||
import io
|
||
import json
|
||
|
||
from google import genai
|
||
from google.genai.errors import (
|
||
ClientError,
|
||
FunctionInvocationError,
|
||
ServerError,
|
||
UnknownFunctionCallArgumentError,
|
||
UnsupportedFunctionError,
|
||
)
|
||
from google.genai.types import (
|
||
Candidate,
|
||
Content,
|
||
ContentListUnion,
|
||
ContentUnion,
|
||
EmbedContentConfig,
|
||
EmbedContentResponse,
|
||
FunctionDeclaration,
|
||
GenerateContentConfig,
|
||
GenerateContentResponse,
|
||
GoogleSearch,
|
||
HarmBlockThreshold,
|
||
HarmCategory,
|
||
HttpOptions,
|
||
Part,
|
||
SafetySetting,
|
||
ThinkingConfig,
|
||
Tool,
|
||
)
|
||
|
||
from src.common.logger import get_logger
|
||
from src.config.model_configs import APIProvider
|
||
from src.llm_models.exceptions import (
|
||
EmptyResponseException,
|
||
NetworkConnectionError,
|
||
ReqAbortException,
|
||
RespNotOkException,
|
||
RespParseException,
|
||
)
|
||
from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart
|
||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||
|
||
from .adapter_base import (
|
||
AdapterClient,
|
||
ProviderResponseParser,
|
||
ProviderStreamResponseHandler,
|
||
await_task_with_interrupt,
|
||
)
|
||
from .base_client import (
|
||
APIResponse,
|
||
AudioTranscriptionRequest,
|
||
EmbeddingRequest,
|
||
ResponseRequest,
|
||
UsageTuple,
|
||
client_registry,
|
||
)
|
||
|
||
logger = get_logger("Gemini客户端")
|
||
|
||
GeminiStreamResponseHandler = Callable[
|
||
[AsyncIterator[GenerateContentResponse], asyncio.Event | None],
|
||
Coroutine[Any, Any, Tuple[APIResponse, Optional[UsageTuple]]],
|
||
]
|
||
"""Gemini 流式响应处理函数类型。"""
|
||
|
||
GeminiResponseParser = Callable[[GenerateContentResponse], Tuple[APIResponse, Optional[UsageTuple]]]
|
||
"""Gemini 非流式响应解析函数类型。"""
|
||
|
||
THINKING_BUDGET_LIMITS: Dict[str, Dict[str, int | bool]] = {
|
||
"gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True},
|
||
"gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True},
|
||
"gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False},
|
||
}
|
||
"""不同 Gemini 模型允许的思考预算范围。"""
|
||
|
||
THINKING_BUDGET_AUTO = -1
|
||
"""自动思考预算模式,由模型自行决定。"""
|
||
|
||
THINKING_BUDGET_DISABLED = 0
|
||
"""禁用思考预算模式。仅部分模型支持。"""
|
||
|
||
GEMINI_SAFE_SETTINGS: List[SafetySetting] = [
|
||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||
]
|
||
"""默认安全策略,避免 Gemini 在部分内容上返回空响应。"""
|
||
|
||
GENERATE_CONFIG_RESERVED_EXTRA_PARAMS = {
|
||
"thinking_budget",
|
||
"include_thoughts",
|
||
"enable_google_search",
|
||
"transcription_prompt",
|
||
"audio_mime_type",
|
||
}
|
||
"""由当前客户端自行处理、不再直接透传给 `GenerateContentConfig` 的额外参数。"""
|
||
|
||
EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = {
|
||
"task_type",
|
||
"title",
|
||
"output_dimensionality",
|
||
"mime_type",
|
||
"auto_truncate",
|
||
}
|
||
"""可透传给 `EmbedContentConfig` 的额外参数字段。"""
|
||
|
||
|
||
def _normalize_image_mime_type(image_format: str) -> str:
|
||
"""将图片格式名称转换为标准 MIME 类型。
|
||
|
||
Args:
|
||
image_format: 图片格式名,例如 `png`、`jpg`。
|
||
|
||
Returns:
|
||
str: 规范化后的图片 MIME 类型。
|
||
"""
|
||
normalized_image_format = image_format.lower()
|
||
if normalized_image_format in {"jpg", "jpeg"}:
|
||
return "image/jpeg"
|
||
return f"image/{normalized_image_format}"
|
||
|
||
|
||
def _build_non_tool_parts(message: Message) -> List[Part]:
|
||
"""将消息中的文本与图片片段转换为 Gemini `Part` 列表。
|
||
|
||
Args:
|
||
message: 内部统一消息对象。
|
||
|
||
Returns:
|
||
List[Part]: Gemini 所需的内容片段列表。
|
||
"""
|
||
converted_parts: List[Part] = []
|
||
for message_part in message.parts:
|
||
if isinstance(message_part, TextMessagePart):
|
||
converted_parts.append(Part.from_text(text=message_part.text))
|
||
continue
|
||
if isinstance(message_part, ImageMessagePart):
|
||
converted_parts.append(
|
||
Part.from_bytes(
|
||
data=base64.b64decode(message_part.image_base64),
|
||
mime_type=_normalize_image_mime_type(message_part.normalized_image_format),
|
||
)
|
||
)
|
||
return converted_parts
|
||
|
||
|
||
def _normalize_function_response_payload(message: Message) -> Dict[str, Any]:
|
||
"""将内部工具结果消息转换为 Gemini 函数响应负载。
|
||
|
||
Args:
|
||
message: 工具结果消息。
|
||
|
||
Returns:
|
||
Dict[str, Any]: 可用于 `Part.from_function_response()` 的响应对象。
|
||
"""
|
||
content = message.content
|
||
if isinstance(content, str):
|
||
stripped_content = content.strip()
|
||
if not stripped_content:
|
||
return {}
|
||
try:
|
||
parsed_content = json.loads(stripped_content)
|
||
except json.JSONDecodeError:
|
||
return {"result": content}
|
||
if isinstance(parsed_content, dict):
|
||
return parsed_content
|
||
return {"result": parsed_content}
|
||
|
||
return {"result": content}
|
||
|
||
|
||
def _get_candidates(response: GenerateContentResponse) -> List[Candidate]:
|
||
"""安全获取 Gemini 响应中的候选列表。
|
||
|
||
Args:
|
||
response: Gemini 响应对象。
|
||
|
||
Returns:
|
||
List[Candidate]: 非空时返回原候选列表,否则返回空列表。
|
||
"""
|
||
return response.candidates or []
|
||
|
||
|
||
def _extract_response_json_schema(response_format: RespFormat) -> Dict[str, object] | None:
|
||
"""从内部响应格式中提取可供 Gemini 使用的 JSON Schema。
|
||
|
||
Args:
|
||
response_format: 输出格式定义。
|
||
|
||
Returns:
|
||
Dict[str, object] | None: 可直接传给 `response_json_schema` 的 JSON Schema。
|
||
"""
|
||
schema_payload = response_format.get_schema_object()
|
||
if schema_payload is None:
|
||
return None
|
||
return cast(Dict[str, object], schema_payload)
|
||
|
||
|
||
def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str | None]:
|
||
"""将内部统一消息列表转换为 Gemini 内容结构。
|
||
|
||
Args:
|
||
messages: 内部统一消息列表。
|
||
|
||
Returns:
|
||
Tuple[ContentListUnion, str | None]: `contents` 与可选的 `system_instruction`。
|
||
|
||
Raises:
|
||
ValueError: 当消息结构无法映射到 Gemini 内容模型时抛出。
|
||
"""
|
||
contents: List[ContentUnion] = []
|
||
system_instruction_chunks: List[str] = []
|
||
tool_name_by_call_id: Dict[str, str] = {}
|
||
|
||
for message in messages:
|
||
if message.role == RoleType.System:
|
||
system_text = message.get_text_content().strip()
|
||
if not system_text:
|
||
raise ValueError("Gemini 的 system message 必须为非空文本")
|
||
system_instruction_chunks.append(system_text)
|
||
continue
|
||
|
||
if message.role == RoleType.User:
|
||
contents.append(Content(role="user", parts=_build_non_tool_parts(message)))
|
||
continue
|
||
|
||
if message.role == RoleType.Assistant:
|
||
assistant_parts = _build_non_tool_parts(message)
|
||
if message.tool_calls:
|
||
for tool_call in message.tool_calls:
|
||
assistant_parts.append(
|
||
Part.from_function_call(
|
||
name=tool_call.func_name,
|
||
args=tool_call.args or {},
|
||
)
|
||
)
|
||
tool_name_by_call_id[tool_call.call_id] = tool_call.func_name
|
||
contents.append(Content(role="model", parts=assistant_parts))
|
||
continue
|
||
|
||
if message.role == RoleType.Tool:
|
||
if not message.tool_call_id:
|
||
raise ValueError("Gemini 工具结果消息缺少 tool_call_id")
|
||
tool_name = tool_name_by_call_id.get(message.tool_call_id)
|
||
if not tool_name:
|
||
raise ValueError(f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称")
|
||
function_response_part = Part.from_function_response(
|
||
name=tool_name,
|
||
response=_normalize_function_response_payload(message),
|
||
)
|
||
contents.append(Content(role="tool", parts=[function_response_part]))
|
||
continue
|
||
|
||
raise ValueError(f"不支持的消息角色: {message.role}")
|
||
|
||
system_instruction = "\n\n".join(chunk for chunk in system_instruction_chunks if chunk.strip()) or None
|
||
return contents, system_instruction
|
||
|
||
|
||
def _build_tools(tool_options: List[ToolOption]) -> List[Tool]:
|
||
"""将内部工具定义转换为 Gemini `Tool` 列表。
|
||
|
||
Args:
|
||
tool_options: 内部统一工具定义列表。
|
||
|
||
Returns:
|
||
List[Tool]: Gemini 所需工具列表。
|
||
"""
|
||
function_declarations: List[FunctionDeclaration] = []
|
||
for tool_option in tool_options:
|
||
payload: Dict[str, Any] = {
|
||
"name": tool_option.name,
|
||
"description": tool_option.description,
|
||
}
|
||
if tool_option.parameters_schema is not None:
|
||
payload["parameters_json_schema"] = tool_option.parameters_schema
|
||
function_declarations.append(FunctionDeclaration(**payload))
|
||
return [Tool(function_declarations=function_declarations)] if function_declarations else []
|
||
|
||
|
||
def _extract_usage_record(response: GenerateContentResponse) -> Optional[UsageTuple]:
|
||
"""从 Gemini 响应中提取使用量信息。
|
||
|
||
Args:
|
||
response: Gemini 响应对象。
|
||
|
||
Returns:
|
||
Optional[UsageTuple]: 统一的使用量三元组;缺失时返回 `None`。
|
||
"""
|
||
usage_metadata = getattr(response, "usage_metadata", None)
|
||
if usage_metadata is None:
|
||
return None
|
||
prompt_tokens = getattr(usage_metadata, "prompt_token_count", 0) or 0
|
||
completion_tokens = (
|
||
(getattr(usage_metadata, "candidates_token_count", 0) or 0)
|
||
+ (getattr(usage_metadata, "thoughts_token_count", 0) or 0)
|
||
)
|
||
total_tokens = getattr(usage_metadata, "total_token_count", 0) or 0
|
||
return prompt_tokens, completion_tokens, total_tokens
|
||
|
||
|
||
def _extract_finish_reason(response: GenerateContentResponse | None) -> str | None:
|
||
"""提取 Gemini 响应的结束原因。
|
||
|
||
Args:
|
||
response: Gemini 响应对象。
|
||
|
||
Returns:
|
||
str | None: 结束原因字符串;获取失败时返回 `None`。
|
||
"""
|
||
if response is None:
|
||
return None
|
||
candidates = _get_candidates(response)
|
||
if not candidates:
|
||
return None
|
||
for candidate in candidates:
|
||
finish_reason = getattr(candidate, "finish_reason", None) or getattr(candidate, "finishReason", None)
|
||
if finish_reason:
|
||
return str(finish_reason)
|
||
return None
|
||
|
||
|
||
def _warn_if_max_tokens_truncated(
|
||
response: GenerateContentResponse | None,
|
||
content: str | None,
|
||
tool_calls: List[ToolCall] | None,
|
||
) -> None:
|
||
"""在 Gemini 因 token 限制截断时输出警告。
|
||
|
||
Args:
|
||
response: Gemini 响应对象。
|
||
content: 已解析的可见文本内容。
|
||
tool_calls: 已解析的工具调用列表。
|
||
"""
|
||
finish_reason = _extract_finish_reason(response)
|
||
if finish_reason is None or "MAX_TOKENS" not in finish_reason:
|
||
return
|
||
has_visible_output = bool((content and content.strip()) or tool_calls)
|
||
if has_visible_output:
|
||
logger.warning(
|
||
"Gemini 响应因达到 max_tokens 限制被部分截断,可能影响回复完整性,建议调整模型 max_tokens 配置。"
|
||
)
|
||
return
|
||
logger.warning("Gemini 响应因达到 max_tokens 限制被截断,且未返回可见输出,请检查模型 max_tokens 配置。")
|
||
|
||
|
||
def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]:
|
||
"""从 Gemini 响应中提取工具调用列表。
|
||
|
||
Args:
|
||
response: Gemini 响应对象。
|
||
|
||
Returns:
|
||
List[ToolCall]: 规范化后的工具调用列表。
|
||
|
||
Raises:
|
||
RespParseException: 当函数调用结构不合法时抛出。
|
||
"""
|
||
raw_function_calls = getattr(response, "function_calls", None)
|
||
candidates = _get_candidates(response)
|
||
if not raw_function_calls and candidates:
|
||
raw_function_calls = []
|
||
for candidate in candidates:
|
||
content = getattr(candidate, "content", None)
|
||
parts = getattr(content, "parts", None) or []
|
||
for part in parts:
|
||
function_call = getattr(part, "function_call", None)
|
||
if function_call is not None:
|
||
raw_function_calls.append(function_call)
|
||
|
||
if not raw_function_calls:
|
||
return []
|
||
|
||
tool_calls: List[ToolCall] = []
|
||
for index, function_call in enumerate(raw_function_calls, start=1):
|
||
call_name = getattr(function_call, "name", None)
|
||
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}"
|
||
call_args = getattr(function_call, "args", None) or {}
|
||
if not isinstance(call_name, str) or not call_name:
|
||
raise RespParseException(response, "响应解析失败,Gemini 工具调用缺少 name 字段")
|
||
if not isinstance(call_args, dict):
|
||
raise RespParseException(response, "响应解析失败,Gemini 工具调用参数无法解析为字典")
|
||
tool_calls.append(ToolCall(call_id=call_id, func_name=call_name, args=call_args))
|
||
return tool_calls
|
||
|
||
|
||
def _process_stream_chunk(
|
||
chunk: GenerateContentResponse,
|
||
content_buffer: io.StringIO,
|
||
tool_calls_buffer: List[ToolCall],
|
||
response: APIResponse,
|
||
) -> None:
|
||
"""处理单个 Gemini 流式响应块。
|
||
|
||
Args:
|
||
chunk: 当前流式响应块。
|
||
content_buffer: 正文缓冲区。
|
||
tool_calls_buffer: 工具调用缓冲区。
|
||
response: 当前累积的统一响应对象。
|
||
"""
|
||
candidates = _get_candidates(chunk)
|
||
for candidate in candidates:
|
||
content = getattr(candidate, "content", None)
|
||
parts = getattr(content, "parts", None) or []
|
||
for part in parts:
|
||
part_text = getattr(part, "text", None)
|
||
if not part_text:
|
||
continue
|
||
if getattr(part, "thought", False):
|
||
response.reasoning_content = (response.reasoning_content or "") + part_text
|
||
else:
|
||
content_buffer.write(part_text)
|
||
|
||
tool_calls_buffer.extend(_collect_function_calls(chunk))
|
||
|
||
|
||
def _build_stream_api_response(
|
||
content_buffer: io.StringIO,
|
||
tool_calls_buffer: List[ToolCall],
|
||
last_response: GenerateContentResponse | None,
|
||
response: APIResponse,
|
||
) -> APIResponse:
|
||
"""根据流式缓冲区内容构建统一响应对象。
|
||
|
||
Args:
|
||
content_buffer: 正文缓冲区。
|
||
tool_calls_buffer: 工具调用缓冲区。
|
||
last_response: 最后一个 Gemini 响应块。
|
||
response: 已累积的响应对象。
|
||
|
||
Returns:
|
||
APIResponse: 构建完成的统一响应对象。
|
||
|
||
Raises:
|
||
EmptyResponseException: 响应中既无正文也无工具调用且无思考内容时抛出。
|
||
"""
|
||
if content_buffer.tell() > 0:
|
||
response.content = content_buffer.getvalue()
|
||
content_buffer.close()
|
||
|
||
if tool_calls_buffer:
|
||
response.tool_calls = list(tool_calls_buffer)
|
||
response.raw_data = last_response
|
||
|
||
_warn_if_max_tokens_truncated(last_response, response.content, response.tool_calls)
|
||
if not response.content and not response.tool_calls and not response.reasoning_content:
|
||
raise EmptyResponseException()
|
||
return response
|
||
|
||
|
||
async def _default_stream_response_handler(
|
||
response_stream: AsyncIterator[GenerateContentResponse],
|
||
interrupt_flag: asyncio.Event | None,
|
||
) -> Tuple[APIResponse, Optional[UsageTuple]]:
|
||
"""处理 Gemini 流式响应。
|
||
|
||
Args:
|
||
response_stream: Gemini 异步流式响应迭代器。
|
||
interrupt_flag: 外部中断标记。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, Optional[UsageTuple]]: 统一响应对象与可选的使用量信息。
|
||
"""
|
||
content_buffer = io.StringIO()
|
||
tool_calls_buffer: List[ToolCall] = []
|
||
api_response = APIResponse()
|
||
usage_record: Optional[UsageTuple] = None
|
||
last_response: GenerateContentResponse | None = None
|
||
|
||
try:
|
||
async for chunk in response_stream:
|
||
last_response = chunk
|
||
if interrupt_flag and interrupt_flag.is_set():
|
||
raise ReqAbortException("请求被外部信号中断")
|
||
_process_stream_chunk(chunk, content_buffer, tool_calls_buffer, api_response)
|
||
usage_record = _extract_usage_record(chunk) or usage_record
|
||
return _build_stream_api_response(content_buffer, tool_calls_buffer, last_response, api_response), usage_record
|
||
except Exception:
|
||
if not content_buffer.closed:
|
||
content_buffer.close()
|
||
raise
|
||
|
||
|
||
def _default_normal_response_parser(
|
||
response: GenerateContentResponse,
|
||
) -> Tuple[APIResponse, Optional[UsageTuple]]:
|
||
"""解析 Gemini 非流式响应。
|
||
|
||
Args:
|
||
response: Gemini 响应对象。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, Optional[UsageTuple]]: 统一响应对象与可选的使用量信息。
|
||
|
||
Raises:
|
||
EmptyResponseException: 响应中既无正文也无工具调用且无思考内容时抛出。
|
||
"""
|
||
api_response = APIResponse(raw_data=response)
|
||
visible_parts: List[str] = []
|
||
|
||
for candidate in _get_candidates(response):
|
||
content = getattr(candidate, "content", None)
|
||
parts = getattr(content, "parts", None) or []
|
||
for part in parts:
|
||
part_text = getattr(part, "text", None)
|
||
if not part_text:
|
||
continue
|
||
if getattr(part, "thought", False):
|
||
api_response.reasoning_content = (api_response.reasoning_content or "") + part_text
|
||
else:
|
||
visible_parts.append(part_text)
|
||
|
||
api_response.content = "".join(visible_parts).strip() or getattr(response, "text", None)
|
||
|
||
tool_calls = _collect_function_calls(response)
|
||
if tool_calls:
|
||
api_response.tool_calls = tool_calls
|
||
|
||
usage_record = _extract_usage_record(response)
|
||
_warn_if_max_tokens_truncated(response, api_response.content, api_response.tool_calls)
|
||
if not api_response.content and not api_response.tool_calls and not api_response.reasoning_content:
|
||
raise EmptyResponseException("响应中既无文本内容也无工具调用")
|
||
return api_response, usage_record
|
||
|
||
|
||
def _build_http_options(api_provider: APIProvider) -> HttpOptions:
|
||
"""根据 Provider 配置构建 Gemini SDK 的 `HttpOptions`。
|
||
|
||
Args:
|
||
api_provider: API 提供商配置。
|
||
|
||
Returns:
|
||
HttpOptions: Gemini SDK HTTP 选项对象。
|
||
"""
|
||
http_options_payload: Dict[str, Any] = {}
|
||
if api_provider.timeout is not None:
|
||
http_options_payload["timeout"] = int(api_provider.timeout * 1000)
|
||
|
||
base_url = api_provider.base_url.strip()
|
||
if base_url:
|
||
normalized_base_url = base_url.rstrip("/")
|
||
version_candidate = normalized_base_url.rsplit("/", 1)
|
||
if len(version_candidate) == 2 and version_candidate[1].startswith("v"):
|
||
http_options_payload["base_url"] = f"{version_candidate[0]}/"
|
||
http_options_payload["api_version"] = version_candidate[1]
|
||
else:
|
||
http_options_payload["base_url"] = f"{normalized_base_url}/"
|
||
|
||
return HttpOptions(**http_options_payload)
|
||
|
||
|
||
def _filter_generate_content_extra_params(extra_params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""筛选可透传给 `GenerateContentConfig` 的额外参数。
|
||
|
||
Args:
|
||
extra_params: 模型级额外参数。
|
||
|
||
Returns:
|
||
Dict[str, Any]: 可直接透传到 `GenerateContentConfig` 的字段字典。
|
||
"""
|
||
filtered_params: Dict[str, Any] = {}
|
||
for key, value in extra_params.items():
|
||
if key in GENERATE_CONFIG_RESERVED_EXTRA_PARAMS:
|
||
continue
|
||
if key in GenerateContentConfig.model_fields:
|
||
filtered_params[key] = value
|
||
return filtered_params
|
||
|
||
|
||
def _build_embed_content_config(extra_params: Dict[str, Any]) -> EmbedContentConfig:
|
||
"""构建 Gemini 嵌入配置。
|
||
|
||
Args:
|
||
extra_params: 模型级额外参数。
|
||
|
||
Returns:
|
||
EmbedContentConfig: Gemini 嵌入配置对象。
|
||
"""
|
||
config_payload: Dict[str, Any] = {"task_type": extra_params.get("task_type", "SEMANTIC_SIMILARITY")}
|
||
for key in EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS:
|
||
if key == "task_type":
|
||
continue
|
||
if key in extra_params:
|
||
config_payload[key] = extra_params[key]
|
||
return EmbedContentConfig(**config_payload)
|
||
|
||
|
||
@client_registry.register_client_class("gemini")
|
||
class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], GenerateContentResponse]):
|
||
"""Gemini 官方 SDK 客户端适配器。"""
|
||
|
||
client: genai.Client
|
||
|
||
def __init__(self, api_provider: APIProvider) -> None:
|
||
"""初始化 Gemini 客户端。
|
||
|
||
Args:
|
||
api_provider: API 提供商配置。
|
||
"""
|
||
super().__init__(api_provider)
|
||
self.client = genai.Client(
|
||
api_key=api_provider.api_key,
|
||
http_options=_build_http_options(api_provider),
|
||
)
|
||
|
||
@staticmethod
|
||
def clamp_thinking_budget(extra_params: Dict[str, Any] | None, model_id: str) -> int:
|
||
"""将思考预算裁剪到模型允许的范围内。
|
||
|
||
Args:
|
||
extra_params: 请求额外参数。
|
||
model_id: 当前模型标识。
|
||
|
||
Returns:
|
||
int: 裁剪后的思考预算值。
|
||
"""
|
||
thinking_budget = THINKING_BUDGET_AUTO
|
||
if extra_params and "thinking_budget" in extra_params:
|
||
try:
|
||
thinking_budget = int(extra_params["thinking_budget"])
|
||
except (TypeError, ValueError):
|
||
logger.warning(f"无效的 thinking_budget={extra_params['thinking_budget']},已回退为自动模式")
|
||
|
||
limits: Dict[str, int | bool] | None = None
|
||
if model_id in THINKING_BUDGET_LIMITS:
|
||
limits = THINKING_BUDGET_LIMITS[model_id]
|
||
else:
|
||
for candidate_prefix in sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True):
|
||
if model_id == candidate_prefix or model_id.startswith(f"{candidate_prefix}-"):
|
||
limits = THINKING_BUDGET_LIMITS[candidate_prefix]
|
||
break
|
||
|
||
if thinking_budget == THINKING_BUDGET_AUTO:
|
||
return THINKING_BUDGET_AUTO
|
||
|
||
if thinking_budget == THINKING_BUDGET_DISABLED:
|
||
if limits and bool(limits.get("can_disable", False)):
|
||
return THINKING_BUDGET_DISABLED
|
||
if limits:
|
||
minimum_value = int(limits["min"])
|
||
logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退为最小值 {minimum_value}")
|
||
return minimum_value
|
||
return THINKING_BUDGET_AUTO
|
||
|
||
if limits is None:
|
||
logger.warning(f"模型 {model_id} 未配置思考预算范围,已回退为自动模式")
|
||
return THINKING_BUDGET_AUTO
|
||
|
||
minimum_value = int(limits["min"])
|
||
maximum_value = int(limits["max"])
|
||
if thinking_budget < minimum_value:
|
||
logger.warning(f"模型 {model_id} 的 thinking_budget={thinking_budget} 过小,已调整为 {minimum_value}")
|
||
return minimum_value
|
||
if thinking_budget > maximum_value:
|
||
logger.warning(f"模型 {model_id} 的 thinking_budget={thinking_budget} 过大,已调整为 {maximum_value}")
|
||
return maximum_value
|
||
return thinking_budget
|
||
|
||
@staticmethod
|
||
def _resolve_model_identifier(model_identifier: str, extra_params: Dict[str, Any]) -> Tuple[str, bool]:
|
||
"""解析请求实际使用的 Gemini 模型标识。
|
||
|
||
Args:
|
||
model_identifier: 原始模型标识。
|
||
extra_params: 模型级额外参数。
|
||
|
||
Returns:
|
||
Tuple[str, bool]: `(实际模型标识, 是否启用 Google Search)`。
|
||
"""
|
||
enable_google_search = bool(extra_params.get("enable_google_search", False))
|
||
resolved_model_identifier = model_identifier
|
||
if resolved_model_identifier.endswith("-search"):
|
||
resolved_model_identifier = resolved_model_identifier.removesuffix("-search")
|
||
enable_google_search = True
|
||
return resolved_model_identifier, enable_google_search
|
||
|
||
def _build_generation_config(
|
||
self,
|
||
*,
|
||
model_identifier: str,
|
||
system_instruction: str | None,
|
||
tool_options: List[ToolOption] | None,
|
||
response_format: RespFormat | None,
|
||
max_tokens: int | None,
|
||
temperature: float | None,
|
||
extra_params: Dict[str, Any],
|
||
enable_google_search: bool,
|
||
) -> GenerateContentConfig:
|
||
"""构建 Gemini 生成配置。
|
||
|
||
Args:
|
||
model_identifier: 当前请求实际使用的模型标识。
|
||
system_instruction: 系统指令文本。
|
||
tool_options: 内部工具定义列表。
|
||
response_format: 输出格式定义。
|
||
max_tokens: 最大输出 token 数。
|
||
temperature: 温度参数。
|
||
extra_params: 模型级额外参数。
|
||
enable_google_search: 是否自动追加 Google Search 工具。
|
||
|
||
Returns:
|
||
GenerateContentConfig: Gemini 生成配置对象。
|
||
"""
|
||
config_payload = _filter_generate_content_extra_params(extra_params)
|
||
|
||
if max_tokens is not None and "max_output_tokens" not in config_payload:
|
||
config_payload["max_output_tokens"] = max_tokens
|
||
if temperature is not None and "temperature" not in config_payload:
|
||
config_payload["temperature"] = temperature
|
||
if system_instruction and "system_instruction" not in config_payload:
|
||
config_payload["system_instruction"] = system_instruction
|
||
if "response_modalities" not in config_payload:
|
||
config_payload["response_modalities"] = ["TEXT"]
|
||
if "safety_settings" not in config_payload:
|
||
config_payload["safety_settings"] = GEMINI_SAFE_SETTINGS
|
||
if "thinking_config" not in config_payload:
|
||
config_payload["thinking_config"] = ThinkingConfig(
|
||
include_thoughts=bool(extra_params.get("include_thoughts", True)),
|
||
thinking_budget=self.clamp_thinking_budget(extra_params, model_identifier),
|
||
)
|
||
|
||
tools = _build_tools(tool_options) if tool_options else []
|
||
if enable_google_search:
|
||
tools.append(Tool(google_search=GoogleSearch()))
|
||
if tools:
|
||
if "tools" in config_payload:
|
||
existing_tools = config_payload["tools"]
|
||
if isinstance(existing_tools, list):
|
||
config_payload["tools"] = [*existing_tools, *tools]
|
||
else:
|
||
config_payload["tools"] = [existing_tools, *tools]
|
||
else:
|
||
config_payload["tools"] = tools
|
||
|
||
if response_format is not None:
|
||
if response_format.format_type == RespFormatType.TEXT:
|
||
config_payload.setdefault("response_mime_type", "text/plain")
|
||
elif response_format.format_type == RespFormatType.JSON_OBJ:
|
||
config_payload.setdefault("response_mime_type", "application/json")
|
||
elif response_format.format_type == RespFormatType.JSON_SCHEMA:
|
||
config_payload.setdefault("response_mime_type", "application/json")
|
||
response_json_schema = _extract_response_json_schema(response_format)
|
||
if (
|
||
response_json_schema is not None
|
||
and "response_json_schema" not in config_payload
|
||
and "response_schema" not in config_payload
|
||
):
|
||
config_payload["response_json_schema"] = response_json_schema
|
||
|
||
return GenerateContentConfig(**config_payload)
|
||
|
||
def _build_default_stream_response_handler(
|
||
self,
|
||
request: ResponseRequest,
|
||
) -> ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]]:
|
||
"""构建 Gemini 默认流式响应处理器。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
|
||
Returns:
|
||
ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]]: 默认流式处理器。
|
||
"""
|
||
del request
|
||
return _default_stream_response_handler
|
||
|
||
def _build_default_response_parser(
|
||
self,
|
||
request: ResponseRequest,
|
||
) -> ProviderResponseParser[GenerateContentResponse]:
|
||
"""构建 Gemini 默认非流式响应解析器。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
|
||
Returns:
|
||
ProviderResponseParser[GenerateContentResponse]: 默认非流式解析器。
|
||
"""
|
||
del request
|
||
return _default_normal_response_parser
|
||
|
||
async def _execute_response_request(
|
||
self,
|
||
request: ResponseRequest,
|
||
stream_response_handler: ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]],
|
||
response_parser: ProviderResponseParser[GenerateContentResponse],
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 Gemini 的文本/多模态响应请求。
|
||
|
||
Args:
|
||
request: 统一响应请求对象。
|
||
stream_response_handler: 流式响应处理器。
|
||
response_parser: 非流式响应解析器。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
model_info = request.model_info
|
||
contents, system_instruction = _convert_messages(request.message_list)
|
||
model_identifier, enable_google_search = self._resolve_model_identifier(
|
||
model_info.model_identifier,
|
||
request.extra_params,
|
||
)
|
||
generation_config = self._build_generation_config(
|
||
model_identifier=model_identifier,
|
||
system_instruction=system_instruction,
|
||
tool_options=request.tool_options,
|
||
response_format=request.response_format,
|
||
max_tokens=request.max_tokens,
|
||
temperature=request.temperature,
|
||
extra_params=request.extra_params,
|
||
enable_google_search=enable_google_search,
|
||
)
|
||
|
||
try:
|
||
if model_info.force_stream_mode:
|
||
stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task(
|
||
self.client.aio.models.generate_content_stream(
|
||
model=model_identifier,
|
||
contents=contents,
|
||
config=generation_config,
|
||
)
|
||
)
|
||
raw_response_stream = cast(
|
||
AsyncIterator[GenerateContentResponse],
|
||
await await_task_with_interrupt(stream_task, request.interrupt_flag),
|
||
)
|
||
return await stream_response_handler(raw_response_stream, request.interrupt_flag)
|
||
|
||
completion_task: asyncio.Task[GenerateContentResponse] = asyncio.create_task(
|
||
self.client.aio.models.generate_content(
|
||
model=model_identifier,
|
||
contents=contents,
|
||
config=generation_config,
|
||
)
|
||
)
|
||
raw_response = cast(
|
||
GenerateContentResponse,
|
||
await await_task_with_interrupt(completion_task, request.interrupt_flag),
|
||
)
|
||
return response_parser(raw_response)
|
||
except ReqAbortException:
|
||
raise
|
||
except (ClientError, ServerError) as exc:
|
||
status_code = int(getattr(exc, "code", 500) or 500)
|
||
raise RespNotOkException(status_code, str(exc)) from exc
|
||
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
|
||
raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc
|
||
except EmptyResponseException:
|
||
raise
|
||
except Exception as exc:
|
||
raise NetworkConnectionError(str(exc)) from exc
|
||
|
||
async def _execute_embedding_request(
|
||
self,
|
||
request: EmbeddingRequest,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 Gemini 文本嵌入请求。
|
||
|
||
Args:
|
||
request: 统一嵌入请求对象。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
model_info = request.model_info
|
||
embedding_input = request.embedding_input
|
||
extra_params = request.extra_params
|
||
embed_config = _build_embed_content_config(extra_params)
|
||
|
||
try:
|
||
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
|
||
model=model_info.model_identifier,
|
||
contents=embedding_input,
|
||
config=embed_config,
|
||
)
|
||
except (ClientError, ServerError) as exc:
|
||
status_code = int(getattr(exc, "code", 500) or 500)
|
||
raise RespNotOkException(status_code, str(exc)) from exc
|
||
except Exception as exc:
|
||
raise NetworkConnectionError(str(exc)) from exc
|
||
|
||
response = APIResponse(raw_data=raw_response)
|
||
if raw_response.embeddings:
|
||
response.embedding = raw_response.embeddings[0].values
|
||
else:
|
||
raise RespParseException(raw_response, "响应解析失败,缺失 embeddings 字段")
|
||
|
||
billable_character_count = 0
|
||
if raw_response.metadata is not None:
|
||
billable_character_count = getattr(raw_response.metadata, "billable_character_count", 0) or 0
|
||
usage_record: UsageTuple = (
|
||
billable_character_count or len(embedding_input),
|
||
0,
|
||
billable_character_count or len(embedding_input),
|
||
)
|
||
return response, usage_record
|
||
|
||
async def _execute_audio_transcription_request(
|
||
self,
|
||
request: AudioTranscriptionRequest,
|
||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||
"""执行 Gemini 音频转录请求。
|
||
|
||
Args:
|
||
request: 统一音频转录请求对象。
|
||
|
||
Returns:
|
||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||
"""
|
||
model_info = request.model_info
|
||
audio_base64 = request.audio_base64
|
||
max_tokens = request.max_tokens
|
||
extra_params = request.extra_params
|
||
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
|
||
|
||
transcription_prompt = str(
|
||
extra_params.get(
|
||
"transcription_prompt",
|
||
"Generate a transcript of the speech. The language of the transcript should match the speech.",
|
||
)
|
||
)
|
||
audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav"))
|
||
contents: List[ContentUnion] = [
|
||
Content(
|
||
role="user",
|
||
parts=[
|
||
Part.from_text(text=transcription_prompt),
|
||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type=audio_mime_type),
|
||
],
|
||
)
|
||
]
|
||
generation_config = self._build_generation_config(
|
||
model_identifier=model_identifier,
|
||
system_instruction=None,
|
||
tool_options=None,
|
||
response_format=None,
|
||
max_tokens=max_tokens,
|
||
temperature=None,
|
||
extra_params=extra_params,
|
||
enable_google_search=False,
|
||
)
|
||
|
||
try:
|
||
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
||
model=model_identifier,
|
||
contents=contents,
|
||
config=generation_config,
|
||
)
|
||
response, usage_record = _default_normal_response_parser(raw_response)
|
||
except (ClientError, ServerError) as exc:
|
||
status_code = int(getattr(exc, "code", 500) or 500)
|
||
raise RespNotOkException(status_code, str(exc)) from exc
|
||
except Exception as exc:
|
||
raise NetworkConnectionError(str(exc)) from exc
|
||
|
||
return response, usage_record
|
||
|
||
def get_support_image_formats(self) -> List[str]:
|
||
"""获取 Gemini 当前支持的图片格式列表。
|
||
|
||
Returns:
|
||
List[str]: 当前客户端支持的图片格式列表。
|
||
"""
|
||
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]
|