Merge branch 'dev' of https://github.com/Mai-with-u/MaiBot into dev
This commit is contained in:
187
pytests/webui/test_model_routes.py
Normal file
187
pytests/webui/test_model_routes.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""模型路由测试
|
||||||
|
|
||||||
|
验证 Gemini 提供商连接测试会使用查询参数传递 API Key,
|
||||||
|
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
|
||||||
|
config_module = ModuleType("src.config.config")
|
||||||
|
config_module.__dict__["CONFIG_DIR"] = "."
|
||||||
|
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
||||||
|
|
||||||
|
dependencies_module = ModuleType("src.webui.dependencies")
|
||||||
|
|
||||||
|
async def require_auth():
|
||||||
|
return "test-token"
|
||||||
|
|
||||||
|
dependencies_module.__dict__["require_auth"] = require_auth
|
||||||
|
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
|
||||||
|
|
||||||
|
sys.modules.pop("src.webui.routers.model", None)
|
||||||
|
return importlib.import_module("src.webui.routers.model")
|
||||||
|
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
"""简化版 HTTP 响应对象。"""
|
||||||
|
|
||||||
|
def __init__(self, status_code: int):
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
|
||||||
|
def build_async_client_factory(
|
||||||
|
responses: list[FakeResponse],
|
||||||
|
calls: list[dict[str, Any]],
|
||||||
|
):
|
||||||
|
"""构造一个可记录请求参数的 AsyncClient 替身。"""
|
||||||
|
|
||||||
|
response_iter = iter(responses)
|
||||||
|
|
||||||
|
class FakeAsyncClient:
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "FakeAsyncClient":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
headers: dict[str, Any] | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> FakeResponse:
|
||||||
|
calls.append(
|
||||||
|
{
|
||||||
|
"url": url,
|
||||||
|
"headers": headers or {},
|
||||||
|
"params": params or {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return next(response_iter)
|
||||||
|
|
||||||
|
return FakeAsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_test_provider_connection_uses_query_api_key_for_gemini(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Gemini 连接测试应通过查询参数传递 API Key。"""
|
||||||
|
model_routes = load_model_routes(monkeypatch)
|
||||||
|
calls: list[dict[str, Any]] = []
|
||||||
|
fake_client_class = build_async_client_factory(
|
||||||
|
responses=[FakeResponse(200), FakeResponse(200)],
|
||||||
|
calls=calls,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
|
||||||
|
|
||||||
|
result = await model_routes.test_provider_connection(
|
||||||
|
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||||
|
api_key="valid-gemini-key",
|
||||||
|
client_type="gemini",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["network_ok"] is True
|
||||||
|
assert result["api_key_valid"] is True
|
||||||
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
network_call = calls[0]
|
||||||
|
validation_call = calls[1]
|
||||||
|
|
||||||
|
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
assert network_call["headers"] == {}
|
||||||
|
assert network_call["params"] == {}
|
||||||
|
|
||||||
|
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
|
||||||
|
assert validation_call["params"] == {"key": "valid-gemini-key"}
|
||||||
|
assert validation_call["headers"] == {"Content-Type": "application/json"}
|
||||||
|
assert "Authorization" not in validation_call["headers"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
|
||||||
|
model_routes = load_model_routes(monkeypatch)
|
||||||
|
calls: list[dict[str, Any]] = []
|
||||||
|
fake_client_class = build_async_client_factory(
|
||||||
|
responses=[FakeResponse(200), FakeResponse(200)],
|
||||||
|
calls=calls,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
|
||||||
|
|
||||||
|
result = await model_routes.test_provider_connection(
|
||||||
|
base_url="https://example.com/v1",
|
||||||
|
api_key="valid-openai-key",
|
||||||
|
client_type="openai",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["network_ok"] is True
|
||||||
|
assert result["api_key_valid"] is True
|
||||||
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
validation_call = calls[1]
|
||||||
|
|
||||||
|
assert validation_call["url"] == "https://example.com/v1/models"
|
||||||
|
assert validation_call["params"] == {}
|
||||||
|
assert validation_call["headers"]["Content-Type"] == "application/json"
|
||||||
|
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_test_provider_connection_by_name_forwards_provider_client_type(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
) -> None:
|
||||||
|
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
|
||||||
|
model_routes = load_model_routes(monkeypatch)
|
||||||
|
config_path = tmp_path / "model_config.toml"
|
||||||
|
config_path.write_text(
|
||||||
|
"""
|
||||||
|
[[api_providers]]
|
||||||
|
name = "Gemini"
|
||||||
|
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
api_key = "valid-gemini-key"
|
||||||
|
client_type = "gemini"
|
||||||
|
""".strip(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
|
||||||
|
|
||||||
|
captured_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
return {
|
||||||
|
"network_ok": True,
|
||||||
|
"api_key_valid": True,
|
||||||
|
"latency_ms": 12.34,
|
||||||
|
"error": None,
|
||||||
|
"http_status": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
|
||||||
|
|
||||||
|
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
|
||||||
|
|
||||||
|
assert result["network_ok"] is True
|
||||||
|
assert result["api_key_valid"] is True
|
||||||
|
assert captured_kwargs == {
|
||||||
|
"base_url": "https://generativelanguage.googleapis.com/v1beta",
|
||||||
|
"api_key": "valid-gemini-key",
|
||||||
|
"client_type": "gemini",
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ from src.common.logger import get_logger
|
|||||||
from ..storage import VectorStore, GraphStore, MetadataStore
|
from ..storage import VectorStore, GraphStore, MetadataStore
|
||||||
from ..embedding import EmbeddingAPIAdapter
|
from ..embedding import EmbeddingAPIAdapter
|
||||||
from ..utils.matcher import AhoCorasick
|
from ..utils.matcher import AhoCorasick
|
||||||
|
from ..utils.metadata import coerce_metadata_dict
|
||||||
from ..utils.time_parser import format_timestamp
|
from ..utils.time_parser import format_timestamp
|
||||||
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
|
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
|
||||||
from .pagerank import PersonalizedPageRank, PageRankConfig
|
from .pagerank import PersonalizedPageRank, PageRankConfig
|
||||||
@@ -482,7 +483,7 @@ class DualPathRetriever:
|
|||||||
score=float(item.score),
|
score=float(item.score),
|
||||||
result_type=item.result_type,
|
result_type=item.result_type,
|
||||||
source=item.source,
|
source=item.source,
|
||||||
metadata=dict(item.metadata or {}),
|
metadata=coerce_metadata_dict(item.metadata),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]:
|
def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]:
|
||||||
@@ -762,7 +763,7 @@ class DualPathRetriever:
|
|||||||
existing = self._clone_retrieval_result(item)
|
existing = self._clone_retrieval_result(item)
|
||||||
merged[item.hash_value] = existing
|
merged[item.hash_value] = existing
|
||||||
else:
|
else:
|
||||||
for key, value in dict(item.metadata or {}).items():
|
for key, value in coerce_metadata_dict(item.metadata).items():
|
||||||
if key not in existing.metadata or existing.metadata.get(key) in (None, "", []):
|
if key not in existing.metadata or existing.metadata.get(key) in (None, "", []):
|
||||||
existing.metadata[key] = value
|
existing.metadata[key] = value
|
||||||
source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search")
|
source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search")
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService
|
|||||||
from ..utils.episode_segmentation_service import EpisodeSegmentationService
|
from ..utils.episode_segmentation_service import EpisodeSegmentationService
|
||||||
from ..utils.episode_service import EpisodeService
|
from ..utils.episode_service import EpisodeService
|
||||||
from ..utils.hash import compute_hash, normalize_text
|
from ..utils.hash import compute_hash, normalize_text
|
||||||
|
from ..utils.metadata import coerce_metadata_dict
|
||||||
from ..utils.person_profile_service import PersonProfileService
|
from ..utils.person_profile_service import PersonProfileService
|
||||||
from ..utils.relation_write_service import RelationWriteService
|
from ..utils.relation_write_service import RelationWriteService
|
||||||
from ..utils.retrieval_tuning_manager import RetrievalTuningManager
|
from ..utils.retrieval_tuning_manager import RetrievalTuningManager
|
||||||
@@ -871,7 +872,7 @@ class SDKMemoryKernel:
|
|||||||
"detail": "chat_filtered",
|
"detail": "chat_filtered",
|
||||||
}
|
}
|
||||||
|
|
||||||
summary_meta = dict(metadata or {})
|
summary_meta = coerce_metadata_dict(metadata)
|
||||||
summary_meta.setdefault("kind", "chat_summary")
|
summary_meta.setdefault("kind", "chat_summary")
|
||||||
if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)):
|
if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)):
|
||||||
result = await self.summarize_chat_stream(
|
result = await self.summarize_chat_stream(
|
||||||
@@ -961,7 +962,7 @@ class SDKMemoryKernel:
|
|||||||
participant_tokens = self._tokens(participants)
|
participant_tokens = self._tokens(participants)
|
||||||
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
|
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
|
||||||
source = self._build_source(source_type, chat_id, person_tokens)
|
source = self._build_source(source_type, chat_id, person_tokens)
|
||||||
paragraph_meta = dict(metadata or {})
|
paragraph_meta = coerce_metadata_dict(metadata)
|
||||||
paragraph_meta.update(
|
paragraph_meta.update(
|
||||||
{
|
{
|
||||||
"external_id": external_token,
|
"external_id": external_token,
|
||||||
|
|||||||
11
src/A_memorix/core/utils/metadata.py
Normal file
11
src/A_memorix/core/utils/metadata.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_metadata_dict(value: Any) -> Dict[str, Any]:
|
||||||
|
"""返回字典,如果输入值不是字典则返回空字典。"""
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
return dict(value)
|
||||||
|
return {}
|
||||||
@@ -27,6 +27,7 @@ from ..retrieval import (
|
|||||||
GraphRelationRecallConfig,
|
GraphRelationRecallConfig,
|
||||||
)
|
)
|
||||||
from ..storage import MetadataStore, GraphStore, VectorStore
|
from ..storage import MetadataStore, GraphStore, VectorStore
|
||||||
|
from .metadata import coerce_metadata_dict
|
||||||
|
|
||||||
logger = get_logger("A_Memorix.PersonProfileService")
|
logger = get_logger("A_Memorix.PersonProfileService")
|
||||||
|
|
||||||
@@ -334,7 +335,7 @@ class PersonProfileService:
|
|||||||
if not pid:
|
if not pid:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
metadata = self._metadata_dict(relation.get("metadata"))
|
metadata = coerce_metadata_dict(relation.get("metadata"))
|
||||||
if str(metadata.get("person_id", "") or "").strip() == pid:
|
if str(metadata.get("person_id", "") or "").strip() == pid:
|
||||||
return True
|
return True
|
||||||
if pid in self._list_tokens(metadata.get("person_ids")):
|
if pid in self._list_tokens(metadata.get("person_ids")):
|
||||||
@@ -350,7 +351,7 @@ class PersonProfileService:
|
|||||||
payload = {
|
payload = {
|
||||||
"hash": source_paragraph,
|
"hash": source_paragraph,
|
||||||
"source": str(paragraph.get("source", "") or ""),
|
"source": str(paragraph.get("source", "") or ""),
|
||||||
"metadata": self._metadata_dict(paragraph.get("metadata")),
|
"metadata": coerce_metadata_dict(paragraph.get("metadata")),
|
||||||
}
|
}
|
||||||
return self._is_evidence_bound_to_person(payload, person_id=pid)
|
return self._is_evidence_bound_to_person(payload, person_id=pid)
|
||||||
|
|
||||||
@@ -385,15 +386,11 @@ class PersonProfileService:
|
|||||||
"score": 1.1,
|
"score": 1.1,
|
||||||
"content": content[:220],
|
"content": content[:220],
|
||||||
"source": str(row.get("source", "") or source),
|
"source": str(row.get("source", "") or source),
|
||||||
"metadata": dict(row.get("metadata", {}) or {}),
|
"metadata": coerce_metadata_dict(row.get("metadata")),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return self._filter_stale_paragraph_evidence(evidence)
|
return self._filter_stale_paragraph_evidence(evidence)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _metadata_dict(value: Any) -> Dict[str, Any]:
|
|
||||||
return dict(value) if isinstance(value, dict) else {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _list_tokens(value: Any) -> List[str]:
|
def _list_tokens(value: Any) -> List[str]:
|
||||||
if value is None:
|
if value is None:
|
||||||
@@ -414,7 +411,7 @@ class PersonProfileService:
|
|||||||
if not pid:
|
if not pid:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
metadata = self._metadata_dict(item.get("metadata"))
|
metadata = coerce_metadata_dict(item.get("metadata"))
|
||||||
source = str(item.get("source", "") or metadata.get("source", "") or "").strip()
|
source = str(item.get("source", "") or metadata.get("source", "") or "").strip()
|
||||||
if source == f"person_fact:{pid}":
|
if source == f"person_fact:{pid}":
|
||||||
return True
|
return True
|
||||||
@@ -440,15 +437,15 @@ class PersonProfileService:
|
|||||||
paragraph_hash: str,
|
paragraph_hash: str,
|
||||||
metadata: Dict[str, Any],
|
metadata: Dict[str, Any],
|
||||||
) -> Tuple[Dict[str, Any], str]:
|
) -> Tuple[Dict[str, Any], str]:
|
||||||
merged = self._metadata_dict(metadata)
|
merged = coerce_metadata_dict(metadata)
|
||||||
source = str(merged.get("source", "") or "").strip()
|
source = str(merged.get("source", "") or "").strip()
|
||||||
try:
|
try:
|
||||||
paragraph = self.metadata_store.get_paragraph(paragraph_hash)
|
paragraph = self.metadata_store.get_paragraph(paragraph_hash)
|
||||||
except Exception:
|
except Exception:
|
||||||
paragraph = None
|
paragraph = None
|
||||||
if isinstance(paragraph, dict):
|
if isinstance(paragraph, dict):
|
||||||
paragraph_metadata = paragraph.get("metadata", {}) or {}
|
paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata"))
|
||||||
if isinstance(paragraph_metadata, dict):
|
if paragraph_metadata:
|
||||||
merged = {**paragraph_metadata, **merged}
|
merged = {**paragraph_metadata, **merged}
|
||||||
source = source or str(paragraph.get("source", "") or "").strip()
|
source = source or str(paragraph.get("source", "") or "").strip()
|
||||||
source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source)
|
source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source)
|
||||||
@@ -538,7 +535,7 @@ class PersonProfileService:
|
|||||||
"score": 0.0,
|
"score": 0.0,
|
||||||
"content": str(para.get("content", ""))[:180],
|
"content": str(para.get("content", ""))[:180],
|
||||||
"source": str(para.get("source", "") or ""),
|
"source": str(para.get("source", "") or ""),
|
||||||
"metadata": self._metadata_dict(para.get("metadata")),
|
"metadata": coerce_metadata_dict(para.get("metadata")),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id):
|
if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id):
|
||||||
@@ -562,18 +559,18 @@ class PersonProfileService:
|
|||||||
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
|
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
|
||||||
continue
|
continue
|
||||||
for item in results:
|
for item in results:
|
||||||
h = str(getattr(item, "hash_value", "") or "")
|
h = str(item.hash_value or "")
|
||||||
if not h or h in seen_hash:
|
if not h or h in seen_hash:
|
||||||
continue
|
continue
|
||||||
metadata, source = self._enrich_paragraph_evidence_metadata(
|
metadata, source = self._enrich_paragraph_evidence_metadata(
|
||||||
h,
|
h,
|
||||||
self._metadata_dict(getattr(item, "metadata", {})),
|
coerce_metadata_dict(item.metadata),
|
||||||
)
|
)
|
||||||
payload = {
|
payload = {
|
||||||
"hash": h,
|
"hash": h,
|
||||||
"type": str(getattr(item, "result_type", "")),
|
"type": str(item.result_type),
|
||||||
"score": float(getattr(item, "score", 0.0) or 0.0),
|
"score": float(item.score or 0.0),
|
||||||
"content": str(getattr(item, "content", "") or "")[:220],
|
"content": str(item.content or "")[:220],
|
||||||
"source": source,
|
"source": source,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from ..retrieval import TemporalQueryOptions
|
from ..retrieval import RetrievalResult, TemporalQueryOptions
|
||||||
|
from .metadata import coerce_metadata_dict
|
||||||
from .search_postprocess import (
|
from .search_postprocess import (
|
||||||
apply_safe_content_dedup,
|
apply_safe_content_dedup,
|
||||||
maybe_apply_smart_path_fallback,
|
maybe_apply_smart_path_fallback,
|
||||||
@@ -286,8 +287,11 @@ class SearchExecutionService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _executor() -> Dict[str, Any]:
|
async def _executor() -> Dict[str, Any]:
|
||||||
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
|
retriever_config = getattr(retriever, "config", None)
|
||||||
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
|
has_runtime_ppr_switch = retriever_config is not None and hasattr(retriever_config, "enable_ppr")
|
||||||
|
original_ppr = bool(retriever_config.enable_ppr) if has_runtime_ppr_switch else None
|
||||||
|
if has_runtime_ppr_switch:
|
||||||
|
retriever_config.enable_ppr = bool(request.enable_ppr)
|
||||||
started_at = time.time()
|
started_at = time.time()
|
||||||
try:
|
try:
|
||||||
retrieved = await retriever.retrieve(
|
retrieved = await retriever.retrieve(
|
||||||
@@ -380,7 +384,8 @@ class SearchExecutionService:
|
|||||||
elapsed_ms = (time.time() - started_at) * 1000.0
|
elapsed_ms = (time.time() - started_at) * 1000.0
|
||||||
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
||||||
finally:
|
finally:
|
||||||
setattr(retriever.config, "enable_ppr", original_ppr)
|
if has_runtime_ppr_switch:
|
||||||
|
retriever_config.enable_ppr = bool(original_ppr)
|
||||||
|
|
||||||
dedup_hit = False
|
dedup_hit = False
|
||||||
try:
|
try:
|
||||||
@@ -421,18 +426,18 @@ class SearchExecutionService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]:
|
def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
|
||||||
serialized: List[Dict[str, Any]] = []
|
serialized: List[Dict[str, Any]] = []
|
||||||
for item in results:
|
for item in results:
|
||||||
metadata = dict(getattr(item, "metadata", {}) or {})
|
metadata = coerce_metadata_dict(item.metadata)
|
||||||
if "time_meta" not in metadata:
|
if "time_meta" not in metadata:
|
||||||
metadata["time_meta"] = {}
|
metadata["time_meta"] = {}
|
||||||
serialized.append(
|
serialized.append(
|
||||||
{
|
{
|
||||||
"hash": getattr(item, "hash_value", ""),
|
"hash": item.hash_value,
|
||||||
"type": getattr(item, "result_type", ""),
|
"type": item.result_type,
|
||||||
"score": float(getattr(item, "score", 0.0)),
|
"score": float(item.score),
|
||||||
"content": getattr(item, "content", ""),
|
"content": item.content,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -296,6 +296,7 @@ async def get_models_by_url(
|
|||||||
async def test_provider_connection(
|
async def test_provider_connection(
|
||||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||||
|
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
测试提供商连接状态
|
测试提供商连接状态
|
||||||
@@ -359,13 +360,19 @@ async def test_provider_connection(
|
|||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||||
headers = {
|
headers = {"Content-Type": "application/json"}
|
||||||
"Authorization": f"Bearer {api_key}",
|
params = {}
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
if client_type == "gemini":
|
||||||
|
# Gemini 使用 URL 参数传递 API Key
|
||||||
|
params["key"] = api_key
|
||||||
|
else:
|
||||||
|
# OpenAI 兼容格式使用 Authorization 头
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
# 尝试获取模型列表
|
# 尝试获取模型列表
|
||||||
models_url = f"{base_url}/models"
|
models_url = f"{base_url}/models"
|
||||||
response = await client.get(models_url, headers=headers)
|
response = await client.get(models_url, headers=headers, params=params)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result["api_key_valid"] = True
|
result["api_key_valid"] = True
|
||||||
@@ -408,9 +415,14 @@ async def test_provider_connection_by_name(
|
|||||||
|
|
||||||
base_url = provider.get("base_url", "")
|
base_url = provider.get("base_url", "")
|
||||||
api_key = provider.get("api_key", "")
|
api_key = provider.get("api_key", "")
|
||||||
|
client_type = provider.get("client_type", "openai")
|
||||||
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||||
|
|
||||||
# 调用测试接口
|
# 调用测试接口
|
||||||
return await test_provider_connection(base_url=base_url, api_key=api_key or None)
|
return await test_provider_connection(
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key if api_key else None,
|
||||||
|
client_type=client_type,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user