This commit is contained in:
SengokuCola
2026-05-09 02:33:03 +08:00
7 changed files with 251 additions and 37 deletions

View File

@@ -0,0 +1,187 @@
"""模型路由测试
验证 Gemini 提供商连接测试会使用查询参数传递 API Key
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
"""
import importlib
import sys
from types import ModuleType
from typing import Any
import pytest
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
config_module = ModuleType("src.config.config")
config_module.__dict__["CONFIG_DIR"] = "."
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
dependencies_module = ModuleType("src.webui.dependencies")
async def require_auth():
return "test-token"
dependencies_module.__dict__["require_auth"] = require_auth
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
sys.modules.pop("src.webui.routers.model", None)
return importlib.import_module("src.webui.routers.model")
class FakeResponse:
"""简化版 HTTP 响应对象。"""
def __init__(self, status_code: int):
self.status_code = status_code
def build_async_client_factory(
responses: list[FakeResponse],
calls: list[dict[str, Any]],
):
"""构造一个可记录请求参数的 AsyncClient 替身。"""
response_iter = iter(responses)
class FakeAsyncClient:
def __init__(self, *args: Any, **kwargs: Any):
self.args = args
self.kwargs = kwargs
async def __aenter__(self) -> "FakeAsyncClient":
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
async def get(
self,
url: str,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> FakeResponse:
calls.append(
{
"url": url,
"headers": headers or {},
"params": params or {},
}
)
return next(response_iter)
return FakeAsyncClient
@pytest.mark.asyncio
async def test_test_provider_connection_uses_query_api_key_for_gemini(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Gemini 连接测试应通过查询参数传递 API Key。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://generativelanguage.googleapis.com/v1beta",
api_key="valid-gemini-key",
client_type="gemini",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
network_call = calls[0]
validation_call = calls[1]
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
assert network_call["headers"] == {}
assert network_call["params"] == {}
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
assert validation_call["params"] == {"key": "valid-gemini-key"}
assert validation_call["headers"] == {"Content-Type": "application/json"}
assert "Authorization" not in validation_call["headers"]
@pytest.mark.asyncio
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://example.com/v1",
api_key="valid-openai-key",
client_type="openai",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
validation_call = calls[1]
assert validation_call["url"] == "https://example.com/v1/models"
assert validation_call["params"] == {}
assert validation_call["headers"]["Content-Type"] == "application/json"
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
@pytest.mark.asyncio
async def test_test_provider_connection_by_name_forwards_provider_client_type(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
) -> None:
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
model_routes = load_model_routes(monkeypatch)
config_path = tmp_path / "model_config.toml"
config_path.write_text(
"""
[[api_providers]]
name = "Gemini"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "valid-gemini-key"
client_type = "gemini"
""".strip(),
encoding="utf-8",
)
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
captured_kwargs: dict[str, Any] = {}
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
captured_kwargs.update(kwargs)
return {
"network_ok": True,
"api_key_valid": True,
"latency_ms": 12.34,
"error": None,
"http_status": 200,
}
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert captured_kwargs == {
"base_url": "https://generativelanguage.googleapis.com/v1beta",
"api_key": "valid-gemini-key",
"client_type": "gemini",
}

View File

@@ -16,6 +16,7 @@ from src.common.logger import get_logger
from ..storage import VectorStore, GraphStore, MetadataStore from ..storage import VectorStore, GraphStore, MetadataStore
from ..embedding import EmbeddingAPIAdapter from ..embedding import EmbeddingAPIAdapter
from ..utils.matcher import AhoCorasick from ..utils.matcher import AhoCorasick
from ..utils.metadata import coerce_metadata_dict
from ..utils.time_parser import format_timestamp from ..utils.time_parser import format_timestamp
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
from .pagerank import PersonalizedPageRank, PageRankConfig from .pagerank import PersonalizedPageRank, PageRankConfig
@@ -482,7 +483,7 @@ class DualPathRetriever:
score=float(item.score), score=float(item.score),
result_type=item.result_type, result_type=item.result_type,
source=item.source, source=item.source,
metadata=dict(item.metadata or {}), metadata=coerce_metadata_dict(item.metadata),
) )
def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]:
@@ -762,7 +763,7 @@ class DualPathRetriever:
existing = self._clone_retrieval_result(item) existing = self._clone_retrieval_result(item)
merged[item.hash_value] = existing merged[item.hash_value] = existing
else: else:
for key, value in dict(item.metadata or {}).items(): for key, value in coerce_metadata_dict(item.metadata).items():
if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): if key not in existing.metadata or existing.metadata.get(key) in (None, "", []):
existing.metadata[key] = value existing.metadata[key] = value
source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search")

View File

@@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService
from ..utils.episode_segmentation_service import EpisodeSegmentationService from ..utils.episode_segmentation_service import EpisodeSegmentationService
from ..utils.episode_service import EpisodeService from ..utils.episode_service import EpisodeService
from ..utils.hash import compute_hash, normalize_text from ..utils.hash import compute_hash, normalize_text
from ..utils.metadata import coerce_metadata_dict
from ..utils.person_profile_service import PersonProfileService from ..utils.person_profile_service import PersonProfileService
from ..utils.relation_write_service import RelationWriteService from ..utils.relation_write_service import RelationWriteService
from ..utils.retrieval_tuning_manager import RetrievalTuningManager from ..utils.retrieval_tuning_manager import RetrievalTuningManager
@@ -871,7 +872,7 @@ class SDKMemoryKernel:
"detail": "chat_filtered", "detail": "chat_filtered",
} }
summary_meta = dict(metadata or {}) summary_meta = coerce_metadata_dict(metadata)
summary_meta.setdefault("kind", "chat_summary") summary_meta.setdefault("kind", "chat_summary")
if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)):
result = await self.summarize_chat_stream( result = await self.summarize_chat_stream(
@@ -961,7 +962,7 @@ class SDKMemoryKernel:
participant_tokens = self._tokens(participants) participant_tokens = self._tokens(participants)
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
source = self._build_source(source_type, chat_id, person_tokens) source = self._build_source(source_type, chat_id, person_tokens)
paragraph_meta = dict(metadata or {}) paragraph_meta = coerce_metadata_dict(metadata)
paragraph_meta.update( paragraph_meta.update(
{ {
"external_id": external_token, "external_id": external_token,

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Dict
def coerce_metadata_dict(value: Any) -> Dict[str, Any]:
"""返回字典,如果输入值不是字典则返回空字典。"""
if isinstance(value, Mapping):
return dict(value)
return {}

View File

@@ -27,6 +27,7 @@ from ..retrieval import (
GraphRelationRecallConfig, GraphRelationRecallConfig,
) )
from ..storage import MetadataStore, GraphStore, VectorStore from ..storage import MetadataStore, GraphStore, VectorStore
from .metadata import coerce_metadata_dict
logger = get_logger("A_Memorix.PersonProfileService") logger = get_logger("A_Memorix.PersonProfileService")
@@ -334,7 +335,7 @@ class PersonProfileService:
if not pid: if not pid:
return False return False
metadata = self._metadata_dict(relation.get("metadata")) metadata = coerce_metadata_dict(relation.get("metadata"))
if str(metadata.get("person_id", "") or "").strip() == pid: if str(metadata.get("person_id", "") or "").strip() == pid:
return True return True
if pid in self._list_tokens(metadata.get("person_ids")): if pid in self._list_tokens(metadata.get("person_ids")):
@@ -350,7 +351,7 @@ class PersonProfileService:
payload = { payload = {
"hash": source_paragraph, "hash": source_paragraph,
"source": str(paragraph.get("source", "") or ""), "source": str(paragraph.get("source", "") or ""),
"metadata": self._metadata_dict(paragraph.get("metadata")), "metadata": coerce_metadata_dict(paragraph.get("metadata")),
} }
return self._is_evidence_bound_to_person(payload, person_id=pid) return self._is_evidence_bound_to_person(payload, person_id=pid)
@@ -385,15 +386,11 @@ class PersonProfileService:
"score": 1.1, "score": 1.1,
"content": content[:220], "content": content[:220],
"source": str(row.get("source", "") or source), "source": str(row.get("source", "") or source),
"metadata": dict(row.get("metadata", {}) or {}), "metadata": coerce_metadata_dict(row.get("metadata")),
} }
) )
return self._filter_stale_paragraph_evidence(evidence) return self._filter_stale_paragraph_evidence(evidence)
@staticmethod
def _metadata_dict(value: Any) -> Dict[str, Any]:
return dict(value) if isinstance(value, dict) else {}
@staticmethod @staticmethod
def _list_tokens(value: Any) -> List[str]: def _list_tokens(value: Any) -> List[str]:
if value is None: if value is None:
@@ -414,7 +411,7 @@ class PersonProfileService:
if not pid: if not pid:
return False return False
metadata = self._metadata_dict(item.get("metadata")) metadata = coerce_metadata_dict(item.get("metadata"))
source = str(item.get("source", "") or metadata.get("source", "") or "").strip() source = str(item.get("source", "") or metadata.get("source", "") or "").strip()
if source == f"person_fact:{pid}": if source == f"person_fact:{pid}":
return True return True
@@ -440,15 +437,15 @@ class PersonProfileService:
paragraph_hash: str, paragraph_hash: str,
metadata: Dict[str, Any], metadata: Dict[str, Any],
) -> Tuple[Dict[str, Any], str]: ) -> Tuple[Dict[str, Any], str]:
merged = self._metadata_dict(metadata) merged = coerce_metadata_dict(metadata)
source = str(merged.get("source", "") or "").strip() source = str(merged.get("source", "") or "").strip()
try: try:
paragraph = self.metadata_store.get_paragraph(paragraph_hash) paragraph = self.metadata_store.get_paragraph(paragraph_hash)
except Exception: except Exception:
paragraph = None paragraph = None
if isinstance(paragraph, dict): if isinstance(paragraph, dict):
paragraph_metadata = paragraph.get("metadata", {}) or {} paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata"))
if isinstance(paragraph_metadata, dict): if paragraph_metadata:
merged = {**paragraph_metadata, **merged} merged = {**paragraph_metadata, **merged}
source = source or str(paragraph.get("source", "") or "").strip() source = source or str(paragraph.get("source", "") or "").strip()
source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source) source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source)
@@ -538,7 +535,7 @@ class PersonProfileService:
"score": 0.0, "score": 0.0,
"content": str(para.get("content", ""))[:180], "content": str(para.get("content", ""))[:180],
"source": str(para.get("source", "") or ""), "source": str(para.get("source", "") or ""),
"metadata": self._metadata_dict(para.get("metadata")), "metadata": coerce_metadata_dict(para.get("metadata")),
} }
) )
if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id): if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id):
@@ -562,18 +559,18 @@ class PersonProfileService:
logger.warning(f"向量证据召回失败: alias={alias}, err={e}") logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
continue continue
for item in results: for item in results:
h = str(getattr(item, "hash_value", "") or "") h = str(item.hash_value or "")
if not h or h in seen_hash: if not h or h in seen_hash:
continue continue
metadata, source = self._enrich_paragraph_evidence_metadata( metadata, source = self._enrich_paragraph_evidence_metadata(
h, h,
self._metadata_dict(getattr(item, "metadata", {})), coerce_metadata_dict(item.metadata),
) )
payload = { payload = {
"hash": h, "hash": h,
"type": str(getattr(item, "result_type", "")), "type": str(item.result_type),
"score": float(getattr(item, "score", 0.0) or 0.0), "score": float(item.score or 0.0),
"content": str(getattr(item, "content", "") or "")[:220], "content": str(item.content or "")[:220],
"source": source, "source": source,
"metadata": metadata, "metadata": metadata,
} }

View File

@@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from ..retrieval import TemporalQueryOptions from ..retrieval import RetrievalResult, TemporalQueryOptions
from .metadata import coerce_metadata_dict
from .search_postprocess import ( from .search_postprocess import (
apply_safe_content_dedup, apply_safe_content_dedup,
maybe_apply_smart_path_fallback, maybe_apply_smart_path_fallback,
@@ -286,8 +287,11 @@ class SearchExecutionService:
) )
async def _executor() -> Dict[str, Any]: async def _executor() -> Dict[str, Any]:
original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) retriever_config = getattr(retriever, "config", None)
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) has_runtime_ppr_switch = retriever_config is not None and hasattr(retriever_config, "enable_ppr")
original_ppr = bool(retriever_config.enable_ppr) if has_runtime_ppr_switch else None
if has_runtime_ppr_switch:
retriever_config.enable_ppr = bool(request.enable_ppr)
started_at = time.time() started_at = time.time()
try: try:
retrieved = await retriever.retrieve( retrieved = await retriever.retrieve(
@@ -380,7 +384,8 @@ class SearchExecutionService:
elapsed_ms = (time.time() - started_at) * 1000.0 elapsed_ms = (time.time() - started_at) * 1000.0
return {"results": retrieved, "elapsed_ms": elapsed_ms} return {"results": retrieved, "elapsed_ms": elapsed_ms}
finally: finally:
setattr(retriever.config, "enable_ppr", original_ppr) if has_runtime_ppr_switch:
retriever_config.enable_ppr = bool(original_ppr)
dedup_hit = False dedup_hit = False
try: try:
@@ -421,18 +426,18 @@ class SearchExecutionService:
) )
@staticmethod @staticmethod
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
serialized: List[Dict[str, Any]] = [] serialized: List[Dict[str, Any]] = []
for item in results: for item in results:
metadata = dict(getattr(item, "metadata", {}) or {}) metadata = coerce_metadata_dict(item.metadata)
if "time_meta" not in metadata: if "time_meta" not in metadata:
metadata["time_meta"] = {} metadata["time_meta"] = {}
serialized.append( serialized.append(
{ {
"hash": getattr(item, "hash_value", ""), "hash": item.hash_value,
"type": getattr(item, "result_type", ""), "type": item.result_type,
"score": float(getattr(item, "score", 0.0)), "score": float(item.score),
"content": getattr(item, "content", ""), "content": item.content,
"metadata": metadata, "metadata": metadata,
} }
) )

View File

@@ -296,6 +296,7 @@ async def get_models_by_url(
async def test_provider_connection( async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"), base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"), api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
): ):
""" """
测试提供商连接状态 测试提供商连接状态
@@ -359,13 +360,19 @@ async def test_provider_connection(
try: try:
start_time = time.time() start_time = time.time()
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
headers = { headers = {"Content-Type": "application/json"}
"Authorization": f"Bearer {api_key}", params = {}
"Content-Type": "application/json",
} if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
# 尝试获取模型列表 # 尝试获取模型列表
models_url = f"{base_url}/models" models_url = f"{base_url}/models"
response = await client.get(models_url, headers=headers) response = await client.get(models_url, headers=headers, params=params)
if response.status_code == 200: if response.status_code == 200:
result["api_key_valid"] = True result["api_key_valid"] = True
@@ -408,9 +415,14 @@ async def test_provider_connection_by_name(
base_url = provider.get("base_url", "") base_url = provider.get("base_url", "")
api_key = provider.get("api_key", "") api_key = provider.get("api_key", "")
client_type = provider.get("client_type", "openai")
if not base_url: if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口 # 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key or None) return await test_provider_connection(
base_url=base_url,
api_key=api_key if api_key else None,
client_type=client_type,
)