# ruff: noqa: B025 from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast import asyncio import base64 import binascii import io import json from google import genai from google.genai.errors import ( ClientError, FunctionInvocationError, ServerError, UnknownFunctionCallArgumentError, UnsupportedFunctionError, ) from google.genai.types import ( Candidate, Content, ContentListUnion, ContentUnion, EmbedContentConfig, EmbedContentResponse, FunctionDeclaration, FunctionCall, FunctionResponse, GenerateContentConfig, GenerateContentResponse, GoogleSearch, HarmBlockThreshold, HarmCategory, HttpOptions, Part, SafetySetting, ThinkingConfig, Tool, ) from src.common.logger import get_logger from src.config.model_configs import APIProvider from src.llm_models.exceptions import ( EmptyResponseException, NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException, ) from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType from src.llm_models.payload_content.tool_option import ToolCall, ToolOption from .adapter_base import ( AdapterClient, ProviderResponseParser, ProviderStreamResponseHandler, await_task_with_interrupt, ) from .base_client import ( APIResponse, AudioTranscriptionRequest, EmbeddingRequest, ResponseRequest, UsageTuple, client_registry, ) from ..request_snapshot import ( attach_request_snapshot, has_request_snapshot, save_failed_request_snapshot, serialize_audio_request_snapshot, serialize_embedding_request_snapshot, serialize_response_request_snapshot, ) logger = get_logger("Gemini客户端") GeminiStreamResponseHandler = Callable[ [AsyncIterator[GenerateContentResponse], asyncio.Event | None], Coroutine[Any, Any, Tuple[APIResponse, Optional[UsageTuple]]], ] """Gemini 流式响应处理函数类型。""" GeminiResponseParser = Callable[[GenerateContentResponse], Tuple[APIResponse, Optional[UsageTuple]]] """Gemini 非流式响应解析函数类型。""" THINKING_BUDGET_LIMITS: Dict[str, Dict[str, int | bool]] = { "gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True}, "gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True}, "gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False}, } """不同 Gemini 模型允许的思考预算范围。""" THINKING_BUDGET_AUTO = -1 """自动思考预算模式,由模型自行决定。""" THINKING_BUDGET_DISABLED = 0 """禁用思考预算模式。仅部分模型支持。""" GEMINI_SAFE_SETTINGS: List[SafetySetting] = [ SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE), ] """默认安全策略,避免 Gemini 在部分内容上返回空响应。""" GENERATE_CONFIG_RESERVED_EXTRA_PARAMS = { "thinking_budget", "include_thoughts", "enable_google_search", "transcription_prompt", "audio_mime_type", } """由当前客户端自行处理、不再直接透传给 `GenerateContentConfig` 的额外参数。""" EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = { "task_type", "title", "output_dimensionality", "mime_type", "auto_truncate", } """可透传给 `EmbedContentConfig` 的额外参数字段。""" GEMINI_EXTRA_CONTENT_PROVIDER_KEY = "google" GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY = "thought_signature" GEMINI_FALLBACK_THOUGHT_SIGNATURE = b"skip_thought_signature_validator" """当历史 function call 没有原始 thought signature 时,使用官方允许的占位签名跳过校验。""" def _normalize_image_mime_type(image_format: str) -> str: """将图片格式名称转换为标准 MIME 类型。 Args: image_format: 图片格式名,例如 `png`、`jpg`。 Returns: str: 规范化后的图片 MIME 类型。 """ normalized_image_format = image_format.lower() if normalized_image_format in {"jpg", "jpeg"}: return "image/jpeg" return f"image/{normalized_image_format}" def _build_non_tool_parts(message: Message) -> List[Part]: """将消息中的文本与图片片段转换为 Gemini `Part` 列表。 Args: message: 内部统一消息对象。 Returns: List[Part]: Gemini 所需的内容片段列表。 """ converted_parts: List[Part] = [] for message_part in message.parts: if isinstance(message_part, TextMessagePart): converted_parts.append(Part.from_text(text=message_part.text)) continue if isinstance(message_part, ImageMessagePart): converted_parts.append( Part.from_bytes( data=base64.b64decode(message_part.image_base64), mime_type=_normalize_image_mime_type(message_part.normalized_image_format), ) ) return converted_parts def _normalize_function_response_payload(message: Message) -> Dict[str, Any]: """将内部工具结果消息转换为 Gemini 函数响应负载。 Args: message: 工具结果消息。 Returns: Dict[str, Any]: 可用于 `Part.from_function_response()` 的响应对象。 """ content = message.content if isinstance(content, str): stripped_content = content.strip() if not stripped_content: return {} try: parsed_content = json.loads(stripped_content) except json.JSONDecodeError: return {"result": content} if isinstance(parsed_content, dict): return parsed_content return {"result": parsed_content} return {"result": content} def _build_gemini_tool_call_extra_content(thought_signature: bytes | None) -> Dict[str, Any] | None: """将 Gemini thought signature 编码为内部工具调用附加信息。""" if not thought_signature: return None return { GEMINI_EXTRA_CONTENT_PROVIDER_KEY: { GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY: base64.b64encode(thought_signature).decode("ascii") } } def _extract_gemini_thought_signature(tool_call: ToolCall) -> bytes | None: """从内部工具调用附加信息中提取 Gemini thought signature。""" if not tool_call.extra_content: return None provider_payload = tool_call.extra_content.get(GEMINI_EXTRA_CONTENT_PROVIDER_KEY) if not isinstance(provider_payload, dict): return None raw_thought_signature = provider_payload.get(GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY) if isinstance(raw_thought_signature, bytes): return raw_thought_signature if not isinstance(raw_thought_signature, str): return None normalized_signature = raw_thought_signature.strip() if not normalized_signature: return None try: return base64.b64decode(normalized_signature.encode("ascii"), validate=True) except (binascii.Error, ValueError): return normalized_signature.encode("utf-8") def _build_gemini_function_call_part( tool_call: ToolCall, *, inject_fallback_signature: bool, ) -> Part: """根据内部工具调用构建 Gemini function call part。""" thought_signature = _extract_gemini_thought_signature(tool_call) if thought_signature is None and inject_fallback_signature: thought_signature = GEMINI_FALLBACK_THOUGHT_SIGNATURE return Part( function_call=FunctionCall( id=tool_call.call_id, name=tool_call.func_name, args=tool_call.args or {}, ), thought_signature=thought_signature, ) def _get_candidates(response: GenerateContentResponse) -> List[Candidate]: """安全获取 Gemini 响应中的候选列表。 Args: response: Gemini 响应对象。 Returns: List[Candidate]: 非空时返回原候选列表,否则返回空列表。 """ return response.candidates or [] def _extract_response_json_schema(response_format: RespFormat) -> Dict[str, object] | None: """从内部响应格式中提取可供 Gemini 使用的 JSON Schema。 Args: response_format: 输出格式定义。 Returns: Dict[str, object] | None: 可直接传给 `response_json_schema` 的 JSON Schema。 """ schema_payload = response_format.get_schema_object() if schema_payload is None: return None return cast(Dict[str, object], schema_payload) def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str | None]: """将内部统一消息列表转换为 Gemini 内容结构。 Args: messages: 内部统一消息列表。 Returns: Tuple[ContentListUnion, str | None]: `contents` 与可选的 `system_instruction`。 Raises: ValueError: 当消息结构无法映射到 Gemini 内容模型时抛出。 """ contents: List[ContentUnion] = [] system_instruction_chunks: List[str] = [] tool_name_by_call_id: Dict[str, str] = {} for message in messages: if message.role == RoleType.System: system_text = message.get_text_content().strip() if not system_text: raise ValueError("Gemini 的 system message 必须为非空文本") system_instruction_chunks.append(system_text) continue if message.role == RoleType.User: contents.append(Content(role="user", parts=_build_non_tool_parts(message))) continue if message.role == RoleType.Assistant: assistant_parts = _build_non_tool_parts(message) if message.tool_calls: for tool_call_index, tool_call in enumerate(message.tool_calls): assistant_parts.append( _build_gemini_function_call_part( tool_call, inject_fallback_signature=tool_call_index == 0, ) ) tool_name_by_call_id[tool_call.call_id] = tool_call.func_name contents.append(Content(role="model", parts=assistant_parts)) continue if message.role == RoleType.Tool: if not message.tool_call_id: raise ValueError("Gemini 工具结果消息缺少 tool_call_id") tool_name = (message.tool_name or tool_name_by_call_id.get(message.tool_call_id, "")).strip() if not tool_name: raise ValueError( f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称," "且消息中未携带 tool_name" ) tool_name_by_call_id[message.tool_call_id] = tool_name function_response_part = Part( function_response=FunctionResponse( id=message.tool_call_id, name=tool_name, response=_normalize_function_response_payload(message), ) ) contents.append(Content(role="tool", parts=[function_response_part])) continue raise ValueError(f"不支持的消息角色: {message.role}") system_instruction = "\n\n".join(chunk for chunk in system_instruction_chunks if chunk.strip()) or None return contents, system_instruction def _build_tools(tool_options: List[ToolOption]) -> List[Tool]: """将内部工具定义转换为 Gemini `Tool` 列表。 Args: tool_options: 内部统一工具定义列表。 Returns: List[Tool]: Gemini 所需工具列表。 """ function_declarations: List[FunctionDeclaration] = [] for tool_option in tool_options: payload: Dict[str, Any] = { "name": tool_option.name, "description": tool_option.description, } if tool_option.parameters_schema is not None: payload["parameters_json_schema"] = tool_option.parameters_schema function_declarations.append(FunctionDeclaration(**payload)) return [Tool(function_declarations=function_declarations)] if function_declarations else [] def _extract_usage_record(response: GenerateContentResponse) -> Optional[UsageTuple]: """从 Gemini 响应中提取使用量信息。 Args: response: Gemini 响应对象。 Returns: Optional[UsageTuple]: 统一的使用量三元组;缺失时返回 `None`。 """ usage_metadata = getattr(response, "usage_metadata", None) if usage_metadata is None: return None prompt_tokens = getattr(usage_metadata, "prompt_token_count", 0) or 0 completion_tokens = ( (getattr(usage_metadata, "candidates_token_count", 0) or 0) + (getattr(usage_metadata, "thoughts_token_count", 0) or 0) ) total_tokens = getattr(usage_metadata, "total_token_count", 0) or 0 return prompt_tokens, completion_tokens, total_tokens def _extract_finish_reason(response: GenerateContentResponse | None) -> str | None: """提取 Gemini 响应的结束原因。 Args: response: Gemini 响应对象。 Returns: str | None: 结束原因字符串;获取失败时返回 `None`。 """ if response is None: return None candidates = _get_candidates(response) if not candidates: return None for candidate in candidates: finish_reason = getattr(candidate, "finish_reason", None) or getattr(candidate, "finishReason", None) if finish_reason: return str(finish_reason) return None def _warn_if_max_tokens_truncated( response: GenerateContentResponse | None, content: str | None, tool_calls: List[ToolCall] | None, ) -> None: """在 Gemini 因 token 限制截断时输出警告。 Args: response: Gemini 响应对象。 content: 已解析的可见文本内容。 tool_calls: 已解析的工具调用列表。 """ finish_reason = _extract_finish_reason(response) if finish_reason is None or "MAX_TOKENS" not in finish_reason: return has_visible_output = bool((content and content.strip()) or tool_calls) if has_visible_output: logger.warning( "Gemini 响应因达到 max_tokens 限制被部分截断,可能影响回复完整性,建议调整模型 max_tokens 配置。" ) return logger.warning("Gemini 响应因达到 max_tokens 限制被截断,且未返回可见输出,请检查模型 max_tokens 配置。") def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]: """从 Gemini 响应中提取工具调用列表。 Args: response: Gemini 响应对象。 Returns: List[ToolCall]: 规范化后的工具调用列表。 Raises: RespParseException: 当函数调用结构不合法时抛出。 """ candidates = _get_candidates(response) tool_calls: List[ToolCall] = [] for candidate in candidates: content = getattr(candidate, "content", None) parts = getattr(content, "parts", None) or [] for part in parts: function_call = getattr(part, "function_call", None) if function_call is None: continue call_name = getattr(function_call, "name", None) call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{len(tool_calls) + 1}" call_args = getattr(function_call, "args", None) or {} if not isinstance(call_name, str) or not call_name: raise RespParseException(response, "响应解析失败,Gemini 工具调用缺少 name 字段") if not isinstance(call_args, dict): raise RespParseException(response, "响应解析失败,Gemini 工具调用参数无法解析为字典") tool_calls.append( ToolCall( call_id=call_id, func_name=call_name, args=call_args, extra_content=_build_gemini_tool_call_extra_content(getattr(part, "thought_signature", None)), ) ) if tool_calls: return tool_calls raw_function_calls = getattr(response, "function_calls", None) if not raw_function_calls: return [] for index, function_call in enumerate(raw_function_calls, start=1): call_name = getattr(function_call, "name", None) call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}" call_args = getattr(function_call, "args", None) or {} if not isinstance(call_name, str) or not call_name: raise RespParseException(response, "响应解析失败,Gemini 工具调用缺少 name 字段") if not isinstance(call_args, dict): raise RespParseException(response, "响应解析失败,Gemini 工具调用参数无法解析为字典") tool_calls.append(ToolCall(call_id=call_id, func_name=call_name, args=call_args)) return tool_calls def _process_stream_chunk( chunk: GenerateContentResponse, content_buffer: io.StringIO, tool_calls_buffer: List[ToolCall], response: APIResponse, ) -> None: """处理单个 Gemini 流式响应块。 Args: chunk: 当前流式响应块。 content_buffer: 正文缓冲区。 tool_calls_buffer: 工具调用缓冲区。 response: 当前累积的统一响应对象。 """ candidates = _get_candidates(chunk) for candidate in candidates: content = getattr(candidate, "content", None) parts = getattr(content, "parts", None) or [] for part in parts: part_text = getattr(part, "text", None) if not part_text: continue if getattr(part, "thought", False): response.reasoning_content = (response.reasoning_content or "") + part_text else: content_buffer.write(part_text) tool_calls_buffer.extend(_collect_function_calls(chunk)) def _build_stream_api_response( content_buffer: io.StringIO, tool_calls_buffer: List[ToolCall], last_response: GenerateContentResponse | None, response: APIResponse, ) -> APIResponse: """根据流式缓冲区内容构建统一响应对象。 Args: content_buffer: 正文缓冲区。 tool_calls_buffer: 工具调用缓冲区。 last_response: 最后一个 Gemini 响应块。 response: 已累积的响应对象。 Returns: APIResponse: 构建完成的统一响应对象。 Raises: EmptyResponseException: 响应中既无正文也无工具调用且无思考内容时抛出。 """ if content_buffer.tell() > 0: response.content = content_buffer.getvalue() content_buffer.close() if tool_calls_buffer: response.tool_calls = list(tool_calls_buffer) response.raw_data = last_response _warn_if_max_tokens_truncated(last_response, response.content, response.tool_calls) if not response.content and not response.tool_calls and not response.reasoning_content: raise EmptyResponseException(last_response) return response async def _default_stream_response_handler( response_stream: AsyncIterator[GenerateContentResponse], interrupt_flag: asyncio.Event | None, ) -> Tuple[APIResponse, Optional[UsageTuple]]: """处理 Gemini 流式响应。 Args: response_stream: Gemini 异步流式响应迭代器。 interrupt_flag: 外部中断标记。 Returns: Tuple[APIResponse, Optional[UsageTuple]]: 统一响应对象与可选的使用量信息。 """ content_buffer = io.StringIO() tool_calls_buffer: List[ToolCall] = [] api_response = APIResponse() usage_record: Optional[UsageTuple] = None last_response: GenerateContentResponse | None = None try: async for chunk in response_stream: last_response = chunk if interrupt_flag and interrupt_flag.is_set(): raise ReqAbortException("请求被外部信号中断") _process_stream_chunk(chunk, content_buffer, tool_calls_buffer, api_response) usage_record = _extract_usage_record(chunk) or usage_record return _build_stream_api_response(content_buffer, tool_calls_buffer, last_response, api_response), usage_record except Exception: if not content_buffer.closed: content_buffer.close() raise def _default_normal_response_parser( response: GenerateContentResponse, ) -> Tuple[APIResponse, Optional[UsageTuple]]: """解析 Gemini 非流式响应。 Args: response: Gemini 响应对象。 Returns: Tuple[APIResponse, Optional[UsageTuple]]: 统一响应对象与可选的使用量信息。 Raises: EmptyResponseException: 响应中既无正文也无工具调用且无思考内容时抛出。 """ api_response = APIResponse(raw_data=response) visible_parts: List[str] = [] for candidate in _get_candidates(response): content = getattr(candidate, "content", None) parts = getattr(content, "parts", None) or [] for part in parts: part_text = getattr(part, "text", None) if not part_text: continue if getattr(part, "thought", False): api_response.reasoning_content = (api_response.reasoning_content or "") + part_text else: visible_parts.append(part_text) api_response.content = "".join(visible_parts).strip() or getattr(response, "text", None) tool_calls = _collect_function_calls(response) if tool_calls: api_response.tool_calls = tool_calls usage_record = _extract_usage_record(response) _warn_if_max_tokens_truncated(response, api_response.content, api_response.tool_calls) if not api_response.content and not api_response.tool_calls and not api_response.reasoning_content: raise EmptyResponseException(response, "响应中既无文本内容也无工具调用") return api_response, usage_record def _build_http_options(api_provider: APIProvider) -> HttpOptions: """根据 Provider 配置构建 Gemini SDK 的 `HttpOptions`。 Args: api_provider: API 提供商配置。 Returns: HttpOptions: Gemini SDK HTTP 选项对象。 """ http_options_payload: Dict[str, Any] = {} if api_provider.timeout is not None: http_options_payload["timeout"] = int(api_provider.timeout * 1000) base_url = api_provider.base_url.strip() if base_url: normalized_base_url = base_url.rstrip("/") version_candidate = normalized_base_url.rsplit("/", 1) if len(version_candidate) == 2 and version_candidate[1].startswith("v"): http_options_payload["base_url"] = f"{version_candidate[0]}/" http_options_payload["api_version"] = version_candidate[1] else: http_options_payload["base_url"] = f"{normalized_base_url}/" return HttpOptions(**http_options_payload) def _filter_generate_content_extra_params(extra_params: Dict[str, Any]) -> Dict[str, Any]: """筛选可透传给 `GenerateContentConfig` 的额外参数。 Args: extra_params: 模型级额外参数。 Returns: Dict[str, Any]: 可直接透传到 `GenerateContentConfig` 的字段字典。 """ filtered_params: Dict[str, Any] = {} for key, value in extra_params.items(): if key in GENERATE_CONFIG_RESERVED_EXTRA_PARAMS: continue if key in GenerateContentConfig.model_fields: filtered_params[key] = value return filtered_params def _build_embed_content_config(extra_params: Dict[str, Any]) -> EmbedContentConfig: """构建 Gemini 嵌入配置。 Args: extra_params: 模型级额外参数。 Returns: EmbedContentConfig: Gemini 嵌入配置对象。 """ config_payload: Dict[str, Any] = {"task_type": extra_params.get("task_type", "SEMANTIC_SIMILARITY")} for key in EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS: if key == "task_type": continue if key in extra_params: config_payload[key] = extra_params[key] return EmbedContentConfig(**config_payload) @client_registry.register_client_class("gemini") class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], GenerateContentResponse]): """Gemini 官方 SDK 客户端适配器。""" client: genai.Client def __init__(self, api_provider: APIProvider) -> None: """初始化 Gemini 客户端。 Args: api_provider: API 提供商配置。 """ super().__init__(api_provider) self.client = genai.Client( api_key=api_provider.api_key, http_options=_build_http_options(api_provider), ) @staticmethod def clamp_thinking_budget(extra_params: Dict[str, Any] | None, model_id: str) -> int: """将思考预算裁剪到模型允许的范围内。 Args: extra_params: 请求额外参数。 model_id: 当前模型标识。 Returns: int: 裁剪后的思考预算值。 """ thinking_budget = THINKING_BUDGET_AUTO if extra_params and "thinking_budget" in extra_params: try: thinking_budget = int(extra_params["thinking_budget"]) except (TypeError, ValueError): logger.warning(f"无效的 thinking_budget={extra_params['thinking_budget']},已回退为自动模式") limits: Dict[str, int | bool] | None = None if model_id in THINKING_BUDGET_LIMITS: limits = THINKING_BUDGET_LIMITS[model_id] else: for candidate_prefix in sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True): if model_id == candidate_prefix or model_id.startswith(f"{candidate_prefix}-"): limits = THINKING_BUDGET_LIMITS[candidate_prefix] break if thinking_budget == THINKING_BUDGET_AUTO: return THINKING_BUDGET_AUTO if thinking_budget == THINKING_BUDGET_DISABLED: if limits and bool(limits.get("can_disable", False)): return THINKING_BUDGET_DISABLED if limits: minimum_value = int(limits["min"]) logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退为最小值 {minimum_value}") return minimum_value return THINKING_BUDGET_AUTO if limits is None: logger.warning(f"模型 {model_id} 未配置思考预算范围,已回退为自动模式") return THINKING_BUDGET_AUTO minimum_value = int(limits["min"]) maximum_value = int(limits["max"]) if thinking_budget < minimum_value: logger.warning(f"模型 {model_id} 的 thinking_budget={thinking_budget} 过小,已调整为 {minimum_value}") return minimum_value if thinking_budget > maximum_value: logger.warning(f"模型 {model_id} 的 thinking_budget={thinking_budget} 过大,已调整为 {maximum_value}") return maximum_value return thinking_budget @staticmethod def _resolve_model_identifier(model_identifier: str, extra_params: Dict[str, Any]) -> Tuple[str, bool]: """解析请求实际使用的 Gemini 模型标识。 Args: model_identifier: 原始模型标识。 extra_params: 模型级额外参数。 Returns: Tuple[str, bool]: `(实际模型标识, 是否启用 Google Search)`。 """ enable_google_search = bool(extra_params.get("enable_google_search", False)) resolved_model_identifier = model_identifier if resolved_model_identifier.endswith("-search"): resolved_model_identifier = resolved_model_identifier.removesuffix("-search") enable_google_search = True return resolved_model_identifier, enable_google_search def _build_generation_config( self, *, model_identifier: str, system_instruction: str | None, tool_options: List[ToolOption] | None, response_format: RespFormat | None, max_tokens: int | None, temperature: float | None, extra_params: Dict[str, Any], enable_google_search: bool, ) -> GenerateContentConfig: """构建 Gemini 生成配置。 Args: model_identifier: 当前请求实际使用的模型标识。 system_instruction: 系统指令文本。 tool_options: 内部工具定义列表。 response_format: 输出格式定义。 max_tokens: 最大输出 token 数。 temperature: 温度参数。 extra_params: 模型级额外参数。 enable_google_search: 是否自动追加 Google Search 工具。 Returns: GenerateContentConfig: Gemini 生成配置对象。 """ config_payload = _filter_generate_content_extra_params(extra_params) if max_tokens is not None and "max_output_tokens" not in config_payload: config_payload["max_output_tokens"] = max_tokens if temperature is not None and "temperature" not in config_payload: config_payload["temperature"] = temperature if system_instruction and "system_instruction" not in config_payload: config_payload["system_instruction"] = system_instruction if "response_modalities" not in config_payload: config_payload["response_modalities"] = ["TEXT"] if "safety_settings" not in config_payload: config_payload["safety_settings"] = GEMINI_SAFE_SETTINGS if "thinking_config" not in config_payload: config_payload["thinking_config"] = ThinkingConfig( include_thoughts=bool(extra_params.get("include_thoughts", True)), thinking_budget=self.clamp_thinking_budget(extra_params, model_identifier), ) tools = _build_tools(tool_options) if tool_options else [] if enable_google_search: tools.append(Tool(google_search=GoogleSearch())) if tools: if "tools" in config_payload: existing_tools = config_payload["tools"] if isinstance(existing_tools, list): config_payload["tools"] = [*existing_tools, *tools] else: config_payload["tools"] = [existing_tools, *tools] else: config_payload["tools"] = tools if response_format is not None: if response_format.format_type == RespFormatType.TEXT: config_payload.setdefault("response_mime_type", "text/plain") elif response_format.format_type == RespFormatType.JSON_OBJ: config_payload.setdefault("response_mime_type", "application/json") elif response_format.format_type == RespFormatType.JSON_SCHEMA: config_payload.setdefault("response_mime_type", "application/json") response_json_schema = _extract_response_json_schema(response_format) if ( response_json_schema is not None and "response_json_schema" not in config_payload and "response_schema" not in config_payload ): config_payload["response_json_schema"] = response_json_schema return GenerateContentConfig(**config_payload) def _build_default_stream_response_handler( self, request: ResponseRequest, ) -> ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]]: """构建 Gemini 默认流式响应处理器。 Args: request: 统一响应请求对象。 Returns: ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]]: 默认流式处理器。 """ del request return _default_stream_response_handler def _build_default_response_parser( self, request: ResponseRequest, ) -> ProviderResponseParser[GenerateContentResponse]: """构建 Gemini 默认非流式响应解析器。 Args: request: 统一响应请求对象。 Returns: ProviderResponseParser[GenerateContentResponse]: 默认非流式解析器。 """ del request return _default_normal_response_parser async def _execute_response_request( self, request: ResponseRequest, stream_response_handler: ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]], response_parser: ProviderResponseParser[GenerateContentResponse], ) -> Tuple[APIResponse, UsageTuple | None]: """执行 Gemini 的文本/多模态响应请求。 Args: request: 统一响应请求对象。 stream_response_handler: 流式响应处理器。 response_parser: 非流式响应解析器。 Returns: Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ model_info = request.model_info snapshot_provider_request = { "base_url": self.api_provider.base_url, "endpoint": "/models/{model}:generateContent", "method": "POST", "operation": "models.generate_content", "request_kwargs": {}, } try: contents, system_instruction = _convert_messages(request.message_list) model_identifier, enable_google_search = self._resolve_model_identifier( model_info.model_identifier, request.extra_params, ) generation_config = self._build_generation_config( model_identifier=model_identifier, system_instruction=system_instruction, tool_options=request.tool_options, response_format=request.response_format, max_tokens=request.max_tokens, temperature=request.temperature, extra_params=request.extra_params, enable_google_search=enable_google_search, ) snapshot_provider_request["request_kwargs"] = { "config": generation_config, "contents": contents, "enable_google_search": enable_google_search, "model": model_identifier, "system_instruction": system_instruction, } if model_info.force_stream_mode: stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task( self.client.aio.models.generate_content_stream( model=model_identifier, contents=contents, config=generation_config, ) ) raw_response_stream = cast( AsyncIterator[GenerateContentResponse], await await_task_with_interrupt(stream_task, request.interrupt_flag), ) return await stream_response_handler(raw_response_stream, request.interrupt_flag) completion_task: asyncio.Task[GenerateContentResponse] = asyncio.create_task( self.client.aio.models.generate_content( model=model_identifier, contents=contents, config=generation_config, ) ) raw_response = cast( GenerateContentResponse, await await_task_with_interrupt(completion_task, request.interrupt_flag), ) return response_parser(raw_response) except ReqAbortException: raise except (ClientError, ServerError) as exc: status_code = int(getattr(exc, "code", 500) or 500) snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=exc, internal_request=serialize_response_request_snapshot(request), model_info=model_info, operation="models.generate_content", provider_request=snapshot_provider_request, ) wrapped_error = RespNotOkException(status_code, str(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc: wrapped_error = RespParseException(None, f"Gemini 工具调用参数错误: {exc}") snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=wrapped_error, internal_request=serialize_response_request_snapshot(request), model_info=model_info, operation="models.generate_content", provider_request=snapshot_provider_request, ) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except EmptyResponseException as exc: snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=exc, internal_request=serialize_response_request_snapshot(request), model_info=model_info, operation="models.generate_content", provider_request=snapshot_provider_request, ) attach_request_snapshot(exc, snapshot_path) raise except Exception as exc: if has_request_snapshot(exc): raise snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=exc, internal_request=serialize_response_request_snapshot(request), model_info=model_info, operation="models.generate_content", provider_request=snapshot_provider_request, ) wrapped_error = ( exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc)) ) attach_request_snapshot(wrapped_error, snapshot_path) if wrapped_error is exc: raise raise wrapped_error from exc except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc: raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc except EmptyResponseException: raise except Exception as exc: raise NetworkConnectionError(str(exc)) from exc async def _execute_embedding_request( self, request: EmbeddingRequest, ) -> Tuple[APIResponse, UsageTuple | None]: """执行 Gemini 文本嵌入请求。 Args: request: 统一嵌入请求对象。 Returns: Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ model_info = request.model_info embedding_input = request.embedding_input extra_params = request.extra_params snapshot_provider_request = { "base_url": self.api_provider.base_url, "endpoint": "/models/{model}:embedContent", "method": "POST", "operation": "models.embed_content", "request_kwargs": {}, } try: embed_config = _build_embed_content_config(extra_params) snapshot_provider_request["request_kwargs"] = { "config": embed_config, "contents": embedding_input, "model": model_info.model_identifier, } raw_response: EmbedContentResponse = await self.client.aio.models.embed_content( model=model_info.model_identifier, contents=embedding_input, config=embed_config, ) except (ClientError, ServerError) as exc: status_code = int(getattr(exc, "code", 500) or 500) snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=exc, internal_request=serialize_embedding_request_snapshot(request), model_info=model_info, operation="models.embed_content", provider_request=snapshot_provider_request, ) wrapped_error = RespNotOkException(status_code, str(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except Exception as exc: if has_request_snapshot(exc): raise snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=exc, internal_request=serialize_embedding_request_snapshot(request), model_info=model_info, operation="models.embed_content", provider_request=snapshot_provider_request, ) wrapped_error = ( exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc)) ) attach_request_snapshot(wrapped_error, snapshot_path) if wrapped_error is exc: raise raise wrapped_error from exc response = APIResponse(raw_data=raw_response) if not raw_response.embeddings: exc = RespParseException(raw_response, "Gemini 嵌入响应解析失败,缺少 embeddings 字段。") snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=exc, internal_request=serialize_embedding_request_snapshot(request), model_info=model_info, operation="models.embed_content", provider_request=snapshot_provider_request, ) attach_request_snapshot(exc, snapshot_path) raise exc if raw_response.embeddings: response.embedding = raw_response.embeddings[0].values else: raise RespParseException(raw_response, "响应解析失败,缺失 embeddings 字段") billable_character_count = 0 if raw_response.metadata is not None: billable_character_count = getattr(raw_response.metadata, "billable_character_count", 0) or 0 usage_record: UsageTuple = ( billable_character_count or len(embedding_input), 0, billable_character_count or len(embedding_input), ) return response, usage_record async def _execute_audio_transcription_request( self, request: AudioTranscriptionRequest, ) -> Tuple[APIResponse, UsageTuple | None]: """执行 Gemini 音频转录请求。 Args: request: 统一音频转录请求对象。 Returns: Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ model_info = request.model_info audio_base64 = request.audio_base64 max_tokens = request.max_tokens extra_params = request.extra_params snapshot_provider_request = { "base_url": self.api_provider.base_url, "endpoint": "/models/{model}:generateContent", "method": "POST", "operation": "models.generate_content", "request_kwargs": {}, } transcription_prompt = str( extra_params.get( "transcription_prompt", "Generate a transcript of the speech. The language of the transcript should match the speech.", ) ) audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav")) try: model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params) contents: List[ContentUnion] = [ Content( role="user", parts=[ Part.from_text(text=transcription_prompt), Part.from_bytes(data=base64.b64decode(audio_base64), mime_type=audio_mime_type), ], ) ] generation_config = self._build_generation_config( model_identifier=model_identifier, system_instruction=None, tool_options=None, response_format=None, max_tokens=max_tokens, temperature=None, extra_params=extra_params, enable_google_search=False, ) snapshot_provider_request["request_kwargs"] = { "audio_base64": audio_base64, "audio_mime_type": audio_mime_type, "config": generation_config, "contents": contents, "model": model_identifier, "transcription_prompt": transcription_prompt, } raw_response: GenerateContentResponse = await self.client.aio.models.generate_content( model=model_identifier, contents=contents, config=generation_config, ) response, usage_record = _default_normal_response_parser(raw_response) except (ClientError, ServerError) as exc: status_code = int(getattr(exc, "code", 500) or 500) snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=exc, internal_request=serialize_audio_request_snapshot(request), model_info=model_info, operation="models.generate_content", provider_request=snapshot_provider_request, ) wrapped_error = RespNotOkException(status_code, str(exc)) attach_request_snapshot(wrapped_error, snapshot_path) raise wrapped_error from exc except Exception as exc: if has_request_snapshot(exc): raise snapshot_path = save_failed_request_snapshot( api_provider=self.api_provider, client_type="gemini", error=exc, internal_request=serialize_audio_request_snapshot(request), model_info=model_info, operation="models.generate_content", provider_request=snapshot_provider_request, ) wrapped_error = ( exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc)) ) attach_request_snapshot(wrapped_error, snapshot_path) if wrapped_error is exc: raise raise wrapped_error from exc return response, usage_record def get_support_image_formats(self) -> List[str]: """获取 Gemini 当前支持的图片格式列表。 Returns: List[str]: 当前客户端支持的图片格式列表。 """ return ["png", "jpg", "jpeg", "webp", "heic", "heif"]