""" 模型列表获取API路由 提供从各个 AI 厂商 API 获取可用模型列表的代理接口 """ import os from typing import Dict, List, Optional import httpx import tomlkit from fastapi import APIRouter, Depends, HTTPException, Query from src.common.logger import get_logger from src.config.config import CONFIG_DIR from src.config.model_configs import APIProvider from src.llm_models.openai_compat import build_openai_compatible_client_config, normalize_openai_base_url from src.webui.dependencies import require_auth from src.webui.utils.network_security import validate_public_url logger = get_logger("webui") router = APIRouter(prefix="/models", tags=["models"], dependencies=[Depends(require_auth)]) # 模型获取器配置 MODEL_FETCHER_CONFIG = { # OpenAI 兼容格式的提供商 "openai": { "endpoint": "/models", "parser": "openai", }, # Gemini 格式 "gemini": { "endpoint": "/models", "parser": "gemini", }, } def _normalize_url(url: str) -> str: """规范化 URL(去掉尾部斜杠)。""" return normalize_openai_base_url(url) if url else "" def _parse_openai_response(data: Dict) -> List[Dict]: """ 解析 OpenAI 格式的模型列表响应 格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] } """ if "data" not in data or not isinstance(data["data"], list): return [] return [ { "id": model["id"], "name": model.get("name") or model["id"], "owned_by": model.get("owned_by", ""), } for model in data["data"] if isinstance(model, dict) and "id" in model ] def _parse_gemini_response(data: Dict) -> List[Dict]: """ 解析 Gemini 格式的模型列表响应 格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] } """ models = [] if "models" in data and isinstance(data["models"], list): for model in data["models"]: if isinstance(model, dict) and "name" in model: # Gemini 的 name 格式是 "models/gemini-pro",我们只取后面部分 model_id = model["name"] if model_id.startswith("models/"): model_id = model_id[7:] # 去掉 "models/" 前缀 models.append( { "id": model_id, "name": model.get("displayName") or model_id, "owned_by": "google", } ) return models async def _fetch_models_from_provider( base_url: str, api_key: str, endpoint: str, parser: str, client_type: str = "openai", auth_type: str = "bearer", auth_header_name: str = "Authorization", auth_header_prefix: str = "Bearer", auth_query_name: str = "api_key", default_headers: Optional[Dict[str, str]] = None, default_query: Optional[Dict[str, str]] = None, ) -> List[Dict]: """从提供商 API 获取模型列表。 Args: base_url: 提供商的基础 URL。 api_key: API 密钥。 endpoint: 获取模型列表的端点。 parser: 响应解析器类型。 client_type: 客户端类型。 auth_type: OpenAI 兼容接口的鉴权方式。 auth_header_name: Header 鉴权时使用的请求头名称。 auth_header_prefix: Header 鉴权时使用的请求头前缀。 auth_query_name: Query 鉴权时使用的查询参数名称。 default_headers: 默认附带的请求头。 default_query: 默认附带的查询参数。 Returns: List[Dict]: 解析后的模型列表。 """ try: base_url = validate_public_url(_normalize_url(base_url)) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e url = f"{base_url}{endpoint}" # 根据客户端类型设置请求头 headers = {} params = {} if client_type == "gemini": # Gemini 使用 URL 参数传递 API Key params["key"] = api_key else: provider = APIProvider( name="webui-openai-compatible-fetcher", base_url=base_url, api_key=api_key, client_type="openai", auth_type=auth_type, auth_header_name=auth_header_name, auth_header_prefix=auth_header_prefix, auth_query_name=auth_query_name, default_headers=default_headers or {}, default_query=default_query or {}, ) client_config = build_openai_compatible_client_config(provider) headers.update(client_config.default_headers) params.update(client_config.default_query) # build_openai_compatible_client_config 在“默认 Bearer”场景下, # 会把 api_key 留在 client_config.api_key 中交给 OpenAI SDK 自行注入 Authorization 头, # 而不会写入 default_headers。这里我们用 httpx 直接发请求,需要手动补上鉴权头/参数。 if client_config.api_key and "Authorization" not in headers: headers["Authorization"] = f"Bearer {client_config.api_key}" try: async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(url, headers=headers, params=params) response.raise_for_status() data = response.json() except httpx.TimeoutException as e: raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e except httpx.HTTPStatusError as e: # 注意:使用 502 Bad Gateway 而不是原始的 401/403, # 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理 if e.response.status_code == 401: raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e elif e.response.status_code == 403: raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e elif e.response.status_code == 404: raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e else: raise HTTPException( status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}" ) from e except Exception as e: logger.error(f"获取模型列表失败: {e}") raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e # 根据解析器类型解析响应 if parser == "openai": return _parse_openai_response(data) elif parser == "gemini": return _parse_gemini_response(data) else: raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}") def _get_provider_config(provider_name: str) -> Optional[Dict]: """ 从 model_config.toml 获取指定提供商的配置 Args: provider_name: 提供商名称 Returns: 提供商配置,如果未找到则返回 None """ config_path = os.path.join(CONFIG_DIR, "model_config.toml") if not os.path.exists(config_path): return None try: with open(config_path, "r", encoding="utf-8") as f: config_data = tomlkit.load(f) providers = config_data.get("api_providers", []) provider = next((provider for provider in providers if provider.get("name") == provider_name), None) return dict(provider) if provider is not None else None except Exception as e: logger.error(f"读取提供商配置失败: {e}") return None @router.get("/list") async def get_provider_models( provider_name: str = Query(..., description="提供商名称"), parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), ): """获取指定提供商的可用模型列表。 通过提供商名称查找配置,然后请求对应的模型列表端点。 """ # 获取提供商配置 provider_config = _get_provider_config(provider_name) if not provider_config: raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}") base_url = provider_config.get("base_url") api_key = provider_config.get("api_key") client_type = provider_config.get("client_type", "openai") if not base_url: raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") if not api_key: raise HTTPException(status_code=400, detail="提供商配置缺少 api_key") resolved_endpoint = provider_config.get("model_list_endpoint", endpoint) if endpoint == "/models" else endpoint # 获取模型列表 models = await _fetch_models_from_provider( base_url=base_url, api_key=api_key, endpoint=resolved_endpoint, parser=parser, client_type=client_type, auth_type=provider_config.get("auth_type", "bearer"), auth_header_name=provider_config.get("auth_header_name", "Authorization"), auth_header_prefix=provider_config.get("auth_header_prefix", "Bearer"), auth_query_name=provider_config.get("auth_query_name", "api_key"), default_headers=provider_config.get("default_headers", {}), default_query=provider_config.get("default_query", {}), ) return { "success": True, "models": models, "provider": provider_name, "count": len(models), } @router.get("/list-by-url") async def get_models_by_url( base_url: str = Query(..., description="提供商的基础 URL"), api_key: str = Query(..., description="API Key"), parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), auth_type: str = Query("bearer", description="鉴权方式 (bearer | header | query | none)"), auth_header_name: str = Query("Authorization", description="Header 鉴权名称"), auth_header_prefix: str = Query("Bearer", description="Header 鉴权前缀"), auth_query_name: str = Query("api_key", description="Query 鉴权参数名"), ): """通过 URL 直接获取模型列表。""" models = await _fetch_models_from_provider( base_url=base_url, api_key=api_key, endpoint=endpoint, parser=parser, client_type=client_type, auth_type=auth_type, auth_header_name=auth_header_name, auth_header_prefix=auth_header_prefix, auth_query_name=auth_query_name, ) return { "success": True, "models": models, "count": len(models), } @router.get("/test-connection") async def test_provider_connection( base_url: str = Query(..., description="提供商的基础 URL"), api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"), client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), ): """ 测试提供商连接状态 分两步测试: 1. 网络连通性测试:向 base_url 发送请求,检查是否能连接 2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效 返回: - network_ok: 网络是否连通 - api_key_valid: API Key 是否有效(仅在提供 api_key 时返回) - latency_ms: 响应延迟(毫秒) - error: 错误信息(如果有) """ import time base_url = _normalize_url(base_url) if not base_url: raise HTTPException(status_code=400, detail="base_url 不能为空") try: base_url = validate_public_url(base_url) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e result = { "network_ok": False, "api_key_valid": None, "latency_ms": None, "error": None, "http_status": None, } # 第一步:测试网络连通性 try: start_time = time.time() async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client: # 尝试 GET 请求 base_url(不需要 API Key) response = await client.get(base_url) latency = (time.time() - start_time) * 1000 result["network_ok"] = True result["latency_ms"] = round(latency, 2) result["http_status"] = response.status_code except httpx.ConnectError as e: result["error"] = f"连接失败:无法连接到服务器 ({str(e)})" return result except httpx.TimeoutException: result["error"] = "连接超时:服务器响应时间过长" return result except httpx.RequestError as e: result["error"] = f"请求错误:{str(e)}" return result except Exception as e: result["error"] = f"未知错误:{str(e)}" return result # 第二步:如果提供了 API Key,验证其有效性 if api_key: try: start_time = time.time() async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: headers = {"Content-Type": "application/json"} params = {} if client_type == "gemini": # Gemini 使用 URL 参数传递 API Key params["key"] = api_key else: # OpenAI 兼容格式使用 Authorization 头 headers["Authorization"] = f"Bearer {api_key}" # 尝试获取模型列表 models_url = f"{base_url}/models" response = await client.get(models_url, headers=headers, params=params) if response.status_code == 200: result["api_key_valid"] = True elif response.status_code in (401, 403): result["api_key_valid"] = False result["error"] = "API Key 无效或已过期" else: # 其他状态码,可能是端点不支持,但 Key 可能是有效的 result["api_key_valid"] = None except Exception as e: # API Key 验证失败不影响网络连通性结果 logger.warning(f"API Key 验证失败: {e}") result["api_key_valid"] = None return result @router.post("/test-connection-by-name") async def test_provider_connection_by_name( provider_name: str = Query(..., description="提供商名称"), ): """ 通过提供商名称测试连接(从配置文件读取信息) """ # 读取配置文件 model_config_path = os.path.join(CONFIG_DIR, "model_config.toml") if not os.path.exists(model_config_path): raise HTTPException(status_code=404, detail="配置文件不存在") with open(model_config_path, "r", encoding="utf-8") as f: config = tomlkit.load(f) # 查找提供商 providers = config.get("api_providers", []) provider = next((item for item in providers if item.get("name") == provider_name), None) if not provider: raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}") base_url = provider.get("base_url", "") api_key = provider.get("api_key", "") client_type = provider.get("client_type", "openai") if not base_url: raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") # 调用测试接口 return await test_provider_connection( base_url=base_url, api_key=api_key if api_key else None, client_type=client_type, )