feat: Enhance OpenAI compatibility and introduce unified LLM service data models
- Refactored model fetching logic to support various authentication methods for OpenAI-compatible APIs. - Introduced new data models for LLM service requests and responses to standardize interactions across layers. - Added an adapter base class for unified request execution across different providers. - Implemented utility functions for building OpenAI-compatible client configurations and request overrides.
This commit is contained in:
@@ -13,6 +13,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.config.model_configs import APIProvider
|
||||
from src.llm_models.openai_compat import build_openai_compatible_client_config, normalize_openai_base_url
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.utils.network_security import validate_public_url
|
||||
|
||||
@@ -35,8 +37,8 @@ MODEL_FETCHER_CONFIG = {
|
||||
|
||||
|
||||
def _normalize_url(url: str) -> str:
|
||||
"""规范化 URL(去掉尾部斜杠)"""
|
||||
return url.rstrip("/") if url else ""
|
||||
"""规范化 URL(去掉尾部斜杠)。"""
|
||||
return normalize_openai_base_url(url) if url else ""
|
||||
|
||||
|
||||
def _parse_openai_response(data: Dict) -> List[Dict]:
|
||||
@@ -89,19 +91,30 @@ async def _fetch_models_from_provider(
|
||||
endpoint: str,
|
||||
parser: str,
|
||||
client_type: str = "openai",
|
||||
auth_type: str = "bearer",
|
||||
auth_header_name: str = "Authorization",
|
||||
auth_header_prefix: str = "Bearer",
|
||||
auth_query_name: str = "api_key",
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
default_query: Optional[Dict[str, str]] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
从提供商 API 获取模型列表
|
||||
"""从提供商 API 获取模型列表。
|
||||
|
||||
Args:
|
||||
base_url: 提供商的基础 URL
|
||||
api_key: API 密钥
|
||||
endpoint: 获取模型列表的端点
|
||||
parser: 响应解析器类型 ('openai' | 'gemini')
|
||||
client_type: 客户端类型 ('openai' | 'gemini')
|
||||
base_url: 提供商的基础 URL。
|
||||
api_key: API 密钥。
|
||||
endpoint: 获取模型列表的端点。
|
||||
parser: 响应解析器类型。
|
||||
client_type: 客户端类型。
|
||||
auth_type: OpenAI 兼容接口的鉴权方式。
|
||||
auth_header_name: Header 鉴权时使用的请求头名称。
|
||||
auth_header_prefix: Header 鉴权时使用的请求头前缀。
|
||||
auth_query_name: Query 鉴权时使用的查询参数名称。
|
||||
default_headers: 默认附带的请求头。
|
||||
default_query: 默认附带的查询参数。
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
List[Dict]: 解析后的模型列表。
|
||||
"""
|
||||
try:
|
||||
base_url = validate_public_url(_normalize_url(base_url))
|
||||
@@ -118,8 +131,21 @@ async def _fetch_models_from_provider(
|
||||
# Gemini 使用 URL 参数传递 API Key
|
||||
params["key"] = api_key
|
||||
else:
|
||||
# OpenAI 兼容格式使用 Authorization 头
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
provider = APIProvider(
|
||||
name="webui-openai-compatible-fetcher",
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
client_type="openai",
|
||||
auth_type=auth_type,
|
||||
auth_header_name=auth_header_name,
|
||||
auth_header_prefix=auth_header_prefix,
|
||||
auth_query_name=auth_query_name,
|
||||
default_headers=default_headers or {},
|
||||
default_query=default_query or {},
|
||||
)
|
||||
client_config = build_openai_compatible_client_config(provider)
|
||||
headers.update(client_config.default_headers)
|
||||
params.update(client_config.default_query)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
@@ -186,10 +212,9 @@ async def get_provider_models(
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
"""获取指定提供商的可用模型列表。
|
||||
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点。
|
||||
"""
|
||||
# 获取提供商配置
|
||||
provider_config = _get_provider_config(provider_name)
|
||||
@@ -205,13 +230,21 @@ async def get_provider_models(
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
||||
|
||||
resolved_endpoint = provider_config.get("model_list_endpoint", endpoint) if endpoint == "/models" else endpoint
|
||||
|
||||
# 获取模型列表
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
endpoint=resolved_endpoint,
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
auth_type=provider_config.get("auth_type", "bearer"),
|
||||
auth_header_name=provider_config.get("auth_header_name", "Authorization"),
|
||||
auth_header_prefix=provider_config.get("auth_header_prefix", "Bearer"),
|
||||
auth_query_name=provider_config.get("auth_query_name", "api_key"),
|
||||
default_headers=provider_config.get("default_headers", {}),
|
||||
default_query=provider_config.get("default_query", {}),
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -229,16 +262,22 @@ async def get_models_by_url(
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
auth_type: str = Query("bearer", description="鉴权方式 (bearer | header | query | none)"),
|
||||
auth_header_name: str = Query("Authorization", description="Header 鉴权名称"),
|
||||
auth_header_prefix: str = Query("Bearer", description="Header 鉴权前缀"),
|
||||
auth_query_name: str = Query("api_key", description="Query 鉴权参数名"),
|
||||
):
|
||||
"""
|
||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||
"""
|
||||
"""通过 URL 直接获取模型列表。"""
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
auth_type=auth_type,
|
||||
auth_header_name=auth_header_name,
|
||||
auth_header_prefix=auth_header_prefix,
|
||||
auth_query_name=auth_query_name,
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user