让Gemini的图像可用,修复部分typing
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Any
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletion
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
@@ -58,7 +56,7 @@ class APIResponse:
|
||||
"""响应原始数据"""
|
||||
|
||||
|
||||
class BaseClient:
|
||||
class BaseClient(ABC):
|
||||
"""
|
||||
基础客户端
|
||||
"""
|
||||
@@ -68,6 +66,7 @@ class BaseClient:
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
self.api_provider = api_provider
|
||||
|
||||
@abstractmethod
|
||||
async def get_response(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -76,12 +75,10 @@ class BaseClient:
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Callable[
|
||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||
tuple[APIResponse, tuple[int, int, int]],
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
] = None,
|
||||
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
@@ -98,8 +95,9 @@ class BaseClient:
|
||||
:param interrupt_flag: 中断信号量(可选,默认为None)
|
||||
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
||||
"""
|
||||
raise RuntimeError("This method should be overridden in subclasses")
|
||||
raise NotImplementedError("'get_response' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -112,8 +110,9 @@ class BaseClient:
|
||||
:param embedding_input: 嵌入输入文本
|
||||
:return: 嵌入响应
|
||||
"""
|
||||
raise RuntimeError("This method should be overridden in subclasses")
|
||||
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -127,7 +126,15 @@ class BaseClient:
|
||||
:extra_params: 附加的请求参数
|
||||
:return: 音频转录响应
|
||||
"""
|
||||
raise RuntimeError("This method should be overridden in subclasses")
|
||||
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
:return: 支持的图片格式列表
|
||||
"""
|
||||
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||
|
||||
|
||||
class ClientRegistry:
|
||||
@@ -137,7 +144,8 @@ class ClientRegistry:
|
||||
def register_client_class(self, client_type: str):
|
||||
"""
|
||||
注册API客户端类
|
||||
:param client_class: API客户端类
|
||||
Args:
|
||||
client_class: API客户端类
|
||||
"""
|
||||
|
||||
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
||||
|
||||
Reference in New Issue
Block a user