WebUI 前端 & 后端超级大重构
This commit is contained in:
@@ -6,27 +6,18 @@
|
||||
|
||||
import os
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
|
||||
from typing import Optional
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"], dependencies=[Depends(require_auth)])
|
||||
# 模型获取器配置
|
||||
MODEL_FETCHER_CONFIG = {
|
||||
# OpenAI 兼容格式的提供商
|
||||
@@ -44,9 +35,7 @@ MODEL_FETCHER_CONFIG = {
|
||||
|
||||
def _normalize_url(url: str) -> str:
|
||||
"""规范化 URL(去掉尾部斜杠)"""
|
||||
if not url:
|
||||
return ""
|
||||
return url.rstrip("/")
|
||||
return url.rstrip("/") if url else ""
|
||||
|
||||
|
||||
def _parse_openai_response(data: dict) -> list[dict]:
|
||||
@@ -55,18 +44,18 @@ def _parse_openai_response(data: dict) -> list[dict]:
|
||||
|
||||
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
if isinstance(model, dict) and "id" in model:
|
||||
models.append(
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
)
|
||||
return models
|
||||
if "data" not in data or not isinstance(data["data"], list):
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
for model in data["data"]
|
||||
if isinstance(model, dict) and "id" in model
|
||||
]
|
||||
|
||||
|
||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
@@ -178,11 +167,8 @@ def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
providers = config_data.get("api_providers", [])
|
||||
for provider in providers:
|
||||
if provider.get("name") == provider_name:
|
||||
return dict(provider)
|
||||
|
||||
return None
|
||||
provider = next((provider for provider in providers if provider.get("name") == provider_name), None)
|
||||
return dict(provider) if provider is not None else None
|
||||
except Exception as e:
|
||||
logger.error(f"读取提供商配置失败: {e}")
|
||||
return None
|
||||
@@ -193,7 +179,6 @@ async def get_provider_models(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
@@ -238,7 +223,6 @@ async def get_models_by_url(
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||
@@ -262,7 +246,6 @@ async def get_models_by_url(
|
||||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
@@ -349,7 +332,6 @@ async def test_provider_connection(
|
||||
@router.post("/test-connection-by-name")
|
||||
async def test_provider_connection_by_name(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过提供商名称测试连接(从配置文件读取信息)
|
||||
@@ -364,11 +346,7 @@ async def test_provider_connection_by_name(
|
||||
|
||||
# 查找提供商
|
||||
providers = config.get("api_providers", [])
|
||||
provider = None
|
||||
for p in providers:
|
||||
if p.get("name") == provider_name:
|
||||
provider = p
|
||||
break
|
||||
provider = next((item for item in providers if item.get("name") == provider_name), None)
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
@@ -380,4 +358,4 @@ async def test_provider_connection_by_name(
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
|
||||
# 调用测试接口
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key or None)
|
||||
|
||||
Reference in New Issue
Block a user