From 43db0aa9abf228c7901861f4ea133033c3079a9a Mon Sep 17 00:00:00 2001 From: anderwer Date: Mon, 16 Mar 2026 17:59:08 +0800 Subject: [PATCH 1/2] fix: validate gemini provider tests with query api key --- src/webui/routers/model.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/webui/routers/model.py b/src/webui/routers/model.py index 2f67aca5..f8e0fcc0 100644 --- a/src/webui/routers/model.py +++ b/src/webui/routers/model.py @@ -252,6 +252,7 @@ async def get_models_by_url( async def test_provider_connection( base_url: str = Query(..., description="提供商的基础 URL"), api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"), + client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), ): """ 测试提供商连接状态 @@ -315,13 +316,19 @@ async def test_provider_connection( try: start_time = time.time() async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - } + headers = {"Content-Type": "application/json"} + params = {} + + if client_type == "gemini": + # Gemini 使用 URL 参数传递 API Key + params["key"] = api_key + else: + # OpenAI 兼容格式使用 Authorization 头 + headers["Authorization"] = f"Bearer {api_key}" + # 尝试获取模型列表 models_url = f"{base_url}/models" - response = await client.get(models_url, headers=headers) + response = await client.get(models_url, headers=headers, params=params) if response.status_code == 200: result["api_key_valid"] = True @@ -364,9 +371,14 @@ async def test_provider_connection_by_name( base_url = provider.get("base_url", "") api_key = provider.get("api_key", "") + client_type = provider.get("client_type", "openai") if not base_url: raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") # 调用测试接口 - return await test_provider_connection(base_url=base_url, api_key=api_key or None) + return await test_provider_connection( + base_url=base_url, + api_key=api_key if api_key else None, + client_type=client_type, + ) From 78415e89c1076a9319c12c1f7d3be22390fa7a74 Mon Sep 17 00:00:00 2001 From: anderwer Date: Thu, 26 Mar 2026 08:58:49 +0800 Subject: [PATCH 2/2] test(webui): cover gemini provider connection auth --- pytests/webui/test_model_routes.py | 187 +++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 pytests/webui/test_model_routes.py diff --git a/pytests/webui/test_model_routes.py b/pytests/webui/test_model_routes.py new file mode 100644 index 00000000..0e05ad87 --- /dev/null +++ b/pytests/webui/test_model_routes.py @@ -0,0 +1,187 @@ +"""模型路由测试 + +验证 Gemini 提供商连接测试会使用查询参数传递 API Key, +并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。 +""" + +import importlib +import sys +from types import ModuleType +from typing import Any + +import pytest + + +def load_model_routes(monkeypatch: pytest.MonkeyPatch): + """在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。""" + config_module = ModuleType("src.config.config") + config_module.__dict__["CONFIG_DIR"] = "." + monkeypatch.setitem(sys.modules, "src.config.config", config_module) + + dependencies_module = ModuleType("src.webui.dependencies") + + async def require_auth(): + return "test-token" + + dependencies_module.__dict__["require_auth"] = require_auth + monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module) + + sys.modules.pop("src.webui.routers.model", None) + return importlib.import_module("src.webui.routers.model") + + +class FakeResponse: + """简化版 HTTP 响应对象。""" + + def __init__(self, status_code: int): + self.status_code = status_code + + +def build_async_client_factory( + responses: list[FakeResponse], + calls: list[dict[str, Any]], +): + """构造一个可记录请求参数的 AsyncClient 替身。""" + + response_iter = iter(responses) + + class FakeAsyncClient: + def __init__(self, *args: Any, **kwargs: Any): + self.args = args + self.kwargs = kwargs + + async def __aenter__(self) -> "FakeAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + async def get( + self, + url: str, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + ) -> FakeResponse: + calls.append( + { + "url": url, + "headers": headers or {}, + "params": params or {}, + } + ) + return next(response_iter) + + return FakeAsyncClient + + +@pytest.mark.asyncio +async def test_test_provider_connection_uses_query_api_key_for_gemini( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Gemini 连接测试应通过查询参数传递 API Key。""" + model_routes = load_model_routes(monkeypatch) + calls: list[dict[str, Any]] = [] + fake_client_class = build_async_client_factory( + responses=[FakeResponse(200), FakeResponse(200)], + calls=calls, + ) + monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) + + result = await model_routes.test_provider_connection( + base_url="https://generativelanguage.googleapis.com/v1beta", + api_key="valid-gemini-key", + client_type="gemini", + ) + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert len(calls) == 2 + + network_call = calls[0] + validation_call = calls[1] + + assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta" + assert network_call["headers"] == {} + assert network_call["params"] == {} + + assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models" + assert validation_call["params"] == {"key": "valid-gemini-key"} + assert validation_call["headers"] == {"Content-Type": "application/json"} + assert "Authorization" not in validation_call["headers"] + + +@pytest.mark.asyncio +async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """非 Gemini 提供商连接测试应继续使用 Bearer 认证。""" + model_routes = load_model_routes(monkeypatch) + calls: list[dict[str, Any]] = [] + fake_client_class = build_async_client_factory( + responses=[FakeResponse(200), FakeResponse(200)], + calls=calls, + ) + monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) + + result = await model_routes.test_provider_connection( + base_url="https://example.com/v1", + api_key="valid-openai-key", + client_type="openai", + ) + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert len(calls) == 2 + + validation_call = calls[1] + + assert validation_call["url"] == "https://example.com/v1/models" + assert validation_call["params"] == {} + assert validation_call["headers"]["Content-Type"] == "application/json" + assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key" + + +@pytest.mark.asyncio +async def test_test_provider_connection_by_name_forwards_provider_client_type( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + """按提供商名称测试连接时,应透传配置中的 client_type。""" + model_routes = load_model_routes(monkeypatch) + config_path = tmp_path / "model_config.toml" + config_path.write_text( + """ +[[api_providers]] +name = "Gemini" +base_url = "https://generativelanguage.googleapis.com/v1beta" +api_key = "valid-gemini-key" +client_type = "gemini" +""".strip(), + encoding="utf-8", + ) + + monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path)) + + captured_kwargs: dict[str, Any] = {} + + async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]: + captured_kwargs.update(kwargs) + return { + "network_ok": True, + "api_key_valid": True, + "latency_ms": 12.34, + "error": None, + "http_status": 200, + } + + monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection) + + result = await model_routes.test_provider_connection_by_name(provider_name="Gemini") + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert captured_kwargs == { + "base_url": "https://generativelanguage.googleapis.com/v1beta", + "api_key": "valid-gemini-key", + "client_type": "gemini", + } \ No newline at end of file