From a566313c5f0b1355d1d99fa9f3cfc10d0a455e08 Mon Sep 17 00:00:00 2001 From: LoveLosita <2810873701@qq.com> Date: Mon, 11 May 2026 15:03:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E6=96=B0=E5=A2=9Enapcat=E6=96=AD=E7=BA=BF?= =?UTF-8?q?=E5=90=8E=E9=87=8D=E8=BF=9E=E9=87=8D=E6=96=B0=E6=8B=89=E5=8F=96?= =?UTF-8?q?=E5=8E=86=E5=8F=B2=E6=B6=88=E6=81=AF=E7=9A=84=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../MaiBot-Napcat-Adapter/constants.py | 3 + .../MaiBot-Napcat-Adapter/runtime/builder.py | 3 + .../MaiBot-Napcat-Adapter/runtime/bundle.py | 2 + .../MaiBot-Napcat-Adapter/runtime/router.py | 304 ++++++++++++- .../services/__init__.py | 3 + .../services/history_recovery_store.py | 423 ++++++++++++++++++ .../services/query_service.py | 40 ++ pytests/test_napcat_adapter_sdk.py | 8 +- pytests/test_napcat_history_recovery.py | 376 ++++++++++++++++ 9 files changed, 1154 insertions(+), 8 deletions(-) create mode 100644 plugin-templates/MaiBot-Napcat-Adapter/services/history_recovery_store.py create mode 100644 pytests/test_napcat_history_recovery.py diff --git a/plugin-templates/MaiBot-Napcat-Adapter/constants.py b/plugin-templates/MaiBot-Napcat-Adapter/constants.py index 8aa3e7eb..95b8bcdc 100644 --- a/plugin-templates/MaiBot-Napcat-Adapter/constants.py +++ b/plugin-templates/MaiBot-Napcat-Adapter/constants.py @@ -8,3 +8,6 @@ DEFAULT_RECONNECT_DELAY_SEC = 5.0 DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0 DEFAULT_ACTION_TIMEOUT_SEC = 15.0 DEFAULT_CHAT_LIST_TYPE = "whitelist" +DEFAULT_HISTORY_RECOVERY_BATCH_SIZE = 20 +DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT = 50 +DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC = 86400.0 * 7 diff --git a/plugin-templates/MaiBot-Napcat-Adapter/runtime/builder.py b/plugin-templates/MaiBot-Napcat-Adapter/runtime/builder.py index 73321a0b..6720d882 100644 --- a/plugin-templates/MaiBot-Napcat-Adapter/runtime/builder.py +++ b/plugin-templates/MaiBot-Napcat-Adapter/runtime/builder.py @@ -14,6 +14,7 @@ from ..services import ( NapCatActionService, NapCatBanStateStore, NapCatBanTracker, + NapCatHistoryRecoveryStore, NapCatOfficialBotGuard, NapCatQueryService, ) @@ -66,6 +67,7 @@ class NapCatRuntimeBuilder: action_service = NapCatActionService(self._logger, transport) query_service = NapCatQueryService(action_service, self._logger) ban_state_store = NapCatBanStateStore(self._logger) + history_recovery_store = NapCatHistoryRecoveryStore(self._logger) inbound_codec = NapCatInboundCodec(self._logger, query_service) notice_codec = NapCatNoticeCodec(self._logger, query_service) runtime_state = NapCatRuntimeStateManager( @@ -92,6 +94,7 @@ class NapCatRuntimeBuilder: ban_tracker=ban_tracker, chat_filter=chat_filter, heartbeat_monitor=heartbeat_monitor, + history_recovery_store=history_recovery_store, inbound_codec=inbound_codec, notice_codec=notice_codec, official_bot_guard=official_bot_guard, diff --git a/plugin-templates/MaiBot-Napcat-Adapter/runtime/bundle.py b/plugin-templates/MaiBot-Napcat-Adapter/runtime/bundle.py index 046723e2..3045bfb1 100644 --- a/plugin-templates/MaiBot-Napcat-Adapter/runtime/bundle.py +++ b/plugin-templates/MaiBot-Napcat-Adapter/runtime/bundle.py @@ -14,6 +14,7 @@ from ..services import ( NapCatActionService, NapCatBanStateStore, NapCatBanTracker, + NapCatHistoryRecoveryStore, NapCatOfficialBotGuard, NapCatQueryService, ) @@ -29,6 +30,7 @@ class NapCatRuntimeBundle: ban_tracker: NapCatBanTracker chat_filter: NapCatChatFilter heartbeat_monitor: NapCatHeartbeatMonitor + history_recovery_store: NapCatHistoryRecoveryStore inbound_codec: NapCatInboundCodec notice_codec: NapCatNoticeCodec official_bot_guard: NapCatOfficialBotGuard diff --git a/plugin-templates/MaiBot-Napcat-Adapter/runtime/router.py b/plugin-templates/MaiBot-Napcat-Adapter/runtime/router.py index 715ea9ae..cf0a8536 100644 --- a/plugin-templates/MaiBot-Napcat-Adapter/runtime/router.py +++ b/plugin-templates/MaiBot-Napcat-Adapter/runtime/router.py @@ -2,11 +2,14 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any, Callable, Dict, Mapping, Optional, Protocol import asyncio from ..config import NapCatPluginSettings +from ..constants import DEFAULT_HISTORY_RECOVERY_BATCH_SIZE, DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT +from ..services import NapCatChatCheckpoint from ..types import NapCatPayloadDict from .bundle import NapCatRuntimeBundle @@ -27,6 +30,14 @@ class _GatewayCapabilityProtocol(Protocol): ... +@dataclass(frozen=True) +class _NapCatChatIdentity: + """描述一条 NapCat 消息所属的会话身份。""" + + chat_type: str + chat_id: str + + class NapCatEventRouter: """协调 NapCat 运行时组件处理各类平台事件。""" @@ -50,6 +61,7 @@ class NapCatEventRouter: self._gateway_name = gateway_name self._load_settings = load_settings self._runtime: Optional[NapCatRuntimeBundle] = None + self._recovery_task: Optional[asyncio.Task[None]] = None def bind_runtime(self, runtime: NapCatRuntimeBundle) -> None: """绑定当前路由器使用的运行时依赖。 @@ -64,6 +76,7 @@ class NapCatEventRouter: runtime = self._runtime if runtime is None: return + self._cancel_recovery_task() runtime.official_bot_guard.clear_cache() async def handle_transport_payload(self, payload: NapCatPayloadDict) -> None: @@ -82,7 +95,7 @@ class NapCatEventRouter: if post_type == "meta_event": await self.handle_meta_event(payload) - async def handle_inbound_message(self, payload: NapCatPayloadDict) -> None: + async def handle_inbound_message(self, payload: NapCatPayloadDict) -> bool: """处理单条 NapCat 入站消息并注入 Host。 Args: @@ -101,25 +114,25 @@ class NapCatEventRouter: sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip() if not sender_user_id: - return + return False group_id = str(payload.get("group_id") or "").strip() if self_id and sender_user_id == self_id and settings.filters.ignore_self_message: - return + return False if not runtime.chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat): - return + return False if await runtime.official_bot_guard.should_reject( sender_user_id=sender_user_id, group_id=group_id, ban_qq_bot=settings.chat.ban_qq_bot, ): - return + return False try: message_dict = await runtime.inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender) except ValueError as exc: self._logger.warning(f"NapCat 入站消息格式不受支持,已丢弃: {exc}") - return + return False route_metadata = self._build_route_metadata(self_id, settings.napcat_server.connection_id) external_message_id = str(payload.get("message_id") or "").strip() @@ -132,6 +145,15 @@ class NapCatEventRouter: ) if not accepted: self._logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}") + return False + + await self._record_inbound_checkpoint( + payload=payload, + self_id=self_id, + external_message_id=external_message_id or str(message_dict.get("message_id") or "").strip(), + scope=settings.napcat_server.connection_id, + ) + return True async def handle_notice_event(self, payload: NapCatPayloadDict) -> None: """处理 NapCat ``notice`` 事件并注入 Host。 @@ -232,6 +254,8 @@ class NapCatEventRouter: await runtime.runtime_state.report_connected(self_id, settings.napcat_server) await runtime.heartbeat_monitor.start(self_id, settings.napcat_server.heartbeat_interval) await runtime.ban_tracker.start() + await runtime.history_recovery_store.load() + self._schedule_history_recovery(self_id=self_id, scope=settings.napcat_server.connection_id) return except asyncio.CancelledError: raise @@ -279,6 +303,274 @@ class NapCatEventRouter: raise RuntimeError("NapCat 运行时尚未初始化") return runtime + def _schedule_history_recovery(self, self_id: str, scope: str) -> None: + """在连接恢复后调度一次历史补拉任务。""" + + self._cancel_recovery_task() + runtime = self._runtime + if runtime is None: + return + + self._recovery_task = asyncio.create_task( + self._recover_recent_history(self_id=self_id, scope=scope), + name="napcat_adapter.history_recovery", + ) + + def _cancel_recovery_task(self) -> None: + """取消当前仍在运行的历史补拉任务。""" + + recovery_task = self._recovery_task + self._recovery_task = None + if recovery_task is not None and not recovery_task.done(): + recovery_task.cancel() + + async def _recover_recent_history(self, *, self_id: str, scope: str) -> None: + """按 checkpoint 列表逐个尝试补拉断线期间遗漏的消息。""" + + runtime = self._require_runtime() + checkpoints = await runtime.history_recovery_store.list_checkpoints( + self_id, + scope=scope, + limit=DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT, + ) + if not checkpoints: + return + + recovered_count = 0 + for checkpoint in checkpoints: + recovered_count += await self._recover_chat_history_from_checkpoint( + self_id=self_id, + scope=scope, + checkpoint=checkpoint, + ) + + if recovered_count > 0: + self._logger.info(f"NapCat 历史补拉完成,共补回 {recovered_count} 条消息") + + async def _recover_chat_history_from_checkpoint( + self, + *, + self_id: str, + scope: str, + checkpoint: NapCatChatCheckpoint, + ) -> int: + """针对单个会话执行一次小批量历史补拉。""" + + runtime = self._require_runtime() + history_messages = await self._query_history_messages(checkpoint, limit=DEFAULT_HISTORY_RECOVERY_BATCH_SIZE) + if not history_messages: + return 0 + + ordered_messages = sorted( + history_messages, + key=lambda item: ( + self._extract_message_timestamp(item), + self._extract_message_seq(item), + str(item.get("message_id") or "").strip(), + ), + ) + + recovered_count = 0 + for history_payload in ordered_messages: + external_message_id = str(history_payload.get("message_id") or "").strip() + if not external_message_id: + continue + if external_message_id == checkpoint.last_message_id: + continue + if await runtime.history_recovery_store.has_recovered_message_seen( + account_id=self_id, + scope=scope, + chat_type=checkpoint.chat_type, + chat_id=checkpoint.chat_id, + external_message_id=external_message_id, + ): + continue + if not self._is_message_after_checkpoint(history_payload, checkpoint): + continue + accepted = await self._reinject_history_payload(history_payload, self_id=self_id) + if not accepted: + continue + await runtime.history_recovery_store.mark_recovered_message_seen( + account_id=self_id, + scope=scope, + chat_type=checkpoint.chat_type, + chat_id=checkpoint.chat_id, + external_message_id=external_message_id, + ) + recovered_count += 1 + + return recovered_count + + async def _query_history_messages( + self, + checkpoint: NapCatChatCheckpoint, + *, + limit: int, + ) -> list[NapCatPayloadDict]: + """查询某个会话在 checkpoint 之后的一小批历史消息。""" + + runtime = self._require_runtime() + payload_collections: list[list[NapCatPayloadDict]] = [] + if checkpoint.last_message_seq is not None: + payload_collections.append( + await self._fetch_history_messages( + chat_type=checkpoint.chat_type, + chat_id=checkpoint.chat_id, + message_seq=checkpoint.last_message_seq, + limit=limit, + ) + ) + payload_collections.append( + await self._fetch_history_messages( + chat_type=checkpoint.chat_type, + chat_id=checkpoint.chat_id, + message_seq=None, + limit=limit, + ) + ) + + merged_payloads: list[NapCatPayloadDict] = [] + seen_message_ids: set[str] = set() + for payloads in payload_collections: + for payload in payloads: + external_message_id = str(payload.get("message_id") or "").strip() + dedupe_key = external_message_id or repr(sorted(payload.items())) + if dedupe_key in seen_message_ids: + continue + seen_message_ids.add(dedupe_key) + merged_payloads.append(payload) + return merged_payloads + + async def _fetch_history_messages( + self, + *, + chat_type: str, + chat_id: str, + message_seq: int | None, + limit: int, + ) -> list[NapCatPayloadDict]: + """调用查询服务获取一批历史消息。""" + + runtime = self._require_runtime() + if chat_type == "group": + history_payloads = await runtime.query_service.get_group_message_history( + chat_id, + message_seq=message_seq, + count=limit, + reverse_order=False, + ) + elif chat_type == "private": + history_payloads = await runtime.query_service.get_friend_message_history( + chat_id, + message_seq=message_seq, + count=limit, + reverse_order=False, + ) + else: + return [] + + if history_payloads is None: + return [] + return [dict(payload) for payload in history_payloads if isinstance(payload, Mapping)] + + async def _reinject_history_payload(self, payload: NapCatPayloadDict, *, self_id: str) -> bool: + """将补拉到的历史消息重新送回实时入站路径。""" + + try: + normalized_payload = dict(payload) + if self_id and not str(normalized_payload.get("self_id") or "").strip(): + normalized_payload["self_id"] = self_id + return await self.handle_inbound_message(normalized_payload) + except asyncio.CancelledError: + raise + except Exception as exc: + external_message_id = str(payload.get("message_id") or "").strip() or "unknown" + self._logger.warning(f"NapCat 历史消息补拉注入失败: message_id={external_message_id} error={exc}") + return False + + async def _record_inbound_checkpoint( + self, + *, + payload: NapCatPayloadDict, + self_id: str, + external_message_id: str, + scope: str, + ) -> None: + """在消息被 Host 接受后更新该会话的最新 checkpoint。""" + + runtime = self._require_runtime() + chat_identity = self._extract_chat_identity(payload) + if chat_identity is None: + return + + await runtime.history_recovery_store.record_checkpoint( + account_id=self_id, + scope=scope, + chat_type=chat_identity.chat_type, + chat_id=chat_identity.chat_id, + message_id=external_message_id, + message_time=self._extract_message_timestamp(payload), + message_seq=self._extract_message_seq(payload), + ) + + @staticmethod + def _extract_chat_identity(payload: Mapping[str, Any]) -> _NapCatChatIdentity | None: + """从 NapCat 载荷中提取会话身份。""" + + group_id = str(payload.get("group_id") or "").strip() + user_id = str(payload.get("user_id") or "").strip() + + if group_id: + return _NapCatChatIdentity(chat_type="group", chat_id=group_id) + if user_id: + return _NapCatChatIdentity(chat_type="private", chat_id=user_id) + return None + + @staticmethod + def _extract_message_seq(payload: Mapping[str, Any]) -> int | None: + """从 NapCat 载荷中提取历史接口可复用的消息序号。""" + + for field_name in ("message_seq", "messageSeq", "msg_seq"): + raw_value = payload.get(field_name) + if raw_value is None or str(raw_value).strip() == "": + continue + try: + return int(raw_value) + except (TypeError, ValueError): + continue + return None + + @staticmethod + def _extract_message_timestamp(payload: Mapping[str, Any]) -> float: + """从 NapCat 载荷中提取消息时间戳。""" + + raw_timestamp = payload.get("time") + if isinstance(raw_timestamp, (int, float)): + return float(raw_timestamp) + return 0.0 + + @classmethod + def _is_message_after_checkpoint( + cls, + payload: Mapping[str, Any], + checkpoint: NapCatChatCheckpoint, + ) -> bool: + """判断历史消息是否位于 checkpoint 之后。""" + + payload_message_id = str(payload.get("message_id") or "").strip() + if payload_message_id == checkpoint.last_message_id: + return False + + payload_message_seq = cls._extract_message_seq(payload) + if payload_message_seq is not None and checkpoint.last_message_seq is not None: + return payload_message_seq > checkpoint.last_message_seq + + payload_timestamp = cls._extract_message_timestamp(payload) + if payload_timestamp != checkpoint.last_message_time: + return payload_timestamp > checkpoint.last_message_time + + return True + @staticmethod def _build_route_metadata(self_id: str, connection_id: str) -> Dict[str, Any]: """构造注入 Host 时使用的路由元数据。 diff --git a/plugin-templates/MaiBot-Napcat-Adapter/services/__init__.py b/plugin-templates/MaiBot-Napcat-Adapter/services/__init__.py index 84d0543f..cc10abc3 100644 --- a/plugin-templates/MaiBot-Napcat-Adapter/services/__init__.py +++ b/plugin-templates/MaiBot-Napcat-Adapter/services/__init__.py @@ -3,6 +3,7 @@ from .action_service import NapCatActionService from .ban_tracker import NapCatBanTracker from .ban_state_store import NapCatBanRecord, NapCatBanStateStore +from .history_recovery_store import NapCatChatCheckpoint, NapCatHistoryRecoveryStore from .official_bot_guard import NapCatOfficialBotGuard from .query_service import NapCatQueryService @@ -11,6 +12,8 @@ __all__ = [ "NapCatBanRecord", "NapCatBanStateStore", "NapCatBanTracker", + "NapCatChatCheckpoint", + "NapCatHistoryRecoveryStore", "NapCatOfficialBotGuard", "NapCatQueryService", ] diff --git a/plugin-templates/MaiBot-Napcat-Adapter/services/history_recovery_store.py b/plugin-templates/MaiBot-Napcat-Adapter/services/history_recovery_store.py new file mode 100644 index 00000000..6dad2771 --- /dev/null +++ b/plugin-templates/MaiBot-Napcat-Adapter/services/history_recovery_store.py @@ -0,0 +1,423 @@ +"""NapCat 历史补拉状态持久化仓库。""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, List, Optional, TypeVar + +import asyncio +import sqlite3 +import time + +from ..constants import DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC + +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +_DEFAULT_STORAGE_PATH = _PROJECT_ROOT / "data" / "napcat_adapter" / "history_recovery.sqlite3" + +_SCHEMA_STATEMENTS = ( + """ + CREATE TABLE IF NOT EXISTS napcat_chat_checkpoint ( + account_id TEXT NOT NULL, + scope TEXT NOT NULL, + chat_type TEXT NOT NULL, + chat_id TEXT NOT NULL, + last_message_id TEXT NOT NULL, + last_message_time REAL NOT NULL, + last_message_seq INTEGER, + updated_at REAL NOT NULL, + PRIMARY KEY (account_id, scope, chat_type, chat_id) + ) + """, + """ + CREATE INDEX IF NOT EXISTS ix_napcat_chat_checkpoint_updated_at + ON napcat_chat_checkpoint (updated_at DESC) + """, + """ + CREATE TABLE IF NOT EXISTS napcat_recovery_seen ( + account_id TEXT NOT NULL, + scope TEXT NOT NULL, + chat_type TEXT NOT NULL, + chat_id TEXT NOT NULL, + external_message_id TEXT NOT NULL, + seen_at REAL NOT NULL, + PRIMARY KEY (account_id, scope, chat_type, chat_id, external_message_id) + ) + """, + """ + CREATE INDEX IF NOT EXISTS ix_napcat_recovery_seen_seen_at + ON napcat_recovery_seen (seen_at DESC) + """, +) + +T = TypeVar("T") + + +@dataclass(frozen=True) +class NapCatChatCheckpoint: + """描述一个会话的最近入站锚点。""" + + account_id: str + scope: str + chat_type: str + chat_id: str + last_message_id: str + last_message_time: float + last_message_seq: int | None + updated_at: float + + @classmethod + def from_row(cls, row: sqlite3.Row) -> "NapCatChatCheckpoint": + """从 SQLite 行对象恢复 checkpoint。""" + + last_message_seq = row["last_message_seq"] + normalized_seq = int(last_message_seq) if isinstance(last_message_seq, int) else None + return cls( + account_id=str(row["account_id"] or "").strip(), + scope=str(row["scope"] or "").strip(), + chat_type=str(row["chat_type"] or "").strip(), + chat_id=str(row["chat_id"] or "").strip(), + last_message_id=str(row["last_message_id"] or "").strip(), + last_message_time=float(row["last_message_time"] or 0.0), + last_message_seq=normalized_seq, + updated_at=float(row["updated_at"] or 0.0), + ) + + +class NapCatHistoryRecoveryStore: + """负责持久化历史补拉所需的会话状态与去重状态。""" + + def __init__(self, logger: Any, storage_path: Path = _DEFAULT_STORAGE_PATH) -> None: + """初始化历史补拉状态仓库。""" + + self._logger = logger + self._storage_path = storage_path + self._store_lock = asyncio.Lock() + self._schema_ready = False + + async def load(self) -> None: + """初始化 SQLite 文件并清理过期去重记录。""" + + await self._execute_locked(self._ensure_schema) + pruned_count = await self.prune_recovery_seen(DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC) + if pruned_count > 0: + self._logger.debug(f"NapCat 历史补拉去重表已清理 {pruned_count} 条过期记录") + + async def list_checkpoints(self, account_id: str, scope: str = "", limit: int = 50) -> List[NapCatChatCheckpoint]: + """列出指定账号与作用域下的最近会话 checkpoint。""" + + normalized_account_id = str(account_id or "").strip() + if not normalized_account_id: + return [] + + normalized_scope = self._normalize_scope(scope) + normalized_limit = max(1, int(limit)) + + def _operation(conn: sqlite3.Connection) -> List[NapCatChatCheckpoint]: + cursor = conn.execute( + """ + SELECT + account_id, + scope, + chat_type, + chat_id, + last_message_id, + last_message_time, + last_message_seq, + updated_at + FROM napcat_chat_checkpoint + WHERE account_id = ? AND scope = ? + ORDER BY updated_at DESC + LIMIT ? + """, + (normalized_account_id, normalized_scope, normalized_limit), + ) + return [NapCatChatCheckpoint.from_row(row) for row in cursor.fetchall()] + + return await self._execute_locked(_operation) + + async def record_checkpoint( + self, + *, + account_id: str, + scope: str = "", + chat_type: str, + chat_id: str, + message_id: str, + message_time: float, + message_seq: int | None = None, + ) -> None: + """记录一条已被 Host 接受的最新入站消息锚点。""" + + normalized_account_id = str(account_id or "").strip() + normalized_scope = self._normalize_scope(scope) + normalized_chat_type = str(chat_type or "").strip() + normalized_chat_id = str(chat_id or "").strip() + normalized_message_id = str(message_id or "").strip() + + if not ( + normalized_account_id + and normalized_chat_type + and normalized_chat_id + and normalized_message_id + ): + return + + normalized_message_time = float(message_time or 0.0) + normalized_message_seq = self._normalize_message_seq(message_seq) + updated_at = time.time() + + def _operation(conn: sqlite3.Connection) -> None: + cursor = conn.execute( + """ + SELECT last_message_id, last_message_time, last_message_seq + FROM napcat_chat_checkpoint + WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ? + """, + ( + normalized_account_id, + normalized_scope, + normalized_chat_type, + normalized_chat_id, + ), + ) + existing_row = cursor.fetchone() + if existing_row is not None and not self._should_advance_checkpoint( + existing_row=existing_row, + message_id=normalized_message_id, + message_time=normalized_message_time, + message_seq=normalized_message_seq, + ): + return + + conn.execute( + """ + INSERT INTO napcat_chat_checkpoint ( + account_id, + scope, + chat_type, + chat_id, + last_message_id, + last_message_time, + last_message_seq, + updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(account_id, scope, chat_type, chat_id) DO UPDATE SET + last_message_id = excluded.last_message_id, + last_message_time = excluded.last_message_time, + last_message_seq = excluded.last_message_seq, + updated_at = excluded.updated_at + """, + ( + normalized_account_id, + normalized_scope, + normalized_chat_type, + normalized_chat_id, + normalized_message_id, + normalized_message_time, + normalized_message_seq, + updated_at, + ), + ) + + await self._execute_locked(_operation) + + async def has_recovered_message_seen( + self, + *, + account_id: str, + scope: str = "", + chat_type: str, + chat_id: str, + external_message_id: str, + ) -> bool: + """判断某条历史补拉消息是否已经被当前仓库记录过。""" + + normalized_account_id = str(account_id or "").strip() + normalized_scope = self._normalize_scope(scope) + normalized_chat_type = str(chat_type or "").strip() + normalized_chat_id = str(chat_id or "").strip() + normalized_external_message_id = str(external_message_id or "").strip() + + if not ( + normalized_account_id + and normalized_chat_type + and normalized_chat_id + and normalized_external_message_id + ): + return False + + def _operation(conn: sqlite3.Connection) -> bool: + cursor = conn.execute( + """ + SELECT 1 + FROM napcat_recovery_seen + WHERE account_id = ? + AND scope = ? + AND chat_type = ? + AND chat_id = ? + AND external_message_id = ? + LIMIT 1 + """, + ( + normalized_account_id, + normalized_scope, + normalized_chat_type, + normalized_chat_id, + normalized_external_message_id, + ), + ) + return cursor.fetchone() is not None + + return await self._execute_locked(_operation) + + async def mark_recovered_message_seen( + self, + *, + account_id: str, + scope: str = "", + chat_type: str, + chat_id: str, + external_message_id: str, + ) -> None: + """将一条历史补拉消息标记为已尝试处理。""" + + normalized_account_id = str(account_id or "").strip() + normalized_scope = self._normalize_scope(scope) + normalized_chat_type = str(chat_type or "").strip() + normalized_chat_id = str(chat_id or "").strip() + normalized_external_message_id = str(external_message_id or "").strip() + + if not ( + normalized_account_id + and normalized_chat_type + and normalized_chat_id + and normalized_external_message_id + ): + return + + def _operation(conn: sqlite3.Connection) -> None: + conn.execute( + """ + INSERT OR REPLACE INTO napcat_recovery_seen ( + account_id, + scope, + chat_type, + chat_id, + external_message_id, + seen_at + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ( + normalized_account_id, + normalized_scope, + normalized_chat_type, + normalized_chat_id, + normalized_external_message_id, + time.time(), + ), + ) + + await self._execute_locked(_operation) + + async def prune_recovery_seen(self, ttl_seconds: float) -> int: + """删除超过保留期的历史补拉去重记录。""" + + normalized_ttl_seconds = max(0.0, float(ttl_seconds or 0.0)) + if normalized_ttl_seconds <= 0.0: + return 0 + + cutoff_timestamp = time.time() - normalized_ttl_seconds + + def _operation(conn: sqlite3.Connection) -> int: + cursor = conn.execute( + "DELETE FROM napcat_recovery_seen WHERE seen_at < ?", + (cutoff_timestamp,), + ) + return int(cursor.rowcount or 0) + + return await self._execute_locked(_operation) + + async def _execute_locked(self, operation: Callable[[sqlite3.Connection], T]) -> T: + """在锁保护下打开 SQLite 并执行一次原子操作。""" + + async with self._store_lock: + conn = self._open_connection() + try: + self._ensure_schema(conn) + result = operation(conn) + conn.commit() + return result + except Exception: + conn.rollback() + raise + finally: + conn.close() + + def _open_connection(self) -> sqlite3.Connection: + """打开一个带 Row 工厂的 SQLite 连接。""" + + self._storage_path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(self._storage_path)) + conn.row_factory = sqlite3.Row + return conn + + def _ensure_schema(self, conn: sqlite3.Connection) -> None: + """确保 SQLite 表结构已经准备完成。""" + + if self._schema_ready: + return + + for statement in _SCHEMA_STATEMENTS: + conn.execute(statement) + self._schema_ready = True + + @staticmethod + def _normalize_scope(scope: str | None) -> str: + """将空作用域统一折叠为空字符串。""" + + return str(scope or "").strip() + + @staticmethod + def _normalize_message_seq(message_seq: object) -> int | None: + """将消息序号规范化为可选整数。""" + + try: + if message_seq is None or str(message_seq).strip() == "": + return None + return int(message_seq) + except (TypeError, ValueError): + return None + + @classmethod + def _should_advance_checkpoint( + cls, + *, + existing_row: sqlite3.Row, + message_id: str, + message_time: float, + message_seq: int | None, + ) -> bool: + """判断新的锚点是否应覆盖旧锚点。""" + + existing_message_id = str(existing_row["last_message_id"] or "").strip() + existing_message_time = float(existing_row["last_message_time"] or 0.0) + existing_message_seq = cls._normalize_message_seq(existing_row["last_message_seq"]) + + if message_seq is not None and existing_message_seq is not None: + if message_seq != existing_message_seq: + return message_seq > existing_message_seq + if message_id == existing_message_id: + return False + return message_time >= existing_message_time + + if message_time != existing_message_time: + return message_time > existing_message_time + + if message_id == existing_message_id: + return False + + if message_seq is not None and existing_message_seq is None: + return True + + return True diff --git a/plugin-templates/MaiBot-Napcat-Adapter/services/query_service.py b/plugin-templates/MaiBot-Napcat-Adapter/services/query_service.py index b24209ee..6a45a9a2 100644 --- a/plugin-templates/MaiBot-Napcat-Adapter/services/query_service.py +++ b/plugin-templates/MaiBot-Napcat-Adapter/services/query_service.py @@ -180,6 +180,46 @@ class NapCatQueryService: response_data = await self._safe_call_action_data("get_msg", {"message_id": message_id}) return response_data if isinstance(response_data, dict) else None + async def get_friend_message_history( + self, + user_id: str, + *, + message_seq: int | None = None, + count: int = 20, + reverse_order: bool = False, + ) -> Optional[NapCatPayloadList]: + """获取私聊历史消息列表。""" + + params: NapCatActionResponse = { + "user_id": user_id, + "count": max(1, int(count)), + "reverse_order": bool(reverse_order), + } + if message_seq is not None: + params["message_seq"] = int(message_seq) + response_data = await self._safe_call_action_data("get_friend_msg_history", params) + return self._normalize_payload_list(response_data, action_name="get_friend_msg_history") + + async def get_group_message_history( + self, + group_id: str, + *, + message_seq: int | None = None, + count: int = 20, + reverse_order: bool = False, + ) -> Optional[NapCatPayloadList]: + """获取群聊历史消息列表。""" + + params: NapCatActionResponse = { + "group_id": group_id, + "count": max(1, int(count)), + "reverse_order": bool(reverse_order), + } + if message_seq is not None: + params["message_seq"] = int(message_seq) + response_data = await self._safe_call_action_data("get_group_msg_history", params) + return self._normalize_payload_list(response_data, action_name="get_group_msg_history") + async def get_forward_message( self, message_id: Optional[str] = None, diff --git a/pytests/test_napcat_adapter_sdk.py b/pytests/test_napcat_adapter_sdk.py index 40c9aa4b..f53408b9 100644 --- a/pytests/test_napcat_adapter_sdk.py +++ b/pytests/test_napcat_adapter_sdk.py @@ -12,8 +12,10 @@ import pytest PROJECT_ROOT = Path(__file__).resolve().parents[1] PLUGINS_ROOT = PROJECT_ROOT / "plugins" +PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates" SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter" +NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter" NAPCAT_TEST_MODULE = "_test_napcat_adapter" for import_path in (str(SDK_ROOT),): @@ -204,12 +206,14 @@ def _load_napcat_sdk_modules() -> Tuple[Any, Any, Any, Any]: 依次返回常量模块、配置模块、插件模块和运行时状态模块。 """ + plugin_dir = NAPCAT_PLUGIN_DIR if NAPCAT_PLUGIN_DIR.is_dir() else NAPCAT_TEMPLATE_DIR + if NAPCAT_TEST_MODULE not in sys.modules: - plugin_path = NAPCAT_PLUGIN_DIR / "plugin.py" + plugin_path = plugin_dir / "plugin.py" spec = util.spec_from_file_location( NAPCAT_TEST_MODULE, plugin_path, - submodule_search_locations=[str(NAPCAT_PLUGIN_DIR)], + submodule_search_locations=[str(plugin_dir)], ) if spec is None or spec.loader is None: raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}") diff --git a/pytests/test_napcat_history_recovery.py b/pytests/test_napcat_history_recovery.py new file mode 100644 index 00000000..30bc58f5 --- /dev/null +++ b/pytests/test_napcat_history_recovery.py @@ -0,0 +1,376 @@ +"""NapCat 历史补拉与恢复状态测试。""" + +from __future__ import annotations + +from importlib import import_module, util +from pathlib import Path +from typing import Any, Dict, List + +import logging +import sys +from types import SimpleNamespace + +import pytest + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +PLUGINS_ROOT = PROJECT_ROOT / "plugins" +PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates" +SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" +NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter" +NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter" +NAPCAT_TEST_MODULE = "_test_napcat_adapter_history_recovery" + +for import_path in (str(SDK_ROOT),): + if import_path not in sys.path: + sys.path.insert(0, import_path) + + +class _FakeGatewayCapability: + """用于测试入站注入的网关替身。""" + + def __init__(self) -> None: + """初始化测试替身。""" + + self.calls: List[Dict[str, Any]] = [] + + async def route_message( + self, + gateway_name: str, + message: Dict[str, Any], + *, + route_metadata: Dict[str, Any] | None = None, + external_message_id: str = "", + dedupe_key: str = "", + ) -> bool: + """记录入站注入请求并始终模拟成功。""" + + self.calls.append( + { + "gateway_name": gateway_name, + "message": dict(message), + "route_metadata": dict(route_metadata or {}), + "external_message_id": external_message_id, + "dedupe_key": dedupe_key, + } + ) + return True + + +def _resolve_napcat_plugin_dir() -> Path: + """返回当前测试可用的 NapCat 插件目录。""" + + if NAPCAT_PLUGIN_DIR.is_dir(): + return NAPCAT_PLUGIN_DIR + return NAPCAT_TEMPLATE_DIR + + +def _load_napcat_module(module_suffix: str) -> Any: + """动态加载 NapCat 测试模块。""" + + plugin_dir = _resolve_napcat_plugin_dir() + if NAPCAT_TEST_MODULE not in sys.modules: + plugin_path = plugin_dir / "plugin.py" + spec = util.spec_from_file_location( + NAPCAT_TEST_MODULE, + plugin_path, + submodule_search_locations=[str(plugin_dir)], + ) + if spec is None or spec.loader is None: + raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}") + + module = util.module_from_spec(spec) + sys.modules[NAPCAT_TEST_MODULE] = module + try: + spec.loader.exec_module(module) + except Exception: + sys.modules.pop(NAPCAT_TEST_MODULE, None) + raise + + return import_module(f"{NAPCAT_TEST_MODULE}.{module_suffix}") + + +def _load_history_recovery_store_cls() -> Any: + """动态加载历史恢复状态仓库类。""" + + return _load_napcat_module("services.history_recovery_store").NapCatHistoryRecoveryStore + + +def _load_query_service_cls() -> Any: + """动态加载查询服务类。""" + + return _load_napcat_module("services.query_service").NapCatQueryService + + +def _load_router_cls() -> Any: + """动态加载事件路由器类。""" + + return _load_napcat_module("runtime.router").NapCatEventRouter + + +class _FakeActionService: + """用于查询服务的动作服务替身。""" + + def __init__(self, response_data: Any) -> None: + """初始化动作服务替身。""" + + self._response_data = response_data + self.action_data_calls: List[Dict[str, Any]] = [] + + async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any: + """记录安全查询动作。""" + + self.action_data_calls.append({"action_name": action_name, "params": dict(params)}) + return self._response_data + + +@pytest.mark.asyncio +async def test_history_recovery_store_persists_checkpoint_and_seen_state(tmp_path: Path) -> None: + """历史恢复状态仓库应持久化 checkpoint 与已补拉标记。""" + + store_cls = _load_history_recovery_store_cls() + store = store_cls( + logger=logging.getLogger("test.napcat.history_store"), + storage_path=tmp_path / "history.sqlite3", + ) + + await store.load() + await store.record_checkpoint( + account_id="10001", + scope="primary", + chat_type="group", + chat_id="20001", + message_id="msg-2", + message_time=200.0, + message_seq=2, + ) + await store.record_checkpoint( + account_id="10001", + scope="primary", + chat_type="group", + chat_id="20001", + message_id="msg-1", + message_time=100.0, + message_seq=1, + ) + await store.mark_recovered_message_seen( + account_id="10001", + scope="primary", + chat_type="group", + chat_id="20001", + external_message_id="history-1", + ) + + checkpoints = await store.list_checkpoints("10001", scope="primary") + + assert len(checkpoints) == 1 + assert checkpoints[0].last_message_id == "msg-2" + assert checkpoints[0].last_message_seq == 2 + assert ( + await store.has_recovered_message_seen( + account_id="10001", + scope="primary", + chat_type="group", + chat_id="20001", + external_message_id="history-1", + ) + is True + ) + + +@pytest.mark.asyncio +async def test_query_service_wraps_group_and_friend_history_actions() -> None: + """查询服务应按官方动作名封装历史消息接口。""" + + query_service_cls = _load_query_service_cls() + action_service = _FakeActionService([{"message_id": "msg-1"}]) + query_service = query_service_cls( + action_service=action_service, + logger=logging.getLogger("test.napcat.history_query"), + ) + + group_payload = await query_service.get_group_message_history("20001", message_seq=123, count=10) + friend_payload = await query_service.get_friend_message_history("30001", count=5, reverse_order=True) + + assert group_payload == [{"message_id": "msg-1"}] + assert friend_payload == [{"message_id": "msg-1"}] + assert action_service.action_data_calls == [ + { + "action_name": "get_group_msg_history", + "params": {"group_id": "20001", "count": 10, "reverse_order": False, "message_seq": 123}, + }, + { + "action_name": "get_friend_msg_history", + "params": {"user_id": "30001", "count": 5, "reverse_order": True}, + }, + ] + + +@pytest.mark.asyncio +async def test_router_recover_recent_history_reinjects_messages_in_order(tmp_path: Path) -> None: + """重连补拉应按时间顺序将历史消息重新注入原入站路径。""" + + history_store_cls = _load_history_recovery_store_cls() + router_cls = _load_router_cls() + gateway_capability = _FakeGatewayCapability() + router = router_cls( + gateway_capability=gateway_capability, + logger=logging.getLogger("test.napcat.history_router"), + gateway_name="napcat_gateway", + load_settings=lambda: SimpleNamespace( + napcat_server=SimpleNamespace(connection_id="primary", heartbeat_interval=30.0), + filters=SimpleNamespace(ignore_self_message=True), + chat=SimpleNamespace(ban_qq_bot=False), + ), + ) + + history_store = history_store_cls( + logger=logging.getLogger("test.napcat.history_router.store"), + storage_path=tmp_path / "router.sqlite3", + ) + await history_store.load() + await history_store.record_checkpoint( + account_id="10001", + scope="primary", + chat_type="group", + chat_id="20001", + message_id="msg-1", + message_time=100.0, + message_seq=10, + ) + + history_calls: List[Dict[str, Any]] = [] + history_payloads = [ + { + "post_type": "message", + "message_type": "group", + "self_id": "10001", + "group_id": "20001", + "user_id": "30002", + "message_id": "msg-3", + "message_seq": 12, + "time": 102, + "message": [{"type": "text", "data": {"text": "第三条"}}], + "sender": {"user_id": "30002", "nickname": "用户二"}, + }, + { + "post_type": "message", + "message_type": "group", + "self_id": "10001", + "group_id": "20001", + "user_id": "30001", + "message_id": "msg-2", + "message_seq": 11, + "time": 101, + "message": [{"type": "text", "data": {"text": "第二条"}}], + "sender": {"user_id": "30001", "nickname": "用户一"}, + }, + ] + + class _FakeQueryService: + async def get_group_message_history( + self, + group_id: str, + *, + message_seq: int | None = None, + count: int = 20, + reverse_order: bool = False, + ) -> List[Dict[str, Any]]: + history_calls.append( + { + "group_id": group_id, + "message_seq": message_seq, + "count": count, + "reverse_order": reverse_order, + } + ) + return list(history_payloads) + + async def get_friend_message_history(self, user_id: str, **kwargs: Any) -> List[Dict[str, Any]]: + del user_id + del kwargs + return [] + + class _FakeInboundCodec: + @staticmethod + async def build_message_dict( + payload: Dict[str, Any], + self_id: str, + sender_user_id: str, + sender: Dict[str, Any], + ) -> Dict[str, Any]: + del self_id + del sender_user_id + del sender + return { + "message_id": str(payload["message_id"]), + "platform": "qq", + "timestamp": str(float(payload["time"])), + "message_info": { + "user_info": {"user_id": str(payload["user_id"]), "user_nickname": "测试用户"}, + "group_info": {"group_id": str(payload["group_id"]), "group_name": "测试群"}, + "additional_config": {}, + }, + "raw_message": [{"type": "text", "data": str(payload["message"][0]["data"]["text"])}], + "processed_plain_text": str(payload["message"][0]["data"]["text"]), + "display_message": str(payload["message"][0]["data"]["text"]), + "is_mentioned": False, + "is_at": False, + "is_emoji": False, + "is_picture": False, + "is_command": False, + "is_notify": False, + "session_id": "", + } + + router.bind_runtime( + SimpleNamespace( + runtime_state=SimpleNamespace(report_connected=lambda *args, **kwargs: _noop_async(), report_disconnected=_noop_async), + chat_filter=SimpleNamespace(is_inbound_chat_allowed=lambda *args, **kwargs: True), + official_bot_guard=SimpleNamespace( + should_reject=lambda *args, **kwargs: _return_false_async(), + clear_cache=lambda: None, + ), + inbound_codec=_FakeInboundCodec(), + history_recovery_store=history_store, + query_service=_FakeQueryService(), + heartbeat_monitor=SimpleNamespace(start=_noop_async, stop=_noop_async), + ban_tracker=SimpleNamespace(start=_noop_async, stop=_noop_async, record_notice=_noop_async), + notice_codec=SimpleNamespace(handle_meta_event=_noop_async, build_notice_message_dict=_return_none_async), + ) + ) + + await router._recover_recent_history(self_id="10001", scope="primary") + + assert history_calls == [ + { + "group_id": "20001", + "message_seq": 10, + "count": 20, + "reverse_order": False, + } + ] + assert [call["external_message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"] + assert [call["message"]["message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"] + + +async def _noop_async(*args: Any, **kwargs: Any) -> None: + """无操作异步函数。""" + + del args + del kwargs + + +async def _return_false_async(*args: Any, **kwargs: Any) -> bool: + """返回 ``False`` 的异步测试替身。""" + + del args + del kwargs + return False + + +async def _return_none_async(*args: Any, **kwargs: Any) -> None: + """返回 ``None`` 的异步测试替身。""" + + del args + del kwargs + return None