from __future__ import annotations from types import SimpleNamespace import numpy as np import pytest from A_memorix.core.embedding import api_adapter as api_adapter_module from A_memorix.core.embedding.api_adapter import EmbeddingAPIAdapter from A_memorix.core.utils.runtime_self_check import run_embedding_runtime_self_check class _FakeEmbeddingClient: def __init__(self, *, natural_dimension: int = 12) -> None: self.natural_dimension = int(natural_dimension) self.requests = [] async def get_embedding(self, request): self.requests.append(request) requested_dimension = request.extra_params.get("dimensions") if requested_dimension is None: requested_dimension = request.extra_params.get("output_dimensionality") dimension = int(requested_dimension or self.natural_dimension) return SimpleNamespace(embedding=[1.0] * dimension) def _build_adapter( monkeypatch: pytest.MonkeyPatch, *, client_type: str, configured_dimension: int = 1024, effective_dimension: int | None = None, model_extra_params: dict | None = None, ): adapter = EmbeddingAPIAdapter(default_dimension=configured_dimension) if effective_dimension is not None: adapter._dimension = int(effective_dimension) adapter._dimension_detected = True fake_client = _FakeEmbeddingClient() model_info = SimpleNamespace( name="embedding-model", api_provider="provider-1", model_identifier="embedding-model-id", extra_params=dict(model_extra_params or {}), ) provider = SimpleNamespace(name="provider-1", client_type=client_type) monkeypatch.setattr(adapter, "_resolve_candidate_model_names", lambda: ["embedding-model"]) monkeypatch.setattr(adapter, "_find_model_info", lambda model_name: model_info) monkeypatch.setattr(adapter, "_find_provider", lambda provider_name: provider) monkeypatch.setattr( api_adapter_module.client_registry, "get_client_class_instance", lambda api_provider, force_new=True: fake_client, ) return adapter, fake_client @pytest.mark.asyncio async def test_encode_uses_canonical_dimension_for_openai_provider(monkeypatch): adapter, fake_client = _build_adapter( monkeypatch, client_type="openai", configured_dimension=1024, effective_dimension=1024, model_extra_params={"task_type": "SEMANTIC_SIMILARITY"}, ) embedding = await adapter.encode("北塔木梯") request = fake_client.requests[-1] assert request.extra_params["dimensions"] == 1024 assert "output_dimensionality" not in request.extra_params assert request.extra_params["task_type"] == "SEMANTIC_SIMILARITY" assert embedding.shape == (1024,) @pytest.mark.asyncio async def test_encode_explicit_dimension_override_wins(monkeypatch): adapter, fake_client = _build_adapter( monkeypatch, client_type="openai", configured_dimension=1024, effective_dimension=1024, ) embedding = await adapter.encode("海潮图", dimensions=256) request = fake_client.requests[-1] assert request.extra_params["dimensions"] == 256 assert "output_dimensionality" not in request.extra_params assert embedding.shape == (256,) @pytest.mark.asyncio async def test_encode_maps_dimension_to_gemini_output_dimensionality(monkeypatch): adapter, fake_client = _build_adapter( monkeypatch, client_type="gemini", configured_dimension=1024, effective_dimension=768, ) embedding = await adapter.encode("广播站") request = fake_client.requests[-1] assert request.extra_params["output_dimensionality"] == 768 assert "dimensions" not in request.extra_params assert embedding.shape == (768,) @pytest.mark.asyncio async def test_encode_does_not_force_dimension_for_unsupported_provider(monkeypatch): adapter, fake_client = _build_adapter( monkeypatch, client_type="custom", configured_dimension=1024, effective_dimension=640, model_extra_params={ "dimensions": 999, "output_dimensionality": 888, "custom_flag": "keep-me", }, ) embedding = await adapter.encode("蓝漆铁盒") request = fake_client.requests[-1] assert "dimensions" not in request.extra_params assert "output_dimensionality" not in request.extra_params assert request.extra_params["custom_flag"] == "keep-me" assert embedding.shape == (fake_client.natural_dimension,) @pytest.mark.asyncio async def test_runtime_self_check_reports_requested_dimension_without_explicit_override(): class _FakeEmbeddingManager: def __init__(self) -> None: self.detected_dimension = 384 self.encode_calls = [] async def _detect_dimension(self) -> int: return self.detected_dimension def get_requested_dimension(self) -> int: return self.detected_dimension async def encode(self, text): self.encode_calls.append(text) return np.ones(self.detected_dimension, dtype=np.float32) manager = _FakeEmbeddingManager() report = await run_embedding_runtime_self_check( config={"embedding": {"dimension": 1024}}, vector_store=SimpleNamespace(dimension=384), embedding_manager=manager, ) assert report["ok"] is True assert report["configured_dimension"] == 1024 assert report["requested_dimension"] == 384 assert report["detected_dimension"] == 384 assert report["encoded_dimension"] == 384 assert manager.encode_calls == ["A_Memorix runtime self check"] @pytest.mark.asyncio async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch): adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True) adapter._dimension = 4 adapter._dimension_detected = True async def fake_detect_dimension() -> int: return 4 async def fake_get_embedding_direct(text: str, dimensions: int | None = None): del dimensions base = float(ord(str(text)[0])) return [base, base + 1.0, base + 2.0, base + 3.0] monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension) monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct) embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2) assert embeddings.shape == (4, 4) assert np.array_equal(embeddings[0], embeddings[2]) assert embeddings[1][0] == float(ord("B")) assert embeddings[3][0] == float(ord("C"))