91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
"""测试黑话学习器的数据库读取行为。"""
|
|
|
|
from contextlib import contextmanager
|
|
from typing import Generator
|
|
|
|
import pytest
|
|
from sqlalchemy.pool import StaticPool
|
|
from sqlmodel import Session, SQLModel, create_engine, select
|
|
|
|
from src.bw_learner.jargon_miner import JargonMiner
|
|
from src.common.database.database_model import Jargon
|
|
|
|
|
|
@pytest.fixture(name="jargon_miner_engine")
|
|
def jargon_miner_engine_fixture() -> Generator:
|
|
"""创建用于黑话学习器测试的内存数据库引擎。
|
|
|
|
Yields:
|
|
Generator: 供测试使用的 SQLite 内存引擎。
|
|
"""
|
|
engine = create_engine(
|
|
"sqlite://",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
SQLModel.metadata.create_all(engine)
|
|
yield engine
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
jargon_miner_engine,
|
|
) -> None:
|
|
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
|
|
import src.bw_learner.jargon_miner as jargon_miner_module
|
|
|
|
with Session(jargon_miner_engine) as session:
|
|
session.add(
|
|
Jargon(
|
|
content="VF8V4L",
|
|
raw_content='["[1] first"]',
|
|
meaning="",
|
|
session_id_dict='{"session-a": 1}',
|
|
count=0,
|
|
is_jargon=True,
|
|
is_complete=False,
|
|
is_global=False,
|
|
last_inference_count=0,
|
|
)
|
|
)
|
|
session.commit()
|
|
|
|
@contextmanager
|
|
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
|
"""构造带自动提交语义的测试会话工厂。
|
|
|
|
Args:
|
|
auto_commit: 退出上下文时是否自动提交。
|
|
|
|
Yields:
|
|
Generator[Session, None, None]: SQLModel 会话对象。
|
|
"""
|
|
session = Session(jargon_miner_engine)
|
|
try:
|
|
yield session
|
|
if auto_commit:
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
|
|
|
|
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
|
|
await jargon_miner.process_extracted_entries(
|
|
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
|
|
)
|
|
|
|
with Session(jargon_miner_engine) as session:
|
|
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
|
|
|
|
assert db_jargon.count == 1
|
|
assert db_jargon.session_id_dict == '{"session-a": 2}'
|
|
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
|
|
"[1] first",
|
|
"[2] second",
|
|
]
|