feat:新增napcat断线后重连重新拉取历史消息的机制

This commit is contained in:
LoveLosita
2026-05-11 15:03:36 +08:00
parent 388f965d09
commit a566313c5f
9 changed files with 1154 additions and 8 deletions

View File

@@ -8,3 +8,6 @@ DEFAULT_RECONNECT_DELAY_SEC = 5.0
DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0 DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0
DEFAULT_ACTION_TIMEOUT_SEC = 15.0 DEFAULT_ACTION_TIMEOUT_SEC = 15.0
DEFAULT_CHAT_LIST_TYPE = "whitelist" DEFAULT_CHAT_LIST_TYPE = "whitelist"
DEFAULT_HISTORY_RECOVERY_BATCH_SIZE = 20
DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT = 50
DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC = 86400.0 * 7

View File

@@ -14,6 +14,7 @@ from ..services import (
NapCatActionService, NapCatActionService,
NapCatBanStateStore, NapCatBanStateStore,
NapCatBanTracker, NapCatBanTracker,
NapCatHistoryRecoveryStore,
NapCatOfficialBotGuard, NapCatOfficialBotGuard,
NapCatQueryService, NapCatQueryService,
) )
@@ -66,6 +67,7 @@ class NapCatRuntimeBuilder:
action_service = NapCatActionService(self._logger, transport) action_service = NapCatActionService(self._logger, transport)
query_service = NapCatQueryService(action_service, self._logger) query_service = NapCatQueryService(action_service, self._logger)
ban_state_store = NapCatBanStateStore(self._logger) ban_state_store = NapCatBanStateStore(self._logger)
history_recovery_store = NapCatHistoryRecoveryStore(self._logger)
inbound_codec = NapCatInboundCodec(self._logger, query_service) inbound_codec = NapCatInboundCodec(self._logger, query_service)
notice_codec = NapCatNoticeCodec(self._logger, query_service) notice_codec = NapCatNoticeCodec(self._logger, query_service)
runtime_state = NapCatRuntimeStateManager( runtime_state = NapCatRuntimeStateManager(
@@ -92,6 +94,7 @@ class NapCatRuntimeBuilder:
ban_tracker=ban_tracker, ban_tracker=ban_tracker,
chat_filter=chat_filter, chat_filter=chat_filter,
heartbeat_monitor=heartbeat_monitor, heartbeat_monitor=heartbeat_monitor,
history_recovery_store=history_recovery_store,
inbound_codec=inbound_codec, inbound_codec=inbound_codec,
notice_codec=notice_codec, notice_codec=notice_codec,
official_bot_guard=official_bot_guard, official_bot_guard=official_bot_guard,

View File

@@ -14,6 +14,7 @@ from ..services import (
NapCatActionService, NapCatActionService,
NapCatBanStateStore, NapCatBanStateStore,
NapCatBanTracker, NapCatBanTracker,
NapCatHistoryRecoveryStore,
NapCatOfficialBotGuard, NapCatOfficialBotGuard,
NapCatQueryService, NapCatQueryService,
) )
@@ -29,6 +30,7 @@ class NapCatRuntimeBundle:
ban_tracker: NapCatBanTracker ban_tracker: NapCatBanTracker
chat_filter: NapCatChatFilter chat_filter: NapCatChatFilter
heartbeat_monitor: NapCatHeartbeatMonitor heartbeat_monitor: NapCatHeartbeatMonitor
history_recovery_store: NapCatHistoryRecoveryStore
inbound_codec: NapCatInboundCodec inbound_codec: NapCatInboundCodec
notice_codec: NapCatNoticeCodec notice_codec: NapCatNoticeCodec
official_bot_guard: NapCatOfficialBotGuard official_bot_guard: NapCatOfficialBotGuard

View File

@@ -2,11 +2,14 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Optional, Protocol from typing import Any, Callable, Dict, Mapping, Optional, Protocol
import asyncio import asyncio
from ..config import NapCatPluginSettings from ..config import NapCatPluginSettings
from ..constants import DEFAULT_HISTORY_RECOVERY_BATCH_SIZE, DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT
from ..services import NapCatChatCheckpoint
from ..types import NapCatPayloadDict from ..types import NapCatPayloadDict
from .bundle import NapCatRuntimeBundle from .bundle import NapCatRuntimeBundle
@@ -27,6 +30,14 @@ class _GatewayCapabilityProtocol(Protocol):
... ...
@dataclass(frozen=True)
class _NapCatChatIdentity:
"""描述一条 NapCat 消息所属的会话身份。"""
chat_type: str
chat_id: str
class NapCatEventRouter: class NapCatEventRouter:
"""协调 NapCat 运行时组件处理各类平台事件。""" """协调 NapCat 运行时组件处理各类平台事件。"""
@@ -50,6 +61,7 @@ class NapCatEventRouter:
self._gateway_name = gateway_name self._gateway_name = gateway_name
self._load_settings = load_settings self._load_settings = load_settings
self._runtime: Optional[NapCatRuntimeBundle] = None self._runtime: Optional[NapCatRuntimeBundle] = None
self._recovery_task: Optional[asyncio.Task[None]] = None
def bind_runtime(self, runtime: NapCatRuntimeBundle) -> None: def bind_runtime(self, runtime: NapCatRuntimeBundle) -> None:
"""绑定当前路由器使用的运行时依赖。 """绑定当前路由器使用的运行时依赖。
@@ -64,6 +76,7 @@ class NapCatEventRouter:
runtime = self._runtime runtime = self._runtime
if runtime is None: if runtime is None:
return return
self._cancel_recovery_task()
runtime.official_bot_guard.clear_cache() runtime.official_bot_guard.clear_cache()
async def handle_transport_payload(self, payload: NapCatPayloadDict) -> None: async def handle_transport_payload(self, payload: NapCatPayloadDict) -> None:
@@ -82,7 +95,7 @@ class NapCatEventRouter:
if post_type == "meta_event": if post_type == "meta_event":
await self.handle_meta_event(payload) await self.handle_meta_event(payload)
async def handle_inbound_message(self, payload: NapCatPayloadDict) -> None: async def handle_inbound_message(self, payload: NapCatPayloadDict) -> bool:
"""处理单条 NapCat 入站消息并注入 Host。 """处理单条 NapCat 入站消息并注入 Host。
Args: Args:
@@ -101,25 +114,25 @@ class NapCatEventRouter:
sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip() sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip()
if not sender_user_id: if not sender_user_id:
return return False
group_id = str(payload.get("group_id") or "").strip() group_id = str(payload.get("group_id") or "").strip()
if self_id and sender_user_id == self_id and settings.filters.ignore_self_message: if self_id and sender_user_id == self_id and settings.filters.ignore_self_message:
return return False
if not runtime.chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat): if not runtime.chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat):
return return False
if await runtime.official_bot_guard.should_reject( if await runtime.official_bot_guard.should_reject(
sender_user_id=sender_user_id, sender_user_id=sender_user_id,
group_id=group_id, group_id=group_id,
ban_qq_bot=settings.chat.ban_qq_bot, ban_qq_bot=settings.chat.ban_qq_bot,
): ):
return return False
try: try:
message_dict = await runtime.inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender) message_dict = await runtime.inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender)
except ValueError as exc: except ValueError as exc:
self._logger.warning(f"NapCat 入站消息格式不受支持,已丢弃: {exc}") self._logger.warning(f"NapCat 入站消息格式不受支持,已丢弃: {exc}")
return return False
route_metadata = self._build_route_metadata(self_id, settings.napcat_server.connection_id) route_metadata = self._build_route_metadata(self_id, settings.napcat_server.connection_id)
external_message_id = str(payload.get("message_id") or "").strip() external_message_id = str(payload.get("message_id") or "").strip()
@@ -132,6 +145,15 @@ class NapCatEventRouter:
) )
if not accepted: if not accepted:
self._logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}") self._logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}")
return False
await self._record_inbound_checkpoint(
payload=payload,
self_id=self_id,
external_message_id=external_message_id or str(message_dict.get("message_id") or "").strip(),
scope=settings.napcat_server.connection_id,
)
return True
async def handle_notice_event(self, payload: NapCatPayloadDict) -> None: async def handle_notice_event(self, payload: NapCatPayloadDict) -> None:
"""处理 NapCat ``notice`` 事件并注入 Host。 """处理 NapCat ``notice`` 事件并注入 Host。
@@ -232,6 +254,8 @@ class NapCatEventRouter:
await runtime.runtime_state.report_connected(self_id, settings.napcat_server) await runtime.runtime_state.report_connected(self_id, settings.napcat_server)
await runtime.heartbeat_monitor.start(self_id, settings.napcat_server.heartbeat_interval) await runtime.heartbeat_monitor.start(self_id, settings.napcat_server.heartbeat_interval)
await runtime.ban_tracker.start() await runtime.ban_tracker.start()
await runtime.history_recovery_store.load()
self._schedule_history_recovery(self_id=self_id, scope=settings.napcat_server.connection_id)
return return
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
@@ -279,6 +303,274 @@ class NapCatEventRouter:
raise RuntimeError("NapCat 运行时尚未初始化") raise RuntimeError("NapCat 运行时尚未初始化")
return runtime return runtime
def _schedule_history_recovery(self, self_id: str, scope: str) -> None:
"""在连接恢复后调度一次历史补拉任务。"""
self._cancel_recovery_task()
runtime = self._runtime
if runtime is None:
return
self._recovery_task = asyncio.create_task(
self._recover_recent_history(self_id=self_id, scope=scope),
name="napcat_adapter.history_recovery",
)
def _cancel_recovery_task(self) -> None:
"""取消当前仍在运行的历史补拉任务。"""
recovery_task = self._recovery_task
self._recovery_task = None
if recovery_task is not None and not recovery_task.done():
recovery_task.cancel()
async def _recover_recent_history(self, *, self_id: str, scope: str) -> None:
"""按 checkpoint 列表逐个尝试补拉断线期间遗漏的消息。"""
runtime = self._require_runtime()
checkpoints = await runtime.history_recovery_store.list_checkpoints(
self_id,
scope=scope,
limit=DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT,
)
if not checkpoints:
return
recovered_count = 0
for checkpoint in checkpoints:
recovered_count += await self._recover_chat_history_from_checkpoint(
self_id=self_id,
scope=scope,
checkpoint=checkpoint,
)
if recovered_count > 0:
self._logger.info(f"NapCat 历史补拉完成,共补回 {recovered_count} 条消息")
async def _recover_chat_history_from_checkpoint(
self,
*,
self_id: str,
scope: str,
checkpoint: NapCatChatCheckpoint,
) -> int:
"""针对单个会话执行一次小批量历史补拉。"""
runtime = self._require_runtime()
history_messages = await self._query_history_messages(checkpoint, limit=DEFAULT_HISTORY_RECOVERY_BATCH_SIZE)
if not history_messages:
return 0
ordered_messages = sorted(
history_messages,
key=lambda item: (
self._extract_message_timestamp(item),
self._extract_message_seq(item),
str(item.get("message_id") or "").strip(),
),
)
recovered_count = 0
for history_payload in ordered_messages:
external_message_id = str(history_payload.get("message_id") or "").strip()
if not external_message_id:
continue
if external_message_id == checkpoint.last_message_id:
continue
if await runtime.history_recovery_store.has_recovered_message_seen(
account_id=self_id,
scope=scope,
chat_type=checkpoint.chat_type,
chat_id=checkpoint.chat_id,
external_message_id=external_message_id,
):
continue
if not self._is_message_after_checkpoint(history_payload, checkpoint):
continue
accepted = await self._reinject_history_payload(history_payload, self_id=self_id)
if not accepted:
continue
await runtime.history_recovery_store.mark_recovered_message_seen(
account_id=self_id,
scope=scope,
chat_type=checkpoint.chat_type,
chat_id=checkpoint.chat_id,
external_message_id=external_message_id,
)
recovered_count += 1
return recovered_count
async def _query_history_messages(
self,
checkpoint: NapCatChatCheckpoint,
*,
limit: int,
) -> list[NapCatPayloadDict]:
"""查询某个会话在 checkpoint 之后的一小批历史消息。"""
runtime = self._require_runtime()
payload_collections: list[list[NapCatPayloadDict]] = []
if checkpoint.last_message_seq is not None:
payload_collections.append(
await self._fetch_history_messages(
chat_type=checkpoint.chat_type,
chat_id=checkpoint.chat_id,
message_seq=checkpoint.last_message_seq,
limit=limit,
)
)
payload_collections.append(
await self._fetch_history_messages(
chat_type=checkpoint.chat_type,
chat_id=checkpoint.chat_id,
message_seq=None,
limit=limit,
)
)
merged_payloads: list[NapCatPayloadDict] = []
seen_message_ids: set[str] = set()
for payloads in payload_collections:
for payload in payloads:
external_message_id = str(payload.get("message_id") or "").strip()
dedupe_key = external_message_id or repr(sorted(payload.items()))
if dedupe_key in seen_message_ids:
continue
seen_message_ids.add(dedupe_key)
merged_payloads.append(payload)
return merged_payloads
async def _fetch_history_messages(
self,
*,
chat_type: str,
chat_id: str,
message_seq: int | None,
limit: int,
) -> list[NapCatPayloadDict]:
"""调用查询服务获取一批历史消息。"""
runtime = self._require_runtime()
if chat_type == "group":
history_payloads = await runtime.query_service.get_group_message_history(
chat_id,
message_seq=message_seq,
count=limit,
reverse_order=False,
)
elif chat_type == "private":
history_payloads = await runtime.query_service.get_friend_message_history(
chat_id,
message_seq=message_seq,
count=limit,
reverse_order=False,
)
else:
return []
if history_payloads is None:
return []
return [dict(payload) for payload in history_payloads if isinstance(payload, Mapping)]
async def _reinject_history_payload(self, payload: NapCatPayloadDict, *, self_id: str) -> bool:
"""将补拉到的历史消息重新送回实时入站路径。"""
try:
normalized_payload = dict(payload)
if self_id and not str(normalized_payload.get("self_id") or "").strip():
normalized_payload["self_id"] = self_id
return await self.handle_inbound_message(normalized_payload)
except asyncio.CancelledError:
raise
except Exception as exc:
external_message_id = str(payload.get("message_id") or "").strip() or "unknown"
self._logger.warning(f"NapCat 历史消息补拉注入失败: message_id={external_message_id} error={exc}")
return False
async def _record_inbound_checkpoint(
self,
*,
payload: NapCatPayloadDict,
self_id: str,
external_message_id: str,
scope: str,
) -> None:
"""在消息被 Host 接受后更新该会话的最新 checkpoint。"""
runtime = self._require_runtime()
chat_identity = self._extract_chat_identity(payload)
if chat_identity is None:
return
await runtime.history_recovery_store.record_checkpoint(
account_id=self_id,
scope=scope,
chat_type=chat_identity.chat_type,
chat_id=chat_identity.chat_id,
message_id=external_message_id,
message_time=self._extract_message_timestamp(payload),
message_seq=self._extract_message_seq(payload),
)
@staticmethod
def _extract_chat_identity(payload: Mapping[str, Any]) -> _NapCatChatIdentity | None:
"""从 NapCat 载荷中提取会话身份。"""
group_id = str(payload.get("group_id") or "").strip()
user_id = str(payload.get("user_id") or "").strip()
if group_id:
return _NapCatChatIdentity(chat_type="group", chat_id=group_id)
if user_id:
return _NapCatChatIdentity(chat_type="private", chat_id=user_id)
return None
@staticmethod
def _extract_message_seq(payload: Mapping[str, Any]) -> int | None:
"""从 NapCat 载荷中提取历史接口可复用的消息序号。"""
for field_name in ("message_seq", "messageSeq", "msg_seq"):
raw_value = payload.get(field_name)
if raw_value is None or str(raw_value).strip() == "":
continue
try:
return int(raw_value)
except (TypeError, ValueError):
continue
return None
@staticmethod
def _extract_message_timestamp(payload: Mapping[str, Any]) -> float:
"""从 NapCat 载荷中提取消息时间戳。"""
raw_timestamp = payload.get("time")
if isinstance(raw_timestamp, (int, float)):
return float(raw_timestamp)
return 0.0
@classmethod
def _is_message_after_checkpoint(
cls,
payload: Mapping[str, Any],
checkpoint: NapCatChatCheckpoint,
) -> bool:
"""判断历史消息是否位于 checkpoint 之后。"""
payload_message_id = str(payload.get("message_id") or "").strip()
if payload_message_id == checkpoint.last_message_id:
return False
payload_message_seq = cls._extract_message_seq(payload)
if payload_message_seq is not None and checkpoint.last_message_seq is not None:
return payload_message_seq > checkpoint.last_message_seq
payload_timestamp = cls._extract_message_timestamp(payload)
if payload_timestamp != checkpoint.last_message_time:
return payload_timestamp > checkpoint.last_message_time
return True
@staticmethod @staticmethod
def _build_route_metadata(self_id: str, connection_id: str) -> Dict[str, Any]: def _build_route_metadata(self_id: str, connection_id: str) -> Dict[str, Any]:
"""构造注入 Host 时使用的路由元数据。 """构造注入 Host 时使用的路由元数据。

View File

@@ -3,6 +3,7 @@
from .action_service import NapCatActionService from .action_service import NapCatActionService
from .ban_tracker import NapCatBanTracker from .ban_tracker import NapCatBanTracker
from .ban_state_store import NapCatBanRecord, NapCatBanStateStore from .ban_state_store import NapCatBanRecord, NapCatBanStateStore
from .history_recovery_store import NapCatChatCheckpoint, NapCatHistoryRecoveryStore
from .official_bot_guard import NapCatOfficialBotGuard from .official_bot_guard import NapCatOfficialBotGuard
from .query_service import NapCatQueryService from .query_service import NapCatQueryService
@@ -11,6 +12,8 @@ __all__ = [
"NapCatBanRecord", "NapCatBanRecord",
"NapCatBanStateStore", "NapCatBanStateStore",
"NapCatBanTracker", "NapCatBanTracker",
"NapCatChatCheckpoint",
"NapCatHistoryRecoveryStore",
"NapCatOfficialBotGuard", "NapCatOfficialBotGuard",
"NapCatQueryService", "NapCatQueryService",
] ]

View File

@@ -0,0 +1,423 @@
"""NapCat 历史补拉状态持久化仓库。"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, List, Optional, TypeVar
import asyncio
import sqlite3
import time
from ..constants import DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
_DEFAULT_STORAGE_PATH = _PROJECT_ROOT / "data" / "napcat_adapter" / "history_recovery.sqlite3"
_SCHEMA_STATEMENTS = (
"""
CREATE TABLE IF NOT EXISTS napcat_chat_checkpoint (
account_id TEXT NOT NULL,
scope TEXT NOT NULL,
chat_type TEXT NOT NULL,
chat_id TEXT NOT NULL,
last_message_id TEXT NOT NULL,
last_message_time REAL NOT NULL,
last_message_seq INTEGER,
updated_at REAL NOT NULL,
PRIMARY KEY (account_id, scope, chat_type, chat_id)
)
""",
"""
CREATE INDEX IF NOT EXISTS ix_napcat_chat_checkpoint_updated_at
ON napcat_chat_checkpoint (updated_at DESC)
""",
"""
CREATE TABLE IF NOT EXISTS napcat_recovery_seen (
account_id TEXT NOT NULL,
scope TEXT NOT NULL,
chat_type TEXT NOT NULL,
chat_id TEXT NOT NULL,
external_message_id TEXT NOT NULL,
seen_at REAL NOT NULL,
PRIMARY KEY (account_id, scope, chat_type, chat_id, external_message_id)
)
""",
"""
CREATE INDEX IF NOT EXISTS ix_napcat_recovery_seen_seen_at
ON napcat_recovery_seen (seen_at DESC)
""",
)
T = TypeVar("T")
@dataclass(frozen=True)
class NapCatChatCheckpoint:
"""描述一个会话的最近入站锚点。"""
account_id: str
scope: str
chat_type: str
chat_id: str
last_message_id: str
last_message_time: float
last_message_seq: int | None
updated_at: float
@classmethod
def from_row(cls, row: sqlite3.Row) -> "NapCatChatCheckpoint":
"""从 SQLite 行对象恢复 checkpoint。"""
last_message_seq = row["last_message_seq"]
normalized_seq = int(last_message_seq) if isinstance(last_message_seq, int) else None
return cls(
account_id=str(row["account_id"] or "").strip(),
scope=str(row["scope"] or "").strip(),
chat_type=str(row["chat_type"] or "").strip(),
chat_id=str(row["chat_id"] or "").strip(),
last_message_id=str(row["last_message_id"] or "").strip(),
last_message_time=float(row["last_message_time"] or 0.0),
last_message_seq=normalized_seq,
updated_at=float(row["updated_at"] or 0.0),
)
class NapCatHistoryRecoveryStore:
"""负责持久化历史补拉所需的会话状态与去重状态。"""
def __init__(self, logger: Any, storage_path: Path = _DEFAULT_STORAGE_PATH) -> None:
"""初始化历史补拉状态仓库。"""
self._logger = logger
self._storage_path = storage_path
self._store_lock = asyncio.Lock()
self._schema_ready = False
async def load(self) -> None:
"""初始化 SQLite 文件并清理过期去重记录。"""
await self._execute_locked(self._ensure_schema)
pruned_count = await self.prune_recovery_seen(DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC)
if pruned_count > 0:
self._logger.debug(f"NapCat 历史补拉去重表已清理 {pruned_count} 条过期记录")
async def list_checkpoints(self, account_id: str, scope: str = "", limit: int = 50) -> List[NapCatChatCheckpoint]:
"""列出指定账号与作用域下的最近会话 checkpoint。"""
normalized_account_id = str(account_id or "").strip()
if not normalized_account_id:
return []
normalized_scope = self._normalize_scope(scope)
normalized_limit = max(1, int(limit))
def _operation(conn: sqlite3.Connection) -> List[NapCatChatCheckpoint]:
cursor = conn.execute(
"""
SELECT
account_id,
scope,
chat_type,
chat_id,
last_message_id,
last_message_time,
last_message_seq,
updated_at
FROM napcat_chat_checkpoint
WHERE account_id = ? AND scope = ?
ORDER BY updated_at DESC
LIMIT ?
""",
(normalized_account_id, normalized_scope, normalized_limit),
)
return [NapCatChatCheckpoint.from_row(row) for row in cursor.fetchall()]
return await self._execute_locked(_operation)
async def record_checkpoint(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
message_id: str,
message_time: float,
message_seq: int | None = None,
) -> None:
"""记录一条已被 Host 接受的最新入站消息锚点。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_message_id = str(message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_message_id
):
return
normalized_message_time = float(message_time or 0.0)
normalized_message_seq = self._normalize_message_seq(message_seq)
updated_at = time.time()
def _operation(conn: sqlite3.Connection) -> None:
cursor = conn.execute(
"""
SELECT last_message_id, last_message_time, last_message_seq
FROM napcat_chat_checkpoint
WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ?
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
),
)
existing_row = cursor.fetchone()
if existing_row is not None and not self._should_advance_checkpoint(
existing_row=existing_row,
message_id=normalized_message_id,
message_time=normalized_message_time,
message_seq=normalized_message_seq,
):
return
conn.execute(
"""
INSERT INTO napcat_chat_checkpoint (
account_id,
scope,
chat_type,
chat_id,
last_message_id,
last_message_time,
last_message_seq,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account_id, scope, chat_type, chat_id) DO UPDATE SET
last_message_id = excluded.last_message_id,
last_message_time = excluded.last_message_time,
last_message_seq = excluded.last_message_seq,
updated_at = excluded.updated_at
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_message_id,
normalized_message_time,
normalized_message_seq,
updated_at,
),
)
await self._execute_locked(_operation)
async def has_recovered_message_seen(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
external_message_id: str,
) -> bool:
"""判断某条历史补拉消息是否已经被当前仓库记录过。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_external_message_id = str(external_message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_external_message_id
):
return False
def _operation(conn: sqlite3.Connection) -> bool:
cursor = conn.execute(
"""
SELECT 1
FROM napcat_recovery_seen
WHERE account_id = ?
AND scope = ?
AND chat_type = ?
AND chat_id = ?
AND external_message_id = ?
LIMIT 1
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_external_message_id,
),
)
return cursor.fetchone() is not None
return await self._execute_locked(_operation)
async def mark_recovered_message_seen(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
external_message_id: str,
) -> None:
"""将一条历史补拉消息标记为已尝试处理。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_external_message_id = str(external_message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_external_message_id
):
return
def _operation(conn: sqlite3.Connection) -> None:
conn.execute(
"""
INSERT OR REPLACE INTO napcat_recovery_seen (
account_id,
scope,
chat_type,
chat_id,
external_message_id,
seen_at
) VALUES (?, ?, ?, ?, ?, ?)
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_external_message_id,
time.time(),
),
)
await self._execute_locked(_operation)
async def prune_recovery_seen(self, ttl_seconds: float) -> int:
"""删除超过保留期的历史补拉去重记录。"""
normalized_ttl_seconds = max(0.0, float(ttl_seconds or 0.0))
if normalized_ttl_seconds <= 0.0:
return 0
cutoff_timestamp = time.time() - normalized_ttl_seconds
def _operation(conn: sqlite3.Connection) -> int:
cursor = conn.execute(
"DELETE FROM napcat_recovery_seen WHERE seen_at < ?",
(cutoff_timestamp,),
)
return int(cursor.rowcount or 0)
return await self._execute_locked(_operation)
async def _execute_locked(self, operation: Callable[[sqlite3.Connection], T]) -> T:
"""在锁保护下打开 SQLite 并执行一次原子操作。"""
async with self._store_lock:
conn = self._open_connection()
try:
self._ensure_schema(conn)
result = operation(conn)
conn.commit()
return result
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _open_connection(self) -> sqlite3.Connection:
"""打开一个带 Row 工厂的 SQLite 连接。"""
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(self._storage_path))
conn.row_factory = sqlite3.Row
return conn
def _ensure_schema(self, conn: sqlite3.Connection) -> None:
"""确保 SQLite 表结构已经准备完成。"""
if self._schema_ready:
return
for statement in _SCHEMA_STATEMENTS:
conn.execute(statement)
self._schema_ready = True
@staticmethod
def _normalize_scope(scope: str | None) -> str:
"""将空作用域统一折叠为空字符串。"""
return str(scope or "").strip()
@staticmethod
def _normalize_message_seq(message_seq: object) -> int | None:
"""将消息序号规范化为可选整数。"""
try:
if message_seq is None or str(message_seq).strip() == "":
return None
return int(message_seq)
except (TypeError, ValueError):
return None
@classmethod
def _should_advance_checkpoint(
cls,
*,
existing_row: sqlite3.Row,
message_id: str,
message_time: float,
message_seq: int | None,
) -> bool:
"""判断新的锚点是否应覆盖旧锚点。"""
existing_message_id = str(existing_row["last_message_id"] or "").strip()
existing_message_time = float(existing_row["last_message_time"] or 0.0)
existing_message_seq = cls._normalize_message_seq(existing_row["last_message_seq"])
if message_seq is not None and existing_message_seq is not None:
if message_seq != existing_message_seq:
return message_seq > existing_message_seq
if message_id == existing_message_id:
return False
return message_time >= existing_message_time
if message_time != existing_message_time:
return message_time > existing_message_time
if message_id == existing_message_id:
return False
if message_seq is not None and existing_message_seq is None:
return True
return True

View File

@@ -180,6 +180,46 @@ class NapCatQueryService:
response_data = await self._safe_call_action_data("get_msg", {"message_id": message_id}) response_data = await self._safe_call_action_data("get_msg", {"message_id": message_id})
return response_data if isinstance(response_data, dict) else None return response_data if isinstance(response_data, dict) else None
async def get_friend_message_history(
self,
user_id: str,
*,
message_seq: int | None = None,
count: int = 20,
reverse_order: bool = False,
) -> Optional[NapCatPayloadList]:
"""获取私聊历史消息列表。"""
params: NapCatActionResponse = {
"user_id": user_id,
"count": max(1, int(count)),
"reverse_order": bool(reverse_order),
}
if message_seq is not None:
params["message_seq"] = int(message_seq)
response_data = await self._safe_call_action_data("get_friend_msg_history", params)
return self._normalize_payload_list(response_data, action_name="get_friend_msg_history")
async def get_group_message_history(
self,
group_id: str,
*,
message_seq: int | None = None,
count: int = 20,
reverse_order: bool = False,
) -> Optional[NapCatPayloadList]:
"""获取群聊历史消息列表。"""
params: NapCatActionResponse = {
"group_id": group_id,
"count": max(1, int(count)),
"reverse_order": bool(reverse_order),
}
if message_seq is not None:
params["message_seq"] = int(message_seq)
response_data = await self._safe_call_action_data("get_group_msg_history", params)
return self._normalize_payload_list(response_data, action_name="get_group_msg_history")
async def get_forward_message( async def get_forward_message(
self, self,
message_id: Optional[str] = None, message_id: Optional[str] = None,

View File

@@ -12,8 +12,10 @@ import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1] PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins" PLUGINS_ROOT = PROJECT_ROOT / "plugins"
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter" NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEST_MODULE = "_test_napcat_adapter" NAPCAT_TEST_MODULE = "_test_napcat_adapter"
for import_path in (str(SDK_ROOT),): for import_path in (str(SDK_ROOT),):
@@ -204,12 +206,14 @@ def _load_napcat_sdk_modules() -> Tuple[Any, Any, Any, Any]:
依次返回常量模块、配置模块、插件模块和运行时状态模块。 依次返回常量模块、配置模块、插件模块和运行时状态模块。
""" """
plugin_dir = NAPCAT_PLUGIN_DIR if NAPCAT_PLUGIN_DIR.is_dir() else NAPCAT_TEMPLATE_DIR
if NAPCAT_TEST_MODULE not in sys.modules: if NAPCAT_TEST_MODULE not in sys.modules:
plugin_path = NAPCAT_PLUGIN_DIR / "plugin.py" plugin_path = plugin_dir / "plugin.py"
spec = util.spec_from_file_location( spec = util.spec_from_file_location(
NAPCAT_TEST_MODULE, NAPCAT_TEST_MODULE,
plugin_path, plugin_path,
submodule_search_locations=[str(NAPCAT_PLUGIN_DIR)], submodule_search_locations=[str(plugin_dir)],
) )
if spec is None or spec.loader is None: if spec is None or spec.loader is None:
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}") raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")

View File

@@ -0,0 +1,376 @@
"""NapCat 历史补拉与恢复状态测试。"""
from __future__ import annotations
from importlib import import_module, util
from pathlib import Path
from typing import Any, Dict, List
import logging
import sys
from types import SimpleNamespace
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEST_MODULE = "_test_napcat_adapter_history_recovery"
for import_path in (str(SDK_ROOT),):
if import_path not in sys.path:
sys.path.insert(0, import_path)
class _FakeGatewayCapability:
"""用于测试入站注入的网关替身。"""
def __init__(self) -> None:
"""初始化测试替身。"""
self.calls: List[Dict[str, Any]] = []
async def route_message(
self,
gateway_name: str,
message: Dict[str, Any],
*,
route_metadata: Dict[str, Any] | None = None,
external_message_id: str = "",
dedupe_key: str = "",
) -> bool:
"""记录入站注入请求并始终模拟成功。"""
self.calls.append(
{
"gateway_name": gateway_name,
"message": dict(message),
"route_metadata": dict(route_metadata or {}),
"external_message_id": external_message_id,
"dedupe_key": dedupe_key,
}
)
return True
def _resolve_napcat_plugin_dir() -> Path:
"""返回当前测试可用的 NapCat 插件目录。"""
if NAPCAT_PLUGIN_DIR.is_dir():
return NAPCAT_PLUGIN_DIR
return NAPCAT_TEMPLATE_DIR
def _load_napcat_module(module_suffix: str) -> Any:
"""动态加载 NapCat 测试模块。"""
plugin_dir = _resolve_napcat_plugin_dir()
if NAPCAT_TEST_MODULE not in sys.modules:
plugin_path = plugin_dir / "plugin.py"
spec = util.spec_from_file_location(
NAPCAT_TEST_MODULE,
plugin_path,
submodule_search_locations=[str(plugin_dir)],
)
if spec is None or spec.loader is None:
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
module = util.module_from_spec(spec)
sys.modules[NAPCAT_TEST_MODULE] = module
try:
spec.loader.exec_module(module)
except Exception:
sys.modules.pop(NAPCAT_TEST_MODULE, None)
raise
return import_module(f"{NAPCAT_TEST_MODULE}.{module_suffix}")
def _load_history_recovery_store_cls() -> Any:
"""动态加载历史恢复状态仓库类。"""
return _load_napcat_module("services.history_recovery_store").NapCatHistoryRecoveryStore
def _load_query_service_cls() -> Any:
"""动态加载查询服务类。"""
return _load_napcat_module("services.query_service").NapCatQueryService
def _load_router_cls() -> Any:
"""动态加载事件路由器类。"""
return _load_napcat_module("runtime.router").NapCatEventRouter
class _FakeActionService:
"""用于查询服务的动作服务替身。"""
def __init__(self, response_data: Any) -> None:
"""初始化动作服务替身。"""
self._response_data = response_data
self.action_data_calls: List[Dict[str, Any]] = []
async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any:
"""记录安全查询动作。"""
self.action_data_calls.append({"action_name": action_name, "params": dict(params)})
return self._response_data
@pytest.mark.asyncio
async def test_history_recovery_store_persists_checkpoint_and_seen_state(tmp_path: Path) -> None:
"""历史恢复状态仓库应持久化 checkpoint 与已补拉标记。"""
store_cls = _load_history_recovery_store_cls()
store = store_cls(
logger=logging.getLogger("test.napcat.history_store"),
storage_path=tmp_path / "history.sqlite3",
)
await store.load()
await store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-2",
message_time=200.0,
message_seq=2,
)
await store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-1",
message_time=100.0,
message_seq=1,
)
await store.mark_recovered_message_seen(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
external_message_id="history-1",
)
checkpoints = await store.list_checkpoints("10001", scope="primary")
assert len(checkpoints) == 1
assert checkpoints[0].last_message_id == "msg-2"
assert checkpoints[0].last_message_seq == 2
assert (
await store.has_recovered_message_seen(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
external_message_id="history-1",
)
is True
)
@pytest.mark.asyncio
async def test_query_service_wraps_group_and_friend_history_actions() -> None:
"""查询服务应按官方动作名封装历史消息接口。"""
query_service_cls = _load_query_service_cls()
action_service = _FakeActionService([{"message_id": "msg-1"}])
query_service = query_service_cls(
action_service=action_service,
logger=logging.getLogger("test.napcat.history_query"),
)
group_payload = await query_service.get_group_message_history("20001", message_seq=123, count=10)
friend_payload = await query_service.get_friend_message_history("30001", count=5, reverse_order=True)
assert group_payload == [{"message_id": "msg-1"}]
assert friend_payload == [{"message_id": "msg-1"}]
assert action_service.action_data_calls == [
{
"action_name": "get_group_msg_history",
"params": {"group_id": "20001", "count": 10, "reverse_order": False, "message_seq": 123},
},
{
"action_name": "get_friend_msg_history",
"params": {"user_id": "30001", "count": 5, "reverse_order": True},
},
]
@pytest.mark.asyncio
async def test_router_recover_recent_history_reinjects_messages_in_order(tmp_path: Path) -> None:
"""重连补拉应按时间顺序将历史消息重新注入原入站路径。"""
history_store_cls = _load_history_recovery_store_cls()
router_cls = _load_router_cls()
gateway_capability = _FakeGatewayCapability()
router = router_cls(
gateway_capability=gateway_capability,
logger=logging.getLogger("test.napcat.history_router"),
gateway_name="napcat_gateway",
load_settings=lambda: SimpleNamespace(
napcat_server=SimpleNamespace(connection_id="primary", heartbeat_interval=30.0),
filters=SimpleNamespace(ignore_self_message=True),
chat=SimpleNamespace(ban_qq_bot=False),
),
)
history_store = history_store_cls(
logger=logging.getLogger("test.napcat.history_router.store"),
storage_path=tmp_path / "router.sqlite3",
)
await history_store.load()
await history_store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-1",
message_time=100.0,
message_seq=10,
)
history_calls: List[Dict[str, Any]] = []
history_payloads = [
{
"post_type": "message",
"message_type": "group",
"self_id": "10001",
"group_id": "20001",
"user_id": "30002",
"message_id": "msg-3",
"message_seq": 12,
"time": 102,
"message": [{"type": "text", "data": {"text": "第三条"}}],
"sender": {"user_id": "30002", "nickname": "用户二"},
},
{
"post_type": "message",
"message_type": "group",
"self_id": "10001",
"group_id": "20001",
"user_id": "30001",
"message_id": "msg-2",
"message_seq": 11,
"time": 101,
"message": [{"type": "text", "data": {"text": "第二条"}}],
"sender": {"user_id": "30001", "nickname": "用户一"},
},
]
class _FakeQueryService:
async def get_group_message_history(
self,
group_id: str,
*,
message_seq: int | None = None,
count: int = 20,
reverse_order: bool = False,
) -> List[Dict[str, Any]]:
history_calls.append(
{
"group_id": group_id,
"message_seq": message_seq,
"count": count,
"reverse_order": reverse_order,
}
)
return list(history_payloads)
async def get_friend_message_history(self, user_id: str, **kwargs: Any) -> List[Dict[str, Any]]:
del user_id
del kwargs
return []
class _FakeInboundCodec:
@staticmethod
async def build_message_dict(
payload: Dict[str, Any],
self_id: str,
sender_user_id: str,
sender: Dict[str, Any],
) -> Dict[str, Any]:
del self_id
del sender_user_id
del sender
return {
"message_id": str(payload["message_id"]),
"platform": "qq",
"timestamp": str(float(payload["time"])),
"message_info": {
"user_info": {"user_id": str(payload["user_id"]), "user_nickname": "测试用户"},
"group_info": {"group_id": str(payload["group_id"]), "group_name": "测试群"},
"additional_config": {},
},
"raw_message": [{"type": "text", "data": str(payload["message"][0]["data"]["text"])}],
"processed_plain_text": str(payload["message"][0]["data"]["text"]),
"display_message": str(payload["message"][0]["data"]["text"]),
"is_mentioned": False,
"is_at": False,
"is_emoji": False,
"is_picture": False,
"is_command": False,
"is_notify": False,
"session_id": "",
}
router.bind_runtime(
SimpleNamespace(
runtime_state=SimpleNamespace(report_connected=lambda *args, **kwargs: _noop_async(), report_disconnected=_noop_async),
chat_filter=SimpleNamespace(is_inbound_chat_allowed=lambda *args, **kwargs: True),
official_bot_guard=SimpleNamespace(
should_reject=lambda *args, **kwargs: _return_false_async(),
clear_cache=lambda: None,
),
inbound_codec=_FakeInboundCodec(),
history_recovery_store=history_store,
query_service=_FakeQueryService(),
heartbeat_monitor=SimpleNamespace(start=_noop_async, stop=_noop_async),
ban_tracker=SimpleNamespace(start=_noop_async, stop=_noop_async, record_notice=_noop_async),
notice_codec=SimpleNamespace(handle_meta_event=_noop_async, build_notice_message_dict=_return_none_async),
)
)
await router._recover_recent_history(self_id="10001", scope="primary")
assert history_calls == [
{
"group_id": "20001",
"message_seq": 10,
"count": 20,
"reverse_order": False,
}
]
assert [call["external_message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
assert [call["message"]["message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
async def _noop_async(*args: Any, **kwargs: Any) -> None:
"""无操作异步函数。"""
del args
del kwargs
async def _return_false_async(*args: Any, **kwargs: Any) -> bool:
"""返回 ``False`` 的异步测试替身。"""
del args
del kwargs
return False
async def _return_none_async(*args: Any, **kwargs: Any) -> None:
"""返回 ``None`` 的异步测试替身。"""
del args
del kwargs
return None