feat:新增napcat断线后重连重新拉取历史消息的机制

This commit is contained in:
LoveLosita
2026-05-11 15:03:36 +08:00
parent 388f965d09
commit a566313c5f
9 changed files with 1154 additions and 8 deletions

View File

@@ -8,3 +8,6 @@ DEFAULT_RECONNECT_DELAY_SEC = 5.0
DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0
DEFAULT_ACTION_TIMEOUT_SEC = 15.0
DEFAULT_CHAT_LIST_TYPE = "whitelist"
DEFAULT_HISTORY_RECOVERY_BATCH_SIZE = 20
DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT = 50
DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC = 86400.0 * 7

View File

@@ -14,6 +14,7 @@ from ..services import (
NapCatActionService,
NapCatBanStateStore,
NapCatBanTracker,
NapCatHistoryRecoveryStore,
NapCatOfficialBotGuard,
NapCatQueryService,
)
@@ -66,6 +67,7 @@ class NapCatRuntimeBuilder:
action_service = NapCatActionService(self._logger, transport)
query_service = NapCatQueryService(action_service, self._logger)
ban_state_store = NapCatBanStateStore(self._logger)
history_recovery_store = NapCatHistoryRecoveryStore(self._logger)
inbound_codec = NapCatInboundCodec(self._logger, query_service)
notice_codec = NapCatNoticeCodec(self._logger, query_service)
runtime_state = NapCatRuntimeStateManager(
@@ -92,6 +94,7 @@ class NapCatRuntimeBuilder:
ban_tracker=ban_tracker,
chat_filter=chat_filter,
heartbeat_monitor=heartbeat_monitor,
history_recovery_store=history_recovery_store,
inbound_codec=inbound_codec,
notice_codec=notice_codec,
official_bot_guard=official_bot_guard,

View File

@@ -14,6 +14,7 @@ from ..services import (
NapCatActionService,
NapCatBanStateStore,
NapCatBanTracker,
NapCatHistoryRecoveryStore,
NapCatOfficialBotGuard,
NapCatQueryService,
)
@@ -29,6 +30,7 @@ class NapCatRuntimeBundle:
ban_tracker: NapCatBanTracker
chat_filter: NapCatChatFilter
heartbeat_monitor: NapCatHeartbeatMonitor
history_recovery_store: NapCatHistoryRecoveryStore
inbound_codec: NapCatInboundCodec
notice_codec: NapCatNoticeCodec
official_bot_guard: NapCatOfficialBotGuard

View File

@@ -2,11 +2,14 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Optional, Protocol
import asyncio
from ..config import NapCatPluginSettings
from ..constants import DEFAULT_HISTORY_RECOVERY_BATCH_SIZE, DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT
from ..services import NapCatChatCheckpoint
from ..types import NapCatPayloadDict
from .bundle import NapCatRuntimeBundle
@@ -27,6 +30,14 @@ class _GatewayCapabilityProtocol(Protocol):
...
@dataclass(frozen=True)
class _NapCatChatIdentity:
"""描述一条 NapCat 消息所属的会话身份。"""
chat_type: str
chat_id: str
class NapCatEventRouter:
"""协调 NapCat 运行时组件处理各类平台事件。"""
@@ -50,6 +61,7 @@ class NapCatEventRouter:
self._gateway_name = gateway_name
self._load_settings = load_settings
self._runtime: Optional[NapCatRuntimeBundle] = None
self._recovery_task: Optional[asyncio.Task[None]] = None
def bind_runtime(self, runtime: NapCatRuntimeBundle) -> None:
"""绑定当前路由器使用的运行时依赖。
@@ -64,6 +76,7 @@ class NapCatEventRouter:
runtime = self._runtime
if runtime is None:
return
self._cancel_recovery_task()
runtime.official_bot_guard.clear_cache()
async def handle_transport_payload(self, payload: NapCatPayloadDict) -> None:
@@ -82,7 +95,7 @@ class NapCatEventRouter:
if post_type == "meta_event":
await self.handle_meta_event(payload)
async def handle_inbound_message(self, payload: NapCatPayloadDict) -> None:
async def handle_inbound_message(self, payload: NapCatPayloadDict) -> bool:
"""处理单条 NapCat 入站消息并注入 Host。
Args:
@@ -101,25 +114,25 @@ class NapCatEventRouter:
sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip()
if not sender_user_id:
return
return False
group_id = str(payload.get("group_id") or "").strip()
if self_id and sender_user_id == self_id and settings.filters.ignore_self_message:
return
return False
if not runtime.chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat):
return
return False
if await runtime.official_bot_guard.should_reject(
sender_user_id=sender_user_id,
group_id=group_id,
ban_qq_bot=settings.chat.ban_qq_bot,
):
return
return False
try:
message_dict = await runtime.inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender)
except ValueError as exc:
self._logger.warning(f"NapCat 入站消息格式不受支持,已丢弃: {exc}")
return
return False
route_metadata = self._build_route_metadata(self_id, settings.napcat_server.connection_id)
external_message_id = str(payload.get("message_id") or "").strip()
@@ -132,6 +145,15 @@ class NapCatEventRouter:
)
if not accepted:
self._logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}")
return False
await self._record_inbound_checkpoint(
payload=payload,
self_id=self_id,
external_message_id=external_message_id or str(message_dict.get("message_id") or "").strip(),
scope=settings.napcat_server.connection_id,
)
return True
async def handle_notice_event(self, payload: NapCatPayloadDict) -> None:
"""处理 NapCat ``notice`` 事件并注入 Host。
@@ -232,6 +254,8 @@ class NapCatEventRouter:
await runtime.runtime_state.report_connected(self_id, settings.napcat_server)
await runtime.heartbeat_monitor.start(self_id, settings.napcat_server.heartbeat_interval)
await runtime.ban_tracker.start()
await runtime.history_recovery_store.load()
self._schedule_history_recovery(self_id=self_id, scope=settings.napcat_server.connection_id)
return
except asyncio.CancelledError:
raise
@@ -279,6 +303,274 @@ class NapCatEventRouter:
raise RuntimeError("NapCat 运行时尚未初始化")
return runtime
def _schedule_history_recovery(self, self_id: str, scope: str) -> None:
"""在连接恢复后调度一次历史补拉任务。"""
self._cancel_recovery_task()
runtime = self._runtime
if runtime is None:
return
self._recovery_task = asyncio.create_task(
self._recover_recent_history(self_id=self_id, scope=scope),
name="napcat_adapter.history_recovery",
)
def _cancel_recovery_task(self) -> None:
"""取消当前仍在运行的历史补拉任务。"""
recovery_task = self._recovery_task
self._recovery_task = None
if recovery_task is not None and not recovery_task.done():
recovery_task.cancel()
async def _recover_recent_history(self, *, self_id: str, scope: str) -> None:
"""按 checkpoint 列表逐个尝试补拉断线期间遗漏的消息。"""
runtime = self._require_runtime()
checkpoints = await runtime.history_recovery_store.list_checkpoints(
self_id,
scope=scope,
limit=DEFAULT_HISTORY_RECOVERY_CHECKPOINT_LIMIT,
)
if not checkpoints:
return
recovered_count = 0
for checkpoint in checkpoints:
recovered_count += await self._recover_chat_history_from_checkpoint(
self_id=self_id,
scope=scope,
checkpoint=checkpoint,
)
if recovered_count > 0:
self._logger.info(f"NapCat 历史补拉完成,共补回 {recovered_count} 条消息")
async def _recover_chat_history_from_checkpoint(
self,
*,
self_id: str,
scope: str,
checkpoint: NapCatChatCheckpoint,
) -> int:
"""针对单个会话执行一次小批量历史补拉。"""
runtime = self._require_runtime()
history_messages = await self._query_history_messages(checkpoint, limit=DEFAULT_HISTORY_RECOVERY_BATCH_SIZE)
if not history_messages:
return 0
ordered_messages = sorted(
history_messages,
key=lambda item: (
self._extract_message_timestamp(item),
self._extract_message_seq(item),
str(item.get("message_id") or "").strip(),
),
)
recovered_count = 0
for history_payload in ordered_messages:
external_message_id = str(history_payload.get("message_id") or "").strip()
if not external_message_id:
continue
if external_message_id == checkpoint.last_message_id:
continue
if await runtime.history_recovery_store.has_recovered_message_seen(
account_id=self_id,
scope=scope,
chat_type=checkpoint.chat_type,
chat_id=checkpoint.chat_id,
external_message_id=external_message_id,
):
continue
if not self._is_message_after_checkpoint(history_payload, checkpoint):
continue
accepted = await self._reinject_history_payload(history_payload, self_id=self_id)
if not accepted:
continue
await runtime.history_recovery_store.mark_recovered_message_seen(
account_id=self_id,
scope=scope,
chat_type=checkpoint.chat_type,
chat_id=checkpoint.chat_id,
external_message_id=external_message_id,
)
recovered_count += 1
return recovered_count
async def _query_history_messages(
self,
checkpoint: NapCatChatCheckpoint,
*,
limit: int,
) -> list[NapCatPayloadDict]:
"""查询某个会话在 checkpoint 之后的一小批历史消息。"""
runtime = self._require_runtime()
payload_collections: list[list[NapCatPayloadDict]] = []
if checkpoint.last_message_seq is not None:
payload_collections.append(
await self._fetch_history_messages(
chat_type=checkpoint.chat_type,
chat_id=checkpoint.chat_id,
message_seq=checkpoint.last_message_seq,
limit=limit,
)
)
payload_collections.append(
await self._fetch_history_messages(
chat_type=checkpoint.chat_type,
chat_id=checkpoint.chat_id,
message_seq=None,
limit=limit,
)
)
merged_payloads: list[NapCatPayloadDict] = []
seen_message_ids: set[str] = set()
for payloads in payload_collections:
for payload in payloads:
external_message_id = str(payload.get("message_id") or "").strip()
dedupe_key = external_message_id or repr(sorted(payload.items()))
if dedupe_key in seen_message_ids:
continue
seen_message_ids.add(dedupe_key)
merged_payloads.append(payload)
return merged_payloads
async def _fetch_history_messages(
self,
*,
chat_type: str,
chat_id: str,
message_seq: int | None,
limit: int,
) -> list[NapCatPayloadDict]:
"""调用查询服务获取一批历史消息。"""
runtime = self._require_runtime()
if chat_type == "group":
history_payloads = await runtime.query_service.get_group_message_history(
chat_id,
message_seq=message_seq,
count=limit,
reverse_order=False,
)
elif chat_type == "private":
history_payloads = await runtime.query_service.get_friend_message_history(
chat_id,
message_seq=message_seq,
count=limit,
reverse_order=False,
)
else:
return []
if history_payloads is None:
return []
return [dict(payload) for payload in history_payloads if isinstance(payload, Mapping)]
async def _reinject_history_payload(self, payload: NapCatPayloadDict, *, self_id: str) -> bool:
"""将补拉到的历史消息重新送回实时入站路径。"""
try:
normalized_payload = dict(payload)
if self_id and not str(normalized_payload.get("self_id") or "").strip():
normalized_payload["self_id"] = self_id
return await self.handle_inbound_message(normalized_payload)
except asyncio.CancelledError:
raise
except Exception as exc:
external_message_id = str(payload.get("message_id") or "").strip() or "unknown"
self._logger.warning(f"NapCat 历史消息补拉注入失败: message_id={external_message_id} error={exc}")
return False
async def _record_inbound_checkpoint(
self,
*,
payload: NapCatPayloadDict,
self_id: str,
external_message_id: str,
scope: str,
) -> None:
"""在消息被 Host 接受后更新该会话的最新 checkpoint。"""
runtime = self._require_runtime()
chat_identity = self._extract_chat_identity(payload)
if chat_identity is None:
return
await runtime.history_recovery_store.record_checkpoint(
account_id=self_id,
scope=scope,
chat_type=chat_identity.chat_type,
chat_id=chat_identity.chat_id,
message_id=external_message_id,
message_time=self._extract_message_timestamp(payload),
message_seq=self._extract_message_seq(payload),
)
@staticmethod
def _extract_chat_identity(payload: Mapping[str, Any]) -> _NapCatChatIdentity | None:
"""从 NapCat 载荷中提取会话身份。"""
group_id = str(payload.get("group_id") or "").strip()
user_id = str(payload.get("user_id") or "").strip()
if group_id:
return _NapCatChatIdentity(chat_type="group", chat_id=group_id)
if user_id:
return _NapCatChatIdentity(chat_type="private", chat_id=user_id)
return None
@staticmethod
def _extract_message_seq(payload: Mapping[str, Any]) -> int | None:
"""从 NapCat 载荷中提取历史接口可复用的消息序号。"""
for field_name in ("message_seq", "messageSeq", "msg_seq"):
raw_value = payload.get(field_name)
if raw_value is None or str(raw_value).strip() == "":
continue
try:
return int(raw_value)
except (TypeError, ValueError):
continue
return None
@staticmethod
def _extract_message_timestamp(payload: Mapping[str, Any]) -> float:
"""从 NapCat 载荷中提取消息时间戳。"""
raw_timestamp = payload.get("time")
if isinstance(raw_timestamp, (int, float)):
return float(raw_timestamp)
return 0.0
@classmethod
def _is_message_after_checkpoint(
cls,
payload: Mapping[str, Any],
checkpoint: NapCatChatCheckpoint,
) -> bool:
"""判断历史消息是否位于 checkpoint 之后。"""
payload_message_id = str(payload.get("message_id") or "").strip()
if payload_message_id == checkpoint.last_message_id:
return False
payload_message_seq = cls._extract_message_seq(payload)
if payload_message_seq is not None and checkpoint.last_message_seq is not None:
return payload_message_seq > checkpoint.last_message_seq
payload_timestamp = cls._extract_message_timestamp(payload)
if payload_timestamp != checkpoint.last_message_time:
return payload_timestamp > checkpoint.last_message_time
return True
@staticmethod
def _build_route_metadata(self_id: str, connection_id: str) -> Dict[str, Any]:
"""构造注入 Host 时使用的路由元数据。

View File

@@ -3,6 +3,7 @@
from .action_service import NapCatActionService
from .ban_tracker import NapCatBanTracker
from .ban_state_store import NapCatBanRecord, NapCatBanStateStore
from .history_recovery_store import NapCatChatCheckpoint, NapCatHistoryRecoveryStore
from .official_bot_guard import NapCatOfficialBotGuard
from .query_service import NapCatQueryService
@@ -11,6 +12,8 @@ __all__ = [
"NapCatBanRecord",
"NapCatBanStateStore",
"NapCatBanTracker",
"NapCatChatCheckpoint",
"NapCatHistoryRecoveryStore",
"NapCatOfficialBotGuard",
"NapCatQueryService",
]

View File

@@ -0,0 +1,423 @@
"""NapCat 历史补拉状态持久化仓库。"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, List, Optional, TypeVar
import asyncio
import sqlite3
import time
from ..constants import DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
_DEFAULT_STORAGE_PATH = _PROJECT_ROOT / "data" / "napcat_adapter" / "history_recovery.sqlite3"
_SCHEMA_STATEMENTS = (
"""
CREATE TABLE IF NOT EXISTS napcat_chat_checkpoint (
account_id TEXT NOT NULL,
scope TEXT NOT NULL,
chat_type TEXT NOT NULL,
chat_id TEXT NOT NULL,
last_message_id TEXT NOT NULL,
last_message_time REAL NOT NULL,
last_message_seq INTEGER,
updated_at REAL NOT NULL,
PRIMARY KEY (account_id, scope, chat_type, chat_id)
)
""",
"""
CREATE INDEX IF NOT EXISTS ix_napcat_chat_checkpoint_updated_at
ON napcat_chat_checkpoint (updated_at DESC)
""",
"""
CREATE TABLE IF NOT EXISTS napcat_recovery_seen (
account_id TEXT NOT NULL,
scope TEXT NOT NULL,
chat_type TEXT NOT NULL,
chat_id TEXT NOT NULL,
external_message_id TEXT NOT NULL,
seen_at REAL NOT NULL,
PRIMARY KEY (account_id, scope, chat_type, chat_id, external_message_id)
)
""",
"""
CREATE INDEX IF NOT EXISTS ix_napcat_recovery_seen_seen_at
ON napcat_recovery_seen (seen_at DESC)
""",
)
T = TypeVar("T")
@dataclass(frozen=True)
class NapCatChatCheckpoint:
"""描述一个会话的最近入站锚点。"""
account_id: str
scope: str
chat_type: str
chat_id: str
last_message_id: str
last_message_time: float
last_message_seq: int | None
updated_at: float
@classmethod
def from_row(cls, row: sqlite3.Row) -> "NapCatChatCheckpoint":
"""从 SQLite 行对象恢复 checkpoint。"""
last_message_seq = row["last_message_seq"]
normalized_seq = int(last_message_seq) if isinstance(last_message_seq, int) else None
return cls(
account_id=str(row["account_id"] or "").strip(),
scope=str(row["scope"] or "").strip(),
chat_type=str(row["chat_type"] or "").strip(),
chat_id=str(row["chat_id"] or "").strip(),
last_message_id=str(row["last_message_id"] or "").strip(),
last_message_time=float(row["last_message_time"] or 0.0),
last_message_seq=normalized_seq,
updated_at=float(row["updated_at"] or 0.0),
)
class NapCatHistoryRecoveryStore:
"""负责持久化历史补拉所需的会话状态与去重状态。"""
def __init__(self, logger: Any, storage_path: Path = _DEFAULT_STORAGE_PATH) -> None:
"""初始化历史补拉状态仓库。"""
self._logger = logger
self._storage_path = storage_path
self._store_lock = asyncio.Lock()
self._schema_ready = False
async def load(self) -> None:
"""初始化 SQLite 文件并清理过期去重记录。"""
await self._execute_locked(self._ensure_schema)
pruned_count = await self.prune_recovery_seen(DEFAULT_HISTORY_RECOVERY_SEEN_TTL_SEC)
if pruned_count > 0:
self._logger.debug(f"NapCat 历史补拉去重表已清理 {pruned_count} 条过期记录")
async def list_checkpoints(self, account_id: str, scope: str = "", limit: int = 50) -> List[NapCatChatCheckpoint]:
"""列出指定账号与作用域下的最近会话 checkpoint。"""
normalized_account_id = str(account_id or "").strip()
if not normalized_account_id:
return []
normalized_scope = self._normalize_scope(scope)
normalized_limit = max(1, int(limit))
def _operation(conn: sqlite3.Connection) -> List[NapCatChatCheckpoint]:
cursor = conn.execute(
"""
SELECT
account_id,
scope,
chat_type,
chat_id,
last_message_id,
last_message_time,
last_message_seq,
updated_at
FROM napcat_chat_checkpoint
WHERE account_id = ? AND scope = ?
ORDER BY updated_at DESC
LIMIT ?
""",
(normalized_account_id, normalized_scope, normalized_limit),
)
return [NapCatChatCheckpoint.from_row(row) for row in cursor.fetchall()]
return await self._execute_locked(_operation)
async def record_checkpoint(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
message_id: str,
message_time: float,
message_seq: int | None = None,
) -> None:
"""记录一条已被 Host 接受的最新入站消息锚点。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_message_id = str(message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_message_id
):
return
normalized_message_time = float(message_time or 0.0)
normalized_message_seq = self._normalize_message_seq(message_seq)
updated_at = time.time()
def _operation(conn: sqlite3.Connection) -> None:
cursor = conn.execute(
"""
SELECT last_message_id, last_message_time, last_message_seq
FROM napcat_chat_checkpoint
WHERE account_id = ? AND scope = ? AND chat_type = ? AND chat_id = ?
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
),
)
existing_row = cursor.fetchone()
if existing_row is not None and not self._should_advance_checkpoint(
existing_row=existing_row,
message_id=normalized_message_id,
message_time=normalized_message_time,
message_seq=normalized_message_seq,
):
return
conn.execute(
"""
INSERT INTO napcat_chat_checkpoint (
account_id,
scope,
chat_type,
chat_id,
last_message_id,
last_message_time,
last_message_seq,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account_id, scope, chat_type, chat_id) DO UPDATE SET
last_message_id = excluded.last_message_id,
last_message_time = excluded.last_message_time,
last_message_seq = excluded.last_message_seq,
updated_at = excluded.updated_at
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_message_id,
normalized_message_time,
normalized_message_seq,
updated_at,
),
)
await self._execute_locked(_operation)
async def has_recovered_message_seen(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
external_message_id: str,
) -> bool:
"""判断某条历史补拉消息是否已经被当前仓库记录过。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_external_message_id = str(external_message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_external_message_id
):
return False
def _operation(conn: sqlite3.Connection) -> bool:
cursor = conn.execute(
"""
SELECT 1
FROM napcat_recovery_seen
WHERE account_id = ?
AND scope = ?
AND chat_type = ?
AND chat_id = ?
AND external_message_id = ?
LIMIT 1
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_external_message_id,
),
)
return cursor.fetchone() is not None
return await self._execute_locked(_operation)
async def mark_recovered_message_seen(
self,
*,
account_id: str,
scope: str = "",
chat_type: str,
chat_id: str,
external_message_id: str,
) -> None:
"""将一条历史补拉消息标记为已尝试处理。"""
normalized_account_id = str(account_id or "").strip()
normalized_scope = self._normalize_scope(scope)
normalized_chat_type = str(chat_type or "").strip()
normalized_chat_id = str(chat_id or "").strip()
normalized_external_message_id = str(external_message_id or "").strip()
if not (
normalized_account_id
and normalized_chat_type
and normalized_chat_id
and normalized_external_message_id
):
return
def _operation(conn: sqlite3.Connection) -> None:
conn.execute(
"""
INSERT OR REPLACE INTO napcat_recovery_seen (
account_id,
scope,
chat_type,
chat_id,
external_message_id,
seen_at
) VALUES (?, ?, ?, ?, ?, ?)
""",
(
normalized_account_id,
normalized_scope,
normalized_chat_type,
normalized_chat_id,
normalized_external_message_id,
time.time(),
),
)
await self._execute_locked(_operation)
async def prune_recovery_seen(self, ttl_seconds: float) -> int:
"""删除超过保留期的历史补拉去重记录。"""
normalized_ttl_seconds = max(0.0, float(ttl_seconds or 0.0))
if normalized_ttl_seconds <= 0.0:
return 0
cutoff_timestamp = time.time() - normalized_ttl_seconds
def _operation(conn: sqlite3.Connection) -> int:
cursor = conn.execute(
"DELETE FROM napcat_recovery_seen WHERE seen_at < ?",
(cutoff_timestamp,),
)
return int(cursor.rowcount or 0)
return await self._execute_locked(_operation)
async def _execute_locked(self, operation: Callable[[sqlite3.Connection], T]) -> T:
"""在锁保护下打开 SQLite 并执行一次原子操作。"""
async with self._store_lock:
conn = self._open_connection()
try:
self._ensure_schema(conn)
result = operation(conn)
conn.commit()
return result
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _open_connection(self) -> sqlite3.Connection:
"""打开一个带 Row 工厂的 SQLite 连接。"""
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(self._storage_path))
conn.row_factory = sqlite3.Row
return conn
def _ensure_schema(self, conn: sqlite3.Connection) -> None:
"""确保 SQLite 表结构已经准备完成。"""
if self._schema_ready:
return
for statement in _SCHEMA_STATEMENTS:
conn.execute(statement)
self._schema_ready = True
@staticmethod
def _normalize_scope(scope: str | None) -> str:
"""将空作用域统一折叠为空字符串。"""
return str(scope or "").strip()
@staticmethod
def _normalize_message_seq(message_seq: object) -> int | None:
"""将消息序号规范化为可选整数。"""
try:
if message_seq is None or str(message_seq).strip() == "":
return None
return int(message_seq)
except (TypeError, ValueError):
return None
@classmethod
def _should_advance_checkpoint(
cls,
*,
existing_row: sqlite3.Row,
message_id: str,
message_time: float,
message_seq: int | None,
) -> bool:
"""判断新的锚点是否应覆盖旧锚点。"""
existing_message_id = str(existing_row["last_message_id"] or "").strip()
existing_message_time = float(existing_row["last_message_time"] or 0.0)
existing_message_seq = cls._normalize_message_seq(existing_row["last_message_seq"])
if message_seq is not None and existing_message_seq is not None:
if message_seq != existing_message_seq:
return message_seq > existing_message_seq
if message_id == existing_message_id:
return False
return message_time >= existing_message_time
if message_time != existing_message_time:
return message_time > existing_message_time
if message_id == existing_message_id:
return False
if message_seq is not None and existing_message_seq is None:
return True
return True

View File

@@ -180,6 +180,46 @@ class NapCatQueryService:
response_data = await self._safe_call_action_data("get_msg", {"message_id": message_id})
return response_data if isinstance(response_data, dict) else None
async def get_friend_message_history(
self,
user_id: str,
*,
message_seq: int | None = None,
count: int = 20,
reverse_order: bool = False,
) -> Optional[NapCatPayloadList]:
"""获取私聊历史消息列表。"""
params: NapCatActionResponse = {
"user_id": user_id,
"count": max(1, int(count)),
"reverse_order": bool(reverse_order),
}
if message_seq is not None:
params["message_seq"] = int(message_seq)
response_data = await self._safe_call_action_data("get_friend_msg_history", params)
return self._normalize_payload_list(response_data, action_name="get_friend_msg_history")
async def get_group_message_history(
self,
group_id: str,
*,
message_seq: int | None = None,
count: int = 20,
reverse_order: bool = False,
) -> Optional[NapCatPayloadList]:
"""获取群聊历史消息列表。"""
params: NapCatActionResponse = {
"group_id": group_id,
"count": max(1, int(count)),
"reverse_order": bool(reverse_order),
}
if message_seq is not None:
params["message_seq"] = int(message_seq)
response_data = await self._safe_call_action_data("get_group_msg_history", params)
return self._normalize_payload_list(response_data, action_name="get_group_msg_history")
async def get_forward_message(
self,
message_id: Optional[str] = None,

View File

@@ -12,8 +12,10 @@ import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEST_MODULE = "_test_napcat_adapter"
for import_path in (str(SDK_ROOT),):
@@ -204,12 +206,14 @@ def _load_napcat_sdk_modules() -> Tuple[Any, Any, Any, Any]:
依次返回常量模块、配置模块、插件模块和运行时状态模块。
"""
plugin_dir = NAPCAT_PLUGIN_DIR if NAPCAT_PLUGIN_DIR.is_dir() else NAPCAT_TEMPLATE_DIR
if NAPCAT_TEST_MODULE not in sys.modules:
plugin_path = NAPCAT_PLUGIN_DIR / "plugin.py"
plugin_path = plugin_dir / "plugin.py"
spec = util.spec_from_file_location(
NAPCAT_TEST_MODULE,
plugin_path,
submodule_search_locations=[str(NAPCAT_PLUGIN_DIR)],
submodule_search_locations=[str(plugin_dir)],
)
if spec is None or spec.loader is None:
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")

View File

@@ -0,0 +1,376 @@
"""NapCat 历史补拉与恢复状态测试。"""
from __future__ import annotations
from importlib import import_module, util
from pathlib import Path
from typing import Any, Dict, List
import logging
import sys
from types import SimpleNamespace
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEST_MODULE = "_test_napcat_adapter_history_recovery"
for import_path in (str(SDK_ROOT),):
if import_path not in sys.path:
sys.path.insert(0, import_path)
class _FakeGatewayCapability:
"""用于测试入站注入的网关替身。"""
def __init__(self) -> None:
"""初始化测试替身。"""
self.calls: List[Dict[str, Any]] = []
async def route_message(
self,
gateway_name: str,
message: Dict[str, Any],
*,
route_metadata: Dict[str, Any] | None = None,
external_message_id: str = "",
dedupe_key: str = "",
) -> bool:
"""记录入站注入请求并始终模拟成功。"""
self.calls.append(
{
"gateway_name": gateway_name,
"message": dict(message),
"route_metadata": dict(route_metadata or {}),
"external_message_id": external_message_id,
"dedupe_key": dedupe_key,
}
)
return True
def _resolve_napcat_plugin_dir() -> Path:
"""返回当前测试可用的 NapCat 插件目录。"""
if NAPCAT_PLUGIN_DIR.is_dir():
return NAPCAT_PLUGIN_DIR
return NAPCAT_TEMPLATE_DIR
def _load_napcat_module(module_suffix: str) -> Any:
"""动态加载 NapCat 测试模块。"""
plugin_dir = _resolve_napcat_plugin_dir()
if NAPCAT_TEST_MODULE not in sys.modules:
plugin_path = plugin_dir / "plugin.py"
spec = util.spec_from_file_location(
NAPCAT_TEST_MODULE,
plugin_path,
submodule_search_locations=[str(plugin_dir)],
)
if spec is None or spec.loader is None:
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
module = util.module_from_spec(spec)
sys.modules[NAPCAT_TEST_MODULE] = module
try:
spec.loader.exec_module(module)
except Exception:
sys.modules.pop(NAPCAT_TEST_MODULE, None)
raise
return import_module(f"{NAPCAT_TEST_MODULE}.{module_suffix}")
def _load_history_recovery_store_cls() -> Any:
"""动态加载历史恢复状态仓库类。"""
return _load_napcat_module("services.history_recovery_store").NapCatHistoryRecoveryStore
def _load_query_service_cls() -> Any:
"""动态加载查询服务类。"""
return _load_napcat_module("services.query_service").NapCatQueryService
def _load_router_cls() -> Any:
"""动态加载事件路由器类。"""
return _load_napcat_module("runtime.router").NapCatEventRouter
class _FakeActionService:
"""用于查询服务的动作服务替身。"""
def __init__(self, response_data: Any) -> None:
"""初始化动作服务替身。"""
self._response_data = response_data
self.action_data_calls: List[Dict[str, Any]] = []
async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any:
"""记录安全查询动作。"""
self.action_data_calls.append({"action_name": action_name, "params": dict(params)})
return self._response_data
@pytest.mark.asyncio
async def test_history_recovery_store_persists_checkpoint_and_seen_state(tmp_path: Path) -> None:
"""历史恢复状态仓库应持久化 checkpoint 与已补拉标记。"""
store_cls = _load_history_recovery_store_cls()
store = store_cls(
logger=logging.getLogger("test.napcat.history_store"),
storage_path=tmp_path / "history.sqlite3",
)
await store.load()
await store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-2",
message_time=200.0,
message_seq=2,
)
await store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-1",
message_time=100.0,
message_seq=1,
)
await store.mark_recovered_message_seen(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
external_message_id="history-1",
)
checkpoints = await store.list_checkpoints("10001", scope="primary")
assert len(checkpoints) == 1
assert checkpoints[0].last_message_id == "msg-2"
assert checkpoints[0].last_message_seq == 2
assert (
await store.has_recovered_message_seen(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
external_message_id="history-1",
)
is True
)
@pytest.mark.asyncio
async def test_query_service_wraps_group_and_friend_history_actions() -> None:
"""查询服务应按官方动作名封装历史消息接口。"""
query_service_cls = _load_query_service_cls()
action_service = _FakeActionService([{"message_id": "msg-1"}])
query_service = query_service_cls(
action_service=action_service,
logger=logging.getLogger("test.napcat.history_query"),
)
group_payload = await query_service.get_group_message_history("20001", message_seq=123, count=10)
friend_payload = await query_service.get_friend_message_history("30001", count=5, reverse_order=True)
assert group_payload == [{"message_id": "msg-1"}]
assert friend_payload == [{"message_id": "msg-1"}]
assert action_service.action_data_calls == [
{
"action_name": "get_group_msg_history",
"params": {"group_id": "20001", "count": 10, "reverse_order": False, "message_seq": 123},
},
{
"action_name": "get_friend_msg_history",
"params": {"user_id": "30001", "count": 5, "reverse_order": True},
},
]
@pytest.mark.asyncio
async def test_router_recover_recent_history_reinjects_messages_in_order(tmp_path: Path) -> None:
"""重连补拉应按时间顺序将历史消息重新注入原入站路径。"""
history_store_cls = _load_history_recovery_store_cls()
router_cls = _load_router_cls()
gateway_capability = _FakeGatewayCapability()
router = router_cls(
gateway_capability=gateway_capability,
logger=logging.getLogger("test.napcat.history_router"),
gateway_name="napcat_gateway",
load_settings=lambda: SimpleNamespace(
napcat_server=SimpleNamespace(connection_id="primary", heartbeat_interval=30.0),
filters=SimpleNamespace(ignore_self_message=True),
chat=SimpleNamespace(ban_qq_bot=False),
),
)
history_store = history_store_cls(
logger=logging.getLogger("test.napcat.history_router.store"),
storage_path=tmp_path / "router.sqlite3",
)
await history_store.load()
await history_store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-1",
message_time=100.0,
message_seq=10,
)
history_calls: List[Dict[str, Any]] = []
history_payloads = [
{
"post_type": "message",
"message_type": "group",
"self_id": "10001",
"group_id": "20001",
"user_id": "30002",
"message_id": "msg-3",
"message_seq": 12,
"time": 102,
"message": [{"type": "text", "data": {"text": "第三条"}}],
"sender": {"user_id": "30002", "nickname": "用户二"},
},
{
"post_type": "message",
"message_type": "group",
"self_id": "10001",
"group_id": "20001",
"user_id": "30001",
"message_id": "msg-2",
"message_seq": 11,
"time": 101,
"message": [{"type": "text", "data": {"text": "第二条"}}],
"sender": {"user_id": "30001", "nickname": "用户一"},
},
]
class _FakeQueryService:
async def get_group_message_history(
self,
group_id: str,
*,
message_seq: int | None = None,
count: int = 20,
reverse_order: bool = False,
) -> List[Dict[str, Any]]:
history_calls.append(
{
"group_id": group_id,
"message_seq": message_seq,
"count": count,
"reverse_order": reverse_order,
}
)
return list(history_payloads)
async def get_friend_message_history(self, user_id: str, **kwargs: Any) -> List[Dict[str, Any]]:
del user_id
del kwargs
return []
class _FakeInboundCodec:
@staticmethod
async def build_message_dict(
payload: Dict[str, Any],
self_id: str,
sender_user_id: str,
sender: Dict[str, Any],
) -> Dict[str, Any]:
del self_id
del sender_user_id
del sender
return {
"message_id": str(payload["message_id"]),
"platform": "qq",
"timestamp": str(float(payload["time"])),
"message_info": {
"user_info": {"user_id": str(payload["user_id"]), "user_nickname": "测试用户"},
"group_info": {"group_id": str(payload["group_id"]), "group_name": "测试群"},
"additional_config": {},
},
"raw_message": [{"type": "text", "data": str(payload["message"][0]["data"]["text"])}],
"processed_plain_text": str(payload["message"][0]["data"]["text"]),
"display_message": str(payload["message"][0]["data"]["text"]),
"is_mentioned": False,
"is_at": False,
"is_emoji": False,
"is_picture": False,
"is_command": False,
"is_notify": False,
"session_id": "",
}
router.bind_runtime(
SimpleNamespace(
runtime_state=SimpleNamespace(report_connected=lambda *args, **kwargs: _noop_async(), report_disconnected=_noop_async),
chat_filter=SimpleNamespace(is_inbound_chat_allowed=lambda *args, **kwargs: True),
official_bot_guard=SimpleNamespace(
should_reject=lambda *args, **kwargs: _return_false_async(),
clear_cache=lambda: None,
),
inbound_codec=_FakeInboundCodec(),
history_recovery_store=history_store,
query_service=_FakeQueryService(),
heartbeat_monitor=SimpleNamespace(start=_noop_async, stop=_noop_async),
ban_tracker=SimpleNamespace(start=_noop_async, stop=_noop_async, record_notice=_noop_async),
notice_codec=SimpleNamespace(handle_meta_event=_noop_async, build_notice_message_dict=_return_none_async),
)
)
await router._recover_recent_history(self_id="10001", scope="primary")
assert history_calls == [
{
"group_id": "20001",
"message_seq": 10,
"count": 20,
"reverse_order": False,
}
]
assert [call["external_message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
assert [call["message"]["message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
async def _noop_async(*args: Any, **kwargs: Any) -> None:
"""无操作异步函数。"""
del args
del kwargs
async def _return_false_async(*args: Any, **kwargs: Any) -> bool:
"""返回 ``False`` 的异步测试替身。"""
del args
del kwargs
return False
async def _return_none_async(*args: Any, **kwargs: Any) -> None:
"""返回 ``None`` 的异步测试替身。"""
del args
del kwargs
return None