feat: Enhance OpenAI compatibility and introduce unified LLM service data models
- Refactored model fetching logic to support various authentication methods for OpenAI-compatible APIs. - Introduced new data models for LLM service requests and responses to standardize interactions across layers. - Added an adapter base class for unified request execution across different providers. - Implemented utility functions for building OpenAI-compatible client configurations and request overrides.
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||
from src.config.model_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||||
|
||||
logger = get_logger("model_client_registry")
|
||||
|
||||
@@ -47,10 +48,10 @@ class APIResponse:
|
||||
reasoning_content: str | None = None
|
||||
"""推理内容"""
|
||||
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
"""工具调用 [(工具名称, 工具参数), ...]"""
|
||||
|
||||
embedding: list[float] | None = None
|
||||
embedding: List[float] | None = None
|
||||
"""嵌入向量"""
|
||||
|
||||
usage: UsageRecord | None = None
|
||||
@@ -60,6 +61,82 @@ class APIResponse:
|
||||
"""响应原始数据"""
|
||||
|
||||
|
||||
UsageTuple = Tuple[int, int, int]
|
||||
"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
|
||||
|
||||
StreamResponseHandler = Callable[
|
||||
[Any, asyncio.Event | None],
|
||||
Coroutine[Any, Any, Tuple["APIResponse", UsageTuple | None]],
|
||||
]
|
||||
"""统一的流式响应处理函数类型。"""
|
||||
|
||||
ResponseParser = Callable[[Any], Tuple["APIResponse", UsageTuple | None]]
|
||||
"""统一的非流式响应解析函数类型。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ResponseRequest:
|
||||
"""统一的文本/多模态响应请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
message_list: List[Message]
|
||||
tool_options: List[ToolOption] | None = None
|
||||
max_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
response_format: RespFormat | None = None
|
||||
stream_response_handler: StreamResponseHandler | None = None
|
||||
async_response_parser: ResponseParser | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def copy_with(self, **changes: Any) -> "ResponseRequest":
|
||||
"""基于当前请求创建一个带局部变更的新请求。
|
||||
|
||||
Args:
|
||||
**changes: 需要覆盖的字段值。
|
||||
|
||||
Returns:
|
||||
ResponseRequest: 复制后的请求对象。
|
||||
"""
|
||||
payload = {
|
||||
"model_info": self.model_info,
|
||||
"message_list": list(self.message_list),
|
||||
"tool_options": None if self.tool_options is None else list(self.tool_options),
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"response_format": self.response_format,
|
||||
"stream_response_handler": self.stream_response_handler,
|
||||
"async_response_parser": self.async_response_parser,
|
||||
"interrupt_flag": self.interrupt_flag,
|
||||
"extra_params": dict(self.extra_params),
|
||||
}
|
||||
payload.update(changes)
|
||||
return ResponseRequest(**payload)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class EmbeddingRequest:
|
||||
"""统一的嵌入请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
embedding_input: str
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AudioTranscriptionRequest:
|
||||
"""统一的音频转录请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
audio_base64: str
|
||||
max_tokens: int | None = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
ClientRequest = ResponseRequest | EmbeddingRequest | AudioTranscriptionRequest
|
||||
"""统一客户端请求类型。"""
|
||||
|
||||
|
||||
class BaseClient(ABC):
|
||||
"""
|
||||
基础客户端
|
||||
@@ -67,97 +144,82 @@ class BaseClient(ABC):
|
||||
|
||||
api_provider: APIProvider
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
def __init__(self, api_provider: APIProvider) -> None:
|
||||
"""初始化基础客户端。
|
||||
|
||||
Args:
|
||||
api_provider: API 提供商配置。
|
||||
"""
|
||||
self.api_provider = api_provider
|
||||
|
||||
@abstractmethod
|
||||
async def get_response(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
] = None,
|
||||
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取对话响应
|
||||
:param model_info: 模型信息
|
||||
:param message_list: 对话体
|
||||
:param tool_options: 工具选项(可选,默认为None)
|
||||
:param max_tokens: 最大token数(可选,默认为1024)
|
||||
:param temperature: 温度(可选,默认为0.7)
|
||||
:param response_format: 响应格式(可选,默认为 NotGiven )
|
||||
:param stream_response_handler: 流式响应处理函数(可选)
|
||||
:param async_response_parser: 响应解析函数(可选)
|
||||
:param interrupt_flag: 中断信号量(可选,默认为None)
|
||||
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
||||
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||
"""获取对话响应。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 统一响应对象。
|
||||
"""
|
||||
raise NotImplementedError("'get_response' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
embedding_input: str,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取文本嵌入
|
||||
:param model_info: 模型信息
|
||||
:param embedding_input: 嵌入输入文本
|
||||
:return: 嵌入响应
|
||||
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||
"""获取文本嵌入。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 嵌入响应。
|
||||
"""
|
||||
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
audio_base64: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取音频转录
|
||||
:param model_info: 模型信息
|
||||
:param audio_base64: base64编码的音频数据
|
||||
:extra_params: 附加的请求参数
|
||||
:return: 音频转录响应
|
||||
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||
"""获取音频转录。
|
||||
|
||||
Args:
|
||||
request: 统一音频转录请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 音频转录响应。
|
||||
"""
|
||||
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
:return: 支持的图片格式列表
|
||||
def get_support_image_formats(self) -> List[str]:
|
||||
"""获取支持的图片格式。
|
||||
|
||||
Returns:
|
||||
List[str]: 支持的图片格式列表。
|
||||
"""
|
||||
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||
|
||||
|
||||
class ClientRegistry:
|
||||
"""客户端注册表。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.client_registry: dict[str, type[BaseClient]] = {}
|
||||
"""初始化注册表并绑定配置重载回调。"""
|
||||
self.client_registry: Dict[str, Type[BaseClient]] = {}
|
||||
"""APIProvider.type -> BaseClient的映射表"""
|
||||
self.client_instance_cache: dict[str, BaseClient] = {}
|
||||
self.client_instance_cache: Dict[str, BaseClient] = {}
|
||||
"""APIProvider.name -> BaseClient的映射表"""
|
||||
config_manager.register_reload_callback(self.clear_client_instance_cache)
|
||||
|
||||
def register_client_class(self, client_type: str):
|
||||
"""
|
||||
注册API客户端类
|
||||
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
|
||||
"""注册 API 客户端类。
|
||||
|
||||
Args:
|
||||
client_class: API客户端类
|
||||
client_type: 客户端类型标识。
|
||||
|
||||
Returns:
|
||||
Callable[[Type[BaseClient]], Type[BaseClient]]: 装饰器函数。
|
||||
"""
|
||||
|
||||
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
||||
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
|
||||
if not issubclass(cls, BaseClient):
|
||||
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
||||
self.client_registry[client_type] = cls
|
||||
@@ -165,14 +227,15 @@ class ClientRegistry:
|
||||
|
||||
return decorator
|
||||
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
||||
"""
|
||||
获取注册的API客户端实例
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
|
||||
"""获取注册的 API 客户端实例。
|
||||
|
||||
Args:
|
||||
api_provider: APIProvider实例
|
||||
force_new: 是否强制创建新实例(用于解决事件循环问题)
|
||||
api_provider: APIProvider 实例。
|
||||
force_new: 是否强制创建新实例。
|
||||
|
||||
Returns:
|
||||
BaseClient: 注册的API客户端实例
|
||||
BaseClient: 注册的 API 客户端实例。
|
||||
"""
|
||||
from . import ensure_client_type_loaded
|
||||
|
||||
@@ -194,6 +257,7 @@ class ClientRegistry:
|
||||
return self.client_instance_cache[api_provider.name]
|
||||
|
||||
def clear_client_instance_cache(self) -> None:
|
||||
"""清空客户端实例缓存。"""
|
||||
self.client_instance_cache.clear()
|
||||
logger.info("检测到配置重载,已清空LLM客户端实例缓存")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user