diff --git a/pytests/webui/test_model_routes.py b/pytests/webui/test_model_routes.py new file mode 100644 index 00000000..0e05ad87 --- /dev/null +++ b/pytests/webui/test_model_routes.py @@ -0,0 +1,187 @@ +"""模型路由测试 + +验证 Gemini 提供商连接测试会使用查询参数传递 API Key, +并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。 +""" + +import importlib +import sys +from types import ModuleType +from typing import Any + +import pytest + + +def load_model_routes(monkeypatch: pytest.MonkeyPatch): + """在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。""" + config_module = ModuleType("src.config.config") + config_module.__dict__["CONFIG_DIR"] = "." + monkeypatch.setitem(sys.modules, "src.config.config", config_module) + + dependencies_module = ModuleType("src.webui.dependencies") + + async def require_auth(): + return "test-token" + + dependencies_module.__dict__["require_auth"] = require_auth + monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module) + + sys.modules.pop("src.webui.routers.model", None) + return importlib.import_module("src.webui.routers.model") + + +class FakeResponse: + """简化版 HTTP 响应对象。""" + + def __init__(self, status_code: int): + self.status_code = status_code + + +def build_async_client_factory( + responses: list[FakeResponse], + calls: list[dict[str, Any]], +): + """构造一个可记录请求参数的 AsyncClient 替身。""" + + response_iter = iter(responses) + + class FakeAsyncClient: + def __init__(self, *args: Any, **kwargs: Any): + self.args = args + self.kwargs = kwargs + + async def __aenter__(self) -> "FakeAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + async def get( + self, + url: str, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + ) -> FakeResponse: + calls.append( + { + "url": url, + "headers": headers or {}, + "params": params or {}, + } + ) + return next(response_iter) + + return FakeAsyncClient + + +@pytest.mark.asyncio +async def test_test_provider_connection_uses_query_api_key_for_gemini( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Gemini 连接测试应通过查询参数传递 API Key。""" + model_routes = load_model_routes(monkeypatch) + calls: list[dict[str, Any]] = [] + fake_client_class = build_async_client_factory( + responses=[FakeResponse(200), FakeResponse(200)], + calls=calls, + ) + monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) + + result = await model_routes.test_provider_connection( + base_url="https://generativelanguage.googleapis.com/v1beta", + api_key="valid-gemini-key", + client_type="gemini", + ) + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert len(calls) == 2 + + network_call = calls[0] + validation_call = calls[1] + + assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta" + assert network_call["headers"] == {} + assert network_call["params"] == {} + + assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models" + assert validation_call["params"] == {"key": "valid-gemini-key"} + assert validation_call["headers"] == {"Content-Type": "application/json"} + assert "Authorization" not in validation_call["headers"] + + +@pytest.mark.asyncio +async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """非 Gemini 提供商连接测试应继续使用 Bearer 认证。""" + model_routes = load_model_routes(monkeypatch) + calls: list[dict[str, Any]] = [] + fake_client_class = build_async_client_factory( + responses=[FakeResponse(200), FakeResponse(200)], + calls=calls, + ) + monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) + + result = await model_routes.test_provider_connection( + base_url="https://example.com/v1", + api_key="valid-openai-key", + client_type="openai", + ) + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert len(calls) == 2 + + validation_call = calls[1] + + assert validation_call["url"] == "https://example.com/v1/models" + assert validation_call["params"] == {} + assert validation_call["headers"]["Content-Type"] == "application/json" + assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key" + + +@pytest.mark.asyncio +async def test_test_provider_connection_by_name_forwards_provider_client_type( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + """按提供商名称测试连接时,应透传配置中的 client_type。""" + model_routes = load_model_routes(monkeypatch) + config_path = tmp_path / "model_config.toml" + config_path.write_text( + """ +[[api_providers]] +name = "Gemini" +base_url = "https://generativelanguage.googleapis.com/v1beta" +api_key = "valid-gemini-key" +client_type = "gemini" +""".strip(), + encoding="utf-8", + ) + + monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path)) + + captured_kwargs: dict[str, Any] = {} + + async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]: + captured_kwargs.update(kwargs) + return { + "network_ok": True, + "api_key_valid": True, + "latency_ms": 12.34, + "error": None, + "http_status": 200, + } + + monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection) + + result = await model_routes.test_provider_connection_by_name(provider_name="Gemini") + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert captured_kwargs == { + "base_url": "https://generativelanguage.googleapis.com/v1beta", + "api_key": "valid-gemini-key", + "client_type": "gemini", + } \ No newline at end of file diff --git a/src/A_memorix/core/retrieval/dual_path.py b/src/A_memorix/core/retrieval/dual_path.py index 245bafea..996c02f8 100644 --- a/src/A_memorix/core/retrieval/dual_path.py +++ b/src/A_memorix/core/retrieval/dual_path.py @@ -16,6 +16,7 @@ from src.common.logger import get_logger from ..storage import VectorStore, GraphStore, MetadataStore from ..embedding import EmbeddingAPIAdapter from ..utils.matcher import AhoCorasick +from ..utils.metadata import coerce_metadata_dict from ..utils.time_parser import format_timestamp from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService from .pagerank import PersonalizedPageRank, PageRankConfig @@ -482,7 +483,7 @@ class DualPathRetriever: score=float(item.score), result_type=item.result_type, source=item.source, - metadata=dict(item.metadata or {}), + metadata=coerce_metadata_dict(item.metadata), ) def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: @@ -762,7 +763,7 @@ class DualPathRetriever: existing = self._clone_retrieval_result(item) merged[item.hash_value] = existing else: - for key, value in dict(item.metadata or {}).items(): + for key, value in coerce_metadata_dict(item.metadata).items(): if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): existing.metadata[key] = value source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") diff --git a/src/A_memorix/core/runtime/sdk_memory_kernel.py b/src/A_memorix/core/runtime/sdk_memory_kernel.py index dfbbbd77..26ff503a 100644 --- a/src/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/src/A_memorix/core/runtime/sdk_memory_kernel.py @@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService from ..utils.episode_segmentation_service import EpisodeSegmentationService from ..utils.episode_service import EpisodeService from ..utils.hash import compute_hash, normalize_text +from ..utils.metadata import coerce_metadata_dict from ..utils.person_profile_service import PersonProfileService from ..utils.relation_write_service import RelationWriteService from ..utils.retrieval_tuning_manager import RetrievalTuningManager @@ -871,7 +872,7 @@ class SDKMemoryKernel: "detail": "chat_filtered", } - summary_meta = dict(metadata or {}) + summary_meta = coerce_metadata_dict(metadata) summary_meta.setdefault("kind", "chat_summary") if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): result = await self.summarize_chat_stream( @@ -961,7 +962,7 @@ class SDKMemoryKernel: participant_tokens = self._tokens(participants) entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) source = self._build_source(source_type, chat_id, person_tokens) - paragraph_meta = dict(metadata or {}) + paragraph_meta = coerce_metadata_dict(metadata) paragraph_meta.update( { "external_id": external_token, diff --git a/src/A_memorix/core/utils/metadata.py b/src/A_memorix/core/utils/metadata.py new file mode 100644 index 00000000..5a1dafc1 --- /dev/null +++ b/src/A_memorix/core/utils/metadata.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Dict + + +def coerce_metadata_dict(value: Any) -> Dict[str, Any]: + """返回字典,如果输入值不是字典则返回空字典。""" + if isinstance(value, Mapping): + return dict(value) + return {} diff --git a/src/A_memorix/core/utils/person_profile_service.py b/src/A_memorix/core/utils/person_profile_service.py index 081eaa66..eec531e0 100644 --- a/src/A_memorix/core/utils/person_profile_service.py +++ b/src/A_memorix/core/utils/person_profile_service.py @@ -27,6 +27,7 @@ from ..retrieval import ( GraphRelationRecallConfig, ) from ..storage import MetadataStore, GraphStore, VectorStore +from .metadata import coerce_metadata_dict logger = get_logger("A_Memorix.PersonProfileService") @@ -334,7 +335,7 @@ class PersonProfileService: if not pid: return False - metadata = self._metadata_dict(relation.get("metadata")) + metadata = coerce_metadata_dict(relation.get("metadata")) if str(metadata.get("person_id", "") or "").strip() == pid: return True if pid in self._list_tokens(metadata.get("person_ids")): @@ -350,7 +351,7 @@ class PersonProfileService: payload = { "hash": source_paragraph, "source": str(paragraph.get("source", "") or ""), - "metadata": self._metadata_dict(paragraph.get("metadata")), + "metadata": coerce_metadata_dict(paragraph.get("metadata")), } return self._is_evidence_bound_to_person(payload, person_id=pid) @@ -385,15 +386,11 @@ class PersonProfileService: "score": 1.1, "content": content[:220], "source": str(row.get("source", "") or source), - "metadata": dict(row.get("metadata", {}) or {}), + "metadata": coerce_metadata_dict(row.get("metadata")), } ) return self._filter_stale_paragraph_evidence(evidence) - @staticmethod - def _metadata_dict(value: Any) -> Dict[str, Any]: - return dict(value) if isinstance(value, dict) else {} - @staticmethod def _list_tokens(value: Any) -> List[str]: if value is None: @@ -414,7 +411,7 @@ class PersonProfileService: if not pid: return False - metadata = self._metadata_dict(item.get("metadata")) + metadata = coerce_metadata_dict(item.get("metadata")) source = str(item.get("source", "") or metadata.get("source", "") or "").strip() if source == f"person_fact:{pid}": return True @@ -440,15 +437,15 @@ class PersonProfileService: paragraph_hash: str, metadata: Dict[str, Any], ) -> Tuple[Dict[str, Any], str]: - merged = self._metadata_dict(metadata) + merged = coerce_metadata_dict(metadata) source = str(merged.get("source", "") or "").strip() try: paragraph = self.metadata_store.get_paragraph(paragraph_hash) except Exception: paragraph = None if isinstance(paragraph, dict): - paragraph_metadata = paragraph.get("metadata", {}) or {} - if isinstance(paragraph_metadata, dict): + paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata")) + if paragraph_metadata: merged = {**paragraph_metadata, **merged} source = source or str(paragraph.get("source", "") or "").strip() source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source) @@ -538,7 +535,7 @@ class PersonProfileService: "score": 0.0, "content": str(para.get("content", ""))[:180], "source": str(para.get("source", "") or ""), - "metadata": self._metadata_dict(para.get("metadata")), + "metadata": coerce_metadata_dict(para.get("metadata")), } ) if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id): @@ -562,18 +559,18 @@ class PersonProfileService: logger.warning(f"向量证据召回失败: alias={alias}, err={e}") continue for item in results: - h = str(getattr(item, "hash_value", "") or "") + h = str(item.hash_value or "") if not h or h in seen_hash: continue metadata, source = self._enrich_paragraph_evidence_metadata( h, - self._metadata_dict(getattr(item, "metadata", {})), + coerce_metadata_dict(item.metadata), ) payload = { "hash": h, - "type": str(getattr(item, "result_type", "")), - "score": float(getattr(item, "score", 0.0) or 0.0), - "content": str(getattr(item, "content", "") or "")[:220], + "type": str(item.result_type), + "score": float(item.score or 0.0), + "content": str(item.content or "")[:220], "source": source, "metadata": metadata, } diff --git a/src/A_memorix/core/utils/search_execution_service.py b/src/A_memorix/core/utils/search_execution_service.py index ace051e9..f05c4820 100644 --- a/src/A_memorix/core/utils/search_execution_service.py +++ b/src/A_memorix/core/utils/search_execution_service.py @@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple from src.common.logger import get_logger -from ..retrieval import TemporalQueryOptions +from ..retrieval import RetrievalResult, TemporalQueryOptions +from .metadata import coerce_metadata_dict from .search_postprocess import ( apply_safe_content_dedup, maybe_apply_smart_path_fallback, @@ -286,8 +287,11 @@ class SearchExecutionService: ) async def _executor() -> Dict[str, Any]: - original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) - setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) + retriever_config = getattr(retriever, "config", None) + has_runtime_ppr_switch = retriever_config is not None and hasattr(retriever_config, "enable_ppr") + original_ppr = bool(retriever_config.enable_ppr) if has_runtime_ppr_switch else None + if has_runtime_ppr_switch: + retriever_config.enable_ppr = bool(request.enable_ppr) started_at = time.time() try: retrieved = await retriever.retrieve( @@ -380,7 +384,8 @@ class SearchExecutionService: elapsed_ms = (time.time() - started_at) * 1000.0 return {"results": retrieved, "elapsed_ms": elapsed_ms} finally: - setattr(retriever.config, "enable_ppr", original_ppr) + if has_runtime_ppr_switch: + retriever_config.enable_ppr = bool(original_ppr) dedup_hit = False try: @@ -421,18 +426,18 @@ class SearchExecutionService: ) @staticmethod - def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: + def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]: serialized: List[Dict[str, Any]] = [] for item in results: - metadata = dict(getattr(item, "metadata", {}) or {}) + metadata = coerce_metadata_dict(item.metadata) if "time_meta" not in metadata: metadata["time_meta"] = {} serialized.append( { - "hash": getattr(item, "hash_value", ""), - "type": getattr(item, "result_type", ""), - "score": float(getattr(item, "score", 0.0)), - "content": getattr(item, "content", ""), + "hash": item.hash_value, + "type": item.result_type, + "score": float(item.score), + "content": item.content, "metadata": metadata, } ) diff --git a/src/webui/routers/model.py b/src/webui/routers/model.py index 27323835..cf2ad1b0 100644 --- a/src/webui/routers/model.py +++ b/src/webui/routers/model.py @@ -296,6 +296,7 @@ async def get_models_by_url( async def test_provider_connection( base_url: str = Query(..., description="提供商的基础 URL"), api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"), + client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), ): """ 测试提供商连接状态 @@ -359,13 +360,19 @@ async def test_provider_connection( try: start_time = time.time() async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - } + headers = {"Content-Type": "application/json"} + params = {} + + if client_type == "gemini": + # Gemini 使用 URL 参数传递 API Key + params["key"] = api_key + else: + # OpenAI 兼容格式使用 Authorization 头 + headers["Authorization"] = f"Bearer {api_key}" + # 尝试获取模型列表 models_url = f"{base_url}/models" - response = await client.get(models_url, headers=headers) + response = await client.get(models_url, headers=headers, params=params) if response.status_code == 200: result["api_key_valid"] = True @@ -408,9 +415,14 @@ async def test_provider_connection_by_name( base_url = provider.get("base_url", "") api_key = provider.get("api_key", "") + client_type = provider.get("client_type", "openai") if not base_url: raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") # 调用测试接口 - return await test_provider_connection(base_url=base_url, api_key=api_key or None) + return await test_provider_connection( + base_url=base_url, + api_key=api_key if api_key else None, + client_type=client_type, + )