feat: 添加 NapCat 适配器的入站消息编解码功能,增强插件配置更新逻辑和数据库交互测试
This commit is contained in:
81
pytests/common_test/test_expression_learner.py
Normal file
81
pytests/common_test/test_expression_learner.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""测试表达方式学习器的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.bw_learner.expression_learner import ExpressionLearner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_learner_engine")
|
||||
def expression_learner_engine_fixture() -> Generator:
|
||||
"""创建用于表达方式学习器测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_find_similar_expression_uses_read_only_session_and_history_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
expression_learner_engine,
|
||||
) -> None:
|
||||
"""查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
|
||||
import src.bw_learner.expression_learner as expression_learner_module
|
||||
|
||||
with Session(expression_learner_engine) as session:
|
||||
session.add(
|
||||
Expression(
|
||||
situation="发送汗滴表情",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
session = Session(expression_learner_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
|
||||
|
||||
learner = ExpressionLearner(session_id="session-a")
|
||||
result = learner._find_similar_expression("表达情绪高涨或生理反应")
|
||||
|
||||
assert result is not None
|
||||
expression, similarity = result
|
||||
assert expression.item_id is not None
|
||||
assert expression.style == "发送💦表情符号"
|
||||
assert similarity == pytest.approx(1.0)
|
||||
90
pytests/common_test/test_jargon_miner.py
Normal file
90
pytests/common_test/test_jargon_miner.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""测试黑话学习器的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
|
||||
from src.bw_learner.jargon_miner import JargonMiner
|
||||
from src.common.database.database_model import Jargon
|
||||
|
||||
|
||||
@pytest.fixture(name="jargon_miner_engine")
|
||||
def jargon_miner_engine_fixture() -> Generator:
|
||||
"""创建用于黑话学习器测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
jargon_miner_engine,
|
||||
) -> None:
|
||||
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
|
||||
import src.bw_learner.jargon_miner as jargon_miner_module
|
||||
|
||||
with Session(jargon_miner_engine) as session:
|
||||
session.add(
|
||||
Jargon(
|
||||
content="VF8V4L",
|
||||
raw_content='["[1] first"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=0,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
session = Session(jargon_miner_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
|
||||
|
||||
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
|
||||
await jargon_miner.process_extracted_entries(
|
||||
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
|
||||
)
|
||||
|
||||
with Session(jargon_miner_engine) as session:
|
||||
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
|
||||
|
||||
assert db_jargon.count == 1
|
||||
assert db_jargon.session_id_dict == '{"session-a": 2}'
|
||||
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
|
||||
"[1] first",
|
||||
"[2] second",
|
||||
]
|
||||
@@ -3,12 +3,16 @@ from typing import Any, Dict
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
|
||||
if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
|
||||
|
||||
NapCatInboundCodec = importlib.import_module("napcat_adapter.codec_inbound").NapCatInboundCodec
|
||||
NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec
|
||||
|
||||
|
||||
@@ -68,3 +72,80 @@ def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> No
|
||||
|
||||
assert action_name == "send_private_msg"
|
||||
assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"}
|
||||
|
||||
|
||||
class DummyQueryService:
|
||||
"""用于测试的轻量查询服务。"""
|
||||
|
||||
async def download_binary(self, url: str) -> bytes:
|
||||
"""返回固定图片二进制。
|
||||
|
||||
Args:
|
||||
url: 图片地址。
|
||||
|
||||
Returns:
|
||||
bytes: 固定测试图片二进制。
|
||||
"""
|
||||
if url:
|
||||
return b"image-bytes"
|
||||
return b""
|
||||
|
||||
async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None:
|
||||
"""返回空消息详情。
|
||||
|
||||
Args:
|
||||
message_id: 目标消息 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 固定空结果。
|
||||
"""
|
||||
del message_id
|
||||
return None
|
||||
|
||||
async def get_record_detail(self, file_name: str, file_id: str | None = None) -> Dict[str, Any] | None:
|
||||
"""返回空语音详情。
|
||||
|
||||
Args:
|
||||
file_name: 语音文件名。
|
||||
file_id: 可选文件 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 固定空结果。
|
||||
"""
|
||||
del file_name
|
||||
del file_id
|
||||
return None
|
||||
|
||||
async def get_forward_message(self, message_id: str) -> Dict[str, Any] | None:
|
||||
"""返回空转发详情。
|
||||
|
||||
Args:
|
||||
message_id: 转发消息 ID。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 固定空结果。
|
||||
"""
|
||||
del message_id
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_napcat_inbound_codec_parses_cq_string_image_segments() -> None:
|
||||
codec = NapCatInboundCodec(SimpleNamespace(debug=lambda message: None), DummyQueryService())
|
||||
payload = {
|
||||
"message": "[CQ:image,file=test.png,sub_type=0,url=https://example.com/test.png][CQ:at,qq=10001] 看到是国人直接给你封了",
|
||||
}
|
||||
|
||||
raw_message, is_at = await codec.convert_segments(payload, "10001")
|
||||
|
||||
assert raw_message[0]["type"] == "image"
|
||||
assert raw_message[1] == {
|
||||
"type": "at",
|
||||
"data": {
|
||||
"target_user_id": "10001",
|
||||
"target_user_nickname": None,
|
||||
"target_user_cardname": None,
|
||||
},
|
||||
}
|
||||
assert raw_message[2] == {"type": "text", "data": " 看到是国人直接给你封了"}
|
||||
assert is_at is True
|
||||
|
||||
60
pytests/test_napcat_adapter_plugin.py
Normal file
60
pytests/test_napcat_adapter_plugin.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""NapCat 插件入口行为测试。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from types import SimpleNamespace
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
|
||||
if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
|
||||
|
||||
NapCatAdapterPlugin = importlib.import_module("napcat_adapter.plugin").NapCatAdapterPlugin
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
"""用于测试的轻量日志对象。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化测试日志对象。"""
|
||||
self.debug_messages: List[str] = []
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""记录调试日志。
|
||||
|
||||
Args:
|
||||
message: 待记录的日志内容。
|
||||
"""
|
||||
self.debug_messages.append(message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_update_refreshes_settings_and_restarts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""配置更新时应刷新插件配置、清空旧 settings,并触发连接重启。"""
|
||||
plugin = NapCatAdapterPlugin()
|
||||
plugin._ctx = SimpleNamespace(logger=DummyLogger())
|
||||
plugin._settings = object()
|
||||
|
||||
restart_calls: List[dict] = []
|
||||
|
||||
async def fake_restart() -> None:
|
||||
"""记录一次重启调用。"""
|
||||
restart_calls.append(dict(plugin._plugin_config))
|
||||
|
||||
monkeypatch.setattr(plugin, "_restart_connection_if_needed", fake_restart)
|
||||
|
||||
new_config = {
|
||||
"plugin": {"enabled": True, "config_version": "0.1.0"},
|
||||
"napcat_server": {"host": "127.0.0.1", "port": 3001},
|
||||
}
|
||||
await plugin.on_config_update(new_config, "v2")
|
||||
|
||||
assert plugin._plugin_config == new_config
|
||||
assert plugin._settings is None
|
||||
assert restart_calls == [new_config]
|
||||
assert plugin.ctx.logger.debug_messages == ["NapCat 适配器收到配置更新通知: v2"]
|
||||
@@ -2238,6 +2238,7 @@ class TestIntegration:
|
||||
async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path):
|
||||
from src.config.file_watcher import FileChange
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
import json
|
||||
|
||||
builtin_root = tmp_path / "src" / "plugins" / "built_in"
|
||||
thirdparty_root = tmp_path / "plugins"
|
||||
@@ -2247,6 +2248,10 @@ class TestIntegration:
|
||||
beta_dir.mkdir(parents=True)
|
||||
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
|
||||
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
|
||||
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||
(alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
|
||||
(beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
@@ -2257,8 +2262,8 @@ class TestIntegration:
|
||||
self.reload_reasons = []
|
||||
self.config_updates = []
|
||||
|
||||
async def reload_plugins(self, reason="manual"):
|
||||
self.reload_reasons.append(reason)
|
||||
async def reload_plugins(self, plugin_ids=None, reason="manual"):
|
||||
self.reload_reasons.append((plugin_ids, reason))
|
||||
|
||||
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
|
||||
self.config_updates.append((plugin_id, config_data, config_version))
|
||||
@@ -2283,13 +2288,13 @@ class TestIntegration:
|
||||
await manager._handle_plugin_source_changes(changes)
|
||||
|
||||
assert manager._builtin_supervisor.reload_reasons == []
|
||||
assert manager._third_party_supervisor.reload_reasons == ["file_watcher"]
|
||||
assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")]
|
||||
assert manager._builtin_supervisor.config_updates == []
|
||||
assert manager._third_party_supervisor.config_updates == []
|
||||
assert refresh_calls == [True]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
|
||||
async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
from src.config.file_watcher import FileChange
|
||||
|
||||
@@ -2308,27 +2313,35 @@ class TestIntegration:
|
||||
def __init__(self, plugin_dirs, plugins):
|
||||
self._plugin_dirs = plugin_dirs
|
||||
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
|
||||
self.config_updates = []
|
||||
self.reload_calls = []
|
||||
|
||||
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
|
||||
self.config_updates.append((plugin_id, config_data, config_version))
|
||||
async def reload_plugin(self, plugin_id, reason="manual"):
|
||||
self.reload_calls.append((plugin_id, reason))
|
||||
return True
|
||||
|
||||
manager = integration_module.PluginRuntimeManager()
|
||||
manager._started = True
|
||||
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
|
||||
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
|
||||
refresh_calls = []
|
||||
|
||||
def fake_refresh() -> None:
|
||||
refresh_calls.append(True)
|
||||
|
||||
manager._refresh_plugin_config_watch_subscriptions = fake_refresh
|
||||
|
||||
await manager._handle_plugin_config_changes(
|
||||
"alpha",
|
||||
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
|
||||
)
|
||||
|
||||
assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "")]
|
||||
assert manager._third_party_supervisor.config_updates == []
|
||||
assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")]
|
||||
assert manager._third_party_supervisor.reload_calls == []
|
||||
assert refresh_calls == [True]
|
||||
|
||||
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
import json
|
||||
|
||||
builtin_root = tmp_path / "src" / "plugins" / "built_in"
|
||||
thirdparty_root = tmp_path / "plugins"
|
||||
@@ -2336,6 +2349,10 @@ class TestIntegration:
|
||||
beta_dir = thirdparty_root / "beta"
|
||||
alpha_dir.mkdir(parents=True)
|
||||
beta_dir.mkdir(parents=True)
|
||||
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||
(alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
|
||||
(beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
|
||||
|
||||
class FakeWatcher:
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user