chore: import deployable mai-bot source tree
This commit is contained in:
2637
pytests/image_sys_test/emoji_manager_test.py
Normal file
2637
pytests/image_sys_test/emoji_manager_test.py
Normal file
File diff suppressed because it is too large
Load Diff
295
pytests/image_sys_test/image_manager_test.py
Normal file
295
pytests/image_sys_test/image_manager_test.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import sys
|
||||
import types
|
||||
import importlib
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def info(self, *a, **k):
|
||||
pass
|
||||
|
||||
def warning(self, *a, **k):
|
||||
pass
|
||||
|
||||
def error(self, *a, **k):
|
||||
pass
|
||||
|
||||
|
||||
class DummySession:
|
||||
def __init__(self):
|
||||
self.record = None
|
||||
|
||||
def exec(self, *a, **k):
|
||||
record = self.record
|
||||
|
||||
class R:
|
||||
def first(self):
|
||||
return record
|
||||
|
||||
def yield_per(self, n):
|
||||
if record is None:
|
||||
return iter(())
|
||||
return iter((record,))
|
||||
|
||||
return R()
|
||||
|
||||
def add(self, record, *a, **k):
|
||||
self.record = record
|
||||
|
||||
def flush(self, *a, **k):
|
||||
pass
|
||||
|
||||
def delete(self, *a, **k):
|
||||
self.record = None
|
||||
|
||||
def expunge(self, *a, **k):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class DummyMaiImage:
|
||||
def __init__(self, full_path=None, image_bytes=None):
|
||||
self.full_path = full_path
|
||||
self.image_bytes = image_bytes
|
||||
self.file_hash = "dummy-hash"
|
||||
self.image_format = "png"
|
||||
self.description = ""
|
||||
self.vlm_processed = False
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, record):
|
||||
image = cls(full_path=getattr(record, "full_path", None))
|
||||
image.file_hash = getattr(record, "image_hash", "dummy-hash")
|
||||
image.description = getattr(record, "description", "")
|
||||
image.vlm_processed = getattr(record, "vlm_processed", False)
|
||||
return image
|
||||
|
||||
def to_db_instance(self):
|
||||
return types.SimpleNamespace(
|
||||
description=self.description,
|
||||
full_path=str(self.full_path) if self.full_path is not None else "",
|
||||
id=1,
|
||||
image_hash=self.file_hash,
|
||||
image_type="image",
|
||||
last_used_time=None,
|
||||
no_file_flag=False,
|
||||
query_count=0,
|
||||
register_time=None,
|
||||
vlm_processed=self.vlm_processed,
|
||||
)
|
||||
|
||||
async def calculate_hash_format(self):
|
||||
self.file_hash = "dummy-hash"
|
||||
return None
|
||||
|
||||
|
||||
class DummyLLMRequest:
|
||||
def __init__(self, *a, **k):
|
||||
pass
|
||||
|
||||
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
|
||||
return ("dummy description", {})
|
||||
|
||||
|
||||
class DummyLLMServiceClient:
|
||||
def __init__(self, *a, **k):
|
||||
pass
|
||||
|
||||
async def generate_response_for_image(self, prompt, image_base64, image_format, options=None):
|
||||
return types.SimpleNamespace(response="dummy description")
|
||||
|
||||
|
||||
class DummySelect:
|
||||
def __init__(self, *a, **k):
|
||||
pass
|
||||
|
||||
def filter_by(self, *a, **k):
|
||||
return self
|
||||
|
||||
def limit(self, n):
|
||||
return self
|
||||
|
||||
|
||||
class DetachedRecord:
|
||||
def __init__(self, description="cached description", vlm_processed=True):
|
||||
self._detached = False
|
||||
self._description = description
|
||||
self._vlm_processed = vlm_processed
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if not self._detached:
|
||||
raise RuntimeError("attribute refresh operation cannot proceed")
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def vlm_processed(self):
|
||||
if not self._detached:
|
||||
raise RuntimeError("attribute refresh operation cannot proceed")
|
||||
return self._vlm_processed
|
||||
|
||||
|
||||
class DetachedRecordSession(DummySession):
|
||||
def __init__(self, record):
|
||||
self.record = record
|
||||
|
||||
def exec(self, *a, **k):
|
||||
record = self.record
|
||||
|
||||
class R:
|
||||
def first(self):
|
||||
return record
|
||||
|
||||
return R()
|
||||
|
||||
def expunge(self, record):
|
||||
record._detached = True
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_external_dependencies(monkeypatch):
|
||||
# Provide dummy implementations as modules so that importing image_manager is safe
|
||||
# Patch LLMRequest
|
||||
llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest)
|
||||
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod)
|
||||
llm_service_mod = types.SimpleNamespace(LLMServiceClient=DummyLLMServiceClient)
|
||||
monkeypatch.setitem(sys.modules, "src.services.llm_service", llm_service_mod)
|
||||
|
||||
# Patch logger
|
||||
logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger())
|
||||
monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod)
|
||||
|
||||
# Patch DB session provider
|
||||
shared_session = DummySession()
|
||||
db_mod = types.SimpleNamespace(get_db_session=lambda: shared_session)
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod)
|
||||
|
||||
# Patch database model types
|
||||
db_model_mod = types.SimpleNamespace(Images=types.SimpleNamespace, ImageType=types.SimpleNamespace(IMAGE="image"))
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database_model", db_model_mod)
|
||||
|
||||
# Patch MaiImage data model
|
||||
data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage)
|
||||
monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod)
|
||||
|
||||
# Patch SQLModel select function
|
||||
sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect())
|
||||
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
|
||||
|
||||
# Patch prompt manager used to build image description prompt.
|
||||
class _PromptManager:
|
||||
def get_prompt(self, _name):
|
||||
return types.SimpleNamespace()
|
||||
|
||||
async def render_prompt(self, _prompt):
|
||||
return "test-style"
|
||||
|
||||
prompt_manager_mod = types.SimpleNamespace(prompt_manager=_PromptManager())
|
||||
monkeypatch.setitem(sys.modules, "src.prompt.prompt_manager", prompt_manager_mod)
|
||||
|
||||
llm_options_mod = types.SimpleNamespace(LLMImageOptions=lambda **kwargs: types.SimpleNamespace(**kwargs))
|
||||
monkeypatch.setitem(sys.modules, "src.common.data_models.llm_service_data_models", llm_options_mod)
|
||||
|
||||
# If module already imported, reload it to apply patches
|
||||
mod_name = "src.chat.image_system.image_manager"
|
||||
if mod_name in sys.modules:
|
||||
importlib.reload(sys.modules[mod_name])
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def _load_image_manager_module(tmp_path=None):
|
||||
repo_root = Path(__file__).parent.parent.parent
|
||||
file_path = repo_root / "src" / "chat" / "image_system" / "image_manager.py"
|
||||
spec = importlib.util.spec_from_file_location("image_manager_test_loaded", str(file_path))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = mod
|
||||
spec.loader.exec_module(mod)
|
||||
# Redirect IMAGE_DIR to pytest's tmp_path when provided
|
||||
try:
|
||||
if tmp_path is not None:
|
||||
tmpdir = Path(tmp_path)
|
||||
tmpdir.mkdir(parents=True, exist_ok=True)
|
||||
mod.IMAGE_DIR = tmpdir
|
||||
except Exception:
|
||||
pass
|
||||
return mod
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_image_description_generates(tmp_path):
|
||||
image_manager = _load_image_manager_module(tmp_path)
|
||||
|
||||
mgr = image_manager.ImageManager()
|
||||
desc = await mgr.get_image_description(image_bytes=b"abc")
|
||||
assert desc == "dummy description"
|
||||
|
||||
|
||||
def test_get_image_from_db_none(tmp_path):
|
||||
image_manager = _load_image_manager_module(tmp_path)
|
||||
|
||||
mgr = image_manager.ImageManager()
|
||||
assert mgr.get_image_from_db("nohash") is None
|
||||
|
||||
|
||||
def test_register_image_to_db(tmp_path):
|
||||
image_manager = _load_image_manager_module(tmp_path)
|
||||
|
||||
mgr = image_manager.ImageManager()
|
||||
p = tmp_path / "img.png"
|
||||
p.write_bytes(b"data")
|
||||
img = DummyMaiImage(full_path=p, image_bytes=b"data")
|
||||
assert mgr.register_image_to_db(img) is True
|
||||
|
||||
|
||||
def test_update_image_description_not_found(tmp_path):
|
||||
image_manager = _load_image_manager_module(tmp_path)
|
||||
|
||||
mgr = image_manager.ImageManager()
|
||||
img = DummyMaiImage()
|
||||
img.file_hash = "nohash"
|
||||
img.description = "desc"
|
||||
assert mgr.update_image_description(img) is False
|
||||
|
||||
|
||||
def test_delete_image_not_found(tmp_path):
|
||||
image_manager = _load_image_manager_module(tmp_path)
|
||||
|
||||
mgr = image_manager.ImageManager()
|
||||
img = DummyMaiImage()
|
||||
img.file_hash = "nohash"
|
||||
img.full_path = tmp_path = None
|
||||
assert mgr.delete_image(img) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_image_and_process_and_cleanup(tmp_path):
|
||||
image_manager = _load_image_manager_module(tmp_path)
|
||||
|
||||
mgr = image_manager.ImageManager()
|
||||
# call save_image_and_process
|
||||
image = await mgr.save_image_and_process(b"binarydata")
|
||||
assert getattr(image, "description", None) == "dummy description"
|
||||
|
||||
# cleanup should run without error
|
||||
mgr.cleanup_invalid_descriptions_in_db()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_image_description_returns_cached_description_after_session_closed(monkeypatch, tmp_path):
|
||||
image_manager = _load_image_manager_module(tmp_path)
|
||||
|
||||
cached_record = DetachedRecord()
|
||||
monkeypatch.setattr(image_manager, "get_db_session", lambda: DetachedRecordSession(cached_record))
|
||||
|
||||
mgr = image_manager.ImageManager()
|
||||
desc = await mgr.get_image_description(image_hash="cached-hash", wait_for_build=False)
|
||||
|
||||
assert desc == "cached description"
|
||||
91
pytests/image_sys_test/test_image_data_model.py
Normal file
91
pytests/image_sys_test/test_image_data_model.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import io
|
||||
|
||||
from PIL import Image as PILImage
|
||||
import pytest
|
||||
|
||||
from src.common.data_models.image_data_model import MaiEmoji, MaiImage
|
||||
|
||||
|
||||
def _build_test_image_bytes(image_format: str) -> bytes:
|
||||
image = PILImage.new("RGB", (8, 8), color="white")
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format=image_format)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_format_updates_runtime_path_metadata(tmp_path: Path) -> None:
|
||||
image_bytes = _build_test_image_bytes("JPEG")
|
||||
tmp_file_path = tmp_path / "emoji.tmp"
|
||||
tmp_file_path.write_bytes(image_bytes)
|
||||
|
||||
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
|
||||
|
||||
assert await emoji.calculate_hash_format() is True
|
||||
assert emoji.image_format == "jpeg"
|
||||
assert emoji.full_path.suffix == ".jpeg"
|
||||
assert emoji.file_name == emoji.full_path.name
|
||||
assert emoji.dir_path == tmp_path.resolve()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_format_reuses_existing_target_file(tmp_path: Path) -> None:
|
||||
image_bytes = _build_test_image_bytes("JPEG")
|
||||
tmp_file_path = tmp_path / "emoji.tmp"
|
||||
target_file_path = tmp_path / "emoji.jpeg"
|
||||
tmp_file_path.write_bytes(image_bytes)
|
||||
target_file_path.write_bytes(image_bytes)
|
||||
|
||||
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
|
||||
|
||||
assert await emoji.calculate_hash_format() is True
|
||||
assert emoji.full_path == target_file_path.resolve()
|
||||
assert emoji.file_name == target_file_path.name
|
||||
assert not tmp_file_path.exists()
|
||||
assert target_file_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_cls", "extra_fields"),
|
||||
[
|
||||
(
|
||||
MaiEmoji,
|
||||
{
|
||||
"description": "",
|
||||
"last_used_time": None,
|
||||
"query_count": 0,
|
||||
"register_time": None,
|
||||
},
|
||||
),
|
||||
(
|
||||
MaiImage,
|
||||
{
|
||||
"description": "",
|
||||
"vlm_processed": False,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_from_db_instance_restores_image_format_from_path(
|
||||
tmp_path: Path,
|
||||
model_cls: type[MaiEmoji] | type[MaiImage],
|
||||
extra_fields: dict[str, object],
|
||||
) -> None:
|
||||
image_path = tmp_path / "cached.png"
|
||||
image_path.write_bytes(_build_test_image_bytes("PNG"))
|
||||
|
||||
record = SimpleNamespace(
|
||||
no_file_flag=False,
|
||||
image_hash="hash",
|
||||
full_path=str(image_path),
|
||||
**extra_fields,
|
||||
)
|
||||
|
||||
image = model_cls.from_db_instance(record)
|
||||
|
||||
assert image.full_path == image_path.resolve()
|
||||
assert image.file_name == image_path.name
|
||||
assert image.image_format == "png"
|
||||
Reference in New Issue
Block a user