diff --git a/pytests/image_sys_test/image_manager_test.py b/pytests/image_sys_test/image_manager_test.py index ccc852df..9c4143f3 100644 --- a/pytests/image_sys_test/image_manager_test.py +++ b/pytests/image_sys_test/image_manager_test.py @@ -18,23 +18,33 @@ class DummyLogger: class DummySession: + def __init__(self): + self.record = None + def exec(self, *a, **k): + record = self.record + class R: def first(self): - return None + return record def yield_per(self, n): - return iter(()) + if record is None: + return iter(()) + return iter((record,)) return R() - def add(self, *a, **k): - pass + def add(self, record, *a, **k): + self.record = record def flush(self, *a, **k): pass def delete(self, *a, **k): + self.record = None + + def expunge(self, *a, **k): pass def __enter__(self): @@ -48,18 +58,35 @@ class DummyMaiImage: def __init__(self, full_path=None, image_bytes=None): self.full_path = full_path self.image_bytes = image_bytes + self.file_hash = "dummy-hash" self.image_format = "png" self.description = "" self.vlm_processed = False @classmethod def from_db_instance(cls, record): - return cls() + image = cls(full_path=getattr(record, "full_path", None)) + image.file_hash = getattr(record, "image_hash", "dummy-hash") + image.description = getattr(record, "description", "") + image.vlm_processed = getattr(record, "vlm_processed", False) + return image def to_db_instance(self): - return types.SimpleNamespace(id=1, full_path=str(self.full_path) if self.full_path is not None else "") + return types.SimpleNamespace( + description=self.description, + full_path=str(self.full_path) if self.full_path is not None else "", + id=1, + image_hash=self.file_hash, + image_type="image", + last_used_time=None, + no_file_flag=False, + query_count=0, + register_time=None, + vlm_processed=self.vlm_processed, + ) async def calculate_hash_format(self): + self.file_hash = "dummy-hash" return None @@ -71,6 +98,14 @@ class DummyLLMRequest: return ("dummy description", {}) +class DummyLLMServiceClient: + def __init__(self, *a, **k): + pass + + async def generate_response_for_image(self, prompt, image_base64, image_format, options=None): + return types.SimpleNamespace(response="dummy description") + + class DummySelect: def __init__(self, *a, **k): pass @@ -82,19 +117,58 @@ class DummySelect: return self +class DetachedRecord: + def __init__(self, description="cached description", vlm_processed=True): + self._detached = False + self._description = description + self._vlm_processed = vlm_processed + + @property + def description(self): + if not self._detached: + raise RuntimeError("attribute refresh operation cannot proceed") + return self._description + + @property + def vlm_processed(self): + if not self._detached: + raise RuntimeError("attribute refresh operation cannot proceed") + return self._vlm_processed + + +class DetachedRecordSession(DummySession): + def __init__(self, record): + self.record = record + + def exec(self, *a, **k): + record = self.record + + class R: + def first(self): + return record + + return R() + + def expunge(self, record): + record._detached = True + + @pytest.fixture(autouse=True) def patch_external_dependencies(monkeypatch): # Provide dummy implementations as modules so that importing image_manager is safe # Patch LLMRequest llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest) monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod) + llm_service_mod = types.SimpleNamespace(LLMServiceClient=DummyLLMServiceClient) + monkeypatch.setitem(sys.modules, "src.services.llm_service", llm_service_mod) # Patch logger logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger()) monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod) # Patch DB session provider - db_mod = types.SimpleNamespace(get_db_session=lambda: DummySession()) + shared_session = DummySession() + db_mod = types.SimpleNamespace(get_db_session=lambda: shared_session) monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod) # Patch database model types @@ -110,11 +184,13 @@ def patch_external_dependencies(monkeypatch): monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod) # Patch config values used at import-time - cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style")) - model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm")) - config_mod = types.SimpleNamespace(global_config=cfg, model_config=model_cfg) + cfg = types.SimpleNamespace(visual=types.SimpleNamespace(visual_style="test-style")) + config_mod = types.SimpleNamespace(global_config=cfg) monkeypatch.setitem(sys.modules, "src.config.config", config_mod) + llm_options_mod = types.SimpleNamespace(LLMImageOptions=lambda **kwargs: types.SimpleNamespace(**kwargs)) + monkeypatch.setitem(sys.modules, "src.common.data_models.llm_service_data_models", llm_options_mod) + # If module already imported, reload it to apply patches mod_name = "src.chat.image_system.image_manager" if mod_name in sys.modules: @@ -198,3 +274,16 @@ async def test_save_image_and_process_and_cleanup(tmp_path): # cleanup should run without error mgr.cleanup_invalid_descriptions_in_db() + + +@pytest.mark.asyncio +async def test_get_image_description_returns_cached_description_after_session_closed(monkeypatch, tmp_path): + image_manager = _load_image_manager_module(tmp_path) + + cached_record = DetachedRecord() + monkeypatch.setattr(image_manager, "get_db_session", lambda: DetachedRecordSession(cached_record)) + + mgr = image_manager.ImageManager() + desc = await mgr.get_image_description(image_hash="cached-hash", wait_for_build=False) + + assert desc == "cached description" diff --git a/src/chat/image_system/image_manager.py b/src/chat/image_system/image_manager.py index 42364cb8..4ab7dbf5 100644 --- a/src/chat/image_system/image_manager.py +++ b/src/chat/image_system/image_manager.py @@ -49,7 +49,11 @@ class ImageManager: """根据哈希获取图片记录。""" with get_db_session() as session: statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1) - return session.exec(statement).first() + record = session.exec(statement).first() + if record is not None: + # 返回会话外使用的只读记录,避免在会话关闭后触发属性刷新。 + session.expunge(record) + return record def _normalize_image_registration_fields(self, record: Images) -> bool: """Normalize accidental emoji registration fields on image records."""