From 78415e89c1076a9319c12c1f7d3be22390fa7a74 Mon Sep 17 00:00:00 2001 From: anderwer Date: Thu, 26 Mar 2026 08:58:49 +0800 Subject: [PATCH] test(webui): cover gemini provider connection auth --- pytests/webui/test_model_routes.py | 187 +++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 pytests/webui/test_model_routes.py diff --git a/pytests/webui/test_model_routes.py b/pytests/webui/test_model_routes.py new file mode 100644 index 00000000..0e05ad87 --- /dev/null +++ b/pytests/webui/test_model_routes.py @@ -0,0 +1,187 @@ +"""模型路由测试 + +验证 Gemini 提供商连接测试会使用查询参数传递 API Key, +并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。 +""" + +import importlib +import sys +from types import ModuleType +from typing import Any + +import pytest + + +def load_model_routes(monkeypatch: pytest.MonkeyPatch): + """在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。""" + config_module = ModuleType("src.config.config") + config_module.__dict__["CONFIG_DIR"] = "." + monkeypatch.setitem(sys.modules, "src.config.config", config_module) + + dependencies_module = ModuleType("src.webui.dependencies") + + async def require_auth(): + return "test-token" + + dependencies_module.__dict__["require_auth"] = require_auth + monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module) + + sys.modules.pop("src.webui.routers.model", None) + return importlib.import_module("src.webui.routers.model") + + +class FakeResponse: + """简化版 HTTP 响应对象。""" + + def __init__(self, status_code: int): + self.status_code = status_code + + +def build_async_client_factory( + responses: list[FakeResponse], + calls: list[dict[str, Any]], +): + """构造一个可记录请求参数的 AsyncClient 替身。""" + + response_iter = iter(responses) + + class FakeAsyncClient: + def __init__(self, *args: Any, **kwargs: Any): + self.args = args + self.kwargs = kwargs + + async def __aenter__(self) -> "FakeAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + async def get( + self, + url: str, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + ) -> FakeResponse: + calls.append( + { + "url": url, + "headers": headers or {}, + "params": params or {}, + } + ) + return next(response_iter) + + return FakeAsyncClient + + +@pytest.mark.asyncio +async def test_test_provider_connection_uses_query_api_key_for_gemini( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Gemini 连接测试应通过查询参数传递 API Key。""" + model_routes = load_model_routes(monkeypatch) + calls: list[dict[str, Any]] = [] + fake_client_class = build_async_client_factory( + responses=[FakeResponse(200), FakeResponse(200)], + calls=calls, + ) + monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) + + result = await model_routes.test_provider_connection( + base_url="https://generativelanguage.googleapis.com/v1beta", + api_key="valid-gemini-key", + client_type="gemini", + ) + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert len(calls) == 2 + + network_call = calls[0] + validation_call = calls[1] + + assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta" + assert network_call["headers"] == {} + assert network_call["params"] == {} + + assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models" + assert validation_call["params"] == {"key": "valid-gemini-key"} + assert validation_call["headers"] == {"Content-Type": "application/json"} + assert "Authorization" not in validation_call["headers"] + + +@pytest.mark.asyncio +async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """非 Gemini 提供商连接测试应继续使用 Bearer 认证。""" + model_routes = load_model_routes(monkeypatch) + calls: list[dict[str, Any]] = [] + fake_client_class = build_async_client_factory( + responses=[FakeResponse(200), FakeResponse(200)], + calls=calls, + ) + monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) + + result = await model_routes.test_provider_connection( + base_url="https://example.com/v1", + api_key="valid-openai-key", + client_type="openai", + ) + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert len(calls) == 2 + + validation_call = calls[1] + + assert validation_call["url"] == "https://example.com/v1/models" + assert validation_call["params"] == {} + assert validation_call["headers"]["Content-Type"] == "application/json" + assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key" + + +@pytest.mark.asyncio +async def test_test_provider_connection_by_name_forwards_provider_client_type( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + """按提供商名称测试连接时,应透传配置中的 client_type。""" + model_routes = load_model_routes(monkeypatch) + config_path = tmp_path / "model_config.toml" + config_path.write_text( + """ +[[api_providers]] +name = "Gemini" +base_url = "https://generativelanguage.googleapis.com/v1beta" +api_key = "valid-gemini-key" +client_type = "gemini" +""".strip(), + encoding="utf-8", + ) + + monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path)) + + captured_kwargs: dict[str, Any] = {} + + async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]: + captured_kwargs.update(kwargs) + return { + "network_ok": True, + "api_key_valid": True, + "latency_ms": 12.34, + "error": None, + "http_status": 200, + } + + monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection) + + result = await model_routes.test_provider_connection_by_name(provider_name="Gemini") + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert captured_kwargs == { + "base_url": "https://generativelanguage.googleapis.com/v1beta", + "api_key": "valid-gemini-key", + "client_type": "gemini", + } \ No newline at end of file