feat: Enhance OpenAI compatibility and introduce unified LLM service data models

- Refactored model fetching logic to support various authentication methods for OpenAI-compatible APIs.
- Introduced new data models for LLM service requests and responses to standardize interactions across layers.
- Added an adapter base class for unified request execution across different providers.
- Implemented utility functions for building OpenAI-compatible client configurations and request overrides.
This commit is contained in:
DrSmoothl
2026-03-26 16:15:42 +08:00
parent 6e7daae55d
commit 777d4cb0d2
48 changed files with 5443 additions and 2945 deletions

View File

@@ -1,14 +1,15 @@
import asyncio
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
import asyncio
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
logger = get_logger("model_client_registry")
@@ -47,10 +48,10 @@ class APIResponse:
reasoning_content: str | None = None
"""推理内容"""
tool_calls: list[ToolCall] | None = None
tool_calls: List[ToolCall] | None = None
"""工具调用 [(工具名称, 工具参数), ...]"""
embedding: list[float] | None = None
embedding: List[float] | None = None
"""嵌入向量"""
usage: UsageRecord | None = None
@@ -60,6 +61,82 @@ class APIResponse:
"""响应原始数据"""
UsageTuple = Tuple[int, int, int]
"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
StreamResponseHandler = Callable[
[Any, asyncio.Event | None],
Coroutine[Any, Any, Tuple["APIResponse", UsageTuple | None]],
]
"""统一的流式响应处理函数类型。"""
ResponseParser = Callable[[Any], Tuple["APIResponse", UsageTuple | None]]
"""统一的非流式响应解析函数类型。"""
@dataclass(slots=True)
class ResponseRequest:
"""统一的文本/多模态响应请求。"""
model_info: ModelInfo
message_list: List[Message]
tool_options: List[ToolOption] | None = None
max_tokens: int | None = None
temperature: float | None = None
response_format: RespFormat | None = None
stream_response_handler: StreamResponseHandler | None = None
async_response_parser: ResponseParser | None = None
interrupt_flag: asyncio.Event | None = None
extra_params: Dict[str, Any] = field(default_factory=dict)
def copy_with(self, **changes: Any) -> "ResponseRequest":
"""基于当前请求创建一个带局部变更的新请求。
Args:
**changes: 需要覆盖的字段值。
Returns:
ResponseRequest: 复制后的请求对象。
"""
payload = {
"model_info": self.model_info,
"message_list": list(self.message_list),
"tool_options": None if self.tool_options is None else list(self.tool_options),
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"response_format": self.response_format,
"stream_response_handler": self.stream_response_handler,
"async_response_parser": self.async_response_parser,
"interrupt_flag": self.interrupt_flag,
"extra_params": dict(self.extra_params),
}
payload.update(changes)
return ResponseRequest(**payload)
@dataclass(slots=True)
class EmbeddingRequest:
"""统一的嵌入请求。"""
model_info: ModelInfo
embedding_input: str
extra_params: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class AudioTranscriptionRequest:
"""统一的音频转录请求。"""
model_info: ModelInfo
audio_base64: str
max_tokens: int | None = None
extra_params: Dict[str, Any] = field(default_factory=dict)
ClientRequest = ResponseRequest | EmbeddingRequest | AudioTranscriptionRequest
"""统一客户端请求类型。"""
class BaseClient(ABC):
"""
基础客户端
@@ -67,97 +144,82 @@ class BaseClient(ABC):
api_provider: APIProvider
def __init__(self, api_provider: APIProvider):
def __init__(self, api_provider: APIProvider) -> None:
"""初始化基础客户端。
Args:
api_provider: API 提供商配置。
"""
self.api_provider = api_provider
@abstractmethod
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取对话响应。
Args:
request: 统一响应请求对象。
Returns:
APIResponse: 统一响应对象。
"""
raise NotImplementedError("'get_response' method should be overridden in subclasses")
@abstractmethod
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取文本嵌入。
Args:
request: 统一嵌入请求对象。
Returns:
APIResponse: 嵌入响应。
"""
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
@abstractmethod
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
max_tokens: Optional[int] = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: base64编码的音频数据
:extra_params: 附加的请求参数
:return: 音频转录响应
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取音频转录。
Args:
request: 统一音频转录请求对象。
Returns:
APIResponse: 音频转录响应。
"""
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
@abstractmethod
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
def get_support_image_formats(self) -> List[str]:
"""获取支持的图片格式。
Returns:
List[str]: 支持的图片格式列表。
"""
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
class ClientRegistry:
"""客户端注册表。"""
def __init__(self) -> None:
self.client_registry: dict[str, type[BaseClient]] = {}
"""初始化注册表并绑定配置重载回调。"""
self.client_registry: Dict[str, Type[BaseClient]] = {}
"""APIProvider.type -> BaseClient的映射表"""
self.client_instance_cache: dict[str, BaseClient] = {}
self.client_instance_cache: Dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
config_manager.register_reload_callback(self.clear_client_instance_cache)
def register_client_class(self, client_type: str):
"""
注册API客户端类
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
"""注册 API 客户端类。
Args:
client_class: API客户端类
client_type: 客户端类型标识。
Returns:
Callable[[Type[BaseClient]], Type[BaseClient]]: 装饰器函数。
"""
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
@@ -165,14 +227,15 @@ class ClientRegistry:
return decorator
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
"""
获取注册的API客户端实例
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
"""获取注册的 API 客户端实例。
Args:
api_provider: APIProvider实例
force_new: 是否强制创建新实例(用于解决事件循环问题)
api_provider: APIProvider 实例
force_new: 是否强制创建新实例
Returns:
BaseClient: 注册的API客户端实例
BaseClient: 注册的 API 客户端实例
"""
from . import ensure_client_type_loaded
@@ -194,6 +257,7 @@ class ClientRegistry:
return self.client_instance_cache[api_provider.name]
def clear_client_instance_cache(self) -> None:
"""清空客户端实例缓存。"""
self.client_instance_cache.clear()
logger.info("检测到配置重载已清空LLM客户端实例缓存")