This commit is contained in:
SengokuCola
2026-05-09 02:33:03 +08:00
7 changed files with 251 additions and 37 deletions

View File

@@ -0,0 +1,187 @@
"""模型路由测试
验证 Gemini 提供商连接测试会使用查询参数传递 API Key
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
"""
import importlib
import sys
from types import ModuleType
from typing import Any
import pytest
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
config_module = ModuleType("src.config.config")
config_module.__dict__["CONFIG_DIR"] = "."
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
dependencies_module = ModuleType("src.webui.dependencies")
async def require_auth():
return "test-token"
dependencies_module.__dict__["require_auth"] = require_auth
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
sys.modules.pop("src.webui.routers.model", None)
return importlib.import_module("src.webui.routers.model")
class FakeResponse:
"""简化版 HTTP 响应对象。"""
def __init__(self, status_code: int):
self.status_code = status_code
def build_async_client_factory(
responses: list[FakeResponse],
calls: list[dict[str, Any]],
):
"""构造一个可记录请求参数的 AsyncClient 替身。"""
response_iter = iter(responses)
class FakeAsyncClient:
def __init__(self, *args: Any, **kwargs: Any):
self.args = args
self.kwargs = kwargs
async def __aenter__(self) -> "FakeAsyncClient":
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
async def get(
self,
url: str,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> FakeResponse:
calls.append(
{
"url": url,
"headers": headers or {},
"params": params or {},
}
)
return next(response_iter)
return FakeAsyncClient
@pytest.mark.asyncio
async def test_test_provider_connection_uses_query_api_key_for_gemini(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Gemini 连接测试应通过查询参数传递 API Key。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://generativelanguage.googleapis.com/v1beta",
api_key="valid-gemini-key",
client_type="gemini",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
network_call = calls[0]
validation_call = calls[1]
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
assert network_call["headers"] == {}
assert network_call["params"] == {}
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
assert validation_call["params"] == {"key": "valid-gemini-key"}
assert validation_call["headers"] == {"Content-Type": "application/json"}
assert "Authorization" not in validation_call["headers"]
@pytest.mark.asyncio
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://example.com/v1",
api_key="valid-openai-key",
client_type="openai",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
validation_call = calls[1]
assert validation_call["url"] == "https://example.com/v1/models"
assert validation_call["params"] == {}
assert validation_call["headers"]["Content-Type"] == "application/json"
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
@pytest.mark.asyncio
async def test_test_provider_connection_by_name_forwards_provider_client_type(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
) -> None:
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
model_routes = load_model_routes(monkeypatch)
config_path = tmp_path / "model_config.toml"
config_path.write_text(
"""
[[api_providers]]
name = "Gemini"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "valid-gemini-key"
client_type = "gemini"
""".strip(),
encoding="utf-8",
)
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
captured_kwargs: dict[str, Any] = {}
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
captured_kwargs.update(kwargs)
return {
"network_ok": True,
"api_key_valid": True,
"latency_ms": 12.34,
"error": None,
"http_status": 200,
}
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert captured_kwargs == {
"base_url": "https://generativelanguage.googleapis.com/v1beta",
"api_key": "valid-gemini-key",
"client_type": "gemini",
}

View File

@@ -16,6 +16,7 @@ from src.common.logger import get_logger
from ..storage import VectorStore, GraphStore, MetadataStore
from ..embedding import EmbeddingAPIAdapter
from ..utils.matcher import AhoCorasick
from ..utils.metadata import coerce_metadata_dict
from ..utils.time_parser import format_timestamp
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
from .pagerank import PersonalizedPageRank, PageRankConfig
@@ -482,7 +483,7 @@ class DualPathRetriever:
score=float(item.score),
result_type=item.result_type,
source=item.source,
metadata=dict(item.metadata or {}),
metadata=coerce_metadata_dict(item.metadata),
)
def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]:
@@ -762,7 +763,7 @@ class DualPathRetriever:
existing = self._clone_retrieval_result(item)
merged[item.hash_value] = existing
else:
for key, value in dict(item.metadata or {}).items():
for key, value in coerce_metadata_dict(item.metadata).items():
if key not in existing.metadata or existing.metadata.get(key) in (None, "", []):
existing.metadata[key] = value
source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search")

View File

@@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService
from ..utils.episode_segmentation_service import EpisodeSegmentationService
from ..utils.episode_service import EpisodeService
from ..utils.hash import compute_hash, normalize_text
from ..utils.metadata import coerce_metadata_dict
from ..utils.person_profile_service import PersonProfileService
from ..utils.relation_write_service import RelationWriteService
from ..utils.retrieval_tuning_manager import RetrievalTuningManager
@@ -871,7 +872,7 @@ class SDKMemoryKernel:
"detail": "chat_filtered",
}
summary_meta = dict(metadata or {})
summary_meta = coerce_metadata_dict(metadata)
summary_meta.setdefault("kind", "chat_summary")
if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)):
result = await self.summarize_chat_stream(
@@ -961,7 +962,7 @@ class SDKMemoryKernel:
participant_tokens = self._tokens(participants)
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
source = self._build_source(source_type, chat_id, person_tokens)
paragraph_meta = dict(metadata or {})
paragraph_meta = coerce_metadata_dict(metadata)
paragraph_meta.update(
{
"external_id": external_token,

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Dict
def coerce_metadata_dict(value: Any) -> Dict[str, Any]:
"""返回字典,如果输入值不是字典则返回空字典。"""
if isinstance(value, Mapping):
return dict(value)
return {}

View File

@@ -27,6 +27,7 @@ from ..retrieval import (
GraphRelationRecallConfig,
)
from ..storage import MetadataStore, GraphStore, VectorStore
from .metadata import coerce_metadata_dict
logger = get_logger("A_Memorix.PersonProfileService")
@@ -334,7 +335,7 @@ class PersonProfileService:
if not pid:
return False
metadata = self._metadata_dict(relation.get("metadata"))
metadata = coerce_metadata_dict(relation.get("metadata"))
if str(metadata.get("person_id", "") or "").strip() == pid:
return True
if pid in self._list_tokens(metadata.get("person_ids")):
@@ -350,7 +351,7 @@ class PersonProfileService:
payload = {
"hash": source_paragraph,
"source": str(paragraph.get("source", "") or ""),
"metadata": self._metadata_dict(paragraph.get("metadata")),
"metadata": coerce_metadata_dict(paragraph.get("metadata")),
}
return self._is_evidence_bound_to_person(payload, person_id=pid)
@@ -385,15 +386,11 @@ class PersonProfileService:
"score": 1.1,
"content": content[:220],
"source": str(row.get("source", "") or source),
"metadata": dict(row.get("metadata", {}) or {}),
"metadata": coerce_metadata_dict(row.get("metadata")),
}
)
return self._filter_stale_paragraph_evidence(evidence)
@staticmethod
def _metadata_dict(value: Any) -> Dict[str, Any]:
return dict(value) if isinstance(value, dict) else {}
@staticmethod
def _list_tokens(value: Any) -> List[str]:
if value is None:
@@ -414,7 +411,7 @@ class PersonProfileService:
if not pid:
return False
metadata = self._metadata_dict(item.get("metadata"))
metadata = coerce_metadata_dict(item.get("metadata"))
source = str(item.get("source", "") or metadata.get("source", "") or "").strip()
if source == f"person_fact:{pid}":
return True
@@ -440,15 +437,15 @@ class PersonProfileService:
paragraph_hash: str,
metadata: Dict[str, Any],
) -> Tuple[Dict[str, Any], str]:
merged = self._metadata_dict(metadata)
merged = coerce_metadata_dict(metadata)
source = str(merged.get("source", "") or "").strip()
try:
paragraph = self.metadata_store.get_paragraph(paragraph_hash)
except Exception:
paragraph = None
if isinstance(paragraph, dict):
paragraph_metadata = paragraph.get("metadata", {}) or {}
if isinstance(paragraph_metadata, dict):
paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata"))
if paragraph_metadata:
merged = {**paragraph_metadata, **merged}
source = source or str(paragraph.get("source", "") or "").strip()
source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source)
@@ -538,7 +535,7 @@ class PersonProfileService:
"score": 0.0,
"content": str(para.get("content", ""))[:180],
"source": str(para.get("source", "") or ""),
"metadata": self._metadata_dict(para.get("metadata")),
"metadata": coerce_metadata_dict(para.get("metadata")),
}
)
if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id):
@@ -562,18 +559,18 @@ class PersonProfileService:
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
continue
for item in results:
h = str(getattr(item, "hash_value", "") or "")
h = str(item.hash_value or "")
if not h or h in seen_hash:
continue
metadata, source = self._enrich_paragraph_evidence_metadata(
h,
self._metadata_dict(getattr(item, "metadata", {})),
coerce_metadata_dict(item.metadata),
)
payload = {
"hash": h,
"type": str(getattr(item, "result_type", "")),
"score": float(getattr(item, "score", 0.0) or 0.0),
"content": str(getattr(item, "content", "") or "")[:220],
"type": str(item.result_type),
"score": float(item.score or 0.0),
"content": str(item.content or "")[:220],
"source": source,
"metadata": metadata,
}

View File

@@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from ..retrieval import TemporalQueryOptions
from ..retrieval import RetrievalResult, TemporalQueryOptions
from .metadata import coerce_metadata_dict
from .search_postprocess import (
apply_safe_content_dedup,
maybe_apply_smart_path_fallback,
@@ -286,8 +287,11 @@ class SearchExecutionService:
)
async def _executor() -> Dict[str, Any]:
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
retriever_config = getattr(retriever, "config", None)
has_runtime_ppr_switch = retriever_config is not None and hasattr(retriever_config, "enable_ppr")
original_ppr = bool(retriever_config.enable_ppr) if has_runtime_ppr_switch else None
if has_runtime_ppr_switch:
retriever_config.enable_ppr = bool(request.enable_ppr)
started_at = time.time()
try:
retrieved = await retriever.retrieve(
@@ -380,7 +384,8 @@ class SearchExecutionService:
elapsed_ms = (time.time() - started_at) * 1000.0
return {"results": retrieved, "elapsed_ms": elapsed_ms}
finally:
setattr(retriever.config, "enable_ppr", original_ppr)
if has_runtime_ppr_switch:
retriever_config.enable_ppr = bool(original_ppr)
dedup_hit = False
try:
@@ -421,18 +426,18 @@ class SearchExecutionService:
)
@staticmethod
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]:
def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
serialized: List[Dict[str, Any]] = []
for item in results:
metadata = dict(getattr(item, "metadata", {}) or {})
metadata = coerce_metadata_dict(item.metadata)
if "time_meta" not in metadata:
metadata["time_meta"] = {}
serialized.append(
{
"hash": getattr(item, "hash_value", ""),
"type": getattr(item, "result_type", ""),
"score": float(getattr(item, "score", 0.0)),
"content": getattr(item, "content", ""),
"hash": item.hash_value,
"type": item.result_type,
"score": float(item.score),
"content": item.content,
"metadata": metadata,
}
)

View File

@@ -296,6 +296,7 @@ async def get_models_by_url(
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
):
"""
测试提供商连接状态
@@ -359,13 +360,19 @@ async def test_provider_connection(
try:
start_time = time.time()
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
headers = {"Content-Type": "application/json"}
params = {}
if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
# 尝试获取模型列表
models_url = f"{base_url}/models"
response = await client.get(models_url, headers=headers)
response = await client.get(models_url, headers=headers, params=params)
if response.status_code == 200:
result["api_key_valid"] = True
@@ -408,9 +415,14 @@ async def test_provider_connection_by_name(
base_url = provider.get("base_url", "")
api_key = provider.get("api_key", "")
client_type = provider.get("client_type", "openai")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key or None)
return await test_provider_connection(
base_url=base_url,
api_key=api_key if api_key else None,
client_type=client_type,
)