让Gemini的图像可用,修复部分typing

This commit is contained in:
UnCLAS-Prommer
2025-08-03 00:49:19 +08:00
parent 38930b0ceb
commit 9afa549aee
5 changed files with 88 additions and 57 deletions

View File

@@ -1,9 +1,7 @@
import asyncio
from dataclasses import dataclass
from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
@@ -58,7 +56,7 @@ class APIResponse:
"""响应原始数据"""
class BaseClient:
class BaseClient(ABC):
"""
基础客户端
"""
@@ -68,6 +66,7 @@ class BaseClient:
def __init__(self, api_provider: APIProvider):
self.api_provider = api_provider
@abstractmethod
async def get_response(
self,
model_info: ModelInfo,
@@ -76,12 +75,10 @@ class BaseClient:
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
@@ -98,8 +95,9 @@ class BaseClient:
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
raise RuntimeError("This method should be overridden in subclasses")
raise NotImplementedError("'get_response' method should be overridden in subclasses")
@abstractmethod
async def get_embedding(
self,
model_info: ModelInfo,
@@ -112,8 +110,9 @@ class BaseClient:
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
raise RuntimeError("This method should be overridden in subclasses")
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
@abstractmethod
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
@@ -127,7 +126,15 @@ class BaseClient:
:extra_params: 附加的请求参数
:return: 音频转录响应
"""
raise RuntimeError("This method should be overridden in subclasses")
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
@abstractmethod
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
class ClientRegistry:
@@ -137,7 +144,8 @@ class ClientRegistry:
def register_client_class(self, client_type: str):
"""
注册API客户端类
:param client_class: API客户端类
Args:
client_class: API客户端类
"""
def decorator(cls: type[BaseClient]) -> type[BaseClient]: