fix:修复数据库问题
This commit is contained in:
@@ -18,23 +18,33 @@ class DummyLogger:
|
|||||||
|
|
||||||
|
|
||||||
class DummySession:
|
class DummySession:
|
||||||
|
def __init__(self):
|
||||||
|
self.record = None
|
||||||
|
|
||||||
def exec(self, *a, **k):
|
def exec(self, *a, **k):
|
||||||
|
record = self.record
|
||||||
|
|
||||||
class R:
|
class R:
|
||||||
def first(self):
|
def first(self):
|
||||||
return None
|
return record
|
||||||
|
|
||||||
def yield_per(self, n):
|
def yield_per(self, n):
|
||||||
return iter(())
|
if record is None:
|
||||||
|
return iter(())
|
||||||
|
return iter((record,))
|
||||||
|
|
||||||
return R()
|
return R()
|
||||||
|
|
||||||
def add(self, *a, **k):
|
def add(self, record, *a, **k):
|
||||||
pass
|
self.record = record
|
||||||
|
|
||||||
def flush(self, *a, **k):
|
def flush(self, *a, **k):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def delete(self, *a, **k):
|
def delete(self, *a, **k):
|
||||||
|
self.record = None
|
||||||
|
|
||||||
|
def expunge(self, *a, **k):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@@ -48,18 +58,35 @@ class DummyMaiImage:
|
|||||||
def __init__(self, full_path=None, image_bytes=None):
|
def __init__(self, full_path=None, image_bytes=None):
|
||||||
self.full_path = full_path
|
self.full_path = full_path
|
||||||
self.image_bytes = image_bytes
|
self.image_bytes = image_bytes
|
||||||
|
self.file_hash = "dummy-hash"
|
||||||
self.image_format = "png"
|
self.image_format = "png"
|
||||||
self.description = ""
|
self.description = ""
|
||||||
self.vlm_processed = False
|
self.vlm_processed = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db_instance(cls, record):
|
def from_db_instance(cls, record):
|
||||||
return cls()
|
image = cls(full_path=getattr(record, "full_path", None))
|
||||||
|
image.file_hash = getattr(record, "image_hash", "dummy-hash")
|
||||||
|
image.description = getattr(record, "description", "")
|
||||||
|
image.vlm_processed = getattr(record, "vlm_processed", False)
|
||||||
|
return image
|
||||||
|
|
||||||
def to_db_instance(self):
|
def to_db_instance(self):
|
||||||
return types.SimpleNamespace(id=1, full_path=str(self.full_path) if self.full_path is not None else "")
|
return types.SimpleNamespace(
|
||||||
|
description=self.description,
|
||||||
|
full_path=str(self.full_path) if self.full_path is not None else "",
|
||||||
|
id=1,
|
||||||
|
image_hash=self.file_hash,
|
||||||
|
image_type="image",
|
||||||
|
last_used_time=None,
|
||||||
|
no_file_flag=False,
|
||||||
|
query_count=0,
|
||||||
|
register_time=None,
|
||||||
|
vlm_processed=self.vlm_processed,
|
||||||
|
)
|
||||||
|
|
||||||
async def calculate_hash_format(self):
|
async def calculate_hash_format(self):
|
||||||
|
self.file_hash = "dummy-hash"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -71,6 +98,14 @@ class DummyLLMRequest:
|
|||||||
return ("dummy description", {})
|
return ("dummy description", {})
|
||||||
|
|
||||||
|
|
||||||
|
class DummyLLMServiceClient:
|
||||||
|
def __init__(self, *a, **k):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def generate_response_for_image(self, prompt, image_base64, image_format, options=None):
|
||||||
|
return types.SimpleNamespace(response="dummy description")
|
||||||
|
|
||||||
|
|
||||||
class DummySelect:
|
class DummySelect:
|
||||||
def __init__(self, *a, **k):
|
def __init__(self, *a, **k):
|
||||||
pass
|
pass
|
||||||
@@ -82,19 +117,58 @@ class DummySelect:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DetachedRecord:
|
||||||
|
def __init__(self, description="cached description", vlm_processed=True):
|
||||||
|
self._detached = False
|
||||||
|
self._description = description
|
||||||
|
self._vlm_processed = vlm_processed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
if not self._detached:
|
||||||
|
raise RuntimeError("attribute refresh operation cannot proceed")
|
||||||
|
return self._description
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vlm_processed(self):
|
||||||
|
if not self._detached:
|
||||||
|
raise RuntimeError("attribute refresh operation cannot proceed")
|
||||||
|
return self._vlm_processed
|
||||||
|
|
||||||
|
|
||||||
|
class DetachedRecordSession(DummySession):
|
||||||
|
def __init__(self, record):
|
||||||
|
self.record = record
|
||||||
|
|
||||||
|
def exec(self, *a, **k):
|
||||||
|
record = self.record
|
||||||
|
|
||||||
|
class R:
|
||||||
|
def first(self):
|
||||||
|
return record
|
||||||
|
|
||||||
|
return R()
|
||||||
|
|
||||||
|
def expunge(self, record):
|
||||||
|
record._detached = True
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def patch_external_dependencies(monkeypatch):
|
def patch_external_dependencies(monkeypatch):
|
||||||
# Provide dummy implementations as modules so that importing image_manager is safe
|
# Provide dummy implementations as modules so that importing image_manager is safe
|
||||||
# Patch LLMRequest
|
# Patch LLMRequest
|
||||||
llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest)
|
llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest)
|
||||||
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod)
|
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod)
|
||||||
|
llm_service_mod = types.SimpleNamespace(LLMServiceClient=DummyLLMServiceClient)
|
||||||
|
monkeypatch.setitem(sys.modules, "src.services.llm_service", llm_service_mod)
|
||||||
|
|
||||||
# Patch logger
|
# Patch logger
|
||||||
logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger())
|
logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger())
|
||||||
monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod)
|
monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod)
|
||||||
|
|
||||||
# Patch DB session provider
|
# Patch DB session provider
|
||||||
db_mod = types.SimpleNamespace(get_db_session=lambda: DummySession())
|
shared_session = DummySession()
|
||||||
|
db_mod = types.SimpleNamespace(get_db_session=lambda: shared_session)
|
||||||
monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod)
|
monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod)
|
||||||
|
|
||||||
# Patch database model types
|
# Patch database model types
|
||||||
@@ -110,11 +184,13 @@ def patch_external_dependencies(monkeypatch):
|
|||||||
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
|
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
|
||||||
|
|
||||||
# Patch config values used at import-time
|
# Patch config values used at import-time
|
||||||
cfg = types.SimpleNamespace(personality=types.SimpleNamespace(visual_style="test-style"))
|
cfg = types.SimpleNamespace(visual=types.SimpleNamespace(visual_style="test-style"))
|
||||||
model_cfg = types.SimpleNamespace(model_task_config=types.SimpleNamespace(vlm="test-vlm"))
|
config_mod = types.SimpleNamespace(global_config=cfg)
|
||||||
config_mod = types.SimpleNamespace(global_config=cfg, model_config=model_cfg)
|
|
||||||
monkeypatch.setitem(sys.modules, "src.config.config", config_mod)
|
monkeypatch.setitem(sys.modules, "src.config.config", config_mod)
|
||||||
|
|
||||||
|
llm_options_mod = types.SimpleNamespace(LLMImageOptions=lambda **kwargs: types.SimpleNamespace(**kwargs))
|
||||||
|
monkeypatch.setitem(sys.modules, "src.common.data_models.llm_service_data_models", llm_options_mod)
|
||||||
|
|
||||||
# If module already imported, reload it to apply patches
|
# If module already imported, reload it to apply patches
|
||||||
mod_name = "src.chat.image_system.image_manager"
|
mod_name = "src.chat.image_system.image_manager"
|
||||||
if mod_name in sys.modules:
|
if mod_name in sys.modules:
|
||||||
@@ -198,3 +274,16 @@ async def test_save_image_and_process_and_cleanup(tmp_path):
|
|||||||
|
|
||||||
# cleanup should run without error
|
# cleanup should run without error
|
||||||
mgr.cleanup_invalid_descriptions_in_db()
|
mgr.cleanup_invalid_descriptions_in_db()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_image_description_returns_cached_description_after_session_closed(monkeypatch, tmp_path):
|
||||||
|
image_manager = _load_image_manager_module(tmp_path)
|
||||||
|
|
||||||
|
cached_record = DetachedRecord()
|
||||||
|
monkeypatch.setattr(image_manager, "get_db_session", lambda: DetachedRecordSession(cached_record))
|
||||||
|
|
||||||
|
mgr = image_manager.ImageManager()
|
||||||
|
desc = await mgr.get_image_description(image_hash="cached-hash", wait_for_build=False)
|
||||||
|
|
||||||
|
assert desc == "cached description"
|
||||||
|
|||||||
@@ -49,7 +49,11 @@ class ImageManager:
|
|||||||
"""根据哈希获取图片记录。"""
|
"""根据哈希获取图片记录。"""
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
||||||
return session.exec(statement).first()
|
record = session.exec(statement).first()
|
||||||
|
if record is not None:
|
||||||
|
# 返回会话外使用的只读记录,避免在会话关闭后触发属性刷新。
|
||||||
|
session.expunge(record)
|
||||||
|
return record
|
||||||
|
|
||||||
def _normalize_image_registration_fields(self, record: Images) -> bool:
|
def _normalize_image_registration_fields(self, record: Images) -> bool:
|
||||||
"""Normalize accidental emoji registration fields on image records."""
|
"""Normalize accidental emoji registration fields on image records."""
|
||||||
|
|||||||
Reference in New Issue
Block a user