Files
mai-bot/pytests/test_napcat_history_recovery.py

377 lines
12 KiB
Python

"""NapCat 历史补拉与恢复状态测试。"""
from __future__ import annotations
from importlib import import_module, util
from pathlib import Path
from typing import Any, Dict, List
import logging
import sys
from types import SimpleNamespace
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEST_MODULE = "_test_napcat_adapter_history_recovery"
for import_path in (str(SDK_ROOT),):
if import_path not in sys.path:
sys.path.insert(0, import_path)
class _FakeGatewayCapability:
"""用于测试入站注入的网关替身。"""
def __init__(self) -> None:
"""初始化测试替身。"""
self.calls: List[Dict[str, Any]] = []
async def route_message(
self,
gateway_name: str,
message: Dict[str, Any],
*,
route_metadata: Dict[str, Any] | None = None,
external_message_id: str = "",
dedupe_key: str = "",
) -> bool:
"""记录入站注入请求并始终模拟成功。"""
self.calls.append(
{
"gateway_name": gateway_name,
"message": dict(message),
"route_metadata": dict(route_metadata or {}),
"external_message_id": external_message_id,
"dedupe_key": dedupe_key,
}
)
return True
def _resolve_napcat_plugin_dir() -> Path:
"""返回当前测试可用的 NapCat 插件目录。"""
if NAPCAT_PLUGIN_DIR.is_dir():
return NAPCAT_PLUGIN_DIR
return NAPCAT_TEMPLATE_DIR
def _load_napcat_module(module_suffix: str) -> Any:
"""动态加载 NapCat 测试模块。"""
plugin_dir = _resolve_napcat_plugin_dir()
if NAPCAT_TEST_MODULE not in sys.modules:
plugin_path = plugin_dir / "plugin.py"
spec = util.spec_from_file_location(
NAPCAT_TEST_MODULE,
plugin_path,
submodule_search_locations=[str(plugin_dir)],
)
if spec is None or spec.loader is None:
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
module = util.module_from_spec(spec)
sys.modules[NAPCAT_TEST_MODULE] = module
try:
spec.loader.exec_module(module)
except Exception:
sys.modules.pop(NAPCAT_TEST_MODULE, None)
raise
return import_module(f"{NAPCAT_TEST_MODULE}.{module_suffix}")
def _load_history_recovery_store_cls() -> Any:
"""动态加载历史恢复状态仓库类。"""
return _load_napcat_module("services.history_recovery_store").NapCatHistoryRecoveryStore
def _load_query_service_cls() -> Any:
"""动态加载查询服务类。"""
return _load_napcat_module("services.query_service").NapCatQueryService
def _load_router_cls() -> Any:
"""动态加载事件路由器类。"""
return _load_napcat_module("runtime.router").NapCatEventRouter
class _FakeActionService:
"""用于查询服务的动作服务替身。"""
def __init__(self, response_data: Any) -> None:
"""初始化动作服务替身。"""
self._response_data = response_data
self.action_data_calls: List[Dict[str, Any]] = []
async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any:
"""记录安全查询动作。"""
self.action_data_calls.append({"action_name": action_name, "params": dict(params)})
return self._response_data
@pytest.mark.asyncio
async def test_history_recovery_store_persists_checkpoint_and_seen_state(tmp_path: Path) -> None:
"""历史恢复状态仓库应持久化 checkpoint 与已补拉标记。"""
store_cls = _load_history_recovery_store_cls()
store = store_cls(
logger=logging.getLogger("test.napcat.history_store"),
storage_path=tmp_path / "history.sqlite3",
)
await store.load()
await store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-2",
message_time=200.0,
message_seq=2,
)
await store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-1",
message_time=100.0,
message_seq=1,
)
await store.mark_recovered_message_seen(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
external_message_id="history-1",
)
checkpoints = await store.list_checkpoints("10001", scope="primary")
assert len(checkpoints) == 1
assert checkpoints[0].last_message_id == "msg-2"
assert checkpoints[0].last_message_seq == 2
assert (
await store.has_recovered_message_seen(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
external_message_id="history-1",
)
is True
)
@pytest.mark.asyncio
async def test_query_service_wraps_group_and_friend_history_actions() -> None:
"""查询服务应按官方动作名封装历史消息接口。"""
query_service_cls = _load_query_service_cls()
action_service = _FakeActionService([{"message_id": "msg-1"}])
query_service = query_service_cls(
action_service=action_service,
logger=logging.getLogger("test.napcat.history_query"),
)
group_payload = await query_service.get_group_message_history("20001", message_seq=123, count=10)
friend_payload = await query_service.get_friend_message_history("30001", count=5, reverse_order=True)
assert group_payload == [{"message_id": "msg-1"}]
assert friend_payload == [{"message_id": "msg-1"}]
assert action_service.action_data_calls == [
{
"action_name": "get_group_msg_history",
"params": {"group_id": "20001", "count": 10, "reverse_order": False, "message_seq": 123},
},
{
"action_name": "get_friend_msg_history",
"params": {"user_id": "30001", "count": 5, "reverse_order": True},
},
]
@pytest.mark.asyncio
async def test_router_recover_recent_history_reinjects_messages_in_order(tmp_path: Path) -> None:
"""重连补拉应按时间顺序将历史消息重新注入原入站路径。"""
history_store_cls = _load_history_recovery_store_cls()
router_cls = _load_router_cls()
gateway_capability = _FakeGatewayCapability()
router = router_cls(
gateway_capability=gateway_capability,
logger=logging.getLogger("test.napcat.history_router"),
gateway_name="napcat_gateway",
load_settings=lambda: SimpleNamespace(
napcat_server=SimpleNamespace(connection_id="primary", heartbeat_interval=30.0),
filters=SimpleNamespace(ignore_self_message=True),
chat=SimpleNamespace(ban_qq_bot=False),
),
)
history_store = history_store_cls(
logger=logging.getLogger("test.napcat.history_router.store"),
storage_path=tmp_path / "router.sqlite3",
)
await history_store.load()
await history_store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-1",
message_time=100.0,
message_seq=10,
)
history_calls: List[Dict[str, Any]] = []
history_payloads = [
{
"post_type": "message",
"message_type": "group",
"self_id": "10001",
"group_id": "20001",
"user_id": "30002",
"message_id": "msg-3",
"message_seq": 12,
"time": 102,
"message": [{"type": "text", "data": {"text": "第三条"}}],
"sender": {"user_id": "30002", "nickname": "用户二"},
},
{
"post_type": "message",
"message_type": "group",
"self_id": "10001",
"group_id": "20001",
"user_id": "30001",
"message_id": "msg-2",
"message_seq": 11,
"time": 101,
"message": [{"type": "text", "data": {"text": "第二条"}}],
"sender": {"user_id": "30001", "nickname": "用户一"},
},
]
class _FakeQueryService:
async def get_group_message_history(
self,
group_id: str,
*,
message_seq: int | None = None,
count: int = 20,
reverse_order: bool = False,
) -> List[Dict[str, Any]]:
history_calls.append(
{
"group_id": group_id,
"message_seq": message_seq,
"count": count,
"reverse_order": reverse_order,
}
)
return list(history_payloads)
async def get_friend_message_history(self, user_id: str, **kwargs: Any) -> List[Dict[str, Any]]:
del user_id
del kwargs
return []
class _FakeInboundCodec:
@staticmethod
async def build_message_dict(
payload: Dict[str, Any],
self_id: str,
sender_user_id: str,
sender: Dict[str, Any],
) -> Dict[str, Any]:
del self_id
del sender_user_id
del sender
return {
"message_id": str(payload["message_id"]),
"platform": "qq",
"timestamp": str(float(payload["time"])),
"message_info": {
"user_info": {"user_id": str(payload["user_id"]), "user_nickname": "测试用户"},
"group_info": {"group_id": str(payload["group_id"]), "group_name": "测试群"},
"additional_config": {},
},
"raw_message": [{"type": "text", "data": str(payload["message"][0]["data"]["text"])}],
"processed_plain_text": str(payload["message"][0]["data"]["text"]),
"display_message": str(payload["message"][0]["data"]["text"]),
"is_mentioned": False,
"is_at": False,
"is_emoji": False,
"is_picture": False,
"is_command": False,
"is_notify": False,
"session_id": "",
}
router.bind_runtime(
SimpleNamespace(
runtime_state=SimpleNamespace(report_connected=lambda *args, **kwargs: _noop_async(), report_disconnected=_noop_async),
chat_filter=SimpleNamespace(is_inbound_chat_allowed=lambda *args, **kwargs: True),
official_bot_guard=SimpleNamespace(
should_reject=lambda *args, **kwargs: _return_false_async(),
clear_cache=lambda: None,
),
inbound_codec=_FakeInboundCodec(),
history_recovery_store=history_store,
query_service=_FakeQueryService(),
heartbeat_monitor=SimpleNamespace(start=_noop_async, stop=_noop_async),
ban_tracker=SimpleNamespace(start=_noop_async, stop=_noop_async, record_notice=_noop_async),
notice_codec=SimpleNamespace(handle_meta_event=_noop_async, build_notice_message_dict=_return_none_async),
)
)
await router._recover_recent_history(self_id="10001", scope="primary")
assert history_calls == [
{
"group_id": "20001",
"message_seq": 10,
"count": 20,
"reverse_order": False,
}
]
assert [call["external_message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
assert [call["message"]["message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
async def _noop_async(*args: Any, **kwargs: Any) -> None:
"""无操作异步函数。"""
del args
del kwargs
async def _return_false_async(*args: Any, **kwargs: Any) -> bool:
"""返回 ``False`` 的异步测试替身。"""
del args
del kwargs
return False
async def _return_none_async(*args: Any, **kwargs: Any) -> None:
"""返回 ``None`` 的异步测试替身。"""
del args
del kwargs
return None